use itertools::Itertools;
use polytype::TypeSchema;
use rand::{distributions::Distribution, distributions::WeightedIndex, seq::IteratorRandom, Rng};
use std::cmp::Ordering;
use utils::{logsumexp, weighted_sample};
use Task;
#[derive(Deserialize, Serialize)]
pub enum GPSelection {
#[serde(alias = "deterministic")]
Deterministic,
#[serde(alias = "drift")]
Drift(f64),
#[serde(alias = "hybrid")]
Hybrid(f64),
#[serde(alias = "probabilistic")]
Probabilistic,
Resample,
}
impl GPSelection {
pub(crate) fn update_population<'a, R: Rng, X: Clone + Send + Sync>(
&self,
population: &mut Vec<(X, f64)>,
children: Vec<X>,
oracle: Box<dyn Fn(&X) -> f64 + Send + Sync + 'a>,
rng: &mut R,
) {
let mut scored_children = children
.into_iter()
.map(|child| {
let fitness = oracle(&child);
(child, fitness)
})
.collect_vec();
match self {
GPSelection::Drift(alpha) => {
for (p, old_fitness) in population.iter_mut() {
let new_fitness = oracle(p);
*old_fitness = *alpha * *old_fitness + (1.0 - alpha) * new_fitness;
}
let pop_size = population.len();
population.extend(scored_children);
*population = sample_without_replacement(population, pop_size, rng);
}
GPSelection::Resample => {
let pop_size = population.len();
*population = sample_with_replacement(&mut scored_children, pop_size, rng);
}
GPSelection::Deterministic => {
for child in scored_children {
sorted_place(child, population);
}
}
_ => {
let pop_size = population.len();
let mut options = Vec::with_capacity(pop_size + scored_children.len());
options.append(population);
options.append(&mut scored_children);
let mut sample_size = pop_size;
if let GPSelection::Hybrid(det_proportion) = self {
let n_best = (pop_size as f64 * det_proportion).ceil() as usize;
sample_size -= n_best;
options.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
let rest = options.split_off(n_best);
*population = options;
options = rest;
}
population.append(&mut sample_pop(options, sample_size));
}
}
}
}
#[derive(Deserialize, Serialize)]
pub struct GPParams {
pub selection: GPSelection,
pub population_size: usize,
pub tournament_size: usize,
pub mutation_prob: f64,
pub n_delta: usize,
}
pub trait GP: Send + Sync + Sized {
type Expression: Clone + Send + Sync;
type Params;
type Observation: Clone + Send + Sync;
fn genesis<R: Rng>(
&self,
params: &Self::Params,
rng: &mut R,
pop_size: usize,
tp: &TypeSchema,
) -> Vec<Self::Expression>;
fn mutate<R: Rng>(
&self,
params: &Self::Params,
rng: &mut R,
prog: &Self::Expression,
obs: &Self::Observation,
) -> Vec<Self::Expression>;
fn crossover<R: Rng>(
&self,
params: &Self::Params,
rng: &mut R,
parent1: &Self::Expression,
parent2: &Self::Expression,
obs: &Self::Observation,
) -> Vec<Self::Expression>;
fn tournament<'a, R: Rng>(
&self,
rng: &mut R,
tournament_size: usize,
population: &'a [(Self::Expression, f64)],
) -> &'a Self::Expression {
if tournament_size == 1 {
&population[rng.gen_range(0, population.len())].0
} else {
(0..population.len())
.choose_multiple(rng, tournament_size)
.into_iter()
.map(|i| &population[i])
.max_by(|&&(_, ref x), &&(_, ref y)| x.partial_cmp(y).expect("found NaN"))
.map(|&(ref expr, _)| expr)
.expect("tournament cannot select winner from no contestants")
}
}
fn init<R: Rng, O: Sync>(
&self,
params: &Self::Params,
rng: &mut R,
gpparams: &GPParams,
task: &Task<Self, Self::Expression, O>,
) -> Vec<(Self::Expression, f64)> {
let exprs = self.genesis(params, rng, gpparams.population_size, &task.tp);
exprs
.into_iter()
.map(|expr| {
let l = (task.oracle)(self, &expr);
(expr, l)
})
.sorted_by(|&(_, ref x), &(_, ref y)| x.partial_cmp(y).expect("found NaN"))
.collect()
}
fn validate_offspring(
&self,
_params: &Self::Params,
_population: &[(Self::Expression, f64)],
_children: &[Self::Expression],
_offspring: &mut Vec<Self::Expression>,
) {
}
fn evolve<R: Rng>(
&self,
params: &Self::Params,
rng: &mut R,
gpparams: &GPParams,
task: &Task<Self, Self::Expression, Self::Observation>,
population: &mut Vec<(Self::Expression, f64)>,
) {
let mut children = Vec::with_capacity(gpparams.n_delta);
while children.len() < gpparams.n_delta {
let mut offspring = if rng.gen_bool(gpparams.mutation_prob) {
let parent = self.tournament(rng, gpparams.tournament_size, population);
self.mutate(params, rng, parent, &task.observation)
} else {
let parent1 = self.tournament(rng, gpparams.tournament_size, population);
let parent2 = self.tournament(rng, gpparams.tournament_size, population);
self.crossover(params, rng, parent1, parent2, &task.observation)
};
self.validate_offspring(params, population, &children, &mut offspring);
children.append(&mut offspring);
}
children.truncate(gpparams.n_delta);
gpparams.selection.update_population(
population,
children,
Box::new(|child| (task.oracle)(self, child)),
rng,
);
}
}
fn sample_pop<T: Clone>(options: Vec<(T, f64)>, sample_size: usize) -> Vec<(T, f64)> {
let (idxs, scores): (Vec<usize>, Vec<f64>) = options
.iter()
.map(|&(_, score)| score)
.combinations(sample_size)
.map(|combo| (-combo.iter().sum::<f64>()))
.enumerate()
.unzip();
let sum_scores = logsumexp(&scores);
let scores = scores
.iter()
.map(|x| (x - sum_scores).exp())
.collect::<Vec<_>>();
let idx = weighted_sample(&idxs, &scores);
options
.into_iter()
.combinations(sample_size)
.nth(*idx)
.unwrap()
}
fn sample_without_replacement<R: Rng, T: Clone>(
options: &mut Vec<(T, f64)>,
sample_size: usize,
rng: &mut R,
) -> Vec<(T, f64)> {
let mut weights = options
.iter()
.map(|(_, weight)| (-weight).exp())
.collect_vec();
let mut sample = Vec::with_capacity(sample_size);
for _ in 0..sample_size {
let dist = WeightedIndex::new(&weights[..]).unwrap();
let sampled_idx = dist.sample(rng);
sample.push(options[sampled_idx].clone());
weights[sampled_idx] = 0.0;
}
sample.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
sample
}
fn sample_with_replacement<R: Rng, T: Clone>(
options: &mut Vec<(T, f64)>,
sample_size: usize,
rng: &mut R,
) -> Vec<(T, f64)> {
let dist = WeightedIndex::new(options.iter().map(|(_, weight)| (-weight).exp())).unwrap();
let mut sample = Vec::with_capacity(sample_size);
for _ in 0..sample_size {
sample.push(options[dist.sample(rng)].clone());
}
sample.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
sample
}
fn sorted_place<T>(child: (T, f64), pop: &mut Vec<(T, f64)>) {
let orig_size = pop.len();
let mut size = orig_size;
if size == 0 {
return;
}
let idx = {
let mut base = 0usize;
while size > 1 {
let half = size / 2;
let mid = base + half;
let other = unsafe { pop.get_unchecked(mid) };
let cmp = other.1.partial_cmp(&child.1).expect("found NaN");
base = if cmp == Ordering::Greater { base } else { mid };
size -= half;
}
let other = unsafe { pop.get_unchecked(base) };
let cmp = other.1.partial_cmp(&child.1).expect("found NaN");
base + (cmp != Ordering::Greater) as usize
};
if idx < orig_size {
pop.pop();
pop.insert(idx, child);
}
}