oxiz_solver/
shared_terms.rs1#![allow(dead_code)] use oxiz_core::TermId;
19use rustc_hash::{FxHashMap, FxHashSet};
20
21pub type TheoryId = usize;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub struct Equality {
27 pub lhs: TermId,
29 pub rhs: TermId,
31}
32
33impl Equality {
34 pub fn new(lhs: TermId, rhs: TermId) -> Self {
36 if lhs.raw() <= rhs.raw() {
38 Self { lhs, rhs }
39 } else {
40 Self { lhs: rhs, rhs: lhs }
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47struct SharedTermInfo {
48 theories: FxHashSet<TheoryId>,
50 representative: TermId,
52}
53
54#[derive(Debug, Clone)]
56pub struct SharedTermsConfig {
57 pub enable_batching: bool,
59 pub max_batch_size: usize,
61}
62
63impl Default for SharedTermsConfig {
64 fn default() -> Self {
65 Self {
66 enable_batching: true,
67 max_batch_size: 1000,
68 }
69 }
70}
71
72#[derive(Debug, Clone, Default)]
74pub struct SharedTermsStats {
75 pub terms_registered: u64,
77 pub subscriptions: u64,
79 pub equalities_propagated: u64,
81 pub batches_sent: u64,
83}
84
85#[derive(Debug)]
87pub struct SharedTermsManager {
88 config: SharedTermsConfig,
90 terms: FxHashMap<TermId, SharedTermInfo>,
92 parent: FxHashMap<TermId, TermId>,
94 pending_equalities: Vec<Equality>,
96 subscriptions: FxHashMap<TermId, FxHashSet<TheoryId>>,
98 stats: SharedTermsStats,
100}
101
102impl SharedTermsManager {
103 pub fn new(config: SharedTermsConfig) -> Self {
105 Self {
106 config,
107 terms: FxHashMap::default(),
108 parent: FxHashMap::default(),
109 pending_equalities: Vec::new(),
110 subscriptions: FxHashMap::default(),
111 stats: SharedTermsStats::default(),
112 }
113 }
114
115 pub fn default_config() -> Self {
117 Self::new(SharedTermsConfig::default())
118 }
119
120 pub fn register_term(&mut self, term: TermId, theory: TheoryId) {
122 let entry = self.terms.entry(term).or_insert_with(|| {
123 self.stats.terms_registered += 1;
124 SharedTermInfo {
125 theories: FxHashSet::default(),
126 representative: term,
127 }
128 });
129
130 entry.theories.insert(theory);
131 self.stats.subscriptions += 1;
132
133 self.subscriptions.entry(term).or_default().insert(theory);
135 }
136
137 pub fn is_shared(&self, term: TermId) -> bool {
139 self.terms
140 .get(&term)
141 .map(|info| info.theories.len() > 1)
142 .unwrap_or(false)
143 }
144
145 pub fn get_theories(&self, term: TermId) -> Vec<TheoryId> {
147 self.terms
148 .get(&term)
149 .map(|info| info.theories.iter().copied().collect())
150 .unwrap_or_default()
151 }
152
153 pub fn assert_equality(&mut self, lhs: TermId, rhs: TermId) {
157 let lhs_rep = self.find(lhs);
158 let rhs_rep = self.find(rhs);
159
160 if lhs_rep == rhs_rep {
161 return; }
163
164 self.parent.insert(lhs_rep, rhs_rep);
166
167 let equality = Equality::new(lhs, rhs);
169 self.pending_equalities.push(equality);
170 self.stats.equalities_propagated += 1;
171
172 if self.pending_equalities.len() >= self.config.max_batch_size {
174 self.flush_equalities();
175 }
176 }
177
178 fn find(&mut self, term: TermId) -> TermId {
180 if let Some(&parent) = self.parent.get(&term)
181 && parent != term
182 {
183 let root = self.find(parent);
184 self.parent.insert(term, root); return root;
186 }
187
188 term
189 }
190
191 pub fn are_equal(&mut self, lhs: TermId, rhs: TermId) -> bool {
193 self.find(lhs) == self.find(rhs)
194 }
195
196 pub fn get_pending_equalities(&self) -> &[Equality] {
198 &self.pending_equalities
199 }
200
201 pub fn flush_equalities(&mut self) {
203 if !self.pending_equalities.is_empty() {
204 self.stats.batches_sent += 1;
205 self.pending_equalities.clear();
206 }
207 }
208
209 pub fn get_shared_terms(&self) -> Vec<TermId> {
211 self.terms
212 .iter()
213 .filter(|(_, info)| info.theories.len() > 1)
214 .map(|(&term, _)| term)
215 .collect()
216 }
217
218 pub fn stats(&self) -> &SharedTermsStats {
220 &self.stats
221 }
222
223 pub fn reset(&mut self) {
225 self.terms.clear();
226 self.parent.clear();
227 self.pending_equalities.clear();
228 self.subscriptions.clear();
229 self.stats = SharedTermsStats::default();
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 fn term(id: u32) -> TermId {
238 TermId::new(id)
239 }
240
241 #[test]
242 fn test_manager_creation() {
243 let manager = SharedTermsManager::default_config();
244 assert_eq!(manager.stats().terms_registered, 0);
245 }
246
247 #[test]
248 fn test_register_term() {
249 let mut manager = SharedTermsManager::default_config();
250
251 manager.register_term(term(1), 0); manager.register_term(term(1), 1); assert!(manager.is_shared(term(1)));
255 assert_eq!(manager.get_theories(term(1)).len(), 2);
256 }
257
258 #[test]
259 fn test_equality() {
260 let mut manager = SharedTermsManager::default_config();
261
262 manager.assert_equality(term(1), term(2));
263
264 assert!(manager.are_equal(term(1), term(2)));
265 assert_eq!(manager.get_pending_equalities().len(), 1);
266 }
267
268 #[test]
269 fn test_equality_transitivity() {
270 let mut manager = SharedTermsManager::default_config();
271
272 manager.assert_equality(term(1), term(2));
273 manager.assert_equality(term(2), term(3));
274
275 assert!(manager.are_equal(term(1), term(3)));
276 }
277
278 #[test]
279 fn test_flush_equalities() {
280 let mut manager = SharedTermsManager::default_config();
281
282 manager.assert_equality(term(1), term(2));
283 assert_eq!(manager.get_pending_equalities().len(), 1);
284
285 manager.flush_equalities();
286 assert_eq!(manager.get_pending_equalities().len(), 0);
287 }
288}