datafusion_functions_aggregate/
covariance.rs1use arrow::array::ArrayRef;
21use arrow::datatypes::{DataType, Field, FieldRef};
22use datafusion_common::cast::{as_float64_array, as_uint64_array};
23use datafusion_common::{Result, ScalarValue};
24use datafusion_expr::{
25 Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
26 function::{AccumulatorArgs, StateFieldsArgs},
27 utils::format_state_name,
28};
29use datafusion_functions_aggregate_common::stats::StatsType;
30use datafusion_macros::user_doc;
31use std::fmt::Debug;
32use std::mem::size_of_val;
33use std::sync::Arc;
34
35make_udaf_expr_and_func!(
36 CovarianceSample,
37 covar_samp,
38 y x,
39 "Computes the sample covariance.",
40 covar_samp_udaf
41);
42
43make_udaf_expr_and_func!(
44 CovariancePopulation,
45 covar_pop,
46 y x,
47 "Computes the population covariance.",
48 covar_pop_udaf
49);
50
51#[user_doc(
52 doc_section(label = "Statistical Functions"),
53 description = "Returns the sample covariance of a set of number pairs.",
54 syntax_example = "covar_samp(expression1, expression2)",
55 sql_example = r#"```sql
56> SELECT covar_samp(column1, column2) FROM table_name;
57+-----------------------------------+
58| covar_samp(column1, column2) |
59+-----------------------------------+
60| 8.25 |
61+-----------------------------------+
62```"#,
63 standard_argument(name = "expression1", prefix = "First"),
64 standard_argument(name = "expression2", prefix = "Second")
65)]
66#[derive(PartialEq, Eq, Hash, Debug)]
67pub struct CovarianceSample {
68 signature: Signature,
69 aliases: Vec<String>,
70}
71
72impl Default for CovarianceSample {
73 fn default() -> Self {
74 Self::new()
75 }
76}
77
78impl CovarianceSample {
79 pub fn new() -> Self {
80 Self {
81 aliases: vec![String::from("covar")],
82 signature: Signature::exact(
83 vec![DataType::Float64, DataType::Float64],
84 Volatility::Immutable,
85 ),
86 }
87 }
88}
89
90impl AggregateUDFImpl for CovarianceSample {
91 fn as_any(&self) -> &dyn std::any::Any {
92 self
93 }
94
95 fn name(&self) -> &str {
96 "covar_samp"
97 }
98
99 fn signature(&self) -> &Signature {
100 &self.signature
101 }
102
103 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
104 Ok(DataType::Float64)
105 }
106
107 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
108 let name = args.name;
109 Ok(vec![
110 Field::new(format_state_name(name, "count"), DataType::UInt64, true),
111 Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
112 Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
113 Field::new(
114 format_state_name(name, "algo_const"),
115 DataType::Float64,
116 true,
117 ),
118 ]
119 .into_iter()
120 .map(Arc::new)
121 .collect())
122 }
123
124 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
125 Ok(Box::new(CovarianceAccumulator::try_new(StatsType::Sample)?))
126 }
127
128 fn aliases(&self) -> &[String] {
129 &self.aliases
130 }
131
132 fn documentation(&self) -> Option<&Documentation> {
133 self.doc()
134 }
135}
136
137#[user_doc(
138 doc_section(label = "Statistical Functions"),
139 description = "Returns the sample covariance of a set of number pairs.",
140 syntax_example = "covar_samp(expression1, expression2)",
141 sql_example = r#"```sql
142> SELECT covar_samp(column1, column2) FROM table_name;
143+-----------------------------------+
144| covar_samp(column1, column2) |
145+-----------------------------------+
146| 8.25 |
147+-----------------------------------+
148```"#,
149 standard_argument(name = "expression1", prefix = "First"),
150 standard_argument(name = "expression2", prefix = "Second")
151)]
152#[derive(PartialEq, Eq, Hash, Debug)]
153pub struct CovariancePopulation {
154 signature: Signature,
155}
156
157impl Default for CovariancePopulation {
158 fn default() -> Self {
159 Self::new()
160 }
161}
162
163impl CovariancePopulation {
164 pub fn new() -> Self {
165 Self {
166 signature: Signature::exact(
167 vec![DataType::Float64, DataType::Float64],
168 Volatility::Immutable,
169 ),
170 }
171 }
172}
173
174impl AggregateUDFImpl for CovariancePopulation {
175 fn as_any(&self) -> &dyn std::any::Any {
176 self
177 }
178
179 fn name(&self) -> &str {
180 "covar_pop"
181 }
182
183 fn signature(&self) -> &Signature {
184 &self.signature
185 }
186
187 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
188 Ok(DataType::Float64)
189 }
190
191 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
192 let name = args.name;
193 Ok(vec![
194 Field::new(format_state_name(name, "count"), DataType::UInt64, true),
195 Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
196 Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
197 Field::new(
198 format_state_name(name, "algo_const"),
199 DataType::Float64,
200 true,
201 ),
202 ]
203 .into_iter()
204 .map(Arc::new)
205 .collect())
206 }
207
208 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
209 Ok(Box::new(CovarianceAccumulator::try_new(
210 StatsType::Population,
211 )?))
212 }
213
214 fn documentation(&self) -> Option<&Documentation> {
215 self.doc()
216 }
217}
218
219#[derive(Debug)]
233pub struct CovarianceAccumulator {
234 algo_const: f64,
235 mean1: f64,
236 mean2: f64,
237 count: u64,
238 stats_type: StatsType,
239}
240
241impl CovarianceAccumulator {
242 pub fn try_new(s_type: StatsType) -> Result<Self> {
244 Ok(Self {
245 algo_const: 0_f64,
246 mean1: 0_f64,
247 mean2: 0_f64,
248 count: 0_u64,
249 stats_type: s_type,
250 })
251 }
252
253 pub fn get_count(&self) -> u64 {
254 self.count
255 }
256
257 pub fn get_mean1(&self) -> f64 {
258 self.mean1
259 }
260
261 pub fn get_mean2(&self) -> f64 {
262 self.mean2
263 }
264
265 pub fn get_algo_const(&self) -> f64 {
266 self.algo_const
267 }
268}
269
270impl Accumulator for CovarianceAccumulator {
271 fn state(&mut self) -> Result<Vec<ScalarValue>> {
272 Ok(vec![
273 ScalarValue::from(self.count),
274 ScalarValue::from(self.mean1),
275 ScalarValue::from(self.mean2),
276 ScalarValue::from(self.algo_const),
277 ])
278 }
279
280 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
281 let values1 = as_float64_array(&values[0])?;
282 let values2 = as_float64_array(&values[1])?;
283
284 for (value1, value2) in values1.iter().zip(values2) {
285 let (value1, value2) = match (value1, value2) {
286 (Some(a), Some(b)) => (a, b),
287 _ => continue,
288 };
289
290 let new_count = self.count + 1;
291 let delta1 = value1 - self.mean1;
292 let new_mean1 = delta1 / new_count as f64 + self.mean1;
293 let delta2 = value2 - self.mean2;
294 let new_mean2 = delta2 / new_count as f64 + self.mean2;
295 let new_c = delta1 * (value2 - new_mean2) + self.algo_const;
296
297 self.count += 1;
298 self.mean1 = new_mean1;
299 self.mean2 = new_mean2;
300 self.algo_const = new_c;
301 }
302
303 Ok(())
304 }
305
306 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
307 let values1 = as_float64_array(&values[0])?;
308 let values2 = as_float64_array(&values[1])?;
309
310 for (value1, value2) in values1.iter().zip(values2) {
311 let (value1, value2) = match (value1, value2) {
312 (Some(a), Some(b)) => (a, b),
313 _ => continue,
314 };
315
316 let new_count = self.count - 1;
317 let delta1 = self.mean1 - value1;
318 let new_mean1 = delta1 / new_count as f64 + self.mean1;
319 let delta2 = self.mean2 - value2;
320 let new_mean2 = delta2 / new_count as f64 + self.mean2;
321 let new_c = self.algo_const - delta1 * (new_mean2 - value2);
322
323 self.count -= 1;
324 self.mean1 = new_mean1;
325 self.mean2 = new_mean2;
326 self.algo_const = new_c;
327 }
328
329 Ok(())
330 }
331
332 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
333 let counts = as_uint64_array(&states[0])?;
334 let means1 = as_float64_array(&states[1])?;
335 let means2 = as_float64_array(&states[2])?;
336 let cs = as_float64_array(&states[3])?;
337
338 for i in 0..counts.len() {
339 let c = counts.value(i);
340 if c == 0_u64 {
341 continue;
342 }
343 let new_count = self.count + c;
344 let new_mean1 = self.mean1 * self.count as f64 / new_count as f64
345 + means1.value(i) * c as f64 / new_count as f64;
346 let new_mean2 = self.mean2 * self.count as f64 / new_count as f64
347 + means2.value(i) * c as f64 / new_count as f64;
348 let delta1 = self.mean1 - means1.value(i);
349 let delta2 = self.mean2 - means2.value(i);
350 let new_c = self.algo_const
351 + cs.value(i)
352 + delta1 * delta2 * self.count as f64 * c as f64 / new_count as f64;
353
354 self.count = new_count;
355 self.mean1 = new_mean1;
356 self.mean2 = new_mean2;
357 self.algo_const = new_c;
358 }
359 Ok(())
360 }
361
362 fn evaluate(&mut self) -> Result<ScalarValue> {
363 let count = match self.stats_type {
364 StatsType::Population => self.count,
365 StatsType::Sample => {
366 if self.count > 0 {
367 self.count - 1
368 } else {
369 self.count
370 }
371 }
372 };
373
374 if count == 0 {
375 Ok(ScalarValue::Float64(None))
376 } else {
377 Ok(ScalarValue::Float64(Some(self.algo_const / count as f64)))
378 }
379 }
380
381 fn size(&self) -> usize {
382 size_of_val(self)
383 }
384}