1use arrow::array::Float64Array;
21use arrow::datatypes::FieldRef;
22use arrow::{
23 array::{ArrayRef, UInt64Array},
24 compute::cast,
25 datatypes::DataType,
26 datatypes::Field,
27};
28use datafusion_common::{
29 downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, HashMap, Result,
30 ScalarValue,
31};
32use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL;
33use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
34use datafusion_expr::type_coercion::aggregates::NUMERICS;
35use datafusion_expr::utils::format_state_name;
36use datafusion_expr::{
37 Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
38};
39use std::any::Any;
40use std::fmt::Debug;
41use std::mem::size_of_val;
42use std::sync::{Arc, LazyLock};
43
44macro_rules! make_regr_udaf_expr_and_func {
45 ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $REGR_TYPE:expr) => {
46 make_udaf_expr!($EXPR_FN, expr_y expr_x, concat!("Compute a linear regression of type [", stringify!($REGR_TYPE), "]"), $AGGREGATE_UDF_FN);
47 create_func!($EXPR_FN, $AGGREGATE_UDF_FN, Regr::new($REGR_TYPE, stringify!($EXPR_FN)));
48 }
49}
50
51make_regr_udaf_expr_and_func!(regr_slope, regr_slope_udaf, RegrType::Slope);
52make_regr_udaf_expr_and_func!(regr_intercept, regr_intercept_udaf, RegrType::Intercept);
53make_regr_udaf_expr_and_func!(regr_count, regr_count_udaf, RegrType::Count);
54make_regr_udaf_expr_and_func!(regr_r2, regr_r2_udaf, RegrType::R2);
55make_regr_udaf_expr_and_func!(regr_avgx, regr_avgx_udaf, RegrType::AvgX);
56make_regr_udaf_expr_and_func!(regr_avgy, regr_avgy_udaf, RegrType::AvgY);
57make_regr_udaf_expr_and_func!(regr_sxx, regr_sxx_udaf, RegrType::SXX);
58make_regr_udaf_expr_and_func!(regr_syy, regr_syy_udaf, RegrType::SYY);
59make_regr_udaf_expr_and_func!(regr_sxy, regr_sxy_udaf, RegrType::SXY);
60
61pub struct Regr {
62 signature: Signature,
63 regr_type: RegrType,
64 func_name: &'static str,
65}
66
67impl Debug for Regr {
68 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
69 f.debug_struct("regr")
70 .field("name", &self.name())
71 .field("signature", &self.signature)
72 .finish()
73 }
74}
75
76impl Regr {
77 pub fn new(regr_type: RegrType, func_name: &'static str) -> Self {
78 Self {
79 signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
80 regr_type,
81 func_name,
82 }
83 }
84}
85
86#[derive(Debug, Clone, PartialEq, Hash, Eq)]
87#[allow(clippy::upper_case_acronyms)]
88pub enum RegrType {
89 Slope,
94 Intercept,
99 Count,
103 R2,
107 AvgX,
111 AvgY,
115 SXX,
119 SYY,
123 SXY,
127}
128
129impl RegrType {
130 fn documentation(&self) -> Option<&Documentation> {
132 get_regr_docs().get(self)
133 }
134}
135
136static DOCUMENTATION: LazyLock<HashMap<RegrType, Documentation>> = LazyLock::new(|| {
137 let mut hash_map = HashMap::new();
138 hash_map.insert(
139 RegrType::Slope,
140 Documentation::builder(
141 DOC_SECTION_STATISTICAL,
142 "Returns the slope of the linear regression line for non-null pairs in aggregate columns. \
143 Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k*X + b) using minimal RSS fitting.",
144
145 "regr_slope(expression_y, expression_x)")
146 .with_standard_argument("expression_y", Some("Dependent variable"))
147 .with_standard_argument("expression_x", Some("Independent variable"))
148 .build()
149 );
150
151 hash_map.insert(
152 RegrType::Intercept,
153 Documentation::builder(
154 DOC_SECTION_STATISTICAL,
155 "Computes the y-intercept of the linear regression line. For the equation (y = kx + b), \
156 this function returns b.",
157
158 "regr_intercept(expression_y, expression_x)")
159 .with_standard_argument("expression_y", Some("Dependent variable"))
160 .with_standard_argument("expression_x", Some("Independent variable"))
161 .build()
162 );
163
164 hash_map.insert(
165 RegrType::Count,
166 Documentation::builder(
167 DOC_SECTION_STATISTICAL,
168 "Counts the number of non-null paired data points.",
169 "regr_count(expression_y, expression_x)",
170 )
171 .with_standard_argument("expression_y", Some("Dependent variable"))
172 .with_standard_argument("expression_x", Some("Independent variable"))
173 .build(),
174 );
175
176 hash_map.insert(
177 RegrType::R2,
178 Documentation::builder(
179 DOC_SECTION_STATISTICAL,
180 "Computes the square of the correlation coefficient between the independent and dependent variables.",
181
182 "regr_r2(expression_y, expression_x)")
183 .with_standard_argument("expression_y", Some("Dependent variable"))
184 .with_standard_argument("expression_x", Some("Independent variable"))
185 .build()
186 );
187
188 hash_map.insert(
189 RegrType::AvgX,
190 Documentation::builder(
191 DOC_SECTION_STATISTICAL,
192 "Computes the average of the independent variable (input) expression_x for the non-null paired data points.",
193
194 "regr_avgx(expression_y, expression_x)")
195 .with_standard_argument("expression_y", Some("Dependent variable"))
196 .with_standard_argument("expression_x", Some("Independent variable"))
197 .build()
198 );
199
200 hash_map.insert(
201 RegrType::AvgY,
202 Documentation::builder(
203 DOC_SECTION_STATISTICAL,
204 "Computes the average of the dependent variable (output) expression_y for the non-null paired data points.",
205
206 "regr_avgy(expression_y, expression_x)")
207 .with_standard_argument("expression_y", Some("Dependent variable"))
208 .with_standard_argument("expression_x", Some("Independent variable"))
209 .build()
210 );
211
212 hash_map.insert(
213 RegrType::SXX,
214 Documentation::builder(
215 DOC_SECTION_STATISTICAL,
216 "Computes the sum of squares of the independent variable.",
217 "regr_sxx(expression_y, expression_x)",
218 )
219 .with_standard_argument("expression_y", Some("Dependent variable"))
220 .with_standard_argument("expression_x", Some("Independent variable"))
221 .build(),
222 );
223
224 hash_map.insert(
225 RegrType::SYY,
226 Documentation::builder(
227 DOC_SECTION_STATISTICAL,
228 "Computes the sum of squares of the dependent variable.",
229 "regr_syy(expression_y, expression_x)",
230 )
231 .with_standard_argument("expression_y", Some("Dependent variable"))
232 .with_standard_argument("expression_x", Some("Independent variable"))
233 .build(),
234 );
235
236 hash_map.insert(
237 RegrType::SXY,
238 Documentation::builder(
239 DOC_SECTION_STATISTICAL,
240 "Computes the sum of products of paired data points.",
241 "regr_sxy(expression_y, expression_x)",
242 )
243 .with_standard_argument("expression_y", Some("Dependent variable"))
244 .with_standard_argument("expression_x", Some("Independent variable"))
245 .build(),
246 );
247 hash_map
248});
249fn get_regr_docs() -> &'static HashMap<RegrType, Documentation> {
250 &DOCUMENTATION
251}
252
253impl AggregateUDFImpl for Regr {
254 fn as_any(&self) -> &dyn Any {
255 self
256 }
257
258 fn name(&self) -> &str {
259 self.func_name
260 }
261
262 fn signature(&self) -> &Signature {
263 &self.signature
264 }
265
266 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
267 if !arg_types[0].is_numeric() {
268 return plan_err!("Covariance requires numeric input types");
269 }
270
271 if matches!(self.regr_type, RegrType::Count) {
272 Ok(DataType::UInt64)
273 } else {
274 Ok(DataType::Float64)
275 }
276 }
277
278 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
279 Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?))
280 }
281
282 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
283 Ok(vec![
284 Field::new(
285 format_state_name(args.name, "count"),
286 DataType::UInt64,
287 true,
288 ),
289 Field::new(
290 format_state_name(args.name, "mean_x"),
291 DataType::Float64,
292 true,
293 ),
294 Field::new(
295 format_state_name(args.name, "mean_y"),
296 DataType::Float64,
297 true,
298 ),
299 Field::new(
300 format_state_name(args.name, "m2_x"),
301 DataType::Float64,
302 true,
303 ),
304 Field::new(
305 format_state_name(args.name, "m2_y"),
306 DataType::Float64,
307 true,
308 ),
309 Field::new(
310 format_state_name(args.name, "algo_const"),
311 DataType::Float64,
312 true,
313 ),
314 ]
315 .into_iter()
316 .map(Arc::new)
317 .collect())
318 }
319
320 fn documentation(&self) -> Option<&Documentation> {
321 self.regr_type.documentation()
322 }
323}
324
325#[derive(Debug)]
365pub struct RegrAccumulator {
366 count: u64,
367 mean_x: f64,
368 mean_y: f64,
369 m2_x: f64,
370 m2_y: f64,
371 algo_const: f64,
372 regr_type: RegrType,
373}
374
375impl RegrAccumulator {
376 pub fn try_new(regr_type: &RegrType) -> Result<Self> {
378 Ok(Self {
379 count: 0_u64,
380 mean_x: 0_f64,
381 mean_y: 0_f64,
382 m2_x: 0_f64,
383 m2_y: 0_f64,
384 algo_const: 0_f64,
385 regr_type: regr_type.clone(),
386 })
387 }
388}
389
390impl Accumulator for RegrAccumulator {
391 fn state(&mut self) -> Result<Vec<ScalarValue>> {
392 Ok(vec![
393 ScalarValue::from(self.count),
394 ScalarValue::from(self.mean_x),
395 ScalarValue::from(self.mean_y),
396 ScalarValue::from(self.m2_x),
397 ScalarValue::from(self.m2_y),
398 ScalarValue::from(self.algo_const),
399 ])
400 }
401
402 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
403 let values_y = &cast(&values[0], &DataType::Float64)?;
405 let values_x = &cast(&values[1], &DataType::Float64)?;
406
407 let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten();
408 let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten();
409
410 for i in 0..values_y.len() {
411 let value_y = if values_y.is_valid(i) {
413 arr_y.next()
414 } else {
415 None
416 };
417 let value_x = if values_x.is_valid(i) {
418 arr_x.next()
419 } else {
420 None
421 };
422 if value_y.is_none() || value_x.is_none() {
423 continue;
424 }
425
426 let value_y = unwrap_or_internal_err!(value_y);
428 let value_x = unwrap_or_internal_err!(value_x);
429
430 self.count += 1;
431 let delta_x = value_x - self.mean_x;
432 let delta_y = value_y - self.mean_y;
433 self.mean_x += delta_x / self.count as f64;
434 self.mean_y += delta_y / self.count as f64;
435 let delta_x_2 = value_x - self.mean_x;
436 let delta_y_2 = value_y - self.mean_y;
437 self.m2_x += delta_x * delta_x_2;
438 self.m2_y += delta_y * delta_y_2;
439 self.algo_const += delta_x * (value_y - self.mean_y);
440 }
441
442 Ok(())
443 }
444
445 fn supports_retract_batch(&self) -> bool {
446 true
447 }
448
449 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
450 let values_y = &cast(&values[0], &DataType::Float64)?;
451 let values_x = &cast(&values[1], &DataType::Float64)?;
452
453 let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten();
454 let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten();
455
456 for i in 0..values_y.len() {
457 let value_y = if values_y.is_valid(i) {
459 arr_y.next()
460 } else {
461 None
462 };
463 let value_x = if values_x.is_valid(i) {
464 arr_x.next()
465 } else {
466 None
467 };
468 if value_y.is_none() || value_x.is_none() {
469 continue;
470 }
471
472 let value_y = unwrap_or_internal_err!(value_y);
474 let value_x = unwrap_or_internal_err!(value_x);
475
476 if self.count > 1 {
477 self.count -= 1;
478 let delta_x = value_x - self.mean_x;
479 let delta_y = value_y - self.mean_y;
480 self.mean_x -= delta_x / self.count as f64;
481 self.mean_y -= delta_y / self.count as f64;
482 let delta_x_2 = value_x - self.mean_x;
483 let delta_y_2 = value_y - self.mean_y;
484 self.m2_x -= delta_x * delta_x_2;
485 self.m2_y -= delta_y * delta_y_2;
486 self.algo_const -= delta_x * (value_y - self.mean_y);
487 } else {
488 self.count = 0;
489 self.mean_x = 0.0;
490 self.m2_x = 0.0;
491 self.m2_y = 0.0;
492 self.mean_y = 0.0;
493 self.algo_const = 0.0;
494 }
495 }
496
497 Ok(())
498 }
499
500 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
501 let count_arr = downcast_value!(states[0], UInt64Array);
502 let mean_x_arr = downcast_value!(states[1], Float64Array);
503 let mean_y_arr = downcast_value!(states[2], Float64Array);
504 let m2_x_arr = downcast_value!(states[3], Float64Array);
505 let m2_y_arr = downcast_value!(states[4], Float64Array);
506 let algo_const_arr = downcast_value!(states[5], Float64Array);
507
508 for i in 0..count_arr.len() {
509 let count_b = count_arr.value(i);
510 if count_b == 0_u64 {
511 continue;
512 }
513 let (count_a, mean_x_a, mean_y_a, m2_x_a, m2_y_a, algo_const_a) = (
514 self.count,
515 self.mean_x,
516 self.mean_y,
517 self.m2_x,
518 self.m2_y,
519 self.algo_const,
520 );
521 let (count_b, mean_x_b, mean_y_b, m2_x_b, m2_y_b, algo_const_b) = (
522 count_b,
523 mean_x_arr.value(i),
524 mean_y_arr.value(i),
525 m2_x_arr.value(i),
526 m2_y_arr.value(i),
527 algo_const_arr.value(i),
528 );
529
530 let count_ab = count_a + count_b;
539 let (count_a, count_b) = (count_a as f64, count_b as f64);
540 let d_x = mean_x_b - mean_x_a;
541 let d_y = mean_y_b - mean_y_a;
542 let mean_x_ab = mean_x_a + d_x * count_b / count_ab as f64;
543 let mean_y_ab = mean_y_a + d_y * count_b / count_ab as f64;
544 let m2_x_ab =
545 m2_x_a + m2_x_b + d_x * d_x * count_a * count_b / count_ab as f64;
546 let m2_y_ab =
547 m2_y_a + m2_y_b + d_y * d_y * count_a * count_b / count_ab as f64;
548 let algo_const_ab = algo_const_a
549 + algo_const_b
550 + d_x * d_y * count_a * count_b / count_ab as f64;
551
552 self.count = count_ab;
553 self.mean_x = mean_x_ab;
554 self.mean_y = mean_y_ab;
555 self.m2_x = m2_x_ab;
556 self.m2_y = m2_y_ab;
557 self.algo_const = algo_const_ab;
558 }
559 Ok(())
560 }
561
562 fn evaluate(&mut self) -> Result<ScalarValue> {
563 let cov_pop_x_y = self.algo_const / self.count as f64;
564 let var_pop_x = self.m2_x / self.count as f64;
565 let var_pop_y = self.m2_y / self.count as f64;
566
567 let nullif_or_stat = |cond: bool, stat: f64| {
568 if cond {
569 Ok(ScalarValue::Float64(None))
570 } else {
571 Ok(ScalarValue::Float64(Some(stat)))
572 }
573 };
574
575 match self.regr_type {
576 RegrType::Slope => {
577 let nullif_cond = self.count <= 1 || var_pop_x == 0.0;
579 nullif_or_stat(nullif_cond, cov_pop_x_y / var_pop_x)
580 }
581 RegrType::Intercept => {
582 let slope = cov_pop_x_y / var_pop_x;
583 let nullif_cond = self.count <= 1 || var_pop_x == 0.0;
585 nullif_or_stat(nullif_cond, self.mean_y - slope * self.mean_x)
586 }
587 RegrType::Count => Ok(ScalarValue::UInt64(Some(self.count))),
588 RegrType::R2 => {
589 let nullif_cond = self.count <= 1 || var_pop_x == 0.0 || var_pop_y == 0.0;
591 nullif_or_stat(
592 nullif_cond,
593 (cov_pop_x_y * cov_pop_x_y) / (var_pop_x * var_pop_y),
594 )
595 }
596 RegrType::AvgX => nullif_or_stat(self.count < 1, self.mean_x),
597 RegrType::AvgY => nullif_or_stat(self.count < 1, self.mean_y),
598 RegrType::SXX => nullif_or_stat(self.count < 1, self.m2_x),
599 RegrType::SYY => nullif_or_stat(self.count < 1, self.m2_y),
600 RegrType::SXY => nullif_or_stat(self.count < 1, self.algo_const),
601 }
602 }
603
604 fn size(&self) -> usize {
605 size_of_val(self)
606 }
607}