gen-rs 0.2.3

a experimental library for probabilistic programming in Rust.
Documentation
use std::rc::Weak;
use nalgebra::DMatrix;
use gen_rs::{GLOBAL_RNG, Distribution, mvnormal, GenFn, Trace, GfDiff};
use super::model::PointedBuffer;
use super::types_2d::{Point,Bounds};


pub struct DriftProposal {
    pub drift_cov: DMatrix<f64>
}

pub type DriftProposalArgs = (Weak<Trace<Bounds,PointedBuffer,Point>>,());

impl GenFn<DriftProposalArgs,PointedBuffer,()> for DriftProposal {

    fn simulate(&self, args: DriftProposalArgs) -> Trace<DriftProposalArgs,PointedBuffer,()> {
        let prev_trace = args.0.upgrade().unwrap();
        let mut choices = (None, prev_trace.data.1.clone());

        GLOBAL_RNG.with_borrow_mut(|rng| {
            let new_latent = mvnormal.random(rng, (prev_trace.data.0.clone().unwrap(), self.drift_cov.clone()));
            choices.0 = Some(new_latent);
        });
        let logp = mvnormal.logpdf(&choices.0.clone().unwrap(), (prev_trace.data.0.clone().unwrap(), self.drift_cov.clone()));

        Trace::new(args, choices, (), logp)
    }

    fn generate(&self, args: DriftProposalArgs, constraints: PointedBuffer) -> (Trace<DriftProposalArgs,PointedBuffer,()>, f64) {
        let prev_trace = args.0.upgrade().unwrap();
        let mut choices = (None, prev_trace.data.1.clone());

        let new_latent: Point;
        let logp: f64;
        let mut weight = 0.;
        match constraints.0 {
            Some(latent_constraint) => {
                new_latent = latent_constraint;
                logp = mvnormal.logpdf(&new_latent, (prev_trace.data.0.clone().unwrap(), self.drift_cov.clone()));
                weight = logp;
            }
            None => {
                new_latent = GLOBAL_RNG.with_borrow_mut(|rng| {
                    mvnormal.random(rng, (prev_trace.data.0.clone().unwrap(), self.drift_cov.clone()))
                });
                logp = mvnormal.logpdf(&new_latent, (prev_trace.data.0.clone().unwrap(), self.drift_cov.clone()));
            }
        }
        choices.0 = Some(new_latent);

        (Trace::new(args, choices, (), logp), weight)
    }

    fn update(&self, _: Trace<DriftProposalArgs,PointedBuffer,()>, _: DriftProposalArgs, _: GfDiff, _: PointedBuffer) -> (Trace<DriftProposalArgs,PointedBuffer,()>, PointedBuffer, f64) {
        panic!("not implemented")
    }
}