#[allow(unused_imports)]
use crate::prelude::*;
pub type TermId = u32;
pub type TheoryId = u32;
pub type DecisionLevel = u32;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Equality {
pub lhs: TermId,
pub rhs: TermId,
}
impl Equality {
pub fn new(lhs: TermId, rhs: TermId) -> Self {
if lhs <= rhs {
Self { lhs, rhs }
} else {
Self { lhs: rhs, rhs: lhs }
}
}
}
#[derive(Debug, Clone)]
pub struct EqualityDisjunction {
pub disjuncts: Vec<Equality>,
pub theory: TheoryId,
pub level: DecisionLevel,
}
impl EqualityDisjunction {
pub fn new(disjuncts: Vec<Equality>, theory: TheoryId, level: DecisionLevel) -> Self {
Self {
disjuncts,
theory,
level,
}
}
pub fn is_unit(&self) -> bool {
self.disjuncts.len() == 1
}
pub fn get_unit(&self) -> Option<Equality> {
if self.is_unit() {
self.disjuncts.first().copied()
} else {
None
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConvexityProperty {
Convex,
NonConvex,
Unknown,
}
#[derive(Debug, Clone)]
pub struct TheoryModel {
pub theory: TheoryId,
pub assignments: FxHashMap<TermId, TermId>,
pub equalities: Vec<Equality>,
}
impl TheoryModel {
pub fn new(theory: TheoryId) -> Self {
Self {
theory,
assignments: FxHashMap::default(),
equalities: Vec::new(),
}
}
pub fn add_assignment(&mut self, term: TermId, value: TermId) {
self.assignments.insert(term, value);
}
pub fn get_assignment(&self, term: TermId) -> Option<TermId> {
self.assignments.get(&term).copied()
}
pub fn add_equality(&mut self, eq: Equality) {
self.equalities.push(eq);
}
}
#[derive(Debug, Clone)]
pub struct ConvexityConfig {
pub model_based_splitting: bool,
pub max_case_splits: usize,
pub conflict_driven_learning: bool,
pub split_strategy: CaseSplitStrategy,
pub simplify_disjunctions: bool,
}
impl Default for ConvexityConfig {
fn default() -> Self {
Self {
model_based_splitting: true,
max_case_splits: 100,
conflict_driven_learning: true,
split_strategy: CaseSplitStrategy::ModelBased,
simplify_disjunctions: true,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CaseSplitStrategy {
Exhaustive,
ModelBased,
Heuristic,
Lazy,
}
#[derive(Debug, Clone, Default)]
pub struct ConvexityStats {
pub disjunctions_processed: u64,
pub case_splits: u64,
pub model_based_decisions: u64,
pub case_split_conflicts: u64,
pub learned_constraints: u64,
}
pub struct ConvexityHandler {
config: ConvexityConfig,
stats: ConvexityStats,
theory_properties: FxHashMap<TheoryId, ConvexityProperty>,
pending_disjunctions: VecDeque<EqualityDisjunction>,
case_split_stack: Vec<CaseSplit>,
learned: Vec<Vec<Equality>>,
decision_level: DecisionLevel,
}
#[derive(Debug, Clone)]
struct CaseSplit {
level: DecisionLevel,
disjunction: EqualityDisjunction,
tried_cases: FxHashSet<usize>,
current_case: Option<usize>,
}
impl ConvexityHandler {
pub fn new() -> Self {
Self::with_config(ConvexityConfig::default())
}
pub fn with_config(config: ConvexityConfig) -> Self {
Self {
config,
stats: ConvexityStats::default(),
theory_properties: FxHashMap::default(),
pending_disjunctions: VecDeque::new(),
case_split_stack: Vec::new(),
learned: Vec::new(),
decision_level: 0,
}
}
pub fn stats(&self) -> &ConvexityStats {
&self.stats
}
pub fn register_theory(&mut self, theory: TheoryId, property: ConvexityProperty) {
self.theory_properties.insert(theory, property);
}
pub fn is_convex(&self, theory: TheoryId) -> bool {
matches!(
self.theory_properties.get(&theory),
Some(ConvexityProperty::Convex)
)
}
pub fn add_disjunction(&mut self, disjunction: EqualityDisjunction) {
if self.config.simplify_disjunctions
&& let Some(simplified) = self.simplify_disjunction(&disjunction)
{
self.pending_disjunctions.push_back(simplified);
self.stats.disjunctions_processed += 1;
return;
}
self.pending_disjunctions.push_back(disjunction);
self.stats.disjunctions_processed += 1;
}
fn simplify_disjunction(
&self,
disjunction: &EqualityDisjunction,
) -> Option<EqualityDisjunction> {
let mut unique_disjuncts = Vec::new();
let mut seen = FxHashSet::default();
for &eq in &disjunction.disjuncts {
if seen.insert(eq) {
unique_disjuncts.push(eq);
}
}
if unique_disjuncts.len() == disjunction.disjuncts.len() {
return None; }
Some(EqualityDisjunction::new(
unique_disjuncts,
disjunction.theory,
disjunction.level,
))
}
pub fn process_disjunctions(&mut self) -> Result<Option<Equality>, String> {
while let Some(disjunction) = self.pending_disjunctions.pop_front() {
if let Some(eq) = disjunction.get_unit() {
return Ok(Some(eq));
}
if self.stats.case_splits >= self.config.max_case_splits as u64 {
return Err("Maximum case splits exceeded".to_string());
}
match self.config.split_strategy {
CaseSplitStrategy::ModelBased => {
return self.model_based_split(&disjunction);
}
CaseSplitStrategy::Exhaustive => {
return self.exhaustive_split(&disjunction);
}
CaseSplitStrategy::Heuristic => {
return self.heuristic_split(&disjunction);
}
CaseSplitStrategy::Lazy => {
self.pending_disjunctions.push_back(disjunction);
continue;
}
}
}
Ok(None)
}
fn model_based_split(
&mut self,
disjunction: &EqualityDisjunction,
) -> Result<Option<Equality>, String> {
self.stats.case_splits += 1;
self.stats.model_based_decisions += 1;
if let Some(&eq) = disjunction.disjuncts.first() {
let split = CaseSplit {
level: self.decision_level,
disjunction: disjunction.clone(),
tried_cases: {
let mut set = FxHashSet::default();
set.insert(0);
set
},
current_case: Some(0),
};
self.case_split_stack.push(split);
return Ok(Some(eq));
}
Err("Empty disjunction".to_string())
}
fn exhaustive_split(
&mut self,
disjunction: &EqualityDisjunction,
) -> Result<Option<Equality>, String> {
self.stats.case_splits += 1;
if let Some((i, &eq)) = disjunction.disjuncts.iter().enumerate().next() {
let split = CaseSplit {
level: self.decision_level,
disjunction: disjunction.clone(),
tried_cases: {
let mut set = FxHashSet::default();
set.insert(i);
set
},
current_case: Some(i),
};
self.case_split_stack.push(split);
return Ok(Some(eq));
}
Err("Empty disjunction".to_string())
}
fn heuristic_split(
&mut self,
disjunction: &EqualityDisjunction,
) -> Result<Option<Equality>, String> {
self.model_based_split(disjunction)
}
pub fn backtrack_case_split(&mut self) -> Result<Option<Equality>, String> {
while let Some(mut split) = self.case_split_stack.pop() {
for (i, &eq) in split.disjunction.disjuncts.iter().enumerate() {
if !split.tried_cases.contains(&i) {
split.tried_cases.insert(i);
split.current_case = Some(i);
self.case_split_stack.push(split);
return Ok(Some(eq));
}
}
if self.config.conflict_driven_learning {
self.learn_conflict(&split.disjunction);
}
self.stats.case_split_conflicts += 1;
}
Ok(None) }
fn learn_conflict(&mut self, disjunction: &EqualityDisjunction) {
self.learned.push(disjunction.disjuncts.clone());
self.stats.learned_constraints += 1;
}
pub fn push_decision_level(&mut self) {
self.decision_level += 1;
}
pub fn backtrack(&mut self, level: DecisionLevel) -> Result<(), String> {
if level > self.decision_level {
return Err("Cannot backtrack to future level".to_string());
}
self.case_split_stack.retain(|split| split.level <= level);
let pending: Vec<_> = self.pending_disjunctions.drain(..).collect();
for disjunction in pending {
if disjunction.level <= level {
self.pending_disjunctions.push_back(disjunction);
}
}
self.decision_level = level;
Ok(())
}
pub fn learned_constraints(&self) -> &[Vec<Equality>] {
&self.learned
}
pub fn clear(&mut self) {
self.pending_disjunctions.clear();
self.case_split_stack.clear();
self.learned.clear();
self.decision_level = 0;
}
pub fn reset_stats(&mut self) {
self.stats = ConvexityStats::default();
}
pub fn has_pending(&self) -> bool {
!self.pending_disjunctions.is_empty()
}
pub fn pending_count(&self) -> usize {
self.pending_disjunctions.len()
}
}
impl Default for ConvexityHandler {
fn default() -> Self {
Self::new()
}
}
pub struct ModelBasedCombination {
models: FxHashMap<TheoryId, TheoryModel>,
derived_equalities: Vec<Equality>,
}
impl ModelBasedCombination {
pub fn new() -> Self {
Self {
models: FxHashMap::default(),
derived_equalities: Vec::new(),
}
}
pub fn add_model(&mut self, model: TheoryModel) {
self.models.insert(model.theory, model);
}
pub fn combine_models(&mut self) -> Result<Vec<Equality>, String> {
self.derived_equalities.clear();
let mut all_terms = FxHashSet::default();
for model in self.models.values() {
for &term in model.assignments.keys() {
all_terms.insert(term);
}
}
for &term1 in &all_terms {
for &term2 in &all_terms {
if term1 >= term2 {
continue;
}
let mut all_agree = true;
for model in self.models.values() {
if let (Some(val1), Some(val2)) =
(model.get_assignment(term1), model.get_assignment(term2))
&& val1 != val2
{
all_agree = false;
break;
}
}
if all_agree {
self.derived_equalities.push(Equality::new(term1, term2));
}
}
}
Ok(self.derived_equalities.clone())
}
pub fn clear(&mut self) {
self.models.clear();
self.derived_equalities.clear();
}
}
impl Default for ModelBasedCombination {
fn default() -> Self {
Self::new()
}
}
pub struct DisjunctiveReasoning {
disjunctions: Vec<EqualityDisjunction>,
unit_queue: VecDeque<Equality>,
}
impl DisjunctiveReasoning {
pub fn new() -> Self {
Self {
disjunctions: Vec::new(),
unit_queue: VecDeque::new(),
}
}
pub fn add_disjunction(&mut self, disjunction: EqualityDisjunction) {
if disjunction.is_unit() {
if let Some(eq) = disjunction.get_unit() {
self.unit_queue.push_back(eq);
}
} else {
self.disjunctions.push(disjunction);
}
}
pub fn propagate_units(&mut self) -> Vec<Equality> {
let mut propagated = Vec::new();
while let Some(eq) = self.unit_queue.pop_front() {
propagated.push(eq);
}
propagated
}
pub fn simplify_with_equality(&mut self, eq: Equality) {
let mut simplified = Vec::new();
for disjunction in self.disjunctions.drain(..) {
let mut new_disjuncts = Vec::new();
for &disjunct in &disjunction.disjuncts {
if disjunct != eq {
new_disjuncts.push(disjunct);
}
}
if !new_disjuncts.is_empty() {
let new_disjunction =
EqualityDisjunction::new(new_disjuncts, disjunction.theory, disjunction.level);
if new_disjunction.is_unit() {
if let Some(unit_eq) = new_disjunction.get_unit() {
self.unit_queue.push_back(unit_eq);
}
} else {
simplified.push(new_disjunction);
}
}
}
self.disjunctions = simplified;
}
pub fn has_conflict(&self) -> bool {
false }
pub fn clear(&mut self) {
self.disjunctions.clear();
self.unit_queue.clear();
}
}
impl Default for DisjunctiveReasoning {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_equality_disjunction() {
let eq1 = Equality::new(1, 2);
let eq2 = Equality::new(3, 4);
let disj = EqualityDisjunction::new(vec![eq1, eq2], 0, 0);
assert!(!disj.is_unit());
}
#[test]
fn test_unit_disjunction() {
let eq = Equality::new(1, 2);
let disj = EqualityDisjunction::new(vec![eq], 0, 0);
assert!(disj.is_unit());
assert_eq!(disj.get_unit(), Some(eq));
}
#[test]
fn test_handler_creation() {
let handler = ConvexityHandler::new();
assert_eq!(handler.stats().disjunctions_processed, 0);
}
#[test]
fn test_register_theory() {
let mut handler = ConvexityHandler::new();
handler.register_theory(0, ConvexityProperty::Convex);
assert!(handler.is_convex(0));
}
#[test]
fn test_add_disjunction() {
let mut handler = ConvexityHandler::new();
let disj = EqualityDisjunction::new(vec![Equality::new(1, 2)], 0, 0);
handler.add_disjunction(disj);
assert_eq!(handler.pending_count(), 1);
}
#[test]
fn test_process_unit_disjunction() {
let mut handler = ConvexityHandler::new();
let eq = Equality::new(1, 2);
let disj = EqualityDisjunction::new(vec![eq], 0, 0);
handler.add_disjunction(disj);
let result = handler.process_disjunctions();
assert!(result.is_ok());
assert_eq!(result.ok().flatten(), Some(eq));
}
#[test]
fn test_model_based_combination() {
let mut mbc = ModelBasedCombination::new();
let mut model1 = TheoryModel::new(0);
model1.add_assignment(1, 10);
model1.add_assignment(2, 10);
mbc.add_model(model1);
let equalities = mbc.combine_models().expect("Combination failed");
assert!(!equalities.is_empty());
}
#[test]
fn test_disjunctive_reasoning() {
let mut dr = DisjunctiveReasoning::new();
let eq = Equality::new(1, 2);
let disj = EqualityDisjunction::new(vec![eq], 0, 0);
dr.add_disjunction(disj);
let propagated = dr.propagate_units();
assert_eq!(propagated.len(), 1);
assert_eq!(propagated[0], eq);
}
#[test]
fn test_simplify_disjunction() {
let mut handler = ConvexityHandler::new();
let eq1 = Equality::new(1, 2);
let eq2 = Equality::new(1, 2);
let disj = EqualityDisjunction::new(vec![eq1, eq2], 0, 0);
handler.add_disjunction(disj);
assert!(handler.has_pending());
}
#[test]
fn test_backtrack() {
let mut handler = ConvexityHandler::new();
handler.push_decision_level();
let disj = EqualityDisjunction::new(vec![Equality::new(1, 2)], 0, 1);
handler.add_disjunction(disj);
handler.backtrack(0).expect("Backtrack failed");
assert_eq!(handler.pending_count(), 0);
}
#[test]
fn test_case_split() {
let mut handler = ConvexityHandler::new();
let eq1 = Equality::new(1, 2);
let eq2 = Equality::new(3, 4);
let disj = EqualityDisjunction::new(vec![eq1, eq2], 0, 0);
handler.add_disjunction(disj);
let result = handler.process_disjunctions();
assert!(result.is_ok());
assert!(result.ok().flatten().is_some());
}
}