programinduction/trs/
mod.rs

1//! (representation) Polymorphically-typed term rewriting system.
2//!
3//! An evaluatable first-order [Term Rewriting System][0] (TRS) with a [Hindley-Milner type
4//! system][1].
5//!
6//! [0]: https://wikipedia.org/wiki/Hindley–Milner_type_system
7//!      "Wikipedia - Hindley-Milner Type System"
8//! [1]: https://en.wikipedia.org/wiki/Rewriting#Term_rewriting_systems
9//!      "Wikipedia - Term Rewriting Systems"
10//!
11//! # Example
12//!
13//! ```
14//! use polytype::{ptp, tp, Context as TypeContext};
15//! use programinduction::trs::{TRS, Lexicon};
16//! use term_rewriting::{Signature, parse_rule};
17//!
18//! let mut sig = Signature::default();
19//!
20//! let mut ops = vec![];
21//! sig.new_op(2, Some("PLUS".to_string()));
22//! ops.push(ptp![@arrow[tp!(int), tp!(int), tp!(int)]]);
23//! sig.new_op(1, Some("SUCC".to_string()));
24//! ops.push(ptp![@arrow[tp!(int), tp!(int)]]);
25//! sig.new_op(0, Some("ZERO".to_string()));
26//! ops.push(ptp![int]);
27//!
28//! let rules = vec![
29//!     parse_rule(&mut sig, "PLUS(x_ ZERO) = x_").expect("parsed rule"),
30//!     parse_rule(&mut sig, "PLUS(x_ SUCC(y_)) = SUCC(PLUS(x_ y_))").expect("parsed rule"),
31//! ];
32//!
33//! let vars = vec![
34//!     ptp![int],
35//!     ptp![int],
36//!     ptp![int],
37//! ];
38//!
39//! let lexicon = Lexicon::from_signature(sig, ops, vars, vec![], vec![], false, TypeContext::default());
40//!
41//! let trs = TRS::new(&lexicon, rules, &lexicon.context());
42//! ```
43
44mod lexicon;
45pub mod parser;
46mod rewrite;
47pub use self::lexicon::{GeneticParams, Lexicon};
48pub use self::parser::{
49    parse_context, parse_lexicon, parse_rule, parse_rulecontext, parse_templates, parse_trs,
50};
51pub use self::rewrite::TRS;
52use crate::Task;
53
54use polytype;
55use polytype::TypeScheme;
56use serde::{Deserialize, Serialize};
57use std::fmt;
58use term_rewriting::{Rule, TRSError};
59
60#[derive(Debug, Clone)]
61/// The error type for type inference.
62pub enum TypeError {
63    Unification(polytype::UnificationError),
64    OpNotFound,
65    VarNotFound,
66}
67impl From<polytype::UnificationError> for TypeError {
68    fn from(e: polytype::UnificationError) -> TypeError {
69        TypeError::Unification(e)
70    }
71}
72impl fmt::Display for TypeError {
73    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
74        match *self {
75            TypeError::Unification(ref e) => write!(f, "unification error: {}", e),
76            TypeError::OpNotFound => write!(f, "operator not found"),
77            TypeError::VarNotFound => write!(f, "variable not found"),
78        }
79    }
80}
81impl ::std::error::Error for TypeError {
82    fn description(&self) -> &'static str {
83        "type error"
84    }
85}
86
87#[derive(Debug, Clone)]
88/// The error type for sampling operations.
89pub enum SampleError {
90    TypeError(TypeError),
91    TRSError(TRSError),
92    SizeExceeded(usize, usize),
93    OptionsExhausted,
94    Subterm,
95}
96impl From<TypeError> for SampleError {
97    fn from(e: TypeError) -> SampleError {
98        SampleError::TypeError(e)
99    }
100}
101impl From<TRSError> for SampleError {
102    fn from(e: TRSError) -> SampleError {
103        SampleError::TRSError(e)
104    }
105}
106impl From<polytype::UnificationError> for SampleError {
107    fn from(e: polytype::UnificationError) -> SampleError {
108        SampleError::TypeError(TypeError::Unification(e))
109    }
110}
111impl fmt::Display for SampleError {
112    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
113        match *self {
114            SampleError::TypeError(ref e) => write!(f, "type error: {}", e),
115            SampleError::TRSError(ref e) => write!(f, "TRS error: {}", e),
116            SampleError::SizeExceeded(size, max_size) => {
117                write!(f, "size {} exceeded maximum of {}", size, max_size)
118            }
119            SampleError::OptionsExhausted => write!(f, "failed to sample (options exhausted)"),
120            SampleError::Subterm => write!(f, "cannot sample subterm"),
121        }
122    }
123}
124impl ::std::error::Error for SampleError {
125    fn description(&self) -> &'static str {
126        "sample error"
127    }
128}
129
130/// Parameters for a TRS-based probabilistic model.
131#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
132pub struct ModelParams {
133    /// How much partial credit is given for incorrect answers; it should be a
134    /// probability (i.e. in [0, 1]).
135    pub p_partial: f64,
136    /// The (non-log) probability of generating observations at arbitrary
137    /// evaluation steps (i.e. not just normal forms). Typically 0.0.
138    pub p_observe: f64,
139    /// The number of evaluation steps you would like to explore in the trace.
140    pub max_steps: usize,
141    /// The largest term that will be considered for evaluation. `None` will
142    /// evaluate all terms.
143    pub max_size: Option<usize>,
144}
145impl Default for ModelParams {
146    fn default() -> ModelParams {
147        ModelParams {
148            p_partial: 0.0,
149            p_observe: 0.0,
150            max_steps: 50,
151            max_size: Some(500),
152        }
153    }
154}
155
156/// Construct a [`Task`] evaluating [`TRS`]s (constructed from a [`Lexicon`])
157/// using rewriting of inputs to outputs.
158///
159/// Each [`term_rewriting::Rule`] in `data` must have a single RHS term. The
160/// resulting [`Task`] checks whether each datum's LHS gets rewritten to its RHS
161/// under a [`TRS`] within the constraints specified by the [`ModelParams`].
162///
163/// [`Lexicon`]: struct.Lexicon.html
164/// [`ModelParams`]: struct.ModelParams.html
165/// [`term_rewriting::Rule`]: https://docs.rs/term_rewriting/~0.3/term_rewriting/struct.Rule.html
166/// [`Task`]: ../struct.Task.html
167/// [`TRS`]: struct.TRS.html
168pub fn task_by_rewrite<'a, O: Sync + 'a>(
169    data: &'a [Rule],
170    params: ModelParams,
171    lex: &Lexicon,
172    observation: O,
173) -> Result<impl Task<O, Representation = Lexicon, Expression = TRS> + 'a, TypeError> {
174    let mut ctx = lex.0.read().expect("poisoned lexicon").ctx.clone();
175    let tp = lex.infer_rules(data, &mut ctx)?;
176    Ok(TrsTask {
177        data,
178        params,
179        tp,
180        observation,
181    })
182}
183
184struct TrsTask<'a, O> {
185    data: &'a [Rule],
186    params: ModelParams,
187    tp: TypeScheme,
188    observation: O,
189}
190impl<'a, O: Sync> Task<O> for TrsTask<'a, O> {
191    type Representation = Lexicon;
192    type Expression = TRS;
193
194    fn oracle(&self, _: &Lexicon, h: &TRS) -> f64 {
195        -h.posterior(self.data, self.params)
196    }
197    fn tp(&self) -> &TypeScheme {
198        &self.tp
199    }
200    fn observation(&self) -> &O {
201        &self.observation
202    }
203}