1use std::cmp::Ordering;
13use std::collections::BinaryHeap;
14
15use numra_core::Scalar;
16
17use crate::error::OptimError;
18use crate::lp::{simplex_solve, LPOptions};
19use crate::problem::VarType;
20use crate::types::{OptimResult, OptimStatus};
21
22#[derive(Clone, Debug)]
24pub struct MILPOptions<S: Scalar> {
25 pub max_nodes: usize,
27 pub int_tol: S,
29 pub lp_tol: S,
31 pub gap_tol: S,
33 pub verbose: bool,
35}
36
37impl<S: Scalar> Default for MILPOptions<S> {
38 fn default() -> Self {
39 Self {
40 max_nodes: 100_000,
41 int_tol: S::from_f64(1e-6),
42 lp_tol: S::from_f64(1e-10),
43 gap_tol: S::from_f64(1e-8),
44 verbose: false,
45 }
46 }
47}
48
49#[derive(Clone, Debug)]
51struct BbNode<S: Scalar> {
52 lb: Vec<S>,
54 ub: Vec<S>,
56 lp_bound: S,
58 depth: usize,
60}
61
62#[derive(Clone, Debug)]
65struct OrderedNode<S: Scalar>(BbNode<S>);
66
67impl<S: Scalar> PartialEq for OrderedNode<S> {
68 fn eq(&self, other: &Self) -> bool {
69 self.0.lp_bound == other.0.lp_bound
70 }
71}
72
73impl<S: Scalar> Eq for OrderedNode<S> {}
74
75impl<S: Scalar> PartialOrd for OrderedNode<S> {
76 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
77 Some(self.cmp(other))
78 }
79}
80
81impl<S: Scalar> Ord for OrderedNode<S> {
82 fn cmp(&self, other: &Self) -> Ordering {
83 other
85 .0
86 .lp_bound
87 .partial_cmp(&self.0.lp_bound)
88 .unwrap_or(Ordering::Equal)
89 }
90}
91
92#[allow(clippy::too_many_arguments)]
110pub fn milp_solve<
111 S: Scalar + faer::SimpleEntity + faer::Conjugate<Canonical = S> + faer::ComplexField,
112>(
113 c: &[S],
114 a_ineq: &[Vec<S>],
115 b_ineq: &[S],
116 a_eq: &[Vec<S>],
117 b_eq: &[S],
118 var_types: &[VarType],
119 bounds: &[Option<(S, S)>],
120 opts: &MILPOptions<S>,
121) -> Result<OptimResult<S>, OptimError> {
122 let start = std::time::Instant::now();
123 let n = c.len();
124
125 if var_types.len() != n || bounds.len() != n {
126 return Err(OptimError::DimensionMismatch {
127 expected: n,
128 actual: if var_types.len() > bounds.len() {
129 var_types.len()
130 } else {
131 bounds.len()
132 },
133 });
134 }
135
136 let int_indices: Vec<usize> = (0..n)
138 .filter(|&i| var_types[i] == VarType::Integer || var_types[i] == VarType::Binary)
139 .collect();
140
141 let mut init_lb = vec![S::ZERO; n];
143 let mut init_ub = vec![S::INFINITY; n];
144
145 for i in 0..n {
146 if let Some((lo, hi)) = bounds[i] {
147 init_lb[i] = if lo > S::ZERO { lo } else { S::ZERO }; init_ub[i] = hi;
149 }
150 if var_types[i] == VarType::Binary {
152 init_lb[i] = S::ZERO;
153 init_ub[i] = S::ONE;
154 }
155 }
156
157 if int_indices.is_empty() {
159 let result = solve_lp_relaxation(c, a_ineq, b_ineq, a_eq, b_eq, &init_lb, &init_ub, opts)?;
160 let mut res = result;
161 res.wall_time_secs = start.elapsed().as_secs_f64();
162 return Ok(res);
163 }
164
165 let root_result =
167 match solve_lp_relaxation(c, a_ineq, b_ineq, a_eq, b_eq, &init_lb, &init_ub, opts) {
168 Ok(r) => r,
169 Err(OptimError::LPInfeasible) => return Err(OptimError::MILPInfeasible),
170 Err(e) => return Err(e),
171 };
172
173 if is_integer_feasible(&root_result.x, &int_indices, opts.int_tol) {
175 let mut res = root_result;
176 res.message = "Optimal (LP relaxation is integer feasible)".into();
177 res.wall_time_secs = start.elapsed().as_secs_f64();
178 return Ok(res);
179 }
180
181 let root_node = BbNode {
183 lb: init_lb,
184 ub: init_ub,
185 lp_bound: root_result.f,
186 depth: 0,
187 };
188
189 let mut heap = BinaryHeap::new();
190 heap.push(OrderedNode(root_node));
191
192 let mut best_obj = S::INFINITY;
193 let mut best_x: Option<Vec<S>> = None;
194 let mut nodes_explored: usize = 0;
195
196 while let Some(OrderedNode(node)) = heap.pop() {
197 nodes_explored += 1;
198
199 if nodes_explored > opts.max_nodes {
200 break;
201 }
202
203 if node.lp_bound >= best_obj - opts.gap_tol {
205 continue;
206 }
207
208 let lp_result =
210 match solve_lp_relaxation(c, a_ineq, b_ineq, a_eq, b_eq, &node.lb, &node.ub, opts) {
211 Ok(r) => r,
212 Err(OptimError::LPInfeasible) => continue, Err(_) => continue, };
215
216 if lp_result.f >= best_obj - opts.gap_tol {
218 continue;
219 }
220
221 if is_integer_feasible(&lp_result.x, &int_indices, opts.int_tol) {
223 if lp_result.f < best_obj {
224 best_obj = lp_result.f;
225 let mut x_rounded = lp_result.x.clone();
227 for &i in &int_indices {
228 x_rounded[i] = x_rounded[i].round();
229 }
230 best_x = Some(x_rounded);
231 if opts.verbose {
232 eprintln!(
233 "MILP: new incumbent obj={:.6} at node {}",
234 best_obj.to_f64(),
235 nodes_explored
236 );
237 }
238 }
239 continue;
240 }
241
242 let branch_var = select_branching_variable(&lp_result.x, &int_indices, opts.int_tol);
244 if let Some(bvar) = branch_var {
245 let val = lp_result.x[bvar];
246 let floor_val = val.floor();
247 let ceil_val = val.ceil();
248
249 if floor_val >= node.lb[bvar] {
251 let mut left_ub = node.ub.clone();
252 left_ub[bvar] = floor_val;
253 let left_node = BbNode {
254 lb: node.lb.clone(),
255 ub: left_ub,
256 lp_bound: lp_result.f, depth: node.depth + 1,
258 };
259 heap.push(OrderedNode(left_node));
260 }
261
262 if ceil_val <= node.ub[bvar] {
264 let mut right_lb = node.lb.clone();
265 right_lb[bvar] = ceil_val;
266 let right_node = BbNode {
267 lb: right_lb,
268 ub: node.ub.clone(),
269 lp_bound: lp_result.f,
270 depth: node.depth + 1,
271 };
272 heap.push(OrderedNode(right_node));
273 }
274 }
275 }
276
277 match best_x {
278 Some(x) => {
279 let f_val: S = c.iter().zip(x.iter()).map(|(&ci, &xi)| ci * xi).sum();
280 Ok(OptimResult {
281 x,
282 f: f_val,
283 grad: c.to_vec(),
284 iterations: nodes_explored,
285 n_feval: nodes_explored,
286 n_geval: 0,
287 converged: true,
288 message: format!(
289 "Optimal integer solution found after {} nodes",
290 nodes_explored
291 ),
292 status: OptimStatus::FunctionConverged,
293 history: Vec::new(),
294 lambda_eq: Vec::new(),
295 lambda_ineq: Vec::new(),
296 active_bounds: Vec::new(),
297 constraint_violation: S::ZERO,
298 wall_time_secs: start.elapsed().as_secs_f64(),
299 pareto: None,
300 sensitivity: None,
301 })
302 }
303 None => Err(OptimError::MILPInfeasible),
304 }
305}
306
307fn is_integer_feasible<S: Scalar>(x: &[S], int_indices: &[usize], tol: S) -> bool {
309 int_indices
310 .iter()
311 .all(|&i| (x[i] - x[i].round()).abs() < tol)
312}
313
314fn select_branching_variable<S: Scalar>(x: &[S], int_indices: &[usize], tol: S) -> Option<usize> {
317 let mut best_idx = None;
318 let mut best_frac_dist = S::ZERO; let half = S::from_f64(0.5);
320
321 for &i in int_indices {
322 let frac = x[i] - x[i].floor();
323 if frac < tol || frac > S::ONE - tol {
324 continue; }
326 let dist = (frac - half).abs();
328 let score = half - dist; if score > best_frac_dist {
330 best_frac_dist = score;
331 best_idx = Some(i);
332 }
333 }
334
335 best_idx
336}
337
338#[allow(clippy::too_many_arguments)]
344fn solve_lp_relaxation<
345 S: Scalar + faer::SimpleEntity + faer::Conjugate<Canonical = S> + faer::ComplexField,
346>(
347 c: &[S],
348 a_ineq: &[Vec<S>],
349 b_ineq: &[S],
350 a_eq: &[Vec<S>],
351 b_eq: &[S],
352 lb: &[S],
353 ub: &[S],
354 opts: &MILPOptions<S>,
355) -> Result<OptimResult<S>, OptimError> {
356 let n = c.len();
357
358 let obj_offset: S = c.iter().zip(lb.iter()).map(|(&ci, &lbi)| ci * lbi).sum();
362
363 let mut new_a_ineq: Vec<Vec<S>> = Vec::with_capacity(a_ineq.len() + n);
370 let mut new_b_ineq: Vec<S> = Vec::with_capacity(b_ineq.len() + n);
371
372 for (i, row) in a_ineq.iter().enumerate() {
373 let shift: S = row.iter().zip(lb.iter()).map(|(&a, &l)| a * l).sum();
374 new_a_ineq.push(row.clone());
375 new_b_ineq.push(b_ineq[i] - shift);
376 }
377
378 let mut new_a_eq: Vec<Vec<S>> = Vec::with_capacity(a_eq.len());
380 let mut new_b_eq: Vec<S> = Vec::with_capacity(b_eq.len());
381
382 for (i, row) in a_eq.iter().enumerate() {
383 let shift: S = row.iter().zip(lb.iter()).map(|(&a, &l)| a * l).sum();
384 new_a_eq.push(row.clone());
385 new_b_eq.push(b_eq[i] - shift);
386 }
387
388 for i in 0..n {
390 if ub[i].is_finite() {
391 let effective_ub = ub[i] - lb[i];
392 if effective_ub < -opts.lp_tol {
393 return Err(OptimError::LPInfeasible);
395 }
396 let mut row = vec![S::ZERO; n];
397 row[i] = S::ONE;
398 new_a_ineq.push(row);
399 new_b_ineq.push(effective_ub);
400 }
401 }
402
403 let lp_opts = LPOptions {
404 max_iter: 10_000,
405 tol: opts.lp_tol,
406 verbose: false,
407 };
408
409 let mut result = simplex_solve(c, &new_a_ineq, &new_b_ineq, &new_a_eq, &new_b_eq, &lp_opts)?;
410
411 for (xi, &lbi) in result.x.iter_mut().zip(lb.iter()) {
413 *xi += lbi;
414 }
415
416 result.f = c
418 .iter()
419 .zip(result.x.iter())
420 .map(|(&ci, &xi)| ci * xi)
421 .sum();
422
423 let _ = obj_offset;
426
427 Ok(result)
428}
429
430#[cfg(test)]
431mod tests {
432 use super::*;
433
434 fn default_opts() -> MILPOptions<f64> {
435 MILPOptions::default()
436 }
437
438 #[test]
439 fn test_milp_no_integer_vars() {
440 let c = vec![-1.0, -1.0];
443 let a_ineq = vec![vec![1.0, 1.0], vec![1.0, 0.0], vec![0.0, 1.0]];
444 let b_ineq = vec![4.0, 3.0, 3.0];
445 let var_types = vec![VarType::Continuous, VarType::Continuous];
446 let bounds = vec![None, None];
447 let opts = default_opts();
448
449 let result =
450 milp_solve(&c, &a_ineq, &b_ineq, &[], &[], &var_types, &bounds, &opts).unwrap();
451 assert!(result.converged);
452 assert!(
453 (result.f - (-4.0)).abs() < 1e-6,
454 "expected f ~ -4.0, got {}",
455 result.f
456 );
457 }
458
459 #[test]
460 fn test_milp_all_integer() {
461 let c = vec![-3.0, -5.0];
464 let a_ineq = vec![vec![1.0, 2.0], vec![2.0, 1.0]];
465 let b_ineq = vec![6.0, 8.0];
466 let var_types = vec![VarType::Integer, VarType::Integer];
467 let bounds = vec![None, None];
468 let opts = default_opts();
469
470 let result =
471 milp_solve(&c, &a_ineq, &b_ineq, &[], &[], &var_types, &bounds, &opts).unwrap();
472 assert!(result.converged);
473 assert!(
475 result.f <= -15.0 + 1e-6,
476 "expected obj <= -15, got {}",
477 result.f
478 );
479 for (i, xi) in result.x.iter().enumerate() {
481 assert!(
482 (xi - xi.round()).abs() < 1e-6,
483 "x[{}]={} is not integer",
484 i,
485 xi
486 );
487 }
488 let lhs1: f64 = result.x[0] + 2.0 * result.x[1];
490 let lhs2: f64 = 2.0 * result.x[0] + result.x[1];
491 assert!(lhs1 <= 6.0 + 1e-6, "constraint 1 violated: {}", lhs1);
492 assert!(lhs2 <= 8.0 + 1e-6, "constraint 2 violated: {}", lhs2);
493 }
494
495 #[test]
496 fn test_milp_binary_knapsack() {
497 let c = vec![-6.0, -5.0, -4.0];
500 let a_ineq = vec![vec![3.0, 4.0, 2.0]];
501 let b_ineq = vec![7.0];
502 let var_types = vec![VarType::Binary, VarType::Binary, VarType::Binary];
503 let bounds = vec![Some((0.0, 1.0)), Some((0.0, 1.0)), Some((0.0, 1.0))];
504 let opts = default_opts();
505
506 let result =
507 milp_solve(&c, &a_ineq, &b_ineq, &[], &[], &var_types, &bounds, &opts).unwrap();
508 assert!(result.converged);
509 assert!(
512 result.f <= -10.0 + 1e-6,
513 "expected obj <= -10, got {}",
514 result.f
515 );
516 for (i, xi) in result.x.iter().enumerate() {
518 assert!(
519 (xi - 0.0).abs() < 1e-6 || (xi - 1.0).abs() < 1e-6,
520 "x[{}]={} is not binary",
521 i,
522 xi
523 );
524 }
525 let weight: f64 = 3.0 * result.x[0] + 4.0 * result.x[1] + 2.0 * result.x[2];
527 assert!(
528 weight <= 7.0 + 1e-6,
529 "knapsack capacity violated: {}",
530 weight
531 );
532 }
533
534 #[test]
535 fn test_milp_mixed_integer() {
536 let c = vec![-1.0, -1.0];
539 let a_ineq = vec![vec![1.0, 1.0]];
540 let b_ineq = vec![3.5];
541 let var_types = vec![VarType::Integer, VarType::Continuous];
542 let bounds = vec![None, None];
543 let opts = default_opts();
544
545 let result =
546 milp_solve(&c, &a_ineq, &b_ineq, &[], &[], &var_types, &bounds, &opts).unwrap();
547 assert!(result.converged);
548 assert!(
550 (result.x[0] - result.x[0].round()).abs() < 1e-6,
551 "x[0]={} is not integer",
552 result.x[0]
553 );
554 assert!(result.f <= -3.4, "expected obj <= -3.4, got {}", result.f);
556 }
557
558 #[test]
559 fn test_milp_infeasible() {
560 let c = vec![-1.0, -1.0];
562 let a_eq = vec![vec![1.0, 1.0], vec![1.0, 1.0]];
563 let b_eq = vec![1.0, 2.0];
564 let var_types = vec![VarType::Integer, VarType::Integer];
565 let bounds = vec![None, None];
566 let opts = default_opts();
567
568 let result = milp_solve(&c, &[], &[], &a_eq, &b_eq, &var_types, &bounds, &opts);
569 assert!(result.is_err(), "expected infeasible, got {:?}", result);
570 match result.unwrap_err() {
571 OptimError::MILPInfeasible | OptimError::LPInfeasible => {}
572 e => panic!("expected MILPInfeasible or LPInfeasible, got {:?}", e),
573 }
574 }
575
576 #[test]
577 fn test_milp_with_equality() {
578 let c = vec![-1.0, -2.0];
581 let a_eq = vec![vec![1.0, 1.0]];
582 let b_eq = vec![3.0];
583 let var_types = vec![VarType::Integer, VarType::Integer];
584 let bounds = vec![None, None];
585 let opts = default_opts();
586
587 let result = milp_solve(&c, &[], &[], &a_eq, &b_eq, &var_types, &bounds, &opts).unwrap();
588 assert!(result.converged);
589 assert!(
590 (result.f - (-6.0)).abs() < 1e-6,
591 "expected obj=-6, got {}",
592 result.f
593 );
594 assert!(
595 (result.x[0] - 0.0).abs() < 1e-6,
596 "expected x[0]=0, got {}",
597 result.x[0]
598 );
599 assert!(
600 (result.x[1] - 3.0).abs() < 1e-6,
601 "expected x[1]=3, got {}",
602 result.x[1]
603 );
604 }
605}