use std::marker::PhantomData;
use crate::{
encodings::{pb, Monotone},
instances::BasicVarManager,
solvers::{SolveIncremental, SolveStats, SolverResult},
types::{Assignment, Lit, TernaryVal},
};
pub trait Solve {
type Solver: SolveIncremental + SolveStats;
fn solve(solver: &mut Self::Solver, objective: &[(Lit, usize)]) -> Option<(Assignment, usize)>;
}
fn objective_value(obj: &[(Lit, usize)], sol: &Assignment) -> usize {
obj.iter().fold(0, |sum, (l, w)| {
if sol.lit_value(*l) == TernaryVal::True {
sum + w
} else {
sum
}
})
}
#[derive(Debug)]
pub struct SolutionImprovingSearch<Solver, PbEnc> {
slv: PhantomData<Solver>,
enc: PhantomData<PbEnc>,
}
impl<Solver, PbEnc> Solve for SolutionImprovingSearch<Solver, PbEnc>
where
Solver: SolveIncremental + SolveStats,
PbEnc: FromIterator<(Lit, usize)> + pb::BoundUpperIncremental + Monotone,
{
type Solver = Solver;
fn solve(solver: &mut Self::Solver, objective: &[(Lit, usize)]) -> Option<(Assignment, usize)> {
let Some(max_var) = solver.max_var() else {
return Some((Assignment::default(), 0));
};
let mut vm = BasicVarManager::from_next_free(max_var + 1);
let mut enc: PbEnc = objective.iter().copied().collect();
let mut sol = None;
loop {
match solver.solve().expect("solver error while solving") {
SolverResult::Sat => {
let assign = solver.solution(max_var).expect("failed getting solution");
let val = objective_value(objective, &assign);
sol = Some((assign, val));
if val == 0 {
return sol;
}
enc.encode_ub(val - 1..val, solver, &mut vm)
.expect("error adding clauses to solver");
for unit in enc.enforce_ub(val - 1).expect("invalid encoding usage") {
solver
.add_unit(unit)
.expect("error adding clause to solver");
}
}
SolverResult::Unsat => return sol,
SolverResult::Interrupted => unreachable!(),
}
}
}
}