irithyll_core/loss/
logistic.rs1use super::Loss;
11use crate::math;
12
13#[derive(Debug, Clone, Copy)]
18pub struct LogisticLoss;
19
20#[inline]
25fn sigmoid(x: f64) -> f64 {
26 if x >= 0.0 {
27 let z = math::exp(-x);
28 1.0 / (1.0 + z)
29 } else {
30 let z = math::exp(x);
31 z / (1.0 + z)
32 }
33}
34
35impl Loss for LogisticLoss {
36 #[inline]
37 fn n_outputs(&self) -> usize {
38 1
39 }
40
41 #[inline]
42 fn gradient(&self, target: f64, prediction: f64) -> f64 {
43 sigmoid(prediction) - target
44 }
45
46 #[inline]
47 fn hessian(&self, _target: f64, prediction: f64) -> f64 {
48 let p = sigmoid(prediction);
49 (p * (1.0 - p)).max(1e-16)
50 }
51
52 fn loss(&self, target: f64, prediction: f64) -> f64 {
53 let p = sigmoid(prediction).clamp(1e-15, 1.0 - 1e-15);
54 -target * math::ln(p) - (1.0 - target) * math::ln(1.0 - p)
55 }
56
57 #[inline]
58 fn predict_transform(&self, raw: f64) -> f64 {
59 sigmoid(raw)
60 }
61
62 fn initial_prediction(&self, targets: &[f64]) -> f64 {
63 if targets.is_empty() {
64 return 0.0;
65 }
66 let sum: f64 = targets.iter().sum();
67 let mean = (sum / targets.len() as f64).clamp(1e-7, 1.0 - 1e-7);
68 math::ln(mean / (1.0 - mean))
70 }
71
72 fn loss_type(&self) -> Option<super::LossType> {
73 Some(super::LossType::Logistic)
74 }
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80
81 const EPS: f64 = 1e-10;
82
83 #[test]
84 fn test_n_outputs() {
85 assert_eq!(LogisticLoss.n_outputs(), 1);
86 }
87
88 #[test]
89 fn test_sigmoid_basic() {
90 assert!((sigmoid(0.0) - 0.5).abs() < EPS);
91 assert!((sigmoid(100.0) - 1.0).abs() < EPS);
92 assert!(sigmoid(-100.0).abs() < EPS);
93 let x = 2.5;
94 assert!((sigmoid(x) + sigmoid(-x) - 1.0).abs() < EPS);
95 }
96
97 #[test]
98 fn test_gradient_target_1_pred_0() {
99 let loss = LogisticLoss;
100 let g = loss.gradient(1.0, 0.0);
101 assert!((g - (-0.5)).abs() < EPS);
102 }
103
104 #[test]
105 fn test_gradient_target_0_pred_0() {
106 let loss = LogisticLoss;
107 let g = loss.gradient(0.0, 0.0);
108 assert!((g - 0.5).abs() < EPS);
109 }
110
111 #[test]
112 fn test_gradient_perfect_prediction() {
113 let loss = LogisticLoss;
114 let g = loss.gradient(1.0, 20.0);
115 assert!(g.abs() < 1e-6);
116 }
117
118 #[test]
119 fn test_hessian_positive() {
120 let loss = LogisticLoss;
121 assert!(loss.hessian(0.0, 0.0) > 0.0);
122 assert!(loss.hessian(1.0, 5.0) > 0.0);
123 assert!(loss.hessian(0.0, -5.0) > 0.0);
124 assert!(loss.hessian(1.0, 100.0) > 0.0);
125 }
126
127 #[test]
128 fn test_hessian_max_at_pred_zero() {
129 let loss = LogisticLoss;
130 let h_zero = loss.hessian(0.0, 0.0);
131 let h_five = loss.hessian(0.0, 5.0);
132 assert!((h_zero - 0.25).abs() < EPS);
133 assert!(h_five < h_zero);
134 }
135
136 #[test]
137 fn test_loss_value() {
138 let loss = LogisticLoss;
139 let l1 = loss.loss(1.0, 0.0);
140 let l0 = loss.loss(0.0, 0.0);
141 let ln2 = math::ln(2.0);
142 assert!((l1 - ln2).abs() < 1e-8);
143 assert!((l0 - ln2).abs() < 1e-8);
144 }
145
146 #[test]
147 fn test_predict_transform_is_sigmoid() {
148 let loss = LogisticLoss;
149 assert!((loss.predict_transform(0.0) - 0.5).abs() < EPS);
150 assert!(loss.predict_transform(10.0) > 0.99);
151 assert!(loss.predict_transform(-10.0) < 0.01);
152 }
153
154 #[test]
155 fn test_initial_prediction_balanced() {
156 let loss = LogisticLoss;
157 let targets = [0.0, 1.0, 0.0, 1.0];
158 assert!(loss.initial_prediction(&targets).abs() < EPS);
159 }
160
161 #[test]
162 fn test_initial_prediction_skewed() {
163 let loss = LogisticLoss;
164 let targets = [1.0, 1.0, 1.0, 0.0];
165 let init = loss.initial_prediction(&targets);
166 let expected = math::ln(0.75 / 0.25);
167 assert!((init - expected).abs() < 1e-8);
168 }
169
170 #[test]
171 fn test_initial_prediction_empty() {
172 let loss = LogisticLoss;
173 assert!((loss.initial_prediction(&[])).abs() < EPS);
174 }
175
176 #[test]
177 fn test_gradient_is_derivative_of_loss() {
178 let loss = LogisticLoss;
179 let target = 1.0;
180 let pred = 1.5;
181 let h = 1e-7;
182 let numerical = (loss.loss(target, pred + h) - loss.loss(target, pred - h)) / (2.0 * h);
183 let analytical = loss.gradient(target, pred);
184 assert!(
185 (numerical - analytical).abs() < 1e-5,
186 "numerical={numerical}, analytical={analytical}"
187 );
188 }
189}