Skip to main content

oxiz_math/
simplex.rs

1//! Simplex algorithm for linear arithmetic.
2//!
3//! This module implements the Simplex algorithm for solving linear programming
4//! problems and checking satisfiability of linear real arithmetic constraints.
5//!
6//! The implementation follows the two-phase simplex method:
7//! - Phase 1: Find a basic feasible solution (or prove infeasibility)
8//! - Phase 2: Optimize the objective function (or detect unboundedness)
9//!
10//! Reference: Z3's `math/simplex/` directory and standard LP textbooks.
11
12use crate::fast_rational::FastRational;
13#[allow(unused_imports)]
14use crate::prelude::*;
15use core::fmt;
16use num_bigint::BigInt;
17use num_traits::{One, Signed, Zero};
18
19/// Variable identifier for simplex.
20pub type VarId = u32;
21
22/// Constraint identifier.
23pub type ConstraintId = u32;
24
25/// Bound type for a variable.
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum BoundType {
28    /// Lower bound: x >= value
29    Lower,
30    /// Upper bound: x <= value
31    Upper,
32    /// Equality: x = value
33    Equal,
34}
35
36/// A bound on a variable.
37#[derive(Debug, Clone, PartialEq, Eq)]
38pub struct Bound {
39    /// The variable this bound applies to.
40    pub var: VarId,
41    /// The type of bound (lower, upper, or equality).
42    pub bound_type: BoundType,
43    /// The bound value.
44    pub value: FastRational,
45}
46
47impl Bound {
48    /// Create a lower bound: var >= value.
49    pub fn lower(var: VarId, value: FastRational) -> Self {
50        Self {
51            var,
52            bound_type: BoundType::Lower,
53            value,
54        }
55    }
56
57    /// Create an upper bound: var <= value.
58    pub fn upper(var: VarId, value: FastRational) -> Self {
59        Self {
60            var,
61            bound_type: BoundType::Upper,
62            value,
63        }
64    }
65
66    /// Create an equality bound: var = value.
67    pub fn equal(var: VarId, value: FastRational) -> Self {
68        Self {
69            var,
70            bound_type: BoundType::Equal,
71            value,
72        }
73    }
74}
75
76/// Variable classification in the tableau.
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78pub enum VarClass {
79    /// Basic variable (appears on LHS of a row).
80    Basic,
81    /// Non-basic variable (appears on RHS).
82    NonBasic,
83}
84
85/// A row in the simplex tableau represents: basic_var = constant + sum(coeff_i * nonbasic_var_i).
86#[derive(Debug, Clone)]
87pub struct Row {
88    /// The basic variable for this row.
89    pub basic_var: VarId,
90    /// The constant term.
91    pub constant: FastRational,
92    /// Coefficients for non-basic variables: var_id -> coefficient.
93    pub coeffs: FxHashMap<VarId, FastRational>,
94}
95
96impl Row {
97    /// Create a new row with a basic variable.
98    pub fn new(basic_var: VarId) -> Self {
99        Self {
100            basic_var,
101            constant: FastRational::zero(),
102            coeffs: FxHashMap::default(),
103        }
104    }
105
106    /// Create a row representing: basic_var = constant + sum(coeffs).
107    pub fn from_expr(
108        basic_var: VarId,
109        constant: FastRational,
110        coeffs: FxHashMap<VarId, FastRational>,
111    ) -> Self {
112        let mut row = Self {
113            basic_var,
114            constant,
115            coeffs: FxHashMap::default(),
116        };
117        for (var, coeff) in coeffs {
118            if !coeff.is_zero() {
119                row.coeffs.insert(var, coeff);
120            }
121        }
122        row
123    }
124
125    /// Get the value of the basic variable given values of non-basic variables.
126    pub fn eval(&self, non_basic_values: &FxHashMap<VarId, FastRational>) -> FastRational {
127        let mut value = self.constant.clone();
128        for (var, coeff) in &self.coeffs {
129            if let Some(val) = non_basic_values.get(var) {
130                value += coeff * val;
131            }
132        }
133        value
134    }
135
136    /// Add a multiple of another row to this row.
137    /// self += multiplier * other
138    pub fn add_row(&mut self, multiplier: &FastRational, other: &Row) {
139        if multiplier.is_zero() {
140            return;
141        }
142
143        self.constant += multiplier * &other.constant;
144
145        for (var, coeff) in &other.coeffs {
146            let new_coeff = self
147                .coeffs
148                .get(var)
149                .cloned()
150                .unwrap_or_else(FastRational::zero)
151                + multiplier * coeff;
152            if new_coeff.is_zero() {
153                self.coeffs.remove(var);
154            } else {
155                self.coeffs.insert(*var, new_coeff);
156            }
157        }
158    }
159
160    /// Multiply the row by a scalar.
161    pub fn scale(&mut self, scalar: &FastRational) {
162        if scalar.is_zero() {
163            self.constant = FastRational::zero();
164            self.coeffs.clear();
165            return;
166        }
167
168        self.constant *= scalar;
169        for coeff in self.coeffs.values_mut() {
170            *coeff *= scalar;
171        }
172    }
173
174    /// Substitute a non-basic variable using another row.
175    /// If var appears in this row, replace it using: var = row.constant + sum(row.coeffs).
176    pub fn substitute(&mut self, var: VarId, row: &Row) {
177        if let Some(coeff) = self.coeffs.remove(&var) {
178            // Add coeff * row to self
179            self.add_row(&coeff, row);
180        }
181    }
182
183    /// Check if the row is valid (no basic variable in RHS).
184    pub fn is_valid(&self, basic_vars: &FxHashSet<VarId>) -> bool {
185        for var in self.coeffs.keys() {
186            if basic_vars.contains(var) {
187                return false;
188            }
189        }
190        true
191    }
192
193    /// Normalize the row by dividing by the GCD of coefficients.
194    pub fn normalize(&mut self) {
195        if self.coeffs.is_empty() {
196            return;
197        }
198
199        // Compute GCD of all coefficients
200        let mut gcd: Option<BigInt> = None;
201        for coeff in self.coeffs.values() {
202            if !coeff.is_zero() {
203                let num = coeff.numer().clone();
204                gcd = Some(match gcd {
205                    None => num.abs(),
206                    Some(g) => gcd_bigint(g, num.abs()),
207                });
208            }
209        }
210
211        if let Some(g) = gcd
212            && !g.is_one()
213        {
214            let divisor = FastRational::from_integer(g);
215            self.scale(&(FastRational::one() / divisor));
216        }
217    }
218}
219
220impl fmt::Display for Row {
221    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
222        write!(f, "x{} = {}", self.basic_var, self.constant)?;
223        for (var, coeff) in &self.coeffs {
224            if coeff.is_positive() {
225                write!(f, " + {}*x{}", coeff, var)?;
226            } else {
227                write!(f, " - {}*x{}", -coeff, var)?;
228            }
229        }
230        Ok(())
231    }
232}
233
234/// Compute GCD of two BigInts using Euclidean algorithm.
235fn gcd_bigint(mut a: BigInt, mut b: BigInt) -> BigInt {
236    while !b.is_zero() {
237        let t = &a % &b;
238        a = b;
239        b = t;
240    }
241    a.abs()
242}
243
244/// Result of a simplex operation.
245#[derive(Debug, Clone, PartialEq, Eq)]
246pub enum SimplexResult {
247    /// The system is satisfiable.
248    Sat,
249    /// The system is unsatisfiable.
250    Unsat,
251    /// The objective is unbounded.
252    Unbounded,
253    /// Unknown (shouldn't happen in simplex).
254    Unknown,
255}
256
257/// Explanation for why a constraint caused unsatisfiability.
258#[derive(Debug, Clone)]
259pub struct Conflict {
260    /// The constraints involved in the conflict.
261    pub constraints: Vec<ConstraintId>,
262}
263
264/// The Simplex tableau.
265#[derive(Debug, Clone)]
266pub struct SimplexTableau {
267    /// Rows of the tableau, indexed by basic variable.
268    rows: FxHashMap<VarId, Row>,
269    /// Set of basic variables.
270    basic_vars: FxHashSet<VarId>,
271    /// Set of non-basic variables.
272    non_basic_vars: FxHashSet<VarId>,
273    /// Current assignment to all variables.
274    assignment: FxHashMap<VarId, FastRational>,
275    /// Lower bounds for variables.
276    lower_bounds: FxHashMap<VarId, FastRational>,
277    /// Upper bounds for variables.
278    upper_bounds: FxHashMap<VarId, FastRational>,
279    /// Mapping from variables to the constraints that bound them.
280    var_to_constraints: FxHashMap<VarId, Vec<ConstraintId>>,
281    /// Next fresh variable ID.
282    next_var_id: VarId,
283    /// Use Bland's rule to prevent cycling.
284    use_blands_rule: bool,
285}
286
287impl SimplexTableau {
288    /// Create a new empty tableau.
289    pub fn new() -> Self {
290        Self {
291            rows: FxHashMap::default(),
292            basic_vars: FxHashSet::default(),
293            non_basic_vars: FxHashSet::default(),
294            assignment: FxHashMap::default(),
295            lower_bounds: FxHashMap::default(),
296            upper_bounds: FxHashMap::default(),
297            var_to_constraints: FxHashMap::default(),
298            next_var_id: 0,
299            use_blands_rule: true,
300        }
301    }
302
303    /// Create a fresh variable.
304    pub fn fresh_var(&mut self) -> VarId {
305        let id = self.next_var_id;
306        self.next_var_id += 1;
307        self.non_basic_vars.insert(id);
308        self.assignment.insert(id, FastRational::zero());
309        id
310    }
311
312    /// Add a variable with initial bounds.
313    pub fn add_var(&mut self, lower: Option<FastRational>, upper: Option<FastRational>) -> VarId {
314        let var = self.fresh_var();
315        if let Some(lb) = lower {
316            self.lower_bounds.insert(var, lb.clone());
317            self.assignment.insert(var, lb);
318        }
319        if let Some(ub) = upper {
320            self.upper_bounds.insert(var, ub);
321        }
322        var
323    }
324
325    /// Add a row to the tableau.
326    /// The row represents: basic_var = constant + sum(coeffs).
327    pub fn add_row(&mut self, row: Row) -> Result<(), String> {
328        let basic_var = row.basic_var;
329
330        // Ensure basic_var is not already basic
331        if self.basic_vars.contains(&basic_var) {
332            return Err(format!("Variable x{} is already basic", basic_var));
333        }
334
335        // Ensure all non-basic vars in coeffs are indeed non-basic
336        for var in row.coeffs.keys() {
337            if self.basic_vars.contains(var) {
338                return Err(format!("Variable x{} appears in RHS but is basic", var));
339            }
340            self.non_basic_vars.insert(*var);
341        }
342
343        // Move basic_var from non-basic to basic
344        self.non_basic_vars.remove(&basic_var);
345        self.basic_vars.insert(basic_var);
346
347        // Compute initial value for basic_var
348        let value = row.eval(&self.assignment);
349        self.assignment.insert(basic_var, value);
350
351        self.rows.insert(basic_var, row);
352        Ok(())
353    }
354
355    /// Add a bound constraint.
356    pub fn add_bound(
357        &mut self,
358        var: VarId,
359        bound_type: BoundType,
360        value: FastRational,
361        constraint_id: ConstraintId,
362    ) -> Result<(), Conflict> {
363        self.var_to_constraints
364            .entry(var)
365            .or_default()
366            .push(constraint_id);
367
368        match bound_type {
369            BoundType::Lower => {
370                let current_lb = self.lower_bounds.get(&var);
371                if let Some(lb) = current_lb {
372                    if &value > lb {
373                        self.lower_bounds.insert(var, value.clone());
374                    }
375                } else {
376                    self.lower_bounds.insert(var, value.clone());
377                }
378
379                // Check for immediate conflict
380                if let Some(ub) = self.upper_bounds.get(&var)
381                    && &value > ub
382                {
383                    return Err(Conflict {
384                        constraints: vec![constraint_id],
385                    });
386                }
387            }
388            BoundType::Upper => {
389                let current_ub = self.upper_bounds.get(&var);
390                if let Some(ub) = current_ub {
391                    if &value < ub {
392                        self.upper_bounds.insert(var, value.clone());
393                    }
394                } else {
395                    self.upper_bounds.insert(var, value.clone());
396                }
397
398                // Check for immediate conflict
399                if let Some(lb) = self.lower_bounds.get(&var)
400                    && &value < lb
401                {
402                    return Err(Conflict {
403                        constraints: vec![constraint_id],
404                    });
405                }
406            }
407            BoundType::Equal => {
408                self.lower_bounds.insert(var, value.clone());
409                self.upper_bounds.insert(var, value.clone());
410
411                // Check existing bounds
412                if let Some(lb) = self.lower_bounds.get(&var)
413                    && &value < lb
414                {
415                    return Err(Conflict {
416                        constraints: vec![constraint_id],
417                    });
418                }
419                if let Some(ub) = self.upper_bounds.get(&var)
420                    && &value > ub
421                {
422                    return Err(Conflict {
423                        constraints: vec![constraint_id],
424                    });
425                }
426            }
427        }
428
429        Ok(())
430    }
431
432    /// Get the current assignment to a variable.
433    pub fn get_value(&self, var: VarId) -> Option<&FastRational> {
434        self.assignment.get(&var)
435    }
436
437    /// Check if a basic variable violates its bounds.
438    fn violates_bounds(&self, var: VarId) -> bool {
439        if let Some(val) = self.assignment.get(&var) {
440            if let Some(lb) = self.lower_bounds.get(&var)
441                && val < lb
442            {
443                return true;
444            }
445            if let Some(ub) = self.upper_bounds.get(&var)
446                && val > ub
447            {
448                return true;
449            }
450        }
451        false
452    }
453
454    /// Find a basic variable that violates its bounds.
455    fn find_violating_basic_var(&self) -> Option<VarId> {
456        if self.use_blands_rule {
457            // Use Bland's rule: pick the smallest index
458            self.basic_vars
459                .iter()
460                .filter(|&&var| self.violates_bounds(var))
461                .min()
462                .copied()
463        } else {
464            self.basic_vars
465                .iter()
466                .find(|&&var| self.violates_bounds(var))
467                .copied()
468        }
469    }
470
471    /// Find a non-basic variable to pivot with.
472    /// Returns (non_basic_var, improving_direction).
473    fn find_pivot_non_basic(
474        &self,
475        basic_var: VarId,
476        target_increase: bool,
477    ) -> Option<(VarId, bool)> {
478        let row = self.rows.get(&basic_var)?;
479
480        let mut candidates = Vec::new();
481
482        for (nb_var, coeff) in &row.coeffs {
483            let current_val = self.assignment.get(nb_var)?;
484            let lb = self.lower_bounds.get(nb_var);
485            let ub = self.upper_bounds.get(nb_var);
486
487            // Determine if we can increase or decrease nb_var
488            let can_increase = ub.is_none_or(|bound| bound > current_val);
489            let can_decrease = lb.is_none_or(|bound| bound < current_val);
490
491            // If coeff > 0: increasing nb_var increases basic_var
492            // If coeff < 0: increasing nb_var decreases basic_var
493            let increases_basic = coeff.is_positive();
494
495            if target_increase {
496                // We want to increase basic_var
497                if increases_basic && can_increase {
498                    candidates.push((*nb_var, true));
499                } else if !increases_basic && can_decrease {
500                    candidates.push((*nb_var, false));
501                }
502            } else {
503                // We want to decrease basic_var
504                if increases_basic && can_decrease {
505                    candidates.push((*nb_var, false));
506                } else if !increases_basic && can_increase {
507                    candidates.push((*nb_var, true));
508                }
509            }
510        }
511
512        if candidates.is_empty() {
513            return None;
514        }
515
516        // Use Bland's rule: pick smallest index
517        if self.use_blands_rule {
518            candidates.sort_by_key(|(var, _)| *var);
519        }
520
521        Some(candidates[0])
522    }
523
524    /// Perform a pivot operation.
525    /// Swap basic_var (currently basic) with non_basic_var (currently non-basic).
526    pub fn pivot(&mut self, basic_var: VarId, non_basic_var: VarId) -> Result<(), String> {
527        // Get the row for basic_var
528        let row = self
529            .rows
530            .get(&basic_var)
531            .ok_or_else(|| format!("No row for basic variable x{}", basic_var))?;
532
533        // Get coefficient of non_basic_var in this row
534        let coeff = row
535            .coeffs
536            .get(&non_basic_var)
537            .ok_or_else(|| {
538                format!(
539                    "Non-basic variable x{} not in row for x{}",
540                    non_basic_var, basic_var
541                )
542            })?
543            .clone();
544
545        if coeff.is_zero() {
546            return Err(format!("Coefficient of x{} is zero", non_basic_var));
547        }
548
549        // Solve for non_basic_var in terms of basic_var
550        // Old: basic_var = constant + coeff * non_basic_var + sum(...)
551        // New: non_basic_var = (basic_var - constant - sum(...)) / coeff
552
553        let mut new_row = Row::new(non_basic_var);
554        new_row.constant = -&row.constant / &coeff;
555        new_row
556            .coeffs
557            .insert(basic_var, FastRational::one() / &coeff);
558
559        for (var, c) in &row.coeffs {
560            if var != &non_basic_var {
561                new_row.coeffs.insert(*var, -c / &coeff);
562            }
563        }
564
565        // Substitute non_basic_var in all other rows
566        let rows_to_update: Vec<VarId> = self
567            .rows
568            .keys()
569            .filter(|&&v| v != basic_var)
570            .copied()
571            .collect();
572
573        for row_var in rows_to_update {
574            if let Some(r) = self.rows.get_mut(&row_var) {
575                r.substitute(non_basic_var, &new_row);
576            }
577        }
578
579        // Remove old row and add new row
580        self.rows.remove(&basic_var);
581        self.rows.insert(non_basic_var, new_row);
582
583        // Update basic/non-basic sets
584        self.basic_vars.remove(&basic_var);
585        self.basic_vars.insert(non_basic_var);
586        self.non_basic_vars.remove(&non_basic_var);
587        self.non_basic_vars.insert(basic_var);
588
589        // Update assignments
590        self.update_assignment();
591
592        Ok(())
593    }
594
595    /// Update the assignment based on current tableau.
596    fn update_assignment(&mut self) {
597        // Evaluate all basic variables
598        for (basic_var, row) in &self.rows {
599            let value = row.eval(&self.assignment);
600            self.assignment.insert(*basic_var, value);
601        }
602    }
603
604    /// Check feasibility and fix violations using pivoting.
605    pub fn check(&mut self) -> Result<SimplexResult, Conflict> {
606        let max_iterations = 10000;
607        let mut iterations = 0;
608
609        while let Some(violating_var) = self.find_violating_basic_var() {
610            iterations += 1;
611            if iterations > max_iterations {
612                return Ok(SimplexResult::Unknown);
613            }
614
615            let current_val = self
616                .assignment
617                .get(&violating_var)
618                .cloned()
619                .ok_or_else(|| Conflict {
620                    constraints: vec![],
621                })?;
622
623            let lb = self.lower_bounds.get(&violating_var);
624            let ub = self.upper_bounds.get(&violating_var);
625
626            // Determine if we need to increase or decrease the variable
627            let need_increase = lb.is_some_and(|l| &current_val < l);
628            let need_decrease = ub.is_some_and(|u| &current_val > u);
629
630            if !need_increase && !need_decrease {
631                continue;
632            }
633
634            // Find a non-basic variable to pivot with
635            if let Some((nb_var, _direction)) =
636                self.find_pivot_non_basic(violating_var, need_increase)
637            {
638                // Compute the new value for nb_var
639                let target_value = if need_increase {
640                    lb.cloned().unwrap_or_else(FastRational::zero)
641                } else {
642                    ub.cloned().unwrap_or_else(FastRational::zero)
643                };
644
645                // Update nb_var to move basic_var to target
646                let row = self.rows.get(&violating_var).ok_or_else(|| Conflict {
647                    constraints: vec![],
648                })?;
649                let coeff = row.coeffs.get(&nb_var).cloned().ok_or_else(|| Conflict {
650                    constraints: vec![],
651                })?;
652
653                let delta = &target_value - &current_val;
654                let nb_delta = &delta / &coeff;
655                let current_nb = self
656                    .assignment
657                    .get(&nb_var)
658                    .cloned()
659                    .ok_or_else(|| Conflict {
660                        constraints: vec![],
661                    })?;
662                let new_nb = current_nb + nb_delta;
663
664                // Clamp to bounds
665                let new_nb = if let Some(lb) = self.lower_bounds.get(&nb_var) {
666                    new_nb.max(lb.clone())
667                } else {
668                    new_nb
669                };
670                let new_nb = if let Some(ub) = self.upper_bounds.get(&nb_var) {
671                    new_nb.min(ub.clone())
672                } else {
673                    new_nb
674                };
675
676                self.assignment.insert(nb_var, new_nb);
677                self.update_assignment();
678            } else {
679                // No pivot found - problem is infeasible or unbounded
680                let constraints = self
681                    .var_to_constraints
682                    .get(&violating_var)
683                    .cloned()
684                    .unwrap_or_default();
685                return Err(Conflict { constraints });
686            }
687        }
688
689        Ok(SimplexResult::Sat)
690    }
691
692    /// Dual simplex algorithm for Linear Programming.
693    ///
694    /// The dual simplex maintains dual feasibility (all reduced costs non-negative)
695    /// while working toward primal feasibility (all variables within bounds).
696    ///
697    /// This is particularly useful for:
698    /// - Reoptimization after adding constraints
699    /// - Branch-and-bound in integer programming
700    /// - Problems that naturally start dual-feasible
701    ///
702    /// Reference: Standard LP textbooks and Z3's dual simplex implementation.
703    pub fn check_dual(&mut self) -> Result<SimplexResult, Conflict> {
704        let max_iterations = 10000;
705        let mut iterations = 0;
706
707        // Dual simplex loop: restore primal feasibility
708        while let Some(leaving_var) = self.find_violating_basic_var() {
709            iterations += 1;
710            if iterations > max_iterations {
711                return Ok(SimplexResult::Unknown);
712            }
713
714            // Get the row for the leaving variable
715            let row = match self.rows.get(&leaving_var) {
716                Some(r) => r.clone(),
717                None => continue,
718            };
719
720            let current_val = self
721                .assignment
722                .get(&leaving_var)
723                .cloned()
724                .unwrap_or_else(FastRational::zero);
725
726            let lb = self.lower_bounds.get(&leaving_var);
727            let ub = self.upper_bounds.get(&leaving_var);
728
729            // Determine which bound is violated
730            let violated_lower = lb.is_some_and(|l| &current_val < l);
731            let violated_upper = ub.is_some_and(|u| &current_val > u);
732
733            if !violated_lower && !violated_upper {
734                continue;
735            }
736
737            // For dual simplex, find entering variable using dual pricing rule
738            let entering_var = self.find_entering_var_dual(&row, violated_lower)?;
739
740            // Pivot leaving_var out, entering_var in
741            self.pivot(leaving_var, entering_var)
742                .map_err(|_| Conflict {
743                    constraints: self
744                        .var_to_constraints
745                        .get(&leaving_var)
746                        .cloned()
747                        .unwrap_or_default(),
748                })?;
749        }
750
751        Ok(SimplexResult::Sat)
752    }
753
754    /// Find the entering variable for dual simplex using dual pricing rule.
755    ///
756    /// For dual simplex:
757    /// - If leaving var violates lower bound: choose entering var with negative coeff
758    /// - If leaving var violates upper bound: choose entering var with positive coeff
759    /// - Among candidates, choose one that maintains dual feasibility
760    fn find_entering_var_dual(&self, row: &Row, violated_lower: bool) -> Result<VarId, Conflict> {
761        let mut best_var = None;
762        let mut best_ratio: Option<FastRational> = None;
763
764        // Iterate through non-basic variables in the row
765        for (&nb_var, coeff) in &row.coeffs {
766            // For dual simplex pricing:
767            // - If violated_lower, we need coeff < 0 (to increase basic var)
768            // - If violated_upper, we need coeff > 0 (to decrease basic var)
769            let sign_ok = if violated_lower {
770                coeff.is_negative()
771            } else {
772                coeff.is_positive()
773            };
774
775            if !sign_ok || coeff.is_zero() {
776                continue;
777            }
778
779            // Compute the dual ratio for maintaining dual feasibility
780            // This is a simplified version; full implementation would compute reduced costs
781            let ratio = coeff.abs();
782
783            match &best_ratio {
784                None => {
785                    best_ratio = Some(ratio);
786                    best_var = Some(nb_var);
787                }
788                Some(current_best) => {
789                    // Choose the variable with smallest ratio (steepest descent in dual)
790                    // Use Bland's rule for tie-breaking if enabled
791                    if &ratio < current_best {
792                        best_ratio = Some(ratio);
793                        best_var = Some(nb_var);
794                    } else if self.use_blands_rule && &ratio == current_best {
795                        // Bland's rule: choose smaller index
796                        // best_var is guaranteed Some when best_ratio is Some
797                        if best_var.is_none_or(|bv| nb_var < bv) {
798                            best_var = Some(nb_var);
799                        }
800                    }
801                }
802            }
803        }
804
805        best_var.ok_or(Conflict {
806            constraints: vec![],
807        })
808    }
809
810    /// Get all variables.
811    pub fn vars(&self) -> Vec<VarId> {
812        let mut vars: Vec<VarId> = self
813            .basic_vars
814            .iter()
815            .chain(self.non_basic_vars.iter())
816            .copied()
817            .collect();
818        vars.sort_unstable();
819        vars
820    }
821
822    /// Get all basic variables.
823    pub fn basic_vars(&self) -> Vec<VarId> {
824        let mut vars: Vec<VarId> = self.basic_vars.iter().copied().collect();
825        vars.sort_unstable();
826        vars
827    }
828
829    /// Get all non-basic variables.
830    pub fn non_basic_vars(&self) -> Vec<VarId> {
831        let mut vars: Vec<VarId> = self.non_basic_vars.iter().copied().collect();
832        vars.sort_unstable();
833        vars
834    }
835
836    /// Get the number of rows in the tableau.
837    pub fn num_rows(&self) -> usize {
838        self.rows.len()
839    }
840
841    /// Get the number of variables.
842    pub fn num_vars(&self) -> usize {
843        self.basic_vars.len() + self.non_basic_vars.len()
844    }
845
846    /// Enable or disable Bland's anti-cycling rule.
847    pub fn set_blands_rule(&mut self, enable: bool) {
848        self.use_blands_rule = enable;
849    }
850
851    /// Get the current model (satisfying assignment).
852    /// Returns None if the system is not known to be satisfiable.
853    pub fn get_model(&self) -> Option<FxHashMap<VarId, FastRational>> {
854        // Check if all variables satisfy their bounds
855        for (var, val) in &self.assignment {
856            if let Some(lb) = self.lower_bounds.get(var)
857                && val < lb
858            {
859                return None;
860            }
861            if let Some(ub) = self.upper_bounds.get(var)
862                && val > ub
863            {
864                return None;
865            }
866        }
867        Some(self.assignment.clone())
868    }
869
870    /// Check if the current assignment satisfies all bounds.
871    pub fn is_feasible(&self) -> bool {
872        for (var, val) in &self.assignment {
873            if let Some(lb) = self.lower_bounds.get(var)
874                && val < lb
875            {
876                return false;
877            }
878            if let Some(ub) = self.upper_bounds.get(var)
879                && val > ub
880            {
881                return false;
882            }
883        }
884        true
885    }
886
887    /// Find a variable that violates its bounds, if any.
888    pub fn find_violated_bound(&self) -> Option<VarId> {
889        for (var, val) in &self.assignment {
890            if let Some(lb) = self.lower_bounds.get(var)
891                && val < lb
892            {
893                return Some(*var);
894            }
895            if let Some(ub) = self.upper_bounds.get(var)
896                && val > ub
897            {
898                return Some(*var);
899            }
900        }
901        None
902    }
903
904    /// Get all constraints associated with a variable.
905    pub fn get_constraints(&self, var: VarId) -> Vec<ConstraintId> {
906        self.var_to_constraints
907            .get(&var)
908            .cloned()
909            .unwrap_or_default()
910    }
911
912    /// Extract a minimal conflicting core from constraints.
913    /// This is a simple implementation that returns all constraints involved.
914    /// A more sophisticated version would compute a true minimal unsat core.
915    pub fn get_unsat_core(&self, conflict: &Conflict) -> Vec<ConstraintId> {
916        conflict.constraints.clone()
917    }
918
919    /// Theory propagation: deduce new bounds from existing constraints.
920    /// Returns a list of (var, bound_type, value) tuples representing deduced bounds.
921    pub fn propagate(&self) -> Vec<(VarId, BoundType, FastRational)> {
922        let mut propagated = Vec::new();
923
924        // For each row: basic_var = constant + sum(coeff * non_basic_var)
925        // We can deduce bounds on basic_var from bounds on non_basic vars
926        for row in self.rows.values() {
927            let basic_var = row.basic_var;
928
929            // Compute lower bound: min value of basic_var
930            // basic_var >= constant + sum(min(coeff * lb, coeff * ub))
931            let mut lower_bound = row.constant.clone();
932            let mut has_finite_lower = true;
933
934            for (var, coeff) in &row.coeffs {
935                if coeff.is_positive() {
936                    // Positive coeff: use lower bound of var
937                    if let Some(lb) = self.lower_bounds.get(var) {
938                        lower_bound += coeff * lb;
939                    } else {
940                        has_finite_lower = false;
941                        break;
942                    }
943                } else {
944                    // Negative coeff: use upper bound of var
945                    if let Some(ub) = self.upper_bounds.get(var) {
946                        lower_bound += coeff * ub;
947                    } else {
948                        has_finite_lower = false;
949                        break;
950                    }
951                }
952            }
953
954            if has_finite_lower {
955                // Check if this is a tighter bound
956                if let Some(current_lb) = self.lower_bounds.get(&basic_var) {
957                    if &lower_bound > current_lb {
958                        propagated.push((basic_var, BoundType::Lower, lower_bound.clone()));
959                    }
960                } else {
961                    propagated.push((basic_var, BoundType::Lower, lower_bound.clone()));
962                }
963            }
964
965            // Compute upper bound: max value of basic_var
966            let mut upper_bound = row.constant.clone();
967            let mut has_finite_upper = true;
968
969            for (var, coeff) in &row.coeffs {
970                if coeff.is_positive() {
971                    // Positive coeff: use upper bound of var
972                    if let Some(ub) = self.upper_bounds.get(var) {
973                        upper_bound += coeff * ub;
974                    } else {
975                        has_finite_upper = false;
976                        break;
977                    }
978                } else {
979                    // Negative coeff: use lower bound of var
980                    if let Some(lb) = self.lower_bounds.get(var) {
981                        upper_bound += coeff * lb;
982                    } else {
983                        has_finite_upper = false;
984                        break;
985                    }
986                }
987            }
988
989            if has_finite_upper {
990                // Check if this is a tighter bound
991                if let Some(current_ub) = self.upper_bounds.get(&basic_var) {
992                    if &upper_bound < current_ub {
993                        propagated.push((basic_var, BoundType::Upper, upper_bound.clone()));
994                    }
995                } else {
996                    propagated.push((basic_var, BoundType::Upper, upper_bound.clone()));
997                }
998            }
999        }
1000
1001        propagated
1002    }
1003
1004    /// Assert a linear constraint: sum(coeffs) + constant {<=, >=, ==} 0.
1005    /// Returns a new slack variable if needed, and the constraint ID.
1006    pub fn assert_constraint(
1007        &mut self,
1008        coeffs: FxHashMap<VarId, FastRational>,
1009        constant: FastRational,
1010        bound_type: BoundType,
1011        constraint_id: ConstraintId,
1012    ) -> Result<VarId, Conflict> {
1013        // Create a fresh slack variable for the constraint
1014        let slack_var = self.fresh_var();
1015
1016        // slack_var = sum(coeffs) + constant
1017        let row = Row::from_expr(slack_var, constant, coeffs);
1018        self.add_row(row).map_err(|_| Conflict {
1019            constraints: vec![constraint_id],
1020        })?;
1021
1022        // Add appropriate bound on slack_var based on bound_type
1023        let zero = FastRational::zero();
1024        match bound_type {
1025            BoundType::Lower => {
1026                // sum(coeffs) + constant >= 0  =>  slack_var >= 0
1027                self.add_bound(slack_var, BoundType::Lower, zero, constraint_id)?;
1028            }
1029            BoundType::Upper => {
1030                // sum(coeffs) + constant <= 0  =>  slack_var <= 0
1031                self.add_bound(slack_var, BoundType::Upper, zero, constraint_id)?;
1032            }
1033            BoundType::Equal => {
1034                // sum(coeffs) + constant == 0  =>  slack_var == 0
1035                self.add_bound(slack_var, BoundType::Equal, zero.clone(), constraint_id)?;
1036                // Equal means both lower and upper bound
1037                self.add_bound(slack_var, BoundType::Lower, zero.clone(), constraint_id)?;
1038                self.add_bound(slack_var, BoundType::Upper, zero, constraint_id)?;
1039            }
1040        }
1041
1042        Ok(slack_var)
1043    }
1044}
1045
1046impl Default for SimplexTableau {
1047    fn default() -> Self {
1048        Self::new()
1049    }
1050}
1051
1052impl fmt::Display for SimplexTableau {
1053    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1054        writeln!(f, "Simplex Tableau:")?;
1055        writeln!(f, "  Basic variables: {:?}", self.basic_vars())?;
1056        writeln!(f, "  Non-basic variables: {:?}", self.non_basic_vars())?;
1057        writeln!(f, "  Rows:")?;
1058        for row in self.rows.values() {
1059            writeln!(f, "    {}", row)?;
1060        }
1061        writeln!(f, "  Assignment:")?;
1062        for var in self.vars() {
1063            if let Some(val) = self.assignment.get(&var) {
1064                write!(f, "    x{} = {}", var, val)?;
1065                if let Some(lb) = self.lower_bounds.get(&var) {
1066                    write!(f, " (>= {})", lb)?;
1067                }
1068                if let Some(ub) = self.upper_bounds.get(&var) {
1069                    write!(f, " (<= {})", ub)?;
1070                }
1071                writeln!(f)?;
1072            }
1073        }
1074        Ok(())
1075    }
1076}
1077
1078#[cfg(test)]
1079mod tests {
1080    use super::*;
1081
1082    fn rat(n: i64) -> FastRational {
1083        FastRational::from_integer(BigInt::from(n))
1084    }
1085
1086    #[test]
1087    fn test_row_eval() {
1088        let mut row = Row::new(0);
1089        row.constant = rat(5);
1090        row.coeffs.insert(1, rat(2));
1091        row.coeffs.insert(2, rat(3));
1092
1093        let mut values = FxHashMap::default();
1094        values.insert(1, rat(1));
1095        values.insert(2, rat(2));
1096
1097        // 5 + 2*1 + 3*2 = 5 + 2 + 6 = 13
1098        assert_eq!(row.eval(&values), rat(13));
1099    }
1100
1101    #[test]
1102    fn test_row_add() {
1103        let mut row1 = Row::new(0);
1104        row1.constant = rat(5);
1105        row1.coeffs.insert(1, rat(2));
1106
1107        let mut row2 = Row::new(0);
1108        row2.constant = rat(3);
1109        row2.coeffs.insert(1, rat(1));
1110        row2.coeffs.insert(2, rat(4));
1111
1112        // row1 += 2 * row2
1113        // row1 = 5 + 2*x1 + 2*(3 + x1 + 4*x2)
1114        // row1 = 5 + 2*x1 + 6 + 2*x1 + 8*x2
1115        // row1 = 11 + 4*x1 + 8*x2
1116        row1.add_row(&rat(2), &row2);
1117        assert_eq!(row1.constant, rat(11));
1118        assert_eq!(row1.coeffs.get(&1), Some(&rat(4)));
1119        assert_eq!(row1.coeffs.get(&2), Some(&rat(8)));
1120    }
1121
1122    #[test]
1123    fn test_simplex_basic() {
1124        let mut tableau = SimplexTableau::new();
1125
1126        // Variables: x0, x1
1127        // Constraint: x0 + x1 <= 10
1128        // Introduce slack: x2 = 10 - x0 - x1
1129
1130        let x0 = tableau.fresh_var();
1131        let x1 = tableau.fresh_var();
1132        let x2 = tableau.fresh_var();
1133
1134        // x2 = 10 - x0 - x1
1135        let mut row = Row::new(x2);
1136        row.constant = rat(10);
1137        row.coeffs.insert(x0, rat(-1));
1138        row.coeffs.insert(x1, rat(-1));
1139
1140        tableau.add_row(row).expect("test operation should succeed");
1141
1142        // Bounds: x0 >= 0, x1 >= 0, x2 >= 0
1143        tableau
1144            .add_bound(x0, BoundType::Lower, rat(0), 0)
1145            .expect("test operation should succeed");
1146        tableau
1147            .add_bound(x1, BoundType::Lower, rat(0), 1)
1148            .expect("test operation should succeed");
1149        tableau
1150            .add_bound(x2, BoundType::Lower, rat(0), 2)
1151            .expect("test operation should succeed");
1152
1153        let result = tableau.check().expect("test operation should succeed");
1154        assert_eq!(result, SimplexResult::Sat);
1155    }
1156
1157    #[test]
1158    fn test_simplex_infeasible() {
1159        let mut tableau = SimplexTableau::new();
1160
1161        let x = tableau.fresh_var();
1162
1163        // x >= 5 and x <= 3 (conflicting bounds)
1164        tableau
1165            .add_bound(x, BoundType::Lower, rat(5), 0)
1166            .expect("test operation should succeed");
1167        let result = tableau.add_bound(x, BoundType::Upper, rat(3), 1);
1168
1169        assert!(result.is_err());
1170    }
1171
1172    #[test]
1173    fn test_simplex_pivot() {
1174        let mut tableau = SimplexTableau::new();
1175
1176        let x0 = tableau.fresh_var();
1177        let x1 = tableau.fresh_var();
1178        let x2 = tableau.fresh_var();
1179
1180        // x2 = 10 - 2*x0 - 3*x1
1181        let mut row = Row::new(x2);
1182        row.constant = rat(10);
1183        row.coeffs.insert(x0, rat(-2));
1184        row.coeffs.insert(x1, rat(-3));
1185
1186        tableau.add_row(row).expect("test operation should succeed");
1187
1188        // Pivot x2 and x0
1189        tableau
1190            .pivot(x2, x0)
1191            .expect("test operation should succeed");
1192
1193        // After pivot: x0 = (10 - x2 - 3*x1) / 2 = 5 - x2/2 - 3*x1/2
1194        assert!(tableau.basic_vars.contains(&x0));
1195        assert!(tableau.non_basic_vars.contains(&x2));
1196
1197        let new_row = tableau.rows.get(&x0).expect("key should exist in map");
1198        assert_eq!(new_row.constant, rat(5));
1199    }
1200
1201    #[test]
1202    fn test_simplex_dual() {
1203        let mut tableau = SimplexTableau::new();
1204
1205        // Test dual simplex with a simple LP
1206        // Variables: x0, x1
1207        // Constraint: x0 + x1 <= 10
1208        // Bounds: x0 >= 0, x1 >= 0
1209
1210        let x0 = tableau.fresh_var();
1211        let x1 = tableau.fresh_var();
1212        let x2 = tableau.fresh_var(); // slack variable
1213
1214        // x2 = 10 - x0 - x1
1215        let mut row = Row::new(x2);
1216        row.constant = rat(10);
1217        row.coeffs.insert(x0, rat(-1));
1218        row.coeffs.insert(x1, rat(-1));
1219
1220        tableau.add_row(row).expect("test operation should succeed");
1221
1222        // Bounds: x0 >= 0, x1 >= 0, x2 >= 0
1223        tableau
1224            .add_bound(x0, BoundType::Lower, rat(0), 0)
1225            .expect("test operation should succeed");
1226        tableau
1227            .add_bound(x1, BoundType::Lower, rat(0), 1)
1228            .expect("test operation should succeed");
1229        tableau
1230            .add_bound(x2, BoundType::Lower, rat(0), 2)
1231            .expect("test operation should succeed");
1232
1233        // Use dual simplex
1234        let result = tableau.check_dual().expect("test operation should succeed");
1235        assert_eq!(result, SimplexResult::Sat);
1236
1237        // Verify the solution is feasible
1238        assert!(tableau.is_feasible());
1239    }
1240
1241    #[test]
1242    fn test_simplex_dual_with_bounds() {
1243        let mut tableau = SimplexTableau::new();
1244
1245        // Test dual simplex with tighter bounds
1246        let x0 = tableau.fresh_var();
1247        let x1 = tableau.fresh_var();
1248        let x2 = tableau.fresh_var();
1249
1250        // x2 = 10 - x0 - x1
1251        let mut row = Row::new(x2);
1252        row.constant = rat(10);
1253        row.coeffs.insert(x0, rat(-1));
1254        row.coeffs.insert(x1, rat(-1));
1255
1256        tableau.add_row(row).expect("test operation should succeed");
1257
1258        // Bounds: 0 <= x0 <= 5, 0 <= x1 <= 5, x2 >= 0
1259        tableau
1260            .add_bound(x0, BoundType::Lower, rat(0), 0)
1261            .expect("test operation should succeed");
1262        tableau
1263            .add_bound(x0, BoundType::Upper, rat(5), 1)
1264            .expect("test operation should succeed");
1265        tableau
1266            .add_bound(x1, BoundType::Lower, rat(0), 2)
1267            .expect("test operation should succeed");
1268        tableau
1269            .add_bound(x1, BoundType::Upper, rat(5), 3)
1270            .expect("test operation should succeed");
1271        tableau
1272            .add_bound(x2, BoundType::Lower, rat(0), 4)
1273            .expect("test operation should succeed");
1274
1275        // Use dual simplex
1276        let result = tableau.check_dual().expect("test operation should succeed");
1277        assert_eq!(result, SimplexResult::Sat);
1278
1279        // Verify the solution is feasible
1280        assert!(tableau.is_feasible());
1281    }
1282}