use vortex_array::ArrayRef;
use vortex_array::ExecutionCtx;
use vortex_array::IntoArray;
use vortex_array::arrays::Constant;
use vortex_array::arrays::ConstantArray;
use vortex_array::arrays::FixedSizeListArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt;
use vortex_array::arrays::scalar_fn::ExactScalarFn;
use vortex_array::dtype::DType;
use vortex_array::dtype::NativePType;
use vortex_array::dtype::PType;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use vortex_error::vortex_err;
use crate::matcher::AnyTensor;
use crate::matcher::TensorMatch;
use crate::scalar_fns::l2_denorm::L2Denorm;
pub fn validate_tensor_float_input(input_dtype: &DType) -> VortexResult<TensorMatch<'_>> {
let ext = input_dtype
.as_extension_opt()
.ok_or_else(|| vortex_err!("expected an extension type, got {input_dtype}"))?;
let tensor_match = ext
.metadata_opt::<AnyTensor>()
.ok_or_else(|| vortex_err!("expected an `AnyTensor`, got {input_dtype}"))?;
let ptype = tensor_match.element_ptype();
vortex_ensure!(
ptype.is_float(),
"expected a float element dtype, got {ptype}",
);
Ok(tensor_match)
}
pub struct FlatElements {
elems: PrimitiveArray,
stride: usize,
list_size: usize,
}
impl FlatElements {
#[must_use]
pub fn ptype(&self) -> PType {
self.elems.ptype()
}
#[must_use]
pub fn row<T: NativePType>(&self, i: usize) -> &[T] {
let slice = self.elems.as_slice::<T>();
&slice[i * self.stride..][..self.list_size]
}
}
pub fn extract_flat_elements(
storage: &ArrayRef,
list_size: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<FlatElements> {
if let Some(constant) = storage.as_opt::<Constant>() {
let single = ConstantArray::new(constant.scalar().clone(), 1).into_array();
let fsl: FixedSizeListArray = single.execute(ctx)?;
let elems: PrimitiveArray = fsl.elements().clone().execute(ctx)?;
return Ok(FlatElements {
elems,
stride: 0,
list_size,
});
}
let fsl: FixedSizeListArray = storage.clone().execute(ctx)?;
let elems: PrimitiveArray = fsl.elements().clone().execute(ctx)?;
Ok(FlatElements {
elems,
stride: list_size,
list_size,
})
}
pub fn extract_l2_denorm_children(array: &ArrayRef) -> (ArrayRef, ArrayRef) {
let sfn = array
.as_opt::<ExactScalarFn<L2Denorm>>()
.vortex_expect("expected ScalarFnArray wrapping L2Denorm");
(
sfn.nth_child(0)
.vortex_expect("L2Denorm missing normalized array"),
sfn.nth_child(1).vortex_expect("L2Denorm missing norms"),
)
}
#[cfg(test)]
pub mod test_helpers {
use vortex_array::ArrayRef;
use vortex_array::IntoArray;
use vortex_array::arrays::ConstantArray;
use vortex_array::arrays::ExtensionArray;
use vortex_array::arrays::FixedSizeListArray;
use vortex_array::dtype::DType;
use vortex_array::dtype::Nullability;
use vortex_array::dtype::PType;
use vortex_array::dtype::extension::ExtDType;
use vortex_array::extension::EmptyMetadata;
use vortex_array::scalar::Scalar;
use vortex_array::validity::Validity;
use vortex_buffer::Buffer;
use vortex_error::VortexResult;
use crate::fixed_shape::FixedShapeTensor;
use crate::fixed_shape::FixedShapeTensorMetadata;
use crate::vector::Vector;
pub fn tensor_array(shape: &[usize], elements: &[f64]) -> VortexResult<ArrayRef> {
let list_size: u32 = shape.iter().product::<usize>().max(1).try_into().unwrap();
let row_count = elements.len() / list_size as usize;
let elems: ArrayRef = Buffer::copy_from(elements).into_array();
let fsl = FixedSizeListArray::new(elems, list_size, Validity::NonNullable, row_count);
let metadata = FixedShapeTensorMetadata::new(shape.to_vec());
let ext_dtype =
ExtDType::<FixedShapeTensor>::try_new(metadata, fsl.dtype().clone())?.erased();
Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array())
}
pub fn vector_array(dim: u32, elements: &[f64]) -> VortexResult<ArrayRef> {
let row_count = elements.len() / dim as usize;
let elems: ArrayRef = Buffer::copy_from(elements).into_array();
let fsl = FixedSizeListArray::new(elems, dim, Validity::NonNullable, row_count);
let ext_dtype = ExtDType::<Vector>::try_new(EmptyMetadata, fsl.dtype().clone())?.erased();
Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array())
}
pub fn constant_tensor_array(
shape: &[usize],
elements: &[f64],
len: usize,
) -> VortexResult<ArrayRef> {
let element_dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
let children: Vec<Scalar> = elements
.iter()
.map(|&v| Scalar::primitive(v, Nullability::NonNullable))
.collect();
let storage_scalar =
Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable);
let storage = ConstantArray::new(storage_scalar, len).into_array();
let metadata = FixedShapeTensorMetadata::new(shape.to_vec());
let ext_dtype =
ExtDType::<FixedShapeTensor>::try_new(metadata, storage.dtype().clone())?.erased();
Ok(ExtensionArray::new(ext_dtype, storage).into_array())
}
pub fn constant_vector_array(elements: &[f64], len: usize) -> VortexResult<ArrayRef> {
let element_dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
let children: Vec<Scalar> = elements
.iter()
.map(|&v| Scalar::primitive(v, Nullability::NonNullable))
.collect();
let storage_scalar =
Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable);
let storage = ConstantArray::new(storage_scalar, len).into_array();
let ext_dtype =
ExtDType::<Vector>::try_new(EmptyMetadata, storage.dtype().clone())?.erased();
Ok(ExtensionArray::new(ext_dtype, storage).into_array())
}
#[track_caller]
pub fn assert_close(actual: &[f64], expected: &[f64]) {
assert_eq!(
actual.len(),
expected.len(),
"length mismatch: got {} elements, expected {}",
actual.len(),
expected.len()
);
for (i, (a, e)) in actual.iter().zip(expected).enumerate() {
if a.is_nan() && e.is_nan() {
continue;
}
assert!(
(a - e).abs() < 1e-10,
"element {i}: got {a}, expected {e} (diff = {})",
(a - e).abs()
);
}
}
}