use itertools::Itertools;
use vortex_buffer::BitBufferMut;
use vortex_buffer::Buffer;
use vortex_buffer::BufferMut;
use vortex_buffer::ByteBuffer;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use vortex_error::vortex_panic;
use crate::ExecutionCtx;
use crate::IntoArray;
use crate::arrays::PrimitiveArray;
use crate::buffer::BufferHandle;
use crate::dtype::BigCast;
use crate::dtype::DType;
use crate::dtype::DecimalDType;
use crate::dtype::DecimalType;
use crate::dtype::IntegerPType;
use crate::dtype::NativeDecimalType;
use crate::match_each_decimal_value_type;
use crate::match_each_integer_ptype;
use crate::patches::Patches;
use crate::stats::ArrayStats;
use crate::validity::Validity;
use crate::vtable::ValidityHelper;
#[derive(Clone, Debug)]
pub struct DecimalArray {
pub(super) dtype: DType,
pub(super) values: BufferHandle,
pub(super) values_type: DecimalType,
pub(super) validity: Validity,
pub(super) stats_set: ArrayStats,
}
pub struct DecimalArrayParts {
pub decimal_dtype: DecimalDType,
pub values: BufferHandle,
pub values_type: DecimalType,
pub validity: Validity,
}
impl DecimalArray {
pub fn new<T: NativeDecimalType>(
buffer: Buffer<T>,
decimal_dtype: DecimalDType,
validity: Validity,
) -> Self {
Self::try_new(buffer, decimal_dtype, validity)
.vortex_expect("DecimalArray construction failed")
}
pub fn new_handle(
values: BufferHandle,
values_type: DecimalType,
decimal_dtype: DecimalDType,
validity: Validity,
) -> Self {
Self::try_new_handle(values, values_type, decimal_dtype, validity)
.vortex_expect("DecimalArray construction failed")
}
pub fn try_new<T: NativeDecimalType>(
buffer: Buffer<T>,
decimal_dtype: DecimalDType,
validity: Validity,
) -> VortexResult<Self> {
let values = BufferHandle::new_host(buffer.into_byte_buffer());
let values_type = T::DECIMAL_TYPE;
Self::try_new_handle(values, values_type, decimal_dtype, validity)
}
pub fn try_new_handle(
values: BufferHandle,
values_type: DecimalType,
decimal_dtype: DecimalDType,
validity: Validity,
) -> VortexResult<Self> {
Self::validate(&values, values_type, &validity)?;
Ok(unsafe { Self::new_unchecked_handle(values, values_type, decimal_dtype, validity) })
}
pub unsafe fn new_unchecked<T: NativeDecimalType>(
buffer: Buffer<T>,
decimal_dtype: DecimalDType,
validity: Validity,
) -> Self {
unsafe {
Self::new_unchecked_handle(
BufferHandle::new_host(buffer.into_byte_buffer()),
T::DECIMAL_TYPE,
decimal_dtype,
validity,
)
}
}
pub unsafe fn new_unchecked_handle(
values: BufferHandle,
values_type: DecimalType,
decimal_dtype: DecimalDType,
validity: Validity,
) -> Self {
#[cfg(debug_assertions)]
{
Self::validate(&values, values_type, &validity)
.vortex_expect("[Debug Assertion]: Invalid `DecimalArray` parameters");
}
Self {
values,
values_type,
dtype: DType::Decimal(decimal_dtype, validity.nullability()),
validity,
stats_set: Default::default(),
}
}
fn validate(
buffer: &BufferHandle,
values_type: DecimalType,
validity: &Validity,
) -> VortexResult<()> {
if let Some(validity_len) = validity.maybe_len() {
let expected_len = values_type.byte_width() * validity_len;
vortex_ensure!(
buffer.len() == expected_len,
InvalidArgument: "expected buffer of size {} bytes, was {} bytes",
expected_len,
buffer.len(),
);
}
Ok(())
}
pub unsafe fn new_unchecked_from_byte_buffer(
byte_buffer: ByteBuffer,
values_type: DecimalType,
decimal_dtype: DecimalDType,
validity: Validity,
) -> Self {
unsafe {
Self::new_unchecked_handle(
BufferHandle::new_host(byte_buffer),
values_type,
decimal_dtype,
validity,
)
}
}
pub fn into_parts(self) -> DecimalArrayParts {
let decimal_dtype = self.dtype.into_decimal_opt().vortex_expect("cannot fail");
DecimalArrayParts {
decimal_dtype,
values: self.values,
values_type: self.values_type,
validity: self.validity,
}
}
pub fn buffer_handle(&self) -> &BufferHandle {
&self.values
}
pub fn buffer<T: NativeDecimalType>(&self) -> Buffer<T> {
if self.values_type != T::DECIMAL_TYPE {
vortex_panic!(
"Cannot extract Buffer<{:?}> for DecimalArray with values_type {:?}",
T::DECIMAL_TYPE,
self.values_type,
);
}
Buffer::<T>::from_byte_buffer(self.values.as_host().clone())
}
pub fn decimal_dtype(&self) -> DecimalDType {
if let DType::Decimal(decimal_dtype, _) = self.dtype {
decimal_dtype
} else {
vortex_panic!("Expected Decimal dtype, got {:?}", self.dtype)
}
}
pub fn values_type(&self) -> DecimalType {
self.values_type
}
pub fn precision(&self) -> u8 {
self.decimal_dtype().precision()
}
pub fn scale(&self) -> i8 {
self.decimal_dtype().scale()
}
pub fn from_iter<T: NativeDecimalType, I: IntoIterator<Item = T>>(
iter: I,
decimal_dtype: DecimalDType,
) -> Self {
let iter = iter.into_iter();
Self::new(
BufferMut::from_iter(iter).freeze(),
decimal_dtype,
Validity::NonNullable,
)
}
pub fn from_option_iter<T: NativeDecimalType, I: IntoIterator<Item = Option<T>>>(
iter: I,
decimal_dtype: DecimalDType,
) -> Self {
let iter = iter.into_iter();
let mut values = BufferMut::with_capacity(iter.size_hint().0);
let mut validity = BitBufferMut::with_capacity(values.capacity());
for i in iter {
match i {
None => {
validity.append(false);
values.push(T::default());
}
Some(e) => {
validity.append(true);
values.push(e);
}
}
}
Self::new(
values.freeze(),
decimal_dtype,
Validity::from(validity.freeze()),
)
}
#[expect(
clippy::cognitive_complexity,
reason = "complexity from nested match_each_* macros"
)]
pub fn patch(self, patches: &Patches, ctx: &mut ExecutionCtx) -> VortexResult<Self> {
let offset = patches.offset();
let patch_indices = patches.indices().clone().execute::<PrimitiveArray>(ctx)?;
let patch_values = patches.values().clone().execute::<DecimalArray>(ctx)?;
let patched_validity = self.validity().clone().patch(
self.len(),
offset,
&patch_indices.clone().into_array(),
patch_values.validity(),
ctx,
)?;
assert_eq!(self.decimal_dtype(), patch_values.decimal_dtype());
Ok(match_each_integer_ptype!(patch_indices.ptype(), |I| {
let patch_indices = patch_indices.as_slice::<I>();
match_each_decimal_value_type!(patch_values.values_type(), |PatchDVT| {
let patch_values = patch_values.buffer::<PatchDVT>();
match_each_decimal_value_type!(self.values_type(), |ValuesDVT| {
let buffer = self.buffer::<ValuesDVT>().into_mut();
patch_typed(
buffer,
self.decimal_dtype(),
patch_indices,
offset,
patch_values,
patched_validity,
)
})
})
}))
}
}
fn patch_typed<I, ValuesDVT, PatchDVT>(
mut buffer: BufferMut<ValuesDVT>,
decimal_dtype: DecimalDType,
patch_indices: &[I],
patch_indices_offset: usize,
patch_values: Buffer<PatchDVT>,
patched_validity: Validity,
) -> DecimalArray
where
I: IntegerPType,
PatchDVT: NativeDecimalType,
ValuesDVT: NativeDecimalType,
{
if !ValuesDVT::DECIMAL_TYPE.is_compatible_decimal_value_type(decimal_dtype) {
vortex_panic!(
"patch_typed: {:?} cannot represent every value in {}.",
ValuesDVT::DECIMAL_TYPE,
decimal_dtype
)
}
for (idx, value) in patch_indices.iter().zip_eq(patch_values.into_iter()) {
buffer[idx.as_() - patch_indices_offset] = <ValuesDVT as BigCast>::from(value).vortex_expect(
"values of a given DecimalDType are representable in all compatible NativeDecimalType",
);
}
DecimalArray::new(buffer.freeze(), decimal_dtype, patched_validity)
}