use nalgebra::DVector;
use std::ops::Deref;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Weights {
Positive,
Negative,
Uniform,
}
impl Default for Weights {
fn default() -> Self {
Self::Negative
}
}
#[derive(Clone, Debug)]
pub(super) struct InitialWeights {
weights: DVector<f64>,
setting: Weights,
mu: usize,
}
impl InitialWeights {
pub fn new(lambda: usize, setting: Weights) -> Self {
let mu = lambda / 2;
let weights: DVector<f64> = match setting {
Weights::Uniform => vec![1.0; mu],
Weights::Positive => (1..=mu)
.map(|i| ((lambda as f64 + 1.0) / 2.0).ln() - (i as f64).ln())
.collect::<Vec<_>>(),
Weights::Negative => (1..=lambda)
.map(|i| ((lambda as f64 + 1.0) / 2.0).ln() - (i as f64).ln())
.collect::<Vec<_>>(),
}
.into();
Self {
weights,
setting,
mu,
}
}
pub fn mu(&self) -> usize {
self.mu
}
pub fn mu_eff(&self) -> f64 {
self.weights.iter().take(self.mu).sum::<f64>().powi(2)
/ self
.weights
.iter()
.take(self.mu)
.map(|w| w.powi(2))
.sum::<f64>()
}
pub fn mu_eff_minus(&self) -> Option<f64> {
if self.weights.len() > self.mu {
Some(
self.weights.iter().skip(self.mu).sum::<f64>().powi(2)
/ self
.weights
.iter()
.skip(self.mu)
.map(|w| w.powi(2))
.sum::<f64>(),
)
} else {
None
}
}
pub fn finalize(self, dim: usize, c1: f64, cmu: f64) -> FinalWeights {
let mu_eff = self.mu_eff();
let mu_eff_minus = self.mu_eff_minus();
let mut weights = self.weights;
let sum_positive_weights = weights.iter().filter(|w| **w > 0.0).sum::<f64>();
for w in &mut weights {
if *w > 0.0 {
*w /= sum_positive_weights;
}
}
if let Some(mu_eff_minus) = mu_eff_minus {
let a_mu = 1.0 + c1 / cmu;
let a_mu_eff = 1.0 + (2.0 * mu_eff_minus) / (mu_eff + 2.0);
let a_pos_def = (1.0 - c1 - cmu) / (dim as f64 * cmu);
let a = a_mu.min(a_mu_eff.min(a_pos_def));
let sum_negative_weights = weights.iter().filter(|w| **w < 0.0).sum::<f64>().abs();
for w in &mut weights {
if *w < 0.0 {
*w = *w * a / sum_negative_weights;
}
}
}
FinalWeights {
weights,
setting: self.setting,
}
}
}
#[derive(Clone, Debug)]
pub struct FinalWeights {
weights: DVector<f64>,
setting: Weights,
}
impl FinalWeights {
pub fn setting(&self) -> Weights {
self.setting
}
}
impl Deref for FinalWeights {
type Target = DVector<f64>;
fn deref(&self) -> &Self::Target {
&self.weights
}
}
#[cfg(test)]
mod tests {
use assert_approx_eq::assert_approx_eq;
use super::*;
#[test]
fn test_weights_positive() {
for lambda in 4..200 {
let initial_weights = InitialWeights::new(lambda, Weights::Positive);
assert!(initial_weights.weights.iter().all(|w| *w > 0.0));
let final_weights = initial_weights.finalize(6, 0.2, 0.8);
assert!(final_weights.weights.iter().all(|w| *w > 0.0));
assert_approx_eq!(final_weights.iter().sum::<f64>(), 1.0, 1e-12);
}
}
#[test]
fn test_weights_negative() {
for lambda in 4..200 {
let initial_weights = InitialWeights::new(lambda, Weights::Negative);
let mu = initial_weights.mu();
assert!(initial_weights.weights.iter().take(mu).all(|w| *w > 0.0));
assert!(initial_weights.weights.iter().skip(mu).all(|w| *w <= 0.0));
let final_weights = initial_weights.finalize(4, 0.5, 0.5);
assert_approx_eq!(final_weights.iter().take(mu).sum::<f64>(), 1.0, 1e-12);
assert!(final_weights.iter().take(mu).all(|w| *w > 0.0));
assert!(final_weights.iter().skip(mu).all(|w| *w <= 0.0));
}
}
}