burn_nn/loss/
cosine_embedding.rs

1use alloc::format;
2
3use burn::tensor::linalg::cosine_similarity;
4
5use burn_core as burn;
6
7use crate::loss::reduction::Reduction;
8use burn::config::Config;
9use burn::module::{Content, DisplaySettings, ModuleDisplay};
10use burn::module::{Ignored, Module};
11use burn::tensor::{Int, Tensor, activation::relu, backend::Backend};
12
13/// Configuration for CosineEmbeddingLoss.
14#[derive(Config, Debug)]
15pub struct CosineEmbeddingLossConfig {
16    /// Margin for negative samples.
17    #[config(default = 0.0)]
18    pub margin: f32,
19
20    /// Specifies the reduction to apply to the output.
21    #[config(default = "Reduction::Mean")]
22    pub reduction: Reduction,
23}
24
25impl CosineEmbeddingLossConfig {
26    /// Initialize CosineEmbeddingLoss.
27    pub fn init(&self) -> CosineEmbeddingLoss {
28        CosineEmbeddingLoss {
29            margin: self.margin,
30            reduction: Ignored(self.reduction.clone()),
31        }
32    }
33}
34
35/// Cosine embedding loss between two tensors.
36///
37/// Measures cosine distance between tensors.
38/// Used for learning embeddings or similarity.
39#[derive(Module, Clone, Debug)]
40#[module(custom_display)]
41pub struct CosineEmbeddingLoss {
42    /// Margin value. Default: 0.0
43    pub margin: f32,
44
45    /// Reduction method
46    pub reduction: Ignored<Reduction>,
47}
48
49impl Default for CosineEmbeddingLoss {
50    fn default() -> Self {
51        CosineEmbeddingLossConfig::new().init()
52    }
53}
54
55impl ModuleDisplay for CosineEmbeddingLoss {
56    fn custom_settings(&self) -> Option<DisplaySettings> {
57        DisplaySettings::new()
58            .with_new_line_after_attribute(false)
59            .optional()
60    }
61
62    fn custom_content(&self, content: Content) -> Option<Content> {
63        content
64            .add("margin", &self.margin)
65            .add("reduction", format!("{:?}", &self.reduction.0).as_str())
66            .optional()
67    }
68}
69
70impl CosineEmbeddingLoss {
71    /// Creates a new instance
72    pub fn new() -> Self {
73        CosineEmbeddingLossConfig::new().init()
74    }
75
76    /// Compute loss with reduction.
77    ///
78    /// # Shapes
79    ///
80    /// - input1: ``[batch_size, embedding_dim]``
81    /// - input2: ``[batch_size, embedding_dim]``
82    /// - target: ``[batch_size]`` with values 1 or -1
83    ///
84    /// # Returns
85    ///
86    /// Loss tensor of shape ``[1]``
87    pub fn forward<B: Backend>(
88        &self,
89        input1: Tensor<B, 2>,
90        input2: Tensor<B, 2>,
91        target: Tensor<B, 1, Int>,
92    ) -> Tensor<B, 1> {
93        let tensor = self.forward_no_reduction(input1, input2, target);
94        match &self.reduction.0 {
95            Reduction::Mean => tensor.mean(),
96            Reduction::Sum => tensor.sum(),
97            other => panic!("{other:?} reduction is not supported"),
98        }
99    }
100
101    /// Compute loss without applying reduction.
102    ///
103    /// # Arguments
104    ///
105    /// * `input1` - First input tensor of shape ``[batch_size, embedding_dim]``
106    /// * `input2` - Second input tensor of shape ``[batch_size, embedding_dim]``
107    /// * `target` - Target tensor of shape ``[batch_size]`` with values 1 or -1
108    ///
109    /// # Returns
110    ///
111    /// Tensor of per-element losses with shape ``[batch_size]``
112    pub fn forward_no_reduction<B: Backend>(
113        &self,
114        input1: Tensor<B, 2>,
115        input2: Tensor<B, 2>,
116        target: Tensor<B, 1, Int>,
117    ) -> Tensor<B, 1> {
118        self.assertions(&input1, &input2, &target);
119
120        // cos_sim shape: [batch_size, 1]
121        let cos_sim = cosine_similarity(input1, input2, 1, None);
122        // cos_sim shape: [batch_size]
123        let cos_sim: Tensor<B, 1> = cos_sim.squeeze_dim(1);
124
125        let mut loss = cos_sim.zeros_like();
126
127        // Similar pairs (target == 1) - Formula: L = 1 - cos_sim
128        let similar_mask = target.clone().equal_elem(1);
129        let similar_loss = cos_sim.clone().neg().add_scalar(1);
130        loss = loss.mask_where(similar_mask, similar_loss);
131
132        // Dissimilar pairs (target == -1) - Formula: L = max(0, cos_sim - margin)
133        let dissimilar_mask = target.equal_elem(-1);
134        let dissimilar_loss = relu(cos_sim.clone().sub_scalar(self.margin));
135        loss = loss.mask_where(dissimilar_mask, dissimilar_loss);
136
137        // return loss shape: [batch_size]
138        loss
139    }
140
141    fn assertions<B: Backend>(
142        &self,
143        input1: &Tensor<B, 2>,
144        input2: &Tensor<B, 2>,
145        target: &Tensor<B, 1, Int>,
146    ) {
147        let [batch_size1, dim1] = input1.dims();
148        let [batch_size2, dim2] = input2.dims();
149        let [batch_size_target] = target.dims();
150
151        assert_eq!(
152            batch_size1, batch_size2,
153            "Batch size of input1 ({batch_size1}) must match batch size of input2 ({batch_size2})"
154        );
155
156        assert_eq!(
157            dim1, dim2,
158            "Embedding dimension of input1 ({dim1}) must match embedding dimension of input2 ({dim2})"
159        );
160
161        assert_eq!(
162            batch_size1, batch_size_target,
163            "Batch size of inputs ({batch_size1}) must match batch size of target ({batch_size_target})"
164        );
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use crate::TestBackend;
172    use burn::tensor::TensorData;
173    use burn::tensor::{Tolerance, ops::FloatElem};
174    type FT = FloatElem<TestBackend>;
175
176    #[test]
177    fn cosine_embedding_loss_positive_target() {
178        let device = Default::default();
179
180        // Two identical vectors should have cosine similarity of 1
181        let input1 = Tensor::<TestBackend, 2>::from_data(
182            TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
183            &device,
184        );
185
186        let input2 = Tensor::<TestBackend, 2>::from_data(
187            TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
188            &device,
189        );
190
191        // Target 1 means that inputs should be similar
192        let target = Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, 1]), &device);
193
194        let loss = CosineEmbeddingLossConfig::new().init();
195        let loss_no_reduction =
196            loss.forward_no_reduction(input1.clone(), input2.clone(), target.clone());
197        let loss_mean = loss.forward(input1.clone(), input2.clone(), target.clone());
198
199        let loss_sum = loss.forward(input1, input2, target);
200
201        // For identical vectors, 1 - cos_sim = 1 - 1 = 0
202        let expected_no_reduction = TensorData::from([0.0, 0.0]);
203        loss_no_reduction
204            .into_data()
205            .assert_approx_eq::<FT>(&expected_no_reduction, Tolerance::default());
206
207        let expected_mean = TensorData::from([0.0]);
208        loss_mean
209            .into_data()
210            .assert_approx_eq::<FT>(&expected_mean, Tolerance::default());
211
212        let expected_sum = TensorData::from([0.0]);
213        loss_sum
214            .into_data()
215            .assert_approx_eq::<FT>(&expected_sum, Tolerance::default());
216    }
217
218    #[test]
219    fn cosine_embedding_loss_negative_target() {
220        let device = Default::default();
221
222        // Two identical vectors should have cosine similarity of 1
223        let input1 = Tensor::<TestBackend, 2>::from_data(
224            TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
225            &device,
226        );
227
228        let input2 = Tensor::<TestBackend, 2>::from_data(
229            TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
230            &device,
231        );
232
233        // Target -1 means that inputs should be dissimilar
234        let target = Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([-1, -1]), &device);
235
236        // With margin 0.0, max(0, cos_sim - margin) = max(0, 1 - 0) = 1
237        let loss = CosineEmbeddingLossConfig::new().init();
238        let loss_no_reduction =
239            loss.forward_no_reduction(input1.clone(), input2.clone(), target.clone());
240        let loss_mean = loss.forward(input1.clone(), input2.clone(), target.clone());
241
242        // Create a loss with Sum reduction for testing
243        let loss_sum_config = CosineEmbeddingLossConfig::new().with_reduction(Reduction::Sum);
244        let loss_sum =
245            loss_sum_config
246                .init()
247                .forward(input1.clone(), input2.clone(), target.clone());
248
249        let expected_no_reduction = TensorData::from([1.0, 1.0]);
250        loss_no_reduction
251            .into_data()
252            .assert_approx_eq::<FT>(&expected_no_reduction, Tolerance::default());
253
254        let expected_mean = TensorData::from([1.0]);
255        loss_mean
256            .into_data()
257            .assert_approx_eq::<FT>(&expected_mean, Tolerance::default());
258
259        let expected_sum = TensorData::from([2.0]);
260        loss_sum
261            .into_data()
262            .assert_approx_eq::<FT>(&expected_sum, Tolerance::default());
263
264        // With margin 0.5, max(0, cos_sim - margin) = max(0, 1 - 0.5) = 0.5
265        let loss_with_margin = CosineEmbeddingLossConfig::new().with_margin(0.5).init();
266        let loss_with_margin = loss_with_margin.forward(input1, input2, target);
267
268        let expected = TensorData::from([0.5]);
269        loss_with_margin
270            .into_data()
271            .assert_approx_eq::<FT>(&expected, Tolerance::default());
272    }
273
274    #[test]
275    fn cosine_embedding_loss_mixed_targets() {
276        let device = Default::default();
277
278        let input1 = Tensor::<TestBackend, 2>::from_data(
279            TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
280            &device,
281        );
282
283        let input2 = Tensor::<TestBackend, 2>::from_data(
284            TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
285            &device,
286        );
287
288        // Mixed targets
289        let target = Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, -1]), &device);
290
291        let loss = CosineEmbeddingLossConfig::new().init();
292        let loss_no_reduction =
293            loss.forward_no_reduction(input1.clone(), input2.clone(), target.clone());
294        let loss_mean = loss.forward(input1, input2, target);
295
296        let expected_no_reduction = TensorData::from([0.0, 1.0]);
297        loss_no_reduction
298            .into_data()
299            .assert_approx_eq::<FT>(&expected_no_reduction, Tolerance::default());
300
301        let expected_mean = TensorData::from([0.5]);
302        loss_mean
303            .into_data()
304            .assert_approx_eq::<FT>(&expected_mean, Tolerance::default());
305    }
306
307    #[test]
308    fn display() {
309        let config = CosineEmbeddingLossConfig::new().with_margin(0.5);
310        let loss = config.init();
311
312        assert_eq!(
313            alloc::format!("{loss}"),
314            "CosineEmbeddingLoss {margin: 0.5, reduction: Mean}"
315        );
316    }
317}