use super::Gas;
use crate::candidate::Candidate;
#[mockall_double::double]
use crate::rando::Rando;
use std::sync::{
atomic::{AtomicBool, AtomicIsize, AtomicUsize, Ordering},
Arc, RwLock,
};
#[cfg_attr(test, allow(dead_code))]
pub struct CycleProgress<const N: usize, const NSYMS: usize> {
pub iteration: Arc<AtomicUsize>,
pub score: Arc<AtomicIsize>,
pub violations: Arc<AtomicUsize>,
pub progress: Arc<AtomicUsize>,
pub top: Arc<RwLock<Candidate<N, NSYMS>>>,
pub sigint: Arc<AtomicBool>,
pub seed_pool_size: Arc<AtomicUsize>,
pub diversity_violations: Arc<AtomicUsize>,
}
#[cfg_attr(test, allow(dead_code))]
impl<const N: usize, const NSYMS: usize> CycleProgress<N, NSYMS> {
pub fn new(gas: &Gas<N, NSYMS>, sigint: &Arc<AtomicBool>) -> CycleProgress<N, NSYMS> {
CycleProgress {
iteration: Arc::new(AtomicUsize::new(0)),
score: Arc::new(AtomicIsize::new(0)),
violations: Arc::new(AtomicUsize::new(0)),
progress: Arc::new(AtomicUsize::new(0)),
seed_pool_size: Arc::new(AtomicUsize::new(0)),
diversity_violations: Arc::new(AtomicUsize::new(0)),
top: Arc::new(RwLock::new(Candidate::from_chromosone(gas, [0; N]))),
sigint: Arc::clone(&sigint),
}
}
pub fn clone(&self) -> CycleProgress<N, NSYMS> {
CycleProgress {
iteration: Arc::clone(&self.iteration),
score: Arc::clone(&self.score),
violations: Arc::clone(&self.violations),
progress: Arc::clone(&self.progress),
seed_pool_size: Arc::clone(&self.seed_pool_size),
diversity_violations: Arc::clone(&self.diversity_violations),
top: Arc::clone(&self.top),
sigint: Arc::clone(&self.sigint),
}
}
pub fn eprint(&self) {
eprint!(
"[{}: {}/{}/{}-{} {}%]\t",
self.iteration.load(Ordering::Relaxed),
self.score.load(Ordering::Relaxed),
self.violations.load(Ordering::Relaxed),
self.seed_pool_size.load(Ordering::Relaxed),
self.diversity_violations.load(Ordering::Relaxed),
self.progress.load(Ordering::Relaxed),
);
}
}
impl<const N: usize, const NSYMS: usize> Gas<N, NSYMS> {
#[cfg_attr(test, allow(dead_code))]
pub fn cycle(&self, progress: &mut CycleProgress<N, NSYMS>) -> Candidate<N, NSYMS> {
let score_weights = self.fitness.weights();
let mut population = Vec::<Candidate<N, NSYMS>>::with_capacity(self.population_size);
let mut rng = Rando::new();
let mut seed_pool = Vec::<Candidate<N, NSYMS>>::new();
for _ in 0..self.population_size {
population.push(Candidate::new(self, &mut rng));
}
let mut ema99 = population[0].total_score(&score_weights);
let mut ema999 = ema99;
let mut cur_violations = population[0].violations;
let mut n_cur_violations = 1;
let mut best_score = 0f64;
let mut winners = Vec::<Candidate<N, NSYMS>>::with_capacity(population.len());
let mut stagnation_iteration = 1usize;
const EMA_FAST_CONST: f64 = 0.995;
const EMA_SLOW_CONST: f64 = 0.9995;
const VIOLATIONS_STAGNATION_THRESHOLD: usize = 100;
const SAMPLING_LENGTH: usize = 3;
enum State {
Seeding,
Running,
Stagnated,
}
let mut state = State::Seeding;
winners.push(population[0].clone());
seed_pool.push(population[0].clone());
for i in 0..(2 << 20) {
progress.iteration.store(i, Ordering::Relaxed);
population = self.generation(&population, &mut rng, &score_weights);
let ts = population[0].total_score(&score_weights);
progress.score.store(ts.round() as isize, Ordering::Relaxed);
progress
.violations
.store(population[0].violations, Ordering::Relaxed);
match progress.top.try_write() {
Err(_) => (),
Ok(mut l) => *l = population[0].clone(),
}
match state {
State::Seeding => {
if population[0].violations == 0
|| (n_cur_violations > VIOLATIONS_STAGNATION_THRESHOLD
&& population[0].violations <= seed_pool[0].violations)
{
if population[0].violations < seed_pool[0].violations {
seed_pool.clear();
}
seed_pool.push(population[0].clone());
if seed_pool.len() == self.population_size {
population = seed_pool.clone();
state = State::Running;
} else {
population.clear();
for _ in 0..self.population_size {
population.push(Candidate::new(self, &mut rng));
}
cur_violations = population[0].violations;
n_cur_violations = 0;
}
} else {
if population[0].violations < cur_violations {
cur_violations = population[0].violations;
n_cur_violations = 1;
} else if population[0].violations == cur_violations {
n_cur_violations += 1;
}
}
progress
.seed_pool_size
.store(seed_pool.len(), Ordering::Relaxed);
progress
.diversity_violations
.store(seed_pool[0].violations, Ordering::Relaxed);
}
State::Running => {
if ts > best_score {
best_score = ts;
winners[0] = population[0].clone();
}
ema99 = ema99 * EMA_FAST_CONST + ts * (1.0 - EMA_FAST_CONST);
ema999 = ema999 * EMA_SLOW_CONST + ts * (1.0 - EMA_SLOW_CONST);
if cur_violations == population[0].violations {
n_cur_violations += 1;
} else {
cur_violations = population[0].violations;
n_cur_violations = 1;
}
if ema99 < ema999 && n_cur_violations > VIOLATIONS_STAGNATION_THRESHOLD {
state = State::Stagnated;
stagnation_iteration = i;
winners.push(population[0].clone());
}
}
State::Stagnated => {
if ts > best_score {
if !winners
.iter()
.any(|c| c.chromosone == population[0].chromosone)
{
winners.push(population[0].clone());
if winners.len() >= population.len() {
break;
}
}
}
if i > (SAMPLING_LENGTH + 1) * stagnation_iteration {
break;
}
progress.progress.store(
usize::max(
(i - stagnation_iteration) * 100
/ (SAMPLING_LENGTH * stagnation_iteration),
winners.len() * 100 / population.len(),
),
Ordering::Relaxed,
);
}
}
if progress.sigint.load(Ordering::Relaxed) {
break;
}
}
let (winner, _) = self
.final_tournament
.run(&winners, &mut rng, &score_weights);
*progress.top.write().unwrap() = winner.clone();
winner
}
}