use super::solver_options::BranchingHeuristic;
use crate::cdcl::clause_db::ClauseIndex;
use solhop_types::{LBool, Lit, Var};
enum InternalBranchStats {
Vsids {
activity: Vec<f64>,
var_inc: f64,
var_decay: f64,
},
Lrb {
alpha: f64,
learnt_counter: usize,
ema: Vec<f64>,
assigned: Vec<usize>,
participated: Vec<usize>,
reasoned: Vec<usize>,
},
}
pub struct VarManager {
assigns: Vec<LBool>,
reason: Vec<Option<ClauseIndex>>,
level: Vec<i32>,
stats: InternalBranchStats,
}
impl VarManager {
pub fn new(bh: BranchingHeuristic) -> Self {
VarManager {
assigns: vec![],
reason: vec![],
level: vec![],
stats: match bh {
BranchingHeuristic::Vsids { var_inc, var_decay } => InternalBranchStats::Vsids {
activity: vec![],
var_inc,
var_decay: 1.0 / var_decay,
},
BranchingHeuristic::Lrb => InternalBranchStats::Lrb {
alpha: 0.4,
learnt_counter: 0,
ema: vec![],
assigned: vec![],
participated: vec![],
reasoned: vec![],
},
},
}
}
pub fn n_vars(&self) -> usize {
self.assigns.len()
}
pub fn new_var(&mut self) -> Var {
let v = Var::new(self.n_vars());
self.reason.push(None);
self.assigns.push(LBool::Undef);
self.level.push(-1);
match &mut self.stats {
InternalBranchStats::Vsids { activity, .. } => {
activity.push(0.0);
}
InternalBranchStats::Lrb {
ema,
assigned,
participated,
reasoned,
..
} => {
ema.push(0.0);
assigned.push(0);
participated.push(0);
reasoned.push(0);
}
}
v
}
pub fn value(&self, x: Var) -> LBool {
self.assigns[x.index()]
}
pub fn value_lit(&self, p: Lit) -> LBool {
if p.sign() {
!self.assigns[p.var().index()]
} else {
self.assigns[p.var().index()]
}
}
pub fn after_conflict_analysis(
&mut self,
participating_variables: Vec<Var>,
reasoned_variables: std::collections::HashSet<Var>,
) {
match &mut self.stats {
InternalBranchStats::Vsids { .. } => {}
InternalBranchStats::Lrb {
alpha,
learnt_counter,
ema,
participated,
reasoned,
..
} => {
*learnt_counter += 1;
for v in participating_variables {
participated[v.index()] += 1;
}
if *alpha > 0.06 {
*alpha -= 1e-6;
}
for v in reasoned_variables {
reasoned[v.index()] += 1;
}
for (index, value) in ema.iter_mut().enumerate() {
if self.assigns[index] == LBool::Undef {
*value *= 0.95;
}
}
}
}
}
pub fn select_var(&self) -> Var {
let max_v = match &self.stats {
InternalBranchStats::Vsids { activity, .. } => (0..self.n_vars())
.filter(|v| self.value(Var::new(*v)) == LBool::Undef)
.max_by(|&x, &y| activity[x].partial_cmp(&activity[y]).unwrap())
.unwrap(),
InternalBranchStats::Lrb { ema, .. } => (0..self.n_vars())
.filter(|v| self.value(Var::new(*v)) == LBool::Undef)
.max_by(|&x, &y| ema[x].partial_cmp(&ema[y]).unwrap())
.unwrap(),
};
Var::new(max_v)
}
pub fn after_learnt_clause(&mut self, ps: &[Lit]) {
match &mut self.stats {
InternalBranchStats::Vsids {
activity, var_inc, ..
} => {
for p in ps {
let x = p.var();
activity[x.index()] += *var_inc;
if activity[x.index()] > 1e100 {
for act in activity.iter_mut() {
*act *= 1e-100;
}
*var_inc *= 1e-100;
}
}
}
InternalBranchStats::Lrb { .. } => {}
}
}
pub fn after_record_learnt_clause(&mut self) {
match &mut self.stats {
InternalBranchStats::Vsids {
var_inc, var_decay, ..
} => {
*var_inc *= *var_decay;
}
InternalBranchStats::Lrb { .. } => {}
}
}
pub fn get_reason(&self, var: Var) -> Option<ClauseIndex> {
self.reason[var.index()]
}
pub fn update(&mut self, var: Var, value: LBool, level: i32, reason: Option<ClauseIndex>) {
match &mut self.stats {
InternalBranchStats::Vsids { .. } => {}
InternalBranchStats::Lrb {
alpha,
learnt_counter,
ema,
assigned,
participated,
reasoned,
} => {
if value != LBool::Undef {
assigned[var.index()] = *learnt_counter;
participated[var.index()] = 0;
reasoned[var.index()] = 0;
} else {
let interval = *learnt_counter - assigned[var.index()];
if interval > 0 {
let interval = interval as f64;
let r = participated[var.index()] as f64 / interval;
let rsr = reasoned[var.index()] as f64 / interval;
let prev_ema = ema[var.index()];
let next_ema = (1.0 - *alpha) * prev_ema + *alpha * (r + rsr);
ema[var.index()] = next_ema;
}
}
}
}
self.assigns[var.index()] = value;
self.level[var.index()] = level;
self.reason[var.index()] = reason;
}
pub fn reset(&mut self, var: Var) {
self.update(var, LBool::Undef, -1, None);
}
pub fn model(&self) -> Vec<bool> {
self.assigns.iter().map(|&x| x == LBool::True).collect()
}
pub fn get_level(&self, var: Var) -> i32 {
self.level[var.index()]
}
}