datafusion_functions_aggregate_common/aggregate/avg_distinct/
decimal.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::{
19    array::{ArrayRef, ArrowNumericType},
20    datatypes::{i256, Decimal128Type, Decimal256Type, DecimalType},
21};
22use datafusion_common::{Result, ScalarValue};
23use datafusion_expr_common::accumulator::Accumulator;
24use std::fmt::Debug;
25use std::mem::size_of_val;
26
27use crate::aggregate::sum_distinct::DistinctSumAccumulator;
28use crate::utils::DecimalAverager;
29
30/// Generic implementation of `AVG DISTINCT` for Decimal types.
31/// Handles both Decimal128Type and Decimal256Type.
32#[derive(Debug)]
33pub struct DecimalDistinctAvgAccumulator<T: DecimalType + Debug> {
34    sum_accumulator: DistinctSumAccumulator<T>,
35    sum_scale: i8,
36    target_precision: u8,
37    target_scale: i8,
38}
39
40impl<T: DecimalType + Debug> DecimalDistinctAvgAccumulator<T> {
41    pub fn with_decimal_params(
42        sum_scale: i8,
43        target_precision: u8,
44        target_scale: i8,
45    ) -> Self {
46        let data_type = T::TYPE_CONSTRUCTOR(T::MAX_PRECISION, sum_scale);
47
48        Self {
49            sum_accumulator: DistinctSumAccumulator::new(&data_type),
50            sum_scale,
51            target_precision,
52            target_scale,
53        }
54    }
55}
56
57impl<T: DecimalType + ArrowNumericType + Debug> Accumulator
58    for DecimalDistinctAvgAccumulator<T>
59{
60    fn state(&mut self) -> Result<Vec<ScalarValue>> {
61        self.sum_accumulator.state()
62    }
63
64    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
65        self.sum_accumulator.update_batch(values)
66    }
67
68    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
69        self.sum_accumulator.merge_batch(states)
70    }
71
72    fn evaluate(&mut self) -> Result<ScalarValue> {
73        if self.sum_accumulator.distinct_count() == 0 {
74            return ScalarValue::new_primitive::<T>(
75                None,
76                &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale),
77            );
78        }
79
80        let sum_scalar = self.sum_accumulator.evaluate()?;
81
82        match sum_scalar {
83            ScalarValue::Decimal128(Some(sum), _, _) => {
84                let decimal_averager = DecimalAverager::<Decimal128Type>::try_new(
85                    self.sum_scale,
86                    self.target_precision,
87                    self.target_scale,
88                )?;
89                let avg = decimal_averager
90                    .avg(sum, self.sum_accumulator.distinct_count() as i128)?;
91                Ok(ScalarValue::Decimal128(
92                    Some(avg),
93                    self.target_precision,
94                    self.target_scale,
95                ))
96            }
97            ScalarValue::Decimal256(Some(sum), _, _) => {
98                let decimal_averager = DecimalAverager::<Decimal256Type>::try_new(
99                    self.sum_scale,
100                    self.target_precision,
101                    self.target_scale,
102                )?;
103                // `distinct_count` returns `u64`, but `avg` expects `i256`
104                // first convert `u64` to `i128`, then convert `i128` to `i256` to avoid overflow
105                let distinct_cnt: i128 = self.sum_accumulator.distinct_count() as i128;
106                let count: i256 = i256::from_i128(distinct_cnt);
107                let avg = decimal_averager.avg(sum, count)?;
108                Ok(ScalarValue::Decimal256(
109                    Some(avg),
110                    self.target_precision,
111                    self.target_scale,
112                ))
113            }
114
115            _ => unreachable!("Unsupported decimal type: {:?}", sum_scalar),
116        }
117    }
118
119    fn size(&self) -> usize {
120        let fixed_size = size_of_val(self);
121
122        // Account for the size of the sum_accumulator with its contained values
123        fixed_size + self.sum_accumulator.size()
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use arrow::array::{Decimal128Array, Decimal256Array};
131    use std::sync::Arc;
132
133    #[test]
134    fn test_decimal128_distinct_avg_accumulator() -> Result<()> {
135        let precision = 10_u8;
136        let scale = 4_i8;
137        let array = Decimal128Array::from(vec![
138            Some(100_0000),
139            Some(125_0000),
140            Some(175_0000),
141            Some(200_0000),
142            Some(200_0000),
143            Some(300_0000),
144            None,
145            None,
146        ])
147        .with_precision_and_scale(precision, scale)?;
148
149        let mut accumulator =
150            DecimalDistinctAvgAccumulator::<Decimal128Type>::with_decimal_params(
151                scale, 14, 8,
152            );
153        accumulator.update_batch(&[Arc::new(array)])?;
154
155        let result = accumulator.evaluate()?;
156        let expected_result = ScalarValue::Decimal128(Some(180_00000000), 14, 8);
157        assert_eq!(result, expected_result);
158
159        Ok(())
160    }
161
162    #[test]
163    fn test_decimal256_distinct_avg_accumulator() -> Result<()> {
164        let precision = 50_u8;
165        let scale = 2_i8;
166
167        let array = Decimal256Array::from(vec![
168            Some(i256::from_i128(10_000)),
169            Some(i256::from_i128(12_500)),
170            Some(i256::from_i128(17_500)),
171            Some(i256::from_i128(20_000)),
172            Some(i256::from_i128(20_000)),
173            Some(i256::from_i128(30_000)),
174            None,
175            None,
176        ])
177        .with_precision_and_scale(precision, scale)?;
178
179        let mut accumulator =
180            DecimalDistinctAvgAccumulator::<Decimal256Type>::with_decimal_params(
181                scale, 54, 6,
182            );
183        accumulator.update_batch(&[Arc::new(array)])?;
184
185        let result = accumulator.evaluate()?;
186        let expected_result =
187            ScalarValue::Decimal256(Some(i256::from_i128(180_000000)), 54, 6);
188        assert_eq!(result, expected_result);
189
190        Ok(())
191    }
192}