use crate::literal::{Lit, Var};
#[allow(unused_imports)]
use crate::prelude::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PhaseMode {
Saved,
Target,
Random,
}
#[derive(Debug, Default, Clone)]
pub struct TargetPhaseStats {
pub target_used: u64,
pub saved_used: u64,
pub random_used: u64,
pub target_updates: u64,
}
pub struct TargetPhaseSelector {
saved_phase: Vec<bool>,
target_phase: Vec<bool>,
use_target: Vec<bool>,
decay: f64,
confidence: Vec<f64>,
stats: TargetPhaseStats,
}
impl TargetPhaseSelector {
pub fn new(num_vars: usize, decay: f64) -> Self {
Self {
saved_phase: vec![false; num_vars],
target_phase: vec![false; num_vars],
use_target: vec![false; num_vars],
decay,
confidence: vec![0.0; num_vars],
stats: TargetPhaseStats::default(),
}
}
pub fn resize(&mut self, num_vars: usize) {
self.saved_phase.resize(num_vars, false);
self.target_phase.resize(num_vars, false);
self.use_target.resize(num_vars, false);
self.confidence.resize(num_vars, 0.0);
}
#[must_use]
pub fn stats(&self) -> &TargetPhaseStats {
&self.stats
}
pub fn save_phase(&mut self, var: Var, phase: bool) {
self.saved_phase[var.index()] = phase;
}
pub fn set_target(&mut self, var: Var, phase: bool, confidence_boost: f64) {
let idx = var.index();
self.target_phase[idx] = phase;
self.confidence[idx] += confidence_boost;
if self.confidence[idx] > 1.0 {
self.use_target[idx] = true;
self.stats.target_updates += 1;
}
}
pub fn on_conflict_literal(&mut self, lit: Lit) {
self.set_target(lit.var(), lit.sign(), 0.5);
}
pub fn on_learned_clause(&mut self, clause: &[Lit]) {
if clause.len() <= 5 {
for &lit in clause {
self.set_target(lit.var(), lit.sign(), 0.2);
}
}
}
pub fn decay_confidence(&mut self) {
for conf in &mut self.confidence {
*conf *= self.decay;
if *conf < 0.5 {
}
}
}
pub fn get_phase(&mut self, var: Var, mode: PhaseMode) -> bool {
let idx = var.index();
match mode {
PhaseMode::Saved => {
self.stats.saved_used += 1;
self.saved_phase[idx]
}
PhaseMode::Target => {
if self.use_target[idx] && self.confidence[idx] > 0.5 {
self.stats.target_used += 1;
self.target_phase[idx]
} else {
self.stats.saved_used += 1;
self.saved_phase[idx]
}
}
PhaseMode::Random => {
self.stats.random_used += 1;
(idx & 1) == 0
}
}
}
pub fn reset_targets(&mut self) {
for use_target in &mut self.use_target {
*use_target = false;
}
for conf in &mut self.confidence {
*conf = 0.0;
}
}
#[must_use]
pub fn get_confidence(&self, var: Var) -> f64 {
self.confidence[var.index()]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_target_phase_creation() {
let selector = TargetPhaseSelector::new(10, 0.95);
assert_eq!(selector.saved_phase.len(), 10);
assert_eq!(selector.target_phase.len(), 10);
assert_eq!(selector.confidence.len(), 10);
}
#[test]
fn test_save_phase() {
let mut selector = TargetPhaseSelector::new(10, 0.95);
let var = Var::new(0);
selector.save_phase(var, true);
assert!(selector.saved_phase[var.index()]);
let phase = selector.get_phase(var, PhaseMode::Saved);
assert!(phase);
}
#[test]
fn test_target_phase() {
let mut selector = TargetPhaseSelector::new(10, 0.95);
let var = Var::new(0);
selector.set_target(var, true, 2.0);
let phase = selector.get_phase(var, PhaseMode::Target);
assert!(phase);
assert!(selector.get_confidence(var) > 1.0);
}
#[test]
fn test_confidence_decay() {
let mut selector = TargetPhaseSelector::new(10, 0.5);
let var = Var::new(0);
selector.set_target(var, true, 2.0);
let initial_conf = selector.get_confidence(var);
selector.decay_confidence();
let decayed_conf = selector.get_confidence(var);
assert!(decayed_conf < initial_conf);
assert!((decayed_conf - initial_conf * 0.5).abs() < 0.001);
}
#[test]
fn test_on_conflict_literal() {
let mut selector = TargetPhaseSelector::new(10, 0.95);
let lit = Lit::pos(Var::new(0));
selector.on_conflict_literal(lit);
assert!(selector.get_confidence(lit.var()) > 0.0);
}
#[test]
fn test_reset_targets() {
let mut selector = TargetPhaseSelector::new(10, 0.95);
let var = Var::new(0);
selector.set_target(var, true, 2.0);
assert!(selector.get_confidence(var) > 0.0);
selector.reset_targets();
assert_eq!(selector.get_confidence(var), 0.0);
}
#[test]
fn test_stats() {
let mut selector = TargetPhaseSelector::new(10, 0.95);
let var = Var::new(0);
selector.get_phase(var, PhaseMode::Saved);
assert_eq!(selector.stats().saved_used, 1);
selector.get_phase(var, PhaseMode::Random);
assert_eq!(selector.stats().random_used, 1);
}
}