irithyll_core/loss/
softmax.rs1use super::Loss;
13use crate::math;
14
15#[derive(Debug, Clone, Copy)]
20pub struct SoftmaxLoss {
21 pub n_classes: usize,
23}
24
25#[inline]
27fn sigmoid(x: f64) -> f64 {
28 if x >= 0.0 {
29 let z = math::exp(-x);
30 1.0 / (1.0 + z)
31 } else {
32 let z = math::exp(x);
33 z / (1.0 + z)
34 }
35}
36
37impl Loss for SoftmaxLoss {
38 #[inline]
39 fn n_outputs(&self) -> usize {
40 self.n_classes
41 }
42
43 #[inline]
44 fn gradient(&self, target: f64, prediction: f64) -> f64 {
45 let indicator = if target == 1.0 { 1.0 } else { 0.0 };
48 sigmoid(prediction) - indicator
49 }
50
51 #[inline]
52 fn hessian(&self, _target: f64, prediction: f64) -> f64 {
53 let p = sigmoid(prediction);
54 (p * (1.0 - p)).max(1e-16)
55 }
56
57 fn loss(&self, target: f64, prediction: f64) -> f64 {
58 let indicator = if target == 1.0 { 1.0 } else { 0.0 };
60 let p = sigmoid(prediction).clamp(1e-15, 1.0 - 1e-15);
61 -indicator * math::ln(p) - (1.0 - indicator) * math::ln(1.0 - p)
62 }
63
64 #[inline]
65 fn predict_transform(&self, raw: f64) -> f64 {
66 sigmoid(raw)
69 }
70
71 fn initial_prediction(&self, _targets: &[f64]) -> f64 {
72 0.0
75 }
76
77 fn loss_type(&self) -> Option<super::LossType> {
78 Some(super::LossType::Softmax {
79 n_classes: self.n_classes,
80 })
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87 use crate::math;
88
89 const EPS: f64 = 1e-10;
90
91 #[test]
92 fn test_n_outputs() {
93 let loss = SoftmaxLoss { n_classes: 5 };
94 assert_eq!(loss.n_outputs(), 5);
95 }
96
97 #[test]
98 fn test_n_outputs_binary() {
99 let loss = SoftmaxLoss { n_classes: 2 };
100 assert_eq!(loss.n_outputs(), 2);
101 }
102
103 #[test]
104 fn test_gradient_correct_class() {
105 let loss = SoftmaxLoss { n_classes: 3 };
106 let g = loss.gradient(1.0, 0.0);
107 assert!((g - (-0.5)).abs() < EPS);
108 }
109
110 #[test]
111 fn test_gradient_wrong_class() {
112 let loss = SoftmaxLoss { n_classes: 3 };
113 let g = loss.gradient(0.0, 0.0);
114 assert!((g - 0.5).abs() < EPS);
115 }
116
117 #[test]
118 fn test_gradient_confident_correct() {
119 let loss = SoftmaxLoss { n_classes: 3 };
120 let g = loss.gradient(1.0, 5.0);
121 assert!(g < 0.0);
122 assert!(g > -0.01);
123 }
124
125 #[test]
126 fn test_hessian_positive() {
127 let loss = SoftmaxLoss { n_classes: 3 };
128 assert!(loss.hessian(0.0, 0.0) > 0.0);
129 assert!(loss.hessian(1.0, 5.0) > 0.0);
130 assert!(loss.hessian(0.0, -5.0) > 0.0);
131 assert!(loss.hessian(1.0, 100.0) > 0.0);
132 }
133
134 #[test]
135 fn test_hessian_max_at_zero() {
136 let loss = SoftmaxLoss { n_classes: 3 };
137 let h_zero = loss.hessian(0.0, 0.0);
138 let h_large = loss.hessian(0.0, 5.0);
139 assert!((h_zero - 0.25).abs() < EPS);
140 assert!(h_large < h_zero);
141 }
142
143 #[test]
144 fn test_loss_value_at_zero() {
145 let loss = SoftmaxLoss { n_classes: 3 };
146 let l1 = loss.loss(1.0, 0.0);
147 let l0 = loss.loss(0.0, 0.0);
148 let ln2 = math::ln(2.0);
149 assert!((l1 - ln2).abs() < 1e-8);
150 assert!((l0 - ln2).abs() < 1e-8);
151 }
152
153 #[test]
154 fn test_loss_decreases_with_correct_prediction() {
155 let loss = SoftmaxLoss { n_classes: 3 };
156 let l_zero = loss.loss(1.0, 0.0);
157 let l_positive = loss.loss(1.0, 3.0);
158 assert!(l_positive < l_zero);
159 }
160
161 #[test]
162 fn test_predict_transform_is_sigmoid() {
163 let loss = SoftmaxLoss { n_classes: 3 };
164 assert!((loss.predict_transform(0.0) - 0.5).abs() < EPS);
165 assert!(loss.predict_transform(10.0) > 0.99);
166 assert!(loss.predict_transform(-10.0) < 0.01);
167 }
168
169 #[test]
170 fn test_initial_prediction_is_zero() {
171 let loss = SoftmaxLoss { n_classes: 3 };
172 let targets = [0.0, 1.0, 2.0, 1.0, 0.0];
173 assert!((loss.initial_prediction(&targets)).abs() < EPS);
174 assert!((loss.initial_prediction(&[])).abs() < EPS);
175 }
176
177 #[test]
178 fn test_gradient_is_derivative_of_loss() {
179 let loss = SoftmaxLoss { n_classes: 3 };
180 let target = 1.0;
181 let pred = 1.5;
182 let h = 1e-7;
183 let numerical = (loss.loss(target, pred + h) - loss.loss(target, pred - h)) / (2.0 * h);
184 let analytical = loss.gradient(target, pred);
185 assert!(
186 (numerical - analytical).abs() < 1e-5,
187 "numerical={numerical}, analytical={analytical}"
188 );
189
190 let target = 0.0;
191 let pred = -0.5;
192 let numerical = (loss.loss(target, pred + h) - loss.loss(target, pred - h)) / (2.0 * h);
193 let analytical = loss.gradient(target, pred);
194 assert!(
195 (numerical - analytical).abs() < 1e-5,
196 "numerical={numerical}, analytical={analytical}"
197 );
198 }
199}