use crate::sparse::CscMatrix;
const CRASH_PIVOT_REL: f64 = 0.1;
const CRASH_PIVOT_ABS: f64 = 1e-8;
pub(crate) fn compute_crash_basis(
a: &CscMatrix,
b: &[f64],
m: usize,
n_shifted: usize,
initial_basis_in: &[usize],
needs_artificial_in: &[bool],
) -> (Vec<usize>, Vec<bool>, usize) {
debug_assert_eq!(initial_basis_in.len(), m);
debug_assert_eq!(needs_artificial_in.len(), m);
debug_assert_eq!(b.len(), m);
let mut basis = initial_basis_in.to_vec();
let mut needs_artificial = needs_artificial_in.to_vec();
let mut row_covered: Vec<bool> = needs_artificial.iter().map(|&v| !v).collect();
let mut col_used: Vec<bool> = vec![false; a.ncols];
for (i, &covered) in row_covered.iter().enumerate() {
if covered {
col_used[basis[i]] = true;
}
}
let num_artificial_initial = needs_artificial.iter().filter(|&&v| v).count();
if num_artificial_initial == 0 {
return (basis, needs_artificial, 0);
}
let mut state = LtsfState::new(a, n_shifted, &row_covered, &col_used);
while let Some(j) = state.pop_min_active_column() {
let (cs, ce) = (a.col_ptr[j], a.col_ptr[j + 1]);
let mut col_max_abs = 0.0_f64;
for k in cs..ce {
let v = a.values[k].abs();
if v > col_max_abs {
col_max_abs = v;
}
}
if col_max_abs < CRASH_PIVOT_ABS {
continue;
}
let pivot_min = (CRASH_PIVOT_REL * col_max_abs).max(CRASH_PIVOT_ABS);
let mut best_row: Option<usize> = None;
let mut best_abs = 0.0_f64;
for k in cs..ce {
let row = a.row_ind[k];
if row_covered[row] {
continue;
}
let val = a.values[k];
let abs = val.abs();
if abs < pivot_min {
continue;
}
let bi = b[row];
if bi != 0.0 && val.signum() != bi.signum() {
continue;
}
if abs > best_abs {
best_abs = abs;
best_row = Some(row);
}
}
if let Some(row) = best_row {
basis[row] = j;
needs_artificial[row] = false;
row_covered[row] = true;
col_used[j] = true;
state.cover_row(row);
}
}
let num_artificial_out = needs_artificial.iter().filter(|&&v| v).count();
(basis, needs_artificial, num_artificial_out)
}
struct LtsfState {
row_ptr: Vec<usize>,
row_cols: Vec<usize>,
col_active: Vec<usize>,
buckets: Vec<Vec<usize>>,
min_k: usize,
col_consumed: Vec<bool>,
}
impl LtsfState {
fn new(a: &CscMatrix, n_shifted: usize, row_covered: &[bool], col_used: &[bool]) -> Self {
let m = a.nrows;
let mut col_active = vec![0usize; n_shifted];
let mut max_k = 0usize;
for j in 0..n_shifted {
if col_used[j] {
continue;
}
let mut cnt = 0usize;
for k in a.col_ptr[j]..a.col_ptr[j + 1] {
if !row_covered[a.row_ind[k]] {
cnt += 1;
}
}
col_active[j] = cnt;
if cnt > max_k {
max_k = cnt;
}
}
let mut buckets: Vec<Vec<usize>> = (0..=max_k).map(|_| Vec::new()).collect();
for j in 0..n_shifted {
if col_used[j] {
continue;
}
let cnt = col_active[j];
if cnt > 0 {
buckets[cnt].push(j);
}
}
let mut row_count = vec![0usize; m];
for j in 0..n_shifted {
for k in a.col_ptr[j]..a.col_ptr[j + 1] {
row_count[a.row_ind[k]] += 1;
}
}
let mut row_ptr = vec![0usize; m + 1];
for r in 0..m {
row_ptr[r + 1] = row_ptr[r] + row_count[r];
}
let mut row_cols = vec![0usize; row_ptr[m]];
let mut pos = row_ptr.clone();
for j in 0..n_shifted {
for k in a.col_ptr[j]..a.col_ptr[j + 1] {
let r = a.row_ind[k];
row_cols[pos[r]] = j;
pos[r] += 1;
}
}
let col_consumed = col_used[..n_shifted].to_vec();
Self {
row_ptr,
row_cols,
col_active,
buckets,
min_k: 1,
col_consumed,
}
}
fn pop_min_active_column(&mut self) -> Option<usize> {
let max_k = self.buckets.len().saturating_sub(1);
loop {
while self.min_k <= max_k && self.buckets[self.min_k].is_empty() {
self.min_k += 1;
}
if self.min_k > max_k {
return None;
}
let j = self.buckets[self.min_k].pop().unwrap();
if self.col_consumed[j] || self.col_active[j] != self.min_k {
continue;
}
self.col_consumed[j] = true;
return Some(j);
}
}
fn cover_row(&mut self, r: usize) {
let s = self.row_ptr[r];
let e = self.row_ptr[r + 1];
for idx in s..e {
let j = self.row_cols[idx];
if self.col_consumed[j] {
continue;
}
let new_cnt = self.col_active[j] - 1;
self.col_active[j] = new_cnt;
if new_cnt == 0 {
self.col_consumed[j] = true;
continue;
}
self.buckets[new_cnt].push(j);
if new_cnt < self.min_k {
self.min_k = new_cnt;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sparse::CscMatrix;
#[test]
fn diagonal_crash_eliminates_all_artificials() {
let a = CscMatrix::from_triplets(&[0, 1, 2], &[0, 1, 2], &[1.0, 1.0, 1.0], 3, 3).unwrap();
let b = vec![1.0, 2.0, 3.0];
let initial_basis = vec![0usize, 0, 0];
let needs_artif = vec![true, true, true];
let (basis, needs_out, num_art) =
compute_crash_basis(&a, &b, 3, 3, &initial_basis, &needs_artif);
assert_eq!(num_art, 0, "全行被覆可能");
assert_eq!(basis, vec![0, 1, 2]);
assert_eq!(needs_out, vec![false; 3]);
}
#[test]
fn small_pivot_column_rejected() {
let a = CscMatrix::from_triplets(&[0], &[0], &[1e-12], 1, 1).unwrap();
let b = vec![1.0];
let initial_basis = vec![0usize];
let needs_artif = vec![true];
let (_, needs_out, num_art) =
compute_crash_basis(&a, &b, 1, 1, &initial_basis, &needs_artif);
assert_eq!(num_art, 1, "tiny pivot は被覆しない");
assert_eq!(needs_out, vec![true]);
}
#[test]
fn covered_rows_kept_as_is() {
let a = CscMatrix::from_triplets(&[0, 1, 1, 0], &[0, 0, 1, 2], &[1.0, 2.0, 0.5, 1.0], 2, 3)
.unwrap();
let b = vec![1.0, 1.0];
let initial_basis = vec![2usize, 0];
let needs_artif = vec![false, true];
let (basis, needs_out, num_art) =
compute_crash_basis(&a, &b, 2, 2, &initial_basis, &needs_artif);
assert_eq!(num_art, 0);
assert_eq!(basis[0], 2, "行 0 の slack basis 維持");
assert!(basis[1] == 0 || basis[1] == 1, "行 1 は構造列で被覆");
assert_eq!(needs_out, vec![false, false]);
}
#[test]
fn partial_coverage() {
let a = CscMatrix::from_triplets(&[0], &[0], &[1.0], 2, 1).unwrap();
let b = vec![1.0, 1.0];
let initial_basis = vec![0usize, 0];
let needs_artif = vec![true, true];
let (basis, needs_out, num_art) =
compute_crash_basis(&a, &b, 2, 1, &initial_basis, &needs_artif);
assert_eq!(num_art, 1);
assert_eq!(basis[0], 0);
assert!(needs_out[1], "行 1 は artificial 必要");
}
#[test]
fn sign_mismatch_rejected() {
let a = CscMatrix::from_triplets(&[0], &[0], &[1.0], 1, 1).unwrap();
let b = vec![-1.0];
let initial_basis = vec![0usize];
let needs_artif = vec![true];
let (_, _, num_art) = compute_crash_basis(&a, &b, 1, 1, &initial_basis, &needs_artif);
assert_eq!(num_art, 1, "符号不一致行は被覆しない");
}
#[test]
fn sign_match_accepted() {
let a = CscMatrix::from_triplets(&[0], &[0], &[-1.0], 1, 1).unwrap();
let b = vec![-1.0];
let initial_basis = vec![0usize];
let needs_artif = vec![true];
let (basis, _, num_art) = compute_crash_basis(&a, &b, 1, 1, &initial_basis, &needs_artif);
assert_eq!(num_art, 0);
assert_eq!(basis[0], 0);
}
#[test]
fn ltsf_singleton_chase_covers_all_rows() {
let rows = vec![0, 1, 2, 0, 1, 0];
let cols = vec![0, 0, 0, 1, 1, 2];
let vals = vec![10.0, 1.0, 1.0, 10.0, 1.0, 10.0];
let a = CscMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
let b = vec![1.0, 1.0, 1.0];
let initial = vec![0, 0, 0];
let needs = vec![true, true, true];
let (basis, needs_out, num_art) = compute_crash_basis(&a, &b, 3, 3, &initial, &needs);
assert_eq!(
num_art, 0,
"LTSF should chase singletons and cover all rows"
);
assert_eq!(needs_out, vec![false; 3]);
let mut seen = std::collections::HashSet::new();
for &c in &basis {
assert!(seen.insert(c), "duplicate column in basis: {:?}", basis);
}
}
#[test]
fn ltsf_dynamic_repriority_full_cover() {
let rows = vec![
0, 1, 2, 3, 0, 1, 2, 1, 2, 3, 0, 3, ];
let cols = vec![0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 4];
let vals = vec![5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0];
let a = CscMatrix::from_triplets(&rows, &cols, &vals, 4, 5).unwrap();
let b = vec![1.0, 1.0, 1.0, 1.0];
let initial = vec![0, 0, 0, 0];
let needs = vec![true, true, true, true];
let (basis, _, num_art) = compute_crash_basis(&a, &b, 4, 5, &initial, &needs);
assert_eq!(
num_art, 0,
"dynamic re-priority should cover all 4 rows; basis={:?}",
basis
);
let mut seen = std::collections::HashSet::new();
for &c in &basis {
assert!(seen.insert(c), "duplicate column in basis: {:?}", basis);
}
}
#[test]
fn ltsf_basis_columns_unique_and_in_range() {
let rows = vec![0, 1, 2, 3, 4, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0];
let cols = vec![0, 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5];
let vals = vec![1.0; 15];
let a = CscMatrix::from_triplets(&rows, &cols, &vals, 5, 6).unwrap();
let b = vec![1.0; 5];
let initial = vec![0; 5];
let needs = vec![true; 5];
let (basis, needs_out, num_art) = compute_crash_basis(&a, &b, 5, 6, &initial, &needs);
let mut seen = std::collections::HashSet::new();
for (i, &c) in basis.iter().enumerate() {
if !needs_out[i] {
assert!(c < 6, "basis[{}]={} out of range", i, c);
assert!(seen.insert(c), "duplicate basis column {}", c);
}
}
assert!(
num_art <= 1,
"LTSF should cover ≥ 4 rows out of 5; got num_art={}",
num_art
);
}
}