use arrow_buffer::BooleanBufferBuilder;
use itertools::Itertools;
use vortex_buffer::{Buffer, BufferMut, ByteBuffer};
use vortex_dtype::{DType, DecimalDType, IntegerPType, match_each_integer_ptype};
use vortex_error::{VortexExpect, VortexResult, vortex_ensure, vortex_panic};
use vortex_scalar::{BigCast, DecimalValueType, NativeDecimalType, match_each_decimal_value_type};
use crate::ToCanonical;
use crate::arrays::is_compatible_decimal_value_type;
use crate::patches::Patches;
use crate::stats::ArrayStats;
use crate::validity::Validity;
use crate::vtable::ValidityHelper;
#[derive(Clone, Debug)]
pub struct DecimalArray {
pub(super) dtype: DType,
pub(super) values: ByteBuffer,
pub(super) values_type: DecimalValueType,
pub(super) validity: Validity,
pub(super) stats_set: ArrayStats,
}
impl DecimalArray {
pub fn new<T: NativeDecimalType>(
buffer: Buffer<T>,
decimal_dtype: DecimalDType,
validity: Validity,
) -> Self {
Self::try_new(buffer, decimal_dtype, validity)
.vortex_expect("DecimalArray construction failed")
}
pub fn try_new<T: NativeDecimalType>(
buffer: Buffer<T>,
decimal_dtype: DecimalDType,
validity: Validity,
) -> VortexResult<Self> {
Self::validate(&buffer, &validity)?;
Ok(unsafe { Self::new_unchecked(buffer, decimal_dtype, validity) })
}
pub unsafe fn new_unchecked<T: NativeDecimalType>(
buffer: Buffer<T>,
decimal_dtype: DecimalDType,
validity: Validity,
) -> Self {
#[cfg(debug_assertions)]
Self::validate(&buffer, &validity)
.vortex_expect("[Debug Assertion]: Invalid `DecimalArray` parameters");
Self {
values: buffer.into_byte_buffer(),
values_type: T::VALUES_TYPE,
dtype: DType::Decimal(decimal_dtype, validity.nullability()),
validity,
stats_set: Default::default(),
}
}
pub fn validate<T: NativeDecimalType>(
buffer: &Buffer<T>,
validity: &Validity,
) -> VortexResult<()> {
if let Some(len) = validity.maybe_len() {
vortex_ensure!(
buffer.len() == len,
"Buffer and validity length mismatch: buffer={}, validity={}",
buffer.len(),
len,
);
}
Ok(())
}
pub fn byte_buffer(&self) -> ByteBuffer {
self.values.clone()
}
pub fn buffer<T: NativeDecimalType>(&self) -> Buffer<T> {
if self.values_type != T::VALUES_TYPE {
vortex_panic!(
"Cannot extract Buffer<{:?}> for DecimalArray with values_type {:?}",
T::VALUES_TYPE,
self.values_type,
);
}
Buffer::<T>::from_byte_buffer(self.values.clone())
}
pub fn decimal_dtype(&self) -> DecimalDType {
if let DType::Decimal(decimal_dtype, _) = self.dtype {
decimal_dtype
} else {
vortex_panic!("Expected Decimal dtype, got {:?}", self.dtype)
}
}
pub fn values_type(&self) -> DecimalValueType {
self.values_type
}
pub fn precision(&self) -> u8 {
self.decimal_dtype().precision()
}
pub fn scale(&self) -> i8 {
self.decimal_dtype().scale()
}
pub fn from_option_iter<T: NativeDecimalType, I: IntoIterator<Item = Option<T>>>(
iter: I,
decimal_dtype: DecimalDType,
) -> Self {
let iter = iter.into_iter();
let mut values = BufferMut::with_capacity(iter.size_hint().0);
let mut validity = BooleanBufferBuilder::new(values.capacity());
for i in iter {
match i {
None => {
validity.append(false);
values.push(T::default());
}
Some(e) => {
validity.append(true);
values.push(e);
}
}
}
Self::new(
values.freeze(),
decimal_dtype,
Validity::from(validity.finish()),
)
}
#[allow(clippy::cognitive_complexity)]
pub fn patch(self, patches: &Patches) -> Self {
let offset = patches.offset();
let patch_indices = patches.indices().to_primitive();
let patch_values = patches.values().to_decimal();
let patched_validity = self.validity().clone().patch(
self.len(),
offset,
patch_indices.as_ref(),
patch_values.validity(),
);
assert_eq!(self.decimal_dtype(), patch_values.decimal_dtype());
match_each_integer_ptype!(patch_indices.ptype(), |I| {
let patch_indices = patch_indices.as_slice::<I>();
match_each_decimal_value_type!(patch_values.values_type(), |PatchDVT| {
let patch_values = patch_values.buffer::<PatchDVT>();
match_each_decimal_value_type!(self.values_type(), |ValuesDVT| {
let buffer = self.buffer::<ValuesDVT>().into_mut();
patch_typed(
buffer,
self.decimal_dtype(),
patch_indices,
offset,
patch_values,
patched_validity,
)
})
})
})
}
}
fn patch_typed<I, ValuesDVT, PatchDVT>(
mut buffer: BufferMut<ValuesDVT>,
decimal_dtype: DecimalDType,
patch_indices: &[I],
patch_indices_offset: usize,
patch_values: Buffer<PatchDVT>,
patched_validity: Validity,
) -> DecimalArray
where
I: IntegerPType,
PatchDVT: NativeDecimalType,
ValuesDVT: NativeDecimalType,
{
if !is_compatible_decimal_value_type(ValuesDVT::VALUES_TYPE, decimal_dtype) {
vortex_panic!(
"patch_typed: {:?} cannot represent every value in {}.",
ValuesDVT::VALUES_TYPE,
decimal_dtype
)
}
for (idx, value) in patch_indices.iter().zip_eq(patch_values.into_iter()) {
buffer[idx.as_() - patch_indices_offset] = <ValuesDVT as BigCast>::from(value).vortex_expect(
"values of a given DecimalDType are representable in all compatible NativeDecimalType",
);
}
DecimalArray::new(buffer.freeze(), decimal_dtype, patched_validity)
}