use core::f64::consts::LN_2;
use logp::{
cross_entropy_nats, entropy_nats, entropy_unchecked, jensen_shannon_divergence, kl_divergence,
kl_divergence_gaussians, mutual_information, mutual_information_ksg, Error, KsgVariant,
};
const TOL: f64 = 1e-9;
#[test]
fn entropy_empty_is_error() {
assert!(matches!(entropy_nats(&[], TOL), Err(Error::Empty)));
}
#[test]
fn kl_empty_is_error() {
assert!(matches!(kl_divergence(&[], &[], TOL), Err(Error::Empty)));
}
#[test]
fn cross_entropy_empty_is_error() {
assert!(matches!(
cross_entropy_nats(&[], &[], TOL),
Err(Error::Empty)
));
}
#[test]
fn entropy_with_zero_entry_is_exactly_zero() {
let h = entropy_nats(&[0.0, 1.0], TOL).unwrap();
assert_eq!(h, 0.0);
}
#[test]
fn kl_skips_zero_p_entries() {
let kl = kl_divergence(&[0.0, 1.0], &[0.5, 0.5], TOL).unwrap();
assert!((kl - LN_2).abs() < 1e-15, "kl={kl}");
}
#[test]
fn cross_entropy_skips_zero_p_entries() {
let h = cross_entropy_nats(&[0.0, 1.0], &[0.5, 0.5], TOL).unwrap();
assert!((h - LN_2).abs() < 1e-15, "h={h}");
let h = cross_entropy_nats(&[0.0, 1.0], &[0.0, 1.0], TOL).unwrap();
assert_eq!(h, 0.0);
}
#[test]
fn kl_zero_q_on_p_support_is_domain_error() {
assert!(matches!(
kl_divergence(&[0.5, 0.5], &[1.0, 0.0], TOL),
Err(Error::Domain(_))
));
}
#[test]
fn js_with_disjoint_supports_is_exactly_ln2() {
let js = jensen_shannon_divergence(&[1.0, 0.0], &[0.0, 1.0], TOL).unwrap();
assert!((js - LN_2).abs() < 1e-12, "js={js}");
}
#[test]
fn entropy_unnormalized_is_not_normalized_error() {
match entropy_nats(&[0.3, 0.3], TOL) {
Err(Error::NotNormalized { sum }) => assert!((sum - 0.6).abs() < 1e-12, "sum={sum}"),
other => panic!("expected NotNormalized, got {other:?}"),
}
}
#[test]
fn entropy_negative_entry_reports_index_and_value() {
match entropy_nats(&[1.5, -0.5], TOL) {
Err(Error::Negative { idx, value }) => {
assert_eq!(idx, 1);
assert!((value - (-0.5)).abs() < 1e-15, "value={value}");
}
other => panic!("expected Negative, got {other:?}"),
}
}
#[test]
fn entropy_nan_entry_is_non_finite_error() {
assert!(matches!(
entropy_nats(&[f64::NAN, 1.0], TOL),
Err(Error::NonFinite { idx: 0, .. })
));
}
#[test]
fn kl_length_mismatch_reports_lengths() {
assert!(matches!(
kl_divergence(&[0.5, 0.5], &[0.3, 0.3, 0.4], TOL),
Err(Error::LengthMismatch(2, 3))
));
}
#[test]
fn entropy_unchecked_empty_is_zero() {
assert_eq!(entropy_unchecked(&[]), 0.0);
}
#[test]
fn entropy_unchecked_unnormalized_goes_negative() {
let h = entropy_unchecked(&[2.0]);
let expected = -(2.0 * 2.0_f64.ln());
assert!((h - expected).abs() < 1e-15, "h={h}");
}
#[test]
fn kl_gaussians_empty_is_zero() {
let kl = kl_divergence_gaussians(&[], &[], &[], &[]).unwrap();
assert_eq!(kl, 0.0);
}
#[test]
fn kl_gaussians_zero_std_is_domain_error() {
assert!(matches!(
kl_divergence_gaussians(&[0.0], &[0.0], &[0.0], &[1.0]),
Err(Error::Domain(_))
));
}
#[test]
fn mutual_information_shape_mismatch() {
let p_xy = [0.25, 0.25, 0.25, 0.25];
assert!(matches!(
mutual_information(&p_xy, 2, 3, TOL),
Err(Error::LengthMismatch(4, 6))
));
assert!(matches!(
mutual_information(&p_xy, 0, 4, TOL),
Err(Error::Domain(_))
));
}
#[test]
fn ksg_rejects_empty_input_and_bad_k() {
let empty: Vec<Vec<f64>> = vec![];
assert!(matches!(
mutual_information_ksg(&empty, &empty, 3, KsgVariant::Alg1),
Err(Error::Domain(_))
));
let x: Vec<Vec<f64>> = (0..5).map(|i| vec![i as f64]).collect();
let y = x.clone();
assert!(matches!(
mutual_information_ksg(&x, &y, 0, KsgVariant::Alg1),
Err(Error::Domain(_))
));
assert!(matches!(
mutual_information_ksg(&x, &y, 5, KsgVariant::Alg2),
Err(Error::Domain(_))
));
assert!(matches!(
mutual_information_ksg(&x, &y[..4], 2, KsgVariant::Alg1),
Err(Error::LengthMismatch(5, 4))
));
}
#[test]
fn cross_entropy_rejects_length_mismatch() {
let p = [0.5, 0.5];
let q = [0.3, 0.3, 0.4];
assert!(matches!(
logp::cross_entropy_nats(&p, &q, 1e-9),
Err(logp::Error::LengthMismatch(2, 3))
));
}
#[test]
fn renyi_tsallis_divergence_reject_nonpositive_alpha() {
let p = [0.5, 0.5];
let q = [0.4, 0.6];
for alpha in [0.0, -1.0, f64::NAN] {
assert!(matches!(
logp::renyi_divergence(&p, &q, alpha, 1e-9),
Err(logp::Error::InvalidAlpha { .. })
));
assert!(matches!(
logp::tsallis_divergence(&p, &q, alpha, 1e-9),
Err(logp::Error::InvalidAlpha { .. })
));
}
}