1use coefficient_transforms::{
2 convex_divided_difference_transform_matrix, cumulative_exp, cumulative_sum_transform_matrix,
3 second_cumulative_exp,
4};
5
6pub use error::SmoothError;
7
8use input_standardization::{
9 apply_input_standardization, compensate_length_scale_for_standardization,
10 compensate_optional_length_scale_for_standardization, compute_spatial_input_scales,
11};
12
13use shape_constraints::{
14 build_shape_constraint_design_1d, build_shape_linear_constraints_1d,
15 merge_linear_constraints_global, shape_lower_bounds_local, shape_order_and_sign,
16 shape_supports_basis, shape_uses_box_reparameterization,
17};
18
19pub fn describe_thin_plate_center_request(strategy: &CenterStrategy) -> String {
20 match strategy {
21 CenterStrategy::Auto(inner) => describe_thin_plate_center_request(inner),
22 CenterStrategy::UserProvided(centers) => format!("{} centers", centers.nrows()),
23 CenterStrategy::EqualMass { num_centers }
24 | CenterStrategy::EqualMassCovarRepresentative { num_centers }
25 | CenterStrategy::FarthestPoint { num_centers }
26 | CenterStrategy::KMeans { num_centers, .. } => format!("{num_centers} centers"),
27 CenterStrategy::UniformGrid { points_per_dim } => {
28 format!("uniform grid with {points_per_dim} points per dimension")
29 }
30 }
31}
32
33pub fn rewrite_thin_plate_knots_error(
34 err: BasisError,
35 termname: &str,
36 feature_count: usize,
37 spec: &ThinPlateBasisSpec,
38) -> BasisError {
39 match err {
40 BasisError::InvalidInput(msg)
43 if msg.contains("thin-plate spline requires at least")
44 && (msg.contains("centers to span") || msg.contains("knots to span")) =>
45 {
46 let min_centers = crate::basis::thin_plate_polynomial_basis_dimension(feature_count);
47 let requested = describe_thin_plate_center_request(&spec.center_strategy);
48 BasisError::InvalidInput(format!(
49 "joint TPS term '{termname}' over {feature_count} covariates with {requested} is invalid; minimum centers is {min_centers}"
50 ))
51 }
52 BasisError::InvalidInput(msg)
57 if msg.starts_with("requested ") && msg.contains(" knots but only ") =>
58 {
59 let min_centers = crate::basis::thin_plate_polynomial_basis_dimension(feature_count);
60 let requested = describe_thin_plate_center_request(&spec.center_strategy);
61 BasisError::InvalidInput(format!(
62 "joint TPS term '{termname}' over {feature_count} covariates with {requested} is invalid; minimum centers is {min_centers}"
63 ))
64 }
65 other => other,
66 }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
70pub enum ShapeConstraint {
71 None,
72 MonotoneIncreasing,
73 MonotoneDecreasing,
74 Convex,
75 Concave,
76}
77
78pub fn parse_shape_constraint(raw: &str) -> Result<ShapeConstraint, String> {
89 let normalized = raw.trim().to_ascii_lowercase().replace('-', "_");
90 match normalized.as_str() {
91 "" | "none" => Ok(ShapeConstraint::None),
92 "monotone_increasing" | "monotonic_increasing" | "increasing" | "mono_inc" | "mpi" => {
93 Ok(ShapeConstraint::MonotoneIncreasing)
94 }
95 "monotone_decreasing" | "monotonic_decreasing" | "decreasing" | "mono_dec" | "mpd" => {
96 Ok(ShapeConstraint::MonotoneDecreasing)
97 }
98 "convex" | "cvx" => Ok(ShapeConstraint::Convex),
99 "concave" | "ccv" => Ok(ShapeConstraint::Concave),
100 other => Err(format!(
101 "unknown shape constraint {other:?}; expected one of \
102 \"none\", \"monotone_increasing\", \"monotone_decreasing\", \
103 \"convex\", \"concave\""
104 )),
105 }
106}
107
108impl ShapeConstraint {
109 pub fn dsl_str(&self) -> &'static str {
112 match self {
113 ShapeConstraint::None => "none",
114 ShapeConstraint::MonotoneIncreasing => "monotone_increasing",
115 ShapeConstraint::MonotoneDecreasing => "monotone_decreasing",
116 ShapeConstraint::Convex => "convex",
117 ShapeConstraint::Concave => "concave",
118 }
119 }
120}
121
122pub const SMOOTH_HEAD_KEYWORDS: [&str; 11] = [
125 "s",
126 "smooth",
127 "te",
128 "tensor",
129 "thinplate",
130 "tps",
131 "duchon",
132 "matern",
133 "sphere",
134 "bs",
135 "bspline",
136];
137
138pub fn apply_shape_constraints_to_formula(
151 formula: &str,
152 constraints: &[(String, String)],
153) -> Result<String, String> {
154 use std::collections::{BTreeMap, BTreeSet};
155
156 if constraints.is_empty() {
157 return Ok(formula.to_string());
158 }
159 let strip_ws = |s: &str| -> String { s.chars().filter(|c| !c.is_whitespace()).collect() };
160
161 let mut wanted: BTreeMap<String, &'static str> = BTreeMap::new();
163 let mut originals: BTreeMap<String, String> = BTreeMap::new();
165 for (key, kind_raw) in constraints {
166 let kind = parse_shape_constraint(kind_raw)?;
167 let nk = strip_ws(key);
168 originals.entry(nk.clone()).or_insert_with(|| key.clone());
169 if kind != ShapeConstraint::None {
170 wanted.insert(nk, kind.dsl_str());
171 }
172 }
173 if wanted.is_empty() {
174 return Ok(formula.to_string());
175 }
176
177 let chars: Vec<char> = formula.chars().collect();
178 let n = chars.len();
179 let is_ident = |c: char| c.is_ascii_alphanumeric() || c == '_';
180
181 let mut out = String::with_capacity(formula.len() + 32);
182 let mut matched: BTreeSet<String> = BTreeSet::new();
183 let mut i = 0usize;
184 while i < n {
185 let mut head: Option<(usize, usize)> = None; let mut p = i;
189 while p < n {
190 let boundary = p == 0 || !is_ident(chars[p - 1]);
191 if boundary {
192 for kw in SMOOTH_HEAD_KEYWORDS.iter() {
193 let klen = kw.chars().count();
194 if p + klen > n || chars[p..p + klen].iter().collect::<String>() != **kw {
195 continue;
196 }
197 let mut q = p + klen;
198 while q < n && chars[q].is_whitespace() {
199 q += 1;
200 }
201 if q < n && chars[q] == '(' {
202 head = Some((p, q));
203 break;
204 }
205 }
206 }
207 if head.is_some() {
208 break;
209 }
210 p += 1;
211 }
212 let (head_start, paren_open) = match head {
213 Some(h) => h,
214 None => {
215 out.extend(chars[i..].iter());
216 break;
217 }
218 };
219 out.extend(chars[i..head_start].iter());
220
221 let body_start = paren_open + 1;
223 let mut depth = 1i32;
224 let mut j = body_start;
225 let mut in_str: Option<char> = None;
226 let mut closed = false;
227 while j < n {
228 let ch = chars[j];
229 if let Some(quote) = in_str {
230 if ch == quote {
231 in_str = None;
232 }
233 } else if ch == '\'' || ch == '"' {
234 in_str = Some(ch);
235 } else if ch == '(' {
236 depth += 1;
237 } else if ch == ')' {
238 depth -= 1;
239 if depth == 0 {
240 closed = true;
241 break;
242 }
243 }
244 j += 1;
245 }
246
247 if !closed {
248 out.extend(chars[head_start..].iter());
251 break;
252 }
253
254 let term_text: String = chars[head_start..=j].iter().collect();
255
256 let key_norm = strip_ws(&term_text);
257
258 match wanted.get(&key_norm) {
259 None => out.extend(chars[head_start..=j].iter()),
260 Some(kind) => {
261 let head_paren: String = chars[head_start..body_start].iter().collect();
262 let inside: String = chars[body_start..j].iter().collect();
263 let inside = inside.trim();
264 if inside.is_empty() {
265 out.push_str(&format!("{head_paren}shape={kind})"));
266 } else {
267 out.push_str(&format!("{head_paren}{inside}, shape={kind})"));
268 }
269 matched.insert(key_norm);
270 }
271 }
272
273 i = j + 1;
274 }
275
276 let mut missing: Vec<String> = wanted
277 .keys()
278 .filter(|k| !matched.contains(*k))
279 .map(|k| originals.get(k).cloned().unwrap_or_else(|| k.clone()))
280 .collect();
281
282 if !missing.is_empty() {
283 missing.sort();
284 return Err(format!(
285 "shape constraints referenced smooth term(s) not found in formula: {}",
286 missing.join(", ")
287 ));
288 }
289
290 Ok(out)
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
294pub enum BySmoothKind {
295 Numeric,
296 Level { level_bits: u64 },
297}
298
299#[derive(Debug, Clone, Serialize, Deserialize)]
300pub enum SmoothBasisSpec {
301 ByVariable {
311 inner: Box<SmoothBasisSpec>,
312 by_col: usize,
313 kind: BySmoothKind,
314 by: ByVariableSpec,
315 },
316 FactorSumToZero {
320 inner: Box<SmoothBasisSpec>,
321 by_col: usize,
322 levels: Vec<u64>,
323 #[serde(default)]
334 frozen_global_orthogonality: Option<Array2<f64>>,
335 },
336 BSpline1D {
337 feature_col: usize,
338 spec: BSplineBasisSpec,
339 },
340 BySmooth {
343 smooth: Box<SmoothBasisSpec>,
344 by_kind: ByVarKind,
345 },
346 FactorSmooth { spec: FactorSmoothSpec },
349 ThinPlate {
350 feature_cols: Vec<usize>,
351 spec: ThinPlateBasisSpec,
352 #[serde(default)]
356 input_scales: Option<Vec<f64>>,
357 },
358 Sphere {
359 feature_cols: Vec<usize>,
360 spec: SphericalSplineBasisSpec,
361 },
362 ConstantCurvature {
368 feature_cols: Vec<usize>,
369 spec: ConstantCurvatureBasisSpec,
370 },
371 Matern {
372 feature_cols: Vec<usize>,
373 spec: MaternBasisSpec,
374 #[serde(default)]
375 input_scales: Option<Vec<f64>>,
376 },
377 MeasureJet {
383 feature_cols: Vec<usize>,
384 spec: MeasureJetBasisSpec,
385 #[serde(default)]
386 input_scales: Option<Vec<f64>>,
387 },
388 Duchon {
389 feature_cols: Vec<usize>,
390 spec: DuchonBasisSpec,
391 #[serde(default)]
392 input_scales: Option<Vec<f64>>,
393 },
394 Pca {
395 feature_cols: Vec<usize>,
396 basis_matrix: Array2<f64>,
397 centered: bool,
398 #[serde(default = "default_pca_smooth_penalty")]
399 smooth_penalty: f64,
400 #[serde(default)]
401 center_mean: Option<Array1<f64>>,
402 #[serde(default)]
403 pca_basis_path: Option<PathBuf>,
404 #[serde(default = "default_pca_chunk_size")]
405 chunk_size: usize,
406 },
407 TensorBSpline {
412 feature_cols: Vec<usize>,
413 spec: TensorBSplineSpec,
414 },
415}
416
417impl SmoothBasisSpec {
418 pub fn min_sample_rows(&self) -> usize {
435 const RADIAL_FLOOR: usize = 5;
440
441 match self {
442 Self::ByVariable { inner, .. } => inner.min_sample_rows(),
443 Self::FactorSumToZero { inner, levels, .. } => {
444 let inner_min = inner.min_sample_rows();
448 let lvls = levels.len().saturating_sub(1).max(1);
449 inner_min.saturating_mul(lvls)
450 }
451 Self::BSpline1D { spec, .. } => bspline_basis_min_rows(spec),
452 Self::BySmooth { smooth, .. } => smooth.min_sample_rows(),
453 Self::FactorSmooth { spec } => {
454 bspline_basis_min_rows(&spec.marginal)
458 }
459 Self::ThinPlate { .. }
460 | Self::Sphere { .. }
461 | Self::ConstantCurvature { .. }
462 | Self::Matern { .. }
463 | Self::MeasureJet { .. }
464 | Self::Duchon { .. } => RADIAL_FLOOR,
465 Self::Pca { basis_matrix, .. } => basis_matrix.ncols().max(1),
466 Self::TensorBSpline { spec, .. } => {
467 let mut total: usize = 0;
513 for marginal in &spec.marginalspecs {
514 let m = bspline_basis_min_rows(marginal);
515 total = total.saturating_add(m.max(1));
516 }
517 total.max(RADIAL_FLOOR)
518 }
519 }
520 }
521
522 pub fn structural_kind(&self) -> &'static str {
533 match self {
534 Self::ByVariable { .. } => "by_variable",
535 Self::FactorSumToZero { .. } => "factor_sum_to_zero",
536 Self::BSpline1D { .. } => "bspline_1d",
537 Self::BySmooth { .. } => "by_smooth",
538 Self::FactorSmooth { .. } => "factor_smooth",
539 Self::ThinPlate { .. } => "thin_plate",
540 Self::Sphere { .. } => "sphere",
541 Self::ConstantCurvature { .. } => "constant_curvature",
542 Self::Matern { .. } => "matern",
543 Self::MeasureJet { .. } => "measurejet",
544 Self::Duchon { .. } => "duchon",
545 Self::Pca { .. } => "pca",
546 Self::TensorBSpline { .. } => "tensor_bspline",
547 }
548 }
549
550 pub fn is_marginally_centered_tensor(&self) -> bool {
559 matches!(
560 self,
561 Self::TensorBSpline { spec, .. }
562 if matches!(spec.identifiability, TensorBSplineIdentifiability::MarginalSumToZero)
563 )
564 }
565
566 pub fn structural_feature_cols(&self) -> Vec<usize> {
570 match self {
571 Self::ByVariable { inner, .. } | Self::FactorSumToZero { inner, .. } => {
572 inner.structural_feature_cols()
573 }
574 Self::BySmooth { smooth, .. } => smooth.structural_feature_cols(),
575 Self::FactorSmooth { .. } => Vec::new(),
576 Self::BSpline1D { feature_col, .. } => vec![*feature_col],
577 Self::ThinPlate { feature_cols, .. }
578 | Self::Sphere { feature_cols, .. }
579 | Self::ConstantCurvature { feature_cols, .. }
580 | Self::Matern { feature_cols, .. }
581 | Self::MeasureJet { feature_cols, .. }
582 | Self::Duchon { feature_cols, .. }
583 | Self::Pca { feature_cols, .. }
584 | Self::TensorBSpline { feature_cols, .. } => feature_cols.clone(),
585 }
586 }
587}
588
589pub fn bspline_basis_min_rows(spec: &crate::basis::BSplineBasisSpec) -> usize {
614 use crate::basis::BSplineKnotSpec;
615 let columns = match &spec.knotspec {
616 BSplineKnotSpec::Generate {
617 num_internal_knots, ..
618 } => *num_internal_knots + spec.degree + 1,
619 BSplineKnotSpec::Automatic {
620 num_internal_knots: Some(k),
621 ..
622 } => *k + spec.degree + 1,
623 BSplineKnotSpec::Automatic {
624 num_internal_knots: None,
625 ..
626 } => {
627 spec.degree + 2
631 }
632 BSplineKnotSpec::Provided(knots) => knots.len().saturating_sub(spec.degree + 1).max(1),
633 BSplineKnotSpec::NaturalCubicRegression { knots } => knots.len(),
635 BSplineKnotSpec::PeriodicUniform { num_basis, .. } => *num_basis,
636 };
637 let columns = columns.max(spec.degree + 2);
638
639 if spec.double_penalty {
640 const DOUBLE_PENALTY_FLOOR: usize = 2;
643 DOUBLE_PENALTY_FLOOR.min(columns).max(1)
644 } else {
645 columns
646 }
647}
648
649#[derive(Debug, Clone, Serialize, Deserialize)]
650pub enum ByVariableSpec {
651 Numeric,
652 Level { value_bits: u64, label: String },
653}
654
655
656#[derive(Debug, Clone, Serialize, Deserialize)]
657pub enum ByVarKind {
658 Numeric {
659 feature_col: usize,
660 },
661 Factor {
662 feature_col: usize,
663 ordered: bool,
664 frozen_levels: Option<Vec<u64>>,
665 },
666}
667
668#[derive(Debug, Clone, Serialize, Deserialize)]
669pub struct FactorSmoothSpec {
670 pub continuous_cols: Vec<usize>,
671 pub group_col: usize,
672 pub marginal: BSplineBasisSpec,
673 pub flavour: FactorSmoothFlavour,
674 pub group_frozen_levels: Option<Vec<u64>>,
675 #[serde(default)]
681 pub frozen_global_orthogonality: Option<Array2<f64>>,
682}
683
684#[derive(Debug, Clone, Serialize, Deserialize)]
685pub enum FactorSmoothFlavour {
686 Fs { m_null_penalty_orders: Vec<usize> },
687 Sz,
688 Re,
689}
690
691#[derive(Debug, Default, Clone, Serialize, Deserialize)]
692pub struct TensorBSplineSpec {
693 pub marginalspecs: Vec<BSplineBasisSpec>,
694 #[serde(default)]
695 pub periods: Vec<Option<f64>>,
696 pub double_penalty: bool,
697 #[serde(default)]
698 pub identifiability: TensorBSplineIdentifiability,
699 #[serde(default)]
700 pub penalty_decomposition: TensorBSplinePenaltyDecomposition,
701}
702
703#[derive(Debug, Default, Clone, Serialize, Deserialize)]
704pub enum TensorBSplineIdentifiability {
705 None,
706 #[default]
707 SumToZero,
708 MarginalSumToZero,
718 FrozenTransform {
719 transform: Array2<f64>,
720 },
721}
722
723#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
724pub enum TensorBSplinePenaltyDecomposition {
725 #[default]
728 MarginalKroneckerSum,
729 Separable,
733}
734
735#[derive(Debug, Clone, Serialize, Deserialize)]
736pub struct SmoothTermSpec {
737 pub name: String,
738 pub basis: SmoothBasisSpec,
739 pub shape: ShapeConstraint,
740 #[serde(default)]
749 pub joint_null_rotation: Option<crate::basis::JointNullRotation>,
750}
751
752#[derive(Debug, Clone)]
753pub struct SmoothTerm {
754 pub name: String,
755 pub coeff_range: Range<usize>,
756 pub shape: ShapeConstraint,
757 pub penalties_local: Vec<Array2<f64>>,
758 pub nullspace_dims: Vec<usize>,
759 pub penaltyinfo_local: Vec<PenaltyInfo>,
760 pub metadata: BasisMetadata,
761 pub lower_bounds_local: Option<Array1<f64>>,
764 pub linear_constraints_local: Option<LinearInequalityConstraints>,
767 pub kronecker_factored: Option<KroneckerFactoredBasis>,
770 pub joint_null_rotation: Option<crate::basis::JointNullRotation>,
793 pub unabsorbed_global_orthogonality: Option<Array2<f64>>,
803}
804
805impl SmoothTerm {
806 pub fn apply_rotation_to_predict(
822 &self,
823 x_new_raw: Array2<f64>,
824 ) -> Result<Array2<f64>, BasisError> {
825 let Some(rot) = self.joint_null_rotation.as_ref() else {
826 return Ok(x_new_raw);
827 };
828 let p_local = rot.rotation.nrows();
829 if x_new_raw.ncols() != p_local {
830 crate::bail_dim_basis!(
831 "joint-null rotation replay for term '{}': raw design has {} columns, \
832 rotation expects {} (the raw basis builder must emit the same column \
833 count as at fit time)",
834 self.name,
835 x_new_raw.ncols(),
836 p_local,
837 );
838 }
839 Ok(gam_linalg::faer_ndarray::fast_ab(
840 &x_new_raw,
841 &rot.rotation,
842 ))
843 }
844
845 pub fn wald_unpenalized_dim(&self) -> usize {
868 joint_unpenalized_dim(
869 self.coeff_range.len(),
870 &self.penalties_local,
871 &self.nullspace_dims,
872 )
873 }
874}
875
876pub fn joint_unpenalized_dim(
881 p_local: usize,
882 penalties_local: &[Array2<f64>],
883 nullspace_dims: &[usize],
884) -> usize {
885 use gam_linalg::faer_ndarray::FaerEigh;
886 if p_local == 0 {
887 return 0;
888 }
889 if penalties_local.is_empty() {
890 return p_local;
892 }
893 let mut s_total = Array2::<f64>::zeros((p_local, p_local));
898 let mut materialized = 0usize;
899 for s in penalties_local {
900 if s.nrows() == p_local && s.ncols() == p_local {
901 s_total += s;
902 materialized += 1;
903 }
904 }
905 if materialized == penalties_local.len() {
906 let symmetric = {
907 let transpose = s_total.t().to_owned();
908 (&s_total + &transpose) * 0.5
909 };
910 if let Ok((evals, _)) = symmetric.eigh(faer::Side::Lower) {
911 let max_abs = evals.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
912 if max_abs == 0.0 {
913 return p_local;
915 }
916 let tol = max_abs * (p_local as f64) * 1e-12;
917 let rank = evals.iter().filter(|&&v| v > tol).count();
918 return p_local.saturating_sub(rank);
919 }
920 }
921 if penalties_local.len() >= 2 {
926 0
927 } else {
928 nullspace_dims
929 .iter()
930 .copied()
931 .min()
932 .unwrap_or(0)
933 .min(p_local)
934 }
935}
936
937#[derive(Debug, Clone, Serialize, Deserialize)]
938pub struct PenaltyBlockInfo {
939 pub global_index: usize,
940 pub termname: Option<String>,
941 pub penalty: PenaltyInfo,
942}
943
944#[derive(Debug, Clone, Serialize, Deserialize)]
945pub struct DroppedPenaltyBlockInfo {
946 pub termname: Option<String>,
947 pub penalty: PenaltyInfo,
948}
949
950#[derive(Debug, Clone)]
951pub struct SmoothDesign {
952 pub term_designs: Vec<DesignMatrix>,
953 pub penalties: Vec<BlockwisePenalty>,
956 pub nullspace_dims: Vec<usize>,
957 pub penaltyinfo: Vec<PenaltyBlockInfo>,
958 pub dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
959 pub terms: Vec<SmoothTerm>,
960 pub coefficient_lower_bounds: Option<Array1<f64>>,
963 pub linear_constraints: Option<LinearInequalityConstraints>,
966}
967
968impl SmoothDesign {
969 pub fn total_smooth_cols(&self) -> usize {
970 self.term_designs.iter().map(DesignMatrix::ncols).sum()
971 }
972 pub fn nrows(&self) -> usize {
973 self.term_designs.first().map_or(0, DesignMatrix::nrows)
974 }
975}
976
977#[derive(Debug, Clone)]
978pub struct RawSmoothDesign {
979 pub term_designs: Vec<DesignMatrix>,
980 pub penalties: Vec<BlockwisePenalty>,
983 pub nullspace_dims: Vec<usize>,
984 pub penaltyinfo: Vec<PenaltyBlockInfo>,
985 pub dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
986 pub terms: Vec<SmoothTerm>,
987 pub coefficient_lower_bounds: Option<Array1<f64>>,
988 pub linear_constraints: Option<LinearInequalityConstraints>,
989}
990
991impl RawSmoothDesign {
992 pub fn total_smooth_cols(&self) -> usize {
993 self.term_designs.iter().map(DesignMatrix::ncols).sum()
994 }
995 pub fn nrows(&self) -> usize {
996 self.term_designs.first().map_or(0, DesignMatrix::nrows)
997 }
998}
999
1000impl From<RawSmoothDesign> for SmoothDesign {
1001 fn from(value: RawSmoothDesign) -> Self {
1002 Self {
1003 term_designs: value.term_designs,
1004 penalties: value.penalties,
1005 nullspace_dims: value.nullspace_dims,
1006 penaltyinfo: value.penaltyinfo,
1007 dropped_penaltyinfo: value.dropped_penaltyinfo,
1008 terms: value.terms,
1009 coefficient_lower_bounds: value.coefficient_lower_bounds,
1010 linear_constraints: value.linear_constraints,
1011 }
1012 }
1013}
1014
1015#[derive(Debug, Default, Clone, Serialize, Deserialize)]
1016pub enum BoundedCoefficientPriorSpec {
1017 #[default]
1018 None,
1019 Uniform,
1020 Beta {
1021 a: f64,
1022 b: f64,
1023 },
1024}
1025
1026#[derive(Debug, Clone, Serialize, Deserialize, Default)]
1027pub enum LinearCoefficientGeometry {
1028 #[default]
1029 Unconstrained,
1030 Bounded {
1031 min: f64,
1032 max: f64,
1033 #[serde(default)]
1034 prior: BoundedCoefficientPriorSpec,
1035 },
1036}
1037
1038#[derive(Debug, Clone, Serialize, Deserialize)]
1039pub struct LinearTermSpec {
1040 pub name: String,
1041 pub feature_col: usize,
1047 #[serde(default)]
1050 pub feature_cols: Vec<usize>,
1051 #[serde(default)]
1062 pub categorical_levels: Vec<(usize, u64)>,
1063 #[serde(default = "default_linear_term_double_penalty")]
1070 pub double_penalty: bool,
1071 #[serde(default)]
1072 pub coefficient_geometry: LinearCoefficientGeometry,
1073 #[serde(default)]
1074 pub coefficient_min: Option<f64>,
1075 #[serde(default)]
1076 pub coefficient_max: Option<f64>,
1077}
1078
1079impl LinearTermSpec {
1080 pub fn effective_feature_cols(&self) -> Vec<usize> {
1083 if self.feature_cols.is_empty() {
1084 vec![self.feature_col]
1085 } else {
1086 self.feature_cols.clone()
1087 }
1088 }
1089
1090 pub fn is_interaction(&self) -> bool {
1092 self.feature_cols.len() > 1 || !self.categorical_levels.is_empty()
1093 }
1094
1095 pub fn realized_design_column(&self, data: ArrayView2<'_, f64>) -> Result<Array1<f64>, String> {
1107 let n = data.nrows();
1108 let p = data.ncols();
1109 let bounds = |col: usize| -> Result<(), String> {
1110 if col >= p {
1111 Err(format!(
1112 "linear term '{}' feature column {} out of bounds for {} columns",
1113 self.name, col, p
1114 ))
1115 } else {
1116 Ok(())
1117 }
1118 };
1119
1120 let mut column = if self.categorical_levels.is_empty() {
1125 let cols = self.effective_feature_cols();
1126 for &c in &cols {
1127 bounds(c)?;
1128 }
1129 let mut acc = data.column(cols[0]).to_owned();
1130 for &c in cols.iter().skip(1) {
1131 acc *= &data.column(c);
1132 }
1133 acc
1134 } else {
1135 let mut acc = Array1::<f64>::ones(n);
1136 for &c in &self.feature_cols {
1137 bounds(c)?;
1138 acc *= &data.column(c);
1139 }
1140 acc
1141 };
1142
1143 for &(col, level_bits) in &self.categorical_levels {
1144 bounds(col)?;
1145 let gate = data.column(col);
1146 for (out, &v) in column.iter_mut().zip(gate.iter()) {
1147 if v.to_bits() != level_bits {
1148 *out = 0.0;
1149 }
1150 }
1151 }
1152
1153 Ok(column)
1154 }
1155}
1156
1157pub const fn default_linear_term_double_penalty() -> bool {
1158 false
1166}
1167
1168pub const fn default_pca_smooth_penalty() -> f64 {
1169 1.0
1170}
1171
1172pub const fn default_pca_chunk_size() -> usize {
1173 4096
1174}
1175
1176#[derive(Debug, Clone, Serialize, Deserialize)]
1182pub struct RandomEffectTermSpec {
1183 pub name: String,
1184 pub feature_col: usize,
1185 pub drop_first_level: bool,
1188 #[serde(default = "default_random_effect_penalized")]
1192 pub penalized: bool,
1193 #[serde(default)]
1196 pub frozen_levels: Option<Vec<u64>>,
1197}
1198
1199pub fn default_random_effect_penalized() -> bool {
1200 true
1201}
1202
1203pub fn validate_measure_jet_positive_vec_len(
1204 label: &str,
1205 term_name: &str,
1206 field: &str,
1207 values: &[f64],
1208 expected: usize,
1209) -> Result<(), String> {
1210 if values.len() != expected {
1211 return Err(SmoothError::invalid_config(format!(
1212 "{label} term '{term_name}' frozen MeasureJet {field} has length {}, expected {expected}",
1213 values.len()
1214 ))
1215 .into());
1216 }
1217 if values
1218 .iter()
1219 .any(|value| !(value.is_finite() && *value > 0.0))
1220 {
1221 return Err(SmoothError::invalid_config(format!(
1222 "{label} term '{term_name}' frozen MeasureJet {field} values must be positive and finite"
1223 ))
1224 .into());
1225 }
1226 Ok(())
1227}
1228
1229#[derive(Debug, Clone, Serialize, Deserialize)]
1230pub struct TermCollectionSpec {
1231 pub linear_terms: Vec<LinearTermSpec>,
1232 pub random_effect_terms: Vec<RandomEffectTermSpec>,
1233 pub smooth_terms: Vec<SmoothTermSpec>,
1234}
1235
1236pub fn validate_smooth_basis_frozen(
1237 basis: &SmoothBasisSpec,
1238 label: &str,
1239 term_name: &str,
1240) -> Result<(), String> {
1241 match basis {
1242 SmoothBasisSpec::ByVariable { inner, .. }
1243 | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
1244 validate_smooth_basis_frozen(inner, label, term_name)
1245 }
1246 SmoothBasisSpec::BSpline1D { spec, .. } => {
1247 if !matches!(
1248 spec.knotspec,
1249 BSplineKnotSpec::Provided(_)
1250 | BSplineKnotSpec::PeriodicUniform { .. }
1251 | BSplineKnotSpec::NaturalCubicRegression { .. }
1252 ) {
1253 return Err(format!(
1254 "{label} term '{term_name}' is not frozen: BSpline knotspec must be Provided, PeriodicUniform, or NaturalCubicRegression"
1255 ));
1256 }
1257 Ok(())
1258 }
1259 SmoothBasisSpec::ThinPlate { spec, .. } => {
1260 if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1261 return Err(format!(
1262 "{label} term '{term_name}' is not frozen: ThinPlate centers must be UserProvided"
1263 ));
1264 }
1265 if matches!(
1266 spec.identifiability,
1267 SpatialIdentifiability::OrthogonalToParametric
1268 ) {
1269 return Err(format!(
1270 "{label} term '{term_name}' is not frozen: ThinPlate identifiability must be FrozenTransform or None"
1271 ));
1272 }
1273 Ok(())
1274 }
1275 _ => Ok(()),
1276 }
1277}
1278
1279impl TermCollectionSpec {
1280 pub fn write_structural_shape_hash(&self, h: &mut gam_runtime::warm_start::Fingerprinter) {
1294 h.write_str("term-collection");
1295 h.write_usize(self.linear_terms.len());
1296 for linear in &self.linear_terms {
1297 h.write_str(&linear.name);
1298 }
1299 h.write_usize(self.random_effect_terms.len());
1300 h.write_usize(self.smooth_terms.len());
1301 for smooth in &self.smooth_terms {
1302 h.write_str(&smooth.name);
1303 h.write_str(smooth.basis.structural_kind());
1304 for col in smooth.basis.structural_feature_cols() {
1305 h.write_usize(col);
1306 }
1307 }
1308 }
1309
1310 pub fn validate_frozen(&self, label: &str) -> Result<(), String> {
1314 for linear in &self.linear_terms {
1315 if let (Some(min), Some(max)) = (linear.coefficient_min, linear.coefficient_max)
1316 && (!min.is_finite() || !max.is_finite() || min > max)
1317 {
1318 return Err(SmoothError::invalid_config(format!(
1319 "{label} linear term '{}' has invalid coefficient constraint [{min}, {max}]",
1320 linear.name
1321 ))
1322 .into());
1323 }
1324 if let Some(min) = linear.coefficient_min
1325 && !min.is_finite()
1326 {
1327 return Err(SmoothError::invalid_config(format!(
1328 "{label} linear term '{}' has non-finite coefficient minimum {min}",
1329 linear.name
1330 ))
1331 .into());
1332 }
1333 if let Some(max) = linear.coefficient_max
1334 && !max.is_finite()
1335 {
1336 return Err(SmoothError::invalid_config(format!(
1337 "{label} linear term '{}' has non-finite coefficient maximum {max}",
1338 linear.name
1339 ))
1340 .into());
1341 }
1342 if let LinearCoefficientGeometry::Bounded { min, max, prior } =
1343 &linear.coefficient_geometry
1344 {
1345 if !min.is_finite() || !max.is_finite() || min >= max {
1346 return Err(SmoothError::invalid_config(format!(
1347 "{label} bounded term '{}' has invalid bounds [{min}, {max}]",
1348 linear.name
1349 ))
1350 .into());
1351 }
1352 match prior {
1353 BoundedCoefficientPriorSpec::None | BoundedCoefficientPriorSpec::Uniform => {}
1354 BoundedCoefficientPriorSpec::Beta { a, b } => {
1355 if !a.is_finite() || !b.is_finite() || *a < 1.0 || *b < 1.0 {
1356 return Err(SmoothError::invalid_config(format!(
1357 "{label} bounded term '{}' has invalid Beta prior ({a}, {b})",
1358 linear.name
1359 ))
1360 .into());
1361 }
1362 }
1363 }
1364 }
1365 }
1366 for st in &self.smooth_terms {
1367 match &st.basis {
1368 SmoothBasisSpec::ByVariable { inner, .. } => {
1369 validate_smooth_basis_frozen(inner, label, &st.name)?;
1370 let nested = SmoothTermSpec {
1371 name: st.name.clone(),
1372 basis: (**inner).clone(),
1373 shape: st.shape,
1374 joint_null_rotation: None,
1375 };
1376 TermCollectionSpec {
1377 linear_terms: Vec::new(),
1378 random_effect_terms: Vec::new(),
1379 smooth_terms: vec![nested],
1380 }
1381 .validate_frozen(label)?;
1382 }
1383 SmoothBasisSpec::FactorSumToZero { inner, levels, .. } => {
1384 if levels.len() < 2 {
1385 return Err(format!(
1386 "{label} term '{}' has invalid frozen sz levels",
1387 st.name
1388 ));
1389 }
1390 validate_smooth_basis_frozen(inner, label, &st.name)?;
1391 }
1392 SmoothBasisSpec::BSpline1D { spec, .. } => {
1393 if !matches!(
1394 spec.knotspec,
1395 BSplineKnotSpec::Provided(_)
1396 | BSplineKnotSpec::PeriodicUniform { .. }
1397 | BSplineKnotSpec::NaturalCubicRegression { .. }
1398 ) {
1399 return Err(SmoothError::invalid_config(format!(
1400 "{label} term '{}' is not frozen: BSpline knotspec must be Provided, PeriodicUniform, or NaturalCubicRegression",
1401 st.name
1402 ))
1403 .into());
1404 }
1405 }
1406 SmoothBasisSpec::ThinPlate { spec, .. } => {
1407 if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1408 return Err(SmoothError::invalid_config(format!(
1409 "{label} term '{}' is not frozen: ThinPlate centers must be UserProvided",
1410 st.name
1411 ))
1412 .into());
1413 }
1414 if matches!(
1415 spec.identifiability,
1416 SpatialIdentifiability::OrthogonalToParametric
1417 ) {
1418 return Err(SmoothError::invalid_config(format!(
1419 "{label} term '{}' is not frozen: ThinPlate identifiability must be FrozenTransform or None",
1420 st.name
1421 ))
1422 .into());
1423 }
1424 }
1425 SmoothBasisSpec::Sphere { spec, .. } => {
1426 if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1427 return Err(SmoothError::invalid_config(format!(
1428 "{label} term '{}' is not frozen: Sphere centers must be UserProvided",
1429 st.name
1430 ))
1431 .into());
1432 }
1433 if matches!(spec.method, crate::basis::SphereMethod::Harmonic)
1434 && spec.max_degree.is_none_or(|d| d == 0)
1435 {
1436 return Err(format!(
1437 "{label} term '{}' is not frozen: sphere max_degree must be positive",
1438 st.name
1439 ));
1440 }
1441 }
1442 SmoothBasisSpec::ConstantCurvature { spec, .. } => {
1443 if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1444 return Err(SmoothError::invalid_config(format!(
1445 "{label} term '{}' is not frozen: ConstantCurvature centers must be UserProvided",
1446 st.name
1447 ))
1448 .into());
1449 }
1450 if !(spec.length_scale.is_finite() && spec.length_scale > 0.0) {
1451 return Err(SmoothError::invalid_config(format!(
1452 "{label} term '{}' is not frozen: ConstantCurvature length_scale must be the realized positive value",
1453 st.name
1454 ))
1455 .into());
1456 }
1457 }
1458 SmoothBasisSpec::MeasureJet { spec, .. } => {
1459 let centers = match &spec.center_strategy {
1460 CenterStrategy::UserProvided(centers) => centers,
1461 _ => {
1462 return Err(SmoothError::invalid_config(format!(
1463 "{label} term '{}' is not frozen: MeasureJet centers must be UserProvided",
1464 st.name
1465 ))
1466 .into());
1467 }
1468 };
1469 if centers.nrows() == 0 {
1470 return Err(SmoothError::invalid_config(format!(
1471 "{label} term '{}' is not frozen: MeasureJet centers are empty",
1472 st.name
1473 ))
1474 .into());
1475 }
1476 if !(spec.length_scale.is_finite() && spec.length_scale > 0.0) {
1477 return Err(SmoothError::invalid_config(format!(
1478 "{label} term '{}' is not frozen: MeasureJet length_scale must be the realized positive value",
1479 st.name
1480 ))
1481 .into());
1482 }
1483 let frozen = spec.frozen_quadrature.as_ref().ok_or_else(|| {
1486 SmoothError::invalid_config(format!(
1487 "{label} term '{}' is not frozen: MeasureJet frozen_quadrature payload is missing",
1488 st.name
1489 ))
1490 })?;
1491 if frozen.masses.len() != centers.nrows() {
1492 return Err(SmoothError::invalid_config(format!(
1493 "{label} term '{}' frozen MeasureJet has {} masses for {} centers",
1494 st.name,
1495 frozen.masses.len(),
1496 centers.nrows()
1497 ))
1498 .into());
1499 }
1500 let total_mass = frozen.masses.sum();
1501 if frozen
1502 .masses
1503 .iter()
1504 .any(|mass| !(mass.is_finite() && *mass >= 0.0))
1505 || !(total_mass.is_finite() && total_mass > 0.0)
1506 {
1507 return Err(SmoothError::invalid_config(format!(
1508 "{label} term '{}' frozen MeasureJet masses must be finite, nonnegative, and have positive total mass",
1509 st.name
1510 ))
1511 .into());
1512 }
1513 let n_levels = frozen.eps_band.len();
1514 if n_levels == 0
1515 || frozen
1516 .eps_band
1517 .iter()
1518 .any(|eps| !(eps.is_finite() && *eps > 0.0))
1519 {
1520 return Err(SmoothError::invalid_config(format!(
1521 "{label} term '{}' frozen MeasureJet eps_band must be nonempty, finite, and positive",
1522 st.name
1523 ))
1524 .into());
1525 }
1526 for (idx, pair) in frozen.eps_band.windows(2).enumerate() {
1527 if pair[1] <= pair[0] {
1528 return Err(SmoothError::invalid_config(format!(
1529 "{label} term '{}' frozen MeasureJet eps_band is not strictly ascending at {idx}: {} then {}",
1530 st.name,
1531 pair[0],
1532 pair[1]
1533 ))
1534 .into());
1535 }
1536 }
1537 validate_measure_jet_positive_vec_len(
1538 label,
1539 &st.name,
1540 "support_means",
1541 &frozen.support_means,
1542 n_levels,
1543 )?;
1544 let per_level = crate::basis::measure_jet_multiscale_mode(spec);
1552 if per_level {
1553 validate_measure_jet_positive_vec_len(
1554 label,
1555 &st.name,
1556 "penalty_normalization_scales",
1557 &frozen.penalty_normalization_scales,
1558 n_levels,
1559 )?;
1560 validate_measure_jet_positive_vec_len(
1561 label,
1562 &st.name,
1563 "raw_penalty_normalization_scales",
1564 &frozen.raw_penalty_normalization_scales,
1565 n_levels,
1566 )?;
1567 if frozen.fused_penalty_normalization_scale.is_some() {
1568 return Err(SmoothError::invalid_config(format!(
1569 "{label} term '{}' per-level MeasureJet must not carry a fused penalty normalization scale",
1570 st.name
1571 ))
1572 .into());
1573 }
1574 } else {
1575 if !frozen.penalty_normalization_scales.is_empty()
1576 || !frozen.raw_penalty_normalization_scales.is_empty()
1577 {
1578 return Err(SmoothError::invalid_config(format!(
1579 "{label} term '{}' fused MeasureJet must not carry per-level penalty normalization scales",
1580 st.name
1581 ))
1582 .into());
1583 }
1584 match frozen.fused_penalty_normalization_scale {
1585 Some(scale) if scale.is_finite() && scale > 0.0 => {}
1586 Some(scale) => {
1587 return Err(SmoothError::invalid_config(format!(
1588 "{label} term '{}' fused MeasureJet penalty normalization scale must be positive and finite, got {scale}",
1589 st.name
1590 ))
1591 .into());
1592 }
1593 None => {
1594 return Err(SmoothError::invalid_config(format!(
1595 "{label} term '{}' fused MeasureJet is missing its penalty normalization scale",
1596 st.name
1597 ))
1598 .into());
1599 }
1600 }
1601 }
1602 }
1603 SmoothBasisSpec::Matern { spec, .. } => {
1604 if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1605 return Err(SmoothError::invalid_config(format!(
1606 "{label} term '{}' is not frozen: Matern centers must be UserProvided",
1607 st.name
1608 ))
1609 .into());
1610 }
1611 }
1612 SmoothBasisSpec::Duchon { spec, .. } => {
1613 if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1614 return Err(SmoothError::invalid_config(format!(
1615 "{label} term '{}' is not frozen: Duchon centers must be UserProvided",
1616 st.name
1617 ))
1618 .into());
1619 }
1620 if matches!(
1621 spec.identifiability,
1622 SpatialIdentifiability::OrthogonalToParametric
1623 ) {
1624 return Err(SmoothError::invalid_config(format!(
1625 "{label} term '{}' is not frozen: Duchon identifiability must be FrozenTransform or None",
1626 st.name
1627 ))
1628 .into());
1629 }
1630 }
1631 SmoothBasisSpec::Pca {
1632 centered,
1633 center_mean,
1634 pca_basis_path,
1635 ..
1636 } => {
1637 if *centered && center_mean.is_none() && pca_basis_path.is_none() {
1638 return Err(SmoothError::invalid_config(format!(
1639 "{label} term '{}' is not frozen: centered Pca missing center_mean",
1640 st.name
1641 ))
1642 .into());
1643 }
1644 }
1645 SmoothBasisSpec::BySmooth { smooth, by_kind } => {
1646 if let SmoothBasisSpec::BySmooth { .. } = smooth.as_ref() {
1647 return Err(format!("{label} term '{}' has nested by-smooths", st.name));
1648 }
1649 match by_kind {
1650 ByVarKind::Numeric { .. } => {}
1651 ByVarKind::Factor { frozen_levels, .. } if frozen_levels.is_none() => {
1652 return Err(format!(
1653 "{label} term '{}' is not frozen: by-factor levels missing",
1654 st.name
1655 ));
1656 }
1657 ByVarKind::Factor { .. } => {}
1658 }
1659 let nested = TermCollectionSpec {
1660 linear_terms: vec![],
1661 random_effect_terms: vec![],
1662 smooth_terms: vec![SmoothTermSpec {
1663 name: st.name.clone(),
1664 basis: (**smooth).clone(),
1665 shape: st.shape,
1666 joint_null_rotation: None,
1667 }],
1668 };
1669 nested.validate_frozen(label)?;
1670 }
1671 SmoothBasisSpec::FactorSmooth { spec } => {
1672 if spec.group_frozen_levels.is_none() {
1673 return Err(format!(
1674 "{label} term '{}' is not frozen: factor-smooth levels missing",
1675 st.name
1676 ));
1677 }
1678 if !matches!(
1679 spec.marginal.knotspec,
1680 BSplineKnotSpec::Provided(_)
1681 | BSplineKnotSpec::PeriodicUniform { .. }
1682 | BSplineKnotSpec::NaturalCubicRegression { .. }
1694 ) {
1695 return Err(format!(
1696 "{label} term '{}' is not frozen: factor-smooth marginal knots missing",
1697 st.name
1698 ));
1699 }
1700 }
1701 SmoothBasisSpec::TensorBSpline { spec, .. } => {
1702 for (dim, marginal) in spec.marginalspecs.iter().enumerate() {
1703 if !matches!(
1704 marginal.knotspec,
1705 BSplineKnotSpec::Provided(_)
1706 | BSplineKnotSpec::PeriodicUniform { .. }
1707 | BSplineKnotSpec::NaturalCubicRegression { .. }
1708 ) {
1709 return Err(SmoothError::invalid_config(format!(
1710 "{label} term '{}' dim {} is not frozen: tensor marginal knotspec must be Provided, PeriodicUniform, or NaturalCubicRegression",
1711 st.name, dim
1712 ))
1713 .into());
1714 }
1715 }
1716 if matches!(
1717 spec.identifiability,
1718 TensorBSplineIdentifiability::SumToZero
1719 | TensorBSplineIdentifiability::MarginalSumToZero
1720 ) {
1721 return Err(SmoothError::invalid_config(format!(
1722 "{label} term '{}' is not frozen: tensor identifiability must be FrozenTransform or None",
1723 st.name
1724 ))
1725 .into());
1726 }
1727 }
1728 }
1729 }
1730
1731 for rt in &self.random_effect_terms {
1732 if rt.frozen_levels.is_none() {
1733 return Err(SmoothError::invalid_config(format!(
1734 "{label} random-effect term '{}' is not frozen: missing frozen_levels",
1735 rt.name
1736 ))
1737 .into());
1738 }
1739 }
1740
1741 Ok(())
1742 }
1743
1744 pub fn remap_feature_columns<E, F>(&self, mut remap: F) -> Result<TermCollectionSpec, E>
1763 where
1764 F: FnMut(usize) -> Result<usize, E>,
1765 {
1766 let mut out = self.clone();
1767 for lt in &mut out.linear_terms {
1768 lt.feature_col = remap(lt.feature_col)?;
1769 for fc in lt.feature_cols.iter_mut() {
1779 *fc = remap(*fc)?;
1780 }
1781 for (col, _bits) in lt.categorical_levels.iter_mut() {
1786 *col = remap(*col)?;
1787 }
1788 }
1789 for rt in &mut out.random_effect_terms {
1790 rt.feature_col = remap(rt.feature_col)?;
1791 }
1792 for st in &mut out.smooth_terms {
1793 remap_smooth_basis_feature_columns(&mut st.basis, &mut remap)?;
1794 }
1795 Ok(out)
1796 }
1797}
1798
1799pub fn remap_smooth_basis_feature_columns<E, F>(
1804 basis: &mut SmoothBasisSpec,
1805 remap: &mut F,
1806) -> Result<(), E>
1807where
1808 F: FnMut(usize) -> Result<usize, E>,
1809{
1810 match basis {
1811 SmoothBasisSpec::ByVariable { inner, by_col, .. }
1812 | SmoothBasisSpec::FactorSumToZero { inner, by_col, .. } => {
1813 *by_col = remap(*by_col)?;
1814 remap_smooth_basis_feature_columns(inner, remap)?;
1815 }
1816 SmoothBasisSpec::BSpline1D { feature_col, .. } => {
1817 *feature_col = remap(*feature_col)?;
1818 }
1819 SmoothBasisSpec::BySmooth { smooth, by_kind } => {
1820 let by_feature_col = match by_kind {
1821 ByVarKind::Numeric { feature_col } | ByVarKind::Factor { feature_col, .. } => {
1822 feature_col
1823 }
1824 };
1825 *by_feature_col = remap(*by_feature_col)?;
1826 remap_smooth_basis_feature_columns(smooth, remap)?;
1827 }
1828 SmoothBasisSpec::FactorSmooth { spec } => {
1829 for fc in spec.continuous_cols.iter_mut() {
1830 *fc = remap(*fc)?;
1831 }
1832 spec.group_col = remap(spec.group_col)?;
1833 }
1834 SmoothBasisSpec::ThinPlate { feature_cols, .. }
1835 | SmoothBasisSpec::Sphere { feature_cols, .. }
1836 | SmoothBasisSpec::ConstantCurvature { feature_cols, .. }
1837 | SmoothBasisSpec::Matern { feature_cols, .. }
1838 | SmoothBasisSpec::MeasureJet { feature_cols, .. }
1839 | SmoothBasisSpec::Duchon { feature_cols, .. }
1840 | SmoothBasisSpec::Pca { feature_cols, .. }
1841 | SmoothBasisSpec::TensorBSpline { feature_cols, .. } => {
1842 for fc in feature_cols.iter_mut() {
1843 *fc = remap(*fc)?;
1844 }
1845 }
1846 }
1847 Ok(())
1848}
1849
1850#[derive(Debug, Clone)]
1851pub enum PenaltyStructureHint {
1852 Ridge(f64),
1853 Kronecker(Vec<Array2<f64>>),
1854}
1855
1856#[derive(Clone)]
1863pub struct BlockwisePenalty {
1864 pub col_range: Range<usize>,
1866 pub local: Array2<f64>,
1869 pub prior_mean: gam_problem::CoefficientPriorMean,
1871 pub structure_hint: Option<PenaltyStructureHint>,
1874 pub op: Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>,
1879}
1880
1881impl std::fmt::Debug for BlockwisePenalty {
1882 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1883 f.debug_struct("BlockwisePenalty")
1884 .field("col_range", &self.col_range)
1885 .field(
1886 "local",
1887 &format_args!("{}×{}", self.local.nrows(), self.local.ncols()),
1888 )
1889 .field("prior_mean", &self.prior_mean)
1890 .field("structure_hint", &self.structure_hint)
1891 .field("op", &self.op.as_ref().map(|o| o.dim()))
1892 .finish()
1893 }
1894}
1895
1896impl BlockwisePenalty {
1897 pub fn new(col_range: Range<usize>, local: Array2<f64>) -> Self {
1899 assert_eq!(col_range.len(), local.nrows());
1900 assert_eq!(col_range.len(), local.ncols());
1901 Self {
1902 col_range,
1903 local,
1904 prior_mean: gam_problem::CoefficientPriorMean::Zero,
1905 structure_hint: None,
1906 op: None,
1907 }
1908 }
1909
1910 pub fn with_prior_mean(
1911 mut self,
1912 prior_mean: gam_problem::CoefficientPriorMean,
1913 ) -> Self {
1914 self.prior_mean = prior_mean;
1915 self
1916 }
1917
1918 pub fn with_op(
1920 mut self,
1921 op: Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>,
1922 ) -> Self {
1923 self.op = op;
1924 self
1925 }
1926
1927 pub fn ridge(col_range: Range<usize>, scale: f64) -> Self {
1928 let block_size = col_range.len();
1929 let mut local = Array2::<f64>::zeros((block_size, block_size));
1930 for i in 0..block_size {
1931 local[[i, i]] = scale;
1932 }
1933 Self {
1934 col_range,
1935 local,
1936 prior_mean: gam_problem::CoefficientPriorMean::Zero,
1937 structure_hint: Some(PenaltyStructureHint::Ridge(scale)),
1938 op: None,
1939 }
1940 }
1941
1942 pub fn kronecker(
1943 col_range: Range<usize>,
1944 local: Array2<f64>,
1945 factors: Vec<Array2<f64>>,
1946 ) -> Self {
1947 assert_eq!(col_range.len(), local.nrows());
1948 assert_eq!(col_range.len(), local.ncols());
1949 Self {
1950 col_range,
1951 local,
1952 prior_mean: gam_problem::CoefficientPriorMean::Zero,
1953 structure_hint: Some(PenaltyStructureHint::Kronecker(factors)),
1954 op: None,
1955 }
1956 }
1957
1958 pub fn to_global(&self, p_total: usize) -> Array2<f64> {
1962 let mut g = Array2::<f64>::zeros((p_total, p_total));
1963 let r = &self.col_range;
1964 assert!(
1965 r.end <= p_total && self.local.nrows() == r.len() && self.local.ncols() == r.len(),
1966 "BlockwisePenalty::to_global shape invariant violated: \
1967 col_range={}..{}, local={}x{}, p_total={}",
1968 r.start,
1969 r.end,
1970 self.local.nrows(),
1971 self.local.ncols(),
1972 p_total,
1973 );
1974 g.slice_mut(s![r.start..r.end, r.start..r.end])
1975 .assign(&self.local);
1976 g
1977 }
1978
1979 pub fn to_penalty_matrix(
1982 &self,
1983 total_dim: usize,
1984 ) -> gam_problem::PenaltyMatrix {
1985 gam_problem::PenaltyMatrix::Blockwise {
1986 local: self.local.clone(),
1987 col_range: self.col_range.clone(),
1988 total_dim,
1989 }
1990 }
1991
1992 #[inline]
1994 pub fn block_size(&self) -> usize {
1995 self.col_range.len()
1996 }
1997}
1998
1999pub fn weighted_blockwise_penalty_sum(
2003 penalties: &[BlockwisePenalty],
2004 lambdas: &[f64],
2005 p_total: usize,
2006) -> Array2<f64> {
2007 assert_eq!(penalties.len(), lambdas.len());
2008 for (idx, &lam) in lambdas.iter().enumerate() {
2015 assert!(
2016 lam.is_finite() && lam >= 0.0,
2017 "weighted_blockwise_penalty_sum: lambdas[{idx}] = {lam} is invalid (must be finite and non-negative; negative smoothing parameters violate S_λ ⪰ 0)",
2018 );
2019 }
2020 for (idx, bp) in penalties.iter().enumerate() {
2024 let r = &bp.col_range;
2025 assert!(
2026 r.end <= p_total,
2027 "weighted_blockwise_penalty_sum: penalties[{idx}] col_range {:?} exceeds p_total = {p_total}",
2028 r,
2029 );
2030 }
2031 let mut out = Array2::<f64>::zeros((p_total, p_total));
2032 for (bp, &lam) in penalties.iter().zip(lambdas.iter()) {
2033 let r = &bp.col_range;
2034 let mut slice = out.slice_mut(s![r.start..r.end, r.start..r.end]);
2035 slice.scaled_add(lam, &bp.local);
2036 }
2037 out
2038}
2039
2040#[derive(Debug, Clone)]
2047pub struct KroneckerPenaltySystem {
2048 pub marginal_penalties: Vec<Array2<f64>>,
2050 pub marginal_eigensystems: Vec<(Array1<f64>, Array2<f64>)>,
2052 pub marginal_dims: Vec<usize>,
2054 pub has_double_penalty: bool,
2056}
2057
2058impl KroneckerPenaltySystem {
2059 pub fn new(
2060 marginal_penalties: Vec<Array2<f64>>,
2061 marginal_dims: Vec<usize>,
2062 has_double_penalty: bool,
2063 ) -> Result<Self, BasisError> {
2064 if marginal_penalties.len() != marginal_dims.len() {
2065 crate::bail_dim_basis!(
2066 "KroneckerPenaltySystem: {} penalties vs {} dims",
2067 marginal_penalties.len(),
2068 marginal_dims.len()
2069 );
2070 }
2071 let eigensystems =
2072 kronecker_marginal_eigensystems(&marginal_penalties, "KroneckerPenaltySystem")
2073 .map_err(|e| BasisError::InvalidInput(e.to_string()))?;
2074 Ok(Self {
2075 marginal_penalties,
2076 marginal_eigensystems: eigensystems,
2077 marginal_dims,
2078 has_double_penalty,
2079 })
2080 }
2081
2082 pub fn p_total(&self) -> usize {
2083 self.marginal_dims.iter().copied().product()
2084 }
2085
2086 pub fn ndim(&self) -> usize {
2087 self.marginal_dims.len()
2088 }
2089
2090 pub fn num_penalties(&self) -> usize {
2091 self.marginal_dims.len() + if self.has_double_penalty { 1 } else { 0 }
2092 }
2093
2094 pub fn logdet_and_derivatives(
2098 &self,
2099 lambdas: &[f64],
2100 ridge: f64,
2101 ) -> (f64, Array1<f64>, Array2<f64>) {
2102 let n_pen = self.num_penalties();
2103 assert_eq!(lambdas.len(), n_pen, "lambda count mismatch");
2104 let marginal_evals: Vec<_> = self
2105 .marginal_eigensystems
2106 .iter()
2107 .map(|(evals, _)| evals.view())
2108 .collect();
2109 kronecker_logdet_and_derivatives(
2110 &marginal_evals,
2111 &self.marginal_dims,
2112 lambdas,
2113 self.has_double_penalty,
2114 ridge,
2115 )
2116 }
2117
2118 pub fn logdet_rank_and_derivatives(
2119 &self,
2120 lambdas: &[f64],
2121 ridge: f64,
2122 ) -> (f64, usize, Array1<f64>, Array2<f64>) {
2123 let n_pen = self.num_penalties();
2124 assert_eq!(lambdas.len(), n_pen, "lambda count mismatch");
2125 let d = self.marginal_dims.len();
2126 let mut logdet = 0.0;
2127 let mut rank = 0usize;
2128 let mut grad = Array1::<f64>::zeros(n_pen);
2129 let mut hess = Array2::<f64>::zeros((n_pen, n_pen));
2130 const EIGENVALUE_POSITIVITY_FLOOR: f64 = 1e-12;
2134 const STRUCTURAL_ZERO_FLOOR: f64 = 1e-12;
2138 let mut multi_idx = vec![0usize; d];
2139 loop {
2140 let mut sigma = 0.0;
2141 let mut structural_sigma = 0.0;
2142 for k in 0..d {
2143 let marginal_eigenvalue = self.marginal_eigensystems[k].0[multi_idx[k]];
2144 structural_sigma += marginal_eigenvalue;
2145 sigma += lambdas[k] * marginal_eigenvalue;
2146 }
2147 let joint_null = structural_sigma <= STRUCTURAL_ZERO_FLOOR;
2148 if self.has_double_penalty && joint_null {
2149 sigma += lambdas[d];
2150 }
2151 if structural_sigma > STRUCTURAL_ZERO_FLOOR {
2152 sigma += ridge;
2153 }
2154
2155 if sigma > EIGENVALUE_POSITIVITY_FLOOR {
2156 rank += 1;
2157 logdet += sigma.ln();
2158 let inv_sigma = 1.0 / sigma;
2159 let inv_sigma2 = inv_sigma * inv_sigma;
2160 for k in 0..n_pen {
2161 let ck = if k < d {
2162 lambdas[k] * self.marginal_eigensystems[k].0[multi_idx[k]]
2163 } else if joint_null {
2164 lambdas[d]
2165 } else {
2166 0.0
2167 };
2168 grad[k] += ck * inv_sigma;
2169 hess[[k, k]] += ck * inv_sigma - ck * ck * inv_sigma2;
2170 for l in (k + 1)..n_pen {
2171 let cl = if l < d {
2172 lambdas[l] * self.marginal_eigensystems[l].0[multi_idx[l]]
2173 } else if joint_null {
2174 lambdas[d]
2175 } else {
2176 0.0
2177 };
2178 let off = -ck * cl * inv_sigma2;
2179 hess[[k, l]] += off;
2180 hess[[l, k]] += off;
2181 }
2182 }
2183 }
2184
2185 let mut carry = true;
2186 for dim in (0..d).rev() {
2187 if carry {
2188 multi_idx[dim] += 1;
2189 if multi_idx[dim] < self.marginal_dims[dim] {
2190 carry = false;
2191 } else {
2192 multi_idx[dim] = 0;
2193 }
2194 }
2195 }
2196 if carry {
2197 break;
2198 }
2199 }
2200 (logdet, rank, grad, hess)
2201 }
2202}
2203
2204#[cfg(test)]
2205mod joint_unpenalized_dim_tests {
2206 use super::joint_unpenalized_dim;
2207 use ndarray::{Array2, array};
2208
2209 #[test]
2210 fn no_penalty_is_fully_unpenalized() {
2211 assert_eq!(joint_unpenalized_dim(4, &[], &[]), 4);
2212 }
2213
2214 #[test]
2215 fn single_penalty_returns_its_own_null_space() {
2216 let s = array![[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 5.0]];
2219 assert_eq!(joint_unpenalized_dim(3, std::slice::from_ref(&s), &[2]), 2);
2220 }
2221
2222 #[test]
2223 fn complementary_double_penalty_has_empty_joint_null_space() {
2224 let bending = array![[0.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]];
2231 let ridge = array![[2.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]];
2232 assert_eq!(joint_unpenalized_dim(3, &[bending, ridge], &[1, 2]), 0);
2233 }
2234
2235 #[test]
2236 fn partial_overlap_keeps_shared_null_direction() {
2237 let a = array![[0.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 0.0]];
2241 let b = array![[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]];
2242 assert_eq!(joint_unpenalized_dim(3, &[a, b], &[2, 2]), 1);
2243 }
2244
2245 #[test]
2246 fn non_materialized_penalty_falls_back_conservatively() {
2247 let full: Array2<f64> = array![[0.0, 0.0], [0.0, 1.0]];
2251 let factor: Array2<f64> = array![[1.0]]; assert_eq!(
2253 joint_unpenalized_dim(2, &[full, factor.clone()], &[1, 0]),
2254 0
2255 );
2256 assert_eq!(joint_unpenalized_dim(4, std::slice::from_ref(&factor), &[2]), 2);
2258 }
2259}
2260
2261#[cfg(test)]
2262mod kronecker_penalty_system_tests {
2263 use super::KroneckerPenaltySystem;
2264 use ndarray::array;
2265
2266 #[test]
2267 fn double_penalty_rank_derivatives_use_only_joint_null_space() {
2268 let penalties = vec![
2269 array![[0.0, 0.0], [0.0, 2.0]],
2270 array![[0.0, 0.0], [0.0, 3.0]],
2271 ];
2272 let system = KroneckerPenaltySystem::new(penalties, vec![2usize, 2usize], true).unwrap();
2273 let lambdas = vec![5.0, 7.0, 11.0];
2274
2275 let (logdet, rank, grad, hess) = system.logdet_rank_and_derivatives(&lambdas, 0.0);
2276
2277 let expected_diag = [11.0_f64, 21.0, 10.0, 31.0];
2278 let expected_logdet: f64 = expected_diag.iter().map(|v| v.ln()).sum();
2279 assert_eq!(rank, 4);
2280 assert!((logdet - expected_logdet).abs() <= 1e-12);
2281 assert!(
2282 (grad[2] - 1.0).abs() <= 1e-12,
2283 "double-penalty rank derivative must count only the joint null mode, got {}",
2284 grad[2]
2285 );
2286 assert!(hess[[2, 2]].abs() <= 1e-12);
2287 }
2288}
2289
2290#[derive(Clone, Debug)]
2291pub struct TermCollectionDesign {
2292 pub design: DesignMatrix,
2301 pub penalties: Vec<BlockwisePenalty>,
2302 pub nullspace_dims: Vec<usize>,
2303 pub penaltyinfo: Vec<PenaltyBlockInfo>,
2304 pub dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
2305 pub coefficient_lower_bounds: Option<Array1<f64>>,
2308 pub linear_constraints: Option<LinearInequalityConstraints>,
2311 pub intercept_range: Range<usize>,
2312 pub linear_ranges: Vec<(String, Range<usize>)>,
2313 pub random_effect_ranges: Vec<(String, Range<usize>)>,
2314 pub random_effect_levels: Vec<(String, Vec<u64>)>,
2315 pub smooth: SmoothDesign,
2316}
2317
2318impl TermCollectionDesign {
2319 pub fn penalties_as_penalty_matrix(&self) -> Vec<gam_problem::PenaltyMatrix> {
2323 let p = self.design.ncols();
2324 self.penalties
2325 .iter()
2326 .map(|bp| bp.to_penalty_matrix(p))
2327 .collect()
2328 }
2329
2330 #[inline]
2332 pub fn num_penalties(&self) -> usize {
2333 self.penalties.len()
2334 }
2335
2336 pub fn realize_coefficient_groups(
2339 &self,
2340 groups: &[CoefficientGroupSpec],
2341 base_prior: &gam_spec::RhoPrior,
2342 ) -> Result<RealizedCoefficientGroups, BasisError> {
2343 realize_coefficient_groups(self, groups, base_prior)
2344 }
2345
2346 pub fn kronecker_penalty_system(&self) -> Option<KroneckerPenaltySystem> {
2357 let [only_term] = self.smooth.terms.as_slice() else {
2358 return None;
2359 };
2360 let kron = only_term.kronecker_factored.as_ref()?;
2361 if kron.marginal_dims.len() < 2
2367 || kron.marginal_penalties.len() != kron.marginal_dims.len()
2368 || kron.marginal_designs.len() != kron.marginal_dims.len()
2369 {
2370 return None;
2371 }
2372 KroneckerPenaltySystem::new(
2373 kron.marginal_penalties.clone(),
2374 kron.marginal_dims.clone(),
2375 kron.has_double_penalty,
2376 )
2377 .ok()
2378 }
2379}
2380
2381#[derive(Clone)]
2387pub struct StandardLatentCoordConfig {
2388 pub values: std::sync::Arc<crate::latent::LatentCoordValues>,
2389 pub term_index: gam_problem::types::SmoothTermIdx,
2390 pub feature_cols: Vec<usize>,
2391 pub manifold: crate::latent::LatentManifold,
2392 pub manifold_auto: bool,
2393 pub retraction_registry: gam_problem::LatentRetractionRegistry,
2394 pub analytic_penalties: Option<std::sync::Arc<crate::AnalyticPenaltyRegistry>>,
2395}
2396
2397#[derive(Clone, Debug, Serialize, Deserialize)]
2398pub struct AdaptiveSpatialMap {
2399 pub termname: String,
2400 pub feature_cols: Vec<usize>,
2401 pub collocation_points: Array2<f64>,
2402 pub inv_magweight: Array1<f64>,
2403 pub invgradweight: Array1<f64>,
2404 pub inv_lapweight: Array1<f64>,
2405}
2406
2407#[derive(Clone, Debug, Serialize, Deserialize)]
2408pub struct AdaptiveRegularizationDiagnostics {
2409 pub epsilon_0: f64,
2410 pub epsilon_g: f64,
2411 pub epsilon_c: f64,
2412 pub epsilon_outer_iterations: usize,
2413 pub mm_iterations: usize,
2414 pub converged: bool,
2415 pub maps: Vec<AdaptiveSpatialMap>,
2416}
2417
2418#[derive(Debug, Clone)]
2419pub struct LinearColumnConditioning {
2420 col_idx: usize,
2421 mean: f64,
2422 scale: f64,
2423}
2424
2425#[derive(Debug, Clone, Default)]
2426pub struct LinearFitConditioning {
2427 pub intercept_idx: usize,
2428 pub columns: Vec<LinearColumnConditioning>,
2429}
2430
2431#[derive(Clone)]
2432pub struct SpatialPsiDerivative {
2433 pub penalty_index: usize,
2435 pub penalty_indices: Vec<usize>,
2436 pub global_range: Range<usize>,
2437 pub total_p: usize,
2438 pub x_psi_local: Array2<f64>,
2439 pub s_psi_components_local: Vec<Array2<f64>>,
2440 pub x_psi_psi_local: Array2<f64>,
2441 pub s_psi_psi_components_local: Vec<Array2<f64>>,
2442 pub aniso_group_id: Option<usize>,
2443 pub aniso_cross_designs: Option<Vec<(usize, Array2<f64>)>>,
2446 pub aniso_cross_penalty_provider: Option<
2450 std::sync::Arc<
2451 dyn Fn(usize) -> Result<Vec<Array2<f64>>, EstimationError> + Send + Sync + 'static,
2452 >,
2453 >,
2454 pub implicit_operator: Option<std::sync::Arc<crate::basis::ImplicitDesignPsiDerivative>>,
2459 pub implicit_axis: usize,
2461}
2462
2463#[derive(Debug, Clone)]
2464pub struct SpatialLogKappaCoords {
2465 pub values: Array1<f64>,
2468 pub dims_per_term: Vec<usize>,
2470}
2471
2472#[derive(Clone, Copy)]
2477pub enum AnisoBoundEnd {
2478 Lower,
2479 Upper,
2480}
2481
2482impl SpatialLogKappaCoords {
2483 pub fn new_with_dims(values: Array1<f64>, dims_per_term: Vec<usize>) -> Self {
2485 assert_eq!(
2486 values.len(),
2487 dims_per_term.iter().sum::<usize>(),
2488 "SpatialLogKappaCoords: values length {} != sum of dims_per_term {}",
2489 values.len(),
2490 dims_per_term.iter().sum::<usize>(),
2491 );
2492 Self {
2493 values,
2494 dims_per_term,
2495 }
2496 }
2497
2498 pub fn from_length_scales(
2500 spec: &TermCollectionSpec,
2501 term_indices: &[usize],
2502 options: &SpatialLengthScaleOptimizationOptions,
2503 ) -> Self {
2504 let mut out = Array1::<f64>::zeros(term_indices.len());
2505 for (slot, &term_idx) in term_indices.iter().enumerate() {
2506 if let Some(cc) = constant_curvature_term_spec(spec, term_idx) {
2512 out[slot] = cc.kappa;
2513 continue;
2514 }
2515 let length_scale = get_spatial_length_scale(spec, term_idx)
2516 .unwrap_or(options.min_length_scale)
2517 .clamp(options.min_length_scale, options.max_length_scale);
2518 out[slot] = -length_scale.ln();
2519 }
2520 Self {
2521 values: out,
2522 dims_per_term: vec![1; term_indices.len()],
2523 }
2524 }
2525
2526 pub fn from_length_scales_aniso(
2544 spec: &TermCollectionSpec,
2545 term_indices: &[usize],
2546 options: &SpatialLengthScaleOptimizationOptions,
2547 ) -> Self {
2548 let mut vals = Vec::new();
2549 let mut dims = Vec::new();
2550 for &term_idx in term_indices {
2551 if let Some(mj) = measure_jet_term_spec(spec, term_idx) {
2555 let seed = measure_jet_psi_seed(mj);
2556 dims.push(seed.len());
2557 vals.extend(seed);
2558 continue;
2559 }
2560 if let Some(cc) = constant_curvature_term_spec(spec, term_idx) {
2566 vals.push(cc.kappa);
2567 dims.push(1);
2568 continue;
2569 }
2570 let length_scale = get_spatial_length_scale(spec, term_idx)
2571 .unwrap_or(options.min_length_scale)
2572 .clamp(options.min_length_scale, options.max_length_scale);
2573 let psi_bar = -length_scale.ln(); if spatial_term_uses_per_axis_psi(spec, term_idx) {
2576 let d = get_spatial_feature_dim(spec, term_idx).unwrap_or(1);
2581 let eta_raw = get_spatial_aniso_log_scales(spec, term_idx)
2582 .expect("predicate guarantees aniso_log_scales is Some");
2583 let eta = center_aniso_log_scales(&eta_raw);
2584 for &eta_a in &eta {
2585 vals.push(psi_bar + eta_a);
2586 }
2587 dims.push(d);
2588 } else {
2589 vals.push(psi_bar);
2596 dims.push(1);
2597 }
2598 }
2599 Self {
2600 values: Array1::from_vec(vals),
2601 dims_per_term: dims,
2602 }
2603 }
2604
2605 pub fn lower_bounds_from_data(
2609 data: ArrayView2<'_, f64>,
2610 spec: &TermCollectionSpec,
2611 term_indices: &[usize],
2612 options: &SpatialLengthScaleOptimizationOptions,
2613 ) -> Self {
2614 let mut values = Array1::<f64>::zeros(term_indices.len());
2615 for (slot, &term_idx) in term_indices.iter().enumerate() {
2616 values[slot] = spatial_term_psi_bounds(data, spec, term_idx, options).0;
2617 }
2618 Self {
2619 values,
2620 dims_per_term: vec![1; term_indices.len()],
2621 }
2622 }
2623
2624 pub fn upper_bounds_from_data(
2626 data: ArrayView2<'_, f64>,
2627 spec: &TermCollectionSpec,
2628 term_indices: &[usize],
2629 options: &SpatialLengthScaleOptimizationOptions,
2630 ) -> Self {
2631 let mut values = Array1::<f64>::zeros(term_indices.len());
2632 for (slot, &term_idx) in term_indices.iter().enumerate() {
2633 values[slot] = spatial_term_psi_bounds(data, spec, term_idx, options).1;
2634 }
2635 Self {
2636 values,
2637 dims_per_term: vec![1; term_indices.len()],
2638 }
2639 }
2640
2641 pub fn lower_bounds_aniso_from_data(
2658 data: ArrayView2<'_, f64>,
2659 spec: &TermCollectionSpec,
2660 term_indices: &[usize],
2661 dims_per_term: &[usize],
2662 options: &SpatialLengthScaleOptimizationOptions,
2663 ) -> Self {
2664 Self::aniso_bounds_from_data(
2665 data,
2666 spec,
2667 term_indices,
2668 dims_per_term,
2669 options,
2670 AnisoBoundEnd::Lower,
2671 )
2672 }
2673
2674 pub fn upper_bounds_aniso_from_data(
2678 data: ArrayView2<'_, f64>,
2679 spec: &TermCollectionSpec,
2680 term_indices: &[usize],
2681 dims_per_term: &[usize],
2682 options: &SpatialLengthScaleOptimizationOptions,
2683 ) -> Self {
2684 Self::aniso_bounds_from_data(
2685 data,
2686 spec,
2687 term_indices,
2688 dims_per_term,
2689 options,
2690 AnisoBoundEnd::Upper,
2691 )
2692 }
2693
2694 fn aniso_bounds_from_data(
2700 data: ArrayView2<'_, f64>,
2701 spec: &TermCollectionSpec,
2702 term_indices: &[usize],
2703 dims_per_term: &[usize],
2704 options: &SpatialLengthScaleOptimizationOptions,
2705 end: AnisoBoundEnd,
2706 ) -> Self {
2707 assert_eq!(term_indices.len(), dims_per_term.len());
2708 let total: usize = dims_per_term.iter().sum();
2709 let mut values = Array1::<f64>::zeros(total);
2710 let mut cursor = 0;
2711 for (slot, &term_idx) in term_indices.iter().enumerate() {
2712 let d = dims_per_term[slot];
2713 if let Some(mj) = measure_jet_term_spec(spec, term_idx) {
2716 let bounds = measure_jet_psi_bound_values(mj, matches!(end, AnisoBoundEnd::Upper));
2717 for (offset, bound) in bounds.into_iter().enumerate() {
2718 if offset < d {
2719 values[cursor + offset] = bound;
2720 }
2721 }
2722 cursor += d;
2723 continue;
2724 }
2725 if constant_curvature_term_spec(spec, term_idx).is_some() {
2728 let (lo, hi) = constant_curvature_kappa_bounds(data, spec, term_idx);
2729 if d >= 1 {
2730 values[cursor] = match end {
2731 AnisoBoundEnd::Lower => lo,
2732 AnisoBoundEnd::Upper => hi,
2733 };
2734 }
2735 cursor += d;
2736 continue;
2737 }
2738 let psi_bound = {
2739 let (lo, hi) = spatial_term_psi_bounds(data, spec, term_idx, options);
2740 match end {
2741 AnisoBoundEnd::Lower => lo,
2742 AnisoBoundEnd::Upper => hi,
2743 }
2744 };
2745 let axis_offsets = if d <= 1 {
2746 vec![0.0; d]
2747 } else {
2748 get_spatial_aniso_log_scales(spec, term_idx)
2749 .filter(|eta| eta.len() == d)
2750 .map(|eta| center_aniso_log_scales(&eta))
2751 .unwrap_or_else(|| vec![0.0; d])
2752 };
2753 for offset in 0..d {
2754 values[cursor + offset] = psi_bound + axis_offsets[offset];
2755 }
2756 cursor += d;
2757 }
2758 Self {
2759 values,
2760 dims_per_term: dims_per_term.to_vec(),
2761 }
2762 }
2763
2764 pub fn reseed_from_data(
2773 mut self,
2774 data: ArrayView2<'_, f64>,
2775 spec: &TermCollectionSpec,
2776 term_indices: &[usize],
2777 options: &SpatialLengthScaleOptimizationOptions,
2778 ) -> Self {
2779 assert_eq!(term_indices.len(), self.dims_per_term.len());
2780 let mut cursor = 0;
2781 for (slot, &term_idx) in term_indices.iter().enumerate() {
2782 let d = self.dims_per_term[slot];
2783 if measure_jet_term_spec(spec, term_idx).is_some() {
2786 cursor += d;
2787 continue;
2788 }
2789 if constant_curvature_term_spec(spec, term_idx).is_some() {
2793 cursor += d;
2794 continue;
2795 }
2796 let Some(psi_bar_new) = spatial_term_psi_seed(data, spec, term_idx, options) else {
2797 cursor += d;
2798 continue;
2799 };
2800 if d == 0 {
2801 continue;
2802 }
2803 let current: Vec<f64> = self.values.slice(s![cursor..cursor + d]).to_vec();
2804 let psi_bar_old = current.iter().sum::<f64>() / d as f64;
2805 for (offset, &old_value) in current.iter().enumerate() {
2806 self.values[cursor + offset] = psi_bar_new + (old_value - psi_bar_old);
2807 }
2808 cursor += d;
2809 }
2810 self
2811 }
2812
2813 pub fn clamp_to_bounds(
2824 mut self,
2825 lower: &SpatialLogKappaCoords,
2826 upper: &SpatialLogKappaCoords,
2827 ) -> Self {
2828 assert_eq!(self.values.len(), lower.values.len());
2829 assert_eq!(self.values.len(), upper.values.len());
2830 let mut n_projected = 0usize;
2831 let mut worst_delta = 0.0_f64;
2832 for idx in 0..self.values.len() {
2833 let lo = lower.values[idx];
2834 let hi = upper.values[idx];
2835 if !(lo.is_finite() && hi.is_finite()) {
2836 continue;
2837 }
2838 let v = self.values[idx];
2839 if v < lo {
2840 worst_delta = worst_delta.max(lo - v);
2841 self.values[idx] = lo;
2842 n_projected += 1;
2843 } else if v > hi {
2844 worst_delta = worst_delta.max(v - hi);
2845 self.values[idx] = hi;
2846 n_projected += 1;
2847 }
2848 }
2849 if n_projected > 0 {
2850 log::info!(
2851 "[spatial-kappa] projected {n_projected}/{} ψ seed coords into data-derived bounds \
2852 (worst excess={worst_delta:.3} log units); user length_scale falls outside \
2853 [{KERNEL_RANGE_MIN_DIAMETER_FRACTION}/r_max, {KERNEL_RANGE_MAX_SPACING_MULTIPLE}/r_min] geometry window",
2854 self.values.len()
2855 );
2856 }
2857 self
2858 }
2859
2860 pub fn from_theta_tail_with_dims(
2862 theta: &Array1<f64>,
2863 start: usize,
2864 dims_per_term: Vec<usize>,
2865 ) -> Self {
2866 let total: usize = dims_per_term.iter().sum();
2867 Self {
2868 values: theta.slice(s![start..start + total]).to_owned(),
2869 dims_per_term,
2870 }
2871 }
2872
2873 pub fn len(&self) -> usize {
2875 self.values.len()
2876 }
2877
2878 pub fn dims_per_term(&self) -> &[usize] {
2880 &self.dims_per_term
2881 }
2882
2883 fn term_offset(&self, term_idx: usize) -> usize {
2885 self.dims_per_term[..term_idx].iter().sum()
2886 }
2887
2888 pub fn term_slice(&self, term_idx: usize) -> &[f64] {
2890 let offset = self.term_offset(term_idx);
2891 let d = self.dims_per_term[term_idx];
2892 &self.values.as_slice().unwrap()[offset..offset + d]
2893 }
2894
2895 pub fn as_array(&self) -> &Array1<f64> {
2896 &self.values
2897 }
2898
2899 pub fn set_scalar_slot(&mut self, slot: usize, value: f64) -> bool {
2905 if slot >= self.dims_per_term.len() || self.dims_per_term[slot] != 1 {
2906 return false;
2907 }
2908 let offset = self.term_offset(slot);
2909 self.values[offset] = value;
2910 true
2911 }
2912
2913 pub fn split_at(&self, mid: usize) -> (Self, Self) {
2916 let flat_mid: usize = self.dims_per_term[..mid].iter().sum();
2917 (
2918 Self {
2919 values: self.values.slice(s![0..flat_mid]).to_owned(),
2920 dims_per_term: self.dims_per_term[..mid].to_vec(),
2921 },
2922 Self {
2923 values: self.values.slice(s![flat_mid..]).to_owned(),
2924 dims_per_term: self.dims_per_term[mid..].to_vec(),
2925 },
2926 )
2927 }
2928
2929 pub fn apply_tospec(
2936 &self,
2937 spec: &TermCollectionSpec,
2938 term_indices: &[usize],
2939 ) -> Result<TermCollectionSpec, EstimationError> {
2940 if term_indices.len() != self.dims_per_term.len() {
2941 crate::bail_invalid_estim!(
2942 "SpatialLogKappaCoords::apply_tospec: term count mismatch: \
2943 term_indices={} dims_per_term={}",
2944 term_indices.len(),
2945 self.dims_per_term.len()
2946 );
2947 }
2948 let mut updated = spec.clone();
2949 for (slot, &term_idx) in term_indices.iter().enumerate() {
2950 let psi = self.term_slice(slot);
2951 let d = self.dims_per_term[slot];
2952 if measure_jet_term_spec(&updated, term_idx).is_some() {
2955 set_measure_jet_psi_dials(&mut updated, term_idx, psi)?;
2956 continue;
2957 }
2958 if constant_curvature_term_spec(&updated, term_idx).is_some() {
2962 set_constant_curvature_kappa(&mut updated, term_idx, psi)?;
2963 continue;
2964 }
2965 let (next_length_scale, next_aniso) = spatial_term_psi_to_length_scale_and_aniso(psi);
2966 if (d == 1 || next_length_scale.is_some())
2967 && let Some(length_scale) = next_length_scale
2968 {
2969 set_spatial_length_scale(&mut updated, term_idx, length_scale)?;
2970 }
2971 if let Some(eta) = next_aniso {
2972 set_spatial_aniso_log_scales(&mut updated, term_idx, eta)?;
2973 }
2974 }
2975 Ok(updated)
2976 }
2977}
2978
2979pub fn center_aniso_log_scales(eta: &[f64]) -> Vec<f64> {
2980 if eta.len() <= 1 {
2981 return eta.to_vec();
2982 }
2983 let mean = eta.iter().sum::<f64>() / eta.len() as f64;
2984 eta.iter()
2985 .map(|&v| {
2986 let centered = v - mean;
2987 if centered.abs() <= 1e-15 {
2988 0.0
2989 } else {
2990 centered
2991 }
2992 })
2993 .collect()
2994}
2995
2996pub fn spatial_term_uses_per_axis_psi(resolvedspec: &TermCollectionSpec, term_idx: usize) -> bool {
2999 if let Some(mj) = measure_jet_term_spec(resolvedspec, term_idx) {
3000 return measure_jet_enrolls_psi(mj);
3001 }
3002 let Some(d) = get_spatial_feature_dim(resolvedspec, term_idx) else {
3003 return false;
3004 };
3005 if d <= 1 {
3006 return false;
3007 }
3008 let Some(eta) = get_spatial_aniso_log_scales(resolvedspec, term_idx) else {
3009 return false;
3010 };
3011 if eta.len() != d {
3012 return false;
3013 }
3014 !matches!(
3015 resolvedspec.smooth_terms.get(term_idx).map(|term| &term.basis),
3016 Some(SmoothBasisSpec::Duchon { .. })
3017 )
3018}
3019
3020pub fn set_spatial_length_scale(
3021 spec: &mut TermCollectionSpec,
3022 term_idx: usize,
3023 length_scale: f64,
3024) -> Result<(), EstimationError> {
3025 let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3026 crate::bail_invalid_estim!("spatial length-scale term index {term_idx} out of range");
3027 };
3028 match &mut term.basis {
3029 SmoothBasisSpec::ThinPlate { spec, .. } => {
3030 spec.length_scale = length_scale;
3031 Ok(())
3032 }
3033 SmoothBasisSpec::Matern { spec, .. } => {
3034 spec.length_scale = length_scale;
3035 Ok(())
3036 }
3037 SmoothBasisSpec::Duchon { spec, .. } => {
3038 spec.length_scale = Some(length_scale);
3039 Ok(())
3040 }
3041 _ => Err(EstimationError::InvalidInput(format!(
3042 "term '{}' does not expose a spatial length scale",
3043 term.name
3044 ))),
3045 }
3046}
3047
3048pub fn get_spatial_length_scale(spec: &TermCollectionSpec, term_idx: usize) -> Option<f64> {
3049 spec.smooth_terms
3050 .get(term_idx)
3051 .and_then(|term| match &term.basis {
3052 SmoothBasisSpec::ThinPlate { spec, .. } => Some(spec.length_scale),
3053 SmoothBasisSpec::Matern { spec, .. } => Some(spec.length_scale),
3054 SmoothBasisSpec::Duchon { spec, .. } => spec.length_scale,
3055 _ => None,
3056 })
3057}
3058
3059pub fn spatial_term_supports_hyper_optimization(spec: &TermCollectionSpec, term_idx: usize) -> bool {
3060 if let Some(term) = spec.smooth_terms.get(term_idx)
3066 && let SmoothBasisSpec::ThinPlate { .. } = &term.basis
3067 {
3068 return false;
3069 }
3070
3071 if let Some(term) = spec.smooth_terms.get(term_idx)
3096 && let SmoothBasisSpec::Matern { .. } = &term.basis
3097 {
3098 return true;
3099 }
3100
3101 if let Some(mj) = measure_jet_term_spec(spec, term_idx) {
3104 return measure_jet_enrolls_psi(mj);
3105 }
3106
3107 if constant_curvature_term_spec(spec, term_idx).is_some() {
3114 return true;
3115 }
3116
3117 get_spatial_length_scale(spec, term_idx).is_some()
3118}
3119
3120pub fn measure_jet_term_spec(
3123 spec: &TermCollectionSpec,
3124 term_idx: usize,
3125) -> Option<&crate::basis::MeasureJetBasisSpec> {
3126 spec.smooth_terms
3127 .get(term_idx)
3128 .and_then(|term| match &term.basis {
3129 SmoothBasisSpec::MeasureJet { spec, .. } => Some(spec),
3130 _ => None,
3131 })
3132}
3133
3134pub fn measure_jet_enrolls_psi(mj: &crate::basis::MeasureJetBasisSpec) -> bool {
3141 measure_jet_learns_length_scale(mj)
3150 || (mj.tau0 > 0.0 && crate::basis::measure_jet_multiscale_mode(mj))
3151}
3152
3153pub fn measure_jet_learns_length_scale(mj: &crate::basis::MeasureJetBasisSpec) -> bool {
3156 mj.learn_length_scale
3157}
3158
3159pub fn freeze_measure_jet_length_scale_learning(spec: &mut TermCollectionSpec) -> usize {
3160 let mut frozen = 0;
3161 for term in spec.smooth_terms.iter_mut() {
3162 if let SmoothBasisSpec::MeasureJet { spec: mj, .. } = &mut term.basis
3163 && mj.learn_length_scale
3164 {
3165 mj.learn_length_scale = false;
3166 frozen += 1;
3167 }
3168 }
3169 frozen
3170}
3171
3172pub const MEASURE_JET_PSI_ALPHA_BOUNDS: (f64, f64) = (-1.0, 3.0);
3180
3181pub const MEASURE_JET_PSI_LN_TAU_BOUNDS: (f64, f64) = (-18.420680743952367, 4.605170185988092);
3182
3183pub const MEASURE_JET_PSI_LN_LENGTH_SCALE_BOUNDS: (f64, f64) = (-6.907755278982137, 4.605170185988092);
3189
3190pub fn measure_jet_penalty_psi_dim(mj: &crate::basis::MeasureJetBasisSpec) -> usize {
3198 if crate::basis::measure_jet_multiscale_mode(mj) {
3199 2
3200 } else {
3201 0
3202 }
3203}
3204
3205pub fn measure_jet_psi_dim(mj: &crate::basis::MeasureJetBasisSpec) -> usize {
3209 usize::from(measure_jet_learns_length_scale(mj)) + measure_jet_penalty_psi_dim(mj)
3210}
3211
3212pub fn measure_jet_psi_seed(mj: &crate::basis::MeasureJetBasisSpec) -> Vec<f64> {
3217 let mut seed = Vec::with_capacity(measure_jet_psi_dim(mj));
3218 if measure_jet_learns_length_scale(mj) {
3219 let ell = if mj.length_scale > 0.0 {
3223 mj.length_scale
3224 } else {
3225 1.0
3226 };
3227 seed.push(ell.ln());
3228 }
3229 if measure_jet_penalty_psi_dim(mj) > 0 {
3230 let ln_tau = mj.tau0.max(f64::MIN_POSITIVE).ln();
3232 seed.extend_from_slice(&[mj.alpha, ln_tau]);
3233 }
3234 seed
3235}
3236
3237pub fn measure_jet_psi_bound_values(mj: &crate::basis::MeasureJetBasisSpec, upper: bool) -> Vec<f64> {
3240 let pick = |b: (f64, f64)| if upper { b.1 } else { b.0 };
3241 let mut bounds = Vec::with_capacity(measure_jet_psi_dim(mj));
3242 if measure_jet_learns_length_scale(mj) {
3243 bounds.push(pick(MEASURE_JET_PSI_LN_LENGTH_SCALE_BOUNDS));
3244 }
3245 if measure_jet_penalty_psi_dim(mj) > 0 {
3246 bounds.push(pick(MEASURE_JET_PSI_ALPHA_BOUNDS));
3248 bounds.push(pick(MEASURE_JET_PSI_LN_TAU_BOUNDS));
3249 }
3250 bounds
3251}
3252
3253pub fn apply_measure_jet_psi(
3258 mj: &mut crate::basis::MeasureJetBasisSpec,
3259 psi: &[f64],
3260) -> Result<bool, EstimationError> {
3261 if psi.len() != measure_jet_psi_dim(mj) {
3262 crate::bail_invalid_estim!(
3263 "measure-jet ψ write-back dimension mismatch: got {} values for a {}-dial term",
3264 psi.len(),
3265 measure_jet_psi_dim(mj)
3266 );
3267 }
3268 let mut changed = false;
3269 let mut cursor = 0usize;
3273 if measure_jet_learns_length_scale(mj) {
3274 let next_ell = psi[cursor].exp();
3275 cursor += 1;
3276 if !(next_ell.is_finite() && next_ell > 0.0) {
3277 crate::bail_invalid_estim!(
3278 "measure-jet ψ write-back produced a non-finite/non-positive length_scale (ℓ={next_ell})"
3279 );
3280 }
3281 if next_ell != mj.length_scale {
3282 mj.length_scale = next_ell;
3283 changed = true;
3284 }
3285 }
3286 if measure_jet_penalty_psi_dim(mj) > 0 {
3287 let next_alpha = psi[cursor];
3290 let next_tau = psi[cursor + 1].exp();
3291 if !(next_alpha.is_finite() && next_tau.is_finite() && next_tau > 0.0) {
3292 crate::bail_invalid_estim!(
3293 "measure-jet ψ write-back produced non-finite dials (alpha={next_alpha}, tau={next_tau})"
3294 );
3295 }
3296 if next_alpha != mj.alpha {
3297 mj.alpha = next_alpha;
3298 changed = true;
3299 }
3300 if next_tau != mj.tau0 {
3301 mj.tau0 = next_tau;
3302 changed = true;
3303 }
3304 }
3305 Ok(changed)
3306}
3307
3308pub fn set_measure_jet_psi_dials(
3311 spec: &mut TermCollectionSpec,
3312 term_idx: usize,
3313 psi: &[f64],
3314) -> Result<bool, EstimationError> {
3315 let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3316 crate::bail_invalid_estim!("measure-jet ψ write-back: term index {term_idx} out of range");
3317 };
3318 set_single_term_measure_jet_psi_dials(term, psi)
3319}
3320
3321pub fn set_single_term_measure_jet_psi_dials(
3326 term: &mut SmoothTermSpec,
3327 psi: &[f64],
3328) -> Result<bool, EstimationError> {
3329 let SmoothBasisSpec::MeasureJet { spec: mj, .. } = &mut term.basis else {
3330 crate::bail_invalid_estim!("measure-jet ψ write-back targeted a non-measure-jet term");
3331 };
3332 apply_measure_jet_psi(mj, psi)
3333}
3334
3335pub fn constant_curvature_term_spec(
3338 spec: &TermCollectionSpec,
3339 term_idx: usize,
3340) -> Option<&crate::basis::ConstantCurvatureBasisSpec> {
3341 spec.smooth_terms
3342 .get(term_idx)
3343 .and_then(|term| match &term.basis {
3344 SmoothBasisSpec::ConstantCurvature { spec, .. } => Some(spec),
3345 _ => None,
3346 })
3347}
3348
3349pub const CONSTANT_CURVATURE_KAPPA_CHART_FRACTION: f64 = 0.5;
3357
3358pub const CONSTANT_CURVATURE_MIN_CHART_RADIUS2: f64 = 1e-8;
3362
3363pub fn constant_curvature_kappa_bounds(
3368 data: ArrayView2<'_, f64>,
3369 spec: &TermCollectionSpec,
3370 term_idx: usize,
3371) -> (f64, f64) {
3372 let feature_cols = match spec.smooth_terms.get(term_idx).map(|t| &t.basis) {
3373 Some(SmoothBasisSpec::ConstantCurvature { feature_cols, .. }) => feature_cols,
3374 _ => return (-1.0, 1.0),
3375 };
3376 let mut max_r2 = CONSTANT_CURVATURE_MIN_CHART_RADIUS2;
3377 for row in data.outer_iter() {
3378 let mut r2 = 0.0_f64;
3379 for &c in feature_cols.iter() {
3380 if let Some(&v) = row.get(c)
3381 && v.is_finite()
3382 {
3383 r2 += v * v;
3384 }
3385 }
3386 if r2 > max_r2 {
3387 max_r2 = r2;
3388 }
3389 }
3390 let half = CONSTANT_CURVATURE_KAPPA_CHART_FRACTION / max_r2;
3391 (-half, half)
3392}
3393
3394pub fn set_constant_curvature_kappa(
3398 spec: &mut TermCollectionSpec,
3399 term_idx: usize,
3400 psi: &[f64],
3401) -> Result<bool, EstimationError> {
3402 let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3403 crate::bail_invalid_estim!(
3404 "constant-curvature κ write-back: term index {term_idx} out of range"
3405 );
3406 };
3407 set_single_term_constant_curvature_kappa(term, psi)
3408}
3409
3410pub fn set_single_term_constant_curvature_kappa(
3415 term: &mut SmoothTermSpec,
3416 psi: &[f64],
3417) -> Result<bool, EstimationError> {
3418 if psi.len() != 1 {
3419 crate::bail_invalid_estim!(
3420 "constant-curvature κ write-back expects exactly one value, got {}",
3421 psi.len()
3422 );
3423 }
3424 let next_kappa = psi[0];
3425 if !next_kappa.is_finite() {
3426 crate::bail_invalid_estim!(
3427 "constant-curvature κ write-back produced a non-finite κ = {next_kappa}"
3428 );
3429 }
3430 let SmoothBasisSpec::ConstantCurvature { spec: cc, .. } = &mut term.basis else {
3431 crate::bail_invalid_estim!(
3432 "constant-curvature κ write-back targeted a non-constant-curvature term"
3433 );
3434 };
3435 if cc.kappa != next_kappa {
3436 cc.kappa = next_kappa;
3437 Ok(true)
3438 } else {
3439 Ok(false)
3440 }
3441}
3442
3443pub fn spatial_term_has_locked_kappa(spec: &TermCollectionSpec, term_idx: usize) -> bool {
3454 get_spatial_length_scale(spec, term_idx).is_some()
3455 && !spatial_term_uses_per_axis_psi(spec, term_idx)
3456}
3457
3458pub fn all_spatial_terms_kappa_fixed(spec: &TermCollectionSpec) -> bool {
3459 spec.smooth_terms.iter().enumerate().all(|(idx, _)| {
3460 !spatial_term_supports_hyper_optimization(spec, idx)
3461 || spatial_term_has_locked_kappa(spec, idx)
3462 })
3463}
3464
3465pub fn spatial_identifiability_policy(termspec: &SmoothTermSpec) -> Option<&SpatialIdentifiability> {
3466 match &termspec.basis {
3467 SmoothBasisSpec::ThinPlate { spec, .. } => Some(&spec.identifiability),
3468 SmoothBasisSpec::Duchon { spec, .. } => Some(&spec.identifiability),
3469 _ => None,
3470 }
3471}
3472
3473pub const NULLSPACE_WELLDET_DEGENERACY_RHO_SD: f64 = 15.0;
3477
3478pub fn is_nullspace_degeneracy_prior(prior: &gam_spec::RhoPrior) -> bool {
3481 matches!(
3482 prior,
3483 gam_spec::RhoPrior::Normal { mean, sd }
3484 if *mean == 0.0 && *sd == NULLSPACE_WELLDET_DEGENERACY_RHO_SD
3485 )
3486}
3487
3488pub const KERNEL_RANGE_MIN_DIAMETER_FRACTION: f64 = 2.0;
3500
3501pub const KERNEL_RANGE_MAX_SPACING_MULTIPLE: f64 = 1e2;
3506
3507
3508pub fn spatial_term_psi_bounds(
3517 data: ArrayView2<'_, f64>,
3518 spec: &TermCollectionSpec,
3519 term_idx: usize,
3520 options: &SpatialLengthScaleOptimizationOptions,
3521) -> (f64, f64) {
3522 let fallback = (
3523 -options.max_length_scale.ln(),
3524 -options.min_length_scale.ln(),
3525 );
3526 if constant_curvature_term_spec(spec, term_idx).is_some() {
3531 return constant_curvature_kappa_bounds(data, spec, term_idx);
3532 }
3533 let Some(term) = spec.smooth_terms.get(term_idx) else {
3534 return fallback;
3535 };
3536 let aniso = get_spatial_aniso_log_scales(spec, term_idx);
3549 let r_bounds = match spatial_term_center_strategy(term) {
3550 Some(CenterStrategy::UserProvided(centers)) if centers.nrows() >= 2 => {
3551 match aniso.as_deref() {
3552 Some(eta) if eta.len() == centers.ncols() => {
3553 let y = points_in_aniso_y_space(centers.view(), eta);
3554 pairwise_distance_bounds(y.view())
3555 }
3556 _ => pairwise_distance_bounds(centers.view()),
3557 }
3558 }
3559 _ => standardized_spatial_term_data(data, term)
3560 .ok()
3561 .and_then(|x| match aniso.as_deref() {
3562 Some(eta) if eta.len() == x.ncols() => {
3563 let y = points_in_aniso_y_space(x.view(), eta);
3564 pairwise_distance_bounds_sampled(y.view())
3565 }
3566 _ => pairwise_distance_bounds_sampled(x.view()),
3567 }),
3568 };
3569 let Some((r_min, r_max)) = r_bounds else {
3570 return fallback;
3571 };
3572 let psi_lo_data = (KERNEL_RANGE_MIN_DIAMETER_FRACTION / r_max).ln();
3578 let psi_hi_data = (KERNEL_RANGE_MAX_SPACING_MULTIPLE / r_min).ln();
3579 let psi_lo = psi_lo_data.max(fallback.0);
3589 let psi_hi = psi_hi_data.min(fallback.1);
3590 if psi_lo >= psi_hi {
3591 return fallback;
3594 }
3595 (psi_lo, psi_hi)
3596}
3597
3598pub fn spatial_term_psi_seed(
3602 data: ArrayView2<'_, f64>,
3603 spec: &TermCollectionSpec,
3604 term_idx: usize,
3605 options: &SpatialLengthScaleOptimizationOptions,
3606) -> Option<f64> {
3607 if get_spatial_length_scale(spec, term_idx).is_some() {
3608 return None; }
3610 let (psi_lo, psi_hi) = spatial_term_psi_bounds(data, spec, term_idx, options);
3611 Some(0.5 * (psi_lo + psi_hi))
3612}
3613
3614pub fn spatial_term_psi_to_length_scale_and_aniso(psi: &[f64]) -> (Option<f64>, Option<Vec<f64>>) {
3615 if psi.len() <= 1 {
3616 (Some((-psi.first().copied().unwrap_or(0.0)).exp()), None)
3617 } else {
3618 let psi_bar = psi.iter().sum::<f64>() / psi.len() as f64;
3619 (
3620 Some((-psi_bar).exp()),
3621 Some(psi.iter().map(|&value| value - psi_bar).collect()),
3622 )
3623 }
3624}
3625
3626pub fn get_spatial_aniso_log_scales(
3628 spec: &TermCollectionSpec,
3629 term_idx: usize,
3630) -> Option<Vec<f64>> {
3631 spec.smooth_terms
3632 .get(term_idx)
3633 .and_then(|term| match &term.basis {
3634 SmoothBasisSpec::Matern { spec, .. } => spec.aniso_log_scales.clone(),
3635 SmoothBasisSpec::Duchon { spec, .. } => spec.aniso_log_scales.clone(),
3636 _ => None,
3637 })
3638}
3639
3640pub fn response_aware_axis_contrasts(
3660 x: ndarray::ArrayView2<'_, f64>,
3661 y: ndarray::ArrayView1<'_, f64>,
3662) -> Option<Vec<f64>> {
3663 let n = x.nrows();
3664 let d = x.ncols();
3665 if d <= 1 || n < 4 || y.len() != n {
3666 return None;
3667 }
3668 if x.iter().any(|v| !v.is_finite()) || y.iter().any(|v| !v.is_finite()) {
3669 return None;
3670 }
3671 let mut scores = Vec::with_capacity(d);
3672 for a in 0..d {
3673 let mut order: Vec<usize> = (0..n).collect();
3674 let col = x.column(a);
3675 order.sort_by(|&i, &j| {
3676 col[i]
3677 .partial_cmp(&col[j])
3678 .unwrap_or(std::cmp::Ordering::Equal)
3679 });
3680 let mut tv = 0.0_f64;
3681 for w in order.windows(2) {
3682 let diff = y[w[1]] - y[w[0]];
3683 tv += diff * diff;
3684 }
3685 scores.push(-0.5 * (tv + 1e-12).ln());
3687 }
3688 if scores.iter().any(|v| !v.is_finite()) {
3689 return None;
3690 }
3691 let mean = scores.iter().sum::<f64>() / d as f64;
3692 let centered: Vec<f64> = scores.iter().map(|&s| s - mean).collect();
3693 if centered.iter().all(|&v| v.abs() < 1e-9) {
3696 return None;
3697 }
3698 Some(centered)
3699}
3700
3701pub fn apply_response_aware_anisotropy_seed(
3710 data: ArrayView2<'_, f64>,
3711 y: ndarray::ArrayView1<'_, f64>,
3712 spec: &mut TermCollectionSpec,
3713 spatial_terms: &[usize],
3714) {
3715 const MAX_NUDGE: f64 = std::f64::consts::LN_2;
3720 for &term_idx in spatial_terms {
3721 let Some(current_eta) = get_spatial_aniso_log_scales(spec, term_idx) else {
3722 continue;
3723 };
3724 let d = current_eta.len();
3725 if d <= 1 {
3726 continue;
3727 }
3728 let Some(term) = spec.smooth_terms.get(term_idx) else {
3729 continue;
3730 };
3731 let feature_cols = term.basis.structural_feature_cols();
3732 if feature_cols.len() != d {
3733 continue;
3734 }
3735 let Ok(x) = select_columns(data, &feature_cols) else {
3736 continue;
3737 };
3738 let Some(contrast) = response_aware_axis_contrasts(x.view(), y) else {
3739 continue;
3740 };
3741 let nudged: Vec<f64> = current_eta
3742 .iter()
3743 .zip(contrast.iter())
3744 .map(|(&eta_a, &c_a)| eta_a + c_a.clamp(-MAX_NUDGE, MAX_NUDGE))
3745 .collect();
3746 if let Err(err) = set_spatial_aniso_log_scales(spec, term_idx, nudged) {
3749 log::debug!(
3750 "[spatial-kappa] response-aware anisotropy seed skipped for term {term_idx}: {err}"
3751 );
3752 }
3753 }
3754}
3755
3756pub fn get_spatial_feature_dim(spec: &TermCollectionSpec, term_idx: usize) -> Option<usize> {
3758 spec.smooth_terms
3759 .get(term_idx)
3760 .and_then(|term| match &term.basis {
3761 SmoothBasisSpec::ThinPlate { feature_cols, .. } => Some(feature_cols.len()),
3762 SmoothBasisSpec::Matern { feature_cols, .. } => Some(feature_cols.len()),
3763 SmoothBasisSpec::Duchon { feature_cols, .. } => Some(feature_cols.len()),
3764 _ => None,
3765 })
3766}
3767
3768pub fn log_spatial_aniso_scales(spec: &TermCollectionSpec) {
3775 for (term_idx, term) in spec.smooth_terms.iter().enumerate() {
3776 let (aniso, length_scale) = match &term.basis {
3777 SmoothBasisSpec::Matern { spec, .. } => {
3778 (spec.aniso_log_scales.as_ref(), Some(spec.length_scale))
3779 }
3780 SmoothBasisSpec::Duchon { spec, .. } => {
3781 (spec.aniso_log_scales.as_ref(), spec.length_scale)
3782 }
3783 _ => (None, None),
3784 };
3785 let Some(eta) = aniso else { continue };
3786 if eta.is_empty() {
3787 continue;
3788 }
3789 let mut lines = match length_scale {
3790 Some(ls) => format!(
3791 "[spatial-kappa] term {} (\"{}\"): anisotropic length scales optimized (global length_scale={:.4})",
3792 term_idx, term.name, ls
3793 ),
3794 None => format!(
3795 "[spatial-kappa] term {} (\"{}\"): pure Duchon shape anisotropy optimized",
3796 term_idx, term.name
3797 ),
3798 };
3799 for (a, &eta_a) in eta.iter().enumerate() {
3800 if let Some(ls) = length_scale {
3801 let length_a = ls * (-eta_a).exp();
3802 let kappa_a = (1.0 / ls) * eta_a.exp();
3803 lines.push_str(&format!(
3804 "\n axis {}: eta={:+.4}, length={:.4}, kappa={:.4}",
3805 a, eta_a, length_a, kappa_a
3806 ));
3807 } else {
3808 lines.push_str(&format!("\n axis {}: eta={:+.4}", a, eta_a));
3809 }
3810 }
3811 log::info!("{}", lines);
3812 }
3813}
3814
3815pub fn set_spatial_aniso_log_scales(
3817 spec: &mut TermCollectionSpec,
3818 term_idx: usize,
3819 eta: Vec<f64>,
3820) -> Result<(), EstimationError> {
3821 let eta = center_aniso_log_scales(&eta);
3822 let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3823 crate::bail_invalid_estim!("spatial aniso_log_scales term index {term_idx} out of range");
3824 };
3825 match &mut term.basis {
3826 SmoothBasisSpec::Matern { spec, .. } => {
3827 spec.aniso_log_scales = Some(eta);
3828 Ok(())
3829 }
3830 SmoothBasisSpec::Duchon { spec, .. } => {
3831 spec.aniso_log_scales = Some(eta);
3832 Ok(())
3833 }
3834 _ => Err(EstimationError::InvalidInput(format!(
3835 "term '{}' does not support aniso_log_scales",
3836 term.name
3837 ))),
3838 }
3839}
3840
3841pub fn sync_aniso_contrasts_from_metadata(
3848 spec: &mut TermCollectionSpec,
3849 design: &SmoothDesign,
3850) {
3851 for (term_idx, term) in design.terms.iter().enumerate() {
3852 let meta_aniso = match &term.metadata {
3853 BasisMetadata::Matern {
3854 aniso_log_scales, ..
3855 } => aniso_log_scales.clone(),
3856 BasisMetadata::Duchon {
3857 aniso_log_scales, ..
3858 } => aniso_log_scales.clone(),
3859 _ => None,
3860 };
3861 if let Some(eta) = meta_aniso
3862 && eta.len() > 1
3863 {
3864 set_spatial_aniso_log_scales(spec, term_idx, eta).ok();
3865 }
3866 }
3867}
3868
3869#[derive(Debug, Clone)]
3870pub struct SpatialLengthScaleOptimizationOptions {
3871 pub enabled: bool,
3875 pub max_outer_iter: usize,
3877 pub rel_tol: f64,
3879 pub log_step: f64,
3881 pub min_length_scale: f64,
3883 pub max_length_scale: f64,
3885 pub pilot_subsample_threshold: usize,
3898 pub outer_wall_clock_budget_secs: Option<f64>,
3906}
3907
3908impl Default for SpatialLengthScaleOptimizationOptions {
3909 fn default() -> Self {
3910 Self {
3911 enabled: true,
3912 max_outer_iter: 80,
3913 rel_tol: 1e-4,
3914 log_step: std::f64::consts::LN_2,
3915 min_length_scale: 1e-3,
3916 max_length_scale: 1e3,
3917 pilot_subsample_threshold: 10_000,
3918 outer_wall_clock_budget_secs: None,
3919 }
3920 }
3921}
3922
3923impl SpatialLengthScaleOptimizationOptions {
3924 pub fn validate(&self) -> Result<(), String> {
3942 if !self.min_length_scale.is_finite() || self.min_length_scale <= 0.0 {
3943 return Err(SmoothError::invalid_config(format!(
3944 "SpatialLengthScaleOptimizationOptions::min_length_scale must be > 0 and finite, got {}",
3945 self.min_length_scale
3946 ))
3947 .into());
3948 }
3949 if !self.max_length_scale.is_finite() || self.max_length_scale <= 0.0 {
3950 return Err(SmoothError::invalid_config(format!(
3951 "SpatialLengthScaleOptimizationOptions::max_length_scale must be > 0 and finite, got {}",
3952 self.max_length_scale
3953 ))
3954 .into());
3955 }
3956 if self.min_length_scale >= self.max_length_scale {
3957 return Err(SmoothError::invalid_config(format!(
3958 "SpatialLengthScaleOptimizationOptions requires min_length_scale < max_length_scale, got min={} max={}",
3959 self.min_length_scale, self.max_length_scale
3960 ))
3961 .into());
3962 }
3963 if !self.rel_tol.is_finite() || self.rel_tol <= 0.0 {
3964 return Err(SmoothError::invalid_config(format!(
3965 "SpatialLengthScaleOptimizationOptions::rel_tol must be > 0 and finite, got {}",
3966 self.rel_tol
3967 ))
3968 .into());
3969 }
3970 if !self.log_step.is_finite() || self.log_step <= 0.0 {
3971 return Err(SmoothError::invalid_config(format!(
3972 "SpatialLengthScaleOptimizationOptions::log_step must be > 0 and finite, got {}",
3973 self.log_step
3974 ))
3975 .into());
3976 }
3977 Ok(())
3978 }
3979}
3980
3981#[derive(Debug, Clone)]
3982pub struct RandomEffectBlock {
3983 pub name: String,
3984 pub group_ids: Vec<Option<usize>>,
3987 pub num_groups: usize,
3988 pub kept_levels: Vec<u64>,
3989}
3990
3991pub const BLOCK_SPARSE_ZERO_EPS: f64 = 1e-12;
3992
3993pub const BLOCK_SPARSE_MAX_DENSITY: f64 = 0.20;
3994
3995pub fn blocks_have_intrinsic_sparse_structure(blocks: &[DesignBlock]) -> bool {
3996 blocks
3997 .iter()
3998 .any(|block| matches!(block, DesignBlock::Sparse(_) | DesignBlock::RandomEffect(_)))
3999}
4000
4001pub fn sparse_compatible_block_nnz(block: &DesignBlock) -> Option<usize> {
4002 match block {
4003 DesignBlock::Intercept(n) => Some(*n),
4004 DesignBlock::RandomEffect(op) => {
4005 Some(op.group_ids.iter().filter(|gid| gid.is_some()).count())
4006 }
4007 DesignBlock::Sparse(sparse) => Some(sparse.val().len()),
4008 DesignBlock::Dense(dense) => dense.as_dense_ref().map(|matrix| {
4009 matrix
4010 .iter()
4011 .filter(|&&value| value.abs() > BLOCK_SPARSE_ZERO_EPS)
4012 .count()
4013 }),
4014 }
4015}
4016
4017pub fn try_build_sparse_design_from_blocks(
4018 blocks: &[DesignBlock],
4019) -> Result<Option<DesignMatrix>, BasisError> {
4020 if blocks.is_empty() {
4021 return Ok(None);
4022 }
4023 let nrows = blocks[0].nrows();
4024 let ncols: usize = blocks.iter().map(DesignBlock::ncols).sum();
4025 if nrows == 0 || ncols == 0 || ncols <= 32 {
4026 return Ok(None);
4027 }
4028
4029 let preserve_sparse_storage = blocks_have_intrinsic_sparse_structure(blocks);
4030 let sparse_nnz_limit = if preserve_sparse_storage {
4031 usize::MAX
4032 } else {
4033 let total_cells = nrows.saturating_mul(ncols);
4034 ((total_cells as f64) * BLOCK_SPARSE_MAX_DENSITY).floor() as usize
4035 };
4036 let mut nnz = 0usize;
4037 for block in blocks {
4038 let block_nnz = if let Some(block_nnz) = sparse_compatible_block_nnz(block) {
4039 block_nnz
4040 } else {
4041 return Ok(None);
4042 };
4043 nnz = nnz.saturating_add(block_nnz);
4044 if nnz > sparse_nnz_limit {
4045 return Ok(None);
4046 }
4047 }
4048
4049 let mut triplets = Vec::<Triplet<usize, usize, f64>>::with_capacity(nnz);
4050 let mut col_offset = 0usize;
4051 for block in blocks {
4052 match block {
4053 DesignBlock::Intercept(n) => {
4054 for row in 0..*n {
4055 triplets.push(Triplet::new(row, col_offset, 1.0));
4056 }
4057 }
4058 DesignBlock::RandomEffect(op) => {
4059 for (row, group_id) in op.group_ids.iter().enumerate() {
4060 if let Some(group) = group_id {
4061 triplets.push(Triplet::new(row, col_offset + group, 1.0));
4062 }
4063 }
4064 }
4065 DesignBlock::Sparse(sparse) => {
4066 let (symbolic, values) = sparse.parts();
4067 let col_ptr = symbolic.col_ptr();
4068 let row_idx = symbolic.row_idx();
4069 for col in 0..sparse.ncols() {
4070 for idx in col_ptr[col]..col_ptr[col + 1] {
4071 let value = values[idx];
4072 if value.abs() > BLOCK_SPARSE_ZERO_EPS {
4073 triplets.push(Triplet::new(row_idx[idx], col_offset + col, value));
4074 }
4075 }
4076 }
4077 }
4078 DesignBlock::Dense(dense) => {
4079 let matrix = dense.as_dense_ref().ok_or_else(|| {
4080 BasisError::InvalidInput(
4081 "sparse-compatible block assembly requires materialized dense blocks"
4082 .to_string(),
4083 )
4084 })?;
4085 for row in 0..matrix.nrows() {
4086 for col in 0..matrix.ncols() {
4087 let value = matrix[[row, col]];
4088 if value.abs() > BLOCK_SPARSE_ZERO_EPS {
4089 triplets.push(Triplet::new(row, col_offset + col, value));
4090 }
4091 }
4092 }
4093 }
4094 }
4095 col_offset += block.ncols();
4096 }
4097
4098 let sparse = SparseColMat::try_new_from_triplets(nrows, ncols, &triplets).map_err(|_| {
4099 BasisError::SparseCreation("failed to assemble sparse term-collection design".to_string())
4100 })?;
4101 Ok(Some(DesignMatrix::Sparse(
4102 gam_linalg::matrix::SparseDesignMatrix::new(sparse),
4103 )))
4104}
4105
4106pub fn assemble_term_collection_design_matrix(
4107 blocks: Vec<DesignBlock>,
4108) -> Result<DesignMatrix, BasisError> {
4109 if let Some(sparse) = try_build_sparse_design_from_blocks(&blocks)? {
4110 return Ok(sparse);
4111 }
4112 let block_op = BlockDesignOperator::new(blocks).map_err(|e| {
4113 BasisError::InvalidInput(format!("failed to build block design operator: {e}"))
4114 })?;
4115 Ok(DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
4116 Arc::new(block_op),
4117 )))
4118}
4119
4120pub fn select_columns(data: ArrayView2<'_, f64>, cols: &[usize]) -> Result<Array2<f64>, BasisError> {
4121 let n = data.nrows();
4122 let p = data.ncols();
4123 for &c in cols {
4124 if c >= p {
4125 crate::bail_dim_basis!("feature column {c} is out of bounds for data with {p} columns");
4126 }
4127 }
4128 let mut out = Array2::<f64>::zeros((n, cols.len()));
4129 for (j, &c) in cols.iter().enumerate() {
4130 out.column_mut(j).assign(&data.column(c));
4131 }
4132 Ok(out)
4133}
4134
4135pub fn nonfinite_value_label(value: f64) -> &'static str {
4136 if value.is_nan() {
4137 "NaN"
4138 } else if value.is_sign_positive() {
4139 "+Inf"
4140 } else {
4141 "-Inf"
4142 }
4143}
4144
4145pub fn validate_term_feature_column_finite(
4146 data: ArrayView2<'_, f64>,
4147 term_kind: &str,
4148 term_name: &str,
4149 feature_col: usize,
4150) -> Result<(), BasisError> {
4151 let p = data.ncols();
4152 if feature_col >= p {
4153 crate::bail_dim_basis!(
4154 "{term_kind} term '{term_name}' feature column {feature_col} out of bounds for {p} columns"
4155 );
4156 }
4157 for (row, &value) in data.column(feature_col).iter().enumerate() {
4158 if !value.is_finite() {
4159 crate::bail_invalid_basis!(
4160 "{term_kind} term '{term_name}' feature column {feature_col} row {row} contains non-finite value {}",
4161 nonfinite_value_label(value)
4162 );
4163 }
4164 }
4165 Ok(())
4166}
4167
4168pub fn validate_smooth_terms_finite_inputs(
4169 data: ArrayView2<'_, f64>,
4170 terms: &[SmoothTermSpec],
4171) -> Result<(), BasisError> {
4172 for term in terms {
4173 for feature_col in smooth_term_feature_cols(term) {
4174 validate_term_feature_column_finite(data, "smooth", &term.name, feature_col)?;
4175 }
4176 }
4177 Ok(())
4178}
4179
4180pub fn validate_term_collection_finite_inputs(
4181 data: ArrayView2<'_, f64>,
4182 spec: &TermCollectionSpec,
4183) -> Result<(), BasisError> {
4184 for term in &spec.linear_terms {
4185 validate_term_feature_column_finite(data, "linear", &term.name, term.feature_col)?;
4186 }
4187 for term in &spec.random_effect_terms {
4188 validate_term_feature_column_finite(data, "random-effect", &term.name, term.feature_col)?;
4189 }
4190 validate_smooth_terms_finite_inputs(data, &spec.smooth_terms)
4191}
4192
4193#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
4194pub struct JointSpatialCenterGroupKey {
4195 feature_cols: Vec<usize>,
4196 strategy_kind: CenterStrategyKind,
4197 strategy_aux: usize,
4198 requested_num_centers: usize,
4199 input_scale_bits: Option<Vec<u64>>,
4200}
4201
4202pub fn spatial_term_min_center_count(term: &SmoothTermSpec) -> usize {
4203 match &term.basis {
4204 SmoothBasisSpec::ThinPlate { feature_cols, .. } => feature_cols.len() + 1,
4205 SmoothBasisSpec::Duchon {
4206 feature_cols, spec, ..
4207 } => match spec.nullspace_order {
4208 crate::basis::DuchonNullspaceOrder::Zero => 1,
4209 crate::basis::DuchonNullspaceOrder::Linear => feature_cols.len() + 1,
4210 crate::basis::DuchonNullspaceOrder::Degree(degree) => {
4211 crate::basis::duchon_nullspace_dimension(feature_cols.len(), degree)
4212 }
4213 },
4214 SmoothBasisSpec::Matern { .. } => 1,
4215 _ => 1,
4216 }
4217}
4218
4219pub fn spatial_term_group_key(term: &SmoothTermSpec) -> Option<JointSpatialCenterGroupKey> {
4220 let (feature_cols, strategy, input_scales) = match &term.basis {
4221 SmoothBasisSpec::ThinPlate {
4222 feature_cols,
4223 spec,
4224 input_scales,
4225 } => (feature_cols, &spec.center_strategy, input_scales.as_ref()),
4226 SmoothBasisSpec::Matern {
4227 feature_cols,
4228 spec,
4229 input_scales,
4230 } => (feature_cols, &spec.center_strategy, input_scales.as_ref()),
4231 SmoothBasisSpec::Duchon {
4232 feature_cols,
4233 spec,
4234 input_scales,
4235 } => (feature_cols, &spec.center_strategy, input_scales.as_ref()),
4236 _ => return None,
4237 };
4238 let strategy_kind = center_strategy_kind(strategy);
4239 let strategy_aux = match strategy {
4240 CenterStrategy::Auto(inner) => match inner.as_ref() {
4241 CenterStrategy::KMeans { max_iter, .. } => *max_iter,
4242 CenterStrategy::UniformGrid { points_per_dim } => *points_per_dim,
4243 _ => 0,
4244 },
4245 CenterStrategy::KMeans { max_iter, .. } => *max_iter,
4246 CenterStrategy::UniformGrid { points_per_dim } => *points_per_dim,
4247 _ => 0,
4248 };
4249 Some(JointSpatialCenterGroupKey {
4250 feature_cols: feature_cols.clone(),
4251 strategy_kind,
4252 strategy_aux,
4253 requested_num_centers: center_strategy_num_centers(strategy)?,
4254 input_scale_bits: input_scales
4255 .map(|values| values.iter().map(|value| value.to_bits()).collect()),
4256 })
4257}
4258
4259pub fn spatial_term_center_strategy(term: &SmoothTermSpec) -> Option<&CenterStrategy> {
4260 match &term.basis {
4261 SmoothBasisSpec::ThinPlate { spec, .. } => Some(&spec.center_strategy),
4262 SmoothBasisSpec::Matern { spec, .. } => Some(&spec.center_strategy),
4263 SmoothBasisSpec::Duchon { spec, .. } => Some(&spec.center_strategy),
4264 _ => None,
4265 }
4266}
4267
4268pub fn set_spatial_term_centers(
4269 term: &mut SmoothTermSpec,
4270 centers: Array2<f64>,
4271) -> Result<(), BasisError> {
4272 match &mut term.basis {
4273 SmoothBasisSpec::ThinPlate { spec, .. } => {
4274 spec.center_strategy = CenterStrategy::UserProvided(centers);
4275 Ok(())
4276 }
4277 SmoothBasisSpec::Matern { spec, .. } => {
4278 spec.center_strategy = CenterStrategy::UserProvided(centers);
4279 Ok(())
4280 }
4281 SmoothBasisSpec::Duchon { spec, .. } => {
4282 spec.center_strategy = CenterStrategy::UserProvided(centers);
4283 Ok(())
4284 }
4285 _ => Err(BasisError::InvalidInput(format!(
4286 "term '{}' does not support spatial center planning",
4287 term.name
4288 ))),
4289 }
4290}
4291
4292pub fn standardized_spatial_term_data(
4293 data: ArrayView2<'_, f64>,
4294 term: &SmoothTermSpec,
4295) -> Result<Array2<f64>, BasisError> {
4296 let (feature_cols, input_scales) = match &term.basis {
4297 SmoothBasisSpec::ThinPlate {
4298 feature_cols,
4299 input_scales,
4300 ..
4301 }
4302 | SmoothBasisSpec::Matern {
4303 feature_cols,
4304 input_scales,
4305 ..
4306 }
4307 | SmoothBasisSpec::Duchon {
4308 feature_cols,
4309 input_scales,
4310 ..
4311 } => (feature_cols, input_scales.as_ref()),
4312 _ => {
4313 crate::bail_invalid_basis!("term '{}' is not a spatial smooth", term.name);
4314 }
4315 };
4316 let mut x = select_columns(data, feature_cols)?;
4317 if let Some(scales) = input_scales {
4318 apply_input_standardization(&mut x, scales);
4319 } else if let Some(scales) = compute_spatial_input_scales(x.view()) {
4320 apply_input_standardization(&mut x, &scales);
4321 }
4322 Ok(x)
4323}
4324
4325pub fn plan_joint_spatial_centers_for_term_blocks(
4326 data: ArrayView2<'_, f64>,
4327 term_blocks: &[Vec<SmoothTermSpec>],
4328) -> Result<Vec<Vec<SmoothTermSpec>>, BasisError> {
4329 let mut planned_blocks = term_blocks.to_vec();
4330 let n = data.nrows();
4331 let mut groups: BTreeMap<JointSpatialCenterGroupKey, Vec<(usize, usize)>> = BTreeMap::new();
4332
4333 for (block_idx, terms) in planned_blocks.iter().enumerate() {
4334 for (term_idx, term) in terms.iter().enumerate() {
4335 let Some(strategy) = spatial_term_center_strategy(term) else {
4336 continue;
4337 };
4338 if !center_strategy_is_auto(strategy) {
4339 continue;
4340 }
4341 let Some(group_key) = spatial_term_group_key(term) else {
4342 continue;
4343 };
4344 if !matches!(
4345 group_key.strategy_kind,
4346 CenterStrategyKind::EqualMass
4347 | CenterStrategyKind::EqualMassCovarRepresentative
4348 | CenterStrategyKind::FarthestPoint
4349 | CenterStrategyKind::KMeans
4350 ) {
4351 continue;
4352 }
4353 if center_strategy_num_centers(strategy).is_none() {
4354 continue;
4355 }
4356 groups
4357 .entry(group_key)
4358 .or_default()
4359 .push((block_idx, term_idx));
4360 }
4361 }
4362
4363 for (group_key, members) in groups {
4364 if members.len() < 2 {
4365 continue;
4366 }
4367 let min_required = members
4368 .iter()
4369 .map(|&(block_idx, term_idx)| {
4370 spatial_term_min_center_count(&planned_blocks[block_idx][term_idx])
4371 })
4372 .max()
4373 .unwrap_or(1);
4374 let joint_centers = group_key
4375 .requested_num_centers
4376 .max(min_required)
4377 .min(n.max(1));
4378 let (first_block_idx, first_term_idx) = members[0];
4379 let prototype = &planned_blocks[first_block_idx][first_term_idx];
4380 let standardized = standardized_spatial_term_data(data, prototype)?;
4381 let strategy = spatial_term_center_strategy(prototype).ok_or_else(|| {
4382 BasisError::InvalidInput(format!(
4383 "term '{}' lost its spatial center strategy during joint planning",
4384 prototype.name
4385 ))
4386 })?;
4387 let joint_strategy = center_strategy_with_num_centers(strategy, joint_centers)?;
4388 let shared_centers = select_centers_by_strategy(standardized.view(), &joint_strategy)?;
4389 log::info!(
4390 "sharing {} spatial centers across {} smooth terms over columns {:?} (requested {} centers)",
4391 shared_centers.nrows(),
4392 members.len(),
4393 group_key.feature_cols,
4394 group_key.requested_num_centers,
4395 );
4396 for (block_idx, term_idx) in members {
4397 set_spatial_term_centers(
4398 &mut planned_blocks[block_idx][term_idx],
4399 shared_centers.clone(),
4400 )?;
4401 }
4402 }
4403
4404 for block in planned_blocks.iter_mut() {
4411 for term in block.iter_mut() {
4412 auto_init_length_scale_in_place(data, term);
4413 }
4414 }
4415
4416 Ok(planned_blocks)
4417}
4418
4419pub fn auto_initial_length_scale(data: ArrayView2<'_, f64>, feature_cols: &[usize]) -> f64 {
4426 const LENGTH_SCALE_FLOOR: f64 = 1e-6;
4429 let n = data.nrows();
4430 if n == 0 || feature_cols.is_empty() {
4431 return 1.0;
4432 }
4433 let mut max_range = 0.0_f64;
4434 for &c in feature_cols {
4435 if c >= data.ncols() {
4436 continue;
4437 }
4438 let col = data.column(c);
4439 let mut lo = f64::INFINITY;
4440 let mut hi = f64::NEG_INFINITY;
4441 for &v in col.iter() {
4442 if v.is_finite() {
4443 if v < lo {
4444 lo = v;
4445 }
4446 if v > hi {
4447 hi = v;
4448 }
4449 }
4450 }
4451 if hi > lo {
4452 let r = hi - lo;
4453 if r > max_range {
4454 max_range = r;
4455 }
4456 }
4457 }
4458 if !max_range.is_finite() || max_range <= 0.0 {
4459 return 1.0;
4460 }
4461 let init = max_range / (n as f64).sqrt();
4462 init.max(LENGTH_SCALE_FLOOR).min(max_range)
4463}
4464
4465pub fn auto_init_length_scale_in_place(data: ArrayView2<'_, f64>, term: &mut SmoothTermSpec) {
4469 auto_init_length_scale_in_basis(data, &mut term.basis);
4470}
4471
4472pub fn auto_init_length_scale_in_basis(data: ArrayView2<'_, f64>, basis: &mut SmoothBasisSpec) {
4485 match basis {
4486 SmoothBasisSpec::Matern {
4487 feature_cols, spec, ..
4488 } => {
4489 if spec.length_scale == 0.0 {
4490 spec.length_scale = auto_initial_length_scale(data, feature_cols);
4491 }
4492 }
4493 SmoothBasisSpec::ThinPlate {
4494 feature_cols, spec, ..
4495 } => {
4496 if spec.length_scale == 0.0 {
4497 spec.length_scale = auto_initial_length_scale(data, feature_cols);
4498 }
4499 }
4500 SmoothBasisSpec::ByVariable { inner, .. }
4501 | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
4502 auto_init_length_scale_in_basis(data, inner);
4503 }
4504 SmoothBasisSpec::BySmooth { smooth, .. } => {
4505 auto_init_length_scale_in_basis(data, smooth);
4506 }
4507 _ => {}
4508 }
4509}
4510
4511impl LinearFitConditioning {
4512 pub fn from_columns(design: &TermCollectionDesign, selected_cols: &[usize]) -> Self {
4513 const SCALE_EPS: f64 = 1e-12;
4514 let n = design.design.nrows();
4515 let p = design.design.ncols();
4516 let mut columns = Vec::with_capacity(selected_cols.len());
4517 if n == 0 || selected_cols.is_empty() {
4518 return Self {
4519 intercept_idx: design.intercept_range.start,
4520 columns,
4521 };
4522 }
4523 let chunk_rows = gam_linalg::utils::row_chunk_for_byte_budget(n, p);
4524 let mut sums = vec![0.0_f64; selected_cols.len()];
4530 for start in (0..n).step_by(chunk_rows) {
4531 let end = (start + chunk_rows).min(n);
4532 let chunk = design
4533 .design
4534 .try_row_chunk(start..end)
4535 .expect("LinearFitConditioning::from_columns row chunk failed");
4536 for (k, &col_idx) in selected_cols.iter().enumerate() {
4537 let column = chunk.column(col_idx);
4538 for &v in column.iter() {
4539 sums[k] += v;
4540 }
4541 }
4542 }
4543 let inv_n = 1.0_f64 / n as f64;
4544 let means: Vec<f64> = sums.iter().map(|&s| s * inv_n).collect();
4545 let mut sq_devs = vec![0.0_f64; selected_cols.len()];
4546 for start in (0..n).step_by(chunk_rows) {
4547 let end = (start + chunk_rows).min(n);
4548 let chunk = design
4549 .design
4550 .try_row_chunk(start..end)
4551 .expect("LinearFitConditioning::from_columns row chunk failed");
4552 for (k, &col_idx) in selected_cols.iter().enumerate() {
4553 let mean_k = means[k];
4554 let column = chunk.column(col_idx);
4555 for &v in column.iter() {
4556 let d = v - mean_k;
4557 sq_devs[k] += d * d;
4558 }
4559 }
4560 }
4561 for (k, &col_idx) in selected_cols.iter().enumerate() {
4562 let mean = means[k];
4563 let var = sq_devs[k] * inv_n;
4564 let (mean, scale) = if var.is_finite() && var > SCALE_EPS * SCALE_EPS {
4565 (mean, var.sqrt())
4566 } else {
4567 (0.0, 1.0)
4570 };
4571 columns.push(LinearColumnConditioning {
4572 col_idx,
4573 mean,
4574 scale,
4575 });
4576 }
4577 Self {
4578 intercept_idx: design.intercept_range.start,
4579 columns,
4580 }
4581 }
4582
4583 pub fn apply_to_design(&self, design: &Array2<f64>) -> Array2<f64> {
4584 let mut out = design.clone();
4585 for col in &self.columns {
4586 {
4587 let mut dst = out.column_mut(col.col_idx);
4588 dst -= col.mean;
4589 }
4590 if col.scale != 1.0 {
4591 out.column_mut(col.col_idx).mapv_inplace(|v| v / col.scale);
4592 }
4593 }
4594 out
4595 }
4596
4597 fn transform_matrix_columnswith_a(&self, mat: &Array2<f64>) -> Array2<f64> {
4598 let mut out = mat.clone();
4599 let intercept = self.intercept_idx;
4600 for col in &self.columns {
4601 let intercept_col = out.column(intercept).to_owned();
4602 let mut target = out.column_mut(col.col_idx);
4603 target -= &(intercept_col * col.mean);
4604 if col.scale != 1.0 {
4605 target.mapv_inplace(|v| v / col.scale);
4606 }
4607 }
4608 out
4609 }
4610
4611 fn transform_matrixrowswith_a_transpose(&self, mat: &Array2<f64>) -> Array2<f64> {
4612 let mut out = mat.clone();
4613 let intercept = self.intercept_idx;
4614 for col in &self.columns {
4615 let interceptrow = out.row(intercept).to_owned();
4616 let mut target = out.row_mut(col.col_idx);
4617 target -= &(interceptrow * col.mean);
4618 if col.scale != 1.0 {
4619 target.mapv_inplace(|v| v / col.scale);
4620 }
4621 }
4622 out
4623 }
4624
4625 fn left_multiply_by_m_inv_transpose(&self, mat_internal: &Array2<f64>) -> Array2<f64> {
4630 let mut out = mat_internal.clone();
4631 let intercept = self.intercept_idx;
4632 let interceptrow_snapshot = mat_internal.row(intercept).to_owned();
4633 for col in &self.columns {
4634 if col.scale != 1.0 {
4635 out.row_mut(col.col_idx).mapv_inplace(|v| v * col.scale);
4636 }
4637 if col.mean != 0.0 {
4638 let mut target = out.row_mut(col.col_idx);
4639 target += &(&interceptrow_snapshot * col.mean);
4640 }
4641 }
4642 out
4643 }
4644
4645 fn right_multiply_by_m_inv(&self, mat_internal: &Array2<f64>) -> Array2<f64> {
4648 let mut out = mat_internal.clone();
4649 let intercept = self.intercept_idx;
4650 let intercept_col_snapshot = mat_internal.column(intercept).to_owned();
4651 for col in &self.columns {
4652 if col.scale != 1.0 {
4653 out.column_mut(col.col_idx).mapv_inplace(|v| v * col.scale);
4654 }
4655 if col.mean != 0.0 {
4656 let mut target = out.column_mut(col.col_idx);
4657 target += &(&intercept_col_snapshot * col.mean);
4658 }
4659 }
4660 out
4661 }
4662
4663 pub fn transform_blockwise_penalties_to_internal(
4670 &self,
4671 penalties: &[BlockwisePenalty],
4672 p: usize,
4673 ) -> Vec<crate::penalty_spec::PenaltySpec> {
4674 let conditioning_cols: std::collections::HashSet<usize> =
4675 self.columns.iter().map(|c| c.col_idx).collect();
4676 penalties
4677 .iter()
4678 .map(|bp| {
4679 let overlaps =
4680 (bp.col_range.start..bp.col_range.end).any(|j| conditioning_cols.contains(&j));
4681 if overlaps {
4682 let global = bp.to_global(p);
4685 let right = self.transform_matrix_columnswith_a(&global);
4686 let transformed = self.transform_matrixrowswith_a_transpose(&right);
4687 crate::penalty_spec::PenaltySpec::Dense(transformed)
4688 } else {
4689 crate::penalty_spec::PenaltySpec::from_blockwise(bp.clone())
4692 }
4693 })
4694 .collect()
4695 }
4696
4697 pub fn backtransform_beta(&self, beta_internal: &Array1<f64>) -> Array1<f64> {
4698 let mut beta = beta_internal.clone();
4699 let intercept = self.intercept_idx;
4700 for col in &self.columns {
4701 beta[intercept] -= beta_internal[col.col_idx] * col.mean / col.scale;
4702 beta[col.col_idx] = beta_internal[col.col_idx] / col.scale;
4703 }
4704 beta
4705 }
4706
4707 pub fn transform_penalized_hessian_to_original(&self, h_internal: &Array2<f64>) -> Array2<f64> {
4710 let right = self.right_multiply_by_m_inv(h_internal);
4711 self.left_multiply_by_m_inv_transpose(&right)
4712 }
4713
4714 pub fn internal_bounds_for(&self, col_idx: usize, min: f64, max: f64) -> (f64, f64) {
4715 if let Some(col) = self.columns.iter().find(|c| c.col_idx == col_idx) {
4716 (min * col.scale, max * col.scale)
4717 } else {
4718 (min, max)
4719 }
4720 }
4721}
4722
4723pub fn freeze_raw_spatial_metadata(metadata: BasisMetadata, raw_cols: usize) -> BasisMetadata {
4724 match metadata {
4725 BasisMetadata::ThinPlate {
4726 centers,
4727 length_scale,
4728 periodic,
4729 identifiability_transform: None,
4730 input_scales,
4731 radial_reparam,
4732 } => BasisMetadata::ThinPlate {
4733 centers,
4734 length_scale,
4735 periodic,
4736 identifiability_transform: Some(Array2::eye(raw_cols)),
4737 input_scales,
4738 radial_reparam,
4739 },
4740 BasisMetadata::Duchon {
4741 centers,
4742 length_scale,
4743 periodic,
4744 power,
4745 nullspace_order,
4746 identifiability_transform: None,
4747 input_scales,
4748 aniso_log_scales,
4749 operator_collocation_points,
4750 radial_reparam,
4751 } => BasisMetadata::Duchon {
4752 centers,
4753 length_scale,
4754 periodic,
4755 power,
4756 nullspace_order,
4757 identifiability_transform: Some(Array2::eye(raw_cols)),
4758 input_scales,
4759 aniso_log_scales,
4760 operator_collocation_points,
4761 radial_reparam,
4762 },
4763 other => other,
4764 }
4765}
4766
4767pub fn matern_operator_penalty_triplet_from_metadata(
4768 metadata: &BasisMetadata,
4769) -> Result<(Vec<Array2<f64>>, Vec<usize>, Vec<PenaltyInfo>), BasisError> {
4770 let BasisMetadata::Matern {
4771 centers,
4772 length_scale,
4773 periodic,
4774 nu,
4775 include_intercept,
4776 identifiability_transform,
4777 aniso_log_scales,
4778 input_scales,
4779 ..
4780 } = metadata
4781 else {
4782 crate::bail_invalid_basis!("Matérn operator penalties require Matérn metadata");
4783 };
4784 let penalty_length_scale = match input_scales.as_deref() {
4796 Some(scales) => compensate_length_scale_for_standardization(*length_scale, scales),
4797 None => *length_scale,
4798 };
4799 matern_operator_penalty_triplet_at_length_scale(
4800 centers.view(),
4801 periodic.as_deref(),
4802 identifiability_transform.as_ref(),
4803 *nu,
4804 *include_intercept,
4805 aniso_log_scales.as_deref(),
4806 penalty_length_scale,
4807 )
4808}
4809
4810pub fn matern_operator_penalty_triplet_at_length_scale(
4828 centers: ArrayView2<'_, f64>,
4829 periodic: Option<&[Option<f64>]>,
4830 identifiability_transform: Option<&Array2<f64>>,
4831 nu: crate::basis::MaternNu,
4832 include_intercept: bool,
4833 aniso_log_scales: Option<&[f64]>,
4834 effective_length_scale: f64,
4835) -> Result<(Vec<Array2<f64>>, Vec<usize>, Vec<PenaltyInfo>), BasisError> {
4836 let penalty_centers = crate::basis::expand_periodic_centers(¢ers.to_owned(), periodic)?;
4837 let ops = build_matern_collocation_operator_matrices(
4838 penalty_centers.view(),
4839 None,
4840 effective_length_scale,
4841 nu,
4842 include_intercept,
4843 identifiability_transform.map(|z| z.view()),
4844 aniso_log_scales,
4845 )?;
4846 const ORDER_EPS: f64 = 1e-9;
4854 let d = penalty_centers.ncols();
4855 let m = nu.half_integer_value() + 0.5 * d as f64;
4856 let mut candidates = Vec::with_capacity(3);
4857 for (raw, source, min_order) in [
4858 (ops.d0.t().dot(&ops.d0), PenaltySource::OperatorMass, 0.0),
4859 (ops.d1.t().dot(&ops.d1), PenaltySource::OperatorTension, 1.0),
4860 (
4861 ops.d2.t().dot(&ops.d2),
4862 PenaltySource::OperatorStiffness,
4863 2.0,
4864 ),
4865 ] {
4866 if min_order > 0.0 && m <= min_order + ORDER_EPS {
4867 continue;
4868 }
4869 let sym = (&raw + &raw.t()) * 0.5;
4870 let (matrix, normalization_scale) = normalize_penalty_in_constrained_space(&sym);
4871 candidates.push(PenaltyCandidate {
4872 matrix,
4873 nullspace_dim_hint: 0,
4874 source,
4875 normalization_scale,
4876 kronecker_factors: None,
4877 op: None,
4878 });
4879 }
4880 filter_active_penalty_candidates(candidates)
4881}
4882
4883pub fn normalize_penalty_in_constrained_space(matrix: &Array2<f64>) -> (Array2<f64>, f64) {
4884 let matrix = (matrix + &matrix.t().to_owned()) * 0.5;
4889 let matrix = crate::basis::project_penalty_to_psd_cone(&matrix);
4891 let c = matrix.iter().map(|v| v * v).sum::<f64>().sqrt();
4892 if c.is_finite() && c > 0.0 {
4893 (matrix.mapv(|v| v / c), c)
4894 } else {
4895 (matrix, 1.0)
4896 }
4897}
4898
4899pub fn tensor_product_design_from_sparse_marginals(
4900 marginal_sparse: &[&SparseColMat<usize, f64>],
4901) -> Result<SparseColMat<usize, f64>, BasisError> {
4902 if marginal_sparse.is_empty() {
4903 crate::bail_invalid_basis!("TensorBSpline requires at least one marginal basis");
4904 }
4905 let n = marginal_sparse[0].nrows();
4906 for (i, m) in marginal_sparse.iter().enumerate().skip(1) {
4907 if m.nrows() != n {
4908 crate::bail_dim_basis!(
4909 "tensor sparse marginal row mismatch at dim {i}: expected {n}, got {}",
4910 m.nrows()
4911 );
4912 }
4913 }
4914 let dims: Vec<usize> = marginal_sparse.iter().map(|m| m.ncols()).collect();
4915 let total_cols = dims.iter().try_fold(1usize, |acc, &q| {
4916 acc.checked_mul(q)
4917 .ok_or_else(|| BasisError::DimensionMismatch("tensor basis too large".to_string()))
4918 })?;
4919 let mut strides = vec![1usize; dims.len()];
4920 for d in (0..dims.len().saturating_sub(1)).rev() {
4921 strides[d] = strides[d + 1]
4922 .checked_mul(dims[d + 1])
4923 .ok_or_else(|| BasisError::DimensionMismatch("tensor basis too large".to_string()))?;
4924 }
4925
4926 use faer::sparse::SparseRowMat;
4927 let csrs: Vec<SparseRowMat<usize, f64>> = marginal_sparse
4928 .iter()
4929 .enumerate()
4930 .map(|(d, m)| {
4931 m.as_ref().to_row_major().map_err(|e| {
4932 BasisError::SparseCreation(format!(
4933 "tensor sparse marginal {d} CSR conversion failed: {e:?}"
4934 ))
4935 })
4936 })
4937 .collect::<Result<Vec<_>, _>>()?;
4938 let row_ptrs: Vec<&[usize]> = csrs.iter().map(|c| c.symbolic().row_ptr()).collect();
4939 let col_idxs: Vec<&[usize]> = csrs.iter().map(|c| c.symbolic().col_idx()).collect();
4940 let vals: Vec<&[f64]> = csrs.iter().map(|c| c.val()).collect();
4941
4942 use rayon::prelude::*;
4943 const CHUNK: usize = 1024;
4944 let num_chunks = n.div_ceil(CHUNK);
4945 let per_chunk: Vec<Vec<Triplet<usize, usize, f64>>> = (0..num_chunks)
4946 .into_par_iter()
4947 .map(|chunk_idx| {
4948 let row_start = chunk_idx * CHUNK;
4949 let row_end = (row_start + CHUNK).min(n);
4950 let mut chunk_triplets = Vec::<Triplet<usize, usize, f64>>::new();
4951 let mut cur_cols = Vec::<usize>::with_capacity(64);
4952 let mut cur_vals = Vec::<f64>::with_capacity(64);
4953 let mut next_cols = Vec::<usize>::with_capacity(64);
4954 let mut next_vals = Vec::<f64>::with_capacity(64);
4955 for i in row_start..row_end {
4956 cur_cols.clear();
4957 cur_vals.clear();
4958 cur_cols.push(0);
4959 cur_vals.push(1.0);
4960 let mut row_is_zero = false;
4961 for d in 0..dims.len() {
4962 let row_start_d = row_ptrs[d][i];
4963 let row_end_d = row_ptrs[d][i + 1];
4964 if row_start_d == row_end_d {
4965 row_is_zero = true;
4966 break;
4967 }
4968 let stride = strides[d];
4969 next_cols.clear();
4970 next_vals.clear();
4971 next_cols.reserve(cur_cols.len() * (row_end_d - row_start_d));
4972 next_vals.reserve(cur_vals.len() * (row_end_d - row_start_d));
4973 for (&prev_col, &prev_val) in cur_cols.iter().zip(cur_vals.iter()) {
4974 for ptr in row_start_d..row_end_d {
4975 let cj = col_idxs[d][ptr];
4976 let vj = vals[d][ptr];
4977 next_cols.push(prev_col + cj * stride);
4978 next_vals.push(prev_val * vj);
4979 }
4980 }
4981 std::mem::swap(&mut cur_cols, &mut next_cols);
4982 std::mem::swap(&mut cur_vals, &mut next_vals);
4983 }
4984 if row_is_zero {
4985 continue;
4986 }
4987 for (&col, &val) in cur_cols.iter().zip(cur_vals.iter()) {
4988 chunk_triplets.push(Triplet::new(i, col, val));
4989 }
4990 }
4991 chunk_triplets
4992 })
4993 .collect();
4994 let total_nnz: usize = per_chunk.iter().map(Vec::len).sum();
4995 let mut triplets = Vec::<Triplet<usize, usize, f64>>::with_capacity(total_nnz);
4996 for chunk in per_chunk {
4997 triplets.extend(chunk);
4998 }
4999 SparseColMat::try_new_from_triplets(n, total_cols, &triplets).map_err(|e| {
5000 BasisError::SparseCreation(format!(
5001 "failed to assemble sparse tensor product design: {e:?}"
5002 ))
5003 })
5004}
5005
5006pub fn dense_local_margin_to_sparse(
5007 dense: &Array2<f64>,
5008) -> Result<SparseColMat<usize, f64>, BasisError> {
5009 let expected_row_nnz = dense.ncols().min(4);
5010 let mut triplets =
5011 Vec::<Triplet<usize, usize, f64>>::with_capacity(dense.nrows() * expected_row_nnz);
5012 for ((row, col), &value) in dense.indexed_iter() {
5013 if value != 0.0 {
5014 triplets.push(Triplet::new(row, col, value));
5015 }
5016 }
5017 SparseColMat::try_new_from_triplets(dense.nrows(), dense.ncols(), &triplets).map_err(|e| {
5018 BasisError::SparseCreation(format!(
5019 "failed to convert tensor marginal design to sparse form: {e:?}"
5020 ))
5021 })
5022}
5023
5024pub struct TensorMarginRangeNullProjectors {
5025 range: Array2<f64>,
5026 null: Array2<f64>,
5027}
5028
5029pub fn projector_from_columns(columns: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
5030 if indices.is_empty() {
5031 return Array2::<f64>::zeros((columns.nrows(), columns.nrows()));
5032 }
5033 let basis = columns.select(Axis(1), indices);
5034 basis.dot(&basis.t())
5035}
5036
5037pub fn tensor_margin_range_null_projectors(
5038 normalized_marginal_penalties: &[(Array2<f64>, f64)],
5039) -> Result<Vec<TensorMarginRangeNullProjectors>, BasisError> {
5040 normalized_marginal_penalties
5041 .iter()
5042 .enumerate()
5043 .map(|(dim, (penalty, _))| {
5044 let analysis = crate::basis::analyze_penalty_block(penalty)?;
5045 if analysis.rank == 0 {
5046 crate::bail_invalid_basis!(
5047 "t2 separable tensor penalty margin {dim} has rank-zero penalty; \
5048 cannot split penalized and null subspaces"
5049 );
5050 }
5051 let mut range_idx = Vec::<usize>::new();
5052 let mut null_idx = Vec::<usize>::new();
5053 for (idx, &ev) in analysis.eigenvalues.iter().enumerate() {
5054 if ev > analysis.tol {
5055 range_idx.push(idx);
5056 } else {
5057 null_idx.push(idx);
5058 }
5059 }
5060 Ok(TensorMarginRangeNullProjectors {
5061 range: projector_from_columns(&analysis.eigenvectors, &range_idx),
5062 null: projector_from_columns(&analysis.eigenvectors, &null_idx),
5063 })
5064 })
5065 .collect()
5066}
5067
5068pub fn build_tensor_bspline_basis(
5069 data: ArrayView2<'_, f64>,
5070 feature_cols: &[usize],
5071 spec: &TensorBSplineSpec,
5072) -> Result<BasisBuildResult, BasisError> {
5073 if feature_cols.is_empty() {
5074 crate::bail_invalid_basis!("TensorBSpline requires at least one feature column");
5075 }
5076 if feature_cols.len() != spec.marginalspecs.len() {
5077 crate::bail_dim_basis!(
5078 "TensorBSpline feature/spec mismatch: feature_cols={}, marginalspecs={}",
5079 feature_cols.len(),
5080 spec.marginalspecs.len()
5081 );
5082 }
5083 if !spec.periods.is_empty() && spec.periods.len() != feature_cols.len() {
5084 crate::bail_dim_basis!(
5085 "TensorBSpline periods length {} does not match feature count {}",
5086 spec.periods.len(),
5087 feature_cols.len()
5088 );
5089 }
5090 let p = data.ncols();
5091 for &c in feature_cols {
5092 if c >= p {
5093 crate::bail_dim_basis!(
5094 "tensor feature column {c} is out of bounds for data with {p} columns"
5095 );
5096 }
5097 }
5098
5099 let mut marginal_knots = Vec::<Array1<f64>>::with_capacity(feature_cols.len());
5100 let mut marginal_is_cr_flags = Vec::<bool>::with_capacity(feature_cols.len());
5103 let mut marginal_degrees = Vec::<usize>::with_capacity(feature_cols.len());
5104 let mut marginalnum_basis = Vec::<usize>::with_capacity(feature_cols.len());
5105 let mut marginal_penalties = Vec::<Array2<f64>>::with_capacity(feature_cols.len());
5106 let mut marginal_designs = Vec::<Array2<f64>>::with_capacity(feature_cols.len());
5107 let mut marginal_effective_periods = Vec::<Option<f64>>::with_capacity(feature_cols.len());
5115 let mut marginal_sparse =
5123 Vec::<Option<SparseColMat<usize, f64>>>::with_capacity(feature_cols.len());
5124
5125 for (dim, (&col, marginalspec)) in feature_cols
5128 .iter()
5129 .zip(spec.marginalspecs.iter())
5130 .enumerate()
5131 {
5132 let mut marginal_unconstrained = marginalspec.clone();
5137 marginal_unconstrained.identifiability = BSplineIdentifiability::None;
5138 let built = build_bspline_basis_1d(data.column(col), &marginal_unconstrained)?;
5139 let (knots, marginal_is_cr) = match built.metadata {
5144 BasisMetadata::BSpline1D { knots, .. } => (knots, false),
5145 BasisMetadata::CubicRegression1D { knots, .. } => (knots, true),
5146 _ => {
5147 crate::bail_invalid_basis!(
5148 "internal TensorBSpline error at dim {dim}: expected BSpline1D or CubicRegression1D metadata"
5149 );
5150 }
5151 };
5152 let metadata_knots = match marginalspec.knotspec {
5153 BSplineKnotSpec::PeriodicUniform {
5154 data_range,
5155 num_basis,
5156 } => Array1::linspace(data_range.0, data_range.1, num_basis),
5157 _ => knots,
5158 };
5159 marginal_knots.push(metadata_knots);
5160 marginal_is_cr_flags.push(marginal_is_cr);
5161 marginal_degrees.push(marginalspec.degree);
5162 marginalnum_basis.push(built.design.ncols());
5163 let dense_marginal = built.design.to_dense();
5168 let sparse_view: Option<SparseColMat<usize, f64>> = match built.design.as_sparse() {
5169 Some(sd) => {
5170 let inner: &SparseColMat<usize, f64> = sd;
5171 Some(inner.clone())
5172 }
5173 None => match marginalspec.knotspec {
5174 BSplineKnotSpec::PeriodicUniform { .. } => {
5175 Some(dense_local_margin_to_sparse(&dense_marginal)?)
5176 }
5177 _ => None,
5178 },
5179 };
5180 marginal_sparse.push(sparse_view);
5181 marginal_designs.push(dense_marginal);
5182 marginal_penalties.push(
5183 built
5184 .penalties
5185 .first()
5186 .ok_or_else(|| {
5187 BasisError::InvalidInput(format!(
5188 "internal TensorBSpline error at dim {dim}: missing marginal penalty"
5189 ))
5190 })?
5191 .clone(),
5192 );
5193 built.nullspace_dims.first().ok_or_else(|| {
5194 BasisError::InvalidInput(format!(
5195 "internal TensorBSpline error at dim {dim}: missing marginal nullspace dim"
5196 ))
5197 })?;
5198 let implied_period = match marginalspec.knotspec {
5206 BSplineKnotSpec::PeriodicUniform { data_range, .. } => {
5207 Some(data_range.1 - data_range.0)
5208 }
5209 _ => spec.periods.get(dim).and_then(|p| *p),
5210 };
5211 marginal_effective_periods.push(implied_period);
5212 }
5213
5214 let total_cols: usize = marginalnum_basis.iter().product();
5215 let mut dense_design = (!matches!(spec.identifiability, TensorBSplineIdentifiability::None))
5216 .then(|| tensor_product_design_from_marginals(&marginal_designs))
5217 .transpose()?;
5218 let mut candidates = Vec::<PenaltyCandidate>::with_capacity(
5219 match spec.penalty_decomposition {
5220 TensorBSplinePenaltyDecomposition::MarginalKroneckerSum => marginal_penalties.len(),
5221 TensorBSplinePenaltyDecomposition::Separable => marginal_penalties.len() * 2,
5222 } + if spec.double_penalty { 1 } else { 0 },
5223 );
5224
5225 let normalized_marginal_penalties: Vec<(Array2<f64>, f64)> = marginal_penalties
5233 .iter()
5234 .map(normalize_penalty_in_constrained_space)
5235 .collect();
5236 let mut kronecker_marginal_penalties =
5237 Vec::<Array2<f64>>::with_capacity(normalized_marginal_penalties.len());
5238
5239 match spec.penalty_decomposition {
5240 TensorBSplinePenaltyDecomposition::MarginalKroneckerSum => {
5241 let mut marginal_kron_sum = Array2::<f64>::zeros((total_cols, total_cols));
5247
5248 for dim in 0..normalized_marginal_penalties.len() {
5249 let mut s_dim = Array2::<f64>::eye(1);
5250 let mut factors = Vec::<Array2<f64>>::with_capacity(marginalnum_basis.len());
5251 for (j, &qj) in marginalnum_basis.iter().enumerate() {
5252 let factor = if j == dim {
5253 normalized_marginal_penalties[j].0.clone()
5254 } else {
5255 Array2::<f64>::eye(qj)
5256 };
5257 factors.push(factor.clone());
5258 s_dim = kronecker_product(&s_dim, &factor);
5259 }
5260 if dim == kronecker_marginal_penalties.len() {
5261 kronecker_marginal_penalties.push(normalized_marginal_penalties[dim].0.clone());
5262 }
5263 marginal_kron_sum += &s_dim;
5264
5265 candidates.push(PenaltyCandidate {
5266 matrix: s_dim,
5267 nullspace_dim_hint: 0,
5268 source: PenaltySource::TensorMarginal { dim },
5269 normalization_scale: normalized_marginal_penalties[dim].1,
5270 kronecker_factors: Some(factors),
5271 op: None,
5272 });
5273 }
5274
5275 if spec.double_penalty
5276 && let Some(shrink) =
5277 crate::basis::build_nullspace_shrinkage_penalty(&marginal_kron_sum)?
5278 {
5279 let (matrix, normalization_scale) =
5280 normalize_penalty_in_constrained_space(&shrink.sym_penalty);
5281 candidates.push(PenaltyCandidate {
5282 matrix,
5283 nullspace_dim_hint: 0,
5284 source: PenaltySource::TensorGlobalRidge,
5285 normalization_scale,
5286 kronecker_factors: None,
5287 op: None,
5288 });
5289 }
5290 }
5291 TensorBSplinePenaltyDecomposition::Separable => {
5292 let projectors = tensor_margin_range_null_projectors(&normalized_marginal_penalties)?;
5293 let n_masks = 1usize.checked_shl(projectors.len() as u32).ok_or_else(|| {
5294 BasisError::InvalidInput(format!(
5295 "t2 separable tensor penalty supports at most {} margins, got {}",
5296 usize::BITS - 1,
5297 projectors.len()
5298 ))
5299 })?;
5300 for mask in 1..n_masks {
5301 let mut matrix = Array2::<f64>::eye(1);
5302 let mut factors = Vec::<Array2<f64>>::with_capacity(projectors.len());
5303 let mut penalized_margins = Vec::<usize>::new();
5304 for (dim, projector) in projectors.iter().enumerate() {
5305 let use_range = ((mask >> dim) & 1) == 1;
5306 let factor = if use_range {
5307 penalized_margins.push(dim);
5308 projector.range.clone()
5309 } else {
5310 projector.null.clone()
5311 };
5312 matrix = kronecker_product(&matrix, &factor);
5313 factors.push(factor);
5314 }
5315 let (matrix, normalization_scale) = normalize_penalty_in_constrained_space(&matrix);
5316 candidates.push(PenaltyCandidate {
5317 matrix,
5318 nullspace_dim_hint: 0,
5319 source: PenaltySource::TensorSeparable { penalized_margins },
5320 normalization_scale,
5321 kronecker_factors: Some(factors),
5322 op: None,
5323 });
5324 }
5325
5326 if spec.double_penalty {
5327 let mut matrix = Array2::<f64>::eye(1);
5328 let mut factors = Vec::<Array2<f64>>::with_capacity(projectors.len());
5329 for projector in &projectors {
5330 matrix = kronecker_product(&matrix, &projector.null);
5331 factors.push(projector.null.clone());
5332 }
5333 let (matrix, normalization_scale) = normalize_penalty_in_constrained_space(&matrix);
5334 candidates.push(PenaltyCandidate {
5335 matrix,
5336 nullspace_dim_hint: 0,
5337 source: PenaltySource::TensorGlobalRidge,
5338 normalization_scale,
5339 kronecker_factors: Some(factors),
5340 op: None,
5341 });
5342 }
5343 }
5344 }
5345
5346 let z_opt = match &spec.identifiability {
5347 TensorBSplineIdentifiability::None => None,
5348 TensorBSplineIdentifiability::SumToZero => {
5349 if total_cols < 2 {
5350 crate::bail_invalid_basis!(
5351 "TensorBSpline requires at least 2 basis coefficients to enforce sum-to-zero identifiability"
5352 );
5353 }
5354 let dense_design_ref = dense_design.as_ref().ok_or_else(|| {
5355 BasisError::InvalidInput(
5356 "tensor sum-to-zero identifiability requires a realized basis".to_string(),
5357 )
5358 })?;
5359 let (_, z) = apply_sum_to_zero_constraint(dense_design_ref.view(), None)?;
5360 let gauge = gam_problem::Gauge::sum_to_zero(z);
5361 Some(gauge.block_transform(0))
5362 }
5363 TensorBSplineIdentifiability::MarginalSumToZero => {
5364 if marginal_designs.len() < 2 {
5375 crate::bail_invalid_basis!(
5376 "tensor interaction (ti) identifiability requires at least 2 margins"
5377 );
5378 }
5379 let mut z = Array2::<f64>::eye(1);
5380 for (dim, marginal) in marginal_designs.iter().enumerate() {
5381 if marginal.ncols() < 2 {
5382 crate::bail_invalid_basis!(
5383 "tensor interaction (ti) margin {dim} has fewer than 2 basis functions; \
5384 cannot remove its marginal main effect"
5385 );
5386 }
5387 let (_, z_dim) = apply_sum_to_zero_constraint(marginal.view(), None)?;
5388 let gauge_dim = gam_problem::Gauge::sum_to_zero(z_dim);
5389 let z_dim = gauge_dim.block_transform(0);
5390 z = kronecker_product(&z, &z_dim);
5391 }
5392 Some(z)
5393 }
5394 TensorBSplineIdentifiability::FrozenTransform { transform } => {
5395 if transform.nrows() != total_cols {
5396 crate::bail_dim_basis!(
5397 "frozen tensor identifiability transform mismatch: design has {} columns but transform has {} rows",
5398 total_cols,
5399 transform.nrows()
5400 );
5401 }
5402 Some(transform.clone())
5403 }
5404 };
5405
5406 if let Some(z) = z_opt.as_ref() {
5407 let gauge = gam_problem::Gauge::from_block_transforms(&[z.clone()]);
5408 let dense = dense_design.as_mut().ok_or_else(|| {
5409 BasisError::InvalidInput(
5410 "tensor identifiability transform requires a realized basis".to_string(),
5411 )
5412 })?;
5413 let restricted_design = gauge.restrict_design(dense);
5414 *dense = restricted_design;
5415 candidates = candidates
5416 .into_iter()
5417 .map(|candidate| -> Result<PenaltyCandidate, BasisError> {
5418 let matrix = gauge.restrict_penalty(&candidate.matrix);
5419 let (matrix, c_new) = normalize_penalty_in_constrained_space(&matrix);
5427 Ok(PenaltyCandidate {
5428 nullspace_dim_hint: candidate.nullspace_dim_hint,
5429 matrix,
5430 source: candidate.source,
5431 normalization_scale: candidate.normalization_scale * c_new,
5432 kronecker_factors: None,
5438 op: candidate.op.clone(),
5439 })
5440 })
5441 .collect::<Result<Vec<_>, _>>()?;
5442 }
5443
5444 let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
5445 filter_active_penalty_candidates_with_ops(candidates)?;
5446 let identifiability_is_none =
5447 matches!(spec.identifiability, TensorBSplineIdentifiability::None);
5448 let all_marginals_sparse = marginal_sparse.iter().all(Option::is_some);
5456 let design = if let Some(dense_design) = dense_design {
5457 DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense_design))
5458 } else if identifiability_is_none && all_marginals_sparse {
5459 let sparse_marginals: Vec<&SparseColMat<usize, f64>> = marginal_sparse
5465 .iter()
5466 .map(|m| m.as_ref().expect("all_marginals_sparse just verified"))
5467 .collect();
5468 let sparse_design = tensor_product_design_from_sparse_marginals(&sparse_marginals)?;
5469 DesignMatrix::Sparse(gam_linalg::matrix::SparseDesignMatrix::new(sparse_design))
5470 } else {
5471 let marginals: Vec<Arc<Array2<f64>>> = marginal_designs
5472 .iter()
5473 .map(|m| Arc::new(m.clone()))
5474 .collect();
5475 let op = TensorProductDesignOperator::new(marginals).map_err(|e| {
5476 BasisError::InvalidInput(format!("TensorProductDesignOperator build failed: {e}"))
5477 })?;
5478 DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(op)))
5479 };
5480
5481 Ok(BasisBuildResult {
5482 design,
5483 penalties,
5484 nullspace_dims,
5485 penaltyinfo,
5486 ops,
5487 null_eigenvectors,
5488 joint_null_rotation: None,
5489 metadata: BasisMetadata::TensorBSpline {
5490 feature_cols: feature_cols.to_vec(),
5491 knots: marginal_knots,
5492 degrees: marginal_degrees,
5493 periods: marginal_effective_periods,
5500 is_cr: marginal_is_cr_flags,
5501 identifiability_transform: z_opt,
5502 },
5503 kronecker_factored: if matches!(spec.identifiability, TensorBSplineIdentifiability::None)
5504 && matches!(
5505 spec.penalty_decomposition,
5506 TensorBSplinePenaltyDecomposition::MarginalKroneckerSum
5507 ) {
5508 Some(KroneckerFactoredBasis::new(
5509 marginal_designs,
5510 kronecker_marginal_penalties,
5511 marginalnum_basis.clone(),
5512 spec.double_penalty,
5513 ))
5514 } else {
5515 None
5516 },
5517 })
5518}
5519
5520pub fn tensor_product_design_from_marginals(
5521 marginal_designs: &[Array2<f64>],
5522) -> Result<Array2<f64>, BasisError> {
5523 if marginal_designs.is_empty() {
5524 crate::bail_invalid_basis!("TensorBSpline requires at least one marginal basis");
5525 }
5526 let n = marginal_designs[0].nrows();
5527 for (i, b) in marginal_designs.iter().enumerate().skip(1) {
5528 if b.nrows() != n {
5529 crate::bail_dim_basis!(
5530 "tensor marginal row mismatch at dim {i}: expected {n}, got {}",
5531 b.nrows()
5532 );
5533 }
5534 }
5535 let total_cols = marginal_designs.iter().try_fold(1usize, |acc, b| {
5536 acc.checked_mul(b.ncols())
5537 .ok_or_else(|| BasisError::DimensionMismatch("tensor basis too large".to_string()))
5538 })?;
5539 use ndarray::parallel::prelude::*;
5545 use rayon::iter::{IntoParallelIterator, ParallelIterator};
5546 let mut design = Array2::<f64>::zeros((n, total_cols));
5547 design
5548 .axis_chunks_iter_mut(ndarray::Axis(0), 1024)
5549 .into_par_iter()
5550 .enumerate()
5551 .for_each(|(chunk_idx, mut block)| {
5552 let row_offset = chunk_idx * 1024;
5553 let mut cur = Vec::<f64>::with_capacity(total_cols);
5555 let mut next = Vec::<f64>::with_capacity(total_cols);
5556 for (local_i, mut out_row) in block.outer_iter_mut().enumerate() {
5557 let i = row_offset + local_i;
5558 cur.clear();
5559 cur.push(1.0);
5560 for b in marginal_designs {
5561 let q = b.ncols();
5562 next.clear();
5563 next.resize(cur.len() * q, 0.0);
5564 let b_row = b.row(i);
5568 let b_slice = b_row
5569 .as_slice()
5570 .expect("Array2 row from outer_iter is contiguous");
5571 for (a_idx, &aval) in cur.iter().enumerate() {
5572 let off = a_idx * q;
5573 let dst = &mut next[off..off + q];
5574 for col in 0..q {
5575 dst[col] = aval * b_slice[col];
5576 }
5577 }
5578 std::mem::swap(&mut cur, &mut next);
5579 }
5580 let out_slice = out_row
5585 .as_slice_mut()
5586 .expect("design row is contiguous in C-major Array2");
5587 out_slice.copy_from_slice(&cur);
5588 }
5589 });
5590 Ok(design)
5591}
5592
5593pub fn build_random_effect_block(
5594 data: ArrayView2<'_, f64>,
5595 spec: &RandomEffectTermSpec,
5596) -> Result<RandomEffectBlock, BasisError> {
5597 let n = data.nrows();
5598 let p = data.ncols();
5599 if spec.feature_col >= p {
5600 crate::bail_dim_basis!(
5601 "random-effect term '{}' feature column {} out of bounds for {} columns",
5602 spec.name,
5603 spec.feature_col,
5604 p
5605 );
5606 }
5607
5608 let col = data.column(spec.feature_col);
5609 if col.iter().any(|v| !v.is_finite()) {
5610 crate::bail_invalid_basis!(
5611 "random-effect term '{}' contains non-finite group values",
5612 spec.name
5613 );
5614 }
5615
5616 let kept_levels: Vec<u64> = if let Some(levels) = spec.frozen_levels.as_ref() {
5617 if levels.is_empty() {
5618 crate::bail_invalid_basis!(
5619 "random-effect term '{}' has empty frozen_levels",
5620 spec.name
5621 );
5622 }
5623 levels.clone()
5624 } else {
5625 let mut levels_set = BTreeSet::<u64>::new();
5626 for &v in col {
5627 levels_set.insert(v.to_bits());
5628 }
5629 if levels_set.is_empty() {
5630 crate::bail_invalid_basis!("random-effect term '{}' has no observed levels", spec.name);
5631 }
5632 let levels: Vec<u64> = levels_set.into_iter().collect();
5633 let start_idx = if spec.drop_first_level && levels.len() > 1 {
5634 1usize
5635 } else {
5636 0usize
5637 };
5638 levels[start_idx..].to_vec()
5639 };
5640
5641 if kept_levels.is_empty() {
5642 crate::bail_invalid_basis!(
5643 "random-effect term '{}' drops all levels; keep at least one level",
5644 spec.name
5645 );
5646 }
5647
5648 let q = kept_levels.len();
5649 let mut level_to_col = BTreeMap::<u64, usize>::new();
5650 for (idx, &bits) in kept_levels.iter().enumerate() {
5651 if level_to_col.insert(bits, idx).is_some() {
5652 crate::bail_invalid_basis!(
5653 "random-effect term '{}' has duplicate frozen level bits {bits}",
5654 spec.name
5655 );
5656 }
5657 }
5658 let mut group_ids = Vec::with_capacity(n);
5659 for &v in col {
5660 let bits = v.to_bits();
5661 group_ids.push(level_to_col.get(&bits).copied());
5662 }
5663
5664 Ok(RandomEffectBlock {
5665 name: spec.name.clone(),
5666 group_ids,
5667 num_groups: q,
5668 kept_levels,
5669 })
5670}
5671
5672impl SmoothDesign {
5673 pub fn map_term_coefficients(
5676 unconstrained: &Array1<f64>,
5677 shape: ShapeConstraint,
5678 ) -> Result<Array1<f64>, BasisError> {
5679 if unconstrained.is_empty() {
5680 crate::bail_invalid_basis!("unconstrained coefficient vector cannot be empty");
5681 }
5682 let mapped = match shape {
5683 ShapeConstraint::None => unconstrained.clone(),
5684 ShapeConstraint::MonotoneIncreasing => cumulative_exp(unconstrained, 1.0),
5685 ShapeConstraint::MonotoneDecreasing => cumulative_exp(unconstrained, -1.0),
5686 ShapeConstraint::Convex => second_cumulative_exp(unconstrained, 1.0),
5687 ShapeConstraint::Concave => second_cumulative_exp(unconstrained, -1.0),
5688 };
5689 Ok(mapped)
5690 }
5691}
5692
5693pub struct LocalSmoothTermBuild {
5694 pub dim: usize,
5695 pub design: DesignMatrix,
5696 pub penalties: Vec<Array2<f64>>,
5697 pub ops: Vec<Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>>,
5698 pub nullspaces: Vec<usize>,
5699 pub null_eigenvectors: Vec<Option<Array2<f64>>>,
5707 pub joint_null_rotation: Option<crate::basis::JointNullRotation>,
5714 pub penaltyinfo: Vec<PenaltyInfo>,
5715 pub pre_dropped_penaltyinfo: Vec<PenaltyInfo>,
5716 pub metadata: BasisMetadata,
5717 pub linear_constraints: Option<LinearInequalityConstraints>,
5718 pub box_reparam: bool,
5719 pub kronecker_factored: Option<KroneckerFactoredBasis>,
5720}
5721
5722#[derive(Clone)]
5723pub struct PcaScoresMemmapDesignOperator {
5724 mmap: Arc<memmap2::Mmap>,
5725 data_offset: usize,
5726 nrows: usize,
5727 ncols: usize,
5728 chunk_size: usize,
5729}
5730
5731impl PcaScoresMemmapDesignOperator {
5732 fn open(path: PathBuf, chunk_size: usize) -> Result<Self, BasisError> {
5733 let file = File::open(&path).map_err(|err| {
5734 BasisError::InvalidInput(format!(
5735 "failed to open lazy Pca .npy scores '{}': {err}",
5736 path.display()
5737 ))
5738 })?;
5739 let mmap = unsafe {
5745 memmap2::Mmap::map(&file).map_err(|err| {
5746 BasisError::InvalidInput(format!(
5747 "failed to memmap lazy Pca .npy scores '{}': {err}",
5748 path.display()
5749 ))
5750 })?
5751 };
5752 let (data_offset, nrows, ncols) = parse_f64_2d_npy_header(&mmap, &path)?;
5753 let expected = data_offset
5754 .checked_add(nrows.saturating_mul(ncols).saturating_mul(8))
5755 .ok_or_else(|| {
5756 BasisError::InvalidInput(format!(
5757 "lazy Pca .npy scores '{}' shape is too large",
5758 path.display()
5759 ))
5760 })?;
5761 if mmap.len() < expected {
5762 crate::bail_invalid_basis!(
5763 "lazy Pca .npy scores '{}' is truncated: header expects {} bytes, file has {}",
5764 path.display(),
5765 expected,
5766 mmap.len()
5767 );
5768 }
5769 Ok(Self {
5770 mmap: Arc::new(mmap),
5771 data_offset,
5772 nrows,
5773 ncols,
5774 chunk_size: chunk_size.max(1),
5775 })
5776 }
5777
5778 fn value(&self, row: usize, col: usize) -> f64 {
5779 let offset = self.data_offset + (row * self.ncols + col) * 8;
5780 let mut bytes = [0_u8; 8];
5781 bytes.copy_from_slice(&self.mmap[offset..offset + 8]);
5782 f64::from_le_bytes(bytes)
5783 }
5784
5785 fn chunk_rows(&self) -> usize {
5786 self.chunk_size.min(self.nrows.max(1))
5787 }
5788}
5789
5790impl LinearOperator for PcaScoresMemmapDesignOperator {
5791 fn nrows(&self) -> usize {
5792 self.nrows
5793 }
5794
5795 fn ncols(&self) -> usize {
5796 self.ncols
5797 }
5798
5799 fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
5800 assert_eq!(
5801 vector.len(),
5802 self.ncols,
5803 "lazy Pca apply vector length mismatch"
5804 );
5805 let mut out = Array1::<f64>::zeros(self.nrows);
5806 for start in (0..self.nrows).step_by(self.chunk_rows()) {
5807 let end = (start + self.chunk_rows()).min(self.nrows);
5808 for row in start..end {
5809 let mut acc = 0.0;
5810 for col in 0..self.ncols {
5811 acc += self.value(row, col) * vector[col];
5812 }
5813 out[row] = acc;
5814 }
5815 }
5816 out
5817 }
5818
5819 fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
5820 assert_eq!(
5821 vector.len(),
5822 self.nrows,
5823 "lazy Pca apply_transpose vector length mismatch"
5824 );
5825 let mut out = Array1::<f64>::zeros(self.ncols);
5826 for start in (0..self.nrows).step_by(self.chunk_rows()) {
5827 let end = (start + self.chunk_rows()).min(self.nrows);
5828 for row in start..end {
5829 let scale = vector[row];
5830 if scale == 0.0 {
5831 continue;
5832 }
5833 for col in 0..self.ncols {
5834 out[col] += scale * self.value(row, col);
5835 }
5836 }
5837 }
5838 out
5839 }
5840
5841 fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
5842 if weights.len() != self.nrows {
5843 return Err(format!(
5844 "lazy Pca diag_xtw_x weight length mismatch: weights={}, nrows={}",
5845 weights.len(),
5846 self.nrows
5847 ));
5848 }
5849 let mut gram = Array2::<f64>::zeros((self.ncols, self.ncols));
5850 for start in (0..self.nrows).step_by(self.chunk_rows()) {
5851 let end = (start + self.chunk_rows()).min(self.nrows);
5852 for row in start..end {
5853 let w = weights[row];
5854 if w == 0.0 {
5855 continue;
5856 }
5857 for a in 0..self.ncols {
5858 let xa = self.value(row, a);
5859 if xa == 0.0 {
5860 continue;
5861 }
5862 for b in a..self.ncols {
5863 gram[[a, b]] += w * xa * self.value(row, b);
5864 }
5865 }
5866 }
5867 }
5868 for a in 0..self.ncols {
5869 for b in 0..a {
5870 gram[[a, b]] = gram[[b, a]];
5871 }
5872 }
5873 Ok(gram)
5874 }
5875
5876 fn apply_weighted_normal(
5877 &self,
5878 weights: &Array1<f64>,
5879 vector: &Array1<f64>,
5880 penalty: Option<&Array2<f64>>,
5881 ridge: f64,
5882 ) -> Array1<f64> {
5883 assert_eq!(
5884 weights.len(),
5885 self.nrows,
5886 "lazy Pca weighted-normal weight mismatch"
5887 );
5888 assert_eq!(
5889 vector.len(),
5890 self.ncols,
5891 "lazy Pca weighted-normal vector mismatch"
5892 );
5893 let mut out = Array1::<f64>::zeros(self.ncols);
5894 for start in (0..self.nrows).step_by(self.chunk_rows()) {
5895 let end = (start + self.chunk_rows()).min(self.nrows);
5896 for row in start..end {
5897 let w = weights[row].max(0.0);
5898 if w == 0.0 {
5899 continue;
5900 }
5901 let mut row_dot = 0.0;
5902 for col in 0..self.ncols {
5903 row_dot += self.value(row, col) * vector[col];
5904 }
5905 if row_dot == 0.0 {
5906 continue;
5907 }
5908 let scaled = w * row_dot;
5909 for col in 0..self.ncols {
5910 out[col] += scaled * self.value(row, col);
5911 }
5912 }
5913 }
5914 if let Some(pen) = penalty {
5915 out += &pen.dot(vector);
5916 }
5917 if ridge > 0.0 {
5918 out += &vector.mapv(|x| ridge * x);
5919 }
5920 out
5921 }
5922}
5923
5924impl DenseDesignOperator for PcaScoresMemmapDesignOperator {
5925 fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
5926 if weights.len() != self.nrows || y.len() != self.nrows {
5927 return Err(format!(
5928 "lazy Pca compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
5929 weights.len(),
5930 y.len(),
5931 self.nrows
5932 ));
5933 }
5934 let mut out = Array1::<f64>::zeros(self.ncols);
5935 for start in (0..self.nrows).step_by(self.chunk_rows()) {
5936 let end = (start + self.chunk_rows()).min(self.nrows);
5937 for row in start..end {
5938 let scale = weights[row] * y[row];
5939 if scale == 0.0 {
5940 continue;
5941 }
5942 for col in 0..self.ncols {
5943 out[col] += scale * self.value(row, col);
5944 }
5945 }
5946 }
5947 Ok(out)
5948 }
5949
5950 fn row_chunk_into(
5951 &self,
5952 rows: Range<usize>,
5953 mut out: ArrayViewMut2<'_, f64>,
5954 ) -> Result<(), MatrixMaterializationError> {
5955 if rows.end > self.nrows || rows.start > rows.end {
5956 return Err(MatrixMaterializationError::MissingRowChunk {
5957 context: "lazy Pca row range out of bounds",
5958 });
5959 }
5960 if out.nrows() != rows.end - rows.start || out.ncols() != self.ncols {
5961 return Err(MatrixMaterializationError::MissingRowChunk {
5962 context: "lazy Pca row_chunk_into shape mismatch",
5963 });
5964 }
5965 for (local, row) in (rows.start..rows.end).enumerate() {
5966 for col in 0..self.ncols {
5967 out[[local, col]] = self.value(row, col);
5968 }
5969 }
5970 Ok(())
5971 }
5972
5973 fn to_dense(&self) -> Array2<f64> {
5974 let mut out = Array2::<f64>::zeros((self.nrows, self.ncols));
5975 self.row_chunk_into(0..self.nrows, out.view_mut())
5976 .expect("lazy Pca full materialization failed");
5977 out
5978 }
5979}
5980
5981pub fn parse_f64_2d_npy_header(
5982 bytes: &[u8],
5983 path: &PathBuf,
5984) -> Result<(usize, usize, usize), BasisError> {
5985 if bytes.len() < 10 || &bytes[0..6] != b"\x93NUMPY" {
5986 crate::bail_invalid_basis!("lazy Pca scores '{}' is not a .npy file", path.display());
5987 }
5988 let major = bytes[6];
5989 let header_len = match major {
5990 1 => u16::from_le_bytes([bytes[8], bytes[9]]) as usize,
5991 2 | 3 => {
5992 if bytes.len() < 12 {
5993 crate::bail_invalid_basis!(
5994 "lazy Pca scores '{}' has a truncated .npy header",
5995 path.display()
5996 );
5997 }
5998 u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]) as usize
5999 }
6000 other => {
6001 crate::bail_invalid_basis!(
6002 "lazy Pca scores '{}' uses unsupported .npy version {}",
6003 path.display(),
6004 other
6005 );
6006 }
6007 };
6008 let header_start = if major == 1 { 10 } else { 12 };
6009 let data_offset = header_start + header_len;
6010 if bytes.len() < data_offset {
6011 crate::bail_invalid_basis!(
6012 "lazy Pca scores '{}' has a truncated .npy header",
6013 path.display()
6014 );
6015 }
6016 let header = std::str::from_utf8(&bytes[header_start..data_offset]).map_err(|err| {
6017 BasisError::InvalidInput(format!(
6018 "lazy Pca scores '{}' has a non-UTF8 .npy header: {err}",
6019 path.display()
6020 ))
6021 })?;
6022 if !(header.contains("'descr': '<f8'")
6023 || header.contains("\"descr\": \"<f8\"")
6024 || header.contains("'descr': '|f8'")
6025 || header.contains("\"descr\": \"|f8\""))
6026 {
6027 crate::bail_invalid_basis!(
6028 "lazy Pca scores '{}' must be float64 little-endian .npy",
6029 path.display()
6030 );
6031 }
6032 if header.contains("True") {
6033 crate::bail_invalid_basis!(
6034 "lazy Pca scores '{}' must be C-contiguous, not Fortran-ordered",
6035 path.display()
6036 );
6037 }
6038 let shape_pos = header.find("shape").ok_or_else(|| {
6039 BasisError::InvalidInput(format!(
6040 "lazy Pca scores '{}' .npy header is missing shape",
6041 path.display()
6042 ))
6043 })?;
6044 let open = header[shape_pos..].find('(').ok_or_else(|| {
6045 BasisError::InvalidInput(format!(
6046 "lazy Pca scores '{}' .npy header has malformed shape",
6047 path.display()
6048 ))
6049 })? + shape_pos;
6050 let close = header[open..].find(')').ok_or_else(|| {
6051 BasisError::InvalidInput(format!(
6052 "lazy Pca scores '{}' .npy header has malformed shape",
6053 path.display()
6054 ))
6055 })? + open;
6056 let dims = header[open + 1..close]
6057 .split(',')
6058 .map(str::trim)
6059 .filter(|part| !part.is_empty())
6060 .map(|part| part.parse::<usize>())
6061 .collect::<Result<Vec<_>, _>>()
6062 .map_err(|err| {
6063 BasisError::InvalidInput(format!(
6064 "lazy Pca scores '{}' .npy shape is not integral: {err}",
6065 path.display()
6066 ))
6067 })?;
6068 if dims.len() != 2 {
6069 crate::bail_invalid_basis!(
6070 "lazy Pca scores '{}' must have shape (N, K), got {:?}",
6071 path.display(),
6072 dims
6073 );
6074 }
6075 Ok((data_offset, dims[0], dims[1]))
6076}
6077
6078pub fn pca_center_mean(x: ArrayView2<'_, f64>) -> Result<Array1<f64>, BasisError> {
6079 if x.nrows() == 0 {
6080 crate::bail_invalid_basis!("Pca basis requires at least one row to compute center mean");
6081 }
6082 let mut mean = Array1::<f64>::zeros(x.ncols());
6083 for row in x.rows() {
6084 mean += &row;
6085 }
6086 mean.mapv_inplace(|v| v / x.nrows() as f64);
6087 Ok(mean)
6088}
6089
6090pub fn build_pca_smooth_basis(
6091 data: ArrayView2<'_, f64>,
6092 feature_cols: &[usize],
6093 basis_matrix: &Array2<f64>,
6094 centered: bool,
6095 smooth_penalty: f64,
6096 center_mean: Option<&Array1<f64>>,
6097 pca_basis_path: Option<&PathBuf>,
6098 chunk_size: usize,
6099) -> Result<BasisBuildResult, BasisError> {
6100 if let Some(path) = pca_basis_path {
6101 let op = PcaScoresMemmapDesignOperator::open(path.clone(), chunk_size)?;
6102 if op.nrows != data.nrows() {
6103 crate::bail_dim_basis!(
6104 "lazy Pca scores row mismatch: .npy has {}, data has {}",
6105 op.nrows,
6106 data.nrows()
6107 );
6108 }
6109 let k = op.ncols;
6110 let mut penalty = Array2::<f64>::eye(k);
6111 penalty.mapv_inplace(|v| v * smooth_penalty);
6112 let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
6113 filter_active_penalty_candidates_with_ops(vec![PenaltyCandidate {
6114 matrix: penalty,
6115 nullspace_dim_hint: 0,
6116 source: PenaltySource::Other("PcaRidge".to_string()),
6117 normalization_scale: 1.0,
6118 kronecker_factors: None,
6119 op: None,
6120 }])?;
6121 return Ok(BasisBuildResult {
6122 design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(op))),
6123 penalties,
6124 nullspace_dims,
6125 penaltyinfo,
6126 ops,
6127 null_eigenvectors,
6128 joint_null_rotation: None,
6129 metadata: BasisMetadata::Pca {
6130 feature_cols: feature_cols.to_vec(),
6131 basis_matrix: basis_matrix.clone(),
6132 centered,
6133 smooth_penalty,
6134 center_mean: center_mean.cloned(),
6135 pca_basis_path: Some(path.clone()),
6136 chunk_size: chunk_size.max(1),
6137 },
6138 kronecker_factored: None,
6139 });
6140 }
6141 if basis_matrix.nrows() != feature_cols.len() {
6142 crate::bail_dim_basis!(
6143 "Pca basis row mismatch: basis rows={}, feature columns={}",
6144 basis_matrix.nrows(),
6145 feature_cols.len()
6146 );
6147 }
6148 let mut x = select_columns(data, feature_cols)?;
6149 let mean = if centered {
6150 match center_mean {
6151 Some(mean) => mean.clone(),
6152 None => pca_center_mean(x.view())?,
6153 }
6154 } else {
6155 Array1::<f64>::zeros(feature_cols.len())
6156 };
6157 if centered {
6158 for mut row in x.rows_mut() {
6159 row -= &mean;
6160 }
6161 }
6162 let design = fast_ab(&x, basis_matrix);
6163 let k = basis_matrix.ncols();
6164 let mut penalty = Array2::<f64>::eye(k);
6165 penalty.mapv_inplace(|v| v * smooth_penalty);
6166 let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
6167 filter_active_penalty_candidates_with_ops(vec![PenaltyCandidate {
6168 matrix: penalty,
6169 nullspace_dim_hint: 0,
6170 source: PenaltySource::Other("PcaRidge".to_string()),
6171 normalization_scale: 1.0,
6172 kronecker_factors: None,
6173 op: None,
6174 }])?;
6175 Ok(BasisBuildResult {
6176 design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(design)),
6177 penalties,
6178 nullspace_dims,
6179 penaltyinfo,
6180 ops,
6181 null_eigenvectors,
6182 joint_null_rotation: None,
6183 metadata: BasisMetadata::Pca {
6184 feature_cols: feature_cols.to_vec(),
6185 basis_matrix: basis_matrix.clone(),
6186 centered,
6187 smooth_penalty,
6188 center_mean: centered.then_some(mean),
6189 pca_basis_path: None,
6190 chunk_size: chunk_size.max(1),
6191 },
6192 kronecker_factored: None,
6193 })
6194}
6195
6196pub fn defer_inner_model_centering_to_factor_level_wrapper(basis: &mut SmoothBasisSpec) {
6212 if let SmoothBasisSpec::BSpline1D { spec, .. } = basis
6213 && matches!(
6214 spec.identifiability,
6215 BSplineIdentifiability::WeightedSumToZero { .. }
6216 )
6217 {
6218 spec.identifiability = BSplineIdentifiability::None;
6219 }
6220}
6221
6222pub fn apply_by_variable_to_local_build(
6223 mut built: LocalSmoothTermBuild,
6224 data: ArrayView2<'_, f64>,
6225 by_col: usize,
6226 by: &ByVariableSpec,
6227 term_name: &str,
6228) -> Result<LocalSmoothTermBuild, BasisError> {
6229 if by_col >= data.ncols() {
6230 crate::bail_dim_basis!(
6231 "by-variable smooth term '{term_name}' references column {by_col}, but data has {} columns",
6232 data.ncols()
6233 );
6234 }
6235 let weights = match by {
6236 ByVariableSpec::Numeric => data.column(by_col).to_owned(),
6237 ByVariableSpec::Level { value_bits, .. } => data.column(by_col).mapv(|value| {
6238 if value.to_bits() == *value_bits {
6239 1.0
6240 } else {
6241 0.0
6242 }
6243 }),
6244 };
6245 if weights.iter().any(|value| !value.is_finite()) {
6246 crate::bail_invalid_basis!(
6247 "by-variable smooth term '{term_name}' has non-finite by-column values"
6248 );
6249 }
6250
6251 let mut dense = built
6252 .design
6253 .try_to_dense_by_chunks("by-variable smooth row gating")
6254 .map_err(BasisError::InvalidInput)?;
6255 for (mut row, &weight) in dense.rows_mut().into_iter().zip(weights.iter()) {
6256 row.mapv_inplace(|value| value * weight);
6257 }
6258 built.design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense));
6259 built.kronecker_factored = None;
6260 Ok(built)
6261}
6262
6263pub fn build_by_smooth_local(
6274 data: ArrayView2<'_, f64>,
6275 term: &SmoothTermSpec,
6276 smooth: &SmoothBasisSpec,
6277 by_kind: &ByVarKind,
6278 workspace: &mut crate::basis::BasisWorkspace,
6279) -> Result<LocalSmoothTermBuild, BasisError> {
6280 let inner_term = SmoothTermSpec {
6281 name: term.name.clone(),
6282 basis: (*smooth).clone(),
6283 shape: term.shape,
6284 joint_null_rotation: None,
6285 };
6286 let inner = build_single_local_smooth_term(data, &inner_term, workspace)?;
6287
6288 match by_kind {
6289 ByVarKind::Numeric { feature_col } => {
6290 let inner_meta = inner.metadata.clone();
6291 let mut built = apply_by_variable_to_local_build(
6292 inner,
6293 data,
6294 *feature_col,
6295 &ByVariableSpec::Numeric,
6296 &term.name,
6297 )?;
6298 built.metadata = BasisMetadata::BySmooth {
6299 inner: Box::new(inner_meta),
6300 by_col: *feature_col,
6301 levels: None,
6302 ordered: false,
6303 };
6304 Ok(built)
6305 }
6306 ByVarKind::Factor {
6307 feature_col,
6308 frozen_levels,
6309 ordered,
6310 } => {
6311 let level_bits: Vec<u64> = if let Some(fl) = frozen_levels {
6314 fl.clone()
6315 } else {
6316 let col = data.column(*feature_col);
6317 let mut seen = BTreeSet::<u64>::new();
6318 for &v in col.iter() {
6319 if v.is_finite() {
6320 seen.insert(v.to_bits());
6321 }
6322 }
6323 seen.into_iter().collect()
6324 };
6325 let n_levels = level_bits.len();
6326 if n_levels == 0 {
6327 crate::bail_invalid_basis!(
6328 "by-factor smooth term '{}': factor column {} has no observed levels",
6329 term.name,
6330 feature_col
6331 );
6332 }
6333 let p = inner.dim;
6334 let q = n_levels * p;
6335 let n = data.nrows();
6336
6337 let inner_dense = inner
6338 .design
6339 .try_to_dense_by_chunks("by-factor smooth design gating")
6340 .map_err(BasisError::InvalidInput)?;
6341
6342 let mut combined = Array2::<f64>::zeros((n, q));
6344 for (lvl_idx, &bits) in level_bits.iter().enumerate() {
6345 let col_start = lvl_idx * p;
6346 for row in 0..n {
6347 if data[[row, *feature_col]].to_bits() == bits {
6348 combined
6349 .slice_mut(s![row, col_start..col_start + p])
6350 .assign(&inner_dense.row(row));
6351 }
6352 }
6353 }
6354
6355 let inner_meta = inner.metadata.clone();
6367 let n_penalties = inner.penalties.len();
6368 let n_blocks = n_penalties.saturating_mul(n_levels);
6369 let mut penalties = Vec::<Array2<f64>>::with_capacity(n_blocks);
6370 let mut penaltyinfo = Vec::<PenaltyInfo>::with_capacity(n_blocks);
6371 let mut nullspaces = Vec::<usize>::with_capacity(n_blocks);
6372 for (pen_pos, s_inner) in inner.penalties.iter().enumerate() {
6373 for lvl in 0..n_levels {
6374 let off = lvl * p;
6375 let mut s_big = Array2::<f64>::zeros((q, q));
6376 s_big
6377 .slice_mut(s![off..off + p, off..off + p])
6378 .assign(s_inner);
6379 let (s_big, scale) = normalize_penalty_in_constrained_space(&s_big);
6380 let mut info = inner.penaltyinfo[pen_pos].clone();
6381 info.original_index = pen_pos * n_levels + lvl;
6384 info.normalization_scale *= scale;
6385 info.kronecker_factors = None;
6388 penalties.push(s_big);
6389 penaltyinfo.push(info);
6390 nullspaces.push(inner.nullspaces[pen_pos]);
6391 }
6392 }
6393
6394 let null_eigenvectors = vec![None; penalties.len()];
6395 let ops = vec![None; penalties.len()];
6396
6397 Ok(LocalSmoothTermBuild {
6398 dim: q,
6399 design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(combined)),
6400 penalties,
6401 ops,
6402 nullspaces,
6403 null_eigenvectors,
6404 joint_null_rotation: None,
6405 penaltyinfo,
6406 pre_dropped_penaltyinfo: inner.pre_dropped_penaltyinfo,
6407 metadata: BasisMetadata::BySmooth {
6408 inner: Box::new(inner_meta),
6409 by_col: *feature_col,
6410 levels: Some(level_bits),
6411 ordered: *ordered,
6412 },
6413 linear_constraints: None,
6414 box_reparam: false,
6415 kronecker_factored: None,
6416 })
6417 }
6418 }
6419}
6420
6421pub fn ensure_by_variable_specs_match(
6422 kind: &BySmoothKind,
6423 by: &ByVariableSpec,
6424 term_name: &str,
6425) -> Result<(), BasisError> {
6426 match (kind, by) {
6427 (BySmoothKind::Numeric, ByVariableSpec::Numeric) => Ok(()),
6428 (BySmoothKind::Level { level_bits }, ByVariableSpec::Level { value_bits, .. })
6429 if level_bits == value_bits =>
6430 {
6431 Ok(())
6432 }
6433 _ => Err(BasisError::InvalidInput(format!(
6434 "by-variable smooth term '{term_name}' has inconsistent by-variable specifications"
6435 ))),
6436 }
6437}
6438
6439pub fn build_factor_smooth(
6467 data: ArrayView2<'_, f64>,
6468 spec: &FactorSmoothSpec,
6469 term_name: &str,
6470 workspace: &mut crate::basis::BasisWorkspace,
6471) -> Result<LocalSmoothTermBuild, BasisError> {
6472 if spec.continuous_cols.len() != 1 {
6473 crate::bail_invalid_basis!(
6474 "factor smooth term '{}' currently supports exactly one continuous covariate; found {}",
6475 term_name,
6476 spec.continuous_cols.len()
6477 );
6478 }
6479 let feature_col = spec.continuous_cols[0];
6480 let group_col = spec.group_col;
6481 if feature_col >= data.ncols() || group_col >= data.ncols() {
6482 crate::bail_dim_basis!(
6483 "factor smooth term '{}' references columns ({}, {}) out of bounds for {} columns",
6484 term_name,
6485 feature_col,
6486 group_col,
6487 data.ncols()
6488 );
6489 }
6490
6491 if matches!(spec.flavour, FactorSmoothFlavour::Sz) {
6494 let levels = resolve_factor_smooth_levels(data, group_col, spec, term_name)?;
6495 let inner = SmoothBasisSpec::BSpline1D {
6496 feature_col,
6497 spec: factor_smooth_marginal_for_replay(&spec.marginal),
6498 };
6499 let sz_term = SmoothTermSpec {
6500 name: term_name.to_string(),
6501 basis: SmoothBasisSpec::FactorSumToZero {
6502 inner: Box::new(inner),
6503 by_col: group_col,
6504 levels: levels.clone(),
6505 frozen_global_orthogonality: None,
6506 },
6507 shape: ShapeConstraint::None,
6508 joint_null_rotation: None,
6509 };
6510 let mut built = build_single_local_smooth_term(data, &sz_term, workspace)?;
6511 let (knots, degree, periodic, marginal_is_cr) = match &built.metadata {
6529 BasisMetadata::BSpline1D {
6530 knots,
6531 periodic,
6532 degree,
6533 ..
6534 } => (
6535 knots.clone(),
6536 degree.unwrap_or(spec.marginal.degree),
6537 *periodic,
6538 false,
6539 ),
6540 BasisMetadata::CubicRegression1D { knots, .. } => {
6541 (knots.clone(), spec.marginal.degree, None, true)
6542 }
6543 other => {
6544 crate::bail_invalid_basis!(
6545 "sz factor smooth term '{}' produced an unexpected marginal metadata variant {:?}",
6546 term_name,
6547 other
6548 );
6549 }
6550 };
6551 built.metadata = BasisMetadata::FactorSmooth {
6552 continuous_cols: spec.continuous_cols.clone(),
6553 group_col,
6554 knots,
6555 degree,
6556 periodic,
6557 group_levels: levels,
6558 flavour: "sz".to_string(),
6559 marginal_is_cr,
6560 };
6561 return Ok(built);
6562 }
6563
6564 let levels = resolve_factor_smooth_levels(data, group_col, spec, term_name)?;
6565 let n_levels = levels.len();
6566 if n_levels < 2 {
6567 crate::bail_invalid_basis!(
6568 "factor smooth term '{}' requires at least two grouping levels; found {}",
6569 term_name,
6570 n_levels
6571 );
6572 }
6573
6574 let use_per_dim_null = matches!(
6582 &spec.flavour,
6583 FactorSmoothFlavour::Fs { m_null_penalty_orders }
6584 if m_null_penalty_orders.iter().copied().max().unwrap_or(0) >= 1
6585 );
6586
6587 let mut marginal_spec = factor_smooth_marginal_for_replay(&spec.marginal);
6593 if use_per_dim_null {
6594 marginal_spec.double_penalty = false;
6595 }
6596 let inner_term = SmoothTermSpec {
6597 name: format!("{term_name}::marginal"),
6598 basis: SmoothBasisSpec::BSpline1D {
6599 feature_col,
6600 spec: marginal_spec,
6601 },
6602 shape: ShapeConstraint::None,
6603 joint_null_rotation: None,
6604 };
6605 let inner = build_single_local_smooth_term(data, &inner_term, workspace)?;
6606 let base = inner
6607 .design
6608 .try_to_dense_by_chunks("factor smooth marginal")
6609 .map_err(BasisError::InvalidInput)?;
6610 let n = base.nrows();
6611 let p = base.ncols();
6612 let q = p * n_levels;
6613
6614 let mut dense = Array2::<f64>::zeros((n, q));
6617 for i in 0..n {
6618 let bits = data[[i, group_col]].to_bits();
6619 let level_idx = levels.iter().position(|b| *b == bits).ok_or_else(|| {
6620 BasisError::InvalidInput(format!(
6621 "factor smooth term '{term_name}' saw an unseen grouping level at row {}",
6622 i + 1
6623 ))
6624 })?;
6625 let start = level_idx * p;
6626 dense
6627 .slice_mut(s![i, start..start + p])
6628 .assign(&base.row(i));
6629 }
6630
6631 let marginal_penalties: Vec<Array2<f64>> = if matches!(spec.flavour, FactorSmoothFlavour::Re) {
6637 vec![Array2::<f64>::eye(p)]
6638 } else {
6639 inner.penalties.clone()
6640 };
6641 let marginal_penaltyinfo: Vec<PenaltyInfo> = if matches!(spec.flavour, FactorSmoothFlavour::Re)
6642 {
6643 vec![PenaltyInfo {
6644 source: PenaltySource::Primary,
6645 original_index: 0,
6646 active: true,
6647 effective_rank: p,
6648 dropped_reason: None,
6649 nullspace_dim_hint: 0,
6650 normalization_scale: 1.0,
6651 kronecker_factors: None,
6652 }]
6653 } else {
6654 inner.penaltyinfo.clone()
6655 };
6656 if marginal_penalties.len() != marginal_penaltyinfo.len() {
6657 crate::bail_invalid_basis!(
6658 "internal factor-smooth penalty metadata mismatch for term '{}': penalties={}, infos={}",
6659 term_name,
6660 marginal_penalties.len(),
6661 marginal_penaltyinfo.len()
6662 );
6663 }
6664
6665 let mut penalties = Vec::<Array2<f64>>::with_capacity(marginal_penalties.len());
6666 let mut penaltyinfo = Vec::<PenaltyInfo>::with_capacity(marginal_penalties.len());
6667 for (penalty_pos, s_inner) in marginal_penalties.iter().enumerate() {
6668 let mut s_big = Array2::<f64>::zeros((q, q));
6669 for level in 0..n_levels {
6670 let start = level * p;
6671 s_big
6672 .slice_mut(s![start..start + p, start..start + p])
6673 .assign(s_inner);
6674 }
6675 let (s_big, factor_smooth_scale) = normalize_penalty_in_constrained_space(&s_big);
6676 let mut info = marginal_penaltyinfo[penalty_pos].clone();
6677 info.original_index = penalty_pos;
6678 info.normalization_scale *= factor_smooth_scale;
6679 info.nullspace_dim_hint = info.nullspace_dim_hint.saturating_mul(n_levels);
6680 info.kronecker_factors = None;
6681 penalties.push(s_big);
6682 penaltyinfo.push(info);
6683 }
6684
6685 let mut nullspaces: Vec<usize> = if matches!(spec.flavour, FactorSmoothFlavour::Re) {
6686 vec![0]
6687 } else {
6688 inner
6689 .nullspaces
6690 .iter()
6691 .map(|ns| ns.saturating_mul(n_levels))
6692 .collect()
6693 };
6694
6695 if use_per_dim_null
6725 && let Some(Some(z)) = inner.null_eigenvectors.first()
6726 && z.nrows() == p
6727 {
6728 for k in 0..z.ncols() {
6729 let zk = z.column(k);
6734 let mut p_k = Array2::<f64>::zeros((p, p));
6735 for a in 0..p {
6736 for b in 0..p {
6737 p_k[[a, b]] = zk[a] * zk[b];
6738 }
6739 }
6740 let mut s_null = Array2::<f64>::zeros((q, q));
6741 for level in 0..n_levels {
6742 let start = level * p;
6743 s_null
6744 .slice_mut(s![start..start + p, start..start + p])
6745 .assign(&p_k);
6746 }
6747 let (s_null, null_scale) = normalize_penalty_in_constrained_space(&s_null);
6748 let null_block = crate::basis::analyze_penalty_block_with_op(&s_null, None)?;
6749 if null_block.rank > 0 {
6750 let original_index = penalties.len();
6751 penalties.push(null_block.sym_penalty);
6752 nullspaces.push(null_block.nullity);
6753 penaltyinfo.push(PenaltyInfo {
6754 source: PenaltySource::Primary,
6755 original_index,
6756 active: true,
6757 effective_rank: null_block.rank,
6758 dropped_reason: None,
6759 nullspace_dim_hint: null_block.nullity,
6760 normalization_scale: null_scale,
6761 kronecker_factors: None,
6762 });
6763 }
6764 }
6765 }
6766 let null_eigenvectors = crate::basis::recompute_null_eigenvectors(&penalties)?;
6767 let joint_null_rotation = crate::basis::compute_joint_null_rotation(&penalties)?;
6768
6769 let (knots, degree, periodic) = match &inner.metadata {
6772 BasisMetadata::BSpline1D {
6773 knots,
6774 periodic,
6775 degree,
6776 ..
6777 } => (
6778 knots.clone(),
6779 degree.unwrap_or(spec.marginal.degree),
6780 *periodic,
6781 ),
6782 other => {
6783 crate::bail_invalid_basis!(
6784 "factor smooth term '{}' produced an unexpected marginal metadata variant {:?}",
6785 term_name,
6786 other
6787 );
6788 }
6789 };
6790 let flavour_tag = match &spec.flavour {
6791 FactorSmoothFlavour::Fs { .. } => "fs",
6792 FactorSmoothFlavour::Sz => "sz",
6793 FactorSmoothFlavour::Re => "re",
6794 }
6795 .to_string();
6796 let metadata = BasisMetadata::FactorSmooth {
6797 continuous_cols: spec.continuous_cols.clone(),
6798 group_col,
6799 knots,
6800 degree,
6801 periodic,
6802 group_levels: levels,
6803 flavour: flavour_tag,
6804 marginal_is_cr: false,
6807 };
6808
6809 let ops = vec![None; penalties.len()];
6810 Ok(LocalSmoothTermBuild {
6811 dim: q,
6812 design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense)),
6813 penalties,
6814 ops,
6815 nullspaces,
6816 null_eigenvectors,
6817 joint_null_rotation,
6818 penaltyinfo,
6819 pre_dropped_penaltyinfo: Vec::new(),
6820 metadata,
6821 linear_constraints: None,
6822 box_reparam: false,
6823 kronecker_factored: None,
6824 })
6825}
6826
6827pub fn resolve_factor_smooth_levels(
6831 data: ArrayView2<'_, f64>,
6832 group_col: usize,
6833 spec: &FactorSmoothSpec,
6834 term_name: &str,
6835) -> Result<Vec<u64>, BasisError> {
6836 if let Some(frozen) = &spec.group_frozen_levels {
6837 if frozen.is_empty() {
6838 crate::bail_invalid_basis!(
6839 "factor smooth term '{}' has an empty frozen level list",
6840 term_name
6841 );
6842 }
6843 return Ok(frozen.clone());
6844 }
6845 let mut bits: Vec<u64> = data.column(group_col).iter().map(|v| v.to_bits()).collect();
6846 bits.sort_by(|a, b| {
6847 f64::from_bits(*a)
6848 .partial_cmp(&f64::from_bits(*b))
6849 .unwrap_or(std::cmp::Ordering::Equal)
6850 });
6851 bits.dedup();
6852 Ok(bits)
6853}
6854
6855pub fn factor_smooth_marginal_for_replay(marginal: &BSplineBasisSpec) -> BSplineBasisSpec {
6862 let mut m = marginal.clone();
6863 m.identifiability = BSplineIdentifiability::None;
6864 m
6865}
6866
6867pub fn build_single_local_smooth_term(
6868 data: ArrayView2<'_, f64>,
6869 term: &SmoothTermSpec,
6870 workspace: &mut crate::basis::BasisWorkspace,
6871) -> Result<LocalSmoothTermBuild, BasisError> {
6872 if term.shape != ShapeConstraint::None && !shape_supports_basis(term) {
6873 crate::bail_invalid_basis!(
6874 "ShapeConstraint::{:?} is unsupported for term '{}'",
6875 term.shape,
6876 term.name
6877 );
6878 }
6879 if let SmoothBasisSpec::ByVariable {
6880 inner,
6881 by_col,
6882 kind,
6883 by,
6884 } = &term.basis
6885 {
6886 ensure_by_variable_specs_match(kind, by, &term.name)?;
6887 let mut inner_basis = (**inner).clone();
6888 if matches!(by, ByVariableSpec::Level { .. }) {
6895 defer_inner_model_centering_to_factor_level_wrapper(&mut inner_basis);
6896 }
6897 let inner_term = SmoothTermSpec {
6898 name: term.name.clone(),
6899 basis: inner_basis,
6900 shape: term.shape,
6901 joint_null_rotation: None,
6902 };
6903 let built = build_single_local_smooth_term(data, &inner_term, workspace)?;
6904 return apply_by_variable_to_local_build(built, data, *by_col, by, &term.name);
6905 }
6906
6907 if let SmoothBasisSpec::BySmooth { smooth, by_kind } = &term.basis {
6910 return build_by_smooth_local(data, term, smooth, by_kind, workspace);
6911 }
6912
6913 let mut shape_axis_col: Option<usize> = None;
6914 let mut built: BasisBuildResult = match &term.basis {
6915 SmoothBasisSpec::FactorSumToZero {
6916 inner,
6917 by_col,
6918 levels,
6919 ..
6920 } => {
6921 if *by_col >= data.ncols() {
6922 crate::bail_dim_basis!(
6923 "term '{}' by column {} out of bounds for {} columns",
6924 term.name,
6925 by_col,
6926 data.ncols()
6927 );
6928 }
6929 if levels.len() < 2 {
6930 crate::bail_invalid_basis!(
6931 "sum-to-zero factor smooth term '{}' requires at least two levels",
6932 term.name
6933 );
6934 }
6935 if term.shape != ShapeConstraint::None {
6936 crate::bail_invalid_basis!(
6937 "ShapeConstraint::{:?} is unsupported for sum-to-zero factor smooth term '{}'",
6938 term.shape,
6939 term.name
6940 );
6941 }
6942 let inner_term = SmoothTermSpec {
6943 name: format!("{}::inner", term.name),
6944 basis: (**inner).clone(),
6945 shape: ShapeConstraint::None,
6946 joint_null_rotation: None,
6947 };
6948 let mut inner_built = build_single_local_smooth_term(data, &inner_term, workspace)?;
6949 let base = inner_built
6950 .design
6951 .try_to_dense_by_chunks("sum-to-zero factor smooth")
6952 .map_err(BasisError::InvalidInput)?;
6953 let n = base.nrows();
6954 let p = base.ncols();
6955 let l_minus_one = levels.len() - 1;
6956 let mut dense = Array2::<f64>::zeros((n, p * l_minus_one));
6957 for i in 0..n {
6958 let bits = data[[i, *by_col]].to_bits();
6959 let level_idx = levels.iter().position(|b| *b == bits).ok_or_else(|| {
6960 BasisError::InvalidInput(format!(
6961 "sum-to-zero factor smooth term '{}' saw an unseen level at row {}",
6962 term.name,
6963 i + 1
6964 ))
6965 })?;
6966 if level_idx < l_minus_one {
6967 let start = level_idx * p;
6968 dense
6969 .slice_mut(s![i, start..start + p])
6970 .assign(&base.row(i));
6971 } else {
6972 for level in 0..l_minus_one {
6973 let start = level * p;
6974 dense
6975 .slice_mut(s![i, start..start + p])
6976 .assign(&base.row(i).mapv(|v| -v));
6977 }
6978 }
6979 }
6980 let mut penalties = Vec::<Array2<f64>>::with_capacity(inner_built.penalties.len());
6981 let active_penalty_indices = inner_built
6982 .penaltyinfo
6983 .iter()
6984 .enumerate()
6985 .filter_map(|(idx, info)| info.active.then_some(idx))
6986 .collect::<Vec<_>>();
6987 if active_penalty_indices.len() != inner_built.penalties.len() {
6988 crate::bail_invalid_basis!(
6989 "internal sz penalty metadata mismatch: activeinfos={}, penalties={}",
6990 active_penalty_indices.len(),
6991 inner_built.penalties.len()
6992 );
6993 }
6994 for (penalty_pos, s_inner) in inner_built.penalties.iter().enumerate() {
6995 let mut s_big = Array2::<f64>::zeros((p * l_minus_one, p * l_minus_one));
6996 for a in 0..l_minus_one {
6997 for b in 0..l_minus_one {
6998 let factor = if a == b { 2.0 } else { 1.0 };
6999 let mut block = s_big.slice_mut(s![a * p..(a + 1) * p, b * p..(b + 1) * p]);
7000 block.assign(&s_inner.mapv(|v| v * factor));
7001 }
7002 }
7003 let (s_big, factor_smooth_scale) = normalize_penalty_in_constrained_space(&s_big);
7004 let info_idx = active_penalty_indices[penalty_pos];
7005 inner_built.penaltyinfo[info_idx].normalization_scale *= factor_smooth_scale;
7006 penalties.push(s_big);
7007 }
7008 inner_built.dim = p * l_minus_one;
7009 inner_built.design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense));
7010 inner_built.penalties = penalties;
7011 inner_built.ops = vec![None; inner_built.penalties.len()];
7012 inner_built.nullspaces = inner_built
7013 .nullspaces
7014 .iter()
7015 .map(|ns| ns.saturating_mul(l_minus_one))
7016 .collect();
7017 inner_built.null_eigenvectors =
7024 crate::basis::recompute_null_eigenvectors(&inner_built.penalties)?;
7025 inner_built.joint_null_rotation =
7026 crate::basis::compute_joint_null_rotation(&inner_built.penalties)?;
7027 inner_built.kronecker_factored = None;
7028 return Ok(inner_built);
7029 }
7030 SmoothBasisSpec::BSpline1D { feature_col, spec } => {
7031 if *feature_col >= data.ncols() {
7032 crate::bail_dim_basis!(
7033 "term '{}' feature column {} out of bounds for {} columns",
7034 term.name,
7035 feature_col,
7036 data.ncols()
7037 );
7038 }
7039 let mut spec_local = spec.clone();
7040 if term.shape != ShapeConstraint::None {
7041 spec_local.identifiability = BSplineIdentifiability::None;
7044 }
7045 build_bspline_basis_1d(data.column(*feature_col), &spec_local)?
7049 }
7050 SmoothBasisSpec::ThinPlate {
7051 feature_cols,
7052 spec,
7053 input_scales,
7054 } => {
7055 if term.shape != ShapeConstraint::None {
7056 if feature_cols.len() != 1 {
7057 crate::bail_invalid_basis!(
7058 "ShapeConstraint::{:?} for term '{}' on ThinPlate basis requires exactly 1 feature axis; found {}",
7059 term.shape,
7060 term.name,
7061 feature_cols.len()
7062 );
7063 }
7064 shape_axis_col = Some(feature_cols[0]);
7065 }
7066 let mut x = select_columns(data, feature_cols)?;
7067 let (scales, length_scale_eff) = if let Some(s) = input_scales {
7073 apply_input_standardization(&mut x, s);
7074 (
7075 Some(s.clone()),
7076 compensate_length_scale_for_standardization(spec.length_scale, s),
7077 )
7078 } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7079 apply_input_standardization(&mut x, &s);
7080 let l_eff = compensate_length_scale_for_standardization(spec.length_scale, &s);
7081 (Some(s), l_eff)
7082 } else {
7083 (None, spec.length_scale)
7084 };
7085 let mut spec_local = spec.clone();
7086 spec_local.length_scale = length_scale_eff;
7087 if matches!(
7088 spec_local.identifiability,
7089 SpatialIdentifiability::OrthogonalToParametric
7090 ) {
7091 spec_local.identifiability = SpatialIdentifiability::None;
7092 }
7093 let mut result = build_thin_plate_basis(x.view(), &spec_local).map_err(|err| {
7094 rewrite_thin_plate_knots_error(err, &term.name, feature_cols.len(), spec)
7095 })?;
7096 match &mut result.metadata {
7104 BasisMetadata::ThinPlate {
7105 input_scales: ms,
7106 length_scale,
7107 ..
7108 } => {
7109 *ms = scales;
7110 *length_scale = spec.length_scale;
7111 }
7112 BasisMetadata::Duchon {
7113 input_scales: ms,
7114 length_scale,
7115 ..
7116 } => {
7117 *ms = scales;
7118 *length_scale = Some(spec.length_scale);
7128 }
7129 _ => {}
7130 }
7131 result
7132 }
7133 SmoothBasisSpec::Sphere { feature_cols, spec } => {
7134 if term.shape != ShapeConstraint::None {
7135 crate::bail_invalid_basis!(
7136 "ShapeConstraint::{:?} for term '{}' is not supported on spherical splines",
7137 term.shape,
7138 term.name
7139 );
7140 }
7141 let x = select_columns(data, feature_cols)?;
7142 build_spherical_spline_basis(x.view(), spec)?
7143 }
7144 SmoothBasisSpec::ConstantCurvature { feature_cols, spec } => {
7145 if term.shape != ShapeConstraint::None {
7146 crate::bail_invalid_basis!(
7147 "ShapeConstraint::{:?} for term '{}' is not supported on constant-curvature smooths",
7148 term.shape,
7149 term.name
7150 );
7151 }
7152 let x = select_columns(data, feature_cols)?;
7159 build_constant_curvature_basis(x.view(), spec)?
7160 }
7161 SmoothBasisSpec::MeasureJet {
7162 feature_cols,
7163 spec,
7164 input_scales,
7165 } => {
7166 if term.shape != ShapeConstraint::None {
7167 crate::bail_invalid_basis!(
7168 "ShapeConstraint::{:?} for term '{}' is not supported on measure-jet smooths",
7169 term.shape,
7170 term.name
7171 );
7172 }
7173 let mut x = select_columns(data, feature_cols)?;
7174 let (scales, length_scale_eff) = if let Some(s) = input_scales {
7186 apply_input_standardization(&mut x, s);
7187 (Some(s.clone()), spec.length_scale)
7188 } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7189 apply_input_standardization(&mut x, &s);
7190 let l_eff = if spec.length_scale > 0.0 {
7191 compensate_length_scale_for_standardization(spec.length_scale, &s)
7192 } else {
7193 spec.length_scale
7194 };
7195 (Some(s), l_eff)
7196 } else {
7197 (None, spec.length_scale)
7198 };
7199 let mut spec_local = spec.clone();
7200 spec_local.length_scale = length_scale_eff;
7201 let mut result = build_measure_jet_basis(x.view(), &spec_local)?;
7202 if let BasisMetadata::MeasureJet {
7203 input_scales: ms, ..
7204 } = &mut result.metadata
7205 {
7206 *ms = scales;
7207 }
7208 result
7209 }
7210 SmoothBasisSpec::Matern {
7211 feature_cols,
7212 spec,
7213 input_scales,
7214 } => {
7215 if term.shape != ShapeConstraint::None {
7216 if feature_cols.len() != 1 {
7217 crate::bail_invalid_basis!(
7218 "ShapeConstraint::{:?} for term '{}' on Matern basis requires exactly 1 feature axis; found {}",
7219 term.shape,
7220 term.name,
7221 feature_cols.len()
7222 );
7223 }
7224 shape_axis_col = Some(feature_cols[0]);
7225 }
7226 let mut x = select_columns(data, feature_cols)?;
7227 let (scales, length_scale_eff) = if let Some(s) = input_scales {
7242 apply_input_standardization(&mut x, s);
7243 (
7244 Some(s.clone()),
7245 compensate_length_scale_for_standardization(spec.length_scale, s),
7246 )
7247 } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7248 apply_input_standardization(&mut x, &s);
7249 let l_eff = compensate_length_scale_for_standardization(spec.length_scale, &s);
7250 (Some(s), l_eff)
7251 } else {
7252 (None, spec.length_scale)
7253 };
7254 let mut spec_local = spec.clone();
7255 spec_local.length_scale = length_scale_eff;
7256 let mut result = build_matern_basiswithworkspace(x.view(), &spec_local, workspace)?;
7257 if let BasisMetadata::Matern {
7258 input_scales,
7259 length_scale,
7260 ..
7261 } = &mut result.metadata
7262 {
7263 *input_scales = scales;
7264 *length_scale = spec.length_scale;
7265 }
7266 result
7267 }
7268 SmoothBasisSpec::Duchon {
7269 feature_cols,
7270 spec,
7271 input_scales,
7272 } => {
7273 if term.shape != ShapeConstraint::None {
7274 if feature_cols.len() != 1 {
7275 crate::bail_invalid_basis!(
7276 "ShapeConstraint::{:?} for term '{}' on Duchon basis requires exactly 1 feature axis; found {}",
7277 term.shape,
7278 term.name,
7279 feature_cols.len()
7280 );
7281 }
7282 shape_axis_col = Some(feature_cols[0]);
7283 }
7284 let mut x = select_columns(data, feature_cols)?;
7285 let (scales, length_scale_eff) = if let Some(s) = input_scales {
7296 apply_input_standardization(&mut x, s);
7297 (
7298 Some(s.clone()),
7299 compensate_optional_length_scale_for_standardization(spec.length_scale, s),
7300 )
7301 } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7302 apply_input_standardization(&mut x, &s);
7303 let l_eff =
7304 compensate_optional_length_scale_for_standardization(spec.length_scale, &s);
7305 (Some(s), l_eff)
7306 } else {
7307 (None, spec.length_scale)
7308 };
7309 let mut spec_local = spec.clone();
7310 spec_local.length_scale = length_scale_eff;
7311 if matches!(
7312 spec_local.identifiability,
7313 SpatialIdentifiability::OrthogonalToParametric
7314 ) {
7315 spec_local.identifiability = SpatialIdentifiability::None;
7316 }
7317 let mut result = build_duchon_basiswithworkspace(x.view(), &spec_local, workspace)?;
7318 if let BasisMetadata::Duchon {
7319 input_scales,
7320 length_scale,
7321 ..
7322 } = &mut result.metadata
7323 {
7324 *input_scales = scales;
7325 *length_scale = spec.length_scale;
7326 }
7327 result
7328 }
7329 SmoothBasisSpec::Pca {
7330 feature_cols,
7331 basis_matrix,
7332 centered,
7333 smooth_penalty,
7334 center_mean,
7335 pca_basis_path,
7336 chunk_size,
7337 } => {
7338 if term.shape != ShapeConstraint::None {
7339 crate::bail_invalid_basis!(
7340 "ShapeConstraint::{:?} for term '{}' is not supported on Pca basis",
7341 term.shape,
7342 term.name
7343 );
7344 }
7345 build_pca_smooth_basis(
7346 data,
7347 feature_cols,
7348 basis_matrix,
7349 *centered,
7350 *smooth_penalty,
7351 center_mean.as_ref(),
7352 pca_basis_path.as_ref(),
7353 *chunk_size,
7354 )?
7355 }
7356 SmoothBasisSpec::TensorBSpline { feature_cols, spec } => {
7357 build_tensor_bspline_basis(data, feature_cols, spec)?
7358 }
7359 SmoothBasisSpec::ByVariable { .. } => {
7360 crate::bail_invalid_basis!(
7361 "internal: ByVariable smooths must return before inner basis dispatch"
7362 );
7363 }
7364 SmoothBasisSpec::BySmooth { .. } => {
7365 crate::bail_invalid_basis!("internal: BySmooth smooths must be lowered to ByVariable before inner basis dispatch"
7366 .to_string(),);
7367 }
7368 SmoothBasisSpec::FactorSmooth { spec } => {
7369 if term.shape != ShapeConstraint::None {
7370 crate::bail_invalid_basis!(
7371 "ShapeConstraint::{:?} is unsupported for factor smooth term '{}'",
7372 term.shape,
7373 term.name
7374 );
7375 }
7376 return build_factor_smooth(data, spec, &term.name, workspace);
7377 }
7378 };
7379
7380 if let SmoothBasisSpec::Matern { .. } = &term.basis {
7396 let (penalties, nullspace_dims, penaltyinfo) =
7397 matern_operator_penalty_triplet_from_metadata(&built.metadata)?;
7398 built.penalties = penalties;
7399 built.nullspace_dims = nullspace_dims;
7400 built.penaltyinfo = penaltyinfo;
7401 }
7402
7403 let p_local = built.design.ncols();
7404 let mut metadata = built.metadata.clone();
7405 let kron_factored = if term.shape == ShapeConstraint::None {
7408 built.kronecker_factored
7409 } else {
7410 None
7411 };
7412 let mut design_t = built.design;
7413 let mut penalties_t: Vec<Array2<f64>> = built.penalties;
7414 let mut ops_t: Vec<Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>> =
7419 built.ops;
7420 if matches!(
7421 spatial_identifiability_policy(term),
7422 Some(SpatialIdentifiability::OrthogonalToParametric)
7423 ) {
7424 metadata = freeze_raw_spatial_metadata(metadata, design_t.ncols());
7425 }
7426
7427 let active_penaltyinfo_t = built
7428 .penaltyinfo
7429 .iter()
7430 .filter(|info| info.active)
7431 .cloned()
7432 .collect::<Vec<_>>();
7433 let pre_dropped_penaltyinfo_t = built
7434 .penaltyinfo
7435 .iter()
7436 .filter(|info| !info.active)
7437 .cloned()
7438 .collect::<Vec<_>>();
7439 let use_box_reparam =
7440 term.shape != ShapeConstraint::None && shape_uses_box_reparameterization(&term.basis);
7441 if let Some((order, sign)) = shape_order_and_sign(term.shape)
7442 && use_box_reparam
7443 {
7444 let t = if order == 2 {
7459 let bspline_meta = match &metadata {
7460 BasisMetadata::BSpline1D {
7461 knots,
7462 degree,
7463 periodic,
7464 ..
7465 } if periodic.is_none() => Some((knots.clone(), degree.unwrap_or(0))),
7466 _ => None,
7467 };
7468 match bspline_meta {
7469 Some((knots, degree)) if degree >= 1 => {
7470 let greville = crate::basis::compute_greville_abscissae(&knots, degree)?;
7471 if greville.len() != p_local {
7472 crate::bail_invalid_basis!(
7473 "shape-constraint Greville abscissae count {} does not match basis dim {} for term '{}'",
7474 greville.len(),
7475 p_local,
7476 term.name
7477 );
7478 }
7479 convex_divided_difference_transform_matrix(&greville, sign)?
7480 }
7481 _ => cumulative_sum_transform_matrix(p_local, order, sign),
7482 }
7483 } else {
7484 cumulative_sum_transform_matrix(p_local, order, sign)
7485 };
7486 let inner_dense = match design_t {
7490 DesignMatrix::Dense(d) => d,
7491 DesignMatrix::Sparse(sp) => gam_linalg::matrix::DenseDesignMatrix::from(
7492 sp.try_to_dense_arc("shape-constrained coefficient transform")
7493 .map_err(BasisError::InvalidInput)?,
7494 ),
7495 };
7496 let coeff_op = gam_linalg::matrix::CoefficientTransformOperator::new(inner_dense, t.clone())
7497 .map_err(|e| BasisError::InvalidInput(format!("CoefficientTransformOperator: {e}")))?;
7498 design_t = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(coeff_op)));
7499 if penalties_t.len() != active_penaltyinfo_t.len() {
7500 crate::bail_invalid_basis!(
7501 "internal box-reparam penalty/info mismatch for term '{}': penalties={}, infos={}",
7502 term.name,
7503 penalties_t.len(),
7504 active_penaltyinfo_t.len()
7505 );
7506 }
7507 let transformed_wiggliness = penalties_t
7523 .iter()
7524 .zip(active_penaltyinfo_t.iter())
7525 .find(|(_, info)| !matches!(info.source, PenaltySource::DoublePenaltyNullspace))
7526 .map(|(s_local, _)| {
7527 let tt_s = fast_atb(&t, s_local);
7528 fast_ab(&tt_s, &t)
7529 });
7530 let mut rebuilt = Vec::with_capacity(penalties_t.len());
7531 for (s_local, info) in penalties_t.iter().zip(active_penaltyinfo_t.iter()) {
7532 if matches!(info.source, PenaltySource::DoublePenaltyNullspace) {
7533 let s_wiggle_t = transformed_wiggliness.as_ref().ok_or_else(|| {
7534 BasisError::InvalidInput(format!(
7535 "box-reparam term '{}' has a double-penalty ridge but no primary wiggliness penalty to derive its nullspace from",
7536 term.name
7537 ))
7538 })?;
7539 let ridge = crate::basis::build_nullspace_shrinkage_penalty(s_wiggle_t)?
7540 .map(|shrink| shrink.sym_penalty)
7541 .unwrap_or_else(|| Array2::<f64>::zeros((p_local, p_local)));
7542 rebuilt.push(ridge);
7543 } else {
7544 let tt_s = fast_atb(&t, s_local);
7545 rebuilt.push(fast_ab(&tt_s, &t));
7546 }
7547 }
7548 penalties_t = rebuilt;
7549 ops_t = vec![None; penalties_t.len()];
7552 }
7553 if penalties_t.len() != active_penaltyinfo_t.len() {
7554 crate::bail_invalid_basis!(
7555 "internal penalty metadata mismatch for term '{}': active penalties={}, active infos={}",
7556 term.name,
7557 penalties_t.len(),
7558 active_penaltyinfo_t.len()
7559 );
7560 }
7561 if ops_t.len() != penalties_t.len() {
7562 ops_t = vec![None; penalties_t.len()];
7563 }
7564 let penalty_candidates = penalties_t
7565 .into_iter()
7566 .zip(active_penaltyinfo_t.into_iter())
7567 .zip(ops_t.into_iter())
7568 .map(
7569 |((matrix, info), op_in)| -> Result<PenaltyCandidate, BasisError> {
7570 let (matrix, c_new) = normalize_penalty_in_constrained_space(&matrix);
7571 let normalization_scale = info.normalization_scale * c_new;
7572 let op_scale = 1.0 / c_new;
7573 let kronecker_scale = 1.0 / c_new;
7574 let scaled_op = if op_scale > 0.0 && op_scale.is_finite() {
7577 op_in.map(|op| {
7578 std::sync::Arc::new(crate::analytic_penalties::ScaledPenaltyOp::new(
7579 op, op_scale,
7580 ))
7581 as std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>
7582 })
7583 } else {
7584 None
7585 };
7586 let kronecker_factors = info.kronecker_factors.map(|mut factors| {
7587 if let Some(first) = factors.first_mut() {
7588 first.mapv_inplace(|v| v * kronecker_scale);
7589 }
7590 factors
7591 });
7592 Ok(PenaltyCandidate {
7593 nullspace_dim_hint: info.nullspace_dim_hint,
7594 matrix,
7595 source: info.source,
7596 normalization_scale,
7597 kronecker_factors,
7598 op: scaled_op,
7599 })
7600 },
7601 )
7602 .collect::<Result<Vec<_>, _>>()?;
7603 let (penalties_t, nullspaces_t, penaltyinfo_t, null_eigenvectors_t, ops_t) =
7604 crate::basis::filter_active_penalty_candidates_with_ops(penalty_candidates)?;
7605 let shape_linear_constraints = if term.shape != ShapeConstraint::None && !use_box_reparam {
7606 let axis = shape_axis_col.ok_or_else(|| {
7607 BasisError::InvalidInput(format!(
7608 "internal shape-constraint axis missing for term '{}'",
7609 term.name
7610 ))
7611 })?;
7612 let (x_shape_eval, design_shape_eval) =
7613 build_shape_constraint_design_1d(data, term, &metadata, axis)?;
7614 build_shape_linear_constraints_1d(
7615 x_shape_eval.view(),
7616 design_shape_eval.view(),
7617 term.shape,
7618 )?
7619 } else {
7620 None
7621 };
7622 let linear_constraints_local = merge_linear_constraints_global(shape_linear_constraints, None);
7623
7624 let joint_null_rotation = match term.joint_null_rotation.clone() {
7643 Some(persisted) => Some(persisted),
7644 None if smooth_has_frozen_identifiability(term) => None,
7645 None if kron_factored.is_some() => None,
7646 None => crate::basis::compute_joint_null_rotation(&penalties_t)?,
7647 };
7648
7649 Ok(LocalSmoothTermBuild {
7650 dim: p_local,
7651 design: design_t,
7652 penalties: penalties_t,
7653 ops: ops_t,
7654 nullspaces: nullspaces_t,
7655 null_eigenvectors: null_eigenvectors_t,
7656 joint_null_rotation,
7657 penaltyinfo: penaltyinfo_t,
7658 pre_dropped_penaltyinfo: pre_dropped_penaltyinfo_t,
7659 metadata,
7660 linear_constraints: linear_constraints_local,
7661 box_reparam: use_box_reparam,
7662 kronecker_factored: kron_factored,
7663 })
7664}
7665
7666pub fn build_smooth_design(
7667 data: ArrayView2<'_, f64>,
7668 terms: &[SmoothTermSpec],
7669) -> Result<RawSmoothDesign, BasisError> {
7670 let mut ws = crate::basis::BasisWorkspace::new();
7671 build_smooth_design_withworkspace(data, terms, &mut ws)
7672}
7673
7674pub fn build_smooth_design_withworkspace(
7681 data: ArrayView2<'_, f64>,
7682 terms: &[SmoothTermSpec],
7683 workspace: &mut crate::basis::BasisWorkspace,
7684) -> Result<RawSmoothDesign, BasisError> {
7685 validate_smooth_terms_finite_inputs(data, terms)?;
7686 build_smooth_design_withworkspace_unvalidated(data, terms, workspace)
7687}
7688
7689pub fn build_smooth_design_withworkspace_unvalidated(
7690 data: ArrayView2<'_, f64>,
7691 terms: &[SmoothTermSpec],
7692 workspace: &mut crate::basis::BasisWorkspace,
7693) -> Result<RawSmoothDesign, BasisError> {
7694 let mut planned_blocks = plan_joint_spatial_centers_for_term_blocks(data, &[terms.to_vec()])?;
7695 let planned_terms = planned_blocks.pop().ok_or_else(|| {
7696 BasisError::InvalidInput(
7697 "joint spatial center planner returned no smooth blocks".to_string(),
7698 )
7699 })?;
7700 let policy = workspace.policy().clone();
7701 let local_builds: Vec<LocalSmoothTermBuild> = {
7702 use rayon::iter::{IntoParallelIterator, ParallelIterator};
7703 planned_terms
7704 .into_par_iter()
7705 .map(|term| {
7706 let mut term_workspace = crate::basis::BasisWorkspace::with_policy(policy.clone());
7707 build_single_local_smooth_term(data, &term, &mut term_workspace)
7708 })
7709 .collect::<Result<Vec<_>, _>>()?
7710 };
7711
7712 let total_p: usize = local_builds.iter().map(|built| built.dim).sum();
7713
7714 let mut local_designs: Vec<DesignMatrix> = Vec::with_capacity(local_builds.len());
7715 let mut terms_out = Vec::<SmoothTerm>::with_capacity(terms.len());
7716 let mut penalties_global = Vec::<BlockwisePenalty>::new();
7717 let mut nullspace_dims_global = Vec::<usize>::new();
7718 let mut penaltyinfo_global = Vec::<PenaltyBlockInfo>::new();
7719 let mut dropped_penaltyinfo_global = Vec::<DroppedPenaltyBlockInfo>::new();
7720 let mut coefficient_lower_bounds = Array1::<f64>::from_elem(total_p, f64::NEG_INFINITY);
7721 let mut any_bounds = false;
7722 let mut linear_constraintsrows: Vec<(usize, usize, Array1<f64>)> = Vec::new();
7727 let mut linear_constraints_b: Vec<f64> = Vec::new();
7728
7729 let mut col_start = 0usize;
7730 for (term, mut built) in terms.iter().zip(local_builds.into_iter()) {
7731 let p_local = built.dim;
7732 let col_end = col_start + p_local;
7733 let lb_local = if built.box_reparam {
7734 shape_lower_bounds_local(term.shape, p_local)
7735 } else {
7736 None
7737 };
7738
7739 let applied_rotation: Option<crate::basis::JointNullRotation> = match (
7771 built.joint_null_rotation.take(),
7772 lb_local.is_some(),
7773 built.linear_constraints.is_some(),
7774 ) {
7775 (Some(rot), false, false) => {
7776 let q = &rot.rotation;
7777 let dense = built
7778 .design
7779 .try_to_dense_by_chunks("joint-null absorption rotation")
7780 .map_err(BasisError::InvalidInput)?;
7781 let rotated = gam_linalg::faer_ndarray::fast_ab(&dense, q);
7782 built.design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(rotated));
7783 built.penalties = built
7784 .penalties
7785 .into_iter()
7786 .map(|s_local| {
7787 let qt_s = gam_linalg::faer_ndarray::fast_atb(q, &s_local);
7788 gam_linalg::faer_ndarray::fast_ab(&qt_s, q)
7789 })
7790 .collect();
7791 built.ops = vec![None; built.penalties.len()];
7792 built.kronecker_factored = None;
7793 Some(rot)
7794 }
7795 (Some(_), _, _) => None,
7796 (None, _, _) => None,
7797 };
7798
7799 let activeinfos = built
7800 .penaltyinfo
7801 .iter()
7802 .filter(|info| info.active)
7803 .collect::<Vec<_>>();
7804 if activeinfos.len() != built.penalties.len() {
7805 crate::bail_invalid_basis!(
7806 "internal penalty info mismatch for term '{}': activeinfos={}, penalties={}",
7807 term.name,
7808 activeinfos.len(),
7809 built.penalties.len()
7810 );
7811 }
7812 for (((s_local, &ns), info), op_local) in built
7813 .penalties
7814 .iter()
7815 .zip(built.nullspaces.iter())
7816 .zip(activeinfos.into_iter())
7817 .zip(built.ops.iter())
7818 {
7819 let global_index = penalties_global.len();
7820 penalties_global.push(
7821 BlockwisePenalty::new(col_start..col_end, s_local.clone())
7822 .with_op(op_local.clone()),
7823 );
7824 nullspace_dims_global.push(ns);
7825 let mut penalty = info.clone();
7826 penalty.nullspace_dim_hint = ns;
7827 penaltyinfo_global.push(PenaltyBlockInfo {
7828 global_index,
7829 termname: Some(term.name.clone()),
7830 penalty,
7831 });
7832 }
7833 for info in built.penaltyinfo.iter().filter(|info| !info.active) {
7834 dropped_penaltyinfo_global.push(DroppedPenaltyBlockInfo {
7835 termname: Some(term.name.clone()),
7836 penalty: info.clone(),
7837 });
7838 }
7839 for info in &built.pre_dropped_penaltyinfo {
7840 dropped_penaltyinfo_global.push(DroppedPenaltyBlockInfo {
7841 termname: Some(term.name.clone()),
7842 penalty: info.clone(),
7843 });
7844 }
7845
7846 if let Some(lin_local) = &built.linear_constraints {
7847 for r in 0..lin_local.a.nrows() {
7848 linear_constraintsrows.push((col_start, col_end, lin_local.a.row(r).to_owned()));
7849 linear_constraints_b.push(lin_local.b[r]);
7850 }
7851 }
7852 if let Some(lb_local) = &lb_local {
7853 coefficient_lower_bounds
7854 .slice_mut(s![col_start..col_end])
7855 .assign(lb_local);
7856 any_bounds = true;
7857 }
7858
7859 local_designs.push(built.design);
7861
7862 terms_out.push(SmoothTerm {
7863 name: term.name.clone(),
7864 coeff_range: col_start..col_end,
7865 shape: term.shape,
7866 penalties_local: built.penalties,
7867 nullspace_dims: built.nullspaces,
7868 penaltyinfo_local: built.penaltyinfo,
7869 metadata: built.metadata,
7870 lower_bounds_local: lb_local,
7871 linear_constraints_local: built.linear_constraints,
7872 kronecker_factored: built.kronecker_factored.take(),
7873 joint_null_rotation: applied_rotation,
7874 unabsorbed_global_orthogonality: None,
7875 });
7876
7877 col_start = col_end;
7878 }
7879
7880 assert_eq!(
7881 penalties_global.len(),
7882 nullspace_dims_global.len(),
7883 "global smooth penalty/nullspace bookkeeping diverged"
7884 );
7885 assert_eq!(
7886 penalties_global.len(),
7887 penaltyinfo_global.len(),
7888 "global smooth penalty metadata bookkeeping diverged"
7889 );
7890
7891 Ok(RawSmoothDesign {
7892 term_designs: local_designs,
7893 penalties: penalties_global,
7894 nullspace_dims: nullspace_dims_global,
7895 penaltyinfo: penaltyinfo_global,
7896 dropped_penaltyinfo: dropped_penaltyinfo_global,
7897 terms: terms_out,
7898 coefficient_lower_bounds: if any_bounds {
7899 Some(coefficient_lower_bounds)
7900 } else {
7901 None
7902 },
7903 linear_constraints: if linear_constraintsrows.is_empty() {
7904 None
7905 } else {
7906 let mut a = Array2::<f64>::zeros((linear_constraintsrows.len(), total_p));
7907 for (i, (cs, ce, values)) in linear_constraintsrows.iter().enumerate() {
7908 a.row_mut(i).slice_mut(s![*cs..*ce]).assign(values);
7909 }
7910 Some(LinearInequalityConstraints {
7911 a,
7912 b: Array1::from_vec(linear_constraints_b),
7913 })
7914 },
7915 })
7916}