use crossbeam_channel::bounded;
use polytype::TypeSchema;
use rayon::prelude::*;
use std::collections::HashMap;
use std::ops::{Deref, DerefMut};
use std::sync::{Arc, Mutex, RwLock};
use std::thread;
use std::time::Duration;
use Task;
pub struct ECParams {
pub frontier_limit: usize,
pub search_limit_timeout: Option<Duration>,
pub search_limit_description_length: Option<f64>,
}
pub trait EC: Send + Sync + Sized {
type Expression: Clone + Send + Sync;
type Params;
fn enumerate<F>(&self, tp: TypeSchema, termination_condition: F)
where
F: Fn(Self::Expression, f64) -> bool + Send + Sync;
fn compress<O: Sync>(
&self,
params: &Self::Params,
tasks: &[Task<Self, Self::Expression, O>],
frontiers: Vec<ECFrontier<Self>>,
) -> (Self, Vec<ECFrontier<Self>>);
fn ec<O: Sync>(
&self,
ecparams: &ECParams,
params: &Self::Params,
tasks: &[Task<Self, Self::Expression, O>],
) -> (Self, Vec<ECFrontier<Self>>) {
let frontiers = self.explore(ecparams, tasks);
if cfg!(feature = "verbose") {
eprintln!(
"EXPLORE-COMPRESS: explored {} frontiers with {} hits",
frontiers.len(),
frontiers.iter().filter(|f| !f.is_empty()).count()
)
}
self.compress(params, tasks, frontiers)
}
fn ec_with_recognition<O: Sync, R>(
&self,
ecparams: &ECParams,
params: &Self::Params,
tasks: &[Task<Self, Self::Expression, O>],
recognizer: R,
) -> (Self, Vec<ECFrontier<Self>>)
where
R: FnOnce(&Self, &[Task<Self, Self::Expression, O>]) -> Vec<Self>,
{
let recognized = recognizer(self, tasks);
let frontiers = self.explore_with_recognition(ecparams, tasks, &recognized);
self.compress(params, tasks, frontiers)
}
fn explore<O: Sync>(
&self,
ec_params: &ECParams,
tasks: &[Task<Self, Self::Expression, O>],
) -> Vec<ECFrontier<Self>> {
let mut tps = HashMap::new();
for (i, task) in tasks.iter().enumerate() {
tps.entry(&task.tp).or_insert_with(Vec::new).push((i, task))
}
let mut results: Vec<ECFrontier<Self>> =
(0..tasks.len()).map(|_| ECFrontier::default()).collect();
{
let mutex = Arc::new(Mutex::new(&mut results));
tps.into_par_iter()
.flat_map(|(tp, tasks)| enumerate_solutions(self, ec_params, tp.clone(), tasks))
.for_each(move |(i, frontier)| {
let mut results = mutex.lock().unwrap();
results[i] = frontier
});
}
results
}
fn explore_with_recognition<O: Sync>(
&self,
ec_params: &ECParams,
tasks: &[Task<Self, Self::Expression, O>],
representations: &[Self],
) -> Vec<ECFrontier<Self>> {
tasks
.par_iter()
.zip(representations)
.enumerate()
.map(|(i, (t, repr))| {
enumerate_solutions(repr, ec_params, t.tp.clone(), vec![(i, t)])
.pop()
.unwrap()
.1
})
.collect()
}
}
fn enumerate_solutions<L, X, O: Sync>(
repr: &L,
params: &ECParams,
tp: TypeSchema,
tasks: Vec<(usize, &Task<L, X, O>)>,
) -> Vec<(usize, ECFrontier<L>)>
where
X: Send + Sync + Clone,
L: EC<Expression = X>,
{
let frontiers: Vec<_> = tasks .into_iter()
.map(|(j, t)| (j, Some(t), ECFrontier::default()))
.collect();
let frontiers = Arc::new(RwLock::new(frontiers));
let mut timeout_complete: Box<dyn Fn() -> bool + Send + Sync> = Box::new(|| false);
let (tx, rx) = bounded(1);
if let Some(duration) = params.search_limit_timeout {
thread::spawn(move || {
thread::sleep(duration);
tx.send(()).unwrap_or(())
});
timeout_complete = Box::new(move || rx.try_recv().is_ok());
}
let mut dl_complete: Box<dyn Fn(f64) -> bool + Send + Sync> = Box::new(|_| false);
if let Some(dl) = params.search_limit_description_length {
dl_complete = Box::new(move |logprior| -logprior > dl);
}
let is_terminated = Arc::new(RwLock::new(false));
let termination_condition = {
let frontiers = Arc::clone(&frontiers);
move |expr: X, logprior: f64| {
{
if *is_terminated.read().unwrap() {
return true;
}
}
let hits: Vec<_> = frontiers
.read()
.expect("enumeration frontiers poisoned")
.iter()
.enumerate()
.filter_map(|(i, &(_, ref ot, _))| ot.as_ref().map(|t| (i, t))) .filter_map(|(i, t)| {
let l = (t.oracle)(repr, &expr);
if l.is_finite() {
Some((i, expr.clone(), logprior, l))
} else {
None
}
})
.collect();
if !hits.is_empty() {
let mut frontiers = frontiers.write().expect("enumeration frontiers poisoned");
for (i, expr, logprior, l) in hits {
frontiers[i].2.push(expr, logprior, l);
if frontiers[i].2.len() >= params.frontier_limit {
frontiers[i].1 = None
}
}
}
let mut is_terminated = is_terminated.write().unwrap();
if *is_terminated
| frontiers
.read()
.expect("enumeration frontiers poisoned")
.is_empty()
| timeout_complete()
| dl_complete(logprior)
{
*is_terminated = true;
true
} else {
false
}
}
};
repr.enumerate(tp, termination_condition);
if let Ok(l) = Arc::try_unwrap(frontiers) {
let frontiers = l.into_inner().expect("enumeration frontiers poisoned");
frontiers.into_iter().map(|(j, _, f)| (j, f)).collect()
} else {
panic!("enumeration lifetime exceeded its scope")
}
}
#[derive(Clone, Debug)]
pub struct ECFrontier<L: EC>(pub Vec<(L::Expression, f64, f64)>);
impl<L: EC> ECFrontier<L> {
pub fn push(&mut self, expr: L::Expression, log_prior: f64, log_likelihood: f64) {
self.0.push((expr, log_prior, log_likelihood))
}
pub fn best_solution(&self) -> Option<&(L::Expression, f64, f64)> {
self.0
.iter()
.max_by(|&&(_, xp, xl), &&(_, yp, yl)| (xp + xl).partial_cmp(&(yp + yl)).unwrap())
}
}
impl<L: EC> Default for ECFrontier<L> {
fn default() -> Self {
ECFrontier(vec![])
}
}
impl<L: EC> Deref for ECFrontier<L> {
type Target = Vec<(L::Expression, f64, f64)>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<L: EC> DerefMut for ECFrontier<L> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}