use ahash::RandomState;
use arrow_array::cast::AsArray;
use arrow_array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait};
use arrow_buffer::{BufferBuilder, OffsetBuffer, ScalarBuffer};
use datafusion_common::cast::as_list_array;
use datafusion_common::hash_utils::create_hashes;
use datafusion_common::utils::array_into_list_array;
use datafusion_common::ScalarValue;
use datafusion_execution::memory_pool::proxy::RawTableAllocExt;
use datafusion_expr::Accumulator;
use std::fmt::Debug;
use std::mem;
use std::ops::Range;
use std::sync::Arc;
#[derive(Debug)]
pub(super) struct StringDistinctCountAccumulator<O: OffsetSizeTrait>(SSOStringHashSet<O>);
impl<O: OffsetSizeTrait> StringDistinctCountAccumulator<O> {
pub(super) fn new() -> Self {
Self(SSOStringHashSet::<O>::new())
}
}
impl<O: OffsetSizeTrait> Accumulator for StringDistinctCountAccumulator<O> {
fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
let set = std::mem::take(&mut self.0);
let arr = set.into_state();
let list = Arc::new(array_into_list_array(arr));
Ok(vec![ScalarValue::List(list)])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> {
if values.is_empty() {
return Ok(());
}
self.0.insert(&values[0]);
Ok(())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> {
if states.is_empty() {
return Ok(());
}
assert_eq!(
states.len(),
1,
"count_distinct states must be single array"
);
let arr = as_list_array(&states[0])?;
arr.iter().try_for_each(|maybe_list| {
if let Some(list) = maybe_list {
self.0.insert(&list);
};
Ok(())
})
}
fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
Ok(ScalarValue::Int64(Some(self.0.len() as i64)))
}
fn size(&self) -> usize {
std::mem::size_of_val(self) + self.0.size()
}
}
const SHORT_STRING_LEN: usize = mem::size_of::<usize>();
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
struct SSOStringHeader {
hash: u64,
offset_or_inline: usize,
len: usize,
}
impl SSOStringHeader {
fn range(&self) -> Range<usize> {
self.offset_or_inline..self.offset_or_inline + self.len
}
}
struct SSOStringHashSet<O> {
map: hashbrown::raw::RawTable<SSOStringHeader>,
map_size: usize,
buffer: BufferBuilder<u8>,
offsets: Vec<O>,
random_state: RandomState,
hashes_buffer: Vec<u64>,
}
impl<O: OffsetSizeTrait> Default for SSOStringHashSet<O> {
fn default() -> Self {
Self::new()
}
}
impl<O: OffsetSizeTrait> SSOStringHashSet<O> {
fn new() -> Self {
Self {
map: hashbrown::raw::RawTable::new(),
map_size: 0,
buffer: BufferBuilder::new(0),
offsets: vec![O::default()], random_state: RandomState::new(),
hashes_buffer: vec![],
}
}
fn insert(&mut self, values: &ArrayRef) {
let batch_hashes = &mut self.hashes_buffer;
batch_hashes.clear();
batch_hashes.resize(values.len(), 0);
create_hashes(&[values.clone()], &self.random_state, batch_hashes)
.unwrap();
let values = values.as_string::<O>();
assert_eq!(values.len(), batch_hashes.len());
for (value, &hash) in values.iter().zip(batch_hashes.iter()) {
let Some(value) = value else {
continue;
};
let value = value.as_bytes();
if value.len() <= SHORT_STRING_LEN {
let inline = value.iter().fold(0usize, |acc, &x| acc << 8 | x as usize);
let entry = self.map.get_mut(hash, |header| {
if header.len != value.len() {
return false;
}
inline == header.offset_or_inline
});
if entry.is_none() {
self.buffer.append_slice(value);
self.offsets.push(O::from_usize(self.buffer.len()).unwrap());
let new_header = SSOStringHeader {
hash,
len: value.len(),
offset_or_inline: inline,
};
self.map.insert_accounted(
new_header,
|header| header.hash,
&mut self.map_size,
);
}
}
else {
let entry = self.map.get_mut(hash, |header| {
if header.len != value.len() {
return false;
}
let existing_value =
unsafe { self.buffer.as_slice().get_unchecked(header.range()) };
value == existing_value
});
if entry.is_none() {
let offset = self.buffer.len(); self.buffer.append_slice(value);
self.offsets.push(O::from_usize(self.buffer.len()).unwrap());
let new_header = SSOStringHeader {
hash,
len: value.len(),
offset_or_inline: offset,
};
self.map.insert_accounted(
new_header,
|header| header.hash,
&mut self.map_size,
);
}
}
}
}
fn into_state(self) -> ArrayRef {
let Self {
map: _,
map_size: _,
offsets,
mut buffer,
random_state: _,
hashes_buffer: _,
} = self;
let offsets: ScalarBuffer<O> = offsets.into();
let values = buffer.finish();
let nulls = None;
let array = unsafe {
GenericStringArray::new_unchecked(OffsetBuffer::new(offsets), values, nulls)
};
Arc::new(array)
}
fn len(&self) -> usize {
self.map.len()
}
fn size(&self) -> usize {
self.map_size
+ self.buffer.capacity() * std::mem::size_of::<u8>()
+ self.offsets.capacity() * std::mem::size_of::<O>()
+ self.hashes_buffer.capacity() * std::mem::size_of::<u64>()
}
}
impl<O: OffsetSizeTrait> Debug for SSOStringHashSet<O> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SSOStringHashSet")
.field("map", &"<map>")
.field("map_size", &self.map_size)
.field("buffer", &self.buffer)
.field("random_state", &self.random_state)
.field("hashes_buffer", &self.hashes_buffer)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::ArrayRef;
use arrow_array::StringArray;
#[test]
fn string_set_empty() {
for values in [StringArray::new_null(0), StringArray::new_null(11)] {
let mut set = SSOStringHashSet::<i32>::new();
let array: ArrayRef = Arc::new(values);
set.insert(&array);
assert_set(set, &[]);
}
}
#[test]
fn string_set_basic_i32() {
test_string_set_basic::<i32>();
}
#[test]
fn string_set_basic_i64() {
test_string_set_basic::<i64>();
}
fn test_string_set_basic<O: OffsetSizeTrait>() {
let values = GenericStringArray::<O>::from(vec![
Some("a"),
Some("b"),
Some("CXCCCCCCCC"), Some(""),
Some("cbcxx"), None,
Some("AAAAAAAA"), Some("BBBBBQBBB"), Some("a"),
Some("cbcxx"),
Some("b"),
Some("cbcxx"),
Some(""),
None,
Some("BBBBBQBBB"),
Some("BBBBBQBBB"),
Some("AAAAAAAA"),
Some("CXCCCCCCCC"),
]);
let mut set = SSOStringHashSet::<O>::new();
let array: ArrayRef = Arc::new(values);
set.insert(&array);
assert_set(
set,
&[
Some(""),
Some("AAAAAAAA"),
Some("BBBBBQBBB"),
Some("CXCCCCCCCC"),
Some("a"),
Some("b"),
Some("cbcxx"),
],
);
}
#[test]
fn string_set_non_utf8_32() {
test_string_set_non_utf8::<i32>();
}
#[test]
fn string_set_non_utf8_64() {
test_string_set_non_utf8::<i64>();
}
fn test_string_set_non_utf8<O: OffsetSizeTrait>() {
let values = GenericStringArray::<O>::from(vec![
Some("a"),
Some("✨🔥"),
Some("🔥"),
Some("✨✨✨"),
Some("foobarbaz"),
Some("🔥"),
Some("✨🔥"),
]);
let mut set = SSOStringHashSet::<O>::new();
let array: ArrayRef = Arc::new(values);
set.insert(&array);
assert_set(
set,
&[
Some("a"),
Some("foobarbaz"),
Some("✨✨✨"),
Some("✨🔥"),
Some("🔥"),
],
);
}
fn assert_set<O: OffsetSizeTrait>(
set: SSOStringHashSet<O>,
expected: &[Option<&str>],
) {
let strings = set.into_state();
let strings = strings.as_string::<O>();
let mut state = strings.into_iter().collect::<Vec<_>>();
state.sort();
assert_eq!(state, expected);
}
#[test]
fn test_string_set_memory_usage() {
let strings1 = GenericStringArray::<i32>::from(vec![
Some("a"),
Some("b"),
Some("CXCCCCCCCC"), Some("AAAAAAAA"), Some("BBBBBQBBB"), ]);
let total_strings1_len = strings1
.iter()
.map(|s| s.map(|s| s.len()).unwrap_or(0))
.sum::<usize>();
let values1: ArrayRef = Arc::new(GenericStringArray::<i32>::from(strings1));
let strings2 = GenericStringArray::<i32>::from(vec![
"FOO".repeat(1000),
"BAR".repeat(2000),
"BAZ".repeat(3000),
]);
let total_strings2_len = strings2
.iter()
.map(|s| s.map(|s| s.len()).unwrap_or(0))
.sum::<usize>();
let values2: ArrayRef = Arc::new(GenericStringArray::<i32>::from(strings2));
let mut set = SSOStringHashSet::<i32>::new();
let size_empty = set.size();
set.insert(&values1);
let size_after_values1 = set.size();
assert!(size_empty < size_after_values1);
assert!(
size_after_values1 > total_strings1_len,
"expect {size_after_values1} to be more than {total_strings1_len}"
);
assert!(size_after_values1 < total_strings1_len + total_strings2_len);
set.insert(&values1);
assert_eq!(set.size(), size_after_values1);
set.insert(&values2);
let size_after_values2 = set.size();
assert!(size_after_values2 > size_after_values1);
assert!(size_after_values2 > total_strings1_len + total_strings2_len);
}
}