use super::{clamp_float_finite, clamp_float_min};
use crate::utils::distributions::ArrayDistribution;
use once_cell::sync::OnceCell;
use tch::{Reduction, Tensor};
pub struct Bernoulli {
logits: Tensor,
probs: OnceCell<Tensor>,
}
impl Bernoulli {
#[must_use]
pub fn new(logits: Tensor) -> Self {
Self {
logits,
probs: OnceCell::new(),
}
}
}
impl ArrayDistribution<Tensor, Tensor> for Bernoulli {
fn batch_shape(&self) -> Vec<usize> {
self.logits
.size()
.iter()
.map(|&s| s.try_into().unwrap()) .collect()
}
fn element_shape(&self) -> Vec<usize> {
Vec::new()
}
fn sample(&self) -> Tensor {
self.probs
.get_or_init(|| self.logits.sigmoid())
.bernoulli()
.greater(0.5) }
fn log_probs(&self, elements: &Tensor) -> Tensor {
-clamp_float_finite(&self.logits)
.unwrap()
.binary_cross_entropy_with_logits::<&Tensor>(
&elements.to_kind(self.logits.kind()),
None,
None,
Reduction::None,
)
}
fn entropy(&self) -> Tensor {
clamp_float_finite(&self.logits)
.unwrap()
.binary_cross_entropy_with_logits::<&Tensor>(
self.probs.get_or_init(|| self.logits.sigmoid()),
None,
None,
Reduction::None,
)
}
fn kl_divergence_from(&self, other: &Self) -> Tensor {
let cross_entropy = clamp_float_min(&other.logits)
.unwrap()
.binary_cross_entropy_with_logits::<&Tensor>(
self.probs.get_or_init(|| self.logits.sigmoid()),
None,
None,
Reduction::None,
);
cross_entropy - self.entropy()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tch::{IndexOp, Kind};
#[test]
fn batch_shape_0d() {
let logits = Tensor::from(2.0);
let d = Bernoulli::new(logits);
assert_eq!(d.batch_shape(), [] as [usize; 0]);
}
#[test]
fn batch_shape_1d() {
let logits = Tensor::of_slice(&[-2.0, 0.0, 1.0]);
let d = Bernoulli::new(logits);
assert_eq!(d.batch_shape(), [3]);
}
#[test]
fn element_shape() {
let logits = Tensor::of_slice(&[-2.0, 0.0, 1.0]);
let d = Bernoulli::new(logits);
assert_eq!(d.element_shape(), [] as [usize; 0]);
}
#[test]
fn sample() {
let logits = Tensor::of_slice(&[f32::NEG_INFINITY, -1000.0, 0.0, 1000.0, f32::INFINITY]);
let d = Bernoulli::new(logits);
let samples = d.sample();
assert_eq!(samples.size(), [5]);
assert_eq!(samples.kind(), Kind::Bool);
let samples_vec: Vec<bool> = samples.into();
assert!(!samples_vec[0]);
assert!(!samples_vec[1]);
assert!(samples_vec[3]);
assert!(samples_vec[4]);
}
#[test]
fn log_probs() {
#[allow(clippy::cast_possible_truncation)]
fn log_sigmoid(logit: f64) -> f32 {
-(-logit).exp().ln_1p() as f32
}
let logits = Tensor::of_slice(&[f32::NEG_INFINITY, -2.0, 0.0, 1.0, 1.0, f32::INFINITY]);
let d = Bernoulli::new(logits);
let log_probs = d.log_probs(&Tensor::of_slice(&[true, true, true, true, false, true]));
assert_eq!(log_probs.size(), [6]);
let expected = [
f32::NEG_INFINITY,
log_sigmoid(-2.0),
(0.5f32).ln(),
log_sigmoid(1.0),
log_sigmoid(-1.0),
0.0,
];
assert!(
log_probs
.i(1..)
.allclose(&Tensor::of_slice(&expected[1..]), 1e-5, 1e-8, false),
"\nlog_probs: {log_probs:?}\nexpected: {expected:?}\n"
);
assert!(f32::from(log_probs.i(0)) <= f32::MIN / 2.0);
}
#[test]
fn entropies() {
#[allow(clippy::cast_possible_truncation)]
fn logit_entropy(logit: f64) -> f32 {
let p = (1.0 + (-logit).exp()).recip();
let h = -p * p.ln() - (1.0 - p) * (1.0 - p).ln();
h as f32
}
let logits = Tensor::of_slice(&[f32::NEG_INFINITY, -2.0, 0.0, 1.0, f32::INFINITY]);
let d = Bernoulli::new(logits);
let entropies = d.entropy();
assert_eq!(entropies.size(), [5]);
let expected = [
0.0,
logit_entropy(-2.0),
-(0.5f32).ln(),
logit_entropy(1.0),
0.0,
];
assert!(
entropies.allclose(&Tensor::of_slice(&expected), 1e-5, 1e-8, false),
"\nentropies: {entropies:?}\nexpected: {expected:?}\n"
);
}
#[test]
fn kl_divergence() {
#[allow(clippy::cast_possible_truncation)]
fn kl_div(lp: f64, lq: f64) -> f32 {
let p = (1.0 + (-lp).exp()).recip();
let q = (1.0 + (-lq).exp()).recip();
let kl = p * (p / q).ln() + (1.0 - p) * ((1.0 - p) / (1.0 - q)).ln();
kl as f32
}
let logit_pairs = [
(0.0, 0.0),
(1.0, 2.0),
(2.0, 1.0),
(0.0, f32::INFINITY),
(0.0f32, f32::NEG_INFINITY),
];
let p = Bernoulli::new(Tensor::of_slice(
&logit_pairs.iter().map(|(p, _)| *p).collect::<Vec<_>>(),
));
let q = Bernoulli::new(Tensor::of_slice(
&logit_pairs.iter().map(|(_, q)| *q).collect::<Vec<_>>(),
));
let kl_divs = p.kl_divergence_from(&q);
assert_eq!(kl_divs.size(), [logit_pairs.len() as i64]);
let expected = [
0.0,
kl_div(1.0, 2.0),
kl_div(2.0, 1.0),
f32::INFINITY,
f32::INFINITY,
];
assert!(
kl_divs
.i(..3)
.allclose(&Tensor::of_slice(&expected[..3]), 1e-5, 1e-8, false),
"\nkl_divs: {kl_divs:?}\nexpected: {expected:?}\n"
);
assert!(f32::from(kl_divs.i(3)) >= f32::MAX / 2.0);
assert!(f32::from(kl_divs.i(4)) >= f32::MAX / 2.0);
}
}