datafusion_comet_spark_expr/agg_funcs/
covariance.rs1use arrow::datatypes::FieldRef;
21use arrow::{
22 array::{ArrayRef, Float64Array},
23 compute::cast,
24 datatypes::{DataType, Field},
25};
26use datafusion::common::{
27 downcast_value, unwrap_or_internal_err, DataFusionError, Result, ScalarValue,
28};
29use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
30use datafusion::logical_expr::type_coercion::aggregates::NUMERICS;
31use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
32use datafusion::physical_expr::expressions::format_state_name;
33use datafusion::physical_expr::expressions::StatsType;
34use std::any::Any;
35use std::sync::Arc;
36
37#[derive(Debug, Clone)]
42pub struct Covariance {
43 name: String,
44 signature: Signature,
45 stats_type: StatsType,
46 null_on_divide_by_zero: bool,
47}
48
49impl Covariance {
50 pub fn new(
52 name: impl Into<String>,
53 data_type: DataType,
54 stats_type: StatsType,
55 null_on_divide_by_zero: bool,
56 ) -> Self {
57 assert!(matches!(data_type, DataType::Float64));
59 Self {
60 name: name.into(),
61 signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
62 stats_type,
63 null_on_divide_by_zero,
64 }
65 }
66}
67
68impl AggregateUDFImpl for Covariance {
69 fn as_any(&self) -> &dyn Any {
71 self
72 }
73
74 fn name(&self) -> &str {
75 &self.name
76 }
77
78 fn signature(&self) -> &Signature {
79 &self.signature
80 }
81
82 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
83 Ok(DataType::Float64)
84 }
85 fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
86 Ok(ScalarValue::Float64(None))
87 }
88
89 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
90 Ok(Box::new(CovarianceAccumulator::try_new(
91 self.stats_type,
92 self.null_on_divide_by_zero,
93 )?))
94 }
95
96 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
97 Ok(vec![
98 Arc::new(Field::new(
99 format_state_name(&self.name, "count"),
100 DataType::Float64,
101 true,
102 )),
103 Arc::new(Field::new(
104 format_state_name(&self.name, "mean1"),
105 DataType::Float64,
106 true,
107 )),
108 Arc::new(Field::new(
109 format_state_name(&self.name, "mean2"),
110 DataType::Float64,
111 true,
112 )),
113 Arc::new(Field::new(
114 format_state_name(&self.name, "algo_const"),
115 DataType::Float64,
116 true,
117 )),
118 ])
119 }
120}
121
122#[derive(Debug)]
124pub struct CovarianceAccumulator {
125 algo_const: f64,
126 mean1: f64,
127 mean2: f64,
128 count: f64,
129 stats_type: StatsType,
130 null_on_divide_by_zero: bool,
131}
132
133impl CovarianceAccumulator {
134 pub fn try_new(s_type: StatsType, null_on_divide_by_zero: bool) -> Result<Self> {
136 Ok(Self {
137 algo_const: 0_f64,
138 mean1: 0_f64,
139 mean2: 0_f64,
140 count: 0_f64,
141 stats_type: s_type,
142 null_on_divide_by_zero,
143 })
144 }
145
146 pub fn get_count(&self) -> f64 {
147 self.count
148 }
149
150 pub fn get_mean1(&self) -> f64 {
151 self.mean1
152 }
153
154 pub fn get_mean2(&self) -> f64 {
155 self.mean2
156 }
157
158 pub fn get_algo_const(&self) -> f64 {
159 self.algo_const
160 }
161}
162
163impl Accumulator for CovarianceAccumulator {
164 fn state(&mut self) -> Result<Vec<ScalarValue>> {
165 Ok(vec![
166 ScalarValue::from(self.count),
167 ScalarValue::from(self.mean1),
168 ScalarValue::from(self.mean2),
169 ScalarValue::from(self.algo_const),
170 ])
171 }
172
173 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
174 let values1 = &cast(&values[0], &DataType::Float64)?;
175 let values2 = &cast(&values[1], &DataType::Float64)?;
176
177 let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
178 let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
179
180 for i in 0..values1.len() {
181 let value1 = if values1.is_valid(i) {
182 arr1.next()
183 } else {
184 None
185 };
186 let value2 = if values2.is_valid(i) {
187 arr2.next()
188 } else {
189 None
190 };
191
192 if value1.is_none() || value2.is_none() {
193 continue;
194 }
195
196 let value1 = unwrap_or_internal_err!(value1);
197 let value2 = unwrap_or_internal_err!(value2);
198 let new_count = self.count + 1.0;
199 let delta1 = value1 - self.mean1;
200 let new_mean1 = delta1 / new_count + self.mean1;
201 let delta2 = value2 - self.mean2;
202 let new_mean2 = delta2 / new_count + self.mean2;
203 let new_c = delta1 * (value2 - new_mean2) + self.algo_const;
204
205 self.count += 1.0;
206 self.mean1 = new_mean1;
207 self.mean2 = new_mean2;
208 self.algo_const = new_c;
209 }
210
211 Ok(())
212 }
213
214 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
215 let values1 = &cast(&values[0], &DataType::Float64)?;
216 let values2 = &cast(&values[1], &DataType::Float64)?;
217 let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
218 let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
219
220 for i in 0..values1.len() {
221 let value1 = if values1.is_valid(i) {
222 arr1.next()
223 } else {
224 None
225 };
226 let value2 = if values2.is_valid(i) {
227 arr2.next()
228 } else {
229 None
230 };
231
232 if value1.is_none() || value2.is_none() {
233 continue;
234 }
235
236 let value1 = unwrap_or_internal_err!(value1);
237 let value2 = unwrap_or_internal_err!(value2);
238
239 let new_count = self.count - 1.0;
240 let delta1 = self.mean1 - value1;
241 let new_mean1 = delta1 / new_count + self.mean1;
242 let delta2 = self.mean2 - value2;
243 let new_mean2 = delta2 / new_count + self.mean2;
244 let new_c = self.algo_const - delta1 * (new_mean2 - value2);
245
246 self.count -= 1.0;
247 self.mean1 = new_mean1;
248 self.mean2 = new_mean2;
249 self.algo_const = new_c;
250 }
251
252 Ok(())
253 }
254
255 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
256 let counts = downcast_value!(states[0], Float64Array);
257 let means1 = downcast_value!(states[1], Float64Array);
258 let means2 = downcast_value!(states[2], Float64Array);
259 let cs = downcast_value!(states[3], Float64Array);
260
261 for i in 0..counts.len() {
262 let c = counts.value(i);
263 if c == 0.0 {
264 continue;
265 }
266 let new_count = self.count + c;
267 let new_mean1 = self.mean1 * self.count / new_count + means1.value(i) * c / new_count;
268 let new_mean2 = self.mean2 * self.count / new_count + means2.value(i) * c / new_count;
269 let delta1 = self.mean1 - means1.value(i);
270 let delta2 = self.mean2 - means2.value(i);
271 let new_c =
272 self.algo_const + cs.value(i) + delta1 * delta2 * self.count * c / new_count;
273
274 self.count = new_count;
275 self.mean1 = new_mean1;
276 self.mean2 = new_mean2;
277 self.algo_const = new_c;
278 }
279 Ok(())
280 }
281
282 fn evaluate(&mut self) -> Result<ScalarValue> {
283 if self.count == 0.0 {
284 return Ok(ScalarValue::Float64(None));
285 }
286
287 let count = match self.stats_type {
288 StatsType::Population => self.count,
289 StatsType::Sample if self.count > 1.0 => self.count - 1.0,
290 StatsType::Sample => {
291 return if self.null_on_divide_by_zero {
293 Ok(ScalarValue::Float64(None))
294 } else {
295 Ok(ScalarValue::Float64(Some(f64::NAN)))
296 };
297 }
298 };
299
300 Ok(ScalarValue::Float64(Some(self.algo_const / count)))
301 }
302
303 fn size(&self) -> usize {
304 std::mem::size_of_val(self)
305 }
306}