1use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
50
51use faer::Side;
52
53use gam_linalg::faer_ndarray::FaerEigh;
54
55use super::{BasisError, MeasureJetBand, measure_jet_energy_form};
56
57pub(crate) const PROFILE_CUTOFF: f64 = 3.0;
62
63pub(crate) const PSEUDOINVERSE_RTOL: f64 = 64.0 * f64::EPSILON;
67
68#[derive(Clone, Copy, Debug, PartialEq, Eq)]
72pub struct LIndex {
73 pub row: usize,
75 pub col: usize,
77}
78
79pub struct MeasureJetAnisotropyJets {
88 pub q: Array2<f64>,
90 pub indices: Vec<LIndex>,
92 pub d_first: Vec<Array2<f64>>,
94 pub d_second: Vec<Array2<f64>>,
97}
98
99impl MeasureJetAnisotropyJets {
100 #[inline]
102 pub fn n_active(&self) -> usize {
103 self.indices.len()
104 }
105
106 #[inline]
108 pub fn second(&self, a: usize, b: usize) -> &Array2<f64> {
109 &self.d_second[a * self.indices.len() + b]
110 }
111}
112
113pub fn lower_triangular_indices(d: usize) -> Vec<LIndex> {
117 let mut idx = Vec::with_capacity(d * (d + 1) / 2);
118 for col in 0..d {
119 for row in col..d {
120 idx.push(LIndex { row, col });
121 }
122 }
123 idx
124}
125
126pub struct NormalizedFactor {
139 pub(crate) m: Array2<f64>,
141 pub(crate) dm: Vec<Array2<f64>>,
143 pub(crate) d2m: Vec<Array2<f64>>,
145}
146
147pub(crate) fn build_normalized_factor(
148 l: ArrayView2<'_, f64>,
149 indices: &[LIndex],
150) -> Result<NormalizedFactor, BasisError> {
151 let d = l.nrows();
152 if l.ncols() != d {
153 crate::bail_dim_basis!(
154 "measure-jet anisotropy needs a square lower-triangular L, got {:?}",
155 l.dim()
156 );
157 }
158 if d == 0 {
159 crate::bail_invalid_basis!("measure-jet anisotropy needs a non-empty ambient metric");
160 }
161 for k in 0..d {
162 if !(l[(k, k)].is_finite() && l[(k, k)] > 0.0) {
163 crate::bail_invalid_basis!(
164 "measure-jet anisotropy needs a positive-definite L: diagonal entry L[{k},{k}] = {} is not finite and positive",
165 l[(k, k)]
166 );
167 }
168 for c in (k + 1)..d {
169 if l[(k, c)] != 0.0 {
170 crate::bail_invalid_basis!(
171 "measure-jet anisotropy L must be lower-triangular: upper entry L[{k},{c}] = {} is nonzero",
172 l[(k, c)]
173 );
174 }
175 if !l[(c, k)].is_finite() {
176 crate::bail_invalid_basis!(
177 "measure-jet anisotropy L has a non-finite entry L[{c},{k}]"
178 );
179 }
180 }
181 }
182
183 let n = indices.len();
184 let l_owned = l.to_owned();
185
186 let inv_d = 1.0 / d as f64;
188 let mut f_first = vec![0.0_f64; n];
189 let mut f_second = vec![0.0_f64; n * n];
192 for (a, ia) in indices.iter().enumerate() {
193 if ia.row == ia.col {
194 let lkk = l_owned[(ia.row, ia.row)];
195 f_first[a] = -inv_d / lkk;
196 f_second[a * n + a] = inv_d / (lkk * lkk);
197 }
198 }
199
200 let g = (-inv_d * {
202 let mut s = 0.0;
203 for k in 0..d {
204 s += l_owned[(k, k)].ln();
205 }
206 s
207 })
208 .exp();
209
210 let mut g_first = vec![0.0_f64; n];
214 let mut g_second = vec![0.0_f64; n * n];
215 for a in 0..n {
216 g_first[a] = g * f_first[a];
217 }
218 for a in 0..n {
219 for b in 0..n {
220 g_second[a * n + b] = g * (f_first[a] * f_first[b] + f_second[a * n + b]);
221 }
222 }
223
224 let m = &l_owned * g;
228 let mut dm = Vec::with_capacity(n);
229 for a in 0..n {
230 let ia = indices[a];
231 let mut ma = &l_owned * g_first[a];
232 ma[(ia.row, ia.col)] += g;
233 dm.push(ma);
234 }
235 let mut d2m = Vec::with_capacity(n * n);
236 for a in 0..n {
237 let ia = indices[a];
238 for b in 0..n {
239 let ib = indices[b];
240 let mut mab = &l_owned * g_second[a * n + b];
241 mab[(ia.row, ia.col)] += g_first[b];
242 mab[(ib.row, ib.col)] += g_first[a];
243 d2m.push(mab);
244 }
245 }
246
247 Ok(NormalizedFactor { m, dm, d2m })
248}
249
250pub(crate) struct MetricDist2 {
259 pub(crate) dm2: Array2<f64>,
261}
262
263pub(crate) fn metric_sq_dists(centers: ArrayView2<'_, f64>, m: ArrayView2<'_, f64>) -> MetricDist2 {
264 let n = centers.nrows();
265 let y = centers.dot(&m);
269 let yn: Vec<f64> = y.outer_iter().map(|r| r.dot(&r)).collect();
270 let g = y.dot(&y.t());
271 let mut dm2 = Array2::<f64>::zeros((n, n));
272 for i in 0..n {
273 for j in 0..n {
274 dm2[(i, j)] = (yn[i] + yn[j] - 2.0 * g[(i, j)]).max(0.0);
275 }
276 }
277 MetricDist2 { dm2 }
278}
279
280pub(crate) struct EighPinv {
285 pub(crate) evals: Array1<f64>,
286 pub(crate) evecs: Array2<f64>,
287 pub(crate) inv: Array1<f64>,
289 pub(crate) pinv: Array2<f64>,
290}
291
292pub(crate) fn eigh_pinv(a: &Array2<f64>, label: &str) -> Result<EighPinv, BasisError> {
293 let n = a.nrows();
294 let (evals, evecs) = a.eigh(Side::Lower).map_err(|e| {
295 BasisError::InvalidInput(format!(
296 "measure-jet anisotropy pseudo-inverse `{label}` eigendecomposition failed: {e}"
297 ))
298 })?;
299 let lam_max = evals.iter().fold(0.0_f64, |acc, v| acc.max((*v).max(0.0)));
300 let rank_tol = PSEUDOINVERSE_RTOL * (n.max(1) as f64) * lam_max;
301 let mut inv = Array1::<f64>::zeros(n);
302 let mut scaled = evecs.clone();
303 for k in 0..n {
304 let lam = evals[k].max(0.0);
305 let iv = if lam > rank_tol { 1.0 / lam } else { 0.0 };
306 inv[k] = iv;
307 let mut col = scaled.column_mut(k);
308 col.mapv_inplace(|v| v * iv);
309 }
310 let pinv = scaled.dot(&evecs.t());
311 Ok(EighPinv {
312 evals,
313 evecs,
314 inv,
315 pinv,
316 })
317}
318
319pub(crate) fn pinv_first_deriv(ep: &EighPinv, gdot: &Array2<f64>) -> Array2<f64> {
341 let n = ep.evals.len();
342 let vt_g = ep.evecs.t().dot(gdot);
343 let mhat = vt_g.dot(&ep.evecs); let mut core = Array2::<f64>::zeros((n, n));
345 for p in 0..n {
346 for q in 0..n {
347 core[(p, q)] = pinv_div1(ep, p, q) * mhat[(p, q)];
348 }
349 }
350 ep.evecs.dot(&core).dot(&ep.evecs.t())
351}
352
353#[inline]
354pub(crate) fn pinv_active(ep: &EighPinv, i: usize) -> bool {
355 ep.inv[i] != 0.0
356}
357
358#[inline]
359pub(crate) fn pinv_value(ep: &EighPinv, i: usize) -> f64 {
360 if pinv_active(ep, i) { ep.inv[i] } else { 0.0 }
361}
362
363#[inline]
364pub(crate) fn pinv_prime(ep: &EighPinv, i: usize) -> f64 {
365 if pinv_active(ep, i) {
366 -ep.inv[i] * ep.inv[i]
367 } else {
368 0.0
369 }
370}
371
372#[inline]
373pub(crate) fn pinv_half_second(ep: &EighPinv, i: usize) -> f64 {
374 if pinv_active(ep, i) {
375 ep.inv[i] * ep.inv[i] * ep.inv[i]
376 } else {
377 0.0
378 }
379}
380
381pub(crate) fn pinv_div1(ep: &EighPinv, i: usize, j: usize) -> f64 {
382 if i == j {
383 return pinv_prime(ep, i);
384 }
385 let li = ep.evals[i];
386 let lj = ep.evals[j];
387 let denom = li - lj;
388 let scale = li.abs().max(lj.abs()).max(1.0);
389 if denom.abs() <= 16.0 * f64::EPSILON * scale {
390 if pinv_active(ep, i) == pinv_active(ep, j) {
391 0.5 * (pinv_prime(ep, i) + pinv_prime(ep, j))
392 } else {
393 0.0
394 }
395 } else {
396 (pinv_value(ep, i) - pinv_value(ep, j)) / denom
397 }
398}
399
400pub(crate) fn pinv_div2(ep: &EighPinv, i: usize, k: usize, j: usize) -> f64 {
401 if i == k && k == j {
402 return pinv_half_second(ep, i);
403 }
404 let li = ep.evals[i];
405 let lk = ep.evals[k];
406 let lj = ep.evals[j];
407 if i == j {
408 let h = lk - li;
409 let scale = li.abs().max(lk.abs()).max(1.0);
410 if h.abs() <= 16.0 * f64::EPSILON * scale {
411 return pinv_half_second(ep, i);
412 }
413 return (pinv_value(ep, k) - pinv_value(ep, i) - pinv_prime(ep, i) * h) / (h * h);
414 }
415 if i == k {
416 let denom = li - lj;
417 let scale = li.abs().max(lj.abs()).max(1.0);
418 if denom.abs() <= 16.0 * f64::EPSILON * scale {
419 return pinv_half_second(ep, i);
420 }
421 return (pinv_prime(ep, i) - pinv_div1(ep, i, j)) / denom;
422 }
423 if k == j {
424 let denom = li - lj;
425 let scale = li.abs().max(lj.abs()).max(1.0);
426 if denom.abs() <= 16.0 * f64::EPSILON * scale {
427 return pinv_half_second(ep, j);
428 }
429 return (pinv_div1(ep, i, j) - pinv_prime(ep, j)) / denom;
430 }
431 let denom = li - lj;
432 let scale = li.abs().max(lj.abs()).max(1.0);
433 if denom.abs() <= 16.0 * f64::EPSILON * scale {
434 let h = lk - li;
435 if h.abs() <= 16.0 * f64::EPSILON * scale {
436 pinv_half_second(ep, i)
437 } else {
438 (pinv_value(ep, k) - pinv_value(ep, i) - pinv_prime(ep, i) * h) / (h * h)
439 }
440 } else {
441 (pinv_div1(ep, i, k) - pinv_div1(ep, k, j)) / denom
442 }
443}
444
445pub(crate) fn pinv_second_deriv(
446 ep: &EighPinv,
447 gx: &Array2<f64>,
448 gy: &Array2<f64>,
449 gxy: &Array2<f64>,
450) -> Array2<f64> {
451 let n = ep.evals.len();
452 let gx_hat = ep.evecs.t().dot(gx).dot(&ep.evecs);
453 let gy_hat = ep.evecs.t().dot(gy).dot(&ep.evecs);
454 let gxy_hat = ep.evecs.t().dot(gxy).dot(&ep.evecs);
455 let mut core = Array2::<f64>::zeros((n, n));
456 for i in 0..n {
457 for j in 0..n {
458 let mut value = pinv_div1(ep, i, j) * gxy_hat[(i, j)];
459 for k in 0..n {
460 value += pinv_div2(ep, i, k, j)
461 * (gx_hat[(i, k)] * gy_hat[(k, j)] + gy_hat[(i, k)] * gx_hat[(k, j)]);
462 }
463 core[(i, j)] = value;
464 }
465 }
466 ep.evecs.dot(&core).dot(&ep.evecs.t())
467}
468
469pub(crate) struct BlockForms {
475 pub(crate) r: Array2<f64>,
477 pub(crate) dr: Vec<Array2<f64>>,
479 pub(crate) d2r: Vec<Array2<f64>>,
481 pub(crate) q: f64,
483 pub(crate) dq: Vec<f64>,
485 pub(crate) d2q: Vec<f64>,
487}
488
489pub(crate) fn block_residual_jets(
496 phi: &Array2<f64>, masses_local: &Array1<f64>, m: ArrayView2<'_, f64>, dm: &[Array2<f64>], d2m: &[Array2<f64>], n_active: usize,
502) -> BlockForms {
503 let ml = phi.nrows();
504 let n = n_active;
505
506 let psi = phi.dot(&m);
508 let mut dpsi: Vec<Array2<f64>> = Vec::with_capacity(n);
509 for a in 0..n {
510 dpsi.push(phi.dot(&dm[a]));
511 }
512 let mut d2psi: Vec<Array2<f64>> = Vec::with_capacity(n * n);
513 for a in 0..n {
514 for b in 0..n {
515 d2psi.push(phi.dot(&d2m[a * n + b]));
516 }
517 }
518
519 let mut w = Array1::<f64>::zeros(ml);
525 let mut dw: Vec<Array1<f64>> = (0..n).map(|_| Array1::<f64>::zeros(ml)).collect();
526 let mut d2w: Vec<Array1<f64>> = (0..n * n).map(|_| Array1::<f64>::zeros(ml)).collect();
527 for a in 0..ml {
528 let psi_a = psi.row(a);
529 let e = -0.5 * psi_a.dot(&psi_a);
530 let wa = masses_local[a] * e.exp();
531 w[a] = wa;
532 let mut ex = vec![0.0_f64; n];
534 for x in 0..n {
535 ex[x] = -psi_a.dot(&dpsi[x].row(a));
536 dw[x][a] = wa * ex[x];
537 }
538 for x in 0..n {
540 for y in 0..n {
541 let dpx = dpsi[x].row(a);
542 let dpy = dpsi[y].row(a);
543 let d2p = d2psi[x * n + y].row(a);
544 let exy = -(dpx.dot(&dpy) + psi_a.dot(&d2p));
545 d2w[x * n + y][a] = wa * (ex[x] * ex[y] + exy);
546 }
547 }
548 }
549
550 let d = phi.ncols();
564
565 let q = w.sum();
567 let mut dq = vec![0.0_f64; n];
568 let mut d2q = vec![0.0_f64; n * n];
569 for x in 0..n {
570 dq[x] = dw[x].sum();
571 }
572 for x in 0..n {
573 for y in 0..n {
574 d2q[x * n + y] = d2w[x * n + y].sum();
575 }
576 }
577
578 let mut pvec = Array1::<f64>::zeros(d);
581 for a in 0..ml {
582 for k in 0..d {
583 pvec[k] += w[a] * psi[(a, k)];
584 }
585 }
586 let mut dpvec: Vec<Array1<f64>> = (0..n).map(|_| Array1::<f64>::zeros(d)).collect();
587 for x in 0..n {
588 for a in 0..ml {
589 for k in 0..d {
590 dpvec[x][k] += dw[x][a] * psi[(a, k)] + w[a] * dpsi[x][(a, k)];
591 }
592 }
593 }
594 let mut d2pvec: Vec<Array1<f64>> = (0..n * n).map(|_| Array1::<f64>::zeros(d)).collect();
595 for x in 0..n {
596 for y in 0..n {
597 let dst = &mut d2pvec[x * n + y];
598 for a in 0..ml {
599 for k in 0..d {
600 dst[k] += d2w[x * n + y][a] * psi[(a, k)]
601 + dw[x][a] * dpsi[y][(a, k)]
602 + dw[y][a] * dpsi[x][(a, k)]
603 + w[a] * d2psi[x * n + y][(a, k)];
604 }
605 }
606 }
607 }
608
609 let amean = &pvec / q;
611 let mut damean: Vec<Array1<f64>> = Vec::with_capacity(n);
612 for x in 0..n {
613 damean.push((&dpvec[x] - &(&amean * dq[x])) / q);
614 }
615 let mut d2amean: Vec<Array1<f64>> = Vec::with_capacity(n * n);
616 for x in 0..n {
617 for y in 0..n {
618 let term = (&d2pvec[x * n + y]) / q
621 - (&(&dpvec[x] * dq[y]) + &(&dpvec[y] * dq[x]) + &(&pvec * d2q[x * n + y]))
622 / (q * q)
623 + &(&pvec * (2.0 * dq[x] * dq[y] / (q * q * q)));
624 d2amean.push(term);
625 }
626 }
627
628 let bmat = |wv: &Array1<f64>, psiv: &Array2<f64>, am: &Array1<f64>| -> Array2<f64> {
630 let mut bb = Array2::<f64>::zeros((ml, d));
631 for a in 0..ml {
632 for k in 0..d {
633 bb[(a, k)] = wv[a] * (psiv[(a, k)] - am[k]);
634 }
635 }
636 bb
637 };
638 let b = bmat(&w, &psi, &amean);
639 let mut db: Vec<Array2<f64>> = Vec::with_capacity(n);
641 for x in 0..n {
642 let mut bb = Array2::<f64>::zeros((ml, d));
643 for a in 0..ml {
644 for k in 0..d {
645 bb[(a, k)] =
646 dw[x][a] * (psi[(a, k)] - amean[k]) + w[a] * (dpsi[x][(a, k)] - damean[x][k]);
647 }
648 }
649 db.push(bb);
650 }
651 let mut d2b: Vec<Array2<f64>> = Vec::with_capacity(n * n);
654 for x in 0..n {
655 for y in 0..n {
656 let mut bb = Array2::<f64>::zeros((ml, d));
657 for a in 0..ml {
658 for k in 0..d {
659 bb[(a, k)] = d2w[x * n + y][a] * (psi[(a, k)] - amean[k])
660 + dw[x][a] * (dpsi[y][(a, k)] - damean[y][k])
661 + dw[y][a] * (dpsi[x][(a, k)] - damean[x][k])
662 + w[a] * (d2psi[x * n + y][(a, k)] - d2amean[x * n + y][k]);
663 }
664 }
665 d2b.push(bb);
666 }
667 }
668
669 let hmat = |wv: &Array1<f64>, psiv: &Array2<f64>| -> Array2<f64> {
671 let mut hh = Array2::<f64>::zeros((d, d));
672 for a in 0..ml {
673 for r in 0..d {
674 for c in 0..d {
675 hh[(r, c)] += wv[a] * psiv[(a, r)] * psiv[(a, c)];
676 }
677 }
678 }
679 hh
680 };
681 let hh = hmat(&w, &psi);
682 let mut dhh: Vec<Array2<f64>> = Vec::with_capacity(n);
683 for x in 0..n {
684 let mut hd = Array2::<f64>::zeros((d, d));
685 for a in 0..ml {
686 for r in 0..d {
687 for c in 0..d {
688 hd[(r, c)] += dw[x][a] * psi[(a, r)] * psi[(a, c)]
689 + w[a] * dpsi[x][(a, r)] * psi[(a, c)]
690 + w[a] * psi[(a, r)] * dpsi[x][(a, c)];
691 }
692 }
693 }
694 dhh.push(hd);
695 }
696 let mut d2hh: Vec<Array2<f64>> = Vec::with_capacity(n * n);
697 for x in 0..n {
698 for y in 0..n {
699 let mut hd = Array2::<f64>::zeros((d, d));
700 for a in 0..ml {
701 for r in 0..d {
702 for c in 0..d {
703 let pr = psi[(a, r)];
704 let pc = psi[(a, c)];
705 let dprx = dpsi[x][(a, r)];
706 let dpcx = dpsi[x][(a, c)];
707 let dpry = dpsi[y][(a, r)];
708 let dpcy = dpsi[y][(a, c)];
709 let d2pr = d2psi[x * n + y][(a, r)];
710 let d2pc = d2psi[x * n + y][(a, c)];
711 hd[(r, c)] += d2w[x * n + y][a] * pr * pc
712 + dw[x][a] * (dpry * pc + pr * dpcy)
713 + dw[y][a] * (dprx * pc + pr * dpcx)
714 + w[a] * (d2pr * pc + dprx * dpcy + dpry * dpcx + pr * d2pc);
715 }
716 }
717 }
718 d2hh.push(hd);
719 }
720 }
721
722 let outer = |u: &Array1<f64>, v: &Array1<f64>| -> Array2<f64> {
724 let mut o = Array2::<f64>::zeros((d, d));
725 for r in 0..d {
726 for c in 0..d {
727 o[(r, c)] = u[r] * v[c];
728 }
729 }
730 o
731 };
732 let g = &(&hh / q) - &outer(&amean, &amean);
733 let mut dg: Vec<Array2<f64>> = Vec::with_capacity(n);
734 for x in 0..n {
735 let dhq = &(&dhh[x] / q) - &(&hh * (dq[x] / (q * q)));
737 let dout = &outer(&damean[x], &amean) + &outer(&amean, &damean[x]);
739 dg.push(&dhq - &dout);
740 }
741 let mut d2g: Vec<Array2<f64>> = Vec::with_capacity(n * n);
742 for x in 0..n {
743 for y in 0..n {
744 let d2hq = &(&d2hh[x * n + y] / q)
746 - &(&(&dhh[x] * (dq[y] / (q * q)))
747 + &(&dhh[y] * (dq[x] / (q * q)))
748 + &(&hh * (d2q[x * n + y] / (q * q))))
749 + &(&hh * (2.0 * dq[x] * dq[y] / (q * q * q)));
750 let d2out = &outer(&d2amean[x * n + y], &amean)
752 + &outer(&damean[x], &damean[y])
753 + &outer(&damean[y], &damean[x])
754 + &outer(&amean, &d2amean[x * n + y]);
755 d2g.push(&d2hq - &d2out);
756 }
757 }
758
759 let ep = eigh_pinv(&g, "local affine Gram").unwrap_or_else(|_| {
761 EighPinv {
766 evals: Array1::zeros(d),
767 evecs: Array2::eye(d),
768 inv: Array1::zeros(d),
769 pinv: Array2::zeros((d, d)),
770 }
771 });
772 let gpinv = ep.pinv.clone();
773 let mut dgpinv: Vec<Array2<f64>> = Vec::with_capacity(n);
774 for x in 0..n {
775 dgpinv.push(pinv_first_deriv(&ep, &dg[x]));
776 }
777 let mut d2gpinv: Vec<Array2<f64>> = Vec::with_capacity(n * n);
782 for x in 0..n {
783 for y in 0..n {
784 d2gpinv.push(pinv_second_deriv(&ep, &dg[x], &dg[y], &d2g[x * n + y]));
785 }
786 }
787
788 let triple = |bb: &Array2<f64>, gp: &Array2<f64>| -> Array2<f64> { bb.dot(gp).dot(&bb.t()) };
791 let p = triple(&b, &gpinv);
792 let mut dp: Vec<Array2<f64>> = Vec::with_capacity(n);
793 for x in 0..n {
794 let t1 = db[x].dot(&gpinv).dot(&b.t());
796 let t2 = b.dot(&dgpinv[x]).dot(&b.t());
797 let t3 = b.dot(&gpinv).dot(&db[x].t());
798 dp.push(&(&t1 + &t2) + &t3);
799 }
800 let mut d2p: Vec<Array2<f64>> = Vec::with_capacity(n * n);
801 for x in 0..n {
802 for y in 0..n {
803 let bx = &db[x];
805 let by = &db[y];
806 let bxy = &d2b[x * n + y];
807 let gx = &dgpinv[x];
808 let gy = &dgpinv[y];
809 let gxy = &d2gpinv[x * n + y];
810 let mut acc = bxy.dot(&gpinv).dot(&b.t());
811 acc += &bx.dot(gy).dot(&b.t());
812 acc += &bx.dot(&gpinv).dot(&by.t());
813 acc += &by.dot(gx).dot(&b.t());
814 acc += &b.dot(gxy).dot(&b.t());
815 acc += &b.dot(gx).dot(&by.t());
816 acc += &by.dot(&gpinv).dot(&bx.t());
817 acc += &b.dot(gy).dot(&bx.t());
818 acc += &b.dot(&gpinv).dot(&bxy.t());
819 d2p.push(acc);
820 }
821 }
822
823 let assemble_r = |wv: &Array1<f64>, qv: f64, pv: &Array2<f64>| -> Array2<f64> {
825 let mut rr = Array2::<f64>::zeros((ml, ml));
826 for a in 0..ml {
827 for c in 0..ml {
828 rr[(a, c)] = -wv[a] * wv[c] / qv - pv[(a, c)] / qv;
829 }
830 rr[(a, a)] += wv[a];
831 }
832 rr
833 };
834 let r = assemble_r(&w, q, &p);
835
836 let mut dr: Vec<Array2<f64>> = Vec::with_capacity(n);
840 for x in 0..n {
841 let mut rr = Array2::<f64>::zeros((ml, ml));
842 for a in 0..ml {
843 for c in 0..ml {
844 let wwt_d = (dw[x][a] * w[c] + w[a] * dw[x][c]) / q - w[a] * w[c] * dq[x] / (q * q);
845 let pd = dp[x][(a, c)] / q - p[(a, c)] * dq[x] / (q * q);
846 rr[(a, c)] = -wwt_d - pd;
847 }
848 rr[(a, a)] += dw[x][a];
849 }
850 dr.push(rr);
851 }
852
853 let mut d2r: Vec<Array2<f64>> = Vec::with_capacity(n * n);
855 for x in 0..n {
856 for y in 0..n {
857 let qx = dq[x];
858 let qy = dq[y];
859 let qxy = d2q[x * n + y];
860 let mut rr = Array2::<f64>::zeros((ml, ml));
861 for a in 0..ml {
862 for c in 0..ml {
863 let num = w[a] * w[c];
865 let num_x = dw[x][a] * w[c] + w[a] * dw[x][c];
866 let num_y = dw[y][a] * w[c] + w[a] * dw[y][c];
867 let num_xy = d2w[x * n + y][a] * w[c]
868 + dw[x][a] * dw[y][c]
869 + dw[y][a] * dw[x][c]
870 + w[a] * d2w[x * n + y][c];
871 let wwt_d2 = num_xy / q - (num_x * qy + num_y * qx + num * qxy) / (q * q)
872 + 2.0 * num * qx * qy / (q * q * q);
873 let pn = p[(a, c)];
875 let pnx = dp[x][(a, c)];
876 let pny = dp[y][(a, c)];
877 let pnxy = d2p[x * n + y][(a, c)];
878 let p_d2 = pnxy / q - (pnx * qy + pny * qx + pn * qxy) / (q * q)
879 + 2.0 * pn * qx * qy / (q * q * q);
880 rr[(a, c)] = -wwt_d2 - p_d2;
881 }
882 rr[(a, a)] += d2w[x * n + y][a];
883 }
884 d2r.push(rr);
885 }
886 }
887
888 BlockForms {
889 r,
890 dr,
891 d2r,
892 q,
893 dq,
894 d2q,
895 }
896}
897
898pub fn measure_jet_anisotropy_energy_form(
906 centers: ArrayView2<'_, f64>,
907 masses: ArrayView1<'_, f64>,
908 band: &MeasureJetBand,
909 order_s: f64,
910 alpha: f64,
911 l: ArrayView2<'_, f64>,
912) -> Result<Array2<f64>, BasisError> {
913 let d = centers.ncols();
923 if l.nrows() != d || l.ncols() != d {
924 crate::bail_dim_basis!(
925 "measure-jet anisotropy metric L must be {d}×{d} to match the ambient dimension, got {:?}",
926 l.dim()
927 );
928 }
929 let indices = lower_triangular_indices(d);
930 let nf = build_normalized_factor(l, &indices)?;
931 let y = centers.dot(&nf.m);
932 measure_jet_energy_form(y.view(), masses, band, order_s, alpha, 0.0)
933}
934
935pub fn measure_jet_anisotropy_energy_form_with_jets(
940 centers: ArrayView2<'_, f64>,
941 masses: ArrayView1<'_, f64>,
942 band: &MeasureJetBand,
943 order_s: f64,
944 alpha: f64,
945 l: ArrayView2<'_, f64>,
946) -> Result<MeasureJetAnisotropyJets, BasisError> {
947 let m_centers = centers.nrows();
948 let d = centers.ncols();
949 if l.nrows() != d || l.ncols() != d {
950 crate::bail_dim_basis!(
951 "measure-jet anisotropy metric L must be {d}×{d} to match the ambient dimension, got {:?}",
952 l.dim()
953 );
954 }
955 if masses.len() != m_centers {
956 crate::bail_dim_basis!(
957 "measure-jet anisotropy mass/center mismatch: {} masses for {} centers",
958 masses.len(),
959 m_centers
960 );
961 }
962 if band.eps.is_empty() || band.eps.iter().any(|e| !(e.is_finite() && *e > 0.0)) {
963 crate::bail_invalid_basis!("measure-jet anisotropy needs a nonempty positive scale band");
964 }
965 if !(order_s.is_finite() && order_s > 0.0 && order_s < 2.0) {
966 crate::bail_invalid_basis!(
967 "measure-jet order s must lie in (0, 2) for the affine-jet energy; got {order_s}"
968 );
969 }
970 if !alpha.is_finite() {
971 crate::bail_invalid_basis!("measure-jet anisotropy needs a finite alpha; got {alpha}");
972 }
973 if masses.iter().any(|v| !(v.is_finite() && *v >= 0.0)) {
974 crate::bail_invalid_basis!("measure-jet anisotropy needs finite nonnegative center masses");
975 }
976
977 let indices = lower_triangular_indices(d);
978 let n = indices.len();
979 let nf = build_normalized_factor(l, &indices)?;
980
981 let md = metric_sq_dists(centers, nf.m.view());
983
984 let mut d_first: Vec<Array2<f64>> = (0..n)
985 .map(|_| Array2::<f64>::zeros((m_centers, m_centers)))
986 .collect();
987 let mut d_second: Vec<Array2<f64>> = (0..n * n)
988 .map(|_| Array2::<f64>::zeros((m_centers, m_centers)))
989 .collect();
990
991 for &eps in &band.eps {
992 let cutoff2 = (PROFILE_CUTOFF * eps) * (PROFILE_CUTOFF * eps);
993 let intrinsic_dim = d as f64;
994 let eta = 2.0 * order_s + intrinsic_dim * (2.0 - 2.0 * alpha);
995 let scale_weight = band.log_step * eps.powf(-eta);
996 let net_radius2 = 0.25 * eps * eps;
997
998 let mut outer: Vec<usize> = Vec::new();
1002 for i in 0..m_centers {
1003 if masses[i] <= 0.0 {
1004 continue;
1005 }
1006 let covered = outer.iter().any(|&o| md.dm2[(i, o)] <= net_radius2);
1007 if !covered {
1008 outer.push(i);
1009 }
1010 }
1011 let mut net_mass = vec![0.0_f64; m_centers];
1012 for i in 0..m_centers {
1013 if masses[i] <= 0.0 {
1014 continue;
1015 }
1016 let mut best = f64::INFINITY;
1017 let mut best_o = usize::MAX;
1018 for &o in &outer {
1019 if md.dm2[(i, o)] < best {
1020 best = md.dm2[(i, o)];
1021 best_o = o;
1022 }
1023 }
1024 if best_o != usize::MAX {
1025 net_mass[best_o] += masses[i];
1026 }
1027 }
1028
1029 for &i in &outer {
1030 let mut idx: Vec<usize> = Vec::new();
1031 for j in 0..m_centers {
1032 if md.dm2[(i, j)] <= cutoff2 {
1033 idx.push(j);
1034 }
1035 }
1036 let ml = idx.len();
1037 let mut phi = Array2::<f64>::zeros((ml, d));
1039 let mut masses_local = Array1::<f64>::zeros(ml);
1040 for (a, &j) in idx.iter().enumerate() {
1041 for k in 0..d {
1042 phi[(a, k)] = (centers[(j, k)] - centers[(i, k)]) / eps;
1043 }
1044 masses_local[a] = masses[j];
1045 }
1046
1047 let q_block: f64 = idx
1050 .iter()
1051 .enumerate()
1052 .map(|(a, &j)| masses_local[a] * (-md.dm2[(i, j)] / (2.0 * eps * eps)).exp())
1053 .sum();
1054 if !(q_block > 0.0) {
1055 continue;
1056 }
1057
1058 let blk = block_residual_jets(&phi, &masses_local, nf.m.view(), &nf.dm, &nf.d2m, n);
1059
1060 let base = scale_weight * net_mass[i] * q_block.powf(1.0 - 2.0 * alpha);
1067 let beta = 1.0 - 2.0 * alpha;
1068
1069 for (a, &ja) in idx.iter().enumerate() {
1073 for (c, &jc) in idx.iter().enumerate() {
1074 for x in 0..n {
1075 let qx_over_q = blk.dq[x] / blk.q;
1076 d_first[x][(ja, jc)] +=
1077 base * (blk.dr[x][(a, c)] + beta * qx_over_q * blk.r[(a, c)]);
1078 }
1079 for x in 0..n {
1080 for y in 0..n {
1081 let qx_over_q = blk.dq[x] / blk.q;
1082 let qy_over_q = blk.dq[y] / blk.q;
1083 let qxy_over_q = blk.d2q[x * n + y] / blk.q;
1084 let density_d2 =
1085 beta * qxy_over_q + beta * (beta - 1.0) * qx_over_q * qy_over_q;
1086 d_second[x * n + y][(ja, jc)] += base
1087 * (blk.d2r[x * n + y][(a, c)]
1088 + beta * qx_over_q * blk.dr[y][(a, c)]
1089 + beta * qy_over_q * blk.dr[x][(a, c)]
1090 + density_d2 * blk.r[(a, c)]);
1091 }
1092 }
1093 }
1094 }
1095 }
1096 }
1097
1098 let y = centers.dot(&nf.m);
1105 let q = measure_jet_energy_form(y.view(), masses, band, order_s, alpha, 0.0)?;
1106
1107 let sym = |a: Array2<f64>| (&a + &a.t()) * 0.5;
1109 let d_first: Vec<Array2<f64>> = d_first.into_iter().map(sym).collect();
1110 let d_second: Vec<Array2<f64>> = d_second.into_iter().map(sym).collect();
1111
1112 Ok(MeasureJetAnisotropyJets {
1113 q,
1114 indices,
1115 d_first,
1116 d_second,
1117 })
1118}
1119
1120#[cfg(test)]
1121mod tests {
1122 use super::*;
1123 use crate::basis::{measure_jet_band, measure_jet_energy_form};
1124 use ndarray::array;
1125
1126 pub(crate) fn band_for(centers: &Array2<f64>) -> MeasureJetBand {
1127 measure_jet_band(centers.view(), 0).expect("band")
1128 }
1129
1130 pub(crate) fn two_cluster_centers() -> (ndarray::Array2<f64>, ndarray::Array1<f64>) {
1131 (
1132 ndarray::array![
1133 [-0.8, -0.6],
1134 [-0.7, -0.5],
1135 [-0.6, -0.7],
1136 [0.8, 0.6],
1137 [0.7, 0.5],
1138 [0.6, 0.7]
1139 ],
1140 ndarray::array![0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
1141 )
1142 }
1143
1144 #[test]
1150 pub(crate) fn identity_metric_reproduces_isotropic_bit_for_bit() {
1151 let (centers, masses) = two_cluster_centers();
1152 let band = band_for(¢ers);
1153 let (s0, a0) = (1.3, 0.8);
1154 let l = Array2::<f64>::eye(2);
1155 let q_aniso = measure_jet_anisotropy_energy_form(
1156 centers.view(),
1157 masses.view(),
1158 &band,
1159 s0,
1160 a0,
1161 l.view(),
1162 )
1163 .expect("aniso energy");
1164 let q_iso = measure_jet_energy_form(centers.view(), masses.view(), &band, s0, a0, 1e-3)
1165 .expect("iso energy");
1166 assert_eq!(q_aniso.dim(), q_iso.dim());
1167 for (a, b) in q_aniso.iter().zip(q_iso.iter()) {
1168 assert_eq!(
1169 a.to_bits(),
1170 b.to_bits(),
1171 "Ā = I must reproduce the isotropic energy bit-for-bit: {a} vs {b}"
1172 );
1173 }
1174 }
1175
1176 #[test]
1182 pub(crate) fn l_jets_match_finite_differences() {
1183 let (centers, masses) = two_cluster_centers();
1184 let band = band_for(¢ers);
1185 let (s0, a0) = (1.3, 0.8);
1186 let l0 = array![[1.30, 0.00], [-0.45, 0.80]];
1188 let jets = measure_jet_anisotropy_energy_form_with_jets(
1189 centers.view(),
1190 masses.view(),
1191 &band,
1192 s0,
1193 a0,
1194 l0.view(),
1195 )
1196 .expect("jets");
1197
1198 let q_plain = measure_jet_anisotropy_energy_form(
1200 centers.view(),
1201 masses.view(),
1202 &band,
1203 s0,
1204 a0,
1205 l0.view(),
1206 )
1207 .expect("plain");
1208 for (a, b) in jets.q.iter().zip(q_plain.iter()) {
1209 assert_eq!(a.to_bits(), b.to_bits(), "value drift {a} vs {b}");
1210 }
1211
1212 let eval = |l: &Array2<f64>| {
1213 measure_jet_anisotropy_energy_form(
1214 centers.view(),
1215 masses.view(),
1216 &band,
1217 s0,
1218 a0,
1219 l.view(),
1220 )
1221 .expect("energy")
1222 };
1223 let perturb = |idx: LIndex, delta: f64| {
1224 let mut l = l0.clone();
1225 l[(idx.row, idx.col)] += delta;
1226 l
1227 };
1228
1229 let h = 1e-4;
1230 let n = jets.n_active();
1231
1232 for a in 0..n {
1234 let ia = jets.indices[a];
1235 let plus = eval(&perturb(ia, h));
1236 let minus = eval(&perturb(ia, -h));
1237 let fd = (&plus - &minus) / (2.0 * h);
1238 let scale = fd.iter().fold(1e-30_f64, |acc, v| acc.max(v.abs()));
1239 for (an, fdv) in jets.d_first[a].iter().zip(fd.iter()) {
1240 assert!(
1241 (an - fdv).abs() <= 5e-5 * scale,
1242 "∂Q/∂L[{},{}] mismatch: analytic {an:.6e} vs FD {fdv:.6e} (scale {scale:.3e})",
1243 ia.row,
1244 ia.col
1245 );
1246 }
1247 }
1248
1249 for a in 0..n {
1251 let ia = jets.indices[a];
1252 let plus = eval(&perturb(ia, h));
1253 let center = eval(&l0);
1254 let minus = eval(&perturb(ia, -h));
1255 let fd = (&(&plus + &minus) - &(¢er * 2.0)) / (h * h);
1256 let scale = fd.iter().fold(1e-30_f64, |acc, v| acc.max(v.abs()));
1257 for (an, fdv) in jets.second(a, a).iter().zip(fd.iter()) {
1258 assert!(
1259 (an - fdv).abs() <= 5e-5 * scale,
1260 "∂²Q/∂L[{},{}]² mismatch: analytic {an:.6e} vs FD {fdv:.6e} (scale {scale:.3e})",
1261 ia.row,
1262 ia.col
1263 );
1264 }
1265 }
1266
1267 for a in 0..n {
1269 let ia = jets.indices[a];
1270 for b in (a + 1)..n {
1271 let ib = jets.indices[b];
1272 let mut lpp = l0.clone();
1273 lpp[(ia.row, ia.col)] += h;
1274 lpp[(ib.row, ib.col)] += h;
1275 let mut lpm = l0.clone();
1276 lpm[(ia.row, ia.col)] += h;
1277 lpm[(ib.row, ib.col)] -= h;
1278 let mut lmp = l0.clone();
1279 lmp[(ia.row, ia.col)] -= h;
1280 lmp[(ib.row, ib.col)] += h;
1281 let mut lmm = l0.clone();
1282 lmm[(ia.row, ia.col)] -= h;
1283 lmm[(ib.row, ib.col)] -= h;
1284 let pp = eval(&lpp);
1285 let pm = eval(&lpm);
1286 let mp = eval(&lmp);
1287 let mm = eval(&lmm);
1288 let fd = (&(&pp - &pm) - &(&mp - &mm)) / (4.0 * h * h);
1289 let scale = fd.iter().fold(1e-30_f64, |acc, v| acc.max(v.abs()));
1290 for (an, fdv) in jets.second(a, b).iter().zip(fd.iter()) {
1291 assert!(
1292 (an - fdv).abs() <= 5e-5 * scale,
1293 "∂²Q/∂L[{},{}]∂L[{},{}] mismatch: analytic {an:.6e} vs FD {fdv:.6e} (scale {scale:.3e})",
1294 ia.row,
1295 ia.col,
1296 ib.row,
1297 ib.col
1298 );
1299 }
1300 for (sab, sba) in jets.second(a, b).iter().zip(jets.second(b, a).iter()) {
1302 assert!((sab - sba).abs() <= 1e-12 * (1.0 + sab.abs()));
1303 }
1304 }
1305 }
1306 }
1307
1308 #[test]
1313 pub(crate) fn det_normalization_is_scale_invariant() {
1314 let (centers, masses) = two_cluster_centers();
1315 let band = band_for(¢ers);
1316 let (s0, a0) = (1.1, 0.9);
1317 let l0 = array![[0.90, 0.00], [0.35, 1.40]];
1318 let q_ref = measure_jet_anisotropy_energy_form(
1319 centers.view(),
1320 masses.view(),
1321 &band,
1322 s0,
1323 a0,
1324 l0.view(),
1325 )
1326 .expect("ref");
1327 for &c in &[0.25_f64, 0.5, 2.0, 7.5] {
1328 let lc = &l0 * c;
1329 let q_c = measure_jet_anisotropy_energy_form(
1330 centers.view(),
1331 masses.view(),
1332 &band,
1333 s0,
1334 a0,
1335 lc.view(),
1336 )
1337 .expect("scaled");
1338 let scale = q_ref.iter().fold(0.0_f64, |acc, v| acc.max(v.abs()));
1339 assert!(scale > 0.0, "energy is identically zero");
1340 for (a, b) in q_c.iter().zip(q_ref.iter()) {
1341 assert!(
1342 (a - b).abs() <= 1e-10 * scale,
1343 "scale c = {c} changed the normalized energy: {a:.6e} vs {b:.6e}"
1344 );
1345 }
1346 }
1347 }
1348
1349 #[test]
1353 pub(crate) fn anisotropic_energy_annihilates_constants() {
1354 let (centers, masses) = two_cluster_centers();
1355 let band = band_for(¢ers);
1356 let l = array![[1.20, 0.00], [-0.30, 0.95]];
1357 let q = measure_jet_anisotropy_energy_form(
1358 centers.view(),
1359 masses.view(),
1360 &band,
1361 1.5,
1362 1.0,
1363 l.view(),
1364 )
1365 .expect("energy");
1366 let m = q.nrows();
1367 let ones = Array1::<f64>::ones(m);
1368 let qv = q.dot(&ones);
1369 let scale = q.iter().fold(0.0_f64, |acc, v| acc.max(v.abs()));
1370 assert!(scale > 0.0, "energy is identically zero");
1371 for (i, v) in qv.iter().enumerate() {
1372 assert!(
1373 v.abs() <= 1e-10 * scale,
1374 "Q·1 leak at row {i}: {v:.3e} vs scale {scale:.3e}"
1375 );
1376 }
1377 }
1378}