Skip to main content

gam_sae/inference/
layer_transport.rs

1//! Functorial inter-layer concept transport maps (issue #1013).
2//!
3//! For an atom whose layer-`l` chart assigns coordinates `t_l` to each row and
4//! whose continuation at layer `l+1` assigns `t_{l+1}`, the estimand is the
5//! smooth transport map
6//!
7//! ```text
8//!     t_{l+1} = h_{l→l+1}(t_l)
9//! ```
10//!
11//! fitted as a small penalized GAM with the engine's Gaussian REML machinery
12//! (exact 1-D criterion, no GCV per policy). Three questions are answered with
13//! evidence:
14//!
15//! 1. **Topology compatibility** — does `h` preserve the chart topology
16//!    (circle→circle degree-±1 covering, i.e. a homeomorphism of `S¹`) or
17//!    break it (circle→arcs, folds)? For circle charts the winding **degree**
18//!    is estimated by maximizing the circular concentration (mean resultant
19//!    length) of the de-wound residual `θ_to − d·θ_from` over candidate
20//!    degrees `d ∈ {−2,−1,0,1,2}` — for a transport whose smooth residual
21//!    stays inside half a turn this is the circular-correlation-maximizing
22//!    degree, and it is exact in the noiseless limit. A fold check on a dense
23//!    grid (`sign(d)·h′(t) > 0` everywhere) separates genuine degree-±1
24//!    covers from degree-±1 maps with local back-tracking.
25//! 2. **Isometry defect** — `∫ (|h′| − 1)² dP̂` under the empirical data
26//!    density `P̂` (the integral is evaluated at the observed coordinates, so
27//!    dense regions of the chart dominate, as the issue requires). A
28//!    delta-method standard error is propagated from the coefficient
29//!    covariance. Near-zero defect ⇒ TRANSPORT layer (the concept is carried
30//!    isometrically); large defect ⇒ COMPUTE layer (the chart metric is
31//!    reshaped).
32//! 3. **Composition law** — `h_{l→l+2}` vs `h_{l+1→l+2} ∘ h_{l→l+1}`. The
33//!    defect `d(t) = h_ac(t) ⊖ h_bc(h_ab(t))` (circular difference on circle
34//!    charts) is evaluated on a grid, studentized by the composed
35//!    delta-method bands, and tested with the existing
36//!    [`wood_smooth_test`](gam_terms::inference::smooth_test::wood_smooth_test)
37//!    machinery applied to a REML smooth of the defect.
38//!
39//! # Gauge discipline
40//!
41//! Each chart coordinate is identified only up to the residual isometry gauge
42//! of its chart, so a transport map is identified only up to the **double
43//! coset** `[Isom(M_to)] · h · [Isom(M_from)]`. Two facts are used:
44//!
45//! * All three routes in a composition test consume the *same* source
46//!   coordinates, so any isometry of the source chart acts identically on
47//!   `h_ac` and on `h_bc ∘ h_ab`; the source gauge cancels in the defect and
48//!   needs no explicit alignment.
49//! * The target gauge does not cancel: before testing, the composed route is
50//!   aligned to the direct route using ONLY the certified finite/1-parameter
51//!   isometries of the target chart — for a circle, the rotation (fixed at
52//!   the circular mean of the defect) and the reflection (the orientation
53//!   with the smaller squared defect); for an interval, the reflection about
54//!   its midpoint. No general reparameterization is ever fitted away.
55//!
56//! All smooths reuse the engine's existing periodic cardinal-B-spline basis
57//! ([`build_periodic_bspline_basis_1d`]) with the cyclic difference penalty on
58//! circular domains, and the open B-spline basis with the standard difference
59//! penalty on interval domains — constructed directly, not via the string DSL.
60
61use gam_linalg::faer_ndarray::FaerEigh;
62use gam_terms::inference::smooth_test::{SmoothTestInput, SmoothTestScale, wood_smooth_test};
63use gam_terms::basis::{
64    BasisOptions, Dense, KnotSource, PeriodicBSplineBasisSpec, build_periodic_bspline_basis_1d,
65    create_basis, create_cyclic_difference_penalty_matrix, create_difference_penalty_matrix,
66    periodic_bspline_first_derivative_nd,
67};
68use crate::chart_canonicalization::CanonicalChartTopology;
69use faer::Side;
70use ndarray::{Array1, Array2, ArrayView1, Axis};
71use statrs::distribution::{ContinuousCDF, Normal};
72use std::f64::consts::{PI, TAU};
73
74/// Cubic splines for every transport smooth.
75const TRANSPORT_SPLINE_DEGREE: usize = 3;
76/// Second-order (curvature) difference penalty: the cyclic variant leaves
77/// constants unpenalized on a circle; the open variant leaves affine maps
78/// unpenalized on an interval — exactly the isometry-adjacent null spaces.
79const TRANSPORT_PENALTY_ORDER: usize = 2;
80/// Minimum paired observations for a transport fit.
81const MIN_TRANSPORT_OBS: usize = 16;
82/// Target observations per basis function when auto-sizing the basis.
83const OBS_PER_BASIS: usize = 8;
84/// Periodic basis size bounds (auto-derived from `n`, never a caller knob).
85const MIN_PERIODIC_BASIS: usize = 8;
86const MAX_PERIODIC_BASIS: usize = 20;
87/// Open-interval internal-knot bounds.
88const MIN_OPEN_INTERNAL_KNOTS: usize = 4;
89const MAX_OPEN_INTERNAL_KNOTS: usize = 12;
90/// Candidate winding degrees scanned by the circular-concentration estimator.
91const DEGREE_CANDIDATES: [i32; 5] = [-2, -1, 0, 1, 2];
92/// Dense grid used for the fold / orientation check of `h′`.
93const FOLD_CHECK_GRID: usize = 512;
94/// Default evaluation grid for the composition-law defect.
95pub const DEFAULT_COMPOSITION_GRID: usize = 256;
96/// REML λ-profile: log-spaced grid points then golden-section refinement.
97const REML_LAMBDA_GRID_POINTS: usize = 41;
98const REML_GOLDEN_ITERATIONS: usize = 40;
99const REML_LAMBDA_SPAN_DECADES: f64 = 8.0;
100
101/// Topology of a one-dimensional concept chart.
102#[derive(Debug, Clone, Copy, PartialEq)]
103pub enum ChartTopology {
104    /// Circular chart; coordinates are angles in radians, identified mod 2π.
105    Circle,
106    /// Interval chart with the Euclidean metric on `[lo, hi]`.
107    Interval { lo: f64, hi: f64 },
108}
109
110impl ChartTopology {
111    /// Short stable name used by FFI payloads.
112    pub fn name(&self) -> &'static str {
113        match self {
114            ChartTopology::Circle => "circle",
115            ChartTopology::Interval { .. } => "interval",
116        }
117    }
118
119    fn validate(&self) -> Result<(), String> {
120        match *self {
121            ChartTopology::Circle => Ok(()),
122            ChartTopology::Interval { lo, hi } => {
123                if !(lo.is_finite() && hi.is_finite()) || hi <= lo {
124                    Err(format!(
125                        "interval chart bounds must be finite and ordered; got [{lo}, {hi}]"
126                    ))
127                } else {
128                    Ok(())
129                }
130            }
131        }
132    }
133}
134
135/// Bridge from the SAE canonicalization topology to the transport topology.
136///
137/// `CanonicalChartTopology::Circle { period }` becomes a `Circle` chart whose
138/// coordinates are interpreted on `[0, period)` — the transport module's period
139/// is fixed to `TAU` (angles in radians), so the conversion rescales by mapping
140/// the period-normalized angle `t / period * TAU` at the call site. The caller
141/// must apply this rescaling before handing coordinates to `fit_transport_map`.
142///
143/// `CanonicalChartTopology::Interval` becomes `Interval { lo: 0.0, hi: 1.0 }`
144/// (the canonical unit-speed interval span set by the canonicalization step).
145impl From<&CanonicalChartTopology> for ChartTopology {
146    fn from(src: &CanonicalChartTopology) -> Self {
147        match src {
148            CanonicalChartTopology::Circle { .. } => ChartTopology::Circle,
149            CanonicalChartTopology::Interval => ChartTopology::Interval { lo: 0.0, hi: 1.0 },
150        }
151    }
152}
153
154impl From<CanonicalChartTopology> for ChartTopology {
155    fn from(src: CanonicalChartTopology) -> Self {
156        ChartTopology::from(&src)
157    }
158}
159
160/// Wrap an angle into `[0, 2π)`.
161fn wrap_tau(x: f64) -> f64 {
162    x.rem_euclid(TAU)
163}
164
165/// Wrap an angle into `(−π, π]`.
166fn wrap_pi(x: f64) -> f64 {
167    let w = (x + PI).rem_euclid(TAU) - PI;
168    if w <= -PI { w + TAU } else { w }
169}
170
171/// Circular mean of a set of angles; `0` when the resultant degenerates.
172fn circular_mean(angles: &[f64]) -> f64 {
173    let mut s = 0.0_f64;
174    let mut c = 0.0_f64;
175    for &a in angles {
176        s += a.sin();
177        c += a.cos();
178    }
179    if s.hypot(c) <= f64::EPSILON * angles.len().max(1) as f64 {
180        0.0
181    } else {
182        s.atan2(c)
183    }
184}
185
186/// Mean resultant length `R ∈ [0, 1]` of a set of angles.
187fn resultant_length(angles: &[f64]) -> f64 {
188    if angles.is_empty() {
189        return 0.0;
190    }
191    let mut s = 0.0_f64;
192    let mut c = 0.0_f64;
193    for &a in angles {
194        s += a.sin();
195        c += a.cos();
196    }
197    s.hypot(c) / angles.len() as f64
198}
199
200/// Domain-side basis carrier: periodic cardinal B-splines on a circle, open
201/// B-splines on an interval. Both reuse the existing basis constructors
202/// directly (no string DSL round-trip).
203#[derive(Debug, Clone)]
204enum DomainBasis {
205    Periodic(PeriodicBSplineBasisSpec),
206    Open { knots: Array1<f64>, degree: usize },
207}
208
209impl DomainBasis {
210    fn build(topology: ChartTopology, coords: ArrayView1<'_, f64>) -> Result<Self, String> {
211        let n = coords.len();
212        match topology {
213            ChartTopology::Circle => {
214                let num_basis = (n / OBS_PER_BASIS).clamp(MIN_PERIODIC_BASIS, MAX_PERIODIC_BASIS);
215                Ok(DomainBasis::Periodic(PeriodicBSplineBasisSpec {
216                    degree: TRANSPORT_SPLINE_DEGREE,
217                    num_basis,
218                    period: TAU,
219                    origin: 0.0,
220                    penalty_order: TRANSPORT_PENALTY_ORDER,
221                }))
222            }
223            ChartTopology::Interval { lo, hi } => {
224                let num_internal =
225                    (n / OBS_PER_BASIS).clamp(MIN_OPEN_INTERNAL_KNOTS, MAX_OPEN_INTERNAL_KNOTS);
226                let (seed, knots) = create_basis::<Dense>(
227                    coords.mapv(|v| v.clamp(lo, hi)).view(),
228                    KnotSource::Generate {
229                        data_range: (lo, hi),
230                        num_internal_knots: num_internal,
231                    },
232                    TRANSPORT_SPLINE_DEGREE,
233                    BasisOptions::value(),
234                )
235                .map_err(|e| format!("layer transport open basis construction failed: {e}"))?;
236                if seed.nrows() != n {
237                    return Err(format!(
238                        "layer transport open basis returned {} rows for {n} inputs",
239                        seed.nrows()
240                    ));
241                }
242                Ok(DomainBasis::Open {
243                    knots,
244                    degree: TRANSPORT_SPLINE_DEGREE,
245                })
246            }
247        }
248    }
249
250    fn num_basis(&self) -> usize {
251        match self {
252            DomainBasis::Periodic(spec) => spec.num_basis,
253            DomainBasis::Open { knots, degree } => knots.len() - degree - 1,
254        }
255    }
256
257    /// Rank of the smoothing penalty: the cyclic 2nd-difference penalty
258    /// annihilates only constants (a linear map is not periodic), the open
259    /// 2nd-difference penalty annihilates affine maps.
260    fn penalty_rank(&self) -> usize {
261        match self {
262            DomainBasis::Periodic(spec) => spec.num_basis - 1,
263            DomainBasis::Open { .. } => self.num_basis() - TRANSPORT_PENALTY_ORDER,
264        }
265    }
266
267    fn penalty(&self) -> Result<Array2<f64>, String> {
268        match self {
269            DomainBasis::Periodic(spec) => {
270                create_cyclic_difference_penalty_matrix(spec.num_basis, TRANSPORT_PENALTY_ORDER)
271                    .map_err(|e| format!("cyclic transport penalty failed: {e}"))
272            }
273            DomainBasis::Open { .. } => {
274                create_difference_penalty_matrix(self.num_basis(), TRANSPORT_PENALTY_ORDER, None)
275                    .map_err(|e| format!("open transport penalty failed: {e}"))
276            }
277        }
278    }
279
280    /// Clamp/wrap an evaluation point into the basis domain.
281    fn project(&self, t: f64) -> f64 {
282        match self {
283            DomainBasis::Periodic(_) => wrap_tau(t),
284            DomainBasis::Open { knots, degree } => {
285                let lo = knots[*degree];
286                let hi = knots[knots.len() - 1 - degree];
287                t.clamp(lo, hi)
288            }
289        }
290    }
291
292    fn value_rows(&self, t: ArrayView1<'_, f64>) -> Result<Array2<f64>, String> {
293        let projected = t.mapv(|v| self.project(v));
294        match self {
295            DomainBasis::Periodic(spec) => build_periodic_bspline_basis_1d(projected.view(), spec)
296                .map_err(|e| format!("periodic transport basis evaluation failed: {e}")),
297            DomainBasis::Open { knots, degree } => {
298                let (rows, used_knots) = create_basis::<Dense>(
299                    projected.view(),
300                    KnotSource::Provided(knots.view()),
301                    *degree,
302                    BasisOptions::value(),
303                )
304                .map_err(|e| format!("open transport basis evaluation failed: {e}"))?;
305                if used_knots.len() != knots.len() {
306                    return Err("open transport basis knot vector drifted".to_string());
307                }
308                Ok(rows.as_ref().to_owned())
309            }
310        }
311    }
312
313    /// Polynomial degree of `h′` on each knot span: the basis degree minus one
314    /// (a cubic spline derivative is piecewise quadratic).
315    fn derivative_poly_degree(&self) -> usize {
316        let degree = match self {
317            DomainBasis::Periodic(spec) => spec.degree,
318            DomainBasis::Open { degree, .. } => *degree,
319        };
320        degree.saturating_sub(1)
321    }
322
323    /// Sorted distinct breakpoints bounding the polynomial pieces of `h′` over
324    /// the active domain `[lo, hi]`. Within each `[breakpoints[k],
325    /// breakpoints[k+1]]` span the derivative is a single polynomial of degree
326    /// [`Self::derivative_poly_degree`], which is what the exact monotonicity
327    /// certificate reconstructs and checks. For the open basis these are the
328    /// distinct interior+boundary knots; for the periodic basis they are the
329    /// uniform cardinal-B-spline segment boundaries over `[0, 2π]`.
330    fn derivative_breakpoints(&self) -> Vec<f64> {
331        match self {
332            DomainBasis::Periodic(spec) => {
333                // Cardinal periodic B-splines on `[origin, origin+period]` have
334                // `num_basis` uniform segments; the derivative is a separate
335                // polynomial on each.
336                let n_seg = spec.num_basis.max(1);
337                (0..=n_seg)
338                    .map(|k| spec.origin + spec.period * k as f64 / n_seg as f64)
339                    .collect()
340            }
341            DomainBasis::Open { knots, degree } => {
342                let lo = knots[*degree];
343                let hi = knots[knots.len() - 1 - degree];
344                let mut breaks: Vec<f64> = Vec::with_capacity(knots.len());
345                for &k in knots.iter() {
346                    if k > lo + 0.0 && k < hi {
347                        breaks.push(k);
348                    }
349                }
350                breaks.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
351                breaks.dedup_by(|a, b| (*a - *b).abs() <= f64::EPSILON * hi.abs().max(1.0));
352                let mut out = Vec::with_capacity(breaks.len() + 2);
353                out.push(lo);
354                out.extend(breaks.into_iter().filter(|&k| k > lo && k < hi));
355                out.push(hi);
356                out
357            }
358        }
359    }
360
361    fn derivative_rows(&self, t: ArrayView1<'_, f64>) -> Result<Array2<f64>, String> {
362        let projected = t.mapv(|v| self.project(v));
363        match self {
364            DomainBasis::Periodic(spec) => {
365                let n = projected.len();
366                let mut col = Array2::<f64>::zeros((n, 1));
367                for (i, &v) in projected.iter().enumerate() {
368                    col[[i, 0]] = v;
369                }
370                let jet = periodic_bspline_first_derivative_nd(
371                    col.view(),
372                    (0.0, TAU),
373                    spec.degree,
374                    spec.num_basis,
375                )
376                .map_err(|e| format!("periodic transport derivative failed: {e}"))?;
377                Ok(jet.index_axis(Axis(2), 0).to_owned())
378            }
379            DomainBasis::Open { knots, degree } => {
380                let (rows, used_knots) = create_basis::<Dense>(
381                    projected.view(),
382                    KnotSource::Provided(knots.view()),
383                    *degree,
384                    BasisOptions::first_derivative(),
385                )
386                .map_err(|e| format!("open transport derivative failed: {e}"))?;
387                if used_knots.len() != knots.len() {
388                    return Err("open transport derivative knot vector drifted".to_string());
389                }
390                Ok(rows.as_ref().to_owned())
391            }
392        }
393    }
394}
395
396/// One penalized 1-D smooth chosen by exact Gaussian REML (or known-scale
397/// REML for the weighted defect fit), with everything downstream inference
398/// needs: scale-included covariance, the influence block for trace-corrected
399/// reference d.f., EDF, and the selected λ.
400struct Penalized1dFit {
401    beta: Array1<f64>,
402    /// Scale-included posterior covariance `σ̂²(XᵀWX + λS)⁻¹` (φ̂ = 1 in the
403    /// known-scale branch).
404    covariance: Array2<f64>,
405    /// Coefficient-space influence `F = (XᵀWX + λS)⁻¹ XᵀWX` for Wood's
406    /// trace-corrected reference d.f.
407    influence: Array2<f64>,
408    lambda: f64,
409    edf: f64,
410    sigma2: f64,
411    residual_rms: f64,
412}
413
414/// Exact 1-D Gaussian REML on a fixed design/penalty pair.
415///
416/// Estimated scale (`known_scale = false`): profile σ² out of Wood's REML,
417/// `V(λ) = (n − M₀)·log PRSS(λ) + log|XᵀWX + λS| − rank(S)·log λ`, with
418/// `M₀ = dim ker S` and `PRSS = yᵀWy − β̂ᵀXᵀWy`. Known scale (φ = 1, used for
419/// the variance-weighted defect smooth): `V(λ) = PRSS + log|XᵀWX + λS| −
420/// rank(S)·log λ`. λ is selected on a deterministic log grid spanning
421/// ±[`REML_LAMBDA_SPAN_DECADES`] decades around the design's trace scale and
422/// refined by golden section — no RNG, no caller knobs.
423fn fit_penalized_1d(
424    design: &Array2<f64>,
425    penalty: &Array2<f64>,
426    response: ArrayView1<'_, f64>,
427    weights: Option<ArrayView1<'_, f64>>,
428    penalty_rank: usize,
429    known_scale: bool,
430) -> Result<Penalized1dFit, String> {
431    let n = design.nrows();
432    let m = design.ncols();
433    if response.len() != n || penalty.nrows() != m || penalty.ncols() != m {
434        return Err(format!(
435            "penalized 1-D fit shape mismatch: X is {n}×{m}, y has {}, S is {}×{}",
436            response.len(),
437            penalty.nrows(),
438            penalty.ncols()
439        ));
440    }
441    if let Some(w) = weights {
442        if w.len() != n {
443            return Err(format!(
444                "penalized 1-D fit weight length {} does not match n = {n}",
445                w.len()
446            ));
447        }
448        if w.iter().any(|&v| !v.is_finite() || v <= 0.0) {
449            return Err("penalized 1-D fit weights must be finite and positive".to_string());
450        }
451    }
452
453    let mut xtwx = Array2::<f64>::zeros((m, m));
454    let mut xtwy = Array1::<f64>::zeros(m);
455    let mut ytwy = 0.0_f64;
456    let mut sum_w = 0.0_f64;
457    for r in 0..n {
458        let w = weights.map_or(1.0, |wv| wv[r]);
459        let y = response[r];
460        ytwy += w * y * y;
461        sum_w += w;
462        for j in 0..m {
463            let xj = design[[r, j]];
464            if xj == 0.0 {
465                continue;
466            }
467            xtwy[j] += w * xj * y;
468            for k in j..m {
469                xtwx[[j, k]] += w * xj * design[[r, k]];
470            }
471        }
472    }
473    for j in 0..m {
474        for k in 0..j {
475            xtwx[[j, k]] = xtwx[[k, j]];
476        }
477    }
478
479    let trace_scale = (0..m).map(|i| xtwx[[i, i]]).sum::<f64>() / m as f64;
480    let anchor = trace_scale.max(f64::MIN_POSITIVE);
481    let nullspace_dim = m.saturating_sub(penalty_rank);
482    let dof = ((n as f64) - nullspace_dim as f64).max(1.0);
483    let rank_f = penalty_rank as f64;
484
485    let solve_at = |lambda: f64| -> Result<(Array1<f64>, Array1<f64>, Array2<f64>), String> {
486        let mut a = xtwx.clone();
487        for j in 0..m {
488            for k in 0..m {
489                a[[j, k]] += lambda * penalty[[j, k]];
490            }
491        }
492        // Representative-selecting micro-ridge for exactly aliased designs.
493        let diag_scale = (0..m).map(|i| a[[i, i]].abs()).fold(1.0_f64, f64::max);
494        for i in 0..m {
495            a[[i, i]] += 1e-12 * diag_scale;
496        }
497        let (evals, evecs) = a
498            .eigh(Side::Lower)
499            .map_err(|e| format!("penalized 1-D fit eigendecomposition failed: {e:?}"))?;
500        Ok((evals, evecs.t().dot(&xtwy), evecs))
501    };
502
503    let criterion = |lambda: f64| -> f64 {
504        let Ok(parts) = solve_at(lambda) else {
505            return f64::INFINITY;
506        };
507        let (evals, rotated) = (&parts.0, &parts.1);
508        let floor = evals.iter().copied().fold(0.0_f64, f64::max) * 1e-14;
509        let mut prss = ytwy;
510        let mut logdet = 0.0_f64;
511        for i in 0..m {
512            let d = evals[i].max(floor).max(f64::MIN_POSITIVE);
513            prss -= rotated[i] * rotated[i] / d;
514            logdet += d.ln();
515        }
516        let prss = prss.max(f64::MIN_POSITIVE);
517        let fit_term = if known_scale { prss } else { dof * prss.ln() };
518        fit_term + logdet - rank_f * lambda.ln()
519    };
520
521    let lo = anchor * 10f64.powf(-REML_LAMBDA_SPAN_DECADES);
522    let hi = anchor * 10f64.powf(REML_LAMBDA_SPAN_DECADES);
523    let grid: Vec<f64> = (0..REML_LAMBDA_GRID_POINTS)
524        .map(|i| {
525            let t = i as f64 / (REML_LAMBDA_GRID_POINTS - 1) as f64;
526            lo * (hi / lo).powf(t)
527        })
528        .collect();
529    let mut best_idx = 0usize;
530    let mut best_val = f64::INFINITY;
531    for (i, &lam) in grid.iter().enumerate() {
532        let v = criterion(lam);
533        if v < best_val {
534            best_val = v;
535            best_idx = i;
536        }
537    }
538    let mut a_log = grid[best_idx.saturating_sub(1)].ln();
539    let mut c_log = grid[(best_idx + 1).min(REML_LAMBDA_GRID_POINTS - 1)].ln();
540    let golden = (5.0_f64.sqrt() - 1.0) / 2.0;
541    let mut x1 = c_log - golden * (c_log - a_log);
542    let mut x2 = a_log + golden * (c_log - a_log);
543    let mut f1 = criterion(x1.exp());
544    let mut f2 = criterion(x2.exp());
545    for _ in 0..REML_GOLDEN_ITERATIONS {
546        if f1 <= f2 {
547            c_log = x2;
548            x2 = x1;
549            f2 = f1;
550            x1 = c_log - golden * (c_log - a_log);
551            f1 = criterion(x1.exp());
552        } else {
553            a_log = x1;
554            x1 = x2;
555            f1 = f2;
556            x2 = a_log + golden * (c_log - a_log);
557            f2 = criterion(x2.exp());
558        }
559    }
560    let lambda = (0.5 * (a_log + c_log)).exp();
561
562    let (evals, rotated, evecs) = solve_at(lambda)?;
563    let floor = evals.iter().copied().fold(0.0_f64, f64::max) * 1e-14;
564    let mut a_inv = Array2::<f64>::zeros((m, m));
565    let mut beta = Array1::<f64>::zeros(m);
566    for i in 0..m {
567        let d = evals[i].max(floor).max(f64::MIN_POSITIVE);
568        let coeff = rotated[i] / d;
569        for j in 0..m {
570            beta[j] += evecs[[j, i]] * coeff;
571            for k in 0..m {
572                a_inv[[j, k]] += evecs[[j, i]] * evecs[[k, i]] / d;
573            }
574        }
575    }
576    let influence = a_inv.dot(&xtwx);
577    let edf = (0..m).map(|i| influence[[i, i]]).sum::<f64>();
578
579    let fitted = design.dot(&beta);
580    let mut rss = 0.0_f64;
581    for r in 0..n {
582        let w = weights.map_or(1.0, |wv| wv[r]);
583        let e = response[r] - fitted[r];
584        rss += w * e * e;
585    }
586    let sigma2 = if known_scale {
587        1.0
588    } else {
589        (rss / ((n as f64) - edf).max(1.0)).max(f64::MIN_POSITIVE)
590    };
591    let covariance = a_inv.mapv(|v| v * sigma2);
592    let residual_rms = (rss / sum_w.max(f64::MIN_POSITIVE)).sqrt();
593
594    if beta.iter().any(|v| !v.is_finite()) {
595        return Err("penalized 1-D fit produced non-finite coefficients".to_string());
596    }
597    Ok(Penalized1dFit {
598        beta,
599        covariance,
600        influence,
601        lambda,
602        edf,
603        sigma2,
604        residual_rms,
605    })
606}
607
608/// A fitted inter-layer transport map with full posterior bookkeeping, ready
609/// for evaluation, banding, and composition testing.
610///
611/// Representation: `h(t) = degree·t + rotation_offset + g(t)` on circle
612/// targets (`g` the REML periodic/open spline; the result is read mod 2π) and
613/// `h(t) = g(t)` on interval targets. The discrete winding `degree` and the
614/// wrap-branch offset are treated as fixed (a discrete selection and a gauge
615/// representative respectively); pointwise variances propagate the spline
616/// coefficient covariance only.
617#[derive(Debug, Clone)]
618pub struct FittedTransport {
619    pub topology_from: ChartTopology,
620    pub topology_to: ChartTopology,
621    /// Winding degree of the map (circle→circle charts only).
622    pub degree: Option<i32>,
623    /// Mean resultant length of the de-wound residual at the selected degree
624    /// (circle→circle only): the concentration evidence behind `degree`.
625    pub degree_concentration: Option<f64>,
626    /// Rotation gauge representative used to pick the wrap branch of the
627    /// angular response (circle targets; `0` for interval targets). The
628    /// estimand is the double coset, so this constant carries no information
629    /// on its own.
630    pub rotation_offset: f64,
631    /// Spline coefficients of the residual smooth `g`.
632    pub beta: Array1<f64>,
633    /// Scale-included posterior covariance of `beta` (mgcv `Vb` analogue).
634    pub covariance: Array2<f64>,
635    pub smoothing_lambda: f64,
636    /// Effective degrees of freedom of the transport smooth.
637    pub edf: f64,
638    /// REML-profiled residual variance σ̂² of the (unwrapped) response.
639    pub noise_variance: f64,
640    pub n_obs: usize,
641    /// Empirical-density-weighted isometry defect `mean((|h′(tᵢ)| − 1)²)`.
642    pub isometry_defect: f64,
643    /// Delta-method standard error of the isometry defect.
644    pub isometry_defect_se: f64,
645    /// Whether `h` is compatible with both chart topologies: a degree-±1
646    /// circle cover without folds, or a fold-free interval homeomorphism.
647    pub topology_preserved: bool,
648    /// `min over a dense grid of orientation·h′(t)`; positive ⇔ no folds.
649    pub min_directional_derivative: f64,
650    /// RMS of the response residuals at the fitted map.
651    pub residual_rms: f64,
652    basis: DomainBasis,
653}
654
655impl FittedTransport {
656    fn linear_slope(&self) -> f64 {
657        self.degree.map_or(0.0, f64::from)
658    }
659
660    /// Evaluate `h` at `t` (wrapped to `[0, 2π)` on circle targets).
661    pub fn eval(&self, t: ArrayView1<'_, f64>) -> Result<Array1<f64>, String> {
662        let rows = self.basis.value_rows(t)?;
663        let smooth = rows.dot(&self.beta);
664        let slope = self.linear_slope();
665        let mut out = Array1::<f64>::zeros(t.len());
666        for i in 0..t.len() {
667            let raw = slope * t[i] + self.rotation_offset + smooth[i];
668            out[i] = match self.topology_to {
669                ChartTopology::Circle => wrap_tau(raw),
670                ChartTopology::Interval { .. } => raw,
671            };
672        }
673        Ok(out)
674    }
675
676    /// Evaluate `h` and its pointwise delta-method variance.
677    pub fn eval_with_variance(
678        &self,
679        t: ArrayView1<'_, f64>,
680    ) -> Result<(Array1<f64>, Array1<f64>), String> {
681        let rows = self.basis.value_rows(t)?;
682        let values = self.eval(t)?;
683        let mut variances = Array1::<f64>::zeros(t.len());
684        for i in 0..t.len() {
685            let row = rows.row(i);
686            variances[i] = row.dot(&self.covariance.dot(&row)).max(0.0);
687        }
688        Ok((values, variances))
689    }
690
691    /// Evaluate `h′(t)` (chart-coordinate derivative).
692    pub fn derivative(&self, t: ArrayView1<'_, f64>) -> Result<Array1<f64>, String> {
693        let rows = self.basis.derivative_rows(t)?;
694        let slope = self.linear_slope();
695        Ok(rows.dot(&self.beta).mapv(|v| v + slope))
696    }
697
698    /// Pre-wrap map value `slope·t + offset + g(t)` at a single point — the
699    /// strictly monotone (when fold-free) handle that [`Self::eval`] wraps for
700    /// circle targets and [`Self::invert`] bisects on.
701    fn raw_at(&self, t: f64) -> Result<f64, String> {
702        let arr = Array1::from_elem(1, t);
703        let smooth = self.basis.value_rows(arr.view())?.dot(&self.beta)[0];
704        Ok(self.linear_slope() * t + self.rotation_offset + smooth)
705    }
706
707    /// `orientation·h′` at the supplied source-chart coordinates.
708    fn oriented_derivative_at(&self, t: &[f64], orientation: f64) -> Result<Vec<f64>, String> {
709        let arr = Array1::from_vec(t.to_vec());
710        let rows = self.basis.derivative_rows(arr.view())?;
711        let slope = self.linear_slope();
712        Ok((0..t.len())
713            .map(|i| orientation * (rows.row(i).dot(&self.beta) + slope))
714            .collect())
715    }
716
717    /// Exactly certify that `h` is strictly monotone over the whole source
718    /// domain, returning the certified orientation (+1 increasing, −1
719    /// decreasing) or an `Err` describing where monotonicity fails.
720    ///
721    /// Unlike [`Self::topology_preserved`], which only samples `h′` on a fixed
722    /// 512-point grid and so can miss a fold *between* grid samples, this is a
723    /// span-exact certificate. On each knot span `h′` is a single polynomial of
724    /// degree `d = `[`DomainBasis::derivative_poly_degree`]` (cubic spline ⇒
725    /// quadratic). A degree-`d` polynomial is determined by `d + 1` samples, so
726    /// per span we sample `h′` at `d + 1` equally-spaced abscissae, reconstruct
727    /// the polynomial by finite differences, locate its interior critical
728    /// points in closed form, and require `orientation·h′ > 0` at the span
729    /// endpoints **and** every interior critical point. To stay sound even if a
730    /// basis is not an exact polynomial of the assumed degree on a span (e.g. a
731    /// row-normalized periodic basis whose row-sum is not a partition of unity),
732    /// the reconstruction is verified against an independent interior sample;
733    /// any mismatch falls back to refusing the span.
734    fn certify_strict_monotonicity(&self) -> Result<f64, String> {
735        let (lo, hi) = match self.topology_from {
736            ChartTopology::Circle => (0.0, TAU),
737            ChartTopology::Interval { lo, hi } => (lo, hi),
738        };
739        // Orientation from the endpoint span of the pre-wrap map, matching the
740        // sign convention `invert` bisects with.
741        let raw_lo = self.raw_at(lo)?;
742        let raw_hi = self.raw_at(hi)?;
743        let orientation = if raw_hi >= raw_lo { 1.0 } else { -1.0 };
744
745        let deg = self.basis.derivative_poly_degree().max(1);
746        let breaks = self.basis.derivative_breakpoints();
747        // Restrict the breakpoints to the active domain (the periodic segment
748        // grid already coincides with `[lo, hi]`).
749        for window in breaks.windows(2) {
750            let (a, b) = (window[0], window[1]);
751            if !(b > a) {
752                continue;
753            }
754            let span = b - a;
755            // Reconstruction abscissae: `deg + 1` equally spaced nodes on the
756            // closed span (sampling strictly inside avoids the knot where two
757            // pieces meet and the open-basis derivative can be one-sided).
758            let pad = span * 1.0e-9;
759            let n_nodes = deg + 1;
760            let nodes: Vec<f64> = (0..n_nodes)
761                .map(|i| {
762                    let s = if n_nodes == 1 {
763                        0.5
764                    } else {
765                        i as f64 / (n_nodes - 1) as f64
766                    };
767                    (a + pad) + (span - 2.0 * pad) * s
768                })
769                .collect();
770            let values = self.oriented_derivative_at(&nodes, orientation)?;
771
772            // Polynomial in the local coordinate u = (t - nodes[0]) / step,
773            // reconstructed by Newton forward differences on the equally-spaced
774            // nodes. Coefficients in the monomial basis of u are recovered for
775            // the closed-form critical-point search.
776            let step = if n_nodes > 1 {
777                nodes[1] - nodes[0]
778            } else {
779                span
780            };
781            let coeffs = monomial_from_equispaced(&values);
782
783            // Sound guard: verify the reconstruction reproduces an independent
784            // interior sample (deliberately off the reconstruction nodes — the
785            // equispaced nodes never land on a 0.37 fraction). If the basis is
786            // not exactly polynomial of the assumed degree on this span, refuse
787            // rather than trust the fit.
788            let probe_t = a + 0.37 * span;
789            let probe_u = (probe_t - nodes[0]) / step;
790            let probe_recon = eval_monomial(&coeffs, probe_u);
791            let probe_actual = self.oriented_derivative_at(&[probe_t], orientation)?[0];
792            let scale = probe_actual.abs().max(1.0);
793            if (probe_recon - probe_actual).abs() > 1.0e-6 * scale {
794                return Err(format!(
795                    "transport monotonicity certificate could not reconstruct h′ on the \
796                     span [{a}, {b}] (reconstruction {probe_recon} vs actual {probe_actual}); \
797                     refusing to certify"
798                ));
799            }
800
801            // Require positivity at the closed-span endpoints.
802            for &edge in &[a, b] {
803                let u = (edge - nodes[0]) / step;
804                let v = eval_monomial(&coeffs, u);
805                if !(v > 0.0) {
806                    return Err(format!(
807                        "transport map is not strictly monotone: orientation·h′ = {v} ≤ 0 at \
808                         t = {edge}"
809                    ));
810                }
811            }
812            // Require positivity at every interior critical point of the
813            // polynomial within the span.
814            for u_crit in monomial_critical_points(&coeffs) {
815                let t_crit = nodes[0] + u_crit * step;
816                if t_crit > a && t_crit < b {
817                    let v = eval_monomial(&coeffs, u_crit);
818                    if !(v > 0.0) {
819                        return Err(format!(
820                            "transport map folds: orientation·h′ = {v} ≤ 0 at interior \
821                             extremum t = {t_crit}"
822                        ));
823                    }
824                }
825            }
826        }
827        Ok(orientation)
828    }
829
830    /// Invert the transport: for each target-chart coordinate `y`, return the
831    /// source-chart coordinate `t` with `eval([t]) == y`.
832    ///
833    /// Requires a strictly monotone, fold-free map (a degree-±1 cover for
834    /// circle charts, a homeomorphism for intervals), so the inverse is
835    /// single-valued; otherwise this errors rather than picking an arbitrary
836    /// branch. Monotonicity is established with [`Self::certify_strict_monotonicity`]
837    /// — a span-exact polynomial certificate, **not** the sampled
838    /// `topology_preserved` diagnostic, which can miss a narrow fold between its
839    /// grid samples. Non-finite targets are rejected. Interval targets reject a
840    /// `y` outside the fitted image (scale-aware tolerance); circle targets
841    /// accept any `y` (the pre-wrap map covers a full `2π`). The root is found
842    /// by monotone bisection on the pre-wrap map `raw_at`, which converges to
843    /// f64 precision (~53 significand bits) in the source coordinate after on
844    /// the order of 50 iterations.
845    ///
846    /// This is the exact inverse of [`Self::eval`] and the missing half of the
847    /// transport algebra alongside [`composition_defect`]: it is what lets a
848    /// caller form `g_B ∘ g_A⁻¹` from two fitted transports.
849    pub fn invert(&self, y: ArrayView1<'_, f64>) -> Result<Array1<f64>, String> {
850        if y.iter().any(|v| !v.is_finite()) {
851            return Err("transport inverse targets must be finite".to_string());
852        }
853        // Span-exact strict-monotonicity certificate; supersedes the sampled
854        // `topology_preserved` flag, which can pass over a between-sample fold.
855        self.certify_strict_monotonicity()?;
856        let (lo, hi) = match self.topology_from {
857            ChartTopology::Circle => (0.0, TAU),
858            ChartTopology::Interval { lo, hi } => (lo, hi),
859        };
860        // The pre-wrap map is strictly monotone over [lo, hi]; the endpoints
861        // anchor its orientation and image span.
862        let raw_lo = self.raw_at(lo)?;
863        let raw_hi = self.raw_at(hi)?;
864        let increasing = raw_hi > raw_lo;
865        let (raw_min, raw_max) = if increasing {
866            (raw_lo, raw_hi)
867        } else {
868            (raw_hi, raw_lo)
869        };
870        // Scale-aware image tolerance: an absolute 1e-9 would wrongly accept a
871        // target well outside a tiny image (e.g. [0, 1e-8]).
872        let scale = raw_min.abs().max(raw_max.abs()).max(1.0);
873        let tol = 32.0 * f64::EPSILON * scale;
874
875        // One reusable single-element buffer for the bisection probes (rebuilt
876        // basis rows on every probe otherwise allocated a fresh `Array1`).
877        let mut probe = Array1::<f64>::zeros(1);
878        let mut raw_at_into = |t: f64| -> Result<f64, String> {
879            probe[0] = t;
880            let smooth = self.basis.value_rows(probe.view())?.dot(&self.beta)[0];
881            Ok(self.linear_slope() * t + self.rotation_offset + smooth)
882        };
883
884        let mut out = Array1::<f64>::zeros(y.len());
885        for (idx, &yi) in y.iter().enumerate() {
886            // Target value in the pre-wrap coordinate.
887            let target = match self.topology_to {
888                ChartTopology::Interval { .. } => {
889                    if yi < raw_min - tol || yi > raw_max + tol {
890                        return Err(format!(
891                            "transport inverse target {yi} is outside the fitted image \
892                             [{raw_min}, {raw_max}]"
893                        ));
894                    }
895                    yi.clamp(raw_min, raw_max)
896                }
897                ChartTopology::Circle => {
898                    // The pre-wrap map covers exactly 2π; shift wrap_tau(y) by
899                    // the unique integer multiple of 2π that lands in the image.
900                    let ywrapped = wrap_tau(yi);
901                    let m = ((raw_min - ywrapped) / TAU).ceil();
902                    ywrapped + TAU * m
903                }
904            };
905            // Monotone bisection on the pre-wrap map over [lo, hi]; stop once
906            // the bracket is below the source-coordinate precision floor (f64
907            // bisection stagnates well before 100 iterations).
908            let (mut a, mut b) = (lo, hi);
909            let width_floor = f64::EPSILON * hi.abs().max(lo.abs()).max(1.0);
910            for _ in 0..100 {
911                if (b - a) <= width_floor {
912                    break;
913                }
914                let mid = 0.5 * (a + b);
915                let rm = raw_at_into(mid)?;
916                let go_right = if increasing { rm < target } else { rm > target };
917                if go_right {
918                    a = mid;
919                } else {
920                    b = mid;
921                }
922            }
923            out[idx] = 0.5 * (a + b);
924        }
925        Ok(out)
926    }
927
928    /// Package the fit as a [`LayerTransportReport`] for the given layer pair
929    /// (composition fields empty; see [`LayerTransportReport::with_composition`]).
930    pub fn report(&self, layer_from: usize, layer_to: usize) -> LayerTransportReport {
931        LayerTransportReport {
932            layer_from,
933            layer_to,
934            topology_from: self.topology_from,
935            topology_to: self.topology_to,
936            topology_preserved: self.topology_preserved,
937            degree: self.degree,
938            degree_concentration: self.degree_concentration,
939            rotation_offset: self.rotation_offset,
940            isometry_defect: self.isometry_defect,
941            isometry_defect_se: self.isometry_defect_se,
942            min_directional_derivative: self.min_directional_derivative,
943            transport_edf: self.edf,
944            smoothing_lambda: self.smoothing_lambda,
945            noise_variance: self.noise_variance,
946            residual_rms: self.residual_rms,
947            n_obs: self.n_obs,
948            composition_defect: None,
949            composition_max_studentized: None,
950            composition_p_value: None,
951            composition_gauge_reflected: None,
952        }
953    }
954}
955
956/// Evidence payload for one estimated inter-layer transport map.
957#[derive(Debug, Clone)]
958pub struct LayerTransportReport {
959    pub layer_from: usize,
960    pub layer_to: usize,
961    pub topology_from: ChartTopology,
962    pub topology_to: ChartTopology,
963    /// Degree-±1 fold-free circle cover (or fold-free interval homeo).
964    pub topology_preserved: bool,
965    /// Estimated winding degree (circle→circle only).
966    pub degree: Option<i32>,
967    /// Circular concentration of the de-wound residual at `degree`.
968    pub degree_concentration: Option<f64>,
969    /// Rotation gauge representative (circle targets).
970    pub rotation_offset: f64,
971    /// `∫(|h′| − 1)² dP̂` under the empirical chart density.
972    pub isometry_defect: f64,
973    /// Delta-method SE of the isometry defect.
974    pub isometry_defect_se: f64,
975    /// Fold diagnostic: min of orientation·h′ over a dense grid.
976    pub min_directional_derivative: f64,
977    /// EDF of the REML transport smooth.
978    pub transport_edf: f64,
979    pub smoothing_lambda: f64,
980    pub noise_variance: f64,
981    pub residual_rms: f64,
982    pub n_obs: usize,
983    /// RMS composition defect of the triple ending at this two-hop map
984    /// (populated by [`transport_ladder`] / [`LayerTransportReport::with_composition`]).
985    pub composition_defect: Option<f64>,
986    /// Max studentized composition defect against the composed bands.
987    pub composition_max_studentized: Option<f64>,
988    /// `wood_smooth_test` p-value of the defect smooth (H₀: defect ≡ 0 up to
989    /// the target-chart gauge).
990    pub composition_p_value: Option<f64>,
991    /// Whether the gauge alignment chose the reflected target orientation.
992    pub composition_gauge_reflected: Option<bool>,
993}
994
995impl LayerTransportReport {
996    /// Merge a composition-law test into this (direct, two-hop) report.
997    pub fn with_composition(mut self, composition: &CompositionDefectReport) -> Self {
998        self.composition_defect = Some(composition.rms_defect);
999        self.composition_max_studentized = Some(composition.max_studentized_defect);
1000        self.composition_p_value = Some(composition.p_value);
1001        self.composition_gauge_reflected = Some(composition.gauge_reflected);
1002        self
1003    }
1004}
1005
1006/// Estimate the transport map `h: M_from → M_to` between two chart
1007/// coordinatizations of the same rows.
1008///
1009/// `coords_from[i]` and `coords_to[i]` must coordinatize the same observation
1010/// in the source and target charts. Circle coordinates are radians (any
1011/// branch; wrapped internally). See the module docs for the estimator.
1012pub fn fit_transport_map(
1013    coords_from: ArrayView1<'_, f64>,
1014    coords_to: ArrayView1<'_, f64>,
1015    topology_from: ChartTopology,
1016    topology_to: ChartTopology,
1017) -> Result<FittedTransport, String> {
1018    let n = coords_from.len();
1019    if coords_to.len() != n {
1020        return Err(format!(
1021            "layer transport coordinate lengths disagree: {} vs {}",
1022            n,
1023            coords_to.len()
1024        ));
1025    }
1026    if n < MIN_TRANSPORT_OBS {
1027        return Err(format!(
1028            "layer transport needs at least {MIN_TRANSPORT_OBS} paired observations, got {n}"
1029        ));
1030    }
1031    if coords_from
1032        .iter()
1033        .chain(coords_to.iter())
1034        .any(|v| !v.is_finite())
1035    {
1036        return Err("layer transport coordinates must all be finite".to_string());
1037    }
1038    topology_from.validate()?;
1039    topology_to.validate()?;
1040
1041    // --- degree + rotation gauge + unwrapped response -----------------------
1042    let (degree, degree_concentration, rotation_offset, response): (
1043        Option<i32>,
1044        Option<f64>,
1045        f64,
1046        Array1<f64>,
1047    ) = match (topology_from, topology_to) {
1048        (ChartTopology::Circle, ChartTopology::Circle) => {
1049            // Winding degree by circular concentration: over candidate
1050            // degrees d, the de-wound residual r_i(d) = θ_to − d·θ_from is
1051            // tightest (largest mean resultant length R_d) at the true
1052            // degree whenever the smooth residual stays inside half a turn.
1053            // This is the circular-correlation-maximizing degree estimate
1054            // the issue specifies, in resultant form.
1055            let mut best_degree = DEGREE_CANDIDATES[0];
1056            let mut best_r = f64::NEG_INFINITY;
1057            for &d in DEGREE_CANDIDATES.iter() {
1058                let residual: Vec<f64> = (0..n)
1059                    .map(|i| coords_to[i] - f64::from(d) * coords_from[i])
1060                    .collect();
1061                let r = resultant_length(&residual);
1062                if r > best_r {
1063                    best_r = r;
1064                    best_degree = d;
1065                }
1066            }
1067            let residual: Vec<f64> = (0..n)
1068                .map(|i| coords_to[i] - f64::from(best_degree) * coords_from[i])
1069                .collect();
1070            let mu = circular_mean(&residual);
1071            let response = Array1::from_iter(residual.iter().map(|&r| wrap_pi(r - mu)));
1072            (Some(best_degree), Some(best_r), mu, response)
1073        }
1074        (_, ChartTopology::Circle) => {
1075            // Interval domain, circular target: the domain is contractible so
1076            // the map is null-homotopic — no winding term. Unwrap the angular
1077            // response about its circular mean.
1078            let angles: Vec<f64> = coords_to.iter().copied().collect();
1079            let mu = circular_mean(&angles);
1080            let response = Array1::from_iter(angles.iter().map(|&a| wrap_pi(a - mu)));
1081            (None, None, mu, response)
1082        }
1083        (_, ChartTopology::Interval { .. }) => (None, None, 0.0, coords_to.to_owned()),
1084    };
1085
1086    // --- REML residual smooth on the source chart ---------------------------
1087    let basis = DomainBasis::build(topology_from, coords_from)?;
1088    let design = basis.value_rows(coords_from)?;
1089    let penalty = basis.penalty()?;
1090    let fit = fit_penalized_1d(
1091        &design,
1092        &penalty,
1093        response.view(),
1094        None,
1095        basis.penalty_rank(),
1096        false,
1097    )?;
1098
1099    // --- isometry defect under the empirical density -------------------------
1100    let slope = degree.map_or(0.0, f64::from);
1101    let deriv_rows = basis.derivative_rows(coords_from)?;
1102    let deriv = deriv_rows.dot(&fit.beta).mapv(|v| v + slope);
1103    let m = basis.num_basis();
1104    let mut defect = 0.0_f64;
1105    let mut grad = Array1::<f64>::zeros(m);
1106    for i in 0..n {
1107        let speed = deriv[i].abs();
1108        let gap = speed - 1.0;
1109        defect += gap * gap;
1110        let sgn = if deriv[i] >= 0.0 { 1.0 } else { -1.0 };
1111        for j in 0..m {
1112            grad[j] += 2.0 * gap * sgn * deriv_rows[[i, j]];
1113        }
1114    }
1115    defect /= n as f64;
1116    grad.mapv_inplace(|v| v / n as f64);
1117    let isometry_defect_se = grad.dot(&fit.covariance.dot(&grad)).max(0.0).sqrt();
1118
1119    // --- fold / orientation check on a dense grid ---------------------------
1120    let grid = domain_grid(topology_from, FOLD_CHECK_GRID);
1121    let grid_deriv = basis
1122        .derivative_rows(grid.view())?
1123        .dot(&fit.beta)
1124        .mapv(|v| v + slope);
1125    let orientation = if slope != 0.0 {
1126        slope.signum()
1127    } else {
1128        let mean = grid_deriv.iter().sum::<f64>() / grid_deriv.len() as f64;
1129        if mean < 0.0 { -1.0 } else { 1.0 }
1130    };
1131    let min_directional_derivative = grid_deriv
1132        .iter()
1133        .map(|&v| orientation * v)
1134        .fold(f64::INFINITY, f64::min);
1135    let topology_preserved = match (topology_from, topology_to) {
1136        (ChartTopology::Circle, ChartTopology::Circle) => {
1137            matches!(degree, Some(1) | Some(-1)) && min_directional_derivative > 0.0
1138        }
1139        (ChartTopology::Interval { .. }, ChartTopology::Interval { .. }) => {
1140            min_directional_derivative > 0.0
1141        }
1142        _ => false,
1143    };
1144
1145    Ok(FittedTransport {
1146        topology_from,
1147        topology_to,
1148        degree,
1149        degree_concentration,
1150        rotation_offset,
1151        beta: fit.beta,
1152        covariance: fit.covariance,
1153        smoothing_lambda: fit.lambda,
1154        edf: fit.edf,
1155        noise_variance: fit.sigma2,
1156        n_obs: n,
1157        isometry_defect: defect,
1158        isometry_defect_se,
1159        topology_preserved,
1160        min_directional_derivative,
1161        residual_rms: fit.residual_rms,
1162        basis,
1163    })
1164}
1165
1166/// Estimate the transport map between two layers and package the evidence.
1167pub fn fit_layer_transport(
1168    layer_from: usize,
1169    layer_to: usize,
1170    coords_from: ArrayView1<'_, f64>,
1171    coords_to: ArrayView1<'_, f64>,
1172    topology_from: ChartTopology,
1173    topology_to: ChartTopology,
1174) -> Result<LayerTransportReport, String> {
1175    Ok(
1176        fit_transport_map(coords_from, coords_to, topology_from, topology_to)?
1177            .report(layer_from, layer_to),
1178    )
1179}
1180
1181/// Composition-law test report for one triple `(h_ab, h_bc, h_ac)`.
1182#[derive(Debug, Clone)]
1183pub struct CompositionDefectReport {
1184    pub n_grid: usize,
1185    /// Rotation gauge applied to the composed route (circle targets).
1186    pub gauge_rotation: f64,
1187    /// Whether the reflected target orientation minimized the defect.
1188    pub gauge_reflected: bool,
1189    pub mean_abs_defect: f64,
1190    pub rms_defect: f64,
1191    pub max_abs_defect: f64,
1192    /// `max_t |d(t)| / band(t)` against the composed pointwise bands.
1193    pub max_studentized_defect: f64,
1194    /// Bonferroni p-value bound for the max studentized defect over all tested
1195    /// grid points.
1196    pub max_studentized_p_value: f64,
1197    /// EDF of the variance-weighted REML defect smooth.
1198    pub defect_edf: f64,
1199    /// Wood rank-truncated Wald statistic of the defect smooth.
1200    pub statistic: f64,
1201    pub ref_df: f64,
1202    /// `wood_smooth_test` p-value for H₀: the gauge-aligned defect is zero.
1203    pub p_value: f64,
1204}
1205
1206/// Recover the monomial coefficients (ascending: `c[0] + c[1]·u + …`) of the
1207/// degree-`(values.len()−1)` polynomial that interpolates `values` at the
1208/// integer abscissae `u = 0, 1, …, values.len()−1`. Used by the strict
1209/// monotonicity certificate to reconstruct `h′` on a knot span from equally
1210/// spaced samples. Exact for the polynomial pieces of a B-spline derivative.
1211fn monomial_from_equispaced(values: &[f64]) -> Vec<f64> {
1212    let n = values.len();
1213    if n == 0 {
1214        return Vec::new();
1215    }
1216    // Newton forward differences Δ^k f[0] over the equally spaced nodes.
1217    let mut diffs: Vec<f64> = values.to_vec();
1218    let mut fwd = vec![0.0_f64; n];
1219    fwd[0] = diffs[0];
1220    for k in 1..n {
1221        for i in 0..(n - k) {
1222            diffs[i] = diffs[i + 1] - diffs[i];
1223        }
1224        fwd[k] = diffs[0];
1225    }
1226    // Newton form p(u) = Σ_k Δ^k f[0] · C(u, k), with the falling-factorial
1227    // binomial C(u, k) = u(u−1)…(u−k+1)/k!. Accumulate into monomial coeffs.
1228    let mut coeffs = vec![0.0_f64; n];
1229    // poly tracks the expanded C(u,k)·k!  = Π_{j<k}(u − j); divide by k! via the
1230    // running factorial.
1231    let mut poly = vec![0.0_f64; n];
1232    poly[0] = 1.0;
1233    let mut poly_len = 1usize;
1234    let mut factorial = 1.0_f64;
1235    for k in 0..n {
1236        if k > 0 {
1237            factorial *= k as f64;
1238        }
1239        let scale = fwd[k] / factorial;
1240        for (i, &p) in poly.iter().take(poly_len).enumerate() {
1241            coeffs[i] += scale * p;
1242        }
1243        // Multiply running product by (u − k): poly ← poly·(u − k).
1244        if k + 1 < n {
1245            let mut next = vec![0.0_f64; poly_len + 1];
1246            for i in 0..poly_len {
1247                next[i + 1] += poly[i]; // u · poly
1248                next[i] -= (k as f64) * poly[i]; // −k · poly
1249            }
1250            for i in 0..=poly_len {
1251                poly[i] = next[i];
1252            }
1253            poly_len += 1;
1254        }
1255    }
1256    coeffs
1257}
1258
1259/// Evaluate an ascending monomial polynomial at `u` (Horner).
1260fn eval_monomial(coeffs: &[f64], u: f64) -> f64 {
1261    coeffs.iter().rev().fold(0.0_f64, |acc, &c| acc * u + c)
1262}
1263
1264/// Interior critical points (roots of the derivative) of an ascending monomial
1265/// polynomial, in the local `u` coordinate. Returns the closed-form roots for
1266/// degree ≤ 2 derivatives (i.e. cubic-spline pieces, the production path);
1267/// higher-degree derivatives fall back to a robust bisection root-isolation so
1268/// the certificate stays exact-enough (a missed extremum can only make the
1269/// certificate stricter, never falsely accept a fold, because the endpoints and
1270/// every sign change found are still checked). For the cubic transport splines
1271/// the polynomial is quadratic and this is the single vertex.
1272fn monomial_critical_points(coeffs: &[f64]) -> Vec<f64> {
1273    // Derivative coefficients: d/du Σ c_k u^k = Σ k·c_k u^{k−1}.
1274    let n = coeffs.len();
1275    if n <= 1 {
1276        return Vec::new();
1277    }
1278    let deriv: Vec<f64> = (1..n).map(|k| k as f64 * coeffs[k]).collect();
1279    // deriv is ascending of length n−1 (degree n−2).
1280    match deriv.len() {
1281        0 => Vec::new(),
1282        1 => Vec::new(), // constant derivative: no critical point
1283        2 => {
1284            // Linear b + a·u = 0 (a = deriv[1]).
1285            let (b, a) = (deriv[0], deriv[1]);
1286            if a.abs() <= f64::MIN_POSITIVE {
1287                Vec::new()
1288            } else {
1289                vec![-b / a]
1290            }
1291        }
1292        3 => {
1293            // Quadratic c + b·u + a·u² = 0.
1294            let (c, b, a) = (deriv[0], deriv[1], deriv[2]);
1295            if a.abs() <= f64::MIN_POSITIVE {
1296                if b.abs() <= f64::MIN_POSITIVE {
1297                    Vec::new()
1298                } else {
1299                    vec![-c / b]
1300                }
1301            } else {
1302                let disc = b * b - 4.0 * a * c;
1303                if disc < 0.0 {
1304                    Vec::new()
1305                } else {
1306                    let s = disc.sqrt();
1307                    vec![(-b + s) / (2.0 * a), (-b - s) / (2.0 * a)]
1308                }
1309            }
1310        }
1311        _ => {
1312            // General fallback: scan for sign changes of the derivative on a
1313            // dense [0, deg] grid and bisect each bracket. Conservative.
1314            let lo = 0.0;
1315            let hi = (coeffs.len() - 1) as f64;
1316            let steps = 256;
1317            let mut roots = Vec::new();
1318            let f = |u: f64| eval_monomial(&deriv, u);
1319            let mut prev_u = lo;
1320            let mut prev_v = f(lo);
1321            for i in 1..=steps {
1322                let u = lo + (hi - lo) * i as f64 / steps as f64;
1323                let v = f(u);
1324                if prev_v == 0.0 {
1325                    roots.push(prev_u);
1326                } else if prev_v * v < 0.0 {
1327                    let (mut a, mut b) = (prev_u, u);
1328                    for _ in 0..60 {
1329                        let m = 0.5 * (a + b);
1330                        if f(a) * f(m) <= 0.0 {
1331                            b = m;
1332                        } else {
1333                            a = m;
1334                        }
1335                    }
1336                    roots.push(0.5 * (a + b));
1337                }
1338                prev_u = u;
1339                prev_v = v;
1340            }
1341            roots
1342        }
1343    }
1344}
1345
1346/// Uniform evaluation grid over a chart domain.
1347fn domain_grid(topology: ChartTopology, n: usize) -> Array1<f64> {
1348    match topology {
1349        ChartTopology::Circle => Array1::from_iter((0..n).map(|i| TAU * i as f64 / n as f64)),
1350        ChartTopology::Interval { lo, hi } => {
1351            Array1::from_iter((0..n).map(|i| lo + (hi - lo) * i as f64 / (n - 1).max(1) as f64))
1352        }
1353    }
1354}
1355
1356/// Test the composition law `h_ac ≟ h_bc ∘ h_ab` on `n_grid` points.
1357///
1358/// The defect `d(t) = h_ac(t) ⊖ (h_bc ∘ h_ab)(t)` (circular difference on
1359/// circle targets) is first quotiented by the certified isometry gauge of the
1360/// TARGET chart only — the source gauge cancels because both routes consume
1361/// identical source coordinates (double-coset estimand; see module docs):
1362/// rotation fixed at the circular mean of the defect, reflection chosen as
1363/// the orientation with smaller squared defect. The aligned defect is then
1364/// (a) studentized pointwise against the composed delta-method bands
1365/// `var(h_ac) + var(h_bc) + h_bc′² var(h_ab)` (the three maps are fitted from
1366/// disjoint response pairs; cross-correlations through shared rows are
1367/// neglected), with the max studentized defect as the headline statistic, and
1368/// (b) smoothed by a variance-weighted known-scale REML fit whose coefficients
1369/// feed [`wood_smooth_test`] for the calibrated p-value.
1370pub fn composition_defect(
1371    h_ab: &FittedTransport,
1372    h_bc: &FittedTransport,
1373    h_ac: &FittedTransport,
1374    n_grid: usize,
1375) -> Result<CompositionDefectReport, String> {
1376    if h_ab.topology_from != h_ac.topology_from
1377        || h_ab.topology_to != h_bc.topology_from
1378        || h_bc.topology_to != h_ac.topology_to
1379    {
1380        return Err("composition defect requires chart-compatible transports: \
1381             h_ab: A→B, h_bc: B→C, h_ac: A→C"
1382            .to_string());
1383    }
1384    if n_grid < MIN_TRANSPORT_OBS {
1385        return Err(format!(
1386            "composition defect grid must have at least {MIN_TRANSPORT_OBS} points, got {n_grid}"
1387        ));
1388    }
1389
1390    let grid = domain_grid(h_ab.topology_from, n_grid);
1391    let (direct, var_direct) = h_ac.eval_with_variance(grid.view())?;
1392    let (mid, var_mid) = h_ab.eval_with_variance(grid.view())?;
1393    let (composed, var_bc) = h_bc.eval_with_variance(mid.view())?;
1394    let mid_slope = h_bc.derivative(mid.view())?;
1395    let mut variance = Array1::<f64>::zeros(n_grid);
1396    for i in 0..n_grid {
1397        variance[i] = var_direct[i] + var_bc[i] + mid_slope[i] * mid_slope[i] * var_mid[i];
1398    }
1399
1400    // --- target-chart gauge alignment (rotation + reflection only) ----------
1401    let circle_target = matches!(h_ac.topology_to, ChartTopology::Circle);
1402    let mut gauge_reflected = false;
1403    let mut gauge_rotation = 0.0_f64;
1404    let mut defect = Array1::<f64>::zeros(n_grid);
1405    let mut best_sse = f64::INFINITY;
1406    for reflected in [false, true] {
1407        let composed_oriented: Array1<f64> = match (h_ac.topology_to, reflected) {
1408            (_, false) => composed.clone(),
1409            (ChartTopology::Circle, true) => composed.mapv(|v| wrap_tau(-v)),
1410            (ChartTopology::Interval { lo, hi }, true) => composed.mapv(|v| lo + hi - v),
1411        };
1412        let (rotation, candidate): (f64, Array1<f64>) = if circle_target {
1413            let raw: Vec<f64> = (0..n_grid)
1414                .map(|i| wrap_pi(direct[i] - composed_oriented[i]))
1415                .collect();
1416            let rot = circular_mean(&raw);
1417            (
1418                rot,
1419                Array1::from_iter(raw.iter().map(|&d| wrap_pi(d - rot))),
1420            )
1421        } else {
1422            (
1423                0.0,
1424                Array1::from_iter((0..n_grid).map(|i| direct[i] - composed_oriented[i])),
1425            )
1426        };
1427        let sse = candidate.iter().map(|&d| d * d).sum::<f64>();
1428        if sse < best_sse {
1429            best_sse = sse;
1430            gauge_reflected = reflected;
1431            gauge_rotation = rotation;
1432            defect = candidate;
1433        }
1434    }
1435
1436    // --- pointwise studentization against the composed bands ----------------
1437    let max_var = variance.iter().copied().fold(0.0_f64, f64::max);
1438    let var_floor = (max_var * 1e-10).max(f64::MIN_POSITIVE);
1439    let mut max_abs = 0.0_f64;
1440    let mut sum_abs = 0.0_f64;
1441    let mut sum_sq = 0.0_f64;
1442    let mut max_z = 0.0_f64;
1443    for i in 0..n_grid {
1444        let d = defect[i];
1445        let a = d.abs();
1446        max_abs = max_abs.max(a);
1447        sum_abs += a;
1448        sum_sq += d * d;
1449        let z = a / variance[i].max(var_floor).sqrt();
1450        max_z = max_z.max(z);
1451    }
1452    let mean_abs_defect = sum_abs / n_grid as f64;
1453    let rms_defect = (sum_sq / n_grid as f64).sqrt();
1454
1455    // --- variance-weighted REML defect smooth + Wood Wald test ---------------
1456    let basis = DomainBasis::build(h_ab.topology_from, grid.view())?;
1457    let design = basis.value_rows(grid.view())?;
1458    let penalty = basis.penalty()?;
1459    let weights = variance.mapv(|v| 1.0 / v.max(var_floor));
1460    let fit = fit_penalized_1d(
1461        &design,
1462        &penalty,
1463        defect.view(),
1464        Some(weights.view()),
1465        basis.penalty_rank(),
1466        true,
1467    )?;
1468    let m = basis.num_basis();
1469    let test = wood_smooth_test(SmoothTestInput {
1470        beta: fit.beta.view(),
1471        covariance: &fit.covariance,
1472        influence_matrix: Some(&fit.influence),
1473        coeff_range: 0..m,
1474        edf: fit.edf,
1475        nullspace_dim: 0,
1476        residual_df: (n_grid as f64 - fit.edf).max(1.0),
1477        scale: SmoothTestScale::Known,
1478    })
1479    .ok_or_else(|| "composition defect smooth test degenerated".to_string())?;
1480
1481    // Bonferroni bound for the max studentized defect over the actual grid:
1482    // valid for arbitrary dependence among the tested pointwise contrasts.
1483    let normal =
1484        Normal::new(0.0, 1.0).map_err(|e| format!("standard normal construction failed: {e}"))?;
1485    let pointwise: f64 = (2.0 * (1.0 - normal.cdf(max_z))).clamp(0.0, 1.0);
1486    let max_studentized_p_value = (n_grid as f64 * pointwise).min(1.0);
1487
1488    Ok(CompositionDefectReport {
1489        n_grid,
1490        gauge_rotation,
1491        gauge_reflected,
1492        mean_abs_defect,
1493        rms_defect,
1494        max_abs_defect: max_abs,
1495        max_studentized_defect: max_z,
1496        max_studentized_p_value,
1497        defect_edf: fit.edf,
1498        statistic: test.statistic,
1499        ref_df: test.ref_df,
1500        p_value: test.p_value,
1501    })
1502}
1503
1504/// Full transport report for a ladder of layers: every adjacent map plus
1505/// every two-hop map with its composition-law test attached.
1506#[derive(Debug, Clone)]
1507pub struct TransportLadderReport {
1508    /// `h_{l→l+1}` for each consecutive pair.
1509    pub adjacent: Vec<LayerTransportReport>,
1510    /// `h_{l→l+2}` with the composition test against the composed adjacent
1511    /// pair merged in.
1512    pub two_hop: Vec<LayerTransportReport>,
1513}
1514
1515/// Fit the whole transport ladder: adjacent maps, two-hop maps, and the
1516/// composition law `h_{l→l+2} ≟ h_{l+1→l+2} ∘ h_{l→l+1}` per triple.
1517///
1518/// `layers[k]`, `coords[k]`, `topologies[k]` describe layer `k` of the
1519/// ladder; all coordinate vectors must index the same rows.
1520pub fn transport_ladder(
1521    layers: &[usize],
1522    coords: &[Array1<f64>],
1523    topologies: &[ChartTopology],
1524) -> Result<TransportLadderReport, String> {
1525    let depth = layers.len();
1526    if coords.len() != depth || topologies.len() != depth {
1527        return Err(format!(
1528            "transport ladder inputs disagree: {depth} layers, {} coordinate vectors, {} topologies",
1529            coords.len(),
1530            topologies.len()
1531        ));
1532    }
1533    if depth < 2 {
1534        return Err("transport ladder needs at least two layers".to_string());
1535    }
1536
1537    let mut adjacent_fits: Vec<FittedTransport> = Vec::with_capacity(depth - 1);
1538    let mut adjacent: Vec<LayerTransportReport> = Vec::with_capacity(depth - 1);
1539    for k in 0..depth - 1 {
1540        let fit = fit_transport_map(
1541            coords[k].view(),
1542            coords[k + 1].view(),
1543            topologies[k],
1544            topologies[k + 1],
1545        )
1546        .map_err(|e| {
1547            format!(
1548                "adjacent transport {}→{} failed: {e}",
1549                layers[k],
1550                layers[k + 1]
1551            )
1552        })?;
1553        adjacent.push(fit.report(layers[k], layers[k + 1]));
1554        adjacent_fits.push(fit);
1555    }
1556
1557    let mut two_hop: Vec<LayerTransportReport> = Vec::with_capacity(depth.saturating_sub(2));
1558    for k in 0..depth.saturating_sub(2) {
1559        let direct = fit_transport_map(
1560            coords[k].view(),
1561            coords[k + 2].view(),
1562            topologies[k],
1563            topologies[k + 2],
1564        )
1565        .map_err(|e| {
1566            format!(
1567                "two-hop transport {}→{} failed: {e}",
1568                layers[k],
1569                layers[k + 2]
1570            )
1571        })?;
1572        let composition = composition_defect(
1573            &adjacent_fits[k],
1574            &adjacent_fits[k + 1],
1575            &direct,
1576            DEFAULT_COMPOSITION_GRID,
1577        )
1578        .map_err(|e| {
1579            format!(
1580                "composition test {}→{}→{} failed: {e}",
1581                layers[k],
1582                layers[k + 1],
1583                layers[k + 2]
1584            )
1585        })?;
1586        two_hop.push(
1587            direct
1588                .report(layers[k], layers[k + 2])
1589                .with_composition(&composition),
1590        );
1591    }
1592
1593    Ok(TransportLadderReport { adjacent, two_hop })
1594}
1595
1596#[cfg(test)]
1597mod invert_tests {
1598    use super::*;
1599    use ndarray::Array1;
1600
1601    fn interval(lo: f64, hi: f64) -> ChartTopology {
1602        ChartTopology::Interval { lo, hi }
1603    }
1604
1605    #[test]
1606    fn invert_round_trips_interval_transport() {
1607        // A strictly increasing nonlinear warp on [0,1] → [0,1] with derivative
1608        // bounded away from zero: to = (t + 0.25·sin(2πt)/(2π)) normalized, whose
1609        // h′ = 1 + 0.25·cos(2πt) ∈ [0.75, 1.25] never approaches zero.
1610        let n = 64;
1611        let from: Array1<f64> = Array1::from_iter((0..n).map(|i| i as f64 / (n as f64 - 1.0)));
1612        let to: Array1<f64> = from.mapv(|t| t + 0.25 * (TAU * t).sin() / TAU);
1613        let ft = fit_transport_map(
1614            from.view(),
1615            to.view(),
1616            interval(0.0, 1.0),
1617            interval(0.0, 1.0),
1618        )
1619        .expect("fit");
1620        assert!(
1621            ft.topology_preserved,
1622            "monotone warp should preserve topology"
1623        );
1624
1625        let probe = Array1::from_iter((1..10).map(|i| i as f64 / 10.0));
1626        // eval ∘ invert and invert ∘ eval both return identity.
1627        let fwd = ft.eval(probe.view()).expect("eval");
1628        let back = ft.invert(fwd.view()).expect("invert");
1629        for i in 0..probe.len() {
1630            assert!(
1631                (back[i] - probe[i]).abs() < 1e-6,
1632                "round-trip failed: t={} back={}",
1633                probe[i],
1634                back[i]
1635            );
1636        }
1637        let re_eval = ft.eval(back.view()).expect("eval");
1638        for i in 0..fwd.len() {
1639            assert!((re_eval[i] - fwd[i]).abs() < 1e-9);
1640        }
1641    }
1642
1643    #[test]
1644    fn invert_round_trips_decreasing_interval_transport() {
1645        // Orientation-reversing homeomorphism with derivative bounded away from
1646        // zero: to = 1 - 0.5·from - 0.5·from² on [0,1] (h′ = -0.5 - from ≤ -0.5).
1647        let n = 64;
1648        let from: Array1<f64> = Array1::from_iter((0..n).map(|i| i as f64 / (n as f64 - 1.0)));
1649        let to: Array1<f64> = from.mapv(|t| 1.0 - 0.5 * t - 0.5 * t * t);
1650        let ft = fit_transport_map(
1651            from.view(),
1652            to.view(),
1653            interval(0.0, 1.0),
1654            interval(0.0, 1.0),
1655        )
1656        .expect("fit");
1657        assert!(ft.topology_preserved);
1658        let probe = Array1::from_iter((1..10).map(|i| i as f64 / 10.0));
1659        let fwd = ft.eval(probe.view()).expect("eval");
1660        let back = ft.invert(fwd.view()).expect("invert");
1661        for i in 0..probe.len() {
1662            assert!(
1663                (back[i] - probe[i]).abs() < 1e-6,
1664                "t={} back={}",
1665                probe[i],
1666                back[i]
1667            );
1668        }
1669    }
1670
1671    #[test]
1672    fn invert_round_trips_circle_transport() {
1673        // Degree-1 circle cover: a rotation plus a fold-free wiggle.
1674        let n = 128;
1675        let from: Array1<f64> = Array1::from_iter((0..n).map(|i| TAU * i as f64 / n as f64));
1676        let to: Array1<f64> = from.mapv(|t| wrap_tau(t + 0.3 + 0.2 * t.sin()));
1677        let ft = fit_transport_map(
1678            from.view(),
1679            to.view(),
1680            ChartTopology::Circle,
1681            ChartTopology::Circle,
1682        )
1683        .expect("fit");
1684        assert!(ft.topology_preserved, "degree {:?}", ft.degree);
1685
1686        let probe = Array1::from_iter((0..7).map(|i| TAU * (i as f64 + 0.5) / 7.0));
1687        let fwd = ft.eval(probe.view()).expect("eval");
1688        let back = ft.invert(fwd.view()).expect("invert");
1689        for i in 0..probe.len() {
1690            // Compare modulo 2π.
1691            let d = wrap_pi(back[i] - probe[i]).abs();
1692            assert!(d < 1e-5, "probe={} back={} d={}", probe[i], back[i], d);
1693        }
1694    }
1695
1696    #[test]
1697    fn invert_rejects_target_outside_interval_image() {
1698        // Image of `to = 0.5·from` is ~[0, 0.5]; y = 0.9 is outside it.
1699        let n = 32;
1700        let from: Array1<f64> = Array1::from_iter((0..n).map(|i| i as f64 / (n as f64 - 1.0)));
1701        let to: Array1<f64> = from.mapv(|t| 0.5 * t);
1702        let ft = fit_transport_map(
1703            from.view(),
1704            to.view(),
1705            interval(0.0, 1.0),
1706            interval(0.0, 1.0),
1707        )
1708        .expect("fit");
1709        assert!(ft.invert(Array1::from_elem(1, 0.9).view()).is_err());
1710    }
1711
1712    /// Build a `FittedTransport` on an interval whose pre-wrap map interpolates
1713    /// `h` by an unpenalized least-squares spline fit (so a deliberately narrow
1714    /// fold in `h` survives into the coefficients, unlike a REML fit which would
1715    /// smooth it away). Fields irrelevant to `eval`/`derivative`/`invert` are
1716    /// filled with sound placeholders.
1717    fn fitted_from_target(
1718        from: ArrayView1<'_, f64>,
1719        target: ArrayView1<'_, f64>,
1720        lo: f64,
1721        hi: f64,
1722    ) -> FittedTransport {
1723        let basis = DomainBasis::build(interval(lo, hi), from).expect("basis");
1724        let design = basis.value_rows(from).expect("design");
1725        let m = design.ncols();
1726        // Normal equations XᵀX β = Xᵀy with a tiny ridge for conditioning only.
1727        let mut xtx = design.t().dot(&design);
1728        let xty = design.t().dot(&target);
1729        let diag = (0..m).map(|i| xtx[[i, i]].abs()).fold(1.0_f64, f64::max);
1730        for i in 0..m {
1731            xtx[[i, i]] += 1e-10 * diag;
1732        }
1733        let (evals, evecs) = xtx.eigh(Side::Lower).expect("eigh");
1734        let rotated = evecs.t().dot(&xty);
1735        let mut beta = Array1::<f64>::zeros(m);
1736        for i in 0..m {
1737            let d = evals[i].max(f64::MIN_POSITIVE);
1738            let c = rotated[i] / d;
1739            for j in 0..m {
1740                beta[j] += evecs[[j, i]] * c;
1741            }
1742        }
1743        FittedTransport {
1744            topology_from: interval(lo, hi),
1745            topology_to: interval(lo, hi),
1746            degree: None,
1747            degree_concentration: None,
1748            rotation_offset: 0.0,
1749            beta,
1750            covariance: Array2::<f64>::zeros((m, m)),
1751            smoothing_lambda: 0.0,
1752            edf: 0.0,
1753            noise_variance: 1.0,
1754            n_obs: from.len(),
1755            isometry_defect: 0.0,
1756            isometry_defect_se: 0.0,
1757            topology_preserved: true,
1758            min_directional_derivative: 1.0,
1759            residual_rms: 0.0,
1760            basis,
1761        }
1762    }
1763
1764    /// Reviewer's between-grid fold reproducer: h(t) = (t−0.5)³/3 − (0.4/511)²·t
1765    /// hides a narrow fold between the 512-point certification-grid samples.
1766    /// `topology_preserved` (the sampled diagnostic) reads true, yet a dense
1767    /// grid finds orientation·h′ < 0 — the span-exact certificate that `invert`
1768    /// now gates on must reject the fit.
1769    #[test]
1770    fn invert_rejects_between_grid_fold() {
1771        let n = 256;
1772        let from: Array1<f64> = Array1::from_iter((0..n).map(|i| i as f64 / (n as f64 - 1.0)));
1773        let eps = 0.4 / 511.0;
1774        let target: Array1<f64> = from.mapv(|t| (t - 0.5).powi(3) / 3.0 - eps * eps * t);
1775        let mut ft = fitted_from_target(from.view(), target.view(), 0.0, 1.0);
1776
1777        // Confirm the fold is genuinely between the 512-pt certification grid:
1778        // recompute the sampled diagnostic the production fit uses.
1779        let grid = domain_grid(interval(0.0, 1.0), FOLD_CHECK_GRID);
1780        let grid_d = ft.derivative(grid.view()).expect("grid deriv");
1781        let mean = grid_d.iter().sum::<f64>() / grid_d.len() as f64;
1782        let orientation = if mean < 0.0 { -1.0 } else { 1.0 };
1783        let min_grid = grid_d
1784            .iter()
1785            .map(|&v| orientation * v)
1786            .fold(f64::INFINITY, f64::min);
1787        // Dense grid (10× finer) to expose the hidden fold.
1788        let dense = Array1::from_iter((0..5120).map(|i| i as f64 / 5119.0));
1789        let dense_d = ft.derivative(dense.view()).expect("dense deriv");
1790        let min_dense = dense_d
1791            .iter()
1792            .map(|&v| orientation * v)
1793            .fold(f64::INFINITY, f64::min);
1794        ft.topology_preserved = min_grid > 0.0;
1795        ft.min_directional_derivative = min_grid;
1796        assert!(
1797            min_grid > 0.0 && min_dense < 0.0,
1798            "fixture must hide a between-grid fold: min on 512-grid={min_grid}, \
1799             min on dense grid={min_dense}"
1800        );
1801
1802        // The span-exact certificate must reject it even though the sampled
1803        // diagnostic passed.
1804        let res = ft.invert(Array1::from_elem(1, 0.0).view());
1805        assert!(
1806            res.is_err(),
1807            "between-grid fold must be rejected by the span-exact certificate \
1808             (topology_preserved={}, min_grid={min_grid}, min_dense={min_dense})",
1809            ft.topology_preserved
1810        );
1811    }
1812
1813    #[test]
1814    fn invert_rejects_non_finite_targets() {
1815        let n = 64;
1816        let from: Array1<f64> = Array1::from_iter((0..n).map(|i| i as f64 / (n as f64 - 1.0)));
1817        let to: Array1<f64> = from.mapv(|t| 0.5 * t);
1818        let ft = fit_transport_map(
1819            from.view(),
1820            to.view(),
1821            interval(0.0, 1.0),
1822            interval(0.0, 1.0),
1823        )
1824        .expect("fit");
1825        for bad in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
1826            assert!(
1827                ft.invert(Array1::from_elem(1, bad).view()).is_err(),
1828                "non-finite target {bad} must be rejected"
1829            );
1830        }
1831    }
1832
1833    #[test]
1834    fn invert_image_tolerance_is_scale_aware() {
1835        // Image of `to = 1e-8·from` is ~[0, 1e-8]. A target 5% outside it must
1836        // be rejected, not silently clamped, under the scale-aware tolerance
1837        // (the old absolute 1e-9 would have accepted it).
1838        let n = 64;
1839        let from: Array1<f64> = Array1::from_iter((0..n).map(|i| i as f64 / (n as f64 - 1.0)));
1840        let scale = 1.0e-8;
1841        let to: Array1<f64> = from.mapv(|t| scale * t);
1842        let ft = fit_transport_map(
1843            from.view(),
1844            to.view(),
1845            interval(0.0, 1.0),
1846            interval(0.0, 1.0),
1847        )
1848        .expect("fit");
1849        let outside = 1.05e-8;
1850        assert!(
1851            ft.invert(Array1::from_elem(1, outside).view()).is_err(),
1852            "target {outside} is 5% outside the [0, {scale}] image and must be rejected"
1853        );
1854        // A target inside the image still round-trips.
1855        let inside = 0.5e-8;
1856        let t = ft
1857            .invert(Array1::from_elem(1, inside).view())
1858            .expect("invert inside");
1859        let re = ft.eval(t.view()).expect("eval");
1860        assert!((re[0] - inside).abs() < 1e-3 * scale);
1861    }
1862
1863    #[test]
1864    fn invert_round_trips_degree_minus_one_circle() {
1865        // Orientation-reversing degree −1 circle cover: a reflection plus a
1866        // fold-free wiggle.
1867        let n = 128;
1868        let from: Array1<f64> = Array1::from_iter((0..n).map(|i| TAU * i as f64 / n as f64));
1869        let to: Array1<f64> = from.mapv(|t| wrap_tau(-t + 0.4 + 0.15 * t.sin()));
1870        let ft = fit_transport_map(
1871            from.view(),
1872            to.view(),
1873            ChartTopology::Circle,
1874            ChartTopology::Circle,
1875        )
1876        .expect("fit");
1877        assert_eq!(ft.degree, Some(-1), "expected a degree −1 cover");
1878        assert!(ft.topology_preserved, "degree {:?}", ft.degree);
1879        let probe = Array1::from_iter((0..7).map(|i| TAU * (i as f64 + 0.5) / 7.0));
1880        let fwd = ft.eval(probe.view()).expect("eval");
1881        let back = ft.invert(fwd.view()).expect("invert");
1882        for i in 0..probe.len() {
1883            let d = wrap_pi(back[i] - probe[i]).abs();
1884            assert!(d < 1e-5, "probe={} back={} d={}", probe[i], back[i], d);
1885        }
1886    }
1887
1888    #[test]
1889    fn invert_round_trips_circle_seam_and_interval_endpoints() {
1890        // Circle seam: invert a target near 0/2π.
1891        let n = 128;
1892        let from: Array1<f64> = Array1::from_iter((0..n).map(|i| TAU * i as f64 / n as f64));
1893        let to: Array1<f64> = from.mapv(|t| wrap_tau(t + 0.3 + 0.2 * t.sin()));
1894        let ft = fit_transport_map(
1895            from.view(),
1896            to.view(),
1897            ChartTopology::Circle,
1898            ChartTopology::Circle,
1899        )
1900        .expect("fit");
1901        assert!(ft.topology_preserved);
1902        for seam in [1e-9, TAU - 1e-9, 0.0] {
1903            let t = ft
1904                .invert(Array1::from_elem(1, seam).view())
1905                .expect("invert seam");
1906            let re = ft.eval(t.view()).expect("eval");
1907            let d = wrap_pi(re[0] - wrap_tau(seam)).abs();
1908            assert!(d < 1e-6, "seam={seam} re={} d={d}", re[0]);
1909        }
1910
1911        // Interval endpoints: invert the image endpoints exactly.
1912        let m = 64;
1913        let ifrom: Array1<f64> = Array1::from_iter((0..m).map(|i| i as f64 / (m as f64 - 1.0)));
1914        let ito: Array1<f64> = ifrom.mapv(|t| t + 0.25 * (TAU * t).sin() / TAU);
1915        let ift = fit_transport_map(
1916            ifrom.view(),
1917            ito.view(),
1918            interval(0.0, 1.0),
1919            interval(0.0, 1.0),
1920        )
1921        .expect("fit");
1922        let raw_lo = ift.raw_at(0.0).expect("raw lo");
1923        let raw_hi = ift.raw_at(1.0).expect("raw hi");
1924        for &edge in &[raw_lo, raw_hi] {
1925            let t = ift
1926                .invert(Array1::from_elem(1, edge).view())
1927                .expect("invert endpoint");
1928            assert!(t[0] >= -1e-9 && t[0] <= 1.0 + 1e-9, "endpoint t={}", t[0]);
1929            let re = ift.eval(t.view()).expect("eval");
1930            assert!((re[0] - edge).abs() < 1e-6, "edge={edge} re={}", re[0]);
1931        }
1932    }
1933
1934    #[test]
1935    fn monomial_reconstruction_is_exact_for_quadratic() {
1936        // The certificate's polynomial reconstruction must be exact on the
1937        // quadratic pieces of a cubic-spline derivative.
1938        let coeffs_true = [0.7_f64, -1.3, 2.1]; // 0.7 − 1.3u + 2.1u²
1939        let values: Vec<f64> = (0..3)
1940            .map(|i| eval_monomial(&coeffs_true, i as f64))
1941            .collect();
1942        let recon = monomial_from_equispaced(&values);
1943        for (a, b) in recon.iter().zip(coeffs_true.iter()) {
1944            assert!((a - b).abs() < 1e-12, "recon {a} vs {b}");
1945        }
1946        // Vertex of 2.1u² − 1.3u + 0.7 is at u = 1.3 / (2·2.1).
1947        let crit = monomial_critical_points(&recon);
1948        assert_eq!(crit.len(), 1);
1949        assert!((crit[0] - 1.3 / 4.2).abs() < 1e-12);
1950    }
1951}