use std::fmt::Display;
use std::fmt::Formatter;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use crate::ArrayRef;
use crate::array::Array;
use crate::array::ArrayParts;
use crate::array::TypedArrayRef;
use crate::arrays::Extension;
use crate::dtype::DType;
use crate::dtype::extension::ExtDTypeRef;
pub(super) const STORAGE_SLOT: usize = 0;
pub(super) const NUM_SLOTS: usize = 1;
pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["storage"];
#[derive(Clone, Debug)]
pub struct ExtensionData {
pub(super) ext_dtype: ExtDTypeRef,
}
impl Display for ExtensionData {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "ext_dtype: {}", self.ext_dtype)
}
}
impl ExtensionData {
pub fn new(ext_dtype: ExtDTypeRef, storage_dtype: &DType) -> Self {
Self::try_new(ext_dtype, storage_dtype).vortex_expect("Failed to create `ExtensionArray`")
}
pub fn try_new(ext_dtype: ExtDTypeRef, storage_dtype: &DType) -> VortexResult<Self> {
assert_eq!(
ext_dtype.storage_dtype(),
storage_dtype,
"ExtensionArray: storage_dtype must match storage array DType",
);
Ok(unsafe { Self::new_unchecked(ext_dtype, storage_dtype) })
}
pub unsafe fn new_unchecked(ext_dtype: ExtDTypeRef, storage_dtype: &DType) -> Self {
debug_assert_eq!(
ext_dtype.storage_dtype(),
storage_dtype,
"ExtensionArray: storage_dtype must match storage array DType",
);
Self { ext_dtype }
}
pub fn ext_dtype(&self) -> &ExtDTypeRef {
&self.ext_dtype
}
}
pub trait ExtensionArrayExt: TypedArrayRef<Extension> {
fn storage_array(&self) -> &ArrayRef {
self.as_ref().slots()[STORAGE_SLOT]
.as_ref()
.vortex_expect("ExtensionArray storage slot")
}
}
impl<T: TypedArrayRef<Extension>> ExtensionArrayExt for T {}
impl Array<Extension> {
pub fn new(ext_dtype: ExtDTypeRef, storage_array: ArrayRef) -> Self {
let dtype = DType::Extension(ext_dtype.clone());
let len = storage_array.len();
let data = ExtensionData::new(ext_dtype, storage_array.dtype());
unsafe {
Array::from_parts_unchecked(
ArrayParts::new(Extension, dtype, len, data).with_slots(vec![Some(storage_array)]),
)
}
}
pub fn try_new(ext_dtype: ExtDTypeRef, storage_array: ArrayRef) -> VortexResult<Self> {
let dtype = DType::Extension(ext_dtype.clone());
let len = storage_array.len();
let data = ExtensionData::try_new(ext_dtype, storage_array.dtype())?;
Ok(unsafe {
Array::from_parts_unchecked(
ArrayParts::new(Extension, dtype, len, data).with_slots(vec![Some(storage_array)]),
)
})
}
}