mod constant;
mod cosine;
mod loguniform;
mod normal;
mod uniform;
pub use constant::*;
pub use cosine::*;
pub use loguniform::*;
pub use normal::*;
pub use uniform::*;
use crate::{Density, SamplingMode, domain::Domain, macros::tval};
use derive_more::IntoIterator;
use nalgebra::{
DefaultAllocator, Dim, OVector, RealField, SVector, U1, VectorView, allocator::Allocator,
};
use rand::RngExt;
use rand_distr::{Distribution, StandardNormal, uniform::SampleUniform};
use serde::{Deserialize, Serialize};
use std::{f64, fmt::Debug, iter::repeat_with};
macro_rules! match_univariate {
($uvpdf:expr, $pat:pat, $body:expr) => {
match $uvpdf {
UnivariateDensity::Constant($pat) => $body,
UnivariateDensity::Cosine($pat) => $body,
UnivariateDensity::Loguniform($pat) => $body,
UnivariateDensity::Normal($pat) => $body,
UnivariateDensity::Uniform($pat) => $body,
}
};
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
#[serde(tag = "type", content = "content")]
pub enum UnivariateDensity<T>
where
T: RealField,
{
Constant(ConstantDensity<T>),
Cosine(CosineDensity<T>),
Loguniform(LogUniformDensity<T>),
Normal(NormalDensity<T>),
Uniform(UniformDensity<T>),
}
impl<T> From<ConstantDensity<T>> for UnivariateDensity<T>
where
T: RealField,
{
fn from(value: ConstantDensity<T>) -> Self {
Self::Constant(value)
}
}
impl<T> From<CosineDensity<T>> for UnivariateDensity<T>
where
T: RealField,
{
fn from(value: CosineDensity<T>) -> Self {
Self::Cosine(value)
}
}
impl<T> From<LogUniformDensity<T>> for UnivariateDensity<T>
where
T: RealField,
{
fn from(value: LogUniformDensity<T>) -> Self {
Self::Loguniform(value)
}
}
impl<T> From<NormalDensity<T>> for UnivariateDensity<T>
where
T: RealField,
{
fn from(value: NormalDensity<T>) -> Self {
Self::Normal(value)
}
}
impl<T> From<UniformDensity<T>> for UnivariateDensity<T>
where
T: RealField,
{
fn from(value: UniformDensity<T>) -> Self {
Self::Uniform(value)
}
}
impl<T> UnivariateDensity<T>
where
T: RealField,
{
pub fn as_constant(&self) -> Option<&ConstantDensity<T>> {
match self {
UnivariateDensity::Constant(pdf) => Some(pdf),
_ => None,
}
}
pub fn as_cosine(&self) -> Option<&CosineDensity<T>> {
match self {
UnivariateDensity::Cosine(pdf) => Some(pdf),
_ => None,
}
}
pub fn as_loguniform(&self) -> Option<&LogUniformDensity<T>> {
match self {
UnivariateDensity::Loguniform(pdf) => Some(pdf),
_ => None,
}
}
pub fn as_normal(&self) -> Option<&NormalDensity<T>> {
match self {
UnivariateDensity::Normal(pdf) => Some(pdf),
_ => None,
}
}
pub fn as_uniform(&self) -> Option<&UniformDensity<T>> {
match self {
UnivariateDensity::Uniform(pdf) => Some(pdf),
_ => None,
}
}
}
impl<T> Density<T, U1> for &UnivariateDensity<T>
where
T: RealField + SampleUniform,
StandardNormal: Distribution<T>,
{
fn density<RStride: Dim, CStride: Dim>(
&self,
sample: &VectorView<T, U1, RStride, CStride>,
) -> Option<T> {
match_univariate!(self, pdf, {
Density::<T, U1>::density::<RStride, CStride>(&pdf, sample)
})
}
fn domain(&self) -> Domain<T, U1> {
let (a, b) = match_univariate!(self, pdf, {
(
pdf.domain().minimum_values()[0].clone(),
pdf.domain().maximum_values()[0].clone(),
)
});
Domain::new_mdomain(OVector::from_element_generic(U1, U1, (a, b)))
}
fn sample(&self, rng: &mut impl RngExt, mode: &SamplingMode) -> Option<OVector<T, U1>> {
let sample = match_univariate!(self, pdf, {
match Density::<T, U1>::sample(&pdf, rng, mode) {
Some(draw) => draw[0].clone(),
None => return None,
}
});
Some(OVector::from([sample]))
}
fn sample_iter(&self, rng: &mut impl RngExt) -> impl Iterator<Item = Option<OVector<T, U1>>> {
repeat_with(move || {
match_univariate!(self, pdf, {
Density::<T, U1>::sample(&pdf, rng, &SamplingMode::SingleAttempt)
.map(|value| OVector::from([value[0].clone()]))
})
})
}
}
#[derive(Clone, Debug, Deserialize, IntoIterator, Serialize)]
#[serde(bound(serialize = "OVector<UnivariateDensity<T>, D>: Serialize"))]
#[serde(bound(deserialize = "OVector<UnivariateDensity<T>, D>: Deserialize<'de>"))]
pub struct MultivariateDensity<T, D>(#[into_iterator(owned, ref)] OVector<UnivariateDensity<T>, D>)
where
T: RealField,
D: Dim,
DefaultAllocator: Allocator<D>;
impl<T, D> MultivariateDensity<T, D>
where
T: RealField,
D: Dim,
DefaultAllocator: Allocator<D>,
{
pub fn new(domains: OVector<UnivariateDensity<T>, D>) -> Self {
Self(domains)
}
}
impl<T, D> Density<T, D> for &MultivariateDensity<T, D>
where
T: RealField + SampleUniform,
D: Dim,
StandardNormal: Distribution<T>,
DefaultAllocator: Allocator<D>,
{
fn density<RStride: Dim, CStride: Dim>(
&self,
sample: &VectorView<T, D, RStride, CStride>,
) -> Option<T> {
if !self.domain().contains(sample) {
return None;
}
let mut rlh = T::one();
self.0.iter().zip(sample.iter()).for_each(|(uvpdf, value)| {
let vec = SVector::from([value.clone()]);
rlh *= match_univariate!(uvpdf, pdf, {
Density::<T, U1>::density::<U1, U1>(&pdf, &vec.as_view())
})
.unwrap_or(tval!(f64::NAN, f64));
});
Some(rlh)
}
fn domain(&self) -> Domain<T, D> {
Domain::new_mdomain(OVector::from_iterator_generic(
self.0.shape_generic().0,
U1,
self.0.iter().map(|uvpdf| {
let (a, b) = match uvpdf {
UnivariateDensity::Constant(pdf) => {
(Some(pdf.constant()), Some(pdf.constant()))
}
UnivariateDensity::Cosine(pdf) => (Some(pdf.minimum()), Some(pdf.maximum())),
UnivariateDensity::Loguniform(pdf) => {
(Some(pdf.minimum()), Some(pdf.maximum()))
}
UnivariateDensity::Normal(pdf) => (pdf.minimum(), pdf.maximum()),
UnivariateDensity::Uniform(pdf) => (Some(pdf.minimum()), Some(pdf.maximum())),
};
(a, b)
}),
))
}
fn sample(&self, rng: &mut impl RngExt, mode: &SamplingMode) -> Option<OVector<T, D>> {
let mut draw = OVector::<T, D>::zeros_generic(self.0.shape_generic().0, U1);
for i in 0..self.0.shape_generic().0.value() {
draw[i] = match_univariate!(&self.0[i], pdf, {
match Density::<T, U1>::sample(&pdf, rng, mode) {
Some(sample) => sample[0].clone(),
None => return None,
}
});
}
Some(draw)
}
fn sample_iter(&self, rng: &mut impl RngExt) -> impl Iterator<Item = Option<OVector<T, D>>> {
let n_dim = self.0.shape_generic().0;
repeat_with(move || {
let draw_opts = OVector::<Option<SVector<T, 1>>, D>::from_iterator_generic(
n_dim,
U1,
self.into_iter()
.map(|pdf| pdf.sample(rng, &SamplingMode::SingleAttempt)),
);
if draw_opts.iter().any(|draw| draw.is_none()) {
return None;
}
let draw = OVector::<T, D>::from_iterator_generic(
n_dim,
U1,
draw_opts.iter().map(|opt_draw| {
opt_draw.clone().unwrap()[0].clone()
}),
);
Some(draw)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::ulps_eq;
use nalgebra::{OVector, SVector, U2, U5};
use rand::SeedableRng;
use rand_xoshiro::Xoshiro256PlusPlus;
#[test]
fn test_multivariate_density_independent_evaluation() {
let const_a = ConstantDensity::new(1.0);
let const_b = ConstantDensity::new(2.0);
assert!(ulps_eq!(
(&const_a)
.density::<U1, U1>(&OVector::from([1.0]).as_view())
.unwrap(),
(&const_b)
.density::<U1, U1>(&OVector::from([2.0]).as_view())
.unwrap()
));
let cosine = CosineDensity::new(-0.1, 0.35).unwrap();
assert!(ulps_eq!(
(&cosine)
.density::<U1, U1>(&OVector::from([0.0]).as_view())
.unwrap(),
1.0
));
assert!(ulps_eq!(
(&cosine)
.density::<U1, U1>(&OVector::from([0.3]).as_view())
.unwrap(),
0.955336489125606
));
assert!(ulps_eq!(
(&cosine)
.density::<U1, U1>(&OVector::from([0.3]).as_view())
.unwrap(),
0.955336489125606
));
let loguniform = LogUniformDensity::new(0.5, 2.5).unwrap();
assert!(ulps_eq!(
(&loguniform)
.density::<U1, U1>(&OVector::from([1.0]).as_view())
.unwrap(),
2.0 * (&loguniform)
.density::<U1, U1>(&OVector::from([2.0]).as_view())
.unwrap()
));
let normal = NormalDensity::new(1.0, 1.0, None, None).unwrap();
assert!(ulps_eq!(
(&normal)
.density::<U1, U1>(&OVector::from([1.0]).as_view())
.unwrap(),
0.3989422804014327
));
assert!(ulps_eq!(
(&normal)
.density::<U1, U1>(&OVector::from([2.0]).as_view())
.unwrap(),
0.24197072451914337
));
let uniform = UniformDensity::new(0.0, 1.0).unwrap();
assert!(ulps_eq!(
(&uniform)
.density::<U1, U1>(&OVector::from([0.5]).as_view())
.unwrap(),
1.0
));
let uvpdf = &MultivariateDensity::new(SVector::from([
ConstantDensity::new(1.0).into(),
CosineDensity::new(0.1, 0.2).unwrap().into(),
LogUniformDensity::new(0.1, 0.5).unwrap().into(),
NormalDensity::new(0.1, 0.25, Some(-0.5), Some(1.5))
.unwrap()
.into(),
UniformDensity::new(1.0, 2.0).unwrap().into(),
]));
assert!(ulps_eq!(
uvpdf
.density::<U1, U5>(&SVector::from([1.0f64, 0.15, 0.15, 0.2, 1.5]).as_view())
.unwrap(),
6.033325,
epsilon = 1e-5,
max_ulps = 5
));
assert!(
uvpdf
.density::<U1, U5>(&SVector::from([1.0, 0.05, 0.2, 0.15, 1.5]).as_view())
.is_none()
);
let mut rng = Xoshiro256PlusPlus::seed_from_u64(1);
assert!(ulps_eq!(
uvpdf
.sample(&mut rng, &SamplingMode::UntilValid { max_attempts: 100 })
.unwrap(),
OVector::from([1.0, 0.1810371, 0.33281568, -0.37896788, 1.7462168,]),
epsilon = 1e-5,
max_ulps = 5
));
assert!(
uvpdf.domain().contains::<U1, U5>(
&uvpdf
.sample(&mut rng, &SamplingMode::UntilValid { max_attempts: 100 })
.unwrap()
.as_view()
)
);
}
#[test]
fn test_multivariate_sample_iter_independent_marginals() {
let mut rng = Xoshiro256PlusPlus::seed_from_u64(42);
let mvpdf = &MultivariateDensity::new(SVector::from([
NormalDensity::new(0.0, 1.0, None, None).unwrap().into(),
UniformDensity::new(0.0, 1.0).unwrap().into(),
]));
let samples: Vec<_> = mvpdf.sample_iter(&mut rng).take(100).flatten().collect();
assert_eq!(samples.len(), 100);
let mean_0: f64 = samples.iter().map(|s| s[0]).sum::<f64>() / samples.len() as f64;
let mean_1: f64 = samples.iter().map(|s| s[1]).sum::<f64>() / samples.len() as f64;
assert!(mean_0.abs() < 0.3);
assert!((mean_1 - 0.5).abs() < 0.15);
for sample in &samples {
assert!(sample[1] >= 0.0 && sample[1] <= 1.0);
}
}
#[test]
fn test_multivariate_sample_iter_5d() {
let mut rng = Xoshiro256PlusPlus::seed_from_u64(123);
let mvpdf = &MultivariateDensity::new(SVector::from([
ConstantDensity::new(1.0).into(),
CosineDensity::new(0.1, 0.2).unwrap().into(),
LogUniformDensity::new(0.1, 0.5).unwrap().into(),
NormalDensity::new(0.1, 0.25, Some(-0.5), Some(1.5))
.unwrap()
.into(),
UniformDensity::new(1.0, 2.0).unwrap().into(),
]));
let samples: Vec<_> = mvpdf.sample_iter(&mut rng).take(50).flatten().collect();
assert!(!samples.is_empty());
for sample in &samples {
assert!(mvpdf.domain().contains::<U1, U5>(&sample.as_view()));
}
}
#[test]
fn test_multivariate_sample_iter_domain_enforcement() {
let mut rng = Xoshiro256PlusPlus::seed_from_u64(42);
let mvpdf = &MultivariateDensity::new(SVector::from([
NormalDensity::new(0.0, 1.0, Some(-1.0), Some(1.0))
.unwrap()
.into(),
UniformDensity::new(0.0, 1.0).unwrap().into(),
]));
let results: Vec<_> = mvpdf.sample_iter(&mut rng).take(200).collect();
let none_count = results.iter().filter(|r| r.is_none()).count();
assert!(
none_count > 0,
"Expected some rejections due to domain constraints"
);
for result in results.iter().flatten() {
assert!(mvpdf.domain().contains::<U1, U2>(&result.as_view()));
}
}
}