use std::mem::swap;
use partial_ref::{partial, split_borrow, PartialRef};
use vec_mut_scan::VecMutScan;
use varisat_formula::{lit::LitIdx, Lit, Var};
use varisat_internal_proof::{clause_hash, lit_hash, ClauseHash};
use crate::{
clause::ClauseRef,
context::{parts::*, Context},
prop::{Conflict, Reason},
};
#[derive(Default)]
pub struct AnalyzeConflict {
clause: Vec<Lit>,
current_level_count: usize,
var_flags: Vec<bool>,
to_clean: Vec<Var>,
involved: Vec<ClauseRef>,
clause_hashes: Vec<ClauseHash>,
unordered_clause_hashes: Vec<(LitIdx, ClauseHash)>,
stack: Vec<Lit>,
}
impl AnalyzeConflict {
pub fn set_var_count(&mut self, count: usize) {
self.var_flags.resize(count, false);
}
pub fn clause(&self) -> &[Lit] {
&self.clause
}
pub fn involved(&self) -> &[ClauseRef] {
&self.involved
}
pub fn clause_hashes(&self) -> &[ClauseHash] {
&self.clause_hashes
}
}
pub fn analyze_conflict<'a>(
mut ctx: partial!(
Context<'a>,
mut AnalyzeConflictP,
mut VsidsP,
ClauseAllocP,
ImplGraphP,
ProofP<'a>,
TrailP,
),
conflict: Conflict,
) -> usize {
split_borrow!(lit_ctx = &(ClauseAllocP) ctx);
{
let analyze = ctx.part_mut(AnalyzeConflictP);
analyze.clause.clear();
analyze.involved.clear();
analyze.clause_hashes.clear();
analyze.unordered_clause_hashes.clear();
analyze.current_level_count = 0;
}
let conflict_lits = conflict.lits(&lit_ctx);
if ctx.part(ProofP).clause_hashes_required() {
ctx.part_mut(AnalyzeConflictP)
.clause_hashes
.push(clause_hash(conflict_lits));
}
if ctx.part(TrailP).current_level() == 0 {
return 0;
}
for &lit in conflict_lits {
add_literal(ctx.borrow(), lit);
}
if let Conflict::Long(cref) = conflict {
ctx.part_mut(AnalyzeConflictP).involved.push(cref);
}
split_borrow!(ctx_trail = &(TrailP) ctx);
for &lit in ctx_trail.part(TrailP).trail().iter().rev() {
let analyze = ctx.part_mut(AnalyzeConflictP);
let lit_present = &mut analyze.var_flags[lit.index()];
if *lit_present {
*lit_present = false;
analyze.current_level_count -= 1;
if analyze.current_level_count == 0 {
analyze.clause.push(!lit);
let end = analyze.clause.len() - 1;
analyze.clause.swap(0, end);
break;
} else {
let (graph, mut ctx) = ctx.split_part(ImplGraphP);
let reason = graph.reason(lit.var());
let lits = reason.lits(&lit_ctx);
if ctx.part(ProofP).clause_hashes_required() && !reason.is_unit() {
let hash = clause_hash(lits) ^ lit_hash(lit);
ctx.part_mut(AnalyzeConflictP).clause_hashes.push(hash);
}
for &lit in lits {
add_literal(ctx.borrow(), lit);
}
if let Reason::Long(cref) = reason {
ctx.part_mut(AnalyzeConflictP).involved.push(*cref);
}
}
}
}
minimize_clause(ctx.borrow());
let (analyze, mut ctx) = ctx.split_part_mut(AnalyzeConflictP);
if ctx.part(ProofP).clause_hashes_required() {
analyze
.unordered_clause_hashes
.sort_unstable_by_key(|&(depth, _)| !depth);
analyze
.unordered_clause_hashes
.dedup_by_key(|&mut (depth, _)| depth);
analyze.clause_hashes.extend(
analyze
.unordered_clause_hashes
.iter()
.map(|&(_, hash)| hash),
);
analyze.clause_hashes.reverse();
}
for var in analyze.to_clean.drain(..) {
analyze.var_flags[var.index()] = false;
}
let mut backtrack_to = 0;
if analyze.clause.len() > 1 {
let (prefix, rest) = analyze.clause.split_at_mut(2);
let lit_1 = &mut prefix[1];
backtrack_to = ctx.part(ImplGraphP).level(lit_1.var());
for lit in rest.iter_mut() {
let lit_level = ctx.part(ImplGraphP).level(lit.var());
if lit_level > backtrack_to {
backtrack_to = lit_level;
swap(lit_1, lit);
}
}
}
ctx.part_mut(VsidsP).decay();
backtrack_to
}
fn add_literal(
mut ctx: partial!(
Context,
mut AnalyzeConflictP,
mut VsidsP,
ImplGraphP,
TrailP
),
lit: Lit,
) {
let (analyze, mut ctx) = ctx.split_part_mut(AnalyzeConflictP);
let lit_level = ctx.part(ImplGraphP).level(lit.var());
if lit_level > 0 && !analyze.var_flags[lit.index()] {
ctx.part_mut(VsidsP).bump(lit.var());
analyze.var_flags[lit.index()] = true;
if lit_level == ctx.part(TrailP).current_level() {
analyze.current_level_count += 1;
} else {
analyze.clause.push(lit);
analyze.to_clean.push(lit.var());
}
}
}
#[derive(Default)]
struct LevelAbstraction {
bits: u64,
}
impl LevelAbstraction {
pub fn add(&mut self, level: usize) {
self.bits |= 1 << (level % 64)
}
pub fn test(&self, level: usize) -> bool {
self.bits & (1 << (level % 64)) != 0
}
}
fn minimize_clause<'a>(
mut ctx: partial!(
Context<'a>,
mut AnalyzeConflictP,
mut VsidsP,
ClauseAllocP,
ImplGraphP,
ProofP<'a>,
TrailP,
),
) {
let (analyze, mut ctx) = ctx.split_part_mut(AnalyzeConflictP);
split_borrow!(lit_ctx = &(ClauseAllocP) ctx);
let impl_graph = ctx.part(ImplGraphP);
let mut involved_levels = LevelAbstraction::default();
for &lit in analyze.clause.iter() {
involved_levels.add(impl_graph.level(lit.var()));
}
let mut scan = VecMutScan::new(&mut analyze.clause);
scan.next();
'next_lit: while let Some(lit) = scan.next() {
if impl_graph.reason(lit.var()) == &Reason::Unit {
continue;
}
analyze.stack.clear();
analyze.stack.push(!*lit);
let top = analyze.to_clean.len();
let hashes_top = analyze.unordered_clause_hashes.len();
while let Some(lit) = analyze.stack.pop() {
let reason = impl_graph.reason(lit.var());
let lits = reason.lits(&lit_ctx);
if ctx.part(ProofP).clause_hashes_required() && !reason.is_unit() {
let depth = impl_graph.depth(lit.var()) as LitIdx;
let hash = clause_hash(lits) ^ lit_hash(lit);
analyze.unordered_clause_hashes.push((depth, hash));
}
for &reason_lit in lits {
let reason_level = impl_graph.level(reason_lit.var());
if !analyze.var_flags[reason_lit.index()] && reason_level > 0 {
if impl_graph.reason(reason_lit.var()) == &Reason::Unit
|| !involved_levels.test(reason_level)
{
for lit in analyze.to_clean.drain(top..) {
analyze.var_flags[lit.index()] = false;
}
analyze.unordered_clause_hashes.truncate(hashes_top);
continue 'next_lit;
} else {
analyze.var_flags[reason_lit.index()] = true;
analyze.to_clean.push(reason_lit.var());
analyze.stack.push(!reason_lit);
}
}
}
}
lit.remove();
}
}