#![doc = include_str!("../README.md")]
#![deny(missing_docs)]
mod domain;
mod macros;
mod multinormal;
mod multivariate;
mod particle;
#[cfg(feature = "pyo3")]
pub mod pytypes;
pub use domain::Domain;
pub use multinormal::MultivariateNormalDensity;
pub use multivariate::{
ConstantDensity, CosineDensity, LogUniformDensity, MultivariateDensity, NormalDensity,
UniformDensity, UnivariateDensity,
};
pub use particle::ParticleDensity;
use nalgebra::{DefaultAllocator, Dim, OVector, RealField, VectorView, allocator::Allocator};
use rand::RngExt;
use serde::{Deserialize, Serialize};
pub trait Density<T, D>
where
T: RealField,
D: Dim,
{
fn density<RStride: Dim, CStride: Dim>(
&self,
sample: &VectorView<T, D, RStride, CStride>,
) -> Option<T>;
fn domain(&self) -> Domain<T, D>
where
DefaultAllocator: Allocator<D>;
fn sample(&self, rng: &mut impl RngExt, mode: &SamplingMode) -> Option<OVector<T, D>>
where
DefaultAllocator: Allocator<D>;
fn sample_iter(&self, rng: &mut impl RngExt) -> impl Iterator<Item = Option<OVector<T, D>>>
where
DefaultAllocator: Allocator<D>;
}
pub trait RejectionSampler<T, D>: Density<T, D>
where
T: RealField,
D: Dim,
{
fn generate_candidate(&self, rng: &mut impl RngExt) -> OVector<T, D>
where
DefaultAllocator: Allocator<D>;
fn rejection_sample(&self, rng: &mut impl RngExt, mode: &SamplingMode) -> Option<OVector<T, D>>
where
DefaultAllocator: Allocator<D>,
{
let mut attempts = 0;
loop {
let candidate = self.generate_candidate(rng);
if self.domain().contains(&candidate.as_view()) {
return Some(candidate);
}
match mode {
SamplingMode::SingleAttempt => return None,
SamplingMode::UntilValid { max_attempts } => {
if attempts >= *max_attempts {
return None;
}
attempts += 1;
}
SamplingMode::UntilValidNoLimit => {
attempts += 1;
}
SamplingMode::UntilValidOrClamp { max_attempts } => {
if attempts >= *max_attempts {
return Some(self.domain().clamp(&candidate.as_view()));
}
attempts += 1;
}
}
}
}
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum SamplingMode {
SingleAttempt,
UntilValid {
max_attempts: usize,
},
UntilValidOrClamp {
max_attempts: usize,
},
UntilValidNoLimit,
}
impl Default for SamplingMode {
fn default() -> Self {
SamplingMode::UntilValid { max_attempts: 512 }
}
}