use crate::branch_node::QuboBBNode;
use crate::branchbound::BBSolver;
use crate::graph_utils::get_all_disconnected_graphs;
use crate::preprocess::preprocess_qubo;
use ndarray::Array1;
use smolprng::{JsfLarge, PRNG};
use std::collections::HashMap;
#[derive(Copy, Clone)]
pub enum BranchStrategy {
FirstNotFixed,
MostViolated,
Random,
WorstApproximation,
WorstApproximation2,
MostEdges,
LargestEdges,
MostFixed,
FullStrongBranching,
PartialStrongBranching,
RoundRobin,
LargestDiag,
MoveingEdges,
ConnectedComponents,
}
pub struct BranchResult {
pub branch_variable: usize,
pub found_fixed_vars: HashMap<usize, usize>,
}
impl BranchStrategy {
pub fn make_branch(self, bb_solver: &BBSolver, node: &QuboBBNode) -> BranchResult {
let branch_result = match self {
Self::FirstNotFixed => first_not_fixed(bb_solver, node),
Self::MostViolated => most_violated(bb_solver, node),
Self::Random => random(bb_solver, node),
Self::WorstApproximation => worst_approximation(bb_solver, node),
Self::WorstApproximation2 => worst_approximation_second_order(bb_solver, node),
Self::MostEdges => most_edges(bb_solver, node),
Self::LargestEdges => largest_edges(bb_solver, node),
Self::MostFixed => most_fixed(bb_solver, node),
Self::FullStrongBranching => full_strong_branching(bb_solver, node),
Self::PartialStrongBranching => partial_strong_branching(bb_solver, node),
Self::RoundRobin => round_robin(bb_solver, node),
Self::LargestDiag => largest_diag(bb_solver, node),
Self::MoveingEdges => moving_edges(bb_solver, node),
Self::ConnectedComponents => connected_components(bb_solver, node),
};
assert!(
!node
.fixed_variables
.contains_key(&branch_result.branch_variable),
"Branching on a fixed variable"
);
branch_result
}
}
fn connected_components(solver: &BBSolver, node: &QuboBBNode) -> BranchResult {
let mut selected_variable = 0;
let mut max_components = 0;
for i in 0..solver.qubo.num_x() {
if !node.fixed_variables.contains_key(&i) {
let mut list_0 = node.fixed_variables.clone();
list_0.insert(i, 0);
let num_components = get_all_disconnected_graphs(&solver.qubo, &list_0);
if num_components.len() > max_components {
max_components = num_components.len();
selected_variable = i;
}
}
}
BranchResult {
branch_variable: selected_variable,
found_fixed_vars: HashMap::new(),
}
}
fn most_edges(solver: &BBSolver, node: &QuboBBNode) -> BranchResult {
let mut edge_count = Array1::<usize>::zeros(solver.qubo.num_x());
for (_, (i, j)) in &solver.qubo.q {
if node.fixed_variables.contains_key(&i) {
continue;
}
if node.fixed_variables.contains_key(&j) {
continue;
}
edge_count[i] += 1;
edge_count[j] += 1;
}
let mut max_edges = 0;
let mut index_max_edges = 0;
for i in 0..solver.qubo.num_x() {
if !node.fixed_variables.contains_key(&i) && edge_count[i] > max_edges {
max_edges = edge_count[i];
index_max_edges = i;
}
}
BranchResult {
branch_variable: index_max_edges,
found_fixed_vars: HashMap::new(),
}
}
fn largest_edges(solver: &BBSolver, node: &QuboBBNode) -> BranchResult {
let mut edge_count = Array1::<f64>::zeros(solver.qubo.num_x());
for (&value, (i, j)) in &solver.qubo.q {
if node.fixed_variables.contains_key(&i) {
continue;
}
if node.fixed_variables.contains_key(&j) {
continue;
}
edge_count[i] += value.abs();
edge_count[j] += value.abs();
}
let mut min_edge_value = 0.0;
let mut index_max_edges = 0;
for i in 0..solver.qubo.num_x() {
if !node.fixed_variables.contains_key(&i) && edge_count[i] > min_edge_value {
min_edge_value = edge_count[i];
index_max_edges = i;
}
}
BranchResult {
branch_variable: index_max_edges,
found_fixed_vars: HashMap::new(),
}
}
pub fn most_fixed(solver: &BBSolver, node: &QuboBBNode) -> BranchResult {
let mut most_fixed = 0;
let mut branch_variable = 0;
let mut found_fixed_vars = HashMap::new();
for i in 0..solver.qubo.num_x() {
if !node.fixed_variables.contains_key(&i) && !found_fixed_vars.contains_key(&i) {
let mut list_0 = node.fixed_variables.clone();
let mut list_1 = node.fixed_variables.clone();
list_0.insert(i, 0);
list_1.insert(i, 1);
for (&key, &value) in &found_fixed_vars {
list_0.insert(key, value);
list_1.insert(key, value);
}
let fixed_0 = preprocess_qubo(&solver.qubo_pp_form, &list_0, true);
let fixed_1 = preprocess_qubo(&solver.qubo_pp_form, &list_1, true);
for (&key, &value) in &fixed_0 {
if !node.fixed_variables.contains_key(&key) && fixed_1.contains_key(&key) {
if fixed_1[&key] == value {
found_fixed_vars.insert(key, value);
}
}
}
let min_fixed = fixed_0.len().min(fixed_1.len());
if min_fixed > most_fixed {
most_fixed = min_fixed;
branch_variable = i;
}
}
}
BranchResult {
branch_variable,
found_fixed_vars,
}
}
pub fn first_not_fixed(solver: &BBSolver, node: &QuboBBNode) -> BranchResult {
for i in 0..solver.qubo.num_x() {
if !node.fixed_variables.contains_key(&i) {
return BranchResult {
branch_variable: i,
found_fixed_vars: HashMap::new(),
};
}
}
panic!("No variable to branch on");
}
pub fn largest_diag(solver: &BBSolver, node: &QuboBBNode) -> BranchResult {
let mut max_diag = f64::NEG_INFINITY;
let mut index_max_diag = 0;
for i in 0..solver.qubo.num_x() {
if !node.fixed_variables.contains_key(&i) {
let diag_value = solver.qubo.q[[i, i]];
if diag_value > max_diag {
max_diag = diag_value;
index_max_diag = i;
}
}
}
BranchResult {
branch_variable: index_max_diag,
found_fixed_vars: HashMap::new(),
}
}
pub fn most_violated(solver: &BBSolver, node: &QuboBBNode) -> BranchResult {
let mut most_violated = 1.0;
let mut index_most_violated = 0;
for i in 0..solver.qubo.num_x() {
if !node.fixed_variables.contains_key(&i) {
let violation = (node.solution[i] - 0.5).abs();
if violation <= most_violated {
most_violated = violation;
index_most_violated = i;
}
}
}
BranchResult {
branch_variable: index_most_violated,
found_fixed_vars: HashMap::new(),
}
}
pub fn worst_approximation_second_order(solver: &BBSolver, node: &QuboBBNode) -> BranchResult {
let (zero_flip, one_flip) = compute_strong_branch(solver, node);
let mut worst_approximation = f64::NEG_INFINITY;
let mut index_worst_approximation = 0;
for (&value, (i, j)) in &solver.qubo.q {
if node.fixed_variables.contains_key(&i) || node.fixed_variables.contains_key(&j) || i == j
{
continue;
}
let Q_ii = solver.qubo.q[[i, i]];
let Q_jj = solver.qubo.q[[j, j]];
let Q_ij = value;
let obj = |x: f64, y: f64| -> f64 { Q_ii * x * x + Q_jj * y * y + 2.0 * Q_ij * x * y };
let flip_00 = obj(-node.solution[i], -node.solution[j]);
let flip_01 = obj(-node.solution[i], 1.0 - node.solution[j]);
let flip_10 = obj(1.0 - node.solution[i], -node.solution[j]);
let flip_11 = obj(1.0 - node.solution[i], 1.0 - node.solution[j]);
let min_obj_gain = flip_00.abs() * flip_01.abs() * flip_10.abs() * flip_11.abs();
if min_obj_gain > worst_approximation {
worst_approximation = min_obj_gain;
let i_approx = zero_flip[i].abs() * (one_flip[i].abs());
let j_approx = zero_flip[j].abs() * (one_flip[j].abs());
if i_approx > j_approx {
index_worst_approximation = i;
} else {
index_worst_approximation = j;
}
}
}
BranchResult {
branch_variable: index_worst_approximation,
found_fixed_vars: HashMap::new(),
}
}
pub fn full_strong_branching(solver: &BBSolver, node: &QuboBBNode) -> BranchResult {
let unfixed_variables = (0..solver.qubo.num_x())
.filter(|i| !node.fixed_variables.contains_key(i))
.collect::<Vec<usize>>();
let mut fixed_variables = node.fixed_variables.clone();
let mut best_score = f64::NEG_INFINITY;
let mut best_variable = *unfixed_variables.first().unwrap();
let mut found_fixes: HashMap<usize, usize> = HashMap::new();
for i in &unfixed_variables {
let mut list_0 = fixed_variables.clone();
let mut list_1 = fixed_variables.clone();
list_0.insert(*i, 0);
list_1.insert(*i, 1);
for (&key, &value) in &found_fixes {
list_0.insert(key, value);
list_1.insert(key, value);
}
list_0 = preprocess_qubo(&solver.qubo_pp_form, &list_0, true);
list_1 = preprocess_qubo(&solver.qubo_pp_form, &list_1, true);
let node_0 = QuboBBNode {
lower_bound: 0.0,
fixed_variables: list_0,
solution: node.solution.clone(),
};
let node_1 = QuboBBNode {
lower_bound: 0.0,
fixed_variables: list_1,
solution: node.solution.clone(),
};
let bound_0 = solver
.subproblem_solver
.solve_lower_bound(solver, &node_0, None);
let bound_1 = solver
.subproblem_solver
.solve_lower_bound(solver, &node_1, None);
let score = bound_0.0.min(bound_1.0);
if bound_0.0 >= solver.best_solution_value {
found_fixes.insert(*i, 1);
fixed_variables.insert(*i, 1);
} else if bound_1.0 >= solver.best_solution_value {
found_fixes.insert(*i, 0);
fixed_variables.insert(*i, 0);
}
for (&key, &value) in &node_0.fixed_variables {
if !node.fixed_variables.contains_key(&key) && node_1.fixed_variables.contains_key(&key)
{
if node_1.fixed_variables[&key] == value {
found_fixes.insert(key, value);
}
}
}
if score > best_score {
best_score = score;
best_variable = *i;
}
}
BranchResult {
branch_variable: best_variable,
found_fixed_vars: found_fixes,
}
}
pub fn partial_strong_branching(solver: &BBSolver, node: &QuboBBNode) -> BranchResult {
let (zero_flip, one_flip) = compute_strong_branch(solver, node);
let mut score = Array1::zeros(solver.qubo.num_x());
let unfixed_vars = (0..solver.qubo.num_x())
.filter(|i| !node.fixed_variables.contains_key(i))
.collect::<Vec<usize>>();
for &i in &unfixed_vars {
score[i] = zero_flip[i].abs() * (one_flip[i].abs());
}
let mut indx = unfixed_vars.clone();
indx.sort_by(|&i, &j| score[i].total_cmp(&score[j]).reverse());
let mut fixed_variables = node.fixed_variables.clone();
let end = usize::min(25, unfixed_vars.len());
let mut found_fixes = HashMap::new();
let mut best_score = f64::NEG_INFINITY;
let mut best_variable = *indx.first().unwrap();
for i in 0..end {
let mut list_0 = fixed_variables.clone();
let mut list_1 = fixed_variables.clone();
let j = *indx.get(i).unwrap();
list_0.insert(j, 0);
list_1.insert(j, 1);
for (&key, &value) in &found_fixes {
list_0.insert(key, value);
list_1.insert(key, value);
}
list_0 = preprocess_qubo(&solver.qubo_pp_form, &list_0, true);
list_1 = preprocess_qubo(&solver.qubo_pp_form, &list_1, true);
let node_0 = QuboBBNode {
lower_bound: 0.0,
fixed_variables: list_0,
solution: node.solution.clone(),
};
let node_1 = QuboBBNode {
lower_bound: 0.0,
fixed_variables: list_1,
solution: node.solution.clone(),
};
let bound_0 = solver
.subproblem_solver
.solve_lower_bound(solver, &node_0, None);
let bound_1 = solver
.subproblem_solver
.solve_lower_bound(solver, &node_1, None);
let score_i = bound_0.0.min(bound_1.0);
if bound_0.0 >= solver.best_solution_value {
found_fixes.insert(j, 1);
fixed_variables.insert(j, 1);
} else if bound_1.0 >= solver.best_solution_value {
found_fixes.insert(j, 0);
fixed_variables.insert(j, 0);
}
for (&key, &value) in &node_0.fixed_variables {
if !node.fixed_variables.contains_key(&key) && node_1.fixed_variables.contains_key(&key)
{
if node_1.fixed_variables[&key] == value {
found_fixes.insert(key, value);
}
}
}
if score_i > best_score {
best_score = score_i;
best_variable = j;
}
}
BranchResult {
branch_variable: best_variable,
found_fixed_vars: found_fixes,
}
}
pub fn random(solver: &BBSolver, node: &QuboBBNode) -> BranchResult {
let mut prng = PRNG {
generator: JsfLarge::from(solver.options.seed as u64 + solver.nodes_visited as u64),
};
let index = usize::try_from(prng.gen_u64() % solver.qubo.num_x() as u64).unwrap();
for i in index..solver.qubo.num_x() {
if !node.fixed_variables.contains_key(&i) {
return BranchResult {
branch_variable: i,
found_fixed_vars: HashMap::new(),
};
}
}
for i in 0..index {
if !node.fixed_variables.contains_key(&i) {
return BranchResult {
branch_variable: i,
found_fixed_vars: HashMap::new(),
};
}
}
panic!("No Variable to branch on")
}
pub fn worst_approximation(solver: &BBSolver, node: &QuboBBNode) -> BranchResult {
let (zero_flip, one_flip) = compute_strong_branch(solver, node);
let mut worst_approximation = f64::NEG_INFINITY;
let mut index_worst_approximation = 0;
for i in 0..solver.qubo.num_x() {
if node.fixed_variables.contains_key(&i) {
continue;
}
let min_obj_gain = zero_flip[i].abs() * (one_flip[i].abs());
if min_obj_gain > worst_approximation {
worst_approximation = min_obj_gain;
index_worst_approximation = i;
}
}
BranchResult {
branch_variable: index_worst_approximation,
found_fixed_vars: HashMap::new(),
}
}
pub fn compute_strong_branch(solver: &BBSolver, node: &QuboBBNode) -> (Array1<f64>, Array1<f64>) {
let mut zero_result = Array1::zeros(solver.qubo.num_x());
let mut one_result = Array1::zeros(solver.qubo.num_x());
for i in 0..solver.qubo.num_x() {
let diag_ii = solver.qubo.q[[i, i]];
zero_result[i] = diag_ii * node.solution[i] * node.solution[i];
one_result[i] = diag_ii * (1.0 - node.solution[i]) * (1.0 - node.solution[i]);
}
(zero_result, one_result)
}
pub fn moving_edges(solver: &BBSolver, node: &QuboBBNode) -> BranchResult {
let mut edge_size = Array1::<f64>::zeros(solver.qubo.num_x());
for (&value, (i, j)) in &solver.qubo.q {
if node.fixed_variables.contains_key(&i) {
continue;
}
if node.fixed_variables.contains_key(&j) {
continue;
}
edge_size[i] += value.abs();
edge_size[j] += value.abs();
}
let mut min_edge_value = -10.0;
let mut index_max_edges = 0;
for i in 0..solver.qubo.num_x() {
if !node.fixed_variables.contains_key(&i) {
let movement = node.solution[i].min(1.0 - node.solution[i]);
if edge_size[i] * movement > min_edge_value {
min_edge_value = edge_size[i] * movement;
index_max_edges = i;
}
}
}
BranchResult {
branch_variable: index_max_edges,
found_fixed_vars: HashMap::new(),
}
}
pub fn round_robin(solver: &BBSolver, node: &QuboBBNode) -> BranchResult {
let node_seed = node.fixed_variables.keys().sum::<usize>() as u64;
let solver_seed = solver.options.seed as u64 + solver.nodes_solved as u64;
let mut prng = PRNG {
generator: JsfLarge::from(node_seed + solver_seed),
};
match prng.gen_u64() % 3 {
0 => largest_edges(solver, node),
1 => most_edges(solver, node),
2 => worst_approximation(solver, node),
_ => panic!("Random branch selection failed"),
}
}