use super::*;
#[derive(Debug)]
pub struct RemlLamlResult {
pub cost: f64,
pub ift_residual_energy: Option<f64>,
pub inner_polish_step: Option<Array1<f64>>,
pub gradient: Option<Array1<f64>>,
pub hessian: crate::solver::rho_optimizer::HessianResult,
pub rho_mode_response_cols: Option<Array2<f64>>,
pub ext_mode_response_cols: Option<Array2<f64>>,
}
use crate::solver::estimate::smooth_floor_dp;
pub(crate) const DENOM_RIDGE: f64 = 1e-8;
#[inline]
pub(crate) fn rho_curvature_lambda(solution: &InnerSolution<'_>, lambda: f64) -> f64 {
solution.rho_curvature_scale * lambda
}
pub(crate) fn penalty_coord_to_operator(
coord: PenaltyCoordinate,
scale: f64,
) -> Arc<dyn HyperOperator> {
struct OwnedPenaltyHyperOperator {
pub(crate) coord: PenaltyCoordinate,
pub(crate) scale: f64,
}
impl HyperOperator for OwnedPenaltyHyperOperator {
fn dim(&self) -> usize {
self.coord.dim()
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v.view(), out.view_mut());
out
}
fn as_any(&self) -> &(dyn std::any::Any + 'static) {
self
}
fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v, out.view_mut());
out
}
fn mul_vec_into(&self, v: ArrayView1<'_, f64>, out: ArrayViewMut1<'_, f64>) {
self.coord.apply_penalty_view_into(v, self.scale, out);
}
fn scaled_add_mul_vec(
&self,
v: ArrayView1<'_, f64>,
scale: f64,
out: ArrayViewMut1<'_, f64>,
) {
if scale == 0.0 {
return;
}
self.coord
.scaled_add_penalty_view(v, scale * self.scale, out);
}
fn to_dense(&self) -> Array2<f64> {
self.coord.scaled_dense_matrix(self.scale)
}
fn is_implicit(&self) -> bool {
false
}
}
Arc::new(OwnedPenaltyHyperOperator { coord, scale })
}
pub(crate) fn penalty_total_drift_result(
coord: &PenaltyCoordinate,
scale: f64,
correction: Option<&DriftDerivResult>,
) -> DriftDerivResult {
match correction {
Some(DriftDerivResult::Dense(corr)) => {
if coord.uses_operator_fast_path() {
DriftDerivResult::Operator(Arc::new(CompositeHyperOperator {
dense: Some(corr.clone()),
operators: vec![penalty_coord_to_operator(coord.clone(), scale)],
dim_hint: coord.dim(),
}))
} else {
let mut dense = coord.scaled_dense_matrix(scale);
dense += corr;
DriftDerivResult::Dense(dense)
}
}
Some(DriftDerivResult::Operator(corr_op)) => {
DriftDerivResult::Operator(Arc::new(CompositeHyperOperator {
dense: if coord.uses_operator_fast_path() {
None
} else {
Some(coord.scaled_dense_matrix(scale))
},
operators: {
let mut ops = vec![Arc::clone(corr_op)];
if coord.uses_operator_fast_path() {
ops.push(penalty_coord_to_operator(coord.clone(), scale));
}
ops
},
dim_hint: coord.dim(),
}))
}
None => {
if coord.uses_operator_fast_path() {
DriftDerivResult::Operator(Arc::new(CompositeHyperOperator {
dense: None,
operators: vec![penalty_coord_to_operator(coord.clone(), scale)],
dim_hint: coord.dim(),
}))
} else {
DriftDerivResult::Dense(coord.scaled_dense_matrix(scale))
}
}
}
}
pub(crate) fn hyper_coord_drift_operators(drift: &HyperCoordDrift) -> Vec<Arc<dyn HyperOperator>> {
let mut operators: Vec<Arc<dyn HyperOperator>> = Vec::new();
if let Some(block_local) = drift.block_local.as_ref() {
operators.push(Arc::new(block_local.clone()));
}
if let Some(operator) = drift.operator.as_ref() {
operators.push(Arc::clone(operator));
}
operators
}
pub(crate) fn hyper_coord_drift_operator_arc(
drift: &HyperCoordDrift,
dim_hint: usize,
) -> Option<Arc<dyn HyperOperator>> {
let mut operators = hyper_coord_drift_operators(drift);
if operators.is_empty() {
return None;
}
if drift.dense.is_none() && operators.len() == 1 {
return Some(operators.pop().expect("single operator drift"));
}
Some(Arc::new(CompositeHyperOperator {
dense: drift.dense.clone(),
operators,
dim_hint,
}))
}
pub(crate) fn drift_parts_into_result(
dense: Option<Array2<f64>>,
mut operators: Vec<Arc<dyn HyperOperator>>,
dim_hint: usize,
) -> DriftDerivResult {
if operators.is_empty() {
DriftDerivResult::Dense(dense.unwrap_or_else(|| Array2::<f64>::zeros((dim_hint, dim_hint))))
} else if dense.is_none() && operators.len() == 1 {
DriftDerivResult::Operator(operators.pop().expect("single operator drift"))
} else {
DriftDerivResult::Operator(Arc::new(CompositeHyperOperator {
dense,
operators,
dim_hint,
}))
}
}
pub(crate) fn hyper_coord_total_drift_parts(
drift: &HyperCoordDrift,
correction: Option<&DriftDerivResult>,
) -> (Option<Array2<f64>>, Vec<Arc<dyn HyperOperator>>) {
let mut dense = drift.dense.clone();
let mut operators = hyper_coord_drift_operators(drift);
if let Some(correction) = correction {
match correction {
DriftDerivResult::Dense(matrix) => {
if let Some(existing) = dense.as_mut() {
*existing += matrix;
} else {
dense = Some(matrix.clone());
}
}
DriftDerivResult::Operator(operator) => operators.push(Arc::clone(operator)),
}
}
(dense, operators)
}
pub(crate) fn hyper_coord_total_drift_result(
drift: &HyperCoordDrift,
correction: Option<&DriftDerivResult>,
dim_hint: usize,
) -> DriftDerivResult {
let (dense, operators) = hyper_coord_total_drift_parts(drift, correction);
drift_parts_into_result(dense, operators, dim_hint)
}
#[inline]
pub(crate) fn efs_q_eff(a_i: f64, dispersion: &DispersionHandling, dp_cgrad: f64, phi: f64) -> f64 {
match dispersion {
DispersionHandling::ProfiledGaussian => 2.0 * dp_cgrad * a_i / phi,
DispersionHandling::Fixed { .. } => 2.0 * a_i,
}
}
pub(crate) fn gamma_precision_rate_for_rho(
prior: &crate::types::RhoPrior,
idx: usize,
) -> Option<f64> {
match prior {
crate::types::RhoPrior::GammaPrecision { rate, .. } => Some(*rate),
crate::types::RhoPrior::Independent(priors) => {
priors.get(idx).and_then(|prior| match prior {
crate::types::RhoPrior::GammaPrecision { rate, .. } => Some(*rate),
_ => None,
})
}
_ => None,
}
}
#[inline]
pub(crate) fn efs_q_eff_with_gamma_rate(
base_q_eff: f64,
lambda: f64,
prior: &crate::types::RhoPrior,
idx: usize,
) -> f64 {
match gamma_precision_rate_for_rho(prior, idx) {
Some(rate) if rate.is_finite() && rate > 0.0 => base_q_eff + 2.0 * rate * lambda,
_ => base_q_eff,
}
}
#[inline]
pub(crate) fn efs_log_step_from_grad(q_eff: f64, g_full: f64) -> Option<f64> {
if !q_eff.is_finite() || q_eff <= 0.0 || !g_full.is_finite() {
return None;
}
let ratio = 1.0 - 2.0 * g_full / q_eff;
if ratio > 0.0 {
Some(ratio.ln().clamp(-EFS_MAX_STEP, EFS_MAX_STEP))
} else {
Some(-EFS_MAX_STEP)
}
}
#[inline]
pub(crate) fn efs_profiling(solution: &InnerSolution<'_>) -> (f64, f64) {
match &solution.dispersion {
DispersionHandling::ProfiledGaussian => {
let dp_raw = -2.0 * solution.log_likelihood + solution.penalty_quadratic;
let (dp_c, dp_cgrad, _) = smooth_floor_dp(dp_raw, solution.dp_floor_scale);
let denom = (solution.n_observations as f64 - solution.nullspace_dim).max(DENOM_RIDGE);
(dp_c / denom, dp_cgrad)
}
DispersionHandling::Fixed { phi, .. } => (*phi, 0.0),
}
}
pub(crate) fn trace_hinv_cached_drift_cross(
hop: &dyn HessianOperator,
left_dense: Option<&Array2<f64>>,
left_op: Option<&dyn HyperOperator>,
right_dense: Option<&Array2<f64>>,
right_op: Option<&dyn HyperOperator>,
) -> f64 {
match (left_op, right_op) {
(Some(left), Some(right)) => hop.trace_hinv_operator_cross(left, right),
(Some(left), None) => hop.trace_hinv_matrix_operator_cross(
right_dense.expect("right dense drift should be cached"),
left,
),
(None, Some(right)) => hop.trace_hinv_matrix_operator_cross(
left_dense.expect("left dense drift should be cached"),
right,
),
(None, None) => hop.trace_hinv_product_cross(
left_dense.expect("left dense drift should be cached"),
right_dense.expect("right dense drift should be cached"),
),
}
}
#[inline]
pub(crate) fn outer_gradient_entry(
a_i: f64,
trace_logdet_i: f64,
ld_s_i: f64,
dispersion: &DispersionHandling,
dp_cgrad: f64,
profiled_scale: f64,
incl_logdet_h: bool,
incl_logdet_s: bool,
) -> f64 {
let penalty_term = match dispersion {
DispersionHandling::ProfiledGaussian => dp_cgrad * a_i / profiled_scale,
DispersionHandling::Fixed { .. } => a_i,
};
let trace_term = if incl_logdet_h {
0.5 * trace_logdet_i
} else {
0.0
};
let det_term = if incl_logdet_s { 0.5 * ld_s_i } else { 0.0 };
penalty_term + trace_term - det_term
}
#[inline]
pub(crate) fn outer_hessian_entry(
a_i: f64,
a_j: f64,
g_i_dot_v_j: f64,
pair_a: f64,
cross_trace: f64,
h2_trace: f64,
pair_ld_s: f64,
profiled_phi: f64,
profiled_nu: f64,
profiled_dp_cgrad: f64,
profiled_dp_cgrad2: f64,
is_profiled: bool,
incl_logdet_h: bool,
incl_logdet_s: bool,
) -> f64 {
let q_raw = pair_a - g_i_dot_v_j;
let q = if is_profiled {
profiled_dp_cgrad * q_raw / profiled_phi
+ 2.0
* (profiled_dp_cgrad2 * profiled_nu * profiled_phi
- profiled_dp_cgrad * profiled_dp_cgrad)
* a_i
* a_j
/ (profiled_nu * profiled_phi * profiled_phi)
} else {
q_raw
};
let l = if incl_logdet_h {
0.5 * (cross_trace + h2_trace)
} else {
0.0
};
let p = if incl_logdet_s { -0.5 * pair_ld_s } else { 0.0 };
q + l + p
}
pub(crate) fn compute_active_constraint_tangent_basis(a_act: &Array2<f64>) -> Option<Array2<f64>> {
let k_act = a_act.nrows();
let p = a_act.ncols();
if k_act == 0 {
return None;
}
let ata = a_act.t().dot(a_act);
let (evals, evecs) = ata.eigh(faer::Side::Lower).ok()?;
let evals_slice = evals.as_slice()?;
let threshold = positive_eigenvalue_threshold(evals_slice);
let null_count = evals_slice.iter().filter(|&&s| s <= threshold).count();
if null_count == 0 || null_count == p {
return None;
}
Some(evecs.slice(ndarray::s![.., 0..null_count]).to_owned())
}
pub(crate) fn materialize_penalty_coord_dense(coord: &PenaltyCoordinate, p: usize) -> Array2<f64> {
let out = coord.scaled_dense_matrix(1.0);
assert_eq!(out.nrows(), p, "penalty coord dim mismatch");
assert_eq!(out.ncols(), p, "penalty coord dim mismatch");
out
}
pub(crate) fn assemble_h_raw_dense(op: &DenseSpectralOperator) -> Array2<f64> {
let p = op.n_dim;
let epsilon = f64::EPSILON.sqrt() * (p as f64).max(1.0);
let eps_sq = epsilon * epsilon;
if p == 0 {
return Array2::<f64>::zeros((0, 0));
}
let mut vs = op.eigenvectors.clone();
for j in 0..p {
let sigma = if op.active_mask[j] {
let r = op.reg_eigenvalues[j];
if r == 0.0 { 0.0 } else { r - eps_sq / r }
} else {
0.0
};
if sigma != 1.0 {
let mut col = vs.column_mut(j);
if sigma == 0.0 {
col.fill(0.0);
} else {
col.mapv_inplace(|v| v * sigma);
}
}
}
crate::faer_ndarray::fast_abt(&vs, &op.eigenvectors)
}
pub(crate) struct TangentProjectedHessianOperator {
pub(crate) z: Array2<f64>,
pub(crate) h_t_op: DenseSpectralOperator,
}
impl HessianOperator for TangentProjectedHessianOperator {
fn active_rank(&self) -> usize {
self.h_t_op.active_rank()
}
fn dim(&self) -> usize {
self.z.nrows()
}
fn logdet(&self) -> f64 {
self.h_t_op.logdet()
}
fn solve(&self, rhs: &Array1<f64>) -> Array1<f64> {
let r_t = self.z.t().dot(rhs);
let q_t = self.h_t_op.solve(&r_t);
self.z.dot(&q_t)
}
fn solve_multi(&self, rhs: &Array2<f64>) -> Array2<f64> {
let r_t = self.z.t().dot(rhs);
let q_t = self.h_t_op.solve_multi(&r_t);
self.z.dot(&q_t)
}
fn trace_hinv_product(&self, a: &Array2<f64>) -> f64 {
let zaz = self.z.t().dot(a).dot(&self.z);
self.h_t_op.trace_hinv_product(&zaz)
}
fn trace_logdet_gradient(&self, a: &Array2<f64>) -> f64 {
let zaz = self.z.t().dot(a).dot(&self.z);
self.h_t_op.trace_logdet_gradient(&zaz)
}
fn is_dense(&self) -> bool {
self.h_t_op.is_dense()
}
fn logdet_traces_match_hinv_kernel(&self) -> bool {
self.h_t_op.logdet_traces_match_hinv_kernel()
}
}
pub(crate) fn tangent_penalty_logdet(
z: &Array2<f64>,
penalty_coords: &[PenaltyCoordinate],
lambdas: &[f64],
p: usize,
) -> Result<PenaltyLogdetDerivs, String> {
let m = z.ncols();
let k = lambdas.len();
let zsz: Vec<Array2<f64>> = penalty_coords
.iter()
.map(|c| {
let s_k_full = materialize_penalty_coord_dense(c, p);
z.t().dot(&s_k_full).dot(z)
})
.collect();
let mut s_t = Array2::<f64>::zeros((m, m));
for k_idx in 0..k {
s_t.scaled_add(lambdas[k_idx], &zsz[k_idx]);
}
let (evals, evecs) = s_t
.eigh(faer::Side::Lower)
.map_err(|e| format!("tangent S eigendecomposition failed: {e}"))?;
let evals_slice = evals.as_slice().ok_or_else(|| {
"tangent S eigendecomposition returned non-contiguous eigenvalues".to_string()
})?;
let threshold = positive_eigenvalue_threshold(evals_slice);
let value = exact_pseudo_logdet(evals_slice, threshold);
let mut s_t_plus = Array2::<f64>::zeros((m, m));
for j in 0..m {
if evals[j] > threshold {
let inv = 1.0 / evals[j];
for r in 0..m {
let factor = evecs[[r, j]] * inv;
for c in 0..m {
s_t_plus[[r, c]] += factor * evecs[[c, j]];
}
}
}
}
let mut first = Array1::<f64>::zeros(k);
for k_idx in 0..k {
first[k_idx] = lambdas[k_idx] * trace_matrix_product(&s_t_plus, &zsz[k_idx]);
}
let mut second = Array2::<f64>::zeros((k, k));
for k_idx in 0..k {
second[[k_idx, k_idx]] += first[k_idx];
}
let s_plus_zsz: Vec<Array2<f64>> = zsz.iter().map(|m_k| s_t_plus.dot(m_k)).collect();
for k_idx in 0..k {
for l_idx in 0..=k_idx {
let cross = trace_matrix_product(&s_plus_zsz[k_idx], &s_plus_zsz[l_idx]);
let entry = -lambdas[k_idx] * lambdas[l_idx] * cross;
second[[k_idx, l_idx]] += entry;
if l_idx != k_idx {
second[[l_idx, k_idx]] += entry;
}
}
}
Ok(PenaltyLogdetDerivs {
value,
first,
second: Some(second),
})
}
pub(crate) struct BorrowedDerivProvider<'a>(&'a dyn HessianDerivativeProvider);
impl<'a> HessianDerivativeProvider for BorrowedDerivProvider<'a> {
fn hessian_derivative_correction(
&self,
v: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
self.0.hessian_derivative_correction(v)
}
fn hessian_derivative_correction_result(
&self,
v: &Array1<f64>,
) -> Result<Option<DriftDerivResult>, String> {
self.0.hessian_derivative_correction_result(v)
}
fn hessian_derivative_corrections_result(
&self,
vs: &[Array1<f64>],
) -> Result<Vec<Option<DriftDerivResult>>, String> {
self.0.hessian_derivative_corrections_result(vs)
}
fn has_batched_hessian_derivative_corrections(&self) -> bool {
self.0.has_batched_hessian_derivative_corrections()
}
fn hessian_second_derivative_correction(
&self,
v_k: &Array1<f64>,
v_l: &Array1<f64>,
u_kl: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
self.0.hessian_second_derivative_correction(v_k, v_l, u_kl)
}
fn hessian_second_derivative_correction_result(
&self,
v_k: &Array1<f64>,
v_l: &Array1<f64>,
u_kl: &Array1<f64>,
) -> Result<Option<DriftDerivResult>, String> {
self.0
.hessian_second_derivative_correction_result(v_k, v_l, u_kl)
}
fn hessian_second_derivative_corrections_result(
&self,
triples: &[(Array1<f64>, Array1<f64>, Array1<f64>)],
) -> Result<Vec<Option<DriftDerivResult>>, String> {
self.0.hessian_second_derivative_corrections_result(triples)
}
fn has_batched_hessian_second_derivative_corrections(&self) -> bool {
self.0.has_batched_hessian_second_derivative_corrections()
}
fn has_corrections(&self) -> bool {
self.0.has_corrections()
}
fn outer_hessian_derivative_kernel(&self) -> Option<OuterHessianDerivativeKernel> {
self.0.outer_hessian_derivative_kernel()
}
fn family_outer_hessian_operator(
&self,
) -> Option<Arc<dyn crate::solver::rho_optimizer::OuterHessianOperator>> {
self.0.family_outer_hessian_operator()
}
fn scalar_glm_ingredients(&self) -> Option<ScalarGlmIngredients<'_>> {
self.0.scalar_glm_ingredients()
}
}
pub(crate) fn try_tangent_projected_evaluate(
solution: &InnerSolution<'_>,
rho: &[f64],
mode: EvalMode,
prior_cost_gradient: Option<(f64, Array1<f64>, Option<Array2<f64>>)>,
) -> Result<Option<RemlLamlResult>, String> {
let block = match solution.active_constraints.as_ref() {
Some(b) if b.a.nrows() > 0 => b,
_ => return Ok(None),
};
let p = solution.beta.len();
if block.a.ncols() != p {
return Err(format!(
"active_constraints.a has {} columns but β is {}-dim",
block.a.ncols(),
p
));
}
let z = match compute_active_constraint_tangent_basis(&block.a) {
Some(z) => z,
None => {
let frozen = InnerSolution {
log_likelihood: solution.log_likelihood,
penalty_quadratic: solution.penalty_quadratic,
hessian_op: Arc::clone(&solution.hessian_op),
beta: solution.beta.clone(),
penalty_coords: solution.penalty_coords.clone(),
penalty_logdet: solution.penalty_logdet.clone(),
deriv_provider: Box::new(BorrowedDerivProvider(solution.deriv_provider.as_ref())),
firth: solution.firth.clone(),
hessian_logdet_correction: solution.hessian_logdet_correction,
penalty_subspace_trace: solution.penalty_subspace_trace.clone(),
rho_curvature_scale: solution.rho_curvature_scale,
rho_prior: solution.rho_prior.clone(),
n_observations: solution.n_observations,
nullspace_dim: solution.nullspace_dim,
gaussian_weight_log_sum_half: solution.gaussian_weight_log_sum_half,
dp_floor_scale: solution.dp_floor_scale,
dispersion: solution.dispersion.clone(),
ext_coords: solution.ext_coords.clone(),
ext_coord_pair_fn: None,
rho_ext_pair_fn: None,
contracted_psi_second_order: None,
fixed_drift_deriv: None,
barrier_config: solution.barrier_config.clone(),
kkt_residual: None,
active_constraints: None,
stochastic_trace_state: solution.stochastic_trace_state.clone(),
};
let result = reml_laml_evaluate(&frozen, rho, mode, prior_cost_gradient)?;
return Ok(Some(result));
}
};
let h_full = solution
.hessian_op
.assemble_h_dense_for_tangent_projection()?;
let h_t = z.t().dot(&h_full).dot(&z);
let h_t_op = DenseSpectralOperator::from_symmetric(&h_t)
.map_err(|e| format!("tangent H eigendecomposition failed: {e}"))?;
let lambdas: Vec<f64> = rho.iter().map(|&r| r.exp()).collect();
let projected_logdet = tangent_penalty_logdet(&z, &solution.penalty_coords, &lambdas, p)?;
let projected_kkt = solution.kkt_residual.clone();
let m_tangent = z.ncols();
let wrapper = TangentProjectedHessianOperator {
z: z.clone(),
h_t_op,
};
let projected_hlogdet_correction = if p == 0 {
0.0
} else {
solution.hessian_logdet_correction * (m_tangent as f64 / p as f64)
};
let projected_firth = solution
.firth
.as_ref()
.map(|term| match term.operator_arc() {
Some(op_arc) => {
let projected_value = op_arc.jeffreys_logdet_projected(z.view());
ExactJeffreysTerm::with_projected_value(op_arc, projected_value)
}
None => term.clone(),
});
if mode == EvalMode::ValueGradientHessian
&& !solution.ext_coords.is_empty()
&& (solution.ext_coord_pair_fn.is_some() || solution.rho_ext_pair_fn.is_some())
{
return Err(
"active constraints + ext_coords + mode=ValueGradientHessian not yet supported; \
fall back to ValueAndGradient. The ext-coord pair callbacks return p-space \
second-drift objects that the tangent hessian wrapper does not re-project."
.to_string(),
);
}
let projected = InnerSolution {
log_likelihood: solution.log_likelihood,
penalty_quadratic: solution.penalty_quadratic,
hessian_op: Arc::new(wrapper),
beta: solution.beta.clone(),
penalty_coords: solution.penalty_coords.clone(),
penalty_logdet: projected_logdet,
deriv_provider: Box::new(BorrowedDerivProvider(solution.deriv_provider.as_ref())),
firth: projected_firth,
hessian_logdet_correction: projected_hlogdet_correction,
penalty_subspace_trace: None,
rho_curvature_scale: solution.rho_curvature_scale,
rho_prior: solution.rho_prior.clone(),
n_observations: solution.n_observations,
nullspace_dim: solution.nullspace_dim,
gaussian_weight_log_sum_half: solution.gaussian_weight_log_sum_half,
dp_floor_scale: solution.dp_floor_scale,
dispersion: solution.dispersion.clone(),
ext_coords: solution.ext_coords.clone(),
ext_coord_pair_fn: None,
rho_ext_pair_fn: None,
contracted_psi_second_order: None,
fixed_drift_deriv: None,
barrier_config: solution.barrier_config.clone(),
kkt_residual: projected_kkt,
active_constraints: None,
stochastic_trace_state: solution.stochastic_trace_state.clone(),
};
let result = reml_laml_evaluate(&projected, rho, mode, prior_cost_gradient)?;
Ok(Some(result))
}