datafusion_functions_extra/common/mode/
bytes.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::ArrayAccessor;
19use arrow::array::ArrayIter;
20use arrow::array::ArrayRef;
21use arrow::array::AsArray;
22use arrow::datatypes::DataType;
23use datafusion::arrow;
24use datafusion::common::cast::as_primitive_array;
25use datafusion::common::cast::as_string_array;
26use datafusion::error::Result;
27use datafusion::logical_expr::Accumulator;
28use datafusion::scalar::ScalarValue;
29use std::collections::HashMap;
30
31#[derive(Debug)]
32pub struct BytesModeAccumulator {
33    value_counts: HashMap<String, i64>,
34    data_type: DataType,
35}
36
37impl BytesModeAccumulator {
38    pub fn new(data_type: &DataType) -> Self {
39        Self {
40            value_counts: HashMap::new(),
41            data_type: data_type.clone(),
42        }
43    }
44
45    fn update_counts<'a, V>(&mut self, array: V)
46    where
47        V: ArrayAccessor<Item = &'a str>,
48    {
49        for value in ArrayIter::new(array).flatten() {
50            let key = value;
51            if let Some(count) = self.value_counts.get_mut(key) {
52                *count += 1;
53            } else {
54                self.value_counts.insert(key.to_string(), 1);
55            }
56        }
57    }
58}
59
60impl Accumulator for BytesModeAccumulator {
61    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
62        if values.is_empty() {
63            return Ok(());
64        }
65
66        match &self.data_type {
67            DataType::Utf8View => {
68                let array = values[0].as_string_view();
69                self.update_counts(array);
70            }
71            _ => {
72                let array = values[0].as_string::<i32>();
73                self.update_counts(array);
74            }
75        };
76
77        Ok(())
78    }
79
80    fn state(&mut self) -> Result<Vec<ScalarValue>> {
81        let values: Vec<ScalarValue> = self
82            .value_counts
83            .keys()
84            .map(|key| ScalarValue::Utf8(Some(key.to_string())))
85            .collect();
86
87        let frequencies: Vec<ScalarValue> = self
88            .value_counts
89            .values()
90            .map(|&count| ScalarValue::Int64(Some(count)))
91            .collect();
92
93        let values_scalar = ScalarValue::new_list_nullable(&values, &DataType::Utf8);
94        let frequencies_scalar = ScalarValue::new_list_nullable(&frequencies, &DataType::Int64);
95
96        Ok(vec![
97            ScalarValue::List(values_scalar),
98            ScalarValue::List(frequencies_scalar),
99        ])
100    }
101
102    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
103        if states.is_empty() {
104            return Ok(());
105        }
106
107        let values_array = as_string_array(&states[0])?;
108        let counts_array = as_primitive_array::<arrow::datatypes::Int64Type>(&states[1])?;
109
110        for (i, value_option) in values_array.iter().enumerate() {
111            if let Some(value) = value_option {
112                let count = counts_array.value(i);
113                let entry = self.value_counts.entry(value.to_string()).or_insert(0);
114                *entry += count;
115            }
116        }
117
118        Ok(())
119    }
120
121    fn evaluate(&mut self) -> Result<ScalarValue> {
122        if self.value_counts.is_empty() {
123            return match &self.data_type {
124                DataType::Utf8View => Ok(ScalarValue::Utf8View(None)),
125                _ => Ok(ScalarValue::Utf8(None)),
126            };
127        }
128
129        let mode = self
130            .value_counts
131            .iter()
132            .max_by(|a, b| {
133                // First compare counts
134                a.1.cmp(b.1)
135                    // If counts are equal, compare keys in reverse order to get the maximum string
136                    .then_with(|| b.0.cmp(a.0))
137            })
138            .map(|(value, _)| value.to_string());
139
140        match mode {
141            Some(result) => match &self.data_type {
142                DataType::Utf8View => Ok(ScalarValue::Utf8View(Some(result))),
143                _ => Ok(ScalarValue::Utf8(Some(result))),
144            },
145            None => match &self.data_type {
146                DataType::Utf8View => Ok(ScalarValue::Utf8View(None)),
147                _ => Ok(ScalarValue::Utf8(None)),
148            },
149        }
150    }
151
152    fn size(&self) -> usize {
153        self.value_counts.capacity() * std::mem::size_of::<(String, i64)>() + std::mem::size_of_val(&self.data_type)
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use arrow::array::{ArrayRef, GenericByteViewArray, StringArray};
161    use std::sync::Arc;
162
163    #[test]
164    fn test_mode_accumulator_single_mode_utf8() -> Result<()> {
165        let mut acc = BytesModeAccumulator::new(&DataType::Utf8);
166        let values: ArrayRef = Arc::new(StringArray::from(vec![
167            Some("apple"),
168            Some("banana"),
169            Some("apple"),
170            Some("orange"),
171            Some("banana"),
172            Some("apple"),
173        ]));
174
175        acc.update_batch(&[values])?;
176        let result = acc.evaluate()?;
177
178        assert_eq!(result, ScalarValue::Utf8(Some("apple".to_string())));
179        Ok(())
180    }
181
182    #[test]
183    fn test_mode_accumulator_tie_utf8() -> Result<()> {
184        let mut acc = BytesModeAccumulator::new(&DataType::Utf8);
185        let values: ArrayRef = Arc::new(StringArray::from(vec![
186            Some("apple"),
187            Some("banana"),
188            Some("apple"),
189            Some("orange"),
190            Some("banana"),
191        ]));
192
193        acc.update_batch(&[values])?;
194        let result = acc.evaluate()?;
195
196        assert_eq!(result, ScalarValue::Utf8(Some("apple".to_string())));
197        Ok(())
198    }
199
200    #[test]
201    fn test_mode_accumulator_all_nulls_utf8() -> Result<()> {
202        let mut acc = BytesModeAccumulator::new(&DataType::Utf8);
203        let values: ArrayRef = Arc::new(StringArray::from(vec![None as Option<&str>, None, None]));
204
205        acc.update_batch(&[values])?;
206        let result = acc.evaluate()?;
207
208        assert_eq!(result, ScalarValue::Utf8(None));
209        Ok(())
210    }
211
212    #[test]
213    fn test_mode_accumulator_with_nulls_utf8() -> Result<()> {
214        let mut acc = BytesModeAccumulator::new(&DataType::Utf8);
215        let values: ArrayRef = Arc::new(StringArray::from(vec![
216            Some("apple"),
217            None,
218            Some("banana"),
219            Some("apple"),
220            None,
221            None,
222            None,
223            Some("banana"),
224        ]));
225
226        acc.update_batch(&[values])?;
227        let result = acc.evaluate()?;
228
229        assert_eq!(result, ScalarValue::Utf8(Some("apple".to_string())));
230        Ok(())
231    }
232
233    #[test]
234    fn test_mode_accumulator_single_mode_utf8view() -> Result<()> {
235        let mut acc = BytesModeAccumulator::new(&DataType::Utf8View);
236        let values: ArrayRef = Arc::new(GenericByteViewArray::from(vec![
237            Some("apple"),
238            Some("banana"),
239            Some("apple"),
240            Some("orange"),
241            Some("banana"),
242            Some("apple"),
243        ]));
244
245        acc.update_batch(&[values])?;
246        let result = acc.evaluate()?;
247
248        assert_eq!(result, ScalarValue::Utf8View(Some("apple".to_string())));
249        Ok(())
250    }
251
252    #[test]
253    fn test_mode_accumulator_tie_utf8view() -> Result<()> {
254        let mut acc = BytesModeAccumulator::new(&DataType::Utf8View);
255        let values: ArrayRef = Arc::new(GenericByteViewArray::from(vec![
256            Some("apple"),
257            Some("banana"),
258            Some("apple"),
259            Some("orange"),
260            Some("banana"),
261        ]));
262
263        acc.update_batch(&[values])?;
264        let result = acc.evaluate()?;
265
266        assert_eq!(result, ScalarValue::Utf8View(Some("apple".to_string())));
267        Ok(())
268    }
269
270    #[test]
271    fn test_mode_accumulator_all_nulls_utf8view() -> Result<()> {
272        let mut acc = BytesModeAccumulator::new(&DataType::Utf8View);
273        let values: ArrayRef = Arc::new(GenericByteViewArray::from(vec![None as Option<&str>, None, None]));
274
275        acc.update_batch(&[values])?;
276        let result = acc.evaluate()?;
277
278        assert_eq!(result, ScalarValue::Utf8View(None));
279        Ok(())
280    }
281
282    #[test]
283    fn test_mode_accumulator_with_nulls_utf8view() -> Result<()> {
284        let mut acc = BytesModeAccumulator::new(&DataType::Utf8View);
285        let values: ArrayRef = Arc::new(GenericByteViewArray::from(vec![
286            Some("apple"),
287            None,
288            Some("banana"),
289            Some("apple"),
290            None,
291            None,
292            None,
293            Some("banana"),
294        ]));
295
296        acc.update_batch(&[values])?;
297        let result = acc.evaluate()?;
298
299        assert_eq!(result, ScalarValue::Utf8View(Some("apple".to_string())));
300        Ok(())
301    }
302}