use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Zip};
use rayon::prelude::*;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use crate::faer_ndarray::FaerEigh;
use crate::linalg::matrix::DesignMatrix;
pub trait HessianOperator: Send + Sync {
fn logdet(&self) -> f64;
fn trace_hinv_product(&self, a: &Array2<f64>) -> f64;
fn as_exact_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
None
}
fn trace_hinv_operator(&self, op: &dyn HyperOperator) -> f64 {
if op.is_implicit() && self.dim() >= 128 {
let mut config = StochasticTraceConfig::default();
let sketch = (self.dim() / 32).clamp(4, 16);
config.hutchpp_sketch_dim = Some(sketch);
config.n_probes_max = (sketch * 4).max(32);
config.n_probes_min = sketch.max(8);
return hutchpp_estimate_trace_hinv_operator(self, op, &config);
}
if op.is_implicit() {
log::warn!(
"trace_hinv_operator: materializing implicit HyperOperator — \
backend should provide a matrix-free override"
);
}
self.trace_hinv_product(&op.to_dense())
}
fn trace_hinv_h_k(
&self,
a_k: &Array2<f64>,
third_deriv_correction: Option<&Array2<f64>>,
) -> f64 {
let base = self.trace_hinv_product(a_k);
match third_deriv_correction {
Some(c) => base + self.trace_hinv_product(c),
None => base,
}
}
fn solve(&self, rhs: &Array1<f64>) -> Array1<f64>;
fn solve_multi(&self, rhs: &Array2<f64>) -> Array2<f64>;
fn stochastic_trace_solve(&self, rhs: &Array1<f64>, _rel_tol: f64) -> Array1<f64> {
self.solve(rhs)
}
fn stochastic_trace_solve_multi(&self, rhs: &Array2<f64>, _rel_tol: f64) -> Array2<f64> {
self.solve_multi(rhs)
}
fn trace_hinv_product_cross(&self, a: &Array2<f64>, b: &Array2<f64>) -> f64 {
let solved_a = self.solve_multi(a);
if std::ptr::eq(a, b) {
return trace_matrix_product(&solved_a, &solved_a);
}
let solved_b = self.solve_multi(b);
trace_matrix_product(&solved_a, &solved_b)
}
fn trace_hinv_matrix_operator_cross(
&self,
matrix: &Array2<f64>,
op: &dyn HyperOperator,
) -> f64 {
if op.is_implicit() && self.dim() >= 128 {
let mut config = StochasticTraceConfig::default();
let sketch = (self.dim() / 32).clamp(4, 16);
config.hutchpp_sketch_dim = Some(sketch);
config.n_probes_max = (sketch * 4).max(32);
config.n_probes_min = sketch.max(8);
let lhs = DenseMatrixHyperOperator {
matrix: matrix.clone(),
};
return hutchpp_estimate_trace_hinv_operator_cross(self, &lhs, op, &config);
}
if op.is_implicit() {
log::warn!(
"trace_hinv_matrix_operator_cross: materializing implicit HyperOperator — \
backend should provide a matrix-free override"
);
}
self.trace_hinv_product_cross(matrix, &op.to_dense())
}
fn trace_hinv_operator_cross(
&self,
left: &dyn HyperOperator,
right: &dyn HyperOperator,
) -> f64 {
let l_implicit = left.is_implicit();
let r_implicit = right.is_implicit();
if (l_implicit || r_implicit) && self.dim() >= 128 {
let mut config = StochasticTraceConfig::default();
let sketch = (self.dim() / 32).clamp(4, 16);
config.hutchpp_sketch_dim = Some(sketch);
config.n_probes_max = (sketch * 4).max(32);
config.n_probes_min = sketch.max(8);
if std::ptr::eq(
left as *const dyn HyperOperator as *const (),
right as *const dyn HyperOperator as *const (),
) {
return hutchpp_estimate_trace_hinv_op_squared(self, left, &config);
}
return hutchpp_estimate_trace_hinv_operator_cross(self, left, right, &config);
}
if l_implicit || r_implicit {
log::warn!(
"trace_hinv_operator_cross: materializing implicit HyperOperator(s) — \
backend should provide a matrix-free override"
);
}
self.trace_hinv_product_cross(&left.to_dense(), &right.to_dense())
}
fn trace_logdet_gradient(&self, a: &Array2<f64>) -> f64 {
self.trace_hinv_product(a)
}
fn xt_logdet_kernel_x_diagonal(&self, x: &DesignMatrix) -> Array1<f64> {
debug_assert!(self.logdet_traces_match_hinv_kernel());
let n = x.nrows();
let p = x.ncols();
let block = {
const TARGET_CHUNK_FLOATS: usize = 1 << 16;
(TARGET_CHUNK_FLOATS / p.max(1)).clamp(1, n.max(1))
};
let mut h = Array1::<f64>::zeros(n);
let mut start = 0usize;
while start < n {
let end = (start + block).min(n);
let rows = x.try_row_chunk(start..end).unwrap_or_else(|err| {
panic!("xt_logdet_kernel_x_diagonal: row chunk failed: {err}")
});
let chunk_t = rows.t().to_owned();
let z_chunk = self.solve_multi(&chunk_t);
for i in 0..(end - start) {
let mut acc = 0.0;
for j in 0..p {
acc += rows[[i, j]] * z_chunk[[j, i]];
}
h[start + i] = acc;
}
start = end;
}
h
}
fn trace_logdet_operator(&self, op: &dyn HyperOperator) -> f64 {
if op.is_implicit() && self.dim() >= 128 && self.logdet_traces_match_hinv_kernel() {
let mut config = StochasticTraceConfig::default();
let sketch = (self.dim() / 32).clamp(4, 16);
config.hutchpp_sketch_dim = Some(sketch);
config.n_probes_max = (sketch * 4).max(32);
config.n_probes_min = sketch.max(8);
return hutchpp_estimate_trace_hinv_operator(self, op, &config);
}
if op.is_implicit() {
log::warn!(
"trace_logdet_operator: materializing implicit HyperOperator — \
backend should provide a matrix-free override"
);
}
self.trace_logdet_gradient(&op.to_dense())
}
fn trace_logdet_h_k(
&self,
a_k: &Array2<f64>,
third_deriv_correction: Option<&Array2<f64>>,
) -> f64 {
let base = self.trace_logdet_gradient(a_k);
match third_deriv_correction {
Some(c) => base + self.trace_logdet_gradient(c),
None => base,
}
}
fn trace_logdet_h_k_operator(
&self,
b_k: &dyn HyperOperator,
third_deriv_correction: Option<&Array2<f64>>,
) -> f64 {
let base = self.trace_logdet_operator(b_k);
match third_deriv_correction {
Some(c) => base + self.trace_logdet_gradient(c),
None => base,
}
}
fn trace_logdet_block_local(
&self,
block: &Array2<f64>,
scale: f64,
start: usize,
end: usize,
) -> f64 {
let p = self.dim();
let mut full = Array2::<f64>::zeros((p, p));
let bs = end - start;
for i in 0..bs {
for j in 0..bs {
full[[start + i, start + j]] = scale * block[[i, j]];
}
}
self.trace_logdet_gradient(&full)
}
fn trace_hinv_block_local(
&self,
block: &Array2<f64>,
scale: f64,
start: usize,
end: usize,
) -> f64 {
let p = self.dim();
let mut full = Array2::<f64>::zeros((p, p));
let bs = end - start;
for i in 0..bs {
for j in 0..bs {
full[[start + i, start + j]] = scale * block[[i, j]];
}
}
self.trace_hinv_product(&full)
}
fn trace_hinv_block_local_cross(
&self,
block: &Array2<f64>,
scale: f64,
start: usize,
end: usize,
) -> f64 {
let p = self.dim();
let bs = end - start;
let mut full = Array2::<f64>::zeros((p, p));
for i in 0..bs {
for j in 0..bs {
full[[start + i, start + j]] = scale * block[[i, j]];
}
}
self.trace_hinv_product_cross(&full, &full)
}
fn trace_logdet_hessian_cross(&self, h_i: &Array2<f64>, h_j: &Array2<f64>) -> f64 {
let y_i = self.solve_multi(h_i);
if std::ptr::eq(h_i, h_j) {
return -trace_matrix_product(&y_i, &y_i);
}
let y_j = self.solve_multi(h_j);
-trace_matrix_product(&y_j, &y_i)
}
fn trace_logdet_hessian_cross_matrix_operator(
&self,
h_i: &Array2<f64>,
h_j: &dyn HyperOperator,
) -> f64 {
self.trace_logdet_hessian_cross(h_i, &h_j.to_dense())
}
fn trace_logdet_hessian_cross_operator(
&self,
h_i: &dyn HyperOperator,
h_j: &dyn HyperOperator,
) -> f64 {
self.trace_logdet_hessian_cross(&h_i.to_dense(), &h_j.to_dense())
}
fn trace_logdet_hessian_crosses(&self, matrices: &[&Array2<f64>]) -> Array2<f64> {
let n = matrices.len();
let mut out = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in i..n {
let value = self.trace_logdet_hessian_cross(matrices[i], matrices[j]);
out[[i, j]] = value;
out[[j, i]] = value;
}
}
out
}
fn active_rank(&self) -> usize;
fn dim(&self) -> usize;
fn is_dense(&self) -> bool {
false
}
fn prefers_stochastic_trace_estimation(&self) -> bool {
self.is_dense()
}
fn logdet_traces_match_hinv_kernel(&self) -> bool {
true
}
fn as_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
None
}
}
pub trait HessianDerivativeProvider: Send + Sync {
fn hessian_derivative_correction(
&self,
v_k: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String>;
fn hessian_derivative_correction_result(
&self,
v_k: &Array1<f64>,
) -> Result<Option<DriftDerivResult>, String> {
Ok(self
.hessian_derivative_correction(v_k)?
.map(DriftDerivResult::Dense))
}
fn hessian_derivative_corrections_result(
&self,
v_ks: &[Array1<f64>],
) -> Result<Vec<Option<DriftDerivResult>>, String> {
v_ks.iter()
.map(|v_k| self.hessian_derivative_correction_result(v_k))
.collect()
}
fn has_batched_hessian_derivative_corrections(&self) -> bool {
false
}
fn hessian_second_derivative_correction(
&self,
_: &Array1<f64>,
_: &Array1<f64>,
_: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
Ok(None)
}
fn hessian_second_derivative_correction_result(
&self,
v_k: &Array1<f64>,
v_l: &Array1<f64>,
u_kl: &Array1<f64>,
) -> Result<Option<DriftDerivResult>, String> {
Ok(self
.hessian_second_derivative_correction(v_k, v_l, u_kl)?
.map(DriftDerivResult::Dense))
}
fn has_corrections(&self) -> bool;
fn scalar_glm_ingredients(&self) -> Option<ScalarGlmIngredients<'_>> {
None
}
fn outer_hessian_derivative_kernel(&self) -> Option<OuterHessianDerivativeKernel> {
self.scalar_glm_ingredients()
.map(OuterHessianDerivativeKernel::from_scalar_glm)
}
fn family_outer_hessian_operator(
&self,
) -> Option<Arc<dyn crate::solver::outer_strategy::OuterHessianOperator>> {
None
}
}
pub struct ScalarGlmIngredients<'a> {
pub c_array: &'a Array1<f64>,
pub d_array: Option<&'a Array1<f64>>,
pub x: &'a DesignMatrix,
}
#[derive(Clone)]
pub enum OuterHessianDerivativeKernel {
Gaussian,
ScalarGlm {
c_array: Array1<f64>,
d_array: Option<Array1<f64>>,
x: DesignMatrix,
},
Callback {
first: Arc<dyn Fn(&Array1<f64>) -> Result<Option<DriftDerivResult>, String> + Send + Sync>,
second: Arc<
dyn Fn(&Array1<f64>, &Array1<f64>) -> Result<Option<DriftDerivResult>, String>
+ Send
+ Sync,
>,
},
}
impl OuterHessianDerivativeKernel {
fn from_scalar_glm(ingredients: ScalarGlmIngredients<'_>) -> Self {
Self::ScalarGlm {
c_array: ingredients.c_array.clone(),
d_array: ingredients.d_array.cloned(),
x: ingredients.x.clone(),
}
}
}
pub struct GaussianDerivatives;
impl HessianDerivativeProvider for GaussianDerivatives {
fn outer_hessian_derivative_kernel(&self) -> Option<OuterHessianDerivativeKernel> {
Some(OuterHessianDerivativeKernel::Gaussian)
}
fn hessian_derivative_correction(
&self,
_: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
Ok(None)
}
fn has_corrections(&self) -> bool {
false
}
}
pub struct SinglePredictorGlmDerivatives {
pub c_array: Array1<f64>,
pub d_array: Option<Array1<f64>>,
pub x_transformed: DesignMatrix,
}
impl HessianDerivativeProvider for SinglePredictorGlmDerivatives {
fn hessian_derivative_correction(
&self,
v_k: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let x_v = self.x_transformed.matrixvectormultiply(v_k);
let mut neg_c_xv = x_v;
Zip::from(&mut neg_c_xv)
.and(&self.c_array)
.par_for_each(|xv_i, &c_i| *xv_i *= -c_i);
let result = self
.x_transformed
.compute_xtwx(&neg_c_xv)
.map_err(|e| format!("hessian_derivative_correction xtwx: {e}"))?;
Ok(Some(result))
}
fn hessian_second_derivative_correction(
&self,
v_k: &Array1<f64>,
v_l: &Array1<f64>,
u_kl: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let x_vk = self.x_transformed.matrixvectormultiply(v_k);
let x_vl = self.x_transformed.matrixvectormultiply(v_l);
let x_ukl = self.x_transformed.matrixvectormultiply(u_kl);
let n = self.x_transformed.nrows();
let mut weights = Array1::zeros(n);
Zip::from(&mut weights)
.and(&self.c_array)
.and(&x_ukl)
.par_for_each(|w, &c, &xu| *w = c * xu);
if let Some(ref d_array) = self.d_array {
Zip::from(&mut weights)
.and(d_array)
.and(&x_vk)
.and(&x_vl)
.par_for_each(|w, &d, &xvk, &xvl| *w += d * xvk * xvl);
}
let result = self
.x_transformed
.compute_xtwx(&weights)
.map_err(|e| format!("hessian_second_derivative_correction xtwx: {e}"))?;
Ok(Some(result))
}
fn has_corrections(&self) -> bool {
true
}
fn scalar_glm_ingredients(&self) -> Option<ScalarGlmIngredients<'_>> {
Some(ScalarGlmIngredients {
c_array: &self.c_array,
d_array: self.d_array.as_ref(),
x: &self.x_transformed,
})
}
}
pub struct FirthAwareGlmDerivatives {
pub(super) base: SinglePredictorGlmDerivatives,
pub(super) firth_op: std::sync::Arc<super::FirthDenseOperator>,
}
impl HessianDerivativeProvider for FirthAwareGlmDerivatives {
fn hessian_derivative_correction(
&self,
v_k: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let base_corr = self.base.hessian_derivative_correction(v_k)?;
let deta_k: Array1<f64> =
crate::faer_ndarray::fast_av(&self.firth_op.x_dense, v_k).mapv(|v| -v);
let dir_k = self.firth_op.direction_from_deta(deta_k);
let firth_corr = self.firth_op.hphi_direction(&dir_k);
match base_corr {
Some(mut bc) => {
bc -= &firth_corr;
Ok(Some(bc))
}
None => Ok(Some(-firth_corr)),
}
}
fn hessian_second_derivative_correction(
&self,
v_k: &Array1<f64>,
v_l: &Array1<f64>,
u_kl: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let base_corr = self
.base
.hessian_second_derivative_correction(v_k, v_l, u_kl)?;
let deta_kl: Array1<f64> = crate::faer_ndarray::fast_av(&self.firth_op.x_dense, u_kl);
let dir_kl = self.firth_op.direction_from_deta(deta_kl);
let firth_first = self.firth_op.hphi_direction(&dir_kl);
let deta_k: Array1<f64> =
crate::faer_ndarray::fast_av(&self.firth_op.x_dense, v_k).mapv(|v| -v);
let dir_k = self.firth_op.direction_from_deta(deta_k);
let deta_l: Array1<f64> =
crate::faer_ndarray::fast_av(&self.firth_op.x_dense, v_l).mapv(|v| -v);
let dir_l = self.firth_op.direction_from_deta(deta_l);
let p = v_k.len();
let eye = Array2::<f64>::eye(p);
let firth_second = self
.firth_op
.hphisecond_direction_apply(&dir_k, &dir_l, &eye);
let mut result = match base_corr {
Some(bc) => bc,
None => Array2::zeros((p, p)),
};
result -= &firth_first;
result -= &firth_second;
Ok(Some(result))
}
fn has_corrections(&self) -> bool {
true
}
fn scalar_glm_ingredients(&self) -> Option<ScalarGlmIngredients<'_>> {
None
}
}
#[derive(Clone)]
pub struct ExactJeffreysTerm {
operator: std::sync::Arc<super::FirthDenseOperator>,
}
impl ExactJeffreysTerm {
pub(crate) fn new(operator: std::sync::Arc<super::FirthDenseOperator>) -> Self {
Self { operator }
}
#[inline]
pub(crate) fn value(&self) -> f64 {
self.operator.jeffreys_logdet()
}
}
#[derive(Clone, Debug)]
pub struct BarrierConfig {
pub tau: f64,
pub constrained_indices: Vec<usize>,
pub lower_bounds: Vec<f64>,
}
impl BarrierConfig {
pub fn from_constraints(
constraints: Option<&crate::pirls::LinearInequalityConstraints>,
) -> Option<Self> {
let constraints = constraints?;
let mut indices = Vec::new();
let mut lower_bounds = Vec::new();
for i in 0..constraints.a.nrows() {
let row = constraints.a.row(i);
let mut single_col = None;
let mut is_simple = true;
for (j, &val) in row.iter().enumerate() {
if val.abs() < 1e-14 {
continue;
}
if (val - 1.0).abs() < 1e-14 && single_col.is_none() {
single_col = Some(j);
} else {
is_simple = false;
break;
}
}
if is_simple {
if let Some(col) = single_col {
indices.push(col);
lower_bounds.push(constraints.b[i]);
}
}
}
if indices.is_empty() {
return None;
}
Some(BarrierConfig {
tau: 1e-6,
constrained_indices: indices,
lower_bounds,
})
}
pub fn slacks(&self, beta: &Array1<f64>) -> Option<Vec<f64>> {
let mut slacks = Vec::with_capacity(self.constrained_indices.len());
for (ci, &idx) in self.constrained_indices.iter().enumerate() {
let delta = beta[idx] - self.lower_bounds[ci];
if delta <= 0.0 {
return None;
}
slacks.push(delta);
}
Some(slacks)
}
pub fn add_barrier_hessian_diagonal(
&self,
h: &mut Array2<f64>,
beta: &Array1<f64>,
) -> Result<(), String> {
let slacks = self
.slacks(beta)
.ok_or_else(|| "Barrier: infeasible point (slack ≤ 0)".to_string())?;
for (ci, &idx) in self.constrained_indices.iter().enumerate() {
h[[idx, idx]] += self.tau / (slacks[ci] * slacks[ci]);
}
Ok(())
}
pub fn barrier_cost(&self, beta: &Array1<f64>) -> Result<f64, String> {
let slacks = self
.slacks(beta)
.ok_or_else(|| "Barrier: infeasible point (slack ≤ 0)".to_string())?;
Ok(-self.tau * slacks.iter().map(|&d| d.ln()).sum::<f64>())
}
pub fn barrier_curvature_is_significant(
&self,
beta: &Array1<f64>,
ref_diag: f64,
threshold: f64,
) -> bool {
let slacks = match self.slacks(beta) {
Some(s) => s,
None => return true, };
let max_barrier_curv = slacks
.iter()
.map(|&d| self.tau / (d * d))
.fold(0.0_f64, f64::max);
max_barrier_curv > threshold * ref_diag
}
}
pub struct BarrierDerivativeProvider<'a> {
inner: &'a dyn HessianDerivativeProvider,
tau: f64,
constrained_indices: &'a [usize],
slacks: Vec<f64>,
p: usize,
}
impl<'a> BarrierDerivativeProvider<'a> {
pub fn new(
inner: &'a dyn HessianDerivativeProvider,
config: &'a BarrierConfig,
beta: &Array1<f64>,
) -> Result<Self, String> {
let slacks = config
.slacks(beta)
.ok_or_else(|| "BarrierDerivativeProvider: infeasible point".to_string())?;
Ok(Self {
inner,
tau: config.tau,
constrained_indices: &config.constrained_indices,
slacks,
p: beta.len(),
})
}
fn barrier_correction(&self, u: &Array1<f64>) -> Array2<f64> {
let mut result = Array2::zeros((self.p, self.p));
for (ci, &idx) in self.constrained_indices.iter().enumerate() {
let inv_cube = 1.0 / (self.slacks[ci].powi(3));
result[[idx, idx]] = -2.0 * self.tau * u[idx] * inv_cube;
}
result
}
fn barrier_second_correction(&self, u: &Array1<f64>, v: &Array1<f64>) -> Array2<f64> {
let mut result = Array2::zeros((self.p, self.p));
for (ci, &idx) in self.constrained_indices.iter().enumerate() {
let inv_4 = 1.0 / (self.slacks[ci].powi(4));
result[[idx, idx]] = 6.0 * self.tau * u[idx] * v[idx] * inv_4;
}
result
}
}
impl HessianDerivativeProvider for BarrierDerivativeProvider<'_> {
fn hessian_derivative_correction(
&self,
v_k: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let neg_v_k = v_k.mapv(|x| -x);
let barrier_corr = self.barrier_correction(&neg_v_k);
match self.inner.hessian_derivative_correction(v_k)? {
Some(mut ic) => {
ic += &barrier_corr;
Ok(Some(ic))
}
None => Ok(Some(barrier_corr)),
}
}
fn hessian_derivative_correction_result(
&self,
v_k: &Array1<f64>,
) -> Result<Option<DriftDerivResult>, String> {
let neg_v_k = v_k.mapv(|x| -x);
let barrier_corr = self.barrier_correction(&neg_v_k);
match self.inner.hessian_derivative_correction_result(v_k)? {
Some(DriftDerivResult::Dense(mut dense)) => {
dense += &barrier_corr;
Ok(Some(DriftDerivResult::Dense(dense)))
}
Some(DriftDerivResult::Operator(operator)) => Ok(Some(DriftDerivResult::Operator(
Arc::new(CompositeHyperOperator {
dense: Some(barrier_corr),
operators: vec![operator],
dim_hint: self.p,
}),
))),
None => Ok(Some(DriftDerivResult::Dense(barrier_corr))),
}
}
fn hessian_second_derivative_correction(
&self,
v_k: &Array1<f64>,
v_l: &Array1<f64>,
u_kl: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let barrier_total =
&self.barrier_correction(u_kl) + &self.barrier_second_correction(v_k, v_l);
match self
.inner
.hessian_second_derivative_correction(v_k, v_l, u_kl)?
{
Some(mut ic) => {
ic += &barrier_total;
Ok(Some(ic))
}
None => Ok(Some(barrier_total)),
}
}
fn hessian_second_derivative_correction_result(
&self,
v_k: &Array1<f64>,
v_l: &Array1<f64>,
u_kl: &Array1<f64>,
) -> Result<Option<DriftDerivResult>, String> {
let barrier_total =
&self.barrier_correction(u_kl) + &self.barrier_second_correction(v_k, v_l);
match self
.inner
.hessian_second_derivative_correction_result(v_k, v_l, u_kl)?
{
Some(DriftDerivResult::Dense(mut dense)) => {
dense += &barrier_total;
Ok(Some(DriftDerivResult::Dense(dense)))
}
Some(DriftDerivResult::Operator(operator)) => Ok(Some(DriftDerivResult::Operator(
Arc::new(CompositeHyperOperator {
dense: Some(barrier_total),
operators: vec![operator],
dim_hint: self.p,
}),
))),
None => Ok(Some(DriftDerivResult::Dense(barrier_total))),
}
}
fn has_corrections(&self) -> bool {
true
}
fn scalar_glm_ingredients(&self) -> Option<ScalarGlmIngredients<'_>> {
None
}
}
pub struct HyperCoord {
pub a: f64,
pub g: Array1<f64>,
pub drift: HyperCoordDrift,
pub ld_s: f64,
pub b_depends_on_beta: bool,
pub is_penalty_like: bool,
pub firth_g: Option<Array1<f64>>,
pub tk_eta_fixed: Option<Array1<f64>>,
pub tk_x_fixed: Option<Array2<f64>>,
}
pub struct HyperCoordPair {
pub a: f64,
pub g: Array1<f64>,
pub b_mat: Array2<f64>,
pub b_operator: Option<Box<dyn HyperOperator>>,
pub ld_s: f64,
}
impl HyperCoordPair {
pub fn zero() -> Self {
Self {
a: 0.0,
g: Array1::zeros(0),
b_mat: Array2::zeros((0, 0)),
b_operator: None,
ld_s: 0.0,
}
}
}
pub enum DriftDerivResult {
Dense(Array2<f64>),
Operator(Arc<dyn HyperOperator>),
}
impl std::fmt::Debug for DriftDerivResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Dense(matrix) => f
.debug_tuple("Dense")
.field(&format_args!("{}x{}", matrix.nrows(), matrix.ncols()))
.finish(),
Self::Operator(_) => f
.debug_tuple("Operator")
.field(&"<hyper-operator>")
.finish(),
}
}
}
impl DriftDerivResult {
pub fn into_operator(self) -> Arc<dyn HyperOperator> {
match self {
Self::Dense(matrix) => Arc::new(DenseMatrixHyperOperator { matrix }),
Self::Operator(operator) => operator,
}
}
pub fn trace_logdet(&self, hop: &dyn HessianOperator) -> f64 {
match self {
Self::Dense(matrix) => hop.trace_logdet_gradient(matrix),
Self::Operator(operator) => hop.trace_logdet_operator(operator.as_ref()),
}
}
pub fn apply(&self, v: &Array1<f64>) -> Array1<f64> {
match self {
Self::Dense(matrix) => matrix.dot(v),
Self::Operator(operator) => operator.mul_vec(v),
}
}
pub fn trace_logdet_hessian_cross(&self, rhs: &Self, hop: &dyn HessianOperator) -> f64 {
match (self, rhs) {
(Self::Dense(left), Self::Dense(right)) => hop.trace_logdet_hessian_cross(left, right),
(Self::Dense(left), Self::Operator(right)) => {
hop.trace_logdet_hessian_cross_matrix_operator(left, right.as_ref())
}
(Self::Operator(left), Self::Dense(right)) => {
hop.trace_logdet_hessian_cross_matrix_operator(right, left.as_ref())
}
(Self::Operator(left), Self::Operator(right)) => {
hop.trace_logdet_hessian_cross_operator(left.as_ref(), right.as_ref())
}
}
}
}
pub type FixedDriftDerivFn =
Box<dyn Fn(usize, &Array1<f64>) -> Option<DriftDerivResult> + Send + Sync>;
pub trait HyperOperator: Send + Sync {
fn dim(&self) -> usize;
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64>;
fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
self.mul_vec(&v.to_owned())
}
fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
out.assign(&self.mul_vec_view(v));
}
fn mul_mat(&self, factor: &Array2<f64>) -> Array2<f64> {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let p = factor.nrows();
let k = factor.ncols();
let mut out = Array2::<f64>::zeros((p, k));
if rayon::current_thread_index().is_some() {
for col in 0..k {
let bv = out.column_mut(col);
self.mul_vec_into(factor.column(col), bv);
}
return out;
}
let cols: Vec<Array1<f64>> = (0..k)
.into_par_iter()
.map(|col| {
let mut bv = Array1::<f64>::zeros(p);
self.mul_vec_into(factor.column(col), bv.view_mut());
bv
})
.collect();
for (col, bv) in cols.into_iter().enumerate() {
out.column_mut(col).assign(&bv);
}
out
}
fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
let op_factor = self.mul_mat(factor);
factor
.iter()
.zip(op_factor.iter())
.map(|(&f, &bf)| f * bf)
.sum()
}
fn trace_projected_factor_cached(
&self,
factor: &Array2<f64>,
_cache: &ProjectedFactorCache,
) -> f64 {
self.trace_projected_factor(factor)
}
fn projected_matrix(&self, factor: &Array2<f64>) -> Array2<f64> {
let op_factor = self.mul_mat(factor);
factor.t().dot(&op_factor)
}
fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
let cols = out.ncols();
let dim = out.nrows();
debug_assert!(start + cols <= dim);
let mut basis = Array1::<f64>::zeros(dim);
for local_col in 0..cols {
let global_col = start + local_col;
basis[global_col] = 1.0;
self.mul_vec_into(basis.view(), out.column_mut(local_col));
basis[global_col] = 0.0;
}
}
fn scaled_add_mul_vec(
&self,
v: ArrayView1<'_, f64>,
scale: f64,
mut out: ArrayViewMut1<'_, f64>,
) {
if scale == 0.0 {
return;
}
let mut work = Array1::<f64>::zeros(out.len());
self.mul_vec_into(v, work.view_mut());
out.scaled_add(scale, &work);
}
fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
let mut bv = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v.view(), bv.view_mut());
u.dot(&bv)
}
fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
let mut bv = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v, bv.view_mut());
u.dot(&bv)
}
fn has_fast_bilinear_view(&self) -> bool {
false
}
fn to_dense(&self) -> Array2<f64>;
fn is_implicit(&self) -> bool;
fn as_implicit(&self) -> Option<&ImplicitHyperOperator> {
None
}
fn block_local_data(&self) -> Option<(&Array2<f64>, usize, usize)> {
None
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub struct ProjectedFactorKey {
design_id: usize,
factor_ptr: usize,
rows: usize,
cols: usize,
row_stride: isize,
col_stride: isize,
value_hash: u64,
value_hash2: u64,
}
impl ProjectedFactorKey {
pub fn from_factor_view(design_id: usize, factor: ArrayView2<'_, f64>) -> Self {
let strides = factor.strides();
let (value_hash, value_hash2) = projected_factor_value_fingerprint(factor);
Self {
design_id,
factor_ptr: factor.as_ptr() as usize,
rows: factor.nrows(),
cols: factor.ncols(),
row_stride: strides[0],
col_stride: strides[1],
value_hash,
value_hash2,
}
}
}
fn projected_factor_value_fingerprint(factor: ArrayView2<'_, f64>) -> (u64, u64) {
let mut h1 = 0xcbf2_9ce4_8422_2325_u64;
let mut h2 = 0x9e37_79b1_85eb_ca87_u64;
for (idx, value) in factor.iter().enumerate() {
let bits = value.to_bits();
let mixed = bits.wrapping_add((idx as u64).wrapping_mul(0x517c_c1b7_2722_0a95));
h1 ^= mixed;
h1 = h1.wrapping_mul(0x0000_0100_0000_01b3);
h2 ^= bits.rotate_left((idx & 63) as u32);
h2 = h2.wrapping_mul(0x94d0_49bb_1331_11eb).rotate_left(27);
}
(h1, h2)
}
pub struct ProjectedFactorCache {
inner: Mutex<ProjectedFactorCacheInner>,
}
struct ProjectedFactorCacheInner {
entries: HashMap<ProjectedFactorKey, ProjectedFactorEntry>,
next_seq: u64,
total_bytes: usize,
budget_bytes: usize,
}
struct ProjectedFactorEntry {
value: Arc<Array2<f64>>,
bytes: usize,
last_used: u64,
}
impl Default for ProjectedFactorCache {
fn default() -> Self {
Self::with_budget(Self::DEFAULT_BUDGET_BYTES)
}
}
impl ProjectedFactorCache {
pub const DEFAULT_BUDGET_BYTES: usize = 2 * 1024 * 1024 * 1024;
pub fn with_budget(budget_bytes: usize) -> Self {
Self {
inner: Mutex::new(ProjectedFactorCacheInner {
entries: HashMap::new(),
next_seq: 0,
total_bytes: 0,
budget_bytes,
}),
}
}
pub fn get_or_insert_with(
&self,
key: ProjectedFactorKey,
compute: impl FnOnce() -> Array2<f64>,
) -> Arc<Array2<f64>> {
{
let mut inner = self
.inner
.lock()
.expect("projected factor cache lock poisoned");
inner.next_seq += 1;
let now = inner.next_seq;
if let Some(entry) = inner.entries.get_mut(&key) {
entry.last_used = now;
return entry.value.clone();
}
}
let computed = Arc::new(compute());
let bytes = computed.len().saturating_mul(std::mem::size_of::<f64>());
let mut inner = self
.inner
.lock()
.expect("projected factor cache lock poisoned");
inner.next_seq += 1;
let now = inner.next_seq;
if let Some(entry) = inner.entries.get_mut(&key) {
entry.last_used = now;
return entry.value.clone();
}
if inner.budget_bytes > 0 && bytes <= inner.budget_bytes {
while inner.total_bytes.saturating_add(bytes) > inner.budget_bytes
&& !inner.entries.is_empty()
{
let Some(oldest_key) = inner
.entries
.iter()
.min_by_key(|(_, e)| e.last_used)
.map(|(k, _)| *k)
else {
break;
};
if let Some(removed) = inner.entries.remove(&oldest_key) {
inner.total_bytes = inner.total_bytes.saturating_sub(removed.bytes);
}
}
}
inner.entries.insert(
key,
ProjectedFactorEntry {
value: computed.clone(),
bytes,
last_used: now,
},
);
inner.total_bytes = inner.total_bytes.saturating_add(bytes);
computed
}
pub fn len(&self) -> usize {
self.inner
.lock()
.map(|inner| inner.entries.len())
.unwrap_or(0)
}
pub fn total_bytes(&self) -> usize {
self.inner
.lock()
.map(|inner| inner.total_bytes)
.unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Clone)]
pub struct DenseMatrixHyperOperator {
pub matrix: Array2<f64>,
}
impl HyperOperator for DenseMatrixHyperOperator {
fn dim(&self) -> usize {
self.matrix.nrows()
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
self.matrix.dot(v)
}
fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
self.matrix.dot(&v)
}
fn mul_vec_into(&self, v: ArrayView1<'_, f64>, out: ArrayViewMut1<'_, f64>) {
dense_matvec_into(&self.matrix, v, out);
}
fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
let end = start + out.ncols();
debug_assert!(end <= self.matrix.ncols());
out.assign(&self.matrix.slice(ndarray::s![.., start..end]));
}
fn scaled_add_mul_vec(&self, v: ArrayView1<'_, f64>, scale: f64, out: ArrayViewMut1<'_, f64>) {
dense_matvec_scaled_add_into(&self.matrix, v, scale, out);
}
fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
dense_bilinear(&self.matrix, v.view(), u.view())
}
fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
dense_bilinear(&self.matrix, v, u)
}
fn to_dense(&self) -> Array2<f64> {
self.matrix.clone()
}
fn is_implicit(&self) -> bool {
false
}
}
#[derive(Clone)]
pub struct CompositeHyperOperator {
pub dense: Option<Array2<f64>>,
pub operators: Vec<Arc<dyn HyperOperator>>,
pub dim_hint: usize,
}
fn composite_trace_implicit_batched(
operators: &[Arc<dyn HyperOperator>],
factor: &Array2<f64>,
cache: Option<&ProjectedFactorCache>,
) -> f64 {
let mut trace = 0.0;
let mut group_starts: Vec<Vec<usize>> = Vec::new();
let mut handled = vec![false; operators.len()];
for (i, op) in operators.iter().enumerate() {
if handled[i] {
continue;
}
let Some(impl_i) = op.as_implicit() else {
continue;
};
let mut group = vec![i];
handled[i] = true;
for j in (i + 1)..operators.len() {
if handled[j] {
continue;
}
if let Some(impl_j) = operators[j].as_implicit() {
if Arc::ptr_eq(&impl_i.implicit_deriv, &impl_j.implicit_deriv)
&& Arc::ptr_eq(&impl_i.x_design, &impl_j.x_design)
&& Arc::ptr_eq(&impl_i.w_diag, &impl_j.w_diag)
&& impl_i.p == impl_j.p
{
group.push(j);
handled[j] = true;
}
}
}
group_starts.push(group);
}
for group in &group_starts {
if group.len() >= 2 {
let lead = operators[group[0]].as_implicit().unwrap();
let xf = match cache {
Some(c) => lead.cached_xf(factor, c),
None => Arc::new(lead.compute_xf(factor)),
};
let axes: Vec<(usize, &Array2<f64>, Option<&Array1<f64>>)> = group
.iter()
.map(|&k| {
let op = operators[k].as_implicit().unwrap();
(op.axis, &op.s_psi, op.c_x_psi_beta.as_deref())
})
.collect();
let values = lead.trace_projected_factor_all_axes_with_xf(factor, xf.view(), &axes);
trace += values.iter().sum::<f64>();
} else {
let op = &operators[group[0]];
trace += match cache {
Some(c) => op.trace_projected_factor_cached(factor, c),
None => op.trace_projected_factor(factor),
};
}
}
for (i, op) in operators.iter().enumerate() {
if handled[i] {
continue;
}
trace += match cache {
Some(c) => op.trace_projected_factor_cached(factor, c),
None => op.trace_projected_factor(factor),
};
}
trace
}
impl HyperOperator for CompositeHyperOperator {
fn dim(&self) -> usize {
self.dim_hint
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v.view(), out.view_mut());
out
}
fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v, out.view_mut());
out
}
fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
if self.dense.is_none() && self.operators.len() == 1 {
self.operators[0].mul_vec_into(v, out);
return;
}
out.fill(0.0);
if let Some(dense) = self.dense.as_ref() {
dense_matvec_into(dense, v, out.view_mut());
}
for op in &self.operators {
op.scaled_add_mul_vec(v, 1.0, out.view_mut());
}
}
fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
if self.dense.is_none() && self.operators.len() == 1 {
self.operators[0].mul_basis_columns_into(start, out);
return;
}
out.fill(0.0);
let cols = out.ncols();
let end = start + cols;
if let Some(dense) = self.dense.as_ref() {
out += &dense.slice(ndarray::s![.., start..end]);
}
let mut work = Array2::<f64>::zeros((out.nrows(), cols));
for op in &self.operators {
op.mul_basis_columns_into(start, work.view_mut());
out += &work;
}
}
fn scaled_add_mul_vec(
&self,
v: ArrayView1<'_, f64>,
scale: f64,
mut out: ArrayViewMut1<'_, f64>,
) {
if scale == 0.0 {
return;
}
if self.dense.is_none() && self.operators.len() == 1 {
self.operators[0].scaled_add_mul_vec(v, scale, out);
return;
}
if let Some(dense) = self.dense.as_ref() {
dense_matvec_scaled_add_into(dense, v, scale, out.view_mut());
}
for op in &self.operators {
op.scaled_add_mul_vec(v, scale, out.view_mut());
}
}
fn mul_mat(&self, factor: &Array2<f64>) -> Array2<f64> {
if self.dense.is_none() && self.operators.len() == 1 {
return self.operators[0].mul_mat(factor);
}
let p = factor.nrows();
let k = factor.ncols();
let mut out = Array2::<f64>::zeros((p, k));
if let Some(dense) = self.dense.as_ref() {
out += &dense.dot(factor);
}
for op in &self.operators {
out += &op.mul_mat(factor);
}
out
}
fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
if self.dense.is_none() && self.operators.len() == 1 {
return self.operators[0].trace_projected_factor(factor);
}
let mut trace = 0.0;
if let Some(dense) = self.dense.as_ref() {
let dense_factor = dense.dot(factor);
trace += factor
.iter()
.zip(dense_factor.iter())
.map(|(&f, &bf)| f * bf)
.sum::<f64>();
}
trace += composite_trace_implicit_batched(&self.operators, factor, None);
trace
}
fn trace_projected_factor_cached(
&self,
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> f64 {
if self.dense.is_none() && self.operators.len() == 1 {
return self.operators[0].trace_projected_factor_cached(factor, cache);
}
let mut trace = 0.0;
if let Some(dense) = self.dense.as_ref() {
let dense_factor = dense.dot(factor);
trace += factor
.iter()
.zip(dense_factor.iter())
.map(|(&f, &bf)| f * bf)
.sum::<f64>();
}
trace += composite_trace_implicit_batched(&self.operators, factor, Some(cache));
trace
}
fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
let mut total = 0.0;
if let Some(dense) = self.dense.as_ref() {
total += dense_bilinear(dense, v.view(), u.view());
}
for op in &self.operators {
total += op.bilinear(v, u);
}
total
}
fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
let mut total = 0.0;
if let Some(dense) = self.dense.as_ref() {
total += dense_bilinear(dense, v, u);
}
for op in &self.operators {
total += op.bilinear_view(v, u);
}
total
}
fn to_dense(&self) -> Array2<f64> {
let mut out = self
.dense
.clone()
.unwrap_or_else(|| Array2::<f64>::zeros((self.dim_hint, self.dim_hint)));
for op in &self.operators {
out += &op.to_dense();
}
out
}
fn is_implicit(&self) -> bool {
self.operators.iter().any(|op| op.is_implicit())
}
}
#[derive(Clone)]
pub struct BlockLocalDrift {
pub local: Array2<f64>,
pub start: usize,
pub end: usize,
pub total_dim: usize,
}
impl HyperOperator for BlockLocalDrift {
fn dim(&self) -> usize {
self.total_dim
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::zeros(v.len());
self.mul_vec_into(v.view(), out.view_mut());
out
}
fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
let mut out = Array1::zeros(v.len());
self.mul_vec_into(v, out.view_mut());
out
}
fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
out.fill(0.0);
let v_block = v.slice(ndarray::s![self.start..self.end]);
let out_block = out.slice_mut(ndarray::s![self.start..self.end]);
dense_matvec_into(&self.local, v_block, out_block);
}
fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
out.fill(0.0);
let global_end = start + out.ncols();
let col_start = start.max(self.start);
let col_end = global_end.min(self.end);
if col_start >= col_end {
return;
}
let local_col_start = col_start - self.start;
let local_col_end = col_end - self.start;
let out_col_start = col_start - start;
let out_col_end = col_end - start;
out.slice_mut(ndarray::s![
self.start..self.end,
out_col_start..out_col_end
])
.assign(
&self
.local
.slice(ndarray::s![.., local_col_start..local_col_end]),
);
}
fn scaled_add_mul_vec(
&self,
v: ArrayView1<'_, f64>,
scale: f64,
mut out: ArrayViewMut1<'_, f64>,
) {
if scale == 0.0 {
return;
}
let v_block = v.slice(ndarray::s![self.start..self.end]);
let out_block = out.slice_mut(ndarray::s![self.start..self.end]);
dense_matvec_scaled_add_into(&self.local, v_block, scale, out_block);
}
fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
let v_block = v.slice(ndarray::s![self.start..self.end]);
let u_block = u.slice(ndarray::s![self.start..self.end]);
u_block.dot(&self.local.dot(&v_block))
}
fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
let v_block = v.slice(ndarray::s![self.start..self.end]);
let u_block = u.slice(ndarray::s![self.start..self.end]);
let mut total = 0.0;
for row in 0..self.local.nrows() {
let mut row_dot = 0.0;
for col in 0..self.local.ncols() {
row_dot += self.local[[row, col]] * v_block[col];
}
total += u_block[row] * row_dot;
}
total
}
fn to_dense(&self) -> Array2<f64> {
let p = self.total_dim;
let mut out = Array2::zeros((p, p));
out.slice_mut(ndarray::s![self.start..self.end, self.start..self.end])
.assign(&self.local);
out
}
fn is_implicit(&self) -> bool {
false
}
fn block_local_data(&self) -> Option<(&Array2<f64>, usize, usize)> {
Some((&self.local, self.start, self.end))
}
}
pub struct HyperCoordDrift {
pub dense: Option<Array2<f64>>,
pub block_local: Option<BlockLocalDrift>,
pub operator: Option<Arc<dyn HyperOperator>>,
}
impl HyperCoordDrift {
pub fn none() -> Self {
Self {
dense: None,
block_local: None,
operator: None,
}
}
pub fn from_dense(dense: Array2<f64>) -> Self {
Self {
dense: Some(dense),
block_local: None,
operator: None,
}
}
pub fn from_operator(operator: Arc<dyn HyperOperator>) -> Self {
Self {
dense: None,
block_local: None,
operator: Some(operator),
}
}
pub fn from_parts(
dense: Option<Array2<f64>>,
operator: Option<Arc<dyn HyperOperator>>,
) -> Self {
let dense = dense.filter(|mat| !(operator.is_some() && mat.is_empty()));
Self {
dense,
block_local: None,
operator,
}
}
pub fn from_block_local_and_operator(
local: Array2<f64>,
start: usize,
end: usize,
total_dim: usize,
operator: Option<Arc<dyn HyperOperator>>,
) -> Self {
Self {
dense: None,
block_local: Some(BlockLocalDrift {
local,
start,
end,
total_dim,
}),
operator,
}
}
pub fn has_operator(&self) -> bool {
self.operator.is_some()
}
pub fn uses_operator_fast_path(&self) -> bool {
self.operator.is_some() || self.block_local.is_some()
}
pub fn operator_ref(&self) -> Option<&dyn HyperOperator> {
self.operator.as_ref().map(Arc::as_ref)
}
pub fn materialize(&self) -> Array2<f64> {
let p = self.infer_dim();
if p == 0 {
return Array2::zeros((0, 0));
}
let mut out = self.dense.clone().unwrap_or_else(|| Array2::zeros((p, p)));
if let Some(bl) = &self.block_local {
out.slice_mut(ndarray::s![bl.start..bl.end, bl.start..bl.end])
.scaled_add(1.0, &bl.local);
}
if let Some(op) = &self.operator {
out += &op.to_dense();
}
out
}
pub fn apply(&self, v: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::zeros(v.len());
self.scaled_add_apply(v.view(), 1.0, &mut out);
out
}
pub fn scaled_add_apply(&self, v: ArrayView1<'_, f64>, scale: f64, out: &mut Array1<f64>) {
debug_assert_eq!(v.len(), out.len());
if scale == 0.0 {
return;
}
if let Some(dense) = &self.dense {
dense_matvec_scaled_add_into(dense, v, scale, out.view_mut());
}
if let Some(bl) = &self.block_local {
let v_block = v.slice(ndarray::s![bl.start..bl.end]);
let out_block = out.slice_mut(ndarray::s![bl.start..bl.end]);
dense_matvec_scaled_add_into(&bl.local, v_block, scale, out_block);
}
if let Some(op) = &self.operator {
op.scaled_add_mul_vec(v, scale, out.view_mut());
}
}
fn infer_dim(&self) -> usize {
if let Some(d) = &self.dense {
return d.nrows();
}
if let Some(op) = &self.operator {
return op.dim();
}
if let Some(bl) = &self.block_local {
return bl.total_dim;
}
0
}
}
pub struct ImplicitHyperOperator {
pub implicit_deriv: std::sync::Arc<crate::terms::basis::ImplicitDesignPsiDerivative>,
pub axis: usize,
pub x_design: std::sync::Arc<DesignMatrix>,
pub w_diag: std::sync::Arc<Array1<f64>>,
pub s_psi: Array2<f64>,
pub p: usize,
pub c_x_psi_beta: Option<std::sync::Arc<Array1<f64>>>,
}
impl HyperOperator for ImplicitHyperOperator {
fn dim(&self) -> usize {
self.p
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(self.p);
self.mul_vec_into(v.view(), out.view_mut());
out
}
fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(self.p);
self.mul_vec_into(v, out.view_mut());
out
}
fn mul_vec_into(&self, v: ArrayView1<'_, f64>, out: ArrayViewMut1<'_, f64>) {
debug_assert_eq!(v.len(), self.p);
let n_obs = self.w_diag.len();
let mut x_v = Array1::<f64>::zeros(n_obs);
let mut n_work = Array1::<f64>::zeros(n_obs);
let mut p_work = Array1::<f64>::zeros(self.p);
design_matrix_apply_view_into(&self.x_design, v, x_v.view_mut());
self.matvec_with_shared_xz_into(&x_v, v, out, n_work.view_mut(), p_work.view_mut());
}
fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
let cols = out.ncols();
debug_assert!(start + cols <= self.p);
let n_obs = self.w_diag.len();
let mut basis = Array1::<f64>::zeros(self.p);
let mut x_col = Array1::<f64>::zeros(n_obs);
let mut dx_col = Array1::<f64>::zeros(n_obs);
let mut weighted = Array1::<f64>::zeros(n_obs);
let mut term = Array1::<f64>::zeros(self.p);
for local_col in 0..cols {
let global_col = start + local_col;
let mut out_col = out.column_mut(local_col);
out_col.assign(&self.s_psi.column(global_col));
design_matrix_column_into(&self.x_design, global_col, x_col.view_mut());
Zip::from(weighted.view_mut())
.and(self.w_diag.view())
.and(x_col.view())
.par_for_each(|dst, &w, &x| *dst = w * x);
term.assign(
&self
.implicit_deriv
.transpose_mul(self.axis, &weighted.view())
.expect("radial scalar evaluation failed during implicit hyper transpose_mul"),
);
out_col += &term;
basis[global_col] = 1.0;
dx_col.assign(
&self
.implicit_deriv
.forward_mul(self.axis, &basis.view())
.expect("radial scalar evaluation failed during implicit hyper forward_mul"),
);
basis[global_col] = 0.0;
Zip::from(weighted.view_mut())
.and(self.w_diag.view())
.and(dx_col.view())
.par_for_each(|dst, &w, &dx| *dst = w * dx);
design_matrix_transpose_apply_view_into(
&self.x_design,
weighted.view(),
term.view_mut(),
);
out_col += &term;
self.accumulate_c_correction_xt_into(
x_col.view(),
weighted.view_mut(),
term.view_mut(),
out_col,
);
}
}
fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
self.bilinear_view(v.view(), u.view())
}
fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
debug_assert_eq!(v.len(), self.p);
debug_assert_eq!(u.len(), self.p);
let x_v = design_matrix_apply_view(&self.x_design, v);
let x_u = design_matrix_apply_view(&self.x_design, u);
let dx_v = self
.implicit_deriv
.forward_mul(self.axis, &v)
.expect("radial scalar evaluation failed during implicit hyper forward_mul");
let dx_u = self
.implicit_deriv
.forward_mul(self.axis, &u)
.expect("radial scalar evaluation failed during implicit hyper forward_mul");
let w = &*self.w_diag;
let mut design = 0.0;
for i in 0..w.len() {
design += dx_v[i] * w[i] * x_u[i];
design += dx_u[i] * w[i] * x_v[i];
}
design += self.c_correction_bilinear(&x_v, &x_u);
let penalty = dense_bilinear(&self.s_psi, v, u);
design + penalty
}
fn to_dense(&self) -> Array2<f64> {
let p = self.p;
let mut out = Array2::<f64>::zeros((p, p));
let mut ei = Array1::<f64>::zeros(p);
for j in 0..p {
ei[j] = 1.0;
self.mul_vec_into(ei.view(), out.column_mut(j));
ei[j] = 0.0;
}
out
}
fn is_implicit(&self) -> bool {
true
}
fn as_implicit(&self) -> Option<&ImplicitHyperOperator> {
Some(self)
}
fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
debug_assert_eq!(factor.nrows(), self.p);
let n_obs = self.w_diag.len();
let rank = factor.ncols();
if rank == 0 || n_obs == 0 {
return 0.0;
}
let xf = self.compute_xf(factor);
self.trace_projected_factor_with_xf(factor, xf.view())
}
fn trace_projected_factor_cached(
&self,
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> f64 {
debug_assert_eq!(factor.nrows(), self.p);
let n_obs = self.w_diag.len();
let rank = factor.ncols();
if rank == 0 || n_obs == 0 {
return 0.0;
}
let xf = self.cached_xf(factor, cache);
self.trace_projected_factor_with_xf(factor, xf.view())
}
}
impl ImplicitHyperOperator {
fn compute_xf(&self, factor: &Array2<f64>) -> Array2<f64> {
let n_obs = self.w_diag.len();
let rank = factor.ncols();
let mut xf = Array2::<f64>::zeros((n_obs, rank));
const TARGET_BYTES: usize = 8 * 1024 * 1024;
let chunk_rows = (TARGET_BYTES / ((self.p + rank).max(1) * 8))
.max(512)
.min(n_obs);
let mut start = 0usize;
while start < n_obs {
let end = (start + chunk_rows).min(n_obs);
let rows = self
.x_design
.try_row_chunk(start..end)
.unwrap_or_else(|err| {
panic!("ImplicitHyperOperator::compute_xf row chunk failed: {err}")
});
let block = crate::faer_ndarray::fast_ab(&rows, factor);
xf.slice_mut(ndarray::s![start..end, ..]).assign(&block);
start = end;
}
xf
}
fn cached_xf(&self, factor: &Array2<f64>, cache: &ProjectedFactorCache) -> Arc<Array2<f64>> {
let design_id = Arc::as_ptr(&self.x_design) as usize;
let key = ProjectedFactorKey::from_factor_view(design_id, factor.view());
cache.get_or_insert_with(key, || self.compute_xf(factor))
}
fn trace_projected_factor_with_xf(&self, factor: &Array2<f64>, xf: ArrayView2<'_, f64>) -> f64 {
let rank = factor.ncols();
let n_obs = self.w_diag.len();
debug_assert_eq!(xf.dim(), (n_obs, rank));
let u_knot = self.implicit_deriv.unproject_matrix(&factor.view());
const TARGET_BYTES: usize = 8 * 1024 * 1024;
let chunk_rows = (TARGET_BYTES / ((self.p + rank).max(1) * 8))
.max(512)
.min(n_obs);
let w = self.w_diag.as_ref();
let c_opt = self.c_x_psi_beta.as_ref().map(|arc| arc.as_ref());
let mut design_total = 0.0_f64;
let mut correction_total = 0.0_f64;
let mut start = 0usize;
while start < n_obs {
let end = (start + chunk_rows).min(n_obs);
let chunk_n = end - start;
let xf_chunk = xf.slice(ndarray::s![start..end, ..]);
let kd_chunk = self
.implicit_deriv
.row_chunk_first_raw(self.axis, start..end)
.expect("radial scalar evaluation failed during implicit hyper forward_mul_matrix");
let dxf_chunk = crate::faer_ndarray::fast_ab(&kd_chunk, &u_knot);
for i_local in 0..chunk_n {
let i = start + i_local;
let w_i = w[i];
let dxf_row = dxf_chunk.row(i_local);
let xf_row = xf_chunk.row(i_local);
for k in 0..rank {
design_total += dxf_row[k] * w_i * xf_row[k];
}
if let Some(c) = c_opt {
let c_i = c[i];
for k in 0..rank {
let v = xf_row[k];
correction_total += c_i * v * v;
}
}
}
start = end;
}
let s_f = self.s_psi.dot(factor);
let penalty: f64 = factor.iter().zip(s_f.iter()).map(|(&f, &s)| f * s).sum();
2.0 * design_total + correction_total + penalty
}
fn trace_projected_factor_all_axes_with_xf<'a>(
&self,
factor: &Array2<f64>,
xf: ArrayView2<'_, f64>,
axes: &[(usize, &'a Array2<f64>, Option<&'a Array1<f64>>)],
) -> Vec<f64> {
let n_axes = axes.len();
if n_axes == 0 {
return Vec::new();
}
let rank = factor.ncols();
let n_obs = self.w_diag.len();
debug_assert_eq!(xf.dim(), (n_obs, rank));
let u_knot = self.implicit_deriv.unproject_matrix(&factor.view());
const TARGET_BYTES: usize = 8 * 1024 * 1024;
let chunk_rows = (TARGET_BYTES / ((self.p + rank).max(1) * 8))
.max(512)
.min(n_obs);
let w = self.w_diag.as_ref();
let mut design_totals = vec![0.0_f64; n_axes];
let mut correction_totals = vec![0.0_f64; n_axes];
let mut start = 0usize;
while start < n_obs {
let end = (start + chunk_rows).min(n_obs);
let chunk_n = end - start;
let xf_chunk = xf.slice(ndarray::s![start..end, ..]);
let kd_all = self
.implicit_deriv
.row_chunk_first_raw_all_axes(start..end)
.expect("radial scalar evaluation failed during implicit hyper batched trace");
for (slot, (axis, _, c_opt)) in axes.iter().enumerate() {
let kd_chunk = &kd_all[*axis];
let dxf_chunk = crate::faer_ndarray::fast_ab(kd_chunk, &u_knot);
let mut design_total = design_totals[slot];
let mut correction_total = correction_totals[slot];
for i_local in 0..chunk_n {
let i = start + i_local;
let w_i = w[i];
let dxf_row = dxf_chunk.row(i_local);
let xf_row = xf_chunk.row(i_local);
for k in 0..rank {
design_total += dxf_row[k] * w_i * xf_row[k];
}
if let Some(c) = c_opt {
let c_i = c[i];
for k in 0..rank {
let v = xf_row[k];
correction_total += c_i * v * v;
}
}
}
design_totals[slot] = design_total;
correction_totals[slot] = correction_total;
}
start = end;
}
let mut out = Vec::with_capacity(n_axes);
for (slot, (_axis, s_psi, _)) in axes.iter().enumerate() {
let s_f = s_psi.dot(factor);
let penalty: f64 = factor.iter().zip(s_f.iter()).map(|(&f, &s)| f * s).sum();
out.push(2.0 * design_totals[slot] + correction_totals[slot] + penalty);
}
out
}
fn accumulate_c_correction_xt_into(
&self,
x_col: ArrayView1<'_, f64>,
mut n_work: ArrayViewMut1<'_, f64>,
mut p_work: ArrayViewMut1<'_, f64>,
mut out_col: ArrayViewMut1<'_, f64>,
) {
let Some(c_x_psi_beta) = self.c_x_psi_beta.as_ref() else {
return;
};
let c = c_x_psi_beta.as_ref();
debug_assert_eq!(x_col.len(), c.len());
debug_assert_eq!(n_work.len(), c.len());
debug_assert_eq!(p_work.len(), self.p);
for i in 0..c.len() {
n_work[i] = c[i] * x_col[i];
}
design_matrix_transpose_apply_view_into(&self.x_design, n_work.view(), p_work.view_mut());
out_col += &p_work;
}
fn c_correction_bilinear(&self, x_v: &Array1<f64>, x_u: &Array1<f64>) -> f64 {
let Some(c_x_psi_beta) = self.c_x_psi_beta.as_ref() else {
return 0.0;
};
x_v.iter()
.zip(x_u.iter())
.zip(c_x_psi_beta.iter())
.map(|((&xv, &xu), &c)| xv * c * xu)
.sum()
}
pub fn bilinear_with_shared_x(
&self,
x_vec: &Array1<f64>,
y_vec: &Array1<f64>,
z: &Array1<f64>,
u: &Array1<f64>,
) -> f64 {
let dx_z = self
.implicit_deriv
.forward_mul(self.axis, &z.view())
.expect("radial scalar evaluation failed during implicit hyper forward_mul");
let dx_u = self
.implicit_deriv
.forward_mul(self.axis, &u.view())
.expect("radial scalar evaluation failed during implicit hyper forward_mul");
let mut design = 0.0f64;
let w = &*self.w_diag;
for i in 0..x_vec.len() {
let wi = w[i];
design += dx_z[i] * wi * y_vec[i];
design += dx_u[i] * wi * x_vec[i];
}
if let Some(c_x_psi_beta) = self.c_x_psi_beta.as_ref() {
let c = c_x_psi_beta.as_ref();
for i in 0..x_vec.len() {
design += y_vec[i] * c[i] * x_vec[i];
}
}
let penalty = dense_bilinear(&self.s_psi, z.view(), u.view());
design + penalty
}
pub fn matvec_with_shared_xz(&self, x_vec: &Array1<f64>, z: &Array1<f64>) -> Array1<f64> {
let w_x_vec = &*self.w_diag * x_vec;
let term1 = self
.implicit_deriv
.transpose_mul(self.axis, &w_x_vec.view())
.expect("radial scalar evaluation failed during implicit hyper transpose_mul");
let dx_z = self
.implicit_deriv
.forward_mul(self.axis, &z.view())
.expect("radial scalar evaluation failed during implicit hyper forward_mul");
let w_dx_z = &*self.w_diag * &dx_z;
let term2 = self.x_design.transpose_vector_multiply(&w_dx_z);
let term3 = self.s_psi.dot(z);
let mut out = term1 + term2 + term3;
if let Some(c_x_psi_beta) = self.c_x_psi_beta.as_ref() {
let weighted = c_x_psi_beta.as_ref() * x_vec;
out += &self.x_design.transpose_vector_multiply(&weighted);
}
out
}
pub fn matvec_with_shared_xz_into(
&self,
x_vec: &Array1<f64>,
z: ArrayView1<'_, f64>,
mut out: ArrayViewMut1<'_, f64>,
mut n_work: ArrayViewMut1<'_, f64>,
mut p_work: ArrayViewMut1<'_, f64>,
) {
debug_assert_eq!(z.len(), self.p);
debug_assert_eq!(out.len(), self.p);
debug_assert_eq!(n_work.len(), self.w_diag.len());
debug_assert_eq!(p_work.len(), self.p);
let w = &*self.w_diag;
for i in 0..w.len() {
n_work[i] = w[i] * x_vec[i];
}
let term1 = self
.implicit_deriv
.transpose_mul(self.axis, &n_work.view())
.expect("radial scalar evaluation failed during implicit hyper transpose_mul");
out.assign(&term1);
let dx_z = self
.implicit_deriv
.forward_mul(self.axis, &z)
.expect("radial scalar evaluation failed during implicit hyper forward_mul");
for i in 0..w.len() {
n_work[i] = w[i] * dx_z[i];
}
design_matrix_transpose_apply_view_into(&self.x_design, n_work.view(), p_work.view_mut());
out += &p_work;
dense_matvec_into(&self.s_psi, z, p_work.view_mut());
out += &p_work;
if let Some(c_x_psi_beta) = self.c_x_psi_beta.as_ref() {
let c = c_x_psi_beta.as_ref();
for i in 0..w.len() {
n_work[i] = c[i] * x_vec[i];
}
design_matrix_transpose_apply_view_into(
&self.x_design,
n_work.view(),
p_work.view_mut(),
);
out += &p_work;
}
}
}
pub struct SparseDirectionalHyperOperator {
pub x_tau: super::HyperDesignDerivative,
pub x_design: DesignMatrix,
pub w_diag: std::sync::Arc<Array1<f64>>,
pub s_tau: Array2<f64>,
pub c_x_tau_beta: Option<Array1<f64>>,
pub firth_hphi_tau_partial: Option<Array2<f64>>,
pub p: usize,
}
impl HyperOperator for SparseDirectionalHyperOperator {
fn dim(&self) -> usize {
self.p
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
debug_assert_eq!(v.len(), self.p);
let x_v = self.x_design.matrixvectormultiply(v);
let w_x_v = &*self.w_diag * &x_v;
let term1 = self
.x_tau
.transpose_mul_original(&w_x_v)
.expect("SparseDirectionalHyperOperator transpose product should be shape-consistent");
let x_tau_v = self
.x_tau
.forward_mul_original(v)
.expect("SparseDirectionalHyperOperator forward product should be shape-consistent");
let w_x_tau_v = &*self.w_diag * &x_tau_v;
let term2 = self.x_design.transpose_vector_multiply(&w_x_tau_v);
let term3 = self.s_tau.dot(v);
let mut out = term1 + term2 + term3;
if let Some(c_x_tau_beta) = self.c_x_tau_beta.as_ref() {
let weighted = c_x_tau_beta * &x_v;
out += &self.x_design.transpose_vector_multiply(&weighted);
}
if let Some(hphi_tau_partial) = self.firth_hphi_tau_partial.as_ref() {
out -= &hphi_tau_partial.dot(v);
}
out
}
fn to_dense(&self) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((self.p, self.p));
let mut basis = Array1::<f64>::zeros(self.p);
for j in 0..self.p {
basis[j] = 1.0;
self.mul_vec_into(basis.view(), out.column_mut(j));
basis[j] = 0.0;
}
out
}
fn is_implicit(&self) -> bool {
false
}
}
#[derive(Clone, Debug)]
pub struct PenaltyLogdetDerivs {
pub value: f64,
pub first: Array1<f64>,
pub second: Option<Array2<f64>>,
}
#[derive(Clone, Debug)]
pub enum PenaltyCoordinate {
DenseRoot(Array2<f64>),
BlockRoot {
root: Array2<f64>,
start: usize,
end: usize,
total_dim: usize,
},
KroneckerMarginal {
eigenvalues: Vec<Array1<f64>>,
dim_index: usize,
marginal_dims: Vec<usize>,
total_dim: usize,
},
}
impl PenaltyCoordinate {
pub fn from_dense_root(root: Array2<f64>) -> Self {
Self::DenseRoot(root)
}
pub fn from_block_root(root: Array2<f64>, start: usize, end: usize, total_dim: usize) -> Self {
assert_eq!(root.ncols(), end.saturating_sub(start));
assert!(end <= total_dim);
Self::BlockRoot {
root,
start,
end,
total_dim,
}
}
pub fn rank(&self) -> usize {
match self {
Self::DenseRoot(root) | Self::BlockRoot { root, .. } => root.nrows(),
Self::KroneckerMarginal {
eigenvalues,
dim_index,
..
} => {
let nz = eigenvalues[*dim_index]
.iter()
.filter(|&&v| v.abs() > 1e-12)
.count();
let other: usize = eigenvalues
.iter()
.enumerate()
.filter(|&(j, _)| j != *dim_index)
.map(|(_, e)| e.len())
.product::<usize>()
.max(1);
nz * other
}
}
}
pub fn dim(&self) -> usize {
match self {
Self::DenseRoot(root) => root.ncols(),
Self::BlockRoot { total_dim, .. } | Self::KroneckerMarginal { total_dim, .. } => {
*total_dim
}
}
}
pub fn uses_operator_fast_path(&self) -> bool {
matches!(
self,
Self::BlockRoot { .. } | Self::KroneckerMarginal { .. }
)
}
fn apply_root(&self, beta: &Array1<f64>) -> Array1<f64> {
debug_assert_eq!(beta.len(), self.dim());
match self {
Self::DenseRoot(root) => root.dot(beta),
Self::BlockRoot {
root, start, end, ..
} => root.dot(&beta.slice(ndarray::s![*start..*end])),
Self::KroneckerMarginal { .. } => {
panic!(
"apply_root not supported for KroneckerMarginal; use apply_penalty directly"
);
}
}
}
pub fn apply_penalty(&self, beta: &Array1<f64>, scale: f64) -> Array1<f64> {
debug_assert_eq!(beta.len(), self.dim());
let mut out = Array1::<f64>::zeros(self.dim());
self.apply_penalty_view_into(beta.view(), scale, out.view_mut());
out
}
pub fn apply_penalty_view_into(
&self,
beta: ArrayView1<'_, f64>,
scale: f64,
mut out: ArrayViewMut1<'_, f64>,
) {
debug_assert_eq!(beta.len(), self.dim());
debug_assert_eq!(out.len(), self.dim());
out.fill(0.0);
self.scaled_add_penalty_view(beta, scale, out);
}
pub fn scaled_add_penalty_view(
&self,
beta: ArrayView1<'_, f64>,
scale: f64,
mut out: ArrayViewMut1<'_, f64>,
) {
debug_assert_eq!(beta.len(), self.dim());
debug_assert_eq!(out.len(), self.dim());
if scale == 0.0 {
return;
}
match self {
Self::DenseRoot(_) | Self::BlockRoot { .. } => match self {
Self::DenseRoot(root) => {
let mut root_beta = Array1::<f64>::zeros(root.nrows());
dense_matvec_into(root, beta, root_beta.view_mut());
dense_transpose_matvec_scaled_add_into(
root,
root_beta.view(),
scale,
out.view_mut(),
);
}
Self::BlockRoot {
root,
start,
end,
total_dim: _,
} => {
let beta_block = beta.slice(ndarray::s![*start..*end]);
let mut root_beta = Array1::<f64>::zeros(root.nrows());
dense_matvec_into(root, beta_block, root_beta.view_mut());
let out_block = out.slice_mut(ndarray::s![*start..*end]);
dense_transpose_matvec_scaled_add_into(
root,
root_beta.view(),
scale,
out_block,
);
}
_ => unreachable!(),
},
Self::KroneckerMarginal {
eigenvalues,
dim_index,
marginal_dims,
total_dim,
} => {
let k = *dim_index;
let q_k = marginal_dims[k];
let stride_k: usize = marginal_dims[k + 1..]
.iter()
.copied()
.product::<usize>()
.max(1);
let outer_size: usize =
marginal_dims[..k].iter().copied().product::<usize>().max(1);
let inner_size = stride_k;
let eigs = &eigenvalues[k];
debug_assert_eq!(
outer_size * q_k * stride_k,
*total_dim,
"KroneckerMarginal dimension mismatch in apply"
);
for outer in 0..outer_size {
for j in 0..q_k {
let mu = eigs[j] * scale;
if mu == 0.0 {
continue;
}
let base = outer * q_k * stride_k + j * stride_k;
for inner in 0..inner_size {
let idx = base + inner;
out[idx] += mu * beta[idx];
}
}
}
}
}
}
pub fn quadratic(&self, beta: &Array1<f64>, scale: f64) -> f64 {
match self {
Self::DenseRoot(_) | Self::BlockRoot { .. } => {
let root_beta = self.apply_root(beta);
scale * root_beta.dot(&root_beta)
}
Self::KroneckerMarginal {
eigenvalues,
dim_index,
marginal_dims,
..
} => {
let k = *dim_index;
let q_k = marginal_dims[k];
let stride_k: usize = marginal_dims[k + 1..]
.iter()
.copied()
.product::<usize>()
.max(1);
let outer_size: usize =
marginal_dims[..k].iter().copied().product::<usize>().max(1);
let inner_size = stride_k;
let eigs = &eigenvalues[k];
let mut sum = 0.0;
for outer in 0..outer_size {
for j in 0..q_k {
let mu = eigs[j];
if mu == 0.0 {
continue;
}
let base = outer * q_k * stride_k + j * stride_k;
for inner in 0..inner_size {
let v = beta[base + inner];
sum += mu * v * v;
}
}
}
sum * scale
}
}
}
pub fn scaled_dense_matrix(&self, scale: f64) -> Array2<f64> {
match self {
Self::DenseRoot(root) => {
let mut out = root.t().dot(root);
out *= scale;
out
}
Self::BlockRoot {
root,
start,
end,
total_dim,
} => {
let mut out = Array2::<f64>::zeros((*total_dim, *total_dim));
let mut block = root.t().dot(root);
block *= scale;
out.slice_mut(ndarray::s![*start..*end, *start..*end])
.assign(&block);
out
}
Self::KroneckerMarginal {
eigenvalues,
dim_index,
marginal_dims,
total_dim,
} => {
let k = *dim_index;
let q_k = marginal_dims[k];
let stride_k: usize = marginal_dims[k + 1..]
.iter()
.copied()
.product::<usize>()
.max(1);
let outer_size: usize =
marginal_dims[..k].iter().copied().product::<usize>().max(1);
let eigs = &eigenvalues[k];
debug_assert_eq!(
outer_size * q_k * stride_k,
*total_dim,
"KroneckerMarginal dimension mismatch in to_dense"
);
let mut out = Array2::<f64>::zeros((*total_dim, *total_dim));
for outer in 0..outer_size {
for j in 0..q_k {
let mu = eigs[j] * scale;
let base = outer * q_k * stride_k + j * stride_k;
for inner in 0..stride_k {
let idx = base + inner;
out[[idx, idx]] = mu;
}
}
}
out
}
}
}
pub fn scaled_block_local(&self, scale: f64) -> (Array2<f64>, usize, usize) {
match self {
Self::DenseRoot(root) => {
let mut out = root.t().dot(root);
out *= scale;
let p = out.nrows();
(out, 0, p)
}
Self::BlockRoot {
root, start, end, ..
} => {
let mut block = root.t().dot(root);
block *= scale;
(block, *start, *end)
}
Self::KroneckerMarginal { total_dim, .. } => {
let mat = self.scaled_dense_matrix(scale);
(mat, 0, *total_dim)
}
}
}
pub fn is_block_local(&self) -> bool {
matches!(
self,
Self::BlockRoot { .. } | Self::KroneckerMarginal { .. }
)
}
pub fn scaled_matvec(&self, v: &Array1<f64>, scale: f64) -> Array1<f64> {
match self {
Self::DenseRoot(root) => {
let root_v = root.dot(v);
let mut out = root.t().dot(&root_v);
out *= scale;
out
}
Self::BlockRoot {
root, start, end, ..
} => {
let mut out = Array1::zeros(v.len());
let v_block = v.slice(ndarray::s![*start..*end]);
let root_v = root.dot(&v_block);
let mut block_result = root.t().dot(&root_v);
block_result *= scale;
out.slice_mut(ndarray::s![*start..*end])
.assign(&block_result);
out
}
Self::KroneckerMarginal { .. } => {
self.apply_penalty(v, scale)
}
}
}
pub fn trace_with_dense(&self, m: &Array2<f64>, scale: f64) -> f64 {
match self {
Self::DenseRoot(root) => {
let rm = root.dot(m);
scale
* rm.iter()
.zip(root.iter())
.map(|(&a, &b)| a * b)
.sum::<f64>()
}
Self::BlockRoot {
root, start, end, ..
} => {
let m_block = m.slice(ndarray::s![*start..*end, *start..*end]);
let rm = root.dot(&m_block);
scale
* rm.iter()
.zip(root.iter())
.map(|(&a, &b)| a * b)
.sum::<f64>()
}
Self::KroneckerMarginal {
eigenvalues,
dim_index,
marginal_dims,
..
} => {
let k = *dim_index;
let q_k = marginal_dims[k];
let stride_k: usize = marginal_dims[k + 1..]
.iter()
.copied()
.product::<usize>()
.max(1);
let outer_size: usize =
marginal_dims[..k].iter().copied().product::<usize>().max(1);
let eigs = &eigenvalues[k];
let mut trace = 0.0;
for outer in 0..outer_size {
for j in 0..q_k {
let mu = eigs[j];
let base = outer * q_k * stride_k + j * stride_k;
for inner in 0..stride_k {
let idx = base + inner;
trace += mu * m[[idx, idx]];
}
}
}
trace * scale
}
}
}
pub fn scaled_operator<'a>(
&'a self,
scale: f64,
dense_correction: Option<&'a Array2<f64>>,
) -> PenaltyHyperOperator<'a> {
PenaltyHyperOperator {
coord: self,
scale,
dense_correction,
}
}
}
pub struct PenaltyHyperOperator<'a> {
coord: &'a PenaltyCoordinate,
scale: f64,
dense_correction: Option<&'a Array2<f64>>,
}
impl HyperOperator for PenaltyHyperOperator<'_> {
fn dim(&self) -> usize {
self.coord.dim()
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v.view(), out.view_mut());
out
}
fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v, out.view_mut());
out
}
fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
self.coord
.apply_penalty_view_into(v, self.scale, out.view_mut());
if let Some(correction) = self.dense_correction {
dense_matvec_scaled_add_into(correction, v, 1.0, out.view_mut());
}
}
fn scaled_add_mul_vec(
&self,
v: ArrayView1<'_, f64>,
scale: f64,
mut out: ArrayViewMut1<'_, f64>,
) {
if scale == 0.0 {
return;
}
self.coord
.scaled_add_penalty_view(v, scale * self.scale, out.view_mut());
if let Some(correction) = self.dense_correction {
dense_matvec_scaled_add_into(correction, v, scale, out.view_mut());
}
}
fn to_dense(&self) -> Array2<f64> {
let mut out = self.coord.scaled_dense_matrix(self.scale);
if let Some(correction) = self.dense_correction {
out += correction;
}
out
}
fn is_implicit(&self) -> bool {
false
}
}
pub(crate) fn exact_intersection_nullity(
penalties: &[Array2<f64>],
nullspace_dims: &[usize],
) -> usize {
if penalties.is_empty() || nullspace_dims.is_empty() {
return 0;
}
if penalties.len() != nullspace_dims.len() {
return 0;
}
if nullspace_dims.iter().any(|&m| m == 0) {
return 0;
}
if penalties.len() == 1 {
return nullspace_dims[0];
}
let p = penalties[0].nrows();
let (_, vecs0) = match penalties[0].eigh(faer::Side::Lower) {
Ok(ev) => ev,
Err(_) => return 0,
};
let m0 = nullspace_dims[0].min(p);
let mut n_basis = Array2::<f64>::zeros((p, m0));
for col in 0..m0 {
for row in 0..p {
n_basis[[row, col]] = vecs0[[row, col]];
}
}
const SHARED_DIR_THRESHOLD: f64 = 0.99;
for k in 1..penalties.len() {
let current_dim = n_basis.ncols();
if current_dim == 0 {
return 0;
}
let (_, vecs_k) = match penalties[k].eigh(faer::Side::Lower) {
Ok(ev) => ev,
Err(_) => return 0,
};
let mk = nullspace_dims[k].min(p);
let mut nk_basis = Array2::<f64>::zeros((p, mk));
for col in 0..mk {
for row in 0..p {
nk_basis[[row, col]] = vecs_k[[row, col]];
}
}
let m_mat = n_basis.t().dot(&nk_basis);
let (u_opt, s, _) = match crate::faer_ndarray::FaerSvd::svd(&m_mat, true, false) {
Ok(usv) => usv,
Err(_) => return 0,
};
let u = match u_opt {
Some(u) => u,
None => return 0,
};
let shared: Vec<usize> = s
.iter()
.enumerate()
.filter(|(_, sv)| **sv > SHARED_DIR_THRESHOLD)
.map(|(i, _)| i)
.collect();
if shared.is_empty() {
return 0;
}
let mut n_new = Array2::<f64>::zeros((p, shared.len()));
for (new_col, &orig_col) in shared.iter().enumerate() {
for row in 0..p {
let mut val = 0.0;
for j in 0..current_dim {
val += n_basis[[row, j]] * u[[j, orig_col]];
}
n_new[[row, new_col]] = val;
}
}
n_basis = n_new;
}
n_basis.ncols()
}
pub(crate) fn positive_eigenvalue_threshold(eigenvalues: &[f64]) -> f64 {
let p = eigenvalues.len();
let max_ev = eigenvalues
.iter()
.copied()
.fold(0.0_f64, |a, b| a.max(b.abs()))
.max(1.0);
let safety = 100.0;
safety * (p as f64) * f64::EPSILON * max_ev
}
pub(crate) fn exact_pseudo_logdet(eigenvalues: &[f64], threshold: f64) -> f64 {
eigenvalues
.iter()
.filter(|&&s| s > threshold)
.map(|&s| s.ln())
.sum()
}
#[derive(Clone, Debug)]
pub struct PenaltySubspaceTrace {
pub u_s: Array2<f64>,
pub h_proj_inverse: Array2<f64>,
}
impl PenaltySubspaceTrace {
pub fn trace_projected_logdet(&self, a: &Array2<f64>) -> f64 {
self.trace_projected_logdet_reduced(&self.reduce(a))
}
pub fn reduce(&self, a: &Array2<f64>) -> Array2<f64> {
let u_s_t_a = crate::faer_ndarray::fast_atb(&self.u_s, a);
crate::faer_ndarray::fast_ab(&u_s_t_a, &self.u_s)
}
pub fn trace_projected_logdet_reduced(&self, r_mat: &Array2<f64>) -> f64 {
let mut trace = 0.0;
let r = self.h_proj_inverse.nrows();
for i in 0..r {
for j in 0..r {
trace += self.h_proj_inverse[[i, j]] * r_mat[[j, i]];
}
}
trace
}
pub fn trace_projected_logdet_cross_reduced(&self, ra: &Array2<f64>, rb: &Array2<f64>) -> f64 {
let left = self.h_proj_inverse.dot(ra);
let right = self.h_proj_inverse.dot(rb);
let r = left.nrows();
let mut trace = 0.0;
for i in 0..r {
for j in 0..r {
trace += left[[i, j]] * right[[j, i]];
}
}
trace
}
pub fn reduce_operator(&self, a: &dyn HyperOperator) -> Array2<f64> {
let au = a.mul_mat(&self.u_s);
crate::faer_ndarray::fast_atb(&self.u_s, &au)
}
pub fn trace_operator(&self, a: &dyn HyperOperator) -> f64 {
self.trace_projected_logdet_reduced(&self.reduce_operator(a))
}
pub fn xt_projected_kernel_x_diagonal(&self, x: &DesignMatrix) -> Array1<f64> {
let n = x.nrows();
let p = x.ncols();
let r = self.u_s.ncols();
debug_assert_eq!(self.u_s.nrows(), p);
debug_assert_eq!(self.h_proj_inverse.nrows(), r);
debug_assert_eq!(self.h_proj_inverse.ncols(), r);
let block = {
const TARGET_CHUNK_FLOATS: usize = 1 << 16;
(TARGET_CHUNK_FLOATS / p.max(1)).clamp(1, n.max(1))
};
let mut h = Array1::<f64>::zeros(n);
let mut start = 0usize;
while start < n {
let end = (start + block).min(n);
let rows = x.try_row_chunk(start..end).unwrap_or_else(|err| {
panic!("xt_projected_kernel_x_diagonal: row chunk failed: {err}")
});
let z_chunk = crate::faer_ndarray::fast_ab(&rows.to_owned(), &self.u_s);
for i in 0..(end - start) {
let row_z = z_chunk.row(i);
let mut acc = 0.0;
for a in 0..r {
let mut inner = 0.0;
for b in 0..r {
inner += self.h_proj_inverse[[a, b]] * row_z[b];
}
acc += row_z[a] * inner;
}
h[start + i] = acc;
}
start = end;
}
h
}
}
#[derive(Clone, Debug)]
pub enum DispersionHandling {
ProfiledGaussian,
Fixed {
phi: f64,
include_logdet_h: bool,
include_logdet_s: bool,
},
}
pub struct InnerSolution<'dp> {
pub log_likelihood: f64,
pub penalty_quadratic: f64,
pub hessian_op: Arc<dyn HessianOperator>,
pub beta: Array1<f64>,
pub penalty_coords: Vec<PenaltyCoordinate>,
pub penalty_logdet: PenaltyLogdetDerivs,
pub deriv_provider: Box<dyn HessianDerivativeProvider + 'dp>,
pub tk_correction: f64,
pub tk_gradient: Option<Array1<f64>>,
pub firth: Option<ExactJeffreysTerm>,
pub hessian_logdet_correction: f64,
pub penalty_subspace_trace: Option<Arc<PenaltySubspaceTrace>>,
pub rho_curvature_scale: f64,
pub n_observations: usize,
pub nullspace_dim: f64,
pub dispersion: DispersionHandling,
pub ext_coords: Vec<HyperCoord>,
pub ext_coord_pair_fn: Option<Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>>,
pub rho_ext_pair_fn: Option<Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>>,
pub fixed_drift_deriv: Option<FixedDriftDerivFn>,
pub barrier_config: Option<BarrierConfig>,
}
pub struct InnerSolutionBuilder<'dp> {
log_likelihood: f64,
penalty_quadratic: f64,
hessian_op: Arc<dyn HessianOperator>,
beta: Array1<f64>,
penalty_coords: Vec<PenaltyCoordinate>,
penalty_logdet: PenaltyLogdetDerivs,
n_observations: usize,
dispersion: DispersionHandling,
deriv_provider: Box<dyn HessianDerivativeProvider + 'dp>,
tk_correction: f64,
tk_gradient: Option<Array1<f64>>,
firth: Option<ExactJeffreysTerm>,
hessian_logdet_correction: f64,
penalty_subspace_trace: Option<Arc<PenaltySubspaceTrace>>,
rho_curvature_scale: f64,
nullspace_dim_override: Option<f64>,
ext_coords: Vec<HyperCoord>,
ext_coord_pair_fn: Option<Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>>,
rho_ext_pair_fn: Option<Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>>,
fixed_drift_deriv: Option<FixedDriftDerivFn>,
barrier_config: Option<BarrierConfig>,
}
impl<'dp> InnerSolutionBuilder<'dp> {
pub fn new(
log_likelihood: f64,
penalty_quadratic: f64,
beta: Array1<f64>,
n_observations: usize,
hessian_op: Arc<dyn HessianOperator>,
penalty_coords: Vec<PenaltyCoordinate>,
penalty_logdet: PenaltyLogdetDerivs,
dispersion: DispersionHandling,
) -> Self {
Self {
log_likelihood,
penalty_quadratic,
hessian_op,
beta,
penalty_coords,
penalty_logdet,
n_observations,
dispersion,
deriv_provider: Box::new(GaussianDerivatives),
tk_correction: 0.0,
tk_gradient: None,
firth: None,
hessian_logdet_correction: 0.0,
penalty_subspace_trace: None,
rho_curvature_scale: 1.0,
nullspace_dim_override: None,
ext_coords: Vec::new(),
ext_coord_pair_fn: None,
rho_ext_pair_fn: None,
fixed_drift_deriv: None,
barrier_config: None,
}
}
pub fn deriv_provider(mut self, p: Box<dyn HessianDerivativeProvider + 'dp>) -> Self {
self.deriv_provider = p;
self
}
pub fn tk(mut self, correction: f64, gradient: Option<Array1<f64>>) -> Self {
self.tk_correction = correction;
self.tk_gradient = gradient;
self
}
pub fn firth(mut self, op: Option<std::sync::Arc<super::FirthDenseOperator>>) -> Self {
self.firth = op.map(ExactJeffreysTerm::new);
self
}
pub fn hessian_logdet_correction(mut self, correction: f64) -> Self {
self.hessian_logdet_correction = correction;
self
}
pub fn penalty_subspace_trace(mut self, kernel: Option<Arc<PenaltySubspaceTrace>>) -> Self {
self.penalty_subspace_trace = kernel;
self
}
pub fn rho_curvature_scale(mut self, scale: f64) -> Self {
self.rho_curvature_scale = scale;
self
}
pub fn nullspace_dim_override(mut self, dim: f64) -> Self {
self.nullspace_dim_override = Some(dim);
self
}
pub fn ext_coords(mut self, coords: Vec<HyperCoord>) -> Self {
self.ext_coords = coords;
self
}
pub fn ext_coord_pair_fn(
mut self,
f: Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>,
) -> Self {
self.ext_coord_pair_fn = Some(f);
self
}
pub fn rho_ext_pair_fn(
mut self,
f: Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>,
) -> Self {
self.rho_ext_pair_fn = Some(f);
self
}
pub fn fixed_drift_deriv(mut self, f: FixedDriftDerivFn) -> Self {
self.fixed_drift_deriv = Some(f);
self
}
pub fn barrier_config(mut self, config: Option<BarrierConfig>) -> Self {
self.barrier_config = config;
self
}
pub fn build(self) -> InnerSolution<'dp> {
let nullspace_dim = self.nullspace_dim_override.unwrap_or_else(|| {
let total_p = self.beta.len();
let penalty_rank: usize = self
.penalty_coords
.iter()
.map(PenaltyCoordinate::rank)
.sum();
total_p.saturating_sub(penalty_rank) as f64
});
InnerSolution {
log_likelihood: self.log_likelihood,
penalty_quadratic: self.penalty_quadratic,
hessian_op: self.hessian_op,
beta: self.beta,
penalty_coords: self.penalty_coords,
penalty_logdet: self.penalty_logdet,
deriv_provider: self.deriv_provider,
tk_correction: self.tk_correction,
tk_gradient: self.tk_gradient,
firth: self.firth,
hessian_logdet_correction: self.hessian_logdet_correction,
penalty_subspace_trace: self.penalty_subspace_trace,
rho_curvature_scale: self.rho_curvature_scale,
n_observations: self.n_observations,
nullspace_dim,
dispersion: self.dispersion,
ext_coords: self.ext_coords,
ext_coord_pair_fn: self.ext_coord_pair_fn,
rho_ext_pair_fn: self.rho_ext_pair_fn,
fixed_drift_deriv: self.fixed_drift_deriv,
barrier_config: self.barrier_config,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum EvalMode {
ValueOnly,
ValueAndGradient,
ValueGradientHessian,
}
pub struct RemlLamlResult {
pub cost: f64,
pub gradient: Option<Array1<f64>>,
pub hessian: crate::solver::outer_strategy::HessianResult,
}
use crate::solver::estimate::smooth_floor_dp;
const DENOM_RIDGE: f64 = 1e-8;
fn penalty_a_k_beta(coord: &PenaltyCoordinate, beta: &Array1<f64>, lambda: f64) -> Array1<f64> {
coord.apply_penalty(beta, lambda)
}
fn penalty_a_k_quadratic(coord: &PenaltyCoordinate, beta: &Array1<f64>, lambda: f64) -> f64 {
coord.quadratic(beta, lambda)
}
#[inline]
fn rho_curvature_lambda(solution: &InnerSolution<'_>, lambda: f64) -> f64 {
solution.rho_curvature_scale * lambda
}
fn penalty_coord_to_operator(coord: PenaltyCoordinate, scale: f64) -> Arc<dyn HyperOperator> {
struct OwnedPenaltyHyperOperator {
coord: PenaltyCoordinate,
scale: f64,
}
impl HyperOperator for OwnedPenaltyHyperOperator {
fn dim(&self) -> usize {
self.coord.dim()
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v.view(), out.view_mut());
out
}
fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v, out.view_mut());
out
}
fn mul_vec_into(&self, v: ArrayView1<'_, f64>, out: ArrayViewMut1<'_, f64>) {
self.coord.apply_penalty_view_into(v, self.scale, out);
}
fn scaled_add_mul_vec(
&self,
v: ArrayView1<'_, f64>,
scale: f64,
out: ArrayViewMut1<'_, f64>,
) {
if scale == 0.0 {
return;
}
self.coord
.scaled_add_penalty_view(v, scale * self.scale, out);
}
fn to_dense(&self) -> Array2<f64> {
self.coord.scaled_dense_matrix(self.scale)
}
fn is_implicit(&self) -> bool {
false
}
}
Arc::new(OwnedPenaltyHyperOperator { coord, scale })
}
fn penalty_total_drift_result(
coord: &PenaltyCoordinate,
scale: f64,
correction: Option<&DriftDerivResult>,
) -> DriftDerivResult {
match correction {
Some(DriftDerivResult::Dense(corr)) => {
if coord.uses_operator_fast_path() {
DriftDerivResult::Operator(Arc::new(CompositeHyperOperator {
dense: Some(corr.clone()),
operators: vec![penalty_coord_to_operator(coord.clone(), scale)],
dim_hint: coord.dim(),
}))
} else {
let mut dense = coord.scaled_dense_matrix(scale);
dense += corr;
DriftDerivResult::Dense(dense)
}
}
Some(DriftDerivResult::Operator(corr_op)) => {
DriftDerivResult::Operator(Arc::new(CompositeHyperOperator {
dense: if coord.uses_operator_fast_path() {
None
} else {
Some(coord.scaled_dense_matrix(scale))
},
operators: {
let mut ops = vec![Arc::clone(corr_op)];
if coord.uses_operator_fast_path() {
ops.push(penalty_coord_to_operator(coord.clone(), scale));
}
ops
},
dim_hint: coord.dim(),
}))
}
None => {
if coord.uses_operator_fast_path() {
DriftDerivResult::Operator(Arc::new(CompositeHyperOperator {
dense: None,
operators: vec![penalty_coord_to_operator(coord.clone(), scale)],
dim_hint: coord.dim(),
}))
} else {
DriftDerivResult::Dense(coord.scaled_dense_matrix(scale))
}
}
}
}
fn hyper_coord_drift_operators(drift: &HyperCoordDrift) -> Vec<Arc<dyn HyperOperator>> {
let mut operators: Vec<Arc<dyn HyperOperator>> = Vec::new();
if let Some(block_local) = drift.block_local.as_ref() {
operators.push(Arc::new(block_local.clone()));
}
if let Some(operator) = drift.operator.as_ref() {
operators.push(Arc::clone(operator));
}
operators
}
fn hyper_coord_drift_operator_arc(
drift: &HyperCoordDrift,
dim_hint: usize,
) -> Option<Arc<dyn HyperOperator>> {
let mut operators = hyper_coord_drift_operators(drift);
if operators.is_empty() {
return None;
}
if drift.dense.is_none() && operators.len() == 1 {
return Some(operators.pop().expect("single operator drift"));
}
Some(Arc::new(CompositeHyperOperator {
dense: drift.dense.clone(),
operators,
dim_hint,
}))
}
fn drift_parts_into_result(
dense: Option<Array2<f64>>,
mut operators: Vec<Arc<dyn HyperOperator>>,
dim_hint: usize,
) -> DriftDerivResult {
if operators.is_empty() {
DriftDerivResult::Dense(dense.unwrap_or_else(|| Array2::<f64>::zeros((dim_hint, dim_hint))))
} else if dense.is_none() && operators.len() == 1 {
DriftDerivResult::Operator(operators.pop().expect("single operator drift"))
} else {
DriftDerivResult::Operator(Arc::new(CompositeHyperOperator {
dense,
operators,
dim_hint,
}))
}
}
fn hyper_coord_total_drift_parts(
drift: &HyperCoordDrift,
correction: Option<&DriftDerivResult>,
) -> (Option<Array2<f64>>, Vec<Arc<dyn HyperOperator>>) {
let mut dense = drift.dense.clone();
let mut operators = hyper_coord_drift_operators(drift);
if let Some(correction) = correction {
match correction {
DriftDerivResult::Dense(matrix) => {
if let Some(existing) = dense.as_mut() {
*existing += matrix;
} else {
dense = Some(matrix.clone());
}
}
DriftDerivResult::Operator(operator) => operators.push(Arc::clone(operator)),
}
}
(dense, operators)
}
fn hyper_coord_total_drift_result(
drift: &HyperCoordDrift,
correction: Option<&DriftDerivResult>,
dim_hint: usize,
) -> DriftDerivResult {
let (dense, operators) = hyper_coord_total_drift_parts(drift, correction);
drift_parts_into_result(dense, operators, dim_hint)
}
#[inline]
fn efs_q_eff(a_i: f64, dispersion: &DispersionHandling, dp_cgrad: f64, phi: f64) -> f64 {
match dispersion {
DispersionHandling::ProfiledGaussian => 2.0 * dp_cgrad * a_i / phi,
DispersionHandling::Fixed { .. } => 2.0 * a_i,
}
}
#[inline]
fn efs_log_step_from_grad(q_eff: f64, g_full: f64) -> Option<f64> {
if !q_eff.is_finite() || q_eff <= 0.0 || !g_full.is_finite() {
return None;
}
let ratio = 1.0 - 2.0 * g_full / q_eff;
if ratio > 0.0 {
Some(ratio.ln().clamp(-EFS_MAX_STEP, EFS_MAX_STEP))
} else {
Some(-EFS_MAX_STEP)
}
}
#[inline]
fn efs_profiling(solution: &InnerSolution<'_>) -> (f64, f64) {
match &solution.dispersion {
DispersionHandling::ProfiledGaussian => {
let dp_raw = -2.0 * solution.log_likelihood + solution.penalty_quadratic;
let (dp_c, dp_cgrad, _) = smooth_floor_dp(dp_raw);
let denom = (solution.n_observations as f64 - solution.nullspace_dim).max(DENOM_RIDGE);
(dp_c / denom, dp_cgrad)
}
DispersionHandling::Fixed { phi, .. } => (*phi, 0.0),
}
}
fn trace_hinv_cached_drift_cross(
hop: &dyn HessianOperator,
left_dense: Option<&Array2<f64>>,
left_op: Option<&dyn HyperOperator>,
right_dense: Option<&Array2<f64>>,
right_op: Option<&dyn HyperOperator>,
) -> f64 {
match (left_op, right_op) {
(Some(left), Some(right)) => hop.trace_hinv_operator_cross(left, right),
(Some(left), None) => hop.trace_hinv_matrix_operator_cross(
right_dense.expect("right dense drift should be cached"),
left,
),
(None, Some(right)) => hop.trace_hinv_matrix_operator_cross(
left_dense.expect("left dense drift should be cached"),
right,
),
(None, None) => hop.trace_hinv_product_cross(
left_dense.expect("left dense drift should be cached"),
right_dense.expect("right dense drift should be cached"),
),
}
}
#[inline]
fn dense_matvec_into(
matrix: &Array2<f64>,
x: ArrayView1<'_, f64>,
mut out: ArrayViewMut1<'_, f64>,
) {
debug_assert_eq!(matrix.ncols(), x.len());
debug_assert_eq!(matrix.nrows(), out.len());
for (row, out_value) in matrix.rows().into_iter().zip(out.iter_mut()) {
*out_value = row.dot(&x);
}
}
#[inline]
fn dense_matvec_scaled_add_into(
matrix: &Array2<f64>,
x: ArrayView1<'_, f64>,
scale: f64,
mut out: ArrayViewMut1<'_, f64>,
) {
debug_assert_eq!(matrix.ncols(), x.len());
debug_assert_eq!(matrix.nrows(), out.len());
if scale == 0.0 {
return;
}
for (row, out_value) in matrix.rows().into_iter().zip(out.iter_mut()) {
*out_value += scale * row.dot(&x);
}
}
#[inline]
fn dense_transpose_matvec_into(
matrix: &Array2<f64>,
x: ArrayView1<'_, f64>,
mut out: ArrayViewMut1<'_, f64>,
) {
debug_assert_eq!(matrix.nrows(), x.len());
debug_assert_eq!(matrix.ncols(), out.len());
out.fill(0.0);
dense_transpose_matvec_scaled_add_into(matrix, x, 1.0, out);
}
#[inline]
fn dense_transpose_matvec_scaled_add_into(
matrix: &Array2<f64>,
x: ArrayView1<'_, f64>,
scale: f64,
mut out: ArrayViewMut1<'_, f64>,
) {
debug_assert_eq!(matrix.nrows(), x.len());
debug_assert_eq!(matrix.ncols(), out.len());
if scale == 0.0 {
return;
}
for (row, x_value) in matrix.rows().into_iter().zip(x.iter().copied()) {
let row_scale = scale * x_value;
if row_scale == 0.0 {
continue;
}
for (out_value, entry) in out.iter_mut().zip(row.iter().copied()) {
*out_value += row_scale * entry;
}
}
}
#[inline]
fn dense_bilinear(matrix: &Array2<f64>, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
debug_assert_eq!(matrix.ncols(), v.len());
debug_assert_eq!(matrix.nrows(), u.len());
let mut total = 0.0;
for (row, u_value) in matrix.rows().into_iter().zip(u.iter().copied()) {
total += u_value * row.dot(&v);
}
total
}
fn design_matrix_apply_view(design: &DesignMatrix, vector: ArrayView1<'_, f64>) -> Array1<f64> {
let mut output = Array1::<f64>::zeros(design.nrows());
design_matrix_apply_view_into(design, vector, output.view_mut());
output
}
fn design_matrix_column_into(
design: &DesignMatrix,
col: usize,
mut output: ArrayViewMut1<'_, f64>,
) {
debug_assert!(col < design.ncols());
debug_assert_eq!(design.nrows(), output.len());
if let Some(dense) = design.as_dense() {
output.assign(&dense.column(col));
return;
}
if let Some(sparse) = design.as_sparse() {
let matrix = sparse.as_ref();
output.fill(0.0);
let (symbolic, values) = matrix.parts();
let col_ptr = symbolic.col_ptr();
let row_idx = symbolic.row_idx();
for idx in col_ptr[col]..col_ptr[col + 1] {
output[row_idx[idx]] = values[idx];
}
return;
}
let mut basis = Array1::<f64>::zeros(design.ncols());
basis[col] = 1.0;
output.assign(&design.matrixvectormultiply(&basis));
}
fn design_matrix_apply_view_into(
design: &DesignMatrix,
vector: ArrayView1<'_, f64>,
mut output: ArrayViewMut1<'_, f64>,
) {
debug_assert_eq!(design.ncols(), vector.len());
debug_assert_eq!(design.nrows(), output.len());
if let Some(dense) = design.as_dense() {
dense_matvec_into(dense, vector, output);
return;
}
if let Some(sparse) = design.as_sparse() {
let matrix = sparse.as_ref();
output.fill(0.0);
let (symbolic, values) = matrix.parts();
let col_ptr = symbolic.col_ptr();
let row_idx = symbolic.row_idx();
for col in 0..matrix.ncols() {
let x = vector[col];
if x == 0.0 {
continue;
}
for idx in col_ptr[col]..col_ptr[col + 1] {
output[row_idx[idx]] += values[idx] * x;
}
}
return;
}
output.assign(&design.matrixvectormultiply(&vector.to_owned()));
}
fn design_matrix_transpose_apply_view_into(
design: &DesignMatrix,
vector: ArrayView1<'_, f64>,
mut output: ArrayViewMut1<'_, f64>,
) {
debug_assert_eq!(design.nrows(), vector.len());
debug_assert_eq!(design.ncols(), output.len());
if let Some(dense) = design.as_dense() {
dense_transpose_matvec_into(dense, vector, output);
return;
}
if let Some(sparse) = design.as_sparse() {
let matrix = sparse.as_ref();
let (symbolic, values) = matrix.parts();
let col_ptr = symbolic.col_ptr();
let row_idx = symbolic.row_idx();
for col in 0..matrix.ncols() {
let mut value = 0.0;
for idx in col_ptr[col]..col_ptr[col + 1] {
value += values[idx] * vector[row_idx[idx]];
}
output[col] = value;
}
return;
}
output.assign(&design.transpose_vector_multiply(&vector.to_owned()));
}
#[inline]
fn trace_matrix_product(left: &Array2<f64>, right: &Array2<f64>) -> f64 {
debug_assert_eq!(left.nrows(), left.ncols());
debug_assert_eq!(left.raw_dim(), right.raw_dim());
let n = left.nrows();
let mut trace = 0.0;
for i in 0..n {
for j in 0..n {
trace += left[[i, j]] * right[[j, i]];
}
}
trace
}
#[inline]
fn outer_gradient_entry(
a_i: f64,
trace_logdet_i: f64,
ld_s_i: f64,
dispersion: &DispersionHandling,
dp_cgrad: f64,
profiled_scale: f64,
incl_logdet_h: bool,
incl_logdet_s: bool,
) -> f64 {
let penalty_term = match dispersion {
DispersionHandling::ProfiledGaussian => dp_cgrad * a_i / profiled_scale,
DispersionHandling::Fixed { .. } => a_i,
};
let trace_term = if incl_logdet_h {
0.5 * trace_logdet_i
} else {
0.0
};
let det_term = if incl_logdet_s { 0.5 * ld_s_i } else { 0.0 };
penalty_term + trace_term - det_term
}
#[inline]
fn outer_hessian_entry(
a_i: f64,
a_j: f64,
g_i_dot_v_j: f64,
pair_a: f64,
cross_trace: f64,
h2_trace: f64,
pair_ld_s: f64,
profiled_phi: f64,
profiled_nu: f64,
profiled_dp_cgrad: f64,
profiled_dp_cgrad2: f64,
is_profiled: bool,
incl_logdet_h: bool,
incl_logdet_s: bool,
) -> f64 {
let q_raw = pair_a - g_i_dot_v_j;
let q = if is_profiled {
profiled_dp_cgrad * q_raw / profiled_phi
+ 2.0
* (profiled_dp_cgrad2 * profiled_nu * profiled_phi
- profiled_dp_cgrad * profiled_dp_cgrad)
* a_i
* a_j
/ (profiled_nu * profiled_phi * profiled_phi)
} else {
q_raw
};
let l = if incl_logdet_h {
0.5 * (cross_trace + h2_trace)
} else {
0.0
};
let p = if incl_logdet_s { -0.5 * pair_ld_s } else { 0.0 };
q + l + p
}
pub fn reml_laml_evaluate(
solution: &InnerSolution<'_>,
rho: &[f64],
mode: EvalMode,
prior_cost_gradient: Option<(f64, Array1<f64>, Option<Array2<f64>>)>,
) -> Result<RemlLamlResult, String> {
let cost_phase_start = std::time::Instant::now();
let k = rho.len();
let lambdas: Vec<f64> = rho.iter().map(|&r| r.exp()).collect();
let curvature_lambdas: Vec<f64> = lambdas
.iter()
.copied()
.map(|lambda| rho_curvature_lambda(solution, lambda))
.collect();
let hop = &*solution.hessian_op;
let log_det_h = hop.logdet() + solution.hessian_logdet_correction;
let log_det_s = solution.penalty_logdet.value;
let (cost, profiled_scale, dp_cgrad, _dp_cgrad2) = match &solution.dispersion {
DispersionHandling::ProfiledGaussian => {
let dp_raw = -2.0 * solution.log_likelihood + solution.penalty_quadratic;
let (dp_c, dp_cgrad, dp_cgrad2) = smooth_floor_dp(dp_raw);
let denom = (solution.n_observations as f64 - solution.nullspace_dim).max(DENOM_RIDGE);
let phi = dp_c / denom;
let cost = dp_c / (2.0 * phi)
+ 0.5 * (log_det_h - log_det_s)
+ (denom / 2.0) * (2.0 * std::f64::consts::PI * phi).ln();
(cost, phi, dp_cgrad, dp_cgrad2)
}
DispersionHandling::Fixed {
phi,
include_logdet_h,
include_logdet_s,
} => {
let logdet_pair_h = if *include_logdet_h { log_det_h } else { 0.0 };
let logdet_pair_s = if *include_logdet_s { log_det_s } else { 0.0 };
let cost_logdet_diff = 0.5 * (logdet_pair_h - logdet_pair_s);
let mut cost =
cost_logdet_diff + (-solution.log_likelihood) + 0.5 * solution.penalty_quadratic;
if *include_logdet_h {
cost += solution.tk_correction
- solution
.firth
.as_ref()
.map_or(0.0, ExactJeffreysTerm::value);
}
(cost, *phi, 0.0, 0.0)
}
};
let mut cost = match &prior_cost_gradient {
Some((pc, _, _)) => cost + pc,
None => cost,
};
if let Some(ref barrier_cfg) = solution.barrier_config {
match barrier_cfg.barrier_cost(&solution.beta) {
Ok(bc) => cost += bc,
Err(e) => {
log::warn!("Barrier cost skipped (infeasible): {e}");
}
}
}
if !cost.is_finite() {
return Err(format!(
"REML/LAML cost is non-finite ({cost}); check inner solver convergence"
));
}
if mode == EvalMode::ValueOnly {
return Ok(RemlLamlResult {
cost,
gradient: None,
hessian: crate::solver::outer_strategy::HessianResult::Unavailable,
});
}
log::info!(
"[STAGE] reml_laml cost_only_done k={} ext_dim={} dim={} elapsed={:.3}s",
k,
solution.ext_coords.len(),
hop.dim(),
cost_phase_start.elapsed().as_secs_f64(),
);
let barrier_deriv_holder: Option<BarrierDerivativeProvider<'_>> = if let Some(ref barrier_cfg) =
solution.barrier_config
{
match BarrierDerivativeProvider::new(&*solution.deriv_provider, barrier_cfg, &solution.beta)
{
Ok(bdp) => Some(bdp),
Err(e) => {
log::warn!("BarrierDerivativeProvider skipped (infeasible): {e}");
None
}
}
} else {
None
};
let effective_deriv: &dyn HessianDerivativeProvider = match barrier_deriv_holder {
Some(ref bdp) => bdp,
None => &*solution.deriv_provider,
};
let (incl_logdet_h, incl_logdet_s) = match &solution.dispersion {
DispersionHandling::ProfiledGaussian => (true, true),
DispersionHandling::Fixed {
include_logdet_h,
include_logdet_s,
..
} => (*include_logdet_h, *include_logdet_s),
};
let ext_dim = solution.ext_coords.len();
let mut grad = Array1::zeros(k + ext_dim);
let rho_penalty_a_k_betas: Vec<Array1<f64>> = (0..k)
.into_par_iter()
.map(|idx| penalty_a_k_beta(&solution.penalty_coords[idx], &solution.beta, lambdas[idx]))
.collect();
let rho_curvature_a_k_betas: Vec<Array1<f64>> = (0..k)
.into_par_iter()
.map(|idx| {
penalty_a_k_beta(
&solution.penalty_coords[idx],
&solution.beta,
curvature_lambdas[idx],
)
})
.collect();
let need_family_corrections = effective_deriv.has_corrections();
let rho_v_ks: Option<Vec<Array1<f64>>> = if need_family_corrections {
Some(
rho_curvature_a_k_betas
.par_iter()
.map(|a_k_beta| hop.solve(a_k_beta))
.collect(),
)
} else {
None
};
let ext_v_is: Vec<Array1<f64>> = solution
.ext_coords
.par_iter()
.map(|coord| hop.solve(&coord.g))
.collect();
let coord_corrections: Vec<Option<DriftDerivResult>> = if need_family_corrections {
let rho_vs = rho_v_ks
.as_ref()
.expect("rho mode responses required for Hessian corrections");
let mut correction_vs = Vec::with_capacity(k + ext_dim);
correction_vs.extend(rho_vs.iter().cloned());
correction_vs.extend(ext_v_is.iter().cloned());
let correction_work = solution
.n_observations
.saturating_mul(hop.dim())
.saturating_mul((k + ext_dim).max(1));
let correction_parallel_work_limit = if hop.dim() <= 512 {
1_000_000_000
} else {
64_000_000
};
let parallel_corrections = correction_work <= correction_parallel_work_limit;
if effective_deriv.has_batched_hessian_derivative_corrections() {
log::info!(
"[STAGE] reml_laml coord_corrections mode=batched k={} ext_dim={} n={} dim={} work={}",
k,
ext_dim,
solution.n_observations,
hop.dim(),
correction_work
);
effective_deriv.hessian_derivative_corrections_result(&correction_vs)?
} else if parallel_corrections {
correction_vs
.par_iter()
.map(|v_k| effective_deriv.hessian_derivative_correction_result(v_k))
.collect::<Result<Vec<_>, _>>()?
} else {
log::info!(
"[STAGE] reml_laml coord_corrections mode=serial k={} ext_dim={} n={} dim={} work={}",
k,
ext_dim,
solution.n_observations,
hop.dim(),
correction_work
);
correction_vs
.iter()
.map(|v_k| effective_deriv.hessian_derivative_correction_result(v_k))
.collect::<Result<Vec<_>, _>>()?
}
} else {
(0..(k + ext_dim)).map(|_| None).collect()
};
if coord_corrections.len() != k + ext_dim {
return Err(format!(
"REML/LAML derivative correction count mismatch: got {}, expected {}",
coord_corrections.len(),
k + ext_dim
));
}
let rho_corrections = &coord_corrections[..k];
let ext_corrections = &coord_corrections[k..];
let total_p = hop.dim();
let use_stochastic_traces = can_use_stochastic_logdet_hinv_kernel(hop, total_p, incl_logdet_h)
&& solution.penalty_subspace_trace.is_none();
let stochastic_trace_values: Option<Vec<f64>> = if use_stochastic_traces {
let mut dense_matrices: Vec<Array2<f64>> = Vec::with_capacity(k + ext_dim);
let mut operators: Vec<Arc<dyn HyperOperator>> = Vec::new();
let mut coord_has_operator = Vec::with_capacity(k + ext_dim);
for idx in 0..k {
match penalty_total_drift_result(
&solution.penalty_coords[idx],
curvature_lambdas[idx],
rho_corrections[idx].as_ref(),
) {
DriftDerivResult::Dense(matrix) => {
dense_matrices.push(matrix);
coord_has_operator.push(false);
}
DriftDerivResult::Operator(op) => {
operators.push(op);
coord_has_operator.push(true);
}
}
}
for (ext_idx, coord) in solution.ext_coords.iter().enumerate() {
let correction = ext_corrections[ext_idx].as_ref();
match hyper_coord_total_drift_result(&coord.drift, correction, hop.dim()) {
DriftDerivResult::Dense(matrix) => {
dense_matrices.push(matrix);
coord_has_operator.push(false);
}
DriftDerivResult::Operator(op) => {
operators.push(op);
coord_has_operator.push(true);
}
}
}
let dense_refs: Vec<&Array2<f64>> = dense_matrices.iter().collect();
let generic_ops: Vec<&dyn HyperOperator> = operators.iter().map(|op| op.as_ref()).collect();
let implicit_ops: Vec<&ImplicitHyperOperator> =
operators.iter().filter_map(|op| op.as_implicit()).collect();
let raw_traces = if generic_ops.is_empty() {
stochastic_trace_hinv_products(hop, StochasticTraceTargets::Dense(&dense_refs))
} else if generic_ops.len() == implicit_ops.len() {
stochastic_trace_hinv_products(
hop,
StochasticTraceTargets::Structural {
dense_matrices: &dense_refs,
implicit_ops: &implicit_ops,
},
)
} else {
stochastic_trace_hinv_products(
hop,
StochasticTraceTargets::Mixed {
dense_matrices: &dense_refs,
operators: &generic_ops,
},
)
};
let mut result = Vec::with_capacity(k + ext_dim);
let n_dense_total = coord_has_operator.iter().filter(|&&b| !b).count();
let mut dense_cursor = 0usize;
let mut operator_cursor = n_dense_total;
for &has_operator in &coord_has_operator {
if has_operator {
result.push(raw_traces[operator_cursor]);
operator_cursor += 1;
} else {
result.push(raw_traces[dense_cursor]);
dense_cursor += 1;
}
}
Some(result)
} else {
None
};
let rho_grad_entries: Vec<(usize, f64)> = (0..k)
.into_par_iter()
.map(|idx| {
let coord = &solution.penalty_coords[idx];
let a_k_beta = &rho_penalty_a_k_betas[idx];
let a_i = 0.5 * solution.beta.dot(a_k_beta);
let trace_logdet_i = if !incl_logdet_h {
0.0
} else if let Some(ref stoch_traces) = stochastic_trace_values {
stoch_traces[idx]
} else if let Some(kernel) = solution.penalty_subspace_trace.as_ref() {
let drift = penalty_total_drift_result(
coord,
curvature_lambdas[idx],
rho_corrections[idx].as_ref(),
);
match drift {
DriftDerivResult::Dense(matrix) => kernel.trace_projected_logdet(&matrix),
DriftDerivResult::Operator(op) => kernel.trace_operator(op.as_ref()),
}
} else if coord.is_block_local() && rho_corrections[idx].is_none() {
let (block, start, end) = coord.scaled_block_local(1.0);
hop.trace_logdet_block_local(&block, curvature_lambdas[idx], start, end)
} else {
penalty_total_drift_result(
coord,
curvature_lambdas[idx],
rho_corrections[idx].as_ref(),
)
.trace_logdet(hop)
};
let value = outer_gradient_entry(
a_i,
trace_logdet_i,
solution.penalty_logdet.first[idx],
&solution.dispersion,
dp_cgrad,
profiled_scale,
incl_logdet_h,
incl_logdet_s,
);
log::trace!(
"[RHO-GRAD] idx={} value={:+.6e} a_i={:+.6e} trace_logdet={:+.6e} ld_s_first={:+.6e} incl_h={} incl_s={}",
idx, value, a_i, trace_logdet_i, solution.penalty_logdet.first[idx], incl_logdet_h, incl_logdet_s
);
(idx, value)
})
.collect();
for (idx, value) in rho_grad_entries {
grad[idx] = value;
}
let ext_grad_entries: Result<Vec<(usize, f64)>, String> = (0..ext_dim)
.into_par_iter()
.map(|ext_idx| {
let coord = &solution.ext_coords[ext_idx];
let ext_coord_start = std::time::Instant::now();
let grad_idx = k + ext_idx;
let trace_logdet_i = if !incl_logdet_h {
0.0
} else if let Some(ref stoch_traces) = stochastic_trace_values {
stoch_traces[k + ext_idx]
} else {
let correction = ext_corrections[ext_idx].as_ref();
let drift = hyper_coord_total_drift_result(&coord.drift, correction, hop.dim());
match (&solution.penalty_subspace_trace, drift) {
(Some(kernel), DriftDerivResult::Dense(matrix)) => {
kernel.trace_projected_logdet(&matrix)
}
(Some(kernel), DriftDerivResult::Operator(op)) => {
kernel.trace_operator(op.as_ref())
}
(None, DriftDerivResult::Dense(matrix)) => hop.trace_logdet_h_k(&matrix, None),
(None, DriftDerivResult::Operator(op)) => {
hop.trace_logdet_operator(op.as_ref())
}
}
};
let value = outer_gradient_entry(
coord.a,
trace_logdet_i,
coord.ld_s,
&solution.dispersion,
dp_cgrad,
profiled_scale,
incl_logdet_h,
incl_logdet_s,
);
log::trace!(
"[EXT-GRAD] ext_idx={} value={:+.6e} coord.a={:+.6e} trace_logdet={:+.6e} ld_s={:+.6e} incl_h={} incl_s={}",
ext_idx, value, coord.a, trace_logdet_i, coord.ld_s, incl_logdet_h, incl_logdet_s
);
log::info!(
"[STAGE] reml_laml ext_coord_trace ext_idx={} elapsed={:.3}s",
ext_idx,
ext_coord_start.elapsed().as_secs_f64(),
);
Ok((grad_idx, value))
})
.collect();
for (idx, value) in ext_grad_entries? {
grad[idx] = value;
}
if let Some(tk_grad) = &solution.tk_gradient {
{
let mut sl = grad.slice_mut(ndarray::s![..k]);
sl += tk_grad;
}
}
if let Some((_, ref pg, _)) = prior_cost_gradient {
{
let mut sl = grad.slice_mut(ndarray::s![..k]);
sl += pg;
}
}
if let Some((idx, value)) = grad.iter().enumerate().find(|(_, v)| !v.is_finite()) {
return Err(format!(
"REML/LAML gradient contains non-finite entry at index {idx}: {value}"
));
}
let hessian = if mode == EvalMode::ValueGradientHessian {
if let Some(family_op) = effective_deriv.family_outer_hessian_operator() {
let n_obs = effective_deriv
.scalar_glm_ingredients()
.map(|ing| ing.x.nrows())
.unwrap_or(solution.n_observations);
let p_dim = hop.dim();
let k_outer = k + solution.ext_coords.len();
log::info!(
"[OUTER hessian-route] choice=operator reason=family_op \
n={n_obs} p={p_dim} k={k_outer} \
callback_kernel=false subspace_trace={subspace} \
scale_prefers_operator=irrelevant",
subspace = solution.penalty_subspace_trace.is_some(),
);
if family_op.dim() != k_outer {
return Err(format!(
"family outer Hessian operator dimension mismatch: got {}, expected {}",
family_op.dim(),
k_outer
));
}
let assembly_start = std::time::Instant::now();
let mut hessian = crate::solver::outer_strategy::HessianResult::Operator(family_op);
if let Some((_, _, Some(ref ph))) = prior_cost_gradient {
hessian.add_rho_block_dense(ph)?;
}
log::info!(
"[OUTER hessian-elapsed] choice=operator reason=family_op \
n={n_obs} p={p_dim} k={k_outer} elapsed={:.3}s",
assembly_start.elapsed().as_secs_f64(),
);
return Ok(RemlLamlResult {
cost,
gradient: Some(grad),
hessian,
});
}
let hessian_kernel = effective_deriv.outer_hessian_derivative_kernel();
let n_obs = effective_deriv
.scalar_glm_ingredients()
.map(|ing| ing.x.nrows())
.unwrap_or(solution.n_observations);
let p_dim = hop.dim();
let k_outer = k + solution.ext_coords.len();
let callback_operator_kernel = matches!(
hessian_kernel,
Some(OuterHessianDerivativeKernel::Callback { .. })
);
let large_p = p_dim >= MATRIX_FREE_OUTER_HESSIAN_DIM_THRESHOLD;
let large_n_and_moderate_p = n_obs >= MATRIX_FREE_OUTER_HESSIAN_LARGE_N_THRESHOLD
&& p_dim >= MATRIX_FREE_OUTER_HESSIAN_DIM_AT_LARGE_N;
let large_linear_work =
n_obs.saturating_mul(p_dim) >= MATRIX_FREE_OUTER_HESSIAN_NP_THRESHOLD;
let large_k = k_outer >= MATRIX_FREE_OUTER_HESSIAN_K_THRESHOLD;
let scale_prefers_operator = prefer_outer_hessian_operator(n_obs, p_dim, k_outer);
let has_subspace_trace = solution.penalty_subspace_trace.is_some();
let use_operator =
hessian_kernel.is_some() && use_outer_hessian_operator_path(n_obs, p_dim, k_outer);
let route_reason = if hessian_kernel.is_none() {
"kernel_absent"
} else if has_subspace_trace && scale_prefers_operator {
"subspace_projected_operator"
} else if large_k {
"large_k"
} else if large_p {
"large_p"
} else if large_n_and_moderate_p {
"large_n_moderate_p"
} else if large_linear_work {
"large_linear_work"
} else {
"below_crossover"
};
let route_choice = if use_operator { "operator" } else { "dense" };
log::info!(
"[OUTER hessian-route] choice={route_choice} reason={route_reason} \
n={n_obs} p={p_dim} k={k_outer} \
callback_kernel={callback_operator_kernel} subspace_trace={has_subspace_trace} \
scale_prefers_operator={scale_prefers_operator}"
);
let assembly_start = std::time::Instant::now();
let result = if use_operator {
let coord_vs_for_hessian = rho_v_ks.as_ref().map(|rho_vs| {
let mut all = Vec::with_capacity(k + ext_dim);
all.extend(rho_vs.iter().cloned());
all.extend(ext_v_is.iter().cloned());
all
});
match build_outer_hessian_operator(
solution,
&lambdas,
effective_deriv,
hessian_kernel.expect("checked is_some above"),
coord_vs_for_hessian.as_deref(),
Some(&coord_corrections),
) {
Ok(op) => {
let mut hessian =
crate::solver::outer_strategy::HessianResult::Operator(Arc::new(op));
if let Some((_, _, Some(ref ph))) = prior_cost_gradient {
hessian.add_rho_block_dense(ph)?;
}
hessian
}
Err(err) if is_hessian_unavailable(&err) => {
log::warn!("{err}");
crate::solver::outer_strategy::HessianResult::Unavailable
}
Err(err) => return Err(err),
}
} else {
let reml_workspace = RemlDerivativeWorkspace {
curvature_lambdas: &curvature_lambdas,
rho_penalty_a_k_betas: &rho_penalty_a_k_betas,
rho_curvature_a_k_betas: &rho_curvature_a_k_betas,
rho_v_ks: rho_v_ks.as_deref(),
coord_corrections: &coord_corrections,
};
match compute_outer_hessian(
solution,
rho,
&lambdas,
hop,
effective_deriv,
Some(&reml_workspace),
) {
Ok(mut h) => {
if let Some((_, _, Some(ref ph))) = prior_cost_gradient {
let mut sl = h.slice_mut(ndarray::s![..k, ..k]);
sl += ph;
}
crate::solver::outer_strategy::HessianResult::Analytic(h)
}
Err(err) if is_hessian_unavailable(&err) => {
log::warn!("{err}");
crate::solver::outer_strategy::HessianResult::Unavailable
}
Err(err) => return Err(err),
}
};
log::info!(
"[OUTER hessian-elapsed] choice={route_choice} reason={route_reason} \
n={n_obs} p={p_dim} k={k_outer} elapsed={:.3}s",
assembly_start.elapsed().as_secs_f64(),
);
result
} else {
crate::solver::outer_strategy::HessianResult::Unavailable
};
Ok(RemlLamlResult {
cost,
gradient: Some(grad),
hessian,
})
}
const HESSIAN_UNAVAILABLE_PREFIX: &str = "outer Hessian unavailable:";
pub(crate) const MATRIX_FREE_OUTER_HESSIAN_DIM_THRESHOLD: usize = 512;
pub(crate) const MATRIX_FREE_OUTER_HESSIAN_LARGE_N_THRESHOLD: usize = 50_000;
pub(crate) const MATRIX_FREE_OUTER_HESSIAN_DIM_AT_LARGE_N: usize = 32;
pub(crate) const MATRIX_FREE_OUTER_HESSIAN_NP_THRESHOLD: usize = 4_000_000;
pub(crate) const MATRIX_FREE_OUTER_HESSIAN_K_THRESHOLD: usize = 32;
pub(crate) fn prefer_outer_hessian_operator(n: usize, p: usize, k: usize) -> bool {
let large_p = p >= MATRIX_FREE_OUTER_HESSIAN_DIM_THRESHOLD;
let large_n_and_moderate_p = n >= MATRIX_FREE_OUTER_HESSIAN_LARGE_N_THRESHOLD
&& p >= MATRIX_FREE_OUTER_HESSIAN_DIM_AT_LARGE_N;
let large_linear_work = n.saturating_mul(p) >= MATRIX_FREE_OUTER_HESSIAN_NP_THRESHOLD;
let large_k = k >= MATRIX_FREE_OUTER_HESSIAN_K_THRESHOLD;
large_p || large_n_and_moderate_p || large_linear_work || large_k
}
pub(crate) fn use_outer_hessian_operator_path(n: usize, p: usize, k: usize) -> bool {
prefer_outer_hessian_operator(n, p, k)
}
fn is_hessian_unavailable(error: &str) -> bool {
error.starts_with(HESSIAN_UNAVAILABLE_PREFIX)
}
fn compute_adjoint_z_c(
ing: &ScalarGlmIngredients<'_>,
hop: &dyn HessianOperator,
leverage: &Array1<f64>,
) -> Result<Array1<f64>, String> {
let mut weighted = Array1::<f64>::zeros(ing.c_array.len());
Zip::from(&mut weighted)
.and(ing.c_array)
.and(leverage)
.for_each(|w, &c, &h| *w = c * h);
let v = ing.x.transpose_vector_multiply(&weighted);
Ok(hop.solve(&v))
}
fn compute_fourth_derivative_trace(
ing: &ScalarGlmIngredients<'_>,
v_k: &Array1<f64>,
v_l: &Array1<f64>,
leverage: &Array1<f64>,
) -> Result<Option<f64>, String> {
let Some(d_array) = ing.d_array else {
return Ok(None);
};
let x_vk = ing.x.matrixvectormultiply(v_k);
let x_vl = ing.x.matrixvectormultiply(v_l);
let mut acc = 0.0;
Zip::from(d_array)
.and(&x_vk)
.and(&x_vl)
.and(leverage)
.for_each(|&d, &xvk, &xvl, &h| acc += d * xvk * xvl * h);
Ok(Some(acc))
}
fn compute_fourth_derivative_trace_matrix(
ing: &ScalarGlmIngredients<'_>,
modes: &[&Array1<f64>],
leverage: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let Some(d_array) = ing.d_array else {
return Ok(None);
};
let n = ing.c_array.len();
let t = modes.len();
if t == 0 {
return Ok(Some(Array2::zeros((0, 0))));
}
if d_array.len() != n || leverage.len() != n {
return Err(format!(
"fourth-derivative trace shape mismatch: c={}, d={}, leverage={}",
n,
d_array.len(),
leverage.len()
));
}
let mut x_modes = Array2::<f64>::zeros((n, t));
for (j, mode) in modes.iter().enumerate() {
let x_v = ing.x.matrixvectormultiply(mode);
if x_v.len() != n {
return Err(format!(
"fourth-derivative trace Xv length mismatch for mode {j}: got {}, expected {n}",
x_v.len()
));
}
x_modes.column_mut(j).assign(&x_v);
}
let mut weighted = x_modes.clone();
Zip::from(weighted.rows_mut())
.and(d_array)
.and(leverage)
.for_each(|mut row, &d, &h| {
let scale = d * h;
row.mapv_inplace(|value| value * scale);
});
Ok(Some(crate::faer_ndarray::fast_atb(&x_modes, &weighted)))
}
fn compute_ift_correction_trace(
hop: &dyn HessianOperator,
rhs: &Array1<f64>,
v_i: &Array1<f64>,
v_j: &Array1<f64>,
effective_deriv: &dyn HessianDerivativeProvider,
adjoint_z_c: Option<&Array1<f64>>,
glm_ingredients: Option<&ScalarGlmIngredients<'_>>,
leverage: Option<&Array1<f64>>,
precomputed_fourth_trace: Option<f64>,
subspace: Option<&PenaltySubspaceTrace>,
) -> Result<f64, String> {
if !effective_deriv.has_corrections() {
return Ok(0.0);
}
if let (Some(z_c), None) = (adjoint_z_c, subspace) {
let c_trace = rhs.dot(z_c);
let d_trace = if let Some(trace) = precomputed_fourth_trace {
trace
} else {
match (glm_ingredients, leverage) {
(Some(ing), Some(h_g)) => {
compute_fourth_derivative_trace(ing, v_i, v_j, h_g)?.unwrap_or(0.0)
}
_ => 0.0,
}
};
Ok(c_trace + d_trace)
} else {
let u = hop.solve(rhs);
if let Some(correction) =
effective_deriv.hessian_second_derivative_correction_result(v_i, v_j, &u)?
{
if let Some(kernel) = subspace {
match correction {
DriftDerivResult::Dense(matrix) => Ok(kernel.trace_projected_logdet(&matrix)),
DriftDerivResult::Operator(op) => Ok(kernel.trace_operator(op.as_ref())),
}
} else {
Ok(correction.trace_logdet(hop))
}
} else {
Ok(0.0)
}
}
}
fn compute_drift_deriv_traces(
hop: &dyn HessianOperator,
b_i_depends: bool,
b_j_depends: bool,
ext_i: Option<usize>,
ext_j: Option<usize>,
beta_i: &Array1<f64>,
beta_j: &Array1<f64>,
fixed_drift_deriv: Option<&FixedDriftDerivFn>,
subspace: Option<&PenaltySubspaceTrace>,
) -> f64 {
let trace_via = |result: DriftDerivResult| -> f64 {
if let Some(kernel) = subspace {
match result {
DriftDerivResult::Dense(matrix) => kernel.trace_projected_logdet(&matrix),
DriftDerivResult::Operator(op) => kernel.trace_operator(op.as_ref()),
}
} else {
match result {
DriftDerivResult::Dense(matrix) => hop.trace_logdet_gradient(&matrix),
DriftDerivResult::Operator(op) => hop.trace_logdet_operator(op.as_ref()),
}
}
};
let mut trace = 0.0;
if b_i_depends {
if let (Some(ei), Some(drift_fn)) = (ext_i, fixed_drift_deriv) {
if let Some(result) = drift_fn(ei, beta_j) {
trace += trace_via(result);
}
}
}
if b_j_depends {
if let (Some(ej), Some(drift_fn)) = (ext_j, fixed_drift_deriv) {
if let Some(result) = drift_fn(ej, beta_i) {
trace += trace_via(result);
}
}
}
trace
}
fn compute_base_h2_trace(
hop: &dyn HessianOperator,
b_mat: &Array2<f64>,
b_operator: Option<&dyn HyperOperator>,
subspace: Option<&PenaltySubspaceTrace>,
) -> f64 {
if let Some(kernel) = subspace {
if let Some(op) = b_operator {
kernel.trace_operator(op)
} else if b_mat.nrows() > 0 {
kernel.trace_projected_logdet(b_mat)
} else {
0.0
}
} else if let Some(op) = b_operator {
hop.trace_logdet_operator(op)
} else if b_mat.nrows() > 0 {
hop.trace_logdet_gradient(b_mat)
} else {
0.0
}
}
fn compute_base_h2_traces(
hop: &dyn HessianOperator,
pairs: &[&HyperCoordPair],
subspace: Option<&PenaltySubspaceTrace>,
) -> Vec<f64> {
if pairs.is_empty() {
return Vec::new();
}
if subspace.is_none()
&& hop.prefers_stochastic_trace_estimation()
&& hop.logdet_traces_match_hinv_kernel()
{
let mut out = vec![0.0; pairs.len()];
let mut dense_refs: Vec<&Array2<f64>> = Vec::new();
let mut dense_slots = Vec::new();
let mut op_refs: Vec<&dyn HyperOperator> = Vec::new();
let mut op_slots = Vec::new();
for (idx, pair) in pairs.iter().enumerate() {
if let Some(op) = pair.b_operator.as_deref() {
op_slots.push(idx);
op_refs.push(op);
} else if pair.b_mat.nrows() > 0 {
dense_slots.push(idx);
dense_refs.push(&pair.b_mat);
}
}
if !dense_refs.is_empty() || !op_refs.is_empty() {
let estimator = StochasticTraceEstimator::with_defaults();
let values = estimator.estimate_traces_with_operators(hop, &dense_refs, &op_refs);
for (local, &slot) in dense_slots.iter().enumerate() {
out[slot] = values[local];
}
let offset = dense_refs.len();
for (local, &slot) in op_slots.iter().enumerate() {
out[slot] = values[offset + local];
}
}
return out;
}
pairs
.iter()
.map(|pair| compute_base_h2_trace(hop, &pair.b_mat, pair.b_operator.as_deref(), subspace))
.collect()
}
fn trace_logdet_hessian_cross_dense_drift(
hop: &dyn HessianOperator,
dense: &Array2<f64>,
drift: &DriftDerivResult,
) -> f64 {
match drift {
DriftDerivResult::Dense(matrix) => hop.trace_logdet_hessian_cross(dense, matrix),
DriftDerivResult::Operator(operator) => {
hop.trace_logdet_hessian_cross_matrix_operator(dense, operator.as_ref())
}
}
}
fn trace_logdet_hessian_crosses_dense_spectral_drifts(
dense_hop: &DenseSpectralOperator,
dense_drifts: &[Array2<f64>],
ext_drifts: &[DriftDerivResult],
) -> Array2<f64> {
let total = dense_drifts.len() + ext_drifts.len();
let mut rotated = Vec::with_capacity(total);
for matrix in dense_drifts {
rotated.push(dense_hop.rotate_to_eigenbasis(matrix));
}
for drift in ext_drifts {
let projected = match drift {
DriftDerivResult::Dense(matrix) => dense_hop.rotate_to_eigenbasis(matrix),
DriftDerivResult::Operator(operator) => {
dense_hop.projected_operator(&dense_hop.eigenvectors, operator.as_ref())
}
};
rotated.push(projected);
}
let mut out = Array2::<f64>::zeros((total, total));
for i in 0..total {
for j in i..total {
let value = dense_hop.trace_logdet_hessian_cross_rotated(&rotated[i], &rotated[j]);
out[[i, j]] = value;
if i != j {
out[[j, i]] = value;
}
}
}
out
}
#[inline]
fn can_use_stochastic_logdet_hinv_kernel(
hop: &dyn HessianOperator,
total_p: usize,
incl_logdet_h: bool,
) -> bool {
total_p > 500
&& hop.prefers_stochastic_trace_estimation()
&& hop.logdet_traces_match_hinv_kernel()
&& incl_logdet_h
}
pub(crate) struct RemlDerivativeWorkspace<'a> {
pub curvature_lambdas: &'a [f64],
pub rho_penalty_a_k_betas: &'a [Array1<f64>],
pub rho_curvature_a_k_betas: &'a [Array1<f64>],
pub rho_v_ks: Option<&'a [Array1<f64>]>,
pub coord_corrections: &'a [Option<DriftDerivResult>],
}
fn compute_outer_hessian(
solution: &InnerSolution<'_>,
rho: &[f64],
lambdas: &[f64],
hop: &dyn HessianOperator,
effective_deriv: &dyn HessianDerivativeProvider,
workspace: Option<&RemlDerivativeWorkspace<'_>>,
) -> Result<Array2<f64>, String> {
let k = rho.len();
let ext_dim = solution.ext_coords.len();
let total = k + ext_dim;
let mut hess = Array2::zeros((total, total));
let curvature_lambdas_storage: Option<Vec<f64>> = if workspace.is_some() {
None
} else {
Some(
lambdas
.iter()
.copied()
.map(|lambda| rho_curvature_lambda(solution, lambda))
.collect(),
)
};
let curvature_lambdas: &[f64] = match workspace {
Some(ws) => ws.curvature_lambdas,
None => curvature_lambdas_storage
.as_deref()
.expect("curvature_lambdas_storage populated when workspace is None"),
};
let (incl_logdet_h, incl_logdet_s) = match &solution.dispersion {
DispersionHandling::ProfiledGaussian => (true, true),
DispersionHandling::Fixed {
include_logdet_h,
include_logdet_s,
..
} => (*include_logdet_h, *include_logdet_s),
};
let det2 = solution.penalty_logdet.second.as_ref().ok_or_else(|| {
"Outer Hessian requested but penalty second derivatives not provided".to_string()
})?;
let (profiled_phi, profiled_nu, profiled_dp_cgrad, profiled_dp_cgrad2, is_profiled) =
match &solution.dispersion {
DispersionHandling::ProfiledGaussian => {
let dp_raw = -2.0 * solution.log_likelihood + solution.penalty_quadratic;
let (dp_c, dp_cgrad, dp_cgrad2) = smooth_floor_dp(dp_raw);
let nu = (solution.n_observations as f64 - solution.nullspace_dim).max(DENOM_RIDGE);
let phi_hat = dp_c / nu;
(phi_hat, nu, dp_cgrad, dp_cgrad2, true)
}
_ => (1.0, 1.0, 1.0, 0.0, false),
};
let penalty_a_k_betas_storage: Option<Vec<Array1<f64>>> = if workspace.is_some() {
None
} else {
Some(
(0..k)
.map(|idx| {
penalty_a_k_beta(&solution.penalty_coords[idx], &solution.beta, lambdas[idx])
})
.collect(),
)
};
let curvature_a_k_betas_storage: Option<Vec<Array1<f64>>> = if workspace.is_some() {
None
} else {
Some(
(0..k)
.map(|idx| {
penalty_a_k_beta(
&solution.penalty_coords[idx],
&solution.beta,
curvature_lambdas[idx],
)
})
.collect(),
)
};
let penalty_a_k_betas: &[Array1<f64>] = match workspace {
Some(ws) => ws.rho_penalty_a_k_betas,
None => penalty_a_k_betas_storage.as_deref().expect("storage set"),
};
let curvature_a_k_betas: &[Array1<f64>] = match workspace {
Some(ws) => ws.rho_curvature_a_k_betas,
None => curvature_a_k_betas_storage.as_deref().expect("storage set"),
};
let v_ks_storage: Option<Vec<Array1<f64>>> = match workspace.and_then(|ws| ws.rho_v_ks) {
Some(_) => None,
None => Some(
curvature_a_k_betas
.iter()
.map(|a_k_beta| hop.solve(a_k_beta))
.collect(),
),
};
let v_ks: &[Array1<f64>] = match workspace.and_then(|ws| ws.rho_v_ks) {
Some(vs) => vs,
None => v_ks_storage.as_deref().expect("storage set"),
};
let rho_a_vals: Vec<f64> = (0..k)
.map(|idx| 0.5 * solution.beta.dot(&penalty_a_k_betas[idx]))
.collect();
let mut a_k_matrices: Vec<Array2<f64>> = Vec::with_capacity(k);
let mut h_k_matrices: Vec<Array2<f64>> = Vec::with_capacity(k);
for idx in 0..k {
let mut a_k = solution.penalty_coords[idx].scaled_dense_matrix(curvature_lambdas[idx]);
a_k_matrices.push(a_k.clone());
let correction: Option<Array2<f64>> = match workspace {
Some(ws) => match ws.coord_corrections[idx].as_ref() {
Some(DriftDerivResult::Dense(matrix)) => Some(matrix.clone()),
Some(DriftDerivResult::Operator(_)) => {
if effective_deriv.has_corrections() {
effective_deriv.hessian_derivative_correction(&v_ks[idx])?
} else {
None
}
}
None => None,
},
None => {
if effective_deriv.has_corrections() {
effective_deriv.hessian_derivative_correction(&v_ks[idx])?
} else {
None
}
}
};
if let Some(corr) = correction {
a_k += &corr;
}
h_k_matrices.push(a_k);
}
let glm_ingredients = effective_deriv.scalar_glm_ingredients();
let leverage = if incl_logdet_h {
glm_ingredients
.as_ref()
.map(|ing| hop.xt_logdet_kernel_x_diagonal(ing.x))
} else {
None
};
let adjoint_z_c = if incl_logdet_h {
match (glm_ingredients.as_ref(), leverage.as_ref()) {
(Some(ing), Some(h_g)) => Some(compute_adjoint_z_c(ing, hop, h_g)?),
_ => None,
}
} else {
None
};
let any_ext_implicit = solution.ext_coords.iter().any(|c| {
c.drift.operator_ref().map_or(false, |op| {
c.drift.uses_operator_fast_path() && op.is_implicit()
})
});
let total_p = hop.dim();
let use_stochastic_cross_traces = any_ext_implicit
&& can_use_stochastic_logdet_hinv_kernel(hop, total_p, incl_logdet_h)
&& !effective_deriv.has_corrections()
&& solution.penalty_subspace_trace.is_none();
let mut ext_v: Vec<Array1<f64>> = Vec::with_capacity(ext_dim);
let mut ext_h_drifts: Vec<DriftDerivResult> = Vec::with_capacity(ext_dim);
for coord in solution.ext_coords.iter() {
let v_i = hop.solve(&coord.g);
let correction = if effective_deriv.has_corrections() {
effective_deriv.hessian_derivative_correction_result(&v_i)?
} else {
None
};
let h_i = hyper_coord_total_drift_result(&coord.drift, correction.as_ref(), hop.dim());
ext_v.push(v_i);
ext_h_drifts.push(h_i);
}
let fourth_trace_matrix =
if incl_logdet_h && solution.penalty_subspace_trace.is_none() && adjoint_z_c.is_some() {
match (glm_ingredients.as_ref(), leverage.as_ref()) {
(Some(ing), Some(h_g)) if ing.d_array.is_some() => {
let modes = v_ks.iter().chain(ext_v.iter()).collect::<Vec<_>>();
compute_fourth_derivative_trace_matrix(ing, &modes, h_g)?
}
_ => None,
}
} else {
None
};
let stochastic_cross_traces: Option<Array2<f64>> = if use_stochastic_cross_traces {
let total_coords = k + ext_dim;
let mut dense_mats: Vec<Array2<f64>> = Vec::new();
let mut coord_has_operator: Vec<bool> = Vec::with_capacity(total_coords);
let mut operator_arcs: Vec<Arc<dyn HyperOperator>> = Vec::new();
for idx in 0..k {
dense_mats.push(h_k_matrices[idx].clone());
coord_has_operator.push(false);
}
for drift in &ext_h_drifts {
match drift {
DriftDerivResult::Dense(matrix) => {
dense_mats.push(matrix.clone());
coord_has_operator.push(false);
}
DriftDerivResult::Operator(operator) => {
operator_arcs.push(Arc::clone(operator));
coord_has_operator.push(true);
}
}
}
let generic_ops: Vec<&dyn HyperOperator> =
operator_arcs.iter().map(|op| op.as_ref()).collect();
let impl_ops: Vec<&ImplicitHyperOperator> = generic_ops
.iter()
.filter_map(|op| op.as_implicit())
.collect();
Some(stochastic_trace_hinv_crosses(
hop,
&dense_mats,
&coord_has_operator,
&generic_ops,
&impl_ops,
))
} else {
None
};
let subspace = solution.penalty_subspace_trace.as_deref();
let reduced_h_drifts: Option<Vec<Array2<f64>>> = subspace.map(|kernel| {
let mut reduced = Vec::with_capacity(k + ext_dim);
for matrix in &h_k_matrices {
reduced.push(kernel.reduce(matrix));
}
for drift in &ext_h_drifts {
let reduced_drift = match drift {
DriftDerivResult::Dense(matrix) => kernel.reduce(matrix),
DriftDerivResult::Operator(operator) => kernel.reduce_operator(operator.as_ref()),
};
reduced.push(reduced_drift);
}
reduced
});
let exact_logdet_cross_traces = if incl_logdet_h && stochastic_cross_traces.is_none() {
if let (Some(kernel), Some(reduced)) = (subspace, reduced_h_drifts.as_ref()) {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let n = reduced.len();
let pairs: Vec<(usize, usize)> =
(0..n).flat_map(|i| (i..n).map(move |j| (i, j))).collect();
let pair_values: Vec<(usize, usize, f64)> = pairs
.into_par_iter()
.map(|(i, j)| {
let value =
-kernel.trace_projected_logdet_cross_reduced(&reduced[i], &reduced[j]);
(i, j, value)
})
.collect();
let mut out = Array2::<f64>::zeros((n, n));
for (i, j, value) in pair_values {
out[[i, j]] = value;
if i != j {
out[[j, i]] = value;
}
}
Some(out)
} else if let Some(dense_hop) = hop.as_exact_dense_spectral() {
Some(trace_logdet_hessian_crosses_dense_spectral_drifts(
dense_hop,
&h_k_matrices,
&ext_h_drifts,
))
} else {
let total_coords = k + ext_dim;
let mut out = Array2::<f64>::zeros((total_coords, total_coords));
for ii in 0..total_coords {
for jj in ii..total_coords {
let value = match (ii < k, jj < k) {
(true, true) => {
hop.trace_logdet_hessian_cross(&h_k_matrices[ii], &h_k_matrices[jj])
}
(true, false) => trace_logdet_hessian_cross_dense_drift(
hop,
&h_k_matrices[ii],
&ext_h_drifts[jj - k],
),
(false, true) => trace_logdet_hessian_cross_dense_drift(
hop,
&h_k_matrices[jj],
&ext_h_drifts[ii - k],
),
(false, false) => ext_h_drifts[ii - k]
.trace_logdet_hessian_cross(&ext_h_drifts[jj - k], hop),
};
out[[ii, jj]] = value;
if ii != jj {
out[[jj, ii]] = value;
}
}
}
Some(out)
}
} else {
None
};
for kk in 0..k {
for ll in kk..k {
let pair_a = if kk == ll { rho_a_vals[kk] } else { 0.0 };
let cross_trace = if let Some(ref exact) = exact_logdet_cross_traces {
exact[[kk, ll]]
} else if let Some(ref sct) = stochastic_cross_traces {
-sct[[kk, ll]]
} else {
hop.trace_logdet_hessian_cross(&h_k_matrices[kk], &h_k_matrices[ll])
};
let base = if kk == ll {
if let Some(kernel) = subspace {
kernel.trace_projected_logdet(&a_k_matrices[kk])
} else if solution.penalty_coords[kk].is_block_local() {
let (block, start, end) = solution.penalty_coords[kk].scaled_block_local(1.0);
hop.trace_logdet_block_local(&block, curvature_lambdas[kk], start, end)
} else {
hop.trace_logdet_gradient(&a_k_matrices[kk])
}
} else {
0.0
};
let mut rhs = h_k_matrices[ll].dot(&v_ks[kk]);
rhs += &solution.penalty_coords[kk].scaled_matvec(&v_ks[ll], curvature_lambdas[kk]);
if kk == ll {
rhs -= &curvature_a_k_betas[kk];
}
let correction = compute_ift_correction_trace(
hop,
&rhs,
&v_ks[kk],
&v_ks[ll],
effective_deriv,
adjoint_z_c.as_ref(),
glm_ingredients.as_ref(),
leverage.as_ref(),
fourth_trace_matrix.as_ref().map(|trace| trace[[kk, ll]]),
subspace,
)?;
let h_kl_trace = base + correction;
let h_val = outer_hessian_entry(
rho_a_vals[kk],
rho_a_vals[ll],
penalty_a_k_betas[ll].dot(&v_ks[kk]),
pair_a,
cross_trace,
h_kl_trace,
det2[[kk, ll]],
profiled_phi,
profiled_nu,
profiled_dp_cgrad,
profiled_dp_cgrad2,
is_profiled,
incl_logdet_h,
incl_logdet_s,
);
hess[[kk, ll]] = h_val;
if kk != ll {
hess[[ll, kk]] = h_val;
}
}
}
if let Some(ref rho_ext_fn) = solution.rho_ext_pair_fn {
for rho_idx in 0..k {
for ext_idx in 0..ext_dim {
let pair = rho_ext_fn(rho_idx, ext_idx);
let a_ext = solution.ext_coords[ext_idx].a;
let (cross_trace, h2_trace) = if incl_logdet_h {
let cross_trace = if let Some(ref exact) = exact_logdet_cross_traces {
exact[[rho_idx, k + ext_idx]]
} else if let Some(ref sct) = stochastic_cross_traces {
-sct[[rho_idx, k + ext_idx]]
} else {
trace_logdet_hessian_cross_dense_drift(
hop,
&h_k_matrices[rho_idx],
&ext_h_drifts[ext_idx],
)
};
let mut rhs = -&pair.g;
rhs += &solution.penalty_coords[rho_idx]
.scaled_matvec(&ext_v[ext_idx], curvature_lambdas[rho_idx]);
let beta_rho = v_ks[rho_idx].mapv(|value| -value);
rhs += &ext_h_drifts[ext_idx].apply(&v_ks[rho_idx]);
let base = compute_base_h2_trace(
hop,
&pair.b_mat,
pair.b_operator.as_deref(),
subspace,
);
let beta_ext = ext_v[ext_idx].mapv(|value| -value);
let m_terms = compute_drift_deriv_traces(
hop,
false, solution.ext_coords[ext_idx].b_depends_on_beta,
None,
Some(ext_idx),
&beta_rho,
&beta_ext,
solution.fixed_drift_deriv.as_ref(),
subspace,
);
let correction = compute_ift_correction_trace(
hop,
&rhs,
&v_ks[rho_idx],
&ext_v[ext_idx],
effective_deriv,
adjoint_z_c.as_ref(),
glm_ingredients.as_ref(),
leverage.as_ref(),
fourth_trace_matrix
.as_ref()
.map(|trace| trace[[rho_idx, k + ext_idx]]),
subspace,
)?;
(cross_trace, base + m_terms + correction)
} else {
(0.0, 0.0)
};
let h_val = outer_hessian_entry(
rho_a_vals[rho_idx],
a_ext,
penalty_a_k_betas[rho_idx].dot(&ext_v[ext_idx]),
pair.a,
cross_trace,
h2_trace,
pair.ld_s,
profiled_phi,
profiled_nu,
profiled_dp_cgrad,
profiled_dp_cgrad2,
is_profiled,
incl_logdet_h,
incl_logdet_s,
);
hess[[rho_idx, k + ext_idx]] = h_val;
hess[[k + ext_idx, rho_idx]] = h_val;
}
}
}
if let Some(ref ext_pair_fn) = solution.ext_coord_pair_fn {
for ii in 0..ext_dim {
for jj in ii..ext_dim {
let pair = ext_pair_fn(ii, jj);
let coord_i = &solution.ext_coords[ii];
let coord_j = &solution.ext_coords[jj];
let (cross_trace, h2_trace) = if incl_logdet_h {
let cross_trace = if let Some(ref exact) = exact_logdet_cross_traces {
exact[[k + ii, k + jj]]
} else if let Some(ref sct) = stochastic_cross_traces {
-sct[[k + ii, k + jj]]
} else {
ext_h_drifts[ii].trace_logdet_hessian_cross(&ext_h_drifts[jj], hop)
};
let mut rhs = -&pair.g;
coord_i
.drift
.scaled_add_apply(ext_v[jj].view(), 1.0, &mut rhs);
rhs += &ext_h_drifts[jj].apply(&ext_v[ii]);
let base = compute_base_h2_trace(
hop,
&pair.b_mat,
pair.b_operator.as_deref(),
subspace,
);
let beta_i = ext_v[ii].mapv(|value| -value);
let beta_j = ext_v[jj].mapv(|value| -value);
let m_terms = compute_drift_deriv_traces(
hop,
coord_i.b_depends_on_beta,
coord_j.b_depends_on_beta,
Some(ii),
Some(jj),
&beta_i,
&beta_j,
solution.fixed_drift_deriv.as_ref(),
subspace,
);
let correction = compute_ift_correction_trace(
hop,
&rhs,
&ext_v[ii],
&ext_v[jj],
effective_deriv,
adjoint_z_c.as_ref(),
glm_ingredients.as_ref(),
leverage.as_ref(),
fourth_trace_matrix
.as_ref()
.map(|trace| trace[[k + ii, k + jj]]),
subspace,
)?;
let h2 = base + m_terms + correction;
let g_dot_v = coord_i.g.dot(&ext_v[jj]);
let pair_g_finite = pair.g.iter().all(|v| v.is_finite());
let b_mat_finite = pair.b_mat.iter().all(|v| v.is_finite());
let ext_vi_finite = ext_v[ii].iter().all(|v| v.is_finite());
let ext_vj_finite = ext_v[jj].iter().all(|v| v.is_finite());
let any_non_finite = !cross_trace.is_finite()
|| !base.is_finite()
|| !m_terms.is_finite()
|| !correction.is_finite()
|| !h2.is_finite()
|| !pair.a.is_finite()
|| !pair.ld_s.is_finite()
|| !g_dot_v.is_finite()
|| !pair_g_finite
|| !b_mat_finite;
if any_non_finite {
let mut first_bad_b_mat = None;
if !b_mat_finite {
'outer: for r in 0..pair.b_mat.nrows() {
for c in 0..pair.b_mat.ncols() {
if !pair.b_mat[[r, c]].is_finite() {
first_bad_b_mat = Some((r, c, pair.b_mat[[r, c]]));
break 'outer;
}
}
}
}
let mut first_bad_pair_g = None;
if !pair_g_finite {
for (idx, value) in pair.g.iter().enumerate() {
if !value.is_finite() {
first_bad_pair_g = Some((idx, *value));
break;
}
}
}
log::warn!(
"[OUTER ext-ext non-finite] ({},{}): cross_trace={} base={} m_terms={} correction={} pair.a={} pair.ld_s={} g.dot(v_jj)={} pair_g_finite={} first_bad_pair_g={:?} b_mat_finite={} first_bad_b_mat={:?} b_operator_present={} b_mat_dim={}x{} ext_v[ii]_finite={} ext_v[jj]_finite={} coord_i.b_depends_on_beta={} coord_j.b_depends_on_beta={}",
ii,
jj,
cross_trace,
base,
m_terms,
correction,
pair.a,
pair.ld_s,
g_dot_v,
pair_g_finite,
first_bad_pair_g,
b_mat_finite,
first_bad_b_mat,
pair.b_operator.is_some(),
pair.b_mat.nrows(),
pair.b_mat.ncols(),
ext_vi_finite,
ext_vj_finite,
coord_i.b_depends_on_beta,
coord_j.b_depends_on_beta,
);
}
(cross_trace, h2)
} else {
(0.0, 0.0)
};
let h_val = outer_hessian_entry(
coord_i.a,
coord_j.a,
coord_i.g.dot(&ext_v[jj]),
pair.a,
cross_trace,
h2_trace,
pair.ld_s,
profiled_phi,
profiled_nu,
profiled_dp_cgrad,
profiled_dp_cgrad2,
is_profiled,
incl_logdet_h,
incl_logdet_s,
);
hess[[k + ii, k + jj]] = h_val;
if ii != jj {
hess[[k + jj, k + ii]] = h_val;
}
}
}
}
if hess.iter().any(|v| !v.is_finite()) {
let report_finite = |name: &str, value: f64, ii: usize, jj: usize| {
if !value.is_finite() {
log::warn!(
"[OUTER non-finite] {} at ({}, {}) = {}",
name,
ii,
jj,
value,
);
}
};
for kk in 0..k {
report_finite("rho_a_vals[kk]", rho_a_vals[kk], kk, kk);
for entry in penalty_a_k_betas[kk].iter() {
if !entry.is_finite() {
log::warn!(
"[OUTER non-finite] penalty_a_k_betas[{}] has non-finite",
kk
);
break;
}
}
for entry in v_ks[kk].iter() {
if !entry.is_finite() {
log::warn!("[OUTER non-finite] v_ks[{}] has non-finite", kk);
break;
}
}
}
if let Some(ref exact) = exact_logdet_cross_traces {
for ii in 0..exact.nrows() {
for jj in 0..exact.ncols() {
report_finite("exact_logdet_cross_traces", exact[[ii, jj]], ii, jj);
}
}
}
if let Some(ref sct) = stochastic_cross_traces {
for ii in 0..sct.nrows() {
for jj in 0..sct.ncols() {
report_finite("stochastic_cross_traces", sct[[ii, jj]], ii, jj);
}
}
}
if let Some(ref h_g) = leverage {
for entry in h_g.iter() {
if !entry.is_finite() {
log::warn!("[OUTER non-finite] leverage h^G has non-finite entries");
break;
}
}
}
if let Some(ref z_c) = adjoint_z_c {
for entry in z_c.iter() {
if !entry.is_finite() {
log::warn!("[OUTER non-finite] adjoint_z_c has non-finite entries");
break;
}
}
}
for ii in 0..total {
for jj in 0..total {
report_finite("hess", hess[[ii, jj]], ii, jj);
}
}
return Err(
"Outer Hessian contains non-finite entries; exact higher-order derivatives are invalid"
.to_string(),
);
}
Ok(hess)
}
struct StoredFirstDrift {
dense: Option<Array2<f64>>,
dense_rotated: Option<Array2<f64>>,
operators: Vec<Arc<dyn HyperOperator>>,
}
impl StoredFirstDrift {
fn from_parts(
dense: Option<Array2<f64>>,
dense_rotated: Option<Array2<f64>>,
operators: Vec<Arc<dyn HyperOperator>>,
) -> Self {
Self {
dense,
dense_rotated,
operators,
}
}
fn scaled_add_apply(&self, v: ArrayView1<'_, f64>, scale: f64, out: &mut Array1<f64>) {
debug_assert_eq!(v.len(), out.len());
if scale == 0.0 {
return;
}
if let Some(matrix) = self.dense.as_ref() {
dense_matvec_scaled_add_into(matrix, v, scale, out.view_mut());
}
if !self.operators.is_empty() {
for op in &self.operators {
op.scaled_add_mul_vec(v, scale, out.view_mut());
}
}
}
fn apply_dot(&self, v: ArrayView1<'_, f64>, test: ArrayView1<'_, f64>) -> f64 {
debug_assert_eq!(v.len(), test.len());
let mut total = 0.0;
if let Some(matrix) = self.dense.as_ref() {
total += dense_bilinear(matrix, v, test);
}
for op in &self.operators {
total += op.bilinear_view(v, test);
}
total
}
}
struct BorrowedStoredDriftOperator<'a> {
drift: &'a StoredFirstDrift,
dim_hint: usize,
}
impl HyperOperator for BorrowedStoredDriftOperator<'_> {
fn dim(&self) -> usize {
self.dim_hint
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v.view(), out.view_mut());
out
}
fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v, out.view_mut());
out
}
fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
out.fill(0.0);
if let Some(matrix) = self.drift.dense.as_ref() {
dense_matvec_into(matrix, v, out.view_mut());
}
for op in &self.drift.operators {
op.scaled_add_mul_vec(v, 1.0, out.view_mut());
}
}
fn scaled_add_mul_vec(&self, v: ArrayView1<'_, f64>, scale: f64, out: ArrayViewMut1<'_, f64>) {
if scale == 0.0 {
return;
}
let mut out = out;
if let Some(matrix) = self.drift.dense.as_ref() {
dense_matvec_scaled_add_into(matrix, v, scale, out.view_mut());
}
for op in &self.drift.operators {
op.scaled_add_mul_vec(v, scale, out.view_mut());
}
}
fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
self.drift.apply_dot(v.view(), u.view())
}
fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
self.drift.apply_dot(v, u)
}
fn to_dense(&self) -> Array2<f64> {
let mut out = self
.drift
.dense
.clone()
.unwrap_or_else(|| Array2::<f64>::zeros((self.dim_hint, self.dim_hint)));
for op in &self.drift.operators {
out += &op.to_dense();
}
out
}
fn is_implicit(&self) -> bool {
!self.drift.operators.is_empty()
}
}
pub(crate) struct WeightedHyperOperator {
pub(crate) terms: Vec<(f64, Arc<dyn HyperOperator>)>,
pub(crate) dim_hint: usize,
}
impl HyperOperator for WeightedHyperOperator {
fn dim(&self) -> usize {
self.dim_hint
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v.view(), out.view_mut());
out
}
fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v, out.view_mut());
out
}
fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
let mut nonzero_terms = self.terms.iter().filter(|(weight, _)| *weight != 0.0);
if let Some((weight, op)) = nonzero_terms.next()
&& nonzero_terms.next().is_none()
{
op.mul_vec_into(v, out.view_mut());
if *weight != 1.0 {
out.mapv_inplace(|value| *weight * value);
}
return;
}
out.fill(0.0);
for (weight, op) in &self.terms {
if *weight != 0.0 {
op.scaled_add_mul_vec(v, *weight, out.view_mut());
}
}
}
fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
let mut nonzero_terms = self.terms.iter().filter(|(weight, _)| *weight != 0.0);
if let Some((weight, op)) = nonzero_terms.next()
&& nonzero_terms.next().is_none()
{
op.mul_basis_columns_into(start, out.view_mut());
if *weight != 1.0 {
out.mapv_inplace(|value| *weight * value);
}
return;
}
out.fill(0.0);
let mut work = Array2::<f64>::zeros((out.nrows(), out.ncols()));
for (weight, op) in &self.terms {
if *weight == 0.0 {
continue;
}
op.mul_basis_columns_into(start, work.view_mut());
out.scaled_add(*weight, &work);
}
}
fn scaled_add_mul_vec(
&self,
v: ArrayView1<'_, f64>,
scale: f64,
mut out: ArrayViewMut1<'_, f64>,
) {
if scale == 0.0 {
return;
}
for (weight, op) in &self.terms {
let combined = scale * *weight;
if combined != 0.0 {
op.scaled_add_mul_vec(v, combined, out.view_mut());
}
}
}
fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
self.terms
.iter()
.filter(|(weight, _)| *weight != 0.0)
.map(|(weight, op)| weight * op.bilinear(v, u))
.sum()
}
fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
self.terms
.iter()
.filter(|(weight, _)| *weight != 0.0)
.map(|(weight, op)| weight * op.bilinear_view(v, u))
.sum()
}
fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
self.terms
.iter()
.filter(|(weight, _)| *weight != 0.0)
.map(|(weight, op)| weight * op.trace_projected_factor(factor))
.sum()
}
fn trace_projected_factor_cached(
&self,
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> f64 {
self.terms
.iter()
.filter(|(weight, _)| *weight != 0.0)
.map(|(weight, op)| weight * op.trace_projected_factor_cached(factor, cache))
.sum()
}
fn to_dense(&self) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((self.dim_hint, self.dim_hint));
for (weight, op) in &self.terms {
if *weight != 0.0 {
out.scaled_add(*weight, &op.to_dense());
}
}
out
}
fn is_implicit(&self) -> bool {
self.terms.iter().any(|(_, op)| op.is_implicit())
}
}
struct OuterHessianCoord {
a: f64,
g: Array1<f64>,
v: Array1<f64>,
total_drift: StoredFirstDrift,
base_drift: StoredFirstDrift,
ext_index: Option<usize>,
b_depends_on_beta: bool,
}
impl OuterHessianCoord {
fn is_ext(&self) -> bool {
self.ext_index.is_some()
}
}
struct UnifiedOuterHessianOperator {
hop: Arc<dyn HessianOperator>,
coords: Vec<OuterHessianCoord>,
pair_a: Array2<f64>,
pair_ld_s: Array2<f64>,
g_dot_v: Array2<f64>,
pair_g: Vec<Vec<Option<Array1<f64>>>>,
base_h2: Array2<f64>,
m_pair_trace: Array2<f64>,
cross_trace: Option<Array2<f64>>,
profiled_phi: f64,
profiled_nu: f64,
profiled_dp_cgrad: f64,
profiled_dp_cgrad2: f64,
is_profiled: bool,
incl_logdet_h: bool,
incl_logdet_s: bool,
kernel: OuterHessianDerivativeKernel,
subspace: Option<Arc<PenaltySubspaceTrace>>,
adjoint_z_c: Option<Array1<f64>>,
leverage: Option<Array1<f64>>,
fourth_trace: Option<Array2<f64>>,
callback_second_modes: Option<Vec<Array1<f64>>>,
}
impl UnifiedOuterHessianOperator {
fn signed_mode_combo_for_correction(&self, alpha: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(self.hop.dim());
for (j, coord) in self.coords.iter().enumerate() {
if alpha[j] == 0.0 {
continue;
}
if coord.is_ext() {
out.scaled_add(-alpha[j], &coord.v);
} else {
out.scaled_add(alpha[j], &coord.v);
}
}
out
}
fn pair_rhs_dot(&self, row: usize, col: usize, test: ArrayView1<'_, f64>) -> f64 {
let row_coord = &self.coords[row];
let col_coord = &self.coords[col];
let pair_g_dot = self.pair_g[row][col]
.as_ref()
.map(|pair_g| pair_g.dot(&test))
.unwrap_or(0.0);
col_coord.total_drift.apply_dot(row_coord.v.view(), test)
+ row_coord.base_drift.apply_dot(col_coord.v.view(), test)
- pair_g_dot
}
fn scaled_add_pair_rhs(&self, row: usize, col: usize, scale: f64, out: &mut Array1<f64>) {
if scale == 0.0 {
return;
}
let row_coord = &self.coords[row];
let col_coord = &self.coords[col];
col_coord
.total_drift
.scaled_add_apply(row_coord.v.view(), scale, out);
row_coord
.base_drift
.scaled_add_apply(col_coord.v.view(), scale, out);
if let Some(pair_g) = self.pair_g[row][col].as_ref() {
out.scaled_add(-scale, pair_g);
}
}
fn pair_rhs_combo(&self, idx: usize, alpha: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(self.hop.dim());
for j in 0..alpha.len() {
if alpha[j] != 0.0 {
self.scaled_add_pair_rhs(idx, j, alpha[j], &mut out);
}
}
out
}
fn scalar_correction_trace(
&self,
idx: usize,
alpha: &Array1<f64>,
v_i: &Array1<f64>,
m_alpha: &Array1<f64>,
) -> Result<f64, String> {
let OuterHessianDerivativeKernel::ScalarGlm {
c_array,
d_array,
x,
} = &self.kernel
else {
return Err("scalar correction requested for non-scalar kernel".to_string());
};
let z_c = self.adjoint_z_c.as_ref().ok_or_else(|| {
"missing adjoint trace cache for scalar outer Hessian operator".to_string()
})?;
let ingredients = ScalarGlmIngredients {
c_array,
d_array: d_array.as_ref(),
x,
};
let h_g = self.leverage.as_ref().ok_or_else(|| {
"missing leverage cache for scalar outer Hessian operator".to_string()
})?;
let mut c_trace = 0.0;
for (j, &alpha_j) in alpha.iter().enumerate() {
if alpha_j == 0.0 {
continue;
}
c_trace += alpha_j * self.pair_rhs_dot(idx, j, z_c.view());
}
let d_trace = if let Some(trace) = self.fourth_trace.as_ref() {
let mut combo = 0.0;
for (j, &alpha_j) in alpha.iter().enumerate() {
if alpha_j != 0.0 {
combo += alpha_j * trace[[idx, j]];
}
}
combo
} else {
compute_fourth_derivative_trace(&ingredients, v_i, m_alpha, h_g)?.unwrap_or(0.0)
};
Ok(c_trace + d_trace)
}
fn callback_correction_trace(
&self,
rhs: &Array1<f64>,
second_v: &Array1<f64>,
neg_m_alpha: &Array1<f64>,
) -> Result<f64, String> {
let OuterHessianDerivativeKernel::Callback { first, second } = &self.kernel else {
return Err("callback correction requested for non-callback kernel".to_string());
};
let u = self.hop.solve(rhs);
let Some(term1) = first(&u)? else {
return Ok(0.0);
};
let Some(term2) = second(neg_m_alpha, second_v)? else {
return Ok(0.0);
};
let combined = CompositeHyperOperator {
dense: None,
operators: vec![term1.into_operator(), term2.into_operator()],
dim_hint: self.hop.dim(),
};
if let Some(subspace) = self.subspace.as_deref() {
Ok(subspace.trace_operator(&combined))
} else {
Ok(self.hop.trace_logdet_operator(&combined))
}
}
}
impl crate::solver::outer_strategy::OuterHessianOperator for UnifiedOuterHessianOperator {
fn dim(&self) -> usize {
self.coords.len()
}
fn matvec(&self, alpha: &Array1<f64>) -> Result<Array1<f64>, String> {
if alpha.len() != self.coords.len() {
return Err(format!(
"outer Hessian alpha length mismatch: got {}, expected {}",
alpha.len(),
self.coords.len()
));
}
let mut a_alpha = 0.0;
for (idx, coord) in self.coords.iter().enumerate() {
if alpha[idx] != 0.0 {
a_alpha += alpha[idx] * coord.a;
}
}
let correction_m_alpha = self.signed_mode_combo_for_correction(alpha);
let callback_neg_m_alpha =
matches!(self.kernel, OuterHessianDerivativeKernel::Callback { .. })
.then(|| -&correction_m_alpha);
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let values: Result<Vec<f64>, String> = (0..self.coords.len())
.into_par_iter()
.map(|idx| {
let coord = &self.coords[idx];
let pair_a = self.pair_a.row(idx).dot(alpha);
let pair_ld_s = self.pair_ld_s.row(idx).dot(alpha);
let g_dot_v_alpha = self.g_dot_v.row(idx).dot(alpha);
let base_h2 = self.base_h2.row(idx).dot(alpha);
let m_terms = self.m_pair_trace.row(idx).dot(alpha);
let cross_trace = match self.cross_trace.as_ref() {
Some(ct) => ct.row(idx).dot(alpha),
None => 0.0,
};
let correction = if self.incl_logdet_h {
match &self.kernel {
OuterHessianDerivativeKernel::Gaussian => 0.0,
OuterHessianDerivativeKernel::ScalarGlm { .. } => {
self.scalar_correction_trace(idx, alpha, &coord.v, &correction_m_alpha)?
}
OuterHessianDerivativeKernel::Callback { .. } => {
let second_v = &self
.callback_second_modes
.as_ref()
.expect("callback second modes")[idx];
let rhs = self.pair_rhs_combo(idx, alpha);
self.callback_correction_trace(
&rhs,
second_v,
callback_neg_m_alpha
.as_ref()
.expect("callback negated mode"),
)?
}
}
} else {
0.0
};
Ok(outer_hessian_entry(
coord.a,
a_alpha,
g_dot_v_alpha,
pair_a,
cross_trace,
base_h2 + m_terms + correction,
pair_ld_s,
self.profiled_phi,
self.profiled_nu,
self.profiled_dp_cgrad,
self.profiled_dp_cgrad2,
self.is_profiled,
self.incl_logdet_h,
self.incl_logdet_s,
))
})
.collect();
Ok(Array1::from_vec(values?))
}
}
fn build_outer_hessian_operator(
solution: &InnerSolution<'_>,
lambdas: &[f64],
effective_deriv: &dyn HessianDerivativeProvider,
kernel: OuterHessianDerivativeKernel,
precomputed_coord_vs: Option<&[Array1<f64>]>,
precomputed_coord_corrections: Option<&[Option<DriftDerivResult>]>,
) -> Result<UnifiedOuterHessianOperator, String> {
let hop = Arc::clone(&solution.hessian_op);
let k = lambdas.len();
let ext_dim = solution.ext_coords.len();
let total = k + ext_dim;
let curvature_lambdas: Vec<f64> = lambdas
.iter()
.copied()
.map(|lambda| rho_curvature_lambda(solution, lambda))
.collect();
let (incl_logdet_h, incl_logdet_s) = match &solution.dispersion {
DispersionHandling::ProfiledGaussian => (true, true),
DispersionHandling::Fixed {
include_logdet_h,
include_logdet_s,
..
} => (*include_logdet_h, *include_logdet_s),
};
let det2 = solution.penalty_logdet.second.as_ref().ok_or_else(|| {
"Outer Hessian requested but penalty second derivatives not provided".to_string()
})?;
let (profiled_phi, profiled_nu, profiled_dp_cgrad, profiled_dp_cgrad2, is_profiled) =
match &solution.dispersion {
DispersionHandling::ProfiledGaussian => {
let dp_raw = -2.0 * solution.log_likelihood + solution.penalty_quadratic;
let (dp_c, dp_cgrad, dp_cgrad2) = smooth_floor_dp(dp_raw);
let nu = (solution.n_observations as f64 - solution.nullspace_dim).max(DENOM_RIDGE);
let phi_hat = dp_c / nu;
(phi_hat, nu, dp_cgrad, dp_cgrad2, true)
}
_ => (1.0, 1.0, 1.0, 0.0, false),
};
let rho_penalty_a_k_betas: Vec<Array1<f64>> = (0..k)
.into_par_iter()
.map(|idx| penalty_a_k_beta(&solution.penalty_coords[idx], &solution.beta, lambdas[idx]))
.collect();
let rho_curvature_a_k_betas: Vec<Array1<f64>> = (0..k)
.into_par_iter()
.map(|idx| {
penalty_a_k_beta(
&solution.penalty_coords[idx],
&solution.beta,
curvature_lambdas[idx],
)
})
.collect();
let subspace = solution.penalty_subspace_trace.as_deref();
let dispatch_solve = |v: &Array1<f64>| -> Array1<f64> { hop.solve(v) };
let coord_vs_storage;
let coord_vs: &[Array1<f64>] = if let Some(precomputed) = precomputed_coord_vs {
if precomputed.len() != total {
return Err(format!(
"outer Hessian precomputed mode-response count mismatch: got {}, expected {}",
precomputed.len(),
total
));
}
precomputed
} else {
let mut owned: Vec<Array1<f64>> = rho_curvature_a_k_betas
.par_iter()
.map(dispatch_solve)
.collect();
owned.extend(
solution
.ext_coords
.par_iter()
.map(|coord| dispatch_solve(&coord.g))
.collect::<Vec<_>>(),
);
coord_vs_storage = owned;
&coord_vs_storage
};
let coord_corrections_storage;
let coord_corrections: &[Option<DriftDerivResult>] = if let Some(precomputed) =
precomputed_coord_corrections
{
if precomputed.len() != total {
return Err(format!(
"outer Hessian precomputed correction count mismatch: got {}, expected {}",
precomputed.len(),
total
));
}
precomputed
} else if effective_deriv.has_corrections() {
if effective_deriv.has_batched_hessian_derivative_corrections() {
log::info!(
"[STAGE] outer_hessian coord_corrections mode=batched k={} ext_dim={} n={} dim={}",
k,
ext_dim,
solution.n_observations,
hop.dim()
);
coord_corrections_storage =
effective_deriv.hessian_derivative_corrections_result(coord_vs)?;
} else {
coord_corrections_storage = coord_vs
.par_iter()
.map(|v_i| effective_deriv.hessian_derivative_correction_result(v_i))
.collect::<Result<Vec<_>, _>>()?;
}
&coord_corrections_storage
} else {
coord_corrections_storage = (0..total).map(|_| None).collect::<Vec<_>>();
&coord_corrections_storage
};
let mut coords = Vec::with_capacity(total);
for idx in 0..k {
let coord = &solution.penalty_coords[idx];
let penalty_a_k_beta_vec = rho_penalty_a_k_betas[idx].clone();
let curvature_a_k_beta = rho_curvature_a_k_betas[idx].clone();
let v_k = coord_vs[idx].clone();
let correction = coord_corrections[idx].as_ref();
let mut total_dense = None;
let mut total_operators = Vec::new();
match penalty_total_drift_result(coord, curvature_lambdas[idx], correction) {
DriftDerivResult::Dense(matrix) => total_dense = Some(matrix),
DriftDerivResult::Operator(op) => total_operators.push(op),
}
let mut base_dense = None;
let mut base_operators = Vec::new();
match penalty_total_drift_result(coord, curvature_lambdas[idx], None) {
DriftDerivResult::Dense(matrix) => base_dense = Some(matrix),
DriftDerivResult::Operator(op) => base_operators.push(op),
}
let dense_rotated = match (hop.as_dense_spectral(), total_dense.as_ref()) {
(Some(dense_hop), Some(matrix)) => Some(dense_hop.rotate_to_eigenbasis(matrix)),
_ => None,
};
let a_i = 0.5 * solution.beta.dot(&penalty_a_k_beta_vec);
coords.push(OuterHessianCoord {
a: a_i,
g: curvature_a_k_beta,
v: v_k,
total_drift: StoredFirstDrift::from_parts(total_dense, dense_rotated, total_operators),
base_drift: StoredFirstDrift::from_parts(base_dense, None, base_operators),
ext_index: None,
b_depends_on_beta: false,
});
}
for (ext_idx, coord) in solution.ext_coords.iter().enumerate() {
let coord_idx = k + ext_idx;
let v_i = coord_vs[coord_idx].clone();
let correction = coord_corrections[coord_idx].as_ref();
let (total_dense, total_operators) =
hyper_coord_total_drift_parts(&coord.drift, correction);
let (base_dense, base_operators) = hyper_coord_total_drift_parts(&coord.drift, None);
let dense_rotated = match (hop.as_dense_spectral(), total_dense.as_ref()) {
(Some(dense_hop), Some(matrix)) => Some(dense_hop.rotate_to_eigenbasis(matrix)),
_ => None,
};
coords.push(OuterHessianCoord {
a: coord.a,
g: coord.g.clone(),
v: v_i,
total_drift: StoredFirstDrift::from_parts(total_dense, dense_rotated, total_operators),
base_drift: StoredFirstDrift::from_parts(base_dense, None, base_operators),
ext_index: Some(ext_idx),
b_depends_on_beta: coord.b_depends_on_beta,
});
}
let mut pair_a = Array2::<f64>::zeros((total, total));
let mut pair_ld_s = Array2::<f64>::zeros((total, total));
let mut g_dot_v = Array2::<f64>::zeros((total, total));
let mut pair_g = vec![vec![None; total]; total];
let mut base_h2 = Array2::<f64>::zeros((total, total));
let mut m_pair_trace = Array2::<f64>::zeros((total, total));
for ii in 0..total {
for jj in ii..total {
let value = match (coords[ii].ext_index, coords[jj].ext_index) {
(None, None) => {
let rho_j = jj;
rho_penalty_a_k_betas[rho_j].dot(&coords[ii].v)
}
(None, Some(_)) => {
let rho_i = ii;
rho_penalty_a_k_betas[rho_i].dot(&coords[jj].v)
}
(Some(_), None) => {
let rho_j = jj;
rho_penalty_a_k_betas[rho_j].dot(&coords[ii].v)
}
(Some(_), Some(_)) => coords[ii].g.dot(&coords[jj].v),
};
g_dot_v[[ii, jj]] = value;
g_dot_v[[jj, ii]] = value;
}
}
for ii in 0..k {
for jj in ii..k {
pair_ld_s[[ii, jj]] = det2[[ii, jj]];
if ii != jj {
pair_ld_s[[jj, ii]] = det2[[ii, jj]];
}
}
}
for idx in 0..k {
pair_a[[idx, idx]] = coords[idx].a;
pair_g[idx][idx] = Some(coords[idx].g.clone());
let base = if let Some(kernel) = subspace {
let a_k = solution.penalty_coords[idx].scaled_dense_matrix(curvature_lambdas[idx]);
kernel.trace_projected_logdet(&a_k)
} else if solution.penalty_coords[idx].is_block_local() {
let (block, start, end) = solution.penalty_coords[idx].scaled_block_local(1.0);
hop.trace_logdet_block_local(&block, curvature_lambdas[idx], start, end)
} else {
let a_k = solution.penalty_coords[idx].scaled_dense_matrix(curvature_lambdas[idx]);
hop.trace_logdet_gradient(&a_k)
};
base_h2[[idx, idx]] = base;
}
if let Some(rho_ext_fn) = solution.rho_ext_pair_fn.as_ref() {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let pairs: Vec<(usize, usize)> = (0..k)
.flat_map(|rho_idx| (0..ext_dim).map(move |ext_idx| (rho_idx, ext_idx)))
.collect();
let entries: Vec<(usize, usize, HyperCoordPair)> = pairs
.into_par_iter()
.map(|(rho_idx, ext_idx)| {
let pair = rho_ext_fn(rho_idx, ext_idx);
(rho_idx, ext_idx, pair)
})
.collect();
let pair_refs: Vec<&HyperCoordPair> = entries.iter().map(|(_, _, pair)| pair).collect();
let bases = compute_base_h2_traces(hop.as_ref(), &pair_refs, subspace);
for ((rho_idx, ext_idx, pair), base) in entries.into_iter().zip(bases.into_iter()) {
let row = rho_idx;
let col = k + ext_idx;
pair_a[[row, col]] = pair.a;
pair_a[[col, row]] = pair.a;
pair_ld_s[[row, col]] = pair.ld_s;
pair_ld_s[[col, row]] = pair.ld_s;
pair_g[row][col] = Some(pair.g.clone());
pair_g[col][row] = Some(pair.g);
base_h2[[row, col]] = base;
base_h2[[col, row]] = base;
}
}
if let Some(ext_pair_fn) = solution.ext_coord_pair_fn.as_ref() {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let pairs: Vec<(usize, usize)> = (0..ext_dim)
.flat_map(|ii| (ii..ext_dim).map(move |jj| (ii, jj)))
.collect();
let entries: Vec<(usize, usize, HyperCoordPair)> = pairs
.into_par_iter()
.map(|(ii, jj)| {
let pair = ext_pair_fn(ii, jj);
(ii, jj, pair)
})
.collect();
let pair_refs: Vec<&HyperCoordPair> = entries.iter().map(|(_, _, pair)| pair).collect();
let bases = compute_base_h2_traces(hop.as_ref(), &pair_refs, subspace);
for ((ii, jj, pair), base) in entries.into_iter().zip(bases.into_iter()) {
let row = k + ii;
let col = k + jj;
pair_a[[row, col]] = pair.a;
pair_a[[col, row]] = pair.a;
pair_ld_s[[row, col]] = pair.ld_s;
pair_ld_s[[col, row]] = pair.ld_s;
let g_pair = pair.g.clone();
pair_g[row][col] = Some(g_pair.clone());
pair_g[col][row] = Some(g_pair);
base_h2[[row, col]] = base;
base_h2[[col, row]] = base;
}
}
{
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let pairs: Vec<(usize, usize)> = (0..total)
.flat_map(|ii| (ii..total).map(move |jj| (ii, jj)))
.collect();
let entries: Vec<((usize, usize), f64)> = pairs
.into_par_iter()
.map(|(ii, jj)| {
let beta_i = coords[ii].v.mapv(|value| -value);
let beta_j = coords[jj].v.mapv(|value| -value);
let trace = compute_drift_deriv_traces(
hop.as_ref(),
coords[ii].b_depends_on_beta,
coords[jj].b_depends_on_beta,
coords[ii].ext_index,
coords[jj].ext_index,
&beta_i,
&beta_j,
solution.fixed_drift_deriv.as_ref(),
subspace,
);
((ii, jj), trace)
})
.collect();
for ((ii, jj), trace) in entries {
m_pair_trace[[ii, jj]] = trace;
m_pair_trace[[jj, ii]] = trace;
}
}
let cross_trace: Option<Array2<f64>> = if incl_logdet_h {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let dense_hop_opt = hop.as_dense_spectral();
if let Some(kernel) = subspace {
let reduced: Vec<Array2<f64>> = coords
.iter()
.map(|coord| {
let mut out = Array2::<f64>::zeros((
kernel.h_proj_inverse.nrows(),
kernel.h_proj_inverse.ncols(),
));
if let Some(matrix) = coord.total_drift.dense.as_ref() {
out += &kernel.reduce(matrix);
}
for op in &coord.total_drift.operators {
out += &kernel.reduce_operator(op.as_ref());
}
out
})
.collect();
let pairs: Vec<(usize, usize)> = (0..total)
.flat_map(|ii| (ii..total).map(move |jj| (ii, jj)))
.collect();
let pair_values: Vec<((usize, usize), f64)> = pairs
.into_par_iter()
.map(|(ii, jj)| {
let value =
-kernel.trace_projected_logdet_cross_reduced(&reduced[ii], &reduced[jj]);
((ii, jj), value)
})
.collect();
let mut ct = Array2::<f64>::zeros((total, total));
for ((ii, jj), value) in pair_values {
if !value.is_finite() {
return Err(format!(
"outer Hessian operator projected cross_trace[{ii}, {jj}] is non-finite ({value})"
));
}
ct[[ii, jj]] = value;
if ii != jj {
ct[[jj, ii]] = value;
}
}
Some(ct)
} else if hop.prefers_stochastic_trace_estimation() && hop.logdet_traces_match_hinv_kernel()
{
let bundled: Vec<BorrowedStoredDriftOperator<'_>> = coords
.iter()
.map(|coord| BorrowedStoredDriftOperator {
drift: &coord.total_drift,
dim_hint: hop.dim(),
})
.collect();
let op_refs: Vec<&dyn HyperOperator> =
bundled.iter().map(|op| op as &dyn HyperOperator).collect();
let estimator = StochasticTraceEstimator::for_outer_hessian(hop.dim(), total);
let no_dense: [&Array2<f64>; 0] = [];
let mut ct = estimator.estimate_second_order_traces_with_operators(
hop.as_ref(),
&no_dense,
&op_refs,
);
ct.mapv_inplace(|value| -value);
Some(ct)
} else if let Some(dense_hop) = dense_hop_opt {
let rotated: Vec<Array2<f64>> = coords
.iter()
.map(|coord| {
let mut projected =
coord.total_drift.dense_rotated.clone().unwrap_or_else(|| {
Array2::<f64>::zeros((dense_hop.n_dim, dense_hop.n_dim))
});
for op in &coord.total_drift.operators {
projected +=
&dense_hop.projected_operator(&dense_hop.eigenvectors, op.as_ref());
}
projected
})
.collect();
let mut ct = Array2::<f64>::zeros((total, total));
for ii in 0..total {
for jj in ii..total {
let value =
dense_hop.trace_logdet_hessian_cross_rotated(&rotated[ii], &rotated[jj]);
if !value.is_finite() {
return Err(format!(
"outer Hessian operator cross_trace[{ii}, {jj}] is non-finite ({value})"
));
}
ct[[ii, jj]] = value;
if ii != jj {
ct[[jj, ii]] = value;
}
}
}
Some(ct)
} else {
let pairs: Vec<(usize, usize)> = (0..total)
.flat_map(|ii| (ii..total).map(move |jj| (ii, jj)))
.collect();
let pair_values: Vec<((usize, usize), f64)> = pairs
.into_par_iter()
.map(|(ii, jj)| {
let left = &coords[ii].total_drift;
let right = &coords[jj].total_drift;
let mut value = 0.0;
if let (Some(left_dense), Some(right_dense)) =
(left.dense.as_ref(), right.dense.as_ref())
{
if let (Some(dense_hop), Some(left_rot), Some(right_rot)) = (
dense_hop_opt,
left.dense_rotated.as_ref(),
right.dense_rotated.as_ref(),
) {
value +=
dense_hop.trace_logdet_hessian_cross_rotated(left_rot, right_rot);
} else {
value += hop.trace_logdet_hessian_cross(left_dense, right_dense);
}
}
if let Some(left_dense) = left.dense.as_ref() {
for op in &right.operators {
value -= hop.trace_hinv_matrix_operator_cross(left_dense, op.as_ref());
}
}
if let Some(right_dense) = right.dense.as_ref() {
for op in &left.operators {
value -= hop.trace_hinv_matrix_operator_cross(right_dense, op.as_ref());
}
}
if !left.operators.is_empty() && !right.operators.is_empty() {
let left_bundle = WeightedHyperOperator {
terms: left
.operators
.iter()
.map(|op| (1.0, Arc::clone(op)))
.collect(),
dim_hint: hop.dim(),
};
let right_bundle = WeightedHyperOperator {
terms: right
.operators
.iter()
.map(|op| (1.0, Arc::clone(op)))
.collect(),
dim_hint: hop.dim(),
};
value -= hop.trace_hinv_operator_cross(&left_bundle, &right_bundle);
}
((ii, jj), value)
})
.collect();
let mut ct = Array2::<f64>::zeros((total, total));
for ((ii, jj), value) in pair_values {
if !value.is_finite() {
return Err(format!(
"outer Hessian operator cross_trace[{ii}, {jj}] is non-finite ({value})"
));
}
ct[[ii, jj]] = value;
if ii != jj {
ct[[jj, ii]] = value;
}
}
Some(ct)
}
} else {
None
};
let leverage = if incl_logdet_h {
match &kernel {
OuterHessianDerivativeKernel::Gaussian => None,
OuterHessianDerivativeKernel::ScalarGlm { x, .. } => match subspace {
Some(s) => Some(s.xt_projected_kernel_x_diagonal(x)),
None => Some(hop.xt_logdet_kernel_x_diagonal(x)),
},
OuterHessianDerivativeKernel::Callback { .. } => None,
}
} else {
None
};
let adjoint_z_c = if incl_logdet_h {
match (&kernel, leverage.as_ref()) {
(
OuterHessianDerivativeKernel::ScalarGlm {
c_array,
d_array,
x,
},
Some(h_g),
) => Some(compute_adjoint_z_c(
&ScalarGlmIngredients {
c_array,
d_array: d_array.as_ref(),
x,
},
hop.as_ref(),
h_g,
)?),
_ => None,
}
} else {
None
};
let callback_second_modes = matches!(kernel, OuterHessianDerivativeKernel::Callback { .. })
.then(|| {
coords
.iter()
.map(|coord| {
if coord.is_ext() {
coord.v.clone()
} else {
-&coord.v
}
})
.collect::<Vec<_>>()
});
let fourth_trace = if incl_logdet_h && adjoint_z_c.is_some() {
match (&kernel, leverage.as_ref()) {
(
OuterHessianDerivativeKernel::ScalarGlm {
c_array,
d_array: Some(d_array),
x,
},
Some(h_g),
) => {
let modes = coords.iter().map(|coord| &coord.v).collect::<Vec<_>>();
compute_fourth_derivative_trace_matrix(
&ScalarGlmIngredients {
c_array,
d_array: Some(d_array),
x,
},
&modes,
h_g,
)?
}
_ => None,
}
} else {
None
};
Ok(UnifiedOuterHessianOperator {
hop,
coords,
pair_a,
pair_ld_s,
g_dot_v,
pair_g,
base_h2,
m_pair_trace,
cross_trace,
profiled_phi,
profiled_nu,
profiled_dp_cgrad,
profiled_dp_cgrad2,
is_profiled,
incl_logdet_h,
incl_logdet_s,
kernel,
subspace: solution.penalty_subspace_trace.clone(),
adjoint_z_c,
leverage,
fourth_trace,
callback_second_modes,
})
}
const EFS_MAX_STEP: f64 = 5.0;
pub fn compute_efs_update(solution: &InnerSolution<'_>, rho: &[f64], gradient: &[f64]) -> Vec<f64> {
let k = rho.len();
let ext_dim = solution.ext_coords.len();
let total = k + ext_dim;
debug_assert_eq!(
gradient.len(),
total,
"compute_efs_update: gradient length {} != n_rho({k}) + n_ext({ext_dim})",
gradient.len(),
);
let mut steps = vec![0.0; total];
let (profiled_scale, dp_cgrad) = efs_profiling(solution);
for idx in 0..k {
let coord = &solution.penalty_coords[idx];
let lambda = rho[idx].exp();
let a_i = 0.5 * penalty_a_k_quadratic(coord, &solution.beta, lambda);
let q_eff = efs_q_eff(a_i, &solution.dispersion, dp_cgrad, profiled_scale);
if let Some(step) = efs_log_step_from_grad(q_eff, gradient[idx]) {
steps[idx] = step;
}
}
for (ext_idx, coord) in solution.ext_coords.iter().enumerate() {
if !coord.is_penalty_like {
continue;
}
let g_idx = k + ext_idx;
let q_eff = efs_q_eff(coord.a, &solution.dispersion, dp_cgrad, profiled_scale);
if let Some(step) = efs_log_step_from_grad(q_eff, gradient[g_idx]) {
steps[g_idx] = step;
}
}
steps
}
const PSI_GRAM_PINV_TOL: f64 = 1e-8;
const PSI_INITIAL_ALPHA: f64 = 1.0;
const HYBRID_EFS_SCALAR_PAR_THRESHOLD: usize = 8;
const HYBRID_EFS_GRAM_PAIR_PAR_THRESHOLD: usize = 24;
const HYBRID_EFS_PSI_DRIFT_PAR_THRESHOLD: usize = 8;
pub struct HybridEfsResult {
pub steps: Vec<f64>,
pub psi_indices: Vec<usize>,
pub psi_gradient: Vec<f64>,
}
pub fn compute_hybrid_efs_update(
solution: &InnerSolution<'_>,
rho: &[f64],
gradient: &[f64],
) -> HybridEfsResult {
let k = rho.len();
let hop = &*solution.hessian_op;
let ext_dim = solution.ext_coords.len();
let total = k + ext_dim;
let mut steps = vec![0.0; total];
let (profiled_scale, dp_cgrad) = efs_profiling(solution);
debug_assert_eq!(
gradient.len(),
total,
"compute_hybrid_efs_update: gradient length {} != n_rho({k}) + n_ext({ext_dim})",
gradient.len(),
);
let rho_candidates: Vec<(usize, Option<f64>)> =
if k >= HYBRID_EFS_SCALAR_PAR_THRESHOLD && rayon::current_thread_index().is_none() {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
(0..k)
.into_par_iter()
.map(|idx| {
let coord = &solution.penalty_coords[idx];
let lambda = rho[idx].exp();
let a_i = 0.5 * penalty_a_k_quadratic(coord, &solution.beta, lambda);
let q_eff = efs_q_eff(a_i, &solution.dispersion, dp_cgrad, profiled_scale);
(idx, efs_log_step_from_grad(q_eff, gradient[idx]))
})
.collect()
} else {
(0..k)
.map(|idx| {
let coord = &solution.penalty_coords[idx];
let lambda = rho[idx].exp();
let a_i = 0.5 * penalty_a_k_quadratic(coord, &solution.beta, lambda);
let q_eff = efs_q_eff(a_i, &solution.dispersion, dp_cgrad, profiled_scale);
(idx, efs_log_step_from_grad(q_eff, gradient[idx]))
})
.collect()
};
for (idx, candidate) in rho_candidates {
if let Some(step) = candidate {
steps[idx] = step;
}
}
let mut psi_local_indices: Vec<usize> = Vec::new(); let mut psi_global_indices: Vec<usize> = Vec::new(); let mut tau_local_indices: Vec<usize> = Vec::new();
for (ext_idx, coord) in solution.ext_coords.iter().enumerate() {
let g_idx = k + ext_idx;
if coord.is_penalty_like {
tau_local_indices.push(ext_idx);
} else {
psi_local_indices.push(ext_idx);
psi_global_indices.push(g_idx);
}
}
let tau_candidates: Vec<(usize, Option<f64>)> = if tau_local_indices.len()
>= HYBRID_EFS_SCALAR_PAR_THRESHOLD
&& rayon::current_thread_index().is_none()
{
use rayon::iter::{IntoParallelIterator, ParallelIterator};
tau_local_indices
.iter()
.copied()
.collect::<Vec<_>>()
.into_par_iter()
.map(|ext_idx| {
let coord = &solution.ext_coords[ext_idx];
let g_idx = k + ext_idx;
let q_eff = efs_q_eff(coord.a, &solution.dispersion, dp_cgrad, profiled_scale);
(g_idx, efs_log_step_from_grad(q_eff, gradient[g_idx]))
})
.collect()
} else {
tau_local_indices
.iter()
.map(|&ext_idx| {
let coord = &solution.ext_coords[ext_idx];
let g_idx = k + ext_idx;
let q_eff = efs_q_eff(coord.a, &solution.dispersion, dp_cgrad, profiled_scale);
(g_idx, efs_log_step_from_grad(q_eff, gradient[g_idx]))
})
.collect()
};
for (g_idx, candidate) in tau_candidates {
if let Some(step) = candidate {
steps[g_idx] = step;
}
}
let psi_gradient: Vec<f64> = psi_global_indices.iter().map(|&gi| gradient[gi]).collect();
let n_psi = psi_local_indices.len();
if n_psi > 0 {
if n_psi == 1 {
let li = psi_local_indices[0];
let drift = &solution.ext_coords[li].drift;
let op = hyper_coord_drift_operator_arc(drift, hop.dim());
let dense = op.is_none().then(|| drift.materialize());
let gram = if let Some(dense_hop) = hop.as_dense_spectral() {
let projected = if let Some(op) = op.as_ref() {
dense_hop.projected_operator(&dense_hop.w_factor, op.as_ref())
} else {
dense_hop
.projected_matrix(dense.as_ref().expect("dense drift should be cached"))
};
dense_hop.trace_projected_cross(&projected, &projected)
} else {
trace_hinv_cached_drift_cross(
hop,
dense.as_ref(),
op.as_deref(),
dense.as_ref(),
op.as_deref(),
)
};
if gram.abs() >= PSI_GRAM_PINV_TOL.max(1e-30) {
let global_idx = psi_global_indices[0];
let raw_step = -PSI_INITIAL_ALPHA * psi_gradient[0] / gram;
steps[global_idx] = raw_step.clamp(-EFS_MAX_STEP, EFS_MAX_STEP);
}
return HybridEfsResult {
steps,
psi_indices: psi_global_indices,
psi_gradient,
};
}
let total_p = hop.dim();
let any_psi_operator = psi_local_indices.iter().any(|&li| {
let drift = &solution.ext_coords[li].drift;
drift.uses_operator_fast_path()
});
let use_stochastic_psi_gram =
any_psi_operator && total_p > 500 && hop.prefers_stochastic_trace_estimation();
let gram = if use_stochastic_psi_gram {
let mut dense_mats = Vec::new();
let mut coord_has_operator = Vec::with_capacity(n_psi);
let mut operator_arcs: Vec<Arc<dyn HyperOperator>> = Vec::new();
for &li in &psi_local_indices {
let coord = &solution.ext_coords[li];
if let Some(op) = hyper_coord_drift_operator_arc(&coord.drift, hop.dim()) {
coord_has_operator.push(true);
operator_arcs.push(op);
} else {
coord_has_operator.push(false);
dense_mats.push(coord.drift.materialize());
}
}
let generic_ops: Vec<&dyn HyperOperator> =
operator_arcs.iter().map(|op| op.as_ref()).collect();
let impl_ops: Vec<&ImplicitHyperOperator> = generic_ops
.iter()
.filter_map(|op| op.as_implicit())
.collect();
stochastic_trace_hinv_crosses(
hop,
&dense_mats,
&coord_has_operator,
&generic_ops,
&impl_ops,
)
} else {
let mut gram = ndarray::Array2::<f64>::zeros((n_psi, n_psi));
let parallel_psi_drifts = n_psi >= HYBRID_EFS_PSI_DRIFT_PAR_THRESHOLD
&& rayon::current_thread_index().is_none();
let drift_ops: Vec<Option<Arc<dyn HyperOperator>>> = if parallel_psi_drifts {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
(0..n_psi)
.into_par_iter()
.map(|idx| {
let drift = &solution.ext_coords[psi_local_indices[idx]].drift;
hyper_coord_drift_operator_arc(drift, hop.dim())
})
.collect()
} else {
psi_local_indices
.iter()
.map(|&li| {
let drift = &solution.ext_coords[li].drift;
hyper_coord_drift_operator_arc(drift, hop.dim())
})
.collect()
};
let dense_drifts: Vec<Option<Array2<f64>>> = if parallel_psi_drifts {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
(0..n_psi)
.into_par_iter()
.map(|idx| {
let drift = &solution.ext_coords[psi_local_indices[idx]].drift;
drift_ops[idx].is_none().then(|| drift.materialize())
})
.collect()
} else {
psi_local_indices
.iter()
.enumerate()
.map(|(idx, &li)| {
let drift = &solution.ext_coords[li].drift;
drift_ops[idx].is_none().then(|| drift.materialize())
})
.collect()
};
let pair_count = n_psi * (n_psi + 1) / 2;
let parallel_gram_pairs = pair_count >= HYBRID_EFS_GRAM_PAIR_PAR_THRESHOLD
&& rayon::current_thread_index().is_none();
if let Some(dense_hop) = hop.as_dense_spectral() {
let projected_drifts: Vec<Array2<f64>> = if parallel_psi_drifts {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
(0..n_psi)
.into_par_iter()
.map(|idx| {
if let Some(op) = drift_ops[idx].as_ref() {
dense_hop.projected_operator(&dense_hop.w_factor, op.as_ref())
} else {
dense_hop.projected_matrix(
dense_drifts[idx]
.as_ref()
.expect("dense drift should be cached"),
)
}
})
.collect()
} else {
(0..n_psi)
.map(|idx| {
if let Some(op) = drift_ops[idx].as_ref() {
dense_hop.projected_operator(&dense_hop.w_factor, op.as_ref())
} else {
dense_hop.projected_matrix(
dense_drifts[idx]
.as_ref()
.expect("dense drift should be cached"),
)
}
})
.collect()
};
if parallel_gram_pairs {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let pairs: Vec<(usize, usize)> = (0..n_psi)
.flat_map(|d| (d..n_psi).map(move |e| (d, e)))
.collect();
let pair_values: Vec<(usize, usize, f64)> = pairs
.into_par_iter()
.map(|(d, e)| {
let val = dense_hop
.trace_projected_cross(&projected_drifts[d], &projected_drifts[e]);
(d, e, val)
})
.collect();
for (d, e, val) in pair_values {
gram[[d, e]] = val;
gram[[e, d]] = val;
}
} else {
for d in 0..n_psi {
for e in d..n_psi {
let val = dense_hop
.trace_projected_cross(&projected_drifts[d], &projected_drifts[e]);
gram[[d, e]] = val;
gram[[e, d]] = val;
}
}
}
} else if parallel_gram_pairs {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let pairs: Vec<(usize, usize)> = (0..n_psi)
.flat_map(|d| (d..n_psi).map(move |e| (d, e)))
.collect();
let pair_values: Vec<(usize, usize, f64)> = pairs
.into_par_iter()
.map(|(d, e)| {
let val = trace_hinv_cached_drift_cross(
hop,
dense_drifts[d].as_ref(),
drift_ops[d].as_deref(),
dense_drifts[e].as_ref(),
drift_ops[e].as_deref(),
);
(d, e, val)
})
.collect();
for (d, e, val) in pair_values {
gram[[d, e]] = val;
gram[[e, d]] = val;
}
} else {
for d in 0..n_psi {
for e in d..n_psi {
let val = trace_hinv_cached_drift_cross(
hop,
dense_drifts[d].as_ref(),
drift_ops[d].as_deref(),
dense_drifts[e].as_ref(),
drift_ops[e].as_deref(),
);
gram[[d, e]] = val;
gram[[e, d]] = val;
}
}
}
gram
};
let delta_psi = pseudoinverse_times_vec(&gram, &psi_gradient, PSI_GRAM_PINV_TOL);
let alpha = PSI_INITIAL_ALPHA;
for (psi_idx, &global_idx) in psi_global_indices.iter().enumerate() {
let raw_step = -alpha * delta_psi[psi_idx];
steps[global_idx] = raw_step.clamp(-EFS_MAX_STEP, EFS_MAX_STEP);
}
}
HybridEfsResult {
steps,
psi_indices: psi_global_indices,
psi_gradient,
}
}
fn pseudoinverse_times_vec(
gram: &ndarray::Array2<f64>,
v: &[f64],
tol: f64,
) -> ndarray::Array1<f64> {
let n = gram.nrows();
assert_eq!(n, v.len(), "pseudoinverse_times_vec dimension mismatch");
if n == 0 {
return ndarray::Array1::zeros(0);
}
if n == 1 {
let g = gram[[0, 0]];
if g.abs() < tol.max(1e-30) {
return ndarray::Array1::zeros(1);
}
return ndarray::Array1::from_vec(vec![v[0] / g]);
}
let (eigenvalues, eigenvectors) = symmetric_eigen(gram);
let max_eval = eigenvalues.iter().cloned().fold(0.0_f64, f64::max);
let cutoff = tol * max_eval;
let qt_v: Vec<f64> = (0..n)
.map(|i| (0..n).map(|row| eigenvectors[[row, i]] * v[row]).sum())
.collect();
let mut result = ndarray::Array1::zeros(n);
for i in 0..n {
if eigenvalues[i] > cutoff {
let scale = qt_v[i] / eigenvalues[i];
for row in 0..n {
result[row] += scale * eigenvectors[[row, i]];
}
}
}
result
}
fn symmetric_eigen(a: &ndarray::Array2<f64>) -> (Vec<f64>, ndarray::Array2<f64>) {
let n = a.nrows();
assert_eq!(n, a.ncols(), "symmetric_eigen requires square matrix");
let mut work = a.clone();
let mut v = ndarray::Array2::<f64>::eye(n);
let max_sweeps = 100;
let tol = 1e-15;
let mut sweep = 0;
while sweep < max_sweeps {
let mut off_diag_sq = 0.0;
for i in 0..n {
for j in (i + 1)..n {
off_diag_sq += work[[i, j]] * work[[i, j]];
}
}
if off_diag_sq < tol * tol {
break;
}
for p in 0..n {
for q in (p + 1)..n {
let apq = work[[p, q]];
if apq.abs() < tol * 0.01 {
continue;
}
let app = work[[p, p]];
let aqq = work[[q, q]];
let tau = (aqq - app) / (2.0 * apq);
let t = if tau.abs() > 1e15 {
continue;
} else {
let sign_tau = if tau >= 0.0 { 1.0 } else { -1.0 };
sign_tau / (tau.abs() + (1.0 + tau * tau).sqrt())
};
let c = 1.0 / (1.0 + t * t).sqrt();
let s = t * c;
work[[p, p]] = app - t * apq;
work[[q, q]] = aqq + t * apq;
work[[p, q]] = 0.0;
work[[q, p]] = 0.0;
for r in 0..n {
if r == p || r == q {
continue;
}
let wrp = work[[r, p]];
let wrq = work[[r, q]];
work[[r, p]] = c * wrp - s * wrq;
work[[p, r]] = work[[r, p]];
work[[r, q]] = s * wrp + c * wrq;
work[[q, r]] = work[[r, q]];
}
for r in 0..n {
let vrp = v[[r, p]];
let vrq = v[[r, q]];
v[[r, p]] = c * vrp - s * vrq;
v[[r, q]] = s * vrp + c * vrq;
}
}
}
sweep += 1;
}
let eigenvalues: Vec<f64> = (0..n).map(|i| work[[i, i]]).collect();
(eigenvalues, v)
}
#[derive(Debug, Clone)]
pub struct OuterHessianIndefinite {
pub min_eigenvalue: f64,
pub active_constraints: Vec<usize>,
pub theta: Vec<f64>,
pub gradient_norm: f64,
pub hessian_norm: f64,
pub suggested_action: &'static str,
}
impl OuterHessianIndefinite {
fn theta_dimension(&self) -> usize {
self.theta.len()
}
}
#[derive(Debug, Clone)]
pub enum CorrectedCovarianceError {
ShapeMismatch(String),
EigendecompositionFailed(String),
Indefinite(OuterHessianIndefinite),
}
impl core::fmt::Display for CorrectedCovarianceError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::ShapeMismatch(msg) => write!(f, "shape mismatch: {msg}"),
Self::EigendecompositionFailed(msg) => write!(f, "eigendecomposition failed: {msg}"),
Self::Indefinite(d) => write!(
f,
"outer Hessian indefinite on free subspace (min eigenvalue = {:.3e}, \
||H||_F = {:.3e}, ||g||_2 = {:.3e}, active = {:?}, theta = {:?}); {}",
d.min_eigenvalue,
d.hessian_norm,
d.gradient_norm,
d.active_constraints,
d.theta,
d.suggested_action,
),
}
}
}
impl std::error::Error for CorrectedCovarianceError {}
#[derive(Debug, Clone)]
pub struct CorrectedCovariance {
pub matrix: Array2<f64>,
pub active_constraints: Vec<usize>,
pub rank_deficient_directions: Vec<usize>,
}
impl CorrectedCovariance {
fn has_structural_diagnostics(&self) -> bool {
!self.active_constraints.is_empty() || !self.rank_deficient_directions.is_empty()
}
}
const INDEFINITE_SUGGESTED_ACTION: &str = "refit with a tighter outer tolerance, verify the inspected objective is the true \
REML/LAML cost rather than a surrogate, and audit recent active-set transitions";
fn detect_active_theta_bounds(theta: Option<&[f64]>, q: usize) -> Vec<usize> {
let Some(theta) = theta else {
return Vec::new();
};
if theta.len() != q {
return Vec::new();
}
let bound = crate::solver::estimate::RHO_BOUND;
let tol = 1e-8;
theta
.iter()
.enumerate()
.filter_map(|(i, &v)| (v.abs() >= bound - tol).then_some(i))
.collect()
}
fn active_bound_indices_for_theta(
theta: Option<&[f64]>,
rho_len: usize,
ext_len: usize,
) -> Vec<usize> {
let q = rho_len + ext_len;
let mut active = detect_active_theta_bounds(theta, q);
active.retain(|&i| i < rho_len);
let _ = ext_len;
active
}
fn projected_inverse_with_inertia_gate(
outer_hessian: &Array2<f64>,
active: &[usize],
theta_for_diag: Option<&[f64]>,
gradient_norm: f64,
) -> Result<(Array2<f64>, Vec<usize>), CorrectedCovarianceError> {
let q = outer_hessian.nrows();
let mut is_active = vec![false; q];
for &i in active {
if i < q {
is_active[i] = true;
}
}
let free: Vec<usize> = (0..q).filter(|i| !is_active[*i]).collect();
let qf = free.len();
let h_norm = outer_hessian.iter().map(|v| v * v).sum::<f64>().sqrt();
let mut v_theta_full = Array2::<f64>::zeros((q, q));
if qf == 0 {
return Ok((v_theta_full, Vec::new()));
}
let mut h_ff = Array2::<f64>::zeros((qf, qf));
for (a, &ia) in free.iter().enumerate() {
for (b, &ib) in free.iter().enumerate() {
h_ff[[a, b]] = outer_hessian[[ia, ib]];
}
}
let (evals, evecs) = h_ff.eigh(faer::Side::Lower).map_err(|e| {
CorrectedCovarianceError::EigendecompositionFailed(format!("projected outer Hessian: {e}"))
})?;
let eps = f64::EPSILON;
let neg_tol = 8.0 * eps * (q.max(1) as f64) * h_norm.max(1.0);
let min_eig = evals.iter().copied().fold(f64::INFINITY, f64::min);
if min_eig < -neg_tol {
let diagnostic = OuterHessianIndefinite {
min_eigenvalue: min_eig,
active_constraints: active.to_vec(),
theta: theta_for_diag.map(|t| t.to_vec()).unwrap_or_default(),
gradient_norm,
hessian_norm: h_norm,
suggested_action: INDEFINITE_SUGGESTED_ACTION,
};
let _theta_dimension = diagnostic.theta_dimension();
return Err(CorrectedCovarianceError::Indefinite(diagnostic));
}
let pos_tol = 8.0 * eps * (q.max(1) as f64) * h_norm.max(1.0);
let mut v_theta_ff = Array2::<f64>::zeros((qf, qf));
let mut rank_deficient_free: Vec<usize> = Vec::new();
for j in 0..qf {
let sigma = evals[j];
if sigma.abs() <= pos_tol {
rank_deficient_free.push(j);
continue;
}
let inv_sigma = 1.0 / sigma;
let u = evecs.column(j);
for a in 0..qf {
let ua = inv_sigma * u[a];
for b in a..qf {
let val = ua * u[b];
v_theta_ff[[a, b]] += val;
if a != b {
v_theta_ff[[b, a]] += val;
}
}
}
}
for (a, &ia) in free.iter().enumerate() {
for (b, &ib) in free.iter().enumerate() {
v_theta_full[[ia, ib]] = v_theta_ff[[a, b]];
}
}
let rank_deficient_directions: Vec<usize> =
rank_deficient_free.into_iter().map(|j| free[j]).collect();
Ok((v_theta_full, rank_deficient_directions))
}
pub fn compute_corrected_covariance(
v_ks: &[Array1<f64>],
ext_v: &[Array1<f64>],
outer_hessian: &Array2<f64>,
hop: &dyn HessianOperator,
) -> Result<Array2<f64>, CorrectedCovarianceError> {
compute_corrected_covariance_with_constraints(v_ks, ext_v, outer_hessian, hop, None, f64::NAN)
.map(|cov| {
if cov.has_structural_diagnostics() {
log::debug!(
"corrected covariance diagnostics: active_constraints={:?} rank_deficient_directions={:?}",
cov.active_constraints,
cov.rank_deficient_directions
);
}
cov.matrix
})
}
pub fn compute_corrected_covariance_with_constraints(
v_ks: &[Array1<f64>],
ext_v: &[Array1<f64>],
outer_hessian: &Array2<f64>,
hop: &dyn HessianOperator,
theta_at_optimum: Option<&[f64]>,
gradient_norm: f64,
) -> Result<CorrectedCovariance, CorrectedCovarianceError> {
let p = hop.dim();
let q = v_ks.len() + ext_v.len();
if q == 0 {
let eye = Array2::eye(p);
return Ok(CorrectedCovariance {
matrix: hop.solve_multi(&eye),
active_constraints: Vec::new(),
rank_deficient_directions: Vec::new(),
});
}
if outer_hessian.nrows() != q || outer_hessian.ncols() != q {
return Err(CorrectedCovarianceError::ShapeMismatch(format!(
"compute_corrected_covariance: outer Hessian dimension ({}, {}) does not match \
total hyperparameter count q = {} (rho: {}, ext: {})",
outer_hessian.nrows(),
outer_hessian.ncols(),
q,
v_ks.len(),
ext_v.len(),
)));
}
let mut j_alpha = Array2::zeros((p, q));
for (col, v) in v_ks.iter().enumerate() {
for row in 0..p {
j_alpha[[row, col]] = -v[row];
}
}
for (i, v) in ext_v.iter().enumerate() {
let col = v_ks.len() + i;
for row in 0..p {
j_alpha[[row, col]] = -v[row];
}
}
let active = active_bound_indices_for_theta(theta_at_optimum, v_ks.len(), ext_v.len());
let (v_theta, rank_deficient_directions) = projected_inverse_with_inertia_gate(
outer_hessian,
&active,
theta_at_optimum,
gradient_norm,
)?;
let j_v_theta = j_alpha.dot(&v_theta);
let correction = j_v_theta.dot(&j_alpha.t());
let eye = Array2::eye(p);
let mut h_inv = hop.solve_multi(&eye);
h_inv += &correction;
enforce_symmetry_inplace(&mut h_inv);
Ok(CorrectedCovariance {
matrix: h_inv,
active_constraints: active,
rank_deficient_directions,
})
}
pub fn compute_corrected_covariance_diagonal(
v_ks: &[Array1<f64>],
ext_v: &[Array1<f64>],
outer_hessian: &Array2<f64>,
hop: &dyn HessianOperator,
) -> Result<Array1<f64>, CorrectedCovarianceError> {
compute_corrected_covariance_diagonal_with_constraints(
v_ks,
ext_v,
outer_hessian,
hop,
None,
f64::NAN,
)
.map(|d| {
if d.has_structural_diagnostics() {
log::debug!(
"corrected covariance diagonal diagnostics: active_constraints={:?} rank_deficient_directions={:?}",
d.active_constraints,
d.rank_deficient_directions
);
}
d.diagonal
})
}
#[derive(Debug, Clone)]
pub struct CorrectedCovarianceDiagonal {
pub diagonal: Array1<f64>,
pub active_constraints: Vec<usize>,
pub rank_deficient_directions: Vec<usize>,
}
impl CorrectedCovarianceDiagonal {
fn has_structural_diagnostics(&self) -> bool {
!self.active_constraints.is_empty() || !self.rank_deficient_directions.is_empty()
}
}
pub fn compute_corrected_covariance_diagonal_with_constraints(
v_ks: &[Array1<f64>],
ext_v: &[Array1<f64>],
outer_hessian: &Array2<f64>,
hop: &dyn HessianOperator,
theta_at_optimum: Option<&[f64]>,
gradient_norm: f64,
) -> Result<CorrectedCovarianceDiagonal, CorrectedCovarianceError> {
let p = hop.dim();
let q = v_ks.len() + ext_v.len();
let mut diag = Array1::zeros(p);
for i in 0..p {
let mut e_i = Array1::zeros(p);
e_i[i] = 1.0;
let h_inv_ei = hop.solve(&e_i);
diag[i] = h_inv_ei[i];
}
if q == 0 {
return Ok(CorrectedCovarianceDiagonal {
diagonal: diag,
active_constraints: Vec::new(),
rank_deficient_directions: Vec::new(),
});
}
if outer_hessian.nrows() != q || outer_hessian.ncols() != q {
return Err(CorrectedCovarianceError::ShapeMismatch(format!(
"compute_corrected_covariance_diagonal: outer Hessian dimension ({}, {}) \
does not match q = {}",
outer_hessian.nrows(),
outer_hessian.ncols(),
q,
)));
}
let active = active_bound_indices_for_theta(theta_at_optimum, v_ks.len(), ext_v.len());
let (v_theta_full, rank_deficient_directions) = projected_inverse_with_inertia_gate(
outer_hessian,
&active,
theta_at_optimum,
gradient_norm,
)?;
let (sym_evals, sym_evecs) = v_theta_full
.eigh(faer::Side::Lower)
.map_err(|e| CorrectedCovarianceError::EigendecompositionFailed(e.to_string()))?;
let mut v_theta_sqrt = Array2::<f64>::zeros((q, q));
for j in 0..q {
let s = sym_evals[j];
if s <= 0.0 {
continue;
}
let scale = s.sqrt();
for row in 0..q {
v_theta_sqrt[[row, j]] = sym_evecs[[row, j]] * scale;
}
}
let mut j_alpha = Array2::zeros((p, q));
for (col, v) in v_ks.iter().enumerate() {
for row in 0..p {
j_alpha[[row, col]] = -v[row];
}
}
for (i, v) in ext_v.iter().enumerate() {
let col = v_ks.len() + i;
for row in 0..p {
j_alpha[[row, col]] = -v[row];
}
}
let m = j_alpha.dot(&v_theta_sqrt);
for i in 0..p {
let mut row_norm_sq = 0.0;
for j in 0..m.ncols() {
row_norm_sq += m[[i, j]] * m[[i, j]];
}
diag[i] += row_norm_sq;
}
Ok(CorrectedCovarianceDiagonal {
diagonal: diag,
active_constraints: active,
rank_deficient_directions,
})
}
fn enforce_symmetry_inplace(m: &mut Array2<f64>) {
let n = m.nrows();
for i in 0..n {
for j in (i + 1)..n {
let avg = 0.5 * (m[[i, j]] + m[[j, i]]);
m[[i, j]] = avg;
m[[j, i]] = avg;
}
}
}
#[inline]
pub(crate) fn spectral_regularize(sigma: f64, epsilon: f64) -> f64 {
let disc = sigma.hypot(2.0 * epsilon);
if sigma >= 0.0 {
0.5 * sigma + 0.5 * disc
} else {
(2.0 * epsilon * epsilon) / (disc - sigma)
}
}
#[inline]
pub(crate) fn spectral_epsilon(eigenvalues: &[f64]) -> f64 {
f64::EPSILON.sqrt() * (eigenvalues.len() as f64).max(1.0)
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum PseudoLogdetMode {
Smooth,
HardPseudo,
}
impl Default for PseudoLogdetMode {
fn default() -> Self {
Self::Smooth
}
}
pub struct DenseSpectralOperator {
reg_eigenvalues: Vec<f64>,
active_mask: Vec<bool>,
eigenvectors: Array2<f64>,
w_factor: Array2<f64>,
hinv_cross_kernel: Array2<f64>,
g_factor: Array2<f64>,
logdet_hessian_kernel: Array2<f64>,
cached_logdet: f64,
projected_factor_cache: ProjectedFactorCache,
n_dim: usize,
}
impl DenseSpectralOperator {
pub fn from_symmetric(h: &Array2<f64>) -> Result<Self, String> {
Self::from_symmetric_with_mode(h, PseudoLogdetMode::Smooth)
}
pub fn from_symmetric_with_mode(
h: &Array2<f64>,
mode: PseudoLogdetMode,
) -> Result<Self, String> {
use faer::Side;
let n = h.nrows();
if n != h.ncols() {
return Err(format!(
"HessianOperator: expected square matrix, got {}×{}",
n,
h.ncols()
));
}
let (eigenvalues, eigenvectors) = h
.eigh(Side::Lower)
.map_err(|e| format!("Eigendecomposition failed: {e}"))?;
let epsilon = spectral_epsilon(eigenvalues.as_slice().unwrap());
let active: Vec<bool> = match mode {
PseudoLogdetMode::Smooth => vec![true; n],
PseudoLogdetMode::HardPseudo => eigenvalues.iter().map(|&s| s > epsilon).collect(),
};
let reg_eigenvalues: Vec<f64> = eigenvalues
.iter()
.map(|&sigma| spectral_regularize(sigma, epsilon))
.collect();
let mut w_factor = Array2::zeros((n, n));
for j in 0..n {
if !active[j] {
continue;
}
let scale = 1.0 / reg_eigenvalues[j].sqrt();
for row in 0..n {
w_factor[[row, j]] = eigenvectors[[row, j]] * scale;
}
}
let mut hinv_cross_kernel = Array2::zeros((n, n));
for a in 0..n {
if !active[a] {
continue;
}
let inv_ra = 1.0 / reg_eigenvalues[a];
for b in 0..n {
if !active[b] {
continue;
}
hinv_cross_kernel[[a, b]] = inv_ra / reg_eigenvalues[b];
}
}
let four_eps_sq = 4.0 * epsilon * epsilon;
let mut g_factor = Array2::zeros((n, n));
for j in 0..n {
if !active[j] {
continue;
}
let sigma = eigenvalues[j];
let phi_prime = 1.0 / (sigma * sigma + four_eps_sq).sqrt();
let scale = phi_prime.sqrt();
for row in 0..n {
g_factor[[row, j]] = eigenvectors[[row, j]] * scale;
}
}
let mut logdet_hessian_kernel = Array2::zeros((n, n));
let sqrt_disc: Vec<f64> = eigenvalues
.iter()
.map(|&s| (s * s + four_eps_sq).sqrt())
.collect();
for a in 0..n {
if !active[a] {
continue;
}
let sigma_a = eigenvalues[a];
let sqrt_a = sqrt_disc[a];
for b in 0..n {
if !active[b] {
continue;
}
logdet_hessian_kernel[[a, b]] = if a == b {
-sigma_a / (sqrt_a * sqrt_a * sqrt_a)
} else {
let sigma_b = eigenvalues[b];
let sqrt_b = sqrt_disc[b];
-(sigma_a + sigma_b) / (sqrt_a * sqrt_b * (sqrt_a + sqrt_b))
};
}
}
let cached_logdet: f64 = reg_eigenvalues
.iter()
.zip(active.iter())
.filter_map(|(&v, &act)| if act { Some(v.ln()) } else { None })
.sum();
Ok(Self {
reg_eigenvalues,
active_mask: active,
eigenvectors,
w_factor,
hinv_cross_kernel,
g_factor,
logdet_hessian_kernel,
cached_logdet,
projected_factor_cache: ProjectedFactorCache::default(),
n_dim: n,
})
}
#[inline]
fn rotate_to_eigenbasis(&self, matrix: &Array2<f64>) -> Array2<f64> {
self.eigenvectors.t().dot(matrix).dot(&self.eigenvectors)
}
pub fn logdet_gradient_factor(&self) -> &Array2<f64> {
&self.g_factor
}
#[inline]
fn trace_hinv_product_cross_rotated(&self, a_rot: &Array2<f64>, b_rot: &Array2<f64>) -> f64 {
let mut result = 0.0;
for a in 0..self.n_dim {
for b in 0..self.n_dim {
result += self.hinv_cross_kernel[[a, b]] * a_rot[[a, b]] * b_rot[[b, a]];
}
}
result
}
#[inline]
fn trace_hinv_product_cross_dense(&self, a: &Array2<f64>, b: &Array2<f64>) -> f64 {
let a_rot = self.rotate_to_eigenbasis(a);
if std::ptr::eq(a, b) {
return self.trace_hinv_product_cross_rotated(&a_rot, &a_rot);
}
let b_rot = self.rotate_to_eigenbasis(b);
self.trace_hinv_product_cross_rotated(&a_rot, &b_rot)
}
#[inline]
fn projected_matrix(&self, matrix: &Array2<f64>) -> Array2<f64> {
self.w_factor.t().dot(matrix).dot(&self.w_factor)
}
#[inline]
fn projected_operator(&self, factor: &Array2<f64>, op: &dyn HyperOperator) -> Array2<f64> {
if log::log_enabled!(log::Level::Info) {
let start = std::time::Instant::now();
let result = op.projected_matrix(factor);
let signature = format!(
"DenseSpectralOperator::projected_operator dim={} rank={} implicit={}",
self.n_dim,
factor.ncols(),
op.is_implicit(),
);
dense_spectral_stage_log(&signature, start.elapsed().as_secs_f64());
result
} else {
op.projected_matrix(factor)
}
}
#[inline]
fn trace_projected_cross(&self, left: &Array2<f64>, right: &Array2<f64>) -> f64 {
let mut result = 0.0;
for a in 0..left.nrows() {
for b in 0..left.ncols() {
result += left[[a, b]] * right[[b, a]];
}
}
result
}
#[inline]
fn trace_logdet_hessian_cross_rotated(
&self,
h_i_rot: &Array2<f64>,
h_j_rot: &Array2<f64>,
) -> f64 {
let mut result = 0.0;
for a in 0..self.n_dim {
for b in 0..self.n_dim {
result += self.logdet_hessian_kernel[[a, b]] * h_i_rot[[a, b]] * h_j_rot[[b, a]];
}
}
result
}
}
fn dense_spectral_stage_log(signature: &str, elapsed_s: f64) {
use std::sync::Mutex;
struct Repeat {
signature: String,
count: u64,
total: f64,
min: f64,
max: f64,
next_heartbeat: u64,
}
static REPEAT: Mutex<Option<Repeat>> = Mutex::new(None);
let mut guard = match REPEAT.lock() {
Ok(g) => g,
Err(poisoned) => poisoned.into_inner(),
};
if let Some(state) = guard.as_mut() {
if state.signature == signature {
state.count += 1;
state.total += elapsed_s;
if elapsed_s < state.min {
state.min = elapsed_s;
}
if elapsed_s > state.max {
state.max = elapsed_s;
}
if state.count >= state.next_heartbeat {
log::info!(
"[STAGE] {} (×{} so far, total={:.3}s min={:.3}s max={:.3}s avg={:.3}s)",
state.signature,
state.count,
state.total,
state.min,
state.max,
state.total / state.count as f64,
);
state.next_heartbeat = state.next_heartbeat.saturating_mul(2);
}
return;
}
if state.count > 1 {
log::info!(
"[STAGE] {} final ×{} total={:.3}s min={:.3}s max={:.3}s avg={:.3}s",
state.signature,
state.count,
state.total,
state.min,
state.max,
state.total / state.count as f64,
);
}
}
log::info!("[STAGE] {} elapsed={:.3}s", signature, elapsed_s);
*guard = Some(Repeat {
signature: signature.to_string(),
count: 1,
total: elapsed_s,
min: elapsed_s,
max: elapsed_s,
next_heartbeat: 2,
});
}
impl HessianOperator for DenseSpectralOperator {
fn logdet(&self) -> f64 {
self.cached_logdet
}
fn as_exact_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
Some(self)
}
fn trace_hinv_product(&self, a: &Array2<f64>) -> f64 {
let aw = a.dot(&self.w_factor);
aw.iter()
.zip(self.w_factor.iter())
.map(|(&a, &w)| a * w)
.sum()
}
fn solve(&self, rhs: &Array1<f64>) -> Array1<f64> {
let mut result = Array1::zeros(self.n_dim);
for j in 0..self.n_dim {
if !self.active_mask[j] {
continue;
}
let u = self.eigenvectors.column(j);
let coeff = u.dot(rhs) / self.reg_eigenvalues[j];
for row in 0..self.n_dim {
result[row] += coeff * u[row];
}
}
result
}
fn solve_multi(&self, rhs: &Array2<f64>) -> Array2<f64> {
let mut projected = self.eigenvectors.t().dot(rhs);
for j in 0..self.n_dim {
if self.active_mask[j] {
let scale = 1.0 / self.reg_eigenvalues[j];
projected.row_mut(j).mapv_inplace(|value| value * scale);
} else {
projected.row_mut(j).fill(0.0);
}
}
self.eigenvectors.dot(&projected)
}
fn trace_hinv_product_cross(&self, a: &Array2<f64>, b: &Array2<f64>) -> f64 {
self.trace_hinv_product_cross_dense(a, b)
}
fn trace_hinv_operator(&self, op: &dyn HyperOperator) -> f64 {
if log::log_enabled!(log::Level::Info) {
let start = std::time::Instant::now();
let result = op.trace_projected_factor_cached(&self.w_factor, &self.projected_factor_cache);
let signature = format!(
"DenseSpectralOperator::trace_hinv_operator dim={} rank={} implicit={}",
self.n_dim,
self.w_factor.ncols(),
op.is_implicit(),
);
dense_spectral_stage_log(&signature, start.elapsed().as_secs_f64());
result
} else {
op.trace_projected_factor_cached(&self.w_factor, &self.projected_factor_cache)
}
}
fn trace_hinv_matrix_operator_cross(
&self,
matrix: &Array2<f64>,
op: &dyn HyperOperator,
) -> f64 {
let left = self.w_factor.t().dot(matrix).dot(&self.w_factor);
let right = self.projected_operator(&self.w_factor, op);
self.trace_projected_cross(&left, &right)
}
fn trace_hinv_operator_cross(
&self,
left: &dyn HyperOperator,
right: &dyn HyperOperator,
) -> f64 {
if log::log_enabled!(log::Level::Info) {
let start = std::time::Instant::now();
let left_proj = self.projected_operator(&self.w_factor, left);
let result = if std::ptr::addr_eq(left, right) {
self.trace_projected_cross(&left_proj, &left_proj)
} else {
let right_proj = self.projected_operator(&self.w_factor, right);
self.trace_projected_cross(&left_proj, &right_proj)
};
let signature = format!(
"DenseSpectralOperator::trace_hinv_operator_cross dim={} rank={} left_implicit={} right_implicit={}",
self.n_dim,
self.w_factor.ncols(),
left.is_implicit(),
right.is_implicit(),
);
dense_spectral_stage_log(&signature, start.elapsed().as_secs_f64());
result
} else {
let left_proj = self.projected_operator(&self.w_factor, left);
if std::ptr::addr_eq(left, right) {
self.trace_projected_cross(&left_proj, &left_proj)
} else {
let right_proj = self.projected_operator(&self.w_factor, right);
self.trace_projected_cross(&left_proj, &right_proj)
}
}
}
fn trace_logdet_gradient(&self, a: &Array2<f64>) -> f64 {
let ag = a.dot(&self.g_factor);
ag.iter()
.zip(self.g_factor.iter())
.map(|(&a, &g)| a * g)
.sum()
}
fn xt_logdet_kernel_x_diagonal(&self, x: &DesignMatrix) -> Array1<f64> {
let n = x.nrows();
let p = x.ncols();
let rank = self.g_factor.ncols();
let mut h = Array1::<f64>::zeros(n);
if n == 0 || p == 0 || rank == 0 {
return h;
}
let chunk_rows = {
const TARGET_BYTES: usize = 8 * 1024 * 1024;
(TARGET_BYTES / ((p + rank).max(1) * 8)).max(512).min(n)
};
let mut start = 0usize;
while start < n {
let end = (start + chunk_rows).min(n);
let rows = x.try_row_chunk(start..end).unwrap_or_else(|err| {
panic!("xt_logdet_kernel_x_diagonal: row chunk failed: {err}")
});
let xg = crate::faer_ndarray::fast_ab(&rows, &self.g_factor);
for (local, row) in xg.outer_iter().enumerate() {
h[start + local] = row.iter().map(|v| v * v).sum();
}
start = end;
}
h
}
fn trace_logdet_block_local(
&self,
block: &Array2<f64>,
scale: f64,
start: usize,
end: usize,
) -> f64 {
let g_block = self.g_factor.slice(ndarray::s![start..end, ..]);
let ag = block.dot(&g_block);
scale
* ag.iter()
.zip(g_block.iter())
.map(|(&a, &g)| a * g)
.sum::<f64>()
}
fn trace_hinv_block_local(
&self,
block: &Array2<f64>,
scale: f64,
start: usize,
end: usize,
) -> f64 {
let w_block = self.w_factor.slice(ndarray::s![start..end, ..]);
let aw = block.dot(&w_block);
scale
* aw.iter()
.zip(w_block.iter())
.map(|(&a, &w)| a * w)
.sum::<f64>()
}
fn trace_hinv_block_local_cross(
&self,
block: &Array2<f64>,
scale: f64,
start: usize,
end: usize,
) -> f64 {
let w_block = self.w_factor.slice(ndarray::s![start..end, ..]);
let bw = block.dot(&w_block); let m = w_block.t().dot(&bw); let scale_sq = scale * scale;
scale_sq * m.iter().map(|&v| v * v).sum::<f64>()
}
fn trace_logdet_operator(&self, op: &dyn HyperOperator) -> f64 {
if log::log_enabled!(log::Level::Info) {
let start = std::time::Instant::now();
let result = op.trace_projected_factor_cached(&self.g_factor, &self.projected_factor_cache);
let signature = format!(
"DenseSpectralOperator::trace_logdet_operator dim={} rank={} implicit={}",
self.n_dim,
self.g_factor.ncols(),
op.is_implicit(),
);
dense_spectral_stage_log(&signature, start.elapsed().as_secs_f64());
result
} else {
op.trace_projected_factor_cached(&self.g_factor, &self.projected_factor_cache)
}
}
fn trace_logdet_hessian_cross(&self, h_i: &Array2<f64>, h_j: &Array2<f64>) -> f64 {
let hp_i = self.rotate_to_eigenbasis(h_i);
if std::ptr::eq(h_i, h_j) {
return self.trace_logdet_hessian_cross_rotated(&hp_i, &hp_i);
}
let hp_j = self.rotate_to_eigenbasis(h_j);
self.trace_logdet_hessian_cross_rotated(&hp_i, &hp_j)
}
fn trace_logdet_hessian_cross_matrix_operator(
&self,
h_i: &Array2<f64>,
h_j: &dyn HyperOperator,
) -> f64 {
let hp_i = self.rotate_to_eigenbasis(h_i);
let hp_j = self.projected_operator(&self.eigenvectors, h_j);
self.trace_logdet_hessian_cross_rotated(&hp_i, &hp_j)
}
fn trace_logdet_hessian_cross_operator(
&self,
h_i: &dyn HyperOperator,
h_j: &dyn HyperOperator,
) -> f64 {
let hp_i = self.projected_operator(&self.eigenvectors, h_i);
if std::ptr::addr_eq(h_i, h_j) {
return self.trace_logdet_hessian_cross_rotated(&hp_i, &hp_i);
}
let hp_j = self.projected_operator(&self.eigenvectors, h_j);
self.trace_logdet_hessian_cross_rotated(&hp_i, &hp_j)
}
fn trace_logdet_hessian_crosses(&self, matrices: &[&Array2<f64>]) -> Array2<f64> {
let n = matrices.len();
let rotated = matrices
.iter()
.map(|matrix| self.rotate_to_eigenbasis(matrix))
.collect::<Vec<_>>();
let mut out = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in i..n {
let value = self.trace_logdet_hessian_cross_rotated(&rotated[i], &rotated[j]);
out[[i, j]] = value;
out[[j, i]] = value;
}
}
out
}
fn active_rank(&self) -> usize {
self.n_dim
}
fn dim(&self) -> usize {
self.n_dim
}
fn is_dense(&self) -> bool {
true
}
fn prefers_stochastic_trace_estimation(&self) -> bool {
false
}
fn logdet_traces_match_hinv_kernel(&self) -> bool {
false
}
fn as_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
Some(self)
}
}
pub struct SparseCholeskyOperator {
factor: std::sync::Arc<crate::linalg::sparse_exact::SparseExactFactor>,
takahashi: Option<std::sync::Arc<crate::linalg::sparse_exact::TakahashiInverse>>,
cached_logdet: f64,
n_dim: usize,
}
impl SparseCholeskyOperator {
pub fn new(
factor: std::sync::Arc<crate::linalg::sparse_exact::SparseExactFactor>,
logdet_h: f64,
dim: usize,
) -> Self {
Self {
factor,
takahashi: None,
cached_logdet: logdet_h,
n_dim: dim,
}
}
pub fn with_takahashi(
mut self,
taka: std::sync::Arc<crate::linalg::sparse_exact::TakahashiInverse>,
) -> Self {
self.takahashi = Some(taka);
self
}
const OPERATOR_SOLVE_CHUNK: usize = 64;
fn takahashi_block_trace(
taka: &crate::linalg::sparse_exact::TakahashiInverse,
block: &Array2<f64>,
start: usize,
) -> f64 {
debug_assert_eq!(block.nrows(), block.ncols());
let mut trace = 0.0;
for i in 0..block.nrows() {
let diag = block[[i, i]];
if diag.abs() > 1e-30 {
trace += taka.get(start + i, start + i) * diag;
}
for j in (i + 1)..block.ncols() {
let pair = block[[i, j]] + block[[j, i]];
if pair.abs() > 1e-30 {
trace += taka.get(start + i, start + j) * pair;
}
}
}
trace
}
fn takahashi_left_multiply_block(
taka: &crate::linalg::sparse_exact::TakahashiInverse,
block: &Array2<f64>,
start: usize,
) -> Array2<f64> {
let dim = block.nrows();
let mut out = Array2::<f64>::zeros((dim, dim));
for i in 0..dim {
let z_diag = taka.get(start + i, start + i);
if z_diag.abs() > 1e-30 {
for k in 0..dim {
out[[i, k]] += z_diag * block[[i, k]];
}
}
for j in (i + 1)..dim {
let z = taka.get(start + i, start + j);
if z.abs() <= 1e-30 {
continue;
}
for k in 0..dim {
out[[i, k]] += z * block[[j, k]];
out[[j, k]] += z * block[[i, k]];
}
}
}
out
}
fn trace_hinv_operator_exact(&self, op: &dyn HyperOperator) -> f64 {
let (range_start, range_end) = op
.block_local_data()
.map(|(_, start, end)| (start, end))
.unwrap_or((0, self.n_dim));
let chunk = Self::OPERATOR_SOLVE_CHUNK.min(self.n_dim.max(1));
let mut trace = 0.0_f64;
let mut rhs_block = Array2::<f64>::zeros((self.n_dim, chunk));
let mut start = range_start;
while start < range_end {
let end = (start + chunk).min(range_end);
let cols = end - start;
op.mul_basis_columns_into(start, rhs_block.slice_mut(ndarray::s![.., ..cols]));
let diagonal_sum = if cols == chunk {
crate::linalg::sparse_exact::solve_sparse_spdmulti_diagonal_sum(
&self.factor,
&rhs_block,
start,
)
} else {
let rhs_view = rhs_block.slice(ndarray::s![.., ..cols]);
crate::linalg::sparse_exact::solve_sparse_spdmulti_diagonal_sum(
&self.factor,
&rhs_view,
start,
)
};
trace += diagonal_sum.unwrap_or_else(|e| {
panic!("SparseCholeskyOperator exact trace_hinv_operator solve failed: {e}")
});
start = end;
}
trace
}
fn solve_operator_column_range_rows_exact(
&self,
op: &dyn HyperOperator,
col_start: usize,
col_end: usize,
row_start: usize,
row_end: usize,
) -> Result<Array2<f64>, String> {
let chunk = Self::OPERATOR_SOLVE_CHUNK.min(self.n_dim.max(1));
let cols_total = col_end - col_start;
let rows_total = row_end - row_start;
let mut solved = Array2::<f64>::zeros((rows_total, cols_total));
let mut rhs_block = Array2::<f64>::zeros((self.n_dim, chunk));
let mut start = col_start;
while start < col_end {
let end = (start + chunk).min(col_end);
let cols = end - start;
op.mul_basis_columns_into(start, rhs_block.slice_mut(ndarray::s![.., ..cols]));
let solved_block = if cols == chunk {
crate::linalg::sparse_exact::solve_sparse_spdmulti_rows(
&self.factor,
&rhs_block,
row_start,
row_end,
)
} else {
let rhs_view = rhs_block.slice(ndarray::s![.., ..cols]);
crate::linalg::sparse_exact::solve_sparse_spdmulti_rows(
&self.factor,
&rhs_view,
row_start,
row_end,
)
}
.map_err(|e| {
format!(
"SparseCholeskyOperator::solve_operator_column_range_rows_exact multi-solve failed: {e}"
)
})?;
solved
.slice_mut(ndarray::s![.., start - col_start..end - col_start])
.assign(&solved_block);
start = end;
}
Ok(solved)
}
fn fill_scaled_block_columns(
block: &Array2<f64>,
scale: f64,
block_start: usize,
local_col_start: usize,
cols: usize,
mut rhs_block: ndarray::ArrayViewMut2<'_, f64>,
) {
let block_end = block_start + block.nrows();
let source = block.slice(ndarray::s![.., local_col_start..local_col_start + cols]);
let mut target = rhs_block.slice_mut(ndarray::s![block_start..block_end, ..cols]);
if scale == 1.0 {
target.assign(&source);
} else {
Zip::from(target)
.and(source)
.for_each(|dst, &value| *dst = scale * value);
}
}
fn trace_hinv_block_local_exact(
&self,
block: &Array2<f64>,
scale: f64,
start: usize,
end: usize,
) -> f64 {
if scale == 0.0 {
return 0.0;
}
debug_assert_eq!(block.nrows(), end - start);
let t_start = std::time::Instant::now();
let block_size = end - start;
let chunk = Self::OPERATOR_SOLVE_CHUNK.min(block_size.max(1));
let mut rhs_block = Array2::<f64>::zeros((self.n_dim, chunk));
let mut trace = 0.0;
let mut local_col_start = 0usize;
while local_col_start < block_size {
let cols = (block_size - local_col_start).min(chunk);
Self::fill_scaled_block_columns(
block,
scale,
start,
local_col_start,
cols,
rhs_block.view_mut(),
);
let diagonal_sum = if cols == chunk {
crate::linalg::sparse_exact::solve_sparse_spdmulti_diagonal_sum(
&self.factor,
&rhs_block,
start + local_col_start,
)
} else {
let rhs_view = rhs_block.slice(ndarray::s![.., ..cols]);
crate::linalg::sparse_exact::solve_sparse_spdmulti_diagonal_sum(
&self.factor,
&rhs_view,
start + local_col_start,
)
};
trace += diagonal_sum.unwrap_or_else(|e| {
panic!("SparseCholeskyOperator exact block-local trace solve failed: {e}")
});
local_col_start += cols;
}
let elapsed_ms = t_start.elapsed().as_secs_f64() * 1000.0;
if elapsed_ms > 100.0 {
log::info!(
"[REML-trace] block_local_exact | n_dim={} | block={} | {:.1}ms",
self.n_dim,
block_size,
elapsed_ms
);
}
trace
}
fn solve_block_local_rows_exact(
&self,
block: &Array2<f64>,
scale: f64,
start: usize,
end: usize,
) -> Result<Array2<f64>, String> {
debug_assert_eq!(block.nrows(), end - start);
let block_size = end - start;
let chunk = Self::OPERATOR_SOLVE_CHUNK.min(block_size.max(1));
let mut solved = Array2::<f64>::zeros((block_size, block_size));
if scale == 0.0 {
return Ok(solved);
}
let mut rhs_block = Array2::<f64>::zeros((self.n_dim, chunk));
let mut local_col_start = 0usize;
while local_col_start < block_size {
let cols = (block_size - local_col_start).min(chunk);
Self::fill_scaled_block_columns(
block,
scale,
start,
local_col_start,
cols,
rhs_block.view_mut(),
);
let solved_block = if cols == chunk {
crate::linalg::sparse_exact::solve_sparse_spdmulti_rows(
&self.factor,
&rhs_block,
start,
end,
)
} else {
let rhs_view = rhs_block.slice(ndarray::s![.., ..cols]);
crate::linalg::sparse_exact::solve_sparse_spdmulti_rows(
&self.factor,
&rhs_view,
start,
end,
)
}
.map_err(|e| {
format!(
"SparseCholeskyOperator::solve_block_local_rows_exact multi-solve failed: {e}"
)
})?;
solved
.slice_mut(ndarray::s![.., local_col_start..local_col_start + cols])
.assign(&solved_block);
local_col_start += cols;
}
Ok(solved)
}
fn trace_hinv_block_local_cross_exact(
&self,
block: &Array2<f64>,
scale: f64,
start: usize,
end: usize,
) -> f64 {
let t_start = std::time::Instant::now();
let solved = self
.solve_block_local_rows_exact(block, scale, start, end)
.unwrap_or_else(|e| {
panic!("SparseCholeskyOperator exact block-local cross solve failed: {e}")
});
let result = trace_matrix_product(&solved, &solved);
let elapsed_ms = t_start.elapsed().as_secs_f64() * 1000.0;
if elapsed_ms > 100.0 {
log::info!(
"[REML-trace] block_local_cross_exact | n_dim={} | block={} | {:.1}ms",
self.n_dim,
end - start,
elapsed_ms
);
}
result
}
fn trace_hinv_matrix_operator_cross_exact(
&self,
matrix: &Array2<f64>,
op: &dyn HyperOperator,
) -> f64 {
if let Some((_, range_start, range_end)) = op.block_local_data()
&& range_end - range_start < self.n_dim
{
return self.trace_hinv_matrix_block_operator_cross_exact(
matrix,
op,
range_start,
range_end,
);
}
let solved_matrix = self.solve_multi(matrix);
let chunk = Self::OPERATOR_SOLVE_CHUNK.min(self.n_dim.max(1));
let mut rhs_block = Array2::<f64>::zeros((self.n_dim, chunk));
let mut trace = 0.0_f64;
let (range_start, range_end) = op
.block_local_data()
.map(|(_, start, end)| (start, end))
.unwrap_or((0, self.n_dim));
let mut start = range_start;
while start < range_end {
let end = (start + chunk).min(range_end);
let cols = end - start;
op.mul_basis_columns_into(start, rhs_block.slice_mut(ndarray::s![.., ..cols]));
let solved_op = if cols == chunk {
crate::linalg::sparse_exact::solve_sparse_spdmulti(&self.factor, &rhs_block)
} else {
let rhs_view = rhs_block.slice(ndarray::s![.., ..cols]);
crate::linalg::sparse_exact::solve_sparse_spdmulti(&self.factor, &rhs_view)
};
let solved_op = solved_op.unwrap_or_else(|e| {
panic!("SparseCholeskyOperator exact matrix/operator cross solve failed: {e}")
});
for local_col in 0..cols {
let matrix_row = start + local_col;
for row in 0..self.n_dim {
trace += solved_matrix[[matrix_row, row]] * solved_op[[row, local_col]];
}
}
start = end;
}
trace
}
fn trace_hinv_matrix_block_operator_cross_exact(
&self,
matrix: &Array2<f64>,
op: &dyn HyperOperator,
range_start: usize,
range_end: usize,
) -> f64 {
let t_start = std::time::Instant::now();
let chunk = Self::OPERATOR_SOLVE_CHUNK.min(self.n_dim.max(1));
let mut op_rhs_block = Array2::<f64>::zeros((self.n_dim, chunk));
let mut eye_rhs_block = Array2::<f64>::zeros((self.n_dim, chunk));
let mut trace = 0.0_f64;
let mut start = range_start;
while start < range_end {
let end = (start + chunk).min(range_end);
let cols = end - start;
op.mul_basis_columns_into(start, op_rhs_block.slice_mut(ndarray::s![.., ..cols]));
eye_rhs_block.fill(0.0);
for local_col in 0..cols {
eye_rhs_block[[start + local_col, local_col]] = 1.0;
}
let solved_op = if cols == chunk {
crate::linalg::sparse_exact::solve_sparse_spdmulti(&self.factor, &op_rhs_block)
} else {
let rhs_view = op_rhs_block.slice(ndarray::s![.., ..cols]);
crate::linalg::sparse_exact::solve_sparse_spdmulti(&self.factor, &rhs_view)
};
let solved_op = solved_op.unwrap_or_else(|e| {
panic!(
"SparseCholeskyOperator exact matrix/block-operator cross operator solve failed: {e}"
)
});
let solved_eye = if cols == chunk {
crate::linalg::sparse_exact::solve_sparse_spdmulti(&self.factor, &eye_rhs_block)
} else {
let rhs_view = eye_rhs_block.slice(ndarray::s![.., ..cols]);
crate::linalg::sparse_exact::solve_sparse_spdmulti(&self.factor, &rhs_view)
};
let solved_eye = solved_eye.unwrap_or_else(|e| {
panic!(
"SparseCholeskyOperator exact matrix/block-operator cross identity solve failed: {e}"
)
});
let selected_rows_t = matrix.t().dot(&solved_eye);
for local_col in 0..cols {
for row in 0..self.n_dim {
trace += selected_rows_t[[row, local_col]] * solved_op[[row, local_col]];
}
}
start = end;
}
let elapsed_ms = t_start.elapsed().as_secs_f64() * 1000.0;
if elapsed_ms > 100.0 {
log::info!(
"[REML-trace] matrix_block_op_cross_exact | n_dim={} | block={} | {:.1}ms",
self.n_dim,
range_end - range_start,
elapsed_ms
);
}
trace
}
fn trace_hinv_operator_cross_exact(
&self,
left: &dyn HyperOperator,
right: &dyn HyperOperator,
) -> f64 {
let (left_start, left_end) = left
.block_local_data()
.map(|(_, start, end)| (start, end))
.unwrap_or((0, self.n_dim));
let (right_start, right_end) = right
.block_local_data()
.map(|(_, start, end)| (start, end))
.unwrap_or((0, self.n_dim));
let solved_left = self
.solve_operator_column_range_rows_exact(
left,
left_start,
left_end,
right_start,
right_end,
)
.unwrap_or_else(|e| {
panic!("SparseCholeskyOperator exact operator cross left solve failed: {e}")
});
let same_operator =
std::ptr::addr_eq(left, right) && left_start == right_start && left_end == right_end;
let solved_right = if same_operator {
None
} else {
Some(
self.solve_operator_column_range_rows_exact(
right,
right_start,
right_end,
left_start,
left_end,
)
.unwrap_or_else(|e| {
panic!("SparseCholeskyOperator exact operator cross right solve failed: {e}")
}),
)
};
let right_cols = right_end - right_start;
let mut trace = 0.0;
for left_col in 0..(left_end - left_start) {
for right_col in 0..right_cols {
let right_value = match solved_right.as_ref() {
Some(solved) => solved[[left_col, right_col]],
None => solved_left[[left_col, right_col]],
};
trace += solved_left[[right_col, left_col]] * right_value;
}
}
trace
}
}
impl HessianOperator for SparseCholeskyOperator {
fn logdet(&self) -> f64 {
self.cached_logdet
}
fn trace_hinv_product(&self, a: &Array2<f64>) -> f64 {
if let Some(ref taka) = self.takahashi {
let mut trace = 0.0;
for i in 0..a.nrows() {
let a_ii = a[[i, i]];
if a_ii.abs() > 1e-30 {
trace += taka.get(i, i) * a_ii;
}
for j in (i + 1)..a.ncols() {
let pair = a[[i, j]] + a[[j, i]];
if pair.abs() > 1e-30 {
trace += taka.get(i, j) * pair;
}
}
}
return trace;
}
crate::linalg::sparse_exact::solve_sparse_spdmulti(&self.factor, a)
.unwrap_or_else(|e| {
panic!("SparseCholeskyOperator exact trace_hinv_product solve failed: {e}")
})
.diag()
.sum()
}
fn trace_hinv_operator(&self, op: &dyn HyperOperator) -> f64 {
if let Some(ref taka) = self.takahashi {
if let Some((local, start, end)) = op.block_local_data() {
debug_assert_eq!(local.nrows(), end - start);
return Self::takahashi_block_trace(taka, local, start);
}
if !op.is_implicit() {
let dense = op.to_dense();
return self.trace_hinv_product(&dense);
}
}
self.trace_hinv_operator_exact(op)
}
fn trace_logdet_operator(&self, op: &dyn HyperOperator) -> f64 {
self.trace_hinv_operator(op)
}
fn trace_hinv_block_local(
&self,
block: &Array2<f64>,
scale: f64,
start: usize,
end: usize,
) -> f64 {
if let Some(ref taka) = self.takahashi {
debug_assert_eq!(block.nrows(), end - start);
return scale * Self::takahashi_block_trace(taka, block, start);
}
self.trace_hinv_block_local_exact(block, scale, start, end)
}
fn trace_hinv_block_local_cross(
&self,
block: &Array2<f64>,
scale: f64,
start: usize,
end: usize,
) -> f64 {
if let Some(ref taka) = self.takahashi {
debug_assert_eq!(block.nrows(), end - start);
let za = Self::takahashi_left_multiply_block(taka, block, start);
return scale * scale * trace_matrix_product(&za, &za);
}
self.trace_hinv_block_local_cross_exact(block, scale, start, end)
}
fn solve(&self, rhs: &Array1<f64>) -> Array1<f64> {
crate::linalg::sparse_exact::solve_sparse_spd(&self.factor, rhs)
.unwrap_or_else(|e| panic!("SparseCholeskyOperator exact solve failed: {e}"))
}
fn solve_multi(&self, rhs: &Array2<f64>) -> Array2<f64> {
crate::linalg::sparse_exact::solve_sparse_spdmulti(&self.factor, rhs)
.unwrap_or_else(|e| panic!("SparseCholeskyOperator exact multi-solve failed: {e}"))
}
fn trace_hinv_product_cross(&self, a: &Array2<f64>, b: &Array2<f64>) -> f64 {
let solved_a = self.solve_multi(a);
if std::ptr::eq(a, b) {
return trace_matrix_product(&solved_a, &solved_a);
}
let solved_b = self.solve_multi(b);
trace_matrix_product(&solved_a, &solved_b)
}
fn trace_hinv_matrix_operator_cross(
&self,
matrix: &Array2<f64>,
op: &dyn HyperOperator,
) -> f64 {
self.trace_hinv_matrix_operator_cross_exact(matrix, op)
}
fn trace_hinv_operator_cross(
&self,
left: &dyn HyperOperator,
right: &dyn HyperOperator,
) -> f64 {
if let Some(ref taka) = self.takahashi {
if let (Some((a_local, a_start, a_end)), Some((b_local, b_start, b_end))) =
(left.block_local_data(), right.block_local_data())
{
if a_start == b_start && a_end == b_end {
let za = Self::takahashi_left_multiply_block(taka, a_local, a_start);
if std::ptr::addr_eq(left, right) {
return trace_matrix_product(&za, &za);
}
let zb = Self::takahashi_left_multiply_block(taka, b_local, b_start);
return (&za * &zb.t()).sum();
}
}
}
self.trace_hinv_operator_cross_exact(left, right)
}
fn trace_logdet_hessian_cross_matrix_operator(
&self,
h_i: &Array2<f64>,
h_j: &dyn HyperOperator,
) -> f64 {
-self.trace_hinv_matrix_operator_cross(h_i, h_j)
}
fn trace_logdet_hessian_cross_operator(
&self,
h_i: &dyn HyperOperator,
h_j: &dyn HyperOperator,
) -> f64 {
-self.trace_hinv_operator_cross(h_i, h_j)
}
fn active_rank(&self) -> usize {
self.n_dim
}
fn dim(&self) -> usize {
self.n_dim
}
}
pub struct BlockCoupledOperator {
inner: DenseSpectralOperator,
}
impl BlockCoupledOperator {
#[cfg(test)]
pub fn from_joint_hessian(joint_hessian: &Array2<f64>) -> Result<Self, String> {
Self::from_joint_hessian_with_mode(joint_hessian, PseudoLogdetMode::Smooth)
}
pub fn from_joint_hessian_with_mode(
joint_hessian: &Array2<f64>,
mode: PseudoLogdetMode,
) -> Result<Self, String> {
let inner = DenseSpectralOperator::from_symmetric_with_mode(joint_hessian, mode)
.map_err(|e| format!("BlockCoupledOperator eigendecomposition: {e}"))?;
Ok(Self { inner })
}
}
impl HessianOperator for BlockCoupledOperator {
fn logdet(&self) -> f64 {
self.inner.logdet()
}
fn as_exact_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
self.inner.as_exact_dense_spectral()
}
fn trace_hinv_product(&self, a: &Array2<f64>) -> f64 {
self.inner.trace_hinv_product(a)
}
fn trace_hinv_h_k(
&self,
a_k: &Array2<f64>,
third_deriv_correction: Option<&Array2<f64>>,
) -> f64 {
self.inner.trace_hinv_h_k(a_k, third_deriv_correction)
}
fn trace_logdet_gradient(&self, a: &Array2<f64>) -> f64 {
self.inner.trace_logdet_gradient(a)
}
fn xt_logdet_kernel_x_diagonal(&self, x: &DesignMatrix) -> Array1<f64> {
self.inner.xt_logdet_kernel_x_diagonal(x)
}
fn trace_logdet_h_k(
&self,
a_k: &Array2<f64>,
third_deriv_correction: Option<&Array2<f64>>,
) -> f64 {
self.inner.trace_logdet_h_k(a_k, third_deriv_correction)
}
fn trace_logdet_operator(&self, op: &dyn HyperOperator) -> f64 {
self.inner.trace_logdet_operator(op)
}
fn trace_logdet_hessian_cross(&self, h_i: &Array2<f64>, h_j: &Array2<f64>) -> f64 {
self.inner.trace_logdet_hessian_cross(h_i, h_j)
}
fn trace_logdet_hessian_crosses(&self, matrices: &[&Array2<f64>]) -> Array2<f64> {
self.inner.trace_logdet_hessian_crosses(matrices)
}
fn trace_hinv_block_local_cross(
&self,
block: &Array2<f64>,
scale: f64,
start: usize,
end: usize,
) -> f64 {
self.inner
.trace_hinv_block_local_cross(block, scale, start, end)
}
fn solve(&self, rhs: &Array1<f64>) -> Array1<f64> {
self.inner.solve(rhs)
}
fn solve_multi(&self, rhs: &Array2<f64>) -> Array2<f64> {
self.inner.solve_multi(rhs)
}
fn trace_hinv_product_cross(&self, a: &Array2<f64>, b: &Array2<f64>) -> f64 {
self.inner.trace_hinv_product_cross(a, b)
}
fn trace_hinv_matrix_operator_cross(
&self,
matrix: &Array2<f64>,
op: &dyn HyperOperator,
) -> f64 {
self.inner.trace_hinv_matrix_operator_cross(matrix, op)
}
fn trace_hinv_operator_cross(
&self,
left: &dyn HyperOperator,
right: &dyn HyperOperator,
) -> f64 {
self.inner.trace_hinv_operator_cross(left, right)
}
fn active_rank(&self) -> usize {
self.inner.active_rank()
}
fn dim(&self) -> usize {
self.inner.dim()
}
fn is_dense(&self) -> bool {
true
}
fn prefers_stochastic_trace_estimation(&self) -> bool {
false
}
fn logdet_traces_match_hinv_kernel(&self) -> bool {
false
}
fn as_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
Some(&self.inner)
}
}
pub struct MatrixFreeSpdOperator {
apply: Arc<dyn Fn(&Array1<f64>) -> Array1<f64> + Send + Sync>,
cached_logdet: crate::resource::RayonSafeOnce<f64>,
n_dim: usize,
dense_spectral: crate::resource::RayonSafeOnce<Option<DenseSpectralOperator>>,
}
impl MatrixFreeSpdOperator {
const EXACT_DENSE_SPECTRAL_MAX_BYTES: usize = 512 * 1024 * 1024;
const EXACT_DENSE_SPECTRAL_ARRAYS: usize = 6;
pub fn new<F>(dim: usize, apply: F) -> Self
where
F: Fn(&Array1<f64>) -> Array1<f64> + Send + Sync + 'static,
{
let apply = Arc::new(apply);
Self {
apply,
cached_logdet: crate::resource::RayonSafeOnce::new(),
n_dim: dim,
dense_spectral: crate::resource::RayonSafeOnce::new(),
}
}
fn exact_dense_spectral_bytes(&self) -> Option<usize> {
self.n_dim
.checked_mul(self.n_dim)?
.checked_mul(std::mem::size_of::<f64>())?
.checked_mul(Self::EXACT_DENSE_SPECTRAL_ARRAYS)
}
fn exact_dense_spectral_budget_ok(&self) -> bool {
match self.exact_dense_spectral_bytes() {
Some(bytes) if bytes <= Self::EXACT_DENSE_SPECTRAL_MAX_BYTES => true,
Some(bytes) => {
log::error!(
"MatrixFreeSpdOperator exact dense spectral materialization requires {:.2} GiB \
for dim={}, exceeding the {:.2} GiB cap",
bytes as f64 / (1024.0 * 1024.0 * 1024.0),
self.n_dim,
Self::EXACT_DENSE_SPECTRAL_MAX_BYTES as f64 / (1024.0 * 1024.0 * 1024.0),
);
false
}
None => {
log::error!(
"MatrixFreeSpdOperator exact dense spectral byte count overflow for dim={}",
self.n_dim
);
false
}
}
}
fn materialize_dense_operator(&self) -> Option<DenseSpectralOperator> {
if !self.exact_dense_spectral_budget_ok() {
return None;
}
let materialize_start = std::time::Instant::now();
let mut matrix = Array2::<f64>::zeros((self.n_dim, self.n_dim));
let mut basis = Array1::<f64>::zeros(self.n_dim);
for j in 0..self.n_dim {
basis[j] = 1.0;
let col = (self.apply)(&basis);
basis[j] = 0.0;
if col.len() != self.n_dim || !col.iter().all(|v| v.is_finite()) {
return None;
}
matrix.column_mut(j).assign(&col);
}
for i in 0..self.n_dim {
for j in (i + 1)..self.n_dim {
let avg = 0.5 * (matrix[[i, j]] + matrix[[j, i]]);
matrix[[i, j]] = avg;
matrix[[j, i]] = avg;
}
}
let result = DenseSpectralOperator::from_symmetric(&matrix).ok();
log::info!(
"[STAGE] matrix_free_spd materialize n_dim={} matvec_count={} elapsed={:.3}s",
self.n_dim,
self.n_dim,
materialize_start.elapsed().as_secs_f64(),
);
result
}
fn dense_spectral(&self) -> Option<&DenseSpectralOperator> {
self.dense_spectral
.get_or_init(|| self.materialize_dense_operator())
.as_ref()
}
fn exact_dense_spectral(&self) -> &DenseSpectralOperator {
self.dense_spectral().expect(
"MatrixFreeSpdOperator exact REML algebra requires dense spectral materialization within the configured budget",
)
}
}
impl HessianOperator for MatrixFreeSpdOperator {
fn logdet(&self) -> f64 {
*self
.cached_logdet
.get_or_init(|| self.exact_dense_spectral().logdet())
}
fn as_exact_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
Some(self.exact_dense_spectral())
}
fn trace_hinv_product(&self, a: &Array2<f64>) -> f64 {
self.exact_dense_spectral().trace_hinv_product(a)
}
fn trace_hinv_operator(&self, op: &dyn HyperOperator) -> f64 {
self.exact_dense_spectral().trace_hinv_operator(op)
}
fn trace_hinv_product_cross(&self, a: &Array2<f64>, b: &Array2<f64>) -> f64 {
self.exact_dense_spectral().trace_hinv_product_cross(a, b)
}
fn trace_hinv_matrix_operator_cross(
&self,
matrix: &Array2<f64>,
op: &dyn HyperOperator,
) -> f64 {
self.exact_dense_spectral()
.trace_hinv_matrix_operator_cross(matrix, op)
}
fn trace_hinv_operator_cross(
&self,
left: &dyn HyperOperator,
right: &dyn HyperOperator,
) -> f64 {
self.exact_dense_spectral()
.trace_hinv_operator_cross(left, right)
}
fn trace_logdet_operator(&self, op: &dyn HyperOperator) -> f64 {
let trace_start = std::time::Instant::now();
let result = self.exact_dense_spectral().trace_logdet_operator(op);
log::info!(
"[STAGE] matrix_free_spd trace_logdet_operator implicit={} dim={} elapsed={:.3}s",
op.is_implicit(),
op.dim(),
trace_start.elapsed().as_secs_f64(),
);
result
}
fn solve(&self, rhs: &Array1<f64>) -> Array1<f64> {
self.exact_dense_spectral().solve(rhs)
}
fn solve_multi(&self, rhs: &Array2<f64>) -> Array2<f64> {
self.exact_dense_spectral().solve_multi(rhs)
}
fn stochastic_trace_solve(&self, rhs: &Array1<f64>, rel_tol: f64) -> Array1<f64> {
let _ = rel_tol;
self.solve(rhs)
}
fn stochastic_trace_solve_multi(&self, rhs: &Array2<f64>, rel_tol: f64) -> Array2<f64> {
let _ = rel_tol;
self.solve_multi(rhs)
}
fn trace_logdet_hessian_cross(&self, h_i: &Array2<f64>, h_j: &Array2<f64>) -> f64 {
self.exact_dense_spectral()
.trace_logdet_hessian_cross(h_i, h_j)
}
fn trace_logdet_hessian_cross_matrix_operator(
&self,
h_i: &Array2<f64>,
h_j: &dyn HyperOperator,
) -> f64 {
self.exact_dense_spectral()
.trace_logdet_hessian_cross_matrix_operator(h_i, h_j)
}
fn trace_logdet_hessian_cross_operator(
&self,
h_i: &dyn HyperOperator,
h_j: &dyn HyperOperator,
) -> f64 {
self.exact_dense_spectral()
.trace_logdet_hessian_cross_operator(h_i, h_j)
}
fn trace_logdet_hessian_crosses(&self, matrices: &[&Array2<f64>]) -> Array2<f64> {
self.exact_dense_spectral()
.trace_logdet_hessian_crosses(matrices)
}
fn active_rank(&self) -> usize {
self.n_dim
}
fn dim(&self) -> usize {
self.n_dim
}
fn is_dense(&self) -> bool {
true
}
fn prefers_stochastic_trace_estimation(&self) -> bool {
false
}
fn logdet_traces_match_hinv_kernel(&self) -> bool {
false
}
fn as_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
self.dense_spectral()
}
}
pub fn penalty_matrix_root(s: &Array2<f64>) -> Result<Array2<f64>, String> {
use faer::Side;
let n = s.nrows();
if n != s.ncols() {
return Err(format!(
"penalty_matrix_root: expected square matrix, got {}×{}",
n,
s.ncols()
));
}
if n == 0 {
return Ok(Array2::zeros((0, 0)));
}
let (eigenvalues, eigenvectors) = s
.eigh(Side::Lower)
.map_err(|e| format!("penalty_matrix_root eigendecomposition failed: {e}"))?;
let max_ev = eigenvalues.iter().copied().fold(0.0_f64, f64::max);
let tol = (n.max(1) as f64) * f64::EPSILON * max_ev.max(1e-12);
let active: Vec<usize> = eigenvalues
.iter()
.enumerate()
.filter(|(_, v)| **v > tol)
.map(|(i, _)| i)
.collect();
let rank = active.len();
let mut r = Array2::zeros((rank, n));
for (out_row, &idx) in active.iter().enumerate() {
let scale = eigenvalues[idx].sqrt();
for col in 0..n {
r[[out_row, col]] = scale * eigenvectors[[col, idx]];
}
}
Ok(r)
}
pub fn compute_block_penalty_logdet_derivs(
per_block_rho: &[Array1<f64>],
per_block_penalties: &[&[Array2<f64>]],
per_block_nullspace_dims: &[&[usize]],
ridge: f64,
) -> Result<PenaltyLogdetDerivs, String> {
use super::penalty_logdet::PenaltyPseudologdet;
let total_k: usize = per_block_rho.iter().map(|r| r.len()).sum();
let block_offsets: Vec<usize> = per_block_rho
.iter()
.scan(0usize, |at, rho| {
let current = *at;
*at += rho.len();
Some(current)
})
.collect();
struct BlockPenaltyLogdetResult {
offset: usize,
value: f64,
first: Array1<f64>,
second: Array2<f64>,
}
let compute_block = |(b, block_rho): (usize, &Array1<f64>)| {
let penalties = per_block_penalties[b];
let kb = block_rho.len();
if penalties.is_empty() || kb == 0 {
return Ok(BlockPenaltyLogdetResult {
offset: block_offsets[b],
value: 0.0,
first: Array1::zeros(kb),
second: Array2::zeros((kb, kb)),
});
}
let lambdas: Vec<f64> = block_rho.iter().map(|&r| r.exp()).collect();
let block_nullspace_dims = if b < per_block_nullspace_dims.len() {
per_block_nullspace_dims[b]
} else {
&[]
};
let structural_nullity =
if !block_nullspace_dims.is_empty() && block_nullspace_dims.len() == penalties.len() {
Some(exact_intersection_nullity(penalties, block_nullspace_dims))
} else {
None
};
let pld = PenaltyPseudologdet::from_components_with_nullity(
penalties,
&lambdas,
ridge,
structural_nullity,
)
.map_err(|e| format!("penalty logdet failed for block {b}: {e}"))?;
let value = pld.value();
let (first, second) = pld.rho_derivatives(penalties, &lambdas);
Ok(BlockPenaltyLogdetResult {
offset: block_offsets[b],
value,
first,
second,
})
};
let block_results: Vec<BlockPenaltyLogdetResult> = if rayon::current_thread_index().is_some() {
per_block_rho
.iter()
.enumerate()
.map(compute_block)
.collect::<Result<Vec<_>, String>>()?
} else {
per_block_rho
.par_iter()
.enumerate()
.map(compute_block)
.collect::<Result<Vec<_>, String>>()?
};
let mut log_det_total = 0.0;
let mut first = Array1::zeros(total_k);
let mut second = Array2::zeros((total_k, total_k));
for block in block_results {
log_det_total += block.value;
let kb = block.first.len();
for k in 0..kb {
first[block.offset + k] = block.first[k];
}
for k in 0..kb {
for l in 0..kb {
second[[block.offset + k, block.offset + l]] = block.second[[k, l]];
}
}
}
Ok(PenaltyLogdetDerivs {
value: log_det_total,
first,
second: Some(second),
})
}
#[derive(Clone, Debug)]
pub struct StochasticTraceConfig {
pub n_probes_min: usize,
pub n_probes_max: usize,
pub relative_tol: f64,
pub tau_rel: f64,
pub solve_rel_tol: f64,
pub seed: u64,
pub hutchpp_sketch_dim: Option<usize>,
}
impl Default for StochasticTraceConfig {
fn default() -> Self {
Self {
n_probes_min: 10,
n_probes_max: 200,
relative_tol: 0.01,
tau_rel: 1e-8,
solve_rel_tol: 1e-8,
seed: 0xCAFE_BABE,
hutchpp_sketch_dim: None,
}
}
}
impl StochasticTraceConfig {
fn outer_hessian(dim: usize, n_coords: usize) -> Self {
let large_problem = dim >= 512 || n_coords >= 4;
Self {
n_probes_min: if large_problem { 4 } else { 6 },
n_probes_max: if large_problem { 8 } else { 24 },
relative_tol: if large_problem { 0.12 } else { 0.05 },
tau_rel: 1e-3,
solve_rel_tol: if large_problem { 1e-4 } else { 1e-5 },
seed: 0xC0A5_7ACE,
hutchpp_sketch_dim: None,
}
}
}
pub struct StochasticTraceEstimator {
config: StochasticTraceConfig,
}
enum StochasticTraceTargets<'a> {
Dense(&'a [&'a Array2<f64>]),
Mixed {
dense_matrices: &'a [&'a Array2<f64>],
operators: &'a [&'a dyn HyperOperator],
},
Structural {
dense_matrices: &'a [&'a Array2<f64>],
implicit_ops: &'a [&'a ImplicitHyperOperator],
},
}
impl StochasticTraceTargets<'_> {
fn len(&self) -> usize {
match self {
Self::Dense(matrices) => matrices.len(),
Self::Mixed {
dense_matrices,
operators,
} => dense_matrices.len() + operators.len(),
Self::Structural {
dense_matrices,
implicit_ops,
} => dense_matrices.len() + implicit_ops.len(),
}
}
}
impl StochasticTraceEstimator {
pub fn new(config: StochasticTraceConfig) -> Self {
Self { config }
}
pub fn with_defaults() -> Self {
Self::new(StochasticTraceConfig::default())
}
fn for_outer_hessian(dim: usize, n_coords: usize) -> Self {
Self::new(StochasticTraceConfig::outer_hessian(dim, n_coords))
}
fn estimate_from_probe_batch<F>(
&self,
hop: &dyn HessianOperator,
n_coords: usize,
mut evaluate_probe: F,
) -> Vec<f64>
where
F: FnMut(&Array1<f64>, &Array1<f64>, &mut [f64]),
{
if n_coords == 0 {
return Vec::new();
}
let p = hop.dim();
if p == 0 {
return vec![0.0; n_coords];
}
let mut means = vec![0.0_f64; n_coords];
let mut m2s = vec![0.0_f64; n_coords];
let mut probe_values = vec![0.0_f64; n_coords];
let mut rng_state = Xoshiro256SS::from_seed(self.config.seed);
let check_interval = 4;
let mut z = Array1::<f64>::zeros(p);
for m in 0..self.config.n_probes_max {
rademacher_probe_into(z.view_mut(), &mut rng_state);
let w = hop.stochastic_trace_solve(&z, self.config.solve_rel_tol);
evaluate_probe(&z, &w, &mut probe_values);
for k in 0..n_coords {
let q_k = probe_values[k];
let count = (m + 1) as f64;
let delta = q_k - means[k];
means[k] += delta / count;
let delta2 = q_k - means[k];
m2s[k] += delta * delta2;
}
let n_done = m + 1;
if n_done >= self.config.n_probes_min && n_done % check_interval == 0 {
if self.check_convergence(n_done, &means, &m2s) {
break;
}
}
}
means
}
fn estimate_matrix_from_probe_batch<F>(
&self,
hop: &dyn HessianOperator,
n_coords: usize,
mut evaluate_probe: F,
) -> Array2<f64>
where
F: FnMut(&Array1<f64>, &mut Array2<f64>),
{
if n_coords == 0 {
return Array2::zeros((0, 0));
}
let p = hop.dim();
if p == 0 {
return Array2::zeros((n_coords, n_coords));
}
let mut means = Array2::<f64>::zeros((n_coords, n_coords));
let mut m2s = Array2::<f64>::zeros((n_coords, n_coords));
let mut probe_values = Array2::<f64>::zeros((n_coords, n_coords));
let mut rng_state = Xoshiro256SS::from_seed(self.config.seed);
let check_interval = 4;
let mut z = Array1::<f64>::zeros(p);
for m in 0..self.config.n_probes_max {
rademacher_probe_into(z.view_mut(), &mut rng_state);
probe_values.fill(0.0);
evaluate_probe(&z, &mut probe_values);
let count = (m + 1) as f64;
for d in 0..n_coords {
for e in 0..n_coords {
let q = probe_values[[d, e]];
let delta = q - means[[d, e]];
means[[d, e]] += delta / count;
let delta2 = q - means[[d, e]];
m2s[[d, e]] += delta * delta2;
}
}
let n_done = m + 1;
if n_done >= self.config.n_probes_min
&& n_done % check_interval == 0
&& self.check_matrix_convergence(n_done, &means, &m2s)
{
break;
}
}
for d in 0..n_coords {
for e in (d + 1)..n_coords {
let avg = 0.5 * (means[[d, e]] + means[[e, d]]);
means[[d, e]] = avg;
means[[e, d]] = avg;
}
}
means
}
fn estimate_hinv_traces(
&self,
hop: &dyn HessianOperator,
targets: StochasticTraceTargets<'_>,
) -> Vec<f64> {
let n_coords = targets.len();
if n_coords == 0 {
return Vec::new();
}
match targets {
StochasticTraceTargets::Dense(matrices) => {
let mut a_w = Array1::<f64>::zeros(hop.dim());
self.estimate_from_probe_batch(hop, n_coords, |z, w, probe_values| {
for k in 0..matrices.len() {
dense_matvec_into(matrices[k], w.view(), a_w.view_mut());
probe_values[k] = z.dot(&a_w);
}
})
}
StochasticTraceTargets::Mixed {
dense_matrices,
operators,
} => {
let mut a_w = Array1::<f64>::zeros(hop.dim());
self.estimate_from_probe_batch(hop, n_coords, |z, w, probe_values| {
for k in 0..dense_matrices.len() {
dense_matvec_into(dense_matrices[k], w.view(), a_w.view_mut());
probe_values[k] = z.dot(&a_w);
}
let dense_count = dense_matrices.len();
for (oi, op) in operators.iter().enumerate() {
let k = dense_count + oi;
if op.has_fast_bilinear_view() {
probe_values[k] = op.bilinear_view(w.view(), z.view());
} else {
op.mul_vec_into(w.view(), a_w.view_mut());
probe_values[k] = z.dot(&a_w);
}
}
})
}
StochasticTraceTargets::Structural {
dense_matrices,
implicit_ops,
} => {
if implicit_ops.is_empty() {
let no_ops: [&dyn HyperOperator; 0] = [];
return self.estimate_hinv_traces(
hop,
StochasticTraceTargets::Mixed {
dense_matrices,
operators: &no_ops,
},
);
}
let x_design = implicit_ops[0].x_design.clone();
let mut x_vec = Array1::<f64>::zeros(x_design.nrows());
let mut y_vec = Array1::<f64>::zeros(x_design.nrows());
let mut a_w = Array1::<f64>::zeros(hop.dim());
self.estimate_from_probe_batch(hop, n_coords, |z, w, probe_values| {
design_matrix_apply_view_into(x_design.as_ref(), z.view(), x_vec.view_mut());
design_matrix_apply_view_into(x_design.as_ref(), w.view(), y_vec.view_mut());
for k in 0..dense_matrices.len() {
dense_matvec_into(dense_matrices[k], w.view(), a_w.view_mut());
probe_values[k] = z.dot(&a_w);
}
let dense_count = dense_matrices.len();
for (oi, op) in implicit_ops.iter().enumerate() {
let k = dense_count + oi;
probe_values[k] = op.bilinear_with_shared_x(&x_vec, &y_vec, z, w);
}
})
}
}
}
pub fn estimate_single_trace(&self, hop: &dyn HessianOperator, matrix: &Array2<f64>) -> f64 {
let matrices = [matrix];
self.estimate_hinv_traces(hop, StochasticTraceTargets::Dense(&matrices))[0]
}
pub fn estimate_traces(
&self,
hop: &dyn HessianOperator,
matrices: &[&Array2<f64>],
) -> Vec<f64> {
self.estimate_hinv_traces(hop, StochasticTraceTargets::Dense(matrices))
}
pub fn estimate_traces_with_operators(
&self,
hop: &dyn HessianOperator,
dense_matrices: &[&Array2<f64>],
operators: &[&dyn HyperOperator],
) -> Vec<f64> {
self.estimate_hinv_traces(
hop,
StochasticTraceTargets::Mixed {
dense_matrices,
operators,
},
)
}
pub fn estimate_traces_structural(
&self,
hop: &dyn HessianOperator,
dense_matrices: &[&Array2<f64>],
implicit_ops: &[&ImplicitHyperOperator],
) -> Vec<f64> {
self.estimate_hinv_traces(
hop,
StochasticTraceTargets::Structural {
dense_matrices,
implicit_ops,
},
)
}
pub fn estimate_second_order_traces(
&self,
hop: &dyn HessianOperator,
dense_matrices: &[&Array2<f64>],
implicit_ops: &[&ImplicitHyperOperator],
) -> Array2<f64> {
let n_dense = dense_matrices.len();
let n_ops = implicit_ops.len();
let total = n_dense + n_ops;
if total == 0 {
return Array2::zeros((0, 0));
}
let p = hop.dim();
if p == 0 {
return Array2::zeros((total, total));
}
if total == 1 {
let value = if n_dense == 1 {
self.estimate_second_order_single_dense(hop, dense_matrices[0])
} else {
self.estimate_second_order_single_implicit(hop, implicit_ops[0])
};
return Array2::from_elem((1, 1), value);
}
let x_design = if n_ops > 0 {
Some(implicit_ops[0].x_design.clone())
} else {
None
};
let mut q_columns = Array2::zeros((p, total));
let mut dense_a_u: Vec<Array1<f64>> = (0..n_dense).map(|_| Array1::zeros(p)).collect();
let n_obs = implicit_ops.first().map(|op| op.w_diag.len()).unwrap_or(0);
let mut x_vec = Array1::<f64>::zeros(n_obs);
let mut y_vec = Array1::<f64>::zeros(n_obs);
let mut x_r: Vec<Array1<f64>> = (0..total).map(|_| Array1::zeros(n_obs)).collect();
struct ImplicitSecondOrderScratch {
w_dx_u: Array1<f64>,
w_y: Array1<f64>,
u_s: Array1<f64>,
}
self.estimate_matrix_from_probe_batch(hop, total, |z, probe_values| {
let u = hop.stochastic_trace_solve(z, self.config.solve_rel_tol);
if let Some(ref x) = x_design {
design_matrix_apply_view_into(x.as_ref(), z.view(), x_vec.view_mut());
}
{
use ndarray::Axis;
use ndarray::parallel::prelude::*;
q_columns
.axis_iter_mut(Axis(1))
.into_par_iter()
.enumerate()
.for_each(|(e, q_col)| {
if e < n_dense {
dense_matvec_into(dense_matrices[e], z.view(), q_col);
} else {
let op = implicit_ops[e - n_dense];
let mut n_work = Array1::<f64>::zeros(n_obs);
let mut p_work = Array1::<f64>::zeros(p);
op.matvec_with_shared_xz_into(
&x_vec,
z.view(),
q_col,
n_work.view_mut(),
p_work.view_mut(),
);
}
});
}
let r = hop.stochastic_trace_solve_multi(&q_columns, self.config.solve_rel_tol);
if let Some(ref x) = x_design {
design_matrix_apply_view_into(x.as_ref(), u.view(), y_vec.view_mut());
}
for d in 0..n_dense {
dense_matvec_into(dense_matrices[d], u.view(), dense_a_u[d].view_mut());
}
if let Some(ref x) = x_design {
use rayon::prelude::*;
x_r.par_iter_mut().enumerate().for_each(|(e, x_r_e)| {
design_matrix_apply_view_into(x.as_ref(), r.column(e), x_r_e.view_mut());
});
}
let implicit_scratch: Vec<ImplicitSecondOrderScratch> = {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
(0..n_ops)
.into_par_iter()
.map(|idx| {
let op = implicit_ops[idx];
let dx_u = op
.implicit_deriv
.forward_mul(op.axis, &u.view())
.expect(
"radial scalar evaluation failed during implicit derivative forward_mul",
);
let w = &*op.w_diag;
let mut w_dx_u = Array1::<f64>::zeros(n_obs);
let mut w_y = Array1::<f64>::zeros(n_obs);
for i in 0..w.len() {
w_dx_u[i] = w[i] * dx_u[i];
w_y[i] = w[i] * y_vec[i];
}
let mut u_s = Array1::<f64>::zeros(p);
dense_transpose_matvec_into(&op.s_psi, u.view(), u_s.view_mut());
ImplicitSecondOrderScratch { w_dx_u, w_y, u_s }
})
.collect()
};
let pairs: Vec<(usize, usize)> = (0..total)
.flat_map(|d| (0..total).map(move |e| (d, e)))
.collect();
let pair_values: Vec<(usize, usize, f64)> = {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
pairs
.into_par_iter()
.map(|(d, e)| {
let r_e = r.column(e);
let val = if d < n_dense {
dense_a_u[d].dot(&r_e)
} else {
let oi = d - n_dense;
let op = implicit_ops[oi];
let scratch = &implicit_scratch[oi];
let x_re = &x_r[e];
let dx_re = op
.implicit_deriv
.forward_mul(op.axis, &r_e)
.expect(
"radial scalar evaluation failed during implicit derivative forward_mul",
);
let mut design_val = 0.0f64;
for i in 0..scratch.w_dx_u.len() {
design_val += scratch.w_dx_u[i] * x_re[i];
design_val += scratch.w_y[i] * dx_re[i];
}
if let Some(c_x_psi_beta) = op.c_x_psi_beta.as_ref() {
let c = c_x_psi_beta.as_ref();
for i in 0..scratch.w_dx_u.len() {
design_val += y_vec[i] * c[i] * x_re[i];
}
}
let penalty_val = scratch.u_s.dot(&r_e);
design_val + penalty_val
};
(d, e, val)
})
.collect()
};
for (d, e, val) in pair_values {
probe_values[[d, e]] = val;
}
})
}
pub fn estimate_second_order_traces_with_operators(
&self,
hop: &dyn HessianOperator,
dense_matrices: &[&Array2<f64>],
operators: &[&dyn HyperOperator],
) -> Array2<f64> {
let n_dense = dense_matrices.len();
let n_ops = operators.len();
let total = n_dense + n_ops;
if total == 0 {
return Array2::zeros((0, 0));
}
let p = hop.dim();
if p == 0 {
return Array2::zeros((total, total));
}
if total == 1 {
let value = if n_dense == 1 {
self.estimate_second_order_single_dense(hop, dense_matrices[0])
} else {
self.estimate_second_order_single_operator(hop, operators[0])
};
return Array2::from_elem((1, 1), value);
}
let mut q_columns = Array2::zeros((p, total));
let mut a_u_columns = Array2::zeros((p, total));
self.estimate_matrix_from_probe_batch(hop, total, |z, probe_values| {
let u = hop.stochastic_trace_solve(z, self.config.solve_rel_tol);
for e in 0..n_dense {
dense_matvec_into(dense_matrices[e], z.view(), q_columns.column_mut(e));
dense_matvec_into(dense_matrices[e], u.view(), a_u_columns.column_mut(e));
}
for (oi, op) in operators.iter().enumerate() {
let e = n_dense + oi;
op.mul_vec_into(z.view(), q_columns.column_mut(e));
op.mul_vec_into(u.view(), a_u_columns.column_mut(e));
}
let r = hop.stochastic_trace_solve_multi(&q_columns, self.config.solve_rel_tol);
for d in 0..total {
let a_d_u = a_u_columns.column(d);
for e in d..total {
let r_e = r.column(e);
let val = a_d_u.dot(&r_e);
probe_values[[d, e]] = val;
if d != e {
let r_d = r.column(d);
let val_sym = a_u_columns.column(e).dot(&r_d);
probe_values[[e, d]] = val_sym;
}
}
}
})
}
fn estimate_second_order_single_dense(
&self,
hop: &dyn HessianOperator,
matrix: &Array2<f64>,
) -> f64 {
let p = hop.dim();
if p == 0 {
return 0.0;
}
if self.config.hutchpp_sketch_dim.is_some() {
let op = DenseMatrixHyperOperator {
matrix: matrix.clone(),
};
return hutchpp_estimate_trace_hinv_op_squared(hop, &op, &self.config);
}
let mut q = Array1::<f64>::zeros(p);
self.estimate_matrix_from_probe_batch(hop, 1, |z, probe_values| {
let u = hop.stochastic_trace_solve(z, self.config.solve_rel_tol);
dense_matvec_into(matrix, z.view(), q.view_mut());
let r = hop.stochastic_trace_solve(&q, self.config.solve_rel_tol);
probe_values[[0, 0]] = dense_bilinear(matrix, u.view(), r.view());
})[[0, 0]]
}
fn estimate_second_order_single_implicit(
&self,
hop: &dyn HessianOperator,
op: &ImplicitHyperOperator,
) -> f64 {
let p = hop.dim();
if p == 0 {
return 0.0;
}
if self.config.hutchpp_sketch_dim.is_some() {
return hutchpp_estimate_trace_hinv_op_squared(hop, op, &self.config);
}
let n_obs = op.w_diag.len();
let mut x_z = Array1::<f64>::zeros(n_obs);
let mut x_u = Array1::<f64>::zeros(n_obs);
let mut x_r = Array1::<f64>::zeros(n_obs);
let mut n_work = Array1::<f64>::zeros(n_obs);
let mut p_work = Array1::<f64>::zeros(p);
let mut q = Array1::<f64>::zeros(p);
self.estimate_matrix_from_probe_batch(hop, 1, |z, probe_values| {
let u = hop.stochastic_trace_solve(z, self.config.solve_rel_tol);
design_matrix_apply_view_into(&op.x_design, z.view(), x_z.view_mut());
op.matvec_with_shared_xz_into(
&x_z,
z.view(),
q.view_mut(),
n_work.view_mut(),
p_work.view_mut(),
);
let r = hop.stochastic_trace_solve(&q, self.config.solve_rel_tol);
design_matrix_apply_view_into(&op.x_design, u.view(), x_u.view_mut());
design_matrix_apply_view_into(&op.x_design, r.view(), x_r.view_mut());
let dx_u = op
.implicit_deriv
.forward_mul(op.axis, &u.view())
.expect("radial scalar evaluation failed during implicit derivative forward_mul");
let dx_r = op
.implicit_deriv
.forward_mul(op.axis, &r.view())
.expect("radial scalar evaluation failed during implicit derivative forward_mul");
let w = &*op.w_diag;
let mut value = 0.0;
for i in 0..w.len() {
let wi = w[i];
value += dx_u[i] * wi * x_r[i];
value += x_u[i] * wi * dx_r[i];
}
if let Some(c_x_psi_beta) = op.c_x_psi_beta.as_ref() {
let c = c_x_psi_beta.as_ref();
for i in 0..w.len() {
value += x_u[i] * c[i] * x_r[i];
}
}
value += dense_bilinear(&op.s_psi, r.view(), u.view());
probe_values[[0, 0]] = value;
})[[0, 0]]
}
fn estimate_second_order_single_operator(
&self,
hop: &dyn HessianOperator,
op: &dyn HyperOperator,
) -> f64 {
let p = hop.dim();
if p == 0 {
return 0.0;
}
let mut q = Array1::<f64>::zeros(p);
let mut a_u = Array1::<f64>::zeros(p);
self.estimate_matrix_from_probe_batch(hop, 1, |z, probe_values| {
let u = hop.stochastic_trace_solve(z, self.config.solve_rel_tol);
op.mul_vec_into(z.view(), q.view_mut());
op.mul_vec_into(u.view(), a_u.view_mut());
let r = hop.stochastic_trace_solve(&q, self.config.solve_rel_tol);
probe_values[[0, 0]] = a_u.dot(&r);
})[[0, 0]]
}
fn check_convergence(&self, n: usize, means: &[f64], m2s: &[f64]) -> bool {
if n < 2 {
return false;
}
let sqrt_n = (n as f64).sqrt();
let n_f = n as f64;
for k in 0..means.len() {
let variance = m2s[k] / (n_f - 1.0);
let std_dev = variance.max(0.0).sqrt();
let denom = sqrt_n * means[k].abs().max(self.config.tau_rel);
let rel_err = std_dev / denom;
if rel_err > self.config.relative_tol {
return false;
}
}
true
}
fn check_matrix_convergence(&self, n: usize, means: &Array2<f64>, m2s: &Array2<f64>) -> bool {
if n < 2 {
return false;
}
let sqrt_n = (n as f64).sqrt();
let n_f = n as f64;
let scale_floor = means
.iter()
.fold(0.0_f64, |acc, &value| acc.max(value.abs()))
.max(1.0)
* self.config.tau_rel;
for ((d, e), &mean) in means.indexed_iter() {
let variance = m2s[[d, e]] / (n_f - 1.0);
let std_dev = variance.max(0.0).sqrt();
let denom = sqrt_n * mean.abs().max(scale_floor);
let rel_err = std_dev / denom;
if rel_err > self.config.relative_tol {
return false;
}
}
true
}
}
fn stochastic_trace_hinv_products(
hop: &dyn HessianOperator,
targets: StochasticTraceTargets<'_>,
) -> Vec<f64> {
let estimator = StochasticTraceEstimator::with_defaults();
match targets {
StochasticTraceTargets::Dense(matrices) if matrices.len() == 1 => {
vec![estimator.estimate_single_trace(hop, matrices[0])]
}
StochasticTraceTargets::Dense(matrices) => estimator.estimate_traces(hop, matrices),
StochasticTraceTargets::Mixed {
dense_matrices,
operators,
} => estimator.estimate_traces_with_operators(hop, dense_matrices, operators),
StochasticTraceTargets::Structural {
dense_matrices,
implicit_ops,
} => estimator.estimate_traces_structural(hop, dense_matrices, implicit_ops),
}
}
fn stochastic_trace_hinv_crosses<'a>(
hop: &dyn HessianOperator,
dense_matrices: &'a [Array2<f64>],
coord_has_operator: &[bool],
generic_ops: &[&'a dyn HyperOperator],
implicit_ops: &[&'a ImplicitHyperOperator],
) -> Array2<f64> {
let estimator =
StochasticTraceEstimator::for_outer_hessian(hop.dim(), coord_has_operator.len());
let dense_refs: Vec<&Array2<f64>> = dense_matrices.iter().collect();
let raw_cross = if generic_ops.len() == implicit_ops.len() {
estimator.estimate_second_order_traces(hop, &dense_refs, implicit_ops)
} else {
estimator.estimate_second_order_traces_with_operators(hop, &dense_refs, generic_ops)
};
let total_coords = coord_has_operator.len();
let n_dense_total = coord_has_operator.iter().filter(|&&b| !b).count();
let mut original_to_raw = Vec::with_capacity(total_coords);
let mut dense_cursor = 0usize;
let mut operator_cursor = n_dense_total;
for &has_operator in coord_has_operator {
if has_operator {
original_to_raw.push(operator_cursor);
operator_cursor += 1;
} else {
original_to_raw.push(dense_cursor);
dense_cursor += 1;
}
}
let mut mapped = Array2::zeros((total_coords, total_coords));
for i in 0..total_coords {
for j in 0..total_coords {
mapped[[i, j]] = raw_cross[[original_to_raw[i], original_to_raw[j]]];
}
}
mapped
}
struct Xoshiro256SS {
s: [u64; 4],
}
impl Xoshiro256SS {
fn from_seed(seed: u64) -> Self {
let mut sm = seed;
let s0 = splitmix64(&mut sm);
let s1 = splitmix64(&mut sm);
let s2 = splitmix64(&mut sm);
let s3 = splitmix64(&mut sm);
let s = if s0 | s1 | s2 | s3 == 0 {
[1, 0, 0, 0]
} else {
[s0, s1, s2, s3]
};
Self { s }
}
#[inline]
fn next_u64(&mut self) -> u64 {
let result = (self.s[1].wrapping_mul(5)).rotate_left(7).wrapping_mul(9);
let t = self.s[1] << 17;
self.s[2] ^= self.s[0];
self.s[3] ^= self.s[1];
self.s[1] ^= self.s[2];
self.s[0] ^= self.s[3];
self.s[2] ^= t;
self.s[3] = self.s[3].rotate_left(45);
result
}
}
#[inline]
fn splitmix64(state: &mut u64) -> u64 {
*state = state.wrapping_add(0x9E3779B97F4A7C15);
let mut z = *state;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D049BB133111EB);
z ^ (z >> 31)
}
fn rademacher_probe_into(mut z: ArrayViewMut1<'_, f64>, rng: &mut Xoshiro256SS) {
let mut bits: u64 = 0;
let mut remaining_bits = 0u32;
for i in 0..z.len() {
if remaining_bits == 0 {
bits = rng.next_u64();
remaining_bits = 64;
}
z[i] = if bits & 1 == 0 { 1.0 } else { -1.0 };
bits >>= 1;
remaining_bits -= 1;
}
}
fn modified_gram_schmidt(y: &Array2<f64>, q: &mut Array2<f64>) -> usize {
let p = y.nrows();
let m = y.ncols();
debug_assert_eq!(q.dim(), (p, m));
q.fill(0.0);
if p == 0 || m == 0 {
return 0;
}
let mut max_norm: f64 = 0.0;
for j in 0..m {
let n = y.column(j).dot(&y.column(j)).sqrt();
if n > max_norm {
max_norm = n;
}
}
let drop_tol = (max_norm * 1.0e-12).max(f64::MIN_POSITIVE);
let mut rank = 0usize;
for j in 0..m {
let mut v = y.column(j).to_owned();
for k in 0..rank {
let qk = q.column(k);
let proj = qk.dot(&v);
if proj != 0.0 {
v.scaled_add(-proj, &qk);
}
}
let norm = v.dot(&v).sqrt();
if !norm.is_finite() || norm <= drop_tol {
continue;
}
let inv = 1.0 / norm;
v.iter_mut().for_each(|x| *x *= inv);
q.column_mut(rank).assign(&v);
rank += 1;
}
rank
}
pub(crate) fn hutchpp_estimate_trace_hinv_operator<H: HessianOperator + ?Sized>(
hop: &H,
op: &dyn HyperOperator,
config: &StochasticTraceConfig,
) -> f64 {
let p = hop.dim();
debug_assert_eq!(op.dim(), p, "Hutch++: operator dim mismatch");
if p == 0 {
return 0.0;
}
let sketch_dim = config.hutchpp_sketch_dim.unwrap_or(0).min(p);
let mut rng_state = Xoshiro256SS::from_seed(config.seed);
let mut q = Array2::<f64>::zeros((p, sketch_dim));
let mut q_rank = 0usize;
if sketch_dim > 0 {
let mut y = Array2::<f64>::zeros((p, sketch_dim));
let mut z = Array1::<f64>::zeros(p);
let mut mz = Array1::<f64>::zeros(p);
for j in 0..sketch_dim {
rademacher_probe_into(z.view_mut(), &mut rng_state);
op.mul_vec_into(z.view(), mz.view_mut());
let w = hop.stochastic_trace_solve(&mz, config.solve_rel_tol);
y.column_mut(j).assign(&w);
}
q_rank = modified_gram_schmidt(&y, &mut q);
}
let mut t_low = 0.0;
if q_rank > 0 {
let mut mq = Array1::<f64>::zeros(p);
for j in 0..q_rank {
let qcol = q.column(j).to_owned();
op.mul_vec_into(qcol.view(), mq.view_mut());
let w = hop.stochastic_trace_solve(&mq, config.solve_rel_tol);
t_low += qcol.dot(&w);
}
}
let used = 2 * q_rank;
let residual_budget_max = config.n_probes_max.saturating_sub(used);
let residual_min = config.n_probes_min.min(residual_budget_max);
let residual_budget = residual_budget_max.max(residual_min);
if residual_budget == 0 {
return t_low;
}
let mut sum = 0.0;
let mut sum_sq = 0.0;
let mut count = 0usize;
let mut z = Array1::<f64>::zeros(p);
let mut z_tilde = Array1::<f64>::zeros(p);
let mut mz = Array1::<f64>::zeros(p);
let check_interval = 4usize;
for m in 0..residual_budget {
rademacher_probe_into(z.view_mut(), &mut rng_state);
z_tilde.assign(&z);
if q_rank > 0 {
for j in 0..q_rank {
let qcol = q.column(j);
let proj = qcol.dot(&z);
if proj != 0.0 {
z_tilde.scaled_add(-proj, &qcol);
}
}
}
op.mul_vec_into(z_tilde.view(), mz.view_mut());
let w = hop.stochastic_trace_solve(&mz, config.solve_rel_tol);
let q_val = z_tilde.dot(&w);
sum += q_val;
sum_sq += q_val * q_val;
count += 1;
if count >= residual_min && count % check_interval == 0 && count >= 2 {
let n = count as f64;
let mean = sum / n;
let var = (sum_sq - n * mean * mean) / (n - 1.0).max(1.0);
if var.is_finite() && var >= 0.0 {
let stderr = (var / n).sqrt();
let denom = (mean.abs()).max(config.tau_rel);
if stderr / denom <= config.relative_tol {
let _ = m; break;
}
}
}
}
let mean_residual = if count > 0 { sum / count as f64 } else { 0.0 };
t_low + mean_residual
}
pub(crate) fn hutchpp_estimate_trace_hinv_op_squared<H: HessianOperator + ?Sized>(
hop: &H,
op: &dyn HyperOperator,
config: &StochasticTraceConfig,
) -> f64 {
let p = hop.dim();
debug_assert_eq!(op.dim(), p, "Hutch++ squared: operator dim mismatch");
if p == 0 {
return 0.0;
}
let sketch_dim = config.hutchpp_sketch_dim.unwrap_or(0).min(p);
let mut rng_state = Xoshiro256SS::from_seed(config.seed);
let apply_b_squared = |hop: &H,
op: &dyn HyperOperator,
input: ArrayView1<'_, f64>,
tmp: &mut Array1<f64>|
-> Array1<f64> {
op.mul_vec_into(input, tmp.view_mut());
let mid = hop.stochastic_trace_solve(tmp, config.solve_rel_tol);
op.mul_vec_into(mid.view(), tmp.view_mut());
hop.stochastic_trace_solve(tmp, config.solve_rel_tol)
};
let mut q = Array2::<f64>::zeros((p, sketch_dim));
let mut q_rank = 0usize;
if sketch_dim > 0 {
let mut y = Array2::<f64>::zeros((p, sketch_dim));
let mut z = Array1::<f64>::zeros(p);
let mut tmp = Array1::<f64>::zeros(p);
for j in 0..sketch_dim {
rademacher_probe_into(z.view_mut(), &mut rng_state);
let w = apply_b_squared(hop, op, z.view(), &mut tmp);
y.column_mut(j).assign(&w);
}
q_rank = modified_gram_schmidt(&y, &mut q);
}
let mut t_low = 0.0;
if q_rank > 0 {
let mut tmp = Array1::<f64>::zeros(p);
for j in 0..q_rank {
let qcol = q.column(j).to_owned();
let w = apply_b_squared(hop, op, qcol.view(), &mut tmp);
t_low += qcol.dot(&w);
}
}
let used = 2 * q_rank;
let residual_budget_max = config.n_probes_max.saturating_sub(used);
let residual_min = config.n_probes_min.min(residual_budget_max);
let residual_budget = residual_budget_max.max(residual_min);
if residual_budget == 0 {
return t_low;
}
let mut sum = 0.0;
let mut sum_sq = 0.0;
let mut count = 0usize;
let mut z = Array1::<f64>::zeros(p);
let mut z_tilde = Array1::<f64>::zeros(p);
let mut tmp = Array1::<f64>::zeros(p);
let check_interval = 4usize;
for _ in 0..residual_budget {
rademacher_probe_into(z.view_mut(), &mut rng_state);
z_tilde.assign(&z);
if q_rank > 0 {
for j in 0..q_rank {
let qcol = q.column(j);
let proj = qcol.dot(&z);
if proj != 0.0 {
z_tilde.scaled_add(-proj, &qcol);
}
}
}
let w = apply_b_squared(hop, op, z_tilde.view(), &mut tmp);
let q_val = z_tilde.dot(&w);
sum += q_val;
sum_sq += q_val * q_val;
count += 1;
if count >= residual_min && count % check_interval == 0 && count >= 2 {
let n = count as f64;
let mean = sum / n;
let var = (sum_sq - n * mean * mean) / (n - 1.0).max(1.0);
if var.is_finite() && var >= 0.0 {
let stderr = (var / n).sqrt();
let denom = (mean.abs()).max(config.tau_rel);
if stderr / denom <= config.relative_tol {
break;
}
}
}
}
let mean_residual = if count > 0 { sum / count as f64 } else { 0.0 };
t_low + mean_residual
}
pub(crate) fn hutchpp_estimate_trace_hinv_operator_cross<H: HessianOperator + ?Sized>(
hop: &H,
left: &dyn HyperOperator,
right: &dyn HyperOperator,
config: &StochasticTraceConfig,
) -> f64 {
let p = hop.dim();
debug_assert_eq!(left.dim(), p, "cross trace: left operator dim mismatch");
debug_assert_eq!(right.dim(), p, "cross trace: right operator dim mismatch");
if p == 0 {
return 0.0;
}
let sketch_dim = config.hutchpp_sketch_dim.unwrap_or(0).min(p);
let mut rng_state = Xoshiro256SS::from_seed(config.seed);
let apply_m = |hop: &H, x: ArrayView1<'_, f64>, tmp: &mut Array1<f64>| -> Array1<f64> {
right.mul_vec_into(x, tmp.view_mut());
let mid = hop.stochastic_trace_solve(tmp, config.solve_rel_tol);
left.mul_vec_into(mid.view(), tmp.view_mut());
hop.stochastic_trace_solve(tmp, config.solve_rel_tol)
};
let mut q = Array2::<f64>::zeros((p, sketch_dim));
let mut q_rank = 0usize;
if sketch_dim > 0 {
let mut y = Array2::<f64>::zeros((p, sketch_dim));
let mut z = Array1::<f64>::zeros(p);
let mut tmp = Array1::<f64>::zeros(p);
for j in 0..sketch_dim {
rademacher_probe_into(z.view_mut(), &mut rng_state);
let w = apply_m(hop, z.view(), &mut tmp);
y.column_mut(j).assign(&w);
}
q_rank = modified_gram_schmidt(&y, &mut q);
}
let mut t_low = 0.0;
if q_rank > 0 {
let mut tmp = Array1::<f64>::zeros(p);
for j in 0..q_rank {
let qcol = q.column(j).to_owned();
let w = apply_m(hop, qcol.view(), &mut tmp);
t_low += qcol.dot(&w);
}
}
let used = 2 * q_rank;
let residual_budget_max = config.n_probes_max.saturating_sub(used);
let residual_min = config.n_probes_min.min(residual_budget_max);
let residual_budget = residual_budget_max.max(residual_min);
if residual_budget == 0 {
return t_low;
}
let mut sum = 0.0;
let mut sum_sq = 0.0;
let mut count = 0usize;
let mut z = Array1::<f64>::zeros(p);
let mut z_tilde = Array1::<f64>::zeros(p);
let mut tmp = Array1::<f64>::zeros(p);
let check_interval = 4usize;
for _ in 0..residual_budget {
rademacher_probe_into(z.view_mut(), &mut rng_state);
z_tilde.assign(&z);
if q_rank > 0 {
for j in 0..q_rank {
let qcol = q.column(j);
let proj = qcol.dot(&z);
if proj != 0.0 {
z_tilde.scaled_add(-proj, &qcol);
}
}
}
let w = apply_m(hop, z_tilde.view(), &mut tmp);
let q_val = z_tilde.dot(&w);
sum += q_val;
sum_sq += q_val * q_val;
count += 1;
if count >= residual_min && count % check_interval == 0 && count >= 2 {
let n = count as f64;
let mean = sum / n;
let var = (sum_sq - n * mean * mean) / (n - 1.0).max(1.0);
if var.is_finite() && var >= 0.0 {
let stderr = (var / n).sqrt();
let denom = (mean.abs()).max(config.tau_rel);
if stderr / denom <= config.relative_tol {
break;
}
}
}
}
let mean_residual = if count > 0 { sum / count as f64 } else { 0.0 };
t_low + mean_residual
}
#[cfg(test)]
mod tests {
use super::*;
use crate::solver::estimate::DP_FLOOR;
use approx::assert_relative_eq;
use ndarray::array;
fn make_factor_key(seed: u64) -> ProjectedFactorKey {
ProjectedFactorKey {
design_id: 1,
factor_ptr: seed as usize,
rows: 1,
cols: 1,
row_stride: 1,
col_stride: 1,
value_hash: seed,
value_hash2: seed.wrapping_mul(31),
}
}
#[test]
fn projected_factor_cache_lru_evicts_oldest_under_budget() {
let entry_floats = 32usize;
let entry_bytes = entry_floats * std::mem::size_of::<f64>();
let cache = ProjectedFactorCache::with_budget(entry_bytes * 2);
let make = |seed: u64| -> Array2<f64> { Array2::from_elem((4, 8), seed as f64) };
let _a = cache.get_or_insert_with(make_factor_key(1), || make(1));
let _b = cache.get_or_insert_with(make_factor_key(2), || make(2));
assert_eq!(cache.len(), 2);
assert_eq!(cache.total_bytes(), entry_bytes * 2);
let _a_again = cache.get_or_insert_with(make_factor_key(1), || make(1));
let _c = cache.get_or_insert_with(make_factor_key(3), || make(3));
assert_eq!(cache.len(), 2);
assert_eq!(cache.total_bytes(), entry_bytes * 2);
let post_a = cache.get_or_insert_with(make_factor_key(1), || make(99));
let post_c = cache.get_or_insert_with(make_factor_key(3), || make(99));
assert_eq!(post_a[[0, 0]], 1.0, "a survived eviction");
assert_eq!(post_c[[0, 0]], 3.0, "c is the freshly inserted entry");
let post_b = cache.get_or_insert_with(make_factor_key(2), || make(99));
assert_eq!(
post_b[[0, 0]],
99.0,
"b was evicted; recompute closure runs"
);
}
#[test]
fn projected_factor_cache_zero_budget_disables_eviction() {
let cache = ProjectedFactorCache::with_budget(0);
for seed in 0..16 {
let _ = cache.get_or_insert_with(make_factor_key(seed), || {
Array2::from_elem((8, 8), seed as f64)
});
}
assert_eq!(cache.len(), 16);
}
#[test]
fn projected_factor_cache_oversize_entry_is_cached_unconditionally() {
let cache = ProjectedFactorCache::with_budget(8);
let huge = cache.get_or_insert_with(make_factor_key(1), || Array2::from_elem((4, 4), 1.0));
assert_eq!(huge[[0, 0]], 1.0);
assert_eq!(cache.len(), 1);
}
struct SentinelOuterHessianOperator {
matrix: Array2<f64>,
}
impl crate::solver::outer_strategy::OuterHessianOperator for SentinelOuterHessianOperator {
fn dim(&self) -> usize {
self.matrix.nrows()
}
fn matvec(&self, v: &Array1<f64>) -> Result<Array1<f64>, String> {
Ok(self.matrix.dot(v))
}
fn is_cheap_to_materialize(&self) -> bool {
true
}
}
struct FamilyOperatorOnlyDerivatives {
op: Arc<dyn crate::solver::outer_strategy::OuterHessianOperator>,
}
impl HessianDerivativeProvider for FamilyOperatorOnlyDerivatives {
fn hessian_derivative_correction(
&self,
_: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
Ok(None)
}
fn has_corrections(&self) -> bool {
false
}
fn outer_hessian_derivative_kernel(&self) -> Option<OuterHessianDerivativeKernel> {
None
}
fn family_outer_hessian_operator(
&self,
) -> Option<Arc<dyn crate::solver::outer_strategy::OuterHessianOperator>> {
Some(Arc::clone(&self.op))
}
}
#[test]
fn value_gradient_hessian_prefers_family_supplied_outer_operator() {
let hop = Arc::new(DenseSpectralOperator::from_symmetric(&Array2::eye(2)).unwrap());
let family_matrix = array![[42.0]];
let family_operator = Arc::new(SentinelOuterHessianOperator {
matrix: family_matrix.clone(),
});
let deriv_provider = FamilyOperatorOnlyDerivatives {
op: family_operator,
};
let solution = InnerSolution {
log_likelihood: -1.25,
penalty_quadratic: 0.4,
hessian_op: hop,
beta: array![0.5, -0.25],
penalty_coords: vec![PenaltyCoordinate::from_dense_root(Array2::eye(2))],
penalty_logdet: PenaltyLogdetDerivs {
value: 0.0,
first: array![0.0],
second: Some(array![[0.0]]),
},
deriv_provider: Box::new(deriv_provider),
tk_correction: 0.0,
tk_gradient: None,
firth: None,
hessian_logdet_correction: 0.0,
penalty_subspace_trace: None,
rho_curvature_scale: 1.0,
n_observations: 2,
nullspace_dim: 0.0,
dispersion: DispersionHandling::Fixed {
phi: 1.0,
include_logdet_h: true,
include_logdet_s: true,
},
ext_coords: Vec::new(),
ext_coord_pair_fn: None,
rho_ext_pair_fn: None,
fixed_drift_deriv: None,
barrier_config: None,
};
let result = reml_laml_evaluate(&solution, &[0.0], EvalMode::ValueGradientHessian, None)
.expect("family outer operator evaluation");
let crate::solver::outer_strategy::HessianResult::Operator(op) = result.hessian else {
panic!("expected family-supplied operator Hessian route");
};
let dense = op.materialize_dense().expect("sentinel materialization");
assert_eq!(dense, family_matrix);
}
#[test]
fn test_dense_spectral_operator_simple() {
let h = Array2::from_diag(&array![2.0, 5.0]);
let op = DenseSpectralOperator::from_symmetric(&h).unwrap();
let expected_logdet = 2.0_f64.ln() + 5.0_f64.ln();
assert!((op.logdet() - expected_logdet).abs() < 1e-12);
let id = Array2::eye(2);
let trace = op.trace_hinv_product(&id);
assert!((trace - 0.7).abs() < 1e-12);
let rhs = array![1.0, 1.0];
let sol = op.solve(&rhs);
assert!((sol[0] - 0.5).abs() < 1e-12);
assert!((sol[1] - 0.2).abs() < 1e-12);
assert_eq!(sol.len(), 2);
}
#[test]
fn test_dense_spectral_operator_solve_multi_matches_column_solves() {
let h = array![[4.0, 1.0, 0.5], [1.0, 3.0, 0.25], [0.5, 0.25, 2.0],];
let op = DenseSpectralOperator::from_symmetric(&h).unwrap();
let rhs = array![[1.0, -1.0], [0.5, 2.0], [3.0, 0.25],];
let multi = op.solve_multi(&rhs);
for col in 0..rhs.ncols() {
let single = op.solve(&rhs.column(col).to_owned());
for row in 0..rhs.nrows() {
let err = (multi[[row, col]] - single[row]).abs();
assert!(
err < 1e-12,
"solve_multi mismatch at ({row}, {col}): multi={}, single={}",
multi[[row, col]],
single[row]
);
}
}
}
#[test]
fn test_dense_spectral_operator_cross_trace_matches_column_solves() {
let h = array![[4.0, 1.0, 0.5], [1.0, 3.0, 0.25], [0.5, 0.25, 2.0],];
let op = DenseSpectralOperator::from_symmetric(&h).unwrap();
let a = array![[1.0, 0.2, -0.1], [0.2, 2.0, 0.3], [-0.1, 0.3, 0.5],];
let b = array![[0.5, -0.4, 0.1], [-0.4, 1.5, 0.25], [0.1, 0.25, 0.75],];
let expected = (&op.solve_multi(&a).t() * &op.solve_multi(&b)).sum();
let exact = op.trace_hinv_product_cross(&a, &b);
assert_relative_eq!(exact, expected, epsilon = 1e-12, max_relative = 1e-12);
}
#[test]
fn test_dense_spectral_operator_operator_cross_matches_dense_formula() {
let h = array![[5.0, 0.5, 0.25], [0.5, 3.5, 0.2], [0.25, 0.2, 2.5],];
let op = DenseSpectralOperator::from_symmetric(&h).unwrap();
let dense = array![[1.0, 0.1, -0.2], [0.1, 0.75, 0.3], [-0.2, 0.3, 1.25],];
let other = array![[0.6, -0.3, 0.15], [-0.3, 1.1, 0.05], [0.15, 0.05, 0.9],];
let other_op = DenseMatrixHyperOperator {
matrix: other.clone(),
};
let expected = op.trace_hinv_product_cross(&dense, &other);
let mixed = op.trace_hinv_matrix_operator_cross(&dense, &other_op);
let operator = op.trace_hinv_operator_cross(&other_op, &other_op);
let operator_expected = op.trace_hinv_product_cross(&other, &other);
assert_relative_eq!(mixed, expected, epsilon = 1e-12, max_relative = 1e-12);
assert_relative_eq!(
operator,
operator_expected,
epsilon = 1e-12,
max_relative = 1e-12
);
}
#[test]
fn test_hyper_coord_total_drift_result_keeps_operator_and_dense_correction() {
let h = array![[4.0, 0.25], [0.25, 3.0],];
let hop = DenseSpectralOperator::from_symmetric(&h).unwrap();
let base = array![[1.0, 0.2], [0.2, 0.5],];
let corr = array![[0.3, -0.1], [-0.1, 0.4],];
let drift = HyperCoordDrift::from_operator(Arc::new(DenseMatrixHyperOperator {
matrix: base.clone(),
}));
let correction = DriftDerivResult::Dense(corr.clone());
let combined = hyper_coord_total_drift_result(&drift, Some(&correction), h.nrows());
let expected = hop.trace_logdet_gradient(&(&base + &corr));
assert_relative_eq!(
combined.trace_logdet(&hop),
expected,
epsilon = 1e-12,
max_relative = 1e-12
);
}
#[test]
fn test_dense_spectral_operator_rotated_logdet_cross_matches_dense_path() {
let h = array![[4.0, 0.5, 0.2], [0.5, 2.5, 0.3], [0.2, 0.3, 1.75],];
let op = DenseSpectralOperator::from_symmetric(&h).unwrap();
let a = array![[0.8, 0.2, -0.1], [0.2, 1.4, 0.35], [-0.1, 0.35, 0.9],];
let b = array![[1.2, -0.25, 0.05], [-0.25, 0.7, 0.15], [0.05, 0.15, 0.6],];
let a_rot = op.rotate_to_eigenbasis(&a);
let b_rot = op.rotate_to_eigenbasis(&b);
let direct = op.trace_logdet_hessian_cross(&a, &b);
let rotated = op.trace_logdet_hessian_cross_rotated(&a_rot, &b_rot);
assert_relative_eq!(rotated, direct, epsilon = 1e-12, max_relative = 1e-12);
}
#[test]
fn test_compute_adjoint_z_c_streaming_matches_dense_reference() {
let n = 64usize;
let p = 8usize;
let mut rng = Xoshiro256SS::from_seed(0x5EED_C0FFEE_u64);
let unit = |rng: &mut Xoshiro256SS| {
let bits = rng.next_u64() >> 11;
(bits as f64) / ((1u64 << 53) as f64) * 2.0 - 1.0
};
let mut x_data = Array2::<f64>::zeros((n, p));
for i in 0..n {
for j in 0..p {
x_data[[i, j]] = unit(&mut rng);
}
}
let mut c_array = Array1::<f64>::zeros(n);
for i in 0..n {
c_array[i] = unit(&mut rng);
}
let mut m = Array2::<f64>::zeros((p, p));
for i in 0..p {
for j in 0..p {
m[[i, j]] = unit(&mut rng);
}
}
let mut h = m.t().dot(&m);
for i in 0..p {
h[[i, i]] += p as f64;
}
let hop = DenseSpectralOperator::from_symmetric(&h).unwrap();
let x = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(x_data.clone()));
let ing = ScalarGlmIngredients {
c_array: &c_array,
d_array: None,
x: &x,
};
let z_full = hop.solve_multi(&x_data.t().to_owned());
let mut h_dense = Array1::<f64>::zeros(n);
for i in 0..n {
let mut acc = 0.0;
for j in 0..p {
acc += x_data[[i, j]] * z_full[[j, i]];
}
h_dense[i] = acc;
}
let streamed = compute_adjoint_z_c(&ing, &hop, &h_dense).expect("adjoint path");
let mut t = h_dense.clone();
Zip::from(&mut t)
.and(&c_array)
.for_each(|t_i, &c_i| *t_i *= c_i);
let v = x_data.t().dot(&t);
let reference = hop.solve(&v);
for k in 0..p {
assert_relative_eq!(
streamed[k],
reference[k],
epsilon = 1e-12,
max_relative = 1e-12
);
}
}
#[test]
fn fourth_derivative_trace_matrix_matches_scalar_pair_formula() {
let n = 37usize;
let p = 5usize;
let t = 4usize;
let mut rng = Xoshiro256SS::from_seed(0xF047_ACE5_u64);
let unit = |rng: &mut Xoshiro256SS| {
let bits = rng.next_u64() >> 11;
(bits as f64) / ((1u64 << 53) as f64) * 2.0 - 1.0
};
let mut x_data = Array2::<f64>::zeros((n, p));
for i in 0..n {
for j in 0..p {
x_data[[i, j]] = unit(&mut rng);
}
}
let mut c_array = Array1::<f64>::zeros(n);
let mut d_array = Array1::<f64>::zeros(n);
let mut leverage = Array1::<f64>::zeros(n);
for i in 0..n {
c_array[i] = unit(&mut rng);
d_array[i] = unit(&mut rng);
leverage[i] = 0.25 + unit(&mut rng).abs();
}
let x = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(x_data));
let ing = ScalarGlmIngredients {
c_array: &c_array,
d_array: Some(&d_array),
x: &x,
};
let mut modes = Vec::with_capacity(t);
for _ in 0..t {
let mut mode = Array1::<f64>::zeros(p);
for j in 0..p {
mode[j] = unit(&mut rng);
}
modes.push(mode);
}
let mode_refs = modes.iter().collect::<Vec<_>>();
let gram = compute_fourth_derivative_trace_matrix(&ing, &mode_refs, &leverage)
.expect("batched fourth trace")
.expect("d-array is present");
for i in 0..t {
for j in 0..t {
let scalar = compute_fourth_derivative_trace(&ing, &modes[i], &modes[j], &leverage)
.expect("scalar fourth trace")
.expect("d-array is present");
assert_relative_eq!(gram[[i, j]], scalar, epsilon = 1e-10, max_relative = 1e-10);
}
}
}
#[test]
fn operator_hessian_matches_dense_with_operator_drifts_and_extended_glm_corrections() {
let h = array![[1.0e-7, 0.0], [0.0, 2.7]];
let hop = Arc::new(DenseSpectralOperator::from_symmetric(&h).unwrap());
let beta = array![0.4, -0.7];
let penalty_root = array![[1.2, 0.1], [0.0, 0.8]];
let ext_drift = array![[0.45, -0.15], [-0.15, 0.35]];
let x = array![[1.0, 0.2], [-0.4, 1.1], [0.7, -0.8]];
let c_array = array![0.31, -0.27, 0.19];
let d_array = array![0.17, -0.11, 0.23];
let deriv_provider = SinglePredictorGlmDerivatives {
c_array,
d_array: Some(d_array),
x_transformed: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(x)),
};
let solution = InnerSolution {
log_likelihood: -2.3,
penalty_quadratic: 0.6,
hessian_op: hop.clone(),
beta,
penalty_coords: vec![PenaltyCoordinate::from_dense_root(penalty_root)],
penalty_logdet: PenaltyLogdetDerivs {
value: 0.0,
first: array![0.4],
second: Some(array![[0.13]]),
},
deriv_provider: Box::new(deriv_provider),
tk_correction: 0.0,
tk_gradient: None,
firth: None,
hessian_logdet_correction: 0.0,
penalty_subspace_trace: None,
rho_curvature_scale: 1.0,
n_observations: 3,
nullspace_dim: 0.0,
dispersion: DispersionHandling::Fixed {
phi: 1.0,
include_logdet_h: true,
include_logdet_s: true,
},
ext_coords: vec![HyperCoord {
a: -0.21,
g: array![0.33, -0.42],
drift: HyperCoordDrift::from_operator(Arc::new(DenseMatrixHyperOperator {
matrix: ext_drift,
})),
ld_s: 0.07,
b_depends_on_beta: false,
is_penalty_like: false,
firth_g: None,
tk_eta_fixed: None,
tk_x_fixed: None,
}],
ext_coord_pair_fn: Some(Box::new(|_, _| HyperCoordPair {
a: 0.09,
g: array![0.16, -0.12],
b_mat: array![[0.08, 0.03], [0.03, -0.04]],
b_operator: None,
ld_s: -0.05,
})),
rho_ext_pair_fn: Some(Box::new(|_, _| HyperCoordPair {
a: -0.14,
g: array![-0.18, 0.22],
b_mat: array![[0.05, -0.02], [-0.02, 0.07]],
b_operator: None,
ld_s: 0.04,
})),
fixed_drift_deriv: None,
barrier_config: None,
};
let rho: Vec<f64> = vec![0.2_f64];
let lambdas: Vec<f64> = rho.iter().map(|value| value.exp()).collect();
let dense = compute_outer_hessian(
&solution,
&rho,
&lambdas,
solution.hessian_op.as_ref(),
solution.deriv_provider.as_ref(),
None,
)
.unwrap();
let kernel = solution
.deriv_provider
.outer_hessian_derivative_kernel()
.unwrap();
let operator = build_outer_hessian_operator(
&solution,
&lambdas,
solution.deriv_provider.as_ref(),
kernel,
None,
None,
)
.unwrap();
let materialized =
crate::solver::outer_strategy::OuterHessianOperator::materialize_dense(&operator)
.unwrap();
for row in 0..dense.nrows() {
for col in 0..dense.ncols() {
let materialized_entry = materialized[[row, col]];
let dense_entry = dense[[row, col]];
let tolerance = 1e-10_f64.max(1e-10 * dense_entry.abs());
assert!(
(materialized_entry - dense_entry).abs() <= tolerance,
"outer Hessian operator mismatch at ({row}, {col}): materialized={materialized_entry}, dense={dense_entry}"
);
}
}
let alpha = array![0.37, -0.58];
let hvp = crate::solver::outer_strategy::OuterHessianOperator::matvec(&operator, &alpha)
.expect("operator HVP");
let dense_hvp = dense.dot(&alpha);
for i in 0..hvp.len() {
let tolerance = 1e-10_f64.max(1e-10 * dense_hvp[i].abs());
assert!(
(hvp[i] - dense_hvp[i]).abs() <= tolerance,
"outer Hessian HVP mismatch at {i}: operator={}, dense={}",
hvp[i],
dense_hvp[i]
);
}
}
#[test]
fn subspace_projected_leverage_and_adjoint_shortcut_match_dense() {
let u_s = array![[1.0, 0.0], [0.0, 1.0], [0.0, 0.0], [0.0, 0.0]];
let det = 3.0_f64 * 5.0 - 0.1 * 0.1;
let h_proj_inverse = array![[5.0 / det, -0.1 / det], [-0.1 / det, 3.0 / det]];
let subspace = PenaltySubspaceTrace {
u_s: u_s.clone(),
h_proj_inverse: h_proj_inverse.clone(),
};
let x_data = array![
[1.0, 0.2, 0.5, 0.3],
[1.0, 1.1, -0.2, 0.4],
[1.0, -0.8, 0.7, -0.1],
[1.0, 0.5, 0.3, 0.6]
];
let c = array![0.31_f64, -0.27, 0.19, -0.11];
let k_dense = u_s.dot(&h_proj_inverse).dot(&u_s.t());
let n = x_data.nrows();
let x_design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(x_data.clone()));
let h_g_proj = subspace.xt_projected_kernel_x_diagonal(&x_design);
assert_eq!(h_g_proj.len(), n);
for i in 0..n {
let row = x_data.row(i).to_owned();
let kx = k_dense.dot(&row);
assert_relative_eq!(h_g_proj[i], row.dot(&kx), epsilon = 1e-12);
}
let probes = [
array![0.6_f64, -0.4, 0.0, 0.0],
array![0.0_f64, 0.0, 0.5, 0.7],
array![0.3_f64, -0.1, 0.4, -0.2],
array![1.0_f64, 1.0, 1.0, 1.0],
];
for u in probes.iter() {
let xu = x_data.dot(u);
let mut weighted_x = x_data.clone();
for i in 0..n {
let w = c[i] * xu[i];
for j in 0..weighted_x.ncols() {
weighted_x[[i, j]] *= w;
}
}
let c_u_dense = x_data.t().dot(&weighted_x);
let lhs = subspace.trace_projected_logdet(&c_u_dense);
let mut weighted = Array1::<f64>::zeros(n);
for i in 0..n {
weighted[i] = c[i] * h_g_proj[i];
}
let rhs = u.dot(&x_data.t().dot(&weighted));
assert_relative_eq!(lhs, rhs, epsilon = 1e-12, max_relative = 1e-12);
}
}
#[test]
fn outer_hessian_operator_matvec_matches_dense_subspace_with_null_alpha() {
let h = array![
[3.0, 0.1, 0.0, 0.0],
[0.1, 5.0, 0.05, 0.0],
[0.0, 0.05, 7.0, 0.15],
[0.0, 0.0, 0.15, 11.0]
];
let hop = Arc::new(DenseSpectralOperator::from_symmetric(&h).unwrap());
let u_s = array![[1.0, 0.0], [0.0, 1.0], [0.0, 0.0], [0.0, 0.0]];
let det = 3.0_f64 * 5.0 - 0.1 * 0.1;
let h_proj_inverse = array![[5.0 / det, -0.1 / det], [-0.1 / det, 3.0 / det]];
let penalty_root_0 = array![[0.7, 0.3, 0.6, 0.0]];
let penalty_root_1 = array![[0.2, 0.5, 0.0, 0.4]];
let x = array![
[1.0, 0.2, 0.5, 0.3],
[1.0, 1.1, -0.2, 0.4],
[1.0, -0.8, 0.7, -0.1],
[1.0, 0.5, 0.3, 0.6]
];
let c_array = array![0.31, -0.27, 0.19, -0.11];
let d_array = array![0.17, -0.11, 0.23, 0.07];
let deriv_provider = SinglePredictorGlmDerivatives {
c_array,
d_array: Some(d_array),
x_transformed: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(x)),
};
let logdet_h_proj = det.ln();
let beta = array![0.4, -0.7, 0.2, 0.1];
let solution = InnerSolution {
log_likelihood: -2.3,
penalty_quadratic: 0.6,
hessian_op: hop.clone(),
beta,
penalty_coords: vec![
PenaltyCoordinate::from_dense_root(penalty_root_0),
PenaltyCoordinate::from_dense_root(penalty_root_1),
],
penalty_logdet: PenaltyLogdetDerivs {
value: 0.0,
first: array![0.4, -0.2],
second: Some(array![[0.13, 0.02], [0.02, 0.09]]),
},
deriv_provider: Box::new(deriv_provider),
tk_correction: 0.0,
tk_gradient: None,
firth: None,
hessian_logdet_correction: logdet_h_proj - hop.logdet(),
penalty_subspace_trace: Some(Arc::new(PenaltySubspaceTrace {
u_s,
h_proj_inverse,
})),
rho_curvature_scale: 1.0,
n_observations: 4,
nullspace_dim: 2.0,
dispersion: DispersionHandling::Fixed {
phi: 1.0,
include_logdet_h: true,
include_logdet_s: true,
},
ext_coords: Vec::new(),
ext_coord_pair_fn: None,
rho_ext_pair_fn: None,
fixed_drift_deriv: None,
barrier_config: None,
};
let rho: Vec<f64> = vec![0.2_f64, -0.1];
let lambdas: Vec<f64> = rho.iter().map(|value| value.exp()).collect();
let dense = compute_outer_hessian(
&solution,
&rho,
&lambdas,
solution.hessian_op.as_ref(),
solution.deriv_provider.as_ref(),
None,
)
.unwrap();
let kernel = solution
.deriv_provider
.outer_hessian_derivative_kernel()
.unwrap();
let operator = build_outer_hessian_operator(
&solution,
&lambdas,
solution.deriv_provider.as_ref(),
kernel,
None,
None,
)
.unwrap();
let materialized =
crate::solver::outer_strategy::OuterHessianOperator::materialize_dense(&operator)
.unwrap();
for row in 0..dense.nrows() {
for col in 0..dense.ncols() {
assert_relative_eq!(
materialized[[row, col]],
dense[[row, col]],
epsilon = 1e-12,
max_relative = 1e-12
);
}
}
let alphas = [
array![1.0, 0.0],
array![0.0, 1.0],
array![1.0, 1.0],
array![1.0, -1.0],
array![0.7, -0.3],
];
for alpha in alphas.iter() {
let hvp = crate::solver::outer_strategy::OuterHessianOperator::matvec(&operator, alpha)
.expect("operator HVP");
let dense_hvp = dense.dot(alpha);
for i in 0..hvp.len() {
assert_relative_eq!(hvp[i], dense_hvp[i], epsilon = 1e-12, max_relative = 1e-12);
}
}
}
#[test]
fn projected_operator_hessian_matches_dense_subspace_trace() {
let h = array![[3.0, 0.2], [0.2, 5.0]];
let hop = Arc::new(DenseSpectralOperator::from_symmetric(&h).unwrap());
let beta = array![0.4, -0.7];
let penalty_root = array![[0.0, 1.0]];
let ext_drift = array![[0.45, -0.15], [-0.15, 0.35]];
let x = array![[1.0, 0.2], [1.0, 1.1], [1.0, -0.8], [1.0, 0.5]];
let c_array = array![0.31, -0.27, 0.19, -0.11];
let d_array = array![0.17, -0.11, 0.23, 0.07];
let deriv_provider = SinglePredictorGlmDerivatives {
c_array,
d_array: Some(d_array),
x_transformed: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(x)),
};
let h_proj = h[[1, 1]];
let solution = InnerSolution {
log_likelihood: -2.3,
penalty_quadratic: 0.6,
hessian_op: hop.clone(),
beta,
penalty_coords: vec![PenaltyCoordinate::from_dense_root(penalty_root)],
penalty_logdet: PenaltyLogdetDerivs {
value: 0.0,
first: array![0.4],
second: Some(array![[0.13]]),
},
deriv_provider: Box::new(deriv_provider),
tk_correction: 0.0,
tk_gradient: None,
firth: None,
hessian_logdet_correction: h_proj.ln() - hop.logdet(),
penalty_subspace_trace: Some(Arc::new(PenaltySubspaceTrace {
u_s: array![[0.0], [1.0]],
h_proj_inverse: array![[1.0 / h_proj]],
})),
rho_curvature_scale: 1.0,
n_observations: 4,
nullspace_dim: 1.0,
dispersion: DispersionHandling::Fixed {
phi: 1.0,
include_logdet_h: true,
include_logdet_s: true,
},
ext_coords: vec![HyperCoord {
a: -0.21,
g: array![0.33, -0.42],
drift: HyperCoordDrift::from_operator(Arc::new(DenseMatrixHyperOperator {
matrix: ext_drift,
})),
ld_s: 0.07,
b_depends_on_beta: false,
is_penalty_like: false,
firth_g: None,
tk_eta_fixed: None,
tk_x_fixed: None,
}],
ext_coord_pair_fn: Some(Box::new(|_, _| HyperCoordPair {
a: 0.09,
g: array![0.16, -0.12],
b_mat: array![[0.08, 0.03], [0.03, -0.04]],
b_operator: None,
ld_s: -0.05,
})),
rho_ext_pair_fn: Some(Box::new(|_, _| HyperCoordPair {
a: -0.14,
g: array![-0.18, 0.22],
b_mat: array![[0.05, -0.02], [-0.02, 0.07]],
b_operator: None,
ld_s: 0.04,
})),
fixed_drift_deriv: None,
barrier_config: None,
};
let rho: Vec<f64> = vec![0.2_f64];
let lambdas: Vec<f64> = rho.iter().map(|value| value.exp()).collect();
let dense = compute_outer_hessian(
&solution,
&rho,
&lambdas,
solution.hessian_op.as_ref(),
solution.deriv_provider.as_ref(),
None,
)
.unwrap();
let kernel = solution
.deriv_provider
.outer_hessian_derivative_kernel()
.unwrap();
let operator = build_outer_hessian_operator(
&solution,
&lambdas,
solution.deriv_provider.as_ref(),
kernel,
None,
None,
)
.unwrap();
let materialized =
crate::solver::outer_strategy::OuterHessianOperator::materialize_dense(&operator)
.unwrap();
for row in 0..dense.nrows() {
for col in 0..dense.ncols() {
assert_relative_eq!(
materialized[[row, col]],
dense[[row, col]],
epsilon = 1e-10,
max_relative = 1e-10
);
}
}
}
#[test]
fn subspace_trace_large_k_routes_to_projected_operator() {
let h = array![[3.0, 0.2], [0.2, 5.0]];
let hop = Arc::new(DenseSpectralOperator::from_symmetric(&h).unwrap());
let pcoord = PenaltyCoordinate::from_dense_root(array![[0.0, 1.0]]);
let k = MATRIX_FREE_OUTER_HESSIAN_K_THRESHOLD;
let x = array![[1.0, 0.2], [1.0, 1.1], [1.0, -0.8], [1.0, 0.5]];
let deriv_provider = SinglePredictorGlmDerivatives {
c_array: array![0.31, -0.27, 0.19, -0.11],
d_array: Some(array![0.17, -0.11, 0.23, 0.07]),
x_transformed: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(x)),
};
let h_proj = h[[1, 1]];
let solution = InnerSolution {
log_likelihood: -2.3,
penalty_quadratic: 0.6,
hessian_op: hop.clone(),
beta: array![0.4, -0.7],
penalty_coords: vec![pcoord; k],
penalty_logdet: PenaltyLogdetDerivs {
value: 0.0,
first: Array1::zeros(k),
second: Some(Array2::zeros((k, k))),
},
deriv_provider: Box::new(deriv_provider),
tk_correction: 0.0,
tk_gradient: None,
firth: None,
hessian_logdet_correction: h_proj.ln() - hop.logdet(),
penalty_subspace_trace: Some(Arc::new(PenaltySubspaceTrace {
u_s: array![[0.0], [1.0]],
h_proj_inverse: array![[1.0 / h_proj]],
})),
rho_curvature_scale: 1.0,
n_observations: 4,
nullspace_dim: 1.0,
dispersion: DispersionHandling::Fixed {
phi: 1.0,
include_logdet_h: true,
include_logdet_s: true,
},
ext_coords: Vec::new(),
ext_coord_pair_fn: None,
rho_ext_pair_fn: None,
fixed_drift_deriv: None,
barrier_config: None,
};
let rho = vec![0.0_f64; k];
let result =
reml_laml_evaluate(&solution, &rho, EvalMode::ValueGradientHessian, None).unwrap();
assert!(
matches!(
result.hessian,
crate::solver::outer_strategy::HessianResult::Operator(_)
),
"large-k subspace-trace case should use projected outer Hessian operator"
);
}
#[test]
fn test_dense_spectral_operator_singular() {
let h = array![[1.0, 1.0], [1.0, 1.0]];
let op = DenseSpectralOperator::from_symmetric(&h).unwrap();
let epsilon = spectral_epsilon(&[0.0, 2.0]);
let r0 = spectral_regularize(0.0, epsilon);
let r2 = spectral_regularize(2.0, epsilon);
let expected_logdet = r0.ln() + r2.ln();
assert!((op.logdet() - expected_logdet).abs() < 1e-10);
let trace = op.trace_hinv_product(&Array2::eye(2));
assert!(trace.is_finite());
}
#[test]
fn test_spectral_regularize_stays_finite_in_extreme_tails() {
let epsilon = 1e-8;
let large_negative = spectral_regularize(-1e16, epsilon);
assert!(
large_negative.is_finite() && large_negative > 0.0,
"large negative sigma should regularize to a positive finite value, got {large_negative}"
);
let large_positive = spectral_regularize(1e308, epsilon);
assert!(
large_positive.is_finite() && large_positive > 0.0,
"large positive sigma should stay finite, got {large_positive}"
);
}
#[test]
fn test_smooth_floor_dp() {
let (val, grad, _) = smooth_floor_dp(1.0);
assert!((val - 1.0).abs() < 1e-6);
assert!((grad - 1.0).abs() < 1e-6);
let (val, grad, _) = smooth_floor_dp(DP_FLOOR);
assert!(val > DP_FLOOR);
assert!((grad - 0.5).abs() < 0.1);
let (val, _, _) = smooth_floor_dp(0.0);
assert!(val >= DP_FLOOR);
}
#[test]
fn test_gaussian_derivatives_has_no_corrections() {
let g = GaussianDerivatives;
assert!(!g.has_corrections());
assert!(
g.hessian_derivative_correction(&array![1.0, 2.0])
.unwrap()
.is_none()
);
}
#[test]
fn gaussian_derivatives_advertise_exact_outer_hvp_kernel() {
let g = GaussianDerivatives;
assert!(matches!(
g.outer_hessian_derivative_kernel(),
Some(OuterHessianDerivativeKernel::Gaussian)
));
}
#[test]
fn standard_gam_large_n_gaussian_prefers_operator_when_dense_work_is_large() {
assert!(prefer_outer_hessian_operator(320_000, 42, 6));
assert!(matches!(
GaussianDerivatives.outer_hessian_derivative_kernel(),
Some(OuterHessianDerivativeKernel::Gaussian)
));
}
#[test]
fn gaussian_outer_hessian_operator_matches_dense_assembly() {
let h = array![[2.4, 0.2], [0.2, 1.7]];
let hop = Arc::new(DenseSpectralOperator::from_symmetric(&h).unwrap());
let beta = array![0.35, -0.55];
let penalty_root_0 = array![[1.0, 0.2], [0.0, 0.4]];
let penalty_root_1 = array![[0.3, -0.1], [0.0, 0.9]];
let solution = InnerSolution {
log_likelihood: -8.0,
penalty_quadratic: 0.9,
hessian_op: hop.clone(),
beta,
penalty_coords: vec![
PenaltyCoordinate::from_dense_root(penalty_root_0),
PenaltyCoordinate::from_dense_root(penalty_root_1),
],
penalty_logdet: PenaltyLogdetDerivs {
value: 0.0,
first: array![0.8, 0.6],
second: Some(array![[0.11, 0.03], [0.03, 0.17]]),
},
deriv_provider: Box::new(GaussianDerivatives),
tk_correction: 0.0,
tk_gradient: None,
firth: None,
hessian_logdet_correction: 0.0,
penalty_subspace_trace: None,
rho_curvature_scale: 1.0,
n_observations: 320_000,
nullspace_dim: 1.0,
dispersion: DispersionHandling::ProfiledGaussian,
ext_coords: Vec::new(),
ext_coord_pair_fn: None,
rho_ext_pair_fn: None,
fixed_drift_deriv: None,
barrier_config: None,
};
let rho: Vec<f64> = vec![0.2_f64, -0.4_f64];
let lambdas: Vec<f64> = rho.iter().map(|value| value.exp()).collect();
let dense = compute_outer_hessian(
&solution,
&rho,
&lambdas,
solution.hessian_op.as_ref(),
solution.deriv_provider.as_ref(),
None,
)
.unwrap();
let kernel = solution
.deriv_provider
.outer_hessian_derivative_kernel()
.unwrap();
let operator = build_outer_hessian_operator(
&solution,
&lambdas,
solution.deriv_provider.as_ref(),
kernel,
None,
None,
)
.unwrap();
let materialized =
crate::solver::outer_strategy::OuterHessianOperator::materialize_dense(&operator)
.unwrap();
for row in 0..dense.nrows() {
for col in 0..dense.ncols() {
let expected = dense[[row, col]];
let actual = materialized[[row, col]];
let tolerance = 1e-10_f64.max(1e-10 * expected.abs());
assert!(
(actual - expected).abs() <= tolerance,
"Gaussian outer Hessian operator mismatch at ({row}, {col}): materialized={actual}, dense={expected}"
);
}
}
}
#[test]
fn efs_step_is_zero_at_scalar_optimum() {
let lambda = 1.0 / 3.0;
let beta_hat = 1.5_f64;
let h = Array2::from_shape_vec((1, 1), vec![1.0 + lambda]).unwrap();
let op = DenseSpectralOperator::from_symmetric(&h).unwrap();
let penalty_root = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
let solution = InnerSolution {
log_likelihood: 0.0,
penalty_quadratic: 0.0,
hessian_op: Arc::new(op),
beta: array![beta_hat],
penalty_coords: vec![PenaltyCoordinate::from_dense_root(penalty_root)],
penalty_logdet: PenaltyLogdetDerivs {
value: 0.0,
first: array![1.0],
second: None,
},
deriv_provider: Box::new(GaussianDerivatives),
tk_correction: 0.0,
tk_gradient: None,
firth: None,
hessian_logdet_correction: 0.0,
penalty_subspace_trace: None,
rho_curvature_scale: 1.0,
n_observations: 10,
nullspace_dim: 0.0,
dispersion: DispersionHandling::Fixed {
phi: 1.0,
include_logdet_h: true,
include_logdet_s: true,
},
ext_coords: Vec::new(),
ext_coord_pair_fn: None,
rho_ext_pair_fn: None,
fixed_drift_deriv: None,
barrier_config: None,
};
let rho = [lambda.ln()];
let gradient_at_optimum = [0.0_f64];
let steps = compute_efs_update(&solution, &rho, &gradient_at_optimum);
assert_eq!(steps.len(), 1);
assert!(
steps[0].abs() < 1e-12,
"EFS step at scalar optimum should be exactly 0, got {} (old buggy formula returned ~+5)",
steps[0]
);
let q_eff = lambda * beta_hat * beta_hat; let g_off = 0.1_f64;
let steps_off = compute_efs_update(&solution, &rho, &[g_off]);
let expected = (1.0_f64 - 2.0 * g_off / q_eff).ln();
assert!(
(steps_off[0] - expected).abs() < 1e-12,
"off-optimum EFS step {} != expected {}",
steps_off[0],
expected
);
}
#[test]
fn efs_log_step_from_grad_recovers_canonical_form() {
let cases = [
(1.0_f64, 0.5),
(2.0, 1.5),
(0.75, 0.75),
(4.0, 0.1),
(1.0, 0.999),
];
for (q_eff, target) in cases {
let g_base = (q_eff - target) / 2.0;
let universal = efs_log_step_from_grad(q_eff, g_base).unwrap();
let canonical = (target / q_eff).ln().clamp(-EFS_MAX_STEP, EFS_MAX_STEP);
assert!(
(universal - canonical).abs() < 1e-12,
"universal {universal} ≠ canonical {canonical} at q={q_eff}, t={target}"
);
}
let target = 0.6_f64;
let g_extra = -0.7_f64;
let augmented_q = target - 2.0 * g_extra;
let g_full_at_aug_opt = (augmented_q - target) / 2.0 + g_extra;
assert!(g_full_at_aug_opt.abs() < 1e-12);
let s_at_opt = efs_log_step_from_grad(augmented_q, g_full_at_aug_opt).unwrap();
assert!(
s_at_opt.abs() < 1e-12,
"Δρ at augmented optimum != 0: {s_at_opt}"
);
let s = efs_log_step_from_grad(2.0, 0.75).expect("stable regime");
assert!((s - (0.25_f64).ln()).abs() < 1e-12);
let s = efs_log_step_from_grad(0.75, 0.0).expect("zero gradient");
assert!(s.abs() < 1e-12);
for &(q_eff, g) in &[(1.0_f64, 0.6), (2.0, 1.5), (0.5, 1e6)] {
let s = efs_log_step_from_grad(q_eff, g).expect("over-correction");
assert!((s - (-EFS_MAX_STEP)).abs() < 1e-12);
}
let s = efs_log_step_from_grad(1.0, 0.5 - 1e-30).expect("near-singular");
assert!((s - (-EFS_MAX_STEP)).abs() < 1e-12 || s == 0.5 * (-EFS_MAX_STEP) || s.is_finite());
assert!(s <= 0.0);
assert!(efs_log_step_from_grad(0.0, 0.0).is_none());
assert!(efs_log_step_from_grad(-1.0, 0.0).is_none());
assert!(efs_log_step_from_grad(f64::NAN, 0.0).is_none());
assert!(efs_log_step_from_grad(1.0, f64::NAN).is_none());
assert!(efs_log_step_from_grad(1.0, f64::INFINITY).is_none());
}
#[test]
fn dense_spectral_block_local_cross_trace_matches_dense() {
let h = array![[4.0, 1.0, 0.5], [1.0, 3.0, 0.25], [0.5, 0.25, 2.0],];
let op = DenseSpectralOperator::from_symmetric(&h).unwrap();
let block = array![[1.5, 0.4], [0.4, 0.7]];
let scale = 1.7_f64;
let mut a_full = Array2::<f64>::zeros((3, 3));
for i in 0..2 {
for j in 0..2 {
a_full[[i, j]] = scale * block[[i, j]];
}
}
let hinva = op.solve_multi(&a_full); let expected = (&hinva.t() * &hinva).sum();
let got = op.trace_hinv_block_local_cross(&block, scale, 0, 2);
assert!(
(got - expected).abs() < 1e-10,
"block-local cross trace = {got}, expected = {expected} (delta {})",
got - expected
);
}
#[test]
fn test_reml_laml_evaluate_gaussian_basic() {
let h = Array2::from_diag(&array![10.0, 8.0]);
let op = DenseSpectralOperator::from_symmetric(&h).unwrap();
let solution = InnerSolution {
log_likelihood: -5.0, penalty_quadratic: 2.0,
hessian_op: Arc::new(op),
beta: array![1.0, 0.5],
penalty_coords: vec![PenaltyCoordinate::from_dense_root(
Array2::eye(2), )],
penalty_logdet: PenaltyLogdetDerivs {
value: 0.0,
first: array![1.0],
second: None,
},
deriv_provider: Box::new(GaussianDerivatives),
tk_correction: 0.0,
tk_gradient: None,
firth: None,
hessian_logdet_correction: 0.0,
penalty_subspace_trace: None,
rho_curvature_scale: 1.0,
n_observations: 100,
nullspace_dim: 0.0,
dispersion: DispersionHandling::ProfiledGaussian,
ext_coords: Vec::new(),
ext_coord_pair_fn: None,
rho_ext_pair_fn: None,
fixed_drift_deriv: None,
barrier_config: None,
};
let rho = [0.0];
let result = reml_laml_evaluate(&solution, &rho, EvalMode::ValueOnly, None).unwrap();
assert!(result.cost.is_finite());
assert!(result.gradient.is_none());
let result = reml_laml_evaluate(&solution, &rho, EvalMode::ValueAndGradient, None).unwrap();
assert!(result.cost.is_finite());
assert!(result.gradient.is_some());
let grad = result.gradient.unwrap();
assert_eq!(grad.len(), 1);
assert!(grad[0].is_finite());
}
#[test]
fn fixed_dispersion_firth_cost_subtracts_jeffreys_term() {
let x = array![[1.0, 0.0], [1.0, 1.0], [1.0, -1.0]];
let eta = array![0.0, 0.4, -0.2];
let firth_op = std::sync::Arc::new(
super::super::FirthDenseOperator::build(&x, &eta).expect("firth operator"),
);
let firth_value = firth_op.jeffreys_logdet();
let solution = InnerSolution {
log_likelihood: 0.0,
penalty_quadratic: 0.0,
hessian_op: Arc::new(DenseSpectralOperator::from_symmetric(&Array2::eye(2)).unwrap()),
beta: Array1::zeros(2),
penalty_coords: Vec::new(),
penalty_logdet: PenaltyLogdetDerivs {
value: 0.0,
first: Array1::zeros(0),
second: None,
},
deriv_provider: Box::new(GaussianDerivatives),
tk_correction: 0.0,
tk_gradient: None,
firth: Some(ExactJeffreysTerm::new(firth_op)),
hessian_logdet_correction: 0.0,
penalty_subspace_trace: None,
rho_curvature_scale: 1.0,
n_observations: x.nrows(),
nullspace_dim: 0.0,
dispersion: DispersionHandling::Fixed {
phi: 1.0,
include_logdet_h: true,
include_logdet_s: false,
},
ext_coords: Vec::new(),
ext_coord_pair_fn: None,
rho_ext_pair_fn: None,
fixed_drift_deriv: None,
barrier_config: None,
};
let result = reml_laml_evaluate(&solution, &[], EvalMode::ValueOnly, None).unwrap();
assert_relative_eq!(result.cost, -firth_value, epsilon = 1e-12);
}
struct FixedOuterHessianOperator {
matrix: Array2<f64>,
}
impl crate::solver::outer_strategy::OuterHessianOperator for FixedOuterHessianOperator {
fn dim(&self) -> usize {
self.matrix.nrows()
}
fn matvec(&self, v: &Array1<f64>) -> Result<Array1<f64>, String> {
if v.len() != self.dim() {
return Err(format!(
"fixed test outer Hessian dimension mismatch: got {}, expected {}",
v.len(),
self.dim()
));
}
Ok(self.matrix.dot(v))
}
fn is_cheap_to_materialize(&self) -> bool {
true
}
}
struct FamilyOperatorDerivatives {
op: Arc<dyn crate::solver::outer_strategy::OuterHessianOperator>,
}
impl HessianDerivativeProvider for FamilyOperatorDerivatives {
fn hessian_derivative_correction(
&self,
_: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
panic!("family operator dispatch should not request pairwise first derivatives")
}
fn hessian_second_derivative_correction(
&self,
_: &Array1<f64>,
_: &Array1<f64>,
_: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
panic!("family operator dispatch should not request pairwise second derivatives")
}
fn has_corrections(&self) -> bool {
false
}
fn family_outer_hessian_operator(
&self,
) -> Option<Arc<dyn crate::solver::outer_strategy::OuterHessianOperator>> {
Some(Arc::clone(&self.op))
}
}
#[test]
fn family_outer_hessian_operator_short_circuits_dense_pairwise_assembly() {
let supplied = array![[2.5]];
let provider_op: Arc<dyn crate::solver::outer_strategy::OuterHessianOperator> =
Arc::new(FixedOuterHessianOperator {
matrix: supplied.clone(),
});
let solution = InnerSolution {
log_likelihood: 0.0,
penalty_quadratic: 0.4,
hessian_op: Arc::new(DenseSpectralOperator::from_symmetric(&array![[3.0]]).unwrap()),
beta: array![0.2],
penalty_coords: vec![PenaltyCoordinate::from_dense_root(array![[1.0]])],
penalty_logdet: PenaltyLogdetDerivs {
value: 0.0,
first: array![1.0],
second: Some(array![[0.0]]),
},
deriv_provider: Box::new(FamilyOperatorDerivatives { op: provider_op }),
tk_correction: 0.0,
tk_gradient: None,
firth: None,
hessian_logdet_correction: 0.0,
penalty_subspace_trace: None,
rho_curvature_scale: 1.0,
n_observations: 1,
nullspace_dim: 0.0,
dispersion: DispersionHandling::Fixed {
phi: 1.0,
include_logdet_h: true,
include_logdet_s: true,
},
ext_coords: Vec::new(),
ext_coord_pair_fn: None,
rho_ext_pair_fn: None,
fixed_drift_deriv: None,
barrier_config: None,
};
let result =
reml_laml_evaluate(&solution, &[0.0], EvalMode::ValueGradientHessian, None).unwrap();
let crate::solver::outer_strategy::HessianResult::Operator(op) = result.hessian else {
panic!("expected family-supplied operator Hessian");
};
assert_eq!(op.dim(), 1);
let hv = op.matvec(&array![4.0]).unwrap();
assert_relative_eq!(hv[0], 10.0, epsilon = 1e-12);
let dense = op.materialize_dense().unwrap();
assert_relative_eq!(dense[[0, 0]], supplied[[0, 0]], epsilon = 1e-12);
}
struct FixedCorrectionDerivatives {
correction: Array2<f64>,
}
impl HessianDerivativeProvider for FixedCorrectionDerivatives {
fn hessian_derivative_correction(
&self,
_: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
Ok(Some(self.correction.clone()))
}
fn has_corrections(&self) -> bool {
true
}
}
fn build_projected_rho_gradient_solution(rho: f64) -> InnerSolution<'static> {
let lambda = rho.exp();
let h = array![[3.0 + 4.0 * rho, 0.0], [0.0, 5.0 + lambda],];
let full_logdet = h[[0, 0]].ln() + h[[1, 1]].ln();
let projected_logdet = h[[1, 1]].ln();
InnerSolution {
log_likelihood: 0.0,
penalty_quadratic: 0.0,
hessian_op: Arc::new(
DenseSpectralOperator::from_symmetric_with_mode(&h, PseudoLogdetMode::HardPseudo)
.unwrap(),
),
beta: Array1::zeros(2),
penalty_coords: vec![PenaltyCoordinate::from_dense_root(array![[0.0, 1.0]])],
penalty_logdet: PenaltyLogdetDerivs {
value: 0.0,
first: array![0.0],
second: None,
},
deriv_provider: Box::new(FixedCorrectionDerivatives {
correction: array![[4.0, 0.0], [0.0, 0.0]],
}),
tk_correction: 0.0,
tk_gradient: None,
firth: None,
hessian_logdet_correction: projected_logdet - full_logdet,
penalty_subspace_trace: Some(Arc::new(PenaltySubspaceTrace {
u_s: array![[0.0], [1.0]],
h_proj_inverse: array![[1.0 / h[[1, 1]]]],
})),
rho_curvature_scale: 1.0,
n_observations: 10,
nullspace_dim: 1.0,
dispersion: DispersionHandling::Fixed {
phi: 1.0,
include_logdet_h: true,
include_logdet_s: false,
},
ext_coords: Vec::new(),
ext_coord_pair_fn: None,
rho_ext_pair_fn: None,
fixed_drift_deriv: None,
barrier_config: None,
}
}
#[test]
fn test_rho_gradient_uses_projected_logdet_kernel_when_available() {
let rho = 0.0;
let result = reml_laml_evaluate(
&build_projected_rho_gradient_solution(rho),
&[rho],
EvalMode::ValueAndGradient,
None,
)
.unwrap();
let analytic = result.gradient.expect("gradient")[0];
let eps = 1e-6;
let rho_plus = rho + eps;
let cost_plus = reml_laml_evaluate(
&build_projected_rho_gradient_solution(rho_plus),
&[rho_plus],
EvalMode::ValueOnly,
None,
)
.unwrap()
.cost;
let rho_minus = rho - eps;
let cost_minus = reml_laml_evaluate(
&build_projected_rho_gradient_solution(rho_minus),
&[rho_minus],
EvalMode::ValueOnly,
None,
)
.unwrap()
.cost;
let fd = (cost_plus - cost_minus) / (2.0 * eps);
assert_relative_eq!(analytic, fd, epsilon = 1e-8, max_relative = 1e-8);
let full_space_trace = 4.0 / 3.0 + 1.0 / 6.0;
assert!(
(analytic - 0.5 * full_space_trace).abs() > 0.5,
"projected rho trace should exclude the null-space leakage term"
);
}
#[test]
fn test_rho_corrections_serial_large_work_case_stays_finite() {
let rho = 0.0;
let mut solution = build_projected_rho_gradient_solution(rho);
solution.n_observations = 40_000_000;
let result = reml_laml_evaluate(&solution, &[rho], EvalMode::ValueAndGradient, None)
.expect("serial rho correction evaluation");
assert!(result.cost.is_finite());
let gradient = result.gradient.expect("gradient");
assert_eq!(gradient.len(), 1);
assert!(gradient[0].is_finite());
}
fn build_gaussian_test_solution(rho: &[f64]) -> InnerSolution<'_> {
let p = 3; let n = 50;
let xtx = array![[10.0, 2.0, 1.0], [2.0, 8.0, 0.5], [1.0, 0.5, 6.0],];
let s1 = array![[1.0, 0.2, 0.0], [0.2, 1.0, 0.0], [0.0, 0.0, 0.0],];
let s2 = array![[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 1.0],];
let lambdas: Vec<f64> = rho.iter().map(|&r| r.exp()).collect();
let mut h = xtx.clone();
h.scaled_add(lambdas[0], &s1);
h.scaled_add(lambdas[1], &s2);
let op = DenseSpectralOperator::from_symmetric(&h).unwrap();
let xty = array![5.0, 3.0, 2.0];
let beta = op.solve(&xty);
let r1 = penalty_matrix_root(&s1).unwrap();
let r2 = penalty_matrix_root(&s2).unwrap();
let penalty_quad =
lambdas[0] * beta.dot(&s1.dot(&beta)) + lambdas[1] * beta.dot(&s2.dot(&beta));
let yty = 20.0;
let deviance = yty - 2.0 * beta.dot(&xty) + beta.dot(&xtx.dot(&beta));
let log_likelihood = -0.5 * deviance;
let mut s_total = Array2::zeros((p, p));
s_total.scaled_add(lambdas[0], &s1);
s_total.scaled_add(lambdas[1], &s2);
let (s_eigs, _) = s_total.eigh(faer::Side::Lower).unwrap();
let threshold = positive_eigenvalue_threshold(s_eigs.as_slice().unwrap());
let log_det_s = exact_pseudo_logdet(s_eigs.as_slice().unwrap(), threshold);
let mut det1 = Array1::zeros(rho.len());
let eps = 1e-7;
for k in 0..rho.len() {
let mut rho_plus = rho.to_vec();
rho_plus[k] += eps;
let lambdas_plus: Vec<f64> = rho_plus.iter().map(|&r| r.exp()).collect();
let mut s_plus = Array2::zeros((p, p));
s_plus.scaled_add(lambdas_plus[0], &s1);
s_plus.scaled_add(lambdas_plus[1], &s2);
let (s_eigs_plus, _) = s_plus.eigh(faer::Side::Lower).unwrap();
let threshold_plus = positive_eigenvalue_threshold(s_eigs_plus.as_slice().unwrap());
let log_det_s_plus =
exact_pseudo_logdet(s_eigs_plus.as_slice().unwrap(), threshold_plus);
let mut rho_minus = rho.to_vec();
rho_minus[k] -= eps;
let lambdas_minus: Vec<f64> = rho_minus.iter().map(|&r| r.exp()).collect();
let mut s_minus = Array2::zeros((p, p));
s_minus.scaled_add(lambdas_minus[0], &s1);
s_minus.scaled_add(lambdas_minus[1], &s2);
let (s_eigs_minus, _) = s_minus.eigh(faer::Side::Lower).unwrap();
let threshold_minus = positive_eigenvalue_threshold(s_eigs_minus.as_slice().unwrap());
let log_det_s_minus =
exact_pseudo_logdet(s_eigs_minus.as_slice().unwrap(), threshold_minus);
det1[k] = (log_det_s_plus - log_det_s_minus) / (2.0 * eps);
}
InnerSolution {
log_likelihood,
penalty_quadratic: penalty_quad,
hessian_op: Arc::new(op),
beta,
penalty_coords: vec![
PenaltyCoordinate::from_dense_root(r1),
PenaltyCoordinate::from_dense_root(r2),
],
penalty_logdet: PenaltyLogdetDerivs {
value: log_det_s,
first: det1,
second: None,
},
deriv_provider: Box::new(GaussianDerivatives),
tk_correction: 0.0,
tk_gradient: None,
firth: None,
hessian_logdet_correction: 0.0,
penalty_subspace_trace: None,
rho_curvature_scale: 1.0,
n_observations: n,
nullspace_dim: 0.0,
dispersion: DispersionHandling::ProfiledGaussian,
ext_coords: Vec::new(),
ext_coord_pair_fn: None,
rho_ext_pair_fn: None,
fixed_drift_deriv: None,
barrier_config: None,
}
}
fn build_large_dense_spectral_gaussian_solution(rho: f64) -> InnerSolution<'static> {
let p = 520usize;
let n = 2 * p;
let lambda = rho.exp();
let xtx_diag = Array1::from_shape_fn(p, |i| 5.0 + 0.01 * (i as f64));
let xtx = Array2::from_diag(&xtx_diag);
let penalty = Array2::<f64>::eye(p);
let mut h = xtx.clone();
h.scaled_add(lambda, &penalty);
let op = DenseSpectralOperator::from_symmetric(&h).unwrap();
let xty = Array1::from_shape_fn(p, |i| 1.0 + 0.002 * (i as f64));
let beta = op.solve(&xty);
let penalty_quad = lambda * beta.dot(&beta);
let yty = 10.0 * (p as f64);
let deviance = yty - 2.0 * beta.dot(&xty) + beta.dot(&xtx.dot(&beta));
let log_likelihood = -0.5 * deviance;
InnerSolution {
log_likelihood,
penalty_quadratic: penalty_quad,
hessian_op: Arc::new(op),
beta,
penalty_coords: vec![PenaltyCoordinate::from_dense_root(Array2::<f64>::eye(p))],
penalty_logdet: PenaltyLogdetDerivs {
value: (p as f64) * rho,
first: array![p as f64],
second: None,
},
deriv_provider: Box::new(GaussianDerivatives),
tk_correction: 0.0,
tk_gradient: None,
firth: None,
hessian_logdet_correction: 0.0,
penalty_subspace_trace: None,
rho_curvature_scale: 1.0,
n_observations: n,
nullspace_dim: 0.0,
dispersion: DispersionHandling::ProfiledGaussian,
ext_coords: Vec::new(),
ext_coord_pair_fn: None,
rho_ext_pair_fn: None,
fixed_drift_deriv: None,
barrier_config: None,
}
}
#[test]
fn test_gaussian_reml_fd_vs_analytic_gradient() {
let rho = vec![1.0, -0.5];
let solution = build_gaussian_test_solution(&rho);
let result = reml_laml_evaluate(&solution, &rho, EvalMode::ValueAndGradient, None).unwrap();
let analytic_grad = result.gradient.unwrap();
let eps = 1e-5;
let mut fd_grad = Array1::zeros(rho.len());
for k in 0..rho.len() {
let mut rho_plus = rho.clone();
rho_plus[k] += eps;
let sol_plus = build_gaussian_test_solution(&rho_plus);
let cost_plus = reml_laml_evaluate(&sol_plus, &rho_plus, EvalMode::ValueOnly, None)
.unwrap()
.cost;
let mut rho_minus = rho.clone();
rho_minus[k] -= eps;
let sol_minus = build_gaussian_test_solution(&rho_minus);
let cost_minus = reml_laml_evaluate(&sol_minus, &rho_minus, EvalMode::ValueOnly, None)
.unwrap()
.cost;
fd_grad[k] = (cost_plus - cost_minus) / (2.0 * eps);
}
for k in 0..rho.len() {
let abs_err = (analytic_grad[k] - fd_grad[k]).abs();
let rel_err = abs_err / (1.0 + analytic_grad[k].abs());
assert!(
rel_err < 1e-4,
"Gradient mismatch at k={}: analytic={:.8e}, fd={:.8e}, rel_err={:.3e}",
k,
analytic_grad[k],
fd_grad[k],
rel_err,
);
}
}
#[test]
fn test_stochastic_trace_estimator_accuracy() {
let h = array![[4.0, 1.0, 0.5], [1.0, 3.0, 0.2], [0.5, 0.2, 2.0],];
let a1 = array![[1.0, 0.3, 0.0], [0.3, 0.5, 0.1], [0.0, 0.1, 0.2],];
let a2 = array![[0.2, 0.0, 0.1], [0.0, 1.0, 0.4], [0.1, 0.4, 0.8],];
let op = DenseSpectralOperator::from_symmetric(&h).unwrap();
let exact1 = op.trace_hinv_product(&a1);
let exact2 = op.trace_hinv_product(&a2);
let config = StochasticTraceConfig {
n_probes_min: 50,
n_probes_max: 200,
relative_tol: 0.005,
tau_rel: 1e-10,
solve_rel_tol: 1e-8,
seed: 42,
hutchpp_sketch_dim: None,
};
let estimator = StochasticTraceEstimator::new(config);
let matrices: Vec<&Array2<f64>> = vec![&a1, &a2];
let estimates = estimator.estimate_traces(&op, &matrices);
let rel_err1 = (estimates[0] - exact1).abs() / exact1.abs().max(1e-10);
let rel_err2 = (estimates[1] - exact2).abs() / exact2.abs().max(1e-10);
assert!(
rel_err1 < 0.05,
"Stochastic trace 1: est={:.6}, exact={:.6}, rel_err={:.4}",
estimates[0],
exact1,
rel_err1,
);
assert!(
rel_err2 < 0.05,
"Stochastic trace 2: est={:.6}, exact={:.6}, rel_err={:.4}",
estimates[1],
exact2,
rel_err2,
);
}
#[test]
fn modified_gram_schmidt_orthonormalizes_well_conditioned_input() {
let y = array![
[1.0, 2.0, 0.5, 3.0],
[0.0, 1.0, 0.5, 1.5],
[0.0, 0.0, 1.0, 0.5],
[0.0, 0.0, 0.0, 1.0],
];
let mut q = Array2::<f64>::zeros(y.dim());
let rank = modified_gram_schmidt(&y, &mut q);
assert_eq!(rank, 4, "well-conditioned input should retain full rank");
for j in 0..rank {
for k in 0..rank {
let dot = q.column(j).dot(&q.column(k));
let expected = if j == k { 1.0 } else { 0.0 };
assert!(
(dot - expected).abs() < 1e-12,
"QᵀQ off-identity at ({j},{k}): got {dot}",
);
}
}
}
#[test]
fn modified_gram_schmidt_drops_redundant_columns() {
let y = array![
[1.0, 2.0, 1.0, 4.0],
[0.0, 1.0, 0.0, 2.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
];
let mut q = Array2::<f64>::zeros(y.dim());
let rank = modified_gram_schmidt(&y, &mut q);
assert_eq!(
rank, 2,
"two duplicate columns plus a zero-extension should drop to rank 2"
);
for j in 0..rank {
for k in 0..rank {
let dot = q.column(j).dot(&q.column(k));
let expected = if j == k { 1.0 } else { 0.0 };
assert!((dot - expected).abs() < 1e-12);
}
}
}
#[test]
fn hutchpp_estimate_trace_hinv_operator_matches_exact_within_tolerance() {
let h = array![
[4.0, 1.0, 0.5, 0.0, 0.0, 0.0],
[1.0, 3.0, 0.2, 0.0, 0.0, 0.0],
[0.5, 0.2, 2.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 5.0, 0.7, 0.1],
[0.0, 0.0, 0.0, 0.7, 4.0, 0.3],
[0.0, 0.0, 0.0, 0.1, 0.3, 3.0],
];
let m = array![
[1.0, 0.3, 0.0, 0.1, 0.0, 0.0],
[0.3, 0.5, 0.1, 0.0, 0.2, 0.0],
[0.0, 0.1, 0.2, 0.0, 0.0, 0.05],
[0.1, 0.0, 0.0, 0.8, 0.2, 0.0],
[0.0, 0.2, 0.0, 0.2, 0.6, 0.1],
[0.0, 0.0, 0.05, 0.0, 0.1, 0.4],
];
let hop = DenseSpectralOperator::from_symmetric(&h).unwrap();
let m_op = DenseMatrixHyperOperator { matrix: m.clone() };
let exact = hop.trace_hinv_product(&m);
let config = StochasticTraceConfig {
n_probes_min: 12,
n_probes_max: 64,
relative_tol: 0.005,
tau_rel: 1e-10,
solve_rel_tol: 1e-10,
seed: 0xABCDEF,
hutchpp_sketch_dim: Some(3),
};
let est = hutchpp_estimate_trace_hinv_operator(&hop, &m_op, &config);
let rel_err = (est - exact).abs() / exact.abs().max(1e-10);
assert!(
rel_err < 0.05,
"Hutch++ trace est={est:.6} exact={exact:.6} rel_err={rel_err:.4}"
);
let mut config_plain = config.clone();
config_plain.hutchpp_sketch_dim = None;
config_plain.n_probes_max = 64; let est_plain = hutchpp_estimate_trace_hinv_operator(&hop, &m_op, &config_plain);
let rel_err_plain = (est_plain - exact).abs() / exact.abs().max(1e-10);
assert!(
rel_err <= rel_err_plain * 2.0 + 0.01,
"Hutch++ ({rel_err:.4}) should be competitive with Hutchinson ({rel_err_plain:.4})"
);
}
#[test]
fn hutchpp_estimate_trace_hinv_op_squared_matches_exact() {
let h = array![
[4.0, 1.0, 0.5, 0.0, 0.0, 0.0],
[1.0, 3.0, 0.2, 0.0, 0.0, 0.0],
[0.5, 0.2, 2.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 5.0, 0.7, 0.1],
[0.0, 0.0, 0.0, 0.7, 4.0, 0.3],
[0.0, 0.0, 0.0, 0.1, 0.3, 3.0],
];
let a = array![
[1.0, 0.3, 0.0, 0.1, 0.0, 0.0],
[0.3, 0.5, 0.1, 0.0, 0.2, 0.0],
[0.0, 0.1, 0.2, 0.0, 0.0, 0.05],
[0.1, 0.0, 0.0, 0.8, 0.2, 0.0],
[0.0, 0.2, 0.0, 0.2, 0.6, 0.1],
[0.0, 0.0, 0.05, 0.0, 0.1, 0.4],
];
let hop = DenseSpectralOperator::from_symmetric(&h).unwrap();
let a_op = DenseMatrixHyperOperator { matrix: a.clone() };
let exact = hop.trace_hinv_product_cross(&a, &a);
let config = StochasticTraceConfig {
n_probes_min: 16,
n_probes_max: 96,
relative_tol: 0.005,
tau_rel: 1e-10,
solve_rel_tol: 1e-10,
seed: 0xC0FFEE,
hutchpp_sketch_dim: Some(3),
};
let est = hutchpp_estimate_trace_hinv_op_squared(&hop, &a_op, &config);
let rel_err = (est - exact).abs() / exact.abs().max(1e-10);
assert!(
rel_err < 0.05,
"Hutch++ tr((H⁻¹A)²) est={est:.6} exact={exact:.6} rel_err={rel_err:.4}"
);
let estimator = StochasticTraceEstimator::new(config.clone());
let est_wired = estimator.estimate_second_order_single_dense(&hop, &a);
let rel_err_wired = (est_wired - exact).abs() / exact.abs().max(1e-10);
assert!(
rel_err_wired < 0.05,
"wired Hutch++ second-order est={est_wired:.6} exact={exact:.6} rel_err={rel_err_wired:.4}"
);
assert!(
(est_wired - est).abs() <= 1e-12,
"wired path must call hutchpp_estimate_trace_hinv_op_squared with the same seed/config"
);
}
#[test]
fn hutchpp_estimate_trace_hinv_operator_cross_matches_exact() {
let h = array![
[4.0, 1.0, 0.5, 0.0, 0.0, 0.0],
[1.0, 3.0, 0.2, 0.0, 0.0, 0.0],
[0.5, 0.2, 2.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 5.0, 0.7, 0.1],
[0.0, 0.0, 0.0, 0.7, 4.0, 0.3],
[0.0, 0.0, 0.0, 0.1, 0.3, 3.0],
];
let a = array![
[1.0, 0.3, 0.0, 0.1, 0.0, 0.0],
[0.3, 0.5, 0.1, 0.0, 0.2, 0.0],
[0.0, 0.1, 0.2, 0.0, 0.0, 0.05],
[0.1, 0.0, 0.0, 0.8, 0.2, 0.0],
[0.0, 0.2, 0.0, 0.2, 0.6, 0.1],
[0.0, 0.0, 0.05, 0.0, 0.1, 0.4],
];
let b = array![
[0.5, 0.0, 0.1, 0.0, 0.05, 0.0],
[0.0, 0.7, 0.0, 0.2, 0.0, 0.1],
[0.1, 0.0, 0.4, 0.0, 0.15, 0.0],
[0.0, 0.2, 0.0, 0.6, 0.0, 0.05],
[0.05, 0.0, 0.15, 0.0, 0.3, 0.0],
[0.0, 0.1, 0.0, 0.05, 0.0, 0.5],
];
let hop = DenseSpectralOperator::from_symmetric(&h).unwrap();
let a_op = DenseMatrixHyperOperator { matrix: a.clone() };
let b_op = DenseMatrixHyperOperator { matrix: b.clone() };
let exact = hop.trace_hinv_product_cross(&a, &b);
let config = StochasticTraceConfig {
n_probes_min: 16,
n_probes_max: 128,
relative_tol: 0.005,
tau_rel: 1e-10,
solve_rel_tol: 1e-10,
seed: 0xDEAD_BEEF,
hutchpp_sketch_dim: Some(3),
};
let est = hutchpp_estimate_trace_hinv_operator_cross(&hop, &a_op, &b_op, &config);
let rel_err = (est - exact).abs() / exact.abs().max(1e-10);
assert!(
rel_err < 0.07,
"Hutch++ cross trace est={est:.6} exact={exact:.6} rel_err={rel_err:.4}"
);
}
#[test]
fn trace_hinv_operator_cross_default_routes_implicit_to_hutchpp() {
let p = 200usize;
let mut h = Array2::<f64>::zeros((p, p));
for i in 0..p {
h[[i, i]] = 5.0 + (i as f64) * 0.01;
if i + 1 < p {
h[[i, i + 1]] = 0.2;
h[[i + 1, i]] = 0.2;
}
}
let mut a = Array2::<f64>::zeros((p, p));
for i in 0..p {
a[[i, i]] = 1.0 + 0.005 * (i as f64);
if i + 2 < p {
a[[i, i + 2]] = 0.1;
a[[i + 2, i]] = 0.1;
}
}
let hop = DenseSpectralOperator::from_symmetric(&h).unwrap();
struct ImplicitDense(Array2<f64>);
impl HyperOperator for ImplicitDense {
fn dim(&self) -> usize {
self.0.nrows()
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(self.0.nrows());
dense_matvec_into(&self.0, v.view(), out.view_mut());
out
}
fn mul_vec_into(&self, v: ArrayView1<'_, f64>, out: ArrayViewMut1<'_, f64>) {
dense_matvec_into(&self.0, v, out);
}
fn to_dense(&self) -> Array2<f64> {
self.0.clone()
}
fn is_implicit(&self) -> bool {
true
}
}
let a_op = ImplicitDense(a.clone());
let exact = hop.trace_hinv_product_cross(&a, &a);
let est_same = hop.trace_hinv_operator_cross(&a_op, &a_op);
assert!(est_same.is_finite(), "cross trace must be finite");
let rel_err_same = (est_same - exact).abs() / exact.abs().max(1e-10);
assert!(
rel_err_same < 0.10,
"default same-op cross routing est={est_same:.6} exact={exact:.6} rel_err={rel_err_same:.4}"
);
let mut b = Array2::<f64>::zeros((p, p));
for i in 0..p {
b[[i, i]] = 0.6 + 0.003 * (i as f64);
if i + 1 < p {
b[[i, i + 1]] = 0.05;
b[[i + 1, i]] = 0.05;
}
}
let b_op = ImplicitDense(b.clone());
let exact_ab = hop.trace_hinv_product_cross(&a, &b);
let est_ab = hop.trace_hinv_operator_cross(&a_op, &b_op);
assert!(est_ab.is_finite(), "cross trace (a,b) must be finite");
let rel_err_ab = (est_ab - exact_ab).abs() / exact_ab.abs().max(1e-10);
assert!(
rel_err_ab < 0.10,
"default distinct-op cross routing est={est_ab:.6} exact={exact_ab:.6} rel_err={rel_err_ab:.4}"
);
let exact_ma = hop.trace_hinv_product_cross(&a, &b);
let est_ma = hop.trace_hinv_matrix_operator_cross(&a, &b_op);
assert!(est_ma.is_finite(), "matrix-op cross trace must be finite");
let rel_err_ma = (est_ma - exact_ma).abs() / exact_ma.abs().max(1e-10);
assert!(
rel_err_ma < 0.10,
"default matrix-operator cross routing est={est_ma:.6} exact={exact_ma:.6} rel_err={rel_err_ma:.4}"
);
}
#[test]
fn dense_spectral_large_p_outer_gradient_matches_finite_difference() {
let rho = 0.2;
let solution = build_large_dense_spectral_gaussian_solution(rho);
let result =
reml_laml_evaluate(&solution, &[rho], EvalMode::ValueAndGradient, None).unwrap();
let analytic = result.gradient.expect("gradient")[0];
let eps = 1e-5;
let rho_plus = rho + eps;
let solution_plus = build_large_dense_spectral_gaussian_solution(rho_plus);
let cost_plus = reml_laml_evaluate(&solution_plus, &[rho_plus], EvalMode::ValueOnly, None)
.unwrap()
.cost;
let rho_minus = rho - eps;
let solution_minus = build_large_dense_spectral_gaussian_solution(rho_minus);
let cost_minus =
reml_laml_evaluate(&solution_minus, &[rho_minus], EvalMode::ValueOnly, None)
.unwrap()
.cost;
let fd = (cost_plus - cost_minus) / (2.0 * eps);
let rel_err = (analytic - fd).abs() / (1.0 + analytic.abs());
assert!(
rel_err < 2e-4,
"large-p dense spectral gradient mismatch: analytic={analytic:.8e}, fd={fd:.8e}, rel_err={rel_err:.3e}"
);
}
#[test]
fn dense_spectral_logdet_traces_do_not_claim_hinv_kernel_equivalence() {
let h = array![[4.0, 1.0], [1.0, 3.0]];
let op = DenseSpectralOperator::from_symmetric(&h).unwrap();
assert!(!op.prefers_stochastic_trace_estimation());
assert!(!op.logdet_traces_match_hinv_kernel());
assert!(!can_use_stochastic_logdet_hinv_kernel(&op, 1024, true));
let block = BlockCoupledOperator::from_joint_hessian(&h).unwrap();
assert!(!block.prefers_stochastic_trace_estimation());
assert!(!block.logdet_traces_match_hinv_kernel());
assert!(!can_use_stochastic_logdet_hinv_kernel(&block, 1024, true));
}
#[test]
fn dense_spectral_hinv_cross_matches_solve_contraction() {
let h = array![[4.0, 1.0, 0.5], [1.0, 3.0, 0.25], [0.5, 0.25, 2.0],];
let a = array![[1.0, 0.2, 0.1], [0.2, 0.5, 0.0], [0.1, 0.0, 0.3],];
let b = array![[0.3, 0.1, 0.0], [0.1, 0.8, 0.2], [0.0, 0.2, 0.6],];
let op = DenseSpectralOperator::from_symmetric(&h).unwrap();
let exact = op.trace_hinv_product_cross(&a, &b);
let solved_a = op.solve_multi(&a);
let solved_b = op.solve_multi(&b);
let reference = (&solved_a.t() * &solved_b).sum();
assert_relative_eq!(exact, reference, epsilon = 1e-10, max_relative = 1e-10);
}
#[test]
fn dense_spectral_batched_logdet_crosses_match_pairwise() {
let h = array![[4.0, 1.0, 0.5], [1.0, 3.0, 0.25], [0.5, 0.25, 2.0],];
let h1 = array![[1.0, 0.2, 0.1], [0.2, 0.5, 0.0], [0.1, 0.0, 0.3],];
let h2 = array![[0.3, 0.1, 0.0], [0.1, 0.8, 0.2], [0.0, 0.2, 0.6],];
let h3 = array![[0.7, 0.0, 0.2], [0.0, 0.4, 0.1], [0.2, 0.1, 0.9],];
let op = DenseSpectralOperator::from_symmetric(&h).unwrap();
let mats = [&h1, &h2, &h3];
let batched = op.trace_logdet_hessian_crosses(&mats);
for i in 0..mats.len() {
for j in 0..mats.len() {
let pairwise = op.trace_logdet_hessian_cross(mats[i], mats[j]);
assert_relative_eq!(
batched[[i, j]],
pairwise,
epsilon = 1e-10,
max_relative = 1e-10
);
}
}
}
#[test]
fn sparse_block_local_trace_without_takahashi_matches_dense_reference() {
let h = array![
[5.0, 0.2, 0.0, 0.1],
[0.2, 4.0, 0.3, 0.0],
[0.0, 0.3, 3.0, 0.4],
[0.1, 0.0, 0.4, 2.5],
];
let h_sparse =
crate::linalg::sparse_exact::dense_to_sparse_symmetric_upper(&h, 0.0).unwrap();
let factor = std::sync::Arc::new(
crate::linalg::sparse_exact::factorize_sparse_spd(&h_sparse).unwrap(),
);
let sparse = SparseCholeskyOperator::new(factor, 0.0, h.nrows());
let dense = DenseSpectralOperator::from_symmetric(&h).unwrap();
let block = array![[0.8, 0.15], [0.15, 0.45]];
let scale = 1.7;
let start = 1;
let end = 3;
let mut full = Array2::<f64>::zeros(h.raw_dim());
for i in 0..block.nrows() {
for j in 0..block.ncols() {
full[[start + i, start + j]] = scale * block[[i, j]];
}
}
assert_relative_eq!(
sparse.trace_hinv_block_local(&block, scale, start, end),
dense.trace_hinv_product(&full),
epsilon = 1e-10,
max_relative = 1e-10
);
assert_relative_eq!(
sparse.trace_hinv_block_local_cross(&block, scale, start, end),
dense.trace_hinv_product_cross(&full, &full),
epsilon = 1e-10,
max_relative = 1e-10
);
}
#[test]
fn sparse_block_local_operator_cross_without_takahashi_matches_dense_reference() {
let h = array![
[5.0, 0.2, 0.0, 0.1],
[0.2, 4.0, 0.3, 0.0],
[0.0, 0.3, 3.0, 0.4],
[0.1, 0.0, 0.4, 2.5],
];
let h_sparse =
crate::linalg::sparse_exact::dense_to_sparse_symmetric_upper(&h, 0.0).unwrap();
let factor = std::sync::Arc::new(
crate::linalg::sparse_exact::factorize_sparse_spd(&h_sparse).unwrap(),
);
let sparse = SparseCholeskyOperator::new(factor, 0.0, h.nrows());
let dense = DenseSpectralOperator::from_symmetric(&h).unwrap();
let local = array![[0.8, 0.15], [0.15, 0.45]];
let start = 1;
let end = 3;
let op = BlockLocalDrift {
local: local.clone(),
start,
end,
total_dim: h.nrows(),
};
let mut full = Array2::<f64>::zeros(h.raw_dim());
full.slice_mut(ndarray::s![start..end, start..end])
.assign(&local);
assert_relative_eq!(
sparse.trace_hinv_operator_cross(&op, &op),
dense.trace_hinv_product_cross(&full, &full),
epsilon = 1e-10,
max_relative = 1e-10
);
}
#[test]
fn sparse_matrix_block_operator_cross_without_takahashi_matches_dense_reference() {
let h = array![
[5.0, 0.2, 0.0, 0.1],
[0.2, 4.0, 0.3, 0.0],
[0.0, 0.3, 3.0, 0.4],
[0.1, 0.0, 0.4, 2.5],
];
let h_sparse =
crate::linalg::sparse_exact::dense_to_sparse_symmetric_upper(&h, 0.0).unwrap();
let factor = std::sync::Arc::new(
crate::linalg::sparse_exact::factorize_sparse_spd(&h_sparse).unwrap(),
);
let sparse = SparseCholeskyOperator::new(factor, 0.0, h.nrows());
let dense = DenseSpectralOperator::from_symmetric(&h).unwrap();
let matrix = array![
[1.0, 0.2, -0.1, 0.3],
[0.2, 0.7, 0.4, -0.2],
[-0.1, 0.4, 1.2, 0.5],
[0.3, -0.2, 0.5, 0.9],
];
let local = array![[0.8, 0.15], [0.15, 0.45]];
let start = 1;
let end = 3;
let op = BlockLocalDrift {
local: local.clone(),
start,
end,
total_dim: h.nrows(),
};
let mut full = Array2::<f64>::zeros(h.raw_dim());
full.slice_mut(ndarray::s![start..end, start..end])
.assign(&local);
assert_relative_eq!(
sparse.trace_hinv_matrix_operator_cross(&matrix, &op),
dense.trace_hinv_product_cross(&matrix, &full),
epsilon = 1e-10,
max_relative = 1e-10
);
}
#[test]
fn sparse_takahashi_trace_hinv_product_pairs_symmetric_lookups() {
let h = array![[4.0, 0.2, 0.1], [0.2, 3.0, 0.4], [0.1, 0.4, 2.5],];
let h_sparse =
crate::linalg::sparse_exact::dense_to_sparse_symmetric_upper(&h, 0.0).unwrap();
let factor = std::sync::Arc::new(
crate::linalg::sparse_exact::factorize_sparse_spd(&h_sparse).unwrap(),
);
let sfactor = crate::linalg::sparse_exact::factorize_simplicial(&h_sparse).unwrap();
let taka = std::sync::Arc::new(
crate::linalg::sparse_exact::TakahashiInverse::compute(&sfactor).unwrap(),
);
let sparse = SparseCholeskyOperator::new(factor, 0.0, h.nrows()).with_takahashi(taka);
let dense = DenseSpectralOperator::from_symmetric(&h).unwrap();
let a = array![[1.0, 0.7, -0.2], [0.1, 0.5, 0.9], [0.4, -0.3, 0.2],];
assert_relative_eq!(
sparse.trace_hinv_product(&a),
dense.trace_hinv_product(&a),
epsilon = 1e-10,
max_relative = 1e-10
);
}
#[test]
fn hyper_operator_bilinear_view_matches_owned_bilinear() {
let dense = DenseMatrixHyperOperator {
matrix: array![[2.0, 0.3, -0.1], [0.3, 1.5, 0.4], [-0.1, 0.4, 3.0],],
};
let block = BlockLocalDrift {
local: array![[1.2, 0.2], [0.2, 0.7]],
start: 1,
end: 3,
total_dim: 3,
};
let composite = CompositeHyperOperator {
dense: Some(array![[0.4, 0.1, 0.0], [0.1, 0.8, -0.2], [0.0, -0.2, 0.6],]),
operators: vec![Arc::new(block.clone())],
dim_hint: 3,
};
let weighted = WeightedHyperOperator {
terms: vec![
(1.7, Arc::new(dense.clone()) as Arc<dyn HyperOperator>),
(-0.4, Arc::new(block.clone()) as Arc<dyn HyperOperator>),
],
dim_hint: 3,
};
let v_storage = array![9.0, 0.5, -1.2, 0.7, 8.0];
let u_storage = array![7.0, -0.3, 1.1, 0.9, 6.0];
let v_view = v_storage.slice(ndarray::s![1..4]);
let u_view = u_storage.slice(ndarray::s![1..4]);
let v_owned = v_view.to_owned();
let u_owned = u_view.to_owned();
let operators: [&dyn HyperOperator; 4] = [&dense, &block, &composite, &weighted];
for op in operators {
assert_relative_eq!(
op.bilinear_view(v_view, u_view),
op.bilinear(&v_owned, &u_owned),
epsilon = 1e-12,
max_relative = 1e-12
);
}
}
#[test]
fn hyper_operator_scaled_add_mul_vec_matches_owned_matvec() {
let dense = DenseMatrixHyperOperator {
matrix: array![[2.0, 0.3, -0.1], [0.3, 1.5, 0.4], [-0.1, 0.4, 3.0],],
};
let block = BlockLocalDrift {
local: array![[1.2, 0.2], [0.2, 0.7]],
start: 1,
end: 3,
total_dim: 3,
};
let composite = CompositeHyperOperator {
dense: Some(array![[0.4, 0.1, 0.0], [0.1, 0.8, -0.2], [0.0, -0.2, 0.6],]),
operators: vec![Arc::new(block.clone())],
dim_hint: 3,
};
let weighted = WeightedHyperOperator {
terms: vec![
(1.7, Arc::new(dense.clone()) as Arc<dyn HyperOperator>),
(-0.4, Arc::new(block.clone()) as Arc<dyn HyperOperator>),
(0.0, Arc::new(composite.clone()) as Arc<dyn HyperOperator>),
],
dim_hint: 3,
};
let v_storage = array![9.0, 0.5, -1.2, 0.7, 8.0];
let v_view = v_storage.slice(ndarray::s![1..4]);
let v_owned = v_view.to_owned();
let base = array![0.25, -0.5, 1.5];
let scale = -1.3;
let operators: [&dyn HyperOperator; 4] = [&dense, &block, &composite, &weighted];
for op in operators {
let mut accumulated = base.clone();
op.scaled_add_mul_vec(v_view, scale, accumulated.view_mut());
let mut expected = base.clone();
expected.scaled_add(scale, &op.mul_vec(&v_owned));
for idx in 0..accumulated.len() {
assert_relative_eq!(
accumulated[idx],
expected[idx],
epsilon = 1e-12,
max_relative = 1e-12
);
}
}
}
#[test]
fn stochastic_single_second_order_estimators_match_batched_paths() {
let diag = array![4.0, 3.0, 2.0];
let hop = MatrixFreeSpdOperator::new(diag.len(), move |v| &diag * v);
let estimator = StochasticTraceEstimator::with_defaults();
let dense = array![[0.8, 0.2, 0.0], [0.2, 0.5, 0.1], [0.0, 0.1, 0.7],];
let op = DenseMatrixHyperOperator {
matrix: dense.clone(),
};
let no_ops: [&dyn HyperOperator; 0] = [];
let dense_refs = [&dense];
let batched_dense =
estimator.estimate_second_order_traces_with_operators(&hop, &dense_refs, &no_ops);
assert_relative_eq!(
estimator.estimate_second_order_single_dense(&hop, &dense),
batched_dense[[0, 0]],
epsilon = 1e-12,
max_relative = 1e-12
);
let no_dense: [&Array2<f64>; 0] = [];
let op_refs: [&dyn HyperOperator; 1] = [&op];
let batched_op =
estimator.estimate_second_order_traces_with_operators(&hop, &no_dense, &op_refs);
assert_relative_eq!(
estimator.estimate_second_order_single_operator(&hop, &op),
batched_op[[0, 0]],
epsilon = 1e-12,
max_relative = 1e-12
);
}
#[test]
fn matrix_free_logdet_traces_use_exact_spectral_algebra() {
let diag = array![4.0, 3.0, 2.0];
let h = Array2::from_diag(&diag);
let dense = DenseSpectralOperator::from_symmetric(&h).unwrap();
let op = MatrixFreeSpdOperator::new(diag.len(), move |v| &diag * v);
let a = array![[0.7, 0.1, 0.0], [0.1, 0.4, 0.2], [0.0, 0.2, 0.5]];
assert_relative_eq!(op.logdet(), dense.logdet(), epsilon = 1e-12);
assert_relative_eq!(
op.trace_hinv_product(&a),
dense.trace_hinv_product(&a),
epsilon = 1e-12
);
assert_relative_eq!(
op.trace_logdet_hessian_cross(&a, &a),
dense.trace_logdet_hessian_cross(&a, &a),
epsilon = 1e-12
);
assert!(!op.prefers_stochastic_trace_estimation());
assert!(!op.logdet_traces_match_hinv_kernel());
assert!(!can_use_stochastic_logdet_hinv_kernel(&op, 1024, true));
assert!(!can_use_stochastic_logdet_hinv_kernel(&op, 128, true));
assert!(!can_use_stochastic_logdet_hinv_kernel(&op, 1024, false));
}
#[test]
fn test_rademacher_probe_properties() {
let mut rng = Xoshiro256SS::from_seed(99);
let mut z = Array1::zeros(100);
rademacher_probe_into(z.view_mut(), &mut rng);
assert_eq!(z.len(), 100);
for &v in z.iter() {
assert!(v == 1.0 || v == -1.0, "Rademacher entry must be +/-1");
}
let mut rng2 = Xoshiro256SS::from_seed(99);
let mut z2 = Array1::zeros(100);
rademacher_probe_into(z2.view_mut(), &mut rng2);
assert_eq!(z, z2, "Same seed must produce identical probes");
}
#[test]
fn test_spectral_logdet_gradient_fd() {
let t0 = 0.0_f64;
let h_step = 1e-6;
let dh_dt = Array2::from_diag(&array![1.0, 2.0, -1.0]);
let h0 = Array2::from_diag(&array![2.0 + t0, 0.01 + 2.0 * t0, 3.0 - t0]);
let op0 = DenseSpectralOperator::from_symmetric(&h0).unwrap();
let analytic = op0.trace_logdet_gradient(&dh_dt);
let h_plus = Array2::from_diag(&array![
2.0 + t0 + h_step,
0.01 + 2.0 * (t0 + h_step),
3.0 - (t0 + h_step)
]);
let h_minus = Array2::from_diag(&array![
2.0 + t0 - h_step,
0.01 + 2.0 * (t0 - h_step),
3.0 - (t0 - h_step)
]);
let op_plus = DenseSpectralOperator::from_symmetric(&h_plus).unwrap();
let op_minus = DenseSpectralOperator::from_symmetric(&h_minus).unwrap();
let fd = (op_plus.logdet() - op_minus.logdet()) / (2.0 * h_step);
let rel_err = (analytic - fd).abs() / fd.abs().max(1e-12);
assert!(
rel_err < 1e-5,
"Spectral logdet gradient mismatch: analytic={:.10e}, fd={:.10e}, rel_err={:.3e}",
analytic,
fd,
rel_err,
);
}
fn rotating_nullspace_penalty(psi: f64, s1: f64, s2: f64) -> Array2<f64> {
let c = psi.cos();
let s = psi.sin();
let r = array![[c, 0.0, -s], [0.0, 1.0, 0.0], [s, 0.0, c],];
let d = Array2::from_diag(&array![s1, s2, 0.0]);
r.dot(&d).dot(&r.t())
}
fn pseudo_logdet(s: &Array2<f64>, tol: f64) -> f64 {
let (eigs, _) = s.eigh(faer::Side::Lower).unwrap();
eigs.iter().filter(|&&v| v > tol).map(|v| v.ln()).sum()
}
fn pseudo_logdet_fd_first(psi: f64, h: f64, s1: f64, s2: f64, tol: f64) -> f64 {
let sp = rotating_nullspace_penalty(psi + h, s1, s2);
let sm = rotating_nullspace_penalty(psi - h, s1, s2);
(pseudo_logdet(&sp, tol) - pseudo_logdet(&sm, tol)) / (2.0 * h)
}
fn pseudo_logdet_fd_second(psi: f64, h: f64, s1: f64, s2: f64, tol: f64) -> f64 {
let sp = pseudo_logdet(&rotating_nullspace_penalty(psi + h, s1, s2), tol);
let s0 = pseudo_logdet(&rotating_nullspace_penalty(psi, s1, s2), tol);
let sm = pseudo_logdet(&rotating_nullspace_penalty(psi - h, s1, s2), tol);
(sp - 2.0 * s0 + sm) / (h * h)
}
fn analytic_pseudo_logdet_second(psi: f64, s1: f64, s2: f64, tol: f64) -> (f64, f64) {
let s_mat = rotating_nullspace_penalty(psi, s1, s2);
let (eigs, vecs) = s_mat.eigh(faer::Side::Lower).unwrap();
let p = eigs.len();
let pos_idx: Vec<usize> = (0..p).filter(|&i| eigs[i] > tol).collect();
let null_idx: Vec<usize> = (0..p).filter(|&i| eigs[i] <= tol).collect();
let c = psi.cos();
let s = psi.sin();
let r = array![[c, 0.0, -s], [0.0, 1.0, 0.0], [s, 0.0, c],];
let rp = array![[-s, 0.0, -c], [0.0, 0.0, 0.0], [c, 0.0, -s],];
let d = Array2::from_diag(&array![s1, s2, 0.0]);
let s_psi = rp.dot(&d).dot(&r.t()) + r.dot(&d).dot(&rp.t());
let rpp = array![[-c, 0.0, s], [0.0, 0.0, 0.0], [-s, 0.0, -c],];
let s_psi_psi =
rpp.dot(&d).dot(&r.t()) + 2.0 * &rp.dot(&d).dot(&rp.t()) + r.dot(&d).dot(&rpp.t());
let mut s_dag = Array2::<f64>::zeros((p, p));
for &i in &pos_idx {
let col = vecs.column(i);
for r in 0..p {
for c2 in 0..p {
s_dag[[r, c2]] += col[r] * col[c2] / eigs[i];
}
}
}
let sdag_s_psi = s_dag.dot(&s_psi);
let term_linear = trace_mat(&s_dag.dot(&s_psi_psi));
let term_quad = trace_mat(&sdag_s_psi.dot(&sdag_s_psi));
let without_correction = term_linear - term_quad;
let mut correction = 0.0_f64;
if !pos_idx.is_empty() && !null_idx.is_empty() {
let n_pos = pos_idx.len();
let n_null = null_idx.len();
let mut u_pos = Array2::<f64>::zeros((p, n_pos));
let mut u_null = Array2::<f64>::zeros((p, n_null));
for (out, &idx) in pos_idx.iter().enumerate() {
u_pos.column_mut(out).assign(&vecs.column(idx));
}
for (out, &idx) in null_idx.iter().enumerate() {
u_null.column_mut(out).assign(&vecs.column(idx));
}
let l_mat = u_pos.t().dot(&s_psi.dot(&u_null));
for a in 0..n_pos {
let sigma_inv_sq = 1.0 / (eigs[pos_idx[a]] * eigs[pos_idx[a]]);
correction += sigma_inv_sq * l_mat.row(a).dot(&l_mat.row(a));
}
correction *= 2.0;
}
let with_correction = without_correction + correction;
(with_correction, without_correction)
}
fn trace_mat(a: &Array2<f64>) -> f64 {
(0..a.nrows()).map(|i| a[[i, i]]).sum()
}
#[test]
fn test_moving_nullspace_correction_needed() {
let s1 = 4.0;
let s2 = 1.0;
let psi = 0.3; let tol = 1e-10;
let h = 1e-5;
let fd_first = pseudo_logdet_fd_first(psi, h, s1, s2, tol);
assert!(
fd_first.is_finite() && fd_first.abs() < 1e-8,
"First derivative should vanish for rotating nullspace, got {fd_first}"
);
let fd_second = pseudo_logdet_fd_second(psi, h, s1, s2, tol);
let (with_corr, without_corr) = analytic_pseudo_logdet_second(psi, s1, s2, tol);
let rel_err_with = (with_corr - fd_second).abs() / fd_second.abs().max(1e-12);
assert!(
rel_err_with < 1e-4,
"With correction: analytic={:.8e}, fd={:.8e}, rel_err={:.3e}",
with_corr,
fd_second,
rel_err_with,
);
let rel_err_without = (without_corr - fd_second).abs() / fd_second.abs().max(1e-12);
assert!(
rel_err_without > 1e-2,
"Without correction should disagree with FD: \
without={:.8e}, fd={:.8e}, rel_err={:.3e} (expected > 1e-2)",
without_corr,
fd_second,
rel_err_without,
);
}
#[test]
fn test_fixed_nullspace_correction_vanishes() {
let tol = 1e-10;
let h = 1e-5;
let rho1 = 0.5_f64;
let rho2 = -0.3_f64;
let build_s = |t: f64| -> Array2<f64> {
Array2::from_diag(&array![(rho1 + t).exp(), (rho2 + 2.0 * t).exp(), 0.0])
};
let t0 = 0.0_f64;
let ld_plus = pseudo_logdet(&build_s(t0 + h), tol);
let ld_0 = pseudo_logdet(&build_s(t0), tol);
let ld_minus = pseudo_logdet(&build_s(t0 - h), tol);
let fd_second = (ld_plus - 2.0 * ld_0 + ld_minus) / (h * h);
let s_mat = build_s(t0);
let s_t = Array2::from_diag(&array![
(rho1 + t0).exp(),
2.0 * (rho2 + 2.0 * t0).exp(),
0.0
]);
let s_tt = Array2::from_diag(&array![
(rho1 + t0).exp(),
4.0 * (rho2 + 2.0 * t0).exp(),
0.0
]);
let (eigs, vecs) = s_mat.eigh(faer::Side::Lower).unwrap();
let p = 3;
let pos_idx: Vec<usize> = (0..p).filter(|&i| eigs[i] > tol).collect();
let null_idx: Vec<usize> = (0..p).filter(|&i| eigs[i] <= tol).collect();
let mut s_dag = Array2::<f64>::zeros((p, p));
for &i in &pos_idx {
let col = vecs.column(i);
for r in 0..p {
for c in 0..p {
s_dag[[r, c]] += col[r] * col[c] / eigs[i];
}
}
}
let sdag_s_t = s_dag.dot(&s_t);
let term_linear = trace_mat(&s_dag.dot(&s_tt));
let term_quad = trace_mat(&sdag_s_t.dot(&sdag_s_t));
let without_correction = term_linear - term_quad;
let mut correction = 0.0_f64;
if !pos_idx.is_empty() && !null_idx.is_empty() {
let n_pos = pos_idx.len();
let n_null = null_idx.len();
let mut u_pos = Array2::<f64>::zeros((p, n_pos));
let mut u_null = Array2::<f64>::zeros((p, n_null));
for (out, &idx) in pos_idx.iter().enumerate() {
u_pos.column_mut(out).assign(&vecs.column(idx));
}
for (out, &idx) in null_idx.iter().enumerate() {
u_null.column_mut(out).assign(&vecs.column(idx));
}
let l_mat = u_pos.t().dot(&s_t.dot(&u_null));
for a in 0..n_pos {
let sigma_inv_sq = 1.0 / (eigs[pos_idx[a]] * eigs[pos_idx[a]]);
correction += sigma_inv_sq * l_mat.row(a).dot(&l_mat.row(a));
}
correction *= 2.0;
}
assert!(
correction.abs() < 1e-12,
"Correction should vanish for fixed nullspace, got {:.3e}",
correction,
);
let with_correction = without_correction + correction;
let abs_err_with = (with_correction - fd_second).abs();
let abs_err_without = (without_correction - fd_second).abs();
assert!(
abs_err_with < 1e-4,
"With correction should match FD: with={:.8e}, fd={:.8e}, abs_err={:.3e}",
with_correction,
fd_second,
abs_err_with,
);
assert!(
abs_err_without < 1e-4,
"Without correction should also match FD (fixed nullspace): \
without={:.8e}, fd={:.8e}, abs_err={:.3e}",
without_correction,
fd_second,
abs_err_without,
);
}
#[test]
fn test_symmetric_eigen_identity() {
let eye = Array2::<f64>::eye(3);
let (evals, evecs) = symmetric_eigen(&eye);
for &e in &evals {
assert!((e - 1.0).abs() < 1e-12, "eigenvalue should be 1.0, got {e}");
}
let prod = evecs.t().dot(&evecs);
for i in 0..3 {
for j in 0..3 {
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(prod[[i, j]] - expected).abs() < 1e-12,
"Q^T Q should be identity"
);
}
}
}
#[test]
fn test_symmetric_eigen_diagonal() {
let mut d = Array2::<f64>::zeros((3, 3));
d[[0, 0]] = 4.0;
d[[1, 1]] = 2.0;
d[[2, 2]] = 1.0;
let (evals, _) = symmetric_eigen(&d);
let mut sorted = evals.clone();
sorted.sort_by(|a, b| a.total_cmp(b));
assert!((sorted[0] - 1.0).abs() < 1e-12);
assert!((sorted[1] - 2.0).abs() < 1e-12);
assert!((sorted[2] - 4.0).abs() < 1e-12);
}
#[test]
fn test_pseudoinverse_times_vec_identity() {
let eye = Array2::<f64>::eye(3);
let v = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let result =
pseudoinverse_times_vec(&eye, v.as_slice().expect("contiguous test vector"), 1e-8);
for i in 0..3 {
assert!((result[i] - v[i]).abs() < 1e-12, "G=I: G⁺v should equal v");
}
}
#[test]
fn test_pseudoinverse_times_vec_singular() {
let mut g = Array2::<f64>::zeros((2, 2));
g[[0, 0]] = 1.0;
g[[0, 1]] = 1.0;
g[[1, 0]] = 1.0;
g[[1, 1]] = 1.0;
let v = Array1::from_vec(vec![2.0, 0.0]);
let result =
pseudoinverse_times_vec(&g, v.as_slice().expect("contiguous test vector"), 1e-8);
assert!((result[0] - 0.5).abs() < 1e-10);
assert!((result[1] - 0.5).abs() < 1e-10);
}
#[test]
fn implicit_hyper_operator_third_derivative_term_matches_dense_reference() {
use crate::terms::basis::ImplicitDesignPsiDerivative;
use std::sync::Arc;
let n = 4usize;
let n_knots = 2usize;
let n_axes = 1usize;
let p = n_knots;
let phi_values = array![1.0, 0.5, 0.7, 0.9, 0.3, 0.4, 0.6, 0.8];
let q_values = array![0.5, -0.2, 0.3, 0.1, -0.4, 0.2, 0.6, -0.1];
let t_values = array![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let axis_components = array![[0.7], [0.3], [-0.4], [0.5], [0.2], [-0.1], [0.6], [0.8]];
let implicit = Arc::new(ImplicitDesignPsiDerivative::new(
phi_values,
q_values,
t_values,
axis_components,
None,
None,
n,
n_knots,
0,
n_axes,
));
let x_data = array![[1.0, 0.30], [0.50, 1.20], [-0.20, 0.80], [0.90, -0.40],];
let x_design = Arc::new(DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
x_data.clone(),
)));
let w_diag = Arc::new(array![0.8, 1.2, 0.6, 1.5]);
let s_psi = array![[0.40, 0.05], [0.05, 0.25]];
let beta_eval = array![0.30, -0.20];
let c_array = array![0.10, -0.05, 0.20, 0.15];
let dx_dpsi = implicit
.materialize_first(0)
.expect("materialize_first should succeed on tiny fixture");
assert_eq!(dx_dpsi.shape(), &[n, p]);
let dx_beta = dx_dpsi.dot(&beta_eval);
let c_x_psi_beta_dense = &c_array * &dx_beta;
let c_x_psi_beta = Some(Arc::new(c_x_psi_beta_dense.clone()));
let op = ImplicitHyperOperator {
implicit_deriv: Arc::clone(&implicit),
axis: 0,
x_design: Arc::clone(&x_design),
w_diag: Arc::clone(&w_diag),
s_psi: s_psi.clone(),
p,
c_x_psi_beta,
};
let probes = [
array![1.0, 0.0],
array![0.0, 1.0],
array![0.7, -0.4],
array![-0.25, 1.10],
];
for (k, v) in probes.iter().enumerate() {
let xv = x_data.dot(v);
let dxv = dx_dpsi.dot(v);
let w_xv = &*w_diag * &xv;
let w_dxv = &*w_diag * &dxv;
let t1 = dx_dpsi.t().dot(&w_xv);
let t2 = x_data.t().dot(&w_dxv);
let weighted = &c_x_psi_beta_dense * &xv;
let t3 = x_data.t().dot(&weighted);
let t4 = s_psi.dot(v);
let want = &t1 + &t2 + &t3 + &t4;
let got = op.mul_vec(v);
assert_eq!(got.len(), p);
for i in 0..p {
let tol = 1e-12 * want[i].abs().max(1.0) + 1e-12;
assert!(
(want[i] - got[i]).abs() <= tol,
"B_d·v mismatch at probe {k}, comp {i}: want={:.6e}, got={:.6e}",
want[i],
got[i],
);
}
}
let op_gauss = ImplicitHyperOperator {
implicit_deriv: Arc::clone(&implicit),
axis: 0,
x_design,
w_diag: Arc::clone(&w_diag),
s_psi: s_psi.clone(),
p,
c_x_psi_beta: None,
};
let v = array![0.7, -0.4];
let xv = x_data.dot(&v);
let dxv = dx_dpsi.dot(&v);
let w_xv = &*w_diag * &xv;
let w_dxv = &*w_diag * &dxv;
let want = &dx_dpsi.t().dot(&w_xv) + &x_data.t().dot(&w_dxv) + &s_psi.dot(&v);
let got = op_gauss.mul_vec(&v);
for i in 0..p {
let tol = 1e-12 * want[i].abs().max(1.0) + 1e-12;
assert!(
(want[i] - got[i]).abs() <= tol,
"Gaussian B_d·v mismatch at comp {i}: want={:.6e}, got={:.6e}",
want[i],
got[i],
);
}
}
#[test]
fn implicit_hyper_operator_third_derivative_term_centered_fd_matches_jacobian_column() {
use crate::terms::basis::ImplicitDesignPsiDerivative;
use std::sync::Arc;
let n = 5usize;
let n_knots = 3usize;
let n_axes = 1usize;
let p = n_knots;
let phi_values =
Array1::from_vec((0..n * n_knots).map(|k| 0.1 + 0.05 * (k as f64)).collect());
let q_values =
Array1::from_vec((0..n * n_knots).map(|k| -0.2 + 0.07 * (k as f64)).collect());
let t_values = Array1::zeros(n * n_knots);
let axis_components = Array2::from_shape_vec(
(n * n_knots, n_axes),
(0..n * n_knots).map(|k| 0.3 + 0.04 * (k as f64)).collect(),
)
.unwrap();
let implicit = Arc::new(ImplicitDesignPsiDerivative::new(
phi_values,
q_values,
t_values,
axis_components,
None,
None,
n,
n_knots,
0,
n_axes,
));
let x_data = array![
[1.0, 0.4, -0.2],
[0.5, 1.1, 0.3],
[-0.3, 0.9, 0.6],
[0.8, -0.5, 1.2],
[0.2, 0.7, -0.4],
];
let x_design = Arc::new(DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
x_data.clone(),
)));
let w_diag = Arc::new(array![1.0, 0.7, 1.3, 0.9, 1.1]);
let s_psi = Array2::<f64>::eye(p) * 0.05;
let beta_eval = array![0.20, -0.10, 0.30];
let c_array = array![0.15, -0.08, 0.22, 0.05, -0.12];
let dx_dpsi = implicit.materialize_first(0).expect("materialize_first");
let dx_beta = dx_dpsi.dot(&beta_eval);
let c_x_psi_beta_dense = &c_array * &dx_beta;
let op = ImplicitHyperOperator {
implicit_deriv: Arc::clone(&implicit),
axis: 0,
x_design,
w_diag,
s_psi,
p,
c_x_psi_beta: Some(Arc::new(c_x_psi_beta_dense.clone())),
};
let v_base = array![0.10, -0.05, 0.20];
let eps = 1e-6;
for j in 0..p {
let mut e_j = Array1::<f64>::zeros(p);
e_j[j] = 1.0;
let mut v_plus = v_base.clone();
v_plus[j] += eps;
let mut v_minus = v_base.clone();
v_minus[j] -= eps;
let fd = (&op.mul_vec(&v_plus) - &op.mul_vec(&v_minus)).mapv(|x| x / (2.0 * eps));
let analytic = op.mul_vec(&e_j);
for i in 0..p {
let tol = 1e-7 * analytic[i].abs().max(1.0) + 1e-7;
assert!(
(analytic[i] - fd[i]).abs() <= tol,
"FD col {j} mismatch at row {i}: analytic={:.6e}, fd={:.6e}",
analytic[i],
fd[i],
);
}
}
}
#[test]
fn test_pseudoinverse_scalar() {
let mut g = Array2::<f64>::zeros((1, 1));
g[[0, 0]] = 4.0;
let v = Array1::from_vec(vec![8.0]);
let result =
pseudoinverse_times_vec(&g, v.as_slice().expect("contiguous test vector"), 1e-8);
assert!((result[0] - 2.0).abs() < 1e-12);
}
#[test]
fn corrected_covariance_indefinite_returns_diagnostic() {
let outer = ndarray::arr2(&[[2.0_f64, 0.0], [0.0, -1.0]]);
let base = Array2::<f64>::eye(2);
let hop = DenseSpectralOperator::from_symmetric(&base)
.expect("DenseSpectralOperator from identity should succeed");
let v0 = Array1::from_vec(vec![0.1, 0.2]);
let v1 = Array1::from_vec(vec![0.3, 0.4]);
let res = compute_corrected_covariance_with_constraints(
&[v0.clone(), v1.clone()],
&[],
&outer,
&hop,
None,
f64::NAN,
);
match res {
Err(CorrectedCovarianceError::Indefinite(diag)) => {
assert!(
diag.min_eigenvalue < -0.5,
"min eigenvalue should be ~-1, got {}",
diag.min_eigenvalue,
);
assert!(
diag.active_constraints.is_empty(),
"no theta supplied ⇒ no active constraints",
);
assert!(
!diag.suggested_action.is_empty(),
"diagnostic must include a suggested-action message",
);
}
Err(other) => panic!("expected Indefinite diagnostic, got error: {:?}", other),
Ok(cov) => panic!(
"indefinite outer Hessian must NOT yield a covariance; got matrix shape {:?}",
cov.matrix.shape(),
),
}
let res_legacy = compute_corrected_covariance(&[v0, v1], &[], &outer, &hop);
assert!(
matches!(res_legacy, Err(CorrectedCovarianceError::Indefinite(_))),
"legacy entry point must also surface Indefinite, got: {:?}",
res_legacy.map(|m| m.shape().to_vec()),
);
}
#[test]
fn corrected_covariance_indefinite_with_active_bound_succeeds() {
let outer = ndarray::arr2(&[[3.0_f64, 0.0], [0.0, -2.0]]);
let base = Array2::<f64>::eye(2);
let hop = DenseSpectralOperator::from_symmetric(&base).expect("hop");
let v0 = Array1::from_vec(vec![0.5, 0.0]);
let v1 = Array1::from_vec(vec![0.0, 0.5]);
let theta = vec![0.0_f64, crate::solver::estimate::RHO_BOUND];
let res = compute_corrected_covariance_with_constraints(
&[v0, v1],
&[],
&outer,
&hop,
Some(&theta),
0.0,
)
.expect("free-subspace SPD ⇒ covariance returned");
assert_eq!(res.active_constraints, vec![1]);
assert!(res.matrix.iter().all(|v| v.is_finite()));
}
}