1use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
115use rayon::prelude::*;
116use serde::{Deserialize, Serialize};
117
118use faer::Side;
119
120use gam_linalg::faer_ndarray::FaerEigh;
121
122use super::{
123 AnisoBasisPsiDerivatives, AnisoPenaltyCrossProvider, BasisBuildResult, BasisError,
124 BasisMetadata, CenterStrategy, PenaltyCandidate, PenaltySource,
125 filter_active_penalty_candidates_with_ops, normalize_penalty,
126 normalize_penalty_cross_psi_derivative, normalize_penaltywith_psi_derivatives,
127 select_centers_by_strategy, trace_of_product,
128};
129
130pub(crate) const MEASURE_JET_PROFILE_CUTOFF: f64 = 3.0;
136
137pub(crate) const MEASURE_JET_PSEUDOINVERSE_RTOL: f64 = 64.0 * f64::EPSILON;
141
142pub(crate) const MEASURE_JET_DEFAULT_ORDER_S: f64 = 1.5;
148
149pub(crate) const MEASURE_JET_MIN_AUTO_SCALES: usize = 3;
153pub(crate) const MEASURE_JET_MAX_AUTO_SCALES: usize = 8;
154
155pub(crate) const MEASURE_JET_AUTO_LENGTH_SCALE_FACTOR: f64 = 1.0;
170
171pub(crate) const MEASURE_JET_FUSED_RIDGE_FRACTION: f64 = 1e-2;
180
181pub(crate) const MEASURE_JET_PARALLEL_FORM_BUDGET_DOUBLES: usize = 1 << 26;
187
188#[derive(Debug, Clone, Serialize, Deserialize, Default)]
195pub enum MeasureJetIdentifiability {
196 #[default]
199 CenterSumToZero,
200 FrozenTransform { transform: Array2<f64> },
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct MeasureJetFrozenQuadrature {
211 pub masses: Array1<f64>,
213 pub eps_band: Vec<f64>,
215 pub support_means: Vec<f64>,
218 pub penalty_normalization_scales: Vec<f64>,
221 pub raw_penalty_normalization_scales: Vec<f64>,
224 pub fused_penalty_normalization_scale: Option<f64>,
227}
228
229fn measure_jet_learn_length_scale_default() -> bool {
233 false
234}
235
236#[derive(Debug, Clone, Serialize, Deserialize)]
244pub struct MeasureJetBasisSpec {
245 pub center_strategy: CenterStrategy,
247 pub order_s: f64,
250 pub alpha: f64,
252 pub tau0: f64,
256 pub num_scales: usize,
258 pub length_scale: f64,
261 pub double_penalty: bool,
264 #[serde(default = "measure_jet_learn_length_scale_default")]
272 pub learn_length_scale: bool,
273 #[serde(default)]
281 pub multiscale: bool,
282 #[serde(default)]
284 pub identifiability: MeasureJetIdentifiability,
285 #[serde(default)]
288 pub frozen_quadrature: Option<MeasureJetFrozenQuadrature>,
289}
290
291impl Default for MeasureJetBasisSpec {
292 fn default() -> Self {
293 Self {
294 center_strategy: CenterStrategy::FarthestPoint { num_centers: 50 },
295 order_s: 0.0,
296 alpha: 1.0,
308 tau0: 1e-3,
309 num_scales: 0,
310 length_scale: 0.0,
311 double_penalty: true,
312 learn_length_scale: false,
313 multiscale: false,
314 identifiability: MeasureJetIdentifiability::CenterSumToZero,
315 frozen_quadrature: None,
316 }
317 }
318}
319
320pub struct MeasureJetBand {
323 pub eps: Vec<f64>,
324 pub log_step: f64,
325}
326
327pub struct MeasureJetEnergyJets {
333 pub q: Array2<f64>,
334 pub dq_ds: Array2<f64>,
335 pub d2q_ds2: Array2<f64>,
336 pub dq_dalpha: Array2<f64>,
337 pub d2q_dalpha2: Array2<f64>,
338 pub d2q_ds_dalpha: Array2<f64>,
339 pub dq_dlogtau: Array2<f64>,
340 pub d2q_dlogtau2: Array2<f64>,
341 pub d2q_ds_dlogtau: Array2<f64>,
342 pub d2q_dalpha_dlogtau: Array2<f64>,
343}
344
345pub(crate) fn householder_sum_to_zero_u(m: usize) -> Array1<f64> {
353 let c = 1.0 / (m as f64).sqrt();
354 let mut u = Array1::<f64>::from_elem(m, c);
355 u[0] -= 1.0;
356 let norm = u.dot(&u).sqrt();
357 u.mapv_inplace(|v| v / norm);
358 u
359}
360
361pub(crate) fn householder_sum_to_zero_z(u: &Array1<f64>) -> Array2<f64> {
365 let m = u.len();
366 let mut z = Array2::<f64>::zeros((m, m - 1));
367 for j in 0..(m - 1) {
368 for i in 0..m {
369 let h = if i == j + 1 { 1.0 } else { 0.0 } - 2.0 * u[i] * u[j + 1];
370 z[(i, j)] = h;
371 }
372 }
373 z
374}
375
376pub(crate) fn symmetric_pseudoinverse(
377 a: &Array2<f64>,
378 label: &str,
379) -> Result<Array2<f64>, BasisError> {
380 let n = a.nrows();
381 if a.ncols() != n {
382 crate::bail_dim_basis!(
383 "measure-jet pseudo-inverse `{label}` needs a square matrix, got {:?}",
384 a.dim()
385 );
386 }
387 let (evals, evecs) = a.eigh(Side::Lower).map_err(|e| {
388 BasisError::InvalidInput(format!(
389 "measure-jet pseudo-inverse `{label}` eigendecomposition failed: {e}"
390 ))
391 })?;
392 let lam_max = evals.iter().fold(0.0_f64, |acc, v| acc.max((*v).max(0.0)));
393 let rank_tol = MEASURE_JET_PSEUDOINVERSE_RTOL * (n.max(1) as f64) * lam_max;
394 let mut scaled = evecs.clone();
395 for (k, mut col) in scaled.axis_iter_mut(Axis(1)).enumerate() {
396 let lam = evals[k].max(0.0);
397 let inv = if lam > rank_tol { 1.0 / lam } else { 0.0 };
398 col.mapv_inplace(|v| v * inv);
399 }
400 Ok(scaled.dot(&evecs.t()))
401}
402
403pub(crate) fn affine_preserving_coefficient_ridge(
411 kz: &Array2<f64>,
412 centers: ArrayView2<'_, f64>,
413 masses: ArrayView1<'_, f64>,
414) -> Result<Array2<f64>, BasisError> {
415 let m = centers.nrows();
416 let d = centers.ncols();
417 let p = kz.ncols();
418 if kz.nrows() != m || masses.len() != m {
419 crate::bail_dim_basis!(
420 "measure-jet affine-preserving ridge shape mismatch: kz {:?}, centers {:?}, masses {}",
421 kz.dim(),
422 centers.dim(),
423 masses.len()
424 );
425 }
426 if p == 0 {
427 return Ok(Array2::<f64>::zeros((0, 0)));
428 }
429 let mut weighted_kz = kz.clone();
430 for (i, mut row) in weighted_kz.outer_iter_mut().enumerate() {
431 row.mapv_inplace(|v| v * masses[i]);
432 }
433 let normal = kz.t().dot(&weighted_kz);
434 let normal_pinv = symmetric_pseudoinverse(&normal, "affine ridge normal")?;
435 let mut affine = Array2::<f64>::ones((m, d + 1));
436 for i in 0..m {
437 for k in 0..d {
438 affine[(i, k + 1)] = centers[(i, k)];
439 }
440 }
441 let mut weighted_affine = affine.clone();
442 for (i, mut row) in weighted_affine.outer_iter_mut().enumerate() {
443 row.mapv_inplace(|v| v * masses[i]);
444 }
445 let rhs = kz.t().dot(&weighted_affine);
446 let beta = normal_pinv.dot(&rhs);
447 let beta_gram = beta.t().dot(&beta);
448 let (evals, evecs) = beta_gram.eigh(Side::Lower).map_err(|e| {
449 BasisError::InvalidInput(format!(
450 "measure-jet affine ridge subspace eigendecomposition failed: {e}"
451 ))
452 })?;
453 let lam_max = evals.iter().fold(0.0_f64, |acc, v| acc.max((*v).max(0.0)));
454 let rank_tol = MEASURE_JET_PSEUDOINVERSE_RTOL * ((d + 1).max(1) as f64) * lam_max;
455 let mut ridge = Array2::<f64>::eye(p);
456 for k in 0..(d + 1) {
457 let lam = evals[k].max(0.0);
458 if lam <= rank_tol {
459 continue;
460 }
461 let dir = beta.dot(&evecs.column(k).to_owned()) / lam.sqrt();
462 for r in 0..p {
463 for c in 0..p {
464 ridge[(r, c)] -= dir[r] * dir[c];
465 }
466 }
467 }
468 Ok((&ridge + &ridge.t()) * 0.5)
469}
470
471pub(crate) fn pairwise_sq_dists(a: ArrayView2<'_, f64>, b: ArrayView2<'_, f64>) -> Array2<f64> {
482 let an: Vec<f64> = a.outer_iter().map(|r| r.dot(&r)).collect();
483 let bn: Vec<f64> = b.outer_iter().map(|r| r.dot(&r)).collect();
484 let mut g = a.dot(&b.t());
485 g.axis_iter_mut(Axis(0))
486 .into_par_iter()
487 .enumerate()
488 .for_each(|(i, mut row)| {
489 for (j, v) in row.iter_mut().enumerate() {
490 *v = (an[i] + bn[j] - 2.0 * *v).max(0.0);
491 }
492 });
493 g
494}
495
496pub(crate) const MEASURE_JET_ASSIGN_BLOCK_ROWS: usize = 65_536;
500
501pub(crate) fn validate_finite_points(
502 points: ArrayView2<'_, f64>,
503 what: &str,
504) -> Result<(), BasisError> {
505 for (i, row) in points.outer_iter().enumerate() {
506 if row.iter().any(|v| !v.is_finite()) {
507 crate::bail_invalid_basis!("measure-jet {what} row {i} has a non-finite coordinate");
508 }
509 }
510 Ok(())
511}
512
513pub(crate) fn median_nearest_center_spacing(dist2: &Array2<f64>) -> Result<f64, BasisError> {
516 let m = dist2.nrows();
517 if m < 2 {
518 return Err(BasisError::InsufficientColumnsForConstraint { found: m });
519 }
520 let mut nearest: Vec<f64> = Vec::with_capacity(m);
521 for i in 0..m {
522 let mut best = f64::INFINITY;
523 for j in 0..m {
524 if j != i && dist2[(i, j)] < best {
525 best = dist2[(i, j)];
526 }
527 }
528 nearest.push(best.sqrt());
529 }
530 nearest.sort_by(|a, b| a.partial_cmp(b).expect("finite center spacings"));
531 let median = nearest[nearest.len() / 2];
532 if !(median.is_finite() && median > 0.0) {
533 crate::bail_invalid_basis!(
534 "measure-jet centers are degenerate (median nearest-center spacing = {median}); \
535 duplicate centers cannot carry a scale band"
536 );
537 }
538 Ok(median)
539}
540
541pub fn measure_jet_band(
549 centers: ArrayView2<'_, f64>,
550 num_scales: usize,
551) -> Result<MeasureJetBand, BasisError> {
552 validate_finite_points(centers, "centers")?;
553 let dist2 = pairwise_sq_dists(centers, centers);
554 let eps_min = median_nearest_center_spacing(&dist2)?;
555 let d = centers.ncols();
557 let mut diag2 = 0.0_f64;
558 for k in 0..d {
559 let col = centers.column(k);
560 let mut lo = f64::INFINITY;
561 let mut hi = f64::NEG_INFINITY;
562 for &v in col.iter() {
563 lo = lo.min(v);
564 hi = hi.max(v);
565 }
566 diag2 += (hi - lo) * (hi - lo);
567 }
568 let eps_max = 0.5 * diag2.sqrt();
569 if !(eps_max.is_finite() && eps_max > eps_min) {
570 return Ok(MeasureJetBand {
571 eps: vec![eps_min],
572 log_step: std::f64::consts::LN_2,
573 });
574 }
575 let auto = ((eps_max / eps_min).log2().ceil() as usize + 1)
576 .clamp(MEASURE_JET_MIN_AUTO_SCALES, MEASURE_JET_MAX_AUTO_SCALES);
577 let count = if num_scales == 0 { auto } else { num_scales };
578 if count == 1 {
579 return Ok(MeasureJetBand {
580 eps: vec![eps_min],
581 log_step: std::f64::consts::LN_2,
582 });
583 }
584 let ratio = (eps_max / eps_min).powf(1.0 / (count as f64 - 1.0));
585 let mut eps = Vec::with_capacity(count);
586 let mut e = eps_min;
587 for _ in 0..count {
588 eps.push(e);
589 e *= ratio;
590 }
591 Ok(MeasureJetBand {
592 eps,
593 log_step: ratio.ln(),
594 })
595}
596
597pub fn measure_jet_quadrature_nodes(
604 data: ArrayView2<'_, f64>,
605 centers: ArrayView2<'_, f64>,
606) -> Result<(Array2<f64>, Array1<f64>), BasisError> {
607 if data.ncols() != centers.ncols() {
608 crate::bail_dim_basis!(
609 "measure-jet mass assignment dimension mismatch: data d={} centers d={}",
610 data.ncols(),
611 centers.ncols()
612 );
613 }
614 validate_finite_points(data, "data")?;
615 validate_finite_points(centers, "centers")?;
616 let n = data.nrows();
617 let m = centers.nrows();
618 let d = centers.ncols();
619 if n == 0 || m == 0 {
620 crate::bail_invalid_basis!("measure-jet mass assignment needs nonempty data and centers");
621 }
622 let cn: Vec<f64> = centers.outer_iter().map(|r| r.dot(&r)).collect();
627 let assignments: Vec<usize> = (0..n)
628 .step_by(MEASURE_JET_ASSIGN_BLOCK_ROWS)
629 .flat_map(|start| {
630 let end = (start + MEASURE_JET_ASSIGN_BLOCK_ROWS).min(n);
631 let g = data.slice(ndarray::s![start..end, ..]).dot(¢ers.t());
632 let block: Vec<usize> = g
633 .axis_iter(Axis(0))
634 .into_par_iter()
635 .map(|row| {
636 let mut best_j = 0usize;
637 let mut best = f64::INFINITY;
638 for (j, &gij) in row.iter().enumerate() {
639 let s = cn[j] - 2.0 * gij;
640 if s < best {
641 best = s;
642 best_j = j;
643 }
644 }
645 best_j
646 })
647 .collect();
648 block
649 })
650 .collect();
651 let mut masses = Array1::<f64>::zeros(m);
652 let mut nodes = centers.to_owned();
653 let mut sums = Array2::<f64>::zeros((m, d));
654 let unit = 1.0 / n as f64;
655 for (i, &j) in assignments.iter().enumerate() {
656 masses[j] += unit;
657 for k in 0..d {
658 sums[(j, k)] += data[(i, k)];
659 }
660 }
661 let mut barycenter = sums;
664 for j in 0..m {
665 let count = masses[j] * n as f64;
666 if count > 0.0 {
667 for k in 0..d {
668 barycenter[(j, k)] /= count;
669 nodes[(j, k)] = barycenter[(j, k)];
670 }
671 }
672 }
673 Ok((nodes, masses))
674}
675
676pub fn measure_jet_center_masses(
679 data: ArrayView2<'_, f64>,
680 centers: ArrayView2<'_, f64>,
681) -> Result<Array1<f64>, BasisError> {
682 measure_jet_quadrature_nodes(data, centers).map(|(_, masses)| masses)
683}
684
685pub(crate) fn assemble_weighted_forms<F>(
708 centers: ArrayView2<'_, f64>,
709 masses: ArrayView1<'_, f64>,
710 band: &MeasureJetBand,
711 order_s: f64,
712 alpha: f64,
713 tau0: f64,
714 n_forms: usize,
715 channels: usize,
716 weights: &F,
717) -> Result<Vec<Array2<f64>>, BasisError>
718where
719 F: Fn(usize, f64, f64, f64, &mut [[f64; 3]]) + Sync,
720{
721 let m = centers.nrows();
722 let d = centers.ncols();
723 if n_forms == 0 || !(1..=3).contains(&channels) {
724 crate::bail_invalid_basis!(
725 "measure-jet assembly needs at least one output form and 1..=3 block channels"
726 );
727 }
728 if masses.len() != m {
729 crate::bail_dim_basis!(
730 "measure-jet energy mass/center mismatch: {} masses for {} centers",
731 masses.len(),
732 m
733 );
734 }
735 if band.eps.is_empty() || band.eps.iter().any(|e| !(e.is_finite() && *e > 0.0)) {
736 crate::bail_invalid_basis!("measure-jet energy needs a nonempty positive scale band");
737 }
738 if !(order_s.is_finite() && order_s > 0.0 && order_s < 2.0) {
739 crate::bail_invalid_basis!(
740 "measure-jet order s must lie in (0, 2) for the affine-jet energy; got {order_s}"
741 );
742 }
743 if !(alpha.is_finite() && tau0.is_finite() && tau0 >= 0.0) {
744 crate::bail_invalid_basis!(
745 "measure-jet energy needs finite alpha and finite tau0 >= 0; got alpha={alpha}, tau0={tau0}"
746 );
747 }
748 if masses.iter().any(|v| !(v.is_finite() && *v >= 0.0)) {
749 crate::bail_invalid_basis!("measure-jet energy needs finite nonnegative center masses");
750 }
751 let dist2 = pairwise_sq_dists(centers, centers);
752
753 let assemble_scale = |scale_idx: usize, eps: f64| -> Result<Vec<Array2<f64>>, BasisError> {
758 let mut out: Vec<Array2<f64>> =
759 (0..n_forms).map(|_| Array2::<f64>::zeros((m, m))).collect();
760 let cutoff2 = (MEASURE_JET_PROFILE_CUTOFF * eps) * (MEASURE_JET_PROFILE_CUTOFF * eps);
761 let inv_two_eps2 = 1.0 / (2.0 * eps * eps);
762 let eta = 2.0 * order_s + (d as f64) * (2.0 - 2.0 * alpha);
763 let scale_weight = band.log_step * eps.powf(-eta);
764 let net_radius2 = 0.25 * eps * eps;
768 let mut outer: Vec<usize> = Vec::new();
769 for i in 0..m {
770 if masses[i] <= 0.0 {
771 continue;
772 }
773 let covered = outer.iter().any(|&o| dist2[(i, o)] <= net_radius2);
774 if !covered {
775 outer.push(i);
776 }
777 }
778 let mut net_mass = vec![0.0_f64; m];
779 for i in 0..m {
780 if masses[i] <= 0.0 {
781 continue;
782 }
783 let mut best = f64::INFINITY;
784 let mut best_o = usize::MAX;
785 for &o in &outer {
786 if dist2[(i, o)] < best {
787 best = dist2[(i, o)];
788 best_o = o;
789 }
790 }
791 if best_o != usize::MAX {
792 net_mass[best_o] += masses[i];
793 }
794 }
795 let mut wbuf = vec![[0.0_f64; 3]; n_forms];
796 for &i in &outer {
797 let mut idx: Vec<usize> = Vec::new();
799 for j in 0..m {
800 if dist2[(i, j)] <= cutoff2 {
801 idx.push(j);
802 }
803 }
804 let ml = idx.len();
805 let mut w = Array1::<f64>::zeros(ml);
807 let mut q = 0.0_f64;
808 for (a, &j) in idx.iter().enumerate() {
809 let wj = masses[j] * (-dist2[(i, j)] * inv_two_eps2).exp();
810 w[a] = wj;
811 q += wj;
812 }
813 if !(q > 0.0) {
814 continue;
815 }
816 let mut phi = Array2::<f64>::zeros((ml, d));
818 for (a, &j) in idx.iter().enumerate() {
819 for k in 0..d {
820 phi[(a, k)] = (centers[(j, k)] - centers[(i, k)]) / eps;
821 }
822 }
823 let a_mean = phi.t().dot(&w) / q;
824 let mut wphi = phi.clone();
826 for (a, mut row) in wphi.outer_iter_mut().enumerate() {
827 row.mapv_inplace(|v| v * w[a]);
828 }
829 let mut b = wphi.clone();
830 for (a, mut row) in b.outer_iter_mut().enumerate() {
831 for k in 0..d {
832 row[k] -= w[a] * a_mean[k];
833 }
834 }
835 let mut g = phi.t().dot(&wphi);
836 g.mapv_inplace(|v| v / q);
837 for r in 0..d {
838 for c in 0..d {
839 g[(r, c)] -= a_mean[r] * a_mean[c];
840 }
841 }
842 let g_pinv = symmetric_pseudoinverse(&g, "local affine Gram")?;
843 let bm = b.dot(&g_pinv);
844 let base = scale_weight * net_mass[i] * q.powf(1.0 - 2.0 * alpha);
845 weights(scale_idx, eps, q, base, &mut wbuf);
846 for (a, &ja) in idx.iter().enumerate() {
849 let bma = bm.row(a);
850 for (c, &jc) in idx.iter().enumerate() {
851 let b_c = b.row(c);
852 let mut val_r = -w[a] * w[c] / q - bma.dot(&b_c) / q;
853 if a == c {
854 val_r += w[a];
855 }
856 for (k, out_k) in out.iter_mut().enumerate() {
857 let wk = wbuf[k];
858 out_k[(ja, jc)] += wk[0] * val_r;
859 }
860 }
861 }
862 }
863 Ok(out)
864 };
865
866 let n_scales = band.eps.len();
867 let parallel_ok = m
868 .saturating_mul(m)
869 .saturating_mul(n_scales)
870 .saturating_mul(n_forms)
871 <= MEASURE_JET_PARALLEL_FORM_BUDGET_DOUBLES;
872 let per_scale: Vec<Vec<Array2<f64>>> = if parallel_ok {
873 band.eps
874 .par_iter()
875 .enumerate()
876 .map(|(scale_idx, &eps)| assemble_scale(scale_idx, eps))
877 .collect::<Result<Vec<_>, BasisError>>()?
878 } else {
879 band.eps
880 .iter()
881 .enumerate()
882 .map(|(scale_idx, &eps)| assemble_scale(scale_idx, eps))
883 .collect::<Result<Vec<_>, BasisError>>()?
884 };
885
886 let mut totals: Vec<Array2<f64>> = (0..n_forms).map(|_| Array2::<f64>::zeros((m, m))).collect();
887 for scale_forms in per_scale {
888 for (total, part) in totals.iter_mut().zip(scale_forms) {
889 *total += ∂
890 }
891 }
892 Ok(totals.into_iter().map(|t| (&t + &t.t()) * 0.5).collect())
894}
895
896pub fn measure_jet_energy_form(
911 centers: ArrayView2<'_, f64>,
912 masses: ArrayView1<'_, f64>,
913 band: &MeasureJetBand,
914 order_s: f64,
915 alpha: f64,
916 tau0: f64,
917) -> Result<Array2<f64>, BasisError> {
918 let mut forms = assemble_weighted_forms(
919 centers,
920 masses,
921 band,
922 order_s,
923 alpha,
924 tau0,
925 1,
926 1,
927 &|_, _, _, base, out: &mut [[f64; 3]]| out[0] = [base, 0.0, 0.0],
928 )?;
929 let q = forms.swap_remove(0);
930 project_symmetric_psd(q, "measure-jet energy form")
938}
939
940pub(crate) fn project_symmetric_psd(
946 a: Array2<f64>,
947 label: &str,
948) -> Result<Array2<f64>, BasisError> {
949 let n = a.nrows();
950 if n == 0 {
951 return Ok(a);
952 }
953 let (evals, evecs) = a.eigh(Side::Lower).map_err(|e| {
954 BasisError::InvalidInput(format!(
955 "measure-jet PSD projection `{label}` eigendecomposition failed: {e}"
956 ))
957 })?;
958 if evals.iter().all(|&lam| lam >= 0.0) {
959 return Ok(a);
960 }
961 let mut scaled = evecs.clone();
962 for (k, mut col) in scaled.axis_iter_mut(Axis(1)).enumerate() {
963 let lam = evals[k].max(0.0);
964 col.mapv_inplace(|v| v * lam);
965 }
966 let psd = scaled.dot(&evecs.t());
967 Ok((&psd + &psd.t()) * 0.5)
968}
969
970pub fn measure_jet_energy_form_with_jets(
985 centers: ArrayView2<'_, f64>,
986 masses: ArrayView1<'_, f64>,
987 band: &MeasureJetBand,
988 order_s: f64,
989 alpha: f64,
990 tau0: f64,
991) -> Result<MeasureJetEnergyJets, BasisError> {
992 if !(tau0.is_finite() && tau0 > 0.0) {
993 crate::bail_invalid_basis!(
994 "measure-jet jets need tau0 > 0 because the retained τ coordinate is ln τ; got {tau0}"
995 );
996 }
997 let mut forms = assemble_weighted_forms(
998 centers,
999 masses,
1000 band,
1001 order_s,
1002 alpha,
1003 tau0,
1004 10,
1005 3,
1006 &|_, eps: f64, q: f64, base: f64, out: &mut [[f64; 3]]| {
1007 let gs = -2.0 * eps.ln();
1008 let intrinsic_dim = centers.ncols() as f64;
1009 let ga = 2.0 * intrinsic_dim * eps.ln() - 2.0 * q.max(f64::MIN_POSITIVE).ln();
1010 out[0] = [base, 0.0, 0.0];
1011 out[1] = [gs * base, 0.0, 0.0];
1012 out[2] = [gs * gs * base, 0.0, 0.0];
1013 out[3] = [ga * base, 0.0, 0.0];
1014 out[4] = [ga * ga * base, 0.0, 0.0];
1015 out[5] = [gs * ga * base, 0.0, 0.0];
1016 out[6] = [0.0, 0.0, 0.0];
1017 out[7] = [0.0, 0.0, 0.0];
1018 out[8] = [0.0, 0.0, 0.0];
1019 out[9] = [0.0, 0.0, 0.0];
1020 },
1021 )?;
1022 let d2q_dalpha_dlogtau = forms.pop().expect("ten assembled forms");
1023 let d2q_ds_dlogtau = forms.pop().expect("ten assembled forms");
1024 let d2q_dlogtau2 = forms.pop().expect("ten assembled forms");
1025 let dq_dlogtau = forms.pop().expect("ten assembled forms");
1026 let d2q_ds_dalpha = forms.pop().expect("ten assembled forms");
1027 let d2q_dalpha2 = forms.pop().expect("ten assembled forms");
1028 let dq_dalpha = forms.pop().expect("ten assembled forms");
1029 let d2q_ds2 = forms.pop().expect("ten assembled forms");
1030 let dq_ds = forms.pop().expect("ten assembled forms");
1031 let q = forms.pop().expect("ten assembled forms");
1032 Ok(MeasureJetEnergyJets {
1033 q,
1034 dq_ds,
1035 d2q_ds2,
1036 dq_dalpha,
1037 d2q_dalpha2,
1038 d2q_ds_dalpha,
1039 dq_dlogtau,
1040 d2q_dlogtau2,
1041 d2q_ds_dlogtau,
1042 d2q_dalpha_dlogtau,
1043 })
1044}
1045
1046pub fn measure_jet_scale_spectrum(
1052 centers: ArrayView2<'_, f64>,
1053 masses: ArrayView1<'_, f64>,
1054 band: &MeasureJetBand,
1055 order_s: f64,
1056 alpha: f64,
1057 tau0: f64,
1058 values: ArrayView1<'_, f64>,
1059) -> Result<Vec<f64>, BasisError> {
1060 if values.len() != centers.nrows() {
1061 crate::bail_dim_basis!(
1062 "measure-jet scale spectrum needs one value per center: {} values for {} centers",
1063 values.len(),
1064 centers.nrows()
1065 );
1066 }
1067 let forms = measure_jet_energy_forms_per_scale(centers, masses, band, order_s, alpha, tau0)?;
1068 Ok(forms
1069 .iter()
1070 .map(|q_l| values.dot(&q_l.dot(&values)))
1071 .collect())
1072}
1073
1074pub fn measure_jet_energy_forms_per_scale(
1080 centers: ArrayView2<'_, f64>,
1081 masses: ArrayView1<'_, f64>,
1082 band: &MeasureJetBand,
1083 order_s: f64,
1084 alpha: f64,
1085 tau0: f64,
1086) -> Result<Vec<Array2<f64>>, BasisError> {
1087 let n_scales = band.eps.len();
1088 assemble_weighted_forms(
1089 centers,
1090 masses,
1091 band,
1092 order_s,
1093 alpha,
1094 tau0,
1095 n_scales,
1096 1,
1097 &|scale_idx, _, _, base, out: &mut [[f64; 3]]| {
1098 for (k, slot) in out.iter_mut().enumerate() {
1099 *slot = if k == scale_idx {
1100 [base, 0.0, 0.0]
1101 } else {
1102 [0.0, 0.0, 0.0]
1103 };
1104 }
1105 },
1106 )
1107}
1108
1109pub fn measure_jet_support_curve(
1117 queries: ArrayView2<'_, f64>,
1118 centers: ArrayView2<'_, f64>,
1119 masses: ArrayView1<'_, f64>,
1120 eps_band: &[f64],
1121) -> Result<Array2<f64>, BasisError> {
1122 if queries.ncols() != centers.ncols() {
1123 crate::bail_dim_basis!(
1124 "measure-jet support curve dimension mismatch: queries d={} centers d={}",
1125 queries.ncols(),
1126 centers.ncols()
1127 );
1128 }
1129 if masses.len() != centers.nrows() {
1130 crate::bail_dim_basis!(
1131 "measure-jet support curve mass/center mismatch: {} masses for {} centers",
1132 masses.len(),
1133 centers.nrows()
1134 );
1135 }
1136 if eps_band.is_empty() || eps_band.iter().any(|e| !(e.is_finite() && *e > 0.0)) {
1137 crate::bail_invalid_basis!("measure-jet support curve needs a nonempty positive band");
1138 }
1139 validate_finite_points(queries, "queries")?;
1140 validate_finite_points(centers, "centers")?;
1141 let nq = queries.nrows();
1142 let nl = eps_band.len();
1143 let d2 = pairwise_sq_dists(queries, centers);
1146 let mut out = Array2::<f64>::zeros((nq, nl));
1147 out.axis_iter_mut(Axis(0))
1148 .into_par_iter()
1149 .enumerate()
1150 .for_each(|(qi, mut row)| {
1151 let d2_row = d2.row(qi);
1152 for (li, &eps) in eps_band.iter().enumerate() {
1153 let inv_two_eps2 = 1.0 / (2.0 * eps * eps);
1154 let mut acc = 0.0_f64;
1155 for (j, &dd) in d2_row.iter().enumerate() {
1156 acc += masses[j] * (-dd * inv_two_eps2).exp();
1157 }
1158 row[li] = acc;
1159 }
1160 });
1161 Ok(out)
1162}
1163
1164pub(crate) fn measure_jet_support_means(
1165 centers: ArrayView2<'_, f64>,
1166 masses: ArrayView1<'_, f64>,
1167 eps_band: &[f64],
1168) -> Result<Vec<f64>, BasisError> {
1169 let total_mass = masses.sum();
1170 if !(total_mass.is_finite() && total_mass > 0.0) {
1171 crate::bail_invalid_basis!(
1172 "measure-jet support means need positive finite total mass; got {total_mass}"
1173 );
1174 }
1175 let support = measure_jet_support_curve(centers, centers, masses, eps_band)?;
1176 let mut means = vec![0.0_f64; eps_band.len()];
1177 for (i, row) in support.rows().into_iter().enumerate() {
1178 let mass = masses[i];
1179 for (mean, &q) in means.iter_mut().zip(row.iter()) {
1180 *mean += mass * q;
1181 }
1182 }
1183 for mean in &mut means {
1184 *mean /= total_mass;
1185 if !(*mean).is_finite() || *mean <= 0.0 {
1186 crate::bail_invalid_basis!(
1187 "measure-jet support mean must be positive and finite; got {mean}"
1188 );
1189 }
1190 }
1191 Ok(means)
1192}
1193
1194pub fn measure_jet_design_matrix(
1196 data: ArrayView2<'_, f64>,
1197 centers: ArrayView2<'_, f64>,
1198 length_scale: f64,
1199) -> Result<Array2<f64>, BasisError> {
1200 if data.ncols() != centers.ncols() {
1201 crate::bail_dim_basis!(
1202 "measure-jet design dimension mismatch: data d={} centers d={}",
1203 data.ncols(),
1204 centers.ncols()
1205 );
1206 }
1207 if !(length_scale.is_finite() && length_scale > 0.0) {
1208 crate::bail_invalid_basis!(
1209 "measure-jet design needs a positive finite length_scale; got {length_scale}"
1210 );
1211 }
1212 validate_finite_points(data, "data")?;
1213 validate_finite_points(centers, "centers")?;
1214 let inv_two_l2 = 1.0 / (2.0 * length_scale * length_scale);
1215 let mut out = pairwise_sq_dists(data, centers);
1218 out.axis_iter_mut(Axis(0))
1219 .into_par_iter()
1220 .for_each(|mut row| {
1221 row.mapv_inplace(|d2| (-d2 * inv_two_l2).exp());
1222 });
1223 Ok(out)
1224}
1225
1226pub fn realized_measure_jet_length_scale(
1231 centers: ArrayView2<'_, f64>,
1232 spec_length_scale: f64,
1233) -> Result<f64, BasisError> {
1234 if spec_length_scale.is_finite() && spec_length_scale > 0.0 {
1235 return Ok(spec_length_scale);
1236 }
1237 if spec_length_scale != 0.0 {
1238 crate::bail_invalid_basis!(
1239 "measure-jet length_scale must be positive (or 0.0 for auto); got {spec_length_scale}"
1240 );
1241 }
1242 let dist2 = pairwise_sq_dists(centers, centers);
1243 let spacing = median_nearest_center_spacing(&dist2)?;
1244 Ok(MEASURE_JET_AUTO_LENGTH_SCALE_FACTOR * spacing)
1245}
1246
1247pub(crate) struct RealizedMeasureJetGeometry {
1252 pub(crate) centers: Array2<f64>,
1253 pub(crate) masses: Array1<f64>,
1254 pub(crate) eps_band: Vec<f64>,
1255 pub(crate) log_step: f64,
1256 pub(crate) length_scale: f64,
1257 pub(crate) order_s_eval: f64,
1261 pub(crate) per_level: bool,
1263 pub(crate) z: Array2<f64>,
1264 pub(crate) coefficient_gauge: gam_problem::Gauge,
1265 pub(crate) kz: Array2<f64>,
1266}
1267
1268pub(crate) fn realize_measure_jet_geometry(
1269 data: ArrayView2<'_, f64>,
1270 spec: &MeasureJetBasisSpec,
1271) -> Result<RealizedMeasureJetGeometry, BasisError> {
1272 if data.ncols() == 0 {
1273 crate::bail_invalid_basis!("measure-jet smooth needs at least one feature column");
1274 }
1275 validate_finite_points(data, "data")?;
1276 let seed_centers = select_centers_by_strategy(data, &spec.center_strategy)?;
1277 let m = seed_centers.nrows();
1278 if m < 3 {
1279 return Err(BasisError::InsufficientColumnsForConstraint { found: m });
1280 }
1281 let order_s = if spec.order_s == 0.0 {
1282 MEASURE_JET_DEFAULT_ORDER_S
1283 } else {
1284 spec.order_s
1285 };
1286 let (centers, masses, eps_band, log_step) = match &spec.frozen_quadrature {
1293 Some(frozen) => {
1294 if frozen.masses.len() != m {
1295 crate::bail_dim_basis!(
1296 "frozen measure-jet quadrature mismatch: {} masses for {} centers",
1297 frozen.masses.len(),
1298 m
1299 );
1300 }
1301 if frozen.eps_band.is_empty() {
1302 crate::bail_invalid_basis!("frozen measure-jet quadrature has an empty band");
1303 }
1304 let log_step = if frozen.eps_band.len() >= 2 {
1305 (frozen.eps_band[1] / frozen.eps_band[0]).ln()
1306 } else {
1307 std::f64::consts::LN_2
1308 };
1309 (
1310 seed_centers,
1311 frozen.masses.clone(),
1312 frozen.eps_band.clone(),
1313 log_step,
1314 )
1315 }
1316 None => {
1317 let (nodes, masses) = measure_jet_quadrature_nodes(data, seed_centers.view())?;
1318 let band = measure_jet_band(nodes.view(), spec.num_scales)?;
1319 (nodes, masses, band.eps, band.log_step)
1320 }
1321 };
1322 let length_scale = realized_measure_jet_length_scale(centers.view(), spec.length_scale)?;
1323 let (z, coefficient_gauge) = match &spec.identifiability {
1327 MeasureJetIdentifiability::FrozenTransform { transform } => {
1328 if transform.nrows() != m {
1329 crate::bail_dim_basis!(
1330 "frozen measure-jet identifiability transform mismatch: {} centers but transform has {} rows",
1331 m,
1332 transform.nrows()
1333 );
1334 }
1335 (
1336 transform.clone(),
1337 gam_problem::Gauge::from_block_transforms(&[transform.clone()]),
1338 )
1339 }
1340 MeasureJetIdentifiability::CenterSumToZero => {
1341 let u = householder_sum_to_zero_u(m);
1345 let z = householder_sum_to_zero_z(&u);
1346 (z.clone(), gam_problem::Gauge::sum_to_zero(z))
1347 }
1348 };
1349 let k_cc = measure_jet_design_matrix(centers.view(), centers.view(), length_scale)?;
1350 let kz = coefficient_gauge.restrict_design(&k_cc);
1351 Ok(RealizedMeasureJetGeometry {
1352 centers,
1353 masses,
1354 eps_band,
1355 log_step,
1356 length_scale,
1357 order_s_eval: order_s,
1358 per_level: spec.multiscale,
1363 z,
1364 coefficient_gauge,
1365 kz,
1366 })
1367}
1368
1369pub fn measure_jet_multiscale_mode(spec: &MeasureJetBasisSpec) -> bool {
1376 spec.multiscale
1377}
1378
1379pub fn build_measure_jet_basis(
1386 data: ArrayView2<'_, f64>,
1387 spec: &MeasureJetBasisSpec,
1388) -> Result<BasisBuildResult, BasisError> {
1389 let RealizedMeasureJetGeometry {
1390 centers,
1391 masses,
1392 eps_band,
1393 log_step,
1394 length_scale,
1395 order_s_eval: order_s,
1396 per_level,
1397 z,
1398 coefficient_gauge,
1399 kz,
1400 } = realize_measure_jet_geometry(data, spec)?;
1401 let band = MeasureJetBand {
1402 eps: eps_band.clone(),
1403 log_step,
1404 };
1405 let raw_design = measure_jet_design_matrix(data, centers.view(), length_scale)?;
1406 let constrained_design = coefficient_gauge.restrict_design(&raw_design);
1407 let design = gam_linalg::matrix::DesignMatrix::Dense(
1408 gam_linalg::matrix::DenseDesignMatrix::from(constrained_design),
1409 );
1410 let support_means = measure_jet_support_means(centers.view(), masses.view(), &eps_band)?;
1411 let mut candidates = Vec::new();
1424 let mut penalty_normalization_scales = Vec::new();
1425 let mut raw_penalty_normalization_scales = Vec::new();
1426 let mut fused_penalty_normalization_scale = None;
1427 if per_level {
1428 let forms = measure_jet_energy_forms_per_scale(
1429 centers.view(),
1430 masses.view(),
1431 &band,
1432 order_s,
1433 spec.alpha,
1434 spec.tau0,
1435 )?;
1436 for (level, q_l) in forms.into_iter().enumerate() {
1437 let s_l = kz.t().dot(&q_l).dot(&kz);
1438 let (s_norm, c_l) = normalize_penalty(&((&s_l + &s_l.t()) * 0.5));
1439 let intrinsic_dim = centers.ncols() as f64;
1440 let eta = 2.0 * order_s + intrinsic_dim * (2.0 - 2.0 * spec.alpha);
1441 let scale_weight = log_step * eps_band[level].powf(-eta);
1442 penalty_normalization_scales.push(c_l);
1443 raw_penalty_normalization_scales.push(c_l / scale_weight);
1444 candidates.push(PenaltyCandidate {
1445 matrix: s_norm,
1446 nullspace_dim_hint: 0,
1447 source: PenaltySource::Other(format!("measure_jet_scale_{level}")),
1448 normalization_scale: c_l,
1449 kronecker_factors: None,
1450 op: None,
1451 });
1452 }
1453 } else {
1454 let q_form = measure_jet_energy_form(
1455 centers.view(),
1456 masses.view(),
1457 &band,
1458 order_s,
1459 spec.alpha,
1460 spec.tau0,
1461 )?;
1462 let mut penalty = kz.t().dot(&q_form).dot(&kz);
1463 penalty = (&penalty + &penalty.t()) * 0.5;
1464 if spec.double_penalty {
1484 let ridge = affine_preserving_coefficient_ridge(&kz, centers.view(), masses.view())?;
1485 let primary_fro = trace_of_product(&penalty, &penalty).sqrt();
1486 let ridge_fro = trace_of_product(&ridge, &ridge).sqrt();
1487 if primary_fro.is_finite()
1488 && primary_fro > 0.0
1489 && ridge_fro.is_finite()
1490 && ridge_fro > 0.0
1491 {
1492 let w = MEASURE_JET_FUSED_RIDGE_FRACTION * primary_fro / ridge_fro;
1493 penalty = &penalty + &(&ridge * w);
1494 }
1495 }
1496 let (penalty_norm, c_primary) = normalize_penalty(&penalty);
1497 fused_penalty_normalization_scale = Some(c_primary);
1498 candidates.push(PenaltyCandidate {
1499 matrix: penalty_norm,
1500 nullspace_dim_hint: 0,
1501 source: PenaltySource::Primary,
1502 normalization_scale: c_primary,
1503 kronecker_factors: None,
1504 op: None,
1505 });
1506 }
1507 if spec.double_penalty && per_level {
1512 let ridge = affine_preserving_coefficient_ridge(&kz, centers.view(), masses.view())?;
1513 let (ridge_norm, c_ridge) = normalize_penalty(&ridge);
1514 candidates.push(PenaltyCandidate {
1515 matrix: ridge_norm,
1516 nullspace_dim_hint: 0,
1517 source: PenaltySource::DoublePenaltyNullspace,
1518 normalization_scale: c_ridge,
1519 kronecker_factors: None,
1520 op: None,
1521 });
1522 }
1523 let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
1524 filter_active_penalty_candidates_with_ops(candidates)?;
1525 Ok(BasisBuildResult {
1526 design,
1527 penalties,
1528 nullspace_dims,
1529 penaltyinfo,
1530 metadata: BasisMetadata::MeasureJet {
1531 centers,
1532 input_scales: None,
1533 length_scale,
1534 eps_band,
1535 order_s: spec.order_s,
1540 alpha: spec.alpha,
1541 tau0: spec.tau0,
1542 masses,
1543 support_means,
1544 penalty_normalization_scales,
1545 raw_penalty_normalization_scales,
1546 fused_penalty_normalization_scale,
1547 constraint_transform: Some(z),
1548 },
1549 kronecker_factored: None,
1550 ops,
1551 null_eigenvectors,
1552 joint_null_rotation: None,
1553 })
1554}
1555
1556pub fn build_measure_jet_basis_psi_derivatives(
1581 data: ArrayView2<'_, f64>,
1582 spec: &MeasureJetBasisSpec,
1583) -> Result<AnisoBasisPsiDerivatives, BasisError> {
1584 if !(spec.tau0.is_finite() && spec.tau0 > 0.0) {
1585 crate::bail_invalid_basis!(
1586 "measure-jet ψ derivatives need tau0 > 0 because the retained τ coordinate is ln τ; got {}",
1587 spec.tau0
1588 );
1589 }
1590 let geom = realize_measure_jet_geometry(data, spec)?;
1591 let band = MeasureJetBand {
1592 eps: geom.eps_band.clone(),
1593 log_step: geom.log_step,
1594 };
1595 let n = data.nrows();
1596 let p = geom.kz.ncols(); let kz = &geom.kz;
1598 let sandwich = |j: &Array2<f64>| {
1599 let s = kz.t().dot(j).dot(kz);
1600 (&s + &s.t()) * 0.5
1601 };
1602 let (n_coords, pairs, raw): (
1606 usize,
1607 Vec<(usize, usize)>,
1608 Vec<(
1609 Array2<f64>,
1610 Vec<Array2<f64>>,
1611 Vec<Array2<f64>>,
1612 Vec<Array2<f64>>,
1613 )>,
1614 ) = if geom.per_level {
1615 let l_count = band.eps.len();
1616 let forms = assemble_weighted_forms(
1619 geom.centers.view(),
1620 geom.masses.view(),
1621 &band,
1622 geom.order_s_eval,
1623 spec.alpha,
1624 spec.tau0,
1625 6 * l_count,
1626 3,
1627 &|scale_idx, eps: f64, q: f64, base: f64, out: &mut [[f64; 3]]| {
1628 for slot in out.iter_mut() {
1629 *slot = [0.0, 0.0, 0.0];
1630 }
1631 let intrinsic_dim = geom.centers.ncols() as f64;
1632 let ga = 2.0 * intrinsic_dim * eps.ln() - 2.0 * q.max(f64::MIN_POSITIVE).ln();
1633 let k0 = 6 * scale_idx;
1634 out[k0] = [base, 0.0, 0.0];
1635 out[k0 + 1] = [ga * base, 0.0, 0.0];
1636 out[k0 + 2] = [ga * ga * base, 0.0, 0.0];
1637 out[k0 + 3] = [0.0, 0.0, 0.0];
1638 out[k0 + 4] = [0.0, 0.0, 0.0];
1639 out[k0 + 5] = [0.0, 0.0, 0.0];
1640 },
1641 )?;
1642 let mut raw = Vec::with_capacity(l_count);
1643 for level in 0..l_count {
1644 let chunk = &forms[6 * level..6 * level + 6];
1645 raw.push((
1646 sandwich(&chunk[0]),
1647 vec![sandwich(&chunk[1]), sandwich(&chunk[3])],
1648 vec![sandwich(&chunk[2]), sandwich(&chunk[4])],
1649 vec![sandwich(&chunk[5])],
1650 ));
1651 }
1652 (2usize, vec![(0usize, 1usize)], raw)
1653 } else {
1654 let q_value = sandwich(&measure_jet_energy_form(
1664 geom.centers.view(),
1665 geom.masses.view(),
1666 &band,
1667 geom.order_s_eval,
1668 spec.alpha,
1669 spec.tau0,
1670 )?);
1671 let raw = vec![(q_value, Vec::new(), Vec::new(), Vec::new())];
1672 (0usize, Vec::new(), raw)
1673 };
1674 let length_scale_design = if spec.learn_length_scale {
1683 let ell = geom.length_scale;
1684 let k = measure_jet_design_matrix(data, geom.centers.view(), ell)?;
1685 let r2 = pairwise_sq_dists(data, geom.centers.view());
1686 let inv_l2 = 1.0 / (ell * ell);
1687 let mut dk = k.clone();
1688 let mut d2k = k.clone();
1689 for ((dk_v, d2k_v), &r2_v) in dk.iter_mut().zip(d2k.iter_mut()).zip(r2.iter()) {
1690 let a = r2_v * inv_l2;
1691 let kij = *dk_v;
1692 *dk_v = kij * a;
1693 *d2k_v = kij * (a * a - 2.0 * a);
1694 }
1695 let dx_du = geom.coefficient_gauge.restrict_design(&dk);
1697 let d2x_du2 = geom.coefficient_gauge.restrict_design(&d2k);
1698 Some((dx_du, d2x_du2))
1699 } else {
1700 None
1701 };
1702 let n_active = raw.len();
1703 let ridge = spec.double_penalty && geom.per_level;
1709 let n_cands = n_active + usize::from(ridge);
1710 let zero_p = || Array2::<f64>::zeros((p, p));
1711 let mut penalties_first: Vec<Vec<Array2<f64>>> =
1712 (0..n_coords).map(|_| Vec::with_capacity(n_cands)).collect();
1713 let mut penalties_second_diag: Vec<Vec<Array2<f64>>> =
1714 (0..n_coords).map(|_| Vec::with_capacity(n_cands)).collect();
1715 let mut crosses: Vec<Vec<Array2<f64>>> = (0..pairs.len()).map(|_| Vec::new()).collect();
1719 for (s_raw, firsts, seconds, cross_raw) in &raw {
1720 let fro = trace_of_product(s_raw, s_raw).sqrt();
1729 let c = if fro.is_finite() && fro > 1e-12 {
1730 fro
1731 } else {
1732 1.0
1733 };
1734 for coord in 0..n_coords {
1735 let (_, s_first, s_second, _) =
1736 normalize_penaltywith_psi_derivatives(s_raw, &firsts[coord], &seconds[coord]);
1737 penalties_first[coord].push(s_first);
1738 penalties_second_diag[coord].push(s_second);
1739 }
1740 for (pair_idx, &(a, b)) in pairs.iter().enumerate() {
1741 let cross_raw_mat = normalize_penalty_cross_psi_derivative(
1742 s_raw,
1743 &firsts[a],
1744 &firsts[b],
1745 &cross_raw[pair_idx],
1746 c,
1747 );
1748 crosses[pair_idx].push(cross_raw_mat);
1749 }
1750 }
1751 if ridge {
1752 for coord in 0..n_coords {
1753 penalties_first[coord].push(zero_p());
1754 penalties_second_diag[coord].push(zero_p());
1755 }
1756 for pair_crosses in crosses.iter_mut() {
1757 pair_crosses.push(zero_p());
1758 }
1759 }
1760 let coord_offset = usize::from(length_scale_design.is_some());
1766 if coord_offset == 1 {
1767 penalties_first.insert(0, (0..n_cands).map(|_| zero_p()).collect());
1768 penalties_second_diag.insert(0, (0..n_cands).map(|_| zero_p()).collect());
1769 }
1770 let n_coords_total = n_coords + coord_offset;
1771 let mut all_pairs: Vec<(usize, usize)> = pairs
1773 .iter()
1774 .map(|&(a, b)| (a + coord_offset, b + coord_offset))
1775 .collect();
1776 let mut all_crosses: Vec<Vec<Array2<f64>>> = crosses;
1777 if coord_offset == 1 {
1782 for c in 1..n_coords_total {
1783 all_pairs.push((0, c));
1784 all_crosses.push((0..n_cands).map(|_| zero_p()).collect());
1785 }
1786 }
1787 let pair_index: Vec<((usize, usize), Vec<Array2<f64>>)> = all_pairs
1788 .iter()
1789 .copied()
1790 .zip(all_crosses.into_iter())
1791 .collect();
1792 let shifted_pairs = all_pairs;
1793 let provider = AnisoPenaltyCrossProvider::new(move |a, b| {
1794 pair_index
1795 .iter()
1796 .find(|((pa, pb), _)| (*pa, *pb) == (a, b) || (*pa, *pb) == (b, a))
1797 .map(|(_, mats)| mats.clone())
1798 .ok_or_else(|| {
1799 BasisError::InvalidInput(format!(
1800 "measure-jet ψ cross derivative requested for unknown pair ({a}, {b})"
1801 ))
1802 })
1803 });
1804 let mut design_first: Vec<Array2<f64>> = (0..n_coords_total)
1805 .map(|_| Array2::<f64>::zeros((n, p)))
1806 .collect();
1807 let mut design_second_diag: Vec<Array2<f64>> = (0..n_coords_total)
1808 .map(|_| Array2::<f64>::zeros((n, p)))
1809 .collect();
1810 if let Some((dx_du, d2x_du2)) = length_scale_design {
1811 design_first[0] = dx_du;
1812 design_second_diag[0] = d2x_du2;
1813 }
1814 Ok(AnisoBasisPsiDerivatives {
1815 design_first,
1816 design_second_diag,
1817 design_second_cross: Vec::new(),
1818 design_second_cross_pairs: Vec::new(),
1819 penalties_first,
1820 penalties_second_diag,
1821 penalties_cross_pairs: shifted_pairs,
1822 penalties_cross_provider: Some(provider),
1823 implicit_operator: None,
1824 })
1825}
1826
1827#[cfg(test)]
1828mod tests {
1829 use super::*;
1830 pub(crate) fn two_cluster_centers() -> (ndarray::Array2<f64>, ndarray::Array1<f64>) {
1831 let centers = array![
1832 [0.00, 0.00],
1833 [0.31, 0.05],
1834 [0.58, -0.07],
1835 [0.93, 0.11],
1836 [1.22, 0.02],
1837 [1.49, -0.04],
1838 [3.10, 2.00],
1839 [3.42, 2.13],
1840 [3.71, 1.91],
1841 [4.05, 2.07],
1842 [4.33, 1.96],
1843 [4.61, 2.12],
1844 ];
1845 let m = centers.nrows();
1846 let masses = ndarray::Array1::<f64>::from_elem(m, 1.0 / m as f64);
1847 (centers, masses)
1848 }
1849 use ndarray::array;
1850
1851 pub(crate) fn band_for(centers: &Array2<f64>) -> MeasureJetBand {
1852 measure_jet_band(centers.view(), 0).expect("band")
1853 }
1854
1855 #[test]
1858 pub(crate) fn energy_form_annihilates_constants_exactly() {
1859 let (centers, masses) = two_cluster_centers();
1860 let band = band_for(¢ers);
1861 let q = measure_jet_energy_form(centers.view(), masses.view(), &band, 1.5, 1.0, 1e-3)
1862 .expect("energy form");
1863 let m = q.nrows();
1864 let ones = Array1::<f64>::ones(m);
1865 let qv = q.dot(&ones);
1866 let scale = q.iter().fold(0.0_f64, |acc, v| acc.max(v.abs()));
1867 assert!(scale > 0.0, "energy form is identically zero");
1868 for (i, v) in qv.iter().enumerate() {
1869 assert!(
1870 v.abs() <= 1e-12 * scale,
1871 "Q·1 leak at row {i}: {v:.3e} vs scale {scale:.3e}"
1872 );
1873 }
1874 let vqv = ones.dot(&qv);
1875 assert!(
1876 vqv.abs() <= 1e-12 * scale,
1877 "constant carries energy: 1ᵀQ1 = {vqv:.3e}"
1878 );
1879 }
1880
1881 #[test]
1884 pub(crate) fn energy_form_annihilates_affine_at_default_tau() {
1885 let (centers, masses) = two_cluster_centers();
1886 let band = band_for(¢ers);
1887 let m = centers.nrows();
1888 let mut affine = Array1::<f64>::zeros(m);
1890 let mut rough = Array1::<f64>::zeros(m);
1891 for i in 0..m {
1892 affine[i] = 0.7 + 1.3 * centers[(i, 0)] - 0.4 * centers[(i, 1)];
1893 rough[i] = if i % 2 == 0 { 1.0 } else { -1.0 };
1894 }
1895 let q = measure_jet_energy_form(centers.view(), masses.view(), &band, 1.5, 1.0, 1e-3)
1896 .expect("energy form");
1897 let e_affine = affine.dot(&q.dot(&affine));
1898 let e_rough = rough.dot(&q.dot(&rough));
1899 assert!(e_rough > 0.0, "rough vector must pay energy");
1900 assert!(
1901 e_affine.abs() <= 1e-12 * e_rough,
1902 "default affine energy {e_affine:.3e} vs rough {e_rough:.3e}"
1903 );
1904 }
1905
1906 #[test]
1908 pub(crate) fn energy_form_is_psd() {
1909 let (centers, masses) = two_cluster_centers();
1910 let band = band_for(¢ers);
1911 let q = measure_jet_energy_form(centers.view(), masses.view(), &band, 1.5, 1.0, 1e-3)
1912 .expect("energy form");
1913 let m = q.nrows();
1914 for trial in 0..5usize {
1915 let v = Array1::<f64>::from_shape_fn(m, |i| {
1916 ((i * 7 + trial * 13) % 11) as f64 / 11.0 - 0.5
1917 });
1918 let e = v.dot(&q.dot(&v));
1919 assert!(e >= -1e-10, "vᵀQv = {e:.3e} < 0 on trial {trial}");
1920 }
1921 }
1922
1923 #[test]
1926 pub(crate) fn rough_vector_pays_more_than_smooth() {
1927 let m = 24usize;
1928 let centers = Array2::<f64>::from_shape_fn((m, 2), |(i, k)| {
1929 let t = i as f64 / (m as f64 - 1.0);
1930 if k == 0 {
1931 t * 4.0
1932 } else {
1933 0.3 * (t * 4.0).sin()
1934 }
1935 });
1936 let masses = Array1::<f64>::from_elem(m, 1.0 / m as f64);
1937 let band = band_for(¢ers);
1938 let q = measure_jet_energy_form(centers.view(), masses.view(), &band, 1.5, 1.0, 1e-3)
1939 .expect("energy form");
1940 let slow = Array1::<f64>::from_shape_fn(m, |i| (i as f64 / (m as f64 - 1.0)).powi(2));
1941 let fast = Array1::<f64>::from_shape_fn(m, |i| if i % 2 == 0 { 0.5 } else { -0.5 });
1942 let e_slow = slow.dot(&q.dot(&slow));
1943 let e_fast = fast.dot(&q.dot(&fast));
1944 assert!(
1945 e_fast > 10.0 * e_slow,
1946 "alternating values must pay >> a slow trend: fast {e_fast:.3e} vs slow {e_slow:.3e}"
1947 );
1948 }
1949
1950 #[test]
1955 pub(crate) fn energy_jets_match_finite_differences() {
1956 let (centers, masses) = two_cluster_centers();
1957 let band = band_for(¢ers);
1958 let (s0, a0, tau) = (1.3, 0.8, 1e-3);
1959 let jets =
1960 measure_jet_energy_form_with_jets(centers.view(), masses.view(), &band, s0, a0, tau)
1961 .expect("jets");
1962 let q_at = |s: f64, a: f64| {
1963 measure_jet_energy_form(centers.view(), masses.view(), &band, s, a, tau)
1964 .expect("energy form")
1965 };
1966 let q_plain = q_at(s0, a0);
1968 for (a, b) in jets.q.iter().zip(q_plain.iter()) {
1969 assert!(
1970 (a - b).abs() <= 1e-14 * (1.0 + b.abs()),
1971 "Q drift {a} vs {b}"
1972 );
1973 }
1974 let lt0 = tau.ln();
1975 let q_at_lt = |lt: f64| {
1976 measure_jet_energy_form(centers.view(), masses.view(), &band, s0, a0, lt.exp())
1977 .expect("energy form")
1978 };
1979 let h = 1e-4;
1985 let checks: [(&str, &Array2<f64>, Array2<f64>); 9] = [
1986 ("dq_ds", &jets.dq_ds, {
1987 let (p, m_) = (q_at(s0 + h, a0), q_at(s0 - h, a0));
1988 (&p - &m_) / (2.0 * h)
1989 }),
1990 ("d2q_ds2", &jets.d2q_ds2, {
1991 let (p, c, m_) = (q_at(s0 + h, a0), q_at(s0, a0), q_at(s0 - h, a0));
1992 (&(&p + &m_) - &(&c * 2.0)) / (h * h)
1993 }),
1994 ("dq_dalpha", &jets.dq_dalpha, {
1995 let (p, m_) = (q_at(s0, a0 + h), q_at(s0, a0 - h));
1996 (&p - &m_) / (2.0 * h)
1997 }),
1998 ("d2q_dalpha2", &jets.d2q_dalpha2, {
1999 let (p, c, m_) = (q_at(s0, a0 + h), q_at(s0, a0), q_at(s0, a0 - h));
2000 (&(&p + &m_) - &(&c * 2.0)) / (h * h)
2001 }),
2002 ("d2q_ds_dalpha", &jets.d2q_ds_dalpha, {
2003 let pp = q_at(s0 + h, a0 + h);
2004 let pm = q_at(s0 + h, a0 - h);
2005 let mp = q_at(s0 - h, a0 + h);
2006 let mm = q_at(s0 - h, a0 - h);
2007 (&(&pp - &pm) - &(&mp - &mm)) / (4.0 * h * h)
2008 }),
2009 ("dq_dlogtau", &jets.dq_dlogtau, {
2010 let (p, m_) = (q_at_lt(lt0 + h), q_at_lt(lt0 - h));
2011 (&p - &m_) / (2.0 * h)
2012 }),
2013 ("d2q_dlogtau2", &jets.d2q_dlogtau2, {
2014 let (p, c, m_) = (q_at_lt(lt0 + h), q_at_lt(lt0), q_at_lt(lt0 - h));
2015 (&(&p + &m_) - &(&c * 2.0)) / (h * h)
2016 }),
2017 ("d2q_ds_dlogtau", &jets.d2q_ds_dlogtau, {
2018 let f = |s: f64, lt: f64| {
2019 measure_jet_energy_form(centers.view(), masses.view(), &band, s, a0, lt.exp())
2020 .expect("energy form")
2021 };
2022 let pp = f(s0 + h, lt0 + h);
2023 let pm = f(s0 + h, lt0 - h);
2024 let mp = f(s0 - h, lt0 + h);
2025 let mm = f(s0 - h, lt0 - h);
2026 (&(&pp - &pm) - &(&mp - &mm)) / (4.0 * h * h)
2027 }),
2028 ("d2q_dalpha_dlogtau", &jets.d2q_dalpha_dlogtau, {
2029 let f = |a: f64, lt: f64| {
2030 measure_jet_energy_form(centers.view(), masses.view(), &band, s0, a, lt.exp())
2031 .expect("energy form")
2032 };
2033 let pp = f(a0 + h, lt0 + h);
2034 let pm = f(a0 + h, lt0 - h);
2035 let mp = f(a0 - h, lt0 + h);
2036 let mm = f(a0 - h, lt0 - h);
2037 (&(&pp - &pm) - &(&mp - &mm)) / (4.0 * h * h)
2038 }),
2039 ];
2040 for (name, analytic, fd) in checks.iter() {
2041 let scale = fd.iter().fold(1e-30_f64, |acc, v| acc.max(v.abs()));
2042 for (a, b) in analytic.iter().zip(fd.iter()) {
2043 assert!(
2044 (a - b).abs() <= 5e-5 * scale,
2045 "{name} jet mismatch: analytic {a:.6e} vs FD {b:.6e} (scale {scale:.3e})"
2046 );
2047 }
2048 }
2049 }
2050
2051 #[test]
2055 pub(crate) fn scale_spectrum_sums_to_total_and_localizes_roughness() {
2056 let m = 24usize;
2057 let centers = Array2::<f64>::from_shape_fn((m, 2), |(i, k)| {
2058 let t = i as f64 / (m as f64 - 1.0);
2059 if k == 0 { t * 4.0 } else { 0.0 }
2060 });
2061 let masses = Array1::<f64>::from_elem(m, 1.0 / m as f64);
2062 let band = band_for(¢ers);
2063 let q = measure_jet_energy_form(centers.view(), masses.view(), &band, 1.5, 1.0, 1e-3)
2064 .expect("energy form");
2065 let fast = Array1::<f64>::from_shape_fn(m, |i| if i % 2 == 0 { 0.5 } else { -0.5 });
2066 let spec = measure_jet_scale_spectrum(
2067 centers.view(),
2068 masses.view(),
2069 &band,
2070 1.5,
2071 1.0,
2072 1e-3,
2073 fast.view(),
2074 )
2075 .expect("spectrum");
2076 assert_eq!(spec.len(), band.eps.len());
2077 let total = fast.dot(&q.dot(&fast));
2078 let sum: f64 = spec.iter().sum();
2079 assert!(
2080 (sum - total).abs() <= 1e-10 * total.abs().max(1e-30),
2081 "spectrum must sum to vᵀQv: {sum:.6e} vs {total:.6e}"
2082 );
2083 let finest = spec[0];
2085 let coarsest = *spec.last().expect("nonempty spectrum");
2086 assert!(
2087 finest > coarsest,
2088 "alternating values must charge fine scales hardest: fine {finest:.3e} vs coarse {coarsest:.3e}"
2089 );
2090 }
2091
2092 #[test]
2095 pub(crate) fn support_curve_separates_on_web_from_off_web() {
2096 let m = 24usize;
2097 let centers = Array2::<f64>::from_shape_fn((m, 2), |(i, k)| {
2098 let t = i as f64 / (m as f64 - 1.0);
2099 if k == 0 { t * 4.0 } else { 0.0 }
2100 });
2101 let masses = Array1::<f64>::from_elem(m, 1.0 / m as f64);
2102 let band = band_for(¢ers);
2103 let queries = array![[2.0, 0.0], [2.0, 1.5]];
2104 let curves =
2105 measure_jet_support_curve(queries.view(), centers.view(), masses.view(), &band.eps)
2106 .expect("support curve");
2107 assert!(
2109 curves[(0, 0)] > 10.0 * curves[(1, 0)],
2110 "fine-scale support must separate web from void: on {:.3e} vs off {:.3e}",
2111 curves[(0, 0)],
2112 curves[(1, 0)]
2113 );
2114 for qi in 0..2 {
2116 for li in 1..band.eps.len() {
2117 assert!(
2118 curves[(qi, li)] >= curves[(qi, li - 1)] - 1e-15,
2119 "support curve must be monotone in scale (query {qi}, level {li})"
2120 );
2121 }
2122 }
2123 }
2124
2125 #[test]
2132 pub(crate) fn default_stays_single_scale_until_multiscale_opt_in() {
2133 let n = 200usize;
2134 let data = Array2::<f64>::from_shape_fn((n, 2), |(i, k)| {
2135 let t = i as f64 / (n as f64 - 1.0);
2136 if k == 0 {
2137 t * 3.0
2138 } else {
2139 0.4 * (t * 3.0).sin()
2140 }
2141 });
2142 let single = MeasureJetBasisSpec {
2148 center_strategy: CenterStrategy::FarthestPoint { num_centers: 80 },
2149 ..MeasureJetBasisSpec::default()
2150 };
2151 assert!(
2152 !measure_jet_multiscale_mode(&single),
2153 "default must resolve to single-scale at any center count"
2154 );
2155 let built_single =
2156 build_measure_jet_basis(data.view(), &single).expect("single-scale build");
2157 assert_eq!(
2158 built_single.penalties.len(),
2159 1,
2160 "single-scale mode emits one fused penalty (ridge folded in, not a 2nd λ)"
2161 );
2162 let multi = MeasureJetBasisSpec {
2166 center_strategy: CenterStrategy::FarthestPoint { num_centers: 80 },
2167 multiscale: true,
2168 ..MeasureJetBasisSpec::default()
2169 };
2170 assert!(
2171 measure_jet_multiscale_mode(&multi),
2172 "multiscale=true must resolve to multiscale mode"
2173 );
2174 let built_multi = build_measure_jet_basis(data.view(), &multi).expect("multiscale build");
2175 assert!(
2176 built_multi.penalties.len() > built_single.penalties.len(),
2177 "multiscale mode emits the per-scale spectral split plus the ridge, got {} (vs single-scale {})",
2178 built_multi.penalties.len(),
2179 built_single.penalties.len()
2180 );
2181 }
2182
2183 #[test]
2188 pub(crate) fn fused_mode_emits_single_primary_candidate() {
2189 let n = 40usize;
2190 let data = Array2::<f64>::from_shape_fn((n, 2), |(i, k)| {
2191 let t = i as f64 / (n as f64 - 1.0);
2192 if k == 0 {
2193 t * 3.0
2194 } else {
2195 0.4 * (t * 3.0).sin()
2196 }
2197 });
2198 let spec = MeasureJetBasisSpec {
2199 center_strategy: CenterStrategy::FarthestPoint { num_centers: 14 },
2200 order_s: 1.3,
2201 ..MeasureJetBasisSpec::default()
2202 };
2203 let built = build_measure_jet_basis(data.view(), &spec).expect("fused build");
2204 assert_eq!(
2205 built.penalties.len(),
2206 1,
2207 "fused single-scale mode emits exactly one Primary candidate (ridge folded in)"
2208 );
2209 let BasisMetadata::MeasureJet { order_s, .. } = &built.metadata else {
2210 panic!("measure-jet build must return MeasureJet metadata");
2211 };
2212 assert_eq!(*order_s, 1.3, "explicit order must persist verbatim");
2213 }
2214
2215 #[test]
2217 pub(crate) fn householder_sum_to_zero_basis_is_orthonormal() {
2218 let m = 9usize;
2219 let u = householder_sum_to_zero_u(m);
2220 let z = householder_sum_to_zero_z(&u);
2221 for j in 0..(m - 1) {
2222 let col_j = z.column(j);
2223 assert!(col_j.sum().abs() <= 1e-12, "column {j} must sum to zero");
2224 for j2 in j..(m - 1) {
2225 let dot = col_j.dot(&z.column(j2));
2226 let want = if j == j2 { 1.0 } else { 0.0 };
2227 assert!(
2228 (dot - want).abs() <= 1e-12,
2229 "orthonormality failure at ({j}, {j2}): {dot}"
2230 );
2231 }
2232 }
2233 }
2234
2235 pub(crate) fn frozen_spec_fixture(
2240 order_s: f64,
2241 multiscale: bool,
2242 ) -> (Array2<f64>, MeasureJetBasisSpec) {
2243 let n = 140usize;
2248 let data = Array2::<f64>::from_shape_fn((n, 2), |(i, k)| {
2249 let t = i as f64 / (n as f64 - 1.0);
2250 if k == 0 {
2251 t * 3.0
2252 } else {
2253 0.5 * (t * 3.0).cos() + if i % 9 == 0 { 0.8 } else { 0.0 }
2254 }
2255 });
2256 let spec = MeasureJetBasisSpec {
2257 center_strategy: CenterStrategy::FarthestPoint { num_centers: 70 },
2258 order_s,
2259 multiscale,
2260 learn_length_scale: false,
2264 ..MeasureJetBasisSpec::default()
2265 };
2266 let first = build_measure_jet_basis(data.view(), &spec).expect("fixture build");
2267 let BasisMetadata::MeasureJet {
2268 centers,
2269 length_scale,
2270 eps_band,
2271 masses,
2272 support_means,
2273 penalty_normalization_scales,
2274 raw_penalty_normalization_scales,
2275 fused_penalty_normalization_scale,
2276 constraint_transform,
2277 ..
2278 } = &first.metadata
2279 else {
2280 panic!("measure-jet build must return MeasureJet metadata");
2281 };
2282 let frozen = MeasureJetBasisSpec {
2283 center_strategy: CenterStrategy::UserProvided(centers.clone()),
2284 order_s,
2285 alpha: spec.alpha,
2286 tau0: spec.tau0,
2287 num_scales: eps_band.len(),
2288 length_scale: *length_scale,
2289 double_penalty: spec.double_penalty,
2290 learn_length_scale: false,
2291 multiscale,
2292 identifiability: MeasureJetIdentifiability::FrozenTransform {
2293 transform: constraint_transform.clone().expect("fit-time z"),
2294 },
2295 frozen_quadrature: Some(MeasureJetFrozenQuadrature {
2296 masses: masses.clone(),
2297 eps_band: eps_band.clone(),
2298 support_means: support_means.clone(),
2299 penalty_normalization_scales: penalty_normalization_scales.clone(),
2300 raw_penalty_normalization_scales: raw_penalty_normalization_scales.clone(),
2301 fused_penalty_normalization_scale: *fused_penalty_normalization_scale,
2302 }),
2303 };
2304 (data, frozen)
2305 }
2306
2307 #[test]
2312 pub(crate) fn psi_producer_matches_fd_per_level_mode() {
2313 let (data, frozen) = frozen_spec_fixture(0.0, true);
2314 let derivs =
2315 build_measure_jet_basis_psi_derivatives(data.view(), &frozen).expect("psi derivatives");
2316 let l_count = frozen
2317 .frozen_quadrature
2318 .as_ref()
2319 .expect("frozen quadrature")
2320 .eps_band
2321 .len();
2322 assert_eq!(
2323 derivs.penalties_first.len(),
2324 2,
2325 "per-level coords are (α, lnτ)"
2326 );
2327 assert_eq!(derivs.penalties_first[0].len(), l_count + 1);
2328 assert_eq!(derivs.penalties_cross_pairs, vec![(0, 1)]);
2329 let pen_at = |alpha: f64, tau0: f64| {
2330 let trial = MeasureJetBasisSpec {
2331 alpha,
2332 tau0,
2333 ..frozen.clone()
2334 };
2335 build_measure_jet_basis(data.view(), &trial)
2336 .expect("trial build")
2337 .penalties
2338 };
2339 let h = 1e-4;
2342 let (a0, t0) = (frozen.alpha, frozen.tau0);
2343 let ap = pen_at(a0 + h, t0);
2344 let am = pen_at(a0 - h, t0);
2345 let tp = pen_at(a0, t0 * h.exp());
2346 let tm = pen_at(a0, t0 * (-h).exp());
2347 assert_eq!(
2348 ap.len(),
2349 l_count + 1,
2350 "fixture must keep every scale active"
2351 );
2352 for level in 0..l_count {
2353 let fd_alpha = (&ap[level] - &am[level]) / (2.0 * h);
2354 let fd_tau = (&tp[level] - &tm[level]) / (2.0 * h);
2355 for (name, analytic, fd) in [
2356 ("alpha", &derivs.penalties_first[0][level], fd_alpha),
2357 ("ln_tau", &derivs.penalties_first[1][level], fd_tau),
2358 ] {
2359 let scale = fd.iter().fold(1e-30_f64, |acc, v| acc.max(v.abs()));
2360 for (x, y) in analytic.iter().zip(fd.iter()) {
2361 assert!(
2362 (x - y).abs() <= 5e-5 * scale,
2363 "{name} jet of scale-candidate {level}: analytic {x:.6e} vs FD {y:.6e}"
2364 );
2365 }
2366 }
2367 }
2368 for coord in 0..2 {
2370 assert!(
2371 derivs.penalties_first[coord][l_count]
2372 .iter()
2373 .all(|v| *v == 0.0),
2374 "ridge candidate must have zero ψ drift"
2375 );
2376 }
2377 let provider = derivs
2379 .penalties_cross_provider
2380 .as_ref()
2381 .expect("cross provider");
2382 let cross = provider.evaluate(0, 1).expect("cross pair (α, lnτ)");
2383 let pp = pen_at(a0 + h, t0 * h.exp());
2384 let pm = pen_at(a0 + h, t0 * (-h).exp());
2385 let mp = pen_at(a0 - h, t0 * h.exp());
2386 let mm = pen_at(a0 - h, t0 * (-h).exp());
2387 for level in 0..l_count {
2388 let fd = (&(&pp[level] - &pm[level]) - &(&mp[level] - &mm[level])) / (4.0 * h * h);
2389 let scale = fd.iter().fold(1e-30_f64, |acc, v| acc.max(v.abs()));
2390 for (x, y) in cross[level].iter().zip(fd.iter()) {
2391 assert!(
2392 (x - y).abs() <= 5e-4 * scale,
2393 "cross (α, lnτ) jet of scale-candidate {level}: analytic {x:.6e} vs FD {y:.6e}"
2394 );
2395 }
2396 }
2397 }
2398
2399 #[test]
2405 pub(crate) fn psi_producer_matches_fd_length_scale() {
2406 let (data, mut frozen) = frozen_spec_fixture(0.0, false);
2409 frozen.learn_length_scale = true;
2410 let derivs =
2411 build_measure_jet_basis_psi_derivatives(data.view(), &frozen).expect("psi derivatives");
2412 assert_eq!(
2414 derivs.design_first.len(),
2415 1,
2416 "single-scale + learn_length_scale enrolls exactly the ℓ coordinate"
2417 );
2418 assert_eq!(
2421 derivs.penalties_first[0].len(),
2422 1,
2423 "one fitted penalty candidate"
2424 );
2425 assert!(
2426 derivs.penalties_first[0][0].iter().all(|v| *v == 0.0)
2427 && derivs.penalties_second_diag[0][0].iter().all(|v| *v == 0.0),
2428 "the jet-energy penalty must not move with ℓ"
2429 );
2430 let ell0 = frozen.length_scale;
2433 let design_at = |ell: f64| {
2434 let trial = MeasureJetBasisSpec {
2435 length_scale: ell,
2436 ..frozen.clone()
2437 };
2438 build_measure_jet_basis(data.view(), &trial)
2439 .expect("trial build")
2440 .design
2441 .to_dense()
2442 };
2443 let h: f64 = 1e-4;
2444 let x_plus = design_at(ell0 * h.exp());
2445 let x_minus = design_at(ell0 * (-h).exp());
2446 let x_0 = design_at(ell0);
2447 let fd_first = (&x_plus - &x_minus) / (2.0 * h);
2448 let fd_second = (&x_plus - &(&x_0 * 2.0) + &x_minus) / (h * h);
2449 let scale1 = fd_first.iter().fold(1e-30_f64, |acc, v| acc.max(v.abs()));
2450 for (x, y) in derivs.design_first[0].iter().zip(fd_first.iter()) {
2451 assert!(
2452 (x - y).abs() <= 5e-5 * scale1,
2453 "∂X/∂lnℓ: analytic {x:.6e} vs FD {y:.6e}"
2454 );
2455 }
2456 let scale2 = fd_second.iter().fold(1e-30_f64, |acc, v| acc.max(v.abs()));
2457 for (x, y) in derivs.design_second_diag[0].iter().zip(fd_second.iter()) {
2458 assert!(
2459 (x - y).abs() <= 1e-3 * scale2,
2460 "∂²X/∂lnℓ²: analytic {x:.6e} vs FD {y:.6e}"
2461 );
2462 }
2463 }
2464
2465 #[test]
2469 pub(crate) fn quadrature_nodes_are_cell_barycenters() {
2470 let data = array![
2473 [0.0, 0.2],
2474 [0.4, -0.2],
2475 [0.2, 0.0],
2476 [9.8, 10.1],
2477 [10.2, 9.9],
2478 ];
2479 let seeds = array![[0.1, 0.1], [10.0, 10.0], [-50.0, -50.0]];
2480 let (nodes, masses) =
2481 measure_jet_quadrature_nodes(data.view(), seeds.view()).expect("quadrature nodes");
2482 assert!((masses.sum() - 1.0).abs() <= 1e-15, "masses must sum to 1");
2483 assert!((masses[0] - 0.6).abs() <= 1e-15);
2484 assert!((masses[1] - 0.4).abs() <= 1e-15);
2485 assert_eq!(masses[2], 0.0);
2486 assert_eq!(nodes[(0, 0)], 0.2);
2488 assert_eq!(nodes[(0, 1)], 0.0);
2489 assert_eq!(nodes[(1, 0)], 10.0);
2491 assert_eq!(nodes[(1, 1)], 10.0);
2492 assert_eq!(nodes[(2, 0)], -50.0);
2494 assert_eq!(nodes[(2, 1)], -50.0);
2495 }
2496
2497 #[test]
2501 pub(crate) fn build_replay_roundtrip_reproduces_design_and_penalty() {
2502 let n = 140usize;
2505 let data = Array2::<f64>::from_shape_fn((n, 2), |(i, k)| {
2506 let t = i as f64 / (n as f64 - 1.0);
2507 if k == 0 {
2508 t * 3.0
2509 } else {
2510 0.5 * (t * 3.0).cos() + if i % 9 == 0 { 0.8 } else { 0.0 }
2511 }
2512 });
2513 let spec = MeasureJetBasisSpec {
2514 center_strategy: CenterStrategy::FarthestPoint { num_centers: 70 },
2515 multiscale: true,
2516 ..MeasureJetBasisSpec::default()
2517 };
2518 let first = build_measure_jet_basis(data.view(), &spec).expect("first build");
2519 let BasisMetadata::MeasureJet {
2520 centers,
2521 length_scale,
2522 eps_band,
2523 order_s,
2524 alpha,
2525 tau0,
2526 masses,
2527 support_means,
2528 penalty_normalization_scales,
2529 raw_penalty_normalization_scales,
2530 fused_penalty_normalization_scale,
2531 constraint_transform,
2532 ..
2533 } = &first.metadata
2534 else {
2535 panic!("measure-jet build must return MeasureJet metadata");
2536 };
2537 let replay_spec = MeasureJetBasisSpec {
2538 center_strategy: CenterStrategy::UserProvided(centers.clone()),
2539 order_s: *order_s,
2540 alpha: *alpha,
2541 tau0: *tau0,
2542 num_scales: eps_band.len(),
2543 length_scale: *length_scale,
2544 double_penalty: spec.double_penalty,
2545 learn_length_scale: spec.learn_length_scale,
2546 multiscale: spec.multiscale,
2547 identifiability: MeasureJetIdentifiability::FrozenTransform {
2548 transform: constraint_transform.clone().expect("fit-time z"),
2549 },
2550 frozen_quadrature: Some(MeasureJetFrozenQuadrature {
2551 masses: masses.clone(),
2552 eps_band: eps_band.clone(),
2553 support_means: support_means.clone(),
2554 penalty_normalization_scales: penalty_normalization_scales.clone(),
2555 raw_penalty_normalization_scales: raw_penalty_normalization_scales.clone(),
2556 fused_penalty_normalization_scale: *fused_penalty_normalization_scale,
2557 }),
2558 };
2559 assert_eq!(
2562 first.penalties.len(),
2563 eps_band.len() + 1,
2564 "per-level mode must emit one candidate per scale + ridge"
2565 );
2566 let second = build_measure_jet_basis(data.view(), &replay_spec).expect("replay build");
2567 let x1 = first.design.to_dense();
2568 let x2 = second.design.to_dense();
2569 assert_eq!(x1.shape(), x2.shape());
2570 for (a, b) in x1.iter().zip(x2.iter()) {
2571 assert!((a - b).abs() <= 1e-12, "design replay drift: {a} vs {b}");
2572 }
2573 assert_eq!(first.penalties.len(), second.penalties.len());
2574 for (p1, p2) in first.penalties.iter().zip(second.penalties.iter()) {
2575 for (a, b) in p1.iter().zip(p2.iter()) {
2576 assert!((a - b).abs() <= 1e-12, "penalty replay drift: {a} vs {b}");
2577 }
2578 }
2579 }
2580}