mod enumerator;
mod parser;
pub use self::parser::ParseError;
use crossbeam_channel::bounded;
use itertools::Itertools;
use polytype::{Type, TypeScheme};
use rand::distributions::{Distribution, Uniform};
use rand::seq::SliceRandom;
use rand::Rng;
use rayon::prelude::*;
use rayon::spawn;
use std::cmp;
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use crate::{ECFrontier, Task, EC, GP};
#[derive(Debug, Clone)]
pub struct Grammar {
pub start: Type,
pub rules: HashMap<Type, Vec<Rule>>,
}
impl Grammar {
pub fn new(start: Type, all_rules: Vec<Rule>) -> Self {
let mut rules = HashMap::new();
for mut rule in all_rules {
let nt = if let Some(ret) = rule.production.returns() {
ret.clone()
} else {
rule.production.clone()
};
rule.logprob = rule.logprob.ln();
rules.entry(nt).or_insert_with(Vec::new).push(rule)
}
let mut g = Grammar { start, rules };
g.normalize();
g
}
pub fn enumerate(&self) -> Box<dyn Iterator<Item = (AppliedRule, f64)>> {
self.enumerate_nonterminal(self.start.clone())
}
pub fn enumerate_nonterminal(
&self,
nonterminal: Type,
) -> Box<dyn Iterator<Item = (AppliedRule, f64)>> {
let (tx, rx) = bounded(1);
let g = self.clone();
spawn(move || {
let tx = tx.clone();
let termination_condition = &mut |expr, logprior| tx.send((expr, logprior)).is_err();
enumerator::new(&g, nonterminal, termination_condition)
});
Box::new(rx.into_iter())
}
pub fn update_parameters(&mut self, params: &EstimationParams, sentences: &[AppliedRule]) {
let mut counts: HashMap<Type, Vec<AtomicUsize>> = HashMap::new();
for (nt, rs) in &self.rules {
counts.insert(
nt.clone(),
(0..rs.len())
.map(|_| AtomicUsize::new(params.pseudocounts as usize))
.collect(),
);
}
let counts = Arc::new(counts);
sentences
.par_iter()
.for_each(|ar| update_counts(ar, &counts));
for (nt, cs) in Arc::try_unwrap(counts).unwrap() {
for (i, c) in cs.into_iter().enumerate() {
self.rules.get_mut(&nt).unwrap()[i].logprob = (c.into_inner() as f64).ln();
}
}
self.normalize();
}
pub fn eval<V, E, F>(&self, ar: &AppliedRule, evaluator: &F) -> Result<V, E>
where
F: Fn(&str, &[V]) -> Result<V, E>,
{
let args =
ar.2.iter()
.map(|ar| self.eval(ar, evaluator))
.collect::<Result<Vec<V>, E>>()?;
evaluator(self.rules[&ar.0][ar.1].name, &args)
}
pub fn sample<R: Rng>(&self, tp: &Type, rng: &mut R) -> AppliedRule {
enumerator::sample(self, tp, rng)
}
pub fn likelihood(&self, ar: &AppliedRule) -> f64 {
self.rules[&ar.0][ar.1].logprob + ar.2.iter().map(|ar| self.likelihood(ar)).sum::<f64>()
}
pub fn parse(&self, inp: &str) -> Result<AppliedRule, ParseError> {
self.parse_nonterminal(inp, self.start.clone())
}
pub fn parse_nonterminal(
&self,
inp: &str,
nonterminal: Type,
) -> Result<AppliedRule, ParseError> {
parser::parse(self, inp, nonterminal)
}
pub fn display(&self, ar: &AppliedRule) -> String {
let r = &self.rules[&ar.0][ar.1];
if r.production.as_arrow().is_some() {
let args = ar.2.iter().map(|ar| self.display(ar)).join(",");
format!("{}({})", r.name, args)
} else {
r.name.to_string()
}
}
fn normalize(&mut self) {
for rs in self.rules.values_mut() {
let lp_largest = rs
.iter()
.fold(f64::NEG_INFINITY, |acc, r| acc.max(r.logprob));
let z = lp_largest
+ rs.iter()
.map(|r| (r.logprob - lp_largest).exp())
.sum::<f64>()
.ln();
for r in rs {
r.logprob -= z;
}
}
}
}
pub struct EstimationParams {
pub pseudocounts: u64,
}
impl Default for EstimationParams {
fn default() -> Self {
EstimationParams { pseudocounts: 1 }
}
}
impl<Observation: ?Sized> EC<Observation> for Grammar {
type Expression = AppliedRule;
type Params = EstimationParams;
fn enumerate<F>(&self, tp: TypeScheme, termination_condition: F)
where
F: FnMut(Self::Expression, f64) -> bool,
{
match tp {
TypeScheme::Monotype(tp) => enumerator::new(self, tp, termination_condition),
_ => panic!("PCFGs can't handle polytypes"),
}
}
fn compress(
&self,
params: &Self::Params,
_tasks: &[impl Task<Observation, Representation = Self, Expression = Self::Expression>],
frontiers: Vec<ECFrontier<Self::Expression>>,
) -> (Self, Vec<ECFrontier<Self::Expression>>) {
let mut counts: HashMap<Type, Vec<AtomicUsize>> = HashMap::new();
for (nt, rs) in &self.rules {
counts.insert(
nt.clone(),
(0..rs.len())
.map(|_| AtomicUsize::new(params.pseudocounts as usize))
.collect(),
);
}
let counts = Arc::new(counts);
frontiers
.par_iter()
.flat_map(|f| &f.0)
.for_each(|(ar, _, _)| update_counts(ar, &counts));
let mut g = self.clone();
for (nt, cs) in Arc::try_unwrap(counts).unwrap() {
for (i, c) in cs.into_iter().enumerate() {
g.rules.get_mut(&nt).unwrap()[i].logprob = (c.into_inner() as f64).ln();
}
}
g.normalize();
(g, frontiers)
}
}
pub struct GeneticParams {
pub progeny_factor: f64,
pub mutation_point: f64,
pub mutation_subtree: f64,
pub mutation_reproduction: f64,
}
impl Default for GeneticParams {
fn default() -> GeneticParams {
GeneticParams {
progeny_factor: 2f64,
mutation_point: 0.45,
mutation_subtree: 0.45,
mutation_reproduction: 0.1,
}
}
}
impl GP<()> for Grammar {
type Expression = AppliedRule;
type Params = GeneticParams;
fn genesis<R: Rng>(
&self,
_params: &Self::Params,
rng: &mut R,
pop_size: usize,
tp: &TypeScheme,
) -> Vec<Self::Expression> {
let tp = match *tp {
TypeScheme::Monotype(ref tp) => tp,
_ => panic!("PCFGs can't handle polytypes"),
};
(0..pop_size).map(|_| self.sample(tp, rng)).collect()
}
fn mutate<R: Rng>(
&self,
params: &Self::Params,
rng: &mut R,
prog: &Self::Expression,
_obs: &(),
) -> Vec<Self::Expression> {
let tot = params.mutation_point + params.mutation_subtree + params.mutation_reproduction;
match Uniform::from(0f64..tot).sample(rng) {
x if x < params.mutation_point => {
vec![mutate_random_node(params, prog.clone(), rng, |ar, rng| {
let rule = &self.rules[&ar.0][ar.1];
let mut candidates: Vec<_> = self.rules[&ar.0]
.iter()
.enumerate()
.filter(|&(i, r)| r.production == rule.production && i != ar.1)
.map(|(i, _)| i)
.collect();
if candidates.is_empty() {
ar
} else {
candidates.shuffle(rng);
AppliedRule(ar.0, candidates[0], ar.2)
}
})]
}
x if x < params.mutation_point + params.mutation_subtree => {
vec![mutate_random_node(params, prog.clone(), rng, |ar, rng| {
self.sample(&ar.0, rng)
})]
}
_ => vec![prog.clone()], }
}
fn crossover<R: Rng>(
&self,
params: &Self::Params,
rng: &mut R,
parent1: &Self::Expression,
parent2: &Self::Expression,
_obs: &(),
) -> Vec<Self::Expression> {
vec![
crossover_random_node(params, parent1, parent2, rng),
crossover_random_node(params, parent2, parent1, rng),
]
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AppliedRule(pub Type, pub usize, pub Vec<AppliedRule>);
#[derive(Debug, Clone)]
pub struct Rule {
pub name: &'static str,
pub production: Type,
pub logprob: f64,
}
impl Rule {
pub fn new(name: &'static str, production: Type, logprob: f64) -> Self {
Rule {
name,
production,
logprob,
}
}
}
impl Ord for Rule {
fn cmp(&self, other: &Rule) -> cmp::Ordering {
self.logprob
.partial_cmp(&other.logprob)
.expect("logprob for rule is not finite")
}
}
impl PartialOrd for Rule {
fn partial_cmp(&self, other: &Rule) -> Option<cmp::Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for Rule {
fn eq(&self, other: &Rule) -> bool {
self.name == other.name && self.production == other.production
}
}
impl Eq for Rule {}
fn update_counts(ar: &AppliedRule, counts: &Arc<HashMap<Type, Vec<AtomicUsize>>>) {
counts[&ar.0][ar.1].fetch_add(1, Ordering::Relaxed);
ar.2.iter().for_each(move |ar| update_counts(ar, counts));
}
pub fn task_by_evaluation<'a, V, E, F>(
evaluator: &'a F,
output: &'a V,
tp: Type,
) -> impl Task<V, Representation = Grammar, Expression = AppliedRule> + 'a
where
E: 'a,
V: PartialEq + Sync + 'a,
F: Fn(&str, &[V]) -> Result<V, E> + Sync + 'a,
{
PcfgTask {
evaluator,
output,
tp: TypeScheme::Monotype(tp),
}
}
struct PcfgTask<'a, V, F> {
evaluator: F,
output: &'a V,
tp: TypeScheme,
}
impl<V, E, F> Task<V> for PcfgTask<'_, V, F>
where
V: PartialEq + Sync,
F: Fn(&str, &[V]) -> Result<V, E> + Sync,
{
type Representation = Grammar;
type Expression = AppliedRule;
fn oracle(&self, g: &Grammar, ar: &AppliedRule) -> f64 {
if let Ok(o) = g.eval(ar, &self.evaluator) {
if &o == self.output {
0f64
} else {
f64::NEG_INFINITY
}
} else {
f64::NEG_INFINITY
}
}
fn tp(&self) -> &TypeScheme {
&self.tp
}
fn observation(&self) -> &V {
self.output
}
}
use self::gp::{crossover_random_node, mutate_random_node};
mod gp {
use super::{AppliedRule, GeneticParams};
use polytype::Type;
use rand::distributions::{Distribution, Uniform};
use rand::Rng;
pub fn mutate_random_node<R, F>(
params: &GeneticParams,
ar: AppliedRule,
rng: &mut R,
mutation: F,
) -> AppliedRule
where
R: Rng,
F: Fn(AppliedRule, &mut R) -> AppliedRule,
{
let mut arc = WeightedAppliedRule::new(params, ar);
let mut selection = Uniform::from(0.0..arc.2).sample(rng);
{
let mut cur = &mut arc;
while selection > 1.0 {
selection -= 1.0;
selection /= params.progeny_factor;
let prev = cur;
cur = prev
.3
.iter_mut()
.find(|arc| {
if selection > arc.2 {
selection -= arc.2;
false
} else {
true
}
})
.unwrap();
}
let inp = AppliedRule::from(cur.clone());
let mutated = mutation(inp, rng);
*cur = WeightedAppliedRule::new(params, mutated);
}
AppliedRule::from(arc)
}
pub fn crossover_random_node<R: Rng>(
params: &GeneticParams,
parent1: &AppliedRule,
parent2: &AppliedRule,
rng: &mut R,
) -> AppliedRule {
mutate_random_node(params, parent1.clone(), rng, |ar, rng| {
let mut viables = Vec::new();
fetch_subtrees_with_type(params, parent2, &ar.0, 0, &mut viables);
if viables.is_empty() {
ar
} else {
let total = viables.iter().map(|&(weight, _)| weight).sum();
let mut idx = Uniform::from(0f64..total).sample(rng);
viables
.into_iter()
.find(|&(weight, _)| {
if idx > weight {
idx -= weight;
false
} else {
true
}
})
.unwrap()
.1
.clone()
}
})
}
fn fetch_subtrees_with_type<'a>(
params: &GeneticParams,
ar: &'a AppliedRule,
tp: &Type,
depth: usize,
viables: &mut Vec<(f64, &'a AppliedRule)>,
) {
if &ar.0 == tp {
viables.push((params.progeny_factor.powf(depth as f64), ar))
}
for ar in &ar.2 {
fetch_subtrees_with_type(params, ar, tp, depth + 1, viables)
}
}
#[derive(Debug, Clone)]
struct WeightedAppliedRule(Type, usize, f64, Vec<WeightedAppliedRule>);
impl WeightedAppliedRule {
fn new(params: &GeneticParams, ar: AppliedRule) -> Self {
if ar.2.is_empty() {
WeightedAppliedRule(ar.0, ar.1, 1.0, vec![])
} else {
let children: Vec<_> =
ar.2.into_iter()
.map(|ar| WeightedAppliedRule::new(params, ar))
.collect();
let children_weight: f64 = children.iter().map(|arc| arc.2).sum();
let weight = 1.0 + params.progeny_factor * children_weight;
WeightedAppliedRule(ar.0, ar.1, weight, children)
}
}
}
impl From<WeightedAppliedRule> for AppliedRule {
fn from(arc: WeightedAppliedRule) -> Self {
if arc.3.is_empty() {
AppliedRule(arc.0, arc.1, vec![])
} else {
let children = arc.3.into_iter().map(Self::from).collect();
AppliedRule(arc.0, arc.1, children)
}
}
}
}