use std::sync::Arc;
use arrow_buffer::BooleanBuffer;
use vortex_buffer::{Buffer, buffer};
use vortex_dtype::{DType, Nullability, match_each_native_ptype};
use vortex_error::VortexExpect;
use vortex_scalar::{
BinaryScalar, BoolScalar, DecimalValue, ExtScalar, ListScalar, Scalar, StructScalar,
Utf8Scalar, match_each_decimal_value, match_each_decimal_value_type,
};
use crate::arrays::binary_view::BinaryView;
use crate::arrays::constant::ConstantArray;
use crate::arrays::primitive::PrimitiveArray;
use crate::arrays::{
BoolArray, ConstantVTable, DecimalArray, ExtensionArray, FixedSizeListArray, ListViewArray,
NullArray, StructArray, VarBinViewArray, smallest_decimal_value_type,
};
use crate::builders::builder_with_capacity;
use crate::validity::Validity;
use crate::vtable::CanonicalVTable;
use crate::{Canonical, IntoArray};
impl CanonicalVTable<ConstantVTable> for ConstantVTable {
fn canonicalize(array: &ConstantArray) -> Canonical {
let scalar = array.scalar();
let validity = match array.dtype().nullability() {
Nullability::NonNullable => Validity::NonNullable,
Nullability::Nullable => match scalar.is_null() {
true => Validity::AllInvalid,
false => Validity::AllValid,
},
};
match array.dtype() {
DType::Null => Canonical::Null(NullArray::new(array.len())),
DType::Bool(..) => Canonical::Bool(BoolArray::from_bool_buffer(
if BoolScalar::try_from(scalar)
.vortex_expect("must be bool")
.value()
.unwrap_or_default()
{
BooleanBuffer::new_set(array.len())
} else {
BooleanBuffer::new_unset(array.len())
},
validity,
)),
DType::Primitive(ptype, ..) => {
match_each_native_ptype!(ptype, |P| {
Canonical::Primitive(PrimitiveArray::new(
if scalar.is_valid() {
Buffer::full(
P::try_from(scalar)
.vortex_expect("Couldn't unwrap scalar to primitive"),
array.len(),
)
} else {
Buffer::zeroed(array.len())
},
validity,
))
})
}
DType::Decimal(decimal_type, ..) => {
let size = smallest_decimal_value_type(decimal_type);
let decimal = scalar.as_decimal();
let Some(value) = decimal.decimal_value() else {
let all_null = match_each_decimal_value_type!(size, |D| {
unsafe {
DecimalArray::new_unchecked(
Buffer::<D>::zeroed(array.len()),
*decimal_type,
validity,
)
}
});
return Canonical::Decimal(all_null);
};
let decimal_array = match_each_decimal_value!(value, |value| {
unsafe {
DecimalArray::new_unchecked(
Buffer::full(value, array.len()),
*decimal_type,
validity,
)
}
});
Canonical::Decimal(decimal_array)
}
DType::Utf8(_) => {
let value = Utf8Scalar::try_from(scalar)
.vortex_expect("Must be a utf8 scalar")
.value();
let const_value = value.as_ref().map(|v| v.as_bytes());
Canonical::VarBinView(constant_canonical_byte_view(
const_value,
array.dtype(),
array.len(),
))
}
DType::Binary(_) => {
let value = BinaryScalar::try_from(scalar)
.vortex_expect("must be a binary scalar")
.value();
let const_value = value.as_ref().map(|v| v.as_slice());
Canonical::VarBinView(constant_canonical_byte_view(
const_value,
array.dtype(),
array.len(),
))
}
DType::Struct(struct_dtype, _) => {
let value = StructScalar::try_from(scalar).vortex_expect("must be struct");
let fields: Vec<_> = match value.fields() {
Some(fields) => fields
.into_iter()
.map(|s| ConstantArray::new(s, array.len()).into_array())
.collect(),
None => {
assert!(validity.all_invalid(array.len()));
struct_dtype
.fields()
.map(|dt| {
let scalar = Scalar::default_value(dt);
ConstantArray::new(scalar, array.len()).into_array()
})
.collect()
}
};
Canonical::Struct(unsafe {
StructArray::new_unchecked(fields, struct_dtype.clone(), array.len(), validity)
})
}
DType::List(..) => Canonical::List(constant_canonical_list_array(scalar, array.len())),
DType::FixedSizeList(element_dtype, list_size, _) => {
let value = ListScalar::try_from(scalar).vortex_expect("must be list");
Canonical::FixedSizeList(constant_canonical_fixed_size_list_array(
value.elements(),
element_dtype,
*list_size,
value.dtype().nullability(),
array.len(),
))
}
DType::Extension(ext_dtype) => {
let s = ExtScalar::try_from(scalar).vortex_expect("must be an extension scalar");
let storage_scalar = s.storage();
let storage_self = ConstantArray::new(storage_scalar, array.len()).into_array();
Canonical::Extension(ExtensionArray::new(ext_dtype.clone(), storage_self))
}
}
}
}
fn constant_canonical_byte_view(
scalar_bytes: Option<&[u8]>,
dtype: &DType,
len: usize,
) -> VarBinViewArray {
match scalar_bytes {
None => {
let views = buffer![BinaryView::from(0_u128); len];
unsafe {
VarBinViewArray::new_unchecked(
views,
Default::default(),
dtype.clone(),
Validity::AllInvalid,
)
}
}
Some(scalar_bytes) => {
let view = BinaryView::make_view(scalar_bytes, 0, 0);
let mut buffers = Vec::new();
if scalar_bytes.len() >= BinaryView::MAX_INLINED_SIZE {
buffers.push(Buffer::copy_from(scalar_bytes));
}
let views = buffer![view; len];
unsafe {
VarBinViewArray::new_unchecked(
views,
Arc::from(buffers),
dtype.clone(),
Validity::from(dtype.nullability()),
)
}
}
}
}
fn constant_canonical_list_array(scalar: &Scalar, len: usize) -> ListViewArray {
let list = ListScalar::try_from(scalar).vortex_expect("must be list");
let elements = if let Some(elements) = list.elements() {
let mut builder = builder_with_capacity(
list.dtype()
.as_list_element_opt()
.vortex_expect("list scalar somehow did not have a list DType"),
list.len(),
);
for scalar in &elements {
builder
.append_scalar(scalar)
.vortex_expect("list element scalar was invalid");
}
builder.finish()
} else {
Canonical::empty(list.element_dtype()).into_array()
};
let validity = if scalar.dtype().is_nullable() {
if list.is_null() {
Validity::AllInvalid
} else {
Validity::AllValid
}
} else {
debug_assert!(!list.is_null());
Validity::NonNullable
};
let offsets = ConstantArray::new::<u64>(0, len).into_array();
let sizes = ConstantArray::new::<u64>(list.len() as u64, len).into_array();
debug_assert!(!offsets.dtype().is_nullable());
debug_assert!(!sizes.dtype().is_nullable());
unsafe { ListViewArray::new_unchecked(elements, offsets, sizes, validity) }
}
fn constant_canonical_fixed_size_list_array(
values: Option<Vec<Scalar>>,
element_dtype: &DType,
list_size: u32,
list_nullability: Nullability,
len: usize,
) -> FixedSizeListArray {
match values {
None => {
let elements_len = list_size as usize * len;
let mut element_builder = builder_with_capacity(element_dtype, elements_len);
element_builder.append_defaults(elements_len);
let elements = element_builder.finish();
unsafe {
FixedSizeListArray::new_unchecked(elements, list_size, Validity::AllInvalid, len)
}
}
Some(values) => {
let mut elements_builder = builder_with_capacity(element_dtype, len * values.len());
for _ in 0..len {
for v in &values {
elements_builder
.append_scalar(v)
.vortex_expect("must be a same dtype");
}
}
let elements = elements_builder.finish();
let validity = Validity::from(list_nullability);
unsafe { FixedSizeListArray::new_unchecked(elements, list_size, validity, len) }
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use enum_iterator::all;
use itertools::Itertools;
use vortex_dtype::half::f16;
use vortex_dtype::{DType, Nullability, PType};
use vortex_scalar::Scalar;
use crate::arrays::{ConstantArray, ListViewRebuildMode};
use crate::canonical::ToCanonical;
use crate::stats::{Stat, StatsProvider};
use crate::validity::Validity;
use crate::vtable::ValidityHelper;
use crate::{Array, IntoArray};
#[test]
fn test_canonicalize_null() {
let const_null = ConstantArray::new(Scalar::null(DType::Null), 42);
let actual = const_null.to_null();
assert_eq!(actual.len(), 42);
assert_eq!(actual.scalar_at(33), Scalar::null(DType::Null));
}
#[test]
fn test_canonicalize_const_str() {
let const_array = ConstantArray::new("four".to_string(), 4);
let canonical = const_array.to_varbinview();
assert_eq!(canonical.len(), 4);
for i in 0..=3 {
assert_eq!(canonical.scalar_at(i), "four".into());
}
}
#[test]
fn test_canonicalize_propagates_stats() {
let scalar = Scalar::bool(true, Nullability::NonNullable);
let const_array = ConstantArray::new(scalar, 4).into_array();
let stats = const_array
.statistics()
.compute_all(&all::<Stat>().collect_vec())
.unwrap();
let canonical = const_array.to_canonical();
let canonical_stats = canonical.as_ref().statistics();
let stats_ref = stats.as_typed_ref(canonical.as_ref().dtype());
for stat in all::<Stat>() {
if stat.dtype(canonical.as_ref().dtype()).is_none() {
continue;
}
assert_eq!(
canonical_stats.get(stat),
stats_ref.get(stat),
"stat mismatch {stat}"
);
}
}
#[test]
fn test_canonicalize_scalar_values() {
let f16_value = f16::from_f32(5.722046e-6);
let f16_scalar = Scalar::primitive(f16_value, Nullability::NonNullable);
let const_array = ConstantArray::new(f16_scalar.clone(), 1).into_array();
let canonical_const = const_array.to_primitive();
assert_eq!(canonical_const.scalar_at(0), f16_scalar);
}
#[test]
fn test_canonicalize_lists() {
let list_scalar = Scalar::list(
Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
vec![1u64.into(), 2u64.into()],
Nullability::NonNullable,
);
let const_array = ConstantArray::new(list_scalar, 2).into_array();
let canonical_const = const_array.to_listview();
let list_array = canonical_const.rebuild(ListViewRebuildMode::MakeZeroCopyToList);
assert_eq!(
list_array.elements().to_primitive().as_slice::<u64>(),
[1u64, 2, 1, 2]
);
assert_eq!(
list_array.offsets().to_primitive().as_slice::<u64>(),
[0u64, 2]
);
assert_eq!(
list_array.sizes().to_primitive().as_slice::<u64>(),
[2u64, 2]
);
}
#[test]
fn test_canonicalize_empty_list() {
let list_scalar = Scalar::list(
Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
vec![],
Nullability::NonNullable,
);
let const_array = ConstantArray::new(list_scalar, 2).into_array();
let canonical_const = const_array.to_listview();
assert!(canonical_const.elements().to_primitive().is_empty());
assert_eq!(
canonical_const.offsets().to_primitive().as_slice::<u64>(),
[0u64, 0]
);
assert_eq!(
canonical_const.sizes().to_primitive().as_slice::<u64>(),
[0u64, 0]
);
}
#[test]
fn test_canonicalize_null_list() {
let list_scalar = Scalar::null(DType::List(
Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
Nullability::Nullable,
));
let const_array = ConstantArray::new(list_scalar, 2).into_array();
let canonical_const = const_array.to_listview();
assert!(canonical_const.elements().to_primitive().is_empty());
assert_eq!(
canonical_const.offsets().to_primitive().as_slice::<u64>(),
[0u64, 0]
);
assert_eq!(
canonical_const.sizes().to_primitive().as_slice::<u64>(),
[0u64, 0]
);
}
#[test]
fn test_canonicalize_nullable_struct() {
let array = ConstantArray::new(
Scalar::null(DType::struct_(
[(
"non_null_field",
DType::Primitive(PType::I8, Nullability::NonNullable),
)],
Nullability::Nullable,
)),
3,
);
let struct_array = array.to_struct();
assert_eq!(struct_array.len(), 3);
assert_eq!(struct_array.valid_count(), 0);
let field = struct_array.field_by_name("non_null_field").unwrap();
assert_eq!(
field.dtype(),
&DType::Primitive(PType::I8, Nullability::NonNullable)
);
}
#[test]
fn test_canonicalize_fixed_size_list_non_null() {
let fsl_scalar = Scalar::fixed_size_list(
Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
vec![
Scalar::primitive(10i32, Nullability::NonNullable),
Scalar::primitive(20i32, Nullability::NonNullable),
Scalar::primitive(30i32, Nullability::NonNullable),
],
Nullability::NonNullable,
);
let const_array = ConstantArray::new(fsl_scalar, 4).into_array();
let canonical = const_array.to_fixed_size_list();
assert_eq!(canonical.len(), 4);
assert_eq!(canonical.list_size(), 3);
assert_eq!(canonical.validity(), &Validity::NonNullable);
for i in 0..4 {
let list = canonical.fixed_size_list_elements_at(i);
let list_primitive = list.to_primitive();
assert_eq!(list_primitive.as_slice::<i32>(), [10, 20, 30]);
}
}
#[test]
fn test_canonicalize_fixed_size_list_nullable() {
let fsl_scalar = Scalar::fixed_size_list(
Arc::new(DType::Primitive(PType::F64, Nullability::NonNullable)),
vec![
Scalar::primitive(1.5f64, Nullability::NonNullable),
Scalar::primitive(2.5f64, Nullability::NonNullable),
],
Nullability::Nullable,
);
let const_array = ConstantArray::new(fsl_scalar, 3).into_array();
let canonical = const_array.to_fixed_size_list();
assert_eq!(canonical.len(), 3);
assert_eq!(canonical.list_size(), 2);
assert_eq!(canonical.validity(), &Validity::AllValid);
let elements = canonical.elements().to_primitive();
assert_eq!(elements.as_slice::<f64>(), [1.5, 2.5, 1.5, 2.5, 1.5, 2.5]);
}
#[test]
fn test_canonicalize_fixed_size_list_null() {
let fsl_scalar = Scalar::null(DType::FixedSizeList(
Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
4,
Nullability::Nullable,
));
let const_array = ConstantArray::new(fsl_scalar, 5).into_array();
let canonical = const_array.to_fixed_size_list();
assert_eq!(canonical.len(), 5);
assert_eq!(canonical.list_size(), 4);
assert_eq!(canonical.validity(), &Validity::AllInvalid);
let elements = canonical.elements().to_primitive();
assert_eq!(elements.len(), 20); assert!(elements.as_slice::<u64>().iter().all(|&x| x == 0));
}
#[test]
fn test_canonicalize_fixed_size_list_empty() {
let fsl_scalar = Scalar::fixed_size_list(
Arc::new(DType::Primitive(PType::I8, Nullability::NonNullable)),
vec![],
Nullability::NonNullable,
);
let const_array = ConstantArray::new(fsl_scalar, 10).into_array();
let canonical = const_array.to_fixed_size_list();
assert_eq!(canonical.len(), 10);
assert_eq!(canonical.list_size(), 0);
assert_eq!(canonical.validity(), &Validity::NonNullable);
assert!(canonical.elements().is_empty());
}
#[test]
fn test_canonicalize_fixed_size_list_nested() {
let fsl_scalar = Scalar::fixed_size_list(
Arc::new(DType::Utf8(Nullability::NonNullable)),
vec![Scalar::from("hello"), Scalar::from("world")],
Nullability::NonNullable,
);
let const_array = ConstantArray::new(fsl_scalar, 2).into_array();
let canonical = const_array.to_fixed_size_list();
assert_eq!(canonical.len(), 2);
assert_eq!(canonical.list_size(), 2);
let elements = canonical.elements().to_varbinview();
assert_eq!(elements.scalar_at(0), "hello".into());
assert_eq!(elements.scalar_at(1), "world".into());
assert_eq!(elements.scalar_at(2), "hello".into());
assert_eq!(elements.scalar_at(3), "world".into());
}
#[test]
fn test_canonicalize_fixed_size_list_single_element() {
let fsl_scalar = Scalar::fixed_size_list(
Arc::new(DType::Primitive(PType::I16, Nullability::NonNullable)),
vec![Scalar::primitive(42i16, Nullability::NonNullable)],
Nullability::NonNullable,
);
let const_array = ConstantArray::new(fsl_scalar, 1).into_array();
let canonical = const_array.to_fixed_size_list();
assert_eq!(canonical.len(), 1);
assert_eq!(canonical.list_size(), 1);
let elements = canonical.elements().to_primitive();
assert_eq!(elements.as_slice::<i16>(), [42]);
}
#[test]
fn test_canonicalize_fixed_size_list_with_null_elements() {
let fsl_scalar = Scalar::fixed_size_list(
Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)),
vec![
Scalar::primitive(100i32, Nullability::Nullable),
Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
Scalar::primitive(200i32, Nullability::Nullable),
],
Nullability::NonNullable,
);
let const_array = ConstantArray::new(fsl_scalar, 3).into_array();
let canonical = const_array.to_fixed_size_list();
assert_eq!(canonical.len(), 3);
assert_eq!(canonical.list_size(), 3);
assert_eq!(canonical.validity(), &Validity::NonNullable);
let elements = canonical.elements().to_primitive();
assert_eq!(elements.as_slice::<i32>()[0], 100);
assert_eq!(elements.as_slice::<i32>()[1], 0); assert_eq!(elements.as_slice::<i32>()[2], 200);
let element_validity = elements.validity();
assert!(element_validity.is_valid(0));
assert!(!element_validity.is_valid(1));
assert!(element_validity.is_valid(2));
assert!(element_validity.is_valid(3));
assert!(!element_validity.is_valid(4));
assert!(element_validity.is_valid(5));
}
#[test]
fn test_canonicalize_fixed_size_list_large() {
let fsl_scalar = Scalar::fixed_size_list(
Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)),
vec![
Scalar::primitive(1u8, Nullability::NonNullable),
Scalar::primitive(2u8, Nullability::NonNullable),
Scalar::primitive(3u8, Nullability::NonNullable),
Scalar::primitive(4u8, Nullability::NonNullable),
Scalar::primitive(5u8, Nullability::NonNullable),
],
Nullability::NonNullable,
);
let const_array = ConstantArray::new(fsl_scalar, 1000).into_array();
let canonical = const_array.to_fixed_size_list();
assert_eq!(canonical.len(), 1000);
assert_eq!(canonical.list_size(), 5);
let elements = canonical.elements().to_primitive();
assert_eq!(elements.len(), 5000);
for i in 0..1000 {
let base = i * 5;
assert_eq!(elements.as_slice::<u8>()[base], 1);
assert_eq!(elements.as_slice::<u8>()[base + 1], 2);
assert_eq!(elements.as_slice::<u8>()[base + 2], 3);
assert_eq!(elements.as_slice::<u8>()[base + 3], 4);
assert_eq!(elements.as_slice::<u8>()[base + 4], 5);
}
}
}