use partial_ref::{partial, PartialRef};
use varisat_formula::{lit::LitIdx, Lit, Var};
use crate::{
context::{parts::*, Context},
decision::make_available,
};
use super::Reason;
#[derive(Default)]
pub struct Assignment {
assignment: Vec<Option<bool>>,
last_value: Vec<bool>,
}
pub fn fast_option_eq(a: Option<bool>, b: Option<bool>) -> bool {
unsafe { std::mem::transmute::<_, u8>(a) == std::mem::transmute::<_, u8>(b) }
}
impl Assignment {
pub fn set_var_count(&mut self, count: usize) {
self.assignment.resize(count, None);
self.last_value.resize(count, false);
}
pub fn assignment(&self) -> &[Option<bool>] {
&self.assignment
}
pub fn var_value(&self, var: Var) -> Option<bool> {
self.assignment[var.index()]
}
pub fn last_var_value(&self, var: Var) -> bool {
self.last_value[var.index()]
}
pub fn lit_value(&self, lit: Lit) -> Option<bool> {
self.assignment[lit.index()].map(|b| b ^ lit.is_negative())
}
pub fn lit_is_true(&self, lit: Lit) -> bool {
fast_option_eq(self.assignment[lit.index()], Some(lit.is_positive()))
}
pub fn lit_is_false(&self, lit: Lit) -> bool {
fast_option_eq(self.assignment[lit.index()], Some(lit.is_negative()))
}
pub fn lit_is_unk(&self, lit: Lit) -> bool {
fast_option_eq(self.assignment[lit.index()], None)
}
pub fn assign_lit(&mut self, lit: Lit) {
self.assignment[lit.index()] = lit.is_positive().into()
}
pub fn unassign_var(&mut self, var: Var) {
self.assignment[var.index()] = None;
}
pub fn set_var(&mut self, var: Var, assignment: Option<bool>) {
self.assignment[var.index()] = assignment;
}
}
#[derive(Default)]
pub struct Trail {
trail: Vec<Lit>,
queue_head_pos: usize,
decisions: Vec<LitIdx>,
units_removed: usize,
}
impl Trail {
pub fn queue_head(&self) -> Option<Lit> {
self.trail.get(self.queue_head_pos).cloned()
}
pub fn pop_queue(&mut self) -> Option<Lit> {
let head = self.queue_head();
if head.is_some() {
self.queue_head_pos += 1;
}
head
}
pub fn reset_queue(&mut self) {
self.queue_head_pos = 0;
}
pub fn trail(&self) -> &[Lit] {
&self.trail
}
pub fn clear(&mut self) {
assert!(self.decisions.is_empty());
self.units_removed += self.trail.len();
self.trail.clear();
self.queue_head_pos = 0;
}
pub fn new_decision_level(&mut self) {
self.decisions.push(self.trail.len() as LitIdx)
}
pub fn current_level(&self) -> usize {
self.decisions.len()
}
pub fn top_level_assignment_count(&self) -> usize {
self.decisions
.get(0)
.map(|&len| len as usize)
.unwrap_or(self.trail.len())
+ self.units_removed
}
pub fn fully_propagated(&self) -> bool {
self.queue_head_pos == self.trail.len()
}
}
pub fn enqueue_assignment(
mut ctx: partial!(Context, mut AssignmentP, mut ImplGraphP, mut TrailP),
lit: Lit,
reason: Reason,
) {
let assignment = ctx.part_mut(AssignmentP);
debug_assert!(assignment.lit_value(lit) == None);
assignment.assign_lit(lit);
let (trail, mut ctx) = ctx.split_part_mut(TrailP);
trail.trail.push(lit);
let node = &mut ctx.part_mut(ImplGraphP).nodes[lit.index()];
node.reason = reason;
node.level = trail.decisions.len() as LitIdx;
node.depth = trail.trail.len() as LitIdx;
}
pub fn backtrack(
mut ctx: partial!(Context, mut AssignmentP, mut TrailP, mut VsidsP),
level: usize,
) {
let (assignment, mut ctx) = ctx.split_part_mut(AssignmentP);
let (trail, mut ctx) = ctx.split_part_mut(TrailP);
if level >= trail.decisions.len() {
return;
}
let new_trail_len = trail.decisions[level] as usize;
trail.queue_head_pos = new_trail_len;
trail.decisions.truncate(level);
let trail_end = &trail.trail[new_trail_len..];
for &lit in trail_end {
make_available(ctx.borrow(), lit.var());
let var_assignment = &mut assignment.assignment[lit.index()];
assignment.last_value[lit.index()] = *var_assignment == Some(true);
*var_assignment = None;
}
trail.trail.truncate(new_trail_len);
}
pub fn full_restart(
mut ctx: partial!(
Context,
mut AssignmentP,
mut AssumptionsP,
mut TrailP,
mut VsidsP,
),
) {
ctx.part_mut(AssumptionsP).full_restart();
backtrack(ctx.borrow(), 0);
}
pub fn restart(
mut ctx: partial!(
Context,
mut AssignmentP,
mut TrailP,
mut VsidsP,
AssumptionsP
),
) {
let level = ctx.part(AssumptionsP).assumption_levels();
backtrack(ctx.borrow(), level);
}