use std::marker::PhantomData;
use crate::{
integrators::{OrthotopeRandomIntegrator, DEFAULT_SAMPLE_COUNT},
MeanVariance, SolverError, SolverResult, VectorType,
};
use nalgebra::{allocator::Allocator, ComplexField, DefaultAllocator, Dim, Scalar};
use num_traits::Float;
use rand::Rng;
use rand_distr::{uniform::SampleUniform, Distribution, Uniform};
const DEFAULT_MAXIMUM_DEPTH: usize = 5;
const DEFAULT_MINIMUM_DIMENSIONAL_SAMPLE_COUNT: usize = 32;
const DEFAULT_ALPHA: f64 = 2.;
const DEFAULT_DITHER: f64 = 0.05;
pub struct MISER<F, V, T> {
f: F,
sample_count: usize,
maximum_depth: usize,
minimum_dimensional_sample_count: usize,
alpha: T,
dither: T,
t_phantom: PhantomData<T>,
v_phantom: PhantomData<V>,
}
impl<F, V, T> MISER<F, V, T>
where
T: Float,
F: Fn(V) -> T,
{
pub fn new(f: F) -> Self {
Self {
f,
sample_count: DEFAULT_SAMPLE_COUNT,
maximum_depth: DEFAULT_MAXIMUM_DEPTH,
alpha: T::from(DEFAULT_ALPHA).unwrap(),
dither: T::from(DEFAULT_DITHER).unwrap(),
minimum_dimensional_sample_count: DEFAULT_MINIMUM_DIMENSIONAL_SAMPLE_COUNT,
t_phantom: PhantomData,
v_phantom: PhantomData,
}
}
pub fn with_sample_count(&mut self, sample_count: usize) -> &mut Self {
self.sample_count = sample_count;
self
}
pub fn with_minimum_dimensional_sample_count(
&mut self,
minimum_dimensional_sample_count: usize,
) -> &mut Self {
self.minimum_dimensional_sample_count = minimum_dimensional_sample_count;
self
}
pub fn with_maximum_depth(&mut self, maximum_depth: usize) -> &mut Self {
self.maximum_depth = maximum_depth;
self
}
pub fn with_alpha(&mut self, alpha: T) -> &mut Self {
self.alpha = alpha;
self
}
pub fn with_dither(&mut self, dither: T) -> &mut Self {
self.dither = dither;
self
}
}
impl<F, T, D> OrthotopeRandomIntegrator<VectorType<T, D>, T> for MISER<F, VectorType<T, D>, T>
where
T: Float + Scalar + ComplexField<RealField = T> + SampleUniform,
D: Dim,
F: Fn(VectorType<T, D>) -> T,
DefaultAllocator: Allocator<D>,
{
fn integrate_with_rng(
&self,
mut from: VectorType<T, D>,
mut to: VectorType<T, D>,
rng: &mut impl rand::Rng,
) -> SolverResult<MeanVariance<T>> {
miser_recurse(
self,
&mut from,
&mut to,
rng,
self.maximum_depth,
self.sample_count,
)
}
}
fn uniform<T: SampleUniform>(lower: T, upper: T) -> SolverResult<Uniform<T>> {
Uniform::new_inclusive(lower, upper)
.map_err(|error| SolverError::ExternalError(error.to_string()))
}
fn miser_recurse<F, T, D>(
miser: &MISER<F, VectorType<T, D>, T>,
from: &mut VectorType<T, D>,
to: &mut VectorType<T, D>,
rng: &mut impl Rng,
depth_remaining: usize,
sample_count: usize,
) -> SolverResult<MeanVariance<T>>
where
T: Float + Scalar + ComplexField<RealField = T> + SampleUniform,
D: Dim,
F: Fn(VectorType<T, D>) -> T,
DefaultAllocator: Allocator<D>,
{
let half = T::from(0.5).unwrap();
let dimensions = from.len();
let minimum_sample_count = miser.minimum_dimensional_sample_count * dimensions;
let samplers = from
.iter()
.zip(to.iter())
.map(|(&lower, &upper)| uniform(lower, upper))
.collect::<Result<Vec<_>, _>>()?;
let mut vector_sampler = || {
let mut vector = from.clone();
for (element, sampler) in vector.iter_mut().zip(&samplers) {
*element = sampler.sample(rng);
}
vector
};
if depth_remaining == 0 || sample_count < minimum_sample_count {
let sample_count = sample_count.max(minimum_sample_count);
let volume = from
.iter()
.zip(to.iter())
.map(|(&lower, &upper)| upper - lower)
.fold(T::one(), T::mul);
return MeanVariance::from_iterator_using_welford(
&mut (0..sample_count).map(|_| (miser.f)(vector_sampler())),
false,
)
.map(|result| result.scale_mean(volume));
}
let sample_count = sample_count.max(2);
debug_assert!(2 <= sample_count && sample_count >= minimum_sample_count);
let mut best_variance = T::max_value();
let mut best_lower_variance = T::max_value();
let mut best_upper_variance = T::max_value();
let mut best_dimension = uniform(0, dimensions - 1)?.sample(rng);
let mut best_mid = from[best_dimension] + (to[best_dimension] - from[best_dimension]) * half;
let dither_sampler = uniform(-miser.dither, miser.dither)?;
for d in 0..dimensions {
let delta = to[d] - from[d];
let mid = from[d] + delta * (half + dither_sampler.sample(rng));
let lower_d_sampler = uniform(from[d], mid)?;
let upper_d_sampler = uniform(mid, to[d])?;
let mut sampler = |d_sampler: &Uniform<T>| {
let mut vector = from.clone();
for (inner_d, element) in vector.iter_mut().enumerate() {
if inner_d == d {
*element = d_sampler.sample(rng);
} else {
*element = samplers[inner_d].sample(rng);
}
}
vector
};
let variance_lower = MeanVariance::from_iterator_using_welford(
&mut (0..sample_count).map(|_| (miser.f)(sampler(&lower_d_sampler))),
true,
)?
.variance;
let variance_upper = MeanVariance::from_iterator_using_welford(
&mut (0..sample_count).map(|_| (miser.f)(sampler(&upper_d_sampler))),
true,
)?
.variance;
let total_variance = variance_lower + variance_upper;
if total_variance < best_variance {
best_variance = total_variance;
best_lower_variance = variance_lower;
best_upper_variance = variance_upper;
best_dimension = d;
best_mid = mid;
}
}
let old_from = from[best_dimension];
let old_to = to[best_dimension];
let beta = T::one() / (T::one() + miser.alpha);
let best_lower_variance_alpha = Float::powf(best_lower_variance, beta);
let best_upper_variance_alpha = Float::powf(best_upper_variance, beta);
let best_variance_alpha = best_lower_variance_alpha + best_upper_variance_alpha;
let sample_count_as_t = T::from(sample_count).ok_or(SolverError::TypeConversionError)?;
let lower_sample_fraction = best_lower_variance_alpha / best_variance_alpha;
let upper_sample_fraction = best_upper_variance_alpha / best_variance_alpha;
let lower_sample_count = (sample_count_as_t * lower_sample_fraction)
.to_usize()
.ok_or(SolverError::TypeConversionError)?;
let upper_sample_count = (sample_count_as_t * upper_sample_fraction)
.to_usize()
.ok_or(SolverError::TypeConversionError)?;
to[best_dimension] = best_mid;
let lower_result = miser_recurse(
miser,
from,
to,
rng,
depth_remaining - 1,
lower_sample_count,
)?;
to[best_dimension] = old_to;
from[best_dimension] = best_mid;
let upper_result = miser_recurse(
miser,
from,
to,
rng,
depth_remaining - 1,
upper_sample_count,
)?;
from[best_dimension] = old_from;
Ok(lower_result.add(&upper_result))
}