use std::{
fmt,
random::RandomSource,
sync::{mpmc, mpsc},
thread::{self, ScopedJoinHandle},
};
use crate::{
Genome, PopulationStats, TrainingReportStrategy, bounds, num_cpus, random_choice_weighted,
random_f32,
};
#[allow(unused_imports)]
use crate::continuous::ContinuousTrainer;
#[allow(unused_imports)]
use crate::continuous;
pub struct StochasticTrainer<'scope, G> {
pub gene_pool: Vec<(G, Option<f32>)>,
pub generation: usize,
pub mutation_rate: f32,
pub reproduction_type_proportion: f32,
#[allow(unused)]
worker_pool: Vec<ScopedJoinHandle<'scope, ()>>,
work_submission: mpmc::Sender<(usize, G)>,
work_reception: mpsc::Receiver<(usize, f32)>,
population_size: usize,
}
impl<'scope, G> StochasticTrainer<'scope, G> {
pub fn new<R>(
population_size: usize,
mutation_rate: f32,
reproduction_type_proportion: f32,
rng: &mut R,
scope: &'scope thread::Scope<'scope, '_>,
) -> Self
where
G: Genome + Send + 'scope,
R: RandomSource,
{
let (work_submission, inbox) = mpmc::channel();
let (outbox, work_reception) = mpsc::channel();
let gene_pool = (0..population_size)
.map(|_| (G::generate(rng), None))
.collect();
let worker_pool = (0..num_cpus())
.map(|_| {
let inbox = inbox.clone();
let outbox = outbox.clone();
scope.spawn(move || StochasticTrainer::<G>::worker_thread(inbox, outbox))
})
.collect();
Self {
gene_pool,
worker_pool,
work_submission,
work_reception,
generation: 0,
population_size,
mutation_rate,
reproduction_type_proportion,
}
}
fn worker_thread<I>(inbox: mpmc::Receiver<(I, G)>, outbox: mpsc::Sender<(I, f32)>)
where
G: Genome,
{
for (id, gene) in inbox {
let result = gene.fitness();
outbox.send((id, result)).unwrap();
}
}
pub fn eval(&mut self)
where
G: Clone,
{
let to_eval = self
.gene_pool
.iter()
.enumerate()
.filter_map(|(i, (gene, score))| score.is_none().then(|| (i, gene.clone())))
.map(|p| self.work_submission.send(p).unwrap())
.count();
for (i, score) in self.work_reception.iter().take(to_eval) {
self.gene_pool[i].1 = Some(score);
}
}
pub fn scores(&self) -> impl Iterator<Item = f32> {
self.gene_pool.iter().filter_map(|(_, s)| *s)
}
pub fn prune<R>(&mut self, mut selection_strategy: impl FnMut(f32, &mut R) -> bool, rng: &mut R)
where
R: RandomSource,
{
let (min_score, max_score) = bounds(self.scores()).unwrap();
let score_range = max_score - min_score;
self.gene_pool.retain(|(_, score)| {
let Some(score) = score else {
return true;
};
let percentile = (*score - min_score) / score_range;
(selection_strategy)(percentile, rng)
});
}
pub fn reproduce<R>(&mut self, rng: &mut R)
where
R: RandomSource,
G: Genome + Clone,
{
let (min_score, max_score) = bounds(self.scores()).unwrap();
let score_range = max_score - min_score;
let mut total_score = 0.0;
let percentile_pairs: Vec<_> = self
.gene_pool
.iter()
.enumerate()
.map(|(i, (_, score))| {
let Some(score) = score else {
return (i, 0.0);
};
let percentile = (*score - min_score) / score_range;
total_score += percentile;
(i, percentile)
})
.collect();
while self.gene_pool.len() < self.population_size {
let should_crossbreed = random_f32(rng) < self.reproduction_type_proportion;
let mut choose_parent_index = || random_choice_weighted(&percentile_pairs, rng);
let new_child = if should_crossbreed {
let &mother_index = (choose_parent_index)();
let &father_index = (choose_parent_index)();
self.gene_pool[mother_index]
.0
.crossbreed(&self.gene_pool[father_index].0, rng)
} else {
let &parent_index = (choose_parent_index)();
let mut new_child = self.gene_pool[parent_index].0.clone();
new_child.mutate(self.mutation_rate, rng);
new_child
};
self.gene_pool.push((new_child, None));
}
}
pub fn step<R>(
&mut self,
selection_strategy: impl FnMut(f32, &mut R) -> bool,
mut reporting_strategy: Option<
&mut TrainingReportStrategy<
impl FnMut(TrainingCriteriaMetrics) -> bool,
impl FnMut(TrainingStats),
>,
>,
rng: &mut R,
) where
G: Clone + Genome,
R: RandomSource,
{
self.generation += 1;
self.eval();
let stats = reporting_strategy
.as_mut()
.and_then(|s| (s.should_report)(self.metrics()).then(|| self.stats()));
self.prune(selection_strategy, rng);
self.reproduce(rng);
if let (Some(stats), Some(reporting_strategy)) = (stats, &mut reporting_strategy) {
(reporting_strategy.report_callback)(stats);
}
}
pub fn train<R>(&mut self, generations: usize, rng: &mut R) -> G
where
G: Clone + Genome,
R: RandomSource,
{
self.train_custom(
default_selection_strategy,
|m| m.generation <= generations,
Some(default_reporting_strategy()),
rng,
)
}
pub fn train_custom<R>(
&mut self,
mut selection_strategy: impl FnMut(f32, &mut R) -> bool,
mut training_criteria: impl FnMut(TrainingCriteriaMetrics) -> bool,
mut reporting_strategy: Option<
TrainingReportStrategy<
impl FnMut(TrainingCriteriaMetrics) -> bool,
impl FnMut(TrainingStats),
>,
>,
rng: &mut R,
) -> G
where
G: Clone + Genome,
R: RandomSource,
{
loop {
self.step::<R>(&mut selection_strategy, reporting_strategy.as_mut(), rng);
if !(training_criteria)(self.metrics()) {
break;
}
}
self.gene_pool
.iter()
.filter_map(|(g, score)| score.map(|s| (g, s)))
.max_by(|a, b| a.1.total_cmp(&b.1))
.unwrap()
.0
.clone()
}
pub fn stats(&self) -> TrainingStats {
TrainingStats {
population_stats: self.scores().collect(),
generation: self.generation,
}
}
pub fn metrics(&self) -> TrainingCriteriaMetrics {
TrainingCriteriaMetrics {
generation: self.generation,
}
}
}
pub struct TrainingCriteriaMetrics {
pub generation: usize,
}
#[derive(Clone, Copy, Debug)]
pub struct TrainingStats {
pub population_stats: PopulationStats,
pub generation: usize,
}
impl fmt::Display for TrainingStats {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"generation {} {}",
self.generation, self.population_stats
)
}
}
pub fn default_reporting_strategy()
-> TrainingReportStrategy<impl FnMut(TrainingCriteriaMetrics) -> bool, impl FnMut(TrainingStats)> {
TrainingReportStrategy {
should_report: |_| true,
report_callback: |s| println!("{s}"),
}
}
pub fn default_selection_strategy<R>(score: f32, rng: &mut R) -> bool
where
R: RandomSource,
{
score > random_f32(rng)
}