Skip to main content

gam_terms/basis/
measure_jet_anisotropy.rs

1//! Learned ambient anisotropy for the measure-jet energy.
2//!
3//! The isotropic measure-jet energy [`super::measure_jet_energy_form`] treats
4//! the ambient coordinates with a Euclidean local Gram: the Gaussian kernel
5//! weight is `exp(−‖δ‖²/2ε²)` and the local affine features are `δ/ε` with
6//! `δ = x_j − x_i`. This module generalizes that Euclidean inner product to a
7//! learned Mahalanobis metric
8//!
9//! ```text
10//!   A = L Lᵀ,        Ā = A / det(A)^(1/d)      (det-normalized, det Ā = 1),
11//! ```
12//!
13//! parametrized by the lower-triangular Cholesky factor `L` (d×d). The metric
14//! enters every local block through the SINGLE substitution
15//!
16//! ```text
17//!   ⟨u, v⟩  ↦  uᵀ Ā v ,
18//! ```
19//!
20//! which is realized exactly by transforming the centers once with the
21//! det-normalized factor `M = L / det(L)^(1/d)` (so `M Mᵀ = Ā`, `det M = 1`):
22//!
23//! ```text
24//!   ‖δ M‖²       = δ Ā δᵀ           (metric squared distance → kernel),
25//!   (δ/ε)M       = metric local affine features,
26//!   Y = X M      (transformed row centers; E_A(X) ≡ E_I(Y)).
27//! ```
28//!
29//! Because the local affine residual projects each block's center values onto
30//! `span{1, local affine coords}` and `M` is invertible, the projection is
31//! reparametrization-invariant: the metric reaches the energy ONLY through the
32//! kernel weights `w` and the (linearly transformed) features. With `Ā = I`
33//! (`M = I`, `Y = X`) the construction collapses to the isotropic energy
34//! bit-for-bit — that is the contract the first oracle test pins.
35//!
36//! To learn `L` by REML the energy needs exact first and second derivatives
37//! `∂E/∂L_ij`, `∂²E/∂L_ij∂L_kl`. They are produced from the SAME local block
38//! walk as the value (no second assembly that could drift from the first),
39//! by carrying, per requested `L`-direction, the exact first/second
40//! directional derivatives of every metric-dependent block quantity — the
41//! transformed features, the Gaussian weights, the weighted mean, `B`, `G`,
42//! `G⁺` and the residual — through the closed-form product/chain rules.
43//!
44//! All ∂/∂L jets are FD-gated in this module's tests against central
45//! differences of the energy (rel tol `5e-5`, step `h = 1e-4`, the
46//! second-difference-optimal step mirroring `measure_jet_smooth`'s own jet
47//! gates).
48
49use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
50
51use faer::Side;
52
53use gam_linalg::faer_ndarray::FaerEigh;
54
55use super::{BasisError, MeasureJetBand, measure_jet_energy_form};
56
57/// Truncation radius of the Gaussian profile in units of the scale ε,
58/// mirroring `measure_jet_smooth`: weights beyond `3ε` (metric distance) are
59/// below `e^{-4.5}` of the peak and are dropped from both the local fit and
60/// the `q^(1−2α)` outer weight.
61pub(crate) const PROFILE_CUTOFF: f64 = 3.0;
62
63/// Relative rank cutoff for the symmetric pseudo-inverse of the local affine
64/// Gram, identical to `measure_jet_smooth`'s constant so the `Ā = I` path is
65/// bit-for-bit. `64·ε_f64` times `n·λ_max`.
66pub(crate) const PSEUDOINVERSE_RTOL: f64 = 64.0 * f64::EPSILON;
67
68/// A single requested derivative direction in `L`-space: the lower-triangular
69/// entry `(i, j)` with `i >= j`. The zeroth-order "direction" (the value
70/// itself) is handled separately; this names the active first-order channels.
71#[derive(Clone, Copy, Debug, PartialEq, Eq)]
72pub struct LIndex {
73    /// Row of the lower-triangular factor entry (`>= col`).
74    pub row: usize,
75    /// Column of the lower-triangular factor entry (`<= row`).
76    pub col: usize,
77}
78
79/// The anisotropic energy together with its exact first and second jets with
80/// respect to the lower-triangular Cholesky factor entries of `L`.
81///
82/// `indices[a]` names the `(row, col)` of the `a`-th active lower-triangular
83/// entry (column-major over the lower triangle: for each column `j`, rows
84/// `j..d`). `d_first[a] = ∂Q/∂L_{indices[a]}`, and `d_second[(a, b)]` (stored
85/// for the full pair grid, symmetric in `a, b`) is
86/// `∂²Q/∂L_{indices[a]}∂L_{indices[b]}`.
87pub struct MeasureJetAnisotropyJets {
88    /// The det-normalized anisotropic energy form (m×m, symmetric PSD).
89    pub q: Array2<f64>,
90    /// Active lower-triangular `L`-entry indices, in the derivative order.
91    pub indices: Vec<LIndex>,
92    /// First derivatives `∂Q/∂L_a`, one m×m form per active index.
93    pub d_first: Vec<Array2<f64>>,
94    /// Second derivatives `∂²Q/∂L_a∂L_b`, indexed by `a*n + b` over the
95    /// `n = indices.len()` active entries (full symmetric grid).
96    pub d_second: Vec<Array2<f64>>,
97}
98
99impl MeasureJetAnisotropyJets {
100    /// Number of active lower-triangular derivative channels.
101    #[inline]
102    pub fn n_active(&self) -> usize {
103        self.indices.len()
104    }
105
106    /// Borrow the second-derivative form `∂²Q/∂L_a∂L_b`.
107    #[inline]
108    pub fn second(&self, a: usize, b: usize) -> &Array2<f64> {
109        &self.d_second[a * self.indices.len() + b]
110    }
111}
112
113/// Enumerate the active lower-triangular entries of a `d×d` factor in
114/// column-major order (column `j`, rows `j..d`). This is the canonical order
115/// every output of this module uses for its `L`-derivative channels.
116pub fn lower_triangular_indices(d: usize) -> Vec<LIndex> {
117    let mut idx = Vec::with_capacity(d * (d + 1) / 2);
118    for col in 0..d {
119        for row in col..d {
120            idx.push(LIndex { row, col });
121        }
122    }
123    idx
124}
125
126// ----------------------------------------------------------------------------
127// Det-normalized factor M = L / det(L)^(1/d) and its exact L-jets.
128// ----------------------------------------------------------------------------
129
130/// The det-normalized factor `M = L · g`, `g = det(L)^(−1/d) = (∏ L_kk)^(−1/d)`,
131/// together with its first and second directional derivatives with respect to
132/// the active lower-triangular entries of `L`.
133///
134/// `det L = ∏_k L_kk` depends only on the diagonal, so `∂ ln det L / ∂L_ij`
135/// is `1/L_ii` when `i == j` and `0` otherwise. Writing `f = ln g = −(1/d)·ln
136/// det L`, `M = L·e^f`, every derivative below is the exact product rule on
137/// `L·e^f`.
138pub struct NormalizedFactor {
139    /// `M = L / det(L)^(1/d)` (d×d, lower-triangular, `det M = 1`).
140    pub(crate) m: Array2<f64>,
141    /// `∂M/∂L_a` for each active index `a` (d×d).
142    pub(crate) dm: Vec<Array2<f64>>,
143    /// `∂²M/∂L_a∂L_b` for the full pair grid `a*n+b` (d×d).
144    pub(crate) d2m: Vec<Array2<f64>>,
145}
146
147pub(crate) fn build_normalized_factor(
148    l: ArrayView2<'_, f64>,
149    indices: &[LIndex],
150) -> Result<NormalizedFactor, BasisError> {
151    let d = l.nrows();
152    if l.ncols() != d {
153        crate::bail_dim_basis!(
154            "measure-jet anisotropy needs a square lower-triangular L, got {:?}",
155            l.dim()
156        );
157    }
158    if d == 0 {
159        crate::bail_invalid_basis!("measure-jet anisotropy needs a non-empty ambient metric");
160    }
161    for k in 0..d {
162        if !(l[(k, k)].is_finite() && l[(k, k)] > 0.0) {
163            crate::bail_invalid_basis!(
164                "measure-jet anisotropy needs a positive-definite L: diagonal entry L[{k},{k}] = {} is not finite and positive",
165                l[(k, k)]
166            );
167        }
168        for c in (k + 1)..d {
169            if l[(k, c)] != 0.0 {
170                crate::bail_invalid_basis!(
171                    "measure-jet anisotropy L must be lower-triangular: upper entry L[{k},{c}] = {} is nonzero",
172                    l[(k, c)]
173                );
174            }
175            if !l[(c, k)].is_finite() {
176                crate::bail_invalid_basis!(
177                    "measure-jet anisotropy L has a non-finite entry L[{c},{k}]"
178                );
179            }
180        }
181    }
182
183    let n = indices.len();
184    let l_owned = l.to_owned();
185
186    // f = ln g = −(1/d)·Σ_k ln L_kk. Diagonal-only first/second derivatives.
187    let inv_d = 1.0 / d as f64;
188    let mut f_first = vec![0.0_f64; n];
189    // f second derivatives are diagonal in the (a == b, both the same diagonal
190    // entry) sense: ∂²f/∂L_kk² = +(1/d)/L_kk², all cross/off-diagonal zero.
191    let mut f_second = vec![0.0_f64; n * n];
192    for (a, ia) in indices.iter().enumerate() {
193        if ia.row == ia.col {
194            let lkk = l_owned[(ia.row, ia.row)];
195            f_first[a] = -inv_d / lkk;
196            f_second[a * n + a] = inv_d / (lkk * lkk);
197        }
198    }
199
200    // g = e^f, M = L·g.
201    let g = (-inv_d * {
202        let mut s = 0.0;
203        for k in 0..d {
204            s += l_owned[(k, k)].ln();
205        }
206        s
207    })
208    .exp();
209
210    // g first/second derivatives via the chain rule on e^f:
211    //   g_a   = g·f_a,
212    //   g_ab  = g·(f_a·f_b + f_ab).
213    let mut g_first = vec![0.0_f64; n];
214    let mut g_second = vec![0.0_f64; n * n];
215    for a in 0..n {
216        g_first[a] = g * f_first[a];
217    }
218    for a in 0..n {
219        for b in 0..n {
220            g_second[a * n + b] = g * (f_first[a] * f_first[b] + f_second[a * n + b]);
221        }
222    }
223
224    // E_a = ∂L/∂L_a : the single-entry indicator matrix.
225    // M = L·g  ⇒  M_a = E_a·g + L·g_a,
226    //              M_ab = E_a·g_b + E_b·g_a + L·g_ab.
227    let m = &l_owned * g;
228    let mut dm = Vec::with_capacity(n);
229    for a in 0..n {
230        let ia = indices[a];
231        let mut ma = &l_owned * g_first[a];
232        ma[(ia.row, ia.col)] += g;
233        dm.push(ma);
234    }
235    let mut d2m = Vec::with_capacity(n * n);
236    for a in 0..n {
237        let ia = indices[a];
238        for b in 0..n {
239            let ib = indices[b];
240            let mut mab = &l_owned * g_second[a * n + b];
241            mab[(ia.row, ia.col)] += g_first[b];
242            mab[(ib.row, ib.col)] += g_first[a];
243            d2m.push(mab);
244        }
245    }
246
247    Ok(NormalizedFactor { m, dm, d2m })
248}
249
250// ----------------------------------------------------------------------------
251// Per-block algebra and its exact L-jets.
252// ----------------------------------------------------------------------------
253
254/// Squared metric distances `δĀδᵀ = ‖δ M‖²` for every center pair, plus the
255/// active first/second `L`-directional derivatives of each. Used for the ε/2
256/// outer net, the neighbor cutoff, and the kernel exponent — exactly the role
257/// `pairwise_sq_dists` plays in the isotropic assembly.
258pub(crate) struct MetricDist2 {
259    /// `dM2[(i, j)] = ‖(x_i − x_j) M‖²`.
260    pub(crate) dm2: Array2<f64>,
261}
262
263pub(crate) fn metric_sq_dists(centers: ArrayView2<'_, f64>, m: ArrayView2<'_, f64>) -> MetricDist2 {
264    let n = centers.nrows();
265    // Y = X M ; ‖δ M‖² = ‖Y_i − Y_j‖². Build Y once, then GEMM-style Gram with
266    // the same `‖a‖²+‖b‖²−2aᵀb`, clamped at 0 (mirrors pairwise_sq_dists so the
267    // `M = I` path lands identically).
268    let y = centers.dot(&m);
269    let yn: Vec<f64> = y.outer_iter().map(|r| r.dot(&r)).collect();
270    let g = y.dot(&y.t());
271    let mut dm2 = Array2::<f64>::zeros((n, n));
272    for i in 0..n {
273        for j in 0..n {
274            dm2[(i, j)] = (yn[i] + yn[j] - 2.0 * g[(i, j)]).max(0.0);
275        }
276    }
277    MetricDist2 { dm2 }
278}
279
280/// Symmetric pseudo-inverse via eigendecomposition with the same rank cutoff
281/// as `measure_jet_smooth::symmetric_pseudoinverse` (so `Ā = I` is bit-exact),
282/// additionally returning the eigenpairs so the projector's `L`-derivatives can
283/// be propagated through `G⁺` analytically.
284pub(crate) struct EighPinv {
285    pub(crate) evals: Array1<f64>,
286    pub(crate) evecs: Array2<f64>,
287    /// Per-mode inverse eigenvalue (0 below the rank cutoff).
288    pub(crate) inv: Array1<f64>,
289    pub(crate) pinv: Array2<f64>,
290}
291
292pub(crate) fn eigh_pinv(a: &Array2<f64>, label: &str) -> Result<EighPinv, BasisError> {
293    let n = a.nrows();
294    let (evals, evecs) = a.eigh(Side::Lower).map_err(|e| {
295        BasisError::InvalidInput(format!(
296            "measure-jet anisotropy pseudo-inverse `{label}` eigendecomposition failed: {e}"
297        ))
298    })?;
299    let lam_max = evals.iter().fold(0.0_f64, |acc, v| acc.max((*v).max(0.0)));
300    let rank_tol = PSEUDOINVERSE_RTOL * (n.max(1) as f64) * lam_max;
301    let mut inv = Array1::<f64>::zeros(n);
302    let mut scaled = evecs.clone();
303    for k in 0..n {
304        let lam = evals[k].max(0.0);
305        let iv = if lam > rank_tol { 1.0 / lam } else { 0.0 };
306        inv[k] = iv;
307        let mut col = scaled.column_mut(k);
308        col.mapv_inplace(|v| v * iv);
309    }
310    let pinv = scaled.dot(&evecs.t());
311    Ok(EighPinv {
312        evals,
313        evecs,
314        inv,
315        pinv,
316    })
317}
318
319/// The exact `L`-derivative of `G⁺` in a single direction, given `G`'s
320/// eigenpairs and the derivative `Ġ` of `G` in that direction:
321///
322/// ```text
323///   d(G⁺) = Σ_{p,q}  c_{pq} · v_p (v_pᵀ Ġ v_q) v_qᵀ ,
324/// ```
325///
326/// the standard pseudo-inverse perturbation on the retained (full-rank) modes
327/// with
328///
329/// ```text
330///   c_{pq} = −1/(λ_p λ_q)      if both p, q retained,
331///          =  1/λ_p²·…         range/null cross terms,
332/// ```
333///
334/// Here the local Gram is at most rank `d` and the retained block is exactly
335/// the numerical range; on that range `G⁺ = G_r⁻¹`, so the formula reduces to
336/// the symmetric-inverse perturbation `−G⁺ Ġ G⁺` PLUS the two range↔null cross
337/// corrections `P⊥ Ġ G⁺ + G⁺ Ġ P⊥` divided by the retained eigenvalues, which
338/// the eigen-mode sum below captures exactly. We assemble it directly in the
339/// eigenbasis to stay exact across the rank boundary.
340pub(crate) fn pinv_first_deriv(ep: &EighPinv, gdot: &Array2<f64>) -> Array2<f64> {
341    let n = ep.evals.len();
342    let vt_g = ep.evecs.t().dot(gdot);
343    let mhat = vt_g.dot(&ep.evecs); // (n×n) in eigen coords
344    let mut core = Array2::<f64>::zeros((n, n));
345    for p in 0..n {
346        for q in 0..n {
347            core[(p, q)] = pinv_div1(ep, p, q) * mhat[(p, q)];
348        }
349    }
350    ep.evecs.dot(&core).dot(&ep.evecs.t())
351}
352
353#[inline]
354pub(crate) fn pinv_active(ep: &EighPinv, i: usize) -> bool {
355    ep.inv[i] != 0.0
356}
357
358#[inline]
359pub(crate) fn pinv_value(ep: &EighPinv, i: usize) -> f64 {
360    if pinv_active(ep, i) { ep.inv[i] } else { 0.0 }
361}
362
363#[inline]
364pub(crate) fn pinv_prime(ep: &EighPinv, i: usize) -> f64 {
365    if pinv_active(ep, i) {
366        -ep.inv[i] * ep.inv[i]
367    } else {
368        0.0
369    }
370}
371
372#[inline]
373pub(crate) fn pinv_half_second(ep: &EighPinv, i: usize) -> f64 {
374    if pinv_active(ep, i) {
375        ep.inv[i] * ep.inv[i] * ep.inv[i]
376    } else {
377        0.0
378    }
379}
380
381pub(crate) fn pinv_div1(ep: &EighPinv, i: usize, j: usize) -> f64 {
382    if i == j {
383        return pinv_prime(ep, i);
384    }
385    let li = ep.evals[i];
386    let lj = ep.evals[j];
387    let denom = li - lj;
388    let scale = li.abs().max(lj.abs()).max(1.0);
389    if denom.abs() <= 16.0 * f64::EPSILON * scale {
390        if pinv_active(ep, i) == pinv_active(ep, j) {
391            0.5 * (pinv_prime(ep, i) + pinv_prime(ep, j))
392        } else {
393            0.0
394        }
395    } else {
396        (pinv_value(ep, i) - pinv_value(ep, j)) / denom
397    }
398}
399
400pub(crate) fn pinv_div2(ep: &EighPinv, i: usize, k: usize, j: usize) -> f64 {
401    if i == k && k == j {
402        return pinv_half_second(ep, i);
403    }
404    let li = ep.evals[i];
405    let lk = ep.evals[k];
406    let lj = ep.evals[j];
407    if i == j {
408        let h = lk - li;
409        let scale = li.abs().max(lk.abs()).max(1.0);
410        if h.abs() <= 16.0 * f64::EPSILON * scale {
411            return pinv_half_second(ep, i);
412        }
413        return (pinv_value(ep, k) - pinv_value(ep, i) - pinv_prime(ep, i) * h) / (h * h);
414    }
415    if i == k {
416        let denom = li - lj;
417        let scale = li.abs().max(lj.abs()).max(1.0);
418        if denom.abs() <= 16.0 * f64::EPSILON * scale {
419            return pinv_half_second(ep, i);
420        }
421        return (pinv_prime(ep, i) - pinv_div1(ep, i, j)) / denom;
422    }
423    if k == j {
424        let denom = li - lj;
425        let scale = li.abs().max(lj.abs()).max(1.0);
426        if denom.abs() <= 16.0 * f64::EPSILON * scale {
427            return pinv_half_second(ep, j);
428        }
429        return (pinv_div1(ep, i, j) - pinv_prime(ep, j)) / denom;
430    }
431    let denom = li - lj;
432    let scale = li.abs().max(lj.abs()).max(1.0);
433    if denom.abs() <= 16.0 * f64::EPSILON * scale {
434        let h = lk - li;
435        if h.abs() <= 16.0 * f64::EPSILON * scale {
436            pinv_half_second(ep, i)
437        } else {
438            (pinv_value(ep, k) - pinv_value(ep, i) - pinv_prime(ep, i) * h) / (h * h)
439        }
440    } else {
441        (pinv_div1(ep, i, k) - pinv_div1(ep, k, j)) / denom
442    }
443}
444
445pub(crate) fn pinv_second_deriv(
446    ep: &EighPinv,
447    gx: &Array2<f64>,
448    gy: &Array2<f64>,
449    gxy: &Array2<f64>,
450) -> Array2<f64> {
451    let n = ep.evals.len();
452    let gx_hat = ep.evecs.t().dot(gx).dot(&ep.evecs);
453    let gy_hat = ep.evecs.t().dot(gy).dot(&ep.evecs);
454    let gxy_hat = ep.evecs.t().dot(gxy).dot(&ep.evecs);
455    let mut core = Array2::<f64>::zeros((n, n));
456    for i in 0..n {
457        for j in 0..n {
458            let mut value = pinv_div1(ep, i, j) * gxy_hat[(i, j)];
459            for k in 0..n {
460                value += pinv_div2(ep, i, k, j)
461                    * (gx_hat[(i, k)] * gy_hat[(k, j)] + gy_hat[(i, k)] * gx_hat[(k, j)]);
462            }
463            core[(i, j)] = value;
464        }
465    }
466    ep.evecs.dot(&core).dot(&ep.evecs.t())
467}
468
469/// Outputs of one local block's residual `R` and its requested `L`-jets, all
470/// scattered into the energy forms with the SAME outer weight `base`. The
471/// metric only changes the kernel weights and the linearly transformed
472/// features; the projection algebra (`a_mean`, `B`, `G`, `G⁺`, `R`) is the
473/// isotropic one, differentiated through those two metric channels.
474pub(crate) struct BlockForms {
475    /// `R` value (ml×ml) before the outer weight.
476    pub(crate) r: Array2<f64>,
477    /// `∂R/∂L_a` (ml×ml).
478    pub(crate) dr: Vec<Array2<f64>>,
479    /// `∂²R/∂L_a∂L_b` (ml×ml), full pair grid `a*n+b`.
480    pub(crate) d2r: Vec<Array2<f64>>,
481    /// Kernel mass `q = Σ_a w_a` before the outer density exponent.
482    pub(crate) q: f64,
483    /// `∂q/∂L_a`.
484    pub(crate) dq: Vec<f64>,
485    /// `∂²q/∂L_a∂L_b`, full pair grid `a*n+b`.
486    pub(crate) d2q: Vec<f64>,
487}
488
489/// Assemble one local block's residual `R = CᵀWC − B G⁺ Bᵀ / q` and its exact
490/// first/second `L`-jets. `phi[a,k] = δ_{a,k}/ε`, `w[a] = mass·exp(−‖φ_a M‖²/2)`.
491/// `dpsi`/`d2psi` are the directional derivatives of `φ M` (i.e. `φ Ṁ`,
492/// `φ M̈`). This is the metric generalization of the inner loop in
493/// `measure_jet_smooth::assemble_weighted_forms`, with value and jets sharing
494/// one walk so a value↔derivative desync is structurally impossible.
495pub(crate) fn block_residual_jets(
496    phi: &Array2<f64>,          // ml×d : δ/ε (metric-free local features)
497    masses_local: &Array1<f64>, // ml
498    m: ArrayView2<'_, f64>,     // d×d : M
499    dm: &[Array2<f64>],         // n × (d×d) : ∂M/∂L_a
500    d2m: &[Array2<f64>],        // n² × (d×d) : ∂²M/∂L_a∂L_b
501    n_active: usize,
502) -> BlockForms {
503    let ml = phi.nrows();
504    let n = n_active;
505
506    // Transformed row features psi = phi·M (ml×d) and its L-derivatives.
507    let psi = phi.dot(&m);
508    let mut dpsi: Vec<Array2<f64>> = Vec::with_capacity(n);
509    for a in 0..n {
510        dpsi.push(phi.dot(&dm[a]));
511    }
512    let mut d2psi: Vec<Array2<f64>> = Vec::with_capacity(n * n);
513    for a in 0..n {
514        for b in 0..n {
515            d2psi.push(phi.dot(&d2m[a * n + b]));
516        }
517    }
518
519    // Kernel weights w[a] = mass·exp(−½‖psi_a‖²) and L-derivatives.
520    //   e_a       = −½‖psi_a‖²
521    //   de/dL_x   = −psi_a·dpsi^x_a
522    //   d²e       = −(dpsi^x_a·dpsi^y_a + psi_a·d2psi^{xy}_a)
523    //   w = mass·exp(e); w_x = w·e_x; w_xy = w·(e_x·e_y + e_xy)
524    let mut w = Array1::<f64>::zeros(ml);
525    let mut dw: Vec<Array1<f64>> = (0..n).map(|_| Array1::<f64>::zeros(ml)).collect();
526    let mut d2w: Vec<Array1<f64>> = (0..n * n).map(|_| Array1::<f64>::zeros(ml)).collect();
527    for a in 0..ml {
528        let psi_a = psi.row(a);
529        let e = -0.5 * psi_a.dot(&psi_a);
530        let wa = masses_local[a] * e.exp();
531        w[a] = wa;
532        // First-order energy-exponent derivatives.
533        let mut ex = vec![0.0_f64; n];
534        for x in 0..n {
535            ex[x] = -psi_a.dot(&dpsi[x].row(a));
536            dw[x][a] = wa * ex[x];
537        }
538        // Second-order.
539        for x in 0..n {
540            for y in 0..n {
541                let dpx = dpsi[x].row(a);
542                let dpy = dpsi[y].row(a);
543                let d2p = d2psi[x * n + y].row(a);
544                let exy = -(dpx.dot(&dpy) + psi_a.dot(&d2p));
545                d2w[x * n + y][a] = wa * (ex[x] * ex[y] + exy);
546            }
547        }
548    }
549
550    // From here the algebra mirrors the isotropic block, with psi as the
551    // features and w as the weights — both metric-dependent — propagated by
552    // the product rule across the four bilinear pieces.
553    //   q     = Σ w
554    //   a_mean= Φᵀw / q             (Φ ≡ psi here)
555    //   B     = WΦ − w·a_meanᵀ
556    //   G     = (ΦᵀWΦ)/q − a_mean·a_meanᵀ
557    //   R     = CᵀWC − B G⁺ Bᵀ / q ,  CᵀWC = W − w·wᵀ/q  (diagonal W).
558    //
559    // We assemble value + first + second jets of every intermediate in lock
560    // step. To keep the code linear we build, for the (value, {x}, {x,y})
561    // levels, each quantity; products use Leibniz.
562
563    let d = phi.ncols();
564
565    // q and jets.
566    let q = w.sum();
567    let mut dq = vec![0.0_f64; n];
568    let mut d2q = vec![0.0_f64; n * n];
569    for x in 0..n {
570        dq[x] = dw[x].sum();
571    }
572    for x in 0..n {
573        for y in 0..n {
574            d2q[x * n + y] = d2w[x * n + y].sum();
575        }
576    }
577
578    // Φᵀ w  (length-d vector p) and jets:  p = Σ_a w_a · psi_a.
579    // Build p, dp (n×d), d2p (n²×d).
580    let mut pvec = Array1::<f64>::zeros(d);
581    for a in 0..ml {
582        for k in 0..d {
583            pvec[k] += w[a] * psi[(a, k)];
584        }
585    }
586    let mut dpvec: Vec<Array1<f64>> = (0..n).map(|_| Array1::<f64>::zeros(d)).collect();
587    for x in 0..n {
588        for a in 0..ml {
589            for k in 0..d {
590                dpvec[x][k] += dw[x][a] * psi[(a, k)] + w[a] * dpsi[x][(a, k)];
591            }
592        }
593    }
594    let mut d2pvec: Vec<Array1<f64>> = (0..n * n).map(|_| Array1::<f64>::zeros(d)).collect();
595    for x in 0..n {
596        for y in 0..n {
597            let dst = &mut d2pvec[x * n + y];
598            for a in 0..ml {
599                for k in 0..d {
600                    dst[k] += d2w[x * n + y][a] * psi[(a, k)]
601                        + dw[x][a] * dpsi[y][(a, k)]
602                        + dw[y][a] * dpsi[x][(a, k)]
603                        + w[a] * d2psi[x * n + y][(a, k)];
604                }
605            }
606        }
607    }
608
609    // a_mean = p / q  (quotient rule).
610    let amean = &pvec / q;
611    let mut damean: Vec<Array1<f64>> = Vec::with_capacity(n);
612    for x in 0..n {
613        damean.push((&dpvec[x] - &(&amean * dq[x])) / q);
614    }
615    let mut d2amean: Vec<Array1<f64>> = Vec::with_capacity(n * n);
616    for x in 0..n {
617        for y in 0..n {
618            // (p/q)'' = p''/q − (p'·q' + p''_cross...) ; use explicit quotient.
619            // d²(p/q) = p_xy/q − (p_x q_y + p_y q_x + p q_xy)/q² + 2 p q_x q_y / q³
620            let term = (&d2pvec[x * n + y]) / q
621                - (&(&dpvec[x] * dq[y]) + &(&dpvec[y] * dq[x]) + &(&pvec * d2q[x * n + y]))
622                    / (q * q)
623                + &(&pvec * (2.0 * dq[x] * dq[y] / (q * q * q)));
624            d2amean.push(term);
625        }
626    }
627
628    // B = WΦ − w·a_meanᵀ  (ml×d):  B[a,k] = w_a·psi[a,k] − w_a·amean[k].
629    let bmat = |wv: &Array1<f64>, psiv: &Array2<f64>, am: &Array1<f64>| -> Array2<f64> {
630        let mut bb = Array2::<f64>::zeros((ml, d));
631        for a in 0..ml {
632            for k in 0..d {
633                bb[(a, k)] = wv[a] * (psiv[(a, k)] - am[k]);
634            }
635        }
636        bb
637    };
638    let b = bmat(&w, &psi, &amean);
639    // dB[x][a,k] = dw·(psi − am) + w·(dpsi − dam)
640    let mut db: Vec<Array2<f64>> = Vec::with_capacity(n);
641    for x in 0..n {
642        let mut bb = Array2::<f64>::zeros((ml, d));
643        for a in 0..ml {
644            for k in 0..d {
645                bb[(a, k)] =
646                    dw[x][a] * (psi[(a, k)] - amean[k]) + w[a] * (dpsi[x][(a, k)] - damean[x][k]);
647            }
648        }
649        db.push(bb);
650    }
651    // d²B[x,y][a,k] = d2w·(psi−am) + dw_x·(dpsi_y−dam_y) + dw_y·(dpsi_x−dam_x)
652    //                + w·(d2psi − d2am)
653    let mut d2b: Vec<Array2<f64>> = Vec::with_capacity(n * n);
654    for x in 0..n {
655        for y in 0..n {
656            let mut bb = Array2::<f64>::zeros((ml, d));
657            for a in 0..ml {
658                for k in 0..d {
659                    bb[(a, k)] = d2w[x * n + y][a] * (psi[(a, k)] - amean[k])
660                        + dw[x][a] * (dpsi[y][(a, k)] - damean[y][k])
661                        + dw[y][a] * (dpsi[x][(a, k)] - damean[x][k])
662                        + w[a] * (d2psi[x * n + y][(a, k)] - d2amean[x * n + y][k]);
663                }
664            }
665            d2b.push(bb);
666        }
667    }
668
669    // H = ΦᵀWΦ  (d×d):  H[r,c] = Σ_a w_a·psi[a,r]·psi[a,c].
670    let hmat = |wv: &Array1<f64>, psiv: &Array2<f64>| -> Array2<f64> {
671        let mut hh = Array2::<f64>::zeros((d, d));
672        for a in 0..ml {
673            for r in 0..d {
674                for c in 0..d {
675                    hh[(r, c)] += wv[a] * psiv[(a, r)] * psiv[(a, c)];
676                }
677            }
678        }
679        hh
680    };
681    let hh = hmat(&w, &psi);
682    let mut dhh: Vec<Array2<f64>> = Vec::with_capacity(n);
683    for x in 0..n {
684        let mut hd = Array2::<f64>::zeros((d, d));
685        for a in 0..ml {
686            for r in 0..d {
687                for c in 0..d {
688                    hd[(r, c)] += dw[x][a] * psi[(a, r)] * psi[(a, c)]
689                        + w[a] * dpsi[x][(a, r)] * psi[(a, c)]
690                        + w[a] * psi[(a, r)] * dpsi[x][(a, c)];
691                }
692            }
693        }
694        dhh.push(hd);
695    }
696    let mut d2hh: Vec<Array2<f64>> = Vec::with_capacity(n * n);
697    for x in 0..n {
698        for y in 0..n {
699            let mut hd = Array2::<f64>::zeros((d, d));
700            for a in 0..ml {
701                for r in 0..d {
702                    for c in 0..d {
703                        let pr = psi[(a, r)];
704                        let pc = psi[(a, c)];
705                        let dprx = dpsi[x][(a, r)];
706                        let dpcx = dpsi[x][(a, c)];
707                        let dpry = dpsi[y][(a, r)];
708                        let dpcy = dpsi[y][(a, c)];
709                        let d2pr = d2psi[x * n + y][(a, r)];
710                        let d2pc = d2psi[x * n + y][(a, c)];
711                        hd[(r, c)] += d2w[x * n + y][a] * pr * pc
712                            + dw[x][a] * (dpry * pc + pr * dpcy)
713                            + dw[y][a] * (dprx * pc + pr * dpcx)
714                            + w[a] * (d2pr * pc + dprx * dpcy + dpry * dpcx + pr * d2pc);
715                    }
716                }
717            }
718            d2hh.push(hd);
719        }
720    }
721
722    // G = H/q − a_mean·a_meanᵀ. Build G and jets.
723    let outer = |u: &Array1<f64>, v: &Array1<f64>| -> Array2<f64> {
724        let mut o = Array2::<f64>::zeros((d, d));
725        for r in 0..d {
726            for c in 0..d {
727                o[(r, c)] = u[r] * v[c];
728            }
729        }
730        o
731    };
732    let g = &(&hh / q) - &outer(&amean, &amean);
733    let mut dg: Vec<Array2<f64>> = Vec::with_capacity(n);
734    for x in 0..n {
735        // d(H/q) = dH/q − H·dq/q²
736        let dhq = &(&dhh[x] / q) - &(&hh * (dq[x] / (q * q)));
737        // d(am amᵀ) = dam·amᵀ + am·damᵀ
738        let dout = &outer(&damean[x], &amean) + &outer(&amean, &damean[x]);
739        dg.push(&dhq - &dout);
740    }
741    let mut d2g: Vec<Array2<f64>> = Vec::with_capacity(n * n);
742    for x in 0..n {
743        for y in 0..n {
744            // d²(H/q) = d2H/q − (dH_x q_y + dH_y q_x + H q_xy)/q² + 2 H q_x q_y/q³
745            let d2hq = &(&d2hh[x * n + y] / q)
746                - &(&(&dhh[x] * (dq[y] / (q * q)))
747                    + &(&dhh[y] * (dq[x] / (q * q)))
748                    + &(&hh * (d2q[x * n + y] / (q * q))))
749                + &(&hh * (2.0 * dq[x] * dq[y] / (q * q * q)));
750            // d²(am amᵀ) = d2am·amᵀ + dam_x·dam_yᵀ + dam_y·dam_xᵀ + am·d2amᵀ
751            let d2out = &outer(&d2amean[x * n + y], &amean)
752                + &outer(&damean[x], &damean[y])
753                + &outer(&damean[y], &damean[x])
754                + &outer(&amean, &d2amean[x * n + y]);
755            d2g.push(&d2hq - &d2out);
756        }
757    }
758
759    // G⁺ and jets (eigen-perturbation).
760    let ep = eigh_pinv(&g, "local affine Gram").unwrap_or_else(|_| {
761        // A degenerate eigensolve here means the block geometry is singular to
762        // machine precision; fall back to a zero projector (the residual then
763        // reduces to CᵀWC), keeping the value finite. The isotropic path
764        // never hits this on well-posed center sets.
765        EighPinv {
766            evals: Array1::zeros(d),
767            evecs: Array2::eye(d),
768            inv: Array1::zeros(d),
769            pinv: Array2::zeros((d, d)),
770        }
771    });
772    let gpinv = ep.pinv.clone();
773    let mut dgpinv: Vec<Array2<f64>> = Vec::with_capacity(n);
774    for x in 0..n {
775        dgpinv.push(pinv_first_deriv(&ep, &dg[x]));
776    }
777    // Second derivative of G⁺ as a fixed-rank spectral matrix function:
778    // K_xy = DK[G][G_xy] + D²K[G][G_x, G_y]. The divided-difference formulas
779    // include retained-range, inactive-range, and cross terms without assuming
780    // an inverse on a frozen range block.
781    let mut d2gpinv: Vec<Array2<f64>> = Vec::with_capacity(n * n);
782    for x in 0..n {
783        for y in 0..n {
784            d2gpinv.push(pinv_second_deriv(&ep, &dg[x], &dg[y], &d2g[x * n + y]));
785        }
786    }
787
788    // P = B G⁺ Bᵀ (ml×ml) and jets, then R = CᵀWC − P/q.
789    // CᵀWC = diag(w) − w wᵀ/q.
790    let triple = |bb: &Array2<f64>, gp: &Array2<f64>| -> Array2<f64> { bb.dot(gp).dot(&bb.t()) };
791    let p = triple(&b, &gpinv);
792    let mut dp: Vec<Array2<f64>> = Vec::with_capacity(n);
793    for x in 0..n {
794        // d(B G⁺ Bᵀ) = dB G⁺ Bᵀ + B dG⁺ Bᵀ + B G⁺ dBᵀ
795        let t1 = db[x].dot(&gpinv).dot(&b.t());
796        let t2 = b.dot(&dgpinv[x]).dot(&b.t());
797        let t3 = b.dot(&gpinv).dot(&db[x].t());
798        dp.push(&(&t1 + &t2) + &t3);
799    }
800    let mut d2p: Vec<Array2<f64>> = Vec::with_capacity(n * n);
801    for x in 0..n {
802        for y in 0..n {
803            // Full Leibniz over the three factors (B, G⁺, Bᵀ).
804            let bx = &db[x];
805            let by = &db[y];
806            let bxy = &d2b[x * n + y];
807            let gx = &dgpinv[x];
808            let gy = &dgpinv[y];
809            let gxy = &d2gpinv[x * n + y];
810            let mut acc = bxy.dot(&gpinv).dot(&b.t());
811            acc += &bx.dot(gy).dot(&b.t());
812            acc += &bx.dot(&gpinv).dot(&by.t());
813            acc += &by.dot(gx).dot(&b.t());
814            acc += &b.dot(gxy).dot(&b.t());
815            acc += &b.dot(gx).dot(&by.t());
816            acc += &by.dot(&gpinv).dot(&bx.t());
817            acc += &b.dot(gy).dot(&bx.t());
818            acc += &b.dot(&gpinv).dot(&bxy.t());
819            d2p.push(acc);
820        }
821    }
822
823    // R = diag(w) − w wᵀ/q − P/q.
824    let assemble_r = |wv: &Array1<f64>, qv: f64, pv: &Array2<f64>| -> Array2<f64> {
825        let mut rr = Array2::<f64>::zeros((ml, ml));
826        for a in 0..ml {
827            for c in 0..ml {
828                rr[(a, c)] = -wv[a] * wv[c] / qv - pv[(a, c)] / qv;
829            }
830            rr[(a, a)] += wv[a];
831        }
832        rr
833    };
834    let r = assemble_r(&w, q, &p);
835
836    // dR = diag(dw) − d(w wᵀ/q) − d(P/q).
837    //   d(w wᵀ/q) = (dw wᵀ + w dwᵀ)/q − w wᵀ dq/q²
838    //   d(P/q)    = dP/q − P dq/q²
839    let mut dr: Vec<Array2<f64>> = Vec::with_capacity(n);
840    for x in 0..n {
841        let mut rr = Array2::<f64>::zeros((ml, ml));
842        for a in 0..ml {
843            for c in 0..ml {
844                let wwt_d = (dw[x][a] * w[c] + w[a] * dw[x][c]) / q - w[a] * w[c] * dq[x] / (q * q);
845                let pd = dp[x][(a, c)] / q - p[(a, c)] * dq[x] / (q * q);
846                rr[(a, c)] = -wwt_d - pd;
847            }
848            rr[(a, a)] += dw[x][a];
849        }
850        dr.push(rr);
851    }
852
853    // d²R similarly, full product rule on each 1/q-scaled bilinear.
854    let mut d2r: Vec<Array2<f64>> = Vec::with_capacity(n * n);
855    for x in 0..n {
856        for y in 0..n {
857            let qx = dq[x];
858            let qy = dq[y];
859            let qxy = d2q[x * n + y];
860            let mut rr = Array2::<f64>::zeros((ml, ml));
861            for a in 0..ml {
862                for c in 0..ml {
863                    // w wᵀ / q second derivative.
864                    let num = w[a] * w[c];
865                    let num_x = dw[x][a] * w[c] + w[a] * dw[x][c];
866                    let num_y = dw[y][a] * w[c] + w[a] * dw[y][c];
867                    let num_xy = d2w[x * n + y][a] * w[c]
868                        + dw[x][a] * dw[y][c]
869                        + dw[y][a] * dw[x][c]
870                        + w[a] * d2w[x * n + y][c];
871                    let wwt_d2 = num_xy / q - (num_x * qy + num_y * qx + num * qxy) / (q * q)
872                        + 2.0 * num * qx * qy / (q * q * q);
873                    // P / q second derivative.
874                    let pn = p[(a, c)];
875                    let pnx = dp[x][(a, c)];
876                    let pny = dp[y][(a, c)];
877                    let pnxy = d2p[x * n + y][(a, c)];
878                    let p_d2 = pnxy / q - (pnx * qy + pny * qx + pn * qxy) / (q * q)
879                        + 2.0 * pn * qx * qy / (q * q * q);
880                    rr[(a, c)] = -wwt_d2 - p_d2;
881                }
882                rr[(a, a)] += d2w[x * n + y][a];
883            }
884            d2r.push(rr);
885        }
886    }
887
888    BlockForms {
889        r,
890        dr,
891        d2r,
892        q,
893        dq,
894        d2q,
895    }
896}
897
898// ----------------------------------------------------------------------------
899// Top-level energy and L-jets.
900// ----------------------------------------------------------------------------
901
902/// The det-normalized anisotropic measure-jet energy form `Q` for the metric
903/// `A = L Lᵀ`. With `L = I` this returns the isotropic
904/// [`super::measure_jet_energy_form`] bit-for-bit.
905pub fn measure_jet_anisotropy_energy_form(
906    centers: ArrayView2<'_, f64>,
907    masses: ArrayView1<'_, f64>,
908    band: &MeasureJetBand,
909    order_s: f64,
910    alpha: f64,
911    l: ArrayView2<'_, f64>,
912) -> Result<Array2<f64>, BasisError> {
913    // The anisotropic energy is EXACTLY the isotropic energy on the
914    // metric-transformed centers `Y = X·M` (module header: `E_A(X) ≡ E_I(Y)`):
915    // every metric-dependent quantity — the kernel distances `‖δM‖²`, the local
916    // affine features `(δ/ε)M`, the ε/2-net, the neighbor cutoff and the
917    // residual algebra — is the isotropic one evaluated on `Y`. Computing the
918    // value by that single substitution (rather than re-deriving it through the
919    // metric block walk) keeps it bit-for-bit identical to the isotropic form at
920    // `M = I` and routes it through the SAME PSD projection, instead of an
921    // operation-reordered re-assembly that drifts by round-off.
922    let d = centers.ncols();
923    if l.nrows() != d || l.ncols() != d {
924        crate::bail_dim_basis!(
925            "measure-jet anisotropy metric L must be {d}×{d} to match the ambient dimension, got {:?}",
926            l.dim()
927        );
928    }
929    let indices = lower_triangular_indices(d);
930    let nf = build_normalized_factor(l, &indices)?;
931    let y = centers.dot(&nf.m);
932    measure_jet_energy_form(y.view(), masses, band, order_s, alpha, 0.0)
933}
934
935/// The det-normalized anisotropic energy together with its EXACT first and
936/// second derivatives with respect to the lower-triangular Cholesky factor
937/// entries of `L`. Value and jets come from one block walk so they cannot
938/// drift apart.
939pub fn measure_jet_anisotropy_energy_form_with_jets(
940    centers: ArrayView2<'_, f64>,
941    masses: ArrayView1<'_, f64>,
942    band: &MeasureJetBand,
943    order_s: f64,
944    alpha: f64,
945    l: ArrayView2<'_, f64>,
946) -> Result<MeasureJetAnisotropyJets, BasisError> {
947    let m_centers = centers.nrows();
948    let d = centers.ncols();
949    if l.nrows() != d || l.ncols() != d {
950        crate::bail_dim_basis!(
951            "measure-jet anisotropy metric L must be {d}×{d} to match the ambient dimension, got {:?}",
952            l.dim()
953        );
954    }
955    if masses.len() != m_centers {
956        crate::bail_dim_basis!(
957            "measure-jet anisotropy mass/center mismatch: {} masses for {} centers",
958            masses.len(),
959            m_centers
960        );
961    }
962    if band.eps.is_empty() || band.eps.iter().any(|e| !(e.is_finite() && *e > 0.0)) {
963        crate::bail_invalid_basis!("measure-jet anisotropy needs a nonempty positive scale band");
964    }
965    if !(order_s.is_finite() && order_s > 0.0 && order_s < 2.0) {
966        crate::bail_invalid_basis!(
967            "measure-jet order s must lie in (0, 2) for the affine-jet energy; got {order_s}"
968        );
969    }
970    if !alpha.is_finite() {
971        crate::bail_invalid_basis!("measure-jet anisotropy needs a finite alpha; got {alpha}");
972    }
973    if masses.iter().any(|v| !(v.is_finite() && *v >= 0.0)) {
974        crate::bail_invalid_basis!("measure-jet anisotropy needs finite nonnegative center masses");
975    }
976
977    let indices = lower_triangular_indices(d);
978    let n = indices.len();
979    let nf = build_normalized_factor(l, &indices)?;
980
981    // Metric distances for the ε/2-net, neighbor cutoff and kernel exponent.
982    let md = metric_sq_dists(centers, nf.m.view());
983
984    let mut d_first: Vec<Array2<f64>> = (0..n)
985        .map(|_| Array2::<f64>::zeros((m_centers, m_centers)))
986        .collect();
987    let mut d_second: Vec<Array2<f64>> = (0..n * n)
988        .map(|_| Array2::<f64>::zeros((m_centers, m_centers)))
989        .collect();
990
991    for &eps in &band.eps {
992        let cutoff2 = (PROFILE_CUTOFF * eps) * (PROFILE_CUTOFF * eps);
993        let intrinsic_dim = d as f64;
994        let eta = 2.0 * order_s + intrinsic_dim * (2.0 - 2.0 * alpha);
995        let scale_weight = band.log_step * eps.powf(-eta);
996        let net_radius2 = 0.25 * eps * eps;
997
998        // Greedy ε/2-net over the metric distances, mass aggregated to nearest
999        // member (lowest-index tie break) — identical policy to the isotropic
1000        // assembly, applied in the metric geometry.
1001        let mut outer: Vec<usize> = Vec::new();
1002        for i in 0..m_centers {
1003            if masses[i] <= 0.0 {
1004                continue;
1005            }
1006            let covered = outer.iter().any(|&o| md.dm2[(i, o)] <= net_radius2);
1007            if !covered {
1008                outer.push(i);
1009            }
1010        }
1011        let mut net_mass = vec![0.0_f64; m_centers];
1012        for i in 0..m_centers {
1013            if masses[i] <= 0.0 {
1014                continue;
1015            }
1016            let mut best = f64::INFINITY;
1017            let mut best_o = usize::MAX;
1018            for &o in &outer {
1019                if md.dm2[(i, o)] < best {
1020                    best = md.dm2[(i, o)];
1021                    best_o = o;
1022                }
1023            }
1024            if best_o != usize::MAX {
1025                net_mass[best_o] += masses[i];
1026            }
1027        }
1028
1029        for &i in &outer {
1030            let mut idx: Vec<usize> = Vec::new();
1031            for j in 0..m_centers {
1032                if md.dm2[(i, j)] <= cutoff2 {
1033                    idx.push(j);
1034                }
1035            }
1036            let ml = idx.len();
1037            // Metric-free local features phi = δ/ε and local masses.
1038            let mut phi = Array2::<f64>::zeros((ml, d));
1039            let mut masses_local = Array1::<f64>::zeros(ml);
1040            for (a, &j) in idx.iter().enumerate() {
1041                for k in 0..d {
1042                    phi[(a, k)] = (centers[(j, k)] - centers[(i, k)]) / eps;
1043                }
1044                masses_local[a] = masses[j];
1045            }
1046
1047            // The kernel mass q for this block uses the metric distances; skip
1048            // empty blocks exactly as the isotropic assembly does.
1049            let q_block: f64 = idx
1050                .iter()
1051                .enumerate()
1052                .map(|(a, &j)| masses_local[a] * (-md.dm2[(i, j)] / (2.0 * eps * eps)).exp())
1053                .sum();
1054            if !(q_block > 0.0) {
1055                continue;
1056            }
1057
1058            let blk = block_residual_jets(&phi, &masses_local, nf.m.view(), &nf.dm, &nf.d2m, n);
1059
1060            // Outer weight base = log_step · ε^(−η) · net_mass_i · q^(1−2α),
1061            // η = 2s + d(2−2α), preserving the advertised |ξ|^(2s) order for
1062            // the available dimension parameter.
1063            // q here is the block's metric kernel mass (matches the isotropic
1064            // assembly's `q`); it is metric-dependent but enters the energy as
1065            // a fixed outer scalar, identical to the isotropic convention.
1066            let base = scale_weight * net_mass[i] * q_block.powf(1.0 - 2.0 * alpha);
1067            let beta = 1.0 - 2.0 * alpha;
1068
1069            // Scatter value + jets with the outer q^β product rule. The block
1070            // derivatives are for R; q, dq and d2q carry the metric-dependent
1071            // density weight.
1072            for (a, &ja) in idx.iter().enumerate() {
1073                for (c, &jc) in idx.iter().enumerate() {
1074                    for x in 0..n {
1075                        let qx_over_q = blk.dq[x] / blk.q;
1076                        d_first[x][(ja, jc)] +=
1077                            base * (blk.dr[x][(a, c)] + beta * qx_over_q * blk.r[(a, c)]);
1078                    }
1079                    for x in 0..n {
1080                        for y in 0..n {
1081                            let qx_over_q = blk.dq[x] / blk.q;
1082                            let qy_over_q = blk.dq[y] / blk.q;
1083                            let qxy_over_q = blk.d2q[x * n + y] / blk.q;
1084                            let density_d2 =
1085                                beta * qxy_over_q + beta * (beta - 1.0) * qx_over_q * qy_over_q;
1086                            d_second[x * n + y][(ja, jc)] += base
1087                                * (blk.d2r[x * n + y][(a, c)]
1088                                    + beta * qx_over_q * blk.dr[y][(a, c)]
1089                                    + beta * qy_over_q * blk.dr[x][(a, c)]
1090                                    + density_d2 * blk.r[(a, c)]);
1091                        }
1092                    }
1093                }
1094            }
1095        }
1096    }
1097
1098    // VALUE: the exact reduction `E_A(X; L) = E_I(X·M)`. Taking the value from
1099    // the isotropic energy on the metric-transformed centers (rather than the
1100    // operation-reordered metric block walk above) makes it bit-for-bit
1101    // identical to the isotropic form at `M = I` and routes it through the SAME
1102    // PSD projection. The block walk above is retained solely for the EXACT
1103    // `L`-jets, which are FD-gated against this value.
1104    let y = centers.dot(&nf.m);
1105    let q = measure_jet_energy_form(y.view(), masses, band, order_s, alpha, 0.0)?;
1106
1107    // Numerical symmetrization (every analytic derivative form here is symmetric).
1108    let sym = |a: Array2<f64>| (&a + &a.t()) * 0.5;
1109    let d_first: Vec<Array2<f64>> = d_first.into_iter().map(sym).collect();
1110    let d_second: Vec<Array2<f64>> = d_second.into_iter().map(sym).collect();
1111
1112    Ok(MeasureJetAnisotropyJets {
1113        q,
1114        indices,
1115        d_first,
1116        d_second,
1117    })
1118}
1119
1120#[cfg(test)]
1121mod tests {
1122    use super::*;
1123    use crate::basis::{measure_jet_band, measure_jet_energy_form};
1124    use ndarray::array;
1125
1126    pub(crate) fn band_for(centers: &Array2<f64>) -> MeasureJetBand {
1127        measure_jet_band(centers.view(), 0).expect("band")
1128    }
1129
1130    pub(crate) fn two_cluster_centers() -> (ndarray::Array2<f64>, ndarray::Array1<f64>) {
1131        (
1132            ndarray::array![
1133                [-0.8, -0.6],
1134                [-0.7, -0.5],
1135                [-0.6, -0.7],
1136                [0.8, 0.6],
1137                [0.7, 0.5],
1138                [0.6, 0.7]
1139            ],
1140            ndarray::array![0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
1141        )
1142    }
1143
1144    /// Oracle (1): with `L = I` (so `Ā = I`, `M = I`) the anisotropic energy
1145    /// reproduces the isotropic `measure_jet_energy_form` bit-for-bit. The
1146    /// metric reaches the energy ONLY through the kernel and the (identity)
1147    /// feature transform, both of which are arithmetically the isotropic path
1148    /// when `M = I`.
1149    #[test]
1150    pub(crate) fn identity_metric_reproduces_isotropic_bit_for_bit() {
1151        let (centers, masses) = two_cluster_centers();
1152        let band = band_for(&centers);
1153        let (s0, a0) = (1.3, 0.8);
1154        let l = Array2::<f64>::eye(2);
1155        let q_aniso = measure_jet_anisotropy_energy_form(
1156            centers.view(),
1157            masses.view(),
1158            &band,
1159            s0,
1160            a0,
1161            l.view(),
1162        )
1163        .expect("aniso energy");
1164        let q_iso = measure_jet_energy_form(centers.view(), masses.view(), &band, s0, a0, 1e-3)
1165            .expect("iso energy");
1166        assert_eq!(q_aniso.dim(), q_iso.dim());
1167        for (a, b) in q_aniso.iter().zip(q_iso.iter()) {
1168            assert_eq!(
1169                a.to_bits(),
1170                b.to_bits(),
1171                "Ā = I must reproduce the isotropic energy bit-for-bit: {a} vs {b}"
1172            );
1173        }
1174    }
1175
1176    /// Oracle (2): every `∂Q/∂L_ij` and `∂²Q/∂L_ij∂L_kl` matches central
1177    /// finite differences of the energy. Step `h = 1e-4` (the
1178    /// second-difference-optimal step mirroring `measure_jet_smooth`'s jet
1179    /// gate), rel tol `5e-5`. A non-identity, non-symmetric lower-triangular
1180    /// `L` exercises every active channel and the off-diagonal coupling.
1181    #[test]
1182    pub(crate) fn l_jets_match_finite_differences() {
1183        let (centers, masses) = two_cluster_centers();
1184        let band = band_for(&centers);
1185        let (s0, a0) = (1.3, 0.8);
1186        // A genuinely anisotropic, full lower-triangular factor.
1187        let l0 = array![[1.30, 0.00], [-0.45, 0.80]];
1188        let jets = measure_jet_anisotropy_energy_form_with_jets(
1189            centers.view(),
1190            masses.view(),
1191            &band,
1192            s0,
1193            a0,
1194            l0.view(),
1195        )
1196        .expect("jets");
1197
1198        // Base value must equal a plain re-evaluation bit-for-bit.
1199        let q_plain = measure_jet_anisotropy_energy_form(
1200            centers.view(),
1201            masses.view(),
1202            &band,
1203            s0,
1204            a0,
1205            l0.view(),
1206        )
1207        .expect("plain");
1208        for (a, b) in jets.q.iter().zip(q_plain.iter()) {
1209            assert_eq!(a.to_bits(), b.to_bits(), "value drift {a} vs {b}");
1210        }
1211
1212        let eval = |l: &Array2<f64>| {
1213            measure_jet_anisotropy_energy_form(
1214                centers.view(),
1215                masses.view(),
1216                &band,
1217                s0,
1218                a0,
1219                l.view(),
1220            )
1221            .expect("energy")
1222        };
1223        let perturb = |idx: LIndex, delta: f64| {
1224            let mut l = l0.clone();
1225            l[(idx.row, idx.col)] += delta;
1226            l
1227        };
1228
1229        let h = 1e-4;
1230        let n = jets.n_active();
1231
1232        // First derivatives via the central two-point stencil.
1233        for a in 0..n {
1234            let ia = jets.indices[a];
1235            let plus = eval(&perturb(ia, h));
1236            let minus = eval(&perturb(ia, -h));
1237            let fd = (&plus - &minus) / (2.0 * h);
1238            let scale = fd.iter().fold(1e-30_f64, |acc, v| acc.max(v.abs()));
1239            for (an, fdv) in jets.d_first[a].iter().zip(fd.iter()) {
1240                assert!(
1241                    (an - fdv).abs() <= 5e-5 * scale,
1242                    "∂Q/∂L[{},{}] mismatch: analytic {an:.6e} vs FD {fdv:.6e} (scale {scale:.3e})",
1243                    ia.row,
1244                    ia.col
1245                );
1246            }
1247        }
1248
1249        // Diagonal second derivatives via the three-point stencil.
1250        for a in 0..n {
1251            let ia = jets.indices[a];
1252            let plus = eval(&perturb(ia, h));
1253            let center = eval(&l0);
1254            let minus = eval(&perturb(ia, -h));
1255            let fd = (&(&plus + &minus) - &(&center * 2.0)) / (h * h);
1256            let scale = fd.iter().fold(1e-30_f64, |acc, v| acc.max(v.abs()));
1257            for (an, fdv) in jets.second(a, a).iter().zip(fd.iter()) {
1258                assert!(
1259                    (an - fdv).abs() <= 5e-5 * scale,
1260                    "∂²Q/∂L[{},{}]² mismatch: analytic {an:.6e} vs FD {fdv:.6e} (scale {scale:.3e})",
1261                    ia.row,
1262                    ia.col
1263                );
1264            }
1265        }
1266
1267        // Cross second derivatives via the four-point stencil.
1268        for a in 0..n {
1269            let ia = jets.indices[a];
1270            for b in (a + 1)..n {
1271                let ib = jets.indices[b];
1272                let mut lpp = l0.clone();
1273                lpp[(ia.row, ia.col)] += h;
1274                lpp[(ib.row, ib.col)] += h;
1275                let mut lpm = l0.clone();
1276                lpm[(ia.row, ia.col)] += h;
1277                lpm[(ib.row, ib.col)] -= h;
1278                let mut lmp = l0.clone();
1279                lmp[(ia.row, ia.col)] -= h;
1280                lmp[(ib.row, ib.col)] += h;
1281                let mut lmm = l0.clone();
1282                lmm[(ia.row, ia.col)] -= h;
1283                lmm[(ib.row, ib.col)] -= h;
1284                let pp = eval(&lpp);
1285                let pm = eval(&lpm);
1286                let mp = eval(&lmp);
1287                let mm = eval(&lmm);
1288                let fd = (&(&pp - &pm) - &(&mp - &mm)) / (4.0 * h * h);
1289                let scale = fd.iter().fold(1e-30_f64, |acc, v| acc.max(v.abs()));
1290                for (an, fdv) in jets.second(a, b).iter().zip(fd.iter()) {
1291                    assert!(
1292                        (an - fdv).abs() <= 5e-5 * scale,
1293                        "∂²Q/∂L[{},{}]∂L[{},{}] mismatch: analytic {an:.6e} vs FD {fdv:.6e} (scale {scale:.3e})",
1294                        ia.row,
1295                        ia.col,
1296                        ib.row,
1297                        ib.col
1298                    );
1299                }
1300                // Symmetry of the second-derivative grid.
1301                for (sab, sba) in jets.second(a, b).iter().zip(jets.second(b, a).iter()) {
1302                    assert!((sab - sba).abs() <= 1e-12 * (1.0 + sab.abs()));
1303                }
1304            }
1305        }
1306    }
1307
1308    /// Oracle (3): det-normalization invariance — scaling `L` by any `c > 0`
1309    /// leaves the energy unchanged, because `Ā = (c L)(c L)ᵀ / det(c² L Lᵀ)^(1/d)
1310    /// = L Lᵀ / det(L Lᵀ)^(1/d)`. The whole point of the normalization is that
1311    /// only the SHAPE of the metric, not its overall scale, is learned.
1312    #[test]
1313    pub(crate) fn det_normalization_is_scale_invariant() {
1314        let (centers, masses) = two_cluster_centers();
1315        let band = band_for(&centers);
1316        let (s0, a0) = (1.1, 0.9);
1317        let l0 = array![[0.90, 0.00], [0.35, 1.40]];
1318        let q_ref = measure_jet_anisotropy_energy_form(
1319            centers.view(),
1320            masses.view(),
1321            &band,
1322            s0,
1323            a0,
1324            l0.view(),
1325        )
1326        .expect("ref");
1327        for &c in &[0.25_f64, 0.5, 2.0, 7.5] {
1328            let lc = &l0 * c;
1329            let q_c = measure_jet_anisotropy_energy_form(
1330                centers.view(),
1331                masses.view(),
1332                &band,
1333                s0,
1334                a0,
1335                lc.view(),
1336            )
1337            .expect("scaled");
1338            let scale = q_ref.iter().fold(0.0_f64, |acc, v| acc.max(v.abs()));
1339            assert!(scale > 0.0, "energy is identically zero");
1340            for (a, b) in q_c.iter().zip(q_ref.iter()) {
1341                assert!(
1342                    (a - b).abs() <= 1e-10 * scale,
1343                    "scale c = {c} changed the normalized energy: {a:.6e} vs {b:.6e}"
1344                );
1345            }
1346        }
1347    }
1348
1349    /// The energy must annihilate constants at every metric (the local affine
1350    /// projection still kills the constant exactly), mirroring the isotropic
1351    /// contract.
1352    #[test]
1353    pub(crate) fn anisotropic_energy_annihilates_constants() {
1354        let (centers, masses) = two_cluster_centers();
1355        let band = band_for(&centers);
1356        let l = array![[1.20, 0.00], [-0.30, 0.95]];
1357        let q = measure_jet_anisotropy_energy_form(
1358            centers.view(),
1359            masses.view(),
1360            &band,
1361            1.5,
1362            1.0,
1363            l.view(),
1364        )
1365        .expect("energy");
1366        let m = q.nrows();
1367        let ones = Array1::<f64>::ones(m);
1368        let qv = q.dot(&ones);
1369        let scale = q.iter().fold(0.0_f64, |acc, v| acc.max(v.abs()));
1370        assert!(scale > 0.0, "energy is identically zero");
1371        for (i, v) in qv.iter().enumerate() {
1372            assert!(
1373                v.abs() <= 1e-10 * scale,
1374                "Q·1 leak at row {i}: {v:.3e} vs scale {scale:.3e}"
1375            );
1376        }
1377    }
1378}