use core::fmt;
use std::{
ops::{Add, AddAssign, Mul},
random::RandomSource,
sync::{Arc, RwLock, mpmc, mpsc},
thread::{self, ScopedJoinHandle},
};
#[allow(unused_imports)]
use crate::continuous::ContinuousTrainer;
use crate::{
Gate, PopulationStats, TrainingReportStrategy, num_cpus, random_choice_weighted_mapped_by_key,
};
pub struct InertialTrainer<'scope, G>
where
G: InertialGenome,
{
pub gene_pool: Arc<RwLock<Vec<RankedGenome<G>>>>,
pub children_created: usize,
pub mutation_rate: f32,
pub inertia: f32,
pub damping: f32,
work_submission: mpmc::Sender<FirstStageJob<G>>,
#[allow(unused)]
first_stage_worker_pool: Vec<ScopedJoinHandle<'scope, ()>>,
#[allow(unused)]
second_stage_worker_pool: Vec<ScopedJoinHandle<'scope, ()>>,
#[allow(unused)]
receiver_thread: ScopedJoinHandle<'scope, ()>,
population_size: usize,
in_flight: Gate<usize>,
}
impl<'scope, G> InertialTrainer<'scope, G>
where
G: InertialGenome,
{
pub fn new(
population_size: usize,
mutation_rate: f32,
inertia: f32,
damping: f32,
scope: &'scope thread::Scope<'scope, '_>,
) -> Self
where
G: 'scope + Send + Sync,
{
let in_flight = Gate::new(0);
let (work_submission, inbox) = mpmc::channel();
let (first_stage_outbox, second_stage_inbox) = mpmc::channel();
let (outbox, work_reception) = mpsc::channel();
let gene_pool = Arc::new(RwLock::new(Vec::new()));
let first_stage_worker_pool = (0..num_cpus())
.map(|_| {
let inbox = inbox.clone();
let outbox = outbox.clone();
let first_stage_outbox = first_stage_outbox.clone();
scope.spawn(move || Self::first_stage_worker(inbox, outbox, first_stage_outbox))
})
.collect();
let second_stage_worker_pool = (0..num_cpus())
.map(|_| {
let second_stage_inbox = second_stage_inbox.clone();
let outbox = outbox.clone();
scope.spawn(move || Self::second_stage_worker(second_stage_inbox, outbox))
})
.collect();
let receiver_thread = {
let gene_pool = gene_pool.clone();
let in_flight = in_flight.clone();
scope.spawn(move || {
Self::work_receiver_thread(work_reception, gene_pool, population_size, in_flight)
})
};
Self {
gene_pool,
children_created: 0,
mutation_rate,
inertia,
damping,
work_submission,
first_stage_worker_pool,
second_stage_worker_pool,
receiver_thread,
population_size,
in_flight,
}
}
fn first_stage_worker(
inbox: mpmc::Receiver<FirstStageJob<G>>,
outbox: mpsc::Sender<RankedGenome<G>>,
first_stage_outbox: mpmc::Sender<SecondStageJob<G>>,
) {
for job in inbox {
match job {
FirstStageJob::NewGene(gene) => {
let mut ranked_gene: RankedGenome<G> = gene.into();
ranked_gene.eval();
outbox.send(ranked_gene).unwrap();
}
FirstStageJob::FirstStageAncestorEvaluation {
mut gene,
mutation_to_apply,
velocity,
} => {
let prior_fitness = gene.fitness();
gene.apply_mutation(&mutation_to_apply);
first_stage_outbox
.send(SecondStageJob {
gene,
prior_fitness,
recent_mutation: mutation_to_apply,
velocity,
})
.unwrap();
}
}
}
}
fn second_stage_worker(
second_stage_inbox: mpmc::Receiver<SecondStageJob<G>>,
outbox: mpsc::Sender<RankedGenome<G>>,
) {
for SecondStageJob {
gene,
prior_fitness,
recent_mutation,
velocity,
} in second_stage_inbox
{
let current_fitness = gene.fitness();
let ranked_gene = RankedGenome {
gene,
current_fitness,
prior_fitness,
velocity,
recent_mutation,
};
outbox.send(ranked_gene).unwrap();
}
}
fn work_receiver_thread(
work_reception: mpsc::Receiver<RankedGenome<G>>,
gene_pool: Arc<RwLock<Vec<RankedGenome<G>>>>,
max_population_size: usize,
in_flight: Gate<usize>,
) {
for ranked_gene in work_reception {
let mut gene_pool = gene_pool.write().unwrap();
let insert_index = gene_pool
.binary_search_by(|rg| ranked_gene.current_fitness.total_cmp(&rg.current_fitness));
let insert_index = match insert_index {
Ok(i) => i,
Err(i) => i,
};
gene_pool.insert(insert_index, ranked_gene);
if gene_pool.len() > max_population_size {
gene_pool.drain(max_population_size..);
}
in_flight.update(|x| *x = x.saturating_sub(1));
}
}
fn submit_job(&mut self, ranked_gene: FirstStageJob<G>) {
self.children_created += 1;
self.in_flight.update(|x| x.add_assign(1));
self.work_submission.send(ranked_gene).unwrap();
}
pub fn seed<R>(&mut self, rng: &mut R)
where
R: RandomSource,
{
let current_gene_pool_size = self.gene_pool.read().unwrap().len();
for _ in current_gene_pool_size..self.population_size {
self.submit_job(FirstStageJob::NewGene(G::generate(rng)));
}
}
pub fn train<R>(&mut self, num_children: usize, rng: &mut R) -> G
where
R: RandomSource,
G: Clone + Send + Sync + 'scope,
{
self.train_custom(
|x| x.child_count <= num_children,
Some(default_reporting_strategy(self.population_size)),
rng,
)
}
pub fn train_custom<R>(
&mut self,
mut train_criteria: impl FnMut(TrainingCriteriaMetrics) -> bool,
mut reporting_strategy: Option<
TrainingReportStrategy<
impl FnMut(TrainingCriteriaMetrics) -> bool,
impl FnMut(TrainingStats),
>,
>,
rng: &mut R,
) -> G
where
R: RandomSource,
G: Clone,
{
self.seed(rng);
self.in_flight.wait_while(|x| *x > 0);
loop {
let parent = {
let gene_pool = self.gene_pool.read().unwrap();
let min_fitness = gene_pool
.iter()
.map(|x| x.current_fitness)
.min_by(|a, b| a.total_cmp(b))
.unwrap();
random_choice_weighted_mapped_by_key(
&gene_pool,
rng,
|x| x - min_fitness,
|x| x.current_fitness,
)
.clone()
};
let delta_fitness = (parent.current_fitness - parent.prior_fitness) * self.inertia;
let acceleration = parent.recent_mutation * delta_fitness;
let velocity = (parent.velocity + acceleration) * self.damping;
let mut gene = parent.gene;
gene.apply_mutation(&velocity);
let mutation_to_apply = G::create_mutation(rng);
let new_job = FirstStageJob::FirstStageAncestorEvaluation {
gene,
mutation_to_apply,
velocity,
};
self.submit_job(new_job);
let metrics = self.metrics();
if let Some(reporting_strategy) = &mut reporting_strategy {
if (reporting_strategy.should_report)(metrics) {
(reporting_strategy.report_callback)(self.stats())
}
}
if !(train_criteria)(metrics) {
break;
}
}
self.in_flight.wait_while(|x| *x > 0);
self.gene_pool.read().unwrap().first().unwrap().gene.clone()
}
pub fn metrics(&self) -> TrainingCriteriaMetrics {
let gene_pool = self.gene_pool.read().unwrap();
TrainingCriteriaMetrics {
max_fitness: gene_pool.first().unwrap().current_fitness,
min_fitness: gene_pool.last().unwrap().current_fitness,
median_fitness: gene_pool[gene_pool.len() / 2].current_fitness,
child_count: self.children_created,
}
}
pub fn stats(&self) -> TrainingStats {
TrainingStats {
population_stats: self
.gene_pool
.read()
.unwrap()
.iter()
.map(|x| x.current_fitness)
.collect(),
child_count: self.children_created,
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct TrainingCriteriaMetrics {
pub max_fitness: f32,
pub min_fitness: f32,
pub median_fitness: f32,
pub child_count: usize,
}
#[derive(Clone, Copy, Debug)]
pub struct TrainingStats {
pub population_stats: PopulationStats,
pub child_count: usize,
}
impl fmt::Display for TrainingStats {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "child #{} {}", self.child_count, self.population_stats)
}
}
pub fn default_reporting_strategy(
n: usize,
) -> TrainingReportStrategy<impl FnMut(TrainingCriteriaMetrics) -> bool, impl FnMut(TrainingStats)>
{
TrainingReportStrategy {
should_report: move |m: TrainingCriteriaMetrics| m.child_count % n == 0,
report_callback: |s| println!("{s}"),
}
}
enum FirstStageJob<G>
where
G: InertialGenome,
{
NewGene(G),
FirstStageAncestorEvaluation {
gene: G,
mutation_to_apply: G::MutationVector,
velocity: G::MutationVector,
},
}
struct SecondStageJob<G>
where
G: InertialGenome,
{
gene: G,
prior_fitness: f32,
recent_mutation: G::MutationVector,
velocity: G::MutationVector,
}
#[derive(Clone, Debug)]
pub struct RankedGenome<G>
where
G: InertialGenome,
{
gene: G,
current_fitness: f32,
prior_fitness: f32,
velocity: G::MutationVector,
recent_mutation: G::MutationVector,
}
impl<G> RankedGenome<G>
where
G: InertialGenome,
{
fn eval(&mut self) {
self.current_fitness = self.gene.fitness();
}
}
impl<G> From<G> for RankedGenome<G>
where
G: InertialGenome,
{
fn from(gene: G) -> Self {
RankedGenome {
gene,
current_fitness: 0.0,
prior_fitness: 0.0,
velocity: G::MutationVector::default(),
recent_mutation: G::MutationVector::default(),
}
}
}
pub trait InertialGenome {
type MutationVector: Clone
+ Sized
+ Default
+ Send
+ Sync
+ Add<Self::MutationVector, Output = Self::MutationVector>
+ Mul<f32, Output = Self::MutationVector>;
fn generate<R>(rng: &mut R) -> Self
where
R: RandomSource;
fn create_mutation<R>(rng: &mut R) -> Self::MutationVector
where
R: RandomSource;
fn apply_mutation(&mut self, mutation: &Self::MutationVector);
fn fitness(&self) -> f32;
}