1use std::collections::{BTreeMap, BTreeSet, HashMap};
8use std::path::PathBuf;
9
10use ndarray::{Array2, ArrayView1};
11
12use crate::basis::{
13 BSplineBasisSpec, BSplineBoundaryConditions, BSplineEndpointBoundaryCondition,
14 BSplineIdentifiability, BSplineKnotSpec, CenterCountRequest, CenterStrategy,
15 ConstantCurvatureBasisSpec, ConstantCurvatureIdentifiability, DuchonBasisSpec,
16 DuchonNullspaceOrder, DuchonOperatorPenaltySpec, MaternBasisSpec, MaternIdentifiability,
17 MaternNu, MeasureJetBasisSpec, MeasureJetIdentifiability, OneDimensionalBoundary,
18 SpatialIdentifiability, SphereMethod, SphereWahbaKernel, SphericalSplineBasisSpec,
19 SphericalSplineIdentifiability, ThinPlateBasisSpec, auto_spatial_center_strategy,
20 default_num_centers, default_spatial_center_strategy, default_spherical_harmonic_degree,
21 plan_spatial_basis,
22};
23use crate::inference::formula_dsl::{
24 ParsedTerm, SmoothKind, option_bool, option_f64, option_f64_strict, option_usize,
25 option_usize_any, option_usize_any_strict, option_usize_strict, strip_quotes,
26};
27use crate::smooth::{
28 ByVarKind, FactorSmoothFlavour, FactorSmoothSpec, LinearCoefficientGeometry, LinearTermSpec,
29 RandomEffectTermSpec, ShapeConstraint, SmoothBasisSpec, SmoothTermSpec,
30 TensorBSplineIdentifiability, TensorBSplinePenaltyDecomposition, TensorBSplineSpec,
31 TermCollectionSpec,
32};
33use gam_problem::types::ColIdx;
34use gam_data::{ColumnKindTag, DataError, EncodedDataset as Dataset};
35use gam_runtime::resource::ResourcePolicy;
36
37const DEFAULT_MATERN_LENGTH_SCALE_FLOOR: f64 = 1e-6;
40
41const DEFAULT_BSPLINE_DEGREE: usize = 3;
45
46const DEFAULT_PENALTY_ORDER: usize = 2;
50
51const CYCLIC_DEFAULT_BASIS_DIM: usize = 12;
57
58const FACTOR_SMOOTH_DEFAULT_BASIS_DIM: usize = 10;
64
65const DEFAULT_PCA_CHUNK_SIZE: usize = 4096;
69
70fn default_matern_length_scale(ds: &Dataset, cols: &[usize]) -> f64 {
71 let mut diameter2 = 0.0_f64;
72 for &col in cols {
73 let column = ds.values.column(col);
74 let mut lo = f64::INFINITY;
75 let mut hi = f64::NEG_INFINITY;
76 for &value in column.iter().filter(|v| v.is_finite()) {
77 lo = lo.min(value);
78 hi = hi.max(value);
79 }
80 if lo.is_finite() && hi.is_finite() && hi > lo {
81 let span = hi - lo;
82 diameter2 += span * span;
83 }
84 }
85 let diameter = diameter2.sqrt();
86 if diameter.is_finite() && diameter > 0.0 {
87 diameter.max(DEFAULT_MATERN_LENGTH_SCALE_FLOOR)
94 } else {
95 1.0
96 }
97}
98
99#[derive(Clone, Debug)]
109pub enum TermBuilderError {
110 MissingColumn { reason: String },
117 ColumnNotFound {
123 name: String,
124 role: Option<String>,
125 available: Vec<String>,
126 similar: Vec<String>,
127 tsv_hint: bool,
128 },
129 IncompatibleConfig { reason: String },
133 InvalidOption { reason: String },
136 UnsupportedFeature { reason: String },
140 DegenerateData { reason: String },
143 MalformedFormula { reason: String },
146}
147
148impl std::fmt::Display for TermBuilderError {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 match self {
151 TermBuilderError::MissingColumn { reason }
152 | TermBuilderError::IncompatibleConfig { reason }
153 | TermBuilderError::InvalidOption { reason }
154 | TermBuilderError::UnsupportedFeature { reason }
155 | TermBuilderError::DegenerateData { reason }
156 | TermBuilderError::MalformedFormula { reason } => f.write_str(reason),
157 TermBuilderError::ColumnNotFound {
163 name,
164 role,
165 available,
166 similar,
167 tsv_hint,
168 } => {
169 let canonical = DataError::ColumnNotFound {
170 name: name.clone(),
171 role: role.clone(),
172 available: available.clone(),
173 similar: similar.clone(),
174 tsv_hint: *tsv_hint,
175 };
176 std::fmt::Display::fmt(&canonical, f)
177 }
178 }
179 }
180}
181
182impl From<TermBuilderError> for String {
183 fn from(err: TermBuilderError) -> String {
184 err.to_string()
185 }
186}
187
188impl From<String> for TermBuilderError {
195 fn from(reason: String) -> Self {
196 Self::IncompatibleConfig { reason }
197 }
198}
199
200impl From<DataError> for TermBuilderError {
207 fn from(err: DataError) -> Self {
208 match err {
209 DataError::ColumnNotFound {
210 name,
211 role,
212 available,
213 similar,
214 tsv_hint,
215 } => Self::ColumnNotFound {
216 name,
217 role,
218 available,
219 similar,
220 tsv_hint,
221 },
222 DataError::SchemaMismatch { reason }
223 | DataError::ParseError { reason }
224 | DataError::EncodingFailure { reason }
225 | DataError::EmptyInput { reason }
226 | DataError::InvalidValue { reason } => Self::MissingColumn { reason },
227 }
228 }
229}
230
231impl TermBuilderError {
233 #[inline]
234 fn missing_column(reason: impl Into<String>) -> Self {
235 TermBuilderError::MissingColumn {
236 reason: reason.into(),
237 }
238 }
239 #[inline]
240 fn incompatible_config(reason: impl Into<String>) -> Self {
241 TermBuilderError::IncompatibleConfig {
242 reason: reason.into(),
243 }
244 }
245 #[inline]
246 fn invalid_option(reason: impl Into<String>) -> Self {
247 TermBuilderError::InvalidOption {
248 reason: reason.into(),
249 }
250 }
251 #[inline]
252 fn unsupported_feature(reason: impl Into<String>) -> Self {
253 TermBuilderError::UnsupportedFeature {
254 reason: reason.into(),
255 }
256 }
257 #[inline]
258 fn degenerate_data(reason: impl Into<String>) -> Self {
259 TermBuilderError::DegenerateData {
260 reason: reason.into(),
261 }
262 }
263 #[inline]
264 fn malformed_formula(reason: impl Into<String>) -> Self {
265 TermBuilderError::MalformedFormula {
266 reason: reason.into(),
267 }
268 }
269}
270
271pub fn resolve_col(col_map: &HashMap<String, usize>, name: &str) -> Result<usize, DataError> {
282 col_map
283 .get(name)
284 .copied()
285 .ok_or_else(|| DataError::column_not_found(col_map, name, None))
286}
287
288pub fn resolve_role_col(
293 col_map: &HashMap<String, usize>,
294 name: &str,
295 role: &str,
296) -> Result<usize, DataError> {
297 col_map
298 .get(name)
299 .copied()
300 .ok_or_else(|| DataError::column_not_found(col_map, name, Some(role)))
301}
302
303fn encoded_levels_for_column(ds: &Dataset, col: ColIdx) -> Vec<(u64, String)> {
304 let mut seen = BTreeSet::<u64>::new();
305 for value in ds.values.column(col.get()) {
306 if value.is_finite() {
307 seen.insert(value.to_bits());
308 }
309 }
310 let schema_levels = ds
311 .schema
312 .columns
313 .get(col.get())
314 .map(|column| column.levels.as_slice())
315 .unwrap_or(&[]);
316 seen.into_iter()
317 .enumerate()
318 .map(|(idx, bits)| {
319 let fallback = format!("level{}", idx + 1);
320 let label = schema_levels.get(idx).cloned().unwrap_or(fallback);
321 (bits, label)
322 })
323 .collect()
324}
325
326pub fn column_map_with_alias(
327 col_map: &HashMap<String, usize>,
328 alias: &str,
329 target_column: &str,
330) -> HashMap<String, usize> {
331 let mut aliased = col_map.clone();
332 if let Some(idx) = col_map.get(target_column).copied() {
333 aliased.entry(alias.to_string()).or_insert(idx);
334 }
335 aliased
336}
337
338pub fn build_termspec(
343 terms: &[ParsedTerm],
344 ds: &Dataset,
345 col_map: &HashMap<String, usize>,
346 inference_notes: &mut Vec<String>,
347 policy: &ResourcePolicy,
348) -> Result<TermCollectionSpec, TermBuilderError> {
349 let mut linear_terms = Vec::<LinearTermSpec>::new();
350 let mut random_terms = Vec::<RandomEffectTermSpec>::new();
351 let mut smooth_terms = Vec::<SmoothTermSpec>::new();
352 let smooth_coordinate_count = terms
353 .iter()
354 .map(|term| match term {
355 ParsedTerm::Smooth { vars, .. } => vars.len(),
356 _ => 0,
357 })
358 .sum::<usize>();
359
360 for t in terms {
361 match t {
362 ParsedTerm::Linear {
363 name,
364 explicit,
365 coefficient_min,
366 coefficient_max,
367 } => {
368 let col = resolve_col(col_map, name)?;
369 let auto_kind = ds.column_kinds.get(col).copied().ok_or_else(|| {
370 TermBuilderError::missing_column(format!(
371 "internal column-kind lookup failed for '{name}'"
372 ))
373 .to_string()
374 })?;
375 if *explicit {
376 linear_terms.push(LinearTermSpec {
377 name: name.clone(),
378 feature_col: col,
379 feature_cols: vec![col],
380 categorical_levels: vec![],
381 double_penalty: false,
384 coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
385 coefficient_min: *coefficient_min,
386 coefficient_max: *coefficient_max,
387 });
388 } else {
389 match auto_kind {
390 ColumnKindTag::Continuous | ColumnKindTag::Binary => {
391 linear_terms.push(LinearTermSpec {
392 name: name.clone(),
393 feature_col: col,
394 feature_cols: vec![col],
395 categorical_levels: vec![],
396 double_penalty: false,
398 coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
399 coefficient_min: *coefficient_min,
400 coefficient_max: *coefficient_max,
401 });
402 }
403 ColumnKindTag::Categorical => {
404 if coefficient_min.is_some() || coefficient_max.is_some() {
405 return Err(TermBuilderError::incompatible_config(format!(
406 "coefficient constraints are not supported for categorical auto-random-effect term '{name}'; use group({name}) or an unconstrained numeric term"
407 )));
408 }
409 random_terms.push(RandomEffectTermSpec {
410 name: name.clone(),
411 feature_col: col,
412 drop_first_level: false,
413 penalized: true,
414 frozen_levels: None,
415 });
416 }
417 }
418 }
419 }
420 ParsedTerm::BoundedLinear {
421 name,
422 min,
423 max,
424 prior,
425 } => {
426 let col = resolve_col(col_map, name)?;
427 let auto_kind = ds.column_kinds.get(col).copied().ok_or_else(|| {
428 TermBuilderError::missing_column(format!(
429 "internal column-kind lookup failed for '{name}'"
430 ))
431 .to_string()
432 })?;
433 if !matches!(auto_kind, ColumnKindTag::Continuous | ColumnKindTag::Binary) {
434 return Err(TermBuilderError::incompatible_config(format!(
435 "bounded() currently supports only numeric columns, got categorical '{name}'"
436 )));
437 }
438 linear_terms.push(LinearTermSpec {
439 name: name.clone(),
440 feature_col: col,
441 feature_cols: vec![col],
442 categorical_levels: vec![],
443 double_penalty: false,
444 coefficient_geometry: LinearCoefficientGeometry::Bounded {
445 min: *min,
446 max: *max,
447 prior: prior.clone(),
448 },
449 coefficient_min: None,
450 coefficient_max: None,
451 });
452 }
453 ParsedTerm::RandomEffect { name } => {
454 let col = resolve_col(col_map, name)?;
455 random_terms.push(RandomEffectTermSpec {
456 name: name.clone(),
457 feature_col: col,
458 drop_first_level: false,
459 penalized: true,
460 frozen_levels: None,
461 });
462 }
463 ParsedTerm::Smooth {
464 label,
465 vars,
466 kind,
467 options,
468 } => {
469 let smooth_vars = vars.clone();
470 let by_name = options.get("by").cloned();
471 let cols = smooth_vars
481 .iter()
482 .map(|v| resolve_col(col_map, v))
483 .collect::<Result<Vec<_>, _>>()?;
484 let mut inner_options = options.clone();
485 inner_options.remove("by");
486 inner_options.remove("ordered");
490 let shape = match inner_options.remove("shape") {
496 None => ShapeConstraint::None,
497 Some(raw) => crate::smooth::parse_shape_constraint(&raw)
498 .map_err(TermBuilderError::invalid_option)?,
499 };
500 let inner_basis = build_smooth_basis(
501 *kind,
502 &smooth_vars,
503 &cols,
504 &inner_options,
505 ds,
506 inference_notes,
507 policy,
508 smooth_coordinate_count,
509 )?;
510 if let Some(by_name) = by_name {
511 let by_col = resolve_col(col_map, &by_name)?;
512 match ds.column_kinds.get(by_col).copied().ok_or_else(|| {
513 format!("internal column-kind lookup failed for by variable '{by_name}'")
514 })? {
515 ColumnKindTag::Categorical => {
516 let levels = encoded_levels_for_column(ds, ColIdx::new(by_col));
517 let penalized_group_owner_present =
530 terms.iter().any(|other| match other {
531 ParsedTerm::RandomEffect { name } => name == &by_name,
532 ParsedTerm::Linear {
533 name,
534 explicit: false,
535 ..
536 } if name == &by_name => col_map
537 .get(name)
538 .and_then(|c| ds.column_kinds.get(*c).copied())
539 .map(|kind| matches!(kind, ColumnKindTag::Categorical))
540 .unwrap_or(false),
541 _ => false,
542 });
543 if !random_terms.iter().any(|rt| rt.name == by_name)
554 && !penalized_group_owner_present
555 {
556 random_terms.push(RandomEffectTermSpec {
557 name: by_name.clone(),
558 feature_col: by_col,
559 drop_first_level: true,
560 penalized: false,
561 frozen_levels: None,
562 });
563 }
564 let frozen_levels: Vec<u64> =
569 levels.iter().map(|(bits, _)| *bits).collect();
570 smooth_terms.push(SmoothTermSpec {
571 name: label.clone(),
572 basis: SmoothBasisSpec::BySmooth {
573 smooth: Box::new(inner_basis),
574 by_kind: ByVarKind::Factor {
575 feature_col: by_col,
576 ordered: option_bool(options, "ordered").unwrap_or(false),
577 frozen_levels: Some(frozen_levels),
578 },
579 },
580 shape,
581 joint_null_rotation: None,
582 });
583 }
584 ColumnKindTag::Binary | ColumnKindTag::Continuous => {
585 smooth_terms.push(SmoothTermSpec {
586 name: label.clone(),
587 basis: SmoothBasisSpec::BySmooth {
588 smooth: Box::new(inner_basis),
589 by_kind: ByVarKind::Numeric {
590 feature_col: by_col,
591 },
592 },
593 shape,
594 joint_null_rotation: None,
595 });
596 }
597 }
598 } else {
599 smooth_terms.push(SmoothTermSpec {
600 name: label.clone(),
601 basis: inner_basis,
602 shape,
603 joint_null_rotation: None,
604 });
605 }
606 }
607 ParsedTerm::LinkWiggle { .. }
608 | ParsedTerm::TimeWiggle { .. }
609 | ParsedTerm::LinkConfig { .. }
610 | ParsedTerm::SurvivalConfig { .. } => {
611 }
613 ParsedTerm::LogSlopeSurface { .. } => {
614 return Err(TermBuilderError::malformed_formula(
615 "logslope(...) declarations must be resolved by the marginal-slope formula path before building a term spec",
616 ));
617 }
618 ParsedTerm::Interaction { vars } => {
619 let main_effect_present = |target: &str| -> bool {
652 terms.iter().any(|other| match other {
653 ParsedTerm::Linear { name, .. }
654 | ParsedTerm::BoundedLinear { name, .. }
655 | ParsedTerm::RandomEffect { name } => name == target,
656 _ => false,
657 })
658 };
659 let parent_present = |drop_var: &str| -> bool {
665 vars.iter()
666 .filter(|v| v.as_str() != drop_var)
667 .all(|v| main_effect_present(v))
668 };
669
670 let mut numeric_cols = Vec::<usize>::new();
671 let mut categorical_factors =
674 Vec::<(String, usize, Vec<(u64, String)>, bool)>::new();
675 for var in vars {
676 let col = resolve_col(col_map, var)?;
677 let kind = ds.column_kinds.get(col).copied().ok_or_else(|| {
678 TermBuilderError::missing_column(format!(
679 "internal column-kind lookup failed for '{var}'"
680 ))
681 .to_string()
682 })?;
683 match kind {
684 ColumnKindTag::Continuous | ColumnKindTag::Binary => numeric_cols.push(col),
685 ColumnKindTag::Categorical => {
686 let mut levels = encoded_levels_for_column(ds, ColIdx::new(col));
687 let treatment_coded = parent_present(var);
691 if treatment_coded && levels.len() > 1 {
692 levels.remove(0);
693 }
694 if levels.is_empty() {
695 return Err(TermBuilderError::incompatible_config(format!(
696 "interaction `{}` references categorical column `{var}` with no usable levels",
697 vars.join(":")
698 )));
699 }
700 categorical_factors.push((var.clone(), col, levels, treatment_coded));
701 }
702 }
703 }
704
705 let label = vars.join(":");
706
707 if categorical_factors.is_empty() {
708 linear_terms.push(LinearTermSpec {
711 name: label,
712 feature_col: numeric_cols[0],
713 feature_cols: numeric_cols,
714 categorical_levels: vec![],
715 double_penalty: false,
718 coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
719 coefficient_min: None,
720 coefficient_max: None,
721 });
722 inference_notes.push(format!(
723 "wired linear interaction `{}` as product of numeric columns",
724 vars.join(":")
725 ));
726 } else {
727 let mut cells: Vec<Vec<(usize, u64, String)>> = vec![Vec::new()];
732 for (_var, col, levels, _treatment_coded) in &categorical_factors {
733 let mut next = Vec::with_capacity(cells.len() * levels.len());
734 for cell in &cells {
735 for (bits, level_label) in levels {
736 let mut extended = cell.clone();
737 extended.push((*col, *bits, level_label.clone()));
738 next.push(extended);
739 }
740 }
741 cells = next;
742 }
743
744 let any_dummy_coded = categorical_factors
756 .iter()
757 .any(|(_, _, _, treatment_coded)| !*treatment_coded);
758 if numeric_cols.is_empty() && any_dummy_coded {
759 let reference_cell: Vec<(usize, u64)> = categorical_factors
762 .iter()
763 .map(|(_, col, _, _)| {
764 let levels = encoded_levels_for_column(ds, ColIdx::new(*col));
765 (*col, levels[0].0)
766 })
767 .collect();
768 cells.retain(|cell| {
769 !reference_cell.iter().all(|(rcol, rbits)| {
770 cell.iter()
771 .any(|(col, bits, _)| col == rcol && bits == rbits)
772 })
773 });
774 }
775
776 let n_cells = cells.len();
777 for cell in cells {
778 let cell_suffix = cell
779 .iter()
780 .map(|(_, _, level_label)| level_label.as_str())
781 .collect::<Vec<_>>()
782 .join(":");
783 let categorical_levels =
784 cell.iter().map(|(col, bits, _)| (*col, *bits)).collect();
785 let feature_col = numeric_cols
791 .first()
792 .copied()
793 .unwrap_or(categorical_factors[0].1);
794 linear_terms.push(LinearTermSpec {
795 name: format!("{label}:{cell_suffix}"),
796 feature_col,
797 feature_cols: numeric_cols.clone(),
798 categorical_levels,
799 double_penalty: false,
800 coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
801 coefficient_min: None,
802 coefficient_max: None,
803 });
804 }
805 let all_treatment_coded = !any_dummy_coded;
806 let coding = if all_treatment_coded {
807 "treatment-coded"
808 } else {
809 "marginality-aware (full dummy / saturated)"
810 };
811 inference_notes.push(format!(
812 "wired factor-aware linear interaction `{}` as {} {} cell column(s)",
813 vars.join(":"),
814 n_cells,
815 coding
816 ));
817 }
818 }
819 }
820 }
821
822 Ok(TermCollectionSpec {
823 linear_terms,
824 random_effect_terms: random_terms,
825 smooth_terms,
826 })
827}
828
829fn split_list_option(raw: &str) -> Vec<String> {
830 let t = raw.trim();
831 let inner = t
838 .strip_prefix('[')
839 .and_then(|u| u.strip_suffix(']'))
840 .or_else(|| {
841 t.strip_prefix("c(")
842 .or_else(|| t.strip_prefix("C("))
843 .or_else(|| t.strip_prefix('('))
844 .and_then(|u| u.strip_suffix(')'))
845 })
846 .unwrap_or(t);
847 inner
848 .split(',')
849 .map(|v| v.trim().to_string())
850 .filter(|v| !v.is_empty())
851 .collect()
852}
853
854fn parse_numeric_expr(raw: &str) -> Result<f64, String> {
855 let mut acc = 1.0f64;
856 let normalized = raw.replace(' ', "");
857 if normalized.eq_ignore_ascii_case("none") {
858 return Err("None is not numeric".to_string());
859 }
860 for factor in normalized.split('*') {
861 if factor.is_empty() {
862 return Err(format!("invalid numeric expression '{raw}'"));
863 }
864 let value = if factor.eq_ignore_ascii_case("pi") || factor == "π" {
865 std::f64::consts::PI
866 } else if factor.eq_ignore_ascii_case("tau") || factor == "τ" {
867 std::f64::consts::TAU
868 } else if let Some(prefix) = factor
869 .strip_suffix("pi")
870 .or_else(|| factor.strip_suffix("π"))
871 {
872 let coefficient = if prefix.is_empty() {
873 1.0
874 } else {
875 prefix
876 .parse::<f64>()
877 .map_err(|err| format!("invalid numeric expression '{raw}': {err}"))?
878 };
879 coefficient * std::f64::consts::PI
880 } else if let Some(prefix) = factor
881 .strip_suffix("tau")
882 .or_else(|| factor.strip_suffix("τ"))
883 {
884 let coefficient = if prefix.is_empty() {
885 1.0
886 } else {
887 prefix
888 .parse::<f64>()
889 .map_err(|err| format!("invalid numeric expression '{raw}': {err}"))?
890 };
891 coefficient * std::f64::consts::TAU
892 } else {
893 factor
894 .parse::<f64>()
895 .map_err(|err| format!("invalid numeric expression '{raw}': {err}"))?
896 };
897 acc *= value;
898 }
899 Ok(acc)
900}
901
902fn option_numeric_expr(
912 options: &BTreeMap<String, String>,
913 key: &str,
914) -> Result<Option<f64>, String> {
915 match options.get(key) {
916 None => Ok(None),
917 Some(raw) => parse_numeric_expr(raw)
918 .map(Some)
919 .map_err(|err| format!("option `{key}={raw}` is not a valid numeric value: {err}")),
920 }
921}
922
923fn parse_periods_option(
924 options: &BTreeMap<String, String>,
925 dim: usize,
926) -> Result<Option<Vec<Option<f64>>>, String> {
927 let Some(raw) = options.get("period") else {
928 return Ok(None);
929 };
930 let values = split_list_option(raw);
931 let mut periods = vec![None; dim];
932 if values.len() == 1 && dim == 1 {
933 periods[0] = Some(parse_numeric_expr(&values[0])?);
934 } else {
935 if values.len() != dim {
936 return Err(format!(
937 "period list length {} must match smooth dimension {}",
938 values.len(),
939 dim
940 ));
941 }
942 for (i, v) in values.iter().enumerate() {
943 if v.eq_ignore_ascii_case("none") {
944 continue;
945 }
946 periods[i] = Some(parse_numeric_expr(v)?);
947 }
948 }
949 Ok(Some(periods))
950}
951
952fn parse_periodic_axes_option(
953 options: &BTreeMap<String, String>,
954 dim: usize,
955) -> Result<Option<Vec<Option<f64>>>, String> {
956 let Some(raw_axes) = options.get("periodic") else {
957 return Ok(None);
958 };
959 let mut periods = parse_periods_option(options, dim)?.unwrap_or_else(|| vec![None; dim]);
960 let axes = split_list_option(raw_axes);
961 if axes.is_empty() {
962 return Ok(Some(periods));
963 }
964 for a in axes {
965 let axis = a
966 .parse::<usize>()
967 .map_err(|err| format!("invalid periodic axis '{a}': {err}"))?;
968 if axis >= dim {
969 return Err(format!(
970 "periodic axis {axis} out of range for {dim}D smooth"
971 ));
972 }
973 if periods[axis].is_none() {
974 return Err(format!(
975 "periodic axis {axis} requires period[{axis}] to be finite"
976 ));
977 }
978 }
979 let listed: std::collections::BTreeSet<usize> = split_list_option(raw_axes)
981 .into_iter()
982 .filter_map(|a| a.parse::<usize>().ok())
983 .collect();
984 for i in 0..dim {
985 if !listed.contains(&i) {
986 periods[i] = None;
987 }
988 }
989 Ok(Some(periods))
990}
991
992fn parse_option_list(raw: &str) -> Vec<String> {
997 let trimmed = raw.trim();
998 let inner = trimmed
1004 .strip_prefix('[')
1005 .and_then(|v| v.strip_suffix(']'))
1006 .or_else(|| {
1007 trimmed
1008 .strip_prefix("c(")
1009 .or_else(|| trimmed.strip_prefix("C("))
1010 .or_else(|| trimmed.strip_prefix('('))
1011 .and_then(|v| v.strip_suffix(')'))
1012 })
1013 .unwrap_or(trimmed);
1014 inner
1015 .split(',')
1016 .map(|v| {
1017 v.trim()
1018 .trim_matches('"')
1019 .trim_matches('\'')
1020 .to_ascii_lowercase()
1021 })
1022 .filter(|v| !v.is_empty())
1023 .collect()
1024}
1025
1026fn parse_periodic_axes(
1027 options: &BTreeMap<String, String>,
1028 dim: usize,
1029) -> Result<Vec<bool>, String> {
1030 let mut axes = vec![false; dim];
1031 if let Some(raw) = options.get("periodic").or_else(|| options.get("cyclic")) {
1032 let lowered = raw.trim().to_ascii_lowercase();
1033 match lowered.as_str() {
1034 "true" | "yes" | "y" => {
1035 axes.fill(true);
1036 return Ok(axes);
1037 }
1038 "false" | "no" | "n" => return Ok(axes),
1039 _ => {}
1040 }
1041 for axis_raw in parse_option_list(raw) {
1042 let axis = axis_raw
1043 .parse::<usize>()
1044 .map_err(|err| format!("invalid periodic axis '{axis_raw}': {err}"))?;
1045 if axis >= dim {
1046 return Err(format!(
1047 "periodic axis {axis} out of range for {dim}D smooth"
1048 ));
1049 }
1050 axes[axis] = true;
1051 }
1052 }
1053 if let Some(raw) = options.get("boundary").or_else(|| options.get("bc")) {
1054 let boundary = parse_option_list(raw);
1055 if boundary.len() == dim {
1056 for (axis, value) in boundary.iter().enumerate() {
1057 if matches!(value.as_str(), "periodic" | "cyclic" | "cc") {
1058 axes[axis] = true;
1059 }
1060 }
1061 } else if dim == 1
1062 && matches!(
1063 boundary.first().map(String::as_str),
1064 Some("periodic" | "cyclic" | "cc")
1065 )
1066 {
1067 axes[0] = true;
1068 }
1069 }
1070 Ok(axes)
1071}
1072
1073fn parse_optional_numeric_list(
1074 options: &BTreeMap<String, String>,
1075 keys: &[&str],
1076 dim: usize,
1077) -> Result<Vec<Option<f64>>, String> {
1078 let Some(raw) = keys.iter().find_map(|key| options.get(*key)) else {
1079 return Ok(vec![None; dim]);
1080 };
1081 let values = split_list_option(raw);
1082 let mut out = vec![None; dim];
1083 if values.len() == 1 && dim == 1 {
1084 if !values[0].eq_ignore_ascii_case("none") {
1085 out[0] = Some(parse_numeric_expr(&values[0])?);
1086 }
1087 return Ok(out);
1088 }
1089 if values.len() != dim {
1090 return Err(format!(
1091 "numeric option list length {} must match smooth dimension {}",
1092 values.len(),
1093 dim
1094 ));
1095 }
1096 for (i, value) in values.iter().enumerate() {
1097 if !value.eq_ignore_ascii_case("none") {
1098 out[i] = Some(parse_numeric_expr(value)?);
1099 }
1100 }
1101 Ok(out)
1102}
1103
1104fn parse_periods(
1105 options: &BTreeMap<String, String>,
1106 periodic_axes: &[bool],
1107) -> Result<Vec<Option<f64>>, String> {
1108 let dim = periodic_axes.len();
1109 let lone_periodic_broadcast = options
1114 .get("period")
1115 .or_else(|| options.get("periods"))
1116 .and_then(|raw| {
1117 let values = split_list_option(raw);
1118 if values.len() != 1 || dim <= 1 {
1119 return None;
1120 }
1121 let mut iter = periodic_axes.iter().enumerate().filter(|(_, p)| **p);
1122 let first = iter.next()?;
1123 if iter.next().is_some() {
1124 return None;
1125 }
1126 Some((first.0, values.into_iter().next().unwrap()))
1127 });
1128 let periods = if let Some((axis, value)) = lone_periodic_broadcast {
1129 let mut out = vec![None; dim];
1130 if !value.eq_ignore_ascii_case("none") {
1131 out[axis] = Some(parse_numeric_expr(&value)?);
1132 }
1133 out
1134 } else {
1135 parse_optional_numeric_list(options, &["period", "periods"], dim)?
1136 };
1137 for (axis, (periodic, period)) in periodic_axes.iter().zip(periods.iter()).enumerate() {
1138 if *periodic
1139 && let Some(value) = period
1140 && (!value.is_finite() || *value <= 0.0)
1141 {
1142 return Err(format!(
1143 "period for periodic axis {axis} must be finite and positive, got {value}"
1144 ));
1145 }
1146 }
1147 Ok(periods)
1148}
1149
1150fn parse_period_origins(
1151 options: &BTreeMap<String, String>,
1152 periodic_axes: &[bool],
1153) -> Result<Vec<Option<f64>>, String> {
1154 parse_optional_numeric_list(
1155 options,
1156 &[
1157 "origin",
1158 "origins",
1159 "period_origin",
1160 "period-origin",
1161 "domain_origin",
1162 ],
1163 periodic_axes.len(),
1164 )
1165}
1166
1167fn parse_tensor_periodic_axes(
1175 options: &BTreeMap<String, String>,
1176 dim: usize,
1177) -> Result<Vec<bool>, String> {
1178 let mut axes = vec![false; dim];
1179 if let Some(raw) = options.get("periodic").or_else(|| options.get("cyclic")) {
1180 let lowered = raw.trim().to_ascii_lowercase();
1181 match lowered.as_str() {
1182 "true" | "yes" | "y" => {
1183 axes.fill(true);
1184 }
1185 "false" | "no" | "n" => {
1186 }
1188 _ => {
1189 let entries = parse_option_list(raw);
1190 let all_bool = !entries.is_empty()
1191 && entries.iter().all(|v| {
1192 matches!(
1193 v.as_str(),
1194 "true" | "yes" | "y" | "false" | "no" | "n" | "none"
1195 )
1196 });
1197 if all_bool {
1198 if entries.len() != dim {
1199 return Err(format!(
1200 "periodic list length {} must match smooth dimension {}",
1201 entries.len(),
1202 dim
1203 ));
1204 }
1205 for (i, v) in entries.iter().enumerate() {
1206 axes[i] = matches!(v.as_str(), "true" | "yes" | "y");
1207 }
1208 } else {
1209 for axis_raw in entries {
1210 let axis = axis_raw
1211 .parse::<usize>()
1212 .map_err(|err| format!("invalid periodic axis '{axis_raw}': {err}"))?;
1213 if axis >= dim {
1214 return Err(format!(
1215 "periodic axis {axis} out of range for {dim}D smooth"
1216 ));
1217 }
1218 axes[axis] = true;
1219 }
1220 }
1221 }
1222 }
1223 }
1224 if let Some(raw) = options.get("boundary").or_else(|| options.get("bc")) {
1225 let boundary = parse_option_list(raw);
1226 if boundary.len() == dim {
1227 for (axis, value) in boundary.iter().enumerate() {
1228 if matches!(value.as_str(), "periodic" | "cyclic" | "cc") {
1229 axes[axis] = true;
1230 }
1231 }
1232 }
1233 }
1234 Ok(axes)
1235}
1236
1237fn tensor_k_axis_option_axis(
1238 key: &str,
1239 cols: &[usize],
1240 ds: &Dataset,
1241) -> Result<Option<usize>, String> {
1242 let Some(suffix) = key.strip_prefix("k_") else {
1243 return Ok(None);
1244 };
1245 if suffix.is_empty() {
1246 return Err("tensor k axis option must be named k_<axis> or k_<variable>".to_string());
1247 }
1248 if let Ok(axis) = suffix.parse::<usize>() {
1249 return if axis < cols.len() {
1250 Ok(Some(axis))
1251 } else {
1252 Err(format!(
1253 "tensor k axis option `{key}` references axis {axis}, but the smooth has {} margins",
1254 cols.len()
1255 ))
1256 };
1257 }
1258
1259 let mut matches = cols
1260 .iter()
1261 .enumerate()
1262 .filter(|(_, col)| ds.headers.get(**col).is_some_and(|name| name == suffix))
1263 .map(|(axis, _)| axis);
1264 let first = matches.next();
1265 if matches.next().is_some() {
1266 return Err(format!(
1267 "tensor k axis option `{key}` matches more than one margin named `{suffix}`"
1268 ));
1269 }
1270 first.map(Some).ok_or_else(|| {
1271 let margin_names = cols
1272 .iter()
1273 .enumerate()
1274 .map(|(axis, col)| {
1275 let name = ds
1276 .headers
1277 .get(*col)
1278 .map(String::as_str)
1279 .unwrap_or("<unnamed>");
1280 format!("{axis}:{name}")
1281 })
1282 .collect::<Vec<_>>()
1283 .join(", ");
1284 format!(
1285 "tensor k axis option `{key}` does not match a margin index or name; tensor margins are [{margin_names}]"
1286 )
1287 })
1288}
1289
1290fn is_tensor_k_axis_option_key(key: &str) -> bool {
1291 key.strip_prefix("k_")
1292 .is_some_and(|suffix| !suffix.is_empty())
1293}
1294
1295fn parse_tensor_k_list(
1299 options: &BTreeMap<String, String>,
1300 cols: &[usize],
1301 ds: &Dataset,
1302) -> Result<(Vec<usize>, bool), String> {
1303 let mut axis_values = vec![None; cols.len()];
1304 let mut saw_axis_alias = false;
1305 for (key, value) in options {
1306 let Some(axis) = tensor_k_axis_option_axis(key, cols, ds)? else {
1307 continue;
1308 };
1309 saw_axis_alias = true;
1310 if axis_values[axis].is_some() {
1311 return Err(format!("tensor k axis {axis} is specified more than once"));
1312 }
1313 let k: usize = value
1314 .parse()
1315 .map_err(|err| format!("invalid tensor k option `{key}={value}`: {err}"))?;
1316 axis_values[axis] = Some(k);
1317 }
1318
1319 let raw = options
1320 .get("k")
1321 .or_else(|| options.get("basis_dim"))
1322 .or_else(|| options.get("basis-dim"))
1323 .or_else(|| options.get("basisdim"));
1324 if saw_axis_alias {
1325 if raw.is_some() {
1326 return Err(
1327 "tensor k axis aliases cannot be combined with k= or basis_dim=".to_string(),
1328 );
1329 }
1330 if let Some(missing_axis) = axis_values.iter().position(Option::is_none) {
1331 let margin_name = cols
1332 .get(missing_axis)
1333 .and_then(|col| ds.headers.get(*col))
1334 .map(String::as_str)
1335 .unwrap_or("<unnamed>");
1336 return Err(format!(
1337 "tensor k axis aliases must specify every margin; missing axis {missing_axis} ({margin_name})"
1338 ));
1339 }
1340 return Ok((
1341 axis_values
1342 .into_iter()
1343 .map(|k| k.expect("missing axis values rejected above"))
1344 .collect(),
1345 false,
1346 ));
1347 }
1348 let Some(raw) = raw else {
1349 let inferred = heuristic_tensor_margin_knots(cols, ds);
1350 return Ok((inferred, true));
1351 };
1352 let entries = split_list_option(raw);
1353 if entries.len() == 1 {
1354 let k: usize = entries[0]
1355 .parse()
1356 .map_err(|err| format!("invalid tensor k '{}': {err}", entries[0]))?;
1357 return Ok((vec![k; cols.len()], false));
1358 }
1359 if entries.len() != cols.len() {
1360 return Err(format!(
1361 "tensor k list length {} must match smooth dimension {}",
1362 entries.len(),
1363 cols.len()
1364 ));
1365 }
1366 let mut out = Vec::with_capacity(entries.len());
1367 for entry in entries {
1368 let k: usize = entry
1369 .parse()
1370 .map_err(|err| format!("invalid tensor k '{entry}': {err}"))?;
1371 out.push(k);
1372 }
1373 Ok((out, false))
1374}
1375
1376fn parse_tensor_identifiability(
1385 options: &BTreeMap<String, String>,
1386 kind: SmoothKind,
1387) -> Result<TensorBSplineIdentifiability, String> {
1388 let Some(raw) = options.get("identifiability").map(String::as_str) else {
1389 return Ok(match kind {
1390 SmoothKind::Ti => TensorBSplineIdentifiability::MarginalSumToZero,
1391 _ => TensorBSplineIdentifiability::default(),
1392 });
1393 };
1394 match raw.trim().to_ascii_lowercase().as_str() {
1395 "none" => Ok(TensorBSplineIdentifiability::None),
1396 "sum_tozero" | "sum-to-zero" | "center_sum_tozero" | "center-sum-to-zero" | "centered"
1397 | "sumtozero" => Ok(TensorBSplineIdentifiability::SumToZero),
1398 "marginal_sum_tozero" | "marginal-sum-to-zero" | "marginal_sumtozero"
1399 | "marginalsumtozero" | "interaction" => {
1400 Ok(TensorBSplineIdentifiability::MarginalSumToZero)
1401 }
1402 other => Err(TermBuilderError::unsupported_feature(format!(
1403 "invalid tensor identifiability '{other}'; expected one of: none, sum_tozero, marginal_sum_tozero"
1404 ))
1405 .to_string()),
1406 }
1407}
1408
1409fn bspline_boundary_declares_periodic_axis(options: &BTreeMap<String, String>) -> bool {
1410 options
1411 .get("boundary")
1412 .or_else(|| options.get("bc"))
1413 .map(|raw| {
1414 parse_option_list(raw)
1415 .into_iter()
1416 .any(|value| matches!(value.as_str(), "periodic" | "cyclic" | "cc"))
1417 })
1418 .unwrap_or(false)
1419}
1420
1421pub(crate) fn canonicalize_smooth_type(raw: &str) -> &str {
1443 match raw {
1444 "tp" => "tps",
1447 "gp" => "matern",
1452 "curv" | "constant_curvature" | "mkappa" => "curvature",
1456 "mjs" | "measure_jet" | "web" => "measurejet",
1460 other => other,
1461 }
1462}
1463
1464pub(crate) fn tensor_margin_bs_is_supported(margin_bs: &str) -> bool {
1475 matches!(
1476 canonicalize_smooth_type(margin_bs),
1477 "tps" | "ps" | "bs" | "bspline" | "cr" | "cs" | "cc" | "cp" | "cyclic"
1478 )
1479}
1480
1481pub(crate) fn smooth_options_declare_periodic(options: &BTreeMap<String, String>) -> bool {
1487 options.contains_key("periodic")
1488 || options.contains_key("cyclic")
1489 || options
1490 .get("boundary")
1491 .or_else(|| options.get("bc"))
1492 .map(|boundary| {
1493 boundary.to_ascii_lowercase().contains("periodic")
1494 || boundary.to_ascii_lowercase().contains("cyclic")
1495 })
1496 .unwrap_or(false)
1497}
1498
1499pub(crate) fn bs_selector_is_vector(raw: &str) -> bool {
1516 let trimmed = raw.trim();
1517 let bracketed = (trimmed.starts_with('[') && trimmed.ends_with(']'))
1518 || (trimmed.starts_with("c(") || trimmed.starts_with("C(")) && trimmed.ends_with(')')
1519 || (trimmed.starts_with('(') && trimmed.ends_with(')'));
1520 bracketed && !parse_option_list(trimmed).is_empty()
1521}
1522
1523pub fn resolve_smooth_type_name(
1524 kind: SmoothKind,
1525 n_cols: usize,
1526 options: &BTreeMap<String, String>,
1527) -> String {
1528 let selector = options.get("type").or_else(|| options.get("bs"));
1529 if let Some(raw) = selector
1534 && bs_selector_is_vector(raw)
1535 && matches!(kind, SmoothKind::Te | SmoothKind::Ti | SmoothKind::T2)
1536 {
1537 return "tensor".to_string();
1538 }
1539 selector
1540 .map(|s| canonicalize_smooth_type(&s.to_ascii_lowercase()).to_string())
1541 .unwrap_or_else(|| match kind {
1542 SmoothKind::Te | SmoothKind::Ti | SmoothKind::T2 => "tensor".to_string(),
1543 SmoothKind::S if n_cols == 1 => "bspline".to_string(),
1544 SmoothKind::S if smooth_options_declare_periodic(options) => "tensor".to_string(),
1548 SmoothKind::S => "tps".to_string(),
1549 })
1550}
1551
1552pub fn smooth_type_uses_spatial_center_heuristic(canonical_type: &str) -> bool {
1561 matches!(canonical_type, "tps" | "matern" | "duchon")
1562}
1563
1564pub fn build_smooth_basis(
1565 kind: SmoothKind,
1566 vars: &[String],
1567 cols: &[usize],
1568 options: &BTreeMap<String, String>,
1569 ds: &Dataset,
1570 inference_notes: &mut Vec<String>,
1571 policy: &ResourcePolicy,
1572 smooth_coordinate_count: usize,
1573) -> Result<SmoothBasisSpec, String> {
1574 let coord_cols: Vec<(&String, usize)> = vars
1590 .iter()
1591 .zip(cols.iter().copied())
1592 .filter(|(_, col)| !matches!(ds.column_kinds.get(*col), Some(ColumnKindTag::Categorical)))
1593 .collect();
1594 if !coord_cols.is_empty() {
1595 let views: Vec<ArrayView1<'_, f64>> = coord_cols
1596 .iter()
1597 .map(|(_, col)| ds.values.column(*col))
1598 .collect();
1599 let n_rows = views[0].len();
1600 let mut distinct_points = std::collections::HashSet::<Vec<u64>>::new();
1601 for r in 0..n_rows {
1602 let key: Vec<u64> = views
1603 .iter()
1604 .map(|v| {
1605 let x = v[r];
1606 let norm = if x == 0.0 { 0.0 } else { x };
1607 norm.to_bits()
1608 })
1609 .collect();
1610 distinct_points.insert(key);
1611 if distinct_points.len() > 1 {
1612 break;
1613 }
1614 }
1615 if distinct_points.len() <= 1 {
1616 return Err(TermBuilderError::degenerate_data(if coord_cols.len() == 1 {
1617 let var = coord_cols[0].0;
1618 format!(
1619 "smooth term over '{var}' has only one unique value in the training data \
1620 — a smooth on a constant column is degenerate and would only fit the response mean. \
1621 Remove `{var}` from the smooth, drop the term, or check the data."
1622 )
1623 } else {
1624 let names = coord_cols
1625 .iter()
1626 .map(|(v, _)| v.as_str())
1627 .collect::<Vec<_>>()
1628 .join(", ");
1629 format!(
1630 "smooth term over ({names}) has only one unique joint coordinate in the training \
1631 data — every coordinate is constant, so the smooth is degenerate and would only \
1632 fit the response mean. Drop the term or check the data."
1633 )
1634 })
1635 .to_string());
1636 }
1637 }
1638 if let Some(by_name) = options.get("by").cloned() {
1639 let by_col = options
1640 .get("__by_col")
1641 .and_then(|raw| raw.parse::<usize>().ok())
1642 .or_else(|| vars.iter().position(|v| v == &by_name).map(|idx| cols[idx]))
1643 .ok_or_else(|| format!("unknown by= column '{by_name}'"))?;
1644 let mut inner_options = options.clone();
1645 inner_options.remove("by");
1646 inner_options.remove("__by_col");
1647 inner_options.remove("id");
1648 let inner = build_smooth_basis(
1649 kind,
1650 vars,
1651 cols,
1652 &inner_options,
1653 ds,
1654 inference_notes,
1655 policy,
1656 smooth_coordinate_count,
1657 )?;
1658 let by_kind = match ds.column_kinds.get(by_col).copied() {
1659 Some(ColumnKindTag::Categorical) => ByVarKind::Factor {
1660 feature_col: by_col,
1661 ordered: option_bool(options, "ordered").unwrap_or(false),
1662 frozen_levels: None,
1663 },
1664 Some(ColumnKindTag::Continuous | ColumnKindTag::Binary) => ByVarKind::Numeric {
1665 feature_col: by_col,
1666 },
1667 None => {
1668 return Err(format!(
1669 "internal column-kind lookup failed for by='{by_name}'"
1670 ));
1671 }
1672 };
1673 return Ok(SmoothBasisSpec::BySmooth {
1674 smooth: Box::new(inner),
1675 by_kind,
1676 });
1677 }
1678
1679 let smooth_double_penalty = option_bool(options, "double_penalty").unwrap_or(true);
1680 let type_opt = resolve_smooth_type_name(kind, cols.len(), options);
1681
1682 if matches!(type_opt.as_str(), "fs" | "sz" | "re") {
1683 validate_known_options(
1684 type_opt.as_str(),
1685 options,
1686 &[
1687 "type",
1688 "bs",
1689 "k",
1690 "basis_dim",
1691 "basis-dim",
1692 "basisdim",
1693 "knots",
1694 "knot_placement",
1695 "knot-placement",
1696 "knotplacement",
1697 "degree",
1698 "penalty_order",
1699 "m",
1700 "double_penalty",
1701 "ordered",
1702 ],
1703 )?;
1704 if cols.len() != 2 {
1705 return Err(format!(
1706 "{} factor-smooth currently expects exactly two variables (one numeric, one categorical)",
1707 type_opt
1708 ));
1709 }
1710 let kinds = cols
1711 .iter()
1712 .map(|&c| ds.column_kinds.get(c).copied())
1713 .collect::<Vec<_>>();
1714 let (cont_idx, group_idx) = if type_opt == "re" {
1715 match (kinds[0], kinds[1]) {
1717 (Some(ColumnKindTag::Categorical), _) => (1usize, 0usize),
1718 (_, Some(ColumnKindTag::Categorical)) => (0usize, 1usize),
1719 _ => (1usize, 0usize),
1720 }
1721 } else {
1722 match (kinds[0], kinds[1]) {
1723 (_, Some(ColumnKindTag::Categorical)) => (0usize, 1usize),
1724 (Some(ColumnKindTag::Categorical), _) => (1usize, 0usize),
1725 _ => {
1726 return Err(format!(
1727 "{} factor-smooth requires one categorical factor variable",
1728 type_opt
1729 ));
1730 }
1731 }
1732 };
1733 let c = cols[cont_idx];
1734 let (minv, maxv) = col_minmax(ds.values.column(c))?;
1735 let degree = if type_opt == "re" {
1736 1
1737 } else {
1738 option_usize(options, "degree").unwrap_or(DEFAULT_BSPLINE_DEGREE)
1739 };
1740 let pooled_internal = heuristic_knots_for_column(ds.values.column(c));
1760 let default_internal = if type_opt == "re" {
1761 0
1774 } else {
1775 let min_group_resolution =
1776 min_per_group_unique_count(ds.values.column(c), ds.values.column(cols[group_idx]));
1777 let basis_cap = min_group_resolution.saturating_sub(2).max(degree + 2);
1785 let internal_cap = basis_cap.saturating_sub(degree + 1);
1786 let capped = pooled_internal.min(internal_cap.max(1));
1787 let fs_default_internal = FACTOR_SMOOTH_DEFAULT_BASIS_DIM
1803 .saturating_sub(degree + 1)
1804 .max(1);
1805 capped.min(fs_default_internal)
1806 };
1807 let (n_knots, _, effective_degree) =
1808 parse_ps_internal_knots(options, degree, default_internal)?;
1809 let penalty_order = option_usize(options, "penalty_order")
1810 .unwrap_or(if effective_degree > 1 { 2 } else { 1 })
1811 .min(effective_degree);
1812 let sz_uses_cr = type_opt.as_str() == "sz" && effective_degree == DEFAULT_BSPLINE_DEGREE;
1826 let marginal_knotspec = if sz_uses_cr {
1835 let k_cr = (n_knots + effective_degree + 1).max(CR_MIN_KNOTS);
1836 match capped_cr_marginal_knotspec(
1837 ds.values.column(c),
1838 k_cr,
1839 &vars.join(","),
1840 inference_notes,
1841 )? {
1842 Some(cr_knotspec) => cr_knotspec,
1843 None => resolve_nonperiodic_bspline_knotspec(
1844 options,
1845 ds.values.column(c),
1846 (minv, maxv),
1847 effective_degree,
1848 n_knots,
1849 )?,
1850 }
1851 } else {
1852 resolve_nonperiodic_bspline_knotspec(
1853 options,
1854 ds.values.column(c),
1855 (minv, maxv),
1856 effective_degree,
1857 n_knots,
1858 )?
1859 };
1860 let marginal = BSplineBasisSpec {
1861 degree: effective_degree,
1862 penalty_order,
1863 knotspec: marginal_knotspec,
1864 double_penalty: option_bool(options, "double_penalty")
1875 .unwrap_or(type_opt.as_str() != "sz"),
1876 identifiability: BSplineIdentifiability::None,
1877 boundary_conditions: Default::default(),
1878 boundary: OneDimensionalBoundary::Open,
1879 };
1880 let flavour = match type_opt.as_str() {
1881 "fs" => FactorSmoothFlavour::Fs {
1882 m_null_penalty_orders: vec![
1883 option_usize(options, "m").unwrap_or(DEFAULT_PENALTY_ORDER),
1884 ],
1885 },
1886 "sz" => FactorSmoothFlavour::Sz,
1887 "re" => FactorSmoothFlavour::Re,
1888 other => {
1890 return Err(format!(
1891 "internal: factor-smooth flavour dispatch reached unexpected type `{}`",
1892 other
1893 ));
1894 }
1895 };
1896 return Ok(SmoothBasisSpec::FactorSmooth {
1897 spec: FactorSmoothSpec {
1898 continuous_cols: vec![c],
1899 group_col: cols[group_idx],
1900 marginal,
1901 flavour,
1902 group_frozen_levels: None,
1903 frozen_global_orthogonality: None,
1904 },
1905 });
1906 }
1907
1908 match type_opt.as_str() {
1909 "cyclic" | "cc" | "cp" | "cyclic-ps" => {
1910 validate_known_options(
1911 "cyclic",
1912 options,
1913 &[
1914 "type",
1915 "bs",
1916 "by",
1917 "k",
1918 "basis_dim",
1919 "basis-dim",
1920 "basisdim",
1921 "degree",
1922 "penalty_order",
1923 "period",
1924 "periods",
1925 "period_start",
1926 "period_end",
1927 "start",
1928 "end",
1929 "origin",
1930 "origins",
1931 "period_origin",
1932 "period-origin",
1933 "domain_origin",
1934 "double_penalty",
1935 "id",
1936 "__by_col",
1937 "identifiability",
1938 ],
1939 )?;
1940 if cols.len() != 1 {
1941 return Err(format!(
1942 "periodic smooth expects one variable, got {}",
1943 cols.len()
1944 ));
1945 }
1946 let c = cols[0];
1947 let (minv, maxv) = col_minmax(ds.values.column(c))?;
1948 let degree = option_usize(options, "degree").unwrap_or(DEFAULT_BSPLINE_DEGREE);
1949 let mut default_internal = heuristic_knots_for_column(ds.values.column(c));
1950 if ds.values.nrows() <= 32 && smooth_coordinate_count >= 5 {
1951 default_internal = default_internal.min(1);
1952 }
1953 let cyclic_default_basis_cap = CYCLIC_DEFAULT_BASIS_DIM.max(degree + 1);
1967 let default_basis = (default_internal + degree + 1).min(cyclic_default_basis_cap);
1968 let num_basis = option_usize_any(options, &["k", "basis_dim", "basis-dim", "basisdim"])
1969 .unwrap_or(default_basis);
1970 if num_basis < degree + 1 {
1971 return Err(format!(
1972 "periodic smooth: k={} too small for degree {}; expected k >= {}",
1973 num_basis,
1974 degree,
1975 degree + 1
1976 ));
1977 }
1978 let periodic_axes = [true];
1989 let periods = parse_periods(options, &periodic_axes)?;
1990 let origins = parse_period_origins(options, &periodic_axes)?;
1991 let (domain_start, period) = if let Some(p) = periods[0] {
1992 (origins[0].unwrap_or(minv), p)
1993 } else {
1994 parse_periodic_domain_1d(options, minv, maxv)?
1995 };
1996 Ok(SmoothBasisSpec::BSpline1D {
1997 feature_col: c,
1998 spec: BSplineBasisSpec {
1999 degree,
2000 penalty_order: option_usize(options, "penalty_order")
2001 .unwrap_or(DEFAULT_PENALTY_ORDER),
2002 knotspec: BSplineKnotSpec::PeriodicUniform {
2003 data_range: (domain_start, domain_start + period),
2004 num_basis,
2005 },
2006 double_penalty: smooth_double_penalty,
2007 identifiability: BSplineIdentifiability::default(),
2008 boundary_conditions: Default::default(),
2009 boundary: OneDimensionalBoundary::Cyclic {
2010 start: domain_start,
2011 end: domain_start + period,
2012 },
2013 },
2014 })
2015 }
2016 "bspline" | "ps" | "p-spline" | "cr" | "cs" => {
2017 let validation_name = match type_opt.as_str() {
2031 "cr" => "cr",
2032 "cs" => "cs",
2033 _ => "bspline",
2034 };
2035 validate_known_options(
2036 validation_name,
2037 options,
2038 &[
2039 "type",
2040 "bs",
2041 "by",
2042 "k",
2043 "basis_dim",
2044 "basis-dim",
2045 "basisdim",
2046 "knots",
2047 "knot_placement",
2048 "knot-placement",
2049 "knotplacement",
2050 "degree",
2051 "penalty_order",
2052 "boundary",
2053 "bc",
2054 "boundary_conditions",
2055 "bc_left",
2056 "bc_right",
2057 "left_bc",
2058 "right_bc",
2059 "start_bc",
2060 "end_bc",
2061 "side",
2062 "anchor",
2063 "anchor_value",
2064 "value",
2065 "anchor_left",
2066 "left_anchor",
2067 "anchor_right",
2068 "right_anchor",
2069 "periodic",
2070 "period",
2071 "periods",
2072 "period_start",
2073 "period_end",
2074 "origin",
2075 "double_penalty",
2076 "by",
2077 "id",
2078 "__by_col",
2079 "identifiability",
2080 "by",
2081 ],
2082 )?;
2083 if cols.len() != 1 {
2084 return Err(TermBuilderError::incompatible_config(format!(
2085 "bspline smooth expects one variable, got {}",
2086 cols.len()
2087 ))
2088 .to_string());
2089 }
2090 let c = cols[0];
2091 let (minv, maxv) = col_minmax(ds.values.column(c))?;
2092 let degree = option_usize(options, "degree").unwrap_or(DEFAULT_BSPLINE_DEGREE);
2093 let default_internal = heuristic_knots_for_column(ds.values.column(c));
2094 let (mut n_knots, inferred, effective_degree) =
2095 parse_ps_internal_knots(options, degree, default_internal)?;
2096 let periodic_axes = parse_periodic_axes(options, 1).map_err(|e| e.to_string())?;
2097 if periodic_axes[0] && effective_degree != degree {
2102 return Err(TermBuilderError::invalid_option(format!(
2103 "periodic smooth: k={} too small for degree {}; expected k >= {}",
2104 effective_degree + 1,
2105 degree,
2106 degree + 1
2107 ))
2108 .to_string());
2109 }
2110 if inferred && ds.values.nrows() <= 32 && smooth_coordinate_count >= 5 {
2111 n_knots = n_knots.min(1);
2112 }
2113 if inferred {
2114 let unique = unique_count_column(ds.values.column(c));
2115 let ceiling = ((unique as f64).cbrt() as usize).max(20);
2116 inference_notes.push(format!(
2117 "Automatically set {} internal knots for smooth '{}' from {} unique values (rule: clamp(unique/4, 4..max(20, cbrt(unique))) = clamp(unique/4, 4..{})). Override with knots=... or k=....",
2118 n_knots,
2119 vars.join(","),
2120 unique,
2121 ceiling,
2122 ));
2123 }
2124 let boundary_conditions =
2125 if periodic_axes[0] && bspline_boundary_declares_periodic_axis(options) {
2126 BSplineBoundaryConditions::default()
2127 } else {
2128 parse_bspline_boundary_conditions(options).map_err(|e| e.to_string())?
2129 };
2130 let periods = parse_periods(options, &periodic_axes).map_err(|e| e.to_string())?;
2131 let origins =
2132 parse_period_origins(options, &periodic_axes).map_err(|e| e.to_string())?;
2133 let (knotspec, boundary) = if periodic_axes[0] {
2134 if !boundary_conditions.is_free() {
2135 return Err(TermBuilderError::incompatible_config(
2136 "periodic B-splines cannot also declare endpoint boundary conditions",
2137 )
2138 .to_string());
2139 }
2140 {
2141 let (domain_start, p_value) = if periods[0].is_some() {
2142 (origins[0].unwrap_or(minv), periods[0].unwrap())
2143 } else {
2144 parse_periodic_domain_1d(options, minv, maxv).map_err(|e| e.to_string())?
2145 };
2146 let domain_end = domain_start + p_value;
2147 (
2148 BSplineKnotSpec::PeriodicUniform {
2149 data_range: (domain_start, domain_end),
2150 num_basis: n_knots + effective_degree + 1,
2151 },
2152 OneDimensionalBoundary::Cyclic {
2153 start: domain_start,
2154 end: domain_end,
2155 },
2156 )
2157 }
2158 } else if type_opt == "cr" || type_opt == "cs" {
2159 let k_cr = (n_knots + effective_degree + 1).max(CR_MIN_KNOTS);
2176 let knotspec = match capped_cr_marginal_knotspec(
2177 ds.values.column(c),
2178 k_cr,
2179 &vars.join(","),
2180 inference_notes,
2181 )? {
2182 Some(cr_knotspec) => cr_knotspec,
2183 None => resolve_nonperiodic_bspline_knotspec(
2184 options,
2185 ds.values.column(c),
2186 (minv, maxv),
2187 effective_degree,
2188 n_knots,
2189 )?,
2190 };
2191 (knotspec, parse_cyclic_boundary(options, minv, maxv)?)
2192 } else {
2193 (
2194 resolve_nonperiodic_bspline_knotspec(
2195 options,
2196 ds.values.column(c),
2197 (minv, maxv),
2198 effective_degree,
2199 n_knots,
2200 )?,
2201 parse_cyclic_boundary(options, minv, maxv)?,
2202 )
2203 };
2204 let double_penalty = if type_opt == "cr" {
2208 option_bool(options, "double_penalty").unwrap_or(false)
2209 } else {
2210 smooth_double_penalty
2211 };
2212 let penalty_order = option_usize(options, "penalty_order")
2217 .unwrap_or(DEFAULT_PENALTY_ORDER)
2218 .min(effective_degree);
2219 Ok(SmoothBasisSpec::BSpline1D {
2220 feature_col: c,
2221 spec: BSplineBasisSpec {
2222 degree: effective_degree,
2223 penalty_order,
2224 knotspec,
2225 double_penalty,
2226 identifiability: BSplineIdentifiability::default(),
2227 boundary,
2228 boundary_conditions,
2229 },
2230 })
2231 }
2232 "tps" | "thinplate" | "thin-plate" => {
2233 validate_known_options(
2234 "thinplate",
2235 options,
2236 &[
2237 SECONDARY_CENTER_CAP_OPTION,
2238 "type",
2239 "bs",
2240 "by",
2241 "length_scale",
2242 "centers",
2243 "k",
2244 "basis_dim",
2245 "basis-dim",
2246 "basisdim",
2247 "knots",
2248 "include_intercept",
2249 "double_penalty",
2250 "by",
2251 "id",
2252 "__by_col",
2253 "identifiability",
2254 "by",
2255 "scale_dims",
2256 ],
2257 )?;
2258 let plan = plan_spatial_basis(
2259 ds.values.nrows(),
2260 cols.len(),
2261 CenterCountRequest::Default,
2262 DuchonNullspaceOrder::Linear,
2263 option_bool(options, "scale_dims").unwrap_or(false),
2264 policy,
2265 )
2266 .map_err(|e| e.to_string())?;
2267 let default_centers = plan.centers;
2277 let centers = parse_countwith_basis_alias(
2278 options,
2279 "centers",
2280 cap_default_spatial_centers(options, default_centers),
2281 )?;
2282 let center_strategy = if has_explicit_countwith_basis_alias(options, "centers") {
2283 spatial_center_strategy_for_dimension(centers, cols.len())
2284 } else {
2285 auto_spatial_center_strategy(centers, cols.len())
2286 };
2287 Ok(SmoothBasisSpec::ThinPlate {
2288 feature_cols: cols.to_vec(),
2289 spec: ThinPlateBasisSpec {
2290 center_strategy,
2291 periodic: parse_periodic_axes_option(options, cols.len())?,
2292 length_scale: option_f64(options, "length_scale").unwrap_or(0.0),
2300 double_penalty: smooth_double_penalty,
2301 identifiability: parse_spatial_identifiability(options)
2302 .map_err(|e| e.to_string())?,
2303 radial_reparam: None,
2304 },
2305 input_scales: None,
2306 })
2307 }
2308 "sphere" | "s2" | "sos" => {
2309 validate_known_options(
2310 "sphere",
2311 options,
2312 &[
2313 "type",
2314 "bs",
2315 "by",
2316 "centers",
2317 "k",
2318 "basis_dim",
2319 "basis-dim",
2320 "basisdim",
2321 "knots",
2322 "penalty_order",
2323 "m",
2324 "double_penalty",
2325 "id",
2326 "__by_col",
2327 "kernel",
2328 "method",
2329 "radians",
2330 "units",
2331 "degree",
2332 "l",
2333 "max_degree",
2334 "max-degree",
2335 ],
2336 )?;
2337 if cols.len() != 2 {
2338 return Err(format!(
2339 "sphere smooth expects exactly two variables (lat, lon), got {}",
2340 cols.len()
2341 ));
2342 }
2343 let radians = option_bool(options, "radians").unwrap_or_else(|| {
2344 options
2345 .get("units")
2346 .map(|u| u.eq_ignore_ascii_case("radian") || u.eq_ignore_ascii_case("radians"))
2347 .unwrap_or(false)
2348 });
2349 let degree_requested = options.contains_key("degree")
2355 || options.contains_key("l")
2356 || options.contains_key("max_degree")
2357 || options.contains_key("max-degree");
2358 let kernel = options
2359 .get("kernel")
2360 .or_else(|| options.get("method"))
2361 .map(|raw| strip_quotes(raw).trim().to_ascii_lowercase())
2362 .unwrap_or_else(|| {
2363 if degree_requested {
2364 "harmonic".to_string()
2365 } else {
2366 "sobolev".to_string()
2367 }
2368 });
2369 let (method, wahba_kernel) = match kernel.as_str() {
2370 "sobolev" | "wahba" | "wahba_sobolev" | "wahba-sobolev" => {
2371 (SphereMethod::Wahba, SphereWahbaKernel::Sobolev)
2372 }
2373 "pseudo" | "mgcv" | "sos" | "wahba_pseudo" | "wahba-pseudo" => {
2374 (SphereMethod::Wahba, SphereWahbaKernel::Pseudo)
2375 }
2376 "harmonic" | "spherical_harmonic" | "spherical-harmonic" => {
2377 (SphereMethod::Harmonic, SphereWahbaKernel::Sobolev)
2378 }
2379 other => {
2380 return Err(format!(
2381 "unsupported sphere kernel '{other}'; expected sobolev, pseudo, or harmonic"
2382 ));
2383 }
2384 };
2385 let max_degree = if matches!(method, SphereMethod::Harmonic) {
2386 let degree =
2387 option_usize_any(options, &["degree", "l", "max_degree", "max-degree"])
2388 .or_else(|| option_usize(options, "centers"))
2389 .or_else(|| {
2390 option_usize_any(options, &["k", "basis_dim", "basis-dim", "basisdim"])
2391 .and_then(|k| (1..=128).find(|&l| l * (l + 2) >= k))
2392 })
2393 .unwrap_or_else(|| default_spherical_harmonic_degree(ds.values.nrows()));
2394 if degree == 0 {
2395 return Err("sphere smooth requires degree/max_degree >= 1".to_string());
2396 }
2397 if degree > 32 {
2398 return Err(format!(
2399 "sphere smooth max_degree={} is too large for the dense harmonic engine (limit 32)",
2400 degree
2401 ));
2402 }
2403 Some(degree)
2404 } else {
2405 None
2406 };
2407 let penalty_order = option_usize(options, "penalty_order")
2408 .or_else(|| option_usize(options, "m"))
2409 .unwrap_or(DEFAULT_PENALTY_ORDER);
2410 let center_strategy = if matches!(method, SphereMethod::Wahba) {
2411 let mut centers = parse_countwith_basis_alias(
2412 options,
2413 "centers",
2414 default_num_centers(ds.values.nrows(), cols.len()),
2415 )?;
2416 if penalty_order >= 4 {
2417 centers = centers.max(30);
2418 }
2419 CenterStrategy::FarthestPoint {
2420 num_centers: centers,
2421 }
2422 } else {
2423 CenterStrategy::FarthestPoint { num_centers: 0 }
2424 };
2425 Ok(SmoothBasisSpec::Sphere {
2426 feature_cols: cols.to_vec(),
2427 spec: SphericalSplineBasisSpec {
2428 center_strategy,
2429 penalty_order,
2430 double_penalty: smooth_double_penalty,
2431 radians,
2432 method,
2433 max_degree,
2434 wahba_kernel,
2435 identifiability: SphericalSplineIdentifiability::CenterSumToZero,
2436 },
2437 })
2438 }
2439 "curvature" => {
2440 validate_known_options(
2446 "curvature",
2447 options,
2448 &[
2449 "type",
2450 "bs",
2451 "by",
2452 "centers",
2453 "k",
2454 "basis_dim",
2455 "basis-dim",
2456 "basisdim",
2457 "knots",
2458 "kappa",
2459 "length_scale",
2460 "double_penalty",
2461 "id",
2462 "__by_col",
2463 ],
2464 )?;
2465 let kappa = option_f64(options, "kappa").unwrap_or(0.0);
2466 if !kappa.is_finite() {
2467 return Err("curvature smooth requires a finite kappa".to_string());
2468 }
2469 let length_scale = option_f64(options, "length_scale").unwrap_or(0.0);
2470 if !length_scale.is_finite() || length_scale < 0.0 {
2471 return Err(format!(
2472 "curvature smooth length_scale must be positive (or omitted for auto); got {length_scale}"
2473 ));
2474 }
2475 let centers = parse_countwith_basis_alias(
2476 options,
2477 "centers",
2478 default_num_centers(ds.values.nrows(), cols.len()),
2479 )?;
2480 if centers < 2 {
2481 return Err("curvature smooth requires at least 2 centers".to_string());
2482 }
2483 Ok(SmoothBasisSpec::ConstantCurvature {
2484 feature_cols: cols.to_vec(),
2485 spec: ConstantCurvatureBasisSpec {
2486 center_strategy: CenterStrategy::FarthestPoint {
2487 num_centers: centers,
2488 },
2489 kappa,
2490 length_scale,
2493 double_penalty: option_bool(options, "double_penalty").unwrap_or(false),
2500 identifiability: ConstantCurvatureIdentifiability::CenterSumToZero,
2501 },
2502 })
2503 }
2504 "measurejet" => {
2505 validate_known_options(
2511 "measurejet",
2512 options,
2513 &[
2514 "type",
2515 "bs",
2516 "by",
2517 "centers",
2518 "k",
2519 "basis_dim",
2520 "basis-dim",
2521 "basisdim",
2522 "knots",
2523 "s",
2524 "alpha",
2525 "tau",
2526 "scales",
2527 "length_scale",
2528 "double_penalty",
2529 "multiscale",
2530 "learn_length_scale",
2531 "id",
2532 "__by_col",
2533 ],
2534 )?;
2535 let order_s = option_f64(options, "s").unwrap_or(0.0);
2536 if !(order_s.is_finite() && (order_s == 0.0 || (order_s > 0.0 && order_s < 2.0))) {
2539 return Err(format!(
2540 "measurejet smooth s must lie in (0, 2) (or be omitted for auto); got {order_s}"
2541 ));
2542 }
2543 let alpha =
2551 option_f64(options, "alpha").unwrap_or(MeasureJetBasisSpec::default().alpha);
2552 if !alpha.is_finite() {
2553 return Err("measurejet smooth requires a finite alpha".to_string());
2554 }
2555 let tau0 = option_f64(options, "tau").unwrap_or(1e-3);
2556 if !(tau0.is_finite() && tau0 >= 0.0) {
2557 return Err(format!(
2558 "measurejet smooth tau must be finite and nonnegative; got {tau0}"
2559 ));
2560 }
2561 let num_scales = option_usize(options, "scales").unwrap_or(0);
2562 let length_scale = option_f64(options, "length_scale").unwrap_or(0.0);
2563 if !length_scale.is_finite() || length_scale < 0.0 {
2564 return Err(format!(
2565 "measurejet smooth length_scale must be positive (or omitted for auto); got {length_scale}"
2566 ));
2567 }
2568 let centers = parse_countwith_basis_alias(
2569 options,
2570 "centers",
2571 default_num_centers(ds.values.nrows(), cols.len()),
2572 )?;
2573 if centers < 3 {
2574 return Err("measurejet smooth requires at least 3 centers".to_string());
2575 }
2576 let multiscale = option_bool(options, "multiscale").unwrap_or(false);
2580 let learn_length_scale = option_bool(options, "learn_length_scale").unwrap_or(false);
2585 Ok(SmoothBasisSpec::MeasureJet {
2586 feature_cols: cols.to_vec(),
2587 spec: MeasureJetBasisSpec {
2588 center_strategy: CenterStrategy::FarthestPoint {
2589 num_centers: centers,
2590 },
2591 order_s,
2592 alpha,
2593 tau0,
2594 num_scales,
2595 length_scale,
2598 double_penalty: smooth_double_penalty,
2599 learn_length_scale,
2600 multiscale,
2601 identifiability: MeasureJetIdentifiability::CenterSumToZero,
2602 frozen_quadrature: None,
2603 },
2604 input_scales: None,
2605 })
2606 }
2607 "matern" => {
2608 validate_known_options(
2613 "matern",
2614 options,
2615 &[
2616 SECONDARY_CENTER_CAP_OPTION,
2617 "type",
2618 "bs",
2619 "by",
2620 "nu",
2621 "length_scale",
2622 "centers",
2623 "k",
2624 "basis_dim",
2625 "basis-dim",
2626 "basisdim",
2627 "knots",
2628 "include_intercept",
2629 "double_penalty",
2630 "by",
2631 "id",
2632 "__by_col",
2633 "identifiability",
2634 "by",
2635 "scale_dims",
2636 ],
2637 )?;
2638 let plan = plan_spatial_basis(
2639 ds.values.nrows(),
2640 cols.len(),
2641 CenterCountRequest::Default,
2642 DuchonNullspaceOrder::Zero,
2643 option_bool(options, "scale_dims").unwrap_or(false),
2644 policy,
2645 )
2646 .map_err(|e| e.to_string())?;
2647 let centers = parse_countwith_basis_alias(
2648 options,
2649 "centers",
2650 cap_default_spatial_centers(
2651 options,
2652 default_matern_center_count(ds.values.nrows(), cols.len(), plan.centers),
2653 ),
2654 )?;
2655 let center_strategy = if has_explicit_countwith_basis_alias(options, "centers") {
2656 spatial_center_strategy_for_dimension(centers, cols.len())
2657 } else {
2658 auto_spatial_center_strategy(centers, cols.len())
2659 };
2660 let nu = parse_matern_nu(options.get("nu").map(String::as_str).unwrap_or("5/2"))?;
2661 if matches!(nu, MaternNu::Half) && cols.len() >= 2 {
2667 return Err(TermBuilderError::unsupported_feature(format!(
2668 "matern() with nu=1/2 is not supported for d>=2 (got {} covariates): \
2669 the exponential kernel's Laplacian is singular at center collisions, \
2670 which makes the operator-collocation penalty non-invertible. \
2671 Choose nu>=3/2 (e.g. nu=3/2 or the default nu=5/2) for multi-dimensional smooths.",
2672 cols.len()
2673 ))
2674 .to_string());
2675 }
2676 let aniso_log_scales = if option_bool(options, "scale_dims").unwrap_or(false) {
2677 Some(vec![0.0; cols.len()])
2678 } else {
2679 None
2680 };
2681 Ok(SmoothBasisSpec::Matern {
2682 feature_cols: cols.to_vec(),
2683 spec: MaternBasisSpec {
2684 center_strategy,
2685 periodic: parse_periodic_axes_option(options, cols.len())?,
2686 length_scale: option_f64(options, "length_scale")
2687 .unwrap_or_else(|| default_matern_length_scale(ds, cols)),
2688 nu,
2689 include_intercept: option_bool(options, "include_intercept").unwrap_or(false),
2690 double_penalty: smooth_double_penalty,
2691 identifiability: parse_matern_identifiability(options)
2692 .map_err(|e| e.to_string())?,
2693 aniso_log_scales,
2694 nullspace_shrinkage_survived: None,
2699 },
2700 input_scales: None,
2701 })
2702 }
2703 "duchon" => {
2704 validate_known_options(
2705 "duchon",
2706 options,
2707 &[
2708 SECONDARY_CENTER_CAP_OPTION,
2709 "type",
2710 "bs",
2711 "by",
2712 "length_scale",
2713 "centers",
2714 "k",
2715 "basis_dim",
2716 "basis-dim",
2717 "basisdim",
2718 "knots",
2719 "power",
2720 "p",
2721 "nullspace_order",
2722 "order",
2723 "identifiability",
2724 "by",
2725 "periodic",
2726 "cyclic",
2727 "period",
2728 "period_start",
2729 "period_end",
2730 "scale_dims",
2731 "double_penalty",
2732 "by",
2733 "id",
2734 "__by_col",
2735 ],
2736 )?;
2737 if options.contains_key("double_penalty") {
2738 return Err(TermBuilderError::incompatible_config(format!(
2739 "Duchon smooth '{}' does not support double_penalty; the Duchon smoother already ships its native reproducing-norm penalty plus a null-space shrinkage ridge.",
2740 vars.join(", ")
2741 ))
2742 .to_string());
2743 }
2744 let requested_nullspace_order = parse_duchon_order(options)?;
2745 let length_scale = option_f64_strict(options, "length_scale")?;
2746 let (nullspace_order, power) = match parse_duchon_power_policy(options)? {
2759 DuchonPowerPolicy::Explicit(req_power) => {
2760 if length_scale.is_some() && req_power.fract() != 0.0 {
2761 return Err(TermBuilderError::incompatible_config(format!(
2762 "hybrid Duchon-Matern smooth '{}' (length_scale=...) requires an integer power, got power={}; \
2763 drop length_scale to use the scale-free structural kernel with a fractional power.",
2764 vars.join(", "),
2765 req_power,
2766 ))
2767 .to_string());
2768 }
2769 (requested_nullspace_order, req_power)
2770 }
2771 DuchonPowerPolicy::CubicStructuralDefault => {
2772 match length_scale {
2779 None => crate::basis::duchon_cubic_default(cols.len()),
2780 Some(_) => {
2781 let max_op = crate::basis::duchon_max_active_operator_derivative_order(
2802 &DuchonOperatorPenaltySpec::default(),
2803 );
2804 let (ns, s) = crate::basis::resolve_duchon_orders(
2805 cols.len(),
2806 requested_nullspace_order,
2807 max_op,
2808 length_scale,
2809 );
2810 (ns, s as f64)
2811 }
2812 }
2813 }
2814 };
2815 let plan = plan_spatial_basis(
2816 ds.values.nrows(),
2817 cols.len(),
2818 CenterCountRequest::Default,
2819 nullspace_order,
2820 option_bool(options, "scale_dims").unwrap_or(false),
2821 policy,
2822 )
2823 .map_err(|e| e.to_string())?;
2824 let centers_explicit = has_explicit_countwith_basis_alias(options, "centers");
2825 let requested_centers = parse_countwith_basis_alias(
2826 options,
2827 "centers",
2828 cap_default_spatial_centers(options, plan.centers),
2829 )?;
2830 let polynomial_cols = match nullspace_order {
2831 DuchonNullspaceOrder::Zero => 1,
2832 DuchonNullspaceOrder::Linear => cols.len() + 1,
2833 DuchonNullspaceOrder::Degree(degree) => {
2834 crate::basis::duchon_nullspace_dimension(cols.len(), degree)
2835 }
2836 };
2837 if requested_centers <= polynomial_cols {
2838 return Err(TermBuilderError::incompatible_config(format!(
2839 "Duchon smooth '{}' requested basis dimension {} but order={:?} in {}D needs {} polynomial null-space columns; choose centers/k > {}",
2840 vars.join(", "),
2841 requested_centers,
2842 nullspace_order,
2843 cols.len(),
2844 polynomial_cols,
2845 polynomial_cols,
2846 ))
2847 .to_string());
2848 }
2849 let mut centers = requested_centers;
2850 if !centers_explicit && ds.values.nrows() <= 32 && smooth_coordinate_count >= 5 {
2851 centers = centers.max(polynomial_cols + 4);
2852 }
2853 let center_strategy = if centers_explicit {
2854 spatial_center_strategy_for_dimension(centers, cols.len())
2855 } else {
2856 auto_spatial_center_strategy(centers, cols.len())
2857 };
2858 let aniso_log_scales = if option_bool(options, "scale_dims").unwrap_or(false) {
2859 Some(vec![0.0; cols.len()])
2860 } else {
2861 None
2862 };
2863 let operator_penalties = DuchonOperatorPenaltySpec::default();
2866 Ok(SmoothBasisSpec::Duchon {
2867 feature_cols: cols.to_vec(),
2868 spec: DuchonBasisSpec {
2869 center_strategy,
2870 periodic: parse_periodic_axes_option(options, cols.len())?,
2871 length_scale,
2872 power,
2873 nullspace_order,
2874 identifiability: parse_spatial_identifiability(options)
2875 .map_err(|e| e.to_string())?,
2876 aniso_log_scales,
2877 operator_penalties,
2878 boundary: if cols.len() == 1 {
2879 let c = cols[0];
2880 let (minv, maxv) = col_minmax(ds.values.column(c))?;
2881 parse_cyclic_boundary(options, minv, maxv)?
2882 } else {
2883 OneDimensionalBoundary::Open
2884 },
2885 radial_reparam: None,
2886 },
2887 input_scales: None,
2888 })
2889 }
2890 "tensor" | "te" | "ti" | "t2" => {
2891 validate_known_options(
2892 "tensor",
2893 options,
2894 &[
2895 "type",
2896 "bs",
2897 "by",
2898 "k",
2899 "basis_dim",
2900 "basis-dim",
2901 "basisdim",
2902 "knot_placement",
2903 "knot-placement",
2904 "knotplacement",
2905 "degree",
2906 "penalty_order",
2907 "double_penalty",
2908 "periodic",
2909 "cyclic",
2910 "period",
2911 "periods",
2912 "period_start",
2913 "period_end",
2914 "origin",
2915 "origins",
2916 "period_origin",
2917 "period-origin",
2918 "domain_origin",
2919 "boundary",
2920 "bc",
2921 "identifiability",
2922 "id",
2923 "__by_col",
2924 ],
2925 )?;
2926 if cols.len() < 2 {
2927 return Err(TermBuilderError::incompatible_config(format!(
2928 "tensor smooth expects at least 2 variables, got {}",
2929 cols.len()
2930 ))
2931 .to_string());
2932 }
2933 let dim = cols.len();
2934
2935 if let Some(raw) = options.get("bs").or_else(|| options.get("type"))
2958 && bs_selector_is_vector(raw)
2959 {
2960 let per_margin = parse_option_list(raw);
2961 if per_margin.len() != dim {
2962 return Err(TermBuilderError::invalid_option(format!(
2963 "tensor smooth per-margin bs vector has {} entries but the smooth has {} margins",
2964 per_margin.len(),
2965 dim
2966 ))
2967 .to_string());
2968 }
2969 for (axis, margin_bs) in per_margin.iter().enumerate() {
2970 if !tensor_margin_bs_is_supported(margin_bs) {
2971 return Err(TermBuilderError::unsupported_feature(format!(
2972 "tensor smooth margin {axis} basis '{margin_bs}' is not a supported penalized-spline margin; \
2973 tensor margins accept tp/tps/ps/bs/cr/cc"
2974 ))
2975 .to_string());
2976 }
2977 }
2978 }
2979 let periodic_axes = parse_tensor_periodic_axes(options, dim)?;
2980 let periods_opt = parse_periods(options, &periodic_axes)?;
2981 let origins_opt = parse_period_origins(options, &periodic_axes)?;
2982 let degree = option_usize(options, "degree").unwrap_or(DEFAULT_BSPLINE_DEGREE);
2983 let penalty_order =
2984 option_usize(options, "penalty_order").unwrap_or(if degree > 1 { 2 } else { 1 });
2985 let (mut k_list, k_inferred) = parse_tensor_k_list(options, cols, ds)?;
2986 if ds.values.nrows() <= 32 && smooth_coordinate_count >= 5 {
2987 for k in &mut k_list {
2988 *k = (*k).min(degree + 2);
2989 }
2990 }
2991 if k_inferred {
2992 inference_notes.push(format!(
2993 "Automatically set per-margin basis sizes {:?} for tensor smooth '{}' \
2994 (dimension-aware tensor budget: total ∏k kept near the mgcv-te default \
2995 and within the data support, distributed geometrically across margins and \
2996 capped per margin by each column's resolution). \
2997 Override with k=<int> or k=[k0,k1,...].",
2998 k_list,
2999 vars.join(",")
3000 ));
3001 }
3002 let per_axis_bs: Vec<Option<String>> =
3015 match options.get("bs").or_else(|| options.get("type")) {
3016 Some(raw) if bs_selector_is_vector(raw) => {
3017 let list = parse_option_list(raw);
3018 (0..dim).map(|a| list.get(a).cloned()).collect()
3019 }
3020 Some(raw) => {
3021 let scalar = raw
3022 .trim()
3023 .trim_matches('"')
3024 .trim_matches('\'')
3025 .to_ascii_lowercase();
3026 vec![Some(scalar); dim]
3027 }
3028 None => vec![None; dim],
3029 };
3030 let margin_wants_cr = |bs: &Option<String>| -> bool {
3036 matches!(
3037 bs.as_deref(),
3038 None | Some("cr") | Some("cs") | Some("tp") | Some("tps")
3039 )
3040 };
3041 let mut margins: Vec<BSplineBasisSpec> = Vec::with_capacity(dim);
3042 let mut emitted_periods: Vec<Option<f64>> = Vec::with_capacity(dim);
3043 for axis in 0..dim {
3044 let c = cols[axis];
3045 let (data_min, data_max) = col_minmax(ds.values.column(c))?;
3046 let k_requested = k_list[axis];
3062 let n_distinct_axis = unique_count_column(ds.values.column(c));
3063 let k_axis = k_requested.min(n_distinct_axis).max(2);
3064 if k_axis < k_requested {
3065 log::info!(
3066 "tensor smooth: margin axis {axis} requested k={k_requested}, but the \
3067 covariate has only {n_distinct_axis} distinct value(s); reducing this \
3068 margin to k={k_axis} (mgcv-style data-support cap on the per-axis basis)."
3069 );
3070 }
3071 if k_axis < 2 {
3084 return Err(TermBuilderError::invalid_option(format!(
3085 "tensor smooth: k[{axis}]={k_axis} too small; tensor margins require k >= 2"
3086 ))
3087 .to_string());
3088 }
3089 if periodic_axes[axis] && k_axis < degree + 1 {
3090 return Err(TermBuilderError::invalid_option(format!(
3091 "tensor smooth: periodic axis {axis} requires k >= {} for degree {degree}, got k={k_axis}",
3092 degree + 1
3093 ))
3094 .to_string());
3095 }
3096 let effective_degree = degree.min(k_axis - 1).max(1);
3097 let effective_penalty_order = penalty_order.min(effective_degree);
3098 let (knotspec, boundary, axis_period) = if periodic_axes[axis] {
3099 let period_value = periods_opt[axis].ok_or_else(|| {
3100 format!(
3101 "tensor smooth axis {axis} is periodic but no period was supplied; \
3102 pass period=<value> (scalar) or period=[..., <value>, ...]"
3103 )
3104 })?;
3105 if !period_value.is_finite() || period_value <= 0.0 {
3106 return Err(format!(
3107 "tensor smooth axis {axis}: period must be a positive finite value, got {period_value}"
3108 ));
3109 }
3110 let domain_start = origins_opt[axis].unwrap_or(data_min);
3111 let domain_end = domain_start + period_value;
3112 (
3113 BSplineKnotSpec::PeriodicUniform {
3114 data_range: (domain_start, domain_end),
3115 num_basis: k_axis,
3116 },
3117 OneDimensionalBoundary::Cyclic {
3118 start: domain_start,
3119 end: domain_end,
3120 },
3121 Some(period_value),
3122 )
3123 } else if margin_wants_cr(&per_axis_bs[axis]) && k_axis >= 3 {
3124 let cr_knots =
3134 crate::basis::select_cr_knots(ds.values.column(c), k_axis)
3135 .map_err(|e| e.to_string())?;
3136 (
3137 BSplineKnotSpec::NaturalCubicRegression { knots: cr_knots },
3138 OneDimensionalBoundary::Open,
3139 None,
3140 )
3141 } else {
3142 let num_internal_knots = if effective_degree < degree {
3149 k_axis.saturating_sub(effective_degree + 1)
3150 } else {
3151 k_axis.saturating_sub(degree + 1).max(1)
3152 };
3153 let knotspec = match parse_knot_placement(options)? {
3154 crate::basis::BSplineKnotPlacement::Uniform => BSplineKnotSpec::Generate {
3155 data_range: (data_min, data_max),
3156 num_internal_knots,
3157 },
3158 crate::basis::BSplineKnotPlacement::Quantile => {
3159 crate::basis::auto_knot_vector_1d_quantile(
3160 ds.values.column(c),
3161 num_internal_knots,
3162 effective_degree,
3163 )
3164 .map_err(|e| e.to_string())?;
3165 BSplineKnotSpec::Automatic {
3166 num_internal_knots: Some(num_internal_knots),
3167 placement: crate::basis::BSplineKnotPlacement::Quantile,
3168 }
3169 }
3170 };
3171 (knotspec, OneDimensionalBoundary::Open, None)
3172 };
3173 let is_cr_margin =
3179 matches!(knotspec, BSplineKnotSpec::NaturalCubicRegression { .. });
3180 let margin_double_penalty =
3181 is_cr_margin && matches!(per_axis_bs[axis].as_deref(), Some("cs"));
3182 margins.push(BSplineBasisSpec {
3183 degree: effective_degree,
3184 penalty_order: effective_penalty_order,
3185 knotspec,
3186 double_penalty: margin_double_penalty,
3187 identifiability: BSplineIdentifiability::None,
3188 boundary,
3189 boundary_conditions: BSplineBoundaryConditions::default(),
3190 });
3191 emitted_periods.push(axis_period);
3192 }
3193 let canon_cols: Vec<usize> = {
3214 let mut perm: Vec<usize> = (0..dim).collect();
3215 perm.sort_by_key(|&a| cols[a]);
3216 if perm.iter().enumerate().any(|(i, &a)| i != a) {
3217 margins = perm.iter().map(|&a| margins[a].clone()).collect();
3218 emitted_periods = perm.iter().map(|&a| emitted_periods[a]).collect();
3219 }
3220 perm.iter().map(|&a| cols[a]).collect()
3221 };
3222 let any_periodic = emitted_periods.iter().any(|p| p.is_some());
3223 let periods_vec = if any_periodic {
3224 emitted_periods
3225 } else {
3226 Vec::new()
3227 };
3228 let tensor_double_penalty = option_bool(options, "double_penalty").unwrap_or(false);
3244 Ok(SmoothBasisSpec::TensorBSpline {
3245 feature_cols: canon_cols,
3246 spec: TensorBSplineSpec {
3247 marginalspecs: margins,
3248 periods: periods_vec,
3249 double_penalty: tensor_double_penalty,
3250 identifiability: parse_tensor_identifiability(options, kind)?,
3251 penalty_decomposition: if matches!(kind, SmoothKind::T2)
3261 || type_opt.as_str() == "t2"
3262 {
3263 TensorBSplinePenaltyDecomposition::Separable
3264 } else {
3265 TensorBSplinePenaltyDecomposition::MarginalKroneckerSum
3266 },
3267 },
3268 })
3269 }
3270 "pca" => {
3271 validate_known_options(
3272 "pca",
3273 options,
3274 &[
3275 "type",
3276 "bs",
3277 "by",
3278 "k",
3279 "basis_dim",
3280 "basis-dim",
3281 "basisdim",
3282 "lazy_path",
3283 "path",
3284 "pca_basis_path",
3285 "chunk_size",
3286 "smooth_penalty",
3287 "centered",
3288 "double_penalty",
3289 "id",
3290 "__by_col",
3291 ],
3292 )?;
3293 let path = options
3294 .get("lazy_path")
3295 .or_else(|| options.get("pca_basis_path"))
3296 .or_else(|| options.get("path"))
3297 .map(|raw| PathBuf::from(strip_quotes(raw)));
3298 let Some(path) = path else {
3299 return Err(TermBuilderError::incompatible_config(
3300 "pca smooth requires lazy_path=... on the formula path",
3301 )
3302 .to_string());
3303 };
3304 let k = option_usize_any(options, &["k", "basis_dim", "basis-dim", "basisdim"])
3305 .unwrap_or(0);
3306 let chunk_size = option_usize(options, "chunk_size").unwrap_or(DEFAULT_PCA_CHUNK_SIZE);
3307 Ok(SmoothBasisSpec::Pca {
3308 feature_cols: cols.to_vec(),
3309 basis_matrix: Array2::<f64>::zeros((cols.len(), k)),
3310 centered: option_bool(options, "centered").unwrap_or(true),
3311 smooth_penalty: option_f64(options, "smooth_penalty").unwrap_or(1.0),
3312 center_mean: None,
3313 pca_basis_path: Some(path),
3314 chunk_size,
3315 })
3316 }
3317 other => Err(TermBuilderError::unsupported_feature(format!(
3318 "unsupported smooth type '{other}'"
3319 ))
3320 .to_string()),
3321 }
3322}
3323
3324pub fn enable_scale_dimensions(spec: &mut TermCollectionSpec) {
3326 for smooth in spec.smooth_terms.iter_mut() {
3327 match &mut smooth.basis {
3328 SmoothBasisSpec::Matern {
3329 feature_cols,
3330 spec: matern,
3331 ..
3332 } => {
3333 if matern.aniso_log_scales.is_none() {
3334 let d = feature_cols.len();
3335 matern.aniso_log_scales = Some(vec![0.0; d]);
3336 }
3337 }
3338 SmoothBasisSpec::Duchon {
3339 feature_cols,
3340 spec: duchon,
3341 ..
3342 } => {
3343 if duchon.aniso_log_scales.is_none() {
3344 let d = feature_cols.len();
3345 duchon.aniso_log_scales = Some(vec![0.0; d]);
3346 }
3347 }
3348 _ => {}
3349 }
3350 }
3351}
3352
3353pub fn spatial_center_strategy_for_dimension(num_centers: usize, d: usize) -> CenterStrategy {
3358 if d <= 3 {
3359 CenterStrategy::FarthestPoint { num_centers }
3366 } else {
3367 default_spatial_center_strategy(num_centers, d)
3368 }
3369}
3370
3371pub fn col_minmax(col: ArrayView1<'_, f64>) -> Result<(f64, f64), String> {
3372 let min = col.iter().fold(f64::INFINITY, |a, &b| a.min(b));
3373 let max = col.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
3374 if !min.is_finite() || !max.is_finite() {
3375 return Err(TermBuilderError::degenerate_data(
3376 "non-finite data encountered while inferring knot range",
3377 )
3378 .to_string());
3379 }
3380 if (max - min).abs() < 1e-12 {
3381 Ok((min, min + 1e-6))
3382 } else {
3383 Ok((min, max))
3384 }
3385}
3386
3387pub fn unique_count_column(col: ArrayView1<'_, f64>) -> usize {
3388 use std::collections::HashSet;
3389 let mut set = HashSet::<u64>::with_capacity(col.len());
3390 for &v in col {
3391 let norm = if v == 0.0 { 0.0 } else { v };
3392 set.insert(norm.to_bits());
3393 }
3394 set.len().max(1)
3395}
3396
3397pub(crate) const CR_MIN_KNOTS: usize = 3;
3403
3404fn capped_cr_marginal_knotspec(
3431 col: ArrayView1<'_, f64>,
3432 k_cr_requested: usize,
3433 label: &str,
3434 inference_notes: &mut Vec<String>,
3435) -> Result<Option<BSplineKnotSpec>, String> {
3436 let n_distinct = unique_count_column(col);
3437 let k_cr = k_cr_requested.min(n_distinct);
3438 if k_cr < CR_MIN_KNOTS {
3439 inference_notes.push(format!(
3440 "Smooth '{label}': cubic-regression ('cr'/'cs'/'sz') basis requested k={k_cr_requested}, \
3441 but the covariate has only {n_distinct} distinct value(s) — too few to support a cubic \
3442 regression spline (needs >= {CR_MIN_KNOTS} distinct values). Degraded to the linear \
3443 B-spline marginal the default basis builds on the same data."
3444 ));
3445 return Ok(None);
3446 }
3447 if k_cr < k_cr_requested {
3448 inference_notes.push(format!(
3449 "Smooth '{label}': cubic-regression ('cr'/'cs'/'sz') basis reduced from k={k_cr_requested} \
3450 to k={k_cr} to match the covariate's {n_distinct} distinct value(s) (mgcv-style \
3451 data-support cap; a cr basis cannot place more value-knots than the data has)."
3452 ));
3453 }
3454 let cr_knots = crate::basis::select_cr_knots(col, k_cr).map_err(|e| e.to_string())?;
3455 Ok(Some(BSplineKnotSpec::NaturalCubicRegression {
3456 knots: cr_knots,
3457 }))
3458}
3459
3460fn min_per_group_unique_count(
3467 feature_col: ArrayView1<'_, f64>,
3468 group_col: ArrayView1<'_, f64>,
3469) -> usize {
3470 use std::collections::{HashMap, HashSet};
3471 let mut per_group: HashMap<u64, HashSet<u64>> = HashMap::new();
3472 for (xi, gi) in feature_col.iter().zip(group_col.iter()) {
3473 let xnorm = if *xi == 0.0 { 0.0 } else { *xi };
3474 let gnorm = if *gi == 0.0 { 0.0 } else { *gi };
3475 per_group
3476 .entry(gnorm.to_bits())
3477 .or_default()
3478 .insert(xnorm.to_bits());
3479 }
3480 per_group
3481 .values()
3482 .map(|s| s.len())
3483 .min()
3484 .unwrap_or(1)
3485 .max(1)
3486}
3487
3488pub fn heuristic_knots_for_column(col: ArrayView1<'_, f64>) -> usize {
3493 let unique = unique_count_column(col);
3494 let ceiling = ((unique as f64).cbrt() as usize).max(20);
3495 (unique / 4).clamp(4, ceiling)
3496}
3497
3498fn heuristic_tensor_margin_knots(cols: &[usize], ds: &Dataset) -> Vec<usize> {
3518 let d = cols.len().max(1);
3519 let degree = DEFAULT_BSPLINE_DEGREE;
3520 let min_k = degree + 2; let n = ds.values.nrows();
3522
3523 let per_margin_cap: Vec<usize> = cols
3527 .iter()
3528 .map(|&c| heuristic_knots_for_column(ds.values.column(c)).max(min_k))
3529 .collect();
3530
3531 let mgcv_like_per_margin = match d {
3538 2 => 7usize,
3539 3 => 5usize,
3540 _ => 4usize,
3541 };
3542 let mgcv_like_total = (mgcv_like_per_margin as f64).powi(d as i32);
3543 let data_budget = (n as f64) * 0.8;
3544 let p_target = mgcv_like_total
3545 .max(min_k.pow(d as u32) as f64)
3546 .min(data_budget);
3547
3548 let geo_per_margin = p_target.powf(1.0 / d as f64).round() as usize;
3551 let unclamped: Vec<usize> = per_margin_cap
3552 .iter()
3553 .map(|&cap| geo_per_margin.clamp(min_k, cap))
3554 .collect();
3555
3556 let mut k_list = unclamped;
3561 loop {
3562 let product: f64 = k_list.iter().map(|&k| k as f64).product();
3563 if product >= p_target {
3564 break;
3565 }
3566 let Some(idx) = k_list
3569 .iter()
3570 .zip(per_margin_cap.iter())
3571 .enumerate()
3572 .filter(|&(_, (k, cap))| k < cap)
3573 .max_by_key(|&(_, (k, cap))| (cap - k, *cap))
3574 .map(|(i, _)| i)
3575 else {
3576 break;
3577 };
3578 k_list[idx] += 1;
3579 }
3580 k_list
3581}
3582
3583pub fn heuristic_centers(n: usize, d: usize) -> usize {
3584 default_num_centers(n, d)
3585}
3586
3587fn parse_endpoint_side(
3592 value: &str,
3593 context: &str,
3594) -> Result<BSplineEndpointBoundaryCondition, String> {
3595 match value.trim().to_ascii_lowercase().as_str() {
3596 "" | "none" | "open" | "unconstrained" | "free" => {
3597 Ok(BSplineEndpointBoundaryCondition::Free)
3598 }
3599 "clamped" | "clamp" | "zero_derivative" | "zero-derivative" => {
3600 Ok(BSplineEndpointBoundaryCondition::Clamped)
3601 }
3602 "anchored" | "anchor" | "zero" | "zero_value" | "zero-value" => {
3603 Ok(BSplineEndpointBoundaryCondition::Anchored { value: 0.0 })
3604 }
3605 other => Err(format!(
3606 "unsupported {context} boundary condition '{other}'; expected free, clamped, or anchored"
3607 )),
3608 }
3609}
3610
3611fn boundary_anchor_value(
3612 options: &BTreeMap<String, String>,
3613 side: &str,
3614 fallback: Option<f64>,
3615) -> Option<f64> {
3616 [
3617 format!("anchor_{side}"),
3618 format!("{side}_anchor"),
3619 format!("anchor-value-{side}"),
3620 ]
3621 .iter()
3622 .find_map(|key| option_f64(options, key))
3623 .or(fallback)
3624}
3625
3626fn apply_anchor_value(
3627 cond: BSplineEndpointBoundaryCondition,
3628 value: Option<f64>,
3629) -> BSplineEndpointBoundaryCondition {
3630 match cond {
3631 BSplineEndpointBoundaryCondition::Anchored { .. } => {
3632 BSplineEndpointBoundaryCondition::Anchored {
3633 value: value.unwrap_or(0.0),
3634 }
3635 }
3636 other => other,
3637 }
3638}
3639
3640fn parse_bspline_boundary_conditions(
3641 options: &BTreeMap<String, String>,
3642) -> Result<BSplineBoundaryConditions, String> {
3643 let fallback_anchor = option_f64(options, "anchor")
3644 .or_else(|| option_f64(options, "anchor_value"))
3645 .or_else(|| option_f64(options, "value"));
3646 let global_boundary_conditions = options
3647 .get("boundary_conditions")
3648 .or_else(|| options.get("bc"));
3649 let mut boundary_conditions = BSplineBoundaryConditions::default();
3650
3651 if let Some(raw_boundary_conditions) = global_boundary_conditions {
3652 let cond = parse_endpoint_side(raw_boundary_conditions, "boundary_conditions")?;
3653 let side = options
3654 .get("side")
3655 .map(|s| s.trim().to_ascii_lowercase())
3656 .unwrap_or_else(|| "both".to_string());
3657 match side.as_str() {
3658 "both" | "all" | "endpoints" => {
3659 boundary_conditions.left = cond;
3660 boundary_conditions.right = cond;
3661 }
3662 "left" | "start" | "lower" => boundary_conditions.left = cond,
3663 "right" | "end" | "upper" => boundary_conditions.right = cond,
3664 other => {
3665 return Err(format!(
3666 "unsupported B-spline boundary side '{other}'; expected left, right, or both"
3667 ));
3668 }
3669 }
3670 }
3671
3672 if let Some(raw) = options
3673 .get("bc_left")
3674 .or_else(|| options.get("left_bc"))
3675 .or_else(|| options.get("bc_start"))
3676 .or_else(|| options.get("start_bc"))
3677 {
3678 boundary_conditions.left = parse_endpoint_side(raw, "left endpoint")?;
3679 }
3680 if let Some(raw) = options
3681 .get("bc_right")
3682 .or_else(|| options.get("right_bc"))
3683 .or_else(|| options.get("bc_end"))
3684 .or_else(|| options.get("end_bc"))
3685 {
3686 boundary_conditions.right = parse_endpoint_side(raw, "right endpoint")?;
3687 }
3688
3689 boundary_conditions.left = apply_anchor_value(
3690 boundary_conditions.left,
3691 boundary_anchor_value(options, "left", fallback_anchor),
3692 );
3693 boundary_conditions.right = apply_anchor_value(
3694 boundary_conditions.right,
3695 boundary_anchor_value(options, "right", fallback_anchor),
3696 );
3697
3698 reject_nonzero_anchor("left", boundary_conditions.left)?;
3706 reject_nonzero_anchor("right", boundary_conditions.right)?;
3707
3708 Ok(boundary_conditions)
3709}
3710
3711fn reject_nonzero_anchor(side: &str, cond: BSplineEndpointBoundaryCondition) -> Result<(), String> {
3712 if let BSplineEndpointBoundaryCondition::Anchored { value } = cond {
3713 if value.abs() > 1e-12 {
3714 return Err(format!(
3715 "non-zero {side} anchor {value} requires an affine offset term that is not yet supported; only anchored value 0 is accepted at parse time"
3716 ));
3717 }
3718 }
3719 Ok(())
3720}
3721
3722fn parse_ps_internal_knots(
3736 options: &BTreeMap<String, String>,
3737 degree: usize,
3738 default_internal_knots: usize,
3739) -> Result<(usize, bool, usize), String> {
3740 const MIN_EXPRESSIVE_INTERNAL_KNOTS: usize = 2;
3741 let knots_internal = if knots_option_is_list(options) {
3751 None
3752 } else {
3753 option_usize_strict(options, "knots")?
3754 };
3755 let basis_dim = option_usize_any_strict(options, &["k", "basis_dim", "basis-dim", "basisdim"])?;
3756 if knots_internal.is_some() && basis_dim.is_some() {
3757 return Err(TermBuilderError::incompatible_config(
3758 "ps/bspline smooth: specify either knots=<internal_knots> or k=<basis_dim> (not both)",
3759 )
3760 .to_string());
3761 }
3762 if let Some(k) = basis_dim {
3763 if k < 2 {
3764 return Err(TermBuilderError::invalid_option(format!(
3765 "ps/bspline smooth: k={} too small; B-spline basis requires k >= 2",
3766 k
3767 ))
3768 .to_string());
3769 }
3770 let effective_degree = degree.min(k - 1).max(1);
3776 let num_internal_knots = if effective_degree < degree {
3777 k.saturating_sub(effective_degree + 1)
3780 } else {
3781 (k - degree - 1).max(MIN_EXPRESSIVE_INTERNAL_KNOTS)
3782 };
3783 Ok((num_internal_knots, false, effective_degree))
3784 } else {
3785 Ok((
3786 knots_internal.unwrap_or(default_internal_knots),
3787 knots_internal.is_none(),
3788 degree,
3789 ))
3790 }
3791}
3792
3793fn knots_option_is_list(options: &BTreeMap<String, String>) -> bool {
3799 options
3800 .get("knots")
3801 .map(|raw| {
3802 let t = raw.trim();
3803 t.starts_with('[') || t.starts_with("c(") || t.starts_with("C(") || t.starts_with('(')
3804 })
3805 .unwrap_or(false)
3806}
3807
3808fn parse_explicit_internal_knots(
3813 options: &BTreeMap<String, String>,
3814) -> Result<Option<Vec<f64>>, String> {
3815 if !knots_option_is_list(options) {
3816 return Ok(None);
3817 }
3818 let raw = options
3819 .get("knots")
3820 .expect("knots_option_is_list implies the key is present");
3821 let tokens = split_list_option(raw);
3822 if tokens.is_empty() {
3823 return Err(TermBuilderError::invalid_option(format!(
3824 "knots={raw} is an empty list; supply at least one internal knot position \
3825 (e.g. knots=[0.2, 0.5, 0.8]) or a scalar count (e.g. knots=8)"
3826 ))
3827 .to_string());
3828 }
3829 let mut positions = Vec::with_capacity(tokens.len());
3830 for tok in &tokens {
3831 let value = parse_numeric_expr(tok).map_err(|err| {
3832 TermBuilderError::invalid_option(format!(
3833 "knots list entry '{tok}' is not a numeric position: {err}"
3834 ))
3835 .to_string()
3836 })?;
3837 positions.push(value);
3838 }
3839 Ok(Some(positions))
3840}
3841
3842fn parse_knot_placement(
3848 options: &BTreeMap<String, String>,
3849) -> Result<crate::basis::BSplineKnotPlacement, String> {
3850 use crate::basis::BSplineKnotPlacement;
3851 match options
3852 .get("knot_placement")
3853 .or_else(|| options.get("knot-placement"))
3854 .or_else(|| options.get("knotplacement"))
3855 {
3856 None => Ok(BSplineKnotPlacement::Uniform),
3857 Some(raw) => match raw
3858 .trim()
3859 .trim_matches('"')
3860 .trim_matches('\'')
3861 .to_ascii_lowercase()
3862 .as_str()
3863 {
3864 "uniform" | "even" | "equal" => Ok(BSplineKnotPlacement::Uniform),
3865 "quantile" | "quantiles" | "data" | "empirical" => Ok(BSplineKnotPlacement::Quantile),
3866 other => Err(TermBuilderError::invalid_option(format!(
3867 "knot_placement={other} is not recognised; expected \"uniform\" or \"quantile\""
3868 ))
3869 .to_string()),
3870 },
3871 }
3872}
3873
3874fn resolve_nonperiodic_bspline_knotspec(
3885 options: &BTreeMap<String, String>,
3886 data: ArrayView1<'_, f64>,
3887 data_range: (f64, f64),
3888 degree: usize,
3889 n_knots: usize,
3890) -> Result<BSplineKnotSpec, String> {
3891 use crate::basis::{BSplineKnotPlacement, clamped_knot_vector_from_internal_positions};
3892 if let Some(positions) = parse_explicit_internal_knots(options)? {
3893 if option_usize_any_strict(options, &["k", "basis_dim", "basis-dim", "basisdim"])?.is_some()
3894 {
3895 return Err(TermBuilderError::incompatible_config(
3896 "ps/bspline smooth: specify either explicit knots=[...] positions or \
3897 k=<basis_dim> (not both); the basis size is fixed by the knot vector",
3898 )
3899 .to_string());
3900 }
3901 let knots = clamped_knot_vector_from_internal_positions(data_range, &positions, degree)
3902 .map_err(|e| e.to_string())?;
3903 return Ok(BSplineKnotSpec::Provided(knots));
3904 }
3905 match parse_knot_placement(options)? {
3906 BSplineKnotPlacement::Uniform => Ok(BSplineKnotSpec::Generate {
3907 data_range,
3908 num_internal_knots: n_knots,
3909 }),
3910 BSplineKnotPlacement::Quantile => {
3911 crate::basis::auto_knot_vector_1d_quantile(data, n_knots, degree)
3915 .map_err(|e| e.to_string())?;
3916 Ok(BSplineKnotSpec::Automatic {
3917 num_internal_knots: Some(n_knots),
3918 placement: BSplineKnotPlacement::Quantile,
3919 })
3920 }
3921 }
3922}
3923
3924pub fn validate_known_options(
3930 term_name: &str,
3931 options: &BTreeMap<String, String>,
3932 known: &[&str],
3933) -> Result<(), String> {
3934 let known_set: std::collections::BTreeSet<&&str> = known.iter().collect();
3935 for key in options.keys() {
3936 if !known_set.contains(&key.as_str()) {
3937 if term_name == "tensor" && is_tensor_k_axis_option_key(key) {
3938 continue;
3939 }
3940 let key_l = key.to_ascii_lowercase();
3942 let mut suggestions: Vec<&str> = known
3943 .iter()
3944 .filter(|k| {
3945 let kl = k.to_ascii_lowercase();
3946 kl.contains(&key_l) || key_l.contains(&kl) || {
3947 let n = kl
3948 .chars()
3949 .zip(key_l.chars())
3950 .take_while(|(a, b)| a == b)
3951 .count();
3952 n >= 3
3953 }
3954 })
3955 .copied()
3956 .collect();
3957 suggestions.sort_unstable();
3958 suggestions.dedup();
3959 let hint = if suggestions.is_empty() {
3960 String::new()
3961 } else {
3962 format!(" — did you mean one of [{}]?", suggestions.join(", "))
3963 };
3964 return Err(TermBuilderError::invalid_option(format!(
3965 "{term_name}() does not accept option `{key}`{hint}. Valid options: [{}]",
3966 {
3967 let mut sorted = known.to_vec();
3968 sorted.sort_unstable();
3969 sorted.join(", ")
3970 }
3971 ))
3972 .to_string());
3973 }
3974 }
3975 Ok(())
3976}
3977
3978pub const SECONDARY_CENTER_CAP_OPTION: &str = "__secondary_center_cap";
3988
3989pub(crate) fn cap_default_spatial_centers(
3994 options: &BTreeMap<String, String>,
3995 default_count: usize,
3996) -> usize {
3997 match option_usize(options, SECONDARY_CENTER_CAP_OPTION) {
3998 Some(cap) => default_count.min(cap),
3999 None => default_count,
4000 }
4001}
4002
4003fn default_matern_center_count(n: usize, d: usize, planned_count: usize) -> usize {
4004 let low_n_floor = (d + 4).min(n);
4011 planned_count.max(low_n_floor).max(1)
4012}
4013
4014pub fn parse_countwith_basis_alias(
4015 options: &BTreeMap<String, String>,
4016 primarykey: &str,
4017 default_count: usize,
4018) -> Result<usize, String> {
4019 let primary = option_usize_strict(options, primarykey)?;
4024 let basis_dim = option_usize_any_strict(
4025 options,
4026 &["k", "basis_dim", "basis-dim", "basisdim", "knots"],
4027 )?;
4028 if primary.is_some() && basis_dim.is_some() {
4029 return Err(TermBuilderError::incompatible_config(format!(
4030 "specify either {}=<count> or k=<basis_dim> (not both)",
4031 primarykey
4032 ))
4033 .to_string());
4034 }
4035 Ok(primary.or(basis_dim).unwrap_or(default_count))
4036}
4037
4038pub fn has_explicit_countwith_basis_alias(
4039 options: &BTreeMap<String, String>,
4040 primarykey: &str,
4041) -> bool {
4042 options.contains_key(primarykey)
4043 || ["k", "basis_dim", "basis-dim", "basisdim", "knots"]
4044 .iter()
4045 .any(|alias| options.contains_key(*alias))
4046}
4047
4048pub fn parse_cyclic_boundary(
4049 options: &BTreeMap<String, String>,
4050 minv: f64,
4051 maxv: f64,
4052) -> Result<OneDimensionalBoundary, String> {
4053 let cyclic = option_bool(options, "cyclic")
4054 .or_else(|| option_bool(options, "periodic"))
4055 .unwrap_or(false);
4056 if !cyclic {
4057 return Ok(OneDimensionalBoundary::Open);
4058 }
4059 let start = match option_numeric_expr(options, "period_start")? {
4060 Some(v) => v,
4061 None => option_numeric_expr(options, "start")?.unwrap_or(minv),
4062 };
4063 let end = match option_numeric_expr(options, "period_end")? {
4064 Some(v) => v,
4065 None => option_numeric_expr(options, "end")?.unwrap_or(maxv),
4066 };
4067 if end <= start {
4068 return Err(format!(
4069 "cyclic smooth requires period_end/end ({end}) > period_start/start ({start})"
4070 ));
4071 }
4072 Ok(OneDimensionalBoundary::Cyclic { start, end })
4073}
4074
4075pub fn parse_periodic_domain_1d(
4082 options: &BTreeMap<String, String>,
4083 minv: f64,
4084 maxv: f64,
4085) -> Result<(f64, f64), String> {
4086 let start = match option_numeric_expr(options, "period_start")? {
4087 Some(v) => v,
4088 None => option_numeric_expr(options, "start")?.unwrap_or(minv),
4089 };
4090 let end = match option_numeric_expr(options, "period_end")? {
4091 Some(v) => v,
4092 None => option_numeric_expr(options, "end")?.unwrap_or(maxv),
4093 };
4094 if !(start.is_finite() && end.is_finite()) {
4095 return Err(format!(
4096 "periodic smooth domain requires finite endpoints, got ({start}, {end})"
4097 ));
4098 }
4099 if end <= start {
4100 return Err(format!(
4101 "periodic smooth requires period_end/end ({end}) > period_start/start ({start})"
4102 ));
4103 }
4104 Ok((start, end - start))
4105}
4106
4107fn parse_matern_nu(raw: &str) -> Result<MaternNu, String> {
4108 let trimmed = raw.trim();
4109 let lowered = trimmed.to_ascii_lowercase();
4110 match lowered.as_str() {
4111 "1/2" | "0.5" | "half" => return Ok(MaternNu::Half),
4112 "3/2" | "1.5" => return Ok(MaternNu::ThreeHalves),
4113 "5/2" | "2.5" => return Ok(MaternNu::FiveHalves),
4114 "7/2" | "3.5" => return Ok(MaternNu::SevenHalves),
4115 "9/2" | "4.5" => return Ok(MaternNu::NineHalves),
4116 _ => {}
4117 }
4118
4119 let value = if let Some((num, den)) = trimmed.split_once('/') {
4120 let num = num
4121 .trim()
4122 .parse::<f64>()
4123 .map_err(|err| format!("{}: {err}", unsupported_matern_nu_message(raw)))?;
4124 let den = den
4125 .trim()
4126 .parse::<f64>()
4127 .map_err(|err| format!("{}: {err}", unsupported_matern_nu_message(raw)))?;
4128 if den == 0.0 || !num.is_finite() || !den.is_finite() {
4129 return Err(unsupported_matern_nu_message(raw));
4130 }
4131 num / den
4132 } else {
4133 trimmed
4134 .parse::<f64>()
4135 .map_err(|err| format!("{}: {err}", unsupported_matern_nu_message(raw)))?
4136 };
4137
4138 const TOL: f64 = 1e-12;
4139 if (value - 0.5).abs() <= TOL {
4140 Ok(MaternNu::Half)
4141 } else if (value - 1.5).abs() <= TOL {
4142 Ok(MaternNu::ThreeHalves)
4143 } else if (value - 2.5).abs() <= TOL {
4144 Ok(MaternNu::FiveHalves)
4145 } else if (value - 3.5).abs() <= TOL {
4146 Ok(MaternNu::SevenHalves)
4147 } else if (value - 4.5).abs() <= TOL {
4148 Ok(MaternNu::NineHalves)
4149 } else {
4150 Err(unsupported_matern_nu_message(raw))
4151 }
4152}
4153
4154fn unsupported_matern_nu_message(raw: &str) -> String {
4155 TermBuilderError::unsupported_feature(format!(
4156 "unsupported Matern nu '{raw}'; supported half-integer values are 1/2, 3/2, 5/2, 7/2, and 9/2"
4157 ))
4158 .to_string()
4159}
4160
4161#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
4162pub enum DuchonPowerPolicy {
4163 Explicit(f64),
4164 CubicStructuralDefault,
4168}
4169
4170pub fn parse_duchon_power_policy(
4171 options: &BTreeMap<String, String>,
4172) -> Result<DuchonPowerPolicy, String> {
4173 if let Some(raw_nu) = options.get("nu") {
4174 return Err(TermBuilderError::incompatible_config(format!(
4175 "Duchon smooths use power=<number>, not nu='{}'. Use power=1.5, power=2, etc.",
4176 raw_nu
4177 ))
4178 .to_string());
4179 }
4180 match options.get("power") {
4181 Some(raw) => {
4182 let value = raw.parse::<f64>().map_err(|err| {
4183 TermBuilderError::invalid_option(format!(
4184 "invalid Duchon power '{}'; expected a non-negative number such as power=1.5 or power=2: {}",
4185 raw, err
4186 ))
4187 .to_string()
4188 })?;
4189 if !value.is_finite() || value < 0.0 {
4190 return Err(TermBuilderError::invalid_option(format!(
4191 "invalid Duchon power '{}'; expected a finite non-negative number such as power=1.5 or power=2",
4192 raw
4193 ))
4194 .to_string());
4195 }
4196 Ok(DuchonPowerPolicy::Explicit(value))
4197 }
4198 None => Ok(DuchonPowerPolicy::CubicStructuralDefault),
4199 }
4200}
4201
4202pub fn parse_duchon_power(options: &BTreeMap<String, String>) -> Result<f64, String> {
4203 match parse_duchon_power_policy(options)? {
4204 DuchonPowerPolicy::Explicit(power) => Ok(power),
4205 DuchonPowerPolicy::CubicStructuralDefault => Ok(1.5),
4211 }
4212}
4213
4214pub fn parse_duchon_order(
4215 options: &BTreeMap<String, String>,
4216) -> Result<DuchonNullspaceOrder, String> {
4217 match options.get("order") {
4218 None => Ok(DuchonNullspaceOrder::Linear),
4222 Some(raw) => match raw.parse::<usize>() {
4223 Ok(0) => Ok(DuchonNullspaceOrder::Zero),
4224 Ok(1) => Ok(DuchonNullspaceOrder::Linear),
4225 Ok(other) => Ok(DuchonNullspaceOrder::Degree(other)),
4226 Err(_) => Err(TermBuilderError::invalid_option(format!(
4227 "invalid Duchon order '{}'; expected a non-negative integer such as order=0, order=1, or order=2",
4228 raw
4229 ))
4230 .to_string()),
4231 },
4232 }
4233}
4234
4235fn parse_matern_identifiability(
4236 options: &BTreeMap<String, String>,
4237) -> Result<MaternIdentifiability, TermBuilderError> {
4238 let Some(raw) = options.get("identifiability").map(String::as_str) else {
4239 return Ok(MaternIdentifiability::default());
4240 };
4241 match raw.trim().to_ascii_lowercase().as_str() {
4242 "none" => Ok(MaternIdentifiability::None),
4243 "sum_tozero" | "sum-to-zero" | "center_sum_tozero" | "center-sum-to-zero" | "centered" => {
4244 Ok(MaternIdentifiability::CenterSumToZero)
4245 }
4246 "linear" | "center_linear_orthogonal" | "center-linear-orthogonal" => {
4247 Ok(MaternIdentifiability::CenterLinearOrthogonal)
4248 }
4249 other => Err(TermBuilderError::unsupported_feature(format!(
4250 "invalid Matérn identifiability '{other}'; expected one of: none, sum_tozero, linear"
4251 ))),
4252 }
4253}
4254
4255fn parse_spatial_identifiability(
4256 options: &BTreeMap<String, String>,
4257) -> Result<SpatialIdentifiability, TermBuilderError> {
4258 let Some(raw) = options.get("identifiability").map(String::as_str) else {
4259 return Ok(SpatialIdentifiability::default());
4260 };
4261 match raw.trim().to_ascii_lowercase().as_str() {
4262 "none" => Ok(SpatialIdentifiability::None),
4263 "orthogonal"
4264 | "orthogonal_to_parametric"
4265 | "orthogonal-to-parametric"
4266 | "parametric_orthogonal" => Ok(SpatialIdentifiability::OrthogonalToParametric),
4267 "frozen" => Err(TermBuilderError::unsupported_feature(
4268 "spatial identifiability 'frozen' is internal-only; use none or orthogonal_to_parametric",
4269 )),
4270 other => Err(TermBuilderError::unsupported_feature(format!(
4271 "invalid spatial identifiability '{other}'; expected one of: none, orthogonal_to_parametric"
4272 ))),
4273 }
4274}
4275
4276#[cfg(test)]
4277mod tests {
4278 use super::*;
4279 use crate::inference::formula_dsl::parse_formula;
4280 use gam_data::{DataSchema, SchemaColumn};
4281 use ndarray::Array2;
4282 use std::collections::BTreeMap;
4283
4284 fn continuous_dataset(headers: &[&str], rows: Vec<Vec<f64>>) -> Dataset {
4285 let nrows = rows.len();
4286 let ncols = headers.len();
4287 let values = Array2::from_shape_vec(
4288 (nrows, ncols),
4289 rows.into_iter().flat_map(|row| row.into_iter()).collect(),
4290 )
4291 .expect("rectangular test data");
4292 Dataset {
4293 headers: headers.iter().map(|name| name.to_string()).collect(),
4294 values,
4295 schema: DataSchema {
4296 columns: headers
4297 .iter()
4298 .map(|name| SchemaColumn {
4299 name: name.to_string(),
4300 kind: ColumnKindTag::Continuous,
4301 levels: vec![],
4302 })
4303 .collect(),
4304 },
4305 column_kinds: vec![ColumnKindTag::Continuous; ncols],
4306 }
4307 }
4308
4309 fn factor_dataset() -> Dataset {
4310 let rows = (0..24)
4311 .map(|i| {
4312 let x = i as f64 / 23.0;
4313 let g = (i % 2) as f64;
4314 vec![x + g, x, g]
4315 })
4316 .collect::<Vec<_>>();
4317 Dataset {
4318 headers: vec!["y".into(), "x".into(), "g".into()],
4319 values: Array2::from_shape_vec(
4320 (rows.len(), 3),
4321 rows.into_iter().flat_map(|row| row.into_iter()).collect(),
4322 )
4323 .expect("rectangular factor test data"),
4324 schema: DataSchema {
4325 columns: vec![
4326 SchemaColumn {
4327 name: "y".into(),
4328 kind: ColumnKindTag::Continuous,
4329 levels: vec![],
4330 },
4331 SchemaColumn {
4332 name: "x".into(),
4333 kind: ColumnKindTag::Continuous,
4334 levels: vec![],
4335 },
4336 SchemaColumn {
4337 name: "g".into(),
4338 kind: ColumnKindTag::Categorical,
4339 levels: vec!["a".into(), "b".into()],
4340 },
4341 ],
4342 },
4343 column_kinds: vec![
4344 ColumnKindTag::Continuous,
4345 ColumnKindTag::Continuous,
4346 ColumnKindTag::Categorical,
4347 ],
4348 }
4349 }
4350
4351 #[test]
4359 fn default_univariate_thinplate_basis_dim_is_modest() {
4360 let n = 300usize;
4363 let rows: Vec<Vec<f64>> = (0..n)
4364 .map(|i| {
4365 let x = -3.0 + 6.0 * (i as f64) / ((n - 1) as f64);
4366 vec![x.sin(), x]
4367 })
4368 .collect();
4369 let ds = continuous_dataset(&["y", "x"], rows);
4370
4371 let mut options = BTreeMap::new();
4372 options.insert("bs".to_string(), "tp".to_string());
4373
4374 let mut notes = Vec::new();
4375 let basis = build_smooth_basis(
4376 SmoothKind::S,
4377 &["x".to_string()],
4378 &[1],
4379 &options,
4380 &ds,
4381 &mut notes,
4382 &ResourcePolicy::default_library(),
4383 1,
4384 )
4385 .expect("build default univariate tp smooth");
4386
4387 let centers = match &basis {
4388 SmoothBasisSpec::ThinPlate { spec, .. } => match &spec.center_strategy {
4389 CenterStrategy::Auto(inner) => match inner.as_ref() {
4390 CenterStrategy::FarthestPoint { num_centers }
4391 | CenterStrategy::EqualMass { num_centers }
4392 | CenterStrategy::EqualMassCovarRepresentative { num_centers }
4393 | CenterStrategy::KMeans { num_centers, .. } => *num_centers,
4394 other => panic!("unexpected auto inner center strategy: {other:?}"),
4395 },
4396 CenterStrategy::FarthestPoint { num_centers }
4397 | CenterStrategy::EqualMass { num_centers }
4398 | CenterStrategy::EqualMassCovarRepresentative { num_centers }
4399 | CenterStrategy::KMeans { num_centers, .. } => *num_centers,
4400 other => panic!("unexpected center strategy: {other:?}"),
4401 },
4402 other => panic!("expected ThinPlate basis, got {other:?}"),
4403 };
4404
4405 assert!(
4409 centers >= 1,
4410 "default univariate tp must still build a usable basis (centers={centers})",
4411 );
4412 }
4413
4414 fn inferred_tensor_basis_product(ds: &Dataset) -> usize {
4415 let parsed = parse_formula("y ~ te(theta, h)").expect("parse tensor formula");
4416 let col_map = ds.column_map();
4417 let mut notes = Vec::new();
4418 let terms = build_termspec(
4419 &parsed.terms,
4420 ds,
4421 &col_map,
4422 &mut notes,
4423 &ResourcePolicy::default_library(),
4424 )
4425 .expect("build tensor termspec");
4426 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
4427 panic!("expected tensor smooth");
4428 };
4429 spec.marginalspecs
4430 .iter()
4431 .map(|marginal| match marginal.knotspec {
4432 BSplineKnotSpec::Generate {
4433 num_internal_knots, ..
4434 } => num_internal_knots + marginal.degree + 1,
4435 BSplineKnotSpec::PeriodicUniform { num_basis, .. } => num_basis,
4436 BSplineKnotSpec::Automatic {
4437 num_internal_knots: Some(num_internal_knots),
4438 ..
4439 } => num_internal_knots + marginal.degree + 1,
4440 BSplineKnotSpec::Automatic {
4441 num_internal_knots: None,
4442 ..
4443 } => panic!("test helper cannot infer automatic knot count"),
4444 BSplineKnotSpec::Provided(ref knots) => {
4445 knots.len().saturating_sub(marginal.degree + 1)
4446 }
4447 BSplineKnotSpec::NaturalCubicRegression { ref knots } => knots.len(),
4449 })
4450 .product()
4451 }
4452
4453 fn tensor_margin_basis_sizes(ds: &Dataset, formula: &str) -> Vec<usize> {
4454 let parsed = parse_formula(formula).expect("parse tensor formula");
4455 let col_map = ds.column_map();
4456 let mut notes = Vec::new();
4457 let terms = build_termspec(
4458 &parsed.terms,
4459 ds,
4460 &col_map,
4461 &mut notes,
4462 &ResourcePolicy::default_library(),
4463 )
4464 .expect("build tensor termspec");
4465 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
4466 panic!("expected tensor smooth");
4467 };
4468 spec.marginalspecs
4469 .iter()
4470 .map(|marginal| match marginal.knotspec {
4471 BSplineKnotSpec::Generate {
4472 num_internal_knots, ..
4473 } => num_internal_knots + marginal.degree + 1,
4474 BSplineKnotSpec::PeriodicUniform { num_basis, .. } => num_basis,
4475 BSplineKnotSpec::Automatic {
4476 num_internal_knots: Some(num_internal_knots),
4477 ..
4478 } => num_internal_knots + marginal.degree + 1,
4479 BSplineKnotSpec::Automatic {
4480 num_internal_knots: None,
4481 ..
4482 } => panic!("test helper cannot infer automatic knot count"),
4483 BSplineKnotSpec::Provided(ref knots) => {
4484 knots.len().saturating_sub(marginal.degree + 1)
4485 }
4486 BSplineKnotSpec::NaturalCubicRegression { ref knots } => knots.len(),
4488 })
4489 .collect()
4490 }
4491
4492 #[test]
4493 fn validate_known_options_lists_valid_option_names_for_unknown_parameter() {
4494 let mut options = BTreeMap::new();
4495 options.insert("lengt_scale".to_string(), "0.25".to_string());
4496 let err = validate_known_options(
4497 "matern",
4498 &options,
4499 &["type", "bs", "length_scale", "centers", "k", "nu"],
4500 )
4501 .expect_err("unknown smooth option should be rejected");
4502 assert!(
4503 err.contains("matern() does not accept option `lengt_scale`"),
4504 "error should name the invalid option, got: {err}"
4505 );
4506 assert!(
4507 err.contains("did you mean one of [length_scale]"),
4508 "error should suggest the closest valid option, got: {err}"
4509 );
4510 assert!(
4511 err.contains("Valid options: ["),
4512 "error should list valid option names, got: {err}"
4513 );
4514 }
4515
4516 #[test]
4517 fn tensor_k_accepts_square_bracket_per_margin_list() {
4518 let ds = continuous_dataset(
4519 &["y", "x", "z"],
4520 (0..40)
4521 .map(|i| {
4522 let x = i as f64 / 39.0;
4523 let z = ((i * 7) % 40) as f64 / 39.0;
4524 vec![x.sin() + z.cos(), x, z]
4525 })
4526 .collect(),
4527 );
4528
4529 assert_eq!(
4530 tensor_margin_basis_sizes(&ds, "y ~ te(x, z, k=[5, 6])"),
4531 vec![5, 6],
4532 "square-bracket k lists should materialize the requested per-margin values"
4533 );
4534 }
4535
4536 #[test]
4537 fn parse_cylinder_periodic_options_match_requested_forms() {
4538 let mut opts = BTreeMap::new();
4539 opts.insert("periodic".to_string(), "[0]".to_string());
4540 opts.insert("period".to_string(), "[2*pi, None]".to_string());
4541 let axes = parse_periodic_axes(&opts, 2).expect("axes");
4542 let periods = parse_periods(&opts, &axes).expect("periods");
4543 assert_eq!(axes, vec![true, false]);
4544 assert!((periods[0].unwrap() - 2.0 * std::f64::consts::PI).abs() < 1e-12);
4545 assert_eq!(periods[1], None);
4546
4547 let mut boundary_opts = BTreeMap::new();
4548 boundary_opts.insert(
4549 "boundary".to_string(),
4550 "['periodic', 'natural']".to_string(),
4551 );
4552 boundary_opts.insert("period".to_string(), "[2*pi, None]".to_string());
4553 let boundary_axes = parse_periodic_axes(&boundary_opts, 2).expect("boundary axes");
4554 let boundary_periods =
4555 parse_periods(&boundary_opts, &boundary_axes).expect("boundary periods");
4556 assert_eq!(boundary_axes, vec![true, false]);
4557 assert!((boundary_periods[0].unwrap() - 2.0 * std::f64::consts::PI).abs() < 1e-12);
4558 assert_eq!(boundary_periods[1], None);
4559
4560 let mut unicode_opts = BTreeMap::new();
4561 unicode_opts.insert("periodic".to_string(), "[0,1]".to_string());
4562 unicode_opts.insert("period".to_string(), "[2π, τ]".to_string());
4563 let unicode_axes = parse_periodic_axes(&unicode_opts, 2).expect("unicode axes");
4564 let unicode_periods = parse_periods(&unicode_opts, &unicode_axes).expect("unicode periods");
4565 assert_eq!(unicode_axes, vec![true, true]);
4566 assert!((unicode_periods[0].unwrap() - 2.0 * std::f64::consts::PI).abs() < 1e-12);
4567 assert!((unicode_periods[1].unwrap() - std::f64::consts::TAU).abs() < 1e-12);
4568 }
4569
4570 #[test]
4571 fn parse_single_axis_periodic_zero_as_axis_not_false() {
4572 let mut opts = BTreeMap::new();
4573 opts.insert("periodic".to_string(), "[0]".to_string());
4574 opts.insert("period".to_string(), "2*pi".to_string());
4575 opts.insert("origin".to_string(), "0".to_string());
4576 let axes = parse_periodic_axes(&opts, 1).expect("axes");
4577 let periods = parse_periods(&opts, &axes).expect("periods");
4578 let origins = parse_period_origins(&opts, &axes).expect("origins");
4579 assert_eq!(axes, vec![true]);
4580 assert!((periods[0].unwrap() - 2.0 * std::f64::consts::PI).abs() < 1e-12);
4581 assert_eq!(origins[0], Some(0.0));
4582 }
4583
4584 #[test]
4585 fn one_dimensional_bspline_accepts_boundary_periodic() {
4586 let ds = continuous_dataset(
4587 &["y", "theta"],
4588 (0..16)
4589 .map(|i| {
4590 let theta = std::f64::consts::TAU * i as f64 / 16.0;
4591 vec![theta.sin(), theta]
4592 })
4593 .collect(),
4594 );
4595 let parsed = parse_formula("y ~ s(theta, boundary=periodic, period=2*pi, origin=0, k=8)")
4596 .expect("parse");
4597 let col_map = ds.column_map();
4598 let mut notes = Vec::new();
4599 let terms = build_termspec(
4600 &parsed.terms,
4601 &ds,
4602 &col_map,
4603 &mut notes,
4604 &gam_runtime::resource::ResourcePolicy::default_library(),
4605 )
4606 .expect("periodic boundary should build");
4607 let SmoothBasisSpec::BSpline1D { spec, .. } = &terms.smooth_terms[0].basis else {
4608 panic!("expected 1D B-spline");
4609 };
4610 assert!(matches!(
4611 &spec.knotspec,
4612 BSplineKnotSpec::PeriodicUniform {
4613 data_range,
4614 num_basis: 8
4615 } if *data_range == (0.0, std::f64::consts::TAU)
4616 ));
4617 }
4618
4619 #[test]
4620 fn univariate_smooth_accepts_mgcv_cubic_regression_aliases() {
4621 let ds = continuous_dataset(
4622 &["y", "x"],
4623 (0..32)
4624 .map(|i| {
4625 let x = i as f64 / 31.0;
4626 vec![x * x, x]
4627 })
4628 .collect(),
4629 );
4630 let col_map = ds.column_map();
4631
4632 for (selector, expect_double_penalty) in [("cr", false), ("cs", true)] {
4633 let formula = format!("y ~ s(x, bs='{selector}')");
4634 let parsed = parse_formula(&formula).expect("parse cr/cs smooth");
4635 let mut notes = Vec::new();
4636 let terms = build_termspec(
4637 &parsed.terms,
4638 &ds,
4639 &col_map,
4640 &mut notes,
4641 &gam_runtime::resource::ResourcePolicy::default_library(),
4642 )
4643 .unwrap_or_else(|err| panic!("bs='{selector}' must build a 1-D smooth, got: {err:?}"));
4644 let SmoothBasisSpec::BSpline1D { spec, .. } = &terms.smooth_terms[0].basis else {
4645 panic!(
4646 "bs='{selector}' must lower to a BSpline1D; got {:?}",
4647 terms.smooth_terms[0].basis
4648 );
4649 };
4650 assert_eq!(
4651 spec.double_penalty, expect_double_penalty,
4652 "bs='{selector}' must default double_penalty to mgcv's convention \
4653 (cr=no-shrinkage, cs=shrinkage); got double_penalty={}",
4654 spec.double_penalty
4655 );
4656 }
4657 }
4658
4659 #[test]
4660 fn univariate_ps_small_k_degree_reduces_through_build() {
4661 let ds = continuous_dataset(
4670 &["y", "x"],
4671 (0..32)
4672 .map(|i| {
4673 let x = i as f64 / 31.0;
4674 vec![x * x, x]
4675 })
4676 .collect(),
4677 );
4678 let col_map = ds.column_map();
4679
4680 for formula in ["y ~ s(x, bs='ps', k=3)", "y ~ s(x, k=3)"] {
4681 let parsed = parse_formula(formula).expect("parse small-k ps/cr smooth");
4682 let mut notes = Vec::new();
4683 let terms = build_termspec(
4684 &parsed.terms,
4685 &ds,
4686 &col_map,
4687 &mut notes,
4688 &gam_runtime::resource::ResourcePolicy::default_library(),
4689 )
4690 .unwrap_or_else(|err| {
4691 panic!("`{formula}` must degree-reduce, not error; got: {err:?}")
4692 });
4693 let SmoothBasisSpec::BSpline1D { spec, .. } = &terms.smooth_terms[0].basis else {
4694 panic!(
4695 "`{formula}` must lower to a BSpline1D; got {:?}",
4696 terms.smooth_terms[0].basis
4697 );
4698 };
4699 assert_eq!(
4700 spec.degree, 2,
4701 "`{formula}` must drop the cubic default to a quadratic basis"
4702 );
4703 let num_internal = match &spec.knotspec {
4704 BSplineKnotSpec::Generate {
4705 num_internal_knots, ..
4706 } => *num_internal_knots,
4707 BSplineKnotSpec::Automatic {
4708 num_internal_knots: Some(n),
4709 ..
4710 } => *n,
4711 other => panic!("`{formula}` unexpected knotspec: {other:?}"),
4712 };
4713 assert_eq!(
4714 num_internal, 0,
4715 "`{formula}` must have zero internal knots (num_basis = k = 3)"
4716 );
4717 assert!(
4719 spec.penalty_order >= 1 && spec.penalty_order <= spec.degree,
4720 "`{formula}` penalty_order {} must satisfy 1 <= order <= degree={}",
4721 spec.penalty_order,
4722 spec.degree
4723 );
4724 }
4725 }
4726
4727 #[test]
4728 fn formula_shape_constraint_round_trips_and_rejects_bogus() {
4729 let ds = continuous_dataset(
4730 &["y", "x"],
4731 (0..32)
4732 .map(|i| {
4733 let x = i as f64 / 31.0;
4734 vec![x * x, x]
4735 })
4736 .collect(),
4737 );
4738 let col_map = ds.column_map();
4739
4740 let parsed =
4741 parse_formula("y ~ s(x, shape=monotone_increasing)").expect("parse monotone smooth");
4742 let mut notes = Vec::new();
4743 let terms = build_termspec(
4744 &parsed.terms,
4745 &ds,
4746 &col_map,
4747 &mut notes,
4748 &gam_runtime::resource::ResourcePolicy::default_library(),
4749 )
4750 .expect("monotone smooth should build");
4751 assert_eq!(
4752 terms.smooth_terms[0].shape,
4753 ShapeConstraint::MonotoneIncreasing
4754 );
4755
4756 let parsed_bad = parse_formula("y ~ s(x, shape=bogus)").expect("parse bogus shape");
4757 let mut notes_bad = Vec::new();
4758 let err = build_termspec(
4759 &parsed_bad.terms,
4760 &ds,
4761 &col_map,
4762 &mut notes_bad,
4763 &gam_runtime::resource::ResourcePolicy::default_library(),
4764 )
4765 .expect_err("bogus shape must error");
4766 assert!(
4767 format!("{err:?}").contains("unknown shape constraint"),
4768 "got: {err:?}"
4769 );
4770 }
4771
4772 #[test]
4773 fn default_sphere_smooth_uses_spherical_farthest_point_centers() {
4774 let ds = continuous_dataset(
4775 &["y", "lat", "lon"],
4776 (0..24)
4777 .map(|i| {
4778 let t = i as f64 / 24.0;
4779 let lat = -60.0 + 120.0 * t;
4780 let lon = -180.0 + 360.0 * ((7 * i) % 24) as f64 / 24.0;
4781 vec![lat.to_radians().sin(), lat, lon]
4782 })
4783 .collect(),
4784 );
4785 let parsed = parse_formula("y ~ sphere(lat, lon)").expect("parse");
4786 let col_map = ds.column_map();
4787 let mut notes = Vec::new();
4788 let terms = build_termspec(
4789 &parsed.terms,
4790 &ds,
4791 &col_map,
4792 &mut notes,
4793 &gam_runtime::resource::ResourcePolicy::default_library(),
4794 )
4795 .expect("build sphere termspec");
4796 let SmoothBasisSpec::Sphere { spec, .. } = &terms.smooth_terms[0].basis else {
4797 panic!("expected sphere term");
4798 };
4799 assert!(matches!(
4800 spec.center_strategy,
4801 CenterStrategy::FarthestPoint { .. }
4802 ));
4803 }
4804
4805 #[test]
4806 fn one_dimensional_duchon_defaults_to_scale_free_length_scale() {
4807 let ds = continuous_dataset(
4808 &["y", "x"],
4809 (0..32)
4810 .map(|i| {
4811 let x = i as f64 / 31.0;
4812 vec![(std::f64::consts::TAU * x).sin(), x]
4813 })
4814 .collect(),
4815 );
4816 let parsed = parse_formula("y ~ duchon(x)").expect("parse");
4817 let col_map = ds.column_map();
4818 let mut notes = Vec::new();
4819 let terms = build_termspec(
4820 &parsed.terms,
4821 &ds,
4822 &col_map,
4823 &mut notes,
4824 &gam_runtime::resource::ResourcePolicy::default_library(),
4825 )
4826 .expect("build default duchon termspec");
4827 let SmoothBasisSpec::Duchon { spec, .. } = &terms.smooth_terms[0].basis else {
4828 panic!("expected Duchon term");
4829 };
4830 assert_eq!(spec.length_scale, None);
4831 }
4832
4833 #[test]
4834 fn one_dimensional_duchon_length_scale_opts_into_hybrid_mode() {
4835 let ds = continuous_dataset(
4836 &["y", "x"],
4837 (0..32)
4838 .map(|i| {
4839 let x = i as f64 / 31.0;
4840 vec![(std::f64::consts::TAU * x).sin(), x]
4841 })
4842 .collect(),
4843 );
4844 let parsed = parse_formula("y ~ duchon(x, length_scale=0.25)").expect("parse");
4845 let col_map = ds.column_map();
4846 let mut notes = Vec::new();
4847 let terms = build_termspec(
4848 &parsed.terms,
4849 &ds,
4850 &col_map,
4851 &mut notes,
4852 &gam_runtime::resource::ResourcePolicy::default_library(),
4853 )
4854 .expect("build hybrid duchon termspec");
4855 let SmoothBasisSpec::Duchon { spec, .. } = &terms.smooth_terms[0].basis else {
4856 panic!("expected Duchon term");
4857 };
4858 assert_eq!(spec.length_scale, Some(0.25));
4859 }
4860
4861 #[test]
4862 fn parse_matern_nu_accepts_equivalent_half_integer_forms() {
4863 let cases = [
4864 ("1/2", MaternNu::Half),
4865 (" 1 / 2 ", MaternNu::Half),
4866 (".5", MaternNu::Half),
4867 ("0.50", MaternNu::Half),
4868 ("half", MaternNu::Half),
4869 ("3 / 2", MaternNu::ThreeHalves),
4870 ("1.50", MaternNu::ThreeHalves),
4871 ("5 / 2", MaternNu::FiveHalves),
4872 ("2.500000000000", MaternNu::FiveHalves),
4873 ("7 / 2", MaternNu::SevenHalves),
4874 ("3.50", MaternNu::SevenHalves),
4875 ("9 / 2", MaternNu::NineHalves),
4876 ("4.50", MaternNu::NineHalves),
4877 ];
4878 for (raw, expected) in cases {
4879 let parsed = parse_matern_nu(raw).expect(raw);
4880 assert!(
4881 matches!(
4882 (parsed, expected),
4883 (MaternNu::Half, MaternNu::Half)
4884 | (MaternNu::ThreeHalves, MaternNu::ThreeHalves)
4885 | (MaternNu::FiveHalves, MaternNu::FiveHalves)
4886 | (MaternNu::SevenHalves, MaternNu::SevenHalves)
4887 | (MaternNu::NineHalves, MaternNu::NineHalves)
4888 ),
4889 "parsed {raw:?} as {parsed:?}, expected {expected:?}"
4890 );
4891 }
4892 }
4893
4894 #[test]
4895 fn parse_matern_nu_rejects_unsupported_or_invalid_values() {
4896 for raw in ["1", "2", "11/2", "1/0", "nan", "fast"] {
4897 let err = parse_matern_nu(raw).expect_err(raw);
4898 assert!(
4899 err.contains("supported half-integer values"),
4900 "unexpected error for {raw:?}: {err}"
4901 );
4902 }
4903 }
4904
4905 #[test]
4906 fn parse_ps_k_promotes_underexpressive_cubic_basis() {
4907 let mut opts = BTreeMap::new();
4908 opts.insert("k".to_string(), "4".to_string());
4909 let (internal, inferred, eff_degree) = parse_ps_internal_knots(&opts, 3, 20).expect("k=4");
4910 assert_eq!(internal, 2);
4911 assert_eq!(eff_degree, 3);
4912 assert!(!inferred);
4913
4914 opts.insert("k".to_string(), "6".to_string());
4915 let (internal, inferred, eff_degree) = parse_ps_internal_knots(&opts, 3, 20).expect("k=6");
4916 assert_eq!(internal, 2);
4917 assert_eq!(eff_degree, 3);
4918 assert!(!inferred);
4919
4920 opts.insert("k".to_string(), "10".to_string());
4921 let (internal, inferred, eff_degree) = parse_ps_internal_knots(&opts, 3, 20).expect("k=10");
4922 assert_eq!(internal, 6);
4923 assert_eq!(eff_degree, 3);
4924 assert!(!inferred);
4925 }
4926
4927 #[test]
4928 fn parse_ps_internal_knots_drops_degree_for_small_k() {
4929 let mut opts = BTreeMap::new();
4934 opts.insert("k".to_string(), "3".to_string());
4935 let (internal, inferred, eff_degree) = parse_ps_internal_knots(&opts, 3, 20).expect("k=3");
4936 assert_eq!(eff_degree, 2);
4937 assert_eq!(internal, 0);
4938 assert!(!inferred);
4939
4940 opts.insert("k".to_string(), "2".to_string());
4943 let (internal, inferred, eff_degree) = parse_ps_internal_knots(&opts, 3, 20).expect("k=2");
4944 assert_eq!(eff_degree, 1);
4945 assert_eq!(internal, 0);
4946 assert!(!inferred);
4947
4948 opts.insert("k".to_string(), "1".to_string());
4952 let err = parse_ps_internal_knots(&opts, 3, 20)
4953 .expect_err("k=1 is below the irreducible spline floor");
4954 assert!(err.contains("requires k >= 2"), "unexpected error: {err}");
4955
4956 opts.insert("k".to_string(), "4".to_string());
4959 let (internal, inferred, eff_degree) = parse_ps_internal_knots(&opts, 3, 20).expect("k=4");
4960 assert_eq!(eff_degree, 3);
4961 assert_eq!(internal, 2);
4962 assert!(!inferred);
4963 }
4964
4965 #[test]
4966 fn factor_smooth_marginal_degree_reduces_for_small_k() {
4967 let ds = factor_dataset();
4968 let col_map = ds.column_map();
4969
4970 for (k, expected_degree) in [(3usize, 2usize), (2usize, 1usize)] {
4971 let parsed =
4972 parse_formula(&format!("y ~ s(x, g, bs=fs, k={k})")).expect("parse factor smooth");
4973 let mut notes = Vec::new();
4974 let terms = build_termspec(
4975 &parsed.terms,
4976 &ds,
4977 &col_map,
4978 &mut notes,
4979 &gam_runtime::resource::ResourcePolicy::default_library(),
4980 )
4981 .unwrap_or_else(|err| panic!("fs k={k} should degree-reduce, got: {err:?}"));
4982 let SmoothBasisSpec::FactorSmooth { spec } = &terms.smooth_terms[0].basis else {
4983 panic!(
4984 "expected factor smooth, got {:?}",
4985 terms.smooth_terms[0].basis
4986 );
4987 };
4988 assert_eq!(spec.marginal.degree, expected_degree);
4989 assert!(
4990 spec.marginal.penalty_order <= spec.marginal.degree,
4991 "penalty_order {} must be clamped to degree {}",
4992 spec.marginal.penalty_order,
4993 spec.marginal.degree
4994 );
4995 let basis_size = match spec.marginal.knotspec {
4996 BSplineKnotSpec::Generate {
4997 num_internal_knots, ..
4998 } => num_internal_knots + spec.marginal.degree + 1,
4999 BSplineKnotSpec::Automatic {
5000 num_internal_knots: Some(num_internal_knots),
5001 ..
5002 } => num_internal_knots + spec.marginal.degree + 1,
5003 ref other => panic!("unexpected factor-smooth knotspec: {other:?}"),
5004 };
5005 assert_eq!(basis_size, k);
5006 }
5007 }
5008
5009 fn ternary_factor_dataset() -> Dataset {
5012 let rows = (0..120)
5013 .map(|i| {
5014 let x = (i % 3) as f64;
5015 let g = (i % 2) as f64;
5016 vec![x + g, x, g]
5017 })
5018 .collect::<Vec<_>>();
5019 Dataset {
5020 headers: vec!["y".into(), "x".into(), "g".into()],
5021 values: Array2::from_shape_vec(
5022 (rows.len(), 3),
5023 rows.into_iter().flat_map(|row| row.into_iter()).collect(),
5024 )
5025 .expect("rectangular ternary factor test data"),
5026 schema: DataSchema {
5027 columns: vec![
5028 SchemaColumn {
5029 name: "y".into(),
5030 kind: ColumnKindTag::Continuous,
5031 levels: vec![],
5032 },
5033 SchemaColumn {
5034 name: "x".into(),
5035 kind: ColumnKindTag::Continuous,
5036 levels: vec![],
5037 },
5038 SchemaColumn {
5039 name: "g".into(),
5040 kind: ColumnKindTag::Categorical,
5041 levels: vec!["a".into(), "b".into()],
5042 },
5043 ],
5044 },
5045 column_kinds: vec![
5046 ColumnKindTag::Continuous,
5047 ColumnKindTag::Continuous,
5048 ColumnKindTag::Categorical,
5049 ],
5050 }
5051 }
5052
5053 #[test]
5054 fn univariate_cr_smooth_caps_knots_to_data_support() {
5055 let ds = continuous_dataset(
5061 &["y", "x"],
5062 (0..90)
5063 .map(|i| vec![(i % 3) as f64, (i % 3) as f64])
5064 .collect(),
5065 );
5066 let col_map = ds.column_map();
5067 let parsed = parse_formula("y ~ s(x, bs=cr, k=10)").expect("parse cr smooth");
5068 let mut notes = Vec::new();
5069 let terms = build_termspec(
5070 &parsed.terms,
5071 &ds,
5072 &col_map,
5073 &mut notes,
5074 &gam_runtime::resource::ResourcePolicy::default_library(),
5075 )
5076 .expect("cr k=10 must cap to data support instead of erroring");
5077 let SmoothBasisSpec::BSpline1D { spec, .. } = &terms.smooth_terms[0].basis else {
5078 panic!("expected BSpline1D for s(x, bs=cr)");
5079 };
5080 let BSplineKnotSpec::NaturalCubicRegression { knots } = &spec.knotspec else {
5081 panic!("expected cr knotspec, got {:?}", spec.knotspec);
5082 };
5083 assert_eq!(knots.len(), 3, "cr basis not capped to 3 distinct values");
5085 assert_eq!(knots.as_slice().unwrap(), &[0.0, 1.0, 2.0]);
5086 assert!(
5088 notes.iter().any(|n| n.contains("data-support cap")),
5089 "cap not reported in inference notes: {notes:?}"
5090 );
5091 }
5092
5093 #[test]
5094 fn univariate_cr_smooth_binary_covariate_degrades_to_bspline() {
5095 let ds = continuous_dataset(
5099 &["y", "x"],
5100 (0..80)
5101 .map(|i| vec![(i % 2) as f64, (i % 2) as f64])
5102 .collect(),
5103 );
5104 let col_map = ds.column_map();
5105 let parsed = parse_formula("y ~ s(x, bs=cr, k=10)").expect("parse cr smooth");
5106 let mut notes = Vec::new();
5107 let terms = build_termspec(
5108 &parsed.terms,
5109 &ds,
5110 &col_map,
5111 &mut notes,
5112 &gam_runtime::resource::ResourcePolicy::default_library(),
5113 )
5114 .expect("binary cr must degrade to B-spline instead of erroring");
5115 let SmoothBasisSpec::BSpline1D { spec, .. } = &terms.smooth_terms[0].basis else {
5116 panic!("expected BSpline1D for s(x, bs=cr)");
5117 };
5118 assert!(
5119 !matches!(
5120 spec.knotspec,
5121 BSplineKnotSpec::NaturalCubicRegression { .. }
5122 ),
5123 "binary covariate must NOT build a cr basis, got {:?}",
5124 spec.knotspec
5125 );
5126 assert!(
5127 notes
5128 .iter()
5129 .any(|n| n.contains("Degraded to the linear B-spline")),
5130 "degradation not reported in inference notes: {notes:?}"
5131 );
5132 }
5133
5134 #[test]
5135 fn sz_factor_smooth_caps_cr_marginal_to_data_support() {
5136 let ds = ternary_factor_dataset();
5140 let col_map = ds.column_map();
5141 let parsed = parse_formula("y ~ s(x, g, bs=sz, k=10)").expect("parse sz factor smooth");
5142 let mut notes = Vec::new();
5143 let terms = build_termspec(
5144 &parsed.terms,
5145 &ds,
5146 &col_map,
5147 &mut notes,
5148 &gam_runtime::resource::ResourcePolicy::default_library(),
5149 )
5150 .expect("sz k=10 must cap the cr marginal instead of erroring");
5151 let SmoothBasisSpec::FactorSmooth { spec } = &terms.smooth_terms[0].basis else {
5152 panic!("expected FactorSmooth for s(x, g, bs=sz)");
5153 };
5154 let BSplineKnotSpec::NaturalCubicRegression { knots } = &spec.marginal.knotspec else {
5155 panic!(
5156 "expected cr marginal knotspec, got {:?}",
5157 spec.marginal.knotspec
5158 );
5159 };
5160 assert_eq!(
5161 knots.len(),
5162 3,
5163 "sz cr marginal not capped to 3 distinct values"
5164 );
5165 assert_eq!(knots.as_slice().unwrap(), &[0.0, 1.0, 2.0]);
5166 }
5167
5168 fn factor_dataset_l3() -> Dataset {
5179 let rows = (0..30)
5181 .map(|i| {
5182 let x = i as f64 / 29.0;
5183 let g = (i % 3) as f64;
5184 vec![x + g, x, g]
5185 })
5186 .collect::<Vec<_>>();
5187 Dataset {
5188 headers: vec!["y".into(), "x".into(), "g".into()],
5189 values: Array2::from_shape_vec(
5190 (rows.len(), 3),
5191 rows.into_iter().flat_map(|row| row.into_iter()).collect(),
5192 )
5193 .expect("rectangular L=3 factor test data"),
5194 schema: DataSchema {
5195 columns: vec![
5196 SchemaColumn {
5197 name: "y".into(),
5198 kind: ColumnKindTag::Continuous,
5199 levels: vec![],
5200 },
5201 SchemaColumn {
5202 name: "x".into(),
5203 kind: ColumnKindTag::Continuous,
5204 levels: vec![],
5205 },
5206 SchemaColumn {
5207 name: "g".into(),
5208 kind: ColumnKindTag::Categorical,
5209 levels: vec!["a".into(), "b".into(), "c".into()],
5210 },
5211 ],
5212 },
5213 column_kinds: vec![
5214 ColumnKindTag::Continuous,
5215 ColumnKindTag::Continuous,
5216 ColumnKindTag::Categorical,
5217 ],
5218 }
5219 }
5220
5221 #[test]
5222 fn factor_by_smooth_plus_bare_categorical_does_not_duplicate_factor_block() {
5223 let ds = factor_dataset_l3();
5224 let col_map = ds.column_map();
5225
5226 let g_blocks = |formula: &str| -> usize {
5227 let parsed = parse_formula(formula).expect("parse by-smooth formula");
5228 let mut notes = Vec::new();
5229 let terms = build_termspec(
5230 &parsed.terms,
5231 &ds,
5232 &col_map,
5233 &mut notes,
5234 &ResourcePolicy::default_library(),
5235 )
5236 .unwrap_or_else(|err| panic!("`{formula}` must build, got: {err:?}"));
5237 terms
5238 .random_effect_terms
5239 .iter()
5240 .filter(|rt| rt.name == "g")
5241 .count()
5242 };
5243
5244 let by_only = g_blocks("y ~ s(x, by=g, k=10)");
5248 assert_eq!(
5249 by_only, 1,
5250 "`y ~ s(x, by=g)` must produce exactly one `g` design block"
5251 );
5252
5253 let by_plus_bare = g_blocks("y ~ s(x, by=g, k=10) + g");
5257 assert_eq!(
5258 by_plus_bare, 1,
5259 "`y ~ s(x, by=g) + g` must collapse to ONE `g` block (#1457): the bare \
5260 `+ g` already owns the factor's level offsets, so the `by=` branch \
5261 must not add a second, treatment-coded main effect"
5262 );
5263
5264 assert_eq!(
5266 by_plus_bare, by_only,
5267 "the bare `+ g` collision must add zero extra `g` blocks (#1457)"
5268 );
5269 }
5270
5271 #[test]
5272 fn parse_tensor_periods_and_origins_aliases() {
5273 let mut opts = BTreeMap::new();
5274 opts.insert(
5275 "boundary".to_string(),
5276 "['periodic', 'periodic']".to_string(),
5277 );
5278 opts.insert("periods".to_string(), "[7, 24]".to_string());
5279 opts.insert("origins".to_string(), "[0, -12]".to_string());
5280 let axes = parse_periodic_axes(&opts, 2).expect("axes");
5281 let periods = parse_periods(&opts, &axes).expect("periods");
5282 let origins = parse_period_origins(&opts, &axes).expect("origins");
5283 assert_eq!(axes, vec![true, true]);
5284 assert_eq!(periods, vec![Some(7.0), Some(24.0)]);
5285 assert_eq!(origins, vec![Some(0.0), Some(-12.0)]);
5286 }
5287
5288 #[test]
5289 fn tensor_smooth_honors_per_margin_k_list() {
5290 let ds = continuous_dataset(
5291 &["y", "theta", "h"],
5292 (0..20)
5293 .map(|i| {
5294 let theta = std::f64::consts::TAU * i as f64 / 20.0;
5295 let h = -1.0 + 2.0 * (i % 5) as f64 / 4.0;
5296 vec![theta.cos() + h, theta, h]
5297 })
5298 .collect(),
5299 );
5300 let parsed = parse_formula(
5301 "y ~ te(theta, h, periodic=[0], period=[2*pi, None], origin=[0, None], k=[9,5])",
5302 )
5303 .expect("parse tensor formula");
5304 let col_map = ds.column_map();
5305 let mut notes = Vec::new();
5306 let terms = build_termspec(
5307 &parsed.terms,
5308 &ds,
5309 &col_map,
5310 &mut notes,
5311 &gam_runtime::resource::ResourcePolicy::default_library(),
5312 )
5313 .expect("build tensor terms");
5314 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
5315 panic!("expected tensor B-spline");
5316 };
5317 let dims = spec
5318 .marginalspecs
5319 .iter()
5320 .map(|m| match m.knotspec {
5321 BSplineKnotSpec::PeriodicUniform { num_basis, .. } => num_basis,
5322 BSplineKnotSpec::Generate {
5323 num_internal_knots, ..
5324 } => num_internal_knots + m.degree + 1,
5325 BSplineKnotSpec::NaturalCubicRegression { ref knots } => knots.len(),
5328 _ => panic!("unexpected tensor marginal knotspec"),
5329 })
5330 .collect::<Vec<_>>();
5331 assert_eq!(dims, vec![9, 5]);
5332 }
5333
5334 #[test]
5335 fn tensor_smooth_honors_per_margin_k_axis_aliases() {
5336 let ds = continuous_dataset(
5337 &["resp", "x", "y"],
5338 (0..12)
5339 .map(|i| {
5340 let t = i as f64 / 11.0;
5341 vec![t, t, 1.0 - t]
5342 })
5343 .collect(),
5344 );
5345 assert_eq!(
5346 tensor_margin_basis_sizes(&ds, "resp ~ te(x, y, k_x=9, k_y=5)"),
5347 vec![9, 5],
5348 "k_<margin> aliases should materialize requested per-margin values"
5349 );
5350 }
5351
5352 #[test]
5353 fn tensor_smooth_low_cardinality_axis_falls_back_to_lower_degree_basis() {
5354 let ds = continuous_dataset(
5361 &["y", "x", "b"],
5362 (0..40)
5363 .map(|i| {
5364 let x = i as f64 / 39.0;
5365 let b = (i % 2) as f64;
5366 vec![x.sin() + 0.5 * b, x, b]
5367 })
5368 .collect(),
5369 );
5370 let parsed = parse_formula("y ~ te(x, b, k=[5, 2])").expect("parse tensor with k=[5,2]");
5371 let col_map = ds.column_map();
5372 let mut notes = Vec::new();
5373 let terms = build_termspec(
5374 &parsed.terms,
5375 &ds,
5376 &col_map,
5377 &mut notes,
5378 &gam_runtime::resource::ResourcePolicy::default_library(),
5379 )
5380 .expect("build tensor with binary margin");
5381 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
5382 panic!("expected tensor B-spline for te(x, b)");
5383 };
5384 let continuous = &spec.marginalspecs[0];
5388 let binary = &spec.marginalspecs[1];
5389 assert_eq!(continuous.degree, 3);
5390 assert_eq!(binary.degree, 1);
5391 assert!(
5392 binary.penalty_order >= 1 && binary.penalty_order <= binary.degree,
5393 "binary margin penalty_order {} must satisfy 1 <= order <= degree={}",
5394 binary.penalty_order,
5395 binary.degree
5396 );
5397 let basis_size = |m: &BSplineBasisSpec| match m.knotspec {
5398 BSplineKnotSpec::PeriodicUniform { num_basis, .. } => num_basis,
5399 BSplineKnotSpec::Generate {
5400 num_internal_knots, ..
5401 } => num_internal_knots + m.degree + 1,
5402 BSplineKnotSpec::Automatic {
5403 num_internal_knots: Some(n),
5404 ..
5405 } => n + m.degree + 1,
5406 BSplineKnotSpec::NaturalCubicRegression { ref knots } => knots.len(),
5409 _ => panic!("unexpected tensor marginal knotspec"),
5410 };
5411 assert_eq!(basis_size(continuous), 5);
5412 assert_eq!(basis_size(binary), 2);
5413 }
5414
5415 #[test]
5416 fn tensor_smooth_uniform_k_is_capped_to_a_low_cardinality_margins_distinct_values() {
5417 let ds = continuous_dataset(
5425 &["y", "x", "b"],
5426 (0..40)
5427 .map(|i| {
5428 let x = i as f64 / 39.0;
5429 let b = (i % 2) as f64;
5430 vec![x.sin() + 0.5 * b, x, b]
5431 })
5432 .collect(),
5433 );
5434 let parsed = parse_formula("y ~ te(x, b, k=5)").expect("parse tensor with uniform k=5");
5435 let col_map = ds.column_map();
5436 let mut notes = Vec::new();
5437 let terms = build_termspec(
5438 &parsed.terms,
5439 &ds,
5440 &col_map,
5441 &mut notes,
5442 &gam_runtime::resource::ResourcePolicy::default_library(),
5443 )
5444 .expect("uniform k=5 must auto-cap the binary margin instead of erroring");
5445 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
5446 panic!("expected tensor B-spline for te(x, b)");
5447 };
5448 let basis_size = |m: &BSplineBasisSpec| match &m.knotspec {
5449 BSplineKnotSpec::PeriodicUniform { num_basis, .. } => *num_basis,
5450 BSplineKnotSpec::Generate {
5451 num_internal_knots, ..
5452 } => num_internal_knots + m.degree + 1,
5453 BSplineKnotSpec::Automatic {
5454 num_internal_knots: Some(n),
5455 ..
5456 } => n + m.degree + 1,
5457 BSplineKnotSpec::NaturalCubicRegression { knots } => knots.len(),
5458 other => panic!("unexpected tensor marginal knotspec: {other:?}"),
5459 };
5460 let binary = &spec.marginalspecs[1];
5461 assert_eq!(basis_size(binary), 2);
5464 assert_eq!(binary.degree, 1);
5465 assert_eq!(basis_size(&spec.marginalspecs[0]), 5);
5467 }
5468
5469 #[test]
5470 fn tensor_all_tp_margins_with_per_margin_k_routes_to_bspline_tensor() {
5471 let ds = continuous_dataset(
5480 &["y", "x1", "x2"],
5481 (0..32)
5482 .map(|i| {
5483 let t = i as f64 / 31.0;
5484 vec![t.sin(), t, 1.0 - t]
5485 })
5486 .collect(),
5487 );
5488 let parsed =
5489 parse_formula("y ~ te(x1, x2, bs=c('tp','tp'), k=c(5,5))").expect("parse tensor");
5490 let col_map = ds.column_map();
5491 let mut notes = Vec::new();
5492 let terms = build_termspec(
5493 &parsed.terms,
5494 &ds,
5495 &col_map,
5496 &mut notes,
5497 &gam_runtime::resource::ResourcePolicy::default_library(),
5498 )
5499 .expect("build tensor terms with per-margin k");
5500 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
5501 panic!(
5502 "expected B-spline tensor when k=c(5,5) is supplied with bs=c('tp','tp'), got {:?}",
5503 terms.smooth_terms[0].basis
5504 );
5505 };
5506 let dims = spec
5507 .marginalspecs
5508 .iter()
5509 .map(|m| match m.knotspec {
5510 BSplineKnotSpec::Generate {
5511 num_internal_knots, ..
5512 } => num_internal_knots + m.degree + 1,
5513 _ => panic!("unexpected tensor marginal knotspec"),
5514 })
5515 .collect::<Vec<_>>();
5516 assert_eq!(dims, vec![5, 5]);
5517 }
5518
5519 #[test]
5520 fn tensor_all_tp_margins_without_per_margin_k_builds_anisotropic_tensor() {
5521 let ds = continuous_dataset(
5529 &["y", "x1", "x2"],
5530 (0..32)
5531 .map(|i| {
5532 let t = i as f64 / 31.0;
5533 vec![t.sin(), t, 1.0 - t]
5534 })
5535 .collect(),
5536 );
5537 let parsed = parse_formula("y ~ te(x1, x2, bs=c('tp','tp'))").expect("parse tensor");
5538 let col_map = ds.column_map();
5539 let mut notes = Vec::new();
5540 let terms = build_termspec(
5541 &parsed.terms,
5542 &ds,
5543 &col_map,
5544 &mut notes,
5545 &gam_runtime::resource::ResourcePolicy::default_library(),
5546 )
5547 .expect("build tensor terms without per-margin k");
5548 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
5549 panic!(
5550 "te(...,bs=c('tp','tp')) must route to an anisotropic tensor product, not a \
5551 silent isotropic thin-plate substitution; got {:?}",
5552 terms.smooth_terms[0].basis
5553 );
5554 };
5555 assert_eq!(
5556 spec.marginalspecs.len(),
5557 2,
5558 "tp tensor must carry one penalized B-spline margin per axis"
5559 );
5560 }
5561
5562 #[test]
5563 fn explicit_basis_sizes_are_not_small_n_clamped() {
5564 let ds = continuous_dataset(
5565 &["y", "x1", "x2", "x3", "x4", "x5"],
5566 (0..12)
5567 .map(|i| {
5568 let x = i as f64 / 11.0;
5569 vec![x.sin(), x, x * x, x + 0.1, 1.0 - x, (2.0 * x).sin()]
5570 })
5571 .collect(),
5572 );
5573 let parsed = parse_formula("y ~ s(x1, k=10) + s(x2) + s(x3) + s(x4) + s(x5)")
5574 .expect("parse multi-smooth formula");
5575 let col_map = ds.column_map();
5576 let mut notes = Vec::new();
5577 let terms = build_termspec(
5578 &parsed.terms,
5579 &ds,
5580 &col_map,
5581 &mut notes,
5582 &gam_runtime::resource::ResourcePolicy::default_library(),
5583 )
5584 .expect("build multi-smooth terms");
5585 let SmoothBasisSpec::BSpline1D { spec, .. } = &terms.smooth_terms[0].basis else {
5586 panic!("expected first smooth to be B-spline");
5587 };
5588 assert!(matches!(
5589 &spec.knotspec,
5590 BSplineKnotSpec::Generate {
5591 num_internal_knots: 6,
5592 ..
5593 }
5594 ));
5595 }
5596
5597 #[test]
5598 fn explicit_duchon_centers_are_not_small_n_bumped() {
5599 let ds = continuous_dataset(
5600 &["y", "x1", "x2", "x3", "x4", "x5"],
5601 (0..12)
5602 .map(|i| {
5603 let x = i as f64 / 11.0;
5604 vec![x.sin(), x, x * x, x + 0.1, 1.0 - x, (2.0 * x).sin()]
5605 })
5606 .collect(),
5607 );
5608 let parsed = parse_formula("y ~ duchon(x1, centers=3) + s(x2) + s(x3) + s(x4) + s(x5)")
5615 .expect("parse multi-smooth formula");
5616 let col_map = ds.column_map();
5617 let mut notes = Vec::new();
5618 let terms = build_termspec(
5619 &parsed.terms,
5620 &ds,
5621 &col_map,
5622 &mut notes,
5623 &gam_runtime::resource::ResourcePolicy::default_library(),
5624 )
5625 .expect("build multi-smooth terms");
5626 let SmoothBasisSpec::Duchon { spec, .. } = &terms.smooth_terms[0].basis else {
5627 panic!("expected first smooth to be Duchon");
5628 };
5629 assert!(matches!(
5630 spec.center_strategy,
5631 CenterStrategy::FarthestPoint { num_centers: 3 }
5632 ));
5633 }
5634
5635 #[test]
5636 fn inferred_tensor_basis_cap_uses_coordinate_support_not_duplicate_rows() {
5637 let mut unique_rows = Vec::new();
5638 for i in 0..50 {
5639 let theta = i as f64 / 50.0;
5640 for j in 0..16 {
5641 let h = -1.0 + 2.0 * (j as f64) / 15.0;
5642 let y = theta.cos() + h;
5643 unique_rows.push(vec![y, theta, h]);
5644 }
5645 }
5646 let mut repeated_rows = Vec::new();
5647 for _ in 0..12 {
5648 repeated_rows.extend(unique_rows.iter().cloned());
5649 }
5650
5651 let unique = continuous_dataset(&["y", "theta", "h"], unique_rows);
5652 let repeated = continuous_dataset(&["y", "theta", "h"], repeated_rows);
5653
5654 let unique_basis = inferred_tensor_basis_product(&unique);
5655 let repeated_basis = inferred_tensor_basis_product(&repeated);
5656
5657 assert_eq!(
5658 unique_basis, repeated_basis,
5659 "duplicating existing tensor coordinates must not inflate inferred basis width"
5660 );
5661 }
5662
5663 #[test]
5664 fn inferred_three_dim_tensor_basis_stays_bounded_for_reml_selection() {
5665 let make = |n: usize| -> usize {
5673 let mut rows = Vec::with_capacity(n);
5674 for i in 0..n {
5675 let f = i as f64 / n as f64;
5676 rows.push(vec![f.sin(), f, (2.0 * f).cos(), (3.0 * f) % 1.0]);
5677 }
5678 let ds = continuous_dataset(&["y", "x1", "x2", "x3"], rows);
5679 let parsed = parse_formula("y ~ te(x1, x2, x3)").expect("parse 3-D tensor");
5680 let col_map = ds.column_map();
5681 let mut notes = Vec::new();
5682 let terms = build_termspec(
5683 &parsed.terms,
5684 &ds,
5685 &col_map,
5686 &mut notes,
5687 &ResourcePolicy::default_library(),
5688 )
5689 .expect("build 3-D tensor termspec");
5690 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
5691 panic!("expected tensor smooth");
5692 };
5693 spec.marginalspecs
5694 .iter()
5695 .map(|m| match m.knotspec {
5696 BSplineKnotSpec::Generate {
5697 num_internal_knots, ..
5698 } => num_internal_knots + m.degree + 1,
5699 BSplineKnotSpec::Automatic {
5700 num_internal_knots: Some(num_internal_knots),
5701 ..
5702 } => num_internal_knots + m.degree + 1,
5703 BSplineKnotSpec::NaturalCubicRegression { ref knots } => knots.len(),
5706 _ => panic!("unexpected tensor margin knotspec"),
5707 })
5708 .product()
5709 };
5710
5711 assert!(
5713 make(60) <= 216,
5714 "3-D te at small n must stay near the mgcv te default, got {}",
5715 make(60)
5716 );
5717 assert!(
5719 make(2000) <= 216,
5720 "3-D te at large n must not blow ∏k toward the data size, got {}",
5721 make(2000)
5722 );
5723 }
5724
5725 #[test]
5726 fn parse_bspline_boundary_conditions_and_side_selector() {
5727 let mut opts = BTreeMap::new();
5732 opts.insert("boundary_conditions".to_string(), "anchored".to_string());
5733 opts.insert("side".to_string(), "left".to_string());
5734 opts.insert("anchor".to_string(), "2.5".to_string());
5735 let err = parse_bspline_boundary_conditions(&opts)
5736 .expect_err("non-zero left anchor must be rejected")
5737 .to_string();
5738 assert!(
5739 err.contains("left") && err.contains("2.5"),
5740 "rejection should name the affected side and value: {err}"
5741 );
5742
5743 let mut opts = BTreeMap::new();
5747 opts.insert("start_bc".to_string(), "clamped".to_string());
5748 opts.insert("end_bc".to_string(), "zero".to_string());
5749 opts.insert("right_anchor".to_string(), "-1.0".to_string());
5750 let err = parse_bspline_boundary_conditions(&opts)
5751 .expect_err("non-zero right anchor must be rejected")
5752 .to_string();
5753 assert!(
5754 err.contains("right") && err.contains("-1"),
5755 "rejection should name the affected side and value: {err}"
5756 );
5757
5758 let mut opts = BTreeMap::new();
5762 opts.insert("start_bc".to_string(), "clamped".to_string());
5763 opts.insert("end_bc".to_string(), "zero".to_string());
5764 let parsed = parse_bspline_boundary_conditions(&opts).expect("boundary conditions");
5765 assert!(matches!(
5766 parsed.left,
5767 BSplineEndpointBoundaryCondition::Clamped
5768 ));
5769 assert!(matches!(
5770 parsed.right,
5771 BSplineEndpointBoundaryCondition::Anchored { value } if value.abs() < 1e-12
5772 ));
5773 }
5774
5775 #[test]
5776 fn categorical_by_numeric_interaction_expands_treatment_coded_cells() {
5777 let ds = factor_dataset();
5788 let parsed = parse_formula("y ~ x:g").expect("parse `y ~ x:g`");
5790 let col_map = ds.column_map();
5791 let mut notes = Vec::new();
5792 let terms = build_termspec(
5793 &parsed.terms,
5794 &ds,
5795 &col_map,
5796 &mut notes,
5797 &ResourcePolicy::default_library(),
5798 )
5799 .expect("factor-aware `x:g` interaction must build, not error");
5800
5801 assert_eq!(
5802 terms.linear_terms.len(),
5803 2,
5804 "interaction-only `x:g` keeps ALL factor levels (full dummy coding): one slope column per group"
5805 );
5806
5807 let x_col = *col_map.get("x").expect("x column");
5808 let g_col = *col_map.get("g").expect("g column");
5809
5810 let mut seen_bits = std::collections::HashSet::new();
5813 for term in &terms.linear_terms {
5814 assert!(
5815 term.is_interaction(),
5816 "the categorical-by-numeric cell is a Wilkinson-Rogers interaction"
5817 );
5818 assert_eq!(term.feature_cols, vec![x_col]);
5819 assert_eq!(term.categorical_levels.len(), 1);
5820 let (gate_col, gate_bits) = term.categorical_levels[0];
5821 assert_eq!(gate_col, g_col);
5822 assert!(seen_bits.insert(gate_bits), "each level appears once");
5823
5824 let column = term
5826 .realized_design_column(ds.values.view())
5827 .expect("realize cell column");
5828 let n = ds.values.nrows();
5829 assert_eq!(column.len(), n);
5830 for row in 0..n {
5831 let x = ds.values[[row, x_col]];
5832 let g = ds.values[[row, g_col]];
5833 let expected = if g.to_bits() == gate_bits { x } else { 0.0 };
5834 assert!(
5835 (column[row] - expected).abs() < 1e-12,
5836 "row {row}: g={g}, x={x}, expected {expected}, got {}",
5837 column[row]
5838 );
5839 }
5840 }
5841 assert!(seen_bits.contains(&0.0_f64.to_bits()));
5844 assert!(seen_bits.contains(&1.0_f64.to_bits()));
5845 }
5846
5847 #[test]
5848 fn categorical_by_numeric_interaction_keeps_treatment_coding_with_parent() {
5849 let ds = factor_dataset();
5857 let parsed = parse_formula("y ~ x + x:g").expect("parse `y ~ x + x:g`");
5858 let col_map = ds.column_map();
5859 let mut notes = Vec::new();
5860 let terms = build_termspec(
5861 &parsed.terms,
5862 &ds,
5863 &col_map,
5864 &mut notes,
5865 &ResourcePolicy::default_library(),
5866 )
5867 .expect("`x + x:g` must build");
5868
5869 let x_col = *col_map.get("x").expect("x column");
5871 let g_col = *col_map.get("g").expect("g column");
5872 let interaction_cells: Vec<_> = terms
5873 .linear_terms
5874 .iter()
5875 .filter(|t| t.is_interaction())
5876 .collect();
5877 assert_eq!(
5878 interaction_cells.len(),
5879 1,
5880 "with `x` present, `x:g` is treatment-coded → one cell (reference dropped)"
5881 );
5882 let term = interaction_cells[0];
5883 assert_eq!(term.feature_cols, vec![x_col]);
5884 assert_eq!(term.categorical_levels.len(), 1);
5885 let (gate_col, gate_bits) = term.categorical_levels[0];
5886 assert_eq!(gate_col, g_col);
5887 assert_eq!(gate_bits, 1.0_f64.to_bits());
5889 }
5890
5891 #[test]
5892 fn categorical_by_categorical_interaction_expands_full_cross_cells() {
5893 let n = 30usize;
5904 let mut rows = Vec::with_capacity(n);
5905 for i in 0..n {
5906 let y = (i as f64).sin();
5907 let f = (i % 3) as f64; let g = (i % 2) as f64; rows.push(vec![y, f, g]);
5910 }
5911 let values = Array2::from_shape_vec(
5912 (n, 3),
5913 rows.into_iter().flat_map(|row| row.into_iter()).collect(),
5914 )
5915 .expect("rectangular cross-factor data");
5916 let ds = Dataset {
5917 headers: vec!["y".into(), "f".into(), "g".into()],
5918 values,
5919 schema: DataSchema {
5920 columns: vec![
5921 SchemaColumn {
5922 name: "y".into(),
5923 kind: ColumnKindTag::Continuous,
5924 levels: vec![],
5925 },
5926 SchemaColumn {
5927 name: "f".into(),
5928 kind: ColumnKindTag::Categorical,
5929 levels: vec!["f0".into(), "f1".into(), "f2".into()],
5930 },
5931 SchemaColumn {
5932 name: "g".into(),
5933 kind: ColumnKindTag::Categorical,
5934 levels: vec!["g0".into(), "g1".into()],
5935 },
5936 ],
5937 },
5938 column_kinds: vec![
5939 ColumnKindTag::Continuous,
5940 ColumnKindTag::Categorical,
5941 ColumnKindTag::Categorical,
5942 ],
5943 };
5944
5945 let parsed = parse_formula("y ~ f:g").expect("parse `y ~ f:g`");
5946 let col_map = ds.column_map();
5947 let mut notes = Vec::new();
5948 let terms = build_termspec(
5949 &parsed.terms,
5950 &ds,
5951 &col_map,
5952 &mut notes,
5953 &ResourcePolicy::default_library(),
5954 )
5955 .expect("factor-by-factor `f:g` interaction must build, not error");
5956
5957 assert_eq!(
5958 terms.linear_terms.len(),
5959 5,
5960 "saturated 3*2 = 6 cross cells minus one reference cell (f0:g0) = 5"
5961 );
5962
5963 let f_col = *col_map.get("f").expect("f column");
5964 let g_col = *col_map.get("g").expect("g column");
5965 let f0 = 0.0_f64.to_bits();
5969 let g0 = 0.0_f64.to_bits();
5970 let mut emitted = std::collections::HashSet::new();
5971 for term in &terms.linear_terms {
5972 assert!(term.feature_cols.is_empty());
5974 assert_eq!(term.categorical_levels.len(), 2);
5975 let mut gates = std::collections::HashMap::new();
5976 for &(col, bits) in &term.categorical_levels {
5977 gates.insert(col, bits);
5978 }
5979 let f_bits = *gates.get(&f_col).expect("f gate present");
5980 let g_bits = *gates.get(&g_col).expect("g gate present");
5981 assert!(
5983 !(f_bits == f0 && g_bits == g0),
5984 "the reference cell f0:g0 must be absorbed by the intercept, not emitted"
5985 );
5986 emitted.insert((f_bits, g_bits));
5987
5988 let column = term
5989 .realized_design_column(ds.values.view())
5990 .expect("realize cross cell");
5991 for row in 0..n {
5992 let f = ds.values[[row, f_col]];
5993 let g = ds.values[[row, g_col]];
5994 let expected = if f.to_bits() == f_bits && g.to_bits() == g_bits {
5995 1.0
5996 } else {
5997 0.0
5998 };
5999 assert!(
6000 (column[row] - expected).abs() < 1e-12,
6001 "row {row}: expected {expected}, got {}",
6002 column[row]
6003 );
6004 }
6005 assert!(
6006 column.iter().any(|&v| v == 1.0),
6007 "each cross cell must be observed in the data"
6008 );
6009 }
6010 let f_levels = [0.0_f64.to_bits(), 1.0_f64.to_bits(), 2.0_f64.to_bits()];
6013 let g_levels = [0.0_f64.to_bits(), 1.0_f64.to_bits()];
6014 for &fb in &f_levels {
6015 for &gb in &g_levels {
6016 if fb == f0 && gb == g0 {
6017 continue;
6018 }
6019 assert!(
6020 emitted.contains(&(fb, gb)),
6021 "saturated cross cell must be present"
6022 );
6023 }
6024 }
6025 }
6026}