use std::collections::HashMap;
use std::ops::{Deref, DerefMut};
use std::sync::{Arc, Mutex};
use polytype::Type;
use rayon::prelude::*;
use Task;
pub struct ECParams {
pub frontier_limit: usize,
pub search_limit: usize,
}
pub trait EC: Send + Sync + Sized {
type Expression: Clone + Send + Sync;
type Params;
fn enumerate<'a>(&'a self, tp: Type) -> Box<Iterator<Item = (Self::Expression, f64)> + 'a>;
fn compress<O: Sync>(
&self,
params: &Self::Params,
tasks: &[Task<Self, Self::Expression, O>],
frontiers: Vec<ECFrontier<Self>>,
) -> (Self, Vec<ECFrontier<Self>>);
fn ec<O: Sync>(
&self,
ecparams: &ECParams,
params: &Self::Params,
tasks: &[Task<Self, Self::Expression, O>],
) -> (Self, Vec<ECFrontier<Self>>) {
let frontiers = self.explore(ecparams, tasks);
self.compress(params, tasks, frontiers)
}
fn ec_with_recognition<O: Sync, R>(
&self,
ecparams: &ECParams,
params: &Self::Params,
tasks: &[Task<Self, Self::Expression, O>],
recognizer: R,
) -> (Self, Vec<ECFrontier<Self>>)
where
R: FnOnce(&Self, &[Task<Self, Self::Expression, O>]) -> Vec<Self>,
{
let recognized = recognizer(self, tasks);
let frontiers = self.explore_with_recognition(ecparams, tasks, &recognized);
self.compress(params, tasks, frontiers)
}
fn explore<O: Sync>(
&self,
ec_params: &ECParams,
tasks: &[Task<Self, Self::Expression, O>],
) -> Vec<ECFrontier<Self>> {
let mut tps = HashMap::new();
for (i, task) in tasks.into_iter().enumerate() {
tps.entry(&task.tp).or_insert_with(Vec::new).push((i, task))
}
let mut results: Vec<ECFrontier<Self>> =
(0..tasks.len()).map(|_| ECFrontier::default()).collect();
{
let mutex = Arc::new(Mutex::new(&mut results));
tps.into_par_iter()
.map(|(tp, tasks)| enumerate_solutions(self, ec_params, tp.clone(), tasks))
.flat_map(|iter| iter)
.for_each(move |(i, frontier)| {
let mut results = mutex.lock().unwrap();
results[i] = frontier
});
}
results
}
fn explore_with_recognition<O: Sync>(
&self,
ec_params: &ECParams,
tasks: &[Task<Self, Self::Expression, O>],
representations: &[Self],
) -> Vec<ECFrontier<Self>> {
tasks
.par_iter()
.zip(representations)
.enumerate()
.map(|(i, (t, repr))| {
enumerate_solutions(repr, ec_params, t.tp.clone(), vec![(i, t)])
.pop()
.unwrap()
.1
})
.collect()
}
}
fn enumerate_solutions<L, X, O: Sync>(
repr: &L,
params: &ECParams,
tp: Type,
tasks: Vec<(usize, &Task<L, X, O>)>,
) -> Vec<(usize, ECFrontier<L>)>
where
X: Send + Sync + Clone,
L: EC<Expression = X>,
{
let mut frontiers = tasks
.iter()
.map(|&(j, _)| (j, ECFrontier::default()))
.collect();
let mut tasks: Vec<_> = tasks
.into_iter()
.enumerate()
.map(|(i, (_, t))| (i, t))
.collect();
let mut searched = 0;
let mut update = |frontiers: &mut Vec<(usize, ECFrontier<L>)>, expr: X, log_prior: f64| {
let evaluations: Vec<_> = tasks
.par_iter()
.map(|&(i, t)| {
let log_likelihood = (t.oracle)(repr, &expr);
(i, t, log_likelihood)
})
.collect();
tasks = evaluations
.into_iter()
.filter_map(|(i, t, l)| {
if l.is_finite() {
frontiers[i].1.push(expr.clone(), log_prior, l);
if frontiers[i].1.len() < params.frontier_limit {
Some((i, t))
} else {
None
}
} else {
Some((i, t))
}
})
.collect();
if tasks.is_empty() {
false
} else {
searched += 1;
searched < params.search_limit
}
};
for (expr, log_prior) in repr.enumerate(tp) {
if !update(&mut frontiers, expr, log_prior) {
break;
}
}
frontiers
}
#[derive(Clone, Debug)]
pub struct ECFrontier<L: EC>(pub Vec<(L::Expression, f64, f64)>);
impl<L: EC> ECFrontier<L> {
pub fn push(&mut self, expr: L::Expression, log_prior: f64, log_likelihood: f64) {
self.0.push((expr, log_prior, log_likelihood))
}
pub fn best_solution(&self) -> Option<&(L::Expression, f64, f64)> {
self.0
.iter()
.max_by(|&&(_, xp, xl), &&(_, yp, yl)| (xp + xl).partial_cmp(&(yp + yl)).unwrap())
}
}
impl<L: EC> Default for ECFrontier<L> {
fn default() -> Self {
ECFrontier(vec![])
}
}
impl<L: EC> Deref for ECFrontier<L> {
type Target = Vec<(L::Expression, f64, f64)>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<L: EC> DerefMut for ECFrontier<L> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}