use std::sync::Arc;
use crate::filter::filter;
use crate::interleave::interleave;
use ahash::RandomState;
use arrow_array::builder::BooleanBufferBuilder;
use arrow_array::types::{
ArrowDictionaryKeyType, ArrowPrimitiveType, BinaryType, ByteArrayType, LargeBinaryType,
LargeUtf8Type, Utf8Type,
};
use arrow_array::{
AnyDictionaryArray, Array, ArrayRef, ArrowNativeTypeOp, BooleanArray, DictionaryArray,
GenericByteArray, PrimitiveArray, downcast_dictionary_array,
};
use arrow_array::{cast::AsArray, downcast_primitive};
use arrow_buffer::{ArrowNativeType, BooleanBuffer, ScalarBuffer, ToByteSlice};
use arrow_schema::{ArrowError, DataType};
pub fn garbage_collect_dictionary<K: ArrowDictionaryKeyType>(
dictionary: &DictionaryArray<K>,
) -> Result<DictionaryArray<K>, ArrowError> {
let keys = dictionary.keys();
let values = dictionary.values();
let mask = dictionary.occupancy();
if mask.count_set_bits() == values.len() {
return Ok(dictionary.clone());
}
let mut key_remap = vec![K::Native::ZERO; values.len()];
for (new_idx, old_idx) in mask.set_indices().enumerate() {
key_remap[old_idx] = K::Native::from_usize(new_idx)
.expect("new index should fit in K::Native, as old index was in range");
}
let new_keys = keys.unary(|key| {
key_remap
.get(key.as_usize())
.copied()
.unwrap_or(K::Native::ZERO)
});
let values = filter(dictionary.values(), &BooleanArray::new(mask, None))?;
DictionaryArray::try_new(new_keys, values)
}
pub fn garbage_collect_any_dictionary(
dictionary: &dyn AnyDictionaryArray,
) -> Result<ArrayRef, ArrowError> {
let dictionary = &*dictionary.slice(0, dictionary.len());
downcast_dictionary_array!(
dictionary => garbage_collect_dictionary(dictionary).map(|dict| Arc::new(dict) as ArrayRef),
_ => unreachable!("have a dictionary array")
)
}
struct Interner<'a, V> {
state: RandomState,
buckets: Vec<Option<InternerBucket<'a, V>>>,
shift: u32,
}
type InternerBucket<'a, V> = (Option<&'a [u8]>, V);
impl<'a, V> Interner<'a, V> {
fn new(capacity: usize) -> Self {
let shift = (capacity as u64 + 128).leading_zeros();
let num_buckets = (u64::MAX >> shift) as usize;
let buckets = (0..num_buckets.saturating_add(1)).map(|_| None).collect();
Self {
state: RandomState::with_seeds(0, 0, 0, 0),
buckets,
shift,
}
}
fn intern<F: FnOnce() -> Result<V, E>, E>(
&mut self,
new: Option<&'a [u8]>,
f: F,
) -> Result<&V, E> {
let hash = self.state.hash_one(new);
let bucket_idx = hash >> self.shift;
Ok(match &mut self.buckets[bucket_idx as usize] {
Some((current, v)) => {
if *current != new {
*v = f()?;
*current = new;
}
v
}
slot => &slot.insert((new, f()?)).1,
})
}
}
pub(crate) struct MergedDictionaries<K: ArrowDictionaryKeyType> {
pub key_mappings: Vec<Vec<K::Native>>,
pub values: ArrayRef,
}
fn bytes_ptr_eq<T: ByteArrayType>(a: &dyn Array, b: &dyn Array) -> bool {
match (a.as_bytes_opt::<T>(), b.as_bytes_opt::<T>()) {
(Some(a), Some(b)) => {
let values_eq = a.values().ptr_eq(b.values()) && a.offsets().ptr_eq(b.offsets());
match (a.nulls(), b.nulls()) {
(Some(a), Some(b)) => values_eq && a.inner().ptr_eq(b.inner()),
(None, None) => values_eq,
_ => false,
}
}
_ => false,
}
}
type PtrEq = fn(&dyn Array, &dyn Array) -> bool;
pub(crate) fn should_merge_dictionary_values<K: ArrowDictionaryKeyType>(
dictionaries: &[&DictionaryArray<K>],
len: usize,
) -> (bool, bool) {
use DataType::*;
let first_values = dictionaries[0].values().as_ref();
let ptr_eq: PtrEq = match first_values.data_type() {
Utf8 => bytes_ptr_eq::<Utf8Type>,
LargeUtf8 => bytes_ptr_eq::<LargeUtf8Type>,
Binary => bytes_ptr_eq::<BinaryType>,
LargeBinary => bytes_ptr_eq::<LargeBinaryType>,
dt => {
if !dt.is_primitive() {
return (
false,
K::Native::from_usize(dictionaries.iter().map(|d| d.values().len()).sum())
.is_none(),
);
}
|a, b| a.to_data().ptr_eq(&b.to_data())
}
};
let mut single_dictionary = true;
let mut total_values = first_values.len();
for dict in dictionaries.iter().skip(1) {
let values = dict.values().as_ref();
total_values += values.len();
if single_dictionary {
single_dictionary = ptr_eq(first_values, values)
}
}
let overflow = K::Native::from_usize(total_values).is_none();
let values_exceed_length = total_values >= len;
(
!single_dictionary && (overflow || values_exceed_length),
overflow,
)
}
pub(crate) fn merge_dictionary_values<K: ArrowDictionaryKeyType>(
dictionaries: &[&DictionaryArray<K>],
masks: Option<&[BooleanBuffer]>,
) -> Result<MergedDictionaries<K>, ArrowError> {
let mut num_values = 0;
let mut values_arrays = Vec::with_capacity(dictionaries.len());
let mut value_slices = Vec::with_capacity(dictionaries.len());
for (idx, dictionary) in dictionaries.iter().enumerate() {
let mask = masks.and_then(|m| m.get(idx));
let key_mask_owned;
let key_mask = match (dictionary.nulls(), mask) {
(Some(n), None) => Some(n.inner()),
(None, Some(n)) => Some(n),
(Some(n), Some(m)) => {
key_mask_owned = n.inner() & m;
Some(&key_mask_owned)
}
(None, None) => None,
};
let keys = dictionary.keys().values();
let values = dictionary.values().as_ref();
let values_mask = compute_values_mask(keys, key_mask, values.len());
let masked_values = get_masked_values(values, &values_mask);
num_values += masked_values.len();
value_slices.push(masked_values);
values_arrays.push(values)
}
let mut interner = Interner::new(num_values);
let mut indices = Vec::with_capacity(num_values);
let key_mappings = dictionaries
.iter()
.enumerate()
.zip(value_slices)
.map(|((dictionary_idx, dictionary), values)| {
let zero = K::Native::from_usize(0).unwrap();
let mut mapping = vec![zero; dictionary.values().len()];
for (value_idx, value) in values {
mapping[value_idx] =
*interner.intern(value, || match K::Native::from_usize(indices.len()) {
Some(idx) => {
indices.push((dictionary_idx, value_idx));
Ok(idx)
}
None => Err(ArrowError::DictionaryKeyOverflowError),
})?;
}
Ok(mapping)
})
.collect::<Result<Vec<_>, ArrowError>>()?;
Ok(MergedDictionaries {
key_mappings,
values: interleave(&values_arrays, &indices)?,
})
}
fn compute_values_mask<K: ArrowNativeType>(
keys: &ScalarBuffer<K>,
mask: Option<&BooleanBuffer>,
max_key: usize,
) -> BooleanBuffer {
let mut builder = BooleanBufferBuilder::new(max_key);
builder.advance(max_key);
match mask {
Some(n) => n
.set_indices()
.for_each(|idx| builder.set_bit(keys[idx].as_usize(), true)),
None => keys
.iter()
.for_each(|k| builder.set_bit(k.as_usize(), true)),
}
builder.finish()
}
fn masked_primitives_to_bytes<'a, T: ArrowPrimitiveType>(
array: &'a PrimitiveArray<T>,
mask: &BooleanBuffer,
) -> Vec<(usize, Option<&'a [u8]>)>
where
T::Native: ToByteSlice,
{
let mut out = Vec::with_capacity(mask.count_set_bits());
let values = array.values();
for idx in mask.set_indices() {
out.push((
idx,
array.is_valid(idx).then_some(values[idx].to_byte_slice()),
))
}
out
}
macro_rules! masked_primitive_to_bytes_helper {
($t:ty, $array:expr, $mask:expr) => {
masked_primitives_to_bytes::<$t>($array.as_primitive(), $mask)
};
}
fn get_masked_values<'a>(
array: &'a dyn Array,
mask: &BooleanBuffer,
) -> Vec<(usize, Option<&'a [u8]>)> {
downcast_primitive! {
array.data_type() => (masked_primitive_to_bytes_helper, array, mask),
DataType::Utf8 => masked_bytes(array.as_string::<i32>(), mask),
DataType::LargeUtf8 => masked_bytes(array.as_string::<i64>(), mask),
DataType::Binary => masked_bytes(array.as_binary::<i32>(), mask),
DataType::LargeBinary => masked_bytes(array.as_binary::<i64>(), mask),
_ => unimplemented!("Dictionary merging for type {} is not implemented", array.data_type()),
}
}
fn masked_bytes<'a, T: ByteArrayType>(
array: &'a GenericByteArray<T>,
mask: &BooleanBuffer,
) -> Vec<(usize, Option<&'a [u8]>)> {
let mut out = Vec::with_capacity(mask.count_set_bits());
for idx in mask.set_indices() {
out.push((
idx,
array.is_valid(idx).then_some(array.value(idx).as_ref()),
))
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::cast::as_string_array;
use arrow_array::types::Int8Type;
use arrow_array::types::Int32Type;
use arrow_array::{DictionaryArray, Int8Array, Int32Array, StringArray};
use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer, OffsetBuffer};
use std::sync::Arc;
#[test]
fn test_garbage_collect_i32_dictionary() {
let values = StringArray::from_iter_values(["a", "b", "c", "d"]);
let keys = Int32Array::from_iter_values([0, 1, 1, 3, 0, 0, 1]);
let dict = DictionaryArray::<Int32Type>::new(keys, Arc::new(values));
let gc = garbage_collect_dictionary(&dict).unwrap();
let expected_values = StringArray::from_iter_values(["a", "b", "d"]);
let expected_keys = Int32Array::from_iter_values([0, 1, 1, 2, 0, 0, 1]);
let expected = DictionaryArray::<Int32Type>::new(expected_keys, Arc::new(expected_values));
assert_eq!(gc, expected);
}
#[test]
fn test_garbage_collect_any_dictionary() {
let values = StringArray::from_iter_values(["a", "b", "c", "d"]);
let keys = Int32Array::from_iter_values([0, 1, 1, 3, 0, 0, 1]);
let dict = DictionaryArray::<Int32Type>::new(keys, Arc::new(values));
let gc = garbage_collect_any_dictionary(&dict).unwrap();
let expected_values = StringArray::from_iter_values(["a", "b", "d"]);
let expected_keys = Int32Array::from_iter_values([0, 1, 1, 2, 0, 0, 1]);
let expected = DictionaryArray::<Int32Type>::new(expected_keys, Arc::new(expected_values));
assert_eq!(gc.as_ref(), &expected);
}
#[test]
fn test_garbage_collect_with_nulls() {
let values = StringArray::from_iter_values(["a", "b", "c"]);
let keys = Int8Array::from(vec![Some(2), None, Some(0)]);
let dict = DictionaryArray::<Int8Type>::new(keys, Arc::new(values));
let gc = garbage_collect_dictionary(&dict).unwrap();
let expected_values = StringArray::from_iter_values(["a", "c"]);
let expected_keys = Int8Array::from(vec![Some(1), None, Some(0)]);
let expected = DictionaryArray::<Int8Type>::new(expected_keys, Arc::new(expected_values));
assert_eq!(gc, expected);
}
#[test]
fn test_garbage_collect_empty_dictionary() {
let values = StringArray::from_iter_values::<&str, _>([]);
let keys = Int32Array::from_iter_values([]);
let dict = DictionaryArray::<Int32Type>::new(keys, Arc::new(values));
let gc = garbage_collect_dictionary(&dict).unwrap();
assert_eq!(gc, dict);
}
#[test]
fn test_garbage_collect_dictionary_all_unreferenced() {
let values = StringArray::from_iter_values(["a", "b", "c"]);
let keys = Int32Array::from(vec![None, None, None]);
let dict = DictionaryArray::<Int32Type>::new(keys, Arc::new(values));
let gc = garbage_collect_dictionary(&dict).unwrap();
let expected_values = StringArray::from_iter_values::<&str, _>([]);
let expected_keys = Int32Array::from(vec![None, None, None]);
let expected = DictionaryArray::<Int32Type>::new(expected_keys, Arc::new(expected_values));
assert_eq!(gc, expected);
}
#[test]
fn test_merge_strings() {
let a = DictionaryArray::<Int32Type>::from_iter(["a", "b", "a", "b", "d", "c", "e"]);
let b = DictionaryArray::<Int32Type>::from_iter(["c", "f", "c", "d", "a", "d"]);
let merged = merge_dictionary_values(&[&a, &b], None).unwrap();
let values = as_string_array(merged.values.as_ref());
let actual: Vec<_> = values.iter().map(Option::unwrap).collect();
assert_eq!(&actual, &["a", "b", "d", "c", "e", "f"]);
assert_eq!(merged.key_mappings.len(), 2);
assert_eq!(&merged.key_mappings[0], &[0, 1, 2, 3, 4]);
assert_eq!(&merged.key_mappings[1], &[3, 5, 2, 0]);
let a_slice = a.slice(1, 4);
let merged = merge_dictionary_values(&[&a_slice, &b], None).unwrap();
let values = as_string_array(merged.values.as_ref());
let actual: Vec<_> = values.iter().map(Option::unwrap).collect();
assert_eq!(&actual, &["a", "b", "d", "c", "f"]);
assert_eq!(merged.key_mappings.len(), 2);
assert_eq!(&merged.key_mappings[0], &[0, 1, 2, 0, 0]);
assert_eq!(&merged.key_mappings[1], &[3, 4, 2, 0]);
let a_mask = BooleanBuffer::from_iter([false, true, false, true, true, false, false]);
let b_mask = BooleanBuffer::new_set(b.len());
let merged = merge_dictionary_values(&[&a, &b], Some(&[a_mask, b_mask])).unwrap();
let values = as_string_array(merged.values.as_ref());
let actual: Vec<_> = values.iter().map(Option::unwrap).collect();
assert_eq!(&actual, &["b", "d", "c", "f", "a"]);
assert_eq!(merged.key_mappings.len(), 2);
assert_eq!(&merged.key_mappings[0], &[0, 0, 1, 0, 0]);
assert_eq!(&merged.key_mappings[1], &[2, 3, 1, 4]);
}
#[test]
fn test_merge_nulls() {
let buffer = Buffer::from(b"helloworldbingohelloworld");
let offsets = OffsetBuffer::from_lengths([5, 5, 5, 5, 5]);
let nulls = NullBuffer::from(vec![true, false, true, true, true]);
let values = StringArray::new(offsets, buffer, Some(nulls));
let key_values = vec![1, 2, 3, 1, 8, 2, 3];
let key_nulls = NullBuffer::from(vec![true, true, false, true, false, true, true]);
let keys = Int32Array::new(key_values.into(), Some(key_nulls));
let a = DictionaryArray::new(keys, Arc::new(values));
let b = DictionaryArray::new(Int32Array::new_null(10), Arc::new(StringArray::new_null(0)));
let merged = merge_dictionary_values(&[&a, &b], None).unwrap();
let expected = StringArray::from(vec![None, Some("bingo"), Some("hello")]);
assert_eq!(merged.values.as_ref(), &expected);
assert_eq!(merged.key_mappings.len(), 2);
assert_eq!(&merged.key_mappings[0], &[0, 0, 1, 2, 0]);
assert_eq!(&merged.key_mappings[1], &[] as &[i32; 0]);
}
#[test]
fn test_merge_keys_smaller() {
let values = StringArray::from_iter_values(["a", "b"]);
let keys = Int32Array::from_iter_values([1]);
let a = DictionaryArray::new(keys, Arc::new(values));
let merged = merge_dictionary_values(&[&a], None).unwrap();
let expected = StringArray::from(vec!["b"]);
assert_eq!(merged.values.as_ref(), &expected);
}
}