use crate::cube::{shift_lit, Cube};
use crate::frames::{Frame, FrameSequence};
use crate::session::{self, PdrSession};
use crate::phase::*;
use warp_types_bmc::TransitionSystem;
use warp_types_sat::bcp::ClauseDb;
use warp_types_sat::literal::Lit;
use warp_types_sat::solver::{solve_watched_budget, SolveResult};
#[derive(Debug)]
pub enum PdrResult {
Safe {
invariant_frame: usize,
},
CounterexampleFound {
depth: u32,
trace: Vec<Vec<bool>>,
},
Exhausted {
frames_explored: usize,
},
}
struct Obligation {
cube: Cube,
level: usize,
parent: Option<usize>,
}
pub fn check(
sys: &TransitionSystem,
max_frames: u32,
conflict_budget: u64,
) -> PdrResult {
let n = sys.num_state_vars;
session::with_session(|init: PdrSession<'_, Init>| {
let modeled = init.build_model();
if let Some(assignment) = check_initiation(sys, conflict_budget) {
let trace = vec![assignment[..n as usize].to_vec()];
let _cex = modeled.check_counterexample();
return PdrResult::CounterexampleFound { depth: 0, trace };
}
let mut frames = FrameSequence::new();
let init_clauses: Vec<Vec<Lit>> = sys
.initial
.iter()
.map(|c| c.lits.clone())
.collect();
frames.push(Frame::from_clauses(init_clauses));
frames.push(Frame::new());
for _iteration in 0..max_frames {
let k = frames.frontier();
loop {
let cti = find_cti(sys, &frames, k, conflict_budget);
match cti {
None => break, Some(cube) => {
match block_cube(sys, &mut frames, cube, k, conflict_budget) {
BlockResult::Blocked => continue,
BlockResult::Counterexample(trace) => {
let _cex = modeled.check_counterexample();
return PdrResult::CounterexampleFound {
depth: trace.len() as u32 - 1,
trace,
};
}
}
}
}
}
if let Some(inv_frame) = propagate_clauses(sys, &mut frames, conflict_budget) {
let _safe = modeled.check_safe();
return PdrResult::Safe {
invariant_frame: inv_frame,
};
}
frames.push(Frame::new());
}
let _exhausted = modeled.check_exhausted();
PdrResult::Exhausted {
frames_explored: frames.len(),
}
})
}
enum BlockResult {
Blocked,
Counterexample(Vec<Vec<bool>>),
}
fn check_initiation(sys: &TransitionSystem, conflict_budget: u64) -> Option<Vec<bool>> {
let n = sys.num_state_vars;
let num_tseitin = sys.property.len() as u32;
let total_vars = n + num_tseitin;
let mut db = ClauseDb::new();
for clause in &sys.initial {
db.add_clause(clause.lits.clone());
}
add_negated_property(&mut db, sys, 0, n);
let (result, _) = solve_watched_budget(db, total_vars, conflict_budget);
match result {
SolveResult::Sat(assign) => Some(assign),
_ => None,
}
}
fn find_cti(
sys: &TransitionSystem,
frames: &FrameSequence,
level: usize,
conflict_budget: u64,
) -> Option<Cube> {
let n = sys.num_state_vars;
let num_tseitin = sys.property.len() as u32;
let total_vars = n + num_tseitin;
let mut db = ClauseDb::new();
add_frame_clauses(&mut db, frames.frame(level), 0);
add_negated_property(&mut db, sys, 0, n);
let (result, _) = solve_watched_budget(db, total_vars, conflict_budget);
match result {
SolveResult::Sat(assign) => Some(Cube::from_assignment(&assign, n)),
_ => None,
}
}
fn check_predecessor(
sys: &TransitionSystem,
frames: &FrameSequence,
cube: &Cube,
level: usize,
conflict_budget: u64,
) -> Option<Cube> {
let n = sys.num_state_vars;
let total_vars = 2 * n;
let mut db = ClauseDb::new();
add_frame_clauses(&mut db, frames.frame(level), 0);
for tc in &sys.transition {
db.add_clause(tc.lits.clone());
}
let shifted = cube.shift(n);
for &lit in &shifted.lits {
db.add_clause(vec![lit]);
}
let (result, _) = solve_watched_budget(db, total_vars, conflict_budget);
match result {
SolveResult::Sat(assign) => Some(Cube::from_assignment(&assign, n)),
_ => None,
}
}
fn block_cube(
sys: &TransitionSystem,
frames: &mut FrameSequence,
cube: Cube,
level: usize,
conflict_budget: u64,
) -> BlockResult {
let n = sys.num_state_vars;
let mut queue: Vec<Obligation> = vec![Obligation {
cube,
level,
parent: None,
}];
while let Some(min_idx) = find_min_level(&queue) {
let obl_level = queue[min_idx].level;
if obl_level == 0 {
if is_initial_reachable(sys, &queue[min_idx].cube, conflict_budget) {
return BlockResult::Counterexample(
reconstruct_trace(&queue, min_idx, n),
);
}
queue.remove(min_idx);
continue;
}
let predecessor = check_predecessor(
sys,
frames,
&queue[min_idx].cube,
obl_level - 1,
conflict_budget,
);
match predecessor {
Some(pred_cube) => {
let parent_idx = min_idx;
queue.push(Obligation {
cube: pred_cube,
level: obl_level - 1,
parent: Some(parent_idx),
});
}
None => {
let clause = generalize(
sys,
frames,
&queue[min_idx].cube,
obl_level,
conflict_budget,
);
frames.add_blocked_clause(obl_level, clause);
queue.remove(min_idx);
}
}
}
BlockResult::Blocked
}
fn find_min_level(queue: &[Obligation]) -> Option<usize> {
if queue.is_empty() {
return None;
}
let mut min_idx = 0;
for i in 1..queue.len() {
if queue[i].level < queue[min_idx].level {
min_idx = i;
}
}
Some(min_idx)
}
fn is_initial_reachable(
sys: &TransitionSystem,
cube: &Cube,
conflict_budget: u64,
) -> bool {
let n = sys.num_state_vars;
let mut db = ClauseDb::new();
for clause in &sys.initial {
db.add_clause(clause.lits.clone());
}
for &lit in &cube.lits {
db.add_clause(vec![lit]);
}
let (result, _) = solve_watched_budget(db, n, conflict_budget);
matches!(result, SolveResult::Sat(_))
}
fn reconstruct_trace(
obligations: &[Obligation],
start_idx: usize,
num_state_vars: u32,
) -> Vec<Vec<bool>> {
let mut trace = Vec::new();
let mut current = Some(start_idx);
while let Some(idx) = current {
trace.push(obligations[idx].cube.to_state_vec(num_state_vars));
current = obligations[idx].parent;
}
trace
}
fn generalize(
sys: &TransitionSystem,
frames: &FrameSequence,
cube: &Cube,
level: usize,
conflict_budget: u64,
) -> Vec<Lit> {
let mut reduced_lits = cube.lits.clone();
let mut i = 0;
while i < reduced_lits.len() {
let mut candidate = reduced_lits.clone();
candidate.remove(i);
if candidate.is_empty() {
i += 1;
continue;
}
let candidate_cube = Cube::new(candidate.clone());
if level > 0 && check_predecessor(sys, frames, &candidate_cube, level - 1, conflict_budget).is_none()
{
reduced_lits = candidate;
} else {
i += 1;
}
}
Cube::new(reduced_lits).negate()
}
fn propagate_clauses(
sys: &TransitionSystem,
frames: &mut FrameSequence,
conflict_budget: u64,
) -> Option<usize> {
let n = sys.num_state_vars;
let frontier = frames.frontier();
for level in 1..frontier {
let clauses_to_check: Vec<Vec<Lit>> = frames.frame(level).clauses().to_vec();
for clause in &clauses_to_check {
if is_clause_inductive(sys, frames, clause, level, n, conflict_budget) {
let next = frames.frame(level + 1);
let already_present = next.clauses().iter().any(|c| {
let mut a: Vec<u32> = c.iter().map(|l| l.code()).collect();
let mut b: Vec<u32> = clause.iter().map(|l| l.code()).collect();
a.sort();
b.sort();
a == b
});
if !already_present {
let clause_copy = clause.clone();
frames.frame_mut(level + 1).add_clause(clause_copy);
}
}
}
}
frames.check_convergence()
}
fn is_clause_inductive(
sys: &TransitionSystem,
frames: &FrameSequence,
clause: &[Lit],
level: usize,
n: u32,
conflict_budget: u64,
) -> bool {
let total_vars = 2 * n;
let mut db = ClauseDb::new();
add_frame_clauses(&mut db, frames.frame(level), 0);
db.add_clause(clause.to_vec());
for tc in &sys.transition {
db.add_clause(tc.lits.clone());
}
for &lit in clause {
let shifted = shift_lit(lit.complement(), n);
db.add_clause(vec![shifted]);
}
let (result, _) = solve_watched_budget(db, total_vars, conflict_budget);
matches!(result, SolveResult::Unsat)
}
fn add_frame_clauses(db: &mut ClauseDb, frame: &Frame, offset: u32) {
for clause in frame.clauses() {
let shifted: Vec<Lit> = clause.iter().map(|&l| shift_lit(l, offset)).collect();
db.add_clause(shifted);
}
}
fn add_negated_property(db: &mut ClauseDb, sys: &TransitionSystem, prop_offset: u32, tseitin_base: u32) {
let num_tseitin = sys.property.len() as u32;
let activation: Vec<Lit> = (0..num_tseitin)
.map(|i| Lit::pos(tseitin_base + i))
.collect();
if !activation.is_empty() {
db.add_clause(activation);
}
for (i, clause) in sys.property.iter().enumerate() {
let t_var = tseitin_base + i as u32;
for &lit in &clause.lits {
let shifted = shift_lit(lit, prop_offset);
db.add_clause(vec![Lit::neg(t_var), shifted.complement()]);
}
}
}