Skip to main content

burn_tensor/tensor/linalg/
cosine_similarity.rs

1use crate::ElementConversion;
2use crate::FloatDType;
3use crate::backend::Backend;
4use crate::tensor::Tensor;
5
6use super::l2_norm;
7
8/// Computes the cosine similarity between two tensors along a specified dimension.
9///
10/// Calculates the cosine of the angle between inputs as their dot product divided
11/// by the product of their L2 norms.
12///
13/// # Arguments
14///
15/// * `x1` - First input tensor
16/// * `x2` - Second input tensor
17/// * `dim` - Dimension along which to compute the similarity
18///   (negative indices allowed: -1 for last dimension)
19/// * `eps` - Small value to avoid division by zero (default: dtype's smallest positive normal)
20///
21/// # Returns
22///
23/// Tensor containing the cosine similarity between x1 and x2
24pub fn cosine_similarity<B: Backend, const D: usize>(
25    x1: Tensor<B, D>,
26    x2: Tensor<B, D>,
27    dim: i32,
28    eps: Option<B::FloatElem>,
29) -> Tensor<B, D> {
30    let eps = eps.unwrap_or_else(|| {
31        let min_positive = x1
32            .dtype()
33            .finfo()
34            .unwrap_or(FloatDType::F32.finfo())
35            .min_positive;
36        B::FloatElem::from_elem(min_positive)
37    });
38
39    // Convert negative dimension to positive
40    let dim_idx = if dim < 0 { D as i32 + dim } else { dim } as usize;
41
42    // Compute dot product: sum(x1 * x2) along the specified dimension
43    let dot_product = (x1.clone() * x2.clone()).sum_dim(dim_idx);
44
45    // Compute L2 norms: ||x1|| and ||x2||
46    let norm_x1 = l2_norm(x1, dim_idx);
47    let norm_x2 = l2_norm(x2, dim_idx);
48
49    // Calculate the denominator (product of the norms) with epsilon to avoid division by zero
50    let denominator = norm_x1.clamp_min(eps) * norm_x2.clamp_min(eps);
51
52    // Return the cosine similarity (dot product divided by the product of norms)
53    dot_product / denominator
54}