1use coefficient_transforms::{
2 convex_divided_difference_transform_matrix, cumulative_exp, cumulative_sum_transform_matrix,
3 second_cumulative_exp,
4};
5
6pub use error::SmoothError;
7
8use input_standardization::{
9 apply_input_standardization, compensate_length_scale_for_standardization,
10 compensate_optional_length_scale_for_standardization, compute_spatial_input_scales,
11};
12
13use shape_constraints::{
14 build_shape_constraint_design_1d, build_shape_linear_constraints_1d,
15 merge_linear_constraints_global, shape_lower_bounds_local, shape_order_and_sign,
16 shape_supports_basis, shape_uses_box_reparameterization,
17};
18
19pub fn describe_thin_plate_center_request(strategy: &CenterStrategy) -> String {
20 match strategy {
21 CenterStrategy::Auto(inner) => describe_thin_plate_center_request(inner),
22 CenterStrategy::UserProvided(centers) => format!("{} centers", centers.nrows()),
23 CenterStrategy::EqualMass { num_centers }
24 | CenterStrategy::EqualMassCovarRepresentative { num_centers }
25 | CenterStrategy::FarthestPoint { num_centers }
26 | CenterStrategy::KMeans { num_centers, .. } => format!("{num_centers} centers"),
27 CenterStrategy::UniformGrid { points_per_dim } => {
28 format!("uniform grid with {points_per_dim} points per dimension")
29 }
30 }
31}
32
33pub fn rewrite_thin_plate_knots_error(
34 err: BasisError,
35 termname: &str,
36 feature_count: usize,
37 spec: &ThinPlateBasisSpec,
38) -> BasisError {
39 match err {
40 BasisError::InvalidInput(msg)
43 if msg.contains("thin-plate spline requires at least")
44 && (msg.contains("centers to span") || msg.contains("knots to span")) =>
45 {
46 let min_centers = crate::basis::thin_plate_polynomial_basis_dimension(feature_count);
47 let requested = describe_thin_plate_center_request(&spec.center_strategy);
48 BasisError::InvalidInput(format!(
49 "joint TPS term '{termname}' over {feature_count} covariates with {requested} is invalid; minimum centers is {min_centers}"
50 ))
51 }
52 BasisError::InvalidInput(msg)
57 if msg.starts_with("requested ") && msg.contains(" knots but only ") =>
58 {
59 let min_centers = crate::basis::thin_plate_polynomial_basis_dimension(feature_count);
60 let requested = describe_thin_plate_center_request(&spec.center_strategy);
61 BasisError::InvalidInput(format!(
62 "joint TPS term '{termname}' over {feature_count} covariates with {requested} is invalid; minimum centers is {min_centers}"
63 ))
64 }
65 other => other,
66 }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
70pub enum ShapeConstraint {
71 None,
72 MonotoneIncreasing,
73 MonotoneDecreasing,
74 Convex,
75 Concave,
76}
77
78pub fn parse_shape_constraint(raw: &str) -> Result<ShapeConstraint, String> {
89 let normalized = raw.trim().to_ascii_lowercase().replace('-', "_");
90 match normalized.as_str() {
91 "" | "none" => Ok(ShapeConstraint::None),
92 "monotone_increasing" | "monotonic_increasing" | "increasing" | "mono_inc" | "mpi" => {
93 Ok(ShapeConstraint::MonotoneIncreasing)
94 }
95 "monotone_decreasing" | "monotonic_decreasing" | "decreasing" | "mono_dec" | "mpd" => {
96 Ok(ShapeConstraint::MonotoneDecreasing)
97 }
98 "convex" | "cvx" => Ok(ShapeConstraint::Convex),
99 "concave" | "ccv" => Ok(ShapeConstraint::Concave),
100 other => Err(format!(
101 "unknown shape constraint {other:?}; expected one of \
102 \"none\", \"monotone_increasing\", \"monotone_decreasing\", \
103 \"convex\", \"concave\""
104 )),
105 }
106}
107
108impl ShapeConstraint {
109 pub fn dsl_str(&self) -> &'static str {
112 match self {
113 ShapeConstraint::None => "none",
114 ShapeConstraint::MonotoneIncreasing => "monotone_increasing",
115 ShapeConstraint::MonotoneDecreasing => "monotone_decreasing",
116 ShapeConstraint::Convex => "convex",
117 ShapeConstraint::Concave => "concave",
118 }
119 }
120}
121
122pub const SMOOTH_HEAD_KEYWORDS: [&str; 11] = [
125 "s",
126 "smooth",
127 "te",
128 "tensor",
129 "thinplate",
130 "tps",
131 "duchon",
132 "matern",
133 "sphere",
134 "bs",
135 "bspline",
136];
137
138pub fn apply_shape_constraints_to_formula(
151 formula: &str,
152 constraints: &[(String, String)],
153) -> Result<String, String> {
154 use std::collections::{BTreeMap, BTreeSet};
155
156 if constraints.is_empty() {
157 return Ok(formula.to_string());
158 }
159 let strip_ws = |s: &str| -> String { s.chars().filter(|c| !c.is_whitespace()).collect() };
160
161 let mut wanted: BTreeMap<String, &'static str> = BTreeMap::new();
163 let mut originals: BTreeMap<String, String> = BTreeMap::new();
165 for (key, kind_raw) in constraints {
166 let kind = parse_shape_constraint(kind_raw)?;
167 let nk = strip_ws(key);
168 originals.entry(nk.clone()).or_insert_with(|| key.clone());
169 if kind != ShapeConstraint::None {
170 wanted.insert(nk, kind.dsl_str());
171 }
172 }
173 if wanted.is_empty() {
174 return Ok(formula.to_string());
175 }
176
177 let chars: Vec<char> = formula.chars().collect();
178 let n = chars.len();
179 let is_ident = |c: char| c.is_ascii_alphanumeric() || c == '_';
180
181 let mut out = String::with_capacity(formula.len() + 32);
182 let mut matched: BTreeSet<String> = BTreeSet::new();
183 let mut i = 0usize;
184 while i < n {
185 let mut head: Option<(usize, usize)> = None; let mut p = i;
189 while p < n {
190 let boundary = p == 0 || !is_ident(chars[p - 1]);
191 if boundary {
192 for kw in SMOOTH_HEAD_KEYWORDS.iter() {
193 let klen = kw.chars().count();
194 if p + klen > n || chars[p..p + klen].iter().collect::<String>() != **kw {
195 continue;
196 }
197 let mut q = p + klen;
198 while q < n && chars[q].is_whitespace() {
199 q += 1;
200 }
201 if q < n && chars[q] == '(' {
202 head = Some((p, q));
203 break;
204 }
205 }
206 }
207 if head.is_some() {
208 break;
209 }
210 p += 1;
211 }
212 let (head_start, paren_open) = match head {
213 Some(h) => h,
214 None => {
215 out.extend(chars[i..].iter());
216 break;
217 }
218 };
219 out.extend(chars[i..head_start].iter());
220
221 let body_start = paren_open + 1;
223 let mut depth = 1i32;
224 let mut j = body_start;
225 let mut in_str: Option<char> = None;
226 let mut closed = false;
227 while j < n {
228 let ch = chars[j];
229 if let Some(quote) = in_str {
230 if ch == quote {
231 in_str = None;
232 }
233 } else if ch == '\'' || ch == '"' {
234 in_str = Some(ch);
235 } else if ch == '(' {
236 depth += 1;
237 } else if ch == ')' {
238 depth -= 1;
239 if depth == 0 {
240 closed = true;
241 break;
242 }
243 }
244 j += 1;
245 }
246
247 if !closed {
248 out.extend(chars[head_start..].iter());
251 break;
252 }
253
254 let term_text: String = chars[head_start..=j].iter().collect();
255
256 let key_norm = strip_ws(&term_text);
257
258 match wanted.get(&key_norm) {
259 None => out.extend(chars[head_start..=j].iter()),
260 Some(kind) => {
261 let head_paren: String = chars[head_start..body_start].iter().collect();
262 let inside: String = chars[body_start..j].iter().collect();
263 let inside = inside.trim();
264 if inside.is_empty() {
265 out.push_str(&format!("{head_paren}shape={kind})"));
266 } else {
267 out.push_str(&format!("{head_paren}{inside}, shape={kind})"));
268 }
269 matched.insert(key_norm);
270 }
271 }
272
273 i = j + 1;
274 }
275
276 let mut missing: Vec<String> = wanted
277 .keys()
278 .filter(|k| !matched.contains(*k))
279 .map(|k| originals.get(k).cloned().unwrap_or_else(|| k.clone()))
280 .collect();
281
282 if !missing.is_empty() {
283 missing.sort();
284 return Err(format!(
285 "shape constraints referenced smooth term(s) not found in formula: {}",
286 missing.join(", ")
287 ));
288 }
289
290 Ok(out)
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
294pub enum BySmoothKind {
295 Numeric,
296 Level { level_bits: u64 },
297}
298
299#[derive(Debug, Clone, Serialize, Deserialize)]
300pub enum SmoothBasisSpec {
301 ByVariable {
311 inner: Box<SmoothBasisSpec>,
312 by_col: usize,
313 kind: BySmoothKind,
314 by: ByVariableSpec,
315 },
316 FactorSumToZero {
320 inner: Box<SmoothBasisSpec>,
321 by_col: usize,
322 levels: Vec<u64>,
323 #[serde(default)]
334 frozen_global_orthogonality: Option<Array2<f64>>,
335 },
336 BSpline1D {
337 feature_col: usize,
338 spec: BSplineBasisSpec,
339 },
340 BySmooth {
343 smooth: Box<SmoothBasisSpec>,
344 by_kind: ByVarKind,
345 },
346 FactorSmooth { spec: FactorSmoothSpec },
349 ThinPlate {
350 feature_cols: Vec<usize>,
351 spec: ThinPlateBasisSpec,
352 #[serde(default)]
356 input_scales: Option<Vec<f64>>,
357 },
358 Sphere {
359 feature_cols: Vec<usize>,
360 spec: SphericalSplineBasisSpec,
361 },
362 ConstantCurvature {
368 feature_cols: Vec<usize>,
369 spec: ConstantCurvatureBasisSpec,
370 },
371 Matern {
372 feature_cols: Vec<usize>,
373 spec: MaternBasisSpec,
374 #[serde(default)]
375 input_scales: Option<Vec<f64>>,
376 },
377 MeasureJet {
383 feature_cols: Vec<usize>,
384 spec: MeasureJetBasisSpec,
385 #[serde(default)]
386 input_scales: Option<Vec<f64>>,
387 },
388 Duchon {
389 feature_cols: Vec<usize>,
390 spec: DuchonBasisSpec,
391 #[serde(default)]
392 input_scales: Option<Vec<f64>>,
393 },
394 Pca {
395 feature_cols: Vec<usize>,
396 basis_matrix: Array2<f64>,
397 centered: bool,
398 #[serde(default = "default_pca_smooth_penalty")]
399 smooth_penalty: f64,
400 #[serde(default)]
401 center_mean: Option<Array1<f64>>,
402 #[serde(default)]
403 pca_basis_path: Option<PathBuf>,
404 #[serde(default = "default_pca_chunk_size")]
405 chunk_size: usize,
406 },
407 TensorBSpline {
412 feature_cols: Vec<usize>,
413 spec: TensorBSplineSpec,
414 },
415}
416
417impl SmoothBasisSpec {
418 pub fn min_sample_rows(&self) -> usize {
435 const RADIAL_FLOOR: usize = 5;
440
441 match self {
442 Self::ByVariable { inner, .. } => inner.min_sample_rows(),
443 Self::FactorSumToZero { inner, levels, .. } => {
444 let inner_min = inner.min_sample_rows();
448 let lvls = levels.len().saturating_sub(1).max(1);
449 inner_min.saturating_mul(lvls)
450 }
451 Self::BSpline1D { spec, .. } => bspline_basis_min_rows(spec),
452 Self::BySmooth { smooth, .. } => smooth.min_sample_rows(),
453 Self::FactorSmooth { spec } => {
454 bspline_basis_min_rows(&spec.marginal)
458 }
459 Self::ThinPlate { .. }
460 | Self::Sphere { .. }
461 | Self::ConstantCurvature { .. }
462 | Self::Matern { .. }
463 | Self::MeasureJet { .. }
464 | Self::Duchon { .. } => RADIAL_FLOOR,
465 Self::Pca { basis_matrix, .. } => basis_matrix.ncols().max(1),
466 Self::TensorBSpline { spec, .. } => {
467 let mut total: usize = 0;
513 for marginal in &spec.marginalspecs {
514 let m = bspline_basis_min_rows(marginal);
515 total = total.saturating_add(m.max(1));
516 }
517 total.max(RADIAL_FLOOR)
518 }
519 }
520 }
521
522 pub fn structural_kind(&self) -> &'static str {
533 match self {
534 Self::ByVariable { .. } => "by_variable",
535 Self::FactorSumToZero { .. } => "factor_sum_to_zero",
536 Self::BSpline1D { .. } => "bspline_1d",
537 Self::BySmooth { .. } => "by_smooth",
538 Self::FactorSmooth { .. } => "factor_smooth",
539 Self::ThinPlate { .. } => "thin_plate",
540 Self::Sphere { .. } => "sphere",
541 Self::ConstantCurvature { .. } => "constant_curvature",
542 Self::Matern { .. } => "matern",
543 Self::MeasureJet { .. } => "measurejet",
544 Self::Duchon { .. } => "duchon",
545 Self::Pca { .. } => "pca",
546 Self::TensorBSpline { .. } => "tensor_bspline",
547 }
548 }
549
550 pub fn is_marginally_centered_tensor(&self) -> bool {
559 matches!(
560 self,
561 Self::TensorBSpline { spec, .. }
562 if matches!(spec.identifiability, TensorBSplineIdentifiability::MarginalSumToZero)
563 )
564 }
565
566 pub fn is_sum_to_zero_factor_smooth(&self) -> bool {
583 matches!(
584 self,
585 Self::FactorSumToZero { .. }
586 | Self::FactorSmooth {
587 spec: FactorSmoothSpec {
588 flavour: FactorSmoothFlavour::Sz,
589 ..
590 }
591 }
592 )
593 }
594
595 pub fn structural_feature_cols(&self) -> Vec<usize> {
599 match self {
600 Self::ByVariable { inner, .. } | Self::FactorSumToZero { inner, .. } => {
601 inner.structural_feature_cols()
602 }
603 Self::BySmooth { smooth, .. } => smooth.structural_feature_cols(),
604 Self::FactorSmooth { .. } => Vec::new(),
605 Self::BSpline1D { feature_col, .. } => vec![*feature_col],
606 Self::ThinPlate { feature_cols, .. }
607 | Self::Sphere { feature_cols, .. }
608 | Self::ConstantCurvature { feature_cols, .. }
609 | Self::Matern { feature_cols, .. }
610 | Self::MeasureJet { feature_cols, .. }
611 | Self::Duchon { feature_cols, .. }
612 | Self::Pca { feature_cols, .. }
613 | Self::TensorBSpline { feature_cols, .. } => feature_cols.clone(),
614 }
615 }
616}
617
618pub fn bspline_basis_min_rows(spec: &crate::basis::BSplineBasisSpec) -> usize {
643 use crate::basis::BSplineKnotSpec;
644 let columns = match &spec.knotspec {
645 BSplineKnotSpec::Generate {
646 num_internal_knots, ..
647 } => *num_internal_knots + spec.degree + 1,
648 BSplineKnotSpec::Automatic {
649 num_internal_knots: Some(k),
650 ..
651 } => *k + spec.degree + 1,
652 BSplineKnotSpec::Automatic {
653 num_internal_knots: None,
654 ..
655 } => {
656 spec.degree + 2
660 }
661 BSplineKnotSpec::Provided(knots) => knots.len().saturating_sub(spec.degree + 1).max(1),
662 BSplineKnotSpec::NaturalCubicRegression { knots } => knots.len(),
664 BSplineKnotSpec::PeriodicUniform { num_basis, .. } => *num_basis,
665 };
666 let columns = columns.max(spec.degree + 2);
667
668 if spec.double_penalty {
669 const DOUBLE_PENALTY_FLOOR: usize = 2;
672 DOUBLE_PENALTY_FLOOR.min(columns).max(1)
673 } else {
674 columns
675 }
676}
677
678#[derive(Debug, Clone, Serialize, Deserialize)]
679pub enum ByVariableSpec {
680 Numeric,
681 Level { value_bits: u64, label: String },
682}
683
684
685#[derive(Debug, Clone, Serialize, Deserialize)]
686pub enum ByVarKind {
687 Numeric {
688 feature_col: usize,
689 },
690 Factor {
691 feature_col: usize,
692 ordered: bool,
693 frozen_levels: Option<Vec<u64>>,
694 },
695}
696
697#[derive(Debug, Clone, Serialize, Deserialize)]
698pub struct FactorSmoothSpec {
699 pub continuous_cols: Vec<usize>,
700 pub group_col: usize,
701 pub marginal: BSplineBasisSpec,
702 pub flavour: FactorSmoothFlavour,
703 pub group_frozen_levels: Option<Vec<u64>>,
704 #[serde(default)]
710 pub frozen_global_orthogonality: Option<Array2<f64>>,
711}
712
713#[derive(Debug, Clone, Serialize, Deserialize)]
714pub enum FactorSmoothFlavour {
715 Fs { m_null_penalty_orders: Vec<usize> },
716 Sz,
717 Re,
718}
719
720#[derive(Debug, Default, Clone, Serialize, Deserialize)]
721pub struct TensorBSplineSpec {
722 pub marginalspecs: Vec<BSplineBasisSpec>,
723 #[serde(default)]
724 pub periods: Vec<Option<f64>>,
725 pub double_penalty: bool,
726 #[serde(default)]
727 pub identifiability: TensorBSplineIdentifiability,
728 #[serde(default)]
729 pub penalty_decomposition: TensorBSplinePenaltyDecomposition,
730}
731
732#[derive(Debug, Default, Clone, Serialize, Deserialize)]
733pub enum TensorBSplineIdentifiability {
734 None,
735 #[default]
736 SumToZero,
737 MarginalSumToZero,
747 FrozenTransform {
748 transform: Array2<f64>,
749 },
750}
751
752#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
753pub enum TensorBSplinePenaltyDecomposition {
754 #[default]
757 MarginalKroneckerSum,
758 Separable,
762}
763
764#[derive(Debug, Clone, Serialize, Deserialize)]
765pub struct SmoothTermSpec {
766 pub name: String,
767 pub basis: SmoothBasisSpec,
768 pub shape: ShapeConstraint,
769 #[serde(default)]
778 pub joint_null_rotation: Option<crate::basis::JointNullRotation>,
779}
780
781#[derive(Debug, Clone)]
782pub struct SmoothTerm {
783 pub name: String,
784 pub coeff_range: Range<usize>,
785 pub shape: ShapeConstraint,
786 pub penalties_local: Vec<Array2<f64>>,
787 pub nullspace_dims: Vec<usize>,
788 pub penaltyinfo_local: Vec<PenaltyInfo>,
789 pub metadata: BasisMetadata,
790 pub lower_bounds_local: Option<Array1<f64>>,
793 pub linear_constraints_local: Option<LinearInequalityConstraints>,
796 pub kronecker_factored: Option<KroneckerFactoredBasis>,
799 pub joint_null_rotation: Option<crate::basis::JointNullRotation>,
822 pub unabsorbed_global_orthogonality: Option<Array2<f64>>,
832}
833
834impl SmoothTerm {
835 pub fn apply_rotation_to_predict(
851 &self,
852 x_new_raw: Array2<f64>,
853 ) -> Result<Array2<f64>, BasisError> {
854 let Some(rot) = self.joint_null_rotation.as_ref() else {
855 return Ok(x_new_raw);
856 };
857 let p_local = rot.rotation.nrows();
858 if x_new_raw.ncols() != p_local {
859 crate::bail_dim_basis!(
860 "joint-null rotation replay for term '{}': raw design has {} columns, \
861 rotation expects {} (the raw basis builder must emit the same column \
862 count as at fit time)",
863 self.name,
864 x_new_raw.ncols(),
865 p_local,
866 );
867 }
868 Ok(gam_linalg::faer_ndarray::fast_ab(
869 &x_new_raw,
870 &rot.rotation,
871 ))
872 }
873
874 pub fn wald_unpenalized_dim(&self) -> usize {
897 joint_unpenalized_dim(
898 self.coeff_range.len(),
899 &self.penalties_local,
900 &self.nullspace_dims,
901 )
902 }
903}
904
905pub fn joint_unpenalized_dim(
910 p_local: usize,
911 penalties_local: &[Array2<f64>],
912 nullspace_dims: &[usize],
913) -> usize {
914 use gam_linalg::faer_ndarray::FaerEigh;
915 if p_local == 0 {
916 return 0;
917 }
918 if penalties_local.is_empty() {
919 return p_local;
921 }
922 let mut s_total = Array2::<f64>::zeros((p_local, p_local));
927 let mut materialized = 0usize;
928 for s in penalties_local {
929 if s.nrows() == p_local && s.ncols() == p_local {
930 s_total += s;
931 materialized += 1;
932 }
933 }
934 if materialized == penalties_local.len() {
935 let symmetric = {
936 let transpose = s_total.t().to_owned();
937 (&s_total + &transpose) * 0.5
938 };
939 if let Ok((evals, _)) = symmetric.eigh(faer::Side::Lower) {
940 let max_abs = evals.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
941 if max_abs == 0.0 {
942 return p_local;
944 }
945 let tol = max_abs * (p_local as f64) * 1e-12;
946 let rank = evals.iter().filter(|&&v| v > tol).count();
947 return p_local.saturating_sub(rank);
948 }
949 }
950 if penalties_local.len() >= 2 {
955 0
956 } else {
957 nullspace_dims
958 .iter()
959 .copied()
960 .min()
961 .unwrap_or(0)
962 .min(p_local)
963 }
964}
965
966#[derive(Debug, Clone, Serialize, Deserialize)]
967pub struct PenaltyBlockInfo {
968 pub global_index: usize,
969 pub termname: Option<String>,
970 pub penalty: PenaltyInfo,
971}
972
973#[derive(Debug, Clone, Serialize, Deserialize)]
974pub struct DroppedPenaltyBlockInfo {
975 pub termname: Option<String>,
976 pub penalty: PenaltyInfo,
977}
978
979#[derive(Debug, Clone)]
980pub struct SmoothDesign {
981 pub term_designs: Vec<DesignMatrix>,
982 pub penalties: Vec<BlockwisePenalty>,
985 pub nullspace_dims: Vec<usize>,
986 pub penaltyinfo: Vec<PenaltyBlockInfo>,
987 pub dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
988 pub terms: Vec<SmoothTerm>,
989 pub coefficient_lower_bounds: Option<Array1<f64>>,
992 pub linear_constraints: Option<LinearInequalityConstraints>,
995}
996
997impl SmoothDesign {
998 pub fn total_smooth_cols(&self) -> usize {
999 self.term_designs.iter().map(DesignMatrix::ncols).sum()
1000 }
1001 pub fn nrows(&self) -> usize {
1002 self.term_designs.first().map_or(0, DesignMatrix::nrows)
1003 }
1004}
1005
1006#[derive(Debug, Clone)]
1007pub struct RawSmoothDesign {
1008 pub term_designs: Vec<DesignMatrix>,
1009 pub penalties: Vec<BlockwisePenalty>,
1012 pub nullspace_dims: Vec<usize>,
1013 pub penaltyinfo: Vec<PenaltyBlockInfo>,
1014 pub dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
1015 pub terms: Vec<SmoothTerm>,
1016 pub coefficient_lower_bounds: Option<Array1<f64>>,
1017 pub linear_constraints: Option<LinearInequalityConstraints>,
1018}
1019
1020impl RawSmoothDesign {
1021 pub fn total_smooth_cols(&self) -> usize {
1022 self.term_designs.iter().map(DesignMatrix::ncols).sum()
1023 }
1024 pub fn nrows(&self) -> usize {
1025 self.term_designs.first().map_or(0, DesignMatrix::nrows)
1026 }
1027}
1028
1029impl From<RawSmoothDesign> for SmoothDesign {
1030 fn from(value: RawSmoothDesign) -> Self {
1031 Self {
1032 term_designs: value.term_designs,
1033 penalties: value.penalties,
1034 nullspace_dims: value.nullspace_dims,
1035 penaltyinfo: value.penaltyinfo,
1036 dropped_penaltyinfo: value.dropped_penaltyinfo,
1037 terms: value.terms,
1038 coefficient_lower_bounds: value.coefficient_lower_bounds,
1039 linear_constraints: value.linear_constraints,
1040 }
1041 }
1042}
1043
1044#[derive(Debug, Default, Clone, Serialize, Deserialize)]
1045pub enum BoundedCoefficientPriorSpec {
1046 #[default]
1047 None,
1048 Uniform,
1049 Beta {
1050 a: f64,
1051 b: f64,
1052 },
1053}
1054
1055#[derive(Debug, Clone, Serialize, Deserialize, Default)]
1056pub enum LinearCoefficientGeometry {
1057 #[default]
1058 Unconstrained,
1059 Bounded {
1060 min: f64,
1061 max: f64,
1062 #[serde(default)]
1063 prior: BoundedCoefficientPriorSpec,
1064 },
1065}
1066
1067#[derive(Debug, Clone, Serialize, Deserialize)]
1068pub struct LinearTermSpec {
1069 pub name: String,
1070 pub feature_col: usize,
1076 #[serde(default)]
1079 pub feature_cols: Vec<usize>,
1080 #[serde(default)]
1091 pub categorical_levels: Vec<(usize, u64)>,
1092 #[serde(default = "default_linear_term_double_penalty")]
1099 pub double_penalty: bool,
1100 #[serde(default)]
1101 pub coefficient_geometry: LinearCoefficientGeometry,
1102 #[serde(default)]
1103 pub coefficient_min: Option<f64>,
1104 #[serde(default)]
1105 pub coefficient_max: Option<f64>,
1106}
1107
1108impl LinearTermSpec {
1109 pub fn effective_feature_cols(&self) -> Vec<usize> {
1112 if self.feature_cols.is_empty() {
1113 vec![self.feature_col]
1114 } else {
1115 self.feature_cols.clone()
1116 }
1117 }
1118
1119 pub fn is_interaction(&self) -> bool {
1121 self.feature_cols.len() > 1 || !self.categorical_levels.is_empty()
1122 }
1123
1124 pub fn realized_design_column(&self, data: ArrayView2<'_, f64>) -> Result<Array1<f64>, String> {
1136 let n = data.nrows();
1137 let p = data.ncols();
1138 let bounds = |col: usize| -> Result<(), String> {
1139 if col >= p {
1140 Err(format!(
1141 "linear term '{}' feature column {} out of bounds for {} columns",
1142 self.name, col, p
1143 ))
1144 } else {
1145 Ok(())
1146 }
1147 };
1148
1149 let mut column = if self.categorical_levels.is_empty() {
1154 let cols = self.effective_feature_cols();
1155 for &c in &cols {
1156 bounds(c)?;
1157 }
1158 let mut acc = data.column(cols[0]).to_owned();
1159 for &c in cols.iter().skip(1) {
1160 acc *= &data.column(c);
1161 }
1162 acc
1163 } else {
1164 let mut acc = Array1::<f64>::ones(n);
1165 for &c in &self.feature_cols {
1166 bounds(c)?;
1167 acc *= &data.column(c);
1168 }
1169 acc
1170 };
1171
1172 for &(col, level_bits) in &self.categorical_levels {
1173 bounds(col)?;
1174 let gate = data.column(col);
1175 for (out, &v) in column.iter_mut().zip(gate.iter()) {
1176 if v.to_bits() != level_bits {
1177 *out = 0.0;
1178 }
1179 }
1180 }
1181
1182 Ok(column)
1183 }
1184}
1185
1186pub const fn default_linear_term_double_penalty() -> bool {
1187 false
1195}
1196
1197pub const fn default_pca_smooth_penalty() -> f64 {
1198 1.0
1199}
1200
1201pub const fn default_pca_chunk_size() -> usize {
1202 4096
1203}
1204
1205#[derive(Debug, Clone, Serialize, Deserialize)]
1211pub struct RandomEffectTermSpec {
1212 pub name: String,
1213 pub feature_col: usize,
1214 pub drop_first_level: bool,
1217 #[serde(default = "default_random_effect_penalized")]
1221 pub penalized: bool,
1222 #[serde(default)]
1225 pub frozen_levels: Option<Vec<u64>>,
1226}
1227
1228pub fn default_random_effect_penalized() -> bool {
1229 true
1230}
1231
1232pub fn validate_measure_jet_positive_vec_len(
1233 label: &str,
1234 term_name: &str,
1235 field: &str,
1236 values: &[f64],
1237 expected: usize,
1238) -> Result<(), String> {
1239 if values.len() != expected {
1240 return Err(SmoothError::invalid_config(format!(
1241 "{label} term '{term_name}' frozen MeasureJet {field} has length {}, expected {expected}",
1242 values.len()
1243 ))
1244 .into());
1245 }
1246 if values
1247 .iter()
1248 .any(|value| !(value.is_finite() && *value > 0.0))
1249 {
1250 return Err(SmoothError::invalid_config(format!(
1251 "{label} term '{term_name}' frozen MeasureJet {field} values must be positive and finite"
1252 ))
1253 .into());
1254 }
1255 Ok(())
1256}
1257
1258#[derive(Debug, Clone, Serialize, Deserialize)]
1259pub struct TermCollectionSpec {
1260 pub linear_terms: Vec<LinearTermSpec>,
1261 pub random_effect_terms: Vec<RandomEffectTermSpec>,
1262 pub smooth_terms: Vec<SmoothTermSpec>,
1263}
1264
1265pub fn validate_smooth_basis_frozen(
1266 basis: &SmoothBasisSpec,
1267 label: &str,
1268 term_name: &str,
1269) -> Result<(), String> {
1270 match basis {
1271 SmoothBasisSpec::ByVariable { inner, .. }
1272 | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
1273 validate_smooth_basis_frozen(inner, label, term_name)
1274 }
1275 SmoothBasisSpec::BSpline1D { spec, .. } => {
1276 if !matches!(
1277 spec.knotspec,
1278 BSplineKnotSpec::Provided(_)
1279 | BSplineKnotSpec::PeriodicUniform { .. }
1280 | BSplineKnotSpec::NaturalCubicRegression { .. }
1281 ) {
1282 return Err(format!(
1283 "{label} term '{term_name}' is not frozen: BSpline knotspec must be Provided, PeriodicUniform, or NaturalCubicRegression"
1284 ));
1285 }
1286 Ok(())
1287 }
1288 SmoothBasisSpec::ThinPlate { spec, .. } => {
1289 if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1290 return Err(format!(
1291 "{label} term '{term_name}' is not frozen: ThinPlate centers must be UserProvided"
1292 ));
1293 }
1294 if matches!(
1295 spec.identifiability,
1296 SpatialIdentifiability::OrthogonalToParametric
1297 ) {
1298 return Err(format!(
1299 "{label} term '{term_name}' is not frozen: ThinPlate identifiability must be FrozenTransform or None"
1300 ));
1301 }
1302 Ok(())
1303 }
1304 _ => Ok(()),
1305 }
1306}
1307
1308impl TermCollectionSpec {
1309 pub fn write_structural_shape_hash(&self, h: &mut gam_runtime::warm_start::Fingerprinter) {
1323 h.write_str("term-collection");
1324 h.write_usize(self.linear_terms.len());
1325 for linear in &self.linear_terms {
1326 h.write_str(&linear.name);
1327 }
1328 h.write_usize(self.random_effect_terms.len());
1329 h.write_usize(self.smooth_terms.len());
1330 for smooth in &self.smooth_terms {
1331 h.write_str(&smooth.name);
1332 h.write_str(smooth.basis.structural_kind());
1333 for col in smooth.basis.structural_feature_cols() {
1334 h.write_usize(col);
1335 }
1336 }
1337 }
1338
1339 pub fn validate_frozen(&self, label: &str) -> Result<(), String> {
1343 for linear in &self.linear_terms {
1344 if let (Some(min), Some(max)) = (linear.coefficient_min, linear.coefficient_max)
1345 && (!min.is_finite() || !max.is_finite() || min > max)
1346 {
1347 return Err(SmoothError::invalid_config(format!(
1348 "{label} linear term '{}' has invalid coefficient constraint [{min}, {max}]",
1349 linear.name
1350 ))
1351 .into());
1352 }
1353 if let Some(min) = linear.coefficient_min
1354 && !min.is_finite()
1355 {
1356 return Err(SmoothError::invalid_config(format!(
1357 "{label} linear term '{}' has non-finite coefficient minimum {min}",
1358 linear.name
1359 ))
1360 .into());
1361 }
1362 if let Some(max) = linear.coefficient_max
1363 && !max.is_finite()
1364 {
1365 return Err(SmoothError::invalid_config(format!(
1366 "{label} linear term '{}' has non-finite coefficient maximum {max}",
1367 linear.name
1368 ))
1369 .into());
1370 }
1371 if let LinearCoefficientGeometry::Bounded { min, max, prior } =
1372 &linear.coefficient_geometry
1373 {
1374 if !min.is_finite() || !max.is_finite() || min >= max {
1375 return Err(SmoothError::invalid_config(format!(
1376 "{label} bounded term '{}' has invalid bounds [{min}, {max}]",
1377 linear.name
1378 ))
1379 .into());
1380 }
1381 match prior {
1382 BoundedCoefficientPriorSpec::None | BoundedCoefficientPriorSpec::Uniform => {}
1383 BoundedCoefficientPriorSpec::Beta { a, b } => {
1384 if !a.is_finite() || !b.is_finite() || *a < 1.0 || *b < 1.0 {
1385 return Err(SmoothError::invalid_config(format!(
1386 "{label} bounded term '{}' has invalid Beta prior ({a}, {b})",
1387 linear.name
1388 ))
1389 .into());
1390 }
1391 }
1392 }
1393 }
1394 }
1395 for st in &self.smooth_terms {
1396 match &st.basis {
1397 SmoothBasisSpec::ByVariable { inner, .. } => {
1398 validate_smooth_basis_frozen(inner, label, &st.name)?;
1399 let nested = SmoothTermSpec {
1400 name: st.name.clone(),
1401 basis: (**inner).clone(),
1402 shape: st.shape,
1403 joint_null_rotation: None,
1404 };
1405 TermCollectionSpec {
1406 linear_terms: Vec::new(),
1407 random_effect_terms: Vec::new(),
1408 smooth_terms: vec![nested],
1409 }
1410 .validate_frozen(label)?;
1411 }
1412 SmoothBasisSpec::FactorSumToZero { inner, levels, .. } => {
1413 if levels.len() < 2 {
1414 return Err(format!(
1415 "{label} term '{}' has invalid frozen sz levels",
1416 st.name
1417 ));
1418 }
1419 validate_smooth_basis_frozen(inner, label, &st.name)?;
1420 }
1421 SmoothBasisSpec::BSpline1D { spec, .. } => {
1422 if !matches!(
1423 spec.knotspec,
1424 BSplineKnotSpec::Provided(_)
1425 | BSplineKnotSpec::PeriodicUniform { .. }
1426 | BSplineKnotSpec::NaturalCubicRegression { .. }
1427 ) {
1428 return Err(SmoothError::invalid_config(format!(
1429 "{label} term '{}' is not frozen: BSpline knotspec must be Provided, PeriodicUniform, or NaturalCubicRegression",
1430 st.name
1431 ))
1432 .into());
1433 }
1434 }
1435 SmoothBasisSpec::ThinPlate { spec, .. } => {
1436 if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1437 return Err(SmoothError::invalid_config(format!(
1438 "{label} term '{}' is not frozen: ThinPlate centers must be UserProvided",
1439 st.name
1440 ))
1441 .into());
1442 }
1443 if matches!(
1444 spec.identifiability,
1445 SpatialIdentifiability::OrthogonalToParametric
1446 ) {
1447 return Err(SmoothError::invalid_config(format!(
1448 "{label} term '{}' is not frozen: ThinPlate identifiability must be FrozenTransform or None",
1449 st.name
1450 ))
1451 .into());
1452 }
1453 }
1454 SmoothBasisSpec::Sphere { spec, .. } => {
1455 if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1456 return Err(SmoothError::invalid_config(format!(
1457 "{label} term '{}' is not frozen: Sphere centers must be UserProvided",
1458 st.name
1459 ))
1460 .into());
1461 }
1462 if matches!(spec.method, crate::basis::SphereMethod::Harmonic)
1463 && spec.max_degree.is_none_or(|d| d == 0)
1464 {
1465 return Err(format!(
1466 "{label} term '{}' is not frozen: sphere max_degree must be positive",
1467 st.name
1468 ));
1469 }
1470 }
1471 SmoothBasisSpec::ConstantCurvature { spec, .. } => {
1472 if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1473 return Err(SmoothError::invalid_config(format!(
1474 "{label} term '{}' is not frozen: ConstantCurvature centers must be UserProvided",
1475 st.name
1476 ))
1477 .into());
1478 }
1479 if !(spec.length_scale.is_finite() && spec.length_scale > 0.0) {
1480 return Err(SmoothError::invalid_config(format!(
1481 "{label} term '{}' is not frozen: ConstantCurvature length_scale must be the realized positive value",
1482 st.name
1483 ))
1484 .into());
1485 }
1486 }
1487 SmoothBasisSpec::MeasureJet { spec, .. } => {
1488 let centers = match &spec.center_strategy {
1489 CenterStrategy::UserProvided(centers) => centers,
1490 _ => {
1491 return Err(SmoothError::invalid_config(format!(
1492 "{label} term '{}' is not frozen: MeasureJet centers must be UserProvided",
1493 st.name
1494 ))
1495 .into());
1496 }
1497 };
1498 if centers.nrows() == 0 {
1499 return Err(SmoothError::invalid_config(format!(
1500 "{label} term '{}' is not frozen: MeasureJet centers are empty",
1501 st.name
1502 ))
1503 .into());
1504 }
1505 if !(spec.length_scale.is_finite() && spec.length_scale > 0.0) {
1506 return Err(SmoothError::invalid_config(format!(
1507 "{label} term '{}' is not frozen: MeasureJet length_scale must be the realized positive value",
1508 st.name
1509 ))
1510 .into());
1511 }
1512 let frozen = spec.frozen_quadrature.as_ref().ok_or_else(|| {
1515 SmoothError::invalid_config(format!(
1516 "{label} term '{}' is not frozen: MeasureJet frozen_quadrature payload is missing",
1517 st.name
1518 ))
1519 })?;
1520 if frozen.masses.len() != centers.nrows() {
1521 return Err(SmoothError::invalid_config(format!(
1522 "{label} term '{}' frozen MeasureJet has {} masses for {} centers",
1523 st.name,
1524 frozen.masses.len(),
1525 centers.nrows()
1526 ))
1527 .into());
1528 }
1529 let total_mass = frozen.masses.sum();
1530 if frozen
1531 .masses
1532 .iter()
1533 .any(|mass| !(mass.is_finite() && *mass >= 0.0))
1534 || !(total_mass.is_finite() && total_mass > 0.0)
1535 {
1536 return Err(SmoothError::invalid_config(format!(
1537 "{label} term '{}' frozen MeasureJet masses must be finite, nonnegative, and have positive total mass",
1538 st.name
1539 ))
1540 .into());
1541 }
1542 let n_levels = frozen.eps_band.len();
1543 if n_levels == 0
1544 || frozen
1545 .eps_band
1546 .iter()
1547 .any(|eps| !(eps.is_finite() && *eps > 0.0))
1548 {
1549 return Err(SmoothError::invalid_config(format!(
1550 "{label} term '{}' frozen MeasureJet eps_band must be nonempty, finite, and positive",
1551 st.name
1552 ))
1553 .into());
1554 }
1555 for (idx, pair) in frozen.eps_band.windows(2).enumerate() {
1556 if pair[1] <= pair[0] {
1557 return Err(SmoothError::invalid_config(format!(
1558 "{label} term '{}' frozen MeasureJet eps_band is not strictly ascending at {idx}: {} then {}",
1559 st.name,
1560 pair[0],
1561 pair[1]
1562 ))
1563 .into());
1564 }
1565 }
1566 validate_measure_jet_positive_vec_len(
1567 label,
1568 &st.name,
1569 "support_means",
1570 &frozen.support_means,
1571 n_levels,
1572 )?;
1573 let per_level = crate::basis::measure_jet_multiscale_mode(spec);
1581 if per_level {
1582 validate_measure_jet_positive_vec_len(
1583 label,
1584 &st.name,
1585 "penalty_normalization_scales",
1586 &frozen.penalty_normalization_scales,
1587 n_levels,
1588 )?;
1589 validate_measure_jet_positive_vec_len(
1590 label,
1591 &st.name,
1592 "raw_penalty_normalization_scales",
1593 &frozen.raw_penalty_normalization_scales,
1594 n_levels,
1595 )?;
1596 if frozen.fused_penalty_normalization_scale.is_some() {
1597 return Err(SmoothError::invalid_config(format!(
1598 "{label} term '{}' per-level MeasureJet must not carry a fused penalty normalization scale",
1599 st.name
1600 ))
1601 .into());
1602 }
1603 } else {
1604 if !frozen.penalty_normalization_scales.is_empty()
1605 || !frozen.raw_penalty_normalization_scales.is_empty()
1606 {
1607 return Err(SmoothError::invalid_config(format!(
1608 "{label} term '{}' fused MeasureJet must not carry per-level penalty normalization scales",
1609 st.name
1610 ))
1611 .into());
1612 }
1613 match frozen.fused_penalty_normalization_scale {
1614 Some(scale) if scale.is_finite() && scale > 0.0 => {}
1615 Some(scale) => {
1616 return Err(SmoothError::invalid_config(format!(
1617 "{label} term '{}' fused MeasureJet penalty normalization scale must be positive and finite, got {scale}",
1618 st.name
1619 ))
1620 .into());
1621 }
1622 None => {
1623 return Err(SmoothError::invalid_config(format!(
1624 "{label} term '{}' fused MeasureJet is missing its penalty normalization scale",
1625 st.name
1626 ))
1627 .into());
1628 }
1629 }
1630 }
1631 }
1632 SmoothBasisSpec::Matern { spec, .. } => {
1633 if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1634 return Err(SmoothError::invalid_config(format!(
1635 "{label} term '{}' is not frozen: Matern centers must be UserProvided",
1636 st.name
1637 ))
1638 .into());
1639 }
1640 }
1641 SmoothBasisSpec::Duchon { spec, .. } => {
1642 if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
1643 return Err(SmoothError::invalid_config(format!(
1644 "{label} term '{}' is not frozen: Duchon centers must be UserProvided",
1645 st.name
1646 ))
1647 .into());
1648 }
1649 if matches!(
1650 spec.identifiability,
1651 SpatialIdentifiability::OrthogonalToParametric
1652 ) {
1653 return Err(SmoothError::invalid_config(format!(
1654 "{label} term '{}' is not frozen: Duchon identifiability must be FrozenTransform or None",
1655 st.name
1656 ))
1657 .into());
1658 }
1659 }
1660 SmoothBasisSpec::Pca {
1661 centered,
1662 center_mean,
1663 pca_basis_path,
1664 ..
1665 } => {
1666 if *centered && center_mean.is_none() && pca_basis_path.is_none() {
1667 return Err(SmoothError::invalid_config(format!(
1668 "{label} term '{}' is not frozen: centered Pca missing center_mean",
1669 st.name
1670 ))
1671 .into());
1672 }
1673 }
1674 SmoothBasisSpec::BySmooth { smooth, by_kind } => {
1675 if let SmoothBasisSpec::BySmooth { .. } = smooth.as_ref() {
1676 return Err(format!("{label} term '{}' has nested by-smooths", st.name));
1677 }
1678 match by_kind {
1679 ByVarKind::Numeric { .. } => {}
1680 ByVarKind::Factor { frozen_levels, .. } if frozen_levels.is_none() => {
1681 return Err(format!(
1682 "{label} term '{}' is not frozen: by-factor levels missing",
1683 st.name
1684 ));
1685 }
1686 ByVarKind::Factor { .. } => {}
1687 }
1688 let nested = TermCollectionSpec {
1689 linear_terms: vec![],
1690 random_effect_terms: vec![],
1691 smooth_terms: vec![SmoothTermSpec {
1692 name: st.name.clone(),
1693 basis: (**smooth).clone(),
1694 shape: st.shape,
1695 joint_null_rotation: None,
1696 }],
1697 };
1698 nested.validate_frozen(label)?;
1699 }
1700 SmoothBasisSpec::FactorSmooth { spec } => {
1701 if spec.group_frozen_levels.is_none() {
1702 return Err(format!(
1703 "{label} term '{}' is not frozen: factor-smooth levels missing",
1704 st.name
1705 ));
1706 }
1707 if !matches!(
1708 spec.marginal.knotspec,
1709 BSplineKnotSpec::Provided(_)
1710 | BSplineKnotSpec::PeriodicUniform { .. }
1711 | BSplineKnotSpec::NaturalCubicRegression { .. }
1723 ) {
1724 return Err(format!(
1725 "{label} term '{}' is not frozen: factor-smooth marginal knots missing",
1726 st.name
1727 ));
1728 }
1729 }
1730 SmoothBasisSpec::TensorBSpline { spec, .. } => {
1731 for (dim, marginal) in spec.marginalspecs.iter().enumerate() {
1732 if !matches!(
1733 marginal.knotspec,
1734 BSplineKnotSpec::Provided(_)
1735 | BSplineKnotSpec::PeriodicUniform { .. }
1736 | BSplineKnotSpec::NaturalCubicRegression { .. }
1737 ) {
1738 return Err(SmoothError::invalid_config(format!(
1739 "{label} term '{}' dim {} is not frozen: tensor marginal knotspec must be Provided, PeriodicUniform, or NaturalCubicRegression",
1740 st.name, dim
1741 ))
1742 .into());
1743 }
1744 }
1745 if matches!(
1746 spec.identifiability,
1747 TensorBSplineIdentifiability::SumToZero
1748 | TensorBSplineIdentifiability::MarginalSumToZero
1749 ) {
1750 return Err(SmoothError::invalid_config(format!(
1751 "{label} term '{}' is not frozen: tensor identifiability must be FrozenTransform or None",
1752 st.name
1753 ))
1754 .into());
1755 }
1756 }
1757 }
1758 }
1759
1760 for rt in &self.random_effect_terms {
1761 if rt.frozen_levels.is_none() {
1762 return Err(SmoothError::invalid_config(format!(
1763 "{label} random-effect term '{}' is not frozen: missing frozen_levels",
1764 rt.name
1765 ))
1766 .into());
1767 }
1768 }
1769
1770 Ok(())
1771 }
1772
1773 pub fn remap_feature_columns<E, F>(&self, mut remap: F) -> Result<TermCollectionSpec, E>
1792 where
1793 F: FnMut(usize) -> Result<usize, E>,
1794 {
1795 let mut out = self.clone();
1796 for lt in &mut out.linear_terms {
1797 lt.feature_col = remap(lt.feature_col)?;
1798 for fc in lt.feature_cols.iter_mut() {
1808 *fc = remap(*fc)?;
1809 }
1810 for (col, _bits) in lt.categorical_levels.iter_mut() {
1815 *col = remap(*col)?;
1816 }
1817 }
1818 for rt in &mut out.random_effect_terms {
1819 rt.feature_col = remap(rt.feature_col)?;
1820 }
1821 for st in &mut out.smooth_terms {
1822 remap_smooth_basis_feature_columns(&mut st.basis, &mut remap)?;
1823 }
1824 Ok(out)
1825 }
1826}
1827
1828pub fn remap_smooth_basis_feature_columns<E, F>(
1833 basis: &mut SmoothBasisSpec,
1834 remap: &mut F,
1835) -> Result<(), E>
1836where
1837 F: FnMut(usize) -> Result<usize, E>,
1838{
1839 match basis {
1840 SmoothBasisSpec::ByVariable { inner, by_col, .. }
1841 | SmoothBasisSpec::FactorSumToZero { inner, by_col, .. } => {
1842 *by_col = remap(*by_col)?;
1843 remap_smooth_basis_feature_columns(inner, remap)?;
1844 }
1845 SmoothBasisSpec::BSpline1D { feature_col, .. } => {
1846 *feature_col = remap(*feature_col)?;
1847 }
1848 SmoothBasisSpec::BySmooth { smooth, by_kind } => {
1849 let by_feature_col = match by_kind {
1850 ByVarKind::Numeric { feature_col } | ByVarKind::Factor { feature_col, .. } => {
1851 feature_col
1852 }
1853 };
1854 *by_feature_col = remap(*by_feature_col)?;
1855 remap_smooth_basis_feature_columns(smooth, remap)?;
1856 }
1857 SmoothBasisSpec::FactorSmooth { spec } => {
1858 for fc in spec.continuous_cols.iter_mut() {
1859 *fc = remap(*fc)?;
1860 }
1861 spec.group_col = remap(spec.group_col)?;
1862 }
1863 SmoothBasisSpec::ThinPlate { feature_cols, .. }
1864 | SmoothBasisSpec::Sphere { feature_cols, .. }
1865 | SmoothBasisSpec::ConstantCurvature { feature_cols, .. }
1866 | SmoothBasisSpec::Matern { feature_cols, .. }
1867 | SmoothBasisSpec::MeasureJet { feature_cols, .. }
1868 | SmoothBasisSpec::Duchon { feature_cols, .. }
1869 | SmoothBasisSpec::Pca { feature_cols, .. }
1870 | SmoothBasisSpec::TensorBSpline { feature_cols, .. } => {
1871 for fc in feature_cols.iter_mut() {
1872 *fc = remap(*fc)?;
1873 }
1874 }
1875 }
1876 Ok(())
1877}
1878
1879#[derive(Debug, Clone)]
1880pub enum PenaltyStructureHint {
1881 Ridge(f64),
1882 Kronecker(Vec<Array2<f64>>),
1883}
1884
1885#[derive(Clone)]
1892pub struct BlockwisePenalty {
1893 pub col_range: Range<usize>,
1895 pub local: Array2<f64>,
1898 pub prior_mean: gam_problem::CoefficientPriorMean,
1900 pub structure_hint: Option<PenaltyStructureHint>,
1903 pub op: Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>,
1908}
1909
1910impl std::fmt::Debug for BlockwisePenalty {
1911 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1912 f.debug_struct("BlockwisePenalty")
1913 .field("col_range", &self.col_range)
1914 .field(
1915 "local",
1916 &format_args!("{}×{}", self.local.nrows(), self.local.ncols()),
1917 )
1918 .field("prior_mean", &self.prior_mean)
1919 .field("structure_hint", &self.structure_hint)
1920 .field("op", &self.op.as_ref().map(|o| o.dim()))
1921 .finish()
1922 }
1923}
1924
1925impl BlockwisePenalty {
1926 pub fn new(col_range: Range<usize>, local: Array2<f64>) -> Self {
1928 assert_eq!(col_range.len(), local.nrows());
1929 assert_eq!(col_range.len(), local.ncols());
1930 Self {
1931 col_range,
1932 local,
1933 prior_mean: gam_problem::CoefficientPriorMean::Zero,
1934 structure_hint: None,
1935 op: None,
1936 }
1937 }
1938
1939 pub fn with_prior_mean(
1940 mut self,
1941 prior_mean: gam_problem::CoefficientPriorMean,
1942 ) -> Self {
1943 self.prior_mean = prior_mean;
1944 self
1945 }
1946
1947 pub fn with_op(
1949 mut self,
1950 op: Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>,
1951 ) -> Self {
1952 self.op = op;
1953 self
1954 }
1955
1956 pub fn ridge(col_range: Range<usize>, scale: f64) -> Self {
1957 let block_size = col_range.len();
1958 let mut local = Array2::<f64>::zeros((block_size, block_size));
1959 for i in 0..block_size {
1960 local[[i, i]] = scale;
1961 }
1962 Self {
1963 col_range,
1964 local,
1965 prior_mean: gam_problem::CoefficientPriorMean::Zero,
1966 structure_hint: Some(PenaltyStructureHint::Ridge(scale)),
1967 op: None,
1968 }
1969 }
1970
1971 pub fn kronecker(
1972 col_range: Range<usize>,
1973 local: Array2<f64>,
1974 factors: Vec<Array2<f64>>,
1975 ) -> Self {
1976 assert_eq!(col_range.len(), local.nrows());
1977 assert_eq!(col_range.len(), local.ncols());
1978 Self {
1979 col_range,
1980 local,
1981 prior_mean: gam_problem::CoefficientPriorMean::Zero,
1982 structure_hint: Some(PenaltyStructureHint::Kronecker(factors)),
1983 op: None,
1984 }
1985 }
1986
1987 pub fn to_global(&self, p_total: usize) -> Array2<f64> {
1991 let mut g = Array2::<f64>::zeros((p_total, p_total));
1992 let r = &self.col_range;
1993 assert!(
1994 r.end <= p_total && self.local.nrows() == r.len() && self.local.ncols() == r.len(),
1995 "BlockwisePenalty::to_global shape invariant violated: \
1996 col_range={}..{}, local={}x{}, p_total={}",
1997 r.start,
1998 r.end,
1999 self.local.nrows(),
2000 self.local.ncols(),
2001 p_total,
2002 );
2003 g.slice_mut(s![r.start..r.end, r.start..r.end])
2004 .assign(&self.local);
2005 g
2006 }
2007
2008 pub fn to_penalty_matrix(
2011 &self,
2012 total_dim: usize,
2013 ) -> gam_problem::PenaltyMatrix {
2014 gam_problem::PenaltyMatrix::Blockwise {
2015 local: self.local.clone(),
2016 col_range: self.col_range.clone(),
2017 total_dim,
2018 }
2019 }
2020
2021 #[inline]
2023 pub fn block_size(&self) -> usize {
2024 self.col_range.len()
2025 }
2026}
2027
2028pub fn weighted_blockwise_penalty_sum(
2032 penalties: &[BlockwisePenalty],
2033 lambdas: &[f64],
2034 p_total: usize,
2035) -> Array2<f64> {
2036 assert_eq!(penalties.len(), lambdas.len());
2037 for (idx, &lam) in lambdas.iter().enumerate() {
2044 assert!(
2045 lam.is_finite() && lam >= 0.0,
2046 "weighted_blockwise_penalty_sum: lambdas[{idx}] = {lam} is invalid (must be finite and non-negative; negative smoothing parameters violate S_λ ⪰ 0)",
2047 );
2048 }
2049 for (idx, bp) in penalties.iter().enumerate() {
2053 let r = &bp.col_range;
2054 assert!(
2055 r.end <= p_total,
2056 "weighted_blockwise_penalty_sum: penalties[{idx}] col_range {:?} exceeds p_total = {p_total}",
2057 r,
2058 );
2059 }
2060 let mut out = Array2::<f64>::zeros((p_total, p_total));
2061 for (bp, &lam) in penalties.iter().zip(lambdas.iter()) {
2062 let r = &bp.col_range;
2063 let mut slice = out.slice_mut(s![r.start..r.end, r.start..r.end]);
2064 slice.scaled_add(lam, &bp.local);
2065 }
2066 out
2067}
2068
2069#[derive(Debug, Clone)]
2076pub struct KroneckerPenaltySystem {
2077 pub marginal_penalties: Vec<Array2<f64>>,
2079 pub marginal_eigensystems: Vec<(Array1<f64>, Array2<f64>)>,
2081 pub marginal_dims: Vec<usize>,
2083 pub has_double_penalty: bool,
2085}
2086
2087impl KroneckerPenaltySystem {
2088 pub fn new(
2089 marginal_penalties: Vec<Array2<f64>>,
2090 marginal_dims: Vec<usize>,
2091 has_double_penalty: bool,
2092 ) -> Result<Self, BasisError> {
2093 if marginal_penalties.len() != marginal_dims.len() {
2094 crate::bail_dim_basis!(
2095 "KroneckerPenaltySystem: {} penalties vs {} dims",
2096 marginal_penalties.len(),
2097 marginal_dims.len()
2098 );
2099 }
2100 let eigensystems =
2101 kronecker_marginal_eigensystems(&marginal_penalties, "KroneckerPenaltySystem")
2102 .map_err(|e| BasisError::InvalidInput(e.to_string()))?;
2103 Ok(Self {
2104 marginal_penalties,
2105 marginal_eigensystems: eigensystems,
2106 marginal_dims,
2107 has_double_penalty,
2108 })
2109 }
2110
2111 pub fn p_total(&self) -> usize {
2112 self.marginal_dims.iter().copied().product()
2113 }
2114
2115 pub fn ndim(&self) -> usize {
2116 self.marginal_dims.len()
2117 }
2118
2119 pub fn num_penalties(&self) -> usize {
2120 self.marginal_dims.len() + if self.has_double_penalty { 1 } else { 0 }
2121 }
2122
2123 pub fn logdet_and_derivatives(
2127 &self,
2128 lambdas: &[f64],
2129 ridge: f64,
2130 ) -> (f64, Array1<f64>, Array2<f64>) {
2131 let n_pen = self.num_penalties();
2132 assert_eq!(lambdas.len(), n_pen, "lambda count mismatch");
2133 let marginal_evals: Vec<_> = self
2134 .marginal_eigensystems
2135 .iter()
2136 .map(|(evals, _)| evals.view())
2137 .collect();
2138 kronecker_logdet_and_derivatives(
2139 &marginal_evals,
2140 &self.marginal_dims,
2141 lambdas,
2142 self.has_double_penalty,
2143 ridge,
2144 )
2145 }
2146
2147 pub fn logdet_rank_and_derivatives(
2148 &self,
2149 lambdas: &[f64],
2150 ridge: f64,
2151 ) -> (f64, usize, Array1<f64>, Array2<f64>) {
2152 let n_pen = self.num_penalties();
2153 assert_eq!(lambdas.len(), n_pen, "lambda count mismatch");
2154 let d = self.marginal_dims.len();
2155 let mut logdet = 0.0;
2156 let mut rank = 0usize;
2157 let mut grad = Array1::<f64>::zeros(n_pen);
2158 let mut hess = Array2::<f64>::zeros((n_pen, n_pen));
2159 const EIGENVALUE_POSITIVITY_FLOOR: f64 = 1e-12;
2163 const STRUCTURAL_ZERO_FLOOR: f64 = 1e-12;
2167 let mut multi_idx = vec![0usize; d];
2168 loop {
2169 let mut sigma = 0.0;
2170 let mut structural_sigma = 0.0;
2171 for k in 0..d {
2172 let marginal_eigenvalue = self.marginal_eigensystems[k].0[multi_idx[k]];
2173 structural_sigma += marginal_eigenvalue;
2174 sigma += lambdas[k] * marginal_eigenvalue;
2175 }
2176 let joint_null = structural_sigma <= STRUCTURAL_ZERO_FLOOR;
2177 if self.has_double_penalty && joint_null {
2178 sigma += lambdas[d];
2179 }
2180 if structural_sigma > STRUCTURAL_ZERO_FLOOR {
2181 sigma += ridge;
2182 }
2183
2184 if sigma > EIGENVALUE_POSITIVITY_FLOOR {
2185 rank += 1;
2186 logdet += sigma.ln();
2187 let inv_sigma = 1.0 / sigma;
2188 let inv_sigma2 = inv_sigma * inv_sigma;
2189 for k in 0..n_pen {
2190 let ck = if k < d {
2191 lambdas[k] * self.marginal_eigensystems[k].0[multi_idx[k]]
2192 } else if joint_null {
2193 lambdas[d]
2194 } else {
2195 0.0
2196 };
2197 grad[k] += ck * inv_sigma;
2198 hess[[k, k]] += ck * inv_sigma - ck * ck * inv_sigma2;
2199 for l in (k + 1)..n_pen {
2200 let cl = if l < d {
2201 lambdas[l] * self.marginal_eigensystems[l].0[multi_idx[l]]
2202 } else if joint_null {
2203 lambdas[d]
2204 } else {
2205 0.0
2206 };
2207 let off = -ck * cl * inv_sigma2;
2208 hess[[k, l]] += off;
2209 hess[[l, k]] += off;
2210 }
2211 }
2212 }
2213
2214 let mut carry = true;
2215 for dim in (0..d).rev() {
2216 if carry {
2217 multi_idx[dim] += 1;
2218 if multi_idx[dim] < self.marginal_dims[dim] {
2219 carry = false;
2220 } else {
2221 multi_idx[dim] = 0;
2222 }
2223 }
2224 }
2225 if carry {
2226 break;
2227 }
2228 }
2229 (logdet, rank, grad, hess)
2230 }
2231}
2232
2233#[cfg(test)]
2234mod joint_unpenalized_dim_tests {
2235 use super::joint_unpenalized_dim;
2236 use ndarray::{Array2, array};
2237
2238 #[test]
2239 fn no_penalty_is_fully_unpenalized() {
2240 assert_eq!(joint_unpenalized_dim(4, &[], &[]), 4);
2241 }
2242
2243 #[test]
2244 fn single_penalty_returns_its_own_null_space() {
2245 let s = array![[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 5.0]];
2248 assert_eq!(joint_unpenalized_dim(3, std::slice::from_ref(&s), &[2]), 2);
2249 }
2250
2251 #[test]
2252 fn complementary_double_penalty_has_empty_joint_null_space() {
2253 let bending = array![[0.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]];
2260 let ridge = array![[2.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]];
2261 assert_eq!(joint_unpenalized_dim(3, &[bending, ridge], &[1, 2]), 0);
2262 }
2263
2264 #[test]
2265 fn partial_overlap_keeps_shared_null_direction() {
2266 let a = array![[0.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 0.0]];
2270 let b = array![[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]];
2271 assert_eq!(joint_unpenalized_dim(3, &[a, b], &[2, 2]), 1);
2272 }
2273
2274 #[test]
2275 fn non_materialized_penalty_falls_back_conservatively() {
2276 let full: Array2<f64> = array![[0.0, 0.0], [0.0, 1.0]];
2280 let factor: Array2<f64> = array![[1.0]]; assert_eq!(
2282 joint_unpenalized_dim(2, &[full, factor.clone()], &[1, 0]),
2283 0
2284 );
2285 assert_eq!(joint_unpenalized_dim(4, std::slice::from_ref(&factor), &[2]), 2);
2287 }
2288}
2289
2290#[cfg(test)]
2291mod kronecker_penalty_system_tests {
2292 use super::KroneckerPenaltySystem;
2293 use ndarray::array;
2294
2295 #[test]
2296 fn double_penalty_rank_derivatives_use_only_joint_null_space() {
2297 let penalties = vec![
2298 array![[0.0, 0.0], [0.0, 2.0]],
2299 array![[0.0, 0.0], [0.0, 3.0]],
2300 ];
2301 let system = KroneckerPenaltySystem::new(penalties, vec![2usize, 2usize], true).unwrap();
2302 let lambdas = vec![5.0, 7.0, 11.0];
2303
2304 let (logdet, rank, grad, hess) = system.logdet_rank_and_derivatives(&lambdas, 0.0);
2305
2306 let expected_diag = [11.0_f64, 21.0, 10.0, 31.0];
2307 let expected_logdet: f64 = expected_diag.iter().map(|v| v.ln()).sum();
2308 assert_eq!(rank, 4);
2309 assert!((logdet - expected_logdet).abs() <= 1e-12);
2310 assert!(
2311 (grad[2] - 1.0).abs() <= 1e-12,
2312 "double-penalty rank derivative must count only the joint null mode, got {}",
2313 grad[2]
2314 );
2315 assert!(hess[[2, 2]].abs() <= 1e-12);
2316 }
2317}
2318
2319#[derive(Clone, Debug)]
2320pub struct TermCollectionDesign {
2321 pub design: DesignMatrix,
2330 pub penalties: Vec<BlockwisePenalty>,
2331 pub nullspace_dims: Vec<usize>,
2332 pub penaltyinfo: Vec<PenaltyBlockInfo>,
2333 pub dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
2334 pub coefficient_lower_bounds: Option<Array1<f64>>,
2337 pub linear_constraints: Option<LinearInequalityConstraints>,
2340 pub intercept_range: Range<usize>,
2341 pub linear_ranges: Vec<(String, Range<usize>)>,
2342 pub random_effect_ranges: Vec<(String, Range<usize>)>,
2343 pub random_effect_levels: Vec<(String, Vec<u64>)>,
2344 pub smooth: SmoothDesign,
2345}
2346
2347impl TermCollectionDesign {
2348 pub fn leading_penalty_blocks_before_smooth(&self) -> usize {
2356 self.penaltyinfo
2357 .iter()
2358 .take_while(|info| {
2359 matches!(
2360 &info.penalty.source,
2361 crate::basis::PenaltySource::Other(source)
2362 if source == "LinearTermRidge"
2363 || source.starts_with("RandomEffectRidge(")
2364 )
2365 })
2366 .count()
2367 }
2368
2369 pub fn penalties_as_penalty_matrix(&self) -> Vec<gam_problem::PenaltyMatrix> {
2373 let p = self.design.ncols();
2374 self.penalties
2375 .iter()
2376 .map(|bp| bp.to_penalty_matrix(p))
2377 .collect()
2378 }
2379
2380 #[inline]
2382 pub fn num_penalties(&self) -> usize {
2383 self.penalties.len()
2384 }
2385
2386 pub fn realize_coefficient_groups(
2389 &self,
2390 groups: &[CoefficientGroupSpec],
2391 base_prior: &gam_spec::RhoPrior,
2392 ) -> Result<RealizedCoefficientGroups, BasisError> {
2393 realize_coefficient_groups(self, groups, base_prior)
2394 }
2395
2396 pub fn kronecker_penalty_system(&self) -> Option<KroneckerPenaltySystem> {
2407 let [only_term] = self.smooth.terms.as_slice() else {
2408 return None;
2409 };
2410 let kron = only_term.kronecker_factored.as_ref()?;
2411 if kron.marginal_dims.len() < 2
2417 || kron.marginal_penalties.len() != kron.marginal_dims.len()
2418 || kron.marginal_designs.len() != kron.marginal_dims.len()
2419 {
2420 return None;
2421 }
2422 KroneckerPenaltySystem::new(
2423 kron.marginal_penalties.clone(),
2424 kron.marginal_dims.clone(),
2425 kron.has_double_penalty,
2426 )
2427 .ok()
2428 }
2429}
2430
2431#[derive(Clone)]
2437pub struct StandardLatentCoordConfig {
2438 pub values: std::sync::Arc<crate::latent::LatentCoordValues>,
2439 pub term_index: gam_problem::types::SmoothTermIdx,
2440 pub feature_cols: Vec<usize>,
2441 pub manifold: crate::latent::LatentManifold,
2442 pub manifold_auto: bool,
2443 pub retraction_registry: gam_problem::LatentRetractionRegistry,
2444 pub analytic_penalties: Option<std::sync::Arc<crate::AnalyticPenaltyRegistry>>,
2445}
2446
2447#[derive(Clone, Debug, Serialize, Deserialize)]
2448pub struct AdaptiveSpatialMap {
2449 pub termname: String,
2450 pub feature_cols: Vec<usize>,
2451 pub collocation_points: Array2<f64>,
2452 pub inv_magweight: Array1<f64>,
2453 pub invgradweight: Array1<f64>,
2454 pub inv_lapweight: Array1<f64>,
2455}
2456
2457#[derive(Clone, Debug, Serialize, Deserialize)]
2458pub struct AdaptiveRegularizationDiagnostics {
2459 pub epsilon_0: f64,
2460 pub epsilon_g: f64,
2461 pub epsilon_c: f64,
2462 pub epsilon_outer_iterations: usize,
2463 pub mm_iterations: usize,
2464 pub converged: bool,
2465 pub maps: Vec<AdaptiveSpatialMap>,
2466}
2467
2468#[derive(Debug, Clone)]
2469pub struct LinearColumnConditioning {
2470 col_idx: usize,
2471 mean: f64,
2472 scale: f64,
2473}
2474
2475#[derive(Debug, Clone, Default)]
2476pub struct LinearFitConditioning {
2477 pub intercept_idx: usize,
2478 pub columns: Vec<LinearColumnConditioning>,
2479}
2480
2481#[derive(Clone)]
2482pub struct SpatialPsiDerivative {
2483 pub penalty_index: usize,
2485 pub penalty_indices: Vec<usize>,
2486 pub global_range: Range<usize>,
2487 pub total_p: usize,
2488 pub x_psi_local: Array2<f64>,
2489 pub s_psi_components_local: Vec<Array2<f64>>,
2490 pub x_psi_psi_local: Array2<f64>,
2491 pub s_psi_psi_components_local: Vec<Array2<f64>>,
2492 pub aniso_group_id: Option<usize>,
2493 pub aniso_cross_designs: Option<Vec<(usize, Array2<f64>)>>,
2496 pub aniso_cross_penalty_provider: Option<
2500 std::sync::Arc<
2501 dyn Fn(usize) -> Result<Vec<Array2<f64>>, EstimationError> + Send + Sync + 'static,
2502 >,
2503 >,
2504 pub implicit_operator: Option<std::sync::Arc<crate::basis::ImplicitDesignPsiDerivative>>,
2509 pub implicit_axis: usize,
2511}
2512
2513#[derive(Debug, Clone)]
2514pub struct SpatialLogKappaCoords {
2515 pub values: Array1<f64>,
2518 pub dims_per_term: Vec<usize>,
2520}
2521
2522#[derive(Clone, Copy)]
2527pub enum AnisoBoundEnd {
2528 Lower,
2529 Upper,
2530}
2531
2532impl SpatialLogKappaCoords {
2533 pub fn new_with_dims(values: Array1<f64>, dims_per_term: Vec<usize>) -> Self {
2535 assert_eq!(
2536 values.len(),
2537 dims_per_term.iter().sum::<usize>(),
2538 "SpatialLogKappaCoords: values length {} != sum of dims_per_term {}",
2539 values.len(),
2540 dims_per_term.iter().sum::<usize>(),
2541 );
2542 Self {
2543 values,
2544 dims_per_term,
2545 }
2546 }
2547
2548 pub fn from_length_scales(
2550 spec: &TermCollectionSpec,
2551 term_indices: &[usize],
2552 options: &SpatialLengthScaleOptimizationOptions,
2553 ) -> Self {
2554 let mut out = Array1::<f64>::zeros(term_indices.len());
2555 for (slot, &term_idx) in term_indices.iter().enumerate() {
2556 if let Some(cc) = constant_curvature_term_spec(spec, term_idx) {
2562 out[slot] = cc.kappa;
2563 continue;
2564 }
2565 let length_scale = get_spatial_length_scale(spec, term_idx)
2566 .unwrap_or(options.min_length_scale)
2567 .clamp(options.min_length_scale, options.max_length_scale);
2568 out[slot] = -length_scale.ln();
2569 }
2570 Self {
2571 values: out,
2572 dims_per_term: vec![1; term_indices.len()],
2573 }
2574 }
2575
2576 pub fn from_length_scales_aniso(
2594 spec: &TermCollectionSpec,
2595 term_indices: &[usize],
2596 options: &SpatialLengthScaleOptimizationOptions,
2597 ) -> Self {
2598 let mut vals = Vec::new();
2599 let mut dims = Vec::new();
2600 for &term_idx in term_indices {
2601 if let Some(mj) = measure_jet_term_spec(spec, term_idx) {
2605 let seed = measure_jet_psi_seed(mj);
2606 dims.push(seed.len());
2607 vals.extend(seed);
2608 continue;
2609 }
2610 if let Some(cc) = constant_curvature_term_spec(spec, term_idx) {
2616 vals.push(cc.kappa);
2617 dims.push(1);
2618 continue;
2619 }
2620 let length_scale = get_spatial_length_scale(spec, term_idx)
2621 .unwrap_or(options.min_length_scale)
2622 .clamp(options.min_length_scale, options.max_length_scale);
2623 let psi_bar = -length_scale.ln(); if spatial_term_uses_per_axis_psi(spec, term_idx) {
2626 let d = get_spatial_feature_dim(spec, term_idx).unwrap_or(1);
2631 let eta_raw = get_spatial_aniso_log_scales(spec, term_idx)
2632 .expect("predicate guarantees aniso_log_scales is Some");
2633 let eta = center_aniso_log_scales(&eta_raw);
2634 for &eta_a in &eta {
2635 vals.push(psi_bar + eta_a);
2636 }
2637 dims.push(d);
2638 } else {
2639 vals.push(psi_bar);
2646 dims.push(1);
2647 }
2648 }
2649 Self {
2650 values: Array1::from_vec(vals),
2651 dims_per_term: dims,
2652 }
2653 }
2654
2655 pub fn lower_bounds_from_data(
2659 data: ArrayView2<'_, f64>,
2660 spec: &TermCollectionSpec,
2661 term_indices: &[usize],
2662 options: &SpatialLengthScaleOptimizationOptions,
2663 ) -> Self {
2664 let mut values = Array1::<f64>::zeros(term_indices.len());
2665 for (slot, &term_idx) in term_indices.iter().enumerate() {
2666 values[slot] = spatial_term_psi_bounds(data, spec, term_idx, options).0;
2667 }
2668 Self {
2669 values,
2670 dims_per_term: vec![1; term_indices.len()],
2671 }
2672 }
2673
2674 pub fn upper_bounds_from_data(
2676 data: ArrayView2<'_, f64>,
2677 spec: &TermCollectionSpec,
2678 term_indices: &[usize],
2679 options: &SpatialLengthScaleOptimizationOptions,
2680 ) -> Self {
2681 let mut values = Array1::<f64>::zeros(term_indices.len());
2682 for (slot, &term_idx) in term_indices.iter().enumerate() {
2683 values[slot] = spatial_term_psi_bounds(data, spec, term_idx, options).1;
2684 }
2685 Self {
2686 values,
2687 dims_per_term: vec![1; term_indices.len()],
2688 }
2689 }
2690
2691 pub fn lower_bounds_aniso_from_data(
2708 data: ArrayView2<'_, f64>,
2709 spec: &TermCollectionSpec,
2710 term_indices: &[usize],
2711 dims_per_term: &[usize],
2712 options: &SpatialLengthScaleOptimizationOptions,
2713 ) -> Self {
2714 Self::aniso_bounds_from_data(
2715 data,
2716 spec,
2717 term_indices,
2718 dims_per_term,
2719 options,
2720 AnisoBoundEnd::Lower,
2721 )
2722 }
2723
2724 pub fn upper_bounds_aniso_from_data(
2728 data: ArrayView2<'_, f64>,
2729 spec: &TermCollectionSpec,
2730 term_indices: &[usize],
2731 dims_per_term: &[usize],
2732 options: &SpatialLengthScaleOptimizationOptions,
2733 ) -> Self {
2734 Self::aniso_bounds_from_data(
2735 data,
2736 spec,
2737 term_indices,
2738 dims_per_term,
2739 options,
2740 AnisoBoundEnd::Upper,
2741 )
2742 }
2743
2744 fn aniso_bounds_from_data(
2750 data: ArrayView2<'_, f64>,
2751 spec: &TermCollectionSpec,
2752 term_indices: &[usize],
2753 dims_per_term: &[usize],
2754 options: &SpatialLengthScaleOptimizationOptions,
2755 end: AnisoBoundEnd,
2756 ) -> Self {
2757 assert_eq!(term_indices.len(), dims_per_term.len());
2758 let total: usize = dims_per_term.iter().sum();
2759 let mut values = Array1::<f64>::zeros(total);
2760 let mut cursor = 0;
2761 for (slot, &term_idx) in term_indices.iter().enumerate() {
2762 let d = dims_per_term[slot];
2763 if let Some(mj) = measure_jet_term_spec(spec, term_idx) {
2766 let bounds = measure_jet_psi_bound_values(mj, matches!(end, AnisoBoundEnd::Upper));
2767 for (offset, bound) in bounds.into_iter().enumerate() {
2768 if offset < d {
2769 values[cursor + offset] = bound;
2770 }
2771 }
2772 cursor += d;
2773 continue;
2774 }
2775 if constant_curvature_term_spec(spec, term_idx).is_some() {
2778 let (lo, hi) = constant_curvature_kappa_bounds(data, spec, term_idx);
2779 if d >= 1 {
2780 values[cursor] = match end {
2781 AnisoBoundEnd::Lower => lo,
2782 AnisoBoundEnd::Upper => hi,
2783 };
2784 }
2785 cursor += d;
2786 continue;
2787 }
2788 let psi_bound = {
2789 let (lo, hi) = spatial_term_psi_bounds(data, spec, term_idx, options);
2790 match end {
2791 AnisoBoundEnd::Lower => lo,
2792 AnisoBoundEnd::Upper => hi,
2793 }
2794 };
2795 let axis_offsets = if d <= 1 {
2796 vec![0.0; d]
2797 } else {
2798 get_spatial_aniso_log_scales(spec, term_idx)
2799 .filter(|eta| eta.len() == d)
2800 .map(|eta| center_aniso_log_scales(&eta))
2801 .unwrap_or_else(|| vec![0.0; d])
2802 };
2803 for offset in 0..d {
2804 values[cursor + offset] = psi_bound + axis_offsets[offset];
2805 }
2806 cursor += d;
2807 }
2808 Self {
2809 values,
2810 dims_per_term: dims_per_term.to_vec(),
2811 }
2812 }
2813
2814 pub fn reseed_from_data(
2823 mut self,
2824 data: ArrayView2<'_, f64>,
2825 spec: &TermCollectionSpec,
2826 term_indices: &[usize],
2827 options: &SpatialLengthScaleOptimizationOptions,
2828 ) -> Self {
2829 assert_eq!(term_indices.len(), self.dims_per_term.len());
2830 let mut cursor = 0;
2831 for (slot, &term_idx) in term_indices.iter().enumerate() {
2832 let d = self.dims_per_term[slot];
2833 if measure_jet_term_spec(spec, term_idx).is_some() {
2836 cursor += d;
2837 continue;
2838 }
2839 if constant_curvature_term_spec(spec, term_idx).is_some() {
2843 cursor += d;
2844 continue;
2845 }
2846 let Some(psi_bar_new) = spatial_term_psi_seed(data, spec, term_idx, options) else {
2847 cursor += d;
2848 continue;
2849 };
2850 if d == 0 {
2851 continue;
2852 }
2853 let current: Vec<f64> = self.values.slice(s![cursor..cursor + d]).to_vec();
2854 let psi_bar_old = current.iter().sum::<f64>() / d as f64;
2855 for (offset, &old_value) in current.iter().enumerate() {
2856 self.values[cursor + offset] = psi_bar_new + (old_value - psi_bar_old);
2857 }
2858 cursor += d;
2859 }
2860 self
2861 }
2862
2863 pub fn clamp_to_bounds(
2874 mut self,
2875 lower: &SpatialLogKappaCoords,
2876 upper: &SpatialLogKappaCoords,
2877 ) -> Self {
2878 assert_eq!(self.values.len(), lower.values.len());
2879 assert_eq!(self.values.len(), upper.values.len());
2880 let mut n_projected = 0usize;
2881 let mut worst_delta = 0.0_f64;
2882 for idx in 0..self.values.len() {
2883 let lo = lower.values[idx];
2884 let hi = upper.values[idx];
2885 if !(lo.is_finite() && hi.is_finite()) {
2886 continue;
2887 }
2888 let v = self.values[idx];
2889 if v < lo {
2890 worst_delta = worst_delta.max(lo - v);
2891 self.values[idx] = lo;
2892 n_projected += 1;
2893 } else if v > hi {
2894 worst_delta = worst_delta.max(v - hi);
2895 self.values[idx] = hi;
2896 n_projected += 1;
2897 }
2898 }
2899 if n_projected > 0 {
2900 log::info!(
2901 "[spatial-kappa] projected {n_projected}/{} ψ seed coords into data-derived bounds \
2902 (worst excess={worst_delta:.3} log units); user length_scale falls outside \
2903 [{KERNEL_RANGE_MIN_DIAMETER_FRACTION}/r_max, {KERNEL_RANGE_MAX_SPACING_MULTIPLE}/r_min] geometry window",
2904 self.values.len()
2905 );
2906 }
2907 self
2908 }
2909
2910 pub fn from_theta_tail_with_dims(
2912 theta: &Array1<f64>,
2913 start: usize,
2914 dims_per_term: Vec<usize>,
2915 ) -> Self {
2916 let total: usize = dims_per_term.iter().sum();
2917 Self {
2918 values: theta.slice(s![start..start + total]).to_owned(),
2919 dims_per_term,
2920 }
2921 }
2922
2923 pub fn len(&self) -> usize {
2925 self.values.len()
2926 }
2927
2928 pub fn dims_per_term(&self) -> &[usize] {
2930 &self.dims_per_term
2931 }
2932
2933 fn term_offset(&self, term_idx: usize) -> usize {
2935 self.dims_per_term[..term_idx].iter().sum()
2936 }
2937
2938 pub fn term_slice(&self, term_idx: usize) -> &[f64] {
2940 let offset = self.term_offset(term_idx);
2941 let d = self.dims_per_term[term_idx];
2942 &self.values.as_slice().unwrap()[offset..offset + d]
2943 }
2944
2945 pub fn as_array(&self) -> &Array1<f64> {
2946 &self.values
2947 }
2948
2949 pub fn set_scalar_slot(&mut self, slot: usize, value: f64) -> bool {
2955 if slot >= self.dims_per_term.len() || self.dims_per_term[slot] != 1 {
2956 return false;
2957 }
2958 let offset = self.term_offset(slot);
2959 self.values[offset] = value;
2960 true
2961 }
2962
2963 pub fn split_at(&self, mid: usize) -> (Self, Self) {
2966 let flat_mid: usize = self.dims_per_term[..mid].iter().sum();
2967 (
2968 Self {
2969 values: self.values.slice(s![0..flat_mid]).to_owned(),
2970 dims_per_term: self.dims_per_term[..mid].to_vec(),
2971 },
2972 Self {
2973 values: self.values.slice(s![flat_mid..]).to_owned(),
2974 dims_per_term: self.dims_per_term[mid..].to_vec(),
2975 },
2976 )
2977 }
2978
2979 pub fn apply_tospec(
2986 &self,
2987 spec: &TermCollectionSpec,
2988 term_indices: &[usize],
2989 ) -> Result<TermCollectionSpec, EstimationError> {
2990 if term_indices.len() != self.dims_per_term.len() {
2991 crate::bail_invalid_estim!(
2992 "SpatialLogKappaCoords::apply_tospec: term count mismatch: \
2993 term_indices={} dims_per_term={}",
2994 term_indices.len(),
2995 self.dims_per_term.len()
2996 );
2997 }
2998 let mut updated = spec.clone();
2999 for (slot, &term_idx) in term_indices.iter().enumerate() {
3000 let psi = self.term_slice(slot);
3001 let d = self.dims_per_term[slot];
3002 if measure_jet_term_spec(&updated, term_idx).is_some() {
3005 set_measure_jet_psi_dials(&mut updated, term_idx, psi)?;
3006 continue;
3007 }
3008 if constant_curvature_term_spec(&updated, term_idx).is_some() {
3012 set_constant_curvature_kappa(&mut updated, term_idx, psi)?;
3013 continue;
3014 }
3015 let (next_length_scale, next_aniso) = spatial_term_psi_to_length_scale_and_aniso(psi);
3016 if (d == 1 || next_length_scale.is_some())
3017 && let Some(length_scale) = next_length_scale
3018 {
3019 set_spatial_length_scale(&mut updated, term_idx, length_scale)?;
3020 }
3021 if let Some(eta) = next_aniso {
3022 set_spatial_aniso_log_scales(&mut updated, term_idx, eta)?;
3023 }
3024 }
3025 Ok(updated)
3026 }
3027}
3028
3029pub fn center_aniso_log_scales(eta: &[f64]) -> Vec<f64> {
3030 if eta.len() <= 1 {
3031 return eta.to_vec();
3032 }
3033 let mean = eta.iter().sum::<f64>() / eta.len() as f64;
3034 eta.iter()
3035 .map(|&v| {
3036 let centered = v - mean;
3037 if centered.abs() <= 1e-15 {
3038 0.0
3039 } else {
3040 centered
3041 }
3042 })
3043 .collect()
3044}
3045
3046pub fn spatial_term_uses_per_axis_psi(resolvedspec: &TermCollectionSpec, term_idx: usize) -> bool {
3049 if let Some(mj) = measure_jet_term_spec(resolvedspec, term_idx) {
3050 return measure_jet_enrolls_psi(mj);
3051 }
3052 let Some(d) = get_spatial_feature_dim(resolvedspec, term_idx) else {
3053 return false;
3054 };
3055 if d <= 1 {
3056 return false;
3057 }
3058 let Some(eta) = get_spatial_aniso_log_scales(resolvedspec, term_idx) else {
3059 return false;
3060 };
3061 if eta.len() != d {
3062 return false;
3063 }
3064 !matches!(
3065 resolvedspec.smooth_terms.get(term_idx).map(|term| &term.basis),
3066 Some(SmoothBasisSpec::Duchon { .. })
3067 )
3068}
3069
3070pub fn set_spatial_length_scale(
3071 spec: &mut TermCollectionSpec,
3072 term_idx: usize,
3073 length_scale: f64,
3074) -> Result<(), EstimationError> {
3075 let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3076 crate::bail_invalid_estim!("spatial length-scale term index {term_idx} out of range");
3077 };
3078 match &mut term.basis {
3079 SmoothBasisSpec::ThinPlate { spec, .. } => {
3080 spec.length_scale = length_scale;
3081 Ok(())
3082 }
3083 SmoothBasisSpec::Matern { spec, .. } => {
3084 spec.length_scale = length_scale;
3085 Ok(())
3086 }
3087 SmoothBasisSpec::Duchon { spec, .. } => {
3088 spec.length_scale = Some(length_scale);
3089 Ok(())
3090 }
3091 _ => Err(EstimationError::InvalidInput(format!(
3092 "term '{}' does not expose a spatial length scale",
3093 term.name
3094 ))),
3095 }
3096}
3097
3098pub fn get_spatial_length_scale(spec: &TermCollectionSpec, term_idx: usize) -> Option<f64> {
3099 spec.smooth_terms
3100 .get(term_idx)
3101 .and_then(|term| match &term.basis {
3102 SmoothBasisSpec::ThinPlate { spec, .. } => Some(spec.length_scale),
3103 SmoothBasisSpec::Matern { spec, .. } => Some(spec.length_scale),
3104 SmoothBasisSpec::Duchon { spec, .. } => spec.length_scale,
3105 _ => None,
3106 })
3107}
3108
3109pub fn spatial_term_supports_hyper_optimization(spec: &TermCollectionSpec, term_idx: usize) -> bool {
3110 if let Some(term) = spec.smooth_terms.get(term_idx)
3116 && let SmoothBasisSpec::ThinPlate { .. } = &term.basis
3117 {
3118 return false;
3119 }
3120
3121 if let Some(term) = spec.smooth_terms.get(term_idx)
3146 && let SmoothBasisSpec::Matern { .. } = &term.basis
3147 {
3148 return true;
3149 }
3150
3151 if let Some(mj) = measure_jet_term_spec(spec, term_idx) {
3154 return measure_jet_enrolls_psi(mj);
3155 }
3156
3157 if constant_curvature_term_spec(spec, term_idx).is_some() {
3164 return true;
3165 }
3166
3167 get_spatial_length_scale(spec, term_idx).is_some()
3168}
3169
3170pub fn measure_jet_term_spec(
3173 spec: &TermCollectionSpec,
3174 term_idx: usize,
3175) -> Option<&crate::basis::MeasureJetBasisSpec> {
3176 spec.smooth_terms
3177 .get(term_idx)
3178 .and_then(|term| match &term.basis {
3179 SmoothBasisSpec::MeasureJet { spec, .. } => Some(spec),
3180 _ => None,
3181 })
3182}
3183
3184pub fn measure_jet_enrolls_psi(mj: &crate::basis::MeasureJetBasisSpec) -> bool {
3191 measure_jet_learns_length_scale(mj)
3200 || (mj.tau0 > 0.0 && crate::basis::measure_jet_multiscale_mode(mj))
3201}
3202
3203pub fn measure_jet_learns_length_scale(mj: &crate::basis::MeasureJetBasisSpec) -> bool {
3206 mj.learn_length_scale
3207}
3208
3209pub fn freeze_measure_jet_length_scale_learning(spec: &mut TermCollectionSpec) -> usize {
3210 let mut frozen = 0;
3211 for term in spec.smooth_terms.iter_mut() {
3212 if let SmoothBasisSpec::MeasureJet { spec: mj, .. } = &mut term.basis
3213 && mj.learn_length_scale
3214 {
3215 mj.learn_length_scale = false;
3216 frozen += 1;
3217 }
3218 }
3219 frozen
3220}
3221
3222pub const MEASURE_JET_PSI_ALPHA_BOUNDS: (f64, f64) = (-1.0, 3.0);
3230
3231pub const MEASURE_JET_PSI_LN_TAU_BOUNDS: (f64, f64) = (-18.420680743952367, 4.605170185988092);
3232
3233pub const MEASURE_JET_PSI_LN_LENGTH_SCALE_BOUNDS: (f64, f64) = (-6.907755278982137, 4.605170185988092);
3239
3240pub fn measure_jet_penalty_psi_dim(mj: &crate::basis::MeasureJetBasisSpec) -> usize {
3248 if crate::basis::measure_jet_multiscale_mode(mj) {
3249 2
3250 } else {
3251 0
3252 }
3253}
3254
3255pub fn measure_jet_psi_dim(mj: &crate::basis::MeasureJetBasisSpec) -> usize {
3259 usize::from(measure_jet_learns_length_scale(mj)) + measure_jet_penalty_psi_dim(mj)
3260}
3261
3262pub fn measure_jet_psi_seed(mj: &crate::basis::MeasureJetBasisSpec) -> Vec<f64> {
3267 let mut seed = Vec::with_capacity(measure_jet_psi_dim(mj));
3268 if measure_jet_learns_length_scale(mj) {
3269 let ell = if mj.length_scale > 0.0 {
3273 mj.length_scale
3274 } else {
3275 1.0
3276 };
3277 seed.push(ell.ln());
3278 }
3279 if measure_jet_penalty_psi_dim(mj) > 0 {
3280 let ln_tau = mj.tau0.max(f64::MIN_POSITIVE).ln();
3282 seed.extend_from_slice(&[mj.alpha, ln_tau]);
3283 }
3284 seed
3285}
3286
3287pub fn measure_jet_psi_bound_values(mj: &crate::basis::MeasureJetBasisSpec, upper: bool) -> Vec<f64> {
3290 let pick = |b: (f64, f64)| if upper { b.1 } else { b.0 };
3291 let mut bounds = Vec::with_capacity(measure_jet_psi_dim(mj));
3292 if measure_jet_learns_length_scale(mj) {
3293 bounds.push(pick(MEASURE_JET_PSI_LN_LENGTH_SCALE_BOUNDS));
3294 }
3295 if measure_jet_penalty_psi_dim(mj) > 0 {
3296 bounds.push(pick(MEASURE_JET_PSI_ALPHA_BOUNDS));
3298 bounds.push(pick(MEASURE_JET_PSI_LN_TAU_BOUNDS));
3299 }
3300 bounds
3301}
3302
3303pub fn apply_measure_jet_psi(
3308 mj: &mut crate::basis::MeasureJetBasisSpec,
3309 psi: &[f64],
3310) -> Result<bool, EstimationError> {
3311 if psi.len() != measure_jet_psi_dim(mj) {
3312 crate::bail_invalid_estim!(
3313 "measure-jet ψ write-back dimension mismatch: got {} values for a {}-dial term",
3314 psi.len(),
3315 measure_jet_psi_dim(mj)
3316 );
3317 }
3318 let mut changed = false;
3319 let mut cursor = 0usize;
3323 if measure_jet_learns_length_scale(mj) {
3324 let next_ell = psi[cursor].exp();
3325 cursor += 1;
3326 if !(next_ell.is_finite() && next_ell > 0.0) {
3327 crate::bail_invalid_estim!(
3328 "measure-jet ψ write-back produced a non-finite/non-positive length_scale (ℓ={next_ell})"
3329 );
3330 }
3331 if next_ell != mj.length_scale {
3332 mj.length_scale = next_ell;
3333 changed = true;
3334 }
3335 }
3336 if measure_jet_penalty_psi_dim(mj) > 0 {
3337 let next_alpha = psi[cursor];
3340 let next_tau = psi[cursor + 1].exp();
3341 if !(next_alpha.is_finite() && next_tau.is_finite() && next_tau > 0.0) {
3342 crate::bail_invalid_estim!(
3343 "measure-jet ψ write-back produced non-finite dials (alpha={next_alpha}, tau={next_tau})"
3344 );
3345 }
3346 if next_alpha != mj.alpha {
3347 mj.alpha = next_alpha;
3348 changed = true;
3349 }
3350 if next_tau != mj.tau0 {
3351 mj.tau0 = next_tau;
3352 changed = true;
3353 }
3354 }
3355 Ok(changed)
3356}
3357
3358pub fn set_measure_jet_psi_dials(
3361 spec: &mut TermCollectionSpec,
3362 term_idx: usize,
3363 psi: &[f64],
3364) -> Result<bool, EstimationError> {
3365 let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3366 crate::bail_invalid_estim!("measure-jet ψ write-back: term index {term_idx} out of range");
3367 };
3368 set_single_term_measure_jet_psi_dials(term, psi)
3369}
3370
3371pub fn set_single_term_measure_jet_psi_dials(
3376 term: &mut SmoothTermSpec,
3377 psi: &[f64],
3378) -> Result<bool, EstimationError> {
3379 let SmoothBasisSpec::MeasureJet { spec: mj, .. } = &mut term.basis else {
3380 crate::bail_invalid_estim!("measure-jet ψ write-back targeted a non-measure-jet term");
3381 };
3382 apply_measure_jet_psi(mj, psi)
3383}
3384
3385pub fn constant_curvature_term_spec(
3388 spec: &TermCollectionSpec,
3389 term_idx: usize,
3390) -> Option<&crate::basis::ConstantCurvatureBasisSpec> {
3391 spec.smooth_terms
3392 .get(term_idx)
3393 .and_then(|term| match &term.basis {
3394 SmoothBasisSpec::ConstantCurvature { spec, .. } => Some(spec),
3395 _ => None,
3396 })
3397}
3398
3399pub const CONSTANT_CURVATURE_KAPPA_CHART_FRACTION: f64 = 0.5;
3407
3408pub const CONSTANT_CURVATURE_MIN_CHART_RADIUS2: f64 = 1e-8;
3412
3413pub fn constant_curvature_kappa_bounds(
3418 data: ArrayView2<'_, f64>,
3419 spec: &TermCollectionSpec,
3420 term_idx: usize,
3421) -> (f64, f64) {
3422 let feature_cols = match spec.smooth_terms.get(term_idx).map(|t| &t.basis) {
3423 Some(SmoothBasisSpec::ConstantCurvature { feature_cols, .. }) => feature_cols,
3424 _ => return (-1.0, 1.0),
3425 };
3426 let mut max_r2 = CONSTANT_CURVATURE_MIN_CHART_RADIUS2;
3427 for row in data.outer_iter() {
3428 let mut r2 = 0.0_f64;
3429 for &c in feature_cols.iter() {
3430 if let Some(&v) = row.get(c)
3431 && v.is_finite()
3432 {
3433 r2 += v * v;
3434 }
3435 }
3436 if r2 > max_r2 {
3437 max_r2 = r2;
3438 }
3439 }
3440 let half = CONSTANT_CURVATURE_KAPPA_CHART_FRACTION / max_r2;
3441 (-half, half)
3442}
3443
3444pub fn set_constant_curvature_kappa(
3448 spec: &mut TermCollectionSpec,
3449 term_idx: usize,
3450 psi: &[f64],
3451) -> Result<bool, EstimationError> {
3452 let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3453 crate::bail_invalid_estim!(
3454 "constant-curvature κ write-back: term index {term_idx} out of range"
3455 );
3456 };
3457 set_single_term_constant_curvature_kappa(term, psi)
3458}
3459
3460pub fn set_single_term_constant_curvature_kappa(
3465 term: &mut SmoothTermSpec,
3466 psi: &[f64],
3467) -> Result<bool, EstimationError> {
3468 if psi.len() != 1 {
3469 crate::bail_invalid_estim!(
3470 "constant-curvature κ write-back expects exactly one value, got {}",
3471 psi.len()
3472 );
3473 }
3474 let next_kappa = psi[0];
3475 if !next_kappa.is_finite() {
3476 crate::bail_invalid_estim!(
3477 "constant-curvature κ write-back produced a non-finite κ = {next_kappa}"
3478 );
3479 }
3480 let SmoothBasisSpec::ConstantCurvature { spec: cc, .. } = &mut term.basis else {
3481 crate::bail_invalid_estim!(
3482 "constant-curvature κ write-back targeted a non-constant-curvature term"
3483 );
3484 };
3485 if cc.kappa != next_kappa {
3486 cc.kappa = next_kappa;
3487 Ok(true)
3488 } else {
3489 Ok(false)
3490 }
3491}
3492
3493pub fn spatial_term_has_locked_kappa(spec: &TermCollectionSpec, term_idx: usize) -> bool {
3504 get_spatial_length_scale(spec, term_idx).is_some()
3505 && !spatial_term_uses_per_axis_psi(spec, term_idx)
3506}
3507
3508pub fn all_spatial_terms_kappa_fixed(spec: &TermCollectionSpec) -> bool {
3509 spec.smooth_terms.iter().enumerate().all(|(idx, _)| {
3510 !spatial_term_supports_hyper_optimization(spec, idx)
3511 || spatial_term_has_locked_kappa(spec, idx)
3512 })
3513}
3514
3515pub fn spatial_identifiability_policy(termspec: &SmoothTermSpec) -> Option<&SpatialIdentifiability> {
3516 match &termspec.basis {
3517 SmoothBasisSpec::ThinPlate { spec, .. } => Some(&spec.identifiability),
3518 SmoothBasisSpec::Duchon { spec, .. } => Some(&spec.identifiability),
3519 _ => None,
3520 }
3521}
3522
3523pub const NULLSPACE_WELLDET_DEGENERACY_RHO_SD: f64 = 15.0;
3527
3528pub fn is_nullspace_degeneracy_prior(prior: &gam_spec::RhoPrior) -> bool {
3531 matches!(
3532 prior,
3533 gam_spec::RhoPrior::Normal { mean, sd }
3534 if *mean == 0.0 && *sd == NULLSPACE_WELLDET_DEGENERACY_RHO_SD
3535 )
3536}
3537
3538pub const KERNEL_RANGE_MIN_DIAMETER_FRACTION: f64 = 2.0;
3550
3551pub const KERNEL_RANGE_MAX_SPACING_MULTIPLE: f64 = 1e2;
3556
3557
3558pub fn spatial_term_psi_bounds(
3567 data: ArrayView2<'_, f64>,
3568 spec: &TermCollectionSpec,
3569 term_idx: usize,
3570 options: &SpatialLengthScaleOptimizationOptions,
3571) -> (f64, f64) {
3572 let fallback = (
3573 -options.max_length_scale.ln(),
3574 -options.min_length_scale.ln(),
3575 );
3576 if constant_curvature_term_spec(spec, term_idx).is_some() {
3581 return constant_curvature_kappa_bounds(data, spec, term_idx);
3582 }
3583 let Some(term) = spec.smooth_terms.get(term_idx) else {
3584 return fallback;
3585 };
3586 let aniso = get_spatial_aniso_log_scales(spec, term_idx);
3599 let r_bounds = match spatial_term_center_strategy(term) {
3600 Some(CenterStrategy::UserProvided(centers)) if centers.nrows() >= 2 => {
3601 match aniso.as_deref() {
3602 Some(eta) if eta.len() == centers.ncols() => {
3603 let y = points_in_aniso_y_space(centers.view(), eta);
3604 pairwise_distance_bounds(y.view())
3605 }
3606 _ => pairwise_distance_bounds(centers.view()),
3607 }
3608 }
3609 _ => standardized_spatial_term_data(data, term)
3610 .ok()
3611 .and_then(|x| match aniso.as_deref() {
3612 Some(eta) if eta.len() == x.ncols() => {
3613 let y = points_in_aniso_y_space(x.view(), eta);
3614 pairwise_distance_bounds_sampled(y.view())
3615 }
3616 _ => pairwise_distance_bounds_sampled(x.view()),
3617 }),
3618 };
3619 let Some((r_min, r_max)) = r_bounds else {
3620 return fallback;
3621 };
3622 let psi_lo_data = (KERNEL_RANGE_MIN_DIAMETER_FRACTION / r_max).ln();
3628 let psi_hi_data = (KERNEL_RANGE_MAX_SPACING_MULTIPLE / r_min).ln();
3629 let psi_lo = psi_lo_data.max(fallback.0);
3639 let psi_hi = psi_hi_data.min(fallback.1);
3640 if psi_lo >= psi_hi {
3641 return fallback;
3644 }
3645 (psi_lo, psi_hi)
3646}
3647
3648pub fn spatial_term_psi_seed(
3652 data: ArrayView2<'_, f64>,
3653 spec: &TermCollectionSpec,
3654 term_idx: usize,
3655 options: &SpatialLengthScaleOptimizationOptions,
3656) -> Option<f64> {
3657 if get_spatial_length_scale(spec, term_idx).is_some() {
3658 return None; }
3660 let (psi_lo, psi_hi) = spatial_term_psi_bounds(data, spec, term_idx, options);
3661 Some(0.5 * (psi_lo + psi_hi))
3662}
3663
3664pub fn spatial_term_psi_to_length_scale_and_aniso(psi: &[f64]) -> (Option<f64>, Option<Vec<f64>>) {
3665 if psi.len() <= 1 {
3666 (Some((-psi.first().copied().unwrap_or(0.0)).exp()), None)
3667 } else {
3668 let psi_bar = psi.iter().sum::<f64>() / psi.len() as f64;
3669 (
3670 Some((-psi_bar).exp()),
3671 Some(psi.iter().map(|&value| value - psi_bar).collect()),
3672 )
3673 }
3674}
3675
3676pub fn get_spatial_aniso_log_scales(
3678 spec: &TermCollectionSpec,
3679 term_idx: usize,
3680) -> Option<Vec<f64>> {
3681 spec.smooth_terms
3682 .get(term_idx)
3683 .and_then(|term| match &term.basis {
3684 SmoothBasisSpec::Matern { spec, .. } => spec.aniso_log_scales.clone(),
3685 SmoothBasisSpec::Duchon { spec, .. } => spec.aniso_log_scales.clone(),
3686 _ => None,
3687 })
3688}
3689
3690pub fn response_aware_axis_contrasts(
3710 x: ndarray::ArrayView2<'_, f64>,
3711 y: ndarray::ArrayView1<'_, f64>,
3712) -> Option<Vec<f64>> {
3713 let n = x.nrows();
3714 let d = x.ncols();
3715 if d <= 1 || n < 4 || y.len() != n {
3716 return None;
3717 }
3718 if x.iter().any(|v| !v.is_finite()) || y.iter().any(|v| !v.is_finite()) {
3719 return None;
3720 }
3721 let mut scores = Vec::with_capacity(d);
3722 for a in 0..d {
3723 let mut order: Vec<usize> = (0..n).collect();
3724 let col = x.column(a);
3725 order.sort_by(|&i, &j| {
3726 col[i]
3727 .partial_cmp(&col[j])
3728 .unwrap_or(std::cmp::Ordering::Equal)
3729 });
3730 let mut tv = 0.0_f64;
3731 for w in order.windows(2) {
3732 let diff = y[w[1]] - y[w[0]];
3733 tv += diff * diff;
3734 }
3735 scores.push(-0.5 * (tv + 1e-12).ln());
3737 }
3738 if scores.iter().any(|v| !v.is_finite()) {
3739 return None;
3740 }
3741 let mean = scores.iter().sum::<f64>() / d as f64;
3742 let centered: Vec<f64> = scores.iter().map(|&s| s - mean).collect();
3743 if centered.iter().all(|&v| v.abs() < 1e-9) {
3746 return None;
3747 }
3748 Some(centered)
3749}
3750
3751pub fn apply_response_aware_anisotropy_seed(
3760 data: ArrayView2<'_, f64>,
3761 y: ndarray::ArrayView1<'_, f64>,
3762 spec: &mut TermCollectionSpec,
3763 spatial_terms: &[usize],
3764) {
3765 const MAX_NUDGE: f64 = std::f64::consts::LN_2;
3770 for &term_idx in spatial_terms {
3771 let Some(current_eta) = get_spatial_aniso_log_scales(spec, term_idx) else {
3772 continue;
3773 };
3774 let d = current_eta.len();
3775 if d <= 1 {
3776 continue;
3777 }
3778 let Some(term) = spec.smooth_terms.get(term_idx) else {
3779 continue;
3780 };
3781 let feature_cols = term.basis.structural_feature_cols();
3782 if feature_cols.len() != d {
3783 continue;
3784 }
3785 let Ok(x) = select_columns(data, &feature_cols) else {
3786 continue;
3787 };
3788 let Some(contrast) = response_aware_axis_contrasts(x.view(), y) else {
3789 continue;
3790 };
3791 let nudged: Vec<f64> = current_eta
3792 .iter()
3793 .zip(contrast.iter())
3794 .map(|(&eta_a, &c_a)| eta_a + c_a.clamp(-MAX_NUDGE, MAX_NUDGE))
3795 .collect();
3796 if let Err(err) = set_spatial_aniso_log_scales(spec, term_idx, nudged) {
3799 log::debug!(
3800 "[spatial-kappa] response-aware anisotropy seed skipped for term {term_idx}: {err}"
3801 );
3802 }
3803 }
3804}
3805
3806pub fn get_spatial_feature_dim(spec: &TermCollectionSpec, term_idx: usize) -> Option<usize> {
3808 spec.smooth_terms
3809 .get(term_idx)
3810 .and_then(|term| match &term.basis {
3811 SmoothBasisSpec::ThinPlate { feature_cols, .. } => Some(feature_cols.len()),
3812 SmoothBasisSpec::Matern { feature_cols, .. } => Some(feature_cols.len()),
3813 SmoothBasisSpec::Duchon { feature_cols, .. } => Some(feature_cols.len()),
3814 _ => None,
3815 })
3816}
3817
3818pub fn log_spatial_aniso_scales(spec: &TermCollectionSpec) {
3825 for (term_idx, term) in spec.smooth_terms.iter().enumerate() {
3826 let (aniso, length_scale) = match &term.basis {
3827 SmoothBasisSpec::Matern { spec, .. } => {
3828 (spec.aniso_log_scales.as_ref(), Some(spec.length_scale))
3829 }
3830 SmoothBasisSpec::Duchon { spec, .. } => {
3831 (spec.aniso_log_scales.as_ref(), spec.length_scale)
3832 }
3833 _ => (None, None),
3834 };
3835 let Some(eta) = aniso else { continue };
3836 if eta.is_empty() {
3837 continue;
3838 }
3839 let mut lines = match length_scale {
3840 Some(ls) => format!(
3841 "[spatial-kappa] term {} (\"{}\"): anisotropic length scales optimized (global length_scale={:.4})",
3842 term_idx, term.name, ls
3843 ),
3844 None => format!(
3845 "[spatial-kappa] term {} (\"{}\"): pure Duchon shape anisotropy optimized",
3846 term_idx, term.name
3847 ),
3848 };
3849 for (a, &eta_a) in eta.iter().enumerate() {
3850 if let Some(ls) = length_scale {
3851 let length_a = ls * (-eta_a).exp();
3852 let kappa_a = (1.0 / ls) * eta_a.exp();
3853 lines.push_str(&format!(
3854 "\n axis {}: eta={:+.4}, length={:.4}, kappa={:.4}",
3855 a, eta_a, length_a, kappa_a
3856 ));
3857 } else {
3858 lines.push_str(&format!("\n axis {}: eta={:+.4}", a, eta_a));
3859 }
3860 }
3861 log::info!("{}", lines);
3862 }
3863}
3864
3865pub fn set_spatial_aniso_log_scales(
3867 spec: &mut TermCollectionSpec,
3868 term_idx: usize,
3869 eta: Vec<f64>,
3870) -> Result<(), EstimationError> {
3871 let eta = center_aniso_log_scales(&eta);
3872 let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
3873 crate::bail_invalid_estim!("spatial aniso_log_scales term index {term_idx} out of range");
3874 };
3875 match &mut term.basis {
3876 SmoothBasisSpec::Matern { spec, .. } => {
3877 spec.aniso_log_scales = Some(eta);
3878 Ok(())
3879 }
3880 SmoothBasisSpec::Duchon { spec, .. } => {
3881 spec.aniso_log_scales = Some(eta);
3882 Ok(())
3883 }
3884 _ => Err(EstimationError::InvalidInput(format!(
3885 "term '{}' does not support aniso_log_scales",
3886 term.name
3887 ))),
3888 }
3889}
3890
3891pub fn sync_aniso_contrasts_from_metadata(
3898 spec: &mut TermCollectionSpec,
3899 design: &SmoothDesign,
3900) {
3901 for (term_idx, term) in design.terms.iter().enumerate() {
3902 let meta_aniso = match &term.metadata {
3903 BasisMetadata::Matern {
3904 aniso_log_scales, ..
3905 } => aniso_log_scales.clone(),
3906 BasisMetadata::Duchon {
3907 aniso_log_scales, ..
3908 } => aniso_log_scales.clone(),
3909 _ => None,
3910 };
3911 if let Some(eta) = meta_aniso
3912 && eta.len() > 1
3913 {
3914 set_spatial_aniso_log_scales(spec, term_idx, eta).ok();
3915 }
3916 }
3917}
3918
3919#[derive(Debug, Clone)]
3920pub struct SpatialLengthScaleOptimizationOptions {
3921 pub enabled: bool,
3925 pub max_outer_iter: usize,
3927 pub rel_tol: f64,
3929 pub log_step: f64,
3931 pub min_length_scale: f64,
3933 pub max_length_scale: f64,
3935 pub pilot_subsample_threshold: usize,
3948}
3949
3950impl Default for SpatialLengthScaleOptimizationOptions {
3951 fn default() -> Self {
3952 Self {
3953 enabled: true,
3954 max_outer_iter: 80,
3955 rel_tol: 1e-4,
3956 log_step: std::f64::consts::LN_2,
3957 min_length_scale: 1e-3,
3958 max_length_scale: 1e3,
3959 pilot_subsample_threshold: 10_000,
3960 }
3961 }
3962}
3963
3964impl SpatialLengthScaleOptimizationOptions {
3965 pub fn validate(&self) -> Result<(), String> {
3983 if !self.min_length_scale.is_finite() || self.min_length_scale <= 0.0 {
3984 return Err(SmoothError::invalid_config(format!(
3985 "SpatialLengthScaleOptimizationOptions::min_length_scale must be > 0 and finite, got {}",
3986 self.min_length_scale
3987 ))
3988 .into());
3989 }
3990 if !self.max_length_scale.is_finite() || self.max_length_scale <= 0.0 {
3991 return Err(SmoothError::invalid_config(format!(
3992 "SpatialLengthScaleOptimizationOptions::max_length_scale must be > 0 and finite, got {}",
3993 self.max_length_scale
3994 ))
3995 .into());
3996 }
3997 if self.min_length_scale >= self.max_length_scale {
3998 return Err(SmoothError::invalid_config(format!(
3999 "SpatialLengthScaleOptimizationOptions requires min_length_scale < max_length_scale, got min={} max={}",
4000 self.min_length_scale, self.max_length_scale
4001 ))
4002 .into());
4003 }
4004 if !self.rel_tol.is_finite() || self.rel_tol <= 0.0 {
4005 return Err(SmoothError::invalid_config(format!(
4006 "SpatialLengthScaleOptimizationOptions::rel_tol must be > 0 and finite, got {}",
4007 self.rel_tol
4008 ))
4009 .into());
4010 }
4011 if !self.log_step.is_finite() || self.log_step <= 0.0 {
4012 return Err(SmoothError::invalid_config(format!(
4013 "SpatialLengthScaleOptimizationOptions::log_step must be > 0 and finite, got {}",
4014 self.log_step
4015 ))
4016 .into());
4017 }
4018 Ok(())
4019 }
4020}
4021
4022#[derive(Debug, Clone)]
4023pub struct RandomEffectBlock {
4024 pub name: String,
4025 pub group_ids: Vec<Option<usize>>,
4028 pub num_groups: usize,
4029 pub kept_levels: Vec<u64>,
4030}
4031
4032pub const BLOCK_SPARSE_ZERO_EPS: f64 = 1e-12;
4033
4034pub const BLOCK_SPARSE_MAX_DENSITY: f64 = 0.20;
4035
4036pub fn blocks_have_intrinsic_sparse_structure(blocks: &[DesignBlock]) -> bool {
4037 blocks
4038 .iter()
4039 .any(|block| matches!(block, DesignBlock::Sparse(_) | DesignBlock::RandomEffect(_)))
4040}
4041
4042pub fn sparse_compatible_block_nnz(block: &DesignBlock) -> Option<usize> {
4043 match block {
4044 DesignBlock::Intercept(n) => Some(*n),
4045 DesignBlock::RandomEffect(op) => {
4046 Some(op.group_ids.iter().filter(|gid| gid.is_some()).count())
4047 }
4048 DesignBlock::Sparse(sparse) => Some(sparse.val().len()),
4049 DesignBlock::Dense(dense) => dense.as_dense_ref().map(|matrix| {
4050 matrix
4051 .iter()
4052 .filter(|&&value| value.abs() > BLOCK_SPARSE_ZERO_EPS)
4053 .count()
4054 }),
4055 }
4056}
4057
4058pub fn try_build_sparse_design_from_blocks(
4059 blocks: &[DesignBlock],
4060) -> Result<Option<DesignMatrix>, BasisError> {
4061 if blocks.is_empty() {
4062 return Ok(None);
4063 }
4064 let nrows = blocks[0].nrows();
4065 let ncols: usize = blocks.iter().map(DesignBlock::ncols).sum();
4066 if nrows == 0 || ncols == 0 || ncols <= 32 {
4067 return Ok(None);
4068 }
4069
4070 let preserve_sparse_storage = blocks_have_intrinsic_sparse_structure(blocks);
4071 let sparse_nnz_limit = if preserve_sparse_storage {
4072 usize::MAX
4073 } else {
4074 let total_cells = nrows.saturating_mul(ncols);
4075 ((total_cells as f64) * BLOCK_SPARSE_MAX_DENSITY).floor() as usize
4076 };
4077 let mut nnz = 0usize;
4078 for block in blocks {
4079 let block_nnz = if let Some(block_nnz) = sparse_compatible_block_nnz(block) {
4080 block_nnz
4081 } else {
4082 return Ok(None);
4083 };
4084 nnz = nnz.saturating_add(block_nnz);
4085 if nnz > sparse_nnz_limit {
4086 return Ok(None);
4087 }
4088 }
4089
4090 let mut triplets = Vec::<Triplet<usize, usize, f64>>::with_capacity(nnz);
4091 let mut col_offset = 0usize;
4092 for block in blocks {
4093 match block {
4094 DesignBlock::Intercept(n) => {
4095 for row in 0..*n {
4096 triplets.push(Triplet::new(row, col_offset, 1.0));
4097 }
4098 }
4099 DesignBlock::RandomEffect(op) => {
4100 for (row, group_id) in op.group_ids.iter().enumerate() {
4101 if let Some(group) = group_id {
4102 triplets.push(Triplet::new(row, col_offset + group, 1.0));
4103 }
4104 }
4105 }
4106 DesignBlock::Sparse(sparse) => {
4107 let (symbolic, values) = sparse.parts();
4108 let col_ptr = symbolic.col_ptr();
4109 let row_idx = symbolic.row_idx();
4110 for col in 0..sparse.ncols() {
4111 for idx in col_ptr[col]..col_ptr[col + 1] {
4112 let value = values[idx];
4113 if value.abs() > BLOCK_SPARSE_ZERO_EPS {
4114 triplets.push(Triplet::new(row_idx[idx], col_offset + col, value));
4115 }
4116 }
4117 }
4118 }
4119 DesignBlock::Dense(dense) => {
4120 let matrix = dense.as_dense_ref().ok_or_else(|| {
4121 BasisError::InvalidInput(
4122 "sparse-compatible block assembly requires materialized dense blocks"
4123 .to_string(),
4124 )
4125 })?;
4126 for row in 0..matrix.nrows() {
4127 for col in 0..matrix.ncols() {
4128 let value = matrix[[row, col]];
4129 if value.abs() > BLOCK_SPARSE_ZERO_EPS {
4130 triplets.push(Triplet::new(row, col_offset + col, value));
4131 }
4132 }
4133 }
4134 }
4135 }
4136 col_offset += block.ncols();
4137 }
4138
4139 let sparse = SparseColMat::try_new_from_triplets(nrows, ncols, &triplets).map_err(|_| {
4140 BasisError::SparseCreation("failed to assemble sparse term-collection design".to_string())
4141 })?;
4142 Ok(Some(DesignMatrix::Sparse(
4143 gam_linalg::matrix::SparseDesignMatrix::new(sparse),
4144 )))
4145}
4146
4147pub fn assemble_term_collection_design_matrix(
4148 blocks: Vec<DesignBlock>,
4149) -> Result<DesignMatrix, BasisError> {
4150 if let Some(sparse) = try_build_sparse_design_from_blocks(&blocks)? {
4151 return Ok(sparse);
4152 }
4153 let block_op = BlockDesignOperator::new(blocks).map_err(|e| {
4154 BasisError::InvalidInput(format!("failed to build block design operator: {e}"))
4155 })?;
4156 Ok(DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
4157 Arc::new(block_op),
4158 )))
4159}
4160
4161pub fn select_columns(data: ArrayView2<'_, f64>, cols: &[usize]) -> Result<Array2<f64>, BasisError> {
4162 let n = data.nrows();
4163 let p = data.ncols();
4164 for &c in cols {
4165 if c >= p {
4166 crate::bail_dim_basis!("feature column {c} is out of bounds for data with {p} columns");
4167 }
4168 }
4169 let mut out = Array2::<f64>::zeros((n, cols.len()));
4170 for (j, &c) in cols.iter().enumerate() {
4171 out.column_mut(j).assign(&data.column(c));
4172 }
4173 Ok(out)
4174}
4175
4176pub fn nonfinite_value_label(value: f64) -> &'static str {
4177 if value.is_nan() {
4178 "NaN"
4179 } else if value.is_sign_positive() {
4180 "+Inf"
4181 } else {
4182 "-Inf"
4183 }
4184}
4185
4186pub fn validate_term_feature_column_finite(
4187 data: ArrayView2<'_, f64>,
4188 term_kind: &str,
4189 term_name: &str,
4190 feature_col: usize,
4191) -> Result<(), BasisError> {
4192 let p = data.ncols();
4193 if feature_col >= p {
4194 crate::bail_dim_basis!(
4195 "{term_kind} term '{term_name}' feature column {feature_col} out of bounds for {p} columns"
4196 );
4197 }
4198 for (row, &value) in data.column(feature_col).iter().enumerate() {
4199 if !value.is_finite() {
4200 crate::bail_invalid_basis!(
4201 "{term_kind} term '{term_name}' feature column {feature_col} row {row} contains non-finite value {}",
4202 nonfinite_value_label(value)
4203 );
4204 }
4205 }
4206 Ok(())
4207}
4208
4209pub fn validate_smooth_terms_finite_inputs(
4210 data: ArrayView2<'_, f64>,
4211 terms: &[SmoothTermSpec],
4212) -> Result<(), BasisError> {
4213 for term in terms {
4214 for feature_col in smooth_term_feature_cols(term) {
4215 validate_term_feature_column_finite(data, "smooth", &term.name, feature_col)?;
4216 }
4217 }
4218 Ok(())
4219}
4220
4221pub fn validate_term_collection_finite_inputs(
4222 data: ArrayView2<'_, f64>,
4223 spec: &TermCollectionSpec,
4224) -> Result<(), BasisError> {
4225 for term in &spec.linear_terms {
4226 validate_term_feature_column_finite(data, "linear", &term.name, term.feature_col)?;
4227 }
4228 for term in &spec.random_effect_terms {
4229 validate_term_feature_column_finite(data, "random-effect", &term.name, term.feature_col)?;
4230 }
4231 validate_smooth_terms_finite_inputs(data, &spec.smooth_terms)
4232}
4233
4234#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
4235pub struct JointSpatialCenterGroupKey {
4236 feature_cols: Vec<usize>,
4237 strategy_kind: CenterStrategyKind,
4238 strategy_aux: usize,
4239 requested_num_centers: usize,
4240 input_scale_bits: Option<Vec<u64>>,
4241}
4242
4243pub fn spatial_term_min_center_count(term: &SmoothTermSpec) -> usize {
4244 match &term.basis {
4245 SmoothBasisSpec::ThinPlate { feature_cols, .. } => feature_cols.len() + 1,
4246 SmoothBasisSpec::Duchon {
4247 feature_cols, spec, ..
4248 } => match spec.nullspace_order {
4249 crate::basis::DuchonNullspaceOrder::Zero => 1,
4250 crate::basis::DuchonNullspaceOrder::Linear => feature_cols.len() + 1,
4251 crate::basis::DuchonNullspaceOrder::Degree(degree) => {
4252 crate::basis::duchon_nullspace_dimension(feature_cols.len(), degree)
4253 }
4254 },
4255 SmoothBasisSpec::Matern { .. } => 1,
4256 _ => 1,
4257 }
4258}
4259
4260pub fn spatial_term_group_key(term: &SmoothTermSpec) -> Option<JointSpatialCenterGroupKey> {
4261 let (feature_cols, strategy, input_scales) = match &term.basis {
4262 SmoothBasisSpec::ThinPlate {
4263 feature_cols,
4264 spec,
4265 input_scales,
4266 } => (feature_cols, &spec.center_strategy, input_scales.as_ref()),
4267 SmoothBasisSpec::Matern {
4268 feature_cols,
4269 spec,
4270 input_scales,
4271 } => (feature_cols, &spec.center_strategy, input_scales.as_ref()),
4272 SmoothBasisSpec::Duchon {
4273 feature_cols,
4274 spec,
4275 input_scales,
4276 } => (feature_cols, &spec.center_strategy, input_scales.as_ref()),
4277 _ => return None,
4278 };
4279 let strategy_kind = center_strategy_kind(strategy);
4280 let strategy_aux = match strategy {
4281 CenterStrategy::Auto(inner) => match inner.as_ref() {
4282 CenterStrategy::KMeans { max_iter, .. } => *max_iter,
4283 CenterStrategy::UniformGrid { points_per_dim } => *points_per_dim,
4284 _ => 0,
4285 },
4286 CenterStrategy::KMeans { max_iter, .. } => *max_iter,
4287 CenterStrategy::UniformGrid { points_per_dim } => *points_per_dim,
4288 _ => 0,
4289 };
4290 Some(JointSpatialCenterGroupKey {
4291 feature_cols: feature_cols.clone(),
4292 strategy_kind,
4293 strategy_aux,
4294 requested_num_centers: center_strategy_num_centers(strategy)?,
4295 input_scale_bits: input_scales
4296 .map(|values| values.iter().map(|value| value.to_bits()).collect()),
4297 })
4298}
4299
4300pub fn spatial_term_center_strategy(term: &SmoothTermSpec) -> Option<&CenterStrategy> {
4301 match &term.basis {
4302 SmoothBasisSpec::ThinPlate { spec, .. } => Some(&spec.center_strategy),
4303 SmoothBasisSpec::Matern { spec, .. } => Some(&spec.center_strategy),
4304 SmoothBasisSpec::Duchon { spec, .. } => Some(&spec.center_strategy),
4305 _ => None,
4306 }
4307}
4308
4309pub fn set_spatial_term_centers(
4310 term: &mut SmoothTermSpec,
4311 centers: Array2<f64>,
4312) -> Result<(), BasisError> {
4313 match &mut term.basis {
4314 SmoothBasisSpec::ThinPlate { spec, .. } => {
4315 spec.center_strategy = CenterStrategy::UserProvided(centers);
4316 Ok(())
4317 }
4318 SmoothBasisSpec::Matern { spec, .. } => {
4319 spec.center_strategy = CenterStrategy::UserProvided(centers);
4320 Ok(())
4321 }
4322 SmoothBasisSpec::Duchon { spec, .. } => {
4323 spec.center_strategy = CenterStrategy::UserProvided(centers);
4324 Ok(())
4325 }
4326 _ => Err(BasisError::InvalidInput(format!(
4327 "term '{}' does not support spatial center planning",
4328 term.name
4329 ))),
4330 }
4331}
4332
4333pub fn standardized_spatial_term_data(
4334 data: ArrayView2<'_, f64>,
4335 term: &SmoothTermSpec,
4336) -> Result<Array2<f64>, BasisError> {
4337 let (feature_cols, input_scales) = match &term.basis {
4338 SmoothBasisSpec::ThinPlate {
4339 feature_cols,
4340 input_scales,
4341 ..
4342 }
4343 | SmoothBasisSpec::Matern {
4344 feature_cols,
4345 input_scales,
4346 ..
4347 }
4348 | SmoothBasisSpec::Duchon {
4349 feature_cols,
4350 input_scales,
4351 ..
4352 } => (feature_cols, input_scales.as_ref()),
4353 _ => {
4354 crate::bail_invalid_basis!("term '{}' is not a spatial smooth", term.name);
4355 }
4356 };
4357 let mut x = select_columns(data, feature_cols)?;
4358 if let Some(scales) = input_scales {
4359 apply_input_standardization(&mut x, scales);
4360 } else if let Some(scales) = compute_spatial_input_scales(x.view()) {
4361 apply_input_standardization(&mut x, &scales);
4362 }
4363 Ok(x)
4364}
4365
4366pub fn plan_joint_spatial_centers_for_term_blocks(
4367 data: ArrayView2<'_, f64>,
4368 term_blocks: &[Vec<SmoothTermSpec>],
4369) -> Result<Vec<Vec<SmoothTermSpec>>, BasisError> {
4370 let mut planned_blocks = term_blocks.to_vec();
4371 let n = data.nrows();
4372 let mut groups: BTreeMap<JointSpatialCenterGroupKey, Vec<(usize, usize)>> = BTreeMap::new();
4373
4374 for (block_idx, terms) in planned_blocks.iter().enumerate() {
4375 for (term_idx, term) in terms.iter().enumerate() {
4376 let Some(strategy) = spatial_term_center_strategy(term) else {
4377 continue;
4378 };
4379 if !center_strategy_is_auto(strategy) {
4380 continue;
4381 }
4382 let Some(group_key) = spatial_term_group_key(term) else {
4383 continue;
4384 };
4385 if !matches!(
4386 group_key.strategy_kind,
4387 CenterStrategyKind::EqualMass
4388 | CenterStrategyKind::EqualMassCovarRepresentative
4389 | CenterStrategyKind::FarthestPoint
4390 | CenterStrategyKind::KMeans
4391 ) {
4392 continue;
4393 }
4394 if center_strategy_num_centers(strategy).is_none() {
4395 continue;
4396 }
4397 groups
4398 .entry(group_key)
4399 .or_default()
4400 .push((block_idx, term_idx));
4401 }
4402 }
4403
4404 for (group_key, members) in groups {
4405 if members.len() < 2 {
4406 continue;
4407 }
4408 let min_required = members
4409 .iter()
4410 .map(|&(block_idx, term_idx)| {
4411 spatial_term_min_center_count(&planned_blocks[block_idx][term_idx])
4412 })
4413 .max()
4414 .unwrap_or(1);
4415 let joint_centers = group_key
4416 .requested_num_centers
4417 .max(min_required)
4418 .min(n.max(1));
4419 let (first_block_idx, first_term_idx) = members[0];
4420 let prototype = &planned_blocks[first_block_idx][first_term_idx];
4421 let standardized = standardized_spatial_term_data(data, prototype)?;
4422 let strategy = spatial_term_center_strategy(prototype).ok_or_else(|| {
4423 BasisError::InvalidInput(format!(
4424 "term '{}' lost its spatial center strategy during joint planning",
4425 prototype.name
4426 ))
4427 })?;
4428 let joint_strategy = center_strategy_with_num_centers(strategy, joint_centers)?;
4429 let shared_centers = select_centers_by_strategy(standardized.view(), &joint_strategy)?;
4430 log::info!(
4431 "sharing {} spatial centers across {} smooth terms over columns {:?} (requested {} centers)",
4432 shared_centers.nrows(),
4433 members.len(),
4434 group_key.feature_cols,
4435 group_key.requested_num_centers,
4436 );
4437 for (block_idx, term_idx) in members {
4438 set_spatial_term_centers(
4439 &mut planned_blocks[block_idx][term_idx],
4440 shared_centers.clone(),
4441 )?;
4442 }
4443 }
4444
4445 for block in planned_blocks.iter_mut() {
4452 for term in block.iter_mut() {
4453 auto_init_length_scale_in_place(data, term);
4454 }
4455 }
4456
4457 Ok(planned_blocks)
4458}
4459
4460const AUTO_LENGTH_SCALE_FLOOR: f64 = 1e-6;
4463
4464fn feature_columns_max_range(data: ArrayView2<'_, f64>, feature_cols: &[usize]) -> Option<f64> {
4467 let mut max_range = 0.0_f64;
4468 for &c in feature_cols {
4469 if c >= data.ncols() {
4470 continue;
4471 }
4472 let col = data.column(c);
4473 let mut lo = f64::INFINITY;
4474 let mut hi = f64::NEG_INFINITY;
4475 for &v in col.iter() {
4476 if v.is_finite() {
4477 if v < lo {
4478 lo = v;
4479 }
4480 if v > hi {
4481 hi = v;
4482 }
4483 }
4484 }
4485 if hi > lo {
4486 let r = hi - lo;
4487 if r > max_range {
4488 max_range = r;
4489 }
4490 }
4491 }
4492 if max_range.is_finite() && max_range > 0.0 {
4493 Some(max_range)
4494 } else {
4495 None
4496 }
4497}
4498
4499pub fn auto_initial_length_scale(data: ArrayView2<'_, f64>, feature_cols: &[usize]) -> f64 {
4506 let n = data.nrows();
4507 if n == 0 || feature_cols.is_empty() {
4508 return 1.0;
4509 }
4510 let Some(max_range) = feature_columns_max_range(data, feature_cols) else {
4511 return 1.0;
4512 };
4513 let init = max_range / (n as f64).sqrt();
4514 init.max(AUTO_LENGTH_SCALE_FLOOR).min(max_range)
4515}
4516
4517pub fn auto_initial_length_scale_for_centers(
4540 data: ArrayView2<'_, f64>,
4541 feature_cols: &[usize],
4542 num_centers: usize,
4543) -> f64 {
4544 let n = data.nrows();
4545 if n == 0 || feature_cols.is_empty() {
4546 return 1.0;
4547 }
4548 let Some(max_range) = feature_columns_max_range(data, feature_cols) else {
4549 return 1.0;
4550 };
4551 let resolution_points = n.max(num_centers).max(1) as f64;
4557 let spacing = max_range / resolution_points.sqrt();
4558 spacing.max(AUTO_LENGTH_SCALE_FLOOR).min(max_range)
4559}
4560
4561pub fn auto_initial_length_scale_for_low_rank_centers(
4571 data: ArrayView2<'_, f64>,
4572 feature_cols: &[usize],
4573 num_centers: usize,
4574) -> f64 {
4575 if data.nrows() == 0 || feature_cols.is_empty() {
4576 return 1.0;
4577 }
4578 let Some(max_range) = feature_columns_max_range(data, feature_cols) else {
4579 return 1.0;
4580 };
4581 let resolution_points = num_centers.max(1) as f64;
4582 let spacing = max_range / resolution_points.sqrt();
4583 spacing.max(AUTO_LENGTH_SCALE_FLOOR).min(max_range)
4584}
4585
4586fn center_strategy_requested_count(strategy: &CenterStrategy) -> Option<usize> {
4589 match strategy {
4590 CenterStrategy::Auto(inner) => center_strategy_requested_count(inner),
4591 CenterStrategy::UserProvided(centers) => Some(centers.nrows()),
4592 CenterStrategy::EqualMass { num_centers }
4593 | CenterStrategy::EqualMassCovarRepresentative { num_centers }
4594 | CenterStrategy::FarthestPoint { num_centers }
4595 | CenterStrategy::KMeans { num_centers, .. } => Some(*num_centers),
4596 CenterStrategy::UniformGrid { .. } => None,
4597 }
4598}
4599
4600pub fn auto_init_length_scale_in_place(data: ArrayView2<'_, f64>, term: &mut SmoothTermSpec) {
4604 auto_init_length_scale_in_basis(data, &mut term.basis);
4605}
4606
4607pub fn auto_init_length_scale_in_basis(data: ArrayView2<'_, f64>, basis: &mut SmoothBasisSpec) {
4620 match basis {
4621 SmoothBasisSpec::Matern {
4622 feature_cols, spec, ..
4623 } => {
4624 if spec.length_scale == 0.0 {
4625 spec.length_scale = match center_strategy_requested_count(&spec.center_strategy) {
4634 Some(k) => auto_initial_length_scale_for_centers(data, feature_cols, k),
4635 None => auto_initial_length_scale(data, feature_cols),
4636 };
4637 }
4638 }
4639 SmoothBasisSpec::ThinPlate {
4640 feature_cols, spec, ..
4641 } => {
4642 if spec.length_scale == 0.0 {
4643 spec.length_scale = match center_strategy_requested_count(&spec.center_strategy) {
4644 Some(k) => {
4645 auto_initial_length_scale_for_low_rank_centers(data, feature_cols, k)
4646 }
4647 None => auto_initial_length_scale(data, feature_cols),
4648 };
4649 }
4650 }
4651 SmoothBasisSpec::ByVariable { inner, .. }
4652 | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
4653 auto_init_length_scale_in_basis(data, inner);
4654 }
4655 SmoothBasisSpec::BySmooth { smooth, .. } => {
4656 auto_init_length_scale_in_basis(data, smooth);
4657 }
4658 _ => {}
4659 }
4660}
4661
4662impl LinearFitConditioning {
4663 pub fn from_columns(design: &TermCollectionDesign, selected_cols: &[usize]) -> Self {
4664 const SCALE_EPS: f64 = 1e-12;
4665 let n = design.design.nrows();
4666 let p = design.design.ncols();
4667 let mut columns = Vec::with_capacity(selected_cols.len());
4668 if n == 0 || selected_cols.is_empty() {
4669 return Self {
4670 intercept_idx: design.intercept_range.start,
4671 columns,
4672 };
4673 }
4674 let chunk_rows = gam_linalg::utils::row_chunk_for_byte_budget(n, p);
4675 let mut sums = vec![0.0_f64; selected_cols.len()];
4681 for start in (0..n).step_by(chunk_rows) {
4682 let end = (start + chunk_rows).min(n);
4683 let chunk = design
4684 .design
4685 .try_row_chunk(start..end)
4686 .expect("LinearFitConditioning::from_columns row chunk failed");
4687 for (k, &col_idx) in selected_cols.iter().enumerate() {
4688 let column = chunk.column(col_idx);
4689 for &v in column.iter() {
4690 sums[k] += v;
4691 }
4692 }
4693 }
4694 let inv_n = 1.0_f64 / n as f64;
4695 let means: Vec<f64> = sums.iter().map(|&s| s * inv_n).collect();
4696 let mut sq_devs = vec![0.0_f64; selected_cols.len()];
4697 for start in (0..n).step_by(chunk_rows) {
4698 let end = (start + chunk_rows).min(n);
4699 let chunk = design
4700 .design
4701 .try_row_chunk(start..end)
4702 .expect("LinearFitConditioning::from_columns row chunk failed");
4703 for (k, &col_idx) in selected_cols.iter().enumerate() {
4704 let mean_k = means[k];
4705 let column = chunk.column(col_idx);
4706 for &v in column.iter() {
4707 let d = v - mean_k;
4708 sq_devs[k] += d * d;
4709 }
4710 }
4711 }
4712 for (k, &col_idx) in selected_cols.iter().enumerate() {
4713 let mean = means[k];
4714 let var = sq_devs[k] * inv_n;
4715 let (mean, scale) = if var.is_finite() && var > SCALE_EPS * SCALE_EPS {
4716 (mean, var.sqrt())
4717 } else {
4718 (0.0, 1.0)
4721 };
4722 columns.push(LinearColumnConditioning {
4723 col_idx,
4724 mean,
4725 scale,
4726 });
4727 }
4728 Self {
4729 intercept_idx: design.intercept_range.start,
4730 columns,
4731 }
4732 }
4733
4734 pub fn apply_to_design(&self, design: &Array2<f64>) -> Array2<f64> {
4735 let mut out = design.clone();
4736 for col in &self.columns {
4737 {
4738 let mut dst = out.column_mut(col.col_idx);
4739 dst -= col.mean;
4740 }
4741 if col.scale != 1.0 {
4742 out.column_mut(col.col_idx).mapv_inplace(|v| v / col.scale);
4743 }
4744 }
4745 out
4746 }
4747
4748 fn transform_matrix_columnswith_a(&self, mat: &Array2<f64>) -> Array2<f64> {
4749 let mut out = mat.clone();
4750 let intercept = self.intercept_idx;
4751 for col in &self.columns {
4752 let intercept_col = out.column(intercept).to_owned();
4753 let mut target = out.column_mut(col.col_idx);
4754 target -= &(intercept_col * col.mean);
4755 if col.scale != 1.0 {
4756 target.mapv_inplace(|v| v / col.scale);
4757 }
4758 }
4759 out
4760 }
4761
4762 fn transform_matrixrowswith_a_transpose(&self, mat: &Array2<f64>) -> Array2<f64> {
4763 let mut out = mat.clone();
4764 let intercept = self.intercept_idx;
4765 for col in &self.columns {
4766 let interceptrow = out.row(intercept).to_owned();
4767 let mut target = out.row_mut(col.col_idx);
4768 target -= &(interceptrow * col.mean);
4769 if col.scale != 1.0 {
4770 target.mapv_inplace(|v| v / col.scale);
4771 }
4772 }
4773 out
4774 }
4775
4776 fn left_multiply_by_m_inv_transpose(&self, mat_internal: &Array2<f64>) -> Array2<f64> {
4781 let mut out = mat_internal.clone();
4782 let intercept = self.intercept_idx;
4783 let interceptrow_snapshot = mat_internal.row(intercept).to_owned();
4784 for col in &self.columns {
4785 if col.scale != 1.0 {
4786 out.row_mut(col.col_idx).mapv_inplace(|v| v * col.scale);
4787 }
4788 if col.mean != 0.0 {
4789 let mut target = out.row_mut(col.col_idx);
4790 target += &(&interceptrow_snapshot * col.mean);
4791 }
4792 }
4793 out
4794 }
4795
4796 fn right_multiply_by_m_inv(&self, mat_internal: &Array2<f64>) -> Array2<f64> {
4799 let mut out = mat_internal.clone();
4800 let intercept = self.intercept_idx;
4801 let intercept_col_snapshot = mat_internal.column(intercept).to_owned();
4802 for col in &self.columns {
4803 if col.scale != 1.0 {
4804 out.column_mut(col.col_idx).mapv_inplace(|v| v * col.scale);
4805 }
4806 if col.mean != 0.0 {
4807 let mut target = out.column_mut(col.col_idx);
4808 target += &(&intercept_col_snapshot * col.mean);
4809 }
4810 }
4811 out
4812 }
4813
4814 pub fn transform_blockwise_penalties_to_internal(
4821 &self,
4822 penalties: &[BlockwisePenalty],
4823 p: usize,
4824 ) -> Vec<crate::penalty_spec::PenaltySpec> {
4825 let conditioning_cols: std::collections::HashSet<usize> =
4826 self.columns.iter().map(|c| c.col_idx).collect();
4827 penalties
4828 .iter()
4829 .map(|bp| {
4830 let overlaps =
4831 (bp.col_range.start..bp.col_range.end).any(|j| conditioning_cols.contains(&j));
4832 if overlaps {
4833 let global = bp.to_global(p);
4836 let right = self.transform_matrix_columnswith_a(&global);
4837 let transformed = self.transform_matrixrowswith_a_transpose(&right);
4838 crate::penalty_spec::PenaltySpec::Dense(transformed)
4839 } else {
4840 crate::penalty_spec::PenaltySpec::from_blockwise(bp.clone())
4843 }
4844 })
4845 .collect()
4846 }
4847
4848 pub fn backtransform_beta(&self, beta_internal: &Array1<f64>) -> Array1<f64> {
4849 let mut beta = beta_internal.clone();
4850 let intercept = self.intercept_idx;
4851 for col in &self.columns {
4852 beta[intercept] -= beta_internal[col.col_idx] * col.mean / col.scale;
4853 beta[col.col_idx] = beta_internal[col.col_idx] / col.scale;
4854 }
4855 beta
4856 }
4857
4858 pub fn transform_penalized_hessian_to_original(&self, h_internal: &Array2<f64>) -> Array2<f64> {
4861 let right = self.right_multiply_by_m_inv(h_internal);
4862 self.left_multiply_by_m_inv_transpose(&right)
4863 }
4864
4865 pub fn internal_bounds_for(&self, col_idx: usize, min: f64, max: f64) -> (f64, f64) {
4866 if let Some(col) = self.columns.iter().find(|c| c.col_idx == col_idx) {
4867 (min * col.scale, max * col.scale)
4868 } else {
4869 (min, max)
4870 }
4871 }
4872}
4873
4874pub fn freeze_raw_spatial_metadata(metadata: BasisMetadata, raw_cols: usize) -> BasisMetadata {
4875 match metadata {
4876 BasisMetadata::ThinPlate {
4877 centers,
4878 length_scale,
4879 periodic,
4880 identifiability_transform: None,
4881 input_scales,
4882 radial_reparam,
4883 } => BasisMetadata::ThinPlate {
4884 centers,
4885 length_scale,
4886 periodic,
4887 identifiability_transform: Some(Array2::eye(raw_cols)),
4888 input_scales,
4889 radial_reparam,
4890 },
4891 BasisMetadata::Duchon {
4892 centers,
4893 length_scale,
4894 periodic,
4895 power,
4896 nullspace_order,
4897 identifiability_transform: None,
4898 input_scales,
4899 aniso_log_scales,
4900 operator_collocation_points,
4901 radial_reparam,
4902 } => BasisMetadata::Duchon {
4903 centers,
4904 length_scale,
4905 periodic,
4906 power,
4907 nullspace_order,
4908 identifiability_transform: Some(Array2::eye(raw_cols)),
4909 input_scales,
4910 aniso_log_scales,
4911 operator_collocation_points,
4912 radial_reparam,
4913 },
4914 other => other,
4915 }
4916}
4917
4918pub fn matern_operator_penalty_triplet_from_metadata(
4919 metadata: &BasisMetadata,
4920) -> Result<(Vec<Array2<f64>>, Vec<usize>, Vec<PenaltyInfo>), BasisError> {
4921 let BasisMetadata::Matern {
4922 centers,
4923 length_scale,
4924 periodic,
4925 nu,
4926 include_intercept,
4927 identifiability_transform,
4928 aniso_log_scales,
4929 input_scales,
4930 ..
4931 } = metadata
4932 else {
4933 crate::bail_invalid_basis!("Matérn operator penalties require Matérn metadata");
4934 };
4935 let penalty_length_scale = match input_scales.as_deref() {
4947 Some(scales) => compensate_length_scale_for_standardization(*length_scale, scales),
4948 None => *length_scale,
4949 };
4950 matern_operator_penalty_triplet_at_length_scale(
4951 centers.view(),
4952 periodic.as_deref(),
4953 identifiability_transform.as_ref(),
4954 *nu,
4955 *include_intercept,
4956 aniso_log_scales.as_deref(),
4957 penalty_length_scale,
4958 )
4959}
4960
4961pub fn matern_operator_penalty_triplet_at_length_scale(
4979 centers: ArrayView2<'_, f64>,
4980 periodic: Option<&[Option<f64>]>,
4981 identifiability_transform: Option<&Array2<f64>>,
4982 nu: crate::basis::MaternNu,
4983 include_intercept: bool,
4984 aniso_log_scales: Option<&[f64]>,
4985 effective_length_scale: f64,
4986) -> Result<(Vec<Array2<f64>>, Vec<usize>, Vec<PenaltyInfo>), BasisError> {
4987 let penalty_centers = crate::basis::expand_periodic_centers(¢ers.to_owned(), periodic)?;
4988 let ops = build_matern_collocation_operator_matrices(
4989 penalty_centers.view(),
4990 None,
4991 effective_length_scale,
4992 nu,
4993 include_intercept,
4994 identifiability_transform.map(|z| z.view()),
4995 aniso_log_scales,
4996 )?;
4997 const ORDER_EPS: f64 = 1e-9;
5005 let d = penalty_centers.ncols();
5006 let m = nu.half_integer_value() + 0.5 * d as f64;
5007 let mut candidates = Vec::with_capacity(3);
5008 for (raw, source, min_order) in [
5009 (ops.d0.t().dot(&ops.d0), PenaltySource::OperatorMass, 0.0),
5010 (ops.d1.t().dot(&ops.d1), PenaltySource::OperatorTension, 1.0),
5011 (
5012 ops.d2.t().dot(&ops.d2),
5013 PenaltySource::OperatorStiffness,
5014 2.0,
5015 ),
5016 ] {
5017 if min_order > 0.0 && m <= min_order + ORDER_EPS {
5018 continue;
5019 }
5020 let sym = (&raw + &raw.t()) * 0.5;
5021 let (matrix, normalization_scale) = normalize_penalty_in_constrained_space(&sym);
5022 candidates.push(PenaltyCandidate {
5023 matrix,
5024 nullspace_dim_hint: 0,
5025 source,
5026 normalization_scale,
5027 kronecker_factors: None,
5028 op: None,
5029 });
5030 }
5031 filter_active_penalty_candidates(candidates)
5032}
5033
5034pub fn normalize_penalty_in_constrained_space(matrix: &Array2<f64>) -> (Array2<f64>, f64) {
5035 let matrix = (matrix + &matrix.t().to_owned()) * 0.5;
5040 let matrix = crate::basis::project_penalty_to_psd_cone(&matrix);
5042 let c = matrix.iter().map(|v| v * v).sum::<f64>().sqrt();
5043 if c.is_finite() && c > 0.0 {
5044 (matrix.mapv(|v| v / c), c)
5045 } else {
5046 (matrix, 1.0)
5047 }
5048}
5049
5050pub fn tensor_product_design_from_sparse_marginals(
5051 marginal_sparse: &[&SparseColMat<usize, f64>],
5052) -> Result<SparseColMat<usize, f64>, BasisError> {
5053 if marginal_sparse.is_empty() {
5054 crate::bail_invalid_basis!("TensorBSpline requires at least one marginal basis");
5055 }
5056 let n = marginal_sparse[0].nrows();
5057 for (i, m) in marginal_sparse.iter().enumerate().skip(1) {
5058 if m.nrows() != n {
5059 crate::bail_dim_basis!(
5060 "tensor sparse marginal row mismatch at dim {i}: expected {n}, got {}",
5061 m.nrows()
5062 );
5063 }
5064 }
5065 let dims: Vec<usize> = marginal_sparse.iter().map(|m| m.ncols()).collect();
5066 let total_cols = dims.iter().try_fold(1usize, |acc, &q| {
5067 acc.checked_mul(q)
5068 .ok_or_else(|| BasisError::DimensionMismatch("tensor basis too large".to_string()))
5069 })?;
5070 let mut strides = vec![1usize; dims.len()];
5071 for d in (0..dims.len().saturating_sub(1)).rev() {
5072 strides[d] = strides[d + 1]
5073 .checked_mul(dims[d + 1])
5074 .ok_or_else(|| BasisError::DimensionMismatch("tensor basis too large".to_string()))?;
5075 }
5076
5077 use faer::sparse::SparseRowMat;
5078 let csrs: Vec<SparseRowMat<usize, f64>> = marginal_sparse
5079 .iter()
5080 .enumerate()
5081 .map(|(d, m)| {
5082 m.as_ref().to_row_major().map_err(|e| {
5083 BasisError::SparseCreation(format!(
5084 "tensor sparse marginal {d} CSR conversion failed: {e:?}"
5085 ))
5086 })
5087 })
5088 .collect::<Result<Vec<_>, _>>()?;
5089 let row_ptrs: Vec<&[usize]> = csrs.iter().map(|c| c.symbolic().row_ptr()).collect();
5090 let col_idxs: Vec<&[usize]> = csrs.iter().map(|c| c.symbolic().col_idx()).collect();
5091 let vals: Vec<&[f64]> = csrs.iter().map(|c| c.val()).collect();
5092
5093 use rayon::prelude::*;
5094 const CHUNK: usize = 1024;
5095 let num_chunks = n.div_ceil(CHUNK);
5096 let per_chunk: Vec<Vec<Triplet<usize, usize, f64>>> = (0..num_chunks)
5097 .into_par_iter()
5098 .map(|chunk_idx| {
5099 let row_start = chunk_idx * CHUNK;
5100 let row_end = (row_start + CHUNK).min(n);
5101 let mut chunk_triplets = Vec::<Triplet<usize, usize, f64>>::new();
5102 let mut cur_cols = Vec::<usize>::with_capacity(64);
5103 let mut cur_vals = Vec::<f64>::with_capacity(64);
5104 let mut next_cols = Vec::<usize>::with_capacity(64);
5105 let mut next_vals = Vec::<f64>::with_capacity(64);
5106 for i in row_start..row_end {
5107 cur_cols.clear();
5108 cur_vals.clear();
5109 cur_cols.push(0);
5110 cur_vals.push(1.0);
5111 let mut row_is_zero = false;
5112 for d in 0..dims.len() {
5113 let row_start_d = row_ptrs[d][i];
5114 let row_end_d = row_ptrs[d][i + 1];
5115 if row_start_d == row_end_d {
5116 row_is_zero = true;
5117 break;
5118 }
5119 let stride = strides[d];
5120 next_cols.clear();
5121 next_vals.clear();
5122 next_cols.reserve(cur_cols.len() * (row_end_d - row_start_d));
5123 next_vals.reserve(cur_vals.len() * (row_end_d - row_start_d));
5124 for (&prev_col, &prev_val) in cur_cols.iter().zip(cur_vals.iter()) {
5125 for ptr in row_start_d..row_end_d {
5126 let cj = col_idxs[d][ptr];
5127 let vj = vals[d][ptr];
5128 next_cols.push(prev_col + cj * stride);
5129 next_vals.push(prev_val * vj);
5130 }
5131 }
5132 std::mem::swap(&mut cur_cols, &mut next_cols);
5133 std::mem::swap(&mut cur_vals, &mut next_vals);
5134 }
5135 if row_is_zero {
5136 continue;
5137 }
5138 for (&col, &val) in cur_cols.iter().zip(cur_vals.iter()) {
5139 chunk_triplets.push(Triplet::new(i, col, val));
5140 }
5141 }
5142 chunk_triplets
5143 })
5144 .collect();
5145 let total_nnz: usize = per_chunk.iter().map(Vec::len).sum();
5146 let mut triplets = Vec::<Triplet<usize, usize, f64>>::with_capacity(total_nnz);
5147 for chunk in per_chunk {
5148 triplets.extend(chunk);
5149 }
5150 SparseColMat::try_new_from_triplets(n, total_cols, &triplets).map_err(|e| {
5151 BasisError::SparseCreation(format!(
5152 "failed to assemble sparse tensor product design: {e:?}"
5153 ))
5154 })
5155}
5156
5157pub fn dense_local_margin_to_sparse(
5158 dense: &Array2<f64>,
5159) -> Result<SparseColMat<usize, f64>, BasisError> {
5160 let expected_row_nnz = dense.ncols().min(4);
5161 let mut triplets =
5162 Vec::<Triplet<usize, usize, f64>>::with_capacity(dense.nrows() * expected_row_nnz);
5163 for ((row, col), &value) in dense.indexed_iter() {
5164 if value != 0.0 {
5165 triplets.push(Triplet::new(row, col, value));
5166 }
5167 }
5168 SparseColMat::try_new_from_triplets(dense.nrows(), dense.ncols(), &triplets).map_err(|e| {
5169 BasisError::SparseCreation(format!(
5170 "failed to convert tensor marginal design to sparse form: {e:?}"
5171 ))
5172 })
5173}
5174
5175pub struct TensorMarginRangeNullProjectors {
5176 range: Array2<f64>,
5177 null: Array2<f64>,
5178}
5179
5180pub fn projector_from_columns(columns: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
5181 if indices.is_empty() {
5182 return Array2::<f64>::zeros((columns.nrows(), columns.nrows()));
5183 }
5184 let basis = columns.select(Axis(1), indices);
5185 basis.dot(&basis.t())
5186}
5187
5188pub fn tensor_margin_range_null_projectors(
5189 normalized_marginal_penalties: &[(Array2<f64>, f64)],
5190) -> Result<Vec<TensorMarginRangeNullProjectors>, BasisError> {
5191 normalized_marginal_penalties
5192 .iter()
5193 .enumerate()
5194 .map(|(dim, (penalty, _))| {
5195 let analysis = crate::basis::analyze_penalty_block(penalty)?;
5196 if analysis.rank == 0 {
5197 crate::bail_invalid_basis!(
5198 "t2 separable tensor penalty margin {dim} has rank-zero penalty; \
5199 cannot split penalized and null subspaces"
5200 );
5201 }
5202 let mut range_idx = Vec::<usize>::new();
5203 let mut null_idx = Vec::<usize>::new();
5204 for (idx, &ev) in analysis.eigenvalues.iter().enumerate() {
5205 if ev > analysis.tol {
5206 range_idx.push(idx);
5207 } else {
5208 null_idx.push(idx);
5209 }
5210 }
5211 Ok(TensorMarginRangeNullProjectors {
5212 range: projector_from_columns(&analysis.eigenvectors, &range_idx),
5213 null: projector_from_columns(&analysis.eigenvectors, &null_idx),
5214 })
5215 })
5216 .collect()
5217}
5218
5219pub fn build_tensor_bspline_basis(
5220 data: ArrayView2<'_, f64>,
5221 feature_cols: &[usize],
5222 spec: &TensorBSplineSpec,
5223) -> Result<BasisBuildResult, BasisError> {
5224 if feature_cols.is_empty() {
5225 crate::bail_invalid_basis!("TensorBSpline requires at least one feature column");
5226 }
5227 if feature_cols.len() != spec.marginalspecs.len() {
5228 crate::bail_dim_basis!(
5229 "TensorBSpline feature/spec mismatch: feature_cols={}, marginalspecs={}",
5230 feature_cols.len(),
5231 spec.marginalspecs.len()
5232 );
5233 }
5234 if !spec.periods.is_empty() && spec.periods.len() != feature_cols.len() {
5235 crate::bail_dim_basis!(
5236 "TensorBSpline periods length {} does not match feature count {}",
5237 spec.periods.len(),
5238 feature_cols.len()
5239 );
5240 }
5241 let p = data.ncols();
5242 for &c in feature_cols {
5243 if c >= p {
5244 crate::bail_dim_basis!(
5245 "tensor feature column {c} is out of bounds for data with {p} columns"
5246 );
5247 }
5248 }
5249
5250 let mut marginal_knots = Vec::<Array1<f64>>::with_capacity(feature_cols.len());
5251 let mut marginal_is_cr_flags = Vec::<bool>::with_capacity(feature_cols.len());
5254 let mut marginal_degrees = Vec::<usize>::with_capacity(feature_cols.len());
5255 let mut marginalnum_basis = Vec::<usize>::with_capacity(feature_cols.len());
5256 let mut marginal_penalties = Vec::<Array2<f64>>::with_capacity(feature_cols.len());
5257 let mut marginal_designs = Vec::<Array2<f64>>::with_capacity(feature_cols.len());
5258 let mut marginal_effective_periods = Vec::<Option<f64>>::with_capacity(feature_cols.len());
5266 let mut marginal_sparse =
5274 Vec::<Option<SparseColMat<usize, f64>>>::with_capacity(feature_cols.len());
5275
5276 for (dim, (&col, marginalspec)) in feature_cols
5279 .iter()
5280 .zip(spec.marginalspecs.iter())
5281 .enumerate()
5282 {
5283 let mut marginal_unconstrained = marginalspec.clone();
5288 marginal_unconstrained.identifiability = BSplineIdentifiability::None;
5289 let built = build_bspline_basis_1d(data.column(col), &marginal_unconstrained)?;
5290 let (knots, marginal_is_cr) = match built.metadata {
5295 BasisMetadata::BSpline1D { knots, .. } => (knots, false),
5296 BasisMetadata::CubicRegression1D { knots, .. } => (knots, true),
5297 _ => {
5298 crate::bail_invalid_basis!(
5299 "internal TensorBSpline error at dim {dim}: expected BSpline1D or CubicRegression1D metadata"
5300 );
5301 }
5302 };
5303 let metadata_knots = match marginalspec.knotspec {
5304 BSplineKnotSpec::PeriodicUniform {
5305 data_range,
5306 num_basis,
5307 } => Array1::linspace(data_range.0, data_range.1, num_basis),
5308 _ => knots,
5309 };
5310 marginal_knots.push(metadata_knots);
5311 marginal_is_cr_flags.push(marginal_is_cr);
5312 marginal_degrees.push(marginalspec.degree);
5313 marginalnum_basis.push(built.design.ncols());
5314 let dense_marginal = built.design.to_dense();
5319 let sparse_view: Option<SparseColMat<usize, f64>> = match built.design.as_sparse() {
5320 Some(sd) => {
5321 let inner: &SparseColMat<usize, f64> = sd;
5322 Some(inner.clone())
5323 }
5324 None => match marginalspec.knotspec {
5325 BSplineKnotSpec::PeriodicUniform { .. } => {
5326 Some(dense_local_margin_to_sparse(&dense_marginal)?)
5327 }
5328 _ => None,
5329 },
5330 };
5331 marginal_sparse.push(sparse_view);
5332 marginal_designs.push(dense_marginal);
5333 marginal_penalties.push(
5334 built
5335 .penalties
5336 .first()
5337 .ok_or_else(|| {
5338 BasisError::InvalidInput(format!(
5339 "internal TensorBSpline error at dim {dim}: missing marginal penalty"
5340 ))
5341 })?
5342 .clone(),
5343 );
5344 built.nullspace_dims.first().ok_or_else(|| {
5345 BasisError::InvalidInput(format!(
5346 "internal TensorBSpline error at dim {dim}: missing marginal nullspace dim"
5347 ))
5348 })?;
5349 let implied_period = match marginalspec.knotspec {
5357 BSplineKnotSpec::PeriodicUniform { data_range, .. } => {
5358 Some(data_range.1 - data_range.0)
5359 }
5360 _ => spec.periods.get(dim).and_then(|p| *p),
5361 };
5362 marginal_effective_periods.push(implied_period);
5363 }
5364
5365 let total_cols: usize = marginalnum_basis.iter().product();
5366 let mut dense_design = (!matches!(spec.identifiability, TensorBSplineIdentifiability::None))
5367 .then(|| tensor_product_design_from_marginals(&marginal_designs))
5368 .transpose()?;
5369 let mut candidates = Vec::<PenaltyCandidate>::with_capacity(
5370 match spec.penalty_decomposition {
5371 TensorBSplinePenaltyDecomposition::MarginalKroneckerSum => marginal_penalties.len(),
5372 TensorBSplinePenaltyDecomposition::Separable => marginal_penalties.len() * 2,
5373 } + if spec.double_penalty { 1 } else { 0 },
5374 );
5375
5376 let normalized_marginal_penalties: Vec<(Array2<f64>, f64)> = marginal_penalties
5384 .iter()
5385 .map(normalize_penalty_in_constrained_space)
5386 .collect();
5387 let mut kronecker_marginal_penalties =
5388 Vec::<Array2<f64>>::with_capacity(normalized_marginal_penalties.len());
5389
5390 match spec.penalty_decomposition {
5391 TensorBSplinePenaltyDecomposition::MarginalKroneckerSum => {
5392 let mut marginal_kron_sum = Array2::<f64>::zeros((total_cols, total_cols));
5398
5399 for dim in 0..normalized_marginal_penalties.len() {
5400 let mut s_dim = Array2::<f64>::eye(1);
5401 let mut factors = Vec::<Array2<f64>>::with_capacity(marginalnum_basis.len());
5402 for (j, &qj) in marginalnum_basis.iter().enumerate() {
5403 let factor = if j == dim {
5404 normalized_marginal_penalties[j].0.clone()
5405 } else {
5406 Array2::<f64>::eye(qj)
5407 };
5408 factors.push(factor.clone());
5409 s_dim = kronecker_product(&s_dim, &factor);
5410 }
5411 if dim == kronecker_marginal_penalties.len() {
5412 kronecker_marginal_penalties.push(normalized_marginal_penalties[dim].0.clone());
5413 }
5414 marginal_kron_sum += &s_dim;
5415
5416 candidates.push(PenaltyCandidate {
5417 matrix: s_dim,
5418 nullspace_dim_hint: 0,
5419 source: PenaltySource::TensorMarginal { dim },
5420 normalization_scale: normalized_marginal_penalties[dim].1,
5421 kronecker_factors: Some(factors),
5422 op: None,
5423 });
5424 }
5425
5426 if spec.double_penalty
5427 && let Some(shrink) =
5428 crate::basis::build_nullspace_shrinkage_penalty(&marginal_kron_sum)?
5429 {
5430 let (matrix, normalization_scale) =
5431 normalize_penalty_in_constrained_space(&shrink.sym_penalty);
5432 candidates.push(PenaltyCandidate {
5433 matrix,
5434 nullspace_dim_hint: 0,
5435 source: PenaltySource::TensorGlobalRidge,
5436 normalization_scale,
5437 kronecker_factors: None,
5438 op: None,
5439 });
5440 }
5441 }
5442 TensorBSplinePenaltyDecomposition::Separable => {
5443 let projectors = tensor_margin_range_null_projectors(&normalized_marginal_penalties)?;
5444 let n_masks = 1usize.checked_shl(projectors.len() as u32).ok_or_else(|| {
5445 BasisError::InvalidInput(format!(
5446 "t2 separable tensor penalty supports at most {} margins, got {}",
5447 usize::BITS - 1,
5448 projectors.len()
5449 ))
5450 })?;
5451 for mask in 1..n_masks {
5452 let mut matrix = Array2::<f64>::eye(1);
5453 let mut factors = Vec::<Array2<f64>>::with_capacity(projectors.len());
5454 let mut penalized_margins = Vec::<usize>::new();
5455 for (dim, projector) in projectors.iter().enumerate() {
5456 let use_range = ((mask >> dim) & 1) == 1;
5457 let factor = if use_range {
5458 penalized_margins.push(dim);
5459 projector.range.clone()
5460 } else {
5461 projector.null.clone()
5462 };
5463 matrix = kronecker_product(&matrix, &factor);
5464 factors.push(factor);
5465 }
5466 let (matrix, normalization_scale) = normalize_penalty_in_constrained_space(&matrix);
5467 candidates.push(PenaltyCandidate {
5468 matrix,
5469 nullspace_dim_hint: 0,
5470 source: PenaltySource::TensorSeparable { penalized_margins },
5471 normalization_scale,
5472 kronecker_factors: Some(factors),
5473 op: None,
5474 });
5475 }
5476
5477 if spec.double_penalty {
5478 let mut matrix = Array2::<f64>::eye(1);
5479 let mut factors = Vec::<Array2<f64>>::with_capacity(projectors.len());
5480 for projector in &projectors {
5481 matrix = kronecker_product(&matrix, &projector.null);
5482 factors.push(projector.null.clone());
5483 }
5484 let (matrix, normalization_scale) = normalize_penalty_in_constrained_space(&matrix);
5485 candidates.push(PenaltyCandidate {
5486 matrix,
5487 nullspace_dim_hint: 0,
5488 source: PenaltySource::TensorGlobalRidge,
5489 normalization_scale,
5490 kronecker_factors: Some(factors),
5491 op: None,
5492 });
5493 }
5494 }
5495 }
5496
5497 let z_opt = match &spec.identifiability {
5498 TensorBSplineIdentifiability::None => None,
5499 TensorBSplineIdentifiability::SumToZero => {
5500 if total_cols < 2 {
5501 crate::bail_invalid_basis!(
5502 "TensorBSpline requires at least 2 basis coefficients to enforce sum-to-zero identifiability"
5503 );
5504 }
5505 let dense_design_ref = dense_design.as_ref().ok_or_else(|| {
5506 BasisError::InvalidInput(
5507 "tensor sum-to-zero identifiability requires a realized basis".to_string(),
5508 )
5509 })?;
5510 let (_, z) = apply_sum_to_zero_constraint(dense_design_ref.view(), None)?;
5511 let gauge = gam_problem::Gauge::sum_to_zero(z);
5512 Some(gauge.block_transform(0))
5513 }
5514 TensorBSplineIdentifiability::MarginalSumToZero => {
5515 if marginal_designs.len() < 2 {
5526 crate::bail_invalid_basis!(
5527 "tensor interaction (ti) identifiability requires at least 2 margins"
5528 );
5529 }
5530 let mut z = Array2::<f64>::eye(1);
5531 for (dim, marginal) in marginal_designs.iter().enumerate() {
5532 if marginal.ncols() < 2 {
5533 crate::bail_invalid_basis!(
5534 "tensor interaction (ti) margin {dim} has fewer than 2 basis functions; \
5535 cannot remove its marginal main effect"
5536 );
5537 }
5538 let (_, z_dim) = apply_sum_to_zero_constraint(marginal.view(), None)?;
5539 let gauge_dim = gam_problem::Gauge::sum_to_zero(z_dim);
5540 let z_dim = gauge_dim.block_transform(0);
5541 z = kronecker_product(&z, &z_dim);
5542 }
5543 Some(z)
5544 }
5545 TensorBSplineIdentifiability::FrozenTransform { transform } => {
5546 if transform.nrows() != total_cols {
5547 crate::bail_dim_basis!(
5548 "frozen tensor identifiability transform mismatch: design has {} columns but transform has {} rows",
5549 total_cols,
5550 transform.nrows()
5551 );
5552 }
5553 Some(transform.clone())
5554 }
5555 };
5556
5557 if let Some(z) = z_opt.as_ref() {
5558 let gauge = gam_problem::Gauge::from_block_transforms(&[z.clone()]);
5559 let dense = dense_design.as_mut().ok_or_else(|| {
5560 BasisError::InvalidInput(
5561 "tensor identifiability transform requires a realized basis".to_string(),
5562 )
5563 })?;
5564 let restricted_design = gauge.restrict_design(dense);
5565 *dense = restricted_design;
5566 candidates = candidates
5567 .into_iter()
5568 .map(|candidate| -> Result<PenaltyCandidate, BasisError> {
5569 let matrix = gauge.restrict_penalty(&candidate.matrix);
5570 let (matrix, c_new) = normalize_penalty_in_constrained_space(&matrix);
5578 Ok(PenaltyCandidate {
5579 nullspace_dim_hint: candidate.nullspace_dim_hint,
5580 matrix,
5581 source: candidate.source,
5582 normalization_scale: candidate.normalization_scale * c_new,
5583 kronecker_factors: None,
5589 op: candidate.op.clone(),
5590 })
5591 })
5592 .collect::<Result<Vec<_>, _>>()?;
5593 }
5594
5595 let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
5596 filter_active_penalty_candidates_with_ops(candidates)?;
5597 let identifiability_is_none =
5598 matches!(spec.identifiability, TensorBSplineIdentifiability::None);
5599 let all_marginals_sparse = marginal_sparse.iter().all(Option::is_some);
5607 let design = if let Some(dense_design) = dense_design {
5608 DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense_design))
5609 } else if identifiability_is_none && all_marginals_sparse {
5610 let sparse_marginals: Vec<&SparseColMat<usize, f64>> = marginal_sparse
5616 .iter()
5617 .map(|m| m.as_ref().expect("all_marginals_sparse just verified"))
5618 .collect();
5619 let sparse_design = tensor_product_design_from_sparse_marginals(&sparse_marginals)?;
5620 DesignMatrix::Sparse(gam_linalg::matrix::SparseDesignMatrix::new(sparse_design))
5621 } else {
5622 let marginals: Vec<Arc<Array2<f64>>> = marginal_designs
5623 .iter()
5624 .map(|m| Arc::new(m.clone()))
5625 .collect();
5626 let op = TensorProductDesignOperator::new(marginals).map_err(|e| {
5627 BasisError::InvalidInput(format!("TensorProductDesignOperator build failed: {e}"))
5628 })?;
5629 DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(op)))
5630 };
5631
5632 Ok(BasisBuildResult {
5633 design,
5634 penalties,
5635 nullspace_dims,
5636 penaltyinfo,
5637 ops,
5638 null_eigenvectors,
5639 joint_null_rotation: None,
5640 metadata: BasisMetadata::TensorBSpline {
5641 feature_cols: feature_cols.to_vec(),
5642 knots: marginal_knots,
5643 degrees: marginal_degrees,
5644 periods: marginal_effective_periods,
5651 is_cr: marginal_is_cr_flags,
5652 identifiability_transform: z_opt,
5653 },
5654 kronecker_factored: if matches!(spec.identifiability, TensorBSplineIdentifiability::None)
5655 && matches!(
5656 spec.penalty_decomposition,
5657 TensorBSplinePenaltyDecomposition::MarginalKroneckerSum
5658 ) {
5659 Some(KroneckerFactoredBasis::new(
5660 marginal_designs,
5661 kronecker_marginal_penalties,
5662 marginalnum_basis.clone(),
5663 spec.double_penalty,
5664 ))
5665 } else {
5666 None
5667 },
5668 })
5669}
5670
5671pub fn tensor_product_design_from_marginals(
5672 marginal_designs: &[Array2<f64>],
5673) -> Result<Array2<f64>, BasisError> {
5674 if marginal_designs.is_empty() {
5675 crate::bail_invalid_basis!("TensorBSpline requires at least one marginal basis");
5676 }
5677 let n = marginal_designs[0].nrows();
5678 for (i, b) in marginal_designs.iter().enumerate().skip(1) {
5679 if b.nrows() != n {
5680 crate::bail_dim_basis!(
5681 "tensor marginal row mismatch at dim {i}: expected {n}, got {}",
5682 b.nrows()
5683 );
5684 }
5685 }
5686 let total_cols = marginal_designs.iter().try_fold(1usize, |acc, b| {
5687 acc.checked_mul(b.ncols())
5688 .ok_or_else(|| BasisError::DimensionMismatch("tensor basis too large".to_string()))
5689 })?;
5690 use ndarray::parallel::prelude::*;
5696 use rayon::iter::{IntoParallelIterator, ParallelIterator};
5697 let mut design = Array2::<f64>::zeros((n, total_cols));
5698 design
5699 .axis_chunks_iter_mut(ndarray::Axis(0), 1024)
5700 .into_par_iter()
5701 .enumerate()
5702 .for_each(|(chunk_idx, mut block)| {
5703 let row_offset = chunk_idx * 1024;
5704 let mut cur = Vec::<f64>::with_capacity(total_cols);
5706 let mut next = Vec::<f64>::with_capacity(total_cols);
5707 for (local_i, mut out_row) in block.outer_iter_mut().enumerate() {
5708 let i = row_offset + local_i;
5709 cur.clear();
5710 cur.push(1.0);
5711 for b in marginal_designs {
5712 let q = b.ncols();
5713 next.clear();
5714 next.resize(cur.len() * q, 0.0);
5715 let b_row = b.row(i);
5719 let b_slice = b_row
5720 .as_slice()
5721 .expect("Array2 row from outer_iter is contiguous");
5722 for (a_idx, &aval) in cur.iter().enumerate() {
5723 let off = a_idx * q;
5724 let dst = &mut next[off..off + q];
5725 for col in 0..q {
5726 dst[col] = aval * b_slice[col];
5727 }
5728 }
5729 std::mem::swap(&mut cur, &mut next);
5730 }
5731 let out_slice = out_row
5736 .as_slice_mut()
5737 .expect("design row is contiguous in C-major Array2");
5738 out_slice.copy_from_slice(&cur);
5739 }
5740 });
5741 Ok(design)
5742}
5743
5744pub fn build_random_effect_block(
5745 data: ArrayView2<'_, f64>,
5746 spec: &RandomEffectTermSpec,
5747) -> Result<RandomEffectBlock, BasisError> {
5748 let n = data.nrows();
5749 let p = data.ncols();
5750 if spec.feature_col >= p {
5751 crate::bail_dim_basis!(
5752 "random-effect term '{}' feature column {} out of bounds for {} columns",
5753 spec.name,
5754 spec.feature_col,
5755 p
5756 );
5757 }
5758
5759 let col = data.column(spec.feature_col);
5760 if col.iter().any(|v| !v.is_finite()) {
5761 crate::bail_invalid_basis!(
5762 "random-effect term '{}' contains non-finite group values",
5763 spec.name
5764 );
5765 }
5766
5767 let kept_levels: Vec<u64> = if let Some(levels) = spec.frozen_levels.as_ref() {
5768 if levels.is_empty() {
5769 crate::bail_invalid_basis!(
5770 "random-effect term '{}' has empty frozen_levels",
5771 spec.name
5772 );
5773 }
5774 levels.clone()
5775 } else {
5776 let mut seen = BTreeSet::<u64>::new();
5777 let mut levels = Vec::<u64>::new();
5778 for &v in col {
5779 let bits = v.to_bits();
5780 if seen.insert(bits) {
5781 levels.push(bits);
5782 }
5783 }
5784 if levels.is_empty() {
5785 crate::bail_invalid_basis!("random-effect term '{}' has no observed levels", spec.name);
5786 }
5787 let start_idx = if spec.drop_first_level && levels.len() > 1 {
5788 1usize
5789 } else {
5790 0usize
5791 };
5792 levels[start_idx..].to_vec()
5793 };
5794
5795 if kept_levels.is_empty() {
5796 crate::bail_invalid_basis!(
5797 "random-effect term '{}' drops all levels; keep at least one level",
5798 spec.name
5799 );
5800 }
5801
5802 let q = kept_levels.len();
5803 let mut level_to_col = BTreeMap::<u64, usize>::new();
5804 for (idx, &bits) in kept_levels.iter().enumerate() {
5805 if level_to_col.insert(bits, idx).is_some() {
5806 crate::bail_invalid_basis!(
5807 "random-effect term '{}' has duplicate frozen level bits {bits}",
5808 spec.name
5809 );
5810 }
5811 }
5812 let mut group_ids = Vec::with_capacity(n);
5813 for &v in col {
5814 let bits = v.to_bits();
5815 group_ids.push(level_to_col.get(&bits).copied());
5816 }
5817
5818 Ok(RandomEffectBlock {
5819 name: spec.name.clone(),
5820 group_ids,
5821 num_groups: q,
5822 kept_levels,
5823 })
5824}
5825
5826impl SmoothDesign {
5827 pub fn map_term_coefficients(
5830 unconstrained: &Array1<f64>,
5831 shape: ShapeConstraint,
5832 ) -> Result<Array1<f64>, BasisError> {
5833 if unconstrained.is_empty() {
5834 crate::bail_invalid_basis!("unconstrained coefficient vector cannot be empty");
5835 }
5836 let mapped = match shape {
5837 ShapeConstraint::None => unconstrained.clone(),
5838 ShapeConstraint::MonotoneIncreasing => cumulative_exp(unconstrained, 1.0),
5839 ShapeConstraint::MonotoneDecreasing => cumulative_exp(unconstrained, -1.0),
5840 ShapeConstraint::Convex => second_cumulative_exp(unconstrained, 1.0),
5841 ShapeConstraint::Concave => second_cumulative_exp(unconstrained, -1.0),
5842 };
5843 Ok(mapped)
5844 }
5845}
5846
5847pub struct LocalSmoothTermBuild {
5848 pub dim: usize,
5849 pub design: DesignMatrix,
5850 pub penalties: Vec<Array2<f64>>,
5851 pub ops: Vec<Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>>,
5852 pub nullspaces: Vec<usize>,
5853 pub null_eigenvectors: Vec<Option<Array2<f64>>>,
5861 pub joint_null_rotation: Option<crate::basis::JointNullRotation>,
5868 pub penaltyinfo: Vec<PenaltyInfo>,
5869 pub pre_dropped_penaltyinfo: Vec<PenaltyInfo>,
5870 pub metadata: BasisMetadata,
5871 pub linear_constraints: Option<LinearInequalityConstraints>,
5872 pub box_reparam: bool,
5873 pub kronecker_factored: Option<KroneckerFactoredBasis>,
5874}
5875
5876#[derive(Clone)]
5877pub struct PcaScoresMemmapDesignOperator {
5878 mmap: Arc<memmap2::Mmap>,
5879 data_offset: usize,
5880 nrows: usize,
5881 ncols: usize,
5882 chunk_size: usize,
5883}
5884
5885impl PcaScoresMemmapDesignOperator {
5886 fn open(path: PathBuf, chunk_size: usize) -> Result<Self, BasisError> {
5887 let file = File::open(&path).map_err(|err| {
5888 BasisError::InvalidInput(format!(
5889 "failed to open lazy Pca .npy scores '{}': {err}",
5890 path.display()
5891 ))
5892 })?;
5893 let mmap = unsafe {
5899 memmap2::Mmap::map(&file).map_err(|err| {
5900 BasisError::InvalidInput(format!(
5901 "failed to memmap lazy Pca .npy scores '{}': {err}",
5902 path.display()
5903 ))
5904 })?
5905 };
5906 let (data_offset, nrows, ncols) = parse_f64_2d_npy_header(&mmap, &path)?;
5907 let expected = data_offset
5908 .checked_add(nrows.saturating_mul(ncols).saturating_mul(8))
5909 .ok_or_else(|| {
5910 BasisError::InvalidInput(format!(
5911 "lazy Pca .npy scores '{}' shape is too large",
5912 path.display()
5913 ))
5914 })?;
5915 if mmap.len() < expected {
5916 crate::bail_invalid_basis!(
5917 "lazy Pca .npy scores '{}' is truncated: header expects {} bytes, file has {}",
5918 path.display(),
5919 expected,
5920 mmap.len()
5921 );
5922 }
5923 Ok(Self {
5924 mmap: Arc::new(mmap),
5925 data_offset,
5926 nrows,
5927 ncols,
5928 chunk_size: chunk_size.max(1),
5929 })
5930 }
5931
5932 fn value(&self, row: usize, col: usize) -> f64 {
5933 let offset = self.data_offset + (row * self.ncols + col) * 8;
5934 let mut bytes = [0_u8; 8];
5935 bytes.copy_from_slice(&self.mmap[offset..offset + 8]);
5936 f64::from_le_bytes(bytes)
5937 }
5938
5939 fn chunk_rows(&self) -> usize {
5940 self.chunk_size.min(self.nrows.max(1))
5941 }
5942}
5943
5944impl LinearOperator for PcaScoresMemmapDesignOperator {
5945 fn nrows(&self) -> usize {
5946 self.nrows
5947 }
5948
5949 fn ncols(&self) -> usize {
5950 self.ncols
5951 }
5952
5953 fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
5954 assert_eq!(
5955 vector.len(),
5956 self.ncols,
5957 "lazy Pca apply vector length mismatch"
5958 );
5959 let mut out = Array1::<f64>::zeros(self.nrows);
5960 for start in (0..self.nrows).step_by(self.chunk_rows()) {
5961 let end = (start + self.chunk_rows()).min(self.nrows);
5962 for row in start..end {
5963 let mut acc = 0.0;
5964 for col in 0..self.ncols {
5965 acc += self.value(row, col) * vector[col];
5966 }
5967 out[row] = acc;
5968 }
5969 }
5970 out
5971 }
5972
5973 fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
5974 assert_eq!(
5975 vector.len(),
5976 self.nrows,
5977 "lazy Pca apply_transpose vector length mismatch"
5978 );
5979 let mut out = Array1::<f64>::zeros(self.ncols);
5980 for start in (0..self.nrows).step_by(self.chunk_rows()) {
5981 let end = (start + self.chunk_rows()).min(self.nrows);
5982 for row in start..end {
5983 let scale = vector[row];
5984 if scale == 0.0 {
5985 continue;
5986 }
5987 for col in 0..self.ncols {
5988 out[col] += scale * self.value(row, col);
5989 }
5990 }
5991 }
5992 out
5993 }
5994
5995 fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
5996 if weights.len() != self.nrows {
5997 return Err(format!(
5998 "lazy Pca diag_xtw_x weight length mismatch: weights={}, nrows={}",
5999 weights.len(),
6000 self.nrows
6001 ));
6002 }
6003 let mut gram = Array2::<f64>::zeros((self.ncols, self.ncols));
6004 for start in (0..self.nrows).step_by(self.chunk_rows()) {
6005 let end = (start + self.chunk_rows()).min(self.nrows);
6006 for row in start..end {
6007 let w = weights[row];
6008 if w == 0.0 {
6009 continue;
6010 }
6011 for a in 0..self.ncols {
6012 let xa = self.value(row, a);
6013 if xa == 0.0 {
6014 continue;
6015 }
6016 for b in a..self.ncols {
6017 gram[[a, b]] += w * xa * self.value(row, b);
6018 }
6019 }
6020 }
6021 }
6022 for a in 0..self.ncols {
6023 for b in 0..a {
6024 gram[[a, b]] = gram[[b, a]];
6025 }
6026 }
6027 Ok(gram)
6028 }
6029
6030 fn apply_weighted_normal(
6031 &self,
6032 weights: &Array1<f64>,
6033 vector: &Array1<f64>,
6034 penalty: Option<&Array2<f64>>,
6035 ridge: f64,
6036 ) -> Array1<f64> {
6037 assert_eq!(
6038 weights.len(),
6039 self.nrows,
6040 "lazy Pca weighted-normal weight mismatch"
6041 );
6042 assert_eq!(
6043 vector.len(),
6044 self.ncols,
6045 "lazy Pca weighted-normal vector mismatch"
6046 );
6047 let mut out = Array1::<f64>::zeros(self.ncols);
6048 for start in (0..self.nrows).step_by(self.chunk_rows()) {
6049 let end = (start + self.chunk_rows()).min(self.nrows);
6050 for row in start..end {
6051 let w = weights[row].max(0.0);
6052 if w == 0.0 {
6053 continue;
6054 }
6055 let mut row_dot = 0.0;
6056 for col in 0..self.ncols {
6057 row_dot += self.value(row, col) * vector[col];
6058 }
6059 if row_dot == 0.0 {
6060 continue;
6061 }
6062 let scaled = w * row_dot;
6063 for col in 0..self.ncols {
6064 out[col] += scaled * self.value(row, col);
6065 }
6066 }
6067 }
6068 if let Some(pen) = penalty {
6069 out += &pen.dot(vector);
6070 }
6071 if ridge > 0.0 {
6072 out += &vector.mapv(|x| ridge * x);
6073 }
6074 out
6075 }
6076}
6077
6078impl DenseDesignOperator for PcaScoresMemmapDesignOperator {
6079 fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
6080 if weights.len() != self.nrows || y.len() != self.nrows {
6081 return Err(format!(
6082 "lazy Pca compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
6083 weights.len(),
6084 y.len(),
6085 self.nrows
6086 ));
6087 }
6088 let mut out = Array1::<f64>::zeros(self.ncols);
6089 for start in (0..self.nrows).step_by(self.chunk_rows()) {
6090 let end = (start + self.chunk_rows()).min(self.nrows);
6091 for row in start..end {
6092 let scale = weights[row] * y[row];
6093 if scale == 0.0 {
6094 continue;
6095 }
6096 for col in 0..self.ncols {
6097 out[col] += scale * self.value(row, col);
6098 }
6099 }
6100 }
6101 Ok(out)
6102 }
6103
6104 fn row_chunk_into(
6105 &self,
6106 rows: Range<usize>,
6107 mut out: ArrayViewMut2<'_, f64>,
6108 ) -> Result<(), MatrixMaterializationError> {
6109 if rows.end > self.nrows || rows.start > rows.end {
6110 return Err(MatrixMaterializationError::MissingRowChunk {
6111 context: "lazy Pca row range out of bounds",
6112 });
6113 }
6114 if out.nrows() != rows.end - rows.start || out.ncols() != self.ncols {
6115 return Err(MatrixMaterializationError::MissingRowChunk {
6116 context: "lazy Pca row_chunk_into shape mismatch",
6117 });
6118 }
6119 for (local, row) in (rows.start..rows.end).enumerate() {
6120 for col in 0..self.ncols {
6121 out[[local, col]] = self.value(row, col);
6122 }
6123 }
6124 Ok(())
6125 }
6126
6127 fn to_dense(&self) -> Array2<f64> {
6128 let mut out = Array2::<f64>::zeros((self.nrows, self.ncols));
6129 self.row_chunk_into(0..self.nrows, out.view_mut())
6130 .expect("lazy Pca full materialization failed");
6131 out
6132 }
6133}
6134
6135pub fn parse_f64_2d_npy_header(
6136 bytes: &[u8],
6137 path: &PathBuf,
6138) -> Result<(usize, usize, usize), BasisError> {
6139 if bytes.len() < 10 || &bytes[0..6] != b"\x93NUMPY" {
6140 crate::bail_invalid_basis!("lazy Pca scores '{}' is not a .npy file", path.display());
6141 }
6142 let major = bytes[6];
6143 let header_len = match major {
6144 1 => u16::from_le_bytes([bytes[8], bytes[9]]) as usize,
6145 2 | 3 => {
6146 if bytes.len() < 12 {
6147 crate::bail_invalid_basis!(
6148 "lazy Pca scores '{}' has a truncated .npy header",
6149 path.display()
6150 );
6151 }
6152 u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]) as usize
6153 }
6154 other => {
6155 crate::bail_invalid_basis!(
6156 "lazy Pca scores '{}' uses unsupported .npy version {}",
6157 path.display(),
6158 other
6159 );
6160 }
6161 };
6162 let header_start = if major == 1 { 10 } else { 12 };
6163 let data_offset = header_start + header_len;
6164 if bytes.len() < data_offset {
6165 crate::bail_invalid_basis!(
6166 "lazy Pca scores '{}' has a truncated .npy header",
6167 path.display()
6168 );
6169 }
6170 let header = std::str::from_utf8(&bytes[header_start..data_offset]).map_err(|err| {
6171 BasisError::InvalidInput(format!(
6172 "lazy Pca scores '{}' has a non-UTF8 .npy header: {err}",
6173 path.display()
6174 ))
6175 })?;
6176 if !(header.contains("'descr': '<f8'")
6177 || header.contains("\"descr\": \"<f8\"")
6178 || header.contains("'descr': '|f8'")
6179 || header.contains("\"descr\": \"|f8\""))
6180 {
6181 crate::bail_invalid_basis!(
6182 "lazy Pca scores '{}' must be float64 little-endian .npy",
6183 path.display()
6184 );
6185 }
6186 if header.contains("True") {
6187 crate::bail_invalid_basis!(
6188 "lazy Pca scores '{}' must be C-contiguous, not Fortran-ordered",
6189 path.display()
6190 );
6191 }
6192 let shape_pos = header.find("shape").ok_or_else(|| {
6193 BasisError::InvalidInput(format!(
6194 "lazy Pca scores '{}' .npy header is missing shape",
6195 path.display()
6196 ))
6197 })?;
6198 let open = header[shape_pos..].find('(').ok_or_else(|| {
6199 BasisError::InvalidInput(format!(
6200 "lazy Pca scores '{}' .npy header has malformed shape",
6201 path.display()
6202 ))
6203 })? + shape_pos;
6204 let close = header[open..].find(')').ok_or_else(|| {
6205 BasisError::InvalidInput(format!(
6206 "lazy Pca scores '{}' .npy header has malformed shape",
6207 path.display()
6208 ))
6209 })? + open;
6210 let dims = header[open + 1..close]
6211 .split(',')
6212 .map(str::trim)
6213 .filter(|part| !part.is_empty())
6214 .map(|part| part.parse::<usize>())
6215 .collect::<Result<Vec<_>, _>>()
6216 .map_err(|err| {
6217 BasisError::InvalidInput(format!(
6218 "lazy Pca scores '{}' .npy shape is not integral: {err}",
6219 path.display()
6220 ))
6221 })?;
6222 if dims.len() != 2 {
6223 crate::bail_invalid_basis!(
6224 "lazy Pca scores '{}' must have shape (N, K), got {:?}",
6225 path.display(),
6226 dims
6227 );
6228 }
6229 Ok((data_offset, dims[0], dims[1]))
6230}
6231
6232pub fn pca_center_mean(x: ArrayView2<'_, f64>) -> Result<Array1<f64>, BasisError> {
6233 if x.nrows() == 0 {
6234 crate::bail_invalid_basis!("Pca basis requires at least one row to compute center mean");
6235 }
6236 let mut mean = Array1::<f64>::zeros(x.ncols());
6237 for row in x.rows() {
6238 mean += &row;
6239 }
6240 mean.mapv_inplace(|v| v / x.nrows() as f64);
6241 Ok(mean)
6242}
6243
6244pub fn build_pca_smooth_basis(
6245 data: ArrayView2<'_, f64>,
6246 feature_cols: &[usize],
6247 basis_matrix: &Array2<f64>,
6248 centered: bool,
6249 smooth_penalty: f64,
6250 center_mean: Option<&Array1<f64>>,
6251 pca_basis_path: Option<&PathBuf>,
6252 chunk_size: usize,
6253) -> Result<BasisBuildResult, BasisError> {
6254 if let Some(path) = pca_basis_path {
6255 let op = PcaScoresMemmapDesignOperator::open(path.clone(), chunk_size)?;
6256 if op.nrows != data.nrows() {
6257 crate::bail_dim_basis!(
6258 "lazy Pca scores row mismatch: .npy has {}, data has {}",
6259 op.nrows,
6260 data.nrows()
6261 );
6262 }
6263 let k = op.ncols;
6264 let mut penalty = Array2::<f64>::eye(k);
6265 penalty.mapv_inplace(|v| v * smooth_penalty);
6266 let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
6267 filter_active_penalty_candidates_with_ops(vec![PenaltyCandidate {
6268 matrix: penalty,
6269 nullspace_dim_hint: 0,
6270 source: PenaltySource::Other("PcaRidge".to_string()),
6271 normalization_scale: 1.0,
6272 kronecker_factors: None,
6273 op: None,
6274 }])?;
6275 return Ok(BasisBuildResult {
6276 design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(op))),
6277 penalties,
6278 nullspace_dims,
6279 penaltyinfo,
6280 ops,
6281 null_eigenvectors,
6282 joint_null_rotation: None,
6283 metadata: BasisMetadata::Pca {
6284 feature_cols: feature_cols.to_vec(),
6285 basis_matrix: basis_matrix.clone(),
6286 centered,
6287 smooth_penalty,
6288 center_mean: center_mean.cloned(),
6289 pca_basis_path: Some(path.clone()),
6290 chunk_size: chunk_size.max(1),
6291 },
6292 kronecker_factored: None,
6293 });
6294 }
6295 if basis_matrix.nrows() != feature_cols.len() {
6296 crate::bail_dim_basis!(
6297 "Pca basis row mismatch: basis rows={}, feature columns={}",
6298 basis_matrix.nrows(),
6299 feature_cols.len()
6300 );
6301 }
6302 let mut x = select_columns(data, feature_cols)?;
6303 let mean = if centered {
6304 match center_mean {
6305 Some(mean) => mean.clone(),
6306 None => pca_center_mean(x.view())?,
6307 }
6308 } else {
6309 Array1::<f64>::zeros(feature_cols.len())
6310 };
6311 if centered {
6312 for mut row in x.rows_mut() {
6313 row -= &mean;
6314 }
6315 }
6316 let design = fast_ab(&x, basis_matrix);
6317 let k = basis_matrix.ncols();
6318 let mut penalty = Array2::<f64>::eye(k);
6319 penalty.mapv_inplace(|v| v * smooth_penalty);
6320 let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
6321 filter_active_penalty_candidates_with_ops(vec![PenaltyCandidate {
6322 matrix: penalty,
6323 nullspace_dim_hint: 0,
6324 source: PenaltySource::Other("PcaRidge".to_string()),
6325 normalization_scale: 1.0,
6326 kronecker_factors: None,
6327 op: None,
6328 }])?;
6329 Ok(BasisBuildResult {
6330 design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(design)),
6331 penalties,
6332 nullspace_dims,
6333 penaltyinfo,
6334 ops,
6335 null_eigenvectors,
6336 joint_null_rotation: None,
6337 metadata: BasisMetadata::Pca {
6338 feature_cols: feature_cols.to_vec(),
6339 basis_matrix: basis_matrix.clone(),
6340 centered,
6341 smooth_penalty,
6342 center_mean: centered.then_some(mean),
6343 pca_basis_path: None,
6344 chunk_size: chunk_size.max(1),
6345 },
6346 kronecker_factored: None,
6347 })
6348}
6349
6350pub fn defer_inner_model_centering_to_factor_level_wrapper(basis: &mut SmoothBasisSpec) {
6366 if let SmoothBasisSpec::BSpline1D { spec, .. } = basis
6367 && matches!(
6368 spec.identifiability,
6369 BSplineIdentifiability::WeightedSumToZero { .. }
6370 )
6371 {
6372 spec.identifiability = BSplineIdentifiability::None;
6373 }
6374}
6375
6376pub fn apply_by_variable_to_local_build(
6377 mut built: LocalSmoothTermBuild,
6378 data: ArrayView2<'_, f64>,
6379 by_col: usize,
6380 by: &ByVariableSpec,
6381 term_name: &str,
6382) -> Result<LocalSmoothTermBuild, BasisError> {
6383 if by_col >= data.ncols() {
6384 crate::bail_dim_basis!(
6385 "by-variable smooth term '{term_name}' references column {by_col}, but data has {} columns",
6386 data.ncols()
6387 );
6388 }
6389 let weights = match by {
6390 ByVariableSpec::Numeric => data.column(by_col).to_owned(),
6391 ByVariableSpec::Level { value_bits, .. } => data.column(by_col).mapv(|value| {
6392 if value.to_bits() == *value_bits {
6393 1.0
6394 } else {
6395 0.0
6396 }
6397 }),
6398 };
6399 if weights.iter().any(|value| !value.is_finite()) {
6400 crate::bail_invalid_basis!(
6401 "by-variable smooth term '{term_name}' has non-finite by-column values"
6402 );
6403 }
6404
6405 let mut dense = built
6406 .design
6407 .try_to_dense_by_chunks("by-variable smooth row gating")
6408 .map_err(BasisError::InvalidInput)?;
6409 for (mut row, &weight) in dense.rows_mut().into_iter().zip(weights.iter()) {
6410 row.mapv_inplace(|value| value * weight);
6411 }
6412 built.design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense));
6413 built.kronecker_factored = None;
6414 Ok(built)
6415}
6416
6417pub fn build_by_smooth_local(
6428 data: ArrayView2<'_, f64>,
6429 term: &SmoothTermSpec,
6430 smooth: &SmoothBasisSpec,
6431 by_kind: &ByVarKind,
6432 workspace: &mut crate::basis::BasisWorkspace,
6433) -> Result<LocalSmoothTermBuild, BasisError> {
6434 let inner_term = SmoothTermSpec {
6435 name: term.name.clone(),
6436 basis: (*smooth).clone(),
6437 shape: term.shape,
6438 joint_null_rotation: None,
6439 };
6440 let inner = build_single_local_smooth_term(data, &inner_term, workspace)?;
6441
6442 match by_kind {
6443 ByVarKind::Numeric { feature_col } => {
6444 let inner_meta = inner.metadata.clone();
6445 let mut built = apply_by_variable_to_local_build(
6446 inner,
6447 data,
6448 *feature_col,
6449 &ByVariableSpec::Numeric,
6450 &term.name,
6451 )?;
6452 built.metadata = BasisMetadata::BySmooth {
6453 inner: Box::new(inner_meta),
6454 by_col: *feature_col,
6455 levels: None,
6456 ordered: false,
6457 };
6458 Ok(built)
6459 }
6460 ByVarKind::Factor {
6461 feature_col,
6462 frozen_levels,
6463 ordered,
6464 } => {
6465 let level_bits: Vec<u64> = if let Some(fl) = frozen_levels {
6468 fl.clone()
6469 } else {
6470 let col = data.column(*feature_col);
6471 let mut seen = BTreeSet::<u64>::new();
6472 for &v in col.iter() {
6473 if v.is_finite() {
6474 seen.insert(v.to_bits());
6475 }
6476 }
6477 seen.into_iter().collect()
6478 };
6479 let n_levels = level_bits.len();
6480 if n_levels == 0 {
6481 crate::bail_invalid_basis!(
6482 "by-factor smooth term '{}': factor column {} has no observed levels",
6483 term.name,
6484 feature_col
6485 );
6486 }
6487 let p = inner.dim;
6488 let q = n_levels * p;
6489 let n = data.nrows();
6490
6491 let inner_dense = inner
6492 .design
6493 .try_to_dense_by_chunks("by-factor smooth design gating")
6494 .map_err(BasisError::InvalidInput)?;
6495
6496 let mut combined = Array2::<f64>::zeros((n, q));
6498 for (lvl_idx, &bits) in level_bits.iter().enumerate() {
6499 let col_start = lvl_idx * p;
6500 for row in 0..n {
6501 if data[[row, *feature_col]].to_bits() == bits {
6502 combined
6503 .slice_mut(s![row, col_start..col_start + p])
6504 .assign(&inner_dense.row(row));
6505 }
6506 }
6507 }
6508
6509 let inner_meta = inner.metadata.clone();
6521 let n_penalties = inner.penalties.len();
6522 let n_blocks = n_penalties.saturating_mul(n_levels);
6523 let mut penalties = Vec::<Array2<f64>>::with_capacity(n_blocks);
6524 let mut penaltyinfo = Vec::<PenaltyInfo>::with_capacity(n_blocks);
6525 let mut nullspaces = Vec::<usize>::with_capacity(n_blocks);
6526 for (pen_pos, s_inner) in inner.penalties.iter().enumerate() {
6527 for lvl in 0..n_levels {
6528 let off = lvl * p;
6529 let mut s_big = Array2::<f64>::zeros((q, q));
6530 s_big
6531 .slice_mut(s![off..off + p, off..off + p])
6532 .assign(s_inner);
6533 let (s_big, scale) = normalize_penalty_in_constrained_space(&s_big);
6534 let mut info = inner.penaltyinfo[pen_pos].clone();
6535 info.original_index = pen_pos * n_levels + lvl;
6538 info.normalization_scale *= scale;
6539 info.kronecker_factors = None;
6542 penalties.push(s_big);
6543 penaltyinfo.push(info);
6544 nullspaces.push(inner.nullspaces[pen_pos]);
6545 }
6546 }
6547
6548 let null_eigenvectors = vec![None; penalties.len()];
6549 let ops = vec![None; penalties.len()];
6550
6551 Ok(LocalSmoothTermBuild {
6552 dim: q,
6553 design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(combined)),
6554 penalties,
6555 ops,
6556 nullspaces,
6557 null_eigenvectors,
6558 joint_null_rotation: None,
6559 penaltyinfo,
6560 pre_dropped_penaltyinfo: inner.pre_dropped_penaltyinfo,
6561 metadata: BasisMetadata::BySmooth {
6562 inner: Box::new(inner_meta),
6563 by_col: *feature_col,
6564 levels: Some(level_bits),
6565 ordered: *ordered,
6566 },
6567 linear_constraints: None,
6568 box_reparam: false,
6569 kronecker_factored: None,
6570 })
6571 }
6572 }
6573}
6574
6575pub fn ensure_by_variable_specs_match(
6576 kind: &BySmoothKind,
6577 by: &ByVariableSpec,
6578 term_name: &str,
6579) -> Result<(), BasisError> {
6580 match (kind, by) {
6581 (BySmoothKind::Numeric, ByVariableSpec::Numeric) => Ok(()),
6582 (BySmoothKind::Level { level_bits }, ByVariableSpec::Level { value_bits, .. })
6583 if level_bits == value_bits =>
6584 {
6585 Ok(())
6586 }
6587 _ => Err(BasisError::InvalidInput(format!(
6588 "by-variable smooth term '{term_name}' has inconsistent by-variable specifications"
6589 ))),
6590 }
6591}
6592
6593pub fn build_factor_smooth(
6621 data: ArrayView2<'_, f64>,
6622 spec: &FactorSmoothSpec,
6623 term_name: &str,
6624 workspace: &mut crate::basis::BasisWorkspace,
6625) -> Result<LocalSmoothTermBuild, BasisError> {
6626 if spec.continuous_cols.len() != 1 {
6627 crate::bail_invalid_basis!(
6628 "factor smooth term '{}' currently supports exactly one continuous covariate; found {}",
6629 term_name,
6630 spec.continuous_cols.len()
6631 );
6632 }
6633 let feature_col = spec.continuous_cols[0];
6634 let group_col = spec.group_col;
6635 if feature_col >= data.ncols() || group_col >= data.ncols() {
6636 crate::bail_dim_basis!(
6637 "factor smooth term '{}' references columns ({}, {}) out of bounds for {} columns",
6638 term_name,
6639 feature_col,
6640 group_col,
6641 data.ncols()
6642 );
6643 }
6644
6645 if matches!(spec.flavour, FactorSmoothFlavour::Sz) {
6648 let levels = resolve_factor_smooth_levels(data, group_col, spec, term_name)?;
6649 let inner = SmoothBasisSpec::BSpline1D {
6650 feature_col,
6651 spec: factor_smooth_marginal_for_replay(&spec.marginal),
6652 };
6653 let sz_term = SmoothTermSpec {
6654 name: term_name.to_string(),
6655 basis: SmoothBasisSpec::FactorSumToZero {
6656 inner: Box::new(inner),
6657 by_col: group_col,
6658 levels: levels.clone(),
6659 frozen_global_orthogonality: None,
6660 },
6661 shape: ShapeConstraint::None,
6662 joint_null_rotation: None,
6663 };
6664 let mut built = build_single_local_smooth_term(data, &sz_term, workspace)?;
6665 let (knots, degree, periodic, marginal_is_cr) = match &built.metadata {
6686 BasisMetadata::BSpline1D {
6687 knots,
6688 periodic,
6689 degree,
6690 ..
6691 } => (
6692 knots.clone(),
6693 degree.unwrap_or(spec.marginal.degree),
6694 *periodic,
6695 false,
6696 ),
6697 BasisMetadata::CubicRegression1D { knots, .. } => {
6698 (knots.clone(), spec.marginal.degree, None, true)
6699 }
6700 other => {
6701 crate::bail_invalid_basis!(
6702 "sz factor smooth term '{}' produced an unexpected marginal metadata variant {:?}",
6703 term_name,
6704 other
6705 );
6706 }
6707 };
6708 built.metadata = BasisMetadata::FactorSmooth {
6709 continuous_cols: spec.continuous_cols.clone(),
6710 group_col,
6711 knots,
6712 degree,
6713 periodic,
6714 group_levels: levels,
6715 flavour: "sz".to_string(),
6716 marginal_is_cr,
6717 };
6718 return Ok(built);
6719 }
6720
6721 let levels = resolve_factor_smooth_levels(data, group_col, spec, term_name)?;
6722 let n_levels = levels.len();
6723 if n_levels < 2 {
6724 crate::bail_invalid_basis!(
6725 "factor smooth term '{}' requires at least two grouping levels; found {}",
6726 term_name,
6727 n_levels
6728 );
6729 }
6730
6731 let use_per_dim_null = matches!(
6739 &spec.flavour,
6740 FactorSmoothFlavour::Fs { m_null_penalty_orders }
6741 if m_null_penalty_orders.iter().copied().max().unwrap_or(0) >= 1
6742 );
6743
6744 let mut marginal_spec = factor_smooth_marginal_for_replay(&spec.marginal);
6750 if use_per_dim_null {
6751 marginal_spec.double_penalty = false;
6752 }
6753 let inner_term = SmoothTermSpec {
6754 name: format!("{term_name}::marginal"),
6755 basis: SmoothBasisSpec::BSpline1D {
6756 feature_col,
6757 spec: marginal_spec,
6758 },
6759 shape: ShapeConstraint::None,
6760 joint_null_rotation: None,
6761 };
6762 let inner = build_single_local_smooth_term(data, &inner_term, workspace)?;
6763 let mut base = inner
6764 .design
6765 .try_to_dense_by_chunks("factor smooth marginal")
6766 .map_err(BasisError::InvalidInput)?;
6767 if matches!(spec.flavour, FactorSmoothFlavour::Re) {
6768 let center = match &inner.metadata {
6778 BasisMetadata::BSpline1D { knots, .. } if !knots.is_empty() => {
6779 0.5 * (knots[0] + knots[knots.len() - 1])
6780 }
6781 _ => 0.0,
6782 };
6783 let mut linear = Array2::<f64>::ones((data.nrows(), 2));
6784 linear
6785 .column_mut(1)
6786 .assign(&data.column(feature_col).mapv(|x| x - center));
6787 base = linear;
6788 }
6789 let n = base.nrows();
6790 let p = base.ncols();
6791 let q = p * n_levels;
6792
6793 let mut dense = Array2::<f64>::zeros((n, q));
6796 for i in 0..n {
6797 let bits = data[[i, group_col]].to_bits();
6798 let level_idx = levels.iter().position(|b| *b == bits).ok_or_else(|| {
6799 BasisError::InvalidInput(format!(
6800 "factor smooth term '{term_name}' saw an unseen grouping level at row {}",
6801 i + 1
6802 ))
6803 })?;
6804 let start = level_idx * p;
6805 dense
6806 .slice_mut(s![i, start..start + p])
6807 .assign(&base.row(i));
6808 }
6809
6810 let marginal_penalties: Vec<Array2<f64>> = if matches!(spec.flavour, FactorSmoothFlavour::Re) {
6816 (0..p)
6817 .map(|j| {
6818 let mut s = Array2::<f64>::zeros((p, p));
6819 s[[j, j]] = 1.0;
6820 s
6821 })
6822 .collect()
6823 } else {
6824 inner.penalties.clone()
6825 };
6826 let marginal_penaltyinfo: Vec<PenaltyInfo> = if matches!(spec.flavour, FactorSmoothFlavour::Re)
6827 {
6828 (0..p)
6829 .map(|j| PenaltyInfo {
6830 source: PenaltySource::Primary,
6831 original_index: j,
6832 active: true,
6833 effective_rank: 1,
6834 dropped_reason: None,
6835 nullspace_dim_hint: p.saturating_sub(1),
6836 normalization_scale: 1.0,
6837 kronecker_factors: None,
6838 })
6839 .collect()
6840 } else {
6841 inner.penaltyinfo.clone()
6842 };
6843 if marginal_penalties.len() != marginal_penaltyinfo.len() {
6844 crate::bail_invalid_basis!(
6845 "internal factor-smooth penalty metadata mismatch for term '{}': penalties={}, infos={}",
6846 term_name,
6847 marginal_penalties.len(),
6848 marginal_penaltyinfo.len()
6849 );
6850 }
6851
6852 let mut penalties = Vec::<Array2<f64>>::with_capacity(marginal_penalties.len());
6853 let mut penaltyinfo = Vec::<PenaltyInfo>::with_capacity(marginal_penalties.len());
6854 for (penalty_pos, s_inner) in marginal_penalties.iter().enumerate() {
6855 let mut s_big = Array2::<f64>::zeros((q, q));
6856 for level in 0..n_levels {
6857 let start = level * p;
6858 s_big
6859 .slice_mut(s![start..start + p, start..start + p])
6860 .assign(s_inner);
6861 }
6862 let (s_big, factor_smooth_scale) = normalize_penalty_in_constrained_space(&s_big);
6863 let mut info = marginal_penaltyinfo[penalty_pos].clone();
6864 info.original_index = penalty_pos;
6865 info.normalization_scale *= factor_smooth_scale;
6866 info.nullspace_dim_hint = info.nullspace_dim_hint.saturating_mul(n_levels);
6867 info.kronecker_factors = None;
6868 penalties.push(s_big);
6869 penaltyinfo.push(info);
6870 }
6871
6872 let mut nullspaces: Vec<usize> = if matches!(spec.flavour, FactorSmoothFlavour::Re) {
6873 vec![q.saturating_sub(n_levels); p]
6874 } else {
6875 inner
6876 .nullspaces
6877 .iter()
6878 .map(|ns| ns.saturating_mul(n_levels))
6879 .collect()
6880 };
6881
6882 if use_per_dim_null
6912 && let Some(Some(z)) = inner.null_eigenvectors.first()
6913 && z.nrows() == p
6914 {
6915 for k in 0..z.ncols() {
6916 let zk = z.column(k);
6921 let mut p_k = Array2::<f64>::zeros((p, p));
6922 for a in 0..p {
6923 for b in 0..p {
6924 p_k[[a, b]] = zk[a] * zk[b];
6925 }
6926 }
6927 let mut s_null = Array2::<f64>::zeros((q, q));
6928 for level in 0..n_levels {
6929 let start = level * p;
6930 s_null
6931 .slice_mut(s![start..start + p, start..start + p])
6932 .assign(&p_k);
6933 }
6934 let (s_null, null_scale) = normalize_penalty_in_constrained_space(&s_null);
6935 let null_block = crate::basis::analyze_penalty_block_with_op(&s_null, None)?;
6936 if null_block.rank > 0 {
6937 let original_index = penalties.len();
6938 penalties.push(null_block.sym_penalty);
6939 nullspaces.push(null_block.nullity);
6940 penaltyinfo.push(PenaltyInfo {
6941 source: PenaltySource::Primary,
6942 original_index,
6943 active: true,
6944 effective_rank: null_block.rank,
6945 dropped_reason: None,
6946 nullspace_dim_hint: null_block.nullity,
6947 normalization_scale: null_scale,
6948 kronecker_factors: None,
6949 });
6950 }
6951 }
6952 }
6953 let null_eigenvectors = crate::basis::recompute_null_eigenvectors(&penalties)?;
6954 let joint_null_rotation = crate::basis::compute_joint_null_rotation(&penalties)?;
6955
6956 let (knots, degree, periodic) = match &inner.metadata {
6959 BasisMetadata::BSpline1D {
6960 knots,
6961 periodic,
6962 degree,
6963 ..
6964 } => (
6965 knots.clone(),
6966 degree.unwrap_or(spec.marginal.degree),
6967 *periodic,
6968 ),
6969 other => {
6970 crate::bail_invalid_basis!(
6971 "factor smooth term '{}' produced an unexpected marginal metadata variant {:?}",
6972 term_name,
6973 other
6974 );
6975 }
6976 };
6977 let flavour_tag = match &spec.flavour {
6978 FactorSmoothFlavour::Fs { .. } => "fs",
6979 FactorSmoothFlavour::Sz => "sz",
6980 FactorSmoothFlavour::Re => "re",
6981 }
6982 .to_string();
6983 let metadata = BasisMetadata::FactorSmooth {
6984 continuous_cols: spec.continuous_cols.clone(),
6985 group_col,
6986 knots,
6987 degree,
6988 periodic,
6989 group_levels: levels,
6990 flavour: flavour_tag,
6991 marginal_is_cr: false,
6994 };
6995
6996 let ops = vec![None; penalties.len()];
6997 Ok(LocalSmoothTermBuild {
6998 dim: q,
6999 design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense)),
7000 penalties,
7001 ops,
7002 nullspaces,
7003 null_eigenvectors,
7004 joint_null_rotation,
7005 penaltyinfo,
7006 pre_dropped_penaltyinfo: Vec::new(),
7007 metadata,
7008 linear_constraints: None,
7009 box_reparam: false,
7010 kronecker_factored: None,
7011 })
7012}
7013
7014pub fn resolve_factor_smooth_levels(
7018 data: ArrayView2<'_, f64>,
7019 group_col: usize,
7020 spec: &FactorSmoothSpec,
7021 term_name: &str,
7022) -> Result<Vec<u64>, BasisError> {
7023 if let Some(frozen) = &spec.group_frozen_levels {
7024 if frozen.is_empty() {
7025 crate::bail_invalid_basis!(
7026 "factor smooth term '{}' has an empty frozen level list",
7027 term_name
7028 );
7029 }
7030 return Ok(frozen.clone());
7031 }
7032 let mut bits: Vec<u64> = data.column(group_col).iter().map(|v| v.to_bits()).collect();
7033 bits.sort_by(|a, b| {
7034 f64::from_bits(*a)
7035 .partial_cmp(&f64::from_bits(*b))
7036 .unwrap_or(std::cmp::Ordering::Equal)
7037 });
7038 bits.dedup();
7039 Ok(bits)
7040}
7041
7042pub fn factor_smooth_marginal_for_replay(marginal: &BSplineBasisSpec) -> BSplineBasisSpec {
7049 let mut m = marginal.clone();
7050 m.identifiability = BSplineIdentifiability::None;
7051 m
7052}
7053
7054pub fn build_single_local_smooth_term(
7055 data: ArrayView2<'_, f64>,
7056 term: &SmoothTermSpec,
7057 workspace: &mut crate::basis::BasisWorkspace,
7058) -> Result<LocalSmoothTermBuild, BasisError> {
7059 if term.shape != ShapeConstraint::None && !shape_supports_basis(term) {
7060 crate::bail_invalid_basis!(
7061 "ShapeConstraint::{:?} is unsupported for term '{}'",
7062 term.shape,
7063 term.name
7064 );
7065 }
7066 if let SmoothBasisSpec::ByVariable {
7067 inner,
7068 by_col,
7069 kind,
7070 by,
7071 } = &term.basis
7072 {
7073 ensure_by_variable_specs_match(kind, by, &term.name)?;
7074 let mut inner_basis = (**inner).clone();
7075 if matches!(by, ByVariableSpec::Level { .. }) {
7082 defer_inner_model_centering_to_factor_level_wrapper(&mut inner_basis);
7083 }
7084 let inner_term = SmoothTermSpec {
7085 name: term.name.clone(),
7086 basis: inner_basis,
7087 shape: term.shape,
7088 joint_null_rotation: None,
7089 };
7090 let built = build_single_local_smooth_term(data, &inner_term, workspace)?;
7091 return apply_by_variable_to_local_build(built, data, *by_col, by, &term.name);
7092 }
7093
7094 if let SmoothBasisSpec::BySmooth { smooth, by_kind } = &term.basis {
7097 return build_by_smooth_local(data, term, smooth, by_kind, workspace);
7098 }
7099
7100 let mut shape_axis_col: Option<usize> = None;
7101 let mut built: BasisBuildResult = match &term.basis {
7102 SmoothBasisSpec::FactorSumToZero {
7103 inner,
7104 by_col,
7105 levels,
7106 ..
7107 } => {
7108 if *by_col >= data.ncols() {
7109 crate::bail_dim_basis!(
7110 "term '{}' by column {} out of bounds for {} columns",
7111 term.name,
7112 by_col,
7113 data.ncols()
7114 );
7115 }
7116 if levels.len() < 2 {
7117 crate::bail_invalid_basis!(
7118 "sum-to-zero factor smooth term '{}' requires at least two levels",
7119 term.name
7120 );
7121 }
7122 if term.shape != ShapeConstraint::None {
7123 crate::bail_invalid_basis!(
7124 "ShapeConstraint::{:?} is unsupported for sum-to-zero factor smooth term '{}'",
7125 term.shape,
7126 term.name
7127 );
7128 }
7129 let inner_term = SmoothTermSpec {
7130 name: format!("{}::inner", term.name),
7131 basis: (**inner).clone(),
7132 shape: ShapeConstraint::None,
7133 joint_null_rotation: None,
7134 };
7135 let mut inner_built = build_single_local_smooth_term(data, &inner_term, workspace)?;
7136 let inner_null_eigenvectors = inner_built.null_eigenvectors.clone();
7140 let base = inner_built
7141 .design
7142 .try_to_dense_by_chunks("sum-to-zero factor smooth")
7143 .map_err(BasisError::InvalidInput)?;
7144 let n = base.nrows();
7145 let p = base.ncols();
7146 let l_minus_one = levels.len() - 1;
7147 let mut dense = Array2::<f64>::zeros((n, p * l_minus_one));
7148 for i in 0..n {
7149 let bits = data[[i, *by_col]].to_bits();
7150 let level_idx = levels.iter().position(|b| *b == bits).ok_or_else(|| {
7151 BasisError::InvalidInput(format!(
7152 "sum-to-zero factor smooth term '{}' saw an unseen level at row {}",
7153 term.name,
7154 i + 1
7155 ))
7156 })?;
7157 if level_idx < l_minus_one {
7158 let start = level_idx * p;
7159 dense
7160 .slice_mut(s![i, start..start + p])
7161 .assign(&base.row(i));
7162 } else {
7163 for level in 0..l_minus_one {
7164 let start = level * p;
7165 dense
7166 .slice_mut(s![i, start..start + p])
7167 .assign(&base.row(i).mapv(|v| -v));
7168 }
7169 }
7170 }
7171 let mut penalties = Vec::<Array2<f64>>::with_capacity(inner_built.penalties.len());
7172 let active_penalty_indices = inner_built
7173 .penaltyinfo
7174 .iter()
7175 .enumerate()
7176 .filter_map(|(idx, info)| info.active.then_some(idx))
7177 .collect::<Vec<_>>();
7178 if active_penalty_indices.len() != inner_built.penalties.len() {
7179 crate::bail_invalid_basis!(
7180 "internal sz penalty metadata mismatch: activeinfos={}, penalties={}",
7181 active_penalty_indices.len(),
7182 inner_built.penalties.len()
7183 );
7184 }
7185 let stz_per_group_penalty = |s_inner: &Array2<f64>, which_level: usize| -> Array2<f64> {
7220 let mut s_big = Array2::<f64>::zeros((p * l_minus_one, p * l_minus_one));
7221 if which_level < l_minus_one {
7222 let k = which_level;
7224 let mut block = s_big.slice_mut(s![k * p..(k + 1) * p, k * p..(k + 1) * p]);
7225 block.assign(s_inner);
7226 } else {
7227 for a in 0..l_minus_one {
7229 for b in 0..l_minus_one {
7230 let mut block =
7231 s_big.slice_mut(s![a * p..(a + 1) * p, b * p..(b + 1) * p]);
7232 block.assign(s_inner);
7233 }
7234 }
7235 }
7236 s_big
7237 };
7238 let mut nullspaces = Vec::<usize>::with_capacity(penalties.capacity());
7244 for (penalty_pos, s_inner) in inner_built.penalties.iter().enumerate() {
7245 let info_idx = active_penalty_indices[penalty_pos];
7246 let base_info = inner_built.penaltyinfo[info_idx].clone();
7247 let marginal_nullity = inner_built.nullspaces.get(penalty_pos).copied().unwrap_or(0);
7248 for which_level in 0..=l_minus_one {
7250 let raw = stz_per_group_penalty(s_inner, which_level);
7251 let (s_big, group_scale) = normalize_penalty_in_constrained_space(&raw);
7252 let block = crate::basis::analyze_penalty_block_with_op(&s_big, None)?;
7253 if block.rank == 0 {
7254 continue;
7255 }
7256 if which_level == 0 {
7257 inner_built.penaltyinfo[info_idx].normalization_scale *= group_scale;
7260 inner_built.penaltyinfo[info_idx].original_index = penalties.len();
7261 inner_built.penaltyinfo[info_idx].effective_rank = block.rank;
7262 inner_built.penaltyinfo[info_idx].nullspace_dim_hint = block.nullity;
7263 } else {
7264 let mut info = base_info.clone();
7265 info.original_index = penalties.len();
7266 info.normalization_scale = base_info.normalization_scale * group_scale;
7267 info.effective_rank = block.rank;
7268 info.nullspace_dim_hint = block.nullity;
7269 info.kronecker_factors = None;
7270 inner_built.penaltyinfo.push(info);
7271 }
7272 penalties.push(block.sym_penalty);
7273 nullspaces.push(marginal_nullity);
7279 }
7280 }
7281
7282 if let Some(Some(z)) = inner_null_eigenvectors.first()
7300 && z.nrows() == p
7301 {
7302 for k in 0..z.ncols() {
7303 let zk = z.column(k);
7304 let mut p_k = Array2::<f64>::zeros((p, p));
7305 for a in 0..p {
7306 for b in 0..p {
7307 p_k[[a, b]] = zk[a] * zk[b];
7308 }
7309 }
7310 let stz_pooled_null = {
7315 let mut s_big = Array2::<f64>::zeros((p * l_minus_one, p * l_minus_one));
7316 for a in 0..l_minus_one {
7317 for b in 0..l_minus_one {
7318 let factor = if a == b { 2.0 } else { 1.0 };
7319 let mut block =
7320 s_big.slice_mut(s![a * p..(a + 1) * p, b * p..(b + 1) * p]);
7321 block.assign(&p_k.mapv(|v| v * factor));
7322 }
7323 }
7324 s_big
7325 };
7326 let (s_null, null_scale) =
7327 normalize_penalty_in_constrained_space(&stz_pooled_null);
7328 let null_block = crate::basis::analyze_penalty_block_with_op(&s_null, None)?;
7329 if null_block.rank > 0 {
7330 let original_index = penalties.len();
7331 penalties.push(null_block.sym_penalty);
7332 nullspaces.push(null_block.nullity);
7333 inner_built.penaltyinfo.push(PenaltyInfo {
7334 source: PenaltySource::Primary,
7335 original_index,
7336 active: true,
7337 effective_rank: null_block.rank,
7338 dropped_reason: None,
7339 nullspace_dim_hint: null_block.nullity,
7340 normalization_scale: null_scale,
7341 kronecker_factors: None,
7342 });
7343 }
7344 }
7345 }
7346 inner_built.dim = p * l_minus_one;
7347 inner_built.design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(dense));
7348 inner_built.penalties = penalties;
7349 inner_built.ops = vec![None; inner_built.penalties.len()];
7350 inner_built.nullspaces = nullspaces;
7351 inner_built.null_eigenvectors =
7358 crate::basis::recompute_null_eigenvectors(&inner_built.penalties)?;
7359 inner_built.joint_null_rotation =
7360 crate::basis::compute_joint_null_rotation(&inner_built.penalties)?;
7361 inner_built.kronecker_factored = None;
7362 return Ok(inner_built);
7363 }
7364 SmoothBasisSpec::BSpline1D { feature_col, spec } => {
7365 if *feature_col >= data.ncols() {
7366 crate::bail_dim_basis!(
7367 "term '{}' feature column {} out of bounds for {} columns",
7368 term.name,
7369 feature_col,
7370 data.ncols()
7371 );
7372 }
7373 let mut spec_local = spec.clone();
7374 if term.shape != ShapeConstraint::None {
7375 spec_local.identifiability = BSplineIdentifiability::None;
7378 }
7379 build_bspline_basis_1d(data.column(*feature_col), &spec_local)?
7383 }
7384 SmoothBasisSpec::ThinPlate {
7385 feature_cols,
7386 spec,
7387 input_scales,
7388 } => {
7389 if term.shape != ShapeConstraint::None {
7390 if feature_cols.len() != 1 {
7391 crate::bail_invalid_basis!(
7392 "ShapeConstraint::{:?} for term '{}' on ThinPlate basis requires exactly 1 feature axis; found {}",
7393 term.shape,
7394 term.name,
7395 feature_cols.len()
7396 );
7397 }
7398 shape_axis_col = Some(feature_cols[0]);
7399 }
7400 let mut x = select_columns(data, feature_cols)?;
7401 let (scales, length_scale_eff) = if let Some(s) = input_scales {
7407 apply_input_standardization(&mut x, s);
7408 (
7409 Some(s.clone()),
7410 compensate_length_scale_for_standardization(spec.length_scale, s),
7411 )
7412 } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7413 apply_input_standardization(&mut x, &s);
7414 let l_eff = compensate_length_scale_for_standardization(spec.length_scale, &s);
7415 (Some(s), l_eff)
7416 } else {
7417 (None, spec.length_scale)
7418 };
7419 let mut spec_local = spec.clone();
7420 spec_local.length_scale = length_scale_eff;
7421 if matches!(
7422 spec_local.identifiability,
7423 SpatialIdentifiability::OrthogonalToParametric
7424 ) {
7425 spec_local.identifiability = SpatialIdentifiability::None;
7426 }
7427 let mut result = build_thin_plate_basis(x.view(), &spec_local).map_err(|err| {
7428 rewrite_thin_plate_knots_error(err, &term.name, feature_cols.len(), spec)
7429 })?;
7430 match &mut result.metadata {
7438 BasisMetadata::ThinPlate {
7439 input_scales: ms,
7440 length_scale,
7441 ..
7442 } => {
7443 *ms = scales;
7444 *length_scale = spec.length_scale;
7445 }
7446 BasisMetadata::Duchon {
7447 input_scales: ms,
7448 length_scale,
7449 ..
7450 } => {
7451 if let (Some(s), Some(realized)) = (scales.as_ref(), *length_scale) {
7476 let inv_sigma_geom =
7477 compensate_length_scale_for_standardization(1.0, s);
7478 if inv_sigma_geom.is_finite() && inv_sigma_geom > 0.0 {
7479 *length_scale = Some(realized / inv_sigma_geom);
7480 }
7481 }
7482 *ms = scales;
7483 }
7484 _ => {}
7485 }
7486 result
7487 }
7488 SmoothBasisSpec::Sphere { feature_cols, spec } => {
7489 if term.shape != ShapeConstraint::None {
7490 crate::bail_invalid_basis!(
7491 "ShapeConstraint::{:?} for term '{}' is not supported on spherical splines",
7492 term.shape,
7493 term.name
7494 );
7495 }
7496 let x = select_columns(data, feature_cols)?;
7497 build_spherical_spline_basis(x.view(), spec)?
7498 }
7499 SmoothBasisSpec::ConstantCurvature { feature_cols, spec } => {
7500 if term.shape != ShapeConstraint::None {
7501 crate::bail_invalid_basis!(
7502 "ShapeConstraint::{:?} for term '{}' is not supported on constant-curvature smooths",
7503 term.shape,
7504 term.name
7505 );
7506 }
7507 let x = select_columns(data, feature_cols)?;
7514 build_constant_curvature_basis(x.view(), spec)?
7515 }
7516 SmoothBasisSpec::MeasureJet {
7517 feature_cols,
7518 spec,
7519 input_scales,
7520 } => {
7521 if term.shape != ShapeConstraint::None {
7522 crate::bail_invalid_basis!(
7523 "ShapeConstraint::{:?} for term '{}' is not supported on measure-jet smooths",
7524 term.shape,
7525 term.name
7526 );
7527 }
7528 let mut x = select_columns(data, feature_cols)?;
7529 let (scales, length_scale_eff) = if let Some(s) = input_scales {
7541 apply_input_standardization(&mut x, s);
7542 (Some(s.clone()), spec.length_scale)
7543 } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7544 apply_input_standardization(&mut x, &s);
7545 let l_eff = if spec.length_scale > 0.0 {
7546 compensate_length_scale_for_standardization(spec.length_scale, &s)
7547 } else {
7548 spec.length_scale
7549 };
7550 (Some(s), l_eff)
7551 } else {
7552 (None, spec.length_scale)
7553 };
7554 let mut spec_local = spec.clone();
7555 spec_local.length_scale = length_scale_eff;
7556 let mut result = build_measure_jet_basis(x.view(), &spec_local)?;
7557 if let BasisMetadata::MeasureJet {
7558 input_scales: ms, ..
7559 } = &mut result.metadata
7560 {
7561 *ms = scales;
7562 }
7563 result
7564 }
7565 SmoothBasisSpec::Matern {
7566 feature_cols,
7567 spec,
7568 input_scales,
7569 } => {
7570 if term.shape != ShapeConstraint::None {
7571 if feature_cols.len() != 1 {
7572 crate::bail_invalid_basis!(
7573 "ShapeConstraint::{:?} for term '{}' on Matern basis requires exactly 1 feature axis; found {}",
7574 term.shape,
7575 term.name,
7576 feature_cols.len()
7577 );
7578 }
7579 shape_axis_col = Some(feature_cols[0]);
7580 }
7581 let mut x = select_columns(data, feature_cols)?;
7582 let (scales, length_scale_eff) = if let Some(s) = input_scales {
7597 apply_input_standardization(&mut x, s);
7598 (
7599 Some(s.clone()),
7600 compensate_length_scale_for_standardization(spec.length_scale, s),
7601 )
7602 } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7603 apply_input_standardization(&mut x, &s);
7604 let l_eff = compensate_length_scale_for_standardization(spec.length_scale, &s);
7605 (Some(s), l_eff)
7606 } else {
7607 (None, spec.length_scale)
7608 };
7609 let mut spec_local = spec.clone();
7610 spec_local.length_scale = length_scale_eff;
7611 let mut result = build_matern_basiswithworkspace(x.view(), &spec_local, workspace)?;
7612 if let BasisMetadata::Matern {
7613 input_scales,
7614 length_scale,
7615 ..
7616 } = &mut result.metadata
7617 {
7618 *input_scales = scales;
7619 *length_scale = spec.length_scale;
7620 }
7621 result
7622 }
7623 SmoothBasisSpec::Duchon {
7624 feature_cols,
7625 spec,
7626 input_scales,
7627 } => {
7628 if term.shape != ShapeConstraint::None {
7629 if feature_cols.len() != 1 {
7630 crate::bail_invalid_basis!(
7631 "ShapeConstraint::{:?} for term '{}' on Duchon basis requires exactly 1 feature axis; found {}",
7632 term.shape,
7633 term.name,
7634 feature_cols.len()
7635 );
7636 }
7637 shape_axis_col = Some(feature_cols[0]);
7638 }
7639 let mut x = select_columns(data, feature_cols)?;
7640 let (scales, length_scale_eff) = if let Some(s) = input_scales {
7651 apply_input_standardization(&mut x, s);
7652 (
7653 Some(s.clone()),
7654 compensate_optional_length_scale_for_standardization(spec.length_scale, s),
7655 )
7656 } else if let Some(s) = compute_spatial_input_scales(x.view()) {
7657 apply_input_standardization(&mut x, &s);
7658 let l_eff =
7659 compensate_optional_length_scale_for_standardization(spec.length_scale, &s);
7660 (Some(s), l_eff)
7661 } else {
7662 (None, spec.length_scale)
7663 };
7664 let mut spec_local = spec.clone();
7665 spec_local.length_scale = length_scale_eff;
7666 if let (Some(s), crate::basis::OneDimensionalBoundary::Cyclic { start, end }) =
7677 (scales.as_ref(), spec_local.boundary.clone())
7678 && s.len() == 1
7679 && s[0] > 0.0
7680 {
7681 spec_local.boundary = crate::basis::OneDimensionalBoundary::Cyclic {
7682 start: start / s[0],
7683 end: end / s[0],
7684 };
7685 }
7686 if matches!(
7687 spec_local.identifiability,
7688 SpatialIdentifiability::OrthogonalToParametric
7689 ) {
7690 spec_local.identifiability = SpatialIdentifiability::None;
7691 }
7692 let mut result = build_duchon_basiswithworkspace(x.view(), &spec_local, workspace)?;
7693 if let BasisMetadata::Duchon {
7694 input_scales,
7695 length_scale,
7696 ..
7697 } = &mut result.metadata
7698 {
7699 *input_scales = scales;
7700 *length_scale = spec.length_scale;
7701 }
7702 result
7703 }
7704 SmoothBasisSpec::Pca {
7705 feature_cols,
7706 basis_matrix,
7707 centered,
7708 smooth_penalty,
7709 center_mean,
7710 pca_basis_path,
7711 chunk_size,
7712 } => {
7713 if term.shape != ShapeConstraint::None {
7714 crate::bail_invalid_basis!(
7715 "ShapeConstraint::{:?} for term '{}' is not supported on Pca basis",
7716 term.shape,
7717 term.name
7718 );
7719 }
7720 build_pca_smooth_basis(
7721 data,
7722 feature_cols,
7723 basis_matrix,
7724 *centered,
7725 *smooth_penalty,
7726 center_mean.as_ref(),
7727 pca_basis_path.as_ref(),
7728 *chunk_size,
7729 )?
7730 }
7731 SmoothBasisSpec::TensorBSpline { feature_cols, spec } => {
7732 build_tensor_bspline_basis(data, feature_cols, spec)?
7733 }
7734 SmoothBasisSpec::ByVariable { .. } => {
7735 crate::bail_invalid_basis!(
7736 "internal: ByVariable smooths must return before inner basis dispatch"
7737 );
7738 }
7739 SmoothBasisSpec::BySmooth { .. } => {
7740 crate::bail_invalid_basis!("internal: BySmooth smooths must be lowered to ByVariable before inner basis dispatch"
7741 .to_string(),);
7742 }
7743 SmoothBasisSpec::FactorSmooth { spec } => {
7744 if term.shape != ShapeConstraint::None {
7745 crate::bail_invalid_basis!(
7746 "ShapeConstraint::{:?} is unsupported for factor smooth term '{}'",
7747 term.shape,
7748 term.name
7749 );
7750 }
7751 return build_factor_smooth(data, spec, &term.name, workspace);
7752 }
7753 };
7754
7755 if let SmoothBasisSpec::Matern { .. } = &term.basis {
7771 let (penalties, nullspace_dims, penaltyinfo) =
7772 matern_operator_penalty_triplet_from_metadata(&built.metadata)?;
7773 built.penalties = penalties;
7774 built.nullspace_dims = nullspace_dims;
7775 built.penaltyinfo = penaltyinfo;
7776 }
7777
7778 let p_local = built.design.ncols();
7779 let mut metadata = built.metadata.clone();
7780 let kron_factored = if term.shape == ShapeConstraint::None {
7783 built.kronecker_factored
7784 } else {
7785 None
7786 };
7787 let mut design_t = built.design;
7788 let mut penalties_t: Vec<Array2<f64>> = built.penalties;
7789 let mut ops_t: Vec<Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>> =
7794 built.ops;
7795 if matches!(
7796 spatial_identifiability_policy(term),
7797 Some(SpatialIdentifiability::OrthogonalToParametric)
7798 ) {
7799 metadata = freeze_raw_spatial_metadata(metadata, design_t.ncols());
7800 }
7801
7802 let active_penaltyinfo_t = built
7803 .penaltyinfo
7804 .iter()
7805 .filter(|info| info.active)
7806 .cloned()
7807 .collect::<Vec<_>>();
7808 let pre_dropped_penaltyinfo_t = built
7809 .penaltyinfo
7810 .iter()
7811 .filter(|info| !info.active)
7812 .cloned()
7813 .collect::<Vec<_>>();
7814 let use_box_reparam =
7815 term.shape != ShapeConstraint::None && shape_uses_box_reparameterization(&term.basis);
7816 if let Some((order, sign)) = shape_order_and_sign(term.shape)
7817 && use_box_reparam
7818 {
7819 let t = if order == 2 {
7834 let bspline_meta = match &metadata {
7835 BasisMetadata::BSpline1D {
7836 knots,
7837 degree,
7838 periodic,
7839 ..
7840 } if periodic.is_none() => Some((knots.clone(), degree.unwrap_or(0))),
7841 _ => None,
7842 };
7843 match bspline_meta {
7844 Some((knots, degree)) if degree >= 1 => {
7845 let greville = crate::basis::compute_greville_abscissae(&knots, degree)?;
7846 if greville.len() != p_local {
7847 crate::bail_invalid_basis!(
7848 "shape-constraint Greville abscissae count {} does not match basis dim {} for term '{}'",
7849 greville.len(),
7850 p_local,
7851 term.name
7852 );
7853 }
7854 convex_divided_difference_transform_matrix(&greville, sign)?
7855 }
7856 _ => cumulative_sum_transform_matrix(p_local, order, sign),
7857 }
7858 } else {
7859 cumulative_sum_transform_matrix(p_local, order, sign)
7860 };
7861 let inner_dense = match design_t {
7865 DesignMatrix::Dense(d) => d,
7866 DesignMatrix::Sparse(sp) => gam_linalg::matrix::DenseDesignMatrix::from(
7867 sp.try_to_dense_arc("shape-constrained coefficient transform")
7868 .map_err(BasisError::InvalidInput)?,
7869 ),
7870 };
7871 let coeff_op = gam_linalg::matrix::CoefficientTransformOperator::new(inner_dense, t.clone())
7872 .map_err(|e| BasisError::InvalidInput(format!("CoefficientTransformOperator: {e}")))?;
7873 design_t = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(coeff_op)));
7874 if penalties_t.len() != active_penaltyinfo_t.len() {
7875 crate::bail_invalid_basis!(
7876 "internal box-reparam penalty/info mismatch for term '{}': penalties={}, infos={}",
7877 term.name,
7878 penalties_t.len(),
7879 active_penaltyinfo_t.len()
7880 );
7881 }
7882 let transformed_wiggliness = penalties_t
7898 .iter()
7899 .zip(active_penaltyinfo_t.iter())
7900 .find(|(_, info)| !matches!(info.source, PenaltySource::DoublePenaltyNullspace))
7901 .map(|(s_local, _)| {
7902 let tt_s = fast_atb(&t, s_local);
7903 fast_ab(&tt_s, &t)
7904 });
7905 let mut rebuilt = Vec::with_capacity(penalties_t.len());
7906 for (s_local, info) in penalties_t.iter().zip(active_penaltyinfo_t.iter()) {
7907 if matches!(info.source, PenaltySource::DoublePenaltyNullspace) {
7908 if order == 2 {
7943 let tt_s = fast_atb(&t, s_local);
7944 rebuilt.push(fast_ab(&tt_s, &t));
7945 } else {
7946 let s_wiggle_t = transformed_wiggliness.as_ref().ok_or_else(|| {
7947 BasisError::InvalidInput(format!(
7948 "box-reparam term '{}' has a double-penalty ridge but no primary wiggliness penalty to derive its nullspace from",
7949 term.name
7950 ))
7951 })?;
7952 let ridge = crate::basis::build_nullspace_shrinkage_penalty(s_wiggle_t)?
7953 .map(|shrink| shrink.sym_penalty)
7954 .unwrap_or_else(|| Array2::<f64>::zeros((p_local, p_local)));
7955 rebuilt.push(ridge);
7956 }
7957 } else {
7958 let tt_s = fast_atb(&t, s_local);
7959 rebuilt.push(fast_ab(&tt_s, &t));
7960 }
7961 }
7962 penalties_t = rebuilt;
7963 ops_t = vec![None; penalties_t.len()];
7966 }
7967 if penalties_t.len() != active_penaltyinfo_t.len() {
7968 crate::bail_invalid_basis!(
7969 "internal penalty metadata mismatch for term '{}': active penalties={}, active infos={}",
7970 term.name,
7971 penalties_t.len(),
7972 active_penaltyinfo_t.len()
7973 );
7974 }
7975 if ops_t.len() != penalties_t.len() {
7976 ops_t = vec![None; penalties_t.len()];
7977 }
7978 let penalty_candidates = penalties_t
7979 .into_iter()
7980 .zip(active_penaltyinfo_t.into_iter())
7981 .zip(ops_t.into_iter())
7982 .map(
7983 |((matrix, info), op_in)| -> Result<PenaltyCandidate, BasisError> {
7984 let (matrix, c_new) = normalize_penalty_in_constrained_space(&matrix);
7985 let normalization_scale = info.normalization_scale * c_new;
7986 let op_scale = 1.0 / c_new;
7987 let kronecker_scale = 1.0 / c_new;
7988 let scaled_op = if op_scale > 0.0 && op_scale.is_finite() {
7991 op_in.map(|op| {
7992 std::sync::Arc::new(crate::analytic_penalties::ScaledPenaltyOp::new(
7993 op, op_scale,
7994 ))
7995 as std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>
7996 })
7997 } else {
7998 None
7999 };
8000 let kronecker_factors = info.kronecker_factors.map(|mut factors| {
8001 if let Some(first) = factors.first_mut() {
8002 first.mapv_inplace(|v| v * kronecker_scale);
8003 }
8004 factors
8005 });
8006 Ok(PenaltyCandidate {
8007 nullspace_dim_hint: info.nullspace_dim_hint,
8008 matrix,
8009 source: info.source,
8010 normalization_scale,
8011 kronecker_factors,
8012 op: scaled_op,
8013 })
8014 },
8015 )
8016 .collect::<Result<Vec<_>, _>>()?;
8017 let (penalties_t, nullspaces_t, penaltyinfo_t, null_eigenvectors_t, ops_t) =
8018 crate::basis::filter_active_penalty_candidates_with_ops(penalty_candidates)?;
8019 let shape_linear_constraints = if term.shape != ShapeConstraint::None && !use_box_reparam {
8020 let axis = shape_axis_col.ok_or_else(|| {
8021 BasisError::InvalidInput(format!(
8022 "internal shape-constraint axis missing for term '{}'",
8023 term.name
8024 ))
8025 })?;
8026 let (x_shape_eval, design_shape_eval) =
8027 build_shape_constraint_design_1d(data, term, &metadata, axis)?;
8028 build_shape_linear_constraints_1d(
8029 x_shape_eval.view(),
8030 design_shape_eval.view(),
8031 term.shape,
8032 )?
8033 } else {
8034 None
8035 };
8036 let linear_constraints_local = merge_linear_constraints_global(shape_linear_constraints, None);
8037
8038 let joint_null_rotation = match term.joint_null_rotation.clone() {
8057 Some(persisted) => Some(persisted),
8058 None if smooth_has_frozen_identifiability(term) => None,
8059 None if kron_factored.is_some() => None,
8060 None => crate::basis::compute_joint_null_rotation(&penalties_t)?,
8061 };
8062
8063 Ok(LocalSmoothTermBuild {
8064 dim: p_local,
8065 design: design_t,
8066 penalties: penalties_t,
8067 ops: ops_t,
8068 nullspaces: nullspaces_t,
8069 null_eigenvectors: null_eigenvectors_t,
8070 joint_null_rotation,
8071 penaltyinfo: penaltyinfo_t,
8072 pre_dropped_penaltyinfo: pre_dropped_penaltyinfo_t,
8073 metadata,
8074 linear_constraints: linear_constraints_local,
8075 box_reparam: use_box_reparam,
8076 kronecker_factored: kron_factored,
8077 })
8078}
8079
8080pub fn build_smooth_design(
8081 data: ArrayView2<'_, f64>,
8082 terms: &[SmoothTermSpec],
8083) -> Result<RawSmoothDesign, BasisError> {
8084 let mut ws = crate::basis::BasisWorkspace::new();
8085 build_smooth_design_withworkspace(data, terms, &mut ws)
8086}
8087
8088pub fn build_smooth_design_withworkspace(
8095 data: ArrayView2<'_, f64>,
8096 terms: &[SmoothTermSpec],
8097 workspace: &mut crate::basis::BasisWorkspace,
8098) -> Result<RawSmoothDesign, BasisError> {
8099 validate_smooth_terms_finite_inputs(data, terms)?;
8100 build_smooth_design_withworkspace_unvalidated(data, terms, workspace)
8101}
8102
8103pub fn build_smooth_design_withworkspace_unvalidated(
8104 data: ArrayView2<'_, f64>,
8105 terms: &[SmoothTermSpec],
8106 workspace: &mut crate::basis::BasisWorkspace,
8107) -> Result<RawSmoothDesign, BasisError> {
8108 let mut planned_blocks = plan_joint_spatial_centers_for_term_blocks(data, &[terms.to_vec()])?;
8109 let planned_terms = planned_blocks.pop().ok_or_else(|| {
8110 BasisError::InvalidInput(
8111 "joint spatial center planner returned no smooth blocks".to_string(),
8112 )
8113 })?;
8114 let policy = workspace.policy().clone();
8115 let local_builds: Vec<LocalSmoothTermBuild> = {
8116 use rayon::iter::{IntoParallelIterator, ParallelIterator};
8117 planned_terms
8118 .into_par_iter()
8119 .map(|term| {
8120 let mut term_workspace = crate::basis::BasisWorkspace::with_policy(policy.clone());
8121 build_single_local_smooth_term(data, &term, &mut term_workspace)
8122 })
8123 .collect::<Result<Vec<_>, _>>()?
8124 };
8125
8126 let total_p: usize = local_builds.iter().map(|built| built.dim).sum();
8127
8128 let mut local_designs: Vec<DesignMatrix> = Vec::with_capacity(local_builds.len());
8129 let mut terms_out = Vec::<SmoothTerm>::with_capacity(terms.len());
8130 let mut penalties_global = Vec::<BlockwisePenalty>::new();
8131 let mut nullspace_dims_global = Vec::<usize>::new();
8132 let mut penaltyinfo_global = Vec::<PenaltyBlockInfo>::new();
8133 let mut dropped_penaltyinfo_global = Vec::<DroppedPenaltyBlockInfo>::new();
8134 let mut coefficient_lower_bounds = Array1::<f64>::from_elem(total_p, f64::NEG_INFINITY);
8135 let mut any_bounds = false;
8136 let mut linear_constraintsrows: Vec<(usize, usize, Array1<f64>)> = Vec::new();
8141 let mut linear_constraints_b: Vec<f64> = Vec::new();
8142
8143 let mut col_start = 0usize;
8144 for (term, mut built) in terms.iter().zip(local_builds.into_iter()) {
8145 let p_local = built.dim;
8146 let col_end = col_start + p_local;
8147 let lb_local = if built.box_reparam {
8148 shape_lower_bounds_local(term.shape, p_local)
8149 } else {
8150 None
8151 };
8152
8153 let applied_rotation: Option<crate::basis::JointNullRotation> = match (
8185 built.joint_null_rotation.take(),
8186 lb_local.is_some(),
8187 built.linear_constraints.is_some(),
8188 ) {
8189 (Some(rot), false, false) => {
8190 let q = &rot.rotation;
8191 let dense = built
8192 .design
8193 .try_to_dense_by_chunks("joint-null absorption rotation")
8194 .map_err(BasisError::InvalidInput)?;
8195 let rotated = gam_linalg::faer_ndarray::fast_ab(&dense, q);
8196 built.design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(rotated));
8197 built.penalties = built
8198 .penalties
8199 .into_iter()
8200 .map(|s_local| {
8201 let qt_s = gam_linalg::faer_ndarray::fast_atb(q, &s_local);
8202 gam_linalg::faer_ndarray::fast_ab(&qt_s, q)
8203 })
8204 .collect();
8205 built.ops = vec![None; built.penalties.len()];
8206 built.kronecker_factored = None;
8207 Some(rot)
8208 }
8209 (Some(_), _, _) => None,
8210 (None, _, _) => None,
8211 };
8212
8213 let activeinfos = built
8214 .penaltyinfo
8215 .iter()
8216 .filter(|info| info.active)
8217 .collect::<Vec<_>>();
8218 if activeinfos.len() != built.penalties.len() {
8219 crate::bail_invalid_basis!(
8220 "internal penalty info mismatch for term '{}': activeinfos={}, penalties={}",
8221 term.name,
8222 activeinfos.len(),
8223 built.penalties.len()
8224 );
8225 }
8226 for (((s_local, &ns), info), op_local) in built
8227 .penalties
8228 .iter()
8229 .zip(built.nullspaces.iter())
8230 .zip(activeinfos.into_iter())
8231 .zip(built.ops.iter())
8232 {
8233 let global_index = penalties_global.len();
8234 penalties_global.push(
8235 BlockwisePenalty::new(col_start..col_end, s_local.clone())
8236 .with_op(op_local.clone()),
8237 );
8238 nullspace_dims_global.push(ns);
8239 let mut penalty = info.clone();
8240 penalty.nullspace_dim_hint = ns;
8241 penaltyinfo_global.push(PenaltyBlockInfo {
8242 global_index,
8243 termname: Some(term.name.clone()),
8244 penalty,
8245 });
8246 }
8247 for info in built.penaltyinfo.iter().filter(|info| !info.active) {
8248 dropped_penaltyinfo_global.push(DroppedPenaltyBlockInfo {
8249 termname: Some(term.name.clone()),
8250 penalty: info.clone(),
8251 });
8252 }
8253 for info in &built.pre_dropped_penaltyinfo {
8254 dropped_penaltyinfo_global.push(DroppedPenaltyBlockInfo {
8255 termname: Some(term.name.clone()),
8256 penalty: info.clone(),
8257 });
8258 }
8259
8260 if let Some(lin_local) = &built.linear_constraints {
8261 for r in 0..lin_local.a.nrows() {
8262 linear_constraintsrows.push((col_start, col_end, lin_local.a.row(r).to_owned()));
8263 linear_constraints_b.push(lin_local.b[r]);
8264 }
8265 }
8266 if let Some(lb_local) = &lb_local {
8267 coefficient_lower_bounds
8268 .slice_mut(s![col_start..col_end])
8269 .assign(lb_local);
8270 any_bounds = true;
8271 }
8272
8273 local_designs.push(built.design);
8275
8276 terms_out.push(SmoothTerm {
8277 name: term.name.clone(),
8278 coeff_range: col_start..col_end,
8279 shape: term.shape,
8280 penalties_local: built.penalties,
8281 nullspace_dims: built.nullspaces,
8282 penaltyinfo_local: built.penaltyinfo,
8283 metadata: built.metadata,
8284 lower_bounds_local: lb_local,
8285 linear_constraints_local: built.linear_constraints,
8286 kronecker_factored: built.kronecker_factored.take(),
8287 joint_null_rotation: applied_rotation,
8288 unabsorbed_global_orthogonality: None,
8289 });
8290
8291 col_start = col_end;
8292 }
8293
8294 assert_eq!(
8295 penalties_global.len(),
8296 nullspace_dims_global.len(),
8297 "global smooth penalty/nullspace bookkeeping diverged"
8298 );
8299 assert_eq!(
8300 penalties_global.len(),
8301 penaltyinfo_global.len(),
8302 "global smooth penalty metadata bookkeeping diverged"
8303 );
8304
8305 Ok(RawSmoothDesign {
8306 term_designs: local_designs,
8307 penalties: penalties_global,
8308 nullspace_dims: nullspace_dims_global,
8309 penaltyinfo: penaltyinfo_global,
8310 dropped_penaltyinfo: dropped_penaltyinfo_global,
8311 terms: terms_out,
8312 coefficient_lower_bounds: if any_bounds {
8313 Some(coefficient_lower_bounds)
8314 } else {
8315 None
8316 },
8317 linear_constraints: if linear_constraintsrows.is_empty() {
8318 None
8319 } else {
8320 let mut a = Array2::<f64>::zeros((linear_constraintsrows.len(), total_p));
8321 for (i, (cs, ce, values)) in linear_constraintsrows.iter().enumerate() {
8322 a.row_mut(i).slice_mut(s![*cs..*ce]).assign(values);
8323 }
8324 Some(LinearInequalityConstraints {
8325 a,
8326 b: Array1::from_vec(linear_constraints_b),
8327 })
8328 },
8329 })
8330}