1use arrow::array::Float64Array;
21use arrow::datatypes::FieldRef;
22use arrow::{
23 array::{ArrayRef, UInt64Array},
24 compute::cast,
25 datatypes::DataType,
26 datatypes::Field,
27};
28use datafusion_common::{
29 downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, HashMap, Result,
30 ScalarValue,
31};
32use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL;
33use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
34use datafusion_expr::type_coercion::aggregates::NUMERICS;
35use datafusion_expr::utils::format_state_name;
36use datafusion_expr::{
37 Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
38};
39use std::any::Any;
40use std::fmt::Debug;
41use std::hash::{DefaultHasher, Hash, Hasher};
42use std::mem::size_of_val;
43use std::sync::{Arc, LazyLock};
44
45macro_rules! make_regr_udaf_expr_and_func {
46 ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $REGR_TYPE:expr) => {
47 make_udaf_expr!($EXPR_FN, expr_y expr_x, concat!("Compute a linear regression of type [", stringify!($REGR_TYPE), "]"), $AGGREGATE_UDF_FN);
48 create_func!($EXPR_FN, $AGGREGATE_UDF_FN, Regr::new($REGR_TYPE, stringify!($EXPR_FN)));
49 }
50}
51
52make_regr_udaf_expr_and_func!(regr_slope, regr_slope_udaf, RegrType::Slope);
53make_regr_udaf_expr_and_func!(regr_intercept, regr_intercept_udaf, RegrType::Intercept);
54make_regr_udaf_expr_and_func!(regr_count, regr_count_udaf, RegrType::Count);
55make_regr_udaf_expr_and_func!(regr_r2, regr_r2_udaf, RegrType::R2);
56make_regr_udaf_expr_and_func!(regr_avgx, regr_avgx_udaf, RegrType::AvgX);
57make_regr_udaf_expr_and_func!(regr_avgy, regr_avgy_udaf, RegrType::AvgY);
58make_regr_udaf_expr_and_func!(regr_sxx, regr_sxx_udaf, RegrType::SXX);
59make_regr_udaf_expr_and_func!(regr_syy, regr_syy_udaf, RegrType::SYY);
60make_regr_udaf_expr_and_func!(regr_sxy, regr_sxy_udaf, RegrType::SXY);
61
62pub struct Regr {
63 signature: Signature,
64 regr_type: RegrType,
65 func_name: &'static str,
66}
67
68impl Debug for Regr {
69 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
70 f.debug_struct("regr")
71 .field("name", &self.name())
72 .field("signature", &self.signature)
73 .finish()
74 }
75}
76
77impl Regr {
78 pub fn new(regr_type: RegrType, func_name: &'static str) -> Self {
79 Self {
80 signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
81 regr_type,
82 func_name,
83 }
84 }
85}
86
87#[derive(Debug, Clone, PartialEq, Hash, Eq)]
88#[allow(clippy::upper_case_acronyms)]
89pub enum RegrType {
90 Slope,
95 Intercept,
100 Count,
104 R2,
108 AvgX,
112 AvgY,
116 SXX,
120 SYY,
124 SXY,
128}
129
130impl RegrType {
131 fn documentation(&self) -> Option<&Documentation> {
133 get_regr_docs().get(self)
134 }
135}
136
137static DOCUMENTATION: LazyLock<HashMap<RegrType, Documentation>> = LazyLock::new(|| {
138 let mut hash_map = HashMap::new();
139 hash_map.insert(
140 RegrType::Slope,
141 Documentation::builder(
142 DOC_SECTION_STATISTICAL,
143 "Returns the slope of the linear regression line for non-null pairs in aggregate columns. \
144 Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k*X + b) using minimal RSS fitting.",
145
146 "regr_slope(expression_y, expression_x)")
147 .with_standard_argument("expression_y", Some("Dependent variable"))
148 .with_standard_argument("expression_x", Some("Independent variable"))
149 .build()
150 );
151
152 hash_map.insert(
153 RegrType::Intercept,
154 Documentation::builder(
155 DOC_SECTION_STATISTICAL,
156 "Computes the y-intercept of the linear regression line. For the equation (y = kx + b), \
157 this function returns b.",
158
159 "regr_intercept(expression_y, expression_x)")
160 .with_standard_argument("expression_y", Some("Dependent variable"))
161 .with_standard_argument("expression_x", Some("Independent variable"))
162 .build()
163 );
164
165 hash_map.insert(
166 RegrType::Count,
167 Documentation::builder(
168 DOC_SECTION_STATISTICAL,
169 "Counts the number of non-null paired data points.",
170 "regr_count(expression_y, expression_x)",
171 )
172 .with_standard_argument("expression_y", Some("Dependent variable"))
173 .with_standard_argument("expression_x", Some("Independent variable"))
174 .build(),
175 );
176
177 hash_map.insert(
178 RegrType::R2,
179 Documentation::builder(
180 DOC_SECTION_STATISTICAL,
181 "Computes the square of the correlation coefficient between the independent and dependent variables.",
182
183 "regr_r2(expression_y, expression_x)")
184 .with_standard_argument("expression_y", Some("Dependent variable"))
185 .with_standard_argument("expression_x", Some("Independent variable"))
186 .build()
187 );
188
189 hash_map.insert(
190 RegrType::AvgX,
191 Documentation::builder(
192 DOC_SECTION_STATISTICAL,
193 "Computes the average of the independent variable (input) expression_x for the non-null paired data points.",
194
195 "regr_avgx(expression_y, expression_x)")
196 .with_standard_argument("expression_y", Some("Dependent variable"))
197 .with_standard_argument("expression_x", Some("Independent variable"))
198 .build()
199 );
200
201 hash_map.insert(
202 RegrType::AvgY,
203 Documentation::builder(
204 DOC_SECTION_STATISTICAL,
205 "Computes the average of the dependent variable (output) expression_y for the non-null paired data points.",
206
207 "regr_avgy(expression_y, expression_x)")
208 .with_standard_argument("expression_y", Some("Dependent variable"))
209 .with_standard_argument("expression_x", Some("Independent variable"))
210 .build()
211 );
212
213 hash_map.insert(
214 RegrType::SXX,
215 Documentation::builder(
216 DOC_SECTION_STATISTICAL,
217 "Computes the sum of squares of the independent variable.",
218 "regr_sxx(expression_y, expression_x)",
219 )
220 .with_standard_argument("expression_y", Some("Dependent variable"))
221 .with_standard_argument("expression_x", Some("Independent variable"))
222 .build(),
223 );
224
225 hash_map.insert(
226 RegrType::SYY,
227 Documentation::builder(
228 DOC_SECTION_STATISTICAL,
229 "Computes the sum of squares of the dependent variable.",
230 "regr_syy(expression_y, expression_x)",
231 )
232 .with_standard_argument("expression_y", Some("Dependent variable"))
233 .with_standard_argument("expression_x", Some("Independent variable"))
234 .build(),
235 );
236
237 hash_map.insert(
238 RegrType::SXY,
239 Documentation::builder(
240 DOC_SECTION_STATISTICAL,
241 "Computes the sum of products of paired data points.",
242 "regr_sxy(expression_y, expression_x)",
243 )
244 .with_standard_argument("expression_y", Some("Dependent variable"))
245 .with_standard_argument("expression_x", Some("Independent variable"))
246 .build(),
247 );
248 hash_map
249});
250fn get_regr_docs() -> &'static HashMap<RegrType, Documentation> {
251 &DOCUMENTATION
252}
253
254impl AggregateUDFImpl for Regr {
255 fn as_any(&self) -> &dyn Any {
256 self
257 }
258
259 fn name(&self) -> &str {
260 self.func_name
261 }
262
263 fn signature(&self) -> &Signature {
264 &self.signature
265 }
266
267 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
268 if !arg_types[0].is_numeric() {
269 return plan_err!("Covariance requires numeric input types");
270 }
271
272 if matches!(self.regr_type, RegrType::Count) {
273 Ok(DataType::UInt64)
274 } else {
275 Ok(DataType::Float64)
276 }
277 }
278
279 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
280 Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?))
281 }
282
283 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
284 Ok(vec![
285 Field::new(
286 format_state_name(args.name, "count"),
287 DataType::UInt64,
288 true,
289 ),
290 Field::new(
291 format_state_name(args.name, "mean_x"),
292 DataType::Float64,
293 true,
294 ),
295 Field::new(
296 format_state_name(args.name, "mean_y"),
297 DataType::Float64,
298 true,
299 ),
300 Field::new(
301 format_state_name(args.name, "m2_x"),
302 DataType::Float64,
303 true,
304 ),
305 Field::new(
306 format_state_name(args.name, "m2_y"),
307 DataType::Float64,
308 true,
309 ),
310 Field::new(
311 format_state_name(args.name, "algo_const"),
312 DataType::Float64,
313 true,
314 ),
315 ]
316 .into_iter()
317 .map(Arc::new)
318 .collect())
319 }
320
321 fn documentation(&self) -> Option<&Documentation> {
322 self.regr_type.documentation()
323 }
324
325 fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
326 let Some(other) = other.as_any().downcast_ref::<Self>() else {
327 return false;
328 };
329 let Self {
330 signature,
331 regr_type,
332 func_name,
333 } = self;
334 signature == &other.signature
335 && regr_type == &other.regr_type
336 && func_name == &other.func_name
337 }
338
339 fn hash_value(&self) -> u64 {
340 let Self {
341 signature,
342 regr_type,
343 func_name,
344 } = self;
345 let mut hasher = DefaultHasher::new();
346 std::any::type_name::<Self>().hash(&mut hasher);
347 signature.hash(&mut hasher);
348 regr_type.hash(&mut hasher);
349 func_name.hash(&mut hasher);
350 hasher.finish()
351 }
352}
353
354#[derive(Debug)]
394pub struct RegrAccumulator {
395 count: u64,
396 mean_x: f64,
397 mean_y: f64,
398 m2_x: f64,
399 m2_y: f64,
400 algo_const: f64,
401 regr_type: RegrType,
402}
403
404impl RegrAccumulator {
405 pub fn try_new(regr_type: &RegrType) -> Result<Self> {
407 Ok(Self {
408 count: 0_u64,
409 mean_x: 0_f64,
410 mean_y: 0_f64,
411 m2_x: 0_f64,
412 m2_y: 0_f64,
413 algo_const: 0_f64,
414 regr_type: regr_type.clone(),
415 })
416 }
417}
418
419impl Accumulator for RegrAccumulator {
420 fn state(&mut self) -> Result<Vec<ScalarValue>> {
421 Ok(vec![
422 ScalarValue::from(self.count),
423 ScalarValue::from(self.mean_x),
424 ScalarValue::from(self.mean_y),
425 ScalarValue::from(self.m2_x),
426 ScalarValue::from(self.m2_y),
427 ScalarValue::from(self.algo_const),
428 ])
429 }
430
431 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
432 let values_y = &cast(&values[0], &DataType::Float64)?;
434 let values_x = &cast(&values[1], &DataType::Float64)?;
435
436 let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten();
437 let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten();
438
439 for i in 0..values_y.len() {
440 let value_y = if values_y.is_valid(i) {
442 arr_y.next()
443 } else {
444 None
445 };
446 let value_x = if values_x.is_valid(i) {
447 arr_x.next()
448 } else {
449 None
450 };
451 if value_y.is_none() || value_x.is_none() {
452 continue;
453 }
454
455 let value_y = unwrap_or_internal_err!(value_y);
457 let value_x = unwrap_or_internal_err!(value_x);
458
459 self.count += 1;
460 let delta_x = value_x - self.mean_x;
461 let delta_y = value_y - self.mean_y;
462 self.mean_x += delta_x / self.count as f64;
463 self.mean_y += delta_y / self.count as f64;
464 let delta_x_2 = value_x - self.mean_x;
465 let delta_y_2 = value_y - self.mean_y;
466 self.m2_x += delta_x * delta_x_2;
467 self.m2_y += delta_y * delta_y_2;
468 self.algo_const += delta_x * (value_y - self.mean_y);
469 }
470
471 Ok(())
472 }
473
474 fn supports_retract_batch(&self) -> bool {
475 true
476 }
477
478 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
479 let values_y = &cast(&values[0], &DataType::Float64)?;
480 let values_x = &cast(&values[1], &DataType::Float64)?;
481
482 let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten();
483 let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten();
484
485 for i in 0..values_y.len() {
486 let value_y = if values_y.is_valid(i) {
488 arr_y.next()
489 } else {
490 None
491 };
492 let value_x = if values_x.is_valid(i) {
493 arr_x.next()
494 } else {
495 None
496 };
497 if value_y.is_none() || value_x.is_none() {
498 continue;
499 }
500
501 let value_y = unwrap_or_internal_err!(value_y);
503 let value_x = unwrap_or_internal_err!(value_x);
504
505 if self.count > 1 {
506 self.count -= 1;
507 let delta_x = value_x - self.mean_x;
508 let delta_y = value_y - self.mean_y;
509 self.mean_x -= delta_x / self.count as f64;
510 self.mean_y -= delta_y / self.count as f64;
511 let delta_x_2 = value_x - self.mean_x;
512 let delta_y_2 = value_y - self.mean_y;
513 self.m2_x -= delta_x * delta_x_2;
514 self.m2_y -= delta_y * delta_y_2;
515 self.algo_const -= delta_x * (value_y - self.mean_y);
516 } else {
517 self.count = 0;
518 self.mean_x = 0.0;
519 self.m2_x = 0.0;
520 self.m2_y = 0.0;
521 self.mean_y = 0.0;
522 self.algo_const = 0.0;
523 }
524 }
525
526 Ok(())
527 }
528
529 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
530 let count_arr = downcast_value!(states[0], UInt64Array);
531 let mean_x_arr = downcast_value!(states[1], Float64Array);
532 let mean_y_arr = downcast_value!(states[2], Float64Array);
533 let m2_x_arr = downcast_value!(states[3], Float64Array);
534 let m2_y_arr = downcast_value!(states[4], Float64Array);
535 let algo_const_arr = downcast_value!(states[5], Float64Array);
536
537 for i in 0..count_arr.len() {
538 let count_b = count_arr.value(i);
539 if count_b == 0_u64 {
540 continue;
541 }
542 let (count_a, mean_x_a, mean_y_a, m2_x_a, m2_y_a, algo_const_a) = (
543 self.count,
544 self.mean_x,
545 self.mean_y,
546 self.m2_x,
547 self.m2_y,
548 self.algo_const,
549 );
550 let (count_b, mean_x_b, mean_y_b, m2_x_b, m2_y_b, algo_const_b) = (
551 count_b,
552 mean_x_arr.value(i),
553 mean_y_arr.value(i),
554 m2_x_arr.value(i),
555 m2_y_arr.value(i),
556 algo_const_arr.value(i),
557 );
558
559 let count_ab = count_a + count_b;
568 let (count_a, count_b) = (count_a as f64, count_b as f64);
569 let d_x = mean_x_b - mean_x_a;
570 let d_y = mean_y_b - mean_y_a;
571 let mean_x_ab = mean_x_a + d_x * count_b / count_ab as f64;
572 let mean_y_ab = mean_y_a + d_y * count_b / count_ab as f64;
573 let m2_x_ab =
574 m2_x_a + m2_x_b + d_x * d_x * count_a * count_b / count_ab as f64;
575 let m2_y_ab =
576 m2_y_a + m2_y_b + d_y * d_y * count_a * count_b / count_ab as f64;
577 let algo_const_ab = algo_const_a
578 + algo_const_b
579 + d_x * d_y * count_a * count_b / count_ab as f64;
580
581 self.count = count_ab;
582 self.mean_x = mean_x_ab;
583 self.mean_y = mean_y_ab;
584 self.m2_x = m2_x_ab;
585 self.m2_y = m2_y_ab;
586 self.algo_const = algo_const_ab;
587 }
588 Ok(())
589 }
590
591 fn evaluate(&mut self) -> Result<ScalarValue> {
592 let cov_pop_x_y = self.algo_const / self.count as f64;
593 let var_pop_x = self.m2_x / self.count as f64;
594 let var_pop_y = self.m2_y / self.count as f64;
595
596 let nullif_or_stat = |cond: bool, stat: f64| {
597 if cond {
598 Ok(ScalarValue::Float64(None))
599 } else {
600 Ok(ScalarValue::Float64(Some(stat)))
601 }
602 };
603
604 match self.regr_type {
605 RegrType::Slope => {
606 let nullif_cond = self.count <= 1 || var_pop_x == 0.0;
608 nullif_or_stat(nullif_cond, cov_pop_x_y / var_pop_x)
609 }
610 RegrType::Intercept => {
611 let slope = cov_pop_x_y / var_pop_x;
612 let nullif_cond = self.count <= 1 || var_pop_x == 0.0;
614 nullif_or_stat(nullif_cond, self.mean_y - slope * self.mean_x)
615 }
616 RegrType::Count => Ok(ScalarValue::UInt64(Some(self.count))),
617 RegrType::R2 => {
618 let nullif_cond = self.count <= 1 || var_pop_x == 0.0 || var_pop_y == 0.0;
620 nullif_or_stat(
621 nullif_cond,
622 (cov_pop_x_y * cov_pop_x_y) / (var_pop_x * var_pop_y),
623 )
624 }
625 RegrType::AvgX => nullif_or_stat(self.count < 1, self.mean_x),
626 RegrType::AvgY => nullif_or_stat(self.count < 1, self.mean_y),
627 RegrType::SXX => nullif_or_stat(self.count < 1, self.m2_x),
628 RegrType::SYY => nullif_or_stat(self.count < 1, self.m2_y),
629 RegrType::SXY => nullif_or_stat(self.count < 1, self.algo_const),
630 }
631 }
632
633 fn size(&self) -> usize {
634 size_of_val(self)
635 }
636}