use crate::builder::{ArrayBuilder, PrimitiveBuilder, StringBuilder};
use crate::types::ArrowDictionaryKeyType;
use crate::{Array, ArrayRef, DictionaryArray, StringArray};
use arrow_buffer::ArrowNativeType;
use arrow_schema::{ArrowError, DataType};
use hashbrown::hash_map::RawEntryMut;
use hashbrown::HashMap;
use std::any::Any;
use std::sync::Arc;
#[derive(Debug)]
pub struct StringDictionaryBuilder<K>
where
K: ArrowDictionaryKeyType,
{
state: ahash::RandomState,
dedup: HashMap<K::Native, (), ()>,
keys_builder: PrimitiveBuilder<K>,
values_builder: StringBuilder,
}
impl<K> Default for StringDictionaryBuilder<K>
where
K: ArrowDictionaryKeyType,
{
fn default() -> Self {
Self::new()
}
}
impl<K> StringDictionaryBuilder<K>
where
K: ArrowDictionaryKeyType,
{
pub fn new() -> Self {
let keys_builder = PrimitiveBuilder::new();
let values_builder = StringBuilder::new();
Self {
state: Default::default(),
dedup: HashMap::with_capacity_and_hasher(keys_builder.capacity(), ()),
keys_builder,
values_builder,
}
}
pub fn with_capacity(
keys_capacity: usize,
value_capacity: usize,
string_capacity: usize,
) -> Self {
Self {
state: Default::default(),
dedup: Default::default(),
keys_builder: PrimitiveBuilder::with_capacity(keys_capacity),
values_builder: StringBuilder::with_capacity(value_capacity, string_capacity),
}
}
pub fn new_with_dictionary(
keys_capacity: usize,
dictionary_values: &StringArray,
) -> Result<Self, ArrowError> {
let state = ahash::RandomState::default();
let dict_len = dictionary_values.len();
let mut dedup = HashMap::with_capacity_and_hasher(dict_len, ());
let values_len = dictionary_values.value_data().len();
let mut values_builder = StringBuilder::with_capacity(dict_len, values_len);
for (idx, maybe_value) in dictionary_values.iter().enumerate() {
match maybe_value {
Some(value) => {
let hash = state.hash_one(value.as_bytes());
let key = K::Native::from_usize(idx)
.ok_or(ArrowError::DictionaryKeyOverflowError)?;
let entry =
dedup.raw_entry_mut().from_hash(hash, |key: &K::Native| {
value.as_bytes() == get_bytes(&values_builder, key)
});
if let RawEntryMut::Vacant(v) = entry {
v.insert_with_hasher(hash, key, (), |key| {
state.hash_one(get_bytes(&values_builder, key))
});
}
values_builder.append_value(value);
}
None => values_builder.append_null(),
}
}
Ok(Self {
state,
dedup,
keys_builder: PrimitiveBuilder::with_capacity(keys_capacity),
values_builder,
})
}
}
impl<K> ArrayBuilder for StringDictionaryBuilder<K>
where
K: ArrowDictionaryKeyType,
{
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
fn into_box_any(self: Box<Self>) -> Box<dyn Any> {
self
}
fn len(&self) -> usize {
self.keys_builder.len()
}
fn is_empty(&self) -> bool {
self.keys_builder.is_empty()
}
fn finish(&mut self) -> ArrayRef {
Arc::new(self.finish())
}
fn finish_cloned(&self) -> ArrayRef {
Arc::new(self.finish_cloned())
}
}
impl<K> StringDictionaryBuilder<K>
where
K: ArrowDictionaryKeyType,
{
pub fn append(&mut self, value: impl AsRef<str>) -> Result<K::Native, ArrowError> {
let value = value.as_ref();
let state = &self.state;
let storage = &mut self.values_builder;
let hash = state.hash_one(value.as_bytes());
let entry = self
.dedup
.raw_entry_mut()
.from_hash(hash, |key| value.as_bytes() == get_bytes(storage, key));
let key = match entry {
RawEntryMut::Occupied(entry) => *entry.into_key(),
RawEntryMut::Vacant(entry) => {
let index = storage.len();
storage.append_value(value);
let key = K::Native::from_usize(index)
.ok_or(ArrowError::DictionaryKeyOverflowError)?;
*entry
.insert_with_hasher(hash, key, (), |key| {
state.hash_one(get_bytes(storage, key))
})
.0
}
};
self.keys_builder.append_value(key);
Ok(key)
}
#[inline]
pub fn append_null(&mut self) {
self.keys_builder.append_null()
}
pub fn finish(&mut self) -> DictionaryArray<K> {
self.dedup.clear();
let values = self.values_builder.finish();
let keys = self.keys_builder.finish();
let data_type =
DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(DataType::Utf8));
let builder = keys
.into_data()
.into_builder()
.data_type(data_type)
.child_data(vec![values.into_data()]);
DictionaryArray::from(unsafe { builder.build_unchecked() })
}
pub fn finish_cloned(&self) -> DictionaryArray<K> {
let values = self.values_builder.finish_cloned();
let keys = self.keys_builder.finish_cloned();
let data_type =
DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(DataType::Utf8));
let builder = keys
.into_data()
.into_builder()
.data_type(data_type)
.child_data(vec![values.into_data()]);
DictionaryArray::from(unsafe { builder.build_unchecked() })
}
}
fn get_bytes<'a, K: ArrowNativeType>(values: &'a StringBuilder, key: &K) -> &'a [u8] {
let offsets = values.offsets_slice();
let values = values.values_slice();
let idx = key.as_usize();
let end_offset = offsets[idx + 1].as_usize();
let start_offset = offsets[idx].as_usize();
&values[start_offset..end_offset]
}
#[cfg(test)]
mod tests {
use super::*;
use crate::array::Array;
use crate::array::Int8Array;
use crate::types::{Int16Type, Int8Type};
#[test]
fn test_string_dictionary_builder() {
let mut builder = StringDictionaryBuilder::<Int8Type>::new();
builder.append("abc").unwrap();
builder.append_null();
builder.append("def").unwrap();
builder.append("def").unwrap();
builder.append("abc").unwrap();
let array = builder.finish();
assert_eq!(
array.keys(),
&Int8Array::from(vec![Some(0), None, Some(1), Some(1), Some(0)])
);
let av = array.values();
let ava: &StringArray = av.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(ava.value(0), "abc");
assert_eq!(ava.value(1), "def");
}
#[test]
fn test_string_dictionary_builder_finish_cloned() {
let mut builder = StringDictionaryBuilder::<Int8Type>::new();
builder.append("abc").unwrap();
builder.append_null();
builder.append("def").unwrap();
builder.append("def").unwrap();
builder.append("abc").unwrap();
let mut array = builder.finish_cloned();
assert_eq!(
array.keys(),
&Int8Array::from(vec![Some(0), None, Some(1), Some(1), Some(0)])
);
let av = array.values();
let ava: &StringArray = av.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(ava.value(0), "abc");
assert_eq!(ava.value(1), "def");
builder.append("abc").unwrap();
builder.append("ghi").unwrap();
builder.append("def").unwrap();
array = builder.finish();
assert_eq!(
array.keys(),
&Int8Array::from(vec![
Some(0),
None,
Some(1),
Some(1),
Some(0),
Some(0),
Some(2),
Some(1)
])
);
let av2 = array.values();
let ava2: &StringArray = av2.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(ava2.value(0), "abc");
assert_eq!(ava2.value(1), "def");
assert_eq!(ava2.value(2), "ghi");
}
#[test]
fn test_string_dictionary_builder_with_existing_dictionary() {
let dictionary = StringArray::from(vec![None, Some("def"), Some("abc")]);
let mut builder =
StringDictionaryBuilder::new_with_dictionary(6, &dictionary).unwrap();
builder.append("abc").unwrap();
builder.append_null();
builder.append("def").unwrap();
builder.append("def").unwrap();
builder.append("abc").unwrap();
builder.append("ghi").unwrap();
let array = builder.finish();
assert_eq!(
array.keys(),
&Int8Array::from(vec![Some(2), None, Some(1), Some(1), Some(2), Some(3)])
);
let av = array.values();
let ava: &StringArray = av.as_any().downcast_ref::<StringArray>().unwrap();
assert!(!ava.is_valid(0));
assert_eq!(ava.value(1), "def");
assert_eq!(ava.value(2), "abc");
assert_eq!(ava.value(3), "ghi");
}
#[test]
fn test_string_dictionary_builder_with_reserved_null_value() {
let dictionary: Vec<Option<&str>> = vec![None];
let dictionary = StringArray::from(dictionary);
let mut builder =
StringDictionaryBuilder::<Int16Type>::new_with_dictionary(4, &dictionary)
.unwrap();
builder.append("abc").unwrap();
builder.append_null();
builder.append("def").unwrap();
builder.append("abc").unwrap();
let array = builder.finish();
assert!(array.is_null(1));
assert!(!array.is_valid(1));
let keys = array.keys();
assert_eq!(keys.value(0), 1);
assert!(keys.is_null(1));
assert_eq!(keys.value(1), 0);
assert_eq!(keys.value(2), 2);
assert_eq!(keys.value(3), 1);
}
}