datafusion_physical_expr_common/
binary_map.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ArrowBytesMap`] and [`ArrowBytesSet`] for storing maps/sets of values from
19//! StringArray / LargeStringArray / BinaryArray / LargeBinaryArray.
20
21use ahash::RandomState;
22use arrow::array::{
23    cast::AsArray,
24    types::{ByteArrayType, GenericBinaryType, GenericStringType},
25    Array, ArrayRef, BufferBuilder, GenericBinaryArray, GenericStringArray,
26    NullBufferBuilder, OffsetSizeTrait,
27};
28use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer};
29use arrow::datatypes::DataType;
30use datafusion_common::hash_utils::create_hashes;
31use datafusion_common::utils::proxy::{HashTableAllocExt, VecAllocExt};
32use std::any::type_name;
33use std::fmt::Debug;
34use std::mem::{size_of, swap};
35use std::ops::Range;
36use std::sync::Arc;
37
38/// Should the output be a String or Binary?
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum OutputType {
41    /// `StringArray` or `LargeStringArray`
42    Utf8,
43    /// `StringViewArray`
44    Utf8View,
45    /// `BinaryArray` or `LargeBinaryArray`
46    Binary,
47    /// `BinaryViewArray`
48    BinaryView,
49}
50
51/// HashSet optimized for storing string or binary values that can produce that
52/// the final set as a GenericStringArray with minimal copies.
53#[derive(Debug)]
54pub struct ArrowBytesSet<O: OffsetSizeTrait>(ArrowBytesMap<O, ()>);
55
56impl<O: OffsetSizeTrait> ArrowBytesSet<O> {
57    pub fn new(output_type: OutputType) -> Self {
58        Self(ArrowBytesMap::new(output_type))
59    }
60
61    /// Return the contents of this set and replace it with a new empty
62    /// set with the same output type
63    pub fn take(&mut self) -> Self {
64        Self(self.0.take())
65    }
66
67    /// Inserts each value from `values` into the set
68    pub fn insert(&mut self, values: &ArrayRef) {
69        fn make_payload_fn(_value: Option<&[u8]>) {}
70        fn observe_payload_fn(_payload: ()) {}
71        self.0
72            .insert_if_new(values, make_payload_fn, observe_payload_fn);
73    }
74
75    /// Converts this set into a `StringArray`/`LargeStringArray` or
76    /// `BinaryArray`/`LargeBinaryArray` containing each distinct value that
77    /// was interned. This is done without copying the values.
78    pub fn into_state(self) -> ArrayRef {
79        self.0.into_state()
80    }
81
82    /// Returns the total number of distinct values (including nulls) seen so far
83    pub fn len(&self) -> usize {
84        self.0.len()
85    }
86
87    pub fn is_empty(&self) -> bool {
88        self.0.is_empty()
89    }
90
91    /// returns the total number of distinct values (not including nulls) seen so far
92    pub fn non_null_len(&self) -> usize {
93        self.0.non_null_len()
94    }
95
96    /// Return the total size, in bytes, of memory used to store the data in
97    /// this set, not including `self`
98    pub fn size(&self) -> usize {
99        self.0.size()
100    }
101}
102
103/// Optimized map for storing Arrow "bytes" types (`String`, `LargeString`,
104/// `Binary`, and `LargeBinary`) values that can produce the set of keys on
105/// output as `GenericBinaryArray` without copies.
106///
107/// Equivalent to `HashSet<String, V>` but with better performance if you need
108/// to emit the keys as an Arrow `StringArray` / `BinaryArray`. For other
109/// purposes it is the same as a `HashMap<String, V>`
110///
111/// # Generic Arguments
112///
113/// * `O`: OffsetSize (String/LargeString)
114/// * `V`: payload type
115///
116/// # Description
117///
118/// This is a specialized HashMap with the following properties:
119///
120/// 1. Optimized for storing and emitting Arrow byte types  (e.g.
121///    `StringArray` / `BinaryArray`) very efficiently by minimizing copying of
122///    the string values themselves, both when inserting and when emitting the
123///    final array.
124///
125///
126/// 2. Retains the insertion order of entries in the final array. The values are
127///    in the same order as they were inserted.
128///
129/// Note this structure can be used as a `HashSet` by specifying the value type
130/// as `()`, as is done by [`ArrowBytesSet`].
131///
132/// This map is used by the special `COUNT DISTINCT` aggregate function to
133/// store the distinct values, and by the `GROUP BY` operator to store
134/// group values when they are a single string array.
135///
136/// # Example
137///
138/// The following diagram shows how the map would store the four strings
139/// "Foo", NULL, "Bar", "TheQuickBrownFox":
140///
141/// * `hashtable` stores entries for each distinct string that has been
142///   inserted. The entries contain the payload as well as information about the
143///   value (either an offset or the actual bytes, see `Entry` docs for more
144///   details)
145///
146/// * `offsets` stores offsets into `buffer` for each distinct string value,
147///   following the same convention as the offsets in a `StringArray` or
148///   `LargeStringArray`.
149///
150/// * `buffer` stores the actual byte data
151///
152/// * `null`: stores the index and payload of the null value, in this case the
153///   second value (index 1)
154///
155/// ```text
156/// ┌───────────────────────────────────┐    ┌─────┐    ┌────┐
157/// │                ...                │    │  0  │    │FooB│
158/// │ ┌──────────────────────────────┐  │    │  0  │    │arTh│
159/// │ │      <Entry for "Bar">       │  │    │  3  │    │eQui│
160/// │ │            len: 3            │  │    │  3  │    │ckBr│
161/// │ │   offset_or_inline: "Bar"    │  │    │  6  │    │ownF│
162/// │ │         payload:...          │  │    │     │    │ox  │
163/// │ └──────────────────────────────┘  │    │     │    │    │
164/// │                ...                │    └─────┘    └────┘
165/// │ ┌──────────────────────────────┐  │
166/// │ │<Entry for "TheQuickBrownFox">│  │    offsets    buffer
167/// │ │           len: 16            │  │
168/// │ │     offset_or_inline: 6      │  │    ┌───────────────┐
169/// │ │         payload: ...         │  │    │    Some(1)    │
170/// │ └──────────────────────────────┘  │    │ payload: ...  │
171/// │                ...                │    └───────────────┘
172/// └───────────────────────────────────┘
173///                                              null
174///               HashTable
175/// ```
176///
177/// # Entry Format
178///
179/// Entries stored in a [`ArrowBytesMap`] represents a value that is either
180/// stored inline or in the buffer
181///
182/// This helps the case where there are many short (less than 8 bytes) strings
183/// that are the same (e.g. "MA", "CA", "NY", "TX", etc)
184///
185/// ```text
186///                                                                ┌──────────────────┐
187///                                                  ─ ─ ─ ─ ─ ─ ─▶│...               │
188///                                                 │              │TheQuickBrownFox  │
189///                                                                │...               │
190///                                                 │              │                  │
191///                                                                └──────────────────┘
192///                                                 │               buffer of u8
193///
194///                                                 │
195///                        ┌────────────────┬───────────────┬───────────────┐
196///  Storing               │                │ starting byte │  length, in   │
197///  "TheQuickBrownFox"    │   hash value   │   offset in   │  bytes (not   │
198///  (long string)         │                │    buffer     │  characters)  │
199///                        └────────────────┴───────────────┴───────────────┘
200///                              8 bytes          8 bytes       4 or 8
201///
202///
203///                         ┌───────────────┬─┬─┬─┬─┬─┬─┬─┬─┬───────────────┐
204/// Storing "foobar"        │               │ │ │ │ │ │ │ │ │  length, in   │
205/// (short string)          │  hash value   │?│?│f│o│o│b│a│r│  bytes (not   │
206///                         │               │ │ │ │ │ │ │ │ │  characters)  │
207///                         └───────────────┴─┴─┴─┴─┴─┴─┴─┴─┴───────────────┘
208///                              8 bytes         8 bytes        4 or 8
209/// ```
210pub struct ArrowBytesMap<O, V>
211where
212    O: OffsetSizeTrait,
213    V: Debug + PartialEq + Eq + Clone + Copy + Default,
214{
215    /// Should the output be String or Binary?
216    output_type: OutputType,
217    /// Underlying hash set for each distinct value
218    map: hashbrown::hash_table::HashTable<Entry<O, V>>,
219    /// Total size of the map in bytes
220    map_size: usize,
221    /// In progress arrow `Buffer` containing all values
222    buffer: BufferBuilder<u8>,
223    /// Offsets into `buffer` for each distinct  value. These offsets as used
224    /// directly to create the final `GenericBinaryArray`. The `i`th string is
225    /// stored in the range `offsets[i]..offsets[i+1]` in `buffer`. Null values
226    /// are stored as a zero length string.
227    offsets: Vec<O>,
228    /// random state used to generate hashes
229    random_state: RandomState,
230    /// buffer that stores hash values (reused across batches to save allocations)
231    hashes_buffer: Vec<u64>,
232    /// `(payload, null_index)` for the 'null' value, if any
233    /// NOTE null_index is the logical index in the final array, not the index
234    /// in the buffer
235    null: Option<(V, usize)>,
236}
237
238/// The size, in number of entries, of the initial hash table
239const INITIAL_MAP_CAPACITY: usize = 128;
240/// The initial size, in bytes, of the string data
241pub const INITIAL_BUFFER_CAPACITY: usize = 8 * 1024;
242impl<O: OffsetSizeTrait, V> ArrowBytesMap<O, V>
243where
244    V: Debug + PartialEq + Eq + Clone + Copy + Default,
245{
246    pub fn new(output_type: OutputType) -> Self {
247        Self {
248            output_type,
249            map: hashbrown::hash_table::HashTable::with_capacity(INITIAL_MAP_CAPACITY),
250            map_size: 0,
251            buffer: BufferBuilder::new(INITIAL_BUFFER_CAPACITY),
252            offsets: vec![O::default()], // first offset is always 0
253            random_state: RandomState::new(),
254            hashes_buffer: vec![],
255            null: None,
256        }
257    }
258
259    /// Return the contents of this map and replace it with a new empty map with
260    /// the same output type
261    pub fn take(&mut self) -> Self {
262        let mut new_self = Self::new(self.output_type);
263        swap(self, &mut new_self);
264        new_self
265    }
266
267    /// Inserts each value from `values` into the map, invoking `payload_fn` for
268    /// each value if *not* already present, deferring the allocation of the
269    /// payload until it is needed.
270    ///
271    /// Note that this is different than a normal map that would replace the
272    /// existing entry
273    ///
274    /// # Arguments:
275    ///
276    /// `values`: array whose values are inserted
277    ///
278    /// `make_payload_fn`:  invoked for each value that is not already present
279    /// to create the payload, in order of the values in `values`
280    ///
281    /// `observe_payload_fn`: invoked once, for each value in `values`, that was
282    /// already present in the map, with corresponding payload value.
283    ///
284    /// # Returns
285    ///
286    /// The payload value for the entry, either the existing value or
287    /// the newly inserted value
288    ///
289    /// # Safety:
290    ///
291    /// Note that `make_payload_fn` and `observe_payload_fn` are only invoked
292    /// with valid values from `values`, not for the `NULL` value.
293    pub fn insert_if_new<MP, OP>(
294        &mut self,
295        values: &ArrayRef,
296        make_payload_fn: MP,
297        observe_payload_fn: OP,
298    ) where
299        MP: FnMut(Option<&[u8]>) -> V,
300        OP: FnMut(V),
301    {
302        // Sanity array type
303        match self.output_type {
304            OutputType::Binary => {
305                assert!(matches!(
306                    values.data_type(),
307                    DataType::Binary | DataType::LargeBinary
308                ));
309                self.insert_if_new_inner::<MP, OP, GenericBinaryType<O>>(
310                    values,
311                    make_payload_fn,
312                    observe_payload_fn,
313                )
314            }
315            OutputType::Utf8 => {
316                assert!(matches!(
317                    values.data_type(),
318                    DataType::Utf8 | DataType::LargeUtf8
319                ));
320                self.insert_if_new_inner::<MP, OP, GenericStringType<O>>(
321                    values,
322                    make_payload_fn,
323                    observe_payload_fn,
324                )
325            }
326            _ => unreachable!("View types should use `ArrowBytesViewMap`"),
327        };
328    }
329
330    /// Generic version of [`Self::insert_if_new`] that handles `ByteArrayType`
331    /// (both String and Binary)
332    ///
333    /// Note this is the only function that is generic on [`ByteArrayType`], which
334    /// avoids having to template the entire structure,  making the code
335    /// simpler and understand and reducing code bloat due to duplication.
336    ///
337    /// See comments on `insert_if_new` for more details
338    fn insert_if_new_inner<MP, OP, B>(
339        &mut self,
340        values: &ArrayRef,
341        mut make_payload_fn: MP,
342        mut observe_payload_fn: OP,
343    ) where
344        MP: FnMut(Option<&[u8]>) -> V,
345        OP: FnMut(V),
346        B: ByteArrayType,
347    {
348        // step 1: compute hashes
349        let batch_hashes = &mut self.hashes_buffer;
350        batch_hashes.clear();
351        batch_hashes.resize(values.len(), 0);
352        create_hashes(&[Arc::clone(values)], &self.random_state, batch_hashes)
353            // hash is supported for all types and create_hashes only
354            // returns errors for unsupported types
355            .unwrap();
356
357        // step 2: insert each value into the set, if not already present
358        let values = values.as_bytes::<B>();
359
360        // Ensure lengths are equivalent
361        assert_eq!(values.len(), batch_hashes.len());
362
363        for (value, &hash) in values.iter().zip(batch_hashes.iter()) {
364            // handle null value
365            let Some(value) = value else {
366                let payload = if let Some(&(payload, _offset)) = self.null.as_ref() {
367                    payload
368                } else {
369                    let payload = make_payload_fn(None);
370                    let null_index = self.offsets.len() - 1;
371                    // nulls need a zero length in the offset buffer
372                    let offset = self.buffer.len();
373                    self.offsets.push(O::usize_as(offset));
374                    self.null = Some((payload, null_index));
375                    payload
376                };
377                observe_payload_fn(payload);
378                continue;
379            };
380
381            // get the value as bytes
382            let value: &[u8] = value.as_ref();
383            let value_len = O::usize_as(value.len());
384
385            // value is "small"
386            let payload = if value.len() <= SHORT_VALUE_LEN {
387                let inline = value.iter().fold(0usize, |acc, &x| (acc << 8) | x as usize);
388
389                // is value is already present in the set?
390                let entry = self.map.find_mut(hash, |header| {
391                    // compare value if hashes match
392                    if header.len != value_len {
393                        return false;
394                    }
395                    // value is stored inline so no need to consult buffer
396                    // (this is the "small string optimization")
397                    inline == header.offset_or_inline
398                });
399
400                if let Some(entry) = entry {
401                    entry.payload
402                }
403                // if no existing entry, make a new one
404                else {
405                    // Put the small values into buffer and offsets so it appears
406                    // the output array, but store the actual bytes inline for
407                    // comparison
408                    self.buffer.append_slice(value);
409                    self.offsets.push(O::usize_as(self.buffer.len()));
410                    let payload = make_payload_fn(Some(value));
411                    let new_header = Entry {
412                        hash,
413                        len: value_len,
414                        offset_or_inline: inline,
415                        payload,
416                    };
417                    self.map.insert_accounted(
418                        new_header,
419                        |header| header.hash,
420                        &mut self.map_size,
421                    );
422                    payload
423                }
424            }
425            // value is not "small"
426            else {
427                // Check if the value is already present in the set
428                let entry = self.map.find_mut(hash, |header| {
429                    // compare value if hashes match
430                    if header.len != value_len {
431                        return false;
432                    }
433                    // Need to compare the bytes in the buffer
434                    // SAFETY: buffer is only appended to, and we correctly inserted values and offsets
435                    let existing_value =
436                        unsafe { self.buffer.as_slice().get_unchecked(header.range()) };
437                    value == existing_value
438                });
439
440                if let Some(entry) = entry {
441                    entry.payload
442                }
443                // if no existing entry, make a new one
444                else {
445                    // Put the small values into buffer and offsets so it
446                    // appears the output array, and store that offset
447                    // so the bytes can be compared if needed
448                    let offset = self.buffer.len(); // offset of start for data
449                    self.buffer.append_slice(value);
450                    self.offsets.push(O::usize_as(self.buffer.len()));
451
452                    let payload = make_payload_fn(Some(value));
453                    let new_header = Entry {
454                        hash,
455                        len: value_len,
456                        offset_or_inline: offset,
457                        payload,
458                    };
459                    self.map.insert_accounted(
460                        new_header,
461                        |header| header.hash,
462                        &mut self.map_size,
463                    );
464                    payload
465                }
466            };
467            observe_payload_fn(payload);
468        }
469        // Check for overflow in offsets (if more data was sent than can be represented)
470        if O::from_usize(self.buffer.len()).is_none() {
471            panic!(
472                "Put {} bytes in buffer, more than can be represented by a {}",
473                self.buffer.len(),
474                type_name::<O>()
475            );
476        }
477    }
478
479    /// Converts this set into a `StringArray`, `LargeStringArray`,
480    /// `BinaryArray`, or `LargeBinaryArray` containing each distinct value
481    /// that was inserted. This is done without copying the values.
482    ///
483    /// The values are guaranteed to be returned in the same order in which
484    /// they were first seen.
485    pub fn into_state(self) -> ArrayRef {
486        let Self {
487            output_type,
488            map: _,
489            map_size: _,
490            offsets,
491            mut buffer,
492            random_state: _,
493            hashes_buffer: _,
494            null,
495        } = self;
496
497        // Only make a `NullBuffer` if there was a null value
498        let nulls = null.map(|(_payload, null_index)| {
499            let num_values = offsets.len() - 1;
500            single_null_buffer(num_values, null_index)
501        });
502        // SAFETY: the offsets were constructed correctly in `insert_if_new` --
503        // monotonically increasing, overflows were checked.
504        let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) };
505        let values = buffer.finish();
506
507        match output_type {
508            OutputType::Binary => {
509                // SAFETY: the offsets were constructed correctly
510                Arc::new(unsafe {
511                    GenericBinaryArray::new_unchecked(offsets, values, nulls)
512                })
513            }
514            OutputType::Utf8 => {
515                // SAFETY:
516                // 1. the offsets were constructed safely
517                //
518                // 2. we asserted the input arrays were all the correct type and
519                // thus since all the values that went in were valid (e.g. utf8)
520                // so are all the values that come out
521                Arc::new(unsafe {
522                    GenericStringArray::new_unchecked(offsets, values, nulls)
523                })
524            }
525            _ => unreachable!("View types should use `ArrowBytesViewMap`"),
526        }
527    }
528
529    /// Total number of entries (including null, if present)
530    pub fn len(&self) -> usize {
531        self.non_null_len() + self.null.map(|_| 1).unwrap_or(0)
532    }
533
534    /// Is the set empty?
535    pub fn is_empty(&self) -> bool {
536        self.map.is_empty() && self.null.is_none()
537    }
538
539    /// Number of non null entries
540    pub fn non_null_len(&self) -> usize {
541        self.map.len()
542    }
543
544    /// Return the total size, in bytes, of memory used to store the data in
545    /// this set, not including `self`
546    pub fn size(&self) -> usize {
547        self.map_size
548            + self.buffer.capacity() * size_of::<u8>()
549            + self.offsets.allocated_size()
550            + self.hashes_buffer.allocated_size()
551    }
552}
553
554/// Returns a `NullBuffer` with a single null value at the given index
555fn single_null_buffer(num_values: usize, null_index: usize) -> NullBuffer {
556    let mut null_builder = NullBufferBuilder::new(num_values);
557    null_builder.append_n_non_nulls(null_index);
558    null_builder.append_null();
559    null_builder.append_n_non_nulls(num_values - null_index - 1);
560    // SAFETY: inner builder must be constructed
561    null_builder.finish().unwrap()
562}
563
564impl<O: OffsetSizeTrait, V> Debug for ArrowBytesMap<O, V>
565where
566    V: Debug + PartialEq + Eq + Clone + Copy + Default,
567{
568    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
569        f.debug_struct("ArrowBytesMap")
570            .field("map", &"<map>")
571            .field("map_size", &self.map_size)
572            .field("buffer", &self.buffer)
573            .field("random_state", &self.random_state)
574            .field("hashes_buffer", &self.hashes_buffer)
575            .finish()
576    }
577}
578
579/// Maximum size of a value that can be inlined in the hash table
580const SHORT_VALUE_LEN: usize = size_of::<usize>();
581
582/// Entry in the hash table -- see [`ArrowBytesMap`] for more details
583#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
584struct Entry<O, V>
585where
586    O: OffsetSizeTrait,
587    V: Debug + PartialEq + Eq + Clone + Copy + Default,
588{
589    /// hash of the value (stored to avoid recomputing it in hash table check)
590    hash: u64,
591    /// if len =< [`SHORT_VALUE_LEN`]: the data inlined
592    /// if len > [`SHORT_VALUE_LEN`], the offset of where the data starts
593    offset_or_inline: usize,
594    /// length of the value, in bytes (use O here so we use only i32 for
595    /// strings, rather 64 bit usize)
596    len: O,
597    /// value stored by the entry
598    payload: V,
599}
600
601impl<O, V> Entry<O, V>
602where
603    O: OffsetSizeTrait,
604    V: Debug + PartialEq + Eq + Clone + Copy + Default,
605{
606    /// returns self.offset..self.offset + self.len
607    #[inline(always)]
608    fn range(&self) -> Range<usize> {
609        self.offset_or_inline..self.offset_or_inline + self.len.as_usize()
610    }
611}
612
613#[cfg(test)]
614mod tests {
615    use super::*;
616    use arrow::array::{BinaryArray, LargeBinaryArray, StringArray};
617    use std::collections::HashMap;
618
619    #[test]
620    fn string_set_empty() {
621        let mut set = ArrowBytesSet::<i32>::new(OutputType::Utf8);
622        let array: ArrayRef = Arc::new(StringArray::new_null(0));
623        set.insert(&array);
624        assert_eq!(set.len(), 0);
625        assert_eq!(set.non_null_len(), 0);
626        assert_set(set, &[]);
627    }
628
629    #[test]
630    fn string_set_one_null() {
631        let mut set = ArrowBytesSet::<i32>::new(OutputType::Utf8);
632        let array: ArrayRef = Arc::new(StringArray::new_null(1));
633        set.insert(&array);
634        assert_eq!(set.len(), 1);
635        assert_eq!(set.non_null_len(), 0);
636        assert_set(set, &[None]);
637    }
638
639    #[test]
640    fn string_set_many_null() {
641        let mut set = ArrowBytesSet::<i32>::new(OutputType::Utf8);
642        let array: ArrayRef = Arc::new(StringArray::new_null(11));
643        set.insert(&array);
644        assert_eq!(set.len(), 1);
645        assert_eq!(set.non_null_len(), 0);
646        assert_set(set, &[None]);
647    }
648
649    #[test]
650    fn string_set_basic_i32() {
651        test_string_set_basic::<i32>();
652    }
653
654    #[test]
655    fn string_set_basic_i64() {
656        test_string_set_basic::<i64>();
657    }
658
659    fn test_string_set_basic<O: OffsetSizeTrait>() {
660        // basic test for mixed small and large string values
661        let values = GenericStringArray::<O>::from(vec![
662            Some("a"),
663            Some("b"),
664            Some("CXCCCCCCCC"), // 10 bytes
665            Some(""),
666            Some("cbcxx"), // 5 bytes
667            None,
668            Some("AAAAAAAA"),  // 8 bytes
669            Some("BBBBBQBBB"), // 9 bytes
670            Some("a"),
671            Some("cbcxx"),
672            Some("b"),
673            Some("cbcxx"),
674            Some(""),
675            None,
676            Some("BBBBBQBBB"),
677            Some("BBBBBQBBB"),
678            Some("AAAAAAAA"),
679            Some("CXCCCCCCCC"),
680        ]);
681
682        let mut set = ArrowBytesSet::<O>::new(OutputType::Utf8);
683        let array: ArrayRef = Arc::new(values);
684        set.insert(&array);
685        // values mut appear be in the order they were inserted
686        assert_set(
687            set,
688            &[
689                Some("a"),
690                Some("b"),
691                Some("CXCCCCCCCC"),
692                Some(""),
693                Some("cbcxx"),
694                None,
695                Some("AAAAAAAA"),
696                Some("BBBBBQBBB"),
697            ],
698        );
699    }
700
701    #[test]
702    fn string_set_non_utf8_32() {
703        test_string_set_non_utf8::<i32>();
704    }
705
706    #[test]
707    fn string_set_non_utf8_64() {
708        test_string_set_non_utf8::<i64>();
709    }
710
711    fn test_string_set_non_utf8<O: OffsetSizeTrait>() {
712        // basic test for mixed small and large string values
713        let values = GenericStringArray::<O>::from(vec![
714            Some("a"),
715            Some("✨🔥"),
716            Some("🔥"),
717            Some("✨✨✨"),
718            Some("foobarbaz"),
719            Some("🔥"),
720            Some("✨🔥"),
721        ]);
722
723        let mut set = ArrowBytesSet::<O>::new(OutputType::Utf8);
724        let array: ArrayRef = Arc::new(values);
725        set.insert(&array);
726        // strings mut appear be in the order they were inserted
727        assert_set(
728            set,
729            &[
730                Some("a"),
731                Some("✨🔥"),
732                Some("🔥"),
733                Some("✨✨✨"),
734                Some("foobarbaz"),
735            ],
736        );
737    }
738
739    // asserts that the set contains the expected strings, in the same order
740    fn assert_set<O: OffsetSizeTrait>(set: ArrowBytesSet<O>, expected: &[Option<&str>]) {
741        let strings = set.into_state();
742        let strings = strings.as_string::<O>();
743        let state = strings.into_iter().collect::<Vec<_>>();
744        assert_eq!(state, expected);
745    }
746
747    // Test use of binary output type
748    #[test]
749    fn test_binary_set() {
750        let values: ArrayRef = Arc::new(BinaryArray::from_opt_vec(vec![
751            Some(b"a"),
752            Some(b"CXCCCCCCCC"),
753            None,
754            Some(b"CXCCCCCCCC"),
755        ]));
756
757        let expected: ArrayRef = Arc::new(BinaryArray::from_opt_vec(vec![
758            Some(b"a"),
759            Some(b"CXCCCCCCCC"),
760            None,
761        ]));
762
763        let mut set = ArrowBytesSet::<i32>::new(OutputType::Binary);
764        set.insert(&values);
765        assert_eq!(&set.into_state(), &expected);
766    }
767
768    // Test use of binary output type
769    #[test]
770    fn test_large_binary_set() {
771        let values: ArrayRef = Arc::new(LargeBinaryArray::from_opt_vec(vec![
772            Some(b"a"),
773            Some(b"CXCCCCCCCC"),
774            None,
775            Some(b"CXCCCCCCCC"),
776        ]));
777
778        let expected: ArrayRef = Arc::new(LargeBinaryArray::from_opt_vec(vec![
779            Some(b"a"),
780            Some(b"CXCCCCCCCC"),
781            None,
782        ]));
783
784        let mut set = ArrowBytesSet::<i64>::new(OutputType::Binary);
785        set.insert(&values);
786        assert_eq!(&set.into_state(), &expected);
787    }
788
789    #[test]
790    #[should_panic(
791        expected = "matches!(values.data_type(), DataType::Utf8 | DataType::LargeUtf8)"
792    )]
793    fn test_mismatched_types() {
794        // inserting binary into a set that expects strings should panic
795        let values: ArrayRef = Arc::new(LargeBinaryArray::from_opt_vec(vec![Some(b"a")]));
796
797        let mut set = ArrowBytesSet::<i64>::new(OutputType::Utf8);
798        set.insert(&values);
799    }
800
801    #[test]
802    #[should_panic]
803    fn test_mismatched_sizes() {
804        // inserting large strings into a set that expects small should panic
805        let values: ArrayRef = Arc::new(LargeBinaryArray::from_opt_vec(vec![Some(b"a")]));
806
807        let mut set = ArrowBytesSet::<i32>::new(OutputType::Binary);
808        set.insert(&values);
809    }
810
811    // put more than 2GB in a string set and expect it to panic
812    #[test]
813    #[should_panic(
814        expected = "Put 2147483648 bytes in buffer, more than can be represented by a i32"
815    )]
816    fn test_string_overflow() {
817        let mut set = ArrowBytesSet::<i32>::new(OutputType::Utf8);
818        for value in ["a", "b", "c"] {
819            // 1GB strings, so 3rd is over 2GB and should panic
820            let arr: ArrayRef =
821                Arc::new(StringArray::from_iter_values([value.repeat(1 << 30)]));
822            set.insert(&arr);
823        }
824    }
825
826    // inserting strings into the set does not increase reported memory
827    #[test]
828    fn test_string_set_memory_usage() {
829        let strings1 = GenericStringArray::<i32>::from(vec![
830            Some("a"),
831            Some("b"),
832            Some("CXCCCCCCCC"), // 10 bytes
833            Some("AAAAAAAA"),   // 8 bytes
834            Some("BBBBBQBBB"),  // 9 bytes
835        ]);
836        let total_strings1_len = strings1
837            .iter()
838            .map(|s| s.map(|s| s.len()).unwrap_or(0))
839            .sum::<usize>();
840        let values1: ArrayRef = Arc::new(GenericStringArray::<i32>::from(strings1));
841
842        // Much larger strings in strings2
843        let strings2 = GenericStringArray::<i32>::from(vec![
844            "FOO".repeat(1000),
845            "BAR".repeat(2000),
846            "BAZ".repeat(3000),
847        ]);
848        let total_strings2_len = strings2
849            .iter()
850            .map(|s| s.map(|s| s.len()).unwrap_or(0))
851            .sum::<usize>();
852        let values2: ArrayRef = Arc::new(GenericStringArray::<i32>::from(strings2));
853
854        let mut set = ArrowBytesSet::<i32>::new(OutputType::Utf8);
855        let size_empty = set.size();
856
857        set.insert(&values1);
858        let size_after_values1 = set.size();
859        assert!(size_empty < size_after_values1);
860        assert!(
861            size_after_values1 > total_strings1_len,
862            "expect {size_after_values1} to be more than {total_strings1_len}"
863        );
864        assert!(size_after_values1 < total_strings1_len + total_strings2_len);
865
866        // inserting the same strings should not affect the size
867        set.insert(&values1);
868        assert_eq!(set.size(), size_after_values1);
869
870        // inserting the large strings should increase the reported size
871        set.insert(&values2);
872        let size_after_values2 = set.size();
873        assert!(size_after_values2 > size_after_values1);
874        assert!(size_after_values2 > total_strings1_len + total_strings2_len);
875    }
876
877    #[test]
878    fn test_map() {
879        let input = vec![
880            // Note mix of short/long strings
881            Some("A"),
882            Some("bcdefghijklmnop"),
883            Some("X"),
884            Some("Y"),
885            None,
886            Some("qrstuvqxyzhjwya"),
887            Some("✨🔥"),
888            Some("🔥"),
889            Some("🔥🔥🔥🔥🔥🔥"),
890        ];
891
892        let mut test_map = TestMap::new();
893        test_map.insert(&input);
894        test_map.insert(&input); // put it in twice
895        let expected_output: ArrayRef = Arc::new(StringArray::from(input));
896        assert_eq!(&test_map.into_array(), &expected_output);
897    }
898
899    #[derive(Debug, PartialEq, Eq, Default, Clone, Copy)]
900    struct TestPayload {
901        // store the string value to check against input
902        index: usize, // store the index of the string (each new string gets the next sequential input)
903    }
904
905    /// Wraps an [`ArrowBytesMap`], validating its invariants
906    struct TestMap {
907        map: ArrowBytesMap<i32, TestPayload>,
908        // stores distinct strings seen, in order
909        strings: Vec<Option<String>>,
910        // map strings to index in strings
911        indexes: HashMap<Option<String>, usize>,
912    }
913
914    impl Debug for TestMap {
915        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
916            f.debug_struct("TestMap")
917                .field("map", &"...")
918                .field("strings", &self.strings)
919                .field("indexes", &self.indexes)
920                .finish()
921        }
922    }
923
924    impl TestMap {
925        /// creates a map with TestPayloads for the given strings and then
926        /// validates the payloads
927        fn new() -> Self {
928            Self {
929                map: ArrowBytesMap::new(OutputType::Utf8),
930                strings: vec![],
931                indexes: HashMap::new(),
932            }
933        }
934
935        /// Inserts strings into the map
936        fn insert(&mut self, strings: &[Option<&str>]) {
937            let string_array = StringArray::from(strings.to_vec());
938            let arr: ArrayRef = Arc::new(string_array);
939
940            let mut next_index = self.indexes.len();
941            let mut actual_new_strings = vec![];
942            let mut actual_seen_indexes = vec![];
943            // update self with new values, keeping track of newly added values
944            for str in strings {
945                let str = str.map(|s| s.to_string());
946                let index = self.indexes.get(&str).cloned().unwrap_or_else(|| {
947                    actual_new_strings.push(str.clone());
948                    let index = self.strings.len();
949                    self.strings.push(str.clone());
950                    self.indexes.insert(str, index);
951                    index
952                });
953                actual_seen_indexes.push(index);
954            }
955
956            // insert the values into the map, recording what we did
957            let mut seen_new_strings = vec![];
958            let mut seen_indexes = vec![];
959            self.map.insert_if_new(
960                &arr,
961                |s| {
962                    let value = s
963                        .map(|s| String::from_utf8(s.to_vec()).expect("Non utf8 string"));
964                    let index = next_index;
965                    next_index += 1;
966                    seen_new_strings.push(value);
967                    TestPayload { index }
968                },
969                |payload| {
970                    seen_indexes.push(payload.index);
971                },
972            );
973
974            assert_eq!(actual_seen_indexes, seen_indexes);
975            assert_eq!(actual_new_strings, seen_new_strings);
976        }
977
978        /// Call `self.map.into_array()` validating that the strings are in the same
979        /// order as they were inserted
980        fn into_array(self) -> ArrayRef {
981            let Self {
982                map,
983                strings,
984                indexes: _,
985            } = self;
986
987            let arr = map.into_state();
988            let expected: ArrayRef = Arc::new(StringArray::from(strings));
989            assert_eq!(&arr, &expected);
990            arr
991        }
992    }
993}