Skip to main content

irithyll_core/loss/
softmax.rs

1//! Softmax cross-entropy loss for multi-class classification.
2//!
3//! In gradient boosting, multi-class classification is typically handled by
4//! training a separate committee of trees for each class. Each committee
5//! member sees binary targets (1.0 if the sample belongs to that class,
6//! 0.0 otherwise) and uses logistic-style gradients.
7//!
8//! The full softmax normalization across classes happens at the ensemble
9//! level after all committees produce their raw outputs. Within each
10//! committee, the per-class loss reduces to binary logistic form.
11
12use super::Loss;
13use crate::math;
14
15/// Softmax (multi-class cross-entropy) loss.
16///
17/// `n_classes` controls the number of tree committees. Each committee uses
18/// logistic-style gradients on one-hot encoded targets.
19#[derive(Debug, Clone, Copy)]
20pub struct SoftmaxLoss {
21    /// Number of classes in the classification problem.
22    pub n_classes: usize,
23}
24
25/// Numerically stable sigmoid: 1 / (1 + exp(-x)).
26#[inline]
27fn sigmoid(x: f64) -> f64 {
28    if x >= 0.0 {
29        let z = math::exp(-x);
30        1.0 / (1.0 + z)
31    } else {
32        let z = math::exp(x);
33        z / (1.0 + z)
34    }
35}
36
37impl Loss for SoftmaxLoss {
38    #[inline]
39    fn n_outputs(&self) -> usize {
40        self.n_classes
41    }
42
43    #[inline]
44    fn gradient(&self, target: f64, prediction: f64) -> f64 {
45        // Per-committee binary logistic gradient.
46        // target is 1.0 if this sample belongs to this committee's class, else 0.0.
47        let indicator = if target == 1.0 { 1.0 } else { 0.0 };
48        sigmoid(prediction) - indicator
49    }
50
51    #[inline]
52    fn hessian(&self, _target: f64, prediction: f64) -> f64 {
53        let p = sigmoid(prediction);
54        (p * (1.0 - p)).max(1e-16)
55    }
56
57    fn loss(&self, target: f64, prediction: f64) -> f64 {
58        // Binary cross-entropy for this committee.
59        let indicator = if target == 1.0 { 1.0 } else { 0.0 };
60        let p = sigmoid(prediction).clamp(1e-15, 1.0 - 1e-15);
61        -indicator * math::ln(p) - (1.0 - indicator) * math::ln(1.0 - p)
62    }
63
64    #[inline]
65    fn predict_transform(&self, raw: f64) -> f64 {
66        // Per-committee sigmoid. Full softmax normalization across classes
67        // is handled by the ensemble.
68        sigmoid(raw)
69    }
70
71    fn initial_prediction(&self, _targets: &[f64]) -> f64 {
72        // Start each class committee from zero (no prior bias).
73        // The boosting loop will learn class priors through early trees.
74        0.0
75    }
76
77    fn loss_type(&self) -> Option<super::LossType> {
78        Some(super::LossType::Softmax {
79            n_classes: self.n_classes,
80        })
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87    use crate::math;
88
89    const EPS: f64 = 1e-10;
90
91    #[test]
92    fn test_n_outputs() {
93        let loss = SoftmaxLoss { n_classes: 5 };
94        assert_eq!(loss.n_outputs(), 5);
95    }
96
97    #[test]
98    fn test_n_outputs_binary() {
99        let loss = SoftmaxLoss { n_classes: 2 };
100        assert_eq!(loss.n_outputs(), 2);
101    }
102
103    #[test]
104    fn test_gradient_correct_class() {
105        let loss = SoftmaxLoss { n_classes: 3 };
106        let g = loss.gradient(1.0, 0.0);
107        assert!((g - (-0.5)).abs() < EPS);
108    }
109
110    #[test]
111    fn test_gradient_wrong_class() {
112        let loss = SoftmaxLoss { n_classes: 3 };
113        let g = loss.gradient(0.0, 0.0);
114        assert!((g - 0.5).abs() < EPS);
115    }
116
117    #[test]
118    fn test_gradient_confident_correct() {
119        let loss = SoftmaxLoss { n_classes: 3 };
120        let g = loss.gradient(1.0, 5.0);
121        assert!(g < 0.0);
122        assert!(g > -0.01);
123    }
124
125    #[test]
126    fn test_hessian_positive() {
127        let loss = SoftmaxLoss { n_classes: 3 };
128        assert!(loss.hessian(0.0, 0.0) > 0.0);
129        assert!(loss.hessian(1.0, 5.0) > 0.0);
130        assert!(loss.hessian(0.0, -5.0) > 0.0);
131        assert!(loss.hessian(1.0, 100.0) > 0.0);
132    }
133
134    #[test]
135    fn test_hessian_max_at_zero() {
136        let loss = SoftmaxLoss { n_classes: 3 };
137        let h_zero = loss.hessian(0.0, 0.0);
138        let h_large = loss.hessian(0.0, 5.0);
139        assert!((h_zero - 0.25).abs() < EPS);
140        assert!(h_large < h_zero);
141    }
142
143    #[test]
144    fn test_loss_value_at_zero() {
145        let loss = SoftmaxLoss { n_classes: 3 };
146        let l1 = loss.loss(1.0, 0.0);
147        let l0 = loss.loss(0.0, 0.0);
148        let ln2 = math::ln(2.0);
149        assert!((l1 - ln2).abs() < 1e-8);
150        assert!((l0 - ln2).abs() < 1e-8);
151    }
152
153    #[test]
154    fn test_loss_decreases_with_correct_prediction() {
155        let loss = SoftmaxLoss { n_classes: 3 };
156        let l_zero = loss.loss(1.0, 0.0);
157        let l_positive = loss.loss(1.0, 3.0);
158        assert!(l_positive < l_zero);
159    }
160
161    #[test]
162    fn test_predict_transform_is_sigmoid() {
163        let loss = SoftmaxLoss { n_classes: 3 };
164        assert!((loss.predict_transform(0.0) - 0.5).abs() < EPS);
165        assert!(loss.predict_transform(10.0) > 0.99);
166        assert!(loss.predict_transform(-10.0) < 0.01);
167    }
168
169    #[test]
170    fn test_initial_prediction_is_zero() {
171        let loss = SoftmaxLoss { n_classes: 3 };
172        let targets = [0.0, 1.0, 2.0, 1.0, 0.0];
173        assert!((loss.initial_prediction(&targets)).abs() < EPS);
174        assert!((loss.initial_prediction(&[])).abs() < EPS);
175    }
176
177    #[test]
178    fn test_gradient_is_derivative_of_loss() {
179        let loss = SoftmaxLoss { n_classes: 3 };
180        let target = 1.0;
181        let pred = 1.5;
182        let h = 1e-7;
183        let numerical = (loss.loss(target, pred + h) - loss.loss(target, pred - h)) / (2.0 * h);
184        let analytical = loss.gradient(target, pred);
185        assert!(
186            (numerical - analytical).abs() < 1e-5,
187            "numerical={numerical}, analytical={analytical}"
188        );
189
190        let target = 0.0;
191        let pred = -0.5;
192        let numerical = (loss.loss(target, pred + h) - loss.loss(target, pred - h)) / (2.0 * h);
193        let analytical = loss.gradient(target, pred);
194        assert!(
195            (numerical - analytical).abs() < 1e-5,
196            "numerical={numerical}, analytical={analytical}"
197        );
198    }
199}