use crate::ElementConversion;
use crate::FloatDType;
use crate::backend::Backend;
use crate::tensor::Tensor;
use super::l2_norm;
pub fn cosine_similarity<B: Backend, const D: usize>(
x1: Tensor<B, D>,
x2: Tensor<B, D>,
dim: i32,
eps: Option<B::FloatElem>,
) -> Tensor<B, D> {
let eps = eps.unwrap_or_else(|| {
let min_positive = x1
.dtype()
.finfo()
.unwrap_or(FloatDType::F32.finfo())
.min_positive;
B::FloatElem::from_elem(min_positive)
});
let dim_idx = if dim < 0 { D as i32 + dim } else { dim } as usize;
let dot_product = (x1.clone() * x2.clone()).sum_dim(dim_idx);
let norm_x1 = l2_norm(x1, dim_idx);
let norm_x2 = l2_norm(x2, dim_idx);
let denominator = norm_x1.clamp_min(eps) * norm_x2.clamp_min(eps);
dot_product / denominator
}