use std::fmt::Display;
use std::fmt::Formatter;
use itertools::Itertools;
use vortex_buffer::Alignment;
use vortex_buffer::BitBufferMut;
use vortex_buffer::Buffer;
use vortex_buffer::BufferMut;
use vortex_buffer::ByteBuffer;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use vortex_error::vortex_panic;
use crate::ArrayRef;
use crate::ExecutionCtx;
use crate::IntoArray;
use crate::array::Array;
use crate::array::ArrayParts;
use crate::array::TypedArrayRef;
use crate::array::child_to_validity;
use crate::array::validity_to_child;
use crate::arrays::Decimal;
use crate::arrays::DecimalArray;
use crate::arrays::PrimitiveArray;
use crate::buffer::BufferHandle;
use crate::dtype::BigCast;
use crate::dtype::DType;
use crate::dtype::DecimalDType;
use crate::dtype::DecimalType;
use crate::dtype::IntegerPType;
use crate::dtype::NativeDecimalType;
use crate::dtype::Nullability;
use crate::match_each_decimal_value_type;
use crate::match_each_integer_ptype;
use crate::patches::Patches;
use crate::validity::Validity;
pub(super) const VALIDITY_SLOT: usize = 0;
pub(super) const NUM_SLOTS: usize = 1;
pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["validity"];
#[derive(Clone, Debug)]
pub struct DecimalData {
pub(super) decimal_dtype: DecimalDType,
pub(super) values: BufferHandle,
pub(super) values_type: DecimalType,
}
impl Display for DecimalData {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"decimal_dtype: {}, values_type: {}",
self.decimal_dtype, self.values_type
)
}
}
pub struct DecimalDataParts {
pub decimal_dtype: DecimalDType,
pub values: BufferHandle,
pub values_type: DecimalType,
pub validity: Validity,
}
pub trait DecimalArrayExt: TypedArrayRef<Decimal> {
fn decimal_dtype(&self) -> DecimalDType {
match self.as_ref().dtype() {
DType::Decimal(decimal_dtype, _) => *decimal_dtype,
_ => unreachable!("DecimalArrayExt requires a decimal dtype"),
}
}
fn nullability(&self) -> Nullability {
match self.as_ref().dtype() {
DType::Decimal(_, nullability) => *nullability,
_ => unreachable!("DecimalArrayExt requires a decimal dtype"),
}
}
fn validity_child(&self) -> Option<&ArrayRef> {
self.as_ref().slots()[VALIDITY_SLOT].as_ref()
}
fn validity(&self) -> Validity {
child_to_validity(&self.as_ref().slots()[VALIDITY_SLOT], self.nullability())
}
fn values_type(&self) -> DecimalType {
self.values_type
}
fn precision(&self) -> u8 {
self.decimal_dtype().precision()
}
fn scale(&self) -> i8 {
self.decimal_dtype().scale()
}
fn buffer_handle(&self) -> &BufferHandle {
&self.values
}
fn buffer<T: NativeDecimalType>(&self) -> Buffer<T> {
DecimalData::buffer::<T>(self)
}
}
impl<T: TypedArrayRef<Decimal>> DecimalArrayExt for T {}
impl DecimalData {
pub(super) fn make_slots(validity: &Validity, len: usize) -> Vec<Option<ArrayRef>> {
vec![validity_to_child(validity, len)]
}
pub fn new<T: NativeDecimalType>(buffer: Buffer<T>, decimal_dtype: DecimalDType) -> Self {
Self::try_new(buffer, decimal_dtype).vortex_expect("DecimalArray construction failed")
}
pub fn new_handle(
values: BufferHandle,
values_type: DecimalType,
decimal_dtype: DecimalDType,
) -> Self {
Self::try_new_handle(values, values_type, decimal_dtype)
.vortex_expect("DecimalArray construction failed")
}
pub fn try_new<T: NativeDecimalType>(
buffer: Buffer<T>,
decimal_dtype: DecimalDType,
) -> VortexResult<Self> {
let values = BufferHandle::new_host(buffer.into_byte_buffer());
let values_type = T::DECIMAL_TYPE;
Self::try_new_handle(values, values_type, decimal_dtype)
}
pub fn try_new_handle(
values: BufferHandle,
values_type: DecimalType,
decimal_dtype: DecimalDType,
) -> VortexResult<Self> {
Self::validate(&values, values_type)?;
Ok(unsafe { Self::new_unchecked_handle(values, values_type, decimal_dtype) })
}
pub unsafe fn new_unchecked<T: NativeDecimalType>(
buffer: Buffer<T>,
decimal_dtype: DecimalDType,
) -> Self {
unsafe {
Self::new_unchecked_handle(
BufferHandle::new_host(buffer.into_byte_buffer()),
T::DECIMAL_TYPE,
decimal_dtype,
)
}
}
pub unsafe fn new_unchecked_handle(
values: BufferHandle,
values_type: DecimalType,
decimal_dtype: DecimalDType,
) -> Self {
Self {
decimal_dtype,
values,
values_type,
}
}
fn validate(buffer: &BufferHandle, values_type: DecimalType) -> VortexResult<()> {
let byte_width = values_type.byte_width();
vortex_ensure!(
buffer.len().is_multiple_of(byte_width),
InvalidArgument: "decimal buffer size {} is not divisible by element width {}",
buffer.len(),
byte_width,
);
match_each_decimal_value_type!(values_type, |D| {
vortex_ensure!(
buffer.is_aligned_to(Alignment::of::<D>()),
InvalidArgument: "decimal buffer alignment {:?} is invalid for values type {:?}",
buffer.alignment(),
D::DECIMAL_TYPE,
);
Ok::<(), vortex_error::VortexError>(())
})?;
Ok(())
}
pub unsafe fn new_unchecked_from_byte_buffer(
byte_buffer: ByteBuffer,
values_type: DecimalType,
decimal_dtype: DecimalDType,
) -> Self {
unsafe {
Self::new_unchecked_handle(
BufferHandle::new_host(byte_buffer),
values_type,
decimal_dtype,
)
}
}
pub fn len(&self) -> usize {
self.values.len() / self.values_type.byte_width()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn buffer_handle(&self) -> &BufferHandle {
&self.values
}
pub fn buffer<T: NativeDecimalType>(&self) -> Buffer<T> {
if self.values_type != T::DECIMAL_TYPE {
vortex_panic!(
"Cannot extract Buffer<{:?}> for DecimalArray with values_type {:?}",
T::DECIMAL_TYPE,
self.values_type,
);
}
Buffer::<T>::from_byte_buffer(self.values.as_host().clone())
}
pub fn values_type(&self) -> DecimalType {
self.values_type
}
pub fn decimal_dtype(&self) -> DecimalDType {
self.decimal_dtype
}
pub fn precision(&self) -> u8 {
self.decimal_dtype.precision()
}
pub fn scale(&self) -> i8 {
self.decimal_dtype.scale()
}
}
impl Array<Decimal> {
pub fn into_data_parts(self) -> DecimalDataParts {
let validity = DecimalArrayExt::validity(&self);
let decimal_dtype = DecimalArrayExt::decimal_dtype(&self);
let data = self.into_data();
DecimalDataParts {
decimal_dtype,
values: data.values,
values_type: data.values_type,
validity,
}
}
}
impl Array<Decimal> {
pub fn new<T: NativeDecimalType>(
buffer: Buffer<T>,
decimal_dtype: DecimalDType,
validity: Validity,
) -> Self {
Self::try_new(buffer, decimal_dtype, validity)
.vortex_expect("DecimalArray construction failed")
}
pub unsafe fn new_unchecked<T: NativeDecimalType>(
buffer: Buffer<T>,
decimal_dtype: DecimalDType,
validity: Validity,
) -> Self {
let dtype = DType::Decimal(decimal_dtype, validity.nullability());
let len = buffer.len();
let slots = DecimalData::make_slots(&validity, len);
let data = unsafe { DecimalData::new_unchecked(buffer, decimal_dtype) };
unsafe {
Array::from_parts_unchecked(
ArrayParts::new(Decimal, dtype, len, data).with_slots(slots),
)
}
}
pub fn try_new<T: NativeDecimalType>(
buffer: Buffer<T>,
decimal_dtype: DecimalDType,
validity: Validity,
) -> VortexResult<Self> {
let dtype = DType::Decimal(decimal_dtype, validity.nullability());
let len = buffer.len();
let slots = DecimalData::make_slots(&validity, len);
let data = DecimalData::try_new(buffer, decimal_dtype)?;
Array::try_from_parts(ArrayParts::new(Decimal, dtype, len, data).with_slots(slots))
}
#[expect(
clippy::same_name_method,
reason = "intentionally named from_iter like Iterator::from_iter"
)]
pub fn from_iter<T: NativeDecimalType, I: IntoIterator<Item = T>>(
iter: I,
decimal_dtype: DecimalDType,
) -> Self {
Self::new(
BufferMut::from_iter(iter).freeze(),
decimal_dtype,
Validity::NonNullable,
)
}
pub fn from_option_iter<T: NativeDecimalType, I: IntoIterator<Item = Option<T>>>(
iter: I,
decimal_dtype: DecimalDType,
) -> Self {
let iter = iter.into_iter();
let mut values = BufferMut::with_capacity(iter.size_hint().0);
let mut validity = BitBufferMut::with_capacity(values.capacity());
for value in iter {
match value {
Some(value) => {
values.push(value);
validity.append(true);
}
None => {
values.push(T::default());
validity.append(false);
}
}
}
Self::new(
values.freeze(),
decimal_dtype,
Validity::from(validity.freeze()),
)
}
pub fn new_handle(
values: BufferHandle,
values_type: DecimalType,
decimal_dtype: DecimalDType,
validity: Validity,
) -> Self {
Self::try_new_handle(values, values_type, decimal_dtype, validity)
.vortex_expect("DecimalArray construction failed")
}
pub fn try_new_handle(
values: BufferHandle,
values_type: DecimalType,
decimal_dtype: DecimalDType,
validity: Validity,
) -> VortexResult<Self> {
let dtype = DType::Decimal(decimal_dtype, validity.nullability());
let len = values.len() / values_type.byte_width();
let slots = DecimalData::make_slots(&validity, len);
let data = DecimalData::try_new_handle(values, values_type, decimal_dtype)?;
Array::try_from_parts(ArrayParts::new(Decimal, dtype, len, data).with_slots(slots))
}
pub unsafe fn new_unchecked_handle(
values: BufferHandle,
values_type: DecimalType,
decimal_dtype: DecimalDType,
validity: Validity,
) -> Self {
let dtype = DType::Decimal(decimal_dtype, validity.nullability());
let len = values.len() / values_type.byte_width();
let slots = DecimalData::make_slots(&validity, len);
let data = unsafe { DecimalData::new_unchecked_handle(values, values_type, decimal_dtype) };
unsafe {
Array::from_parts_unchecked(
ArrayParts::new(Decimal, dtype, len, data).with_slots(slots),
)
}
}
#[allow(
clippy::cognitive_complexity,
reason = "patching depends on both patch and value physical types"
)]
pub fn patch(self, patches: &Patches, ctx: &mut ExecutionCtx) -> VortexResult<Self> {
let offset = patches.offset();
let dtype = self.dtype().clone();
let len = self.len();
let patch_indices = patches.indices().clone().execute::<PrimitiveArray>(ctx)?;
let patch_values = patches.values().clone().execute::<DecimalArray>(ctx)?;
let patch_validity = patch_values.validity()?;
let patched_validity = self.validity()?.patch(
self.len(),
offset,
&patch_indices.clone().into_array(),
&patch_validity,
ctx,
)?;
assert_eq!(self.decimal_dtype(), patch_values.decimal_dtype());
let data = self.into_data();
let data = match_each_integer_ptype!(patch_indices.ptype(), |I| {
let patch_indices = patch_indices.as_slice::<I>();
match_each_decimal_value_type!(patch_values.values_type(), |PatchDVT| {
let patch_values = patch_values.buffer::<PatchDVT>();
match_each_decimal_value_type!(data.values_type(), |ValuesDVT| {
let buffer = data.buffer::<ValuesDVT>().into_mut();
patch_typed(
buffer,
data.decimal_dtype(),
patch_indices,
offset,
patch_values,
)
})
})
});
let slots = DecimalData::make_slots(&patched_validity, len);
Ok(unsafe {
Array::from_parts_unchecked(
ArrayParts::new(Decimal, dtype, len, data).with_slots(slots),
)
})
}
}
fn patch_typed<I, ValuesDVT, PatchDVT>(
mut buffer: BufferMut<ValuesDVT>,
decimal_dtype: DecimalDType,
patch_indices: &[I],
patch_indices_offset: usize,
patch_values: Buffer<PatchDVT>,
) -> DecimalData
where
I: IntegerPType,
PatchDVT: NativeDecimalType,
ValuesDVT: NativeDecimalType,
{
if !ValuesDVT::DECIMAL_TYPE.is_compatible_decimal_value_type(decimal_dtype) {
vortex_panic!(
"patch_typed: {:?} cannot represent every value in {}.",
ValuesDVT::DECIMAL_TYPE,
decimal_dtype
)
}
for (idx, value) in patch_indices.iter().zip_eq(patch_values.into_iter()) {
buffer[idx.as_() - patch_indices_offset] = <ValuesDVT as BigCast>::from(value).vortex_expect(
"values of a given DecimalDType are representable in all compatible NativeDecimalType",
);
}
DecimalData::new(buffer.freeze(), decimal_dtype)
}