oxiz_theories/
array_eager_expand.rs1use oxiz_core::ast::{TermId, TermManager};
42use oxiz_core::sort::SortId;
43use rustc_hash::FxHashMap;
44
45#[derive(Debug, Clone)]
47pub struct EagerExpandConfig {
48 pub enabled: bool,
50 pub max_domain_size: usize,
52 pub max_arrays: usize,
54 pub access_threshold: usize,
56}
57
58impl Default for EagerExpandConfig {
59 fn default() -> Self {
60 Self {
61 enabled: true,
62 max_domain_size: 10,
63 max_arrays: 100,
64 access_threshold: 3,
65 }
66 }
67}
68
69#[derive(Debug, Clone, Default)]
71pub struct EagerExpandStats {
72 pub arrays_expanded: u64,
74 pub elements_created: u64,
76 pub selects_simplified: u64,
78 pub stores_simplified: u64,
80 pub expansion_time_us: u64,
82}
83
84#[derive(Debug, Clone)]
86pub struct ExpandedArray {
87 pub array_term: TermId,
89 pub domain: Vec<i64>,
91 pub elements: FxHashMap<i64, TermId>,
93 pub element_sort: SortId,
95}
96
97pub struct EagerArrayExpander {
99 config: EagerExpandConfig,
101 stats: EagerExpandStats,
103 expanded: FxHashMap<TermId, ExpandedArray>,
105 access_counts: FxHashMap<TermId, usize>,
107}
108
109impl EagerArrayExpander {
110 pub fn new() -> Self {
112 Self::with_config(EagerExpandConfig::default())
113 }
114
115 pub fn with_config(config: EagerExpandConfig) -> Self {
117 Self {
118 config,
119 stats: EagerExpandStats::default(),
120 expanded: FxHashMap::default(),
121 access_counts: FxHashMap::default(),
122 }
123 }
124
125 pub fn stats(&self) -> &EagerExpandStats {
127 &self.stats
128 }
129
130 pub fn reset_stats(&mut self) {
132 self.stats = EagerExpandStats::default();
133 }
134
135 pub fn record_access(&mut self, array: TermId) {
137 *self.access_counts.entry(array).or_insert(0) += 1;
138 }
139
140 pub fn should_expand(&self, array: TermId, domain_size: usize) -> bool {
142 if !self.config.enabled {
143 return false;
144 }
145
146 if domain_size > self.config.max_domain_size {
147 return false;
148 }
149
150 if self.expanded.len() >= self.config.max_arrays {
151 return false;
152 }
153
154 if self.expanded.contains_key(&array) {
155 return false; }
157
158 let access_count = self.access_counts.get(&array).copied().unwrap_or(0);
160 access_count >= self.config.access_threshold
161 }
162
163 pub fn expand_array(
167 &mut self,
168 array: TermId,
169 domain: Vec<i64>,
170 element_sort: SortId,
171 tm: &mut TermManager,
172 ) -> Result<(), String> {
173 if self.expanded.contains_key(&array) {
174 return Ok(()); }
176
177 let start = std::time::Instant::now();
178
179 let mut elements = FxHashMap::default();
181
182 for &index_val in &domain {
183 let element_name = format!("array_{}_{}", array.raw(), index_val);
185 let element_var = tm.mk_var(&element_name, element_sort);
186 elements.insert(index_val, element_var);
187 }
188
189 self.stats.elements_created += elements.len() as u64;
190
191 let expanded_array = ExpandedArray {
192 array_term: array,
193 domain: domain.clone(),
194 elements,
195 element_sort,
196 };
197
198 self.expanded.insert(array, expanded_array);
199 self.stats.arrays_expanded += 1;
200 self.stats.expansion_time_us += start.elapsed().as_micros() as u64;
201
202 Ok(())
203 }
204
205 pub fn get_element(&self, array: TermId, index: i64) -> Option<TermId> {
207 self.expanded
208 .get(&array)
209 .and_then(|exp| exp.elements.get(&index).copied())
210 }
211
212 pub fn is_expanded(&self, array: TermId) -> bool {
214 self.expanded.contains_key(&array)
215 }
216
217 pub fn get_expanded(&self, array: TermId) -> Option<&ExpandedArray> {
219 self.expanded.get(&array)
220 }
221
222 pub fn clear(&mut self) {
224 self.expanded.clear();
225 self.access_counts.clear();
226 }
227}
228
229impl Default for EagerArrayExpander {
230 fn default() -> Self {
231 Self::new()
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 #[test]
240 fn test_eager_expand_config_default() {
241 let config = EagerExpandConfig::default();
242 assert!(config.enabled);
243 assert_eq!(config.max_domain_size, 10);
244 }
245
246 #[test]
247 fn test_eager_expander_creation() {
248 let expander = EagerArrayExpander::new();
249 assert_eq!(expander.stats().arrays_expanded, 0);
250 }
251
252 #[test]
253 fn test_record_access() {
254 let mut expander = EagerArrayExpander::new();
255
256 let array = TermId::new(1);
257
258 expander.record_access(array);
259 expander.record_access(array);
260 expander.record_access(array);
261
262 assert_eq!(expander.access_counts.get(&array), Some(&3));
263 }
264
265 #[test]
266 fn test_should_expand_small_domain() {
267 let mut expander = EagerArrayExpander::new();
268
269 let array = TermId::new(1);
270
271 for _ in 0..3 {
273 expander.record_access(array);
274 }
275
276 assert!(expander.should_expand(array, 5));
277 assert!(!expander.should_expand(array, 100)); }
279
280 #[test]
281 fn test_should_expand_disabled() {
282 let config = EagerExpandConfig {
283 enabled: false,
284 ..Default::default()
285 };
286 let expander = EagerArrayExpander::with_config(config);
287
288 assert!(!expander.should_expand(TermId::new(1), 5));
289 }
290
291 #[test]
292 fn test_is_expanded() {
293 let mut expander = EagerArrayExpander::new();
294 let mut tm = TermManager::new();
295
296 let array = TermId::new(1);
297 let sort = tm.sorts.int_sort;
298
299 assert!(!expander.is_expanded(array));
300
301 let _result = expander.expand_array(array, vec![0, 1, 2], sort, &mut tm);
302
303 assert!(expander.is_expanded(array));
304 assert_eq!(expander.stats().arrays_expanded, 1);
305 assert_eq!(expander.stats().elements_created, 3);
306 }
307
308 #[test]
309 fn test_get_element() {
310 let mut expander = EagerArrayExpander::new();
311 let mut tm = TermManager::new();
312
313 let array = TermId::new(1);
314 let sort = tm.sorts.int_sort;
315
316 let _result = expander.expand_array(array, vec![0, 1, 2], sort, &mut tm);
317
318 assert!(expander.get_element(array, 0).is_some());
320 assert!(expander.get_element(array, 1).is_some());
321 assert!(expander.get_element(array, 2).is_some());
322 assert!(expander.get_element(array, 3).is_none()); }
324
325 #[test]
326 fn test_stats_reset() {
327 let mut expander = EagerArrayExpander::new();
328
329 expander.stats.arrays_expanded = 10;
330 expander.stats.elements_created = 50;
331
332 expander.reset_stats();
333
334 assert_eq!(expander.stats().arrays_expanded, 0);
335 assert_eq!(expander.stats().elements_created, 0);
336 }
337
338 #[test]
339 fn test_clear() {
340 let mut expander = EagerArrayExpander::new();
341 let mut tm = TermManager::new();
342
343 let array = TermId::new(1);
344 let sort = tm.sorts.int_sort;
345
346 expander
347 .expand_array(array, vec![0, 1, 2], sort, &mut tm)
348 .unwrap();
349 expander.record_access(array);
350
351 assert!(!expander.expanded.is_empty());
352 assert!(!expander.access_counts.is_empty());
353
354 expander.clear();
355
356 assert!(expander.expanded.is_empty());
357 assert!(expander.access_counts.is_empty());
358 }
359}