1use std::ops::Neg;
2
3use cnvx_core::*;
5use cnvx_math::{DenseMatrix, Matrix, matrix::SparseMatrix};
6
7pub struct PrimalSimplexSolver<'model> {
25 state: State<'model>,
27 pub tolerance: f64,
29 pub max_iter: usize,
31 pub logging: bool,
33}
34
35impl<'model> Solver<'model> for PrimalSimplexSolver<'model> {
36 fn new(model: &'model Model) -> Self {
37 Self {
38 state: State::Dense(PrimalSimplexState::new(model)),
39 tolerance: 1e-8,
40 max_iter: 1000,
41 logging: false,
42 }
43 }
44
45 fn solve(&mut self) -> Result<Solution, SolveError> {
46 match &self.state {
48 State::Dense(s) => crate::validate::check_lp(s.model)?,
49 State::Sparse(s) => crate::validate::check_lp(s.model)?,
50 }
51
52 let (values, obj) = match &mut self.state {
54 State::Dense(s) => s.solve_lp(self.max_iter, self.tolerance)?,
55 State::Sparse(s) => s.solve_lp(self.max_iter, self.tolerance)?,
56 };
57
58 if self.logging {
59 match &self.state {
60 State::Dense(s) => println!(
61 "Simplex finished with status {:?} in {} iterations. Objective value: {}",
62 s.status, s.iteration, obj
63 ),
64 State::Sparse(s) => println!(
65 "Simplex finished with status {:?} in {} iterations. Objective value: {}",
66 s.status, s.iteration, obj
67 ),
68 }
69 }
70
71 let status = match &self.state {
72 State::Dense(s) => s.status.clone(),
73 State::Sparse(s) => s.status.clone(),
74 };
75
76 Ok(Solution { values, objective_value: Some(obj), status })
77 }
78
79 fn get_objective_value(&self) -> f64 {
80 match &self.state {
81 State::Dense(s) => s.objective,
82 State::Sparse(s) => s.objective,
83 }
84 }
85
86 fn get_solution(&self) -> Vec<f64> {
87 vec![]
89 }
90}
91
92#[allow(dead_code)]
93enum State<'model> {
94 Dense(PrimalSimplexState<'model, DenseMatrix>),
95 Sparse(PrimalSimplexState<'model, SparseMatrix>),
96}
97
98#[derive(Clone)]
103pub struct PrimalSimplexState<'model, A: Matrix> {
104 pub model: &'model Model,
106 pub iteration: usize,
108
109 pub basis: Vec<usize>,
111 pub non_basis: Vec<usize>,
113 pub x_b: Vec<f64>,
115
116 pub a: A,
118 pub b: Vec<f64>,
120 pub c: Vec<f64>,
122
123 pub objective: f64,
125 pub status: SolveStatus,
127
128 minimise: bool,
130
131 logging: bool,
133
134 log_interval: usize,
136}
137
138impl<'model, A: Matrix> PrimalSimplexState<'model, A> {
139 pub fn new(model: &'model Model) -> Self {
144 let n_vars = model.vars().len();
145 let n_cons = model.constraints().len();
146
147 let mut b = vec![0.0; n_cons];
148
149 let mut n_total = n_vars;
150 for cons in model.constraints().iter() {
151 match cons.cmp {
152 Cmp::Leq | Cmp::Geq => n_total += 1,
153 Cmp::Eq => {}
154 }
155 }
156
157 let mut a = A::new(n_cons, n_total);
158 let mut c = vec![0.0; n_total];
159
160 let minimise =
161 model.objective().map(|o| o.sense == Sense::Minimize).unwrap_or(false);
162
163 if let Some(obj) = model.objective() {
164 for term in &obj.expr.terms {
165 c[term.var.0] = match obj.sense {
166 Sense::Maximize => term.coeff,
167 Sense::Minimize => -term.coeff,
168 };
169 }
170 }
171
172 let mut extra_idx = n_vars;
173 for (i, cons) in model.constraints().iter().enumerate() {
174 b[i] = cons.rhs;
175 for term in &cons.expr.terms {
176 a.set(i, term.var.0, term.coeff);
177 }
178 match cons.cmp {
179 Cmp::Leq => {
180 a.set(i, extra_idx, 1.0);
181 extra_idx += 1;
182 }
183 Cmp::Geq => {
184 a.set(i, extra_idx, -1.0);
185 extra_idx += 1;
186 }
187 Cmp::Eq => {}
188 }
189 }
190
191 Self {
192 model,
193 iteration: 0,
194 basis: Vec::new(),
195 non_basis: (0..n_vars).collect(),
196 x_b: vec![0.0; n_cons],
197 a,
198 b,
199 c,
200 objective: 0.0,
201 status: SolveStatus::NotSolved,
202 minimise,
203 logging: true,
205 log_interval: 100,
206 }
207 }
208
209 pub fn solve_lp(
215 &mut self,
216 max_iter: usize,
217 tol: f64,
218 ) -> Result<(Vec<f64>, f64), SolveError> {
219 self.init_basis();
220 let orig_n = self.a.cols();
221
222 if self.try_phase2(max_iter, tol)? {
223 return Ok(self.extract_solution(orig_n));
224 }
225
226 self.phase1(orig_n, max_iter, tol)?;
227 self.phase2(max_iter, tol)?;
228
229 Ok(self.extract_solution(orig_n))
230 }
231
232 fn try_phase2(&mut self, max_iter: usize, tol: f64) -> Result<bool, SolveError> {
234 let mut bmat = self.build_bmat();
235 match self.compute_basic_solution(&mut bmat) {
236 Ok(xb) if xb.iter().all(|&v| v >= -tol) => {
237 self.x_b = xb;
238 self.remove_artificial_from_basis(&mut bmat, self.a.cols())
239 .map_err(SolveError::InvalidModel)?;
240 self.run_simplex(&mut bmat, max_iter, tol)?;
241 Ok(true)
242 }
243 _ => Ok(false),
244 }
245 }
246
247 fn phase1(
249 &mut self,
250 orig_n: usize,
251 max_iter: usize,
252 tol: f64,
253 ) -> Result<(), SolveError> {
254 let (orig_a, orig_c, mut bmat) = self.setup_phase1(orig_n);
255 self.run_simplex(&mut bmat, max_iter, tol)?;
256
257 let sum_art: f64 = self
258 .basis
259 .iter()
260 .enumerate()
261 .map(|(i, &v)| self.c[v] * self.x_b[i])
262 .sum::<f64>()
263 .neg();
264
265 if sum_art > tol {
266 self.status = SolveStatus::Infeasible;
267 return Ok(());
268 }
269
270 self.remove_artificial_from_basis(&mut bmat, orig_n)
271 .map_err(SolveError::InvalidModel)?;
272
273 self.a = orig_a;
274 self.c = orig_c;
275 let mut used = vec![false; orig_n];
276 for &b in &self.basis {
277 if b < orig_n {
278 used[b] = true;
279 }
280 }
281 self.non_basis = (0..orig_n).filter(|&j| !used[j]).collect();
282 Ok(())
283 }
284
285 fn phase2(&mut self, max_iter: usize, tol: f64) -> Result<(), SolveError> {
287 let mut bmat = self.build_bmat();
288 self.run_simplex(&mut bmat, max_iter, tol)
289 }
290
291 pub fn init_basis(&mut self) {
293 let m = self.a.rows();
294 let n = self.a.cols();
295
296 let mut basis = vec![None; m];
297 let mut used = vec![false; n];
298
299 for (j, used_j) in used.iter_mut().enumerate().take(n) {
300 let mut one_row = None;
301 let mut ok = true;
302 for i in 0..m {
303 let v = self.a.get(i, j);
304 if v.abs() > 1e-12 {
305 if (v - 1.0).abs() < 1e-12 {
306 if one_row.is_some() {
307 ok = false;
308 break;
309 }
310 one_row = Some(i);
311 } else {
312 ok = false;
313 break;
314 }
315 }
316 }
317 if ok && one_row.is_some_and(|r| basis[r].is_none()) {
318 let r = one_row.unwrap();
319 basis[r] = Some(j);
320 *used_j = true;
321 }
322 }
323
324 if basis.iter().all(|b| b.is_some()) {
325 self.basis = basis.into_iter().map(|b| b.unwrap()).collect();
326 self.non_basis = (0..n).filter(|j| !used[*j]).collect();
327 } else {
328 self.basis = (0..m).collect();
329 self.non_basis = (m..n).collect();
330 }
331 }
332
333 pub fn build_bmat(&self) -> A {
335 let m = self.a.rows();
336 let mut bmat = A::new(m, m);
337 for i in 0..m {
338 for j in 0..m {
339 bmat.set(i, j, self.a.get(i, self.basis[j]));
340 }
341 }
342 bmat
343 }
344
345 pub fn compute_basic_solution(&self, bmat: &mut A) -> Result<Vec<f64>, String> {
347 let mut xb = self.b.clone();
348 bmat.mldivide(&mut xb).map_err(|e| format!("gauss failed: {e}"))?;
349 Ok(xb)
350 }
351
352 fn run_simplex(
354 &mut self,
355 bmat: &mut A,
356 max_iter: usize,
357 tol: f64,
358 ) -> Result<(), SolveError> {
359 let current_iter = self.iteration;
360 for iter in current_iter..max_iter {
361 self.iteration = iter;
362
363 let pi = self.compute_duals(bmat)?;
364 let Some((nb_pos, entering)) = self.choose_entering(&pi, tol) else {
365 self.status = SolveStatus::Optimal;
366 return Ok(());
367 };
368
369 let d = self.compute_direction(bmat, entering)?;
370 let Some((leave_row, theta)) = self.choose_leaving(&d, tol) else {
371 self.status = SolveStatus::Unbounded;
372 return Ok(());
373 };
374
375 self.update_primal(&d, leave_row, theta);
376 self.pivot(bmat, nb_pos, leave_row, entering);
377 self.update_objective();
378
379 if self.logging && (iter + 1) % self.log_interval == 0 {
380 println!(
381 "Iteration {:>4}: Objective = {:>12.6}",
382 iter + 1,
383 if self.minimise { -self.objective } else { self.objective }
384 );
385 }
386 }
387
388 Err(SolveError::Other("max iterations reached".into()))
389 }
390
391 fn compute_duals(&self, bmat: &A) -> Result<Vec<f64>, SolveError> {
393 let m = bmat.rows();
394 let mut pi = (0..m).map(|i| self.c[self.basis[i]]).collect::<Vec<_>>();
395
396 let mut bt = A::new(m, m);
397 for i in 0..m {
398 for j in 0..m {
399 bt.set(i, j, bmat.get(j, i));
400 }
401 }
402
403 bt.mldivide(&mut pi)
404 .map_err(|e| SolveError::Other(format!("dual solve failed: {e}")))?;
405
406 Ok(pi)
407 }
408
409 fn choose_entering(&self, pi: &[f64], tol: f64) -> Option<(usize, usize)> {
411 self.non_basis
412 .iter()
413 .enumerate()
414 .filter_map(|(pos, &j)| {
415 let rc = self.c[j]
416 - (0..pi.len()).map(|i| pi[i] * self.a.get(i, j)).sum::<f64>();
417 (rc > tol).then_some((pos, j, rc))
418 })
419 .max_by(|a, b| a.2.partial_cmp(&b.2).unwrap())
420 .map(|(pos, j, _)| (pos, j))
421 }
422
423 fn compute_direction(
425 &self,
426 bmat: &mut A,
427 entering: usize,
428 ) -> Result<Vec<f64>, SolveError> {
429 let mut d = (0..bmat.rows()).map(|i| self.a.get(i, entering)).collect::<Vec<_>>();
430
431 bmat.mldivide(&mut d)
432 .map_err(|e| SolveError::Other(format!("direction solve failed: {e}")))?;
433
434 Ok(d)
435 }
436
437 fn choose_leaving(&self, d: &[f64], tol: f64) -> Option<(usize, f64)> {
439 (0..d.len())
440 .filter(|&i| d[i] > tol)
441 .map(|i| (i, self.x_b[i] / d[i]))
442 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
443 }
444
445 fn update_primal(&mut self, d: &[f64], leave: usize, theta: f64) {
447 for (xi, di) in self.x_b.iter_mut().zip(d.iter()) {
448 *xi -= theta * di;
449 if (*xi).abs() < 1e-12 {
450 *xi = 0.0;
451 }
452 }
453 self.x_b[leave] = theta;
454 }
455
456 fn pivot(
458 &mut self,
459 bmat: &mut A,
460 enter_pos: usize,
461 leave_row: usize,
462 entering: usize,
463 ) {
464 let leaving = self.basis[leave_row];
465 self.basis[leave_row] = entering;
466 self.non_basis[enter_pos] = leaving;
467
468 for i in 0..bmat.rows() {
469 bmat.set(i, leave_row, self.a.get(i, entering));
470 }
471 }
472
473 fn update_objective(&mut self) {
475 self.objective = self
476 .basis
477 .iter()
478 .enumerate()
479 .map(|(i, &v)| self.c[v] * self.x_b[i])
480 .sum();
481 }
482
483 pub fn setup_phase1(&mut self, orig_n: usize) -> (A, Vec<f64>, A) {
485 let m = self.a.rows();
486 let n = self.a.cols();
487
488 let mut a_aug = A::new(m, n + m);
489 let mut b_aug = self.b.clone();
490
491 for (i, bval) in b_aug.iter_mut().enumerate().take(m) {
492 if *bval < 0.0 {
493 *bval = -*bval;
494 for j in 0..n {
495 a_aug.set(i, j, -self.a.get(i, j));
496 }
497 } else {
498 for j in 0..n {
499 a_aug.set(i, j, self.a.get(i, j));
500 }
501 }
502
503 for j in 0..m {
504 a_aug.set(i, n + j, if i == j { 1.0 } else { 0.0 });
505 }
506 }
507
508 let mut c_aug = vec![0.0; n + m];
509 for j in 0..m {
510 c_aug[n + j] = -1.0;
511 }
512
513 let orig_a = self.a.clone();
514 let orig_c = self.c.clone();
515
516 self.a = a_aug;
517 self.c = c_aug;
518 self.basis = (orig_n..orig_n + m).collect();
519 self.non_basis = (0..orig_n).collect();
520 self.x_b = b_aug;
521
522 let mut bmat = A::new(m, m);
523 for i in 0..m {
524 for j in 0..m {
525 bmat.set(i, j, self.a.get(i, self.basis[j]));
526 }
527 }
528
529 (orig_a, orig_c, bmat)
530 }
531
532 pub fn remove_artificial_from_basis(
534 &mut self,
535 bmat: &mut A,
536 orig_n: usize,
537 ) -> Result<(), String> {
538 let m = bmat.rows();
539 for row in 0..m {
540 if self.basis[row] >= orig_n {
541 let mut pivot = None;
542 for (nb_pos, &j) in self.non_basis.iter().enumerate() {
543 if j < orig_n && self.a.get(row, j).abs() > 1e-12 {
544 pivot = Some((nb_pos, j));
545 break;
546 }
547 }
548
549 if let Some((nb_pos, j)) = pivot {
550 let leaving = self.basis[row];
551 self.basis[row] = j;
552 self.non_basis[nb_pos] = leaving;
553 for i in 0..m {
554 bmat.set(i, row, self.a.get(i, j));
555 }
556 } else if self.x_b[row].abs() > 1e-12 {
557 return Err(
558 "artificial variable left in basis with non-zero value".into()
559 );
560 } else {
561 for (nb_pos, &j) in self.non_basis.iter().enumerate() {
562 if j < orig_n && self.a.get(row, j).abs() < 1e-12 {
563 let leaving = self.basis[row];
564 self.basis[row] = j;
565 self.non_basis[nb_pos] = leaving;
566 for i in 0..m {
567 bmat.set(i, row, self.a.get(i, j));
568 }
569 break;
570 }
571 }
572 }
573 }
574 }
575 Ok(())
576 }
577
578 pub fn extract_solution(&self, orig_n: usize) -> (Vec<f64>, f64) {
580 let m = self.a.rows();
581 let mut sol = vec![0.0; orig_n];
582
583 for i in 0..m {
584 if self.basis[i] < orig_n {
585 sol[self.basis[i]] = self.x_b[i];
586 }
587 }
588
589 let mut obj = self
590 .basis
591 .iter()
592 .enumerate()
593 .filter(|(_, v)| **v < orig_n)
594 .map(|(i, v)| self.c[*v] * self.x_b[i])
595 .sum::<f64>();
596
597 if self.minimise {
598 obj = -obj;
599 }
600
601 (sol, obj)
602 }
603}