use std::fmt::Formatter;
use std::sync::Arc;
use num_traits::Float;
use prost::Message;
use vortex_array::ArrayRef;
use vortex_array::ExecutionCtx;
use vortex_array::IntoArray;
use vortex_array::arrays::Constant;
use vortex_array::arrays::ConstantArray;
use vortex_array::arrays::Dict;
use vortex_array::arrays::Extension;
use vortex_array::arrays::ExtensionArray;
use vortex_array::arrays::FixedSizeList;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::ScalarFnArray;
use vortex_array::arrays::ScalarFnVTable as ScalarFnArrayEncoding;
use vortex_array::arrays::dict::DictArraySlotsExt;
use vortex_array::arrays::extension::ExtensionArrayExt;
use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt;
use vortex_array::arrays::scalar_fn::ExactScalarFn;
use vortex_array::arrays::scalar_fn::ScalarFnArrayExt;
use vortex_array::arrays::scalar_fn::ScalarFnArrayView;
use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts;
use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable;
use vortex_array::dtype::DType;
use vortex_array::dtype::NativePType;
use vortex_array::dtype::Nullability;
use vortex_array::dtype::PType;
use vortex_array::dtype::extension::ExtDType;
use vortex_array::dtype::proto::dtype as pb;
use vortex_array::expr::Expression;
use vortex_array::expr::and;
use vortex_array::extension::EmptyMetadata;
use vortex_array::match_each_float_ptype;
use vortex_array::scalar::Scalar;
use vortex_array::scalar_fn::Arity;
use vortex_array::scalar_fn::ChildName;
use vortex_array::scalar_fn::EmptyOptions;
use vortex_array::scalar_fn::ExecutionArgs;
use vortex_array::scalar_fn::ScalarFn;
use vortex_array::scalar_fn::ScalarFnId;
use vortex_array::scalar_fn::ScalarFnVTable;
use vortex_array::serde::ArrayChildren;
use vortex_buffer::Buffer;
use vortex_buffer::BufferMut;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use vortex_error::vortex_err;
use vortex_session::VortexSession;
use crate::matcher::AnyTensor;
use crate::scalar_fns::l2_denorm::L2Denorm;
use crate::scalar_fns::sorf_transform::SorfMatrix;
use crate::scalar_fns::sorf_transform::SorfTransform;
use crate::utils::extract_flat_elements;
use crate::utils::extract_l2_denorm_children;
use crate::vector::Vector;
#[derive(Clone)]
pub struct InnerProduct;
impl InnerProduct {
pub fn new() -> ScalarFn<InnerProduct> {
ScalarFn::new(InnerProduct, EmptyOptions)
}
pub fn try_new_array(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult<ScalarFnArray> {
ScalarFnArray::try_new(InnerProduct::new().erased(), vec![lhs, rhs], len)
}
}
impl ScalarFnVTable for InnerProduct {
type Options = EmptyOptions;
fn id(&self) -> ScalarFnId {
ScalarFnId::from("vortex.tensor.inner_product")
}
fn arity(&self, _options: &Self::Options) -> Arity {
Arity::Exact(2)
}
fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
match child_idx {
0 => ChildName::from("lhs"),
1 => ChildName::from("rhs"),
_ => unreachable!("InnerProduct must have exactly two children"),
}
}
fn fmt_sql(
&self,
_options: &Self::Options,
expr: &Expression,
f: &mut Formatter<'_>,
) -> std::fmt::Result {
write!(f, "inner_product(")?;
expr.child(0).fmt_sql(f)?;
write!(f, ", ")?;
expr.child(1).fmt_sql(f)?;
write!(f, ")")
}
fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
let lhs = &arg_dtypes[0];
let rhs = &arg_dtypes[1];
vortex_ensure!(
lhs.eq_ignore_nullability(rhs),
"InnerProduct requires both inputs to have the same dtype, got {lhs} and {rhs}"
);
let lhs_ext = lhs
.as_extension_opt()
.ok_or_else(|| vortex_err!("InnerProduct lhs must be an extension type, got {lhs}"))?;
vortex_ensure!(
lhs_ext.is::<AnyTensor>(),
"InnerProduct inputs must be an `AnyTensor`, got {lhs}"
);
let tensor_match = lhs_ext
.metadata_opt::<AnyTensor>()
.ok_or_else(|| vortex_err!("InnerProduct inputs must be an `AnyTensor`, got {lhs}"))?;
let ptype = tensor_match.element_ptype();
vortex_ensure!(
ptype.is_float(),
"InnerProduct element dtype must be a float primitive, got {ptype}"
);
let nullability = Nullability::from(lhs.is_nullable() || rhs.is_nullable());
Ok(DType::Primitive(ptype, nullability))
}
fn execute(
&self,
_options: &Self::Options,
args: &dyn ExecutionArgs,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let mut lhs_ref = args.get(0)?;
let mut rhs_ref = args.get(1)?;
let len = args.row_count();
{
let lhs_is_denorm = lhs_ref.is::<ExactScalarFn<L2Denorm>>();
let rhs_is_denorm = rhs_ref.is::<ExactScalarFn<L2Denorm>>();
if lhs_is_denorm && rhs_is_denorm {
return self.execute_both_denorm(&lhs_ref, &rhs_ref, len, ctx);
} else if lhs_is_denorm || rhs_is_denorm {
if rhs_is_denorm {
(lhs_ref, rhs_ref) = (rhs_ref, lhs_ref);
}
return self.execute_one_denorm(&lhs_ref, &rhs_ref, len, ctx);
}
}
if let Some(rewritten) = self.try_execute_sorf_constant(&lhs_ref, &rhs_ref, len, ctx)? {
return Ok(rewritten);
}
if let Some(result) = self.try_execute_dict_constant(&lhs_ref, &rhs_ref, len, ctx)? {
return Ok(result);
}
let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?;
let lhs: ExtensionArray = lhs_ref.execute(ctx)?;
let rhs: ExtensionArray = rhs_ref.execute(ctx)?;
let ext = lhs.dtype().as_extension();
let tensor_match = ext
.metadata_opt::<AnyTensor>()
.vortex_expect("we already validated this in `return_dtype`");
let dimensions = tensor_match.list_size();
let lhs_storage = lhs.storage_array();
let rhs_storage = rhs.storage_array();
let lhs_flat = extract_flat_elements(lhs_storage, dimensions, ctx)?;
let rhs_flat = extract_flat_elements(rhs_storage, dimensions, ctx)?;
match_each_float_ptype!(lhs_flat.ptype(), |T| {
let buffer: Buffer<T> = (0..len)
.map(|i| inner_product_row(lhs_flat.row::<T>(i), rhs_flat.row::<T>(i)))
.collect();
Ok(unsafe { PrimitiveArray::new_unchecked(buffer, validity) }.into_array())
})
}
fn validity(
&self,
_options: &Self::Options,
expression: &Expression,
) -> VortexResult<Option<Expression>> {
let lhs_validity = expression.child(0).validity()?;
let rhs_validity = expression.child(1).validity()?;
Ok(Some(and(lhs_validity, rhs_validity)))
}
fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
false
}
fn is_fallible(&self, _options: &Self::Options) -> bool {
false
}
}
#[derive(Clone, prost::Message)]
pub(crate) struct BinaryTensorOpMetadata {
#[prost(message, optional, tag = "1")]
pub(crate) lhs_dtype: Option<pb::DType>,
#[prost(message, optional, tag = "2")]
pub(crate) rhs_dtype: Option<pb::DType>,
}
impl BinaryTensorOpMetadata {
pub(crate) fn encode_from_view<V: ScalarFnVTable>(
view: &ScalarFnArrayView<V>,
) -> VortexResult<Vec<u8>> {
let scalar_fn_array = view.as_::<ScalarFnArrayEncoding>();
let lhs_dtype = Some(scalar_fn_array.child_at(0).dtype().try_into()?);
let rhs_dtype = Some(scalar_fn_array.child_at(1).dtype().try_into()?);
Ok(Self {
lhs_dtype,
rhs_dtype,
}
.encode_to_vec())
}
pub(crate) fn decode_children(
metadata: &[u8],
len: usize,
children: &dyn ArrayChildren,
session: &VortexSession,
scalar_fn_name: &str,
) -> VortexResult<Vec<ArrayRef>> {
let metadata = Self::decode(metadata)
.map_err(|e| vortex_err!("Failed to decode BinaryTensorOpMetadata: {e}"))?;
let lhs_pb = metadata
.lhs_dtype
.as_ref()
.ok_or_else(|| vortex_err!("{scalar_fn_name} metadata missing lhs_dtype"))?;
let rhs_pb = metadata
.rhs_dtype
.as_ref()
.ok_or_else(|| vortex_err!("{scalar_fn_name} metadata missing rhs_dtype"))?;
let lhs_dtype = DType::from_proto(lhs_pb, session)?;
let rhs_dtype = DType::from_proto(rhs_pb, session)?;
vortex_ensure!(
lhs_dtype.eq_ignore_nullability(&rhs_dtype),
"{scalar_fn_name} operand dtype mismatch: {lhs_dtype} vs {rhs_dtype}"
);
let lhs = children.get(0, &lhs_dtype, len)?;
let rhs = children.get(1, &rhs_dtype, len)?;
Ok(vec![lhs, rhs])
}
}
impl ScalarFnArrayVTable for InnerProduct {
fn serialize(
&self,
view: &ScalarFnArrayView<Self>,
_session: &VortexSession,
) -> VortexResult<Option<Vec<u8>>> {
Ok(Some(BinaryTensorOpMetadata::encode_from_view(view)?))
}
fn deserialize(
&self,
_dtype: &DType,
len: usize,
metadata: &[u8],
children: &dyn ArrayChildren,
session: &VortexSession,
) -> VortexResult<ScalarFnArrayParts<Self>> {
let reconstructed = BinaryTensorOpMetadata::decode_children(
metadata,
len,
children,
session,
"InnerProduct",
)?;
Ok(ScalarFnArrayParts {
options: EmptyOptions,
children: reconstructed,
})
}
}
impl InnerProduct {
fn execute_both_denorm(
&self,
lhs_ref: &ArrayRef,
rhs_ref: &ArrayRef,
len: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?;
let (normalized_l, norms_l) = extract_l2_denorm_children(lhs_ref);
let (normalized_r, norms_r) = extract_l2_denorm_children(rhs_ref);
let norms_l: PrimitiveArray = norms_l.execute(ctx)?;
let norms_r: PrimitiveArray = norms_r.execute(ctx)?;
let dot: PrimitiveArray = InnerProduct::try_new_array(normalized_l, normalized_r, len)?
.into_array()
.execute(ctx)?;
match_each_float_ptype!(dot.ptype(), |T| {
let dots = dot.as_slice::<T>();
let nl = norms_l.as_slice::<T>();
let nr = norms_r.as_slice::<T>();
let buffer: Buffer<T> = (0..len).map(|i| nl[i] * nr[i] * dots[i]).collect();
Ok(unsafe { PrimitiveArray::new_unchecked(buffer, validity) }.into_array())
})
}
fn execute_one_denorm(
&self,
denorm_ref: &ArrayRef,
plain_ref: &ArrayRef,
len: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let validity = denorm_ref.validity()?.and(plain_ref.validity()?)?;
let (normalized, norms) = extract_l2_denorm_children(denorm_ref);
let denorm_norms: PrimitiveArray = norms.execute(ctx)?;
let dot: PrimitiveArray = InnerProduct::try_new_array(normalized, plain_ref.clone(), len)?
.into_array()
.execute(ctx)?;
match_each_float_ptype!(dot.ptype(), |T| {
let dots = dot.as_slice::<T>();
let ns = denorm_norms.as_slice::<T>();
let buffer: Buffer<T> = (0..len).map(|i| ns[i] * dots[i]).collect();
Ok(unsafe { PrimitiveArray::new_unchecked(buffer, validity) }.into_array())
})
}
fn try_execute_sorf_constant(
&self,
lhs_ref: &ArrayRef,
rhs_ref: &ArrayRef,
len: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
let (sorf_view, const_ref) =
if let Some(view) = lhs_ref.as_opt::<ExactScalarFn<SorfTransform>>() {
(view, rhs_ref)
} else if let Some(view) = rhs_ref.as_opt::<ExactScalarFn<SorfTransform>>() {
(view, lhs_ref)
} else {
return Ok(None);
};
if sorf_view.options.element_ptype != PType::F32 {
return Ok(None);
}
let Some(const_ext) = const_ref.as_opt::<Extension>() else {
return Ok(None);
};
let const_storage = const_ext.storage_array();
let Some(const_backing) = const_storage.as_opt::<Constant>() else {
return Ok(None);
};
if const_backing.scalar().is_null() {
return Ok(None);
}
let dim = sorf_view.options.dimension as usize;
let num_rounds = sorf_view.options.num_rounds as usize;
let seed = sorf_view.options.seed;
let padded_dim = dim.next_power_of_two();
let flat = extract_flat_elements(const_storage, dim, ctx)?;
if flat.ptype() != PType::F32 {
return Ok(None);
}
let mut padded_query = vec![0.0f32; padded_dim];
padded_query[..dim].copy_from_slice(flat.row::<f32>(0));
let rotation = SorfMatrix::try_new(seed, dim, num_rounds)?;
let mut rotated_query = vec![0.0f32; padded_dim];
rotation.rotate(&padded_query, &mut rotated_query);
let storage_fsl_nullability = const_storage.dtype().nullability();
let element_dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
let children: Vec<Scalar> = rotated_query
.into_iter()
.map(|v| Scalar::primitive(v, Nullability::NonNullable))
.collect();
let fsl_scalar =
Scalar::fixed_size_list(element_dtype.clone(), children, storage_fsl_nullability);
let new_storage = ConstantArray::new(fsl_scalar, len).into_array();
let padded_dim_u32 = u32::try_from(padded_dim).vortex_expect("padded_dim fits u32");
let new_fsl_dtype = DType::FixedSizeList(
Arc::new(element_dtype),
padded_dim_u32,
storage_fsl_nullability,
);
let new_ext_dtype = ExtDType::<Vector>::try_new(EmptyMetadata, new_fsl_dtype)?.erased();
let new_constant = ExtensionArray::new(new_ext_dtype, new_storage).into_array();
let sorf_child = sorf_view
.nth_child(0)
.vortex_expect("SorfTransform must have exactly one child");
let rewritten = InnerProduct::try_new_array(sorf_child, new_constant, len)?
.into_array()
.execute(ctx)?;
Ok(Some(rewritten))
}
fn try_execute_dict_constant(
&self,
lhs_ref: &ArrayRef,
rhs_ref: &ArrayRef,
len: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
if let Some(result) = self.try_execute_dict_constant_oriented(lhs_ref, rhs_ref, len, ctx)? {
return Ok(Some(result));
}
self.try_execute_dict_constant_oriented(rhs_ref, lhs_ref, len, ctx)
}
fn try_execute_dict_constant_oriented(
&self,
dict_candidate: &ArrayRef,
const_candidate: &ArrayRef,
len: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
let Some(dict_ext) = dict_candidate.as_opt::<Extension>() else {
return Ok(None);
};
let Some(fsl) = dict_ext.storage_array().as_opt::<FixedSizeList>() else {
return Ok(None);
};
let Some(dict) = fsl.elements().as_opt::<Dict>() else {
return Ok(None);
};
let Some(const_ext) = const_candidate.as_opt::<Extension>() else {
return Ok(None);
};
let const_storage = const_ext.storage_array();
let Some(const_backing) = const_storage.as_opt::<Constant>() else {
return Ok(None);
};
if const_backing.scalar().is_null() {
return Ok(None);
}
let codes_prim: PrimitiveArray = dict.codes().clone().execute(ctx)?;
let values_prim: PrimitiveArray = dict.values().clone().execute(ctx)?;
if codes_prim.ptype() != PType::U8 {
return Ok(None);
}
if values_prim.ptype() != PType::F32 {
return Ok(None);
}
let padded_dim = usize::try_from(fsl.list_size()).vortex_expect("fsl list_size fits usize");
let flat = extract_flat_elements(const_storage, padded_dim, ctx)?;
if flat.ptype() != PType::F32 {
return Ok(None);
}
let validity = dict_candidate
.validity()?
.and(const_candidate.validity()?)?;
if len == 0 {
let empty = PrimitiveArray::empty::<f32>(validity.nullability());
return Ok(Some(empty.into_array()));
}
let q: &[f32] = flat.row::<f32>(0);
debug_assert_eq!(q.len(), padded_dim);
let codes: &[u8] = codes_prim.as_slice::<u8>();
let values: &[f32] = values_prim.as_slice::<f32>();
debug_assert_eq!(codes.len(), len * padded_dim);
let out = execute_dict_constant_inner_product(q, values, codes, len, padded_dim);
let result = unsafe { PrimitiveArray::new_unchecked(out.freeze(), validity) }.into_array();
Ok(Some(result))
}
}
fn inner_product_row<T: Float + NativePType>(a: &[T], b: &[T]) -> T {
a.iter()
.zip(b.iter())
.map(|(&x, &y)| x * y)
.fold(T::zero(), |acc, v| acc + v)
}
fn execute_dict_constant_inner_product(
q: &[f32],
values: &[f32],
codes: &[u8],
num_rows: usize,
dim: usize,
) -> BufferMut<f32> {
let mut out = BufferMut::<f32>::with_capacity(num_rows);
const PARTIAL_SUMS: usize = 8;
for row_codes in codes.chunks_exact(dim) {
let mut acc = [0.0f32; PARTIAL_SUMS];
let code_chunks = row_codes.chunks_exact(PARTIAL_SUMS);
let q_chunks = q.chunks_exact(PARTIAL_SUMS);
let code_rem = code_chunks.remainder();
let q_rem = q_chunks.remainder();
for (cc, qd) in code_chunks.zip(q_chunks) {
for i in 0..PARTIAL_SUMS {
acc[i] = qd[i].mul_add(values[cc[i] as usize], acc[i]);
}
}
for (&code, &q_val) in code_rem.iter().zip(q_rem.iter()) {
acc[0] = q_val.mul_add(values[code as usize], acc[0]);
}
unsafe { out.push_unchecked(acc.iter().sum::<f32>()) };
}
out
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use vortex_array::ArrayPlugin;
use vortex_array::ArrayRef;
use vortex_array::IntoArray;
use vortex_array::VortexSessionExecute;
use vortex_array::arrays::MaskedArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::ScalarFnArray;
use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin;
use vortex_array::validity::Validity;
use vortex_error::VortexResult;
use crate::scalar_fns::inner_product::InnerProduct;
use crate::scalar_fns::l2_denorm::L2Denorm;
use crate::tests::SESSION;
use crate::utils::test_helpers::assert_close;
use crate::utils::test_helpers::tensor_array;
use crate::utils::test_helpers::vector_array;
fn eval_inner_product(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult<Vec<f64>> {
let scalar_fn = InnerProduct::new().erased();
let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], len)?;
let mut ctx = SESSION.create_execution_ctx();
let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?;
Ok(prim.as_slice::<f64>().to_vec())
}
#[rstest]
#[case::orthogonal(&[2], &[1.0, 0.0], &[0.0, 1.0], &[0.0])]
#[case::parallel(&[2], &[3.0, 4.0], &[3.0, 4.0], &[25.0])]
#[case::antiparallel(&[2], &[1.0, 2.0], &[-1.0, -2.0], &[-5.0])]
#[case::scaled(&[2], &[2.0, 0.0], &[3.0, 0.0], &[6.0])]
fn single_row(
#[case] shape: &[usize],
#[case] lhs_elems: &[f64],
#[case] rhs_elems: &[f64],
#[case] expected: &[f64],
) -> VortexResult<()> {
let lhs = tensor_array(shape, lhs_elems)?;
let rhs = tensor_array(shape, rhs_elems)?;
assert_close(&eval_inner_product(lhs, rhs, 1)?, expected);
Ok(())
}
#[test]
fn multiple_rows() -> VortexResult<()> {
let lhs = tensor_array(
&[3],
&[
1.0, 0.0, 0.0, 3.0, 4.0, 0.0, 1.0, 1.0, 1.0, ],
)?;
let rhs = tensor_array(
&[3],
&[
0.0, 1.0, 0.0, 3.0, 4.0, 0.0, 2.0, 2.0, 2.0, ],
)?;
assert_close(&eval_inner_product(lhs, rhs, 3)?, &[0.0, 25.0, 6.0]);
Ok(())
}
#[test]
fn vector_inner_product() -> VortexResult<()> {
let lhs = vector_array(
2,
&[
3.0, 4.0, 1.0, 0.0, ],
)?;
let rhs = vector_array(
2,
&[
3.0, 4.0, 0.0, 1.0, ],
)?;
assert_close(&eval_inner_product(lhs, rhs, 2)?, &[25.0, 0.0]);
Ok(())
}
#[test]
fn null_input_row() -> VortexResult<()> {
let lhs = tensor_array(&[2], &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0])?;
let rhs = tensor_array(&[2], &[7.0, 8.0, 9.0, 10.0, 11.0, 12.0])?;
let lhs = MaskedArray::try_new(lhs, Validity::from_iter([true, false, true]))?.into_array();
let scalar_fn = InnerProduct::new().erased();
let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?;
let mut ctx = SESSION.create_execution_ctx();
let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?;
assert!(prim.is_valid(0, &mut ctx)?);
assert!(!prim.is_valid(1, &mut ctx)?);
assert!(prim.is_valid(2, &mut ctx)?);
assert_close(&[prim.as_slice::<f64>()[0]], &[23.0]);
assert_close(&[prim.as_slice::<f64>()[2]], &[127.0]);
Ok(())
}
#[test]
fn rejects_non_extension_dtype() {
let lhs = PrimitiveArray::from_iter([1.0_f64, 2.0]).into_array();
let rhs = PrimitiveArray::from_iter([3.0_f64, 4.0]).into_array();
let result = InnerProduct::try_new_array(lhs, rhs, 2);
assert!(result.is_err());
}
#[test]
fn rejects_mismatched_dtypes() -> VortexResult<()> {
let lhs = tensor_array(&[2], &[1.0_f64, 2.0])?;
let rhs = vector_array(2, &[3.0_f64, 4.0])?;
let result = InnerProduct::try_new_array(lhs, rhs, 1);
assert!(result.is_err());
Ok(())
}
fn l2_denorm_array(
shape: &[usize],
normalized_elements: &[f64],
norms: &[f64],
) -> VortexResult<ArrayRef> {
use vortex_array::IntoArray;
let len = norms.len();
let normalized = tensor_array(shape, normalized_elements)?;
let norms = PrimitiveArray::from_iter(norms.iter().copied()).into_array();
let mut ctx = SESSION.create_execution_ctx();
Ok(L2Denorm::try_new_array(normalized, norms, len, &mut ctx)?.into_array())
}
#[test]
fn both_denorm() -> VortexResult<()> {
let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?;
let rhs = l2_denorm_array(&[2], &[1.0, 0.0], &[1.0])?;
assert_close(&eval_inner_product(lhs, rhs, 1)?, &[3.0]);
Ok(())
}
#[test]
fn both_denorm_multiple_rows() -> VortexResult<()> {
let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?;
let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 0.0, 1.0], &[5.0, 1.0])?;
assert_close(&eval_inner_product(lhs, rhs, 2)?, &[25.0, 0.0]);
Ok(())
}
#[test]
fn one_side_denorm_lhs() -> VortexResult<()> {
let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?;
let rhs = tensor_array(&[2], &[1.0, 2.0])?;
assert_close(&eval_inner_product(lhs, rhs, 1)?, &[11.0]);
Ok(())
}
#[test]
fn one_side_denorm_rhs() -> VortexResult<()> {
let lhs = tensor_array(&[2], &[1.0, 2.0])?;
let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?;
assert_close(&eval_inner_product(lhs, rhs, 1)?, &[11.0]);
Ok(())
}
#[test]
fn both_denorm_null_norms() -> VortexResult<()> {
let normalized_l = tensor_array(&[2], &[0.6, 0.8, 1.0, 0.0])?;
let norms_l = PrimitiveArray::from_option_iter([Some(5.0f64), None]).into_array();
let mut ctx = SESSION.create_execution_ctx();
let lhs = L2Denorm::try_new_array(normalized_l, norms_l, 2, &mut ctx)?.into_array();
let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?;
let scalar_fn = InnerProduct::new().erased();
let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 2)?;
let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?;
assert!(prim.is_valid(0, &mut ctx)?);
assert!(!prim.is_valid(1, &mut ctx)?);
assert_close(&[prim.as_slice::<f64>()[0]], &[25.0]);
Ok(())
}
#[rstest]
#[case::vector(
vector_array(3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(),
vector_array(3, &[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).unwrap(),
2,
)]
#[case::fixed_shape_tensor(
tensor_array(&[2], &[1.0, 2.0, 3.0, 4.0]).unwrap(),
tensor_array(&[2], &[5.0, 6.0, 7.0, 8.0]).unwrap(),
2,
)]
fn serde_round_trip(
#[case] lhs: ArrayRef,
#[case] rhs: ArrayRef,
#[case] len: usize,
) -> VortexResult<()> {
let original = InnerProduct::try_new_array(lhs.clone(), rhs.clone(), len)?.into_array();
let plugin = ScalarFnArrayPlugin::new(InnerProduct);
let metadata = plugin
.serialize(&original, &SESSION)?
.expect("InnerProduct serialize must produce metadata");
let children = vec![lhs, rhs];
let recovered = plugin.deserialize(
original.dtype(),
original.len(),
&metadata,
&[],
&children,
&SESSION,
)?;
assert_eq!(recovered.dtype(), original.dtype());
assert_eq!(recovered.len(), original.len());
assert_eq!(recovered.encoding_id(), original.encoding_id());
Ok(())
}
#[allow(
clippy::cast_possible_truncation,
reason = "tests build small fixtures with deterministic in-range indices"
)]
mod constant_query_optimizations {
use std::sync::LazyLock;
use rstest::rstest;
use vortex_array::ArrayRef;
use vortex_array::IntoArray;
use vortex_array::VortexSessionExecute;
use vortex_array::arrays::ConstantArray;
use vortex_array::arrays::ExtensionArray;
use vortex_array::arrays::FixedSizeListArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::ScalarFnArray;
use vortex_array::arrays::dict::DictArray;
use vortex_array::dtype::DType;
use vortex_array::dtype::Nullability;
use vortex_array::dtype::PType;
use vortex_array::dtype::extension::ExtDType;
use vortex_array::extension::EmptyMetadata;
use vortex_array::scalar::Scalar;
use vortex_array::session::ArraySession;
use vortex_array::validity::Validity;
use vortex_buffer::Buffer;
use vortex_error::VortexResult;
use vortex_session::VortexSession;
use crate::scalar_fns::inner_product::InnerProduct;
use crate::scalar_fns::sorf_transform::SorfMatrix;
use crate::scalar_fns::sorf_transform::SorfOptions;
use crate::scalar_fns::sorf_transform::SorfTransform;
use crate::vector::Vector;
static SESSION: LazyLock<VortexSession> =
LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
fn vector_f32(dim: u32, elements: &[f32]) -> VortexResult<ArrayRef> {
let row_count = elements.len() / dim as usize;
let elems: ArrayRef = Buffer::copy_from(elements).into_array();
let fsl = FixedSizeListArray::new(elems, dim, Validity::NonNullable, row_count);
let ext_dtype =
ExtDType::<Vector>::try_new(EmptyMetadata, fsl.dtype().clone())?.erased();
Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array())
}
fn constant_vector_f32(elements: &[f32], len: usize) -> VortexResult<ArrayRef> {
let element_dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
let children: Vec<Scalar> = elements
.iter()
.map(|&v| Scalar::primitive(v, Nullability::NonNullable))
.collect();
let storage_scalar =
Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable);
let storage = ConstantArray::new(storage_scalar, len).into_array();
let ext_dtype =
ExtDType::<Vector>::try_new(EmptyMetadata, storage.dtype().clone())?.erased();
Ok(ExtensionArray::new(ext_dtype, storage).into_array())
}
fn dict_vector_f32(list_size: u32, codes: &[u8], values: &[f32]) -> VortexResult<ArrayRef> {
let num_rows = codes.len() / list_size as usize;
let codes_arr =
PrimitiveArray::new::<u8>(Buffer::copy_from(codes), Validity::NonNullable)
.into_array();
let values_arr =
PrimitiveArray::new::<f32>(Buffer::copy_from(values), Validity::NonNullable)
.into_array();
let dict = DictArray::try_new(codes_arr, values_arr)?;
let fsl = FixedSizeListArray::try_new(
dict.into_array(),
list_size,
Validity::NonNullable,
num_rows,
)?;
let ext_dtype =
ExtDType::<Vector>::try_new(EmptyMetadata, fsl.dtype().clone())?.erased();
Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array())
}
fn eval_ip_f32(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult<Vec<f32>> {
let scalar_fn = InnerProduct::new().erased();
let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], len)?;
let mut ctx = SESSION.create_execution_ctx();
let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?;
Ok(prim.as_slice::<f32>().to_vec())
}
fn assert_close_f32(actual: &[f32], expected: &[f32], tol: f32) {
assert_eq!(actual.len(), expected.len(), "length mismatch");
for (i, (a, e)) in actual.iter().zip(expected).enumerate() {
assert!(
(a - e).abs() < tol,
"row {i}: got {a}, expected {e} (diff = {})",
(a - e).abs()
);
}
}
fn build_sorf_with_dict_child(
dim: u32,
num_rows: usize,
seed: u64,
num_rounds: u8,
) -> VortexResult<(ArrayRef, Vec<u8>, Vec<f32>, usize)> {
let padded_dim = (dim as usize).next_power_of_two();
let values: Vec<f32> = vec![-1.5, -1.0, -0.5, -0.1, 0.1, 0.5, 1.0, 1.5];
let codes: Vec<u8> = (0..num_rows * padded_dim)
.map(|i| (i as u8) % (values.len() as u8))
.collect();
let padded_vector = dict_vector_f32(padded_dim as u32, &codes, &values)?;
let sorf_options = SorfOptions {
seed,
num_rounds,
dimension: dim,
element_ptype: PType::F32,
};
let sorf =
SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array();
Ok((sorf, codes, values, padded_dim))
}
fn decode_sorf_dict(
codes: &[u8],
values: &[f32],
padded_dim: usize,
dim: usize,
num_rows: usize,
seed: u64,
num_rounds: u8,
) -> VortexResult<Vec<f32>> {
let rotation = SorfMatrix::try_new(seed, dim, num_rounds as usize)?;
let mut padded = vec![0.0f32; padded_dim];
let mut rotated = vec![0.0f32; padded_dim];
let mut out = Vec::with_capacity(num_rows * dim);
for row in 0..num_rows {
for j in 0..padded_dim {
padded[j] = values[codes[row * padded_dim + j] as usize];
}
rotation.inverse_rotate(&padded, &mut rotated);
out.extend_from_slice(&rotated[..dim]);
}
Ok(out)
}
fn naive_dot(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
}
#[test]
fn case1_sorf_lhs_constant_rhs_padded_gt_dim() -> VortexResult<()> {
let dim: u32 = 100;
let num_rows = 7usize;
let seed = 42u64;
let num_rounds = 3u8;
let padded_dim = (dim as usize).next_power_of_two();
assert!(padded_dim > dim as usize, "test must exercise padding");
let (sorf_lhs, codes, values, padded_dim_computed) =
build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?;
assert_eq!(padded_dim_computed, padded_dim);
let query_elems: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.1).sin()).collect();
let const_rhs = constant_vector_f32(&query_elems, num_rows)?;
let decoded = decode_sorf_dict(
&codes,
&values,
padded_dim,
dim as usize,
num_rows,
seed,
num_rounds,
)?;
let expected: Vec<f32> = (0..num_rows)
.map(|i| {
naive_dot(
&decoded[i * dim as usize..(i + 1) * dim as usize],
&query_elems,
)
})
.collect();
let actual = eval_ip_f32(sorf_lhs, const_rhs, num_rows)?;
assert_close_f32(&actual, &expected, 1e-3);
Ok(())
}
#[test]
fn case1_constant_lhs_sorf_rhs_mirrored() -> VortexResult<()> {
let dim: u32 = 100;
let num_rows = 5usize;
let seed = 7u64;
let num_rounds = 3u8;
let (sorf, codes, values, padded_dim) =
build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?;
let query_elems: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.2).cos()).collect();
let const_lhs = constant_vector_f32(&query_elems, num_rows)?;
let decoded = decode_sorf_dict(
&codes,
&values,
padded_dim,
dim as usize,
num_rows,
seed,
num_rounds,
)?;
let expected: Vec<f32> = (0..num_rows)
.map(|i| {
naive_dot(
&decoded[i * dim as usize..(i + 1) * dim as usize],
&query_elems,
)
})
.collect();
let actual = eval_ip_f32(const_lhs, sorf, num_rows)?;
assert_close_f32(&actual, &expected, 1e-3);
Ok(())
}
#[test]
fn case1_padded_equals_dim() -> VortexResult<()> {
let dim: u32 = 128;
let num_rows = 4usize;
let seed = 11u64;
let num_rounds = 3u8;
let (sorf, codes, values, padded_dim) =
build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?;
assert_eq!(padded_dim, dim as usize);
let query_elems: Vec<f32> = (0..dim).map(|i| i as f32 * 0.01 - 0.5).collect();
let const_rhs = constant_vector_f32(&query_elems, num_rows)?;
let decoded = decode_sorf_dict(
&codes,
&values,
padded_dim,
dim as usize,
num_rows,
seed,
num_rounds,
)?;
let expected: Vec<f32> = (0..num_rows)
.map(|i| {
naive_dot(
&decoded[i * dim as usize..(i + 1) * dim as usize],
&query_elems,
)
})
.collect();
let actual = eval_ip_f32(sorf, const_rhs, num_rows)?;
assert_close_f32(&actual, &expected, 1e-3);
Ok(())
}
#[test]
fn case1_empty_len_zero() -> VortexResult<()> {
let dim: u32 = 100;
let num_rows = 0usize;
let seed = 42u64;
let num_rounds = 3u8;
let (sorf, _codes, _values, _padded_dim) =
build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?;
let query_elems: Vec<f32> = vec![0.0; dim as usize];
let const_rhs = constant_vector_f32(&query_elems, num_rows)?;
let actual = eval_ip_f32(sorf, const_rhs, num_rows)?;
assert_eq!(actual.len(), 0);
Ok(())
}
#[test]
fn case2_dict_lhs_constant_rhs_matches_naive() -> VortexResult<()> {
let list_size: u32 = 8;
let num_rows = 10usize;
let values: Vec<f32> = vec![-1.0, -0.5, -0.25, -0.1, 0.1, 0.25, 0.5, 1.0];
let codes: Vec<u8> = (0..num_rows * list_size as usize)
.map(|i| (i as u8) % (values.len() as u8))
.collect();
let dict_lhs = dict_vector_f32(list_size, &codes, &values)?;
let query: Vec<f32> = (0..list_size).map(|i| (i as f32 + 1.0) * 0.3).collect();
let const_rhs = constant_vector_f32(&query, num_rows)?;
let expected: Vec<f32> = (0..num_rows)
.map(|row| {
let mut acc = 0.0f32;
for j in 0..list_size as usize {
let k = codes[row * list_size as usize + j] as usize;
acc += query[j] * values[k];
}
acc
})
.collect();
let actual = eval_ip_f32(dict_lhs, const_rhs, num_rows)?;
assert_close_f32(&actual, &expected, 1e-5);
Ok(())
}
#[test]
fn case2_constant_lhs_dict_rhs_mirrored() -> VortexResult<()> {
let list_size: u32 = 4;
let num_rows = 6usize;
let values: Vec<f32> = vec![0.1, 0.4, 0.7, 1.0];
let codes: Vec<u8> = (0..num_rows * list_size as usize)
.map(|i| ((i * 3) as u8) % (values.len() as u8))
.collect();
let dict_rhs = dict_vector_f32(list_size, &codes, &values)?;
let query: Vec<f32> = vec![0.5, -1.0, 2.5, -0.25];
let const_lhs = constant_vector_f32(&query, num_rows)?;
let expected: Vec<f32> = (0..num_rows)
.map(|row| {
let mut acc = 0.0f32;
for j in 0..list_size as usize {
let k = codes[row * list_size as usize + j] as usize;
acc += query[j] * values[k];
}
acc
})
.collect();
let actual = eval_ip_f32(const_lhs, dict_rhs, num_rows)?;
assert_close_f32(&actual, &expected, 1e-5);
Ok(())
}
#[test]
fn case2_u16_codes_falls_through() -> VortexResult<()> {
let list_size: u32 = 4;
let num_rows = 3usize;
let num_values = 300usize;
let values: Vec<f32> = (0..num_values).map(|i| i as f32 * 0.01).collect();
let codes_u16: Vec<u16> = (0..(num_rows * 4))
.map(|i| (i % num_values) as u16)
.collect();
let codes_arr =
PrimitiveArray::new::<u16>(Buffer::copy_from(codes_u16), Validity::NonNullable)
.into_array();
let values_arr =
PrimitiveArray::new::<f32>(Buffer::copy_from(&values), Validity::NonNullable)
.into_array();
let dict = DictArray::try_new(codes_arr, values_arr)?;
let fsl = FixedSizeListArray::try_new(
dict.into_array(),
list_size,
Validity::NonNullable,
num_rows,
)?;
let ext_dtype =
ExtDType::<Vector>::try_new(EmptyMetadata, fsl.dtype().clone())?.erased();
let dict_lhs = ExtensionArray::new(ext_dtype, fsl.into_array()).into_array();
let query: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let const_rhs = constant_vector_f32(&query, num_rows)?;
let expected: Vec<f32> = (0..num_rows)
.map(|row| {
let mut acc = 0.0f32;
for j in 0..4 {
let code = (row * 4 + j) % num_values;
acc += query[j] * values[code];
}
acc
})
.collect();
let actual = eval_ip_f32(dict_lhs, const_rhs, num_rows)?;
assert_close_f32(&actual, &expected, 1e-5);
Ok(())
}
#[test]
fn case2_plain_fsl_falls_through() -> VortexResult<()> {
let dim: u32 = 4;
let num_rows = 3usize;
let lhs_elems: Vec<f32> = (0..num_rows * dim as usize)
.map(|i| i as f32 * 0.25)
.collect();
let plain_lhs = vector_f32(dim, &lhs_elems)?;
let query: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let const_rhs = constant_vector_f32(&query, num_rows)?;
let expected: Vec<f32> = (0..num_rows)
.map(|row| {
naive_dot(
&lhs_elems[row * dim as usize..(row + 1) * dim as usize],
&query,
)
})
.collect();
let actual = eval_ip_f32(plain_lhs, const_rhs, num_rows)?;
assert_close_f32(&actual, &expected, 1e-5);
Ok(())
}
#[test]
fn case2_empty_len_zero() -> VortexResult<()> {
let list_size: u32 = 4;
let num_rows = 0usize;
let values: Vec<f32> = vec![0.0, 1.0, 2.0, 3.0];
let codes: Vec<u8> = Vec::new();
let dict_lhs = dict_vector_f32(list_size, &codes, &values)?;
let query: Vec<f32> = vec![0.0; 4];
let const_rhs = constant_vector_f32(&query, num_rows)?;
let actual = eval_ip_f32(dict_lhs, const_rhs, num_rows)?;
assert_eq!(actual.len(), 0);
Ok(())
}
#[test]
fn end_to_end_sorf_plus_dict_cosine_path() -> VortexResult<()> {
let dim: u32 = 100;
let num_rows = 9usize;
let seed = 99u64;
let num_rounds = 3u8;
let (sorf, codes, values, padded_dim) =
build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?;
let query_elems: Vec<f32> = (0..dim).map(|i| ((i as f32) * 0.15).sin() * 0.4).collect();
let const_rhs = constant_vector_f32(&query_elems, num_rows)?;
let decoded = decode_sorf_dict(
&codes,
&values,
padded_dim,
dim as usize,
num_rows,
seed,
num_rounds,
)?;
let expected: Vec<f32> = (0..num_rows)
.map(|i| {
naive_dot(
&decoded[i * dim as usize..(i + 1) * dim as usize],
&query_elems,
)
})
.collect();
let actual = eval_ip_f32(sorf, const_rhs, num_rows)?;
assert_close_f32(&actual, &expected, 1e-3);
Ok(())
}
struct XorShift64(u64);
impl XorShift64 {
fn new(seed: u64) -> Self {
Self(seed.wrapping_add(0x9E37_79B9_7F4A_7C15))
}
fn next_u64(&mut self) -> u64 {
let mut x = self.0;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.0 = x;
x
}
fn next_f32(&mut self) -> f32 {
let bits = (self.next_u64() >> 40) as u32; (bits as f32) / (1u32 << 24) as f32 * 2.0 - 1.0
}
}
#[test]
fn case2_large_u8_codebook_direct_lookup() -> VortexResult<()> {
let list_size: u32 = 16;
let num_rows = 20usize;
let num_centroids = 200usize;
assert!(num_centroids > 8 && num_centroids <= 256);
let mut rng = XorShift64::new(0xDEAD_BEEF);
let values: Vec<f32> = (0..num_centroids).map(|_| rng.next_f32()).collect();
let codes: Vec<u8> = (0..num_rows * list_size as usize)
.map(|_| (rng.next_u64() % num_centroids as u64) as u8)
.collect();
let dict_lhs = dict_vector_f32(list_size, &codes, &values)?;
let query: Vec<f32> = (0..list_size).map(|_| rng.next_f32()).collect();
let const_rhs = constant_vector_f32(&query, num_rows)?;
let expected: Vec<f32> = (0..num_rows)
.map(|row| {
let mut acc = 0.0f32;
for j in 0..list_size as usize {
let k = codes[row * list_size as usize + j] as usize;
acc += query[j] * values[k];
}
acc
})
.collect();
let actual = eval_ip_f32(dict_lhs, const_rhs, num_rows)?;
assert_close_f32(&actual, &expected, 1e-4);
Ok(())
}
#[rstest]
#[case::small_no_pad(128, 11, 1, 1)]
#[case::small_no_pad_rounds3(128, 23, 1_234, 3)]
#[case::small_padded(100, 17, 42, 3)]
#[case::mid_padded(200, 13, 2024, 3)]
#[case::mid_power_of_two(256, 31, 7, 3)]
#[case::larger_padded(300, 9, 99, 3)]
#[case::max_rounds(128, 5, 31_415, 5)]
fn case1_sorf_random_sweep(
#[case] dim: u32,
#[case] num_rows: usize,
#[case] seed: u64,
#[case] num_rounds: u8,
) -> VortexResult<()> {
let (sorf, codes, values, padded_dim) =
build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?;
let mut rng = XorShift64::new(seed ^ 0xABCD_1234);
let query: Vec<f32> = (0..dim).map(|_| rng.next_f32()).collect();
let const_rhs = constant_vector_f32(&query, num_rows)?;
let decoded = decode_sorf_dict(
&codes,
&values,
padded_dim,
dim as usize,
num_rows,
seed,
num_rounds,
)?;
let expected: Vec<f32> = (0..num_rows)
.map(|i| naive_dot(&decoded[i * dim as usize..(i + 1) * dim as usize], &query))
.collect();
let actual = eval_ip_f32(sorf, const_rhs, num_rows)?;
assert_close_f32(&actual, &expected, 1e-2);
Ok(())
}
#[rstest]
#[case::small(4, 7, 8)]
#[case::medium(16, 50, 64)]
#[case::larger(32, 100, 150)]
#[case::very_large_codebook(8, 25, 250)]
fn case2_random_sweep(
#[case] list_size: u32,
#[case] num_rows: usize,
#[case] num_centroids: usize,
) -> VortexResult<()> {
let mut rng = XorShift64::new((list_size as u64) * 31 + num_rows as u64);
let values: Vec<f32> = (0..num_centroids).map(|_| rng.next_f32()).collect();
assert!(num_centroids <= 256, "u8 codes cap at 256 centroids");
let codes: Vec<u8> = (0..num_rows * list_size as usize)
.map(|_| (rng.next_u64() % num_centroids as u64) as u8)
.collect();
let dict_lhs = dict_vector_f32(list_size, &codes, &values)?;
let query: Vec<f32> = (0..list_size).map(|_| rng.next_f32()).collect();
let const_rhs = constant_vector_f32(&query, num_rows)?;
let expected: Vec<f32> = (0..num_rows)
.map(|row| {
let mut acc = 0.0f32;
for j in 0..list_size as usize {
let k = codes[row * list_size as usize + j] as usize;
acc += query[j] * values[k];
}
acc
})
.collect();
let actual = eval_ip_f32(dict_lhs, const_rhs, num_rows)?;
assert_close_f32(&actual, &expected, 1e-4);
Ok(())
}
#[test]
fn end_to_end_dim128_rows64_bit6_regression() -> VortexResult<()> {
let dim: u32 = 128;
let num_rows = 64usize;
let seed = 0xFACE_F00D;
let num_rounds = 3u8;
let num_centroids = 64usize;
let padded_dim = (dim as usize).next_power_of_two();
let mut rng = XorShift64::new(seed);
let values: Vec<f32> = (0..num_centroids).map(|_| rng.next_f32()).collect();
let codes: Vec<u8> = (0..num_rows * padded_dim)
.map(|_| (rng.next_u64() % num_centroids as u64) as u8)
.collect();
let padded_vector = dict_vector_f32(padded_dim as u32, &codes, &values)?;
let sorf_options = SorfOptions {
seed,
num_rounds,
dimension: dim,
element_ptype: PType::F32,
};
let sorf =
SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array();
let query: Vec<f32> = (0..dim).map(|_| rng.next_f32()).collect();
let const_rhs = constant_vector_f32(&query, num_rows)?;
let decoded = decode_sorf_dict(
&codes,
&values,
padded_dim,
dim as usize,
num_rows,
seed,
num_rounds,
)?;
let expected: Vec<f32> = (0..num_rows)
.map(|i| naive_dot(&decoded[i * dim as usize..(i + 1) * dim as usize], &query))
.collect();
let actual = eval_ip_f32(sorf, const_rhs, num_rows)?;
assert_close_f32(&actual, &expected, 1e-2);
for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
let denom = e.abs().max(1.0);
let rel = (a - e).abs() / denom;
assert!(
rel < 1e-3,
"row {i}: rel err {rel} too large (a={a}, e={e})"
);
}
Ok(())
}
#[rstest]
#[case(1)]
#[case(2)]
#[case(3)]
#[case(4)]
#[case(5)]
fn case1_various_num_rounds(#[case] num_rounds: u8) -> VortexResult<()> {
let dim: u32 = 128;
let num_rows = 8usize;
let seed = 0x1234_5678;
let (sorf, codes, values, padded_dim) =
build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?;
let mut rng = XorShift64::new(seed ^ (num_rounds as u64));
let query: Vec<f32> = (0..dim).map(|_| rng.next_f32()).collect();
let const_rhs = constant_vector_f32(&query, num_rows)?;
let decoded = decode_sorf_dict(
&codes,
&values,
padded_dim,
dim as usize,
num_rows,
seed,
num_rounds,
)?;
let expected: Vec<f32> = (0..num_rows)
.map(|i| naive_dot(&decoded[i * dim as usize..(i + 1) * dim as usize], &query))
.collect();
let actual = eval_ip_f32(sorf, const_rhs, num_rows)?;
assert_close_f32(&actual, &expected, 1e-2);
Ok(())
}
#[test]
fn end_to_end_constant_lhs_sorf_rhs_mirrored() -> VortexResult<()> {
let dim: u32 = 256;
let num_rows = 12usize;
let seed = 0xBEEF_CAFE;
let num_rounds = 3u8;
let (sorf, codes, values, padded_dim) =
build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?;
let mut rng = XorShift64::new(seed);
let query: Vec<f32> = (0..dim).map(|_| rng.next_f32()).collect();
let const_lhs = constant_vector_f32(&query, num_rows)?;
let decoded = decode_sorf_dict(
&codes,
&values,
padded_dim,
dim as usize,
num_rows,
seed,
num_rounds,
)?;
let expected: Vec<f32> = (0..num_rows)
.map(|i| naive_dot(&decoded[i * dim as usize..(i + 1) * dim as usize], &query))
.collect();
let actual = eval_ip_f32(const_lhs, sorf, num_rows)?;
assert_close_f32(&actual, &expected, 1e-2);
Ok(())
}
}
}