use crate::binary_map::OutputType;
use ahash::RandomState;
use arrow::array::NullBufferBuilder;
use arrow::array::cast::AsArray;
use arrow::array::{Array, ArrayRef, BinaryViewArray, ByteView, make_view};
use arrow::buffer::{Buffer, ScalarBuffer};
use arrow::datatypes::{BinaryViewType, ByteViewType, DataType, StringViewType};
use datafusion_common::hash_utils::create_hashes;
use datafusion_common::utils::proxy::{HashTableAllocExt, VecAllocExt};
use std::fmt::Debug;
use std::mem::size_of;
use std::sync::Arc;
#[derive(Debug)]
pub struct ArrowBytesViewSet(ArrowBytesViewMap<()>);
impl ArrowBytesViewSet {
pub fn new(output_type: OutputType) -> Self {
Self(ArrowBytesViewMap::new(output_type))
}
pub fn insert(&mut self, values: &ArrayRef) {
fn make_payload_fn(_value: Option<&[u8]>) {}
fn observe_payload_fn(_payload: ()) {}
self.0
.insert_if_new(values, make_payload_fn, observe_payload_fn);
}
pub fn take(&mut self) -> Self {
let mut new_self = Self::new(self.0.output_type);
std::mem::swap(self, &mut new_self);
new_self
}
pub fn into_state(self) -> ArrayRef {
self.0.into_state()
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn non_null_len(&self) -> usize {
self.0.non_null_len()
}
pub fn size(&self) -> usize {
self.0.size()
}
}
const BYTE_VIEW_MAX_BLOCK_SIZE: usize = 2 * 1024 * 1024;
pub struct ArrowBytesViewMap<V>
where
V: Debug + PartialEq + Eq + Clone + Copy + Default,
{
output_type: OutputType,
map: hashbrown::hash_table::HashTable<Entry<V>>,
map_size: usize,
views: Vec<u128>,
in_progress: Vec<u8>,
completed: Vec<Buffer>,
nulls: NullBufferBuilder,
random_state: RandomState,
hashes_buffer: Vec<u64>,
null: Option<(V, usize)>,
}
const INITIAL_MAP_CAPACITY: usize = 512;
impl<V> ArrowBytesViewMap<V>
where
V: Debug + PartialEq + Eq + Clone + Copy + Default,
{
pub fn new(output_type: OutputType) -> Self {
Self {
output_type,
map: hashbrown::hash_table::HashTable::with_capacity(INITIAL_MAP_CAPACITY),
map_size: 0,
views: Vec::new(),
in_progress: Vec::new(),
completed: Vec::new(),
nulls: NullBufferBuilder::new(0),
random_state: RandomState::new(),
hashes_buffer: vec![],
null: None,
}
}
pub fn take(&mut self) -> Self {
let mut new_self = Self::new(self.output_type);
std::mem::swap(self, &mut new_self);
new_self
}
pub fn insert_if_new<MP, OP>(
&mut self,
values: &ArrayRef,
make_payload_fn: MP,
observe_payload_fn: OP,
) where
MP: FnMut(Option<&[u8]>) -> V,
OP: FnMut(V),
{
match self.output_type {
OutputType::BinaryView => {
assert!(matches!(values.data_type(), DataType::BinaryView));
self.insert_if_new_inner::<MP, OP, BinaryViewType>(
values,
make_payload_fn,
observe_payload_fn,
)
}
OutputType::Utf8View => {
assert!(matches!(values.data_type(), DataType::Utf8View));
self.insert_if_new_inner::<MP, OP, StringViewType>(
values,
make_payload_fn,
observe_payload_fn,
)
}
_ => unreachable!("Utf8/Binary should use `ArrowBytesSet`"),
};
}
fn insert_if_new_inner<MP, OP, B>(
&mut self,
values: &ArrayRef,
mut make_payload_fn: MP,
mut observe_payload_fn: OP,
) where
MP: FnMut(Option<&[u8]>) -> V,
OP: FnMut(V),
B: ByteViewType,
{
let batch_hashes = &mut self.hashes_buffer;
batch_hashes.clear();
batch_hashes.resize(values.len(), 0);
create_hashes([values], &self.random_state, batch_hashes)
.unwrap();
let values = values.as_byte_view::<B>();
let input_views = values.views();
assert_eq!(values.len(), self.hashes_buffer.len());
for i in 0..values.len() {
let view_u128 = input_views[i];
let hash = self.hashes_buffer[i];
if values.is_null(i) {
let payload = if let Some(&(payload, _offset)) = self.null.as_ref() {
payload
} else {
let payload = make_payload_fn(None);
let null_index = self.views.len();
self.views.push(0);
self.nulls.append_null();
self.null = Some((payload, null_index));
payload
};
observe_payload_fn(payload);
continue;
}
let len = view_u128 as u32;
let maybe_payload = {
let completed = &self.completed;
let in_progress = &self.in_progress;
self.map
.find(hash, |header| {
if header.hash != hash {
return false;
}
if len <= 12 {
return header.view == view_u128;
}
let stored_prefix = (header.view >> 32) as u32;
let input_prefix = (view_u128 >> 32) as u32;
if stored_prefix != input_prefix {
return false;
}
let byte_view = ByteView::from(header.view);
let stored_len = byte_view.length as usize;
let buffer_index = byte_view.buffer_index as usize;
let offset = byte_view.offset as usize;
let stored_value = if buffer_index < completed.len() {
&completed[buffer_index].as_slice()
[offset..offset + stored_len]
} else {
&in_progress[offset..offset + stored_len]
};
let input_value: &[u8] = values.value(i).as_ref();
stored_value == input_value
})
.map(|entry| entry.payload)
};
let payload = if let Some(payload) = maybe_payload {
payload
} else {
let value: &[u8] = values.value(i).as_ref();
let payload = make_payload_fn(Some(value));
let new_view = self.append_value(value);
let new_header = Entry {
view: new_view,
hash,
payload,
};
self.map
.insert_accounted(new_header, |h| h.hash, &mut self.map_size);
payload
};
observe_payload_fn(payload);
}
}
pub fn into_state(mut self) -> ArrayRef {
if !self.in_progress.is_empty() {
let flushed = std::mem::take(&mut self.in_progress);
self.completed.push(Buffer::from_vec(flushed));
}
let null_buffer = self.nulls.finish();
let views = ScalarBuffer::from(self.views);
let array =
unsafe { BinaryViewArray::new_unchecked(views, self.completed, null_buffer) };
match self.output_type {
OutputType::BinaryView => Arc::new(array),
OutputType::Utf8View => {
let array = unsafe { array.to_string_view_unchecked() };
Arc::new(array)
}
_ => unreachable!("Utf8/Binary should use `ArrowBytesMap`"),
}
}
fn append_value(&mut self, value: &[u8]) -> u128 {
let len = value.len();
let view = if len <= 12 {
make_view(value, 0, 0)
} else {
if self.in_progress.len() + len > BYTE_VIEW_MAX_BLOCK_SIZE {
let flushed = std::mem::replace(
&mut self.in_progress,
Vec::with_capacity(BYTE_VIEW_MAX_BLOCK_SIZE),
);
self.completed.push(Buffer::from_vec(flushed));
}
let buffer_index = self.completed.len() as u32;
let offset = self.in_progress.len() as u32;
self.in_progress.extend_from_slice(value);
make_view(value, buffer_index, offset)
};
self.views.push(view);
self.nulls.append_non_null();
view
}
pub fn len(&self) -> usize {
self.non_null_len() + self.null.map(|_| 1).unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.map.is_empty() && self.null.is_none()
}
pub fn non_null_len(&self) -> usize {
self.map.len()
}
pub fn size(&self) -> usize {
let views_size = self.views.len() * size_of::<u128>();
let in_progress_size = self.in_progress.capacity();
let completed_size: usize = self.completed.iter().map(|b| b.len()).sum();
let nulls_size = self.nulls.allocated_size();
self.map_size
+ views_size
+ in_progress_size
+ completed_size
+ nulls_size
+ self.hashes_buffer.allocated_size()
}
}
impl<V> Debug for ArrowBytesViewMap<V>
where
V: Debug + PartialEq + Eq + Clone + Copy + Default,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ArrowBytesMap")
.field("map", &"<map>")
.field("map_size", &self.map_size)
.field("views_len", &self.views.len())
.field("completed_buffers", &self.completed.len())
.field("random_state", &self.random_state)
.field("hashes_buffer", &self.hashes_buffer)
.finish()
}
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
struct Entry<V>
where
V: Debug + PartialEq + Eq + Clone + Copy + Default,
{
view: u128,
hash: u64,
payload: V,
}
#[cfg(test)]
mod tests {
use arrow::array::{BinaryViewArray, GenericByteViewArray, StringViewArray};
use datafusion_common::HashMap;
use super::*;
fn assert_set(set: ArrowBytesViewSet, expected: &[Option<&str>]) {
let strings = set.into_state();
let strings = strings.as_string_view();
let state = strings.into_iter().collect::<Vec<_>>();
assert_eq!(state, expected);
}
#[test]
fn string_view_set_empty() {
let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
let array: ArrayRef = Arc::new(StringViewArray::new_null(0));
set.insert(&array);
assert_eq!(set.len(), 0);
assert_eq!(set.non_null_len(), 0);
assert_set(set, &[]);
}
#[test]
fn string_view_set_one_null() {
let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
let array: ArrayRef = Arc::new(StringViewArray::new_null(1));
set.insert(&array);
assert_eq!(set.len(), 1);
assert_eq!(set.non_null_len(), 0);
assert_set(set, &[None]);
}
#[test]
fn string_view_set_many_null() {
let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
let array: ArrayRef = Arc::new(StringViewArray::new_null(11));
set.insert(&array);
assert_eq!(set.len(), 1);
assert_eq!(set.non_null_len(), 0);
assert_set(set, &[None]);
}
#[test]
fn test_string_view_set_basic() {
let values = GenericByteViewArray::from(vec![
Some("a"),
Some("b"),
Some("CXCCCCCCCCAABB"), Some(""),
Some("cbcxx"), None,
Some("AAAAAAAA"), Some("BBBBBQBBBAAA"), Some("a"),
Some("cbcxx"),
Some("b"),
Some("cbcxx"),
Some(""),
None,
Some("BBBBBQBBBAAA"),
Some("BBBBBQBBBAAA"),
Some("AAAAAAAA"),
Some("CXCCCCCCCCAABB"),
]);
let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
let array: ArrayRef = Arc::new(values);
set.insert(&array);
assert_set(
set,
&[
Some("a"),
Some("b"),
Some("CXCCCCCCCCAABB"),
Some(""),
Some("cbcxx"),
None,
Some("AAAAAAAA"),
Some("BBBBBQBBBAAA"),
],
);
}
#[test]
fn test_string_set_non_utf8() {
let values = GenericByteViewArray::from(vec![
Some("a"),
Some("✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥"),
Some("🔥"),
Some("✨✨✨"),
Some("foobarbaz"),
Some("🔥"),
Some("✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥"),
]);
let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
let array: ArrayRef = Arc::new(values);
set.insert(&array);
assert_set(
set,
&[
Some("a"),
Some("✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥"),
Some("🔥"),
Some("✨✨✨"),
Some("foobarbaz"),
],
);
}
#[test]
fn test_binary_set() {
let v: Vec<Option<&[u8]>> = vec![
Some(b"a"),
Some(b"CXCCCCCCCCCCCCC"),
None,
Some(b"CXCCCCCCCCCCCCC"),
];
let values: ArrayRef = Arc::new(BinaryViewArray::from(v));
let expected: Vec<Option<&[u8]>> =
vec![Some(b"a"), Some(b"CXCCCCCCCCCCCCC"), None];
let expected: ArrayRef = Arc::new(GenericByteViewArray::from(expected));
let mut set = ArrowBytesViewSet::new(OutputType::BinaryView);
set.insert(&values);
assert_eq!(&set.into_state(), &expected);
}
#[test]
fn test_string_set_memory_usage() {
let strings1 = StringViewArray::from(vec![
Some("a"),
Some("b"),
Some("CXCCCCCCCCCCC"), Some("AAAAAAAA"), Some("BBBBBQBBB"), ]);
let total_strings1_len = strings1
.iter()
.map(|s| s.map(|s| s.len()).unwrap_or(0))
.sum::<usize>();
let values1: ArrayRef = Arc::new(StringViewArray::from(strings1));
let strings2 = StringViewArray::from(vec![
"FOO".repeat(1000),
"BAR larger than 12 bytes.".repeat(100_000),
"more unique.".repeat(1000),
"more unique2.".repeat(1000),
"FOO".repeat(3000),
]);
let total_strings2_len = strings2
.iter()
.map(|s| s.map(|s| s.len()).unwrap_or(0))
.sum::<usize>();
let values2: ArrayRef = Arc::new(StringViewArray::from(strings2));
let mut set = ArrowBytesViewSet::new(OutputType::Utf8View);
let size_empty = set.size();
set.insert(&values1);
let size_after_values1 = set.size();
assert!(size_empty < size_after_values1);
assert!(
size_after_values1 > total_strings1_len,
"expect {size_after_values1} to be more than {total_strings1_len}"
);
assert!(size_after_values1 < total_strings1_len + total_strings2_len);
set.insert(&values1);
assert_eq!(set.size(), size_after_values1);
assert_eq!(set.len(), 5);
set.insert(&values2);
let size_after_values2 = set.size();
assert!(size_after_values2 > size_after_values1);
assert_eq!(set.len(), 10);
}
#[derive(Debug, PartialEq, Eq, Default, Clone, Copy)]
struct TestPayload {
index: usize, }
struct TestMap {
map: ArrowBytesViewMap<TestPayload>,
strings: Vec<Option<String>>,
indexes: HashMap<Option<String>, usize>,
}
impl TestMap {
fn new() -> Self {
Self {
map: ArrowBytesViewMap::new(OutputType::Utf8View),
strings: vec![],
indexes: HashMap::new(),
}
}
fn insert(&mut self, strings: &[Option<&str>]) {
let string_array = StringViewArray::from(strings.to_vec());
let arr: ArrayRef = Arc::new(string_array);
let mut next_index = self.indexes.len();
let mut actual_new_strings = vec![];
let mut actual_seen_indexes = vec![];
for str in strings {
let str = str.map(|s| s.to_string());
let index = self.indexes.get(&str).cloned().unwrap_or_else(|| {
actual_new_strings.push(str.clone());
let index = self.strings.len();
self.strings.push(str.clone());
self.indexes.insert(str, index);
index
});
actual_seen_indexes.push(index);
}
let mut seen_new_strings = vec![];
let mut seen_indexes = vec![];
self.map.insert_if_new(
&arr,
|s| {
let value = s
.map(|s| String::from_utf8(s.to_vec()).expect("Non utf8 string"));
let index = next_index;
next_index += 1;
seen_new_strings.push(value);
TestPayload { index }
},
|payload| {
seen_indexes.push(payload.index);
},
);
assert_eq!(actual_seen_indexes, seen_indexes);
assert_eq!(actual_new_strings, seen_new_strings);
}
fn into_array(self) -> ArrayRef {
let Self {
map,
strings,
indexes: _,
} = self;
let arr = map.into_state();
let expected: ArrayRef = Arc::new(StringViewArray::from(strings));
assert_eq!(&arr, &expected);
arr
}
}
#[test]
fn test_map() {
let input = vec![
Some("A"),
Some("bcdefghijklmnop1234567"),
Some("X"),
Some("Y"),
None,
Some("qrstuvqxyzhjwya"),
Some("✨🔥"),
Some("🔥"),
Some("🔥🔥🔥🔥🔥🔥"),
];
let mut test_map = TestMap::new();
test_map.insert(&input);
test_map.insert(&input); let expected_output: ArrayRef = Arc::new(StringViewArray::from(input));
assert_eq!(&test_map.into_array(), &expected_output);
}
}