1use super::*;
2
3type DuchonBasisCacheKey = (u64, u64);
26
27#[derive(Clone)]
28struct CachedDuchonBasis(BasisBuildResult);
29
30impl gam_runtime::resource::ResidentBytes for CachedDuchonBasis {
31 fn resident_bytes(&self) -> usize {
32 let design_bytes = self
36 .0
37 .design
38 .nrows()
39 .saturating_mul(self.0.design.ncols())
40 .saturating_mul(std::mem::size_of::<f64>());
41 let penalty_bytes: usize = self
42 .0
43 .penalties
44 .iter()
45 .map(|s| s.len().saturating_mul(std::mem::size_of::<f64>()))
46 .sum();
47 design_bytes
48 .saturating_add(penalty_bytes)
49 .saturating_add(4096)
50 }
51}
52
53fn duchon_basis_cache()
58-> &'static gam_runtime::resource::ByteLruCache<DuchonBasisCacheKey, CachedDuchonBasis> {
59 static CACHE: std::sync::OnceLock<
60 gam_runtime::resource::ByteLruCache<DuchonBasisCacheKey, CachedDuchonBasis>,
61 > = std::sync::OnceLock::new();
62 CACHE.get_or_init(|| gam_runtime::resource::ByteLruCache::new(1 << 30))
63}
64
65fn duchon_basis_fingerprint(
75 data: ArrayView2<'_, f64>,
76 spec: &DuchonBasisSpec,
77) -> Option<DuchonBasisCacheKey> {
78 let spec_bytes = serde_json::to_vec(spec).ok()?;
79 let mut lo = DefaultHasher::new();
80 let mut hi = DefaultHasher::new();
81 0x9E37_79B9_7F4A_7C15u64.hash(&mut hi);
83
84 let (nrows, ncols) = data.dim();
85 for h in [&mut lo, &mut hi] {
86 nrows.hash(h);
87 ncols.hash(h);
88 }
89 for row in data.rows() {
93 for &v in row {
94 let bits = v.to_bits();
95 bits.hash(&mut lo);
96 bits.hash(&mut hi);
97 }
98 }
99 for h in [&mut lo, &mut hi] {
100 spec_bytes.len().hash(h);
101 spec_bytes.hash(h);
102 }
103 Some((lo.finish(), hi.finish()))
104}
105
106pub fn build_duchon_basiswithworkspace(
107 data: ArrayView2<'_, f64>,
108 spec: &DuchonBasisSpec,
109 workspace: &mut BasisWorkspace,
110) -> Result<BasisBuildResult, BasisError> {
111 if let Some(key) = duchon_basis_fingerprint(data, spec) {
112 if let Some(hit) = duchon_basis_cache().get(&key) {
113 return Ok(hit.0);
114 }
115 let result = build_duchon_basis_uncached(data, spec, workspace)?;
116 duchon_basis_cache().insert(key, CachedDuchonBasis(result.clone()));
117 return Ok(result);
118 }
119 build_duchon_basis_uncached(data, spec, workspace)
120}
121
122fn build_duchon_basis_uncached(
123 data: ArrayView2<'_, f64>,
124 spec: &DuchonBasisSpec,
125 workspace: &mut BasisWorkspace,
126) -> Result<BasisBuildResult, BasisError> {
127 if let Some((start, end, _period)) = spec.boundary.period() {
128 return build_cyclic_duchon_basis_1dwithworkspace(data, spec, start, end);
129 }
130 let centers = select_centers_by_strategy(data, &spec.center_strategy)?;
131 assert_spatial_centers_below_large_scale_cap(data.ncols(), centers.view())?;
132 if let Some(periodic) = spec.periodic.as_ref() {
133 if periodic.len() != data.ncols() {
134 crate::bail_invalid_basis!(
135 "periodic must have length d={}, got {}",
136 data.ncols(),
137 periodic.len()
138 );
139 }
140 if data.ncols() > 1 && periodic.iter().any(Option::is_some) {
141 let flags = periodic.iter().map(Option::is_some).collect::<Vec<_>>();
142 let periods = periodic
143 .iter()
144 .map(|axis| axis.unwrap_or(1.0))
145 .collect::<Vec<_>>();
146 return build_duchon_basis_mixed_periodicity_auto(data, spec, &flags, Some(&periods));
147 }
148 return build_periodic_duchon_basis_1d(data, spec, centers, workspace);
149 }
150 let effective_nullspace_order =
167 duchon_effective_nullspace_order(centers.view(), spec.nullspace_order);
168 let p_order = duchon_p_from_nullspace_order(effective_nullspace_order);
169 let aniso = auto_seed_aniso_contrasts(centers.view(), spec.aniso_log_scales.as_deref());
175 let validation_power = if spec.length_scale.is_some() {
198 spec.power_as_usize() as f64
199 } else {
200 spec.power
201 };
202 validate_duchon_kernel_orders(spec.length_scale, p_order, validation_power, data.ncols())?;
203 let mut kernel_transform = kernel_constraint_nullspace(
204 centers.view(),
205 effective_nullspace_order,
206 &mut workspace.cache,
207 )?;
208 let poly_cols = polynomial_block_from_order(data, effective_nullspace_order).ncols();
209 let base_cols = kernel_transform.ncols() + poly_cols;
210 let dense_bytes = dense_design_bytes(data.nrows(), base_cols);
211 let use_lazy = should_use_lazy_spatial_design(data.nrows(), base_cols, workspace.policy());
212 let mut frozen_radial_reparam: Option<Array2<f64>> = None;
219 if let Some(v) = spec.radial_reparam.as_ref() {
220 if v.nrows() != kernel_transform.ncols() {
221 crate::bail_dim_basis!(
222 "Duchon frozen radial reparam shape {:?} does not match constrained kernel dimension {}",
223 v.dim(),
224 kernel_transform.ncols()
225 );
226 }
227 kernel_transform = fast_ab(&kernel_transform, v);
228 frozen_radial_reparam = Some(v.clone());
229 }
230 let (design, identifiability_transform) = if use_lazy {
231 log::info!(
233 "Duchon basis switching to lazy chunked design: n={} p={} ({:.1} MiB dense)",
234 data.nrows(),
235 base_cols,
236 dense_bytes as f64 / (1024.0 * 1024.0),
237 );
238 let d = data.ncols();
239 let shared_data = shared_owned_data_matrix(data, &workspace.cache);
240 let p_order = duchon_p_from_nullspace_order(effective_nullspace_order);
241 let s_order: f64 = spec.power;
242 let length_scale = spec.length_scale;
243 let s_order_int = length_scale.map(|_| duchon_power_to_usize(s_order));
244 let coeffs = length_scale.map(|ls| {
245 duchon_partial_fraction_coeffs(
249 p_order,
250 s_order_int.expect("hybrid Duchon requires integer power"),
251 1.0 / ls.max(1e-300),
252 )
253 });
254 let pure_poly_coeff = if length_scale.is_none() {
255 Some(PolyharmonicBlockCoeff::new(
256 pure_duchon_block_order(p_order, s_order),
257 d,
258 ))
259 } else {
260 None
261 };
262 let center_mean: Vec<f64> = (0..d)
269 .map(|c| centers.column(c).sum() / (centers.nrows().max(1) as f64))
270 .collect();
271 let mut data_centered = data.to_owned();
272 for c in 0..d {
273 let mu = center_mean[c];
274 data_centered.column_mut(c).mapv_inplace(|v| v - mu);
275 }
276 let poly_block =
277 polynomial_block_from_order(data_centered.view(), effective_nullspace_order);
278 let kernel_amp = duchon_kernel_amplification(
279 centers.view(),
280 length_scale,
281 p_order,
282 duchon_power_to_usize(s_order),
283 d,
284 aniso.as_deref(),
285 coeffs.as_ref(),
286 pure_poly_coeff.as_ref(),
287 );
288 let base_design = if let Some(eta) = aniso.as_ref() {
289 let metric_weights = eta.iter().map(|&v| (2.0 * v).exp()).collect::<Vec<_>>();
290 let coeffs = coeffs.clone();
291 let kernel = move |data_row: &[f64], center_row: &[f64]| -> f64 {
292 let mut q = 0.0f64;
293 for axis in 0..data_row.len() {
294 let delta = data_row[axis] - center_row[axis];
295 q += metric_weights[axis] * delta * delta;
296 }
297 let r = q.sqrt();
298 let raw = if let Some(ppc) = pure_poly_coeff {
299 ppc.eval(r)
300 } else {
301 duchon_matern_kernel_general_from_distance(
302 r,
303 length_scale,
304 p_order,
305 s_order_int.expect("hybrid Duchon requires integer power"),
306 d,
307 coeffs.as_ref(),
308 )
309 .expect("validated Duchon inputs should not fail")
310 };
311 raw * kernel_amp
312 };
313 let kernel_gauge = Arc::new(gam_problem::Gauge::from_block_transforms(&[
314 kernel_transform.clone(),
315 ]));
316 let base_op = ChunkedKernelDesignOperator::new(
317 shared_data.clone(),
318 Arc::new(centers.clone()),
319 kernel,
320 Some(kernel_gauge),
321 Some(Arc::new(poly_block.clone())),
322 )
323 .map_err(BasisError::InvalidInput)?;
324 DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(
325 base_op,
326 )))
327 } else {
328 let coeffs = coeffs.clone();
329 let kernel = move |data_row: &[f64], center_row: &[f64]| -> f64 {
330 let r = stable_euclidean_norm((0..d).map(|axis| data_row[axis] - center_row[axis]));
331 let raw = if let Some(ppc) = pure_poly_coeff {
332 ppc.eval(r)
333 } else {
334 duchon_matern_kernel_general_from_distance(
335 r,
336 length_scale,
337 p_order,
338 s_order_int.expect("hybrid Duchon requires integer power"),
339 d,
340 coeffs.as_ref(),
341 )
342 .expect("validated Duchon inputs should not fail")
343 };
344 raw * kernel_amp
345 };
346 let kernel_gauge = Arc::new(gam_problem::Gauge::from_block_transforms(&[
347 kernel_transform.clone(),
348 ]));
349 let base_op = ChunkedKernelDesignOperator::new(
350 shared_data,
351 Arc::new(centers.clone()),
352 kernel,
353 Some(kernel_gauge),
354 Some(Arc::new(poly_block)),
355 )
356 .map_err(BasisError::InvalidInput)?;
357 DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(
358 base_op,
359 )))
360 };
361 let identifiability_transform = spatial_identifiability_transform_from_design_matrix(
362 data,
363 &base_design,
364 &spec.identifiability,
365 "Duchon",
366 )?;
367 let design = if let Some(transform) = identifiability_transform.as_ref() {
368 wrap_dense_design_with_transform(base_design, transform, "Duchon")?
369 } else {
370 base_design
371 };
372 (design, identifiability_transform)
373 } else {
374 let operators_active = matches!(
392 spec.operator_penalties.mass,
393 OperatorPenaltySpec::Active { .. }
394 ) || matches!(
395 spec.operator_penalties.tension,
396 OperatorPenaltySpec::Active { .. }
397 ) || matches!(
398 spec.operator_penalties.stiffness,
399 OperatorPenaltySpec::Active { .. }
400 );
401 if frozen_radial_reparam.is_none() && !operators_active {
402 let kernel_cols = kernel_transform.ncols();
403 if kernel_cols > 0 {
404 let raw = build_duchon_basis_designwithworkspace(
408 data,
409 centers.view(),
410 spec.length_scale,
411 spec.power,
412 effective_nullspace_order,
413 aniso.as_deref(),
414 None,
415 workspace,
416 )?;
417 let kernel_block = raw.basis.slice(s![.., 0..kernel_cols]);
418 let design_gram = symmetrize_penalty(&fast_atb(&kernel_block, &kernel_block));
419 let omega_constrained = duchon_constrained_bending_penalty(
420 centers.view(),
421 spec.length_scale,
422 spec.power,
423 effective_nullspace_order,
424 aniso.as_deref(),
425 &kernel_transform,
426 )?;
427 let (v, _mu) =
428 thin_plate_radial_reparam_data_metric(&omega_constrained, &design_gram)?;
429 if v.ncols() > 0 {
432 kernel_transform = fast_ab(&kernel_transform, &v);
433 frozen_radial_reparam = Some(v);
434 }
435 }
436 }
437 let d = build_duchon_basis_designwithworkspace(
438 data,
439 centers.view(),
440 spec.length_scale,
441 spec.power,
442 effective_nullspace_order,
443 aniso.as_deref(),
444 frozen_radial_reparam.as_ref(),
445 workspace,
446 )?;
447 let basis = d.basis;
448 let identifiability_transform = spatial_identifiability_transform_from_design(
449 data,
450 basis.view(),
451 &spec.identifiability,
452 "Duchon",
453 )?;
454 let design = if let Some(z) = identifiability_transform.as_ref() {
455 DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(fast_ab(
456 &basis, z,
457 )))
458 } else {
459 DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(basis))
460 };
461 (design, identifiability_transform)
462 };
463 let operator_collocation_points = {
473 let any_operator = matches!(
474 spec.operator_penalties.mass,
475 OperatorPenaltySpec::Active { .. }
476 ) || matches!(
477 spec.operator_penalties.tension,
478 OperatorPenaltySpec::Active { .. }
479 ) || matches!(
480 spec.operator_penalties.stiffness,
481 OperatorPenaltySpec::Active { .. }
482 );
483 if any_operator {
484 let m = (DUCHON_COLLOCATION_OVERSAMPLE * centers.nrows()).min(data.nrows());
485 Some(select_thin_plate_knots(data, m)?)
486 } else {
487 None
488 }
489 };
490 let mut candidates = duchon_native_penalty_candidates(
491 centers.view(),
492 spec.length_scale,
493 spec.power,
494 effective_nullspace_order,
495 aniso.as_deref(),
496 &kernel_transform,
497 identifiability_transform.as_ref(),
498 poly_cols,
499 )?;
500 if let Some(points) = operator_collocation_points.as_ref() {
501 candidates.extend(duchon_operator_penalty_candidates(
502 points.view(),
503 centers.view(),
504 &spec.operator_penalties,
505 spec.length_scale,
506 spec.power,
507 effective_nullspace_order,
508 aniso.is_some(),
509 identifiability_transform.as_ref(),
510 workspace,
511 )?);
512 }
513 let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
514 filter_active_penalty_candidates_with_ops(candidates)?;
515 Ok(BasisBuildResult {
516 design,
517 penalties,
518 nullspace_dims,
519 penaltyinfo,
520 ops,
521 null_eigenvectors,
522 joint_null_rotation: None,
523 metadata: BasisMetadata::Duchon {
524 centers,
525 length_scale: spec.length_scale,
526 periodic: spec.periodic.clone(),
527 power: spec.power,
528 nullspace_order: effective_nullspace_order,
529 identifiability_transform,
530 input_scales: None,
531 aniso_log_scales: aniso,
532 operator_collocation_points,
533 radial_reparam: frozen_radial_reparam,
534 },
535 kronecker_factored: None,
536 })
537}
538
539pub fn duchon_penalties_at_length_scale(
560 centers: ArrayView2<'_, f64>,
561 identifiability_transform: Option<&Array2<f64>>,
562 operator_collocation_points: Option<ArrayView2<'_, f64>>,
563 operator_penalties: &DuchonOperatorPenaltySpec,
564 power: f64,
565 nullspace_order: DuchonNullspaceOrder,
566 aniso_log_scales: Option<&[f64]>,
567 radial_reparam: Option<&Array2<f64>>,
568 length_scale: Option<f64>,
569 workspace: &mut BasisWorkspace,
570) -> Result<(Vec<Array2<f64>>, Vec<usize>), BasisError> {
571 let effective_nullspace_order = duchon_effective_nullspace_order(centers, nullspace_order);
575 let aniso = auto_seed_aniso_contrasts(centers, aniso_log_scales);
576 let mut kernel_transform =
578 kernel_constraint_nullspace(centers, effective_nullspace_order, &mut workspace.cache)?;
579 if let Some(v) = radial_reparam {
582 if v.nrows() != kernel_transform.ncols() {
583 crate::bail_dim_basis!(
584 "Duchon frozen radial reparam shape {:?} does not match constrained kernel dimension {}",
585 v.dim(),
586 kernel_transform.ncols()
587 );
588 }
589 kernel_transform = fast_ab(&kernel_transform, v);
590 }
591 let poly_cols = polynomial_block_from_order(centers, effective_nullspace_order).ncols();
594 let mut candidates = duchon_native_penalty_candidates(
595 centers,
596 length_scale,
597 power,
598 effective_nullspace_order,
599 aniso.as_deref(),
600 &kernel_transform,
601 identifiability_transform,
602 poly_cols,
603 )?;
604 if let Some(points) = operator_collocation_points {
605 candidates.extend(duchon_operator_penalty_candidates(
606 points,
607 centers,
608 operator_penalties,
609 length_scale,
610 power,
611 effective_nullspace_order,
612 aniso.is_some(),
613 identifiability_transform,
614 workspace,
615 )?);
616 }
617 let (penalties, nullspace_dims, _info, _eig, _ops) =
618 filter_active_penalty_candidates_with_ops(candidates)?;
619 Ok((penalties, nullspace_dims))
620}
621
622pub(crate) fn polynomial_block_from_order(
641 points: ArrayView2<'_, f64>,
642 order: DuchonNullspaceOrder,
643) -> Array2<f64> {
644 let n = points.nrows();
645 let d = points.ncols();
646 match order {
647 DuchonNullspaceOrder::Zero => Array2::<f64>::ones((n, 1)),
648 DuchonNullspaceOrder::Linear => {
649 let mut poly = Array2::<f64>::zeros((n, d + 1));
650 poly.column_mut(0).fill(1.0);
651 for c in 0..d {
652 poly.column_mut(c + 1).assign(&points.column(c));
653 }
654 poly
655 }
656 DuchonNullspaceOrder::Degree(degree) => monomial_basis_block(points, degree),
657 }
658}
659
660pub fn monomial_exponents(dimension: usize, max_total_degree: usize) -> Vec<Vec<usize>> {
661 fn recurse(
662 axis: usize,
663 remaining_degree: usize,
664 current: &mut [usize],
665 out: &mut Vec<Vec<usize>>,
666 ) {
667 if axis + 1 == current.len() {
668 current[axis] = remaining_degree;
669 out.push(current.to_vec());
670 return;
671 }
672 for exponent in (0..=remaining_degree).rev() {
673 current[axis] = exponent;
674 recurse(axis + 1, remaining_degree - exponent, current, out);
675 }
676 }
677
678 if dimension == 0 {
679 return vec![Vec::new()];
680 }
681
682 let mut out = Vec::new();
683 let mut current = vec![0usize; dimension];
684 for total_degree in 0..=max_total_degree {
685 recurse(0, total_degree, &mut current, &mut out);
686 }
687 out
688}
689
690pub fn duchon_nullspace_dimension(dimension: usize, max_total_degree: usize) -> usize {
691 monomial_exponents(dimension, max_total_degree).len()
692}
693
694pub(crate) fn monomial_basis_block(
695 points: ArrayView2<'_, f64>,
696 max_total_degree: usize,
697) -> Array2<f64> {
698 let n = points.nrows();
699 let exponents = monomial_exponents(points.ncols(), max_total_degree);
700 let mut block = Array2::<f64>::zeros((n, exponents.len()));
701 for (col, exponents) in exponents.iter().enumerate() {
702 for row in 0..n {
703 let mut value = 1.0;
704 for axis in 0..points.ncols() {
705 let exponent = exponents[axis];
706 if exponent != 0 {
707 value *= points[[row, axis]].powi(exponent as i32);
708 }
709 }
710 block[[row, col]] = value;
711 }
712 }
713 block
714}
715
716#[inline(always)]
717pub(crate) fn thin_plate_polynomial_degree(dimension: usize) -> usize {
718 thin_plate_penalty_order(dimension).saturating_sub(1)
719}
720
721pub(crate) fn thin_plate_polynomial_block(points: ArrayView2<'_, f64>) -> Array2<f64> {
722 monomial_basis_block(points, thin_plate_polynomial_degree(points.ncols()))
723}
724
725pub fn thin_plate_polynomial_basis_dimension(dimension: usize) -> usize {
726 monomial_exponents(dimension, thin_plate_polynomial_degree(dimension)).len()
727}
728
729fn thin_plate_retained_radial_indices(evals: &Array1<f64>) -> Vec<usize> {
753 let k = evals.len();
754 if k == 0 {
755 return Vec::new();
756 }
757 let max_eval = evals
758 .iter()
759 .copied()
760 .fold(0.0_f64, |acc, value| acc.max(value.abs()));
761 if !max_eval.is_finite() || max_eval <= 0.0 {
762 return Vec::new();
763 }
764 let num_floor = (k as f64) * f64::EPSILON * max_eval;
768 evals
769 .iter()
770 .enumerate()
771 .filter_map(|(idx, &value)| (value.abs() > num_floor).then_some(idx))
772 .collect()
773}
774
775pub(crate) fn thin_plate_radial_reparam_from_constrained_penalty(
776 omega_constrained: &Array2<f64>,
777) -> Result<(Array2<f64>, Array1<f64>), BasisError> {
778 let kernel_cols = omega_constrained.nrows();
779 if kernel_cols != omega_constrained.ncols() {
780 crate::bail_dim_basis!(
781 "thin-plate constrained radial penalty must be square: got {:?}",
782 omega_constrained.dim()
783 );
784 }
785 if kernel_cols == 0 {
786 return Ok((Array2::<f64>::zeros((0, 0)), Array1::<f64>::zeros(0)));
787 }
788 let sym = symmetrize_penalty(omega_constrained);
789 let (mut evals, evecs) = FaerEigh::eigh(&sym, Side::Lower).map_err(BasisError::LinalgError)?;
790 for value in evals.iter_mut() {
791 if *value < 0.0 {
792 *value = 0.0;
793 }
794 }
795 let keep = thin_plate_retained_radial_indices(&evals);
796 Ok((evecs.select(Axis(1), &keep), evals.select(Axis(0), &keep)))
797}
798
799pub(crate) fn thin_plate_radial_reparam_data_metric(
829 omega_constrained: &Array2<f64>,
830 design_gram: &Array2<f64>,
831) -> Result<(Array2<f64>, Array1<f64>), BasisError> {
832 let k = omega_constrained.nrows();
833 if k != omega_constrained.ncols() || design_gram.nrows() != k || design_gram.ncols() != k {
834 crate::bail_dim_basis!(
835 "thin-plate data-metric reparam requires square k×k Ω_c and G_c: Ω_c={:?}, G_c={:?}",
836 omega_constrained.dim(),
837 design_gram.dim()
838 );
839 }
840 if k == 0 {
841 return Ok((Array2::<f64>::zeros((0, 0)), Array1::<f64>::zeros(0)));
842 }
843 let g_sym = symmetrize_penalty(design_gram);
846 let (g_evals, g_evecs) =
847 FaerEigh::eigh(&g_sym, Side::Lower).map_err(BasisError::LinalgError)?;
848 let gmax = g_evals.iter().copied().fold(0.0_f64, |a, b| a.max(b.abs()));
849 if !gmax.is_finite() || gmax <= 0.0 {
850 return thin_plate_radial_reparam_from_constrained_penalty(omega_constrained);
852 }
853 let g_floor = (k as f64) * f64::EPSILON * gmax;
854 let mut cols: Vec<usize> = Vec::with_capacity(k);
855 for j in 0..k {
856 if g_evals[j] > g_floor {
857 cols.push(j);
858 }
859 }
860 let m = cols.len();
861 if m == 0 {
862 return thin_plate_radial_reparam_from_constrained_penalty(omega_constrained);
863 }
864 let mut w = Array2::<f64>::zeros((k, m));
865 for (c, &j) in cols.iter().enumerate() {
866 let inv_sqrt = 1.0 / g_evals[j].sqrt();
867 for i in 0..k {
868 w[[i, c]] = g_evecs[[i, j]] * inv_sqrt;
869 }
870 }
871 let omega_sym = symmetrize_penalty(omega_constrained);
873 let wt_omega = fast_atb(&w, &omega_sym);
874 let m_mat = symmetrize_penalty(&fast_ab(&wt_omega, &w));
875 let (mut mu, p_mat) = FaerEigh::eigh(&m_mat, Side::Lower).map_err(BasisError::LinalgError)?;
876 for value in mu.iter_mut() {
877 if *value < 0.0 {
878 *value = 0.0;
879 }
880 }
881 let v_full = fast_ab(&w, &p_mat); let keep = thin_plate_retained_radial_indices(&mu);
883 Ok((v_full.select(Axis(1), &keep), mu.select(Axis(0), &keep)))
884}
885
886pub(crate) fn thin_plate_radial_reparam_from_centers(
887 centers: ArrayView2<'_, f64>,
888 length_scale: f64,
889 kernel_transform: &Array2<f64>,
890) -> Result<(Array2<f64>, Array1<f64>), BasisError> {
891 let k = centers.nrows();
892 let d = centers.ncols();
893 let mut omega = Array2::<f64>::zeros((k, k));
894 let length_scale_sq = length_scale * length_scale;
895 fill_symmetric_from_row_kernel(&mut omega, |i, j| {
896 let mut dist2 = 0.0;
897 for c in 0..d {
898 let delta = centers[[i, c]] - centers[[j, c]];
899 dist2 += delta * delta;
900 }
901 thin_plate_kernel_from_dist2(dist2 / length_scale_sq, d)
902 })?;
903 let kernel_gauge = gam_problem::Gauge::from_block_transforms(&[kernel_transform.clone()]);
904 let omega_constrained = symmetrize_penalty(&kernel_gauge.restrict_penalty(&omega));
905 thin_plate_radial_reparam_from_constrained_penalty(&omega_constrained)
906}
907
908pub(crate) fn kernel_constraint_nullspace_from_matrix(
909 constraint_matrix: ArrayView2<'_, f64>,
910) -> Result<Array2<f64>, BasisError> {
911 let k = constraint_matrix.nrows();
912 let q = constraint_matrix.ncols();
913 if q == 0 {
914 return Ok(Array2::<f64>::eye(k));
915 }
916 let (z, _) = rrqr_nullspace_basis(&constraint_matrix, default_rrqr_rank_alpha())
919 .map_err(BasisError::LinalgError)?;
920 Ok(z)
921}
922
923pub fn select_thin_plate_knots(
927 data: ArrayView2<f64>,
928 num_knots: usize,
929) -> Result<Array2<f64>, BasisError> {
930 let n = data.nrows();
931 let d = data.ncols();
932 if d == 0 {
933 crate::bail_invalid_basis!("thin-plate spline requires at least one covariate dimension");
934 }
935 if n == 0 {
936 crate::bail_invalid_basis!("cannot select thin-plate knots from empty data");
937 }
938 if data.iter().any(|v| !v.is_finite()) {
939 crate::bail_invalid_basis!("thin-plate spline knot selection requires finite data");
940 }
941 if num_knots == 0 {
942 crate::bail_invalid_basis!("thin-plate spline knot count must be positive");
943 }
944 if num_knots > n {
945 crate::bail_invalid_basis!(
946 "requested {} knots but only {} rows are available",
947 num_knots,
948 n
949 );
950 }
951
952 let centroid: Vec<f64> = (0..d)
989 .map(|c| {
990 let mut col: Vec<f64> = (0..n).map(|i| data[[i, c]]).collect();
991 col.sort_by(|a, b| a.total_cmp(b));
992 let s: f64 = col.iter().sum();
993 s / n as f64
994 })
995 .collect();
996 let dist2_to_centroid: Vec<f64> = (0..n)
997 .into_par_iter()
998 .map(|i| {
999 let mut d2 = 0.0;
1000 for c in 0..d {
1001 let delta = data[[i, c]] - centroid[c];
1002 d2 += delta * delta;
1003 }
1004 d2
1005 })
1006 .collect();
1007
1008 let value_less = |i: usize, j: usize| -> bool {
1016 for c in 0..d {
1017 let vi = data[[i, c]];
1018 let vj = data[[j, c]];
1019 if vi < vj {
1020 return true;
1021 }
1022 if vi > vj {
1023 return false;
1024 }
1025 }
1026 i < j
1029 };
1030
1031 let seed_idx = (0..n)
1035 .into_par_iter()
1036 .map(|i| (i, dist2_to_centroid[i]))
1037 .reduce_with(|a, b| {
1038 if b.1 < a.1 || (b.1 == a.1 && value_less(b.0, a.0)) {
1039 b
1040 } else {
1041 a
1042 }
1043 })
1044 .map(|(i, _)| i)
1045 .unwrap_or(0);
1046
1047 let mut selected = Vec::with_capacity(num_knots);
1048 let mut chosen = vec![false; n];
1049 let mut min_dist2 = vec![f64::INFINITY; n];
1050
1051 selected.push(seed_idx);
1052 chosen[seed_idx] = true;
1053
1054 min_dist2.par_iter_mut().enumerate().for_each(|(i, slot)| {
1055 let mut d2 = 0.0;
1056 for c in 0..d {
1057 let delta = data[[i, c]] - data[[seed_idx, c]];
1058 d2 += delta * delta;
1059 }
1060 *slot = d2;
1061 });
1062 min_dist2[seed_idx] = 0.0;
1063
1064 while selected.len() < num_knots {
1065 let best_idx = min_dist2
1066 .par_iter()
1067 .enumerate()
1068 .filter(|(i, _)| !chosen[*i])
1069 .map(|(i, &cand)| (i, cand))
1070 .reduce_with(|a, b| {
1071 let pick_b = b.1 > a.1
1080 || (b.1 == a.1
1081 && (dist2_to_centroid[b.0] > dist2_to_centroid[a.0]
1082 || (dist2_to_centroid[b.0] == dist2_to_centroid[a.0]
1083 && value_less(b.0, a.0))));
1084 if pick_b { b } else { a }
1085 })
1086 .map(|(i, _)| i);
1087 let next_idx = match best_idx {
1088 Some(i) => i,
1089 None => break,
1090 };
1091 selected.push(next_idx);
1092 chosen[next_idx] = true;
1093
1094 min_dist2.par_iter_mut().enumerate().for_each(|(i, slot)| {
1095 if chosen[i] {
1096 return;
1097 }
1098 let mut d2 = 0.0;
1099 for c in 0..d {
1100 let delta = data[[i, c]] - data[[next_idx, c]];
1101 d2 += delta * delta;
1102 }
1103 if d2 < *slot {
1104 *slot = d2;
1105 }
1106 });
1107 }
1108
1109 let mut knots = Array2::<f64>::zeros((selected.len(), d));
1110 for (r, &idx) in selected.iter().enumerate() {
1111 knots.row_mut(r).assign(&data.row(idx));
1112 }
1113 Ok(knots)
1114}
1115
1116#[inline(always)]
1117pub(crate) fn thin_plate_kernel_from_dist2(
1118 dist2: f64,
1119 dimension: usize,
1120) -> Result<f64, BasisError> {
1121 if !dist2.is_finite() || dist2 < 0.0 {
1122 crate::bail_invalid_basis!("thin-plate kernel distance must be finite and non-negative");
1123 }
1124 if dist2 == 0.0 {
1125 return Ok(0.0);
1126 }
1127 match dimension {
1128 1 => Ok(dist2 * dist2.sqrt()),
1134 2 => Ok(0.5 * dist2 * dist2.ln()),
1135 3 => Ok(-dist2.sqrt()),
1136 _ => {
1137 let m = dimension / 2 + 1;
1141 let r = dist2.sqrt();
1142 Ok(polyharmonic_kernel(r, (m) as f64, dimension))
1143 }
1144 }
1145}
1146
1147#[inline(always)]
1148pub(crate) fn thin_plate_penalty_order(dimension: usize) -> usize {
1149 match dimension {
1150 1..=3 => 2,
1151 _ => dimension / 2 + 1,
1152 }
1153}
1154
1155#[inline(always)]
1159pub(crate) fn d_canonical_tps_infeasible(dimension: usize, num_centers: usize) -> bool {
1160 num_centers < thin_plate_polynomial_basis_dimension(dimension)
1161}
1162
1163pub(crate) fn duchon_thin_plate_fallback_params(
1175 dimension: usize,
1176 num_centers: usize,
1177) -> Option<(DuchonNullspaceOrder, usize)> {
1178 let d = dimension;
1179 let max_op = 2usize; for (order, p, m_poly) in [
1181 (DuchonNullspaceOrder::Linear, 2usize, d + 1),
1182 (DuchonNullspaceOrder::Zero, 1usize, 1usize),
1183 ] {
1184 if num_centers < m_poly {
1185 continue;
1186 }
1187 let target = d + max_op;
1189 let s_min = if 2 * p > target {
1190 0
1191 } else {
1192 (target - 2 * p) / 2 + 1
1193 };
1194 return Some((order, s_min));
1195 }
1196 None
1197}
1198
1199pub(crate) fn hybrid_duchon_promotion_length_scale(
1212 centers: ArrayView2<'_, f64>,
1213 requested_length_scale: f64,
1214) -> f64 {
1215 match pairwise_distance_bounds_sampled(centers) {
1216 Some((r_min, r_max)) => {
1217 (r_min * r_max).sqrt()
1220 }
1221 None => {
1222 if requested_length_scale.is_finite() && requested_length_scale > 0.0 {
1223 requested_length_scale
1224 } else {
1225 1.0
1226 }
1227 }
1228 }
1229}
1230
1231#[inline(always)]
1232pub(crate) fn thin_plate_kernel_triplet_from_scaled_distance(
1233 scaled_distance: f64,
1234 dimension: usize,
1235) -> Result<(f64, f64, f64), BasisError> {
1236 if !scaled_distance.is_finite() || scaled_distance < 0.0 {
1237 crate::bail_invalid_basis!("thin-plate scaled distance must be finite and non-negative");
1238 }
1239 if scaled_distance == 0.0 {
1240 return Ok((0.0, 0.0, 0.0));
1241 }
1242
1243 match dimension {
1244 1 => {
1245 let value = scaled_distance.powi(3);
1246 let first = 3.0 * scaled_distance.powi(2);
1247 let second = 6.0 * scaled_distance;
1248 Ok((value, first, second))
1249 }
1250 2 => {
1251 let log_r = scaled_distance.max(1e-300).ln();
1252 let value = scaled_distance.powi(2) * log_r;
1253 let first = 2.0 * scaled_distance * log_r + scaled_distance;
1254 let second = 2.0 * log_r + 3.0;
1255 Ok((value, first, second))
1256 }
1257 3 => Ok((-scaled_distance, -1.0, 0.0)),
1258 _ => polyharmonic_kernel_triplet(
1259 scaled_distance,
1260 thin_plate_penalty_order(dimension) as f64,
1261 dimension,
1262 ),
1263 }
1264}
1265
1266#[inline(always)]
1267pub(crate) fn thin_plate_kernel_psi_triplet_from_distance(
1268 distance: f64,
1269 length_scale: f64,
1270 dimension: usize,
1271) -> Result<(f64, f64, f64), BasisError> {
1272 if !distance.is_finite() || distance < 0.0 {
1273 crate::bail_invalid_basis!("thin-plate kernel distance must be finite and non-negative");
1274 }
1275 if !length_scale.is_finite() || length_scale <= 0.0 {
1276 crate::bail_invalid_basis!("thin-plate length_scale must be finite and positive");
1277 }
1278
1279 let scaled_distance = distance / length_scale;
1295 let (value, radial_first, radial_second) =
1296 thin_plate_kernel_triplet_from_scaled_distance(scaled_distance, dimension)?;
1297 let psi = radial_first * scaled_distance;
1298 let psi_psi = radial_second * scaled_distance * scaled_distance + psi;
1299 Ok((value, psi, psi_psi))
1300}
1301
1302pub fn create_thin_plate_spline_basis(
1315 data: ArrayView2<f64>,
1316 knots: ArrayView2<f64>,
1317) -> Result<ThinPlateSplineBasis, BasisError> {
1318 let mut workspace = BasisWorkspace::default();
1319 create_thin_plate_spline_basiswithworkspace(data, knots, &mut workspace)
1320}
1321
1322pub fn create_thin_plate_spline_basiswithworkspace(
1323 data: ArrayView2<f64>,
1324 knots: ArrayView2<f64>,
1325 workspace: &mut BasisWorkspace,
1326) -> Result<ThinPlateSplineBasis, BasisError> {
1327 create_thin_plate_spline_basis_scaledwithworkspace(data, knots, 1.0, None, workspace)
1328}
1329
1330pub(crate) fn create_thin_plate_spline_basis_scaledwithworkspace(
1331 data: ArrayView2<f64>,
1332 knots: ArrayView2<f64>,
1333 length_scale: f64,
1334 frozen_radial_reparam: Option<&Array2<f64>>,
1335 workspace: &mut BasisWorkspace,
1336) -> Result<ThinPlateSplineBasis, BasisError> {
1337 let n = data.nrows();
1338 let k = knots.nrows();
1339 let d = data.ncols();
1340
1341 if d == 0 {
1342 crate::bail_invalid_basis!("thin-plate spline requires at least one covariate dimension");
1343 }
1344 if d != knots.ncols() {
1345 crate::bail_dim_basis!(
1346 "thin-plate spline dimension mismatch: data has {} columns, knots have {} columns",
1347 d,
1348 knots.ncols()
1349 );
1350 }
1351 let poly_cols = thin_plate_polynomial_basis_dimension(d);
1352 if k < poly_cols {
1353 crate::bail_invalid_basis!(
1354 "thin-plate spline requires at least {} knots to span the degree-{} polynomial null space in dimension {}; got {}",
1355 poly_cols,
1356 thin_plate_polynomial_degree(d),
1357 d,
1358 k
1359 );
1360 }
1361 if data.iter().any(|v| !v.is_finite()) || knots.iter().any(|v| !v.is_finite()) {
1362 crate::bail_invalid_basis!("thin-plate spline requires finite data and knot values");
1363 }
1364 if !length_scale.is_finite() || length_scale <= 0.0 {
1365 crate::bail_invalid_basis!("thin-plate length_scale must be finite and positive");
1366 }
1367
1368 let knot_mean: Vec<f64> = (0..d)
1384 .map(|c| knots.column(c).sum() / (k.max(1) as f64))
1385 .collect();
1386 let mut data_centered = data.to_owned();
1387 let mut knots_centered = knots.to_owned();
1388 for c in 0..d {
1389 let mu = knot_mean[c];
1390 data_centered.column_mut(c).mapv_inplace(|v| v - mu);
1391 knots_centered.column_mut(c).mapv_inplace(|v| v - mu);
1392 }
1393 let data = data_centered.view();
1394 let knots = knots_centered.view();
1395
1396 let mut kernel_block = Array2::<f64>::zeros((n, k));
1398 let kernel_result: Result<(), BasisError> = kernel_block
1399 .axis_iter_mut(Axis(0))
1400 .into_par_iter()
1401 .enumerate()
1402 .try_for_each(|(i, mut row)| {
1403 for j in 0..k {
1404 let mut dist2 = 0.0;
1405 for c in 0..d {
1406 let delta = data[[i, c]] - knots[[j, c]];
1407 dist2 += delta * delta;
1408 }
1409 row[j] = thin_plate_kernel_from_dist2(dist2 / (length_scale * length_scale), d)?;
1410 }
1411 Ok(())
1412 });
1413 kernel_result?;
1414
1415 let poly_block = thin_plate_polynomial_block(data);
1417
1418 let mut omega = Array2::<f64>::zeros((k, k));
1420 let length_scale_sq = length_scale * length_scale;
1421 fill_symmetric_from_row_kernel(&mut omega, |i, j| {
1422 let mut dist2 = 0.0;
1423 for c in 0..d {
1424 let delta = knots[[i, c]] - knots[[j, c]];
1425 dist2 += delta * delta;
1426 }
1427 thin_plate_kernel_from_dist2(dist2 / length_scale_sq, d)
1428 })?;
1429
1430 let z = thin_plate_kernel_constraint_nullspace(knots, &mut workspace.cache)?;
1433 let kernel_constrained = fast_ab(&kernel_block, &z);
1434 let omega_constrained = {
1435 let zt_o = fast_atb(&z, &omega);
1436 symmetrize_penalty(&fast_ab(&zt_o, &z))
1437 };
1438 let omega_psd = validate_psd_penalty(
1439 &omega_constrained,
1440 &format!("thin_plate bending penalty (dimension={d})"),
1441 "thin-plate kernel and side-constraint assembly must yield a PSD penalty on the constrained subspace",
1442 )?;
1443 assert!(
1444 omega_psd.min_eigenvalue >= -omega_psd.tolerance,
1445 "thin-plate constrained penalty PSD validation violated tolerance after validation: min_eigenvalue={}, tolerance={}",
1446 omega_psd.min_eigenvalue,
1447 omega_psd.tolerance
1448 );
1449 assert!(
1450 omega_psd.max_abs_eigenvalue.is_finite(),
1451 "thin-plate constrained penalty has non-finite max eigenvalue after validation: max_abs_eigenvalue={}",
1452 omega_psd.max_abs_eigenvalue
1453 );
1454 assert!(
1455 omega_psd.effective_rank <= omega_constrained.nrows(),
1456 "thin-plate constrained penalty rank exceeds constrained rows: effective_rank={}, rows={}",
1457 omega_psd.effective_rank,
1458 omega_constrained.nrows()
1459 );
1460
1461 let constrained_kernel_cols = kernel_constrained.ncols();
1462
1463 let (radial_reparam, radial_eigvals): (Array2<f64>, Array1<f64>) = if let Some(frozen) =
1472 frozen_radial_reparam
1473 {
1474 if frozen.nrows() != constrained_kernel_cols {
1475 crate::bail_dim_basis!(
1476 "thin-plate frozen radial reparam shape {:?} does not match constrained radial dimension {}",
1477 frozen.dim(),
1478 constrained_kernel_cols
1479 );
1480 }
1481 let v = frozen.to_owned();
1482 let vt_omega_v = fast_atb(&v, &omega_constrained);
1483 let lambda_diag = fast_ab(&vt_omega_v, &v);
1484 let mut evals = Array1::<f64>::zeros(v.ncols());
1485 for i in 0..v.ncols() {
1486 evals[i] = lambda_diag[[i, i]].max(0.0);
1487 }
1488 (v, evals)
1489 } else if constrained_kernel_cols == 0 {
1490 (Array2::<f64>::zeros((0, 0)), Array1::<f64>::zeros(0))
1491 } else {
1492 let design_gram = symmetrize_penalty(&fast_atb(&kernel_constrained, &kernel_constrained));
1497 thin_plate_radial_reparam_data_metric(&omega_constrained, &design_gram)?
1498 };
1499 let kernel_cols = radial_eigvals.len();
1500 let total_cols = kernel_cols + poly_cols;
1501
1502 let kernel_rotated = if kernel_cols == 0 {
1503 Array2::<f64>::zeros((n, 0))
1504 } else {
1505 fast_ab(&kernel_constrained, &radial_reparam)
1506 };
1507
1508 let mut basis = Array2::<f64>::zeros((n, total_cols));
1509 basis
1510 .slice_mut(s![.., 0..kernel_cols])
1511 .assign(&kernel_rotated);
1512 basis.slice_mut(s![.., kernel_cols..]).assign(&poly_block);
1513
1514 let mut penalty_bending = Array2::<f64>::zeros((total_cols, total_cols));
1515 for i in 0..kernel_cols {
1516 penalty_bending[[i, i]] = radial_eigvals[i];
1517 }
1518 let penalty_ridge = build_nullspace_shrinkage_penalty(&penalty_bending)?
1519 .map(|block| block.sym_penalty)
1520 .unwrap_or_else(|| Array2::<f64>::zeros((total_cols, total_cols)));
1521
1522 Ok(ThinPlateSplineBasis {
1523 basis,
1524 penalty_bending,
1525 penalty_ridge,
1526 num_kernel_basis: kernel_cols,
1527 num_polynomial_basis: poly_cols,
1528 dimension: d,
1529 radial_reparam,
1530 })
1531}
1532
1533pub(crate) fn active_thin_plate_penalty_derivatives(
1534 penaltyinfo: &[PenaltyInfo],
1535 primary_derivative: &Array2<f64>,
1536) -> Result<Vec<Array2<f64>>, BasisError> {
1537 penaltyinfo
1538 .iter()
1539 .filter(|info| info.active)
1540 .map(|info| match &info.source {
1541 PenaltySource::Primary => Ok(primary_derivative.clone()),
1542 PenaltySource::DoublePenaltyNullspace => {
1543 Ok(Array2::<f64>::zeros(primary_derivative.raw_dim()))
1544 }
1545 other => Err(BasisError::InvalidInput(format!(
1546 "unexpected ThinPlate penalty source in psi-derivative path: {other:?}"
1547 ))),
1548 })
1549 .collect()
1550}
1551
1552pub fn build_thin_plate_penalty_psi_derivativeswithworkspace(
1559 centers: ArrayView2<'_, f64>,
1560 spec: &ThinPlateBasisSpec,
1561 identifiability_transform: Option<&Array2<f64>>,
1562 workspace: &mut BasisWorkspace,
1563) -> Result<(Array2<f64>, Array2<f64>), BasisError> {
1564 let z_kernel = thin_plate_kernel_constraint_nullspace(centers, &mut workspace.cache)?;
1578 let constrained_kernel_cols = z_kernel.ncols();
1579 let poly_cols = thin_plate_polynomial_basis_dimension(centers.ncols());
1580 let k = centers.nrows();
1581 let d = centers.ncols();
1582
1583 let mut omega = Array2::<f64>::zeros((k, k));
1586 let mut omega_psi = Array2::<f64>::zeros((k, k));
1587 let mut omega_psi_psi = Array2::<f64>::zeros((k, k));
1588
1589 struct ThinPlatePsiTileEntry {
1595 pub(crate) i: usize,
1596 pub(crate) j: usize,
1597 pub(crate) phi: f64,
1598 pub(crate) phi_psi: f64,
1599 pub(crate) phi_psi_psi: f64,
1600 }
1601
1602 let n_tiles = k.div_ceil(THIN_PLATE_PENALTY_PSI_TILE_ROWS);
1603 let omega_tiles: Result<Vec<Vec<ThinPlatePsiTileEntry>>, BasisError> = (0..n_tiles)
1604 .into_par_iter()
1605 .map(|tile_idx| {
1606 let row_start = tile_idx * THIN_PLATE_PENALTY_PSI_TILE_ROWS;
1607 let row_end = (row_start + THIN_PLATE_PENALTY_PSI_TILE_ROWS).min(k);
1608 let tile_pairs = (row_start..row_end).map(|i| i + 1).sum::<usize>();
1609 let mut entries = Vec::with_capacity(tile_pairs);
1610 for i in row_start..row_end {
1611 for j in 0..=i {
1612 let mut dist2 = 0.0;
1613 for axis in 0..d {
1614 let delta = centers[[i, axis]] - centers[[j, axis]];
1615 dist2 += delta * delta;
1616 }
1617 let (phi, phi_psi, phi_psi_psi) = thin_plate_kernel_psi_triplet_from_distance(
1618 dist2.sqrt(),
1619 spec.length_scale,
1620 d,
1621 )?;
1622 entries.push(ThinPlatePsiTileEntry {
1623 i,
1624 j,
1625 phi,
1626 phi_psi,
1627 phi_psi_psi,
1628 });
1629 }
1630 }
1631 Ok(entries)
1632 })
1633 .collect();
1634
1635 for tile in omega_tiles? {
1636 for entry in tile {
1637 omega[[entry.i, entry.j]] = entry.phi;
1638 omega_psi[[entry.i, entry.j]] = entry.phi_psi;
1639 omega_psi_psi[[entry.i, entry.j]] = entry.phi_psi_psi;
1640 if entry.i != entry.j {
1641 omega[[entry.j, entry.i]] = entry.phi;
1642 omega_psi[[entry.j, entry.i]] = entry.phi_psi;
1643 omega_psi_psi[[entry.j, entry.i]] = entry.phi_psi_psi;
1644 }
1645 }
1646 }
1647
1648 let m_constrained = symmetrize_penalty(&z_kernel.t().dot(&omega).dot(&z_kernel));
1650 let m_psi_constrained = symmetrize_penalty(&z_kernel.t().dot(&omega_psi).dot(&z_kernel));
1651 let m_pp_constrained = symmetrize_penalty(&z_kernel.t().dot(&omega_psi_psi).dot(&z_kernel));
1652
1653 let (v, lambda) = if let Some(frozen) = spec.radial_reparam.as_ref() {
1655 if frozen.nrows() != constrained_kernel_cols {
1656 crate::bail_dim_basis!(
1657 "thin-plate frozen radial reparam shape {:?} does not match constrained radial dimension {}",
1658 frozen.dim(),
1659 constrained_kernel_cols
1660 );
1661 }
1662 let v_owned = frozen.to_owned();
1663 let lambda_diag = fast_ab(&fast_atb(&v_owned, &m_constrained), &v_owned);
1664 let mut evals = Array1::<f64>::zeros(v_owned.ncols());
1665 for i in 0..v_owned.ncols() {
1666 evals[i] = lambda_diag[[i, i]].max(0.0);
1667 }
1668 (v_owned, evals)
1669 } else if constrained_kernel_cols == 0 {
1670 (Array2::<f64>::zeros((0, 0)), Array1::<f64>::zeros(0))
1671 } else {
1672 let (mut evals, evecs) =
1673 FaerEigh::eigh(&m_constrained, Side::Lower).map_err(BasisError::LinalgError)?;
1674 for ev in evals.iter_mut() {
1675 if *ev < 0.0 {
1676 *ev = 0.0;
1677 }
1678 }
1679 let keep = thin_plate_retained_radial_indices(&evals);
1680 (evecs.select(Axis(1), &keep), evals.select(Axis(0), &keep))
1681 };
1682 let kernel_cols = lambda.len();
1683 let total_cols = kernel_cols + poly_cols;
1684 let v_is_frozen = spec.radial_reparam.is_some();
1685
1686 let a_psi = if kernel_cols > 0 {
1691 v.t().dot(&m_psi_constrained).dot(&v)
1692 } else {
1693 Array2::<f64>::zeros((0, 0))
1694 };
1695 let a_pp = if kernel_cols > 0 {
1696 v.t().dot(&m_pp_constrained).dot(&v)
1697 } else {
1698 Array2::<f64>::zeros((0, 0))
1699 };
1700
1701 let s_raw_kernel = Array2::from_diag(&lambda);
1717 let s_raw_psi_kernel = if v_is_frozen {
1718 a_psi.clone()
1719 } else {
1720 let mut diag = Array2::<f64>::zeros((kernel_cols, kernel_cols));
1721 for i in 0..kernel_cols {
1722 diag[[i, i]] = a_psi[[i, i]];
1723 }
1724 diag
1725 };
1726 let s_raw_pp_kernel = if v_is_frozen {
1727 a_pp.clone()
1728 } else {
1729 let mut diag = Array2::<f64>::zeros((kernel_cols, kernel_cols));
1730 for i in 0..kernel_cols {
1731 let mut acc = a_pp[[i, i]];
1732 for k_idx in 0..kernel_cols {
1733 if k_idx == i {
1734 continue;
1735 }
1736 let denom = lambda[i] - lambda[k_idx];
1737 if denom.abs() > 1e-14 {
1738 acc += 2.0 * a_psi[[i, k_idx]].powi(2) / denom;
1739 }
1740 }
1741 diag[[i, i]] = acc;
1742 }
1743 diag
1744 };
1745
1746 let pad = |kernel_block: &Array2<f64>| -> Array2<f64> {
1748 let mut s = Array2::<f64>::zeros((total_cols, total_cols));
1749 if kernel_cols > 0 {
1750 s.slice_mut(s![0..kernel_cols, 0..kernel_cols])
1751 .assign(kernel_block);
1752 }
1753 s
1754 };
1755 let s_raw = pad(&s_raw_kernel);
1756 let s_raw_psi = pad(&s_raw_psi_kernel);
1757 let s_raw_pp = pad(&s_raw_pp_kernel);
1758
1759 let (_, s_norm_psi, s_norm_pp, _c) =
1767 normalize_penaltywith_psi_derivatives(&s_raw, &s_raw_psi, &s_raw_pp);
1768
1769 let s_psi_out = project_penalty_matrix(&s_norm_psi, identifiability_transform);
1772 let s_psi_psi_out = project_penalty_matrix(&s_norm_pp, identifiability_transform);
1773
1774 Ok((s_psi_out, s_psi_psi_out))
1775}
1776
1777pub(crate) fn build_thin_plate_scalar_design_psi_derivatives(
1787 data: ArrayView2<'_, f64>,
1788 centers: ArrayView2<'_, f64>,
1789 spec: &ThinPlateBasisSpec,
1790 identifiability_transform: Option<&Array2<f64>>,
1791 workspace: &mut BasisWorkspace,
1792) -> Result<ScalarDesignPsiDerivatives, BasisError> {
1793 let z_kernel = thin_plate_kernel_constraint_nullspace(centers, &mut workspace.cache)?;
1794 let constrained_kernel_cols = z_kernel.ncols();
1795 let kernel_transform = if let Some(v) = spec.radial_reparam.as_ref() {
1796 if v.nrows() != constrained_kernel_cols {
1797 crate::bail_dim_basis!(
1798 "thin-plate radial reparam shape {:?} does not match constrained radial dimension {}",
1799 v.dim(),
1800 constrained_kernel_cols
1801 );
1802 }
1803 fast_ab(&z_kernel, v)
1804 } else {
1805 z_kernel
1806 };
1807 let kernel_cols = kernel_transform.ncols();
1808 let poly_cols = thin_plate_polynomial_basis_dimension(data.ncols());
1809 let p_after_pad = kernel_cols + poly_cols;
1810 let p_final = identifiability_transform
1811 .map(|zf| zf.ncols())
1812 .unwrap_or(p_after_pad);
1813 build_scalar_design_psi_derivatives_shared(
1814 data,
1815 centers,
1816 None,
1817 p_final,
1818 Some(kernel_transform),
1819 identifiability_transform.cloned(),
1820 poly_cols,
1821 RadialScalarKind::ThinPlate {
1822 length_scale: spec.length_scale,
1823 dim: data.ncols(),
1824 },
1825 0.0,
1826 )
1827}
1828
1829pub fn build_thin_plate_basis_log_kappa_derivative(
1830 data: ArrayView2<'_, f64>,
1831 spec: &ThinPlateBasisSpec,
1832) -> Result<BasisPsiDerivativeResult, BasisError> {
1833 let mut workspace = BasisWorkspace::default();
1834 build_thin_plate_basis_log_kappa_derivativewithworkspace(data, spec, &mut workspace)
1835}
1836
1837pub fn build_thin_plate_basis_log_kappa_derivativewithworkspace(
1838 data: ArrayView2<'_, f64>,
1839 spec: &ThinPlateBasisSpec,
1840 workspace: &mut BasisWorkspace,
1841) -> Result<BasisPsiDerivativeResult, BasisError> {
1842 let mut bundle =
1843 build_thin_plate_basis_log_kappa_derivativeswithworkspace(data, spec, workspace)?;
1844 bundle.first.implicit_operator = bundle.implicit_operator;
1845 Ok(bundle.first)
1846}
1847
1848pub fn build_thin_plate_basis_log_kappa_derivatives(
1849 data: ArrayView2<'_, f64>,
1850 spec: &ThinPlateBasisSpec,
1851) -> Result<BasisPsiDerivativeBundle, BasisError> {
1852 let mut workspace = BasisWorkspace::default();
1853 build_thin_plate_basis_log_kappa_derivativeswithworkspace(data, spec, &mut workspace)
1854}
1855
1856pub fn build_thin_plate_basis_log_kappa_derivativeswithworkspace(
1857 data: ArrayView2<'_, f64>,
1858 spec: &ThinPlateBasisSpec,
1859 workspace: &mut BasisWorkspace,
1860) -> Result<BasisPsiDerivativeBundle, BasisError> {
1861 let base = build_thin_plate_basiswithworkspace(data, spec, workspace)?;
1862 let (centers, identifiability_transform, radial_reparam) = match &base.metadata {
1863 BasisMetadata::ThinPlate {
1864 centers,
1865 identifiability_transform,
1866 radial_reparam,
1867 ..
1868 } => (
1869 centers.clone(),
1870 identifiability_transform.clone(),
1871 radial_reparam.clone(),
1872 ),
1873 _ => {
1874 crate::bail_invalid_basis!("ThinPlate derivative path expected ThinPlate metadata");
1875 }
1876 };
1877 let mut derivative_spec = spec.clone();
1878 if derivative_spec.radial_reparam.is_none() {
1879 derivative_spec.radial_reparam = radial_reparam;
1880 }
1881 let scalar = build_thin_plate_scalar_design_psi_derivatives(
1882 data,
1883 centers.view(),
1884 &derivative_spec,
1885 identifiability_transform.as_ref(),
1886 workspace,
1887 )?;
1888 let (primary_derivative_opt, primarysecond_derivative_opt) =
1889 build_thin_plate_penalty_psi_derivativeswithworkspace(
1890 centers.view(),
1891 &derivative_spec,
1892 identifiability_transform.as_ref(),
1893 workspace,
1894 )?;
1895 let primary_derivative = primary_derivative_opt;
1896 let primarysecond_derivative = primarysecond_derivative_opt;
1897 let penalties_derivative =
1898 active_thin_plate_penalty_derivatives(&base.penaltyinfo, &primary_derivative)?;
1899 let penaltiessecond_derivative =
1900 active_thin_plate_penalty_derivatives(&base.penaltyinfo, &primarysecond_derivative)?;
1901 Ok(BasisPsiDerivativeBundle {
1902 first: BasisPsiDerivativeResult {
1903 design_derivative: scalar.design_first,
1904 penalties_derivative,
1905 implicit_operator: None,
1906 },
1907 second: BasisPsiSecondDerivativeResult {
1908 designsecond_derivative: scalar.design_second_diag,
1909 penaltiessecond_derivative,
1910 implicit_operator: None,
1911 },
1912 implicit_operator: scalar.implicit_operator,
1913 })
1914}
1915
1916pub fn build_thin_plate_basis_log_kappasecond_derivative(
1917 data: ArrayView2<'_, f64>,
1918 spec: &ThinPlateBasisSpec,
1919) -> Result<BasisPsiSecondDerivativeResult, BasisError> {
1920 let mut workspace = BasisWorkspace::default();
1921 build_thin_plate_basis_log_kappasecond_derivativewithworkspace(data, spec, &mut workspace)
1922}
1923
1924pub fn build_thin_plate_basis_log_kappasecond_derivativewithworkspace(
1925 data: ArrayView2<'_, f64>,
1926 spec: &ThinPlateBasisSpec,
1927 workspace: &mut BasisWorkspace,
1928) -> Result<BasisPsiSecondDerivativeResult, BasisError> {
1929 let mut bundle =
1930 build_thin_plate_basis_log_kappa_derivativeswithworkspace(data, spec, workspace)?;
1931 bundle.second.implicit_operator = bundle.implicit_operator;
1932 Ok(bundle.second)
1933}
1934
1935pub fn create_thin_plate_spline_basis_with_knot_count(
1937 data: ArrayView2<f64>,
1938 num_knots: usize,
1939) -> Result<(ThinPlateSplineBasis, Array2<f64>), BasisError> {
1940 let mut workspace = BasisWorkspace::default();
1941 create_thin_plate_spline_basis_with_knot_count_andworkspace(data, num_knots, &mut workspace)
1942}
1943
1944pub fn create_thin_plate_spline_basis_with_knot_count_andworkspace(
1945 data: ArrayView2<f64>,
1946 num_knots: usize,
1947 workspace: &mut BasisWorkspace,
1948) -> Result<(ThinPlateSplineBasis, Array2<f64>), BasisError> {
1949 let knots = select_thin_plate_knots(data, num_knots)?;
1950 let basis = create_thin_plate_spline_basiswithworkspace(data, knots.view(), workspace)?;
1951 Ok((basis, knots))
1952}
1953
1954pub fn apply_sum_to_zero_constraint(
1969 basis_matrix: ArrayView2<f64>,
1970 weights: Option<ArrayView1<f64>>,
1971) -> Result<(Array2<f64>, Array2<f64>), BasisError> {
1972 let n = basis_matrix.nrows();
1973 let k = basis_matrix.ncols();
1974 if k < 2 {
1975 return Err(BasisError::InsufficientColumnsForConstraint { found: k });
1976 }
1977
1978 let constraintvector = match weights {
1980 Some(w) => {
1981 if w.len() != n {
1982 return Err(BasisError::WeightsDimensionMismatch {
1983 expected: n,
1984 found: w.len(),
1985 });
1986 }
1987 w.to_owned()
1988 }
1989 None => Array1::<f64>::ones(n),
1990 };
1991 let c = basis_matrix.t().dot(&constraintvector); let mut c_mat = Array2::<f64>::zeros((k, 1));
1996 c_mat.column_mut(0).assign(&c);
1997 let (z, rank) =
1998 rrqr_nullspace_basis(&c_mat, default_rrqr_rank_alpha()).map_err(BasisError::LinalgError)?;
1999 if rank >= k {
2000 return Err(BasisError::ConstraintNullspaceCollapsed {
2001 site: "apply_sum_to_zero_constraint",
2002 cross_rank: rank,
2003 coeff_dim: k,
2004 cross_frobenius: c.iter().map(|v| v * v).sum::<f64>().sqrt(),
2005 gram_spectrum: "not computed (structural rank collapse before Gram eigendecomposition)"
2006 .to_string(),
2007 });
2008 }
2009 if rank == 0 {
2010 return Ok((basis_matrix.to_owned(), Array2::eye(k)));
2012 }
2013
2014 let gauge = gam_problem::Gauge::sum_to_zero(z);
2015 let constrained = gauge.restrict_design(&basis_matrix);
2016 let z = gauge.block_transform(0);
2017 Ok((constrained, z))
2018}
2019
2020pub fn apply_sum_to_zero_constraint_sparse(
2038 basis_matrix: &SparseColMat<usize, f64>,
2039 weights: Option<ArrayView1<f64>>,
2040) -> Result<(Array2<f64>, Array2<f64>), BasisError> {
2041 let n = basis_matrix.nrows();
2042 let k = basis_matrix.ncols();
2043 if k < 2 {
2044 return Err(BasisError::InsufficientColumnsForConstraint { found: k });
2045 }
2046
2047 let constraint_weights = match weights {
2048 Some(w) => {
2049 if w.len() != n {
2050 return Err(BasisError::WeightsDimensionMismatch {
2051 expected: n,
2052 found: w.len(),
2053 });
2054 }
2055 w.to_owned()
2056 }
2057 None => Array1::<f64>::ones(n),
2058 };
2059
2060 let mut c = Array1::<f64>::zeros(k);
2063 let (symbolic, values) = basis_matrix.parts();
2064 let col_ptr = symbolic.col_ptr();
2065 let row_idx = symbolic.row_idx();
2066 for col in 0..k {
2067 let mut sum = 0.0;
2068 for idx in col_ptr[col]..col_ptr[col + 1] {
2069 sum += values[idx] * constraint_weights[row_idx[idx]];
2070 }
2071 c[col] = sum;
2072 }
2073
2074 let mut c_mat = Array2::<f64>::zeros((k, 1));
2079 c_mat.column_mut(0).assign(&c);
2080 let (z, rank) =
2081 rrqr_nullspace_basis(&c_mat, default_rrqr_rank_alpha()).map_err(BasisError::LinalgError)?;
2082 if rank >= k {
2083 return Err(BasisError::ConstraintNullspaceCollapsed {
2084 site: "apply_sum_to_zero_constraint_sparse",
2085 cross_rank: rank,
2086 coeff_dim: k,
2087 cross_frobenius: c.iter().map(|v| v * v).sum::<f64>().sqrt(),
2088 gram_spectrum: "not computed (structural rank collapse before Gram eigendecomposition)"
2089 .to_string(),
2090 });
2091 }
2092 if rank == 0 {
2093 let mut dense_b = Array2::<f64>::zeros((n, k));
2097 for col in 0..k {
2098 for idx in col_ptr[col]..col_ptr[col + 1] {
2099 dense_b[[row_idx[idx], col]] = values[idx];
2100 }
2101 }
2102 return Ok((dense_b, Array2::eye(k)));
2103 }
2104
2105 let kc = z.ncols();
2109 let mut constrained = Array2::<f64>::zeros((n, kc));
2110 for out_col in 0..kc {
2111 let z_col = z.column(out_col);
2112 let mut dst = constrained.column_mut(out_col);
2113 for src_col in 0..k {
2114 let coeff = z_col[src_col];
2115 if coeff == 0.0 {
2116 continue;
2117 }
2118 for idx in col_ptr[src_col]..col_ptr[src_col + 1] {
2119 dst[row_idx[idx]] += coeff * values[idx];
2120 }
2121 }
2122 }
2123
2124 Ok((constrained, z))
2125}
2126
2127pub fn applyweighted_orthogonality_constraint(
2149 basis_matrix: ArrayView2<f64>,
2150 constraint_matrix: ArrayView2<f64>,
2151 weights: Option<ArrayView1<f64>>,
2152) -> Result<(Array2<f64>, Array2<f64>), BasisError> {
2153 let n = basis_matrix.nrows();
2154 let k = basis_matrix.ncols();
2155 if constraint_matrix.nrows() != n {
2156 return Err(BasisError::ConstraintMatrixRowMismatch {
2157 basisrows: n,
2158 constraintrows: constraint_matrix.nrows(),
2159 });
2160 }
2161 if k == 0 {
2162 return Err(BasisError::InsufficientColumnsForConstraint { found: 0 });
2163 }
2164 let q = constraint_matrix.ncols();
2165 if q == 0 {
2166 return Ok((basis_matrix.to_owned(), Array2::eye(k)));
2167 }
2168
2169 let mut weighted_constraints = constraint_matrix.to_owned();
2171 if let Some(w) = weights {
2172 if w.len() != n {
2173 return Err(BasisError::WeightsDimensionMismatch {
2174 expected: n,
2175 found: w.len(),
2176 });
2177 }
2178 for (mut row, &weight) in weighted_constraints.axis_iter_mut(Axis(0)).zip(w.iter()) {
2179 row *= weight;
2180 }
2181 }
2182
2183 let constraint_cross = basis_matrix.t().dot(&weighted_constraints); let gram = fast_ata(&basis_matrix);
2187 let transform = orthogonality_transform_from_cross_and_gram(&constraint_cross, &gram)?;
2188 let basis_orthonormal = fast_ab(&basis_matrix, &transform);
2189 Ok((basis_orthonormal, transform))
2190}
2191
2192pub fn compute_greville_abscissae(
2213 knot_vector: &Array1<f64>,
2214 degree: usize,
2215) -> Result<Array1<f64>, BasisError> {
2216 let n_knots = knot_vector.len();
2217 if degree == 0 {
2218 let n_basis = n_knots.saturating_sub(1);
2220 if n_basis == 0 {
2221 return Err(BasisError::InsufficientColumnsForConstraint { found: 0 });
2222 }
2223 let mut g = Array1::<f64>::zeros(n_basis);
2224 for j in 0..n_basis {
2225 g[j] = 0.5 * (knot_vector[j] + knot_vector[j + 1]);
2226 }
2227 return Ok(g);
2228 }
2229
2230 if n_knots <= degree + 1 {
2232 return Err(BasisError::InsufficientColumnsForConstraint {
2233 found: n_knots.saturating_sub(degree + 1),
2234 });
2235 }
2236 let n_basis = n_knots - degree - 1;
2237
2238 let mut g = Array1::<f64>::zeros(n_basis);
2239 let d_inv = 1.0 / (degree as f64);
2240
2241 for j in 0..n_basis {
2242 let mut sum = 0.0;
2244 for k in 1..=degree {
2245 sum += knot_vector[j + k];
2246 }
2247 g[j] = sum * d_inv;
2248 }
2249
2250 let g_min = g.iter().cloned().fold(f64::INFINITY, f64::min);
2252 let g_max = g.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
2253 if (g_max - g_min) < 1e-10 {
2254 return Err(BasisError::DegenerateKnots);
2255 }
2256
2257 Ok(g)
2258}
2259
2260pub fn compute_geometric_constraint_transform(
2285 knot_vector: &Array1<f64>,
2286 degree: usize,
2287 penalty_order: usize,
2288) -> Result<(Array2<f64>, Array2<f64>), BasisError> {
2289 let g = compute_greville_abscissae(knot_vector, degree)?;
2291 let k = g.len();
2292
2293 if k < 3 {
2294 return Err(BasisError::InsufficientColumnsForConstraint { found: k });
2295 }
2296
2297 let mut c_geom = Array2::<f64>::zeros((2, k));
2301 for j in 0..k {
2302 c_geom[[0, j]] = 1.0;
2303 c_geom[[1, j]] = g[j];
2304 }
2305
2306 let g_mean = g.mean().unwrap_or(0.0);
2308 let gvar = g.iter().map(|&x| (x - g_mean).powi(2)).sum::<f64>() / (k as f64);
2309 let g_std = gvar.sqrt().max(1e-10);
2310 for j in 0..k {
2311 c_geom[[1, j]] = (c_geom[[1, j]] - g_mean) / g_std;
2312 }
2313
2314 let (z, rank) = rrqr_nullspace_basis(&c_geom.t(), default_rrqr_rank_alpha())
2316 .map_err(BasisError::LinalgError)?;
2317 if rank >= k {
2318 return Err(BasisError::ConstraintNullspaceCollapsed {
2319 site: "compute_geometric_constraint_transform",
2320 cross_rank: rank,
2321 coeff_dim: k,
2322 cross_frobenius: f64::NAN,
2323 gram_spectrum: "not computed (structural rank collapse before Gram eigendecomposition)"
2324 .to_string(),
2325 });
2326 }
2327
2328 if z.ncols() == 0 {
2329 return Err(BasisError::ConstraintNullspaceCollapsed {
2330 site: "compute_geometric_constraint_transform",
2331 cross_rank: 0,
2332 coeff_dim: k,
2333 cross_frobenius: f64::NAN,
2334 gram_spectrum: "not computed (structural rank collapse before Gram eigendecomposition)"
2335 .to_string(),
2336 });
2337 }
2338
2339 let s_raw = create_difference_penalty_matrix(k, penalty_order, Some(g.view()))?;
2341 let s_constrained = {
2342 let zt_s = fast_atb(&z, &s_raw);
2343 fast_ab(&zt_s, &z)
2344 };
2345
2346 Ok((z, s_constrained))
2347}
2348
2349#[derive(Debug, Clone)]
2361pub struct AutoBSplineKnots {
2362 pub knots: Array1<f64>,
2363 pub degree: usize,
2364 pub num_internal_knots: usize,
2365 pub shrunk: bool,
2366}
2367
2368pub fn auto_knot_vector_1d_quantile(
2384 data: ArrayView1<'_, f64>,
2385 num_internal_knots: usize,
2386 degree: usize,
2387) -> Result<AutoBSplineKnots, BasisError> {
2388 let n = data.len();
2389 let Some((eff_knots, eff_degree, shrunk)) =
2390 auto_shrink_bspline_config(n, num_internal_knots, degree)
2391 else {
2392 crate::bail_invalid_basis!(
2393 "auto-knot placement needs at least 2 finite evaluation points (got n={n}); \
2394 cannot fit even a linear B-spline",
2395 );
2396 };
2397 let knots = internal::generate_full_knot_vector_quantile(data, eff_knots, eff_degree)?;
2398 Ok(AutoBSplineKnots {
2399 knots,
2400 degree: eff_degree,
2401 num_internal_knots: eff_knots,
2402 shrunk,
2403 })
2404}
2405
2406pub fn clamped_knot_vector_from_internal_positions(
2422 data_range: (f64, f64),
2423 internal_positions: &[f64],
2424 degree: usize,
2425) -> Result<Array1<f64>, BasisError> {
2426 let (minval, maxval) = data_range;
2427 if !(minval.is_finite() && maxval.is_finite()) {
2428 crate::bail_invalid_basis!(
2429 "explicit knots require a finite data range, got ({minval:.6e}, {maxval:.6e})"
2430 );
2431 }
2432 if minval >= maxval {
2433 return Err(BasisError::InvalidRange(minval, maxval));
2434 }
2435 let scale = (maxval - minval).abs().max(1.0);
2436 let tol = 1e-12 * scale;
2437
2438 let mut interior: Vec<f64> = Vec::with_capacity(internal_positions.len());
2439 for &k in internal_positions {
2440 if !k.is_finite() {
2441 crate::bail_invalid_basis!("explicit knot position {k:.6e} is not finite");
2442 }
2443 if k <= minval + tol || k >= maxval - tol {
2444 crate::bail_invalid_basis!(
2445 "explicit internal knot {k:.6e} must lie strictly inside the data range \
2446 ({minval:.6e}, {maxval:.6e}); boundary knots are added automatically"
2447 );
2448 }
2449 interior.push(k);
2450 }
2451 interior.sort_by(f64::total_cmp);
2452 for w in interior.windows(2) {
2453 if (w[1] - w[0]).abs() <= tol {
2454 crate::bail_invalid_basis!(
2455 "explicit internal knots must be strictly increasing; \
2456 found a duplicate/near-duplicate near {:.6e}",
2457 w[0]
2458 );
2459 }
2460 }
2461
2462 let total_knots = interior.len() + 2 * (degree + 1);
2463 let mut knots = Vec::with_capacity(total_knots);
2464 for _ in 0..=degree {
2465 knots.push(minval);
2466 }
2467 knots.extend_from_slice(&interior);
2468 for _ in 0..=degree {
2469 knots.push(maxval);
2470 }
2471 Ok(Array::from_vec(knots))
2472}
2473
2474pub fn auto_centers_1d_equal_mass(
2479 data: ArrayView1<'_, f64>,
2480 num_centers: usize,
2481) -> Result<Array1<f64>, BasisError> {
2482 let column = data.to_owned().insert_axis(Axis(1));
2483 let centers = select_equal_mass_centers(column.view(), num_centers)?;
2484 let mut flat: Vec<f64> = centers.column(0).iter().copied().collect();
2485 flat.sort_by(f64::total_cmp);
2486 Ok(Array1::from_vec(flat))
2487}
2488
2489#[cfg(test)]
2490mod knot_selection_invariance_tests {
2491 use super::select_thin_plate_knots;
2505 use ndarray::Array2;
2506
2507 fn sample_cloud() -> Array2<f64> {
2513 let pts: Vec<[f64; 2]> = vec![
2515 [0.10, 0.20],
2516 [1.30, 0.05],
2517 [2.10, 1.40],
2518 [0.40, 2.30],
2519 [1.90, 2.80],
2520 [3.20, 0.70],
2521 [2.70, 3.10],
2522 [0.90, 1.10],
2523 [3.50, 2.20],
2524 [1.60, 3.60],
2525 [0.05, 3.05],
2526 [2.40, 0.30],
2527 ];
2528 let mut a = Array2::<f64>::zeros((pts.len(), 2));
2529 for (i, p) in pts.iter().enumerate() {
2530 a[[i, 0]] = p[0];
2531 a[[i, 1]] = p[1];
2532 }
2533 a
2534 }
2535
2536 fn canonical(knots: &Array2<f64>) -> Vec<(u64, u64)> {
2544 let mut rows: Vec<(u64, u64)> = (0..knots.nrows())
2545 .map(|r| (knots[[r, 0]].to_bits(), knots[[r, 1]].to_bits()))
2546 .collect();
2547 rows.sort_unstable();
2548 rows
2549 }
2550
2551 fn data_centroid_2d(data: &Array2<f64>) -> (f64, f64) {
2553 let n = data.nrows();
2554 let cx = (0..n).map(|i| data[[i, 0]]).sum::<f64>() / n as f64;
2555 let cz = (0..n).map(|i| data[[i, 1]]).sum::<f64>() / n as f64;
2556 (cx, cz)
2557 }
2558
2559 fn rotate_90_about(data: &Array2<f64>, cx: f64, cz: f64) -> Array2<f64> {
2568 let n = data.nrows();
2569 let mut out = Array2::<f64>::zeros((n, 2));
2570 for i in 0..n {
2571 let dx = data[[i, 0]] - cx;
2572 let dz = data[[i, 1]] - cz;
2573 out[[i, 0]] = cx - dz;
2574 out[[i, 1]] = cz + dx;
2575 }
2576 out
2577 }
2578
2579 #[test]
2580 fn knot_set_is_rotation_invariant_gh1456() {
2581 let data = sample_cloud();
2582 let n = data.nrows();
2583 let num_knots = 5;
2585 assert!(num_knots < n, "must exercise the farthest-point selector");
2586
2587 let knots = select_thin_plate_knots(data.view(), num_knots).expect("select knots");
2588 assert_eq!(knots.nrows(), num_knots);
2589
2590 let (cx, cz) = data_centroid_2d(&data);
2595 let rotated = rotate_90_about(&data, cx, cz);
2596 let knots_rot = select_thin_plate_knots(rotated.view(), num_knots).expect("select rotated");
2597
2598 let knots_then_rotate = rotate_90_about(&knots, cx, cz);
2602 assert_eq!(
2603 canonical(&knots_then_rotate),
2604 canonical(&knots_rot),
2605 "rotating-then-selecting must equal selecting-then-rotating (gh#1456); \
2606 the OLD lexicographic seed picks a different physical point after rotation"
2607 );
2608 }
2609
2610 #[test]
2611 fn knot_set_is_row_permutation_invariant_gh1378() {
2612 let data = sample_cloud();
2613 let n = data.nrows();
2614 let num_knots = 5;
2615 assert!(num_knots < n, "must exercise the farthest-point selector");
2616
2617 let knots = select_thin_plate_knots(data.view(), num_knots).expect("select knots");
2618
2619 let perm: Vec<usize> = vec![7, 0, 11, 3, 9, 1, 5, 10, 2, 8, 4, 6];
2621 assert_eq!(perm.len(), n);
2622 let mut permuted = Array2::<f64>::zeros((n, 2));
2623 for (new_row, &old_row) in perm.iter().enumerate() {
2624 permuted[[new_row, 0]] = data[[old_row, 0]];
2625 permuted[[new_row, 1]] = data[[old_row, 1]];
2626 }
2627
2628 let knots_perm =
2629 select_thin_plate_knots(permuted.view(), num_knots).expect("select permuted");
2630
2631 assert_eq!(
2634 canonical(&knots),
2635 canonical(&knots_perm),
2636 "reordering rows must not change the selected knot set (gh#1378)"
2637 );
2638 }
2639}
2640
2641#[cfg(test)]
2642mod retained_radial_indices_tests {
2643 use super::thin_plate_retained_radial_indices;
2644 use ndarray::Array1;
2645
2646 #[test]
2653 fn linear_data_spectrum_keeps_every_mode() {
2654 let evals = Array1::from_vec(vec![
2658 885.4, 119.98, 26.287, 10.030, 5.066, 2.330, 1.3953, 0.67709, 0.46814, 0.34210,
2659 0.26488, 0.17895, 0.14514,
2660 ]);
2661 let keep = thin_plate_retained_radial_indices(&evals);
2662 assert_eq!(
2663 keep.len(),
2664 evals.len(),
2665 "all numerically real modes must be retained"
2666 );
2667 }
2668
2669 #[test]
2670 fn lidar_spectrum_keeps_every_real_mode() {
2671 let evals = Array1::from_vec(vec![
2676 1212.2, 144.94, 37.270, 15.529, 6.0768, 3.5845, 1.8094, 1.1058, 0.73002, 0.43701,
2677 0.33814, 0.23136, 0.18267, 0.15702, 0.13654, 0.044936, 0.041844, 0.038235,
2678 ]);
2679 let keep = thin_plate_retained_radial_indices(&evals);
2680 assert_eq!(keep.len(), evals.len(), "every above-floor mode is kept");
2681 }
2682
2683 #[test]
2684 fn pure_roundoff_modes_are_dropped() {
2685 let big = 1.0e3;
2689 let dust = 0.1 * 5.0 * f64::EPSILON * big; let evals = Array1::from_vec(vec![big, 100.0, 10.0, 1.0, dust]);
2691 let keep = thin_plate_retained_radial_indices(&evals);
2692 assert_eq!(keep.len(), 4, "the sub-floor roundoff mode must be pruned");
2693 assert!(!keep.contains(&4));
2694 }
2695
2696 #[test]
2697 fn empty_and_singleton_spectra_are_handled() {
2698 assert!(thin_plate_retained_radial_indices(&Array1::from_vec(vec![])).is_empty());
2699 assert_eq!(
2700 thin_plate_retained_radial_indices(&Array1::from_vec(vec![5.0])),
2701 vec![0]
2702 );
2703 }
2704}