use crate::DType;
use super::{TensorLinProgResult, TensorLinearConstraints, simplex_impl};
use crate::optimize::error::{OptimizeError, OptimizeResult};
use crate::optimize::linprog::LinProgOptions;
use numr::ops::{CompareOps, ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct TensorMilpResult<R: Runtime<DType = DType>> {
pub x: Tensor<R>,
pub fun: f64,
pub success: bool,
pub nodes: usize,
pub gap: f64,
pub message: String,
}
#[derive(Debug, Clone)]
pub struct MilpOptionsInternal {
pub max_nodes: usize,
pub int_tol: f64,
pub gap_tol: f64,
pub lp_options: LinProgOptions,
}
pub fn milp_impl<R, C>(
client: &C,
c: &Tensor<R>,
constraints: &TensorLinearConstraints<R>,
integrality: &Tensor<R>,
options: &MilpOptionsInternal,
) -> OptimizeResult<TensorMilpResult<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + CompareOps<R> + RuntimeClient<R>,
{
let n = c.shape()[0];
if n == 0 {
return Err(OptimizeError::InvalidInput {
context: "milp: empty objective vector".to_string(),
});
}
let int_data: Vec<f64> = integrality.to_vec();
if int_data.len() != n {
return Err(OptimizeError::InvalidInput {
context: format!(
"milp: integrality has {} elements, expected {}",
int_data.len(),
n
),
});
}
let (base_lower, base_upper) = match (&constraints.lower_bounds, &constraints.upper_bounds) {
(Some(lb), Some(ub)) => (lb.to_vec(), ub.to_vec()),
_ => (vec![0.0; n], vec![f64::INFINITY; n]),
};
struct BnBNode {
lower: Vec<f64>,
upper: Vec<f64>,
lower_bound: f64,
}
let mut stack: Vec<BnBNode> = vec![BnBNode {
lower: base_lower,
upper: base_upper,
lower_bound: f64::NEG_INFINITY,
}];
let mut best_solution: Option<Vec<f64>> = None;
let mut best_objective = f64::INFINITY;
let mut nodes_explored = 0;
let device = c.device();
while let Some(node) = stack.pop() {
nodes_explored += 1;
if nodes_explored > options.max_nodes {
break;
}
if node.lower_bound >= best_objective - options.gap_tol {
continue;
}
let node_lower = Tensor::<R>::from_slice(&node.lower, &[n], device);
let node_upper = Tensor::<R>::from_slice(&node.upper, &[n], device);
let node_constraints = TensorLinearConstraints {
a_ub: constraints.a_ub.clone(),
b_ub: constraints.b_ub.clone(),
a_eq: constraints.a_eq.clone(),
b_eq: constraints.b_eq.clone(),
lower_bounds: Some(node_lower),
upper_bounds: Some(node_upper),
};
let lp_result: TensorLinProgResult<R> =
match simplex_impl(client, c, &node_constraints, &options.lp_options) {
Ok(r) => r,
Err(_) => continue,
};
if !lp_result.success {
continue;
}
if lp_result.fun >= best_objective - options.gap_tol {
continue;
}
let x_data: Vec<f64> = lp_result.x.to_vec();
let mut is_integer_feasible = true;
let mut branch_var = None;
let mut max_fractionality = 0.0;
for (i, (&is_int, &xi)) in int_data.iter().zip(x_data.iter()).enumerate() {
if is_int > 0.5 {
let frac = xi - xi.floor();
let fractionality = frac.min(1.0 - frac);
if fractionality > options.int_tol {
is_integer_feasible = false;
if fractionality > max_fractionality {
max_fractionality = fractionality;
branch_var = Some(i);
}
}
}
}
if is_integer_feasible {
if lp_result.fun < best_objective {
best_objective = lp_result.fun;
best_solution = Some(x_data);
}
continue;
}
if let Some(var) = branch_var {
let xi = x_data[var];
let floor_val = xi.floor();
let ceil_val = xi.ceil();
let mut left_upper = node.upper.clone();
left_upper[var] = left_upper[var].min(floor_val);
if node.lower[var] <= left_upper[var] {
stack.push(BnBNode {
lower: node.lower.clone(),
upper: left_upper,
lower_bound: lp_result.fun,
});
}
let mut right_lower = node.lower.clone();
right_lower[var] = right_lower[var].max(ceil_val);
if right_lower[var] <= node.upper[var] {
stack.push(BnBNode {
lower: right_lower,
upper: node.upper,
lower_bound: lp_result.fun,
});
}
}
}
match best_solution {
Some(x) => {
let x_rounded: Vec<f64> = x
.iter()
.zip(int_data.iter())
.map(|(&xi, &is_int)| if is_int > 0.5 { xi.round() } else { xi })
.collect();
let c_data: Vec<f64> = c.to_vec();
let fun: f64 = x_rounded
.iter()
.zip(c_data.iter())
.map(|(&xi, &ci)| xi * ci)
.sum();
let gap = if best_objective.abs() > 1e-10 {
(best_objective - fun).abs() / best_objective.abs()
} else {
0.0
};
let x_tensor = Tensor::<R>::from_slice(&x_rounded, &[n], device);
Ok(TensorMilpResult {
x: x_tensor,
fun,
success: true,
nodes: nodes_explored,
gap,
message: "Optimal solution found".to_string(),
})
}
None => {
let zeros = vec![0.0; n];
let x_tensor = Tensor::<R>::from_slice(&zeros, &[n], device);
Ok(TensorMilpResult {
x: x_tensor,
fun: f64::INFINITY,
success: false,
nodes: nodes_explored,
gap: f64::INFINITY,
message: if nodes_explored >= options.max_nodes {
"Maximum nodes reached without finding feasible solution".to_string()
} else {
"No feasible integer solution found".to_string()
},
})
}
}
}