use crate::hvm::net_trees_mut;
use super::{tree_children, tree_children_mut};
use core::ops::RangeFrom;
use hvm::ast::{Net, Tree};
use std::collections::HashMap;
pub fn eta_reduce_hvm_net(net: &mut Net) {
let mut phase1 = Phase1::default();
for tree in net_trees_mut(net) {
phase1.walk_tree(tree);
}
let mut phase2 = Phase2 { nodes: phase1.nodes, index: 0.. };
for tree in net_trees_mut(net) {
phase2.reduce_tree(tree);
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum NodeType {
Ctr(u16),
Var(isize),
Era,
Other,
Hole,
}
#[derive(Default, Debug)]
struct Phase1<'a> {
vars: HashMap<&'a str, usize>,
nodes: Vec<NodeType>,
}
impl<'a> Phase1<'a> {
fn walk_tree(&mut self, tree: &'a Tree) {
match tree {
Tree::Con { fst, snd } => {
self.nodes.push(NodeType::Ctr(0));
self.walk_tree(fst);
self.walk_tree(snd);
}
Tree::Dup { fst, snd } => {
self.nodes.push(NodeType::Ctr(1));
self.walk_tree(fst);
self.walk_tree(snd);
}
Tree::Var { nam } => {
if let Some(i) = self.vars.get(&**nam) {
let j = self.nodes.len() as isize;
self.nodes.push(NodeType::Var(*i as isize - j));
self.nodes[*i] = NodeType::Var(j - *i as isize);
} else {
self.vars.insert(nam, self.nodes.len());
self.nodes.push(NodeType::Hole);
}
}
Tree::Era => self.nodes.push(NodeType::Era),
_ => {
self.nodes.push(NodeType::Other);
for i in tree_children(tree) {
self.walk_tree(i);
}
}
}
}
}
struct Phase2 {
nodes: Vec<NodeType>,
index: RangeFrom<usize>,
}
impl Phase2 {
fn reduce_ctr(&mut self, tree: &mut Tree, idx: usize) -> NodeType {
if let Tree::Con { fst, snd } | Tree::Dup { fst, snd } = tree {
let fst_typ = self.reduce_tree(fst);
let snd_typ = self.reduce_tree(snd);
match (fst_typ, snd_typ) {
(NodeType::Var(off_lft), NodeType::Var(off_rgt)) => {
if off_lft == off_rgt && self.nodes[idx] == self.nodes[(idx as isize + off_lft) as usize] {
let Tree::Var { nam } = fst.as_mut() else { unreachable!() };
*tree = Tree::Var { nam: std::mem::take(nam) };
return NodeType::Var(off_lft);
}
}
(NodeType::Era, NodeType::Era) => {
*tree = Tree::Era;
return NodeType::Era;
}
_ => {}
}
self.nodes[idx]
} else {
unreachable!()
}
}
fn reduce_tree(&mut self, tree: &mut Tree) -> NodeType {
let idx = self.index.next().unwrap();
match tree {
Tree::Con { .. } | Tree::Dup { .. } => self.reduce_ctr(tree, idx),
_ => {
for child in tree_children_mut(tree) {
self.reduce_tree(child);
}
self.nodes[idx]
}
}
}
}