use crate::{
ConditionDifferentiableDistribution, DependentJoint, Distribution, DistributionError, Event,
IndependentJoint, RandomVariable, SampleableDistribution, ValueDifferentiableDistribution,
ValuedDistribution,
};
use opensrdk_linear_algebra::{Matrix, MatrixError, Vector};
use rand::prelude::*;
use std::{
fmt::Debug,
marker::PhantomData,
ops::{BitAnd, Mul},
};
#[derive(Clone)]
pub struct ValueDifferentiableValuedDistribution<D, T1, T2, U, F, G>
where
D: Distribution<Value = T1, Condition = U>,
T1: RandomVariable,
T2: RandomVariable,
U: Event,
F: Fn(&T2) -> Result<T1, DistributionError> + Clone + Send + Sync,
G: Fn(&T2) -> Matrix + Clone + Send + Sync,
{
valued_distribution: ValuedDistribution<D, T1, T2, U, F>,
value_diff: G,
phantom: PhantomData<T2>,
}
impl<D, T1, T2, U, F, G> ValueDifferentiableValuedDistribution<D, T1, T2, U, F, G>
where
D: Distribution<Value = T1, Condition = U>,
T1: RandomVariable,
T2: RandomVariable,
U: Event,
F: Fn(&T2) -> Result<T1, DistributionError> + Clone + Send + Sync,
G: Fn(&T2) -> Matrix + Clone + Send + Sync,
{
pub fn new(valued_distribution: ValuedDistribution<D, T1, T2, U, F>, value_diff: G) -> Self {
Self {
valued_distribution,
value_diff,
phantom: PhantomData,
}
}
}
impl<D, T1, T2, U, F, G> Debug for ValueDifferentiableValuedDistribution<D, T1, T2, U, F, G>
where
D: Distribution<Value = T1, Condition = U>,
T1: RandomVariable,
T2: RandomVariable,
U: Event,
F: Fn(&T2) -> Result<T1, DistributionError> + Clone + Send + Sync,
G: Fn(&T2) -> Matrix + Clone + Send + Sync,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"ValuedDistribution {{ distribution: {:#?} }}",
self.valued_distribution.distribution
)
}
}
impl<D, T1, T2, U, F, G> Distribution for ValueDifferentiableValuedDistribution<D, T1, T2, U, F, G>
where
D: Distribution<Value = T1, Condition = U>,
T1: RandomVariable,
T2: RandomVariable,
U: Event,
F: Fn(&T2) -> Result<T1, DistributionError> + Clone + Send + Sync,
G: Fn(&T2) -> Matrix + Clone + Send + Sync,
{
type Value = T2;
type Condition = U;
fn p_kernel(
&self,
x: &Self::Value,
theta: &Self::Condition,
) -> Result<f64, crate::DistributionError> {
self.valued_distribution
.distribution
.p_kernel(&(self.valued_distribution.value)(x)?, theta)
}
}
impl<D, T1, T2, U, Rhs, TRhs, F, G> Mul<Rhs>
for ValueDifferentiableValuedDistribution<D, T1, T2, U, F, G>
where
D: Distribution<Value = T1, Condition = U>,
T1: RandomVariable,
T2: RandomVariable,
U: Event,
Rhs: Distribution<Value = TRhs, Condition = U>,
TRhs: RandomVariable,
F: Fn(&T2) -> Result<T1, DistributionError> + Clone + Send + Sync,
G: Fn(&T2) -> Matrix + Clone + Send + Sync,
{
type Output = IndependentJoint<Self, Rhs, T2, TRhs, U>;
fn mul(self, rhs: Rhs) -> Self::Output {
IndependentJoint::new(self, rhs)
}
}
impl<D, T1, T2, U, Rhs, URhs, F, G> BitAnd<Rhs>
for ValueDifferentiableValuedDistribution<D, T1, T2, U, F, G>
where
D: Distribution<Value = T1, Condition = U>,
T1: RandomVariable,
T2: RandomVariable,
U: Event,
Rhs: Distribution<Value = U, Condition = URhs>,
URhs: RandomVariable,
F: Fn(&T2) -> Result<T1, DistributionError> + Clone + Send + Sync,
G: Fn(&T2) -> Matrix + Clone + Send + Sync,
{
type Output = DependentJoint<Self, Rhs, T2, U, URhs>;
fn bitand(self, rhs: Rhs) -> Self::Output {
DependentJoint::new(self, rhs)
}
}
impl<D, T1, T2, U, F, G> ValueDifferentiableDistribution
for ValueDifferentiableValuedDistribution<D, T1, T2, U, F, G>
where
D: Distribution<Value = T1, Condition = U> + ValueDifferentiableDistribution,
T1: RandomVariable,
T2: RandomVariable,
U: Event,
F: Fn(&T2) -> Result<T1, DistributionError> + Clone + Send + Sync,
G: Fn(&T2) -> Matrix + Clone + Send + Sync,
{
fn ln_diff_value(
&self,
x: &Self::Value,
theta: &Self::Condition,
) -> Result<Vec<f64>, DistributionError> {
let f = self
.valued_distribution
.distribution
.ln_diff_value(&(self.valued_distribution.value)(x)?, theta)
.unwrap();
let g = &(self.value_diff)(x);
let diff_mat = f.row_mat() * g.t();
let diff = diff_mat.vec();
Ok(diff)
}
}
impl<D, T1, T2, U, F, G> ConditionDifferentiableDistribution
for ValueDifferentiableValuedDistribution<D, T1, T2, U, F, G>
where
D: Distribution<Value = T1, Condition = U> + ConditionDifferentiableDistribution,
T1: RandomVariable,
T2: RandomVariable,
U: Event,
F: Fn(&T2) -> Result<T1, DistributionError> + Clone + Send + Sync,
G: Fn(&T2) -> Matrix + Clone + Send + Sync,
{
fn ln_diff_condition(
&self,
x: &Self::Value,
theta: &Self::Condition,
) -> Result<Vec<f64>, DistributionError> {
let f = self
.valued_distribution
.distribution
.ln_diff_condition(&(self.valued_distribution.value)(x)?, theta)
.unwrap();
Ok(f)
}
}
impl<D, T1, T2, U, F, G> SampleableDistribution
for ValueDifferentiableValuedDistribution<D, T1, T2, U, F, G>
where
D: Distribution<Value = T1, Condition = U>,
T1: RandomVariable,
T2: RandomVariable,
U: Event,
F: Fn(&T2) -> Result<T1, DistributionError> + Clone + Send + Sync,
G: Fn(&T2) -> Matrix + Clone + Send + Sync,
{
fn sample(
&self,
theta: &Self::Condition,
rng: &mut dyn RngCore,
) -> Result<Self::Value, crate::DistributionError> {
todo!()
}
}