use crate::error::OptimError;
use crate::types::{IterationRecord, OptimResult, OptimStatus};
use numra_core::Scalar;
use numra_linalg::{DenseMatrix, Matrix};
#[derive(Clone, Debug)]
pub struct LPOptions<S: Scalar> {
pub max_iter: usize,
pub tol: S,
pub verbose: bool,
}
impl<S: Scalar> Default for LPOptions<S> {
fn default() -> Self {
Self {
max_iter: 10_000,
tol: S::from_f64(1e-10),
verbose: false,
}
}
}
pub fn simplex_solve<
S: Scalar + faer::SimpleEntity + faer::Conjugate<Canonical = S> + faer::ComplexField,
>(
c: &[S],
a_ineq: &[Vec<S>],
b_ineq: &[S],
a_eq: &[Vec<S>],
b_eq: &[S],
opts: &LPOptions<S>,
) -> Result<OptimResult<S>, OptimError> {
let start = std::time::Instant::now();
let n_orig = c.len();
let m_ineq = a_ineq.len();
let m_eq = a_eq.len();
let m = m_ineq + m_eq;
if m == 0 && n_orig > 0 {
if c.iter().any(|&ci| ci < -opts.tol) {
return Err(OptimError::Unbounded);
}
let x = vec![S::ZERO; n_orig];
return Ok(OptimResult::unconstrained(
x,
S::ZERO,
c.to_vec(),
0,
0,
0,
true,
"Optimal at origin (no constraints)".into(),
OptimStatus::FunctionConverged,
)
.with_wall_time(start));
}
let n_slack = m_ineq;
let n_total = n_orig + n_slack;
let mut a_rows: Vec<Vec<S>> = Vec::with_capacity(m);
let mut b_std: Vec<S> = Vec::with_capacity(m);
for i in 0..m_eq {
if a_eq[i].len() != n_orig {
return Err(OptimError::DimensionMismatch {
expected: n_orig,
actual: a_eq[i].len(),
});
}
let mut row = a_eq[i].clone();
row.resize(n_total, S::ZERO);
let mut bi = b_eq[i];
if bi < S::ZERO {
for r in row.iter_mut() {
*r = -*r;
}
bi = -bi;
}
a_rows.push(row);
b_std.push(bi);
}
for i in 0..m_ineq {
if a_ineq[i].len() != n_orig {
return Err(OptimError::DimensionMismatch {
expected: n_orig,
actual: a_ineq[i].len(),
});
}
let mut row = a_ineq[i].clone();
row.resize(n_total, S::ZERO);
row[n_orig + i] = S::ONE;
let mut bi = b_ineq[i];
if bi < S::ZERO {
for r in row.iter_mut() {
*r = -*r;
}
bi = -bi;
}
a_rows.push(row);
b_std.push(bi);
}
let mut c_std = c.to_vec();
c_std.resize(n_total, S::ZERO);
let mut need_art = vec![false; m];
for item in need_art.iter_mut().take(m_eq) {
*item = true;
}
let has_artificials = need_art.iter().any(|&x| x);
let mut basis: Vec<usize> = Vec::with_capacity(m);
if has_artificials {
let n_phase1 = n_total + m;
let mut a_ph1: Vec<Vec<S>> = Vec::with_capacity(m);
for (i, row) in a_rows.iter().enumerate() {
let mut ph1_row = row.clone();
ph1_row.resize(n_phase1, S::ZERO);
if need_art[i] {
ph1_row[n_total + i] = S::ONE;
}
a_ph1.push(ph1_row);
}
let mut c_ph1 = vec![S::ZERO; n_phase1];
for i in 0..m {
if need_art[i] {
c_ph1[n_total + i] = S::ONE;
}
}
let mut basis_ph1 = Vec::with_capacity(m);
for i in 0..m_eq {
basis_ph1.push(n_total + i); }
for i in 0..m_ineq {
basis_ph1.push(n_orig + i); }
let ph1_result = simplex_core(&a_ph1, &mut b_std, &c_ph1, &mut basis_ph1, opts)?;
if ph1_result > opts.tol {
return Err(OptimError::LPInfeasible);
}
for &bj in &basis_ph1 {
if bj >= n_total {
let row = basis_ph1.iter().position(|&x| x == bj).unwrap();
if b_std[row].abs() > opts.tol {
return Err(OptimError::LPInfeasible);
}
}
}
basis = basis_ph1;
for row in 0..m {
if basis[row] >= n_total {
for (j, val) in a_rows[row].iter().enumerate().take(n_total) {
if !basis.contains(&j) && val.abs() > opts.tol {
basis[row] = j;
break;
}
}
}
}
} else {
for i in 0..m_ineq {
basis.push(n_orig + i);
}
}
let _opt_val = simplex_core(&a_rows, &mut b_std, &c_std, &mut basis, opts)?;
let mut x = vec![S::ZERO; n_total];
for (row, &bj) in basis.iter().enumerate() {
if bj < n_total {
x[bj] = b_std[row];
}
}
let x_orig: Vec<S> = x[..n_orig].to_vec();
let f_val: S = c
.iter()
.zip(x_orig.iter())
.map(|(&ci, &xi)| ci * xi)
.sum::<S>();
let mut b_mat = DenseMatrix::<S>::zeros(m, m);
for (col, &bj) in basis.iter().enumerate() {
for (i, a_row) in a_rows.iter().enumerate().take(m) {
b_mat.set(i, col, a_row[bj]);
}
}
let c_b: Vec<S> = basis.iter().map(|&j| c_std[j]).collect();
let mut bt = DenseMatrix::<S>::zeros(m, m);
for i in 0..m {
for j in 0..m {
bt.set(i, j, b_mat.get(j, i));
}
}
let pi = bt.solve(&c_b).unwrap_or_else(|_| vec![S::ZERO; m]);
let mut lambda_eq = Vec::with_capacity(m_eq);
let mut lambda_ineq = Vec::with_capacity(m_ineq);
for item in pi.iter().take(m_eq) {
lambda_eq.push(*item);
}
for i in 0..m_ineq {
lambda_ineq.push(pi[m_eq + i]);
}
let history = vec![IterationRecord {
iteration: 0,
objective: f_val,
gradient_norm: S::ZERO,
step_size: S::ZERO,
constraint_violation: S::ZERO,
}];
Ok((OptimResult {
lambda_eq,
lambda_ineq,
history,
..OptimResult::unconstrained(
x_orig,
f_val,
c.to_vec(),
0,
0,
0,
true,
"Optimal solution found".into(),
OptimStatus::FunctionConverged,
)
})
.with_wall_time(start))
}
fn simplex_core<
S: Scalar + faer::SimpleEntity + faer::Conjugate<Canonical = S> + faer::ComplexField,
>(
a_rows: &[Vec<S>],
b: &mut [S],
c: &[S],
basis: &mut [usize],
opts: &LPOptions<S>,
) -> Result<S, OptimError> {
let m = a_rows.len();
let n = c.len();
for _iter in 0..opts.max_iter {
let mut b_mat = DenseMatrix::<S>::zeros(m, m);
for (col, &bj) in basis.iter().enumerate() {
for (row, a_row) in a_rows.iter().enumerate().take(m) {
b_mat.set(row, col, a_row[bj]);
}
}
let c_b: Vec<S> = basis.iter().map(|&j| c[j]).collect();
let mut bt = DenseMatrix::<S>::zeros(m, m);
for i in 0..m {
for j in 0..m {
bt.set(i, j, b_mat.get(j, i));
}
}
let pi = bt.solve(&c_b).map_err(|_| OptimError::SingularMatrix)?;
let mut entering = None;
let mut min_rc = -opts.tol;
for j in 0..n {
if basis.contains(&j) {
continue;
}
let mut rc = c[j];
for i in 0..m {
rc -= pi[i] * a_rows[i][j];
}
if rc < min_rc {
min_rc = rc;
entering = Some(j);
}
}
let entering = match entering {
Some(j) => j,
None => {
return Ok(basis
.iter()
.enumerate()
.map(|(row, &j)| c[j] * b[row])
.sum::<S>());
}
};
let a_col: Vec<S> = (0..m).map(|i| a_rows[i][entering]).collect();
let d = b_mat
.solve(&a_col)
.map_err(|_| OptimError::SingularMatrix)?;
let mut min_ratio = S::INFINITY;
let mut leaving_row = None;
for i in 0..m {
if d[i] > opts.tol {
let ratio = b[i] / d[i];
if ratio < min_ratio {
min_ratio = ratio;
leaving_row = Some(i);
}
}
}
let leaving_row = match leaving_row {
Some(r) => r,
None => return Err(OptimError::Unbounded),
};
let theta = min_ratio;
for i in 0..m {
if i == leaving_row {
b[i] = theta;
} else {
b[i] -= theta * d[i];
}
}
basis[leaving_row] = entering;
for bi in b.iter_mut() {
if *bi < S::ZERO && *bi > -opts.tol {
*bi = S::ZERO;
}
}
}
Err(OptimError::Other(format!(
"simplex: max iterations ({}) reached",
opts.max_iter
)))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lp_simple_2d() {
let c = vec![-1.0, -1.0];
let a_ineq = vec![vec![1.0, 1.0], vec![1.0, 0.0], vec![0.0, 1.0]];
let b_ineq = vec![4.0, 3.0, 3.0];
let opts = LPOptions::default();
let result = simplex_solve(&c, &a_ineq, &b_ineq, &[], &[], &opts).unwrap();
assert!(result.converged, "LP did not converge: {}", result.message);
assert!(
(result.f - (-4.0)).abs() < 1e-8,
"f={}, expected -4",
result.f
);
assert!(result.x[0] + result.x[1] - 4.0 < 1e-8);
}
#[test]
fn test_lp_with_equality() {
let c = vec![1.0, 2.0];
let a_eq = vec![vec![1.0, 1.0]];
let b_eq = vec![3.0];
let opts = LPOptions::default();
let result = simplex_solve(&c, &[], &[], &a_eq, &b_eq, &opts).unwrap();
assert!(result.converged, "LP did not converge: {}", result.message);
assert!((result.f - 3.0).abs() < 1e-8, "f={}", result.f);
assert!((result.x[0] - 3.0).abs() < 1e-8);
assert!(result.x[1].abs() < 1e-8);
}
#[test]
fn test_lp_unbounded() {
let c = vec![-1.0];
let opts = LPOptions::default();
let result = simplex_solve(&c, &[], &[], &[], &[], &opts);
assert!(result.is_err());
}
#[test]
fn test_lp_infeasible() {
let c2 = vec![1.0, 1.0];
let a_eq2 = vec![vec![1.0, 1.0], vec![1.0, 1.0]];
let b_eq2 = vec![1.0, 2.0];
let opts = LPOptions::default();
let result = simplex_solve(&c2, &[], &[], &a_eq2, &b_eq2, &opts);
assert!(result.is_err());
}
#[test]
fn test_lp_3d_production() {
let c = vec![-5.0, -4.0, -3.0];
let a_ineq = vec![vec![6.0, 4.0, 2.0], vec![3.0, 2.0, 5.0]];
let b_ineq = vec![240.0, 270.0];
let opts = LPOptions::default();
let result = simplex_solve(&c, &a_ineq, &b_ineq, &[], &[], &opts).unwrap();
assert!(result.converged, "LP did not converge: {}", result.message);
assert!(result.f <= -219.0, "f={}, expected <= -219", result.f);
}
#[test]
fn test_lp_dual_variables() {
let c = vec![-1.0, -1.0];
let a_ineq = vec![vec![1.0, 1.0]];
let b_ineq = vec![1.0];
let opts = LPOptions::default();
let result = simplex_solve(&c, &a_ineq, &b_ineq, &[], &[], &opts).unwrap();
assert!(result.converged);
assert!((result.f - (-1.0)).abs() < 1e-8, "f={}", result.f);
assert!(!result.lambda_ineq.is_empty());
}
}