use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use faer::Side;
use gam_linalg::faer_ndarray::FaerEigh;
use super::{
AnisoBasisPsiDerivatives, AnisoPenaltyCrossProvider, BasisBuildResult, BasisError,
BasisMetadata, CenterStrategy, PenaltyCandidate, PenaltySource,
filter_active_penalty_candidates_with_ops, normalize_penalty,
normalize_penalty_cross_psi_derivative, normalize_penaltywith_psi_derivatives,
select_centers_by_strategy, trace_of_product,
};
pub(crate) const MEASURE_JET_PROFILE_CUTOFF: f64 = 3.0;
pub(crate) const MEASURE_JET_PSEUDOINVERSE_RTOL: f64 = 64.0 * f64::EPSILON;
pub(crate) const MEASURE_JET_DEFAULT_ORDER_S: f64 = 1.5;
pub(crate) const MEASURE_JET_MIN_AUTO_SCALES: usize = 3;
pub(crate) const MEASURE_JET_MAX_AUTO_SCALES: usize = 8;
pub(crate) const MEASURE_JET_AUTO_LENGTH_SCALE_FACTOR: f64 = 1.0;
pub(crate) const MEASURE_JET_FUSED_RIDGE_FRACTION: f64 = 1e-2;
pub(crate) const MEASURE_JET_PARALLEL_FORM_BUDGET_DOUBLES: usize = 1 << 26;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub enum MeasureJetIdentifiability {
#[default]
CenterSumToZero,
FrozenTransform { transform: Array2<f64> },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MeasureJetFrozenQuadrature {
pub masses: Array1<f64>,
pub eps_band: Vec<f64>,
pub support_means: Vec<f64>,
pub penalty_normalization_scales: Vec<f64>,
pub raw_penalty_normalization_scales: Vec<f64>,
pub fused_penalty_normalization_scale: Option<f64>,
}
fn measure_jet_learn_length_scale_default() -> bool {
false
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MeasureJetBasisSpec {
pub center_strategy: CenterStrategy,
pub order_s: f64,
pub alpha: f64,
pub tau0: f64,
pub num_scales: usize,
pub length_scale: f64,
pub double_penalty: bool,
#[serde(default = "measure_jet_learn_length_scale_default")]
pub learn_length_scale: bool,
#[serde(default)]
pub multiscale: bool,
#[serde(default)]
pub identifiability: MeasureJetIdentifiability,
#[serde(default)]
pub frozen_quadrature: Option<MeasureJetFrozenQuadrature>,
}
impl Default for MeasureJetBasisSpec {
fn default() -> Self {
Self {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 50 },
order_s: 0.0,
alpha: 1.0,
tau0: 1e-3,
num_scales: 0,
length_scale: 0.0,
double_penalty: true,
learn_length_scale: false,
multiscale: false,
identifiability: MeasureJetIdentifiability::CenterSumToZero,
frozen_quadrature: None,
}
}
}
pub struct MeasureJetBand {
pub eps: Vec<f64>,
pub log_step: f64,
}
pub struct MeasureJetEnergyJets {
pub q: Array2<f64>,
pub dq_ds: Array2<f64>,
pub d2q_ds2: Array2<f64>,
pub dq_dalpha: Array2<f64>,
pub d2q_dalpha2: Array2<f64>,
pub d2q_ds_dalpha: Array2<f64>,
pub dq_dlogtau: Array2<f64>,
pub d2q_dlogtau2: Array2<f64>,
pub d2q_ds_dlogtau: Array2<f64>,
pub d2q_dalpha_dlogtau: Array2<f64>,
}
pub(crate) fn householder_sum_to_zero_u(m: usize) -> Array1<f64> {
let c = 1.0 / (m as f64).sqrt();
let mut u = Array1::<f64>::from_elem(m, c);
u[0] -= 1.0;
let norm = u.dot(&u).sqrt();
u.mapv_inplace(|v| v / norm);
u
}
pub(crate) fn householder_sum_to_zero_z(u: &Array1<f64>) -> Array2<f64> {
let m = u.len();
let mut z = Array2::<f64>::zeros((m, m - 1));
for j in 0..(m - 1) {
for i in 0..m {
let h = if i == j + 1 { 1.0 } else { 0.0 } - 2.0 * u[i] * u[j + 1];
z[(i, j)] = h;
}
}
z
}
pub(crate) fn symmetric_pseudoinverse(
a: &Array2<f64>,
label: &str,
) -> Result<Array2<f64>, BasisError> {
let n = a.nrows();
if a.ncols() != n {
crate::bail_dim_basis!(
"measure-jet pseudo-inverse `{label}` needs a square matrix, got {:?}",
a.dim()
);
}
let (evals, evecs) = a.eigh(Side::Lower).map_err(|e| {
BasisError::InvalidInput(format!(
"measure-jet pseudo-inverse `{label}` eigendecomposition failed: {e}"
))
})?;
let lam_max = evals.iter().fold(0.0_f64, |acc, v| acc.max((*v).max(0.0)));
let rank_tol = MEASURE_JET_PSEUDOINVERSE_RTOL * (n.max(1) as f64) * lam_max;
let mut scaled = evecs.clone();
for (k, mut col) in scaled.axis_iter_mut(Axis(1)).enumerate() {
let lam = evals[k].max(0.0);
let inv = if lam > rank_tol { 1.0 / lam } else { 0.0 };
col.mapv_inplace(|v| v * inv);
}
Ok(scaled.dot(&evecs.t()))
}
pub(crate) fn affine_preserving_coefficient_ridge(
kz: &Array2<f64>,
centers: ArrayView2<'_, f64>,
masses: ArrayView1<'_, f64>,
) -> Result<Array2<f64>, BasisError> {
let m = centers.nrows();
let d = centers.ncols();
let p = kz.ncols();
if kz.nrows() != m || masses.len() != m {
crate::bail_dim_basis!(
"measure-jet affine-preserving ridge shape mismatch: kz {:?}, centers {:?}, masses {}",
kz.dim(),
centers.dim(),
masses.len()
);
}
if p == 0 {
return Ok(Array2::<f64>::zeros((0, 0)));
}
let mut weighted_kz = kz.clone();
for (i, mut row) in weighted_kz.outer_iter_mut().enumerate() {
row.mapv_inplace(|v| v * masses[i]);
}
let normal = kz.t().dot(&weighted_kz);
let normal_pinv = symmetric_pseudoinverse(&normal, "affine ridge normal")?;
let mut affine = Array2::<f64>::ones((m, d + 1));
for i in 0..m {
for k in 0..d {
affine[(i, k + 1)] = centers[(i, k)];
}
}
let mut weighted_affine = affine.clone();
for (i, mut row) in weighted_affine.outer_iter_mut().enumerate() {
row.mapv_inplace(|v| v * masses[i]);
}
let rhs = kz.t().dot(&weighted_affine);
let beta = normal_pinv.dot(&rhs);
let beta_gram = beta.t().dot(&beta);
let (evals, evecs) = beta_gram.eigh(Side::Lower).map_err(|e| {
BasisError::InvalidInput(format!(
"measure-jet affine ridge subspace eigendecomposition failed: {e}"
))
})?;
let lam_max = evals.iter().fold(0.0_f64, |acc, v| acc.max((*v).max(0.0)));
let rank_tol = MEASURE_JET_PSEUDOINVERSE_RTOL * ((d + 1).max(1) as f64) * lam_max;
let mut ridge = Array2::<f64>::eye(p);
for k in 0..(d + 1) {
let lam = evals[k].max(0.0);
if lam <= rank_tol {
continue;
}
let dir = beta.dot(&evecs.column(k).to_owned()) / lam.sqrt();
for r in 0..p {
for c in 0..p {
ridge[(r, c)] -= dir[r] * dir[c];
}
}
}
Ok((&ridge + &ridge.t()) * 0.5)
}
pub(crate) fn pairwise_sq_dists(a: ArrayView2<'_, f64>, b: ArrayView2<'_, f64>) -> Array2<f64> {
let an: Vec<f64> = a.outer_iter().map(|r| r.dot(&r)).collect();
let bn: Vec<f64> = b.outer_iter().map(|r| r.dot(&r)).collect();
let mut g = a.dot(&b.t());
g.axis_iter_mut(Axis(0))
.into_par_iter()
.enumerate()
.for_each(|(i, mut row)| {
for (j, v) in row.iter_mut().enumerate() {
*v = (an[i] + bn[j] - 2.0 * *v).max(0.0);
}
});
g
}
pub(crate) const MEASURE_JET_ASSIGN_BLOCK_ROWS: usize = 65_536;
pub(crate) fn validate_finite_points(
points: ArrayView2<'_, f64>,
what: &str,
) -> Result<(), BasisError> {
for (i, row) in points.outer_iter().enumerate() {
if row.iter().any(|v| !v.is_finite()) {
crate::bail_invalid_basis!("measure-jet {what} row {i} has a non-finite coordinate");
}
}
Ok(())
}
pub(crate) fn median_nearest_center_spacing(dist2: &Array2<f64>) -> Result<f64, BasisError> {
let m = dist2.nrows();
if m < 2 {
return Err(BasisError::InsufficientColumnsForConstraint { found: m });
}
let mut nearest: Vec<f64> = Vec::with_capacity(m);
for i in 0..m {
let mut best = f64::INFINITY;
for j in 0..m {
if j != i && dist2[(i, j)] < best {
best = dist2[(i, j)];
}
}
nearest.push(best.sqrt());
}
nearest.sort_by(|a, b| a.partial_cmp(b).expect("finite center spacings"));
let median = nearest[nearest.len() / 2];
if !(median.is_finite() && median > 0.0) {
crate::bail_invalid_basis!(
"measure-jet centers are degenerate (median nearest-center spacing = {median}); \
duplicate centers cannot carry a scale band"
);
}
Ok(median)
}
pub fn measure_jet_band(
centers: ArrayView2<'_, f64>,
num_scales: usize,
) -> Result<MeasureJetBand, BasisError> {
validate_finite_points(centers, "centers")?;
let dist2 = pairwise_sq_dists(centers, centers);
let eps_min = median_nearest_center_spacing(&dist2)?;
let d = centers.ncols();
let mut diag2 = 0.0_f64;
for k in 0..d {
let col = centers.column(k);
let mut lo = f64::INFINITY;
let mut hi = f64::NEG_INFINITY;
for &v in col.iter() {
lo = lo.min(v);
hi = hi.max(v);
}
diag2 += (hi - lo) * (hi - lo);
}
let eps_max = 0.5 * diag2.sqrt();
if !(eps_max.is_finite() && eps_max > eps_min) {
return Ok(MeasureJetBand {
eps: vec![eps_min],
log_step: std::f64::consts::LN_2,
});
}
let auto = ((eps_max / eps_min).log2().ceil() as usize + 1)
.clamp(MEASURE_JET_MIN_AUTO_SCALES, MEASURE_JET_MAX_AUTO_SCALES);
let count = if num_scales == 0 { auto } else { num_scales };
if count == 1 {
return Ok(MeasureJetBand {
eps: vec![eps_min],
log_step: std::f64::consts::LN_2,
});
}
let ratio = (eps_max / eps_min).powf(1.0 / (count as f64 - 1.0));
let mut eps = Vec::with_capacity(count);
let mut e = eps_min;
for _ in 0..count {
eps.push(e);
e *= ratio;
}
Ok(MeasureJetBand {
eps,
log_step: ratio.ln(),
})
}
pub fn measure_jet_quadrature_nodes(
data: ArrayView2<'_, f64>,
centers: ArrayView2<'_, f64>,
) -> Result<(Array2<f64>, Array1<f64>), BasisError> {
if data.ncols() != centers.ncols() {
crate::bail_dim_basis!(
"measure-jet mass assignment dimension mismatch: data d={} centers d={}",
data.ncols(),
centers.ncols()
);
}
validate_finite_points(data, "data")?;
validate_finite_points(centers, "centers")?;
let n = data.nrows();
let m = centers.nrows();
let d = centers.ncols();
if n == 0 || m == 0 {
crate::bail_invalid_basis!("measure-jet mass assignment needs nonempty data and centers");
}
let cn: Vec<f64> = centers.outer_iter().map(|r| r.dot(&r)).collect();
let assignments: Vec<usize> = (0..n)
.step_by(MEASURE_JET_ASSIGN_BLOCK_ROWS)
.flat_map(|start| {
let end = (start + MEASURE_JET_ASSIGN_BLOCK_ROWS).min(n);
let g = data.slice(ndarray::s![start..end, ..]).dot(¢ers.t());
let block: Vec<usize> = g
.axis_iter(Axis(0))
.into_par_iter()
.map(|row| {
let mut best_j = 0usize;
let mut best = f64::INFINITY;
for (j, &gij) in row.iter().enumerate() {
let s = cn[j] - 2.0 * gij;
if s < best {
best = s;
best_j = j;
}
}
best_j
})
.collect();
block
})
.collect();
let mut masses = Array1::<f64>::zeros(m);
let mut nodes = centers.to_owned();
let mut sums = Array2::<f64>::zeros((m, d));
let unit = 1.0 / n as f64;
for (i, &j) in assignments.iter().enumerate() {
masses[j] += unit;
for k in 0..d {
sums[(j, k)] += data[(i, k)];
}
}
let mut barycenter = sums;
for j in 0..m {
let count = masses[j] * n as f64;
if count > 0.0 {
for k in 0..d {
barycenter[(j, k)] /= count;
nodes[(j, k)] = barycenter[(j, k)];
}
}
}
Ok((nodes, masses))
}
pub fn measure_jet_center_masses(
data: ArrayView2<'_, f64>,
centers: ArrayView2<'_, f64>,
) -> Result<Array1<f64>, BasisError> {
measure_jet_quadrature_nodes(data, centers).map(|(_, masses)| masses)
}
pub(crate) fn assemble_weighted_forms<F>(
centers: ArrayView2<'_, f64>,
masses: ArrayView1<'_, f64>,
band: &MeasureJetBand,
order_s: f64,
alpha: f64,
tau0: f64,
n_forms: usize,
channels: usize,
weights: &F,
) -> Result<Vec<Array2<f64>>, BasisError>
where
F: Fn(usize, f64, f64, f64, &mut [[f64; 3]]) + Sync,
{
let m = centers.nrows();
let d = centers.ncols();
if n_forms == 0 || !(1..=3).contains(&channels) {
crate::bail_invalid_basis!(
"measure-jet assembly needs at least one output form and 1..=3 block channels"
);
}
if masses.len() != m {
crate::bail_dim_basis!(
"measure-jet energy mass/center mismatch: {} masses for {} centers",
masses.len(),
m
);
}
if band.eps.is_empty() || band.eps.iter().any(|e| !(e.is_finite() && *e > 0.0)) {
crate::bail_invalid_basis!("measure-jet energy needs a nonempty positive scale band");
}
if !(order_s.is_finite() && order_s > 0.0 && order_s < 2.0) {
crate::bail_invalid_basis!(
"measure-jet order s must lie in (0, 2) for the affine-jet energy; got {order_s}"
);
}
if !(alpha.is_finite() && tau0.is_finite() && tau0 >= 0.0) {
crate::bail_invalid_basis!(
"measure-jet energy needs finite alpha and finite tau0 >= 0; got alpha={alpha}, tau0={tau0}"
);
}
if masses.iter().any(|v| !(v.is_finite() && *v >= 0.0)) {
crate::bail_invalid_basis!("measure-jet energy needs finite nonnegative center masses");
}
let dist2 = pairwise_sq_dists(centers, centers);
let assemble_scale = |scale_idx: usize, eps: f64| -> Result<Vec<Array2<f64>>, BasisError> {
let mut out: Vec<Array2<f64>> =
(0..n_forms).map(|_| Array2::<f64>::zeros((m, m))).collect();
let cutoff2 = (MEASURE_JET_PROFILE_CUTOFF * eps) * (MEASURE_JET_PROFILE_CUTOFF * eps);
let inv_two_eps2 = 1.0 / (2.0 * eps * eps);
let eta = 2.0 * order_s + (d as f64) * (2.0 - 2.0 * alpha);
let scale_weight = band.log_step * eps.powf(-eta);
let net_radius2 = 0.25 * eps * eps;
let mut outer: Vec<usize> = Vec::new();
for i in 0..m {
if masses[i] <= 0.0 {
continue;
}
let covered = outer.iter().any(|&o| dist2[(i, o)] <= net_radius2);
if !covered {
outer.push(i);
}
}
let mut net_mass = vec![0.0_f64; m];
for i in 0..m {
if masses[i] <= 0.0 {
continue;
}
let mut best = f64::INFINITY;
let mut best_o = usize::MAX;
for &o in &outer {
if dist2[(i, o)] < best {
best = dist2[(i, o)];
best_o = o;
}
}
if best_o != usize::MAX {
net_mass[best_o] += masses[i];
}
}
let mut wbuf = vec![[0.0_f64; 3]; n_forms];
for &i in &outer {
let mut idx: Vec<usize> = Vec::new();
for j in 0..m {
if dist2[(i, j)] <= cutoff2 {
idx.push(j);
}
}
let ml = idx.len();
let mut w = Array1::<f64>::zeros(ml);
let mut q = 0.0_f64;
for (a, &j) in idx.iter().enumerate() {
let wj = masses[j] * (-dist2[(i, j)] * inv_two_eps2).exp();
w[a] = wj;
q += wj;
}
if !(q > 0.0) {
continue;
}
let mut phi = Array2::<f64>::zeros((ml, d));
for (a, &j) in idx.iter().enumerate() {
for k in 0..d {
phi[(a, k)] = (centers[(j, k)] - centers[(i, k)]) / eps;
}
}
let a_mean = phi.t().dot(&w) / q;
let mut wphi = phi.clone();
for (a, mut row) in wphi.outer_iter_mut().enumerate() {
row.mapv_inplace(|v| v * w[a]);
}
let mut b = wphi.clone();
for (a, mut row) in b.outer_iter_mut().enumerate() {
for k in 0..d {
row[k] -= w[a] * a_mean[k];
}
}
let mut g = phi.t().dot(&wphi);
g.mapv_inplace(|v| v / q);
for r in 0..d {
for c in 0..d {
g[(r, c)] -= a_mean[r] * a_mean[c];
}
}
let g_pinv = symmetric_pseudoinverse(&g, "local affine Gram")?;
let bm = b.dot(&g_pinv);
let base = scale_weight * net_mass[i] * q.powf(1.0 - 2.0 * alpha);
weights(scale_idx, eps, q, base, &mut wbuf);
for (a, &ja) in idx.iter().enumerate() {
let bma = bm.row(a);
for (c, &jc) in idx.iter().enumerate() {
let b_c = b.row(c);
let mut val_r = -w[a] * w[c] / q - bma.dot(&b_c) / q;
if a == c {
val_r += w[a];
}
for (k, out_k) in out.iter_mut().enumerate() {
let wk = wbuf[k];
out_k[(ja, jc)] += wk[0] * val_r;
}
}
}
}
Ok(out)
};
let n_scales = band.eps.len();
let parallel_ok = m
.saturating_mul(m)
.saturating_mul(n_scales)
.saturating_mul(n_forms)
<= MEASURE_JET_PARALLEL_FORM_BUDGET_DOUBLES;
let per_scale: Vec<Vec<Array2<f64>>> = if parallel_ok {
band.eps
.par_iter()
.enumerate()
.map(|(scale_idx, &eps)| assemble_scale(scale_idx, eps))
.collect::<Result<Vec<_>, BasisError>>()?
} else {
band.eps
.iter()
.enumerate()
.map(|(scale_idx, &eps)| assemble_scale(scale_idx, eps))
.collect::<Result<Vec<_>, BasisError>>()?
};
let mut totals: Vec<Array2<f64>> = (0..n_forms).map(|_| Array2::<f64>::zeros((m, m))).collect();
for scale_forms in per_scale {
for (total, part) in totals.iter_mut().zip(scale_forms) {
*total += ∂
}
}
Ok(totals.into_iter().map(|t| (&t + &t.t()) * 0.5).collect())
}
pub fn measure_jet_energy_form(
centers: ArrayView2<'_, f64>,
masses: ArrayView1<'_, f64>,
band: &MeasureJetBand,
order_s: f64,
alpha: f64,
tau0: f64,
) -> Result<Array2<f64>, BasisError> {
let mut forms = assemble_weighted_forms(
centers,
masses,
band,
order_s,
alpha,
tau0,
1,
1,
&|_, _, _, base, out: &mut [[f64; 3]]| out[0] = [base, 0.0, 0.0],
)?;
let q = forms.swap_remove(0);
project_symmetric_psd(q, "measure-jet energy form")
}
pub(crate) fn project_symmetric_psd(
a: Array2<f64>,
label: &str,
) -> Result<Array2<f64>, BasisError> {
let n = a.nrows();
if n == 0 {
return Ok(a);
}
let (evals, evecs) = a.eigh(Side::Lower).map_err(|e| {
BasisError::InvalidInput(format!(
"measure-jet PSD projection `{label}` eigendecomposition failed: {e}"
))
})?;
if evals.iter().all(|&lam| lam >= 0.0) {
return Ok(a);
}
let mut scaled = evecs.clone();
for (k, mut col) in scaled.axis_iter_mut(Axis(1)).enumerate() {
let lam = evals[k].max(0.0);
col.mapv_inplace(|v| v * lam);
}
let psd = scaled.dot(&evecs.t());
Ok((&psd + &psd.t()) * 0.5)
}
pub fn measure_jet_energy_form_with_jets(
centers: ArrayView2<'_, f64>,
masses: ArrayView1<'_, f64>,
band: &MeasureJetBand,
order_s: f64,
alpha: f64,
tau0: f64,
) -> Result<MeasureJetEnergyJets, BasisError> {
if !(tau0.is_finite() && tau0 > 0.0) {
crate::bail_invalid_basis!(
"measure-jet jets need tau0 > 0 because the retained τ coordinate is ln τ; got {tau0}"
);
}
let mut forms = assemble_weighted_forms(
centers,
masses,
band,
order_s,
alpha,
tau0,
10,
3,
&|_, eps: f64, q: f64, base: f64, out: &mut [[f64; 3]]| {
let gs = -2.0 * eps.ln();
let intrinsic_dim = centers.ncols() as f64;
let ga = 2.0 * intrinsic_dim * eps.ln() - 2.0 * q.max(f64::MIN_POSITIVE).ln();
out[0] = [base, 0.0, 0.0];
out[1] = [gs * base, 0.0, 0.0];
out[2] = [gs * gs * base, 0.0, 0.0];
out[3] = [ga * base, 0.0, 0.0];
out[4] = [ga * ga * base, 0.0, 0.0];
out[5] = [gs * ga * base, 0.0, 0.0];
out[6] = [0.0, 0.0, 0.0];
out[7] = [0.0, 0.0, 0.0];
out[8] = [0.0, 0.0, 0.0];
out[9] = [0.0, 0.0, 0.0];
},
)?;
let d2q_dalpha_dlogtau = forms.pop().expect("ten assembled forms");
let d2q_ds_dlogtau = forms.pop().expect("ten assembled forms");
let d2q_dlogtau2 = forms.pop().expect("ten assembled forms");
let dq_dlogtau = forms.pop().expect("ten assembled forms");
let d2q_ds_dalpha = forms.pop().expect("ten assembled forms");
let d2q_dalpha2 = forms.pop().expect("ten assembled forms");
let dq_dalpha = forms.pop().expect("ten assembled forms");
let d2q_ds2 = forms.pop().expect("ten assembled forms");
let dq_ds = forms.pop().expect("ten assembled forms");
let q = forms.pop().expect("ten assembled forms");
Ok(MeasureJetEnergyJets {
q,
dq_ds,
d2q_ds2,
dq_dalpha,
d2q_dalpha2,
d2q_ds_dalpha,
dq_dlogtau,
d2q_dlogtau2,
d2q_ds_dlogtau,
d2q_dalpha_dlogtau,
})
}
pub fn measure_jet_scale_spectrum(
centers: ArrayView2<'_, f64>,
masses: ArrayView1<'_, f64>,
band: &MeasureJetBand,
order_s: f64,
alpha: f64,
tau0: f64,
values: ArrayView1<'_, f64>,
) -> Result<Vec<f64>, BasisError> {
if values.len() != centers.nrows() {
crate::bail_dim_basis!(
"measure-jet scale spectrum needs one value per center: {} values for {} centers",
values.len(),
centers.nrows()
);
}
let forms = measure_jet_energy_forms_per_scale(centers, masses, band, order_s, alpha, tau0)?;
Ok(forms
.iter()
.map(|q_l| values.dot(&q_l.dot(&values)))
.collect())
}
pub fn measure_jet_energy_forms_per_scale(
centers: ArrayView2<'_, f64>,
masses: ArrayView1<'_, f64>,
band: &MeasureJetBand,
order_s: f64,
alpha: f64,
tau0: f64,
) -> Result<Vec<Array2<f64>>, BasisError> {
let n_scales = band.eps.len();
assemble_weighted_forms(
centers,
masses,
band,
order_s,
alpha,
tau0,
n_scales,
1,
&|scale_idx, _, _, base, out: &mut [[f64; 3]]| {
for (k, slot) in out.iter_mut().enumerate() {
*slot = if k == scale_idx {
[base, 0.0, 0.0]
} else {
[0.0, 0.0, 0.0]
};
}
},
)
}
pub fn measure_jet_support_curve(
queries: ArrayView2<'_, f64>,
centers: ArrayView2<'_, f64>,
masses: ArrayView1<'_, f64>,
eps_band: &[f64],
) -> Result<Array2<f64>, BasisError> {
if queries.ncols() != centers.ncols() {
crate::bail_dim_basis!(
"measure-jet support curve dimension mismatch: queries d={} centers d={}",
queries.ncols(),
centers.ncols()
);
}
if masses.len() != centers.nrows() {
crate::bail_dim_basis!(
"measure-jet support curve mass/center mismatch: {} masses for {} centers",
masses.len(),
centers.nrows()
);
}
if eps_band.is_empty() || eps_band.iter().any(|e| !(e.is_finite() && *e > 0.0)) {
crate::bail_invalid_basis!("measure-jet support curve needs a nonempty positive band");
}
validate_finite_points(queries, "queries")?;
validate_finite_points(centers, "centers")?;
let nq = queries.nrows();
let nl = eps_band.len();
let d2 = pairwise_sq_dists(queries, centers);
let mut out = Array2::<f64>::zeros((nq, nl));
out.axis_iter_mut(Axis(0))
.into_par_iter()
.enumerate()
.for_each(|(qi, mut row)| {
let d2_row = d2.row(qi);
for (li, &eps) in eps_band.iter().enumerate() {
let inv_two_eps2 = 1.0 / (2.0 * eps * eps);
let mut acc = 0.0_f64;
for (j, &dd) in d2_row.iter().enumerate() {
acc += masses[j] * (-dd * inv_two_eps2).exp();
}
row[li] = acc;
}
});
Ok(out)
}
pub(crate) fn measure_jet_support_means(
centers: ArrayView2<'_, f64>,
masses: ArrayView1<'_, f64>,
eps_band: &[f64],
) -> Result<Vec<f64>, BasisError> {
let total_mass = masses.sum();
if !(total_mass.is_finite() && total_mass > 0.0) {
crate::bail_invalid_basis!(
"measure-jet support means need positive finite total mass; got {total_mass}"
);
}
let support = measure_jet_support_curve(centers, centers, masses, eps_band)?;
let mut means = vec![0.0_f64; eps_band.len()];
for (i, row) in support.rows().into_iter().enumerate() {
let mass = masses[i];
for (mean, &q) in means.iter_mut().zip(row.iter()) {
*mean += mass * q;
}
}
for mean in &mut means {
*mean /= total_mass;
if !(*mean).is_finite() || *mean <= 0.0 {
crate::bail_invalid_basis!(
"measure-jet support mean must be positive and finite; got {mean}"
);
}
}
Ok(means)
}
pub fn measure_jet_design_matrix(
data: ArrayView2<'_, f64>,
centers: ArrayView2<'_, f64>,
length_scale: f64,
) -> Result<Array2<f64>, BasisError> {
if data.ncols() != centers.ncols() {
crate::bail_dim_basis!(
"measure-jet design dimension mismatch: data d={} centers d={}",
data.ncols(),
centers.ncols()
);
}
if !(length_scale.is_finite() && length_scale > 0.0) {
crate::bail_invalid_basis!(
"measure-jet design needs a positive finite length_scale; got {length_scale}"
);
}
validate_finite_points(data, "data")?;
validate_finite_points(centers, "centers")?;
let inv_two_l2 = 1.0 / (2.0 * length_scale * length_scale);
let mut out = pairwise_sq_dists(data, centers);
out.axis_iter_mut(Axis(0))
.into_par_iter()
.for_each(|mut row| {
row.mapv_inplace(|d2| (-d2 * inv_two_l2).exp());
});
Ok(out)
}
pub fn realized_measure_jet_length_scale(
centers: ArrayView2<'_, f64>,
spec_length_scale: f64,
) -> Result<f64, BasisError> {
if spec_length_scale.is_finite() && spec_length_scale > 0.0 {
return Ok(spec_length_scale);
}
if spec_length_scale != 0.0 {
crate::bail_invalid_basis!(
"measure-jet length_scale must be positive (or 0.0 for auto); got {spec_length_scale}"
);
}
let dist2 = pairwise_sq_dists(centers, centers);
let spacing = median_nearest_center_spacing(&dist2)?;
Ok(MEASURE_JET_AUTO_LENGTH_SCALE_FACTOR * spacing)
}
pub(crate) struct RealizedMeasureJetGeometry {
pub(crate) centers: Array2<f64>,
pub(crate) masses: Array1<f64>,
pub(crate) eps_band: Vec<f64>,
pub(crate) log_step: f64,
pub(crate) length_scale: f64,
pub(crate) order_s_eval: f64,
pub(crate) per_level: bool,
pub(crate) z: Array2<f64>,
pub(crate) coefficient_gauge: gam_problem::Gauge,
pub(crate) kz: Array2<f64>,
}
pub(crate) fn realize_measure_jet_geometry(
data: ArrayView2<'_, f64>,
spec: &MeasureJetBasisSpec,
) -> Result<RealizedMeasureJetGeometry, BasisError> {
if data.ncols() == 0 {
crate::bail_invalid_basis!("measure-jet smooth needs at least one feature column");
}
validate_finite_points(data, "data")?;
let seed_centers = select_centers_by_strategy(data, &spec.center_strategy)?;
let m = seed_centers.nrows();
if m < 3 {
return Err(BasisError::InsufficientColumnsForConstraint { found: m });
}
let order_s = if spec.order_s == 0.0 {
MEASURE_JET_DEFAULT_ORDER_S
} else {
spec.order_s
};
let (centers, masses, eps_band, log_step) = match &spec.frozen_quadrature {
Some(frozen) => {
if frozen.masses.len() != m {
crate::bail_dim_basis!(
"frozen measure-jet quadrature mismatch: {} masses for {} centers",
frozen.masses.len(),
m
);
}
if frozen.eps_band.is_empty() {
crate::bail_invalid_basis!("frozen measure-jet quadrature has an empty band");
}
let log_step = if frozen.eps_band.len() >= 2 {
(frozen.eps_band[1] / frozen.eps_band[0]).ln()
} else {
std::f64::consts::LN_2
};
(
seed_centers,
frozen.masses.clone(),
frozen.eps_band.clone(),
log_step,
)
}
None => {
let (nodes, masses) = measure_jet_quadrature_nodes(data, seed_centers.view())?;
let band = measure_jet_band(nodes.view(), spec.num_scales)?;
(nodes, masses, band.eps, band.log_step)
}
};
let length_scale = realized_measure_jet_length_scale(centers.view(), spec.length_scale)?;
let (z, coefficient_gauge) = match &spec.identifiability {
MeasureJetIdentifiability::FrozenTransform { transform } => {
if transform.nrows() != m {
crate::bail_dim_basis!(
"frozen measure-jet identifiability transform mismatch: {} centers but transform has {} rows",
m,
transform.nrows()
);
}
(
transform.clone(),
gam_problem::Gauge::from_block_transforms(&[transform.clone()]),
)
}
MeasureJetIdentifiability::CenterSumToZero => {
let u = householder_sum_to_zero_u(m);
let z = householder_sum_to_zero_z(&u);
(z.clone(), gam_problem::Gauge::sum_to_zero(z))
}
};
let k_cc = measure_jet_design_matrix(centers.view(), centers.view(), length_scale)?;
let kz = coefficient_gauge.restrict_design(&k_cc);
Ok(RealizedMeasureJetGeometry {
centers,
masses,
eps_band,
log_step,
length_scale,
order_s_eval: order_s,
per_level: spec.multiscale,
z,
coefficient_gauge,
kz,
})
}
pub fn measure_jet_multiscale_mode(spec: &MeasureJetBasisSpec) -> bool {
spec.multiscale
}
pub fn build_measure_jet_basis(
data: ArrayView2<'_, f64>,
spec: &MeasureJetBasisSpec,
) -> Result<BasisBuildResult, BasisError> {
let RealizedMeasureJetGeometry {
centers,
masses,
eps_band,
log_step,
length_scale,
order_s_eval: order_s,
per_level,
z,
coefficient_gauge,
kz,
} = realize_measure_jet_geometry(data, spec)?;
let band = MeasureJetBand {
eps: eps_band.clone(),
log_step,
};
let raw_design = measure_jet_design_matrix(data, centers.view(), length_scale)?;
let constrained_design = coefficient_gauge.restrict_design(&raw_design);
let design = gam_linalg::matrix::DesignMatrix::Dense(
gam_linalg::matrix::DenseDesignMatrix::from(constrained_design),
);
let support_means = measure_jet_support_means(centers.view(), masses.view(), &eps_band)?;
let mut candidates = Vec::new();
let mut penalty_normalization_scales = Vec::new();
let mut raw_penalty_normalization_scales = Vec::new();
let mut fused_penalty_normalization_scale = None;
if per_level {
let forms = measure_jet_energy_forms_per_scale(
centers.view(),
masses.view(),
&band,
order_s,
spec.alpha,
spec.tau0,
)?;
for (level, q_l) in forms.into_iter().enumerate() {
let s_l = kz.t().dot(&q_l).dot(&kz);
let (s_norm, c_l) = normalize_penalty(&((&s_l + &s_l.t()) * 0.5));
let intrinsic_dim = centers.ncols() as f64;
let eta = 2.0 * order_s + intrinsic_dim * (2.0 - 2.0 * spec.alpha);
let scale_weight = log_step * eps_band[level].powf(-eta);
penalty_normalization_scales.push(c_l);
raw_penalty_normalization_scales.push(c_l / scale_weight);
candidates.push(PenaltyCandidate {
matrix: s_norm,
nullspace_dim_hint: 0,
source: PenaltySource::Other(format!("measure_jet_scale_{level}")),
normalization_scale: c_l,
kronecker_factors: None,
op: None,
});
}
} else {
let q_form = measure_jet_energy_form(
centers.view(),
masses.view(),
&band,
order_s,
spec.alpha,
spec.tau0,
)?;
let mut penalty = kz.t().dot(&q_form).dot(&kz);
penalty = (&penalty + &penalty.t()) * 0.5;
if spec.double_penalty {
let ridge = affine_preserving_coefficient_ridge(&kz, centers.view(), masses.view())?;
let primary_fro = trace_of_product(&penalty, &penalty).sqrt();
let ridge_fro = trace_of_product(&ridge, &ridge).sqrt();
if primary_fro.is_finite()
&& primary_fro > 0.0
&& ridge_fro.is_finite()
&& ridge_fro > 0.0
{
let w = MEASURE_JET_FUSED_RIDGE_FRACTION * primary_fro / ridge_fro;
penalty = &penalty + &(&ridge * w);
}
}
let (penalty_norm, c_primary) = normalize_penalty(&penalty);
fused_penalty_normalization_scale = Some(c_primary);
candidates.push(PenaltyCandidate {
matrix: penalty_norm,
nullspace_dim_hint: 0,
source: PenaltySource::Primary,
normalization_scale: c_primary,
kronecker_factors: None,
op: None,
});
}
if spec.double_penalty && per_level {
let ridge = affine_preserving_coefficient_ridge(&kz, centers.view(), masses.view())?;
let (ridge_norm, c_ridge) = normalize_penalty(&ridge);
candidates.push(PenaltyCandidate {
matrix: ridge_norm,
nullspace_dim_hint: 0,
source: PenaltySource::DoublePenaltyNullspace,
normalization_scale: c_ridge,
kronecker_factors: None,
op: None,
});
}
let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
filter_active_penalty_candidates_with_ops(candidates)?;
Ok(BasisBuildResult {
design,
penalties,
nullspace_dims,
penaltyinfo,
metadata: BasisMetadata::MeasureJet {
centers,
input_scales: None,
length_scale,
eps_band,
order_s: spec.order_s,
alpha: spec.alpha,
tau0: spec.tau0,
masses,
support_means,
penalty_normalization_scales,
raw_penalty_normalization_scales,
fused_penalty_normalization_scale,
constraint_transform: Some(z),
},
kronecker_factored: None,
ops,
null_eigenvectors,
joint_null_rotation: None,
})
}
pub fn build_measure_jet_basis_psi_derivatives(
data: ArrayView2<'_, f64>,
spec: &MeasureJetBasisSpec,
) -> Result<AnisoBasisPsiDerivatives, BasisError> {
if !(spec.tau0.is_finite() && spec.tau0 > 0.0) {
crate::bail_invalid_basis!(
"measure-jet ψ derivatives need tau0 > 0 because the retained τ coordinate is ln τ; got {}",
spec.tau0
);
}
let geom = realize_measure_jet_geometry(data, spec)?;
let band = MeasureJetBand {
eps: geom.eps_band.clone(),
log_step: geom.log_step,
};
let n = data.nrows();
let p = geom.kz.ncols(); let kz = &geom.kz;
let sandwich = |j: &Array2<f64>| {
let s = kz.t().dot(j).dot(kz);
(&s + &s.t()) * 0.5
};
let (n_coords, pairs, raw): (
usize,
Vec<(usize, usize)>,
Vec<(
Array2<f64>,
Vec<Array2<f64>>,
Vec<Array2<f64>>,
Vec<Array2<f64>>,
)>,
) = if geom.per_level {
let l_count = band.eps.len();
let forms = assemble_weighted_forms(
geom.centers.view(),
geom.masses.view(),
&band,
geom.order_s_eval,
spec.alpha,
spec.tau0,
6 * l_count,
3,
&|scale_idx, eps: f64, q: f64, base: f64, out: &mut [[f64; 3]]| {
for slot in out.iter_mut() {
*slot = [0.0, 0.0, 0.0];
}
let intrinsic_dim = geom.centers.ncols() as f64;
let ga = 2.0 * intrinsic_dim * eps.ln() - 2.0 * q.max(f64::MIN_POSITIVE).ln();
let k0 = 6 * scale_idx;
out[k0] = [base, 0.0, 0.0];
out[k0 + 1] = [ga * base, 0.0, 0.0];
out[k0 + 2] = [ga * ga * base, 0.0, 0.0];
out[k0 + 3] = [0.0, 0.0, 0.0];
out[k0 + 4] = [0.0, 0.0, 0.0];
out[k0 + 5] = [0.0, 0.0, 0.0];
},
)?;
let mut raw = Vec::with_capacity(l_count);
for level in 0..l_count {
let chunk = &forms[6 * level..6 * level + 6];
raw.push((
sandwich(&chunk[0]),
vec![sandwich(&chunk[1]), sandwich(&chunk[3])],
vec![sandwich(&chunk[2]), sandwich(&chunk[4])],
vec![sandwich(&chunk[5])],
));
}
(2usize, vec![(0usize, 1usize)], raw)
} else {
let q_value = sandwich(&measure_jet_energy_form(
geom.centers.view(),
geom.masses.view(),
&band,
geom.order_s_eval,
spec.alpha,
spec.tau0,
)?);
let raw = vec![(q_value, Vec::new(), Vec::new(), Vec::new())];
(0usize, Vec::new(), raw)
};
let length_scale_design = if spec.learn_length_scale {
let ell = geom.length_scale;
let k = measure_jet_design_matrix(data, geom.centers.view(), ell)?;
let r2 = pairwise_sq_dists(data, geom.centers.view());
let inv_l2 = 1.0 / (ell * ell);
let mut dk = k.clone();
let mut d2k = k.clone();
for ((dk_v, d2k_v), &r2_v) in dk.iter_mut().zip(d2k.iter_mut()).zip(r2.iter()) {
let a = r2_v * inv_l2;
let kij = *dk_v;
*dk_v = kij * a;
*d2k_v = kij * (a * a - 2.0 * a);
}
let dx_du = geom.coefficient_gauge.restrict_design(&dk);
let d2x_du2 = geom.coefficient_gauge.restrict_design(&d2k);
Some((dx_du, d2x_du2))
} else {
None
};
let n_active = raw.len();
let ridge = spec.double_penalty && geom.per_level;
let n_cands = n_active + usize::from(ridge);
let zero_p = || Array2::<f64>::zeros((p, p));
let mut penalties_first: Vec<Vec<Array2<f64>>> =
(0..n_coords).map(|_| Vec::with_capacity(n_cands)).collect();
let mut penalties_second_diag: Vec<Vec<Array2<f64>>> =
(0..n_coords).map(|_| Vec::with_capacity(n_cands)).collect();
let mut crosses: Vec<Vec<Array2<f64>>> = (0..pairs.len()).map(|_| Vec::new()).collect();
for (s_raw, firsts, seconds, cross_raw) in &raw {
let fro = trace_of_product(s_raw, s_raw).sqrt();
let c = if fro.is_finite() && fro > 1e-12 {
fro
} else {
1.0
};
for coord in 0..n_coords {
let (_, s_first, s_second, _) =
normalize_penaltywith_psi_derivatives(s_raw, &firsts[coord], &seconds[coord]);
penalties_first[coord].push(s_first);
penalties_second_diag[coord].push(s_second);
}
for (pair_idx, &(a, b)) in pairs.iter().enumerate() {
let cross_raw_mat = normalize_penalty_cross_psi_derivative(
s_raw,
&firsts[a],
&firsts[b],
&cross_raw[pair_idx],
c,
);
crosses[pair_idx].push(cross_raw_mat);
}
}
if ridge {
for coord in 0..n_coords {
penalties_first[coord].push(zero_p());
penalties_second_diag[coord].push(zero_p());
}
for pair_crosses in crosses.iter_mut() {
pair_crosses.push(zero_p());
}
}
let coord_offset = usize::from(length_scale_design.is_some());
if coord_offset == 1 {
penalties_first.insert(0, (0..n_cands).map(|_| zero_p()).collect());
penalties_second_diag.insert(0, (0..n_cands).map(|_| zero_p()).collect());
}
let n_coords_total = n_coords + coord_offset;
let mut all_pairs: Vec<(usize, usize)> = pairs
.iter()
.map(|&(a, b)| (a + coord_offset, b + coord_offset))
.collect();
let mut all_crosses: Vec<Vec<Array2<f64>>> = crosses;
if coord_offset == 1 {
for c in 1..n_coords_total {
all_pairs.push((0, c));
all_crosses.push((0..n_cands).map(|_| zero_p()).collect());
}
}
let pair_index: Vec<((usize, usize), Vec<Array2<f64>>)> = all_pairs
.iter()
.copied()
.zip(all_crosses.into_iter())
.collect();
let shifted_pairs = all_pairs;
let provider = AnisoPenaltyCrossProvider::new(move |a, b| {
pair_index
.iter()
.find(|((pa, pb), _)| (*pa, *pb) == (a, b) || (*pa, *pb) == (b, a))
.map(|(_, mats)| mats.clone())
.ok_or_else(|| {
BasisError::InvalidInput(format!(
"measure-jet ψ cross derivative requested for unknown pair ({a}, {b})"
))
})
});
let mut design_first: Vec<Array2<f64>> = (0..n_coords_total)
.map(|_| Array2::<f64>::zeros((n, p)))
.collect();
let mut design_second_diag: Vec<Array2<f64>> = (0..n_coords_total)
.map(|_| Array2::<f64>::zeros((n, p)))
.collect();
if let Some((dx_du, d2x_du2)) = length_scale_design {
design_first[0] = dx_du;
design_second_diag[0] = d2x_du2;
}
Ok(AnisoBasisPsiDerivatives {
design_first,
design_second_diag,
design_second_cross: Vec::new(),
design_second_cross_pairs: Vec::new(),
penalties_first,
penalties_second_diag,
penalties_cross_pairs: shifted_pairs,
penalties_cross_provider: Some(provider),
implicit_operator: None,
})
}
#[cfg(test)]
mod tests {
use super::*;
pub(crate) fn two_cluster_centers() -> (ndarray::Array2<f64>, ndarray::Array1<f64>) {
let centers = array![
[0.00, 0.00],
[0.31, 0.05],
[0.58, -0.07],
[0.93, 0.11],
[1.22, 0.02],
[1.49, -0.04],
[3.10, 2.00],
[3.42, 2.13],
[3.71, 1.91],
[4.05, 2.07],
[4.33, 1.96],
[4.61, 2.12],
];
let m = centers.nrows();
let masses = ndarray::Array1::<f64>::from_elem(m, 1.0 / m as f64);
(centers, masses)
}
use ndarray::array;
pub(crate) fn band_for(centers: &Array2<f64>) -> MeasureJetBand {
measure_jet_band(centers.view(), 0).expect("band")
}
#[test]
pub(crate) fn energy_form_annihilates_constants_exactly() {
let (centers, masses) = two_cluster_centers();
let band = band_for(¢ers);
let q = measure_jet_energy_form(centers.view(), masses.view(), &band, 1.5, 1.0, 1e-3)
.expect("energy form");
let m = q.nrows();
let ones = Array1::<f64>::ones(m);
let qv = q.dot(&ones);
let scale = q.iter().fold(0.0_f64, |acc, v| acc.max(v.abs()));
assert!(scale > 0.0, "energy form is identically zero");
for (i, v) in qv.iter().enumerate() {
assert!(
v.abs() <= 1e-12 * scale,
"Q·1 leak at row {i}: {v:.3e} vs scale {scale:.3e}"
);
}
let vqv = ones.dot(&qv);
assert!(
vqv.abs() <= 1e-12 * scale,
"constant carries energy: 1ᵀQ1 = {vqv:.3e}"
);
}
#[test]
pub(crate) fn energy_form_annihilates_affine_at_default_tau() {
let (centers, masses) = two_cluster_centers();
let band = band_for(¢ers);
let m = centers.nrows();
let mut affine = Array1::<f64>::zeros(m);
let mut rough = Array1::<f64>::zeros(m);
for i in 0..m {
affine[i] = 0.7 + 1.3 * centers[(i, 0)] - 0.4 * centers[(i, 1)];
rough[i] = if i % 2 == 0 { 1.0 } else { -1.0 };
}
let q = measure_jet_energy_form(centers.view(), masses.view(), &band, 1.5, 1.0, 1e-3)
.expect("energy form");
let e_affine = affine.dot(&q.dot(&affine));
let e_rough = rough.dot(&q.dot(&rough));
assert!(e_rough > 0.0, "rough vector must pay energy");
assert!(
e_affine.abs() <= 1e-12 * e_rough,
"default affine energy {e_affine:.3e} vs rough {e_rough:.3e}"
);
}
#[test]
pub(crate) fn energy_form_is_psd() {
let (centers, masses) = two_cluster_centers();
let band = band_for(¢ers);
let q = measure_jet_energy_form(centers.view(), masses.view(), &band, 1.5, 1.0, 1e-3)
.expect("energy form");
let m = q.nrows();
for trial in 0..5usize {
let v = Array1::<f64>::from_shape_fn(m, |i| {
((i * 7 + trial * 13) % 11) as f64 / 11.0 - 0.5
});
let e = v.dot(&q.dot(&v));
assert!(e >= -1e-10, "vᵀQv = {e:.3e} < 0 on trial {trial}");
}
}
#[test]
pub(crate) fn rough_vector_pays_more_than_smooth() {
let m = 24usize;
let centers = Array2::<f64>::from_shape_fn((m, 2), |(i, k)| {
let t = i as f64 / (m as f64 - 1.0);
if k == 0 {
t * 4.0
} else {
0.3 * (t * 4.0).sin()
}
});
let masses = Array1::<f64>::from_elem(m, 1.0 / m as f64);
let band = band_for(¢ers);
let q = measure_jet_energy_form(centers.view(), masses.view(), &band, 1.5, 1.0, 1e-3)
.expect("energy form");
let slow = Array1::<f64>::from_shape_fn(m, |i| (i as f64 / (m as f64 - 1.0)).powi(2));
let fast = Array1::<f64>::from_shape_fn(m, |i| if i % 2 == 0 { 0.5 } else { -0.5 });
let e_slow = slow.dot(&q.dot(&slow));
let e_fast = fast.dot(&q.dot(&fast));
assert!(
e_fast > 10.0 * e_slow,
"alternating values must pay >> a slow trend: fast {e_fast:.3e} vs slow {e_slow:.3e}"
);
}
#[test]
pub(crate) fn energy_jets_match_finite_differences() {
let (centers, masses) = two_cluster_centers();
let band = band_for(¢ers);
let (s0, a0, tau) = (1.3, 0.8, 1e-3);
let jets =
measure_jet_energy_form_with_jets(centers.view(), masses.view(), &band, s0, a0, tau)
.expect("jets");
let q_at = |s: f64, a: f64| {
measure_jet_energy_form(centers.view(), masses.view(), &band, s, a, tau)
.expect("energy form")
};
let q_plain = q_at(s0, a0);
for (a, b) in jets.q.iter().zip(q_plain.iter()) {
assert!(
(a - b).abs() <= 1e-14 * (1.0 + b.abs()),
"Q drift {a} vs {b}"
);
}
let lt0 = tau.ln();
let q_at_lt = |lt: f64| {
measure_jet_energy_form(centers.view(), masses.view(), &band, s0, a0, lt.exp())
.expect("energy form")
};
let h = 1e-4;
let checks: [(&str, &Array2<f64>, Array2<f64>); 9] = [
("dq_ds", &jets.dq_ds, {
let (p, m_) = (q_at(s0 + h, a0), q_at(s0 - h, a0));
(&p - &m_) / (2.0 * h)
}),
("d2q_ds2", &jets.d2q_ds2, {
let (p, c, m_) = (q_at(s0 + h, a0), q_at(s0, a0), q_at(s0 - h, a0));
(&(&p + &m_) - &(&c * 2.0)) / (h * h)
}),
("dq_dalpha", &jets.dq_dalpha, {
let (p, m_) = (q_at(s0, a0 + h), q_at(s0, a0 - h));
(&p - &m_) / (2.0 * h)
}),
("d2q_dalpha2", &jets.d2q_dalpha2, {
let (p, c, m_) = (q_at(s0, a0 + h), q_at(s0, a0), q_at(s0, a0 - h));
(&(&p + &m_) - &(&c * 2.0)) / (h * h)
}),
("d2q_ds_dalpha", &jets.d2q_ds_dalpha, {
let pp = q_at(s0 + h, a0 + h);
let pm = q_at(s0 + h, a0 - h);
let mp = q_at(s0 - h, a0 + h);
let mm = q_at(s0 - h, a0 - h);
(&(&pp - &pm) - &(&mp - &mm)) / (4.0 * h * h)
}),
("dq_dlogtau", &jets.dq_dlogtau, {
let (p, m_) = (q_at_lt(lt0 + h), q_at_lt(lt0 - h));
(&p - &m_) / (2.0 * h)
}),
("d2q_dlogtau2", &jets.d2q_dlogtau2, {
let (p, c, m_) = (q_at_lt(lt0 + h), q_at_lt(lt0), q_at_lt(lt0 - h));
(&(&p + &m_) - &(&c * 2.0)) / (h * h)
}),
("d2q_ds_dlogtau", &jets.d2q_ds_dlogtau, {
let f = |s: f64, lt: f64| {
measure_jet_energy_form(centers.view(), masses.view(), &band, s, a0, lt.exp())
.expect("energy form")
};
let pp = f(s0 + h, lt0 + h);
let pm = f(s0 + h, lt0 - h);
let mp = f(s0 - h, lt0 + h);
let mm = f(s0 - h, lt0 - h);
(&(&pp - &pm) - &(&mp - &mm)) / (4.0 * h * h)
}),
("d2q_dalpha_dlogtau", &jets.d2q_dalpha_dlogtau, {
let f = |a: f64, lt: f64| {
measure_jet_energy_form(centers.view(), masses.view(), &band, s0, a, lt.exp())
.expect("energy form")
};
let pp = f(a0 + h, lt0 + h);
let pm = f(a0 + h, lt0 - h);
let mp = f(a0 - h, lt0 + h);
let mm = f(a0 - h, lt0 - h);
(&(&pp - &pm) - &(&mp - &mm)) / (4.0 * h * h)
}),
];
for (name, analytic, fd) in checks.iter() {
let scale = fd.iter().fold(1e-30_f64, |acc, v| acc.max(v.abs()));
for (a, b) in analytic.iter().zip(fd.iter()) {
assert!(
(a - b).abs() <= 5e-5 * scale,
"{name} jet mismatch: analytic {a:.6e} vs FD {b:.6e} (scale {scale:.3e})"
);
}
}
}
#[test]
pub(crate) fn scale_spectrum_sums_to_total_and_localizes_roughness() {
let m = 24usize;
let centers = Array2::<f64>::from_shape_fn((m, 2), |(i, k)| {
let t = i as f64 / (m as f64 - 1.0);
if k == 0 { t * 4.0 } else { 0.0 }
});
let masses = Array1::<f64>::from_elem(m, 1.0 / m as f64);
let band = band_for(¢ers);
let q = measure_jet_energy_form(centers.view(), masses.view(), &band, 1.5, 1.0, 1e-3)
.expect("energy form");
let fast = Array1::<f64>::from_shape_fn(m, |i| if i % 2 == 0 { 0.5 } else { -0.5 });
let spec = measure_jet_scale_spectrum(
centers.view(),
masses.view(),
&band,
1.5,
1.0,
1e-3,
fast.view(),
)
.expect("spectrum");
assert_eq!(spec.len(), band.eps.len());
let total = fast.dot(&q.dot(&fast));
let sum: f64 = spec.iter().sum();
assert!(
(sum - total).abs() <= 1e-10 * total.abs().max(1e-30),
"spectrum must sum to vᵀQv: {sum:.6e} vs {total:.6e}"
);
let finest = spec[0];
let coarsest = *spec.last().expect("nonempty spectrum");
assert!(
finest > coarsest,
"alternating values must charge fine scales hardest: fine {finest:.3e} vs coarse {coarsest:.3e}"
);
}
#[test]
pub(crate) fn support_curve_separates_on_web_from_off_web() {
let m = 24usize;
let centers = Array2::<f64>::from_shape_fn((m, 2), |(i, k)| {
let t = i as f64 / (m as f64 - 1.0);
if k == 0 { t * 4.0 } else { 0.0 }
});
let masses = Array1::<f64>::from_elem(m, 1.0 / m as f64);
let band = band_for(¢ers);
let queries = array![[2.0, 0.0], [2.0, 1.5]];
let curves =
measure_jet_support_curve(queries.view(), centers.view(), masses.view(), &band.eps)
.expect("support curve");
assert!(
curves[(0, 0)] > 10.0 * curves[(1, 0)],
"fine-scale support must separate web from void: on {:.3e} vs off {:.3e}",
curves[(0, 0)],
curves[(1, 0)]
);
for qi in 0..2 {
for li in 1..band.eps.len() {
assert!(
curves[(qi, li)] >= curves[(qi, li - 1)] - 1e-15,
"support curve must be monotone in scale (query {qi}, level {li})"
);
}
}
}
#[test]
pub(crate) fn default_stays_single_scale_until_multiscale_opt_in() {
let n = 200usize;
let data = Array2::<f64>::from_shape_fn((n, 2), |(i, k)| {
let t = i as f64 / (n as f64 - 1.0);
if k == 0 {
t * 3.0
} else {
0.4 * (t * 3.0).sin()
}
});
let single = MeasureJetBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 80 },
..MeasureJetBasisSpec::default()
};
assert!(
!measure_jet_multiscale_mode(&single),
"default must resolve to single-scale at any center count"
);
let built_single =
build_measure_jet_basis(data.view(), &single).expect("single-scale build");
assert_eq!(
built_single.penalties.len(),
1,
"single-scale mode emits one fused penalty (ridge folded in, not a 2nd λ)"
);
let multi = MeasureJetBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 80 },
multiscale: true,
..MeasureJetBasisSpec::default()
};
assert!(
measure_jet_multiscale_mode(&multi),
"multiscale=true must resolve to multiscale mode"
);
let built_multi = build_measure_jet_basis(data.view(), &multi).expect("multiscale build");
assert!(
built_multi.penalties.len() > built_single.penalties.len(),
"multiscale mode emits the per-scale spectral split plus the ridge, got {} (vs single-scale {})",
built_multi.penalties.len(),
built_single.penalties.len()
);
}
#[test]
pub(crate) fn fused_mode_emits_single_primary_candidate() {
let n = 40usize;
let data = Array2::<f64>::from_shape_fn((n, 2), |(i, k)| {
let t = i as f64 / (n as f64 - 1.0);
if k == 0 {
t * 3.0
} else {
0.4 * (t * 3.0).sin()
}
});
let spec = MeasureJetBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 14 },
order_s: 1.3,
..MeasureJetBasisSpec::default()
};
let built = build_measure_jet_basis(data.view(), &spec).expect("fused build");
assert_eq!(
built.penalties.len(),
1,
"fused single-scale mode emits exactly one Primary candidate (ridge folded in)"
);
let BasisMetadata::MeasureJet { order_s, .. } = &built.metadata else {
panic!("measure-jet build must return MeasureJet metadata");
};
assert_eq!(*order_s, 1.3, "explicit order must persist verbatim");
}
#[test]
pub(crate) fn householder_sum_to_zero_basis_is_orthonormal() {
let m = 9usize;
let u = householder_sum_to_zero_u(m);
let z = householder_sum_to_zero_z(&u);
for j in 0..(m - 1) {
let col_j = z.column(j);
assert!(col_j.sum().abs() <= 1e-12, "column {j} must sum to zero");
for j2 in j..(m - 1) {
let dot = col_j.dot(&z.column(j2));
let want = if j == j2 { 1.0 } else { 0.0 };
assert!(
(dot - want).abs() <= 1e-12,
"orthonormality failure at ({j}, {j2}): {dot}"
);
}
}
}
pub(crate) fn frozen_spec_fixture(
order_s: f64,
multiscale: bool,
) -> (Array2<f64>, MeasureJetBasisSpec) {
let n = 140usize;
let data = Array2::<f64>::from_shape_fn((n, 2), |(i, k)| {
let t = i as f64 / (n as f64 - 1.0);
if k == 0 {
t * 3.0
} else {
0.5 * (t * 3.0).cos() + if i % 9 == 0 { 0.8 } else { 0.0 }
}
});
let spec = MeasureJetBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 70 },
order_s,
multiscale,
learn_length_scale: false,
..MeasureJetBasisSpec::default()
};
let first = build_measure_jet_basis(data.view(), &spec).expect("fixture build");
let BasisMetadata::MeasureJet {
centers,
length_scale,
eps_band,
masses,
support_means,
penalty_normalization_scales,
raw_penalty_normalization_scales,
fused_penalty_normalization_scale,
constraint_transform,
..
} = &first.metadata
else {
panic!("measure-jet build must return MeasureJet metadata");
};
let frozen = MeasureJetBasisSpec {
center_strategy: CenterStrategy::UserProvided(centers.clone()),
order_s,
alpha: spec.alpha,
tau0: spec.tau0,
num_scales: eps_band.len(),
length_scale: *length_scale,
double_penalty: spec.double_penalty,
learn_length_scale: false,
multiscale,
identifiability: MeasureJetIdentifiability::FrozenTransform {
transform: constraint_transform.clone().expect("fit-time z"),
},
frozen_quadrature: Some(MeasureJetFrozenQuadrature {
masses: masses.clone(),
eps_band: eps_band.clone(),
support_means: support_means.clone(),
penalty_normalization_scales: penalty_normalization_scales.clone(),
raw_penalty_normalization_scales: raw_penalty_normalization_scales.clone(),
fused_penalty_normalization_scale: *fused_penalty_normalization_scale,
}),
};
(data, frozen)
}
#[test]
pub(crate) fn psi_producer_matches_fd_per_level_mode() {
let (data, frozen) = frozen_spec_fixture(0.0, true);
let derivs =
build_measure_jet_basis_psi_derivatives(data.view(), &frozen).expect("psi derivatives");
let l_count = frozen
.frozen_quadrature
.as_ref()
.expect("frozen quadrature")
.eps_band
.len();
assert_eq!(
derivs.penalties_first.len(),
2,
"per-level coords are (α, lnτ)"
);
assert_eq!(derivs.penalties_first[0].len(), l_count + 1);
assert_eq!(derivs.penalties_cross_pairs, vec![(0, 1)]);
let pen_at = |alpha: f64, tau0: f64| {
let trial = MeasureJetBasisSpec {
alpha,
tau0,
..frozen.clone()
};
build_measure_jet_basis(data.view(), &trial)
.expect("trial build")
.penalties
};
let h = 1e-4;
let (a0, t0) = (frozen.alpha, frozen.tau0);
let ap = pen_at(a0 + h, t0);
let am = pen_at(a0 - h, t0);
let tp = pen_at(a0, t0 * h.exp());
let tm = pen_at(a0, t0 * (-h).exp());
assert_eq!(
ap.len(),
l_count + 1,
"fixture must keep every scale active"
);
for level in 0..l_count {
let fd_alpha = (&ap[level] - &am[level]) / (2.0 * h);
let fd_tau = (&tp[level] - &tm[level]) / (2.0 * h);
for (name, analytic, fd) in [
("alpha", &derivs.penalties_first[0][level], fd_alpha),
("ln_tau", &derivs.penalties_first[1][level], fd_tau),
] {
let scale = fd.iter().fold(1e-30_f64, |acc, v| acc.max(v.abs()));
for (x, y) in analytic.iter().zip(fd.iter()) {
assert!(
(x - y).abs() <= 5e-5 * scale,
"{name} jet of scale-candidate {level}: analytic {x:.6e} vs FD {y:.6e}"
);
}
}
}
for coord in 0..2 {
assert!(
derivs.penalties_first[coord][l_count]
.iter()
.all(|v| *v == 0.0),
"ridge candidate must have zero ψ drift"
);
}
let provider = derivs
.penalties_cross_provider
.as_ref()
.expect("cross provider");
let cross = provider.evaluate(0, 1).expect("cross pair (α, lnτ)");
let pp = pen_at(a0 + h, t0 * h.exp());
let pm = pen_at(a0 + h, t0 * (-h).exp());
let mp = pen_at(a0 - h, t0 * h.exp());
let mm = pen_at(a0 - h, t0 * (-h).exp());
for level in 0..l_count {
let fd = (&(&pp[level] - &pm[level]) - &(&mp[level] - &mm[level])) / (4.0 * h * h);
let scale = fd.iter().fold(1e-30_f64, |acc, v| acc.max(v.abs()));
for (x, y) in cross[level].iter().zip(fd.iter()) {
assert!(
(x - y).abs() <= 5e-4 * scale,
"cross (α, lnτ) jet of scale-candidate {level}: analytic {x:.6e} vs FD {y:.6e}"
);
}
}
}
#[test]
pub(crate) fn psi_producer_matches_fd_length_scale() {
let (data, mut frozen) = frozen_spec_fixture(0.0, false);
frozen.learn_length_scale = true;
let derivs =
build_measure_jet_basis_psi_derivatives(data.view(), &frozen).expect("psi derivatives");
assert_eq!(
derivs.design_first.len(),
1,
"single-scale + learn_length_scale enrolls exactly the ℓ coordinate"
);
assert_eq!(
derivs.penalties_first[0].len(),
1,
"one fitted penalty candidate"
);
assert!(
derivs.penalties_first[0][0].iter().all(|v| *v == 0.0)
&& derivs.penalties_second_diag[0][0].iter().all(|v| *v == 0.0),
"the jet-energy penalty must not move with ℓ"
);
let ell0 = frozen.length_scale;
let design_at = |ell: f64| {
let trial = MeasureJetBasisSpec {
length_scale: ell,
..frozen.clone()
};
build_measure_jet_basis(data.view(), &trial)
.expect("trial build")
.design
.to_dense()
};
let h: f64 = 1e-4;
let x_plus = design_at(ell0 * h.exp());
let x_minus = design_at(ell0 * (-h).exp());
let x_0 = design_at(ell0);
let fd_first = (&x_plus - &x_minus) / (2.0 * h);
let fd_second = (&x_plus - &(&x_0 * 2.0) + &x_minus) / (h * h);
let scale1 = fd_first.iter().fold(1e-30_f64, |acc, v| acc.max(v.abs()));
for (x, y) in derivs.design_first[0].iter().zip(fd_first.iter()) {
assert!(
(x - y).abs() <= 5e-5 * scale1,
"∂X/∂lnℓ: analytic {x:.6e} vs FD {y:.6e}"
);
}
let scale2 = fd_second.iter().fold(1e-30_f64, |acc, v| acc.max(v.abs()));
for (x, y) in derivs.design_second_diag[0].iter().zip(fd_second.iter()) {
assert!(
(x - y).abs() <= 1e-3 * scale2,
"∂²X/∂lnℓ²: analytic {x:.6e} vs FD {y:.6e}"
);
}
}
#[test]
pub(crate) fn quadrature_nodes_are_cell_barycenters() {
let data = array![
[0.0, 0.2],
[0.4, -0.2],
[0.2, 0.0],
[9.8, 10.1],
[10.2, 9.9],
];
let seeds = array![[0.1, 0.1], [10.0, 10.0], [-50.0, -50.0]];
let (nodes, masses) =
measure_jet_quadrature_nodes(data.view(), seeds.view()).expect("quadrature nodes");
assert!((masses.sum() - 1.0).abs() <= 1e-15, "masses must sum to 1");
assert!((masses[0] - 0.6).abs() <= 1e-15);
assert!((masses[1] - 0.4).abs() <= 1e-15);
assert_eq!(masses[2], 0.0);
assert_eq!(nodes[(0, 0)], 0.2);
assert_eq!(nodes[(0, 1)], 0.0);
assert_eq!(nodes[(1, 0)], 10.0);
assert_eq!(nodes[(1, 1)], 10.0);
assert_eq!(nodes[(2, 0)], -50.0);
assert_eq!(nodes[(2, 1)], -50.0);
}
#[test]
pub(crate) fn build_replay_roundtrip_reproduces_design_and_penalty() {
let n = 140usize;
let data = Array2::<f64>::from_shape_fn((n, 2), |(i, k)| {
let t = i as f64 / (n as f64 - 1.0);
if k == 0 {
t * 3.0
} else {
0.5 * (t * 3.0).cos() + if i % 9 == 0 { 0.8 } else { 0.0 }
}
});
let spec = MeasureJetBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 70 },
multiscale: true,
..MeasureJetBasisSpec::default()
};
let first = build_measure_jet_basis(data.view(), &spec).expect("first build");
let BasisMetadata::MeasureJet {
centers,
length_scale,
eps_band,
order_s,
alpha,
tau0,
masses,
support_means,
penalty_normalization_scales,
raw_penalty_normalization_scales,
fused_penalty_normalization_scale,
constraint_transform,
..
} = &first.metadata
else {
panic!("measure-jet build must return MeasureJet metadata");
};
let replay_spec = MeasureJetBasisSpec {
center_strategy: CenterStrategy::UserProvided(centers.clone()),
order_s: *order_s,
alpha: *alpha,
tau0: *tau0,
num_scales: eps_band.len(),
length_scale: *length_scale,
double_penalty: spec.double_penalty,
learn_length_scale: spec.learn_length_scale,
multiscale: spec.multiscale,
identifiability: MeasureJetIdentifiability::FrozenTransform {
transform: constraint_transform.clone().expect("fit-time z"),
},
frozen_quadrature: Some(MeasureJetFrozenQuadrature {
masses: masses.clone(),
eps_band: eps_band.clone(),
support_means: support_means.clone(),
penalty_normalization_scales: penalty_normalization_scales.clone(),
raw_penalty_normalization_scales: raw_penalty_normalization_scales.clone(),
fused_penalty_normalization_scale: *fused_penalty_normalization_scale,
}),
};
assert_eq!(
first.penalties.len(),
eps_band.len() + 1,
"per-level mode must emit one candidate per scale + ridge"
);
let second = build_measure_jet_basis(data.view(), &replay_spec).expect("replay build");
let x1 = first.design.to_dense();
let x2 = second.design.to_dense();
assert_eq!(x1.shape(), x2.shape());
for (a, b) in x1.iter().zip(x2.iter()) {
assert!((a - b).abs() <= 1e-12, "design replay drift: {a} vs {b}");
}
assert_eq!(first.penalties.len(), second.penalties.len());
for (p1, p2) in first.penalties.iter().zip(second.penalties.iter()) {
for (a, b) in p1.iter().zip(p2.iter()) {
assert!((a - b).abs() <= 1e-12, "penalty replay drift: {a} vs {b}");
}
}
}
}