use crate::qubo::Qubo;
use ndarray::Array1;
use rayon::prelude::*;
use crate::branch_node::QuboBBNode;
use crate::branch_stratagy::{BranchResult, BranchStrategy};
use crate::branch_subproblem::{get_sub_problem_solver, SubProblemSolver};
use crate::branchbound_utils::{check_integer_feasibility, get_current_time};
use crate::branchboundlogger::SolverOutputLogger;
use crate::lower_bound::li_lower_bound;
use crate::preprocess;
use crate::preprocess::preprocess_qubo;
use crate::solver_options::SolverOptions;
use std::collections::BinaryHeap;
pub struct BBSolver {
pub qubo: Qubo,
pub qubo_pp_form: Qubo,
pub best_solution: Array1<usize>,
pub best_solution_value: f64,
pub nodes: BinaryHeap<QuboBBNode>,
pub nodes_processed: usize,
pub nodes_solved: usize,
pub nodes_visited: usize,
pub time_start: f64,
pub branch_strategy: BranchStrategy,
pub subproblem_solver: Box<dyn SubProblemSolver + Sync>,
pub options: SolverOptions,
pub early_stop: bool,
pub solver_logger: SolverOutputLogger,
}
pub enum Event {
UpdateBestSolution(Array1<usize>, f64),
AddBranches(QuboBBNode, QuboBBNode),
Nill,
}
pub enum NodeLoggingAction {
Visited,
Processed,
Solved,
}
pub enum PruneAction {
Prune,
Dont,
}
pub struct ProcessNodeState {
pub prune_action: PruneAction,
pub events: Vec<Event>,
pub logging: NodeLoggingAction,
}
pub enum SolverResult {
OptimalSolution(Array1<f64>, f64),
SubOptimalSolution(Array1<f64>, f64),
}
impl BBSolver {
pub fn new(qubo: Qubo, options: SolverOptions) -> Self {
let qubo = qubo.convex_symmetric_form();
let num_x = qubo.num_x();
let subproblem_solver = get_sub_problem_solver(&qubo, &options.sub_problem_solver);
let branch_strategy = options.branch_strategy;
let start_time = get_current_time();
let output_level = options.verbose;
let pp_form = preprocess::shift_qubo(&qubo);
Self {
qubo,
qubo_pp_form: pp_form,
best_solution: Array1::zeros(num_x),
best_solution_value: 0.0,
nodes: BinaryHeap::new(),
nodes_processed: 0,
nodes_visited: 0,
nodes_solved: 0,
time_start: start_time,
branch_strategy,
subproblem_solver,
options,
early_stop: false,
solver_logger: SolverOutputLogger::new(output_level),
}
}
pub fn warm_start(&mut self, initial_solution: Array1<usize>) {
let warm_start_value = self.qubo.eval_usize(&initial_solution);
self.update_solution_if_better(&initial_solution, warm_start_value);
}
pub fn solve(&mut self) -> (Array1<usize>, f64) {
let fixed_variables =
preprocess_qubo(&self.qubo_pp_form, &self.options.fixed_variables, true);
self.options.fixed_variables.clone_from(&fixed_variables);
let mut root_node = QuboBBNode {
lower_bound: f64::NEG_INFINITY,
solution: 0.5 * Array1::ones(self.qubo.num_x()), fixed_variables,
};
let (root_lower_bound, root_solution) = self.solve_node(&root_node);
root_node.lower_bound = root_lower_bound;
root_node.solution = root_solution;
self.nodes.push(root_node);
self.time_start = get_current_time();
self.solver_logger.output_header(self);
if self.best_solution_value < 0.0 {
self.solver_logger.output_warm_start_info(self);
}
while !(*self).termination_condition() {
let nodes = self.get_next_nodes(self.options.threads);
let process_results = nodes
.par_iter()
.map(|node| self.process_node(node))
.collect::<Vec<_>>();
for state in process_results {
self.apply_events(state.events);
self.apply_logging_action(state.logging);
}
self.solver_logger.generate_output_line(self);
}
self.solver_logger.generate_exit_line(self);
(self.best_solution.clone(), self.best_solution_value)
}
pub fn can_prune_action(&self, node: &QuboBBNode) -> (PruneAction, Event) {
if node.lower_bound > self.best_solution_value {
return (PruneAction::Prune, Event::Nill);
}
if node.fixed_variables.len() == self.qubo.num_x() {
let mut solution = Array1::zeros(self.qubo.num_x());
for (&index, &value) in &node.fixed_variables {
solution[index] = value;
}
let value = self.qubo.eval_usize(&solution);
return (
PruneAction::Prune,
Event::UpdateBestSolution(solution, value),
);
}
(PruneAction::Dont, Event::Nill)
}
pub fn apply_logging_action(&mut self, action: NodeLoggingAction) {
match action {
NodeLoggingAction::Visited => {
self.nodes_visited += 1;
}
NodeLoggingAction::Processed => {
self.nodes_processed += 1;
}
NodeLoggingAction::Solved => {
self.nodes_processed += 1;
self.nodes_solved += 1;
}
}
}
pub fn process_node(&self, node: &QuboBBNode) -> ProcessNodeState {
let mut node = node.clone();
node.fixed_variables = preprocess_qubo(&self.qubo_pp_form, &node.fixed_variables, true);
let li_bound = li_lower_bound(&self.qubo, &node.fixed_variables);
node.lower_bound = node.lower_bound.max(li_bound);
let (prune_action, event) = self.can_prune_action(&node);
if matches!(prune_action, PruneAction::Prune) {
return ProcessNodeState {
prune_action,
events: vec![event],
logging: NodeLoggingAction::Processed,
};
}
let (is_int_feasible, rounded_sol) = check_integer_feasibility(&node);
if is_int_feasible {
let value = self.qubo.eval_usize(&rounded_sol);
return ProcessNodeState {
prune_action,
events: vec![Event::UpdateBestSolution(rounded_sol, value)],
logging: NodeLoggingAction::Solved,
};
}
let (lower_bound, solution) = self.solve_node(&node);
node.solution.clone_from(&solution);
let branch_result = self.make_branch(&node);
for (&index, &value) in &branch_result.found_fixed_vars {
node.fixed_variables.insert(index, value);
}
if node.fixed_variables.len() == self.qubo.num_x() {
let mut solution = Array1::zeros(self.qubo.num_x());
for (&index, &value) in &node.fixed_variables {
solution[index] = value;
}
let value = self.qubo.eval_usize(&solution);
return ProcessNodeState {
prune_action,
events: vec![Event::UpdateBestSolution(solution, value)],
logging: NodeLoggingAction::Solved,
};
}
let (heur_sol, heur_obj) = self.options.heuristic.make_heuristic(self, &node);
let (zero_branch, one_branch) =
Self::branch(node, branch_result.branch_variable, lower_bound, solution);
ProcessNodeState {
prune_action,
events: vec![
Event::AddBranches(zero_branch, one_branch),
Event::UpdateBestSolution(heur_sol, heur_obj),
],
logging: NodeLoggingAction::Solved,
}
}
pub fn apply_events(&mut self, events: Vec<Event>) {
for action in events {
match action {
Event::UpdateBestSolution(solution, value) => {
self.update_solution_if_better(&solution, value);
}
Event::AddBranches(zero_branch, one_branch) => {
if zero_branch.lower_bound <= self.best_solution_value {
self.nodes.push(zero_branch);
}
if one_branch.lower_bound <= self.best_solution_value {
self.nodes.push(one_branch);
}
}
Event::Nill => {}
}
}
}
pub fn update_solution_if_better(&mut self, solution: &Array1<usize>, solution_value: f64) {
if solution_value < self.best_solution_value {
self.best_solution.clone_from(solution);
self.best_solution_value = solution_value;
self.nodes.retain(|node| {
node.lower_bound <= self.best_solution_value
});
}
}
pub fn get_next_node(&mut self) -> Option<QuboBBNode> {
while !self.nodes.is_empty() {
let optional_node = self.nodes.pop();
let node = optional_node?;
self.apply_logging_action(NodeLoggingAction::Visited);
let (prune, event) = self.can_prune_action(&node);
if let Event::UpdateBestSolution(solution, value) = event {
self.update_solution_if_better(&solution, value);
}
if matches!(prune, PruneAction::Dont) {
return Some(node);
}
}
None
}
pub fn get_next_nodes(&mut self, n: usize) -> Vec<QuboBBNode> {
let mut nodes = Vec::new();
while nodes.len() <= n {
let next_node = self.get_next_node();
if let Some(node) = next_node {
nodes.push(node);
} else {
break;
}
}
nodes
}
pub fn termination_condition(&self) -> bool {
let current_time = get_current_time();
if current_time - self.time_start > self.options.max_time {
return true;
}
if self.nodes.is_empty() {
return true;
}
if self.early_stop {
return true;
}
false
}
pub fn make_branch(&self, node: &QuboBBNode) -> BranchResult {
self.branch_strategy.make_branch(self, node)
}
pub fn branch(
node: QuboBBNode,
branch_id: usize,
lower_bound: f64,
solution: Array1<f64>,
) -> (QuboBBNode, QuboBBNode) {
let mut zero_branch = node.clone();
let mut one_branch = node;
zero_branch.fixed_variables.insert(branch_id, 0);
one_branch.fixed_variables.insert(branch_id, 1);
zero_branch.solution.clone_from(&solution);
one_branch.solution = solution;
zero_branch.lower_bound = lower_bound;
one_branch.lower_bound = lower_bound;
(zero_branch, one_branch)
}
pub fn solve_node(&self, node: &QuboBBNode) -> (f64, Array1<f64>) {
self.subproblem_solver.solve_lower_bound(self, node, None)
}
}
#[cfg(test)]
mod tests {
use crate::branch_stratagy::BranchStrategy;
use crate::branch_subproblem::SubProblemSelection;
use crate::preprocess::preprocess_qubo;
use crate::qubo::Qubo;
use crate::solver_options::SolverOptions;
use crate::tests::make_test_prng;
use crate::{branchbound, local_search};
use ndarray::Array1;
use sprs::CsMat;
use std::collections::HashMap;
pub fn get_default_solver_options() -> SolverOptions {
let mut options = SolverOptions::new();
options.verbose = 1;
options.max_time = 1000.0;
options.threads = 20;
options
}
#[test]
pub fn branch_bound_test() {
let mut prng = make_test_prng();
let eye = CsMat::eye(3);
let c = Array1::from_vec(vec![-1.1, -2.0, -3.0]);
let p = Qubo::new_with_c(eye, c);
let guess = local_search::particle_swarm_search(&p, 100, 1000, &mut prng);
let mut solver = branchbound::BBSolver::new(p, SolverOptions::new());
solver.warm_start(guess);
solver.solve();
assert_eq!(solver.best_solution, Array1::from_vec(vec![1, 1, 1]));
}
#[test]
pub fn test_gka2b_solve() {
let file_path = "test_data/gka2b.qubo";
let p = Qubo::read_qubo(file_path);
let sol_val = Array1::from_vec(vec![
0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0,
0, 0,
]);
solve_qubo_with_all_permutations(&p, &sol_val);
}
#[test]
pub fn test_gka1b_solve() {
let file_path = "test_data/gka1b.qubo";
let p = Qubo::read_qubo(file_path);
let sol_val = Array1::from_vec(vec![
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0,
]);
solve_qubo_with_all_permutations(&p, &sol_val);
}
#[test]
pub fn test_gka6a_solve() {
let file_path = "test_data/gka6a.qubo";
let p = Qubo::read_qubo(file_path);
let sol_val = Array1::from_vec(vec![
0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1,
0, 1,
]);
solve_qubo_with_all_permutations(&p, &sol_val);
}
#[test]
pub fn test_gka7a_solve() {
let file_path = "test_data/gka7a.qubo";
let p = Qubo::read_qubo(file_path);
let sol_val = Array1::from_vec(vec![
0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
0, 1,
]);
solve_qubo_with_all_permutations(&p, &sol_val);
}
pub fn solve_qubo_with_all_permutations(qubo: &Qubo, sol_val: &Array1<usize>) {
let branch_options = vec![
BranchStrategy::FirstNotFixed,
BranchStrategy::MostViolated,
BranchStrategy::Random,
BranchStrategy::WorstApproximation,
BranchStrategy::WorstApproximation2,
BranchStrategy::MostEdges,
BranchStrategy::LargestEdges,
BranchStrategy::MostFixed,
BranchStrategy::FullStrongBranching,
BranchStrategy::PartialStrongBranching,
BranchStrategy::LargestDiag,
BranchStrategy::MoveingEdges,
BranchStrategy::RoundRobin,
];
let sub_problem_solvers = vec![
SubProblemSelection::ClarabelQP,
SubProblemSelection::ClarabelLP,
SubProblemSelection::HerculesCDQP,
];
for branch in &branch_options {
for sup_problem_solver in &sub_problem_solvers {
setup_and_solve_problem(branch, &sup_problem_solver, qubo, sol_val);
}
}
}
pub fn setup_and_solve_problem(
branch: &BranchStrategy,
sup_problem_solver: &SubProblemSelection,
qubo: &Qubo,
true_sol: &Array1<usize>,
) {
let mut prng = make_test_prng();
let fixed_variables = preprocess_qubo(&qubo, &HashMap::new(), false);
let guess = local_search::particle_swarm_search(&qubo, 10, 100, &mut prng);
let mut options = get_default_solver_options();
options.branch_strategy = *branch;
options.fixed_variables = fixed_variables.clone();
options.sub_problem_solver = *sup_problem_solver;
options.verbose = 0;
let mut solver = branchbound::BBSolver::new(qubo.clone(), options);
solver.warm_start(guess);
let (_, sol_value) = solver.solve();
let actual_obj = solver.qubo.eval_usize(&true_sol);
assert!((sol_value - actual_obj).abs() <= 1E-5);
}
}