use itertools::Itertools;
use polytype::TypeScheme;
use rand::{distributions::Distribution, distributions::WeightedIndex, seq::SliceRandom, Rng};
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use crate::utils::weighted_permutation;
use crate::Task;
#[derive(Deserialize, Serialize)]
pub enum GPSelection {
#[serde(alias = "deterministic")]
Deterministic,
#[serde(alias = "drift")]
Drift(f64),
#[serde(alias = "hybrid")]
Hybrid(f64),
#[serde(alias = "probabilistic")]
Probabilistic,
Resample,
}
impl GPSelection {
pub(crate) fn update_population<R: Rng, X: Clone>(
&self,
population: &mut Vec<(X, f64)>,
children: Vec<X>,
oracle: impl Fn(&X) -> f64,
rng: &mut R,
) {
let mut scored_children = children
.into_iter()
.map(|child| {
let fitness = oracle(&child);
(child, fitness)
})
.collect_vec();
match self {
GPSelection::Drift(alpha) => {
for (p, old_fitness) in population.iter_mut() {
let new_fitness = oracle(p);
*old_fitness = *alpha * *old_fitness + (1.0 - alpha) * new_fitness;
}
let pop_size = population.len();
population.extend(scored_children);
*population = sample_without_replacement(population, pop_size, rng);
}
GPSelection::Resample => {
let pop_size = population.len();
*population = sample_with_replacement(&scored_children, pop_size, rng);
}
GPSelection::Deterministic => {
for child in scored_children {
sorted_place(child, population);
}
}
GPSelection::Hybrid(_) | GPSelection::Probabilistic => {
let pop_size = population.len();
let mut options = Vec::with_capacity(pop_size + scored_children.len());
options.append(population);
options.append(&mut scored_children);
let mut sample_size = pop_size;
if let GPSelection::Hybrid(det_proportion) = self {
let n_best = (pop_size as f64 * det_proportion).ceil() as usize;
sample_size -= n_best;
options.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
let rest = options.split_off(n_best);
*population = options;
options = rest;
}
population.append(&mut sample_pop(rng, options, sample_size));
}
}
}
}
#[derive(Deserialize, Serialize)]
pub struct GPParams {
pub selection: GPSelection,
pub population_size: usize,
pub tournament_size: usize,
pub mutation_prob: f64,
pub n_delta: usize,
}
pub trait GP<Observation: ?Sized> {
type Expression: Clone;
type Params;
fn genesis<R: Rng>(
&self,
params: &Self::Params,
rng: &mut R,
pop_size: usize,
tp: &TypeScheme,
) -> Vec<Self::Expression>;
fn mutate<R: Rng>(
&self,
params: &Self::Params,
rng: &mut R,
prog: &Self::Expression,
obs: &Observation,
) -> Vec<Self::Expression>;
fn crossover<R: Rng>(
&self,
params: &Self::Params,
rng: &mut R,
parent1: &Self::Expression,
parent2: &Self::Expression,
obs: &Observation,
) -> Vec<Self::Expression>;
fn tournament<'a, R: Rng>(
&self,
rng: &mut R,
tournament_size: usize,
population: &'a [(Self::Expression, f64)],
) -> &'a Self::Expression {
let tribute = if tournament_size == 1 {
population.choose(rng)
} else {
population
.choose_multiple(rng, tournament_size)
.max_by(|&(_, x), &(_, y)| x.partial_cmp(y).expect("found NaN"))
};
tribute
.map(|(expr, _)| expr)
.expect("tournament cannot select winner from no contestants")
}
fn init<R: Rng>(
&self,
params: &Self::Params,
rng: &mut R,
gpparams: &GPParams,
task: &impl Task<Observation, Representation = Self, Expression = Self::Expression>,
) -> Vec<(Self::Expression, f64)> {
let exprs = self.genesis(params, rng, gpparams.population_size, task.tp());
exprs
.into_iter()
.map(|expr| {
let l = task.oracle(self, &expr);
(expr, l)
})
.sorted_by(|(_, x), (_, y)| x.partial_cmp(y).expect("found NaN"))
.collect()
}
fn validate_offspring(
&self,
_params: &Self::Params,
_population: &[(Self::Expression, f64)],
_children: &[Self::Expression],
_offspring: &mut Vec<Self::Expression>,
) {
}
fn evolve<R: Rng>(
&self,
params: &Self::Params,
rng: &mut R,
gpparams: &GPParams,
task: &impl Task<Observation, Representation = Self, Expression = Self::Expression>,
population: &mut Vec<(Self::Expression, f64)>,
) {
let mut children = Vec::with_capacity(gpparams.n_delta);
while children.len() < gpparams.n_delta {
let mut offspring = if rng.gen_bool(gpparams.mutation_prob) {
let parent = self.tournament(rng, gpparams.tournament_size, population);
self.mutate(params, rng, parent, task.observation())
} else {
let parent1 = self.tournament(rng, gpparams.tournament_size, population);
let parent2 = self.tournament(rng, gpparams.tournament_size, population);
self.crossover(params, rng, parent1, parent2, task.observation())
};
self.validate_offspring(params, population, &children, &mut offspring);
children.append(&mut offspring);
}
children.truncate(gpparams.n_delta);
gpparams.selection.update_population(
population,
children,
|child| task.oracle(self, child),
rng,
);
}
}
fn sample_pop<T: Clone, R: Rng>(
rng: &mut R,
options: Vec<(T, f64)>,
sample_size: usize,
) -> Vec<(T, f64)> {
let mut sample = options
.choose_multiple_weighted(rng, sample_size, |(_, score)| *score)
.expect("bad weight")
.cloned()
.collect::<Vec<_>>();
sample.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
sample
}
fn sample_without_replacement<R: Rng, T: Clone>(
options: &[(T, f64)],
sample_size: usize,
rng: &mut R,
) -> Vec<(T, f64)> {
let weights = options.iter().map(|(_, weight)| *weight).collect_vec();
let mut sample = weighted_permutation(rng, options, &weights, Some(sample_size));
sample.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
sample
}
fn sample_with_replacement<R: Rng, T: Clone>(
options: &[(T, f64)],
sample_size: usize,
rng: &mut R,
) -> Vec<(T, f64)> {
let dist = WeightedIndex::new(options.iter().map(|(_, weight)| *weight)).unwrap();
let mut sample = Vec::with_capacity(sample_size);
for _ in 0..sample_size {
sample.push(options[dist.sample(rng)].clone());
}
sample.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
sample
}
fn sorted_place<T>(child: (T, f64), pop: &mut Vec<(T, f64)>) {
let r = pop.binary_search_by(|probe| probe.1.partial_cmp(&child.1).expect("found NaN"));
let idx = match r {
Ok(found) => found,
Err(insertion_point) => insertion_point,
};
if idx < pop.len() {
pop.pop();
pop.insert(idx, child);
}
}