use ahash::RandomState;
use arrow::array::{
Array, ArrayRef, BufferBuilder, GenericBinaryArray, GenericStringArray,
NullBufferBuilder, OffsetSizeTrait,
cast::AsArray,
types::{ByteArrayType, GenericBinaryType, GenericStringType},
};
use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer};
use arrow::datatypes::DataType;
use datafusion_common::hash_utils::create_hashes;
use datafusion_common::utils::proxy::{HashTableAllocExt, VecAllocExt};
use std::any::type_name;
use std::fmt::Debug;
use std::mem::{size_of, swap};
use std::ops::Range;
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OutputType {
Utf8,
Utf8View,
Binary,
BinaryView,
}
#[derive(Debug)]
pub struct ArrowBytesSet<O: OffsetSizeTrait>(ArrowBytesMap<O, ()>);
impl<O: OffsetSizeTrait> ArrowBytesSet<O> {
pub fn new(output_type: OutputType) -> Self {
Self(ArrowBytesMap::new(output_type))
}
pub fn take(&mut self) -> Self {
Self(self.0.take())
}
pub fn insert(&mut self, values: &ArrayRef) {
fn make_payload_fn(_value: Option<&[u8]>) {}
fn observe_payload_fn(_payload: ()) {}
self.0
.insert_if_new(values, make_payload_fn, observe_payload_fn);
}
pub fn into_state(self) -> ArrayRef {
self.0.into_state()
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn non_null_len(&self) -> usize {
self.0.non_null_len()
}
pub fn size(&self) -> usize {
self.0.size()
}
}
pub struct ArrowBytesMap<O, V>
where
O: OffsetSizeTrait,
V: Debug + PartialEq + Eq + Clone + Copy + Default,
{
output_type: OutputType,
map: hashbrown::hash_table::HashTable<Entry<O, V>>,
map_size: usize,
buffer: BufferBuilder<u8>,
offsets: Vec<O>,
random_state: RandomState,
hashes_buffer: Vec<u64>,
null: Option<(V, usize)>,
}
const INITIAL_MAP_CAPACITY: usize = 128;
pub const INITIAL_BUFFER_CAPACITY: usize = 8 * 1024;
impl<O: OffsetSizeTrait, V> ArrowBytesMap<O, V>
where
V: Debug + PartialEq + Eq + Clone + Copy + Default,
{
pub fn new(output_type: OutputType) -> Self {
Self {
output_type,
map: hashbrown::hash_table::HashTable::with_capacity(INITIAL_MAP_CAPACITY),
map_size: 0,
buffer: BufferBuilder::new(INITIAL_BUFFER_CAPACITY),
offsets: vec![O::default()], random_state: RandomState::new(),
hashes_buffer: vec![],
null: None,
}
}
pub fn take(&mut self) -> Self {
let mut new_self = Self::new(self.output_type);
swap(self, &mut new_self);
new_self
}
pub fn insert_if_new<MP, OP>(
&mut self,
values: &ArrayRef,
make_payload_fn: MP,
observe_payload_fn: OP,
) where
MP: FnMut(Option<&[u8]>) -> V,
OP: FnMut(V),
{
match self.output_type {
OutputType::Binary => {
assert!(matches!(
values.data_type(),
DataType::Binary | DataType::LargeBinary
));
self.insert_if_new_inner::<MP, OP, GenericBinaryType<O>>(
values,
make_payload_fn,
observe_payload_fn,
)
}
OutputType::Utf8 => {
assert!(matches!(
values.data_type(),
DataType::Utf8 | DataType::LargeUtf8
));
self.insert_if_new_inner::<MP, OP, GenericStringType<O>>(
values,
make_payload_fn,
observe_payload_fn,
)
}
_ => unreachable!("View types should use `ArrowBytesViewMap`"),
};
}
fn insert_if_new_inner<MP, OP, B>(
&mut self,
values: &ArrayRef,
mut make_payload_fn: MP,
mut observe_payload_fn: OP,
) where
MP: FnMut(Option<&[u8]>) -> V,
OP: FnMut(V),
B: ByteArrayType,
{
let batch_hashes = &mut self.hashes_buffer;
batch_hashes.clear();
batch_hashes.resize(values.len(), 0);
create_hashes([values], &self.random_state, batch_hashes)
.unwrap();
let values = values.as_bytes::<B>();
assert_eq!(values.len(), batch_hashes.len());
for (value, &hash) in values.iter().zip(batch_hashes.iter()) {
let Some(value) = value else {
let payload = if let Some(&(payload, _offset)) = self.null.as_ref() {
payload
} else {
let payload = make_payload_fn(None);
let null_index = self.offsets.len() - 1;
let offset = self.buffer.len();
self.offsets.push(O::usize_as(offset));
self.null = Some((payload, null_index));
payload
};
observe_payload_fn(payload);
continue;
};
let value: &[u8] = value.as_ref();
let value_len = O::usize_as(value.len());
let payload = if value.len() <= SHORT_VALUE_LEN {
let inline = value.iter().fold(0usize, |acc, &x| (acc << 8) | x as usize);
let entry = self.map.find_mut(hash, |header| {
if header.hash != hash || header.len != value_len {
return false;
}
inline == header.offset_or_inline
});
if let Some(entry) = entry {
entry.payload
}
else {
self.buffer.append_slice(value);
self.offsets.push(O::usize_as(self.buffer.len()));
let payload = make_payload_fn(Some(value));
let new_header = Entry {
hash,
len: value_len,
offset_or_inline: inline,
payload,
};
self.map.insert_accounted(
new_header,
|header| header.hash,
&mut self.map_size,
);
payload
}
}
else {
let entry = self.map.find_mut(hash, |header| {
if header.hash != hash {
return false;
}
let existing_value =
unsafe { self.buffer.as_slice().get_unchecked(header.range()) };
value == existing_value
});
if let Some(entry) = entry {
entry.payload
}
else {
let offset = self.buffer.len(); self.buffer.append_slice(value);
self.offsets.push(O::usize_as(self.buffer.len()));
let payload = make_payload_fn(Some(value));
let new_header = Entry {
hash,
len: value_len,
offset_or_inline: offset,
payload,
};
self.map.insert_accounted(
new_header,
|header| header.hash,
&mut self.map_size,
);
payload
}
};
observe_payload_fn(payload);
}
if O::from_usize(self.buffer.len()).is_none() {
panic!(
"Put {} bytes in buffer, more than can be represented by a {}",
self.buffer.len(),
type_name::<O>()
);
}
}
pub fn into_state(self) -> ArrayRef {
let Self {
output_type,
map: _,
map_size: _,
offsets,
mut buffer,
random_state: _,
hashes_buffer: _,
null,
} = self;
let nulls = null.map(|(_payload, null_index)| {
let num_values = offsets.len() - 1;
single_null_buffer(num_values, null_index)
});
let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) };
let values = buffer.finish();
match output_type {
OutputType::Binary => {
Arc::new(unsafe {
GenericBinaryArray::new_unchecked(offsets, values, nulls)
})
}
OutputType::Utf8 => {
Arc::new(unsafe {
GenericStringArray::new_unchecked(offsets, values, nulls)
})
}
_ => unreachable!("View types should use `ArrowBytesViewMap`"),
}
}
pub fn len(&self) -> usize {
self.non_null_len() + self.null.map(|_| 1).unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.map.is_empty() && self.null.is_none()
}
pub fn non_null_len(&self) -> usize {
self.map.len()
}
pub fn size(&self) -> usize {
self.map_size
+ self.buffer.capacity() * size_of::<u8>()
+ self.offsets.allocated_size()
+ self.hashes_buffer.allocated_size()
}
}
fn single_null_buffer(num_values: usize, null_index: usize) -> NullBuffer {
let mut null_builder = NullBufferBuilder::new(num_values);
null_builder.append_n_non_nulls(null_index);
null_builder.append_null();
null_builder.append_n_non_nulls(num_values - null_index - 1);
null_builder.finish().unwrap()
}
impl<O: OffsetSizeTrait, V> Debug for ArrowBytesMap<O, V>
where
V: Debug + PartialEq + Eq + Clone + Copy + Default,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ArrowBytesMap")
.field("map", &"<map>")
.field("map_size", &self.map_size)
.field("buffer", &self.buffer)
.field("random_state", &self.random_state)
.field("hashes_buffer", &self.hashes_buffer)
.finish()
}
}
const SHORT_VALUE_LEN: usize = size_of::<usize>();
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
struct Entry<O, V>
where
O: OffsetSizeTrait,
V: Debug + PartialEq + Eq + Clone + Copy + Default,
{
hash: u64,
offset_or_inline: usize,
len: O,
payload: V,
}
impl<O, V> Entry<O, V>
where
O: OffsetSizeTrait,
V: Debug + PartialEq + Eq + Clone + Copy + Default,
{
#[inline(always)]
fn range(&self) -> Range<usize> {
self.offset_or_inline..self.offset_or_inline + self.len.as_usize()
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{BinaryArray, LargeBinaryArray, StringArray};
use std::collections::HashMap;
#[test]
fn string_set_empty() {
let mut set = ArrowBytesSet::<i32>::new(OutputType::Utf8);
let array: ArrayRef = Arc::new(StringArray::new_null(0));
set.insert(&array);
assert_eq!(set.len(), 0);
assert_eq!(set.non_null_len(), 0);
assert_set(set, &[]);
}
#[test]
fn string_set_one_null() {
let mut set = ArrowBytesSet::<i32>::new(OutputType::Utf8);
let array: ArrayRef = Arc::new(StringArray::new_null(1));
set.insert(&array);
assert_eq!(set.len(), 1);
assert_eq!(set.non_null_len(), 0);
assert_set(set, &[None]);
}
#[test]
fn string_set_many_null() {
let mut set = ArrowBytesSet::<i32>::new(OutputType::Utf8);
let array: ArrayRef = Arc::new(StringArray::new_null(11));
set.insert(&array);
assert_eq!(set.len(), 1);
assert_eq!(set.non_null_len(), 0);
assert_set(set, &[None]);
}
#[test]
fn string_set_basic_i32() {
test_string_set_basic::<i32>();
}
#[test]
fn string_set_basic_i64() {
test_string_set_basic::<i64>();
}
fn test_string_set_basic<O: OffsetSizeTrait>() {
let values = GenericStringArray::<O>::from(vec![
Some("a"),
Some("b"),
Some("CXCCCCCCCC"), Some(""),
Some("cbcxx"), None,
Some("AAAAAAAA"), Some("BBBBBQBBB"), Some("a"),
Some("cbcxx"),
Some("b"),
Some("cbcxx"),
Some(""),
None,
Some("BBBBBQBBB"),
Some("BBBBBQBBB"),
Some("AAAAAAAA"),
Some("CXCCCCCCCC"),
]);
let mut set = ArrowBytesSet::<O>::new(OutputType::Utf8);
let array: ArrayRef = Arc::new(values);
set.insert(&array);
assert_set(
set,
&[
Some("a"),
Some("b"),
Some("CXCCCCCCCC"),
Some(""),
Some("cbcxx"),
None,
Some("AAAAAAAA"),
Some("BBBBBQBBB"),
],
);
}
#[test]
fn string_set_non_utf8_32() {
test_string_set_non_utf8::<i32>();
}
#[test]
fn string_set_non_utf8_64() {
test_string_set_non_utf8::<i64>();
}
fn test_string_set_non_utf8<O: OffsetSizeTrait>() {
let values = GenericStringArray::<O>::from(vec![
Some("a"),
Some("✨🔥"),
Some("🔥"),
Some("✨✨✨"),
Some("foobarbaz"),
Some("🔥"),
Some("✨🔥"),
]);
let mut set = ArrowBytesSet::<O>::new(OutputType::Utf8);
let array: ArrayRef = Arc::new(values);
set.insert(&array);
assert_set(
set,
&[
Some("a"),
Some("✨🔥"),
Some("🔥"),
Some("✨✨✨"),
Some("foobarbaz"),
],
);
}
fn assert_set<O: OffsetSizeTrait>(set: ArrowBytesSet<O>, expected: &[Option<&str>]) {
let strings = set.into_state();
let strings = strings.as_string::<O>();
let state = strings.into_iter().collect::<Vec<_>>();
assert_eq!(state, expected);
}
#[test]
fn test_binary_set() {
let values: ArrayRef = Arc::new(BinaryArray::from_opt_vec(vec![
Some(b"a"),
Some(b"CXCCCCCCCC"),
None,
Some(b"CXCCCCCCCC"),
]));
let expected: ArrayRef = Arc::new(BinaryArray::from_opt_vec(vec![
Some(b"a"),
Some(b"CXCCCCCCCC"),
None,
]));
let mut set = ArrowBytesSet::<i32>::new(OutputType::Binary);
set.insert(&values);
assert_eq!(&set.into_state(), &expected);
}
#[test]
fn test_large_binary_set() {
let values: ArrayRef = Arc::new(LargeBinaryArray::from_opt_vec(vec![
Some(b"a"),
Some(b"CXCCCCCCCC"),
None,
Some(b"CXCCCCCCCC"),
]));
let expected: ArrayRef = Arc::new(LargeBinaryArray::from_opt_vec(vec![
Some(b"a"),
Some(b"CXCCCCCCCC"),
None,
]));
let mut set = ArrowBytesSet::<i64>::new(OutputType::Binary);
set.insert(&values);
assert_eq!(&set.into_state(), &expected);
}
#[test]
#[should_panic(
expected = "matches!(values.data_type(), DataType::Utf8 | DataType::LargeUtf8)"
)]
fn test_mismatched_types() {
let values: ArrayRef = Arc::new(LargeBinaryArray::from_opt_vec(vec![Some(b"a")]));
let mut set = ArrowBytesSet::<i64>::new(OutputType::Utf8);
set.insert(&values);
}
#[test]
#[should_panic]
fn test_mismatched_sizes() {
let values: ArrayRef = Arc::new(LargeBinaryArray::from_opt_vec(vec![Some(b"a")]));
let mut set = ArrowBytesSet::<i32>::new(OutputType::Binary);
set.insert(&values);
}
#[test]
#[should_panic(
expected = "Put 2147483648 bytes in buffer, more than can be represented by a i32"
)]
fn test_string_overflow() {
let mut set = ArrowBytesSet::<i32>::new(OutputType::Utf8);
for value in ["a", "b", "c"] {
let arr: ArrayRef =
Arc::new(StringArray::from_iter_values([value.repeat(1 << 30)]));
set.insert(&arr);
}
}
#[test]
fn test_string_set_memory_usage() {
let strings1 = GenericStringArray::<i32>::from(vec![
Some("a"),
Some("b"),
Some("CXCCCCCCCC"), Some("AAAAAAAA"), Some("BBBBBQBBB"), ]);
let total_strings1_len = strings1
.iter()
.map(|s| s.map(|s| s.len()).unwrap_or(0))
.sum::<usize>();
let values1: ArrayRef = Arc::new(GenericStringArray::<i32>::from(strings1));
let strings2 = GenericStringArray::<i32>::from(vec![
"FOO".repeat(1000),
"BAR".repeat(2000),
"BAZ".repeat(3000),
]);
let total_strings2_len = strings2
.iter()
.map(|s| s.map(|s| s.len()).unwrap_or(0))
.sum::<usize>();
let values2: ArrayRef = Arc::new(GenericStringArray::<i32>::from(strings2));
let mut set = ArrowBytesSet::<i32>::new(OutputType::Utf8);
let size_empty = set.size();
set.insert(&values1);
let size_after_values1 = set.size();
assert!(size_empty < size_after_values1);
assert!(
size_after_values1 > total_strings1_len,
"expect {size_after_values1} to be more than {total_strings1_len}"
);
assert!(size_after_values1 < total_strings1_len + total_strings2_len);
set.insert(&values1);
assert_eq!(set.size(), size_after_values1);
set.insert(&values2);
let size_after_values2 = set.size();
assert!(size_after_values2 > size_after_values1);
assert!(size_after_values2 > total_strings1_len + total_strings2_len);
}
#[test]
fn test_map() {
let input = vec![
Some("A"),
Some("bcdefghijklmnop"),
Some("X"),
Some("Y"),
None,
Some("qrstuvqxyzhjwya"),
Some("✨🔥"),
Some("🔥"),
Some("🔥🔥🔥🔥🔥🔥"),
];
let mut test_map = TestMap::new();
test_map.insert(&input);
test_map.insert(&input); let expected_output: ArrayRef = Arc::new(StringArray::from(input));
assert_eq!(&test_map.into_array(), &expected_output);
}
#[derive(Debug, PartialEq, Eq, Default, Clone, Copy)]
struct TestPayload {
index: usize, }
struct TestMap {
map: ArrowBytesMap<i32, TestPayload>,
strings: Vec<Option<String>>,
indexes: HashMap<Option<String>, usize>,
}
impl Debug for TestMap {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TestMap")
.field("map", &"...")
.field("strings", &self.strings)
.field("indexes", &self.indexes)
.finish()
}
}
impl TestMap {
fn new() -> Self {
Self {
map: ArrowBytesMap::new(OutputType::Utf8),
strings: vec![],
indexes: HashMap::new(),
}
}
fn insert(&mut self, strings: &[Option<&str>]) {
let string_array = StringArray::from(strings.to_vec());
let arr: ArrayRef = Arc::new(string_array);
let mut next_index = self.indexes.len();
let mut actual_new_strings = vec![];
let mut actual_seen_indexes = vec![];
for str in strings {
let str = str.map(|s| s.to_string());
let index = self.indexes.get(&str).cloned().unwrap_or_else(|| {
actual_new_strings.push(str.clone());
let index = self.strings.len();
self.strings.push(str.clone());
self.indexes.insert(str, index);
index
});
actual_seen_indexes.push(index);
}
let mut seen_new_strings = vec![];
let mut seen_indexes = vec![];
self.map.insert_if_new(
&arr,
|s| {
let value = s
.map(|s| String::from_utf8(s.to_vec()).expect("Non utf8 string"));
let index = next_index;
next_index += 1;
seen_new_strings.push(value);
TestPayload { index }
},
|payload| {
seen_indexes.push(payload.index);
},
);
assert_eq!(actual_seen_indexes, seen_indexes);
assert_eq!(actual_new_strings, seen_new_strings);
}
fn into_array(self) -> ArrayRef {
let Self {
map,
strings,
indexes: _,
} = self;
let arr = map.into_state();
let expected: ArrayRef = Arc::new(StringArray::from(strings));
assert_eq!(&arr, &expected);
arr
}
}
}