use crate::constraint::{ConstraintEnum, VarId};
use crate::domain::Domain;
use crate::ordering::{self, Ordering};
use crate::solver::ac3;
use crate::solver::backtrack::Solution;
use crate::solver::propagate;
use crate::solver::SearchContext;
use crate::variable::Variable;
use crate::Pruning;
pub struct OptimizeConfig {
pub pruning: Pruning,
pub ordering: Ordering,
pub max_solutions: usize,
pub constraint_weights: Vec<f64>,
pub var_constraint_ids: Vec<Vec<usize>>,
pub maximize: bool,
pub node_budget: Option<u64>,
}
pub trait DomainCostEval<D: Domain> {
fn cost(&self, domain: &D, val: &D::Value) -> f64;
fn min_cost(&self, domain: &D) -> f64;
fn max_cost(&self, domain: &D) -> f64;
}
pub struct ZeroCost;
impl<D: Domain> DomainCostEval<D> for ZeroCost {
#[inline]
fn cost(&self, _domain: &D, _val: &D::Value) -> f64 {
0.0
}
#[inline]
fn min_cost(&self, _domain: &D) -> f64 {
0.0
}
#[inline]
fn max_cost(&self, _domain: &D) -> f64 {
0.0
}
}
pub struct CostDomainEval;
impl<D: crate::domain::CostDomain> DomainCostEval<D> for CostDomainEval {
#[inline]
fn cost(&self, domain: &D, val: &D::Value) -> f64 {
domain.cost(val)
}
#[inline]
fn min_cost(&self, domain: &D) -> f64 {
domain.min_cost()
}
#[inline]
fn max_cost(&self, domain: &D) -> f64 {
domain
.values()
.into_iter()
.map(|v| domain.cost(&v))
.fold(f64::NEG_INFINITY, f64::max)
}
}
struct ScoredSolution<D: Domain> {
solution: Solution<D>,
cost: f64,
}
struct BranchBoundState<D: Domain> {
scored: Vec<ScoredSolution<D>>,
best_cost: f64,
}
pub fn branch_and_bound<D: Domain>(
variables: &mut [Variable<D>],
constraints: &[ConstraintEnum<D>],
adjacency: &crate::adjacency::Adjacency,
config: &OptimizeConfig,
stats: &mut crate::SolveStats,
cost_eval: &dyn DomainCostEval<D>,
) -> Vec<Solution<D>>
where
D::Value: PartialEq,
{
let num_vars = variables.len();
let mut assignment: Vec<Option<D::Value>> = vec![None; num_vars];
let mut stack: Vec<VarId> = (0..num_vars as u32).collect();
let mut ctx = SearchContext { variables, constraints, adjacency, stats };
let mut bb = BranchBoundState {
scored: Vec::new(),
best_cost: f64::INFINITY,
};
bb_recurse(
&mut ctx,
config,
cost_eval,
&mut assignment,
&mut stack,
&mut bb,
0,
);
if config.maximize {
bb.scored.sort_by(|a, b| b.cost.partial_cmp(&a.cost).unwrap_or(std::cmp::Ordering::Equal));
} else {
bb.scored.sort_by(|a, b| a.cost.partial_cmp(&b.cost).unwrap_or(std::cmp::Ordering::Equal));
}
bb.scored.truncate(config.max_solutions);
bb.scored.into_iter().map(|s| s.solution).collect()
}
fn assignment_cost<D: Domain>(
assignment: &[Option<D::Value>],
variables: &[Variable<D>],
constraints: &[ConstraintEnum<D>],
cost_eval: &dyn DomainCostEval<D>,
) -> f64
where
D::Value: PartialEq,
{
let mut cost = 0.0;
for (i, val) in assignment.iter().enumerate() {
if let Some(v) = val {
cost += cost_eval.cost(&variables[i].domain, v);
}
}
for c in constraints {
cost += c.soft_penalty(assignment);
}
cost
}
fn optimistic_bound<D: Domain>(
assignment: &[Option<D::Value>],
variables: &[Variable<D>],
constraints: &[ConstraintEnum<D>],
cost_eval: &dyn DomainCostEval<D>,
maximize: bool,
) -> f64
where
D::Value: PartialEq,
{
let mut bound = 0.0;
for (i, val) in assignment.iter().enumerate() {
match val {
Some(v) => bound += cost_eval.cost(&variables[i].domain, v),
None => {
if maximize {
bound += cost_eval.max_cost(&variables[i].domain);
} else {
bound += cost_eval.min_cost(&variables[i].domain);
}
}
}
}
for c in constraints {
let scope = c.scope();
if scope.iter().all(|&v| assignment[v as usize].is_some()) {
bound += c.soft_penalty(assignment);
}
}
bound
}
fn bb_recurse<D: Domain>(
ctx: &mut SearchContext<'_, D>,
config: &OptimizeConfig,
cost_eval: &dyn DomainCostEval<D>,
assignment: &mut Vec<Option<D::Value>>,
stack: &mut Vec<VarId>,
bb: &mut BranchBoundState<D>,
depth: usize,
) -> bool
where
D::Value: PartialEq,
{
if stack.is_empty() {
let cost = assignment_cost(assignment, ctx.variables, ctx.constraints, cost_eval);
let effective_cost = if config.maximize { -cost } else { cost };
if effective_cost < bb.best_cost {
bb.best_cost = effective_cost;
}
let sol: Solution<D> = assignment
.iter()
.map(|v| v.as_ref().unwrap().clone())
.collect();
bb.scored.push(ScoredSolution { solution: sol, cost });
return false;
}
if let Some(budget) = config.node_budget
&& ctx.stats.nodes_explored >= budget
{
ctx.stats.budget_exceeded = true;
return true;
}
ctx.stats.nodes_explored += 1;
let ob = optimistic_bound(
assignment, ctx.variables, ctx.constraints, cost_eval, config.maximize,
);
let effective_ob = if config.maximize { -ob } else { ob };
if effective_ob >= bb.best_cost {
return false;
}
let idx = ordering::select_variable(
stack,
ctx.variables,
config.ordering,
&config.constraint_weights,
&config.var_constraint_ids,
)
.unwrap();
let var = stack.swap_remove(idx);
let mut values: Vec<_> = ctx.variables[var as usize].domain.iter().collect();
{
let domain = &ctx.variables[var as usize].domain;
if config.maximize {
values.sort_by(|a, b| {
let ca = cost_eval.cost(domain, b);
let cb = cost_eval.cost(domain, a);
ca.partial_cmp(&cb).unwrap_or(std::cmp::Ordering::Equal)
});
} else {
values.sort_by(|a, b| {
let ca = cost_eval.cost(domain, a);
let cb = cost_eval.cost(domain, b);
ca.partial_cmp(&cb).unwrap_or(std::cmp::Ordering::Equal)
});
}
}
for val in values {
assignment[var as usize] = Some(val.clone());
ctx.variables[var as usize].restrict_to(&val, depth);
let mut valid = true;
for &ci in ctx.adjacency.constraints_for(var) {
let ci = ci as usize;
let scope = ctx.constraints[ci].scope();
if scope.iter().all(|&v| assignment[v as usize].is_some())
&& !ctx.constraints[ci].check(assignment)
{
valid = false;
break;
}
}
if valid {
let dwo = match config.pruning {
Pruning::None => false,
Pruning::ForwardChecking => propagate::forward_check(
var,
ctx.variables,
ctx.constraints,
ctx.adjacency,
assignment.as_mut_slice(),
ctx.stats,
depth,
),
Pruning::Ac3 => ac3::ac3_from_variable(
var, ctx.variables, ctx.constraints, ctx.adjacency, assignment, ctx.stats, depth,
),
Pruning::AcFc => propagate::ac_fc(
var,
ctx.variables,
ctx.constraints,
ctx.adjacency,
assignment.as_mut_slice(),
ctx.stats,
depth,
),
};
if !dwo
&& bb_recurse(
ctx,
config,
cost_eval,
assignment,
stack,
bb,
depth + 1,
)
{
return true;
}
}
ctx.stats.backtracks += 1;
assignment[var as usize] = None;
for v in ctx.variables.iter_mut() {
v.restore(depth);
}
}
stack.push(var);
false
}