Skip to main content

oxiz_solver/
shared_terms.rs

1//! Shared Terms Management for Theory Combination.
2#![allow(dead_code)] // Under development
3//!
4//! Manages terms that appear in multiple theories, enabling efficient
5//! equality sharing in Nelson-Oppen combination.
6//!
7//! ## Architecture
8//!
9//! - **Term Index**: Fast lookup of shared terms
10//! - **Theory Subscriptions**: Theories register interest in terms
11//! - **Notification System**: Propagate equality information between theories
12//!
13//! ## References
14//!
15//! - Nelson & Oppen: "Simplification by Cooperating Decision Procedures" (1979)
16//! - Z3's `smt/theory_combine.cpp`
17
18use oxiz_core::TermId;
19use rustc_hash::{FxHashMap, FxHashSet};
20
21/// Theory identifier.
22pub type TheoryId = usize;
23
24/// Equality between two terms.
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub struct Equality {
27    /// Left-hand side term.
28    pub lhs: TermId,
29    /// Right-hand side term.
30    pub rhs: TermId,
31}
32
33impl Equality {
34    /// Create a new equality.
35    pub fn new(lhs: TermId, rhs: TermId) -> Self {
36        // Normalize: smaller TermId first
37        if lhs.raw() <= rhs.raw() {
38            Self { lhs, rhs }
39        } else {
40            Self { lhs: rhs, rhs: lhs }
41        }
42    }
43}
44
45/// Information about a shared term.
46#[derive(Debug, Clone)]
47struct SharedTermInfo {
48    /// Theories that use this term.
49    theories: FxHashSet<TheoryId>,
50    /// Representative term in equivalence class.
51    representative: TermId,
52}
53
54/// Configuration for shared terms manager.
55#[derive(Debug, Clone)]
56pub struct SharedTermsConfig {
57    /// Enable notification batching.
58    pub enable_batching: bool,
59    /// Maximum batch size before forcing flush.
60    pub max_batch_size: usize,
61}
62
63impl Default for SharedTermsConfig {
64    fn default() -> Self {
65        Self {
66            enable_batching: true,
67            max_batch_size: 1000,
68        }
69    }
70}
71
72/// Statistics for shared terms.
73#[derive(Debug, Clone, Default)]
74pub struct SharedTermsStats {
75    /// Number of shared terms registered.
76    pub terms_registered: u64,
77    /// Number of theory subscriptions.
78    pub subscriptions: u64,
79    /// Equalities propagated.
80    pub equalities_propagated: u64,
81    /// Notification batches sent.
82    pub batches_sent: u64,
83}
84
85/// Shared terms manager for theory combination.
86#[derive(Debug)]
87pub struct SharedTermsManager {
88    /// Configuration.
89    config: SharedTermsConfig,
90    /// Shared term information.
91    terms: FxHashMap<TermId, SharedTermInfo>,
92    /// Equality classes (union-find).
93    parent: FxHashMap<TermId, TermId>,
94    /// Pending equalities to propagate.
95    pending_equalities: Vec<Equality>,
96    /// Theories subscribed to each term.
97    subscriptions: FxHashMap<TermId, FxHashSet<TheoryId>>,
98    /// Statistics.
99    stats: SharedTermsStats,
100}
101
102impl SharedTermsManager {
103    /// Create a new shared terms manager.
104    pub fn new(config: SharedTermsConfig) -> Self {
105        Self {
106            config,
107            terms: FxHashMap::default(),
108            parent: FxHashMap::default(),
109            pending_equalities: Vec::new(),
110            subscriptions: FxHashMap::default(),
111            stats: SharedTermsStats::default(),
112        }
113    }
114
115    /// Create with default configuration.
116    pub fn default_config() -> Self {
117        Self::new(SharedTermsConfig::default())
118    }
119
120    /// Register a shared term.
121    pub fn register_term(&mut self, term: TermId, theory: TheoryId) {
122        let entry = self.terms.entry(term).or_insert_with(|| {
123            self.stats.terms_registered += 1;
124            SharedTermInfo {
125                theories: FxHashSet::default(),
126                representative: term,
127            }
128        });
129
130        entry.theories.insert(theory);
131        self.stats.subscriptions += 1;
132
133        // Also track subscriptions separately for fast lookup
134        self.subscriptions.entry(term).or_default().insert(theory);
135    }
136
137    /// Check if a term is shared between multiple theories.
138    pub fn is_shared(&self, term: TermId) -> bool {
139        self.terms
140            .get(&term)
141            .map(|info| info.theories.len() > 1)
142            .unwrap_or(false)
143    }
144
145    /// Get theories that use a term.
146    pub fn get_theories(&self, term: TermId) -> Vec<TheoryId> {
147        self.terms
148            .get(&term)
149            .map(|info| info.theories.iter().copied().collect())
150            .unwrap_or_default()
151    }
152
153    /// Assert equality between two terms.
154    ///
155    /// This merges their equivalence classes and queues notifications.
156    pub fn assert_equality(&mut self, lhs: TermId, rhs: TermId) {
157        let lhs_rep = self.find(lhs);
158        let rhs_rep = self.find(rhs);
159
160        if lhs_rep == rhs_rep {
161            return; // Already equal
162        }
163
164        // Union: make lhs_rep point to rhs_rep
165        self.parent.insert(lhs_rep, rhs_rep);
166
167        // Queue equality for propagation
168        let equality = Equality::new(lhs, rhs);
169        self.pending_equalities.push(equality);
170        self.stats.equalities_propagated += 1;
171
172        // Check if should flush batch
173        if self.pending_equalities.len() >= self.config.max_batch_size {
174            self.flush_equalities();
175        }
176    }
177
178    /// Find representative of equivalence class (with path compression).
179    fn find(&mut self, term: TermId) -> TermId {
180        if let Some(&parent) = self.parent.get(&term)
181            && parent != term
182        {
183            let root = self.find(parent);
184            self.parent.insert(term, root); // Path compression
185            return root;
186        }
187
188        term
189    }
190
191    /// Check if two terms are in the same equivalence class.
192    pub fn are_equal(&mut self, lhs: TermId, rhs: TermId) -> bool {
193        self.find(lhs) == self.find(rhs)
194    }
195
196    /// Get pending equalities to propagate.
197    pub fn get_pending_equalities(&self) -> &[Equality] {
198        &self.pending_equalities
199    }
200
201    /// Flush pending equalities (send to theories).
202    pub fn flush_equalities(&mut self) {
203        if !self.pending_equalities.is_empty() {
204            self.stats.batches_sent += 1;
205            self.pending_equalities.clear();
206        }
207    }
208
209    /// Get all shared terms.
210    pub fn get_shared_terms(&self) -> Vec<TermId> {
211        self.terms
212            .iter()
213            .filter(|(_, info)| info.theories.len() > 1)
214            .map(|(&term, _)| term)
215            .collect()
216    }
217
218    /// Get statistics.
219    pub fn stats(&self) -> &SharedTermsStats {
220        &self.stats
221    }
222
223    /// Reset manager state.
224    pub fn reset(&mut self) {
225        self.terms.clear();
226        self.parent.clear();
227        self.pending_equalities.clear();
228        self.subscriptions.clear();
229        self.stats = SharedTermsStats::default();
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    fn term(id: u32) -> TermId {
238        TermId::new(id)
239    }
240
241    #[test]
242    fn test_manager_creation() {
243        let manager = SharedTermsManager::default_config();
244        assert_eq!(manager.stats().terms_registered, 0);
245    }
246
247    #[test]
248    fn test_register_term() {
249        let mut manager = SharedTermsManager::default_config();
250
251        manager.register_term(term(1), 0); // Theory 0
252        manager.register_term(term(1), 1); // Theory 1
253
254        assert!(manager.is_shared(term(1)));
255        assert_eq!(manager.get_theories(term(1)).len(), 2);
256    }
257
258    #[test]
259    fn test_equality() {
260        let mut manager = SharedTermsManager::default_config();
261
262        manager.assert_equality(term(1), term(2));
263
264        assert!(manager.are_equal(term(1), term(2)));
265        assert_eq!(manager.get_pending_equalities().len(), 1);
266    }
267
268    #[test]
269    fn test_equality_transitivity() {
270        let mut manager = SharedTermsManager::default_config();
271
272        manager.assert_equality(term(1), term(2));
273        manager.assert_equality(term(2), term(3));
274
275        assert!(manager.are_equal(term(1), term(3)));
276    }
277
278    #[test]
279    fn test_flush_equalities() {
280        let mut manager = SharedTermsManager::default_config();
281
282        manager.assert_equality(term(1), term(2));
283        assert_eq!(manager.get_pending_equalities().len(), 1);
284
285        manager.flush_equalities();
286        assert_eq!(manager.get_pending_equalities().len(), 0);
287    }
288}