1use crate::ast::{AstNode, BinaryOp, CmpOp, NodeInner, UnaryOp};
3use crate::{Path, TyValue};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
7pub enum PredicateOp {
8 Unary(UnaryOp),
9 Binary(BinaryOp),
10 Piecewise,
11 Const,
12 Var,
13}
14
15impl PredicateOp {
16 pub fn matches<N: AstNode>(&self, n: &N) -> bool {
18 match (self, n.as_inner()) {
19 (Self::Unary(po), NodeInner::Unary(n)) => po.eq(&n.op),
20 (Self::Binary(po), NodeInner::Binary(n)) => po.eq(&n.op),
21 (Self::Piecewise, NodeInner::Piecewise(_)) => true,
22 (Self::Const, NodeInner::Const(_)) => true,
23 (Self::Var, NodeInner::Var(_)) => true,
24 _ => false,
25 }
26 }
27}
28
29impl TryFrom<&str> for PredicateOp {
30 type Error = ();
31
32 fn try_from(s: &str) -> Result<Self, Self::Error> {
33 match s {
34 "const" => Ok(Self::Const),
35 "var" => Ok(Self::Var),
36 "piecewise" => Ok(Self::Piecewise),
37
38 "neg" => Ok(Self::Unary(UnaryOp::Negate)),
39 "abs" => Ok(Self::Unary(UnaryOp::Abs)),
40
41 "pow" => Ok(Self::Binary(BinaryOp::Pow)),
42 "root" => Ok(Self::Binary(BinaryOp::Root)),
43 "pm" | "±" => Ok(Self::Binary(BinaryOp::PlusOrMinus)),
44 "-" => Ok(Self::Binary(BinaryOp::Sub)),
45 "+" => Ok(Self::Binary(BinaryOp::Add)),
46 "/" => Ok(Self::Binary(BinaryOp::Div)),
47 "*" => Ok(Self::Binary(BinaryOp::Mul)),
48 "min" => Ok(Self::Binary(BinaryOp::Min)),
49 "max" => Ok(Self::Binary(BinaryOp::Max)),
50 "==" => Ok(Self::Binary(BinaryOp::Cmp(CmpOp::Equals))),
51 "<" => Ok(Self::Binary(BinaryOp::Cmp(CmpOp::LessThan(false)))),
52 "<=" => Ok(Self::Binary(BinaryOp::Cmp(CmpOp::LessThan(true)))),
53 ">" => Ok(Self::Binary(BinaryOp::Cmp(CmpOp::GreaterThan(false)))),
54 ">=" => Ok(Self::Binary(BinaryOp::Cmp(CmpOp::GreaterThan(true)))),
55 _ => Err(()),
56 }
57 }
58}
59
60#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
62pub struct Predicate {
63 pub op: Option<PredicateOp>,
65 pub not_op: Option<PredicateOp>,
67
68 pub const_value: Option<TyValue>,
70
71 pub equivalent: Vec<(Path, Path)>,
73
74 pub arity: Option<usize>,
78
79 pub children: Vec<Option<Self>>,
84}
85
86impl Predicate {
87 pub fn op(op: PredicateOp) -> Self {
89 Self {
90 op: Some(op),
91 ..Default::default()
92 }
93 }
94
95 pub fn children(children: Vec<Option<Self>>) -> Self {
100 Self {
101 children,
102 ..Default::default()
103 }
104 }
105
106 pub fn matches<N: AstNode>(&self, n: &N) -> bool {
108 if !self.op.map(|po| po.matches(n)).unwrap_or(true) {
109 return false;
110 }
111 if self.not_op.map(|po| po.matches(n)).unwrap_or(false) {
112 return false;
113 }
114 match (self.const_value.as_ref(), n.as_inner()) {
115 (None, _) => {}
116 (Some(v), NodeInner::Const(c)) => {
117 if c.value() != v {
118 return false;
119 }
120 }
121 (Some(_), _) => {
122 return false;
123 }
124 }
125 if let Some(arity) = self.arity {
126 if !match (arity, n.as_inner()) {
127 (0, NodeInner::Const(_) | NodeInner::Var(_)) => true,
128 (_, NodeInner::Const(_) | NodeInner::Var(_)) => false,
129 (2, NodeInner::Binary(_)) => true,
130 (_, NodeInner::Binary(_)) => false,
131 (1, NodeInner::Unary(_)) => true,
132 (_, NodeInner::Unary(_)) => false,
133 (a, NodeInner::Piecewise(p)) => a == 2 * p.iter_branches().count(),
134 } {
135 return false;
136 }
137 }
138
139 for (l, r) in self.equivalent.iter() {
140 let (l, r) = (n.get(l.iter()), n.get(r.iter()));
141 if let (Some(l), Some(r)) = (l, r) {
142 if l != r {
143 return false;
144 }
145 } else {
146 return false;
147 }
148 }
149
150 if self.children.len() > 0 {
151 if let NodeInner::Piecewise(_) = n.as_inner() {
154 let all_meets = self.children.iter().enumerate().all(|(i, pc)| {
155 if let Some(pc) = pc {
156 if let Some(c) = n.get(Path::with_next(i).iter()) {
157 pc.matches(c)
158 } else {
159 false
160 }
161 } else {
162 true
163 }
164 });
165 if !all_meets {
166 return false;
167 }
168 } else {
169 if !self.children.iter().zip(n.iter_children()).all(|(pc, c)| {
170 if let Some(pc) = pc {
171 pc.matches(c)
172 } else {
173 true
174 }
175 }) {
176 return false;
177 }
178 }
179 }
180
181 true
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use crate::ast::Node;
189
190 #[test]
191 fn predicate_op_matches() {
192 assert_eq!(
193 PredicateOp::Binary(BinaryOp::Add).matches(&Node::try_from("3 + 5").unwrap()),
194 true,
195 );
196 assert_eq!(
197 PredicateOp::Binary(BinaryOp::Add).matches(&Node::try_from("3 - 5").unwrap()),
198 false,
199 );
200 assert_eq!(
201 PredicateOp::Binary(BinaryOp::Add).matches(&Node::try_from("-5").unwrap()),
202 false,
203 );
204
205 assert_eq!(
206 PredicateOp::Unary(UnaryOp::Negate).matches(&Node::try_from("-5").unwrap()),
207 true,
208 );
209 assert_eq!(
210 PredicateOp::Const.matches(&Node::try_from("-5").unwrap()),
211 false,
212 );
213 assert_eq!(
214 PredicateOp::Const.matches(&Node::try_from("5").unwrap()),
215 true,
216 );
217 assert_eq!(
218 PredicateOp::Piecewise.matches(&Node::try_from("{otherwise 5}").unwrap()),
219 true,
220 );
221 }
222
223 #[test]
224 fn equivalent_matches() {
225 assert_eq!(
226 Predicate {
227 equivalent: vec![(vec![0].into(), vec![1].into())],
228 ..Default::default()
229 }
230 .matches(&Node::try_from("x + x").unwrap()),
231 true,
232 );
233 assert_eq!(
234 Predicate {
235 equivalent: vec![(vec![0].into(), vec![1].into())],
236 ..Default::default()
237 }
238 .matches(&Node::try_from("x + 2x").unwrap()),
239 false,
240 );
241
242 assert_eq!(
243 Predicate {
244 equivalent: vec![(vec![0].into(), vec![1, 0].into())],
245 ..Default::default()
246 }
247 .matches(&Node::try_from("a * (a + 1)").unwrap()),
248 true,
249 );
250 }
251
252 #[test]
253 fn children_matches() {
254 assert_eq!(
255 Predicate {
256 children: vec![],
257 ..Default::default()
258 }
259 .matches(&Node::try_from("3 + 5").unwrap()),
260 true,
261 );
262 assert_eq!(
263 Predicate {
264 children: vec![None, None],
265 ..Default::default()
266 }
267 .matches(&Node::try_from("3 + 5").unwrap()),
268 true,
269 );
270
271 assert_eq!(
272 Predicate {
273 op: Some(PredicateOp::Binary(BinaryOp::Add)),
274 children: vec![Some(Predicate::op(PredicateOp::Const))],
275 ..Default::default()
276 }
277 .matches(&Node::try_from("3 + 5").unwrap()),
278 true,
279 );
280
281 assert_eq!(
282 Predicate {
283 op: Some(PredicateOp::Binary(BinaryOp::Add)),
284 children: vec![
285 Some(Predicate::op(PredicateOp::Const)),
286 Some(Predicate::op(PredicateOp::Const))
287 ],
288 ..Default::default()
289 }
290 .matches(&Node::try_from("5 + 3 * 4").unwrap()),
291 false,
292 );
293 assert_eq!(
294 Predicate {
295 op: Some(PredicateOp::Binary(BinaryOp::Add)),
296 children: vec![
297 Some(Predicate::op(PredicateOp::Const)),
298 Some(Predicate::op(PredicateOp::Binary(BinaryOp::Mul))),
299 ],
300 ..Default::default()
301 }
302 .matches(&Node::try_from("5 + 3 * 4").unwrap()),
303 true,
304 );
305 assert_eq!(
306 Predicate {
307 children: vec![Some(Predicate::op(PredicateOp::Const)), None],
308 ..Default::default()
309 }
310 .matches(&Node::try_from("5 + 3 * 5").unwrap()),
311 true,
312 );
313
314 assert_eq!(
315 Predicate {
316 children: vec![Some(Predicate::op(PredicateOp::Unary(UnaryOp::Negate)))],
317 ..Default::default()
318 }
319 .matches(&Node::try_from("3 + 5").unwrap()),
320 false,
321 );
322 assert_eq!(
323 Predicate {
324 children: vec![Some(Predicate::op(PredicateOp::Unary(UnaryOp::Negate)))],
325 ..Default::default()
326 }
327 .matches(&Node::try_from("-3 + 5").unwrap()),
328 true,
329 );
330
331 assert_eq!(
333 Predicate {
334 children: vec![Some(Predicate::op(PredicateOp::Unary(UnaryOp::Negate)))],
335 ..Default::default()
336 }
337 .matches(&Node::try_from("{otherwise -2}").unwrap()),
338 true,
339 );
340 assert_eq!(
341 Predicate {
342 children: vec![
343 Some(Predicate::op(PredicateOp::Const)),
344 Some(Predicate::op(PredicateOp::Binary(BinaryOp::Cmp(
345 CmpOp::LessThan(false)
346 )))),
347 Some(Predicate::op(PredicateOp::Unary(UnaryOp::Negate))),
348 ],
349 ..Default::default()
350 }
351 .matches(&Node::try_from("{1 if x < 0; otherwise -2}").unwrap()),
352 true,
353 );
354
355 assert_eq!(
357 Predicate {
358 children: vec![
359 Some(Predicate::children(vec![Some(Predicate::op(
360 PredicateOp::Const
361 ))])),
362 Some(Predicate::children(vec![
363 Some(Predicate::op(PredicateOp::Const)),
364 Some(Predicate::op(PredicateOp::Const))
365 ]))
366 ],
367 op: Some(PredicateOp::Binary(BinaryOp::Add)),
368 ..Default::default()
369 }
370 .matches(&Node::try_from("-4 + 2 * 3").unwrap()),
371 true,
372 );
373 assert_eq!(
374 Predicate {
375 children: vec![
376 Some(Predicate::children(vec![Some(Predicate::op(
377 PredicateOp::Const
378 ))])),
379 Some(Predicate::children(vec![Some(Predicate::op(
380 PredicateOp::Unary(UnaryOp::Negate)
381 ))]))
382 ],
383 ..Default::default()
384 }
385 .matches(&Node::try_from("-4 + 2 * 3").unwrap()),
386 false,
387 );
388 }
389
390 #[test]
391 fn not_op() {
392 assert_eq!(
393 Predicate {
394 not_op: Some(PredicateOp::Binary(BinaryOp::Mul)),
395 ..Default::default()
396 }
397 .matches(&Node::try_from("3 + 5").unwrap()),
398 true,
399 );
400 assert_eq!(
401 Predicate {
402 not_op: Some(PredicateOp::Var),
403 ..Default::default()
404 }
405 .matches(&Node::try_from("x").unwrap()),
406 false,
407 );
408 }
409
410 #[test]
411 fn const_value() {
412 assert_eq!(
413 Predicate {
414 const_value: Some(TyValue::Bool(true)),
415 ..Default::default()
416 }
417 .matches(&Node::try_from("3").unwrap()),
418 false,
419 );
420 assert_eq!(
421 Predicate {
422 const_value: Some(TyValue::from(3.5)),
423 ..Default::default()
424 }
425 .matches(&Node::try_from("3.5 + 2").unwrap()),
426 false,
427 );
428 assert_eq!(
429 Predicate {
430 const_value: Some(TyValue::from(3)),
431 ..Default::default()
432 }
433 .matches(&Node::try_from("3").unwrap()),
434 true,
435 );
436 assert_eq!(
437 Predicate {
438 const_value: Some(TyValue::from(4)),
439 ..Default::default()
440 }
441 .matches(&Node::try_from("3").unwrap()),
442 false,
443 );
444 }
445}