use vortex_array::dtype::DType;
use vortex_array::dtype::PType;
use vortex_array::dtype::extension::ExtDTypeRef;
use vortex_array::dtype::extension::Matcher;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use vortex_error::vortex_panic;
use crate::vector::Vector;
pub struct AnyVector;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct VectorMatcherMetadata {
element_ptype: PType,
dimensions: u32,
}
impl Matcher for AnyVector {
type Match<'a> = VectorMatcherMetadata;
fn try_match<'a>(ext_dtype: &'a ExtDTypeRef) -> Option<Self::Match<'a>> {
if !ext_dtype.is::<Vector>() {
return None;
}
let DType::FixedSizeList(element_dtype, list_size, _) = ext_dtype.storage_dtype() else {
vortex_panic!("`Vector` type somehow did not have a `FixedSizeList` storage type")
};
let dimensions = *list_size;
assert!(element_dtype.is_float(), "element dtype must be primitive");
assert!(
!element_dtype.is_nullable(),
"element dtype must be non-nullable"
);
let element_ptype = element_dtype.as_ptype();
let vector_metadata = VectorMatcherMetadata::try_new(element_ptype, dimensions)
.vortex_expect("`Vector` type somehow did not have float elements");
Some(vector_metadata)
}
}
impl VectorMatcherMetadata {
pub fn try_new(element_ptype: PType, dimensions: u32) -> VortexResult<Self> {
vortex_ensure!(element_ptype.is_float());
Ok(Self {
element_ptype,
dimensions,
})
}
pub fn element_ptype(&self) -> PType {
self.element_ptype
}
pub fn dimensions(&self) -> u32 {
self.dimensions
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use vortex_array::dtype::DType;
use vortex_array::dtype::Nullability;
use vortex_array::dtype::PType;
use vortex_array::dtype::extension::ExtDType;
use vortex_array::extension::EmptyMetadata;
use vortex_error::VortexResult;
use super::*;
use crate::fixed_shape::FixedShapeTensor;
use crate::fixed_shape::FixedShapeTensorMetadata;
fn vector_storage_dtype(element_ptype: PType, dimensions: u32) -> DType {
DType::FixedSizeList(
Arc::new(DType::Primitive(element_ptype, Nullability::NonNullable)),
dimensions,
Nullability::NonNullable,
)
}
#[test]
fn matches_vector_dtype_metadata() -> VortexResult<()> {
let ext_dtype =
ExtDType::<Vector>::try_new(EmptyMetadata, vector_storage_dtype(PType::F32, 256))?
.erased();
let metadata = ext_dtype.metadata::<AnyVector>();
assert_eq!(metadata.element_ptype(), PType::F32);
assert_eq!(metadata.dimensions(), 256);
Ok(())
}
#[test]
fn does_not_match_fixed_shape_tensor() -> VortexResult<()> {
let ext_dtype = ExtDType::<FixedShapeTensor>::try_new(
FixedShapeTensorMetadata::new(vec![16, 16]),
vector_storage_dtype(PType::F32, 256),
)?
.erased();
assert!(ext_dtype.metadata_opt::<AnyVector>().is_none());
Ok(())
}
}