use std::fmt::Debug;
use std::ops::Range;
use itertools::Itertools as _;
use vortex_buffer::BitBuffer;
use vortex_error::VortexExpect as _;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_error::vortex_err;
use vortex_mask::Mask;
use vortex_mask::MaskValues;
use crate::ArrayRef;
use crate::Canonical;
use crate::ExecutionCtx;
use crate::IntoArray;
use crate::LEGACY_SESSION;
use crate::VortexSessionExecute;
use crate::arrays::BoolArray;
use crate::arrays::ChunkedArray;
use crate::arrays::ConstantArray;
use crate::arrays::scalar_fn::ScalarFnFactoryExt;
use crate::builtins::ArrayBuiltins;
use crate::dtype::DType;
use crate::dtype::Nullability;
use crate::optimizer::ArrayOptimizer;
use crate::patches::Patches;
use crate::scalar::Scalar;
use crate::scalar_fn::fns::binary::Binary;
use crate::scalar_fn::fns::operators::Operator;
#[derive(Clone)]
pub enum Validity {
NonNullable,
AllValid,
AllInvalid,
Array(ArrayRef),
}
impl Debug for Validity {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NonNullable => write!(f, "NonNullable"),
Self::AllValid => write!(f, "AllValid"),
Self::AllInvalid => write!(f, "AllInvalid"),
Self::Array(arr) => write!(f, "SomeValid({})", arr.display_values()),
}
}
}
impl Validity {
pub fn execute(self, ctx: &mut ExecutionCtx) -> VortexResult<Validity> {
match self {
v @ Validity::NonNullable | v @ Validity::AllValid | v @ Validity::AllInvalid => Ok(v),
Validity::Array(a) => Ok(Validity::Array(a.execute::<Canonical>(ctx)?.into_array())),
}
}
}
impl Validity {
pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
pub fn to_array(&self, len: usize) -> ArrayRef {
match self {
Self::NonNullable | Self::AllValid => ConstantArray::new(true, len).into_array(),
Self::AllInvalid => ConstantArray::new(false, len).into_array(),
Self::Array(a) => a.clone(),
}
}
#[inline]
pub fn into_array(self) -> Option<ArrayRef> {
if let Self::Array(a) = self {
Some(a)
} else {
None
}
}
#[inline]
pub fn as_array(&self) -> Option<&ArrayRef> {
if let Self::Array(a) = self {
Some(a)
} else {
None
}
}
#[inline]
pub fn nullability(&self) -> Nullability {
if matches!(self, Self::NonNullable) {
Nullability::NonNullable
} else {
Nullability::Nullable
}
}
#[inline]
pub fn no_nulls(&self) -> bool {
matches!(self, Self::NonNullable | Self::AllValid)
}
#[inline]
pub fn union_nullability(self, nullability: Nullability) -> Self {
match nullability {
Nullability::NonNullable => self,
Nullability::Nullable => self.into_nullable(),
}
}
#[inline]
pub fn is_valid(&self, index: usize) -> VortexResult<bool> {
Ok(match self {
Self::NonNullable | Self::AllValid => true,
Self::AllInvalid => false,
Self::Array(a) => a
.execute_scalar(index, &mut LEGACY_SESSION.create_execution_ctx())
.vortex_expect("Validity array must support execute_scalar")
.as_bool()
.value()
.vortex_expect("Validity must be non-nullable"),
})
}
#[inline]
pub fn is_null(&self, index: usize) -> VortexResult<bool> {
Ok(!self.is_valid(index)?)
}
#[inline]
pub fn slice(&self, range: Range<usize>) -> VortexResult<Self> {
match self {
Self::Array(a) => Ok(Self::Array(a.slice(range)?)),
Self::NonNullable | Self::AllValid | Self::AllInvalid => Ok(self.clone()),
}
}
pub fn take(&self, indices: &ArrayRef) -> VortexResult<Self> {
match self {
Self::NonNullable => indices.validity(),
Self::AllValid => Ok(match indices.validity()? {
Self::NonNullable => Self::AllValid,
v => v,
}),
Self::AllInvalid => Ok(Self::AllInvalid),
Self::Array(is_valid) => {
let maybe_is_valid = is_valid.take(indices.clone())?;
let is_valid = maybe_is_valid.fill_null(Scalar::from(false))?;
Ok(Self::Array(is_valid))
}
}
}
pub fn not(&self) -> VortexResult<Self> {
match self {
Validity::NonNullable => Ok(Validity::NonNullable),
Validity::AllValid => Ok(Validity::AllInvalid),
Validity::AllInvalid => Ok(Validity::AllValid),
Validity::Array(arr) => Ok(Validity::Array(arr.not()?)),
}
}
pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
match self {
v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
Ok(v.clone())
}
Validity::Array(arr) => Ok(Validity::Array(arr.filter(mask.clone())?)),
}
}
#[deprecated(note = "Use execute_mask")]
pub fn to_mask(&self, length: usize, ctx: &mut ExecutionCtx) -> VortexResult<Mask> {
match self {
Self::NonNullable | Self::AllValid => Ok(Mask::new_true(length)),
Self::AllInvalid => Ok(Mask::new_false(length)),
Self::Array(arr) => arr.clone().execute::<Mask>(ctx),
}
}
pub fn execute_mask(&self, length: usize, ctx: &mut ExecutionCtx) -> VortexResult<Mask> {
match self {
Self::NonNullable | Self::AllValid => Ok(Mask::AllTrue(length)),
Self::AllInvalid => Ok(Mask::AllFalse(length)),
Self::Array(arr) => {
assert_eq!(
arr.len(),
length,
"Validity::Array length must equal to_logical's argument: {}, {}.",
arr.len(),
length,
);
arr.clone().execute::<Mask>(ctx)
}
}
}
pub fn mask_eq(&self, other: &Validity, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
match (self, other) {
(Validity::NonNullable, Validity::NonNullable) => Ok(true),
(Validity::AllValid, Validity::AllValid) => Ok(true),
(Validity::AllInvalid, Validity::AllInvalid) => Ok(true),
(Validity::Array(a), Validity::Array(b)) => {
let a = a.clone().execute::<Mask>(ctx)?;
let b = b.clone().execute::<Mask>(ctx)?;
Ok(a == b)
}
_ => Ok(false),
}
}
#[inline]
pub fn and(self, rhs: Validity) -> VortexResult<Validity> {
Ok(match (self, rhs) {
(Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
(Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
(Validity::Array(a), Validity::AllValid)
| (Validity::Array(a), Validity::NonNullable)
| (Validity::NonNullable, Validity::Array(a))
| (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
(Validity::NonNullable, Validity::AllValid)
| (Validity::AllValid, Validity::NonNullable)
| (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
(Validity::Array(lhs), Validity::Array(rhs)) => Validity::Array(
Binary
.try_new_array(lhs.len(), Operator::And, [lhs, rhs])?
.optimize()?,
),
})
}
pub fn patch(
self,
len: usize,
indices_offset: usize,
indices: &ArrayRef,
patches: &Validity,
ctx: &mut ExecutionCtx,
) -> VortexResult<Self> {
match (&self, patches) {
(Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
(Validity::NonNullable, _) => {
vortex_bail!("Can't patch a non-nullable validity with nullable validity")
}
(_, Validity::NonNullable) => {
vortex_bail!("Can't patch a nullable validity with non-nullable validity")
}
(Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
(Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
_ => {}
};
if matches!(self, Validity::NonNullable) {
return Ok(Self::NonNullable);
}
let source = match self {
Validity::NonNullable => BoolArray::from(BitBuffer::new_set(len)),
Validity::AllValid => BoolArray::from(BitBuffer::new_set(len)),
Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(len)),
Validity::Array(a) => a.execute::<BoolArray>(ctx)?,
};
let patch_values = match patches {
Validity::NonNullable => BoolArray::from(BitBuffer::new_set(indices.len())),
Validity::AllValid => BoolArray::from(BitBuffer::new_set(indices.len())),
Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(indices.len())),
Validity::Array(a) => a.clone().execute::<BoolArray>(ctx)?,
};
let patches = Patches::new(
len,
indices_offset,
indices.clone(),
patch_values.into_array(),
None,
)?;
Ok(Self::Array(source.patch(&patches, ctx)?.into_array()))
}
#[inline]
pub fn into_nullable(self) -> Validity {
match self {
Self::NonNullable => Self::AllValid,
Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
}
}
#[inline]
pub fn into_non_nullable(self, len: usize, ctx: &mut ExecutionCtx) -> Option<Validity> {
match self {
_ if len == 0 => Some(Validity::NonNullable),
Self::NonNullable => Some(Self::NonNullable),
Self::AllValid => Some(Self::NonNullable),
Self::AllInvalid => None,
Self::Array(is_valid) => {
is_valid
.statistics()
.compute_min::<bool>(ctx)
.vortex_expect("validity array must support min")
.then(|| {
Self::NonNullable
})
}
}
}
#[inline]
pub fn trivial_into_non_nullable(self, len: usize) -> VortexResult<Option<Validity>> {
match self {
_ if len == 0 => Ok(Some(Validity::NonNullable)),
Self::NonNullable => Ok(Some(Self::NonNullable)),
Self::AllValid => Ok(Some(Self::NonNullable)),
Self::AllInvalid => {
Err(vortex_err!(InvalidArgument: "Cannot cast AllInvalid to NonNullable"))
}
Self::Array(_) => Ok(None),
}
}
#[inline]
pub fn cast_nullability(
self,
nullability: Nullability,
len: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<Validity> {
match nullability {
Nullability::NonNullable => self.into_non_nullable(len, ctx).ok_or_else(|| {
vortex_err!(InvalidArgument: "Cannot cast array with invalid values to non-nullable type.")
}),
Nullability::Nullable => Ok(self.into_nullable()),
}
}
#[inline]
pub fn trivial_cast_nullability(
self,
nullability: Nullability,
len: usize,
) -> VortexResult<Option<Validity>> {
match nullability {
Nullability::NonNullable => self.trivial_into_non_nullable(len),
Nullability::Nullable => Ok(Some(self.into_nullable())),
}
}
#[inline]
pub fn maybe_len(&self) -> Option<usize> {
match self {
Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
Self::Array(a) => Some(a.len()),
}
}
}
impl From<BitBuffer> for Validity {
#[inline]
fn from(value: BitBuffer) -> Self {
let true_count = value.true_count();
if true_count == value.len() {
Self::AllValid
} else if true_count == 0 {
Self::AllInvalid
} else {
Self::Array(BoolArray::from(value).into_array())
}
}
}
impl FromIterator<Mask> for Validity {
#[inline]
fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
}
}
impl FromIterator<bool> for Validity {
#[inline]
fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
Validity::from(BitBuffer::from_iter(iter))
}
}
impl From<Nullability> for Validity {
#[inline]
fn from(value: Nullability) -> Self {
Validity::from(&value)
}
}
impl From<&Nullability> for Validity {
#[inline]
fn from(value: &Nullability) -> Self {
match *value {
Nullability::NonNullable => Validity::NonNullable,
Nullability::Nullable => Validity::AllValid,
}
}
}
impl Validity {
pub fn concat(validities: Vec<(Validity, usize)>) -> Option<Self> {
let mut validity_kinds = validities
.iter()
.map(|(v, _)| std::mem::discriminant(v))
.unique();
let validity_kind = validity_kinds.next()?;
if validity_kinds.next().is_none() {
if validity_kind == std::mem::discriminant(&Validity::AllValid) {
return Some(Validity::AllValid);
}
if validity_kind == std::mem::discriminant(&Validity::AllInvalid) {
return Some(Validity::AllInvalid);
}
if validity_kind == std::mem::discriminant(&Validity::NonNullable) {
return Some(Validity::NonNullable);
}
}
Some(Validity::Array(
unsafe {
ChunkedArray::new_unchecked(
validities
.into_iter()
.map(|(v, len)| v.to_array(len))
.collect(),
DType::Bool(Nullability::NonNullable),
)
}
.into_array(),
))
}
}
impl Validity {
pub fn from_bit_buffer(buffer: BitBuffer, nullability: Nullability) -> Self {
if buffer.true_count() == buffer.len() {
nullability.into()
} else if buffer.true_count() == 0 {
Validity::AllInvalid
} else {
Validity::Array(BoolArray::new(buffer, Validity::NonNullable).into_array())
}
}
pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
assert!(
nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
"NonNullable validity must be AllValid",
);
match mask {
Mask::AllTrue(_) => match nullability {
Nullability::NonNullable => Validity::NonNullable,
Nullability::Nullable => Validity::AllValid,
},
Mask::AllFalse(_) => Validity::AllInvalid,
Mask::Values(values) => Validity::Array(values.into_array()),
}
}
}
impl IntoArray for Mask {
#[inline]
fn into_array(self) -> ArrayRef {
match self {
Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
Self::Values(a) => a.into_array(),
}
}
}
impl IntoArray for &MaskValues {
#[inline]
fn into_array(self) -> ArrayRef {
BoolArray::new(self.bit_buffer().clone(), Validity::NonNullable).into_array()
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use vortex_buffer::Buffer;
use vortex_buffer::buffer;
use vortex_mask::Mask;
use crate::ArrayRef;
use crate::IntoArray;
use crate::LEGACY_SESSION;
use crate::VortexSessionExecute;
use crate::arrays::PrimitiveArray;
use crate::dtype::Nullability;
use crate::validity::BoolArray;
use crate::validity::Validity;
#[rstest]
#[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
#[case(
Validity::AllValid,
5,
&[2, 4],
Validity::AllInvalid,
Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
)]
#[case(
Validity::AllValid,
5,
&[2, 4],
Validity::Array(BoolArray::from_iter([true, false]).into_array()),
Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
)]
#[case(
Validity::AllInvalid,
5,
&[2, 4],
Validity::AllValid,
Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
)]
#[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
#[case(
Validity::AllInvalid,
5,
&[2, 4],
Validity::Array(BoolArray::from_iter([true, false]).into_array()),
Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
)]
#[case(
Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
5,
&[2, 4],
Validity::AllValid,
Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
)]
#[case(
Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
5,
&[2, 4],
Validity::AllInvalid,
Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
)]
#[case(
Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
5,
&[2, 4],
Validity::Array(BoolArray::from_iter([true, false]).into_array()),
Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
)]
fn patch_validity(
#[case] validity: Validity,
#[case] len: usize,
#[case] positions: &[u64],
#[case] patches: Validity,
#[case] expected: Validity,
) {
let indices =
PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
let mut ctx = LEGACY_SESSION.create_execution_ctx();
assert!(
validity
.patch(len, 0, &indices, &patches, &mut ctx,)
.unwrap()
.mask_eq(&expected, &mut ctx)
.unwrap()
);
}
#[test]
#[should_panic]
fn out_of_bounds_patch() {
let mut ctx = LEGACY_SESSION.create_execution_ctx();
Validity::NonNullable
.patch(
2,
0,
&buffer![4].into_array(),
&Validity::AllInvalid,
&mut ctx,
)
.unwrap();
}
#[test]
#[should_panic]
fn into_validity_nullable() {
Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
}
#[test]
#[should_panic]
fn into_validity_nullable_array() {
Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
}
#[rstest]
#[case(
Validity::AllValid,
PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
Validity::from_iter(vec![true, false])
)]
#[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
#[case(
Validity::AllValid,
PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
Validity::AllInvalid
)]
#[case(
Validity::NonNullable,
PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
Validity::from_iter(vec![true, false])
)]
#[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
#[case(
Validity::NonNullable,
PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
Validity::AllInvalid
)]
fn validity_take(
#[case] validity: Validity,
#[case] indices: ArrayRef,
#[case] expected: Validity,
) {
let mut ctx = LEGACY_SESSION.create_execution_ctx();
assert!(
validity
.take(&indices)
.unwrap()
.mask_eq(&expected, &mut ctx)
.unwrap()
);
}
}