use crate::error::{OptimizeError, OptimizeResult};
use super::lift_project::{LiftProjectConfig, LiftProjectCut, LiftProjectGenerator};
pub struct LiftProjectMipSolver {
config: LiftProjectConfig,
generator: LiftProjectGenerator,
cut_pool: Vec<LiftProjectCut>,
iterations: usize,
total_cuts_generated: usize,
}
impl LiftProjectMipSolver {
pub fn new(config: LiftProjectConfig) -> Self {
let generator = LiftProjectGenerator::new(config.clone());
LiftProjectMipSolver {
config,
generator,
cut_pool: Vec::new(),
iterations: 0,
total_cuts_generated: 0,
}
}
pub fn default_solver() -> Self {
LiftProjectMipSolver::new(LiftProjectConfig::default())
}
pub fn add_cuts_to_lp(
&mut self,
a: &[Vec<f64>],
b: &[f64],
x_bar: &[f64],
integer_vars: &[usize],
) -> OptimizeResult<Vec<LiftProjectCut>> {
self.iterations += 1;
let new_cuts = self.generator.generate_cuts(a, b, x_bar, integer_vars)?;
if new_cuts.is_empty() {
return Ok(Vec::new());
}
let violated: Vec<LiftProjectCut> = new_cuts
.into_iter()
.filter(|c| {
let v = self.generator.cut_violation(c, x_bar);
v > self.config.cut_violation_tol
})
.collect();
self.total_cuts_generated += violated.len();
self.cut_pool.extend(violated.clone());
Ok(violated)
}
pub fn cut_pool_size(&self) -> usize {
self.cut_pool.len()
}
pub fn cut_pool(&self) -> &[LiftProjectCut] {
&self.cut_pool
}
pub fn clear_cut_pool(&mut self) {
self.cut_pool.clear();
}
pub fn purge_non_violated_cuts(&mut self, x_new: &[f64]) {
self.cut_pool.retain(|c| {
let v = self.generator.cut_violation(c, x_new);
v > self.config.cut_violation_tol
});
}
pub fn iterations(&self) -> usize {
self.iterations
}
pub fn total_cuts_generated(&self) -> usize {
self.total_cuts_generated
}
pub fn config(&self) -> &LiftProjectConfig {
&self.config
}
pub fn cut_violation(&self, cut: &LiftProjectCut, x_bar: &[f64]) -> f64 {
self.generator.cut_violation(cut, x_bar)
}
pub fn build_augmented_system(
&self,
a: &[Vec<f64>],
b: &[f64],
) -> OptimizeResult<(Vec<Vec<f64>>, Vec<f64>)> {
if a.len() != b.len() {
return Err(OptimizeError::InvalidInput(format!(
"Constraint matrix has {} rows but b has {} entries",
a.len(),
b.len()
)));
}
let n = if a.is_empty() {
self.cut_pool.first().map_or(0, |c| c.pi.len())
} else {
a[0].len()
};
let mut a_aug: Vec<Vec<f64>> = a.to_vec();
let mut b_aug: Vec<f64> = b.to_vec();
for cut in &self.cut_pool {
if cut.pi.len() != n {
return Err(OptimizeError::InvalidInput(format!(
"Cut has {} coefficients but constraint matrix has {} columns",
cut.pi.len(),
n
)));
}
let neg_pi: Vec<f64> = cut.pi.iter().map(|&p| -p).collect();
a_aug.push(neg_pi);
b_aug.push(-cut.pi0);
}
Ok((a_aug, b_aug))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_fractional_lp() -> (Vec<Vec<f64>>, Vec<f64>, Vec<f64>, Vec<usize>) {
let a = vec![vec![1.0, 1.0]];
let b = vec![1.0];
let x_bar = vec![0.4, 0.6];
let ivars = vec![0, 1];
(a, b, x_bar, ivars)
}
#[test]
fn test_add_cuts_increases_pool_size() {
let mut solver = LiftProjectMipSolver::default_solver();
let (a, b, x_bar, ivars) = make_fractional_lp();
assert_eq!(solver.cut_pool_size(), 0);
let cuts = solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
assert_eq!(solver.cut_pool_size(), cuts.len());
assert!(solver.cut_pool_size() > 0, "Expected cuts to be generated");
}
#[test]
fn test_add_cuts_returns_violated_cuts() {
let mut solver = LiftProjectMipSolver::default_solver();
let (a, b, x_bar, ivars) = make_fractional_lp();
let cuts = solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
for cut in &cuts {
let v = solver.cut_violation(cut, &x_bar);
assert!(
v > solver.config().cut_violation_tol,
"Returned cut should be violated at x_bar, got v={}",
v
);
}
}
#[test]
fn test_add_cuts_empty_for_integer_solution() {
let mut solver = LiftProjectMipSolver::default_solver();
let (a, b, _, ivars) = make_fractional_lp();
let x_int = vec![1.0, 0.0]; let cuts = solver.add_cuts_to_lp(&a, &b, &x_int, &ivars).unwrap();
assert!(cuts.is_empty());
assert_eq!(solver.cut_pool_size(), 0);
}
#[test]
fn test_clear_cut_pool_resets_size() {
let mut solver = LiftProjectMipSolver::default_solver();
let (a, b, x_bar, ivars) = make_fractional_lp();
solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
assert!(solver.cut_pool_size() > 0);
solver.clear_cut_pool();
assert_eq!(solver.cut_pool_size(), 0);
}
#[test]
fn test_iterations_counter_increments() {
let mut solver = LiftProjectMipSolver::default_solver();
let (a, b, x_bar, ivars) = make_fractional_lp();
assert_eq!(solver.iterations(), 0);
solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
assert_eq!(solver.iterations(), 1);
solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
assert_eq!(solver.iterations(), 2);
}
#[test]
fn test_total_cuts_generated_accumulates() {
let mut solver = LiftProjectMipSolver::default_solver();
let (a, b, x_bar, ivars) = make_fractional_lp();
solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
let after_first = solver.total_cuts_generated();
solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
let after_second = solver.total_cuts_generated();
assert!(after_second >= after_first);
}
#[test]
fn test_pool_accumulates_across_calls() {
let mut solver = LiftProjectMipSolver::default_solver();
let (a, b, x_bar, ivars) = make_fractional_lp();
solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
let size_after_first = solver.cut_pool_size();
solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
let size_after_second = solver.cut_pool_size();
assert!(size_after_second >= size_after_first);
}
#[test]
fn test_purge_non_violated_cuts() {
let mut solver = LiftProjectMipSolver::default_solver();
let (a, b, x_bar, ivars) = make_fractional_lp();
solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
let size_before = solver.cut_pool_size();
let x_int = vec![1.0, 0.0];
solver.purge_non_violated_cuts(&x_int);
let size_after = solver.cut_pool_size();
assert!(
size_after <= size_before,
"Pool should not grow after purge"
);
}
#[test]
fn test_build_augmented_system_appends_cuts() {
let mut solver = LiftProjectMipSolver::default_solver();
let (a, b, x_bar, ivars) = make_fractional_lp();
solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
let n_original = a.len();
let n_cuts = solver.cut_pool_size();
let (a_aug, b_aug) = solver.build_augmented_system(&a, &b).unwrap();
assert_eq!(a_aug.len(), n_original + n_cuts);
assert_eq!(b_aug.len(), n_original + n_cuts);
}
#[test]
fn test_build_augmented_system_negates_cuts() {
let mut solver = LiftProjectMipSolver::default_solver();
let (a, b, x_bar, ivars) = make_fractional_lp();
solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
let (a_aug, b_aug) = solver.build_augmented_system(&a, &b).unwrap();
let n_orig = a.len();
for (k, cut) in solver.cut_pool().iter().enumerate() {
let row = &a_aug[n_orig + k];
let rhs = b_aug[n_orig + k];
for (j, (&aug_coeff, &pi_k)) in row.iter().zip(cut.pi.iter()).enumerate() {
assert!(
(aug_coeff - (-pi_k)).abs() < 1e-12,
"Augmented row coeff [{}][{}] = {} but expected {}",
k, j, aug_coeff, -pi_k
);
}
assert!(
(rhs - (-cut.pi0)).abs() < 1e-12,
"Augmented RHS = {} but expected {}",
rhs, -cut.pi0
);
}
}
#[test]
fn test_build_augmented_system_error_on_mismatched_a_b() {
let solver = LiftProjectMipSolver::default_solver();
let a = vec![vec![1.0, 1.0], vec![0.0, 1.0]];
let b = vec![1.0]; let result = solver.build_augmented_system(&a, &b);
assert!(result.is_err());
}
#[test]
fn test_cut_pool_accessor_matches_pool_size() {
let mut solver = LiftProjectMipSolver::default_solver();
let (a, b, x_bar, ivars) = make_fractional_lp();
solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
assert_eq!(solver.cut_pool().len(), solver.cut_pool_size());
}
#[test]
fn test_config_accessor() {
let config = LiftProjectConfig {
max_cuts: 7,
cut_violation_tol: 1e-5,
..Default::default()
};
let solver = LiftProjectMipSolver::new(config.clone());
assert_eq!(solver.config().max_cuts, 7);
assert!((solver.config().cut_violation_tol - 1e-5).abs() < 1e-12);
}
#[test]
fn test_multiple_constraint_rows_generate_more_cuts() {
let mut solver = LiftProjectMipSolver::default_solver();
let a = vec![
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![1.0, 1.0],
];
let b = vec![0.8, 0.8, 1.2];
let x_bar = vec![0.4, 0.5];
let ivars = vec![0, 1];
let cuts = solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
assert!(!cuts.is_empty());
}
#[test]
fn test_solver_handles_no_integer_vars_gracefully() {
let mut solver = LiftProjectMipSolver::default_solver();
let a = vec![vec![1.0, 1.0]];
let b = vec![1.0];
let x_bar = vec![0.4, 0.6];
let cuts = solver.add_cuts_to_lp(&a, &b, &x_bar, &[]).unwrap();
assert!(cuts.is_empty());
}
}