Skip to main content

irithyll_core/loss/
logistic.rs

1//! Logistic loss for binary classification.
2//!
3//! L(y, f) = -y * ln(sigma(f)) - (1-y) * ln(1 - sigma(f))
4//!
5//! where sigma(f) = 1 / (1 + exp(-f)) is the sigmoid function,
6//! y in {0, 1} is the binary target, and f is the raw model output (logit).
7//!
8//! Also known as binary cross-entropy or log loss.
9
10use super::Loss;
11use crate::math;
12
13/// Logistic (binary cross-entropy) loss.
14///
15/// Targets must be `0.0` or `1.0`. Predictions are raw logits (unbounded).
16/// The sigmoid transform maps logits to probabilities in `[0, 1]`.
17#[derive(Debug, Clone, Copy)]
18pub struct LogisticLoss;
19
20/// Numerically stable sigmoid: 1 / (1 + exp(-x)).
21///
22/// Handles large positive and negative inputs without overflow by branching
23/// on the sign of x.
24#[inline]
25fn sigmoid(x: f64) -> f64 {
26    if x >= 0.0 {
27        let z = math::exp(-x);
28        1.0 / (1.0 + z)
29    } else {
30        let z = math::exp(x);
31        z / (1.0 + z)
32    }
33}
34
35impl Loss for LogisticLoss {
36    #[inline]
37    fn n_outputs(&self) -> usize {
38        1
39    }
40
41    #[inline]
42    fn gradient(&self, target: f64, prediction: f64) -> f64 {
43        sigmoid(prediction) - target
44    }
45
46    #[inline]
47    fn hessian(&self, _target: f64, prediction: f64) -> f64 {
48        let p = sigmoid(prediction);
49        (p * (1.0 - p)).max(1e-16)
50    }
51
52    fn loss(&self, target: f64, prediction: f64) -> f64 {
53        let p = sigmoid(prediction).clamp(1e-15, 1.0 - 1e-15);
54        -target * math::ln(p) - (1.0 - target) * math::ln(1.0 - p)
55    }
56
57    #[inline]
58    fn predict_transform(&self, raw: f64) -> f64 {
59        sigmoid(raw)
60    }
61
62    fn initial_prediction(&self, targets: &[f64]) -> f64 {
63        if targets.is_empty() {
64            return 0.0;
65        }
66        let sum: f64 = targets.iter().sum();
67        let mean = (sum / targets.len() as f64).clamp(1e-7, 1.0 - 1e-7);
68        // log-odds: ln(p / (1 - p))
69        math::ln(mean / (1.0 - mean))
70    }
71
72    fn loss_type(&self) -> Option<super::LossType> {
73        Some(super::LossType::Logistic)
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80
81    const EPS: f64 = 1e-10;
82
83    #[test]
84    fn test_n_outputs() {
85        assert_eq!(LogisticLoss.n_outputs(), 1);
86    }
87
88    #[test]
89    fn test_sigmoid_basic() {
90        assert!((sigmoid(0.0) - 0.5).abs() < EPS);
91        assert!((sigmoid(100.0) - 1.0).abs() < EPS);
92        assert!(sigmoid(-100.0).abs() < EPS);
93        let x = 2.5;
94        assert!((sigmoid(x) + sigmoid(-x) - 1.0).abs() < EPS);
95    }
96
97    #[test]
98    fn test_gradient_target_1_pred_0() {
99        let loss = LogisticLoss;
100        let g = loss.gradient(1.0, 0.0);
101        assert!((g - (-0.5)).abs() < EPS);
102    }
103
104    #[test]
105    fn test_gradient_target_0_pred_0() {
106        let loss = LogisticLoss;
107        let g = loss.gradient(0.0, 0.0);
108        assert!((g - 0.5).abs() < EPS);
109    }
110
111    #[test]
112    fn test_gradient_perfect_prediction() {
113        let loss = LogisticLoss;
114        let g = loss.gradient(1.0, 20.0);
115        assert!(g.abs() < 1e-6);
116    }
117
118    #[test]
119    fn test_hessian_positive() {
120        let loss = LogisticLoss;
121        assert!(loss.hessian(0.0, 0.0) > 0.0);
122        assert!(loss.hessian(1.0, 5.0) > 0.0);
123        assert!(loss.hessian(0.0, -5.0) > 0.0);
124        assert!(loss.hessian(1.0, 100.0) > 0.0);
125    }
126
127    #[test]
128    fn test_hessian_max_at_pred_zero() {
129        let loss = LogisticLoss;
130        let h_zero = loss.hessian(0.0, 0.0);
131        let h_five = loss.hessian(0.0, 5.0);
132        assert!((h_zero - 0.25).abs() < EPS);
133        assert!(h_five < h_zero);
134    }
135
136    #[test]
137    fn test_loss_value() {
138        let loss = LogisticLoss;
139        let l1 = loss.loss(1.0, 0.0);
140        let l0 = loss.loss(0.0, 0.0);
141        let ln2 = math::ln(2.0);
142        assert!((l1 - ln2).abs() < 1e-8);
143        assert!((l0 - ln2).abs() < 1e-8);
144    }
145
146    #[test]
147    fn test_predict_transform_is_sigmoid() {
148        let loss = LogisticLoss;
149        assert!((loss.predict_transform(0.0) - 0.5).abs() < EPS);
150        assert!(loss.predict_transform(10.0) > 0.99);
151        assert!(loss.predict_transform(-10.0) < 0.01);
152    }
153
154    #[test]
155    fn test_initial_prediction_balanced() {
156        let loss = LogisticLoss;
157        let targets = [0.0, 1.0, 0.0, 1.0];
158        assert!(loss.initial_prediction(&targets).abs() < EPS);
159    }
160
161    #[test]
162    fn test_initial_prediction_skewed() {
163        let loss = LogisticLoss;
164        let targets = [1.0, 1.0, 1.0, 0.0];
165        let init = loss.initial_prediction(&targets);
166        let expected = math::ln(0.75 / 0.25);
167        assert!((init - expected).abs() < 1e-8);
168    }
169
170    #[test]
171    fn test_initial_prediction_empty() {
172        let loss = LogisticLoss;
173        assert!((loss.initial_prediction(&[])).abs() < EPS);
174    }
175
176    #[test]
177    fn test_gradient_is_derivative_of_loss() {
178        let loss = LogisticLoss;
179        let target = 1.0;
180        let pred = 1.5;
181        let h = 1e-7;
182        let numerical = (loss.loss(target, pred + h) - loss.loss(target, pred - h)) / (2.0 * h);
183        let analytical = loss.gradient(target, pred);
184        assert!(
185            (numerical - analytical).abs() < 1e-5,
186            "numerical={numerical}, analytical={analytical}"
187        );
188    }
189}