use std::{
fmt,
ops::AddAssign,
random::RandomSource,
sync::{Arc, RwLock, mpmc, mpsc},
thread::{self, ScopedJoinHandle},
};
use crate::{
Gate, Genome, PopulationStats, TrainingReportStrategy, num_cpus, random_choice_weighted_mapped,
random_f32,
};
pub struct ContinuousTrainer<'scope, G> {
pub gene_pool: Arc<RwLock<Vec<(G, f32)>>>,
pub children_created: usize,
pub mutation_rate: f32,
pub reproduction_type_proportion: f32,
work_submission: mpmc::Sender<G>,
#[allow(unused)]
worker_pool: Vec<ScopedJoinHandle<'scope, ()>>,
#[allow(unused)]
receiver_thread: ScopedJoinHandle<'scope, ()>,
population_size: usize,
in_flight: Gate<usize>,
}
impl<'scope, G> ContinuousTrainer<'scope, G> {
pub fn new(
population_size: usize,
mutation_rate: f32,
reproduction_type_proportion: f32,
scope: &'scope thread::Scope<'scope, '_>,
) -> Self
where
G: Genome + 'scope + Send + Sync,
{
let in_flight = Gate::new(0);
let (work_submission, inbox) = mpmc::sync_channel(0);
let (outbox, work_reception) = mpsc::channel();
let gene_pool = Arc::new(RwLock::new(Vec::new()));
let worker_pool = (0..num_cpus())
.map(|_| {
let inbox = inbox.clone();
let outbox = outbox.clone();
scope.spawn(move || Self::worker_thread(inbox, outbox))
})
.collect();
let receiver_thread = {
let gene_pool = gene_pool.clone();
let in_flight = in_flight.clone();
scope.spawn(move || {
Self::work_receiver_thread(
work_reception,
gene_pool,
population_size,
in_flight,
)
})
};
Self {
gene_pool,
work_submission,
worker_pool,
receiver_thread,
mutation_rate,
population_size,
in_flight,
children_created: 0,
reproduction_type_proportion,
}
}
fn worker_thread(inbox: mpmc::Receiver<G>, outbox: mpsc::Sender<(G, f32)>)
where
G: Genome,
{
for gene in inbox {
let fitness = gene.fitness();
outbox.send((gene, fitness)).unwrap();
}
}
fn work_receiver_thread(
work_reception: mpsc::Receiver<(G, f32)>,
gene_pool: Arc<RwLock<Vec<(G, f32)>>>,
max_population_size: usize,
in_flight: Gate<usize>,
) {
for (gene, score) in work_reception {
let mut gene_pool = gene_pool.write().unwrap();
let insert_index = gene_pool.binary_search_by(|x| score.total_cmp(&x.1));
let insert_index = match insert_index {
Ok(i) => i,
Err(i) => i,
};
gene_pool.insert(insert_index, (gene, score));
if gene_pool.len() > max_population_size {
gene_pool.drain(max_population_size..);
}
in_flight.update(|x| *x = x.saturating_sub(1));
}
}
pub fn submit_job(&mut self, gene: G) {
self.children_created += 1;
self.in_flight.update(|x| x.add_assign(1));
self.work_submission.send(gene).unwrap();
}
pub fn seed<R>(&mut self, rng: &mut R)
where
R: RandomSource,
G: Genome,
{
let current_gene_pool_size = self.gene_pool.read().unwrap().len();
for _ in current_gene_pool_size..self.population_size {
self.submit_job(G::generate(rng));
}
}
pub fn train<R>(&mut self, num_children: usize, rng: &mut R) -> G
where
R: RandomSource,
G: Clone + Genome + Send + Sync + 'scope,
{
self.train_custom(
|x| x.child_count <= num_children,
Some(default_reporting_strategy(self.population_size)),
rng,
)
}
pub fn train_custom<R>(
&mut self,
mut train_criteria: impl FnMut(TrainingCriteriaMetrics) -> bool,
mut reporting_strategy: Option<
TrainingReportStrategy<
impl FnMut(TrainingCriteriaMetrics) -> bool,
impl FnMut(TrainingStats),
>,
>,
rng: &mut R,
) -> G
where
R: RandomSource,
G: Clone + Genome + Send + Sync + 'scope,
{
self.seed(rng);
self.in_flight.wait_while(|x| *x > 0);
loop {
let new_child = {
let gene_pool = self.gene_pool.read().unwrap();
let min_fitness = gene_pool
.iter()
.map(|x| x.1)
.min_by(|a, b| a.total_cmp(b))
.unwrap();
let should_crossbreed = random_f32(rng) < self.reproduction_type_proportion;
let mut choose_parent =
|| random_choice_weighted_mapped(&gene_pool, rng, |x| x - min_fitness);
if should_crossbreed {
let mother = (choose_parent)();
let father = (choose_parent)();
mother.crossbreed(father, rng)
} else {
let mut new_child = (choose_parent)().clone();
new_child.mutate(self.mutation_rate, rng);
new_child
}
};
self.submit_job(new_child);
let metrics = self.metrics();
if let Some(reporting_strategy) = &mut reporting_strategy {
if (reporting_strategy.should_report)(metrics) {
(reporting_strategy.report_callback)(self.stats())
}
}
if !(train_criteria)(metrics) {
break;
}
}
self.in_flight.wait_while(|x| *x > 0);
self.gene_pool.read().unwrap().first().unwrap().0.clone()
}
pub fn metrics(&self) -> TrainingCriteriaMetrics {
let gene_pool = self.gene_pool.read().unwrap();
TrainingCriteriaMetrics {
max_fitness: gene_pool.first().unwrap().1,
min_fitness: gene_pool.last().unwrap().1,
median_fitness: gene_pool[gene_pool.len() / 2].1,
child_count: self.children_created,
}
}
pub fn stats(&self) -> TrainingStats {
TrainingStats {
population_stats: self.gene_pool.read().unwrap().iter().map(|x| x.1).collect(),
child_count: self.children_created,
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct TrainingCriteriaMetrics {
pub max_fitness: f32,
pub min_fitness: f32,
pub median_fitness: f32,
pub child_count: usize,
}
#[derive(Clone, Copy, Debug)]
pub struct TrainingStats {
pub population_stats: PopulationStats,
pub child_count: usize,
}
impl fmt::Display for TrainingStats {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "child #{} {}", self.child_count, self.population_stats)
}
}
pub fn default_reporting_strategy(
n: usize,
) -> TrainingReportStrategy<impl FnMut(TrainingCriteriaMetrics) -> bool, impl FnMut(TrainingStats)>
{
TrainingReportStrategy {
should_report: move |m: TrainingCriteriaMetrics| m.child_count % n == 0,
report_callback: |s| println!("{s}"),
}
}