#[allow(unused_imports)]
use crate::prelude::*;
use core::cmp::Ordering;
pub type TermId = u32;
pub type TheoryId = u32;
pub type DecisionLevel = u32;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Equality {
pub lhs: TermId,
pub rhs: TermId,
}
impl Equality {
pub fn new(lhs: TermId, rhs: TermId) -> Self {
if lhs <= rhs {
Self { lhs, rhs }
} else {
Self { lhs: rhs, rhs: lhs }
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GenerationStrategy {
Eager,
Lazy,
Minimal,
Incremental,
Adaptive,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct EqualityPriority {
pub level: u32,
pub relevancy: u32,
pub decision_level: DecisionLevel,
}
impl Ord for EqualityPriority {
fn cmp(&self, other: &Self) -> Ordering {
self.level
.cmp(&other.level)
.then_with(|| self.relevancy.cmp(&other.relevancy))
.then_with(|| other.decision_level.cmp(&self.decision_level))
}
}
impl PartialOrd for EqualityPriority {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[derive(Debug, Clone)]
pub struct InterfaceEquality {
pub equality: Equality,
pub theories: FxHashSet<TheoryId>,
pub priority: EqualityPriority,
pub is_necessary: bool,
pub timestamp: u64,
}
impl PartialEq for InterfaceEquality {
fn eq(&self, other: &Self) -> bool {
self.equality == other.equality
}
}
impl Eq for InterfaceEquality {}
impl Ord for InterfaceEquality {
fn cmp(&self, other: &Self) -> Ordering {
self.priority.cmp(&other.priority)
}
}
impl PartialOrd for InterfaceEquality {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[derive(Debug, Clone)]
pub struct InterfaceEClass {
pub representative: TermId,
pub members: FxHashSet<TermId>,
pub theories: FxHashSet<TheoryId>,
pub strategy: GenerationStrategy,
}
impl InterfaceEClass {
fn new(representative: TermId, theory: TheoryId) -> Self {
let mut members = FxHashSet::default();
members.insert(representative);
let mut theories = FxHashSet::default();
theories.insert(theory);
Self {
representative,
members,
theories,
strategy: GenerationStrategy::Minimal,
}
}
fn add_term(&mut self, term: TermId, theory: TheoryId) {
self.members.insert(term);
self.theories.insert(theory);
}
fn merge(&mut self, other: &InterfaceEClass) {
for &term in &other.members {
self.members.insert(term);
}
for &theory in &other.theories {
self.theories.insert(theory);
}
}
fn is_shared(&self) -> bool {
self.theories.len() > 1
}
fn generate_equalities(
&self,
timestamp: u64,
decision_level: DecisionLevel,
) -> Vec<InterfaceEquality> {
match self.strategy {
GenerationStrategy::Eager => self.generate_eager(timestamp, decision_level),
GenerationStrategy::Lazy => Vec::new(), GenerationStrategy::Minimal => self.generate_minimal(timestamp, decision_level),
GenerationStrategy::Incremental => self.generate_incremental(timestamp, decision_level),
GenerationStrategy::Adaptive => self.generate_adaptive(timestamp, decision_level),
}
}
fn generate_eager(
&self,
timestamp: u64,
decision_level: DecisionLevel,
) -> Vec<InterfaceEquality> {
let mut equalities = Vec::new();
let members: Vec<_> = self.members.iter().copied().collect();
for i in 0..members.len() {
for j in (i + 1)..members.len() {
equalities.push(InterfaceEquality {
equality: Equality::new(members[i], members[j]),
theories: self.theories.clone(),
priority: EqualityPriority {
level: 100,
relevancy: 50,
decision_level,
},
is_necessary: false,
timestamp,
});
}
}
equalities
}
fn generate_minimal(
&self,
timestamp: u64,
decision_level: DecisionLevel,
) -> Vec<InterfaceEquality> {
let mut equalities = Vec::new();
let rep = self.representative;
for &term in &self.members {
if term != rep {
equalities.push(InterfaceEquality {
equality: Equality::new(term, rep),
theories: self.theories.clone(),
priority: EqualityPriority {
level: 100,
relevancy: 50,
decision_level,
},
is_necessary: true,
timestamp,
});
}
}
equalities
}
fn generate_incremental(
&self,
timestamp: u64,
decision_level: DecisionLevel,
) -> Vec<InterfaceEquality> {
self.generate_minimal(timestamp, decision_level)
}
fn generate_adaptive(
&self,
timestamp: u64,
decision_level: DecisionLevel,
) -> Vec<InterfaceEquality> {
if self.members.len() <= 2 {
self.generate_eager(timestamp, decision_level)
} else {
self.generate_minimal(timestamp, decision_level)
}
}
}
#[derive(Debug, Clone)]
pub struct InterfaceEqualityConfig {
pub default_strategy: GenerationStrategy,
pub enable_minimization: bool,
pub enable_priority: bool,
pub max_batch_size: usize,
pub track_relevancy: bool,
pub adaptive_threshold: usize,
}
impl Default for InterfaceEqualityConfig {
fn default() -> Self {
Self {
default_strategy: GenerationStrategy::Minimal,
enable_minimization: true,
enable_priority: true,
max_batch_size: 1000,
track_relevancy: true,
adaptive_threshold: 10,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct InterfaceEqualityStats {
pub equalities_generated: u64,
pub equalities_minimized: u64,
pub eager_generations: u64,
pub lazy_generations: u64,
pub minimal_generations: u64,
pub eclasses: u64,
pub batches_sent: u64,
}
pub struct InterfaceEqualityManager {
config: InterfaceEqualityConfig,
stats: InterfaceEqualityStats,
term_to_eclass: FxHashMap<TermId, usize>,
eclasses: Vec<InterfaceEClass>,
pending: BinaryHeap<InterfaceEquality>,
generated: FxHashSet<Equality>,
timestamp: u64,
decision_level: DecisionLevel,
relevancy: FxHashMap<TermId, u32>,
history: FxHashMap<DecisionLevel, Vec<Equality>>,
}
impl InterfaceEqualityManager {
pub fn new() -> Self {
Self::with_config(InterfaceEqualityConfig::default())
}
pub fn with_config(config: InterfaceEqualityConfig) -> Self {
Self {
config,
stats: InterfaceEqualityStats::default(),
term_to_eclass: FxHashMap::default(),
eclasses: Vec::new(),
pending: BinaryHeap::new(),
generated: FxHashSet::default(),
timestamp: 0,
decision_level: 0,
relevancy: FxHashMap::default(),
history: FxHashMap::default(),
}
}
pub fn stats(&self) -> &InterfaceEqualityStats {
&self.stats
}
pub fn register_term(&mut self, term: TermId, theory: TheoryId) {
if let Some(&eclass_id) = self.term_to_eclass.get(&term) {
self.eclasses[eclass_id].add_term(term, theory);
} else {
let eclass_id = self.eclasses.len();
self.eclasses.push(InterfaceEClass::new(term, theory));
self.term_to_eclass.insert(term, eclass_id);
self.stats.eclasses += 1;
}
}
pub fn assert_equality(&mut self, lhs: TermId, rhs: TermId) -> Result<(), String> {
let lhs_class = self.find_or_create(lhs);
let rhs_class = self.find_or_create(rhs);
if lhs_class == rhs_class {
return Ok(());
}
let (small, large) =
if self.eclasses[lhs_class].members.len() < self.eclasses[rhs_class].members.len() {
(lhs_class, rhs_class)
} else {
(rhs_class, lhs_class)
};
let small_eclass = self.eclasses[small].clone();
self.eclasses[large].merge(&small_eclass);
for &term in &small_eclass.members {
self.term_to_eclass.insert(term, large);
}
if self.eclasses[large].is_shared() {
self.generate_equalities_for_class(large)?;
}
Ok(())
}
fn find_or_create(&mut self, term: TermId) -> usize {
if let Some(&eclass_id) = self.term_to_eclass.get(&term) {
eclass_id
} else {
let eclass_id = self.eclasses.len();
self.eclasses.push(InterfaceEClass::new(term, 0));
self.term_to_eclass.insert(term, eclass_id);
self.stats.eclasses += 1;
eclass_id
}
}
fn generate_equalities_for_class(&mut self, eclass_id: usize) -> Result<(), String> {
if eclass_id >= self.eclasses.len() {
return Err("Invalid eclass ID".to_string());
}
let eclass = &self.eclasses[eclass_id];
let equalities = eclass.generate_equalities(self.timestamp, self.decision_level);
for eq in equalities {
if !self.generated.contains(&eq.equality) {
self.generated.insert(eq.equality);
self.pending.push(eq);
self.stats.equalities_generated += 1;
match eclass.strategy {
GenerationStrategy::Eager => self.stats.eager_generations += 1,
GenerationStrategy::Lazy => self.stats.lazy_generations += 1,
GenerationStrategy::Minimal => self.stats.minimal_generations += 1,
_ => {}
}
}
}
self.timestamp += 1;
Ok(())
}
pub fn get_pending_batch(&mut self) -> Vec<InterfaceEquality> {
let mut batch = Vec::new();
while batch.len() < self.config.max_batch_size {
if let Some(eq) = self.pending.pop() {
batch.push(eq);
} else {
break;
}
}
if !batch.is_empty() {
self.stats.batches_sent += 1;
}
batch
}
pub fn get_all_pending(&mut self) -> Vec<InterfaceEquality> {
let mut all = Vec::new();
while let Some(eq) = self.pending.pop() {
all.push(eq);
}
if !all.is_empty() {
self.stats.batches_sent += 1;
}
all
}
pub fn set_strategy(
&mut self,
term: TermId,
strategy: GenerationStrategy,
) -> Result<(), String> {
let eclass_id = self
.term_to_eclass
.get(&term)
.ok_or("Term not registered")?;
self.eclasses[*eclass_id].strategy = strategy;
Ok(())
}
pub fn minimize_equalities(&mut self) {
if !self.config.enable_minimization {
return;
}
let all_pending: Vec<_> = self.pending.drain().collect();
let mut necessary = Vec::new();
let mut by_class: FxHashMap<usize, Vec<InterfaceEquality>> = FxHashMap::default();
for eq in all_pending {
if let Some(&eclass_id) = self.term_to_eclass.get(&eq.equality.lhs) {
by_class.entry(eclass_id).or_default().push(eq);
}
}
for (_eclass_id, mut equalities) in by_class {
if equalities.len() <= 2 {
necessary.extend(equalities);
continue;
}
let rep = equalities[0].equality.lhs;
equalities.retain(|eq| eq.equality.lhs == rep || eq.equality.rhs == rep);
let before = equalities.len();
let minimized = equalities.len();
self.stats.equalities_minimized += (before - minimized) as u64;
necessary.extend(equalities);
}
for eq in necessary {
self.pending.push(eq);
}
}
pub fn update_relevancy(&mut self, term: TermId, score: u32) {
if !self.config.track_relevancy {
return;
}
self.relevancy.insert(term, score);
let all_pending: Vec<_> = self.pending.drain().collect();
for mut eq in all_pending {
if eq.equality.lhs == term || eq.equality.rhs == term {
eq.priority.relevancy = score;
}
self.pending.push(eq);
}
}
pub fn push_decision_level(&mut self) {
self.decision_level += 1;
}
pub fn backtrack(&mut self, level: DecisionLevel) -> Result<(), String> {
if level > self.decision_level {
return Err("Cannot backtrack to future level".to_string());
}
let all_pending: Vec<_> = self.pending.drain().collect();
for eq in all_pending {
if eq.priority.decision_level <= level {
self.pending.push(eq);
} else {
self.generated.remove(&eq.equality);
}
}
let levels_to_remove: Vec<_> = self
.history
.keys()
.filter(|&&l| l > level)
.copied()
.collect();
for l in levels_to_remove {
self.history.remove(&l);
}
self.decision_level = level;
Ok(())
}
pub fn clear(&mut self) {
self.term_to_eclass.clear();
self.eclasses.clear();
self.pending.clear();
self.generated.clear();
self.timestamp = 0;
self.decision_level = 0;
self.relevancy.clear();
self.history.clear();
}
pub fn reset_stats(&mut self) {
self.stats = InterfaceEqualityStats::default();
}
pub fn pending_count(&self) -> usize {
self.pending.len()
}
pub fn is_generated(&self, eq: &Equality) -> bool {
self.generated.contains(eq)
}
pub fn force_generate_all(&mut self) -> Result<(), String> {
for eclass_id in 0..self.eclasses.len() {
if self.eclasses[eclass_id].is_shared() {
self.generate_equalities_for_class(eclass_id)?;
}
}
Ok(())
}
pub fn get_eclass(&self, term: TermId) -> Option<&InterfaceEClass> {
self.term_to_eclass
.get(&term)
.and_then(|&id| self.eclasses.get(id))
}
pub fn get_representative(&self, term: TermId) -> Option<TermId> {
self.get_eclass(term).map(|ec| ec.representative)
}
pub fn are_equal(&self, lhs: TermId, rhs: TermId) -> bool {
if let (Some(&lhs_class), Some(&rhs_class)) =
(self.term_to_eclass.get(&lhs), self.term_to_eclass.get(&rhs))
{
lhs_class == rhs_class
} else {
false
}
}
}
impl Default for InterfaceEqualityManager {
fn default() -> Self {
Self::new()
}
}
pub struct EqualityScheduler {
scheduled: BinaryHeap<InterfaceEquality>,
policy: SchedulingPolicy,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SchedulingPolicy {
Fifo,
Priority,
Relevancy,
RoundRobin,
}
impl EqualityScheduler {
pub fn new(policy: SchedulingPolicy) -> Self {
Self {
scheduled: BinaryHeap::new(),
policy,
}
}
pub fn schedule(&mut self, equality: InterfaceEquality) {
self.scheduled.push(equality);
}
#[allow(clippy::should_implement_trait)]
pub fn next(&mut self) -> Option<InterfaceEquality> {
match self.policy {
SchedulingPolicy::Fifo => {
let all: Vec<_> = self.scheduled.drain().collect();
all.into_iter().next()
}
SchedulingPolicy::Priority | SchedulingPolicy::Relevancy => self.scheduled.pop(),
SchedulingPolicy::RoundRobin => {
self.scheduled.pop()
}
}
}
pub fn next_batch(&mut self, size: usize) -> Vec<InterfaceEquality> {
let mut batch = Vec::new();
for _ in 0..size {
if let Some(eq) = self.next() {
batch.push(eq);
} else {
break;
}
}
batch
}
pub fn clear(&mut self) {
self.scheduled.clear();
}
}
pub struct EqualityMinimizer {
parent: FxHashMap<TermId, TermId>,
rank: FxHashMap<TermId, usize>,
}
impl EqualityMinimizer {
pub fn new() -> Self {
Self {
parent: FxHashMap::default(),
rank: FxHashMap::default(),
}
}
pub fn add_equality(&mut self, eq: Equality) {
let lhs_root = self.find(eq.lhs);
let rhs_root = self.find(eq.rhs);
if lhs_root == rhs_root {
return;
}
let lhs_rank = self.rank.get(&lhs_root).copied().unwrap_or(0);
let rhs_rank = self.rank.get(&rhs_root).copied().unwrap_or(0);
if lhs_rank < rhs_rank {
self.parent.insert(lhs_root, rhs_root);
} else if lhs_rank > rhs_rank {
self.parent.insert(rhs_root, lhs_root);
} else {
self.parent.insert(lhs_root, rhs_root);
self.rank.insert(rhs_root, rhs_rank + 1);
}
}
fn find(&mut self, mut term: TermId) -> TermId {
let mut path = Vec::new();
while let Some(&parent) = self.parent.get(&term) {
if parent == term {
break;
}
path.push(term);
term = parent;
}
for node in path {
self.parent.insert(node, term);
}
term
}
pub fn is_redundant(&mut self, eq: &Equality) -> bool {
self.find(eq.lhs) == self.find(eq.rhs)
}
pub fn minimize(&mut self, equalities: Vec<Equality>) -> Vec<Equality> {
let mut minimal = Vec::new();
for eq in equalities {
if !self.is_redundant(&eq) {
self.add_equality(eq);
minimal.push(eq);
}
}
minimal
}
pub fn clear(&mut self) {
self.parent.clear();
self.rank.clear();
}
}
impl Default for EqualityMinimizer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_interface_eclass() {
let mut eclass = InterfaceEClass::new(1, 0);
eclass.add_term(2, 1);
assert_eq!(eclass.members.len(), 2);
assert!(eclass.is_shared());
}
#[test]
fn test_minimal_generation() {
let mut eclass = InterfaceEClass::new(1, 0);
eclass.add_term(2, 0);
eclass.add_term(3, 0);
let equalities = eclass.generate_minimal(0, 0);
assert_eq!(equalities.len(), 2); }
#[test]
fn test_manager_creation() {
let manager = InterfaceEqualityManager::new();
assert_eq!(manager.stats().equalities_generated, 0);
}
#[test]
fn test_register_term() {
let mut manager = InterfaceEqualityManager::new();
manager.register_term(1, 0);
manager.register_term(1, 1);
assert_eq!(manager.stats().eclasses, 1);
}
#[test]
fn test_assert_equality() {
let mut manager = InterfaceEqualityManager::new();
manager.register_term(1, 0);
manager.register_term(2, 1);
manager.assert_equality(1, 2).expect("Assert failed");
assert!(manager.are_equal(1, 2));
}
#[test]
fn test_get_pending() {
let mut manager = InterfaceEqualityManager::new();
manager.register_term(1, 0);
manager.register_term(2, 1);
manager.register_term(1, 1);
manager.assert_equality(1, 2).expect("Assert failed");
let pending = manager.get_all_pending();
assert!(!pending.is_empty());
}
#[test]
fn test_minimization() {
let mut manager = InterfaceEqualityManager::new();
for i in 1..=5 {
manager.register_term(i, 0);
manager.register_term(i, 1);
}
for i in 2..=5 {
manager.assert_equality(1, i).expect("Assert failed");
}
manager.minimize_equalities();
let pending = manager.get_all_pending();
assert!(pending.len() <= 4);
}
#[test]
fn test_scheduler() {
let mut scheduler = EqualityScheduler::new(SchedulingPolicy::Priority);
let eq = InterfaceEquality {
equality: Equality::new(1, 2),
theories: FxHashSet::default(),
priority: EqualityPriority {
level: 100,
relevancy: 50,
decision_level: 0,
},
is_necessary: true,
timestamp: 0,
};
scheduler.schedule(eq);
assert!(scheduler.next().is_some());
}
#[test]
fn test_minimizer() {
let mut minimizer = EqualityMinimizer::new();
let eq1 = Equality::new(1, 2);
let eq2 = Equality::new(2, 3);
let eq3 = Equality::new(1, 3);
minimizer.add_equality(eq1);
minimizer.add_equality(eq2);
assert!(minimizer.is_redundant(&eq3));
}
#[test]
fn test_backtrack() {
let mut manager = InterfaceEqualityManager::new();
manager.push_decision_level();
manager.register_term(1, 0);
manager.backtrack(0).expect("Backtrack failed");
}
#[test]
fn test_set_strategy() {
let mut manager = InterfaceEqualityManager::new();
manager.register_term(1, 0);
manager
.set_strategy(1, GenerationStrategy::Eager)
.expect("Set strategy failed");
let eclass = manager.get_eclass(1).expect("No eclass");
assert_eq!(eclass.strategy, GenerationStrategy::Eager);
}
}