1use oxiz_core::ast::TermId;
10use rustc_hash::FxHashMap;
11
12pub struct BacktrackManager {
14 decision_stack: Vec<DecisionLevel>,
16 current_level: usize,
18 trail: Vec<Assignment>,
20 var_to_trail: FxHashMap<TermId, usize>,
22 undo_stack: Vec<UndoAction>,
24 checkpoints: Vec<Checkpoint>,
26 stats: BacktrackStats,
28}
29
30#[derive(Debug, Clone)]
32pub struct DecisionLevel {
33 pub level: usize,
35 pub decision: Option<Assignment>,
37 pub trail_start: usize,
39}
40
41#[derive(Debug, Clone)]
43pub struct Assignment {
44 pub var: TermId,
46 pub value: bool,
48 pub level: usize,
50 pub reason: Option<ReasonClause>,
52}
53
54#[derive(Debug, Clone)]
56pub enum ReasonClause {
57 BooleanClause(Vec<TermId>),
59 TheoryPropagation(TheoryReason),
61}
62
63#[derive(Debug, Clone)]
65pub struct TheoryReason {
66 pub theory_id: usize,
68 pub explanation: Vec<TermId>,
70}
71
72#[derive(Debug, Clone)]
74pub enum UndoAction {
75 UndoEquality(TermId, TermId),
77 UndoBound(TermId, BoundUpdate),
79 UndoArrayStore(TermId),
81 TheoryUndo(usize, Vec<TermId>),
83}
84
85#[derive(Debug, Clone)]
87pub struct BoundUpdate {
88 pub old_lower: Option<i64>,
90 pub old_upper: Option<i64>,
92}
93
94#[derive(Debug, Clone)]
96pub struct Checkpoint {
97 pub level: usize,
99 pub trail_size: usize,
101 pub undo_size: usize,
103}
104
105#[derive(Debug, Clone, Default)]
107pub struct BacktrackStats {
108 pub decisions: usize,
110 pub propagations: usize,
112 pub backtracks: usize,
114 pub restarts: usize,
116 pub undos: usize,
118 pub max_level: usize,
120}
121
122impl BacktrackManager {
123 pub fn new() -> Self {
125 Self {
126 decision_stack: vec![DecisionLevel {
127 level: 0,
128 decision: None,
129 trail_start: 0,
130 }],
131 current_level: 0,
132 trail: Vec::new(),
133 var_to_trail: FxHashMap::default(),
134 undo_stack: Vec::new(),
135 checkpoints: Vec::new(),
136 stats: BacktrackStats::default(),
137 }
138 }
139
140 pub fn decide(&mut self, var: TermId, value: bool) -> Result<(), String> {
142 self.current_level += 1;
143 self.stats.decisions += 1;
144
145 if self.current_level > self.stats.max_level {
146 self.stats.max_level = self.current_level;
147 }
148
149 let assignment = Assignment {
150 var,
151 value,
152 level: self.current_level,
153 reason: None, };
155
156 let trail_start = self.trail.len();
158
159 self.decision_stack.push(DecisionLevel {
161 level: self.current_level,
162 decision: Some(assignment.clone()),
163 trail_start,
164 });
165
166 self.assign(assignment)?;
168
169 Ok(())
170 }
171
172 pub fn propagate(
174 &mut self,
175 var: TermId,
176 value: bool,
177 reason: ReasonClause,
178 ) -> Result<(), String> {
179 self.stats.propagations += 1;
180
181 let assignment = Assignment {
182 var,
183 value,
184 level: self.current_level,
185 reason: Some(reason),
186 };
187
188 self.assign(assignment)?;
189
190 Ok(())
191 }
192
193 fn assign(&mut self, assignment: Assignment) -> Result<(), String> {
195 if let Some(&trail_idx) = self.var_to_trail.get(&assignment.var) {
197 let existing = &self.trail[trail_idx];
198 if existing.value != assignment.value {
199 return Err(format!(
200 "Conflict: variable {:?} already assigned differently",
201 assignment.var
202 ));
203 }
204 return Ok(());
206 }
207
208 let trail_idx = self.trail.len();
210 self.var_to_trail.insert(assignment.var, trail_idx);
211
212 self.trail.push(assignment);
214
215 Ok(())
216 }
217
218 pub fn backtrack(&mut self, target_level: usize) -> Result<(), String> {
220 if target_level > self.current_level {
221 return Err("Cannot backtrack to higher level".to_string());
222 }
223
224 if target_level == self.current_level {
225 return Ok(()); }
227
228 self.stats.backtracks += 1;
229
230 let target_trail_pos = if target_level + 1 < self.decision_stack.len() {
233 self.decision_stack[target_level + 1].trail_start
234 } else {
235 self.trail.len()
237 };
238
239 while self.trail.len() > target_trail_pos {
241 if let Some(assignment) = self.trail.pop() {
242 self.var_to_trail.remove(&assignment.var);
243 }
244 }
245
246 self.undo_to_level(target_level)?;
248
249 self.decision_stack.truncate(target_level + 1);
251 self.current_level = target_level;
252
253 Ok(())
254 }
255
256 pub fn restart(&mut self) -> Result<(), String> {
258 self.stats.restarts += 1;
259 self.backtrack(0)
260 }
261
262 pub fn record_undo(&mut self, action: UndoAction) {
264 self.undo_stack.push(action);
265 }
266
267 fn undo_to_level(&mut self, _target_level: usize) -> Result<(), String> {
269 let mut undos_to_apply = Vec::new();
273
274 while let Some(undo) = self.undo_stack.pop() {
277 undos_to_apply.push(undo);
278
279 if undos_to_apply.len() > 100 {
281 break;
282 }
283 }
284
285 for undo in undos_to_apply {
287 self.apply_undo(undo)?;
288 self.stats.undos += 1;
289 }
290
291 Ok(())
292 }
293
294 fn apply_undo(&mut self, _action: UndoAction) -> Result<(), String> {
296 Ok(())
299 }
300
301 pub fn push_checkpoint(&mut self) {
303 self.checkpoints.push(Checkpoint {
304 level: self.current_level,
305 trail_size: self.trail.len(),
306 undo_size: self.undo_stack.len(),
307 });
308 }
309
310 pub fn pop_checkpoint(&mut self) -> Result<(), String> {
312 if let Some(checkpoint) = self.checkpoints.pop() {
313 self.backtrack(checkpoint.level)?;
314
315 self.undo_stack.truncate(checkpoint.undo_size);
317
318 Ok(())
319 } else {
320 Err("No checkpoint to pop".to_string())
321 }
322 }
323
324 pub fn current_level(&self) -> usize {
326 self.current_level
327 }
328
329 pub fn is_assigned(&self, var: TermId) -> bool {
331 self.var_to_trail.contains_key(&var)
332 }
333
334 pub fn get_assignment(&self, var: TermId) -> Option<&Assignment> {
336 if let Some(&trail_idx) = self.var_to_trail.get(&var) {
337 self.trail.get(trail_idx)
338 } else {
339 None
340 }
341 }
342
343 pub fn current_assignments(&self) -> &[Assignment] {
345 &self.trail
346 }
347
348 pub fn get_decision(&self, level: usize) -> Option<&Assignment> {
350 if level < self.decision_stack.len() {
351 self.decision_stack[level].decision.as_ref()
352 } else {
353 None
354 }
355 }
356
357 pub fn stats(&self) -> &BacktrackStats {
359 &self.stats
360 }
361
362 pub fn reset_stats(&mut self) {
364 self.stats = BacktrackStats::default();
365 }
366}
367
368impl Default for BacktrackManager {
369 fn default() -> Self {
370 Self::new()
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn test_backtrack_manager() {
380 let mgr = BacktrackManager::new();
381 assert_eq!(mgr.current_level(), 0);
382 assert_eq!(mgr.stats.decisions, 0);
383 }
384
385 #[test]
386 fn test_decide() {
387 let mut mgr = BacktrackManager::new();
388
389 let var = TermId::from(1);
390 mgr.decide(var, true).unwrap();
391
392 assert_eq!(mgr.current_level(), 1);
393 assert_eq!(mgr.stats.decisions, 1);
394 assert!(mgr.is_assigned(var));
395 }
396
397 #[test]
398 fn test_propagate() {
399 let mut mgr = BacktrackManager::new();
400
401 let var1 = TermId::from(1);
402 mgr.decide(var1, true).unwrap();
403
404 let var2 = TermId::from(2);
405 let reason = ReasonClause::BooleanClause(vec![var1]);
406 mgr.propagate(var2, false, reason).unwrap();
407
408 assert_eq!(mgr.stats.propagations, 1);
409 assert!(mgr.is_assigned(var2));
410 }
411
412 #[test]
413 fn test_backtrack() {
414 let mut mgr = BacktrackManager::new();
415
416 let var1 = TermId::from(1);
417 mgr.decide(var1, true).unwrap();
418
419 let var2 = TermId::from(2);
420 mgr.decide(var2, false).unwrap();
421
422 assert_eq!(mgr.current_level(), 2);
423
424 mgr.backtrack(1).unwrap();
425
426 assert_eq!(mgr.current_level(), 1);
427 assert!(mgr.is_assigned(var1));
428 assert!(!mgr.is_assigned(var2));
429 }
430
431 #[test]
432 fn test_restart() {
433 let mut mgr = BacktrackManager::new();
434
435 mgr.decide(TermId::from(1), true).unwrap();
436 mgr.decide(TermId::from(2), false).unwrap();
437
438 mgr.restart().unwrap();
439
440 assert_eq!(mgr.current_level(), 0);
441 assert_eq!(mgr.stats.restarts, 1);
442 assert_eq!(mgr.trail.len(), 0);
443 }
444
445 #[test]
446 fn test_checkpoint() {
447 let mut mgr = BacktrackManager::new();
448
449 mgr.decide(TermId::from(1), true).unwrap();
450 mgr.push_checkpoint();
451
452 mgr.decide(TermId::from(2), false).unwrap();
453
454 mgr.pop_checkpoint().unwrap();
455
456 assert_eq!(mgr.current_level(), 1);
457 assert!(!mgr.is_assigned(TermId::from(2)));
458 }
459
460 #[test]
461 fn test_get_assignment() {
462 let mut mgr = BacktrackManager::new();
463
464 let var = TermId::from(1);
465 mgr.decide(var, true).unwrap();
466
467 let assignment = mgr.get_assignment(var).unwrap();
468 assert!(assignment.value);
469 assert_eq!(assignment.level, 1);
470 assert!(assignment.reason.is_none());
471 }
472}