Skip to main content

gam_sae/manifold/
arrow_solver.rs

1use super::*;
2
3#[derive(Debug, Clone)]
4pub struct SaeArrowVector {
5    pub t: Array1<f64>,
6    pub beta: Array1<f64>,
7}
8
9pub(crate) struct DeflatedArrowSolver<'a> {
10    pub(crate) cache: &'a ArrowFactorCache,
11    pub(crate) gauge_basis: Vec<Array1<f64>>,
12    pub(crate) gauge_response_physical: Vec<Array1<f64>>,
13    pub(crate) woodbury_factor: Option<FaerCholeskyFactor>,
14    pub(crate) gauge_stiffness_recip: f64,
15}
16
17impl<'a> DeflatedArrowSolver<'a> {
18    pub(crate) fn plain(cache: &'a ArrowFactorCache) -> Self {
19        Self {
20            cache,
21            gauge_basis: Vec::new(),
22            gauge_response_physical: Vec::new(),
23            woodbury_factor: None,
24            gauge_stiffness_recip: 0.0,
25        }
26    }
27
28    pub(crate) fn from_orthonormal_gauges(
29        cache: &'a ArrowFactorCache,
30        gauge_basis: Vec<Array1<f64>>,
31        stiffness: f64,
32    ) -> Result<Self, String> {
33        if gauge_basis.is_empty() {
34            return Ok(Self::plain(cache));
35        }
36        if !(stiffness.is_finite() && stiffness > 0.0) {
37            return Err(format!(
38                "DeflatedArrowSolver: gauge stiffness must be finite and positive; got {stiffness}"
39            ));
40        }
41        let full_len = cache.delta_t_len() + cache.k;
42        let mut gauge_responses = Vec::with_capacity(gauge_basis.len());
43        for gauge in &gauge_basis {
44            if gauge.len() != full_len {
45                return Err(format!(
46                    "DeflatedArrowSolver: gauge length {} != cache full length {full_len}",
47                    gauge.len()
48                ));
49            }
50            let (sol_t, sol_beta) = cache
51                .full_inverse_apply(
52                    gauge.slice(s![..cache.delta_t_len()]),
53                    gauge.slice(s![cache.delta_t_len()..]),
54                )
55                .map_err(|err| format!("DeflatedArrowSolver: gauge back-solve: {err}"))?;
56            gauge_responses.push(flatten_arrow_parts(sol_t.view(), sol_beta.view()));
57        }
58
59        let rank = gauge_basis.len();
60        let stiffness_recip = stiffness.recip();
61        let mut gauge_metric = Array2::<f64>::zeros((rank, rank));
62        let mut woodbury = Array2::<f64>::eye(rank);
63        for i in 0..rank {
64            woodbury[[i, i]] *= stiffness_recip;
65            for j in 0..rank {
66                let value = gauge_basis[i].dot(&gauge_responses[j]);
67                gauge_metric[[i, j]] = value;
68                woodbury[[i, j]] += value;
69            }
70        }
71        let woodbury_factor = woodbury
72            .cholesky(Side::Lower)
73            .map_err(|err| format!("DeflatedArrowSolver: gauge Woodbury factor failed: {err}"))?;
74        let mut gauge_response_physical = gauge_responses;
75        for j in 0..rank {
76            for i in 0..rank {
77                let coeff = gauge_metric[[i, j]];
78                for row in 0..full_len {
79                    gauge_response_physical[j][row] -= coeff * gauge_basis[i][row];
80                }
81            }
82        }
83        Ok(Self {
84            cache,
85            gauge_basis,
86            gauge_response_physical,
87            woodbury_factor: Some(woodbury_factor),
88            gauge_stiffness_recip: stiffness_recip,
89        })
90    }
91
92    pub(crate) fn solve(
93        &self,
94        rhs_t: ArrayView1<'_, f64>,
95        rhs_beta: ArrayView1<'_, f64>,
96    ) -> Result<SaeArrowVector, String> {
97        let (sol_t, sol_beta) = self
98            .cache
99            .full_inverse_apply(rhs_t, rhs_beta)
100            .map_err(|err| format!("DeflatedArrowSolver: full inverse: {err}"))?;
101        let Some(factor) = self.woodbury_factor.as_ref() else {
102            return Ok(SaeArrowVector {
103                t: sol_t,
104                beta: sol_beta,
105            });
106        };
107
108        let full_len = self.cache.delta_t_len() + self.cache.k;
109        let mut flat = flatten_arrow_parts(sol_t.view(), sol_beta.view());
110        if flat.len() != full_len {
111            return Err(format!(
112                "DeflatedArrowSolver: solution length {} != cache full length {full_len}",
113                flat.len()
114            ));
115        }
116        let mut gauge_coeffs = Array1::<f64>::zeros(self.gauge_basis.len());
117        for (idx, gauge) in self.gauge_basis.iter().enumerate() {
118            gauge_coeffs[idx] = gauge.dot(&flat);
119        }
120        let weights = factor.solvevec(&gauge_coeffs);
121        for (gauge, &coeff) in self.gauge_basis.iter().zip(gauge_coeffs.iter()) {
122            for i in 0..flat.len() {
123                flat[i] -= gauge[i] * coeff;
124            }
125        }
126        for (response, &weight) in self.gauge_response_physical.iter().zip(weights.iter()) {
127            for i in 0..flat.len() {
128                flat[i] -= response[i] * weight;
129            }
130        }
131        for (gauge, &weight) in self.gauge_basis.iter().zip(weights.iter()) {
132            let coeff = self.gauge_stiffness_recip * weight;
133            for i in 0..flat.len() {
134                flat[i] += gauge[i] * coeff;
135            }
136        }
137        Ok(SaeArrowVector {
138            t: flat.slice(s![..self.cache.delta_t_len()]).to_owned(),
139            beta: flat.slice(s![self.cache.delta_t_len()..]).to_owned(),
140        })
141    }
142
143    /// Per-row latent-block inverse diagonal with the UNIT-stiffness deflated
144    /// subspace REMOVED — the kept-subspace selected inverse the outer ρ/θ
145    /// gradient diagonal traces must contract against.
146    ///
147    /// [`Self::latent_inverse_diagonal`] returns the diagonal of the DEFLATED
148    /// inverse, which assigns `1/λ̃ = 1` to every per-row direction `vᵢ` that the
149    /// undamped evidence factor stiffened to unit curvature; a `½ tr(H⁻¹ ∂H/∂ρ)`
150    /// diagonal contraction against it therefore spuriously includes
151    /// `Σ_i vᵢ[s]²` at slot `s`, a ρ/θ-independent contribution that must be 0.
152    /// This variant subtracts the per-row deflated outer-product diagonal
153    /// `Σ_i vᵢ[s]²` so the diagonal traces (ARD precision, IBP/softmax assignment
154    /// log-strength) see only the kept subspace. The deflated subspace's β-Schur
155    /// coupling is higher order and left to the per-block subtraction the
156    /// off-diagonal (`solve`-based) traces apply directly.
157    pub(crate) fn latent_inverse_diagonal_kept(&self) -> Result<Array1<f64>, String> {
158        let mut out = self.latent_inverse_diagonal()?;
159        let cache = self.cache;
160        for (row, dirs) in cache.deflated_row_directions.iter().enumerate() {
161            if dirs.is_empty() {
162                continue;
163            }
164            let base = cache.row_offsets[row];
165            for v in dirs {
166                for s in 0..v.len() {
167                    if base + s < out.len() {
168                        out[base + s] -= v[s] * v[s];
169                    }
170                }
171            }
172        }
173        Ok(out)
174    }
175
176    /// #932 FRONT C — whether the cheap row-local Takahashi selected inverse
177    /// ([`Self::beta_inv`] / [`Self::selected_inverse_row_blocks`]) reproduces
178    /// `solve`'s selected entries EXACTLY. It does so only on the plain bordered
179    /// arrow: when a gauge Woodbury deflation is active (`woodbury_factor`) the
180    /// `solve` output carries the rank-`R` gauge correction the row-local blocks
181    /// omit, and when a #1038 cross-row IBP Woodbury is present the cache's
182    /// per-row factors are the NO-SELF base `H₀'` (not the full operator). In
183    /// either case callers MUST fall back to the per-row `solve` loop — the
184    /// row-local blocks are NOT valid there.
185    pub(crate) fn plain_selected_inverse_available(&self) -> bool {
186        self.woodbury_factor.is_none() && self.cache.cross_row_woodbury.is_none()
187    }
188
189    /// #932 FRONT C — the full `(H⁻¹)_ββ = S⁻¹` block (`K×K`), formed ONCE per
190    /// outer step from the cached dense Schur factor (no per-column full-system
191    /// `solve`). On the plain arrow this equals the `beta_inv` the logdet /
192    /// α-trace consumers used to build with `K` calls to [`Self::solve`] with
193    /// unit β-RHS. ONLY valid when [`Self::plain_selected_inverse_available`].
194    pub(crate) fn beta_inv(&self) -> Result<Array2<f64>, String> {
195        let k = self.cache.k;
196        if k == 0 {
197            return Ok(Array2::<f64>::zeros((0, 0)));
198        }
199        self.cache
200            .schur_inverse_block(0..k)
201            .map_err(|err| format!("DeflatedArrowSolver::beta_inv: {err}"))
202    }
203
204    /// #932 FRONT C — row-local Takahashi selected inverse of the PLAIN bordered
205    /// arrow: returns this row's own `(H⁻¹)_tt` block (`q×q`) and its `(H⁻¹)_tβ`
206    /// block (`q×K`) WITHOUT the O(n) full-system sweep that one
207    /// [`Self::solve`] per unit RHS performs. Mirrors
208    /// `ArrowFactorCache::latent_block_inverse_diagonal` (system.rs) but returns
209    /// the full blocks rather than only the diagonal. With `A_i =
210    /// undamped_factor(i)`, `B_i = H_tβ^(i)`, `G_i = A_i⁻¹ B_i`, `S⁻¹ = beta_inv`:
211    ///
212    /// ```text
213    ///   (H⁻¹)_tt[i,i] = A_i⁻¹ + G_i S⁻¹ G_iᵀ
214    ///   (H⁻¹)_tβ[i]   = −G_i S⁻¹
215    /// ```
216    ///
217    /// Touches ONLY row `i`'s own factor, its `H_tβ^(i)` coupling, and the shared
218    /// `S⁻¹` — O(q·(q+K)) per row, no `n`-sweep. ONLY valid when
219    /// [`Self::plain_selected_inverse_available`]; pass the `S⁻¹` from
220    /// [`Self::beta_inv`].
221    pub(crate) fn selected_inverse_row_blocks(
222        &self,
223        row: usize,
224        beta_inv: &Array2<f64>,
225    ) -> Result<(Array2<f64>, Array2<f64>), String> {
226        let cache = self.cache;
227        let q = cache.row_dims[row];
228        let k = cache.k;
229        let factor = cache.undamped_factor(row);
230
231        // A_i⁻¹ (q×q): solve A_i x = e_j per column.
232        let mut a_inv = Array2::<f64>::zeros((q, q));
233        let mut e_j = Array1::<f64>::zeros(q);
234        for j in 0..q {
235            e_j.fill(0.0);
236            e_j[j] = 1.0;
237            let col = cholesky_solve_vector(factor, e_j.view());
238            for r in 0..q {
239                a_inv[[r, j]] = col[r];
240            }
241        }
242
243        if k == 0 {
244            return Ok((a_inv, Array2::<f64>::zeros((q, 0))));
245        }
246
247        // G_i = A_i⁻¹ B_i (q×K): column c is A_i⁻¹ (B_i e_c), where B_i e_c is the
248        // c-th column of H_tβ^(i) recovered via `apply_htbeta_row`.
249        let mut g = Array2::<f64>::zeros((q, k));
250        let mut e_c = Array1::<f64>::zeros(k);
251        let mut b_col = Array1::<f64>::zeros(q);
252        for c in 0..k {
253            e_c.fill(0.0);
254            e_c[c] = 1.0;
255            b_col.fill(0.0);
256            if !cache.apply_htbeta_row(row, e_c.view(), &mut b_col) {
257                return Err(format!(
258                    "DeflatedArrowSolver::selected_inverse_row_blocks: H_tβ^({row}) apply failed"
259                ));
260            }
261            let g_col = cholesky_solve_vector(factor, b_col.view());
262            for r in 0..q {
263                g[[r, c]] = g_col[r];
264            }
265        }
266
267        // GS = G_i S⁻¹ (q×K).
268        let mut gs = Array2::<f64>::zeros((q, k));
269        for r in 0..q {
270            for m in 0..k {
271                let mut acc = 0.0_f64;
272                for n in 0..k {
273                    acc += g[[r, n]] * beta_inv[[n, m]];
274                }
275                gs[[r, m]] = acc;
276            }
277        }
278
279        // (H⁻¹)_tβ[i] = −G_i S⁻¹ = −GS, layout [col, b].
280        let mut inv_vbeta = Array2::<f64>::zeros((q, k));
281        for col in 0..q {
282            for b in 0..k {
283                inv_vbeta[[col, b]] = -gs[[col, b]];
284            }
285        }
286
287        // (H⁻¹)_tt[i,i] = A_i⁻¹ + G_i S⁻¹ G_iᵀ = A_i⁻¹ + GS·Gᵀ, layout [r, col].
288        let mut inv_vv = a_inv;
289        for r in 0..q {
290            for col in 0..q {
291                let mut acc = 0.0_f64;
292                for m in 0..k {
293                    acc += gs[[r, m]] * g[[col, m]];
294                }
295                inv_vv[[r, col]] += acc;
296            }
297        }
298
299        Ok((inv_vv, inv_vbeta))
300    }
301
302    pub(crate) fn latent_inverse_diagonal(&self) -> Result<Array1<f64>, String> {
303        if self.woodbury_factor.is_none() {
304            return self
305                .cache
306                .latent_block_inverse_diagonal()
307                .map_err(|err| format!("DeflatedArrowSolver: latent inverse diagonal: {err}"));
308        }
309        let total_t = self.cache.delta_t_len();
310        let mut out = Array1::<f64>::zeros(total_t);
311        let rhs_beta = Array1::<f64>::zeros(self.cache.k);
312        for idx in 0..total_t {
313            let mut rhs_t = Array1::<f64>::zeros(total_t);
314            rhs_t[idx] = 1.0;
315            let solved = self.solve(rhs_t.view(), rhs_beta.view())?;
316            out[idx] = solved.t[idx];
317        }
318        Ok(out)
319    }
320}
321
322#[cfg(test)]
323mod selected_inverse_row_blocks_oracle_tests {
324    //! #932 FRONT C oracle: the row-local Takahashi selected-inverse blocks
325    //! ([`DeflatedArrowSolver::selected_inverse_row_blocks`] / [`beta_inv`])
326    //! MUST reproduce the per-row full-system `solve` loop they replace, to
327    //! ≤1e-9, on the plain bordered arrow. This is the gate the logdet /
328    //! α-trace consumers rely on when they take the fast path.
329    use super::*;
330    use gam_solve::arrow_schur::{
331        ArrowFactorSlab, ArrowHtbetaCache, ArrowSolverMode, ArrowUndampedFactors, PcgDiagnostics,
332    };
333    use ndarray::array;
334    use std::sync::Arc;
335
336    /// A plain bordered-arrow cache with a NONZERO `H_tβ` coupling and a PD
337    /// dense Schur factor, so the β-Schur back-substitution genuinely exercises
338    /// the `G S⁻¹ Gᵀ` / `−G S⁻¹` terms (not just the block-diagonal `A⁻¹`). The
339    /// stored factors are lower-Cholesky factors `L` (the represented block is
340    /// `L Lᵀ`); the row-local identity holds for any PD `A`/`S` and any `B`.
341    fn coupled_arrow_cache() -> ArrowFactorCache {
342        let htt = ArrowFactorSlab::from_blocks(vec![
343            array![[1.3_f64, 0.0], [0.4, 1.1]],
344            array![[0.9_f64]],
345        ]);
346        let schur = array![[1.2_f64, 0.0], [0.25, 0.95]];
347        ArrowFactorCache {
348            htt_factors: htt,
349            htt_factors_undamped: ArrowUndampedFactors::SameAsDamped,
350            schur_factor: Some(schur),
351            joint_hessian_log_det: None,
352            solver_mode: ArrowSolverMode::Direct,
353            ridge_t: 0.0,
354            ridge_beta: 0.0,
355            htbeta: ArrowHtbetaCache::Dense {
356                blocks: Arc::from(
357                    vec![
358                        array![[0.5_f64, -0.2], [0.1, 0.4]],
359                        array![[0.3_f64, 0.7]],
360                    ]
361                    .into_boxed_slice(),
362                ),
363                estimated_bytes: 0,
364            },
365            d: 2,
366            row_dims: Arc::from(vec![2usize, 1usize].into_boxed_slice()),
367            row_offsets: Arc::from(vec![0usize, 2usize, 3usize].into_boxed_slice()),
368            k: 2,
369            manifold_mode_fingerprint: 0,
370            row_hessian_fingerprint: 0,
371            pcg_diagnostics: PcgDiagnostics::default(),
372            gauge_deflated_directions: 0,
373            deflated_row_directions: Arc::from(Vec::new()),
374            deflation_row_spectra: Arc::from(Vec::new()),
375            cross_row_woodbury: None,
376        }
377    }
378
379    #[test]
380    fn row_local_blocks_match_per_row_solve() {
381        let cache = coupled_arrow_cache();
382        let solver = DeflatedArrowSolver::plain(&cache);
383        assert!(
384            solver.plain_selected_inverse_available(),
385            "plain cache must take the fast selected-inverse path"
386        );
387        let total_t = cache.delta_t_len();
388        let k = cache.k;
389
390        // β-block `(H⁻¹)_ββ = S⁻¹`: beta_inv() vs the per-column unit-β solve.
391        let beta_inv = solver.beta_inv().expect("beta_inv");
392        let rhs_t_zero = Array1::<f64>::zeros(total_t);
393        for col in 0..k {
394            let mut rhs_beta = Array1::<f64>::zeros(k);
395            rhs_beta[col] = 1.0;
396            let solved = solver
397                .solve(rhs_t_zero.view(), rhs_beta.view())
398                .expect("β solve");
399            for r in 0..k {
400                assert!(
401                    (beta_inv[[r, col]] - solved.beta[r]).abs() <= 1e-9,
402                    "beta_inv[{r},{col}] {} != solve {}",
403                    beta_inv[[r, col]],
404                    solved.beta[r]
405                );
406            }
407        }
408
409        // Per-row `(H⁻¹)_tt` (q×q) and `(H⁻¹)_tβ` (q×K) blocks.
410        let rhs_beta_zero = Array1::<f64>::zeros(k);
411        for row in 0..cache.n_rows() {
412            let q = cache.row_dims[row];
413            let base = cache.row_offsets[row];
414            let (inv_vv, inv_vbeta) = solver
415                .selected_inverse_row_blocks(row, &beta_inv)
416                .expect("row blocks");
417            for col in 0..q {
418                let mut rhs_t = Array1::<f64>::zeros(total_t);
419                rhs_t[base + col] = 1.0;
420                let solved = solver
421                    .solve(rhs_t.view(), rhs_beta_zero.view())
422                    .expect("t solve");
423                for r in 0..q {
424                    assert!(
425                        (inv_vv[[r, col]] - solved.t[base + r]).abs() <= 1e-9,
426                        "inv_vv[{r},{col}] {} != solve {}",
427                        inv_vv[[r, col]],
428                        solved.t[base + r]
429                    );
430                }
431                for b in 0..k {
432                    assert!(
433                        (inv_vbeta[[col, b]] - solved.beta[b]).abs() <= 1e-9,
434                        "inv_vbeta[{col},{b}] {} != solve {}",
435                        inv_vbeta[[col, b]],
436                        solved.beta[b]
437                    );
438                }
439            }
440        }
441    }
442}
443
444pub(crate) fn flatten_arrow_parts(
445    t: ArrayView1<'_, f64>,
446    beta: ArrayView1<'_, f64>,
447) -> Array1<f64> {
448    let mut out = Array1::<f64>::zeros(t.len() + beta.len());
449    for i in 0..t.len() {
450        out[i] = t[i];
451    }
452    for i in 0..beta.len() {
453        out[t.len() + i] = beta[i];
454    }
455    out
456}
457
458pub(crate) fn apply_cached_arrow_hessian(
459    cache: &ArrowFactorCache,
460    v_t: ArrayView1<'_, f64>,
461    v_beta: ArrayView1<'_, f64>,
462) -> Result<SaeArrowVector, String> {
463    let total_t = cache.delta_t_len();
464    if v_t.len() != total_t || v_beta.len() != cache.k {
465        return Err(format!(
466            "apply_cached_arrow_hessian: vector shapes (t={}, beta={}) != cache shapes \
467             (t={total_t}, beta={})",
468            v_t.len(),
469            v_beta.len(),
470            cache.k
471        ));
472    }
473
474    let mut out_t = Array1::<f64>::zeros(total_t);
475    let mut out_beta = Array1::<f64>::zeros(cache.k);
476    for row in 0..cache.n_rows() {
477        let di = cache.row_dims[row];
478        let base = cache.row_offsets[row];
479        let row_v = v_t.slice(s![base..base + di]);
480        let factor = cache.undamped_factor(row);
481        let av = cholesky_factor_apply(factor, row_v);
482        for j in 0..di {
483            out_t[base + j] += av[j];
484        }
485        if cache.k > 0 {
486            let mut b_vbeta = Array1::<f64>::zeros(di);
487            if !cache.apply_htbeta_row(row, v_beta, &mut b_vbeta) {
488                return Err(format!(
489                    "apply_cached_arrow_hessian: H_tβ^({row}) apply failed"
490                ));
491            }
492            for j in 0..di {
493                out_t[base + j] += b_vbeta[j];
494            }
495            if !cache.apply_htbeta_row_transpose(row, row_v, &mut out_beta, None) {
496                return Err(format!(
497                    "apply_cached_arrow_hessian: H_βt^({row}) apply failed"
498                ));
499            }
500        }
501    }
502
503    if cache.k > 0 {
504        let Some(schur_factor) = cache.schur_factor.as_ref() else {
505            return Err(
506                "apply_cached_arrow_hessian: dense Schur factor is required for gauge probing"
507                    .to_string(),
508            );
509        };
510        let schur_v = cholesky_factor_apply(schur_factor.view(), v_beta);
511        for i in 0..cache.k {
512            out_beta[i] += schur_v[i];
513        }
514        for row in 0..cache.n_rows() {
515            let di = cache.row_dims[row];
516            let mut b_vbeta = Array1::<f64>::zeros(di);
517            if !cache.apply_htbeta_row(row, v_beta, &mut b_vbeta) {
518                return Err(format!(
519                    "apply_cached_arrow_hessian: H_tβ^({row}) Schur correction apply failed"
520                ));
521            }
522            let a_inv_b_vbeta = cholesky_solve_vector(cache.undamped_factor(row), b_vbeta.view());
523            if !cache.apply_htbeta_row_transpose(row, a_inv_b_vbeta.view(), &mut out_beta, None) {
524                return Err(format!(
525                    "apply_cached_arrow_hessian: H_βt^({row}) Schur correction apply failed"
526                ));
527            }
528        }
529    }
530
531    // #1038 IBP cross-row curvature: when the cache carries the exact rank-`R`
532    // Woodbury, the operator it represents is `H_full = H₀' + U D Uᵀ` (the same
533    // operator `full_inverse_apply` inverts and `arrow_log_det` reports). The
534    // per-row factors reconstructed above are only the NO-SELF base `H₀'`, so the
535    // forward apply MUST add `U D Uᵀ v` here — otherwise the forward operator
536    // (used by the #1418 exact-stationarity solve) silently drops the cross-row
537    // block while its CG preconditioner inverts the full `H_full`, desyncing the
538    // outer-REML gradient. `U` has no `β` support ⇒ only the `t` block changes.
539    if let Some(woodbury) = cache.cross_row_woodbury.as_ref() {
540        woodbury.apply_forward_t(v_t, &mut out_t);
541    }
542
543    Ok(SaeArrowVector {
544        t: out_t,
545        beta: out_beta,
546    })
547}
548
549pub(crate) fn cholesky_factor_apply(
550    factor: ArrayView2<'_, f64>,
551    vector: ArrayView1<'_, f64>,
552) -> Array1<f64> {
553    let n = factor.nrows();
554    let mut lt_v = Array1::<f64>::zeros(n);
555    for row in 0..n {
556        let mut acc = 0.0_f64;
557        for col in row..n {
558            acc += factor[[col, row]] * vector[col];
559        }
560        lt_v[row] = acc;
561    }
562    let mut out = Array1::<f64>::zeros(n);
563    for row in 0..n {
564        let mut acc = 0.0_f64;
565        for col in 0..=row {
566            acc += factor[[row, col]] * lt_v[col];
567        }
568        out[row] = acc;
569    }
570    out
571}
572
573#[derive(Debug, Clone, Copy)]
574pub(crate) enum SaeLocalRowVar {
575    Logit { atom: usize },
576    Coord { atom: usize, axis: usize },
577}
578
579#[derive(Debug, Clone)]
580pub(crate) struct SaeBorderChannel {
581    pub(crate) atom: usize,
582    pub(crate) basis_col: usize,
583    pub(crate) index: usize,
584    pub(crate) output: Vec<f64>,
585}
586
587#[derive(Debug, Clone)]
588pub(crate) struct SaeRowJets {
589    pub(crate) vars: Vec<SaeLocalRowVar>,
590    pub(crate) first: Vec<Vec<f64>>,
591    pub(crate) second: Vec<Vec<Vec<f64>>>,
592    pub(crate) beta: Vec<Vec<f64>>,
593    pub(crate) beta_deriv: Vec<Vec<Vec<f64>>>,
594    pub(crate) beta_l_deriv: Vec<Vec<Vec<f64>>>,
595}
596
597pub(crate) fn sae_dot(a: &[f64], b: &[f64]) -> f64 {
598    a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
599}
600
601/// Euclidean inner product `⟨a, b⟩` over the concatenated `(t, β)` blocks of two
602/// arrow vectors. Used by the #1418 `B`-preconditioned CG inner solve.
603pub(crate) fn sae_inner(a: &SaeArrowVector, b: &SaeArrowVector) -> f64 {
604    sae_dot(a.t.as_slice().unwrap_or(&[]), b.t.as_slice().unwrap_or(&[]))
605        + sae_dot(
606            a.beta.as_slice().unwrap_or(&[]),
607            b.beta.as_slice().unwrap_or(&[]),
608        )
609}
610
611/// Euclidean norm `‖a‖` over the concatenated `(t, β)` blocks of an arrow vector.
612pub(crate) fn sae_norm(a: &SaeArrowVector) -> f64 {
613    sae_inner(a, a).max(0.0).sqrt()
614}
615
616/// #1418: solve `A x = rhs` by **`B`-preconditioned conjugate gradients**, where
617/// `apply_a(v) = A v` is the exact stationarity-Jacobian matvec and the
618/// `solver` (the assembled `B` factorization) supplies the SPD preconditioner
619/// `B⁻¹`. The IFT step `θ̂_ρ = −A⁻¹ g_ρ` must invert the EXACT `A`, not the
620/// surrogate `B`; the earlier truncated Neumann series `Σ_m (−B⁻¹ΔC)^m B⁻¹ rhs`
621/// equals `A⁻¹ rhs` only when `ρ(B⁻¹ΔC) < 1`, and DIVERGED for large
622/// `ΔC = ⟨r, ∂²f⟩`. PCG converges for any spectral radius in ≤ `dim` steps — one
623/// `A` matvec and one `B⁻¹` solve per step, no second factorization. On
624/// non-positive curvature `pᵀ A p ≤ 0` (the high-residual `A` can be indefinite
625/// away from a strict minimum) it stops at the last finite iterate.
626pub(crate) fn solve_b_preconditioned_cg<F>(
627    solver: &DeflatedArrowSolver<'_>,
628    rhs: &SaeArrowVector,
629    apply_a: F,
630) -> Result<SaeArrowVector, String>
631where
632    F: Fn(&SaeArrowVector) -> Result<SaeArrowVector, String>,
633{
634    // x_0 = B⁻¹ rhs (the surrogate step; CG corrects it onto A⁻¹ rhs).
635    let mut x = solver
636        .solve(rhs.t.view(), rhs.beta.view())
637        .map_err(|err| format!("solve_b_preconditioned_cg: B inverse: {err}"))?;
638    // r_0 = rhs − A x_0; z_0 = B⁻¹ r_0; p_0 = z_0.
639    let ax = apply_a(&x)?;
640    let mut r = SaeArrowVector {
641        t: &rhs.t - &ax.t,
642        beta: &rhs.beta - &ax.beta,
643    };
644    let mut z = solver
645        .solve(r.t.view(), r.beta.view())
646        .map_err(|err| format!("solve_b_preconditioned_cg: B preconditioner: {err}"))?;
647    let mut p = z.clone();
648    let mut rz = sae_inner(&r, &z);
649
650    let rhs_norm = sae_norm(rhs).max(1.0);
651    let max_iters = (x.t.len() + x.beta.len()).clamp(8, 256);
652    let rel_tol = 1.0e-10;
653    for _ in 0..max_iters {
654        if !rz.is_finite() || rz <= 0.0 {
655            break; // preconditioned residual exhausted / degenerate.
656        }
657        let ap = apply_a(&p)?;
658        let p_ap = sae_inner(&p, &ap);
659        if !p_ap.is_finite() || p_ap <= 0.0 {
660            break; // non-positive curvature: keep the finite iterate.
661        }
662        let alpha = rz / p_ap;
663        for idx in 0..x.t.len() {
664            x.t[idx] += alpha * p.t[idx];
665            r.t[idx] -= alpha * ap.t[idx];
666        }
667        for idx in 0..x.beta.len() {
668            x.beta[idx] += alpha * p.beta[idx];
669            r.beta[idx] -= alpha * ap.beta[idx];
670        }
671        if sae_norm(&r) <= rel_tol * rhs_norm {
672            break;
673        }
674        z = solver
675            .solve(r.t.view(), r.beta.view())
676            .map_err(|err| format!("solve_b_preconditioned_cg: B preconditioner: {err}"))?;
677        let rz_next = sae_inner(&r, &z);
678        let beta = rz_next / rz;
679        for idx in 0..p.t.len() {
680            p.t[idx] = z.t[idx] + beta * p.t[idx];
681        }
682        for idx in 0..p.beta.len() {
683            p.beta[idx] = z.beta[idx] + beta * p.beta[idx];
684        }
685        rz = rz_next;
686    }
687    Ok(x)
688}