use pounce_common::types::{Index, Number};
use crate::block_solve::{lu_factor_partial_pivot, lu_solve};
#[derive(Debug, Clone)]
pub struct ProjectionResult {
pub all_implied: bool,
pub per_row: Vec<ProjectedRow>,
}
#[derive(Debug, Clone)]
pub struct ProjectedRow {
pub inner_row: usize,
pub coef: Vec<Number>,
pub rhs_l: Number,
pub rhs_u: Number,
pub activity_lo: Number,
pub activity_hi: Number,
pub implied: bool,
}
#[allow(clippy::too_many_arguments)]
pub fn project_inequalities(
block_rows: &[usize],
block_cols: &[usize],
coupled_ineq_rows: &[usize],
n_vars: usize,
x_l: &[Number],
x_u: &[Number],
g_l: &[Number],
g_u: &[Number],
jac_irow: &[Index],
jac_jcol: &[Index],
jac_values: &[Number],
g_at_probe: &[Number],
x_probe: &[Number],
one_based: bool,
) -> Option<ProjectionResult> {
let k = block_rows.len();
if k == 0 || k != block_cols.len() {
return None;
}
let nnz = jac_irow.len();
let mut by_row: std::collections::HashMap<usize, Vec<(usize, Number)>> =
std::collections::HashMap::new();
for kk in 0..nnz {
let i = if one_based {
(jac_irow[kk] as isize - 1) as usize
} else {
jac_irow[kk] as usize
};
let j = if one_based {
(jac_jcol[kk] as isize - 1) as usize
} else {
jac_jcol[kk] as usize
};
if j >= n_vars {
continue;
}
by_row.entry(i).or_default().push((j, jac_values[kk]));
}
let mut col_to_block_pos: Vec<Option<usize>> = vec![None; n_vars];
for (pos, &c) in block_cols.iter().enumerate() {
col_to_block_pos[c] = Some(pos);
}
let mut j_block = vec![0.0; k * k];
let mut const_block = vec![0.0; k];
for (i_block, &r) in block_rows.iter().enumerate() {
let entries = match by_row.get(&r) {
Some(e) => e,
None => return None, };
let mut sum_jx = 0.0;
for &(c, v) in entries {
sum_jx += v * x_probe[c];
if let Some(j_pos) = col_to_block_pos[c] {
j_block[i_block * k + j_pos] = v;
}
}
let c_r = g_at_probe[r] - sum_jx;
const_block[i_block] = g_l[r] - c_r;
}
let mut j_block_lu = j_block.clone();
let piv = lu_factor_partial_pivot(&mut j_block_lu, k).ok()?;
let mut p = const_block.clone();
lu_solve(&j_block_lu, &piv, &mut p, k);
let mut m_columns: std::collections::HashMap<usize, Vec<Number>> =
std::collections::HashMap::new();
for (i_block, &r) in block_rows.iter().enumerate() {
let entries = by_row.get(&r).expect("checked above");
for &(c, v) in entries {
if col_to_block_pos[c].is_some() {
continue;
}
let col = m_columns.entry(c).or_insert_with(|| vec![0.0; k]);
col[i_block] += v;
let _ = v;
let _ = i_block;
}
}
let mut m: std::collections::HashMap<usize, Vec<Number>> =
std::collections::HashMap::with_capacity(m_columns.len());
for (c, mut col) in m_columns {
lu_solve(&j_block_lu, &piv, &mut col, k);
for v in col.iter_mut() {
*v = -*v;
}
m.insert(c, col);
}
let mut per_row: Vec<ProjectedRow> = Vec::with_capacity(coupled_ineq_rows.len());
let mut all_implied = true;
for &r in coupled_ineq_rows {
let entries = by_row.get(&r).cloned().unwrap_or_default();
let mut sum_jx = 0.0;
for &(c, v) in &entries {
sum_jx += v * x_probe[c];
}
let const_r = g_at_probe[r] - sum_jx;
let gl = g_l[r] - const_r;
let gu = g_u[r] - const_r;
let mut a_b = vec![0.0; k];
let mut a_y: std::collections::HashMap<usize, Number> = std::collections::HashMap::new();
for &(c, v) in &entries {
if let Some(pos) = col_to_block_pos[c] {
a_b[pos] = v;
} else {
*a_y.entry(c).or_insert(0.0) += v;
}
}
let mut a_b_dot_p: Number = 0.0;
for (i_block, &val) in a_b.iter().enumerate() {
a_b_dot_p += val * p[i_block];
}
let rhs_l = gl - a_b_dot_p;
let rhs_u = gu - a_b_dot_p;
let mut coef = vec![0.0; n_vars];
for (&c, &ay_val) in &a_y {
coef[c] += ay_val;
}
for (&c, m_col) in &m {
let mut s = 0.0;
for i_block in 0..k {
s += a_b[i_block] * m_col[i_block];
}
coef[c] += s;
}
let mut lo: Number = 0.0;
let mut hi: Number = 0.0;
let mut bounded = true;
for c in 0..n_vars {
let v = coef[c];
if v == 0.0 {
continue;
}
let xl = x_l[c];
let xu = x_u[c];
if !xl.is_finite() || !xu.is_finite() {
bounded = false;
break;
}
if v > 0.0 {
lo += v * xl;
hi += v * xu;
} else {
lo += v * xu;
hi += v * xl;
}
}
let (activity_lo, activity_hi) = if bounded {
(lo, hi)
} else {
(Number::NEG_INFINITY, Number::INFINITY)
};
let implied = rhs_l <= activity_lo && activity_hi <= rhs_u;
if !implied {
all_implied = false;
}
per_row.push(ProjectedRow {
inner_row: r,
coef,
rhs_l,
rhs_u,
activity_lo,
activity_hi,
implied,
});
}
Some(ProjectionResult {
all_implied,
per_row,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn projection_implied_singleton() {
let result = project_inequalities(
&[0], &[0], &[1], 2,
&[-1e19, 0.0],
&[1e19, 1.0],
&[3.0, -10.0],
&[3.0, 10.0],
&[0, 1],
&[0, 0],
&[1.0, 1.0],
&[0.0, 0.0],
&[0.0, 0.0],
false,
)
.expect("non-singular");
assert!(result.all_implied);
let row = &result.per_row[0];
assert_eq!(row.coef[0], 0.0); assert!(row.coef[1].abs() < 1e-12);
assert!((row.rhs_l - (-13.0)).abs() < 1e-12); assert!((row.rhs_u - 7.0).abs() < 1e-12); assert!(row.implied);
}
#[test]
fn projection_not_implied_tight() {
let result = project_inequalities(
&[0],
&[0],
&[1],
2,
&[-1e19, 0.0],
&[1e19, 1.0],
&[5.0, -1.0],
&[5.0, 1.0],
&[0, 1],
&[0, 0],
&[1.0, 1.0],
&[0.0, 0.0],
&[0.0, 0.0],
false,
)
.expect("non-singular");
assert!(!result.all_implied);
assert!(!result.per_row[0].implied);
}
#[test]
fn projection_2x2_block_implied() {
let result = project_inequalities(
&[0, 1],
&[0, 1],
&[2],
3,
&[-1e19, -1e19, 0.0],
&[1e19, 1e19, 1.0],
&[3.0, 1.0, 0.0],
&[3.0, 1.0, 100.0],
&[0, 0, 1, 1, 2, 2],
&[0, 1, 0, 1, 0, 1],
&[1.0, 1.0, 1.0, -1.0, 1.0, 1.0],
&[0.0, 0.0, 0.0],
&[0.0, 0.0, 0.0],
false,
)
.expect("non-singular");
assert!(result.all_implied);
}
#[test]
fn projection_unbounded_var_in_y_blocks_admit() {
let result = project_inequalities(
&[0],
&[0],
&[1],
2,
&[-1e19, 0.0],
&[1e19, 1e19], &[3.0, -1e19],
&[3.0, 100.0],
&[0, 1, 1],
&[0, 0, 1],
&[1.0, 1.0, 1.0],
&[0.0, 0.0],
&[0.0, 0.0],
false,
)
.expect("non-singular");
assert!(!result.all_implied);
}
#[test]
fn projection_singular_block() {
let r = project_inequalities(
&[0, 1],
&[0, 1],
&[],
2,
&[-1e19, -1e19],
&[1e19, 1e19],
&[0.0, 0.0],
&[0.0, 0.0],
&[0, 0, 1, 1],
&[0, 1, 0, 1],
&[1.0, 2.0, 2.0, 4.0],
&[0.0, 0.0],
&[0.0, 0.0],
false,
);
assert!(r.is_none());
}
}