mod enumerator;
mod eval;
mod compression;
mod parser;
pub use self::compression::CompressionParams;
pub use self::eval::LispEvaluator;
pub use self::parser::ParseError;
use std::collections::{HashMap, VecDeque};
use std::f64;
use std::fmt::{self, Debug};
use std::rc::Rc;
use polytype::{Context, Type, UnificationError};
use {ECFrontier, Task, EC};
#[derive(Debug, Clone)]
pub struct Language {
pub primitives: Vec<(String, Type, f64)>,
pub invented: Vec<(Expression, Type, f64)>,
pub variable_logprob: f64,
}
impl Language {
pub fn uniform(primitives: Vec<(&str, Type)>) -> Self {
let primitives = primitives
.into_iter()
.map(|(s, t)| (String::from(s), t, 0f64))
.collect();
Language {
primitives,
invented: vec![],
variable_logprob: 0f64,
}
}
pub fn infer(&self, expr: &Expression) -> Result<Type, InferenceError> {
let mut ctx = Context::default();
let env = VecDeque::new();
let mut indices = HashMap::new();
expr.infer(self, &mut ctx, &env, &mut indices)
}
pub fn enumerate<'a>(&'a self, tp: Type) -> Box<Iterator<Item = (Expression, f64)> + 'a> {
enumerator::new(self, tp)
}
pub fn compress<O: Sync>(
&self,
params: &CompressionParams,
tasks: &[Task<Language, Expression, O>],
frontiers: Vec<ECFrontier<Self>>,
) -> (Self, Vec<ECFrontier<Self>>) {
compression::induce(self, params, tasks, frontiers)
}
pub fn eval<V, F>(&self, expr: &Expression, evaluator: &F, inps: &[V]) -> Option<V>
where
F: Fn(&str, &[V]) -> V,
V: Clone + PartialEq + Debug,
{
eval::simple_eval(self, expr, evaluator, inps)
}
pub fn likelihood(&self, request: &Type, expr: &Expression) -> f64 {
enumerator::likelihood(self, request, expr)
}
pub fn invent(
&mut self,
expr: Expression,
log_probability: f64,
) -> Result<usize, InferenceError> {
let tp = self.infer(&expr)?;
self.invented.push((expr, tp, log_probability));
Ok(self.invented.len() - 1)
}
pub fn strip_invented(&self, expr: &Expression) -> Expression {
expr.strip_invented(&self.invented)
}
pub fn parse(&self, inp: &str) -> Result<Expression, ParseError> {
parser::parse(self, inp)
}
pub fn display(&self, expr: &Expression) -> String {
expr.show(self, false)
}
pub fn lispify(&self, expr: &Expression, conversions: &HashMap<String, String>) -> String {
expr.as_lisp(self, false, conversions, 0)
}
fn candidates(
&self,
request: &Type,
ctx: &Context,
env: &VecDeque<Type>,
) -> Vec<(f64, Expression, Type, Context)> {
let prims = self.primitives
.iter()
.enumerate()
.map(|(i, &(_, ref tp, p))| (p, tp, true, Expression::Primitive(i)));
let invented = self.invented
.iter()
.enumerate()
.map(|(i, &(_, ref tp, p))| (p, tp, true, Expression::Invented(i)));
let indices = env.iter()
.enumerate()
.map(|(i, tp)| (self.variable_logprob, tp, false, Expression::Index(i)));
let mut cands: Vec<_> = prims
.chain(invented)
.chain(indices)
.filter_map(|(p, tp, instantiate, expr)| {
let mut ctx = ctx.clone();
let itp;
let tp = if instantiate {
itp = tp.instantiate_indep(&mut ctx);
&itp
} else {
tp
};
let ret = if let Type::Arrow(ref arrow) = *tp {
arrow.returns()
} else {
&tp
};
ctx.unify(ret, request).ok().map(|_| {
let tp = tp.apply(&ctx);
(p, expr, tp, ctx)
})
})
.collect();
let log_n_indexed = (cands
.iter()
.filter(|&&(_, ref expr, _, _)| match *expr {
Expression::Index(_) => true,
_ => false,
})
.count() as f64)
.ln();
for mut c in &mut cands {
if let Expression::Index(_) = c.1 {
c.0 -= log_n_indexed
}
}
let p_largest = cands
.iter()
.map(|&(p, _, _, _)| p)
.fold(f64::NEG_INFINITY, f64::max);
let z = p_largest
+ cands
.iter()
.map(|&(p, _, _, _)| (p - p_largest).exp())
.sum::<f64>()
.ln();
for mut c in &mut cands {
c.0 -= z;
}
cands
}
}
impl EC for Language {
type Expression = Expression;
type Params = CompressionParams;
fn enumerate<'a>(&'a self, tp: Type) -> Box<Iterator<Item = (Expression, f64)> + 'a> {
self.enumerate(tp)
}
fn compress<O: Sync>(
&self,
params: &Self::Params,
tasks: &[Task<Self, Self::Expression, O>],
frontiers: Vec<ECFrontier<Self>>,
) -> (Self, Vec<ECFrontier<Self>>) {
self.compress(params, tasks, frontiers)
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub enum Expression {
Primitive(usize),
Application(Box<Expression>, Box<Expression>),
Abstraction(Box<Expression>),
Index(usize),
Invented(usize),
}
impl Expression {
fn infer(
&self,
dsl: &Language,
mut ctx: &mut Context,
env: &VecDeque<Type>,
indices: &mut HashMap<usize, Type>,
) -> Result<Type, InferenceError> {
match *self {
Expression::Primitive(num) => if let Some(prim) = dsl.primitives.get(num as usize) {
Ok(prim.1.instantiate_indep(ctx))
} else {
Err(InferenceError::BadExpression(format!(
"primitive does not exist: {}",
num
)))
},
Expression::Application(ref f, ref x) => {
let f_tp = f.infer(dsl, &mut ctx, env, indices)?;
let x_tp = x.infer(dsl, &mut ctx, env, indices)?;
let ret_tp = ctx.new_variable();
ctx.unify(&f_tp, &arrow![x_tp, ret_tp.clone()])?;
Ok(ret_tp.apply(ctx))
}
Expression::Abstraction(ref body) => {
let arg_tp = ctx.new_variable();
let mut env = env.clone();
env.push_front(arg_tp.clone());
let ret_tp = body.infer(dsl, &mut ctx, &env, indices)?;
Ok(arrow![arg_tp, ret_tp].apply(ctx))
}
Expression::Index(i) => {
if (i as usize) < env.len() {
Ok(env[i as usize].apply(ctx))
} else {
Ok(indices
.entry(i - env.len())
.or_insert_with(|| ctx.new_variable())
.apply(ctx))
}
}
Expression::Invented(num) => if let Some(inv) = dsl.invented.get(num as usize) {
Ok(inv.1.instantiate_indep(ctx))
} else {
Err(InferenceError::BadExpression(format!(
"invention does not exist: {}",
num
)))
},
}
}
fn strip_invented(&self, invented: &[(Expression, Type, f64)]) -> Expression {
match *self {
Expression::Application(ref f, ref x) => Expression::Application(
Box::new(f.strip_invented(invented)),
Box::new(x.strip_invented(invented)),
),
Expression::Abstraction(ref body) => {
Expression::Abstraction(Box::new(body.strip_invented(invented)))
}
Expression::Invented(num) => invented[num].0.strip_invented(invented),
_ => self.clone(),
}
}
fn shift(&mut self, offset: i64) -> bool {
self.shift_internal(offset, 0)
}
fn shift_internal(&mut self, offset: i64, depth: usize) -> bool {
match *self {
Expression::Index(ref mut i) => {
if *i < depth {
true
} else if offset >= 0 {
*i += offset as usize;
true
} else if let Some(ni) = i.checked_sub((-offset) as usize) {
*i = ni;
true
} else {
false
}
}
Expression::Application(ref mut f, ref mut x) => {
f.shift_internal(offset, depth) && x.shift_internal(offset, depth)
}
Expression::Abstraction(ref mut body) => body.shift_internal(offset, depth + 1),
_ => true,
}
}
fn as_lisp(
&self,
dsl: &Language,
is_function: bool,
conversions: &HashMap<String, String>,
depth: usize,
) -> String {
match *self {
Expression::Primitive(num) => {
let name = &dsl.primitives[num as usize].0;
conversions.get(name).unwrap_or(name).to_string()
}
Expression::Application(ref f, ref x) => {
let f_lisp = f.as_lisp(dsl, true, conversions, depth);
let x_lisp = x.as_lisp(dsl, false, conversions, depth);
if is_function {
format!("{} {}", f_lisp, x_lisp)
} else {
format!("({} {})", f_lisp, x_lisp)
}
}
Expression::Abstraction(ref body) => {
let var = (97 + depth as u8) as char;
format!(
"(λ ({}) {})",
var,
body.as_lisp(dsl, false, conversions, depth + 1)
)
}
Expression::Index(i) => {
let var = (96 + (depth - i) as u8) as char;
format!("{}", var)
}
Expression::Invented(num) => {
dsl.invented[num as usize]
.0
.as_lisp(dsl, false, conversions, depth)
}
}
}
fn show(&self, dsl: &Language, is_function: bool) -> String {
match *self {
Expression::Primitive(num) => dsl.primitives[num as usize].0.clone(),
Expression::Application(ref f, ref x) => if is_function {
format!("{} {}", f.show(dsl, true), x.show(dsl, false))
} else {
format!("({} {})", f.show(dsl, true), x.show(dsl, false))
},
Expression::Abstraction(ref body) => format!("(λ {})", body.show(dsl, false)),
Expression::Index(i) => format!("${}", i),
Expression::Invented(num) => {
format!("#{}", dsl.invented[num as usize].0.show(dsl, false))
}
}
}
}
pub fn task_by_simple_evaluation<'a, V, F>(
simple_evaluator: &'a F,
tp: Type,
examples: &'a [(Vec<V>, V)],
) -> Task<'a, Language, Expression, &'a [(Vec<V>, V)]>
where
V: PartialEq + Clone + Sync + Debug + 'a,
F: Fn(&str, &[V]) -> V + Sync + 'a,
{
let oracle = Box::new(move |dsl: &Language, expr: &Expression| {
let success = examples.iter().all(|&(ref inps, ref out)| {
if let Some(ref o) = dsl.eval(expr, simple_evaluator, inps) {
o == out
} else {
false
}
});
if success {
0f64
} else {
f64::NEG_INFINITY
}
});
Task {
oracle,
observation: examples,
tp,
}
}
#[derive(Debug, Clone)]
struct LinkedList<T: Clone>(Option<(T, Rc<LinkedList<T>>)>);
impl<T: Clone> LinkedList<T> {
fn prepend(lst: &Rc<LinkedList<T>>, v: T) -> Rc<LinkedList<T>> {
Rc::new(LinkedList(Some((v, lst.clone()))))
}
fn as_vecdeque(&self) -> VecDeque<T> {
let mut lst: &Rc<LinkedList<T>>;
let mut out = VecDeque::new();
if let Some((ref v, ref nlst)) = self.0 {
out.push_back(v.clone());
lst = nlst;
while let Some((ref v, ref nlst)) = lst.0 {
out.push_back(v.clone());
lst = nlst;
}
}
out
}
fn len(&self) -> usize {
let mut lst: &Rc<LinkedList<T>>;
let mut n = 0;
if let Some((_, ref nlst)) = self.0 {
n += 1;
lst = nlst;
while let Some((_, ref nlst)) = lst.0 {
n += 1;
lst = nlst;
}
}
n
}
}
impl<T: Clone> Default for LinkedList<T> {
fn default() -> Self {
LinkedList(None)
}
}
impl<T: Clone> ::std::ops::Index<usize> for LinkedList<T> {
type Output = T;
fn index(&self, i: usize) -> &Self::Output {
let mut lst: &Rc<LinkedList<T>>;
let mut n = 0;
if let Some((ref v, ref nlst)) = self.0 {
if i == n {
return v;
}
n += 1;
lst = nlst;
while let Some((ref v, ref nlst)) = lst.0 {
if i == n {
return v;
}
n += 1;
lst = nlst;
}
}
panic!("index out of bounds");
}
}
#[derive(Debug, Clone)]
pub enum InferenceError {
BadExpression(String),
Unify(UnificationError),
}
impl From<UnificationError> for InferenceError {
fn from(err: UnificationError) -> Self {
InferenceError::Unify(err)
}
}
impl fmt::Display for InferenceError {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
match *self {
InferenceError::BadExpression(ref msg) => write!(f, "invalid expression: '{}'", msg),
InferenceError::Unify(ref err) => write!(f, "could not unify to infer type: {}", err),
}
}
}
impl ::std::error::Error for InferenceError {
fn description(&self) -> &str {
"could not infer type"
}
}