1use crate::Tensor;
4use ndarray::Array1;
5
6use super::LossFn;
7
8pub struct WeightedLoss {
31 inner: Box<dyn LossFn>,
32 weight: f32,
33}
34
35impl WeightedLoss {
36 pub fn new(inner: Box<dyn LossFn>, weight: f32) -> Self {
43 Self { inner, weight }
44 }
45
46 pub fn unweighted(inner: Box<dyn LossFn>) -> Self {
48 Self::new(inner, 1.0)
49 }
50
51 pub fn weight(&self) -> f32 {
53 self.weight
54 }
55
56 pub fn set_weight(&mut self, weight: f32) {
58 self.weight = weight;
59 }
60}
61
62impl LossFn for WeightedLoss {
63 fn forward(&self, predictions: &Tensor, targets: &Tensor) -> Tensor {
64 let inner_loss = self.inner.forward(predictions, targets);
65
66 if (self.weight - 1.0).abs() < 1e-7 {
67 return inner_loss;
69 }
70
71 let weighted_val = inner_loss.data()[0] * self.weight;
73 let mut weighted_loss = Tensor::from_vec(vec![weighted_val], true);
74
75 use crate::autograd::BackwardOp;
77 use std::rc::Rc;
78
79 struct WeightedBackward {
80 inner_backward: Option<Rc<dyn BackwardOp>>,
81 #[allow(dead_code)]
82 weight: f32, }
84
85 impl BackwardOp for WeightedBackward {
86 fn backward(&self) {
87 if let Some(ref inner) = self.inner_backward {
90 inner.backward();
91 }
92 }
93 }
94
95 if predictions.requires_grad() {
96 weighted_loss.set_backward_op(Rc::new(WeightedBackward {
97 inner_backward: inner_loss.backward_op(),
98 weight: self.weight,
99 }));
100 }
101
102 weighted_loss
103 }
104
105 fn name(&self) -> &'static str {
106 "Weighted"
107 }
108}
109
110pub struct SampleWeightedLoss {
131 #[allow(dead_code)]
132 inner: Box<dyn LossFn>, }
134
135impl SampleWeightedLoss {
136 pub fn new(inner: Box<dyn LossFn>) -> Self {
138 Self { inner }
139 }
140
141 pub fn forward_weighted(
149 &self,
150 predictions: &Tensor,
151 targets: &Tensor,
152 weights: &[f32],
153 ) -> Tensor {
154 assert_eq!(predictions.len(), weights.len(), "Weights must match predictions length");
155
156 let diff = predictions.data() - targets.data();
158 let n = predictions.len() as f32;
159
160 let weighted_loss: f32 =
162 diff.iter().zip(weights.iter()).map(|(&d, &w)| w * d * d).sum::<f32>() / n;
163
164 let mut loss = Tensor::from_vec(vec![weighted_loss], true);
165
166 let grad: Array1<f32> =
168 diff.iter().zip(weights.iter()).map(|(&d, &w)| 2.0 * w * d / n).collect();
169
170 use crate::autograd::BackwardOp;
171 use std::rc::Rc;
172
173 struct SampleWeightedBackward {
174 pred_grad_cell: Rc<std::cell::RefCell<Option<Array1<f32>>>>,
175 grad: Array1<f32>,
176 }
177
178 impl BackwardOp for SampleWeightedBackward {
179 fn backward(&self) {
180 let mut pred_grad = self.pred_grad_cell.borrow_mut();
181 if let Some(existing) = pred_grad.as_mut() {
182 *existing = &*existing + &self.grad;
183 } else {
184 *pred_grad = Some(self.grad.clone());
185 }
186 }
187 }
188
189 if predictions.requires_grad() {
190 loss.set_backward_op(Rc::new(SampleWeightedBackward {
191 pred_grad_cell: predictions.grad_cell(),
192 grad,
193 }));
194 }
195
196 loss
197 }
198}
199
200impl LossFn for SampleWeightedLoss {
201 fn forward(&self, predictions: &Tensor, targets: &Tensor) -> Tensor {
202 let weights = vec![1.0; predictions.len()];
204 self.forward_weighted(predictions, targets, &weights)
205 }
206
207 fn name(&self) -> &'static str {
208 "SampleWeighted"
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use crate::train::MSELoss;
216 use approx::assert_relative_eq;
217
218 #[test]
219 fn test_weighted_loss_scales_value() {
220 let loss_fn = WeightedLoss::new(Box::new(MSELoss), 1.5);
221 let unweighted = MSELoss;
222
223 let pred = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
224 let target = Tensor::from_vec(vec![1.5, 2.5, 3.5], false);
225
226 let weighted = loss_fn.forward(&pred, &target);
227 let base = unweighted.forward(&pred.clone(), &target);
228
229 assert_relative_eq!(weighted.data()[0], base.data()[0] * 1.5, epsilon = 1e-5);
231 }
232
233 #[test]
234 fn test_weighted_loss_unit_weight() {
235 let loss_fn = WeightedLoss::new(Box::new(MSELoss), 1.0);
236 let unweighted = MSELoss;
237
238 let pred = Tensor::from_vec(vec![1.0, 2.0], true);
239 let target = Tensor::from_vec(vec![1.5, 2.5], false);
240
241 let weighted = loss_fn.forward(&pred, &target);
242 let base = unweighted.forward(&pred.clone(), &target);
243
244 assert_relative_eq!(weighted.data()[0], base.data()[0], epsilon = 1e-5);
246 }
247
248 #[test]
249 fn test_weighted_loss_zero_weight() {
250 let loss_fn = WeightedLoss::new(Box::new(MSELoss), 0.0);
251
252 let pred = Tensor::from_vec(vec![1.0, 2.0], true);
253 let target = Tensor::from_vec(vec![10.0, 20.0], false);
254
255 let loss = loss_fn.forward(&pred, &target);
256
257 assert_relative_eq!(loss.data()[0], 0.0, epsilon = 1e-5);
259 }
260
261 #[test]
262 fn test_weighted_loss_methods() {
263 let mut loss_fn = WeightedLoss::new(Box::new(MSELoss), 1.5);
264
265 assert_eq!(loss_fn.weight(), 1.5);
266 assert_eq!(loss_fn.name(), "Weighted");
267
268 loss_fn.set_weight(2.0);
269 assert_eq!(loss_fn.weight(), 2.0);
270 }
271
272 #[test]
273 fn test_weighted_loss_unweighted() {
274 let loss_fn = WeightedLoss::unweighted(Box::new(MSELoss));
275 let pred = Tensor::from_vec(vec![1.0, 2.0], true);
276 let target = Tensor::from_vec(vec![1.5, 2.5], false);
277 let loss = loss_fn.forward(&pred, &target);
278 assert_eq!(loss_fn.weight(), 1.0);
279 assert!(loss.data()[0] > 0.0);
280 }
281
282 #[test]
283 fn test_weighted_no_grad() {
284 let loss_fn = WeightedLoss::new(Box::new(MSELoss), 1.5);
285 let pred = Tensor::from_vec(vec![1.0, 2.0], false);
286 let target = Tensor::from_vec(vec![1.5, 2.5], false);
287 let loss = loss_fn.forward(&pred, &target);
288 assert!(loss.data()[0] > 0.0);
289 }
290
291 #[test]
292 fn test_weighted_backward_with_grad() {
293 let loss_fn = WeightedLoss::new(Box::new(MSELoss), 2.0);
294 let pred = Tensor::from_vec(vec![1.0, 2.0], true);
295 let target = Tensor::from_vec(vec![0.0, 0.0], false);
296
297 let loss = loss_fn.forward(&pred, &target);
298 if let Some(backward_op) = loss.backward_op() {
299 backward_op.backward();
300 }
301
302 let grad = pred.grad();
304 assert!(grad.is_some());
305 }
306
307 #[test]
308 fn test_sample_weighted_loss_uniform() {
309 let loss_fn = SampleWeightedLoss::new(Box::new(MSELoss));
310
311 let pred = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
312 let target = Tensor::from_vec(vec![1.5, 2.5, 3.5], false);
313
314 let loss = loss_fn.forward(&pred, &target);
316
317 let mse_loss = MSELoss.forward(&pred.clone(), &target);
319 assert_relative_eq!(loss.data()[0], mse_loss.data()[0], epsilon = 1e-5);
320 }
321
322 #[test]
323 fn test_sample_weighted_loss_custom_weights() {
324 let loss_fn = SampleWeightedLoss::new(Box::new(MSELoss));
325
326 let pred = Tensor::from_vec(vec![0.0, 0.0], true);
327 let target = Tensor::from_vec(vec![1.0, 1.0], false);
328 let weights = vec![2.0, 0.0]; let loss = loss_fn.forward_weighted(&pred, &target, &weights);
331
332 assert_relative_eq!(loss.data()[0], 1.0, epsilon = 1e-5);
334 }
335
336 #[test]
337 fn test_sample_weighted_loss_gradient() {
338 let loss_fn = SampleWeightedLoss::new(Box::new(MSELoss));
339
340 let pred = Tensor::from_vec(vec![0.0, 0.0], true);
341 let target = Tensor::from_vec(vec![1.0, 1.0], false);
342 let weights = vec![2.0, 1.0];
343
344 let loss = loss_fn.forward_weighted(&pred, &target, &weights);
345
346 if let Some(backward_op) = loss.backward_op() {
347 backward_op.backward();
348 }
349
350 let grad = pred.grad().expect("gradient should be available");
351 assert_relative_eq!(grad[0], -2.0, epsilon = 1e-5);
355 assert_relative_eq!(grad[1], -1.0, epsilon = 1e-5);
356 }
357
358 #[test]
359 fn test_sample_weighted_citl_reweight() {
360 let loss_fn = SampleWeightedLoss::new(Box::new(MSELoss));
362
363 let pred = Tensor::from_vec(vec![0.0, 0.0, 0.0], true);
364 let target = Tensor::from_vec(vec![1.0, 1.0, 1.0], false);
365
366 let weights = vec![1.5, 1.5, 1.0];
369
370 let weighted_loss = loss_fn.forward_weighted(&pred, &target, &weights);
371
372 let uniform = loss_fn.forward(&pred.clone(), &target);
374
375 assert!(weighted_loss.data()[0] > uniform.data()[0]);
377 }
378
379 #[test]
380 fn test_sample_weighted_no_grad() {
381 let loss_fn = SampleWeightedLoss::new(Box::new(MSELoss));
382 let pred = Tensor::from_vec(vec![1.0, 2.0], false);
383 let target = Tensor::from_vec(vec![1.5, 2.5], false);
384 let weights = vec![1.0, 2.0];
385 let loss = loss_fn.forward_weighted(&pred, &target, &weights);
386 assert!(loss.data()[0] > 0.0);
387 }
388
389 #[test]
390 #[should_panic(expected = "Weights must match")]
391 fn test_sample_weighted_mismatched_weights() {
392 let loss_fn = SampleWeightedLoss::new(Box::new(MSELoss));
393 let pred = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
394 let target = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
395 let weights = vec![1.0, 1.0]; loss_fn.forward_weighted(&pred, &target, &weights);
397 }
398
399 #[test]
400 fn test_gradient_accumulation_sample_weighted() {
401 let pred = Tensor::from_vec(vec![1.0, 2.0], true);
402 let target = Tensor::from_vec(vec![0.0, 0.0], false);
403 let weights = vec![1.0, 1.5];
404 let loss_fn = SampleWeightedLoss::new(Box::new(MSELoss));
405
406 let loss1 = loss_fn.forward_weighted(&pred, &target, &weights);
407 if let Some(op) = loss1.backward_op() {
408 op.backward();
409 }
410
411 let loss2 = loss_fn.forward_weighted(&pred, &target, &weights);
412 if let Some(op) = loss2.backward_op() {
413 op.backward();
414 }
415
416 let grad = pred.grad().expect("gradient should be available");
417 assert!(grad[0].is_finite());
418 assert!(grad[1].is_finite());
419 }
420}