datafusion_functions_aggregate_common/aggregate/avg_distinct/
decimal.rs1use arrow::{
19 array::{ArrayRef, ArrowNumericType},
20 datatypes::{
21 i256, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, DecimalType,
22 },
23};
24use datafusion_common::{Result, ScalarValue};
25use datafusion_expr_common::accumulator::Accumulator;
26use std::fmt::Debug;
27use std::mem::size_of_val;
28
29use crate::aggregate::sum_distinct::DistinctSumAccumulator;
30use crate::utils::DecimalAverager;
31
32#[derive(Debug)]
35pub struct DecimalDistinctAvgAccumulator<T: DecimalType + Debug> {
36 sum_accumulator: DistinctSumAccumulator<T>,
37 sum_scale: i8,
38 target_precision: u8,
39 target_scale: i8,
40}
41
42impl<T: DecimalType + Debug> DecimalDistinctAvgAccumulator<T> {
43 pub fn with_decimal_params(
44 sum_scale: i8,
45 target_precision: u8,
46 target_scale: i8,
47 ) -> Self {
48 let data_type = T::TYPE_CONSTRUCTOR(T::MAX_PRECISION, sum_scale);
49
50 Self {
51 sum_accumulator: DistinctSumAccumulator::new(&data_type),
52 sum_scale,
53 target_precision,
54 target_scale,
55 }
56 }
57}
58
59impl<T: DecimalType + ArrowNumericType + Debug> Accumulator
60 for DecimalDistinctAvgAccumulator<T>
61{
62 fn state(&mut self) -> Result<Vec<ScalarValue>> {
63 self.sum_accumulator.state()
64 }
65
66 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
67 self.sum_accumulator.update_batch(values)
68 }
69
70 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
71 self.sum_accumulator.merge_batch(states)
72 }
73
74 fn evaluate(&mut self) -> Result<ScalarValue> {
75 if self.sum_accumulator.distinct_count() == 0 {
76 return ScalarValue::new_primitive::<T>(
77 None,
78 &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale),
79 );
80 }
81
82 let sum_scalar = self.sum_accumulator.evaluate()?;
83
84 match sum_scalar {
85 ScalarValue::Decimal32(Some(sum), _, _) => {
86 let decimal_averager = DecimalAverager::<Decimal32Type>::try_new(
87 self.sum_scale,
88 self.target_precision,
89 self.target_scale,
90 )?;
91 let avg = decimal_averager
92 .avg(sum, self.sum_accumulator.distinct_count() as i32)?;
93 Ok(ScalarValue::Decimal32(
94 Some(avg),
95 self.target_precision,
96 self.target_scale,
97 ))
98 }
99 ScalarValue::Decimal64(Some(sum), _, _) => {
100 let decimal_averager = DecimalAverager::<Decimal64Type>::try_new(
101 self.sum_scale,
102 self.target_precision,
103 self.target_scale,
104 )?;
105 let avg = decimal_averager
106 .avg(sum, self.sum_accumulator.distinct_count() as i64)?;
107 Ok(ScalarValue::Decimal64(
108 Some(avg),
109 self.target_precision,
110 self.target_scale,
111 ))
112 }
113 ScalarValue::Decimal128(Some(sum), _, _) => {
114 let decimal_averager = DecimalAverager::<Decimal128Type>::try_new(
115 self.sum_scale,
116 self.target_precision,
117 self.target_scale,
118 )?;
119 let avg = decimal_averager
120 .avg(sum, self.sum_accumulator.distinct_count() as i128)?;
121 Ok(ScalarValue::Decimal128(
122 Some(avg),
123 self.target_precision,
124 self.target_scale,
125 ))
126 }
127 ScalarValue::Decimal256(Some(sum), _, _) => {
128 let decimal_averager = DecimalAverager::<Decimal256Type>::try_new(
129 self.sum_scale,
130 self.target_precision,
131 self.target_scale,
132 )?;
133 let distinct_cnt: i128 = self.sum_accumulator.distinct_count() as i128;
136 let count: i256 = i256::from_i128(distinct_cnt);
137 let avg = decimal_averager.avg(sum, count)?;
138 Ok(ScalarValue::Decimal256(
139 Some(avg),
140 self.target_precision,
141 self.target_scale,
142 ))
143 }
144
145 _ => unreachable!("Unsupported decimal type: {:?}", sum_scalar),
146 }
147 }
148
149 fn size(&self) -> usize {
150 let fixed_size = size_of_val(self);
151
152 fixed_size + self.sum_accumulator.size()
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use arrow::array::{
161 Decimal128Array, Decimal256Array, Decimal32Array, Decimal64Array,
162 };
163 use std::sync::Arc;
164
165 #[test]
166 fn test_decimal32_distinct_avg_accumulator() -> Result<()> {
167 let precision = 5_u8;
168 let scale = 2_i8;
169 let array = Decimal32Array::from(vec![
170 Some(10_00),
171 Some(12_50),
172 Some(17_50),
173 Some(20_00),
174 Some(20_00),
175 Some(30_00),
176 None,
177 None,
178 ])
179 .with_precision_and_scale(precision, scale)?;
180
181 let mut accumulator =
182 DecimalDistinctAvgAccumulator::<Decimal32Type>::with_decimal_params(
183 scale, 9, 6,
184 );
185 accumulator.update_batch(&[Arc::new(array)])?;
186
187 let result = accumulator.evaluate()?;
188 let expected_result = ScalarValue::Decimal32(Some(18000000), 9, 6);
189 assert_eq!(result, expected_result);
190
191 Ok(())
192 }
193
194 #[test]
195 fn test_decimal64_distinct_avg_accumulator() -> Result<()> {
196 let precision = 10_u8;
197 let scale = 4_i8;
198 let array = Decimal64Array::from(vec![
199 Some(100_0000),
200 Some(125_0000),
201 Some(175_0000),
202 Some(200_0000),
203 Some(200_0000),
204 Some(300_0000),
205 None,
206 None,
207 ])
208 .with_precision_and_scale(precision, scale)?;
209
210 let mut accumulator =
211 DecimalDistinctAvgAccumulator::<Decimal64Type>::with_decimal_params(
212 scale, 14, 8,
213 );
214 accumulator.update_batch(&[Arc::new(array)])?;
215
216 let result = accumulator.evaluate()?;
217 let expected_result = ScalarValue::Decimal64(Some(180_00000000), 14, 8);
218 assert_eq!(result, expected_result);
219
220 Ok(())
221 }
222
223 #[test]
224 fn test_decimal128_distinct_avg_accumulator() -> Result<()> {
225 let precision = 10_u8;
226 let scale = 4_i8;
227 let array = Decimal128Array::from(vec![
228 Some(100_0000),
229 Some(125_0000),
230 Some(175_0000),
231 Some(200_0000),
232 Some(200_0000),
233 Some(300_0000),
234 None,
235 None,
236 ])
237 .with_precision_and_scale(precision, scale)?;
238
239 let mut accumulator =
240 DecimalDistinctAvgAccumulator::<Decimal128Type>::with_decimal_params(
241 scale, 14, 8,
242 );
243 accumulator.update_batch(&[Arc::new(array)])?;
244
245 let result = accumulator.evaluate()?;
246 let expected_result = ScalarValue::Decimal128(Some(180_00000000), 14, 8);
247 assert_eq!(result, expected_result);
248
249 Ok(())
250 }
251
252 #[test]
253 fn test_decimal256_distinct_avg_accumulator() -> Result<()> {
254 let precision = 50_u8;
255 let scale = 2_i8;
256
257 let array = Decimal256Array::from(vec![
258 Some(i256::from_i128(10_000)),
259 Some(i256::from_i128(12_500)),
260 Some(i256::from_i128(17_500)),
261 Some(i256::from_i128(20_000)),
262 Some(i256::from_i128(20_000)),
263 Some(i256::from_i128(30_000)),
264 None,
265 None,
266 ])
267 .with_precision_and_scale(precision, scale)?;
268
269 let mut accumulator =
270 DecimalDistinctAvgAccumulator::<Decimal256Type>::with_decimal_params(
271 scale, 54, 6,
272 );
273 accumulator.update_batch(&[Arc::new(array)])?;
274
275 let result = accumulator.evaluate()?;
276 let expected_result =
277 ScalarValue::Decimal256(Some(i256::from_i128(180_000000)), 54, 6);
278 assert_eq!(result, expected_result);
279
280 Ok(())
281 }
282}