use bytes::bytes_dict_builder;
use primitive::primitive_dict_builder;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_error::vortex_panic;
use crate::ArrayRef;
use crate::DynArray;
use crate::IntoArray;
use crate::ToCanonical;
use crate::arrays::DictArray;
use crate::arrays::PrimitiveVTable;
use crate::arrays::VarBinVTable;
use crate::arrays::VarBinViewVTable;
use crate::dtype::PType;
use crate::match_each_native_ptype;
mod bytes;
mod primitive;
#[derive(Clone)]
pub struct DictConstraints {
pub max_bytes: usize,
pub max_len: usize,
}
pub const UNCONSTRAINED: DictConstraints = DictConstraints {
max_bytes: usize::MAX,
max_len: usize::MAX,
};
pub trait DictEncoder: Send {
fn encode(&mut self, array: &ArrayRef) -> ArrayRef;
fn reset(&mut self) -> ArrayRef;
fn codes_ptype(&self) -> PType;
}
pub fn dict_encoder(array: &ArrayRef, constraints: &DictConstraints) -> Box<dyn DictEncoder> {
let dict_builder: Box<dyn DictEncoder> = if let Some(pa) = array.as_opt::<PrimitiveVTable>() {
match_each_native_ptype!(pa.ptype(), |P| {
primitive_dict_builder::<P>(pa.dtype().nullability(), constraints)
})
} else if let Some(vbv) = array.as_opt::<VarBinViewVTable>() {
bytes_dict_builder(vbv.dtype().clone(), constraints)
} else if let Some(vb) = array.as_opt::<VarBinVTable>() {
bytes_dict_builder(vb.dtype().clone(), constraints)
} else {
vortex_panic!("Can only encode primitive or varbin/view arrays")
};
dict_builder
}
pub fn dict_encode_with_constraints(
array: &ArrayRef,
constraints: &DictConstraints,
) -> VortexResult<DictArray> {
let mut encoder = dict_encoder(array, constraints);
let codes = encoder.encode(array).to_primitive().narrow()?;
unsafe {
Ok(
DictArray::new_unchecked(codes.into_array(), encoder.reset())
.set_all_values_referenced(true),
)
}
}
pub fn dict_encode(array: &ArrayRef) -> VortexResult<DictArray> {
let dict_array = dict_encode_with_constraints(array, &UNCONSTRAINED)?;
if dict_array.len() != array.len() {
vortex_bail!(
"must have encoded all {} elements, but only encoded {}",
array.len(),
dict_array.len(),
);
}
Ok(dict_array)
}