1use rustc_hash::{FxHashMap, FxHashSet};
34use std::collections::VecDeque;
35
36pub type TermId = u32;
38
39pub type TheoryId = u32;
41
42pub type DecisionLevel = u32;
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub struct Equality {
48 pub lhs: TermId,
50 pub rhs: TermId,
52}
53
54impl Equality {
55 pub fn new(lhs: TermId, rhs: TermId) -> Self {
57 if lhs <= rhs {
58 Self { lhs, rhs }
59 } else {
60 Self { lhs: rhs, rhs: lhs }
61 }
62 }
63}
64
65#[derive(Debug, Clone)]
67pub struct EqualityDisjunction {
68 pub disjuncts: Vec<Equality>,
70 pub theory: TheoryId,
72 pub level: DecisionLevel,
74}
75
76impl EqualityDisjunction {
77 pub fn new(disjuncts: Vec<Equality>, theory: TheoryId, level: DecisionLevel) -> Self {
79 Self {
80 disjuncts,
81 theory,
82 level,
83 }
84 }
85
86 pub fn is_unit(&self) -> bool {
88 self.disjuncts.len() == 1
89 }
90
91 pub fn get_unit(&self) -> Option<Equality> {
93 if self.is_unit() {
94 self.disjuncts.first().copied()
95 } else {
96 None
97 }
98 }
99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq)]
103pub enum ConvexityProperty {
104 Convex,
106 NonConvex,
108 Unknown,
110}
111
112#[derive(Debug, Clone)]
114pub struct TheoryModel {
115 pub theory: TheoryId,
117 pub assignments: FxHashMap<TermId, TermId>,
119 pub equalities: Vec<Equality>,
121}
122
123impl TheoryModel {
124 pub fn new(theory: TheoryId) -> Self {
126 Self {
127 theory,
128 assignments: FxHashMap::default(),
129 equalities: Vec::new(),
130 }
131 }
132
133 pub fn add_assignment(&mut self, term: TermId, value: TermId) {
135 self.assignments.insert(term, value);
136 }
137
138 pub fn get_assignment(&self, term: TermId) -> Option<TermId> {
140 self.assignments.get(&term).copied()
141 }
142
143 pub fn add_equality(&mut self, eq: Equality) {
145 self.equalities.push(eq);
146 }
147}
148
149#[derive(Debug, Clone)]
151pub struct ConvexityConfig {
152 pub model_based_splitting: bool,
154
155 pub max_case_splits: usize,
157
158 pub conflict_driven_learning: bool,
160
161 pub split_strategy: CaseSplitStrategy,
163
164 pub simplify_disjunctions: bool,
166}
167
168impl Default for ConvexityConfig {
169 fn default() -> Self {
170 Self {
171 model_based_splitting: true,
172 max_case_splits: 100,
173 conflict_driven_learning: true,
174 split_strategy: CaseSplitStrategy::ModelBased,
175 simplify_disjunctions: true,
176 }
177 }
178}
179
180#[derive(Debug, Clone, Copy, PartialEq, Eq)]
182pub enum CaseSplitStrategy {
183 Exhaustive,
185 ModelBased,
187 Heuristic,
189 Lazy,
191}
192
193#[derive(Debug, Clone, Default)]
195pub struct ConvexityStats {
196 pub disjunctions_processed: u64,
198 pub case_splits: u64,
200 pub model_based_decisions: u64,
202 pub case_split_conflicts: u64,
204 pub learned_constraints: u64,
206}
207
208pub struct ConvexityHandler {
210 config: ConvexityConfig,
212
213 stats: ConvexityStats,
215
216 theory_properties: FxHashMap<TheoryId, ConvexityProperty>,
218
219 pending_disjunctions: VecDeque<EqualityDisjunction>,
221
222 case_split_stack: Vec<CaseSplit>,
224
225 learned: Vec<Vec<Equality>>,
227
228 decision_level: DecisionLevel,
230}
231
232#[derive(Debug, Clone)]
234struct CaseSplit {
235 level: DecisionLevel,
237 disjunction: EqualityDisjunction,
239 tried_cases: FxHashSet<usize>,
241 current_case: Option<usize>,
243}
244
245impl ConvexityHandler {
246 pub fn new() -> Self {
248 Self::with_config(ConvexityConfig::default())
249 }
250
251 pub fn with_config(config: ConvexityConfig) -> Self {
253 Self {
254 config,
255 stats: ConvexityStats::default(),
256 theory_properties: FxHashMap::default(),
257 pending_disjunctions: VecDeque::new(),
258 case_split_stack: Vec::new(),
259 learned: Vec::new(),
260 decision_level: 0,
261 }
262 }
263
264 pub fn stats(&self) -> &ConvexityStats {
266 &self.stats
267 }
268
269 pub fn register_theory(&mut self, theory: TheoryId, property: ConvexityProperty) {
271 self.theory_properties.insert(theory, property);
272 }
273
274 pub fn is_convex(&self, theory: TheoryId) -> bool {
276 matches!(
277 self.theory_properties.get(&theory),
278 Some(ConvexityProperty::Convex)
279 )
280 }
281
282 pub fn add_disjunction(&mut self, disjunction: EqualityDisjunction) {
284 if self.config.simplify_disjunctions
285 && let Some(simplified) = self.simplify_disjunction(&disjunction)
286 {
287 self.pending_disjunctions.push_back(simplified);
288 self.stats.disjunctions_processed += 1;
289 return;
290 }
291
292 self.pending_disjunctions.push_back(disjunction);
293 self.stats.disjunctions_processed += 1;
294 }
295
296 fn simplify_disjunction(
298 &self,
299 disjunction: &EqualityDisjunction,
300 ) -> Option<EqualityDisjunction> {
301 let mut unique_disjuncts = Vec::new();
303 let mut seen = FxHashSet::default();
304
305 for &eq in &disjunction.disjuncts {
306 if seen.insert(eq) {
307 unique_disjuncts.push(eq);
308 }
309 }
310
311 if unique_disjuncts.len() == disjunction.disjuncts.len() {
312 return None; }
314
315 Some(EqualityDisjunction::new(
316 unique_disjuncts,
317 disjunction.theory,
318 disjunction.level,
319 ))
320 }
321
322 pub fn process_disjunctions(&mut self) -> Result<Option<Equality>, String> {
324 while let Some(disjunction) = self.pending_disjunctions.pop_front() {
325 if let Some(eq) = disjunction.get_unit() {
327 return Ok(Some(eq));
328 }
329
330 if self.stats.case_splits >= self.config.max_case_splits as u64 {
332 return Err("Maximum case splits exceeded".to_string());
333 }
334
335 match self.config.split_strategy {
336 CaseSplitStrategy::ModelBased => {
337 return self.model_based_split(&disjunction);
338 }
339 CaseSplitStrategy::Exhaustive => {
340 return self.exhaustive_split(&disjunction);
341 }
342 CaseSplitStrategy::Heuristic => {
343 return self.heuristic_split(&disjunction);
344 }
345 CaseSplitStrategy::Lazy => {
346 self.pending_disjunctions.push_back(disjunction);
348 continue;
349 }
350 }
351 }
352
353 Ok(None)
354 }
355
356 fn model_based_split(
358 &mut self,
359 disjunction: &EqualityDisjunction,
360 ) -> Result<Option<Equality>, String> {
361 self.stats.case_splits += 1;
362 self.stats.model_based_decisions += 1;
363
364 if let Some(&eq) = disjunction.disjuncts.first() {
366 let split = CaseSplit {
368 level: self.decision_level,
369 disjunction: disjunction.clone(),
370 tried_cases: {
371 let mut set = FxHashSet::default();
372 set.insert(0);
373 set
374 },
375 current_case: Some(0),
376 };
377
378 self.case_split_stack.push(split);
379 return Ok(Some(eq));
380 }
381
382 Err("Empty disjunction".to_string())
383 }
384
385 fn exhaustive_split(
387 &mut self,
388 disjunction: &EqualityDisjunction,
389 ) -> Result<Option<Equality>, String> {
390 self.stats.case_splits += 1;
391
392 if let Some((i, &eq)) = disjunction.disjuncts.iter().enumerate().next() {
394 let split = CaseSplit {
395 level: self.decision_level,
396 disjunction: disjunction.clone(),
397 tried_cases: {
398 let mut set = FxHashSet::default();
399 set.insert(i);
400 set
401 },
402 current_case: Some(i),
403 };
404
405 self.case_split_stack.push(split);
406 return Ok(Some(eq));
407 }
408
409 Err("Empty disjunction".to_string())
410 }
411
412 fn heuristic_split(
414 &mut self,
415 disjunction: &EqualityDisjunction,
416 ) -> Result<Option<Equality>, String> {
417 self.model_based_split(disjunction)
419 }
420
421 pub fn backtrack_case_split(&mut self) -> Result<Option<Equality>, String> {
423 while let Some(mut split) = self.case_split_stack.pop() {
424 for (i, &eq) in split.disjunction.disjuncts.iter().enumerate() {
426 if !split.tried_cases.contains(&i) {
427 split.tried_cases.insert(i);
428 split.current_case = Some(i);
429 self.case_split_stack.push(split);
430 return Ok(Some(eq));
431 }
432 }
433
434 if self.config.conflict_driven_learning {
436 self.learn_conflict(&split.disjunction);
437 }
438
439 self.stats.case_split_conflicts += 1;
440 }
441
442 Ok(None) }
444
445 fn learn_conflict(&mut self, disjunction: &EqualityDisjunction) {
447 self.learned.push(disjunction.disjuncts.clone());
449 self.stats.learned_constraints += 1;
450 }
451
452 pub fn push_decision_level(&mut self) {
454 self.decision_level += 1;
455 }
456
457 pub fn backtrack(&mut self, level: DecisionLevel) -> Result<(), String> {
459 if level > self.decision_level {
460 return Err("Cannot backtrack to future level".to_string());
461 }
462
463 self.case_split_stack.retain(|split| split.level <= level);
465
466 let pending: Vec<_> = self.pending_disjunctions.drain(..).collect();
468 for disjunction in pending {
469 if disjunction.level <= level {
470 self.pending_disjunctions.push_back(disjunction);
471 }
472 }
473
474 self.decision_level = level;
475 Ok(())
476 }
477
478 pub fn learned_constraints(&self) -> &[Vec<Equality>] {
480 &self.learned
481 }
482
483 pub fn clear(&mut self) {
485 self.pending_disjunctions.clear();
486 self.case_split_stack.clear();
487 self.learned.clear();
488 self.decision_level = 0;
489 }
490
491 pub fn reset_stats(&mut self) {
493 self.stats = ConvexityStats::default();
494 }
495
496 pub fn has_pending(&self) -> bool {
498 !self.pending_disjunctions.is_empty()
499 }
500
501 pub fn pending_count(&self) -> usize {
503 self.pending_disjunctions.len()
504 }
505}
506
507impl Default for ConvexityHandler {
508 fn default() -> Self {
509 Self::new()
510 }
511}
512
513pub struct ModelBasedCombination {
515 models: FxHashMap<TheoryId, TheoryModel>,
517
518 derived_equalities: Vec<Equality>,
520}
521
522impl ModelBasedCombination {
523 pub fn new() -> Self {
525 Self {
526 models: FxHashMap::default(),
527 derived_equalities: Vec::new(),
528 }
529 }
530
531 pub fn add_model(&mut self, model: TheoryModel) {
533 self.models.insert(model.theory, model);
534 }
535
536 pub fn combine_models(&mut self) -> Result<Vec<Equality>, String> {
538 self.derived_equalities.clear();
539
540 let mut all_terms = FxHashSet::default();
542
543 for model in self.models.values() {
544 for &term in model.assignments.keys() {
545 all_terms.insert(term);
546 }
547 }
548
549 for &term1 in &all_terms {
551 for &term2 in &all_terms {
552 if term1 >= term2 {
553 continue;
554 }
555
556 let mut all_agree = true;
558
559 for model in self.models.values() {
560 if let (Some(val1), Some(val2)) =
561 (model.get_assignment(term1), model.get_assignment(term2))
562 && val1 != val2
563 {
564 all_agree = false;
565 break;
566 }
567 }
568
569 if all_agree {
570 self.derived_equalities.push(Equality::new(term1, term2));
571 }
572 }
573 }
574
575 Ok(self.derived_equalities.clone())
576 }
577
578 pub fn clear(&mut self) {
580 self.models.clear();
581 self.derived_equalities.clear();
582 }
583}
584
585impl Default for ModelBasedCombination {
586 fn default() -> Self {
587 Self::new()
588 }
589}
590
591pub struct DisjunctiveReasoning {
593 disjunctions: Vec<EqualityDisjunction>,
595
596 unit_queue: VecDeque<Equality>,
598}
599
600impl DisjunctiveReasoning {
601 pub fn new() -> Self {
603 Self {
604 disjunctions: Vec::new(),
605 unit_queue: VecDeque::new(),
606 }
607 }
608
609 pub fn add_disjunction(&mut self, disjunction: EqualityDisjunction) {
611 if disjunction.is_unit() {
612 if let Some(eq) = disjunction.get_unit() {
613 self.unit_queue.push_back(eq);
614 }
615 } else {
616 self.disjunctions.push(disjunction);
617 }
618 }
619
620 pub fn propagate_units(&mut self) -> Vec<Equality> {
622 let mut propagated = Vec::new();
623
624 while let Some(eq) = self.unit_queue.pop_front() {
625 propagated.push(eq);
626 }
627
628 propagated
629 }
630
631 pub fn simplify_with_equality(&mut self, eq: Equality) {
633 let mut simplified = Vec::new();
634
635 for disjunction in self.disjunctions.drain(..) {
636 let mut new_disjuncts = Vec::new();
637
638 for &disjunct in &disjunction.disjuncts {
639 if disjunct != eq {
641 new_disjuncts.push(disjunct);
642 }
643 }
644
645 if !new_disjuncts.is_empty() {
646 let new_disjunction =
647 EqualityDisjunction::new(new_disjuncts, disjunction.theory, disjunction.level);
648
649 if new_disjunction.is_unit() {
650 if let Some(unit_eq) = new_disjunction.get_unit() {
651 self.unit_queue.push_back(unit_eq);
652 }
653 } else {
654 simplified.push(new_disjunction);
655 }
656 }
657 }
658
659 self.disjunctions = simplified;
660 }
661
662 pub fn has_conflict(&self) -> bool {
664 false }
666
667 pub fn clear(&mut self) {
669 self.disjunctions.clear();
670 self.unit_queue.clear();
671 }
672}
673
674impl Default for DisjunctiveReasoning {
675 fn default() -> Self {
676 Self::new()
677 }
678}
679
680#[cfg(test)]
681mod tests {
682 use super::*;
683
684 #[test]
685 fn test_equality_disjunction() {
686 let eq1 = Equality::new(1, 2);
687 let eq2 = Equality::new(3, 4);
688
689 let disj = EqualityDisjunction::new(vec![eq1, eq2], 0, 0);
690 assert!(!disj.is_unit());
691 }
692
693 #[test]
694 fn test_unit_disjunction() {
695 let eq = Equality::new(1, 2);
696 let disj = EqualityDisjunction::new(vec![eq], 0, 0);
697
698 assert!(disj.is_unit());
699 assert_eq!(disj.get_unit(), Some(eq));
700 }
701
702 #[test]
703 fn test_handler_creation() {
704 let handler = ConvexityHandler::new();
705 assert_eq!(handler.stats().disjunctions_processed, 0);
706 }
707
708 #[test]
709 fn test_register_theory() {
710 let mut handler = ConvexityHandler::new();
711 handler.register_theory(0, ConvexityProperty::Convex);
712
713 assert!(handler.is_convex(0));
714 }
715
716 #[test]
717 fn test_add_disjunction() {
718 let mut handler = ConvexityHandler::new();
719 let disj = EqualityDisjunction::new(vec![Equality::new(1, 2)], 0, 0);
720
721 handler.add_disjunction(disj);
722 assert_eq!(handler.pending_count(), 1);
723 }
724
725 #[test]
726 fn test_process_unit_disjunction() {
727 let mut handler = ConvexityHandler::new();
728 let eq = Equality::new(1, 2);
729 let disj = EqualityDisjunction::new(vec![eq], 0, 0);
730
731 handler.add_disjunction(disj);
732
733 let result = handler.process_disjunctions();
734 assert!(result.is_ok());
735 assert_eq!(result.ok().flatten(), Some(eq));
736 }
737
738 #[test]
739 fn test_model_based_combination() {
740 let mut mbc = ModelBasedCombination::new();
741
742 let mut model1 = TheoryModel::new(0);
743 model1.add_assignment(1, 10);
744 model1.add_assignment(2, 10);
745
746 mbc.add_model(model1);
747
748 let equalities = mbc.combine_models().expect("Combination failed");
749 assert!(!equalities.is_empty());
750 }
751
752 #[test]
753 fn test_disjunctive_reasoning() {
754 let mut dr = DisjunctiveReasoning::new();
755
756 let eq = Equality::new(1, 2);
757 let disj = EqualityDisjunction::new(vec![eq], 0, 0);
758
759 dr.add_disjunction(disj);
760
761 let propagated = dr.propagate_units();
762 assert_eq!(propagated.len(), 1);
763 assert_eq!(propagated[0], eq);
764 }
765
766 #[test]
767 fn test_simplify_disjunction() {
768 let mut handler = ConvexityHandler::new();
769
770 let eq1 = Equality::new(1, 2);
771 let eq2 = Equality::new(1, 2); let disj = EqualityDisjunction::new(vec![eq1, eq2], 0, 0);
774 handler.add_disjunction(disj);
775
776 assert!(handler.has_pending());
778 }
779
780 #[test]
781 fn test_backtrack() {
782 let mut handler = ConvexityHandler::new();
783
784 handler.push_decision_level();
785 let disj = EqualityDisjunction::new(vec![Equality::new(1, 2)], 0, 1);
786 handler.add_disjunction(disj);
787
788 handler.backtrack(0).expect("Backtrack failed");
789 assert_eq!(handler.pending_count(), 0);
790 }
791
792 #[test]
793 fn test_case_split() {
794 let mut handler = ConvexityHandler::new();
795
796 let eq1 = Equality::new(1, 2);
797 let eq2 = Equality::new(3, 4);
798 let disj = EqualityDisjunction::new(vec![eq1, eq2], 0, 0);
799
800 handler.add_disjunction(disj);
801
802 let result = handler.process_disjunctions();
803 assert!(result.is_ok());
804 assert!(result.ok().flatten().is_some());
805 }
806}