mod compression;
mod enumerator;
mod eval;
mod parser;
pub use self::compression::CompressionParams;
pub use self::eval::{
Evaluator, LazyEvaluator, LiftedFunction, LiftedLazyFunction, SimpleEvaluator,
};
pub use self::parser::ParseError;
use crossbeam_channel::bounded;
use polytype::{Context, Type, TypeSchema, UnificationError};
use rayon::spawn;
use std::collections::{HashMap, VecDeque};
use std::error::Error;
use std::f64;
use std::fmt;
use std::ops::Index;
use std::rc::Rc;
use std::sync::Arc;
use {ECFrontier, Task, EC};
#[derive(Debug, Clone)]
pub struct Language {
pub primitives: Vec<(String, TypeSchema, f64)>,
pub invented: Vec<(Expression, TypeSchema, f64)>,
pub variable_logprob: f64,
pub symmetry_violations: Vec<(usize, usize, usize)>,
}
impl Language {
pub fn uniform(primitives: Vec<(&str, TypeSchema)>) -> Self {
let primitives = primitives
.into_iter()
.map(|(s, t)| (String::from(s), t, 0f64))
.collect();
Language {
primitives,
invented: vec![],
variable_logprob: 0f64,
symmetry_violations: Vec::new(),
}
}
pub fn infer(&self, expr: &Expression) -> Result<TypeSchema, InferenceError> {
let mut ctx = Context::default();
let env = VecDeque::new();
let mut indices = HashMap::new();
expr.infer(self, &mut ctx, &env, &mut indices)
.map(|t| t.generalize(&Context::default()))
}
pub fn enumerate(&self, tp: TypeSchema) -> Box<Iterator<Item = (Expression, f64)>> {
let (tx, rx) = bounded(1);
let dsl = self.clone();
spawn(move || {
let tx = tx.clone();
let termination_condition = |expr, logprior| tx.send((expr, logprior)).is_err();
enumerator::run(&dsl, tp, termination_condition)
});
Box::new(rx.into_iter())
}
pub fn compress<O: Sync>(
&self,
params: &CompressionParams,
tasks: &[Task<Language, Expression, O>],
frontiers: Vec<ECFrontier<Self>>,
) -> (Self, Vec<ECFrontier<Self>>) {
compression::induce(self, params, tasks, frontiers)
}
pub fn eval<V, E>(&self, expr: &Expression, evaluator: E, inps: &[V]) -> Result<V, E::Error>
where
V: Clone + PartialEq + Send + Sync,
E: Evaluator<Space = V>,
{
eval::eval(self, expr, &Arc::new(evaluator), inps)
}
pub fn eval_arc<V, E>(
&self,
expr: &Expression,
evaluator: &Arc<E>,
inps: &[V],
) -> Result<V, E::Error>
where
V: Clone + PartialEq + Send + Sync,
E: Evaluator<Space = V>,
{
eval::eval(self, expr, evaluator, inps)
}
pub fn lazy_eval<V, E>(
&self,
expr: &Expression,
evaluator: E,
inps: &[V],
) -> Result<V, E::Error>
where
V: Clone + PartialEq + Send + Sync,
E: LazyEvaluator<Space = V>,
{
eval::lazy_eval(self, expr, &Arc::new(evaluator), inps)
}
pub fn lazy_eval_arc<V, E>(
&self,
expr: &Expression,
evaluator: &Arc<E>,
inps: &[V],
) -> Result<V, E::Error>
where
V: Clone + PartialEq + Send + Sync,
E: LazyEvaluator<Space = V>,
{
eval::lazy_eval(self, expr, evaluator, inps)
}
pub fn likelihood(&self, request: &TypeSchema, expr: &Expression) -> f64 {
enumerator::likelihood(self, request, expr)
}
pub fn invent(
&mut self,
expr: Expression,
log_probability: f64,
) -> Result<usize, InferenceError> {
let tp = self.infer(&expr)?;
self.invented.push((expr, tp, log_probability));
Ok(self.invented.len() - 1)
}
pub fn add_symmetry_violation(&mut self, primitive: usize, arg_index: usize, arg: usize) {
let x = (primitive, arg_index, arg);
if let Err(i) = self.symmetry_violations.binary_search(&x) {
self.symmetry_violations.insert(i, x)
}
}
pub fn violates_symmetry(&self, f: &Expression, index: usize, x: &Expression) -> bool {
match (f, x) {
(&Expression::Primitive(f), &Expression::Primitive(x)) => {
let x = (f, index, x);
self.symmetry_violations.binary_search(&x).is_ok()
}
(&Expression::Primitive(f), &Expression::Application(ref x, _)) => {
let mut z: &Expression = &**x;
while let Expression::Application(ref x, _) = *z {
z = x
}
if let Expression::Primitive(x) = *z {
let x = (f, index, x);
self.symmetry_violations.binary_search(&x).is_ok()
} else {
false
}
}
_ => false,
}
}
pub fn strip_invented(&self, expr: &Expression) -> Expression {
expr.strip_invented(&self.invented)
}
pub fn parse(&self, inp: &str) -> Result<Expression, ParseError> {
parser::parse(self, inp)
}
pub fn display(&self, expr: &Expression) -> String {
expr.show(self, false)
}
pub fn lispify(&self, expr: &Expression, conversions: &HashMap<String, String>) -> String {
expr.as_lisp(self, false, conversions, 0)
}
fn candidates(
&self,
request: &Type,
ctx: &Context,
env: &VecDeque<Type>,
) -> Vec<(f64, Expression, Type, Context)> {
let mut cands = Vec::with_capacity(self.primitives.len() + self.invented.len() + env.len());
let prims = self.primitives
.iter()
.enumerate()
.map(|(i, &(_, ref tp, p))| (p, tp, Expression::Primitive(i)));
let invented = self.invented
.iter()
.enumerate()
.map(|(i, &(_, ref tp, p))| (p, tp, Expression::Invented(i)));
for (p, tp, expr) in prims.chain(invented) {
let mut ctx = ctx.clone();
let mut tp = tp.clone().instantiate_owned(&mut ctx);
let unifies = {
let ret = if let Some(ret) = tp.returns() {
ret
} else {
&tp
};
ctx.unify_fast(ret.clone(), request.clone()).is_ok()
};
if unifies {
tp.apply_mut(&ctx);
cands.push((p, expr, tp, ctx))
}
}
let indexed_start = cands.len();
for (i, tp) in env.iter().enumerate() {
let expr = Expression::Index(i);
let mut ctx = ctx.clone();
let ret = if let Some(ret) = tp.returns() {
ret
} else {
&tp
};
if ctx.unify_fast(ret.clone(), request.clone()).is_ok() {
let mut tp = tp.clone();
tp.apply_mut(&ctx);
cands.push((self.variable_logprob, expr, tp, ctx))
}
}
let log_n_indexed = ((cands.len() - indexed_start) as f64).ln();
for mut c in &mut cands[indexed_start..] {
c.0 -= log_n_indexed
}
let p_largest = cands
.iter()
.take(indexed_start + 1)
.map(|&(p, _, _, _)| p)
.fold(f64::NEG_INFINITY, f64::max);
let z = p_largest
+ cands
.iter()
.map(|&(p, _, _, _)| (p - p_largest).exp())
.sum::<f64>()
.ln();
for mut c in &mut cands {
c.0 -= z;
}
cands
}
}
impl EC for Language {
type Expression = Expression;
type Params = CompressionParams;
fn enumerate<F>(&self, tp: TypeSchema, termination_condition: F)
where
F: Fn(Expression, f64) -> bool + Send + Sync,
{
enumerator::run(self, tp, termination_condition)
}
fn compress<O: Sync>(
&self,
params: &Self::Params,
tasks: &[Task<Self, Self::Expression, O>],
frontiers: Vec<ECFrontier<Self>>,
) -> (Self, Vec<ECFrontier<Self>>) {
self.compress(params, tasks, frontiers)
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub enum Expression {
Primitive(usize),
Application(Box<Expression>, Box<Expression>),
Abstraction(Box<Expression>),
Index(usize),
Invented(usize),
}
impl Expression {
fn infer(
&self,
dsl: &Language,
mut ctx: &mut Context,
env: &VecDeque<Type>,
indices: &mut HashMap<usize, Type>,
) -> Result<Type, InferenceError> {
match *self {
Expression::Primitive(num) => if let Some(prim) = dsl.primitives.get(num as usize) {
Ok(prim.1.clone().instantiate_owned(ctx))
} else {
Err(InferenceError::InvalidPrimitive(num))
},
Expression::Application(ref f, ref x) => {
let f_tp = f.infer(dsl, &mut ctx, env, indices)?;
let x_tp = x.infer(dsl, &mut ctx, env, indices)?;
let ret_tp = ctx.new_variable();
ctx.unify(&f_tp, &Type::arrow(x_tp, ret_tp.clone()))?;
Ok(ret_tp.apply(ctx))
}
Expression::Abstraction(ref body) => {
let arg_tp = ctx.new_variable();
let mut env = env.clone();
env.push_front(arg_tp.clone());
let ret_tp = body.infer(dsl, &mut ctx, &env, indices)?;
let mut tp = Type::arrow(arg_tp, ret_tp);
tp.apply_mut(ctx);
Ok(tp)
}
Expression::Index(i) => {
if (i as usize) < env.len() {
let mut tp = env[i as usize].clone();
tp.apply_mut(ctx);
Ok(tp)
} else {
let mut tp = indices
.entry(i - env.len())
.or_insert_with(|| ctx.new_variable())
.clone();
tp.apply_mut(ctx);
Ok(tp)
}
}
Expression::Invented(num) => if let Some(inv) = dsl.invented.get(num as usize) {
Ok(inv.1.clone().instantiate_owned(ctx))
} else {
Err(InferenceError::InvalidInvention(num))
},
}
}
fn strip_invented(&self, invented: &[(Expression, TypeSchema, f64)]) -> Expression {
match *self {
Expression::Application(ref f, ref x) => Expression::Application(
Box::new(f.strip_invented(invented)),
Box::new(x.strip_invented(invented)),
),
Expression::Abstraction(ref body) => {
Expression::Abstraction(Box::new(body.strip_invented(invented)))
}
Expression::Invented(num) => invented[num].0.strip_invented(invented),
_ => self.clone(),
}
}
fn shift(&mut self, offset: i64) -> bool {
self.shift_internal(offset, 0)
}
fn shift_internal(&mut self, offset: i64, depth: usize) -> bool {
match *self {
Expression::Index(ref mut i) => {
if *i < depth {
true
} else if offset >= 0 {
*i += offset as usize;
true
} else if let Some(ni) = i.checked_sub((-offset) as usize) {
*i = ni;
true
} else {
false
}
}
Expression::Application(ref mut f, ref mut x) => {
f.shift_internal(offset, depth) && x.shift_internal(offset, depth)
}
Expression::Abstraction(ref mut body) => body.shift_internal(offset, depth + 1),
_ => true,
}
}
fn as_lisp(
&self,
dsl: &Language,
is_function: bool,
conversions: &HashMap<String, String>,
depth: usize,
) -> String {
match *self {
Expression::Primitive(num) => {
let name = &dsl.primitives[num as usize].0;
conversions.get(name).unwrap_or(name).to_string()
}
Expression::Application(ref f, ref x) => {
let f_lisp = f.as_lisp(dsl, true, conversions, depth);
let x_lisp = x.as_lisp(dsl, false, conversions, depth);
if is_function {
format!("{} {}", f_lisp, x_lisp)
} else {
format!("({} {})", f_lisp, x_lisp)
}
}
Expression::Abstraction(ref body) => {
let var = (97 + depth as u8) as char;
format!(
"(λ ({}) {})",
var,
body.as_lisp(dsl, false, conversions, depth + 1)
)
}
Expression::Index(i) => {
let var = (96 + (depth - i) as u8) as char;
format!("{}", var)
}
Expression::Invented(num) => {
dsl.invented[num as usize]
.0
.as_lisp(dsl, false, conversions, depth)
}
}
}
fn show(&self, dsl: &Language, is_function: bool) -> String {
match *self {
Expression::Primitive(num) => dsl.primitives[num as usize].0.clone(),
Expression::Application(ref f, ref x) => if is_function {
format!("{} {}", f.show(dsl, true), x.show(dsl, false))
} else {
format!("({} {})", f.show(dsl, true), x.show(dsl, false))
},
Expression::Abstraction(ref body) => format!("(λ {})", body.show(dsl, false)),
Expression::Index(i) => format!("${}", i),
Expression::Invented(num) => {
format!("#{}", dsl.invented[num as usize].0.show(dsl, false))
}
}
}
}
pub fn task_by_evaluation<'a, E, V>(
evaluator: E,
tp: TypeSchema,
examples: &'a [(Vec<V>, V)],
) -> Task<'a, Language, Expression, &'a [(Vec<V>, V)]>
where
E: Evaluator<Space = V> + Send + 'a,
V: PartialEq + Clone + Send + Sync + 'a,
{
let evaluator = Arc::new(evaluator);
let oracle = Box::new(move |dsl: &Language, expr: &Expression| {
let success = examples.iter().all(|&(ref inps, ref out)| {
if let Ok(o) = dsl.eval_arc(expr, &evaluator, inps) {
o == *out
} else {
false
}
});
if success {
0f64
} else {
f64::NEG_INFINITY
}
});
Task {
oracle,
observation: examples,
tp,
}
}
pub fn task_by_lazy_evaluation<'a, E, V>(
evaluator: E,
tp: TypeSchema,
examples: &'a [(Vec<V>, V)],
) -> Task<'a, Language, Expression, &'a [(Vec<V>, V)]>
where
E: LazyEvaluator<Space = V> + Send + 'a,
V: PartialEq + Clone + Send + Sync + 'a,
{
let evaluator = Arc::new(evaluator);
let oracle = Box::new(move |dsl: &Language, expr: &Expression| {
let success = examples.iter().all(|&(ref inps, ref out)| {
if let Ok(o) = dsl.lazy_eval_arc(expr, &evaluator, inps) {
o == *out
} else {
false
}
});
if success {
0f64
} else {
f64::NEG_INFINITY
}
});
Task {
oracle,
observation: examples,
tp,
}
}
#[derive(Debug, Clone)]
struct LinkedList<T: Clone>(Option<(T, Rc<LinkedList<T>>)>);
impl<T: Clone> LinkedList<T> {
fn prepend(lst: &Rc<LinkedList<T>>, v: T) -> Rc<LinkedList<T>> {
Rc::new(LinkedList(Some((v, lst.clone()))))
}
fn as_vecdeque(&self) -> VecDeque<T> {
let mut lst: &Rc<LinkedList<T>>;
let mut out = VecDeque::new();
if let Some((ref v, ref nlst)) = self.0 {
out.push_back(v.clone());
lst = nlst;
while let Some((ref v, ref nlst)) = lst.0 {
out.push_back(v.clone());
lst = nlst;
}
}
out
}
fn len(&self) -> usize {
let mut lst: &Rc<LinkedList<T>>;
let mut n = 0;
if let Some((_, ref nlst)) = self.0 {
n += 1;
lst = nlst;
while let Some((_, ref nlst)) = lst.0 {
n += 1;
lst = nlst;
}
}
n
}
}
impl<T: Clone> Default for LinkedList<T> {
fn default() -> Self {
LinkedList(None)
}
}
impl<T: Clone> Index<usize> for LinkedList<T> {
type Output = T;
fn index(&self, i: usize) -> &Self::Output {
let mut lst: &Rc<LinkedList<T>>;
let mut n = 0;
if let Some((ref v, ref nlst)) = self.0 {
if i == n {
return v;
}
n += 1;
lst = nlst;
while let Some((ref v, ref nlst)) = lst.0 {
if i == n {
return v;
}
n += 1;
lst = nlst;
}
}
panic!("index out of bounds");
}
}
#[derive(Debug, Clone)]
pub enum InferenceError {
InvalidPrimitive(usize),
InvalidInvention(usize),
Unify(UnificationError),
}
impl From<UnificationError> for InferenceError {
fn from(err: UnificationError) -> Self {
InferenceError::Unify(err)
}
}
impl fmt::Display for InferenceError {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
match *self {
InferenceError::InvalidPrimitive(n) => write!(f, "primitive {} not in Language", n),
InferenceError::InvalidInvention(n) => write!(f, "invention {} not in Language", n),
InferenceError::Unify(ref err) => write!(f, "could not unify to infer type: {}", err),
}
}
}
impl Error for InferenceError {
fn description(&self) -> &str {
"could not infer type"
}
}