1use std::collections::{BTreeMap, BTreeSet, HashMap};
8use std::path::PathBuf;
9
10use ndarray::{Array2, ArrayView1};
11
12use crate::basis::{
13 BSplineBasisSpec, BSplineBoundaryConditions, BSplineEndpointBoundaryCondition,
14 BSplineIdentifiability, BSplineKnotSpec, CenterCountRequest, CenterStrategy,
15 ConstantCurvatureBasisSpec, ConstantCurvatureIdentifiability, DuchonBasisSpec,
16 DuchonNullspaceOrder, DuchonOperatorPenaltySpec, MaternBasisSpec, MaternIdentifiability,
17 MaternNu, MeasureJetBasisSpec, MeasureJetIdentifiability, OneDimensionalBoundary,
18 SpatialIdentifiability, SphereMethod, SphereWahbaKernel, SphericalSplineBasisSpec,
19 SphericalSplineIdentifiability, ThinPlateBasisSpec, auto_spatial_center_strategy,
20 default_num_centers, default_spatial_center_strategy, default_spherical_harmonic_degree,
21 plan_spatial_basis,
22};
23use crate::inference::formula_dsl::{
24 ParsedTerm, SmoothKind, option_bool, option_f64, option_f64_strict, option_usize,
25 option_usize_any, option_usize_any_strict, option_usize_strict, strip_quotes,
26};
27use crate::smooth::{
28 ByVarKind, FactorSmoothFlavour, FactorSmoothSpec, LinearCoefficientGeometry, LinearTermSpec,
29 RandomEffectTermSpec, ShapeConstraint, SmoothBasisSpec, SmoothTermSpec,
30 TensorBSplineIdentifiability, TensorBSplinePenaltyDecomposition, TensorBSplineSpec,
31 TermCollectionSpec,
32};
33use gam_problem::types::ColIdx;
34use gam_data::{ColumnKindTag, DataError, EncodedDataset as Dataset};
35use gam_runtime::resource::ResourcePolicy;
36
37const DEFAULT_MATERN_LENGTH_SCALE_FLOOR: f64 = 1e-6;
40
41const DEFAULT_BSPLINE_DEGREE: usize = 3;
45
46const DEFAULT_PENALTY_ORDER: usize = 2;
50
51const CYCLIC_DEFAULT_BASIS_DIM: usize = 12;
57
58const FACTOR_SMOOTH_DEFAULT_BASIS_DIM: usize = 10;
64
65const DEFAULT_PCA_CHUNK_SIZE: usize = 4096;
69
70fn default_matern_length_scale(ds: &Dataset, cols: &[usize]) -> f64 {
71 let mut diameter2 = 0.0_f64;
72 for &col in cols {
73 let column = ds.values.column(col);
74 let mut lo = f64::INFINITY;
75 let mut hi = f64::NEG_INFINITY;
76 for &value in column.iter().filter(|v| v.is_finite()) {
77 lo = lo.min(value);
78 hi = hi.max(value);
79 }
80 if lo.is_finite() && hi.is_finite() && hi > lo {
81 let span = hi - lo;
82 diameter2 += span * span;
83 }
84 }
85 let diameter = diameter2.sqrt();
86 if diameter.is_finite() && diameter > 0.0 {
87 diameter.max(DEFAULT_MATERN_LENGTH_SCALE_FLOOR)
94 } else {
95 1.0
96 }
97}
98
99#[derive(Clone, Debug)]
109pub enum TermBuilderError {
110 MissingColumn { reason: String },
117 ColumnNotFound {
123 name: String,
124 role: Option<String>,
125 available: Vec<String>,
126 similar: Vec<String>,
127 tsv_hint: bool,
128 },
129 IncompatibleConfig { reason: String },
133 InvalidOption { reason: String },
136 UnsupportedFeature { reason: String },
140 DegenerateData { reason: String },
143 MalformedFormula { reason: String },
146}
147
148impl std::fmt::Display for TermBuilderError {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 match self {
151 TermBuilderError::MissingColumn { reason }
152 | TermBuilderError::IncompatibleConfig { reason }
153 | TermBuilderError::InvalidOption { reason }
154 | TermBuilderError::UnsupportedFeature { reason }
155 | TermBuilderError::DegenerateData { reason }
156 | TermBuilderError::MalformedFormula { reason } => f.write_str(reason),
157 TermBuilderError::ColumnNotFound {
163 name,
164 role,
165 available,
166 similar,
167 tsv_hint,
168 } => {
169 let canonical = DataError::ColumnNotFound {
170 name: name.clone(),
171 role: role.clone(),
172 available: available.clone(),
173 similar: similar.clone(),
174 tsv_hint: *tsv_hint,
175 };
176 std::fmt::Display::fmt(&canonical, f)
177 }
178 }
179 }
180}
181
182impl From<TermBuilderError> for String {
183 fn from(err: TermBuilderError) -> String {
184 err.to_string()
185 }
186}
187
188impl From<String> for TermBuilderError {
195 fn from(reason: String) -> Self {
196 Self::IncompatibleConfig { reason }
197 }
198}
199
200impl From<DataError> for TermBuilderError {
207 fn from(err: DataError) -> Self {
208 match err {
209 DataError::ColumnNotFound {
210 name,
211 role,
212 available,
213 similar,
214 tsv_hint,
215 } => Self::ColumnNotFound {
216 name,
217 role,
218 available,
219 similar,
220 tsv_hint,
221 },
222 DataError::SchemaMismatch { reason }
223 | DataError::ParseError { reason }
224 | DataError::EncodingFailure { reason }
225 | DataError::EmptyInput { reason }
226 | DataError::InvalidValue { reason } => Self::MissingColumn { reason },
227 }
228 }
229}
230
231impl TermBuilderError {
233 #[inline]
234 fn missing_column(reason: impl Into<String>) -> Self {
235 TermBuilderError::MissingColumn {
236 reason: reason.into(),
237 }
238 }
239 #[inline]
240 fn incompatible_config(reason: impl Into<String>) -> Self {
241 TermBuilderError::IncompatibleConfig {
242 reason: reason.into(),
243 }
244 }
245 #[inline]
246 fn invalid_option(reason: impl Into<String>) -> Self {
247 TermBuilderError::InvalidOption {
248 reason: reason.into(),
249 }
250 }
251 #[inline]
252 fn unsupported_feature(reason: impl Into<String>) -> Self {
253 TermBuilderError::UnsupportedFeature {
254 reason: reason.into(),
255 }
256 }
257 #[inline]
258 fn degenerate_data(reason: impl Into<String>) -> Self {
259 TermBuilderError::DegenerateData {
260 reason: reason.into(),
261 }
262 }
263 #[inline]
264 fn malformed_formula(reason: impl Into<String>) -> Self {
265 TermBuilderError::MalformedFormula {
266 reason: reason.into(),
267 }
268 }
269}
270
271pub fn resolve_col(col_map: &HashMap<String, usize>, name: &str) -> Result<usize, DataError> {
282 col_map
283 .get(name)
284 .copied()
285 .ok_or_else(|| DataError::column_not_found(col_map, name, None))
286}
287
288pub fn resolve_role_col(
293 col_map: &HashMap<String, usize>,
294 name: &str,
295 role: &str,
296) -> Result<usize, DataError> {
297 col_map
298 .get(name)
299 .copied()
300 .ok_or_else(|| DataError::column_not_found(col_map, name, Some(role)))
301}
302
303fn encoded_levels_for_column(ds: &Dataset, col: ColIdx) -> Vec<(u64, String)> {
304 let mut seen = BTreeSet::<u64>::new();
305 for value in ds.values.column(col.get()) {
306 if value.is_finite() {
307 seen.insert(value.to_bits());
308 }
309 }
310 let schema_levels = ds
311 .schema
312 .columns
313 .get(col.get())
314 .map(|column| column.levels.as_slice())
315 .unwrap_or(&[]);
316 seen.into_iter()
317 .enumerate()
318 .map(|(idx, bits)| {
319 let fallback = format!("level{}", idx + 1);
320 let label = schema_levels.get(idx).cloned().unwrap_or(fallback);
321 (bits, label)
322 })
323 .collect()
324}
325
326pub fn column_map_with_alias(
327 col_map: &HashMap<String, usize>,
328 alias: &str,
329 target_column: &str,
330) -> HashMap<String, usize> {
331 let mut aliased = col_map.clone();
332 if let Some(idx) = col_map.get(target_column).copied() {
333 aliased.entry(alias.to_string()).or_insert(idx);
334 }
335 aliased
336}
337
338pub fn build_termspec(
343 terms: &[ParsedTerm],
344 ds: &Dataset,
345 col_map: &HashMap<String, usize>,
346 inference_notes: &mut Vec<String>,
347 policy: &ResourcePolicy,
348) -> Result<TermCollectionSpec, TermBuilderError> {
349 let mut linear_terms = Vec::<LinearTermSpec>::new();
350 let mut random_terms = Vec::<RandomEffectTermSpec>::new();
351 let mut smooth_terms = Vec::<SmoothTermSpec>::new();
352 let smooth_coordinate_count = terms
353 .iter()
354 .map(|term| match term {
355 ParsedTerm::Smooth { vars, .. } => vars.len(),
356 _ => 0,
357 })
358 .sum::<usize>();
359
360 for t in terms {
361 match t {
362 ParsedTerm::Linear {
363 name,
364 explicit,
365 coefficient_min,
366 coefficient_max,
367 } => {
368 let col = resolve_col(col_map, name)?;
369 let auto_kind = ds.column_kinds.get(col).copied().ok_or_else(|| {
370 TermBuilderError::missing_column(format!(
371 "internal column-kind lookup failed for '{name}'"
372 ))
373 .to_string()
374 })?;
375 if *explicit {
376 linear_terms.push(LinearTermSpec {
377 name: name.clone(),
378 feature_col: col,
379 feature_cols: vec![col],
380 categorical_levels: vec![],
381 double_penalty: false,
384 coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
385 coefficient_min: *coefficient_min,
386 coefficient_max: *coefficient_max,
387 });
388 } else {
389 match auto_kind {
390 ColumnKindTag::Continuous | ColumnKindTag::Binary => {
391 linear_terms.push(LinearTermSpec {
392 name: name.clone(),
393 feature_col: col,
394 feature_cols: vec![col],
395 categorical_levels: vec![],
396 double_penalty: false,
398 coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
399 coefficient_min: *coefficient_min,
400 coefficient_max: *coefficient_max,
401 });
402 }
403 ColumnKindTag::Categorical => {
404 if coefficient_min.is_some() || coefficient_max.is_some() {
405 return Err(TermBuilderError::incompatible_config(format!(
406 "coefficient constraints are not supported for categorical auto-random-effect term '{name}'; use group({name}) or an unconstrained numeric term"
407 )));
408 }
409 random_terms.push(RandomEffectTermSpec {
410 name: name.clone(),
411 feature_col: col,
412 drop_first_level: false,
413 penalized: true,
414 frozen_levels: None,
415 });
416 }
417 }
418 }
419 }
420 ParsedTerm::BoundedLinear {
421 name,
422 min,
423 max,
424 prior,
425 } => {
426 let col = resolve_col(col_map, name)?;
427 let auto_kind = ds.column_kinds.get(col).copied().ok_or_else(|| {
428 TermBuilderError::missing_column(format!(
429 "internal column-kind lookup failed for '{name}'"
430 ))
431 .to_string()
432 })?;
433 if !matches!(auto_kind, ColumnKindTag::Continuous | ColumnKindTag::Binary) {
434 return Err(TermBuilderError::incompatible_config(format!(
435 "bounded() currently supports only numeric columns, got categorical '{name}'"
436 )));
437 }
438 linear_terms.push(LinearTermSpec {
439 name: name.clone(),
440 feature_col: col,
441 feature_cols: vec![col],
442 categorical_levels: vec![],
443 double_penalty: false,
444 coefficient_geometry: LinearCoefficientGeometry::Bounded {
445 min: *min,
446 max: *max,
447 prior: prior.clone(),
448 },
449 coefficient_min: None,
450 coefficient_max: None,
451 });
452 }
453 ParsedTerm::RandomEffect { name } => {
454 let col = resolve_col(col_map, name)?;
455 random_terms.push(RandomEffectTermSpec {
456 name: name.clone(),
457 feature_col: col,
458 drop_first_level: false,
459 penalized: true,
460 frozen_levels: None,
461 });
462 }
463 ParsedTerm::Smooth {
464 label,
465 vars,
466 kind,
467 options,
468 } => {
469 let smooth_vars = vars.clone();
470 let by_name = options.get("by").cloned();
471 let cols = smooth_vars
481 .iter()
482 .map(|v| resolve_col(col_map, v))
483 .collect::<Result<Vec<_>, _>>()?;
484 let mut inner_options = options.clone();
485 inner_options.remove("by");
486 inner_options.remove("ordered");
490 let shape = match inner_options.remove("shape") {
496 None => ShapeConstraint::None,
497 Some(raw) => crate::smooth::parse_shape_constraint(&raw)
498 .map_err(TermBuilderError::invalid_option)?,
499 };
500 let inner_basis = build_smooth_basis(
501 *kind,
502 &smooth_vars,
503 &cols,
504 &inner_options,
505 ds,
506 inference_notes,
507 policy,
508 smooth_coordinate_count,
509 )?;
510 if let Some(by_name) = by_name {
511 let by_col = resolve_col(col_map, &by_name)?;
512 match ds.column_kinds.get(by_col).copied().ok_or_else(|| {
513 format!("internal column-kind lookup failed for by variable '{by_name}'")
514 })? {
515 ColumnKindTag::Categorical => {
516 let levels = encoded_levels_for_column(ds, ColIdx::new(by_col));
517 let penalized_group_owner_present =
530 terms.iter().any(|other| match other {
531 ParsedTerm::RandomEffect { name } => name == &by_name,
532 ParsedTerm::Linear {
533 name,
534 explicit: false,
535 ..
536 } if name == &by_name => col_map
537 .get(name)
538 .and_then(|c| ds.column_kinds.get(*c).copied())
539 .map(|kind| matches!(kind, ColumnKindTag::Categorical))
540 .unwrap_or(false),
541 _ => false,
542 });
543 if !random_terms.iter().any(|rt| rt.name == by_name)
554 && !penalized_group_owner_present
555 {
556 random_terms.push(RandomEffectTermSpec {
557 name: by_name.clone(),
558 feature_col: by_col,
559 drop_first_level: true,
560 penalized: false,
561 frozen_levels: None,
562 });
563 }
564 let frozen_levels: Vec<u64> =
569 levels.iter().map(|(bits, _)| *bits).collect();
570 smooth_terms.push(SmoothTermSpec {
571 name: label.clone(),
572 basis: SmoothBasisSpec::BySmooth {
573 smooth: Box::new(inner_basis),
574 by_kind: ByVarKind::Factor {
575 feature_col: by_col,
576 ordered: option_bool(options, "ordered").unwrap_or(false),
577 frozen_levels: Some(frozen_levels),
578 },
579 },
580 shape,
581 joint_null_rotation: None,
582 });
583 }
584 ColumnKindTag::Binary | ColumnKindTag::Continuous => {
585 smooth_terms.push(SmoothTermSpec {
586 name: label.clone(),
587 basis: SmoothBasisSpec::BySmooth {
588 smooth: Box::new(inner_basis),
589 by_kind: ByVarKind::Numeric {
590 feature_col: by_col,
591 },
592 },
593 shape,
594 joint_null_rotation: None,
595 });
596 }
597 }
598 } else {
599 smooth_terms.push(SmoothTermSpec {
600 name: label.clone(),
601 basis: inner_basis,
602 shape,
603 joint_null_rotation: None,
604 });
605 }
606 }
607 ParsedTerm::LinkWiggle { .. }
608 | ParsedTerm::TimeWiggle { .. }
609 | ParsedTerm::LinkConfig { .. }
610 | ParsedTerm::SurvivalConfig { .. } => {
611 }
613 ParsedTerm::LogSlopeSurface { .. } => {
614 return Err(TermBuilderError::malformed_formula(
615 "logslope(...) declarations must be resolved by the marginal-slope formula path before building a term spec",
616 ));
617 }
618 ParsedTerm::Interaction { vars } => {
619 let main_effect_present = |target: &str| -> bool {
652 terms.iter().any(|other| match other {
653 ParsedTerm::Linear { name, .. }
654 | ParsedTerm::BoundedLinear { name, .. }
655 | ParsedTerm::RandomEffect { name } => name == target,
656 _ => false,
657 })
658 };
659 let parent_present = |drop_var: &str| -> bool {
665 vars.iter()
666 .filter(|v| v.as_str() != drop_var)
667 .all(|v| main_effect_present(v))
668 };
669
670 let mut numeric_cols = Vec::<usize>::new();
671 let mut categorical_factors =
674 Vec::<(String, usize, Vec<(u64, String)>, bool)>::new();
675 for var in vars {
676 let col = resolve_col(col_map, var)?;
677 let kind = ds.column_kinds.get(col).copied().ok_or_else(|| {
678 TermBuilderError::missing_column(format!(
679 "internal column-kind lookup failed for '{var}'"
680 ))
681 .to_string()
682 })?;
683 match kind {
684 ColumnKindTag::Continuous | ColumnKindTag::Binary => numeric_cols.push(col),
685 ColumnKindTag::Categorical => {
686 let mut levels = encoded_levels_for_column(ds, ColIdx::new(col));
687 let treatment_coded = parent_present(var);
691 if treatment_coded && levels.len() > 1 {
692 levels.remove(0);
693 }
694 if levels.is_empty() {
695 return Err(TermBuilderError::incompatible_config(format!(
696 "interaction `{}` references categorical column `{var}` with no usable levels",
697 vars.join(":")
698 )));
699 }
700 categorical_factors.push((var.clone(), col, levels, treatment_coded));
701 }
702 }
703 }
704
705 let label = vars.join(":");
706
707 if categorical_factors.is_empty() {
708 linear_terms.push(LinearTermSpec {
711 name: label,
712 feature_col: numeric_cols[0],
713 feature_cols: numeric_cols,
714 categorical_levels: vec![],
715 double_penalty: false,
718 coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
719 coefficient_min: None,
720 coefficient_max: None,
721 });
722 inference_notes.push(format!(
723 "wired linear interaction `{}` as product of numeric columns",
724 vars.join(":")
725 ));
726 } else {
727 let mut cells: Vec<Vec<(usize, u64, String)>> = vec![Vec::new()];
732 for (_var, col, levels, _treatment_coded) in &categorical_factors {
733 let mut next = Vec::with_capacity(cells.len() * levels.len());
734 for cell in &cells {
735 for (bits, level_label) in levels {
736 let mut extended = cell.clone();
737 extended.push((*col, *bits, level_label.clone()));
738 next.push(extended);
739 }
740 }
741 cells = next;
742 }
743
744 let any_dummy_coded = categorical_factors
756 .iter()
757 .any(|(_, _, _, treatment_coded)| !*treatment_coded);
758 if numeric_cols.is_empty() && any_dummy_coded {
759 let reference_cell: Vec<(usize, u64)> = categorical_factors
762 .iter()
763 .map(|(_, col, _, _)| {
764 let levels = encoded_levels_for_column(ds, ColIdx::new(*col));
765 (*col, levels[0].0)
766 })
767 .collect();
768 cells.retain(|cell| {
769 !reference_cell.iter().all(|(rcol, rbits)| {
770 cell.iter()
771 .any(|(col, bits, _)| col == rcol && bits == rbits)
772 })
773 });
774 }
775
776 let n_cells = cells.len();
777 for cell in cells {
778 let cell_suffix = cell
779 .iter()
780 .map(|(_, _, level_label)| level_label.as_str())
781 .collect::<Vec<_>>()
782 .join(":");
783 let categorical_levels =
784 cell.iter().map(|(col, bits, _)| (*col, *bits)).collect();
785 let feature_col = numeric_cols
791 .first()
792 .copied()
793 .unwrap_or(categorical_factors[0].1);
794 linear_terms.push(LinearTermSpec {
795 name: format!("{label}:{cell_suffix}"),
796 feature_col,
797 feature_cols: numeric_cols.clone(),
798 categorical_levels,
799 double_penalty: false,
800 coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
801 coefficient_min: None,
802 coefficient_max: None,
803 });
804 }
805 let all_treatment_coded = !any_dummy_coded;
806 let coding = if all_treatment_coded {
807 "treatment-coded"
808 } else {
809 "marginality-aware (full dummy / saturated)"
810 };
811 inference_notes.push(format!(
812 "wired factor-aware linear interaction `{}` as {} {} cell column(s)",
813 vars.join(":"),
814 n_cells,
815 coding
816 ));
817 }
818 }
819 }
820 }
821
822 Ok(TermCollectionSpec {
823 linear_terms,
824 random_effect_terms: random_terms,
825 smooth_terms,
826 })
827}
828
829fn split_list_option(raw: &str) -> Vec<String> {
830 let t = raw.trim();
831 let inner = t
838 .strip_prefix('[')
839 .and_then(|u| u.strip_suffix(']'))
840 .or_else(|| {
841 t.strip_prefix("c(")
842 .or_else(|| t.strip_prefix("C("))
843 .or_else(|| t.strip_prefix('('))
844 .and_then(|u| u.strip_suffix(')'))
845 })
846 .unwrap_or(t);
847 inner
848 .split(',')
849 .map(|v| v.trim().to_string())
850 .filter(|v| !v.is_empty())
851 .collect()
852}
853
854fn parse_numeric_expr(raw: &str) -> Result<f64, String> {
855 let mut acc = 1.0f64;
856 let normalized = raw.replace(' ', "");
857 if normalized.eq_ignore_ascii_case("none") {
858 return Err("None is not numeric".to_string());
859 }
860 for factor in normalized.split('*') {
861 if factor.is_empty() {
862 return Err(format!("invalid numeric expression '{raw}'"));
863 }
864 let value = if factor.eq_ignore_ascii_case("pi") || factor == "π" {
865 std::f64::consts::PI
866 } else if factor.eq_ignore_ascii_case("tau") || factor == "τ" {
867 std::f64::consts::TAU
868 } else if let Some(prefix) = factor
869 .strip_suffix("pi")
870 .or_else(|| factor.strip_suffix("π"))
871 {
872 let coefficient = if prefix.is_empty() {
873 1.0
874 } else {
875 prefix
876 .parse::<f64>()
877 .map_err(|err| format!("invalid numeric expression '{raw}': {err}"))?
878 };
879 coefficient * std::f64::consts::PI
880 } else if let Some(prefix) = factor
881 .strip_suffix("tau")
882 .or_else(|| factor.strip_suffix("τ"))
883 {
884 let coefficient = if prefix.is_empty() {
885 1.0
886 } else {
887 prefix
888 .parse::<f64>()
889 .map_err(|err| format!("invalid numeric expression '{raw}': {err}"))?
890 };
891 coefficient * std::f64::consts::TAU
892 } else {
893 factor
894 .parse::<f64>()
895 .map_err(|err| format!("invalid numeric expression '{raw}': {err}"))?
896 };
897 acc *= value;
898 }
899 Ok(acc)
900}
901
902fn option_numeric_expr(
912 options: &BTreeMap<String, String>,
913 key: &str,
914) -> Result<Option<f64>, String> {
915 match options.get(key) {
916 None => Ok(None),
917 Some(raw) => parse_numeric_expr(raw)
918 .map(Some)
919 .map_err(|err| format!("option `{key}={raw}` is not a valid numeric value: {err}")),
920 }
921}
922
923fn parse_periods_option(
924 options: &BTreeMap<String, String>,
925 dim: usize,
926) -> Result<Option<Vec<Option<f64>>>, String> {
927 let Some(raw) = options.get("period") else {
928 return Ok(None);
929 };
930 let values = split_list_option(raw);
931 let mut periods = vec![None; dim];
932 if values.len() == 1 && dim == 1 {
933 periods[0] = Some(parse_numeric_expr(&values[0])?);
934 } else {
935 if values.len() != dim {
936 return Err(format!(
937 "period list length {} must match smooth dimension {}",
938 values.len(),
939 dim
940 ));
941 }
942 for (i, v) in values.iter().enumerate() {
943 if v.eq_ignore_ascii_case("none") {
944 continue;
945 }
946 periods[i] = Some(parse_numeric_expr(v)?);
947 }
948 }
949 Ok(Some(periods))
950}
951
952fn parse_periodic_axes_option(
953 options: &BTreeMap<String, String>,
954 dim: usize,
955) -> Result<Option<Vec<Option<f64>>>, String> {
956 let Some(raw_axes) = options.get("periodic") else {
957 return Ok(None);
958 };
959 let mut periods = parse_periods_option(options, dim)?.unwrap_or_else(|| vec![None; dim]);
960 let axes = split_list_option(raw_axes);
961 if axes.is_empty() {
962 return Ok(Some(periods));
963 }
964 for a in axes {
965 let axis = a
966 .parse::<usize>()
967 .map_err(|err| format!("invalid periodic axis '{a}': {err}"))?;
968 if axis >= dim {
969 return Err(format!(
970 "periodic axis {axis} out of range for {dim}D smooth"
971 ));
972 }
973 if periods[axis].is_none() {
974 return Err(format!(
975 "periodic axis {axis} requires period[{axis}] to be finite"
976 ));
977 }
978 }
979 let listed: std::collections::BTreeSet<usize> = split_list_option(raw_axes)
981 .into_iter()
982 .filter_map(|a| a.parse::<usize>().ok())
983 .collect();
984 for i in 0..dim {
985 if !listed.contains(&i) {
986 periods[i] = None;
987 }
988 }
989 Ok(Some(periods))
990}
991
992fn parse_option_list(raw: &str) -> Vec<String> {
997 let trimmed = raw.trim();
998 let inner = trimmed
1004 .strip_prefix('[')
1005 .and_then(|v| v.strip_suffix(']'))
1006 .or_else(|| {
1007 trimmed
1008 .strip_prefix("c(")
1009 .or_else(|| trimmed.strip_prefix("C("))
1010 .or_else(|| trimmed.strip_prefix('('))
1011 .and_then(|v| v.strip_suffix(')'))
1012 })
1013 .unwrap_or(trimmed);
1014 inner
1015 .split(',')
1016 .map(|v| {
1017 v.trim()
1018 .trim_matches('"')
1019 .trim_matches('\'')
1020 .to_ascii_lowercase()
1021 })
1022 .filter(|v| !v.is_empty())
1023 .collect()
1024}
1025
1026fn parse_periodic_axes(
1027 options: &BTreeMap<String, String>,
1028 dim: usize,
1029) -> Result<Vec<bool>, String> {
1030 let mut axes = vec![false; dim];
1031 if let Some(raw) = options.get("periodic").or_else(|| options.get("cyclic")) {
1032 let lowered = raw.trim().to_ascii_lowercase();
1033 match lowered.as_str() {
1034 "true" | "yes" | "y" => {
1035 axes.fill(true);
1036 return Ok(axes);
1037 }
1038 "false" | "no" | "n" => return Ok(axes),
1039 _ => {}
1040 }
1041 for axis_raw in parse_option_list(raw) {
1042 let axis = axis_raw
1043 .parse::<usize>()
1044 .map_err(|err| format!("invalid periodic axis '{axis_raw}': {err}"))?;
1045 if axis >= dim {
1046 return Err(format!(
1047 "periodic axis {axis} out of range for {dim}D smooth"
1048 ));
1049 }
1050 axes[axis] = true;
1051 }
1052 }
1053 if let Some(raw) = options.get("boundary").or_else(|| options.get("bc")) {
1054 let boundary = parse_option_list(raw);
1055 if boundary.len() == dim {
1056 for (axis, value) in boundary.iter().enumerate() {
1057 if matches!(value.as_str(), "periodic" | "cyclic" | "cc") {
1058 axes[axis] = true;
1059 }
1060 }
1061 } else if dim == 1
1062 && matches!(
1063 boundary.first().map(String::as_str),
1064 Some("periodic" | "cyclic" | "cc")
1065 )
1066 {
1067 axes[0] = true;
1068 }
1069 }
1070 Ok(axes)
1071}
1072
1073fn parse_optional_numeric_list(
1074 options: &BTreeMap<String, String>,
1075 keys: &[&str],
1076 dim: usize,
1077) -> Result<Vec<Option<f64>>, String> {
1078 let Some(raw) = keys.iter().find_map(|key| options.get(*key)) else {
1079 return Ok(vec![None; dim]);
1080 };
1081 let values = split_list_option(raw);
1082 let mut out = vec![None; dim];
1083 if values.len() == 1 && dim == 1 {
1084 if !values[0].eq_ignore_ascii_case("none") {
1085 out[0] = Some(parse_numeric_expr(&values[0])?);
1086 }
1087 return Ok(out);
1088 }
1089 if values.len() != dim {
1090 return Err(format!(
1091 "numeric option list length {} must match smooth dimension {}",
1092 values.len(),
1093 dim
1094 ));
1095 }
1096 for (i, value) in values.iter().enumerate() {
1097 if !value.eq_ignore_ascii_case("none") {
1098 out[i] = Some(parse_numeric_expr(value)?);
1099 }
1100 }
1101 Ok(out)
1102}
1103
1104fn parse_periods(
1105 options: &BTreeMap<String, String>,
1106 periodic_axes: &[bool],
1107) -> Result<Vec<Option<f64>>, String> {
1108 let dim = periodic_axes.len();
1109 let lone_periodic_broadcast = options
1114 .get("period")
1115 .or_else(|| options.get("periods"))
1116 .and_then(|raw| {
1117 let values = split_list_option(raw);
1118 if values.len() != 1 || dim <= 1 {
1119 return None;
1120 }
1121 let mut iter = periodic_axes.iter().enumerate().filter(|(_, p)| **p);
1122 let first = iter.next()?;
1123 if iter.next().is_some() {
1124 return None;
1125 }
1126 Some((first.0, values.into_iter().next().unwrap()))
1127 });
1128 let periods = if let Some((axis, value)) = lone_periodic_broadcast {
1129 let mut out = vec![None; dim];
1130 if !value.eq_ignore_ascii_case("none") {
1131 out[axis] = Some(parse_numeric_expr(&value)?);
1132 }
1133 out
1134 } else {
1135 parse_optional_numeric_list(options, &["period", "periods"], dim)?
1136 };
1137 for (axis, (periodic, period)) in periodic_axes.iter().zip(periods.iter()).enumerate() {
1138 if *periodic
1139 && let Some(value) = period
1140 && (!value.is_finite() || *value <= 0.0)
1141 {
1142 return Err(format!(
1143 "period for periodic axis {axis} must be finite and positive, got {value}"
1144 ));
1145 }
1146 }
1147 Ok(periods)
1148}
1149
1150fn parse_period_origins(
1151 options: &BTreeMap<String, String>,
1152 periodic_axes: &[bool],
1153) -> Result<Vec<Option<f64>>, String> {
1154 parse_optional_numeric_list(
1155 options,
1156 &[
1157 "origin",
1158 "origins",
1159 "period_origin",
1160 "period-origin",
1161 "domain_origin",
1162 ],
1163 periodic_axes.len(),
1164 )
1165}
1166
1167fn parse_tensor_periodic_axes(
1175 options: &BTreeMap<String, String>,
1176 dim: usize,
1177) -> Result<Vec<bool>, String> {
1178 let mut axes = vec![false; dim];
1179 if let Some(raw) = options.get("periodic").or_else(|| options.get("cyclic")) {
1180 let lowered = raw.trim().to_ascii_lowercase();
1181 match lowered.as_str() {
1182 "true" | "yes" | "y" => {
1183 axes.fill(true);
1184 }
1185 "false" | "no" | "n" => {
1186 }
1188 _ => {
1189 let entries = parse_option_list(raw);
1190 let all_bool = !entries.is_empty()
1191 && entries.iter().all(|v| {
1192 matches!(
1193 v.as_str(),
1194 "true" | "yes" | "y" | "false" | "no" | "n" | "none"
1195 )
1196 });
1197 if all_bool {
1198 if entries.len() != dim {
1199 return Err(format!(
1200 "periodic list length {} must match smooth dimension {}",
1201 entries.len(),
1202 dim
1203 ));
1204 }
1205 for (i, v) in entries.iter().enumerate() {
1206 axes[i] = matches!(v.as_str(), "true" | "yes" | "y");
1207 }
1208 } else {
1209 for axis_raw in entries {
1210 let axis = axis_raw
1211 .parse::<usize>()
1212 .map_err(|err| format!("invalid periodic axis '{axis_raw}': {err}"))?;
1213 if axis >= dim {
1214 return Err(format!(
1215 "periodic axis {axis} out of range for {dim}D smooth"
1216 ));
1217 }
1218 axes[axis] = true;
1219 }
1220 }
1221 }
1222 }
1223 }
1224 if let Some(raw) = options.get("boundary").or_else(|| options.get("bc")) {
1225 let boundary = parse_option_list(raw);
1226 if boundary.len() == dim {
1227 for (axis, value) in boundary.iter().enumerate() {
1228 if matches!(value.as_str(), "periodic" | "cyclic" | "cc") {
1229 axes[axis] = true;
1230 }
1231 }
1232 }
1233 }
1234 Ok(axes)
1235}
1236
1237fn tensor_k_axis_option_axis(
1238 key: &str,
1239 cols: &[usize],
1240 ds: &Dataset,
1241) -> Result<Option<usize>, String> {
1242 let Some(suffix) = key.strip_prefix("k_") else {
1243 return Ok(None);
1244 };
1245 if suffix.is_empty() {
1246 return Err("tensor k axis option must be named k_<axis> or k_<variable>".to_string());
1247 }
1248 if let Ok(axis) = suffix.parse::<usize>() {
1249 return if axis < cols.len() {
1250 Ok(Some(axis))
1251 } else {
1252 Err(format!(
1253 "tensor k axis option `{key}` references axis {axis}, but the smooth has {} margins",
1254 cols.len()
1255 ))
1256 };
1257 }
1258
1259 let mut matches = cols
1260 .iter()
1261 .enumerate()
1262 .filter(|(_, col)| ds.headers.get(**col).is_some_and(|name| name == suffix))
1263 .map(|(axis, _)| axis);
1264 let first = matches.next();
1265 if matches.next().is_some() {
1266 return Err(format!(
1267 "tensor k axis option `{key}` matches more than one margin named `{suffix}`"
1268 ));
1269 }
1270 first.map(Some).ok_or_else(|| {
1271 let margin_names = cols
1272 .iter()
1273 .enumerate()
1274 .map(|(axis, col)| {
1275 let name = ds
1276 .headers
1277 .get(*col)
1278 .map(String::as_str)
1279 .unwrap_or("<unnamed>");
1280 format!("{axis}:{name}")
1281 })
1282 .collect::<Vec<_>>()
1283 .join(", ");
1284 format!(
1285 "tensor k axis option `{key}` does not match a margin index or name; tensor margins are [{margin_names}]"
1286 )
1287 })
1288}
1289
1290fn is_tensor_k_axis_option_key(key: &str) -> bool {
1291 key.strip_prefix("k_")
1292 .is_some_and(|suffix| !suffix.is_empty())
1293}
1294
1295fn parse_tensor_k_list(
1299 options: &BTreeMap<String, String>,
1300 cols: &[usize],
1301 ds: &Dataset,
1302) -> Result<(Vec<usize>, bool), String> {
1303 let mut axis_values = vec![None; cols.len()];
1304 let mut saw_axis_alias = false;
1305 for (key, value) in options {
1306 let Some(axis) = tensor_k_axis_option_axis(key, cols, ds)? else {
1307 continue;
1308 };
1309 saw_axis_alias = true;
1310 if axis_values[axis].is_some() {
1311 return Err(format!("tensor k axis {axis} is specified more than once"));
1312 }
1313 let k: usize = value
1314 .parse()
1315 .map_err(|err| format!("invalid tensor k option `{key}={value}`: {err}"))?;
1316 axis_values[axis] = Some(k);
1317 }
1318
1319 let raw = options
1320 .get("k")
1321 .or_else(|| options.get("basis_dim"))
1322 .or_else(|| options.get("basis-dim"))
1323 .or_else(|| options.get("basisdim"));
1324 if saw_axis_alias {
1325 if raw.is_some() {
1326 return Err(
1327 "tensor k axis aliases cannot be combined with k= or basis_dim=".to_string(),
1328 );
1329 }
1330 if let Some(missing_axis) = axis_values.iter().position(Option::is_none) {
1331 let margin_name = cols
1332 .get(missing_axis)
1333 .and_then(|col| ds.headers.get(*col))
1334 .map(String::as_str)
1335 .unwrap_or("<unnamed>");
1336 return Err(format!(
1337 "tensor k axis aliases must specify every margin; missing axis {missing_axis} ({margin_name})"
1338 ));
1339 }
1340 return Ok((
1341 axis_values
1342 .into_iter()
1343 .map(|k| k.expect("missing axis values rejected above"))
1344 .collect(),
1345 false,
1346 ));
1347 }
1348 let Some(raw) = raw else {
1349 let inferred = heuristic_tensor_margin_knots(cols, ds);
1350 return Ok((inferred, true));
1351 };
1352 let entries = split_list_option(raw);
1353 if entries.len() == 1 {
1354 let k: usize = entries[0]
1355 .parse()
1356 .map_err(|err| format!("invalid tensor k '{}': {err}", entries[0]))?;
1357 return Ok((vec![k; cols.len()], false));
1358 }
1359 if entries.len() != cols.len() {
1360 return Err(format!(
1361 "tensor k list length {} must match smooth dimension {}",
1362 entries.len(),
1363 cols.len()
1364 ));
1365 }
1366 let mut out = Vec::with_capacity(entries.len());
1367 for entry in entries {
1368 let k: usize = entry
1369 .parse()
1370 .map_err(|err| format!("invalid tensor k '{entry}': {err}"))?;
1371 out.push(k);
1372 }
1373 Ok((out, false))
1374}
1375
1376fn parse_tensor_identifiability(
1385 options: &BTreeMap<String, String>,
1386 kind: SmoothKind,
1387) -> Result<TensorBSplineIdentifiability, String> {
1388 let Some(raw) = options.get("identifiability").map(String::as_str) else {
1389 return Ok(match kind {
1390 SmoothKind::Ti => TensorBSplineIdentifiability::MarginalSumToZero,
1391 _ => TensorBSplineIdentifiability::default(),
1392 });
1393 };
1394 match raw.trim().to_ascii_lowercase().as_str() {
1395 "none" => Ok(TensorBSplineIdentifiability::None),
1396 "sum_tozero" | "sum-to-zero" | "center_sum_tozero" | "center-sum-to-zero" | "centered"
1397 | "sumtozero" => Ok(TensorBSplineIdentifiability::SumToZero),
1398 "marginal_sum_tozero" | "marginal-sum-to-zero" | "marginal_sumtozero"
1399 | "marginalsumtozero" | "interaction" => {
1400 Ok(TensorBSplineIdentifiability::MarginalSumToZero)
1401 }
1402 other => Err(TermBuilderError::unsupported_feature(format!(
1403 "invalid tensor identifiability '{other}'; expected one of: none, sum_tozero, marginal_sum_tozero"
1404 ))
1405 .to_string()),
1406 }
1407}
1408
1409fn bspline_boundary_declares_periodic_axis(options: &BTreeMap<String, String>) -> bool {
1410 options
1411 .get("boundary")
1412 .or_else(|| options.get("bc"))
1413 .map(|raw| {
1414 parse_option_list(raw)
1415 .into_iter()
1416 .any(|value| matches!(value.as_str(), "periodic" | "cyclic" | "cc"))
1417 })
1418 .unwrap_or(false)
1419}
1420
1421pub(crate) fn canonicalize_smooth_type(raw: &str) -> &str {
1443 match raw {
1444 "tp" => "tps",
1447 "gp" => "matern",
1452 "curv" | "constant_curvature" | "mkappa" => "curvature",
1456 "mjs" | "measure_jet" | "web" => "measurejet",
1460 other => other,
1461 }
1462}
1463
1464pub(crate) fn tensor_margin_bs_is_supported(margin_bs: &str) -> bool {
1475 matches!(
1476 canonicalize_smooth_type(margin_bs),
1477 "tps" | "ps" | "bs" | "bspline" | "cr" | "cs" | "cc" | "cp" | "cyclic"
1478 )
1479}
1480
1481pub(crate) fn smooth_options_declare_periodic(options: &BTreeMap<String, String>) -> bool {
1487 options.contains_key("periodic")
1488 || options.contains_key("cyclic")
1489 || options
1490 .get("boundary")
1491 .or_else(|| options.get("bc"))
1492 .map(|boundary| {
1493 boundary.to_ascii_lowercase().contains("periodic")
1494 || boundary.to_ascii_lowercase().contains("cyclic")
1495 })
1496 .unwrap_or(false)
1497}
1498
1499pub(crate) fn bs_selector_is_vector(raw: &str) -> bool {
1516 let trimmed = raw.trim();
1517 let bracketed = (trimmed.starts_with('[') && trimmed.ends_with(']'))
1518 || (trimmed.starts_with("c(") || trimmed.starts_with("C(")) && trimmed.ends_with(')')
1519 || (trimmed.starts_with('(') && trimmed.ends_with(')'));
1520 bracketed && !parse_option_list(trimmed).is_empty()
1521}
1522
1523pub fn resolve_smooth_type_name(
1524 kind: SmoothKind,
1525 n_cols: usize,
1526 options: &BTreeMap<String, String>,
1527) -> String {
1528 let selector = options.get("type").or_else(|| options.get("bs"));
1529 if let Some(raw) = selector
1534 && bs_selector_is_vector(raw)
1535 && matches!(kind, SmoothKind::Te | SmoothKind::Ti | SmoothKind::T2)
1536 {
1537 return "tensor".to_string();
1538 }
1539 selector
1540 .map(|s| canonicalize_smooth_type(&s.to_ascii_lowercase()).to_string())
1541 .unwrap_or_else(|| match kind {
1542 SmoothKind::Te | SmoothKind::Ti | SmoothKind::T2 => "tensor".to_string(),
1543 SmoothKind::S if n_cols == 1 => "bspline".to_string(),
1544 SmoothKind::S if smooth_options_declare_periodic(options) => "tensor".to_string(),
1548 SmoothKind::S => "tps".to_string(),
1549 })
1550}
1551
1552pub fn smooth_type_uses_spatial_center_heuristic(canonical_type: &str) -> bool {
1561 matches!(canonical_type, "tps" | "matern" | "duchon")
1562}
1563
1564pub fn build_smooth_basis(
1565 kind: SmoothKind,
1566 vars: &[String],
1567 cols: &[usize],
1568 options: &BTreeMap<String, String>,
1569 ds: &Dataset,
1570 inference_notes: &mut Vec<String>,
1571 policy: &ResourcePolicy,
1572 smooth_coordinate_count: usize,
1573) -> Result<SmoothBasisSpec, String> {
1574 let coord_cols: Vec<(&String, usize)> = vars
1590 .iter()
1591 .zip(cols.iter().copied())
1592 .filter(|(_, col)| !matches!(ds.column_kinds.get(*col), Some(ColumnKindTag::Categorical)))
1593 .collect();
1594 if !coord_cols.is_empty() {
1595 let views: Vec<ArrayView1<'_, f64>> = coord_cols
1596 .iter()
1597 .map(|(_, col)| ds.values.column(*col))
1598 .collect();
1599 let n_rows = views[0].len();
1600 let mut distinct_points = std::collections::HashSet::<Vec<u64>>::new();
1601 for r in 0..n_rows {
1602 let key: Vec<u64> = views
1603 .iter()
1604 .map(|v| {
1605 let x = v[r];
1606 let norm = if x == 0.0 { 0.0 } else { x };
1607 norm.to_bits()
1608 })
1609 .collect();
1610 distinct_points.insert(key);
1611 if distinct_points.len() > 1 {
1612 break;
1613 }
1614 }
1615 if distinct_points.len() <= 1 {
1616 return Err(TermBuilderError::degenerate_data(if coord_cols.len() == 1 {
1617 let var = coord_cols[0].0;
1618 format!(
1619 "smooth term over '{var}' has only one unique value in the training data \
1620 — a smooth on a constant column is degenerate and would only fit the response mean. \
1621 Remove `{var}` from the smooth, drop the term, or check the data."
1622 )
1623 } else {
1624 let names = coord_cols
1625 .iter()
1626 .map(|(v, _)| v.as_str())
1627 .collect::<Vec<_>>()
1628 .join(", ");
1629 format!(
1630 "smooth term over ({names}) has only one unique joint coordinate in the training \
1631 data — every coordinate is constant, so the smooth is degenerate and would only \
1632 fit the response mean. Drop the term or check the data."
1633 )
1634 })
1635 .to_string());
1636 }
1637 }
1638 if let Some(by_name) = options.get("by").cloned() {
1639 let by_col = options
1640 .get("__by_col")
1641 .and_then(|raw| raw.parse::<usize>().ok())
1642 .or_else(|| vars.iter().position(|v| v == &by_name).map(|idx| cols[idx]))
1643 .ok_or_else(|| format!("unknown by= column '{by_name}'"))?;
1644 let mut inner_options = options.clone();
1645 inner_options.remove("by");
1646 inner_options.remove("__by_col");
1647 inner_options.remove("id");
1648 let inner = build_smooth_basis(
1649 kind,
1650 vars,
1651 cols,
1652 &inner_options,
1653 ds,
1654 inference_notes,
1655 policy,
1656 smooth_coordinate_count,
1657 )?;
1658 let by_kind = match ds.column_kinds.get(by_col).copied() {
1659 Some(ColumnKindTag::Categorical) => ByVarKind::Factor {
1660 feature_col: by_col,
1661 ordered: option_bool(options, "ordered").unwrap_or(false),
1662 frozen_levels: None,
1663 },
1664 Some(ColumnKindTag::Continuous | ColumnKindTag::Binary) => ByVarKind::Numeric {
1665 feature_col: by_col,
1666 },
1667 None => {
1668 return Err(format!(
1669 "internal column-kind lookup failed for by='{by_name}'"
1670 ));
1671 }
1672 };
1673 return Ok(SmoothBasisSpec::BySmooth {
1674 smooth: Box::new(inner),
1675 by_kind,
1676 });
1677 }
1678
1679 let smooth_double_penalty = option_bool(options, "double_penalty").unwrap_or(true);
1680 let type_opt = resolve_smooth_type_name(kind, cols.len(), options);
1681
1682 if matches!(type_opt.as_str(), "fs" | "sz" | "re") {
1683 validate_known_options(
1684 type_opt.as_str(),
1685 options,
1686 &[
1687 "type",
1688 "bs",
1689 "k",
1690 "basis_dim",
1691 "basis-dim",
1692 "basisdim",
1693 "knots",
1694 "knot_placement",
1695 "knot-placement",
1696 "knotplacement",
1697 "degree",
1698 "penalty_order",
1699 "m",
1700 "double_penalty",
1701 "ordered",
1702 ],
1703 )?;
1704 if cols.len() != 2 {
1705 return Err(format!(
1706 "{} factor-smooth currently expects exactly two variables (one numeric, one categorical)",
1707 type_opt
1708 ));
1709 }
1710 let kinds = cols
1711 .iter()
1712 .map(|&c| ds.column_kinds.get(c).copied())
1713 .collect::<Vec<_>>();
1714 let (cont_idx, group_idx) = if type_opt == "re" {
1715 match (kinds[0], kinds[1]) {
1717 (Some(ColumnKindTag::Categorical), _) => (1usize, 0usize),
1718 (_, Some(ColumnKindTag::Categorical)) => (0usize, 1usize),
1719 _ => (1usize, 0usize),
1720 }
1721 } else {
1722 match (kinds[0], kinds[1]) {
1723 (_, Some(ColumnKindTag::Categorical)) => (0usize, 1usize),
1724 (Some(ColumnKindTag::Categorical), _) => (1usize, 0usize),
1725 _ => {
1726 return Err(format!(
1727 "{} factor-smooth requires one categorical factor variable",
1728 type_opt
1729 ));
1730 }
1731 }
1732 };
1733 let c = cols[cont_idx];
1734 let (minv, maxv) = col_minmax(ds.values.column(c))?;
1735 let degree = if type_opt == "re" {
1736 1
1737 } else {
1738 option_usize(options, "degree").unwrap_or(DEFAULT_BSPLINE_DEGREE)
1739 };
1740 let pooled_internal = heuristic_knots_for_column(ds.values.column(c));
1760 let default_internal = if type_opt == "re" {
1761 0
1774 } else {
1775 let min_group_resolution =
1776 min_per_group_unique_count(ds.values.column(c), ds.values.column(cols[group_idx]));
1777 let basis_cap = min_group_resolution.saturating_sub(2).max(degree + 2);
1785 let internal_cap = basis_cap.saturating_sub(degree + 1);
1786 let capped = pooled_internal.min(internal_cap.max(1));
1787 let fs_default_internal = FACTOR_SMOOTH_DEFAULT_BASIS_DIM
1803 .saturating_sub(degree + 1)
1804 .max(1);
1805 capped.min(fs_default_internal)
1806 };
1807 let (n_knots, _, effective_degree) =
1808 parse_ps_internal_knots(options, degree, default_internal)?;
1809 let penalty_order = option_usize(options, "penalty_order")
1810 .unwrap_or(if effective_degree > 1 { 2 } else { 1 })
1811 .min(effective_degree);
1812 let marginal_knotspec = resolve_nonperiodic_bspline_knotspec(
1846 options,
1847 ds.values.column(c),
1848 (minv, maxv),
1849 effective_degree,
1850 n_knots,
1851 )?;
1852 let marginal = BSplineBasisSpec {
1853 degree: effective_degree,
1854 penalty_order,
1855 knotspec: marginal_knotspec,
1856 double_penalty: option_bool(options, "double_penalty")
1867 .unwrap_or(type_opt.as_str() != "sz"),
1868 identifiability: BSplineIdentifiability::None,
1869 boundary_conditions: Default::default(),
1870 boundary: OneDimensionalBoundary::Open,
1871 };
1872 let flavour = match type_opt.as_str() {
1873 "fs" => FactorSmoothFlavour::Fs {
1874 m_null_penalty_orders: vec![
1875 option_usize(options, "m").unwrap_or(DEFAULT_PENALTY_ORDER),
1876 ],
1877 },
1878 "sz" => FactorSmoothFlavour::Sz,
1879 "re" => FactorSmoothFlavour::Re,
1880 other => {
1882 return Err(format!(
1883 "internal: factor-smooth flavour dispatch reached unexpected type `{}`",
1884 other
1885 ));
1886 }
1887 };
1888 return Ok(SmoothBasisSpec::FactorSmooth {
1889 spec: FactorSmoothSpec {
1890 continuous_cols: vec![c],
1891 group_col: cols[group_idx],
1892 marginal,
1893 flavour,
1894 group_frozen_levels: None,
1895 frozen_global_orthogonality: None,
1896 },
1897 });
1898 }
1899
1900 match type_opt.as_str() {
1901 "cyclic" | "cc" | "cp" | "cyclic-ps" => {
1902 validate_known_options(
1903 "cyclic",
1904 options,
1905 &[
1906 "type",
1907 "bs",
1908 "by",
1909 "k",
1910 "basis_dim",
1911 "basis-dim",
1912 "basisdim",
1913 "degree",
1914 "penalty_order",
1915 "period",
1916 "periods",
1917 "period_start",
1918 "period_end",
1919 "start",
1920 "end",
1921 "origin",
1922 "origins",
1923 "period_origin",
1924 "period-origin",
1925 "domain_origin",
1926 "double_penalty",
1927 "id",
1928 "__by_col",
1929 "identifiability",
1930 ],
1931 )?;
1932 if cols.len() != 1 {
1933 return Err(format!(
1934 "periodic smooth expects one variable, got {}",
1935 cols.len()
1936 ));
1937 }
1938 let c = cols[0];
1939 let (minv, maxv) = col_minmax(ds.values.column(c))?;
1940 let degree = option_usize(options, "degree").unwrap_or(DEFAULT_BSPLINE_DEGREE);
1941 let mut default_internal = heuristic_knots_for_column(ds.values.column(c));
1942 if ds.values.nrows() <= 32 && smooth_coordinate_count >= 5 {
1943 default_internal = default_internal.min(1);
1944 }
1945 let cyclic_default_basis_cap = CYCLIC_DEFAULT_BASIS_DIM.max(degree + 1);
1959 let default_basis = (default_internal + degree + 1).min(cyclic_default_basis_cap);
1960 let num_basis = option_usize_any(options, &["k", "basis_dim", "basis-dim", "basisdim"])
1961 .unwrap_or(default_basis);
1962 if num_basis < degree + 1 {
1963 return Err(format!(
1964 "periodic smooth: k={} too small for degree {}; expected k >= {}",
1965 num_basis,
1966 degree,
1967 degree + 1
1968 ));
1969 }
1970 let periodic_axes = [true];
1981 let periods = parse_periods(options, &periodic_axes)?;
1982 let origins = parse_period_origins(options, &periodic_axes)?;
1983 let (domain_start, period) = if let Some(p) = periods[0] {
1984 (origins[0].unwrap_or(minv), p)
1985 } else {
1986 parse_periodic_domain_1d(options, minv, maxv)?
1987 };
1988 Ok(SmoothBasisSpec::BSpline1D {
1989 feature_col: c,
1990 spec: BSplineBasisSpec {
1991 degree,
1992 penalty_order: option_usize(options, "penalty_order")
1993 .unwrap_or(DEFAULT_PENALTY_ORDER),
1994 knotspec: BSplineKnotSpec::PeriodicUniform {
1995 data_range: (domain_start, domain_start + period),
1996 num_basis,
1997 },
1998 double_penalty: smooth_double_penalty,
1999 identifiability: BSplineIdentifiability::default(),
2000 boundary_conditions: Default::default(),
2001 boundary: OneDimensionalBoundary::Cyclic {
2002 start: domain_start,
2003 end: domain_start + period,
2004 },
2005 },
2006 })
2007 }
2008 "bspline" | "ps" | "p-spline" | "cr" | "cs" => {
2009 let validation_name = match type_opt.as_str() {
2023 "cr" => "cr",
2024 "cs" => "cs",
2025 _ => "bspline",
2026 };
2027 validate_known_options(
2028 validation_name,
2029 options,
2030 &[
2031 "type",
2032 "bs",
2033 "by",
2034 "k",
2035 "basis_dim",
2036 "basis-dim",
2037 "basisdim",
2038 "knots",
2039 "knot_placement",
2040 "knot-placement",
2041 "knotplacement",
2042 "degree",
2043 "penalty_order",
2044 "boundary",
2045 "bc",
2046 "boundary_conditions",
2047 "bc_left",
2048 "bc_right",
2049 "left_bc",
2050 "right_bc",
2051 "start_bc",
2052 "end_bc",
2053 "side",
2054 "anchor",
2055 "anchor_value",
2056 "value",
2057 "anchor_left",
2058 "left_anchor",
2059 "anchor_right",
2060 "right_anchor",
2061 "periodic",
2062 "period",
2063 "periods",
2064 "period_start",
2065 "period_end",
2066 "origin",
2067 "double_penalty",
2068 "by",
2069 "id",
2070 "__by_col",
2071 "identifiability",
2072 "by",
2073 ],
2074 )?;
2075 if cols.len() != 1 {
2076 return Err(TermBuilderError::incompatible_config(format!(
2077 "bspline smooth expects one variable, got {}",
2078 cols.len()
2079 ))
2080 .to_string());
2081 }
2082 let c = cols[0];
2083 let (minv, maxv) = col_minmax(ds.values.column(c))?;
2084 let degree = option_usize(options, "degree").unwrap_or(DEFAULT_BSPLINE_DEGREE);
2085 let default_internal = heuristic_knots_for_column(ds.values.column(c));
2086 let (mut n_knots, inferred, effective_degree) =
2087 parse_ps_internal_knots(options, degree, default_internal)?;
2088 let periodic_axes = parse_periodic_axes(options, 1).map_err(|e| e.to_string())?;
2089 if periodic_axes[0] && effective_degree != degree {
2094 return Err(TermBuilderError::invalid_option(format!(
2095 "periodic smooth: k={} too small for degree {}; expected k >= {}",
2096 effective_degree + 1,
2097 degree,
2098 degree + 1
2099 ))
2100 .to_string());
2101 }
2102 if inferred && ds.values.nrows() <= 32 && smooth_coordinate_count >= 5 {
2103 n_knots = n_knots.min(1);
2104 }
2105 if inferred {
2106 let unique = unique_count_column(ds.values.column(c));
2107 let ceiling = ((unique as f64).cbrt() as usize).max(20);
2108 inference_notes.push(format!(
2109 "Automatically set {} internal knots for smooth '{}' from {} unique values (rule: clamp(unique/4, 4..max(20, cbrt(unique))) = clamp(unique/4, 4..{})). Override with knots=... or k=....",
2110 n_knots,
2111 vars.join(","),
2112 unique,
2113 ceiling,
2114 ));
2115 }
2116 let boundary_conditions =
2117 if periodic_axes[0] && bspline_boundary_declares_periodic_axis(options) {
2118 BSplineBoundaryConditions::default()
2119 } else {
2120 parse_bspline_boundary_conditions(options).map_err(|e| e.to_string())?
2121 };
2122 let periods = parse_periods(options, &periodic_axes).map_err(|e| e.to_string())?;
2123 let origins =
2124 parse_period_origins(options, &periodic_axes).map_err(|e| e.to_string())?;
2125 let (knotspec, boundary) = if periodic_axes[0] {
2126 if !boundary_conditions.is_free() {
2127 return Err(TermBuilderError::incompatible_config(
2128 "periodic B-splines cannot also declare endpoint boundary conditions",
2129 )
2130 .to_string());
2131 }
2132 {
2133 let (domain_start, p_value) = if periods[0].is_some() {
2134 (origins[0].unwrap_or(minv), periods[0].unwrap())
2135 } else {
2136 parse_periodic_domain_1d(options, minv, maxv).map_err(|e| e.to_string())?
2137 };
2138 let domain_end = domain_start + p_value;
2139 (
2140 BSplineKnotSpec::PeriodicUniform {
2141 data_range: (domain_start, domain_end),
2142 num_basis: n_knots + effective_degree + 1,
2143 },
2144 OneDimensionalBoundary::Cyclic {
2145 start: domain_start,
2146 end: domain_end,
2147 },
2148 )
2149 }
2150 } else if type_opt == "cr" || type_opt == "cs" {
2151 let k_cr = (n_knots + effective_degree + 1).max(CR_MIN_KNOTS);
2168 let knotspec = match capped_cr_marginal_knotspec(
2169 ds.values.column(c),
2170 k_cr,
2171 &vars.join(","),
2172 inference_notes,
2173 )? {
2174 Some(cr_knotspec) => cr_knotspec,
2175 None => resolve_nonperiodic_bspline_knotspec(
2176 options,
2177 ds.values.column(c),
2178 (minv, maxv),
2179 effective_degree,
2180 n_knots,
2181 )?,
2182 };
2183 (knotspec, parse_cyclic_boundary(options, minv, maxv)?)
2184 } else {
2185 (
2186 resolve_nonperiodic_bspline_knotspec(
2187 options,
2188 ds.values.column(c),
2189 (minv, maxv),
2190 effective_degree,
2191 n_knots,
2192 )?,
2193 parse_cyclic_boundary(options, minv, maxv)?,
2194 )
2195 };
2196 let double_penalty = if type_opt == "cr" {
2200 option_bool(options, "double_penalty").unwrap_or(false)
2201 } else {
2202 smooth_double_penalty
2203 };
2204 let penalty_order = option_usize(options, "penalty_order")
2209 .unwrap_or(DEFAULT_PENALTY_ORDER)
2210 .min(effective_degree);
2211 Ok(SmoothBasisSpec::BSpline1D {
2212 feature_col: c,
2213 spec: BSplineBasisSpec {
2214 degree: effective_degree,
2215 penalty_order,
2216 knotspec,
2217 double_penalty,
2218 identifiability: BSplineIdentifiability::default(),
2219 boundary,
2220 boundary_conditions,
2221 },
2222 })
2223 }
2224 "tps" | "thinplate" | "thin-plate" => {
2225 validate_known_options(
2226 "thinplate",
2227 options,
2228 &[
2229 SECONDARY_CENTER_CAP_OPTION,
2230 "type",
2231 "bs",
2232 "by",
2233 "length_scale",
2234 "centers",
2235 "k",
2236 "basis_dim",
2237 "basis-dim",
2238 "basisdim",
2239 "knots",
2240 "include_intercept",
2241 "double_penalty",
2242 "by",
2243 "id",
2244 "__by_col",
2245 "identifiability",
2246 "by",
2247 "scale_dims",
2248 ],
2249 )?;
2250 let plan = plan_spatial_basis(
2251 ds.values.nrows(),
2252 cols.len(),
2253 CenterCountRequest::Default,
2254 DuchonNullspaceOrder::Linear,
2255 option_bool(options, "scale_dims").unwrap_or(false),
2256 policy,
2257 )
2258 .map_err(|e| e.to_string())?;
2259 let default_centers = plan.centers;
2269 let centers = parse_countwith_basis_alias(
2270 options,
2271 "centers",
2272 cap_default_spatial_centers(options, default_centers),
2273 )?;
2274 let center_strategy = if has_explicit_countwith_basis_alias(options, "centers") {
2275 spatial_center_strategy_for_dimension(centers, cols.len())
2276 } else {
2277 auto_spatial_center_strategy(centers, cols.len())
2278 };
2279 Ok(SmoothBasisSpec::ThinPlate {
2280 feature_cols: cols.to_vec(),
2281 spec: ThinPlateBasisSpec {
2282 center_strategy,
2283 periodic: parse_periodic_axes_option(options, cols.len())?,
2284 length_scale: option_f64(options, "length_scale").unwrap_or(0.0),
2292 double_penalty: smooth_double_penalty,
2293 identifiability: parse_spatial_identifiability(options)
2294 .map_err(|e| e.to_string())?,
2295 radial_reparam: None,
2296 },
2297 input_scales: None,
2298 })
2299 }
2300 "sphere" | "s2" | "sos" => {
2301 validate_known_options(
2302 "sphere",
2303 options,
2304 &[
2305 "type",
2306 "bs",
2307 "by",
2308 "centers",
2309 "k",
2310 "basis_dim",
2311 "basis-dim",
2312 "basisdim",
2313 "knots",
2314 "penalty_order",
2315 "m",
2316 "double_penalty",
2317 "id",
2318 "__by_col",
2319 "kernel",
2320 "method",
2321 "radians",
2322 "units",
2323 "degree",
2324 "l",
2325 "max_degree",
2326 "max-degree",
2327 ],
2328 )?;
2329 if cols.len() != 2 {
2330 return Err(format!(
2331 "sphere smooth expects exactly two variables (lat, lon), got {}",
2332 cols.len()
2333 ));
2334 }
2335 let radians = option_bool(options, "radians").unwrap_or_else(|| {
2336 options
2337 .get("units")
2338 .map(|u| u.eq_ignore_ascii_case("radian") || u.eq_ignore_ascii_case("radians"))
2339 .unwrap_or(false)
2340 });
2341 let degree_requested = options.contains_key("degree")
2347 || options.contains_key("l")
2348 || options.contains_key("max_degree")
2349 || options.contains_key("max-degree");
2350 let kernel = options
2351 .get("kernel")
2352 .or_else(|| options.get("method"))
2353 .map(|raw| strip_quotes(raw).trim().to_ascii_lowercase())
2354 .unwrap_or_else(|| {
2355 if degree_requested {
2356 "harmonic".to_string()
2357 } else {
2358 "sobolev".to_string()
2359 }
2360 });
2361 let (method, wahba_kernel) = match kernel.as_str() {
2362 "sobolev" | "wahba" | "wahba_sobolev" | "wahba-sobolev" => {
2363 (SphereMethod::Wahba, SphereWahbaKernel::Sobolev)
2364 }
2365 "pseudo" | "mgcv" | "sos" | "wahba_pseudo" | "wahba-pseudo" => {
2366 (SphereMethod::Wahba, SphereWahbaKernel::Pseudo)
2367 }
2368 "harmonic" | "spherical_harmonic" | "spherical-harmonic" => {
2369 (SphereMethod::Harmonic, SphereWahbaKernel::Sobolev)
2370 }
2371 other => {
2372 return Err(format!(
2373 "unsupported sphere kernel '{other}'; expected sobolev, pseudo, or harmonic"
2374 ));
2375 }
2376 };
2377 let max_degree = if matches!(method, SphereMethod::Harmonic) {
2378 let degree =
2379 option_usize_any(options, &["degree", "l", "max_degree", "max-degree"])
2380 .or_else(|| option_usize(options, "centers"))
2381 .or_else(|| {
2382 option_usize_any(options, &["k", "basis_dim", "basis-dim", "basisdim"])
2383 .and_then(|k| (1..=128).find(|&l| l * (l + 2) >= k))
2384 })
2385 .unwrap_or_else(|| default_spherical_harmonic_degree(ds.values.nrows()));
2386 if degree == 0 {
2387 return Err("sphere smooth requires degree/max_degree >= 1".to_string());
2388 }
2389 if degree > 32 {
2390 return Err(format!(
2391 "sphere smooth max_degree={} is too large for the dense harmonic engine (limit 32)",
2392 degree
2393 ));
2394 }
2395 Some(degree)
2396 } else {
2397 None
2398 };
2399 let penalty_order = option_usize(options, "penalty_order")
2400 .or_else(|| option_usize(options, "m"))
2401 .unwrap_or(DEFAULT_PENALTY_ORDER);
2402 let center_strategy = if matches!(method, SphereMethod::Wahba) {
2403 let mut centers = parse_countwith_basis_alias(
2404 options,
2405 "centers",
2406 default_num_centers(ds.values.nrows(), cols.len()),
2407 )?;
2408 if penalty_order >= 4 {
2409 centers = centers.max(30);
2410 }
2411 CenterStrategy::FarthestPoint {
2412 num_centers: centers,
2413 }
2414 } else {
2415 CenterStrategy::FarthestPoint { num_centers: 0 }
2416 };
2417 Ok(SmoothBasisSpec::Sphere {
2418 feature_cols: cols.to_vec(),
2419 spec: SphericalSplineBasisSpec {
2420 center_strategy,
2421 penalty_order,
2422 double_penalty: smooth_double_penalty,
2423 radians,
2424 method,
2425 max_degree,
2426 wahba_kernel,
2427 identifiability: SphericalSplineIdentifiability::CenterSumToZero,
2428 },
2429 })
2430 }
2431 "curvature" => {
2432 validate_known_options(
2438 "curvature",
2439 options,
2440 &[
2441 "type",
2442 "bs",
2443 "by",
2444 "centers",
2445 "k",
2446 "basis_dim",
2447 "basis-dim",
2448 "basisdim",
2449 "knots",
2450 "kappa",
2451 "length_scale",
2452 "double_penalty",
2453 "id",
2454 "__by_col",
2455 ],
2456 )?;
2457 let kappa = option_f64(options, "kappa").unwrap_or(0.0);
2458 if !kappa.is_finite() {
2459 return Err("curvature smooth requires a finite kappa".to_string());
2460 }
2461 let length_scale = option_f64(options, "length_scale").unwrap_or(0.0);
2462 if !length_scale.is_finite() || length_scale < 0.0 {
2463 return Err(format!(
2464 "curvature smooth length_scale must be positive (or omitted for auto); got {length_scale}"
2465 ));
2466 }
2467 let centers = parse_countwith_basis_alias(
2468 options,
2469 "centers",
2470 default_num_centers(ds.values.nrows(), cols.len()),
2471 )?;
2472 if centers < 2 {
2473 return Err("curvature smooth requires at least 2 centers".to_string());
2474 }
2475 Ok(SmoothBasisSpec::ConstantCurvature {
2476 feature_cols: cols.to_vec(),
2477 spec: ConstantCurvatureBasisSpec {
2478 center_strategy: CenterStrategy::FarthestPoint {
2479 num_centers: centers,
2480 },
2481 kappa,
2482 length_scale,
2485 double_penalty: option_bool(options, "double_penalty").unwrap_or(false),
2492 identifiability: ConstantCurvatureIdentifiability::CenterSumToZero,
2493 },
2494 })
2495 }
2496 "measurejet" => {
2497 validate_known_options(
2503 "measurejet",
2504 options,
2505 &[
2506 "type",
2507 "bs",
2508 "by",
2509 "centers",
2510 "k",
2511 "basis_dim",
2512 "basis-dim",
2513 "basisdim",
2514 "knots",
2515 "s",
2516 "alpha",
2517 "tau",
2518 "scales",
2519 "length_scale",
2520 "double_penalty",
2521 "multiscale",
2522 "learn_length_scale",
2523 "id",
2524 "__by_col",
2525 ],
2526 )?;
2527 let order_s = option_f64(options, "s").unwrap_or(0.0);
2528 if !(order_s.is_finite() && (order_s == 0.0 || (order_s > 0.0 && order_s < 2.0))) {
2531 return Err(format!(
2532 "measurejet smooth s must lie in (0, 2) (or be omitted for auto); got {order_s}"
2533 ));
2534 }
2535 let alpha =
2543 option_f64(options, "alpha").unwrap_or(MeasureJetBasisSpec::default().alpha);
2544 if !alpha.is_finite() {
2545 return Err("measurejet smooth requires a finite alpha".to_string());
2546 }
2547 let tau0 = option_f64(options, "tau").unwrap_or(1e-3);
2548 if !(tau0.is_finite() && tau0 >= 0.0) {
2549 return Err(format!(
2550 "measurejet smooth tau must be finite and nonnegative; got {tau0}"
2551 ));
2552 }
2553 let num_scales = option_usize(options, "scales").unwrap_or(0);
2554 let length_scale = option_f64(options, "length_scale").unwrap_or(0.0);
2555 if !length_scale.is_finite() || length_scale < 0.0 {
2556 return Err(format!(
2557 "measurejet smooth length_scale must be positive (or omitted for auto); got {length_scale}"
2558 ));
2559 }
2560 let centers = parse_countwith_basis_alias(
2561 options,
2562 "centers",
2563 default_num_centers(ds.values.nrows(), cols.len()),
2564 )?;
2565 if centers < 3 {
2566 return Err("measurejet smooth requires at least 3 centers".to_string());
2567 }
2568 let multiscale = option_bool(options, "multiscale").unwrap_or(false);
2572 let learn_length_scale = option_bool(options, "learn_length_scale").unwrap_or(false);
2577 Ok(SmoothBasisSpec::MeasureJet {
2578 feature_cols: cols.to_vec(),
2579 spec: MeasureJetBasisSpec {
2580 center_strategy: CenterStrategy::FarthestPoint {
2581 num_centers: centers,
2582 },
2583 order_s,
2584 alpha,
2585 tau0,
2586 num_scales,
2587 length_scale,
2590 double_penalty: smooth_double_penalty,
2591 learn_length_scale,
2592 multiscale,
2593 identifiability: MeasureJetIdentifiability::CenterSumToZero,
2594 frozen_quadrature: None,
2595 },
2596 input_scales: None,
2597 })
2598 }
2599 "matern" => {
2600 validate_known_options(
2605 "matern",
2606 options,
2607 &[
2608 SECONDARY_CENTER_CAP_OPTION,
2609 "type",
2610 "bs",
2611 "by",
2612 "nu",
2613 "length_scale",
2614 "centers",
2615 "k",
2616 "basis_dim",
2617 "basis-dim",
2618 "basisdim",
2619 "knots",
2620 "include_intercept",
2621 "double_penalty",
2622 "by",
2623 "id",
2624 "__by_col",
2625 "identifiability",
2626 "by",
2627 "scale_dims",
2628 ],
2629 )?;
2630 let plan = plan_spatial_basis(
2631 ds.values.nrows(),
2632 cols.len(),
2633 CenterCountRequest::Default,
2634 DuchonNullspaceOrder::Zero,
2635 option_bool(options, "scale_dims").unwrap_or(false),
2636 policy,
2637 )
2638 .map_err(|e| e.to_string())?;
2639 let centers = parse_countwith_basis_alias(
2640 options,
2641 "centers",
2642 cap_default_spatial_centers(
2643 options,
2644 default_matern_center_count(ds.values.nrows(), cols.len(), plan.centers),
2645 ),
2646 )?;
2647 let center_strategy = if has_explicit_countwith_basis_alias(options, "centers") {
2648 spatial_center_strategy_for_dimension(centers, cols.len())
2649 } else {
2650 auto_spatial_center_strategy(centers, cols.len())
2651 };
2652 let nu = parse_matern_nu(options.get("nu").map(String::as_str).unwrap_or("5/2"))?;
2653 if matches!(nu, MaternNu::Half) && cols.len() >= 2 {
2659 return Err(TermBuilderError::unsupported_feature(format!(
2660 "matern() with nu=1/2 is not supported for d>=2 (got {} covariates): \
2661 the exponential kernel's Laplacian is singular at center collisions, \
2662 which makes the operator-collocation penalty non-invertible. \
2663 Choose nu>=3/2 (e.g. nu=3/2 or the default nu=5/2) for multi-dimensional smooths.",
2664 cols.len()
2665 ))
2666 .to_string());
2667 }
2668 let aniso_log_scales = if option_bool(options, "scale_dims").unwrap_or(false) {
2669 Some(vec![0.0; cols.len()])
2670 } else {
2671 None
2672 };
2673 Ok(SmoothBasisSpec::Matern {
2674 feature_cols: cols.to_vec(),
2675 spec: MaternBasisSpec {
2676 center_strategy,
2677 periodic: parse_periodic_axes_option(options, cols.len())?,
2678 length_scale: option_f64(options, "length_scale")
2679 .unwrap_or_else(|| default_matern_length_scale(ds, cols)),
2680 nu,
2681 include_intercept: option_bool(options, "include_intercept").unwrap_or(false),
2682 double_penalty: smooth_double_penalty,
2683 identifiability: parse_matern_identifiability(options)
2684 .map_err(|e| e.to_string())?,
2685 aniso_log_scales,
2686 nullspace_shrinkage_survived: None,
2691 },
2692 input_scales: None,
2693 })
2694 }
2695 "duchon" => {
2696 validate_known_options(
2697 "duchon",
2698 options,
2699 &[
2700 SECONDARY_CENTER_CAP_OPTION,
2701 "type",
2702 "bs",
2703 "by",
2704 "length_scale",
2705 "centers",
2706 "k",
2707 "basis_dim",
2708 "basis-dim",
2709 "basisdim",
2710 "knots",
2711 "power",
2712 "p",
2713 "nullspace_order",
2714 "order",
2715 "identifiability",
2716 "by",
2717 "periodic",
2718 "cyclic",
2719 "period",
2720 "period_start",
2721 "period_end",
2722 "scale_dims",
2723 "double_penalty",
2724 "by",
2725 "id",
2726 "__by_col",
2727 ],
2728 )?;
2729 if options.contains_key("double_penalty") {
2730 return Err(TermBuilderError::incompatible_config(format!(
2731 "Duchon smooth '{}' does not support double_penalty; the Duchon smoother already ships its native reproducing-norm penalty plus a null-space shrinkage ridge.",
2732 vars.join(", ")
2733 ))
2734 .to_string());
2735 }
2736 let requested_nullspace_order = parse_duchon_order(options)?;
2737 let length_scale = option_f64_strict(options, "length_scale")?;
2738 let (nullspace_order, power) = match parse_duchon_power_policy(options)? {
2751 DuchonPowerPolicy::Explicit(req_power) => {
2752 if length_scale.is_some() && req_power.fract() != 0.0 {
2753 return Err(TermBuilderError::incompatible_config(format!(
2754 "hybrid Duchon-Matern smooth '{}' (length_scale=...) requires an integer power, got power={}; \
2755 drop length_scale to use the scale-free structural kernel with a fractional power.",
2756 vars.join(", "),
2757 req_power,
2758 ))
2759 .to_string());
2760 }
2761 (requested_nullspace_order, req_power)
2762 }
2763 DuchonPowerPolicy::CubicStructuralDefault => {
2764 match length_scale {
2771 None => crate::basis::duchon_cubic_default(cols.len()),
2772 Some(_) => {
2773 let max_op = crate::basis::duchon_max_active_operator_derivative_order(
2794 &DuchonOperatorPenaltySpec::default(),
2795 );
2796 let (ns, s) = crate::basis::resolve_duchon_orders(
2797 cols.len(),
2798 requested_nullspace_order,
2799 max_op,
2800 length_scale,
2801 );
2802 (ns, s as f64)
2803 }
2804 }
2805 }
2806 };
2807 let plan = plan_spatial_basis(
2808 ds.values.nrows(),
2809 cols.len(),
2810 CenterCountRequest::Default,
2811 nullspace_order,
2812 option_bool(options, "scale_dims").unwrap_or(false),
2813 policy,
2814 )
2815 .map_err(|e| e.to_string())?;
2816 let centers_explicit = has_explicit_countwith_basis_alias(options, "centers");
2817 let requested_centers = parse_countwith_basis_alias(
2818 options,
2819 "centers",
2820 cap_default_spatial_centers(options, plan.centers),
2821 )?;
2822 let polynomial_cols = match nullspace_order {
2823 DuchonNullspaceOrder::Zero => 1,
2824 DuchonNullspaceOrder::Linear => cols.len() + 1,
2825 DuchonNullspaceOrder::Degree(degree) => {
2826 crate::basis::duchon_nullspace_dimension(cols.len(), degree)
2827 }
2828 };
2829 if requested_centers <= polynomial_cols {
2830 return Err(TermBuilderError::incompatible_config(format!(
2831 "Duchon smooth '{}' requested basis dimension {} but order={:?} in {}D needs {} polynomial null-space columns; choose centers/k > {}",
2832 vars.join(", "),
2833 requested_centers,
2834 nullspace_order,
2835 cols.len(),
2836 polynomial_cols,
2837 polynomial_cols,
2838 ))
2839 .to_string());
2840 }
2841 let mut centers = requested_centers;
2842 if !centers_explicit && ds.values.nrows() <= 32 && smooth_coordinate_count >= 5 {
2843 centers = centers.max(polynomial_cols + 4);
2844 }
2845 let center_strategy = if centers_explicit {
2846 spatial_center_strategy_for_dimension(centers, cols.len())
2847 } else {
2848 auto_spatial_center_strategy(centers, cols.len())
2849 };
2850 let aniso_log_scales = if option_bool(options, "scale_dims").unwrap_or(false) {
2851 Some(vec![0.0; cols.len()])
2852 } else {
2853 None
2854 };
2855 let operator_penalties = DuchonOperatorPenaltySpec::default();
2858 Ok(SmoothBasisSpec::Duchon {
2859 feature_cols: cols.to_vec(),
2860 spec: DuchonBasisSpec {
2861 center_strategy,
2862 periodic: parse_periodic_axes_option(options, cols.len())?,
2863 length_scale,
2864 power,
2865 nullspace_order,
2866 identifiability: parse_spatial_identifiability(options)
2867 .map_err(|e| e.to_string())?,
2868 aniso_log_scales,
2869 operator_penalties,
2870 boundary: if cols.len() == 1 {
2871 let c = cols[0];
2872 let (minv, maxv) = col_minmax(ds.values.column(c))?;
2873 parse_cyclic_boundary(options, minv, maxv)?
2874 } else {
2875 OneDimensionalBoundary::Open
2876 },
2877 radial_reparam: None,
2878 },
2879 input_scales: None,
2880 })
2881 }
2882 "tensor" | "te" | "ti" | "t2" => {
2883 validate_known_options(
2884 "tensor",
2885 options,
2886 &[
2887 "type",
2888 "bs",
2889 "by",
2890 "k",
2891 "basis_dim",
2892 "basis-dim",
2893 "basisdim",
2894 "knot_placement",
2895 "knot-placement",
2896 "knotplacement",
2897 "degree",
2898 "penalty_order",
2899 "double_penalty",
2900 "periodic",
2901 "cyclic",
2902 "period",
2903 "periods",
2904 "period_start",
2905 "period_end",
2906 "origin",
2907 "origins",
2908 "period_origin",
2909 "period-origin",
2910 "domain_origin",
2911 "boundary",
2912 "bc",
2913 "identifiability",
2914 "id",
2915 "__by_col",
2916 ],
2917 )?;
2918 if cols.len() < 2 {
2919 return Err(TermBuilderError::incompatible_config(format!(
2920 "tensor smooth expects at least 2 variables, got {}",
2921 cols.len()
2922 ))
2923 .to_string());
2924 }
2925 let dim = cols.len();
2926
2927 if let Some(raw) = options.get("bs").or_else(|| options.get("type"))
2950 && bs_selector_is_vector(raw)
2951 {
2952 let per_margin = parse_option_list(raw);
2953 if per_margin.len() != dim {
2954 return Err(TermBuilderError::invalid_option(format!(
2955 "tensor smooth per-margin bs vector has {} entries but the smooth has {} margins",
2956 per_margin.len(),
2957 dim
2958 ))
2959 .to_string());
2960 }
2961 for (axis, margin_bs) in per_margin.iter().enumerate() {
2962 if !tensor_margin_bs_is_supported(margin_bs) {
2963 return Err(TermBuilderError::unsupported_feature(format!(
2964 "tensor smooth margin {axis} basis '{margin_bs}' is not a supported penalized-spline margin; \
2965 tensor margins accept tp/tps/ps/bs/cr/cc"
2966 ))
2967 .to_string());
2968 }
2969 }
2970 }
2971 let periodic_axes = parse_tensor_periodic_axes(options, dim)?;
2972 let periods_opt = parse_periods(options, &periodic_axes)?;
2973 let origins_opt = parse_period_origins(options, &periodic_axes)?;
2974 let degree = option_usize(options, "degree").unwrap_or(DEFAULT_BSPLINE_DEGREE);
2975 let penalty_order =
2976 option_usize(options, "penalty_order").unwrap_or(if degree > 1 { 2 } else { 1 });
2977 let (mut k_list, k_inferred) = parse_tensor_k_list(options, cols, ds)?;
2978 if ds.values.nrows() <= 32 && smooth_coordinate_count >= 5 {
2979 for k in &mut k_list {
2980 *k = (*k).min(degree + 2);
2981 }
2982 }
2983 if k_inferred {
2984 inference_notes.push(format!(
2985 "Automatically set per-margin basis sizes {:?} for tensor smooth '{}' \
2986 (dimension-aware tensor budget: total ∏k kept near the mgcv-te default \
2987 and within the data support, distributed geometrically across margins and \
2988 capped per margin by each column's resolution). \
2989 Override with k=<int> or k=[k0,k1,...].",
2990 k_list,
2991 vars.join(",")
2992 ));
2993 }
2994 let per_axis_bs: Vec<Option<String>> =
3007 match options.get("bs").or_else(|| options.get("type")) {
3008 Some(raw) if bs_selector_is_vector(raw) => {
3009 let list = parse_option_list(raw);
3010 (0..dim).map(|a| list.get(a).cloned()).collect()
3011 }
3012 Some(raw) => {
3013 let scalar = raw
3014 .trim()
3015 .trim_matches('"')
3016 .trim_matches('\'')
3017 .to_ascii_lowercase();
3018 vec![Some(scalar); dim]
3019 }
3020 None => vec![None; dim],
3021 };
3022 let margin_wants_cr = |bs: &Option<String>| -> bool {
3028 matches!(
3029 bs.as_deref(),
3030 None | Some("cr") | Some("cs") | Some("tp") | Some("tps")
3031 )
3032 };
3033 let mut margins: Vec<BSplineBasisSpec> = Vec::with_capacity(dim);
3034 let mut emitted_periods: Vec<Option<f64>> = Vec::with_capacity(dim);
3035 for axis in 0..dim {
3036 let c = cols[axis];
3037 let (data_min, data_max) = col_minmax(ds.values.column(c))?;
3038 let k_requested = k_list[axis];
3054 let n_distinct_axis = unique_count_column(ds.values.column(c));
3055 let k_axis = k_requested.min(n_distinct_axis).max(2);
3056 if k_axis < k_requested {
3057 log::info!(
3058 "tensor smooth: margin axis {axis} requested k={k_requested}, but the \
3059 covariate has only {n_distinct_axis} distinct value(s); reducing this \
3060 margin to k={k_axis} (mgcv-style data-support cap on the per-axis basis)."
3061 );
3062 }
3063 if k_axis < 2 {
3076 return Err(TermBuilderError::invalid_option(format!(
3077 "tensor smooth: k[{axis}]={k_axis} too small; tensor margins require k >= 2"
3078 ))
3079 .to_string());
3080 }
3081 if periodic_axes[axis] && k_axis < degree + 1 {
3082 return Err(TermBuilderError::invalid_option(format!(
3083 "tensor smooth: periodic axis {axis} requires k >= {} for degree {degree}, got k={k_axis}",
3084 degree + 1
3085 ))
3086 .to_string());
3087 }
3088 let effective_degree = degree.min(k_axis - 1).max(1);
3089 let effective_penalty_order = penalty_order.min(effective_degree);
3090 let (knotspec, boundary, axis_period) = if periodic_axes[axis] {
3091 let period_value = periods_opt[axis].ok_or_else(|| {
3092 format!(
3093 "tensor smooth axis {axis} is periodic but no period was supplied; \
3094 pass period=<value> (scalar) or period=[..., <value>, ...]"
3095 )
3096 })?;
3097 if !period_value.is_finite() || period_value <= 0.0 {
3098 return Err(format!(
3099 "tensor smooth axis {axis}: period must be a positive finite value, got {period_value}"
3100 ));
3101 }
3102 let domain_start = origins_opt[axis].unwrap_or(data_min);
3103 let domain_end = domain_start + period_value;
3104 (
3105 BSplineKnotSpec::PeriodicUniform {
3106 data_range: (domain_start, domain_end),
3107 num_basis: k_axis,
3108 },
3109 OneDimensionalBoundary::Cyclic {
3110 start: domain_start,
3111 end: domain_end,
3112 },
3113 Some(period_value),
3114 )
3115 } else if margin_wants_cr(&per_axis_bs[axis]) && k_axis >= 3 {
3116 let cr_knots =
3126 crate::basis::select_cr_knots(ds.values.column(c), k_axis)
3127 .map_err(|e| e.to_string())?;
3128 (
3129 BSplineKnotSpec::NaturalCubicRegression { knots: cr_knots },
3130 OneDimensionalBoundary::Open,
3131 None,
3132 )
3133 } else {
3134 let num_internal_knots = if effective_degree < degree {
3141 k_axis.saturating_sub(effective_degree + 1)
3142 } else {
3143 k_axis.saturating_sub(degree + 1).max(1)
3144 };
3145 let knotspec = match parse_knot_placement(options)? {
3146 crate::basis::BSplineKnotPlacement::Uniform => BSplineKnotSpec::Generate {
3147 data_range: (data_min, data_max),
3148 num_internal_knots,
3149 },
3150 crate::basis::BSplineKnotPlacement::Quantile => {
3151 crate::basis::auto_knot_vector_1d_quantile(
3152 ds.values.column(c),
3153 num_internal_knots,
3154 effective_degree,
3155 )
3156 .map_err(|e| e.to_string())?;
3157 BSplineKnotSpec::Automatic {
3158 num_internal_knots: Some(num_internal_knots),
3159 placement: crate::basis::BSplineKnotPlacement::Quantile,
3160 }
3161 }
3162 };
3163 (knotspec, OneDimensionalBoundary::Open, None)
3164 };
3165 let is_cr_margin =
3171 matches!(knotspec, BSplineKnotSpec::NaturalCubicRegression { .. });
3172 let margin_double_penalty =
3173 is_cr_margin && matches!(per_axis_bs[axis].as_deref(), Some("cs"));
3174 margins.push(BSplineBasisSpec {
3175 degree: effective_degree,
3176 penalty_order: effective_penalty_order,
3177 knotspec,
3178 double_penalty: margin_double_penalty,
3179 identifiability: BSplineIdentifiability::None,
3180 boundary,
3181 boundary_conditions: BSplineBoundaryConditions::default(),
3182 });
3183 emitted_periods.push(axis_period);
3184 }
3185 let canon_cols: Vec<usize> = {
3206 let mut perm: Vec<usize> = (0..dim).collect();
3207 perm.sort_by_key(|&a| cols[a]);
3208 if perm.iter().enumerate().any(|(i, &a)| i != a) {
3209 margins = perm.iter().map(|&a| margins[a].clone()).collect();
3210 emitted_periods = perm.iter().map(|&a| emitted_periods[a]).collect();
3211 }
3212 perm.iter().map(|&a| cols[a]).collect()
3213 };
3214 let any_periodic = emitted_periods.iter().any(|p| p.is_some());
3215 let periods_vec = if any_periodic {
3216 emitted_periods
3217 } else {
3218 Vec::new()
3219 };
3220 let tensor_double_penalty = option_bool(options, "double_penalty").unwrap_or(false);
3236 Ok(SmoothBasisSpec::TensorBSpline {
3237 feature_cols: canon_cols,
3238 spec: TensorBSplineSpec {
3239 marginalspecs: margins,
3240 periods: periods_vec,
3241 double_penalty: tensor_double_penalty,
3242 identifiability: parse_tensor_identifiability(options, kind)?,
3243 penalty_decomposition: if matches!(kind, SmoothKind::T2)
3253 || type_opt.as_str() == "t2"
3254 {
3255 TensorBSplinePenaltyDecomposition::Separable
3256 } else {
3257 TensorBSplinePenaltyDecomposition::MarginalKroneckerSum
3258 },
3259 },
3260 })
3261 }
3262 "pca" => {
3263 validate_known_options(
3264 "pca",
3265 options,
3266 &[
3267 "type",
3268 "bs",
3269 "by",
3270 "k",
3271 "basis_dim",
3272 "basis-dim",
3273 "basisdim",
3274 "lazy_path",
3275 "path",
3276 "pca_basis_path",
3277 "chunk_size",
3278 "smooth_penalty",
3279 "centered",
3280 "double_penalty",
3281 "id",
3282 "__by_col",
3283 ],
3284 )?;
3285 let path = options
3286 .get("lazy_path")
3287 .or_else(|| options.get("pca_basis_path"))
3288 .or_else(|| options.get("path"))
3289 .map(|raw| PathBuf::from(strip_quotes(raw)));
3290 let Some(path) = path else {
3291 return Err(TermBuilderError::incompatible_config(
3292 "pca smooth requires lazy_path=... on the formula path",
3293 )
3294 .to_string());
3295 };
3296 let k = option_usize_any(options, &["k", "basis_dim", "basis-dim", "basisdim"])
3297 .unwrap_or(0);
3298 let chunk_size = option_usize(options, "chunk_size").unwrap_or(DEFAULT_PCA_CHUNK_SIZE);
3299 Ok(SmoothBasisSpec::Pca {
3300 feature_cols: cols.to_vec(),
3301 basis_matrix: Array2::<f64>::zeros((cols.len(), k)),
3302 centered: option_bool(options, "centered").unwrap_or(true),
3303 smooth_penalty: option_f64(options, "smooth_penalty").unwrap_or(1.0),
3304 center_mean: None,
3305 pca_basis_path: Some(path),
3306 chunk_size,
3307 })
3308 }
3309 other => Err(TermBuilderError::unsupported_feature(format!(
3310 "unsupported smooth type '{other}'"
3311 ))
3312 .to_string()),
3313 }
3314}
3315
3316pub fn enable_scale_dimensions(spec: &mut TermCollectionSpec) {
3318 for smooth in spec.smooth_terms.iter_mut() {
3319 match &mut smooth.basis {
3320 SmoothBasisSpec::Matern {
3321 feature_cols,
3322 spec: matern,
3323 ..
3324 } => {
3325 if matern.aniso_log_scales.is_none() {
3326 let d = feature_cols.len();
3327 matern.aniso_log_scales = Some(vec![0.0; d]);
3328 }
3329 }
3330 SmoothBasisSpec::Duchon {
3331 feature_cols,
3332 spec: duchon,
3333 ..
3334 } => {
3335 if duchon.aniso_log_scales.is_none() {
3336 let d = feature_cols.len();
3337 duchon.aniso_log_scales = Some(vec![0.0; d]);
3338 }
3339 }
3340 _ => {}
3341 }
3342 }
3343}
3344
3345pub fn spatial_center_strategy_for_dimension(num_centers: usize, d: usize) -> CenterStrategy {
3350 if d <= 3 {
3351 CenterStrategy::FarthestPoint { num_centers }
3358 } else {
3359 default_spatial_center_strategy(num_centers, d)
3360 }
3361}
3362
3363pub fn col_minmax(col: ArrayView1<'_, f64>) -> Result<(f64, f64), String> {
3364 let min = col.iter().fold(f64::INFINITY, |a, &b| a.min(b));
3365 let max = col.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
3366 if !min.is_finite() || !max.is_finite() {
3367 return Err(TermBuilderError::degenerate_data(
3368 "non-finite data encountered while inferring knot range",
3369 )
3370 .to_string());
3371 }
3372 if (max - min).abs() < 1e-12 {
3373 Ok((min, min + 1e-6))
3374 } else {
3375 Ok((min, max))
3376 }
3377}
3378
3379pub fn unique_count_column(col: ArrayView1<'_, f64>) -> usize {
3380 use std::collections::HashSet;
3381 let mut set = HashSet::<u64>::with_capacity(col.len());
3382 for &v in col {
3383 let norm = if v == 0.0 { 0.0 } else { v };
3384 set.insert(norm.to_bits());
3385 }
3386 set.len().max(1)
3387}
3388
3389pub(crate) const CR_MIN_KNOTS: usize = 3;
3395
3396fn capped_cr_marginal_knotspec(
3423 col: ArrayView1<'_, f64>,
3424 k_cr_requested: usize,
3425 label: &str,
3426 inference_notes: &mut Vec<String>,
3427) -> Result<Option<BSplineKnotSpec>, String> {
3428 let n_distinct = unique_count_column(col);
3429 let k_cr = k_cr_requested.min(n_distinct);
3430 if k_cr < CR_MIN_KNOTS {
3431 inference_notes.push(format!(
3432 "Smooth '{label}': cubic-regression ('cr'/'cs'/'sz') basis requested k={k_cr_requested}, \
3433 but the covariate has only {n_distinct} distinct value(s) — too few to support a cubic \
3434 regression spline (needs >= {CR_MIN_KNOTS} distinct values). Degraded to the linear \
3435 B-spline marginal the default basis builds on the same data."
3436 ));
3437 return Ok(None);
3438 }
3439 if k_cr < k_cr_requested {
3440 inference_notes.push(format!(
3441 "Smooth '{label}': cubic-regression ('cr'/'cs'/'sz') basis reduced from k={k_cr_requested} \
3442 to k={k_cr} to match the covariate's {n_distinct} distinct value(s) (mgcv-style \
3443 data-support cap; a cr basis cannot place more value-knots than the data has)."
3444 ));
3445 }
3446 let cr_knots = crate::basis::select_cr_knots(col, k_cr).map_err(|e| e.to_string())?;
3447 Ok(Some(BSplineKnotSpec::NaturalCubicRegression {
3448 knots: cr_knots,
3449 }))
3450}
3451
3452fn min_per_group_unique_count(
3459 feature_col: ArrayView1<'_, f64>,
3460 group_col: ArrayView1<'_, f64>,
3461) -> usize {
3462 use std::collections::{HashMap, HashSet};
3463 let mut per_group: HashMap<u64, HashSet<u64>> = HashMap::new();
3464 for (xi, gi) in feature_col.iter().zip(group_col.iter()) {
3465 let xnorm = if *xi == 0.0 { 0.0 } else { *xi };
3466 let gnorm = if *gi == 0.0 { 0.0 } else { *gi };
3467 per_group
3468 .entry(gnorm.to_bits())
3469 .or_default()
3470 .insert(xnorm.to_bits());
3471 }
3472 per_group
3473 .values()
3474 .map(|s| s.len())
3475 .min()
3476 .unwrap_or(1)
3477 .max(1)
3478}
3479
3480pub fn heuristic_knots_for_column(col: ArrayView1<'_, f64>) -> usize {
3485 let unique = unique_count_column(col);
3486 let ceiling = ((unique as f64).cbrt() as usize).max(20);
3487 (unique / 4).clamp(4, ceiling)
3488}
3489
3490fn heuristic_tensor_margin_knots(cols: &[usize], ds: &Dataset) -> Vec<usize> {
3510 let d = cols.len().max(1);
3511 let degree = DEFAULT_BSPLINE_DEGREE;
3512 let min_k = degree + 2; let n = ds.values.nrows();
3514
3515 let per_margin_cap: Vec<usize> = cols
3519 .iter()
3520 .map(|&c| heuristic_knots_for_column(ds.values.column(c)).max(min_k))
3521 .collect();
3522
3523 let mgcv_like_per_margin = match d {
3530 2 => 7usize,
3531 3 => 5usize,
3532 _ => 4usize,
3533 };
3534 let mgcv_like_total = (mgcv_like_per_margin as f64).powi(d as i32);
3535 let data_budget = (n as f64) * 0.8;
3536 let p_target = mgcv_like_total
3537 .max(min_k.pow(d as u32) as f64)
3538 .min(data_budget);
3539
3540 let geo_per_margin = p_target.powf(1.0 / d as f64).round() as usize;
3543 let unclamped: Vec<usize> = per_margin_cap
3544 .iter()
3545 .map(|&cap| geo_per_margin.clamp(min_k, cap))
3546 .collect();
3547
3548 let mut k_list = unclamped;
3553 loop {
3554 let product: f64 = k_list.iter().map(|&k| k as f64).product();
3555 if product >= p_target {
3556 break;
3557 }
3558 let Some(idx) = k_list
3561 .iter()
3562 .zip(per_margin_cap.iter())
3563 .enumerate()
3564 .filter(|&(_, (k, cap))| k < cap)
3565 .max_by_key(|&(_, (k, cap))| (cap - k, *cap))
3566 .map(|(i, _)| i)
3567 else {
3568 break;
3569 };
3570 k_list[idx] += 1;
3571 }
3572 k_list
3573}
3574
3575pub fn heuristic_centers(n: usize, d: usize) -> usize {
3576 default_num_centers(n, d)
3577}
3578
3579fn parse_endpoint_side(
3584 value: &str,
3585 context: &str,
3586) -> Result<BSplineEndpointBoundaryCondition, String> {
3587 match value.trim().to_ascii_lowercase().as_str() {
3588 "" | "none" | "open" | "unconstrained" | "free" => {
3589 Ok(BSplineEndpointBoundaryCondition::Free)
3590 }
3591 "clamped" | "clamp" | "zero_derivative" | "zero-derivative" => {
3592 Ok(BSplineEndpointBoundaryCondition::Clamped)
3593 }
3594 "anchored" | "anchor" | "zero" | "zero_value" | "zero-value" => {
3595 Ok(BSplineEndpointBoundaryCondition::Anchored { value: 0.0 })
3596 }
3597 other => Err(format!(
3598 "unsupported {context} boundary condition '{other}'; expected free, clamped, or anchored"
3599 )),
3600 }
3601}
3602
3603fn boundary_anchor_value(
3604 options: &BTreeMap<String, String>,
3605 side: &str,
3606 fallback: Option<f64>,
3607) -> Option<f64> {
3608 [
3609 format!("anchor_{side}"),
3610 format!("{side}_anchor"),
3611 format!("anchor-value-{side}"),
3612 ]
3613 .iter()
3614 .find_map(|key| option_f64(options, key))
3615 .or(fallback)
3616}
3617
3618fn apply_anchor_value(
3619 cond: BSplineEndpointBoundaryCondition,
3620 value: Option<f64>,
3621) -> BSplineEndpointBoundaryCondition {
3622 match cond {
3623 BSplineEndpointBoundaryCondition::Anchored { .. } => {
3624 BSplineEndpointBoundaryCondition::Anchored {
3625 value: value.unwrap_or(0.0),
3626 }
3627 }
3628 other => other,
3629 }
3630}
3631
3632fn parse_bspline_boundary_conditions(
3633 options: &BTreeMap<String, String>,
3634) -> Result<BSplineBoundaryConditions, String> {
3635 let fallback_anchor = option_f64(options, "anchor")
3636 .or_else(|| option_f64(options, "anchor_value"))
3637 .or_else(|| option_f64(options, "value"));
3638 let global_boundary_conditions = options
3639 .get("boundary_conditions")
3640 .or_else(|| options.get("bc"));
3641 let mut boundary_conditions = BSplineBoundaryConditions::default();
3642
3643 if let Some(raw_boundary_conditions) = global_boundary_conditions {
3644 let cond = parse_endpoint_side(raw_boundary_conditions, "boundary_conditions")?;
3645 let side = options
3646 .get("side")
3647 .map(|s| s.trim().to_ascii_lowercase())
3648 .unwrap_or_else(|| "both".to_string());
3649 match side.as_str() {
3650 "both" | "all" | "endpoints" => {
3651 boundary_conditions.left = cond;
3652 boundary_conditions.right = cond;
3653 }
3654 "left" | "start" | "lower" => boundary_conditions.left = cond,
3655 "right" | "end" | "upper" => boundary_conditions.right = cond,
3656 other => {
3657 return Err(format!(
3658 "unsupported B-spline boundary side '{other}'; expected left, right, or both"
3659 ));
3660 }
3661 }
3662 }
3663
3664 if let Some(raw) = options
3665 .get("bc_left")
3666 .or_else(|| options.get("left_bc"))
3667 .or_else(|| options.get("bc_start"))
3668 .or_else(|| options.get("start_bc"))
3669 {
3670 boundary_conditions.left = parse_endpoint_side(raw, "left endpoint")?;
3671 }
3672 if let Some(raw) = options
3673 .get("bc_right")
3674 .or_else(|| options.get("right_bc"))
3675 .or_else(|| options.get("bc_end"))
3676 .or_else(|| options.get("end_bc"))
3677 {
3678 boundary_conditions.right = parse_endpoint_side(raw, "right endpoint")?;
3679 }
3680
3681 boundary_conditions.left = apply_anchor_value(
3682 boundary_conditions.left,
3683 boundary_anchor_value(options, "left", fallback_anchor),
3684 );
3685 boundary_conditions.right = apply_anchor_value(
3686 boundary_conditions.right,
3687 boundary_anchor_value(options, "right", fallback_anchor),
3688 );
3689
3690 reject_nonzero_anchor("left", boundary_conditions.left)?;
3698 reject_nonzero_anchor("right", boundary_conditions.right)?;
3699
3700 Ok(boundary_conditions)
3701}
3702
3703fn reject_nonzero_anchor(side: &str, cond: BSplineEndpointBoundaryCondition) -> Result<(), String> {
3704 if let BSplineEndpointBoundaryCondition::Anchored { value } = cond {
3705 if value.abs() > 1e-12 {
3706 return Err(format!(
3707 "non-zero {side} anchor {value} requires an affine offset term that is not yet supported; only anchored value 0 is accepted at parse time"
3708 ));
3709 }
3710 }
3711 Ok(())
3712}
3713
3714fn parse_ps_internal_knots(
3728 options: &BTreeMap<String, String>,
3729 degree: usize,
3730 default_internal_knots: usize,
3731) -> Result<(usize, bool, usize), String> {
3732 const MIN_EXPRESSIVE_INTERNAL_KNOTS: usize = 2;
3733 let knots_internal = if knots_option_is_list(options) {
3743 None
3744 } else {
3745 option_usize_strict(options, "knots")?
3746 };
3747 let basis_dim = option_usize_any_strict(options, &["k", "basis_dim", "basis-dim", "basisdim"])?;
3748 if knots_internal.is_some() && basis_dim.is_some() {
3749 return Err(TermBuilderError::incompatible_config(
3750 "ps/bspline smooth: specify either knots=<internal_knots> or k=<basis_dim> (not both)",
3751 )
3752 .to_string());
3753 }
3754 if let Some(k) = basis_dim {
3755 if k < 2 {
3756 return Err(TermBuilderError::invalid_option(format!(
3757 "ps/bspline smooth: k={} too small; B-spline basis requires k >= 2",
3758 k
3759 ))
3760 .to_string());
3761 }
3762 let effective_degree = degree.min(k - 1).max(1);
3768 let num_internal_knots = if effective_degree < degree {
3769 k.saturating_sub(effective_degree + 1)
3772 } else {
3773 (k - degree - 1).max(MIN_EXPRESSIVE_INTERNAL_KNOTS)
3774 };
3775 Ok((num_internal_knots, false, effective_degree))
3776 } else {
3777 Ok((
3778 knots_internal.unwrap_or(default_internal_knots),
3779 knots_internal.is_none(),
3780 degree,
3781 ))
3782 }
3783}
3784
3785fn knots_option_is_list(options: &BTreeMap<String, String>) -> bool {
3791 options
3792 .get("knots")
3793 .map(|raw| {
3794 let t = raw.trim();
3795 t.starts_with('[') || t.starts_with("c(") || t.starts_with("C(") || t.starts_with('(')
3796 })
3797 .unwrap_or(false)
3798}
3799
3800fn parse_explicit_internal_knots(
3805 options: &BTreeMap<String, String>,
3806) -> Result<Option<Vec<f64>>, String> {
3807 if !knots_option_is_list(options) {
3808 return Ok(None);
3809 }
3810 let raw = options
3811 .get("knots")
3812 .expect("knots_option_is_list implies the key is present");
3813 let tokens = split_list_option(raw);
3814 if tokens.is_empty() {
3815 return Err(TermBuilderError::invalid_option(format!(
3816 "knots={raw} is an empty list; supply at least one internal knot position \
3817 (e.g. knots=[0.2, 0.5, 0.8]) or a scalar count (e.g. knots=8)"
3818 ))
3819 .to_string());
3820 }
3821 let mut positions = Vec::with_capacity(tokens.len());
3822 for tok in &tokens {
3823 let value = parse_numeric_expr(tok).map_err(|err| {
3824 TermBuilderError::invalid_option(format!(
3825 "knots list entry '{tok}' is not a numeric position: {err}"
3826 ))
3827 .to_string()
3828 })?;
3829 positions.push(value);
3830 }
3831 Ok(Some(positions))
3832}
3833
3834fn parse_knot_placement(
3840 options: &BTreeMap<String, String>,
3841) -> Result<crate::basis::BSplineKnotPlacement, String> {
3842 use crate::basis::BSplineKnotPlacement;
3843 match options
3844 .get("knot_placement")
3845 .or_else(|| options.get("knot-placement"))
3846 .or_else(|| options.get("knotplacement"))
3847 {
3848 None => Ok(BSplineKnotPlacement::Uniform),
3849 Some(raw) => match raw
3850 .trim()
3851 .trim_matches('"')
3852 .trim_matches('\'')
3853 .to_ascii_lowercase()
3854 .as_str()
3855 {
3856 "uniform" | "even" | "equal" => Ok(BSplineKnotPlacement::Uniform),
3857 "quantile" | "quantiles" | "data" | "empirical" => Ok(BSplineKnotPlacement::Quantile),
3858 other => Err(TermBuilderError::invalid_option(format!(
3859 "knot_placement={other} is not recognised; expected \"uniform\" or \"quantile\""
3860 ))
3861 .to_string()),
3862 },
3863 }
3864}
3865
3866fn resolve_nonperiodic_bspline_knotspec(
3877 options: &BTreeMap<String, String>,
3878 data: ArrayView1<'_, f64>,
3879 data_range: (f64, f64),
3880 degree: usize,
3881 n_knots: usize,
3882) -> Result<BSplineKnotSpec, String> {
3883 use crate::basis::{BSplineKnotPlacement, clamped_knot_vector_from_internal_positions};
3884 if let Some(positions) = parse_explicit_internal_knots(options)? {
3885 if option_usize_any_strict(options, &["k", "basis_dim", "basis-dim", "basisdim"])?.is_some()
3886 {
3887 return Err(TermBuilderError::incompatible_config(
3888 "ps/bspline smooth: specify either explicit knots=[...] positions or \
3889 k=<basis_dim> (not both); the basis size is fixed by the knot vector",
3890 )
3891 .to_string());
3892 }
3893 let knots = clamped_knot_vector_from_internal_positions(data_range, &positions, degree)
3894 .map_err(|e| e.to_string())?;
3895 return Ok(BSplineKnotSpec::Provided(knots));
3896 }
3897 match parse_knot_placement(options)? {
3898 BSplineKnotPlacement::Uniform => Ok(BSplineKnotSpec::Generate {
3899 data_range,
3900 num_internal_knots: n_knots,
3901 }),
3902 BSplineKnotPlacement::Quantile => {
3903 crate::basis::auto_knot_vector_1d_quantile(data, n_knots, degree)
3907 .map_err(|e| e.to_string())?;
3908 Ok(BSplineKnotSpec::Automatic {
3909 num_internal_knots: Some(n_knots),
3910 placement: BSplineKnotPlacement::Quantile,
3911 })
3912 }
3913 }
3914}
3915
3916pub fn validate_known_options(
3922 term_name: &str,
3923 options: &BTreeMap<String, String>,
3924 known: &[&str],
3925) -> Result<(), String> {
3926 let known_set: std::collections::BTreeSet<&&str> = known.iter().collect();
3927 for key in options.keys() {
3928 if !known_set.contains(&key.as_str()) {
3929 if term_name == "tensor" && is_tensor_k_axis_option_key(key) {
3930 continue;
3931 }
3932 let key_l = key.to_ascii_lowercase();
3934 let mut suggestions: Vec<&str> = known
3935 .iter()
3936 .filter(|k| {
3937 let kl = k.to_ascii_lowercase();
3938 kl.contains(&key_l) || key_l.contains(&kl) || {
3939 let n = kl
3940 .chars()
3941 .zip(key_l.chars())
3942 .take_while(|(a, b)| a == b)
3943 .count();
3944 n >= 3
3945 }
3946 })
3947 .copied()
3948 .collect();
3949 suggestions.sort_unstable();
3950 suggestions.dedup();
3951 let hint = if suggestions.is_empty() {
3952 String::new()
3953 } else {
3954 format!(" — did you mean one of [{}]?", suggestions.join(", "))
3955 };
3956 return Err(TermBuilderError::invalid_option(format!(
3957 "{term_name}() does not accept option `{key}`{hint}. Valid options: [{}]",
3958 {
3959 let mut sorted = known.to_vec();
3960 sorted.sort_unstable();
3961 sorted.join(", ")
3962 }
3963 ))
3964 .to_string());
3965 }
3966 }
3967 Ok(())
3968}
3969
3970pub const SECONDARY_CENTER_CAP_OPTION: &str = "__secondary_center_cap";
3980
3981pub(crate) fn cap_default_spatial_centers(
3986 options: &BTreeMap<String, String>,
3987 default_count: usize,
3988) -> usize {
3989 match option_usize(options, SECONDARY_CENTER_CAP_OPTION) {
3990 Some(cap) => default_count.min(cap),
3991 None => default_count,
3992 }
3993}
3994
3995fn default_matern_center_count(n: usize, d: usize, planned_count: usize) -> usize {
3996 let low_n_floor = (d + 4).min(n);
4003 planned_count.max(low_n_floor).max(1)
4004}
4005
4006pub fn parse_countwith_basis_alias(
4007 options: &BTreeMap<String, String>,
4008 primarykey: &str,
4009 default_count: usize,
4010) -> Result<usize, String> {
4011 let primary = option_usize_strict(options, primarykey)?;
4016 let basis_dim = option_usize_any_strict(
4017 options,
4018 &["k", "basis_dim", "basis-dim", "basisdim", "knots"],
4019 )?;
4020 if primary.is_some() && basis_dim.is_some() {
4021 return Err(TermBuilderError::incompatible_config(format!(
4022 "specify either {}=<count> or k=<basis_dim> (not both)",
4023 primarykey
4024 ))
4025 .to_string());
4026 }
4027 Ok(primary.or(basis_dim).unwrap_or(default_count))
4028}
4029
4030pub fn has_explicit_countwith_basis_alias(
4031 options: &BTreeMap<String, String>,
4032 primarykey: &str,
4033) -> bool {
4034 options.contains_key(primarykey)
4035 || ["k", "basis_dim", "basis-dim", "basisdim", "knots"]
4036 .iter()
4037 .any(|alias| options.contains_key(*alias))
4038}
4039
4040pub fn parse_cyclic_boundary(
4041 options: &BTreeMap<String, String>,
4042 minv: f64,
4043 maxv: f64,
4044) -> Result<OneDimensionalBoundary, String> {
4045 let cyclic = option_bool(options, "cyclic")
4046 .or_else(|| option_bool(options, "periodic"))
4047 .unwrap_or(false);
4048 if !cyclic {
4049 return Ok(OneDimensionalBoundary::Open);
4050 }
4051 let start = match option_numeric_expr(options, "period_start")? {
4052 Some(v) => v,
4053 None => option_numeric_expr(options, "start")?.unwrap_or(minv),
4054 };
4055 let end = match option_numeric_expr(options, "period_end")? {
4056 Some(v) => v,
4057 None => option_numeric_expr(options, "end")?.unwrap_or(maxv),
4058 };
4059 if end <= start {
4060 return Err(format!(
4061 "cyclic smooth requires period_end/end ({end}) > period_start/start ({start})"
4062 ));
4063 }
4064 Ok(OneDimensionalBoundary::Cyclic { start, end })
4065}
4066
4067pub fn parse_periodic_domain_1d(
4074 options: &BTreeMap<String, String>,
4075 minv: f64,
4076 maxv: f64,
4077) -> Result<(f64, f64), String> {
4078 let start = match option_numeric_expr(options, "period_start")? {
4079 Some(v) => v,
4080 None => option_numeric_expr(options, "start")?.unwrap_or(minv),
4081 };
4082 let end = match option_numeric_expr(options, "period_end")? {
4083 Some(v) => v,
4084 None => option_numeric_expr(options, "end")?.unwrap_or(maxv),
4085 };
4086 if !(start.is_finite() && end.is_finite()) {
4087 return Err(format!(
4088 "periodic smooth domain requires finite endpoints, got ({start}, {end})"
4089 ));
4090 }
4091 if end <= start {
4092 return Err(format!(
4093 "periodic smooth requires period_end/end ({end}) > period_start/start ({start})"
4094 ));
4095 }
4096 Ok((start, end - start))
4097}
4098
4099fn parse_matern_nu(raw: &str) -> Result<MaternNu, String> {
4100 let trimmed = raw.trim();
4101 let lowered = trimmed.to_ascii_lowercase();
4102 match lowered.as_str() {
4103 "1/2" | "0.5" | "half" => return Ok(MaternNu::Half),
4104 "3/2" | "1.5" => return Ok(MaternNu::ThreeHalves),
4105 "5/2" | "2.5" => return Ok(MaternNu::FiveHalves),
4106 "7/2" | "3.5" => return Ok(MaternNu::SevenHalves),
4107 "9/2" | "4.5" => return Ok(MaternNu::NineHalves),
4108 _ => {}
4109 }
4110
4111 let value = if let Some((num, den)) = trimmed.split_once('/') {
4112 let num = num
4113 .trim()
4114 .parse::<f64>()
4115 .map_err(|err| format!("{}: {err}", unsupported_matern_nu_message(raw)))?;
4116 let den = den
4117 .trim()
4118 .parse::<f64>()
4119 .map_err(|err| format!("{}: {err}", unsupported_matern_nu_message(raw)))?;
4120 if den == 0.0 || !num.is_finite() || !den.is_finite() {
4121 return Err(unsupported_matern_nu_message(raw));
4122 }
4123 num / den
4124 } else {
4125 trimmed
4126 .parse::<f64>()
4127 .map_err(|err| format!("{}: {err}", unsupported_matern_nu_message(raw)))?
4128 };
4129
4130 const TOL: f64 = 1e-12;
4131 if (value - 0.5).abs() <= TOL {
4132 Ok(MaternNu::Half)
4133 } else if (value - 1.5).abs() <= TOL {
4134 Ok(MaternNu::ThreeHalves)
4135 } else if (value - 2.5).abs() <= TOL {
4136 Ok(MaternNu::FiveHalves)
4137 } else if (value - 3.5).abs() <= TOL {
4138 Ok(MaternNu::SevenHalves)
4139 } else if (value - 4.5).abs() <= TOL {
4140 Ok(MaternNu::NineHalves)
4141 } else {
4142 Err(unsupported_matern_nu_message(raw))
4143 }
4144}
4145
4146fn unsupported_matern_nu_message(raw: &str) -> String {
4147 TermBuilderError::unsupported_feature(format!(
4148 "unsupported Matern nu '{raw}'; supported half-integer values are 1/2, 3/2, 5/2, 7/2, and 9/2"
4149 ))
4150 .to_string()
4151}
4152
4153#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
4154pub enum DuchonPowerPolicy {
4155 Explicit(f64),
4156 CubicStructuralDefault,
4160}
4161
4162pub fn parse_duchon_power_policy(
4163 options: &BTreeMap<String, String>,
4164) -> Result<DuchonPowerPolicy, String> {
4165 if let Some(raw_nu) = options.get("nu") {
4166 return Err(TermBuilderError::incompatible_config(format!(
4167 "Duchon smooths use power=<number>, not nu='{}'. Use power=1.5, power=2, etc.",
4168 raw_nu
4169 ))
4170 .to_string());
4171 }
4172 match options.get("power") {
4173 Some(raw) => {
4174 let value = raw.parse::<f64>().map_err(|err| {
4175 TermBuilderError::invalid_option(format!(
4176 "invalid Duchon power '{}'; expected a non-negative number such as power=1.5 or power=2: {}",
4177 raw, err
4178 ))
4179 .to_string()
4180 })?;
4181 if !value.is_finite() || value < 0.0 {
4182 return Err(TermBuilderError::invalid_option(format!(
4183 "invalid Duchon power '{}'; expected a finite non-negative number such as power=1.5 or power=2",
4184 raw
4185 ))
4186 .to_string());
4187 }
4188 Ok(DuchonPowerPolicy::Explicit(value))
4189 }
4190 None => Ok(DuchonPowerPolicy::CubicStructuralDefault),
4191 }
4192}
4193
4194pub fn parse_duchon_power(options: &BTreeMap<String, String>) -> Result<f64, String> {
4195 match parse_duchon_power_policy(options)? {
4196 DuchonPowerPolicy::Explicit(power) => Ok(power),
4197 DuchonPowerPolicy::CubicStructuralDefault => Ok(1.5),
4203 }
4204}
4205
4206pub fn parse_duchon_order(
4207 options: &BTreeMap<String, String>,
4208) -> Result<DuchonNullspaceOrder, String> {
4209 match options.get("order") {
4210 None => Ok(DuchonNullspaceOrder::Linear),
4214 Some(raw) => match raw.parse::<usize>() {
4215 Ok(0) => Ok(DuchonNullspaceOrder::Zero),
4216 Ok(1) => Ok(DuchonNullspaceOrder::Linear),
4217 Ok(other) => Ok(DuchonNullspaceOrder::Degree(other)),
4218 Err(_) => Err(TermBuilderError::invalid_option(format!(
4219 "invalid Duchon order '{}'; expected a non-negative integer such as order=0, order=1, or order=2",
4220 raw
4221 ))
4222 .to_string()),
4223 },
4224 }
4225}
4226
4227fn parse_matern_identifiability(
4228 options: &BTreeMap<String, String>,
4229) -> Result<MaternIdentifiability, TermBuilderError> {
4230 let Some(raw) = options.get("identifiability").map(String::as_str) else {
4231 return Ok(MaternIdentifiability::default());
4232 };
4233 match raw.trim().to_ascii_lowercase().as_str() {
4234 "none" => Ok(MaternIdentifiability::None),
4235 "sum_tozero" | "sum-to-zero" | "center_sum_tozero" | "center-sum-to-zero" | "centered" => {
4236 Ok(MaternIdentifiability::CenterSumToZero)
4237 }
4238 "linear" | "center_linear_orthogonal" | "center-linear-orthogonal" => {
4239 Ok(MaternIdentifiability::CenterLinearOrthogonal)
4240 }
4241 other => Err(TermBuilderError::unsupported_feature(format!(
4242 "invalid Matérn identifiability '{other}'; expected one of: none, sum_tozero, linear"
4243 ))),
4244 }
4245}
4246
4247fn parse_spatial_identifiability(
4248 options: &BTreeMap<String, String>,
4249) -> Result<SpatialIdentifiability, TermBuilderError> {
4250 let Some(raw) = options.get("identifiability").map(String::as_str) else {
4251 return Ok(SpatialIdentifiability::default());
4252 };
4253 match raw.trim().to_ascii_lowercase().as_str() {
4254 "none" => Ok(SpatialIdentifiability::None),
4255 "orthogonal"
4256 | "orthogonal_to_parametric"
4257 | "orthogonal-to-parametric"
4258 | "parametric_orthogonal" => Ok(SpatialIdentifiability::OrthogonalToParametric),
4259 "frozen" => Err(TermBuilderError::unsupported_feature(
4260 "spatial identifiability 'frozen' is internal-only; use none or orthogonal_to_parametric",
4261 )),
4262 other => Err(TermBuilderError::unsupported_feature(format!(
4263 "invalid spatial identifiability '{other}'; expected one of: none, orthogonal_to_parametric"
4264 ))),
4265 }
4266}
4267
4268#[cfg(test)]
4269mod tests {
4270 use super::*;
4271 use crate::inference::formula_dsl::parse_formula;
4272 use gam_data::{DataSchema, SchemaColumn};
4273 use ndarray::Array2;
4274 use std::collections::BTreeMap;
4275
4276 fn continuous_dataset(headers: &[&str], rows: Vec<Vec<f64>>) -> Dataset {
4277 let nrows = rows.len();
4278 let ncols = headers.len();
4279 let values = Array2::from_shape_vec(
4280 (nrows, ncols),
4281 rows.into_iter().flat_map(|row| row.into_iter()).collect(),
4282 )
4283 .expect("rectangular test data");
4284 Dataset {
4285 headers: headers.iter().map(|name| name.to_string()).collect(),
4286 values,
4287 schema: DataSchema {
4288 columns: headers
4289 .iter()
4290 .map(|name| SchemaColumn {
4291 name: name.to_string(),
4292 kind: ColumnKindTag::Continuous,
4293 levels: vec![],
4294 })
4295 .collect(),
4296 },
4297 column_kinds: vec![ColumnKindTag::Continuous; ncols],
4298 }
4299 }
4300
4301 fn factor_dataset() -> Dataset {
4302 let rows = (0..24)
4303 .map(|i| {
4304 let x = i as f64 / 23.0;
4305 let g = (i % 2) as f64;
4306 vec![x + g, x, g]
4307 })
4308 .collect::<Vec<_>>();
4309 Dataset {
4310 headers: vec!["y".into(), "x".into(), "g".into()],
4311 values: Array2::from_shape_vec(
4312 (rows.len(), 3),
4313 rows.into_iter().flat_map(|row| row.into_iter()).collect(),
4314 )
4315 .expect("rectangular factor test data"),
4316 schema: DataSchema {
4317 columns: vec![
4318 SchemaColumn {
4319 name: "y".into(),
4320 kind: ColumnKindTag::Continuous,
4321 levels: vec![],
4322 },
4323 SchemaColumn {
4324 name: "x".into(),
4325 kind: ColumnKindTag::Continuous,
4326 levels: vec![],
4327 },
4328 SchemaColumn {
4329 name: "g".into(),
4330 kind: ColumnKindTag::Categorical,
4331 levels: vec!["a".into(), "b".into()],
4332 },
4333 ],
4334 },
4335 column_kinds: vec![
4336 ColumnKindTag::Continuous,
4337 ColumnKindTag::Continuous,
4338 ColumnKindTag::Categorical,
4339 ],
4340 }
4341 }
4342
4343 #[test]
4351 fn default_univariate_thinplate_basis_dim_is_modest() {
4352 let n = 300usize;
4355 let rows: Vec<Vec<f64>> = (0..n)
4356 .map(|i| {
4357 let x = -3.0 + 6.0 * (i as f64) / ((n - 1) as f64);
4358 vec![x.sin(), x]
4359 })
4360 .collect();
4361 let ds = continuous_dataset(&["y", "x"], rows);
4362
4363 let mut options = BTreeMap::new();
4364 options.insert("bs".to_string(), "tp".to_string());
4365
4366 let mut notes = Vec::new();
4367 let basis = build_smooth_basis(
4368 SmoothKind::S,
4369 &["x".to_string()],
4370 &[1],
4371 &options,
4372 &ds,
4373 &mut notes,
4374 &ResourcePolicy::default_library(),
4375 1,
4376 )
4377 .expect("build default univariate tp smooth");
4378
4379 let centers = match &basis {
4380 SmoothBasisSpec::ThinPlate { spec, .. } => match &spec.center_strategy {
4381 CenterStrategy::Auto(inner) => match inner.as_ref() {
4382 CenterStrategy::FarthestPoint { num_centers }
4383 | CenterStrategy::EqualMass { num_centers }
4384 | CenterStrategy::EqualMassCovarRepresentative { num_centers }
4385 | CenterStrategy::KMeans { num_centers, .. } => *num_centers,
4386 other => panic!("unexpected auto inner center strategy: {other:?}"),
4387 },
4388 CenterStrategy::FarthestPoint { num_centers }
4389 | CenterStrategy::EqualMass { num_centers }
4390 | CenterStrategy::EqualMassCovarRepresentative { num_centers }
4391 | CenterStrategy::KMeans { num_centers, .. } => *num_centers,
4392 other => panic!("unexpected center strategy: {other:?}"),
4393 },
4394 other => panic!("expected ThinPlate basis, got {other:?}"),
4395 };
4396
4397 assert!(
4401 centers >= 1,
4402 "default univariate tp must still build a usable basis (centers={centers})",
4403 );
4404 }
4405
4406 fn inferred_tensor_basis_product(ds: &Dataset) -> usize {
4407 let parsed = parse_formula("y ~ te(theta, h)").expect("parse tensor formula");
4408 let col_map = ds.column_map();
4409 let mut notes = Vec::new();
4410 let terms = build_termspec(
4411 &parsed.terms,
4412 ds,
4413 &col_map,
4414 &mut notes,
4415 &ResourcePolicy::default_library(),
4416 )
4417 .expect("build tensor termspec");
4418 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
4419 panic!("expected tensor smooth");
4420 };
4421 spec.marginalspecs
4422 .iter()
4423 .map(|marginal| match marginal.knotspec {
4424 BSplineKnotSpec::Generate {
4425 num_internal_knots, ..
4426 } => num_internal_knots + marginal.degree + 1,
4427 BSplineKnotSpec::PeriodicUniform { num_basis, .. } => num_basis,
4428 BSplineKnotSpec::Automatic {
4429 num_internal_knots: Some(num_internal_knots),
4430 ..
4431 } => num_internal_knots + marginal.degree + 1,
4432 BSplineKnotSpec::Automatic {
4433 num_internal_knots: None,
4434 ..
4435 } => panic!("test helper cannot infer automatic knot count"),
4436 BSplineKnotSpec::Provided(ref knots) => {
4437 knots.len().saturating_sub(marginal.degree + 1)
4438 }
4439 BSplineKnotSpec::NaturalCubicRegression { ref knots } => knots.len(),
4441 })
4442 .product()
4443 }
4444
4445 fn tensor_margin_basis_sizes(ds: &Dataset, formula: &str) -> Vec<usize> {
4446 let parsed = parse_formula(formula).expect("parse tensor formula");
4447 let col_map = ds.column_map();
4448 let mut notes = Vec::new();
4449 let terms = build_termspec(
4450 &parsed.terms,
4451 ds,
4452 &col_map,
4453 &mut notes,
4454 &ResourcePolicy::default_library(),
4455 )
4456 .expect("build tensor termspec");
4457 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
4458 panic!("expected tensor smooth");
4459 };
4460 spec.marginalspecs
4461 .iter()
4462 .map(|marginal| match marginal.knotspec {
4463 BSplineKnotSpec::Generate {
4464 num_internal_knots, ..
4465 } => num_internal_knots + marginal.degree + 1,
4466 BSplineKnotSpec::PeriodicUniform { num_basis, .. } => num_basis,
4467 BSplineKnotSpec::Automatic {
4468 num_internal_knots: Some(num_internal_knots),
4469 ..
4470 } => num_internal_knots + marginal.degree + 1,
4471 BSplineKnotSpec::Automatic {
4472 num_internal_knots: None,
4473 ..
4474 } => panic!("test helper cannot infer automatic knot count"),
4475 BSplineKnotSpec::Provided(ref knots) => {
4476 knots.len().saturating_sub(marginal.degree + 1)
4477 }
4478 BSplineKnotSpec::NaturalCubicRegression { ref knots } => knots.len(),
4480 })
4481 .collect()
4482 }
4483
4484 #[test]
4485 fn validate_known_options_lists_valid_option_names_for_unknown_parameter() {
4486 let mut options = BTreeMap::new();
4487 options.insert("lengt_scale".to_string(), "0.25".to_string());
4488 let err = validate_known_options(
4489 "matern",
4490 &options,
4491 &["type", "bs", "length_scale", "centers", "k", "nu"],
4492 )
4493 .expect_err("unknown smooth option should be rejected");
4494 assert!(
4495 err.contains("matern() does not accept option `lengt_scale`"),
4496 "error should name the invalid option, got: {err}"
4497 );
4498 assert!(
4499 err.contains("did you mean one of [length_scale]"),
4500 "error should suggest the closest valid option, got: {err}"
4501 );
4502 assert!(
4503 err.contains("Valid options: ["),
4504 "error should list valid option names, got: {err}"
4505 );
4506 }
4507
4508 #[test]
4509 fn tensor_k_accepts_square_bracket_per_margin_list() {
4510 let ds = continuous_dataset(
4511 &["y", "x", "z"],
4512 (0..40)
4513 .map(|i| {
4514 let x = i as f64 / 39.0;
4515 let z = ((i * 7) % 40) as f64 / 39.0;
4516 vec![x.sin() + z.cos(), x, z]
4517 })
4518 .collect(),
4519 );
4520
4521 assert_eq!(
4522 tensor_margin_basis_sizes(&ds, "y ~ te(x, z, k=[5, 6])"),
4523 vec![5, 6],
4524 "square-bracket k lists should materialize the requested per-margin values"
4525 );
4526 }
4527
4528 #[test]
4529 fn parse_cylinder_periodic_options_match_requested_forms() {
4530 let mut opts = BTreeMap::new();
4531 opts.insert("periodic".to_string(), "[0]".to_string());
4532 opts.insert("period".to_string(), "[2*pi, None]".to_string());
4533 let axes = parse_periodic_axes(&opts, 2).expect("axes");
4534 let periods = parse_periods(&opts, &axes).expect("periods");
4535 assert_eq!(axes, vec![true, false]);
4536 assert!((periods[0].unwrap() - 2.0 * std::f64::consts::PI).abs() < 1e-12);
4537 assert_eq!(periods[1], None);
4538
4539 let mut boundary_opts = BTreeMap::new();
4540 boundary_opts.insert(
4541 "boundary".to_string(),
4542 "['periodic', 'natural']".to_string(),
4543 );
4544 boundary_opts.insert("period".to_string(), "[2*pi, None]".to_string());
4545 let boundary_axes = parse_periodic_axes(&boundary_opts, 2).expect("boundary axes");
4546 let boundary_periods =
4547 parse_periods(&boundary_opts, &boundary_axes).expect("boundary periods");
4548 assert_eq!(boundary_axes, vec![true, false]);
4549 assert!((boundary_periods[0].unwrap() - 2.0 * std::f64::consts::PI).abs() < 1e-12);
4550 assert_eq!(boundary_periods[1], None);
4551
4552 let mut unicode_opts = BTreeMap::new();
4553 unicode_opts.insert("periodic".to_string(), "[0,1]".to_string());
4554 unicode_opts.insert("period".to_string(), "[2π, τ]".to_string());
4555 let unicode_axes = parse_periodic_axes(&unicode_opts, 2).expect("unicode axes");
4556 let unicode_periods = parse_periods(&unicode_opts, &unicode_axes).expect("unicode periods");
4557 assert_eq!(unicode_axes, vec![true, true]);
4558 assert!((unicode_periods[0].unwrap() - 2.0 * std::f64::consts::PI).abs() < 1e-12);
4559 assert!((unicode_periods[1].unwrap() - std::f64::consts::TAU).abs() < 1e-12);
4560 }
4561
4562 #[test]
4563 fn parse_single_axis_periodic_zero_as_axis_not_false() {
4564 let mut opts = BTreeMap::new();
4565 opts.insert("periodic".to_string(), "[0]".to_string());
4566 opts.insert("period".to_string(), "2*pi".to_string());
4567 opts.insert("origin".to_string(), "0".to_string());
4568 let axes = parse_periodic_axes(&opts, 1).expect("axes");
4569 let periods = parse_periods(&opts, &axes).expect("periods");
4570 let origins = parse_period_origins(&opts, &axes).expect("origins");
4571 assert_eq!(axes, vec![true]);
4572 assert!((periods[0].unwrap() - 2.0 * std::f64::consts::PI).abs() < 1e-12);
4573 assert_eq!(origins[0], Some(0.0));
4574 }
4575
4576 #[test]
4577 fn one_dimensional_bspline_accepts_boundary_periodic() {
4578 let ds = continuous_dataset(
4579 &["y", "theta"],
4580 (0..16)
4581 .map(|i| {
4582 let theta = std::f64::consts::TAU * i as f64 / 16.0;
4583 vec![theta.sin(), theta]
4584 })
4585 .collect(),
4586 );
4587 let parsed = parse_formula("y ~ s(theta, boundary=periodic, period=2*pi, origin=0, k=8)")
4588 .expect("parse");
4589 let col_map = ds.column_map();
4590 let mut notes = Vec::new();
4591 let terms = build_termspec(
4592 &parsed.terms,
4593 &ds,
4594 &col_map,
4595 &mut notes,
4596 &gam_runtime::resource::ResourcePolicy::default_library(),
4597 )
4598 .expect("periodic boundary should build");
4599 let SmoothBasisSpec::BSpline1D { spec, .. } = &terms.smooth_terms[0].basis else {
4600 panic!("expected 1D B-spline");
4601 };
4602 assert!(matches!(
4603 &spec.knotspec,
4604 BSplineKnotSpec::PeriodicUniform {
4605 data_range,
4606 num_basis: 8
4607 } if *data_range == (0.0, std::f64::consts::TAU)
4608 ));
4609 }
4610
4611 #[test]
4612 fn univariate_smooth_accepts_mgcv_cubic_regression_aliases() {
4613 let ds = continuous_dataset(
4614 &["y", "x"],
4615 (0..32)
4616 .map(|i| {
4617 let x = i as f64 / 31.0;
4618 vec![x * x, x]
4619 })
4620 .collect(),
4621 );
4622 let col_map = ds.column_map();
4623
4624 for (selector, expect_double_penalty) in [("cr", false), ("cs", true)] {
4625 let formula = format!("y ~ s(x, bs='{selector}')");
4626 let parsed = parse_formula(&formula).expect("parse cr/cs smooth");
4627 let mut notes = Vec::new();
4628 let terms = build_termspec(
4629 &parsed.terms,
4630 &ds,
4631 &col_map,
4632 &mut notes,
4633 &gam_runtime::resource::ResourcePolicy::default_library(),
4634 )
4635 .unwrap_or_else(|err| panic!("bs='{selector}' must build a 1-D smooth, got: {err:?}"));
4636 let SmoothBasisSpec::BSpline1D { spec, .. } = &terms.smooth_terms[0].basis else {
4637 panic!(
4638 "bs='{selector}' must lower to a BSpline1D; got {:?}",
4639 terms.smooth_terms[0].basis
4640 );
4641 };
4642 assert_eq!(
4643 spec.double_penalty, expect_double_penalty,
4644 "bs='{selector}' must default double_penalty to mgcv's convention \
4645 (cr=no-shrinkage, cs=shrinkage); got double_penalty={}",
4646 spec.double_penalty
4647 );
4648 }
4649 }
4650
4651 #[test]
4652 fn univariate_ps_small_k_degree_reduces_through_build() {
4653 let ds = continuous_dataset(
4662 &["y", "x"],
4663 (0..32)
4664 .map(|i| {
4665 let x = i as f64 / 31.0;
4666 vec![x * x, x]
4667 })
4668 .collect(),
4669 );
4670 let col_map = ds.column_map();
4671
4672 for formula in ["y ~ s(x, bs='ps', k=3)", "y ~ s(x, k=3)"] {
4673 let parsed = parse_formula(formula).expect("parse small-k ps/cr smooth");
4674 let mut notes = Vec::new();
4675 let terms = build_termspec(
4676 &parsed.terms,
4677 &ds,
4678 &col_map,
4679 &mut notes,
4680 &gam_runtime::resource::ResourcePolicy::default_library(),
4681 )
4682 .unwrap_or_else(|err| {
4683 panic!("`{formula}` must degree-reduce, not error; got: {err:?}")
4684 });
4685 let SmoothBasisSpec::BSpline1D { spec, .. } = &terms.smooth_terms[0].basis else {
4686 panic!(
4687 "`{formula}` must lower to a BSpline1D; got {:?}",
4688 terms.smooth_terms[0].basis
4689 );
4690 };
4691 assert_eq!(
4692 spec.degree, 2,
4693 "`{formula}` must drop the cubic default to a quadratic basis"
4694 );
4695 let num_internal = match &spec.knotspec {
4696 BSplineKnotSpec::Generate {
4697 num_internal_knots, ..
4698 } => *num_internal_knots,
4699 BSplineKnotSpec::Automatic {
4700 num_internal_knots: Some(n),
4701 ..
4702 } => *n,
4703 other => panic!("`{formula}` unexpected knotspec: {other:?}"),
4704 };
4705 assert_eq!(
4706 num_internal, 0,
4707 "`{formula}` must have zero internal knots (num_basis = k = 3)"
4708 );
4709 assert!(
4711 spec.penalty_order >= 1 && spec.penalty_order <= spec.degree,
4712 "`{formula}` penalty_order {} must satisfy 1 <= order <= degree={}",
4713 spec.penalty_order,
4714 spec.degree
4715 );
4716 }
4717 }
4718
4719 #[test]
4720 fn formula_shape_constraint_round_trips_and_rejects_bogus() {
4721 let ds = continuous_dataset(
4722 &["y", "x"],
4723 (0..32)
4724 .map(|i| {
4725 let x = i as f64 / 31.0;
4726 vec![x * x, x]
4727 })
4728 .collect(),
4729 );
4730 let col_map = ds.column_map();
4731
4732 let parsed =
4733 parse_formula("y ~ s(x, shape=monotone_increasing)").expect("parse monotone smooth");
4734 let mut notes = Vec::new();
4735 let terms = build_termspec(
4736 &parsed.terms,
4737 &ds,
4738 &col_map,
4739 &mut notes,
4740 &gam_runtime::resource::ResourcePolicy::default_library(),
4741 )
4742 .expect("monotone smooth should build");
4743 assert_eq!(
4744 terms.smooth_terms[0].shape,
4745 ShapeConstraint::MonotoneIncreasing
4746 );
4747
4748 let parsed_bad = parse_formula("y ~ s(x, shape=bogus)").expect("parse bogus shape");
4749 let mut notes_bad = Vec::new();
4750 let err = build_termspec(
4751 &parsed_bad.terms,
4752 &ds,
4753 &col_map,
4754 &mut notes_bad,
4755 &gam_runtime::resource::ResourcePolicy::default_library(),
4756 )
4757 .expect_err("bogus shape must error");
4758 assert!(
4759 format!("{err:?}").contains("unknown shape constraint"),
4760 "got: {err:?}"
4761 );
4762 }
4763
4764 #[test]
4765 fn default_sphere_smooth_uses_spherical_farthest_point_centers() {
4766 let ds = continuous_dataset(
4767 &["y", "lat", "lon"],
4768 (0..24)
4769 .map(|i| {
4770 let t = i as f64 / 24.0;
4771 let lat = -60.0 + 120.0 * t;
4772 let lon = -180.0 + 360.0 * ((7 * i) % 24) as f64 / 24.0;
4773 vec![lat.to_radians().sin(), lat, lon]
4774 })
4775 .collect(),
4776 );
4777 let parsed = parse_formula("y ~ sphere(lat, lon)").expect("parse");
4778 let col_map = ds.column_map();
4779 let mut notes = Vec::new();
4780 let terms = build_termspec(
4781 &parsed.terms,
4782 &ds,
4783 &col_map,
4784 &mut notes,
4785 &gam_runtime::resource::ResourcePolicy::default_library(),
4786 )
4787 .expect("build sphere termspec");
4788 let SmoothBasisSpec::Sphere { spec, .. } = &terms.smooth_terms[0].basis else {
4789 panic!("expected sphere term");
4790 };
4791 assert!(matches!(
4792 spec.center_strategy,
4793 CenterStrategy::FarthestPoint { .. }
4794 ));
4795 }
4796
4797 #[test]
4798 fn one_dimensional_duchon_defaults_to_scale_free_length_scale() {
4799 let ds = continuous_dataset(
4800 &["y", "x"],
4801 (0..32)
4802 .map(|i| {
4803 let x = i as f64 / 31.0;
4804 vec![(std::f64::consts::TAU * x).sin(), x]
4805 })
4806 .collect(),
4807 );
4808 let parsed = parse_formula("y ~ duchon(x)").expect("parse");
4809 let col_map = ds.column_map();
4810 let mut notes = Vec::new();
4811 let terms = build_termspec(
4812 &parsed.terms,
4813 &ds,
4814 &col_map,
4815 &mut notes,
4816 &gam_runtime::resource::ResourcePolicy::default_library(),
4817 )
4818 .expect("build default duchon termspec");
4819 let SmoothBasisSpec::Duchon { spec, .. } = &terms.smooth_terms[0].basis else {
4820 panic!("expected Duchon term");
4821 };
4822 assert_eq!(spec.length_scale, None);
4823 }
4824
4825 #[test]
4826 fn one_dimensional_duchon_length_scale_opts_into_hybrid_mode() {
4827 let ds = continuous_dataset(
4828 &["y", "x"],
4829 (0..32)
4830 .map(|i| {
4831 let x = i as f64 / 31.0;
4832 vec![(std::f64::consts::TAU * x).sin(), x]
4833 })
4834 .collect(),
4835 );
4836 let parsed = parse_formula("y ~ duchon(x, length_scale=0.25)").expect("parse");
4837 let col_map = ds.column_map();
4838 let mut notes = Vec::new();
4839 let terms = build_termspec(
4840 &parsed.terms,
4841 &ds,
4842 &col_map,
4843 &mut notes,
4844 &gam_runtime::resource::ResourcePolicy::default_library(),
4845 )
4846 .expect("build hybrid duchon termspec");
4847 let SmoothBasisSpec::Duchon { spec, .. } = &terms.smooth_terms[0].basis else {
4848 panic!("expected Duchon term");
4849 };
4850 assert_eq!(spec.length_scale, Some(0.25));
4851 }
4852
4853 #[test]
4854 fn parse_matern_nu_accepts_equivalent_half_integer_forms() {
4855 let cases = [
4856 ("1/2", MaternNu::Half),
4857 (" 1 / 2 ", MaternNu::Half),
4858 (".5", MaternNu::Half),
4859 ("0.50", MaternNu::Half),
4860 ("half", MaternNu::Half),
4861 ("3 / 2", MaternNu::ThreeHalves),
4862 ("1.50", MaternNu::ThreeHalves),
4863 ("5 / 2", MaternNu::FiveHalves),
4864 ("2.500000000000", MaternNu::FiveHalves),
4865 ("7 / 2", MaternNu::SevenHalves),
4866 ("3.50", MaternNu::SevenHalves),
4867 ("9 / 2", MaternNu::NineHalves),
4868 ("4.50", MaternNu::NineHalves),
4869 ];
4870 for (raw, expected) in cases {
4871 let parsed = parse_matern_nu(raw).expect(raw);
4872 assert!(
4873 matches!(
4874 (parsed, expected),
4875 (MaternNu::Half, MaternNu::Half)
4876 | (MaternNu::ThreeHalves, MaternNu::ThreeHalves)
4877 | (MaternNu::FiveHalves, MaternNu::FiveHalves)
4878 | (MaternNu::SevenHalves, MaternNu::SevenHalves)
4879 | (MaternNu::NineHalves, MaternNu::NineHalves)
4880 ),
4881 "parsed {raw:?} as {parsed:?}, expected {expected:?}"
4882 );
4883 }
4884 }
4885
4886 #[test]
4887 fn parse_matern_nu_rejects_unsupported_or_invalid_values() {
4888 for raw in ["1", "2", "11/2", "1/0", "nan", "fast"] {
4889 let err = parse_matern_nu(raw).expect_err(raw);
4890 assert!(
4891 err.contains("supported half-integer values"),
4892 "unexpected error for {raw:?}: {err}"
4893 );
4894 }
4895 }
4896
4897 #[test]
4898 fn parse_ps_k_promotes_underexpressive_cubic_basis() {
4899 let mut opts = BTreeMap::new();
4900 opts.insert("k".to_string(), "4".to_string());
4901 let (internal, inferred, eff_degree) = parse_ps_internal_knots(&opts, 3, 20).expect("k=4");
4902 assert_eq!(internal, 2);
4903 assert_eq!(eff_degree, 3);
4904 assert!(!inferred);
4905
4906 opts.insert("k".to_string(), "6".to_string());
4907 let (internal, inferred, eff_degree) = parse_ps_internal_knots(&opts, 3, 20).expect("k=6");
4908 assert_eq!(internal, 2);
4909 assert_eq!(eff_degree, 3);
4910 assert!(!inferred);
4911
4912 opts.insert("k".to_string(), "10".to_string());
4913 let (internal, inferred, eff_degree) = parse_ps_internal_knots(&opts, 3, 20).expect("k=10");
4914 assert_eq!(internal, 6);
4915 assert_eq!(eff_degree, 3);
4916 assert!(!inferred);
4917 }
4918
4919 #[test]
4920 fn parse_ps_internal_knots_drops_degree_for_small_k() {
4921 let mut opts = BTreeMap::new();
4926 opts.insert("k".to_string(), "3".to_string());
4927 let (internal, inferred, eff_degree) = parse_ps_internal_knots(&opts, 3, 20).expect("k=3");
4928 assert_eq!(eff_degree, 2);
4929 assert_eq!(internal, 0);
4930 assert!(!inferred);
4931
4932 opts.insert("k".to_string(), "2".to_string());
4935 let (internal, inferred, eff_degree) = parse_ps_internal_knots(&opts, 3, 20).expect("k=2");
4936 assert_eq!(eff_degree, 1);
4937 assert_eq!(internal, 0);
4938 assert!(!inferred);
4939
4940 opts.insert("k".to_string(), "1".to_string());
4944 let err = parse_ps_internal_knots(&opts, 3, 20)
4945 .expect_err("k=1 is below the irreducible spline floor");
4946 assert!(err.contains("requires k >= 2"), "unexpected error: {err}");
4947
4948 opts.insert("k".to_string(), "4".to_string());
4951 let (internal, inferred, eff_degree) = parse_ps_internal_knots(&opts, 3, 20).expect("k=4");
4952 assert_eq!(eff_degree, 3);
4953 assert_eq!(internal, 2);
4954 assert!(!inferred);
4955 }
4956
4957 #[test]
4958 fn factor_smooth_marginal_degree_reduces_for_small_k() {
4959 let ds = factor_dataset();
4960 let col_map = ds.column_map();
4961
4962 for (k, expected_degree) in [(3usize, 2usize), (2usize, 1usize)] {
4963 let parsed =
4964 parse_formula(&format!("y ~ s(x, g, bs=fs, k={k})")).expect("parse factor smooth");
4965 let mut notes = Vec::new();
4966 let terms = build_termspec(
4967 &parsed.terms,
4968 &ds,
4969 &col_map,
4970 &mut notes,
4971 &gam_runtime::resource::ResourcePolicy::default_library(),
4972 )
4973 .unwrap_or_else(|err| panic!("fs k={k} should degree-reduce, got: {err:?}"));
4974 let SmoothBasisSpec::FactorSmooth { spec } = &terms.smooth_terms[0].basis else {
4975 panic!(
4976 "expected factor smooth, got {:?}",
4977 terms.smooth_terms[0].basis
4978 );
4979 };
4980 assert_eq!(spec.marginal.degree, expected_degree);
4981 assert!(
4982 spec.marginal.penalty_order <= spec.marginal.degree,
4983 "penalty_order {} must be clamped to degree {}",
4984 spec.marginal.penalty_order,
4985 spec.marginal.degree
4986 );
4987 let basis_size = match spec.marginal.knotspec {
4988 BSplineKnotSpec::Generate {
4989 num_internal_knots, ..
4990 } => num_internal_knots + spec.marginal.degree + 1,
4991 BSplineKnotSpec::Automatic {
4992 num_internal_knots: Some(num_internal_knots),
4993 ..
4994 } => num_internal_knots + spec.marginal.degree + 1,
4995 ref other => panic!("unexpected factor-smooth knotspec: {other:?}"),
4996 };
4997 assert_eq!(basis_size, k);
4998 }
4999 }
5000
5001 fn ternary_factor_dataset() -> Dataset {
5004 let rows = (0..120)
5005 .map(|i| {
5006 let x = (i % 3) as f64;
5007 let g = (i % 2) as f64;
5008 vec![x + g, x, g]
5009 })
5010 .collect::<Vec<_>>();
5011 Dataset {
5012 headers: vec!["y".into(), "x".into(), "g".into()],
5013 values: Array2::from_shape_vec(
5014 (rows.len(), 3),
5015 rows.into_iter().flat_map(|row| row.into_iter()).collect(),
5016 )
5017 .expect("rectangular ternary factor test data"),
5018 schema: DataSchema {
5019 columns: vec![
5020 SchemaColumn {
5021 name: "y".into(),
5022 kind: ColumnKindTag::Continuous,
5023 levels: vec![],
5024 },
5025 SchemaColumn {
5026 name: "x".into(),
5027 kind: ColumnKindTag::Continuous,
5028 levels: vec![],
5029 },
5030 SchemaColumn {
5031 name: "g".into(),
5032 kind: ColumnKindTag::Categorical,
5033 levels: vec!["a".into(), "b".into()],
5034 },
5035 ],
5036 },
5037 column_kinds: vec![
5038 ColumnKindTag::Continuous,
5039 ColumnKindTag::Continuous,
5040 ColumnKindTag::Categorical,
5041 ],
5042 }
5043 }
5044
5045 #[test]
5046 fn univariate_cr_smooth_caps_knots_to_data_support() {
5047 let ds = continuous_dataset(
5053 &["y", "x"],
5054 (0..90)
5055 .map(|i| vec![(i % 3) as f64, (i % 3) as f64])
5056 .collect(),
5057 );
5058 let col_map = ds.column_map();
5059 let parsed = parse_formula("y ~ s(x, bs=cr, k=10)").expect("parse cr smooth");
5060 let mut notes = Vec::new();
5061 let terms = build_termspec(
5062 &parsed.terms,
5063 &ds,
5064 &col_map,
5065 &mut notes,
5066 &gam_runtime::resource::ResourcePolicy::default_library(),
5067 )
5068 .expect("cr k=10 must cap to data support instead of erroring");
5069 let SmoothBasisSpec::BSpline1D { spec, .. } = &terms.smooth_terms[0].basis else {
5070 panic!("expected BSpline1D for s(x, bs=cr)");
5071 };
5072 let BSplineKnotSpec::NaturalCubicRegression { knots } = &spec.knotspec else {
5073 panic!("expected cr knotspec, got {:?}", spec.knotspec);
5074 };
5075 assert_eq!(knots.len(), 3, "cr basis not capped to 3 distinct values");
5077 assert_eq!(knots.as_slice().unwrap(), &[0.0, 1.0, 2.0]);
5078 assert!(
5080 notes.iter().any(|n| n.contains("data-support cap")),
5081 "cap not reported in inference notes: {notes:?}"
5082 );
5083 }
5084
5085 #[test]
5086 fn univariate_cr_smooth_binary_covariate_degrades_to_bspline() {
5087 let ds = continuous_dataset(
5091 &["y", "x"],
5092 (0..80)
5093 .map(|i| vec![(i % 2) as f64, (i % 2) as f64])
5094 .collect(),
5095 );
5096 let col_map = ds.column_map();
5097 let parsed = parse_formula("y ~ s(x, bs=cr, k=10)").expect("parse cr smooth");
5098 let mut notes = Vec::new();
5099 let terms = build_termspec(
5100 &parsed.terms,
5101 &ds,
5102 &col_map,
5103 &mut notes,
5104 &gam_runtime::resource::ResourcePolicy::default_library(),
5105 )
5106 .expect("binary cr must degrade to B-spline instead of erroring");
5107 let SmoothBasisSpec::BSpline1D { spec, .. } = &terms.smooth_terms[0].basis else {
5108 panic!("expected BSpline1D for s(x, bs=cr)");
5109 };
5110 assert!(
5111 !matches!(
5112 spec.knotspec,
5113 BSplineKnotSpec::NaturalCubicRegression { .. }
5114 ),
5115 "binary covariate must NOT build a cr basis, got {:?}",
5116 spec.knotspec
5117 );
5118 assert!(
5119 notes
5120 .iter()
5121 .any(|n| n.contains("Degraded to the linear B-spline")),
5122 "degradation not reported in inference notes: {notes:?}"
5123 );
5124 }
5125
5126 #[test]
5127 fn sz_factor_smooth_low_cardinality_uses_bspline_marginal() {
5128 let ds = ternary_factor_dataset();
5137 let col_map = ds.column_map();
5138 let parsed = parse_formula("y ~ s(x, g, bs=sz, k=10)").expect("parse sz factor smooth");
5139 let mut notes = Vec::new();
5140 let terms = build_termspec(
5141 &parsed.terms,
5142 &ds,
5143 &col_map,
5144 &mut notes,
5145 &gam_runtime::resource::ResourcePolicy::default_library(),
5146 )
5147 .expect("sz on a ternary covariate must build (B-spline marginal), not hard-fail");
5148 let SmoothBasisSpec::FactorSmooth { spec } = &terms.smooth_terms[0].basis else {
5149 panic!("expected FactorSmooth for s(x, g, bs=sz)");
5150 };
5151 assert!(
5152 !matches!(
5153 spec.marginal.knotspec,
5154 BSplineKnotSpec::NaturalCubicRegression { .. }
5155 ),
5156 "sz marginal must be a B-spline (curvature-capable), not the \
5157 natural-BC cr basis; got {:?}",
5158 spec.marginal.knotspec
5159 );
5160 }
5161
5162 fn continuous_x_factor_dataset(n: usize, n_groups: usize) -> Dataset {
5167 let rows = (0..n)
5168 .map(|i| {
5169 let x = i as f64 / (n as f64 - 1.0);
5170 let g = (i % n_groups) as f64;
5171 vec![x + g, x, g]
5172 })
5173 .collect::<Vec<_>>();
5174 let levels: Vec<String> = (0..n_groups).map(|k| format!("g{k}")).collect();
5175 Dataset {
5176 headers: vec!["y".into(), "x".into(), "g".into()],
5177 values: Array2::from_shape_vec(
5178 (rows.len(), 3),
5179 rows.into_iter().flat_map(|row| row.into_iter()).collect(),
5180 )
5181 .expect("rectangular continuous-x factor data"),
5182 schema: DataSchema {
5183 columns: vec![
5184 SchemaColumn {
5185 name: "y".into(),
5186 kind: ColumnKindTag::Continuous,
5187 levels: vec![],
5188 },
5189 SchemaColumn {
5190 name: "x".into(),
5191 kind: ColumnKindTag::Continuous,
5192 levels: vec![],
5193 },
5194 SchemaColumn {
5195 name: "g".into(),
5196 kind: ColumnKindTag::Categorical,
5197 levels,
5198 },
5199 ],
5200 },
5201 column_kinds: vec![
5202 ColumnKindTag::Continuous,
5203 ColumnKindTag::Continuous,
5204 ColumnKindTag::Categorical,
5205 ],
5206 }
5207 }
5208
5209 fn factor_smooth_spec_for(formula: &str, ds: &Dataset) -> FactorSmoothSpec {
5210 let col_map = ds.column_map();
5211 let parsed = parse_formula(formula).expect("parse factor smooth formula");
5212 let mut notes = Vec::new();
5213 let terms = build_termspec(
5214 &parsed.terms,
5215 ds,
5216 &col_map,
5217 &mut notes,
5218 &gam_runtime::resource::ResourcePolicy::default_library(),
5219 )
5220 .expect("build factor smooth term");
5221 let SmoothBasisSpec::FactorSmooth { spec } = &terms.smooth_terms[0].basis else {
5222 panic!("expected FactorSmooth basis for `{formula}`");
5223 };
5224 spec.clone()
5225 }
5226
5227 #[test]
5246 fn sz_factor_smooth_carries_null_space_ridge_like_fs() {
5247 let ds = continuous_x_factor_dataset(180, 4);
5248 let mut workspace = crate::basis::BasisWorkspace::new();
5249
5250 let sz_spec = factor_smooth_spec_for("y ~ s(x, g, bs=sz, k=8)", &ds);
5251 let sz_built = crate::smooth::build_factor_smooth(
5252 ds.values.view(),
5253 &sz_spec,
5254 "sz_term",
5255 &mut workspace,
5256 )
5257 .expect("build sz factor smooth");
5258
5259 let fs_spec = factor_smooth_spec_for("y ~ s(x, g, bs=fs, k=8)", &ds);
5260 let fs_built = crate::smooth::build_factor_smooth(
5261 ds.values.view(),
5262 &fs_spec,
5263 "fs_term",
5264 &mut workspace,
5265 )
5266 .expect("build fs factor smooth");
5267
5268 let n_levels = sz_spec
5275 .group_frozen_levels
5276 .as_ref()
5277 .map(|l| l.len())
5278 .unwrap_or(4);
5279 assert!(n_levels >= 3, "test needs >=3 groups, got {n_levels}");
5280
5281 assert_eq!(
5282 sz_built.penalties.len(),
5283 fs_built.penalties.len(),
5284 "sz must carry the same number of penalties as fs (wiggliness + one \
5285 null-space ridge per marginal null direction); sz had {} (only the \
5286 wiggliness penalties => null space unpenalized => over-smoothed), fs \
5287 had {}",
5288 sz_built.penalties.len(),
5289 fs_built.penalties.len(),
5290 );
5291
5292 assert!(
5297 sz_built.penalties.len() >= 2,
5298 "sz deviation block carries no null-space ridge (penalties={}); the \
5299 null space is unpenalized and REML over-smooths the deviations",
5300 sz_built.penalties.len(),
5301 );
5302
5303 assert!(
5308 sz_built.dim < fs_built.dim,
5309 "sz design width {} must be strictly less than fs width {} \
5310 (zero-sum contrast drops one level block)",
5311 sz_built.dim,
5312 fs_built.dim,
5313 );
5314
5315 assert_eq!(sz_built.penalties.len(), sz_built.nullspaces.len());
5318 assert_eq!(sz_built.penalties.len(), sz_built.penaltyinfo.len());
5319 assert_eq!(sz_built.penalties.len(), sz_built.null_eigenvectors.len());
5320 }
5321
5322 fn factor_dataset_l3() -> Dataset {
5333 let rows = (0..30)
5335 .map(|i| {
5336 let x = i as f64 / 29.0;
5337 let g = (i % 3) as f64;
5338 vec![x + g, x, g]
5339 })
5340 .collect::<Vec<_>>();
5341 Dataset {
5342 headers: vec!["y".into(), "x".into(), "g".into()],
5343 values: Array2::from_shape_vec(
5344 (rows.len(), 3),
5345 rows.into_iter().flat_map(|row| row.into_iter()).collect(),
5346 )
5347 .expect("rectangular L=3 factor test data"),
5348 schema: DataSchema {
5349 columns: vec![
5350 SchemaColumn {
5351 name: "y".into(),
5352 kind: ColumnKindTag::Continuous,
5353 levels: vec![],
5354 },
5355 SchemaColumn {
5356 name: "x".into(),
5357 kind: ColumnKindTag::Continuous,
5358 levels: vec![],
5359 },
5360 SchemaColumn {
5361 name: "g".into(),
5362 kind: ColumnKindTag::Categorical,
5363 levels: vec!["a".into(), "b".into(), "c".into()],
5364 },
5365 ],
5366 },
5367 column_kinds: vec![
5368 ColumnKindTag::Continuous,
5369 ColumnKindTag::Continuous,
5370 ColumnKindTag::Categorical,
5371 ],
5372 }
5373 }
5374
5375 #[test]
5376 fn factor_by_smooth_plus_bare_categorical_does_not_duplicate_factor_block() {
5377 let ds = factor_dataset_l3();
5378 let col_map = ds.column_map();
5379
5380 let g_blocks = |formula: &str| -> usize {
5381 let parsed = parse_formula(formula).expect("parse by-smooth formula");
5382 let mut notes = Vec::new();
5383 let terms = build_termspec(
5384 &parsed.terms,
5385 &ds,
5386 &col_map,
5387 &mut notes,
5388 &ResourcePolicy::default_library(),
5389 )
5390 .unwrap_or_else(|err| panic!("`{formula}` must build, got: {err:?}"));
5391 terms
5392 .random_effect_terms
5393 .iter()
5394 .filter(|rt| rt.name == "g")
5395 .count()
5396 };
5397
5398 let by_only = g_blocks("y ~ s(x, by=g, k=10)");
5402 assert_eq!(
5403 by_only, 1,
5404 "`y ~ s(x, by=g)` must produce exactly one `g` design block"
5405 );
5406
5407 let by_plus_bare = g_blocks("y ~ s(x, by=g, k=10) + g");
5411 assert_eq!(
5412 by_plus_bare, 1,
5413 "`y ~ s(x, by=g) + g` must collapse to ONE `g` block (#1457): the bare \
5414 `+ g` already owns the factor's level offsets, so the `by=` branch \
5415 must not add a second, treatment-coded main effect"
5416 );
5417
5418 assert_eq!(
5420 by_plus_bare, by_only,
5421 "the bare `+ g` collision must add zero extra `g` blocks (#1457)"
5422 );
5423 }
5424
5425 #[test]
5426 fn parse_tensor_periods_and_origins_aliases() {
5427 let mut opts = BTreeMap::new();
5428 opts.insert(
5429 "boundary".to_string(),
5430 "['periodic', 'periodic']".to_string(),
5431 );
5432 opts.insert("periods".to_string(), "[7, 24]".to_string());
5433 opts.insert("origins".to_string(), "[0, -12]".to_string());
5434 let axes = parse_periodic_axes(&opts, 2).expect("axes");
5435 let periods = parse_periods(&opts, &axes).expect("periods");
5436 let origins = parse_period_origins(&opts, &axes).expect("origins");
5437 assert_eq!(axes, vec![true, true]);
5438 assert_eq!(periods, vec![Some(7.0), Some(24.0)]);
5439 assert_eq!(origins, vec![Some(0.0), Some(-12.0)]);
5440 }
5441
5442 #[test]
5443 fn tensor_smooth_honors_per_margin_k_list() {
5444 let ds = continuous_dataset(
5445 &["y", "theta", "h"],
5446 (0..20)
5447 .map(|i| {
5448 let theta = std::f64::consts::TAU * i as f64 / 20.0;
5449 let h = -1.0 + 2.0 * (i % 5) as f64 / 4.0;
5450 vec![theta.cos() + h, theta, h]
5451 })
5452 .collect(),
5453 );
5454 let parsed = parse_formula(
5455 "y ~ te(theta, h, periodic=[0], period=[2*pi, None], origin=[0, None], k=[9,5])",
5456 )
5457 .expect("parse tensor formula");
5458 let col_map = ds.column_map();
5459 let mut notes = Vec::new();
5460 let terms = build_termspec(
5461 &parsed.terms,
5462 &ds,
5463 &col_map,
5464 &mut notes,
5465 &gam_runtime::resource::ResourcePolicy::default_library(),
5466 )
5467 .expect("build tensor terms");
5468 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
5469 panic!("expected tensor B-spline");
5470 };
5471 let dims = spec
5472 .marginalspecs
5473 .iter()
5474 .map(|m| match m.knotspec {
5475 BSplineKnotSpec::PeriodicUniform { num_basis, .. } => num_basis,
5476 BSplineKnotSpec::Generate {
5477 num_internal_knots, ..
5478 } => num_internal_knots + m.degree + 1,
5479 BSplineKnotSpec::NaturalCubicRegression { ref knots } => knots.len(),
5482 _ => panic!("unexpected tensor marginal knotspec"),
5483 })
5484 .collect::<Vec<_>>();
5485 assert_eq!(dims, vec![9, 5]);
5486 }
5487
5488 #[test]
5489 fn tensor_smooth_honors_per_margin_k_axis_aliases() {
5490 let ds = continuous_dataset(
5491 &["resp", "x", "y"],
5492 (0..12)
5493 .map(|i| {
5494 let t = i as f64 / 11.0;
5495 vec![t, t, 1.0 - t]
5496 })
5497 .collect(),
5498 );
5499 assert_eq!(
5500 tensor_margin_basis_sizes(&ds, "resp ~ te(x, y, k_x=9, k_y=5)"),
5501 vec![9, 5],
5502 "k_<margin> aliases should materialize requested per-margin values"
5503 );
5504 }
5505
5506 #[test]
5507 fn tensor_smooth_low_cardinality_axis_falls_back_to_lower_degree_basis() {
5508 let ds = continuous_dataset(
5515 &["y", "x", "b"],
5516 (0..40)
5517 .map(|i| {
5518 let x = i as f64 / 39.0;
5519 let b = (i % 2) as f64;
5520 vec![x.sin() + 0.5 * b, x, b]
5521 })
5522 .collect(),
5523 );
5524 let parsed = parse_formula("y ~ te(x, b, k=[5, 2])").expect("parse tensor with k=[5,2]");
5525 let col_map = ds.column_map();
5526 let mut notes = Vec::new();
5527 let terms = build_termspec(
5528 &parsed.terms,
5529 &ds,
5530 &col_map,
5531 &mut notes,
5532 &gam_runtime::resource::ResourcePolicy::default_library(),
5533 )
5534 .expect("build tensor with binary margin");
5535 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
5536 panic!("expected tensor B-spline for te(x, b)");
5537 };
5538 let continuous = &spec.marginalspecs[0];
5542 let binary = &spec.marginalspecs[1];
5543 assert_eq!(continuous.degree, 3);
5544 assert_eq!(binary.degree, 1);
5545 assert!(
5546 binary.penalty_order >= 1 && binary.penalty_order <= binary.degree,
5547 "binary margin penalty_order {} must satisfy 1 <= order <= degree={}",
5548 binary.penalty_order,
5549 binary.degree
5550 );
5551 let basis_size = |m: &BSplineBasisSpec| match m.knotspec {
5552 BSplineKnotSpec::PeriodicUniform { num_basis, .. } => num_basis,
5553 BSplineKnotSpec::Generate {
5554 num_internal_knots, ..
5555 } => num_internal_knots + m.degree + 1,
5556 BSplineKnotSpec::Automatic {
5557 num_internal_knots: Some(n),
5558 ..
5559 } => n + m.degree + 1,
5560 BSplineKnotSpec::NaturalCubicRegression { ref knots } => knots.len(),
5563 _ => panic!("unexpected tensor marginal knotspec"),
5564 };
5565 assert_eq!(basis_size(continuous), 5);
5566 assert_eq!(basis_size(binary), 2);
5567 }
5568
5569 #[test]
5570 fn tensor_smooth_uniform_k_is_capped_to_a_low_cardinality_margins_distinct_values() {
5571 let ds = continuous_dataset(
5579 &["y", "x", "b"],
5580 (0..40)
5581 .map(|i| {
5582 let x = i as f64 / 39.0;
5583 let b = (i % 2) as f64;
5584 vec![x.sin() + 0.5 * b, x, b]
5585 })
5586 .collect(),
5587 );
5588 let parsed = parse_formula("y ~ te(x, b, k=5)").expect("parse tensor with uniform k=5");
5589 let col_map = ds.column_map();
5590 let mut notes = Vec::new();
5591 let terms = build_termspec(
5592 &parsed.terms,
5593 &ds,
5594 &col_map,
5595 &mut notes,
5596 &gam_runtime::resource::ResourcePolicy::default_library(),
5597 )
5598 .expect("uniform k=5 must auto-cap the binary margin instead of erroring");
5599 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
5600 panic!("expected tensor B-spline for te(x, b)");
5601 };
5602 let basis_size = |m: &BSplineBasisSpec| match &m.knotspec {
5603 BSplineKnotSpec::PeriodicUniform { num_basis, .. } => *num_basis,
5604 BSplineKnotSpec::Generate {
5605 num_internal_knots, ..
5606 } => num_internal_knots + m.degree + 1,
5607 BSplineKnotSpec::Automatic {
5608 num_internal_knots: Some(n),
5609 ..
5610 } => n + m.degree + 1,
5611 BSplineKnotSpec::NaturalCubicRegression { knots } => knots.len(),
5612 other => panic!("unexpected tensor marginal knotspec: {other:?}"),
5613 };
5614 let binary = &spec.marginalspecs[1];
5615 assert_eq!(basis_size(binary), 2);
5618 assert_eq!(binary.degree, 1);
5619 assert_eq!(basis_size(&spec.marginalspecs[0]), 5);
5621 }
5622
5623 #[test]
5624 fn tensor_all_tp_margins_with_per_margin_k_routes_to_bspline_tensor() {
5625 let ds = continuous_dataset(
5634 &["y", "x1", "x2"],
5635 (0..32)
5636 .map(|i| {
5637 let t = i as f64 / 31.0;
5638 vec![t.sin(), t, 1.0 - t]
5639 })
5640 .collect(),
5641 );
5642 let parsed =
5643 parse_formula("y ~ te(x1, x2, bs=c('tp','tp'), k=c(5,5))").expect("parse tensor");
5644 let col_map = ds.column_map();
5645 let mut notes = Vec::new();
5646 let terms = build_termspec(
5647 &parsed.terms,
5648 &ds,
5649 &col_map,
5650 &mut notes,
5651 &gam_runtime::resource::ResourcePolicy::default_library(),
5652 )
5653 .expect("build tensor terms with per-margin k");
5654 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
5655 panic!(
5656 "expected B-spline tensor when k=c(5,5) is supplied with bs=c('tp','tp'), got {:?}",
5657 terms.smooth_terms[0].basis
5658 );
5659 };
5660 let dims = spec
5670 .marginalspecs
5671 .iter()
5672 .map(|m| match m.knotspec {
5673 BSplineKnotSpec::Generate {
5674 num_internal_knots, ..
5675 } => num_internal_knots + m.degree + 1,
5676 BSplineKnotSpec::Automatic {
5677 num_internal_knots: Some(num_internal_knots),
5678 ..
5679 } => num_internal_knots + m.degree + 1,
5680 BSplineKnotSpec::PeriodicUniform { num_basis, .. } => num_basis,
5681 BSplineKnotSpec::Provided(ref knots) => {
5682 knots.len().saturating_sub(m.degree + 1)
5683 }
5684 BSplineKnotSpec::NaturalCubicRegression { ref knots } => knots.len(),
5685 BSplineKnotSpec::Automatic {
5686 num_internal_knots: None,
5687 ..
5688 } => panic!("test cannot infer automatic knot count"),
5689 })
5690 .collect::<Vec<_>>();
5691 assert_eq!(dims, vec![5, 5]);
5692 }
5693
5694 #[test]
5695 fn tensor_all_tp_margins_without_per_margin_k_builds_anisotropic_tensor() {
5696 let ds = continuous_dataset(
5704 &["y", "x1", "x2"],
5705 (0..32)
5706 .map(|i| {
5707 let t = i as f64 / 31.0;
5708 vec![t.sin(), t, 1.0 - t]
5709 })
5710 .collect(),
5711 );
5712 let parsed = parse_formula("y ~ te(x1, x2, bs=c('tp','tp'))").expect("parse tensor");
5713 let col_map = ds.column_map();
5714 let mut notes = Vec::new();
5715 let terms = build_termspec(
5716 &parsed.terms,
5717 &ds,
5718 &col_map,
5719 &mut notes,
5720 &gam_runtime::resource::ResourcePolicy::default_library(),
5721 )
5722 .expect("build tensor terms without per-margin k");
5723 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
5724 panic!(
5725 "te(...,bs=c('tp','tp')) must route to an anisotropic tensor product, not a \
5726 silent isotropic thin-plate substitution; got {:?}",
5727 terms.smooth_terms[0].basis
5728 );
5729 };
5730 assert_eq!(
5731 spec.marginalspecs.len(),
5732 2,
5733 "tp tensor must carry one penalized B-spline margin per axis"
5734 );
5735 }
5736
5737 #[test]
5738 fn explicit_basis_sizes_are_not_small_n_clamped() {
5739 let ds = continuous_dataset(
5740 &["y", "x1", "x2", "x3", "x4", "x5"],
5741 (0..12)
5742 .map(|i| {
5743 let x = i as f64 / 11.0;
5744 vec![x.sin(), x, x * x, x + 0.1, 1.0 - x, (2.0 * x).sin()]
5745 })
5746 .collect(),
5747 );
5748 let parsed = parse_formula("y ~ s(x1, k=10) + s(x2) + s(x3) + s(x4) + s(x5)")
5749 .expect("parse multi-smooth formula");
5750 let col_map = ds.column_map();
5751 let mut notes = Vec::new();
5752 let terms = build_termspec(
5753 &parsed.terms,
5754 &ds,
5755 &col_map,
5756 &mut notes,
5757 &gam_runtime::resource::ResourcePolicy::default_library(),
5758 )
5759 .expect("build multi-smooth terms");
5760 let SmoothBasisSpec::BSpline1D { spec, .. } = &terms.smooth_terms[0].basis else {
5761 panic!("expected first smooth to be B-spline");
5762 };
5763 assert!(matches!(
5764 &spec.knotspec,
5765 BSplineKnotSpec::Generate {
5766 num_internal_knots: 6,
5767 ..
5768 }
5769 ));
5770 }
5771
5772 #[test]
5773 fn explicit_duchon_centers_are_not_small_n_bumped() {
5774 let ds = continuous_dataset(
5775 &["y", "x1", "x2", "x3", "x4", "x5"],
5776 (0..12)
5777 .map(|i| {
5778 let x = i as f64 / 11.0;
5779 vec![x.sin(), x, x * x, x + 0.1, 1.0 - x, (2.0 * x).sin()]
5780 })
5781 .collect(),
5782 );
5783 let parsed = parse_formula("y ~ duchon(x1, centers=3) + s(x2) + s(x3) + s(x4) + s(x5)")
5790 .expect("parse multi-smooth formula");
5791 let col_map = ds.column_map();
5792 let mut notes = Vec::new();
5793 let terms = build_termspec(
5794 &parsed.terms,
5795 &ds,
5796 &col_map,
5797 &mut notes,
5798 &gam_runtime::resource::ResourcePolicy::default_library(),
5799 )
5800 .expect("build multi-smooth terms");
5801 let SmoothBasisSpec::Duchon { spec, .. } = &terms.smooth_terms[0].basis else {
5802 panic!("expected first smooth to be Duchon");
5803 };
5804 assert!(matches!(
5805 spec.center_strategy,
5806 CenterStrategy::FarthestPoint { num_centers: 3 }
5807 ));
5808 }
5809
5810 #[test]
5811 fn inferred_tensor_basis_cap_uses_coordinate_support_not_duplicate_rows() {
5812 let mut unique_rows = Vec::new();
5813 for i in 0..50 {
5814 let theta = i as f64 / 50.0;
5815 for j in 0..16 {
5816 let h = -1.0 + 2.0 * (j as f64) / 15.0;
5817 let y = theta.cos() + h;
5818 unique_rows.push(vec![y, theta, h]);
5819 }
5820 }
5821 let mut repeated_rows = Vec::new();
5822 for _ in 0..12 {
5823 repeated_rows.extend(unique_rows.iter().cloned());
5824 }
5825
5826 let unique = continuous_dataset(&["y", "theta", "h"], unique_rows);
5827 let repeated = continuous_dataset(&["y", "theta", "h"], repeated_rows);
5828
5829 let unique_basis = inferred_tensor_basis_product(&unique);
5830 let repeated_basis = inferred_tensor_basis_product(&repeated);
5831
5832 assert_eq!(
5833 unique_basis, repeated_basis,
5834 "duplicating existing tensor coordinates must not inflate inferred basis width"
5835 );
5836 }
5837
5838 #[test]
5839 fn inferred_three_dim_tensor_basis_stays_bounded_for_reml_selection() {
5840 let make = |n: usize| -> usize {
5848 let mut rows = Vec::with_capacity(n);
5849 for i in 0..n {
5850 let f = i as f64 / n as f64;
5851 rows.push(vec![f.sin(), f, (2.0 * f).cos(), (3.0 * f) % 1.0]);
5852 }
5853 let ds = continuous_dataset(&["y", "x1", "x2", "x3"], rows);
5854 let parsed = parse_formula("y ~ te(x1, x2, x3)").expect("parse 3-D tensor");
5855 let col_map = ds.column_map();
5856 let mut notes = Vec::new();
5857 let terms = build_termspec(
5858 &parsed.terms,
5859 &ds,
5860 &col_map,
5861 &mut notes,
5862 &ResourcePolicy::default_library(),
5863 )
5864 .expect("build 3-D tensor termspec");
5865 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
5866 panic!("expected tensor smooth");
5867 };
5868 spec.marginalspecs
5869 .iter()
5870 .map(|m| match m.knotspec {
5871 BSplineKnotSpec::Generate {
5872 num_internal_knots, ..
5873 } => num_internal_knots + m.degree + 1,
5874 BSplineKnotSpec::Automatic {
5875 num_internal_knots: Some(num_internal_knots),
5876 ..
5877 } => num_internal_knots + m.degree + 1,
5878 BSplineKnotSpec::NaturalCubicRegression { ref knots } => knots.len(),
5881 _ => panic!("unexpected tensor margin knotspec"),
5882 })
5883 .product()
5884 };
5885
5886 assert!(
5888 make(60) <= 216,
5889 "3-D te at small n must stay near the mgcv te default, got {}",
5890 make(60)
5891 );
5892 assert!(
5894 make(2000) <= 216,
5895 "3-D te at large n must not blow ∏k toward the data size, got {}",
5896 make(2000)
5897 );
5898 }
5899
5900 #[test]
5901 fn parse_bspline_boundary_conditions_and_side_selector() {
5902 let mut opts = BTreeMap::new();
5907 opts.insert("boundary_conditions".to_string(), "anchored".to_string());
5908 opts.insert("side".to_string(), "left".to_string());
5909 opts.insert("anchor".to_string(), "2.5".to_string());
5910 let err = parse_bspline_boundary_conditions(&opts)
5911 .expect_err("non-zero left anchor must be rejected")
5912 .to_string();
5913 assert!(
5914 err.contains("left") && err.contains("2.5"),
5915 "rejection should name the affected side and value: {err}"
5916 );
5917
5918 let mut opts = BTreeMap::new();
5922 opts.insert("start_bc".to_string(), "clamped".to_string());
5923 opts.insert("end_bc".to_string(), "zero".to_string());
5924 opts.insert("right_anchor".to_string(), "-1.0".to_string());
5925 let err = parse_bspline_boundary_conditions(&opts)
5926 .expect_err("non-zero right anchor must be rejected")
5927 .to_string();
5928 assert!(
5929 err.contains("right") && err.contains("-1"),
5930 "rejection should name the affected side and value: {err}"
5931 );
5932
5933 let mut opts = BTreeMap::new();
5937 opts.insert("start_bc".to_string(), "clamped".to_string());
5938 opts.insert("end_bc".to_string(), "zero".to_string());
5939 let parsed = parse_bspline_boundary_conditions(&opts).expect("boundary conditions");
5940 assert!(matches!(
5941 parsed.left,
5942 BSplineEndpointBoundaryCondition::Clamped
5943 ));
5944 assert!(matches!(
5945 parsed.right,
5946 BSplineEndpointBoundaryCondition::Anchored { value } if value.abs() < 1e-12
5947 ));
5948 }
5949
5950 #[test]
5951 fn categorical_by_numeric_interaction_expands_treatment_coded_cells() {
5952 let ds = factor_dataset();
5963 let parsed = parse_formula("y ~ x:g").expect("parse `y ~ x:g`");
5965 let col_map = ds.column_map();
5966 let mut notes = Vec::new();
5967 let terms = build_termspec(
5968 &parsed.terms,
5969 &ds,
5970 &col_map,
5971 &mut notes,
5972 &ResourcePolicy::default_library(),
5973 )
5974 .expect("factor-aware `x:g` interaction must build, not error");
5975
5976 assert_eq!(
5977 terms.linear_terms.len(),
5978 2,
5979 "interaction-only `x:g` keeps ALL factor levels (full dummy coding): one slope column per group"
5980 );
5981
5982 let x_col = *col_map.get("x").expect("x column");
5983 let g_col = *col_map.get("g").expect("g column");
5984
5985 let mut seen_bits = std::collections::HashSet::new();
5988 for term in &terms.linear_terms {
5989 assert!(
5990 term.is_interaction(),
5991 "the categorical-by-numeric cell is a Wilkinson-Rogers interaction"
5992 );
5993 assert_eq!(term.feature_cols, vec![x_col]);
5994 assert_eq!(term.categorical_levels.len(), 1);
5995 let (gate_col, gate_bits) = term.categorical_levels[0];
5996 assert_eq!(gate_col, g_col);
5997 assert!(seen_bits.insert(gate_bits), "each level appears once");
5998
5999 let column = term
6001 .realized_design_column(ds.values.view())
6002 .expect("realize cell column");
6003 let n = ds.values.nrows();
6004 assert_eq!(column.len(), n);
6005 for row in 0..n {
6006 let x = ds.values[[row, x_col]];
6007 let g = ds.values[[row, g_col]];
6008 let expected = if g.to_bits() == gate_bits { x } else { 0.0 };
6009 assert!(
6010 (column[row] - expected).abs() < 1e-12,
6011 "row {row}: g={g}, x={x}, expected {expected}, got {}",
6012 column[row]
6013 );
6014 }
6015 }
6016 assert!(seen_bits.contains(&0.0_f64.to_bits()));
6019 assert!(seen_bits.contains(&1.0_f64.to_bits()));
6020 }
6021
6022 #[test]
6023 fn categorical_by_numeric_interaction_keeps_treatment_coding_with_parent() {
6024 let ds = factor_dataset();
6032 let parsed = parse_formula("y ~ x + x:g").expect("parse `y ~ x + x:g`");
6033 let col_map = ds.column_map();
6034 let mut notes = Vec::new();
6035 let terms = build_termspec(
6036 &parsed.terms,
6037 &ds,
6038 &col_map,
6039 &mut notes,
6040 &ResourcePolicy::default_library(),
6041 )
6042 .expect("`x + x:g` must build");
6043
6044 let x_col = *col_map.get("x").expect("x column");
6046 let g_col = *col_map.get("g").expect("g column");
6047 let interaction_cells: Vec<_> = terms
6048 .linear_terms
6049 .iter()
6050 .filter(|t| t.is_interaction())
6051 .collect();
6052 assert_eq!(
6053 interaction_cells.len(),
6054 1,
6055 "with `x` present, `x:g` is treatment-coded → one cell (reference dropped)"
6056 );
6057 let term = interaction_cells[0];
6058 assert_eq!(term.feature_cols, vec![x_col]);
6059 assert_eq!(term.categorical_levels.len(), 1);
6060 let (gate_col, gate_bits) = term.categorical_levels[0];
6061 assert_eq!(gate_col, g_col);
6062 assert_eq!(gate_bits, 1.0_f64.to_bits());
6064 }
6065
6066 #[test]
6067 fn categorical_by_categorical_interaction_expands_full_cross_cells() {
6068 let n = 30usize;
6079 let mut rows = Vec::with_capacity(n);
6080 for i in 0..n {
6081 let y = (i as f64).sin();
6082 let f = (i % 3) as f64; let g = (i % 2) as f64; rows.push(vec![y, f, g]);
6085 }
6086 let values = Array2::from_shape_vec(
6087 (n, 3),
6088 rows.into_iter().flat_map(|row| row.into_iter()).collect(),
6089 )
6090 .expect("rectangular cross-factor data");
6091 let ds = Dataset {
6092 headers: vec!["y".into(), "f".into(), "g".into()],
6093 values,
6094 schema: DataSchema {
6095 columns: vec![
6096 SchemaColumn {
6097 name: "y".into(),
6098 kind: ColumnKindTag::Continuous,
6099 levels: vec![],
6100 },
6101 SchemaColumn {
6102 name: "f".into(),
6103 kind: ColumnKindTag::Categorical,
6104 levels: vec!["f0".into(), "f1".into(), "f2".into()],
6105 },
6106 SchemaColumn {
6107 name: "g".into(),
6108 kind: ColumnKindTag::Categorical,
6109 levels: vec!["g0".into(), "g1".into()],
6110 },
6111 ],
6112 },
6113 column_kinds: vec![
6114 ColumnKindTag::Continuous,
6115 ColumnKindTag::Categorical,
6116 ColumnKindTag::Categorical,
6117 ],
6118 };
6119
6120 let parsed = parse_formula("y ~ f:g").expect("parse `y ~ f:g`");
6121 let col_map = ds.column_map();
6122 let mut notes = Vec::new();
6123 let terms = build_termspec(
6124 &parsed.terms,
6125 &ds,
6126 &col_map,
6127 &mut notes,
6128 &ResourcePolicy::default_library(),
6129 )
6130 .expect("factor-by-factor `f:g` interaction must build, not error");
6131
6132 assert_eq!(
6133 terms.linear_terms.len(),
6134 5,
6135 "saturated 3*2 = 6 cross cells minus one reference cell (f0:g0) = 5"
6136 );
6137
6138 let f_col = *col_map.get("f").expect("f column");
6139 let g_col = *col_map.get("g").expect("g column");
6140 let f0 = 0.0_f64.to_bits();
6144 let g0 = 0.0_f64.to_bits();
6145 let mut emitted = std::collections::HashSet::new();
6146 for term in &terms.linear_terms {
6147 assert!(term.feature_cols.is_empty());
6149 assert_eq!(term.categorical_levels.len(), 2);
6150 let mut gates = std::collections::HashMap::new();
6151 for &(col, bits) in &term.categorical_levels {
6152 gates.insert(col, bits);
6153 }
6154 let f_bits = *gates.get(&f_col).expect("f gate present");
6155 let g_bits = *gates.get(&g_col).expect("g gate present");
6156 assert!(
6158 !(f_bits == f0 && g_bits == g0),
6159 "the reference cell f0:g0 must be absorbed by the intercept, not emitted"
6160 );
6161 emitted.insert((f_bits, g_bits));
6162
6163 let column = term
6164 .realized_design_column(ds.values.view())
6165 .expect("realize cross cell");
6166 for row in 0..n {
6167 let f = ds.values[[row, f_col]];
6168 let g = ds.values[[row, g_col]];
6169 let expected = if f.to_bits() == f_bits && g.to_bits() == g_bits {
6170 1.0
6171 } else {
6172 0.0
6173 };
6174 assert!(
6175 (column[row] - expected).abs() < 1e-12,
6176 "row {row}: expected {expected}, got {}",
6177 column[row]
6178 );
6179 }
6180 assert!(
6181 column.iter().any(|&v| v == 1.0),
6182 "each cross cell must be observed in the data"
6183 );
6184 }
6185 let f_levels = [0.0_f64.to_bits(), 1.0_f64.to_bits(), 2.0_f64.to_bits()];
6188 let g_levels = [0.0_f64.to_bits(), 1.0_f64.to_bits()];
6189 for &fb in &f_levels {
6190 for &gb in &g_levels {
6191 if fb == f0 && gb == g0 {
6192 continue;
6193 }
6194 assert!(
6195 emitted.contains(&(fb, gb)),
6196 "saturated cross cell must be present"
6197 );
6198 }
6199 }
6200 }
6201}