mod lexicon;
pub mod parser;
mod rewrite;
pub use self::lexicon::{GeneticParams, Lexicon};
pub use self::parser::{
parse_context, parse_lexicon, parse_rule, parse_rulecontext, parse_templates, parse_trs,
};
pub use self::rewrite::TRS;
use crate::Task;
use polytype;
use polytype::TypeScheme;
use serde::{Deserialize, Serialize};
use std::fmt;
use term_rewriting::{Rule, TRSError};
#[derive(Debug, Clone)]
pub enum TypeError {
Unification(polytype::UnificationError),
OpNotFound,
VarNotFound,
}
impl From<polytype::UnificationError> for TypeError {
fn from(e: polytype::UnificationError) -> TypeError {
TypeError::Unification(e)
}
}
impl fmt::Display for TypeError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
TypeError::Unification(ref e) => write!(f, "unification error: {}", e),
TypeError::OpNotFound => write!(f, "operator not found"),
TypeError::VarNotFound => write!(f, "variable not found"),
}
}
}
impl ::std::error::Error for TypeError {
fn description(&self) -> &'static str {
"type error"
}
}
#[derive(Debug, Clone)]
pub enum SampleError {
TypeError(TypeError),
TRSError(TRSError),
SizeExceeded(usize, usize),
OptionsExhausted,
Subterm,
}
impl From<TypeError> for SampleError {
fn from(e: TypeError) -> SampleError {
SampleError::TypeError(e)
}
}
impl From<TRSError> for SampleError {
fn from(e: TRSError) -> SampleError {
SampleError::TRSError(e)
}
}
impl From<polytype::UnificationError> for SampleError {
fn from(e: polytype::UnificationError) -> SampleError {
SampleError::TypeError(TypeError::Unification(e))
}
}
impl fmt::Display for SampleError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
SampleError::TypeError(ref e) => write!(f, "type error: {}", e),
SampleError::TRSError(ref e) => write!(f, "TRS error: {}", e),
SampleError::SizeExceeded(size, max_size) => {
write!(f, "size {} exceeded maximum of {}", size, max_size)
}
SampleError::OptionsExhausted => write!(f, "failed to sample (options exhausted)"),
SampleError::Subterm => write!(f, "cannot sample subterm"),
}
}
}
impl ::std::error::Error for SampleError {
fn description(&self) -> &'static str {
"sample error"
}
}
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
pub struct ModelParams {
pub p_partial: f64,
pub p_observe: f64,
pub max_steps: usize,
pub max_size: Option<usize>,
}
impl Default for ModelParams {
fn default() -> ModelParams {
ModelParams {
p_partial: 0.0,
p_observe: 0.0,
max_steps: 50,
max_size: Some(500),
}
}
}
pub fn task_by_rewrite<'a, O: Sync + 'a>(
data: &'a [Rule],
params: ModelParams,
lex: &Lexicon,
observation: O,
) -> Result<impl Task<O, Representation = Lexicon, Expression = TRS> + 'a, TypeError> {
let mut ctx = lex.0.read().expect("poisoned lexicon").ctx.clone();
let tp = lex.infer_rules(data, &mut ctx)?;
Ok(TrsTask {
data,
params,
tp,
observation,
})
}
struct TrsTask<'a, O> {
data: &'a [Rule],
params: ModelParams,
tp: TypeScheme,
observation: O,
}
impl<'a, O: Sync> Task<O> for TrsTask<'a, O> {
type Representation = Lexicon;
type Expression = TRS;
fn oracle(&self, _: &Lexicon, h: &TRS) -> f64 {
-h.posterior(self.data, self.params)
}
fn tp(&self) -> &TypeScheme {
&self.tp
}
fn observation(&self) -> &O {
&self.observation
}
}