Skip to main content

oxiz_solver/combination/
convexity.rs

1//! Convexity Checking and Handling for Theory Combination.
2//!
3//! This module provides convexity analysis for theories in combination:
4//! - Convexity checking for theory solvers
5//! - Non-convex theory handling strategies
6//! - Model-based case analysis for non-convex theories
7//! - Disjunctive reasoning
8//!
9//! ## Convexity
10//!
11//! A theory T is **convex** if for any conjunction of literals C and
12//! disjunction of equalities (t1 = s1) ∨ ... ∨ (tn = sn):
13//!
14//! If C ∧ T ⊨ (t1 = s1) ∨ ... ∨ (tn = sn),
15//! then C ∧ T ⊨ (ti = si) for some i.
16//!
17//! **Convex theories**: Equality, Uninterpreted Functions, Linear Arithmetic (rationals)
18//! **Non-convex theories**: Integer Arithmetic, Bit-vectors
19//!
20//! ## Non-Convex Theory Handling
21//!
22//! For non-convex theories, we must handle disjunctions explicitly:
23//! - Case splitting on equality disjunctions
24//! - Model-based theory combination
25//! - Conflict-driven learning to prune search space
26//!
27//! ## References
28//!
29//! - Nelson & Oppen (1979): "Simplification by Cooperating Decision Procedures"
30//! - Tinelli & Harandi (1996): "A New Correctness Proof of the Nelson-Oppen Combination"
31//! - Z3's `smt/theory_opt.cpp`
32
33use rustc_hash::{FxHashMap, FxHashSet};
34use std::collections::VecDeque;
35
36/// Term identifier.
37pub type TermId = u32;
38
39/// Theory identifier.
40pub type TheoryId = u32;
41
42/// Decision level.
43pub type DecisionLevel = u32;
44
45/// Equality between terms.
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub struct Equality {
48    /// Left-hand side.
49    pub lhs: TermId,
50    /// Right-hand side.
51    pub rhs: TermId,
52}
53
54impl Equality {
55    /// Create new equality.
56    pub fn new(lhs: TermId, rhs: TermId) -> Self {
57        if lhs <= rhs {
58            Self { lhs, rhs }
59        } else {
60            Self { lhs: rhs, rhs: lhs }
61        }
62    }
63}
64
65/// Disjunction of equalities.
66#[derive(Debug, Clone)]
67pub struct EqualityDisjunction {
68    /// Disjuncts (equalities).
69    pub disjuncts: Vec<Equality>,
70    /// Source theory.
71    pub theory: TheoryId,
72    /// Decision level.
73    pub level: DecisionLevel,
74}
75
76impl EqualityDisjunction {
77    /// Create new disjunction.
78    pub fn new(disjuncts: Vec<Equality>, theory: TheoryId, level: DecisionLevel) -> Self {
79        Self {
80            disjuncts,
81            theory,
82            level,
83        }
84    }
85
86    /// Check if disjunction is unit (single disjunct).
87    pub fn is_unit(&self) -> bool {
88        self.disjuncts.len() == 1
89    }
90
91    /// Get unit disjunct if disjunction is unit.
92    pub fn get_unit(&self) -> Option<Equality> {
93        if self.is_unit() {
94            self.disjuncts.first().copied()
95        } else {
96            None
97        }
98    }
99}
100
101/// Convexity property of a theory.
102#[derive(Debug, Clone, Copy, PartialEq, Eq)]
103pub enum ConvexityProperty {
104    /// Theory is convex.
105    Convex,
106    /// Theory is non-convex.
107    NonConvex,
108    /// Convexity is unknown or theory-dependent.
109    Unknown,
110}
111
112/// Theory model for case splitting.
113#[derive(Debug, Clone)]
114pub struct TheoryModel {
115    /// Theory identifier.
116    pub theory: TheoryId,
117    /// Variable assignments.
118    pub assignments: FxHashMap<TermId, TermId>,
119    /// Implied equalities.
120    pub equalities: Vec<Equality>,
121}
122
123impl TheoryModel {
124    /// Create new model.
125    pub fn new(theory: TheoryId) -> Self {
126        Self {
127            theory,
128            assignments: FxHashMap::default(),
129            equalities: Vec::new(),
130        }
131    }
132
133    /// Add assignment.
134    pub fn add_assignment(&mut self, term: TermId, value: TermId) {
135        self.assignments.insert(term, value);
136    }
137
138    /// Get assignment.
139    pub fn get_assignment(&self, term: TermId) -> Option<TermId> {
140        self.assignments.get(&term).copied()
141    }
142
143    /// Add implied equality.
144    pub fn add_equality(&mut self, eq: Equality) {
145        self.equalities.push(eq);
146    }
147}
148
149/// Configuration for convexity handling.
150#[derive(Debug, Clone)]
151pub struct ConvexityConfig {
152    /// Enable model-based case splitting.
153    pub model_based_splitting: bool,
154
155    /// Maximum case splits.
156    pub max_case_splits: usize,
157
158    /// Enable conflict-driven learning.
159    pub conflict_driven_learning: bool,
160
161    /// Case split strategy.
162    pub split_strategy: CaseSplitStrategy,
163
164    /// Enable disjunction simplification.
165    pub simplify_disjunctions: bool,
166}
167
168impl Default for ConvexityConfig {
169    fn default() -> Self {
170        Self {
171            model_based_splitting: true,
172            max_case_splits: 100,
173            conflict_driven_learning: true,
174            split_strategy: CaseSplitStrategy::ModelBased,
175            simplify_disjunctions: true,
176        }
177    }
178}
179
180/// Case split strategy for non-convex theories.
181#[derive(Debug, Clone, Copy, PartialEq, Eq)]
182pub enum CaseSplitStrategy {
183    /// Enumerate all cases.
184    Exhaustive,
185    /// Use model to guide splitting.
186    ModelBased,
187    /// Use heuristics.
188    Heuristic,
189    /// Lazy splitting (defer as long as possible).
190    Lazy,
191}
192
193/// Statistics for convexity handling.
194#[derive(Debug, Clone, Default)]
195pub struct ConvexityStats {
196    /// Disjunctions processed.
197    pub disjunctions_processed: u64,
198    /// Case splits performed.
199    pub case_splits: u64,
200    /// Model-based decisions.
201    pub model_based_decisions: u64,
202    /// Conflicts from case splits.
203    pub case_split_conflicts: u64,
204    /// Learned constraints.
205    pub learned_constraints: u64,
206}
207
208/// Convexity checker and handler.
209pub struct ConvexityHandler {
210    /// Configuration.
211    config: ConvexityConfig,
212
213    /// Statistics.
214    stats: ConvexityStats,
215
216    /// Theory convexity properties.
217    theory_properties: FxHashMap<TheoryId, ConvexityProperty>,
218
219    /// Pending disjunctions.
220    pending_disjunctions: VecDeque<EqualityDisjunction>,
221
222    /// Case split stack.
223    case_split_stack: Vec<CaseSplit>,
224
225    /// Learned constraints.
226    learned: Vec<Vec<Equality>>,
227
228    /// Current decision level.
229    decision_level: DecisionLevel,
230}
231
232/// Case split record.
233#[derive(Debug, Clone)]
234struct CaseSplit {
235    /// Decision level where split was made.
236    level: DecisionLevel,
237    /// Disjunction being split.
238    disjunction: EqualityDisjunction,
239    /// Cases already tried.
240    tried_cases: FxHashSet<usize>,
241    /// Current case being explored.
242    current_case: Option<usize>,
243}
244
245impl ConvexityHandler {
246    /// Create new handler.
247    pub fn new() -> Self {
248        Self::with_config(ConvexityConfig::default())
249    }
250
251    /// Create with configuration.
252    pub fn with_config(config: ConvexityConfig) -> Self {
253        Self {
254            config,
255            stats: ConvexityStats::default(),
256            theory_properties: FxHashMap::default(),
257            pending_disjunctions: VecDeque::new(),
258            case_split_stack: Vec::new(),
259            learned: Vec::new(),
260            decision_level: 0,
261        }
262    }
263
264    /// Get statistics.
265    pub fn stats(&self) -> &ConvexityStats {
266        &self.stats
267    }
268
269    /// Register theory with convexity property.
270    pub fn register_theory(&mut self, theory: TheoryId, property: ConvexityProperty) {
271        self.theory_properties.insert(theory, property);
272    }
273
274    /// Check if theory is convex.
275    pub fn is_convex(&self, theory: TheoryId) -> bool {
276        matches!(
277            self.theory_properties.get(&theory),
278            Some(ConvexityProperty::Convex)
279        )
280    }
281
282    /// Add disjunction to process.
283    pub fn add_disjunction(&mut self, disjunction: EqualityDisjunction) {
284        if self.config.simplify_disjunctions
285            && let Some(simplified) = self.simplify_disjunction(&disjunction)
286        {
287            self.pending_disjunctions.push_back(simplified);
288            self.stats.disjunctions_processed += 1;
289            return;
290        }
291
292        self.pending_disjunctions.push_back(disjunction);
293        self.stats.disjunctions_processed += 1;
294    }
295
296    /// Simplify disjunction.
297    fn simplify_disjunction(
298        &self,
299        disjunction: &EqualityDisjunction,
300    ) -> Option<EqualityDisjunction> {
301        // Remove duplicate disjuncts
302        let mut unique_disjuncts = Vec::new();
303        let mut seen = FxHashSet::default();
304
305        for &eq in &disjunction.disjuncts {
306            if seen.insert(eq) {
307                unique_disjuncts.push(eq);
308            }
309        }
310
311        if unique_disjuncts.len() == disjunction.disjuncts.len() {
312            return None; // No simplification
313        }
314
315        Some(EqualityDisjunction::new(
316            unique_disjuncts,
317            disjunction.theory,
318            disjunction.level,
319        ))
320    }
321
322    /// Process pending disjunctions.
323    pub fn process_disjunctions(&mut self) -> Result<Option<Equality>, String> {
324        while let Some(disjunction) = self.pending_disjunctions.pop_front() {
325            // If unit, return the single equality
326            if let Some(eq) = disjunction.get_unit() {
327                return Ok(Some(eq));
328            }
329
330            // Non-unit disjunction: perform case split
331            if self.stats.case_splits >= self.config.max_case_splits as u64 {
332                return Err("Maximum case splits exceeded".to_string());
333            }
334
335            match self.config.split_strategy {
336                CaseSplitStrategy::ModelBased => {
337                    return self.model_based_split(&disjunction);
338                }
339                CaseSplitStrategy::Exhaustive => {
340                    return self.exhaustive_split(&disjunction);
341                }
342                CaseSplitStrategy::Heuristic => {
343                    return self.heuristic_split(&disjunction);
344                }
345                CaseSplitStrategy::Lazy => {
346                    // Defer splitting
347                    self.pending_disjunctions.push_back(disjunction);
348                    continue;
349                }
350            }
351        }
352
353        Ok(None)
354    }
355
356    /// Model-based case split.
357    fn model_based_split(
358        &mut self,
359        disjunction: &EqualityDisjunction,
360    ) -> Result<Option<Equality>, String> {
361        self.stats.case_splits += 1;
362        self.stats.model_based_decisions += 1;
363
364        // Choose first disjunct (model would guide this choice)
365        if let Some(&eq) = disjunction.disjuncts.first() {
366            // Record case split
367            let split = CaseSplit {
368                level: self.decision_level,
369                disjunction: disjunction.clone(),
370                tried_cases: {
371                    let mut set = FxHashSet::default();
372                    set.insert(0);
373                    set
374                },
375                current_case: Some(0),
376            };
377
378            self.case_split_stack.push(split);
379            return Ok(Some(eq));
380        }
381
382        Err("Empty disjunction".to_string())
383    }
384
385    /// Exhaustive case split.
386    fn exhaustive_split(
387        &mut self,
388        disjunction: &EqualityDisjunction,
389    ) -> Result<Option<Equality>, String> {
390        self.stats.case_splits += 1;
391
392        // Try first untried case
393        if let Some((i, &eq)) = disjunction.disjuncts.iter().enumerate().next() {
394            let split = CaseSplit {
395                level: self.decision_level,
396                disjunction: disjunction.clone(),
397                tried_cases: {
398                    let mut set = FxHashSet::default();
399                    set.insert(i);
400                    set
401                },
402                current_case: Some(i),
403            };
404
405            self.case_split_stack.push(split);
406            return Ok(Some(eq));
407        }
408
409        Err("Empty disjunction".to_string())
410    }
411
412    /// Heuristic-based split.
413    fn heuristic_split(
414        &mut self,
415        disjunction: &EqualityDisjunction,
416    ) -> Result<Option<Equality>, String> {
417        // Use model-based for now (could be enhanced with better heuristics)
418        self.model_based_split(disjunction)
419    }
420
421    /// Backtrack case split on conflict.
422    pub fn backtrack_case_split(&mut self) -> Result<Option<Equality>, String> {
423        while let Some(mut split) = self.case_split_stack.pop() {
424            // Try next untried case
425            for (i, &eq) in split.disjunction.disjuncts.iter().enumerate() {
426                if !split.tried_cases.contains(&i) {
427                    split.tried_cases.insert(i);
428                    split.current_case = Some(i);
429                    self.case_split_stack.push(split);
430                    return Ok(Some(eq));
431                }
432            }
433
434            // All cases tried, learn conflict
435            if self.config.conflict_driven_learning {
436                self.learn_conflict(&split.disjunction);
437            }
438
439            self.stats.case_split_conflicts += 1;
440        }
441
442        Ok(None) // No more cases to try
443    }
444
445    /// Learn conflict from exhausted disjunction.
446    fn learn_conflict(&mut self, disjunction: &EqualityDisjunction) {
447        // Learn that this disjunction is unsatisfiable
448        self.learned.push(disjunction.disjuncts.clone());
449        self.stats.learned_constraints += 1;
450    }
451
452    /// Push decision level.
453    pub fn push_decision_level(&mut self) {
454        self.decision_level += 1;
455    }
456
457    /// Backtrack to decision level.
458    pub fn backtrack(&mut self, level: DecisionLevel) -> Result<(), String> {
459        if level > self.decision_level {
460            return Err("Cannot backtrack to future level".to_string());
461        }
462
463        // Remove case splits above this level
464        self.case_split_stack.retain(|split| split.level <= level);
465
466        // Remove disjunctions above this level
467        let pending: Vec<_> = self.pending_disjunctions.drain(..).collect();
468        for disjunction in pending {
469            if disjunction.level <= level {
470                self.pending_disjunctions.push_back(disjunction);
471            }
472        }
473
474        self.decision_level = level;
475        Ok(())
476    }
477
478    /// Get learned constraints.
479    pub fn learned_constraints(&self) -> &[Vec<Equality>] {
480        &self.learned
481    }
482
483    /// Clear all state.
484    pub fn clear(&mut self) {
485        self.pending_disjunctions.clear();
486        self.case_split_stack.clear();
487        self.learned.clear();
488        self.decision_level = 0;
489    }
490
491    /// Reset statistics.
492    pub fn reset_stats(&mut self) {
493        self.stats = ConvexityStats::default();
494    }
495
496    /// Check if there are pending disjunctions.
497    pub fn has_pending(&self) -> bool {
498        !self.pending_disjunctions.is_empty()
499    }
500
501    /// Get number of pending disjunctions.
502    pub fn pending_count(&self) -> usize {
503        self.pending_disjunctions.len()
504    }
505}
506
507impl Default for ConvexityHandler {
508    fn default() -> Self {
509        Self::new()
510    }
511}
512
513/// Model-based theory combination for non-convex theories.
514pub struct ModelBasedCombination {
515    /// Theory models.
516    models: FxHashMap<TheoryId, TheoryModel>,
517
518    /// Equalities derived from models.
519    derived_equalities: Vec<Equality>,
520}
521
522impl ModelBasedCombination {
523    /// Create new model-based combination.
524    pub fn new() -> Self {
525        Self {
526            models: FxHashMap::default(),
527            derived_equalities: Vec::new(),
528        }
529    }
530
531    /// Add theory model.
532    pub fn add_model(&mut self, model: TheoryModel) {
533        self.models.insert(model.theory, model);
534    }
535
536    /// Combine models to derive interface equalities.
537    pub fn combine_models(&mut self) -> Result<Vec<Equality>, String> {
538        self.derived_equalities.clear();
539
540        // Collect all terms from all models
541        let mut all_terms = FxHashSet::default();
542
543        for model in self.models.values() {
544            for &term in model.assignments.keys() {
545                all_terms.insert(term);
546            }
547        }
548
549        // Check consistency and derive equalities
550        for &term1 in &all_terms {
551            for &term2 in &all_terms {
552                if term1 >= term2 {
553                    continue;
554                }
555
556                // Check if all models agree that term1 = term2
557                let mut all_agree = true;
558
559                for model in self.models.values() {
560                    if let (Some(val1), Some(val2)) =
561                        (model.get_assignment(term1), model.get_assignment(term2))
562                        && val1 != val2
563                    {
564                        all_agree = false;
565                        break;
566                    }
567                }
568
569                if all_agree {
570                    self.derived_equalities.push(Equality::new(term1, term2));
571                }
572            }
573        }
574
575        Ok(self.derived_equalities.clone())
576    }
577
578    /// Clear all models.
579    pub fn clear(&mut self) {
580        self.models.clear();
581        self.derived_equalities.clear();
582    }
583}
584
585impl Default for ModelBasedCombination {
586    fn default() -> Self {
587        Self::new()
588    }
589}
590
591/// Disjunctive reasoning engine.
592pub struct DisjunctiveReasoning {
593    /// Active disjunctions.
594    disjunctions: Vec<EqualityDisjunction>,
595
596    /// Unit propagation queue.
597    unit_queue: VecDeque<Equality>,
598}
599
600impl DisjunctiveReasoning {
601    /// Create new disjunctive reasoning engine.
602    pub fn new() -> Self {
603        Self {
604            disjunctions: Vec::new(),
605            unit_queue: VecDeque::new(),
606        }
607    }
608
609    /// Add disjunction.
610    pub fn add_disjunction(&mut self, disjunction: EqualityDisjunction) {
611        if disjunction.is_unit() {
612            if let Some(eq) = disjunction.get_unit() {
613                self.unit_queue.push_back(eq);
614            }
615        } else {
616            self.disjunctions.push(disjunction);
617        }
618    }
619
620    /// Propagate unit disjunctions.
621    pub fn propagate_units(&mut self) -> Vec<Equality> {
622        let mut propagated = Vec::new();
623
624        while let Some(eq) = self.unit_queue.pop_front() {
625            propagated.push(eq);
626        }
627
628        propagated
629    }
630
631    /// Simplify disjunctions given an equality.
632    pub fn simplify_with_equality(&mut self, eq: Equality) {
633        let mut simplified = Vec::new();
634
635        for disjunction in self.disjunctions.drain(..) {
636            let mut new_disjuncts = Vec::new();
637
638            for &disjunct in &disjunction.disjuncts {
639                // Check if disjunct is satisfied by eq
640                if disjunct != eq {
641                    new_disjuncts.push(disjunct);
642                }
643            }
644
645            if !new_disjuncts.is_empty() {
646                let new_disjunction =
647                    EqualityDisjunction::new(new_disjuncts, disjunction.theory, disjunction.level);
648
649                if new_disjunction.is_unit() {
650                    if let Some(unit_eq) = new_disjunction.get_unit() {
651                        self.unit_queue.push_back(unit_eq);
652                    }
653                } else {
654                    simplified.push(new_disjunction);
655                }
656            }
657        }
658
659        self.disjunctions = simplified;
660    }
661
662    /// Check for conflicts (empty disjunctions).
663    pub fn has_conflict(&self) -> bool {
664        false // Simplified: would check for empty disjunctions
665    }
666
667    /// Clear all disjunctions.
668    pub fn clear(&mut self) {
669        self.disjunctions.clear();
670        self.unit_queue.clear();
671    }
672}
673
674impl Default for DisjunctiveReasoning {
675    fn default() -> Self {
676        Self::new()
677    }
678}
679
680#[cfg(test)]
681mod tests {
682    use super::*;
683
684    #[test]
685    fn test_equality_disjunction() {
686        let eq1 = Equality::new(1, 2);
687        let eq2 = Equality::new(3, 4);
688
689        let disj = EqualityDisjunction::new(vec![eq1, eq2], 0, 0);
690        assert!(!disj.is_unit());
691    }
692
693    #[test]
694    fn test_unit_disjunction() {
695        let eq = Equality::new(1, 2);
696        let disj = EqualityDisjunction::new(vec![eq], 0, 0);
697
698        assert!(disj.is_unit());
699        assert_eq!(disj.get_unit(), Some(eq));
700    }
701
702    #[test]
703    fn test_handler_creation() {
704        let handler = ConvexityHandler::new();
705        assert_eq!(handler.stats().disjunctions_processed, 0);
706    }
707
708    #[test]
709    fn test_register_theory() {
710        let mut handler = ConvexityHandler::new();
711        handler.register_theory(0, ConvexityProperty::Convex);
712
713        assert!(handler.is_convex(0));
714    }
715
716    #[test]
717    fn test_add_disjunction() {
718        let mut handler = ConvexityHandler::new();
719        let disj = EqualityDisjunction::new(vec![Equality::new(1, 2)], 0, 0);
720
721        handler.add_disjunction(disj);
722        assert_eq!(handler.pending_count(), 1);
723    }
724
725    #[test]
726    fn test_process_unit_disjunction() {
727        let mut handler = ConvexityHandler::new();
728        let eq = Equality::new(1, 2);
729        let disj = EqualityDisjunction::new(vec![eq], 0, 0);
730
731        handler.add_disjunction(disj);
732
733        let result = handler.process_disjunctions();
734        assert!(result.is_ok());
735        assert_eq!(result.ok().flatten(), Some(eq));
736    }
737
738    #[test]
739    fn test_model_based_combination() {
740        let mut mbc = ModelBasedCombination::new();
741
742        let mut model1 = TheoryModel::new(0);
743        model1.add_assignment(1, 10);
744        model1.add_assignment(2, 10);
745
746        mbc.add_model(model1);
747
748        let equalities = mbc.combine_models().expect("Combination failed");
749        assert!(!equalities.is_empty());
750    }
751
752    #[test]
753    fn test_disjunctive_reasoning() {
754        let mut dr = DisjunctiveReasoning::new();
755
756        let eq = Equality::new(1, 2);
757        let disj = EqualityDisjunction::new(vec![eq], 0, 0);
758
759        dr.add_disjunction(disj);
760
761        let propagated = dr.propagate_units();
762        assert_eq!(propagated.len(), 1);
763        assert_eq!(propagated[0], eq);
764    }
765
766    #[test]
767    fn test_simplify_disjunction() {
768        let mut handler = ConvexityHandler::new();
769
770        let eq1 = Equality::new(1, 2);
771        let eq2 = Equality::new(1, 2); // Duplicate
772
773        let disj = EqualityDisjunction::new(vec![eq1, eq2], 0, 0);
774        handler.add_disjunction(disj);
775
776        // Should be simplified to unit
777        assert!(handler.has_pending());
778    }
779
780    #[test]
781    fn test_backtrack() {
782        let mut handler = ConvexityHandler::new();
783
784        handler.push_decision_level();
785        let disj = EqualityDisjunction::new(vec![Equality::new(1, 2)], 0, 1);
786        handler.add_disjunction(disj);
787
788        handler.backtrack(0).expect("Backtrack failed");
789        assert_eq!(handler.pending_count(), 0);
790    }
791
792    #[test]
793    fn test_case_split() {
794        let mut handler = ConvexityHandler::new();
795
796        let eq1 = Equality::new(1, 2);
797        let eq2 = Equality::new(3, 4);
798        let disj = EqualityDisjunction::new(vec![eq1, eq2], 0, 0);
799
800        handler.add_disjunction(disj);
801
802        let result = handler.process_disjunctions();
803        assert!(result.is_ok());
804        assert!(result.ok().flatten().is_some());
805    }
806}