use crate::{SolverResult, VectorType, DEFAULT_ITERMAX, DEFAULT_TOL};
use nalgebra::{allocator::Allocator, ComplexField, DefaultAllocator, Dim, Scalar, UniformNorm};
use num_traits::Float;
use rand::rng;
use rand_distr::{Distribution, Normal, StandardNormal};
use std::marker::PhantomData;
const DEFAULT_SAMPLE_SIZE: usize = 100;
const DEFAULT_IMPORTANCE_SELECTION_SIZE: usize = 10;
pub struct CrossEntropy<T, D, F>
where
T: Float + Scalar,
D: Dim,
F: Fn(VectorType<T, D>) -> T,
DefaultAllocator: Allocator<D>,
{
f: F,
std_dev: Option<VectorType<T, D>>,
tolerance: T,
iter_max: usize,
sample_size: usize,
importance_selection_size: usize,
d_phantom: PhantomData<D>,
}
impl<T, D, F> CrossEntropy<T, D, F>
where
T: Float + Scalar + ComplexField<RealField = T>,
D: Dim,
F: Fn(VectorType<T, D>) -> T,
DefaultAllocator: Allocator<D>,
StandardNormal: Distribution<T>,
{
pub fn new(f: F) -> Self {
Self {
f,
std_dev: None,
tolerance: T::from(DEFAULT_TOL).unwrap(),
iter_max: DEFAULT_ITERMAX,
sample_size: DEFAULT_SAMPLE_SIZE,
importance_selection_size: DEFAULT_IMPORTANCE_SELECTION_SIZE,
d_phantom: PhantomData,
}
}
pub fn with_tol(&mut self, tolerance: T) -> &mut Self {
self.tolerance = tolerance;
self
}
pub fn with_iter_max(&mut self, iter_max: usize) -> &mut Self {
self.iter_max = iter_max;
self
}
pub fn with_sample_size(&mut self, sample_size: usize) -> &mut Self {
self.sample_size = sample_size;
self
}
pub fn with_importance_selection_size(
&mut self,
importance_selection_size: usize,
) -> &mut Self {
self.importance_selection_size = importance_selection_size;
self
}
pub fn with_std_dev(&mut self, std_dev: VectorType<T, D>) -> &mut Self {
self.std_dev = Some(std_dev);
self
}
pub fn solve(&self, x0: VectorType<T, D>) -> SolverResult<VectorType<T, D>> {
let mut mus = x0.clone();
let mut sigmas = self.std_dev.clone().unwrap_or_else(|| {
let mut x0 = x0.clone();
x0.fill(T::one());
x0
});
let mut iter = 1;
while sigmas.apply_norm(&UniformNorm) > self.tolerance && iter < self.iter_max {
let distributions: Vec<_> = mus
.iter()
.zip(sigmas.iter())
.map(|(&mu, &sigma)| Normal::new(mu, sigma).unwrap())
.collect();
let mut x_fx_pairs = vec![];
for _ in 0..self.sample_size {
let mut sample_x = x0.clone();
sample_x
.iter_mut()
.zip(distributions.iter())
.for_each(|(x, dist)| {
*x = dist.sample(&mut rng());
});
x_fx_pairs.push((sample_x.clone(), (self.f)(sample_x)));
}
x_fx_pairs.sort_unstable_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap());
let mut new_mus = x0.clone();
let mut new_sigmas = x0.clone();
for (i, (mu, sigma)) in new_mus.iter_mut().zip(new_sigmas.iter_mut()).enumerate() {
*mu = T::zero();
for (v, _) in x_fx_pairs.iter().take(self.importance_selection_size) {
*mu += v[i];
}
*mu /= T::from_usize(10).unwrap();
*sigma = T::zero();
for (v, _) in x_fx_pairs.iter().take(self.importance_selection_size) {
*sigma += Float::powi(v[i] - *mu, 2);
}
*sigma /= T::from_usize(self.importance_selection_size - 1).unwrap();
*sigma = Float::sqrt(*sigma);
}
mus = new_mus;
sigmas = new_sigmas;
iter += 1;
}
Ok(mus)
}
}