1use super::*;
2
3pub struct DenseSpectralOperator {
14 pub(crate) reg_eigenvalues: Vec<f64>,
16 pub(crate) active_mask: Vec<bool>,
24 pub(crate) eigenvectors: Array2<f64>,
26 pub(crate) w_factor: Array2<f64>,
29 pub(crate) hinv_cross_kernel: Array2<f64>,
32 pub(crate) g_factor: Array2<f64>,
35 pub(crate) logdet_hessian_kernel: Array2<f64>,
38 pub(crate) cached_logdet: f64,
40 pub(crate) projected_factor_cache: ProjectedFactorCache,
41 pub(crate) n_dim: usize,
43}
44
45impl DenseSpectralOperator {
46 pub fn from_symmetric(h: &Array2<f64>) -> Result<Self, String> {
52 Self::from_symmetric_with_mode(h, PseudoLogdetMode::Smooth)
53 }
54
55 pub fn from_symmetric_with_mode(
64 h: &Array2<f64>,
65 mode: PseudoLogdetMode,
66 ) -> Result<Self, String> {
67 use faer::Side;
68
69 let n = h.nrows();
70 if n != h.ncols() {
71 return Err(RemlError::DimensionMismatch {
72 reason: format!(
73 "HessianOperator: expected square matrix, got {}×{}",
74 n,
75 h.ncols()
76 ),
77 }
78 .into());
79 }
80
81 let (eigenvalues, eigenvectors) = h
82 .eigh(Side::Lower)
83 .map_err(|e| format!("Eigendecomposition failed: {e}"))?;
84
85 let epsilon = spectral_epsilon(eigenvalues.as_slice().unwrap());
86
87 let active: Vec<bool> = match mode {
100 PseudoLogdetMode::Smooth => vec![true; n],
101 PseudoLogdetMode::HardPseudo => eigenvalues.iter().map(|&s| s > epsilon).collect(),
102 };
103
104 let reg_eigenvalues: Vec<f64> = eigenvalues
109 .iter()
110 .map(|&sigma| spectral_regularize(sigma, epsilon))
111 .collect();
112
113 let mut w_factor = Array2::zeros((n, n));
116 for j in 0..n {
117 if !active[j] {
118 continue;
119 }
120 let scale = 1.0 / reg_eigenvalues[j].sqrt();
121 for row in 0..n {
122 w_factor[[row, j]] = eigenvectors[[row, j]] * scale;
123 }
124 }
125
126 let mut hinv_cross_kernel = Array2::zeros((n, n));
127 for a in 0..n {
128 if !active[a] {
129 continue;
130 }
131 let inv_ra = 1.0 / reg_eigenvalues[a];
132 for b in 0..n {
133 if !active[b] {
134 continue;
135 }
136 hinv_cross_kernel[[a, b]] = inv_ra / reg_eigenvalues[b];
137 }
138 }
139
140 let four_eps_sq = 4.0 * epsilon * epsilon;
145 let mut g_factor = Array2::zeros((n, n));
146 for j in 0..n {
147 if !active[j] {
148 continue;
149 }
150 let sigma = eigenvalues[j];
151 let phi_prime = 1.0 / (sigma * sigma + four_eps_sq).sqrt();
152 let scale = phi_prime.sqrt();
153 for row in 0..n {
154 g_factor[[row, j]] = eigenvectors[[row, j]] * scale;
155 }
156 }
157
158 let mut logdet_hessian_kernel = Array2::zeros((n, n));
159 let sqrt_disc: Vec<f64> = eigenvalues
160 .iter()
161 .map(|&s| (s * s + four_eps_sq).sqrt())
162 .collect();
163 for a in 0..n {
164 if !active[a] {
165 continue;
166 }
167 let sigma_a = eigenvalues[a];
168 let sqrt_a = sqrt_disc[a];
169 for b in 0..n {
170 if !active[b] {
171 continue;
172 }
173 logdet_hessian_kernel[[a, b]] = if a == b {
174 -sigma_a / (sqrt_a * sqrt_a * sqrt_a)
175 } else {
176 let sigma_b = eigenvalues[b];
177 let sqrt_b = sqrt_disc[b];
178 -(sigma_a + sigma_b) / (sqrt_a * sqrt_b * (sqrt_a + sqrt_b))
179 };
180 }
181 }
182
183 let cached_logdet: f64 = reg_eigenvalues
185 .iter()
186 .zip(active.iter())
187 .filter_map(|(&v, &act)| if act { Some(v.ln()) } else { None })
188 .sum();
189
190 Ok(Self {
191 reg_eigenvalues,
192 active_mask: active,
193 eigenvectors,
194 w_factor,
195 hinv_cross_kernel,
196 g_factor,
197 logdet_hessian_kernel,
198 cached_logdet,
199 projected_factor_cache: ProjectedFactorCache::default(),
200 n_dim: n,
201 })
202 }
203
204 #[inline]
205 pub(crate) fn rotate_to_eigenbasis(&self, matrix: &Array2<f64>) -> Array2<f64> {
206 let left = gam_linalg::faer_ndarray::fast_atb(&self.eigenvectors, matrix);
207 gam_linalg::faer_ndarray::fast_ab(&left, &self.eigenvectors)
208 }
209
210 pub fn logdet_gradient_factor(&self) -> &Array2<f64> {
215 &self.g_factor
216 }
217
218 #[inline]
219 pub(crate) fn trace_hinv_product_cross_rotated(
220 &self,
221 a_rot: &Array2<f64>,
222 b_rot: &Array2<f64>,
223 ) -> f64 {
224 let mut result = 0.0;
225 for ((kernel_row, a_row), b_col) in self
226 .hinv_cross_kernel
227 .rows()
228 .into_iter()
229 .zip(a_rot.rows().into_iter())
230 .zip(b_rot.columns().into_iter())
231 {
232 for ((kernel, a_value), b_value) in kernel_row
233 .iter()
234 .copied()
235 .zip(a_row.iter().copied())
236 .zip(b_col.iter().copied())
237 {
238 result += kernel * a_value * b_value;
239 }
240 }
241 result
242 }
243
244 #[inline]
245 pub(crate) fn trace_hinv_product_cross_dense(&self, a: &Array2<f64>, b: &Array2<f64>) -> f64 {
246 let a_rot = self.rotate_to_eigenbasis(a);
247 if std::ptr::eq(a, b) {
248 return self.trace_hinv_product_cross_rotated(&a_rot, &a_rot);
249 }
250 let b_rot = self.rotate_to_eigenbasis(b);
251 self.trace_hinv_product_cross_rotated(&a_rot, &b_rot)
252 }
253
254 #[inline]
255 pub(crate) fn projected_matrix(&self, matrix: &Array2<f64>) -> Array2<f64> {
256 let left = gam_linalg::faer_ndarray::fast_atb(&self.w_factor, matrix);
257 gam_linalg::faer_ndarray::fast_ab(&left, &self.w_factor)
258 }
259
260 #[inline]
261 pub(crate) fn projected_operator(
262 &self,
263 factor: &Array2<f64>,
264 op: &dyn HyperOperator,
265 ) -> Array2<f64> {
266 if log::log_enabled!(log::Level::Info) {
267 let start = std::time::Instant::now();
268 let result = op.projected_matrix_cached(factor, &self.projected_factor_cache);
269 let signature = format!(
270 "DenseSpectralOperator::projected_operator dim={} rank={} implicit={}",
271 self.n_dim,
272 factor.ncols(),
273 op.is_implicit(),
274 );
275 dense_spectral_stage_log(&signature, start.elapsed().as_secs_f64());
276 result
277 } else {
278 op.projected_matrix_cached(factor, &self.projected_factor_cache)
279 }
280 }
281
282 #[inline]
283 pub(crate) fn trace_projected_cross(&self, left: &Array2<f64>, right: &Array2<f64>) -> f64 {
284 let mut result = 0.0;
285 for (left_row, right_col) in left.rows().into_iter().zip(right.columns().into_iter()) {
286 for (left_value, right_value) in left_row.iter().copied().zip(right_col.iter().copied())
287 {
288 result += left_value * right_value;
289 }
290 }
291 result
292 }
293
294 #[inline]
295 pub(crate) fn trace_logdet_hessian_cross_rotated(
296 &self,
297 h_i_rot: &Array2<f64>,
298 h_j_rot: &Array2<f64>,
299 ) -> f64 {
300 let mut result = 0.0;
301 for ((kernel_row, h_i_row), h_j_col) in self
302 .logdet_hessian_kernel
303 .rows()
304 .into_iter()
305 .zip(h_i_rot.rows().into_iter())
306 .zip(h_j_rot.columns().into_iter())
307 {
308 for ((kernel, h_i_value), h_j_value) in kernel_row
309 .iter()
310 .copied()
311 .zip(h_i_row.iter().copied())
312 .zip(h_j_col.iter().copied())
313 {
314 result += kernel * h_i_value * h_j_value;
315 }
316 }
317 result
318 }
319}
320
321pub(crate) fn dense_spectral_stage_log(signature: &str, elapsed_s: f64) {
327 use std::sync::Mutex;
328 struct Repeat {
329 pub(crate) signature: String,
330 pub(crate) count: u64,
331 pub(crate) total: f64,
332 pub(crate) min: f64,
333 pub(crate) max: f64,
334 pub(crate) next_heartbeat: u64,
335 }
336 static REPEAT: Mutex<Option<Repeat>> = Mutex::new(None);
337
338 let mut guard = match REPEAT.lock() {
339 Ok(g) => g,
340 Err(poisoned) => poisoned.into_inner(),
341 };
342
343 if let Some(state) = guard.as_mut() {
344 if state.signature == signature {
345 state.count += 1;
346 state.total += elapsed_s;
347 if elapsed_s < state.min {
348 state.min = elapsed_s;
349 }
350 if elapsed_s > state.max {
351 state.max = elapsed_s;
352 }
353 if state.count >= state.next_heartbeat {
354 log::info!(
355 "[STAGE] {} (×{} so far, total={:.3}s min={:.3}s max={:.3}s avg={:.3}s)",
356 state.signature,
357 state.count,
358 state.total,
359 state.min,
360 state.max,
361 state.total / state.count as f64,
362 );
363 state.next_heartbeat = state.next_heartbeat.saturating_mul(2);
364 }
365 return;
366 }
367 if state.count > 1 {
371 log::info!(
372 "[STAGE] {} final ×{} total={:.3}s min={:.3}s max={:.3}s avg={:.3}s",
373 state.signature,
374 state.count,
375 state.total,
376 state.min,
377 state.max,
378 state.total / state.count as f64,
379 );
380 }
381 }
382
383 log::info!("[STAGE] {} elapsed={:.3}s", signature, elapsed_s);
384 *guard = Some(Repeat {
385 signature: signature.to_string(),
386 count: 1,
387 total: elapsed_s,
388 min: elapsed_s,
389 max: elapsed_s,
390 next_heartbeat: 2,
391 });
392}
393
394impl HessianOperator for DenseSpectralOperator {
395 fn logdet(&self) -> f64 {
396 self.cached_logdet
397 }
398
399 fn as_exact_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
400 Some(self)
401 }
402
403 fn assemble_h_dense_for_tangent_projection(&self) -> Result<Array2<f64>, String> {
404 Ok(assemble_h_raw_dense(self))
405 }
406
407 fn trace_hinv_product(&self, a: &Array2<f64>) -> f64 {
408 let aw = a.dot(&self.w_factor);
411 aw.iter()
412 .zip(self.w_factor.iter())
413 .map(|(&a, &w)| a * w)
414 .sum()
415 }
416
417 fn solve(&self, rhs: &Array1<f64>) -> Array1<f64> {
418 let mut result = Array1::zeros(self.n_dim);
425 for j in 0..self.n_dim {
426 if !self.active_mask[j] {
427 continue;
428 }
429 let u = self.eigenvectors.column(j);
430 let coeff = u.dot(rhs) / self.reg_eigenvalues[j];
431 for row in 0..self.n_dim {
432 result[row] += coeff * u[row];
433 }
434 }
435 result
436 }
437
438 fn solve_multi(&self, rhs: &Array2<f64>) -> Array2<f64> {
439 let mut projected = self.eigenvectors.t().dot(rhs);
440 for j in 0..self.n_dim {
441 if self.active_mask[j] {
442 let scale = 1.0 / self.reg_eigenvalues[j];
443 projected.row_mut(j).mapv_inplace(|value| value * scale);
444 } else {
445 projected.row_mut(j).fill(0.0);
448 }
449 }
450 self.eigenvectors.dot(&projected)
451 }
452
453 fn trace_hinv_product_cross(&self, a: &Array2<f64>, b: &Array2<f64>) -> f64 {
454 self.trace_hinv_product_cross_dense(a, b)
455 }
456
457 fn trace_hinv_operator(&self, op: &dyn HyperOperator) -> f64 {
458 if log::log_enabled!(log::Level::Info) {
459 let start = std::time::Instant::now();
460 let result =
461 op.trace_projected_factor_cached(&self.w_factor, &self.projected_factor_cache);
462 let signature = format!(
463 "DenseSpectralOperator::trace_hinv_operator dim={} rank={} implicit={}",
464 self.n_dim,
465 self.w_factor.ncols(),
466 op.is_implicit(),
467 );
468 dense_spectral_stage_log(&signature, start.elapsed().as_secs_f64());
469 result
470 } else {
471 op.trace_projected_factor_cached(&self.w_factor, &self.projected_factor_cache)
472 }
473 }
474
475 fn trace_hinv_matrix_operator_cross(
476 &self,
477 matrix: &Array2<f64>,
478 op: &dyn HyperOperator,
479 ) -> f64 {
480 let left = self.w_factor.t().dot(matrix).dot(&self.w_factor);
481 let right = self.projected_operator(&self.w_factor, op);
482 self.trace_projected_cross(&left, &right)
483 }
484
485 fn trace_hinv_operator_cross(
486 &self,
487 left: &dyn HyperOperator,
488 right: &dyn HyperOperator,
489 ) -> f64 {
490 if log::log_enabled!(log::Level::Info) {
491 let start = std::time::Instant::now();
492 let left_proj = self.projected_operator(&self.w_factor, left);
493 let result = if std::ptr::addr_eq(left, right) {
494 self.trace_projected_cross(&left_proj, &left_proj)
495 } else {
496 let right_proj = self.projected_operator(&self.w_factor, right);
497 self.trace_projected_cross(&left_proj, &right_proj)
498 };
499 let signature = format!(
500 "DenseSpectralOperator::trace_hinv_operator_cross dim={} rank={} left_implicit={} right_implicit={}",
501 self.n_dim,
502 self.w_factor.ncols(),
503 left.is_implicit(),
504 right.is_implicit(),
505 );
506 dense_spectral_stage_log(&signature, start.elapsed().as_secs_f64());
507 result
508 } else {
509 let left_proj = self.projected_operator(&self.w_factor, left);
510 if std::ptr::addr_eq(left, right) {
511 self.trace_projected_cross(&left_proj, &left_proj)
512 } else {
513 let right_proj = self.projected_operator(&self.w_factor, right);
514 self.trace_projected_cross(&left_proj, &right_proj)
515 }
516 }
517 }
518
519 fn trace_logdet_gradient(&self, a: &Array2<f64>) -> f64 {
520 let ag = a.dot(&self.g_factor);
524 ag.iter()
525 .zip(self.g_factor.iter())
526 .map(|(&a, &g)| a * g)
527 .sum()
528 }
529
530 fn xt_logdet_kernel_x_diagonal(&self, x: &DesignMatrix) -> Array1<f64> {
531 let n = x.nrows();
540 let p = x.ncols();
541 let rank = self.g_factor.ncols();
542 let mut h = Array1::<f64>::zeros(n);
543 if n == 0 || p == 0 || rank == 0 {
544 return h;
545 }
546 if let Some(gpu) = gam_gpu::linalg_dispatch::try_fast_spectral_leverage_diagonal(
553 x,
554 self.g_factor.view(),
555 ) {
556 return gpu;
557 }
558 let chunk_rows = byte_balanced_row_chunk(p + rank, n);
559 let mut start = 0usize;
560 while start < n {
561 let end = (start + chunk_rows).min(n);
562 let rows = x.try_row_chunk(start..end).unwrap_or_else(|err| {
563 reml_contract_panic(format!(
570 "xt_logdet_kernel_x_diagonal: row chunk failed: {err}"
571 ))
572 });
573 let xg = gam_linalg::faer_ndarray::fast_ab(&rows, &self.g_factor);
574 for (local, row) in xg.outer_iter().enumerate() {
575 h[start + local] = row.iter().map(|v| v * v).sum();
576 }
577 start = end;
578 }
579 h
580 }
581
582 fn trace_logdet_block_local(
583 &self,
584 block: &Array2<f64>,
585 scale: f64,
586 start: usize,
587 end: usize,
588 ) -> f64 {
589 let g_block = self.g_factor.slice(ndarray::s![start..end, ..]);
592 let ag = block.dot(&g_block);
593 scale
594 * ag.iter()
595 .zip(g_block.iter())
596 .map(|(&a, &g)| a * g)
597 .sum::<f64>()
598 }
599
600 fn trace_logdet_operator(&self, op: &dyn HyperOperator) -> f64 {
601 if log::log_enabled!(log::Level::Info) {
602 let start = std::time::Instant::now();
603 let result =
604 op.trace_projected_factor_cached(&self.g_factor, &self.projected_factor_cache);
605 let signature = format!(
606 "DenseSpectralOperator::trace_logdet_operator dim={} rank={} implicit={}",
607 self.n_dim,
608 self.g_factor.ncols(),
609 op.is_implicit(),
610 );
611 dense_spectral_stage_log(&signature, start.elapsed().as_secs_f64());
612 result
613 } else {
614 op.trace_projected_factor_cached(&self.g_factor, &self.projected_factor_cache)
615 }
616 }
617
618 fn trace_logdet_hessian_cross(&self, h_i: &Array2<f64>, h_j: &Array2<f64>) -> f64 {
619 let hp_i = self.rotate_to_eigenbasis(h_i);
620 if std::ptr::eq(h_i, h_j) {
621 return self.trace_logdet_hessian_cross_rotated(&hp_i, &hp_i);
622 }
623 let hp_j = self.rotate_to_eigenbasis(h_j);
624 self.trace_logdet_hessian_cross_rotated(&hp_i, &hp_j)
625 }
626
627 fn trace_logdet_hessian_cross_matrix_operator(
628 &self,
629 h_i: &Array2<f64>,
630 h_j: &dyn HyperOperator,
631 ) -> f64 {
632 let hp_i = self.rotate_to_eigenbasis(h_i);
633 let hp_j = self.projected_operator(&self.eigenvectors, h_j);
634 self.trace_logdet_hessian_cross_rotated(&hp_i, &hp_j)
635 }
636
637 fn trace_logdet_hessian_cross_operator(
638 &self,
639 h_i: &dyn HyperOperator,
640 h_j: &dyn HyperOperator,
641 ) -> f64 {
642 let hp_i = self.projected_operator(&self.eigenvectors, h_i);
643 if std::ptr::addr_eq(h_i, h_j) {
644 return self.trace_logdet_hessian_cross_rotated(&hp_i, &hp_i);
645 }
646 let hp_j = self.projected_operator(&self.eigenvectors, h_j);
647 self.trace_logdet_hessian_cross_rotated(&hp_i, &hp_j)
648 }
649
650 fn active_rank(&self) -> usize {
651 self.active_mask.iter().filter(|&&active| active).count()
652 }
653
654 fn dim(&self) -> usize {
655 self.n_dim
656 }
657
658 fn is_dense(&self) -> bool {
659 true
660 }
661
662 fn prefers_stochastic_trace_estimation(&self) -> bool {
663 false
664 }
665
666 fn logdet_traces_match_hinv_kernel(&self) -> bool {
667 false
668 }
669
670 fn as_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
671 Some(self)
672 }
673}