use crate::error::{OptimizeError, OptimizeResult};
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum VariableSelectionStrategy {
#[default]
MostFractional,
FirstFractional,
DeepestCut,
}
#[derive(Debug, Clone)]
pub struct LiftProjectConfig {
pub max_cuts: usize,
pub variable_selection: VariableSelectionStrategy,
pub cut_violation_tol: f64,
pub ls_strengthening: bool,
pub int_tol: f64,
pub max_rows_per_var: usize,
}
impl Default for LiftProjectConfig {
fn default() -> Self {
LiftProjectConfig {
max_cuts: 50,
variable_selection: VariableSelectionStrategy::MostFractional,
cut_violation_tol: 1e-6,
ls_strengthening: false,
int_tol: 1e-8,
max_rows_per_var: 1000,
}
}
}
#[derive(Debug, Clone)]
pub struct LiftProjectCut {
pub pi: Vec<f64>,
pub pi0: f64,
pub source_var: usize,
pub source_row: usize,
pub violation: f64,
}
pub struct LiftProjectGenerator {
config: LiftProjectConfig,
}
impl LiftProjectGenerator {
pub fn new(config: LiftProjectConfig) -> Self {
LiftProjectGenerator { config }
}
pub fn default_generator() -> Self {
LiftProjectGenerator::new(LiftProjectConfig::default())
}
pub fn generate_cuts(
&self,
a: &[Vec<f64>],
b: &[f64],
x_bar: &[f64],
integer_vars: &[usize],
) -> OptimizeResult<Vec<LiftProjectCut>> {
let n = x_bar.len();
if n == 0 {
return Err(OptimizeError::InvalidInput(
"x_bar must be non-empty".to_string(),
));
}
if a.len() != b.len() {
return Err(OptimizeError::InvalidInput(format!(
"Constraint matrix has {} rows but b has {} entries",
a.len(),
b.len()
)));
}
for (i, row) in a.iter().enumerate() {
if row.len() != n {
return Err(OptimizeError::InvalidInput(format!(
"Row {} has {} columns but x_bar has {} components",
i,
row.len(),
n
)));
}
}
let (a_aug, b_aug) = build_augmented_constraints(a, b, x_bar, integer_vars);
let fractional_vars: Vec<usize> = integer_vars
.iter()
.copied()
.filter(|&j| {
j < n && {
let xj = x_bar[j];
xj > self.config.int_tol && xj < 1.0 - self.config.int_tol
}
})
.collect();
if fractional_vars.is_empty() {
return Ok(Vec::new());
}
let mut all_cuts: Vec<LiftProjectCut> = Vec::new();
match self.config.variable_selection {
VariableSelectionStrategy::MostFractional => {
let mut ranked: Vec<(usize, f64)> = fractional_vars
.iter()
.map(|&j| {
let frac = x_bar[j];
let dist = frac.min(1.0 - frac);
(j, dist)
})
.collect();
ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
for (j, _) in ranked {
if all_cuts.len() >= self.config.max_cuts {
break;
}
self.append_cuts_for_var(&a_aug, &b_aug, x_bar, j, &mut all_cuts)?;
}
}
VariableSelectionStrategy::FirstFractional => {
for &j in &fractional_vars {
if all_cuts.len() >= self.config.max_cuts {
break;
}
self.append_cuts_for_var(&a_aug, &b_aug, x_bar, j, &mut all_cuts)?;
}
}
VariableSelectionStrategy::DeepestCut => {
let mut candidates: Vec<LiftProjectCut> = Vec::new();
for &j in &fractional_vars {
let mut tmp: Vec<LiftProjectCut> = Vec::new();
self.append_cuts_for_var(&a_aug, &b_aug, x_bar, j, &mut tmp)?;
if let Some(best) = tmp.into_iter().max_by(|c1, c2| {
c1.violation
.partial_cmp(&c2.violation)
.unwrap_or(std::cmp::Ordering::Equal)
}) {
candidates.push(best);
}
}
candidates.sort_by(|c1, c2| {
c2.violation
.partial_cmp(&c1.violation)
.unwrap_or(std::cmp::Ordering::Equal)
});
candidates.truncate(self.config.max_cuts);
all_cuts = candidates;
}
}
all_cuts.sort_by(|c1, c2| {
c2.violation
.partial_cmp(&c1.violation)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(all_cuts)
}
pub fn select_variable(&self, x_bar: &[f64], integer_vars: &[usize]) -> Option<usize> {
let n = x_bar.len();
match self.config.variable_selection {
VariableSelectionStrategy::FirstFractional => integer_vars
.iter()
.copied()
.find(|&j| {
j < n
&& x_bar[j] > self.config.int_tol
&& x_bar[j] < 1.0 - self.config.int_tol
}),
VariableSelectionStrategy::MostFractional | VariableSelectionStrategy::DeepestCut => {
let mut best_idx = None;
let mut best_dist = -1.0_f64;
for &j in integer_vars {
if j >= n {
continue;
}
let xj = x_bar[j];
if xj > self.config.int_tol && xj < 1.0 - self.config.int_tol {
let dist = xj.min(1.0 - xj);
if dist > best_dist {
best_dist = dist;
best_idx = Some(j);
}
}
}
best_idx
}
}
}
pub fn generate_cut_for_var(
&self,
a: &[Vec<f64>],
b: &[f64],
x_bar: &[f64],
j: usize,
) -> Option<LiftProjectCut> {
let f_j = x_bar[j];
if f_j <= self.config.int_tol || f_j >= 1.0 - self.config.int_tol {
return None;
}
let n = x_bar.len();
let integer_vars_for_j: Vec<usize> = (0..n).collect();
let (a_aug, b_aug) = build_augmented_constraints(a, b, x_bar, &integer_vars_for_j);
self.best_cut_from_rows(&a_aug, &b_aug, x_bar, j)
}
pub fn cut_violation(&self, cut: &LiftProjectCut, x_bar: &[f64]) -> f64 {
cut.pi
.iter()
.zip(x_bar.iter())
.map(|(&pi_k, &xk)| pi_k * xk)
.sum::<f64>()
- cut.pi0
}
fn bcc_cut_from_row(
&self,
row: &[f64],
bi: f64,
x_bar: &[f64],
j: usize,
row_index: usize,
) -> Option<LiftProjectCut> {
let a_ij = row[j];
if a_ij.abs() < 1e-12 {
return None;
}
let dot_ax: f64 = row.iter().zip(x_bar.iter()).map(|(&aik, &xk)| aik * xk).sum();
let r_i = bi - dot_ax; let f_j = x_bar[j];
let violation = if a_ij > 0.0 {
r_i * f_j / (1.0 - f_j)
} else {
r_i * (1.0 - f_j) / f_j
};
if violation <= self.config.cut_violation_tol {
return None;
}
let pi0 = dot_ax - violation;
Some(LiftProjectCut {
pi: row.to_vec(),
pi0,
source_var: j,
source_row: row_index,
violation,
})
}
fn best_cut_from_rows(
&self,
a: &[Vec<f64>],
b: &[f64],
x_bar: &[f64],
j: usize,
) -> Option<LiftProjectCut> {
let mut best: Option<LiftProjectCut> = None;
let row_limit = self.config.max_rows_per_var.min(a.len());
for (i, (row, &bi)) in a.iter().zip(b.iter()).enumerate().take(row_limit) {
if let Some(cut) = self.bcc_cut_from_row(row, bi, x_bar, j, i) {
let better = best.as_ref().map_or(true, |prev| cut.violation > prev.violation);
if better {
best = Some(cut);
}
}
}
best
}
fn append_cuts_for_var(
&self,
a: &[Vec<f64>],
b: &[f64],
x_bar: &[f64],
j: usize,
out: &mut Vec<LiftProjectCut>,
) -> OptimizeResult<()> {
let f_j = x_bar[j];
if f_j <= self.config.int_tol || f_j >= 1.0 - self.config.int_tol {
return Ok(());
}
let row_limit = self.config.max_rows_per_var.min(a.len());
for (i, (row, &bi)) in a.iter().zip(b.iter()).enumerate().take(row_limit) {
if out.len() >= self.config.max_cuts {
break;
}
if let Some(cut) = self.bcc_cut_from_row(row, bi, x_bar, j, i) {
out.push(cut);
}
}
Ok(())
}
}
fn build_augmented_constraints(
a: &[Vec<f64>],
b: &[f64],
x_bar: &[f64],
integer_vars: &[usize],
) -> (Vec<Vec<f64>>, Vec<f64>) {
let n = x_bar.len();
let mut a_aug: Vec<Vec<f64>> = a.to_vec();
let mut b_aug: Vec<f64> = b.to_vec();
for &k in integer_vars {
if k >= n {
continue;
}
let mut ub_row = vec![0.0; n];
ub_row[k] = 1.0;
a_aug.push(ub_row);
b_aug.push(1.0);
let mut lb_row = vec![0.0; n];
lb_row[k] = -1.0;
a_aug.push(lb_row);
b_aug.push(0.0);
}
(a_aug, b_aug)
}
pub fn ls_strengthen(
cut: &LiftProjectCut,
x_bar: &[f64],
integer_vars: &[usize],
j: usize,
) -> LiftProjectCut {
let n = cut.pi.len();
let f_j = if j < x_bar.len() { x_bar[j] } else { 0.5 };
let mut new_pi = cut.pi.clone();
let mut delta_pi0 = 0.0_f64;
for &k in integer_vars {
if k >= n || k == j {
continue;
}
let x_k = if k < x_bar.len() { x_bar[k] } else { continue };
let pi_k = cut.pi[k];
if pi_k > 0.0 && x_k > f_j {
let tightening = pi_k * (x_k - f_j) * f_j;
delta_pi0 += tightening;
let scale_denom = x_k + 1e-12;
new_pi[k] = pi_k + tightening / scale_denom;
}
}
LiftProjectCut {
pi: new_pi,
pi0: cut.pi0 + delta_pi0,
source_var: cut.source_var,
source_row: cut.source_row,
violation: cut.violation - delta_pi0,
}
}
pub fn cut_satisfied_at_integer(cut: &LiftProjectCut, x: &[f64]) -> bool {
let dot: f64 = cut
.pi
.iter()
.zip(x.iter())
.map(|(&pi_k, &xk)| pi_k * xk)
.sum();
dot >= cut.pi0 - 1e-9
}
#[cfg(test)]
mod tests {
use super::*;
fn simple_constraints() -> (Vec<Vec<f64>>, Vec<f64>) {
let a = vec![vec![1.0, 1.0]];
let b = vec![1.0];
(a, b)
}
fn simple_x_bar() -> Vec<f64> {
vec![0.5, 0.5]
}
fn simple_integer_vars() -> Vec<usize> {
vec![0, 1]
}
#[test]
fn test_most_fractional_selects_closest_to_half() {
let config = LiftProjectConfig {
variable_selection: VariableSelectionStrategy::MostFractional,
..Default::default()
};
let gen = LiftProjectGenerator::new(config);
let x_bar = vec![0.2, 0.5, 0.4];
let integer_vars = vec![0, 1, 2];
let selected = gen.select_variable(&x_bar, &integer_vars);
assert_eq!(selected, Some(1));
}
#[test]
fn test_first_fractional_selects_first() {
let config = LiftProjectConfig {
variable_selection: VariableSelectionStrategy::FirstFractional,
..Default::default()
};
let gen = LiftProjectGenerator::new(config);
let x_bar = vec![0.8, 0.5, 0.3];
let integer_vars = vec![0, 1, 2];
let selected = gen.select_variable(&x_bar, &integer_vars);
assert_eq!(selected, Some(0));
}
#[test]
fn test_select_variable_returns_none_when_all_integer() {
let config = LiftProjectConfig::default();
let gen = LiftProjectGenerator::new(config);
let x_bar = vec![0.0, 1.0, 0.0, 1.0];
let integer_vars = vec![0, 1, 2, 3];
assert_eq!(gen.select_variable(&x_bar, &integer_vars), None);
}
#[test]
fn test_select_variable_skips_continuous_vars() {
let config = LiftProjectConfig {
variable_selection: VariableSelectionStrategy::MostFractional,
..Default::default()
};
let gen = LiftProjectGenerator::new(config);
let x_bar = vec![0.3, 0.5, 0.4];
let integer_vars = vec![2];
let selected = gen.select_variable(&x_bar, &integer_vars);
assert_eq!(selected, Some(2));
}
#[test]
fn test_generate_cuts_empty_for_integer_solution() {
let gen = LiftProjectGenerator::default_generator();
let (a, b) = simple_constraints();
let x_bar = vec![1.0, 0.0]; let integer_vars = simple_integer_vars();
let cuts = gen.generate_cuts(&a, &b, &x_bar, &integer_vars).unwrap();
assert!(cuts.is_empty());
}
#[test]
fn test_generate_cuts_empty_for_no_integer_vars() {
let gen = LiftProjectGenerator::default_generator();
let (a, b) = simple_constraints();
let x_bar = vec![0.5, 0.5];
let cuts = gen.generate_cuts(&a, &b, &x_bar, &[]).unwrap();
assert!(cuts.is_empty());
}
#[test]
fn test_cuts_violated_at_x_bar() {
let gen = LiftProjectGenerator::default_generator();
let (a, b) = simple_constraints();
let x_bar = simple_x_bar();
let integer_vars = simple_integer_vars();
let cuts = gen.generate_cuts(&a, &b, &x_bar, &integer_vars).unwrap();
assert!(!cuts.is_empty(), "Expected at least one cut");
for cut in &cuts {
let violation = gen.cut_violation(cut, &x_bar);
assert!(
violation > gen.config.cut_violation_tol,
"Cut should be violated at x̄, got violation = {}",
violation
);
}
}
#[test]
fn test_cut_violation_is_positive_at_x_bar() {
let gen = LiftProjectGenerator::default_generator();
let (a, b) = simple_constraints();
let x_bar = simple_x_bar();
let cut = gen
.generate_cut_for_var(&a, &b, &x_bar, 0)
.expect("Should generate a cut for variable 0");
assert!(
cut.violation > 0.0,
"violation field should be positive, got {}",
cut.violation
);
let v2 = gen.cut_violation(&cut, &x_bar);
assert!(
(cut.violation - v2).abs() < 1e-12,
"violation field and cut_violation must agree: {} vs {}",
cut.violation,
v2
);
}
#[test]
fn test_cuts_satisfied_at_zero_vector() {
let gen = LiftProjectGenerator::default_generator();
let (a, b) = simple_constraints();
let x_bar = simple_x_bar();
let integer_vars = simple_integer_vars();
let cuts = gen.generate_cuts(&a, &b, &x_bar, &integer_vars).unwrap();
let zero = vec![0.0, 0.0];
for cut in &cuts {
let dot: f64 = cut.pi.iter().zip(zero.iter()).map(|(&p, &x)| p * x).sum();
assert!(
cut_satisfied_at_integer(cut, &zero),
"Cut should hold at x=(0,0): π·x={:.6} ≥ π₀={:.6}",
dot,
cut.pi0
);
}
}
#[test]
fn test_cuts_satisfied_at_ones_vector() {
let a = vec![vec![1.0, 1.0]];
let b = vec![2.0];
let x_bar = vec![0.4, 0.6];
let gen = LiftProjectGenerator::default_generator();
let cuts = gen.generate_cuts(&a, &b, &x_bar, &[0, 1]).unwrap();
let ones = vec![1.0, 1.0];
for cut in &cuts {
assert!(
cut_satisfied_at_integer(cut, &ones),
"Cut should hold at x=(1,1): π·x={:.6} ≥ π₀={:.6}",
cut.pi.iter().zip(ones.iter()).map(|(&p, &x)| p * x).sum::<f64>(),
cut.pi0
);
}
}
#[test]
fn test_cuts_satisfied_at_unit_vectors() {
let gen = LiftProjectGenerator::default_generator();
let (a, b) = simple_constraints();
let x_bar = simple_x_bar();
let integer_vars = simple_integer_vars();
let cuts = gen.generate_cuts(&a, &b, &x_bar, &integer_vars).unwrap();
let e0 = vec![1.0, 0.0];
let e1 = vec![0.0, 1.0];
for cut in &cuts {
assert!(cut_satisfied_at_integer(cut, &e0), "Cut should hold at e0");
assert!(cut_satisfied_at_integer(cut, &e1), "Cut should hold at e1");
}
}
#[test]
fn test_max_cuts_limits_output() {
let config = LiftProjectConfig {
max_cuts: 1,
..Default::default()
};
let gen = LiftProjectGenerator::new(config);
let a = vec![
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![1.0, 1.0],
vec![2.0, 1.0],
];
let b = vec![1.0, 1.0, 1.5, 2.0];
let x_bar = vec![0.4, 0.6];
let cuts = gen.generate_cuts(&a, &b, &x_bar, &[0, 1]).unwrap();
assert!(
cuts.len() <= 1,
"Expected at most 1 cut, got {}",
cuts.len()
);
}
#[test]
fn test_zero_max_cuts_returns_empty() {
let config = LiftProjectConfig {
max_cuts: 0,
..Default::default()
};
let gen = LiftProjectGenerator::new(config);
let (a, b) = simple_constraints();
let x_bar = simple_x_bar();
let cuts = gen.generate_cuts(&a, &b, &x_bar, &[0, 1]).unwrap();
assert!(cuts.is_empty());
}
#[test]
fn test_generate_cuts_error_on_empty_x_bar() {
let gen = LiftProjectGenerator::default_generator();
let result = gen.generate_cuts(&[], &[], &[], &[]);
assert!(result.is_err());
}
#[test]
fn test_generate_cuts_error_on_mismatched_a_b() {
let gen = LiftProjectGenerator::default_generator();
let a = vec![vec![1.0, 1.0], vec![1.0, 0.0]];
let b = vec![1.0]; let result = gen.generate_cuts(&a, &b, &[0.5, 0.5], &[0, 1]);
assert!(result.is_err());
}
#[test]
fn test_generate_cuts_error_on_row_length_mismatch() {
let gen = LiftProjectGenerator::default_generator();
let a = vec![vec![1.0, 1.0, 0.5]]; let b = vec![1.0];
let x_bar = vec![0.5, 0.5]; let result = gen.generate_cuts(&a, &b, &x_bar, &[0, 1]);
assert!(result.is_err());
}
#[test]
fn test_deepest_cut_strategy_returns_most_violated() {
let config = LiftProjectConfig {
variable_selection: VariableSelectionStrategy::DeepestCut,
max_cuts: 10,
..Default::default()
};
let gen = LiftProjectGenerator::new(config);
let a = vec![
vec![1.0, 1.0],
vec![2.0, 1.0],
vec![1.0, 2.0],
];
let b = vec![1.5, 2.0, 2.0];
let x_bar = vec![0.4, 0.6];
let cuts = gen.generate_cuts(&a, &b, &x_bar, &[0, 1]).unwrap();
for w in cuts.windows(2) {
assert!(
w[0].violation >= w[1].violation - 1e-12,
"Cuts should be sorted by decreasing violation"
);
}
}
#[test]
fn test_generate_cut_for_var_negative_coefficient() {
let gen = LiftProjectGenerator::default_generator();
let a = vec![vec![-1.0, 1.0]];
let b = vec![0.5];
let x_bar = vec![0.3, 0.7];
let cut = gen.generate_cut_for_var(&a, &b, &x_bar, 0);
if let Some(c) = cut {
assert_eq!(c.pi.len(), 2);
assert!(c.violation >= 0.0);
}
}
#[test]
fn test_generate_cut_for_var_no_cut_when_no_structural_row_has_coeff_but_bound_exists() {
let gen = LiftProjectGenerator::default_generator();
let a = vec![vec![0.0, 1.0]];
let b = vec![0.8];
let x_bar = vec![0.4, 0.6];
let cut = gen.generate_cut_for_var(&a, &b, &x_bar, 0);
assert!(
cut.is_some(),
"Should get a cut from the bound row for variable 0"
);
}
#[test]
fn test_ls_strengthen_does_not_decrease_coefficients_to_negative() {
let gen = LiftProjectGenerator::default_generator();
let (a, b) = simple_constraints();
let x_bar = simple_x_bar();
let cut = gen
.generate_cut_for_var(&a, &b, &x_bar, 0)
.expect("Should generate a cut for variable 0");
let strengthened = ls_strengthen(&cut, &x_bar, &[0, 1], 0);
for &pi_k in &strengthened.pi {
assert!(
pi_k >= 0.0 - 1e-12,
"Coefficient should not become negative: {}",
pi_k
);
}
}
#[test]
fn test_cut_satisfied_at_integer_utility() {
let cut = LiftProjectCut {
pi: vec![1.0, 1.0],
pi0: 0.0,
source_var: 0,
source_row: 0,
violation: 0.5,
};
assert!(cut_satisfied_at_integer(&cut, &[0.0, 0.0]));
assert!(cut_satisfied_at_integer(&cut, &[1.0, 0.0]));
assert!(cut_satisfied_at_integer(&cut, &[0.0, 1.0]));
assert!(cut_satisfied_at_integer(&cut, &[1.0, 1.0]));
}
#[test]
fn test_bcc_violation_equals_stored_violation() {
let gen = LiftProjectGenerator::default_generator();
let a = vec![vec![2.0, 1.0], vec![1.0, 2.0]];
let b = vec![3.0, 3.0];
let x_bar = vec![0.6, 0.8];
let cuts = gen.generate_cuts(&a, &b, &x_bar, &[0, 1]).unwrap();
for cut in &cuts {
let recomputed = gen.cut_violation(cut, &x_bar);
assert!(
(cut.violation - recomputed).abs() < 1e-12,
"Stored violation {} != recomputed {}",
cut.violation,
recomputed
);
}
}
#[test]
fn test_augmented_system_size() {
let a = vec![vec![1.0, 1.0]];
let b = vec![1.5];
let x_bar = vec![0.5, 0.5];
let integer_vars = vec![0, 1];
let (a_aug, b_aug) = build_augmented_constraints(&a, &b, &x_bar, &integer_vars);
assert_eq!(a_aug.len(), 5);
assert_eq!(b_aug.len(), 5);
}
}