use crate::{
Density, SamplingMode, UnivariateDensity, domain::Domain, macros::tval,
univariate::match_univariate,
};
use derive_more::IntoIterator;
use nalgebra::{
DefaultAllocator, Dim, OVector, RealField, SVector, Scalar, U1, VectorView,
allocator::Allocator,
};
use rand::RngExt;
use rand_distr::{Distribution, StandardNormal, uniform::SampleUniform};
use serde::{Deserialize, Serialize};
use std::{f64, fmt::Debug, iter::repeat_with};
#[derive(Clone, Debug, Deserialize, IntoIterator, Serialize)]
#[serde(bound(serialize = "OVector<UnivariateDensity<T>, D>: Serialize"))]
#[serde(bound(deserialize = "OVector<UnivariateDensity<T>, D>: Deserialize<'de>"))]
pub struct MultivariateDensity<T, D>(#[into_iterator(owned, ref)] OVector<UnivariateDensity<T>, D>)
where
T: Scalar,
D: Dim,
DefaultAllocator: Allocator<D>;
impl<T, D> MultivariateDensity<T, D>
where
T: RealField,
D: Dim,
DefaultAllocator: Allocator<D>,
{
pub fn new(domains: OVector<UnivariateDensity<T>, D>) -> Self {
Self(domains)
}
pub fn marginals(&self) -> &OVector<UnivariateDensity<T>, D> {
&self.0
}
}
impl<T, D> Density<T, D> for MultivariateDensity<T, D>
where
T: RealField + SampleUniform,
D: Dim,
StandardNormal: Distribution<T>,
DefaultAllocator: Allocator<D>,
{
fn density<RStride: Dim, CStride: Dim>(
&self,
sample: &VectorView<T, D, RStride, CStride>,
) -> Option<T> {
if !self.domain().contains(sample) {
return None;
}
let mut rlh = T::one();
self.0.iter().zip(sample.iter()).for_each(|(uvpdf, value)| {
let vec = SVector::from([value.clone()]);
rlh *= match_univariate!(uvpdf, pdf, {
Density::<T, U1>::density::<U1, U1>(&pdf, &vec.as_view())
})
.unwrap_or(tval!(f64::NAN, f64));
});
Some(rlh)
}
fn domain(&self) -> Domain<T, D> {
Domain::new_mdomain(OVector::from_iterator_generic(
self.0.shape_generic().0,
U1,
self.0.iter().map(|uvpdf| {
let (a, b) = match uvpdf {
UnivariateDensity::Constant(pdf) => {
(Some(pdf.constant()), Some(pdf.constant()))
}
UnivariateDensity::Cosine(pdf) => (Some(pdf.minimum()), Some(pdf.maximum())),
UnivariateDensity::Lognormal(pdf) => (Some(pdf.minimum()), Some(pdf.maximum())),
UnivariateDensity::Loguniform(pdf) => {
(Some(pdf.minimum()), Some(pdf.maximum()))
}
UnivariateDensity::Normal(pdf) => (pdf.minimum(), pdf.maximum()),
UnivariateDensity::Uniform(pdf) => (Some(pdf.minimum()), Some(pdf.maximum())),
};
(a, b)
}),
))
}
fn mean(&self) -> OVector<T, D> {
OVector::from_iterator_generic(
self.0.shape_generic().0,
U1,
self.0
.iter()
.map(|uvpdf| match_univariate!(uvpdf, pdf, { pdf.mean() })[0].clone()),
)
}
fn sample(&self, rng: &mut impl RngExt, mode: &SamplingMode) -> Option<OVector<T, D>> {
let mut draw = OVector::<T, D>::zeros_generic(self.0.shape_generic().0, U1);
for i in 0..self.0.shape_generic().0.value() {
draw[i] = match_univariate!(&self.0[i], pdf, {
match Density::<T, U1>::sample(&pdf, rng, mode) {
Some(sample) => sample[0].clone(),
None => return None,
}
});
}
Some(draw)
}
fn sample_iter(&self, rng: &mut impl RngExt) -> impl Iterator<Item = Option<OVector<T, D>>> {
let n_dim = self.0.shape_generic().0;
repeat_with(move || {
let draw_opts = OVector::<Option<SVector<T, 1>>, D>::from_iterator_generic(
n_dim,
U1,
self.into_iter()
.map(|pdf| pdf.sample(rng, &SamplingMode::SingleAttempt)),
);
if draw_opts.iter().any(|draw| draw.is_none()) {
return None;
}
let draw = OVector::<T, D>::from_iterator_generic(
n_dim,
U1,
draw_opts.iter().map(|opt_draw| {
opt_draw.as_ref().unwrap()[0].clone()
}),
);
Some(draw)
})
}
fn variance(&self) -> OVector<T, D> {
OVector::from_iterator_generic(
self.0.shape_generic().0,
U1,
self.0
.iter()
.map(|uvpdf| match_univariate!(uvpdf, pdf, { pdf.variance() })[0].clone()),
)
}
}