use std::ops::Neg;
use cnvx_core::*;
use cnvx_math::{DenseMatrix, Matrix, matrix::SparseMatrix};
pub struct PrimalSimplexSolver<'model> {
state: State<'model>,
pub tolerance: f64,
pub max_iter: usize,
pub logging: bool,
}
impl<'model> Solver<'model> for PrimalSimplexSolver<'model> {
fn new(model: &'model Model) -> Self {
Self {
state: State::Dense(PrimalSimplexState::new(model)),
tolerance: 1e-8,
max_iter: 1000,
logging: false,
}
}
fn solve(&mut self) -> Result<Solution, SolveError> {
match &self.state {
State::Dense(s) => crate::validate::check_lp(s.model)?,
State::Sparse(s) => crate::validate::check_lp(s.model)?,
}
let (values, obj) = match &mut self.state {
State::Dense(s) => s.solve_lp(self.max_iter, self.tolerance)?,
State::Sparse(s) => s.solve_lp(self.max_iter, self.tolerance)?,
};
if self.logging {
match &self.state {
State::Dense(s) => println!(
"Simplex finished with status {:?} in {} iterations. Objective value: {}",
s.status, s.iteration, obj
),
State::Sparse(s) => println!(
"Simplex finished with status {:?} in {} iterations. Objective value: {}",
s.status, s.iteration, obj
),
}
}
let status = match &self.state {
State::Dense(s) => s.status.clone(),
State::Sparse(s) => s.status.clone(),
};
Ok(Solution { values, objective_value: Some(obj), status })
}
fn get_objective_value(&self) -> f64 {
match &self.state {
State::Dense(s) => s.objective,
State::Sparse(s) => s.objective,
}
}
fn get_solution(&self) -> Vec<f64> {
vec![]
}
}
#[allow(dead_code)]
enum State<'model> {
Dense(PrimalSimplexState<'model, DenseMatrix>),
Sparse(PrimalSimplexState<'model, SparseMatrix>),
}
#[derive(Clone)]
pub struct PrimalSimplexState<'model, A: Matrix> {
pub model: &'model Model,
pub iteration: usize,
pub basis: Vec<usize>,
pub non_basis: Vec<usize>,
pub x_b: Vec<f64>,
pub a: A,
pub b: Vec<f64>,
pub c: Vec<f64>,
pub objective: f64,
pub status: SolveStatus,
minimise: bool,
logging: bool,
log_interval: usize,
}
impl<'model, A: Matrix> PrimalSimplexState<'model, A> {
pub fn new(model: &'model Model) -> Self {
let n_vars = model.vars().len();
let n_cons = model.constraints().len();
let mut b = vec![0.0; n_cons];
let mut n_total = n_vars;
for cons in model.constraints().iter() {
match cons.cmp {
Cmp::Leq | Cmp::Geq => n_total += 1,
Cmp::Eq => {}
}
}
let mut a = A::new(n_cons, n_total);
let mut c = vec![0.0; n_total];
let minimise =
model.objective().map(|o| o.sense == Sense::Minimize).unwrap_or(false);
if let Some(obj) = model.objective() {
for term in &obj.expr.terms {
c[term.var.0] = match obj.sense {
Sense::Maximize => term.coeff,
Sense::Minimize => -term.coeff,
};
}
}
let mut extra_idx = n_vars;
for (i, cons) in model.constraints().iter().enumerate() {
b[i] = cons.rhs;
for term in &cons.expr.terms {
a.set(i, term.var.0, term.coeff);
}
match cons.cmp {
Cmp::Leq => {
a.set(i, extra_idx, 1.0);
extra_idx += 1;
}
Cmp::Geq => {
a.set(i, extra_idx, -1.0);
extra_idx += 1;
}
Cmp::Eq => {}
}
}
Self {
model,
iteration: 0,
basis: Vec::new(),
non_basis: (0..n_vars).collect(),
x_b: vec![0.0; n_cons],
a,
b,
c,
objective: 0.0,
status: SolveStatus::NotSolved,
minimise,
logging: true,
log_interval: 100,
}
}
pub fn solve_lp(
&mut self,
max_iter: usize,
tol: f64,
) -> Result<(Vec<f64>, f64), SolveError> {
self.init_basis();
let orig_n = self.a.cols();
if self.try_phase2(max_iter, tol)? {
return Ok(self.extract_solution(orig_n));
}
self.phase1(orig_n, max_iter, tol)?;
self.phase2(max_iter, tol)?;
Ok(self.extract_solution(orig_n))
}
fn try_phase2(&mut self, max_iter: usize, tol: f64) -> Result<bool, SolveError> {
let mut bmat = self.build_bmat();
match self.compute_basic_solution(&mut bmat) {
Ok(xb) if xb.iter().all(|&v| v >= -tol) => {
self.x_b = xb;
self.remove_artificial_from_basis(&mut bmat, self.a.cols())
.map_err(SolveError::InvalidModel)?;
self.run_simplex(&mut bmat, max_iter, tol)?;
Ok(true)
}
_ => Ok(false),
}
}
fn phase1(
&mut self,
orig_n: usize,
max_iter: usize,
tol: f64,
) -> Result<(), SolveError> {
let (orig_a, orig_c, mut bmat) = self.setup_phase1(orig_n);
self.run_simplex(&mut bmat, max_iter, tol)?;
let sum_art: f64 = self
.basis
.iter()
.enumerate()
.map(|(i, &v)| self.c[v] * self.x_b[i])
.sum::<f64>()
.neg();
if sum_art > tol {
self.status = SolveStatus::Infeasible;
return Ok(());
}
self.remove_artificial_from_basis(&mut bmat, orig_n)
.map_err(SolveError::InvalidModel)?;
self.a = orig_a;
self.c = orig_c;
let mut used = vec![false; orig_n];
for &b in &self.basis {
if b < orig_n {
used[b] = true;
}
}
self.non_basis = (0..orig_n).filter(|&j| !used[j]).collect();
Ok(())
}
fn phase2(&mut self, max_iter: usize, tol: f64) -> Result<(), SolveError> {
let mut bmat = self.build_bmat();
self.run_simplex(&mut bmat, max_iter, tol)
}
pub fn init_basis(&mut self) {
let m = self.a.rows();
let n = self.a.cols();
let mut basis = vec![None; m];
let mut used = vec![false; n];
for (j, used_j) in used.iter_mut().enumerate().take(n) {
let mut one_row = None;
let mut ok = true;
for i in 0..m {
let v = self.a.get(i, j);
if v.abs() > 1e-12 {
if (v - 1.0).abs() < 1e-12 {
if one_row.is_some() {
ok = false;
break;
}
one_row = Some(i);
} else {
ok = false;
break;
}
}
}
if ok && one_row.is_some_and(|r| basis[r].is_none()) {
let r = one_row.unwrap();
basis[r] = Some(j);
*used_j = true;
}
}
if basis.iter().all(|b| b.is_some()) {
self.basis = basis.into_iter().map(|b| b.unwrap()).collect();
self.non_basis = (0..n).filter(|j| !used[*j]).collect();
} else {
self.basis = (0..m).collect();
self.non_basis = (m..n).collect();
}
}
pub fn build_bmat(&self) -> A {
let m = self.a.rows();
let mut bmat = A::new(m, m);
for i in 0..m {
for j in 0..m {
bmat.set(i, j, self.a.get(i, self.basis[j]));
}
}
bmat
}
pub fn compute_basic_solution(&self, bmat: &mut A) -> Result<Vec<f64>, String> {
let mut xb = self.b.clone();
bmat.mldivide(&mut xb).map_err(|e| format!("gauss failed: {e}"))?;
Ok(xb)
}
fn run_simplex(
&mut self,
bmat: &mut A,
max_iter: usize,
tol: f64,
) -> Result<(), SolveError> {
let current_iter = self.iteration;
for iter in current_iter..max_iter {
self.iteration = iter;
let pi = self.compute_duals(bmat)?;
let Some((nb_pos, entering)) = self.choose_entering(&pi, tol) else {
self.status = SolveStatus::Optimal;
return Ok(());
};
let d = self.compute_direction(bmat, entering)?;
let Some((leave_row, theta)) = self.choose_leaving(&d, tol) else {
self.status = SolveStatus::Unbounded;
return Ok(());
};
self.update_primal(&d, leave_row, theta);
self.pivot(bmat, nb_pos, leave_row, entering);
self.update_objective();
if self.logging && (iter + 1) % self.log_interval == 0 {
println!(
"Iteration {:>4}: Objective = {:>12.6}",
iter + 1,
if self.minimise { -self.objective } else { self.objective }
);
}
}
Err(SolveError::Other("max iterations reached".into()))
}
fn compute_duals(&self, bmat: &A) -> Result<Vec<f64>, SolveError> {
let m = bmat.rows();
let mut pi = (0..m).map(|i| self.c[self.basis[i]]).collect::<Vec<_>>();
let mut bt = A::new(m, m);
for i in 0..m {
for j in 0..m {
bt.set(i, j, bmat.get(j, i));
}
}
bt.mldivide(&mut pi)
.map_err(|e| SolveError::Other(format!("dual solve failed: {e}")))?;
Ok(pi)
}
fn choose_entering(&self, pi: &[f64], tol: f64) -> Option<(usize, usize)> {
self.non_basis
.iter()
.enumerate()
.filter_map(|(pos, &j)| {
let rc = self.c[j]
- (0..pi.len()).map(|i| pi[i] * self.a.get(i, j)).sum::<f64>();
(rc > tol).then_some((pos, j, rc))
})
.max_by(|a, b| a.2.partial_cmp(&b.2).unwrap())
.map(|(pos, j, _)| (pos, j))
}
fn compute_direction(
&self,
bmat: &mut A,
entering: usize,
) -> Result<Vec<f64>, SolveError> {
let mut d = (0..bmat.rows()).map(|i| self.a.get(i, entering)).collect::<Vec<_>>();
bmat.mldivide(&mut d)
.map_err(|e| SolveError::Other(format!("direction solve failed: {e}")))?;
Ok(d)
}
fn choose_leaving(&self, d: &[f64], tol: f64) -> Option<(usize, f64)> {
(0..d.len())
.filter(|&i| d[i] > tol)
.map(|i| (i, self.x_b[i] / d[i]))
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
}
fn update_primal(&mut self, d: &[f64], leave: usize, theta: f64) {
for (xi, di) in self.x_b.iter_mut().zip(d.iter()) {
*xi -= theta * di;
if (*xi).abs() < 1e-12 {
*xi = 0.0;
}
}
self.x_b[leave] = theta;
}
fn pivot(
&mut self,
bmat: &mut A,
enter_pos: usize,
leave_row: usize,
entering: usize,
) {
let leaving = self.basis[leave_row];
self.basis[leave_row] = entering;
self.non_basis[enter_pos] = leaving;
for i in 0..bmat.rows() {
bmat.set(i, leave_row, self.a.get(i, entering));
}
}
fn update_objective(&mut self) {
self.objective = self
.basis
.iter()
.enumerate()
.map(|(i, &v)| self.c[v] * self.x_b[i])
.sum();
}
pub fn setup_phase1(&mut self, orig_n: usize) -> (A, Vec<f64>, A) {
let m = self.a.rows();
let n = self.a.cols();
let mut a_aug = A::new(m, n + m);
let mut b_aug = self.b.clone();
for (i, bval) in b_aug.iter_mut().enumerate().take(m) {
if *bval < 0.0 {
*bval = -*bval;
for j in 0..n {
a_aug.set(i, j, -self.a.get(i, j));
}
} else {
for j in 0..n {
a_aug.set(i, j, self.a.get(i, j));
}
}
for j in 0..m {
a_aug.set(i, n + j, if i == j { 1.0 } else { 0.0 });
}
}
let mut c_aug = vec![0.0; n + m];
for j in 0..m {
c_aug[n + j] = -1.0;
}
let orig_a = self.a.clone();
let orig_c = self.c.clone();
self.a = a_aug;
self.c = c_aug;
self.basis = (orig_n..orig_n + m).collect();
self.non_basis = (0..orig_n).collect();
self.x_b = b_aug;
let mut bmat = A::new(m, m);
for i in 0..m {
for j in 0..m {
bmat.set(i, j, self.a.get(i, self.basis[j]));
}
}
(orig_a, orig_c, bmat)
}
pub fn remove_artificial_from_basis(
&mut self,
bmat: &mut A,
orig_n: usize,
) -> Result<(), String> {
let m = bmat.rows();
for row in 0..m {
if self.basis[row] >= orig_n {
let mut pivot = None;
for (nb_pos, &j) in self.non_basis.iter().enumerate() {
if j < orig_n && self.a.get(row, j).abs() > 1e-12 {
pivot = Some((nb_pos, j));
break;
}
}
if let Some((nb_pos, j)) = pivot {
let leaving = self.basis[row];
self.basis[row] = j;
self.non_basis[nb_pos] = leaving;
for i in 0..m {
bmat.set(i, row, self.a.get(i, j));
}
} else if self.x_b[row].abs() > 1e-12 {
return Err(
"artificial variable left in basis with non-zero value".into()
);
} else {
for (nb_pos, &j) in self.non_basis.iter().enumerate() {
if j < orig_n && self.a.get(row, j).abs() < 1e-12 {
let leaving = self.basis[row];
self.basis[row] = j;
self.non_basis[nb_pos] = leaving;
for i in 0..m {
bmat.set(i, row, self.a.get(i, j));
}
break;
}
}
}
}
}
Ok(())
}
pub fn extract_solution(&self, orig_n: usize) -> (Vec<f64>, f64) {
let m = self.a.rows();
let mut sol = vec![0.0; orig_n];
for i in 0..m {
if self.basis[i] < orig_n {
sol[self.basis[i]] = self.x_b[i];
}
}
let mut obj = self
.basis
.iter()
.enumerate()
.filter(|(_, v)| **v < orig_n)
.map(|(i, v)| self.c[*v] * self.x_b[i])
.sum::<f64>();
if self.minimise {
obj = -obj;
}
(sol, obj)
}
}