pub const EULER_GAMMA: f32 = 0.577_215_7;
#[cfg(any(feature = "ndarray-backend", feature = "candle-backend"))]
pub(crate) const BOUNDARY_CONTAINMENT_THRESHOLD: f32 = 0.99;
pub fn softplus(x: f32, beta: f32) -> f32 {
let bx = beta * x;
if bx > 20.0 {
x } else if bx < -20.0 {
0.0 } else {
bx.exp().ln_1p() / beta
}
}
pub fn stable_logsumexp(a: f32, b: f32) -> f32 {
let m = a.max(b);
if m == f32::NEG_INFINITY {
return f32::NEG_INFINITY;
}
m + ((a - m).exp() + (b - m).exp()).ln()
}
pub fn gumbel_lse_min(z_a: f32, z_b: f32, temperature: f32) -> f32 {
temperature * stable_logsumexp(z_a / temperature, z_b / temperature)
}
pub fn gumbel_lse_max(z_a: f32, z_b: f32, temperature: f32) -> f32 {
-gumbel_lse_min(-z_a, -z_b, temperature)
}
pub fn bessel_side_length(z: f32, big_z: f32, t_int: f32, t_vol: f32) -> f32 {
let arg = big_z - z - 2.0 * EULER_GAMMA * t_int;
softplus(arg, 1.0 / t_vol)
}
pub fn bessel_log_volume(mins: &[f32], maxs: &[f32], t_int: f32, t_vol: f32) -> (f32, f32) {
const EPS: f32 = 1e-13;
let log_vol: f32 = mins
.iter()
.zip(maxs.iter())
.map(|(&z, &big_z)| (bessel_side_length(z, big_z, t_int, t_vol) + EPS).ln())
.sum();
(log_vol, log_vol.exp())
}
pub const MIN_TEMPERATURE: f32 = 1e-3;
pub const MAX_TEMPERATURE: f32 = 10.0;
pub(crate) fn clamp_temperature_default(temp: f32) -> f32 {
temp.clamp(MIN_TEMPERATURE, MAX_TEMPERATURE)
}
pub fn stable_sigmoid(x: f32) -> f32 {
if x >= 0.0 {
1.0 / (1.0 + (-x).exp())
} else {
let exp_x = x.exp();
exp_x / (1.0 + exp_x)
}
}
pub fn gumbel_membership_prob(x: f32, min: f32, max: f32, temp: f32) -> f32 {
let temp_safe = clamp_temperature_default(temp);
let min_prob = stable_sigmoid((x - min) / temp_safe);
let max_prob = stable_sigmoid((max - x) / temp_safe);
min_prob * max_prob
}
pub fn sample_gumbel(u: f32, epsilon: f32) -> f32 {
let u_clamped = u.clamp(epsilon, 1.0 - epsilon);
-(-u_clamped.ln()).ln()
}
pub fn map_gumbel_to_bounds(gumbel: f32, min: f32, max: f32, temp: f32) -> f32 {
let temp_safe = clamp_temperature_default(temp);
let normalized = (gumbel / temp_safe).tanh();
let t = (normalized + 1.0) / 2.0;
min + (max - min) * t.clamp(0.0, 1.0)
}
pub fn log_space_volume<I>(side_lengths: I) -> (f32, f32)
where
I: Iterator<Item = f32>,
{
const EPSILON: f32 = 1e-10;
let mut log_sum = 0.0;
let mut has_zero = false;
for side_len in side_lengths {
if side_len <= EPSILON {
has_zero = true;
break;
}
log_sum += side_len.ln();
}
if has_zero {
(f32::NEG_INFINITY, 0.0)
} else {
let volume = log_sum.exp();
(log_sum, volume)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_clamp_temperature() {
assert_eq!(clamp_temperature_default(0.0), MIN_TEMPERATURE);
assert_eq!(clamp_temperature_default(100.0), MAX_TEMPERATURE);
assert_eq!(clamp_temperature_default(1.0), 1.0);
}
#[test]
fn test_stable_sigmoid() {
assert!(stable_sigmoid(-100.0) > 0.0);
assert!(stable_sigmoid(100.0) <= 1.0); assert!(stable_sigmoid(100.0) > 0.99); assert!((stable_sigmoid(0.0) - 0.5).abs() < 1e-6);
}
#[test]
fn test_gumbel_membership_prob() {
let prob = gumbel_membership_prob(0.5, 0.0, 1.0, 1.0);
assert!(prob > 0.0 && prob <= 1.0);
let prob_out = gumbel_membership_prob(2.0, 0.0, 1.0, 0.001);
assert!(prob_out < 0.5);
let prob_hard = gumbel_membership_prob(0.5, 0.0, 1.0, 0.0001);
assert!((prob_hard - 1.0).abs() < 1e-3);
}
#[test]
fn test_sample_gumbel() {
let g = sample_gumbel(0.5, 1e-7);
assert!(g.is_finite());
let g0 = sample_gumbel(0.0, 1e-7);
assert!(g0.is_finite());
let g1 = sample_gumbel(1.0, 1e-7);
assert!(g1.is_finite());
}
#[test]
fn test_map_gumbel_to_bounds() {
let value = map_gumbel_to_bounds(0.0, 0.0, 1.0, 1.0);
assert!((0.0..=1.0).contains(&value));
}
#[test]
fn test_log_space_volume() {
let side_lengths = [1.0, 2.0, 0.5];
let (log_vol, vol) = log_space_volume(side_lengths.iter().copied());
assert!((vol - 1.0).abs() < 1e-6); assert!((log_vol - 0.0).abs() < 1e-6);
let many_small = [0.1; 20];
let (log_vol_hd, vol_hd) = log_space_volume(many_small.iter().copied());
assert!(log_vol_hd.is_finite());
assert!(vol_hd > 0.0);
assert!(vol_hd < 1.0);
let with_zero = [1.0, 0.0, 2.0];
let (log_vol_zero, vol_zero) = log_space_volume(with_zero.iter().copied());
assert_eq!(log_vol_zero, f32::NEG_INFINITY);
assert_eq!(vol_zero, 0.0);
}
#[test]
fn test_softplus_basic() {
assert!((softplus(0.0, 1.0) - 0.693).abs() < 0.01);
assert!((softplus(100.0, 1.0) - 100.0).abs() < 0.01);
assert!(softplus(-100.0, 1.0) < 1e-10);
assert!(softplus(-5.0, 1.0) >= 0.0);
assert!(softplus(0.0, 1.0) >= 0.0);
assert!(softplus(5.0, 1.0) >= 0.0);
}
#[test]
fn test_softplus_with_beta() {
let sharp = softplus(0.5, 10.0);
let soft = softplus(0.5, 0.1);
assert!(sharp > 0.0);
assert!(soft > 0.0);
let expected = (1.0_f32 / 2.0) * (1.0 + (2.0 * 3.0_f32).exp()).ln();
assert!((softplus(3.0, 2.0) - expected).abs() < 0.01);
}
#[test]
fn test_stable_logsumexp() {
assert!((stable_logsumexp(0.0, 0.0) - 0.693).abs() < 0.01);
assert!((stable_logsumexp(5.0, f32::NEG_INFINITY) - 5.0).abs() < 1e-6);
assert!(stable_logsumexp(3.0, 5.0) >= 5.0);
assert!(stable_logsumexp(-1.0, -3.0) >= -1.0);
assert!((stable_logsumexp(2.0, 7.0) - stable_logsumexp(7.0, 2.0)).abs() < 1e-6);
assert!(stable_logsumexp(100.0, 100.0).is_finite());
assert!(stable_logsumexp(-100.0, -100.0).is_finite());
}
#[test]
fn test_gumbel_lse_min_is_smooth_max() {
assert!(gumbel_lse_min(1.0, 3.0, 1.0) >= 3.0);
assert!(gumbel_lse_min(5.0, 2.0, 1.0) >= 5.0);
let hard_approx = gumbel_lse_min(1.0, 3.0, 0.01);
assert!((hard_approx - 3.0).abs() < 0.05, "got {hard_approx}");
}
#[test]
fn test_gumbel_lse_max_is_smooth_min() {
assert!(gumbel_lse_max(5.0, 3.0, 1.0) <= 3.0);
assert!(gumbel_lse_max(2.0, 7.0, 1.0) <= 2.0);
let hard_approx = gumbel_lse_max(5.0, 3.0, 0.01);
assert!((hard_approx - 3.0).abs() < 0.05, "got {hard_approx}");
}
#[test]
fn test_gumbel_lse_identity() {
let a = 3.0;
let b = 7.0;
let t = 1.5;
let via_max = gumbel_lse_max(a, b, t);
let via_min = -gumbel_lse_min(-a, -b, t);
assert!((via_max - via_min).abs() < 1e-5, "{via_max} vs {via_min}");
}
#[test]
fn test_bessel_side_length_basic() {
let sl = bessel_side_length(0.0, 10.0, 0.01, 0.01);
assert!(
(sl - 10.0).abs() < 0.1,
"large box at low T should have sl ~ 10, got {sl}"
);
let sl_zero = bessel_side_length(5.0, 5.0, 1.0, 1.0);
assert!(
sl_zero > 0.0,
"zero hard side should have positive Bessel side, got {sl_zero}"
);
}
#[test]
fn test_bessel_log_volume_basic() {
let (log_v, v) = bessel_log_volume(&[0.0, 0.0], &[5.0, 5.0], 0.01, 0.01);
assert!(v > 0.0, "volume should be positive");
assert!(log_v.is_finite(), "log volume should be finite");
assert!(
(v - 25.0).abs() < 2.0,
"at low T, vol should be ~25, got {v}"
);
}
#[test]
fn test_euler_gamma_value() {
assert!((EULER_GAMMA - 0.5772).abs() < 0.001);
}
mod proptests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn prop_softplus_monotone(
a in -50.0f32..50.0f32,
delta in 0.0f32..50.0f32,
beta in 0.01f32..10.0f32,
) {
let b = a + delta;
let sa = softplus(a, beta);
let sb = softplus(b, beta);
prop_assert!(
sb >= sa - 1e-6,
"softplus({a}, {beta})={sa} > softplus({b}, {beta})={sb}, violates monotonicity"
);
}
}
proptest! {
#[test]
fn prop_logsumexp_ge_max(
a in -100.0f32..100.0f32,
b in -100.0f32..100.0f32,
) {
let lse = stable_logsumexp(a, b);
let m = a.max(b);
prop_assert!(
lse >= m - 1e-6,
"stable_logsumexp({a}, {b})={lse} < max={m}"
);
}
}
proptest! {
#[test]
fn prop_gumbel_membership_prob_bounds(
x in -10.0f32..10.0f32,
min_val in -10.0f32..10.0f32,
width in 0.01f32..20.0f32,
temp in 0.1f32..10.0f32,
) {
let max_val = min_val + width;
let p = gumbel_membership_prob(x, min_val, max_val, temp);
prop_assert!(
(0.0..=1.0).contains(&p),
"gumbel_membership_prob({x}, {min_val}, {max_val}, {temp}) = {p} not in [0, 1]"
);
prop_assert!(p.is_finite(),
"gumbel_membership_prob must be finite, got {p}");
}
}
}
}