use std::collections::HashMap;
use std::ops::{Deref, DerefMut};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use std::thread;
use crossbeam_channel::bounded;
use polytype::TypeSchema;
use rayon::prelude::*;
use Task;
pub struct ECParams {
pub frontier_limit: usize,
pub search_limit_timeout: Option<Duration>,
pub search_limit_description_length: Option<f64>,
}
pub trait EC: Send + Sync + Sized {
type Expression: Clone + Send + Sync;
type Params;
fn enumerate<'a>(
&'a self,
tp: TypeSchema,
) -> Box<Iterator<Item = (Self::Expression, f64)> + 'a>;
fn compress<O: Sync>(
&self,
params: &Self::Params,
tasks: &[Task<Self, Self::Expression, O>],
frontiers: Vec<ECFrontier<Self>>,
) -> (Self, Vec<ECFrontier<Self>>);
fn ec<O: Sync>(
&self,
ecparams: &ECParams,
params: &Self::Params,
tasks: &[Task<Self, Self::Expression, O>],
) -> (Self, Vec<ECFrontier<Self>>) {
let frontiers = self.explore(ecparams, tasks);
self.compress(params, tasks, frontiers)
}
fn ec_with_recognition<O: Sync, R>(
&self,
ecparams: &ECParams,
params: &Self::Params,
tasks: &[Task<Self, Self::Expression, O>],
recognizer: R,
) -> (Self, Vec<ECFrontier<Self>>)
where
R: FnOnce(&Self, &[Task<Self, Self::Expression, O>]) -> Vec<Self>,
{
let recognized = recognizer(self, tasks);
let frontiers = self.explore_with_recognition(ecparams, tasks, &recognized);
self.compress(params, tasks, frontiers)
}
fn explore<O: Sync>(
&self,
ec_params: &ECParams,
tasks: &[Task<Self, Self::Expression, O>],
) -> Vec<ECFrontier<Self>> {
let mut tps = HashMap::new();
for (i, task) in tasks.into_iter().enumerate() {
tps.entry(&task.tp).or_insert_with(Vec::new).push((i, task))
}
let mut results: Vec<ECFrontier<Self>> =
(0..tasks.len()).map(|_| ECFrontier::default()).collect();
{
let mutex = Arc::new(Mutex::new(&mut results));
tps.into_par_iter()
.map(|(tp, tasks)| enumerate_solutions(self, ec_params, tp.clone(), tasks))
.flat_map(|iter| iter)
.for_each(move |(i, frontier)| {
let mut results = mutex.lock().unwrap();
results[i] = frontier
});
}
results
}
fn explore_with_recognition<O: Sync>(
&self,
ec_params: &ECParams,
tasks: &[Task<Self, Self::Expression, O>],
representations: &[Self],
) -> Vec<ECFrontier<Self>> {
tasks
.par_iter()
.zip(representations)
.enumerate()
.map(|(i, (t, repr))| {
enumerate_solutions(repr, ec_params, t.tp.clone(), vec![(i, t)])
.pop()
.unwrap()
.1
})
.collect()
}
}
fn enumerate_solutions<L, X, O: Sync>(
repr: &L,
params: &ECParams,
tp: TypeSchema,
tasks: Vec<(usize, &Task<L, X, O>)>,
) -> Vec<(usize, ECFrontier<L>)>
where
X: Send + Sync + Clone,
L: EC<Expression = X>,
{
let mut frontiers = tasks
.iter()
.map(|&(j, _)| (j, ECFrontier::default()))
.collect();
let mut tasks: Vec<_> = tasks
.into_iter()
.enumerate()
.map(|(i, (_, t))| (i, t))
.collect();
let mut timeout_complete: Box<Fn() -> bool> = Box::new(|| false);
let (tx, rx) = bounded(1);
if let Some(duration) = params.search_limit_timeout {
thread::spawn(move || {
thread::sleep(duration);
tx.send(()).unwrap_or(());
});
timeout_complete = Box::new(move || rx.try_recv().is_ok());
}
let mut dl_complete: Box<Fn(f64) -> bool> = Box::new(|_| false);
if let Some(dl) = params.search_limit_description_length {
dl_complete = Box::new(move |logprior| -logprior > dl);
}
let mut update = |frontiers: &mut Vec<(usize, ECFrontier<L>)>, expr: X, logprior: f64| {
let evaluations: Vec<_> = tasks
.par_iter()
.map(|&(i, t)| {
let log_likelihood = (t.oracle)(repr, &expr);
(i, t, log_likelihood)
})
.collect();
tasks = evaluations
.into_iter()
.filter_map(|(i, t, l)| {
if l.is_finite() {
frontiers[i].1.push(expr.clone(), logprior, l);
if frontiers[i].1.len() >= params.frontier_limit {
None
} else {
Some((i, t))
}
} else {
Some((i, t))
}
})
.collect();
!(tasks.is_empty() || timeout_complete() || dl_complete(logprior))
};
for (expr, log_prior) in repr.enumerate(tp) {
if !update(&mut frontiers, expr, log_prior) {
break;
}
}
frontiers
}
#[derive(Clone, Debug)]
pub struct ECFrontier<L: EC>(pub Vec<(L::Expression, f64, f64)>);
impl<L: EC> ECFrontier<L> {
pub fn push(&mut self, expr: L::Expression, log_prior: f64, log_likelihood: f64) {
self.0.push((expr, log_prior, log_likelihood))
}
pub fn best_solution(&self) -> Option<&(L::Expression, f64, f64)> {
self.0
.iter()
.max_by(|&&(_, xp, xl), &&(_, yp, yl)| (xp + xl).partial_cmp(&(yp + yl)).unwrap())
}
}
impl<L: EC> Default for ECFrontier<L> {
fn default() -> Self {
ECFrontier(vec![])
}
}
impl<L: EC> Deref for ECFrontier<L> {
type Target = Vec<(L::Expression, f64, f64)>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<L: EC> DerefMut for ECFrontier<L> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}