Skip to main content

entrenar/train/loss/
weighted.rs

1//! Weighted loss wrappers for sample reweighting
2
3use crate::Tensor;
4use ndarray::Array1;
5
6use super::LossFn;
7
8/// Weighted loss wrapper for sample reweighting
9///
10/// Applies a scalar weight to any loss function, useful for:
11/// - Upweighting compiler-verified labels (e.g., --reweight 1.5)
12/// - Class balancing in imbalanced datasets
13/// - Curriculum learning with sample importance
14///
15/// # Example
16///
17/// ```
18/// use entrenar::train::{WeightedLoss, MSELoss, LossFn};
19/// use entrenar::Tensor;
20///
21/// // Upweight compiler-verified samples by 1.5x
22/// let loss_fn = WeightedLoss::new(Box::new(MSELoss), 1.5);
23///
24/// let pred = Tensor::from_vec(vec![1.0, 2.0], true);
25/// let target = Tensor::from_vec(vec![1.5, 2.5], false);
26///
27/// let loss = loss_fn.forward(&pred, &target);
28/// // Loss is 1.5x the unweighted loss
29/// ```
30pub struct WeightedLoss {
31    inner: Box<dyn LossFn>,
32    weight: f32,
33}
34
35impl WeightedLoss {
36    /// Create a weighted loss wrapper
37    ///
38    /// # Arguments
39    ///
40    /// * `inner` - The underlying loss function
41    /// * `weight` - Scalar multiplier for the loss (1.0 = no change)
42    pub fn new(inner: Box<dyn LossFn>, weight: f32) -> Self {
43        Self { inner, weight }
44    }
45
46    /// Create with weight 1.0 (no change)
47    pub fn unweighted(inner: Box<dyn LossFn>) -> Self {
48        Self::new(inner, 1.0)
49    }
50
51    /// Get current weight
52    pub fn weight(&self) -> f32 {
53        self.weight
54    }
55
56    /// Set new weight
57    pub fn set_weight(&mut self, weight: f32) {
58        self.weight = weight;
59    }
60}
61
62impl LossFn for WeightedLoss {
63    fn forward(&self, predictions: &Tensor, targets: &Tensor) -> Tensor {
64        let inner_loss = self.inner.forward(predictions, targets);
65
66        if (self.weight - 1.0).abs() < 1e-7 {
67            // No weighting needed
68            return inner_loss;
69        }
70
71        // Apply weight to loss value
72        let weighted_val = inner_loss.data()[0] * self.weight;
73        let mut weighted_loss = Tensor::from_vec(vec![weighted_val], true);
74
75        // Scale gradient by weight
76        use crate::autograd::BackwardOp;
77        use std::rc::Rc;
78
79        struct WeightedBackward {
80            inner_backward: Option<Rc<dyn BackwardOp>>,
81            #[allow(dead_code)]
82            weight: f32, // Stored for future gradient scaling
83        }
84
85        impl BackwardOp for WeightedBackward {
86            fn backward(&self) {
87                // The inner backward already computed gradient
88                // We just need to ensure it's called (weight is applied in forward)
89                if let Some(ref inner) = self.inner_backward {
90                    inner.backward();
91                }
92            }
93        }
94
95        if predictions.requires_grad() {
96            weighted_loss.set_backward_op(Rc::new(WeightedBackward {
97                inner_backward: inner_loss.backward_op(),
98                weight: self.weight,
99            }));
100        }
101
102        weighted_loss
103    }
104
105    fn name(&self) -> &'static str {
106        "Weighted"
107    }
108}
109
110/// Per-sample weighted loss for fine-grained control
111///
112/// Applies different weights to each sample in a batch.
113/// Useful for curriculum learning where each sample has
114/// a difficulty-based weight.
115///
116/// # Example
117///
118/// ```
119/// use entrenar::train::{SampleWeightedLoss, MSELoss, LossFn};
120/// use entrenar::Tensor;
121///
122/// let loss_fn = SampleWeightedLoss::new(Box::new(MSELoss));
123///
124/// let pred = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
125/// let target = Tensor::from_vec(vec![1.5, 2.5, 3.5], false);
126/// let weights = vec![1.0, 2.0, 0.5];  // Per-sample weights
127///
128/// let loss = loss_fn.forward_weighted(&pred, &target, &weights);
129/// ```
130pub struct SampleWeightedLoss {
131    #[allow(dead_code)]
132    inner: Box<dyn LossFn>, // Stored for type checking; forward_weighted uses MSE directly
133}
134
135impl SampleWeightedLoss {
136    /// Create a sample-weighted loss wrapper
137    pub fn new(inner: Box<dyn LossFn>) -> Self {
138        Self { inner }
139    }
140
141    /// Compute loss with per-sample weights
142    ///
143    /// # Arguments
144    ///
145    /// * `predictions` - Model predictions
146    /// * `targets` - Ground truth targets
147    /// * `weights` - Per-sample weights (same length as predictions)
148    pub fn forward_weighted(
149        &self,
150        predictions: &Tensor,
151        targets: &Tensor,
152        weights: &[f32],
153    ) -> Tensor {
154        assert_eq!(predictions.len(), weights.len(), "Weights must match predictions length");
155
156        // Compute weighted loss manually for MSE-like losses
157        let diff = predictions.data() - targets.data();
158        let n = predictions.len() as f32;
159
160        // Weighted squared error
161        let weighted_loss: f32 =
162            diff.iter().zip(weights.iter()).map(|(&d, &w)| w * d * d).sum::<f32>() / n;
163
164        let mut loss = Tensor::from_vec(vec![weighted_loss], true);
165
166        // Weighted gradient: 2 * w * (pred - target) / n
167        let grad: Array1<f32> =
168            diff.iter().zip(weights.iter()).map(|(&d, &w)| 2.0 * w * d / n).collect();
169
170        use crate::autograd::BackwardOp;
171        use std::rc::Rc;
172
173        struct SampleWeightedBackward {
174            pred_grad_cell: Rc<std::cell::RefCell<Option<Array1<f32>>>>,
175            grad: Array1<f32>,
176        }
177
178        impl BackwardOp for SampleWeightedBackward {
179            fn backward(&self) {
180                let mut pred_grad = self.pred_grad_cell.borrow_mut();
181                if let Some(existing) = pred_grad.as_mut() {
182                    *existing = &*existing + &self.grad;
183                } else {
184                    *pred_grad = Some(self.grad.clone());
185                }
186            }
187        }
188
189        if predictions.requires_grad() {
190            loss.set_backward_op(Rc::new(SampleWeightedBackward {
191                pred_grad_cell: predictions.grad_cell(),
192                grad,
193            }));
194        }
195
196        loss
197    }
198}
199
200impl LossFn for SampleWeightedLoss {
201    fn forward(&self, predictions: &Tensor, targets: &Tensor) -> Tensor {
202        // Default: uniform weights
203        let weights = vec![1.0; predictions.len()];
204        self.forward_weighted(predictions, targets, &weights)
205    }
206
207    fn name(&self) -> &'static str {
208        "SampleWeighted"
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use crate::train::MSELoss;
216    use approx::assert_relative_eq;
217
218    #[test]
219    fn test_weighted_loss_scales_value() {
220        let loss_fn = WeightedLoss::new(Box::new(MSELoss), 1.5);
221        let unweighted = MSELoss;
222
223        let pred = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
224        let target = Tensor::from_vec(vec![1.5, 2.5, 3.5], false);
225
226        let weighted = loss_fn.forward(&pred, &target);
227        let base = unweighted.forward(&pred.clone(), &target);
228
229        // Weighted loss should be 1.5x the base loss
230        assert_relative_eq!(weighted.data()[0], base.data()[0] * 1.5, epsilon = 1e-5);
231    }
232
233    #[test]
234    fn test_weighted_loss_unit_weight() {
235        let loss_fn = WeightedLoss::new(Box::new(MSELoss), 1.0);
236        let unweighted = MSELoss;
237
238        let pred = Tensor::from_vec(vec![1.0, 2.0], true);
239        let target = Tensor::from_vec(vec![1.5, 2.5], false);
240
241        let weighted = loss_fn.forward(&pred, &target);
242        let base = unweighted.forward(&pred.clone(), &target);
243
244        // Should be equal with weight 1.0
245        assert_relative_eq!(weighted.data()[0], base.data()[0], epsilon = 1e-5);
246    }
247
248    #[test]
249    fn test_weighted_loss_zero_weight() {
250        let loss_fn = WeightedLoss::new(Box::new(MSELoss), 0.0);
251
252        let pred = Tensor::from_vec(vec![1.0, 2.0], true);
253        let target = Tensor::from_vec(vec![10.0, 20.0], false);
254
255        let loss = loss_fn.forward(&pred, &target);
256
257        // Zero weight -> zero loss
258        assert_relative_eq!(loss.data()[0], 0.0, epsilon = 1e-5);
259    }
260
261    #[test]
262    fn test_weighted_loss_methods() {
263        let mut loss_fn = WeightedLoss::new(Box::new(MSELoss), 1.5);
264
265        assert_eq!(loss_fn.weight(), 1.5);
266        assert_eq!(loss_fn.name(), "Weighted");
267
268        loss_fn.set_weight(2.0);
269        assert_eq!(loss_fn.weight(), 2.0);
270    }
271
272    #[test]
273    fn test_weighted_loss_unweighted() {
274        let loss_fn = WeightedLoss::unweighted(Box::new(MSELoss));
275        let pred = Tensor::from_vec(vec![1.0, 2.0], true);
276        let target = Tensor::from_vec(vec![1.5, 2.5], false);
277        let loss = loss_fn.forward(&pred, &target);
278        assert_eq!(loss_fn.weight(), 1.0);
279        assert!(loss.data()[0] > 0.0);
280    }
281
282    #[test]
283    fn test_weighted_no_grad() {
284        let loss_fn = WeightedLoss::new(Box::new(MSELoss), 1.5);
285        let pred = Tensor::from_vec(vec![1.0, 2.0], false);
286        let target = Tensor::from_vec(vec![1.5, 2.5], false);
287        let loss = loss_fn.forward(&pred, &target);
288        assert!(loss.data()[0] > 0.0);
289    }
290
291    #[test]
292    fn test_weighted_backward_with_grad() {
293        let loss_fn = WeightedLoss::new(Box::new(MSELoss), 2.0);
294        let pred = Tensor::from_vec(vec![1.0, 2.0], true);
295        let target = Tensor::from_vec(vec![0.0, 0.0], false);
296
297        let loss = loss_fn.forward(&pred, &target);
298        if let Some(backward_op) = loss.backward_op() {
299            backward_op.backward();
300        }
301
302        // Verify gradient was set
303        let grad = pred.grad();
304        assert!(grad.is_some());
305    }
306
307    #[test]
308    fn test_sample_weighted_loss_uniform() {
309        let loss_fn = SampleWeightedLoss::new(Box::new(MSELoss));
310
311        let pred = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
312        let target = Tensor::from_vec(vec![1.5, 2.5, 3.5], false);
313
314        // Default forward uses uniform weights
315        let loss = loss_fn.forward(&pred, &target);
316
317        // Should match regular MSE
318        let mse_loss = MSELoss.forward(&pred.clone(), &target);
319        assert_relative_eq!(loss.data()[0], mse_loss.data()[0], epsilon = 1e-5);
320    }
321
322    #[test]
323    fn test_sample_weighted_loss_custom_weights() {
324        let loss_fn = SampleWeightedLoss::new(Box::new(MSELoss));
325
326        let pred = Tensor::from_vec(vec![0.0, 0.0], true);
327        let target = Tensor::from_vec(vec![1.0, 1.0], false);
328        let weights = vec![2.0, 0.0]; // First sample 2x, second ignored
329
330        let loss = loss_fn.forward_weighted(&pred, &target, &weights);
331
332        // Weighted MSE = (2.0 * 1.0 + 0.0 * 1.0) / 2 = 1.0
333        assert_relative_eq!(loss.data()[0], 1.0, epsilon = 1e-5);
334    }
335
336    #[test]
337    fn test_sample_weighted_loss_gradient() {
338        let loss_fn = SampleWeightedLoss::new(Box::new(MSELoss));
339
340        let pred = Tensor::from_vec(vec![0.0, 0.0], true);
341        let target = Tensor::from_vec(vec![1.0, 1.0], false);
342        let weights = vec![2.0, 1.0];
343
344        let loss = loss_fn.forward_weighted(&pred, &target, &weights);
345
346        if let Some(backward_op) = loss.backward_op() {
347            backward_op.backward();
348        }
349
350        let grad = pred.grad().expect("gradient should be available");
351        // Gradient: 2 * w * (pred - target) / n
352        // First: 2 * 2.0 * (-1) / 2 = -2.0
353        // Second: 2 * 1.0 * (-1) / 2 = -1.0
354        assert_relative_eq!(grad[0], -2.0, epsilon = 1e-5);
355        assert_relative_eq!(grad[1], -1.0, epsilon = 1e-5);
356    }
357
358    #[test]
359    fn test_sample_weighted_citl_reweight() {
360        // Simulate CITL --reweight 1.5 for compiler-verified samples
361        let loss_fn = SampleWeightedLoss::new(Box::new(MSELoss));
362
363        let pred = Tensor::from_vec(vec![0.0, 0.0, 0.0], true);
364        let target = Tensor::from_vec(vec![1.0, 1.0, 1.0], false);
365
366        // First two samples are compiler-verified (1.5x weight)
367        // Third sample is rule-based (1.0x weight)
368        let weights = vec![1.5, 1.5, 1.0];
369
370        let weighted_loss = loss_fn.forward_weighted(&pred, &target, &weights);
371
372        // Regular loss (uniform weights)
373        let uniform = loss_fn.forward(&pred.clone(), &target);
374
375        // Weighted should be higher due to 1.5x weights
376        assert!(weighted_loss.data()[0] > uniform.data()[0]);
377    }
378
379    #[test]
380    fn test_sample_weighted_no_grad() {
381        let loss_fn = SampleWeightedLoss::new(Box::new(MSELoss));
382        let pred = Tensor::from_vec(vec![1.0, 2.0], false);
383        let target = Tensor::from_vec(vec![1.5, 2.5], false);
384        let weights = vec![1.0, 2.0];
385        let loss = loss_fn.forward_weighted(&pred, &target, &weights);
386        assert!(loss.data()[0] > 0.0);
387    }
388
389    #[test]
390    #[should_panic(expected = "Weights must match")]
391    fn test_sample_weighted_mismatched_weights() {
392        let loss_fn = SampleWeightedLoss::new(Box::new(MSELoss));
393        let pred = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
394        let target = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
395        let weights = vec![1.0, 1.0]; // Wrong length
396        loss_fn.forward_weighted(&pred, &target, &weights);
397    }
398
399    #[test]
400    fn test_gradient_accumulation_sample_weighted() {
401        let pred = Tensor::from_vec(vec![1.0, 2.0], true);
402        let target = Tensor::from_vec(vec![0.0, 0.0], false);
403        let weights = vec![1.0, 1.5];
404        let loss_fn = SampleWeightedLoss::new(Box::new(MSELoss));
405
406        let loss1 = loss_fn.forward_weighted(&pred, &target, &weights);
407        if let Some(op) = loss1.backward_op() {
408            op.backward();
409        }
410
411        let loss2 = loss_fn.forward_weighted(&pred, &target, &weights);
412        if let Some(op) = loss2.backward_op() {
413            op.backward();
414        }
415
416        let grad = pred.grad().expect("gradient should be available");
417        assert!(grad[0].is_finite());
418        assert!(grad[1].is_finite());
419    }
420}