burn_tensor/tensor/linalg/cosine_similarity.rs
1use crate::ElementConversion;
2use crate::FloatDType;
3use crate::backend::Backend;
4use crate::tensor::Tensor;
5
6use super::l2_norm;
7
8/// Computes the cosine similarity between two tensors along a specified dimension.
9///
10/// Calculates the cosine of the angle between inputs as their dot product divided
11/// by the product of their L2 norms.
12///
13/// # Arguments
14///
15/// * `x1` - First input tensor
16/// * `x2` - Second input tensor
17/// * `dim` - Dimension along which to compute the similarity
18/// (negative indices allowed: -1 for last dimension)
19/// * `eps` - Small value to avoid division by zero (default: dtype's smallest positive normal)
20///
21/// # Returns
22///
23/// Tensor containing the cosine similarity between x1 and x2
24pub fn cosine_similarity<B: Backend, const D: usize>(
25 x1: Tensor<B, D>,
26 x2: Tensor<B, D>,
27 dim: i32,
28 eps: Option<B::FloatElem>,
29) -> Tensor<B, D> {
30 let eps = eps.unwrap_or_else(|| {
31 let min_positive = x1
32 .dtype()
33 .finfo()
34 .unwrap_or(FloatDType::F32.finfo())
35 .min_positive;
36 B::FloatElem::from_elem(min_positive)
37 });
38
39 // Convert negative dimension to positive
40 let dim_idx = if dim < 0 { D as i32 + dim } else { dim } as usize;
41
42 // Compute dot product: sum(x1 * x2) along the specified dimension
43 let dot_product = (x1.clone() * x2.clone()).sum_dim(dim_idx);
44
45 // Compute L2 norms: ||x1|| and ||x2||
46 let norm_x1 = l2_norm(x1, dim_idx);
47 let norm_x2 = l2_norm(x2, dim_idx);
48
49 // Calculate the denominator (product of the norms) with epsilon to avoid division by zero
50 let denominator = norm_x1.clamp_min(eps) * norm_x2.clamp_min(eps);
51
52 // Return the cosine similarity (dot product divided by the product of norms)
53 dot_product / denominator
54}