use vortex_array::dtype::PType;
use vortex_array::dtype::extension::ExtDTypeRef;
use vortex_array::dtype::extension::Matcher;
use crate::types::fixed_shape_tensor::AnyFixedShapeTensor;
use crate::types::fixed_shape_tensor::FixedShapeTensorMatcherMetadata;
use crate::types::vector::AnyVector;
use crate::types::vector::VectorMatcherMetadata;
pub struct AnyTensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TensorMatch<'a> {
FixedShapeTensor(FixedShapeTensorMatcherMetadata<'a>),
Vector(VectorMatcherMetadata),
}
impl TensorMatch<'_> {
pub fn element_ptype(self) -> PType {
match self {
Self::FixedShapeTensor(metadata) => metadata.element_ptype(),
Self::Vector(metadata) => metadata.element_ptype(),
}
}
pub fn list_size(self) -> u32 {
match self {
Self::FixedShapeTensor(metadata) => metadata.flat_list_size(),
Self::Vector(metadata) => metadata.dimensions(),
}
}
}
impl Matcher for AnyTensor {
type Match<'a> = TensorMatch<'a>;
fn try_match<'a>(ext_dtype: &'a ExtDTypeRef) -> Option<Self::Match<'a>> {
if let Some(metadata) = ext_dtype.metadata_opt::<AnyFixedShapeTensor>() {
return Some(TensorMatch::FixedShapeTensor(metadata));
}
if let Some(metadata) = ext_dtype.metadata_opt::<AnyVector>() {
return Some(TensorMatch::Vector(metadata));
}
None
}
}