1use super::*;
2
3pub struct SparseCholeskyOperator {
12 pub(crate) factor: std::sync::Arc<gam_linalg::sparse_exact::SparseExactFactor>,
14 pub(crate) takahashi: Option<std::sync::Arc<gam_linalg::sparse_exact::TakahashiInverse>>,
17 pub(crate) cached_logdet: f64,
19 pub(crate) n_dim: usize,
21}
22
23impl SparseCholeskyOperator {
24 pub fn new(
26 factor: std::sync::Arc<gam_linalg::sparse_exact::SparseExactFactor>,
27 logdet_h: f64,
28 dim: usize,
29 ) -> Self {
30 Self {
31 factor,
32 takahashi: None,
33 cached_logdet: logdet_h,
34 n_dim: dim,
35 }
36 }
37
38 pub fn with_takahashi(
39 mut self,
40 taka: std::sync::Arc<gam_linalg::sparse_exact::TakahashiInverse>,
41 ) -> Self {
42 self.takahashi = Some(taka);
43 self
44 }
45
46 pub(crate) const OPERATOR_SOLVE_CHUNK: usize = 64;
47
48 pub(crate) fn takahashi_block_trace(
49 taka: &gam_linalg::sparse_exact::TakahashiInverse,
50 block: &Array2<f64>,
51 start: usize,
52 ) -> f64 {
53 assert_eq!(block.nrows(), block.ncols());
54 let mut trace = 0.0;
55 for i in 0..block.nrows() {
56 let diag = block[[i, i]];
57 if diag.abs() > 1e-30 {
58 trace += taka.get(start + i, start + i) * diag;
59 }
60 for j in (i + 1)..block.ncols() {
61 let pair = block[[i, j]] + block[[j, i]];
62 if pair.abs() > 1e-30 {
63 trace += taka.get(start + i, start + j) * pair;
64 }
65 }
66 }
67 trace
68 }
69
70 pub(crate) fn takahashi_left_multiply_block(
71 taka: &gam_linalg::sparse_exact::TakahashiInverse,
72 block: &Array2<f64>,
73 start: usize,
74 ) -> Array2<f64> {
75 let dim = block.nrows();
76 let mut out = Array2::<f64>::zeros((dim, dim));
77 for i in 0..dim {
78 let z_diag = taka.get(start + i, start + i);
79 if z_diag.abs() > 1e-30 {
80 for k in 0..dim {
81 out[[i, k]] += z_diag * block[[i, k]];
82 }
83 }
84 for j in (i + 1)..dim {
85 let z = taka.get(start + i, start + j);
86 if z.abs() <= 1e-30 {
87 continue;
88 }
89 for k in 0..dim {
90 out[[i, k]] += z * block[[j, k]];
91 out[[j, k]] += z * block[[i, k]];
92 }
93 }
94 }
95 out
96 }
97
98 pub(crate) fn trace_hinv_operator_exact(&self, op: &dyn HyperOperator) -> f64 {
99 let (range_start, range_end) = op
100 .block_local_data()
101 .map(|(_, start, end)| (start, end))
102 .unwrap_or((0, self.n_dim));
103 let chunk = Self::OPERATOR_SOLVE_CHUNK.min(self.n_dim.max(1));
104 let mut trace = 0.0_f64;
105 let mut rhs_block = Array2::<f64>::zeros((self.n_dim, chunk));
106 let mut start = range_start;
107
108 while start < range_end {
109 let end = (start + chunk).min(range_end);
110 let cols = end - start;
111 op.mul_basis_columns_into(start, rhs_block.slice_mut(ndarray::s![.., ..cols]));
112
113 let diagonal_sum = if cols == chunk {
114 gam_linalg::sparse_exact::solve_sparse_spdmulti_diagonal_sum(
115 &self.factor,
116 &rhs_block,
117 start,
118 )
119 } else {
120 let rhs_view = rhs_block.slice(ndarray::s![.., ..cols]);
121 gam_linalg::sparse_exact::solve_sparse_spdmulti_diagonal_sum(
122 &self.factor,
123 &rhs_view,
124 start,
125 )
126 };
127 trace += diagonal_sum.unwrap_or_else(|e| {
128 reml_contract_panic(format!(
137 "SparseCholeskyOperator exact trace_hinv_operator solve failed: {e}"
138 ))
139 });
140 start = end;
141 }
142
143 trace
144 }
145
146 pub(crate) fn solve_operator_column_range_rows_exact(
147 &self,
148 op: &dyn HyperOperator,
149 col_start: usize,
150 col_end: usize,
151 row_start: usize,
152 row_end: usize,
153 ) -> Result<Array2<f64>, String> {
154 let chunk = Self::OPERATOR_SOLVE_CHUNK.min(self.n_dim.max(1));
155 let cols_total = col_end - col_start;
156 let rows_total = row_end - row_start;
157 let mut solved = Array2::<f64>::zeros((rows_total, cols_total));
158 let mut rhs_block = Array2::<f64>::zeros((self.n_dim, chunk));
159 let mut start = col_start;
160
161 while start < col_end {
162 let end = (start + chunk).min(col_end);
163 let cols = end - start;
164 op.mul_basis_columns_into(start, rhs_block.slice_mut(ndarray::s![.., ..cols]));
165
166 let solved_block = if cols == chunk {
167 gam_linalg::sparse_exact::solve_sparse_spdmulti_rows(
168 &self.factor,
169 &rhs_block,
170 row_start,
171 row_end,
172 )
173 } else {
174 let rhs_view = rhs_block.slice(ndarray::s![.., ..cols]);
175 gam_linalg::sparse_exact::solve_sparse_spdmulti_rows(
176 &self.factor,
177 &rhs_view,
178 row_start,
179 row_end,
180 )
181 }
182 .map_err(|e| {
183 format!(
184 "SparseCholeskyOperator::solve_operator_column_range_rows_exact multi-solve failed: {e}"
185 )
186 })?;
187 solved
188 .slice_mut(ndarray::s![.., start - col_start..end - col_start])
189 .assign(&solved_block);
190 start = end;
191 }
192
193 Ok(solved)
194 }
195
196 pub(crate) fn trace_hinv_matrix_operator_cross_exact(
197 &self,
198 matrix: &Array2<f64>,
199 op: &dyn HyperOperator,
200 ) -> f64 {
201 if let Some((_, range_start, range_end)) = op.block_local_data()
202 && range_end - range_start < self.n_dim
203 {
204 return self.trace_hinv_matrix_block_operator_cross_exact(
205 matrix,
206 op,
207 range_start,
208 range_end,
209 );
210 }
211
212 let solved_matrix = self.solve_multi(matrix);
213 let chunk = Self::OPERATOR_SOLVE_CHUNK.min(self.n_dim.max(1));
214 let mut rhs_block = Array2::<f64>::zeros((self.n_dim, chunk));
215 let mut trace = 0.0_f64;
216 let (range_start, range_end) = op
217 .block_local_data()
218 .map(|(_, start, end)| (start, end))
219 .unwrap_or((0, self.n_dim));
220 let mut start = range_start;
221
222 while start < range_end {
223 let end = (start + chunk).min(range_end);
224 let cols = end - start;
225 op.mul_basis_columns_into(start, rhs_block.slice_mut(ndarray::s![.., ..cols]));
226
227 let solved_op = if cols == chunk {
228 gam_linalg::sparse_exact::solve_sparse_spdmulti(&self.factor, &rhs_block)
229 } else {
230 let rhs_view = rhs_block.slice(ndarray::s![.., ..cols]);
231 gam_linalg::sparse_exact::solve_sparse_spdmulti(&self.factor, &rhs_view)
232 };
233
234 let solved_op = solved_op.unwrap_or_else(|e| {
235 panic!("SparseCholeskyOperator exact matrix/operator cross solve failed: {e}")
242 });
243
244 for local_col in 0..cols {
245 let matrix_row = start + local_col;
246 for row in 0..self.n_dim {
247 trace += solved_matrix[[matrix_row, row]] * solved_op[[row, local_col]];
248 }
249 }
250 start = end;
251 }
252
253 trace
254 }
255
256 pub(crate) fn trace_hinv_matrix_block_operator_cross_exact(
257 &self,
258 matrix: &Array2<f64>,
259 op: &dyn HyperOperator,
260 range_start: usize,
261 range_end: usize,
262 ) -> f64 {
263 let t_start = std::time::Instant::now();
264 let chunk = Self::OPERATOR_SOLVE_CHUNK.min(self.n_dim.max(1));
265 let mut op_rhs_block = Array2::<f64>::zeros((self.n_dim, chunk));
266 let mut eye_rhs_block = Array2::<f64>::zeros((self.n_dim, chunk));
267 let mut trace = 0.0_f64;
268 let mut start = range_start;
269
270 while start < range_end {
271 let end = (start + chunk).min(range_end);
272 let cols = end - start;
273 op.mul_basis_columns_into(start, op_rhs_block.slice_mut(ndarray::s![.., ..cols]));
274
275 eye_rhs_block.fill(0.0);
276 for local_col in 0..cols {
277 eye_rhs_block[[start + local_col, local_col]] = 1.0;
278 }
279
280 let solved_op = if cols == chunk {
281 gam_linalg::sparse_exact::solve_sparse_spdmulti(&self.factor, &op_rhs_block)
282 } else {
283 let rhs_view = op_rhs_block.slice(ndarray::s![.., ..cols]);
284 gam_linalg::sparse_exact::solve_sparse_spdmulti(&self.factor, &rhs_view)
285 };
286 let solved_op = solved_op.unwrap_or_else(|e| {
287 panic!(
293 "SparseCholeskyOperator exact matrix/block-operator cross operator solve failed: {e}"
294 )
295 });
296
297 let solved_eye = if cols == chunk {
298 gam_linalg::sparse_exact::solve_sparse_spdmulti(&self.factor, &eye_rhs_block)
299 } else {
300 let rhs_view = eye_rhs_block.slice(ndarray::s![.., ..cols]);
301 gam_linalg::sparse_exact::solve_sparse_spdmulti(&self.factor, &rhs_view)
302 };
303 let solved_eye = solved_eye.unwrap_or_else(|e| {
304 panic!(
310 "SparseCholeskyOperator exact matrix/block-operator cross identity solve failed: {e}"
311 )
312 });
313
314 let selected_rows_t = matrix.t().dot(&solved_eye);
315 for local_col in 0..cols {
316 for row in 0..self.n_dim {
317 trace += selected_rows_t[[row, local_col]] * solved_op[[row, local_col]];
318 }
319 }
320 start = end;
321 }
322
323 let elapsed_ms = t_start.elapsed().as_secs_f64() * 1000.0;
324 if elapsed_ms > REML_TRACE_SLOW_LOG_MS {
325 log::info!(
326 "[REML-trace] matrix_block_op_cross_exact | n_dim={} | block={} | {:.1}ms",
327 self.n_dim,
328 range_end - range_start,
329 elapsed_ms
330 );
331 }
332 trace
333 }
334
335 pub(crate) fn trace_hinv_operator_cross_exact(
336 &self,
337 left: &dyn HyperOperator,
338 right: &dyn HyperOperator,
339 ) -> f64 {
340 let (left_start, left_end) = left
341 .block_local_data()
342 .map(|(_, start, end)| (start, end))
343 .unwrap_or((0, self.n_dim));
344 let (right_start, right_end) = right
345 .block_local_data()
346 .map(|(_, start, end)| (start, end))
347 .unwrap_or((0, self.n_dim));
348
349 let solved_left = self
350 .solve_operator_column_range_rows_exact(
351 left,
352 left_start,
353 left_end,
354 right_start,
355 right_end,
356 )
357 .unwrap_or_else(|e| {
358 panic!("SparseCholeskyOperator exact operator cross left solve failed: {e}")
365 });
366 let same_operator =
367 std::ptr::addr_eq(left, right) && left_start == right_start && left_end == right_end;
368 let solved_right = if same_operator {
369 None
370 } else {
371 Some(
372 self.solve_operator_column_range_rows_exact(
373 right,
374 right_start,
375 right_end,
376 left_start,
377 left_end,
378 )
379 .unwrap_or_else(|e| {
380 panic!("SparseCholeskyOperator exact operator cross right solve failed: {e}")
386 }),
387 )
388 };
389
390 let right_cols = right_end - right_start;
391 let mut trace = 0.0;
392 for left_col in 0..(left_end - left_start) {
393 for right_col in 0..right_cols {
394 let right_value = match solved_right.as_ref() {
395 Some(solved) => solved[[left_col, right_col]],
396 None => solved_left[[left_col, right_col]],
397 };
398 trace += solved_left[[right_col, left_col]] * right_value;
399 }
400 }
401 trace
402 }
403}
404
405impl HessianOperator for SparseCholeskyOperator {
406 fn logdet(&self) -> f64 {
407 self.cached_logdet
408 }
409
410 fn assemble_h_dense_for_tangent_projection(&self) -> Result<Array2<f64>, String> {
411 let h = gam_linalg::sparse_exact::assemble_sparse_factor_h_dense(&self.factor)
412 .map_err(|e| e.to_string())?;
413 if h.nrows() != self.n_dim || h.ncols() != self.n_dim {
414 return Err(format!(
415 "sparse Cholesky tangent projection dense H has shape {}x{}, expected {}x{}",
416 h.nrows(),
417 h.ncols(),
418 self.n_dim,
419 self.n_dim
420 ));
421 }
422 Ok(h)
423 }
424
425 fn trace_hinv_product(&self, a: &Array2<f64>) -> f64 {
426 if let Some(ref taka) = self.takahashi {
429 let mut trace = 0.0;
430 for i in 0..a.nrows() {
431 let a_ii = a[[i, i]];
432 if a_ii.abs() > 1e-30 {
433 trace += taka.get(i, i) * a_ii;
434 }
435 for j in (i + 1)..a.ncols() {
436 let pair = a[[i, j]] + a[[j, i]];
437 if pair.abs() > 1e-30 {
438 trace += taka.get(i, j) * pair;
439 }
440 }
441 }
442 return trace;
443 }
444 gam_linalg::sparse_exact::solve_sparse_spdmulti(&self.factor, a)
445 .unwrap_or_else(|e| {
446 panic!("SparseCholeskyOperator exact trace_hinv_product solve failed: {e}")
453 })
454 .diag()
455 .sum()
456 }
457
458 fn trace_hinv_operator(&self, op: &dyn HyperOperator) -> f64 {
459 if let Some(ref taka) = self.takahashi {
460 if let Some((local, start, end)) = op.block_local_data() {
461 assert_eq!(local.nrows(), end - start);
462 return Self::takahashi_block_trace(taka, local, start);
463 }
464 if !op.is_implicit() {
466 let dense = op.to_dense();
467 return self.trace_hinv_product(&dense);
468 }
469 }
470 self.trace_hinv_operator_exact(op)
471 }
472
473 fn trace_logdet_operator(&self, op: &dyn HyperOperator) -> f64 {
474 self.trace_hinv_operator(op)
475 }
476
477 fn solve(&self, rhs: &Array1<f64>) -> Array1<f64> {
478 gam_linalg::sparse_exact::solve_sparse_spd(&self.factor, rhs)
483 .unwrap_or_else(|e| panic!("SparseCholeskyOperator exact solve failed: {e}"))
485 }
486
487 fn solve_multi(&self, rhs: &Array2<f64>) -> Array2<f64> {
488 gam_linalg::sparse_exact::solve_sparse_spdmulti(&self.factor, rhs)
492 .unwrap_or_else(|e| panic!("SparseCholeskyOperator exact multi-solve failed: {e}"))
494 }
495
496 fn trace_hinv_product_cross(&self, a: &Array2<f64>, b: &Array2<f64>) -> f64 {
497 let solved_a = self.solve_multi(a);
501 if std::ptr::eq(a, b) {
502 return trace_matrix_product(&solved_a, &solved_a);
503 }
504 let solved_b = self.solve_multi(b);
505 trace_matrix_product(&solved_a, &solved_b)
506 }
507
508 fn trace_hinv_matrix_operator_cross(
509 &self,
510 matrix: &Array2<f64>,
511 op: &dyn HyperOperator,
512 ) -> f64 {
513 self.trace_hinv_matrix_operator_cross_exact(matrix, op)
517 }
518
519 fn trace_hinv_operator_cross(
520 &self,
521 left: &dyn HyperOperator,
522 right: &dyn HyperOperator,
523 ) -> f64 {
524 if let Some(ref taka) = self.takahashi
527 && let (Some((a_local, a_start, a_end)), Some((b_local, b_start, b_end))) =
528 (left.block_local_data(), right.block_local_data())
529 && a_start == b_start
530 && a_end == b_end
531 {
532 let za = Self::takahashi_left_multiply_block(taka, a_local, a_start);
534 if std::ptr::addr_eq(left, right) {
535 return trace_matrix_product(&za, &za);
536 }
537 let zb = Self::takahashi_left_multiply_block(taka, b_local, b_start);
538 return (&za * &zb.t()).sum();
540 }
541 self.trace_hinv_operator_cross_exact(left, right)
544 }
545
546 fn trace_logdet_hessian_cross_matrix_operator(
547 &self,
548 h_i: &Array2<f64>,
549 h_j: &dyn HyperOperator,
550 ) -> f64 {
551 -self.trace_hinv_matrix_operator_cross(h_i, h_j)
552 }
553
554 fn trace_logdet_hessian_cross_operator(
555 &self,
556 h_i: &dyn HyperOperator,
557 h_j: &dyn HyperOperator,
558 ) -> f64 {
559 -self.trace_hinv_operator_cross(h_i, h_j)
560 }
561
562 fn active_rank(&self) -> usize {
563 self.n_dim
564 }
565
566 fn dim(&self) -> usize {
567 self.n_dim
568 }
569}
570
571pub struct DenseCholeskyValueOnlyOperator {
600 pub(crate) chol: gam_linalg::faer_ndarray::FaerCholeskyFactor,
602 pub(crate) cached_logdet: f64,
604 pub(crate) n_dim: usize,
606}
607
608impl DenseCholeskyValueOnlyOperator {
609 pub fn from_spd(h: &Array2<f64>) -> Result<Self, String> {
615 use gam_linalg::faer_ndarray::FaerCholesky;
616 use faer::Side;
617
618 let n = h.nrows();
619 if n != h.ncols() {
620 return Err(format!(
621 "DenseCholeskyValueOnlyOperator: expected square matrix, got {}×{}",
622 n,
623 h.ncols()
624 ));
625 }
626 let chol = h
627 .cholesky(Side::Lower)
628 .map_err(|e| format!("DenseCholeskyValueOnlyOperator LLT failed: {e}"))?;
629 let diag = chol.diag();
630 let cached_logdet = 2.0 * diag.iter().map(|&d| d.ln()).sum::<f64>();
631 Ok(Self {
632 chol,
633 cached_logdet,
634 n_dim: n,
635 })
636 }
637}
638
639impl HessianOperator for DenseCholeskyValueOnlyOperator {
640 fn logdet(&self) -> f64 {
641 self.cached_logdet
642 }
643
644 fn trace_hinv_product(&self, a: &Array2<f64>) -> f64 {
645 let hinv_a = self.chol.solve_mat(a);
648 hinv_a.diag().iter().sum()
649 }
650
651 fn solve(&self, rhs: &Array1<f64>) -> Array1<f64> {
652 self.chol.solvevec(rhs)
653 }
654
655 fn solve_multi(&self, rhs: &Array2<f64>) -> Array2<f64> {
656 self.chol.solve_mat(rhs)
657 }
658
659 fn active_rank(&self) -> usize {
660 self.n_dim
662 }
663
664 fn dim(&self) -> usize {
665 self.n_dim
666 }
667}
668
669pub struct BlockCoupledOperator {
690 pub(crate) inner: DenseSpectralOperator,
692}
693
694impl BlockCoupledOperator {
695 pub fn from_joint_hessian_with_mode(
699 joint_hessian: &Array2<f64>,
700 mode: PseudoLogdetMode,
701 ) -> Result<Self, String> {
702 let inner = DenseSpectralOperator::from_symmetric_with_mode(joint_hessian, mode)
703 .map_err(|e| format!("BlockCoupledOperator eigendecomposition: {e}"))?;
704
705 Ok(Self { inner })
706 }
707}
708
709impl HessianOperator for BlockCoupledOperator {
710 fn logdet(&self) -> f64 {
711 self.inner.logdet()
712 }
713
714 fn as_exact_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
715 self.inner.as_exact_dense_spectral()
716 }
717
718 fn assemble_h_dense_for_tangent_projection(&self) -> Result<Array2<f64>, String> {
719 self.inner.assemble_h_dense_for_tangent_projection()
720 }
721
722 fn trace_hinv_product(&self, a: &Array2<f64>) -> f64 {
723 self.inner.trace_hinv_product(a)
724 }
725
726 fn trace_logdet_gradient(&self, a: &Array2<f64>) -> f64 {
727 self.inner.trace_logdet_gradient(a)
728 }
729
730 fn xt_logdet_kernel_x_diagonal(&self, x: &DesignMatrix) -> Array1<f64> {
731 self.inner.xt_logdet_kernel_x_diagonal(x)
732 }
733
734 fn trace_logdet_h_k(
735 &self,
736 a_k: &Array2<f64>,
737 third_deriv_correction: Option<&Array2<f64>>,
738 ) -> f64 {
739 self.inner.trace_logdet_h_k(a_k, third_deriv_correction)
740 }
741
742 fn trace_logdet_operator(&self, op: &dyn HyperOperator) -> f64 {
743 self.inner.trace_logdet_operator(op)
744 }
745
746 fn trace_logdet_hessian_cross(&self, h_i: &Array2<f64>, h_j: &Array2<f64>) -> f64 {
747 self.inner.trace_logdet_hessian_cross(h_i, h_j)
748 }
749
750 fn solve(&self, rhs: &Array1<f64>) -> Array1<f64> {
751 self.inner.solve(rhs)
752 }
753
754 fn solve_multi(&self, rhs: &Array2<f64>) -> Array2<f64> {
755 self.inner.solve_multi(rhs)
756 }
757
758 fn trace_hinv_product_cross(&self, a: &Array2<f64>, b: &Array2<f64>) -> f64 {
759 self.inner.trace_hinv_product_cross(a, b)
760 }
761
762 fn trace_hinv_matrix_operator_cross(
763 &self,
764 matrix: &Array2<f64>,
765 op: &dyn HyperOperator,
766 ) -> f64 {
767 self.inner.trace_hinv_matrix_operator_cross(matrix, op)
768 }
769
770 fn trace_hinv_operator_cross(
771 &self,
772 left: &dyn HyperOperator,
773 right: &dyn HyperOperator,
774 ) -> f64 {
775 self.inner.trace_hinv_operator_cross(left, right)
776 }
777
778 fn active_rank(&self) -> usize {
779 self.inner.active_rank()
780 }
781
782 fn dim(&self) -> usize {
783 self.inner.dim()
784 }
785
786 fn is_dense(&self) -> bool {
787 true
788 }
789
790 fn prefers_stochastic_trace_estimation(&self) -> bool {
791 false
792 }
793
794 fn logdet_traces_match_hinv_kernel(&self) -> bool {
795 false
796 }
797
798 fn as_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
799 Some(&self.inner)
800 }
801}
802
803pub struct MatrixFreeSpdOperator {
816 pub(crate) apply: Arc<dyn Fn(&Array1<f64>) -> Array1<f64> + Send + Sync>,
817 pub(crate) dense_assemble: Option<Arc<dyn Fn() -> Option<Array2<f64>> + Send + Sync>>,
829 pub(crate) cached_logdet: gam_runtime::resource::RayonSafeOnce<f64>,
830 pub(crate) n_dim: usize,
831 pub(crate) dense_spectral: gam_runtime::resource::RayonSafeOnce<Option<DenseSpectralOperator>>,
841 pub(crate) mode: PseudoLogdetMode,
850}
851
852impl MatrixFreeSpdOperator {
853 pub(crate) const EXACT_DENSE_SPECTRAL_MAX_BYTES: usize = 512 * 1024 * 1024;
854 pub(crate) const EXACT_DENSE_SPECTRAL_ARRAYS: usize = 6;
855
856 pub fn new_with_mode<F>(dim: usize, apply: F, mode: PseudoLogdetMode) -> Self
857 where
858 F: Fn(&Array1<f64>) -> Array1<f64> + Send + Sync + 'static,
859 {
860 Self::new_with_mode_and_dense_assemble(dim, apply, mode, None)
861 }
862
863 pub fn new_with_mode_and_dense_assemble<F>(
868 dim: usize,
869 apply: F,
870 mode: PseudoLogdetMode,
871 dense_assemble: Option<Arc<dyn Fn() -> Option<Array2<f64>> + Send + Sync>>,
872 ) -> Self
873 where
874 F: Fn(&Array1<f64>) -> Array1<f64> + Send + Sync + 'static,
875 {
876 let apply = Arc::new(apply);
877
878 Self {
879 apply,
880 dense_assemble,
881 cached_logdet: gam_runtime::resource::RayonSafeOnce::new(),
882 n_dim: dim,
883 dense_spectral: gam_runtime::resource::RayonSafeOnce::new(),
884 mode,
885 }
886 }
887
888 pub(crate) fn exact_dense_spectral_bytes(&self) -> Option<usize> {
889 self.n_dim
890 .checked_mul(self.n_dim)?
891 .checked_mul(std::mem::size_of::<f64>())?
892 .checked_mul(Self::EXACT_DENSE_SPECTRAL_ARRAYS)
893 }
894
895 pub(crate) fn exact_dense_spectral_budget_ok(&self) -> bool {
896 match self.exact_dense_spectral_bytes() {
897 Some(bytes) if bytes <= Self::EXACT_DENSE_SPECTRAL_MAX_BYTES => true,
898 Some(bytes) => {
899 log::error!(
900 "MatrixFreeSpdOperator exact dense spectral materialization requires {:.2} GiB \
901 for dim={}, exceeding the {:.2} GiB cap",
902 bytes as f64 / (1024.0 * 1024.0 * 1024.0),
903 self.n_dim,
904 Self::EXACT_DENSE_SPECTRAL_MAX_BYTES as f64 / (1024.0 * 1024.0 * 1024.0),
905 );
906 false
907 }
908 None => {
909 log::error!(
910 "MatrixFreeSpdOperator exact dense spectral byte count overflow for dim={}",
911 self.n_dim
912 );
913 false
914 }
915 }
916 }
917
918 pub(crate) fn materialize_dense_operator(&self) -> Option<DenseSpectralOperator> {
919 if !self.exact_dense_spectral_budget_ok() {
920 return None;
921 }
922 let materialize_start = std::time::Instant::now();
923 let (matrix, matvec_count) =
929 match self.dense_assemble.as_ref().and_then(|assemble| assemble()) {
930 Some(mut direct)
931 if direct.nrows() == self.n_dim
932 && direct.ncols() == self.n_dim
933 && direct.iter().all(|v| v.is_finite()) =>
934 {
935 for i in 0..self.n_dim {
939 for j in (i + 1)..self.n_dim {
940 let avg = 0.5 * (direct[[i, j]] + direct[[j, i]]);
941 direct[[i, j]] = avg;
942 direct[[j, i]] = avg;
943 }
944 }
945 (direct, 0usize)
946 }
947 _ => {
948 let mut matrix = Array2::<f64>::zeros((self.n_dim, self.n_dim));
949 let mut basis = Array1::<f64>::zeros(self.n_dim);
950 for j in 0..self.n_dim {
951 basis[j] = 1.0;
952 let col = (self.apply)(&basis);
953 basis[j] = 0.0;
954 if col.len() != self.n_dim || !col.iter().all(|v| v.is_finite()) {
955 return None;
956 }
957 matrix.column_mut(j).assign(&col);
958 }
959 for i in 0..self.n_dim {
960 for j in (i + 1)..self.n_dim {
961 let avg = 0.5 * (matrix[[i, j]] + matrix[[j, i]]);
962 matrix[[i, j]] = avg;
963 matrix[[j, i]] = avg;
964 }
965 }
966 (matrix, self.n_dim)
967 }
968 };
969 let result = DenseSpectralOperator::from_symmetric_with_mode(&matrix, self.mode).ok();
970 log::info!(
971 "[STAGE] matrix_free_spd materialize n_dim={} matvec_count={} elapsed={:.3}s",
972 self.n_dim,
973 matvec_count,
974 materialize_start.elapsed().as_secs_f64(),
975 );
976 result
977 }
978
979 pub(crate) fn dense_spectral(&self) -> Option<&DenseSpectralOperator> {
980 self.dense_spectral
981 .get_or_compute(|| self.materialize_dense_operator())
982 .as_ref()
983 }
984
985 pub(crate) fn exact_dense_spectral(&self) -> &DenseSpectralOperator {
986 self.dense_spectral().expect(
987 "MatrixFreeSpdOperator exact REML algebra requires dense spectral materialization within the configured budget",
988 )
989 }
990
991 pub(crate) fn use_trace_cg(&self, rel_tol: f64) -> bool {
992 rel_tol.is_finite()
993 && rel_tol > 0.0
994 && self.prefers_stochastic_trace_estimation()
995 && self.has_matrix_free_trace_cg_operator()
996 }
997
998 pub(crate) fn cg_trace_solve(
999 &self,
1000 rhs: &Array1<f64>,
1001 rel_tol: f64,
1002 probe_id: Option<u64>,
1003 trace_state: Option<&Arc<Mutex<StochasticTraceState>>>,
1004 ) -> Array1<f64> {
1005 let dim = rhs.len();
1006 if dim != self.n_dim {
1007 return self.solve(rhs);
1008 }
1009
1010 let (initial, warm_start_used) = match (probe_id, trace_state) {
1011 (Some(id), Some(state)) => {
1012 let cached = match state.lock() {
1013 Ok(guard) => guard.cg_warm_starts.get(&id).cloned(),
1014 Err(poisoned) => poisoned.into_inner().cg_warm_starts.get(&id).cloned(),
1015 };
1016 match cached {
1017 Some(x) if x.len() == dim => (x, true),
1018 _ => (Array1::<f64>::zeros(dim), false),
1019 }
1020 }
1021 _ => (Array1::<f64>::zeros(dim), false),
1022 };
1023
1024 let Some((solution, iters, residual_norm)) =
1025 conjugate_gradient_trace_solve(rhs, rel_tol, initial, |v| (self.apply)(v))
1026 else {
1027 return self.solve(rhs);
1028 };
1029
1030 if let Some(state) = trace_state {
1031 let mut guard = match state.lock() {
1032 Ok(guard) => guard,
1033 Err(poisoned) => poisoned.into_inner(),
1034 };
1035 guard.last_linear_residual_norm = Some(
1036 guard
1037 .last_linear_residual_norm
1038 .unwrap_or(0.0)
1039 .max(residual_norm),
1040 );
1041 if let Some(id) = probe_id {
1042 guard.cg_warm_starts.insert(id, solution.clone());
1043 }
1044 }
1045
1046 let probe_label = probe_id
1047 .map(|id| id.to_string())
1048 .unwrap_or_else(|| "untracked".to_string());
1049 log::info!(
1050 "[CG-TRACE] probe_id={} iters={} rel_tol={} warm_start_used={}",
1051 probe_label,
1052 iters,
1053 rel_tol,
1054 warm_start_used
1055 );
1056
1057 solution
1058 }
1059}
1060
1061pub(crate) fn conjugate_gradient_trace_solve<F>(
1062 rhs: &Array1<f64>,
1063 rel_tol: f64,
1064 mut x: Array1<f64>,
1065 apply: F,
1066) -> Option<(Array1<f64>, usize, f64)>
1067where
1068 F: Fn(&Array1<f64>) -> Array1<f64>,
1069{
1070 let dim = rhs.len();
1071 if x.len() != dim {
1072 return None;
1073 }
1074
1075 let rhs_norm_sq = rhs.dot(rhs);
1076 if !rhs_norm_sq.is_finite() {
1077 return None;
1078 }
1079 if rhs_norm_sq <= f64::MIN_POSITIVE {
1080 return Some((Array1::<f64>::zeros(dim), 0, 0.0));
1081 }
1082
1083 let target_sq = (rel_tol * rel_tol * rhs_norm_sq).max(f64::MIN_POSITIVE);
1084 let mut r = rhs.clone();
1085 if x.iter().any(|value| *value != 0.0) {
1086 let ax = apply(&x);
1087 if ax.len() != dim || !ax.iter().all(|value| value.is_finite()) {
1088 return None;
1089 }
1090 r.scaled_add(-1.0, &ax);
1091 }
1092
1093 let mut rs_old = r.dot(&r);
1094 if !rs_old.is_finite() {
1095 return None;
1096 }
1097 if rs_old <= target_sq {
1098 return Some((x, 0, rs_old.max(0.0).sqrt()));
1099 }
1100
1101 let mut p = r.clone();
1102 let mut iters = 0usize;
1103 let mut residual_norm = rs_old.max(0.0).sqrt();
1104 for k in 0..dim.max(1) {
1105 let ap = apply(&p);
1106 if ap.len() != dim || !ap.iter().all(|value| value.is_finite()) {
1107 return None;
1108 }
1109 let denom = p.dot(&ap);
1110 if !denom.is_finite() || denom <= 0.0 {
1111 log::warn!(
1112 "[CG-TRACE] non-positive curvature in trace CG at iter={} denom={}",
1113 k + 1,
1114 denom
1115 );
1116 break;
1117 }
1118 let alpha = rs_old / denom;
1119 if !alpha.is_finite() {
1120 return None;
1121 }
1122 x.scaled_add(alpha, &p);
1123 r.scaled_add(-alpha, &ap);
1124 let rs_new = r.dot(&r);
1125 if !rs_new.is_finite() {
1126 return None;
1127 }
1128 iters = k + 1;
1129 residual_norm = rs_new.max(0.0).sqrt();
1130 if rs_new <= target_sq {
1131 break;
1132 }
1133 let beta = rs_new / rs_old;
1134 if !beta.is_finite() {
1135 return None;
1136 }
1137 p.mapv_inplace(|value| beta * value);
1138 p += &r;
1139 rs_old = rs_new;
1140 }
1141
1142 Some((x, iters, residual_norm))
1143}
1144
1145impl HessianOperator for MatrixFreeSpdOperator {
1146 fn logdet(&self) -> f64 {
1147 *self
1148 .cached_logdet
1149 .get_or_compute(|| self.exact_dense_spectral().logdet())
1150 }
1151
1152 fn as_exact_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
1153 Some(self.exact_dense_spectral())
1154 }
1155
1156 fn trace_hinv_product(&self, a: &Array2<f64>) -> f64 {
1157 self.exact_dense_spectral().trace_hinv_product(a)
1158 }
1159
1160 fn trace_hinv_operator(&self, op: &dyn HyperOperator) -> f64 {
1161 self.exact_dense_spectral().trace_hinv_operator(op)
1162 }
1163
1164 fn trace_hinv_product_cross(&self, a: &Array2<f64>, b: &Array2<f64>) -> f64 {
1165 self.exact_dense_spectral().trace_hinv_product_cross(a, b)
1166 }
1167
1168 fn trace_hinv_matrix_operator_cross(
1169 &self,
1170 matrix: &Array2<f64>,
1171 op: &dyn HyperOperator,
1172 ) -> f64 {
1173 self.exact_dense_spectral()
1174 .trace_hinv_matrix_operator_cross(matrix, op)
1175 }
1176
1177 fn trace_hinv_operator_cross(
1178 &self,
1179 left: &dyn HyperOperator,
1180 right: &dyn HyperOperator,
1181 ) -> f64 {
1182 self.exact_dense_spectral()
1183 .trace_hinv_operator_cross(left, right)
1184 }
1185
1186 fn trace_logdet_operator(&self, op: &dyn HyperOperator) -> f64 {
1187 let trace_start = std::time::Instant::now();
1188 let result = self.exact_dense_spectral().trace_logdet_operator(op);
1189 log::info!(
1190 "[STAGE] matrix_free_spd trace_logdet_operator implicit={} dim={} elapsed={:.3}s",
1191 op.is_implicit(),
1192 op.dim(),
1193 trace_start.elapsed().as_secs_f64(),
1194 );
1195 result
1196 }
1197
1198 fn solve(&self, rhs: &Array1<f64>) -> Array1<f64> {
1199 self.exact_dense_spectral().solve(rhs)
1200 }
1201
1202 fn solve_multi(&self, rhs: &Array2<f64>) -> Array2<f64> {
1203 self.exact_dense_spectral().solve_multi(rhs)
1204 }
1205
1206 fn stochastic_trace_solve(&self, rhs: &Array1<f64>, rel_tol: f64) -> Array1<f64> {
1207 if self.use_trace_cg(rel_tol) {
1208 return self.cg_trace_solve(rhs, rel_tol, None, None);
1209 }
1210 self.solve(rhs)
1211 }
1212
1213 fn stochastic_trace_solve_for_probe(
1214 &self,
1215 rhs: &Array1<f64>,
1216 rel_tol: f64,
1217 probe_id: u64,
1218 trace_state: Option<&Arc<Mutex<StochasticTraceState>>>,
1219 ) -> Array1<f64> {
1220 if self.use_trace_cg(rel_tol) {
1221 return self.cg_trace_solve(rhs, rel_tol, Some(probe_id), trace_state);
1222 }
1223 self.solve(rhs)
1224 }
1225
1226 fn stochastic_trace_solve_multi(&self, rhs: &Array2<f64>, rel_tol: f64) -> Array2<f64> {
1227 if self.use_trace_cg(rel_tol) {
1228 let mut out = Array2::<f64>::zeros(rhs.raw_dim());
1229 for j in 0..rhs.ncols() {
1230 let solved = self.cg_trace_solve(&rhs.column(j).to_owned(), rel_tol, None, None);
1231 out.column_mut(j).assign(&solved);
1232 }
1233 return out;
1234 }
1235 self.solve_multi(rhs)
1236 }
1237
1238 fn trace_logdet_hessian_cross(&self, h_i: &Array2<f64>, h_j: &Array2<f64>) -> f64 {
1239 self.exact_dense_spectral()
1240 .trace_logdet_hessian_cross(h_i, h_j)
1241 }
1242
1243 fn trace_logdet_hessian_cross_matrix_operator(
1244 &self,
1245 h_i: &Array2<f64>,
1246 h_j: &dyn HyperOperator,
1247 ) -> f64 {
1248 self.exact_dense_spectral()
1249 .trace_logdet_hessian_cross_matrix_operator(h_i, h_j)
1250 }
1251
1252 fn trace_logdet_hessian_cross_operator(
1253 &self,
1254 h_i: &dyn HyperOperator,
1255 h_j: &dyn HyperOperator,
1256 ) -> f64 {
1257 self.exact_dense_spectral()
1258 .trace_logdet_hessian_cross_operator(h_i, h_j)
1259 }
1260
1261 fn active_rank(&self) -> usize {
1262 self.n_dim
1263 }
1264
1265 fn dim(&self) -> usize {
1266 self.n_dim
1267 }
1268
1269 fn is_dense(&self) -> bool {
1270 true
1271 }
1272
1273 fn prefers_stochastic_trace_estimation(&self) -> bool {
1286 !self.exact_dense_spectral_budget_ok()
1287 }
1288
1289 fn logdet_traces_match_hinv_kernel(&self) -> bool {
1300 !self.exact_dense_spectral_budget_ok()
1301 }
1302
1303 fn as_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
1304 self.dense_spectral()
1305 }
1306
1307 fn has_matrix_free_trace_cg_operator(&self) -> bool {
1308 true
1309 }
1310}
1311
1312pub fn penalty_matrix_root(s: &Array2<f64>) -> Result<Array2<f64>, String> {
1321 use faer::Side;
1322 let n = s.nrows();
1323 if n != s.ncols() {
1324 return Err(RemlError::DimensionMismatch {
1325 reason: format!(
1326 "penalty_matrix_root: expected square matrix, got {}×{}",
1327 n,
1328 s.ncols()
1329 ),
1330 }
1331 .into());
1332 }
1333 if n == 0 {
1334 return Ok(Array2::zeros((0, 0)));
1335 }
1336
1337 let (eigenvalues, eigenvectors) = s
1338 .eigh(Side::Lower)
1339 .map_err(|e| format!("penalty_matrix_root eigendecomposition failed: {e}"))?;
1340
1341 let max_ev = eigenvalues.iter().copied().fold(0.0_f64, f64::max);
1342 let tol = (n.max(1) as f64) * f64::EPSILON * max_ev.max(1e-12);
1343
1344 let active: Vec<usize> = eigenvalues
1345 .iter()
1346 .enumerate()
1347 .filter(|(_, v)| **v > tol)
1348 .map(|(i, _)| i)
1349 .collect();
1350 let rank = active.len();
1351
1352 let mut r = Array2::zeros((rank, n));
1353 for (out_row, &idx) in active.iter().enumerate() {
1354 let scale = eigenvalues[idx].sqrt();
1355 for col in 0..n {
1356 r[[out_row, col]] = scale * eigenvectors[[col, idx]];
1357 }
1358 }
1359 Ok(r)
1360}
1361
1362pub fn compute_block_penalty_logdet_derivs(
1381 per_block_rho: &[Array1<f64>],
1382 per_block_penalties: &[&[Array2<f64>]],
1383 ridge: f64,
1384) -> Result<PenaltyLogdetDerivs, String> {
1385 use super::super::penalty_logdet::PenaltyPseudologdet;
1386
1387 let total_k: usize = per_block_rho.iter().map(|r| r.len()).sum();
1388 let block_offsets: Vec<usize> = per_block_rho
1389 .iter()
1390 .scan(0usize, |at, rho| {
1391 let current = *at;
1392 *at += rho.len();
1393 Some(current)
1394 })
1395 .collect();
1396
1397 struct BlockPenaltyLogdetResult {
1398 pub(crate) offset: usize,
1399 pub(crate) value: f64,
1400 pub(crate) first: Array1<f64>,
1401 pub(crate) second: Array2<f64>,
1402 }
1403
1404 let compute_block = |(b, block_rho): (usize, &Array1<f64>)| {
1405 let penalties = per_block_penalties[b];
1406 let kb = block_rho.len();
1407 if penalties.is_empty() || kb == 0 {
1408 return Ok(BlockPenaltyLogdetResult {
1409 offset: block_offsets[b],
1410 value: 0.0,
1411 first: Array1::zeros(kb),
1412 second: Array2::zeros((kb, kb)),
1413 });
1414 }
1415 let lambdas: Vec<f64> = block_rho.iter().map(|&r| r.exp()).collect();
1416
1417 let pld = PenaltyPseudologdet::from_components(penalties, &lambdas, ridge)
1423 .map_err(|e| format!("penalty logdet failed for block {b}: {e}"))?;
1424
1425 let value = pld.value();
1426 let (first, second) = pld.rho_derivatives(penalties, &lambdas);
1427 Ok(BlockPenaltyLogdetResult {
1428 offset: block_offsets[b],
1429 value,
1430 first,
1431 second,
1432 })
1433 };
1434
1435 let block_results: Vec<BlockPenaltyLogdetResult> = if rayon::current_thread_index().is_some() {
1436 per_block_rho
1437 .iter()
1438 .enumerate()
1439 .map(compute_block)
1440 .collect::<Result<Vec<_>, String>>()?
1441 } else {
1442 per_block_rho
1443 .par_iter()
1444 .enumerate()
1445 .map(compute_block)
1446 .collect::<Result<Vec<_>, String>>()?
1447 };
1448
1449 let mut log_det_total = 0.0;
1450 let mut first = Array1::zeros(total_k);
1451 let mut second = Array2::zeros((total_k, total_k));
1452 for block in block_results {
1453 log_det_total += block.value;
1454 let kb = block.first.len();
1455 for k in 0..kb {
1456 first[block.offset + k] = block.first[k];
1457 }
1458 for k in 0..kb {
1459 for l in 0..kb {
1460 second[[block.offset + k, block.offset + l]] = block.second[[k, l]];
1461 }
1462 }
1463 }
1464
1465 Ok(PenaltyLogdetDerivs {
1466 value: log_det_total,
1467 first,
1468 second: Some(second),
1469 })
1470}
1471
1472