use ndarray::ArrayView1;
use crate::util;
use crate::vec_simd::dot;
const LOGISTIC_ROUND_ACTIVATION: f32 = 10.0;
pub fn log_logistic_loss(u: ArrayView1<f32>, v: ArrayView1<f32>, label: bool) -> (f32, f32) {
let dp = dot(u, v);
let lf = logistic_function(dp);
let grad = (label as usize) as f32 - lf;
let loss = if label {
-util::safe_ln(lf)
} else {
-util::safe_ln(1.0 - lf)
};
(loss, grad)
}
fn logistic_function(a: f32) -> f32 {
if a > LOGISTIC_ROUND_ACTIVATION {
1.0
} else if a < -LOGISTIC_ROUND_ACTIVATION {
0.0
} else {
1.0 / (1.0 + (-a).exp())
}
}
#[cfg(test)]
mod tests {
use ndarray::Array1;
use crate::util::{all_close, close};
use super::{log_logistic_loss, logistic_function};
#[test]
fn logistic_function_test() {
let activations = &[
-11.0, -5.0, -4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 11.0,
];
let outputs: Vec<_> = activations.iter().map(|&a| logistic_function(a)).collect();
assert!(all_close(
&[
0.0, 0.00669, 0.01799, 0.04743, 0.11920, 0.26894, 0.5, 0.73106, 0.88080, 0.95257,
0.982014, 0.99331, 1.0
],
outputs.as_slice(),
1e-5
));
}
#[test]
fn log_logistic_loss_test() {
let a = Array1::from_shape_vec((6,), vec![1., 1., 1., 0., 0., 0.]).unwrap();
let a_orth = Array1::from_shape_vec((6,), vec![0., 0., 0., 1., 1., 1.]).unwrap();
let a_opp = Array1::from_shape_vec((6,), vec![-1., -1., -1., 0., 0., 0.]).unwrap();
let (loss, gradient) = log_logistic_loss(a.view(), a_orth.view(), true);
assert!(close(loss, 0.69312, 1e-5));
assert!(close(gradient, 0.5, 1e-5));
let (loss, gradient) = log_logistic_loss(a.view(), a_orth.view(), false);
assert!(close(loss, 0.69312, 1e-5));
assert!(close(gradient, -0.5, 1e-5));
let (loss, gradient) = log_logistic_loss(a.view(), a.view(), true);
assert!(close(loss, 0.04858, 1e-5));
assert!(close(gradient, 0.04742, 1e-5));
let (loss, gradient) = log_logistic_loss(a.view(), a_opp.view(), false);
assert!(close(loss, 0.04858, 1e-5));
assert!(close(gradient, -0.04743, 1e-5));
let (loss, gradient) = log_logistic_loss(a.view(), a_opp.view(), true);
assert!(close(loss, 3.04838, 1e-5));
assert!(close(gradient, 0.95257, 1e-5));
}
}