1use std::{
2 collections::{HashMap, HashSet},
3 fmt,
4};
5
6use crate::{
7 bindings::Bindings,
8 error::RuntimeError,
9 filter::singleton,
10 folder::{fold_term, Folder},
11 terms::*,
12};
13
14use super::partial::{invert_operation, FALSE, TRUE};
15
16type Result<T> = core::result::Result<T, RuntimeError>;
17
18const TRACK_PERF: bool = false;
21
22const SIMPLIFY_DEBUG: bool = false;
24
25macro_rules! if_debug {
26 ($($e:tt)*) => {
27 if SIMPLIFY_DEBUG {
28 $($e)*
29 }
30 }
31}
32
33macro_rules! simplify_debug {
34 ($($e:tt)*) => {
35 if_debug!(eprintln!($($e)*))
36 }
37}
38
39enum MaybeDrop {
40 Keep,
41 Drop,
42 Bind(Symbol, Term),
43 Check(Symbol, Term),
44}
45
46struct VariableSubber {
47 this_var: Symbol,
48}
49
50impl VariableSubber {
51 pub fn new(this_var: Symbol) -> Self {
52 Self { this_var }
53 }
54}
55
56impl Folder for VariableSubber {
57 fn fold_variable(&mut self, v: Symbol) -> Symbol {
58 if v == self.this_var {
59 sym!("_this")
60 } else {
61 v
62 }
63 }
64
65 fn fold_rest_variable(&mut self, v: Symbol) -> Symbol {
66 if v == self.this_var {
67 sym!("_this")
68 } else {
69 v
70 }
71 }
72}
73
74pub fn sub_this(this: Symbol, term: Term) -> Term {
76 if term.as_symbol().map(|s| s == &this).unwrap_or(false) {
77 return term;
78 }
79 fold_term(term, &mut VariableSubber::new(this))
80}
81
82fn simplify_trivial_constraint(this: Symbol, term: Term) -> Term {
84 use {Operator::*, Value::*};
85 match term.value() {
86 Expression(o) if o.operator == Unify => {
87 let left = &o.args[0];
88 let right = &o.args[1];
89 match (left.value(), right.value()) {
90 (Variable(v), Variable(w))
91 | (Variable(v), RestVariable(w))
92 | (RestVariable(v), Variable(w))
93 | (RestVariable(v), RestVariable(w))
94 if v == w =>
95 {
96 TRUE.into()
97 }
98 (Variable(l), _) | (RestVariable(l), _) if l == &this && right.is_ground() => {
99 right.clone()
100 }
101 (_, Variable(r)) | (_, RestVariable(r)) if r == &this && left.is_ground() => {
102 left.clone()
103 }
104 _ => term,
105 }
106 }
107 _ => term,
108 }
109}
110
111pub fn simplify_partial(
112 var: &Symbol,
113 mut term: Term,
114 output_vars: HashSet<Symbol>,
115 track_performance: bool,
116) -> (Term, Option<PerfCounters>) {
117 let mut simplifier = Simplifier::new(output_vars, track_performance);
118 simplify_debug!("*** simplify partial {:?}", var);
119 simplifier.simplify_partial(&mut term);
120 term = simplify_trivial_constraint(var.clone(), term);
121 simplify_debug!("simplify partial done {:?}, {}", var, term);
122 if matches!(term.value(), Value::Expression(e) if e.operator != Operator::And) {
123 (op!(And, term).into(), simplifier.perf_counters())
124 } else {
125 (term, simplifier.perf_counters())
126 }
127}
128
129pub fn simplify_bindings(bindings: Bindings) -> Option<Bindings> {
130 simplify_bindings_opt(bindings, true)
131 .expect("unexpected error thrown by the simplifier when simplifying all bindings")
132}
133
134pub fn simplify_bindings_opt(bindings: Bindings, all: bool) -> Result<Option<Bindings>> {
139 let mut perf = PerfCounters::new(TRACK_PERF);
140 simplify_debug!("simplify bindings");
141
142 if_debug! {
143 eprintln!("before simplified");
144 for (k, v) in bindings.iter() {
145 eprintln!("{:?} {}", k, v);
146 }
147 }
148
149 let mut unsatisfiable = false;
150 let mut simplify_var = |bindings: &Bindings, var: &Symbol, value: &Term| match value.value() {
151 Value::Expression(o) => {
152 assert_eq!(o.operator, Operator::And);
153 let output_vars = if all {
154 singleton(var.clone())
155 } else {
156 bindings
157 .keys()
158 .filter(|v| !v.is_temporary_var())
159 .cloned()
160 .collect::<HashSet<_>>()
161 };
162
163 let (simplified, p) = simplify_partial(var, value.clone(), output_vars, TRACK_PERF);
164 if let Some(p) = p {
165 perf.merge(p);
166 }
167
168 match simplified.as_expression() {
169 Ok(o) if o == &FALSE => unsatisfiable = true,
170 _ => (),
171 }
172 simplified
173 }
174 Value::Variable(v) | Value::RestVariable(v)
175 if v.is_temporary_var()
176 && bindings.contains_key(v)
177 && matches!(
178 bindings[v].value(),
179 Value::Variable(_) | Value::RestVariable(_)
180 ) =>
181 {
182 bindings[v].clone()
183 }
184 _ => value.clone(),
185 };
186
187 simplify_debug!("simplify bindings {}", if all { "all" } else { "not all" });
188
189 let mut simplified_bindings = HashMap::new();
190 for (var, value) in &bindings {
191 if !var.is_temporary_var() || all {
192 let simplified = simplify_var(&bindings, var, value);
193 simplified_bindings.insert(var.clone(), simplified);
194 } else if let Value::Expression(e) = value.value() {
195 if e.variables().iter().all(|v| v.is_temporary_var()) {
196 return Err(RuntimeError::UnhandledPartial {
197 var: var.clone(),
198 term: value.clone(),
199 });
200 }
201 }
202 }
203
204 if unsatisfiable {
205 Ok(None)
206 } else {
207 if_debug! {
208 eprintln!("after simplified");
209 for (k, v) in simplified_bindings.iter() {
210 eprintln!("{:?} {}", k, v);
211 }
212 }
213
214 Ok(Some(simplified_bindings))
215 }
216}
217
218#[derive(Clone, Default)]
219pub struct PerfCounters {
220 enabled: bool,
221
222 simplify_term: HashMap<Term, u64>,
224 preprocess_and: HashMap<Term, u64>,
225
226 acc_simplify_term: u64,
227 acc_preprocess_and: u64,
228}
229
230impl fmt::Display for PerfCounters {
231 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232 writeln!(f, "perf {{")?;
233 writeln!(f, "simplify term")?;
234 for (term, ncalls) in self.simplify_term.iter() {
235 writeln!(f, "\t{}: {}", term, ncalls)?;
236 }
237
238 writeln!(f, "preprocess and")?;
239
240 for (term, ncalls) in self.preprocess_and.iter() {
241 writeln!(f, "\t{}: {}", term, ncalls)?;
242 }
243
244 writeln!(f, "}}")
245 }
246}
247
248impl PerfCounters {
249 fn new(enabled: bool) -> Self {
250 Self {
251 enabled,
252 ..Default::default()
253 }
254 }
255
256 fn preprocess_and(&mut self) {
257 if !self.enabled {
258 return;
259 }
260
261 self.acc_preprocess_and += 1;
262 }
263
264 fn simplify_term(&mut self) {
265 if !self.enabled {
266 return;
267 }
268
269 self.acc_simplify_term += 1;
270 }
271
272 fn finish_acc(&mut self, term: Term) {
273 if !self.enabled {
274 return;
275 }
276
277 self.simplify_term
278 .insert(term.clone(), self.acc_simplify_term);
279 self.preprocess_and.insert(term, self.acc_preprocess_and);
280 self.acc_preprocess_and = 0;
281 self.acc_simplify_term = 0;
282 }
283
284 fn merge(&mut self, other: PerfCounters) {
285 if !self.enabled {
286 return;
287 }
288
289 self.simplify_term.extend(other.simplify_term);
290 self.preprocess_and.extend(other.preprocess_and);
291 }
292
293 pub fn is_enabled(&self) -> bool {
294 self.enabled
295 }
296}
297
298#[derive(Clone)]
299pub struct Simplifier {
300 bindings: Bindings,
301 output_vars: HashSet<Symbol>,
302 seen: HashSet<Term>,
303
304 counters: PerfCounters,
305}
306
307type TermSimplifier = dyn Fn(&mut Simplifier, &mut Term);
308
309impl Simplifier {
310 pub fn new(output_vars: HashSet<Symbol>, track_performance: bool) -> Self {
311 Self {
312 bindings: Bindings::new(),
313 output_vars,
314 seen: HashSet::new(),
315 counters: PerfCounters::new(track_performance),
316 }
317 }
318
319 fn perf_counters(&mut self) -> Option<PerfCounters> {
320 if !self.counters.is_enabled() {
321 return None;
322 }
323
324 let mut counter = PerfCounters::new(true);
325 std::mem::swap(&mut self.counters, &mut counter);
326 Some(counter)
327 }
328
329 pub fn bind(&mut self, var: Symbol, value: Term) {
330 if !self.is_bound(&var) {
332 self.bindings.insert(var, self.deref(&value));
333 }
334 }
335
336 pub fn deref(&self, term: &Term) -> Term {
337 match term.value() {
338 Value::Variable(var) | Value::RestVariable(var) => {
339 self.bindings.get(var).unwrap_or(term).clone()
340 }
341 _ => term.clone(),
342 }
343 }
344
345 fn is_bound(&self, var: &Symbol) -> bool {
346 self.bindings.contains_key(var)
347 }
348
349 fn is_output(&self, t: &Term) -> bool {
350 match t.value() {
351 Value::Variable(v) | Value::RestVariable(v) => self.output_vars.contains(v),
352 _ => false,
353 }
354 }
355
356 fn maybe_bind_constraint(&mut self, constraint: &Operation) -> MaybeDrop {
369 match constraint.operator {
370 Operator::And if constraint.args.is_empty() => MaybeDrop::Drop,
372
373 Operator::Unify | Operator::Eq => {
375 let left = &constraint.args[0];
376 let right = &constraint.args[1];
377
378 if left == right {
379 MaybeDrop::Drop
381 } else {
382 match (left.value(), right.value()) {
384 (Value::Variable(_), Value::Variable(_))
386 if self.is_output(left) && self.is_output(right) =>
387 {
388 MaybeDrop::Keep
389 }
390 (Value::Variable(l), _) if !self.is_bound(l) && !self.is_output(left) => {
392 simplify_debug!("*** 1");
393 MaybeDrop::Bind(l.clone(), right.clone())
394 }
395 (_, Value::Variable(r)) if !self.is_bound(r) && !self.is_output(right) => {
397 simplify_debug!("*** 2");
398 MaybeDrop::Bind(r.clone(), left.clone())
399 }
400 (Value::Variable(var), val) if val.is_ground() && !self.is_bound(var) => {
402 simplify_debug!("*** 3");
403 MaybeDrop::Check(var.clone(), right.clone())
404 }
405 (val, Value::Variable(var)) if val.is_ground() && !self.is_bound(var) => {
407 simplify_debug!("*** 4");
408 MaybeDrop::Check(var.clone(), left.clone())
409 }
410 _ => MaybeDrop::Keep,
412 }
413 }
414 }
415 _ => MaybeDrop::Keep,
416 }
417 }
418
419 pub fn simplify_operation_variables(
426 &mut self,
427 o: &mut Operation,
428 simplify_term: &TermSimplifier,
429 ) {
430 fn toss_trivial_unifies(args: &mut TermList) {
431 args.retain(|c| {
432 let o = c.as_expression().unwrap();
433 match o.operator {
434 Operator::Unify | Operator::Eq => {
435 assert_eq!(o.args.len(), 2);
436 let left = &o.args[0];
437 let right = &o.args[1];
438 left != right
439 }
440 _ => true,
441 }
442 });
443 }
444
445 if o.operator == Operator::And || o.operator == Operator::Or {
446 toss_trivial_unifies(&mut o.args);
447 }
448
449 match o.operator {
450 Operator::And | Operator::Or if o.args.is_empty() => (),
453
454 Operator::And | Operator::Or if o.args.len() == 1 => {
456 if let Value::Expression(operation) = o.args[0].value() {
457 *o = operation.clone();
458 self.simplify_operation_variables(o, simplify_term);
459 }
460 }
461
462 Operator::And if o.args.len() > 1 => {
465 let mut keep = o.args.iter().map(|_| true).collect::<Vec<bool>>();
467 let mut references = o.args.iter().map(|_| false).collect::<Vec<bool>>();
468 for (i, arg) in o.args.iter().enumerate() {
469 match self.maybe_bind_constraint(arg.as_expression().unwrap()) {
470 MaybeDrop::Keep => (),
471 MaybeDrop::Drop => keep[i] = false,
472 MaybeDrop::Bind(var, value) => {
473 keep[i] = false;
474 simplify_debug!("bind {:?}, {}", var, value);
475 self.bind(var, value);
476 }
477 MaybeDrop::Check(var, value) => {
478 simplify_debug!("check {}, {}", var, value);
479 for (j, arg) in o.args.iter().enumerate() {
480 if j != i && arg.contains_variable(&var) {
481 simplify_debug!("check bind {}, {} ref: {}", var, value, j);
482 self.bind(var, value);
483 keep[i] = false;
484
485 references[j] = true;
487 break;
488 }
489 }
490 }
491 }
492 }
493
494 let mut i = 0;
496 o.args.retain(|_| {
497 i += 1;
498 keep[i - 1] || references[i - 1]
499 });
500
501 for arg in &mut o.args {
503 simplify_term(self, arg);
504 }
505 }
506
507 Operator::Not => {
510 assert_eq!(o.args.len(), 1);
511 let mut simplified = o.args[0].clone();
512 let mut simplifier = self.clone();
513 simplifier.simplify_partial(&mut simplified);
514 *o = invert_operation(
515 simplified
516 .as_expression()
517 .expect("a simplified expression")
518 .clone(),
519 )
520 }
521
522 _ => {
524 for arg in &mut o.args {
525 simplify_term(self, arg);
526 }
527 }
528 }
529 }
530
531 pub fn deduplicate_operation(&mut self, o: &mut Operation, simplify_term: &TermSimplifier) {
534 fn preprocess_and(args: &mut TermList) {
535 let mut seen: HashSet<u64> = HashSet::with_capacity(args.len());
538 args.retain(|a| {
539 let o = a.as_expression().unwrap();
540 o != &TRUE && !seen.contains(&Term::from(o.mirror()).hash_value()) && seen.insert(a.hash_value()) });
544 }
545
546 if o.operator == Operator::And {
547 self.counters.preprocess_and();
548 preprocess_and(&mut o.args);
549 }
550
551 match o.operator {
552 Operator::And | Operator::Or if o.args.is_empty() => (),
553
554 Operator::And | Operator::Or if o.args.len() == 1 => {
556 if let Value::Expression(operation) = o.args[0].value() {
557 *o = operation.clone();
558 self.deduplicate_operation(o, simplify_term);
559 }
560 }
561
562 _ => {
564 for arg in &mut o.args {
565 simplify_term(self, arg);
566 }
567 }
568 }
569 }
570
571 pub fn simplify_term<F>(&mut self, term: &mut Term, simplify_operation: F)
578 where
579 F: Fn(&mut Self, &mut Operation, &TermSimplifier) + 'static + Clone,
580 {
581 if self.seen.contains(term) {
582 return;
583 }
584 let orig = term.clone();
585 self.seen.insert(term.clone());
586
587 let de = self.deref(term);
588 *term = de;
589
590 match term.mut_value() {
591 Value::Dictionary(dict) => {
592 for (_, v) in dict.fields.iter_mut() {
593 self.simplify_term(v, simplify_operation.clone());
594 }
595 }
596 Value::Call(call) => {
597 for arg in call.args.iter_mut() {
598 self.simplify_term(arg, simplify_operation.clone());
599 }
600 if let Some(kwargs) = &mut call.kwargs {
601 for (_, v) in kwargs.iter_mut() {
602 self.simplify_term(v, simplify_operation.clone());
603 }
604 }
605 }
606 Value::List(list) => {
607 for elem in list.iter_mut() {
608 self.simplify_term(elem, simplify_operation.clone());
609 }
610 }
611 Value::Expression(operation) => {
612 let so = simplify_operation.clone();
613 let cont = move |s: &mut Self, term: &mut Term| {
614 s.simplify_term(term, simplify_operation.clone())
615 };
616 so(self, operation, &cont);
617 }
618 _ => (),
619 }
620
621 if let Ok(sym) = orig.as_symbol() {
622 if term.contains_variable(sym) {
623 *term = orig.clone()
624 }
625 }
626 self.seen.remove(&orig);
627 }
628
629 pub fn simplify_partial(&mut self, term: &mut Term) {
631 let mut last = term.hash_value();
633 let mut nbindings = self.bindings.len();
634 loop {
635 simplify_debug!("simplify loop {}", term);
636 self.counters.simplify_term();
637
638 self.simplify_term(term, Simplifier::simplify_operation_variables);
639 let now = term.hash_value();
640 if last == now && self.bindings.len() == nbindings {
641 break;
642 }
643 last = now;
644 nbindings = self.bindings.len();
645 }
646
647 self.simplify_term(term, Simplifier::deduplicate_operation);
648
649 self.counters.finish_acc(term.clone());
650 }
651}
652
653#[cfg(test)]
654mod test {
655 use super::*;
656
657 #[test]
660 #[allow(clippy::bool_assert_comparison)]
661 fn test_debug_off() {
662 assert_eq!(SIMPLIFY_DEBUG, false);
663 assert_eq!(TRACK_PERF, false);
664 }
665
666 #[test]
667 fn test_simplify_circular_dot_with_isa() {
668 let op = term!(op!(Dot, var!("x"), str!("x")));
669 let op = term!(op!(Unify, var!("x"), op));
670 let op = term!(op!(
671 And,
672 op,
673 term!(op!(Isa, var!("x"), term!(pattern!(instance!("X")))))
674 ));
675 let mut vs: HashSet<Symbol> = HashSet::new();
676 vs.insert(sym!("x"));
677 let (x, _) = simplify_partial(&sym!("x"), op, vs, false);
678 assert_eq!(
679 x,
680 term!(op!(
681 And,
682 term!(op!(Unify, var!("x"), term!(op!(Dot, var!("x"), str!("x"))))),
683 term!(op!(Isa, var!("x"), term!(pattern!(instance!("X")))))
684 ))
685 );
686 }
687}