Skip to main content

oxiz_math/
simplex.rs

1//! Simplex algorithm for linear arithmetic.
2//!
3//! This module implements the Simplex algorithm for solving linear programming
4//! problems and checking satisfiability of linear real arithmetic constraints.
5//!
6//! The implementation follows the two-phase simplex method:
7//! - Phase 1: Find a basic feasible solution (or prove infeasibility)
8//! - Phase 2: Optimize the objective function (or detect unboundedness)
9//!
10//! Reference: Z3's `math/simplex/` directory and standard LP textbooks.
11
12use num_bigint::BigInt;
13use num_rational::BigRational;
14use num_traits::{One, Signed, Zero};
15use rustc_hash::{FxHashMap, FxHashSet};
16use std::fmt;
17
18/// Variable identifier for simplex.
19pub type VarId = u32;
20
21/// Constraint identifier.
22pub type ConstraintId = u32;
23
24/// Bound type for a variable.
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum BoundType {
27    /// Lower bound: x >= value
28    Lower,
29    /// Upper bound: x <= value
30    Upper,
31    /// Equality: x = value
32    Equal,
33}
34
35/// A bound on a variable.
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub struct Bound {
38    /// The variable this bound applies to.
39    pub var: VarId,
40    /// The type of bound (lower, upper, or equality).
41    pub bound_type: BoundType,
42    /// The bound value.
43    pub value: BigRational,
44}
45
46impl Bound {
47    /// Create a lower bound: var >= value.
48    pub fn lower(var: VarId, value: BigRational) -> Self {
49        Self {
50            var,
51            bound_type: BoundType::Lower,
52            value,
53        }
54    }
55
56    /// Create an upper bound: var <= value.
57    pub fn upper(var: VarId, value: BigRational) -> Self {
58        Self {
59            var,
60            bound_type: BoundType::Upper,
61            value,
62        }
63    }
64
65    /// Create an equality bound: var = value.
66    pub fn equal(var: VarId, value: BigRational) -> Self {
67        Self {
68            var,
69            bound_type: BoundType::Equal,
70            value,
71        }
72    }
73}
74
75/// Variable classification in the tableau.
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
77pub enum VarClass {
78    /// Basic variable (appears on LHS of a row).
79    Basic,
80    /// Non-basic variable (appears on RHS).
81    NonBasic,
82}
83
84/// A row in the simplex tableau represents: basic_var = constant + sum(coeff_i * nonbasic_var_i).
85#[derive(Debug, Clone)]
86pub struct Row {
87    /// The basic variable for this row.
88    pub basic_var: VarId,
89    /// The constant term.
90    pub constant: BigRational,
91    /// Coefficients for non-basic variables: var_id -> coefficient.
92    pub coeffs: FxHashMap<VarId, BigRational>,
93}
94
95impl Row {
96    /// Create a new row with a basic variable.
97    pub fn new(basic_var: VarId) -> Self {
98        Self {
99            basic_var,
100            constant: BigRational::zero(),
101            coeffs: FxHashMap::default(),
102        }
103    }
104
105    /// Create a row representing: basic_var = constant + sum(coeffs).
106    pub fn from_expr(
107        basic_var: VarId,
108        constant: BigRational,
109        coeffs: FxHashMap<VarId, BigRational>,
110    ) -> Self {
111        let mut row = Self {
112            basic_var,
113            constant,
114            coeffs: FxHashMap::default(),
115        };
116        for (var, coeff) in coeffs {
117            if !coeff.is_zero() {
118                row.coeffs.insert(var, coeff);
119            }
120        }
121        row
122    }
123
124    /// Get the value of the basic variable given values of non-basic variables.
125    pub fn eval(&self, non_basic_values: &FxHashMap<VarId, BigRational>) -> BigRational {
126        let mut value = self.constant.clone();
127        for (var, coeff) in &self.coeffs {
128            if let Some(val) = non_basic_values.get(var) {
129                value += coeff * val;
130            }
131        }
132        value
133    }
134
135    /// Add a multiple of another row to this row.
136    /// self += multiplier * other
137    pub fn add_row(&mut self, multiplier: &BigRational, other: &Row) {
138        if multiplier.is_zero() {
139            return;
140        }
141
142        self.constant += multiplier * &other.constant;
143
144        for (var, coeff) in &other.coeffs {
145            let new_coeff = self
146                .coeffs
147                .get(var)
148                .cloned()
149                .unwrap_or_else(BigRational::zero)
150                + multiplier * coeff;
151            if new_coeff.is_zero() {
152                self.coeffs.remove(var);
153            } else {
154                self.coeffs.insert(*var, new_coeff);
155            }
156        }
157    }
158
159    /// Multiply the row by a scalar.
160    pub fn scale(&mut self, scalar: &BigRational) {
161        if scalar.is_zero() {
162            self.constant = BigRational::zero();
163            self.coeffs.clear();
164            return;
165        }
166
167        self.constant *= scalar;
168        for coeff in self.coeffs.values_mut() {
169            *coeff *= scalar;
170        }
171    }
172
173    /// Substitute a non-basic variable using another row.
174    /// If var appears in this row, replace it using: var = row.constant + sum(row.coeffs).
175    pub fn substitute(&mut self, var: VarId, row: &Row) {
176        if let Some(coeff) = self.coeffs.remove(&var) {
177            // Add coeff * row to self
178            self.add_row(&coeff, row);
179        }
180    }
181
182    /// Check if the row is valid (no basic variable in RHS).
183    pub fn is_valid(&self, basic_vars: &FxHashSet<VarId>) -> bool {
184        for var in self.coeffs.keys() {
185            if basic_vars.contains(var) {
186                return false;
187            }
188        }
189        true
190    }
191
192    /// Normalize the row by dividing by the GCD of coefficients.
193    pub fn normalize(&mut self) {
194        if self.coeffs.is_empty() {
195            return;
196        }
197
198        // Compute GCD of all coefficients
199        let mut gcd: Option<BigInt> = None;
200        for coeff in self.coeffs.values() {
201            if !coeff.is_zero() {
202                let num = coeff.numer().clone();
203                gcd = Some(match gcd {
204                    None => num.abs(),
205                    Some(g) => gcd_bigint(g, num.abs()),
206                });
207            }
208        }
209
210        if let Some(g) = gcd
211            && !g.is_one()
212        {
213            let divisor = BigRational::from_integer(g);
214            self.scale(&(BigRational::one() / divisor));
215        }
216    }
217}
218
219impl fmt::Display for Row {
220    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
221        write!(f, "x{} = {}", self.basic_var, self.constant)?;
222        for (var, coeff) in &self.coeffs {
223            if coeff.is_positive() {
224                write!(f, " + {}*x{}", coeff, var)?;
225            } else {
226                write!(f, " - {}*x{}", -coeff, var)?;
227            }
228        }
229        Ok(())
230    }
231}
232
233/// Compute GCD of two BigInts using Euclidean algorithm.
234fn gcd_bigint(mut a: BigInt, mut b: BigInt) -> BigInt {
235    while !b.is_zero() {
236        let t = &a % &b;
237        a = b;
238        b = t;
239    }
240    a.abs()
241}
242
243/// Result of a simplex operation.
244#[derive(Debug, Clone, PartialEq, Eq)]
245pub enum SimplexResult {
246    /// The system is satisfiable.
247    Sat,
248    /// The system is unsatisfiable.
249    Unsat,
250    /// The objective is unbounded.
251    Unbounded,
252    /// Unknown (shouldn't happen in simplex).
253    Unknown,
254}
255
256/// Explanation for why a constraint caused unsatisfiability.
257#[derive(Debug, Clone)]
258pub struct Conflict {
259    /// The constraints involved in the conflict.
260    pub constraints: Vec<ConstraintId>,
261}
262
263/// The Simplex tableau.
264#[derive(Debug, Clone)]
265pub struct SimplexTableau {
266    /// Rows of the tableau, indexed by basic variable.
267    rows: FxHashMap<VarId, Row>,
268    /// Set of basic variables.
269    basic_vars: FxHashSet<VarId>,
270    /// Set of non-basic variables.
271    non_basic_vars: FxHashSet<VarId>,
272    /// Current assignment to all variables.
273    assignment: FxHashMap<VarId, BigRational>,
274    /// Lower bounds for variables.
275    lower_bounds: FxHashMap<VarId, BigRational>,
276    /// Upper bounds for variables.
277    upper_bounds: FxHashMap<VarId, BigRational>,
278    /// Mapping from variables to the constraints that bound them.
279    var_to_constraints: FxHashMap<VarId, Vec<ConstraintId>>,
280    /// Next fresh variable ID.
281    next_var_id: VarId,
282    /// Use Bland's rule to prevent cycling.
283    use_blands_rule: bool,
284}
285
286impl SimplexTableau {
287    /// Create a new empty tableau.
288    pub fn new() -> Self {
289        Self {
290            rows: FxHashMap::default(),
291            basic_vars: FxHashSet::default(),
292            non_basic_vars: FxHashSet::default(),
293            assignment: FxHashMap::default(),
294            lower_bounds: FxHashMap::default(),
295            upper_bounds: FxHashMap::default(),
296            var_to_constraints: FxHashMap::default(),
297            next_var_id: 0,
298            use_blands_rule: true,
299        }
300    }
301
302    /// Create a fresh variable.
303    pub fn fresh_var(&mut self) -> VarId {
304        let id = self.next_var_id;
305        self.next_var_id += 1;
306        self.non_basic_vars.insert(id);
307        self.assignment.insert(id, BigRational::zero());
308        id
309    }
310
311    /// Add a variable with initial bounds.
312    pub fn add_var(&mut self, lower: Option<BigRational>, upper: Option<BigRational>) -> VarId {
313        let var = self.fresh_var();
314        if let Some(lb) = lower {
315            self.lower_bounds.insert(var, lb.clone());
316            self.assignment.insert(var, lb);
317        }
318        if let Some(ub) = upper {
319            self.upper_bounds.insert(var, ub);
320        }
321        var
322    }
323
324    /// Add a row to the tableau.
325    /// The row represents: basic_var = constant + sum(coeffs).
326    pub fn add_row(&mut self, row: Row) -> Result<(), String> {
327        let basic_var = row.basic_var;
328
329        // Ensure basic_var is not already basic
330        if self.basic_vars.contains(&basic_var) {
331            return Err(format!("Variable x{} is already basic", basic_var));
332        }
333
334        // Ensure all non-basic vars in coeffs are indeed non-basic
335        for var in row.coeffs.keys() {
336            if self.basic_vars.contains(var) {
337                return Err(format!("Variable x{} appears in RHS but is basic", var));
338            }
339            self.non_basic_vars.insert(*var);
340        }
341
342        // Move basic_var from non-basic to basic
343        self.non_basic_vars.remove(&basic_var);
344        self.basic_vars.insert(basic_var);
345
346        // Compute initial value for basic_var
347        let value = row.eval(&self.assignment);
348        self.assignment.insert(basic_var, value);
349
350        self.rows.insert(basic_var, row);
351        Ok(())
352    }
353
354    /// Add a bound constraint.
355    pub fn add_bound(
356        &mut self,
357        var: VarId,
358        bound_type: BoundType,
359        value: BigRational,
360        constraint_id: ConstraintId,
361    ) -> Result<(), Conflict> {
362        self.var_to_constraints
363            .entry(var)
364            .or_default()
365            .push(constraint_id);
366
367        match bound_type {
368            BoundType::Lower => {
369                let current_lb = self.lower_bounds.get(&var);
370                if let Some(lb) = current_lb {
371                    if &value > lb {
372                        self.lower_bounds.insert(var, value.clone());
373                    }
374                } else {
375                    self.lower_bounds.insert(var, value.clone());
376                }
377
378                // Check for immediate conflict
379                if let Some(ub) = self.upper_bounds.get(&var)
380                    && &value > ub
381                {
382                    return Err(Conflict {
383                        constraints: vec![constraint_id],
384                    });
385                }
386            }
387            BoundType::Upper => {
388                let current_ub = self.upper_bounds.get(&var);
389                if let Some(ub) = current_ub {
390                    if &value < ub {
391                        self.upper_bounds.insert(var, value.clone());
392                    }
393                } else {
394                    self.upper_bounds.insert(var, value.clone());
395                }
396
397                // Check for immediate conflict
398                if let Some(lb) = self.lower_bounds.get(&var)
399                    && &value < lb
400                {
401                    return Err(Conflict {
402                        constraints: vec![constraint_id],
403                    });
404                }
405            }
406            BoundType::Equal => {
407                self.lower_bounds.insert(var, value.clone());
408                self.upper_bounds.insert(var, value.clone());
409
410                // Check existing bounds
411                if let Some(lb) = self.lower_bounds.get(&var)
412                    && &value < lb
413                {
414                    return Err(Conflict {
415                        constraints: vec![constraint_id],
416                    });
417                }
418                if let Some(ub) = self.upper_bounds.get(&var)
419                    && &value > ub
420                {
421                    return Err(Conflict {
422                        constraints: vec![constraint_id],
423                    });
424                }
425            }
426        }
427
428        Ok(())
429    }
430
431    /// Get the current assignment to a variable.
432    pub fn get_value(&self, var: VarId) -> Option<&BigRational> {
433        self.assignment.get(&var)
434    }
435
436    /// Check if a basic variable violates its bounds.
437    fn violates_bounds(&self, var: VarId) -> bool {
438        if let Some(val) = self.assignment.get(&var) {
439            if let Some(lb) = self.lower_bounds.get(&var)
440                && val < lb
441            {
442                return true;
443            }
444            if let Some(ub) = self.upper_bounds.get(&var)
445                && val > ub
446            {
447                return true;
448            }
449        }
450        false
451    }
452
453    /// Find a basic variable that violates its bounds.
454    fn find_violating_basic_var(&self) -> Option<VarId> {
455        if self.use_blands_rule {
456            // Use Bland's rule: pick the smallest index
457            self.basic_vars
458                .iter()
459                .filter(|&&var| self.violates_bounds(var))
460                .min()
461                .copied()
462        } else {
463            self.basic_vars
464                .iter()
465                .find(|&&var| self.violates_bounds(var))
466                .copied()
467        }
468    }
469
470    /// Find a non-basic variable to pivot with.
471    /// Returns (non_basic_var, improving_direction).
472    fn find_pivot_non_basic(
473        &self,
474        basic_var: VarId,
475        target_increase: bool,
476    ) -> Option<(VarId, bool)> {
477        let row = self.rows.get(&basic_var)?;
478
479        let mut candidates = Vec::new();
480
481        for (nb_var, coeff) in &row.coeffs {
482            let current_val = self.assignment.get(nb_var)?;
483            let lb = self.lower_bounds.get(nb_var);
484            let ub = self.upper_bounds.get(nb_var);
485
486            // Determine if we can increase or decrease nb_var
487            let can_increase = ub.is_none_or(|bound| bound > current_val);
488            let can_decrease = lb.is_none_or(|bound| bound < current_val);
489
490            // If coeff > 0: increasing nb_var increases basic_var
491            // If coeff < 0: increasing nb_var decreases basic_var
492            let increases_basic = coeff.is_positive();
493
494            if target_increase {
495                // We want to increase basic_var
496                if increases_basic && can_increase {
497                    candidates.push((*nb_var, true));
498                } else if !increases_basic && can_decrease {
499                    candidates.push((*nb_var, false));
500                }
501            } else {
502                // We want to decrease basic_var
503                if increases_basic && can_decrease {
504                    candidates.push((*nb_var, false));
505                } else if !increases_basic && can_increase {
506                    candidates.push((*nb_var, true));
507                }
508            }
509        }
510
511        if candidates.is_empty() {
512            return None;
513        }
514
515        // Use Bland's rule: pick smallest index
516        if self.use_blands_rule {
517            candidates.sort_by_key(|(var, _)| *var);
518        }
519
520        Some(candidates[0])
521    }
522
523    /// Perform a pivot operation.
524    /// Swap basic_var (currently basic) with non_basic_var (currently non-basic).
525    pub fn pivot(&mut self, basic_var: VarId, non_basic_var: VarId) -> Result<(), String> {
526        // Get the row for basic_var
527        let row = self
528            .rows
529            .get(&basic_var)
530            .ok_or_else(|| format!("No row for basic variable x{}", basic_var))?;
531
532        // Get coefficient of non_basic_var in this row
533        let coeff = row
534            .coeffs
535            .get(&non_basic_var)
536            .ok_or_else(|| {
537                format!(
538                    "Non-basic variable x{} not in row for x{}",
539                    non_basic_var, basic_var
540                )
541            })?
542            .clone();
543
544        if coeff.is_zero() {
545            return Err(format!("Coefficient of x{} is zero", non_basic_var));
546        }
547
548        // Solve for non_basic_var in terms of basic_var
549        // Old: basic_var = constant + coeff * non_basic_var + sum(...)
550        // New: non_basic_var = (basic_var - constant - sum(...)) / coeff
551
552        let mut new_row = Row::new(non_basic_var);
553        new_row.constant = -&row.constant / &coeff;
554        new_row
555            .coeffs
556            .insert(basic_var, BigRational::one() / &coeff);
557
558        for (var, c) in &row.coeffs {
559            if var != &non_basic_var {
560                new_row.coeffs.insert(*var, -c / &coeff);
561            }
562        }
563
564        // Substitute non_basic_var in all other rows
565        let rows_to_update: Vec<VarId> = self
566            .rows
567            .keys()
568            .filter(|&&v| v != basic_var)
569            .copied()
570            .collect();
571
572        for row_var in rows_to_update {
573            if let Some(r) = self.rows.get_mut(&row_var) {
574                r.substitute(non_basic_var, &new_row);
575            }
576        }
577
578        // Remove old row and add new row
579        self.rows.remove(&basic_var);
580        self.rows.insert(non_basic_var, new_row);
581
582        // Update basic/non-basic sets
583        self.basic_vars.remove(&basic_var);
584        self.basic_vars.insert(non_basic_var);
585        self.non_basic_vars.remove(&non_basic_var);
586        self.non_basic_vars.insert(basic_var);
587
588        // Update assignments
589        self.update_assignment();
590
591        Ok(())
592    }
593
594    /// Update the assignment based on current tableau.
595    fn update_assignment(&mut self) {
596        // Evaluate all basic variables
597        for (basic_var, row) in &self.rows {
598            let value = row.eval(&self.assignment);
599            self.assignment.insert(*basic_var, value);
600        }
601    }
602
603    /// Check feasibility and fix violations using pivoting.
604    pub fn check(&mut self) -> Result<SimplexResult, Conflict> {
605        let max_iterations = 10000;
606        let mut iterations = 0;
607
608        while let Some(violating_var) = self.find_violating_basic_var() {
609            iterations += 1;
610            if iterations > max_iterations {
611                return Ok(SimplexResult::Unknown);
612            }
613
614            let current_val = self
615                .assignment
616                .get(&violating_var)
617                .cloned()
618                .ok_or_else(|| Conflict {
619                    constraints: vec![],
620                })?;
621
622            let lb = self.lower_bounds.get(&violating_var);
623            let ub = self.upper_bounds.get(&violating_var);
624
625            // Determine if we need to increase or decrease the variable
626            let need_increase = lb.is_some_and(|l| &current_val < l);
627            let need_decrease = ub.is_some_and(|u| &current_val > u);
628
629            if !need_increase && !need_decrease {
630                continue;
631            }
632
633            // Find a non-basic variable to pivot with
634            if let Some((nb_var, _direction)) =
635                self.find_pivot_non_basic(violating_var, need_increase)
636            {
637                // Compute the new value for nb_var
638                let target_value = if need_increase {
639                    lb.cloned().unwrap_or_else(BigRational::zero)
640                } else {
641                    ub.cloned().unwrap_or_else(BigRational::zero)
642                };
643
644                // Update nb_var to move basic_var to target
645                let row = self.rows.get(&violating_var).ok_or_else(|| Conflict {
646                    constraints: vec![],
647                })?;
648                let coeff = row.coeffs.get(&nb_var).cloned().ok_or_else(|| Conflict {
649                    constraints: vec![],
650                })?;
651
652                let delta = &target_value - &current_val;
653                let nb_delta = &delta / &coeff;
654                let current_nb = self
655                    .assignment
656                    .get(&nb_var)
657                    .cloned()
658                    .ok_or_else(|| Conflict {
659                        constraints: vec![],
660                    })?;
661                let new_nb = current_nb + nb_delta;
662
663                // Clamp to bounds
664                let new_nb = if let Some(lb) = self.lower_bounds.get(&nb_var) {
665                    new_nb.max(lb.clone())
666                } else {
667                    new_nb
668                };
669                let new_nb = if let Some(ub) = self.upper_bounds.get(&nb_var) {
670                    new_nb.min(ub.clone())
671                } else {
672                    new_nb
673                };
674
675                self.assignment.insert(nb_var, new_nb);
676                self.update_assignment();
677            } else {
678                // No pivot found - problem is infeasible or unbounded
679                let constraints = self
680                    .var_to_constraints
681                    .get(&violating_var)
682                    .cloned()
683                    .unwrap_or_default();
684                return Err(Conflict { constraints });
685            }
686        }
687
688        Ok(SimplexResult::Sat)
689    }
690
691    /// Dual simplex algorithm for Linear Programming.
692    ///
693    /// The dual simplex maintains dual feasibility (all reduced costs non-negative)
694    /// while working toward primal feasibility (all variables within bounds).
695    ///
696    /// This is particularly useful for:
697    /// - Reoptimization after adding constraints
698    /// - Branch-and-bound in integer programming
699    /// - Problems that naturally start dual-feasible
700    ///
701    /// Reference: Standard LP textbooks and Z3's dual simplex implementation.
702    pub fn check_dual(&mut self) -> Result<SimplexResult, Conflict> {
703        let max_iterations = 10000;
704        let mut iterations = 0;
705
706        // Dual simplex loop: restore primal feasibility
707        while let Some(leaving_var) = self.find_violating_basic_var() {
708            iterations += 1;
709            if iterations > max_iterations {
710                return Ok(SimplexResult::Unknown);
711            }
712
713            // Get the row for the leaving variable
714            let row = match self.rows.get(&leaving_var) {
715                Some(r) => r.clone(),
716                None => continue,
717            };
718
719            let current_val = self
720                .assignment
721                .get(&leaving_var)
722                .cloned()
723                .unwrap_or_else(BigRational::zero);
724
725            let lb = self.lower_bounds.get(&leaving_var);
726            let ub = self.upper_bounds.get(&leaving_var);
727
728            // Determine which bound is violated
729            let violated_lower = lb.is_some_and(|l| &current_val < l);
730            let violated_upper = ub.is_some_and(|u| &current_val > u);
731
732            if !violated_lower && !violated_upper {
733                continue;
734            }
735
736            // For dual simplex, find entering variable using dual pricing rule
737            let entering_var = self.find_entering_var_dual(&row, violated_lower)?;
738
739            // Pivot leaving_var out, entering_var in
740            self.pivot(leaving_var, entering_var)
741                .map_err(|_| Conflict {
742                    constraints: self
743                        .var_to_constraints
744                        .get(&leaving_var)
745                        .cloned()
746                        .unwrap_or_default(),
747                })?;
748        }
749
750        Ok(SimplexResult::Sat)
751    }
752
753    /// Find the entering variable for dual simplex using dual pricing rule.
754    ///
755    /// For dual simplex:
756    /// - If leaving var violates lower bound: choose entering var with negative coeff
757    /// - If leaving var violates upper bound: choose entering var with positive coeff
758    /// - Among candidates, choose one that maintains dual feasibility
759    fn find_entering_var_dual(&self, row: &Row, violated_lower: bool) -> Result<VarId, Conflict> {
760        let mut best_var = None;
761        let mut best_ratio: Option<BigRational> = None;
762
763        // Iterate through non-basic variables in the row
764        for (&nb_var, coeff) in &row.coeffs {
765            // For dual simplex pricing:
766            // - If violated_lower, we need coeff < 0 (to increase basic var)
767            // - If violated_upper, we need coeff > 0 (to decrease basic var)
768            let sign_ok = if violated_lower {
769                coeff.is_negative()
770            } else {
771                coeff.is_positive()
772            };
773
774            if !sign_ok || coeff.is_zero() {
775                continue;
776            }
777
778            // Compute the dual ratio for maintaining dual feasibility
779            // This is a simplified version; full implementation would compute reduced costs
780            let ratio = coeff.abs();
781
782            match &best_ratio {
783                None => {
784                    best_ratio = Some(ratio);
785                    best_var = Some(nb_var);
786                }
787                Some(current_best) => {
788                    // Choose the variable with smallest ratio (steepest descent in dual)
789                    // Use Bland's rule for tie-breaking if enabled
790                    if &ratio < current_best {
791                        best_ratio = Some(ratio);
792                        best_var = Some(nb_var);
793                    } else if self.use_blands_rule && &ratio == current_best {
794                        // Bland's rule: choose smaller index
795                        // best_var is guaranteed Some when best_ratio is Some
796                        if best_var.is_none_or(|bv| nb_var < bv) {
797                            best_var = Some(nb_var);
798                        }
799                    }
800                }
801            }
802        }
803
804        best_var.ok_or(Conflict {
805            constraints: vec![],
806        })
807    }
808
809    /// Get all variables.
810    pub fn vars(&self) -> Vec<VarId> {
811        let mut vars: Vec<VarId> = self
812            .basic_vars
813            .iter()
814            .chain(self.non_basic_vars.iter())
815            .copied()
816            .collect();
817        vars.sort_unstable();
818        vars
819    }
820
821    /// Get all basic variables.
822    pub fn basic_vars(&self) -> Vec<VarId> {
823        let mut vars: Vec<VarId> = self.basic_vars.iter().copied().collect();
824        vars.sort_unstable();
825        vars
826    }
827
828    /// Get all non-basic variables.
829    pub fn non_basic_vars(&self) -> Vec<VarId> {
830        let mut vars: Vec<VarId> = self.non_basic_vars.iter().copied().collect();
831        vars.sort_unstable();
832        vars
833    }
834
835    /// Get the number of rows in the tableau.
836    pub fn num_rows(&self) -> usize {
837        self.rows.len()
838    }
839
840    /// Get the number of variables.
841    pub fn num_vars(&self) -> usize {
842        self.basic_vars.len() + self.non_basic_vars.len()
843    }
844
845    /// Enable or disable Bland's anti-cycling rule.
846    pub fn set_blands_rule(&mut self, enable: bool) {
847        self.use_blands_rule = enable;
848    }
849
850    /// Get the current model (satisfying assignment).
851    /// Returns None if the system is not known to be satisfiable.
852    pub fn get_model(&self) -> Option<FxHashMap<VarId, BigRational>> {
853        // Check if all variables satisfy their bounds
854        for (var, val) in &self.assignment {
855            if let Some(lb) = self.lower_bounds.get(var)
856                && val < lb
857            {
858                return None;
859            }
860            if let Some(ub) = self.upper_bounds.get(var)
861                && val > ub
862            {
863                return None;
864            }
865        }
866        Some(self.assignment.clone())
867    }
868
869    /// Check if the current assignment satisfies all bounds.
870    pub fn is_feasible(&self) -> bool {
871        for (var, val) in &self.assignment {
872            if let Some(lb) = self.lower_bounds.get(var)
873                && val < lb
874            {
875                return false;
876            }
877            if let Some(ub) = self.upper_bounds.get(var)
878                && val > ub
879            {
880                return false;
881            }
882        }
883        true
884    }
885
886    /// Find a variable that violates its bounds, if any.
887    pub fn find_violated_bound(&self) -> Option<VarId> {
888        for (var, val) in &self.assignment {
889            if let Some(lb) = self.lower_bounds.get(var)
890                && val < lb
891            {
892                return Some(*var);
893            }
894            if let Some(ub) = self.upper_bounds.get(var)
895                && val > ub
896            {
897                return Some(*var);
898            }
899        }
900        None
901    }
902
903    /// Get all constraints associated with a variable.
904    pub fn get_constraints(&self, var: VarId) -> Vec<ConstraintId> {
905        self.var_to_constraints
906            .get(&var)
907            .cloned()
908            .unwrap_or_default()
909    }
910
911    /// Extract a minimal conflicting core from constraints.
912    /// This is a simple implementation that returns all constraints involved.
913    /// A more sophisticated version would compute a true minimal unsat core.
914    pub fn get_unsat_core(&self, conflict: &Conflict) -> Vec<ConstraintId> {
915        conflict.constraints.clone()
916    }
917
918    /// Theory propagation: deduce new bounds from existing constraints.
919    /// Returns a list of (var, bound_type, value) tuples representing deduced bounds.
920    pub fn propagate(&self) -> Vec<(VarId, BoundType, BigRational)> {
921        let mut propagated = Vec::new();
922
923        // For each row: basic_var = constant + sum(coeff * non_basic_var)
924        // We can deduce bounds on basic_var from bounds on non_basic vars
925        for row in self.rows.values() {
926            let basic_var = row.basic_var;
927
928            // Compute lower bound: min value of basic_var
929            // basic_var >= constant + sum(min(coeff * lb, coeff * ub))
930            let mut lower_bound = row.constant.clone();
931            let mut has_finite_lower = true;
932
933            for (var, coeff) in &row.coeffs {
934                if coeff.is_positive() {
935                    // Positive coeff: use lower bound of var
936                    if let Some(lb) = self.lower_bounds.get(var) {
937                        lower_bound += coeff * lb;
938                    } else {
939                        has_finite_lower = false;
940                        break;
941                    }
942                } else {
943                    // Negative coeff: use upper bound of var
944                    if let Some(ub) = self.upper_bounds.get(var) {
945                        lower_bound += coeff * ub;
946                    } else {
947                        has_finite_lower = false;
948                        break;
949                    }
950                }
951            }
952
953            if has_finite_lower {
954                // Check if this is a tighter bound
955                if let Some(current_lb) = self.lower_bounds.get(&basic_var) {
956                    if &lower_bound > current_lb {
957                        propagated.push((basic_var, BoundType::Lower, lower_bound.clone()));
958                    }
959                } else {
960                    propagated.push((basic_var, BoundType::Lower, lower_bound.clone()));
961                }
962            }
963
964            // Compute upper bound: max value of basic_var
965            let mut upper_bound = row.constant.clone();
966            let mut has_finite_upper = true;
967
968            for (var, coeff) in &row.coeffs {
969                if coeff.is_positive() {
970                    // Positive coeff: use upper bound of var
971                    if let Some(ub) = self.upper_bounds.get(var) {
972                        upper_bound += coeff * ub;
973                    } else {
974                        has_finite_upper = false;
975                        break;
976                    }
977                } else {
978                    // Negative coeff: use lower bound of var
979                    if let Some(lb) = self.lower_bounds.get(var) {
980                        upper_bound += coeff * lb;
981                    } else {
982                        has_finite_upper = false;
983                        break;
984                    }
985                }
986            }
987
988            if has_finite_upper {
989                // Check if this is a tighter bound
990                if let Some(current_ub) = self.upper_bounds.get(&basic_var) {
991                    if &upper_bound < current_ub {
992                        propagated.push((basic_var, BoundType::Upper, upper_bound.clone()));
993                    }
994                } else {
995                    propagated.push((basic_var, BoundType::Upper, upper_bound.clone()));
996                }
997            }
998        }
999
1000        propagated
1001    }
1002
1003    /// Assert a linear constraint: sum(coeffs) + constant {<=, >=, ==} 0.
1004    /// Returns a new slack variable if needed, and the constraint ID.
1005    pub fn assert_constraint(
1006        &mut self,
1007        coeffs: FxHashMap<VarId, BigRational>,
1008        constant: BigRational,
1009        bound_type: BoundType,
1010        constraint_id: ConstraintId,
1011    ) -> Result<VarId, Conflict> {
1012        // Create a fresh slack variable for the constraint
1013        let slack_var = self.fresh_var();
1014
1015        // slack_var = sum(coeffs) + constant
1016        let row = Row::from_expr(slack_var, constant, coeffs);
1017        self.add_row(row).map_err(|_| Conflict {
1018            constraints: vec![constraint_id],
1019        })?;
1020
1021        // Add appropriate bound on slack_var based on bound_type
1022        let zero = BigRational::zero();
1023        match bound_type {
1024            BoundType::Lower => {
1025                // sum(coeffs) + constant >= 0  =>  slack_var >= 0
1026                self.add_bound(slack_var, BoundType::Lower, zero, constraint_id)?;
1027            }
1028            BoundType::Upper => {
1029                // sum(coeffs) + constant <= 0  =>  slack_var <= 0
1030                self.add_bound(slack_var, BoundType::Upper, zero, constraint_id)?;
1031            }
1032            BoundType::Equal => {
1033                // sum(coeffs) + constant == 0  =>  slack_var == 0
1034                self.add_bound(slack_var, BoundType::Equal, zero.clone(), constraint_id)?;
1035                // Equal means both lower and upper bound
1036                self.add_bound(slack_var, BoundType::Lower, zero.clone(), constraint_id)?;
1037                self.add_bound(slack_var, BoundType::Upper, zero, constraint_id)?;
1038            }
1039        }
1040
1041        Ok(slack_var)
1042    }
1043}
1044
1045impl Default for SimplexTableau {
1046    fn default() -> Self {
1047        Self::new()
1048    }
1049}
1050
1051impl fmt::Display for SimplexTableau {
1052    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1053        writeln!(f, "Simplex Tableau:")?;
1054        writeln!(f, "  Basic variables: {:?}", self.basic_vars())?;
1055        writeln!(f, "  Non-basic variables: {:?}", self.non_basic_vars())?;
1056        writeln!(f, "  Rows:")?;
1057        for row in self.rows.values() {
1058            writeln!(f, "    {}", row)?;
1059        }
1060        writeln!(f, "  Assignment:")?;
1061        for var in self.vars() {
1062            if let Some(val) = self.assignment.get(&var) {
1063                write!(f, "    x{} = {}", var, val)?;
1064                if let Some(lb) = self.lower_bounds.get(&var) {
1065                    write!(f, " (>= {})", lb)?;
1066                }
1067                if let Some(ub) = self.upper_bounds.get(&var) {
1068                    write!(f, " (<= {})", ub)?;
1069                }
1070                writeln!(f)?;
1071            }
1072        }
1073        Ok(())
1074    }
1075}
1076
1077#[cfg(test)]
1078mod tests {
1079    use super::*;
1080
1081    fn rat(n: i64) -> BigRational {
1082        BigRational::from_integer(BigInt::from(n))
1083    }
1084
1085    #[test]
1086    fn test_row_eval() {
1087        let mut row = Row::new(0);
1088        row.constant = rat(5);
1089        row.coeffs.insert(1, rat(2));
1090        row.coeffs.insert(2, rat(3));
1091
1092        let mut values = FxHashMap::default();
1093        values.insert(1, rat(1));
1094        values.insert(2, rat(2));
1095
1096        // 5 + 2*1 + 3*2 = 5 + 2 + 6 = 13
1097        assert_eq!(row.eval(&values), rat(13));
1098    }
1099
1100    #[test]
1101    fn test_row_add() {
1102        let mut row1 = Row::new(0);
1103        row1.constant = rat(5);
1104        row1.coeffs.insert(1, rat(2));
1105
1106        let mut row2 = Row::new(0);
1107        row2.constant = rat(3);
1108        row2.coeffs.insert(1, rat(1));
1109        row2.coeffs.insert(2, rat(4));
1110
1111        // row1 += 2 * row2
1112        // row1 = 5 + 2*x1 + 2*(3 + x1 + 4*x2)
1113        // row1 = 5 + 2*x1 + 6 + 2*x1 + 8*x2
1114        // row1 = 11 + 4*x1 + 8*x2
1115        row1.add_row(&rat(2), &row2);
1116        assert_eq!(row1.constant, rat(11));
1117        assert_eq!(row1.coeffs.get(&1), Some(&rat(4)));
1118        assert_eq!(row1.coeffs.get(&2), Some(&rat(8)));
1119    }
1120
1121    #[test]
1122    fn test_simplex_basic() {
1123        let mut tableau = SimplexTableau::new();
1124
1125        // Variables: x0, x1
1126        // Constraint: x0 + x1 <= 10
1127        // Introduce slack: x2 = 10 - x0 - x1
1128
1129        let x0 = tableau.fresh_var();
1130        let x1 = tableau.fresh_var();
1131        let x2 = tableau.fresh_var();
1132
1133        // x2 = 10 - x0 - x1
1134        let mut row = Row::new(x2);
1135        row.constant = rat(10);
1136        row.coeffs.insert(x0, rat(-1));
1137        row.coeffs.insert(x1, rat(-1));
1138
1139        tableau.add_row(row).unwrap();
1140
1141        // Bounds: x0 >= 0, x1 >= 0, x2 >= 0
1142        tableau.add_bound(x0, BoundType::Lower, rat(0), 0).unwrap();
1143        tableau.add_bound(x1, BoundType::Lower, rat(0), 1).unwrap();
1144        tableau.add_bound(x2, BoundType::Lower, rat(0), 2).unwrap();
1145
1146        let result = tableau.check().unwrap();
1147        assert_eq!(result, SimplexResult::Sat);
1148    }
1149
1150    #[test]
1151    fn test_simplex_infeasible() {
1152        let mut tableau = SimplexTableau::new();
1153
1154        let x = tableau.fresh_var();
1155
1156        // x >= 5 and x <= 3 (conflicting bounds)
1157        tableau.add_bound(x, BoundType::Lower, rat(5), 0).unwrap();
1158        let result = tableau.add_bound(x, BoundType::Upper, rat(3), 1);
1159
1160        assert!(result.is_err());
1161    }
1162
1163    #[test]
1164    fn test_simplex_pivot() {
1165        let mut tableau = SimplexTableau::new();
1166
1167        let x0 = tableau.fresh_var();
1168        let x1 = tableau.fresh_var();
1169        let x2 = tableau.fresh_var();
1170
1171        // x2 = 10 - 2*x0 - 3*x1
1172        let mut row = Row::new(x2);
1173        row.constant = rat(10);
1174        row.coeffs.insert(x0, rat(-2));
1175        row.coeffs.insert(x1, rat(-3));
1176
1177        tableau.add_row(row).unwrap();
1178
1179        // Pivot x2 and x0
1180        tableau.pivot(x2, x0).unwrap();
1181
1182        // After pivot: x0 = (10 - x2 - 3*x1) / 2 = 5 - x2/2 - 3*x1/2
1183        assert!(tableau.basic_vars.contains(&x0));
1184        assert!(tableau.non_basic_vars.contains(&x2));
1185
1186        let new_row = tableau.rows.get(&x0).unwrap();
1187        assert_eq!(new_row.constant, rat(5));
1188    }
1189
1190    #[test]
1191    fn test_simplex_dual() {
1192        let mut tableau = SimplexTableau::new();
1193
1194        // Test dual simplex with a simple LP
1195        // Variables: x0, x1
1196        // Constraint: x0 + x1 <= 10
1197        // Bounds: x0 >= 0, x1 >= 0
1198
1199        let x0 = tableau.fresh_var();
1200        let x1 = tableau.fresh_var();
1201        let x2 = tableau.fresh_var(); // slack variable
1202
1203        // x2 = 10 - x0 - x1
1204        let mut row = Row::new(x2);
1205        row.constant = rat(10);
1206        row.coeffs.insert(x0, rat(-1));
1207        row.coeffs.insert(x1, rat(-1));
1208
1209        tableau.add_row(row).unwrap();
1210
1211        // Bounds: x0 >= 0, x1 >= 0, x2 >= 0
1212        tableau.add_bound(x0, BoundType::Lower, rat(0), 0).unwrap();
1213        tableau.add_bound(x1, BoundType::Lower, rat(0), 1).unwrap();
1214        tableau.add_bound(x2, BoundType::Lower, rat(0), 2).unwrap();
1215
1216        // Use dual simplex
1217        let result = tableau.check_dual().unwrap();
1218        assert_eq!(result, SimplexResult::Sat);
1219
1220        // Verify the solution is feasible
1221        assert!(tableau.is_feasible());
1222    }
1223
1224    #[test]
1225    fn test_simplex_dual_with_bounds() {
1226        let mut tableau = SimplexTableau::new();
1227
1228        // Test dual simplex with tighter bounds
1229        let x0 = tableau.fresh_var();
1230        let x1 = tableau.fresh_var();
1231        let x2 = tableau.fresh_var();
1232
1233        // x2 = 10 - x0 - x1
1234        let mut row = Row::new(x2);
1235        row.constant = rat(10);
1236        row.coeffs.insert(x0, rat(-1));
1237        row.coeffs.insert(x1, rat(-1));
1238
1239        tableau.add_row(row).unwrap();
1240
1241        // Bounds: 0 <= x0 <= 5, 0 <= x1 <= 5, x2 >= 0
1242        tableau.add_bound(x0, BoundType::Lower, rat(0), 0).unwrap();
1243        tableau.add_bound(x0, BoundType::Upper, rat(5), 1).unwrap();
1244        tableau.add_bound(x1, BoundType::Lower, rat(0), 2).unwrap();
1245        tableau.add_bound(x1, BoundType::Upper, rat(5), 3).unwrap();
1246        tableau.add_bound(x2, BoundType::Lower, rat(0), 4).unwrap();
1247
1248        // Use dual simplex
1249        let result = tableau.check_dual().unwrap();
1250        assert_eq!(result, SimplexResult::Sat);
1251
1252        // Verify the solution is feasible
1253        assert!(tableau.is_feasible());
1254    }
1255}