1use std::fmt;
2use std::ops::{BitAnd, BitOr, Not};
3
4#[allow(unused_imports)]
5use erg_common::log;
6use erg_common::set::Set;
7use erg_common::traits::{Immutable, LimitedDisplay, StructuralEq};
8use erg_common::{fmt_option, set, Str};
9
10use super::free::{Constraint, HasLevel};
11use super::typaram::TyParam;
12use super::value::ValueObj;
13use super::{SharedFrees, Type};
14
15impl Immutable for Predicate {}
16
17#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
18pub enum Predicate {
19 Value(ValueObj), Const(Str),
21 Call {
22 receiver: TyParam,
23 name: Option<Str>,
24 args: Vec<TyParam>,
25 },
26 Attr {
27 receiver: TyParam,
28 name: Str,
29 },
30 Equal {
33 lhs: Str,
34 rhs: TyParam,
35 },
36 GreaterEqual {
38 lhs: Str,
39 rhs: TyParam,
40 },
41 LessEqual {
43 lhs: Str,
44 rhs: TyParam,
45 },
46 NotEqual {
47 lhs: Str,
48 rhs: TyParam,
49 },
50 GeneralEqual {
51 lhs: Box<Predicate>,
52 rhs: Box<Predicate>,
53 },
54 GeneralLessEqual {
55 lhs: Box<Predicate>,
56 rhs: Box<Predicate>,
57 },
58 GeneralGreaterEqual {
59 lhs: Box<Predicate>,
60 rhs: Box<Predicate>,
61 },
62 GeneralNotEqual {
63 lhs: Box<Predicate>,
64 rhs: Box<Predicate>,
65 },
66 Or(Set<Predicate>),
67 And(Box<Predicate>, Box<Predicate>),
68 Not(Box<Predicate>),
69 #[default]
70 Failure,
71}
72
73impl fmt::Display for Predicate {
74 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75 match self {
76 Self::Value(v) => write!(f, "{v}"),
77 Self::Const(c) => write!(f, "{c}"),
78 Self::Call {
79 receiver,
80 name,
81 args,
82 } => {
83 write!(
84 f,
85 "{receiver}{}({})",
86 fmt_option!(pre ".", name),
87 args.iter()
88 .map(|a| format!("{a}"))
89 .collect::<Vec<_>>()
90 .join(", ")
91 )
92 }
93 Self::Attr { receiver, name } => write!(f, "{receiver}.{name}"),
94 Self::Equal { lhs, rhs } => write!(f, "{lhs} == {rhs}"),
95 Self::GreaterEqual { lhs, rhs } => write!(f, "{lhs} >= {rhs}"),
96 Self::LessEqual { lhs, rhs } => write!(f, "{lhs} <= {rhs}"),
97 Self::NotEqual { lhs, rhs } => write!(f, "{lhs} != {rhs}"),
98 Self::GeneralEqual { lhs, rhs } => write!(f, "{lhs} == {rhs}"),
99 Self::GeneralLessEqual { lhs, rhs } => write!(f, "{lhs} <= {rhs}"),
100 Self::GeneralGreaterEqual { lhs, rhs } => write!(f, "{lhs} >= {rhs}"),
101 Self::GeneralNotEqual { lhs, rhs } => write!(f, "{lhs} != {rhs}"),
102 Self::Or(preds) => {
103 write!(f, "(")?;
104 for (i, pred) in preds.iter().enumerate() {
105 if i != 0 {
106 write!(f, " or ")?;
107 }
108 write!(f, "{pred}")?;
109 }
110 write!(f, ")")
111 }
112 Self::And(l, r) => write!(f, "({l}) and ({r})"),
113 Self::Not(p) => write!(f, "not ({p})"),
114 Self::Failure => write!(f, "<failure>"),
115 }
116 }
117}
118
119impl LimitedDisplay for Predicate {
120 fn limited_fmt<W: std::fmt::Write>(&self, f: &mut W, limit: isize) -> std::fmt::Result {
121 if limit == 0 {
122 return write!(f, "...");
123 }
124 match self {
125 Self::Value(v) => v.limited_fmt(f, limit),
126 Self::Const(c) => write!(f, "{c}"),
127 Self::Call {
129 receiver,
130 name,
131 args,
132 } => {
133 write!(
134 f,
135 "{receiver}{}({})",
136 fmt_option!(pre ".", name),
137 args.iter()
138 .map(|a| format!("{a}"))
139 .collect::<Vec<_>>()
140 .join(", ")
141 )
142 }
143 Self::Attr { receiver, name } => write!(f, "{receiver}.{name}"),
144 Self::Equal { lhs, rhs } => {
145 write!(f, "{lhs} == ")?;
146 rhs.limited_fmt(f, limit - 1)
147 }
148 Self::GreaterEqual { lhs, rhs } => {
149 write!(f, "{lhs} >= ")?;
150 rhs.limited_fmt(f, limit - 1)
151 }
152 Self::LessEqual { lhs, rhs } => {
153 write!(f, "{lhs} <= ")?;
154 rhs.limited_fmt(f, limit - 1)
155 }
156 Self::NotEqual { lhs, rhs } => {
157 write!(f, "{lhs} != ")?;
158 rhs.limited_fmt(f, limit - 1)
159 }
160 Self::GeneralEqual { lhs, rhs } => {
161 lhs.limited_fmt(f, limit - 1)?;
162 write!(f, " == ")?;
163 rhs.limited_fmt(f, limit - 1)
164 }
165 Self::GeneralLessEqual { lhs, rhs } => {
166 lhs.limited_fmt(f, limit - 1)?;
167 write!(f, " <= ")?;
168 rhs.limited_fmt(f, limit - 1)
169 }
170 Self::GeneralGreaterEqual { lhs, rhs } => {
171 lhs.limited_fmt(f, limit - 1)?;
172 write!(f, " >= ")?;
173 rhs.limited_fmt(f, limit - 1)
174 }
175 Self::GeneralNotEqual { lhs, rhs } => {
176 lhs.limited_fmt(f, limit - 1)?;
177 write!(f, " != ")?;
178 rhs.limited_fmt(f, limit - 1)
179 }
180 Self::Or(preds) => {
181 write!(f, "(")?;
182 for (i, pred) in preds.iter().enumerate() {
183 if i != 0 {
184 write!(f, " or ")?;
185 }
186 pred.limited_fmt(f, limit - 1)?;
187 }
188 write!(f, ")")
189 }
190 Self::And(l, r) => {
191 write!(f, "(")?;
192 l.limited_fmt(f, limit - 1)?;
193 write!(f, ") and (")?;
194 r.limited_fmt(f, limit - 1)?;
195 write!(f, ")")
196 }
197 Self::Not(p) => {
198 write!(f, "not (")?;
199 p.limited_fmt(f, limit - 1)?;
200 write!(f, ")")
201 }
202 Self::Failure => write!(f, "<failure>"),
203 }
204 }
205}
206
207impl StructuralEq for Predicate {
208 fn structural_eq(&self, other: &Self) -> bool {
209 match (self, other) {
210 (Self::Equal { rhs, .. }, Self::Equal { rhs: r2, .. })
211 | (Self::NotEqual { rhs, .. }, Self::NotEqual { rhs: r2, .. })
212 | (Self::GreaterEqual { rhs, .. }, Self::GreaterEqual { rhs: r2, .. })
213 | (Self::LessEqual { rhs, .. }, Self::LessEqual { rhs: r2, .. }) => {
214 rhs.structural_eq(r2)
215 }
216 (Self::GeneralEqual { lhs, rhs }, Self::GeneralEqual { lhs: l, rhs: r })
217 | (Self::GeneralLessEqual { lhs, rhs }, Self::GeneralLessEqual { lhs: l, rhs: r })
218 | (
219 Self::GeneralGreaterEqual { lhs, rhs },
220 Self::GeneralGreaterEqual { lhs: l, rhs: r },
221 )
222 | (Self::GeneralNotEqual { lhs, rhs }, Self::GeneralNotEqual { lhs: l, rhs: r }) => {
223 lhs.structural_eq(l) && rhs.structural_eq(r)
224 }
225 (
226 Self::Attr { receiver, name },
227 Self::Attr {
228 receiver: r,
229 name: n,
230 },
231 ) => receiver.structural_eq(r) && name == n,
232 (
233 Self::Call {
234 receiver,
235 name,
236 args,
237 },
238 Self::Call {
239 receiver: r,
240 name: n,
241 args: a,
242 },
243 ) => {
244 receiver.structural_eq(r)
245 && name == n
246 && args.iter().zip(a.iter()).all(|(l, r)| l.structural_eq(r))
247 }
248 (Self::Or(self_ors), Self::Or(other_ors)) => {
249 if self_ors.len() != other_ors.len() {
250 return false;
251 }
252 for l_val in self_ors.iter() {
253 if other_ors.get_by(l_val, |l, r| l.structural_eq(r)).is_none() {
254 return false;
255 }
256 }
257 true
258 }
259 (Self::And(_, _), Self::And(_, _)) => {
260 let self_ands = self.ands();
261 let other_ands = other.ands();
262 if self_ands.len() != other_ands.len() {
263 return false;
264 }
265 for l_val in self_ands.iter() {
266 if other_ands
267 .get_by(l_val, |l, r| l.structural_eq(r))
268 .is_none()
269 {
270 return false;
271 }
272 }
273 true
274 }
275 (Self::Not(p1), Self::Not(p2)) => p1.structural_eq(p2),
276 _ => self == other,
277 }
278 }
279}
280
281impl HasLevel for Predicate {
282 fn level(&self) -> Option<usize> {
283 match self {
284 Self::Value(_) | Self::Const(_) | Self::Failure => None,
285 Self::Equal { rhs, .. }
286 | Self::GreaterEqual { rhs, .. }
287 | Self::LessEqual { rhs, .. }
288 | Self::NotEqual { rhs, .. } => rhs.level(),
289 Self::GeneralEqual { lhs, rhs }
290 | Self::GeneralLessEqual { lhs, rhs }
291 | Self::GeneralGreaterEqual { lhs, rhs }
292 | Self::GeneralNotEqual { lhs, rhs } => {
293 lhs.level().zip(rhs.level()).map(|(a, b)| a.min(b))
294 }
295 Self::Or(preds) => preds.iter().filter_map(|p| p.level()).min(),
296 Self::And(lhs, rhs) => lhs.level().zip(rhs.level()).map(|(a, b)| a.min(b)),
297 Self::Not(p) => p.level(),
298 Self::Call { receiver, args, .. } => receiver
299 .level()
300 .zip(args.iter().map(|a| a.level().unwrap_or(usize::MAX)).min())
301 .map(|(a, b)| a.min(b)),
302 Self::Attr { receiver, .. } => receiver.level(),
303 }
304 }
305
306 fn set_level(&self, level: usize) {
307 match self {
308 Self::Value(_) | Self::Const(_) | Self::Failure => {}
309 Self::Call { receiver, args, .. } => {
310 receiver.set_level(level);
311 for arg in args {
312 arg.set_level(level);
313 }
314 }
315 Self::Attr { receiver, .. } => {
316 receiver.set_level(level);
317 }
318 Self::Equal { rhs, .. }
319 | Self::GreaterEqual { rhs, .. }
320 | Self::LessEqual { rhs, .. }
321 | Self::NotEqual { rhs, .. } => {
322 rhs.set_level(level);
323 }
324 Self::GeneralEqual { lhs, rhs }
325 | Self::GeneralLessEqual { lhs, rhs }
326 | Self::GeneralGreaterEqual { lhs, rhs }
327 | Self::GeneralNotEqual { lhs, rhs } => {
328 lhs.set_level(level);
329 rhs.set_level(level);
330 }
331 Self::Or(preds) => {
332 for pred in preds {
333 pred.set_level(level);
334 }
335 }
336 Self::And(lhs, rhs) => {
337 lhs.set_level(level);
338 rhs.set_level(level);
339 }
340 Self::Not(p) => {
341 p.set_level(level);
342 }
343 }
344 }
345}
346
347impl BitAnd for Predicate {
348 type Output = Self;
349
350 fn bitand(self, rhs: Self) -> Self::Output {
351 Self::and(self, rhs)
352 }
353}
354
355impl BitOr for Predicate {
356 type Output = Self;
357
358 fn bitor(self, rhs: Self) -> Self::Output {
359 Self::or(self, rhs)
360 }
361}
362
363impl Not for Predicate {
364 type Output = Self;
365
366 fn not(self) -> Self::Output {
367 self.invert()
368 }
369}
370
371impl Predicate {
372 pub const TRUE: Predicate = Predicate::Value(ValueObj::Bool(true));
373 pub const FALSE: Predicate = Predicate::Value(ValueObj::Bool(false));
374
375 pub fn general_eq(lhs: Predicate, rhs: Predicate) -> Self {
376 Self::GeneralEqual {
377 lhs: Box::new(lhs),
378 rhs: Box::new(rhs),
379 }
380 }
381
382 pub fn general_ge(lhs: Predicate, rhs: Predicate) -> Self {
383 Self::GeneralGreaterEqual {
384 lhs: Box::new(lhs),
385 rhs: Box::new(rhs),
386 }
387 }
388
389 pub fn general_le(lhs: Predicate, rhs: Predicate) -> Self {
390 Self::GeneralLessEqual {
391 lhs: Box::new(lhs),
392 rhs: Box::new(rhs),
393 }
394 }
395
396 pub fn general_ne(lhs: Predicate, rhs: Predicate) -> Self {
397 Self::GeneralNotEqual {
398 lhs: Box::new(lhs),
399 rhs: Box::new(rhs),
400 }
401 }
402
403 pub fn call(receiver: TyParam, name: Option<Str>, args: Vec<TyParam>) -> Self {
404 Self::Call {
405 receiver,
406 name,
407 args,
408 }
409 }
410
411 pub fn attr(receiver: TyParam, name: Str) -> Self {
412 Self::Attr { receiver, name }
413 }
414
415 pub const fn eq(lhs: Str, rhs: TyParam) -> Self {
416 Self::Equal { lhs, rhs }
417 }
418 pub const fn ne(lhs: Str, rhs: TyParam) -> Self {
419 Self::NotEqual { lhs, rhs }
420 }
421 pub const fn ge(lhs: Str, rhs: TyParam) -> Self {
423 Self::GreaterEqual { lhs, rhs }
424 }
425
426 pub fn gt(lhs: Str, rhs: TyParam) -> Self {
428 Self::and(Self::ge(lhs.clone(), rhs.clone()), Self::ne(lhs, rhs))
429 }
430
431 pub const fn le(lhs: Str, rhs: TyParam) -> Self {
433 Self::LessEqual { lhs, rhs }
434 }
435
436 pub fn lt(lhs: Str, rhs: TyParam) -> Self {
438 Self::and(Self::le(lhs.clone(), rhs.clone()), Self::ne(lhs, rhs))
439 }
440
441 pub fn and(lhs: Predicate, rhs: Predicate) -> Self {
442 match (lhs, rhs) {
443 (Predicate::Value(ValueObj::Bool(true)), p) => p,
444 (p, Predicate::Value(ValueObj::Bool(true))) => p,
445 (Predicate::Value(ValueObj::Bool(false)), _)
446 | (_, Predicate::Value(ValueObj::Bool(false))) => Predicate::FALSE,
447 (Predicate::And(l, r), other) | (other, Predicate::And(l, r)) => {
448 if l.as_ref() == &other {
449 *r & other
450 } else if r.as_ref() == &other {
451 *l & other
452 } else {
453 Self::And(Box::new(Self::And(l, r)), Box::new(other))
454 }
455 }
456 (p1, p2) => {
457 if p1 == p2 {
458 p1
459 } else {
460 Self::And(Box::new(p1), Box::new(p2))
461 }
462 }
463 }
464 }
465
466 pub fn or(lhs: Predicate, rhs: Predicate) -> Self {
467 match (lhs, rhs) {
468 (Predicate::Value(ValueObj::Bool(true)), _)
469 | (_, Predicate::Value(ValueObj::Bool(true))) => Predicate::TRUE,
470 (Predicate::Value(ValueObj::Bool(false)), p) => p,
471 (p, Predicate::Value(ValueObj::Bool(false))) => p,
472 (Predicate::Or(l), Predicate::Or(r)) => Self::Or(l.union(&r)),
473 (Predicate::Or(mut preds), other) | (other, Predicate::Or(mut preds)) => {
474 preds.insert(other);
475 Self::Or(preds)
476 }
477 (
479 Predicate::Equal { lhs, rhs },
480 Predicate::GreaterEqual {
481 lhs: lhs2,
482 rhs: rhs2,
483 },
484 ) if lhs == lhs2 && rhs == rhs2 => Self::ge(lhs, rhs),
485 (p1, p2) => {
486 if p1 == p2 {
487 p1
488 } else {
489 Self::Or(set! { p1, p2 })
490 }
491 }
492 }
493 }
494
495 pub fn is_equal(&self) -> bool {
496 matches!(self, Self::Equal { .. })
497 }
498
499 pub fn consist_of_equal(&self) -> bool {
500 match self {
501 Self::Equal { .. } => true,
502 Self::Or(preds) => preds.iter().all(|p| p.consist_of_equal()),
503 Self::And(lhs, rhs) => lhs.consist_of_equal() && rhs.consist_of_equal(),
504 Self::Not(pred) => pred.consist_of_equal(),
505 _ => false,
506 }
507 }
508
509 pub fn ands(&self) -> Set<&Predicate> {
510 match self {
511 Self::And(lhs, rhs) => {
512 let mut set = lhs.ands();
513 set.extend(rhs.ands());
514 set
515 }
516 _ => set! { self },
517 }
518 }
519
520 pub fn ors(&self) -> Set<&Predicate> {
521 match self {
522 Self::Or(preds) => preds.iter().collect(),
523 _ => set! { self },
524 }
525 }
526
527 pub fn subject(&self) -> Option<&str> {
528 match self {
529 Self::Equal { lhs, .. }
530 | Self::LessEqual { lhs, .. }
531 | Self::GreaterEqual { lhs, .. }
532 | Self::NotEqual { lhs, .. } => Some(&lhs[..]),
533 Self::Or(preds) => {
534 let mut iter = preds.iter();
535 let first = iter.next()?;
536 let subject = first.subject()?;
537 for pred in iter {
538 if subject != pred.subject()? {
539 return None;
540 }
541 }
542 Some(subject)
543 }
544 Self::And(lhs, rhs) => {
545 let l = lhs.subject();
546 let r = rhs.subject();
547 if l != r {
548 log!(err "{l:?} != {r:?}");
549 None
550 } else {
551 l
552 }
553 }
554 Self::Not(pred) => pred.subject(),
555 _ => None,
556 }
557 }
558
559 pub fn change_subject_name(self, name: Str) -> Self {
560 match self {
561 Self::Equal { rhs, .. } => Self::eq(name, rhs),
562 Self::GreaterEqual { rhs, .. } => Self::ge(name, rhs),
563 Self::LessEqual { rhs, .. } => Self::le(name, rhs),
564 Self::NotEqual { rhs, .. } => Self::ne(name, rhs),
565 Self::And(lhs, rhs) => Self::and(
566 lhs.change_subject_name(name.clone()),
567 rhs.change_subject_name(name),
568 ),
569 Self::Or(preds) => Self::Or(
570 preds
571 .iter()
572 .cloned()
573 .map(|p| p.change_subject_name(name.clone()))
574 .collect(),
575 ),
576 Self::Not(pred) => Self::not(pred.change_subject_name(name)),
577 Self::GeneralEqual { lhs, rhs } => Self::general_eq(
578 lhs.change_subject_name(name.clone()),
579 rhs.change_subject_name(name),
580 ),
581 Self::GeneralGreaterEqual { lhs, rhs } => Self::general_ge(
582 lhs.change_subject_name(name.clone()),
583 rhs.change_subject_name(name),
584 ),
585 Self::GeneralLessEqual { lhs, rhs } => Self::general_le(
586 lhs.change_subject_name(name.clone()),
587 rhs.change_subject_name(name),
588 ),
589 Self::GeneralNotEqual { lhs, rhs } => Self::general_ne(
590 lhs.change_subject_name(name.clone()),
591 rhs.change_subject_name(name),
592 ),
593 Self::Value(_)
594 | Self::Const(_)
595 | Self::Call { .. }
596 | Self::Attr { .. }
597 | Self::Failure => self,
598 }
599 }
600
601 pub fn substitute(self, var: &str, tp: &TyParam) -> Self {
602 match self {
603 Self::Equal { lhs, .. } if lhs == var => Self::eq(lhs, tp.clone()),
604 Self::GreaterEqual { lhs, .. } if lhs == var => Self::ge(lhs, tp.clone()),
605 Self::LessEqual { lhs, .. } if lhs == var => Self::le(lhs, tp.clone()),
606 Self::NotEqual { lhs, .. } if lhs == var => Self::ne(lhs, tp.clone()),
607 Self::Equal { lhs, rhs } => Self::eq(lhs, rhs.substitute(var, tp)),
608 Self::GreaterEqual { lhs, rhs } => Self::ge(lhs, rhs.substitute(var, tp)),
609 Self::LessEqual { lhs, rhs } => Self::le(lhs, rhs.substitute(var, tp)),
610 Self::NotEqual { lhs, rhs } => Self::ne(lhs, rhs.substitute(var, tp)),
611 Self::And(lhs, rhs) => Self::and(lhs.substitute(var, tp), rhs.substitute(var, tp)),
612 Self::Or(preds) => Self::Or(preds.into_iter().map(|p| p.substitute(var, tp)).collect()),
613 Self::Not(pred) => Self::not(pred.substitute(var, tp)),
614 Self::GeneralEqual { lhs, rhs } => {
615 Self::general_eq(lhs.substitute(var, tp), rhs.substitute(var, tp))
616 }
617 Self::GeneralGreaterEqual { lhs, rhs } => {
618 Self::general_ge(lhs.substitute(var, tp), rhs.substitute(var, tp))
619 }
620 Self::GeneralLessEqual { lhs, rhs } => {
621 Self::general_le(lhs.substitute(var, tp), rhs.substitute(var, tp))
622 }
623 Self::GeneralNotEqual { lhs, rhs } => {
624 Self::general_ne(lhs.substitute(var, tp), rhs.substitute(var, tp))
625 }
626 Self::Call {
627 receiver,
628 name,
629 args,
630 } => {
631 let receiver = receiver.substitute(var, tp);
632 let mut new_args = vec![];
633 for arg in args {
634 new_args.push(arg.substitute(var, tp));
635 }
636 Self::Call {
637 receiver,
638 name,
639 args: new_args,
640 }
641 }
642 Self::Attr { receiver, name } => Self::Attr {
643 receiver: receiver.substitute(var, tp),
644 name,
645 },
646 Self::Value(_) | Self::Const(_) | Self::Failure => self,
647 }
648 }
649
650 pub fn mentions(&self, name: &str) -> bool {
651 match self {
652 Self::Const(n) => &n[..] == name,
653 Self::Equal { lhs, .. }
654 | Self::LessEqual { lhs, .. }
655 | Self::GreaterEqual { lhs, .. }
656 | Self::NotEqual { lhs, .. } => &lhs[..] == name,
657 Self::GeneralEqual { lhs, rhs }
658 | Self::GeneralLessEqual { lhs, rhs }
659 | Self::GeneralGreaterEqual { lhs, rhs }
660 | Self::GeneralNotEqual { lhs, rhs } => lhs.mentions(name) || rhs.mentions(name),
661 Self::Not(pred) => pred.mentions(name),
662 Self::And(lhs, rhs) => lhs.mentions(name) || rhs.mentions(name),
663 Self::Or(preds) => preds.iter().any(|p| p.mentions(name)),
664 _ => false,
665 }
666 }
667
668 pub fn can_be_false(&self) -> Option<bool> {
669 match self {
670 Self::Value(l) => Some(matches!(l, ValueObj::Bool(false))),
671 Self::Const(_) => None,
672 Self::Or(preds) => {
673 for pred in preds {
674 if pred.can_be_false()? {
675 return Some(true);
676 }
677 }
678 Some(false)
679 }
680 Self::And(lhs, rhs) => Some(lhs.can_be_false()? && rhs.can_be_false()?),
681 Self::Not(pred) => Some(!pred.can_be_false()?),
682 _ => Some(true),
683 }
684 }
685
686 pub fn qvars(&self) -> Set<(Str, Constraint)> {
687 match self {
688 Self::Const(_) | Self::Failure => set! {},
689 Self::Value(val) => val.qvars(),
690 Self::Call { receiver, args, .. } => {
691 let mut set = receiver.qvars();
692 for arg in args {
693 set.extend(arg.qvars());
694 }
695 set
696 }
697 Self::Attr { receiver, .. } => receiver.qvars(),
698 Self::Equal { rhs, .. }
699 | Self::GreaterEqual { rhs, .. }
700 | Self::LessEqual { rhs, .. }
701 | Self::NotEqual { rhs, .. } => rhs.qvars(),
702 Self::GeneralEqual { lhs, rhs }
703 | Self::GeneralLessEqual { lhs, rhs }
704 | Self::GeneralGreaterEqual { lhs, rhs }
705 | Self::GeneralNotEqual { lhs, rhs } => {
706 lhs.qvars().concat(rhs.qvars()).into_iter().collect()
707 }
708 Self::And(lhs, rhs) => lhs.qvars().concat(rhs.qvars()),
709 Self::Or(preds) => preds.iter().fold(set! {}, |acc, p| acc.union(&p.qvars())),
710 Self::Not(pred) => pred.qvars(),
711 }
712 }
713
714 pub fn has_type_satisfies(&self, f: impl Fn(&Type) -> bool + Copy) -> bool {
715 match self {
716 Self::Const(_) | Self::Failure => false,
717 Self::Value(val) => val.has_type_satisfies(f),
718 Self::Call { receiver, args, .. } => {
719 receiver.has_type_satisfies(f) || args.iter().any(|a| a.has_type_satisfies(f))
720 }
721 Self::Attr { receiver, .. } => receiver.has_type_satisfies(f),
722 Self::Equal { rhs, .. }
723 | Self::GreaterEqual { rhs, .. }
724 | Self::LessEqual { rhs, .. }
725 | Self::NotEqual { rhs, .. } => rhs.has_type_satisfies(f),
726 Self::GeneralEqual { lhs, rhs }
727 | Self::GeneralLessEqual { lhs, rhs }
728 | Self::GeneralGreaterEqual { lhs, rhs }
729 | Self::GeneralNotEqual { lhs, rhs } => {
730 lhs.has_type_satisfies(f) || rhs.has_type_satisfies(f)
731 }
732 Self::And(lhs, rhs) => lhs.has_type_satisfies(f) || rhs.has_type_satisfies(f),
733 Self::Or(preds) => preds.iter().any(|p| p.has_type_satisfies(f)),
734 Self::Not(pred) => pred.has_type_satisfies(f),
735 }
736 }
737
738 pub fn has_qvar(&self) -> bool {
739 match self {
740 Self::Const(_) | Self::Failure => false,
741 Self::Value(val) => val.has_qvar(),
742 Self::Call { receiver, args, .. } => {
743 receiver.has_qvar() || args.iter().any(|a| a.has_qvar())
744 }
745 Self::Attr { receiver, .. } => receiver.has_qvar(),
746 Self::Equal { rhs, .. }
747 | Self::GreaterEqual { rhs, .. }
748 | Self::LessEqual { rhs, .. }
749 | Self::NotEqual { rhs, .. } => rhs.has_qvar(),
750 Self::GeneralEqual { lhs, rhs }
751 | Self::GeneralLessEqual { lhs, rhs }
752 | Self::GeneralGreaterEqual { lhs, rhs }
753 | Self::GeneralNotEqual { lhs, rhs } => lhs.has_qvar() || rhs.has_qvar(),
754 Self::And(lhs, rhs) => lhs.has_qvar() || rhs.has_qvar(),
755 Self::Or(preds) => preds.iter().any(|p| p.has_qvar()),
756 Self::Not(pred) => pred.has_qvar(),
757 }
758 }
759
760 pub fn has_unbound_var(&self) -> bool {
761 match self {
762 Self::Const(_) | Self::Failure => false,
763 Self::Value(val) => val.has_unbound_var(),
764 Self::Call { receiver, args, .. } => {
765 receiver.has_unbound_var() || args.iter().any(|a| a.has_unbound_var())
766 }
767 Self::Attr { receiver, .. } => receiver.has_unbound_var(),
768 Self::Equal { rhs, .. }
769 | Self::GreaterEqual { rhs, .. }
770 | Self::LessEqual { rhs, .. }
771 | Self::NotEqual { rhs, .. } => rhs.has_unbound_var(),
772 Self::GeneralEqual { lhs, rhs }
773 | Self::GeneralLessEqual { lhs, rhs }
774 | Self::GeneralGreaterEqual { lhs, rhs }
775 | Self::GeneralNotEqual { lhs, rhs } => lhs.has_unbound_var() || rhs.has_unbound_var(),
776 Self::And(lhs, rhs) => lhs.has_unbound_var() || rhs.has_unbound_var(),
777 Self::Or(preds) => preds.iter().any(|p| p.has_unbound_var()),
778 Self::Not(pred) => pred.has_unbound_var(),
779 }
780 }
781
782 pub fn has_undoable_linked_var(&self) -> bool {
783 match self {
784 Self::Const(_) | Self::Failure => false,
785 Self::Value(val) => val.has_undoable_linked_var(),
786 Self::Call { receiver, args, .. } => {
787 receiver.has_undoable_linked_var()
788 || args.iter().any(|a| a.has_undoable_linked_var())
789 }
790 Self::Attr { receiver, .. } => receiver.has_undoable_linked_var(),
791 Self::Equal { rhs, .. }
792 | Self::GreaterEqual { rhs, .. }
793 | Self::LessEqual { rhs, .. }
794 | Self::NotEqual { rhs, .. } => rhs.has_undoable_linked_var(),
795 Self::GeneralEqual { lhs, rhs }
796 | Self::GeneralLessEqual { lhs, rhs }
797 | Self::GeneralGreaterEqual { lhs, rhs }
798 | Self::GeneralNotEqual { lhs, rhs } => {
799 lhs.has_undoable_linked_var() || rhs.has_undoable_linked_var()
800 }
801 Self::And(lhs, rhs) => lhs.has_undoable_linked_var() || rhs.has_undoable_linked_var(),
802 Self::Or(preds) => preds.iter().any(|p| p.has_undoable_linked_var()),
803 Self::Not(pred) => pred.has_undoable_linked_var(),
804 }
805 }
806
807 pub fn min_max<'a>(
808 &'a self,
809 min: Option<&'a TyParam>,
810 max: Option<&'a TyParam>,
811 ) -> (Option<&'a TyParam>, Option<&'a TyParam>) {
812 match self {
813 Predicate::Equal { rhs: _, .. } => todo!(),
814 Predicate::LessEqual { rhs, .. } => (
816 min,
817 max.map(|l: &TyParam| match l.cheap_cmp(rhs) {
818 Some(c) if c.is_ge() => l,
819 Some(_) => rhs,
820 _ => l,
821 })
822 .or(Some(rhs)),
823 ),
824 Predicate::GreaterEqual { rhs, .. } => (
826 min.map(|l: &TyParam| match l.cheap_cmp(rhs) {
827 Some(c) if c.is_le() => l,
828 Some(_) => rhs,
829 _ => l,
830 })
831 .or(Some(rhs)),
832 max,
833 ),
834 Predicate::And(_l, _r) => todo!(),
835 _ => todo!(),
836 }
837 }
838
839 pub fn typarams(&self) -> Vec<&TyParam> {
840 match self {
841 Self::Value(_) | Self::Const(_) | Self::Attr { .. } | Self::Failure => vec![],
842 Self::Call { args, .. } => {
844 let mut vec = vec![];
845 vec.extend(args);
846 vec
847 }
848 Self::Equal { rhs, .. }
849 | Self::GreaterEqual { rhs, .. }
850 | Self::LessEqual { rhs, .. }
851 | Self::NotEqual { rhs, .. } => vec![rhs],
852 Self::GeneralEqual { .. }
853 | Self::GeneralLessEqual { .. }
854 | Self::GeneralGreaterEqual { .. }
855 | Self::GeneralNotEqual { .. } => vec![],
856 Self::And(lhs, rhs) => lhs.typarams().into_iter().chain(rhs.typarams()).collect(),
857 Self::Or(preds) => preds.iter().flat_map(|p| p.typarams()).collect(),
858 Self::Not(pred) => pred.typarams(),
859 }
860 }
861
862 pub fn invert(self) -> Self {
863 match self {
864 Self::Value(ValueObj::Bool(b)) => Self::Value(ValueObj::Bool(!b)),
865 Self::Equal { lhs, rhs } => Self::ne(lhs, rhs),
866 Self::GreaterEqual { lhs, rhs } => Self::lt(lhs, rhs),
867 Self::LessEqual { lhs, rhs } => Self::gt(lhs, rhs),
868 Self::NotEqual { lhs, rhs } => Self::eq(lhs, rhs),
869 Self::GeneralEqual { lhs, rhs } => Self::GeneralNotEqual { lhs, rhs },
870 Self::GeneralLessEqual { lhs, rhs } => Self::GeneralGreaterEqual { lhs, rhs },
871 Self::GeneralGreaterEqual { lhs, rhs } => Self::GeneralLessEqual { lhs, rhs },
872 Self::GeneralNotEqual { lhs, rhs } => Self::GeneralEqual { lhs, rhs },
873 Self::Not(pred) => *pred,
874 other => Self::Not(Box::new(other)),
875 }
876 }
877
878 pub fn possible_tps(&self) -> Vec<&TyParam> {
879 match self {
880 Self::Or(preds) => preds.iter().flat_map(|p| p.possible_tps()).collect(),
881 Self::Equal { rhs, .. } => vec![rhs],
882 _ => vec![],
883 }
884 }
885
886 pub fn possible_values(&self) -> Vec<&ValueObj> {
887 match self {
888 Self::Equal {
890 rhs: TyParam::Value(value),
891 ..
892 } => vec![value],
893 Self::Or(preds) => preds.iter().flat_map(|p| p.possible_values()).collect(),
894 _ => vec![],
895 }
896 }
897
898 pub fn variables(&self) -> Set<Str> {
899 match self {
900 Self::Value(_) | Self::Failure => set! {},
901 Self::Const(name) => set! { name.clone() },
902 Self::Call { receiver, args, .. } => {
903 let mut set = receiver.variables();
904 for arg in args {
905 set.extend(arg.variables());
906 }
907 set
908 }
909 Self::Attr { receiver, .. } => receiver.variables(),
910 Self::Equal { rhs, .. }
911 | Self::GreaterEqual { rhs, .. }
912 | Self::LessEqual { rhs, .. }
913 | Self::NotEqual { rhs, .. } => rhs.variables(),
914 Self::GeneralEqual { lhs, rhs }
915 | Self::GeneralLessEqual { lhs, rhs }
916 | Self::GeneralGreaterEqual { lhs, rhs }
917 | Self::GeneralNotEqual { lhs, rhs } => lhs.variables().concat(rhs.variables()),
918 Self::And(lhs, rhs) => lhs.variables().concat(rhs.variables()),
919 Self::Or(preds) => preds
920 .iter()
921 .fold(set! {}, |acc, p| acc.union(&p.variables())),
922 Self::Not(pred) => pred.variables(),
923 }
924 }
925
926 pub fn contains_value(&self, value: &ValueObj) -> bool {
927 match self {
928 Self::Value(v) => v.contains(value),
929 Self::Const(_) => false,
930 Self::Call { receiver, args, .. } => {
931 receiver.contains_value(value) || args.iter().any(|a| a.contains_value(value))
932 }
933 Self::Attr { receiver, .. } => receiver.contains_value(value),
934 Self::Equal { rhs, .. }
935 | Self::GreaterEqual { rhs, .. }
936 | Self::LessEqual { rhs, .. }
937 | Self::NotEqual { rhs, .. } => rhs.contains_value(value),
938 Self::GeneralEqual { lhs, rhs }
939 | Self::GeneralLessEqual { lhs, rhs }
940 | Self::GeneralGreaterEqual { lhs, rhs }
941 | Self::GeneralNotEqual { lhs, rhs } => {
942 lhs.contains_value(value) || rhs.contains_value(value)
943 }
944 Self::And(lhs, rhs) => lhs.contains_value(value) || rhs.contains_value(value),
945 Self::Or(preds) => preds.iter().any(|p| p.contains_value(value)),
946 Self::Not(pred) => pred.contains_value(value),
947 Self::Failure => false,
948 }
949 }
950
951 pub fn contains_tp(&self, tp: &TyParam) -> bool {
952 match self {
953 Self::Value(v) => v.contains_tp(tp),
954 Self::Call { receiver, args, .. } => {
955 receiver.contains_tp(tp) || args.iter().any(|a| a.contains_tp(tp))
956 }
957 Self::Attr { receiver, .. } => receiver.contains_tp(tp),
958 Self::Equal { rhs, .. }
959 | Self::GreaterEqual { rhs, .. }
960 | Self::LessEqual { rhs, .. }
961 | Self::NotEqual { rhs, .. } => rhs.contains_tp(tp),
962 Self::GeneralEqual { lhs, rhs }
963 | Self::GeneralLessEqual { lhs, rhs }
964 | Self::GeneralGreaterEqual { lhs, rhs }
965 | Self::GeneralNotEqual { lhs, rhs } => lhs.contains_tp(tp) || rhs.contains_tp(tp),
966 Self::And(lhs, rhs) => lhs.contains_tp(tp) || rhs.contains_tp(tp),
967 Self::Or(preds) => preds.iter().any(|p| p.contains_tp(tp)),
968 Self::Not(pred) => pred.contains_tp(tp),
969 Self::Failure | Self::Const(_) => false,
970 }
971 }
972
973 pub fn contains_t(&self, t: &Type) -> bool {
974 match self {
975 Self::Value(v) => v.contains_type(t),
976 Self::Call { receiver, args, .. } => {
977 receiver.contains_type(t) || args.iter().any(|a| a.contains_type(t))
978 }
979 Self::Attr { receiver, .. } => receiver.contains_type(t),
980 Self::Equal { rhs, .. }
981 | Self::GreaterEqual { rhs, .. }
982 | Self::LessEqual { rhs, .. }
983 | Self::NotEqual { rhs, .. } => rhs.contains_type(t),
984 Self::GeneralEqual { lhs, rhs }
985 | Self::GeneralLessEqual { lhs, rhs }
986 | Self::GeneralGreaterEqual { lhs, rhs }
987 | Self::GeneralNotEqual { lhs, rhs } => lhs.contains_t(t) || rhs.contains_t(t),
988 Self::And(lhs, rhs) => lhs.contains_t(t) || rhs.contains_t(t),
989 Self::Or(preds) => preds.iter().any(|p| p.contains_t(t)),
990 Self::Not(pred) => pred.contains_t(t),
991 Self::Const(_) | Self::Failure => false,
992 }
993 }
994
995 pub fn _replace_tp(self, target: &TyParam, to: &TyParam, tvs: &SharedFrees) -> Self {
996 self.map_tp(&mut |tp| tp._replace(target, to, tvs), tvs)
997 }
998
999 pub fn replace_tp(self, target: &TyParam, to: &TyParam) -> Self {
1000 self.map_tp(&mut |tp| tp.replace(target, to), &SharedFrees::new())
1001 }
1002
1003 pub fn _replace_t(self, target: &Type, to: &Type, tvs: &SharedFrees) -> Self {
1004 self.map_t(&mut |t| t._replace(target, to, tvs), tvs)
1005 }
1006
1007 pub fn dereference(&mut self) {
1008 *self = std::mem::take(self).map_t(
1009 &mut |mut t| {
1010 t.dereference();
1011 t
1012 },
1013 &SharedFrees::new(),
1014 );
1015 }
1016
1017 pub fn map_t(self, f: &mut impl FnMut(Type) -> Type, tvs: &SharedFrees) -> Self {
1018 match self {
1019 Self::Value(val) => Self::Value(val.map_t(f)),
1020 Self::Const(_) => self,
1021 Self::Call {
1022 receiver,
1023 args,
1024 name,
1025 } => Self::Call {
1026 receiver: receiver.map_t(f, tvs),
1027 args: args.into_iter().map(|a| a.map_t(f, tvs)).collect(),
1028 name,
1029 },
1030 Self::Attr { receiver, name } => Self::Attr {
1031 receiver: receiver.map_t(f, tvs),
1032 name,
1033 },
1034 Self::Equal { lhs, rhs } => Self::Equal {
1035 lhs,
1036 rhs: rhs.map_t(f, tvs),
1037 },
1038 Self::GreaterEqual { lhs, rhs } => Self::GreaterEqual {
1039 lhs,
1040 rhs: rhs.map_t(f, tvs),
1041 },
1042 Self::LessEqual { lhs, rhs } => Self::LessEqual {
1043 lhs,
1044 rhs: rhs.map_t(f, tvs),
1045 },
1046 Self::NotEqual { lhs, rhs } => Self::NotEqual {
1047 lhs,
1048 rhs: rhs.map_t(f, tvs),
1049 },
1050 Self::GeneralEqual { lhs, rhs } => Self::GeneralEqual {
1051 lhs: Box::new(lhs.map_t(f, tvs)),
1052 rhs: Box::new(rhs.map_t(f, tvs)),
1053 },
1054 Self::GeneralLessEqual { lhs, rhs } => Self::GeneralLessEqual {
1055 lhs: Box::new(lhs.map_t(f, tvs)),
1056 rhs: Box::new(rhs.map_t(f, tvs)),
1057 },
1058 Self::GeneralGreaterEqual { lhs, rhs } => Self::GeneralGreaterEqual {
1059 lhs: Box::new(lhs.map_t(f, tvs)),
1060 rhs: Box::new(rhs.map_t(f, tvs)),
1061 },
1062 Self::GeneralNotEqual { lhs, rhs } => Self::GeneralNotEqual {
1063 lhs: Box::new(lhs.map_t(f, tvs)),
1064 rhs: Box::new(rhs.map_t(f, tvs)),
1065 },
1066 Self::And(lhs, rhs) => {
1067 Self::And(Box::new(lhs.map_t(f, tvs)), Box::new(rhs.map_t(f, tvs)))
1068 }
1069 Self::Or(preds) => Self::Or(preds.into_iter().map(|p| p.map_t(f, tvs)).collect()),
1070 Self::Not(pred) => Self::Not(Box::new(pred.map_t(f, tvs))),
1071 Self::Failure => self,
1072 }
1073 }
1074
1075 pub fn map_tp(self, f: &mut impl FnMut(TyParam) -> TyParam, tvs: &SharedFrees) -> Self {
1076 match self {
1077 Self::Value(val) => Self::Value(val.map_tp(f, tvs)),
1078 Self::Const(_) => self,
1079 Self::Call {
1080 receiver,
1081 args,
1082 name,
1083 } => Self::Call {
1084 receiver: receiver.map(f, tvs),
1085 args: args.into_iter().map(|a| a.map(f, tvs)).collect(),
1086 name,
1087 },
1088 Self::Attr { receiver, name } => Self::Attr {
1089 receiver: receiver.map(f, tvs),
1090 name,
1091 },
1092 Self::Equal { lhs, rhs } => Self::Equal {
1093 lhs,
1094 rhs: rhs.map(f, tvs),
1095 },
1096 Self::GreaterEqual { lhs, rhs } => Self::GreaterEqual {
1097 lhs,
1098 rhs: rhs.map(f, tvs),
1099 },
1100 Self::LessEqual { lhs, rhs } => Self::LessEqual {
1101 lhs,
1102 rhs: rhs.map(f, tvs),
1103 },
1104 Self::NotEqual { lhs, rhs } => Self::NotEqual {
1105 lhs,
1106 rhs: rhs.map(f, tvs),
1107 },
1108 Self::GeneralEqual { lhs, rhs } => Self::GeneralEqual {
1109 lhs: Box::new(lhs.map_tp(f, tvs)),
1110 rhs: Box::new(rhs.map_tp(f, tvs)),
1111 },
1112 Self::GeneralLessEqual { lhs, rhs } => Self::GeneralLessEqual {
1113 lhs: Box::new(lhs.map_tp(f, tvs)),
1114 rhs: Box::new(rhs.map_tp(f, tvs)),
1115 },
1116 Self::GeneralGreaterEqual { lhs, rhs } => Self::GeneralGreaterEqual {
1117 lhs: Box::new(lhs.map_tp(f, tvs)),
1118 rhs: Box::new(rhs.map_tp(f, tvs)),
1119 },
1120 Self::GeneralNotEqual { lhs, rhs } => Self::GeneralNotEqual {
1121 lhs: Box::new(lhs.map_tp(f, tvs)),
1122 rhs: Box::new(rhs.map_tp(f, tvs)),
1123 },
1124 Self::And(lhs, rhs) => {
1125 Self::And(Box::new(lhs.map_tp(f, tvs)), Box::new(rhs.map_tp(f, tvs)))
1126 }
1127 Self::Or(preds) => Self::Or(preds.into_iter().map(|p| p.map_tp(f, tvs)).collect()),
1128 Self::Not(pred) => Self::Not(Box::new(pred.map_tp(f, tvs))),
1129 Self::Failure => self,
1130 }
1131 }
1132
1133 pub fn try_map_tp<E>(
1134 self,
1135 f: &mut impl FnMut(TyParam) -> Result<TyParam, E>,
1136 tvs: &SharedFrees,
1137 ) -> Result<Self, E> {
1138 match self {
1139 Self::Value(val) => Ok(Self::Value(val.try_map_tp(f, tvs)?)),
1140 Self::Call {
1141 receiver,
1142 args,
1143 name,
1144 } => Ok(Self::Call {
1145 receiver: f(receiver)?,
1146 args: args.into_iter().map(f).collect::<Result<_, E>>()?,
1147 name,
1148 }),
1149 Self::Attr { receiver, name } => Ok(Self::Attr {
1150 receiver: f(receiver)?,
1151 name,
1152 }),
1153 Self::Equal { lhs, rhs } => Ok(Self::Equal { lhs, rhs: f(rhs)? }),
1154 Self::GreaterEqual { lhs, rhs } => Ok(Self::GreaterEqual { lhs, rhs: f(rhs)? }),
1155 Self::LessEqual { lhs, rhs } => Ok(Self::LessEqual { lhs, rhs: f(rhs)? }),
1156 Self::NotEqual { lhs, rhs } => Ok(Self::NotEqual { lhs, rhs: f(rhs)? }),
1157 Self::GeneralEqual { lhs, rhs } => Ok(Self::GeneralEqual {
1158 lhs: Box::new(lhs.try_map_tp(f, tvs)?),
1159 rhs: Box::new(rhs.try_map_tp(f, tvs)?),
1160 }),
1161 Self::GeneralLessEqual { lhs, rhs } => Ok(Self::GeneralLessEqual {
1162 lhs: Box::new(lhs.try_map_tp(f, tvs)?),
1163 rhs: Box::new(rhs.try_map_tp(f, tvs)?),
1164 }),
1165 Self::GeneralGreaterEqual { lhs, rhs } => Ok(Self::GeneralGreaterEqual {
1166 lhs: Box::new(lhs.try_map_tp(f, tvs)?),
1167 rhs: Box::new(rhs.try_map_tp(f, tvs)?),
1168 }),
1169 Self::GeneralNotEqual { lhs, rhs } => Ok(Self::GeneralNotEqual {
1170 lhs: Box::new(lhs.try_map_tp(f, tvs)?),
1171 rhs: Box::new(rhs.try_map_tp(f, tvs)?),
1172 }),
1173 Self::And(lhs, rhs) => Ok(Self::And(
1174 Box::new(lhs.try_map_tp(f, tvs)?),
1175 Box::new(rhs.try_map_tp(f, tvs)?),
1176 )),
1177 Self::Or(preds) => Ok(Self::Or(
1178 preds
1179 .into_iter()
1180 .map(|p| p.try_map_tp(f, tvs))
1181 .collect::<Result<_, E>>()?,
1182 )),
1183 Self::Not(pred) => Ok(Self::Not(Box::new(pred.try_map_tp(f, tvs)?))),
1184 Self::Failure | Self::Const(_) => Ok(self),
1185 }
1186 }
1187}