use vortex_array::dtype::DType;
use vortex_array::dtype::PType;
use vortex_array::dtype::extension::ExtDTypeRef;
use vortex_array::dtype::extension::Matcher;
use vortex_error::VortexExpect;
use vortex_error::vortex_panic;
use crate::fixed_shape::FixedShapeTensor;
use crate::fixed_shape::FixedShapeTensorMetadata;
pub struct AnyFixedShapeTensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct FixedShapeTensorMatcherMetadata<'a> {
metadata: &'a FixedShapeTensorMetadata,
element_ptype: PType,
flat_list_size: usize,
}
impl Matcher for AnyFixedShapeTensor {
type Match<'a> = FixedShapeTensorMatcherMetadata<'a>;
fn try_match<'a>(ext_dtype: &'a ExtDTypeRef) -> Option<Self::Match<'a>> {
if !ext_dtype.is::<FixedShapeTensor>() {
return None;
}
let metadata = ext_dtype
.metadata_opt::<FixedShapeTensor>()
.vortex_expect("`FixedShapeTensor` type somehow did not have metadata");
let DType::FixedSizeList(element_dtype, list_size, _) = ext_dtype.storage_dtype() else {
vortex_panic!(
"`FixedShapeTensor` type somehow did not have a `FixedSizeList` storage type"
)
};
assert!(
element_dtype.is_primitive(),
"element dtype must be primitive"
);
assert!(
!element_dtype.is_nullable(),
"element dtype must be non-nullable"
);
Some(FixedShapeTensorMatcherMetadata {
metadata,
element_ptype: element_dtype.as_ptype(),
flat_list_size: *list_size as usize,
})
}
}
impl FixedShapeTensorMatcherMetadata<'_> {
pub fn metadata(&self) -> &FixedShapeTensorMetadata {
self.metadata
}
pub fn element_ptype(&self) -> PType {
self.element_ptype
}
pub fn list_size(&self) -> usize {
self.flat_list_size
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use vortex_array::dtype::DType;
use vortex_array::dtype::Nullability;
use vortex_array::dtype::PType;
use vortex_array::dtype::extension::ExtDType;
use vortex_array::extension::EmptyMetadata;
use vortex_error::VortexResult;
use super::*;
use crate::vector::Vector;
fn tensor_storage_dtype(element_ptype: PType, list_size: u32) -> DType {
DType::FixedSizeList(
Arc::new(DType::Primitive(element_ptype, Nullability::NonNullable)),
list_size,
Nullability::NonNullable,
)
}
#[test]
fn matches_fixed_shape_tensor_dtype_metadata() -> VortexResult<()> {
let ext_dtype = ExtDType::<FixedShapeTensor>::try_new(
FixedShapeTensorMetadata::new(vec![2, 3, 4]),
tensor_storage_dtype(PType::F32, 24),
)?
.erased();
let metadata = ext_dtype.metadata::<AnyFixedShapeTensor>();
assert_eq!(metadata.element_ptype(), PType::F32);
assert_eq!(metadata.list_size(), 24);
assert_eq!(metadata.metadata().logical_shape(), &[2, 3, 4]);
Ok(())
}
#[test]
fn does_not_match_vector() -> VortexResult<()> {
let ext_dtype =
ExtDType::<Vector>::try_new(EmptyMetadata, tensor_storage_dtype(PType::F32, 24))?
.erased();
assert!(ext_dtype.metadata_opt::<AnyFixedShapeTensor>().is_none());
Ok(())
}
}