1use gam_linalg::faer_ndarray::FaerEigh;
62use gam_terms::inference::smooth_test::{SmoothTestInput, SmoothTestScale, wood_smooth_test};
63use gam_terms::basis::{
64 BasisOptions, Dense, KnotSource, PeriodicBSplineBasisSpec, build_periodic_bspline_basis_1d,
65 create_basis, create_cyclic_difference_penalty_matrix, create_difference_penalty_matrix,
66 periodic_bspline_first_derivative_nd,
67};
68use crate::chart_canonicalization::CanonicalChartTopology;
69use faer::Side;
70use ndarray::{Array1, Array2, ArrayView1, Axis};
71use statrs::distribution::{ContinuousCDF, Normal};
72use std::f64::consts::{PI, TAU};
73
74const TRANSPORT_SPLINE_DEGREE: usize = 3;
76const TRANSPORT_PENALTY_ORDER: usize = 2;
80const MIN_TRANSPORT_OBS: usize = 16;
82const OBS_PER_BASIS: usize = 8;
84const MIN_PERIODIC_BASIS: usize = 8;
86const MAX_PERIODIC_BASIS: usize = 20;
87const MIN_OPEN_INTERNAL_KNOTS: usize = 4;
89const MAX_OPEN_INTERNAL_KNOTS: usize = 12;
90const DEGREE_CANDIDATES: [i32; 5] = [-2, -1, 0, 1, 2];
92const FOLD_CHECK_GRID: usize = 512;
94pub const DEFAULT_COMPOSITION_GRID: usize = 256;
96const REML_LAMBDA_GRID_POINTS: usize = 41;
98const REML_GOLDEN_ITERATIONS: usize = 40;
99const REML_LAMBDA_SPAN_DECADES: f64 = 8.0;
100
101#[derive(Debug, Clone, Copy, PartialEq)]
103pub enum ChartTopology {
104 Circle,
106 Interval { lo: f64, hi: f64 },
108}
109
110impl ChartTopology {
111 pub fn name(&self) -> &'static str {
113 match self {
114 ChartTopology::Circle => "circle",
115 ChartTopology::Interval { .. } => "interval",
116 }
117 }
118
119 fn validate(&self) -> Result<(), String> {
120 match *self {
121 ChartTopology::Circle => Ok(()),
122 ChartTopology::Interval { lo, hi } => {
123 if !(lo.is_finite() && hi.is_finite()) || hi <= lo {
124 Err(format!(
125 "interval chart bounds must be finite and ordered; got [{lo}, {hi}]"
126 ))
127 } else {
128 Ok(())
129 }
130 }
131 }
132 }
133}
134
135impl From<&CanonicalChartTopology> for ChartTopology {
146 fn from(src: &CanonicalChartTopology) -> Self {
147 match src {
148 CanonicalChartTopology::Circle { .. } => ChartTopology::Circle,
149 CanonicalChartTopology::Interval => ChartTopology::Interval { lo: 0.0, hi: 1.0 },
150 }
151 }
152}
153
154impl From<CanonicalChartTopology> for ChartTopology {
155 fn from(src: CanonicalChartTopology) -> Self {
156 ChartTopology::from(&src)
157 }
158}
159
160fn wrap_tau(x: f64) -> f64 {
162 x.rem_euclid(TAU)
163}
164
165fn wrap_pi(x: f64) -> f64 {
167 let w = (x + PI).rem_euclid(TAU) - PI;
168 if w <= -PI { w + TAU } else { w }
169}
170
171fn circular_mean(angles: &[f64]) -> f64 {
173 let mut s = 0.0_f64;
174 let mut c = 0.0_f64;
175 for &a in angles {
176 s += a.sin();
177 c += a.cos();
178 }
179 if s.hypot(c) <= f64::EPSILON * angles.len().max(1) as f64 {
180 0.0
181 } else {
182 s.atan2(c)
183 }
184}
185
186fn resultant_length(angles: &[f64]) -> f64 {
188 if angles.is_empty() {
189 return 0.0;
190 }
191 let mut s = 0.0_f64;
192 let mut c = 0.0_f64;
193 for &a in angles {
194 s += a.sin();
195 c += a.cos();
196 }
197 s.hypot(c) / angles.len() as f64
198}
199
200#[derive(Debug, Clone)]
204enum DomainBasis {
205 Periodic(PeriodicBSplineBasisSpec),
206 Open { knots: Array1<f64>, degree: usize },
207}
208
209impl DomainBasis {
210 fn build(topology: ChartTopology, coords: ArrayView1<'_, f64>) -> Result<Self, String> {
211 let n = coords.len();
212 match topology {
213 ChartTopology::Circle => {
214 let num_basis = (n / OBS_PER_BASIS).clamp(MIN_PERIODIC_BASIS, MAX_PERIODIC_BASIS);
215 Ok(DomainBasis::Periodic(PeriodicBSplineBasisSpec {
216 degree: TRANSPORT_SPLINE_DEGREE,
217 num_basis,
218 period: TAU,
219 origin: 0.0,
220 penalty_order: TRANSPORT_PENALTY_ORDER,
221 }))
222 }
223 ChartTopology::Interval { lo, hi } => {
224 let num_internal =
225 (n / OBS_PER_BASIS).clamp(MIN_OPEN_INTERNAL_KNOTS, MAX_OPEN_INTERNAL_KNOTS);
226 let (seed, knots) = create_basis::<Dense>(
227 coords.mapv(|v| v.clamp(lo, hi)).view(),
228 KnotSource::Generate {
229 data_range: (lo, hi),
230 num_internal_knots: num_internal,
231 },
232 TRANSPORT_SPLINE_DEGREE,
233 BasisOptions::value(),
234 )
235 .map_err(|e| format!("layer transport open basis construction failed: {e}"))?;
236 if seed.nrows() != n {
237 return Err(format!(
238 "layer transport open basis returned {} rows for {n} inputs",
239 seed.nrows()
240 ));
241 }
242 Ok(DomainBasis::Open {
243 knots,
244 degree: TRANSPORT_SPLINE_DEGREE,
245 })
246 }
247 }
248 }
249
250 fn num_basis(&self) -> usize {
251 match self {
252 DomainBasis::Periodic(spec) => spec.num_basis,
253 DomainBasis::Open { knots, degree } => knots.len() - degree - 1,
254 }
255 }
256
257 fn penalty_rank(&self) -> usize {
261 match self {
262 DomainBasis::Periodic(spec) => spec.num_basis - 1,
263 DomainBasis::Open { .. } => self.num_basis() - TRANSPORT_PENALTY_ORDER,
264 }
265 }
266
267 fn penalty(&self) -> Result<Array2<f64>, String> {
268 match self {
269 DomainBasis::Periodic(spec) => {
270 create_cyclic_difference_penalty_matrix(spec.num_basis, TRANSPORT_PENALTY_ORDER)
271 .map_err(|e| format!("cyclic transport penalty failed: {e}"))
272 }
273 DomainBasis::Open { .. } => {
274 create_difference_penalty_matrix(self.num_basis(), TRANSPORT_PENALTY_ORDER, None)
275 .map_err(|e| format!("open transport penalty failed: {e}"))
276 }
277 }
278 }
279
280 fn project(&self, t: f64) -> f64 {
282 match self {
283 DomainBasis::Periodic(_) => wrap_tau(t),
284 DomainBasis::Open { knots, degree } => {
285 let lo = knots[*degree];
286 let hi = knots[knots.len() - 1 - degree];
287 t.clamp(lo, hi)
288 }
289 }
290 }
291
292 fn value_rows(&self, t: ArrayView1<'_, f64>) -> Result<Array2<f64>, String> {
293 let projected = t.mapv(|v| self.project(v));
294 match self {
295 DomainBasis::Periodic(spec) => build_periodic_bspline_basis_1d(projected.view(), spec)
296 .map_err(|e| format!("periodic transport basis evaluation failed: {e}")),
297 DomainBasis::Open { knots, degree } => {
298 let (rows, used_knots) = create_basis::<Dense>(
299 projected.view(),
300 KnotSource::Provided(knots.view()),
301 *degree,
302 BasisOptions::value(),
303 )
304 .map_err(|e| format!("open transport basis evaluation failed: {e}"))?;
305 if used_knots.len() != knots.len() {
306 return Err("open transport basis knot vector drifted".to_string());
307 }
308 Ok(rows.as_ref().to_owned())
309 }
310 }
311 }
312
313 fn derivative_poly_degree(&self) -> usize {
316 let degree = match self {
317 DomainBasis::Periodic(spec) => spec.degree,
318 DomainBasis::Open { degree, .. } => *degree,
319 };
320 degree.saturating_sub(1)
321 }
322
323 fn derivative_breakpoints(&self) -> Vec<f64> {
331 match self {
332 DomainBasis::Periodic(spec) => {
333 let n_seg = spec.num_basis.max(1);
337 (0..=n_seg)
338 .map(|k| spec.origin + spec.period * k as f64 / n_seg as f64)
339 .collect()
340 }
341 DomainBasis::Open { knots, degree } => {
342 let lo = knots[*degree];
343 let hi = knots[knots.len() - 1 - degree];
344 let mut breaks: Vec<f64> = Vec::with_capacity(knots.len());
345 for &k in knots.iter() {
346 if k > lo + 0.0 && k < hi {
347 breaks.push(k);
348 }
349 }
350 breaks.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
351 breaks.dedup_by(|a, b| (*a - *b).abs() <= f64::EPSILON * hi.abs().max(1.0));
352 let mut out = Vec::with_capacity(breaks.len() + 2);
353 out.push(lo);
354 out.extend(breaks.into_iter().filter(|&k| k > lo && k < hi));
355 out.push(hi);
356 out
357 }
358 }
359 }
360
361 fn derivative_rows(&self, t: ArrayView1<'_, f64>) -> Result<Array2<f64>, String> {
362 let projected = t.mapv(|v| self.project(v));
363 match self {
364 DomainBasis::Periodic(spec) => {
365 let n = projected.len();
366 let mut col = Array2::<f64>::zeros((n, 1));
367 for (i, &v) in projected.iter().enumerate() {
368 col[[i, 0]] = v;
369 }
370 let jet = periodic_bspline_first_derivative_nd(
371 col.view(),
372 (0.0, TAU),
373 spec.degree,
374 spec.num_basis,
375 )
376 .map_err(|e| format!("periodic transport derivative failed: {e}"))?;
377 Ok(jet.index_axis(Axis(2), 0).to_owned())
378 }
379 DomainBasis::Open { knots, degree } => {
380 let (rows, used_knots) = create_basis::<Dense>(
381 projected.view(),
382 KnotSource::Provided(knots.view()),
383 *degree,
384 BasisOptions::first_derivative(),
385 )
386 .map_err(|e| format!("open transport derivative failed: {e}"))?;
387 if used_knots.len() != knots.len() {
388 return Err("open transport derivative knot vector drifted".to_string());
389 }
390 Ok(rows.as_ref().to_owned())
391 }
392 }
393 }
394}
395
396struct Penalized1dFit {
401 beta: Array1<f64>,
402 covariance: Array2<f64>,
405 influence: Array2<f64>,
408 lambda: f64,
409 edf: f64,
410 sigma2: f64,
411 residual_rms: f64,
412}
413
414fn fit_penalized_1d(
424 design: &Array2<f64>,
425 penalty: &Array2<f64>,
426 response: ArrayView1<'_, f64>,
427 weights: Option<ArrayView1<'_, f64>>,
428 penalty_rank: usize,
429 known_scale: bool,
430) -> Result<Penalized1dFit, String> {
431 let n = design.nrows();
432 let m = design.ncols();
433 if response.len() != n || penalty.nrows() != m || penalty.ncols() != m {
434 return Err(format!(
435 "penalized 1-D fit shape mismatch: X is {n}×{m}, y has {}, S is {}×{}",
436 response.len(),
437 penalty.nrows(),
438 penalty.ncols()
439 ));
440 }
441 if let Some(w) = weights {
442 if w.len() != n {
443 return Err(format!(
444 "penalized 1-D fit weight length {} does not match n = {n}",
445 w.len()
446 ));
447 }
448 if w.iter().any(|&v| !v.is_finite() || v <= 0.0) {
449 return Err("penalized 1-D fit weights must be finite and positive".to_string());
450 }
451 }
452
453 let mut xtwx = Array2::<f64>::zeros((m, m));
454 let mut xtwy = Array1::<f64>::zeros(m);
455 let mut ytwy = 0.0_f64;
456 let mut sum_w = 0.0_f64;
457 for r in 0..n {
458 let w = weights.map_or(1.0, |wv| wv[r]);
459 let y = response[r];
460 ytwy += w * y * y;
461 sum_w += w;
462 for j in 0..m {
463 let xj = design[[r, j]];
464 if xj == 0.0 {
465 continue;
466 }
467 xtwy[j] += w * xj * y;
468 for k in j..m {
469 xtwx[[j, k]] += w * xj * design[[r, k]];
470 }
471 }
472 }
473 for j in 0..m {
474 for k in 0..j {
475 xtwx[[j, k]] = xtwx[[k, j]];
476 }
477 }
478
479 let trace_scale = (0..m).map(|i| xtwx[[i, i]]).sum::<f64>() / m as f64;
480 let anchor = trace_scale.max(f64::MIN_POSITIVE);
481 let nullspace_dim = m.saturating_sub(penalty_rank);
482 let dof = ((n as f64) - nullspace_dim as f64).max(1.0);
483 let rank_f = penalty_rank as f64;
484
485 let solve_at = |lambda: f64| -> Result<(Array1<f64>, Array1<f64>, Array2<f64>), String> {
486 let mut a = xtwx.clone();
487 for j in 0..m {
488 for k in 0..m {
489 a[[j, k]] += lambda * penalty[[j, k]];
490 }
491 }
492 let diag_scale = (0..m).map(|i| a[[i, i]].abs()).fold(1.0_f64, f64::max);
494 for i in 0..m {
495 a[[i, i]] += 1e-12 * diag_scale;
496 }
497 let (evals, evecs) = a
498 .eigh(Side::Lower)
499 .map_err(|e| format!("penalized 1-D fit eigendecomposition failed: {e:?}"))?;
500 Ok((evals, evecs.t().dot(&xtwy), evecs))
501 };
502
503 let criterion = |lambda: f64| -> f64 {
504 let Ok(parts) = solve_at(lambda) else {
505 return f64::INFINITY;
506 };
507 let (evals, rotated) = (&parts.0, &parts.1);
508 let floor = evals.iter().copied().fold(0.0_f64, f64::max) * 1e-14;
509 let mut prss = ytwy;
510 let mut logdet = 0.0_f64;
511 for i in 0..m {
512 let d = evals[i].max(floor).max(f64::MIN_POSITIVE);
513 prss -= rotated[i] * rotated[i] / d;
514 logdet += d.ln();
515 }
516 let prss = prss.max(f64::MIN_POSITIVE);
517 let fit_term = if known_scale { prss } else { dof * prss.ln() };
518 fit_term + logdet - rank_f * lambda.ln()
519 };
520
521 let lo = anchor * 10f64.powf(-REML_LAMBDA_SPAN_DECADES);
522 let hi = anchor * 10f64.powf(REML_LAMBDA_SPAN_DECADES);
523 let grid: Vec<f64> = (0..REML_LAMBDA_GRID_POINTS)
524 .map(|i| {
525 let t = i as f64 / (REML_LAMBDA_GRID_POINTS - 1) as f64;
526 lo * (hi / lo).powf(t)
527 })
528 .collect();
529 let mut best_idx = 0usize;
530 let mut best_val = f64::INFINITY;
531 for (i, &lam) in grid.iter().enumerate() {
532 let v = criterion(lam);
533 if v < best_val {
534 best_val = v;
535 best_idx = i;
536 }
537 }
538 let mut a_log = grid[best_idx.saturating_sub(1)].ln();
539 let mut c_log = grid[(best_idx + 1).min(REML_LAMBDA_GRID_POINTS - 1)].ln();
540 let golden = (5.0_f64.sqrt() - 1.0) / 2.0;
541 let mut x1 = c_log - golden * (c_log - a_log);
542 let mut x2 = a_log + golden * (c_log - a_log);
543 let mut f1 = criterion(x1.exp());
544 let mut f2 = criterion(x2.exp());
545 for _ in 0..REML_GOLDEN_ITERATIONS {
546 if f1 <= f2 {
547 c_log = x2;
548 x2 = x1;
549 f2 = f1;
550 x1 = c_log - golden * (c_log - a_log);
551 f1 = criterion(x1.exp());
552 } else {
553 a_log = x1;
554 x1 = x2;
555 f1 = f2;
556 x2 = a_log + golden * (c_log - a_log);
557 f2 = criterion(x2.exp());
558 }
559 }
560 let lambda = (0.5 * (a_log + c_log)).exp();
561
562 let (evals, rotated, evecs) = solve_at(lambda)?;
563 let floor = evals.iter().copied().fold(0.0_f64, f64::max) * 1e-14;
564 let mut a_inv = Array2::<f64>::zeros((m, m));
565 let mut beta = Array1::<f64>::zeros(m);
566 for i in 0..m {
567 let d = evals[i].max(floor).max(f64::MIN_POSITIVE);
568 let coeff = rotated[i] / d;
569 for j in 0..m {
570 beta[j] += evecs[[j, i]] * coeff;
571 for k in 0..m {
572 a_inv[[j, k]] += evecs[[j, i]] * evecs[[k, i]] / d;
573 }
574 }
575 }
576 let influence = a_inv.dot(&xtwx);
577 let edf = (0..m).map(|i| influence[[i, i]]).sum::<f64>();
578
579 let fitted = design.dot(&beta);
580 let mut rss = 0.0_f64;
581 for r in 0..n {
582 let w = weights.map_or(1.0, |wv| wv[r]);
583 let e = response[r] - fitted[r];
584 rss += w * e * e;
585 }
586 let sigma2 = if known_scale {
587 1.0
588 } else {
589 (rss / ((n as f64) - edf).max(1.0)).max(f64::MIN_POSITIVE)
590 };
591 let covariance = a_inv.mapv(|v| v * sigma2);
592 let residual_rms = (rss / sum_w.max(f64::MIN_POSITIVE)).sqrt();
593
594 if beta.iter().any(|v| !v.is_finite()) {
595 return Err("penalized 1-D fit produced non-finite coefficients".to_string());
596 }
597 Ok(Penalized1dFit {
598 beta,
599 covariance,
600 influence,
601 lambda,
602 edf,
603 sigma2,
604 residual_rms,
605 })
606}
607
608#[derive(Debug, Clone)]
618pub struct FittedTransport {
619 pub topology_from: ChartTopology,
620 pub topology_to: ChartTopology,
621 pub degree: Option<i32>,
623 pub degree_concentration: Option<f64>,
626 pub rotation_offset: f64,
631 pub beta: Array1<f64>,
633 pub covariance: Array2<f64>,
635 pub smoothing_lambda: f64,
636 pub edf: f64,
638 pub noise_variance: f64,
640 pub n_obs: usize,
641 pub isometry_defect: f64,
643 pub isometry_defect_se: f64,
645 pub topology_preserved: bool,
648 pub min_directional_derivative: f64,
650 pub residual_rms: f64,
652 basis: DomainBasis,
653}
654
655impl FittedTransport {
656 fn linear_slope(&self) -> f64 {
657 self.degree.map_or(0.0, f64::from)
658 }
659
660 pub fn eval(&self, t: ArrayView1<'_, f64>) -> Result<Array1<f64>, String> {
662 let rows = self.basis.value_rows(t)?;
663 let smooth = rows.dot(&self.beta);
664 let slope = self.linear_slope();
665 let mut out = Array1::<f64>::zeros(t.len());
666 for i in 0..t.len() {
667 let raw = slope * t[i] + self.rotation_offset + smooth[i];
668 out[i] = match self.topology_to {
669 ChartTopology::Circle => wrap_tau(raw),
670 ChartTopology::Interval { .. } => raw,
671 };
672 }
673 Ok(out)
674 }
675
676 pub fn eval_with_variance(
678 &self,
679 t: ArrayView1<'_, f64>,
680 ) -> Result<(Array1<f64>, Array1<f64>), String> {
681 let rows = self.basis.value_rows(t)?;
682 let values = self.eval(t)?;
683 let mut variances = Array1::<f64>::zeros(t.len());
684 for i in 0..t.len() {
685 let row = rows.row(i);
686 variances[i] = row.dot(&self.covariance.dot(&row)).max(0.0);
687 }
688 Ok((values, variances))
689 }
690
691 pub fn derivative(&self, t: ArrayView1<'_, f64>) -> Result<Array1<f64>, String> {
693 let rows = self.basis.derivative_rows(t)?;
694 let slope = self.linear_slope();
695 Ok(rows.dot(&self.beta).mapv(|v| v + slope))
696 }
697
698 fn raw_at(&self, t: f64) -> Result<f64, String> {
702 let arr = Array1::from_elem(1, t);
703 let smooth = self.basis.value_rows(arr.view())?.dot(&self.beta)[0];
704 Ok(self.linear_slope() * t + self.rotation_offset + smooth)
705 }
706
707 fn oriented_derivative_at(&self, t: &[f64], orientation: f64) -> Result<Vec<f64>, String> {
709 let arr = Array1::from_vec(t.to_vec());
710 let rows = self.basis.derivative_rows(arr.view())?;
711 let slope = self.linear_slope();
712 Ok((0..t.len())
713 .map(|i| orientation * (rows.row(i).dot(&self.beta) + slope))
714 .collect())
715 }
716
717 fn certify_strict_monotonicity(&self) -> Result<f64, String> {
735 let (lo, hi) = match self.topology_from {
736 ChartTopology::Circle => (0.0, TAU),
737 ChartTopology::Interval { lo, hi } => (lo, hi),
738 };
739 let raw_lo = self.raw_at(lo)?;
742 let raw_hi = self.raw_at(hi)?;
743 let orientation = if raw_hi >= raw_lo { 1.0 } else { -1.0 };
744
745 let deg = self.basis.derivative_poly_degree().max(1);
746 let breaks = self.basis.derivative_breakpoints();
747 for window in breaks.windows(2) {
750 let (a, b) = (window[0], window[1]);
751 if !(b > a) {
752 continue;
753 }
754 let span = b - a;
755 let pad = span * 1.0e-9;
759 let n_nodes = deg + 1;
760 let nodes: Vec<f64> = (0..n_nodes)
761 .map(|i| {
762 let s = if n_nodes == 1 {
763 0.5
764 } else {
765 i as f64 / (n_nodes - 1) as f64
766 };
767 (a + pad) + (span - 2.0 * pad) * s
768 })
769 .collect();
770 let values = self.oriented_derivative_at(&nodes, orientation)?;
771
772 let step = if n_nodes > 1 {
777 nodes[1] - nodes[0]
778 } else {
779 span
780 };
781 let coeffs = monomial_from_equispaced(&values);
782
783 let probe_t = a + 0.37 * span;
789 let probe_u = (probe_t - nodes[0]) / step;
790 let probe_recon = eval_monomial(&coeffs, probe_u);
791 let probe_actual = self.oriented_derivative_at(&[probe_t], orientation)?[0];
792 let scale = probe_actual.abs().max(1.0);
793 if (probe_recon - probe_actual).abs() > 1.0e-6 * scale {
794 return Err(format!(
795 "transport monotonicity certificate could not reconstruct h′ on the \
796 span [{a}, {b}] (reconstruction {probe_recon} vs actual {probe_actual}); \
797 refusing to certify"
798 ));
799 }
800
801 for &edge in &[a, b] {
803 let u = (edge - nodes[0]) / step;
804 let v = eval_monomial(&coeffs, u);
805 if !(v > 0.0) {
806 return Err(format!(
807 "transport map is not strictly monotone: orientation·h′ = {v} ≤ 0 at \
808 t = {edge}"
809 ));
810 }
811 }
812 for u_crit in monomial_critical_points(&coeffs) {
815 let t_crit = nodes[0] + u_crit * step;
816 if t_crit > a && t_crit < b {
817 let v = eval_monomial(&coeffs, u_crit);
818 if !(v > 0.0) {
819 return Err(format!(
820 "transport map folds: orientation·h′ = {v} ≤ 0 at interior \
821 extremum t = {t_crit}"
822 ));
823 }
824 }
825 }
826 }
827 Ok(orientation)
828 }
829
830 pub fn invert(&self, y: ArrayView1<'_, f64>) -> Result<Array1<f64>, String> {
850 if y.iter().any(|v| !v.is_finite()) {
851 return Err("transport inverse targets must be finite".to_string());
852 }
853 self.certify_strict_monotonicity()?;
856 let (lo, hi) = match self.topology_from {
857 ChartTopology::Circle => (0.0, TAU),
858 ChartTopology::Interval { lo, hi } => (lo, hi),
859 };
860 let raw_lo = self.raw_at(lo)?;
863 let raw_hi = self.raw_at(hi)?;
864 let increasing = raw_hi > raw_lo;
865 let (raw_min, raw_max) = if increasing {
866 (raw_lo, raw_hi)
867 } else {
868 (raw_hi, raw_lo)
869 };
870 let scale = raw_min.abs().max(raw_max.abs()).max(1.0);
873 let tol = 32.0 * f64::EPSILON * scale;
874
875 let mut probe = Array1::<f64>::zeros(1);
878 let mut raw_at_into = |t: f64| -> Result<f64, String> {
879 probe[0] = t;
880 let smooth = self.basis.value_rows(probe.view())?.dot(&self.beta)[0];
881 Ok(self.linear_slope() * t + self.rotation_offset + smooth)
882 };
883
884 let mut out = Array1::<f64>::zeros(y.len());
885 for (idx, &yi) in y.iter().enumerate() {
886 let target = match self.topology_to {
888 ChartTopology::Interval { .. } => {
889 if yi < raw_min - tol || yi > raw_max + tol {
890 return Err(format!(
891 "transport inverse target {yi} is outside the fitted image \
892 [{raw_min}, {raw_max}]"
893 ));
894 }
895 yi.clamp(raw_min, raw_max)
896 }
897 ChartTopology::Circle => {
898 let ywrapped = wrap_tau(yi);
901 let m = ((raw_min - ywrapped) / TAU).ceil();
902 ywrapped + TAU * m
903 }
904 };
905 let (mut a, mut b) = (lo, hi);
909 let width_floor = f64::EPSILON * hi.abs().max(lo.abs()).max(1.0);
910 for _ in 0..100 {
911 if (b - a) <= width_floor {
912 break;
913 }
914 let mid = 0.5 * (a + b);
915 let rm = raw_at_into(mid)?;
916 let go_right = if increasing { rm < target } else { rm > target };
917 if go_right {
918 a = mid;
919 } else {
920 b = mid;
921 }
922 }
923 out[idx] = 0.5 * (a + b);
924 }
925 Ok(out)
926 }
927
928 pub fn report(&self, layer_from: usize, layer_to: usize) -> LayerTransportReport {
931 LayerTransportReport {
932 layer_from,
933 layer_to,
934 topology_from: self.topology_from,
935 topology_to: self.topology_to,
936 topology_preserved: self.topology_preserved,
937 degree: self.degree,
938 degree_concentration: self.degree_concentration,
939 rotation_offset: self.rotation_offset,
940 isometry_defect: self.isometry_defect,
941 isometry_defect_se: self.isometry_defect_se,
942 min_directional_derivative: self.min_directional_derivative,
943 transport_edf: self.edf,
944 smoothing_lambda: self.smoothing_lambda,
945 noise_variance: self.noise_variance,
946 residual_rms: self.residual_rms,
947 n_obs: self.n_obs,
948 composition_defect: None,
949 composition_max_studentized: None,
950 composition_p_value: None,
951 composition_gauge_reflected: None,
952 }
953 }
954}
955
956#[derive(Debug, Clone)]
958pub struct LayerTransportReport {
959 pub layer_from: usize,
960 pub layer_to: usize,
961 pub topology_from: ChartTopology,
962 pub topology_to: ChartTopology,
963 pub topology_preserved: bool,
965 pub degree: Option<i32>,
967 pub degree_concentration: Option<f64>,
969 pub rotation_offset: f64,
971 pub isometry_defect: f64,
973 pub isometry_defect_se: f64,
975 pub min_directional_derivative: f64,
977 pub transport_edf: f64,
979 pub smoothing_lambda: f64,
980 pub noise_variance: f64,
981 pub residual_rms: f64,
982 pub n_obs: usize,
983 pub composition_defect: Option<f64>,
986 pub composition_max_studentized: Option<f64>,
988 pub composition_p_value: Option<f64>,
991 pub composition_gauge_reflected: Option<bool>,
993}
994
995impl LayerTransportReport {
996 pub fn with_composition(mut self, composition: &CompositionDefectReport) -> Self {
998 self.composition_defect = Some(composition.rms_defect);
999 self.composition_max_studentized = Some(composition.max_studentized_defect);
1000 self.composition_p_value = Some(composition.p_value);
1001 self.composition_gauge_reflected = Some(composition.gauge_reflected);
1002 self
1003 }
1004}
1005
1006pub fn fit_transport_map(
1013 coords_from: ArrayView1<'_, f64>,
1014 coords_to: ArrayView1<'_, f64>,
1015 topology_from: ChartTopology,
1016 topology_to: ChartTopology,
1017) -> Result<FittedTransport, String> {
1018 let n = coords_from.len();
1019 if coords_to.len() != n {
1020 return Err(format!(
1021 "layer transport coordinate lengths disagree: {} vs {}",
1022 n,
1023 coords_to.len()
1024 ));
1025 }
1026 if n < MIN_TRANSPORT_OBS {
1027 return Err(format!(
1028 "layer transport needs at least {MIN_TRANSPORT_OBS} paired observations, got {n}"
1029 ));
1030 }
1031 if coords_from
1032 .iter()
1033 .chain(coords_to.iter())
1034 .any(|v| !v.is_finite())
1035 {
1036 return Err("layer transport coordinates must all be finite".to_string());
1037 }
1038 topology_from.validate()?;
1039 topology_to.validate()?;
1040
1041 let (degree, degree_concentration, rotation_offset, response): (
1043 Option<i32>,
1044 Option<f64>,
1045 f64,
1046 Array1<f64>,
1047 ) = match (topology_from, topology_to) {
1048 (ChartTopology::Circle, ChartTopology::Circle) => {
1049 let mut best_degree = DEGREE_CANDIDATES[0];
1056 let mut best_r = f64::NEG_INFINITY;
1057 for &d in DEGREE_CANDIDATES.iter() {
1058 let residual: Vec<f64> = (0..n)
1059 .map(|i| coords_to[i] - f64::from(d) * coords_from[i])
1060 .collect();
1061 let r = resultant_length(&residual);
1062 if r > best_r {
1063 best_r = r;
1064 best_degree = d;
1065 }
1066 }
1067 let residual: Vec<f64> = (0..n)
1068 .map(|i| coords_to[i] - f64::from(best_degree) * coords_from[i])
1069 .collect();
1070 let mu = circular_mean(&residual);
1071 let response = Array1::from_iter(residual.iter().map(|&r| wrap_pi(r - mu)));
1072 (Some(best_degree), Some(best_r), mu, response)
1073 }
1074 (_, ChartTopology::Circle) => {
1075 let angles: Vec<f64> = coords_to.iter().copied().collect();
1079 let mu = circular_mean(&angles);
1080 let response = Array1::from_iter(angles.iter().map(|&a| wrap_pi(a - mu)));
1081 (None, None, mu, response)
1082 }
1083 (_, ChartTopology::Interval { .. }) => (None, None, 0.0, coords_to.to_owned()),
1084 };
1085
1086 let basis = DomainBasis::build(topology_from, coords_from)?;
1088 let design = basis.value_rows(coords_from)?;
1089 let penalty = basis.penalty()?;
1090 let fit = fit_penalized_1d(
1091 &design,
1092 &penalty,
1093 response.view(),
1094 None,
1095 basis.penalty_rank(),
1096 false,
1097 )?;
1098
1099 let slope = degree.map_or(0.0, f64::from);
1101 let deriv_rows = basis.derivative_rows(coords_from)?;
1102 let deriv = deriv_rows.dot(&fit.beta).mapv(|v| v + slope);
1103 let m = basis.num_basis();
1104 let mut defect = 0.0_f64;
1105 let mut grad = Array1::<f64>::zeros(m);
1106 for i in 0..n {
1107 let speed = deriv[i].abs();
1108 let gap = speed - 1.0;
1109 defect += gap * gap;
1110 let sgn = if deriv[i] >= 0.0 { 1.0 } else { -1.0 };
1111 for j in 0..m {
1112 grad[j] += 2.0 * gap * sgn * deriv_rows[[i, j]];
1113 }
1114 }
1115 defect /= n as f64;
1116 grad.mapv_inplace(|v| v / n as f64);
1117 let isometry_defect_se = grad.dot(&fit.covariance.dot(&grad)).max(0.0).sqrt();
1118
1119 let grid = domain_grid(topology_from, FOLD_CHECK_GRID);
1121 let grid_deriv = basis
1122 .derivative_rows(grid.view())?
1123 .dot(&fit.beta)
1124 .mapv(|v| v + slope);
1125 let orientation = if slope != 0.0 {
1126 slope.signum()
1127 } else {
1128 let mean = grid_deriv.iter().sum::<f64>() / grid_deriv.len() as f64;
1129 if mean < 0.0 { -1.0 } else { 1.0 }
1130 };
1131 let min_directional_derivative = grid_deriv
1132 .iter()
1133 .map(|&v| orientation * v)
1134 .fold(f64::INFINITY, f64::min);
1135 let topology_preserved = match (topology_from, topology_to) {
1136 (ChartTopology::Circle, ChartTopology::Circle) => {
1137 matches!(degree, Some(1) | Some(-1)) && min_directional_derivative > 0.0
1138 }
1139 (ChartTopology::Interval { .. }, ChartTopology::Interval { .. }) => {
1140 min_directional_derivative > 0.0
1141 }
1142 _ => false,
1143 };
1144
1145 Ok(FittedTransport {
1146 topology_from,
1147 topology_to,
1148 degree,
1149 degree_concentration,
1150 rotation_offset,
1151 beta: fit.beta,
1152 covariance: fit.covariance,
1153 smoothing_lambda: fit.lambda,
1154 edf: fit.edf,
1155 noise_variance: fit.sigma2,
1156 n_obs: n,
1157 isometry_defect: defect,
1158 isometry_defect_se,
1159 topology_preserved,
1160 min_directional_derivative,
1161 residual_rms: fit.residual_rms,
1162 basis,
1163 })
1164}
1165
1166pub fn fit_layer_transport(
1168 layer_from: usize,
1169 layer_to: usize,
1170 coords_from: ArrayView1<'_, f64>,
1171 coords_to: ArrayView1<'_, f64>,
1172 topology_from: ChartTopology,
1173 topology_to: ChartTopology,
1174) -> Result<LayerTransportReport, String> {
1175 Ok(
1176 fit_transport_map(coords_from, coords_to, topology_from, topology_to)?
1177 .report(layer_from, layer_to),
1178 )
1179}
1180
1181#[derive(Debug, Clone)]
1183pub struct CompositionDefectReport {
1184 pub n_grid: usize,
1185 pub gauge_rotation: f64,
1187 pub gauge_reflected: bool,
1189 pub mean_abs_defect: f64,
1190 pub rms_defect: f64,
1191 pub max_abs_defect: f64,
1192 pub max_studentized_defect: f64,
1194 pub max_studentized_p_value: f64,
1197 pub defect_edf: f64,
1199 pub statistic: f64,
1201 pub ref_df: f64,
1202 pub p_value: f64,
1204}
1205
1206fn monomial_from_equispaced(values: &[f64]) -> Vec<f64> {
1212 let n = values.len();
1213 if n == 0 {
1214 return Vec::new();
1215 }
1216 let mut diffs: Vec<f64> = values.to_vec();
1218 let mut fwd = vec![0.0_f64; n];
1219 fwd[0] = diffs[0];
1220 for k in 1..n {
1221 for i in 0..(n - k) {
1222 diffs[i] = diffs[i + 1] - diffs[i];
1223 }
1224 fwd[k] = diffs[0];
1225 }
1226 let mut coeffs = vec![0.0_f64; n];
1229 let mut poly = vec![0.0_f64; n];
1232 poly[0] = 1.0;
1233 let mut poly_len = 1usize;
1234 let mut factorial = 1.0_f64;
1235 for k in 0..n {
1236 if k > 0 {
1237 factorial *= k as f64;
1238 }
1239 let scale = fwd[k] / factorial;
1240 for (i, &p) in poly.iter().take(poly_len).enumerate() {
1241 coeffs[i] += scale * p;
1242 }
1243 if k + 1 < n {
1245 let mut next = vec![0.0_f64; poly_len + 1];
1246 for i in 0..poly_len {
1247 next[i + 1] += poly[i]; next[i] -= (k as f64) * poly[i]; }
1250 for i in 0..=poly_len {
1251 poly[i] = next[i];
1252 }
1253 poly_len += 1;
1254 }
1255 }
1256 coeffs
1257}
1258
1259fn eval_monomial(coeffs: &[f64], u: f64) -> f64 {
1261 coeffs.iter().rev().fold(0.0_f64, |acc, &c| acc * u + c)
1262}
1263
1264fn monomial_critical_points(coeffs: &[f64]) -> Vec<f64> {
1273 let n = coeffs.len();
1275 if n <= 1 {
1276 return Vec::new();
1277 }
1278 let deriv: Vec<f64> = (1..n).map(|k| k as f64 * coeffs[k]).collect();
1279 match deriv.len() {
1281 0 => Vec::new(),
1282 1 => Vec::new(), 2 => {
1284 let (b, a) = (deriv[0], deriv[1]);
1286 if a.abs() <= f64::MIN_POSITIVE {
1287 Vec::new()
1288 } else {
1289 vec![-b / a]
1290 }
1291 }
1292 3 => {
1293 let (c, b, a) = (deriv[0], deriv[1], deriv[2]);
1295 if a.abs() <= f64::MIN_POSITIVE {
1296 if b.abs() <= f64::MIN_POSITIVE {
1297 Vec::new()
1298 } else {
1299 vec![-c / b]
1300 }
1301 } else {
1302 let disc = b * b - 4.0 * a * c;
1303 if disc < 0.0 {
1304 Vec::new()
1305 } else {
1306 let s = disc.sqrt();
1307 vec![(-b + s) / (2.0 * a), (-b - s) / (2.0 * a)]
1308 }
1309 }
1310 }
1311 _ => {
1312 let lo = 0.0;
1315 let hi = (coeffs.len() - 1) as f64;
1316 let steps = 256;
1317 let mut roots = Vec::new();
1318 let f = |u: f64| eval_monomial(&deriv, u);
1319 let mut prev_u = lo;
1320 let mut prev_v = f(lo);
1321 for i in 1..=steps {
1322 let u = lo + (hi - lo) * i as f64 / steps as f64;
1323 let v = f(u);
1324 if prev_v == 0.0 {
1325 roots.push(prev_u);
1326 } else if prev_v * v < 0.0 {
1327 let (mut a, mut b) = (prev_u, u);
1328 for _ in 0..60 {
1329 let m = 0.5 * (a + b);
1330 if f(a) * f(m) <= 0.0 {
1331 b = m;
1332 } else {
1333 a = m;
1334 }
1335 }
1336 roots.push(0.5 * (a + b));
1337 }
1338 prev_u = u;
1339 prev_v = v;
1340 }
1341 roots
1342 }
1343 }
1344}
1345
1346fn domain_grid(topology: ChartTopology, n: usize) -> Array1<f64> {
1348 match topology {
1349 ChartTopology::Circle => Array1::from_iter((0..n).map(|i| TAU * i as f64 / n as f64)),
1350 ChartTopology::Interval { lo, hi } => {
1351 Array1::from_iter((0..n).map(|i| lo + (hi - lo) * i as f64 / (n - 1).max(1) as f64))
1352 }
1353 }
1354}
1355
1356pub fn composition_defect(
1371 h_ab: &FittedTransport,
1372 h_bc: &FittedTransport,
1373 h_ac: &FittedTransport,
1374 n_grid: usize,
1375) -> Result<CompositionDefectReport, String> {
1376 if h_ab.topology_from != h_ac.topology_from
1377 || h_ab.topology_to != h_bc.topology_from
1378 || h_bc.topology_to != h_ac.topology_to
1379 {
1380 return Err("composition defect requires chart-compatible transports: \
1381 h_ab: A→B, h_bc: B→C, h_ac: A→C"
1382 .to_string());
1383 }
1384 if n_grid < MIN_TRANSPORT_OBS {
1385 return Err(format!(
1386 "composition defect grid must have at least {MIN_TRANSPORT_OBS} points, got {n_grid}"
1387 ));
1388 }
1389
1390 let grid = domain_grid(h_ab.topology_from, n_grid);
1391 let (direct, var_direct) = h_ac.eval_with_variance(grid.view())?;
1392 let (mid, var_mid) = h_ab.eval_with_variance(grid.view())?;
1393 let (composed, var_bc) = h_bc.eval_with_variance(mid.view())?;
1394 let mid_slope = h_bc.derivative(mid.view())?;
1395 let mut variance = Array1::<f64>::zeros(n_grid);
1396 for i in 0..n_grid {
1397 variance[i] = var_direct[i] + var_bc[i] + mid_slope[i] * mid_slope[i] * var_mid[i];
1398 }
1399
1400 let circle_target = matches!(h_ac.topology_to, ChartTopology::Circle);
1402 let mut gauge_reflected = false;
1403 let mut gauge_rotation = 0.0_f64;
1404 let mut defect = Array1::<f64>::zeros(n_grid);
1405 let mut best_sse = f64::INFINITY;
1406 for reflected in [false, true] {
1407 let composed_oriented: Array1<f64> = match (h_ac.topology_to, reflected) {
1408 (_, false) => composed.clone(),
1409 (ChartTopology::Circle, true) => composed.mapv(|v| wrap_tau(-v)),
1410 (ChartTopology::Interval { lo, hi }, true) => composed.mapv(|v| lo + hi - v),
1411 };
1412 let (rotation, candidate): (f64, Array1<f64>) = if circle_target {
1413 let raw: Vec<f64> = (0..n_grid)
1414 .map(|i| wrap_pi(direct[i] - composed_oriented[i]))
1415 .collect();
1416 let rot = circular_mean(&raw);
1417 (
1418 rot,
1419 Array1::from_iter(raw.iter().map(|&d| wrap_pi(d - rot))),
1420 )
1421 } else {
1422 (
1423 0.0,
1424 Array1::from_iter((0..n_grid).map(|i| direct[i] - composed_oriented[i])),
1425 )
1426 };
1427 let sse = candidate.iter().map(|&d| d * d).sum::<f64>();
1428 if sse < best_sse {
1429 best_sse = sse;
1430 gauge_reflected = reflected;
1431 gauge_rotation = rotation;
1432 defect = candidate;
1433 }
1434 }
1435
1436 let max_var = variance.iter().copied().fold(0.0_f64, f64::max);
1438 let var_floor = (max_var * 1e-10).max(f64::MIN_POSITIVE);
1439 let mut max_abs = 0.0_f64;
1440 let mut sum_abs = 0.0_f64;
1441 let mut sum_sq = 0.0_f64;
1442 let mut max_z = 0.0_f64;
1443 for i in 0..n_grid {
1444 let d = defect[i];
1445 let a = d.abs();
1446 max_abs = max_abs.max(a);
1447 sum_abs += a;
1448 sum_sq += d * d;
1449 let z = a / variance[i].max(var_floor).sqrt();
1450 max_z = max_z.max(z);
1451 }
1452 let mean_abs_defect = sum_abs / n_grid as f64;
1453 let rms_defect = (sum_sq / n_grid as f64).sqrt();
1454
1455 let basis = DomainBasis::build(h_ab.topology_from, grid.view())?;
1457 let design = basis.value_rows(grid.view())?;
1458 let penalty = basis.penalty()?;
1459 let weights = variance.mapv(|v| 1.0 / v.max(var_floor));
1460 let fit = fit_penalized_1d(
1461 &design,
1462 &penalty,
1463 defect.view(),
1464 Some(weights.view()),
1465 basis.penalty_rank(),
1466 true,
1467 )?;
1468 let m = basis.num_basis();
1469 let test = wood_smooth_test(SmoothTestInput {
1470 beta: fit.beta.view(),
1471 covariance: &fit.covariance,
1472 influence_matrix: Some(&fit.influence),
1473 coeff_range: 0..m,
1474 edf: fit.edf,
1475 nullspace_dim: 0,
1476 residual_df: (n_grid as f64 - fit.edf).max(1.0),
1477 scale: SmoothTestScale::Known,
1478 })
1479 .ok_or_else(|| "composition defect smooth test degenerated".to_string())?;
1480
1481 let normal =
1484 Normal::new(0.0, 1.0).map_err(|e| format!("standard normal construction failed: {e}"))?;
1485 let pointwise: f64 = (2.0 * (1.0 - normal.cdf(max_z))).clamp(0.0, 1.0);
1486 let max_studentized_p_value = (n_grid as f64 * pointwise).min(1.0);
1487
1488 Ok(CompositionDefectReport {
1489 n_grid,
1490 gauge_rotation,
1491 gauge_reflected,
1492 mean_abs_defect,
1493 rms_defect,
1494 max_abs_defect: max_abs,
1495 max_studentized_defect: max_z,
1496 max_studentized_p_value,
1497 defect_edf: fit.edf,
1498 statistic: test.statistic,
1499 ref_df: test.ref_df,
1500 p_value: test.p_value,
1501 })
1502}
1503
1504#[derive(Debug, Clone)]
1507pub struct TransportLadderReport {
1508 pub adjacent: Vec<LayerTransportReport>,
1510 pub two_hop: Vec<LayerTransportReport>,
1513}
1514
1515pub fn transport_ladder(
1521 layers: &[usize],
1522 coords: &[Array1<f64>],
1523 topologies: &[ChartTopology],
1524) -> Result<TransportLadderReport, String> {
1525 let depth = layers.len();
1526 if coords.len() != depth || topologies.len() != depth {
1527 return Err(format!(
1528 "transport ladder inputs disagree: {depth} layers, {} coordinate vectors, {} topologies",
1529 coords.len(),
1530 topologies.len()
1531 ));
1532 }
1533 if depth < 2 {
1534 return Err("transport ladder needs at least two layers".to_string());
1535 }
1536
1537 let mut adjacent_fits: Vec<FittedTransport> = Vec::with_capacity(depth - 1);
1538 let mut adjacent: Vec<LayerTransportReport> = Vec::with_capacity(depth - 1);
1539 for k in 0..depth - 1 {
1540 let fit = fit_transport_map(
1541 coords[k].view(),
1542 coords[k + 1].view(),
1543 topologies[k],
1544 topologies[k + 1],
1545 )
1546 .map_err(|e| {
1547 format!(
1548 "adjacent transport {}→{} failed: {e}",
1549 layers[k],
1550 layers[k + 1]
1551 )
1552 })?;
1553 adjacent.push(fit.report(layers[k], layers[k + 1]));
1554 adjacent_fits.push(fit);
1555 }
1556
1557 let mut two_hop: Vec<LayerTransportReport> = Vec::with_capacity(depth.saturating_sub(2));
1558 for k in 0..depth.saturating_sub(2) {
1559 let direct = fit_transport_map(
1560 coords[k].view(),
1561 coords[k + 2].view(),
1562 topologies[k],
1563 topologies[k + 2],
1564 )
1565 .map_err(|e| {
1566 format!(
1567 "two-hop transport {}→{} failed: {e}",
1568 layers[k],
1569 layers[k + 2]
1570 )
1571 })?;
1572 let composition = composition_defect(
1573 &adjacent_fits[k],
1574 &adjacent_fits[k + 1],
1575 &direct,
1576 DEFAULT_COMPOSITION_GRID,
1577 )
1578 .map_err(|e| {
1579 format!(
1580 "composition test {}→{}→{} failed: {e}",
1581 layers[k],
1582 layers[k + 1],
1583 layers[k + 2]
1584 )
1585 })?;
1586 two_hop.push(
1587 direct
1588 .report(layers[k], layers[k + 2])
1589 .with_composition(&composition),
1590 );
1591 }
1592
1593 Ok(TransportLadderReport { adjacent, two_hop })
1594}
1595
1596#[cfg(test)]
1597mod invert_tests {
1598 use super::*;
1599 use ndarray::Array1;
1600
1601 fn interval(lo: f64, hi: f64) -> ChartTopology {
1602 ChartTopology::Interval { lo, hi }
1603 }
1604
1605 #[test]
1606 fn invert_round_trips_interval_transport() {
1607 let n = 64;
1611 let from: Array1<f64> = Array1::from_iter((0..n).map(|i| i as f64 / (n as f64 - 1.0)));
1612 let to: Array1<f64> = from.mapv(|t| t + 0.25 * (TAU * t).sin() / TAU);
1613 let ft = fit_transport_map(
1614 from.view(),
1615 to.view(),
1616 interval(0.0, 1.0),
1617 interval(0.0, 1.0),
1618 )
1619 .expect("fit");
1620 assert!(
1621 ft.topology_preserved,
1622 "monotone warp should preserve topology"
1623 );
1624
1625 let probe = Array1::from_iter((1..10).map(|i| i as f64 / 10.0));
1626 let fwd = ft.eval(probe.view()).expect("eval");
1628 let back = ft.invert(fwd.view()).expect("invert");
1629 for i in 0..probe.len() {
1630 assert!(
1631 (back[i] - probe[i]).abs() < 1e-6,
1632 "round-trip failed: t={} back={}",
1633 probe[i],
1634 back[i]
1635 );
1636 }
1637 let re_eval = ft.eval(back.view()).expect("eval");
1638 for i in 0..fwd.len() {
1639 assert!((re_eval[i] - fwd[i]).abs() < 1e-9);
1640 }
1641 }
1642
1643 #[test]
1644 fn invert_round_trips_decreasing_interval_transport() {
1645 let n = 64;
1648 let from: Array1<f64> = Array1::from_iter((0..n).map(|i| i as f64 / (n as f64 - 1.0)));
1649 let to: Array1<f64> = from.mapv(|t| 1.0 - 0.5 * t - 0.5 * t * t);
1650 let ft = fit_transport_map(
1651 from.view(),
1652 to.view(),
1653 interval(0.0, 1.0),
1654 interval(0.0, 1.0),
1655 )
1656 .expect("fit");
1657 assert!(ft.topology_preserved);
1658 let probe = Array1::from_iter((1..10).map(|i| i as f64 / 10.0));
1659 let fwd = ft.eval(probe.view()).expect("eval");
1660 let back = ft.invert(fwd.view()).expect("invert");
1661 for i in 0..probe.len() {
1662 assert!(
1663 (back[i] - probe[i]).abs() < 1e-6,
1664 "t={} back={}",
1665 probe[i],
1666 back[i]
1667 );
1668 }
1669 }
1670
1671 #[test]
1672 fn invert_round_trips_circle_transport() {
1673 let n = 128;
1675 let from: Array1<f64> = Array1::from_iter((0..n).map(|i| TAU * i as f64 / n as f64));
1676 let to: Array1<f64> = from.mapv(|t| wrap_tau(t + 0.3 + 0.2 * t.sin()));
1677 let ft = fit_transport_map(
1678 from.view(),
1679 to.view(),
1680 ChartTopology::Circle,
1681 ChartTopology::Circle,
1682 )
1683 .expect("fit");
1684 assert!(ft.topology_preserved, "degree {:?}", ft.degree);
1685
1686 let probe = Array1::from_iter((0..7).map(|i| TAU * (i as f64 + 0.5) / 7.0));
1687 let fwd = ft.eval(probe.view()).expect("eval");
1688 let back = ft.invert(fwd.view()).expect("invert");
1689 for i in 0..probe.len() {
1690 let d = wrap_pi(back[i] - probe[i]).abs();
1692 assert!(d < 1e-5, "probe={} back={} d={}", probe[i], back[i], d);
1693 }
1694 }
1695
1696 #[test]
1697 fn invert_rejects_target_outside_interval_image() {
1698 let n = 32;
1700 let from: Array1<f64> = Array1::from_iter((0..n).map(|i| i as f64 / (n as f64 - 1.0)));
1701 let to: Array1<f64> = from.mapv(|t| 0.5 * t);
1702 let ft = fit_transport_map(
1703 from.view(),
1704 to.view(),
1705 interval(0.0, 1.0),
1706 interval(0.0, 1.0),
1707 )
1708 .expect("fit");
1709 assert!(ft.invert(Array1::from_elem(1, 0.9).view()).is_err());
1710 }
1711
1712 fn fitted_from_target(
1718 from: ArrayView1<'_, f64>,
1719 target: ArrayView1<'_, f64>,
1720 lo: f64,
1721 hi: f64,
1722 ) -> FittedTransport {
1723 let basis = DomainBasis::build(interval(lo, hi), from).expect("basis");
1724 let design = basis.value_rows(from).expect("design");
1725 let m = design.ncols();
1726 let mut xtx = design.t().dot(&design);
1728 let xty = design.t().dot(&target);
1729 let diag = (0..m).map(|i| xtx[[i, i]].abs()).fold(1.0_f64, f64::max);
1730 for i in 0..m {
1731 xtx[[i, i]] += 1e-10 * diag;
1732 }
1733 let (evals, evecs) = xtx.eigh(Side::Lower).expect("eigh");
1734 let rotated = evecs.t().dot(&xty);
1735 let mut beta = Array1::<f64>::zeros(m);
1736 for i in 0..m {
1737 let d = evals[i].max(f64::MIN_POSITIVE);
1738 let c = rotated[i] / d;
1739 for j in 0..m {
1740 beta[j] += evecs[[j, i]] * c;
1741 }
1742 }
1743 FittedTransport {
1744 topology_from: interval(lo, hi),
1745 topology_to: interval(lo, hi),
1746 degree: None,
1747 degree_concentration: None,
1748 rotation_offset: 0.0,
1749 beta,
1750 covariance: Array2::<f64>::zeros((m, m)),
1751 smoothing_lambda: 0.0,
1752 edf: 0.0,
1753 noise_variance: 1.0,
1754 n_obs: from.len(),
1755 isometry_defect: 0.0,
1756 isometry_defect_se: 0.0,
1757 topology_preserved: true,
1758 min_directional_derivative: 1.0,
1759 residual_rms: 0.0,
1760 basis,
1761 }
1762 }
1763
1764 #[test]
1770 fn invert_rejects_between_grid_fold() {
1771 let n = 256;
1772 let from: Array1<f64> = Array1::from_iter((0..n).map(|i| i as f64 / (n as f64 - 1.0)));
1773 let eps = 0.4 / 511.0;
1774 let target: Array1<f64> = from.mapv(|t| (t - 0.5).powi(3) / 3.0 - eps * eps * t);
1775 let mut ft = fitted_from_target(from.view(), target.view(), 0.0, 1.0);
1776
1777 let grid = domain_grid(interval(0.0, 1.0), FOLD_CHECK_GRID);
1780 let grid_d = ft.derivative(grid.view()).expect("grid deriv");
1781 let mean = grid_d.iter().sum::<f64>() / grid_d.len() as f64;
1782 let orientation = if mean < 0.0 { -1.0 } else { 1.0 };
1783 let min_grid = grid_d
1784 .iter()
1785 .map(|&v| orientation * v)
1786 .fold(f64::INFINITY, f64::min);
1787 let dense = Array1::from_iter((0..5120).map(|i| i as f64 / 5119.0));
1789 let dense_d = ft.derivative(dense.view()).expect("dense deriv");
1790 let min_dense = dense_d
1791 .iter()
1792 .map(|&v| orientation * v)
1793 .fold(f64::INFINITY, f64::min);
1794 ft.topology_preserved = min_grid > 0.0;
1795 ft.min_directional_derivative = min_grid;
1796 assert!(
1797 min_grid > 0.0 && min_dense < 0.0,
1798 "fixture must hide a between-grid fold: min on 512-grid={min_grid}, \
1799 min on dense grid={min_dense}"
1800 );
1801
1802 let res = ft.invert(Array1::from_elem(1, 0.0).view());
1805 assert!(
1806 res.is_err(),
1807 "between-grid fold must be rejected by the span-exact certificate \
1808 (topology_preserved={}, min_grid={min_grid}, min_dense={min_dense})",
1809 ft.topology_preserved
1810 );
1811 }
1812
1813 #[test]
1814 fn invert_rejects_non_finite_targets() {
1815 let n = 64;
1816 let from: Array1<f64> = Array1::from_iter((0..n).map(|i| i as f64 / (n as f64 - 1.0)));
1817 let to: Array1<f64> = from.mapv(|t| 0.5 * t);
1818 let ft = fit_transport_map(
1819 from.view(),
1820 to.view(),
1821 interval(0.0, 1.0),
1822 interval(0.0, 1.0),
1823 )
1824 .expect("fit");
1825 for bad in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
1826 assert!(
1827 ft.invert(Array1::from_elem(1, bad).view()).is_err(),
1828 "non-finite target {bad} must be rejected"
1829 );
1830 }
1831 }
1832
1833 #[test]
1834 fn invert_image_tolerance_is_scale_aware() {
1835 let n = 64;
1839 let from: Array1<f64> = Array1::from_iter((0..n).map(|i| i as f64 / (n as f64 - 1.0)));
1840 let scale = 1.0e-8;
1841 let to: Array1<f64> = from.mapv(|t| scale * t);
1842 let ft = fit_transport_map(
1843 from.view(),
1844 to.view(),
1845 interval(0.0, 1.0),
1846 interval(0.0, 1.0),
1847 )
1848 .expect("fit");
1849 let outside = 1.05e-8;
1850 assert!(
1851 ft.invert(Array1::from_elem(1, outside).view()).is_err(),
1852 "target {outside} is 5% outside the [0, {scale}] image and must be rejected"
1853 );
1854 let inside = 0.5e-8;
1856 let t = ft
1857 .invert(Array1::from_elem(1, inside).view())
1858 .expect("invert inside");
1859 let re = ft.eval(t.view()).expect("eval");
1860 assert!((re[0] - inside).abs() < 1e-3 * scale);
1861 }
1862
1863 #[test]
1864 fn invert_round_trips_degree_minus_one_circle() {
1865 let n = 128;
1868 let from: Array1<f64> = Array1::from_iter((0..n).map(|i| TAU * i as f64 / n as f64));
1869 let to: Array1<f64> = from.mapv(|t| wrap_tau(-t + 0.4 + 0.15 * t.sin()));
1870 let ft = fit_transport_map(
1871 from.view(),
1872 to.view(),
1873 ChartTopology::Circle,
1874 ChartTopology::Circle,
1875 )
1876 .expect("fit");
1877 assert_eq!(ft.degree, Some(-1), "expected a degree −1 cover");
1878 assert!(ft.topology_preserved, "degree {:?}", ft.degree);
1879 let probe = Array1::from_iter((0..7).map(|i| TAU * (i as f64 + 0.5) / 7.0));
1880 let fwd = ft.eval(probe.view()).expect("eval");
1881 let back = ft.invert(fwd.view()).expect("invert");
1882 for i in 0..probe.len() {
1883 let d = wrap_pi(back[i] - probe[i]).abs();
1884 assert!(d < 1e-5, "probe={} back={} d={}", probe[i], back[i], d);
1885 }
1886 }
1887
1888 #[test]
1889 fn invert_round_trips_circle_seam_and_interval_endpoints() {
1890 let n = 128;
1892 let from: Array1<f64> = Array1::from_iter((0..n).map(|i| TAU * i as f64 / n as f64));
1893 let to: Array1<f64> = from.mapv(|t| wrap_tau(t + 0.3 + 0.2 * t.sin()));
1894 let ft = fit_transport_map(
1895 from.view(),
1896 to.view(),
1897 ChartTopology::Circle,
1898 ChartTopology::Circle,
1899 )
1900 .expect("fit");
1901 assert!(ft.topology_preserved);
1902 for seam in [1e-9, TAU - 1e-9, 0.0] {
1903 let t = ft
1904 .invert(Array1::from_elem(1, seam).view())
1905 .expect("invert seam");
1906 let re = ft.eval(t.view()).expect("eval");
1907 let d = wrap_pi(re[0] - wrap_tau(seam)).abs();
1908 assert!(d < 1e-6, "seam={seam} re={} d={d}", re[0]);
1909 }
1910
1911 let m = 64;
1913 let ifrom: Array1<f64> = Array1::from_iter((0..m).map(|i| i as f64 / (m as f64 - 1.0)));
1914 let ito: Array1<f64> = ifrom.mapv(|t| t + 0.25 * (TAU * t).sin() / TAU);
1915 let ift = fit_transport_map(
1916 ifrom.view(),
1917 ito.view(),
1918 interval(0.0, 1.0),
1919 interval(0.0, 1.0),
1920 )
1921 .expect("fit");
1922 let raw_lo = ift.raw_at(0.0).expect("raw lo");
1923 let raw_hi = ift.raw_at(1.0).expect("raw hi");
1924 for &edge in &[raw_lo, raw_hi] {
1925 let t = ift
1926 .invert(Array1::from_elem(1, edge).view())
1927 .expect("invert endpoint");
1928 assert!(t[0] >= -1e-9 && t[0] <= 1.0 + 1e-9, "endpoint t={}", t[0]);
1929 let re = ift.eval(t.view()).expect("eval");
1930 assert!((re[0] - edge).abs() < 1e-6, "edge={edge} re={}", re[0]);
1931 }
1932 }
1933
1934 #[test]
1935 fn monomial_reconstruction_is_exact_for_quadratic() {
1936 let coeffs_true = [0.7_f64, -1.3, 2.1]; let values: Vec<f64> = (0..3)
1940 .map(|i| eval_monomial(&coeffs_true, i as f64))
1941 .collect();
1942 let recon = monomial_from_equispaced(&values);
1943 for (a, b) in recon.iter().zip(coeffs_true.iter()) {
1944 assert!((a - b).abs() < 1e-12, "recon {a} vs {b}");
1945 }
1946 let crit = monomial_critical_points(&recon);
1948 assert_eq!(crit.len(), 1);
1949 assert!((crit[0] - 1.3 / 4.2).abs() < 1e-12);
1950 }
1951}