use crate::{Expr, Name};
use std::collections::{HashMap, HashSet};
use super::functions::{mk_congr_theorem, TermIdx};
#[allow(dead_code)]
pub struct FlatSubstitution {
pairs: Vec<(String, String)>,
}
#[allow(dead_code)]
impl FlatSubstitution {
pub fn new() -> Self {
Self { pairs: Vec::new() }
}
pub fn add(&mut self, from: impl Into<String>, to: impl Into<String>) {
self.pairs.push((from.into(), to.into()));
}
pub fn apply(&self, s: &str) -> String {
let mut result = s.to_string();
for (from, to) in &self.pairs {
result = result.replace(from.as_str(), to.as_str());
}
result
}
pub fn len(&self) -> usize {
self.pairs.len()
}
pub fn is_empty(&self) -> bool {
self.pairs.is_empty()
}
}
pub struct InstrumentedCC {
pub cc: CongruenceClosure,
pub stats: CongrClosureStats,
}
impl InstrumentedCC {
pub fn new() -> Self {
Self {
cc: CongruenceClosure::new(),
stats: CongrClosureStats::new(),
}
}
pub fn add_equality(&mut self, e1: Expr, e2: Expr) {
let before = self.cc.num_classes();
self.cc.add_equality(e1, e2);
let after = self.cc.num_classes();
if after < before {
self.stats.unions += 1;
}
self.stats.equalities_added += 1;
self.stats.apps_tracked = self.cc.apps.len();
}
pub fn are_equal(&mut self, e1: &Expr, e2: &Expr) -> bool {
self.cc.are_equal(e1, e2)
}
pub fn reset(&mut self) {
self.cc.clear();
self.stats = CongrClosureStats::new();
}
pub fn num_classes(&mut self) -> usize {
self.cc.num_classes()
}
}
pub struct EGraph {
nodes: Vec<ENode>,
}
impl EGraph {
pub fn new() -> Self {
Self { nodes: Vec::new() }
}
pub fn add_expr(&mut self, expr: Expr) -> usize {
for (i, node) in self.nodes.iter().enumerate() {
if node.contains(&expr) {
return i;
}
}
self.nodes.push(ENode::singleton(expr));
self.nodes.len() - 1
}
pub fn find_class(&self, expr: &Expr) -> Option<usize> {
self.nodes.iter().position(|n| n.contains(expr))
}
pub fn merge_classes(&mut self, id1: usize, id2: usize, proof: Option<Expr>) {
if id1 == id2 {
return;
}
let (small_id, large_id) = if id1 < id2 { (id2, id1) } else { (id1, id2) };
let small = self.nodes.remove(small_id);
for (m, p) in small.members.into_iter().zip(small.proofs) {
let proof_for = proof.clone().or(p);
self.nodes[large_id].add_member(m, proof_for);
}
}
pub fn are_equal(&self, e1: &Expr, e2: &Expr) -> bool {
match (self.find_class(e1), self.find_class(e2)) {
(Some(c1), Some(c2)) => c1 == c2,
_ => false,
}
}
pub fn add_equality(&mut self, e1: Expr, e2: Expr, proof: Option<Expr>) {
let c1 = self.add_expr(e1);
let c2 = self.add_expr(e2);
if c1 != c2 {
self.merge_classes(c1, c2, proof);
}
}
pub fn num_classes(&self) -> usize {
self.nodes.len()
}
pub fn get_class(&self, id: usize) -> Option<&ENode> {
self.nodes.get(id)
}
pub fn representative(&self, expr: &Expr) -> Option<&Expr> {
self.find_class(expr)
.and_then(|id| self.nodes.get(id))
.map(|n| &n.repr)
}
pub fn clear(&mut self) {
self.nodes.clear();
}
}
#[derive(Debug, Clone)]
pub enum CongrProof {
Refl(Expr),
Symm(Box<CongrProof>),
Trans(Box<CongrProof>, Box<CongrProof>),
Congr(Box<CongrProof>, Box<CongrProof>),
Hyp(Name),
}
impl CongrProof {
pub fn depth(&self) -> usize {
match self {
CongrProof::Refl(_) | CongrProof::Hyp(_) => 0,
CongrProof::Symm(p) => 1 + p.depth(),
CongrProof::Trans(p, q) | CongrProof::Congr(p, q) => 1 + p.depth().max(q.depth()),
}
}
pub fn is_refl(&self) -> bool {
matches!(self, CongrProof::Refl(_))
}
pub fn hypothesis_count(&self) -> usize {
match self {
CongrProof::Hyp(_) => 1,
CongrProof::Refl(_) => 0,
CongrProof::Symm(p) => p.hypothesis_count(),
CongrProof::Trans(p, q) | CongrProof::Congr(p, q) => {
p.hypothesis_count() + q.hypothesis_count()
}
}
}
pub fn simplify(self) -> Self {
match self {
CongrProof::Symm(inner) => match *inner {
CongrProof::Symm(p) => p.simplify(),
other => CongrProof::Symm(Box::new(other.simplify())),
},
CongrProof::Trans(p, q) => {
CongrProof::Trans(Box::new(p.simplify()), Box::new(q.simplify()))
}
CongrProof::Congr(p, q) => {
CongrProof::Congr(Box::new(p.simplify()), Box::new(q.simplify()))
}
other => other,
}
}
}
#[allow(dead_code)]
pub struct WindowIterator<'a, T> {
pub(super) data: &'a [T],
pub(super) pos: usize,
pub(super) window: usize,
}
#[allow(dead_code)]
impl<'a, T> WindowIterator<'a, T> {
pub fn new(data: &'a [T], window: usize) -> Self {
Self {
data,
pos: 0,
window,
}
}
}
#[allow(dead_code)]
#[allow(missing_docs)]
pub struct RewriteRule {
pub name: String,
pub lhs: String,
pub rhs: String,
pub conditional: bool,
}
#[allow(dead_code)]
impl RewriteRule {
pub fn unconditional(
name: impl Into<String>,
lhs: impl Into<String>,
rhs: impl Into<String>,
) -> Self {
Self {
name: name.into(),
lhs: lhs.into(),
rhs: rhs.into(),
conditional: false,
}
}
pub fn conditional(
name: impl Into<String>,
lhs: impl Into<String>,
rhs: impl Into<String>,
) -> Self {
Self {
name: name.into(),
lhs: lhs.into(),
rhs: rhs.into(),
conditional: true,
}
}
pub fn display(&self) -> String {
format!("{}: {} → {}", self.name, self.lhs, self.rhs)
}
}
#[allow(dead_code)]
pub struct PrefixCounter {
children: std::collections::HashMap<char, PrefixCounter>,
count: usize,
}
#[allow(dead_code)]
impl PrefixCounter {
pub fn new() -> Self {
Self {
children: std::collections::HashMap::new(),
count: 0,
}
}
pub fn record(&mut self, s: &str) {
self.count += 1;
let mut node = self;
for c in s.chars() {
node = node.children.entry(c).or_default();
node.count += 1;
}
}
pub fn count_with_prefix(&self, prefix: &str) -> usize {
let mut node = self;
for c in prefix.chars() {
match node.children.get(&c) {
Some(n) => node = n,
None => return 0,
}
}
node.count
}
}
#[allow(dead_code)]
pub struct NonEmptyVec<T> {
head: T,
tail: Vec<T>,
}
#[allow(dead_code)]
impl<T> NonEmptyVec<T> {
pub fn singleton(val: T) -> Self {
Self {
head: val,
tail: Vec::new(),
}
}
pub fn push(&mut self, val: T) {
self.tail.push(val);
}
pub fn first(&self) -> &T {
&self.head
}
pub fn last(&self) -> &T {
self.tail.last().unwrap_or(&self.head)
}
pub fn len(&self) -> usize {
1 + self.tail.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn to_vec(&self) -> Vec<&T> {
let mut v = vec![&self.head];
v.extend(self.tail.iter());
v
}
}
#[allow(dead_code)]
pub struct SlidingSum {
window: Vec<f64>,
capacity: usize,
pos: usize,
sum: f64,
count: usize,
}
#[allow(dead_code)]
impl SlidingSum {
pub fn new(capacity: usize) -> Self {
Self {
window: vec![0.0; capacity],
capacity,
pos: 0,
sum: 0.0,
count: 0,
}
}
pub fn push(&mut self, val: f64) {
let oldest = self.window[self.pos];
self.sum -= oldest;
self.sum += val;
self.window[self.pos] = val;
self.pos = (self.pos + 1) % self.capacity;
if self.count < self.capacity {
self.count += 1;
}
}
pub fn sum(&self) -> f64 {
self.sum
}
pub fn mean(&self) -> Option<f64> {
if self.count == 0 {
None
} else {
Some(self.sum / self.count as f64)
}
}
pub fn count(&self) -> usize {
self.count
}
}
#[allow(dead_code)]
pub struct RewriteRuleSet {
rules: Vec<RewriteRule>,
}
#[allow(dead_code)]
impl RewriteRuleSet {
pub fn new() -> Self {
Self { rules: Vec::new() }
}
pub fn add(&mut self, rule: RewriteRule) {
self.rules.push(rule);
}
pub fn len(&self) -> usize {
self.rules.len()
}
pub fn is_empty(&self) -> bool {
self.rules.is_empty()
}
pub fn conditional_rules(&self) -> Vec<&RewriteRule> {
self.rules.iter().filter(|r| r.conditional).collect()
}
pub fn unconditional_rules(&self) -> Vec<&RewriteRule> {
self.rules.iter().filter(|r| !r.conditional).collect()
}
pub fn get(&self, name: &str) -> Option<&RewriteRule> {
self.rules.iter().find(|r| r.name == name)
}
}
#[allow(dead_code)]
pub struct PathBuf {
components: Vec<String>,
}
#[allow(dead_code)]
impl PathBuf {
pub fn new() -> Self {
Self {
components: Vec::new(),
}
}
pub fn push(&mut self, comp: impl Into<String>) {
self.components.push(comp.into());
}
pub fn pop(&mut self) {
self.components.pop();
}
pub fn as_str(&self) -> String {
self.components.join("/")
}
pub fn depth(&self) -> usize {
self.components.len()
}
pub fn clear(&mut self) {
self.components.clear();
}
}
#[allow(dead_code)]
pub struct ConfigNode {
key: String,
value: Option<String>,
children: Vec<ConfigNode>,
}
#[allow(dead_code)]
impl ConfigNode {
pub fn leaf(key: impl Into<String>, value: impl Into<String>) -> Self {
Self {
key: key.into(),
value: Some(value.into()),
children: Vec::new(),
}
}
pub fn section(key: impl Into<String>) -> Self {
Self {
key: key.into(),
value: None,
children: Vec::new(),
}
}
pub fn add_child(&mut self, child: ConfigNode) {
self.children.push(child);
}
pub fn key(&self) -> &str {
&self.key
}
pub fn value(&self) -> Option<&str> {
self.value.as_deref()
}
pub fn num_children(&self) -> usize {
self.children.len()
}
pub fn lookup(&self, path: &str) -> Option<&str> {
let mut parts = path.splitn(2, '.');
let head = parts.next()?;
let tail = parts.next();
if head != self.key {
return None;
}
match tail {
None => self.value.as_deref(),
Some(rest) => self.children.iter().find_map(|c| c.lookup_relative(rest)),
}
}
fn lookup_relative(&self, path: &str) -> Option<&str> {
let mut parts = path.splitn(2, '.');
let head = parts.next()?;
let tail = parts.next();
if head != self.key {
return None;
}
match tail {
None => self.value.as_deref(),
Some(rest) => self.children.iter().find_map(|c| c.lookup_relative(rest)),
}
}
}
#[allow(dead_code)]
pub struct TokenBucket {
capacity: u64,
tokens: u64,
refill_per_ms: u64,
last_refill: std::time::Instant,
}
#[allow(dead_code)]
impl TokenBucket {
pub fn new(capacity: u64, refill_per_ms: u64) -> Self {
Self {
capacity,
tokens: capacity,
refill_per_ms,
last_refill: std::time::Instant::now(),
}
}
pub fn try_consume(&mut self, n: u64) -> bool {
self.refill();
if self.tokens >= n {
self.tokens -= n;
true
} else {
false
}
}
fn refill(&mut self) {
let now = std::time::Instant::now();
let elapsed_ms = now.duration_since(self.last_refill).as_millis() as u64;
if elapsed_ms > 0 {
let new_tokens = elapsed_ms * self.refill_per_ms;
self.tokens = (self.tokens + new_tokens).min(self.capacity);
self.last_refill = now;
}
}
pub fn available(&self) -> u64 {
self.tokens
}
pub fn capacity(&self) -> u64 {
self.capacity
}
}
#[allow(dead_code)]
pub struct VersionedRecord<T: Clone> {
history: Vec<T>,
}
#[allow(dead_code)]
impl<T: Clone> VersionedRecord<T> {
pub fn new(initial: T) -> Self {
Self {
history: vec![initial],
}
}
pub fn update(&mut self, val: T) {
self.history.push(val);
}
pub fn current(&self) -> &T {
self.history
.last()
.expect("VersionedRecord history is always non-empty after construction")
}
pub fn at_version(&self, n: usize) -> Option<&T> {
self.history.get(n)
}
pub fn version(&self) -> usize {
self.history.len() - 1
}
pub fn has_history(&self) -> bool {
self.history.len() > 1
}
}
#[allow(dead_code)]
pub struct TransformStat {
before: StatSummary,
after: StatSummary,
}
#[allow(dead_code)]
impl TransformStat {
pub fn new() -> Self {
Self {
before: StatSummary::new(),
after: StatSummary::new(),
}
}
pub fn record_before(&mut self, v: f64) {
self.before.record(v);
}
pub fn record_after(&mut self, v: f64) {
self.after.record(v);
}
pub fn mean_ratio(&self) -> Option<f64> {
let b = self.before.mean()?;
let a = self.after.mean()?;
if b.abs() < f64::EPSILON {
return None;
}
Some(a / b)
}
}
#[derive(Debug, Default)]
pub struct CongrLemmaCache {
entries: std::collections::HashMap<(Name, usize), CongruenceTheorem>,
}
impl CongrLemmaCache {
pub fn new() -> Self {
Self::default()
}
pub fn get(&self, name: &Name, num_args: usize) -> Option<&CongruenceTheorem> {
self.entries.get(&(name.clone(), num_args))
}
pub fn insert(&mut self, thm: CongruenceTheorem) {
let key = (thm.fn_name.clone(), thm.num_args);
self.entries.insert(key, thm);
}
pub fn get_or_compute(&mut self, name: Name, num_args: usize) -> &CongruenceTheorem {
let key = (name.clone(), num_args);
self.entries
.entry(key)
.or_insert_with(|| mk_congr_theorem(name, num_args))
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn clear(&mut self) {
self.entries.clear();
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
#[allow(dead_code)]
pub struct StatSummary {
count: u64,
sum: f64,
min: f64,
max: f64,
}
#[allow(dead_code)]
impl StatSummary {
pub fn new() -> Self {
Self {
count: 0,
sum: 0.0,
min: f64::INFINITY,
max: f64::NEG_INFINITY,
}
}
pub fn record(&mut self, val: f64) {
self.count += 1;
self.sum += val;
if val < self.min {
self.min = val;
}
if val > self.max {
self.max = val;
}
}
pub fn mean(&self) -> Option<f64> {
if self.count == 0 {
None
} else {
Some(self.sum / self.count as f64)
}
}
pub fn min(&self) -> Option<f64> {
if self.count == 0 {
None
} else {
Some(self.min)
}
}
pub fn max(&self) -> Option<f64> {
if self.count == 0 {
None
} else {
Some(self.max)
}
}
pub fn count(&self) -> u64 {
self.count
}
}
#[allow(dead_code)]
pub struct Stopwatch {
start: std::time::Instant,
splits: Vec<f64>,
}
#[allow(dead_code)]
impl Stopwatch {
pub fn start() -> Self {
Self {
start: std::time::Instant::now(),
splits: Vec::new(),
}
}
pub fn split(&mut self) {
self.splits.push(self.elapsed_ms());
}
pub fn elapsed_ms(&self) -> f64 {
self.start.elapsed().as_secs_f64() * 1000.0
}
pub fn splits(&self) -> &[f64] {
&self.splits
}
pub fn num_splits(&self) -> usize {
self.splits.len()
}
}
#[allow(dead_code)]
pub enum Either2<A, B> {
First(A),
Second(B),
}
#[allow(dead_code)]
impl<A, B> Either2<A, B> {
pub fn is_first(&self) -> bool {
matches!(self, Either2::First(_))
}
pub fn is_second(&self) -> bool {
matches!(self, Either2::Second(_))
}
pub fn first(self) -> Option<A> {
match self {
Either2::First(a) => Some(a),
_ => None,
}
}
pub fn second(self) -> Option<B> {
match self {
Either2::Second(b) => Some(b),
_ => None,
}
}
pub fn map_first<C, F: FnOnce(A) -> C>(self, f: F) -> Either2<C, B> {
match self {
Either2::First(a) => Either2::First(f(a)),
Either2::Second(b) => Either2::Second(b),
}
}
}
#[allow(dead_code)]
pub struct WriteOnce<T> {
value: std::cell::Cell<Option<T>>,
}
#[allow(dead_code)]
impl<T: Copy> WriteOnce<T> {
pub fn new() -> Self {
Self {
value: std::cell::Cell::new(None),
}
}
pub fn write(&self, val: T) -> bool {
if self.value.get().is_some() {
return false;
}
self.value.set(Some(val));
true
}
pub fn read(&self) -> Option<T> {
self.value.get()
}
pub fn is_written(&self) -> bool {
self.value.get().is_some()
}
}
#[allow(dead_code)]
#[allow(missing_docs)]
pub enum DecisionNode {
Leaf(String),
Branch {
key: String,
val: String,
yes_branch: Box<DecisionNode>,
no_branch: Box<DecisionNode>,
},
}
#[allow(dead_code)]
impl DecisionNode {
pub fn evaluate(&self, ctx: &std::collections::HashMap<String, String>) -> &str {
match self {
DecisionNode::Leaf(action) => action.as_str(),
DecisionNode::Branch {
key,
val,
yes_branch,
no_branch,
} => {
let actual = ctx.get(key).map(|s| s.as_str()).unwrap_or("");
if actual == val.as_str() {
yes_branch.evaluate(ctx)
} else {
no_branch.evaluate(ctx)
}
}
}
}
pub fn depth(&self) -> usize {
match self {
DecisionNode::Leaf(_) => 0,
DecisionNode::Branch {
yes_branch,
no_branch,
..
} => 1 + yes_branch.depth().max(no_branch.depth()),
}
}
}
#[allow(dead_code)]
pub struct TransitiveClosure {
adj: Vec<Vec<usize>>,
n: usize,
}
#[allow(dead_code)]
impl TransitiveClosure {
pub fn new(n: usize) -> Self {
Self {
adj: vec![Vec::new(); n],
n,
}
}
pub fn add_edge(&mut self, from: usize, to: usize) {
if from < self.n {
self.adj[from].push(to);
}
}
pub fn reachable_from(&self, start: usize) -> Vec<usize> {
let mut visited = vec![false; self.n];
let mut queue = std::collections::VecDeque::new();
queue.push_back(start);
while let Some(node) = queue.pop_front() {
if node >= self.n || visited[node] {
continue;
}
visited[node] = true;
for &next in &self.adj[node] {
queue.push_back(next);
}
}
(0..self.n).filter(|&i| visited[i]).collect()
}
pub fn can_reach(&self, from: usize, to: usize) -> bool {
self.reachable_from(from).contains(&to)
}
}
#[allow(dead_code)]
pub struct SimpleDag {
edges: Vec<Vec<usize>>,
}
#[allow(dead_code)]
impl SimpleDag {
pub fn new(n: usize) -> Self {
Self {
edges: vec![Vec::new(); n],
}
}
pub fn add_edge(&mut self, from: usize, to: usize) {
if from < self.edges.len() {
self.edges[from].push(to);
}
}
pub fn successors(&self, node: usize) -> &[usize] {
self.edges.get(node).map(|v| v.as_slice()).unwrap_or(&[])
}
pub fn can_reach(&self, from: usize, to: usize) -> bool {
let mut visited = vec![false; self.edges.len()];
self.dfs(from, to, &mut visited)
}
fn dfs(&self, cur: usize, target: usize, visited: &mut Vec<bool>) -> bool {
if cur == target {
return true;
}
if cur >= visited.len() || visited[cur] {
return false;
}
visited[cur] = true;
for &next in self.successors(cur) {
if self.dfs(next, target, visited) {
return true;
}
}
false
}
pub fn topological_sort(&self) -> Option<Vec<usize>> {
let n = self.edges.len();
let mut in_degree = vec![0usize; n];
for succs in &self.edges {
for &s in succs {
if s < n {
in_degree[s] += 1;
}
}
}
let mut queue: std::collections::VecDeque<usize> =
(0..n).filter(|&i| in_degree[i] == 0).collect();
let mut order = Vec::new();
while let Some(node) = queue.pop_front() {
order.push(node);
for &s in self.successors(node) {
if s < n {
in_degree[s] -= 1;
if in_degree[s] == 0 {
queue.push_back(s);
}
}
}
}
if order.len() == n {
Some(order)
} else {
None
}
}
pub fn num_nodes(&self) -> usize {
self.edges.len()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CongrArgKind {
Fixed,
Eq,
HEq,
Cast,
Subsingle,
}
pub struct CongruenceClosure {
parent: HashMap<Expr, Expr>,
rank: HashMap<Expr, u32>,
pending: Vec<(Expr, Expr)>,
apps: HashSet<(Expr, Expr)>,
proofs: HashMap<(Expr, Expr), Expr>,
}
impl CongruenceClosure {
pub fn new() -> Self {
Self {
parent: HashMap::new(),
rank: HashMap::new(),
pending: Vec::new(),
apps: HashSet::new(),
proofs: HashMap::new(),
}
}
pub fn find(&mut self, expr: &Expr) -> Expr {
if !self.parent.contains_key(expr) {
self.parent.insert(expr.clone(), expr.clone());
self.rank.insert(expr.clone(), 0);
return expr.clone();
}
let parent = self
.parent
.get(expr)
.cloned()
.expect("expr must have a parent in the union-find structure");
if &parent == expr {
return expr.clone();
}
let root = self.find(&parent);
self.parent.insert(expr.clone(), root.clone());
root
}
fn union(&mut self, e1: &Expr, e2: &Expr) {
let r1 = self.find(e1);
let r2 = self.find(e2);
if r1 == r2 {
return;
}
let rank1 = *self.rank.get(&r1).unwrap_or(&0);
let rank2 = *self.rank.get(&r2).unwrap_or(&0);
if rank1 < rank2 {
self.parent.insert(r1, r2);
} else if rank1 > rank2 {
self.parent.insert(r2, r1);
} else {
self.parent.insert(r2, r1.clone());
self.rank.insert(r1, rank1 + 1);
}
}
fn register(&mut self, expr: &Expr) {
if !self.parent.contains_key(expr) {
self.parent.insert(expr.clone(), expr.clone());
self.rank.insert(expr.clone(), 0);
}
if let Expr::App(f, a) = expr {
self.apps.insert(((**f).clone(), (**a).clone()));
self.register(f);
self.register(a);
}
}
pub fn add_equality(&mut self, e1: Expr, e2: Expr) {
self.register(&e1);
self.register(&e2);
self.pending.push((e1, e2));
self.process_pending();
}
pub fn add_equality_with_proof(&mut self, e1: Expr, e2: Expr, proof: Expr) {
self.proofs.insert((e1.clone(), e2.clone()), proof);
self.add_equality(e1, e2);
}
fn process_pending(&mut self) {
while let Some((e1, e2)) = self.pending.pop() {
let r1 = self.find(&e1);
let r2 = self.find(&e2);
if r1 == r2 {
continue;
}
self.union(&e1, &e2);
let apps_vec: Vec<_> = self.apps.iter().cloned().collect();
for i in 0..apps_vec.len() {
for j in (i + 1)..apps_vec.len() {
let (f1, a1) = &apps_vec[i];
let (f2, a2) = &apps_vec[j];
let rf1 = self.find(f1);
let rf2 = self.find(f2);
let ra1 = self.find(a1);
let ra2 = self.find(a2);
if rf1 == rf2 && ra1 == ra2 {
let app1 = Expr::App(Box::new(f1.clone()), Box::new(a1.clone()));
let app2 = Expr::App(Box::new(f2.clone()), Box::new(a2.clone()));
if self.find(&app1) != self.find(&app2) {
self.pending.push((app1, app2));
}
}
}
}
}
}
pub fn are_equal(&mut self, e1: &Expr, e2: &Expr) -> bool {
if self.find(e1) == self.find(e2) {
return true;
}
match (e1, e2) {
(Expr::App(f1, a1), Expr::App(f2, a2)) => {
self.are_equal(f1, f2) && self.are_equal(a1, a2)
}
_ => false,
}
}
pub fn get_class(&mut self, expr: &Expr) -> Vec<Expr> {
let root = self.find(expr);
let keys: Vec<Expr> = self.parent.keys().cloned().collect();
keys.into_iter().filter(|e| self.find(e) == root).collect()
}
pub fn num_classes(&mut self) -> usize {
let keys: Vec<Expr> = self.parent.keys().cloned().collect();
let mut roots = HashSet::new();
for k in &keys {
roots.insert(self.find(k));
}
roots.len()
}
pub fn get_proof(&self, e1: &Expr, e2: &Expr) -> Option<&Expr> {
self.proofs
.get(&(e1.clone(), e2.clone()))
.or_else(|| self.proofs.get(&(e2.clone(), e1.clone())))
}
pub fn clear(&mut self) {
self.parent.clear();
self.rank.clear();
self.pending.clear();
self.apps.clear();
self.proofs.clear();
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct FlatApp {
pub fn_idx: TermIdx,
pub arg_idx: TermIdx,
pub result_idx: Option<TermIdx>,
}
#[allow(dead_code)]
pub struct StringPool {
free: Vec<String>,
}
#[allow(dead_code)]
impl StringPool {
pub fn new() -> Self {
Self { free: Vec::new() }
}
pub fn take(&mut self) -> String {
self.free.pop().unwrap_or_default()
}
pub fn give(&mut self, mut s: String) {
s.clear();
self.free.push(s);
}
pub fn free_count(&self) -> usize {
self.free.len()
}
}
#[allow(dead_code)]
pub struct SparseVec<T: Default + Clone + PartialEq> {
entries: std::collections::HashMap<usize, T>,
default_: T,
logical_len: usize,
}
#[allow(dead_code)]
impl<T: Default + Clone + PartialEq> SparseVec<T> {
pub fn new(len: usize) -> Self {
Self {
entries: std::collections::HashMap::new(),
default_: T::default(),
logical_len: len,
}
}
pub fn set(&mut self, idx: usize, val: T) {
if val == self.default_ {
self.entries.remove(&idx);
} else {
self.entries.insert(idx, val);
}
}
pub fn get(&self, idx: usize) -> &T {
self.entries.get(&idx).unwrap_or(&self.default_)
}
pub fn len(&self) -> usize {
self.logical_len
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn nnz(&self) -> usize {
self.entries.len()
}
}
#[allow(dead_code)]
pub struct StackCalc {
stack: Vec<i64>,
}
#[allow(dead_code)]
impl StackCalc {
pub fn new() -> Self {
Self { stack: Vec::new() }
}
pub fn push(&mut self, n: i64) {
self.stack.push(n);
}
pub fn add(&mut self) {
let b = self
.stack
.pop()
.expect("stack must have at least two values for add");
let a = self
.stack
.pop()
.expect("stack must have at least two values for add");
self.stack.push(a + b);
}
pub fn sub(&mut self) {
let b = self
.stack
.pop()
.expect("stack must have at least two values for sub");
let a = self
.stack
.pop()
.expect("stack must have at least two values for sub");
self.stack.push(a - b);
}
pub fn mul(&mut self) {
let b = self
.stack
.pop()
.expect("stack must have at least two values for mul");
let a = self
.stack
.pop()
.expect("stack must have at least two values for mul");
self.stack.push(a * b);
}
pub fn peek(&self) -> Option<i64> {
self.stack.last().copied()
}
pub fn depth(&self) -> usize {
self.stack.len()
}
}
#[derive(Debug, Clone)]
pub struct CongruenceTheorem {
pub fn_name: Name,
pub num_args: usize,
pub arg_kinds: Vec<CongrArgKind>,
pub proof: Option<Expr>,
pub ty: Option<Expr>,
}
impl CongruenceTheorem {
pub fn new(fn_name: Name, arg_kinds: Vec<CongrArgKind>) -> Self {
let num_args = arg_kinds.len();
Self {
fn_name,
num_args,
arg_kinds,
proof: None,
ty: None,
}
}
pub fn has_eq_args(&self) -> bool {
self.arg_kinds
.iter()
.any(|k| matches!(k, CongrArgKind::Eq | CongrArgKind::HEq))
}
pub fn num_eq_hypotheses(&self) -> usize {
self.arg_kinds
.iter()
.filter(|k| matches!(k, CongrArgKind::Eq | CongrArgKind::HEq))
.count()
}
}
#[allow(dead_code)]
pub struct FocusStack<T> {
items: Vec<T>,
}
#[allow(dead_code)]
impl<T> FocusStack<T> {
pub fn new() -> Self {
Self { items: Vec::new() }
}
pub fn focus(&mut self, item: T) {
self.items.push(item);
}
pub fn blur(&mut self) -> Option<T> {
self.items.pop()
}
pub fn current(&self) -> Option<&T> {
self.items.last()
}
pub fn depth(&self) -> usize {
self.items.len()
}
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
}
#[allow(dead_code)]
pub struct Fixture {
data: std::collections::HashMap<String, String>,
}
#[allow(dead_code)]
impl Fixture {
pub fn new() -> Self {
Self {
data: std::collections::HashMap::new(),
}
}
pub fn set(&mut self, key: impl Into<String>, val: impl Into<String>) {
self.data.insert(key.into(), val.into());
}
pub fn get(&self, key: &str) -> Option<&str> {
self.data.get(key).map(|s| s.as_str())
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Debug, Clone)]
pub struct CongrHypothesis {
pub lhs: Expr,
pub rhs: Expr,
pub is_heq: bool,
}
impl CongrHypothesis {
pub fn eq(lhs: Expr, rhs: Expr) -> Self {
Self {
lhs,
rhs,
is_heq: false,
}
}
pub fn heq(lhs: Expr, rhs: Expr) -> Self {
Self {
lhs,
rhs,
is_heq: true,
}
}
pub fn is_trivial(&self) -> bool {
self.lhs == self.rhs
}
}
pub struct FlatCC {
parent: Vec<TermIdx>,
rank: Vec<u32>,
apps: Vec<FlatApp>,
}
impl FlatCC {
pub fn new(n: usize) -> Self {
Self {
parent: (0..n).collect(),
rank: vec![0; n],
apps: Vec::new(),
}
}
pub fn find(&mut self, i: TermIdx) -> TermIdx {
if self.parent[i] != i {
self.parent[i] = self.find(self.parent[i]);
}
self.parent[i]
}
pub fn union(&mut self, i: TermIdx, j: TermIdx) {
let ri = self.find(i);
let rj = self.find(j);
if ri == rj {
return;
}
if self.rank[ri] < self.rank[rj] {
self.parent[ri] = rj;
} else if self.rank[ri] > self.rank[rj] {
self.parent[rj] = ri;
} else {
self.parent[rj] = ri;
self.rank[ri] += 1;
}
}
pub fn add_node(&mut self) -> TermIdx {
let idx = self.parent.len();
self.parent.push(idx);
self.rank.push(0);
idx
}
pub fn add_app(&mut self, app: FlatApp) {
self.apps.push(app);
}
pub fn are_equal(&mut self, i: TermIdx, j: TermIdx) -> bool {
self.find(i) == self.find(j)
}
pub fn num_nodes(&self) -> usize {
self.parent.len()
}
pub fn num_apps(&self) -> usize {
self.apps.len()
}
pub fn propagate_congruences(&mut self) {
let len = self.apps.len();
for i in 0..len {
for j in (i + 1)..len {
let a = self.apps[i];
let b = self.apps[j];
if self.are_equal(a.fn_idx, b.fn_idx) && self.are_equal(a.arg_idx, b.arg_idx) {
if let (Some(ra), Some(rb)) = (a.result_idx, b.result_idx) {
self.union(ra, rb);
}
}
}
}
}
}
#[derive(Debug, Clone)]
pub struct ENode {
pub repr: Expr,
pub members: Vec<Expr>,
pub proofs: Vec<Option<Expr>>,
}
impl ENode {
pub fn singleton(expr: Expr) -> Self {
Self {
repr: expr.clone(),
members: vec![expr],
proofs: vec![None],
}
}
pub fn add_member(&mut self, expr: Expr, proof: Option<Expr>) {
if !self.members.contains(&expr) {
self.members.push(expr);
self.proofs.push(proof);
}
}
pub fn contains(&self, expr: &Expr) -> bool {
self.members.contains(expr)
}
pub fn size(&self) -> usize {
self.members.len()
}
}
#[allow(dead_code)]
pub struct RawFnPtr {
ptr: usize,
arity: usize,
name: String,
}
#[allow(dead_code)]
impl RawFnPtr {
pub fn new(ptr: usize, arity: usize, name: impl Into<String>) -> Self {
Self {
ptr,
arity,
name: name.into(),
}
}
pub fn arity(&self) -> usize {
self.arity
}
pub fn name(&self) -> &str {
&self.name
}
pub fn raw(&self) -> usize {
self.ptr
}
}
#[allow(dead_code)]
pub struct LabelSet {
labels: Vec<String>,
}
#[allow(dead_code)]
impl LabelSet {
pub fn new() -> Self {
Self { labels: Vec::new() }
}
pub fn add(&mut self, label: impl Into<String>) {
let s = label.into();
if !self.labels.contains(&s) {
self.labels.push(s);
}
}
pub fn has(&self, label: &str) -> bool {
self.labels.iter().any(|l| l == label)
}
pub fn count(&self) -> usize {
self.labels.len()
}
pub fn all(&self) -> &[String] {
&self.labels
}
}
#[allow(dead_code)]
pub struct MinHeap<T: Ord> {
data: Vec<T>,
}
#[allow(dead_code)]
impl<T: Ord> MinHeap<T> {
pub fn new() -> Self {
Self { data: Vec::new() }
}
pub fn push(&mut self, val: T) {
self.data.push(val);
self.sift_up(self.data.len() - 1);
}
pub fn pop(&mut self) -> Option<T> {
if self.data.is_empty() {
return None;
}
let n = self.data.len();
self.data.swap(0, n - 1);
let min = self.data.pop();
if !self.data.is_empty() {
self.sift_down(0);
}
min
}
pub fn peek(&self) -> Option<&T> {
self.data.first()
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
fn sift_up(&mut self, mut i: usize) {
while i > 0 {
let parent = (i - 1) / 2;
if self.data[i] < self.data[parent] {
self.data.swap(i, parent);
i = parent;
} else {
break;
}
}
}
fn sift_down(&mut self, mut i: usize) {
let n = self.data.len();
loop {
let left = 2 * i + 1;
let right = 2 * i + 2;
let mut smallest = i;
if left < n && self.data[left] < self.data[smallest] {
smallest = left;
}
if right < n && self.data[right] < self.data[smallest] {
smallest = right;
}
if smallest == i {
break;
}
self.data.swap(i, smallest);
i = smallest;
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CongrClosureStats {
pub equalities_added: usize,
pub congruences_propagated: usize,
pub unions: usize,
pub apps_tracked: usize,
}
impl CongrClosureStats {
pub fn new() -> Self {
Self::default()
}
}
#[allow(dead_code)]
pub struct SmallMap<K: Ord + Clone, V: Clone> {
entries: Vec<(K, V)>,
}
#[allow(dead_code)]
impl<K: Ord + Clone, V: Clone> SmallMap<K, V> {
pub fn new() -> Self {
Self {
entries: Vec::new(),
}
}
pub fn insert(&mut self, key: K, val: V) {
match self.entries.binary_search_by_key(&&key, |(k, _)| k) {
Ok(i) => self.entries[i].1 = val,
Err(i) => self.entries.insert(i, (key, val)),
}
}
pub fn get(&self, key: &K) -> Option<&V> {
self.entries
.binary_search_by_key(&key, |(k, _)| k)
.ok()
.map(|i| &self.entries[i].1)
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn keys(&self) -> Vec<&K> {
self.entries.iter().map(|(k, _)| k).collect()
}
pub fn values(&self) -> Vec<&V> {
self.entries.iter().map(|(_, v)| v).collect()
}
}