use super::clamp_float_min;
use crate::utils::distributions::ArrayDistribution;
use tch::{Kind, Tensor};
#[derive(Debug, PartialEq)]
pub struct Categorical {
log_probs: Tensor,
}
impl Categorical {
#[must_use]
pub fn new(unnormalized_log_probs: &Tensor) -> Self {
Self {
log_probs: unnormalized_log_probs.log_softmax(-1, Kind::Float),
}
}
}
impl ArrayDistribution<Tensor, Tensor> for Categorical {
fn batch_shape(&self) -> Vec<usize> {
self.log_probs
.size() .split_last() .unwrap()
.1
.iter()
.map(|&s| s.try_into().unwrap()) .collect()
}
fn element_shape(&self) -> Vec<usize> {
Vec::new()
}
fn sample(&self) -> Tensor {
self.log_probs.exp().multinomial(1, true).squeeze_dim(-1)
}
fn log_probs(&self, elements: &Tensor) -> Tensor {
self.log_probs
.gather(-1, &elements.unsqueeze(-1), false)
.squeeze_dim(-1)
}
fn entropy(&self) -> Tensor {
let clamped_log_probs = clamp_float_min(&self.log_probs)
.map_err(|kind| format!("log_probs must be f32 or f64, not {:?}", kind))
.unwrap();
-(clamped_log_probs * self.log_probs.exp()).sum_dim_intlist(&[-1], false, Kind::Float)
}
fn kl_divergence_from(&self, other: &Self) -> Tensor {
let clamped_rel_log_probs = clamp_float_min(&(&self.log_probs - &other.log_probs))
.map_err(|kind| format!("log_probs must be f32 or f64, not {:?}", kind))
.unwrap();
(clamped_rel_log_probs * self.log_probs.exp()).sum_dim_intlist(&[-1], false, Kind::Float)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tch::IndexOp;
#[test]
fn batch_shape_0d() {
let log_probs = Tensor::of_slice(&[0.0_f32, 1.0, 0.0]).log();
let d = Categorical::new(&log_probs);
assert_eq!(d.batch_shape(), [] as [usize; 0]);
}
#[test]
fn batch_shape_1d() {
let log_probs = Tensor::of_slice(&[
0.0_f32, 1.0, 0.0, 0.2, 0.3, 0.5, ])
.reshape(&[2, 3])
.log();
let d = Categorical::new(&log_probs);
assert_eq!(d.batch_shape(), [2]);
}
#[test]
fn element_shape() {
let log_probs = Tensor::of_slice(&[
0.0_f32, 1.0, 0.0, 0.2, 0.3, 0.5, ])
.reshape(&[2, 3])
.log();
let d = Categorical::new(&log_probs);
assert_eq!(d.element_shape(), [] as [usize; 0]);
}
#[test]
fn sample() {
let log_probs = Tensor::of_slice(&[
1.0_f32, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.3, 0.3, 0.4, ])
.reshape(&[4, 3])
.log();
let d = Categorical::new(&log_probs);
let samples = d.sample();
assert_eq!(samples.size(), [4]);
assert_eq!(samples.i(..3), Tensor::of_slice(&[0_i64, 1, 2]));
assert!((0..3).contains(&i64::from(samples.i(3))));
}
#[test]
fn log_probs() {
let log_probs = Tensor::of_slice(&[
f32::NEG_INFINITY,
0.0,
f32::NEG_INFINITY,
f32::NEG_INFINITY,
0.0,
f32::NEG_INFINITY,
f32::NEG_INFINITY,
0.0,
0.0,
f32::NEG_INFINITY,
0.0,
0.0,
-1.0,
0.0,
1.0,
-1.0,
0.0,
1.0,
-1.0,
0.0,
1.0,
0.0,
0.0,
0.0,
])
.reshape(&[-1, 3]);
let distribution = Categorical::new(&log_probs);
let elements = Tensor::of_slice(&[1_i64, 0, 2, 0, 0, 1, 2, 0]);
let log_normalizer = f32::ln(f32::exp(-1.0) + 1.0 + f32::exp(1.0));
let expected = Tensor::of_slice(&[
0.0,
f32::NEG_INFINITY,
-f32::ln(2.0),
f32::NEG_INFINITY,
-1.0 - log_normalizer,
-log_normalizer,
1.0 - log_normalizer,
f32::ln(3.0_f32.recip()),
]);
let actual = distribution.log_probs(&elements);
assert!(
Into::<bool>::into(expected.isclose(&actual, 1e-6, 1e-6, false).all()),
"expected: {:?}\nactual: {:?}",
expected,
actual
);
}
#[test]
fn entropies() {
let log_probs = Tensor::of_slice(&[
f32::NEG_INFINITY,
0.0,
f32::NEG_INFINITY,
f32::NEG_INFINITY,
0.0,
0.0,
0.0,
0.0,
0.0,
0.1_f32.ln(),
0.3_f32.ln(),
0.6_f32.ln(),
])
.reshape(&[-1, 3]);
let distribution = Categorical::new(&log_probs);
let actual = distribution.entropy();
let expected = Tensor::of_slice(&[
0.0,
-(0.5_f32.ln()),
-(3.0_f32.recip().ln()),
-0.1 * 0.1_f32.ln() - 0.3 * 0.3_f32.ln() - 0.6 * 0.6_f32.ln(),
]);
assert!(
Into::<bool>::into(expected.isclose(&actual, 1e-6, 1e-6, false).all()),
"expected: {:?}\nactual: {:?}",
expected,
actual
);
}
#[test]
fn kl_divergence() {
let log_probs_a = Tensor::of_slice(&[
0.2_f32, 0.3, 0.5, 0.2, 0.3, 0.5, 0.0, 1.0, 0.0, ])
.reshape(&[3, 3])
.log();
let distribution_a = Categorical::new(&log_probs_a);
let log_probs_b = Tensor::of_slice(&[
0.2_f32, 0.3, 0.5, 0.7, 0.2, 0.1, 0.2, 0.3, 0.5, ])
.reshape(&[3, 3])
.log();
let distribution_b = Categorical::new(&log_probs_b);
let actual = distribution_a.kl_divergence_from(&distribution_b);
let expected = Tensor::of_slice(&[
0.0_f32,
0.2 * (0.2_f32 / 0.7).ln() + 0.3 * (0.3_f32 / 0.2).ln() + 0.5 * (0.5_f32 / 0.1).ln(),
(1.0_f32 / 0.3).ln(),
]);
assert!(
Into::<bool>::into(expected.isclose(&actual, 1e-6, 1e-6, false).all()),
"expected: {:?}\nactual: {:?}",
expected,
actual
);
}
}