datafusion_functions_aggregate/
covariance.rs1use arrow::datatypes::FieldRef;
21use arrow::{
22 array::{ArrayRef, Float64Array, UInt64Array},
23 compute::kernels::cast,
24 datatypes::{DataType, Field},
25};
26use datafusion_common::{
27 Result, ScalarValue, downcast_value, plan_err, unwrap_or_internal_err,
28};
29use datafusion_expr::{
30 Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
31 function::{AccumulatorArgs, StateFieldsArgs},
32 type_coercion::aggregates::NUMERICS,
33 utils::format_state_name,
34};
35use datafusion_functions_aggregate_common::stats::StatsType;
36use datafusion_macros::user_doc;
37use std::fmt::Debug;
38use std::mem::size_of_val;
39use std::sync::Arc;
40
41make_udaf_expr_and_func!(
42 CovarianceSample,
43 covar_samp,
44 y x,
45 "Computes the sample covariance.",
46 covar_samp_udaf
47);
48
49make_udaf_expr_and_func!(
50 CovariancePopulation,
51 covar_pop,
52 y x,
53 "Computes the population covariance.",
54 covar_pop_udaf
55);
56
57#[user_doc(
58 doc_section(label = "Statistical Functions"),
59 description = "Returns the sample covariance of a set of number pairs.",
60 syntax_example = "covar_samp(expression1, expression2)",
61 sql_example = r#"```sql
62> SELECT covar_samp(column1, column2) FROM table_name;
63+-----------------------------------+
64| covar_samp(column1, column2) |
65+-----------------------------------+
66| 8.25 |
67+-----------------------------------+
68```"#,
69 standard_argument(name = "expression1", prefix = "First"),
70 standard_argument(name = "expression2", prefix = "Second")
71)]
72#[derive(PartialEq, Eq, Hash)]
73pub struct CovarianceSample {
74 signature: Signature,
75 aliases: Vec<String>,
76}
77
78impl Debug for CovarianceSample {
79 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
80 f.debug_struct("CovarianceSample")
81 .field("name", &self.name())
82 .field("signature", &self.signature)
83 .finish()
84 }
85}
86
87impl Default for CovarianceSample {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93impl CovarianceSample {
94 pub fn new() -> Self {
95 Self {
96 aliases: vec![String::from("covar")],
97 signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
98 }
99 }
100}
101
102impl AggregateUDFImpl for CovarianceSample {
103 fn as_any(&self) -> &dyn std::any::Any {
104 self
105 }
106
107 fn name(&self) -> &str {
108 "covar_samp"
109 }
110
111 fn signature(&self) -> &Signature {
112 &self.signature
113 }
114
115 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
116 if !arg_types[0].is_numeric() {
117 return plan_err!("Covariance requires numeric input types");
118 }
119
120 Ok(DataType::Float64)
121 }
122
123 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
124 let name = args.name;
125 Ok(vec![
126 Field::new(format_state_name(name, "count"), DataType::UInt64, true),
127 Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
128 Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
129 Field::new(
130 format_state_name(name, "algo_const"),
131 DataType::Float64,
132 true,
133 ),
134 ]
135 .into_iter()
136 .map(Arc::new)
137 .collect())
138 }
139
140 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
141 Ok(Box::new(CovarianceAccumulator::try_new(StatsType::Sample)?))
142 }
143
144 fn aliases(&self) -> &[String] {
145 &self.aliases
146 }
147
148 fn documentation(&self) -> Option<&Documentation> {
149 self.doc()
150 }
151}
152
153#[user_doc(
154 doc_section(label = "Statistical Functions"),
155 description = "Returns the sample covariance of a set of number pairs.",
156 syntax_example = "covar_samp(expression1, expression2)",
157 sql_example = r#"```sql
158> SELECT covar_samp(column1, column2) FROM table_name;
159+-----------------------------------+
160| covar_samp(column1, column2) |
161+-----------------------------------+
162| 8.25 |
163+-----------------------------------+
164```"#,
165 standard_argument(name = "expression1", prefix = "First"),
166 standard_argument(name = "expression2", prefix = "Second")
167)]
168#[derive(PartialEq, Eq, Hash)]
169pub struct CovariancePopulation {
170 signature: Signature,
171}
172
173impl Debug for CovariancePopulation {
174 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
175 f.debug_struct("CovariancePopulation")
176 .field("name", &self.name())
177 .field("signature", &self.signature)
178 .finish()
179 }
180}
181
182impl Default for CovariancePopulation {
183 fn default() -> Self {
184 Self::new()
185 }
186}
187
188impl CovariancePopulation {
189 pub fn new() -> Self {
190 Self {
191 signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
192 }
193 }
194}
195
196impl AggregateUDFImpl for CovariancePopulation {
197 fn as_any(&self) -> &dyn std::any::Any {
198 self
199 }
200
201 fn name(&self) -> &str {
202 "covar_pop"
203 }
204
205 fn signature(&self) -> &Signature {
206 &self.signature
207 }
208
209 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
210 if !arg_types[0].is_numeric() {
211 return plan_err!("Covariance requires numeric input types");
212 }
213
214 Ok(DataType::Float64)
215 }
216
217 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
218 let name = args.name;
219 Ok(vec![
220 Field::new(format_state_name(name, "count"), DataType::UInt64, true),
221 Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
222 Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
223 Field::new(
224 format_state_name(name, "algo_const"),
225 DataType::Float64,
226 true,
227 ),
228 ]
229 .into_iter()
230 .map(Arc::new)
231 .collect())
232 }
233
234 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
235 Ok(Box::new(CovarianceAccumulator::try_new(
236 StatsType::Population,
237 )?))
238 }
239
240 fn documentation(&self) -> Option<&Documentation> {
241 self.doc()
242 }
243}
244
245#[derive(Debug)]
259pub struct CovarianceAccumulator {
260 algo_const: f64,
261 mean1: f64,
262 mean2: f64,
263 count: u64,
264 stats_type: StatsType,
265}
266
267impl CovarianceAccumulator {
268 pub fn try_new(s_type: StatsType) -> Result<Self> {
270 Ok(Self {
271 algo_const: 0_f64,
272 mean1: 0_f64,
273 mean2: 0_f64,
274 count: 0_u64,
275 stats_type: s_type,
276 })
277 }
278
279 pub fn get_count(&self) -> u64 {
280 self.count
281 }
282
283 pub fn get_mean1(&self) -> f64 {
284 self.mean1
285 }
286
287 pub fn get_mean2(&self) -> f64 {
288 self.mean2
289 }
290
291 pub fn get_algo_const(&self) -> f64 {
292 self.algo_const
293 }
294}
295
296impl Accumulator for CovarianceAccumulator {
297 fn state(&mut self) -> Result<Vec<ScalarValue>> {
298 Ok(vec![
299 ScalarValue::from(self.count),
300 ScalarValue::from(self.mean1),
301 ScalarValue::from(self.mean2),
302 ScalarValue::from(self.algo_const),
303 ])
304 }
305
306 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
307 let values1 = &cast(&values[0], &DataType::Float64)?;
308 let values2 = &cast(&values[1], &DataType::Float64)?;
309
310 let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
311 let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
312
313 for i in 0..values1.len() {
314 let value1 = if values1.is_valid(i) {
315 arr1.next()
316 } else {
317 None
318 };
319 let value2 = if values2.is_valid(i) {
320 arr2.next()
321 } else {
322 None
323 };
324
325 if value1.is_none() || value2.is_none() {
326 continue;
327 }
328
329 let value1 = unwrap_or_internal_err!(value1);
330 let value2 = unwrap_or_internal_err!(value2);
331 let new_count = self.count + 1;
332 let delta1 = value1 - self.mean1;
333 let new_mean1 = delta1 / new_count as f64 + self.mean1;
334 let delta2 = value2 - self.mean2;
335 let new_mean2 = delta2 / new_count as f64 + self.mean2;
336 let new_c = delta1 * (value2 - new_mean2) + self.algo_const;
337
338 self.count += 1;
339 self.mean1 = new_mean1;
340 self.mean2 = new_mean2;
341 self.algo_const = new_c;
342 }
343
344 Ok(())
345 }
346
347 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
348 let values1 = &cast(&values[0], &DataType::Float64)?;
349 let values2 = &cast(&values[1], &DataType::Float64)?;
350 let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
351 let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
352
353 for i in 0..values1.len() {
354 let value1 = if values1.is_valid(i) {
355 arr1.next()
356 } else {
357 None
358 };
359 let value2 = if values2.is_valid(i) {
360 arr2.next()
361 } else {
362 None
363 };
364
365 if value1.is_none() || value2.is_none() {
366 continue;
367 }
368
369 let value1 = unwrap_or_internal_err!(value1);
370 let value2 = unwrap_or_internal_err!(value2);
371
372 let new_count = self.count - 1;
373 let delta1 = self.mean1 - value1;
374 let new_mean1 = delta1 / new_count as f64 + self.mean1;
375 let delta2 = self.mean2 - value2;
376 let new_mean2 = delta2 / new_count as f64 + self.mean2;
377 let new_c = self.algo_const - delta1 * (new_mean2 - value2);
378
379 self.count -= 1;
380 self.mean1 = new_mean1;
381 self.mean2 = new_mean2;
382 self.algo_const = new_c;
383 }
384
385 Ok(())
386 }
387
388 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
389 let counts = downcast_value!(states[0], UInt64Array);
390 let means1 = downcast_value!(states[1], Float64Array);
391 let means2 = downcast_value!(states[2], Float64Array);
392 let cs = downcast_value!(states[3], Float64Array);
393
394 for i in 0..counts.len() {
395 let c = counts.value(i);
396 if c == 0_u64 {
397 continue;
398 }
399 let new_count = self.count + c;
400 let new_mean1 = self.mean1 * self.count as f64 / new_count as f64
401 + means1.value(i) * c as f64 / new_count as f64;
402 let new_mean2 = self.mean2 * self.count as f64 / new_count as f64
403 + means2.value(i) * c as f64 / new_count as f64;
404 let delta1 = self.mean1 - means1.value(i);
405 let delta2 = self.mean2 - means2.value(i);
406 let new_c = self.algo_const
407 + cs.value(i)
408 + delta1 * delta2 * self.count as f64 * c as f64 / new_count as f64;
409
410 self.count = new_count;
411 self.mean1 = new_mean1;
412 self.mean2 = new_mean2;
413 self.algo_const = new_c;
414 }
415 Ok(())
416 }
417
418 fn evaluate(&mut self) -> Result<ScalarValue> {
419 let count = match self.stats_type {
420 StatsType::Population => self.count,
421 StatsType::Sample => {
422 if self.count > 0 {
423 self.count - 1
424 } else {
425 self.count
426 }
427 }
428 };
429
430 if count == 0 {
431 Ok(ScalarValue::Float64(None))
432 } else {
433 Ok(ScalarValue::Float64(Some(self.algo_const / count as f64)))
434 }
435 }
436
437 fn size(&self) -> usize {
438 size_of_val(self)
439 }
440}