use super::*;
pub(crate) const EFS_MAX_STEP: f64 = 5.0;
pub fn compute_efs_update(solution: &InnerSolution<'_>, rho: &[f64], gradient: &[f64]) -> Vec<f64> {
let k = rho.len();
let ext_dim = solution.ext_coords.len();
let total = k + ext_dim;
assert_eq!(
gradient.len(),
total,
"compute_efs_update: gradient length {} != n_rho({k}) + n_ext({ext_dim})",
gradient.len(),
);
let mut steps = vec![0.0; total];
let (profiled_scale, dp_cgrad) = efs_profiling(solution);
let lambdas: Vec<f64> = rho.iter().map(|&r| r.exp()).collect();
let penalty_quad_atom =
crate::solver::estimate::reml::atoms::PenaltyQuadAtom::from_penalty_coords(
&lambdas,
&solution.penalty_coords,
&solution.beta,
)
.expect("EFS penalty-quadratic atom must match InnerSolution penalty layout");
for idx in 0..k {
let lambda = lambdas[idx];
let a_i = penalty_quad_atom.rho_frozen_d1(idx);
let q_eff = efs_q_eff_with_gamma_rate(
efs_q_eff(a_i, &solution.dispersion, dp_cgrad, profiled_scale),
lambda,
&solution.rho_prior,
idx,
);
if let Some(step) = efs_log_step_from_grad(q_eff, gradient[idx]) {
steps[idx] = step;
}
}
for (ext_idx, coord) in solution.ext_coords.iter().enumerate() {
if !coord.is_penalty_like {
continue;
}
let g_idx = k + ext_idx;
let q_eff = efs_q_eff(coord.a, &solution.dispersion, dp_cgrad, profiled_scale);
if let Some(step) = efs_log_step_from_grad(q_eff, gradient[g_idx]) {
steps[g_idx] = step;
}
}
steps
}
#[derive(Clone, Copy, Debug)]
pub(crate) struct EfsSingleLoopDiagnostics {
pub(crate) bias_proxy: f64,
pub(crate) gradient_residual: f64,
pub(crate) inner_residual: f64,
pub(crate) gradient_norm: f64,
pub(crate) step_inf_norm: f64,
}
pub(crate) fn efs_single_loop_diagnostics(
solution: &InnerSolution<'_>,
rho: &[f64],
gradient: &[f64],
steps: &[f64],
inner_residual: f64,
) -> EfsSingleLoopDiagnostics {
let k = rho.len();
let ext_dim = solution.ext_coords.len();
let total = k + ext_dim;
assert_eq!(gradient.len(), total);
assert_eq!(steps.len(), total);
let (profiled_scale, dp_cgrad) = efs_profiling(solution);
let lambdas: Vec<f64> = rho.iter().map(|&r| r.exp()).collect();
let penalty_quad_atom =
crate::solver::estimate::reml::atoms::PenaltyQuadAtom::from_penalty_coords(
&lambdas,
&solution.penalty_coords,
&solution.beta,
)
.expect("EFS diagnostics penalty-quadratic atom must match InnerSolution penalty layout");
let mut grad_sq = 0.0_f64;
let mut residual_sq = 0.0_f64;
let mut step_inf_norm = 0.0_f64;
for idx in 0..k {
let g = gradient[idx];
grad_sq += g * g;
step_inf_norm = step_inf_norm.max(steps[idx].abs());
let lambda = lambdas[idx];
let a_i = penalty_quad_atom.rho_frozen_d1(idx);
let q_eff = efs_q_eff_with_gamma_rate(
efs_q_eff(a_i, &solution.dispersion, dp_cgrad, profiled_scale),
lambda,
&solution.rho_prior,
idx,
);
if q_eff.is_finite() && q_eff > 0.0 && steps[idx].is_finite() {
let g_efs = 0.5 * q_eff * (1.0 - steps[idx].exp());
let d = g - g_efs;
residual_sq += d * d;
} else {
residual_sq += g * g;
}
}
for (ext_idx, coord) in solution.ext_coords.iter().enumerate() {
let idx = k + ext_idx;
let g = gradient[idx];
grad_sq += g * g;
step_inf_norm = step_inf_norm.max(steps[idx].abs());
if !coord.is_penalty_like {
continue;
}
let q_eff = efs_q_eff(coord.a, &solution.dispersion, dp_cgrad, profiled_scale);
if q_eff.is_finite() && q_eff > 0.0 && steps[idx].is_finite() {
let g_efs = 0.5 * q_eff * (1.0 - steps[idx].exp());
let d = g - g_efs;
residual_sq += d * d;
} else {
residual_sq += g * g;
}
}
let gradient_norm = grad_sq.sqrt();
let gradient_residual = residual_sq.sqrt() / (1.0 + gradient_norm);
let inner_residual = if inner_residual.is_finite() && inner_residual >= 0.0 {
inner_residual
} else {
f64::INFINITY
};
let bias_proxy = gradient_residual.max(inner_residual);
EfsSingleLoopDiagnostics {
bias_proxy,
gradient_residual,
inner_residual,
gradient_norm,
step_inf_norm,
}
}
pub(crate) const PSI_GRAM_PINV_TOL: f64 = 1e-8;
pub(crate) const PSI_INITIAL_ALPHA: f64 = 1.0;
pub(crate) const HYBRID_EFS_SCALAR_PAR_THRESHOLD: usize = 8;
pub(crate) const HYBRID_EFS_GRAM_PAIR_PAR_THRESHOLD: usize = 24;
pub(crate) const HYBRID_EFS_PSI_DRIFT_PAR_THRESHOLD: usize = 8;
pub struct HybridEfsResult {
pub steps: Vec<f64>,
pub psi_indices: Vec<usize>,
pub psi_gradient: Vec<f64>,
}
pub fn compute_hybrid_efs_update(
solution: &InnerSolution<'_>,
rho: &[f64],
gradient: &[f64],
) -> HybridEfsResult {
let k = rho.len();
let hop = &*solution.hessian_op;
let ext_dim = solution.ext_coords.len();
let total = k + ext_dim;
let mut steps = vec![0.0; total];
let (profiled_scale, dp_cgrad) = efs_profiling(solution);
assert_eq!(
gradient.len(),
total,
"compute_hybrid_efs_update: gradient length {} != n_rho({k}) + n_ext({ext_dim})",
gradient.len(),
);
let lambdas: Vec<f64> = rho.iter().map(|&r| r.exp()).collect();
let penalty_quad_atom =
crate::solver::estimate::reml::atoms::PenaltyQuadAtom::from_penalty_coords(
&lambdas,
&solution.penalty_coords,
&solution.beta,
)
.expect("hybrid EFS penalty-quadratic atom must match InnerSolution penalty layout");
let rho_candidates: Vec<(usize, Option<f64>)> =
if k >= HYBRID_EFS_SCALAR_PAR_THRESHOLD && rayon::current_thread_index().is_none() {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
(0..k)
.into_par_iter()
.map(|idx| {
let lambda = lambdas[idx];
let a_i = penalty_quad_atom.rho_frozen_d1(idx);
let q_eff = efs_q_eff_with_gamma_rate(
efs_q_eff(a_i, &solution.dispersion, dp_cgrad, profiled_scale),
lambda,
&solution.rho_prior,
idx,
);
(idx, efs_log_step_from_grad(q_eff, gradient[idx]))
})
.collect()
} else {
(0..k)
.map(|idx| {
let lambda = lambdas[idx];
let a_i = penalty_quad_atom.rho_frozen_d1(idx);
let q_eff = efs_q_eff_with_gamma_rate(
efs_q_eff(a_i, &solution.dispersion, dp_cgrad, profiled_scale),
lambda,
&solution.rho_prior,
idx,
);
(idx, efs_log_step_from_grad(q_eff, gradient[idx]))
})
.collect()
};
for (idx, candidate) in rho_candidates {
if let Some(step) = candidate {
steps[idx] = step;
}
}
let mut psi_local_indices: Vec<usize> = Vec::new(); let mut psi_global_indices: Vec<usize> = Vec::new(); let mut tau_local_indices: Vec<usize> = Vec::new();
for (ext_idx, coord) in solution.ext_coords.iter().enumerate() {
let g_idx = k + ext_idx;
if coord.is_penalty_like {
tau_local_indices.push(ext_idx);
} else {
psi_local_indices.push(ext_idx);
psi_global_indices.push(g_idx);
}
}
let tau_candidates: Vec<(usize, Option<f64>)> = if tau_local_indices.len()
>= HYBRID_EFS_SCALAR_PAR_THRESHOLD
&& rayon::current_thread_index().is_none()
{
use rayon::iter::{IntoParallelIterator, ParallelIterator};
tau_local_indices
.to_vec()
.into_par_iter()
.map(|ext_idx| {
let coord = &solution.ext_coords[ext_idx];
let g_idx = k + ext_idx;
let q_eff = efs_q_eff(coord.a, &solution.dispersion, dp_cgrad, profiled_scale);
(g_idx, efs_log_step_from_grad(q_eff, gradient[g_idx]))
})
.collect()
} else {
tau_local_indices
.iter()
.map(|&ext_idx| {
let coord = &solution.ext_coords[ext_idx];
let g_idx = k + ext_idx;
let q_eff = efs_q_eff(coord.a, &solution.dispersion, dp_cgrad, profiled_scale);
(g_idx, efs_log_step_from_grad(q_eff, gradient[g_idx]))
})
.collect()
};
for (g_idx, candidate) in tau_candidates {
if let Some(step) = candidate {
steps[g_idx] = step;
}
}
let psi_gradient: Vec<f64> = psi_global_indices.iter().map(|&gi| gradient[gi]).collect();
let n_psi = psi_local_indices.len();
if n_psi > 0 {
if n_psi == 1 {
let li = psi_local_indices[0];
let drift = &solution.ext_coords[li].drift;
let op = hyper_coord_drift_operator_arc(drift, hop.dim());
let dense = op.is_none().then(|| drift.materialize());
let gram = if let Some(dense_hop) = hop.as_dense_spectral() {
let projected = if let Some(op) = op.as_ref() {
dense_hop.projected_operator(&dense_hop.w_factor, op.as_ref())
} else {
dense_hop
.projected_matrix(dense.as_ref().expect("dense drift should be cached"))
};
dense_hop.trace_projected_cross(&projected, &projected)
} else {
trace_hinv_cached_drift_cross(
hop,
dense.as_ref(),
op.as_deref(),
dense.as_ref(),
op.as_deref(),
)
};
if gram.abs() >= PSI_GRAM_PINV_TOL.max(1e-30) {
let global_idx = psi_global_indices[0];
let raw_step = -PSI_INITIAL_ALPHA * psi_gradient[0] / gram;
steps[global_idx] = raw_step.clamp(-EFS_MAX_STEP, EFS_MAX_STEP);
}
return HybridEfsResult {
steps,
psi_indices: psi_global_indices,
psi_gradient,
};
}
let total_p = hop.dim();
let any_psi_operator = psi_local_indices.iter().any(|&li| {
let drift = &solution.ext_coords[li].drift;
drift.uses_operator_fast_path()
});
let use_stochastic_psi_gram = any_psi_operator
&& total_p > STOCHASTIC_TRACE_DIM_THRESHOLD
&& hop.prefers_stochastic_trace_estimation();
let gram = if use_stochastic_psi_gram {
let mut dense_mats = Vec::new();
let mut coord_has_operator = Vec::with_capacity(n_psi);
let mut operator_arcs: Vec<Arc<dyn HyperOperator>> = Vec::new();
for &li in &psi_local_indices {
let coord = &solution.ext_coords[li];
if let Some(op) = hyper_coord_drift_operator_arc(&coord.drift, hop.dim()) {
coord_has_operator.push(true);
operator_arcs.push(op);
} else {
coord_has_operator.push(false);
dense_mats.push(coord.drift.materialize());
}
}
let generic_ops: Vec<&dyn HyperOperator> =
operator_arcs.iter().map(|op| op.as_ref()).collect();
let impl_ops: Vec<&ImplicitHyperOperator> = generic_ops
.iter()
.filter_map(|&op| as_implicit(op))
.collect();
stochastic_trace_hinv_crosses(
hop,
&dense_mats,
&coord_has_operator,
&generic_ops,
&impl_ops,
)
} else {
let mut gram = ndarray::Array2::<f64>::zeros((n_psi, n_psi));
let parallel_psi_drifts = n_psi >= HYBRID_EFS_PSI_DRIFT_PAR_THRESHOLD
&& rayon::current_thread_index().is_none();
let drift_ops: Vec<Option<Arc<dyn HyperOperator>>> = if parallel_psi_drifts {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
(0..n_psi)
.into_par_iter()
.map(|idx| {
let drift = &solution.ext_coords[psi_local_indices[idx]].drift;
hyper_coord_drift_operator_arc(drift, hop.dim())
})
.collect()
} else {
psi_local_indices
.iter()
.map(|&li| {
let drift = &solution.ext_coords[li].drift;
hyper_coord_drift_operator_arc(drift, hop.dim())
})
.collect()
};
let dense_drifts: Vec<Option<Array2<f64>>> = if parallel_psi_drifts {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
(0..n_psi)
.into_par_iter()
.map(|idx| {
let drift = &solution.ext_coords[psi_local_indices[idx]].drift;
drift_ops[idx].is_none().then(|| drift.materialize())
})
.collect()
} else {
psi_local_indices
.iter()
.enumerate()
.map(|(idx, &li)| {
let drift = &solution.ext_coords[li].drift;
drift_ops[idx].is_none().then(|| drift.materialize())
})
.collect()
};
let pair_count = n_psi * (n_psi + 1) / 2;
let parallel_gram_pairs = pair_count >= HYBRID_EFS_GRAM_PAIR_PAR_THRESHOLD
&& rayon::current_thread_index().is_none();
if let Some(dense_hop) = hop.as_dense_spectral() {
let mut projected_drifts: Vec<Option<Array2<f64>>> =
(0..n_psi).map(|_| None).collect();
let mut op_terms: Vec<(usize, f64, &dyn HyperOperator)> = Vec::new();
for idx in 0..n_psi {
if let Some(op) = drift_ops[idx].as_ref() {
op_terms.push((idx, 1.0, op.as_ref()));
} else {
projected_drifts[idx] = Some(
dense_hop.projected_matrix(
dense_drifts[idx]
.as_ref()
.expect("dense drift should be cached"),
),
);
}
}
if !op_terms.is_empty() {
let batched = projected_operator_terms_batched(
n_psi,
&op_terms,
&dense_hop.w_factor,
&dense_hop.projected_factor_cache,
);
for (idx, _, _) in &op_terms {
projected_drifts[*idx] = Some(batched[*idx].clone());
}
}
let projected_drifts: Vec<Array2<f64>> = projected_drifts
.into_iter()
.map(|m| m.expect("projected drift filled"))
.collect();
if parallel_gram_pairs {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let pair_count = n_psi * (n_psi + 1) / 2;
let pair_values: Vec<(usize, usize, f64)> = (0..pair_count)
.into_par_iter()
.map(|pair_idx| {
let (d, e) = upper_triangle_pair_from_index(pair_idx, n_psi);
let val = dense_hop
.trace_projected_cross(&projected_drifts[d], &projected_drifts[e]);
(d, e, val)
})
.collect();
for (d, e, val) in pair_values {
gram[[d, e]] = val;
gram[[e, d]] = val;
}
} else {
for d in 0..n_psi {
for e in d..n_psi {
let val = dense_hop
.trace_projected_cross(&projected_drifts[d], &projected_drifts[e]);
gram[[d, e]] = val;
gram[[e, d]] = val;
}
}
}
} else if parallel_gram_pairs {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let pair_count = n_psi * (n_psi + 1) / 2;
let pair_values: Vec<(usize, usize, f64)> = (0..pair_count)
.into_par_iter()
.map(|pair_idx| {
let (d, e) = upper_triangle_pair_from_index(pair_idx, n_psi);
let val = trace_hinv_cached_drift_cross(
hop,
dense_drifts[d].as_ref(),
drift_ops[d].as_deref(),
dense_drifts[e].as_ref(),
drift_ops[e].as_deref(),
);
(d, e, val)
})
.collect();
for (d, e, val) in pair_values {
gram[[d, e]] = val;
gram[[e, d]] = val;
}
} else {
for d in 0..n_psi {
for e in d..n_psi {
let val = trace_hinv_cached_drift_cross(
hop,
dense_drifts[d].as_ref(),
drift_ops[d].as_deref(),
dense_drifts[e].as_ref(),
drift_ops[e].as_deref(),
);
gram[[d, e]] = val;
gram[[e, d]] = val;
}
}
}
gram
};
let delta_psi = pseudoinverse_times_vec(&gram, &psi_gradient, PSI_GRAM_PINV_TOL);
let alpha = PSI_INITIAL_ALPHA;
for (psi_idx, &global_idx) in psi_global_indices.iter().enumerate() {
let raw_step = -alpha * delta_psi[psi_idx];
steps[global_idx] = raw_step.clamp(-EFS_MAX_STEP, EFS_MAX_STEP);
}
}
HybridEfsResult {
steps,
psi_indices: psi_global_indices,
psi_gradient,
}
}
pub(crate) fn pseudoinverse_times_vec(
gram: &ndarray::Array2<f64>,
v: &[f64],
tol: f64,
) -> ndarray::Array1<f64> {
let n = gram.nrows();
assert_eq!(n, v.len(), "pseudoinverse_times_vec dimension mismatch");
if n == 0 {
return ndarray::Array1::zeros(0);
}
if n == 1 {
let g = gram[[0, 0]];
if g.abs() < tol.max(1e-30) {
return ndarray::Array1::zeros(1);
}
return ndarray::Array1::from_vec(vec![v[0] / g]);
}
let (eigenvalues, eigenvectors) = symmetric_eigen(gram);
let max_eval = eigenvalues.iter().cloned().fold(0.0_f64, f64::max);
let cutoff = tol * max_eval;
let qt_v: Vec<f64> = (0..n)
.map(|i| (0..n).map(|row| eigenvectors[[row, i]] * v[row]).sum())
.collect();
let mut result = ndarray::Array1::zeros(n);
for i in 0..n {
if eigenvalues[i] > cutoff {
let scale = qt_v[i] / eigenvalues[i];
for row in 0..n {
result[row] += scale * eigenvectors[[row, i]];
}
}
}
result
}
pub(crate) fn symmetric_eigen(a: &ndarray::Array2<f64>) -> (Vec<f64>, ndarray::Array2<f64>) {
let n = a.nrows();
assert_eq!(n, a.ncols(), "symmetric_eigen requires square matrix");
let mut work = a.clone();
let mut v = ndarray::Array2::<f64>::eye(n);
const MAX_SWEEPS: usize = 100;
const TOL: f64 = 1e-15;
const PAIR_SKIP_TOL: f64 = TOL * 0.01;
const TAU_DIAGONAL_THRESHOLD: f64 = 1e15;
let mut sweep = 0;
while sweep < MAX_SWEEPS {
let mut off_diag_sq = 0.0;
for i in 0..n {
for j in (i + 1)..n {
off_diag_sq += work[[i, j]] * work[[i, j]];
}
}
if off_diag_sq < TOL * TOL {
break;
}
for p in 0..n {
for q in (p + 1)..n {
let apq = work[[p, q]];
if apq.abs() < PAIR_SKIP_TOL {
continue;
}
let app = work[[p, p]];
let aqq = work[[q, q]];
let tau = (aqq - app) / (2.0 * apq);
let t = if tau.abs() > TAU_DIAGONAL_THRESHOLD {
continue;
} else {
let sign_tau = if tau >= 0.0 { 1.0 } else { -1.0 };
sign_tau / (tau.abs() + (1.0 + tau * tau).sqrt())
};
let c = 1.0 / (1.0 + t * t).sqrt();
let s = t * c;
work[[p, p]] = app - t * apq;
work[[q, q]] = aqq + t * apq;
work[[p, q]] = 0.0;
work[[q, p]] = 0.0;
for r in 0..n {
if r == p || r == q {
continue;
}
let wrp = work[[r, p]];
let wrq = work[[r, q]];
work[[r, p]] = c * wrp - s * wrq;
work[[p, r]] = work[[r, p]];
work[[r, q]] = s * wrp + c * wrq;
work[[q, r]] = work[[r, q]];
}
for r in 0..n {
let vrp = v[[r, p]];
let vrq = v[[r, q]];
v[[r, p]] = c * vrp - s * vrq;
v[[r, q]] = s * vrp + c * vrq;
}
}
}
sweep += 1;
}
let eigenvalues: Vec<f64> = (0..n).map(|i| work[[i, i]]).collect();
(eigenvalues, v)
}