datafusion_comet_spark_expr/agg_funcs/
avg.rs1use arrow::array::{
19 builder::PrimitiveBuilder,
20 cast::AsArray,
21 types::{Float64Type, Int64Type},
22 Array, ArrayRef, ArrowNumericType, Int64Array, PrimitiveArray,
23};
24use arrow::compute::sum;
25use arrow::datatypes::{DataType, Field, FieldRef};
26use datafusion::common::{not_impl_err, Result, ScalarValue};
27use datafusion::logical_expr::{
28 type_coercion::aggregates::avg_return_type, Accumulator, AggregateUDFImpl, EmitTo,
29 GroupsAccumulator, ReversedUDAF, Signature,
30};
31use datafusion::physical_expr::expressions::format_state_name;
32use std::{any::Any, sync::Arc};
33
34use arrow::array::ArrowNativeTypeOp;
35use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
36use datafusion::logical_expr::Volatility::Immutable;
37use DataType::*;
38
39#[derive(Debug, Clone)]
41pub struct Avg {
42 name: String,
43 signature: Signature,
44 input_data_type: DataType,
46 result_data_type: DataType,
47}
48
49impl Avg {
50 pub fn new(name: impl Into<String>, data_type: DataType) -> Self {
52 let result_data_type = avg_return_type("avg", &data_type).unwrap();
53
54 Self {
55 name: name.into(),
56 signature: Signature::user_defined(Immutable),
57 input_data_type: data_type,
58 result_data_type,
59 }
60 }
61}
62
63impl AggregateUDFImpl for Avg {
64 fn as_any(&self) -> &dyn Any {
66 self
67 }
68
69 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
70 match (&self.input_data_type, &self.result_data_type) {
72 (Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
73 _ => not_impl_err!(
74 "AvgAccumulator for ({} --> {})",
75 self.input_data_type,
76 self.result_data_type
77 ),
78 }
79 }
80
81 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
82 Ok(vec![
83 Arc::new(Field::new(
84 format_state_name(&self.name, "sum"),
85 self.input_data_type.clone(),
86 true,
87 )),
88 Arc::new(Field::new(
89 format_state_name(&self.name, "count"),
90 DataType::Int64,
91 true,
92 )),
93 ])
94 }
95
96 fn name(&self) -> &str {
97 &self.name
98 }
99
100 fn reverse_expr(&self) -> ReversedUDAF {
101 ReversedUDAF::Identical
102 }
103
104 fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
105 true
106 }
107
108 fn create_groups_accumulator(
109 &self,
110 _args: AccumulatorArgs,
111 ) -> Result<Box<dyn GroupsAccumulator>> {
112 match (&self.input_data_type, &self.result_data_type) {
114 (Float64, Float64) => Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
115 &self.input_data_type,
116 |sum: f64, count: i64| Ok(sum / count as f64),
117 ))),
118
119 _ => not_impl_err!(
120 "AvgGroupsAccumulator for ({} --> {})",
121 self.input_data_type,
122 self.result_data_type
123 ),
124 }
125 }
126
127 fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
128 Ok(ScalarValue::Float64(None))
129 }
130
131 fn signature(&self) -> &Signature {
132 &self.signature
133 }
134
135 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
136 avg_return_type(self.name(), &arg_types[0])
137 }
138}
139
140#[derive(Debug, Default)]
142pub struct AvgAccumulator {
143 sum: Option<f64>,
144 count: i64,
145}
146
147impl Accumulator for AvgAccumulator {
148 fn state(&mut self) -> Result<Vec<ScalarValue>> {
149 Ok(vec![
150 ScalarValue::Float64(self.sum),
151 ScalarValue::from(self.count),
152 ])
153 }
154
155 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
156 let values = values[0].as_primitive::<Float64Type>();
157 self.count += (values.len() - values.null_count()) as i64;
158 let v = self.sum.get_or_insert(0.);
159 if let Some(x) = sum(values) {
160 *v += x;
161 }
162 Ok(())
163 }
164
165 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
166 self.count += sum(states[1].as_primitive::<Int64Type>()).unwrap_or_default();
168
169 if let Some(x) = sum(states[0].as_primitive::<Float64Type>()) {
171 let v = self.sum.get_or_insert(0.);
172 *v += x;
173 }
174 Ok(())
175 }
176
177 fn evaluate(&mut self) -> Result<ScalarValue> {
178 if self.count == 0 {
179 Ok(ScalarValue::Float64(None))
182 } else {
183 Ok(ScalarValue::Float64(
184 self.sum.map(|f| f / self.count as f64),
185 ))
186 }
187 }
188
189 fn size(&self) -> usize {
190 std::mem::size_of_val(self)
191 }
192}
193
194#[derive(Debug)]
200struct AvgGroupsAccumulator<T, F>
201where
202 T: ArrowNumericType + Send,
203 F: Fn(T::Native, i64) -> Result<T::Native> + Send,
204{
205 return_data_type: DataType,
207
208 counts: Vec<i64>,
210
211 sums: Vec<T::Native>,
213
214 avg_fn: F,
216}
217
218impl<T, F> AvgGroupsAccumulator<T, F>
219where
220 T: ArrowNumericType + Send,
221 F: Fn(T::Native, i64) -> Result<T::Native> + Send,
222{
223 pub fn new(return_data_type: &DataType, avg_fn: F) -> Self {
224 Self {
225 return_data_type: return_data_type.clone(),
226 counts: vec![],
227 sums: vec![],
228 avg_fn,
229 }
230 }
231}
232
233impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
234where
235 T: ArrowNumericType + Send,
236 F: Fn(T::Native, i64) -> Result<T::Native> + Send,
237{
238 fn update_batch(
239 &mut self,
240 values: &[ArrayRef],
241 group_indices: &[usize],
242 _opt_filter: Option<&arrow::array::BooleanArray>,
243 total_num_groups: usize,
244 ) -> Result<()> {
245 assert_eq!(values.len(), 1, "single argument to update_batch");
246 let values = values[0].as_primitive::<T>();
247 let data = values.values();
248
249 self.counts.resize(total_num_groups, 0);
251 self.sums.resize(total_num_groups, T::default_value());
252
253 let iter = group_indices.iter().zip(data.iter());
254 if values.null_count() == 0 {
255 for (&group_index, &value) in iter {
256 let sum = &mut self.sums[group_index];
257 *sum = (*sum).add_wrapping(value);
258 self.counts[group_index] += 1;
259 }
260 } else {
261 for (idx, (&group_index, &value)) in iter.enumerate() {
262 if values.is_null(idx) {
263 continue;
264 }
265 let sum = &mut self.sums[group_index];
266 *sum = (*sum).add_wrapping(value);
267
268 self.counts[group_index] += 1;
269 }
270 }
271
272 Ok(())
273 }
274
275 fn merge_batch(
276 &mut self,
277 values: &[ArrayRef],
278 group_indices: &[usize],
279 _opt_filter: Option<&arrow::array::BooleanArray>,
280 total_num_groups: usize,
281 ) -> Result<()> {
282 assert_eq!(values.len(), 2, "two arguments to merge_batch");
283 let partial_sums = values[0].as_primitive::<T>();
285 let partial_counts = values[1].as_primitive::<Int64Type>();
286 self.counts.resize(total_num_groups, 0);
288 let iter1 = group_indices.iter().zip(partial_counts.values().iter());
289 for (&group_index, &partial_count) in iter1 {
290 self.counts[group_index] += partial_count;
291 }
292
293 self.sums.resize(total_num_groups, T::default_value());
295 let iter2 = group_indices.iter().zip(partial_sums.values().iter());
296 for (&group_index, &new_value) in iter2 {
297 let sum = &mut self.sums[group_index];
298 *sum = sum.add_wrapping(new_value);
299 }
300
301 Ok(())
302 }
303
304 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
305 let counts = emit_to.take_needed(&mut self.counts);
306 let sums = emit_to.take_needed(&mut self.sums);
307 let mut builder = PrimitiveBuilder::<T>::with_capacity(sums.len());
308 let iter = sums.into_iter().zip(counts);
309
310 for (sum, count) in iter {
311 if count != 0 {
312 builder.append_value((self.avg_fn)(sum, count)?)
313 } else {
314 builder.append_null();
315 }
316 }
317 let array: PrimitiveArray<T> = builder.finish();
318
319 Ok(Arc::new(array))
320 }
321
322 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
324 let counts = emit_to.take_needed(&mut self.counts);
325 let counts = Int64Array::new(counts.into(), None);
326
327 let sums = emit_to.take_needed(&mut self.sums);
328 let sums = PrimitiveArray::<T>::new(sums.into(), None)
329 .with_data_type(self.return_data_type.clone());
330
331 Ok(vec![
332 Arc::new(sums) as ArrayRef,
333 Arc::new(counts) as ArrayRef,
334 ])
335 }
336
337 fn size(&self) -> usize {
338 self.counts.capacity() * std::mem::size_of::<i64>()
339 + self.sums.capacity() * std::mem::size_of::<T>()
340 }
341}