use crate::Symbol;
use crate::{
util::pretty_print, Analysis, ENodeOrVar, HashMap, HashSet, Id, Language, PatternAst, RecExpr,
Rewrite, Subst, UnionFind, Var,
};
use std::fmt::{self, Debug, Display, Formatter};
use std::rc::Rc;
use symbolic_expressions::Sexp;
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
pub(crate) enum Justification {
Rule(Symbol),
Congruence,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
struct ExplainNode<L: Language> {
node: L,
next: Id,
current: Id,
justification: Justification,
existance_node: Id,
is_rewrite_forward: bool,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
pub struct Explain<L: Language> {
explainfind: Vec<ExplainNode<L>>,
#[cfg_attr(feature = "serde-1", serde(with = "vectorize"))]
uncanon_memo: HashMap<L, Id>,
}
pub type TreeExplanation<L> = Vec<Rc<TreeTerm<L>>>;
pub type FlatExplanation<L> = Vec<FlatTerm<L>>;
type ExplainCache<L> = HashMap<(Id, Id), Rc<TreeTerm<L>>>;
pub struct Explanation<L: Language> {
pub explanation_trees: TreeExplanation<L>,
flat_explanation: Option<FlatExplanation<L>>,
}
impl<L: Language + Display> Display for Explanation<L> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let mut s = "".to_string();
pretty_print(&mut s, &self.get_sexp(), 100, 0).unwrap();
f.write_str(&s)
}
}
impl<L: Language + Display> Explanation<L> {
pub fn get_flat_string(&mut self) -> String {
self.get_flat_strings().join("\n")
}
pub fn get_string(&self) -> String {
self.to_string()
}
pub fn get_string_with_let(&self) -> String {
let mut s = "".to_string();
pretty_print(&mut s, &self.get_sexp_with_let(), 100, 0).unwrap();
s
}
pub fn get_flat_strings(&mut self) -> Vec<String> {
self.make_flat_explanation()
.iter()
.map(|e| e.to_string())
.collect()
}
pub fn get_sexp(&self) -> Sexp {
let mut items = vec![Sexp::String("Explanation".to_string())];
for e in self.explanation_trees.iter() {
items.push(e.get_sexp());
}
Sexp::List(items)
}
pub fn get_sexp_with_let(&self) -> Sexp {
let mut shared: HashSet<*const TreeTerm<L>> = Default::default();
let mut to_let_bind = vec![];
for term in &self.explanation_trees {
self.find_to_let_bind(term.clone(), &mut shared, &mut to_let_bind);
}
let mut bindings: HashMap<*const TreeTerm<L>, Sexp> = Default::default();
let mut generated_bindings: Vec<(Sexp, Sexp)> = Default::default();
for to_bind in to_let_bind {
if bindings.get(&(&*to_bind as *const TreeTerm<L>)).is_none() {
let name = Sexp::String("v_".to_string() + &generated_bindings.len().to_string());
let ast = to_bind.get_sexp_with_bindings(&bindings);
generated_bindings.push((name.clone(), ast));
bindings.insert(&*to_bind as *const TreeTerm<L>, name);
}
}
let mut items = vec![Sexp::String("Explanation".to_string())];
for e in self.explanation_trees.iter() {
if let Some(existing) = bindings.get(&(&**e as *const TreeTerm<L>)) {
items.push(existing.clone());
} else {
items.push(e.get_sexp_with_bindings(&bindings));
}
}
let mut result = Sexp::List(items);
for (name, expr) in generated_bindings.into_iter().rev() {
let let_expr = Sexp::List(vec![name, expr]);
result = Sexp::List(vec![Sexp::String("let".to_string()), let_expr, result]);
}
result
}
fn find_to_let_bind(
&self,
term: Rc<TreeTerm<L>>,
shared: &mut HashSet<*const TreeTerm<L>>,
to_let_bind: &mut Vec<Rc<TreeTerm<L>>>,
) {
for proof in &term.child_proofs {
for child in proof {
self.find_to_let_bind(child.clone(), shared, to_let_bind);
}
}
if !term.child_proofs.is_empty() && !shared.insert(&*term as *const TreeTerm<L>) {
to_let_bind.push(term);
}
}
pub fn get_flat_sexps(&mut self) -> Vec<Sexp> {
self.make_flat_explanation()
.iter()
.map(|e| e.get_sexp())
.collect()
}
}
impl<L: Language> Explanation<L> {
pub fn new(explanation_trees: TreeExplanation<L>) -> Explanation<L> {
Explanation {
explanation_trees,
flat_explanation: None,
}
}
pub fn make_flat_explanation(&mut self) -> &FlatExplanation<L> {
if self.flat_explanation.is_some() {
return self.flat_explanation.as_ref().unwrap();
} else {
self.flat_explanation = Some(TreeTerm::flatten_proof(&self.explanation_trees));
self.flat_explanation.as_ref().unwrap()
}
}
pub fn check_proof<'a, R, N: Analysis<L>>(&mut self, rules: R)
where
R: IntoIterator<Item = &'a Rewrite<L, N>>,
L: 'a,
N: 'a,
{
let rules: Vec<&Rewrite<L, N>> = rules.into_iter().collect();
let rule_table = Explain::make_rule_table(rules.as_slice());
self.make_flat_explanation();
let flat_explanation = self.flat_explanation.as_ref().unwrap();
assert!(!flat_explanation[0].has_rewrite_forward());
assert!(!flat_explanation[0].has_rewrite_backward());
for i in 0..flat_explanation.len() - 1 {
let current = &flat_explanation[i];
let next = &flat_explanation[i + 1];
let has_forward = next.has_rewrite_forward();
let has_backward = next.has_rewrite_backward();
assert!(has_forward ^ has_backward);
if has_forward {
assert!(self.check_rewrite_at(current, next, &rule_table, true));
} else {
assert!(self.check_rewrite_at(current, next, &rule_table, false));
}
}
}
fn check_rewrite_at<N: Analysis<L>>(
&self,
current: &FlatTerm<L>,
next: &FlatTerm<L>,
table: &HashMap<Symbol, &Rewrite<L, N>>,
is_forward: bool,
) -> bool {
if is_forward && next.forward_rule.is_some() {
let rule_name = next.forward_rule.as_ref().unwrap();
if let Some(rule) = table.get(rule_name) {
Explanation::check_rewrite(current, next, rule)
} else {
true
}
} else if !is_forward && next.backward_rule.is_some() {
let rule_name = next.backward_rule.as_ref().unwrap();
if let Some(rule) = table.get(rule_name) {
Explanation::check_rewrite(next, current, rule)
} else {
true
}
} else {
for (left, right) in current.children.iter().zip(next.children.iter()) {
if !self.check_rewrite_at(left, right, table, is_forward) {
return false;
}
}
true
}
}
fn check_rewrite<'a, N: Analysis<L>>(
current: &'a FlatTerm<L>,
next: &'a FlatTerm<L>,
rewrite: &Rewrite<L, N>,
) -> bool {
if let Some(lhs) = rewrite.searcher.get_pattern_ast() {
if let Some(rhs) = rewrite.applier.get_pattern_ast() {
if ¤t.rewrite(lhs, rhs) != next {
return false;
}
}
}
true
}
}
#[derive(Debug, Clone)]
pub struct TreeTerm<L: Language> {
pub node: L,
pub backward_rule: Option<Symbol>,
pub forward_rule: Option<Symbol>,
pub child_proofs: Vec<TreeExplanation<L>>,
}
impl<L: Language> TreeTerm<L> {
pub fn new(node: L, child_proofs: Vec<TreeExplanation<L>>) -> TreeTerm<L> {
TreeTerm {
node,
backward_rule: None,
forward_rule: None,
child_proofs,
}
}
fn flatten_proof(proof: &[Rc<TreeTerm<L>>]) -> FlatExplanation<L> {
let mut flat_proof: FlatExplanation<L> = vec![];
for tree in proof {
let mut explanation = tree.flatten_explanation();
if !flat_proof.is_empty()
&& !explanation[0].has_rewrite_forward()
&& !explanation[0].has_rewrite_backward()
{
let last = flat_proof.pop().unwrap();
explanation[0].combine_rewrites(&last);
}
flat_proof.extend(explanation);
}
flat_proof
}
pub fn flatten_explanation(&self) -> FlatExplanation<L> {
let mut proof = vec![];
let mut child_proofs = vec![];
let mut representative_terms = vec![];
for child_explanation in &self.child_proofs {
let flat_proof = TreeTerm::flatten_proof(child_explanation);
representative_terms.push(flat_proof[0].remove_rewrites());
child_proofs.push(flat_proof);
}
proof.push(FlatTerm::new(
self.node.clone(),
representative_terms.clone(),
));
for (i, child_proof) in child_proofs.iter().enumerate() {
proof.last_mut().unwrap().children[i] = child_proof[0].clone();
for child in child_proof.iter().skip(1) {
let mut children = vec![];
for (j, rep_term) in representative_terms.iter().enumerate() {
if j == i {
children.push(child.clone());
} else {
children.push(rep_term.clone());
}
}
proof.push(FlatTerm::new(self.node.clone(), children));
}
representative_terms[i] = child_proof.last().unwrap().remove_rewrites();
}
proof[0].backward_rule = self.backward_rule;
proof[0].forward_rule = self.forward_rule;
proof
}
}
#[derive(Debug, Clone, Eq)]
pub struct FlatTerm<L: Language> {
pub node: L,
pub backward_rule: Option<Symbol>,
pub forward_rule: Option<Symbol>,
pub children: FlatExplanation<L>,
}
impl<L: Language + Display> Display for FlatTerm<L> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let s = self.get_sexp().to_string();
write!(f, "{}", s)
}
}
impl<L: Language> PartialEq for FlatTerm<L> {
fn eq(&self, other: &FlatTerm<L>) -> bool {
if !self.node.matches(&other.node) {
return false;
}
for (child1, child2) in self.children.iter().zip(other.children.iter()) {
if !child1.eq(&child2) {
return false;
}
}
true
}
}
impl<L: Language> FlatTerm<L> {
fn remove_rewrites(&self) -> FlatTerm<L> {
FlatTerm::new(
self.node.clone(),
self.children
.iter()
.map(|child| child.remove_rewrites())
.collect(),
)
}
fn combine_rewrites(&mut self, other: &FlatTerm<L>) {
if other.forward_rule.is_some() {
assert!(self.forward_rule.is_none());
self.forward_rule = other.forward_rule;
}
if other.backward_rule.is_some() {
assert!(self.backward_rule.is_none());
self.backward_rule = other.backward_rule;
}
for (left, right) in self.children.iter_mut().zip(other.children.iter()) {
left.combine_rewrites(right);
}
}
}
impl<L: Language> Default for Explain<L> {
fn default() -> Self {
Self::new()
}
}
impl<L: Language + Display> FlatTerm<L> {
pub fn get_sexp(&self) -> Sexp {
let op = Sexp::String(self.node.to_string());
let mut expr = if self.node.is_leaf() {
op
} else {
let mut vec = vec![op];
for child in &self.children {
vec.push(child.get_sexp());
}
Sexp::List(vec)
};
if let Some(rule_name) = &self.backward_rule {
expr = Sexp::List(vec![
Sexp::String("Rewrite<=".to_string()),
Sexp::String((*rule_name).to_string()),
expr,
]);
}
if let Some(rule_name) = &self.forward_rule {
expr = Sexp::List(vec![
Sexp::String("Rewrite=>".to_string()),
Sexp::String((*rule_name).to_string()),
expr,
]);
}
expr
}
}
impl<L: Language + Display> Display for TreeTerm<L> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let mut buf = String::new();
let width = 80;
pretty_print(&mut buf, &self.get_sexp(), width, 1).unwrap();
write!(f, "{}", buf)
}
}
impl<L: Language + Display> TreeTerm<L> {
pub fn get_sexp(&self) -> Sexp {
self.get_sexp_with_bindings(&Default::default())
}
pub(crate) fn get_sexp_with_bindings(
&self,
bindings: &HashMap<*const TreeTerm<L>, Sexp>,
) -> Sexp {
let op = Sexp::String(self.node.to_string());
let mut expr = if self.node.is_leaf() {
op
} else {
let mut vec = vec![op];
for child in &self.child_proofs {
assert!(!child.is_empty());
if child.len() == 1 {
if let Some(existing) = bindings.get(&(&*child[0] as *const TreeTerm<L>)) {
vec.push(existing.clone());
} else {
vec.push(child[0].get_sexp_with_bindings(bindings));
}
} else {
let mut child_expressions = vec![Sexp::String("Explanation".to_string())];
for child_explanation in child.iter() {
if let Some(existing) =
bindings.get(&(&**child_explanation as *const TreeTerm<L>))
{
child_expressions.push(existing.clone());
} else {
child_expressions
.push(child_explanation.get_sexp_with_bindings(bindings));
}
}
vec.push(Sexp::List(child_expressions));
}
}
Sexp::List(vec)
};
if let Some(rule_name) = &self.backward_rule {
expr = Sexp::List(vec![
Sexp::String("Rewrite<=".to_string()),
Sexp::String((*rule_name).to_string()),
expr,
]);
}
if let Some(rule_name) = &self.forward_rule {
expr = Sexp::List(vec![
Sexp::String("Rewrite=>".to_string()),
Sexp::String((*rule_name).to_string()),
expr,
]);
}
expr
}
}
impl<L: Language> FlatTerm<L> {
pub fn new(node: L, children: FlatExplanation<L>) -> FlatTerm<L> {
FlatTerm {
node,
backward_rule: None,
forward_rule: None,
children,
}
}
pub fn rewrite(&self, lhs: &PatternAst<L>, rhs: &PatternAst<L>) -> FlatTerm<L> {
let lhs_nodes = lhs.as_ref().as_ref();
let rhs_nodes = rhs.as_ref().as_ref();
let mut bindings = Default::default();
self.make_bindings(lhs_nodes, lhs_nodes.len() - 1, &mut bindings);
FlatTerm::from_pattern(rhs_nodes, rhs_nodes.len() - 1, &bindings)
}
pub fn has_rewrite_forward(&self) -> bool {
self.forward_rule.is_some()
|| self
.children
.iter()
.any(|child| child.has_rewrite_forward())
}
pub fn has_rewrite_backward(&self) -> bool {
self.backward_rule.is_some()
|| self
.children
.iter()
.any(|child| child.has_rewrite_backward())
}
fn from_pattern(
pattern: &[ENodeOrVar<L>],
location: usize,
bindings: &HashMap<Var, &FlatTerm<L>>,
) -> FlatTerm<L> {
match &pattern[location] {
ENodeOrVar::Var(var) => (*bindings.get(var).unwrap()).clone(),
ENodeOrVar::ENode(node) => {
let children = node.fold(vec![], |mut acc, child| {
acc.push(FlatTerm::from_pattern(
pattern,
usize::from(child),
bindings,
));
acc
});
FlatTerm::new(node.clone(), children)
}
}
}
fn make_bindings<'a>(
&'a self,
pattern: &[ENodeOrVar<L>],
location: usize,
bindings: &mut HashMap<Var, &'a FlatTerm<L>>,
) {
match &pattern[location] {
ENodeOrVar::Var(var) => {
bindings.insert(*var, self);
}
ENodeOrVar::ENode(node) => {
assert!(node.matches(&self.node));
let mut counter = 0;
node.for_each(|child| {
self.children[counter].make_bindings(pattern, usize::from(child), bindings);
counter += 1;
});
}
}
}
}
impl<L: Language> Explain<L> {
fn node_to_explanation(&self, node_id: Id) -> TreeTerm<L> {
let node = self.explainfind[usize::from(node_id)].node.clone();
let children = node.fold(vec![], |mut sofar, child| {
sofar.push(vec![Rc::new(self.node_to_explanation(child))]);
sofar
});
TreeTerm::new(node, children)
}
fn node_to_flat_explanation(&self, node_id: Id) -> FlatTerm<L> {
let node = self.explainfind[usize::from(node_id)].node.clone();
let children = node.fold(vec![], |mut sofar, child| {
sofar.push(self.node_to_flat_explanation(child));
sofar
});
FlatTerm::new(node, children)
}
fn make_rule_table<'a, N: Analysis<L>>(
rules: &[&'a Rewrite<L, N>],
) -> HashMap<Symbol, &'a Rewrite<L, N>> {
let mut table: HashMap<Symbol, &'a Rewrite<L, N>> = Default::default();
for r in rules {
table.insert(r.name, r);
}
table
}
pub fn check_each_explain<N: Analysis<L>>(&self, rules: &[&Rewrite<L, N>]) -> bool {
let rule_table = Explain::make_rule_table(rules);
for i in 0..self.explainfind.len() {
let explain_node = &self.explainfind[i];
let mut existance = i;
let mut seen_existance: HashSet<usize> = Default::default();
loop {
seen_existance.insert(existance);
let next = usize::from(self.explainfind[existance].existance_node);
if existance == next {
break;
}
existance = next;
if seen_existance.contains(&existance) {
panic!("Cycle in existance!");
}
}
if explain_node.next != Id::from(i) {
let mut current_explanation = self.node_to_flat_explanation(Id::from(i));
let mut next_explanation = self.node_to_flat_explanation(explain_node.next);
if let Justification::Rule(rule_name) = &explain_node.justification {
if let Some(rule) = rule_table.get(rule_name) {
if !explain_node.is_rewrite_forward {
std::mem::swap(&mut current_explanation, &mut next_explanation);
}
if !Explanation::check_rewrite(
¤t_explanation,
&next_explanation,
rule,
) {
return false;
}
}
}
}
}
true
}
pub fn new() -> Self {
Explain {
explainfind: vec![],
uncanon_memo: Default::default(),
}
}
pub(crate) fn set_existance_reason(&mut self, node: Id, existance_node: Id) {
self.explainfind[usize::from(node)].existance_node = existance_node;
}
pub(crate) fn add(&mut self, node: L, set: Id, existance_node: Id) -> Id {
self.uncanon_memo.insert(node.clone(), set);
self.explainfind.push(ExplainNode {
node,
justification: Justification::Congruence,
next: set,
current: set,
existance_node,
is_rewrite_forward: false,
});
set
}
fn add_expr(
&mut self,
expr: &RecExpr<L>,
memo: &HashMap<L, Id>,
unionfind: &mut UnionFind,
) -> Id {
let nodes: Vec<ENodeOrVar<L>> = expr
.as_ref()
.iter()
.map(|node| ENodeOrVar::ENode(node.clone()))
.collect();
let pattern = PatternAst::from(nodes);
self.add_match(&pattern, &Default::default(), memo, unionfind)
}
pub(crate) fn add_match(
&mut self,
pattern: &PatternAst<L>,
subst: &Subst,
memo: &HashMap<L, Id>,
unionfind: &mut UnionFind,
) -> Id {
let nodes = pattern.as_ref().as_ref();
let mut new_ids = Vec::with_capacity(nodes.len());
let mut match_ids = Vec::with_capacity(nodes.len());
for node in nodes {
match node {
ENodeOrVar::Var(var) => {
let bottom_id = unionfind.find(subst[*var]);
new_ids.push(unionfind.find(bottom_id));
match_ids.push(bottom_id);
}
ENodeOrVar::ENode(pattern_node) => {
let node = pattern_node
.clone()
.map_children(|i| new_ids[usize::from(i)]);
let new_congruent_node = pattern_node
.clone()
.map_children(|i| match_ids[usize::from(i)]);
if let Some(existing_id) = self.uncanon_memo.get(&new_congruent_node) {
new_ids.push(unionfind.find(*existing_id));
match_ids.push(*existing_id);
} else {
let congruent_id = *memo.get(&node).unwrap_or_else(|| {
panic!("Internal error! Pattern did not exist for substitution.");
});
let congruent_class = unionfind.find(congruent_id);
new_ids.push(congruent_class);
assert!(
node == self.explainfind[usize::from(congruent_id)]
.node
.clone()
.map_children(|id| unionfind.find(id))
);
let new_congruent_id =
self.add(new_congruent_node, unionfind.make_set(), congruent_id);
match_ids.push(new_congruent_id);
unionfind.union(congruent_class, new_congruent_id);
self.union(
new_congruent_id,
congruent_id,
Justification::Congruence,
false,
);
}
}
}
}
let last_id = *match_ids.last().unwrap();
last_id
}
fn make_leader(&mut self, node: Id) {
let next = self.explainfind[usize::from(node)].next;
if next != node {
self.make_leader(next);
self.explainfind[usize::from(next)].justification =
self.explainfind[usize::from(node)].justification.clone();
self.explainfind[usize::from(next)].is_rewrite_forward =
!self.explainfind[usize::from(node)].is_rewrite_forward;
self.explainfind[usize::from(next)].next = node;
}
}
pub(crate) fn union(
&mut self,
node1: Id,
node2: Id,
justification: Justification,
new_rhs: bool,
) {
if new_rhs {
self.set_existance_reason(node2, node1)
}
self.make_leader(node1);
self.explainfind[usize::from(node1)].next = node2;
self.explainfind[usize::from(node1)].justification = justification;
self.explainfind[usize::from(node1)].is_rewrite_forward = true;
}
pub(crate) fn explain_matches(
&mut self,
left: &RecExpr<L>,
right: &PatternAst<L>,
subst: &Subst,
memo: &HashMap<L, Id>,
unionfind: &mut UnionFind,
) -> Explanation<L> {
let left_added = self.add_expr(left, memo, unionfind);
let right_added = self.add_match(right, &subst, memo, unionfind);
let mut cache = Default::default();
Explanation::new(self.explain_enodes(left_added, right_added, &mut cache))
}
pub(crate) fn explain_equivalence(
&mut self,
left: &RecExpr<L>,
right: &RecExpr<L>,
memo: &HashMap<L, Id>,
unionfind: &mut UnionFind,
) -> Explanation<L> {
let left_added = self.add_expr(left, memo, unionfind);
let right_added = self.add_expr(right, memo, unionfind);
let mut cache = Default::default();
Explanation::new(self.explain_enodes(left_added, right_added, &mut cache))
}
pub(crate) fn explain_existance(
&mut self,
left: &RecExpr<L>,
memo: &HashMap<L, Id>,
unionfind: &mut UnionFind,
) -> Explanation<L> {
let left_added = self.add_expr(left, memo, unionfind);
let mut cache = Default::default();
Explanation::new(self.explain_enode_existance(
left_added,
Rc::new(self.node_to_explanation(left_added)),
&mut cache,
))
}
pub(crate) fn explain_existance_pattern(
&mut self,
left: &PatternAst<L>,
subst: &Subst,
memo: &HashMap<L, Id>,
unionfind: &mut UnionFind,
) -> Explanation<L> {
let left_added = self.add_match(left, &subst, memo, unionfind);
let mut cache = Default::default();
Explanation::new(self.explain_enode_existance(
left_added,
Rc::new(self.node_to_explanation(left_added)),
&mut cache,
))
}
fn common_ancestor(&self, mut left: Id, mut right: Id) -> Id {
let mut seen_left: HashSet<Id> = Default::default();
let mut seen_right: HashSet<Id> = Default::default();
loop {
seen_left.insert(left);
if seen_right.contains(&left) {
return left;
}
seen_right.insert(right);
if seen_left.contains(&right) {
return right;
}
let next_left = self.explainfind[usize::from(left)].next;
let next_right = self.explainfind[usize::from(right)].next;
assert!(next_left != left || next_right != right);
left = next_left;
right = next_right;
}
}
fn get_nodes(&self, mut node: Id, ancestor: Id) -> Vec<&ExplainNode<L>> {
if node == ancestor {
return vec![];
}
let mut nodes = vec![];
loop {
let next = self.explainfind[usize::from(node)].next;
nodes.push(&self.explainfind[usize::from(node)]);
if next == ancestor {
return nodes;
}
assert!(next != node);
node = next;
}
}
fn explain_enode_existance(
&self,
node: Id,
rest_of_proof: Rc<TreeTerm<L>>,
cache: &mut ExplainCache<L>,
) -> TreeExplanation<L> {
let graphnode = &self.explainfind[usize::from(node)];
let existance = graphnode.existance_node;
let existance_node = &self.explainfind[usize::from(existance)];
if existance == node {
return vec![Rc::new(self.node_to_explanation(node)), rest_of_proof];
}
if graphnode.next == existance || existance_node.next == node {
let direction;
let justification;
if graphnode.next == existance {
direction = !graphnode.is_rewrite_forward;
justification = &graphnode.justification;
} else {
direction = existance_node.is_rewrite_forward;
justification = &existance_node.justification;
}
return self.explain_enode_existance(
existance,
self.explain_adjacent(existance, node, direction, justification, cache),
cache,
);
}
let mut new_rest_of_proof = self.node_to_explanation(existance);
let mut index_of_child = 0;
let mut found = false;
existance_node.node.for_each(|child| {
if found {
return;
}
if child == node {
found = true;
} else {
index_of_child += 1;
}
});
assert!(found);
new_rest_of_proof.child_proofs[index_of_child].push(rest_of_proof);
self.explain_enode_existance(existance, Rc::new(new_rest_of_proof), cache)
}
fn explain_enodes(
&self,
left: Id,
right: Id,
cache: &mut ExplainCache<L>,
) -> TreeExplanation<L> {
let mut proof = vec![Rc::new(self.node_to_explanation(left))];
let ancestor = self.common_ancestor(left, right);
let left_nodes = self.get_nodes(left, ancestor);
let right_nodes = self.get_nodes(right, ancestor);
for (i, node) in left_nodes
.iter()
.chain(right_nodes.iter().rev())
.enumerate()
{
let mut direction = node.is_rewrite_forward;
let mut next = node.next;
let mut current = node.current;
if i >= left_nodes.len() {
direction = !direction;
std::mem::swap(&mut next, &mut current);
}
proof.push(self.explain_adjacent(current, next, direction, &node.justification, cache));
}
proof
}
fn explain_adjacent(
&self,
current: Id,
next: Id,
rule_direction: bool,
justification: &Justification,
cache: &mut ExplainCache<L>,
) -> Rc<TreeTerm<L>> {
let fingerprint = (current, next);
if let Some(answer) = cache.get(&fingerprint) {
return answer.clone();
}
let term = match justification {
Justification::Rule(name) => {
let mut rewritten = self.node_to_explanation(next);
if rule_direction {
rewritten.forward_rule = Some(*name);
} else {
rewritten.backward_rule = Some(*name);
}
Rc::new(rewritten)
}
Justification::Congruence => {
let current_node = &self.explainfind[usize::from(current)].node;
let next_node = &self.explainfind[usize::from(next)].node;
assert!(current_node.matches(next_node));
let mut subproofs = vec![];
for (left_child, right_child) in current_node
.children()
.iter()
.zip(next_node.children().iter())
{
subproofs.push(self.explain_enodes(*left_child, *right_child, cache));
}
Rc::new(TreeTerm::new(current_node.clone(), subproofs))
}
};
cache.insert(fingerprint, term.clone());
term
}
}