use crate::distributions::{MultivarNormDist, NormDist};
use crate::splitter::Splitter;
use crate::triangular::Triangular;
use futures::stream::{FuturesOrdered, StreamExt};
use rand::distributions::{Distribution, Standard, Uniform};
use rand::seq::SliceRandom;
use rand::Rng;
use rayon::prelude::*;
use std::cmp::Ordering;
use std::future::Future;
use std::ops::RangeInclusive;
#[derive(Copy, Clone, Debug)]
pub enum SearchRange {
Finite {
low: f64,
high: f64,
},
Infinite {
average: f64,
stddev: f64,
},
}
impl From<RangeInclusive<f64>> for SearchRange {
fn from(range: RangeInclusive<f64>) -> Self {
SearchRange::Finite {
low: *range.start(),
high: *range.end(),
}
}
}
#[derive(Clone, Debug)]
enum SearchDist {
Finite(Uniform<f64>),
Infinite(NormDist<f64>),
}
impl From<SearchRange> for SearchDist {
fn from(search_range: SearchRange) -> SearchDist {
match search_range {
SearchRange::Finite { low, high } => {
SearchDist::Finite(Uniform::new_inclusive(low, high))
}
SearchRange::Infinite { average, stddev } => {
SearchDist::Infinite(NormDist::new(average, stddev))
}
}
}
}
impl Distribution<f64> for SearchDist {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
match self {
SearchDist::Finite(dist) => dist.sample(rng),
SearchDist::Infinite(dist) => dist.sample(rng),
}
}
}
pub trait Specimen {
fn params(&self) -> &[f64];
fn cmp_cost(&self, other: &Self) -> Ordering;
fn params_dist(&self, other: &Self) -> f64 {
self.params()
.iter()
.copied()
.zip(other.params().iter().copied())
.map(|(a, b)| (a - b).powf(2.0))
.sum()
}
}
#[derive(Clone, Debug)]
pub struct BasicSpecimen {
pub params: Vec<f64>,
pub cost: f64,
}
impl Specimen for BasicSpecimen {
fn params(&self) -> &[f64] {
&self.params
}
fn cmp_cost(&self, other: &Self) -> Ordering {
let a = self.cost;
let b = other.cost;
a.partial_cmp(&b)
.unwrap_or_else(|| a.is_nan().cmp(&b.is_nan()).then(Ordering::Equal))
}
}
#[derive(Clone, Debug)]
pub struct Solver<S, C> {
search_space: Vec<SearchRange>,
search_dists: Vec<SearchDist>,
constructor: C,
division_count: usize,
min_population: usize,
is_sorted: bool,
specimens: Vec<S>,
}
impl<S, C> Solver<S, C> {
pub fn new<T>(search_space: Vec<SearchRange>, constructor: C) -> Self
where
C: Fn(Vec<f64>) -> T + Sync,
{
let search_dists: Vec<SearchDist> = search_space
.iter()
.copied()
.map(|search_range| SearchDist::from(search_range))
.collect();
let mut solver = Solver {
search_space,
search_dists,
constructor,
division_count: Default::default(),
min_population: Default::default(),
is_sorted: true,
specimens: vec![],
};
solver.set_division_count(1);
solver
}
pub fn dim(&self) -> usize {
self.search_space.len()
}
pub fn set_division_count(&mut self, mut division_count: usize) {
let dim = self.dim();
if dim == 0 {
self.division_count = 1;
self.min_population = 0;
} else {
if division_count < 1 {
division_count = 1;
} else if division_count > dim {
division_count = dim;
}
self.division_count = division_count;
self.min_population = (dim - 1) / division_count + 1;
}
}
pub fn set_speed_factor(&mut self, speed_factor: f64) {
assert!(
speed_factor >= 0.0 && speed_factor <= 1.0,
"speed_factor must be between 0.0 and 1.0"
);
self.set_division_count((self.dim() as f64).powf(speed_factor).round() as usize);
}
pub fn set_max_division_size(&mut self, max_division_size: usize) {
self.set_division_count((self.dim() as f64 / max_division_size as f64).ceil() as usize);
}
pub fn division_count(&self) -> usize {
self.division_count
}
pub fn min_population(&self) -> usize {
self.min_population
}
}
impl<S, C> Solver<S, C>
where
S: Specimen + Send,
{
fn sort(&mut self) {
if !self.is_sorted {
self.specimens.par_sort_by(S::cmp_cost);
self.is_sorted = true;
}
}
pub fn extend_specimens<I: IntoIterator<Item = S>>(&mut self, iter: I) {
self.is_sorted = false;
self.specimens.extend(iter);
}
pub fn replace_worst_specimens<I: IntoIterator<Item = S>>(&mut self, iter: I) {
let count = self.specimens.len();
self.extend_specimens(iter);
self.truncate_specimens(count);
}
pub async fn extend_specimens_async<F, I>(&mut self, iter: I)
where
F: Future<Output = S> + Send,
I: IntoIterator<Item = F>,
{
let new_specimens = FuturesOrdered::from_iter(iter);
self.specimens.reserve(new_specimens.len());
self.is_sorted = false;
new_specimens
.for_each(|specimen| {
self.specimens.push(specimen);
async { () }
})
.await;
}
pub async fn replace_worst_specimens_async<F, I>(&mut self, iter: I)
where
F: Future<Output = S> + Send,
I: IntoIterator<Item = F>,
{
let count = self.specimens.len();
self.extend_specimens_async(iter).await;
self.truncate_specimens(count);
}
pub fn truncate_specimens(&mut self, count: usize) {
self.sort();
self.specimens.truncate(count);
}
pub fn converged(&mut self) -> bool {
let len = self.specimens.len();
if len == 0 {
true
} else {
self.sort();
self.specimens[0].cmp_cost(&self.specimens[len - 1]) == Ordering::Equal
}
}
pub fn specimens(&mut self) -> &[S] {
self.sort();
&self.specimens
}
pub fn specimens_mut(&mut self) -> &mut Vec<S> {
self.sort();
self.is_sorted = false;
&mut self.specimens
}
pub fn into_specimens(mut self) -> Vec<S> {
self.sort();
self.specimens
}
pub fn into_specimen(self) -> S {
self.into_specimens()
.into_iter()
.next()
.expect("solver contains no specimen")
}
}
impl<S, C, T> Solver<S, C>
where
S: Specimen + Send + Sync,
C: Fn(Vec<f64>) -> T + Sync,
T: Send,
{
fn random_specimen<R>(&self, rng: &mut R) -> T
where
R: Rng + ?Sized,
{
(self.constructor)(
self.search_dists
.iter()
.map(|dist| dist.sample(rng))
.collect::<Vec<f64>>(),
)
}
pub fn random_specimens(&self, count: usize) -> Vec<T> {
(0..count)
.into_par_iter()
.map_init(|| rand::thread_rng(), |rng, _| self.random_specimen(rng))
.collect()
}
pub fn recombined_specimens(&mut self, children_count: usize, local_factor: f64) -> Vec<T> {
self.sort();
let total_count = self.specimens.len();
let total_weight = total_count as f64;
let splitter = Splitter::new(&mut rand::thread_rng(), self.dim(), self.division_count());
let sub_averages = splitter
.groups()
.par_iter()
.map(|group| {
(0..group.len())
.into_par_iter()
.map(|i| {
let i_orig = group[i];
self.specimens
.par_iter()
.map(|specimen| specimen.params()[i_orig])
.sum::<f64>()
/ total_weight
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>(); let sub_dists = splitter
.groups()
.par_iter()
.zip(sub_averages.into_par_iter())
.map(|(group, averages)| {
let covariances = Triangular::<f64>::par_new(group.len(), |(i, j)| {
let i_orig = group[i];
let j_orig = group[j];
self.specimens
.iter()
.map(|specimen| {
let a = specimen.params()[i_orig] - averages[i];
let b = specimen.params()[j_orig] - averages[j];
a * b
})
.sum::<f64>()
/ total_weight
});
MultivarNormDist::new(averages, covariances)
})
.collect::<Vec<_>>(); let local_exp = if local_factor > 0.0 {
1.0 / local_factor
} else {
f64::INFINITY
};
(0..children_count)
.into_par_iter()
.map_init(
|| rand::thread_rng(),
|rng, _| {
let param_groups_iters: Box<[_]> = sub_dists
.iter()
.map(|dist| dist.sample(rng).into_iter())
.collect();
let mut params: Vec<_> = splitter.merge(param_groups_iters).collect();
let specimen = self.specimens.choose(rng).unwrap();
let parent_params = specimen.params();
let factor1: f64 = Standard.sample(rng);
let factor1 = 2.0 * factor1.powf(local_exp);
let factor2: f64 = 1.0 - factor1;
for i in 0..params.len() {
params[i] = factor1 * parent_params[i] + factor2 * params[i];
if let SearchRange::Finite { low, high } = self.search_space[i] {
if !(low..=high).contains(¶ms[i]) {
return self.random_specimen(rng);
}
}
}
(self.constructor)(params)
},
)
.collect()
}
}
#[cfg(test)]
mod tests {
use super::{BasicSpecimen, SearchRange, Solver, Specimen as _};
use rand::{thread_rng, Rng};
#[test]
fn test_solver() {
let mut rng = thread_rng();
const PARAMCNT: usize = 3;
let ranges = vec![-1.0..=1.0; PARAMCNT];
let search_space: Vec<SearchRange> = ranges.iter().cloned().map(Into::into).collect();
let goals: Vec<f64> = ranges
.iter()
.cloned()
.map(|range| rng.gen_range(range) * 0.75)
.collect();
let mut solver = Solver::new(search_space, |params: Vec<f64>| {
let mut cost: f64 = 0.0;
for (param, goal) in params.iter().zip(goals.iter()) {
cost += (param - goal) * (param - goal);
}
BasicSpecimen { params, cost }
});
let initial_specimens = solver.random_specimens(200);
solver.extend_specimens(initial_specimens);
for _ in 0..1000 {
let new_specimens = solver.recombined_specimens(10, 0.0);
solver.replace_worst_specimens(new_specimens);
}
for (param, goal) in solver
.specimens
.first()
.unwrap()
.params()
.iter()
.zip(goals.iter())
{
assert!((param - goal).abs() < 0.01);
}
}
}