datafusion_functions_aggregate/
covariance.rs1use arrow::datatypes::FieldRef;
21use arrow::{
22 array::{ArrayRef, Float64Array, UInt64Array},
23 compute::kernels::cast,
24 datatypes::{DataType, Field},
25};
26use datafusion_common::{
27 downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, Result,
28 ScalarValue,
29};
30use datafusion_expr::{
31 function::{AccumulatorArgs, StateFieldsArgs},
32 type_coercion::aggregates::NUMERICS,
33 utils::format_state_name,
34 Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
35};
36use datafusion_functions_aggregate_common::stats::StatsType;
37use datafusion_macros::user_doc;
38use std::fmt::Debug;
39use std::mem::size_of_val;
40use std::sync::Arc;
41
42make_udaf_expr_and_func!(
43 CovarianceSample,
44 covar_samp,
45 y x,
46 "Computes the sample covariance.",
47 covar_samp_udaf
48);
49
50make_udaf_expr_and_func!(
51 CovariancePopulation,
52 covar_pop,
53 y x,
54 "Computes the population covariance.",
55 covar_pop_udaf
56);
57
58#[user_doc(
59 doc_section(label = "Statistical Functions"),
60 description = "Returns the sample covariance of a set of number pairs.",
61 syntax_example = "covar_samp(expression1, expression2)",
62 sql_example = r#"```sql
63> SELECT covar_samp(column1, column2) FROM table_name;
64+-----------------------------------+
65| covar_samp(column1, column2) |
66+-----------------------------------+
67| 8.25 |
68+-----------------------------------+
69```"#,
70 standard_argument(name = "expression1", prefix = "First"),
71 standard_argument(name = "expression2", prefix = "Second")
72)]
73#[derive(PartialEq, Eq, Hash)]
74pub struct CovarianceSample {
75 signature: Signature,
76 aliases: Vec<String>,
77}
78
79impl Debug for CovarianceSample {
80 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
81 f.debug_struct("CovarianceSample")
82 .field("name", &self.name())
83 .field("signature", &self.signature)
84 .finish()
85 }
86}
87
88impl Default for CovarianceSample {
89 fn default() -> Self {
90 Self::new()
91 }
92}
93
94impl CovarianceSample {
95 pub fn new() -> Self {
96 Self {
97 aliases: vec![String::from("covar")],
98 signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
99 }
100 }
101}
102
103impl AggregateUDFImpl for CovarianceSample {
104 fn as_any(&self) -> &dyn std::any::Any {
105 self
106 }
107
108 fn name(&self) -> &str {
109 "covar_samp"
110 }
111
112 fn signature(&self) -> &Signature {
113 &self.signature
114 }
115
116 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
117 if !arg_types[0].is_numeric() {
118 return plan_err!("Covariance requires numeric input types");
119 }
120
121 Ok(DataType::Float64)
122 }
123
124 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
125 let name = args.name;
126 Ok(vec![
127 Field::new(format_state_name(name, "count"), DataType::UInt64, true),
128 Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
129 Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
130 Field::new(
131 format_state_name(name, "algo_const"),
132 DataType::Float64,
133 true,
134 ),
135 ]
136 .into_iter()
137 .map(Arc::new)
138 .collect())
139 }
140
141 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
142 Ok(Box::new(CovarianceAccumulator::try_new(StatsType::Sample)?))
143 }
144
145 fn aliases(&self) -> &[String] {
146 &self.aliases
147 }
148
149 fn documentation(&self) -> Option<&Documentation> {
150 self.doc()
151 }
152}
153
154#[user_doc(
155 doc_section(label = "Statistical Functions"),
156 description = "Returns the sample covariance of a set of number pairs.",
157 syntax_example = "covar_samp(expression1, expression2)",
158 sql_example = r#"```sql
159> SELECT covar_samp(column1, column2) FROM table_name;
160+-----------------------------------+
161| covar_samp(column1, column2) |
162+-----------------------------------+
163| 8.25 |
164+-----------------------------------+
165```"#,
166 standard_argument(name = "expression1", prefix = "First"),
167 standard_argument(name = "expression2", prefix = "Second")
168)]
169#[derive(PartialEq, Eq, Hash)]
170pub struct CovariancePopulation {
171 signature: Signature,
172}
173
174impl Debug for CovariancePopulation {
175 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
176 f.debug_struct("CovariancePopulation")
177 .field("name", &self.name())
178 .field("signature", &self.signature)
179 .finish()
180 }
181}
182
183impl Default for CovariancePopulation {
184 fn default() -> Self {
185 Self::new()
186 }
187}
188
189impl CovariancePopulation {
190 pub fn new() -> Self {
191 Self {
192 signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
193 }
194 }
195}
196
197impl AggregateUDFImpl for CovariancePopulation {
198 fn as_any(&self) -> &dyn std::any::Any {
199 self
200 }
201
202 fn name(&self) -> &str {
203 "covar_pop"
204 }
205
206 fn signature(&self) -> &Signature {
207 &self.signature
208 }
209
210 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
211 if !arg_types[0].is_numeric() {
212 return plan_err!("Covariance requires numeric input types");
213 }
214
215 Ok(DataType::Float64)
216 }
217
218 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
219 let name = args.name;
220 Ok(vec![
221 Field::new(format_state_name(name, "count"), DataType::UInt64, true),
222 Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
223 Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
224 Field::new(
225 format_state_name(name, "algo_const"),
226 DataType::Float64,
227 true,
228 ),
229 ]
230 .into_iter()
231 .map(Arc::new)
232 .collect())
233 }
234
235 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
236 Ok(Box::new(CovarianceAccumulator::try_new(
237 StatsType::Population,
238 )?))
239 }
240
241 fn documentation(&self) -> Option<&Documentation> {
242 self.doc()
243 }
244}
245
246#[derive(Debug)]
260pub struct CovarianceAccumulator {
261 algo_const: f64,
262 mean1: f64,
263 mean2: f64,
264 count: u64,
265 stats_type: StatsType,
266}
267
268impl CovarianceAccumulator {
269 pub fn try_new(s_type: StatsType) -> Result<Self> {
271 Ok(Self {
272 algo_const: 0_f64,
273 mean1: 0_f64,
274 mean2: 0_f64,
275 count: 0_u64,
276 stats_type: s_type,
277 })
278 }
279
280 pub fn get_count(&self) -> u64 {
281 self.count
282 }
283
284 pub fn get_mean1(&self) -> f64 {
285 self.mean1
286 }
287
288 pub fn get_mean2(&self) -> f64 {
289 self.mean2
290 }
291
292 pub fn get_algo_const(&self) -> f64 {
293 self.algo_const
294 }
295}
296
297impl Accumulator for CovarianceAccumulator {
298 fn state(&mut self) -> Result<Vec<ScalarValue>> {
299 Ok(vec![
300 ScalarValue::from(self.count),
301 ScalarValue::from(self.mean1),
302 ScalarValue::from(self.mean2),
303 ScalarValue::from(self.algo_const),
304 ])
305 }
306
307 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
308 let values1 = &cast(&values[0], &DataType::Float64)?;
309 let values2 = &cast(&values[1], &DataType::Float64)?;
310
311 let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
312 let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
313
314 for i in 0..values1.len() {
315 let value1 = if values1.is_valid(i) {
316 arr1.next()
317 } else {
318 None
319 };
320 let value2 = if values2.is_valid(i) {
321 arr2.next()
322 } else {
323 None
324 };
325
326 if value1.is_none() || value2.is_none() {
327 continue;
328 }
329
330 let value1 = unwrap_or_internal_err!(value1);
331 let value2 = unwrap_or_internal_err!(value2);
332 let new_count = self.count + 1;
333 let delta1 = value1 - self.mean1;
334 let new_mean1 = delta1 / new_count as f64 + self.mean1;
335 let delta2 = value2 - self.mean2;
336 let new_mean2 = delta2 / new_count as f64 + self.mean2;
337 let new_c = delta1 * (value2 - new_mean2) + self.algo_const;
338
339 self.count += 1;
340 self.mean1 = new_mean1;
341 self.mean2 = new_mean2;
342 self.algo_const = new_c;
343 }
344
345 Ok(())
346 }
347
348 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
349 let values1 = &cast(&values[0], &DataType::Float64)?;
350 let values2 = &cast(&values[1], &DataType::Float64)?;
351 let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
352 let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
353
354 for i in 0..values1.len() {
355 let value1 = if values1.is_valid(i) {
356 arr1.next()
357 } else {
358 None
359 };
360 let value2 = if values2.is_valid(i) {
361 arr2.next()
362 } else {
363 None
364 };
365
366 if value1.is_none() || value2.is_none() {
367 continue;
368 }
369
370 let value1 = unwrap_or_internal_err!(value1);
371 let value2 = unwrap_or_internal_err!(value2);
372
373 let new_count = self.count - 1;
374 let delta1 = self.mean1 - value1;
375 let new_mean1 = delta1 / new_count as f64 + self.mean1;
376 let delta2 = self.mean2 - value2;
377 let new_mean2 = delta2 / new_count as f64 + self.mean2;
378 let new_c = self.algo_const - delta1 * (new_mean2 - value2);
379
380 self.count -= 1;
381 self.mean1 = new_mean1;
382 self.mean2 = new_mean2;
383 self.algo_const = new_c;
384 }
385
386 Ok(())
387 }
388
389 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
390 let counts = downcast_value!(states[0], UInt64Array);
391 let means1 = downcast_value!(states[1], Float64Array);
392 let means2 = downcast_value!(states[2], Float64Array);
393 let cs = downcast_value!(states[3], Float64Array);
394
395 for i in 0..counts.len() {
396 let c = counts.value(i);
397 if c == 0_u64 {
398 continue;
399 }
400 let new_count = self.count + c;
401 let new_mean1 = self.mean1 * self.count as f64 / new_count as f64
402 + means1.value(i) * c as f64 / new_count as f64;
403 let new_mean2 = self.mean2 * self.count as f64 / new_count as f64
404 + means2.value(i) * c as f64 / new_count as f64;
405 let delta1 = self.mean1 - means1.value(i);
406 let delta2 = self.mean2 - means2.value(i);
407 let new_c = self.algo_const
408 + cs.value(i)
409 + delta1 * delta2 * self.count as f64 * c as f64 / new_count as f64;
410
411 self.count = new_count;
412 self.mean1 = new_mean1;
413 self.mean2 = new_mean2;
414 self.algo_const = new_c;
415 }
416 Ok(())
417 }
418
419 fn evaluate(&mut self) -> Result<ScalarValue> {
420 let count = match self.stats_type {
421 StatsType::Population => self.count,
422 StatsType::Sample => {
423 if self.count > 0 {
424 self.count - 1
425 } else {
426 self.count
427 }
428 }
429 };
430
431 if count == 0 {
432 Ok(ScalarValue::Float64(None))
433 } else {
434 Ok(ScalarValue::Float64(Some(self.algo_const / count as f64)))
435 }
436 }
437
438 fn size(&self) -> usize {
439 size_of_val(self)
440 }
441}