use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
use crate::array::{Array, ArrayRef, ArrowPrimitiveType, DictionaryArray};
use crate::datatypes::{ArrowNativeType, DataType, ToByteSlice};
use crate::error::{ArrowError, Result};
use super::ArrayBuilder;
use super::PrimitiveBuilder;
#[derive(Debug)]
pub struct PrimitiveDictionaryBuilder<K, V>
where
K: ArrowPrimitiveType,
V: ArrowPrimitiveType,
{
keys_builder: PrimitiveBuilder<K>,
values_builder: PrimitiveBuilder<V>,
map: HashMap<Box<[u8]>, K::Native>,
}
impl<K, V> PrimitiveDictionaryBuilder<K, V>
where
K: ArrowPrimitiveType,
V: ArrowPrimitiveType,
{
pub fn new(
keys_builder: PrimitiveBuilder<K>,
values_builder: PrimitiveBuilder<V>,
) -> Self {
Self {
keys_builder,
values_builder,
map: HashMap::new(),
}
}
}
impl<K, V> ArrayBuilder for PrimitiveDictionaryBuilder<K, V>
where
K: ArrowPrimitiveType,
V: ArrowPrimitiveType,
{
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
fn into_box_any(self: Box<Self>) -> Box<dyn Any> {
self
}
fn len(&self) -> usize {
self.keys_builder.len()
}
fn is_empty(&self) -> bool {
self.keys_builder.is_empty()
}
fn finish(&mut self) -> ArrayRef {
Arc::new(self.finish())
}
}
impl<K, V> PrimitiveDictionaryBuilder<K, V>
where
K: ArrowPrimitiveType,
V: ArrowPrimitiveType,
{
#[inline]
pub fn append(&mut self, value: V::Native) -> Result<K::Native> {
if let Some(&key) = self.map.get(value.to_byte_slice()) {
self.keys_builder.append_value(key);
Ok(key)
} else {
let key = K::Native::from_usize(self.values_builder.len())
.ok_or(ArrowError::DictionaryKeyOverflowError)?;
self.values_builder.append_value(value);
self.keys_builder.append_value(key as K::Native);
self.map.insert(value.to_byte_slice().into(), key);
Ok(key)
}
}
#[inline]
pub fn append_null(&mut self) {
self.keys_builder.append_null()
}
pub fn finish(&mut self) -> DictionaryArray<K> {
self.map.clear();
let values = self.values_builder.finish();
let keys = self.keys_builder.finish();
let data_type =
DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(V::DATA_TYPE));
let builder = keys
.into_data()
.into_builder()
.data_type(data_type)
.child_data(vec![values.into_data()]);
DictionaryArray::from(unsafe { builder.build_unchecked() })
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::array::Array;
use crate::array::UInt32Array;
use crate::array::UInt8Array;
use crate::datatypes::UInt32Type;
use crate::datatypes::UInt8Type;
#[test]
fn test_primitive_dictionary_builder() {
let key_builder = PrimitiveBuilder::<UInt8Type>::new(3);
let value_builder = PrimitiveBuilder::<UInt32Type>::new(2);
let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder);
builder.append(12345678).unwrap();
builder.append_null();
builder.append(22345678).unwrap();
let array = builder.finish();
assert_eq!(
array.keys(),
&UInt8Array::from(vec![Some(0), None, Some(1)])
);
let av = array.values();
let ava: &UInt32Array = av.as_any().downcast_ref::<UInt32Array>().unwrap();
let avs: &[u32] = ava.values();
assert!(!array.is_null(0));
assert!(array.is_null(1));
assert!(!array.is_null(2));
assert_eq!(avs, &[12345678, 22345678]);
}
#[test]
#[should_panic(expected = "DictionaryKeyOverflowError")]
fn test_primitive_dictionary_overflow() {
let key_builder = PrimitiveBuilder::<UInt8Type>::new(257);
let value_builder = PrimitiveBuilder::<UInt32Type>::new(257);
let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder);
for i in 0..256 {
builder.append(i + 1000).unwrap();
}
builder.append(1257).unwrap();
}
}