Skip to main content

cnvx_lp/
primal_simplex.rs

1use std::ops::Neg;
2
3// FIXME: Replace with better solving techniques.
4use cnvx_core::*;
5use cnvx_math::{DenseMatrix, Matrix, matrix::SparseMatrix};
6
7/// A simplex solver for linear programs (LPs).
8///
9/// # Examples
10///
11/// ```rust
12/// # use cnvx_core::*;
13/// # use cnvx_lp::PrimalSimplexSolver;
14/// let mut model = Model::new();
15/// let x = model.add_var().finish();
16/// model += x.geq(0.0);
17/// model += x.leq(10.0);
18/// model.add_objective(Objective::maximize(x * 2.0).name("maximize_x"));
19///
20/// let mut solver = PrimalSimplexSolver::new(&model);
21/// let solution = solver.solve().unwrap();
22/// println!("Solution value: {}", solution.value(x));
23/// ```
24pub struct PrimalSimplexSolver<'model> {
25    // Internal state of the simplex algorithm, including the tableau and current solution.
26    state: State<'model>,
27    /// The numerical tolerance used for feasibility and optimality checks.
28    pub tolerance: f64,
29    /// The maximum number of simplex iterations before terminating with an error.
30    pub max_iter: usize,
31    /// Whether to log iteration details during the simplex algorithm.
32    pub logging: bool,
33}
34
35impl<'model> Solver<'model> for PrimalSimplexSolver<'model> {
36    fn new(model: &'model Model) -> Self {
37        Self {
38            state: State::Dense(PrimalSimplexState::new(model)),
39            tolerance: 1e-8,
40            max_iter: 1000,
41            logging: false,
42        }
43    }
44
45    fn solve(&mut self) -> Result<Solution, SolveError> {
46        // crate::validate::check_lp(self.state.model)?;
47        match &self.state {
48            State::Dense(s) => crate::validate::check_lp(s.model)?,
49            State::Sparse(s) => crate::validate::check_lp(s.model)?,
50        }
51
52        // let (values, obj) = self.state.solve_lp(self.max_iter, self.tolerance)?;
53        let (values, obj) = match &mut self.state {
54            State::Dense(s) => s.solve_lp(self.max_iter, self.tolerance)?,
55            State::Sparse(s) => s.solve_lp(self.max_iter, self.tolerance)?,
56        };
57
58        if self.logging {
59            match &self.state {
60                State::Dense(s) => println!(
61                    "Simplex finished with status {:?} in {} iterations. Objective value: {}",
62                    s.status, s.iteration, obj
63                ),
64                State::Sparse(s) => println!(
65                    "Simplex finished with status {:?} in {} iterations. Objective value: {}",
66                    s.status, s.iteration, obj
67                ),
68            }
69        }
70
71        let status = match &self.state {
72            State::Dense(s) => s.status.clone(),
73            State::Sparse(s) => s.status.clone(),
74        };
75
76        Ok(Solution { values, objective_value: Some(obj), status })
77    }
78
79    fn get_objective_value(&self) -> f64 {
80        match &self.state {
81            State::Dense(s) => s.objective,
82            State::Sparse(s) => s.objective,
83        }
84    }
85
86    fn get_solution(&self) -> Vec<f64> {
87        // TODO: reconstruct full solution vector from x_b and non-basic vars
88        vec![]
89    }
90}
91
92#[allow(dead_code)]
93enum State<'model> {
94    Dense(PrimalSimplexState<'model, DenseMatrix>),
95    Sparse(PrimalSimplexState<'model, SparseMatrix>),
96}
97
98/// Internal state for the simplex algorithm.
99///
100/// Tracks the current basis, non-basis variables, solution vector, objective value,
101/// and the LP tableau.
102#[derive(Clone)]
103pub struct PrimalSimplexState<'model, A: Matrix> {
104    /// Reference to the LP model being solved.
105    pub model: &'model Model,
106    /// Current iteration count of the simplex algorithm.
107    pub iteration: usize,
108
109    /// Indices of basis variables in the tableau.
110    pub basis: Vec<usize>,
111    /// Indices of non-basis variables in the tableau.
112    pub non_basis: Vec<usize>,
113    /// Values of the basic variables.
114    pub x_b: Vec<f64>,
115
116    /// Constraint matrix `A`.
117    pub a: A,
118    /// Right-hand side vector `b`.
119    pub b: Vec<f64>,
120    /// Objective coefficients vector `c`.
121    pub c: Vec<f64>,
122
123    /// Current objective value.
124    pub objective: f64,
125    /// Solution status after solving (Optimal, Infeasible, Unbounded, etc.).
126    pub status: SolveStatus,
127
128    /// Whether the LP is a minimization problem.
129    minimise: bool,
130
131    /// Whether to log iteration details during the simplex algorithm.
132    logging: bool,
133
134    /// Interval at which to log iteration details if logging is enabled.
135    log_interval: usize,
136}
137
138impl<'model, A: Matrix> PrimalSimplexState<'model, A> {
139    /// Initialize a new simplex state from a given `Model`.
140    ///
141    /// Constructs the tableau, sets up artificial variables for inequalities, and
142    /// computes the objective coefficients based on the problem's sense (min/max).
143    pub fn new(model: &'model Model) -> Self {
144        let n_vars = model.vars().len();
145        let n_cons = model.constraints().len();
146
147        let mut b = vec![0.0; n_cons];
148
149        let mut n_total = n_vars;
150        for cons in model.constraints().iter() {
151            match cons.cmp {
152                Cmp::Leq | Cmp::Geq => n_total += 1,
153                Cmp::Eq => {}
154            }
155        }
156
157        let mut a = A::new(n_cons, n_total);
158        let mut c = vec![0.0; n_total];
159
160        let minimise =
161            model.objective().map(|o| o.sense == Sense::Minimize).unwrap_or(false);
162
163        if let Some(obj) = model.objective() {
164            for term in &obj.expr.terms {
165                c[term.var.0] = match obj.sense {
166                    Sense::Maximize => term.coeff,
167                    Sense::Minimize => -term.coeff,
168                };
169            }
170        }
171
172        let mut extra_idx = n_vars;
173        for (i, cons) in model.constraints().iter().enumerate() {
174            b[i] = cons.rhs;
175            for term in &cons.expr.terms {
176                a.set(i, term.var.0, term.coeff);
177            }
178            match cons.cmp {
179                Cmp::Leq => {
180                    a.set(i, extra_idx, 1.0);
181                    extra_idx += 1;
182                }
183                Cmp::Geq => {
184                    a.set(i, extra_idx, -1.0);
185                    extra_idx += 1;
186                }
187                Cmp::Eq => {}
188            }
189        }
190
191        Self {
192            model,
193            iteration: 0,
194            basis: Vec::new(),
195            non_basis: (0..n_vars).collect(),
196            x_b: vec![0.0; n_cons],
197            a,
198            b,
199            c,
200            objective: 0.0,
201            status: SolveStatus::NotSolved,
202            minimise,
203            // FIXME: For now always enable logging
204            logging: true,
205            log_interval: 100,
206        }
207    }
208
209    /// Solve the LP using the simplex method.
210    ///
211    /// Performs a two-phase simplex if necessary (phase 1 for feasibility, phase 2 for optimality).
212    ///
213    /// Returns the solution vector and the objective value.
214    pub fn solve_lp(
215        &mut self,
216        max_iter: usize,
217        tol: f64,
218    ) -> Result<(Vec<f64>, f64), SolveError> {
219        self.init_basis();
220        let orig_n = self.a.cols();
221
222        if self.try_phase2(max_iter, tol)? {
223            return Ok(self.extract_solution(orig_n));
224        }
225
226        self.phase1(orig_n, max_iter, tol)?;
227        self.phase2(max_iter, tol)?;
228
229        Ok(self.extract_solution(orig_n))
230    }
231
232    /// Attempt to directly run phase 2 if the initial basis is feasible.
233    fn try_phase2(&mut self, max_iter: usize, tol: f64) -> Result<bool, SolveError> {
234        let mut bmat = self.build_bmat();
235        match self.compute_basic_solution(&mut bmat) {
236            Ok(xb) if xb.iter().all(|&v| v >= -tol) => {
237                self.x_b = xb;
238                self.remove_artificial_from_basis(&mut bmat, self.a.cols())
239                    .map_err(SolveError::InvalidModel)?;
240                self.run_simplex(&mut bmat, max_iter, tol)?;
241                Ok(true)
242            }
243            _ => Ok(false),
244        }
245    }
246
247    /// Phase 1 of the two-phase simplex method to remove artificial variables.
248    fn phase1(
249        &mut self,
250        orig_n: usize,
251        max_iter: usize,
252        tol: f64,
253    ) -> Result<(), SolveError> {
254        let (orig_a, orig_c, mut bmat) = self.setup_phase1(orig_n);
255        self.run_simplex(&mut bmat, max_iter, tol)?;
256
257        let sum_art: f64 = self
258            .basis
259            .iter()
260            .enumerate()
261            .map(|(i, &v)| self.c[v] * self.x_b[i])
262            .sum::<f64>()
263            .neg();
264
265        if sum_art > tol {
266            self.status = SolveStatus::Infeasible;
267            return Ok(());
268        }
269
270        self.remove_artificial_from_basis(&mut bmat, orig_n)
271            .map_err(SolveError::InvalidModel)?;
272
273        self.a = orig_a;
274        self.c = orig_c;
275        let mut used = vec![false; orig_n];
276        for &b in &self.basis {
277            if b < orig_n {
278                used[b] = true;
279            }
280        }
281        self.non_basis = (0..orig_n).filter(|&j| !used[j]).collect();
282        Ok(())
283    }
284
285    /// Phase 2 of the simplex method to optimize the LP.
286    fn phase2(&mut self, max_iter: usize, tol: f64) -> Result<(), SolveError> {
287        let mut bmat = self.build_bmat();
288        self.run_simplex(&mut bmat, max_iter, tol)
289    }
290
291    /// Initialize the basis using slack, surplus, and identity columns.
292    pub fn init_basis(&mut self) {
293        let m = self.a.rows();
294        let n = self.a.cols();
295
296        let mut basis = vec![None; m];
297        let mut used = vec![false; n];
298
299        for (j, used_j) in used.iter_mut().enumerate().take(n) {
300            let mut one_row = None;
301            let mut ok = true;
302            for i in 0..m {
303                let v = self.a.get(i, j);
304                if v.abs() > 1e-12 {
305                    if (v - 1.0).abs() < 1e-12 {
306                        if one_row.is_some() {
307                            ok = false;
308                            break;
309                        }
310                        one_row = Some(i);
311                    } else {
312                        ok = false;
313                        break;
314                    }
315                }
316            }
317            if ok && one_row.is_some_and(|r| basis[r].is_none()) {
318                let r = one_row.unwrap();
319                basis[r] = Some(j);
320                *used_j = true;
321            }
322        }
323
324        if basis.iter().all(|b| b.is_some()) {
325            self.basis = basis.into_iter().map(|b| b.unwrap()).collect();
326            self.non_basis = (0..n).filter(|j| !used[*j]).collect();
327        } else {
328            self.basis = (0..m).collect();
329            self.non_basis = (m..n).collect();
330        }
331    }
332
333    /// Build the current basis matrix `B` from the full tableau `A`.
334    pub fn build_bmat(&self) -> A {
335        let m = self.a.rows();
336        let mut bmat = A::new(m, m);
337        for i in 0..m {
338            for j in 0..m {
339                bmat.set(i, j, self.a.get(i, self.basis[j]));
340            }
341        }
342        bmat
343    }
344
345    /// Compute the values of the basic variables by solving `B x_B = b`.
346    pub fn compute_basic_solution(&self, bmat: &mut A) -> Result<Vec<f64>, String> {
347        let mut xb = self.b.clone();
348        bmat.mldivide(&mut xb).map_err(|e| format!("gauss failed: {e}"))?;
349        Ok(xb)
350    }
351
352    /// Run the main simplex iteration loop.
353    fn run_simplex(
354        &mut self,
355        bmat: &mut A,
356        max_iter: usize,
357        tol: f64,
358    ) -> Result<(), SolveError> {
359        let current_iter = self.iteration;
360        for iter in current_iter..max_iter {
361            self.iteration = iter;
362
363            let pi = self.compute_duals(bmat)?;
364            let Some((nb_pos, entering)) = self.choose_entering(&pi, tol) else {
365                self.status = SolveStatus::Optimal;
366                return Ok(());
367            };
368
369            let d = self.compute_direction(bmat, entering)?;
370            let Some((leave_row, theta)) = self.choose_leaving(&d, tol) else {
371                self.status = SolveStatus::Unbounded;
372                return Ok(());
373            };
374
375            self.update_primal(&d, leave_row, theta);
376            self.pivot(bmat, nb_pos, leave_row, entering);
377            self.update_objective();
378
379            if self.logging && (iter + 1) % self.log_interval == 0 {
380                println!(
381                    "Iteration {:>4}: Objective = {:>12.6}",
382                    iter + 1,
383                    if self.minimise { -self.objective } else { self.objective }
384                );
385            }
386        }
387
388        Err(SolveError::Other("max iterations reached".into()))
389    }
390
391    /// Compute dual variables for the current basis.
392    fn compute_duals(&self, bmat: &A) -> Result<Vec<f64>, SolveError> {
393        let m = bmat.rows();
394        let mut pi = (0..m).map(|i| self.c[self.basis[i]]).collect::<Vec<_>>();
395
396        let mut bt = A::new(m, m);
397        for i in 0..m {
398            for j in 0..m {
399                bt.set(i, j, bmat.get(j, i));
400            }
401        }
402
403        bt.mldivide(&mut pi)
404            .map_err(|e| SolveError::Other(format!("dual solve failed: {e}")))?;
405
406        Ok(pi)
407    }
408
409    /// Choose entering variable using reduced costs.
410    fn choose_entering(&self, pi: &[f64], tol: f64) -> Option<(usize, usize)> {
411        self.non_basis
412            .iter()
413            .enumerate()
414            .filter_map(|(pos, &j)| {
415                let rc = self.c[j]
416                    - (0..pi.len()).map(|i| pi[i] * self.a.get(i, j)).sum::<f64>();
417                (rc > tol).then_some((pos, j, rc))
418            })
419            .max_by(|a, b| a.2.partial_cmp(&b.2).unwrap())
420            .map(|(pos, j, _)| (pos, j))
421    }
422
423    /// Compute the simplex direction `d = B^{-1} A_j`.
424    fn compute_direction(
425        &self,
426        bmat: &mut A,
427        entering: usize,
428    ) -> Result<Vec<f64>, SolveError> {
429        let mut d = (0..bmat.rows()).map(|i| self.a.get(i, entering)).collect::<Vec<_>>();
430
431        bmat.mldivide(&mut d)
432            .map_err(|e| SolveError::Other(format!("direction solve failed: {e}")))?;
433
434        Ok(d)
435    }
436
437    /// Choose leaving variable using minimum ratio test.
438    fn choose_leaving(&self, d: &[f64], tol: f64) -> Option<(usize, f64)> {
439        (0..d.len())
440            .filter(|&i| d[i] > tol)
441            .map(|i| (i, self.x_b[i] / d[i]))
442            .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
443    }
444
445    /// Update the primal solution vector `x_B` after a pivot.
446    fn update_primal(&mut self, d: &[f64], leave: usize, theta: f64) {
447        for (xi, di) in self.x_b.iter_mut().zip(d.iter()) {
448            *xi -= theta * di;
449            if (*xi).abs() < 1e-12 {
450                *xi = 0.0;
451            }
452        }
453        self.x_b[leave] = theta;
454    }
455
456    /// Perform pivot operations on the basis and non-basis sets.
457    fn pivot(
458        &mut self,
459        bmat: &mut A,
460        enter_pos: usize,
461        leave_row: usize,
462        entering: usize,
463    ) {
464        let leaving = self.basis[leave_row];
465        self.basis[leave_row] = entering;
466        self.non_basis[enter_pos] = leaving;
467
468        for i in 0..bmat.rows() {
469            bmat.set(i, leave_row, self.a.get(i, entering));
470        }
471    }
472
473    /// Update the current objective value.
474    fn update_objective(&mut self) {
475        self.objective = self
476            .basis
477            .iter()
478            .enumerate()
479            .map(|(i, &v)| self.c[v] * self.x_b[i])
480            .sum();
481    }
482
483    /// Prepare the LP for phase 1 of two-phase simplex by adding artificial variables.
484    pub fn setup_phase1(&mut self, orig_n: usize) -> (A, Vec<f64>, A) {
485        let m = self.a.rows();
486        let n = self.a.cols();
487
488        let mut a_aug = A::new(m, n + m);
489        let mut b_aug = self.b.clone();
490
491        for (i, bval) in b_aug.iter_mut().enumerate().take(m) {
492            if *bval < 0.0 {
493                *bval = -*bval;
494                for j in 0..n {
495                    a_aug.set(i, j, -self.a.get(i, j));
496                }
497            } else {
498                for j in 0..n {
499                    a_aug.set(i, j, self.a.get(i, j));
500                }
501            }
502
503            for j in 0..m {
504                a_aug.set(i, n + j, if i == j { 1.0 } else { 0.0 });
505            }
506        }
507
508        let mut c_aug = vec![0.0; n + m];
509        for j in 0..m {
510            c_aug[n + j] = -1.0;
511        }
512
513        let orig_a = self.a.clone();
514        let orig_c = self.c.clone();
515
516        self.a = a_aug;
517        self.c = c_aug;
518        self.basis = (orig_n..orig_n + m).collect();
519        self.non_basis = (0..orig_n).collect();
520        self.x_b = b_aug;
521
522        let mut bmat = A::new(m, m);
523        for i in 0..m {
524            for j in 0..m {
525                bmat.set(i, j, self.a.get(i, self.basis[j]));
526            }
527        }
528
529        (orig_a, orig_c, bmat)
530    }
531
532    /// Remove artificial variables from the basis once feasibility is established.
533    pub fn remove_artificial_from_basis(
534        &mut self,
535        bmat: &mut A,
536        orig_n: usize,
537    ) -> Result<(), String> {
538        let m = bmat.rows();
539        for row in 0..m {
540            if self.basis[row] >= orig_n {
541                let mut pivot = None;
542                for (nb_pos, &j) in self.non_basis.iter().enumerate() {
543                    if j < orig_n && self.a.get(row, j).abs() > 1e-12 {
544                        pivot = Some((nb_pos, j));
545                        break;
546                    }
547                }
548
549                if let Some((nb_pos, j)) = pivot {
550                    let leaving = self.basis[row];
551                    self.basis[row] = j;
552                    self.non_basis[nb_pos] = leaving;
553                    for i in 0..m {
554                        bmat.set(i, row, self.a.get(i, j));
555                    }
556                } else if self.x_b[row].abs() > 1e-12 {
557                    return Err(
558                        "artificial variable left in basis with non-zero value".into()
559                    );
560                } else {
561                    for (nb_pos, &j) in self.non_basis.iter().enumerate() {
562                        if j < orig_n && self.a.get(row, j).abs() < 1e-12 {
563                            let leaving = self.basis[row];
564                            self.basis[row] = j;
565                            self.non_basis[nb_pos] = leaving;
566                            for i in 0..m {
567                                bmat.set(i, row, self.a.get(i, j));
568                            }
569                            break;
570                        }
571                    }
572                }
573            }
574        }
575        Ok(())
576    }
577
578    /// Extract the final solution and objective value.
579    pub fn extract_solution(&self, orig_n: usize) -> (Vec<f64>, f64) {
580        let m = self.a.rows();
581        let mut sol = vec![0.0; orig_n];
582
583        for i in 0..m {
584            if self.basis[i] < orig_n {
585                sol[self.basis[i]] = self.x_b[i];
586            }
587        }
588
589        let mut obj = self
590            .basis
591            .iter()
592            .enumerate()
593            .filter(|(_, v)| **v < orig_n)
594            .map(|(i, v)| self.c[*v] * self.x_b[i])
595            .sum::<f64>();
596
597        if self.minimise {
598            obj = -obj;
599        }
600
601        (sol, obj)
602    }
603}