1use ahash::RandomState;
23use arrow::array::cast::AsArray;
24use arrow::array::{Array, ArrayBuilder, ArrayRef, GenericByteViewBuilder};
25use arrow::datatypes::{BinaryViewType, ByteViewType, DataType, StringViewType};
26use datafusion_common::hash_utils::create_hashes;
27use datafusion_common::utils::proxy::{HashTableAllocExt, VecAllocExt};
28use std::fmt::Debug;
29use std::sync::Arc;
30
31use crate::binary_map::OutputType;
32
33#[derive(Debug)]
36pub struct ArrowBytesViewSet(ArrowBytesViewMap<()>);
37
38impl ArrowBytesViewSet {
39 pub fn new(output_type: OutputType) -> Self {
40 Self(ArrowBytesViewMap::new(output_type))
41 }
42
43 pub fn insert(&mut self, values: &ArrayRef) {
45 fn make_payload_fn(_value: Option<&[u8]>) {}
46 fn observe_payload_fn(_payload: ()) {}
47 self.0
48 .insert_if_new(values, make_payload_fn, observe_payload_fn);
49 }
50
51 pub fn take(&mut self) -> Self {
54 let mut new_self = Self::new(self.0.output_type);
55 std::mem::swap(self, &mut new_self);
56 new_self
57 }
58
59 pub fn into_state(self) -> ArrayRef {
63 self.0.into_state()
64 }
65
66 pub fn len(&self) -> usize {
68 self.0.len()
69 }
70
71 pub fn is_empty(&self) -> bool {
72 self.0.is_empty()
73 }
74
75 pub fn non_null_len(&self) -> usize {
77 self.0.non_null_len()
78 }
79
80 pub fn size(&self) -> usize {
83 self.0.size()
84 }
85}
86
87pub struct ArrowBytesViewMap<V>
118where
119 V: Debug + PartialEq + Eq + Clone + Copy + Default,
120{
121 output_type: OutputType,
123 map: hashbrown::hash_table::HashTable<Entry<V>>,
125 map_size: usize,
127
128 builder: GenericByteViewBuilder<BinaryViewType>,
130 random_state: RandomState,
132 hashes_buffer: Vec<u64>,
134 null: Option<(V, usize)>,
138}
139
140const INITIAL_MAP_CAPACITY: usize = 512;
142
143impl<V> ArrowBytesViewMap<V>
144where
145 V: Debug + PartialEq + Eq + Clone + Copy + Default,
146{
147 pub fn new(output_type: OutputType) -> Self {
148 Self {
149 output_type,
150 map: hashbrown::hash_table::HashTable::with_capacity(INITIAL_MAP_CAPACITY),
151 map_size: 0,
152 builder: GenericByteViewBuilder::new(),
153 random_state: RandomState::new(),
154 hashes_buffer: vec![],
155 null: None,
156 }
157 }
158
159 pub fn take(&mut self) -> Self {
162 let mut new_self = Self::new(self.output_type);
163 std::mem::swap(self, &mut new_self);
164 new_self
165 }
166
167 pub fn insert_if_new<MP, OP>(
194 &mut self,
195 values: &ArrayRef,
196 make_payload_fn: MP,
197 observe_payload_fn: OP,
198 ) where
199 MP: FnMut(Option<&[u8]>) -> V,
200 OP: FnMut(V),
201 {
202 match self.output_type {
204 OutputType::BinaryView => {
205 assert!(matches!(values.data_type(), DataType::BinaryView));
206 self.insert_if_new_inner::<MP, OP, BinaryViewType>(
207 values,
208 make_payload_fn,
209 observe_payload_fn,
210 )
211 }
212 OutputType::Utf8View => {
213 assert!(matches!(values.data_type(), DataType::Utf8View));
214 self.insert_if_new_inner::<MP, OP, StringViewType>(
215 values,
216 make_payload_fn,
217 observe_payload_fn,
218 )
219 }
220 _ => unreachable!("Utf8/Binary should use `ArrowBytesSet`"),
221 };
222 }
223
224 fn insert_if_new_inner<MP, OP, B>(
233 &mut self,
234 values: &ArrayRef,
235 mut make_payload_fn: MP,
236 mut observe_payload_fn: OP,
237 ) where
238 MP: FnMut(Option<&[u8]>) -> V,
239 OP: FnMut(V),
240 B: ByteViewType,
241 {
242 let batch_hashes = &mut self.hashes_buffer;
244 batch_hashes.clear();
245 batch_hashes.resize(values.len(), 0);
246 create_hashes(&[Arc::clone(values)], &self.random_state, batch_hashes)
247 .unwrap();
250
251 let values = values.as_byte_view::<B>();
253
254 assert_eq!(values.len(), batch_hashes.len());
256
257 for (value, &hash) in values.iter().zip(batch_hashes.iter()) {
258 let Some(value) = value else {
260 let payload = if let Some(&(payload, _offset)) = self.null.as_ref() {
261 payload
262 } else {
263 let payload = make_payload_fn(None);
264 let null_index = self.builder.len();
265 self.builder.append_null();
266 self.null = Some((payload, null_index));
267 payload
268 };
269 observe_payload_fn(payload);
270 continue;
271 };
272
273 let value: &[u8] = value.as_ref();
275
276 let entry = self.map.find_mut(hash, |header| {
277 let v = self.builder.get_value(header.view_idx);
278
279 if v.len() != value.len() {
280 return false;
281 }
282
283 v == value
284 });
285
286 let payload = if let Some(entry) = entry {
287 entry.payload
288 } else {
289 let payload = make_payload_fn(Some(value));
291
292 let inner_view_idx = self.builder.len();
293 let new_header = Entry {
294 view_idx: inner_view_idx,
295 hash,
296 payload,
297 };
298
299 self.builder.append_value(value);
300
301 self.map
302 .insert_accounted(new_header, |h| h.hash, &mut self.map_size);
303 payload
304 };
305 observe_payload_fn(payload);
306 }
307 }
308
309 pub fn into_state(self) -> ArrayRef {
316 let mut builder = self.builder;
317 match self.output_type {
318 OutputType::BinaryView => {
319 let array = builder.finish();
320
321 Arc::new(array)
322 }
323 OutputType::Utf8View => {
324 let array = builder.finish();
329 let array = unsafe { array.to_string_view_unchecked() };
330 Arc::new(array)
331 }
332 _ => {
333 unreachable!("Utf8/Binary should use `ArrowBytesMap`")
334 }
335 }
336 }
337
338 pub fn len(&self) -> usize {
340 self.non_null_len() + self.null.map(|_| 1).unwrap_or(0)
341 }
342
343 pub fn is_empty(&self) -> bool {
345 self.map.is_empty() && self.null.is_none()
346 }
347
348 pub fn non_null_len(&self) -> usize {
350 self.map.len()
351 }
352
353 pub fn size(&self) -> usize {
356 self.map_size
357 + self.builder.allocated_size()
358 + self.hashes_buffer.allocated_size()
359 }
360}
361
362impl<V> Debug for ArrowBytesViewMap<V>
363where
364 V: Debug + PartialEq + Eq + Clone + Copy + Default,
365{
366 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367 f.debug_struct("ArrowBytesMap")
368 .field("map", &"<map>")
369 .field("map_size", &self.map_size)
370 .field("view_builder", &self.builder)
371 .field("random_state", &self.random_state)
372 .field("hashes_buffer", &self.hashes_buffer)
373 .finish()
374 }
375}
376
377#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
379struct Entry<V>
380where
381 V: Debug + PartialEq + Eq + Clone + Copy + Default,
382{
383 view_idx: usize,
385
386 hash: u64,
387
388 payload: V,
390}
391
392#[cfg(test)]
393mod tests {
394 use arrow::array::{BinaryViewArray, GenericByteViewArray, StringViewArray};
395 use datafusion_common::HashMap;
396
397 use super::*;
398
399 fn assert_set(set: ArrowBytesViewSet, expected: &[Option<&str>]) {
401 let strings = set.into_state();
402 let strings = strings.as_string_view();
403 let state = strings.into_iter().collect::<Vec<_>>();
404 assert_eq!(state, expected);
405 }
406
407 #[test]
408 fn string_view_set_empty() {
409 let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
410 let array: ArrayRef = Arc::new(StringViewArray::new_null(0));
411 set.insert(&array);
412 assert_eq!(set.len(), 0);
413 assert_eq!(set.non_null_len(), 0);
414 assert_set(set, &[]);
415 }
416
417 #[test]
418 fn string_view_set_one_null() {
419 let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
420 let array: ArrayRef = Arc::new(StringViewArray::new_null(1));
421 set.insert(&array);
422 assert_eq!(set.len(), 1);
423 assert_eq!(set.non_null_len(), 0);
424 assert_set(set, &[None]);
425 }
426
427 #[test]
428 fn string_view_set_many_null() {
429 let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
430 let array: ArrayRef = Arc::new(StringViewArray::new_null(11));
431 set.insert(&array);
432 assert_eq!(set.len(), 1);
433 assert_eq!(set.non_null_len(), 0);
434 assert_set(set, &[None]);
435 }
436
437 #[test]
438 fn test_string_view_set_basic() {
439 let values = GenericByteViewArray::from(vec![
441 Some("a"),
442 Some("b"),
443 Some("CXCCCCCCCCAABB"), Some(""),
445 Some("cbcxx"), None,
447 Some("AAAAAAAA"), Some("BBBBBQBBBAAA"), Some("a"),
450 Some("cbcxx"),
451 Some("b"),
452 Some("cbcxx"),
453 Some(""),
454 None,
455 Some("BBBBBQBBBAAA"),
456 Some("BBBBBQBBBAAA"),
457 Some("AAAAAAAA"),
458 Some("CXCCCCCCCCAABB"),
459 ]);
460
461 let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
462 let array: ArrayRef = Arc::new(values);
463 set.insert(&array);
464 assert_set(
466 set,
467 &[
468 Some("a"),
469 Some("b"),
470 Some("CXCCCCCCCCAABB"),
471 Some(""),
472 Some("cbcxx"),
473 None,
474 Some("AAAAAAAA"),
475 Some("BBBBBQBBBAAA"),
476 ],
477 );
478 }
479
480 #[test]
481 fn test_string_set_non_utf8() {
482 let values = GenericByteViewArray::from(vec![
484 Some("a"),
485 Some("✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥"),
486 Some("🔥"),
487 Some("✨✨✨"),
488 Some("foobarbaz"),
489 Some("🔥"),
490 Some("✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥"),
491 ]);
492
493 let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
494 let array: ArrayRef = Arc::new(values);
495 set.insert(&array);
496 assert_set(
498 set,
499 &[
500 Some("a"),
501 Some("✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥"),
502 Some("🔥"),
503 Some("✨✨✨"),
504 Some("foobarbaz"),
505 ],
506 );
507 }
508
509 #[test]
511 fn test_binary_set() {
512 let v: Vec<Option<&[u8]>> = vec![
513 Some(b"a"),
514 Some(b"CXCCCCCCCCCCCCC"),
515 None,
516 Some(b"CXCCCCCCCCCCCCC"),
517 ];
518 let values: ArrayRef = Arc::new(BinaryViewArray::from(v));
519
520 let expected: Vec<Option<&[u8]>> =
521 vec![Some(b"a"), Some(b"CXCCCCCCCCCCCCC"), None];
522 let expected: ArrayRef = Arc::new(GenericByteViewArray::from(expected));
523
524 let mut set = ArrowBytesViewSet::new(OutputType::BinaryView);
525 set.insert(&values);
526 assert_eq!(&set.into_state(), &expected);
527 }
528
529 #[test]
531 fn test_string_set_memory_usage() {
532 let strings1 = StringViewArray::from(vec![
533 Some("a"),
534 Some("b"),
535 Some("CXCCCCCCCCCCC"), Some("AAAAAAAA"), Some("BBBBBQBBB"), ]);
539 let total_strings1_len = strings1
540 .iter()
541 .map(|s| s.map(|s| s.len()).unwrap_or(0))
542 .sum::<usize>();
543 let values1: ArrayRef = Arc::new(StringViewArray::from(strings1));
544
545 let strings2 = StringViewArray::from(vec![
547 "FOO".repeat(1000),
548 "BAR larger than 12 bytes.".repeat(100_000),
549 "more unique.".repeat(1000),
550 "more unique2.".repeat(1000),
551 "FOO".repeat(3000),
552 ]);
553 let total_strings2_len = strings2
554 .iter()
555 .map(|s| s.map(|s| s.len()).unwrap_or(0))
556 .sum::<usize>();
557 let values2: ArrayRef = Arc::new(StringViewArray::from(strings2));
558
559 let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
560 let size_empty = set.size();
561
562 set.insert(&values1);
563 let size_after_values1 = set.size();
564 assert!(size_empty < size_after_values1);
565 assert!(
566 size_after_values1 > total_strings1_len,
567 "expect {size_after_values1} to be more than {total_strings1_len}"
568 );
569 assert!(size_after_values1 < total_strings1_len + total_strings2_len);
570
571 set.insert(&values1);
573 assert_eq!(set.size(), size_after_values1);
574 assert_eq!(set.len(), 5);
575
576 set.insert(&values2);
578 let size_after_values2 = set.size();
579 assert!(size_after_values2 > size_after_values1);
580
581 assert_eq!(set.len(), 10);
582 }
583
584 #[derive(Debug, PartialEq, Eq, Default, Clone, Copy)]
585 struct TestPayload {
586 index: usize, }
589
590 struct TestMap {
592 map: ArrowBytesViewMap<TestPayload>,
593 strings: Vec<Option<String>>,
595 indexes: HashMap<Option<String>, usize>,
597 }
598
599 impl TestMap {
600 fn new() -> Self {
603 Self {
604 map: ArrowBytesViewMap::new(OutputType::Utf8View),
605 strings: vec![],
606 indexes: HashMap::new(),
607 }
608 }
609
610 fn insert(&mut self, strings: &[Option<&str>]) {
612 let string_array = StringViewArray::from(strings.to_vec());
613 let arr: ArrayRef = Arc::new(string_array);
614
615 let mut next_index = self.indexes.len();
616 let mut actual_new_strings = vec![];
617 let mut actual_seen_indexes = vec![];
618 for str in strings {
620 let str = str.map(|s| s.to_string());
621 let index = self.indexes.get(&str).cloned().unwrap_or_else(|| {
622 actual_new_strings.push(str.clone());
623 let index = self.strings.len();
624 self.strings.push(str.clone());
625 self.indexes.insert(str, index);
626 index
627 });
628 actual_seen_indexes.push(index);
629 }
630
631 let mut seen_new_strings = vec![];
633 let mut seen_indexes = vec![];
634 self.map.insert_if_new(
635 &arr,
636 |s| {
637 let value = s
638 .map(|s| String::from_utf8(s.to_vec()).expect("Non utf8 string"));
639 let index = next_index;
640 next_index += 1;
641 seen_new_strings.push(value);
642 TestPayload { index }
643 },
644 |payload| {
645 seen_indexes.push(payload.index);
646 },
647 );
648
649 assert_eq!(actual_seen_indexes, seen_indexes);
650 assert_eq!(actual_new_strings, seen_new_strings);
651 }
652
653 fn into_array(self) -> ArrayRef {
656 let Self {
657 map,
658 strings,
659 indexes: _,
660 } = self;
661
662 let arr = map.into_state();
663 let expected: ArrayRef = Arc::new(StringViewArray::from(strings));
664 assert_eq!(&arr, &expected);
665 arr
666 }
667 }
668
669 #[test]
670 fn test_map() {
671 let input = vec![
672 Some("A"),
674 Some("bcdefghijklmnop1234567"),
675 Some("X"),
676 Some("Y"),
677 None,
678 Some("qrstuvqxyzhjwya"),
679 Some("✨🔥"),
680 Some("🔥"),
681 Some("🔥🔥🔥🔥🔥🔥"),
682 ];
683
684 let mut test_map = TestMap::new();
685 test_map.insert(&input);
686 test_map.insert(&input); let expected_output: ArrayRef = Arc::new(StringViewArray::from(input));
688 assert_eq!(&test_map.into_array(), &expected_output);
689 }
690}