use crate::basis::{BasisManager, LuBasis};
use crate::sparse::CscMatrix;
pub(super) fn compute_dual_vars_into(
c: &[f64],
basis_mgr: &mut LuBasis,
basis: &[usize],
y_out: &mut [f64],
) {
debug_assert_eq!(y_out.len(), basis.len());
for (i, slot) in y_out.iter_mut().enumerate() {
*slot = c[basis[i]];
}
basis_mgr.btran_dense(y_out);
}
pub(super) fn compute_dual_vars(
c: &[f64],
basis_mgr: &mut LuBasis,
basis: &[usize],
m: usize,
) -> Vec<f64> {
let mut y = vec![0.0f64; m];
debug_assert_eq!(basis.len(), m);
compute_dual_vars_into(c, basis_mgr, basis, &mut y);
y
}
pub(super) fn compute_reduced_costs_into(
a: &CscMatrix,
c: &[f64],
basis_mgr: &mut LuBasis,
is_basic: &[bool],
n_price: usize,
basis: &[usize],
y_buf: &mut [f64],
rc_out: &mut [f64],
) {
debug_assert_eq!(rc_out.len(), n_price);
compute_dual_vars_into(c, basis_mgr, basis, y_buf);
for j in 0..n_price {
if is_basic[j] {
rc_out[j] = 0.0;
continue;
}
let (rows, vals) = a.get_column(j).unwrap();
let mut ya = 0.0;
for (k, &row) in rows.iter().enumerate() {
ya += y_buf[row] * vals[k];
}
rc_out[j] = c[j] - ya;
}
}
pub(super) fn compute_reduced_costs(
a: &CscMatrix,
c: &[f64],
basis_mgr: &mut LuBasis,
is_basic: &[bool],
n_price: usize,
m: usize,
basis: &[usize],
) -> Vec<f64> {
let mut y = vec![0.0f64; m];
let mut reduced_costs = vec![0.0f64; n_price];
compute_reduced_costs_into(
a, c, basis_mgr, is_basic, n_price, basis, &mut y, &mut reduced_costs,
);
reduced_costs
}
pub(super) fn basic_obj(c: &[f64], basis: &[usize], x_b: &[f64]) -> f64 {
debug_assert_eq!(basis.len(), x_b.len());
basis
.iter()
.zip(x_b.iter())
.map(|(&j, &v)| c[j] * v)
.sum()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::basis::LuBasis;
use crate::sparse::CscMatrix;
fn make_identity_plus(n: usize, m: usize) -> (CscMatrix, Vec<f64>, Vec<usize>) {
let mut rows: Vec<usize> = Vec::new();
let mut cols: Vec<usize> = Vec::new();
let mut vals: Vec<f64> = Vec::new();
for j in 0..m {
rows.push(j);
cols.push(j);
vals.push(1.0);
}
for j in m..n {
rows.push((j - m) % m);
cols.push(j);
vals.push(2.0);
}
let a = CscMatrix::from_triplets(&rows, &cols, &vals, m, n).unwrap();
let basis: Vec<usize> = (0..m).collect();
let c: Vec<f64> = (0..n).map(|j| (j as f64) + 1.0).collect();
(a, c, basis)
}
#[test]
fn dual_vars_identity_basis_returns_c_b() {
let m = 4;
let (a, c, basis) = make_identity_plus(m + 3, m);
let mut bm = LuBasis::new(&a, &basis, 32).unwrap();
let y = compute_dual_vars(&c, &mut bm, &basis, m);
for i in 0..m {
assert!((y[i] - c[i]).abs() < 1e-12, "y[{}] = {} expected {}", i, y[i], c[i]);
}
}
#[test]
fn dual_vars_into_matches_allocating_and_is_reuse_safe() {
let m = 4;
let (a, c, basis) = make_identity_plus(m + 3, m);
let mut bm = LuBasis::new(&a, &basis, 32).unwrap();
let y_alloc = compute_dual_vars(&c, &mut bm, &basis, m);
let mut y_into = vec![999.0f64; m];
compute_dual_vars_into(&c, &mut bm, &basis, &mut y_into);
for i in 0..m {
assert!((y_into[i] - y_alloc[i]).abs() < 1e-14);
}
for slot in y_into.iter_mut() { *slot = -42.0; }
compute_dual_vars_into(&c, &mut bm, &basis, &mut y_into);
for i in 0..m {
assert!((y_into[i] - y_alloc[i]).abs() < 1e-14);
}
}
#[test]
fn reduced_costs_identity_basis_match_closed_form() {
let m = 3;
let n = m + 3;
let (a, c, basis) = make_identity_plus(n, m);
let is_basic: Vec<bool> = (0..n).map(|j| j < m).collect();
let mut bm = LuBasis::new(&a, &basis, 32).unwrap();
let r = compute_reduced_costs(&a, &c, &mut bm, &is_basic, n, m, &basis);
for j in 0..m {
assert_eq!(r[j], 0.0);
}
for j in m..n {
let expected = c[j] - 2.0 * c[(j - m) % m];
assert!((r[j] - expected).abs() < 1e-12, "r[{}] = {} expected {}", j, r[j], expected);
}
}
#[test]
fn reduced_costs_into_matches_allocating_and_clears_basic_slots() {
let m = 3;
let n = m + 3;
let (a, c, basis) = make_identity_plus(n, m);
let is_basic: Vec<bool> = (0..n).map(|j| j < m).collect();
let mut bm = LuBasis::new(&a, &basis, 32).unwrap();
let r_alloc = compute_reduced_costs(&a, &c, &mut bm, &is_basic, n, m, &basis);
let mut y_buf = vec![0.0f64; m];
let mut rc_out = vec![123.456f64; n];
compute_reduced_costs_into(&a, &c, &mut bm, &is_basic, n, &basis, &mut y_buf, &mut rc_out);
for j in 0..n {
assert!((rc_out[j] - r_alloc[j]).abs() < 1e-14, "j={}", j);
}
for j in 0..m {
assert_eq!(rc_out[j], 0.0, "basic slot {} not zeroed", j);
}
}
#[test]
fn reduced_costs_zero_cost_yields_zero_vector() {
let m = 3;
let n = m + 2;
let (a, _c, basis) = make_identity_plus(n, m);
let c = vec![0.0f64; n];
let is_basic: Vec<bool> = (0..n).map(|j| j < m).collect();
let mut bm = LuBasis::new(&a, &basis, 32).unwrap();
let r = compute_reduced_costs(&a, &c, &mut bm, &is_basic, n, m, &basis);
for &rj in &r {
assert!(rj.abs() < 1e-14, "r = {:?} should be all zero", r);
}
}
#[test]
fn dual_vars_permuted_basis_uses_basis_indexing() {
let m = 3;
let n = m;
let rows: Vec<usize> = (0..m).collect();
let cols: Vec<usize> = (0..m).collect();
let vals: Vec<f64> = vec![1.0; m];
let a = CscMatrix::from_triplets(&rows, &cols, &vals, m, n).unwrap();
let basis = vec![2usize, 0, 1];
let c = vec![10.0, 20.0, 30.0];
let mut bm = LuBasis::new(&a, &basis, 32).unwrap();
let y = compute_dual_vars(&c, &mut bm, &basis, m);
for i in 0..m {
let (rs, vs) = a.get_column(basis[i]).unwrap();
let mut dot = 0.0;
for (k, &row) in rs.iter().enumerate() {
dot += y[row] * vs[k];
}
assert!((dot - c[basis[i]]).abs() < 1e-12,
"y^T a_{{basis[{}]}} = {} expected {}", i, dot, c[basis[i]]);
}
}
#[test]
fn basic_obj_identity_basis() {
let m = 4;
let (_a, c, basis) = make_identity_plus(m + 2, m);
let x_b = vec![1.0, 2.0, 3.0, 4.0];
let obj = basic_obj(&c, &basis, &x_b);
let expected: f64 = (0..m).map(|i| c[basis[i]] * x_b[i]).sum();
assert!((obj - expected).abs() < 1e-14);
assert!((obj - 30.0).abs() < 1e-14);
}
#[test]
fn basic_obj_permuted_basis() {
let basis = vec![2usize, 0, 1];
let c = vec![10.0, 20.0, 30.0];
let x_b = vec![1.0, 2.0, 3.0];
let obj = basic_obj(&c, &basis, &x_b);
assert!((obj - 110.0).abs() < 1e-14);
}
#[test]
fn basic_obj_empty_basis() {
let c = vec![1.0, 2.0, 3.0];
let basis: Vec<usize> = vec![];
let x_b: Vec<f64> = vec![];
assert_eq!(basic_obj(&c, &basis, &x_b), 0.0);
}
#[test]
fn basic_obj_negative_x_b_signs_preserved() {
let c = vec![1.0, 2.0, 3.0];
let basis = vec![0usize, 1, 2];
let x_b = vec![-1.0, 2.0, -3.0];
assert!((basic_obj(&c, &basis, &x_b) - (-6.0)).abs() < 1e-14);
}
}