Skip to main content

minicas_core/pred/
mod.rs

1//! Predicate rules/matching for AST nodes.
2use crate::ast::{AstNode, BinaryOp, CmpOp, NodeInner, UnaryOp};
3use crate::{Path, TyValue};
4
5/// Describes a predicate on the operation.
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
7pub enum PredicateOp {
8    Unary(UnaryOp),
9    Binary(BinaryOp),
10    Piecewise,
11    Const,
12    Var,
13}
14
15impl PredicateOp {
16    /// Indicates if the operation matches that of the given [AstNode].
17    pub fn matches<N: AstNode>(&self, n: &N) -> bool {
18        match (self, n.as_inner()) {
19            (Self::Unary(po), NodeInner::Unary(n)) => po.eq(&n.op),
20            (Self::Binary(po), NodeInner::Binary(n)) => po.eq(&n.op),
21            (Self::Piecewise, NodeInner::Piecewise(_)) => true,
22            (Self::Const, NodeInner::Const(_)) => true,
23            (Self::Var, NodeInner::Var(_)) => true,
24            _ => false,
25        }
26    }
27}
28
29impl TryFrom<&str> for PredicateOp {
30    type Error = ();
31
32    fn try_from(s: &str) -> Result<Self, Self::Error> {
33        match s {
34            "const" => Ok(Self::Const),
35            "var" => Ok(Self::Var),
36            "piecewise" => Ok(Self::Piecewise),
37
38            "neg" => Ok(Self::Unary(UnaryOp::Negate)),
39            "abs" => Ok(Self::Unary(UnaryOp::Abs)),
40
41            "pow" => Ok(Self::Binary(BinaryOp::Pow)),
42            "root" => Ok(Self::Binary(BinaryOp::Root)),
43            "pm" | "±" => Ok(Self::Binary(BinaryOp::PlusOrMinus)),
44            "-" => Ok(Self::Binary(BinaryOp::Sub)),
45            "+" => Ok(Self::Binary(BinaryOp::Add)),
46            "/" => Ok(Self::Binary(BinaryOp::Div)),
47            "*" => Ok(Self::Binary(BinaryOp::Mul)),
48            "min" => Ok(Self::Binary(BinaryOp::Min)),
49            "max" => Ok(Self::Binary(BinaryOp::Max)),
50            "==" => Ok(Self::Binary(BinaryOp::Cmp(CmpOp::Equals))),
51            "<" => Ok(Self::Binary(BinaryOp::Cmp(CmpOp::LessThan(false)))),
52            "<=" => Ok(Self::Binary(BinaryOp::Cmp(CmpOp::LessThan(true)))),
53            ">" => Ok(Self::Binary(BinaryOp::Cmp(CmpOp::GreaterThan(false)))),
54            ">=" => Ok(Self::Binary(BinaryOp::Cmp(CmpOp::GreaterThan(true)))),
55            _ => Err(()),
56        }
57    }
58}
59
60/// Describes a predicate on an AST node.
61#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
62pub struct Predicate {
63    /// Match only on nodes performing the given operation.
64    pub op: Option<PredicateOp>,
65    /// Match only on nodes which are NOT the given operation.
66    pub not_op: Option<PredicateOp>,
67
68    /// Match if the node is a constant with some value.
69    pub const_value: Option<TyValue>,
70
71    /// Match if the descendant nodes specified by the two paths are equal.
72    pub equivalent: Vec<(Path, Path)>,
73
74    /// Match if the descendant nodes have a certain number of children.
75    ///
76    /// For piecewise functions, this describes the number of conditional arms.
77    pub arity: Option<usize>,
78
79    /// Match only on nodes whos children match the given predicates respectively.
80    ///
81    /// A `None` value in some position means to skip considering the child in that
82    /// place, for instance a `[None, Some(...)]` skips evaluation of a first operand.
83    pub children: Vec<Option<Self>>,
84}
85
86impl Predicate {
87    /// Creates a predicate matching only on the specified op.
88    pub fn op(op: PredicateOp) -> Self {
89        Self {
90            op: Some(op),
91            ..Default::default()
92        }
93    }
94
95    /// Creates a predicate matching only the specified children.
96    ///
97    /// A `None` value in some position means to skip considering the child in that
98    /// place, for instance a `[None, Some(...)]` skips evaluation of a first operand.
99    pub fn children(children: Vec<Option<Self>>) -> Self {
100        Self {
101            children,
102            ..Default::default()
103        }
104    }
105
106    /// Indicates if the predicate matches a given [AstNode].
107    pub fn matches<N: AstNode>(&self, n: &N) -> bool {
108        if !self.op.map(|po| po.matches(n)).unwrap_or(true) {
109            return false;
110        }
111        if self.not_op.map(|po| po.matches(n)).unwrap_or(false) {
112            return false;
113        }
114        match (self.const_value.as_ref(), n.as_inner()) {
115            (None, _) => {}
116            (Some(v), NodeInner::Const(c)) => {
117                if c.value() != v {
118                    return false;
119                }
120            }
121            (Some(_), _) => {
122                return false;
123            }
124        }
125        if let Some(arity) = self.arity {
126            if !match (arity, n.as_inner()) {
127                (0, NodeInner::Const(_) | NodeInner::Var(_)) => true,
128                (_, NodeInner::Const(_) | NodeInner::Var(_)) => false,
129                (2, NodeInner::Binary(_)) => true,
130                (_, NodeInner::Binary(_)) => false,
131                (1, NodeInner::Unary(_)) => true,
132                (_, NodeInner::Unary(_)) => false,
133                (a, NodeInner::Piecewise(p)) => a == 2 * p.iter_branches().count(),
134            } {
135                return false;
136            }
137        }
138
139        for (l, r) in self.equivalent.iter() {
140            let (l, r) = (n.get(l.iter()), n.get(r.iter()));
141            if let (Some(l), Some(r)) = (l, r) {
142                if l != r {
143                    return false;
144                }
145            } else {
146                return false;
147            }
148        }
149
150        if self.children.len() > 0 {
151            // TODO: piecewise shouldnt be special-cased, but supported via
152            // iter_children() path.
153            if let NodeInner::Piecewise(_) = n.as_inner() {
154                let all_meets = self.children.iter().enumerate().all(|(i, pc)| {
155                    if let Some(pc) = pc {
156                        if let Some(c) = n.get(Path::with_next(i).iter()) {
157                            pc.matches(c)
158                        } else {
159                            false
160                        }
161                    } else {
162                        true
163                    }
164                });
165                if !all_meets {
166                    return false;
167                }
168            } else {
169                if !self.children.iter().zip(n.iter_children()).all(|(pc, c)| {
170                    if let Some(pc) = pc {
171                        pc.matches(c)
172                    } else {
173                        true
174                    }
175                }) {
176                    return false;
177                }
178            }
179        }
180
181        true
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use crate::ast::Node;
189
190    #[test]
191    fn predicate_op_matches() {
192        assert_eq!(
193            PredicateOp::Binary(BinaryOp::Add).matches(&Node::try_from("3 + 5").unwrap()),
194            true,
195        );
196        assert_eq!(
197            PredicateOp::Binary(BinaryOp::Add).matches(&Node::try_from("3 - 5").unwrap()),
198            false,
199        );
200        assert_eq!(
201            PredicateOp::Binary(BinaryOp::Add).matches(&Node::try_from("-5").unwrap()),
202            false,
203        );
204
205        assert_eq!(
206            PredicateOp::Unary(UnaryOp::Negate).matches(&Node::try_from("-5").unwrap()),
207            true,
208        );
209        assert_eq!(
210            PredicateOp::Const.matches(&Node::try_from("-5").unwrap()),
211            false,
212        );
213        assert_eq!(
214            PredicateOp::Const.matches(&Node::try_from("5").unwrap()),
215            true,
216        );
217        assert_eq!(
218            PredicateOp::Piecewise.matches(&Node::try_from("{otherwise 5}").unwrap()),
219            true,
220        );
221    }
222
223    #[test]
224    fn equivalent_matches() {
225        assert_eq!(
226            Predicate {
227                equivalent: vec![(vec![0].into(), vec![1].into())],
228                ..Default::default()
229            }
230            .matches(&Node::try_from("x + x").unwrap()),
231            true,
232        );
233        assert_eq!(
234            Predicate {
235                equivalent: vec![(vec![0].into(), vec![1].into())],
236                ..Default::default()
237            }
238            .matches(&Node::try_from("x + 2x").unwrap()),
239            false,
240        );
241
242        assert_eq!(
243            Predicate {
244                equivalent: vec![(vec![0].into(), vec![1, 0].into())],
245                ..Default::default()
246            }
247            .matches(&Node::try_from("a * (a + 1)").unwrap()),
248            true,
249        );
250    }
251
252    #[test]
253    fn children_matches() {
254        assert_eq!(
255            Predicate {
256                children: vec![],
257                ..Default::default()
258            }
259            .matches(&Node::try_from("3 + 5").unwrap()),
260            true,
261        );
262        assert_eq!(
263            Predicate {
264                children: vec![None, None],
265                ..Default::default()
266            }
267            .matches(&Node::try_from("3 + 5").unwrap()),
268            true,
269        );
270
271        assert_eq!(
272            Predicate {
273                op: Some(PredicateOp::Binary(BinaryOp::Add)),
274                children: vec![Some(Predicate::op(PredicateOp::Const))],
275                ..Default::default()
276            }
277            .matches(&Node::try_from("3 + 5").unwrap()),
278            true,
279        );
280
281        assert_eq!(
282            Predicate {
283                op: Some(PredicateOp::Binary(BinaryOp::Add)),
284                children: vec![
285                    Some(Predicate::op(PredicateOp::Const)),
286                    Some(Predicate::op(PredicateOp::Const))
287                ],
288                ..Default::default()
289            }
290            .matches(&Node::try_from("5 + 3 * 4").unwrap()),
291            false,
292        );
293        assert_eq!(
294            Predicate {
295                op: Some(PredicateOp::Binary(BinaryOp::Add)),
296                children: vec![
297                    Some(Predicate::op(PredicateOp::Const)),
298                    Some(Predicate::op(PredicateOp::Binary(BinaryOp::Mul))),
299                ],
300                ..Default::default()
301            }
302            .matches(&Node::try_from("5 + 3 * 4").unwrap()),
303            true,
304        );
305        assert_eq!(
306            Predicate {
307                children: vec![Some(Predicate::op(PredicateOp::Const)), None],
308                ..Default::default()
309            }
310            .matches(&Node::try_from("5 + 3 * 5").unwrap()),
311            true,
312        );
313
314        assert_eq!(
315            Predicate {
316                children: vec![Some(Predicate::op(PredicateOp::Unary(UnaryOp::Negate)))],
317                ..Default::default()
318            }
319            .matches(&Node::try_from("3 + 5").unwrap()),
320            false,
321        );
322        assert_eq!(
323            Predicate {
324                children: vec![Some(Predicate::op(PredicateOp::Unary(UnaryOp::Negate)))],
325                ..Default::default()
326            }
327            .matches(&Node::try_from("-3 + 5").unwrap()),
328            true,
329        );
330
331        // Test matching piecewise function arms / else case
332        assert_eq!(
333            Predicate {
334                children: vec![Some(Predicate::op(PredicateOp::Unary(UnaryOp::Negate)))],
335                ..Default::default()
336            }
337            .matches(&Node::try_from("{otherwise -2}").unwrap()),
338            true,
339        );
340        assert_eq!(
341            Predicate {
342                children: vec![
343                    Some(Predicate::op(PredicateOp::Const)),
344                    Some(Predicate::op(PredicateOp::Binary(BinaryOp::Cmp(
345                        CmpOp::LessThan(false)
346                    )))),
347                    Some(Predicate::op(PredicateOp::Unary(UnaryOp::Negate))),
348                ],
349                ..Default::default()
350            }
351            .matches(&Node::try_from("{1 if x < 0; otherwise -2}").unwrap()),
352            true,
353        );
354
355        // Test some deeper nesting
356        assert_eq!(
357            Predicate {
358                children: vec![
359                    Some(Predicate::children(vec![Some(Predicate::op(
360                        PredicateOp::Const
361                    ))])),
362                    Some(Predicate::children(vec![
363                        Some(Predicate::op(PredicateOp::Const)),
364                        Some(Predicate::op(PredicateOp::Const))
365                    ]))
366                ],
367                op: Some(PredicateOp::Binary(BinaryOp::Add)),
368                ..Default::default()
369            }
370            .matches(&Node::try_from("-4 + 2 * 3").unwrap()),
371            true,
372        );
373        assert_eq!(
374            Predicate {
375                children: vec![
376                    Some(Predicate::children(vec![Some(Predicate::op(
377                        PredicateOp::Const
378                    ))])),
379                    Some(Predicate::children(vec![Some(Predicate::op(
380                        PredicateOp::Unary(UnaryOp::Negate)
381                    ))]))
382                ],
383                ..Default::default()
384            }
385            .matches(&Node::try_from("-4 + 2 * 3").unwrap()),
386            false,
387        );
388    }
389
390    #[test]
391    fn not_op() {
392        assert_eq!(
393            Predicate {
394                not_op: Some(PredicateOp::Binary(BinaryOp::Mul)),
395                ..Default::default()
396            }
397            .matches(&Node::try_from("3 + 5").unwrap()),
398            true,
399        );
400        assert_eq!(
401            Predicate {
402                not_op: Some(PredicateOp::Var),
403                ..Default::default()
404            }
405            .matches(&Node::try_from("x").unwrap()),
406            false,
407        );
408    }
409
410    #[test]
411    fn const_value() {
412        assert_eq!(
413            Predicate {
414                const_value: Some(TyValue::Bool(true)),
415                ..Default::default()
416            }
417            .matches(&Node::try_from("3").unwrap()),
418            false,
419        );
420        assert_eq!(
421            Predicate {
422                const_value: Some(TyValue::from(3.5)),
423                ..Default::default()
424            }
425            .matches(&Node::try_from("3.5 + 2").unwrap()),
426            false,
427        );
428        assert_eq!(
429            Predicate {
430                const_value: Some(TyValue::from(3)),
431                ..Default::default()
432            }
433            .matches(&Node::try_from("3").unwrap()),
434            true,
435        );
436        assert_eq!(
437            Predicate {
438                const_value: Some(TyValue::from(4)),
439                ..Default::default()
440            }
441            .matches(&Node::try_from("3").unwrap()),
442            false,
443        );
444    }
445}