use super::*;
use crate::estimate::reml::outer_eval;
pub(crate) struct RemlDerivativeWorkspace<'a> {
pub curvature_lambdas: &'a [f64],
pub rho_penalty_a_k_betas: &'a [Array1<f64>],
pub rho_curvature_a_k_betas: &'a [Array1<f64>],
pub rho_v_ks: Option<&'a [Array1<f64>]>,
pub ext_v_is: Option<&'a [Array1<f64>]>,
pub coord_corrections: &'a [Option<DriftDerivResult>],
}
pub(crate) struct KktThetaCorrections {
pub(crate) gradient: Array1<f64>,
pub(crate) hessian: Option<Array2<f64>>,
}
pub(crate) fn compute_kkt_residual_theta_corrections<F>(
hop: &dyn HessianOperator,
subspace: Option<&PenaltySubspaceTrace>,
score_derivs: &[Array1<f64>],
drift_apply: F,
residual: &Array1<f64>,
include_hessian: bool,
active: &[bool],
) -> Result<KktThetaCorrections, String>
where
F: Fn(usize, &Array1<f64>) -> Array1<f64>,
{
let m = score_derivs.len();
if m == 0 {
return Ok(KktThetaCorrections {
gradient: Array1::zeros(0),
hessian: include_hessian.then(|| Array2::zeros((0, 0))),
});
}
if active.len() != m {
return Err(RemlError::DimensionMismatch {
reason: format!(
"KKT theta correction active-bound mask mismatch: mask={} coords={}",
active.len(),
m
),
}
.into());
}
if residual.len() != hop.dim() {
return Err(RemlError::DimensionMismatch {
reason: format!(
"KKT residual dimension mismatch: residual={} Hessian dim={}",
residual.len(),
hop.dim()
),
}
.into());
}
let q = solve_kkt_residual_kernel(hop, subspace, residual);
let mut a_i_qs = Vec::with_capacity(m);
let mut r_i_dot_q = Vec::with_capacity(m);
let mut q_a_i_q = Vec::with_capacity(m);
for idx in 0..m {
if active[idx] {
r_i_dot_q.push(0.0);
q_a_i_q.push(0.0);
a_i_qs.push(Array1::<f64>::zeros(hop.dim()));
continue;
}
let a_i_q = drift_apply(idx, &q);
let linear = score_derivs[idx].dot(&q);
let quadratic = q.dot(&a_i_q);
if !linear.is_finite() || !quadratic.is_finite() {
return Err(RemlError::NonFiniteValue {
reason: format!(
"KKT theta correction produced non-finite gradient ingredients at coord \
{idx}: linear={linear} quadratic={quadratic}"
),
}
.into());
}
r_i_dot_q.push(linear);
q_a_i_q.push(quadratic);
a_i_qs.push(a_i_q);
}
let mut gradient = Array1::<f64>::zeros(m);
for idx in 0..m {
if !active[idx] {
gradient[idx] = -r_i_dot_q[idx] + 0.5 * q_a_i_q[idx];
}
}
let hessian = if include_hessian {
let mut a_solutions = Vec::with_capacity(m);
let mut q_derivs = Vec::with_capacity(m);
for idx in 0..m {
if active[idx] {
a_solutions.push(Array1::<f64>::zeros(hop.dim()));
q_derivs.push(Array1::<f64>::zeros(hop.dim()));
continue;
}
a_solutions.push(solve_kkt_residual_kernel(hop, subspace, &score_derivs[idx]));
let mut rhs = score_derivs[idx].clone();
rhs -= &a_i_qs[idx];
q_derivs.push(solve_kkt_residual_kernel(hop, subspace, &rhs));
}
let entry = |i: usize, j: usize| -> f64 {
if active[i] || active[j] {
return 0.0;
}
let cancel_exact_kkt_profile_term = score_derivs[i].dot(&a_solutions[j]);
cancel_exact_kkt_profile_term - score_derivs[i].dot(&q_derivs[j])
+ q_derivs[j].dot(&a_i_qs[i])
};
let mut h = Array2::<f64>::zeros((m, m));
for i in 0..m {
for j in i..m {
let raw = if i == j {
entry(i, j)
} else {
0.5 * (entry(i, j) + entry(j, i))
};
if !raw.is_finite() {
return Err(RemlError::NonFiniteValue {
reason: format!(
"KKT theta correction produced non-finite Hessian entry ({i}, {j}): \
{raw}"
),
}
.into());
}
h[[i, j]] = raw;
if i != j {
h[[j, i]] = raw;
}
}
}
Some(h)
} else {
None
};
Ok(KktThetaCorrections { gradient, hessian })
}
pub(crate) fn solve_kkt_residual_kernel(
hop: &dyn HessianOperator,
subspace: Option<&PenaltySubspaceTrace>,
rhs: &Array1<f64>,
) -> Array1<f64> {
if let Some(kernel) = subspace {
let projected = crate::faer_ndarray::fast_atv(&kernel.u_s, rhs);
let solved_projected = kernel.h_proj_inverse.dot(&projected);
crate::faer_ndarray::fast_av(&kernel.u_s, &solved_projected)
} else {
hop.solve(rhs)
}
}
pub(crate) fn active_upper_rho_mask(rho: &[f64]) -> Vec<bool> {
let latest_theta = outer_eval::latest_outer_theta_for_ift();
let matching_outer_theta = latest_theta.as_ref().is_some_and(|theta| {
theta.len() >= rho.len()
&& theta
.iter()
.take(rho.len())
.zip(rho.iter())
.all(|(&recorded, ¤t)| recorded.to_bits() == current.to_bits())
});
let upper_bounds = matching_outer_theta
.then(outer_eval::latest_outer_rho_upper_bounds_for_ift)
.flatten();
rho.iter()
.enumerate()
.map(|(idx, &value)| {
let upper = upper_bounds
.as_ref()
.and_then(|bounds| bounds.get(idx))
.copied()
.unwrap_or(crate::solver::estimate::RHO_BOUND);
upper.is_finite() && value >= upper - 1.0e-8
})
.collect()
}