use std::hash::BuildHasher;
use std::sync::Arc;
use arrow_buffer::NullBufferBuilder;
use vortex_array::accessor::ArrayAccessor;
use vortex_array::arrays::binary_view::BinaryView;
use vortex_array::arrays::{PrimitiveArray, VarBinVTable, VarBinViewArray, VarBinViewVTable};
use vortex_array::validity::Validity;
use vortex_array::{Array, ArrayRef, IntoArray};
use vortex_buffer::{BufferMut, ByteBufferMut};
use vortex_dtype::{DType, UnsignedPType};
use vortex_error::{VortexExpect, VortexResult, VortexUnwrap, vortex_bail, vortex_panic};
use vortex_utils::aliases::hash_map::{DefaultHashBuilder, HashTable, HashTableEntry, RandomState};
use super::DictConstraints;
use crate::builders::DictEncoder;
pub struct BytesDictBuilder<Codes> {
lookup: Option<HashTable<Codes>>,
views: BufferMut<BinaryView>,
values: ByteBufferMut,
values_nulls: NullBufferBuilder,
hasher: RandomState,
dtype: DType,
max_dict_bytes: usize,
max_dict_len: usize,
}
pub fn bytes_dict_builder(dtype: DType, constraints: &DictConstraints) -> Box<dyn DictEncoder> {
match constraints.max_len as u64 {
max if max <= u8::MAX as u64 => Box::new(BytesDictBuilder::<u8>::new(dtype, constraints)),
max if max <= u16::MAX as u64 => Box::new(BytesDictBuilder::<u16>::new(dtype, constraints)),
max if max <= u32::MAX as u64 => Box::new(BytesDictBuilder::<u32>::new(dtype, constraints)),
_ => Box::new(BytesDictBuilder::<u64>::new(dtype, constraints)),
}
}
impl<Code: UnsignedPType> BytesDictBuilder<Code> {
pub fn new(dtype: DType, constraints: &DictConstraints) -> Self {
Self {
lookup: Some(HashTable::new()),
views: BufferMut::<BinaryView>::empty(),
values: BufferMut::empty(),
values_nulls: NullBufferBuilder::new(0),
hasher: DefaultHashBuilder::default(),
dtype,
max_dict_bytes: constraints.max_bytes,
max_dict_len: constraints.max_len,
}
}
fn dict_bytes(&self) -> usize {
self.views.len() * size_of::<BinaryView>() + self.values.len()
}
#[inline]
fn lookup_bytes(&self, idx: usize) -> Option<&[u8]> {
self.values_nulls.is_valid(idx).then(|| {
let bin_view = &self.views[idx];
if bin_view.is_inlined() {
bin_view.as_inlined().value()
} else {
&self.values[bin_view.as_view().as_range()]
}
})
}
#[inline]
fn encode_value(&mut self, lookup: &mut HashTable<Code>, val: Option<&[u8]>) -> Option<Code> {
match lookup.entry(
self.hasher.hash_one(val),
|idx| val == self.lookup_bytes(idx.as_()),
|idx| self.hasher.hash_one(self.lookup_bytes(idx.as_())),
) {
HashTableEntry::Occupied(occupied) => Some(*occupied.get()),
HashTableEntry::Vacant(vacant) => {
if self.views.len() >= self.max_dict_len {
return None;
}
let next_code = self.views.len();
match val {
None => {
self.views.push(BinaryView::default());
self.values_nulls.append_null();
}
Some(val) => {
let view = BinaryView::make_view(
val,
0,
u32::try_from(self.values.len()).vortex_unwrap(),
);
let additional_bytes = if view.is_inlined() {
size_of::<BinaryView>()
} else {
size_of::<BinaryView>() + val.len()
};
if self.dict_bytes() + additional_bytes > self.max_dict_bytes {
return None;
}
self.views.push(view);
self.values_nulls.append_non_null();
if !view.is_inlined() {
self.values.extend_from_slice(val);
}
}
}
let next_code = Code::from_usize(next_code).unwrap_or_else(|| {
vortex_panic!("{next_code} has to fit into {}", Code::PTYPE)
});
Some(*vacant.insert(next_code).get())
}
}
}
fn encode_bytes<A: ArrayAccessor<[u8]>>(
&mut self,
accessor: &A,
len: usize,
) -> VortexResult<ArrayRef> {
let mut local_lookup = self.lookup.take().vortex_expect("Must have a lookup dict");
let mut codes: BufferMut<Code> = BufferMut::with_capacity(len);
accessor.with_iterator(|it| {
for value in it {
let Some(code) = self.encode_value(&mut local_lookup, value) else {
break;
};
unsafe { codes.push_unchecked(code) }
}
})?;
self.lookup = Some(local_lookup);
Ok(PrimitiveArray::new(codes, Validity::NonNullable).into_array())
}
}
impl<Code: UnsignedPType> DictEncoder for BytesDictBuilder<Code> {
fn encode(&mut self, array: &dyn Array) -> VortexResult<ArrayRef> {
if &self.dtype != array.dtype() {
vortex_bail!(
"Array DType {} does not match builder dtype {}",
array.dtype(),
self.dtype
);
}
let len = array.len();
if let Some(varbinview) = array.as_opt::<VarBinViewVTable>() {
self.encode_bytes(varbinview, len)
} else if let Some(varbin) = array.as_opt::<VarBinVTable>() {
self.encode_bytes(varbin, len)
} else {
vortex_bail!("Can only dictionary encode VarBin and VarBinView arrays");
}
}
fn values(&mut self) -> VortexResult<ArrayRef> {
unsafe {
Ok(VarBinViewArray::new_unchecked(
self.views.clone().freeze(),
Arc::from([self.values.clone().freeze()]),
self.dtype.clone(),
Validity::from_null_buffer(
self.values_nulls.finish_cloned(),
self.dtype.nullability(),
),
)
.into_array())
}
}
}
#[cfg(test)]
mod test {
use std::str;
use vortex_array::ToCanonical;
use vortex_array::accessor::ArrayAccessor;
use vortex_array::arrays::VarBinArray;
use crate::builders::dict_encode;
#[test]
fn encode_varbin() {
let arr = VarBinArray::from(vec!["hello", "world", "hello", "again", "world"]);
let dict = dict_encode(arr.as_ref()).unwrap();
assert_eq!(
dict.codes().to_primitive().as_slice::<u8>(),
&[0, 1, 0, 2, 1]
);
dict.values()
.to_varbinview()
.with_iterator(|iter| {
assert_eq!(
iter.flatten()
.map(|b| unsafe { str::from_utf8_unchecked(b) })
.collect::<Vec<_>>(),
vec!["hello", "world", "again"]
);
})
.unwrap();
}
#[test]
fn encode_varbin_nulls() {
let arr: VarBinArray = vec![
Some("hello"),
None,
Some("world"),
Some("hello"),
None,
Some("again"),
Some("world"),
None,
]
.into_iter()
.collect();
let dict = dict_encode(arr.as_ref()).unwrap();
assert_eq!(
dict.codes().to_primitive().as_slice::<u8>(),
&[0, 1, 2, 0, 1, 3, 2, 1]
);
dict.values()
.to_varbinview()
.with_iterator(|iter| {
assert_eq!(
iter.map(|b| b.map(|v| unsafe { str::from_utf8_unchecked(v) }))
.collect::<Vec<_>>(),
vec![Some("hello"), None, Some("world"), Some("again")]
);
})
.unwrap();
}
#[test]
fn repeated_values() {
let arr = VarBinArray::from(vec!["a", "a", "b", "b", "a", "b", "a", "b"]);
let dict = dict_encode(arr.as_ref()).unwrap();
dict.values()
.to_varbinview()
.with_iterator(|iter| {
assert_eq!(
iter.flatten()
.map(|b| unsafe { str::from_utf8_unchecked(b) })
.collect::<Vec<_>>(),
vec!["a", "b"]
);
})
.unwrap();
assert_eq!(
dict.codes().to_primitive().as_slice::<u8>(),
&[0, 0, 1, 1, 0, 1, 0, 1]
);
}
}