use crate::error::{OptimizeError, OptimizeResult};
pub fn compute_kkt_jacobian(
q: &[Vec<f64>],
g: &[Vec<f64>],
a: &[Vec<f64>],
x: &[f64],
lam: &[f64],
nu: &[f64],
) -> Vec<Vec<f64>> {
let n = x.len();
let m = lam.len();
let p = nu.len();
let dim = n + m + p;
let mut jac = vec![vec![0.0; dim]; dim];
for i in 0..n {
for j in 0..n {
jac[i][j] = if i < q.len() && j < q[i].len() {
q[i][j]
} else {
0.0
};
}
}
for j in 0..m {
for i in 0..n {
let g_val = if j < g.len() && i < g[j].len() {
g[j][i]
} else {
0.0
};
jac[i][n + j] = g_val;
}
}
for j in 0..p {
for i in 0..n {
let a_val = if j < a.len() && i < a[j].len() {
a[j][i]
} else {
0.0
};
jac[i][n + m + j] = a_val;
}
}
for i in 0..m {
let li = lam[i];
for j in 0..n {
let g_val = if i < g.len() && j < g[i].len() {
g[i][j]
} else {
0.0
};
jac[n + i][j] = li * g_val;
}
}
for i in 0..m {
let mut gx_i = 0.0;
if i < g.len() {
for j in 0..n.min(g[i].len()) {
gx_i += g[i][j] * x[j];
}
}
jac[n + i][n + i] = gx_i; }
for i in 0..p {
for j in 0..n {
let a_val = if i < a.len() && j < a[i].len() {
a[i][j]
} else {
0.0
};
jac[n + m + i][j] = a_val;
}
}
jac
}
pub fn adjust_complementarity_diagonal(jac: &mut [Vec<f64>], h: &[f64], n: usize) {
for (i, &h_i) in h.iter().enumerate() {
jac[n + i][n + i] -= h_i;
}
}
pub fn solve_implicit_system(mat: &[Vec<f64>], rhs: &[f64]) -> OptimizeResult<Vec<f64>> {
let n = rhs.len();
if mat.len() != n {
return Err(OptimizeError::InvalidInput(format!(
"KKT matrix rows ({}) != rhs length ({})",
mat.len(),
n
)));
}
let mut aug: Vec<Vec<f64>> = mat
.iter()
.enumerate()
.map(|(i, row)| {
let mut r = row.clone();
r.push(rhs[i]);
r
})
.collect();
for col in 0..n {
let mut max_val = aug[col][col].abs();
let mut max_row = col;
for row in (col + 1)..n {
let v = aug[row][col].abs();
if v > max_val {
max_val = v;
max_row = row;
}
}
if max_val < 1e-30 {
return Err(OptimizeError::ComputationError(
"Singular KKT matrix in implicit differentiation".to_string(),
));
}
if max_row != col {
aug.swap(col, max_row);
}
let pivot = aug[col][col];
for row in (col + 1)..n {
let factor = aug[row][col] / pivot;
for j in col..=n {
let val = aug[col][j];
aug[row][j] -= factor * val;
}
}
}
let mut solution = vec![0.0; n];
for i in (0..n).rev() {
let mut sum = aug[i][n];
for j in (i + 1)..n {
sum -= aug[i][j] * solution[j];
}
let diag = aug[i][i];
if diag.abs() < 1e-30 {
return Err(OptimizeError::ComputationError(
"Zero diagonal in back substitution".to_string(),
));
}
solution[i] = sum / diag;
}
Ok(solution)
}
pub fn solve_implicit_system_multi(
mat: &[Vec<f64>],
rhs_cols: &[Vec<f64>],
) -> OptimizeResult<Vec<Vec<f64>>> {
rhs_cols
.iter()
.map(|rhs| solve_implicit_system(mat, rhs))
.collect()
}
pub fn identify_active_constraints(g: &[Vec<f64>], h: &[f64], x: &[f64], tol: f64) -> Vec<usize> {
let m = h.len();
let n = x.len();
let mut active = Vec::new();
for i in 0..m {
let mut gx_i = 0.0;
if i < g.len() {
for j in 0..n.min(g[i].len()) {
gx_i += g[i][j] * x[j];
}
}
let slack = h[i] - gx_i; if slack.abs() <= tol {
active.push(i);
}
}
active
}
pub fn extract_active_constraints(
g: &[Vec<f64>],
h: &[f64],
active: &[usize],
) -> (Vec<Vec<f64>>, Vec<f64>) {
let g_active: Vec<Vec<f64>> = active.iter().filter_map(|&i| g.get(i).cloned()).collect();
let h_active: Vec<f64> = active.iter().filter_map(|&i| h.get(i).copied()).collect();
(g_active, h_active)
}
pub fn compute_full_implicit_gradient(
q: &[Vec<f64>],
g: &[Vec<f64>],
h: &[f64],
a: &[Vec<f64>],
x: &[f64],
lam: &[f64],
nu: &[f64],
dl_dx: &[f64],
) -> OptimizeResult<super::types::ImplicitGradient> {
let n = x.len();
let m = lam.len();
let p = nu.len();
let dim = n + m + p;
let mut kkt = compute_kkt_jacobian(q, g, a, x, lam, nu);
adjust_complementarity_diagonal(&mut kkt, h, n);
let mut kkt_t = vec![vec![0.0; dim]; dim];
for i in 0..dim {
for j in 0..dim {
kkt_t[i][j] = kkt[j][i];
}
}
let mut rhs = vec![0.0; dim];
for i in 0..n {
rhs[i] = -dl_dx[i];
}
let dz = solve_implicit_system(&kkt_t, &rhs)?;
let dx = &dz[..n];
let dlam = &dz[n..n + m];
let dnu = &dz[n + m..];
let dl_dc = dx.to_vec();
let dl_dh: Vec<f64> = dlam.iter().map(|&v| -v).collect();
let dl_db: Vec<f64> = dnu.iter().map(|&v| -v).collect();
let mut dl_dq = vec![vec![0.0; n]; n];
for i in 0..n {
for j in 0..n {
dl_dq[i][j] = 0.5 * (dx[i] * x[j] + dx[j] * x[i]);
}
}
let mut dl_dg = vec![vec![0.0; n]; m];
for i in 0..m {
for j in 0..n {
dl_dg[i][j] = dx[j] * lam[i] + dlam[i] * lam[i] * x[j];
}
}
let mut dl_da = vec![vec![0.0; n]; p];
for i in 0..p {
for j in 0..n {
dl_da[i][j] = dx[j] * nu[i] + dnu[i] * x[j];
}
}
Ok(super::types::ImplicitGradient {
dl_dq: Some(dl_dq),
dl_dc,
dl_dg: Some(dl_dg),
dl_dh,
dl_da: if p > 0 { Some(dl_da) } else { None },
dl_db,
})
}
pub fn compute_active_set_implicit_gradient(
q: &[Vec<f64>],
g: &[Vec<f64>],
h: &[f64],
a: &[Vec<f64>],
x: &[f64],
lam: &[f64],
nu: &[f64],
dl_dx: &[f64],
active_tol: f64,
) -> OptimizeResult<super::types::ImplicitGradient> {
let m = lam.len();
let active = identify_active_constraints(g, h, x, active_tol);
let (g_active, h_active) = extract_active_constraints(g, h, &active);
let lam_active: Vec<f64> = active
.iter()
.filter_map(|&i| if i < m { Some(lam[i]) } else { None })
.collect();
let grad =
compute_full_implicit_gradient(q, &g_active, &h_active, a, x, &lam_active, nu, dl_dx)?;
let m_full = lam.len();
let n = x.len();
let mut dl_dh_full = vec![0.0; m_full];
for (idx, &ai) in active.iter().enumerate() {
if ai < m_full && idx < grad.dl_dh.len() {
dl_dh_full[ai] = grad.dl_dh[idx];
}
}
let dl_dg_full = if let Some(ref dg) = grad.dl_dg {
let mut full = vec![vec![0.0; n]; m_full];
for (idx, &ai) in active.iter().enumerate() {
if ai < m_full && idx < dg.len() {
full[ai] = dg[idx].clone();
}
}
Some(full)
} else {
None
};
Ok(super::types::ImplicitGradient {
dl_dq: grad.dl_dq,
dl_dc: grad.dl_dc,
dl_dg: dl_dg_full,
dl_dh: dl_dh_full,
dl_da: grad.dl_da,
dl_db: grad.dl_db,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kkt_jacobian_structure_2x2() {
let q = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
let g = vec![vec![1.0, 1.0]];
let a: Vec<Vec<f64>> = vec![];
let x = vec![0.25, 0.25];
let lam = vec![0.5];
let nu: Vec<f64> = vec![];
let jac = compute_kkt_jacobian(&q, &g, &a, &x, &lam, &nu);
assert_eq!(jac.len(), 3);
assert_eq!(jac[0].len(), 3);
assert!((jac[0][0] - 2.0).abs() < 1e-12);
assert!((jac[1][1] - 2.0).abs() < 1e-12);
assert!((jac[0][2] - 1.0).abs() < 1e-12);
assert!((jac[1][2] - 1.0).abs() < 1e-12);
assert!((jac[2][0] - 0.5).abs() < 1e-12);
assert!((jac[2][1] - 0.5).abs() < 1e-12);
}
#[test]
fn test_active_constraint_identification() {
let g = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
let h = vec![1.0, 2.0, 1.5];
let x = vec![1.0, 0.5];
let active = identify_active_constraints(&g, &h, &x, 1e-6);
assert_eq!(active, vec![0, 2]);
}
#[test]
fn test_solve_implicit_system_simple() {
let mat = vec![vec![2.0, 1.0], vec![1.0, 3.0]];
let rhs = vec![5.0, 7.0];
let sol = solve_implicit_system(&mat, &rhs).expect("solve failed");
assert!((sol[0] - 1.6).abs() < 1e-10);
assert!((sol[1] - 1.8).abs() < 1e-10);
}
#[test]
fn test_solve_singular_matrix() {
let mat = vec![vec![1.0, 2.0], vec![2.0, 4.0]];
let rhs = vec![3.0, 6.0];
let result = solve_implicit_system(&mat, &rhs);
assert!(result.is_err());
}
#[test]
fn test_extract_active_constraints() {
let g = vec![vec![1.0], vec![2.0], vec![3.0]];
let h = vec![10.0, 20.0, 30.0];
let active = vec![0, 2];
let (ga, ha) = extract_active_constraints(&g, &h, &active);
assert_eq!(ga.len(), 2);
assert!((ga[0][0] - 1.0).abs() < 1e-12);
assert!((ga[1][0] - 3.0).abs() < 1e-12);
assert!((ha[0] - 10.0).abs() < 1e-12);
assert!((ha[1] - 30.0).abs() < 1e-12);
}
#[test]
fn test_empty_constraints() {
let q = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
let g: Vec<Vec<f64>> = vec![];
let a: Vec<Vec<f64>> = vec![];
let x = vec![1.0, 2.0];
let lam: Vec<f64> = vec![];
let nu: Vec<f64> = vec![];
let jac = compute_kkt_jacobian(&q, &g, &a, &x, &lam, &nu);
assert_eq!(jac.len(), 2);
assert!((jac[0][0] - 2.0).abs() < 1e-12);
}
}