use std::ops::Not;
use arrow_buffer::BooleanBuffer;
use vortex_array::arrays::{BoolArray, ConstantArray};
use vortex_array::compute::{Operator, cast, compare, mask, take};
use vortex_array::validity::Validity;
use vortex_array::vtable::CanonicalVTable;
use vortex_array::{Array, ArrayRef, Canonical, IntoArray, ToCanonical};
use vortex_dtype::{DType, Nullability};
use vortex_error::{VortexExpect, VortexResult};
use vortex_mask::{AllOr, Mask};
use vortex_scalar::Scalar;
use crate::{DictArray, DictVTable};
impl CanonicalVTable<DictVTable> for DictVTable {
fn canonicalize(array: &DictArray) -> Canonical {
match array.dtype() {
DType::Utf8(_) | DType::Binary(_) => {
let canonical_values: ArrayRef = array.values().to_canonical().into_array();
take(&canonical_values, array.codes())
.vortex_expect("taking codes from dictionary values shouldn't fail")
.to_canonical()
}
DType::Bool(_) => {
dict_bool_take(array).vortex_expect("Canonicalizing dict bool array shouldn't fail")
}
_ => take(array.values(), array.codes())
.vortex_expect("taking codes from dictionary values shouldn't fail")
.to_canonical(),
}
}
}
fn dict_bool_take(dict_array: &DictArray) -> VortexResult<Canonical> {
let values = dict_array.values();
let codes = dict_array.codes();
let result_nullability = dict_array.dtype().nullability();
let bool_values = values.to_bool();
let result_validity = bool_values.validity_mask();
let bool_buffer = bool_values.boolean_buffer();
let (first_match, second_match) = match result_validity.boolean_buffer() {
AllOr::All => {
let mut indices_iter = bool_buffer.set_indices();
(indices_iter.next(), indices_iter.next())
}
AllOr::None => (None, None),
AllOr::Some(v) => {
let mut indices_iter = bool_buffer.set_indices().filter(|i| v.value(*i));
(indices_iter.next(), indices_iter.next())
}
};
Ok(match (first_match, second_match) {
(None, _) => match result_validity {
Mask::AllTrue(_) => BoolArray::from_bool_buffer(
BooleanBuffer::new_unset(codes.len()),
Validity::copy_from_array(codes).union_nullability(result_nullability),
)
.to_canonical(),
Mask::AllFalse(_) => ConstantArray::new(
Scalar::null(DType::Bool(Nullability::Nullable)),
codes.len(),
)
.to_canonical(),
Mask::Values(_) => BoolArray::from_bool_buffer(
BooleanBuffer::new_unset(codes.len()),
Validity::from_mask(result_validity, result_nullability).take(codes)?,
)
.to_canonical(),
},
(Some(code), None) => match result_validity {
Mask::AllTrue(_) => cast(
&compare(
codes,
&cast(
ConstantArray::new(code, codes.len()).as_ref(),
codes.dtype(),
)?,
Operator::Eq,
)?,
&DType::Bool(result_nullability),
)?
.to_canonical(),
Mask::AllFalse(_) => ConstantArray::new(
Scalar::null(DType::Bool(Nullability::Nullable)),
codes.len(),
)
.to_canonical(),
Mask::Values(rv) => mask(
&compare(
codes,
&cast(
ConstantArray::new(code, codes.len()).as_ref(),
codes.dtype(),
)?,
Operator::Eq,
)?,
&Mask::from_buffer(
take(BoolArray::from(rv.boolean_buffer().clone()).as_ref(), codes)?
.to_bool()
.boolean_buffer()
.not(),
),
)?
.to_canonical(),
},
_ => take(bool_values.as_ref(), codes)
.vortex_expect("taking codes from dictionary values shouldn't fail")
.to_canonical(),
})
}