use vortex_array::dtype::PType;
use vortex_array::dtype::extension::ExtDTypeRef;
use vortex_array::dtype::extension::Matcher;
use crate::fixed_shape::AnyFixedShapeTensor;
use crate::fixed_shape::FixedShapeTensorMatcherMetadata;
use crate::vector::AnyVector;
use crate::vector::VectorMatcherMetadata;
pub struct AnyTensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TensorMatch<'a> {
FixedShapeTensor(FixedShapeTensorMatcherMetadata<'a>),
Vector(VectorMatcherMetadata),
}
impl TensorMatch<'_> {
pub fn element_ptype(self) -> PType {
match self {
Self::FixedShapeTensor(metadata) => metadata.element_ptype(),
Self::Vector(metadata) => metadata.element_ptype(),
}
}
pub fn list_size(self) -> usize {
match self {
Self::FixedShapeTensor(metadata) => metadata.list_size(),
Self::Vector(metadata) => metadata.dimensions() as usize,
}
}
}
impl Matcher for AnyTensor {
type Match<'a> = TensorMatch<'a>;
fn try_match<'a>(ext_dtype: &'a ExtDTypeRef) -> Option<Self::Match<'a>> {
if let Some(metadata) = ext_dtype.metadata_opt::<AnyFixedShapeTensor>() {
return Some(TensorMatch::FixedShapeTensor(metadata));
}
if let Some(metadata) = ext_dtype.metadata_opt::<AnyVector>() {
return Some(TensorMatch::Vector(metadata));
}
None
}
}