datafusion_functions_extra/common/mode/
bytes.rs1use arrow::array::ArrayAccessor;
19use arrow::array::ArrayIter;
20use arrow::array::ArrayRef;
21use arrow::array::AsArray;
22use arrow::datatypes::DataType;
23use datafusion::arrow;
24use datafusion::common::cast::as_primitive_array;
25use datafusion::common::cast::as_string_array;
26use datafusion::error::Result;
27use datafusion::logical_expr::Accumulator;
28use datafusion::scalar::ScalarValue;
29use std::collections::HashMap;
30
31#[derive(Debug)]
32pub struct BytesModeAccumulator {
33 value_counts: HashMap<String, i64>,
34 data_type: DataType,
35}
36
37impl BytesModeAccumulator {
38 pub fn new(data_type: &DataType) -> Self {
39 Self {
40 value_counts: HashMap::new(),
41 data_type: data_type.clone(),
42 }
43 }
44
45 fn update_counts<'a, V>(&mut self, array: V)
46 where
47 V: ArrayAccessor<Item = &'a str>,
48 {
49 for value in ArrayIter::new(array).flatten() {
50 let key = value;
51 if let Some(count) = self.value_counts.get_mut(key) {
52 *count += 1;
53 } else {
54 self.value_counts.insert(key.to_string(), 1);
55 }
56 }
57 }
58}
59
60impl Accumulator for BytesModeAccumulator {
61 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
62 if values.is_empty() {
63 return Ok(());
64 }
65
66 match &self.data_type {
67 DataType::Utf8View => {
68 let array = values[0].as_string_view();
69 self.update_counts(array);
70 }
71 _ => {
72 let array = values[0].as_string::<i32>();
73 self.update_counts(array);
74 }
75 };
76
77 Ok(())
78 }
79
80 fn state(&mut self) -> Result<Vec<ScalarValue>> {
81 let values: Vec<ScalarValue> = self
82 .value_counts
83 .keys()
84 .map(|key| ScalarValue::Utf8(Some(key.to_string())))
85 .collect();
86
87 let frequencies: Vec<ScalarValue> = self
88 .value_counts
89 .values()
90 .map(|&count| ScalarValue::Int64(Some(count)))
91 .collect();
92
93 let values_scalar = ScalarValue::new_list_nullable(&values, &DataType::Utf8);
94 let frequencies_scalar = ScalarValue::new_list_nullable(&frequencies, &DataType::Int64);
95
96 Ok(vec![
97 ScalarValue::List(values_scalar),
98 ScalarValue::List(frequencies_scalar),
99 ])
100 }
101
102 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
103 if states.is_empty() {
104 return Ok(());
105 }
106
107 let values_array = as_string_array(&states[0])?;
108 let counts_array = as_primitive_array::<arrow::datatypes::Int64Type>(&states[1])?;
109
110 for (i, value_option) in values_array.iter().enumerate() {
111 if let Some(value) = value_option {
112 let count = counts_array.value(i);
113 let entry = self.value_counts.entry(value.to_string()).or_insert(0);
114 *entry += count;
115 }
116 }
117
118 Ok(())
119 }
120
121 fn evaluate(&mut self) -> Result<ScalarValue> {
122 if self.value_counts.is_empty() {
123 return match &self.data_type {
124 DataType::Utf8View => Ok(ScalarValue::Utf8View(None)),
125 _ => Ok(ScalarValue::Utf8(None)),
126 };
127 }
128
129 let mode = self
130 .value_counts
131 .iter()
132 .max_by(|a, b| {
133 a.1.cmp(b.1)
135 .then_with(|| b.0.cmp(a.0))
137 })
138 .map(|(value, _)| value.to_string());
139
140 match mode {
141 Some(result) => match &self.data_type {
142 DataType::Utf8View => Ok(ScalarValue::Utf8View(Some(result))),
143 _ => Ok(ScalarValue::Utf8(Some(result))),
144 },
145 None => match &self.data_type {
146 DataType::Utf8View => Ok(ScalarValue::Utf8View(None)),
147 _ => Ok(ScalarValue::Utf8(None)),
148 },
149 }
150 }
151
152 fn size(&self) -> usize {
153 self.value_counts.capacity() * std::mem::size_of::<(String, i64)>() + std::mem::size_of_val(&self.data_type)
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use arrow::array::{ArrayRef, GenericByteViewArray, StringArray};
161 use std::sync::Arc;
162
163 #[test]
164 fn test_mode_accumulator_single_mode_utf8() -> Result<()> {
165 let mut acc = BytesModeAccumulator::new(&DataType::Utf8);
166 let values: ArrayRef = Arc::new(StringArray::from(vec![
167 Some("apple"),
168 Some("banana"),
169 Some("apple"),
170 Some("orange"),
171 Some("banana"),
172 Some("apple"),
173 ]));
174
175 acc.update_batch(&[values])?;
176 let result = acc.evaluate()?;
177
178 assert_eq!(result, ScalarValue::Utf8(Some("apple".to_string())));
179 Ok(())
180 }
181
182 #[test]
183 fn test_mode_accumulator_tie_utf8() -> Result<()> {
184 let mut acc = BytesModeAccumulator::new(&DataType::Utf8);
185 let values: ArrayRef = Arc::new(StringArray::from(vec![
186 Some("apple"),
187 Some("banana"),
188 Some("apple"),
189 Some("orange"),
190 Some("banana"),
191 ]));
192
193 acc.update_batch(&[values])?;
194 let result = acc.evaluate()?;
195
196 assert_eq!(result, ScalarValue::Utf8(Some("apple".to_string())));
197 Ok(())
198 }
199
200 #[test]
201 fn test_mode_accumulator_all_nulls_utf8() -> Result<()> {
202 let mut acc = BytesModeAccumulator::new(&DataType::Utf8);
203 let values: ArrayRef = Arc::new(StringArray::from(vec![None as Option<&str>, None, None]));
204
205 acc.update_batch(&[values])?;
206 let result = acc.evaluate()?;
207
208 assert_eq!(result, ScalarValue::Utf8(None));
209 Ok(())
210 }
211
212 #[test]
213 fn test_mode_accumulator_with_nulls_utf8() -> Result<()> {
214 let mut acc = BytesModeAccumulator::new(&DataType::Utf8);
215 let values: ArrayRef = Arc::new(StringArray::from(vec![
216 Some("apple"),
217 None,
218 Some("banana"),
219 Some("apple"),
220 None,
221 None,
222 None,
223 Some("banana"),
224 ]));
225
226 acc.update_batch(&[values])?;
227 let result = acc.evaluate()?;
228
229 assert_eq!(result, ScalarValue::Utf8(Some("apple".to_string())));
230 Ok(())
231 }
232
233 #[test]
234 fn test_mode_accumulator_single_mode_utf8view() -> Result<()> {
235 let mut acc = BytesModeAccumulator::new(&DataType::Utf8View);
236 let values: ArrayRef = Arc::new(GenericByteViewArray::from(vec![
237 Some("apple"),
238 Some("banana"),
239 Some("apple"),
240 Some("orange"),
241 Some("banana"),
242 Some("apple"),
243 ]));
244
245 acc.update_batch(&[values])?;
246 let result = acc.evaluate()?;
247
248 assert_eq!(result, ScalarValue::Utf8View(Some("apple".to_string())));
249 Ok(())
250 }
251
252 #[test]
253 fn test_mode_accumulator_tie_utf8view() -> Result<()> {
254 let mut acc = BytesModeAccumulator::new(&DataType::Utf8View);
255 let values: ArrayRef = Arc::new(GenericByteViewArray::from(vec![
256 Some("apple"),
257 Some("banana"),
258 Some("apple"),
259 Some("orange"),
260 Some("banana"),
261 ]));
262
263 acc.update_batch(&[values])?;
264 let result = acc.evaluate()?;
265
266 assert_eq!(result, ScalarValue::Utf8View(Some("apple".to_string())));
267 Ok(())
268 }
269
270 #[test]
271 fn test_mode_accumulator_all_nulls_utf8view() -> Result<()> {
272 let mut acc = BytesModeAccumulator::new(&DataType::Utf8View);
273 let values: ArrayRef = Arc::new(GenericByteViewArray::from(vec![None as Option<&str>, None, None]));
274
275 acc.update_batch(&[values])?;
276 let result = acc.evaluate()?;
277
278 assert_eq!(result, ScalarValue::Utf8View(None));
279 Ok(())
280 }
281
282 #[test]
283 fn test_mode_accumulator_with_nulls_utf8view() -> Result<()> {
284 let mut acc = BytesModeAccumulator::new(&DataType::Utf8View);
285 let values: ArrayRef = Arc::new(GenericByteViewArray::from(vec![
286 Some("apple"),
287 None,
288 Some("banana"),
289 Some("apple"),
290 None,
291 None,
292 None,
293 Some("banana"),
294 ]));
295
296 acc.update_batch(&[values])?;
297 let result = acc.evaluate()?;
298
299 assert_eq!(result, ScalarValue::Utf8View(Some("apple".to_string())));
300 Ok(())
301 }
302}