use crate::ElementConversion;
use crate::backend::Backend;
use crate::tensor::Tensor;
use super::l2_norm;
pub const DEFAULT_EPSILON: f64 = 1e-8;
pub fn cosine_similarity<B: Backend, const D: usize>(
x1: Tensor<B, D>,
x2: Tensor<B, D>,
dim: i32,
eps: Option<B::FloatElem>,
) -> Tensor<B, D> {
let eps = eps.unwrap_or_else(|| B::FloatElem::from_elem(DEFAULT_EPSILON));
let dim_idx = if dim < 0 { D as i32 + dim } else { dim } as usize;
let dot_product = (x1.clone() * x2.clone()).sum_dim(dim_idx);
let norm_x1 = l2_norm(x1, dim_idx);
let norm_x2 = l2_norm(x2, dim_idx);
let denominator = norm_x1.clamp_min(eps) * norm_x2.clamp_min(eps);
dot_product / denominator
}