datafusion_functions_aggregate_common/aggregate/avg_distinct/
decimal.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::{
19    array::{ArrayRef, ArrowNumericType},
20    datatypes::{
21        i256, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, DecimalType,
22    },
23};
24use datafusion_common::{Result, ScalarValue};
25use datafusion_expr_common::accumulator::Accumulator;
26use std::fmt::Debug;
27use std::mem::size_of_val;
28
29use crate::aggregate::sum_distinct::DistinctSumAccumulator;
30use crate::utils::DecimalAverager;
31
32/// Generic implementation of `AVG DISTINCT` for Decimal types.
33/// Handles both all Arrow decimal types (32, 64, 128 and 256 bits).
34#[derive(Debug)]
35pub struct DecimalDistinctAvgAccumulator<T: DecimalType + Debug> {
36    sum_accumulator: DistinctSumAccumulator<T>,
37    sum_scale: i8,
38    target_precision: u8,
39    target_scale: i8,
40}
41
42impl<T: DecimalType + Debug> DecimalDistinctAvgAccumulator<T> {
43    pub fn with_decimal_params(
44        sum_scale: i8,
45        target_precision: u8,
46        target_scale: i8,
47    ) -> Self {
48        let data_type = T::TYPE_CONSTRUCTOR(T::MAX_PRECISION, sum_scale);
49
50        Self {
51            sum_accumulator: DistinctSumAccumulator::new(&data_type),
52            sum_scale,
53            target_precision,
54            target_scale,
55        }
56    }
57}
58
59impl<T: DecimalType + ArrowNumericType + Debug> Accumulator
60    for DecimalDistinctAvgAccumulator<T>
61{
62    fn state(&mut self) -> Result<Vec<ScalarValue>> {
63        self.sum_accumulator.state()
64    }
65
66    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
67        self.sum_accumulator.update_batch(values)
68    }
69
70    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
71        self.sum_accumulator.merge_batch(states)
72    }
73
74    fn evaluate(&mut self) -> Result<ScalarValue> {
75        if self.sum_accumulator.distinct_count() == 0 {
76            return ScalarValue::new_primitive::<T>(
77                None,
78                &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale),
79            );
80        }
81
82        let sum_scalar = self.sum_accumulator.evaluate()?;
83
84        match sum_scalar {
85            ScalarValue::Decimal32(Some(sum), _, _) => {
86                let decimal_averager = DecimalAverager::<Decimal32Type>::try_new(
87                    self.sum_scale,
88                    self.target_precision,
89                    self.target_scale,
90                )?;
91                let avg = decimal_averager
92                    .avg(sum, self.sum_accumulator.distinct_count() as i32)?;
93                Ok(ScalarValue::Decimal32(
94                    Some(avg),
95                    self.target_precision,
96                    self.target_scale,
97                ))
98            }
99            ScalarValue::Decimal64(Some(sum), _, _) => {
100                let decimal_averager = DecimalAverager::<Decimal64Type>::try_new(
101                    self.sum_scale,
102                    self.target_precision,
103                    self.target_scale,
104                )?;
105                let avg = decimal_averager
106                    .avg(sum, self.sum_accumulator.distinct_count() as i64)?;
107                Ok(ScalarValue::Decimal64(
108                    Some(avg),
109                    self.target_precision,
110                    self.target_scale,
111                ))
112            }
113            ScalarValue::Decimal128(Some(sum), _, _) => {
114                let decimal_averager = DecimalAverager::<Decimal128Type>::try_new(
115                    self.sum_scale,
116                    self.target_precision,
117                    self.target_scale,
118                )?;
119                let avg = decimal_averager
120                    .avg(sum, self.sum_accumulator.distinct_count() as i128)?;
121                Ok(ScalarValue::Decimal128(
122                    Some(avg),
123                    self.target_precision,
124                    self.target_scale,
125                ))
126            }
127            ScalarValue::Decimal256(Some(sum), _, _) => {
128                let decimal_averager = DecimalAverager::<Decimal256Type>::try_new(
129                    self.sum_scale,
130                    self.target_precision,
131                    self.target_scale,
132                )?;
133                // `distinct_count` returns `u64`, but `avg` expects `i256`
134                // first convert `u64` to `i128`, then convert `i128` to `i256` to avoid overflow
135                let distinct_cnt: i128 = self.sum_accumulator.distinct_count() as i128;
136                let count: i256 = i256::from_i128(distinct_cnt);
137                let avg = decimal_averager.avg(sum, count)?;
138                Ok(ScalarValue::Decimal256(
139                    Some(avg),
140                    self.target_precision,
141                    self.target_scale,
142                ))
143            }
144
145            _ => unreachable!("Unsupported decimal type: {:?}", sum_scalar),
146        }
147    }
148
149    fn size(&self) -> usize {
150        let fixed_size = size_of_val(self);
151
152        // Account for the size of the sum_accumulator with its contained values
153        fixed_size + self.sum_accumulator.size()
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use arrow::array::{
161        Decimal128Array, Decimal256Array, Decimal32Array, Decimal64Array,
162    };
163    use std::sync::Arc;
164
165    #[test]
166    fn test_decimal32_distinct_avg_accumulator() -> Result<()> {
167        let precision = 5_u8;
168        let scale = 2_i8;
169        let array = Decimal32Array::from(vec![
170            Some(10_00),
171            Some(12_50),
172            Some(17_50),
173            Some(20_00),
174            Some(20_00),
175            Some(30_00),
176            None,
177            None,
178        ])
179        .with_precision_and_scale(precision, scale)?;
180
181        let mut accumulator =
182            DecimalDistinctAvgAccumulator::<Decimal32Type>::with_decimal_params(
183                scale, 9, 6,
184            );
185        accumulator.update_batch(&[Arc::new(array)])?;
186
187        let result = accumulator.evaluate()?;
188        let expected_result = ScalarValue::Decimal32(Some(18000000), 9, 6);
189        assert_eq!(result, expected_result);
190
191        Ok(())
192    }
193
194    #[test]
195    fn test_decimal64_distinct_avg_accumulator() -> Result<()> {
196        let precision = 10_u8;
197        let scale = 4_i8;
198        let array = Decimal64Array::from(vec![
199            Some(100_0000),
200            Some(125_0000),
201            Some(175_0000),
202            Some(200_0000),
203            Some(200_0000),
204            Some(300_0000),
205            None,
206            None,
207        ])
208        .with_precision_and_scale(precision, scale)?;
209
210        let mut accumulator =
211            DecimalDistinctAvgAccumulator::<Decimal64Type>::with_decimal_params(
212                scale, 14, 8,
213            );
214        accumulator.update_batch(&[Arc::new(array)])?;
215
216        let result = accumulator.evaluate()?;
217        let expected_result = ScalarValue::Decimal64(Some(180_00000000), 14, 8);
218        assert_eq!(result, expected_result);
219
220        Ok(())
221    }
222
223    #[test]
224    fn test_decimal128_distinct_avg_accumulator() -> Result<()> {
225        let precision = 10_u8;
226        let scale = 4_i8;
227        let array = Decimal128Array::from(vec![
228            Some(100_0000),
229            Some(125_0000),
230            Some(175_0000),
231            Some(200_0000),
232            Some(200_0000),
233            Some(300_0000),
234            None,
235            None,
236        ])
237        .with_precision_and_scale(precision, scale)?;
238
239        let mut accumulator =
240            DecimalDistinctAvgAccumulator::<Decimal128Type>::with_decimal_params(
241                scale, 14, 8,
242            );
243        accumulator.update_batch(&[Arc::new(array)])?;
244
245        let result = accumulator.evaluate()?;
246        let expected_result = ScalarValue::Decimal128(Some(180_00000000), 14, 8);
247        assert_eq!(result, expected_result);
248
249        Ok(())
250    }
251
252    #[test]
253    fn test_decimal256_distinct_avg_accumulator() -> Result<()> {
254        let precision = 50_u8;
255        let scale = 2_i8;
256
257        let array = Decimal256Array::from(vec![
258            Some(i256::from_i128(10_000)),
259            Some(i256::from_i128(12_500)),
260            Some(i256::from_i128(17_500)),
261            Some(i256::from_i128(20_000)),
262            Some(i256::from_i128(20_000)),
263            Some(i256::from_i128(30_000)),
264            None,
265            None,
266        ])
267        .with_precision_and_scale(precision, scale)?;
268
269        let mut accumulator =
270            DecimalDistinctAvgAccumulator::<Decimal256Type>::with_decimal_params(
271                scale, 54, 6,
272            );
273        accumulator.update_batch(&[Arc::new(array)])?;
274
275        let result = accumulator.evaluate()?;
276        let expected_result =
277            ScalarValue::Decimal256(Some(i256::from_i128(180_000000)), 54, 6);
278        assert_eq!(result, expected_result);
279
280        Ok(())
281    }
282}