use super::*;
pub(crate) fn compute_adjoint_z_c(
ing: &ScalarGlmIngredients<'_>,
hop: &dyn HessianOperator,
leverage: &Array1<f64>,
subspace: Option<&PenaltySubspaceTrace>,
) -> Result<Array1<f64>, String> {
let mut weighted = Array1::<f64>::zeros(ing.c_array.len());
Zip::from(&mut weighted)
.and(ing.c_array)
.and(leverage)
.for_each(|w, &c, &h| *w = c * h);
let v = ing.x.transpose_vector_multiply(&weighted);
match subspace {
Some(kernel) => Ok(kernel.apply_pseudo_inverse(&v)),
None => Ok(hop.solve(&v)),
}
}
pub(crate) fn compute_fourth_derivative_trace(
ing: &ScalarGlmIngredients<'_>,
v_k: &Array1<f64>,
v_l: &Array1<f64>,
leverage: &Array1<f64>,
) -> Result<Option<f64>, String> {
let Some(d_array) = ing.d_array else {
return Ok(None);
};
let x_vk = ing.x.matrixvectormultiply(v_k);
let x_vl = ing.x.matrixvectormultiply(v_l);
let mut acc = 0.0;
Zip::from(d_array)
.and(&x_vk)
.and(&x_vl)
.and(leverage)
.for_each(|&d, &xvk, &xvl, &h| acc += d * xvk * xvl * h);
Ok(Some(acc))
}
pub(crate) fn compute_fourth_derivative_trace_matrix(
ing: &ScalarGlmIngredients<'_>,
modes: &[&Array1<f64>],
leverage: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let Some(d_array) = ing.d_array else {
return Ok(None);
};
let n = ing.c_array.len();
let t = modes.len();
if t == 0 {
return Ok(Some(Array2::zeros((0, 0))));
}
if d_array.len() != n || leverage.len() != n {
return Err(RemlError::DimensionMismatch {
reason: format!(
"fourth-derivative trace shape mismatch: c={}, d={}, leverage={}",
n,
d_array.len(),
leverage.len()
),
}
.into());
}
let mut x_modes = Array2::<f64>::zeros((n, t));
for (j, mode) in modes.iter().enumerate() {
let x_v = ing.x.matrixvectormultiply(mode);
if x_v.len() != n {
return Err(RemlError::DimensionMismatch {
reason: format!(
"fourth-derivative trace Xv length mismatch for mode {j}: got {}, expected {n}",
x_v.len()
),
}
.into());
}
x_modes.column_mut(j).assign(&x_v);
}
let mut weighted = x_modes.clone();
Zip::from(weighted.rows_mut())
.and(d_array)
.and(leverage)
.for_each(|mut row, &d, &h| {
let scale = d * h;
row.mapv_inplace(|value| value * scale);
});
Ok(Some(crate::faer_ndarray::fast_atb(&x_modes, &weighted)))
}
pub(crate) fn compute_ift_correction_trace(
hop: &dyn HessianOperator,
rhs: &Array1<f64>,
v_i: &Array1<f64>,
v_j: &Array1<f64>,
effective_deriv: &dyn HessianDerivativeProvider,
adjoint_z_c: Option<&Array1<f64>>,
glm_ingredients: Option<&ScalarGlmIngredients<'_>>,
leverage: Option<&Array1<f64>>,
precomputed_fourth_trace: Option<f64>,
subspace: Option<&PenaltySubspaceTrace>,
) -> Result<f64, String> {
if !effective_deriv.has_corrections() {
return Ok(0.0);
}
if let (Some(z_c), None) = (adjoint_z_c, subspace) {
let c_trace = rhs.dot(z_c);
let d_trace = if let Some(trace) = precomputed_fourth_trace {
trace
} else {
match (glm_ingredients, leverage) {
(Some(ing), Some(h_g)) => {
compute_fourth_derivative_trace(ing, v_i, v_j, h_g)?.unwrap_or(0.0)
}
_ => 0.0,
}
};
Ok(c_trace + d_trace)
} else {
let u = hop.solve(rhs);
if let Some(correction) =
effective_deriv.hessian_second_derivative_correction_result(v_i, v_j, &u)?
{
if let Some(kernel) = subspace {
match correction {
DriftDerivResult::Dense(matrix) => Ok(kernel.trace_projected_logdet(&matrix)),
DriftDerivResult::Operator(op) => Ok(kernel.trace_operator(op.as_ref())),
}
} else {
Ok(correction.trace_logdet(hop))
}
} else {
Ok(0.0)
}
}
}
pub(crate) fn compute_drift_deriv_traces(
hop: &dyn HessianOperator,
b_i_depends: bool,
b_j_depends: bool,
ext_i: Option<usize>,
ext_j: Option<usize>,
beta_i: &Array1<f64>,
beta_j: &Array1<f64>,
fixed_drift_deriv: Option<&FixedDriftDerivFn>,
subspace: Option<&PenaltySubspaceTrace>,
) -> f64 {
let trace_via = |result: DriftDerivResult| -> f64 {
if let Some(kernel) = subspace {
match result {
DriftDerivResult::Dense(matrix) => kernel.trace_projected_logdet(&matrix),
DriftDerivResult::Operator(op) => kernel.trace_operator(op.as_ref()),
}
} else {
match result {
DriftDerivResult::Dense(matrix) => hop.trace_logdet_gradient(&matrix),
DriftDerivResult::Operator(op) => hop.trace_logdet_operator(op.as_ref()),
}
}
};
let mut trace = 0.0;
if b_i_depends
&& let (Some(ei), Some(drift_fn)) = (ext_i, fixed_drift_deriv)
&& let Some(result) = drift_fn(ei, beta_j)
{
trace += trace_via(result);
}
if b_j_depends
&& let (Some(ej), Some(drift_fn)) = (ext_j, fixed_drift_deriv)
&& let Some(result) = drift_fn(ej, beta_i)
{
trace += trace_via(result);
}
trace
}
pub(crate) fn compute_base_h2_trace(
hop: &dyn HessianOperator,
b_mat: &Array2<f64>,
b_operator: Option<&dyn HyperOperator>,
subspace: Option<&PenaltySubspaceTrace>,
) -> f64 {
if let Some(kernel) = subspace {
if let Some(op) = b_operator {
kernel.trace_operator(op)
} else if b_mat.nrows() > 0 {
kernel.trace_projected_logdet(b_mat)
} else {
0.0
}
} else if let Some(op) = b_operator {
hop.trace_logdet_operator(op)
} else if b_mat.nrows() > 0 {
hop.trace_logdet_gradient(b_mat)
} else {
0.0
}
}
pub(crate) fn compute_base_h2_traces(
hop: &dyn HessianOperator,
pairs: &[&HyperCoordPair],
subspace: Option<&PenaltySubspaceTrace>,
trace_state: Option<Arc<Mutex<StochasticTraceState>>>,
) -> Vec<f64> {
if pairs.is_empty() {
return Vec::new();
}
if let Some(kernel) = subspace {
let factor = penalty_subspace_trace_factor(kernel);
let cache = ProjectedFactorCache::default();
let mut out = vec![0.0_f64; pairs.len()];
let mut op_terms: Vec<(usize, f64, &dyn HyperOperator)> = Vec::new();
for (idx, pair) in pairs.iter().enumerate() {
if let Some(op) = pair.b_operator.as_deref() {
collect_projected_trace_terms(idx, 1.0, op, &factor, &mut out, &mut op_terms);
} else if pair.b_mat.nrows() > 0 {
out[idx] = dense_trace_projected_factor(&pair.b_mat, &factor);
}
}
if !op_terms.is_empty() {
let batched =
trace_projected_operator_terms_batched(pairs.len(), &op_terms, &factor, &cache);
for (idx, val) in batched.into_iter().enumerate() {
out[idx] += val;
}
}
return out;
}
if subspace.is_none()
&& let Some(ds) = hop.as_exact_dense_spectral()
{
let mut out = vec![0.0_f64; pairs.len()];
let mut op_terms: Vec<(usize, f64, &dyn HyperOperator)> = Vec::new();
for (idx, pair) in pairs.iter().enumerate() {
if let Some(op) = pair.b_operator.as_deref() {
op_terms.push((idx, 1.0, op));
} else if pair.b_mat.nrows() > 0 {
out[idx] = hop.trace_logdet_gradient(&pair.b_mat);
}
}
if !op_terms.is_empty() {
let batched = trace_projected_operator_terms_batched(
pairs.len(),
&op_terms,
&ds.g_factor,
&ds.projected_factor_cache,
);
for (idx, val) in batched.into_iter().enumerate() {
out[idx] += val;
}
}
return out;
}
if subspace.is_none()
&& hop.prefers_stochastic_trace_estimation()
&& hop.logdet_traces_match_hinv_kernel()
{
let mut out = vec![0.0; pairs.len()];
let mut dense_refs: Vec<&Array2<f64>> = Vec::new();
let mut dense_slots = Vec::new();
let mut op_refs: Vec<&dyn HyperOperator> = Vec::new();
let mut op_slots = Vec::new();
for (idx, pair) in pairs.iter().enumerate() {
if let Some(op) = pair.b_operator.as_deref() {
op_slots.push(idx);
op_refs.push(op);
} else if pair.b_mat.nrows() > 0 {
dense_slots.push(idx);
dense_refs.push(&pair.b_mat);
}
}
if !dense_refs.is_empty() || !op_refs.is_empty() {
let estimator = match trace_state {
Some(state) => StochasticTraceEstimator::with_shared_trace_state(
StochasticTraceConfig::default(),
state,
),
None => StochasticTraceEstimator::with_defaults(),
};
let values = estimator.estimate_traces_with_operators(hop, &dense_refs, &op_refs);
for (local, &slot) in dense_slots.iter().enumerate() {
out[slot] = values[local];
}
let offset = dense_refs.len();
for (local, &slot) in op_slots.iter().enumerate() {
out[slot] = values[offset + local];
}
}
return out;
}
pairs
.iter()
.map(|pair| compute_base_h2_trace(hop, &pair.b_mat, pair.b_operator.as_deref(), subspace))
.collect()
}
pub(crate) fn trace_logdet_hessian_cross_dense_drift(
hop: &dyn HessianOperator,
dense: &Array2<f64>,
drift: &DriftDerivResult,
) -> f64 {
match drift {
DriftDerivResult::Dense(matrix) => hop.trace_logdet_hessian_cross(dense, matrix),
DriftDerivResult::Operator(operator) => {
hop.trace_logdet_hessian_cross_matrix_operator(dense, operator.as_ref())
}
}
}
pub(crate) fn trace_logdet_hessian_crosses_dense_spectral_drifts(
dense_hop: &DenseSpectralOperator,
dense_drifts: &[Array2<f64>],
ext_drifts: &[DriftDerivResult],
) -> Array2<f64> {
let total = dense_drifts.len() + ext_drifts.len();
let mut rotated = Vec::with_capacity(total);
for matrix in dense_drifts {
rotated.push(dense_hop.rotate_to_eigenbasis(matrix));
}
let mut ext_rotated: Vec<Option<Array2<f64>>> = (0..ext_drifts.len()).map(|_| None).collect();
let mut op_terms: Vec<(usize, f64, &dyn HyperOperator)> = Vec::new();
for (i, drift) in ext_drifts.iter().enumerate() {
match drift {
DriftDerivResult::Dense(matrix) => {
ext_rotated[i] = Some(dense_hop.rotate_to_eigenbasis(matrix));
}
DriftDerivResult::Operator(operator) => {
op_terms.push((i, 1.0, operator.as_ref()));
}
}
}
if !op_terms.is_empty() {
let batched = projected_operator_terms_batched(
ext_drifts.len(),
&op_terms,
&dense_hop.eigenvectors,
&dense_hop.projected_factor_cache,
);
for (i, _, _) in &op_terms {
ext_rotated[*i] = Some(batched[*i].clone());
}
}
for r in ext_rotated {
rotated.push(r.expect("every ext drift contributes a rotation"));
}
let mut out = Array2::<f64>::zeros((total, total));
for i in 0..total {
for j in i..total {
let value = dense_hop.trace_logdet_hessian_cross_rotated(&rotated[i], &rotated[j]);
out[[i, j]] = value;
if i != j {
out[[j, i]] = value;
}
}
}
out
}
#[inline]
pub(crate) fn can_use_stochastic_logdet_hinv_kernel(
hop: &dyn HessianOperator,
total_p: usize,
incl_logdet_h: bool,
) -> bool {
total_p > STOCHASTIC_TRACE_DIM_THRESHOLD
&& hop.prefers_stochastic_trace_estimation()
&& hop.logdet_traces_match_hinv_kernel()
&& incl_logdet_h
}