Skip to main content

minicas_core/rules/
mod.rs

1//! Defines the syntax and processing for the rules engine: used by the `minicas_crs` crate.
2
3use crate::ast::{AstNode, Node, NodeInner};
4use crate::pred::Predicate;
5use serde::{Deserialize, Deserializer};
6
7mod pred_spec;
8pub use pred_spec::PredSpec;
9
10/// Describes a rule test, which consists of the expected AST
11/// before and after applying a rule.
12#[derive(Deserialize, Debug, Clone, PartialEq, Eq)]
13#[serde(deny_unknown_fields)]
14pub struct TestSpec {
15    pub from: String,
16    pub to: String,
17}
18
19impl TryFrom<TestSpec> for (Node, Node) {
20    type Error = String;
21
22    fn try_from(s: TestSpec) -> Result<Self, Self::Error> {
23        Ok((s.from.as_str().try_into()?, s.to.as_str().try_into()?))
24    }
25}
26
27fn deserialize_ast_str<'de, D>(deserializer: D) -> Result<Node, D::Error>
28where
29    D: Deserializer<'de>,
30{
31    let buf = String::deserialize(deserializer)?;
32
33    Node::try_from(buf.as_str()).map_err(serde::de::Error::custom)
34}
35
36/// Describes the action to perform when a rule matches.
37#[derive(Deserialize, Debug, Clone, PartialEq, Eq)]
38#[serde(deny_unknown_fields, untagged)]
39pub enum RuleActionSpec {
40    /// Overwrite a part of the AST with a node elsewhere in the AST.
41    Replace {
42        /// The indexing sequence to fetch the node to be written.
43        ///
44        /// See [AstNode::get] for more details on indexing.
45        from: Vec<usize>,
46        /// The indexing sequence the node should be written to.
47        ///
48        /// See [AstNode::get] for more details on indexing.
49        to: Vec<usize>,
50    },
51    /// Swaps two nodes in the AST.
52    Swap { swap: (Vec<usize>, Vec<usize>) },
53    /// Constructs a replacement node by starting from some base
54    /// and performing actions against it from the original node.
55    Rewrite {
56        #[serde(deserialize_with = "deserialize_ast_str")]
57        base: Node,
58        replace: Vec<(Vec<usize>, Vec<usize>)>,
59        #[serde(default)]
60        fold: Vec<Vec<usize>>,
61    },
62    /// Perform multiple actions in order.
63    Multi(Vec<Self>),
64}
65
66impl RuleActionSpec {
67    /// Applies the action to the given node.
68    // TODO: Real error type.
69    pub fn apply<N: AstNode + From<Node>>(&self, n: &mut N) -> Result<(), ()> {
70        match self {
71            Self::Replace { from, to } => {
72                let from = n.get(from.iter().map(|i| *i)).ok_or(())?.clone();
73                let to = n.get_mut(to.iter().map(|i| *i)).ok_or(())?;
74                *to = from;
75            }
76            Self::Swap { swap: (a, b) } => {
77                let mut tmp = NodeInner::new_const(false);
78                let a_mut = n.get_mut(a.iter().map(|i| *i)).ok_or(())?;
79                std::mem::swap(&mut tmp, a_mut);
80
81                let b_mut = n.get_mut(b.iter().map(|i| *i)).ok_or(())?;
82                std::mem::swap(&mut tmp, b_mut);
83
84                let a_mut = n.get_mut(a.iter().map(|i| *i)).ok_or(())?;
85                std::mem::swap(&mut tmp, a_mut);
86            }
87            Self::Rewrite {
88                base,
89                replace,
90                fold,
91            } => {
92                let mut out: N = N::from(base.clone());
93
94                for (from, to) in replace {
95                    let from = n.get(from.iter().map(|i| *i)).ok_or(())?.clone();
96                    let to = out.get_mut(to.iter().map(|i| *i)).ok_or(())?;
97                    *to = from;
98                }
99                for path in fold {
100                    let target = out.get_mut(path.iter().map(|i| *i)).ok_or(())?;
101                    crate::ast::fold(target).map_err(|_| ())?;
102                }
103
104                *n = out;
105            }
106            Self::Multi(v) => {
107                for r in v.iter() {
108                    r.apply(n)?;
109                }
110            }
111        }
112        Ok(())
113    }
114}
115
116/// Describes metadata about a rule.
117#[derive(Deserialize, Debug, Default, Clone, PartialEq, Eq)]
118#[serde(deny_unknown_fields)]
119pub struct MetaSpec {
120    #[serde(default)]
121    pub is_simplify: bool,
122    #[serde(default)]
123    /// Indicates that the rule describes an equivalence that isn't
124    /// necessarily a simplification, but might be a relevant form
125    /// for algebraic transformations.
126    pub alt_form: bool,
127}
128
129/// Describes how a rule is specified in a rule file.
130#[derive(Deserialize, Debug, Clone, PartialEq, Eq)]
131#[serde(deny_unknown_fields)]
132pub struct RuleSpec {
133    #[serde(default)]
134    pub meta: MetaSpec,
135
136    #[serde(alias = "match")]
137    pub predicate: PredSpec,
138
139    #[serde(
140        alias = "replace",
141        alias = "actions",
142        alias = "swap",
143        alias = "rewrite"
144    )]
145    pub action: RuleActionSpec,
146
147    #[serde(default)]
148    pub tests: Vec<TestSpec>,
149}
150
151/// Describes the schema for a rule file.
152#[derive(Deserialize, Debug, Clone, PartialEq, Eq)]
153#[serde(deny_unknown_fields)]
154pub struct RuleFileSpec {
155    pub rules: Vec<RuleSpec>,
156}
157
158/// An AST update rule.
159#[derive(Debug, Clone, PartialEq, Eq)]
160pub struct Rule {
161    pub meta: MetaSpec,
162    pred: Predicate,
163    action: RuleActionSpec,
164    tests: Vec<(Node, Node)>,
165}
166
167impl Rule {
168    /// Applies the rule to the specified node. True is returned if
169    /// the rule matched & an action was taken.
170    pub fn eval<N: AstNode + From<Node>>(&self, n: &mut N) -> Result<bool, ()> {
171        if self.pred.matches(n) {
172            self.action.apply(n)?;
173            Ok(true)
174        } else {
175            Ok(false)
176        }
177    }
178
179    /// Run the self tests.
180    pub fn self_test(&self) -> Result<(), (usize, String)> {
181        for (i, (from, to)) in self.tests.iter().enumerate() {
182            let mut n = from.clone();
183            match self.eval(&mut n) {
184                Ok(true) => {}
185                Ok(false) => return Err((i, "predicate did not match".to_string())),
186                Err(()) => return Err((i, "evaluation failed".to_string())),
187            }
188            if &n != to {
189                return Err((i, format!("got result {}", n)));
190            }
191        }
192        Ok(())
193    }
194}
195
196impl TryFrom<RuleSpec> for Rule {
197    type Error = String;
198
199    fn try_from(s: RuleSpec) -> Result<Self, Self::Error> {
200        Ok(Self {
201            meta: s.meta,
202            pred: s.predicate.try_into()?,
203            action: s.action,
204            tests: s
205                .tests
206                .into_iter()
207                .map(|s| <TestSpec as TryInto<(Node, Node)>>::try_into(s))
208                .collect::<Result<Vec<(Node, Node)>, String>>()?,
209        })
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216    use crate::ast::BinaryOp;
217    use crate::pred::PredicateOp;
218    use crate::TyValue;
219    use toml::de;
220
221    #[test]
222    fn parse_spec() {
223        assert_eq!(
224            de::from_str::<RuleSpec>(
225                r#"
226            match.op = '/'
227			match.rhs = {const = '1'}
228
229			replace.from = [0] # using the numerator (first lhs)
230			replace.to = []    # overwrite self
231
232			tests = [
233			    {from = '12 / 1', to = '12'},
234			]
235            "#
236            ),
237            Ok(RuleSpec {
238                meta: MetaSpec::default(),
239                predicate: PredSpec {
240                    op: Some("/".to_string()),
241                    rhs: Some(Box::new(PredSpec {
242                        const_val: Some("1".to_string()),
243                        ..PredSpec::default()
244                    })),
245                    ..PredSpec::default()
246                },
247                action: RuleActionSpec::Replace {
248                    from: vec![0],
249                    to: vec![],
250                },
251                tests: vec![TestSpec {
252                    from: "12 / 1".to_string(),
253                    to: "12".to_string(),
254                }],
255            })
256        );
257    }
258
259    #[test]
260    fn convert_spec() {
261        assert_eq!(
262            de::from_str::<RuleSpec>(
263                r#"
264            match.op = '/'
265			match.rhs = {const = '1'}
266
267			replace.from = [0] # using the numerator (first lhs)
268			replace.to = []    # overwrite self
269
270			tests = [
271			    {from = '12 / 1', to = '12'},
272			]
273            "#
274            )
275            .unwrap()
276            .try_into(),
277            Ok(Rule {
278                meta: MetaSpec::default(),
279                pred: Predicate {
280                    op: Some(PredicateOp::Binary(BinaryOp::Div)),
281                    children: vec![
282                        None,
283                        Some(Predicate {
284                            const_value: Some(TyValue::from(1)),
285                            ..Predicate::op(PredicateOp::Const)
286                        })
287                    ],
288                    ..Predicate::default()
289                },
290                action: RuleActionSpec::Replace {
291                    from: vec![0],
292                    to: vec![],
293                },
294                tests: vec![(
295                    Node::try_from("12 / 1").unwrap(),
296                    Node::try_from("12").unwrap()
297                ),],
298            })
299        );
300    }
301
302    #[test]
303    fn apply_replace() {
304        let rule: Rule = de::from_str::<RuleSpec>(
305            r#"
306            match.op = '/'
307			match.rhs = {const = '1'}
308
309			replace.from = [0] # using the numerator (first lhs)
310			replace.to = []    # overwrite self
311            "#,
312        )
313        .unwrap()
314        .try_into()
315        .unwrap();
316
317        let mut ast = Node::try_from("12 / 1").unwrap();
318        assert_eq!(rule.eval(&mut ast), Ok(true));
319        assert_eq!(ast, Node::try_from("12").unwrap());
320    }
321
322    #[test]
323    fn apply_swap() {
324        let rule: Rule = de::from_str::<RuleSpec>(
325            r#"
326            match.op = '*'
327            match.lhs = {not_op = 'const'}
328            match.rhs = {op = 'const'}
329
330            action.swap = [[0], [1]]
331            "#,
332        )
333        .unwrap()
334        .try_into()
335        .unwrap();
336
337        let mut ast = Node::try_from("x * 2").unwrap();
338        assert_eq!(rule.eval(&mut ast), Ok(true));
339        assert_eq!(ast, Node::try_from("2x").unwrap());
340    }
341
342    #[test]
343    fn apply_rewrite() {
344        let rule: Rule = de::from_str::<RuleSpec>(
345            r#"
346            match.op = '-'
347            match.lhs = {op = 'neg'}
348            match.rhs = {not_op = 'neg'}
349
350            [rewrite]
351            base = "-(1 + 1)"
352            replace = [
353                [[0, 0], [0, 0]],
354                [[1], [0, 1]],
355            ]
356            "#,
357        )
358        .unwrap()
359        .try_into()
360        .unwrap();
361
362        let mut ast = Node::try_from("-12 - 2").unwrap();
363        assert_eq!(rule.eval(&mut ast), Ok(true));
364        assert_eq!(ast, Node::try_from("-(12 + 2)").unwrap());
365    }
366
367    #[test]
368    fn selftest() {
369        let mut rule: Rule = de::from_str::<RuleSpec>(
370            r#"
371            match.op = '/'
372			match.rhs = {const = '1'}
373
374			actions = [
375                {from = [0], to = []},
376            ]
377
378			tests = [
379			    {from = '12 / 1', to = '12'},
380			]
381            "#,
382        )
383        .unwrap()
384        .try_into()
385        .unwrap();
386
387        assert_eq!(rule.self_test(), Ok(()));
388
389        // mess up the test result
390        rule.tests[0].1 = Node::try_from("4x").unwrap();
391        assert_eq!(rule.self_test(), Err((0, "got result 12".to_string())));
392    }
393}