datafusion_functions_extra/common/mode/
bytes.rs1use datafusion::arrow::array::AsArray;
19use datafusion::{arrow, common, error, logical_expr, scalar};
20use std::{collections, mem};
21
22#[derive(Debug)]
23pub struct BytesModeAccumulator {
24 value_counts: collections::HashMap<String, i64>,
25 data_type: arrow::datatypes::DataType,
26}
27
28impl BytesModeAccumulator {
29 pub fn new(data_type: &arrow::datatypes::DataType) -> Self {
30 Self {
31 value_counts: collections::HashMap::new(),
32 data_type: data_type.clone(),
33 }
34 }
35
36 fn update_counts<'a, V>(&mut self, array: V)
37 where
38 V: arrow::array::ArrayAccessor<Item = &'a str>,
39 {
40 for value in arrow::array::ArrayIter::new(array).flatten() {
41 let key = value;
42 if let Some(count) = self.value_counts.get_mut(key) {
43 *count += 1;
44 } else {
45 self.value_counts.insert(key.to_string(), 1);
46 }
47 }
48 }
49}
50
51impl logical_expr::Accumulator for BytesModeAccumulator {
52 fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> error::Result<()> {
53 if values.is_empty() {
54 return Ok(());
55 }
56
57 match &self.data_type {
58 arrow::datatypes::DataType::Utf8View => {
59 let array = values[0].as_string_view();
60 self.update_counts(array);
61 }
62 _ => {
63 let array = values[0].as_string::<i32>();
64 self.update_counts(array);
65 }
66 };
67
68 Ok(())
69 }
70
71 fn state(&mut self) -> error::Result<Vec<scalar::ScalarValue>> {
72 let values: Vec<scalar::ScalarValue> = self
73 .value_counts
74 .keys()
75 .map(|key| scalar::ScalarValue::Utf8(Some(key.to_string())))
76 .collect();
77
78 let frequencies: Vec<scalar::ScalarValue> = self
79 .value_counts
80 .values()
81 .map(|&count| scalar::ScalarValue::Int64(Some(count)))
82 .collect();
83
84 let values_scalar =
85 scalar::ScalarValue::new_list_nullable(&values, &arrow::datatypes::DataType::Utf8);
86 let frequencies_scalar = scalar::ScalarValue::new_list_nullable(
87 &frequencies,
88 &arrow::datatypes::DataType::Int64,
89 );
90
91 Ok(vec![
92 scalar::ScalarValue::List(values_scalar),
93 scalar::ScalarValue::List(frequencies_scalar),
94 ])
95 }
96
97 fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> error::Result<()> {
98 if states.is_empty() {
99 return Ok(());
100 }
101
102 let values_array = arrow::array::as_string_array(&states[0]);
103 let counts_array =
104 common::cast::as_primitive_array::<arrow::datatypes::Int64Type>(&states[1])?;
105
106 for (i, value_option) in values_array.iter().enumerate() {
107 if let Some(value) = value_option {
108 let count = counts_array.value(i);
109 let entry = self.value_counts.entry(value.to_string()).or_insert(0);
110 *entry += count;
111 }
112 }
113
114 Ok(())
115 }
116
117 fn evaluate(&mut self) -> error::Result<scalar::ScalarValue> {
118 if self.value_counts.is_empty() {
119 return match &self.data_type {
120 arrow::datatypes::DataType::Utf8View => Ok(scalar::ScalarValue::Utf8View(None)),
121 _ => Ok(scalar::ScalarValue::Utf8(None)),
122 };
123 }
124
125 let mode = self
126 .value_counts
127 .iter()
128 .max_by(|a, b| {
129 a.1.cmp(b.1)
131 .then_with(|| b.0.cmp(a.0))
133 })
134 .map(|(value, _)| value.to_string());
135
136 match mode {
137 Some(result) => match &self.data_type {
138 arrow::datatypes::DataType::Utf8View => {
139 Ok(scalar::ScalarValue::Utf8View(Some(result)))
140 }
141 _ => Ok(scalar::ScalarValue::Utf8(Some(result))),
142 },
143 None => match &self.data_type {
144 arrow::datatypes::DataType::Utf8View => Ok(scalar::ScalarValue::Utf8View(None)),
145 _ => Ok(scalar::ScalarValue::Utf8(None)),
146 },
147 }
148 }
149
150 fn size(&self) -> usize {
151 self.value_counts.capacity() * mem::size_of::<(String, i64)>()
152 + mem::size_of_val(&self.data_type)
153 }
154}
155
156#[cfg(test)]
157mod tests {
158
159 use super::*;
160
161 use datafusion::logical_expr::Accumulator;
162 use std::sync;
163
164 #[test]
165 fn test_mode_accumulator_single_mode_utf8() -> error::Result<()> {
166 let mut acc = BytesModeAccumulator::new(&arrow::datatypes::DataType::Utf8);
167 let values: arrow::array::ArrayRef = sync::Arc::new(arrow::array::StringArray::from(vec![
168 Some("apple"),
169 Some("banana"),
170 Some("apple"),
171 Some("orange"),
172 Some("banana"),
173 Some("apple"),
174 ]));
175
176 acc.update_batch(&[values])?;
177 let result = acc.evaluate()?;
178
179 assert_eq!(result, scalar::ScalarValue::Utf8(Some("apple".to_string())));
180 Ok(())
181 }
182
183 #[test]
184 fn test_mode_accumulator_tie_utf8() -> error::Result<()> {
185 let mut acc = BytesModeAccumulator::new(&arrow::datatypes::DataType::Utf8);
186 let values: arrow::array::ArrayRef = sync::Arc::new(arrow::array::StringArray::from(vec![
187 Some("apple"),
188 Some("banana"),
189 Some("apple"),
190 Some("orange"),
191 Some("banana"),
192 ]));
193
194 acc.update_batch(&[values])?;
195 let result = acc.evaluate()?;
196
197 assert_eq!(result, scalar::ScalarValue::Utf8(Some("apple".to_string())));
198 Ok(())
199 }
200
201 #[test]
202 fn test_mode_accumulator_all_nulls_utf8() -> error::Result<()> {
203 let mut acc = BytesModeAccumulator::new(&arrow::datatypes::DataType::Utf8);
204 let values: arrow::array::ArrayRef = sync::Arc::new(arrow::array::StringArray::from(vec![
205 None as Option<&str>,
206 None,
207 None,
208 ]));
209
210 acc.update_batch(&[values])?;
211 let result = acc.evaluate()?;
212
213 assert_eq!(result, scalar::ScalarValue::Utf8(None));
214 Ok(())
215 }
216
217 #[test]
218 fn test_mode_accumulator_with_nulls_utf8() -> error::Result<()> {
219 let mut acc = BytesModeAccumulator::new(&arrow::datatypes::DataType::Utf8);
220 let values: arrow::array::ArrayRef = sync::Arc::new(arrow::array::StringArray::from(vec![
221 Some("apple"),
222 None,
223 Some("banana"),
224 Some("apple"),
225 None,
226 None,
227 None,
228 Some("banana"),
229 ]));
230
231 acc.update_batch(&[values])?;
232 let result = acc.evaluate()?;
233
234 assert_eq!(result, scalar::ScalarValue::Utf8(Some("apple".to_string())));
235 Ok(())
236 }
237
238 #[test]
239 fn test_mode_accumulator_single_mode_utf8view() -> error::Result<()> {
240 let mut acc = BytesModeAccumulator::new(&arrow::datatypes::DataType::Utf8View);
241 let values: arrow::array::ArrayRef =
242 sync::Arc::new(arrow::array::GenericByteViewArray::from(vec![
243 Some("apple"),
244 Some("banana"),
245 Some("apple"),
246 Some("orange"),
247 Some("banana"),
248 Some("apple"),
249 ]));
250
251 acc.update_batch(&[values])?;
252 let result = acc.evaluate()?;
253
254 assert_eq!(
255 result,
256 scalar::ScalarValue::Utf8View(Some("apple".to_string()))
257 );
258 Ok(())
259 }
260
261 #[test]
262 fn test_mode_accumulator_tie_utf8view() -> error::Result<()> {
263 let mut acc = BytesModeAccumulator::new(&arrow::datatypes::DataType::Utf8View);
264 let values: arrow::array::ArrayRef =
265 sync::Arc::new(arrow::array::GenericByteViewArray::from(vec![
266 Some("apple"),
267 Some("banana"),
268 Some("apple"),
269 Some("orange"),
270 Some("banana"),
271 ]));
272
273 acc.update_batch(&[values])?;
274 let result = acc.evaluate()?;
275
276 assert_eq!(
277 result,
278 scalar::ScalarValue::Utf8View(Some("apple".to_string()))
279 );
280 Ok(())
281 }
282
283 #[test]
284 fn test_mode_accumulator_all_nulls_utf8view() -> error::Result<()> {
285 let mut acc = BytesModeAccumulator::new(&arrow::datatypes::DataType::Utf8View);
286 let values: arrow::array::ArrayRef =
287 sync::Arc::new(arrow::array::GenericByteViewArray::from(vec![
288 None as Option<&str>,
289 None,
290 None,
291 ]));
292
293 acc.update_batch(&[values])?;
294 let result = acc.evaluate()?;
295
296 assert_eq!(result, scalar::ScalarValue::Utf8View(None));
297 Ok(())
298 }
299
300 #[test]
301 fn test_mode_accumulator_with_nulls_utf8view() -> error::Result<()> {
302 let mut acc = BytesModeAccumulator::new(&arrow::datatypes::DataType::Utf8View);
303 let values: arrow::array::ArrayRef =
304 sync::Arc::new(arrow::array::GenericByteViewArray::from(vec![
305 Some("apple"),
306 None,
307 Some("banana"),
308 Some("apple"),
309 None,
310 None,
311 None,
312 Some("banana"),
313 ]));
314
315 acc.update_batch(&[values])?;
316 let result = acc.evaluate()?;
317
318 assert_eq!(
319 result,
320 scalar::ScalarValue::Utf8View(Some("apple".to_string()))
321 );
322 Ok(())
323 }
324}