Skip to main content

oxiz_solver/
shared_terms.rs

1//! Shared Terms Management for Theory Combination.
2#![allow(dead_code)] // Under development
3//!
4//! Manages terms that appear in multiple theories, enabling efficient
5//! equality sharing in Nelson-Oppen combination.
6//!
7//! ## Architecture
8//!
9//! - **Term Index**: Fast lookup of shared terms
10//! - **Theory Subscriptions**: Theories register interest in terms
11//! - **Notification System**: Propagate equality information between theories
12//!
13//! ## References
14//!
15//! - Nelson & Oppen: "Simplification by Cooperating Decision Procedures" (1979)
16//! - Z3's `smt/theory_combine.cpp`
17
18#[allow(unused_imports)]
19use crate::prelude::*;
20use oxiz_core::TermId;
21
22/// Theory identifier.
23pub type TheoryId = usize;
24
25/// Equality between two terms.
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
27pub struct Equality {
28    /// Left-hand side term.
29    pub lhs: TermId,
30    /// Right-hand side term.
31    pub rhs: TermId,
32}
33
34impl Equality {
35    /// Create a new equality.
36    pub fn new(lhs: TermId, rhs: TermId) -> Self {
37        // Normalize: smaller TermId first
38        if lhs.raw() <= rhs.raw() {
39            Self { lhs, rhs }
40        } else {
41            Self { lhs: rhs, rhs: lhs }
42        }
43    }
44}
45
46/// Information about a shared term.
47#[derive(Debug, Clone)]
48struct SharedTermInfo {
49    /// Theories that use this term.
50    theories: FxHashSet<TheoryId>,
51    /// Representative term in equivalence class.
52    representative: TermId,
53}
54
55/// Configuration for shared terms manager.
56#[derive(Debug, Clone)]
57pub struct SharedTermsConfig {
58    /// Enable notification batching.
59    pub enable_batching: bool,
60    /// Maximum batch size before forcing flush.
61    pub max_batch_size: usize,
62}
63
64impl Default for SharedTermsConfig {
65    fn default() -> Self {
66        Self {
67            enable_batching: true,
68            max_batch_size: 1000,
69        }
70    }
71}
72
73/// Statistics for shared terms.
74#[derive(Debug, Clone, Default)]
75pub struct SharedTermsStats {
76    /// Number of shared terms registered.
77    pub terms_registered: u64,
78    /// Number of theory subscriptions.
79    pub subscriptions: u64,
80    /// Equalities propagated.
81    pub equalities_propagated: u64,
82    /// Notification batches sent.
83    pub batches_sent: u64,
84}
85
86/// Shared terms manager for theory combination.
87#[derive(Debug)]
88pub struct SharedTermsManager {
89    /// Configuration.
90    config: SharedTermsConfig,
91    /// Shared term information.
92    terms: FxHashMap<TermId, SharedTermInfo>,
93    /// Equality classes (union-find).
94    parent: FxHashMap<TermId, TermId>,
95    /// Pending equalities to propagate.
96    pending_equalities: Vec<Equality>,
97    /// Theories subscribed to each term.
98    subscriptions: FxHashMap<TermId, FxHashSet<TheoryId>>,
99    /// Statistics.
100    stats: SharedTermsStats,
101}
102
103impl SharedTermsManager {
104    /// Create a new shared terms manager.
105    pub fn new(config: SharedTermsConfig) -> Self {
106        Self {
107            config,
108            terms: FxHashMap::default(),
109            parent: FxHashMap::default(),
110            pending_equalities: Vec::new(),
111            subscriptions: FxHashMap::default(),
112            stats: SharedTermsStats::default(),
113        }
114    }
115
116    /// Create with default configuration.
117    pub fn default_config() -> Self {
118        Self::new(SharedTermsConfig::default())
119    }
120
121    /// Register a shared term.
122    pub fn register_term(&mut self, term: TermId, theory: TheoryId) {
123        let entry = self.terms.entry(term).or_insert_with(|| {
124            self.stats.terms_registered += 1;
125            SharedTermInfo {
126                theories: FxHashSet::default(),
127                representative: term,
128            }
129        });
130
131        entry.theories.insert(theory);
132        self.stats.subscriptions += 1;
133
134        // Also track subscriptions separately for fast lookup
135        self.subscriptions.entry(term).or_default().insert(theory);
136    }
137
138    /// Check if a term is shared between multiple theories.
139    pub fn is_shared(&self, term: TermId) -> bool {
140        self.terms
141            .get(&term)
142            .map(|info| info.theories.len() > 1)
143            .unwrap_or(false)
144    }
145
146    /// Get theories that use a term.
147    pub fn get_theories(&self, term: TermId) -> Vec<TheoryId> {
148        self.terms
149            .get(&term)
150            .map(|info| info.theories.iter().copied().collect())
151            .unwrap_or_default()
152    }
153
154    /// Assert equality between two terms.
155    ///
156    /// This merges their equivalence classes and queues notifications.
157    pub fn assert_equality(&mut self, lhs: TermId, rhs: TermId) {
158        let lhs_rep = self.find(lhs);
159        let rhs_rep = self.find(rhs);
160
161        if lhs_rep == rhs_rep {
162            return; // Already equal
163        }
164
165        // Union: make lhs_rep point to rhs_rep
166        self.parent.insert(lhs_rep, rhs_rep);
167
168        // Queue equality for propagation
169        let equality = Equality::new(lhs, rhs);
170        self.pending_equalities.push(equality);
171        self.stats.equalities_propagated += 1;
172
173        // Check if should flush batch
174        if self.pending_equalities.len() >= self.config.max_batch_size {
175            self.flush_equalities();
176        }
177    }
178
179    /// Find representative of equivalence class (with path compression).
180    fn find(&mut self, term: TermId) -> TermId {
181        if let Some(&parent) = self.parent.get(&term)
182            && parent != term
183        {
184            let root = self.find(parent);
185            self.parent.insert(term, root); // Path compression
186            return root;
187        }
188
189        term
190    }
191
192    /// Check if two terms are in the same equivalence class.
193    pub fn are_equal(&mut self, lhs: TermId, rhs: TermId) -> bool {
194        self.find(lhs) == self.find(rhs)
195    }
196
197    /// Get pending equalities to propagate.
198    pub fn get_pending_equalities(&self) -> &[Equality] {
199        &self.pending_equalities
200    }
201
202    /// Flush pending equalities (send to theories).
203    pub fn flush_equalities(&mut self) {
204        if !self.pending_equalities.is_empty() {
205            self.stats.batches_sent += 1;
206            self.pending_equalities.clear();
207        }
208    }
209
210    /// Get all shared terms.
211    pub fn get_shared_terms(&self) -> Vec<TermId> {
212        self.terms
213            .iter()
214            .filter(|(_, info)| info.theories.len() > 1)
215            .map(|(&term, _)| term)
216            .collect()
217    }
218
219    /// Get statistics.
220    pub fn stats(&self) -> &SharedTermsStats {
221        &self.stats
222    }
223
224    /// Reset manager state.
225    pub fn reset(&mut self) {
226        self.terms.clear();
227        self.parent.clear();
228        self.pending_equalities.clear();
229        self.subscriptions.clear();
230        self.stats = SharedTermsStats::default();
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    fn term(id: u32) -> TermId {
239        TermId::new(id)
240    }
241
242    #[test]
243    fn test_manager_creation() {
244        let manager = SharedTermsManager::default_config();
245        assert_eq!(manager.stats().terms_registered, 0);
246    }
247
248    #[test]
249    fn test_register_term() {
250        let mut manager = SharedTermsManager::default_config();
251
252        manager.register_term(term(1), 0); // Theory 0
253        manager.register_term(term(1), 1); // Theory 1
254
255        assert!(manager.is_shared(term(1)));
256        assert_eq!(manager.get_theories(term(1)).len(), 2);
257    }
258
259    #[test]
260    fn test_equality() {
261        let mut manager = SharedTermsManager::default_config();
262
263        manager.assert_equality(term(1), term(2));
264
265        assert!(manager.are_equal(term(1), term(2)));
266        assert_eq!(manager.get_pending_equalities().len(), 1);
267    }
268
269    #[test]
270    fn test_equality_transitivity() {
271        let mut manager = SharedTermsManager::default_config();
272
273        manager.assert_equality(term(1), term(2));
274        manager.assert_equality(term(2), term(3));
275
276        assert!(manager.are_equal(term(1), term(3)));
277    }
278
279    #[test]
280    fn test_flush_equalities() {
281        let mut manager = SharedTermsManager::default_config();
282
283        manager.assert_equality(term(1), term(2));
284        assert_eq!(manager.get_pending_equalities().len(), 1);
285
286        manager.flush_equalities();
287        assert_eq!(manager.get_pending_equalities().len(), 0);
288    }
289}