mod enumerator;
mod parser;
pub use self::parser::ParseError;
use std::cmp;
use std::collections::HashMap;
use std::f64;
use std::fmt::Debug;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use itertools::Itertools;
use polytype::{Type, TypeSchema};
use rand::Rng;
use rand::distributions::Range;
use rayon::prelude::*;
use {ECFrontier, Task, EC, GP};
#[derive(Debug, Clone)]
pub struct Grammar {
pub start: Type,
pub rules: HashMap<Type, Vec<Rule>>,
}
impl Grammar {
pub fn new(start: Type, all_rules: Vec<Rule>) -> Self {
let mut rules = HashMap::new();
for mut rule in all_rules {
let nt = if let Some(ret) = rule.production.returns() {
ret.clone()
} else {
rule.production.clone()
};
rule.logprob = rule.logprob.ln();
rules.entry(nt).or_insert_with(Vec::new).push(rule)
}
let mut g = Grammar { start, rules };
g.normalize();
g
}
pub fn enumerate<'a>(&'a self) -> Box<Iterator<Item = (AppliedRule, f64)> + 'a> {
self.enumerate_nonterminal(self.start.clone())
}
pub fn enumerate_nonterminal<'a>(
&'a self,
tp: Type,
) -> Box<Iterator<Item = (AppliedRule, f64)> + 'a> {
enumerator::new(self, tp)
}
pub fn update_parameters(&mut self, params: &EstimationParams, sentences: &[AppliedRule]) {
let mut counts: HashMap<Type, Vec<AtomicUsize>> = HashMap::new();
for (nt, rs) in &self.rules {
counts.insert(
nt.clone(),
(0..rs.len())
.map(|_| AtomicUsize::new(params.pseudocounts as usize))
.collect(),
);
}
let counts = Arc::new(counts);
sentences
.par_iter()
.for_each(|ar| update_counts(ar, &counts));
for (nt, cs) in Arc::try_unwrap(counts).unwrap() {
for (i, c) in cs.into_iter().enumerate() {
self.rules.get_mut(&nt).unwrap()[i].logprob = (c.into_inner() as f64).ln();
}
}
self.normalize();
}
pub fn eval<V, F>(&self, ar: &AppliedRule, evaluator: &F) -> V
where
F: Fn(&str, &[V]) -> V,
{
let args: Vec<V> = ar.2.iter().map(|ar| self.eval(ar, evaluator)).collect();
evaluator(self.rules[&ar.0][ar.1].name, &args)
}
pub fn sample<R: Rng>(&self, tp: &Type, rng: &mut R) -> AppliedRule {
enumerator::sample(self, tp, rng)
}
pub fn likelihood(&self, ar: &AppliedRule) -> f64 {
self.rules[&ar.0][ar.1].logprob + ar.2.iter().map(|ar| self.likelihood(ar)).sum::<f64>()
}
pub fn parse(&self, inp: &str) -> Result<AppliedRule, ParseError> {
self.parse_nonterminal(inp, self.start.clone())
}
pub fn parse_nonterminal(
&self,
inp: &str,
nonterminal: Type,
) -> Result<AppliedRule, ParseError> {
parser::parse(self, inp, nonterminal)
}
pub fn display(&self, ar: &AppliedRule) -> String {
let r = &self.rules[&ar.0][ar.1];
if r.production.as_arrow().is_some() {
let args = ar.2.iter().map(|ar| self.display(ar)).join(",");
format!("{}({})", r.name, args)
} else {
format!("{}", r.name)
}
}
fn normalize(&mut self) {
for rs in self.rules.values_mut() {
let lp_largest = rs.iter()
.fold(f64::NEG_INFINITY, |acc, r| acc.max(r.logprob));
let z = lp_largest
+ rs.iter()
.map(|r| (r.logprob - lp_largest).exp())
.sum::<f64>()
.ln();
for r in rs {
r.logprob -= z;
}
}
}
}
pub struct EstimationParams {
pub pseudocounts: u64,
}
impl Default for EstimationParams {
fn default() -> Self {
EstimationParams { pseudocounts: 1 }
}
}
impl EC for Grammar {
type Expression = AppliedRule;
type Params = EstimationParams;
fn enumerate<'a>(
&'a self,
tp: TypeSchema,
) -> Box<Iterator<Item = (Self::Expression, f64)> + 'a> {
match tp {
TypeSchema::Monotype(tp) => self.enumerate_nonterminal(tp),
_ => panic!("PCFGs can't handle polytypes"),
}
}
fn compress<O: Sync>(
&self,
params: &Self::Params,
_tasks: &[Task<Self, Self::Expression, O>],
frontiers: Vec<ECFrontier<Self>>,
) -> (Self, Vec<ECFrontier<Self>>) {
let mut counts: HashMap<Type, Vec<AtomicUsize>> = HashMap::new();
for (nt, rs) in &self.rules {
counts.insert(
nt.clone(),
(0..rs.len())
.map(|_| AtomicUsize::new(params.pseudocounts as usize))
.collect(),
);
}
let counts = Arc::new(counts);
frontiers
.par_iter()
.flat_map(|f| &f.0)
.for_each(|&(ref ar, _, _)| update_counts(ar, &counts));
let mut g = self.clone();
for (nt, cs) in Arc::try_unwrap(counts).unwrap() {
for (i, c) in cs.into_iter().enumerate() {
g.rules.get_mut(&nt).unwrap()[i].logprob = (c.into_inner() as f64).ln();
}
}
g.normalize();
(g, frontiers)
}
}
pub struct GeneticParams {
pub max_crossover_depth: u32,
pub mutation_point: f64,
pub mutation_subtree: f64,
pub mutation_reproduction: f64,
}
impl GP for Grammar {
type Expression = AppliedRule;
type Params = GeneticParams;
fn genesis<R: Rng>(
&self,
_params: &Self::Params,
rng: &mut R,
pop_size: usize,
tp: &TypeSchema,
) -> Vec<Self::Expression> {
let tp = match *tp {
TypeSchema::Monotype(ref tp) => tp,
_ => panic!("PCFGs can't handle polytypes"),
};
(0..pop_size).map(|_| self.sample(tp, rng)).collect()
}
fn mutate<R: Rng>(
&self,
params: &Self::Params,
rng: &mut R,
prog: &Self::Expression,
) -> Self::Expression {
let tot = params.mutation_point + params.mutation_subtree + params.mutation_reproduction;
match Range::sample_single(0f64, tot, rng) {
x if x < params.mutation_point => mutate_random_node(prog.clone(), rng, |ar, rng| {
let rule = &self.rules[&ar.0][ar.1];
let mut candidates: Vec<_> = self.rules[&ar.0]
.iter()
.enumerate()
.filter(|&(i, r)| r.production == rule.production && i != ar.1)
.map(|(i, _)| i)
.collect();
if candidates.is_empty() {
ar
} else {
rng.shuffle(&mut candidates);
AppliedRule(ar.0, candidates[0], ar.2)
}
}),
x if x < params.mutation_point + params.mutation_subtree => {
mutate_random_node(prog.clone(), rng, |ar, rng| self.sample(&ar.0, rng))
}
_ => prog.clone(), }
}
fn crossover<R: Rng>(
&self,
params: &Self::Params,
rng: &mut R,
parent1: &Self::Expression,
parent2: &Self::Expression,
) -> Vec<Self::Expression> {
let _ = (rng, params);
vec![parent1.clone(), parent2.clone()]
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AppliedRule(pub Type, pub usize, pub Vec<AppliedRule>);
#[derive(Debug, Clone)]
pub struct Rule {
pub name: &'static str,
pub production: Type,
pub logprob: f64,
}
impl Rule {
pub fn new(name: &'static str, production: Type, logprob: f64) -> Self {
Rule {
name,
production,
logprob,
}
}
}
impl Ord for Rule {
fn cmp(&self, other: &Rule) -> cmp::Ordering {
self.partial_cmp(other)
.expect("logprob for rule is not finite")
}
}
impl PartialOrd for Rule {
fn partial_cmp(&self, other: &Rule) -> Option<cmp::Ordering> {
self.logprob.partial_cmp(&other.logprob)
}
}
impl PartialEq for Rule {
fn eq(&self, other: &Rule) -> bool {
self.name == other.name && self.production == other.production
}
}
impl Eq for Rule {}
fn update_counts<'a>(ar: &'a AppliedRule, counts: &Arc<HashMap<Type, Vec<AtomicUsize>>>) {
counts[&ar.0][ar.1].fetch_add(1, Ordering::Relaxed);
ar.2.iter().for_each(move |ar| update_counts(ar, counts));
}
pub fn task_by_evaluation<'a, V, F>(
evaluator: &'a F,
output: &'a V,
tp: Type,
) -> Task<'a, Grammar, AppliedRule, &'a V>
where
V: PartialEq + Clone + Sync + Debug + 'a,
F: Fn(&str, &[V]) -> V + Sync + 'a,
{
let oracle = Box::new(move |g: &Grammar, ar: &AppliedRule| {
if output == &g.eval(ar, evaluator) {
0f64
} else {
f64::NEG_INFINITY
}
});
Task {
oracle,
observation: output,
tp: TypeSchema::Monotype(tp),
}
}
fn mutate_random_node<R, F>(ar: AppliedRule, rng: &mut R, mutation: F) -> AppliedRule
where
R: Rng,
F: Fn(AppliedRule, &mut R) -> AppliedRule,
{
mutation(ar, rng)
}