1use crate::binary_map::OutputType;
21use ahash::RandomState;
22use arrow::array::NullBufferBuilder;
23use arrow::array::cast::AsArray;
24use arrow::array::{Array, ArrayRef, BinaryViewArray, ByteView, make_view};
25use arrow::buffer::{Buffer, ScalarBuffer};
26use arrow::datatypes::{BinaryViewType, ByteViewType, DataType, StringViewType};
27use datafusion_common::hash_utils::create_hashes;
28use datafusion_common::utils::proxy::{HashTableAllocExt, VecAllocExt};
29use std::fmt::Debug;
30use std::mem::size_of;
31use std::sync::Arc;
32
33#[derive(Debug)]
36pub struct ArrowBytesViewSet(ArrowBytesViewMap<()>);
37
38impl ArrowBytesViewSet {
39 pub fn new(output_type: OutputType) -> Self {
40 Self(ArrowBytesViewMap::new(output_type))
41 }
42
43 pub fn insert(&mut self, values: &ArrayRef) {
45 fn make_payload_fn(_value: Option<&[u8]>) {}
46 fn observe_payload_fn(_payload: ()) {}
47 self.0
48 .insert_if_new(values, make_payload_fn, observe_payload_fn);
49 }
50
51 pub fn take(&mut self) -> Self {
54 let mut new_self = Self::new(self.0.output_type);
55 std::mem::swap(self, &mut new_self);
56 new_self
57 }
58
59 pub fn into_state(self) -> ArrayRef {
63 self.0.into_state()
64 }
65
66 pub fn len(&self) -> usize {
68 self.0.len()
69 }
70
71 pub fn is_empty(&self) -> bool {
72 self.0.is_empty()
73 }
74
75 pub fn non_null_len(&self) -> usize {
77 self.0.non_null_len()
78 }
79
80 pub fn size(&self) -> usize {
83 self.0.size()
84 }
85}
86
87const BYTE_VIEW_MAX_BLOCK_SIZE: usize = 2 * 1024 * 1024;
119
120pub struct ArrowBytesViewMap<V>
121where
122 V: Debug + PartialEq + Eq + Clone + Copy + Default,
123{
124 output_type: OutputType,
126 map: hashbrown::hash_table::HashTable<Entry<V>>,
128 map_size: usize,
130
131 views: Vec<u128>,
133 in_progress: Vec<u8>,
135 completed: Vec<Buffer>,
137 nulls: NullBufferBuilder,
139
140 random_state: RandomState,
142 hashes_buffer: Vec<u64>,
144 null: Option<(V, usize)>,
148}
149
150const INITIAL_MAP_CAPACITY: usize = 512;
152
153impl<V> ArrowBytesViewMap<V>
154where
155 V: Debug + PartialEq + Eq + Clone + Copy + Default,
156{
157 pub fn new(output_type: OutputType) -> Self {
158 Self {
159 output_type,
160 map: hashbrown::hash_table::HashTable::with_capacity(INITIAL_MAP_CAPACITY),
161 map_size: 0,
162 views: Vec::new(),
163 in_progress: Vec::new(),
164 completed: Vec::new(),
165 nulls: NullBufferBuilder::new(0),
166 random_state: RandomState::new(),
167 hashes_buffer: vec![],
168 null: None,
169 }
170 }
171
172 pub fn take(&mut self) -> Self {
175 let mut new_self = Self::new(self.output_type);
176 std::mem::swap(self, &mut new_self);
177 new_self
178 }
179
180 pub fn insert_if_new<MP, OP>(
207 &mut self,
208 values: &ArrayRef,
209 make_payload_fn: MP,
210 observe_payload_fn: OP,
211 ) where
212 MP: FnMut(Option<&[u8]>) -> V,
213 OP: FnMut(V),
214 {
215 match self.output_type {
217 OutputType::BinaryView => {
218 assert!(matches!(values.data_type(), DataType::BinaryView));
219 self.insert_if_new_inner::<MP, OP, BinaryViewType>(
220 values,
221 make_payload_fn,
222 observe_payload_fn,
223 )
224 }
225 OutputType::Utf8View => {
226 assert!(matches!(values.data_type(), DataType::Utf8View));
227 self.insert_if_new_inner::<MP, OP, StringViewType>(
228 values,
229 make_payload_fn,
230 observe_payload_fn,
231 )
232 }
233 _ => unreachable!("Utf8/Binary should use `ArrowBytesSet`"),
234 };
235 }
236
237 fn insert_if_new_inner<MP, OP, B>(
246 &mut self,
247 values: &ArrayRef,
248 mut make_payload_fn: MP,
249 mut observe_payload_fn: OP,
250 ) where
251 MP: FnMut(Option<&[u8]>) -> V,
252 OP: FnMut(V),
253 B: ByteViewType,
254 {
255 let batch_hashes = &mut self.hashes_buffer;
257 batch_hashes.clear();
258 batch_hashes.resize(values.len(), 0);
259 create_hashes([values], &self.random_state, batch_hashes)
260 .unwrap();
263
264 let values = values.as_byte_view::<B>();
266
267 let input_views = values.views();
269
270 assert_eq!(values.len(), self.hashes_buffer.len());
272
273 for i in 0..values.len() {
274 let view_u128 = input_views[i];
275 let hash = self.hashes_buffer[i];
276
277 if values.is_null(i) {
279 let payload = if let Some(&(payload, _offset)) = self.null.as_ref() {
280 payload
281 } else {
282 let payload = make_payload_fn(None);
283 let null_index = self.views.len();
284 self.views.push(0);
285 self.nulls.append_null();
286 self.null = Some((payload, null_index));
287 payload
288 };
289 observe_payload_fn(payload);
290 continue;
291 }
292
293 let len = view_u128 as u32;
295
296 let maybe_payload = {
298 let completed = &self.completed;
300 let in_progress = &self.in_progress;
301
302 self.map
303 .find(hash, |header| {
304 if header.hash != hash {
305 return false;
306 }
307
308 if len <= 12 {
310 return header.view == view_u128;
311 }
312
313 let stored_prefix = (header.view >> 32) as u32;
315 let input_prefix = (view_u128 >> 32) as u32;
316 if stored_prefix != input_prefix {
317 return false;
318 }
319
320 let byte_view = ByteView::from(header.view);
322 let stored_len = byte_view.length as usize;
323 let buffer_index = byte_view.buffer_index as usize;
324 let offset = byte_view.offset as usize;
325
326 let stored_value = if buffer_index < completed.len() {
327 &completed[buffer_index].as_slice()
328 [offset..offset + stored_len]
329 } else {
330 &in_progress[offset..offset + stored_len]
331 };
332 let input_value: &[u8] = values.value(i).as_ref();
333 stored_value == input_value
334 })
335 .map(|entry| entry.payload)
336 };
337
338 let payload = if let Some(payload) = maybe_payload {
339 payload
340 } else {
341 let value: &[u8] = values.value(i).as_ref();
343 let payload = make_payload_fn(Some(value));
344
345 let new_view = self.append_value(value);
347 let new_header = Entry {
348 view: new_view,
349 hash,
350 payload,
351 };
352
353 self.map
354 .insert_accounted(new_header, |h| h.hash, &mut self.map_size);
355 payload
356 };
357 observe_payload_fn(payload);
358 }
359 }
360
361 pub fn into_state(mut self) -> ArrayRef {
368 if !self.in_progress.is_empty() {
370 let flushed = std::mem::take(&mut self.in_progress);
371 self.completed.push(Buffer::from_vec(flushed));
372 }
373
374 let null_buffer = self.nulls.finish();
376
377 let views = ScalarBuffer::from(self.views);
378 let array =
379 unsafe { BinaryViewArray::new_unchecked(views, self.completed, null_buffer) };
380
381 match self.output_type {
382 OutputType::BinaryView => Arc::new(array),
383 OutputType::Utf8View => {
384 let array = unsafe { array.to_string_view_unchecked() };
386 Arc::new(array)
387 }
388 _ => unreachable!("Utf8/Binary should use `ArrowBytesMap`"),
389 }
390 }
391
392 fn append_value(&mut self, value: &[u8]) -> u128 {
394 let len = value.len();
395 let view = if len <= 12 {
396 make_view(value, 0, 0)
397 } else {
398 if self.in_progress.len() + len > BYTE_VIEW_MAX_BLOCK_SIZE {
400 let flushed = std::mem::replace(
401 &mut self.in_progress,
402 Vec::with_capacity(BYTE_VIEW_MAX_BLOCK_SIZE),
403 );
404 self.completed.push(Buffer::from_vec(flushed));
405 }
406
407 let buffer_index = self.completed.len() as u32;
408 let offset = self.in_progress.len() as u32;
409 self.in_progress.extend_from_slice(value);
410
411 make_view(value, buffer_index, offset)
412 };
413
414 self.views.push(view);
415 self.nulls.append_non_null();
416 view
417 }
418
419 pub fn len(&self) -> usize {
421 self.non_null_len() + self.null.map(|_| 1).unwrap_or(0)
422 }
423
424 pub fn is_empty(&self) -> bool {
426 self.map.is_empty() && self.null.is_none()
427 }
428
429 pub fn non_null_len(&self) -> usize {
431 self.map.len()
432 }
433
434 pub fn size(&self) -> usize {
437 let views_size = self.views.len() * size_of::<u128>();
438 let in_progress_size = self.in_progress.capacity();
439 let completed_size: usize = self.completed.iter().map(|b| b.len()).sum();
440 let nulls_size = self.nulls.allocated_size();
441
442 self.map_size
443 + views_size
444 + in_progress_size
445 + completed_size
446 + nulls_size
447 + self.hashes_buffer.allocated_size()
448 }
449}
450
451impl<V> Debug for ArrowBytesViewMap<V>
452where
453 V: Debug + PartialEq + Eq + Clone + Copy + Default,
454{
455 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
456 f.debug_struct("ArrowBytesMap")
457 .field("map", &"<map>")
458 .field("map_size", &self.map_size)
459 .field("views_len", &self.views.len())
460 .field("completed_buffers", &self.completed.len())
461 .field("random_state", &self.random_state)
462 .field("hashes_buffer", &self.hashes_buffer)
463 .finish()
464 }
465}
466
467#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
474struct Entry<V>
475where
476 V: Debug + PartialEq + Eq + Clone + Copy + Default,
477{
478 view: u128,
482
483 hash: u64,
484
485 payload: V,
487}
488
489#[cfg(test)]
490mod tests {
491 use arrow::array::{BinaryViewArray, GenericByteViewArray, StringViewArray};
492 use datafusion_common::HashMap;
493
494 use super::*;
495
496 fn assert_set(set: ArrowBytesViewSet, expected: &[Option<&str>]) {
498 let strings = set.into_state();
499 let strings = strings.as_string_view();
500 let state = strings.into_iter().collect::<Vec<_>>();
501 assert_eq!(state, expected);
502 }
503
504 #[test]
505 fn string_view_set_empty() {
506 let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
507 let array: ArrayRef = Arc::new(StringViewArray::new_null(0));
508 set.insert(&array);
509 assert_eq!(set.len(), 0);
510 assert_eq!(set.non_null_len(), 0);
511 assert_set(set, &[]);
512 }
513
514 #[test]
515 fn string_view_set_one_null() {
516 let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
517 let array: ArrayRef = Arc::new(StringViewArray::new_null(1));
518 set.insert(&array);
519 assert_eq!(set.len(), 1);
520 assert_eq!(set.non_null_len(), 0);
521 assert_set(set, &[None]);
522 }
523
524 #[test]
525 fn string_view_set_many_null() {
526 let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
527 let array: ArrayRef = Arc::new(StringViewArray::new_null(11));
528 set.insert(&array);
529 assert_eq!(set.len(), 1);
530 assert_eq!(set.non_null_len(), 0);
531 assert_set(set, &[None]);
532 }
533
534 #[test]
535 fn test_string_view_set_basic() {
536 let values = GenericByteViewArray::from(vec![
538 Some("a"),
539 Some("b"),
540 Some("CXCCCCCCCCAABB"), Some(""),
542 Some("cbcxx"), None,
544 Some("AAAAAAAA"), Some("BBBBBQBBBAAA"), Some("a"),
547 Some("cbcxx"),
548 Some("b"),
549 Some("cbcxx"),
550 Some(""),
551 None,
552 Some("BBBBBQBBBAAA"),
553 Some("BBBBBQBBBAAA"),
554 Some("AAAAAAAA"),
555 Some("CXCCCCCCCCAABB"),
556 ]);
557
558 let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
559 let array: ArrayRef = Arc::new(values);
560 set.insert(&array);
561 assert_set(
563 set,
564 &[
565 Some("a"),
566 Some("b"),
567 Some("CXCCCCCCCCAABB"),
568 Some(""),
569 Some("cbcxx"),
570 None,
571 Some("AAAAAAAA"),
572 Some("BBBBBQBBBAAA"),
573 ],
574 );
575 }
576
577 #[test]
578 fn test_string_set_non_utf8() {
579 let values = GenericByteViewArray::from(vec![
581 Some("a"),
582 Some("✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥"),
583 Some("🔥"),
584 Some("✨✨✨"),
585 Some("foobarbaz"),
586 Some("🔥"),
587 Some("✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥"),
588 ]);
589
590 let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
591 let array: ArrayRef = Arc::new(values);
592 set.insert(&array);
593 assert_set(
595 set,
596 &[
597 Some("a"),
598 Some("✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥"),
599 Some("🔥"),
600 Some("✨✨✨"),
601 Some("foobarbaz"),
602 ],
603 );
604 }
605
606 #[test]
608 fn test_binary_set() {
609 let v: Vec<Option<&[u8]>> = vec![
610 Some(b"a"),
611 Some(b"CXCCCCCCCCCCCCC"),
612 None,
613 Some(b"CXCCCCCCCCCCCCC"),
614 ];
615 let values: ArrayRef = Arc::new(BinaryViewArray::from(v));
616
617 let expected: Vec<Option<&[u8]>> =
618 vec![Some(b"a"), Some(b"CXCCCCCCCCCCCCC"), None];
619 let expected: ArrayRef = Arc::new(GenericByteViewArray::from(expected));
620
621 let mut set = ArrowBytesViewSet::new(OutputType::BinaryView);
622 set.insert(&values);
623 assert_eq!(&set.into_state(), &expected);
624 }
625
626 #[test]
628 fn test_string_set_memory_usage() {
629 let strings1 = StringViewArray::from(vec![
630 Some("a"),
631 Some("b"),
632 Some("CXCCCCCCCCCCC"), Some("AAAAAAAA"), Some("BBBBBQBBB"), ]);
636 let total_strings1_len = strings1
637 .iter()
638 .map(|s| s.map(|s| s.len()).unwrap_or(0))
639 .sum::<usize>();
640 let values1: ArrayRef = Arc::new(StringViewArray::from(strings1));
641
642 let strings2 = StringViewArray::from(vec![
644 "FOO".repeat(1000),
645 "BAR larger than 12 bytes.".repeat(100_000),
646 "more unique.".repeat(1000),
647 "more unique2.".repeat(1000),
648 "FOO".repeat(3000),
649 ]);
650 let total_strings2_len = strings2
651 .iter()
652 .map(|s| s.map(|s| s.len()).unwrap_or(0))
653 .sum::<usize>();
654 let values2: ArrayRef = Arc::new(StringViewArray::from(strings2));
655
656 let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
657 let size_empty = set.size();
658
659 set.insert(&values1);
660 let size_after_values1 = set.size();
661 assert!(size_empty < size_after_values1);
662 assert!(
663 size_after_values1 > total_strings1_len,
664 "expect {size_after_values1} to be more than {total_strings1_len}"
665 );
666 assert!(size_after_values1 < total_strings1_len + total_strings2_len);
667
668 set.insert(&values1);
670 assert_eq!(set.size(), size_after_values1);
671 assert_eq!(set.len(), 5);
672
673 set.insert(&values2);
675 let size_after_values2 = set.size();
676 assert!(size_after_values2 > size_after_values1);
677
678 assert_eq!(set.len(), 10);
679 }
680
681 #[derive(Debug, PartialEq, Eq, Default, Clone, Copy)]
682 struct TestPayload {
683 index: usize, }
686
687 struct TestMap {
689 map: ArrowBytesViewMap<TestPayload>,
690 strings: Vec<Option<String>>,
692 indexes: HashMap<Option<String>, usize>,
694 }
695
696 impl TestMap {
697 fn new() -> Self {
700 Self {
701 map: ArrowBytesViewMap::new(OutputType::Utf8View),
702 strings: vec![],
703 indexes: HashMap::new(),
704 }
705 }
706
707 fn insert(&mut self, strings: &[Option<&str>]) {
709 let string_array = StringViewArray::from(strings.to_vec());
710 let arr: ArrayRef = Arc::new(string_array);
711
712 let mut next_index = self.indexes.len();
713 let mut actual_new_strings = vec![];
714 let mut actual_seen_indexes = vec![];
715 for str in strings {
717 let str = str.map(|s| s.to_string());
718 let index = self.indexes.get(&str).cloned().unwrap_or_else(|| {
719 actual_new_strings.push(str.clone());
720 let index = self.strings.len();
721 self.strings.push(str.clone());
722 self.indexes.insert(str, index);
723 index
724 });
725 actual_seen_indexes.push(index);
726 }
727
728 let mut seen_new_strings = vec![];
730 let mut seen_indexes = vec![];
731 self.map.insert_if_new(
732 &arr,
733 |s| {
734 let value = s
735 .map(|s| String::from_utf8(s.to_vec()).expect("Non utf8 string"));
736 let index = next_index;
737 next_index += 1;
738 seen_new_strings.push(value);
739 TestPayload { index }
740 },
741 |payload| {
742 seen_indexes.push(payload.index);
743 },
744 );
745
746 assert_eq!(actual_seen_indexes, seen_indexes);
747 assert_eq!(actual_new_strings, seen_new_strings);
748 }
749
750 fn into_array(self) -> ArrayRef {
753 let Self {
754 map,
755 strings,
756 indexes: _,
757 } = self;
758
759 let arr = map.into_state();
760 let expected: ArrayRef = Arc::new(StringViewArray::from(strings));
761 assert_eq!(&arr, &expected);
762 arr
763 }
764 }
765
766 #[test]
767 fn test_map() {
768 let input = vec![
769 Some("A"),
771 Some("bcdefghijklmnop1234567"),
772 Some("X"),
773 Some("Y"),
774 None,
775 Some("qrstuvqxyzhjwya"),
776 Some("✨🔥"),
777 Some("🔥"),
778 Some("🔥🔥🔥🔥🔥🔥"),
779 ];
780
781 let mut test_map = TestMap::new();
782 test_map.insert(&input);
783 test_map.insert(&input); let expected_output: ArrayRef = Arc::new(StringViewArray::from(input));
785 assert_eq!(&test_map.into_array(), &expected_output);
786 }
787}