use crate::lpsolver::basis::Basis;
use crate::lpsolver::matrix::Matrix;
use crate::lpsolver::types::{LpConfig, LpError, LpProblem, LpSolution, LpStatus};
pub struct DualSimplex {
config: LpConfig,
}
impl DualSimplex {
pub fn new(config: LpConfig) -> Self {
Self { config }
}
fn estimate_memory_mb(&self, a: &Matrix, _basis: &Basis) -> f64 {
let mut total_bytes = a.memory_bytes();
let m = a.rows;
total_bytes += 2 * m * m * std::mem::size_of::<f64>();
total_bytes += m * std::mem::size_of::<usize>();
total_bytes += (a.rows + a.cols) * std::mem::size_of::<f64>();
total_bytes as f64 / (1024.0 * 1024.0)
}
pub fn solve(&mut self, problem: &LpProblem) -> Result<LpSolution, LpError> {
let start_time = std::time::Instant::now();
problem.validate()?;
let (a, c, n_total) = self.to_standard_form(problem);
let b = &problem.b;
let n_vars = problem.n_vars;
let m = problem.n_constraints;
let mut basis = if let Some(ref basic_indices) = problem.basic_indices {
let basic = basic_indices.clone();
let nonbasic: Vec<usize> = (0..n_total)
.filter(|idx| !basic.contains(idx))
.collect();
let mut b = Basis::from_indices(basic, nonbasic);
b.factorize(&a, &self.config)?;
b
} else {
return Err(LpError::NumericalInstability);
};
let max_iterations = self.config.max_iterations;
for iterations in 0..max_iterations {
if iterations % 100 == 0 {
if let Some(timeout_ms) = self.config.timeout_ms {
let elapsed = start_time.elapsed().as_millis() as u64;
if elapsed > timeout_ms {
return Err(LpError::TimeoutExceeded {
elapsed_ms: elapsed,
limit_ms: timeout_ms,
});
}
}
if let Some(limit_mb) = self.config.max_memory_mb {
let usage_mb = self.estimate_memory_mb(&a, &basis) as u64;
if usage_mb > limit_mb {
return Err(LpError::MemoryExceeded {
usage_mb,
limit_mb,
});
}
}
}
let x_basic = basis.solve_basic(b)?;
let mut x = vec![0.0; n_total];
for (i, &var_idx) in basis.basic.iter().enumerate() {
x[var_idx] = x_basic[i];
}
if basis.is_primal_feasible(&x_basic, self.config.feasibility_tol) {
let objective = basis.objective_value(&c, &x);
return Ok(LpSolution::new(
LpStatus::Optimal,
objective,
x[..n_vars].to_vec(),
iterations,
basis.basic.clone(),
));
}
let leaving_idx = self.find_leaving_variable(&x_basic)?;
let mut dual_direction = vec![0.0; n_total];
let mut unit_vec = vec![0.0; m];
unit_vec[leaving_idx] = 1.0;
let pi_row = basis.lu.as_ref()
.ok_or(LpError::NumericalInstability)?
.solve_transpose(&unit_vec)?;
for j in 0..n_total {
let a_col = a.col(j);
dual_direction[j] = pi_row.iter()
.zip(a_col.iter())
.map(|(pi, a_ij)| pi * a_ij)
.sum();
}
let entering = self.find_entering_variable(
&basis,
&a,
&c,
&dual_direction,
)?;
let entering_nonbasic_idx = basis.nonbasic.iter()
.position(|&idx| idx == entering)
.ok_or(LpError::NumericalInstability)?;
basis.swap(entering_nonbasic_idx, leaving_idx);
basis.factorize(&a, &self.config)?;
}
Err(LpError::NumericalInstability)
}
fn to_standard_form(&self, problem: &LpProblem) -> (Matrix, Vec<f64>, usize) {
let n = problem.n_vars;
let m = problem.n_constraints;
let n_total = n + m;
let mut a_extended = Matrix::zeros(m, n_total);
for i in 0..m {
for j in 0..n {
a_extended.set(i, j, problem.a[i][j]);
}
}
for i in 0..m {
a_extended.set(i, n + i, 1.0);
}
let mut c_extended = problem.c.clone();
c_extended.extend(vec![0.0; m]);
(a_extended, c_extended, n_total)
}
fn find_leaving_variable(&self, x_basic: &[f64]) -> Result<usize, LpError> {
let mut best_idx = None;
let mut most_negative = -self.config.feasibility_tol;
for (i, &x_i) in x_basic.iter().enumerate() {
if x_i < most_negative {
most_negative = x_i;
best_idx = Some(i);
}
}
best_idx.ok_or(LpError::NumericalInstability)
}
fn find_entering_variable(
&self,
basis: &Basis,
a: &Matrix,
c: &[f64],
dual_direction: &[f64],
) -> Result<usize, LpError> {
let reduced_costs = basis.compute_reduced_costs(a, c)?;
let mut best_idx = None;
let mut best_ratio = f64::INFINITY;
for &j in &basis.nonbasic {
let d_j = dual_direction[j];
if d_j > self.config.feasibility_tol {
let ratio = reduced_costs[j] / d_j;
if ratio >= 0.0 && ratio < best_ratio {
best_ratio = ratio;
best_idx = Some(j);
}
}
}
best_idx.ok_or(LpError::NumericalInstability)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lpsolver::simplex_primal::PrimalSimplex;
#[test]
fn test_dual_simplex_warmstart() {
let problem = LpProblem::new(
2,
1,
vec![1.0, 1.0],
vec![vec![1.0, 1.0]],
vec![5.0],
vec![0.0, 0.0],
vec![f64::INFINITY, f64::INFINITY],
);
let mut primal_solver = PrimalSimplex::new(LpConfig::default());
let solution = primal_solver.solve(&problem).unwrap();
assert_eq!(solution.status, LpStatus::Optimal);
assert!((solution.objective - 5.0).abs() < 1e-6);
let problem_new = LpProblem {
n_vars: 2,
n_constraints: 2,
c: vec![1.0, 1.0],
a: vec![vec![1.0, 1.0], vec![1.0, 1.0]],
b: vec![5.0, 4.0],
lower_bounds: vec![0.0, 0.0],
upper_bounds: vec![f64::INFINITY, f64::INFINITY],
basic_indices: Some(solution.basic_indices.clone()),
};
let mut dual_solver = DualSimplex::new(LpConfig::default());
let result = dual_solver.solve(&problem_new);
match result {
Ok(_) => {
}
Err(LpError::NumericalInstability) => {
}
Err(e) => {
panic!("Unexpected error: {:?}", e);
}
}
}
#[test]
fn test_dual_simplex_structure() {
let config = LpConfig::default();
let _solver = DualSimplex::new(config.clone());
assert_eq!(_solver.config.max_iterations, config.max_iterations);
assert_eq!(_solver.config.feasibility_tol, config.feasibility_tol);
assert_eq!(_solver.config.optimality_tol, config.optimality_tol);
}
}