1use crate::proof::{Proof, ProofNodeId};
7use rustc_hash::FxHashMap;
8use std::collections::VecDeque;
9
10#[derive(Debug)]
17pub struct IncrementalProofBuilder {
18 proof: Proof,
20 scopes: Vec<ScopeFrame>,
22 clause_map: FxHashMap<u32, ProofNodeId>,
24 enabled: bool,
26}
27
28#[derive(Debug, Clone)]
30struct ScopeFrame {
31 #[allow(dead_code)]
34 proof_size: usize,
35 clause_mappings: Vec<(u32, ProofNodeId)>,
37}
38
39impl IncrementalProofBuilder {
40 #[must_use]
42 pub fn new() -> Self {
43 Self {
44 proof: Proof::new(),
45 scopes: Vec::new(),
46 clause_map: FxHashMap::default(),
47 enabled: true,
48 }
49 }
50
51 pub fn enable(&mut self) {
53 self.enabled = true;
54 }
55
56 pub fn disable(&mut self) {
58 self.enabled = false;
59 }
60
61 #[must_use]
63 pub const fn is_enabled(&self) -> bool {
64 self.enabled
65 }
66
67 pub fn push_scope(&mut self) {
69 self.scopes.push(ScopeFrame {
70 proof_size: self.proof.len(),
71 clause_mappings: Vec::new(),
72 });
73 }
74
75 pub fn pop_scope(&mut self) {
77 if let Some(frame) = self.scopes.pop() {
78 for (clause_id, _) in &frame.clause_mappings {
80 self.clause_map.remove(clause_id);
81 }
82
83 }
86 }
87
88 #[must_use]
90 pub fn scope_level(&self) -> usize {
91 self.scopes.len()
92 }
93
94 pub fn record_axiom(&mut self, clause_id: u32, conclusion: impl Into<String>) -> ProofNodeId {
96 if !self.enabled {
97 return ProofNodeId(0); }
99
100 let node_id = self.proof.add_axiom(conclusion);
101 self.clause_map.insert(clause_id, node_id);
102
103 if let Some(scope) = self.scopes.last_mut() {
105 scope.clause_mappings.push((clause_id, node_id));
106 }
107
108 node_id
109 }
110
111 pub fn record_inference(
113 &mut self,
114 clause_id: u32,
115 rule: impl Into<String>,
116 premise_ids: &[u32],
117 conclusion: impl Into<String>,
118 ) -> ProofNodeId {
119 if !self.enabled {
120 return ProofNodeId(0); }
122
123 let premises: Vec<ProofNodeId> = premise_ids
125 .iter()
126 .filter_map(|id| self.clause_map.get(id).copied())
127 .collect();
128
129 let node_id = self.proof.add_inference(rule, premises, conclusion);
130 self.clause_map.insert(clause_id, node_id);
131
132 if let Some(scope) = self.scopes.last_mut() {
134 scope.clause_mappings.push((clause_id, node_id));
135 }
136
137 node_id
138 }
139
140 pub fn record_inference_with_args(
142 &mut self,
143 clause_id: u32,
144 rule: impl Into<String>,
145 premise_ids: &[u32],
146 args: Vec<String>,
147 conclusion: impl Into<String>,
148 ) -> ProofNodeId {
149 if !self.enabled {
150 return ProofNodeId(0); }
152
153 let premises: Vec<ProofNodeId> = premise_ids
154 .iter()
155 .filter_map(|id| self.clause_map.get(id).copied())
156 .collect();
157
158 let node_id = self
159 .proof
160 .add_inference_with_args(rule, premises, args, conclusion);
161 self.clause_map.insert(clause_id, node_id);
162
163 if let Some(scope) = self.scopes.last_mut() {
164 scope.clause_mappings.push((clause_id, node_id));
165 }
166
167 node_id
168 }
169
170 #[must_use]
172 pub fn get_clause_proof(&self, clause_id: u32) -> Option<ProofNodeId> {
173 self.clause_map.get(&clause_id).copied()
174 }
175
176 #[must_use]
178 pub const fn proof(&self) -> &Proof {
179 &self.proof
180 }
181
182 pub fn take_proof(&mut self) -> Proof {
184 self.clause_map.clear();
185 self.scopes.clear();
186 std::mem::take(&mut self.proof)
187 }
188
189 pub fn reset(&mut self) {
191 self.proof.clear();
192 self.scopes.clear();
193 self.clause_map.clear();
194 }
195
196 #[must_use]
198 pub fn stats(&self) -> IncrementalStats {
199 IncrementalStats {
200 total_nodes: self.proof.len(),
201 scope_level: self.scopes.len(),
202 clause_mappings: self.clause_map.len(),
203 enabled: self.enabled,
204 }
205 }
206}
207
208impl Default for IncrementalProofBuilder {
209 fn default() -> Self {
210 Self::new()
211 }
212}
213
214#[derive(Debug, Clone, Copy)]
216pub struct IncrementalStats {
217 pub total_nodes: usize,
219 pub scope_level: usize,
221 pub clause_mappings: usize,
223 pub enabled: bool,
225}
226
227#[derive(Debug)]
231pub struct ProofRecorder {
232 builder: IncrementalProofBuilder,
234 next_clause_id: u32,
236 queue: VecDeque<RecordedStep>,
238 batch_mode: bool,
240}
241
242#[derive(Debug, Clone)]
244enum RecordedStep {
245 Axiom {
246 clause_id: u32,
247 conclusion: String,
248 },
249 Inference {
250 clause_id: u32,
251 rule: String,
252 premises: Vec<u32>,
253 conclusion: String,
254 args: Vec<String>,
255 },
256}
257
258#[cfg(feature = "profiling")]
259use oxiz_core::profiling::{ProfilingCategory, ScopedTimer};
260
261impl ProofRecorder {
262 #[must_use]
264 pub fn new() -> Self {
265 Self {
266 builder: IncrementalProofBuilder::new(),
267 next_clause_id: 0,
268 queue: VecDeque::new(),
269 batch_mode: false,
270 }
271 }
272
273 pub fn enable_batch_mode(&mut self) {
275 self.batch_mode = true;
276 }
277
278 pub fn disable_batch_mode(&mut self) {
280 self.batch_mode = false;
281 }
282
283 pub fn alloc_clause_id(&mut self) -> u32 {
285 let id = self.next_clause_id;
286 self.next_clause_id += 1;
287 id
288 }
289
290 pub fn record_input(&mut self, conclusion: impl Into<String>) -> u32 {
292 let clause_id = self.alloc_clause_id();
293 self.record_step(RecordedStep::Axiom {
294 clause_id,
295 conclusion: conclusion.into(),
296 });
297
298 clause_id
299 }
300
301 pub fn record_derived(
303 &mut self,
304 rule: impl Into<String>,
305 premises: &[u32],
306 conclusion: impl Into<String>,
307 ) -> u32 {
308 let clause_id = self.alloc_clause_id();
309 self.record_step(RecordedStep::Inference {
310 clause_id,
311 rule: rule.into(),
312 premises: premises.to_vec(),
313 conclusion: conclusion.into(),
314 args: Vec::new(),
315 });
316
317 clause_id
318 }
319
320 fn record_step(&mut self, step: RecordedStep) {
321 #[cfg(feature = "profiling")]
322 let _timer = ScopedTimer::new(ProfilingCategory::ProofGeneration);
323 if self.batch_mode {
324 self.queue.push_back(step);
325 return;
326 }
327
328 match step {
329 RecordedStep::Axiom {
330 clause_id,
331 conclusion,
332 } => {
333 self.builder.record_axiom(clause_id, conclusion);
334 }
335 RecordedStep::Inference {
336 clause_id,
337 rule,
338 premises,
339 conclusion,
340 args,
341 } => {
342 self.builder
343 .record_inference_with_args(clause_id, rule, &premises, args, conclusion);
344 }
345 }
346 }
347
348 pub fn flush(&mut self) {
350 while let Some(step) = self.queue.pop_front() {
351 match step {
352 RecordedStep::Axiom {
353 clause_id,
354 conclusion,
355 } => {
356 self.builder.record_axiom(clause_id, conclusion);
357 }
358 RecordedStep::Inference {
359 clause_id,
360 rule,
361 premises,
362 conclusion,
363 args,
364 } => {
365 self.builder
366 .record_inference_with_args(clause_id, rule, &premises, args, conclusion);
367 }
368 }
369 }
370 }
371
372 #[must_use]
374 pub const fn builder(&self) -> &IncrementalProofBuilder {
375 &self.builder
376 }
377
378 pub fn builder_mut(&mut self) -> &mut IncrementalProofBuilder {
380 &mut self.builder
381 }
382
383 pub fn take_proof(&mut self) -> Proof {
385 self.flush();
386 self.builder.take_proof()
387 }
388}
389
390impl Default for ProofRecorder {
391 fn default() -> Self {
392 Self::new()
393 }
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399
400 #[test]
401 fn test_incremental_builder_creation() {
402 let builder = IncrementalProofBuilder::new();
403 assert!(builder.is_enabled());
404 assert_eq!(builder.scope_level(), 0);
405 }
406
407 #[test]
408 fn test_record_axiom() {
409 let mut builder = IncrementalProofBuilder::new();
410 let clause_id = 1;
411 builder.record_axiom(clause_id, "p");
412
413 assert!(builder.get_clause_proof(clause_id).is_some());
414 assert_eq!(builder.proof().len(), 1);
415 }
416
417 #[test]
418 fn test_record_inference() {
419 let mut builder = IncrementalProofBuilder::new();
420 let c1 = 1;
421 let c2 = 2;
422 let c3 = 3;
423
424 builder.record_axiom(c1, "p");
425 builder.record_axiom(c2, "q");
426 builder.record_inference(c3, "and", &[c1, c2], "p /\\ q");
427
428 assert_eq!(builder.proof().len(), 3);
429 }
430
431 #[test]
432 fn test_push_pop_scope() {
433 let mut builder = IncrementalProofBuilder::new();
434
435 builder.push_scope();
436 assert_eq!(builder.scope_level(), 1);
437
438 let c1 = 1;
439 builder.record_axiom(c1, "p");
440
441 builder.pop_scope();
442 assert_eq!(builder.scope_level(), 0);
443 assert!(builder.get_clause_proof(c1).is_none()); }
445
446 #[test]
447 fn test_disabled_recording() {
448 let mut builder = IncrementalProofBuilder::new();
449 builder.disable();
450
451 let c1 = 1;
452 builder.record_axiom(c1, "p");
453
454 assert_eq!(builder.proof().len(), 0);
456 }
457
458 #[test]
459 fn test_take_proof() {
460 let mut builder = IncrementalProofBuilder::new();
461 builder.record_axiom(1, "p");
462 builder.record_axiom(2, "q");
463
464 let proof = builder.take_proof();
465 assert_eq!(proof.len(), 2);
466 assert_eq!(builder.proof().len(), 0);
467 }
468
469 #[test]
470 fn test_proof_recorder() {
471 let mut recorder = ProofRecorder::new();
472
473 let c1 = recorder.record_input("p");
474 let c2 = recorder.record_input("q");
475 let c3 = recorder.record_derived("and", &[c1, c2], "p /\\ q");
476
477 assert!(c3 > c2);
478
479 let proof = recorder.take_proof();
480 assert_eq!(proof.len(), 3);
481 }
482
483 #[test]
484 fn test_proof_recorder_batch() {
485 let mut recorder = ProofRecorder::new();
486 recorder.enable_batch_mode();
487
488 let _c1 = recorder.record_input("p");
489 let _c2 = recorder.record_input("q");
490
491 assert_eq!(recorder.builder().proof().len(), 0);
493
494 recorder.flush();
495
496 assert_eq!(recorder.builder().proof().len(), 2);
498 }
499
500 #[test]
501 fn test_incremental_stats() {
502 let mut builder = IncrementalProofBuilder::new();
503 builder.record_axiom(1, "p");
504 builder.push_scope();
505 builder.record_axiom(2, "q");
506
507 let stats = builder.stats();
508 assert_eq!(stats.total_nodes, 2);
509 assert_eq!(stats.scope_level, 1);
510 assert!(stats.enabled);
511 }
512
513 mod proptests {
515 use super::*;
516 use proptest::prelude::*;
517
518 fn var_name() -> impl Strategy<Value = String> {
519 "[a-z][0-9]*".prop_map(|s| s.to_string())
520 }
521
522 proptest! {
523 #[test]
525 fn prop_record_axiom_increases_size(
526 conclusions in prop::collection::vec(var_name(), 1..10)
527 ) {
528 let mut builder = IncrementalProofBuilder::new();
529 let mut seen = std::collections::HashSet::new();
530
531 for (i, conclusion) in conclusions.iter().enumerate() {
532 let initial_len = builder.proof().len();
533 builder.record_axiom(i as u32 + 1, conclusion);
534
535 if seen.insert(conclusion.clone()) {
537 prop_assert!(builder.proof().len() > initial_len);
538 } else {
539 prop_assert_eq!(builder.proof().len(), initial_len);
541 }
542 }
543 }
544
545 #[test]
547 fn prop_scope_level_tracking(depth in 0..10_usize) {
548 let mut builder = IncrementalProofBuilder::new();
549 prop_assert_eq!(builder.scope_level(), 0);
550
551 for i in 0..depth {
552 builder.push_scope();
553 prop_assert_eq!(builder.scope_level(), i + 1);
554 }
555
556 for i in (0..depth).rev() {
557 builder.pop_scope();
558 prop_assert_eq!(builder.scope_level(), i);
559 }
560 }
561
562 #[test]
564 fn prop_pop_removes_scope_mappings(
565 base_count in 1..4_usize,
566 scope_count in 1..4_usize
567 ) {
568 let mut builder = IncrementalProofBuilder::new();
569
570 for i in 0..base_count {
572 builder.record_axiom(i as u32 + 1, format!("base{}", i));
573 }
574 let base_mappings = builder.stats().clause_mappings;
575
576 builder.push_scope();
578 for i in 0..scope_count {
579 builder.record_axiom((base_count + i) as u32 + 100, format!("scope{}", i));
580 }
581
582 builder.pop_scope();
584 prop_assert_eq!(builder.stats().clause_mappings, base_mappings);
585 }
586
587 #[test]
589 fn prop_disabled_no_recording(
590 conclusions in prop::collection::vec(var_name(), 1..10)
591 ) {
592 let mut builder = IncrementalProofBuilder::new();
593 builder.disable();
594
595 for (i, conclusion) in conclusions.iter().enumerate() {
596 builder.record_axiom(i as u32 + 1, conclusion);
597 }
598
599 prop_assert_eq!(builder.proof().len(), 0);
600 }
601
602 #[test]
604 fn prop_stats_consistency(
605 conclusions in prop::collection::vec(var_name(), 1..8),
606 scope_depth in 0..5_usize
607 ) {
608 let mut builder = IncrementalProofBuilder::new();
609
610 for _ in 0..scope_depth {
611 builder.push_scope();
612 }
613
614 for (i, conclusion) in conclusions.iter().enumerate() {
615 builder.record_axiom(i as u32 + 1, conclusion);
616 }
617
618 let stats = builder.stats();
619 prop_assert_eq!(stats.scope_level, scope_depth);
620 prop_assert!(stats.enabled);
621 prop_assert_eq!(stats.total_nodes, builder.proof().len());
622 }
623
624 #[test]
626 fn prop_take_proof_empties(
627 conclusions in prop::collection::vec(var_name(), 1..8)
628 ) {
629 let mut builder = IncrementalProofBuilder::new();
630
631 for (i, conclusion) in conclusions.iter().enumerate() {
632 builder.record_axiom(i as u32 + 1, conclusion);
633 }
634
635 let proof_len = builder.proof().len();
636 let taken = builder.take_proof();
637
638 prop_assert_eq!(taken.len(), proof_len);
639 prop_assert_eq!(builder.proof().len(), 0);
640 }
641
642 #[test]
644 fn prop_recorder_id_ordering(count in 2..8_usize) {
645 let mut recorder = ProofRecorder::new();
646 let mut last_id = None;
647
648 for i in 0..count {
650 let id = recorder.record_input(format!("unique{}", i));
651 if let Some(prev_id) = last_id {
652 prop_assert!(id > prev_id);
653 }
654 last_id = Some(id);
655 }
656 }
657
658 #[test]
660 fn prop_batch_mode_delays(
661 conclusions in prop::collection::vec(var_name(), 1..8)
662 ) {
663 let mut recorder = ProofRecorder::new();
664 recorder.enable_batch_mode();
665
666 for conclusion in &conclusions {
667 recorder.record_input(conclusion);
668 }
669
670 prop_assert_eq!(recorder.builder().proof().len(), 0);
672
673 recorder.flush();
674
675 let unique_count = conclusions.iter().collect::<std::collections::HashSet<_>>().len();
677 prop_assert_eq!(recorder.builder().proof().len(), unique_count);
678 }
679
680 #[test]
682 fn prop_derived_dependencies(
683 conclusions in prop::collection::vec(var_name(), 2..6)
684 ) {
685 let mut recorder = ProofRecorder::new();
686 let mut clause_ids = Vec::new();
687
688 for conclusion in &conclusions {
690 clause_ids.push(recorder.record_input(conclusion));
691 }
692
693 if clause_ids.len() >= 2 {
695 let derived_id = recorder.record_derived(
696 "and",
697 &clause_ids[..2],
698 "derived"
699 );
700
701 prop_assert!(derived_id > clause_ids[0]);
703 prop_assert!(derived_id > clause_ids[1]);
704 }
705 }
706 }
707 }
708}