datafusion_functions_aggregate_common/aggregate/avg_distinct/
decimal.rs1use arrow::{
19 array::{ArrayRef, ArrowNumericType},
20 datatypes::{i256, Decimal128Type, Decimal256Type, DecimalType},
21};
22use datafusion_common::{Result, ScalarValue};
23use datafusion_expr_common::accumulator::Accumulator;
24use std::fmt::Debug;
25use std::mem::size_of_val;
26
27use crate::aggregate::sum_distinct::DistinctSumAccumulator;
28use crate::utils::DecimalAverager;
29
30#[derive(Debug)]
33pub struct DecimalDistinctAvgAccumulator<T: DecimalType + Debug> {
34 sum_accumulator: DistinctSumAccumulator<T>,
35 sum_scale: i8,
36 target_precision: u8,
37 target_scale: i8,
38}
39
40impl<T: DecimalType + Debug> DecimalDistinctAvgAccumulator<T> {
41 pub fn with_decimal_params(
42 sum_scale: i8,
43 target_precision: u8,
44 target_scale: i8,
45 ) -> Self {
46 let data_type = T::TYPE_CONSTRUCTOR(T::MAX_PRECISION, sum_scale);
47
48 Self {
49 sum_accumulator: DistinctSumAccumulator::new(&data_type),
50 sum_scale,
51 target_precision,
52 target_scale,
53 }
54 }
55}
56
57impl<T: DecimalType + ArrowNumericType + Debug> Accumulator
58 for DecimalDistinctAvgAccumulator<T>
59{
60 fn state(&mut self) -> Result<Vec<ScalarValue>> {
61 self.sum_accumulator.state()
62 }
63
64 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
65 self.sum_accumulator.update_batch(values)
66 }
67
68 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
69 self.sum_accumulator.merge_batch(states)
70 }
71
72 fn evaluate(&mut self) -> Result<ScalarValue> {
73 if self.sum_accumulator.distinct_count() == 0 {
74 return ScalarValue::new_primitive::<T>(
75 None,
76 &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale),
77 );
78 }
79
80 let sum_scalar = self.sum_accumulator.evaluate()?;
81
82 match sum_scalar {
83 ScalarValue::Decimal128(Some(sum), _, _) => {
84 let decimal_averager = DecimalAverager::<Decimal128Type>::try_new(
85 self.sum_scale,
86 self.target_precision,
87 self.target_scale,
88 )?;
89 let avg = decimal_averager
90 .avg(sum, self.sum_accumulator.distinct_count() as i128)?;
91 Ok(ScalarValue::Decimal128(
92 Some(avg),
93 self.target_precision,
94 self.target_scale,
95 ))
96 }
97 ScalarValue::Decimal256(Some(sum), _, _) => {
98 let decimal_averager = DecimalAverager::<Decimal256Type>::try_new(
99 self.sum_scale,
100 self.target_precision,
101 self.target_scale,
102 )?;
103 let distinct_cnt: i128 = self.sum_accumulator.distinct_count() as i128;
106 let count: i256 = i256::from_i128(distinct_cnt);
107 let avg = decimal_averager.avg(sum, count)?;
108 Ok(ScalarValue::Decimal256(
109 Some(avg),
110 self.target_precision,
111 self.target_scale,
112 ))
113 }
114
115 _ => unreachable!("Unsupported decimal type: {:?}", sum_scalar),
116 }
117 }
118
119 fn size(&self) -> usize {
120 let fixed_size = size_of_val(self);
121
122 fixed_size + self.sum_accumulator.size()
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130 use arrow::array::{Decimal128Array, Decimal256Array};
131 use std::sync::Arc;
132
133 #[test]
134 fn test_decimal128_distinct_avg_accumulator() -> Result<()> {
135 let precision = 10_u8;
136 let scale = 4_i8;
137 let array = Decimal128Array::from(vec![
138 Some(100_0000),
139 Some(125_0000),
140 Some(175_0000),
141 Some(200_0000),
142 Some(200_0000),
143 Some(300_0000),
144 None,
145 None,
146 ])
147 .with_precision_and_scale(precision, scale)?;
148
149 let mut accumulator =
150 DecimalDistinctAvgAccumulator::<Decimal128Type>::with_decimal_params(
151 scale, 14, 8,
152 );
153 accumulator.update_batch(&[Arc::new(array)])?;
154
155 let result = accumulator.evaluate()?;
156 let expected_result = ScalarValue::Decimal128(Some(180_00000000), 14, 8);
157 assert_eq!(result, expected_result);
158
159 Ok(())
160 }
161
162 #[test]
163 fn test_decimal256_distinct_avg_accumulator() -> Result<()> {
164 let precision = 50_u8;
165 let scale = 2_i8;
166
167 let array = Decimal256Array::from(vec![
168 Some(i256::from_i128(10_000)),
169 Some(i256::from_i128(12_500)),
170 Some(i256::from_i128(17_500)),
171 Some(i256::from_i128(20_000)),
172 Some(i256::from_i128(20_000)),
173 Some(i256::from_i128(30_000)),
174 None,
175 None,
176 ])
177 .with_precision_and_scale(precision, scale)?;
178
179 let mut accumulator =
180 DecimalDistinctAvgAccumulator::<Decimal256Type>::with_decimal_params(
181 scale, 54, 6,
182 );
183 accumulator.update_batch(&[Arc::new(array)])?;
184
185 let result = accumulator.evaluate()?;
186 let expected_result =
187 ScalarValue::Decimal256(Some(i256::from_i128(180_000000)), 54, 6);
188 assert_eq!(result, expected_result);
189
190 Ok(())
191 }
192}