use ahash::RandomState;
use arrow::array::cast::AsArray;
use arrow::array::{Array, ArrayBuilder, ArrayRef, GenericByteViewBuilder};
use arrow::datatypes::{BinaryViewType, ByteViewType, DataType, StringViewType};
use datafusion_common::hash_utils::create_hashes;
use datafusion_common::utils::proxy::{HashTableAllocExt, VecAllocExt};
use std::fmt::Debug;
use std::sync::Arc;
use crate::binary_map::OutputType;
#[derive(Debug)]
pub struct ArrowBytesViewSet(ArrowBytesViewMap<()>);
impl ArrowBytesViewSet {
pub fn new(output_type: OutputType) -> Self {
Self(ArrowBytesViewMap::new(output_type))
}
pub fn insert(&mut self, values: &ArrayRef) {
fn make_payload_fn(_value: Option<&[u8]>) {}
fn observe_payload_fn(_payload: ()) {}
self.0
.insert_if_new(values, make_payload_fn, observe_payload_fn);
}
pub fn take(&mut self) -> Self {
let mut new_self = Self::new(self.0.output_type);
std::mem::swap(self, &mut new_self);
new_self
}
pub fn into_state(self) -> ArrayRef {
self.0.into_state()
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn non_null_len(&self) -> usize {
self.0.non_null_len()
}
pub fn size(&self) -> usize {
self.0.size()
}
}
pub struct ArrowBytesViewMap<V>
where
V: Debug + PartialEq + Eq + Clone + Copy + Default,
{
output_type: OutputType,
map: hashbrown::hash_table::HashTable<Entry<V>>,
map_size: usize,
builder: GenericByteViewBuilder<BinaryViewType>,
random_state: RandomState,
hashes_buffer: Vec<u64>,
null: Option<(V, usize)>,
}
const INITIAL_MAP_CAPACITY: usize = 512;
impl<V> ArrowBytesViewMap<V>
where
V: Debug + PartialEq + Eq + Clone + Copy + Default,
{
pub fn new(output_type: OutputType) -> Self {
Self {
output_type,
map: hashbrown::hash_table::HashTable::with_capacity(INITIAL_MAP_CAPACITY),
map_size: 0,
builder: GenericByteViewBuilder::new(),
random_state: RandomState::new(),
hashes_buffer: vec![],
null: None,
}
}
pub fn take(&mut self) -> Self {
let mut new_self = Self::new(self.output_type);
std::mem::swap(self, &mut new_self);
new_self
}
pub fn insert_if_new<MP, OP>(
&mut self,
values: &ArrayRef,
make_payload_fn: MP,
observe_payload_fn: OP,
) where
MP: FnMut(Option<&[u8]>) -> V,
OP: FnMut(V),
{
match self.output_type {
OutputType::BinaryView => {
assert!(matches!(values.data_type(), DataType::BinaryView));
self.insert_if_new_inner::<MP, OP, BinaryViewType>(
values,
make_payload_fn,
observe_payload_fn,
)
}
OutputType::Utf8View => {
assert!(matches!(values.data_type(), DataType::Utf8View));
self.insert_if_new_inner::<MP, OP, StringViewType>(
values,
make_payload_fn,
observe_payload_fn,
)
}
_ => unreachable!("Utf8/Binary should use `ArrowBytesSet`"),
};
}
fn insert_if_new_inner<MP, OP, B>(
&mut self,
values: &ArrayRef,
mut make_payload_fn: MP,
mut observe_payload_fn: OP,
) where
MP: FnMut(Option<&[u8]>) -> V,
OP: FnMut(V),
B: ByteViewType,
{
let batch_hashes = &mut self.hashes_buffer;
batch_hashes.clear();
batch_hashes.resize(values.len(), 0);
create_hashes(&[Arc::clone(values)], &self.random_state, batch_hashes)
.unwrap();
let values = values.as_byte_view::<B>();
assert_eq!(values.len(), batch_hashes.len());
for (value, &hash) in values.iter().zip(batch_hashes.iter()) {
let Some(value) = value else {
let payload = if let Some(&(payload, _offset)) = self.null.as_ref() {
payload
} else {
let payload = make_payload_fn(None);
let null_index = self.builder.len();
self.builder.append_null();
self.null = Some((payload, null_index));
payload
};
observe_payload_fn(payload);
continue;
};
let value: &[u8] = value.as_ref();
let entry = self.map.find_mut(hash, |header| {
let v = self.builder.get_value(header.view_idx);
if v.len() != value.len() {
return false;
}
v == value
});
let payload = if let Some(entry) = entry {
entry.payload
} else {
let payload = make_payload_fn(Some(value));
let inner_view_idx = self.builder.len();
let new_header = Entry {
view_idx: inner_view_idx,
hash,
payload,
};
self.builder.append_value(value);
self.map
.insert_accounted(new_header, |h| h.hash, &mut self.map_size);
payload
};
observe_payload_fn(payload);
}
}
pub fn into_state(self) -> ArrayRef {
let mut builder = self.builder;
match self.output_type {
OutputType::BinaryView => {
let array = builder.finish();
Arc::new(array)
}
OutputType::Utf8View => {
let array = builder.finish();
let array = unsafe { array.to_string_view_unchecked() };
Arc::new(array)
}
_ => {
unreachable!("Utf8/Binary should use `ArrowBytesMap`")
}
}
}
pub fn len(&self) -> usize {
self.non_null_len() + self.null.map(|_| 1).unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.map.is_empty() && self.null.is_none()
}
pub fn non_null_len(&self) -> usize {
self.map.len()
}
pub fn size(&self) -> usize {
self.map_size
+ self.builder.allocated_size()
+ self.hashes_buffer.allocated_size()
}
}
impl<V> Debug for ArrowBytesViewMap<V>
where
V: Debug + PartialEq + Eq + Clone + Copy + Default,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ArrowBytesMap")
.field("map", &"<map>")
.field("map_size", &self.map_size)
.field("view_builder", &self.builder)
.field("random_state", &self.random_state)
.field("hashes_buffer", &self.hashes_buffer)
.finish()
}
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
struct Entry<V>
where
V: Debug + PartialEq + Eq + Clone + Copy + Default,
{
view_idx: usize,
hash: u64,
payload: V,
}
#[cfg(test)]
mod tests {
use arrow::array::{BinaryViewArray, GenericByteViewArray, StringViewArray};
use datafusion_common::HashMap;
use super::*;
fn assert_set(set: ArrowBytesViewSet, expected: &[Option<&str>]) {
let strings = set.into_state();
let strings = strings.as_string_view();
let state = strings.into_iter().collect::<Vec<_>>();
assert_eq!(state, expected);
}
#[test]
fn string_view_set_empty() {
let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
let array: ArrayRef = Arc::new(StringViewArray::new_null(0));
set.insert(&array);
assert_eq!(set.len(), 0);
assert_eq!(set.non_null_len(), 0);
assert_set(set, &[]);
}
#[test]
fn string_view_set_one_null() {
let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
let array: ArrayRef = Arc::new(StringViewArray::new_null(1));
set.insert(&array);
assert_eq!(set.len(), 1);
assert_eq!(set.non_null_len(), 0);
assert_set(set, &[None]);
}
#[test]
fn string_view_set_many_null() {
let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
let array: ArrayRef = Arc::new(StringViewArray::new_null(11));
set.insert(&array);
assert_eq!(set.len(), 1);
assert_eq!(set.non_null_len(), 0);
assert_set(set, &[None]);
}
#[test]
fn test_string_view_set_basic() {
let values = GenericByteViewArray::from(vec![
Some("a"),
Some("b"),
Some("CXCCCCCCCCAABB"), Some(""),
Some("cbcxx"), None,
Some("AAAAAAAA"), Some("BBBBBQBBBAAA"), Some("a"),
Some("cbcxx"),
Some("b"),
Some("cbcxx"),
Some(""),
None,
Some("BBBBBQBBBAAA"),
Some("BBBBBQBBBAAA"),
Some("AAAAAAAA"),
Some("CXCCCCCCCCAABB"),
]);
let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
let array: ArrayRef = Arc::new(values);
set.insert(&array);
assert_set(
set,
&[
Some("a"),
Some("b"),
Some("CXCCCCCCCCAABB"),
Some(""),
Some("cbcxx"),
None,
Some("AAAAAAAA"),
Some("BBBBBQBBBAAA"),
],
);
}
#[test]
fn test_string_set_non_utf8() {
let values = GenericByteViewArray::from(vec![
Some("a"),
Some("✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥"),
Some("🔥"),
Some("✨✨✨"),
Some("foobarbaz"),
Some("🔥"),
Some("✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥"),
]);
let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
let array: ArrayRef = Arc::new(values);
set.insert(&array);
assert_set(
set,
&[
Some("a"),
Some("✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥"),
Some("🔥"),
Some("✨✨✨"),
Some("foobarbaz"),
],
);
}
#[test]
fn test_binary_set() {
let v: Vec<Option<&[u8]>> = vec![
Some(b"a"),
Some(b"CXCCCCCCCCCCCCC"),
None,
Some(b"CXCCCCCCCCCCCCC"),
];
let values: ArrayRef = Arc::new(BinaryViewArray::from(v));
let expected: Vec<Option<&[u8]>> =
vec![Some(b"a"), Some(b"CXCCCCCCCCCCCCC"), None];
let expected: ArrayRef = Arc::new(GenericByteViewArray::from(expected));
let mut set = ArrowBytesViewSet::new(OutputType::BinaryView);
set.insert(&values);
assert_eq!(&set.into_state(), &expected);
}
#[test]
fn test_string_set_memory_usage() {
let strings1 = StringViewArray::from(vec![
Some("a"),
Some("b"),
Some("CXCCCCCCCCCCC"), Some("AAAAAAAA"), Some("BBBBBQBBB"), ]);
let total_strings1_len = strings1
.iter()
.map(|s| s.map(|s| s.len()).unwrap_or(0))
.sum::<usize>();
let values1: ArrayRef = Arc::new(StringViewArray::from(strings1));
let strings2 = StringViewArray::from(vec![
"FOO".repeat(1000),
"BAR larger than 12 bytes.".repeat(100_000),
"more unique.".repeat(1000),
"more unique2.".repeat(1000),
"FOO".repeat(3000),
]);
let total_strings2_len = strings2
.iter()
.map(|s| s.map(|s| s.len()).unwrap_or(0))
.sum::<usize>();
let values2: ArrayRef = Arc::new(StringViewArray::from(strings2));
let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
let size_empty = set.size();
set.insert(&values1);
let size_after_values1 = set.size();
assert!(size_empty < size_after_values1);
assert!(
size_after_values1 > total_strings1_len,
"expect {size_after_values1} to be more than {total_strings1_len}"
);
assert!(size_after_values1 < total_strings1_len + total_strings2_len);
set.insert(&values1);
assert_eq!(set.size(), size_after_values1);
assert_eq!(set.len(), 5);
set.insert(&values2);
let size_after_values2 = set.size();
assert!(size_after_values2 > size_after_values1);
assert_eq!(set.len(), 10);
}
#[derive(Debug, PartialEq, Eq, Default, Clone, Copy)]
struct TestPayload {
index: usize, }
struct TestMap {
map: ArrowBytesViewMap<TestPayload>,
strings: Vec<Option<String>>,
indexes: HashMap<Option<String>, usize>,
}
impl TestMap {
fn new() -> Self {
Self {
map: ArrowBytesViewMap::new(OutputType::Utf8View),
strings: vec![],
indexes: HashMap::new(),
}
}
fn insert(&mut self, strings: &[Option<&str>]) {
let string_array = StringViewArray::from(strings.to_vec());
let arr: ArrayRef = Arc::new(string_array);
let mut next_index = self.indexes.len();
let mut actual_new_strings = vec![];
let mut actual_seen_indexes = vec![];
for str in strings {
let str = str.map(|s| s.to_string());
let index = self.indexes.get(&str).cloned().unwrap_or_else(|| {
actual_new_strings.push(str.clone());
let index = self.strings.len();
self.strings.push(str.clone());
self.indexes.insert(str, index);
index
});
actual_seen_indexes.push(index);
}
let mut seen_new_strings = vec![];
let mut seen_indexes = vec![];
self.map.insert_if_new(
&arr,
|s| {
let value = s
.map(|s| String::from_utf8(s.to_vec()).expect("Non utf8 string"));
let index = next_index;
next_index += 1;
seen_new_strings.push(value);
TestPayload { index }
},
|payload| {
seen_indexes.push(payload.index);
},
);
assert_eq!(actual_seen_indexes, seen_indexes);
assert_eq!(actual_new_strings, seen_new_strings);
}
fn into_array(self) -> ArrayRef {
let Self {
map,
strings,
indexes: _,
} = self;
let arr = map.into_state();
let expected: ArrayRef = Arc::new(StringViewArray::from(strings));
assert_eq!(&arr, &expected);
arr
}
}
#[test]
fn test_map() {
let input = vec![
Some("A"),
Some("bcdefghijklmnop1234567"),
Some("X"),
Some("Y"),
None,
Some("qrstuvqxyzhjwya"),
Some("✨🔥"),
Some("🔥"),
Some("🔥🔥🔥🔥🔥🔥"),
];
let mut test_map = TestMap::new();
test_map.insert(&input);
test_map.insert(&input); let expected_output: ArrayRef = Arc::new(StringViewArray::from(input));
assert_eq!(&test_map.into_array(), &expected_output);
}
}