use itertools::{repeat_n, Itertools};
use polytype::{Context as TypeContext, Type, TypeScheme, Variable as TypeVar};
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt;
use std::iter;
use std::sync::{Arc, RwLock};
use term_rewriting::{
Atom, Context, Operator, Place, Rule, RuleContext, Signature, Term, Variable, TRS as UntypedTRS,
};
use super::{SampleError, TypeError, TRS};
use crate::utils::{logsumexp, weighted_permutation};
use crate::GP;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeneticParams {
pub n_crosses: usize,
pub max_sample_size: usize,
pub p_add: f64,
pub p_keep: f64,
pub atom_weights: (f64, f64, f64),
}
#[derive(Clone)]
pub struct Lexicon(pub(crate) Arc<RwLock<Lex>>);
impl Lexicon {
pub fn new(
operators: Vec<(u32, Option<String>, TypeScheme)>,
deterministic: bool,
ctx: TypeContext,
) -> Lexicon {
let mut signature = Signature::default();
let mut ops = Vec::with_capacity(operators.len());
for (id, name, tp) in operators {
signature.new_op(id, name);
ops.push(tp)
}
Lexicon(Arc::new(RwLock::new(Lex {
ops,
vars: Vec::new(),
signature,
background: vec![],
templates: vec![],
deterministic,
ctx,
})))
}
pub fn from_signature(
signature: Signature,
ops: Vec<TypeScheme>,
vars: Vec<TypeScheme>,
background: Vec<Rule>,
templates: Vec<RuleContext>,
deterministic: bool,
ctx: TypeContext,
) -> Lexicon {
Lexicon(Arc::new(RwLock::new(Lex {
ops,
vars,
signature,
background,
templates,
deterministic,
ctx,
})))
}
pub fn has_op(&self, name: Option<&str>, arity: u32) -> Option<Operator> {
let sig = &self.0.read().expect("poisoned lexicon").signature;
sig.operators()
.into_iter()
.find(|op| op.arity() == arity && op.name().as_deref() == name)
}
pub fn free_vars(&self) -> Vec<TypeVar> {
self.0.read().expect("poisoned lexicon").free_vars()
}
pub fn context(&self) -> TypeContext {
self.0.read().expect("poisoned lexicon").ctx.clone()
}
pub fn infer_context(
&self,
context: &Context,
ctx: &mut TypeContext,
) -> Result<TypeScheme, TypeError> {
let lex = self.0.write().expect("poisoned lexicon");
lex.infer_context(context, ctx)
}
pub fn infer_rulecontext(
&self,
context: &RuleContext,
ctx: &mut TypeContext,
) -> Result<TypeScheme, TypeError> {
let lex = self.0.write().expect("poisoned lexicon");
lex.infer_rulecontext(context, ctx)
}
pub fn infer_rule(&self, rule: &Rule, ctx: &mut TypeContext) -> Result<TypeScheme, TypeError> {
let lex = self.0.write().expect("poisoned lexicon");
lex.infer_rule(rule, ctx).map(|(r, _, _)| r)
}
pub fn infer_rules(
&self,
rules: &[Rule],
ctx: &mut TypeContext,
) -> Result<TypeScheme, TypeError> {
let lex = self.0.write().expect("poisoned lexicon");
lex.infer_rules(rules, ctx)
}
pub fn infer_op(&self, op: &Operator) -> Result<TypeScheme, TypeError> {
self.0.write().expect("poisoned lexicon").op_tp(op)
}
#[cfg_attr(feature = "cargo-clippy", allow(clippy::too_many_arguments))]
pub fn sample_term<R: Rng>(
&mut self,
rng: &mut R,
scheme: &TypeScheme,
ctx: &mut TypeContext,
atom_weights: (f64, f64, f64),
invent: bool,
variable: bool,
max_size: usize,
) -> Result<Term, SampleError> {
let mut lex = self.0.write().expect("poisoned lexicon");
lex.sample_term(
rng,
scheme,
ctx,
atom_weights,
invent,
variable,
max_size,
0,
)
}
#[cfg_attr(feature = "cargo-clippy", allow(clippy::too_many_arguments))]
pub fn sample_term_from_context<R: Rng>(
&mut self,
rng: &mut R,
context: &Context,
ctx: &mut TypeContext,
atom_weights: (f64, f64, f64),
invent: bool,
variable: bool,
max_size: usize,
) -> Result<Term, SampleError> {
let mut lex = self.0.write().expect("poisoned lexicon");
lex.sample_term_from_context(
rng,
context,
ctx,
atom_weights,
invent,
variable,
max_size,
0,
)
}
pub fn sample_rule<R: Rng>(
&mut self,
rng: &mut R,
scheme: &TypeScheme,
ctx: &mut TypeContext,
atom_weights: (f64, f64, f64),
invent: bool,
max_size: usize,
) -> Result<Rule, SampleError> {
let mut lex = self.0.write().expect("poisoned lexicon");
lex.sample_rule(rng, scheme, ctx, atom_weights, invent, max_size, 0)
}
pub fn sample_rule_from_context<R: Rng>(
&mut self,
rng: &mut R,
context: RuleContext,
ctx: &mut TypeContext,
atom_weights: (f64, f64, f64),
invent: bool,
max_size: usize,
) -> Result<Rule, SampleError> {
let mut lex = self.0.write().expect("posioned lexicon");
lex.sample_rule_from_context(rng, context, ctx, atom_weights, invent, max_size, 0)
}
pub fn logprior_term(
&self,
term: &Term,
scheme: &TypeScheme,
ctx: &mut TypeContext,
atom_weights: (f64, f64, f64),
invent: bool,
) -> Result<f64, SampleError> {
let lex = self.0.read().expect("posioned lexicon");
lex.logprior_term(term, scheme, ctx, atom_weights, invent)
}
pub fn logprior_rule(
&self,
rule: &Rule,
scheme: &TypeScheme,
ctx: &mut TypeContext,
atom_weights: (f64, f64, f64),
invent: bool,
) -> Result<f64, SampleError> {
let lex = self.0.read().expect("poisoned lexicon");
lex.logprior_rule(rule, scheme, ctx, atom_weights, invent)
}
pub fn logprior_utrs(
&self,
utrs: &UntypedTRS,
schemes: &[TypeScheme],
p_rule: f64,
ctx: &mut TypeContext,
atom_weights: (f64, f64, f64),
invent: bool,
) -> Result<f64, SampleError> {
let lex = self.0.read().expect("poisoned lexicon");
lex.logprior_utrs(utrs, schemes, p_rule, ctx, atom_weights, invent)
}
pub fn combine<R: Rng>(&self, rng: &mut R, trs1: &TRS, trs2: &TRS) -> Result<TRS, TypeError> {
assert_eq!(trs1.lex, trs2.lex);
let background_size = trs1
.lex
.0
.read()
.expect("poisoned lexicon")
.background
.len();
let rules1 = trs1.utrs.rules[..(trs1.utrs.len() - background_size)].to_vec();
let rules2 = trs2.utrs.rules[..(trs2.utrs.len() - background_size)].to_vec();
let ctx = &self.0.read().expect("poisoned lexicon").ctx;
let mut trs = TRS::new(&trs1.lex, rules1, ctx)?;
trs.utrs.pushes(rules2).unwrap(); if self.0.read().expect("poisoned lexicon").deterministic {
make_deterministic(&mut trs.utrs, rng);
}
Ok(trs)
}
}
impl fmt::Debug for Lexicon {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let lex = self.0.read();
write!(f, "Lexicon({:?})", lex)
}
}
impl PartialEq for Lexicon {
fn eq(&self, other: &Lexicon) -> bool {
Arc::ptr_eq(&self.0, &other.0)
}
}
impl fmt::Display for Lexicon {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.0.read().expect("poisoned lexicon").fmt(f)
}
}
#[derive(Debug, Clone)]
pub(crate) struct Lex {
pub(crate) ops: Vec<TypeScheme>,
pub(crate) vars: Vec<TypeScheme>,
pub(crate) signature: Signature,
pub(crate) background: Vec<Rule>,
pub(crate) templates: Vec<RuleContext>,
pub(crate) deterministic: bool,
pub(crate) ctx: TypeContext,
}
impl fmt::Display for Lex {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, "Signature:")?;
for (op, scheme) in self.signature.operators().iter().zip(&self.ops) {
writeln!(f, "{}: {}", op.display(), scheme)?;
}
for (var, scheme) in self.signature.variables().iter().zip(&self.vars) {
writeln!(f, "{}: {}", var.display(), scheme)?;
}
writeln!(f, "\nBackground: {}", self.background.len())?;
for rule in &self.background {
writeln!(f, "{}", rule.pretty())?;
}
writeln!(f, "\nTemplates: {}", self.templates.len())?;
for template in &self.templates {
writeln!(f, "{}", template.pretty())?;
}
writeln!(f, "\nDeterministic: {}", self.deterministic)
}
}
impl Lex {
fn free_vars(&self) -> Vec<TypeVar> {
let vars_fvs = self.vars.iter().flat_map(TypeScheme::free_vars);
let ops_fvs = self.ops.iter().flat_map(TypeScheme::free_vars);
vars_fvs.chain(ops_fvs).unique().collect()
}
fn free_vars_applied(&self, ctx: &TypeContext) -> Vec<TypeVar> {
self.free_vars()
.into_iter()
.flat_map(|x| Type::Variable(x).apply(ctx).vars())
.unique()
.collect::<Vec<_>>()
}
fn invent_variable(&mut self, tp: &Type) -> Variable {
let var = self.signature.new_var(None);
self.vars.push(TypeScheme::Monotype(tp.clone()));
var
}
fn fit_atom(
&self,
atom: &Atom,
tp: &Type,
ctx: &mut TypeContext,
) -> Result<Vec<Type>, SampleError> {
let atom_tp = self.instantiate_atom(atom, ctx)?;
ctx.unify(atom_tp.returns().unwrap_or(&atom_tp), tp)?;
Ok(atom_tp
.args()
.map(|o| o.into_iter().cloned().collect())
.unwrap_or_else(Vec::new))
}
#[cfg_attr(feature = "cargo-clippy", allow(clippy::too_many_arguments))]
fn place_atom<R: Rng>(
&mut self,
rng: &mut R,
atom: &Atom,
arg_types: Vec<Type>,
ctx: &mut TypeContext,
atom_weights: (f64, f64, f64),
invent: bool,
max_size: usize,
size: usize,
vars: &mut Vec<Variable>,
) -> Result<Term, SampleError> {
let mut size = size;
match *atom {
Atom::Variable(ref v) => Ok(Term::Variable(v.clone())),
Atom::Operator(ref op) => {
size += 1;
let orig_ctx = ctx.clone(); let mut args = Vec::with_capacity(arg_types.len());
let can_be_variable = true;
for arg_tp in arg_types {
let subtype = arg_tp.apply(ctx);
let arg_scheme = TypeScheme::Monotype(arg_tp);
let result = self
.sample_term_internal(
rng,
&arg_scheme,
ctx,
atom_weights,
invent,
can_be_variable,
max_size,
size,
vars,
)
.map_err(|_| SampleError::Subterm)
.and_then(|subterm| {
let tp = self.infer_term(&subterm, ctx)?.instantiate_owned(ctx);
ctx.unify_fast(subtype, tp)?;
Ok(subterm)
});
match result {
Ok(subterm) => {
size += subterm.size();
args.push(subterm);
}
Err(e) => {
*ctx = orig_ctx;
return Err(e);
}
}
}
Ok(Term::Application {
op: op.clone(),
args,
})
}
}
}
fn instantiate_atom(&self, atom: &Atom, ctx: &mut TypeContext) -> Result<Type, TypeError> {
let mut tp = self.infer_atom(atom)?.instantiate_owned(ctx);
tp.apply_mut(ctx);
Ok(tp)
}
fn var_tp(&self, v: &Variable) -> Result<TypeScheme, TypeError> {
if let Some(idx) = self.signature.variables().iter().position(|x| x == v) {
Ok(self.vars[idx].clone())
} else {
Err(TypeError::VarNotFound)
}
}
fn op_tp(&self, o: &Operator) -> Result<TypeScheme, TypeError> {
if let Some(idx) = self.signature.operators().iter().position(|x| x == o) {
Ok(self.ops[idx].clone())
} else {
Err(TypeError::OpNotFound)
}
}
fn infer_atom(&self, atom: &Atom) -> Result<TypeScheme, TypeError> {
match *atom {
Atom::Operator(ref o) => self.op_tp(o),
Atom::Variable(ref v) => self.var_tp(v),
}
}
pub fn infer_term(&self, term: &Term, ctx: &mut TypeContext) -> Result<TypeScheme, TypeError> {
let tp = self.infer_term_internal(term, ctx)?;
let lex_vars = self.free_vars_applied(ctx);
Ok(tp.apply(ctx).generalize(&lex_vars))
}
fn infer_term_internal(&self, term: &Term, ctx: &mut TypeContext) -> Result<Type, TypeError> {
if let Term::Application { ref op, ref args } = *term {
if op.arity() > 0 {
let head_type = self.instantiate_atom(&Atom::from(op.clone()), ctx)?;
let body_type = {
let mut pre_types = Vec::with_capacity(args.len() + 1);
for a in args {
pre_types.push(self.infer_term_internal(a, ctx)?);
}
pre_types.push(ctx.new_variable());
Type::from(pre_types)
};
ctx.unify(&head_type, &body_type)?;
return Ok(head_type.returns().unwrap_or(&head_type).apply(ctx));
}
}
self.instantiate_atom(&term.head(), ctx)
}
pub fn infer_context(
&self,
context: &Context,
ctx: &mut TypeContext,
) -> Result<TypeScheme, TypeError> {
let tp = self.infer_context_internal(context, ctx, vec![], &mut HashMap::new())?;
let lex_vars = self.free_vars_applied(ctx);
Ok(tp.apply(ctx).generalize(&lex_vars))
}
fn infer_context_internal(
&self,
context: &Context,
ctx: &mut TypeContext,
place: Place,
tps: &mut HashMap<Place, Type>,
) -> Result<Type, TypeError> {
let tp = match *context {
Context::Hole => ctx.new_variable(),
Context::Variable(ref v) => self.instantiate_atom(&Atom::from(v.clone()), ctx)?,
Context::Application { ref op, ref args } => {
let head_type = self.instantiate_atom(&Atom::from(op.clone()), ctx)?;
let body_type = {
let mut pre_types = Vec::with_capacity(args.len() + 1);
for (i, a) in args.iter().enumerate() {
let mut new_place = place.clone();
new_place.push(i);
pre_types.push(self.infer_context_internal(a, ctx, new_place, tps)?);
}
pre_types.push(ctx.new_variable());
Type::from(pre_types)
};
ctx.unify(&head_type, &body_type)?;
head_type.returns().unwrap_or(&head_type).apply(ctx)
}
};
tps.insert(place, tp.clone());
Ok(tp)
}
pub fn infer_rule(
&self,
r: &Rule,
ctx: &mut TypeContext,
) -> Result<(TypeScheme, TypeScheme, Vec<TypeScheme>), TypeError> {
let lhs_scheme = self.infer_term(&r.lhs, ctx)?;
let lhs_type = lhs_scheme.instantiate(ctx);
let mut rhs_types = Vec::with_capacity(r.rhs.len());
let mut rhs_schemes = Vec::with_capacity(r.rhs.len());
for rhs in &r.rhs {
let rhs_scheme = self.infer_term(rhs, ctx)?;
rhs_types.push(rhs_scheme.instantiate(ctx));
rhs_schemes.push(rhs_scheme);
}
for rhs_type in rhs_types {
ctx.unify(&lhs_type, &rhs_type)?;
}
let lex_vars = self.free_vars_applied(ctx);
let rule_scheme = lhs_type.apply(ctx).generalize(&lex_vars);
Ok((rule_scheme, lhs_scheme, rhs_schemes))
}
pub fn infer_rules(
&self,
rules: &[Rule],
ctx: &mut TypeContext,
) -> Result<TypeScheme, TypeError> {
let tp = ctx.new_variable();
let mut rule_tps = vec![];
for rule in rules.iter() {
let rule_tp = self.infer_rule(rule, ctx)?.0;
rule_tps.push(rule_tp.instantiate(ctx));
}
for rule_tp in rule_tps {
ctx.unify(&tp, &rule_tp)?;
}
let lex_vars = self.free_vars_applied(ctx);
Ok(tp.apply(ctx).generalize(&lex_vars))
}
pub fn infer_rulecontext(
&self,
context: &RuleContext,
ctx: &mut TypeContext,
) -> Result<TypeScheme, TypeError> {
let tp = self.infer_rulecontext_internal(context, ctx, &mut HashMap::new())?;
let lex_vars = self.free_vars_applied(ctx);
Ok(tp.apply(ctx).generalize(&lex_vars))
}
fn infer_rulecontext_internal(
&self,
context: &RuleContext,
ctx: &mut TypeContext,
tps: &mut HashMap<Place, Type>,
) -> Result<Type, TypeError> {
let lhs_type = self.infer_context_internal(&context.lhs, ctx, vec![0], tps)?;
let rhs_types = context
.rhs
.iter()
.enumerate()
.map(|(i, rhs)| self.infer_context_internal(rhs, ctx, vec![i + 1], tps))
.collect::<Result<Vec<Type>, _>>()?;
for rhs_type in rhs_types {
ctx.unify(&lhs_type, &rhs_type)?;
}
Ok(lhs_type.apply(ctx))
}
pub fn infer_utrs(&self, utrs: &UntypedTRS, ctx: &mut TypeContext) -> Result<(), TypeError> {
for rule in &utrs.rules {
self.infer_rule(rule, ctx)?;
}
Ok(())
}
#[cfg_attr(feature = "cargo-clippy", allow(clippy::too_many_arguments))]
pub fn sample_term<R: Rng>(
&mut self,
rng: &mut R,
scheme: &TypeScheme,
ctx: &mut TypeContext,
atom_weights: (f64, f64, f64),
invent: bool,
variable: bool,
max_size: usize,
size: usize,
) -> Result<Term, SampleError> {
self.sample_term_internal(
rng,
scheme,
ctx,
atom_weights,
invent,
variable,
max_size,
size,
&mut vec![],
)
}
#[cfg_attr(feature = "cargo-clippy", allow(clippy::too_many_arguments))]
pub fn sample_term_internal<R: Rng>(
&mut self,
rng: &mut R,
scheme: &TypeScheme,
ctx: &mut TypeContext,
atom_weights: (f64, f64, f64),
invent: bool,
variable: bool,
max_size: usize,
size: usize,
vars: &mut Vec<Variable>,
) -> Result<Term, SampleError> {
if size >= max_size {
return Err(SampleError::SizeExceeded(size, max_size));
}
let tp = scheme.instantiate(ctx);
let (atom, arg_types) =
self.prepare_option(rng, vars, atom_weights, invent, variable, &tp, ctx)?;
self.place_atom(
rng,
&atom,
arg_types,
ctx,
atom_weights,
invent,
max_size,
size,
vars,
)
}
#[cfg_attr(feature = "cargo-clippy", allow(clippy::too_many_arguments))]
fn prepare_option<R: Rng>(
&mut self,
rng: &mut R,
vars: &mut Vec<Variable>,
(vw, cw, ow): (f64, f64, f64),
invent: bool,
variable: bool,
tp: &Type,
ctx: &mut TypeContext,
) -> Result<(Atom, Vec<Type>), SampleError> {
let ops = self.signature.operators();
let mut options: Vec<_> = ops.into_iter().map(|o| Some(Atom::Operator(o))).collect();
if variable {
options.extend(vars.iter().cloned().map(|v| Some(Atom::Variable(v))));
if invent {
options.push(None);
}
}
let weights: Vec<_> = options
.iter()
.map(|ref o| match o {
None | Some(Atom::Variable(_)) => vw,
Some(Atom::Operator(ref o)) if o.arity() == 0 => cw,
Some(Atom::Operator(_)) => ow,
})
.collect();
for option in weighted_permutation(rng, &options, &weights, None) {
let atom = option.unwrap_or_else(|| {
let new_var = self.invent_variable(tp);
vars.push(new_var.clone());
Atom::Variable(new_var)
});
match self.fit_atom(&atom, tp, ctx) {
Ok(arg_types) => return Ok((atom, arg_types)),
_ => continue,
}
}
Err(SampleError::OptionsExhausted)
}
#[cfg_attr(feature = "cargo-clippy", allow(clippy::too_many_arguments))]
pub fn sample_term_from_context<R: Rng>(
&mut self,
rng: &mut R,
context: &Context,
ctx: &mut TypeContext,
atom_weights: (f64, f64, f64),
invent: bool,
variable: bool,
max_size: usize,
size: usize,
) -> Result<Term, SampleError> {
let mut map = HashMap::new();
let context = context.clone();
let hole_places = context.holes();
self.infer_context_internal(&context, ctx, vec![], &mut map)?;
let lex_vars = self.free_vars_applied(ctx);
let mut context_vars = context.variables();
for p in &hole_places {
let scheme = &map[p].apply(ctx).generalize(&lex_vars);
let subterm = self.sample_term_internal(
rng,
scheme,
ctx,
atom_weights,
invent,
variable,
max_size,
size,
&mut context_vars,
)?;
context.replace(p, Context::from(subterm));
}
context.to_term().or(Err(SampleError::Subterm))
}
#[cfg_attr(feature = "cargo-clippy", allow(clippy::too_many_arguments))]
pub fn sample_rule<R: Rng>(
&mut self,
rng: &mut R,
scheme: &TypeScheme,
ctx: &mut TypeContext,
atom_weights: (f64, f64, f64),
invent: bool,
max_size: usize,
size: usize,
) -> Result<Rule, SampleError> {
let orig_self = self.clone();
let orig_ctx = ctx.clone();
loop {
let mut vars = vec![];
let lhs = self.sample_term_internal(
rng,
scheme,
ctx,
atom_weights,
invent,
false,
max_size,
size,
&mut vars,
)?;
let rhs = self.sample_term_internal(
rng,
scheme,
ctx,
atom_weights,
false,
true,
max_size,
size,
&mut vars,
)?;
if let Some(rule) = Rule::new(lhs, vec![rhs]) {
return Ok(rule);
} else {
*self = orig_self.clone();
*ctx = orig_ctx.clone();
}
}
}
#[cfg_attr(feature = "cargo-clippy", allow(clippy::too_many_arguments))]
pub fn sample_rule_from_context<R: Rng>(
&mut self,
rng: &mut R,
mut context: RuleContext,
ctx: &mut TypeContext,
atom_weights: (f64, f64, f64),
invent: bool,
max_size: usize,
size: usize,
) -> Result<Rule, SampleError> {
let mut map = HashMap::new();
let hole_places = context.holes();
let mut context_vars = context.variables();
self.infer_rulecontext_internal(&context, ctx, &mut map)?;
for p in &hole_places {
let scheme = TypeScheme::Monotype(map[p].apply(ctx));
let can_invent = p[0] == 0 && invent;
let can_be_variable = p == &vec![0];
let subterm = self.sample_term_internal(
rng,
&scheme,
ctx,
atom_weights,
can_invent,
can_be_variable,
max_size,
size,
&mut context_vars,
)?;
context = context
.replace(p, Context::from(subterm))
.ok_or(SampleError::Subterm)?;
}
context.to_rule().or(Err(SampleError::Subterm))
}
fn logprior_term(
&self,
term: &Term,
scheme: &TypeScheme,
ctx: &mut TypeContext,
(vw, cw, ow): (f64, f64, f64),
invent: bool,
) -> Result<f64, SampleError> {
let tp = scheme.instantiate(ctx);
let (mut vs, mut cs, mut os) = (vec![], vec![], vec![]);
let atoms = self.signature.atoms();
for atom in &atoms {
if let Ok(arg_types) = self.fit_atom(atom, &tp, ctx) {
match *atom {
Atom::Variable(_) => vs.push((Some(atom.clone()), arg_types)),
Atom::Operator(ref o) if o.arity() == 0 => {
cs.push((Some(atom.clone()), arg_types))
}
Atom::Operator(_) => os.push((Some(atom.clone()), arg_types)),
}
}
}
if invent {
if let Term::Variable(ref v) = *term {
if !atoms.contains(&Atom::Variable(v.clone())) {
vs.push((Some(Atom::Variable(v.clone())), vec![]));
}
} else {
vs.push((None, vec![]));
}
}
let z = vw + cw + ow;
let (vw, cw, ow) = (vw / z, cw / z, ow / z);
let vlp = vw.ln() - (vs.len() as f64).ln();
let clp = cw.ln() - (cs.len() as f64).ln();
let olp = ow.ln() - (os.len() as f64).ln();
let mut options = vec![];
options.append(&mut vs);
options.append(&mut cs);
options.append(&mut os);
match options
.into_iter()
.find(|(o, _)| o == &Some(term.head()))
.map(|(o, arg_types)| (o.unwrap(), arg_types))
{
Some((Atom::Variable(_), _)) => Ok(vlp),
Some((Atom::Operator(_), ref arg_types)) if arg_types.is_empty() => Ok(clp),
Some((Atom::Operator(_), arg_types)) => {
let mut lp = olp;
for (subterm, mut arg_tp) in term.args().iter().zip(arg_types) {
arg_tp.apply_mut(ctx);
let arg_scheme = TypeScheme::Monotype(arg_tp.clone());
lp += self.logprior_term(subterm, &arg_scheme, ctx, (vw, cw, ow), invent)?;
let final_type = self.infer_term(subterm, ctx)?.instantiate_owned(ctx);
if ctx.unify(&arg_tp, &final_type).is_err() {
return Ok(f64::NEG_INFINITY);
}
}
Ok(lp)
}
None => Ok(f64::NEG_INFINITY),
}
}
fn logprior_rule(
&self,
rule: &Rule,
scheme: &TypeScheme,
ctx: &mut TypeContext,
atom_weights: (f64, f64, f64),
invent: bool,
) -> Result<f64, SampleError> {
let mut lp = 0.0;
let lp_lhs = self.logprior_term(&rule.lhs, scheme, ctx, atom_weights, invent)?;
for rhs in &rule.rhs {
let tmp_lp = self.logprior_term(rhs, scheme, ctx, atom_weights, false)?;
lp += tmp_lp + lp_lhs;
}
Ok(lp)
}
fn logprior_utrs(
&self,
utrs: &UntypedTRS,
schemes: &[TypeScheme],
p_rule: f64,
ctx: &mut TypeContext,
atom_weights: (f64, f64, f64),
invent: bool,
) -> Result<f64, SampleError> {
let p_n_rules = p_rule.ln() * (utrs.clauses().len() as f64);
let mut p_rules = 0.0;
for rule in &utrs.rules {
if self.background.contains(rule) {
continue;
}
let mut rule_ps = vec![];
for scheme in schemes {
let tmp_lp = self.logprior_rule(rule, scheme, ctx, atom_weights, invent)?;
rule_ps.push(tmp_lp);
}
p_rules += logsumexp(&rule_ps);
}
Ok(p_n_rules + p_rules)
}
}
impl GP<[Rule]> for Lexicon {
type Expression = TRS;
type Params = GeneticParams;
fn genesis<R: Rng>(
&self,
params: &Self::Params,
rng: &mut R,
pop_size: usize,
_tp: &TypeScheme,
) -> Vec<Self::Expression> {
let trs = TRS::new(
self,
Vec::new(),
&self.0.read().expect("poisoned lexicon").ctx,
);
match trs {
Ok(mut trs) => {
if self.0.read().expect("poisoned lexicon").deterministic {
make_deterministic(&mut trs.utrs, rng);
}
let templates = self.0.read().expect("poisoned lexicon").templates.clone();
repeat_n(trs, pop_size)
.map(|trs| loop {
if let Ok(new_trs) = trs.add_rule(
&templates,
params.atom_weights,
params.max_sample_size,
rng,
) {
return new_trs;
}
})
.collect()
}
Err(err) => {
let lex = self.0.read().expect("poisoned lexicon");
let background_trs = UntypedTRS::new(lex.background.clone());
panic!(
"invalid background knowledge {}: {}",
background_trs.display(),
err
)
}
}
}
fn mutate<R: Rng>(
&self,
params: &Self::Params,
rng: &mut R,
trs: &Self::Expression,
_obs: &[Rule],
) -> Vec<Self::Expression> {
loop {
if trs.is_empty() | rng.gen_bool(params.p_add) {
let templates = self.0.read().expect("poisoned lexicon").templates.clone();
if let Ok(new_trs) =
trs.add_rule(&templates, params.atom_weights, params.max_sample_size, rng)
{
return vec![new_trs];
}
} else if let Ok(new_trs) = trs.delete_rule(rng) {
return vec![new_trs];
}
}
}
fn crossover<R: Rng>(
&self,
params: &Self::Params,
rng: &mut R,
parent1: &Self::Expression,
parent2: &Self::Expression,
_obs: &[Rule],
) -> Vec<Self::Expression> {
let trs = self
.combine(rng, parent1, parent2)
.expect("poorly-typed TRS in crossover");
iter::repeat(trs)
.take(params.n_crosses)
.update(|trs| {
trs.utrs.rules.retain(|r| {
self.0
.read()
.expect("poisoned lexicon")
.background
.contains(r)
|| rng.gen_bool(params.p_keep)
})
})
.collect()
}
fn validate_offspring(
&self,
_params: &Self::Params,
population: &[(Self::Expression, f64)],
children: &[Self::Expression],
offspring: &mut Vec<Self::Expression>,
) {
offspring.retain(|x| {
!population
.iter()
.any(|p| UntypedTRS::alphas(&p.0.utrs, &x.utrs))
& !children
.iter()
.any(|c| UntypedTRS::alphas(&c.utrs, &x.utrs))
});
*offspring = offspring.iter().fold(vec![], |mut acc, x| {
if !acc.iter().any(|a| UntypedTRS::alphas(&a.utrs, &x.utrs)) {
acc.push(x.clone());
}
acc
});
}
}
fn make_deterministic<R: Rng>(utrs: &mut UntypedTRS, rng: &mut R) -> bool {
struct CompatRng<'a, R: Rng>(&'a mut R);
impl<'a, R: Rng> rand_core::RngCore for CompatRng<'a, R> {
fn next_u32(&mut self) -> u32 {
self.0.next_u32()
}
fn next_u64(&mut self) -> u64 {
self.0.next_u64()
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
self.0.fill_bytes(dest)
}
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> {
self.0.try_fill_bytes(dest).map_err(|_| {
rand_core::Error::new(rand_core::ErrorKind::Unexpected, "error behind rand compat")
})
}
}
utrs.make_deterministic(&mut CompatRng(rng))
}