use std::fmt::Formatter;
use num_traits::Zero;
use vortex_array::ArrayRef;
use vortex_array::ExecutionCtx;
use vortex_array::IntoArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::ScalarFnArray;
use vortex_array::arrays::scalar_fn::ExactScalarFn;
use vortex_array::builtins::ArrayBuiltins;
use vortex_array::dtype::DType;
use vortex_array::dtype::Nullability;
use vortex_array::expr::Expression;
use vortex_array::expr::and;
use vortex_array::match_each_float_ptype;
use vortex_array::scalar_fn::Arity;
use vortex_array::scalar_fn::ChildName;
use vortex_array::scalar_fn::ExecutionArgs;
use vortex_array::scalar_fn::ScalarFn;
use vortex_array::scalar_fn::ScalarFnId;
use vortex_array::scalar_fn::ScalarFnVTable;
use vortex_array::validity::Validity;
use vortex_buffer::Buffer;
use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use crate::scalar_fns::ApproxOptions;
use crate::scalar_fns::inner_product::InnerProduct;
use crate::scalar_fns::l2_denorm::L2Denorm;
use crate::scalar_fns::l2_norm::L2Norm;
use crate::utils::extract_l2_denorm_children;
use crate::utils::validate_tensor_float_input;
#[derive(Clone)]
pub struct CosineSimilarity;
impl CosineSimilarity {
pub fn new(options: &ApproxOptions) -> ScalarFn<CosineSimilarity> {
ScalarFn::new(CosineSimilarity, options.clone())
}
pub fn try_new_array(
options: &ApproxOptions,
lhs: ArrayRef,
rhs: ArrayRef,
len: usize,
) -> VortexResult<ScalarFnArray> {
ScalarFnArray::try_new(CosineSimilarity::new(options).erased(), vec![lhs, rhs], len)
}
}
impl ScalarFnVTable for CosineSimilarity {
type Options = ApproxOptions;
fn id(&self) -> ScalarFnId {
ScalarFnId::from("vortex.tensor.cosine_similarity")
}
fn arity(&self, _options: &Self::Options) -> Arity {
Arity::Exact(2)
}
fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
match child_idx {
0 => ChildName::from("lhs"),
1 => ChildName::from("rhs"),
_ => unreachable!("CosineSimilarity must have exactly two children"),
}
}
fn fmt_sql(
&self,
_options: &Self::Options,
expr: &Expression,
f: &mut Formatter<'_>,
) -> std::fmt::Result {
write!(f, "cosine_similarity(")?;
expr.child(0).fmt_sql(f)?;
write!(f, ", ")?;
expr.child(1).fmt_sql(f)?;
write!(f, ")")
}
fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
let lhs = &arg_dtypes[0];
let rhs = &arg_dtypes[1];
vortex_ensure!(
lhs.eq_ignore_nullability(rhs),
"CosineSimilarity requires both inputs to have the same dtype, got {lhs} and {rhs}"
);
let tensor_match = validate_tensor_float_input(lhs)?;
let ptype = tensor_match.element_ptype();
let nullability = Nullability::from(lhs.is_nullable() || rhs.is_nullable());
Ok(DType::Primitive(ptype, nullability))
}
fn execute(
&self,
options: &Self::Options,
args: &dyn ExecutionArgs,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let mut lhs_ref = args.get(0)?;
let mut rhs_ref = args.get(1)?;
let len = args.row_count();
{
let lhs_is_denorm = lhs_ref.is::<ExactScalarFn<L2Denorm>>();
let rhs_is_denorm = rhs_ref.is::<ExactScalarFn<L2Denorm>>();
if lhs_is_denorm && rhs_is_denorm {
return self.execute_both_denorm(options, &lhs_ref, &rhs_ref, len, ctx);
} else if lhs_is_denorm || rhs_is_denorm {
if rhs_is_denorm {
(lhs_ref, rhs_ref) = (rhs_ref, lhs_ref);
}
return self.execute_one_denorm(options, &lhs_ref, &rhs_ref, len, ctx);
}
}
let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?;
let norm_lhs_arr = L2Norm::try_new_array(options, lhs_ref.clone(), len)?;
let norm_rhs_arr = L2Norm::try_new_array(options, rhs_ref.clone(), len)?;
let dot_arr = InnerProduct::try_new_array(options, lhs_ref, rhs_ref, len)?;
let dot: PrimitiveArray = dot_arr.into_array().execute(ctx)?;
let norm_l: PrimitiveArray = norm_lhs_arr.into_array().execute(ctx)?;
let norm_r: PrimitiveArray = norm_rhs_arr.into_array().execute(ctx)?;
match_each_float_ptype!(dot.ptype(), |T| {
let dots = dot.as_slice::<T>();
let norms_l = norm_l.as_slice::<T>();
let norms_r = norm_r.as_slice::<T>();
let buffer: Buffer<T> = (0..len)
.map(|i| {
let denom = norms_l[i] * norms_r[i];
if denom == T::zero() {
T::zero()
} else {
dots[i] / denom
}
})
.collect();
Ok(unsafe { PrimitiveArray::new_unchecked(buffer, validity) }.into_array())
})
}
fn validity(
&self,
_options: &Self::Options,
expression: &Expression,
) -> VortexResult<Option<Expression>> {
let lhs_validity = expression.child(0).validity()?;
let rhs_validity = expression.child(1).validity()?;
Ok(Some(and(lhs_validity, rhs_validity)))
}
fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
false
}
fn is_fallible(&self, _options: &Self::Options) -> bool {
false
}
}
impl CosineSimilarity {
fn execute_both_denorm(
&self,
options: &ApproxOptions,
lhs_ref: &ArrayRef,
rhs_ref: &ArrayRef,
len: usize,
_ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?;
let (normalized_l, _) = extract_l2_denorm_children(lhs_ref);
let (normalized_r, _) = extract_l2_denorm_children(rhs_ref);
let dot =
InnerProduct::try_new_array(options, normalized_l, normalized_r, len)?.into_array();
if !matches!(validity, Validity::NonNullable) {
dot.mask(validity.to_array(len))
} else {
Ok(dot)
}
}
fn execute_one_denorm(
&self,
options: &ApproxOptions,
denorm_ref: &ArrayRef,
plain_ref: &ArrayRef,
len: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let validity = denorm_ref.validity()?.and(plain_ref.validity()?)?;
let (normalized, _) = extract_l2_denorm_children(denorm_ref);
let dot_arr = InnerProduct::try_new_array(options, normalized, plain_ref.clone(), len)?;
let norm_arr = L2Norm::try_new_array(options, plain_ref.clone(), len)?;
let dot: PrimitiveArray = dot_arr.into_array().execute(ctx)?;
let plain_norm: PrimitiveArray = norm_arr.into_array().execute(ctx)?;
match_each_float_ptype!(dot.ptype(), |T| {
let dots = dot.as_slice::<T>();
let norms = plain_norm.as_slice::<T>();
let buffer: Buffer<T> = (0..len)
.map(|i| {
if norms[i] == T::zero() {
T::zero()
} else {
dots[i] / norms[i]
}
})
.collect();
Ok(unsafe { PrimitiveArray::new_unchecked(buffer, validity) }.into_array())
})
}
}
#[cfg(test)]
mod tests {
use std::sync::LazyLock;
use rstest::rstest;
use vortex_array::ArrayRef;
use vortex_array::IntoArray;
use vortex_array::VortexSessionExecute;
use vortex_array::arrays::MaskedArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::ScalarFnArray;
use vortex_array::scalar_fn::ScalarFn;
use vortex_array::session::ArraySession;
use vortex_array::validity::Validity;
use vortex_error::VortexResult;
use vortex_session::VortexSession;
use crate::scalar_fns::ApproxOptions;
use crate::scalar_fns::cosine_similarity::CosineSimilarity;
use crate::scalar_fns::l2_denorm::L2Denorm;
use crate::utils::test_helpers::assert_close;
use crate::utils::test_helpers::constant_tensor_array;
use crate::utils::test_helpers::constant_vector_array;
use crate::utils::test_helpers::tensor_array;
use crate::utils::test_helpers::vector_array;
static SESSION: LazyLock<VortexSession> =
LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult<Vec<f64>> {
let scalar_fn = ScalarFn::new(CosineSimilarity, ApproxOptions::Exact).erased();
let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], len)?;
let mut ctx = SESSION.create_execution_ctx();
let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?;
Ok(prim.as_slice::<f64>().to_vec())
}
#[test]
fn unit_vectors_1d() -> VortexResult<()> {
let lhs = tensor_array(
&[3],
&[
1.0, 0.0, 0.0, 0.0, 1.0, 0.0, ],
)?;
let rhs = tensor_array(
&[3],
&[
1.0, 0.0, 0.0, 1.0, 0.0, 0.0, ],
)?;
assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]);
Ok(())
}
#[rstest]
#[case::opposite(&[3], &[1.0, 0.0, 0.0], &[-1.0, 0.0, 0.0], &[-1.0])]
#[case::non_unit(&[2], &[3.0, 4.0], &[4.0, 3.0], &[0.96])]
#[case::zero_norm(&[2], &[0.0, 0.0], &[1.0, 0.0], &[0.0])]
fn single_row(
#[case] shape: &[usize],
#[case] lhs_elems: &[f64],
#[case] rhs_elems: &[f64],
#[case] expected: &[f64],
) -> VortexResult<()> {
let lhs = tensor_array(shape, lhs_elems)?;
let rhs = tensor_array(shape, rhs_elems)?;
assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, expected);
Ok(())
}
#[rstest]
#[case::matrix_2d(
&[2, 3],
&[
1.0, 0.0, 0.0, // row 0
0.0, 0.0, 0.0, // row 1
],
)]
#[case::tensor_3d(&[2, 2, 2], &[1.0; 8])]
fn self_similarity(#[case] shape: &[usize], #[case] elements: &[f64]) -> VortexResult<()> {
let lhs = tensor_array(shape, elements)?;
let rhs = tensor_array(shape, elements)?;
assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[1.0]);
Ok(())
}
#[test]
fn scalar_0d() -> VortexResult<()> {
let lhs = tensor_array(&[], &[5.0, 3.0])?;
let rhs = tensor_array(&[], &[5.0, -3.0])?;
assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, -1.0]);
Ok(())
}
#[test]
fn many_rows() -> VortexResult<()> {
let lhs = tensor_array(
&[4],
&[
1.0, 2.0, 3.0, 4.0, 0.0, 1.0, 0.0, 0.0, 5.0, 0.0, 5.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 7.0, ],
)?;
let rhs = lhs.clone();
assert_close(
&eval_cosine_similarity(lhs, rhs, 5)?,
&[1.0, 1.0, 1.0, 1.0, 1.0],
);
Ok(())
}
#[test]
fn constant_query_tensor() -> VortexResult<()> {
let data = tensor_array(
&[3],
&[
1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, ],
)?;
let query = constant_tensor_array(&[3], &[1.0, 0.0, 0.0], 4)?;
assert_close(
&eval_cosine_similarity(data, query, 4)?,
&[1.0, 0.0, 0.0, 1.0],
);
Ok(())
}
#[test]
fn vector_unit_vectors() -> VortexResult<()> {
let lhs = vector_array(
3,
&[
1.0, 0.0, 0.0, 0.0, 1.0, 0.0, ],
)?;
let rhs = vector_array(
3,
&[
1.0, 0.0, 0.0, 1.0, 0.0, 0.0, ],
)?;
assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]);
Ok(())
}
#[test]
fn vector_constant_query() -> VortexResult<()> {
let data = vector_array(
3,
&[
1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, ],
)?;
let query = constant_vector_array(&[1.0, 0.0, 0.0], 4)?;
assert_close(
&eval_cosine_similarity(data, query, 4)?,
&[1.0, 0.0, 0.0, 1.0],
);
Ok(())
}
#[test]
fn null_input_row() -> VortexResult<()> {
let lhs = tensor_array(&[2], &[3.0, 4.0, 1.0, 0.0])?;
let rhs = tensor_array(&[2], &[3.0, 4.0, 0.0, 1.0])?;
let rhs = MaskedArray::try_new(rhs, Validity::from_iter([true, false]))?.into_array();
let scalar_fn = ScalarFn::new(CosineSimilarity, ApproxOptions::Exact).erased();
let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 2)?;
let mut ctx = SESSION.create_execution_ctx();
let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?;
assert!(prim.is_valid(0)?);
assert!(!prim.is_valid(1)?);
assert_close(&[prim.as_slice::<f64>()[0]], &[1.0]);
Ok(())
}
fn l2_denorm_array(
shape: &[usize],
normalized_elements: &[f64],
norms: &[f64],
) -> VortexResult<ArrayRef> {
let len = norms.len();
let normalized = tensor_array(shape, normalized_elements)?;
let norms = PrimitiveArray::from_iter(norms.iter().copied()).into_array();
let mut ctx = SESSION.create_execution_ctx();
Ok(
L2Denorm::try_new_array(&ApproxOptions::Exact, normalized, norms, len, &mut ctx)?
.into_array(),
)
}
#[test]
fn both_denorm_self_similarity() -> VortexResult<()> {
let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?;
let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?;
assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 1.0]);
Ok(())
}
#[test]
fn both_denorm_orthogonal() -> VortexResult<()> {
let lhs = l2_denorm_array(&[2], &[1.0, 0.0], &[3.0])?;
let rhs = l2_denorm_array(&[2], &[0.0, 1.0], &[4.0])?;
assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.0]);
Ok(())
}
#[test]
fn both_denorm_zero_norm() -> VortexResult<()> {
let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 0.0, 0.0], &[5.0, 0.0])?;
let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?;
assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]);
Ok(())
}
#[test]
fn one_side_denorm_lhs() -> VortexResult<()> {
let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?;
let rhs = tensor_array(&[2], &[3.0, 4.0])?;
assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[1.0]);
Ok(())
}
#[test]
fn one_side_denorm_rhs() -> VortexResult<()> {
let lhs = tensor_array(&[2], &[1.0, 0.0])?;
let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?;
assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.6]);
Ok(())
}
#[test]
fn both_denorm_null_norms() -> VortexResult<()> {
let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?;
let normalized_r = tensor_array(&[2], &[0.6, 0.8, 1.0, 0.0])?;
let norms_r = PrimitiveArray::from_option_iter([Some(5.0f64), None]).into_array();
let mut ctx = SESSION.create_execution_ctx();
let rhs =
L2Denorm::try_new_array(&ApproxOptions::Exact, normalized_r, norms_r, 2, &mut ctx)?
.into_array();
let scalar_fn = ScalarFn::new(CosineSimilarity, ApproxOptions::Exact).erased();
let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 2)?;
let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?;
assert!(prim.is_valid(0)?);
assert!(!prim.is_valid(1)?);
assert_close(&[prim.as_slice::<f64>()[0]], &[1.0]);
Ok(())
}
}