use super::*;
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum PseudoLogdetMode {
#[default]
Smooth,
HardPseudo,
}
pub struct DenseSpectralOperator {
pub(crate) reg_eigenvalues: Vec<f64>,
pub(crate) active_mask: Vec<bool>,
pub(crate) eigenvectors: Array2<f64>,
pub(crate) w_factor: Array2<f64>,
pub(crate) hinv_cross_kernel: Array2<f64>,
pub(crate) g_factor: Array2<f64>,
pub(crate) logdet_hessian_kernel: Array2<f64>,
pub(crate) cached_logdet: f64,
pub(crate) projected_factor_cache: ProjectedFactorCache,
pub(crate) n_dim: usize,
}
impl DenseSpectralOperator {
pub fn from_symmetric(h: &Array2<f64>) -> Result<Self, String> {
Self::from_symmetric_with_mode(h, PseudoLogdetMode::Smooth)
}
pub fn from_symmetric_with_mode(
h: &Array2<f64>,
mode: PseudoLogdetMode,
) -> Result<Self, String> {
use faer::Side;
let n = h.nrows();
if n != h.ncols() {
return Err(RemlError::DimensionMismatch {
reason: format!(
"HessianOperator: expected square matrix, got {}×{}",
n,
h.ncols()
),
}
.into());
}
let (eigenvalues, eigenvectors) = h
.eigh(Side::Lower)
.map_err(|e| format!("Eigendecomposition failed: {e}"))?;
let epsilon = spectral_epsilon(eigenvalues.as_slice().unwrap());
let active: Vec<bool> = match mode {
PseudoLogdetMode::Smooth => vec![true; n],
PseudoLogdetMode::HardPseudo => eigenvalues.iter().map(|&s| s > epsilon).collect(),
};
let reg_eigenvalues: Vec<f64> = eigenvalues
.iter()
.map(|&sigma| spectral_regularize(sigma, epsilon))
.collect();
let mut w_factor = Array2::zeros((n, n));
for j in 0..n {
if !active[j] {
continue;
}
let scale = 1.0 / reg_eigenvalues[j].sqrt();
for row in 0..n {
w_factor[[row, j]] = eigenvectors[[row, j]] * scale;
}
}
let mut hinv_cross_kernel = Array2::zeros((n, n));
for a in 0..n {
if !active[a] {
continue;
}
let inv_ra = 1.0 / reg_eigenvalues[a];
for b in 0..n {
if !active[b] {
continue;
}
hinv_cross_kernel[[a, b]] = inv_ra / reg_eigenvalues[b];
}
}
let four_eps_sq = 4.0 * epsilon * epsilon;
let mut g_factor = Array2::zeros((n, n));
for j in 0..n {
if !active[j] {
continue;
}
let sigma = eigenvalues[j];
let phi_prime = 1.0 / (sigma * sigma + four_eps_sq).sqrt();
let scale = phi_prime.sqrt();
for row in 0..n {
g_factor[[row, j]] = eigenvectors[[row, j]] * scale;
}
}
let mut logdet_hessian_kernel = Array2::zeros((n, n));
let sqrt_disc: Vec<f64> = eigenvalues
.iter()
.map(|&s| (s * s + four_eps_sq).sqrt())
.collect();
for a in 0..n {
if !active[a] {
continue;
}
let sigma_a = eigenvalues[a];
let sqrt_a = sqrt_disc[a];
for b in 0..n {
if !active[b] {
continue;
}
logdet_hessian_kernel[[a, b]] = if a == b {
-sigma_a / (sqrt_a * sqrt_a * sqrt_a)
} else {
let sigma_b = eigenvalues[b];
let sqrt_b = sqrt_disc[b];
-(sigma_a + sigma_b) / (sqrt_a * sqrt_b * (sqrt_a + sqrt_b))
};
}
}
let cached_logdet: f64 = reg_eigenvalues
.iter()
.zip(active.iter())
.filter_map(|(&v, &act)| if act { Some(v.ln()) } else { None })
.sum();
Ok(Self {
reg_eigenvalues,
active_mask: active,
eigenvectors,
w_factor,
hinv_cross_kernel,
g_factor,
logdet_hessian_kernel,
cached_logdet,
projected_factor_cache: ProjectedFactorCache::default(),
n_dim: n,
})
}
#[inline]
pub(crate) fn rotate_to_eigenbasis(&self, matrix: &Array2<f64>) -> Array2<f64> {
let left = crate::faer_ndarray::fast_atb(&self.eigenvectors, matrix);
crate::faer_ndarray::fast_ab(&left, &self.eigenvectors)
}
pub fn logdet_gradient_factor(&self) -> &Array2<f64> {
&self.g_factor
}
#[inline]
pub(crate) fn trace_hinv_product_cross_rotated(
&self,
a_rot: &Array2<f64>,
b_rot: &Array2<f64>,
) -> f64 {
let mut result = 0.0;
for ((kernel_row, a_row), b_col) in self
.hinv_cross_kernel
.rows()
.into_iter()
.zip(a_rot.rows().into_iter())
.zip(b_rot.columns().into_iter())
{
for ((kernel, a_value), b_value) in kernel_row
.iter()
.copied()
.zip(a_row.iter().copied())
.zip(b_col.iter().copied())
{
result += kernel * a_value * b_value;
}
}
result
}
#[inline]
pub(crate) fn trace_hinv_product_cross_dense(&self, a: &Array2<f64>, b: &Array2<f64>) -> f64 {
let a_rot = self.rotate_to_eigenbasis(a);
if std::ptr::eq(a, b) {
return self.trace_hinv_product_cross_rotated(&a_rot, &a_rot);
}
let b_rot = self.rotate_to_eigenbasis(b);
self.trace_hinv_product_cross_rotated(&a_rot, &b_rot)
}
#[inline]
pub(crate) fn projected_matrix(&self, matrix: &Array2<f64>) -> Array2<f64> {
let left = crate::faer_ndarray::fast_atb(&self.w_factor, matrix);
crate::faer_ndarray::fast_ab(&left, &self.w_factor)
}
#[inline]
pub(crate) fn projected_operator(
&self,
factor: &Array2<f64>,
op: &dyn HyperOperator,
) -> Array2<f64> {
if log::log_enabled!(log::Level::Info) {
let start = std::time::Instant::now();
let result = op.projected_matrix_cached(factor, &self.projected_factor_cache);
let signature = format!(
"DenseSpectralOperator::projected_operator dim={} rank={} implicit={}",
self.n_dim,
factor.ncols(),
op.is_implicit(),
);
dense_spectral_stage_log(&signature, start.elapsed().as_secs_f64());
result
} else {
op.projected_matrix_cached(factor, &self.projected_factor_cache)
}
}
#[inline]
pub(crate) fn trace_projected_cross(&self, left: &Array2<f64>, right: &Array2<f64>) -> f64 {
let mut result = 0.0;
for (left_row, right_col) in left.rows().into_iter().zip(right.columns().into_iter()) {
for (left_value, right_value) in left_row.iter().copied().zip(right_col.iter().copied())
{
result += left_value * right_value;
}
}
result
}
#[inline]
pub(crate) fn trace_logdet_hessian_cross_rotated(
&self,
h_i_rot: &Array2<f64>,
h_j_rot: &Array2<f64>,
) -> f64 {
let mut result = 0.0;
for ((kernel_row, h_i_row), h_j_col) in self
.logdet_hessian_kernel
.rows()
.into_iter()
.zip(h_i_rot.rows().into_iter())
.zip(h_j_rot.columns().into_iter())
{
for ((kernel, h_i_value), h_j_value) in kernel_row
.iter()
.copied()
.zip(h_i_row.iter().copied())
.zip(h_j_col.iter().copied())
{
result += kernel * h_i_value * h_j_value;
}
}
result
}
}
pub(crate) fn dense_spectral_stage_log(signature: &str, elapsed_s: f64) {
use std::sync::Mutex;
struct Repeat {
pub(crate) signature: String,
pub(crate) count: u64,
pub(crate) total: f64,
pub(crate) min: f64,
pub(crate) max: f64,
pub(crate) next_heartbeat: u64,
}
static REPEAT: Mutex<Option<Repeat>> = Mutex::new(None);
let mut guard = match REPEAT.lock() {
Ok(g) => g,
Err(poisoned) => poisoned.into_inner(),
};
if let Some(state) = guard.as_mut() {
if state.signature == signature {
state.count += 1;
state.total += elapsed_s;
if elapsed_s < state.min {
state.min = elapsed_s;
}
if elapsed_s > state.max {
state.max = elapsed_s;
}
if state.count >= state.next_heartbeat {
log::info!(
"[STAGE] {} (×{} so far, total={:.3}s min={:.3}s max={:.3}s avg={:.3}s)",
state.signature,
state.count,
state.total,
state.min,
state.max,
state.total / state.count as f64,
);
state.next_heartbeat = state.next_heartbeat.saturating_mul(2);
}
return;
}
if state.count > 1 {
log::info!(
"[STAGE] {} final ×{} total={:.3}s min={:.3}s max={:.3}s avg={:.3}s",
state.signature,
state.count,
state.total,
state.min,
state.max,
state.total / state.count as f64,
);
}
}
log::info!("[STAGE] {} elapsed={:.3}s", signature, elapsed_s);
*guard = Some(Repeat {
signature: signature.to_string(),
count: 1,
total: elapsed_s,
min: elapsed_s,
max: elapsed_s,
next_heartbeat: 2,
});
}
impl HessianOperator for DenseSpectralOperator {
fn logdet(&self) -> f64 {
self.cached_logdet
}
fn as_exact_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
Some(self)
}
fn assemble_h_dense_for_tangent_projection(&self) -> Result<Array2<f64>, String> {
Ok(assemble_h_raw_dense(self))
}
fn trace_hinv_product(&self, a: &Array2<f64>) -> f64 {
let aw = a.dot(&self.w_factor);
aw.iter()
.zip(self.w_factor.iter())
.map(|(&a, &w)| a * w)
.sum()
}
fn solve(&self, rhs: &Array1<f64>) -> Array1<f64> {
let mut result = Array1::zeros(self.n_dim);
for j in 0..self.n_dim {
if !self.active_mask[j] {
continue;
}
let u = self.eigenvectors.column(j);
let coeff = u.dot(rhs) / self.reg_eigenvalues[j];
for row in 0..self.n_dim {
result[row] += coeff * u[row];
}
}
result
}
fn solve_multi(&self, rhs: &Array2<f64>) -> Array2<f64> {
let mut projected = self.eigenvectors.t().dot(rhs);
for j in 0..self.n_dim {
if self.active_mask[j] {
let scale = 1.0 / self.reg_eigenvalues[j];
projected.row_mut(j).mapv_inplace(|value| value * scale);
} else {
projected.row_mut(j).fill(0.0);
}
}
self.eigenvectors.dot(&projected)
}
fn trace_hinv_product_cross(&self, a: &Array2<f64>, b: &Array2<f64>) -> f64 {
self.trace_hinv_product_cross_dense(a, b)
}
fn trace_hinv_operator(&self, op: &dyn HyperOperator) -> f64 {
if log::log_enabled!(log::Level::Info) {
let start = std::time::Instant::now();
let result =
op.trace_projected_factor_cached(&self.w_factor, &self.projected_factor_cache);
let signature = format!(
"DenseSpectralOperator::trace_hinv_operator dim={} rank={} implicit={}",
self.n_dim,
self.w_factor.ncols(),
op.is_implicit(),
);
dense_spectral_stage_log(&signature, start.elapsed().as_secs_f64());
result
} else {
op.trace_projected_factor_cached(&self.w_factor, &self.projected_factor_cache)
}
}
fn trace_hinv_matrix_operator_cross(
&self,
matrix: &Array2<f64>,
op: &dyn HyperOperator,
) -> f64 {
let left = self.w_factor.t().dot(matrix).dot(&self.w_factor);
let right = self.projected_operator(&self.w_factor, op);
self.trace_projected_cross(&left, &right)
}
fn trace_hinv_operator_cross(
&self,
left: &dyn HyperOperator,
right: &dyn HyperOperator,
) -> f64 {
if log::log_enabled!(log::Level::Info) {
let start = std::time::Instant::now();
let left_proj = self.projected_operator(&self.w_factor, left);
let result = if std::ptr::addr_eq(left, right) {
self.trace_projected_cross(&left_proj, &left_proj)
} else {
let right_proj = self.projected_operator(&self.w_factor, right);
self.trace_projected_cross(&left_proj, &right_proj)
};
let signature = format!(
"DenseSpectralOperator::trace_hinv_operator_cross dim={} rank={} left_implicit={} right_implicit={}",
self.n_dim,
self.w_factor.ncols(),
left.is_implicit(),
right.is_implicit(),
);
dense_spectral_stage_log(&signature, start.elapsed().as_secs_f64());
result
} else {
let left_proj = self.projected_operator(&self.w_factor, left);
if std::ptr::addr_eq(left, right) {
self.trace_projected_cross(&left_proj, &left_proj)
} else {
let right_proj = self.projected_operator(&self.w_factor, right);
self.trace_projected_cross(&left_proj, &right_proj)
}
}
}
fn trace_logdet_gradient(&self, a: &Array2<f64>) -> f64 {
let ag = a.dot(&self.g_factor);
ag.iter()
.zip(self.g_factor.iter())
.map(|(&a, &g)| a * g)
.sum()
}
fn xt_logdet_kernel_x_diagonal(&self, x: &DesignMatrix) -> Array1<f64> {
let n = x.nrows();
let p = x.ncols();
let rank = self.g_factor.ncols();
let mut h = Array1::<f64>::zeros(n);
if n == 0 || p == 0 || rank == 0 {
return h;
}
if let Some(gpu) = crate::gpu::linalg_dispatch::try_fast_spectral_leverage_diagonal(
x,
self.g_factor.view(),
) {
return gpu;
}
let chunk_rows = byte_balanced_row_chunk(p + rank, n);
let mut start = 0usize;
while start < n {
let end = (start + chunk_rows).min(n);
let rows = x.try_row_chunk(start..end).unwrap_or_else(|err| {
reml_contract_panic(format!(
"xt_logdet_kernel_x_diagonal: row chunk failed: {err}"
))
});
let xg = crate::faer_ndarray::fast_ab(&rows, &self.g_factor);
for (local, row) in xg.outer_iter().enumerate() {
h[start + local] = row.iter().map(|v| v * v).sum();
}
start = end;
}
h
}
fn trace_logdet_block_local(
&self,
block: &Array2<f64>,
scale: f64,
start: usize,
end: usize,
) -> f64 {
let g_block = self.g_factor.slice(ndarray::s![start..end, ..]);
let ag = block.dot(&g_block);
scale
* ag.iter()
.zip(g_block.iter())
.map(|(&a, &g)| a * g)
.sum::<f64>()
}
fn trace_logdet_operator(&self, op: &dyn HyperOperator) -> f64 {
if log::log_enabled!(log::Level::Info) {
let start = std::time::Instant::now();
let result =
op.trace_projected_factor_cached(&self.g_factor, &self.projected_factor_cache);
let signature = format!(
"DenseSpectralOperator::trace_logdet_operator dim={} rank={} implicit={}",
self.n_dim,
self.g_factor.ncols(),
op.is_implicit(),
);
dense_spectral_stage_log(&signature, start.elapsed().as_secs_f64());
result
} else {
op.trace_projected_factor_cached(&self.g_factor, &self.projected_factor_cache)
}
}
fn trace_logdet_hessian_cross(&self, h_i: &Array2<f64>, h_j: &Array2<f64>) -> f64 {
let hp_i = self.rotate_to_eigenbasis(h_i);
if std::ptr::eq(h_i, h_j) {
return self.trace_logdet_hessian_cross_rotated(&hp_i, &hp_i);
}
let hp_j = self.rotate_to_eigenbasis(h_j);
self.trace_logdet_hessian_cross_rotated(&hp_i, &hp_j)
}
fn trace_logdet_hessian_cross_matrix_operator(
&self,
h_i: &Array2<f64>,
h_j: &dyn HyperOperator,
) -> f64 {
let hp_i = self.rotate_to_eigenbasis(h_i);
let hp_j = self.projected_operator(&self.eigenvectors, h_j);
self.trace_logdet_hessian_cross_rotated(&hp_i, &hp_j)
}
fn trace_logdet_hessian_cross_operator(
&self,
h_i: &dyn HyperOperator,
h_j: &dyn HyperOperator,
) -> f64 {
let hp_i = self.projected_operator(&self.eigenvectors, h_i);
if std::ptr::addr_eq(h_i, h_j) {
return self.trace_logdet_hessian_cross_rotated(&hp_i, &hp_i);
}
let hp_j = self.projected_operator(&self.eigenvectors, h_j);
self.trace_logdet_hessian_cross_rotated(&hp_i, &hp_j)
}
fn active_rank(&self) -> usize {
self.active_mask.iter().filter(|&&active| active).count()
}
fn dim(&self) -> usize {
self.n_dim
}
fn is_dense(&self) -> bool {
true
}
fn prefers_stochastic_trace_estimation(&self) -> bool {
false
}
fn logdet_traces_match_hinv_kernel(&self) -> bool {
false
}
fn as_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
Some(self)
}
}