use std::fmt::Debug;
use scirs2_core::num_traits::{Float, FromPrimitive};
use crate::error::{OptimizeError, OptimizeResult};
#[derive(Debug, Clone)]
pub struct CdclConfig {
pub max_clauses: usize,
pub max_learned_per_conflict: usize,
pub decay: f64,
pub activity_bump: f64,
pub min_activity_threshold: f64,
}
impl Default for CdclConfig {
fn default() -> Self {
Self {
max_clauses: 10_000,
max_learned_per_conflict: 3,
decay: 0.95,
activity_bump: 1.0,
min_activity_threshold: 1e-12,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BranchingDecision {
pub var_index: usize,
pub value: i32,
}
impl BranchingDecision {
pub fn new(var_index: usize, value: i32) -> Self {
Self { var_index, value }
}
}
#[derive(Debug, Clone)]
pub struct LearnedClause {
pub literals: Vec<(usize, i32)>,
pub activity: f64,
}
impl LearnedClause {
pub fn new(literals: Vec<(usize, i32)>) -> Self {
Self {
literals,
activity: 1.0,
}
}
pub fn len(&self) -> usize {
self.literals.len()
}
pub fn is_empty(&self) -> bool {
self.literals.is_empty()
}
pub fn subsumes(&self, other: &LearnedClause) -> bool {
self.literals
.iter()
.all(|lit| other.literals.contains(lit))
}
}
#[derive(Debug, Clone)]
pub struct CdclBranchingState<F = f64> {
pub n_vars: usize,
pub decisions: Vec<BranchingDecision>,
pub learned_clauses: Vec<LearnedClause>,
pub activity: Vec<F>,
pub config: CdclConfig,
}
impl<F> CdclBranchingState<F>
where
F: Float + FromPrimitive + Debug + Clone + std::ops::AddAssign + std::ops::MulAssign,
{
pub fn new(n_vars: usize, config: CdclConfig) -> OptimizeResult<Self> {
if n_vars == 0 {
return Err(OptimizeError::InvalidInput(
"n_vars must be positive".into(),
));
}
Ok(Self {
n_vars,
decisions: Vec::new(),
learned_clauses: Vec::new(),
activity: vec![F::zero(); n_vars],
config,
})
}
pub fn select_branching_var(&self, lp_solution: &[F]) -> Option<usize> {
if lp_solution.len() != self.n_vars {
return None;
}
let zero = F::zero();
let one = F::one();
let frac_tol =
F::from_f64(1e-6).unwrap_or_else(|| F::from_f64(1e-6).unwrap_or(zero));
let mut best_idx: Option<usize> = None;
let mut best_activity = F::neg_infinity();
for (i, &val) in lp_solution.iter().enumerate() {
if val > frac_tol && val < one - frac_tol {
let act = self.activity[i];
if act > best_activity {
best_activity = act;
best_idx = Some(i);
}
}
}
best_idx
}
pub fn record_conflict(&mut self, infeasible_decisions: &[BranchingDecision]) {
if infeasible_decisions.is_empty() {
return;
}
let literals: Vec<(usize, i32)> = infeasible_decisions
.iter()
.map(|d| (d.var_index, d.value))
.collect();
let clause = LearnedClause::new(literals);
let bump =
F::from_f64(self.config.activity_bump).unwrap_or(F::one());
for &(var_idx, _) in &clause.literals {
if var_idx < self.n_vars {
self.activity[var_idx] += bump;
}
}
let already_covered = self
.learned_clauses
.iter()
.any(|existing| existing.subsumes(&clause));
if !already_covered {
self.learned_clauses.push(clause);
}
if self.learned_clauses.len() > self.config.max_clauses {
let excess = self.learned_clauses.len() - self.config.max_clauses;
self.learned_clauses.drain(0..excess);
}
self.decay_activities();
}
pub fn apply_clauses(&self, current_decisions: &[BranchingDecision]) -> bool {
'clause: for clause in &self.learned_clauses {
if clause.is_empty() {
continue;
}
for &(var_idx, value) in &clause.literals {
let matched = current_decisions
.iter()
.any(|d| d.var_index == var_idx && d.value == value);
if !matched {
continue 'clause; }
}
return true;
}
false
}
pub fn decay_activities(&mut self) {
let decay =
F::from_f64(self.config.decay).unwrap_or(F::one());
for act in &mut self.activity {
*act *= decay;
}
}
pub fn push_decision(&mut self, decision: BranchingDecision) {
self.decisions.push(decision);
}
pub fn pop_decision(&mut self) -> Option<BranchingDecision> {
self.decisions.pop()
}
pub fn n_learned_clauses(&self) -> usize {
self.learned_clauses.len()
}
pub fn prune_inactive_clauses(&mut self) {
let threshold = self.config.min_activity_threshold;
self.learned_clauses
.retain(|c| c.activity >= threshold);
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum BranchingStrategy<F = f64> {
MostFractional,
StrongBranching,
Cdcl(CdclBranchingState<F>),
}
#[cfg(test)]
mod tests {
use super::*;
type F = f64;
fn make_state(n: usize) -> CdclBranchingState<F> {
CdclBranchingState::new(n, CdclConfig::default()).unwrap()
}
#[test]
fn test_new_state_correct_size() {
let state = make_state(5);
assert_eq!(state.n_vars, 5);
assert_eq!(state.activity.len(), 5);
assert!(state.decisions.is_empty());
assert!(state.learned_clauses.is_empty());
}
#[test]
fn test_cdcl_config_defaults() {
let cfg = CdclConfig::default();
assert_eq!(cfg.max_clauses, 10_000);
assert_eq!(cfg.max_learned_per_conflict, 3);
assert!((cfg.decay - 0.95).abs() < 1e-12);
assert!((cfg.activity_bump - 1.0).abs() < 1e-12);
}
#[test]
fn test_select_branching_var_picks_highest_activity() {
let mut state = make_state(4);
state.activity[0] = 0.1;
state.activity[1] = 0.5;
state.activity[2] = 1.0;
state.activity[3] = 0.3;
let lp_sol = vec![0.5, 0.7, 0.3, 0.6];
let selected = state.select_branching_var(&lp_sol);
assert_eq!(selected, Some(2), "should pick var 2 (highest activity)");
}
#[test]
fn test_select_branching_var_skips_integral() {
let mut state = make_state(3);
state.activity = vec![10.0, 0.5, 0.1];
let lp_sol = vec![1.0, 0.4, 0.6];
let selected = state.select_branching_var(&lp_sol);
assert_eq!(selected, Some(1));
}
#[test]
fn test_record_conflict_creates_clause_correct_length() {
let mut state = make_state(4);
let decisions = vec![
BranchingDecision::new(0, 1),
BranchingDecision::new(2, 0),
];
state.record_conflict(&decisions);
assert!(!state.learned_clauses.is_empty());
let clause = &state.learned_clauses[0];
assert_eq!(clause.len(), 2, "clause should have 2 literals");
}
#[test]
fn test_record_conflict_bumps_activity() {
let mut state = make_state(3);
let decisions = vec![
BranchingDecision::new(0, 1),
BranchingDecision::new(1, 0),
];
state.record_conflict(&decisions);
let expected = 1.0 * 0.95;
assert!(
(state.activity[0] - expected).abs() < 1e-9,
"activity[0] = {}",
state.activity[0]
);
assert!(
(state.activity[1] - expected).abs() < 1e-9,
"activity[1] = {}",
state.activity[1]
);
assert!(state.activity[2] < state.activity[0]);
}
#[test]
fn test_apply_clauses_no_violation() {
let mut state = make_state(3);
state.learned_clauses.push(LearnedClause::new(vec![(0, 1), (1, 0)]));
let current = vec![BranchingDecision::new(0, 1)]; assert!(!state.apply_clauses(¤t), "clause should not fire");
}
#[test]
fn test_apply_clauses_violation() {
let mut state = make_state(3);
state.learned_clauses.push(LearnedClause::new(vec![(0, 1), (1, 0)]));
let current = vec![
BranchingDecision::new(0, 1),
BranchingDecision::new(1, 0),
];
assert!(state.apply_clauses(¤t), "clause should fire");
}
#[test]
fn test_apply_clauses_empty_trail_no_violation() {
let mut state = make_state(3);
state.learned_clauses.push(LearnedClause::new(vec![(0, 1)]));
assert!(!state.apply_clauses(&[]), "empty trail cannot satisfy any clause");
}
#[test]
fn test_decay_activities_reduces_all() {
let mut state = make_state(3);
state.activity = vec![1.0, 2.0, 0.5];
let before: Vec<f64> = state.activity.clone();
state.decay_activities();
for (a, b) in before.iter().zip(state.activity.iter()) {
assert!(b < a || (*a == 0.0 && *b == 0.0));
}
}
#[test]
fn test_learned_clause_subsumption() {
let short = LearnedClause::new(vec![(0, 1)]);
let long = LearnedClause::new(vec![(0, 1), (1, 0)]);
assert!(short.subsumes(&long), "shorter clause should subsume longer");
assert!(!long.subsumes(&short), "longer clause should not subsume shorter");
}
#[test]
fn test_duplicate_clause_not_added() {
let mut state = make_state(3);
let decisions = vec![BranchingDecision::new(0, 1), BranchingDecision::new(1, 0)];
state.record_conflict(&decisions);
let count_before = state.learned_clauses.len();
let decisions2 = vec![BranchingDecision::new(0, 1), BranchingDecision::new(1, 0), BranchingDecision::new(2, 1)];
state.record_conflict(&decisions2);
assert_eq!(
state.learned_clauses.len(),
count_before,
"subsumed clause should not be added"
);
}
}