pub use cobre_solver::ffi::{HIGHS_BASIS_STATUS_BASIC, HIGHS_BASIS_STATUS_LOWER};
use cobre_solver::Basis;
use crate::cut::pool::CutPool;
pub fn pad_basis_for_cuts(
basis: &mut Basis,
pool: &CutPool,
state: &[f64],
theta_value: f64,
base_row_count: usize,
tolerance: f64,
) -> (usize, usize) {
debug_assert!(
state.len() == pool.state_dimension,
"state length {} != pool.state_dimension {}",
state.len(),
pool.state_dimension,
);
debug_assert!(
basis.row_status.len() >= base_row_count,
"basis.row_status.len() {} < base_row_count {}",
basis.row_status.len(),
base_row_count,
);
let target_len = base_row_count + pool.active_count();
if basis.row_status.len() >= target_len {
return (0, 0);
}
let old_len = basis.row_status.len();
basis
.row_status
.resize(target_len, HIGHS_BASIS_STATUS_BASIC);
let already_padded_cuts = old_len.saturating_sub(base_row_count);
let mut tight_count: usize = 0;
let mut slack_count: usize = 0;
let mut cut_index: usize = 0;
for (_slot, intercept, coefficients) in pool.active_cuts() {
if cut_index < already_padded_cuts {
cut_index += 1;
continue;
}
let cut_value: f64 = intercept
+ coefficients
.iter()
.zip(state)
.map(|(c, x)| c * x)
.sum::<f64>();
let slack = theta_value - cut_value;
let row_idx = base_row_count + cut_index;
if slack <= tolerance {
basis.row_status[row_idx] = HIGHS_BASIS_STATUS_LOWER;
tight_count += 1;
} else {
slack_count += 1;
}
cut_index += 1;
}
(tight_count, slack_count)
}
#[cfg(test)]
mod tests {
use cobre_solver::Basis;
use super::{HIGHS_BASIS_STATUS_BASIC, HIGHS_BASIS_STATUS_LOWER, pad_basis_for_cuts};
use crate::cut::pool::CutPool;
fn make_pool_with_cuts(cuts: &[(f64, Vec<f64>)], state_dim: usize) -> CutPool {
let mut pool = CutPool::new(cuts.len().max(1) * 10, state_dim, 1, 0);
for (i, (intercept, coeffs)) in cuts.iter().enumerate() {
pool.add_cut(i as u64, 0, *intercept, coeffs);
}
pool
}
#[test]
fn test_tight_and_slack_cuts_get_correct_status() {
let pool = make_pool_with_cuts(
&[(10.0, vec![1.0]), (20.0, vec![2.0]), (30.0, vec![3.0])],
1,
);
let mut basis = Basis::new(5, 2); let (tight, slack) = pad_basis_for_cuts(&mut basis, &pool, &[5.0], 25.0, 2, 1e-7);
assert_eq!(basis.row_status.len(), 5, "basis must grow to base+active");
assert_eq!(
basis.row_status[2], HIGHS_BASIS_STATUS_BASIC,
"cut 0 slack=10 → BASIC"
);
assert_eq!(
basis.row_status[3], HIGHS_BASIS_STATUS_LOWER,
"cut 1 slack=-5 → NONBASIC_LOWER"
);
assert_eq!(
basis.row_status[4], HIGHS_BASIS_STATUS_LOWER,
"cut 2 slack=-20 → NONBASIC_LOWER"
);
assert_eq!(tight, 2, "two tight/violated cuts");
assert_eq!(slack, 1, "one slack cut");
}
#[test]
fn test_exactly_tight_cut_is_nonbasic_lower() {
let pool = make_pool_with_cuts(&[(5.0, vec![1.0, 2.0])], 2);
let mut basis = Basis::new(3, 0);
let (tight, slack) = pad_basis_for_cuts(&mut basis, &pool, &[1.0, 1.0], 8.0, 0, 1e-7);
assert_eq!(basis.row_status.len(), 1);
assert_eq!(basis.row_status[0], HIGHS_BASIS_STATUS_LOWER);
assert_eq!(tight, 1);
assert_eq!(slack, 0);
}
#[test]
fn test_empty_pool_is_noop() {
let pool = CutPool::new(10, 2, 1, 0); let mut basis = Basis::new(3, 2);
basis.row_status[0] = HIGHS_BASIS_STATUS_LOWER;
basis.row_status[1] = HIGHS_BASIS_STATUS_BASIC;
let (tight, slack) = pad_basis_for_cuts(&mut basis, &pool, &[1.0, 1.0], 5.0, 2, 1e-7);
assert_eq!(basis.row_status.len(), 2, "row_status unchanged");
assert_eq!(basis.row_status[0], HIGHS_BASIS_STATUS_LOWER);
assert_eq!(basis.row_status[1], HIGHS_BASIS_STATUS_BASIC);
assert_eq!(tight, 0);
assert_eq!(slack, 0);
}
#[test]
fn test_already_padded_basis_is_noop() {
let pool = make_pool_with_cuts(&[(10.0, vec![1.0]), (20.0, vec![2.0])], 1);
let mut basis = Basis::new(3, 4);
basis.row_status[2] = HIGHS_BASIS_STATUS_LOWER;
basis.row_status[3] = HIGHS_BASIS_STATUS_LOWER;
let (tight, slack) = pad_basis_for_cuts(&mut basis, &pool, &[1.0], 5.0, 2, 1e-7);
assert_eq!(basis.row_status.len(), 4, "row_status unchanged");
assert_eq!(
basis.row_status[2], HIGHS_BASIS_STATUS_LOWER,
"prior status preserved"
);
assert_eq!(
basis.row_status[3], HIGHS_BASIS_STATUS_LOWER,
"prior status preserved"
);
assert_eq!(tight, 0);
assert_eq!(slack, 0);
}
#[test]
fn test_all_slack_cuts_get_basic() {
let pool = make_pool_with_cuts(&[(1.0, vec![1.0]), (2.0, vec![2.0])], 1);
let mut basis = Basis::new(3, 2);
let (tight, slack) = pad_basis_for_cuts(&mut basis, &pool, &[1.0], 1000.0, 2, 1e-7);
assert_eq!(basis.row_status.len(), 4);
assert_eq!(basis.row_status[2], HIGHS_BASIS_STATUS_BASIC);
assert_eq!(basis.row_status[3], HIGHS_BASIS_STATUS_BASIC);
assert_eq!(tight, 0);
assert_eq!(slack, 2);
}
#[test]
fn test_negative_slack_is_tight() {
let pool = make_pool_with_cuts(&[(100.0, vec![10.0])], 1);
let mut basis = Basis::new(3, 1);
let (tight, slack) = pad_basis_for_cuts(&mut basis, &pool, &[5.0], 1.0, 1, 1e-7);
assert_eq!(basis.row_status.len(), 2);
assert_eq!(basis.row_status[1], HIGHS_BASIS_STATUS_LOWER);
assert_eq!(tight, 1);
assert_eq!(slack, 0);
}
}