use alloc::format;
use burn::tensor::linalg::cosine_similarity;
use burn_core as burn;
use crate::loss::reduction::Reduction;
use burn::config::Config;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::module::{Ignored, Module};
use burn::tensor::{Int, Tensor, activation::relu, backend::Backend};
#[derive(Config, Debug)]
pub struct CosineEmbeddingLossConfig {
#[config(default = 0.0)]
pub margin: f32,
#[config(default = "Reduction::Mean")]
pub reduction: Reduction,
}
impl CosineEmbeddingLossConfig {
pub fn init(&self) -> CosineEmbeddingLoss {
CosineEmbeddingLoss {
margin: self.margin,
reduction: Ignored(self.reduction.clone()),
}
}
}
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct CosineEmbeddingLoss {
pub margin: f32,
pub reduction: Ignored<Reduction>,
}
impl Default for CosineEmbeddingLoss {
fn default() -> Self {
CosineEmbeddingLossConfig::new().init()
}
}
impl ModuleDisplay for CosineEmbeddingLoss {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("margin", &self.margin)
.add("reduction", format!("{:?}", &self.reduction.0).as_str())
.optional()
}
}
impl CosineEmbeddingLoss {
pub fn new() -> Self {
CosineEmbeddingLossConfig::new().init()
}
pub fn forward<B: Backend>(
&self,
input1: Tensor<B, 2>,
input2: Tensor<B, 2>,
target: Tensor<B, 1, Int>,
) -> Tensor<B, 1> {
let tensor = self.forward_no_reduction(input1, input2, target);
match &self.reduction.0 {
Reduction::Mean => tensor.mean(),
Reduction::Sum => tensor.sum(),
other => panic!("{other:?} reduction is not supported"),
}
}
pub fn forward_no_reduction<B: Backend>(
&self,
input1: Tensor<B, 2>,
input2: Tensor<B, 2>,
target: Tensor<B, 1, Int>,
) -> Tensor<B, 1> {
self.assertions(&input1, &input2, &target);
let cos_sim = cosine_similarity(input1, input2, 1, None);
let cos_sim: Tensor<B, 1> = cos_sim.squeeze_dim(1);
let mut loss = cos_sim.zeros_like();
let similar_mask = target.clone().equal_elem(1);
let similar_loss = cos_sim.clone().neg().add_scalar(1);
loss = loss.mask_where(similar_mask, similar_loss);
let dissimilar_mask = target.equal_elem(-1);
let dissimilar_loss = relu(cos_sim.clone().sub_scalar(self.margin));
loss = loss.mask_where(dissimilar_mask, dissimilar_loss);
loss
}
fn assertions<B: Backend>(
&self,
input1: &Tensor<B, 2>,
input2: &Tensor<B, 2>,
target: &Tensor<B, 1, Int>,
) {
let [batch_size1, dim1] = input1.dims();
let [batch_size2, dim2] = input2.dims();
let [batch_size_target] = target.dims();
assert_eq!(
batch_size1, batch_size2,
"Batch size of input1 ({batch_size1}) must match batch size of input2 ({batch_size2})"
);
assert_eq!(
dim1, dim2,
"Embedding dimension of input1 ({dim1}) must match embedding dimension of input2 ({dim2})"
);
assert_eq!(
batch_size1, batch_size_target,
"Batch size of inputs ({batch_size1}) must match batch size of target ({batch_size_target})"
);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[test]
fn cosine_embedding_loss_positive_target() {
let device = Default::default();
let input1 = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
&device,
);
let input2 = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
&device,
);
let target = Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, 1]), &device);
let loss = CosineEmbeddingLossConfig::new().init();
let loss_no_reduction =
loss.forward_no_reduction(input1.clone(), input2.clone(), target.clone());
let loss_mean = loss.forward(input1.clone(), input2.clone(), target.clone());
let loss_sum = loss.forward(input1, input2, target);
let expected_no_reduction = TensorData::from([0.0, 0.0]);
loss_no_reduction
.into_data()
.assert_approx_eq::<FT>(&expected_no_reduction, Tolerance::default());
let expected_mean = TensorData::from([0.0]);
loss_mean
.into_data()
.assert_approx_eq::<FT>(&expected_mean, Tolerance::default());
let expected_sum = TensorData::from([0.0]);
loss_sum
.into_data()
.assert_approx_eq::<FT>(&expected_sum, Tolerance::default());
}
#[test]
fn cosine_embedding_loss_negative_target() {
let device = Default::default();
let input1 = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
&device,
);
let input2 = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
&device,
);
let target = Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([-1, -1]), &device);
let loss = CosineEmbeddingLossConfig::new().init();
let loss_no_reduction =
loss.forward_no_reduction(input1.clone(), input2.clone(), target.clone());
let loss_mean = loss.forward(input1.clone(), input2.clone(), target.clone());
let loss_sum_config = CosineEmbeddingLossConfig::new().with_reduction(Reduction::Sum);
let loss_sum =
loss_sum_config
.init()
.forward(input1.clone(), input2.clone(), target.clone());
let expected_no_reduction = TensorData::from([1.0, 1.0]);
loss_no_reduction
.into_data()
.assert_approx_eq::<FT>(&expected_no_reduction, Tolerance::default());
let expected_mean = TensorData::from([1.0]);
loss_mean
.into_data()
.assert_approx_eq::<FT>(&expected_mean, Tolerance::default());
let expected_sum = TensorData::from([2.0]);
loss_sum
.into_data()
.assert_approx_eq::<FT>(&expected_sum, Tolerance::default());
let loss_with_margin = CosineEmbeddingLossConfig::new().with_margin(0.5).init();
let loss_with_margin = loss_with_margin.forward(input1, input2, target);
let expected = TensorData::from([0.5]);
loss_with_margin
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn cosine_embedding_loss_mixed_targets() {
let device = Default::default();
let input1 = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
&device,
);
let input2 = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
&device,
);
let target = Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, -1]), &device);
let loss = CosineEmbeddingLossConfig::new().init();
let loss_no_reduction =
loss.forward_no_reduction(input1.clone(), input2.clone(), target.clone());
let loss_mean = loss.forward(input1, input2, target);
let expected_no_reduction = TensorData::from([0.0, 1.0]);
loss_no_reduction
.into_data()
.assert_approx_eq::<FT>(&expected_no_reduction, Tolerance::default());
let expected_mean = TensorData::from([0.5]);
loss_mean
.into_data()
.assert_approx_eq::<FT>(&expected_mean, Tolerance::default());
}
#[test]
fn display() {
let config = CosineEmbeddingLossConfig::new().with_margin(0.5);
let loss = config.init();
assert_eq!(
alloc::format!("{loss}"),
"CosineEmbeddingLoss {margin: 0.5, reduction: Mean}"
);
}
}