use crossbeam_channel::bounded;
use polytype::TypeScheme;
use rayon::prelude::*;
use std::collections::HashMap;
use std::ops::{Deref, DerefMut};
use std::sync::{Arc, Mutex, RwLock};
use std::thread;
use std::time::Duration;
use crate::Task;
pub struct ECParams {
pub frontier_limit: usize,
pub search_limit_timeout: Option<Duration>,
pub search_limit_description_length: Option<f64>,
}
pub trait EC<Observation: ?Sized>: Sync + Sized {
type Expression: Clone + Send + Sync;
type Params;
fn enumerate<F>(&self, tp: TypeScheme, termination_condition: F)
where
F: Fn(Self::Expression, f64) -> bool + Sync;
fn compress(
&self,
params: &Self::Params,
tasks: &[impl Task<Observation, Representation = Self, Expression = Self::Expression>],
frontiers: Vec<ECFrontier<Self::Expression>>,
) -> (Self, Vec<ECFrontier<Self::Expression>>);
fn ec(
&self,
ecparams: &ECParams,
params: &Self::Params,
tasks: &[impl Task<Observation, Representation = Self, Expression = Self::Expression>],
) -> (Self, Vec<ECFrontier<Self::Expression>>) {
let frontiers = self.explore(ecparams, tasks);
if cfg!(feature = "verbose") {
eprintln!(
"EXPLORE-COMPRESS: explored {} frontiers with {} hits",
frontiers.len(),
frontiers.iter().filter(|f| !f.is_empty()).count()
)
}
self.compress(params, tasks, frontiers)
}
fn ec_with_recognition<T, R>(
&self,
ecparams: &ECParams,
params: &Self::Params,
tasks: &[T],
recognizer: R,
) -> (Self, Vec<ECFrontier<Self::Expression>>)
where
T: Task<Observation, Representation = Self, Expression = Self::Expression>,
R: FnOnce(&Self, &[T]) -> Vec<Self>,
{
let recognized = recognizer(self, tasks);
let frontiers = self.explore_with_recognition(ecparams, tasks, &recognized);
self.compress(params, tasks, frontiers)
}
fn explore(
&self,
ec_params: &ECParams,
tasks: &[impl Task<Observation, Representation = Self, Expression = Self::Expression>],
) -> Vec<ECFrontier<Self::Expression>> {
let mut tps = HashMap::new();
for (i, task) in tasks.iter().enumerate() {
tps.entry(task.tp())
.or_insert_with(Vec::new)
.push((i, task))
}
let mut results: Vec<ECFrontier<Self::Expression>> =
(0..tasks.len()).map(|_| ECFrontier::default()).collect();
{
let mutex = Arc::new(Mutex::new(&mut results));
tps.into_par_iter()
.flat_map(|(tp, tasks)| enumerate_solutions(self, ec_params, tp.clone(), tasks))
.for_each(move |(i, frontier)| {
let mut results = mutex.lock().unwrap();
results[i] = frontier
});
}
results
}
fn explore_with_recognition(
&self,
ec_params: &ECParams,
tasks: &[impl Task<Observation, Representation = Self, Expression = Self::Expression>],
representations: &[Self],
) -> Vec<ECFrontier<Self::Expression>> {
tasks
.par_iter()
.zip(representations)
.enumerate()
.map(|(i, (t, repr))| {
enumerate_solutions(repr, ec_params, t.tp().clone(), vec![(i, t)])
.pop()
.unwrap()
.1
})
.collect()
}
}
fn enumerate_solutions<Observation, L, T>(
repr: &L,
params: &ECParams,
tp: TypeScheme,
tasks: Vec<(usize, &T)>,
) -> Vec<(usize, ECFrontier<L::Expression>)>
where
Observation: ?Sized,
L: EC<Observation>,
T: Task<Observation, Representation = L, Expression = L::Expression>,
{
let frontiers: Vec<_> = tasks .into_iter()
.map(|(j, t)| (j, Some(t), ECFrontier::default()))
.collect();
let frontiers = Arc::new(RwLock::new(frontiers));
let mut timeout_complete: Box<dyn Fn() -> bool + Send + Sync> = Box::new(|| false);
if let Some(duration) = params.search_limit_timeout {
let (tx, rx) = bounded(1);
thread::spawn(move || {
thread::sleep(duration);
tx.send(()).unwrap_or(())
});
timeout_complete = Box::new(move || rx.try_recv().is_ok());
}
let mut dl_complete: Box<dyn Fn(f64) -> bool + Send + Sync> = Box::new(|_| false);
if let Some(dl) = params.search_limit_description_length {
dl_complete = Box::new(move |logprior| -logprior > dl);
}
let is_terminated = Arc::new(RwLock::new(false));
let termination_condition = {
let frontiers = Arc::clone(&frontiers);
move |expr: L::Expression, logprior: f64| {
if *is_terminated.read().unwrap() {
return true;
}
let hits: Vec<_> = frontiers
.read()
.expect("enumeration frontiers poisoned")
.iter()
.enumerate()
.filter_map(|(i, (_, ot, _))| ot.as_ref().map(|t| (i, t))) .filter_map(|(i, t)| {
let l = t.oracle(repr, &expr);
if l.is_finite() {
Some((i, expr.clone(), logprior, l))
} else {
None
}
})
.collect();
if !hits.is_empty() {
let mut frontiers = frontiers.write().expect("enumeration frontiers poisoned");
for (i, expr, logprior, l) in hits {
frontiers[i].2.push(expr, logprior, l);
if frontiers[i].2.len() >= params.frontier_limit {
frontiers[i].1 = None
}
}
}
let mut is_terminated = is_terminated.write().unwrap();
if *is_terminated
| frontiers
.read()
.expect("enumeration frontiers poisoned")
.is_empty()
| timeout_complete()
| dl_complete(logprior)
{
*is_terminated = true;
true
} else {
false
}
}
};
repr.enumerate(tp, termination_condition);
if let Ok(lock) = Arc::try_unwrap(frontiers) {
let frontiers = lock.into_inner().expect("enumeration frontiers poisoned");
frontiers.into_iter().map(|(j, _, f)| (j, f)).collect()
} else {
panic!("enumeration lifetime exceeded its scope")
}
}
#[derive(Clone, Debug)]
pub struct ECFrontier<Expression>(pub Vec<(Expression, f64, f64)>);
impl<Expression> ECFrontier<Expression> {
pub fn push(&mut self, expr: Expression, log_prior: f64, log_likelihood: f64) {
self.0.push((expr, log_prior, log_likelihood))
}
pub fn best_solution(&self) -> Option<&(Expression, f64, f64)> {
self.0
.iter()
.max_by(|&&(_, xp, xl), &&(_, yp, yl)| (xp + xl).partial_cmp(&(yp + yl)).unwrap())
}
}
impl<Expression> Default for ECFrontier<Expression> {
fn default() -> Self {
ECFrontier(Default::default())
}
}
impl<Expression> Deref for ECFrontier<Expression> {
type Target = Vec<(Expression, f64, f64)>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<Expression> DerefMut for ECFrontier<Expression> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}