Skip to main content

gam_terms/analytic_penalties/
isometry.rs

1use super::*;
2pub use gam_problem::WeightField;
3
4// ---------------------------------------------------------------------------
5// Isometry penalty
6// ---------------------------------------------------------------------------
7
8/// Choice of reference Riemannian metric `g^ref(t)` on the latent manifold.
9///
10/// `Euclidean` is the natural default: the reference metric is `I_d`, so the
11/// penalty pulls the decoder toward locally-isometric (length-preserving)
12/// behavior. `UserSupplied` lets the caller hand in a `(n_obs, d, d)` jet of
13/// per-row reference metrics (useful for warm-starting from a chart of a
14/// pre-fit GP-LVM).
15#[derive(Clone)]
16pub enum IsometryReference {
17    Euclidean,
18    UserSupplied(Arc<Array2<f64>>), // (n_obs, d*d) row-major flattened
19}
20
21impl std::fmt::Debug for IsometryReference {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        match self {
24            IsometryReference::Euclidean => f.write_str("Euclidean"),
25            IsometryReference::UserSupplied(a) => f
26                .debug_tuple("UserSupplied")
27                .field(&format_args!("{}×{}", a.nrows(), a.ncols()))
28                .finish(),
29        }
30    }
31}
32
33/// Radial Duchon decoder metadata used to materialize
34/// `∂J_n[i, a] / ∂t_{n, c}` from `φ'(r)` and `φ''(r)` on demand.
35///
36/// `radial_coefficients[k, i]` is the decoder coefficient that maps radial
37/// basis column `k` into output channel `i`. Polynomial-tail columns are not
38/// represented here; callers whose decoder contains a non-linear polynomial
39/// tail should provide `jacobian_second_cache` directly.
40#[derive(Debug, Clone)]
41pub struct IsometryDuchonRadialSource {
42    pub centers: Arc<Array2<f64>>,
43    pub radial_coefficients: Arc<Array2<f64>>,
44    pub length_scale: Option<f64>,
45    pub nullspace_order: DuchonNullspaceOrder,
46    /// Forward hybrid spectral order `s = spec.power`. The Cartesian
47    /// derivative engine must resolve the same `(p, s, κ)` the forward
48    /// `build_duchon_basis` used, so it differentiates the exact resolved
49    /// hybrid Green's function `φ_{p,s,κ}` rather than a hard-coded `s = 0`
50    /// surrogate (issue #440).
51    pub power: usize,
52}
53
54/// Isometry-to-reference penalty (canonical-coordinate gauge term).
55///
56/// Lives on ext-coords: the target slice is a row of the `LatentCoordValues` flat
57/// vector (row-major `n_obs × d`). Owns one ρ-axis (`log μ_iso`).
58///
59/// Penalizes `½ μ Σ_n ‖g_n(t) − g^ref(t_n)‖²_F`, where the pullback metric
60/// at row `n` is
61///
62/// ```text
63///   g_n = J_n^T W_n J_n,    J_n ∈ ℝ^{p × d}
64/// ```
65///
66/// and `W_n` is a per-row low-rank PSD behavioral metric stored as
67/// `W_n = U_n U_n^T` with `U_n ∈ ℝ^{p × r}`. The canonical-coordinate
68/// statement is "one unit of motion in `t` ↦ one unit of behavioral change",
69/// so the `W_n` weighting is load-bearing.
70///
71/// In the SAE objective this is the extension-coordinate gauge fix: it prevents
72/// the latent chart from absorbing arbitrary smooth reparameterizations of the
73/// decoder manifold. ARD, sparsity, or rank penalties can then select axes or
74/// structure in a chart whose metric scale is pinned.
75///
76/// **Contraction order invariant.** Every place this struct touches `W_n`,
77/// the contraction is `(J^T U_n)(U_n^T J)` — never `J^T W_n J` with `W_n`
78/// materialized as `p × p`. Concretely we form `M_n = U_n^T J_n ∈ ℝ^{r × d}`
79/// once and then `g_n = M_n^T M_n` (`d × d`). Cost per row:
80/// `O(p · r · d + r · d²)`, independent of `p²`.
81///
82/// **When to use.** Whenever a `LatentCoord` block is in play without an
83/// auxiliary variable (`AuxPrior`) to break the diffeomorphism gauge. Fixes
84/// the audit finding that ARD is not a standalone gauge fix. With a Euclidean
85/// reference, the penalty pulls the decoder toward a local isometry, which is
86/// enough to make the inner Hessian on `t` full-rank and the IFT well-defined.
87///
88/// **Math.** Let `J_n ∈ ℝ^{p × d}` be the local decoder Jacobian. Then
89/// `g_n = J_n^T W_n J_n` and the penalty is
90/// `½ μ Σ_n ‖J_n^T W_n J_n − g^ref_n‖²_F`. Analytic gradient w.r.t. `t_n`:
91///
92/// ```text
93///   ∂P/∂t_{n,c}
94///     = μ Σ_{a,b} (g_n − g^ref_n)_{ab}
95///         [ H_{n,:,a,c}^T W_n J_{n,:,b}
96///           + J_{n,:,a}^T W_n H_{n,:,b,c} ],
97///   H_{n,i,a,c} = ∂J_{n,i,a}/∂t_{n,c}.
98/// ```
99///
100/// Gotchas:
101///
102/// * The value path returns the configured missing-cache default when the
103///   first-jet cache is absent; gradient/HVP paths need the first and second
104///   decoder jets and return zeros when the analytic jet source is unavailable.
105/// * The exact Hessian includes a residual-curvature term requiring the third
106///   decoder jet. REML/PIRLS curvature should prefer the Gauss-Newton PSD
107///   majorizer when a positive curvature block is required.
108/// * `W_n` is a metric weight, not a scalar confidence. Changing it changes the
109///   canonical units of latent motion.
110///
111/// The per-row Jacobian `J_n` is exactly the radial-derivative jet
112/// `design_gradient_wrt_t` already computes for `LatentCoordValues`; the
113/// second derivative `∂J/∂t` is built by the shared
114/// [`crate::basis::radial_basis_cartesian_derivative`] engine from the
115/// radial Hessian identity. A finite-difference oracle for the docstring is
116/// to central-difference `value(t ± h e_j)` against `grad_target(t)[j]`;
117/// the analytic value follows the oracle until finite-difference
118/// cancellation dominates. No autograd needed.
119///
120/// `μ = exp(ρ_iso)` is REML-selectable as one extra ρ axis.
121///
122/// `jacobian_cache_slot` and `jacobian_second_cache_slot` are interior-mutable
123/// (`RwLock<Option<Arc<…>>>`) so the SAE outer loop can refresh them in place
124/// each step without needing `&mut self` on the registry-held penalty (see
125/// `refresh_caches` and [`crate::terms::sae::manifold::refresh_isometry_caches_from_atom`]).
126/// Readers go through the [`Self::jacobian_cache`] / [`Self::jacobian_second_cache`]
127/// accessors, which take the read lock briefly and clone the inner `Arc`
128/// (refcount bump — no payload copy). Writers go through [`Self::refresh_caches`].
129#[derive(Debug)]
130pub struct IsometryPenalty {
131    pub target: PsiSlice,
132    pub reference: IsometryReference,
133    /// Index of this penalty's strength `log μ_iso` inside the *local* rho
134    /// view this penalty receives. Always `0` for now (single owned axis).
135    pub rho_index: usize,
136    /// Cached Jacobian `J ∈ ℝ^{n_obs × p × d}`, flattened row-major
137    /// `(n_obs, p*d)`. The owning driver refreshes this each IFT outer step
138    /// before invoking `value` / `grad_target`; in operator-only call sites
139    /// (Hessian-vector products) the cache must be live. Access through
140    /// [`Self::jacobian_cache`] / [`Self::set_jacobian_cache`].
141    pub jacobian_cache_slot: RwLock<Option<Arc<Array2<f64>>>>,
142    /// Optional cached per-row Jacobian *second derivative*
143    /// `H_n ∈ ℝ^{p × d × d}`, flattened row-major as `(n_obs, p*d*d)`.
144    /// `H_n[i, a, c] = ∂J_n[i, a] / ∂t_{n, c}`. Either this cache or
145    /// `duchon_radial_source` must be present for exact isometry
146    /// gradient/HVP calls. Access through [`Self::jacobian_second_cache`] /
147    /// [`Self::set_jacobian_second_cache`].
148    pub jacobian_second_cache_slot: RwLock<Option<Arc<Array2<f64>>>>,
149    /// Optional radial-Duchon source used to build `jacobian_second_cache`
150    /// analytically from `φ'(r)` and the public `φ''(r)` jet helper. This is
151    /// the exact chain-rule path for callers that do not pre-cache `∂J/∂t`.
152    pub duchon_radial_source: Option<Arc<IsometryDuchonRadialSource>>,
153    /// Optional cached per-row Jacobian *third derivative*
154    /// `K_n ∈ ℝ^{p × d × d × d}`, stored as an `Array3` with shape
155    /// `(n_obs, p, d * d * d)` where the third axis packs `(a, c, d)` in
156    /// row-major order `((a * d) + c) * d + dd`. `hvp` uses the full
157    /// residual-curvature Hessian (proposal §4(b)):
158    ///   B_{ab,cd} = K_{a,cd}^T W J_b + H_{a,c}^T W H_{b,d}
159    ///             + H_{a,d}^T W H_{b,c} + J_a^T W K_{b,cd}.
160    /// Either this cache or `duchon_radial_source` must be present for
161    /// analytic `hvp` calls. Interior-mutable (mirrors
162    /// `jacobian_second_cache_slot`) so the SAE outer loop can refresh `K` in
163    /// place each step. Access through [`Self::third_decoder_derivative`] /
164    /// [`Self::set_third_decoder_derivative`].
165    pub third_decoder_derivative_slot: RwLock<Option<Arc<ndarray::Array3<f64>>>>,
166    /// Output dimensionality `p` (column count of each per-row Jacobian).
167    pub p_out: usize,
168    /// Per-row behavioral metric in low-rank factored form. Defaults to
169    /// `Identity` (the unweighted `J^T J` pullback). When `Factored`, all
170    /// `g_n` contractions are done via `M_n = U_n^T J_n` (`r × d`), keeping
171    /// memory and FLOPs scaling at `O(p · r · d)` per row instead of
172    /// `O(p²)` per row.
173    pub weight: WeightField,
174    pub scalar_weight: f64,
175    pub weight_schedule: Option<ScalarWeightSchedule>,
176}
177
178pub(crate) struct IsometryHvpState<'a> {
179    d: usize,
180    n_obs: usize,
181    p: usize,
182    jac2: CowArray<'a, f64, Ix2>,
183    jac3: CowArray<'a, f64, Ix3>,
184    metric: IsometryMetricState,
185    wj_rows: Vec<Array2<f64>>,
186}
187
188#[derive(Debug, Clone)]
189struct IsometryMetricState {
190    g: Array2<f64>,
191    residual: Array2<f64>,
192    metric_grad: Array2<f64>,
193    normalizer: f64,
194    trace_denominator: f64,
195    residual_dot_g: f64,
196}
197
198impl IsometryMetricState {
199    fn residual_direction(&self, delta_g: ArrayView2<'_, f64>, d: usize) -> (Array2<f64>, f64) {
200        let n_obs = self.g.nrows();
201        let dd = d * d;
202        let mut delta_trace_sum = 0.0;
203        for n in 0..n_obs {
204            for a in 0..d {
205                delta_trace_sum += delta_g[[n, a * d + a]];
206            }
207        }
208        let delta_normalizer = delta_trace_sum / self.trace_denominator;
209        let inv_norm = 1.0 / self.normalizer;
210        let inv_norm_sq = inv_norm * inv_norm;
211        let mut delta_residual = Array2::<f64>::zeros((n_obs, dd));
212        for n in 0..n_obs {
213            for k in 0..dd {
214                delta_residual[[n, k]] =
215                    delta_g[[n, k]] * inv_norm - self.g[[n, k]] * delta_normalizer * inv_norm_sq;
216            }
217        }
218        (delta_residual, delta_normalizer)
219    }
220
221    fn metric_grad_direction(&self, delta_g: ArrayView2<'_, f64>, d: usize) -> Array2<f64> {
222        let n_obs = self.g.nrows();
223        let dd = d * d;
224        let (delta_residual, delta_normalizer) = self.residual_direction(delta_g, d);
225        let mut delta_residual_dot_g = 0.0;
226        for n in 0..n_obs {
227            for k in 0..dd {
228                delta_residual_dot_g += delta_residual[[n, k]] * self.g[[n, k]];
229                delta_residual_dot_g += self.residual[[n, k]] * delta_g[[n, k]];
230            }
231        }
232        let inv_norm = 1.0 / self.normalizer;
233        let inv_norm_sq = inv_norm * inv_norm;
234        let delta_trace_coeff = delta_residual_dot_g * inv_norm_sq / self.trace_denominator
235            - 2.0 * self.residual_dot_g * delta_normalizer * inv_norm_sq * inv_norm
236                / self.trace_denominator;
237        let mut out = Array2::<f64>::zeros((n_obs, dd));
238        for n in 0..n_obs {
239            for a in 0..d {
240                for b in 0..d {
241                    let k = a * d + b;
242                    let mut value = delta_residual[[n, k]] * inv_norm
243                        - self.residual[[n, k]] * delta_normalizer * inv_norm_sq;
244                    if a == b {
245                        value -= delta_trace_coeff;
246                    }
247                    out[[n, k]] = value;
248                }
249            }
250        }
251        out
252    }
253}
254
255fn isometry_dg_entry(
256    jac2: ArrayView2<'_, f64>,
257    wj: ArrayView2<'_, f64>,
258    n: usize,
259    d: usize,
260    p: usize,
261    a: usize,
262    b: usize,
263    c: usize,
264) -> f64 {
265    let mut s = 0.0;
266    for i in 0..p {
267        s += jac2[[n, (i * d + a) * d + c]] * wj[[i, b]];
268        s += wj[[i, a]] * jac2[[n, (i * d + b) * d + c]];
269    }
270    s
271}
272
273fn isometry_row_delta_g(
274    jac2: ArrayView2<'_, f64>,
275    wj: ArrayView2<'_, f64>,
276    v: ArrayView1<'_, f64>,
277    n: usize,
278    d: usize,
279    p: usize,
280) -> Array2<f64> {
281    let mut delta_g = Array2::<f64>::zeros((d, d));
282    for a in 0..d {
283        for b in 0..d {
284            let mut s = 0.0;
285            for c in 0..d {
286                s += isometry_dg_entry(jac2, wj, n, d, p, a, b, c) * v[n * d + c];
287            }
288            delta_g[[a, b]] = s;
289        }
290    }
291    delta_g
292}
293
294impl IsometryPenalty {
295    pub const DEFAULT_VALUE_ON_MISSING_CACHE: f64 = 0.0;
296
297    #[must_use]
298    pub fn new_euclidean(target: PsiSlice, p_out: usize) -> Self {
299        Self {
300            target,
301            reference: IsometryReference::Euclidean,
302            rho_index: 0,
303            jacobian_cache_slot: RwLock::new(None),
304            jacobian_second_cache_slot: RwLock::new(None),
305            duchon_radial_source: None,
306            third_decoder_derivative_slot: RwLock::new(None),
307            p_out,
308            weight: WeightField::Identity,
309            scalar_weight: 1.0,
310            weight_schedule: None,
311        }
312    }
313
314    /// Read-side accessor: takes the read lock briefly and clones the inner
315    /// `Arc` (refcount bump only; no payload copy). Returns `None` when the
316    /// cache has not been refreshed yet. Internally panics on poisoned lock
317    /// — the lock only wraps an `Option<Arc<…>>`, so the write side cannot
318    /// leave it in an invariant-violating state.
319    #[must_use]
320    pub fn jacobian_cache(&self) -> Option<Arc<Array2<f64>>> {
321        self.jacobian_cache_slot
322            .read()
323            .expect("IsometryPenalty::jacobian_cache_slot poisoned")
324            .clone()
325    }
326
327    /// Read-side accessor for the per-row Jacobian second derivative.
328    /// Mirrors [`Self::jacobian_cache`].
329    #[must_use]
330    pub fn jacobian_second_cache(&self) -> Option<Arc<Array2<f64>>> {
331        self.jacobian_second_cache_slot
332            .read()
333            .expect("IsometryPenalty::jacobian_second_cache_slot poisoned")
334            .clone()
335    }
336
337    /// Per-step refresh entry point. Takes `&self` (no `&mut`) so the SAE
338    /// outer loop can install fresh caches on an `Arc<IsometryPenalty>` held
339    /// in the analytic-penalty registry without disturbing the surrounding
340    /// dispatcher. Pass `None` for either argument to clear that cache (the
341    /// dispatcher will then either fall back to the Duchon radial source if
342    /// available, or return the zero safe default).
343    pub fn refresh_caches(&self, jac: Option<Arc<Array2<f64>>>, jac2: Option<Arc<Array2<f64>>>) {
344        *self
345            .jacobian_cache_slot
346            .write()
347            .expect("IsometryPenalty::jacobian_cache_slot poisoned") = jac;
348        *self
349            .jacobian_second_cache_slot
350            .write()
351            .expect("IsometryPenalty::jacobian_second_cache_slot poisoned") = jac2;
352    }
353
354    /// In-place writer for just the Jacobian cache (used by callers that
355    /// already own the radial Duchon source and only want to refresh `J`).
356    pub fn set_jacobian_cache(&self, jac: Option<Arc<Array2<f64>>>) {
357        *self
358            .jacobian_cache_slot
359            .write()
360            .expect("IsometryPenalty::jacobian_cache_slot poisoned") = jac;
361    }
362
363    /// In-place writer for just the Jacobian second-derivative cache.
364    pub fn set_jacobian_second_cache(&self, jac2: Option<Arc<Array2<f64>>>) {
365        *self
366            .jacobian_second_cache_slot
367            .write()
368            .expect("IsometryPenalty::jacobian_second_cache_slot poisoned") = jac2;
369    }
370
371    /// Read-side accessor for the per-row Jacobian third derivative `K`.
372    /// Mirrors [`Self::jacobian_second_cache`].
373    #[must_use]
374    pub fn third_decoder_derivative(&self) -> Option<Arc<ndarray::Array3<f64>>> {
375        self.third_decoder_derivative_slot
376            .read()
377            .expect("IsometryPenalty::third_decoder_derivative_slot poisoned")
378            .clone()
379    }
380
381    /// In-place writer for just the Jacobian third-derivative cache `K`.
382    pub fn set_third_decoder_derivative(&self, jac3: Option<Arc<ndarray::Array3<f64>>>) {
383        *self
384            .third_decoder_derivative_slot
385            .write()
386            .expect("IsometryPenalty::third_decoder_derivative_slot poisoned") = jac3;
387    }
388}
389
390impl Clone for IsometryPenalty {
391    fn clone(&self) -> Self {
392        Self {
393            target: self.target.clone(),
394            reference: self.reference.clone(),
395            rho_index: self.rho_index,
396            jacobian_cache_slot: RwLock::new(self.jacobian_cache()),
397            jacobian_second_cache_slot: RwLock::new(self.jacobian_second_cache()),
398            duchon_radial_source: self.duchon_radial_source.clone(),
399            third_decoder_derivative_slot: RwLock::new(self.third_decoder_derivative()),
400            p_out: self.p_out,
401            weight: self.weight.clone(),
402            scalar_weight: self.scalar_weight,
403            weight_schedule: self.weight_schedule.clone(),
404        }
405    }
406}
407
408impl IsometryPenalty {
409    /// Attach a cached third decoder derivative
410    /// `K_n[i, a, c, d] = ∂²J_n[i, a] / ∂t_{n, c} ∂t_{n, d}`, flattened
411    /// row-major as `(n_obs, p * d * d * d)`. The Hessian-vector product
412    /// uses the full residual-curvature term in addition to the metric
413    /// Gauss-Newton piece.
414    #[must_use]
415    pub fn with_third_decoder_derivative(self, k: Arc<ndarray::Array3<f64>>) -> Self {
416        self.set_third_decoder_derivative(Some(k));
417        self
418    }
419
420    #[must_use]
421    pub fn with_reference(mut self, reference: IsometryReference) -> Self {
422        self.reference = reference;
423        self
424    }
425
426    #[must_use]
427    pub fn with_jacobian_cache(self, j: Arc<Array2<f64>>) -> Self {
428        self.set_jacobian_cache(Some(j));
429        self
430    }
431
432    #[must_use]
433    pub fn with_jacobian_second_cache(self, h: Arc<Array2<f64>>) -> Self {
434        self.set_jacobian_second_cache(Some(h));
435        self
436    }
437
438    /// Attach radial Duchon decoder metadata so the exact `∂J/∂t` tensor can
439    /// be rebuilt from the current target coordinates. A doc-test oracle for
440    /// this path is: build `J(t)` from `duchon_radial_first_derivative_nd`,
441    /// evaluate `grad_target(t)`, then central-difference `value(t ± h e_j)`;
442    /// the analytic component should agree to finite-difference tolerance as
443    /// `h` is refined before cancellation dominates.
444    #[must_use]
445    pub fn with_duchon_radial_source(mut self, source: Arc<IsometryDuchonRadialSource>) -> Self {
446        self.duchon_radial_source = Some(source);
447        self
448    }
449
450    /// Attach the gauge metric **from the single
451    /// [`RowMetric`](gam_problem::RowMetric)** that also drives
452    /// the reconstruction likelihood. This is the only way an `IsometryPenalty`
453    /// acquires a non-identity behavioral metric: the independent
454    /// `WeightField` setter has been removed so a gauge-metric ≠
455    /// likelihood-metric state is structurally unrepresentable. The
456    /// contraction-order invariant (`M_n = U_n^T J_n`, never materializing the
457    /// `p × p` `W_n`) is preserved by the [`WeightField::Factored`] layout the
458    /// metric emits.
459    ///
460    /// `p_out` is taken from the metric so the gauge's output dimension is
461    /// pinned to the metric's.
462    #[must_use]
463    pub fn with_row_metric(mut self, metric: &gam_problem::RowMetric) -> Self {
464        // Only a metric that drives the gauge installs a non-identity pullback
465        // weight. A Euclidean metric reduces the gauge pullback to the bare
466        // `J_nᵀ J_n`, so its `to_weight_field()` is `Identity` and the existing
467        // (default-Identity) weight is left exactly as is — bit-for-bit the
468        // pre-metric isotropic gauge. The output dimension is pinned to the
469        // metric's regardless, so the gauge and likelihood agree on `p_out`.
470        if metric.drives_gauge() {
471            self.weight = metric.to_weight_field();
472        }
473        self.p_out = metric.p_out();
474        self
475    }
476
477    impl_with_weight_schedule!(scalar_weight);
478
479    fn missing_cache_default(&self, method: &str, detail: &str) {
480        log::warn!(
481            "IsometryPenalty::{method} missing required derivative state: {detail}; \
482             returning the zero safe default"
483        );
484    }
485
486    fn has_jacobian_cache(&self, method: &str) -> bool {
487        if self.jacobian_cache().is_some() {
488            true
489        } else {
490            self.missing_cache_default(method, "jacobian_cache is None");
491            false
492        }
493    }
494
495    fn has_jacobian_second_source(&self, method: &str) -> bool {
496        if self.jacobian_second_cache().is_some() || self.duchon_radial_source.is_some() {
497            true
498        } else {
499            self.missing_cache_default(
500                method,
501                "both jacobian_second_cache and duchon_radial_source are None",
502            );
503            false
504        }
505    }
506
507    fn has_jacobian_third_source(&self, method: &str) -> bool {
508        if self.third_decoder_derivative().is_some() || self.duchon_radial_source.is_some() {
509            true
510        } else {
511            self.missing_cache_default(
512                method,
513                "both third_decoder_derivative cache and duchon_radial_source are None",
514            );
515            false
516        }
517    }
518
519    /// Build `M_n = U_n^T J_n ∈ ℝ^{r_n × d}` for row `n`. For
520    /// `WeightField::Identity`, `r_n = p` and `M_n = J_n`.
521    ///
522    /// This is the single contraction site where `W_n` (or its `U_n` factor)
523    /// is consumed. Every value/grad/hvp path funnels through here, so the
524    /// `(J^T U)(U^T J)` ordering invariant cannot be violated by accident.
525    fn projected_jacobian_row(&self, n: usize, d: usize) -> Option<Array2<f64>> {
526        let Some(jac) = self.jacobian_cache() else {
527            self.missing_cache_default("projected_jacobian_row", "jacobian_cache is None");
528            return None;
529        };
530        let jac_row = jac.row(n);
531        let jac_slice = jac_row
532            .as_slice()
533            .expect("jacobian cache must be in standard row-major layout");
534        match &self.weight {
535            WeightField::Identity => {
536                let p = self.p_out;
537                let mut m = Array2::<f64>::zeros((p, d));
538                for i in 0..p {
539                    for a in 0..d {
540                        m[[i, a]] = jac_slice[i * d + a];
541                    }
542                }
543                Some(m)
544            }
545            WeightField::Factored { u, rank, p_out } => {
546                let u_row = u.row(n);
547                let u_slice = u_row
548                    .as_slice()
549                    .expect("weight factor U must be in standard row-major layout");
550                Some(WeightField::project_jac_row_with_u(
551                    u_slice, jac_slice, *p_out, *rank, d,
552                ))
553            }
554        }
555    }
556
557    /// Form `W_n J_n` without materializing `W_n`.
558    fn weighted_jacobian_row(&self, n: usize, d: usize) -> Option<Array2<f64>> {
559        let Some(jac) = self.jacobian_cache() else {
560            self.missing_cache_default("weighted_jacobian_row", "jacobian_cache is None");
561            return None;
562        };
563        let p = self.p_out;
564        match &self.weight {
565            WeightField::Identity => {
566                let mut out = Array2::<f64>::zeros((p, d));
567                for i in 0..p {
568                    for a in 0..d {
569                        out[[i, a]] = jac[[n, i * d + a]];
570                    }
571                }
572                Some(out)
573            }
574            WeightField::Factored { u, rank, p_out } => {
575                assert_eq!(p, *p_out);
576                let r = *rank;
577                let m_n = self.projected_jacobian_row(n, d)?;
578                let mut out = Array2::<f64>::zeros((p, d));
579                for i in 0..p {
580                    for a in 0..d {
581                        let mut s = 0.0;
582                        for k in 0..r {
583                            s += u[[n, i * r + k]] * m_n[[k, a]];
584                        }
585                        out[[i, a]] = s;
586                    }
587                }
588                Some(out)
589            }
590        }
591    }
592
593    fn weighted_dot_decoder_vectors<F, G>(&self, n: usize, p: usize, x: F, y: G) -> f64
594    where
595        F: Fn(usize) -> f64,
596        G: Fn(usize) -> f64,
597    {
598        match &self.weight {
599            WeightField::Identity => {
600                let mut s = 0.0;
601                for i in 0..p {
602                    s += x(i) * y(i);
603                }
604                s
605            }
606            WeightField::Factored { u, rank, p_out } => {
607                assert_eq!(p, *p_out);
608                let r = *rank;
609                let mut s = 0.0;
610                for k in 0..r {
611                    let mut ux = 0.0;
612                    let mut uy = 0.0;
613                    for i in 0..p {
614                        let uik = u[[n, i * r + k]];
615                        ux += uik * x(i);
616                        uy += uik * y(i);
617                    }
618                    s += ux * uy;
619                }
620                s
621            }
622        }
623    }
624
625    fn target_matrix(target: ArrayView1<'_, f64>, n_obs: usize, d: usize) -> Array2<f64> {
626        let mut out = Array2::<f64>::zeros((n_obs, d));
627        for n in 0..n_obs {
628            for a in 0..d {
629                out[[n, a]] = target[n * d + a];
630            }
631        }
632        out
633    }
634
635    /// Second-order input-location derivative tensor of the Duchon decoder,
636    /// flattened to `(n_obs, p_out · d²)` with column layout
637    /// `i·d² + (a·d + c)`.
638    ///
639    /// Thin adapter over the shared [`radial_basis_cartesian_derivative`]
640    /// engine: it owns the radial-jet evaluation and the radial→Cartesian map;
641    /// here we only forward the source geometry.
642    fn duchon_radial_jacobian_second(
643        &self,
644        target: ArrayView1<'_, f64>,
645        n_obs: usize,
646        d: usize,
647        source: &IsometryDuchonRadialSource,
648    ) -> Result<Array2<f64>, BasisError> {
649        assert_eq!(source.centers.ncols(), d);
650        assert_eq!(source.radial_coefficients.nrows(), source.centers.nrows());
651        assert_eq!(source.radial_coefficients.ncols(), self.p_out);
652        let t = Self::target_matrix(target, n_obs, d);
653        radial_basis_cartesian_derivative(
654            2,
655            t.view(),
656            source.centers.view(),
657            source.radial_coefficients.view(),
658            source.length_scale,
659            source.nullspace_order,
660            source.power,
661        )
662    }
663
664    /// Third-order input-location derivative tensor of the Duchon decoder,
665    /// shaped `(n_obs, p_out, d³)` with last-axis layout `(a·d + c)·d + e`.
666    ///
667    /// Thin adapter over the shared [`radial_basis_cartesian_derivative`]
668    /// engine; the flat `(n_obs, p_out · d³)` result is reshaped to the
669    /// `Array3` consumed by the HVP path (row-major flatten of `(p_out, d³)`
670    /// is exactly `i·d³ + idx`).
671    fn duchon_radial_jacobian_third(
672        &self,
673        target: ArrayView1<'_, f64>,
674        n_obs: usize,
675        d: usize,
676        source: &IsometryDuchonRadialSource,
677    ) -> Result<ndarray::Array3<f64>, BasisError> {
678        assert_eq!(source.centers.ncols(), d);
679        assert_eq!(source.radial_coefficients.nrows(), source.centers.nrows());
680        assert_eq!(source.radial_coefficients.ncols(), self.p_out);
681        let t = Self::target_matrix(target, n_obs, d);
682        let flat = radial_basis_cartesian_derivative(
683            3,
684            t.view(),
685            source.centers.view(),
686            source.radial_coefficients.view(),
687            source.length_scale,
688            source.nullspace_order,
689            source.power,
690        )?;
691        Ok(flat
692            .into_shape_with_order((n_obs, self.p_out, d * d * d))
693            .expect("radial_basis_cartesian_derivative order-3 output reshapes to (n_obs, p, d³)"))
694    }
695
696    fn jacobian_second<'a>(
697        &'a self,
698        target: ArrayView1<'_, f64>,
699        n_obs: usize,
700        d: usize,
701    ) -> Option<CowArray<'a, f64, Ix2>> {
702        if let Some(jac2) = self.jacobian_second_cache() {
703            // Clone the underlying Array2 to detach from the Arc — the
704            // CowArray needs to outlive the temporary Arc returned by the
705            // accessor. The clone is `n_obs × p·d²` floats, paid once per
706            // grad_target / hvp_state invocation; same per-step cost as the
707            // pre-refactor code path which also took ownership via
708            // `jac2.view().to_owned()` semantics implicitly.
709            return Some(CowArray::from((*jac2).clone()));
710        }
711        let source = self.duchon_radial_source.as_ref()?;
712        match self.duchon_radial_jacobian_second(target, n_obs, d, source) {
713            Ok(jac2) => Some(CowArray::from(jac2)),
714            Err(err) => {
715                self.missing_cache_default(
716                    "jacobian_second",
717                    &format!("failed to materialize Duchon radial second derivative: {err}"),
718                );
719                None
720            }
721        }
722    }
723
724    fn jacobian_third<'a>(
725        &'a self,
726        target: ArrayView1<'_, f64>,
727        n_obs: usize,
728        d: usize,
729    ) -> Option<CowArray<'a, f64, Ix3>> {
730        if let Some(jac3) = self.third_decoder_derivative() {
731            return Some(CowArray::from(jac3.as_ref().clone()));
732        }
733        let source = self.duchon_radial_source.as_ref()?;
734        match self.duchon_radial_jacobian_third(target, n_obs, d, source) {
735            Ok(jac3) => Some(CowArray::from(jac3)),
736            Err(err) => {
737                self.missing_cache_default(
738                    "jacobian_third",
739                    &format!("failed to materialize Duchon radial third derivative: {err}"),
740                );
741                None
742            }
743        }
744    }
745
746    pub(crate) fn hvp_state<'a>(
747        &'a self,
748        target: ArrayView1<'_, f64>,
749    ) -> Option<IsometryHvpState<'a>> {
750        let d = self
751            .target
752            .latent_dim
753            .expect("IsometryPenalty requires latent_dim on its PsiSlice");
754        let n_obs = target.len() / d;
755        if !self.has_jacobian_cache("hvp")
756            || !self.has_jacobian_second_source("hvp")
757            || !self.has_jacobian_third_source("hvp")
758        {
759            return None;
760        }
761        let p = self.p_out;
762        let jac2 = self.jacobian_second(target.view(), n_obs, d)?;
763        let jac3 = self.jacobian_third(target.view(), n_obs, d)?;
764        let g = self.pullback_metric(d)?;
765        let metric = self.normalized_metric_state(g, n_obs, d)?;
766        let mut wj_rows = Vec::with_capacity(n_obs);
767        for n in 0..n_obs {
768            wj_rows.push(self.weighted_jacobian_row(n, d)?);
769        }
770        Some(IsometryHvpState {
771            d,
772            n_obs,
773            p,
774            jac2,
775            jac3,
776            metric,
777            wj_rows,
778        })
779    }
780
781    pub(crate) fn hvp_with_precomputed_state(
782        &self,
783        state: &IsometryHvpState<'_>,
784        rho: ArrayView1<'_, f64>,
785        v: ArrayView1<'_, f64>,
786    ) -> Array1<f64> {
787        let mu = resolve_learnable_weight(self.scalar_weight, rho[self.rho_index]);
788        let d = state.d;
789        let n_obs = state.n_obs;
790        let p = state.p;
791        let jac2 = &state.jac2;
792        let jac3 = &state.jac3;
793        let metric = &state.metric;
794        let mut out = Array1::<f64>::zeros(v.len());
795        let mut delta_g = Array2::<f64>::zeros((n_obs, d * d));
796        for n in 0..n_obs {
797            let wj = &state.wj_rows[n];
798            let row_delta = isometry_row_delta_g(jac2.view(), wj.view(), v, n, d, p);
799            for a in 0..d {
800                for b in 0..d {
801                    delta_g[[n, a * d + b]] = row_delta[[a, b]];
802                }
803            }
804        }
805        let delta_metric_grad = metric.metric_grad_direction(delta_g.view(), d);
806
807        for n in 0..n_obs {
808            let wj = &state.wj_rows[n];
809            for c in 0..d {
810                let mut acc = 0.0;
811                for a in 0..d {
812                    for b in 0..d {
813                        let dg = isometry_dg_entry(jac2.view(), wj.view(), n, d, p, a, b, c);
814                        acc += dg * delta_metric_grad[[n, a * d + b]];
815                    }
816                }
817                out[n * d + c] = mu * acc;
818            }
819
820            for c in 0..d {
821                let mut acc_res = 0.0;
822                for a in 0..d {
823                    for b in 0..d {
824                        let metric_grad = metric.metric_grad[[n, a * d + b]];
825                        if metric_grad == 0.0 {
826                            continue;
827                        }
828                        let mut bv = 0.0;
829                        for dd in 0..d {
830                            let vd = v[n * d + dd];
831                            if vd == 0.0 {
832                                continue;
833                            }
834                            let mut k_a_cd_w_j_b = 0.0;
835                            for i in 0..p {
836                                k_a_cd_w_j_b += jac3[[n, i, ((a * d) + c) * d + dd]] * wj[[i, b]];
837                            }
838                            let h_a_c_w_h_b_d = self.weighted_dot_decoder_vectors(
839                                n,
840                                p,
841                                |i| jac2[[n, (i * d + a) * d + c]],
842                                |i| jac2[[n, (i * d + b) * d + dd]],
843                            );
844                            let h_a_d_w_h_b_c = self.weighted_dot_decoder_vectors(
845                                n,
846                                p,
847                                |i| jac2[[n, (i * d + a) * d + dd]],
848                                |i| jac2[[n, (i * d + b) * d + c]],
849                            );
850                            let mut j_a_w_k_b_cd = 0.0;
851                            for i in 0..p {
852                                j_a_w_k_b_cd += wj[[i, a]] * jac3[[n, i, ((b * d) + c) * d + dd]];
853                            }
854                            bv +=
855                                (k_a_cd_w_j_b + h_a_c_w_h_b_d + h_a_d_w_h_b_c + j_a_w_k_b_cd) * vd;
856                        }
857                        acc_res += metric_grad * bv;
858                    }
859                }
860                out[n * d + c] += mu * acc_res;
861            }
862        }
863        out
864    }
865
866    /// Per-row pullback metric `g_n = J_n^T W_n J_n = M_n^T M_n` with
867    /// `M_n = U_n^T J_n ∈ ℝ^{r_n × d}`. Returns `(n_obs, d, d)` flattened
868    /// row-major as `(n_obs, d*d)`.
869    ///
870    /// Cost per row: `O(p · r · d)` for the `M_n` build (single pass over
871    /// `U_n` and `J_n`) plus `O(r · d²)` for `M_n^T M_n`. The `p × p` weight
872    /// `W_n` is never materialized.
873    pub fn pullback_metric(&self, latent_dim: usize) -> Option<Array2<f64>> {
874        let Some(jac) = self.jacobian_cache() else {
875            self.missing_cache_default("pullback_metric", "jacobian_cache is None");
876            return None;
877        };
878        let n_obs = jac.nrows();
879        let p = self.p_out;
880        assert_eq!(jac.ncols(), p * latent_dim);
881        let mut g_all = Array2::<f64>::zeros((n_obs, latent_dim * latent_dim));
882        for n in 0..n_obs {
883            // M_n = U_n^T J_n  (or J_n itself when W = I).
884            let m = self.projected_jacobian_row(n, latent_dim)?;
885            let r = m.nrows();
886            // g_n = M_n^T M_n: (d × d) result, contracting r.
887            for a in 0..latent_dim {
888                for b in 0..latent_dim {
889                    let mut s = 0.0;
890                    for k in 0..r {
891                        s += m[[k, a]] * m[[k, b]];
892                    }
893                    g_all[[n, a * latent_dim + b]] = s;
894                }
895            }
896        }
897        Some(g_all)
898    }
899
900    /// Reference metric per row for the normalized pullback metric, `(n_obs, d*d)`.
901    fn reference_metric(&self, n_obs: usize, d: usize) -> CowArray<'_, f64, Ix2> {
902        match &self.reference {
903            IsometryReference::Euclidean => {
904                let mut out = Array2::<f64>::zeros((n_obs, d * d));
905                for n in 0..n_obs {
906                    for a in 0..d {
907                        out[[n, a * d + a]] = 1.0;
908                    }
909                }
910                CowArray::from(out)
911            }
912            IsometryReference::UserSupplied(a) => {
913                assert_eq!(a.nrows(), n_obs);
914                assert_eq!(a.ncols(), d * d);
915                CowArray::from(a.view())
916            }
917        }
918    }
919
920    /// Shared normalized metric state for the scale-invariant isometry gauge.
921    ///
922    /// The residual is `R_n = g_n / gbar - g_ref,n`, with
923    /// `gbar = (1 / (N d)) Σ_n tr(g_n)`. The metric-gradient is the exact
924    /// derivative of `0.5 Σ ||R_n||²` with respect to the raw pullback metrics:
925    ///
926    /// `A_n = R_n / gbar - (Σ_l R_l:g_l) I / (gbar² N d)`.
927    ///
928    /// All value, gradient, and HVP paths consume this state so the global
929    /// normalizer's derivative is never detached.
930    fn normalized_metric_state(
931        &self,
932        g: Array2<f64>,
933        n_obs: usize,
934        d: usize,
935    ) -> Option<IsometryMetricState> {
936        let dd = d * d;
937        let trace_denominator = (n_obs * d) as f64;
938        let mut trace_sum = 0.0;
939        for n in 0..n_obs {
940            for a in 0..d {
941                trace_sum += g[[n, a * d + a]];
942            }
943        }
944        let normalizer = trace_sum / trace_denominator;
945        if !(normalizer.is_finite() && normalizer > f64::MIN_POSITIVE) {
946            self.missing_cache_default(
947                "normalized_metric_state",
948                &format!(
949                    "unit-average-speed normalizer is non-positive or non-finite: {normalizer}"
950                ),
951            );
952            return None;
953        }
954        let g_ref = self.reference_metric(n_obs, d);
955        let mut residual = Array2::<f64>::zeros((n_obs, dd));
956        let inv_norm = 1.0 / normalizer;
957        for n in 0..n_obs {
958            for k in 0..dd {
959                residual[[n, k]] = g[[n, k]] * inv_norm - g_ref[[n, k]];
960            }
961        }
962        let mut residual_dot_g = 0.0;
963        for n in 0..n_obs {
964            for k in 0..dd {
965                residual_dot_g += residual[[n, k]] * g[[n, k]];
966            }
967        }
968        let trace_coeff = residual_dot_g / (normalizer * normalizer * trace_denominator);
969        let mut metric_grad = Array2::<f64>::zeros((n_obs, dd));
970        for n in 0..n_obs {
971            for a in 0..d {
972                for b in 0..d {
973                    let k = a * d + b;
974                    let mut value = residual[[n, k]] * inv_norm;
975                    if a == b {
976                        value -= trace_coeff;
977                    }
978                    metric_grad[[n, k]] = value;
979                }
980            }
981        }
982        Some(IsometryMetricState {
983            g,
984            residual,
985            metric_grad,
986            normalizer,
987            trace_denominator,
988            residual_dot_g,
989        })
990    }
991
992    /// Exact closed-form gradient of the isometry penalty with respect to the
993    /// cached decoder Jacobian `J ∈ ℝ^{n_obs × p × d}` (the autograd input that
994    /// torch's `_IsometryPenaltyFn` differentiates). Returns the flattened
995    /// `(n_obs, p*d)` layout that matches the Jacobian cache.
996    ///
997    /// Derivation (W-aware, reference-aware, weight-aware):
998    ///
999    ///   P        = ½ μ Σ_n ‖R_n‖²_F,
1000    ///   R_n      = g_n / gbar − g^ref_n,
1001    ///   gbar     = (1 / (N d)) Σ_n tr(g_n)
1002    ///   A_n      = ∂(P/μ)/∂g_n
1003    ///   ∂g_{ab}/∂J_{i,c}
1004    ///            = δ_{ca}(W J)_{i,b} + δ_{cb}(W J)_{i,a}   (W symmetric)
1005    ///   ∂P/∂J_{i,c}
1006    ///            = μ Σ_{a,b} A_{ab} ∂g_{ab}/∂J_{i,c}
1007    ///            = 2 μ Σ_b A_{cb} (W J)_{i,b}
1008    ///            = 2 μ ((W J) A)_{i,c}
1009    ///
1010    /// where `A` includes the exact derivative of the shared `gbar` normalizer.
1011    pub fn grad_jacobian(
1012        &self,
1013        target: ArrayView1<'_, f64>,
1014        rho: ArrayView1<'_, f64>,
1015    ) -> Array2<f64> {
1016        let d = self
1017            .target
1018            .latent_dim
1019            .expect("IsometryPenalty requires latent_dim on its PsiSlice");
1020        let n_obs = target.len() / d;
1021        let p = self.p_out;
1022        let mut grad = Array2::<f64>::zeros((n_obs, p * d));
1023        if !self.has_jacobian_cache("grad_jacobian") {
1024            return grad;
1025        }
1026        let Some(g) = self.pullback_metric(d) else {
1027            return grad;
1028        };
1029        let Some(metric) = self.normalized_metric_state(g, n_obs, d) else {
1030            return grad;
1031        };
1032        let mu = resolve_learnable_weight(self.scalar_weight, rho[self.rho_index]);
1033        for n in 0..n_obs {
1034            let Some(wj) = self.weighted_jacobian_row(n, d) else {
1035                return Array2::<f64>::zeros((n_obs, p * d));
1036            };
1037            for i in 0..p {
1038                for c in 0..d {
1039                    let mut acc = 0.0;
1040                    for b in 0..d {
1041                        acc += metric.metric_grad[[n, c * d + b]] * wj[[i, b]];
1042                    }
1043                    grad[[n, i * d + c]] = 2.0 * mu * acc;
1044                }
1045            }
1046        }
1047        grad
1048    }
1049}
1050
1051impl AnalyticPenalty for IsometryPenalty {
1052    fn tier(&self) -> PenaltyTier {
1053        PenaltyTier::Psi
1054    }
1055
1056    fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
1057        let d = self
1058            .target
1059            .latent_dim
1060            .expect("IsometryPenalty requires latent_dim on its PsiSlice");
1061        let n_obs = target.len() / d;
1062        if !self.has_jacobian_cache("value") {
1063            return Self::DEFAULT_VALUE_ON_MISSING_CACHE;
1064        }
1065        let Some(g) = self.pullback_metric(d) else {
1066            return Self::DEFAULT_VALUE_ON_MISSING_CACHE;
1067        };
1068        let Some(metric) = self.normalized_metric_state(g, n_obs, d) else {
1069            return Self::DEFAULT_VALUE_ON_MISSING_CACHE;
1070        };
1071        let mu = resolve_learnable_weight(self.scalar_weight, rho[self.rho_index]);
1072        let mut acc = 0.0;
1073        for n in 0..n_obs {
1074            for k in 0..(d * d) {
1075                let diff = metric.residual[[n, k]];
1076                acc += diff * diff;
1077            }
1078        }
1079        0.5 * mu * acc
1080    }
1081
1082    fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
1083        // Exact closed-form gradient, W-aware:
1084        //
1085        //   P     = ½ μ Σ_n ‖R_n‖²_F,   R_n = g_n / gbar − g^ref_n
1086        //   g_n   = J_n^T W_n J_n,      W_n = U_n U_n^T
1087        //   A_n   = ∂(P/μ)/∂g_n, including the exact derivative of
1088        //           gbar = (1 / (N d)) Σ_n tr(g_n)
1089        //   ∂g_{ab}/∂t_c
1090        //         = (H_{:,a,c})^T (W J)_{:,b}  +  (J_{:,a})^T W H_{:,b,c}
1091        //   ∂P/∂t_c
1092        //         = μ Σ_{a,b} A_{a,b} · ∂g_{ab}/∂t_c
1093        //
1094        // `H = ∂J/∂t` comes either from the live cache or from the radial
1095        // Duchon `φ''(r)` helper. The sign is positive: differentiating
1096        // `t - c` with respect to `t` contributes `+I`.
1097        let d = self
1098            .target
1099            .latent_dim
1100            .expect("IsometryPenalty requires latent_dim on its PsiSlice");
1101        let n_obs = target.len() / d;
1102        if !self.has_jacobian_cache("grad_target")
1103            || !self.has_jacobian_second_source("grad_target")
1104        {
1105            return Array1::<f64>::zeros(target.len());
1106        }
1107        let Some(g) = self.pullback_metric(d) else {
1108            return Array1::<f64>::zeros(target.len());
1109        };
1110        let Some(metric) = self.normalized_metric_state(g, n_obs, d) else {
1111            return Array1::<f64>::zeros(target.len());
1112        };
1113        let p = self.p_out;
1114        let mu = resolve_learnable_weight(self.scalar_weight, rho[self.rho_index]);
1115        let mut grad = Array1::<f64>::zeros(target.len());
1116        let Some(jac2) = self.jacobian_second(target, n_obs, d) else {
1117            return grad;
1118        };
1119        assert_eq!(jac2.ncols(), p * d * d);
1120
1121        for n in 0..n_obs {
1122            let Some(wj) = self.weighted_jacobian_row(n, d) else {
1123                return grad;
1124            };
1125            for c in 0..d {
1126                let mut acc = 0.0;
1127                for a in 0..d {
1128                    for b in 0..d {
1129                        let mut dg = 0.0;
1130                        for i in 0..p {
1131                            dg += jac2[[n, (i * d + a) * d + c]] * wj[[i, b]];
1132                            dg += wj[[i, a]] * jac2[[n, (i * d + b) * d + c]];
1133                        }
1134                        acc += metric.metric_grad[[n, a * d + b]] * dg;
1135                    }
1136                }
1137                grad[n * d + c] = mu * acc;
1138            }
1139        }
1140        grad
1141    }
1142
1143    /// Fully analytic - wired through `radial_basis_cartesian_derivative`.
1144    fn hvp(
1145        &self,
1146        target: ArrayView1<'_, f64>,
1147        rho: ArrayView1<'_, f64>,
1148        v: ArrayView1<'_, f64>,
1149    ) -> Array1<f64> {
1150        // Fully analytic isometry Hessian-vector product wired through the
1151        // shared `radial_basis_cartesian_derivative` engine when no
1152        // third-derivative cache is supplied.
1153        //
1154        // The full Hessian of P_iso = (μ/2) Σ_n ||J^T W J / gbar - G_ref||²_F
1155        // (per proposal §4(b)) is
1156        //   μ [Dgᵀ · ∂²(0.5||R||²)/∂g² · Dg + A · ∂²g],
1157        // where R = g/gbar - G_ref and A = ∂(0.5||R||²)/∂g includes the global
1158        // gbar derivative.
1159        //   B_{ab,cd} = K_{a,cd}^T W J_b + H_{a,c}^T W H_{b,d}
1160        //             + H_{a,d}^T W H_{b,c} + J_a^T W K_{b,cd},
1161        // where K is the third decoder derivative and H is the second.
1162        let Some(state) = self.hvp_state(target) else {
1163            return Array1::<f64>::zeros(v.len());
1164        };
1165        self.hvp_with_precomputed_state(&state, rho, v)
1166    }
1167
1168    /// PSD majorizer-vector product `B_GN(target; ρ) v` for the **nonconvex**
1169    /// isometry penalty.
1170    ///
1171    /// The Gauss-Newton block differentiates the normalized residual
1172    /// `R = g/gbar - G_ref` itself and returns `μ DRᵀ DR v`. This is PSD by
1173    /// construction and includes the shared-normalizer derivative exactly;
1174    /// using only `∂g` would reintroduce scale coupling and would not be the
1175    /// Gauss-Newton operator of the objective being minimized.
1176    fn psd_majorizer_hvp(
1177        &self,
1178        target: ArrayView1<'_, f64>,
1179        rho: ArrayView1<'_, f64>,
1180        v: ArrayView1<'_, f64>,
1181    ) -> Array1<f64> {
1182        let d = self
1183            .target
1184            .latent_dim
1185            .expect("IsometryPenalty requires latent_dim on its PsiSlice");
1186        let n_obs = target.len() / d;
1187        if !self.has_jacobian_cache("psd_majorizer_hvp")
1188            || !self.has_jacobian_second_source("psd_majorizer_hvp")
1189        {
1190            return Array1::<f64>::zeros(v.len());
1191        }
1192        let Some(jac2) = self.jacobian_second(target, n_obs, d) else {
1193            return Array1::<f64>::zeros(v.len());
1194        };
1195        let Some(g) = self.pullback_metric(d) else {
1196            return Array1::<f64>::zeros(v.len());
1197        };
1198        let Some(metric) = self.normalized_metric_state(g, n_obs, d) else {
1199            return Array1::<f64>::zeros(v.len());
1200        };
1201        let p = self.p_out;
1202        let mu = resolve_learnable_weight(self.scalar_weight, rho[self.rho_index]);
1203        let mut out = Array1::<f64>::zeros(v.len());
1204        let mut wj_rows = Vec::with_capacity(n_obs);
1205        for n in 0..n_obs {
1206            let Some(wj) = self.weighted_jacobian_row(n, d) else {
1207                return Array1::<f64>::zeros(v.len());
1208            };
1209            wj_rows.push(wj);
1210        }
1211        let mut delta_g = Array2::<f64>::zeros((n_obs, d * d));
1212        for n in 0..n_obs {
1213            let row_delta = isometry_row_delta_g(jac2.view(), wj_rows[n].view(), v, n, d, p);
1214            for a in 0..d {
1215                for b in 0..d {
1216                    delta_g[[n, a * d + b]] = row_delta[[a, b]];
1217                }
1218            }
1219        }
1220        let (delta_residual, _delta_normalizer) = metric.residual_direction(delta_g.view(), d);
1221        let mut g_dot_delta_residual = 0.0;
1222        for n in 0..n_obs {
1223            for k in 0..(d * d) {
1224                g_dot_delta_residual += metric.g[[n, k]] * delta_residual[[n, k]];
1225            }
1226        }
1227        let inv_norm = 1.0 / metric.normalizer;
1228        let inv_norm_sq = inv_norm * inv_norm;
1229        for n in 0..n_obs {
1230            let wj = &wj_rows[n];
1231            for c in 0..d {
1232                let mut trace_dg = 0.0;
1233                for a in 0..d {
1234                    trace_dg += isometry_dg_entry(jac2.view(), wj.view(), n, d, p, a, a, c);
1235                }
1236                let delta_normalizer_c = trace_dg / metric.trace_denominator;
1237                let mut acc = -delta_normalizer_c * inv_norm_sq * g_dot_delta_residual;
1238                for a in 0..d {
1239                    for b in 0..d {
1240                        let dg = isometry_dg_entry(jac2.view(), wj.view(), n, d, p, a, b, c);
1241                        acc += dg * inv_norm * delta_residual[[n, a * d + b]];
1242                    }
1243                }
1244                out[n * d + c] = mu * acc;
1245            }
1246        }
1247        out
1248    }
1249
1250    fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
1251        // P(ρ) = ½ μ · S, where S is the (ρ-independent) Frobenius sum and
1252        // μ = exp(ρ_iso). So ∂P/∂ρ_iso = P.
1253        let mut out = Array1::<f64>::zeros(self.rho_count());
1254        out[self.rho_index] = self.value(target, rho);
1255        out
1256    }
1257
1258    fn rho_count(&self) -> usize {
1259        1
1260    }
1261
1262    fn name(&self) -> &str {
1263        "isometry"
1264    }
1265
1266    impl_scalar_apply_schedule!(scalar_weight);
1267}