use half::f16;
use prost::Message;
use vortex_array::ArrayRef;
use vortex_array::ExecutionCtx;
use vortex_array::IntoArray;
use vortex_array::arrays::Constant;
use vortex_array::arrays::ConstantArray;
use vortex_array::arrays::FixedSizeListArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::ScalarFn;
use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt;
use vortex_array::arrays::primitive::PrimitiveArrayExt;
use vortex_array::arrays::scalar_fn::ExactScalarFn;
use vortex_array::arrays::scalar_fn::ScalarFnArrayExt;
use vortex_array::arrays::scalar_fn::ScalarFnArrayView;
use vortex_array::dtype::DType;
use vortex_array::dtype::NativePType;
use vortex_array::dtype::PType;
use vortex_array::dtype::proto::dtype as pb;
use vortex_array::scalar_fn::ScalarFnVTable;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use vortex_error::vortex_err;
use vortex_session::VortexSession;
use crate::matcher::AnyTensor;
use crate::matcher::TensorMatch;
use crate::scalar_fns::l2_denorm::L2Denorm;
pub(crate) const SAFETY_FACTOR: usize = 10;
pub fn unit_norm_tolerance(element_ptype: PType, dimensions: usize) -> f64 {
let machine_epsilon: f64 = match element_ptype {
PType::F64 => f64::EPSILON,
PType::F32 => f32::EPSILON as f64,
PType::F16 => f16::EPSILON.to_f64_const(),
_ => unreachable!("unit_norm_tolerance requires a float ptype, got {element_ptype:?}"),
};
let dimensions_root = (dimensions as f64).sqrt();
SAFETY_FACTOR as f64 * machine_epsilon * dimensions_root
}
pub fn extract_l2_denorm_children(array: &ArrayRef) -> (ArrayRef, ArrayRef) {
let sfn = array
.as_opt::<ExactScalarFn<L2Denorm>>()
.vortex_expect("expected ScalarFnArray wrapping L2Denorm");
(
sfn.nth_child(0)
.vortex_expect("L2Denorm missing normalized array"),
sfn.nth_child(1).vortex_expect("L2Denorm missing norms"),
)
}
pub fn validate_tensor_float_input(input_dtype: &DType) -> VortexResult<TensorMatch<'_>> {
let ext = input_dtype
.as_extension_opt()
.ok_or_else(|| vortex_err!("expected an extension type, got {input_dtype}"))?;
let tensor_match = ext
.metadata_opt::<AnyTensor>()
.ok_or_else(|| vortex_err!("expected an `AnyTensor`, got {input_dtype}"))?;
let ptype = tensor_match.element_ptype();
vortex_ensure!(
ptype.is_float(),
"expected a float element dtype, got {ptype}",
);
Ok(tensor_match)
}
pub fn validate_binary_tensor_float_inputs<'a>(
lhs: &'a DType,
rhs: &DType,
) -> VortexResult<TensorMatch<'a>> {
vortex_ensure!(
lhs.eq_ignore_nullability(rhs),
"binary tensor expression expects inputs to have the same dtype, got {lhs} and {rhs}"
);
validate_tensor_float_input(lhs)
}
pub struct FlatElements {
elems: PrimitiveArray,
list_size: usize,
is_constant: bool,
}
impl FlatElements {
#[must_use]
pub fn ptype(&self) -> PType {
self.elems.ptype()
}
#[must_use]
pub fn row<T: NativePType>(&self, i: usize) -> &[T] {
let row_idx = if self.is_constant { 0 } else { i };
let slice = self.elems.as_slice::<T>();
&slice[row_idx * self.list_size..][..self.list_size]
}
}
pub fn extract_flat_elements(
storage: &ArrayRef,
list_size: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<FlatElements> {
let (source, is_constant) = if let Some(constant) = storage.as_opt::<Constant>() {
let single = ConstantArray::new(constant.scalar().clone(), 1).into_array();
(single, true)
} else {
(storage.clone(), false)
};
let fsl: FixedSizeListArray = source.execute(ctx)?;
let elems: PrimitiveArray = fsl.elements().clone().execute(ctx)?;
vortex_ensure!(
!elems.nullability().is_nullable(),
"tensor storage elements must be non-nullable, got {}",
elems.dtype(),
);
Ok(FlatElements {
elems,
list_size,
is_constant,
})
}
pub struct FlatRow {
elems: PrimitiveArray,
}
impl FlatRow {
#[must_use]
pub fn ptype(&self) -> PType {
self.elems.ptype()
}
#[must_use]
pub fn as_slice<T: NativePType>(&self) -> &[T] {
self.elems.as_slice::<T>()
}
}
pub fn extract_constant_flat_row(
storage: &ArrayRef,
ctx: &mut ExecutionCtx,
) -> VortexResult<FlatRow> {
let constant = storage
.as_opt::<Constant>()
.vortex_expect("extract_constant_flat_row requires Constant-backed storage");
let single = ConstantArray::new(constant.scalar().clone(), 1).into_array();
let fsl: FixedSizeListArray = single.execute(ctx)?;
let elems: PrimitiveArray = fsl.elements().clone().execute(ctx)?;
vortex_ensure!(
!elems.nullability().is_nullable(),
"tensor storage elements must be non-nullable, got {}",
elems.dtype(),
);
Ok(FlatRow { elems })
}
#[derive(Clone, prost::Message)]
pub(crate) struct BinaryTensorOpMetadata {
#[prost(message, optional, tag = "1")]
pub(crate) lhs_dtype: Option<pb::DType>,
#[prost(message, optional, tag = "2")]
pub(crate) rhs_dtype: Option<pb::DType>,
}
impl BinaryTensorOpMetadata {
pub(crate) fn encode_from_view<V: ScalarFnVTable>(
view: &ScalarFnArrayView<V>,
) -> VortexResult<Vec<u8>> {
let scalar_fn_array = view.as_::<ScalarFn>();
let lhs_dtype = Some(scalar_fn_array.child_at(0).dtype().try_into()?);
let rhs_dtype = Some(scalar_fn_array.child_at(1).dtype().try_into()?);
Ok(Self {
lhs_dtype,
rhs_dtype,
}
.encode_to_vec())
}
pub(crate) fn decode_children(
metadata: &[u8],
len: usize,
children: &dyn vortex_array::serde::ArrayChildren,
session: &VortexSession,
) -> VortexResult<Vec<ArrayRef>> {
let metadata = Self::decode(metadata)
.map_err(|e| vortex_err!("Failed to decode BinaryTensorOpMetadata: {e}"))?;
let lhs_pb = metadata
.lhs_dtype
.as_ref()
.ok_or_else(|| vortex_err!("metadata missing lhs_dtype"))?;
let rhs_pb = metadata
.rhs_dtype
.as_ref()
.ok_or_else(|| vortex_err!("metadata missing rhs_dtype"))?;
let lhs_dtype = DType::from_proto(lhs_pb, session)?;
let rhs_dtype = DType::from_proto(rhs_pb, session)?;
validate_binary_tensor_float_inputs(&lhs_dtype, &rhs_dtype)?;
let lhs = children.get(0, &lhs_dtype, len)?;
let rhs = children.get(1, &rhs_dtype, len)?;
Ok(vec![lhs, rhs])
}
}
#[cfg(test)]
pub mod test_helpers {
use vortex_array::ArrayRef;
use vortex_array::ExecutionCtx;
use vortex_array::IntoArray;
use vortex_array::arrays::ConstantArray;
use vortex_array::arrays::ExtensionArray;
use vortex_array::arrays::FixedSizeListArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::dtype::DType;
use vortex_array::dtype::NativePType;
use vortex_array::dtype::Nullability;
use vortex_array::dtype::extension::ExtDType;
use vortex_array::scalar::PValue;
use vortex_array::scalar::Scalar;
use vortex_array::validity::Validity;
use vortex_buffer::Buffer;
use vortex_error::VortexResult;
use crate::scalar_fns::l2_denorm::L2Denorm;
use crate::types::fixed_shape_tensor::FixedShapeTensor;
use crate::types::fixed_shape_tensor::FixedShapeTensorMetadata;
use crate::types::vector::Vector;
fn flat_fsl<T: NativePType>(elements: &[T], list_size: u32) -> ArrayRef {
let row_count = elements.len() / list_size as usize;
let elems: ArrayRef = Buffer::copy_from(elements).into_array();
FixedSizeListArray::new(elems, list_size, Validity::NonNullable, row_count).into_array()
}
fn fsl_scalar<T: NativePType + Into<PValue>>(elements: &[T]) -> Scalar {
let element_dtype = DType::Primitive(T::PTYPE, Nullability::NonNullable);
let children: Vec<Scalar> = elements
.iter()
.map(|&v| Scalar::primitive(v, Nullability::NonNullable))
.collect();
Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable)
}
pub fn tensor_array<T: NativePType>(shape: &[usize], elements: &[T]) -> VortexResult<ArrayRef> {
let list_size: u32 = shape.iter().product::<usize>().max(1).try_into().unwrap();
let storage = flat_fsl(elements, list_size);
let metadata = FixedShapeTensorMetadata::new(shape.to_vec());
let ext_dtype =
ExtDType::<FixedShapeTensor>::try_new(metadata, storage.dtype().clone())?.erased();
Ok(ExtensionArray::new(ext_dtype, storage).into_array())
}
pub fn vector_array<T: NativePType>(dim: u32, elements: &[T]) -> VortexResult<ArrayRef> {
Vector::try_new_vector_array(flat_fsl(elements, dim))
}
pub fn constant_tensor_array<T: NativePType + Into<PValue>>(
shape: &[usize],
elements: &[T],
len: usize,
) -> VortexResult<ArrayRef> {
let storage = ConstantArray::new(fsl_scalar(elements), len).into_array();
let metadata = FixedShapeTensorMetadata::new(shape.to_vec());
let ext_dtype =
ExtDType::<FixedShapeTensor>::try_new(metadata, storage.dtype().clone())?.erased();
Ok(ExtensionArray::new(ext_dtype, storage).into_array())
}
pub fn literal_vector_array<T: NativePType + Into<PValue>>(
elements: &[T],
len: usize,
) -> ArrayRef {
use vortex_array::EmptyMetadata;
let ext_scalar = Scalar::extension::<Vector>(EmptyMetadata, fsl_scalar(elements));
ConstantArray::new(ext_scalar, len).into_array()
}
pub fn l2_denorm_array<T: NativePType>(
shape: &[usize],
normalized_elements: &[T],
norms: &[T],
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let normalized = tensor_array(shape, normalized_elements)?;
let norms =
PrimitiveArray::new(Buffer::copy_from(norms), Validity::NonNullable).into_array();
Ok(L2Denorm::try_new_array(normalized, norms, ctx)?.into_array())
}
#[track_caller]
pub fn assert_close(actual: &[f64], expected: &[f64]) {
assert_eq!(
actual.len(),
expected.len(),
"length mismatch: got {} elements, expected {}",
actual.len(),
expected.len()
);
for (i, (a, e)) in actual.iter().zip(expected).enumerate() {
if a.is_nan() && e.is_nan() {
continue;
}
assert!(
(a - e).abs() < 1e-10,
"element {i}: got {a}, expected {e} (diff = {})",
(a - e).abs()
);
}
}
}