use std::cmp::Ordering;
use std::collections::BinaryHeap;
use numra_core::Scalar;
use crate::error::OptimError;
use crate::lp::{simplex_solve, LPOptions};
use crate::problem::VarType;
use crate::types::{OptimResult, OptimStatus};
#[derive(Clone, Debug)]
pub struct MILPOptions<S: Scalar> {
pub max_nodes: usize,
pub int_tol: S,
pub lp_tol: S,
pub gap_tol: S,
pub verbose: bool,
}
impl<S: Scalar> Default for MILPOptions<S> {
fn default() -> Self {
Self {
max_nodes: 100_000,
int_tol: S::from_f64(1e-6),
lp_tol: S::from_f64(1e-10),
gap_tol: S::from_f64(1e-8),
verbose: false,
}
}
}
#[derive(Clone, Debug)]
struct BbNode<S: Scalar> {
lb: Vec<S>,
ub: Vec<S>,
lp_bound: S,
depth: usize,
}
#[derive(Clone, Debug)]
struct OrderedNode<S: Scalar>(BbNode<S>);
impl<S: Scalar> PartialEq for OrderedNode<S> {
fn eq(&self, other: &Self) -> bool {
self.0.lp_bound == other.0.lp_bound
}
}
impl<S: Scalar> Eq for OrderedNode<S> {}
impl<S: Scalar> PartialOrd for OrderedNode<S> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<S: Scalar> Ord for OrderedNode<S> {
fn cmp(&self, other: &Self) -> Ordering {
other
.0
.lp_bound
.partial_cmp(&self.0.lp_bound)
.unwrap_or(Ordering::Equal)
}
}
#[allow(clippy::too_many_arguments)]
pub fn milp_solve<
S: Scalar + faer::SimpleEntity + faer::Conjugate<Canonical = S> + faer::ComplexField,
>(
c: &[S],
a_ineq: &[Vec<S>],
b_ineq: &[S],
a_eq: &[Vec<S>],
b_eq: &[S],
var_types: &[VarType],
bounds: &[Option<(S, S)>],
opts: &MILPOptions<S>,
) -> Result<OptimResult<S>, OptimError> {
let start = std::time::Instant::now();
let n = c.len();
if var_types.len() != n || bounds.len() != n {
return Err(OptimError::DimensionMismatch {
expected: n,
actual: if var_types.len() > bounds.len() {
var_types.len()
} else {
bounds.len()
},
});
}
let int_indices: Vec<usize> = (0..n)
.filter(|&i| var_types[i] == VarType::Integer || var_types[i] == VarType::Binary)
.collect();
let mut init_lb = vec![S::ZERO; n];
let mut init_ub = vec![S::INFINITY; n];
for i in 0..n {
if let Some((lo, hi)) = bounds[i] {
init_lb[i] = if lo > S::ZERO { lo } else { S::ZERO }; init_ub[i] = hi;
}
if var_types[i] == VarType::Binary {
init_lb[i] = S::ZERO;
init_ub[i] = S::ONE;
}
}
if int_indices.is_empty() {
let result = solve_lp_relaxation(c, a_ineq, b_ineq, a_eq, b_eq, &init_lb, &init_ub, opts)?;
let mut res = result;
res.wall_time_secs = start.elapsed().as_secs_f64();
return Ok(res);
}
let root_result =
match solve_lp_relaxation(c, a_ineq, b_ineq, a_eq, b_eq, &init_lb, &init_ub, opts) {
Ok(r) => r,
Err(OptimError::LPInfeasible) => return Err(OptimError::MILPInfeasible),
Err(e) => return Err(e),
};
if is_integer_feasible(&root_result.x, &int_indices, opts.int_tol) {
let mut res = root_result;
res.message = "Optimal (LP relaxation is integer feasible)".into();
res.wall_time_secs = start.elapsed().as_secs_f64();
return Ok(res);
}
let root_node = BbNode {
lb: init_lb,
ub: init_ub,
lp_bound: root_result.f,
depth: 0,
};
let mut heap = BinaryHeap::new();
heap.push(OrderedNode(root_node));
let mut best_obj = S::INFINITY;
let mut best_x: Option<Vec<S>> = None;
let mut nodes_explored: usize = 0;
while let Some(OrderedNode(node)) = heap.pop() {
nodes_explored += 1;
if nodes_explored > opts.max_nodes {
break;
}
if node.lp_bound >= best_obj - opts.gap_tol {
continue;
}
let lp_result =
match solve_lp_relaxation(c, a_ineq, b_ineq, a_eq, b_eq, &node.lb, &node.ub, opts) {
Ok(r) => r,
Err(OptimError::LPInfeasible) => continue, Err(_) => continue, };
if lp_result.f >= best_obj - opts.gap_tol {
continue;
}
if is_integer_feasible(&lp_result.x, &int_indices, opts.int_tol) {
if lp_result.f < best_obj {
best_obj = lp_result.f;
let mut x_rounded = lp_result.x.clone();
for &i in &int_indices {
x_rounded[i] = x_rounded[i].round();
}
best_x = Some(x_rounded);
if opts.verbose {
eprintln!(
"MILP: new incumbent obj={:.6} at node {}",
best_obj.to_f64(),
nodes_explored
);
}
}
continue;
}
let branch_var = select_branching_variable(&lp_result.x, &int_indices, opts.int_tol);
if let Some(bvar) = branch_var {
let val = lp_result.x[bvar];
let floor_val = val.floor();
let ceil_val = val.ceil();
if floor_val >= node.lb[bvar] {
let mut left_ub = node.ub.clone();
left_ub[bvar] = floor_val;
let left_node = BbNode {
lb: node.lb.clone(),
ub: left_ub,
lp_bound: lp_result.f, depth: node.depth + 1,
};
heap.push(OrderedNode(left_node));
}
if ceil_val <= node.ub[bvar] {
let mut right_lb = node.lb.clone();
right_lb[bvar] = ceil_val;
let right_node = BbNode {
lb: right_lb,
ub: node.ub.clone(),
lp_bound: lp_result.f,
depth: node.depth + 1,
};
heap.push(OrderedNode(right_node));
}
}
}
match best_x {
Some(x) => {
let f_val: S = c.iter().zip(x.iter()).map(|(&ci, &xi)| ci * xi).sum();
Ok(OptimResult {
x,
f: f_val,
grad: c.to_vec(),
iterations: nodes_explored,
n_feval: nodes_explored,
n_geval: 0,
converged: true,
message: format!(
"Optimal integer solution found after {} nodes",
nodes_explored
),
status: OptimStatus::FunctionConverged,
history: Vec::new(),
lambda_eq: Vec::new(),
lambda_ineq: Vec::new(),
active_bounds: Vec::new(),
constraint_violation: S::ZERO,
wall_time_secs: start.elapsed().as_secs_f64(),
pareto: None,
sensitivity: None,
})
}
None => Err(OptimError::MILPInfeasible),
}
}
fn is_integer_feasible<S: Scalar>(x: &[S], int_indices: &[usize], tol: S) -> bool {
int_indices
.iter()
.all(|&i| (x[i] - x[i].round()).abs() < tol)
}
fn select_branching_variable<S: Scalar>(x: &[S], int_indices: &[usize], tol: S) -> Option<usize> {
let mut best_idx = None;
let mut best_frac_dist = S::ZERO; let half = S::from_f64(0.5);
for &i in int_indices {
let frac = x[i] - x[i].floor();
if frac < tol || frac > S::ONE - tol {
continue; }
let dist = (frac - half).abs();
let score = half - dist; if score > best_frac_dist {
best_frac_dist = score;
best_idx = Some(i);
}
}
best_idx
}
#[allow(clippy::too_many_arguments)]
fn solve_lp_relaxation<
S: Scalar + faer::SimpleEntity + faer::Conjugate<Canonical = S> + faer::ComplexField,
>(
c: &[S],
a_ineq: &[Vec<S>],
b_ineq: &[S],
a_eq: &[Vec<S>],
b_eq: &[S],
lb: &[S],
ub: &[S],
opts: &MILPOptions<S>,
) -> Result<OptimResult<S>, OptimError> {
let n = c.len();
let obj_offset: S = c.iter().zip(lb.iter()).map(|(&ci, &lbi)| ci * lbi).sum();
let mut new_a_ineq: Vec<Vec<S>> = Vec::with_capacity(a_ineq.len() + n);
let mut new_b_ineq: Vec<S> = Vec::with_capacity(b_ineq.len() + n);
for (i, row) in a_ineq.iter().enumerate() {
let shift: S = row.iter().zip(lb.iter()).map(|(&a, &l)| a * l).sum();
new_a_ineq.push(row.clone());
new_b_ineq.push(b_ineq[i] - shift);
}
let mut new_a_eq: Vec<Vec<S>> = Vec::with_capacity(a_eq.len());
let mut new_b_eq: Vec<S> = Vec::with_capacity(b_eq.len());
for (i, row) in a_eq.iter().enumerate() {
let shift: S = row.iter().zip(lb.iter()).map(|(&a, &l)| a * l).sum();
new_a_eq.push(row.clone());
new_b_eq.push(b_eq[i] - shift);
}
for i in 0..n {
if ub[i].is_finite() {
let effective_ub = ub[i] - lb[i];
if effective_ub < -opts.lp_tol {
return Err(OptimError::LPInfeasible);
}
let mut row = vec![S::ZERO; n];
row[i] = S::ONE;
new_a_ineq.push(row);
new_b_ineq.push(effective_ub);
}
}
let lp_opts = LPOptions {
max_iter: 10_000,
tol: opts.lp_tol,
verbose: false,
};
let mut result = simplex_solve(c, &new_a_ineq, &new_b_ineq, &new_a_eq, &new_b_eq, &lp_opts)?;
for (xi, &lbi) in result.x.iter_mut().zip(lb.iter()) {
*xi += lbi;
}
result.f = c
.iter()
.zip(result.x.iter())
.map(|(&ci, &xi)| ci * xi)
.sum();
let _ = obj_offset;
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
fn default_opts() -> MILPOptions<f64> {
MILPOptions::default()
}
#[test]
fn test_milp_no_integer_vars() {
let c = vec![-1.0, -1.0];
let a_ineq = vec![vec![1.0, 1.0], vec![1.0, 0.0], vec![0.0, 1.0]];
let b_ineq = vec![4.0, 3.0, 3.0];
let var_types = vec![VarType::Continuous, VarType::Continuous];
let bounds = vec![None, None];
let opts = default_opts();
let result =
milp_solve(&c, &a_ineq, &b_ineq, &[], &[], &var_types, &bounds, &opts).unwrap();
assert!(result.converged);
assert!(
(result.f - (-4.0)).abs() < 1e-6,
"expected f ~ -4.0, got {}",
result.f
);
}
#[test]
fn test_milp_all_integer() {
let c = vec![-3.0, -5.0];
let a_ineq = vec![vec![1.0, 2.0], vec![2.0, 1.0]];
let b_ineq = vec![6.0, 8.0];
let var_types = vec![VarType::Integer, VarType::Integer];
let bounds = vec![None, None];
let opts = default_opts();
let result =
milp_solve(&c, &a_ineq, &b_ineq, &[], &[], &var_types, &bounds, &opts).unwrap();
assert!(result.converged);
assert!(
result.f <= -15.0 + 1e-6,
"expected obj <= -15, got {}",
result.f
);
for (i, xi) in result.x.iter().enumerate() {
assert!(
(xi - xi.round()).abs() < 1e-6,
"x[{}]={} is not integer",
i,
xi
);
}
let lhs1: f64 = result.x[0] + 2.0 * result.x[1];
let lhs2: f64 = 2.0 * result.x[0] + result.x[1];
assert!(lhs1 <= 6.0 + 1e-6, "constraint 1 violated: {}", lhs1);
assert!(lhs2 <= 8.0 + 1e-6, "constraint 2 violated: {}", lhs2);
}
#[test]
fn test_milp_binary_knapsack() {
let c = vec![-6.0, -5.0, -4.0];
let a_ineq = vec![vec![3.0, 4.0, 2.0]];
let b_ineq = vec![7.0];
let var_types = vec![VarType::Binary, VarType::Binary, VarType::Binary];
let bounds = vec![Some((0.0, 1.0)), Some((0.0, 1.0)), Some((0.0, 1.0))];
let opts = default_opts();
let result =
milp_solve(&c, &a_ineq, &b_ineq, &[], &[], &var_types, &bounds, &opts).unwrap();
assert!(result.converged);
assert!(
result.f <= -10.0 + 1e-6,
"expected obj <= -10, got {}",
result.f
);
for (i, xi) in result.x.iter().enumerate() {
assert!(
(xi - 0.0).abs() < 1e-6 || (xi - 1.0).abs() < 1e-6,
"x[{}]={} is not binary",
i,
xi
);
}
let weight: f64 = 3.0 * result.x[0] + 4.0 * result.x[1] + 2.0 * result.x[2];
assert!(
weight <= 7.0 + 1e-6,
"knapsack capacity violated: {}",
weight
);
}
#[test]
fn test_milp_mixed_integer() {
let c = vec![-1.0, -1.0];
let a_ineq = vec![vec![1.0, 1.0]];
let b_ineq = vec![3.5];
let var_types = vec![VarType::Integer, VarType::Continuous];
let bounds = vec![None, None];
let opts = default_opts();
let result =
milp_solve(&c, &a_ineq, &b_ineq, &[], &[], &var_types, &bounds, &opts).unwrap();
assert!(result.converged);
assert!(
(result.x[0] - result.x[0].round()).abs() < 1e-6,
"x[0]={} is not integer",
result.x[0]
);
assert!(result.f <= -3.4, "expected obj <= -3.4, got {}", result.f);
}
#[test]
fn test_milp_infeasible() {
let c = vec![-1.0, -1.0];
let a_eq = vec![vec![1.0, 1.0], vec![1.0, 1.0]];
let b_eq = vec![1.0, 2.0];
let var_types = vec![VarType::Integer, VarType::Integer];
let bounds = vec![None, None];
let opts = default_opts();
let result = milp_solve(&c, &[], &[], &a_eq, &b_eq, &var_types, &bounds, &opts);
assert!(result.is_err(), "expected infeasible, got {:?}", result);
match result.unwrap_err() {
OptimError::MILPInfeasible | OptimError::LPInfeasible => {}
e => panic!("expected MILPInfeasible or LPInfeasible, got {:?}", e),
}
}
#[test]
fn test_milp_with_equality() {
let c = vec![-1.0, -2.0];
let a_eq = vec![vec![1.0, 1.0]];
let b_eq = vec![3.0];
let var_types = vec![VarType::Integer, VarType::Integer];
let bounds = vec![None, None];
let opts = default_opts();
let result = milp_solve(&c, &[], &[], &a_eq, &b_eq, &var_types, &bounds, &opts).unwrap();
assert!(result.converged);
assert!(
(result.f - (-6.0)).abs() < 1e-6,
"expected obj=-6, got {}",
result.f
);
assert!(
(result.x[0] - 0.0).abs() < 1e-6,
"expected x[0]=0, got {}",
result.x[0]
);
assert!(
(result.x[1] - 3.0).abs() < 1e-6,
"expected x[1]=3, got {}",
result.x[1]
);
}
}