reinforcex 0.0.4

Deep Reinforcement Learning Framework
use super::base_distribution::BaseDistribution;
use tch::{Kind, Tensor};

pub struct GaussianDistribution {
    mean: Tensor,
    var: Tensor,
}

unsafe impl Sync for GaussianDistribution {}
unsafe impl Send for GaussianDistribution {}

impl GaussianDistribution {
    pub fn new(mean: Tensor, var: Tensor) -> Self {
        assert_eq!(mean.size(), var.size(), "mean and var must have same shape");
        GaussianDistribution { mean, var }
    }
}

impl BaseDistribution for GaussianDistribution {
    fn params(&self) -> (&Tensor, &Tensor) {
        (&self.mean, &self.var)
    }

    fn kl(&self, q: &Box<dyn BaseDistribution>) -> Tensor {
        let (q_mean, q_var) = q.params();
        let mean_diff = (&self.mean - q_mean).pow_tensor_scalar(2.0);
        let term1 = q_var.log() - &self.var.log();
        let term2 = (&self.var + mean_diff) / q_var;
        0.5 * (term1 + term2 - 1.0).sum_dim_intlist([-1].as_ref(), false, Kind::Float)
    }

    fn entropy(&self) -> Tensor {
        let dim = self.mean.size()[1];
        let log_term = 0.5 * ((2.0 * std::f64::consts::PI).ln() + 1.0);
        log_term * dim as f64
            + 0.5
                * self
                    .var
                    .log()
                    .sum_dim_intlist([-1].as_ref(), false, Kind::Float)
    }

    fn sample(&self) -> Tensor {
        let std = self.var.sqrt();
        let noise = Tensor::randn_like(&self.mean);
        (&self.mean + &std * noise).detach()
    }

    fn prob(&self, x: &Tensor) -> Tensor {
        self.log_prob(x).exp()
    }

    fn log_prob(&self, x: &Tensor) -> Tensor {
        let diff = (x - &self.mean).pow_tensor_scalar(2.0);
        let log_prob_each_dim: Tensor =
            -0.5 * ((2.0 * std::f64::consts::PI).ln() + &self.var.log() + &diff / &self.var);
        log_prob_each_dim.sum_dim_intlist([-1].as_ref(), false, Kind::Float)
    }

    fn copy(&self) -> Box<dyn BaseDistribution> {
        Box::new(Self::new(
            self.mean.shallow_clone().detach(),
            self.var.shallow_clone().detach(),
        ))
    }

    fn most_probable(&self) -> Tensor {
        self.mean.shallow_clone()
    }

    fn detach(&mut self) {
        self.mean = self.mean.detach();
        self.var = self.var.detach();
    }

    fn concat(&mut self, others: Vec<Box<dyn BaseDistribution>>) {
        let means = others
            .iter()
            .map(|d| d.params().0.shallow_clone())
            .collect::<Vec<Tensor>>();
        let vars = others
            .iter()
            .map(|d| d.params().1.shallow_clone())
            .collect::<Vec<Tensor>>();
        self.mean = Tensor::cat(&means, 0);
        self.var = Tensor::cat(&vars, 0);
    }

    fn all_prob(&self) -> Tensor {
        Tensor::new()
    }

    fn all_log_prob(&self) -> Tensor {
        Tensor::new()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use tch::Tensor;

    #[test]
    fn test_new_and_params() {
        let mean = Tensor::from_slice(&[0.0, 0.0, 0.0]).view([1, 3]);
        let var = Tensor::from_slice(&[1.0, 1.0, 1.0]).view([1, 3]);
        let gaussian: GaussianDistribution =
            GaussianDistribution::new(mean.shallow_clone(), var.shallow_clone());

        let (mean_out, var_out) = gaussian.params();
        assert_eq!(mean_out, &mean);
        assert_eq!(var_out, &var);
    }

    #[test]
    fn test_most_probable() {
        let mean = Tensor::from_slice(&[0.0, 1.0]).view([1, 2]);
        let var = Tensor::from_slice(&[1.0, 4.0]).view([1, 2]);
        let gaussian: GaussianDistribution = GaussianDistribution::new(mean.shallow_clone(), var);

        let most_probable = gaussian.most_probable();
        assert_eq!(most_probable, mean);
    }

    #[test]
    fn test_sample() {
        let mean = Tensor::from_slice(&[0.0, 1.0]).view([1, 2]);
        let var = Tensor::from_slice(&[1.0, 4.0]).view([1, 2]);
        let gaussian: GaussianDistribution = GaussianDistribution::new(mean, var);

        let sample = gaussian.sample();
        assert_eq!(sample.size(), vec![1, 2]);
    }

    #[test]
    fn test_log_prob() {
        let mean = Tensor::from_slice(&[0.0]).view([1, 1]);
        let var = Tensor::from_slice(&[1.0]).view([1, 1]);
        let gaussian = GaussianDistribution::new(mean, var);

        let x = Tensor::from_slice(&[1.0, 2.0, 3.0]).view([1, 3]);
        let log_prob = gaussian.log_prob(&x);
        let expected_log_prob: f64 = -9.7568156;
        assert!((log_prob.double_value(&[]) - expected_log_prob).abs() < 1e-6);
    }

    #[test]
    fn test_kl_divergence() {
        let mean_p = Tensor::from_slice(&[0.0]).view([1, 1]);
        let var_p = Tensor::from_slice(&[1.0]).view([1, 1]);
        let gaussian_p: GaussianDistribution = GaussianDistribution::new(mean_p, var_p);

        let mean_q = Tensor::from_slice(&[0.8]).view([1, 1]);
        let var_q = Tensor::from_slice(&[1.5]).view([1, 1]);
        let gaussian_q: Box<dyn BaseDistribution> =
            Box::new(GaussianDistribution::new(mean_q, var_q));

        let kl_div = gaussian_p.kl(&gaussian_q);
        let expected_kl: f64 = 0.249399221;
        assert!((kl_div.double_value(&[]) - expected_kl).abs() < 1e-6);
    }

    #[test]
    fn test_entropy() {
        let mean = Tensor::from_slice(&[0.0]).view([1, 1]);
        let var = Tensor::from_slice(&[1.0]).view([1, 1]);
        let gaussian: GaussianDistribution = GaussianDistribution::new(mean, var);

        let entropy = gaussian.entropy();
        let expected_entropy: f64 = 1.418938533;
        assert!((entropy.double_value(&[]) - expected_entropy).abs() < 1e-6);

        let mean = Tensor::from_slice(&[0.0, 0.0]).view([1, 2]);
        let var = Tensor::from_slice(&[1.0, 1.0]).view([1, 2]);
        let gaussian: GaussianDistribution = GaussianDistribution::new(mean, var);

        let entropy = gaussian.entropy();
        let expected_entropy: f64 = 1.418938533 * 2.0;
        assert!((entropy.double_value(&[]) - expected_entropy).abs() < 1e-6);
    }
}