1use super::*;
2
3#[derive(Debug, Clone)]
4pub struct SaeArrowVector {
5 pub t: Array1<f64>,
6 pub beta: Array1<f64>,
7}
8
9pub(crate) struct DeflatedArrowSolver<'a> {
10 pub(crate) cache: &'a ArrowFactorCache,
11 pub(crate) gauge_basis: Vec<Array1<f64>>,
12 pub(crate) gauge_response_physical: Vec<Array1<f64>>,
13 pub(crate) woodbury_factor: Option<FaerCholeskyFactor>,
14 pub(crate) gauge_stiffness_recip: f64,
15}
16
17impl<'a> DeflatedArrowSolver<'a> {
18 pub(crate) fn plain(cache: &'a ArrowFactorCache) -> Self {
19 Self {
20 cache,
21 gauge_basis: Vec::new(),
22 gauge_response_physical: Vec::new(),
23 woodbury_factor: None,
24 gauge_stiffness_recip: 0.0,
25 }
26 }
27
28 pub(crate) fn from_orthonormal_gauges(
29 cache: &'a ArrowFactorCache,
30 gauge_basis: Vec<Array1<f64>>,
31 stiffness: f64,
32 ) -> Result<Self, String> {
33 if gauge_basis.is_empty() {
34 return Ok(Self::plain(cache));
35 }
36 if !(stiffness.is_finite() && stiffness > 0.0) {
37 return Err(format!(
38 "DeflatedArrowSolver: gauge stiffness must be finite and positive; got {stiffness}"
39 ));
40 }
41 let full_len = cache.delta_t_len() + cache.k;
42 let mut gauge_responses = Vec::with_capacity(gauge_basis.len());
43 for gauge in &gauge_basis {
44 if gauge.len() != full_len {
45 return Err(format!(
46 "DeflatedArrowSolver: gauge length {} != cache full length {full_len}",
47 gauge.len()
48 ));
49 }
50 let (sol_t, sol_beta) = cache
51 .full_inverse_apply(
52 gauge.slice(s![..cache.delta_t_len()]),
53 gauge.slice(s![cache.delta_t_len()..]),
54 )
55 .map_err(|err| format!("DeflatedArrowSolver: gauge back-solve: {err}"))?;
56 gauge_responses.push(flatten_arrow_parts(sol_t.view(), sol_beta.view()));
57 }
58
59 let rank = gauge_basis.len();
60 let stiffness_recip = stiffness.recip();
61 let mut gauge_metric = Array2::<f64>::zeros((rank, rank));
62 let mut woodbury = Array2::<f64>::eye(rank);
63 for i in 0..rank {
64 woodbury[[i, i]] *= stiffness_recip;
65 for j in 0..rank {
66 let value = gauge_basis[i].dot(&gauge_responses[j]);
67 gauge_metric[[i, j]] = value;
68 woodbury[[i, j]] += value;
69 }
70 }
71 let woodbury_factor = woodbury
72 .cholesky(Side::Lower)
73 .map_err(|err| format!("DeflatedArrowSolver: gauge Woodbury factor failed: {err}"))?;
74 let mut gauge_response_physical = gauge_responses;
75 for j in 0..rank {
76 for i in 0..rank {
77 let coeff = gauge_metric[[i, j]];
78 for row in 0..full_len {
79 gauge_response_physical[j][row] -= coeff * gauge_basis[i][row];
80 }
81 }
82 }
83 Ok(Self {
84 cache,
85 gauge_basis,
86 gauge_response_physical,
87 woodbury_factor: Some(woodbury_factor),
88 gauge_stiffness_recip: stiffness_recip,
89 })
90 }
91
92 pub(crate) fn solve(
93 &self,
94 rhs_t: ArrayView1<'_, f64>,
95 rhs_beta: ArrayView1<'_, f64>,
96 ) -> Result<SaeArrowVector, String> {
97 let (sol_t, sol_beta) = self
98 .cache
99 .full_inverse_apply(rhs_t, rhs_beta)
100 .map_err(|err| format!("DeflatedArrowSolver: full inverse: {err}"))?;
101 let Some(factor) = self.woodbury_factor.as_ref() else {
102 return Ok(SaeArrowVector {
103 t: sol_t,
104 beta: sol_beta,
105 });
106 };
107
108 let full_len = self.cache.delta_t_len() + self.cache.k;
109 let mut flat = flatten_arrow_parts(sol_t.view(), sol_beta.view());
110 if flat.len() != full_len {
111 return Err(format!(
112 "DeflatedArrowSolver: solution length {} != cache full length {full_len}",
113 flat.len()
114 ));
115 }
116 let mut gauge_coeffs = Array1::<f64>::zeros(self.gauge_basis.len());
117 for (idx, gauge) in self.gauge_basis.iter().enumerate() {
118 gauge_coeffs[idx] = gauge.dot(&flat);
119 }
120 let weights = factor.solvevec(&gauge_coeffs);
121 for (gauge, &coeff) in self.gauge_basis.iter().zip(gauge_coeffs.iter()) {
122 for i in 0..flat.len() {
123 flat[i] -= gauge[i] * coeff;
124 }
125 }
126 for (response, &weight) in self.gauge_response_physical.iter().zip(weights.iter()) {
127 for i in 0..flat.len() {
128 flat[i] -= response[i] * weight;
129 }
130 }
131 for (gauge, &weight) in self.gauge_basis.iter().zip(weights.iter()) {
132 let coeff = self.gauge_stiffness_recip * weight;
133 for i in 0..flat.len() {
134 flat[i] += gauge[i] * coeff;
135 }
136 }
137 Ok(SaeArrowVector {
138 t: flat.slice(s![..self.cache.delta_t_len()]).to_owned(),
139 beta: flat.slice(s![self.cache.delta_t_len()..]).to_owned(),
140 })
141 }
142
143 pub(crate) fn latent_inverse_diagonal_kept(&self) -> Result<Array1<f64>, String> {
158 let mut out = self.latent_inverse_diagonal()?;
159 let cache = self.cache;
160 for (row, dirs) in cache.deflated_row_directions.iter().enumerate() {
161 if dirs.is_empty() {
162 continue;
163 }
164 let base = cache.row_offsets[row];
165 for v in dirs {
166 for s in 0..v.len() {
167 if base + s < out.len() {
168 out[base + s] -= v[s] * v[s];
169 }
170 }
171 }
172 }
173 Ok(out)
174 }
175
176 pub(crate) fn plain_selected_inverse_available(&self) -> bool {
186 self.woodbury_factor.is_none() && self.cache.cross_row_woodbury.is_none()
187 }
188
189 pub(crate) fn beta_inv(&self) -> Result<Array2<f64>, String> {
195 let k = self.cache.k;
196 if k == 0 {
197 return Ok(Array2::<f64>::zeros((0, 0)));
198 }
199 self.cache
200 .schur_inverse_block(0..k)
201 .map_err(|err| format!("DeflatedArrowSolver::beta_inv: {err}"))
202 }
203
204 pub(crate) fn selected_inverse_row_blocks(
222 &self,
223 row: usize,
224 beta_inv: &Array2<f64>,
225 ) -> Result<(Array2<f64>, Array2<f64>), String> {
226 let cache = self.cache;
227 let q = cache.row_dims[row];
228 let k = cache.k;
229 let factor = cache.undamped_factor(row);
230
231 let mut a_inv = Array2::<f64>::zeros((q, q));
233 let mut e_j = Array1::<f64>::zeros(q);
234 for j in 0..q {
235 e_j.fill(0.0);
236 e_j[j] = 1.0;
237 let col = cholesky_solve_vector(factor, e_j.view());
238 for r in 0..q {
239 a_inv[[r, j]] = col[r];
240 }
241 }
242
243 if k == 0 {
244 return Ok((a_inv, Array2::<f64>::zeros((q, 0))));
245 }
246
247 let mut g = Array2::<f64>::zeros((q, k));
250 let mut e_c = Array1::<f64>::zeros(k);
251 let mut b_col = Array1::<f64>::zeros(q);
252 for c in 0..k {
253 e_c.fill(0.0);
254 e_c[c] = 1.0;
255 b_col.fill(0.0);
256 if !cache.apply_htbeta_row(row, e_c.view(), &mut b_col) {
257 return Err(format!(
258 "DeflatedArrowSolver::selected_inverse_row_blocks: H_tβ^({row}) apply failed"
259 ));
260 }
261 let g_col = cholesky_solve_vector(factor, b_col.view());
262 for r in 0..q {
263 g[[r, c]] = g_col[r];
264 }
265 }
266
267 let mut gs = Array2::<f64>::zeros((q, k));
269 for r in 0..q {
270 for m in 0..k {
271 let mut acc = 0.0_f64;
272 for n in 0..k {
273 acc += g[[r, n]] * beta_inv[[n, m]];
274 }
275 gs[[r, m]] = acc;
276 }
277 }
278
279 let mut inv_vbeta = Array2::<f64>::zeros((q, k));
281 for col in 0..q {
282 for b in 0..k {
283 inv_vbeta[[col, b]] = -gs[[col, b]];
284 }
285 }
286
287 let mut inv_vv = a_inv;
289 for r in 0..q {
290 for col in 0..q {
291 let mut acc = 0.0_f64;
292 for m in 0..k {
293 acc += gs[[r, m]] * g[[col, m]];
294 }
295 inv_vv[[r, col]] += acc;
296 }
297 }
298
299 Ok((inv_vv, inv_vbeta))
300 }
301
302 pub(crate) fn latent_inverse_diagonal(&self) -> Result<Array1<f64>, String> {
303 if self.woodbury_factor.is_none() {
304 return self
305 .cache
306 .latent_block_inverse_diagonal()
307 .map_err(|err| format!("DeflatedArrowSolver: latent inverse diagonal: {err}"));
308 }
309 let total_t = self.cache.delta_t_len();
310 let mut out = Array1::<f64>::zeros(total_t);
311 let rhs_beta = Array1::<f64>::zeros(self.cache.k);
312 for idx in 0..total_t {
313 let mut rhs_t = Array1::<f64>::zeros(total_t);
314 rhs_t[idx] = 1.0;
315 let solved = self.solve(rhs_t.view(), rhs_beta.view())?;
316 out[idx] = solved.t[idx];
317 }
318 Ok(out)
319 }
320}
321
322#[cfg(test)]
323mod selected_inverse_row_blocks_oracle_tests {
324 use super::*;
330 use gam_solve::arrow_schur::{
331 ArrowFactorSlab, ArrowHtbetaCache, ArrowSolverMode, ArrowUndampedFactors, PcgDiagnostics,
332 };
333 use ndarray::array;
334 use std::sync::Arc;
335
336 fn coupled_arrow_cache() -> ArrowFactorCache {
342 let htt = ArrowFactorSlab::from_blocks(vec![
343 array![[1.3_f64, 0.0], [0.4, 1.1]],
344 array![[0.9_f64]],
345 ]);
346 let schur = array![[1.2_f64, 0.0], [0.25, 0.95]];
347 ArrowFactorCache {
348 htt_factors: htt,
349 htt_factors_undamped: ArrowUndampedFactors::SameAsDamped,
350 schur_factor: Some(schur),
351 joint_hessian_log_det: None,
352 solver_mode: ArrowSolverMode::Direct,
353 ridge_t: 0.0,
354 ridge_beta: 0.0,
355 htbeta: ArrowHtbetaCache::Dense {
356 blocks: Arc::from(
357 vec![
358 array![[0.5_f64, -0.2], [0.1, 0.4]],
359 array![[0.3_f64, 0.7]],
360 ]
361 .into_boxed_slice(),
362 ),
363 estimated_bytes: 0,
364 },
365 d: 2,
366 row_dims: Arc::from(vec![2usize, 1usize].into_boxed_slice()),
367 row_offsets: Arc::from(vec![0usize, 2usize, 3usize].into_boxed_slice()),
368 k: 2,
369 manifold_mode_fingerprint: 0,
370 row_hessian_fingerprint: 0,
371 pcg_diagnostics: PcgDiagnostics::default(),
372 gauge_deflated_directions: 0,
373 deflated_row_directions: Arc::from(Vec::new()),
374 deflation_row_spectra: Arc::from(Vec::new()),
375 cross_row_woodbury: None,
376 }
377 }
378
379 #[test]
380 fn row_local_blocks_match_per_row_solve() {
381 let cache = coupled_arrow_cache();
382 let solver = DeflatedArrowSolver::plain(&cache);
383 assert!(
384 solver.plain_selected_inverse_available(),
385 "plain cache must take the fast selected-inverse path"
386 );
387 let total_t = cache.delta_t_len();
388 let k = cache.k;
389
390 let beta_inv = solver.beta_inv().expect("beta_inv");
392 let rhs_t_zero = Array1::<f64>::zeros(total_t);
393 for col in 0..k {
394 let mut rhs_beta = Array1::<f64>::zeros(k);
395 rhs_beta[col] = 1.0;
396 let solved = solver
397 .solve(rhs_t_zero.view(), rhs_beta.view())
398 .expect("β solve");
399 for r in 0..k {
400 assert!(
401 (beta_inv[[r, col]] - solved.beta[r]).abs() <= 1e-9,
402 "beta_inv[{r},{col}] {} != solve {}",
403 beta_inv[[r, col]],
404 solved.beta[r]
405 );
406 }
407 }
408
409 let rhs_beta_zero = Array1::<f64>::zeros(k);
411 for row in 0..cache.n_rows() {
412 let q = cache.row_dims[row];
413 let base = cache.row_offsets[row];
414 let (inv_vv, inv_vbeta) = solver
415 .selected_inverse_row_blocks(row, &beta_inv)
416 .expect("row blocks");
417 for col in 0..q {
418 let mut rhs_t = Array1::<f64>::zeros(total_t);
419 rhs_t[base + col] = 1.0;
420 let solved = solver
421 .solve(rhs_t.view(), rhs_beta_zero.view())
422 .expect("t solve");
423 for r in 0..q {
424 assert!(
425 (inv_vv[[r, col]] - solved.t[base + r]).abs() <= 1e-9,
426 "inv_vv[{r},{col}] {} != solve {}",
427 inv_vv[[r, col]],
428 solved.t[base + r]
429 );
430 }
431 for b in 0..k {
432 assert!(
433 (inv_vbeta[[col, b]] - solved.beta[b]).abs() <= 1e-9,
434 "inv_vbeta[{col},{b}] {} != solve {}",
435 inv_vbeta[[col, b]],
436 solved.beta[b]
437 );
438 }
439 }
440 }
441 }
442}
443
444pub(crate) fn flatten_arrow_parts(
445 t: ArrayView1<'_, f64>,
446 beta: ArrayView1<'_, f64>,
447) -> Array1<f64> {
448 let mut out = Array1::<f64>::zeros(t.len() + beta.len());
449 for i in 0..t.len() {
450 out[i] = t[i];
451 }
452 for i in 0..beta.len() {
453 out[t.len() + i] = beta[i];
454 }
455 out
456}
457
458pub(crate) fn apply_cached_arrow_hessian(
459 cache: &ArrowFactorCache,
460 v_t: ArrayView1<'_, f64>,
461 v_beta: ArrayView1<'_, f64>,
462) -> Result<SaeArrowVector, String> {
463 let total_t = cache.delta_t_len();
464 if v_t.len() != total_t || v_beta.len() != cache.k {
465 return Err(format!(
466 "apply_cached_arrow_hessian: vector shapes (t={}, beta={}) != cache shapes \
467 (t={total_t}, beta={})",
468 v_t.len(),
469 v_beta.len(),
470 cache.k
471 ));
472 }
473
474 let mut out_t = Array1::<f64>::zeros(total_t);
475 let mut out_beta = Array1::<f64>::zeros(cache.k);
476 for row in 0..cache.n_rows() {
477 let di = cache.row_dims[row];
478 let base = cache.row_offsets[row];
479 let row_v = v_t.slice(s![base..base + di]);
480 let factor = cache.undamped_factor(row);
481 let av = cholesky_factor_apply(factor, row_v);
482 for j in 0..di {
483 out_t[base + j] += av[j];
484 }
485 if cache.k > 0 {
486 let mut b_vbeta = Array1::<f64>::zeros(di);
487 if !cache.apply_htbeta_row(row, v_beta, &mut b_vbeta) {
488 return Err(format!(
489 "apply_cached_arrow_hessian: H_tβ^({row}) apply failed"
490 ));
491 }
492 for j in 0..di {
493 out_t[base + j] += b_vbeta[j];
494 }
495 if !cache.apply_htbeta_row_transpose(row, row_v, &mut out_beta, None) {
496 return Err(format!(
497 "apply_cached_arrow_hessian: H_βt^({row}) apply failed"
498 ));
499 }
500 }
501 }
502
503 if cache.k > 0 {
504 let Some(schur_factor) = cache.schur_factor.as_ref() else {
505 return Err(
506 "apply_cached_arrow_hessian: dense Schur factor is required for gauge probing"
507 .to_string(),
508 );
509 };
510 let schur_v = cholesky_factor_apply(schur_factor.view(), v_beta);
511 for i in 0..cache.k {
512 out_beta[i] += schur_v[i];
513 }
514 for row in 0..cache.n_rows() {
515 let di = cache.row_dims[row];
516 let mut b_vbeta = Array1::<f64>::zeros(di);
517 if !cache.apply_htbeta_row(row, v_beta, &mut b_vbeta) {
518 return Err(format!(
519 "apply_cached_arrow_hessian: H_tβ^({row}) Schur correction apply failed"
520 ));
521 }
522 let a_inv_b_vbeta = cholesky_solve_vector(cache.undamped_factor(row), b_vbeta.view());
523 if !cache.apply_htbeta_row_transpose(row, a_inv_b_vbeta.view(), &mut out_beta, None) {
524 return Err(format!(
525 "apply_cached_arrow_hessian: H_βt^({row}) Schur correction apply failed"
526 ));
527 }
528 }
529 }
530
531 if let Some(woodbury) = cache.cross_row_woodbury.as_ref() {
540 woodbury.apply_forward_t(v_t, &mut out_t);
541 }
542
543 Ok(SaeArrowVector {
544 t: out_t,
545 beta: out_beta,
546 })
547}
548
549pub(crate) fn cholesky_factor_apply(
550 factor: ArrayView2<'_, f64>,
551 vector: ArrayView1<'_, f64>,
552) -> Array1<f64> {
553 let n = factor.nrows();
554 let mut lt_v = Array1::<f64>::zeros(n);
555 for row in 0..n {
556 let mut acc = 0.0_f64;
557 for col in row..n {
558 acc += factor[[col, row]] * vector[col];
559 }
560 lt_v[row] = acc;
561 }
562 let mut out = Array1::<f64>::zeros(n);
563 for row in 0..n {
564 let mut acc = 0.0_f64;
565 for col in 0..=row {
566 acc += factor[[row, col]] * lt_v[col];
567 }
568 out[row] = acc;
569 }
570 out
571}
572
573#[derive(Debug, Clone, Copy)]
574pub(crate) enum SaeLocalRowVar {
575 Logit { atom: usize },
576 Coord { atom: usize, axis: usize },
577}
578
579#[derive(Debug, Clone)]
580pub(crate) struct SaeBorderChannel {
581 pub(crate) atom: usize,
582 pub(crate) basis_col: usize,
583 pub(crate) index: usize,
584 pub(crate) output: Vec<f64>,
585}
586
587#[derive(Debug, Clone)]
588pub(crate) struct SaeRowJets {
589 pub(crate) vars: Vec<SaeLocalRowVar>,
590 pub(crate) first: Vec<Vec<f64>>,
591 pub(crate) second: Vec<Vec<Vec<f64>>>,
592 pub(crate) beta: Vec<Vec<f64>>,
593 pub(crate) beta_deriv: Vec<Vec<Vec<f64>>>,
594 pub(crate) beta_l_deriv: Vec<Vec<Vec<f64>>>,
595}
596
597pub(crate) fn sae_dot(a: &[f64], b: &[f64]) -> f64 {
598 a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
599}
600
601pub(crate) fn sae_inner(a: &SaeArrowVector, b: &SaeArrowVector) -> f64 {
604 sae_dot(a.t.as_slice().unwrap_or(&[]), b.t.as_slice().unwrap_or(&[]))
605 + sae_dot(
606 a.beta.as_slice().unwrap_or(&[]),
607 b.beta.as_slice().unwrap_or(&[]),
608 )
609}
610
611pub(crate) fn sae_norm(a: &SaeArrowVector) -> f64 {
613 sae_inner(a, a).max(0.0).sqrt()
614}
615
616pub(crate) fn solve_b_preconditioned_cg<F>(
627 solver: &DeflatedArrowSolver<'_>,
628 rhs: &SaeArrowVector,
629 apply_a: F,
630) -> Result<SaeArrowVector, String>
631where
632 F: Fn(&SaeArrowVector) -> Result<SaeArrowVector, String>,
633{
634 let mut x = solver
636 .solve(rhs.t.view(), rhs.beta.view())
637 .map_err(|err| format!("solve_b_preconditioned_cg: B inverse: {err}"))?;
638 let ax = apply_a(&x)?;
640 let mut r = SaeArrowVector {
641 t: &rhs.t - &ax.t,
642 beta: &rhs.beta - &ax.beta,
643 };
644 let mut z = solver
645 .solve(r.t.view(), r.beta.view())
646 .map_err(|err| format!("solve_b_preconditioned_cg: B preconditioner: {err}"))?;
647 let mut p = z.clone();
648 let mut rz = sae_inner(&r, &z);
649
650 let rhs_norm = sae_norm(rhs).max(1.0);
651 let max_iters = (x.t.len() + x.beta.len()).clamp(8, 256);
652 let rel_tol = 1.0e-10;
653 for _ in 0..max_iters {
654 if !rz.is_finite() || rz <= 0.0 {
655 break; }
657 let ap = apply_a(&p)?;
658 let p_ap = sae_inner(&p, &ap);
659 if !p_ap.is_finite() || p_ap <= 0.0 {
660 break; }
662 let alpha = rz / p_ap;
663 for idx in 0..x.t.len() {
664 x.t[idx] += alpha * p.t[idx];
665 r.t[idx] -= alpha * ap.t[idx];
666 }
667 for idx in 0..x.beta.len() {
668 x.beta[idx] += alpha * p.beta[idx];
669 r.beta[idx] -= alpha * ap.beta[idx];
670 }
671 if sae_norm(&r) <= rel_tol * rhs_norm {
672 break;
673 }
674 z = solver
675 .solve(r.t.view(), r.beta.view())
676 .map_err(|err| format!("solve_b_preconditioned_cg: B preconditioner: {err}"))?;
677 let rz_next = sae_inner(&r, &z);
678 let beta = rz_next / rz;
679 for idx in 0..p.t.len() {
680 p.t[idx] = z.t[idx] + beta * p.t[idx];
681 }
682 for idx in 0..p.beta.len() {
683 p.beta[idx] = z.beta[idx] + beta * p.beta[idx];
684 }
685 rz = rz_next;
686 }
687 Ok(x)
688}