1use crate::binary_map::OutputType;
23use ahash::RandomState;
24use arrow::array::cast::AsArray;
25use arrow::array::{Array, ArrayBuilder, ArrayRef, GenericByteViewBuilder};
26use arrow::datatypes::{BinaryViewType, ByteViewType, DataType, StringViewType};
27use datafusion_common::hash_utils::create_hashes;
28use datafusion_common::utils::proxy::{HashTableAllocExt, VecAllocExt};
29use std::fmt::Debug;
30use std::sync::Arc;
31
32#[derive(Debug)]
35pub struct ArrowBytesViewSet(ArrowBytesViewMap<()>);
36
37impl ArrowBytesViewSet {
38 pub fn new(output_type: OutputType) -> Self {
39 Self(ArrowBytesViewMap::new(output_type))
40 }
41
42 pub fn insert(&mut self, values: &ArrayRef) {
44 fn make_payload_fn(_value: Option<&[u8]>) {}
45 fn observe_payload_fn(_payload: ()) {}
46 self.0
47 .insert_if_new(values, make_payload_fn, observe_payload_fn);
48 }
49
50 pub fn take(&mut self) -> Self {
53 let mut new_self = Self::new(self.0.output_type);
54 std::mem::swap(self, &mut new_self);
55 new_self
56 }
57
58 pub fn into_state(self) -> ArrayRef {
62 self.0.into_state()
63 }
64
65 pub fn len(&self) -> usize {
67 self.0.len()
68 }
69
70 pub fn is_empty(&self) -> bool {
71 self.0.is_empty()
72 }
73
74 pub fn non_null_len(&self) -> usize {
76 self.0.non_null_len()
77 }
78
79 pub fn size(&self) -> usize {
82 self.0.size()
83 }
84}
85
86pub struct ArrowBytesViewMap<V>
117where
118 V: Debug + PartialEq + Eq + Clone + Copy + Default,
119{
120 output_type: OutputType,
122 map: hashbrown::hash_table::HashTable<Entry<V>>,
124 map_size: usize,
126
127 builder: GenericByteViewBuilder<BinaryViewType>,
129 random_state: RandomState,
131 hashes_buffer: Vec<u64>,
133 null: Option<(V, usize)>,
137}
138
139const INITIAL_MAP_CAPACITY: usize = 512;
141
142impl<V> ArrowBytesViewMap<V>
143where
144 V: Debug + PartialEq + Eq + Clone + Copy + Default,
145{
146 pub fn new(output_type: OutputType) -> Self {
147 Self {
148 output_type,
149 map: hashbrown::hash_table::HashTable::with_capacity(INITIAL_MAP_CAPACITY),
150 map_size: 0,
151 builder: GenericByteViewBuilder::new(),
152 random_state: RandomState::new(),
153 hashes_buffer: vec![],
154 null: None,
155 }
156 }
157
158 pub fn take(&mut self) -> Self {
161 let mut new_self = Self::new(self.output_type);
162 std::mem::swap(self, &mut new_self);
163 new_self
164 }
165
166 pub fn insert_if_new<MP, OP>(
193 &mut self,
194 values: &ArrayRef,
195 make_payload_fn: MP,
196 observe_payload_fn: OP,
197 ) where
198 MP: FnMut(Option<&[u8]>) -> V,
199 OP: FnMut(V),
200 {
201 match self.output_type {
203 OutputType::BinaryView => {
204 assert!(matches!(values.data_type(), DataType::BinaryView));
205 self.insert_if_new_inner::<MP, OP, BinaryViewType>(
206 values,
207 make_payload_fn,
208 observe_payload_fn,
209 )
210 }
211 OutputType::Utf8View => {
212 assert!(matches!(values.data_type(), DataType::Utf8View));
213 self.insert_if_new_inner::<MP, OP, StringViewType>(
214 values,
215 make_payload_fn,
216 observe_payload_fn,
217 )
218 }
219 _ => unreachable!("Utf8/Binary should use `ArrowBytesSet`"),
220 };
221 }
222
223 fn insert_if_new_inner<MP, OP, B>(
232 &mut self,
233 values: &ArrayRef,
234 mut make_payload_fn: MP,
235 mut observe_payload_fn: OP,
236 ) where
237 MP: FnMut(Option<&[u8]>) -> V,
238 OP: FnMut(V),
239 B: ByteViewType,
240 {
241 let batch_hashes = &mut self.hashes_buffer;
243 batch_hashes.clear();
244 batch_hashes.resize(values.len(), 0);
245 create_hashes([values], &self.random_state, batch_hashes)
246 .unwrap();
249
250 let values = values.as_byte_view::<B>();
252
253 assert_eq!(values.len(), batch_hashes.len());
255
256 for (value, &hash) in values.iter().zip(batch_hashes.iter()) {
257 let Some(value) = value else {
259 let payload = if let Some(&(payload, _offset)) = self.null.as_ref() {
260 payload
261 } else {
262 let payload = make_payload_fn(None);
263 let null_index = self.builder.len();
264 self.builder.append_null();
265 self.null = Some((payload, null_index));
266 payload
267 };
268 observe_payload_fn(payload);
269 continue;
270 };
271
272 let value: &[u8] = value.as_ref();
274
275 let entry = self.map.find_mut(hash, |header| {
276 let v = self.builder.get_value(header.view_idx);
277
278 if v.len() != value.len() {
279 return false;
280 }
281
282 v == value
283 });
284
285 let payload = if let Some(entry) = entry {
286 entry.payload
287 } else {
288 let payload = make_payload_fn(Some(value));
290
291 let inner_view_idx = self.builder.len();
292 let new_header = Entry {
293 view_idx: inner_view_idx,
294 hash,
295 payload,
296 };
297
298 self.builder.append_value(value);
299
300 self.map
301 .insert_accounted(new_header, |h| h.hash, &mut self.map_size);
302 payload
303 };
304 observe_payload_fn(payload);
305 }
306 }
307
308 pub fn into_state(self) -> ArrayRef {
315 let mut builder = self.builder;
316 match self.output_type {
317 OutputType::BinaryView => {
318 let array = builder.finish();
319
320 Arc::new(array)
321 }
322 OutputType::Utf8View => {
323 let array = builder.finish();
328 let array = unsafe { array.to_string_view_unchecked() };
329 Arc::new(array)
330 }
331 _ => {
332 unreachable!("Utf8/Binary should use `ArrowBytesMap`")
333 }
334 }
335 }
336
337 pub fn len(&self) -> usize {
339 self.non_null_len() + self.null.map(|_| 1).unwrap_or(0)
340 }
341
342 pub fn is_empty(&self) -> bool {
344 self.map.is_empty() && self.null.is_none()
345 }
346
347 pub fn non_null_len(&self) -> usize {
349 self.map.len()
350 }
351
352 pub fn size(&self) -> usize {
355 self.map_size
356 + self.builder.allocated_size()
357 + self.hashes_buffer.allocated_size()
358 }
359}
360
361impl<V> Debug for ArrowBytesViewMap<V>
362where
363 V: Debug + PartialEq + Eq + Clone + Copy + Default,
364{
365 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
366 f.debug_struct("ArrowBytesMap")
367 .field("map", &"<map>")
368 .field("map_size", &self.map_size)
369 .field("view_builder", &self.builder)
370 .field("random_state", &self.random_state)
371 .field("hashes_buffer", &self.hashes_buffer)
372 .finish()
373 }
374}
375
376#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
378struct Entry<V>
379where
380 V: Debug + PartialEq + Eq + Clone + Copy + Default,
381{
382 view_idx: usize,
384
385 hash: u64,
386
387 payload: V,
389}
390
391#[cfg(test)]
392mod tests {
393 use arrow::array::{BinaryViewArray, GenericByteViewArray, StringViewArray};
394 use datafusion_common::HashMap;
395
396 use super::*;
397
398 fn assert_set(set: ArrowBytesViewSet, expected: &[Option<&str>]) {
400 let strings = set.into_state();
401 let strings = strings.as_string_view();
402 let state = strings.into_iter().collect::<Vec<_>>();
403 assert_eq!(state, expected);
404 }
405
406 #[test]
407 fn string_view_set_empty() {
408 let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
409 let array: ArrayRef = Arc::new(StringViewArray::new_null(0));
410 set.insert(&array);
411 assert_eq!(set.len(), 0);
412 assert_eq!(set.non_null_len(), 0);
413 assert_set(set, &[]);
414 }
415
416 #[test]
417 fn string_view_set_one_null() {
418 let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
419 let array: ArrayRef = Arc::new(StringViewArray::new_null(1));
420 set.insert(&array);
421 assert_eq!(set.len(), 1);
422 assert_eq!(set.non_null_len(), 0);
423 assert_set(set, &[None]);
424 }
425
426 #[test]
427 fn string_view_set_many_null() {
428 let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
429 let array: ArrayRef = Arc::new(StringViewArray::new_null(11));
430 set.insert(&array);
431 assert_eq!(set.len(), 1);
432 assert_eq!(set.non_null_len(), 0);
433 assert_set(set, &[None]);
434 }
435
436 #[test]
437 fn test_string_view_set_basic() {
438 let values = GenericByteViewArray::from(vec![
440 Some("a"),
441 Some("b"),
442 Some("CXCCCCCCCCAABB"), Some(""),
444 Some("cbcxx"), None,
446 Some("AAAAAAAA"), Some("BBBBBQBBBAAA"), Some("a"),
449 Some("cbcxx"),
450 Some("b"),
451 Some("cbcxx"),
452 Some(""),
453 None,
454 Some("BBBBBQBBBAAA"),
455 Some("BBBBBQBBBAAA"),
456 Some("AAAAAAAA"),
457 Some("CXCCCCCCCCAABB"),
458 ]);
459
460 let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
461 let array: ArrayRef = Arc::new(values);
462 set.insert(&array);
463 assert_set(
465 set,
466 &[
467 Some("a"),
468 Some("b"),
469 Some("CXCCCCCCCCAABB"),
470 Some(""),
471 Some("cbcxx"),
472 None,
473 Some("AAAAAAAA"),
474 Some("BBBBBQBBBAAA"),
475 ],
476 );
477 }
478
479 #[test]
480 fn test_string_set_non_utf8() {
481 let values = GenericByteViewArray::from(vec![
483 Some("a"),
484 Some("✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥"),
485 Some("🔥"),
486 Some("✨✨✨"),
487 Some("foobarbaz"),
488 Some("🔥"),
489 Some("✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥"),
490 ]);
491
492 let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
493 let array: ArrayRef = Arc::new(values);
494 set.insert(&array);
495 assert_set(
497 set,
498 &[
499 Some("a"),
500 Some("✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥"),
501 Some("🔥"),
502 Some("✨✨✨"),
503 Some("foobarbaz"),
504 ],
505 );
506 }
507
508 #[test]
510 fn test_binary_set() {
511 let v: Vec<Option<&[u8]>> = vec![
512 Some(b"a"),
513 Some(b"CXCCCCCCCCCCCCC"),
514 None,
515 Some(b"CXCCCCCCCCCCCCC"),
516 ];
517 let values: ArrayRef = Arc::new(BinaryViewArray::from(v));
518
519 let expected: Vec<Option<&[u8]>> =
520 vec![Some(b"a"), Some(b"CXCCCCCCCCCCCCC"), None];
521 let expected: ArrayRef = Arc::new(GenericByteViewArray::from(expected));
522
523 let mut set = ArrowBytesViewSet::new(OutputType::BinaryView);
524 set.insert(&values);
525 assert_eq!(&set.into_state(), &expected);
526 }
527
528 #[test]
530 fn test_string_set_memory_usage() {
531 let strings1 = StringViewArray::from(vec![
532 Some("a"),
533 Some("b"),
534 Some("CXCCCCCCCCCCC"), Some("AAAAAAAA"), Some("BBBBBQBBB"), ]);
538 let total_strings1_len = strings1
539 .iter()
540 .map(|s| s.map(|s| s.len()).unwrap_or(0))
541 .sum::<usize>();
542 let values1: ArrayRef = Arc::new(StringViewArray::from(strings1));
543
544 let strings2 = StringViewArray::from(vec![
546 "FOO".repeat(1000),
547 "BAR larger than 12 bytes.".repeat(100_000),
548 "more unique.".repeat(1000),
549 "more unique2.".repeat(1000),
550 "FOO".repeat(3000),
551 ]);
552 let total_strings2_len = strings2
553 .iter()
554 .map(|s| s.map(|s| s.len()).unwrap_or(0))
555 .sum::<usize>();
556 let values2: ArrayRef = Arc::new(StringViewArray::from(strings2));
557
558 let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
559 let size_empty = set.size();
560
561 set.insert(&values1);
562 let size_after_values1 = set.size();
563 assert!(size_empty < size_after_values1);
564 assert!(
565 size_after_values1 > total_strings1_len,
566 "expect {size_after_values1} to be more than {total_strings1_len}"
567 );
568 assert!(size_after_values1 < total_strings1_len + total_strings2_len);
569
570 set.insert(&values1);
572 assert_eq!(set.size(), size_after_values1);
573 assert_eq!(set.len(), 5);
574
575 set.insert(&values2);
577 let size_after_values2 = set.size();
578 assert!(size_after_values2 > size_after_values1);
579
580 assert_eq!(set.len(), 10);
581 }
582
583 #[derive(Debug, PartialEq, Eq, Default, Clone, Copy)]
584 struct TestPayload {
585 index: usize, }
588
589 struct TestMap {
591 map: ArrowBytesViewMap<TestPayload>,
592 strings: Vec<Option<String>>,
594 indexes: HashMap<Option<String>, usize>,
596 }
597
598 impl TestMap {
599 fn new() -> Self {
602 Self {
603 map: ArrowBytesViewMap::new(OutputType::Utf8View),
604 strings: vec![],
605 indexes: HashMap::new(),
606 }
607 }
608
609 fn insert(&mut self, strings: &[Option<&str>]) {
611 let string_array = StringViewArray::from(strings.to_vec());
612 let arr: ArrayRef = Arc::new(string_array);
613
614 let mut next_index = self.indexes.len();
615 let mut actual_new_strings = vec![];
616 let mut actual_seen_indexes = vec![];
617 for str in strings {
619 let str = str.map(|s| s.to_string());
620 let index = self.indexes.get(&str).cloned().unwrap_or_else(|| {
621 actual_new_strings.push(str.clone());
622 let index = self.strings.len();
623 self.strings.push(str.clone());
624 self.indexes.insert(str, index);
625 index
626 });
627 actual_seen_indexes.push(index);
628 }
629
630 let mut seen_new_strings = vec![];
632 let mut seen_indexes = vec![];
633 self.map.insert_if_new(
634 &arr,
635 |s| {
636 let value = s
637 .map(|s| String::from_utf8(s.to_vec()).expect("Non utf8 string"));
638 let index = next_index;
639 next_index += 1;
640 seen_new_strings.push(value);
641 TestPayload { index }
642 },
643 |payload| {
644 seen_indexes.push(payload.index);
645 },
646 );
647
648 assert_eq!(actual_seen_indexes, seen_indexes);
649 assert_eq!(actual_new_strings, seen_new_strings);
650 }
651
652 fn into_array(self) -> ArrayRef {
655 let Self {
656 map,
657 strings,
658 indexes: _,
659 } = self;
660
661 let arr = map.into_state();
662 let expected: ArrayRef = Arc::new(StringViewArray::from(strings));
663 assert_eq!(&arr, &expected);
664 arr
665 }
666 }
667
668 #[test]
669 fn test_map() {
670 let input = vec![
671 Some("A"),
673 Some("bcdefghijklmnop1234567"),
674 Some("X"),
675 Some("Y"),
676 None,
677 Some("qrstuvqxyzhjwya"),
678 Some("✨🔥"),
679 Some("🔥"),
680 Some("🔥🔥🔥🔥🔥🔥"),
681 ];
682
683 let mut test_map = TestMap::new();
684 test_map.insert(&input);
685 test_map.insert(&input); let expected_output: ArrayRef = Arc::new(StringViewArray::from(input));
687 assert_eq!(&test_map.into_array(), &expected_output);
688 }
689}