polar_core/partial/
simplify.rs

1use std::{
2    collections::{HashMap, HashSet},
3    fmt,
4};
5
6use crate::{
7    bindings::Bindings,
8    error::RuntimeError,
9    filter::singleton,
10    folder::{fold_term, Folder},
11    terms::*,
12};
13
14use super::partial::{invert_operation, FALSE, TRUE};
15
16type Result<T> = core::result::Result<T, RuntimeError>;
17
18/// Set to `true` to debug performance in simplifier by turning on
19/// performance counters.
20const TRACK_PERF: bool = false;
21
22/// Set to `true` to turn on simplify debug logging.
23const SIMPLIFY_DEBUG: bool = false;
24
25macro_rules! if_debug {
26    ($($e:tt)*) => {
27        if SIMPLIFY_DEBUG {
28            $($e)*
29        }
30    }
31}
32
33macro_rules! simplify_debug {
34    ($($e:tt)*) => {
35        if_debug!(eprintln!($($e)*))
36    }
37}
38
39enum MaybeDrop {
40    Keep,
41    Drop,
42    Bind(Symbol, Term),
43    Check(Symbol, Term),
44}
45
46struct VariableSubber {
47    this_var: Symbol,
48}
49
50impl VariableSubber {
51    pub fn new(this_var: Symbol) -> Self {
52        Self { this_var }
53    }
54}
55
56impl Folder for VariableSubber {
57    fn fold_variable(&mut self, v: Symbol) -> Symbol {
58        if v == self.this_var {
59            sym!("_this")
60        } else {
61            v
62        }
63    }
64
65    fn fold_rest_variable(&mut self, v: Symbol) -> Symbol {
66        if v == self.this_var {
67            sym!("_this")
68        } else {
69            v
70        }
71    }
72}
73
74/// Substitute `sym!("_this")` for a variable in a partial.
75pub fn sub_this(this: Symbol, term: Term) -> Term {
76    if term.as_symbol().map(|s| s == &this).unwrap_or(false) {
77        return term;
78    }
79    fold_term(term, &mut VariableSubber::new(this))
80}
81
82/// Turn `_this = x` into `x` when it's ground.
83fn simplify_trivial_constraint(this: Symbol, term: Term) -> Term {
84    use {Operator::*, Value::*};
85    match term.value() {
86        Expression(o) if o.operator == Unify => {
87            let left = &o.args[0];
88            let right = &o.args[1];
89            match (left.value(), right.value()) {
90                (Variable(v), Variable(w))
91                | (Variable(v), RestVariable(w))
92                | (RestVariable(v), Variable(w))
93                | (RestVariable(v), RestVariable(w))
94                    if v == w =>
95                {
96                    TRUE.into()
97                }
98                (Variable(l), _) | (RestVariable(l), _) if l == &this && right.is_ground() => {
99                    right.clone()
100                }
101                (_, Variable(r)) | (_, RestVariable(r)) if r == &this && left.is_ground() => {
102                    left.clone()
103                }
104                _ => term,
105            }
106        }
107        _ => term,
108    }
109}
110
111pub fn simplify_partial(
112    var: &Symbol,
113    mut term: Term,
114    output_vars: HashSet<Symbol>,
115    track_performance: bool,
116) -> (Term, Option<PerfCounters>) {
117    let mut simplifier = Simplifier::new(output_vars, track_performance);
118    simplify_debug!("*** simplify partial {:?}", var);
119    simplifier.simplify_partial(&mut term);
120    term = simplify_trivial_constraint(var.clone(), term);
121    simplify_debug!("simplify partial done {:?}, {}", var, term);
122    if matches!(term.value(), Value::Expression(e) if e.operator != Operator::And) {
123        (op!(And, term).into(), simplifier.perf_counters())
124    } else {
125        (term, simplifier.perf_counters())
126    }
127}
128
129pub fn simplify_bindings(bindings: Bindings) -> Option<Bindings> {
130    simplify_bindings_opt(bindings, true)
131        .expect("unexpected error thrown by the simplifier when simplifying all bindings")
132}
133
134/// Simplify the values of the bindings to be returned to the host language.
135///
136/// - For partials, simplify the constraint expressions.
137/// - For non-partials, deep deref. TODO(ap/gj): deep deref.
138pub fn simplify_bindings_opt(bindings: Bindings, all: bool) -> Result<Option<Bindings>> {
139    let mut perf = PerfCounters::new(TRACK_PERF);
140    simplify_debug!("simplify bindings");
141
142    if_debug! {
143        eprintln!("before simplified");
144        for (k, v) in bindings.iter() {
145            eprintln!("{:?} {}", k, v);
146        }
147    }
148
149    let mut unsatisfiable = false;
150    let mut simplify_var = |bindings: &Bindings, var: &Symbol, value: &Term| match value.value() {
151        Value::Expression(o) => {
152            assert_eq!(o.operator, Operator::And);
153            let output_vars = if all {
154                singleton(var.clone())
155            } else {
156                bindings
157                    .keys()
158                    .filter(|v| !v.is_temporary_var())
159                    .cloned()
160                    .collect::<HashSet<_>>()
161            };
162
163            let (simplified, p) = simplify_partial(var, value.clone(), output_vars, TRACK_PERF);
164            if let Some(p) = p {
165                perf.merge(p);
166            }
167
168            match simplified.as_expression() {
169                Ok(o) if o == &FALSE => unsatisfiable = true,
170                _ => (),
171            }
172            simplified
173        }
174        Value::Variable(v) | Value::RestVariable(v)
175            if v.is_temporary_var()
176                && bindings.contains_key(v)
177                && matches!(
178                    bindings[v].value(),
179                    Value::Variable(_) | Value::RestVariable(_)
180                ) =>
181        {
182            bindings[v].clone()
183        }
184        _ => value.clone(),
185    };
186
187    simplify_debug!("simplify bindings {}", if all { "all" } else { "not all" });
188
189    let mut simplified_bindings = HashMap::new();
190    for (var, value) in &bindings {
191        if !var.is_temporary_var() || all {
192            let simplified = simplify_var(&bindings, var, value);
193            simplified_bindings.insert(var.clone(), simplified);
194        } else if let Value::Expression(e) = value.value() {
195            if e.variables().iter().all(|v| v.is_temporary_var()) {
196                return Err(RuntimeError::UnhandledPartial {
197                    var: var.clone(),
198                    term: value.clone(),
199                });
200            }
201        }
202    }
203
204    if unsatisfiable {
205        Ok(None)
206    } else {
207        if_debug! {
208            eprintln!("after simplified");
209            for (k, v) in simplified_bindings.iter() {
210                eprintln!("{:?} {}", k, v);
211            }
212        }
213
214        Ok(Some(simplified_bindings))
215    }
216}
217
218#[derive(Clone, Default)]
219pub struct PerfCounters {
220    enabled: bool,
221
222    // Map of number simplifier loops by term to simplify.
223    simplify_term: HashMap<Term, u64>,
224    preprocess_and: HashMap<Term, u64>,
225
226    acc_simplify_term: u64,
227    acc_preprocess_and: u64,
228}
229
230impl fmt::Display for PerfCounters {
231    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232        writeln!(f, "perf {{")?;
233        writeln!(f, "simplify term")?;
234        for (term, ncalls) in self.simplify_term.iter() {
235            writeln!(f, "\t{}: {}", term, ncalls)?;
236        }
237
238        writeln!(f, "preprocess and")?;
239
240        for (term, ncalls) in self.preprocess_and.iter() {
241            writeln!(f, "\t{}: {}", term, ncalls)?;
242        }
243
244        writeln!(f, "}}")
245    }
246}
247
248impl PerfCounters {
249    fn new(enabled: bool) -> Self {
250        Self {
251            enabled,
252            ..Default::default()
253        }
254    }
255
256    fn preprocess_and(&mut self) {
257        if !self.enabled {
258            return;
259        }
260
261        self.acc_preprocess_and += 1;
262    }
263
264    fn simplify_term(&mut self) {
265        if !self.enabled {
266            return;
267        }
268
269        self.acc_simplify_term += 1;
270    }
271
272    fn finish_acc(&mut self, term: Term) {
273        if !self.enabled {
274            return;
275        }
276
277        self.simplify_term
278            .insert(term.clone(), self.acc_simplify_term);
279        self.preprocess_and.insert(term, self.acc_preprocess_and);
280        self.acc_preprocess_and = 0;
281        self.acc_simplify_term = 0;
282    }
283
284    fn merge(&mut self, other: PerfCounters) {
285        if !self.enabled {
286            return;
287        }
288
289        self.simplify_term.extend(other.simplify_term);
290        self.preprocess_and.extend(other.preprocess_and);
291    }
292
293    pub fn is_enabled(&self) -> bool {
294        self.enabled
295    }
296}
297
298#[derive(Clone)]
299pub struct Simplifier {
300    bindings: Bindings,
301    output_vars: HashSet<Symbol>,
302    seen: HashSet<Term>,
303
304    counters: PerfCounters,
305}
306
307type TermSimplifier = dyn Fn(&mut Simplifier, &mut Term);
308
309impl Simplifier {
310    pub fn new(output_vars: HashSet<Symbol>, track_performance: bool) -> Self {
311        Self {
312            bindings: Bindings::new(),
313            output_vars,
314            seen: HashSet::new(),
315            counters: PerfCounters::new(track_performance),
316        }
317    }
318
319    fn perf_counters(&mut self) -> Option<PerfCounters> {
320        if !self.counters.is_enabled() {
321            return None;
322        }
323
324        let mut counter = PerfCounters::new(true);
325        std::mem::swap(&mut self.counters, &mut counter);
326        Some(counter)
327    }
328
329    pub fn bind(&mut self, var: Symbol, value: Term) {
330        // We do not allow rebindings.
331        if !self.is_bound(&var) {
332            self.bindings.insert(var, self.deref(&value));
333        }
334    }
335
336    pub fn deref(&self, term: &Term) -> Term {
337        match term.value() {
338            Value::Variable(var) | Value::RestVariable(var) => {
339                self.bindings.get(var).unwrap_or(term).clone()
340            }
341            _ => term.clone(),
342        }
343    }
344
345    fn is_bound(&self, var: &Symbol) -> bool {
346        self.bindings.contains_key(var)
347    }
348
349    fn is_output(&self, t: &Term) -> bool {
350        match t.value() {
351            Value::Variable(v) | Value::RestVariable(v) => self.output_vars.contains(v),
352            _ => false,
353        }
354    }
355
356    /// Determine whether to keep, drop, bind or conditionally bind a unification operation.
357    ///
358    /// Returns:
359    /// - Keep: to indicate that the operation should not be removed
360    /// - Drop: to indicate the operation should be removed with no new bindings
361    /// - Bind(var, val) to indicate that the operation should be removed, and var should be
362    ///                  bound to val.
363    /// - Check(var, val) To indicate that the operation should be removed and var should
364    ///                   be bound to val *if* var is referenced elsewhere in the expression.
365    ///
366    /// Params:
367    ///     constraint: The constraint to consider removing from its parent.
368    fn maybe_bind_constraint(&mut self, constraint: &Operation) -> MaybeDrop {
369        match constraint.operator {
370            // X and X is always true, so drop.
371            Operator::And if constraint.args.is_empty() => MaybeDrop::Drop,
372
373            // Choose a unification to maybe drop.
374            Operator::Unify | Operator::Eq => {
375                let left = &constraint.args[0];
376                let right = &constraint.args[1];
377
378                if left == right {
379                    // The sides are exactly equal, so drop.
380                    MaybeDrop::Drop
381                } else {
382                    // Maybe bind one side to the other.
383                    match (left.value(), right.value()) {
384                        // Always keep unifications of two output variables (x = y).
385                        (Value::Variable(_), Value::Variable(_))
386                            if self.is_output(left) && self.is_output(right) =>
387                        {
388                            MaybeDrop::Keep
389                        }
390                        // Replace non-output variable l with right.
391                        (Value::Variable(l), _) if !self.is_bound(l) && !self.is_output(left) => {
392                            simplify_debug!("*** 1");
393                            MaybeDrop::Bind(l.clone(), right.clone())
394                        }
395                        // Replace non-output variable r with left.
396                        (_, Value::Variable(r)) if !self.is_bound(r) && !self.is_output(right) => {
397                            simplify_debug!("*** 2");
398                            MaybeDrop::Bind(r.clone(), left.clone())
399                        }
400                        // Replace unbound variable with ground value.
401                        (Value::Variable(var), val) if val.is_ground() && !self.is_bound(var) => {
402                            simplify_debug!("*** 3");
403                            MaybeDrop::Check(var.clone(), right.clone())
404                        }
405                        // Replace unbound variable with ground value.
406                        (val, Value::Variable(var)) if val.is_ground() && !self.is_bound(var) => {
407                            simplify_debug!("*** 4");
408                            MaybeDrop::Check(var.clone(), left.clone())
409                        }
410                        // Keep everything else.
411                        _ => MaybeDrop::Keep,
412                    }
413                }
414            }
415            _ => MaybeDrop::Keep,
416        }
417    }
418
419    /// Perform simplification of variable names in an operation by eliminating unification
420    /// operations to express an operation in terms of output variables only.
421    ///
422    /// Also inverts negation operations.
423    ///
424    /// May require multiple calls to perform all eliminiations.
425    pub fn simplify_operation_variables(
426        &mut self,
427        o: &mut Operation,
428        simplify_term: &TermSimplifier,
429    ) {
430        fn toss_trivial_unifies(args: &mut TermList) {
431            args.retain(|c| {
432                let o = c.as_expression().unwrap();
433                match o.operator {
434                    Operator::Unify | Operator::Eq => {
435                        assert_eq!(o.args.len(), 2);
436                        let left = &o.args[0];
437                        let right = &o.args[1];
438                        left != right
439                    }
440                    _ => true,
441                }
442            });
443        }
444
445        if o.operator == Operator::And || o.operator == Operator::Or {
446            toss_trivial_unifies(&mut o.args);
447        }
448
449        match o.operator {
450            // Zero-argument conjunctions & disjunctions represent constants
451            // TRUE and FALSE, respectively. We do not simplify them.
452            Operator::And | Operator::Or if o.args.is_empty() => (),
453
454            // Replace one-argument conjunctions & disjunctions with their argument.
455            Operator::And | Operator::Or if o.args.len() == 1 => {
456                if let Value::Expression(operation) = o.args[0].value() {
457                    *o = operation.clone();
458                    self.simplify_operation_variables(o, simplify_term);
459                }
460            }
461
462            // Non-trivial conjunctions. Choose unification constraints
463            // to make bindings from and throw away; fold the rest.
464            Operator::And if o.args.len() > 1 => {
465                // Compute which constraints to keep.
466                let mut keep = o.args.iter().map(|_| true).collect::<Vec<bool>>();
467                let mut references = o.args.iter().map(|_| false).collect::<Vec<bool>>();
468                for (i, arg) in o.args.iter().enumerate() {
469                    match self.maybe_bind_constraint(arg.as_expression().unwrap()) {
470                        MaybeDrop::Keep => (),
471                        MaybeDrop::Drop => keep[i] = false,
472                        MaybeDrop::Bind(var, value) => {
473                            keep[i] = false;
474                            simplify_debug!("bind {:?}, {}", var, value);
475                            self.bind(var, value);
476                        }
477                        MaybeDrop::Check(var, value) => {
478                            simplify_debug!("check {}, {}", var, value);
479                            for (j, arg) in o.args.iter().enumerate() {
480                                if j != i && arg.contains_variable(&var) {
481                                    simplify_debug!("check bind {}, {} ref: {}", var, value, j);
482                                    self.bind(var, value);
483                                    keep[i] = false;
484
485                                    // record that this term references var and must be kept.
486                                    references[j] = true;
487                                    break;
488                                }
489                            }
490                        }
491                    }
492                }
493
494                // Drop the rest.
495                let mut i = 0;
496                o.args.retain(|_| {
497                    i += 1;
498                    keep[i - 1] || references[i - 1]
499                });
500
501                // Simplify the survivors.
502                for arg in &mut o.args {
503                    simplify_term(self, arg);
504                }
505            }
506
507            // Negation. Simplify the negated term, saving & restoring the
508            // current bindings because bindings may not leak out of a negation.
509            Operator::Not => {
510                assert_eq!(o.args.len(), 1);
511                let mut simplified = o.args[0].clone();
512                let mut simplifier = self.clone();
513                simplifier.simplify_partial(&mut simplified);
514                *o = invert_operation(
515                    simplified
516                        .as_expression()
517                        .expect("a simplified expression")
518                        .clone(),
519                )
520            }
521
522            // Default case.
523            _ => {
524                for arg in &mut o.args {
525                    simplify_term(self, arg);
526                }
527            }
528        }
529    }
530
531    /// Deduplicate an operation by removing terms that are mirrors or duplicates
532    /// of other terms.
533    pub fn deduplicate_operation(&mut self, o: &mut Operation, simplify_term: &TermSimplifier) {
534        fn preprocess_and(args: &mut TermList) {
535            // HashSet of term hash values used to deduplicate. We use hash values
536            // to avoid cloning to insert terms.
537            let mut seen: HashSet<u64> = HashSet::with_capacity(args.len());
538            args.retain(|a| {
539                let o = a.as_expression().unwrap();
540                o != &TRUE // trivial
541                    && !seen.contains(&Term::from(o.mirror()).hash_value()) // reflection
542                    && seen.insert(a.hash_value()) // duplicate
543            });
544        }
545
546        if o.operator == Operator::And {
547            self.counters.preprocess_and();
548            preprocess_and(&mut o.args);
549        }
550
551        match o.operator {
552            Operator::And | Operator::Or if o.args.is_empty() => (),
553
554            // Replace one-argument conjunctions & disjunctions with their argument.
555            Operator::And | Operator::Or if o.args.len() == 1 => {
556                if let Value::Expression(operation) = o.args[0].value() {
557                    *o = operation.clone();
558                    self.deduplicate_operation(o, simplify_term);
559                }
560            }
561
562            // Default case.
563            _ => {
564                for arg in &mut o.args {
565                    simplify_term(self, arg);
566                }
567            }
568        }
569    }
570
571    /// Simplify a term `term` in place by calling the simplification
572    /// function `simplify_operation` on any Expression in that term.
573    ///
574    /// `simplify_operation` should perform simplification operations in-place
575    /// on the operation argument. To recursively simplify sub-terms in that operation,
576    /// it must call the passed TermSimplifier.
577    pub fn simplify_term<F>(&mut self, term: &mut Term, simplify_operation: F)
578    where
579        F: Fn(&mut Self, &mut Operation, &TermSimplifier) + 'static + Clone,
580    {
581        if self.seen.contains(term) {
582            return;
583        }
584        let orig = term.clone();
585        self.seen.insert(term.clone());
586
587        let de = self.deref(term);
588        *term = de;
589
590        match term.mut_value() {
591            Value::Dictionary(dict) => {
592                for (_, v) in dict.fields.iter_mut() {
593                    self.simplify_term(v, simplify_operation.clone());
594                }
595            }
596            Value::Call(call) => {
597                for arg in call.args.iter_mut() {
598                    self.simplify_term(arg, simplify_operation.clone());
599                }
600                if let Some(kwargs) = &mut call.kwargs {
601                    for (_, v) in kwargs.iter_mut() {
602                        self.simplify_term(v, simplify_operation.clone());
603                    }
604                }
605            }
606            Value::List(list) => {
607                for elem in list.iter_mut() {
608                    self.simplify_term(elem, simplify_operation.clone());
609                }
610            }
611            Value::Expression(operation) => {
612                let so = simplify_operation.clone();
613                let cont = move |s: &mut Self, term: &mut Term| {
614                    s.simplify_term(term, simplify_operation.clone())
615                };
616                so(self, operation, &cont);
617            }
618            _ => (),
619        }
620
621        if let Ok(sym) = orig.as_symbol() {
622            if term.contains_variable(sym) {
623                *term = orig.clone()
624            }
625        }
626        self.seen.remove(&orig);
627    }
628
629    /// Simplify a partial until quiescence.
630    pub fn simplify_partial(&mut self, term: &mut Term) {
631        // TODO(ap): This does not handle hash collisions.
632        let mut last = term.hash_value();
633        let mut nbindings = self.bindings.len();
634        loop {
635            simplify_debug!("simplify loop {}", term);
636            self.counters.simplify_term();
637
638            self.simplify_term(term, Simplifier::simplify_operation_variables);
639            let now = term.hash_value();
640            if last == now && self.bindings.len() == nbindings {
641                break;
642            }
643            last = now;
644            nbindings = self.bindings.len();
645        }
646
647        self.simplify_term(term, Simplifier::deduplicate_operation);
648
649        self.counters.finish_acc(term.clone());
650    }
651}
652
653#[cfg(test)]
654mod test {
655    use super::*;
656
657    /// Ensure that debug flags are false. Do not remove this test. It is here
658    /// to ensure we don't release with debug logs or performance tracking enabled.
659    #[test]
660    #[allow(clippy::bool_assert_comparison)]
661    fn test_debug_off() {
662        assert_eq!(SIMPLIFY_DEBUG, false);
663        assert_eq!(TRACK_PERF, false);
664    }
665
666    #[test]
667    fn test_simplify_circular_dot_with_isa() {
668        let op = term!(op!(Dot, var!("x"), str!("x")));
669        let op = term!(op!(Unify, var!("x"), op));
670        let op = term!(op!(
671            And,
672            op,
673            term!(op!(Isa, var!("x"), term!(pattern!(instance!("X")))))
674        ));
675        let mut vs: HashSet<Symbol> = HashSet::new();
676        vs.insert(sym!("x"));
677        let (x, _) = simplify_partial(&sym!("x"), op, vs, false);
678        assert_eq!(
679            x,
680            term!(op!(
681                And,
682                term!(op!(Unify, var!("x"), term!(op!(Dot, var!("x"), str!("x"))))),
683                term!(op!(Isa, var!("x"), term!(pattern!(instance!("X")))))
684            ))
685        );
686    }
687}