programinduction/trs/
mod.rs1mod lexicon;
45pub mod parser;
46mod rewrite;
47pub use self::lexicon::{GeneticParams, Lexicon};
48pub use self::parser::{
49 parse_context, parse_lexicon, parse_rule, parse_rulecontext, parse_templates, parse_trs,
50};
51pub use self::rewrite::TRS;
52use crate::Task;
53
54use polytype;
55use polytype::TypeScheme;
56use serde::{Deserialize, Serialize};
57use std::fmt;
58use term_rewriting::{Rule, TRSError};
59
60#[derive(Debug, Clone)]
61pub enum TypeError {
63 Unification(polytype::UnificationError),
64 OpNotFound,
65 VarNotFound,
66}
67impl From<polytype::UnificationError> for TypeError {
68 fn from(e: polytype::UnificationError) -> TypeError {
69 TypeError::Unification(e)
70 }
71}
72impl fmt::Display for TypeError {
73 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
74 match *self {
75 TypeError::Unification(ref e) => write!(f, "unification error: {}", e),
76 TypeError::OpNotFound => write!(f, "operator not found"),
77 TypeError::VarNotFound => write!(f, "variable not found"),
78 }
79 }
80}
81impl ::std::error::Error for TypeError {
82 fn description(&self) -> &'static str {
83 "type error"
84 }
85}
86
87#[derive(Debug, Clone)]
88pub enum SampleError {
90 TypeError(TypeError),
91 TRSError(TRSError),
92 SizeExceeded(usize, usize),
93 OptionsExhausted,
94 Subterm,
95}
96impl From<TypeError> for SampleError {
97 fn from(e: TypeError) -> SampleError {
98 SampleError::TypeError(e)
99 }
100}
101impl From<TRSError> for SampleError {
102 fn from(e: TRSError) -> SampleError {
103 SampleError::TRSError(e)
104 }
105}
106impl From<polytype::UnificationError> for SampleError {
107 fn from(e: polytype::UnificationError) -> SampleError {
108 SampleError::TypeError(TypeError::Unification(e))
109 }
110}
111impl fmt::Display for SampleError {
112 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
113 match *self {
114 SampleError::TypeError(ref e) => write!(f, "type error: {}", e),
115 SampleError::TRSError(ref e) => write!(f, "TRS error: {}", e),
116 SampleError::SizeExceeded(size, max_size) => {
117 write!(f, "size {} exceeded maximum of {}", size, max_size)
118 }
119 SampleError::OptionsExhausted => write!(f, "failed to sample (options exhausted)"),
120 SampleError::Subterm => write!(f, "cannot sample subterm"),
121 }
122 }
123}
124impl ::std::error::Error for SampleError {
125 fn description(&self) -> &'static str {
126 "sample error"
127 }
128}
129
130#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
132pub struct ModelParams {
133 pub p_partial: f64,
136 pub p_observe: f64,
139 pub max_steps: usize,
141 pub max_size: Option<usize>,
144}
145impl Default for ModelParams {
146 fn default() -> ModelParams {
147 ModelParams {
148 p_partial: 0.0,
149 p_observe: 0.0,
150 max_steps: 50,
151 max_size: Some(500),
152 }
153 }
154}
155
156pub fn task_by_rewrite<'a, O: Sync + 'a>(
169 data: &'a [Rule],
170 params: ModelParams,
171 lex: &Lexicon,
172 observation: O,
173) -> Result<impl Task<O, Representation = Lexicon, Expression = TRS> + 'a, TypeError> {
174 let mut ctx = lex.0.read().expect("poisoned lexicon").ctx.clone();
175 let tp = lex.infer_rules(data, &mut ctx)?;
176 Ok(TrsTask {
177 data,
178 params,
179 tp,
180 observation,
181 })
182}
183
184struct TrsTask<'a, O> {
185 data: &'a [Rule],
186 params: ModelParams,
187 tp: TypeScheme,
188 observation: O,
189}
190impl<'a, O: Sync> Task<O> for TrsTask<'a, O> {
191 type Representation = Lexicon;
192 type Expression = TRS;
193
194 fn oracle(&self, _: &Lexicon, h: &TRS) -> f64 {
195 -h.posterior(self.data, self.params)
196 }
197 fn tp(&self) -> &TypeScheme {
198 &self.tp
199 }
200 fn observation(&self) -> &O {
201 &self.observation
202 }
203}