Skip to main content

entrenar/train/loss/mse/
huber_loss.rs

1//! Huber Loss (Smooth L1 Loss)
2//!
3//! Forward scalar computation delegates to [`aprender::loss::huber_loss`].
4//! Gradient computation (backward) is entrenar's autograd concern.
5
6use aprender::primitives::Vector;
7
8use crate::autograd::BackwardOp;
9use crate::Tensor;
10use ndarray::Array1;
11use std::rc::Rc;
12
13use crate::train::loss::LossFn;
14
15/// Huber Loss (Smooth L1 Loss)
16///
17/// Combines MSE for small errors and MAE for large errors,
18/// making it robust to outliers.
19///
20/// For |error| <= delta:  L = 0.5 * error^2
21/// For |error| > delta:   L = delta * (|error| - 0.5 * delta)
22///
23/// # Example
24///
25/// ```
26/// use entrenar::train::{HuberLoss, LossFn};
27/// use entrenar::Tensor;
28///
29/// let loss_fn = HuberLoss::new(1.0);
30/// let pred = Tensor::from_vec(vec![1.0, 2.0, 10.0], true);  // 10.0 is outlier
31/// let target = Tensor::from_vec(vec![1.5, 2.5, 0.0], false);
32///
33/// let loss = loss_fn.forward(&pred, &target);
34/// assert!(loss.data()[0] > 0.0);
35/// ```
36pub struct HuberLoss {
37    /// Threshold for switching between quadratic and linear
38    delta: f32,
39}
40
41impl HuberLoss {
42    /// Create Huber loss with given delta threshold
43    pub fn new(delta: f32) -> Self {
44        assert!(delta > 0.0, "delta must be positive");
45        Self { delta }
46    }
47
48    /// Create Huber loss with default delta = 1.0
49    pub fn default_delta() -> Self {
50        Self::new(1.0)
51    }
52}
53
54impl Default for HuberLoss {
55    fn default() -> Self {
56        Self::new(1.0)
57    }
58}
59
60impl LossFn for HuberLoss {
61    fn forward(&self, predictions: &Tensor, targets: &Tensor) -> Tensor {
62        assert_eq!(
63            predictions.len(),
64            targets.len(),
65            "Predictions and targets must have same length"
66        );
67
68        // Delegate forward scalar to aprender
69        let pred_vec =
70            Vector::from_slice(predictions.data().as_slice().expect("contiguous tensor data"));
71        let tgt_vec =
72            Vector::from_slice(targets.data().as_slice().expect("contiguous tensor data"));
73        let mean_loss = aprender::loss::huber_loss(&pred_vec, &tgt_vec, self.delta);
74
75        let mut loss = Tensor::from_vec(vec![mean_loss], true);
76
77        // Gradient computation is entrenar's autograd concern
78        let diff = predictions.data() - targets.data();
79        let n = predictions.len() as f32;
80        let delta = self.delta;
81
82        // Compute gradient per element
83        // d(Huber)/d(pred) = error if |error| <= delta
84        //                  = delta * sign(error) if |error| > delta
85        let grad: Array1<f32> = diff
86            .iter()
87            .map(|&d| {
88                let abs_d = d.abs();
89                if abs_d <= delta {
90                    d / n
91                } else {
92                    delta * d.signum() / n
93                }
94            })
95            .collect();
96
97        struct HuberBackward {
98            pred_grad_cell: Rc<std::cell::RefCell<Option<Array1<f32>>>>,
99            grad: Array1<f32>,
100        }
101
102        impl BackwardOp for HuberBackward {
103            fn backward(&self) {
104                let mut pred_grad = self.pred_grad_cell.borrow_mut();
105                if let Some(existing) = pred_grad.as_mut() {
106                    *existing = &*existing + &self.grad;
107                } else {
108                    *pred_grad = Some(self.grad.clone());
109                }
110            }
111        }
112
113        if predictions.requires_grad() {
114            loss.set_backward_op(Rc::new(HuberBackward {
115                pred_grad_cell: predictions.grad_cell(),
116                grad,
117            }));
118        }
119
120        loss
121    }
122
123    fn name(&self) -> &'static str {
124        "Huber"
125    }
126}
127
128/// Smooth L1 Loss (alias for HuberLoss with delta=1.0)
129///
130/// Equivalent to HuberLoss with delta=1.0, commonly used in
131/// object detection (e.g., Faster R-CNN).
132pub type SmoothL1Loss = HuberLoss;