use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
use rayon::prelude::*;
use faer::Side;
use crate::linalg::faer_ndarray::FaerEigh;
use super::{BasisError, MeasureJetBand};
pub(crate) const MEASURE_JET3_PROFILE_CUTOFF: f64 = 3.0;
pub(crate) const MEASURE_JET3_PSEUDOINVERSE_RTOL: f64 = 64.0 * f64::EPSILON;
pub(crate) const MEASURE_JET3_FRAME_MAX: usize = 8;
pub(crate) const MEASURE_JET3_PARALLEL_BUDGET_DOUBLES: usize = 1 << 26;
pub(crate) fn pairwise_sq_dists(a: ArrayView2<'_, f64>) -> Array2<f64> {
let norms: Vec<f64> = a.outer_iter().map(|r| r.dot(&r)).collect();
let mut g = a.dot(&a.t());
g.axis_iter_mut(Axis(0))
.into_par_iter()
.enumerate()
.for_each(|(i, mut row)| {
for (j, v) in row.iter_mut().enumerate() {
*v = (norms[i] + norms[j] - 2.0 * *v).max(0.0);
}
});
g
}
pub(crate) fn symmetric_pseudoinverse(
a: &Array2<f64>,
label: &str,
) -> Result<Array2<f64>, BasisError> {
let n = a.nrows();
if a.ncols() != n {
crate::bail_dim_basis!(
"measure-jet r=3 pseudo-inverse `{label}` needs a square matrix, got {:?}",
a.dim()
);
}
if n == 0 {
return Ok(Array2::<f64>::zeros((0, 0)));
}
let (evals, evecs) = a.eigh(Side::Lower).map_err(|e| {
BasisError::InvalidInput(format!(
"measure-jet r=3 pseudo-inverse `{label}` eigendecomposition failed: {e}"
))
})?;
let lam_max = evals.iter().fold(0.0_f64, |acc, v| acc.max((*v).max(0.0)));
let rank_tol = MEASURE_JET3_PSEUDOINVERSE_RTOL * (n as f64) * lam_max;
let mut scaled = evecs.clone();
for (k, mut col) in scaled.axis_iter_mut(Axis(1)).enumerate() {
let lam = evals[k].max(0.0);
let inv = if lam > rank_tol { 1.0 / lam } else { 0.0 };
col.mapv_inplace(|v| v * inv);
}
Ok(scaled.dot(&evecs.t()))
}
pub fn measure_jet_order3_energy_form(
centers: ArrayView2<'_, f64>,
masses: ArrayView1<'_, f64>,
band: &MeasureJetBand,
order_s: f64,
alpha: f64,
tau0: f64,
) -> Result<Array2<f64>, BasisError> {
let m = centers.nrows();
let d = centers.ncols();
if masses.len() != m {
crate::bail_dim_basis!(
"measure-jet r=3 mass/center mismatch: {} masses for {} centers",
masses.len(),
m
);
}
if band.eps.is_empty() || band.eps.iter().any(|e| !(e.is_finite() && *e > 0.0)) {
crate::bail_invalid_basis!("measure-jet r=3 energy needs a nonempty positive scale band");
}
if !(order_s.is_finite() && order_s > 0.0 && order_s < 3.0) {
crate::bail_invalid_basis!(
"measure-jet r=3 order s must lie in (0, 3) for the cubic-jet energy; got {order_s}"
);
}
if !(alpha.is_finite() && tau0.is_finite() && tau0 >= 0.0) {
crate::bail_invalid_basis!(
"measure-jet r=3 energy needs finite alpha and finite tau0 >= 0; got alpha={alpha}, tau0={tau0}"
);
}
if masses.iter().any(|v| !(v.is_finite() && *v >= 0.0)) {
crate::bail_invalid_basis!("measure-jet r=3 energy needs finite nonnegative center masses");
}
if m == 0 {
return Ok(Array2::<f64>::zeros((0, 0)));
}
let dist2 = pairwise_sq_dists(centers);
let assemble_scale = |eps: f64| -> Result<Array2<f64>, BasisError> {
let mut out = Array2::<f64>::zeros((m, m));
let cutoff2 = (MEASURE_JET3_PROFILE_CUTOFF * eps) * (MEASURE_JET3_PROFILE_CUTOFF * eps);
let inv_two_eps2 = 1.0 / (2.0 * eps * eps);
let eta = 2.0 * order_s + (d as f64) * (2.0 - 2.0 * alpha);
let scale_weight = band.log_step * eps.powf(-eta);
let net_radius2 = 0.25 * eps * eps;
let mut outer: Vec<usize> = Vec::new();
for i in 0..m {
if masses[i] <= 0.0 {
continue;
}
let covered = outer.iter().any(|&o| dist2[(i, o)] <= net_radius2);
if !covered {
outer.push(i);
}
}
let mut net_mass = vec![0.0_f64; m];
for i in 0..m {
if masses[i] <= 0.0 {
continue;
}
let mut best = f64::INFINITY;
let mut best_o = usize::MAX;
for &o in &outer {
if dist2[(i, o)] < best {
best = dist2[(i, o)];
best_o = o;
}
}
if best_o != usize::MAX {
net_mass[best_o] += masses[i];
}
}
for &i in &outer {
let mut idx: Vec<usize> = Vec::new();
for j in 0..m {
if dist2[(i, j)] <= cutoff2 {
idx.push(j);
}
}
let ml = idx.len();
let mut w = Array1::<f64>::zeros(ml);
let mut q = 0.0_f64;
for (a, &j) in idx.iter().enumerate() {
let wj = masses[j] * (-dist2[(i, j)] * inv_two_eps2).exp();
w[a] = wj;
q += wj;
}
if !(q > 0.0) {
continue;
}
let mut phi = Array2::<f64>::zeros((ml, d));
for (a, &j) in idx.iter().enumerate() {
for k in 0..d {
phi[(a, k)] = (centers[(j, k)] - centers[(i, k)]) / eps;
}
}
let a_mean = phi.t().dot(&w) / q;
let mut phi_c = phi.clone();
for mut row in phi_c.outer_iter_mut() {
for k in 0..d {
row[k] -= a_mean[k];
}
}
let mut wphic = phi_c.clone();
for (a, mut row) in wphic.outer_iter_mut().enumerate() {
row.mapv_inplace(|v| v * w[a]);
}
let mut g = phi_c.t().dot(&wphic);
g.mapv_inplace(|v| v / q);
let (g_evals, g_evecs) = g.eigh(Side::Lower).map_err(|e| {
BasisError::InvalidInput(format!(
"measure-jet r=3 frame eigendecomposition failed: {e}"
))
})?;
let lam_max = g_evals
.iter()
.fold(0.0_f64, |acc, v| acc.max((*v).max(0.0)));
if !(lam_max > 0.0) {
continue;
}
let frame_tol = MEASURE_JET3_PSEUDOINVERSE_RTOL * (d.max(1) as f64) * lam_max;
let mut keep: Vec<usize> = (0..g_evals.len())
.rev()
.filter(|&k| g_evals[k] > frame_tol)
.collect();
keep.truncate(MEASURE_JET3_FRAME_MAX);
let qdim = keep.len();
if qdim == 0 {
continue;
}
let mut u = Array2::<f64>::zeros((d, qdim));
for (col, &k) in keep.iter().enumerate() {
for r in 0..d {
u[(r, col)] = g_evecs[(r, k)];
}
}
let y = phi_c.dot(&u);
let n_quad = qdim * (qdim + 1) / 2;
let p3 = 1 + qdim + n_quad;
let mut pmat = Array2::<f64>::zeros((ml, p3));
for a in 0..ml {
pmat[(a, 0)] = 1.0;
for c in 0..qdim {
pmat[(a, 1 + c)] = y[(a, c)];
}
let mut col = 1 + qdim;
for c0 in 0..qdim {
for c1 in c0..qdim {
pmat[(a, col)] = y[(a, c0)] * y[(a, c1)];
col += 1;
}
}
}
let mut wp = pmat.clone();
for (a, mut row) in wp.outer_iter_mut().enumerate() {
row.mapv_inplace(|v| v * w[a]);
}
let ptwp = pmat.t().dot(&wp); let ptwp_pinv = symmetric_pseudoinverse(&ptwp, "degree-<3 frame normal matrix")?;
let tmp = wp.dot(&ptwp_pinv); let h = tmp.dot(&wp.t());
let base = scale_weight * net_mass[i] * q.powf(1.0 - 2.0 * alpha);
for (a, &ja) in idx.iter().enumerate() {
for (c, &jc) in idx.iter().enumerate() {
let mut r_ac = -h[(a, c)];
if a == c {
r_ac += w[a];
}
out[(ja, jc)] += base * r_ac;
}
}
}
Ok(out)
};
let n_scales = band.eps.len();
let parallel_ok =
m.saturating_mul(m).saturating_mul(n_scales) <= MEASURE_JET3_PARALLEL_BUDGET_DOUBLES;
let per_scale: Vec<Array2<f64>> = if parallel_ok {
band.eps
.par_iter()
.map(|&eps| assemble_scale(eps))
.collect::<Result<Vec<_>, BasisError>>()?
} else {
band.eps
.iter()
.map(|&eps| assemble_scale(eps))
.collect::<Result<Vec<_>, BasisError>>()?
};
let mut total = Array2::<f64>::zeros((m, m));
for part in per_scale {
total += ∂
}
Ok((&total + &total.t()) * 0.5)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::terms::basis::measure_jet_band;
use ndarray::array;
pub(crate) fn grid_centers() -> Array2<f64> {
let n = 5usize;
let mut v: Vec<f64> = Vec::with_capacity(n * n * 2);
for i in 0..n {
for j in 0..n {
let xi = i as f64 / (n as f64 - 1.0);
let yj = j as f64 / (n as f64 - 1.0);
let jx = 0.04 * (3.0 * xi + 1.7 * yj).sin();
let jy = 0.04 * (2.1 * yj - 1.3 * xi).cos();
v.push(xi + jx);
v.push(yj + jy);
}
}
Array2::from_shape_vec((n * n, 2), v).expect("grid centers")
}
pub(crate) fn uniform_masses(m: usize) -> Array1<f64> {
Array1::from_elem(m, 1.0 / m as f64)
}
pub(crate) fn band_for(centers: &Array2<f64>) -> MeasureJetBand {
measure_jet_band(centers.view(), 0).expect("auto band")
}
pub(crate) fn energy_of(q: &Array2<f64>, f: &Array1<f64>) -> f64 {
f.dot(&q.dot(f))
}
pub(crate) fn sample_quadratic(
centers: &Array2<f64>,
a: f64,
b: [f64; 2],
m: [[f64; 2]; 2],
) -> Array1<f64> {
let n = centers.nrows();
Array1::from_shape_fn(n, |i| {
let x0 = centers[(i, 0)];
let x1 = centers[(i, 1)];
a + b[0] * x0
+ b[1] * x1
+ m[0][0] * x0 * x0
+ (m[0][1] + m[1][0]) * x0 * x1
+ m[1][1] * x1 * x1
})
}
#[test]
pub(crate) fn order3_annihilates_ambient_quadratics_r2_does_not() {
let centers = grid_centers();
let masses = uniform_masses(centers.nrows());
let band = band_for(¢ers);
let q3 =
measure_jet_order3_energy_form(centers.view(), masses.view(), &band, 1.5, 1.0, 1e-3)
.expect("r=3 energy");
let q2 = crate::terms::basis::measure_jet_energy_form(
centers.view(),
masses.view(),
&band,
1.5,
1.0,
1e-3,
)
.expect("r=2 energy");
let fq = sample_quadratic(¢ers, 0.3, [-0.7, 1.1], [[0.9, 0.4], [0.4, -0.6]]);
let n = centers.nrows();
let fcubic = Array1::from_shape_fn(n, |i| {
let x0 = centers[(i, 0)];
let x1 = centers[(i, 1)];
x0 * x0 * x0 - 1.4 * x1 * x1 * x1 + 0.8 * x0 * x0 * x1
});
let rough3 = energy_of(&q3, &fcubic).abs().max(1e-30);
let e3_quad = energy_of(&q3, &fq);
assert!(
e3_quad.abs() <= 1e-8 * rough3,
"r=3 quadratic energy {e3_quad:.3e} not annihilated (rough3 = {rough3:.3e})"
);
let e2_quad = energy_of(&q2, &fq);
assert!(
e2_quad > 1e-6 * rough3,
"r=2 quadratic energy {e2_quad:.3e} unexpectedly small; expected the affine \
energy to penalize curvature (rough3 = {rough3:.3e})"
);
assert!(
e2_quad > 1e6 * e3_quad.abs().max(f64::MIN_POSITIVE),
"expected r=2 quadratic energy ({e2_quad:.3e}) >> r=3 ({e3_quad:.3e})"
);
}
#[test]
pub(crate) fn order3_energy_form_is_psd() {
let centers = grid_centers();
let masses = uniform_masses(centers.nrows());
let band = band_for(¢ers);
let q =
measure_jet_order3_energy_form(centers.view(), masses.view(), &band, 1.5, 1.0, 1e-3)
.expect("r=3 energy");
let m = q.nrows();
let scale = q.iter().fold(0.0_f64, |acc, v| acc.max(v.abs()));
assert!(scale > 0.0, "r=3 energy form is identically zero");
for trial in 0..8usize {
let v = Array1::from_shape_fn(m, |i| ((i * 13 + trial * 29) % 23) as f64 / 23.0 - 0.5);
let e = energy_of(&q, &v);
assert!(
e >= -1e-9 * scale,
"vᵀQv = {e:.3e} < 0 on trial {trial} (scale {scale:.3e})"
);
}
let ones = Array1::ones(m);
let qv = q.dot(&ones);
let leak = qv.iter().fold(0.0_f64, |acc, v| acc.max(v.abs()));
assert!(
leak <= 1e-10 * scale,
"Q·1 leak {leak:.3e} vs scale {scale:.3e}"
);
}
#[test]
pub(crate) fn order3_sinusoid_energy_concentrates_at_fine_scales() {
let centers = grid_centers();
let masses = uniform_masses(centers.nrows());
let band = band_for(¢ers);
let n = centers.nrows();
let per_scale_energy = |f: &Array1<f64>| -> Vec<f64> {
band.eps
.iter()
.map(|&e| {
let single = MeasureJetBand {
eps: vec![e],
log_step: band.log_step,
};
let q = measure_jet_order3_energy_form(
centers.view(),
masses.view(),
&single,
1.5,
1.0,
1e-3,
)
.expect("single-scale r=3 energy");
energy_of(&q, f).max(0.0)
})
.collect()
};
let sinusoid = |freq: f64| -> Array1<f64> {
Array1::from_shape_fn(n, |i| {
let x0 = centers[(i, 0)];
let x1 = centers[(i, 1)];
(freq * (x0 + 0.5 * x1)).sin()
})
};
let weighted_scale_index = |es: &[f64]| -> f64 {
let tot: f64 = es.iter().sum::<f64>().max(1e-300);
es.iter()
.enumerate()
.map(|(k, e)| k as f64 * e)
.sum::<f64>()
/ tot
};
let f_fast = sinusoid(14.0);
let f_slow = sinusoid(3.0);
let es_fast = per_scale_energy(&f_fast);
let es_slow = per_scale_energy(&f_slow);
let idx_fast = weighted_scale_index(&es_fast);
let idx_slow = weighted_scale_index(&es_slow);
assert!(
idx_fast < idx_slow,
"fine sinusoid did not concentrate at finer scales: \
idx_fast {idx_fast:.3} should be < idx_slow {idx_slow:.3} \
(es_fast = {es_fast:?}, es_slow = {es_slow:?})"
);
}
}