use std::collections::VecDeque;
use crate::core::math::{
ComponentMulAssign, MatrixFromDiagonal, MatrixIdentity, Scalar, ScaleInPlace, VectorLen,
};
use crate::core::problem::EvalCounts;
use crate::core::state::{CountsMirror, PopulationState, State};
pub struct CmaEsState<V, M, F = f64> {
pub(crate) candidates: Vec<V>,
pub(crate) costs: Vec<F>,
pub(crate) m: V,
pub(crate) m_cost: Option<F>,
pub(crate) sigma: F,
pub(crate) p_sigma: V,
pub(crate) p_c: V,
pub(crate) c: M,
pub(crate) b: M,
pub(crate) d: V,
pub(crate) d_inv: V,
pub(crate) generation: u64,
pub(crate) best_param: Option<V>,
pub(crate) best_cost: F,
pub(crate) best_iter: u64,
pub(crate) best_cost_evals: u64,
pub(crate) iter: u64,
pub(crate) cost_evals: u64,
pub(crate) penalty: Option<BoundPenalty<V, F>>,
}
pub(crate) struct BoundPenalty<V, F = f64> {
pub(crate) gamma: V,
pub(crate) weights_initialized: bool,
pub(crate) hist: VecDeque<F>,
pub(crate) raw_costs: Vec<F>,
}
impl<V, M, F> CmaEsState<V, M, F>
where
V: VectorLen + Clone + ScaleInPlace<F> + std::ops::IndexMut<usize, Output = F>,
M: MatrixIdentity,
F: Scalar,
{
pub fn new(mean: V, sigma: F) -> Self {
assert!(
sigma > F::zero(),
"CmaEsState requires sigma > 0, got {:?}",
sigma
);
let n = mean.vec_len();
assert!(n >= 1, "CmaEsState requires a non-empty mean");
let mut p_sigma = mean.clone();
p_sigma.scale_in_place(F::zero());
let p_c = p_sigma.clone();
let mut d = mean.clone();
let mut d_inv = mean.clone();
for i in 0..n {
d[i] = F::one();
d_inv[i] = F::one();
}
Self {
candidates: Vec::new(),
costs: Vec::new(),
m: mean,
m_cost: None,
sigma,
p_sigma,
p_c,
c: M::identity(n),
b: M::identity(n),
d,
d_inv,
generation: 0,
best_param: None,
best_cost: F::infinity(),
best_iter: 0,
best_cost_evals: 0,
iter: 0,
cost_evals: 0,
penalty: None,
}
}
}
impl<V, M, F> CmaEsState<V, M, F>
where
V: VectorLen + Clone + ComponentMulAssign + std::ops::IndexMut<usize, Output = F>,
M: MatrixFromDiagonal<V>,
F: Scalar,
{
pub fn with_stds(mut self, stds: V) -> Self {
let n = self.m.vec_len();
assert_eq!(
stds.vec_len(),
n,
"CmaEsState::with_stds requires stds.len() == mean.len(), got {} vs {}",
stds.vec_len(),
n
);
for i in 0..n {
assert!(
stds[i] > F::zero(),
"CmaEsState::with_stds requires every std > 0, got stds[{}] = {:?}",
i,
stds[i]
);
}
let mut sq = stds.clone();
sq.component_mul_assign(&stds);
self.c = M::from_diagonal(&sq);
for i in 0..n {
self.d[i] = stds[i];
self.d_inv[i] = F::one() / stds[i];
}
self
}
}
impl<V, M, F> CmaEsState<V, M, F>
where
V: VectorLen + std::ops::Index<usize, Output = F>,
F: Scalar,
{
pub fn mean(&self) -> &V {
&self.m
}
pub fn sigma(&self) -> F {
self.sigma
}
pub(crate) fn max_axis_std(&self) -> F {
let mut m = F::zero();
for i in 0..self.d.vec_len() {
let v = self.d[i];
if v > m {
m = v;
}
}
m
}
}
impl<V: Clone, M, F: Scalar> State for CmaEsState<V, M, F> {
type Param = V;
type Float = F;
fn iter(&self) -> u64 {
self.iter
}
fn increment_iter(&mut self) {
self.iter += 1;
}
fn cost_evals(&self) -> u64 {
self.cost_evals
}
fn param(&self) -> &V {
&self.m
}
fn cost(&self) -> F {
self.m_cost
.expect("CmaEsState::cost read before Solver::init evaluated the mean")
}
fn best_param(&self) -> &V {
self.best_param
.as_ref()
.expect("CmaEsState::best_param read before Solver::init populated it")
}
fn best_cost(&self) -> F {
self.best_cost
}
fn best_iter(&self) -> u64 {
self.best_iter
}
fn best_cost_evals(&self) -> u64 {
self.best_cost_evals
}
fn update_best(&mut self) {
if let Some(&best_sample_cost) = self.costs.first() {
if self.best_param.is_none() || best_sample_cost < self.best_cost {
self.best_param = Some(self.candidates[0].clone());
self.best_cost = best_sample_cost;
self.best_iter = self.iter;
self.best_cost_evals = self.cost_evals;
}
}
if let Some(mc) = self.m_cost {
if self.best_param.is_none() || mc < self.best_cost {
self.best_param = Some(self.m.clone());
self.best_cost = mc;
self.best_iter = self.iter;
self.best_cost_evals = self.cost_evals;
}
}
}
fn reset_best(&mut self) {
self.best_param = None;
self.best_cost = F::infinity();
self.best_iter = 0;
self.best_cost_evals = 0;
}
}
impl<V, M, F> CountsMirror for CmaEsState<V, M, F>
where
CmaEsState<V, M, F>: State,
{
fn mirror(&mut self, delta: &EvalCounts) {
self.cost_evals = delta.total_work();
}
}
impl<V: Clone, M, F: Scalar> PopulationState for CmaEsState<V, M, F> {
fn candidates(&self) -> &[V] {
&self.candidates
}
fn costs(&self) -> &[F] {
&self.costs
}
}