use std::marker::PhantomData;
use rand_distr::uniform::SampleUniform;
use crate::core::constraint::BoxConstraints;
use crate::core::executor::OptimizationResult;
use crate::core::inner::{InnerExecutor, WarmStart};
use crate::core::math::{SampleUniformBox, Scalar, ScaleInPlace, ScaledAdd, VectorLen};
use crate::core::problem::{CostFunction, Problem};
use crate::core::solver::Solver;
use crate::core::state::{BasicPopulationState, CountsMirror, State};
use crate::core::termination::{MaxCostEvals, TerminationCriterion, TerminationReason};
use crate::solver::cma_es::sort_population_ascending;
use crate::solver::cma_inject::MemeticInner;
use crate::solver::de::De;
pub struct DeInject<I, V, F = f64>
where
F: Scalar,
I: MemeticInner<V, F>,
{
de: De<F>,
inner: InnerExecutor<I::State, I>,
k: usize,
refine_every: u64,
_phantom: PhantomData<V>,
}
impl<I, V, F> DeInject<I, V, F>
where
F: Scalar,
I: MemeticInner<V, F>,
I::State: CountsMirror,
{
pub fn with_inner_solver(de: De<F>, inner: I) -> Self {
Self {
de,
inner: InnerExecutor::new(inner).max_iter(50),
k: 1,
refine_every: 1,
_phantom: PhantomData,
}
}
pub fn with_k(mut self, k: usize) -> Self {
assert!(k >= 1, "DeInject requires k >= 1, got {}", k);
self.k = k;
self
}
pub fn with_refine_every(mut self, n: u64) -> Self {
assert!(n >= 1, "DeInject requires refine_every >= 1, got {}", n);
self.refine_every = n;
self
}
pub fn with_inner_max_iter(self, n: u64) -> Self {
let Self {
de,
inner,
k,
refine_every,
_phantom,
} = self;
Self {
de,
inner: inner.max_iter(n),
k,
refine_every,
_phantom,
}
}
pub fn inner_terminate_on<C>(self, criterion: C) -> Self
where
C: TerminationCriterion<I::State> + 'static,
{
let Self {
de,
inner,
k,
refine_every,
_phantom,
} = self;
Self {
de,
inner: inner.terminate_on(criterion),
k,
refine_every,
_phantom,
}
}
}
impl<I, V, F> DeInject<I, V, F>
where
F: Scalar,
I: MemeticInner<V, F>,
I::State: CountsMirror,
MaxCostEvals: TerminationCriterion<I::State> + 'static,
{
pub fn with_ls_intensity(self, evals: u64) -> Self {
self.inner_terminate_on(MaxCostEvals(evals))
}
}
impl<P, I, V, F> Solver<P, BasicPopulationState<V, F>> for DeInject<I, V, F>
where
F: Scalar + SampleUniform,
P: CostFunction<Param = V, Output = F> + BoxConstraints<Param = V>,
I: MemeticInner<V, F> + Solver<P, <I as WarmStart<V>>::State, Error = P::Error>,
I::State: State<Param = V, Float = F> + CountsMirror,
V: VectorLen
+ Clone
+ SampleUniformBox
+ ScaledAdd<F>
+ ScaleInPlace<F>
+ std::ops::Index<usize, Output = F>
+ std::ops::IndexMut<usize, Output = F>,
De<F>: Solver<P, BasicPopulationState<V, F>, Error = P::Error>,
{
type Error = P::Error;
fn init(
&mut self,
problem: &mut Problem<P>,
state: BasicPopulationState<V, F>,
) -> Result<BasicPopulationState<V, F>, Self::Error> {
self.de.init(problem, state)
}
fn next_iter(
&mut self,
problem: &mut Problem<P>,
state: BasicPopulationState<V, F>,
) -> Result<(BasicPopulationState<V, F>, Option<TerminationReason>), Self::Error> {
let (mut state, reason) = self.de.next_iter(problem, state)?;
if let Some(r) = reason {
return Ok((state, Some(r)));
}
if state.iter() % self.refine_every != 0 {
return Ok((state, None));
}
let lo = problem.inner().lower().clone();
let hi = problem.inner().upper().clone();
let refine = self.k.min(state.candidates.len());
for i in 0..refine {
let x_seed = state.candidates[i].clone();
let c_orig = state.costs[i];
let inner_state = self.inner.solver().seed_scaled(&x_seed, F::one());
let inner_result: OptimizationResult<I::State> =
self.inner.run(problem, inner_state)?;
if inner_result.reason.is_failure() {
return Ok((state, Some(inner_result.reason)));
}
let mut x_refined = inner_result.state.param().clone();
let n = x_refined.vec_len();
for j in 0..n {
if x_refined[j] < lo[j] {
x_refined[j] = lo[j];
} else if x_refined[j] > hi[j] {
x_refined[j] = hi[j];
}
}
let c_refined = problem.cost(&x_refined)?;
if c_refined < c_orig {
state.candidates[i] = x_refined;
state.costs[i] = c_refined;
}
}
if refine > 0 {
sort_population_ascending(&mut state.candidates, &mut state.costs);
}
Ok((state, None))
}
fn terminate(&self, state: &BasicPopulationState<V, F>) -> Option<TerminationReason> {
<De<F> as Solver<P, BasicPopulationState<V, F>>>::terminate(&self.de, state)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::solver::nelder_mead::NelderMead;
#[test]
#[should_panic(expected = "DeInject requires k >= 1")]
fn with_k_zero_panics() {
let _ = DeInject::<NelderMead<crate::solver::nelder_mead::Unbounded, f64>, Vec<f64>, f64>::with_inner_solver(
De::new(0),
NelderMead::new(),
)
.with_k(0);
}
#[test]
#[should_panic(expected = "DeInject requires refine_every >= 1")]
fn with_refine_every_zero_panics() {
let _ = DeInject::<NelderMead<crate::solver::nelder_mead::Unbounded, f64>, Vec<f64>, f64>::with_inner_solver(
De::new(0),
NelderMead::new(),
)
.with_refine_every(0);
}
}