use super::*;
#[derive(Debug, Default)]
pub struct StochasticTraceState {
pub monotone_probe_floor: usize,
pub cg_warm_starts: HashMap<u64, Array1<f64>>,
pub solve_rel_tol_override: Option<f64>,
pub last_linear_residual_norm: Option<f64>,
pub last_probe_sigma_sq: Option<f64>,
pub last_probe_count: usize,
}
pub(crate) const HUTCHPP_TRACE_MIN_DIM: usize = 128;
pub(crate) fn hutchpp_config_for_dim(dim: usize) -> StochasticTraceConfig {
const SKETCH_DIM_PER: usize = 32;
const SKETCH_DIM_MIN: usize = 4;
const SKETCH_DIM_MAX: usize = 16;
const PROBES_PER_SKETCH: usize = 4;
const PROBES_MAX_FLOOR: usize = 32;
const PROBES_MIN_FLOOR: usize = 8;
let sketch = (dim / SKETCH_DIM_PER).clamp(SKETCH_DIM_MIN, SKETCH_DIM_MAX);
let mut config = StochasticTraceConfig::default();
config.hutchpp_sketch_dim = Some(sketch);
config.n_probes_max = (sketch * PROBES_PER_SKETCH).max(PROBES_MAX_FLOOR);
config.n_probes_min = sketch.max(PROBES_MIN_FLOOR);
config
}
pub trait HessianOperator: Send + Sync {
fn logdet(&self) -> f64;
fn trace_hinv_product(&self, a: &Array2<f64>) -> f64;
fn as_exact_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
None
}
fn assemble_h_dense_for_tangent_projection(&self) -> Result<Array2<f64>, String> {
Err("backend does not support tangent projection".to_string())
}
fn trace_hinv_operator(&self, op: &dyn HyperOperator) -> f64 {
if op.is_implicit() && self.dim() >= HUTCHPP_TRACE_MIN_DIM {
let config = hutchpp_config_for_dim(self.dim());
return hutchpp_estimate_trace_hinv_operator(self, op, &config);
}
if op.is_implicit() {
log::warn!(
"trace_hinv_operator: materializing implicit HyperOperator — \
backend should provide a matrix-free override"
);
}
self.trace_hinv_product(&op.to_dense())
}
fn solve(&self, rhs: &Array1<f64>) -> Array1<f64>;
fn solve_multi(&self, rhs: &Array2<f64>) -> Array2<f64>;
fn stochastic_trace_solve(&self, rhs: &Array1<f64>, rel_tol: f64) -> Array1<f64> {
assert!(
rel_tol.is_finite() && rel_tol > 0.0,
"stochastic trace solve tolerance must be positive and finite"
);
self.solve(rhs)
}
fn stochastic_trace_solve_for_probe(
&self,
rhs: &Array1<f64>,
rel_tol: f64,
probe_id: u64,
state: Option<&Arc<Mutex<StochasticTraceState>>>,
) -> Array1<f64> {
if let Some(state_arc) = state
&& let Ok(mut guard) = state_arc.lock()
{
guard.cg_warm_starts.remove(&probe_id);
}
self.stochastic_trace_solve(rhs, rel_tol)
}
fn stochastic_trace_solve_multi(&self, rhs: &Array2<f64>, rel_tol: f64) -> Array2<f64> {
assert!(
rel_tol.is_finite() && rel_tol > 0.0,
"stochastic trace multi-solve tolerance must be positive and finite"
);
self.solve_multi(rhs)
}
fn has_matrix_free_trace_cg_operator(&self) -> bool {
false
}
fn trace_hinv_product_cross(&self, a: &Array2<f64>, b: &Array2<f64>) -> f64 {
let solved_a = self.solve_multi(a);
if std::ptr::eq(a, b) {
return trace_matrix_product(&solved_a, &solved_a);
}
let solved_b = self.solve_multi(b);
trace_matrix_product(&solved_a, &solved_b)
}
fn trace_hinv_matrix_operator_cross(
&self,
matrix: &Array2<f64>,
op: &dyn HyperOperator,
) -> f64 {
if op.is_implicit() && self.dim() >= HUTCHPP_TRACE_MIN_DIM {
let config = hutchpp_config_for_dim(self.dim());
let lhs = DenseMatrixHyperOperator {
matrix: matrix.clone(),
};
return hutchpp_estimate_trace_hinv_operator_cross(self, &lhs, op, &config);
}
if op.is_implicit() {
log::warn!(
"trace_hinv_matrix_operator_cross: materializing implicit HyperOperator — \
backend should provide a matrix-free override"
);
}
self.trace_hinv_product_cross(matrix, &op.to_dense())
}
fn trace_hinv_operator_cross(
&self,
left: &dyn HyperOperator,
right: &dyn HyperOperator,
) -> f64 {
let l_implicit = left.is_implicit();
let r_implicit = right.is_implicit();
if (l_implicit || r_implicit) && self.dim() >= HUTCHPP_TRACE_MIN_DIM {
let config = hutchpp_config_for_dim(self.dim());
if std::ptr::eq(
left as *const dyn HyperOperator as *const (),
right as *const dyn HyperOperator as *const (),
) {
return hutchpp_estimate_trace_hinv_op_squared(self, left, &config);
}
return hutchpp_estimate_trace_hinv_operator_cross(self, left, right, &config);
}
if l_implicit || r_implicit {
log::warn!(
"trace_hinv_operator_cross: materializing implicit HyperOperator(s) — \
backend should provide a matrix-free override"
);
}
self.trace_hinv_product_cross(&left.to_dense(), &right.to_dense())
}
fn trace_logdet_gradient(&self, a: &Array2<f64>) -> f64 {
self.trace_hinv_product(a)
}
fn xt_logdet_kernel_x_diagonal(&self, x: &DesignMatrix) -> Array1<f64> {
assert!(self.logdet_traces_match_hinv_kernel());
let n = x.nrows();
let p = x.ncols();
let block = {
const TARGET_CHUNK_FLOATS: usize = 1 << 16;
(TARGET_CHUNK_FLOATS / p.max(1)).clamp(1, n.max(1))
};
let mut h = Array1::<f64>::zeros(n);
let mut start = 0usize;
while start < n {
let end = (start + block).min(n);
let rows = x.try_row_chunk(start..end).unwrap_or_else(|err| {
reml_contract_panic(format!(
"xt_logdet_kernel_x_diagonal: row chunk failed: {err}"
))
});
let chunk_t = rows.t().to_owned();
let z_chunk = self.solve_multi(&chunk_t);
for (i, (row, z_col)) in rows
.outer_iter()
.zip(z_chunk.columns().into_iter())
.enumerate()
{
let mut acc = 0.0;
for (row_value, z_value) in row.iter().copied().zip(z_col.iter().copied()) {
acc += row_value * z_value;
}
h[start + i] = acc;
}
start = end;
}
h
}
fn trace_logdet_operator(&self, op: &dyn HyperOperator) -> f64 {
if op.is_implicit()
&& self.dim() >= HUTCHPP_TRACE_MIN_DIM
&& self.logdet_traces_match_hinv_kernel()
{
let config = hutchpp_config_for_dim(self.dim());
return hutchpp_estimate_trace_hinv_operator(self, op, &config);
}
if op.is_implicit() {
log::warn!(
"trace_logdet_operator: materializing implicit HyperOperator — \
backend should provide a matrix-free override"
);
}
self.trace_logdet_gradient(&op.to_dense())
}
fn trace_logdet_h_k(
&self,
a_k: &Array2<f64>,
third_deriv_correction: Option<&Array2<f64>>,
) -> f64 {
let base = self.trace_logdet_gradient(a_k);
match third_deriv_correction {
Some(c) => base + self.trace_logdet_gradient(c),
None => base,
}
}
fn trace_logdet_block_local(
&self,
block: &Array2<f64>,
scale: f64,
start: usize,
end: usize,
) -> f64 {
let p = self.dim();
let mut full = Array2::<f64>::zeros((p, p));
let bs = end - start;
for i in 0..bs {
for j in 0..bs {
full[[start + i, start + j]] = scale * block[[i, j]];
}
}
self.trace_logdet_gradient(&full)
}
fn trace_logdet_hessian_cross(&self, h_i: &Array2<f64>, h_j: &Array2<f64>) -> f64 {
let y_i = self.solve_multi(h_i);
if std::ptr::eq(h_i, h_j) {
return -trace_matrix_product(&y_i, &y_i);
}
let y_j = self.solve_multi(h_j);
-trace_matrix_product(&y_j, &y_i)
}
fn trace_logdet_hessian_cross_matrix_operator(
&self,
h_i: &Array2<f64>,
h_j: &dyn HyperOperator,
) -> f64 {
self.trace_logdet_hessian_cross(h_i, &h_j.to_dense())
}
fn trace_logdet_hessian_cross_operator(
&self,
h_i: &dyn HyperOperator,
h_j: &dyn HyperOperator,
) -> f64 {
self.trace_logdet_hessian_cross(&h_i.to_dense(), &h_j.to_dense())
}
fn active_rank(&self) -> usize;
fn dim(&self) -> usize;
fn is_dense(&self) -> bool {
false
}
fn prefers_stochastic_trace_estimation(&self) -> bool {
self.is_dense()
}
fn logdet_traces_match_hinv_kernel(&self) -> bool {
true
}
fn as_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
None
}
}
pub fn hessian_operator_geometric_scale(op: &dyn HessianOperator) -> Option<f64> {
let rank = op.active_rank();
if rank == 0 {
return None;
}
let logdet = op.logdet();
if !logdet.is_finite() {
return None;
}
let scale = (logdet / rank as f64).exp();
if scale.is_finite() && scale > 0.0 {
Some(scale)
} else {
None
}
}