1use alloc::format;
2
3use burn::tensor::linalg::cosine_similarity;
4
5use burn_core as burn;
6
7use crate::loss::reduction::Reduction;
8use burn::config::Config;
9use burn::module::{Content, DisplaySettings, ModuleDisplay};
10use burn::module::{Ignored, Module};
11use burn::tensor::{Int, Tensor, activation::relu, backend::Backend};
12
13#[derive(Config, Debug)]
15pub struct CosineEmbeddingLossConfig {
16 #[config(default = 0.0)]
18 pub margin: f32,
19
20 #[config(default = "Reduction::Mean")]
22 pub reduction: Reduction,
23}
24
25impl CosineEmbeddingLossConfig {
26 pub fn init(&self) -> CosineEmbeddingLoss {
28 CosineEmbeddingLoss {
29 margin: self.margin,
30 reduction: Ignored(self.reduction.clone()),
31 }
32 }
33}
34
35#[derive(Module, Clone, Debug)]
40#[module(custom_display)]
41pub struct CosineEmbeddingLoss {
42 pub margin: f32,
44
45 pub reduction: Ignored<Reduction>,
47}
48
49impl Default for CosineEmbeddingLoss {
50 fn default() -> Self {
51 CosineEmbeddingLossConfig::new().init()
52 }
53}
54
55impl ModuleDisplay for CosineEmbeddingLoss {
56 fn custom_settings(&self) -> Option<DisplaySettings> {
57 DisplaySettings::new()
58 .with_new_line_after_attribute(false)
59 .optional()
60 }
61
62 fn custom_content(&self, content: Content) -> Option<Content> {
63 content
64 .add("margin", &self.margin)
65 .add("reduction", format!("{:?}", &self.reduction.0).as_str())
66 .optional()
67 }
68}
69
70impl CosineEmbeddingLoss {
71 pub fn new() -> Self {
73 CosineEmbeddingLossConfig::new().init()
74 }
75
76 pub fn forward<B: Backend>(
88 &self,
89 input1: Tensor<B, 2>,
90 input2: Tensor<B, 2>,
91 target: Tensor<B, 1, Int>,
92 ) -> Tensor<B, 1> {
93 let tensor = self.forward_no_reduction(input1, input2, target);
94 match &self.reduction.0 {
95 Reduction::Mean => tensor.mean(),
96 Reduction::Sum => tensor.sum(),
97 other => panic!("{other:?} reduction is not supported"),
98 }
99 }
100
101 pub fn forward_no_reduction<B: Backend>(
113 &self,
114 input1: Tensor<B, 2>,
115 input2: Tensor<B, 2>,
116 target: Tensor<B, 1, Int>,
117 ) -> Tensor<B, 1> {
118 self.assertions(&input1, &input2, &target);
119
120 let cos_sim = cosine_similarity(input1, input2, 1, None);
122 let cos_sim: Tensor<B, 1> = cos_sim.squeeze_dim(1);
124
125 let mut loss = cos_sim.zeros_like();
126
127 let similar_mask = target.clone().equal_elem(1);
129 let similar_loss = cos_sim.clone().neg().add_scalar(1);
130 loss = loss.mask_where(similar_mask, similar_loss);
131
132 let dissimilar_mask = target.equal_elem(-1);
134 let dissimilar_loss = relu(cos_sim.clone().sub_scalar(self.margin));
135 loss = loss.mask_where(dissimilar_mask, dissimilar_loss);
136
137 loss
139 }
140
141 fn assertions<B: Backend>(
142 &self,
143 input1: &Tensor<B, 2>,
144 input2: &Tensor<B, 2>,
145 target: &Tensor<B, 1, Int>,
146 ) {
147 let [batch_size1, dim1] = input1.dims();
148 let [batch_size2, dim2] = input2.dims();
149 let [batch_size_target] = target.dims();
150
151 assert_eq!(
152 batch_size1, batch_size2,
153 "Batch size of input1 ({batch_size1}) must match batch size of input2 ({batch_size2})"
154 );
155
156 assert_eq!(
157 dim1, dim2,
158 "Embedding dimension of input1 ({dim1}) must match embedding dimension of input2 ({dim2})"
159 );
160
161 assert_eq!(
162 batch_size1, batch_size_target,
163 "Batch size of inputs ({batch_size1}) must match batch size of target ({batch_size_target})"
164 );
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use crate::TestBackend;
172 use burn::tensor::TensorData;
173 use burn::tensor::{Tolerance, ops::FloatElem};
174 type FT = FloatElem<TestBackend>;
175
176 #[test]
177 fn cosine_embedding_loss_positive_target() {
178 let device = Default::default();
179
180 let input1 = Tensor::<TestBackend, 2>::from_data(
182 TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
183 &device,
184 );
185
186 let input2 = Tensor::<TestBackend, 2>::from_data(
187 TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
188 &device,
189 );
190
191 let target = Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, 1]), &device);
193
194 let loss = CosineEmbeddingLossConfig::new().init();
195 let loss_no_reduction =
196 loss.forward_no_reduction(input1.clone(), input2.clone(), target.clone());
197 let loss_mean = loss.forward(input1.clone(), input2.clone(), target.clone());
198
199 let loss_sum = loss.forward(input1, input2, target);
200
201 let expected_no_reduction = TensorData::from([0.0, 0.0]);
203 loss_no_reduction
204 .into_data()
205 .assert_approx_eq::<FT>(&expected_no_reduction, Tolerance::default());
206
207 let expected_mean = TensorData::from([0.0]);
208 loss_mean
209 .into_data()
210 .assert_approx_eq::<FT>(&expected_mean, Tolerance::default());
211
212 let expected_sum = TensorData::from([0.0]);
213 loss_sum
214 .into_data()
215 .assert_approx_eq::<FT>(&expected_sum, Tolerance::default());
216 }
217
218 #[test]
219 fn cosine_embedding_loss_negative_target() {
220 let device = Default::default();
221
222 let input1 = Tensor::<TestBackend, 2>::from_data(
224 TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
225 &device,
226 );
227
228 let input2 = Tensor::<TestBackend, 2>::from_data(
229 TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
230 &device,
231 );
232
233 let target = Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([-1, -1]), &device);
235
236 let loss = CosineEmbeddingLossConfig::new().init();
238 let loss_no_reduction =
239 loss.forward_no_reduction(input1.clone(), input2.clone(), target.clone());
240 let loss_mean = loss.forward(input1.clone(), input2.clone(), target.clone());
241
242 let loss_sum_config = CosineEmbeddingLossConfig::new().with_reduction(Reduction::Sum);
244 let loss_sum =
245 loss_sum_config
246 .init()
247 .forward(input1.clone(), input2.clone(), target.clone());
248
249 let expected_no_reduction = TensorData::from([1.0, 1.0]);
250 loss_no_reduction
251 .into_data()
252 .assert_approx_eq::<FT>(&expected_no_reduction, Tolerance::default());
253
254 let expected_mean = TensorData::from([1.0]);
255 loss_mean
256 .into_data()
257 .assert_approx_eq::<FT>(&expected_mean, Tolerance::default());
258
259 let expected_sum = TensorData::from([2.0]);
260 loss_sum
261 .into_data()
262 .assert_approx_eq::<FT>(&expected_sum, Tolerance::default());
263
264 let loss_with_margin = CosineEmbeddingLossConfig::new().with_margin(0.5).init();
266 let loss_with_margin = loss_with_margin.forward(input1, input2, target);
267
268 let expected = TensorData::from([0.5]);
269 loss_with_margin
270 .into_data()
271 .assert_approx_eq::<FT>(&expected, Tolerance::default());
272 }
273
274 #[test]
275 fn cosine_embedding_loss_mixed_targets() {
276 let device = Default::default();
277
278 let input1 = Tensor::<TestBackend, 2>::from_data(
279 TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
280 &device,
281 );
282
283 let input2 = Tensor::<TestBackend, 2>::from_data(
284 TensorData::from([[1.0, 0.0], [0.0, 1.0]]),
285 &device,
286 );
287
288 let target = Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, -1]), &device);
290
291 let loss = CosineEmbeddingLossConfig::new().init();
292 let loss_no_reduction =
293 loss.forward_no_reduction(input1.clone(), input2.clone(), target.clone());
294 let loss_mean = loss.forward(input1, input2, target);
295
296 let expected_no_reduction = TensorData::from([0.0, 1.0]);
297 loss_no_reduction
298 .into_data()
299 .assert_approx_eq::<FT>(&expected_no_reduction, Tolerance::default());
300
301 let expected_mean = TensorData::from([0.5]);
302 loss_mean
303 .into_data()
304 .assert_approx_eq::<FT>(&expected_mean, Tolerance::default());
305 }
306
307 #[test]
308 fn display() {
309 let config = CosineEmbeddingLossConfig::new().with_margin(0.5);
310 let loss = config.init();
311
312 assert_eq!(
313 alloc::format!("{loss}"),
314 "CosineEmbeddingLoss {margin: 0.5, reduction: Mean}"
315 );
316 }
317}