use super::activation::sigmoid;
use super::LossFunction;
#[derive(Debug, Clone, Copy, Default)]
pub struct BinaryLogLoss {
eps: f32,
}
impl BinaryLogLoss {
pub fn new() -> Self {
Self { eps: 1e-7 }
}
pub fn with_eps(eps: f32) -> Self {
Self { eps }
}
#[inline]
pub fn to_probability(&self, raw: f32) -> f32 {
sigmoid(raw)
}
#[inline]
pub fn to_class(&self, prob: f32, threshold: f32) -> u32 {
if prob >= threshold {
1
} else {
0
}
}
}
impl LossFunction for BinaryLogLoss {
#[inline]
fn loss(&self, target: f32, prediction: f32) -> f32 {
let p = sigmoid(prediction).clamp(self.eps, 1.0 - self.eps);
-(target * p.ln() + (1.0 - target) * (1.0 - p).ln())
}
#[inline]
fn gradient(&self, target: f32, prediction: f32) -> f32 {
sigmoid(prediction) - target
}
#[inline]
fn hessian(&self, _target: f32, prediction: f32) -> f32 {
let p = sigmoid(prediction);
let h = p * (1.0 - p);
h.max(self.eps)
}
#[inline]
fn gradient_hessian(&self, target: f32, prediction: f32) -> (f32, f32) {
let p = sigmoid(prediction);
let g = p - target;
let h = (p * (1.0 - p)).max(self.eps);
(g, h)
}
fn compute_gradients(
&self,
targets: &[f32],
predictions: &[f32],
gradients: &mut [f32],
hessians: &mut [f32],
) {
debug_assert_eq!(targets.len(), predictions.len());
debug_assert_eq!(targets.len(), gradients.len());
debug_assert_eq!(targets.len(), hessians.len());
for i in 0..targets.len() {
let p = sigmoid(predictions[i]);
gradients[i] = p - targets[i];
hessians[i] = (p * (1.0 - p)).max(self.eps);
}
}
fn initial_prediction(&self, targets: &[f32]) -> f32 {
if targets.is_empty() {
return 0.0;
}
let positive: f32 = targets.iter().filter(|&&t| t > 0.5).count() as f32;
let total = targets.len() as f32;
let p = (positive / total).clamp(self.eps, 1.0 - self.eps);
(p / (1.0 - p)).ln()
}
fn name(&self) -> &'static str {
"binary_logloss"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_loss() {
let loss = BinaryLogLoss::new();
assert!(loss.loss(1.0, 10.0) < 0.001);
assert!(loss.loss(0.0, -10.0) < 0.001);
assert!(loss.loss(1.0, -10.0) > 5.0);
assert!(loss.loss(0.0, 10.0) > 5.0);
}
#[test]
fn test_gradient() {
let loss = BinaryLogLoss::new();
assert!((loss.gradient(1.0, 0.0) - (-0.5)).abs() < 1e-6);
assert!((loss.gradient(0.0, 0.0) - 0.5).abs() < 1e-6);
assert!(loss.gradient(1.0, 10.0).abs() < 0.001);
assert!((loss.gradient(0.0, 10.0) - 1.0).abs() < 0.001);
}
#[test]
fn test_hessian() {
let loss = BinaryLogLoss::new();
assert!((loss.hessian(0.0, 0.0) - 0.25).abs() < 1e-6);
assert!(loss.hessian(0.0, 10.0) < 0.01);
assert!(loss.hessian(0.0, -10.0) < 0.01);
assert!(loss.hessian(0.0, 100.0) > 0.0);
}
#[test]
fn test_initial_prediction() {
let loss = BinaryLogLoss::new();
let targets = vec![0.0, 1.0, 0.0, 1.0];
assert!(loss.initial_prediction(&targets).abs() < 1e-6);
let targets = vec![1.0, 1.0, 1.0, 0.0];
assert!((loss.initial_prediction(&targets) - 3.0_f32.ln()).abs() < 1e-5);
let targets = vec![0.0, 0.0, 0.0, 1.0];
assert!((loss.initial_prediction(&targets) - (1.0_f32 / 3.0).ln()).abs() < 1e-5);
}
#[test]
fn test_to_probability() {
let loss = BinaryLogLoss::new();
assert!((loss.to_probability(0.0) - 0.5).abs() < 1e-6);
assert!(loss.to_probability(10.0) > 0.99);
assert!(loss.to_probability(-10.0) < 0.01);
}
#[test]
fn test_to_class() {
let loss = BinaryLogLoss::new();
assert_eq!(loss.to_class(0.6, 0.5), 1);
assert_eq!(loss.to_class(0.4, 0.5), 0);
assert_eq!(loss.to_class(0.5, 0.5), 1);
assert_eq!(loss.to_class(0.6, 0.7), 0);
assert_eq!(loss.to_class(0.8, 0.7), 1);
}
#[test]
fn test_numerical_stability() {
let loss = BinaryLogLoss::new();
let extreme_preds = [-1000.0, -100.0, 100.0, 1000.0];
for pred in extreme_preds {
let l = loss.loss(0.0, pred);
let g = loss.gradient(0.0, pred);
let h = loss.hessian(0.0, pred);
assert!(l.is_finite(), "Loss not finite for pred={}", pred);
assert!(g.is_finite(), "Gradient not finite for pred={}", pred);
assert!(h.is_finite(), "Hessian not finite for pred={}", pred);
assert!(h > 0.0, "Hessian not positive for pred={}", pred);
}
}
}