datafusion_functions_extra/common/mode/
bytes.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use datafusion::arrow::array::AsArray;
19use datafusion::{arrow, common, error, logical_expr, scalar};
20use std::{collections, mem};
21
22#[derive(Debug)]
23pub struct BytesModeAccumulator {
24    value_counts: collections::HashMap<String, i64>,
25    data_type: arrow::datatypes::DataType,
26}
27
28impl BytesModeAccumulator {
29    pub fn new(data_type: &arrow::datatypes::DataType) -> Self {
30        Self {
31            value_counts: collections::HashMap::new(),
32            data_type: data_type.clone(),
33        }
34    }
35
36    fn update_counts<'a, V>(&mut self, array: V)
37    where
38        V: arrow::array::ArrayAccessor<Item = &'a str>,
39    {
40        for value in arrow::array::ArrayIter::new(array).flatten() {
41            let key = value;
42            if let Some(count) = self.value_counts.get_mut(key) {
43                *count += 1;
44            } else {
45                self.value_counts.insert(key.to_string(), 1);
46            }
47        }
48    }
49}
50
51impl logical_expr::Accumulator for BytesModeAccumulator {
52    fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> error::Result<()> {
53        if values.is_empty() {
54            return Ok(());
55        }
56
57        match &self.data_type {
58            arrow::datatypes::DataType::Utf8View => {
59                let array = values[0].as_string_view();
60                self.update_counts(array);
61            }
62            _ => {
63                let array = values[0].as_string::<i32>();
64                self.update_counts(array);
65            }
66        };
67
68        Ok(())
69    }
70
71    fn state(&mut self) -> error::Result<Vec<scalar::ScalarValue>> {
72        let values: Vec<scalar::ScalarValue> = self
73            .value_counts
74            .keys()
75            .map(|key| scalar::ScalarValue::Utf8(Some(key.to_string())))
76            .collect();
77
78        let frequencies: Vec<scalar::ScalarValue> = self
79            .value_counts
80            .values()
81            .map(|&count| scalar::ScalarValue::Int64(Some(count)))
82            .collect();
83
84        let values_scalar =
85            scalar::ScalarValue::new_list_nullable(&values, &arrow::datatypes::DataType::Utf8);
86        let frequencies_scalar = scalar::ScalarValue::new_list_nullable(
87            &frequencies,
88            &arrow::datatypes::DataType::Int64,
89        );
90
91        Ok(vec![
92            scalar::ScalarValue::List(values_scalar),
93            scalar::ScalarValue::List(frequencies_scalar),
94        ])
95    }
96
97    fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> error::Result<()> {
98        if states.is_empty() {
99            return Ok(());
100        }
101
102        let values_array = arrow::array::as_string_array(&states[0]);
103        let counts_array =
104            common::cast::as_primitive_array::<arrow::datatypes::Int64Type>(&states[1])?;
105
106        for (i, value_option) in values_array.iter().enumerate() {
107            if let Some(value) = value_option {
108                let count = counts_array.value(i);
109                let entry = self.value_counts.entry(value.to_string()).or_insert(0);
110                *entry += count;
111            }
112        }
113
114        Ok(())
115    }
116
117    fn evaluate(&mut self) -> error::Result<scalar::ScalarValue> {
118        if self.value_counts.is_empty() {
119            return match &self.data_type {
120                arrow::datatypes::DataType::Utf8View => Ok(scalar::ScalarValue::Utf8View(None)),
121                _ => Ok(scalar::ScalarValue::Utf8(None)),
122            };
123        }
124
125        let mode = self
126            .value_counts
127            .iter()
128            .max_by(|a, b| {
129                // First compare counts
130                a.1.cmp(b.1)
131                    // If counts are equal, compare keys in reverse order to get the maximum string
132                    .then_with(|| b.0.cmp(a.0))
133            })
134            .map(|(value, _)| value.to_string());
135
136        match mode {
137            Some(result) => match &self.data_type {
138                arrow::datatypes::DataType::Utf8View => {
139                    Ok(scalar::ScalarValue::Utf8View(Some(result)))
140                }
141                _ => Ok(scalar::ScalarValue::Utf8(Some(result))),
142            },
143            None => match &self.data_type {
144                arrow::datatypes::DataType::Utf8View => Ok(scalar::ScalarValue::Utf8View(None)),
145                _ => Ok(scalar::ScalarValue::Utf8(None)),
146            },
147        }
148    }
149
150    fn size(&self) -> usize {
151        self.value_counts.capacity() * mem::size_of::<(String, i64)>()
152            + mem::size_of_val(&self.data_type)
153    }
154}
155
156#[cfg(test)]
157mod tests {
158
159    use super::*;
160
161    use datafusion::logical_expr::Accumulator;
162    use std::sync;
163
164    #[test]
165    fn test_mode_accumulator_single_mode_utf8() -> error::Result<()> {
166        let mut acc = BytesModeAccumulator::new(&arrow::datatypes::DataType::Utf8);
167        let values: arrow::array::ArrayRef = sync::Arc::new(arrow::array::StringArray::from(vec![
168            Some("apple"),
169            Some("banana"),
170            Some("apple"),
171            Some("orange"),
172            Some("banana"),
173            Some("apple"),
174        ]));
175
176        acc.update_batch(&[values])?;
177        let result = acc.evaluate()?;
178
179        assert_eq!(result, scalar::ScalarValue::Utf8(Some("apple".to_string())));
180        Ok(())
181    }
182
183    #[test]
184    fn test_mode_accumulator_tie_utf8() -> error::Result<()> {
185        let mut acc = BytesModeAccumulator::new(&arrow::datatypes::DataType::Utf8);
186        let values: arrow::array::ArrayRef = sync::Arc::new(arrow::array::StringArray::from(vec![
187            Some("apple"),
188            Some("banana"),
189            Some("apple"),
190            Some("orange"),
191            Some("banana"),
192        ]));
193
194        acc.update_batch(&[values])?;
195        let result = acc.evaluate()?;
196
197        assert_eq!(result, scalar::ScalarValue::Utf8(Some("apple".to_string())));
198        Ok(())
199    }
200
201    #[test]
202    fn test_mode_accumulator_all_nulls_utf8() -> error::Result<()> {
203        let mut acc = BytesModeAccumulator::new(&arrow::datatypes::DataType::Utf8);
204        let values: arrow::array::ArrayRef = sync::Arc::new(arrow::array::StringArray::from(vec![
205            None as Option<&str>,
206            None,
207            None,
208        ]));
209
210        acc.update_batch(&[values])?;
211        let result = acc.evaluate()?;
212
213        assert_eq!(result, scalar::ScalarValue::Utf8(None));
214        Ok(())
215    }
216
217    #[test]
218    fn test_mode_accumulator_with_nulls_utf8() -> error::Result<()> {
219        let mut acc = BytesModeAccumulator::new(&arrow::datatypes::DataType::Utf8);
220        let values: arrow::array::ArrayRef = sync::Arc::new(arrow::array::StringArray::from(vec![
221            Some("apple"),
222            None,
223            Some("banana"),
224            Some("apple"),
225            None,
226            None,
227            None,
228            Some("banana"),
229        ]));
230
231        acc.update_batch(&[values])?;
232        let result = acc.evaluate()?;
233
234        assert_eq!(result, scalar::ScalarValue::Utf8(Some("apple".to_string())));
235        Ok(())
236    }
237
238    #[test]
239    fn test_mode_accumulator_single_mode_utf8view() -> error::Result<()> {
240        let mut acc = BytesModeAccumulator::new(&arrow::datatypes::DataType::Utf8View);
241        let values: arrow::array::ArrayRef =
242            sync::Arc::new(arrow::array::GenericByteViewArray::from(vec![
243                Some("apple"),
244                Some("banana"),
245                Some("apple"),
246                Some("orange"),
247                Some("banana"),
248                Some("apple"),
249            ]));
250
251        acc.update_batch(&[values])?;
252        let result = acc.evaluate()?;
253
254        assert_eq!(
255            result,
256            scalar::ScalarValue::Utf8View(Some("apple".to_string()))
257        );
258        Ok(())
259    }
260
261    #[test]
262    fn test_mode_accumulator_tie_utf8view() -> error::Result<()> {
263        let mut acc = BytesModeAccumulator::new(&arrow::datatypes::DataType::Utf8View);
264        let values: arrow::array::ArrayRef =
265            sync::Arc::new(arrow::array::GenericByteViewArray::from(vec![
266                Some("apple"),
267                Some("banana"),
268                Some("apple"),
269                Some("orange"),
270                Some("banana"),
271            ]));
272
273        acc.update_batch(&[values])?;
274        let result = acc.evaluate()?;
275
276        assert_eq!(
277            result,
278            scalar::ScalarValue::Utf8View(Some("apple".to_string()))
279        );
280        Ok(())
281    }
282
283    #[test]
284    fn test_mode_accumulator_all_nulls_utf8view() -> error::Result<()> {
285        let mut acc = BytesModeAccumulator::new(&arrow::datatypes::DataType::Utf8View);
286        let values: arrow::array::ArrayRef =
287            sync::Arc::new(arrow::array::GenericByteViewArray::from(vec![
288                None as Option<&str>,
289                None,
290                None,
291            ]));
292
293        acc.update_batch(&[values])?;
294        let result = acc.evaluate()?;
295
296        assert_eq!(result, scalar::ScalarValue::Utf8View(None));
297        Ok(())
298    }
299
300    #[test]
301    fn test_mode_accumulator_with_nulls_utf8view() -> error::Result<()> {
302        let mut acc = BytesModeAccumulator::new(&arrow::datatypes::DataType::Utf8View);
303        let values: arrow::array::ArrayRef =
304            sync::Arc::new(arrow::array::GenericByteViewArray::from(vec![
305                Some("apple"),
306                None,
307                Some("banana"),
308                Some("apple"),
309                None,
310                None,
311                None,
312                Some("banana"),
313            ]));
314
315        acc.update_batch(&[values])?;
316        let result = acc.evaluate()?;
317
318        assert_eq!(
319            result,
320            scalar::ScalarValue::Utf8View(Some("apple".to_string()))
321        );
322        Ok(())
323    }
324}