oxiz_solver/
shared_terms.rs1#![allow(dead_code)] #[allow(unused_imports)]
19use crate::prelude::*;
20use oxiz_core::TermId;
21
22pub type TheoryId = usize;
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
27pub struct Equality {
28 pub lhs: TermId,
30 pub rhs: TermId,
32}
33
34impl Equality {
35 pub fn new(lhs: TermId, rhs: TermId) -> Self {
37 if lhs.raw() <= rhs.raw() {
39 Self { lhs, rhs }
40 } else {
41 Self { lhs: rhs, rhs: lhs }
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
48struct SharedTermInfo {
49 theories: FxHashSet<TheoryId>,
51 representative: TermId,
53}
54
55#[derive(Debug, Clone)]
57pub struct SharedTermsConfig {
58 pub enable_batching: bool,
60 pub max_batch_size: usize,
62}
63
64impl Default for SharedTermsConfig {
65 fn default() -> Self {
66 Self {
67 enable_batching: true,
68 max_batch_size: 1000,
69 }
70 }
71}
72
73#[derive(Debug, Clone, Default)]
75pub struct SharedTermsStats {
76 pub terms_registered: u64,
78 pub subscriptions: u64,
80 pub equalities_propagated: u64,
82 pub batches_sent: u64,
84}
85
86#[derive(Debug)]
88pub struct SharedTermsManager {
89 config: SharedTermsConfig,
91 terms: FxHashMap<TermId, SharedTermInfo>,
93 parent: FxHashMap<TermId, TermId>,
95 pending_equalities: Vec<Equality>,
97 subscriptions: FxHashMap<TermId, FxHashSet<TheoryId>>,
99 stats: SharedTermsStats,
101}
102
103impl SharedTermsManager {
104 pub fn new(config: SharedTermsConfig) -> Self {
106 Self {
107 config,
108 terms: FxHashMap::default(),
109 parent: FxHashMap::default(),
110 pending_equalities: Vec::new(),
111 subscriptions: FxHashMap::default(),
112 stats: SharedTermsStats::default(),
113 }
114 }
115
116 pub fn default_config() -> Self {
118 Self::new(SharedTermsConfig::default())
119 }
120
121 pub fn register_term(&mut self, term: TermId, theory: TheoryId) {
123 let entry = self.terms.entry(term).or_insert_with(|| {
124 self.stats.terms_registered += 1;
125 SharedTermInfo {
126 theories: FxHashSet::default(),
127 representative: term,
128 }
129 });
130
131 entry.theories.insert(theory);
132 self.stats.subscriptions += 1;
133
134 self.subscriptions.entry(term).or_default().insert(theory);
136 }
137
138 pub fn is_shared(&self, term: TermId) -> bool {
140 self.terms
141 .get(&term)
142 .map(|info| info.theories.len() > 1)
143 .unwrap_or(false)
144 }
145
146 pub fn get_theories(&self, term: TermId) -> Vec<TheoryId> {
148 self.terms
149 .get(&term)
150 .map(|info| info.theories.iter().copied().collect())
151 .unwrap_or_default()
152 }
153
154 pub fn assert_equality(&mut self, lhs: TermId, rhs: TermId) {
158 let lhs_rep = self.find(lhs);
159 let rhs_rep = self.find(rhs);
160
161 if lhs_rep == rhs_rep {
162 return; }
164
165 self.parent.insert(lhs_rep, rhs_rep);
167
168 let equality = Equality::new(lhs, rhs);
170 self.pending_equalities.push(equality);
171 self.stats.equalities_propagated += 1;
172
173 if self.pending_equalities.len() >= self.config.max_batch_size {
175 self.flush_equalities();
176 }
177 }
178
179 fn find(&mut self, term: TermId) -> TermId {
181 if let Some(&parent) = self.parent.get(&term)
182 && parent != term
183 {
184 let root = self.find(parent);
185 self.parent.insert(term, root); return root;
187 }
188
189 term
190 }
191
192 pub fn are_equal(&mut self, lhs: TermId, rhs: TermId) -> bool {
194 self.find(lhs) == self.find(rhs)
195 }
196
197 pub fn get_pending_equalities(&self) -> &[Equality] {
199 &self.pending_equalities
200 }
201
202 pub fn flush_equalities(&mut self) {
204 if !self.pending_equalities.is_empty() {
205 self.stats.batches_sent += 1;
206 self.pending_equalities.clear();
207 }
208 }
209
210 pub fn get_shared_terms(&self) -> Vec<TermId> {
212 self.terms
213 .iter()
214 .filter(|(_, info)| info.theories.len() > 1)
215 .map(|(&term, _)| term)
216 .collect()
217 }
218
219 pub fn stats(&self) -> &SharedTermsStats {
221 &self.stats
222 }
223
224 pub fn reset(&mut self) {
226 self.terms.clear();
227 self.parent.clear();
228 self.pending_equalities.clear();
229 self.subscriptions.clear();
230 self.stats = SharedTermsStats::default();
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 fn term(id: u32) -> TermId {
239 TermId::new(id)
240 }
241
242 #[test]
243 fn test_manager_creation() {
244 let manager = SharedTermsManager::default_config();
245 assert_eq!(manager.stats().terms_registered, 0);
246 }
247
248 #[test]
249 fn test_register_term() {
250 let mut manager = SharedTermsManager::default_config();
251
252 manager.register_term(term(1), 0); manager.register_term(term(1), 1); assert!(manager.is_shared(term(1)));
256 assert_eq!(manager.get_theories(term(1)).len(), 2);
257 }
258
259 #[test]
260 fn test_equality() {
261 let mut manager = SharedTermsManager::default_config();
262
263 manager.assert_equality(term(1), term(2));
264
265 assert!(manager.are_equal(term(1), term(2)));
266 assert_eq!(manager.get_pending_equalities().len(), 1);
267 }
268
269 #[test]
270 fn test_equality_transitivity() {
271 let mut manager = SharedTermsManager::default_config();
272
273 manager.assert_equality(term(1), term(2));
274 manager.assert_equality(term(2), term(3));
275
276 assert!(manager.are_equal(term(1), term(3)));
277 }
278
279 #[test]
280 fn test_flush_equalities() {
281 let mut manager = SharedTermsManager::default_config();
282
283 manager.assert_equality(term(1), term(2));
284 assert_eq!(manager.get_pending_equalities().len(), 1);
285
286 manager.flush_equalities();
287 assert_eq!(manager.get_pending_equalities().len(), 0);
288 }
289}