1use std::collections::{BTreeMap, BTreeSet, HashMap};
8use std::path::PathBuf;
9
10use ndarray::{Array2, ArrayView1};
11
12use crate::basis::{
13 BSplineBasisSpec, BSplineBoundaryConditions, BSplineEndpointBoundaryCondition,
14 BSplineIdentifiability, BSplineKnotSpec, CenterCountRequest, CenterStrategy,
15 ConstantCurvatureBasisSpec, ConstantCurvatureIdentifiability, DuchonBasisSpec,
16 DuchonNullspaceOrder, DuchonOperatorPenaltySpec, MaternBasisSpec, MaternIdentifiability,
17 MaternNu, MeasureJetBasisSpec, MeasureJetIdentifiability, OneDimensionalBoundary,
18 SpatialIdentifiability, SphereMethod, SphereWahbaKernel, SphericalSplineBasisSpec,
19 SphericalSplineIdentifiability, ThinPlateBasisSpec, auto_spatial_center_strategy,
20 default_num_centers, default_spatial_center_strategy, default_spherical_harmonic_degree,
21 plan_spatial_basis, thin_plate_penalty_order,
22};
23use crate::inference::formula_dsl::{
24 ParsedTerm, SmoothKind, option_bool, option_f64, option_f64_strict, option_usize,
25 option_usize_any, option_usize_any_strict, option_usize_strict, strip_quotes,
26};
27use crate::smooth::{
28 ByVarKind, FactorSmoothFlavour, FactorSmoothSpec, LinearCoefficientGeometry, LinearTermSpec,
29 RandomEffectTermSpec, ShapeConstraint, SmoothBasisSpec, SmoothTermSpec,
30 TensorBSplineIdentifiability, TensorBSplinePenaltyDecomposition, TensorBSplineSpec,
31 TermCollectionSpec,
32};
33use gam_problem::types::ColIdx;
34use gam_data::{ColumnKindTag, DataError, EncodedDataset as Dataset};
35use gam_runtime::resource::ResourcePolicy;
36
37const DEFAULT_BSPLINE_DEGREE: usize = 3;
41
42const DEFAULT_PENALTY_ORDER: usize = 2;
46
47const CYCLIC_DEFAULT_BASIS_DIM: usize = 12;
53
54const FACTOR_SMOOTH_DEFAULT_BASIS_DIM: usize = 10;
60
61const DEFAULT_PCA_CHUNK_SIZE: usize = 4096;
65
66#[derive(Clone, Debug)]
76pub enum TermBuilderError {
77 MissingColumn { reason: String },
84 ColumnNotFound {
90 name: String,
91 role: Option<String>,
92 available: Vec<String>,
93 similar: Vec<String>,
94 tsv_hint: bool,
95 },
96 IncompatibleConfig { reason: String },
100 InvalidOption { reason: String },
103 UnsupportedFeature { reason: String },
107 DegenerateData { reason: String },
110 MalformedFormula { reason: String },
113}
114
115impl std::fmt::Display for TermBuilderError {
116 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117 match self {
118 TermBuilderError::MissingColumn { reason }
119 | TermBuilderError::IncompatibleConfig { reason }
120 | TermBuilderError::InvalidOption { reason }
121 | TermBuilderError::UnsupportedFeature { reason }
122 | TermBuilderError::DegenerateData { reason }
123 | TermBuilderError::MalformedFormula { reason } => f.write_str(reason),
124 TermBuilderError::ColumnNotFound {
130 name,
131 role,
132 available,
133 similar,
134 tsv_hint,
135 } => {
136 let canonical = DataError::ColumnNotFound {
137 name: name.clone(),
138 role: role.clone(),
139 available: available.clone(),
140 similar: similar.clone(),
141 tsv_hint: *tsv_hint,
142 };
143 std::fmt::Display::fmt(&canonical, f)
144 }
145 }
146 }
147}
148
149impl From<TermBuilderError> for String {
150 fn from(err: TermBuilderError) -> String {
151 err.to_string()
152 }
153}
154
155impl From<String> for TermBuilderError {
162 fn from(reason: String) -> Self {
163 Self::IncompatibleConfig { reason }
164 }
165}
166
167impl From<DataError> for TermBuilderError {
174 fn from(err: DataError) -> Self {
175 match err {
176 DataError::ColumnNotFound {
177 name,
178 role,
179 available,
180 similar,
181 tsv_hint,
182 } => Self::ColumnNotFound {
183 name,
184 role,
185 available,
186 similar,
187 tsv_hint,
188 },
189 DataError::SchemaMismatch { reason }
190 | DataError::ParseError { reason }
191 | DataError::EncodingFailure { reason }
192 | DataError::EmptyInput { reason }
193 | DataError::InvalidValue { reason } => Self::MissingColumn { reason },
194 }
195 }
196}
197
198impl TermBuilderError {
200 #[inline]
201 fn missing_column(reason: impl Into<String>) -> Self {
202 TermBuilderError::MissingColumn {
203 reason: reason.into(),
204 }
205 }
206 #[inline]
207 fn incompatible_config(reason: impl Into<String>) -> Self {
208 TermBuilderError::IncompatibleConfig {
209 reason: reason.into(),
210 }
211 }
212 #[inline]
213 fn invalid_option(reason: impl Into<String>) -> Self {
214 TermBuilderError::InvalidOption {
215 reason: reason.into(),
216 }
217 }
218 #[inline]
219 fn unsupported_feature(reason: impl Into<String>) -> Self {
220 TermBuilderError::UnsupportedFeature {
221 reason: reason.into(),
222 }
223 }
224 #[inline]
225 fn degenerate_data(reason: impl Into<String>) -> Self {
226 TermBuilderError::DegenerateData {
227 reason: reason.into(),
228 }
229 }
230 #[inline]
231 fn malformed_formula(reason: impl Into<String>) -> Self {
232 TermBuilderError::MalformedFormula {
233 reason: reason.into(),
234 }
235 }
236}
237
238pub fn resolve_col(col_map: &HashMap<String, usize>, name: &str) -> Result<usize, DataError> {
249 col_map
250 .get(name)
251 .copied()
252 .ok_or_else(|| DataError::column_not_found(col_map, name, None))
253}
254
255pub fn resolve_role_col(
260 col_map: &HashMap<String, usize>,
261 name: &str,
262 role: &str,
263) -> Result<usize, DataError> {
264 col_map
265 .get(name)
266 .copied()
267 .ok_or_else(|| DataError::column_not_found(col_map, name, Some(role)))
268}
269
270fn encoded_levels_for_column(ds: &Dataset, col: ColIdx) -> Vec<(u64, String)> {
271 let mut seen = BTreeSet::<u64>::new();
272 for value in ds.values.column(col.get()) {
273 if value.is_finite() {
274 seen.insert(value.to_bits());
275 }
276 }
277 let schema_levels = ds
278 .schema
279 .columns
280 .get(col.get())
281 .map(|column| column.levels.as_slice())
282 .unwrap_or(&[]);
283 seen.into_iter()
284 .enumerate()
285 .map(|(idx, bits)| {
286 let fallback = format!("level{}", idx + 1);
287 let label = schema_levels.get(idx).cloned().unwrap_or(fallback);
288 (bits, label)
289 })
290 .collect()
291}
292
293pub fn column_map_with_alias(
294 col_map: &HashMap<String, usize>,
295 alias: &str,
296 target_column: &str,
297) -> HashMap<String, usize> {
298 let mut aliased = col_map.clone();
299 if let Some(idx) = col_map.get(target_column).copied() {
300 aliased.entry(alias.to_string()).or_insert(idx);
301 }
302 aliased
303}
304
305pub fn build_termspec(
310 terms: &[ParsedTerm],
311 ds: &Dataset,
312 col_map: &HashMap<String, usize>,
313 inference_notes: &mut Vec<String>,
314 policy: &ResourcePolicy,
315) -> Result<TermCollectionSpec, TermBuilderError> {
316 let mut linear_terms = Vec::<LinearTermSpec>::new();
317 let mut random_terms = Vec::<RandomEffectTermSpec>::new();
318 let mut smooth_terms = Vec::<SmoothTermSpec>::new();
319 let smooth_coordinate_count = terms
320 .iter()
321 .map(|term| match term {
322 ParsedTerm::Smooth { vars, .. } => vars.len(),
323 _ => 0,
324 })
325 .sum::<usize>();
326
327 for t in terms {
328 match t {
329 ParsedTerm::Linear {
330 name,
331 explicit,
332 coefficient_min,
333 coefficient_max,
334 } => {
335 let col = resolve_col(col_map, name)?;
336 let auto_kind = ds.column_kinds.get(col).copied().ok_or_else(|| {
337 TermBuilderError::missing_column(format!(
338 "internal column-kind lookup failed for '{name}'"
339 ))
340 .to_string()
341 })?;
342 if *explicit {
343 linear_terms.push(LinearTermSpec {
344 name: name.clone(),
345 feature_col: col,
346 feature_cols: vec![col],
347 categorical_levels: vec![],
348 double_penalty: false,
351 coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
352 coefficient_min: *coefficient_min,
353 coefficient_max: *coefficient_max,
354 });
355 } else {
356 match auto_kind {
357 ColumnKindTag::Continuous | ColumnKindTag::Binary => {
358 linear_terms.push(LinearTermSpec {
359 name: name.clone(),
360 feature_col: col,
361 feature_cols: vec![col],
362 categorical_levels: vec![],
363 double_penalty: false,
365 coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
366 coefficient_min: *coefficient_min,
367 coefficient_max: *coefficient_max,
368 });
369 }
370 ColumnKindTag::Categorical => {
371 if coefficient_min.is_some() || coefficient_max.is_some() {
372 return Err(TermBuilderError::incompatible_config(format!(
373 "coefficient constraints are not supported for categorical auto-random-effect term '{name}'; use group({name}) or an unconstrained numeric term"
374 )));
375 }
376 random_terms.push(RandomEffectTermSpec {
377 name: name.clone(),
378 feature_col: col,
379 drop_first_level: false,
380 penalized: true,
381 frozen_levels: None,
382 });
383 }
384 }
385 }
386 }
387 ParsedTerm::BoundedLinear {
388 name,
389 min,
390 max,
391 prior,
392 } => {
393 let col = resolve_col(col_map, name)?;
394 let auto_kind = ds.column_kinds.get(col).copied().ok_or_else(|| {
395 TermBuilderError::missing_column(format!(
396 "internal column-kind lookup failed for '{name}'"
397 ))
398 .to_string()
399 })?;
400 if !matches!(auto_kind, ColumnKindTag::Continuous | ColumnKindTag::Binary) {
401 return Err(TermBuilderError::incompatible_config(format!(
402 "bounded() currently supports only numeric columns, got categorical '{name}'"
403 )));
404 }
405 linear_terms.push(LinearTermSpec {
406 name: name.clone(),
407 feature_col: col,
408 feature_cols: vec![col],
409 categorical_levels: vec![],
410 double_penalty: false,
411 coefficient_geometry: LinearCoefficientGeometry::Bounded {
412 min: *min,
413 max: *max,
414 prior: prior.clone(),
415 },
416 coefficient_min: None,
417 coefficient_max: None,
418 });
419 }
420 ParsedTerm::RandomEffect { name } => {
421 let col = resolve_col(col_map, name)?;
422 random_terms.push(RandomEffectTermSpec {
423 name: name.clone(),
424 feature_col: col,
425 drop_first_level: false,
426 penalized: true,
427 frozen_levels: None,
428 });
429 }
430 ParsedTerm::Smooth {
431 label,
432 vars,
433 kind,
434 options,
435 } => {
436 let smooth_vars = vars.clone();
437 let by_name = options.get("by").cloned();
438 let cols = smooth_vars
448 .iter()
449 .map(|v| resolve_col(col_map, v))
450 .collect::<Result<Vec<_>, _>>()?;
451 let mut inner_options = options.clone();
452 inner_options.remove("by");
453 inner_options.remove("ordered");
457 let shape = match inner_options.remove("shape") {
463 None => ShapeConstraint::None,
464 Some(raw) => crate::smooth::parse_shape_constraint(&raw)
465 .map_err(TermBuilderError::invalid_option)?,
466 };
467 let inner_basis = build_smooth_basis(
468 *kind,
469 &smooth_vars,
470 &cols,
471 &inner_options,
472 ds,
473 inference_notes,
474 policy,
475 smooth_coordinate_count,
476 )?;
477 if let Some(by_name) = by_name {
478 let by_col = resolve_col(col_map, &by_name)?;
479 match ds.column_kinds.get(by_col).copied().ok_or_else(|| {
480 format!("internal column-kind lookup failed for by variable '{by_name}'")
481 })? {
482 ColumnKindTag::Categorical => {
483 let levels = encoded_levels_for_column(ds, ColIdx::new(by_col));
484 let penalized_group_owner_present =
497 terms.iter().any(|other| match other {
498 ParsedTerm::RandomEffect { name } => name == &by_name,
499 ParsedTerm::Linear {
500 name,
501 explicit: false,
502 ..
503 } if name == &by_name => col_map
504 .get(name)
505 .and_then(|c| ds.column_kinds.get(*c).copied())
506 .map(|kind| matches!(kind, ColumnKindTag::Categorical))
507 .unwrap_or(false),
508 _ => false,
509 });
510 if !random_terms.iter().any(|rt| rt.name == by_name)
521 && !penalized_group_owner_present
522 {
523 random_terms.push(RandomEffectTermSpec {
524 name: by_name.clone(),
525 feature_col: by_col,
526 drop_first_level: true,
527 penalized: false,
528 frozen_levels: None,
529 });
530 }
531 let frozen_levels: Vec<u64> =
536 levels.iter().map(|(bits, _)| *bits).collect();
537 smooth_terms.push(SmoothTermSpec {
538 name: label.clone(),
539 basis: SmoothBasisSpec::BySmooth {
540 smooth: Box::new(inner_basis),
541 by_kind: ByVarKind::Factor {
542 feature_col: by_col,
543 ordered: option_bool(options, "ordered").unwrap_or(false),
544 frozen_levels: Some(frozen_levels),
545 },
546 },
547 shape,
548 joint_null_rotation: None,
549 });
550 }
551 ColumnKindTag::Binary | ColumnKindTag::Continuous => {
552 smooth_terms.push(SmoothTermSpec {
553 name: label.clone(),
554 basis: SmoothBasisSpec::BySmooth {
555 smooth: Box::new(inner_basis),
556 by_kind: ByVarKind::Numeric {
557 feature_col: by_col,
558 },
559 },
560 shape,
561 joint_null_rotation: None,
562 });
563 }
564 }
565 } else {
566 smooth_terms.push(SmoothTermSpec {
567 name: label.clone(),
568 basis: inner_basis,
569 shape,
570 joint_null_rotation: None,
571 });
572 }
573 }
574 ParsedTerm::LinkWiggle { .. }
575 | ParsedTerm::TimeWiggle { .. }
576 | ParsedTerm::LinkConfig { .. }
577 | ParsedTerm::SurvivalConfig { .. } => {
578 }
580 ParsedTerm::LogSlopeSurface { .. } => {
581 return Err(TermBuilderError::malformed_formula(
582 "logslope(...) declarations must be resolved by the marginal-slope formula path before building a term spec",
583 ));
584 }
585 ParsedTerm::Interaction { vars } => {
586 let main_effect_present = |target: &str| -> bool {
619 terms.iter().any(|other| match other {
620 ParsedTerm::Linear { name, .. }
621 | ParsedTerm::BoundedLinear { name, .. }
622 | ParsedTerm::RandomEffect { name } => name == target,
623 _ => false,
624 })
625 };
626 let parent_present = |drop_var: &str| -> bool {
632 vars.iter()
633 .filter(|v| v.as_str() != drop_var)
634 .all(|v| main_effect_present(v))
635 };
636
637 let mut numeric_cols = Vec::<usize>::new();
638 let mut categorical_factors =
641 Vec::<(String, usize, Vec<(u64, String)>, bool)>::new();
642 for var in vars {
643 let col = resolve_col(col_map, var)?;
644 let kind = ds.column_kinds.get(col).copied().ok_or_else(|| {
645 TermBuilderError::missing_column(format!(
646 "internal column-kind lookup failed for '{var}'"
647 ))
648 .to_string()
649 })?;
650 match kind {
651 ColumnKindTag::Continuous | ColumnKindTag::Binary => numeric_cols.push(col),
652 ColumnKindTag::Categorical => {
653 let mut levels = encoded_levels_for_column(ds, ColIdx::new(col));
654 let treatment_coded = parent_present(var);
658 if treatment_coded && levels.len() > 1 {
659 levels.remove(0);
660 }
661 if levels.is_empty() {
662 return Err(TermBuilderError::incompatible_config(format!(
663 "interaction `{}` references categorical column `{var}` with no usable levels",
664 vars.join(":")
665 )));
666 }
667 categorical_factors.push((var.clone(), col, levels, treatment_coded));
668 }
669 }
670 }
671
672 let label = vars.join(":");
673
674 if categorical_factors.is_empty() {
675 linear_terms.push(LinearTermSpec {
678 name: label,
679 feature_col: numeric_cols[0],
680 feature_cols: numeric_cols,
681 categorical_levels: vec![],
682 double_penalty: false,
685 coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
686 coefficient_min: None,
687 coefficient_max: None,
688 });
689 inference_notes.push(format!(
690 "wired linear interaction `{}` as product of numeric columns",
691 vars.join(":")
692 ));
693 } else {
694 let mut cells: Vec<Vec<(usize, u64, String)>> = vec![Vec::new()];
699 for (_var, col, levels, _treatment_coded) in &categorical_factors {
700 let mut next = Vec::with_capacity(cells.len() * levels.len());
701 for cell in &cells {
702 for (bits, level_label) in levels {
703 let mut extended = cell.clone();
704 extended.push((*col, *bits, level_label.clone()));
705 next.push(extended);
706 }
707 }
708 cells = next;
709 }
710
711 let any_dummy_coded = categorical_factors
723 .iter()
724 .any(|(_, _, _, treatment_coded)| !*treatment_coded);
725 if numeric_cols.is_empty() && any_dummy_coded {
726 let reference_cell: Vec<(usize, u64)> = categorical_factors
729 .iter()
730 .map(|(_, col, _, _)| {
731 let levels = encoded_levels_for_column(ds, ColIdx::new(*col));
732 (*col, levels[0].0)
733 })
734 .collect();
735 cells.retain(|cell| {
736 !reference_cell.iter().all(|(rcol, rbits)| {
737 cell.iter()
738 .any(|(col, bits, _)| col == rcol && bits == rbits)
739 })
740 });
741 }
742
743 let n_cells = cells.len();
744 for cell in cells {
745 let cell_suffix = cell
746 .iter()
747 .map(|(_, _, level_label)| level_label.as_str())
748 .collect::<Vec<_>>()
749 .join(":");
750 let categorical_levels =
751 cell.iter().map(|(col, bits, _)| (*col, *bits)).collect();
752 let feature_col = numeric_cols
758 .first()
759 .copied()
760 .unwrap_or(categorical_factors[0].1);
761 linear_terms.push(LinearTermSpec {
762 name: format!("{label}:{cell_suffix}"),
763 feature_col,
764 feature_cols: numeric_cols.clone(),
765 categorical_levels,
766 double_penalty: false,
767 coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
768 coefficient_min: None,
769 coefficient_max: None,
770 });
771 }
772 let all_treatment_coded = !any_dummy_coded;
773 let coding = if all_treatment_coded {
774 "treatment-coded"
775 } else {
776 "marginality-aware (full dummy / saturated)"
777 };
778 inference_notes.push(format!(
779 "wired factor-aware linear interaction `{}` as {} {} cell column(s)",
780 vars.join(":"),
781 n_cells,
782 coding
783 ));
784 }
785 }
786 }
787 }
788
789 Ok(TermCollectionSpec {
790 linear_terms,
791 random_effect_terms: random_terms,
792 smooth_terms,
793 })
794}
795
796fn split_list_option(raw: &str) -> Vec<String> {
797 let t = raw.trim();
798 let inner = t
805 .strip_prefix('[')
806 .and_then(|u| u.strip_suffix(']'))
807 .or_else(|| {
808 t.strip_prefix("c(")
809 .or_else(|| t.strip_prefix("C("))
810 .or_else(|| t.strip_prefix('('))
811 .and_then(|u| u.strip_suffix(')'))
812 })
813 .unwrap_or(t);
814 inner
815 .split(',')
816 .map(|v| v.trim().to_string())
817 .filter(|v| !v.is_empty())
818 .collect()
819}
820
821fn parse_numeric_expr(raw: &str) -> Result<f64, String> {
822 let mut acc = 1.0f64;
823 let normalized = raw.replace(' ', "");
824 if normalized.eq_ignore_ascii_case("none") {
825 return Err("None is not numeric".to_string());
826 }
827 for factor in normalized.split('*') {
828 if factor.is_empty() {
829 return Err(format!("invalid numeric expression '{raw}'"));
830 }
831 let value = if factor.eq_ignore_ascii_case("pi") || factor == "π" {
832 std::f64::consts::PI
833 } else if factor.eq_ignore_ascii_case("tau") || factor == "τ" {
834 std::f64::consts::TAU
835 } else if let Some(prefix) = factor
836 .strip_suffix("pi")
837 .or_else(|| factor.strip_suffix("π"))
838 {
839 let coefficient = if prefix.is_empty() {
840 1.0
841 } else {
842 prefix
843 .parse::<f64>()
844 .map_err(|err| format!("invalid numeric expression '{raw}': {err}"))?
845 };
846 coefficient * std::f64::consts::PI
847 } else if let Some(prefix) = factor
848 .strip_suffix("tau")
849 .or_else(|| factor.strip_suffix("τ"))
850 {
851 let coefficient = if prefix.is_empty() {
852 1.0
853 } else {
854 prefix
855 .parse::<f64>()
856 .map_err(|err| format!("invalid numeric expression '{raw}': {err}"))?
857 };
858 coefficient * std::f64::consts::TAU
859 } else {
860 factor
861 .parse::<f64>()
862 .map_err(|err| format!("invalid numeric expression '{raw}': {err}"))?
863 };
864 acc *= value;
865 }
866 Ok(acc)
867}
868
869fn option_numeric_expr(
879 options: &BTreeMap<String, String>,
880 key: &str,
881) -> Result<Option<f64>, String> {
882 match options.get(key) {
883 None => Ok(None),
884 Some(raw) => parse_numeric_expr(raw)
885 .map(Some)
886 .map_err(|err| format!("option `{key}={raw}` is not a valid numeric value: {err}")),
887 }
888}
889
890fn parse_periods_option(
891 options: &BTreeMap<String, String>,
892 dim: usize,
893) -> Result<Option<Vec<Option<f64>>>, String> {
894 let Some(raw) = options.get("period") else {
895 return Ok(None);
896 };
897 let values = split_list_option(raw);
898 let mut periods = vec![None; dim];
899 if values.len() == 1 && dim == 1 {
900 periods[0] = Some(parse_numeric_expr(&values[0])?);
901 } else {
902 if values.len() != dim {
903 return Err(format!(
904 "period list length {} must match smooth dimension {}",
905 values.len(),
906 dim
907 ));
908 }
909 for (i, v) in values.iter().enumerate() {
910 if v.eq_ignore_ascii_case("none") {
911 continue;
912 }
913 periods[i] = Some(parse_numeric_expr(v)?);
914 }
915 }
916 Ok(Some(periods))
917}
918
919fn parse_periodic_axes_option(
920 options: &BTreeMap<String, String>,
921 dim: usize,
922) -> Result<Option<Vec<Option<f64>>>, String> {
923 let Some(raw_axes) = options.get("periodic") else {
924 return Ok(None);
925 };
926 let mut periods = parse_periods_option(options, dim)?.unwrap_or_else(|| vec![None; dim]);
927 let axes = split_list_option(raw_axes);
928 if axes.is_empty() {
929 return Ok(Some(periods));
930 }
931 for a in axes {
932 let axis = a
933 .parse::<usize>()
934 .map_err(|err| format!("invalid periodic axis '{a}': {err}"))?;
935 if axis >= dim {
936 return Err(format!(
937 "periodic axis {axis} out of range for {dim}D smooth"
938 ));
939 }
940 if periods[axis].is_none() {
941 return Err(format!(
942 "periodic axis {axis} requires period[{axis}] to be finite"
943 ));
944 }
945 }
946 let listed: std::collections::BTreeSet<usize> = split_list_option(raw_axes)
948 .into_iter()
949 .filter_map(|a| a.parse::<usize>().ok())
950 .collect();
951 for i in 0..dim {
952 if !listed.contains(&i) {
953 periods[i] = None;
954 }
955 }
956 Ok(Some(periods))
957}
958
959fn parse_option_list(raw: &str) -> Vec<String> {
964 let trimmed = raw.trim();
965 let inner = trimmed
971 .strip_prefix('[')
972 .and_then(|v| v.strip_suffix(']'))
973 .or_else(|| {
974 trimmed
975 .strip_prefix("c(")
976 .or_else(|| trimmed.strip_prefix("C("))
977 .or_else(|| trimmed.strip_prefix('('))
978 .and_then(|v| v.strip_suffix(')'))
979 })
980 .unwrap_or(trimmed);
981 inner
982 .split(',')
983 .map(|v| {
984 v.trim()
985 .trim_matches('"')
986 .trim_matches('\'')
987 .to_ascii_lowercase()
988 })
989 .filter(|v| !v.is_empty())
990 .collect()
991}
992
993fn parse_periodic_axes(
994 options: &BTreeMap<String, String>,
995 dim: usize,
996) -> Result<Vec<bool>, String> {
997 let mut axes = vec![false; dim];
998 if let Some(raw) = options.get("periodic").or_else(|| options.get("cyclic")) {
999 let lowered = raw.trim().to_ascii_lowercase();
1000 match lowered.as_str() {
1001 "true" | "yes" | "y" => {
1002 axes.fill(true);
1003 return Ok(axes);
1004 }
1005 "false" | "no" | "n" => return Ok(axes),
1006 _ => {}
1007 }
1008 for axis_raw in parse_option_list(raw) {
1009 let axis = axis_raw
1010 .parse::<usize>()
1011 .map_err(|err| format!("invalid periodic axis '{axis_raw}': {err}"))?;
1012 if axis >= dim {
1013 return Err(format!(
1014 "periodic axis {axis} out of range for {dim}D smooth"
1015 ));
1016 }
1017 axes[axis] = true;
1018 }
1019 }
1020 if let Some(raw) = options.get("boundary").or_else(|| options.get("bc")) {
1021 let boundary = parse_option_list(raw);
1022 if boundary.len() == dim {
1023 for (axis, value) in boundary.iter().enumerate() {
1024 if matches!(value.as_str(), "periodic" | "cyclic" | "cc") {
1025 axes[axis] = true;
1026 }
1027 }
1028 } else if dim == 1
1029 && matches!(
1030 boundary.first().map(String::as_str),
1031 Some("periodic" | "cyclic" | "cc")
1032 )
1033 {
1034 axes[0] = true;
1035 }
1036 }
1037 Ok(axes)
1038}
1039
1040fn parse_optional_numeric_list(
1041 options: &BTreeMap<String, String>,
1042 keys: &[&str],
1043 dim: usize,
1044) -> Result<Vec<Option<f64>>, String> {
1045 let Some(raw) = keys.iter().find_map(|key| options.get(*key)) else {
1046 return Ok(vec![None; dim]);
1047 };
1048 let values = split_list_option(raw);
1049 let mut out = vec![None; dim];
1050 if values.len() == 1 && dim == 1 {
1051 if !values[0].eq_ignore_ascii_case("none") {
1052 out[0] = Some(parse_numeric_expr(&values[0])?);
1053 }
1054 return Ok(out);
1055 }
1056 if values.len() != dim {
1057 return Err(format!(
1058 "numeric option list length {} must match smooth dimension {}",
1059 values.len(),
1060 dim
1061 ));
1062 }
1063 for (i, value) in values.iter().enumerate() {
1064 if !value.eq_ignore_ascii_case("none") {
1065 out[i] = Some(parse_numeric_expr(value)?);
1066 }
1067 }
1068 Ok(out)
1069}
1070
1071fn parse_periods(
1072 options: &BTreeMap<String, String>,
1073 periodic_axes: &[bool],
1074) -> Result<Vec<Option<f64>>, String> {
1075 let dim = periodic_axes.len();
1076 let lone_periodic_broadcast = options
1081 .get("period")
1082 .or_else(|| options.get("periods"))
1083 .and_then(|raw| {
1084 let values = split_list_option(raw);
1085 if values.len() != 1 || dim <= 1 {
1086 return None;
1087 }
1088 let mut iter = periodic_axes.iter().enumerate().filter(|(_, p)| **p);
1089 let first = iter.next()?;
1090 if iter.next().is_some() {
1091 return None;
1092 }
1093 Some((first.0, values.into_iter().next().unwrap()))
1094 });
1095 let periods = if let Some((axis, value)) = lone_periodic_broadcast {
1096 let mut out = vec![None; dim];
1097 if !value.eq_ignore_ascii_case("none") {
1098 out[axis] = Some(parse_numeric_expr(&value)?);
1099 }
1100 out
1101 } else {
1102 parse_optional_numeric_list(options, &["period", "periods"], dim)?
1103 };
1104 for (axis, (periodic, period)) in periodic_axes.iter().zip(periods.iter()).enumerate() {
1105 if *periodic
1106 && let Some(value) = period
1107 && (!value.is_finite() || *value <= 0.0)
1108 {
1109 return Err(format!(
1110 "period for periodic axis {axis} must be finite and positive, got {value}"
1111 ));
1112 }
1113 }
1114 Ok(periods)
1115}
1116
1117fn parse_period_origins(
1118 options: &BTreeMap<String, String>,
1119 periodic_axes: &[bool],
1120) -> Result<Vec<Option<f64>>, String> {
1121 parse_optional_numeric_list(
1122 options,
1123 &[
1124 "origin",
1125 "origins",
1126 "period_origin",
1127 "period-origin",
1128 "domain_origin",
1129 ],
1130 periodic_axes.len(),
1131 )
1132}
1133
1134fn parse_tensor_periodic_axes(
1142 options: &BTreeMap<String, String>,
1143 dim: usize,
1144) -> Result<Vec<bool>, String> {
1145 let mut axes = vec![false; dim];
1146 if let Some(raw) = options.get("periodic").or_else(|| options.get("cyclic")) {
1147 let lowered = raw.trim().to_ascii_lowercase();
1148 match lowered.as_str() {
1149 "true" | "yes" | "y" => {
1150 axes.fill(true);
1151 }
1152 "false" | "no" | "n" => {
1153 }
1155 _ => {
1156 let entries = parse_option_list(raw);
1157 let all_bool = !entries.is_empty()
1158 && entries.iter().all(|v| {
1159 matches!(
1160 v.as_str(),
1161 "true" | "yes" | "y" | "false" | "no" | "n" | "none"
1162 )
1163 });
1164 if all_bool {
1165 if entries.len() != dim {
1166 return Err(format!(
1167 "periodic list length {} must match smooth dimension {}",
1168 entries.len(),
1169 dim
1170 ));
1171 }
1172 for (i, v) in entries.iter().enumerate() {
1173 axes[i] = matches!(v.as_str(), "true" | "yes" | "y");
1174 }
1175 } else {
1176 for axis_raw in entries {
1177 let axis = axis_raw
1178 .parse::<usize>()
1179 .map_err(|err| format!("invalid periodic axis '{axis_raw}': {err}"))?;
1180 if axis >= dim {
1181 return Err(format!(
1182 "periodic axis {axis} out of range for {dim}D smooth"
1183 ));
1184 }
1185 axes[axis] = true;
1186 }
1187 }
1188 }
1189 }
1190 }
1191 if let Some(raw) = options.get("boundary").or_else(|| options.get("bc")) {
1192 let boundary = parse_option_list(raw);
1193 if boundary.len() == dim {
1194 for (axis, value) in boundary.iter().enumerate() {
1195 if matches!(value.as_str(), "periodic" | "cyclic" | "cc") {
1196 axes[axis] = true;
1197 }
1198 }
1199 }
1200 }
1201 Ok(axes)
1202}
1203
1204fn tensor_k_axis_option_axis(
1205 key: &str,
1206 cols: &[usize],
1207 ds: &Dataset,
1208) -> Result<Option<usize>, String> {
1209 let Some(suffix) = key.strip_prefix("k_") else {
1210 return Ok(None);
1211 };
1212 if suffix.is_empty() {
1213 return Err("tensor k axis option must be named k_<axis> or k_<variable>".to_string());
1214 }
1215 if let Ok(axis) = suffix.parse::<usize>() {
1216 return if axis < cols.len() {
1217 Ok(Some(axis))
1218 } else {
1219 Err(format!(
1220 "tensor k axis option `{key}` references axis {axis}, but the smooth has {} margins",
1221 cols.len()
1222 ))
1223 };
1224 }
1225
1226 let mut matches = cols
1227 .iter()
1228 .enumerate()
1229 .filter(|(_, col)| ds.headers.get(**col).is_some_and(|name| name == suffix))
1230 .map(|(axis, _)| axis);
1231 let first = matches.next();
1232 if matches.next().is_some() {
1233 return Err(format!(
1234 "tensor k axis option `{key}` matches more than one margin named `{suffix}`"
1235 ));
1236 }
1237 first.map(Some).ok_or_else(|| {
1238 let margin_names = cols
1239 .iter()
1240 .enumerate()
1241 .map(|(axis, col)| {
1242 let name = ds
1243 .headers
1244 .get(*col)
1245 .map(String::as_str)
1246 .unwrap_or("<unnamed>");
1247 format!("{axis}:{name}")
1248 })
1249 .collect::<Vec<_>>()
1250 .join(", ");
1251 format!(
1252 "tensor k axis option `{key}` does not match a margin index or name; tensor margins are [{margin_names}]"
1253 )
1254 })
1255}
1256
1257fn is_tensor_k_axis_option_key(key: &str) -> bool {
1258 key.strip_prefix("k_")
1259 .is_some_and(|suffix| !suffix.is_empty())
1260}
1261
1262fn parse_tensor_k_list(
1266 options: &BTreeMap<String, String>,
1267 cols: &[usize],
1268 ds: &Dataset,
1269) -> Result<(Vec<usize>, bool), String> {
1270 let mut axis_values = vec![None; cols.len()];
1271 let mut saw_axis_alias = false;
1272 for (key, value) in options {
1273 let Some(axis) = tensor_k_axis_option_axis(key, cols, ds)? else {
1274 continue;
1275 };
1276 saw_axis_alias = true;
1277 if axis_values[axis].is_some() {
1278 return Err(format!("tensor k axis {axis} is specified more than once"));
1279 }
1280 let k: usize = value
1281 .parse()
1282 .map_err(|err| format!("invalid tensor k option `{key}={value}`: {err}"))?;
1283 axis_values[axis] = Some(k);
1284 }
1285
1286 let raw = options
1287 .get("k")
1288 .or_else(|| options.get("basis_dim"))
1289 .or_else(|| options.get("basis-dim"))
1290 .or_else(|| options.get("basisdim"));
1291 if saw_axis_alias {
1292 if raw.is_some() {
1293 return Err(
1294 "tensor k axis aliases cannot be combined with k= or basis_dim=".to_string(),
1295 );
1296 }
1297 if let Some(missing_axis) = axis_values.iter().position(Option::is_none) {
1298 let margin_name = cols
1299 .get(missing_axis)
1300 .and_then(|col| ds.headers.get(*col))
1301 .map(String::as_str)
1302 .unwrap_or("<unnamed>");
1303 return Err(format!(
1304 "tensor k axis aliases must specify every margin; missing axis {missing_axis} ({margin_name})"
1305 ));
1306 }
1307 return Ok((
1308 axis_values
1309 .into_iter()
1310 .map(|k| k.expect("missing axis values rejected above"))
1311 .collect(),
1312 false,
1313 ));
1314 }
1315 let Some(raw) = raw else {
1316 let inferred = heuristic_tensor_margin_knots(cols, ds);
1317 return Ok((inferred, true));
1318 };
1319 let entries = split_list_option(raw);
1320 if entries.len() == 1 {
1321 let k: usize = entries[0]
1322 .parse()
1323 .map_err(|err| format!("invalid tensor k '{}': {err}", entries[0]))?;
1324 return Ok((vec![k; cols.len()], false));
1325 }
1326 if entries.len() != cols.len() {
1327 return Err(format!(
1328 "tensor k list length {} must match smooth dimension {}",
1329 entries.len(),
1330 cols.len()
1331 ));
1332 }
1333 let mut out = Vec::with_capacity(entries.len());
1334 for entry in entries {
1335 let k: usize = entry
1336 .parse()
1337 .map_err(|err| format!("invalid tensor k '{entry}': {err}"))?;
1338 out.push(k);
1339 }
1340 Ok((out, false))
1341}
1342
1343fn parse_tensor_identifiability(
1352 options: &BTreeMap<String, String>,
1353 kind: SmoothKind,
1354) -> Result<TensorBSplineIdentifiability, String> {
1355 let Some(raw) = options.get("identifiability").map(String::as_str) else {
1356 return Ok(match kind {
1357 SmoothKind::Ti => TensorBSplineIdentifiability::MarginalSumToZero,
1358 _ => TensorBSplineIdentifiability::default(),
1359 });
1360 };
1361 match raw.trim().to_ascii_lowercase().as_str() {
1362 "none" => Ok(TensorBSplineIdentifiability::None),
1363 "sum_tozero" | "sum-to-zero" | "center_sum_tozero" | "center-sum-to-zero" | "centered"
1364 | "sumtozero" => Ok(TensorBSplineIdentifiability::SumToZero),
1365 "marginal_sum_tozero" | "marginal-sum-to-zero" | "marginal_sumtozero"
1366 | "marginalsumtozero" | "interaction" => {
1367 Ok(TensorBSplineIdentifiability::MarginalSumToZero)
1368 }
1369 other => Err(TermBuilderError::unsupported_feature(format!(
1370 "invalid tensor identifiability '{other}'; expected one of: none, sum_tozero, marginal_sum_tozero"
1371 ))
1372 .to_string()),
1373 }
1374}
1375
1376fn bspline_boundary_declares_periodic_axis(options: &BTreeMap<String, String>) -> bool {
1377 options
1378 .get("boundary")
1379 .or_else(|| options.get("bc"))
1380 .map(|raw| {
1381 parse_option_list(raw)
1382 .into_iter()
1383 .any(|value| matches!(value.as_str(), "periodic" | "cyclic" | "cc"))
1384 })
1385 .unwrap_or(false)
1386}
1387
1388pub(crate) fn canonicalize_smooth_type(raw: &str) -> &str {
1410 match raw {
1411 "tp" => "tps",
1414 "gp" => "matern",
1419 "curv" | "constant_curvature" | "mkappa" => "curvature",
1423 "mjs" | "measure_jet" | "web" => "measurejet",
1427 other => other,
1428 }
1429}
1430
1431pub(crate) fn tensor_margin_bs_is_supported(margin_bs: &str) -> bool {
1442 matches!(
1443 canonicalize_smooth_type(margin_bs),
1444 "tps" | "ps" | "bs" | "bspline" | "cr" | "cs" | "cc" | "cp" | "cyclic"
1445 )
1446}
1447
1448pub(crate) fn smooth_options_declare_periodic(options: &BTreeMap<String, String>) -> bool {
1454 options.contains_key("periodic")
1455 || options.contains_key("cyclic")
1456 || options
1457 .get("boundary")
1458 .or_else(|| options.get("bc"))
1459 .map(|boundary| {
1460 boundary.to_ascii_lowercase().contains("periodic")
1461 || boundary.to_ascii_lowercase().contains("cyclic")
1462 })
1463 .unwrap_or(false)
1464}
1465
1466pub(crate) fn bs_selector_is_vector(raw: &str) -> bool {
1483 let trimmed = raw.trim();
1484 let bracketed = (trimmed.starts_with('[') && trimmed.ends_with(']'))
1485 || (trimmed.starts_with("c(") || trimmed.starts_with("C(")) && trimmed.ends_with(')')
1486 || (trimmed.starts_with('(') && trimmed.ends_with(')'));
1487 bracketed && !parse_option_list(trimmed).is_empty()
1488}
1489
1490pub fn resolve_smooth_type_name(
1491 kind: SmoothKind,
1492 n_cols: usize,
1493 options: &BTreeMap<String, String>,
1494) -> String {
1495 let selector = options.get("type").or_else(|| options.get("bs"));
1496 if let Some(raw) = selector
1501 && bs_selector_is_vector(raw)
1502 && matches!(kind, SmoothKind::Te | SmoothKind::Ti | SmoothKind::T2)
1503 {
1504 return "tensor".to_string();
1505 }
1506 selector
1507 .map(|s| canonicalize_smooth_type(&s.to_ascii_lowercase()).to_string())
1508 .unwrap_or_else(|| match kind {
1509 SmoothKind::Te | SmoothKind::Ti | SmoothKind::T2 => "tensor".to_string(),
1510 SmoothKind::S if n_cols == 1 => "bspline".to_string(),
1511 SmoothKind::S if smooth_options_declare_periodic(options) => "tensor".to_string(),
1515 SmoothKind::S => "tps".to_string(),
1516 })
1517}
1518
1519pub fn smooth_type_uses_spatial_center_heuristic(canonical_type: &str) -> bool {
1528 matches!(canonical_type, "tps" | "matern" | "duchon")
1529}
1530
1531pub fn build_smooth_basis(
1532 kind: SmoothKind,
1533 vars: &[String],
1534 cols: &[usize],
1535 options: &BTreeMap<String, String>,
1536 ds: &Dataset,
1537 inference_notes: &mut Vec<String>,
1538 policy: &ResourcePolicy,
1539 smooth_coordinate_count: usize,
1540) -> Result<SmoothBasisSpec, String> {
1541 let coord_cols: Vec<(&String, usize)> = vars
1557 .iter()
1558 .zip(cols.iter().copied())
1559 .filter(|(_, col)| !matches!(ds.column_kinds.get(*col), Some(ColumnKindTag::Categorical)))
1560 .collect();
1561 if !coord_cols.is_empty() {
1562 let views: Vec<ArrayView1<'_, f64>> = coord_cols
1563 .iter()
1564 .map(|(_, col)| ds.values.column(*col))
1565 .collect();
1566 let n_rows = views[0].len();
1567 let mut distinct_points = std::collections::HashSet::<Vec<u64>>::new();
1568 for r in 0..n_rows {
1569 let key: Vec<u64> = views
1570 .iter()
1571 .map(|v| {
1572 let x = v[r];
1573 let norm = if x == 0.0 { 0.0 } else { x };
1574 norm.to_bits()
1575 })
1576 .collect();
1577 distinct_points.insert(key);
1578 if distinct_points.len() > 1 {
1579 break;
1580 }
1581 }
1582 if distinct_points.len() <= 1 {
1583 return Err(TermBuilderError::degenerate_data(if coord_cols.len() == 1 {
1584 let var = coord_cols[0].0;
1585 format!(
1586 "smooth term over '{var}' has only one unique value in the training data \
1587 — a smooth on a constant column is degenerate and would only fit the response mean. \
1588 Remove `{var}` from the smooth, drop the term, or check the data."
1589 )
1590 } else {
1591 let names = coord_cols
1592 .iter()
1593 .map(|(v, _)| v.as_str())
1594 .collect::<Vec<_>>()
1595 .join(", ");
1596 format!(
1597 "smooth term over ({names}) has only one unique joint coordinate in the training \
1598 data — every coordinate is constant, so the smooth is degenerate and would only \
1599 fit the response mean. Drop the term or check the data."
1600 )
1601 })
1602 .to_string());
1603 }
1604 }
1605 if let Some(by_name) = options.get("by").cloned() {
1606 let by_col = options
1607 .get("__by_col")
1608 .and_then(|raw| raw.parse::<usize>().ok())
1609 .or_else(|| vars.iter().position(|v| v == &by_name).map(|idx| cols[idx]))
1610 .ok_or_else(|| format!("unknown by= column '{by_name}'"))?;
1611 let mut inner_options = options.clone();
1612 inner_options.remove("by");
1613 inner_options.remove("__by_col");
1614 inner_options.remove("id");
1615 let inner = build_smooth_basis(
1616 kind,
1617 vars,
1618 cols,
1619 &inner_options,
1620 ds,
1621 inference_notes,
1622 policy,
1623 smooth_coordinate_count,
1624 )?;
1625 let by_kind = match ds.column_kinds.get(by_col).copied() {
1626 Some(ColumnKindTag::Categorical) => ByVarKind::Factor {
1627 feature_col: by_col,
1628 ordered: option_bool(options, "ordered").unwrap_or(false),
1629 frozen_levels: None,
1630 },
1631 Some(ColumnKindTag::Continuous | ColumnKindTag::Binary) => ByVarKind::Numeric {
1632 feature_col: by_col,
1633 },
1634 None => {
1635 return Err(format!(
1636 "internal column-kind lookup failed for by='{by_name}'"
1637 ));
1638 }
1639 };
1640 return Ok(SmoothBasisSpec::BySmooth {
1641 smooth: Box::new(inner),
1642 by_kind,
1643 });
1644 }
1645
1646 let smooth_double_penalty = option_bool(options, "double_penalty").unwrap_or(true);
1647 let type_opt = resolve_smooth_type_name(kind, cols.len(), options);
1648
1649 if matches!(type_opt.as_str(), "fs" | "sz" | "re") {
1650 validate_known_options(
1651 type_opt.as_str(),
1652 options,
1653 &[
1654 "type",
1655 "bs",
1656 "k",
1657 "basis_dim",
1658 "basis-dim",
1659 "basisdim",
1660 "knots",
1661 "knot_placement",
1662 "knot-placement",
1663 "knotplacement",
1664 "degree",
1665 "penalty_order",
1666 "m",
1667 "double_penalty",
1668 "ordered",
1669 ],
1670 )?;
1671 if cols.len() != 2 {
1672 return Err(format!(
1673 "{} factor-smooth currently expects exactly two variables (one numeric, one categorical)",
1674 type_opt
1675 ));
1676 }
1677 let kinds = cols
1678 .iter()
1679 .map(|&c| ds.column_kinds.get(c).copied())
1680 .collect::<Vec<_>>();
1681 let (cont_idx, group_idx) = if type_opt == "re" {
1682 match (kinds[0], kinds[1]) {
1684 (Some(ColumnKindTag::Categorical), _) => (1usize, 0usize),
1685 (_, Some(ColumnKindTag::Categorical)) => (0usize, 1usize),
1686 _ => (1usize, 0usize),
1687 }
1688 } else {
1689 match (kinds[0], kinds[1]) {
1690 (_, Some(ColumnKindTag::Categorical)) => (0usize, 1usize),
1691 (Some(ColumnKindTag::Categorical), _) => (1usize, 0usize),
1692 _ => {
1693 return Err(format!(
1694 "{} factor-smooth requires one categorical factor variable",
1695 type_opt
1696 ));
1697 }
1698 }
1699 };
1700 let c = cols[cont_idx];
1701 let (minv, maxv) = col_minmax(ds.values.column(c))?;
1702 let degree = if type_opt == "re" {
1703 1
1704 } else {
1705 option_usize(options, "degree").unwrap_or(DEFAULT_BSPLINE_DEGREE)
1706 };
1707 let pooled_internal = heuristic_knots_for_column(ds.values.column(c));
1727 let default_internal = if type_opt == "re" {
1728 0
1741 } else {
1742 let min_group_resolution =
1743 min_per_group_unique_count(ds.values.column(c), ds.values.column(cols[group_idx]));
1744 let basis_cap = min_group_resolution.saturating_sub(2).max(degree + 2);
1752 let internal_cap = basis_cap.saturating_sub(degree + 1);
1753 let capped = pooled_internal.min(internal_cap.max(1));
1754 let fs_default_internal = FACTOR_SMOOTH_DEFAULT_BASIS_DIM
1770 .saturating_sub(degree + 1)
1771 .max(1);
1772 capped.min(fs_default_internal)
1773 };
1774 let (n_knots, _, effective_degree) =
1775 parse_ps_internal_knots(options, degree, default_internal)?;
1776 let penalty_order = option_usize(options, "penalty_order")
1777 .unwrap_or(if effective_degree > 1 { 2 } else { 1 })
1778 .min(effective_degree);
1779 let marginal_knotspec = resolve_nonperiodic_bspline_knotspec(
1813 options,
1814 ds.values.column(c),
1815 (minv, maxv),
1816 effective_degree,
1817 n_knots,
1818 )?;
1819 let marginal = BSplineBasisSpec {
1820 degree: effective_degree,
1821 penalty_order,
1822 knotspec: marginal_knotspec,
1823 double_penalty: option_bool(options, "double_penalty")
1834 .unwrap_or(type_opt.as_str() != "sz"),
1835 identifiability: BSplineIdentifiability::None,
1836 boundary_conditions: Default::default(),
1837 boundary: OneDimensionalBoundary::Open,
1838 };
1839 let flavour = match type_opt.as_str() {
1840 "fs" => FactorSmoothFlavour::Fs {
1841 m_null_penalty_orders: vec![
1842 option_usize(options, "m").unwrap_or(DEFAULT_PENALTY_ORDER),
1843 ],
1844 },
1845 "sz" => FactorSmoothFlavour::Sz,
1846 "re" => FactorSmoothFlavour::Re,
1847 other => {
1849 return Err(format!(
1850 "internal: factor-smooth flavour dispatch reached unexpected type `{}`",
1851 other
1852 ));
1853 }
1854 };
1855 return Ok(SmoothBasisSpec::FactorSmooth {
1856 spec: FactorSmoothSpec {
1857 continuous_cols: vec![c],
1858 group_col: cols[group_idx],
1859 marginal,
1860 flavour,
1861 group_frozen_levels: None,
1862 frozen_global_orthogonality: None,
1863 },
1864 });
1865 }
1866
1867 match type_opt.as_str() {
1868 "cyclic" | "cc" | "cp" | "cyclic-ps" => {
1869 validate_known_options(
1870 "cyclic",
1871 options,
1872 &[
1873 "type",
1874 "bs",
1875 "by",
1876 "k",
1877 "basis_dim",
1878 "basis-dim",
1879 "basisdim",
1880 "degree",
1881 "penalty_order",
1882 "period",
1883 "periods",
1884 "period_start",
1885 "period_end",
1886 "start",
1887 "end",
1888 "origin",
1889 "origins",
1890 "period_origin",
1891 "period-origin",
1892 "domain_origin",
1893 "double_penalty",
1894 "id",
1895 "__by_col",
1896 "identifiability",
1897 ],
1898 )?;
1899 if cols.len() != 1 {
1900 return Err(format!(
1901 "periodic smooth expects one variable, got {}",
1902 cols.len()
1903 ));
1904 }
1905 let c = cols[0];
1906 let (minv, maxv) = col_minmax(ds.values.column(c))?;
1907 let degree = option_usize(options, "degree").unwrap_or(DEFAULT_BSPLINE_DEGREE);
1908 let mut default_internal = heuristic_knots_for_column(ds.values.column(c));
1909 if ds.values.nrows() <= 32 && smooth_coordinate_count >= 5 {
1910 default_internal = default_internal.min(1);
1911 }
1912 let cyclic_default_basis_cap = CYCLIC_DEFAULT_BASIS_DIM.max(degree + 1);
1928 let default_basis = (default_internal + degree + 1).min(cyclic_default_basis_cap);
1929 let num_basis = option_usize_any(options, &["k", "basis_dim", "basis-dim", "basisdim"])
1930 .unwrap_or(default_basis);
1931 if num_basis < degree + 1 {
1932 return Err(format!(
1933 "periodic smooth: k={} too small for degree {}; expected k >= {}",
1934 num_basis,
1935 degree,
1936 degree + 1
1937 ));
1938 }
1939 let periodic_axes = [true];
1950 let periods = parse_periods(options, &periodic_axes)?;
1951 let origins = parse_period_origins(options, &periodic_axes)?;
1952 let (domain_start, period) = if let Some(p) = periods[0] {
1953 (origins[0].unwrap_or(minv), p)
1954 } else {
1955 parse_periodic_domain_1d(options, minv, maxv)?
1956 };
1957 Ok(SmoothBasisSpec::BSpline1D {
1958 feature_col: c,
1959 spec: BSplineBasisSpec {
1960 degree,
1961 penalty_order: option_usize(options, "penalty_order")
1962 .unwrap_or(DEFAULT_PENALTY_ORDER),
1963 knotspec: BSplineKnotSpec::PeriodicUniform {
1964 data_range: (domain_start, domain_start + period),
1965 num_basis,
1966 },
1967 double_penalty: smooth_double_penalty,
1968 identifiability: BSplineIdentifiability::default(),
1969 boundary_conditions: Default::default(),
1970 boundary: OneDimensionalBoundary::Cyclic {
1971 start: domain_start,
1972 end: domain_start + period,
1973 },
1974 },
1975 })
1976 }
1977 "bspline" | "ps" | "p-spline" | "cr" | "cs" => {
1978 let validation_name = match type_opt.as_str() {
1992 "cr" => "cr",
1993 "cs" => "cs",
1994 _ => "bspline",
1995 };
1996 validate_known_options(
1997 validation_name,
1998 options,
1999 &[
2000 "type",
2001 "bs",
2002 "by",
2003 "k",
2004 "basis_dim",
2005 "basis-dim",
2006 "basisdim",
2007 "knots",
2008 "knot_placement",
2009 "knot-placement",
2010 "knotplacement",
2011 "degree",
2012 "penalty_order",
2013 "boundary",
2014 "bc",
2015 "boundary_conditions",
2016 "bc_left",
2017 "bc_right",
2018 "left_bc",
2019 "right_bc",
2020 "start_bc",
2021 "end_bc",
2022 "side",
2023 "anchor",
2024 "anchor_value",
2025 "value",
2026 "anchor_left",
2027 "left_anchor",
2028 "anchor_right",
2029 "right_anchor",
2030 "periodic",
2031 "period",
2032 "periods",
2033 "period_start",
2034 "period_end",
2035 "origin",
2036 "double_penalty",
2037 "by",
2038 "id",
2039 "__by_col",
2040 "identifiability",
2041 "by",
2042 ],
2043 )?;
2044 if cols.len() != 1 {
2045 return Err(TermBuilderError::incompatible_config(format!(
2046 "bspline smooth expects one variable, got {}",
2047 cols.len()
2048 ))
2049 .to_string());
2050 }
2051 let c = cols[0];
2052 let (minv, maxv) = col_minmax(ds.values.column(c))?;
2053 let degree = option_usize(options, "degree").unwrap_or(DEFAULT_BSPLINE_DEGREE);
2054 let default_internal = heuristic_knots_for_column(ds.values.column(c));
2055 let (mut n_knots, inferred, effective_degree) =
2056 parse_ps_internal_knots(options, degree, default_internal)?;
2057 let periodic_axes = parse_periodic_axes(options, 1).map_err(|e| e.to_string())?;
2058 if periodic_axes[0] && effective_degree != degree {
2063 return Err(TermBuilderError::invalid_option(format!(
2064 "periodic smooth: k={} too small for degree {}; expected k >= {}",
2065 effective_degree + 1,
2066 degree,
2067 degree + 1
2068 ))
2069 .to_string());
2070 }
2071 if inferred && ds.values.nrows() <= 32 && smooth_coordinate_count >= 5 {
2072 n_knots = n_knots.min(1);
2073 }
2074 if inferred {
2075 let unique = unique_count_column(ds.values.column(c));
2076 let ceiling = ((unique as f64).cbrt() as usize).max(20);
2077 inference_notes.push(format!(
2078 "Automatically set {} internal knots for smooth '{}' from {} unique values (rule: clamp(unique/4, 4..max(20, cbrt(unique))) = clamp(unique/4, 4..{})). Override with knots=... or k=....",
2079 n_knots,
2080 vars.join(","),
2081 unique,
2082 ceiling,
2083 ));
2084 }
2085 let boundary_conditions =
2086 if periodic_axes[0] && bspline_boundary_declares_periodic_axis(options) {
2087 BSplineBoundaryConditions::default()
2088 } else {
2089 parse_bspline_boundary_conditions(options).map_err(|e| e.to_string())?
2090 };
2091 let periods = parse_periods(options, &periodic_axes).map_err(|e| e.to_string())?;
2092 let origins =
2093 parse_period_origins(options, &periodic_axes).map_err(|e| e.to_string())?;
2094 let (knotspec, boundary) = if periodic_axes[0] {
2095 if !boundary_conditions.is_free() {
2096 return Err(TermBuilderError::incompatible_config(
2097 "periodic B-splines cannot also declare endpoint boundary conditions",
2098 )
2099 .to_string());
2100 }
2101 {
2102 let (domain_start, p_value) = if periods[0].is_some() {
2103 (origins[0].unwrap_or(minv), periods[0].unwrap())
2104 } else {
2105 parse_periodic_domain_1d(options, minv, maxv).map_err(|e| e.to_string())?
2106 };
2107 let domain_end = domain_start + p_value;
2108 (
2109 BSplineKnotSpec::PeriodicUniform {
2110 data_range: (domain_start, domain_end),
2111 num_basis: n_knots + effective_degree + 1,
2112 },
2113 OneDimensionalBoundary::Cyclic {
2114 start: domain_start,
2115 end: domain_end,
2116 },
2117 )
2118 }
2119 } else if type_opt == "cr" || type_opt == "cs" {
2120 let k_cr = (n_knots + effective_degree + 1).max(CR_MIN_KNOTS);
2137 let knotspec = match capped_cr_marginal_knotspec(
2138 ds.values.column(c),
2139 k_cr,
2140 &vars.join(","),
2141 inference_notes,
2142 )? {
2143 Some(cr_knotspec) => cr_knotspec,
2144 None => resolve_nonperiodic_bspline_knotspec(
2145 options,
2146 ds.values.column(c),
2147 (minv, maxv),
2148 effective_degree,
2149 n_knots,
2150 )?,
2151 };
2152 (knotspec, parse_cyclic_boundary(options, minv, maxv)?)
2153 } else {
2154 (
2155 resolve_nonperiodic_bspline_knotspec(
2156 options,
2157 ds.values.column(c),
2158 (minv, maxv),
2159 effective_degree,
2160 n_knots,
2161 )?,
2162 parse_cyclic_boundary(options, minv, maxv)?,
2163 )
2164 };
2165 let double_penalty = if type_opt == "cr" {
2169 option_bool(options, "double_penalty").unwrap_or(false)
2170 } else {
2171 smooth_double_penalty
2172 };
2173 let penalty_order = option_usize(options, "penalty_order")
2178 .unwrap_or(DEFAULT_PENALTY_ORDER)
2179 .min(effective_degree);
2180 Ok(SmoothBasisSpec::BSpline1D {
2181 feature_col: c,
2182 spec: BSplineBasisSpec {
2183 degree: effective_degree,
2184 penalty_order,
2185 knotspec,
2186 double_penalty,
2187 identifiability: BSplineIdentifiability::default(),
2188 boundary,
2189 boundary_conditions,
2190 },
2191 })
2192 }
2193 "tps" | "thinplate" | "thin-plate" => {
2194 validate_known_options(
2195 "thinplate",
2196 options,
2197 &[
2198 SECONDARY_CENTER_CAP_OPTION,
2199 "type",
2200 "bs",
2201 "by",
2202 "length_scale",
2203 "centers",
2204 "k",
2205 "basis_dim",
2206 "basis-dim",
2207 "basisdim",
2208 "knots",
2209 "include_intercept",
2210 "double_penalty",
2211 "by",
2212 "id",
2213 "__by_col",
2214 "identifiability",
2215 "by",
2216 "scale_dims",
2217 ],
2218 )?;
2219 let plan = plan_spatial_basis(
2220 ds.values.nrows(),
2221 cols.len(),
2222 CenterCountRequest::Default,
2223 DuchonNullspaceOrder::Linear,
2224 option_bool(options, "scale_dims").unwrap_or(false),
2225 policy,
2226 )
2227 .map_err(|e| e.to_string())?;
2228 let default_centers = plan.centers;
2238 let centers = parse_countwith_basis_alias(
2239 options,
2240 "centers",
2241 cap_default_spatial_centers(options, default_centers),
2242 )?;
2243 let center_strategy = if has_explicit_countwith_basis_alias(options, "centers") {
2244 spatial_center_strategy_for_dimension(centers, cols.len())
2245 } else {
2246 auto_spatial_center_strategy(centers, cols.len())
2247 };
2248 Ok(SmoothBasisSpec::ThinPlate {
2249 feature_cols: cols.to_vec(),
2250 spec: ThinPlateBasisSpec {
2251 center_strategy,
2252 periodic: parse_periodic_axes_option(options, cols.len())?,
2253 length_scale: option_f64(options, "length_scale").unwrap_or(0.0),
2261 double_penalty: smooth_double_penalty,
2262 identifiability: parse_spatial_identifiability(options)
2263 .map_err(|e| e.to_string())?,
2264 radial_reparam: None,
2265 },
2266 input_scales: None,
2267 })
2268 }
2269 "sphere" | "s2" | "sos" => {
2270 validate_known_options(
2271 "sphere",
2272 options,
2273 &[
2274 "type",
2275 "bs",
2276 "by",
2277 "centers",
2278 "k",
2279 "basis_dim",
2280 "basis-dim",
2281 "basisdim",
2282 "knots",
2283 "penalty_order",
2284 "m",
2285 "double_penalty",
2286 "id",
2287 "__by_col",
2288 "kernel",
2289 "method",
2290 "radians",
2291 "units",
2292 "degree",
2293 "l",
2294 "max_degree",
2295 "max-degree",
2296 ],
2297 )?;
2298 if cols.len() != 2 {
2299 return Err(format!(
2300 "sphere smooth expects exactly two variables (lat, lon), got {}",
2301 cols.len()
2302 ));
2303 }
2304 let radians = option_bool(options, "radians").unwrap_or_else(|| {
2305 options
2306 .get("units")
2307 .map(|u| u.eq_ignore_ascii_case("radian") || u.eq_ignore_ascii_case("radians"))
2308 .unwrap_or(false)
2309 });
2310 let degree_requested = options.contains_key("degree")
2316 || options.contains_key("l")
2317 || options.contains_key("max_degree")
2318 || options.contains_key("max-degree");
2319 let kernel = options
2320 .get("kernel")
2321 .or_else(|| options.get("method"))
2322 .map(|raw| strip_quotes(raw).trim().to_ascii_lowercase())
2323 .unwrap_or_else(|| {
2324 if degree_requested {
2325 "harmonic".to_string()
2326 } else {
2327 "sobolev".to_string()
2328 }
2329 });
2330 let (method, wahba_kernel) = match kernel.as_str() {
2331 "sobolev" | "wahba" | "wahba_sobolev" | "wahba-sobolev" => {
2332 (SphereMethod::Wahba, SphereWahbaKernel::Sobolev)
2333 }
2334 "pseudo" | "mgcv" | "sos" | "wahba_pseudo" | "wahba-pseudo" => {
2335 (SphereMethod::Wahba, SphereWahbaKernel::Pseudo)
2336 }
2337 "harmonic" | "spherical_harmonic" | "spherical-harmonic" => {
2338 (SphereMethod::Harmonic, SphereWahbaKernel::Sobolev)
2339 }
2340 other => {
2341 return Err(format!(
2342 "unsupported sphere kernel '{other}'; expected sobolev, pseudo, or harmonic"
2343 ));
2344 }
2345 };
2346 let max_degree = if matches!(method, SphereMethod::Harmonic) {
2347 let degree =
2348 option_usize_any(options, &["degree", "l", "max_degree", "max-degree"])
2349 .or_else(|| option_usize(options, "centers"))
2350 .or_else(|| {
2351 option_usize_any(options, &["k", "basis_dim", "basis-dim", "basisdim"])
2352 .and_then(|k| (1..=128).find(|&l| l * (l + 2) >= k))
2353 })
2354 .unwrap_or_else(|| default_spherical_harmonic_degree(ds.values.nrows()));
2355 if degree == 0 {
2356 return Err("sphere smooth requires degree/max_degree >= 1".to_string());
2357 }
2358 if degree > 32 {
2359 return Err(format!(
2360 "sphere smooth max_degree={} is too large for the dense harmonic engine (limit 32)",
2361 degree
2362 ));
2363 }
2364 Some(degree)
2365 } else {
2366 None
2367 };
2368 let penalty_order = option_usize(options, "penalty_order")
2369 .or_else(|| option_usize(options, "m"))
2370 .unwrap_or(DEFAULT_PENALTY_ORDER);
2371 let center_strategy = if matches!(method, SphereMethod::Wahba) {
2372 let mut centers = parse_countwith_basis_alias(
2373 options,
2374 "centers",
2375 default_num_centers(ds.values.nrows(), cols.len()),
2376 )?;
2377 if penalty_order >= 4 {
2378 centers = centers.max(30);
2379 }
2380 CenterStrategy::FarthestPoint {
2381 num_centers: centers,
2382 }
2383 } else {
2384 CenterStrategy::FarthestPoint { num_centers: 0 }
2385 };
2386 Ok(SmoothBasisSpec::Sphere {
2387 feature_cols: cols.to_vec(),
2388 spec: SphericalSplineBasisSpec {
2389 center_strategy,
2390 penalty_order,
2391 double_penalty: smooth_double_penalty,
2392 radians,
2393 method,
2394 max_degree,
2395 wahba_kernel,
2396 identifiability: SphericalSplineIdentifiability::CenterSumToZero,
2397 },
2398 })
2399 }
2400 "curvature" => {
2401 validate_known_options(
2407 "curvature",
2408 options,
2409 &[
2410 "type",
2411 "bs",
2412 "by",
2413 "centers",
2414 "k",
2415 "basis_dim",
2416 "basis-dim",
2417 "basisdim",
2418 "knots",
2419 "kappa",
2420 "length_scale",
2421 "double_penalty",
2422 "id",
2423 "__by_col",
2424 ],
2425 )?;
2426 let kappa = option_f64(options, "kappa").unwrap_or(0.0);
2427 if !kappa.is_finite() {
2428 return Err("curvature smooth requires a finite kappa".to_string());
2429 }
2430 let length_scale = option_f64(options, "length_scale").unwrap_or(0.0);
2431 if !length_scale.is_finite() || length_scale < 0.0 {
2432 return Err(format!(
2433 "curvature smooth length_scale must be positive (or omitted for auto); got {length_scale}"
2434 ));
2435 }
2436 let centers = parse_countwith_basis_alias(
2437 options,
2438 "centers",
2439 default_num_centers(ds.values.nrows(), cols.len()),
2440 )?;
2441 if centers < 2 {
2442 return Err("curvature smooth requires at least 2 centers".to_string());
2443 }
2444 Ok(SmoothBasisSpec::ConstantCurvature {
2445 feature_cols: cols.to_vec(),
2446 spec: ConstantCurvatureBasisSpec {
2447 center_strategy: CenterStrategy::FarthestPoint {
2448 num_centers: centers,
2449 },
2450 kappa,
2451 length_scale,
2454 double_penalty: option_bool(options, "double_penalty").unwrap_or(false),
2461 identifiability: ConstantCurvatureIdentifiability::CenterSumToZero,
2462 },
2463 })
2464 }
2465 "measurejet" => {
2466 validate_known_options(
2472 "measurejet",
2473 options,
2474 &[
2475 "type",
2476 "bs",
2477 "by",
2478 "centers",
2479 "k",
2480 "basis_dim",
2481 "basis-dim",
2482 "basisdim",
2483 "knots",
2484 "s",
2485 "alpha",
2486 "tau",
2487 "scales",
2488 "length_scale",
2489 "double_penalty",
2490 "multiscale",
2491 "learn_length_scale",
2492 "id",
2493 "__by_col",
2494 ],
2495 )?;
2496 let order_s = option_f64(options, "s").unwrap_or(0.0);
2497 if !(order_s.is_finite() && (order_s == 0.0 || (order_s > 0.0 && order_s < 2.0))) {
2500 return Err(format!(
2501 "measurejet smooth s must lie in (0, 2) (or be omitted for auto); got {order_s}"
2502 ));
2503 }
2504 let alpha =
2512 option_f64(options, "alpha").unwrap_or(MeasureJetBasisSpec::default().alpha);
2513 if !alpha.is_finite() {
2514 return Err("measurejet smooth requires a finite alpha".to_string());
2515 }
2516 let tau0 = option_f64(options, "tau").unwrap_or(1e-3);
2517 if !(tau0.is_finite() && tau0 >= 0.0) {
2518 return Err(format!(
2519 "measurejet smooth tau must be finite and nonnegative; got {tau0}"
2520 ));
2521 }
2522 let num_scales = option_usize(options, "scales").unwrap_or(0);
2523 let length_scale = option_f64(options, "length_scale").unwrap_or(0.0);
2524 if !length_scale.is_finite() || length_scale < 0.0 {
2525 return Err(format!(
2526 "measurejet smooth length_scale must be positive (or omitted for auto); got {length_scale}"
2527 ));
2528 }
2529 let centers = parse_countwith_basis_alias(
2530 options,
2531 "centers",
2532 default_num_centers(ds.values.nrows(), cols.len()),
2533 )?;
2534 if centers < 3 {
2535 return Err("measurejet smooth requires at least 3 centers".to_string());
2536 }
2537 let multiscale = option_bool(options, "multiscale").unwrap_or(false);
2541 let learn_length_scale = option_bool(options, "learn_length_scale").unwrap_or(false);
2546 Ok(SmoothBasisSpec::MeasureJet {
2547 feature_cols: cols.to_vec(),
2548 spec: MeasureJetBasisSpec {
2549 center_strategy: CenterStrategy::FarthestPoint {
2550 num_centers: centers,
2551 },
2552 order_s,
2553 alpha,
2554 tau0,
2555 num_scales,
2556 length_scale,
2559 double_penalty: smooth_double_penalty,
2560 learn_length_scale,
2561 multiscale,
2562 identifiability: MeasureJetIdentifiability::CenterSumToZero,
2563 frozen_quadrature: None,
2564 },
2565 input_scales: None,
2566 })
2567 }
2568 "matern" => {
2569 validate_known_options(
2574 "matern",
2575 options,
2576 &[
2577 SECONDARY_CENTER_CAP_OPTION,
2578 "type",
2579 "bs",
2580 "by",
2581 "nu",
2582 "length_scale",
2583 "centers",
2584 "k",
2585 "basis_dim",
2586 "basis-dim",
2587 "basisdim",
2588 "knots",
2589 "include_intercept",
2590 "double_penalty",
2591 "by",
2592 "id",
2593 "__by_col",
2594 "identifiability",
2595 "by",
2596 "scale_dims",
2597 ],
2598 )?;
2599 let plan = plan_spatial_basis(
2600 ds.values.nrows(),
2601 cols.len(),
2602 CenterCountRequest::Default,
2603 DuchonNullspaceOrder::Zero,
2604 option_bool(options, "scale_dims").unwrap_or(false),
2605 policy,
2606 )
2607 .map_err(|e| e.to_string())?;
2608 let centers = parse_countwith_basis_alias(
2609 options,
2610 "centers",
2611 cap_default_spatial_centers(
2612 options,
2613 default_matern_center_count(ds.values.nrows(), cols.len(), plan.centers),
2614 ),
2615 )?;
2616 let center_strategy = if has_explicit_countwith_basis_alias(options, "centers") {
2617 spatial_center_strategy_for_dimension(centers, cols.len())
2618 } else {
2619 auto_spatial_center_strategy(centers, cols.len())
2620 };
2621 let nu = parse_matern_nu(options.get("nu").map(String::as_str).unwrap_or("5/2"))?;
2622 if matches!(nu, MaternNu::Half) && cols.len() >= 2 {
2628 return Err(TermBuilderError::unsupported_feature(format!(
2629 "matern() with nu=1/2 is not supported for d>=2 (got {} covariates): \
2630 the exponential kernel's Laplacian is singular at center collisions, \
2631 which makes the operator-collocation penalty non-invertible. \
2632 Choose nu>=3/2 (e.g. nu=3/2 or the default nu=5/2) for multi-dimensional smooths.",
2633 cols.len()
2634 ))
2635 .to_string());
2636 }
2637 let aniso_log_scales = if option_bool(options, "scale_dims").unwrap_or(false) {
2638 Some(vec![0.0; cols.len()])
2639 } else {
2640 None
2641 };
2642 Ok(SmoothBasisSpec::Matern {
2643 feature_cols: cols.to_vec(),
2644 spec: MaternBasisSpec {
2645 center_strategy,
2646 periodic: parse_periodic_axes_option(options, cols.len())?,
2647 length_scale: option_f64(options, "length_scale").unwrap_or(0.0),
2665 nu,
2666 include_intercept: option_bool(options, "include_intercept").unwrap_or(false),
2667 double_penalty: smooth_double_penalty,
2668 identifiability: parse_matern_identifiability(options)
2669 .map_err(|e| e.to_string())?,
2670 aniso_log_scales,
2671 nullspace_shrinkage_survived: None,
2676 },
2677 input_scales: None,
2678 })
2679 }
2680 "duchon" => {
2681 validate_known_options(
2682 "duchon",
2683 options,
2684 &[
2685 SECONDARY_CENTER_CAP_OPTION,
2686 "type",
2687 "bs",
2688 "by",
2689 "length_scale",
2690 "centers",
2691 "k",
2692 "basis_dim",
2693 "basis-dim",
2694 "basisdim",
2695 "knots",
2696 "power",
2697 "p",
2698 "nullspace_order",
2699 "order",
2700 "identifiability",
2701 "by",
2702 "periodic",
2703 "cyclic",
2704 "period",
2705 "period_start",
2706 "period_end",
2707 "scale_dims",
2708 "double_penalty",
2709 "by",
2710 "id",
2711 "__by_col",
2712 ],
2713 )?;
2714 if options.contains_key("double_penalty") {
2715 return Err(TermBuilderError::incompatible_config(format!(
2716 "Duchon smooth '{}' does not support double_penalty; the Duchon smoother already ships its native reproducing-norm penalty plus a null-space shrinkage ridge.",
2717 vars.join(", ")
2718 ))
2719 .to_string());
2720 }
2721 let requested_nullspace_order = parse_duchon_order(options)?;
2722 let length_scale = option_f64_strict(options, "length_scale")?;
2723 let (nullspace_order, power) = match parse_duchon_power_policy(options)? {
2736 DuchonPowerPolicy::Explicit(req_power) => {
2737 if length_scale.is_some() && req_power.fract() != 0.0 {
2738 return Err(TermBuilderError::incompatible_config(format!(
2739 "hybrid Duchon-Matern smooth '{}' (length_scale=...) requires an integer power, got power={}; \
2740 drop length_scale to use the scale-free structural kernel with a fractional power.",
2741 vars.join(", "),
2742 req_power,
2743 ))
2744 .to_string());
2745 }
2746 (requested_nullspace_order, req_power)
2747 }
2748 DuchonPowerPolicy::CubicStructuralDefault => {
2749 match length_scale {
2756 None => crate::basis::duchon_cubic_default(cols.len()),
2757 Some(_) => {
2758 let max_op = crate::basis::duchon_max_active_operator_derivative_order(
2779 &DuchonOperatorPenaltySpec::default(),
2780 );
2781 let (ns, s) = crate::basis::resolve_duchon_orders(
2782 cols.len(),
2783 requested_nullspace_order,
2784 max_op,
2785 length_scale,
2786 );
2787 (ns, s as f64)
2788 }
2789 }
2790 }
2791 };
2792 let plan = plan_spatial_basis(
2793 ds.values.nrows(),
2794 cols.len(),
2795 CenterCountRequest::Default,
2796 nullspace_order,
2797 option_bool(options, "scale_dims").unwrap_or(false),
2798 policy,
2799 )
2800 .map_err(|e| e.to_string())?;
2801 let centers_explicit = has_explicit_countwith_basis_alias(options, "centers");
2802 let requested_centers = parse_countwith_basis_alias(
2803 options,
2804 "centers",
2805 cap_default_spatial_centers(options, plan.centers),
2806 )?;
2807 let polynomial_cols = match nullspace_order {
2808 DuchonNullspaceOrder::Zero => 1,
2809 DuchonNullspaceOrder::Linear => cols.len() + 1,
2810 DuchonNullspaceOrder::Degree(degree) => {
2811 crate::basis::duchon_nullspace_dimension(cols.len(), degree)
2812 }
2813 };
2814 if requested_centers <= polynomial_cols {
2815 return Err(TermBuilderError::incompatible_config(format!(
2816 "Duchon smooth '{}' requested basis dimension {} but order={:?} in {}D needs {} polynomial null-space columns; choose centers/k > {}",
2817 vars.join(", "),
2818 requested_centers,
2819 nullspace_order,
2820 cols.len(),
2821 polynomial_cols,
2822 polynomial_cols,
2823 ))
2824 .to_string());
2825 }
2826 let mut centers = requested_centers;
2827 if !centers_explicit && ds.values.nrows() <= 32 && smooth_coordinate_count >= 5 {
2828 centers = centers.max(polynomial_cols + 4);
2829 }
2830 let center_strategy = if centers_explicit {
2831 spatial_center_strategy_for_dimension(centers, cols.len())
2832 } else {
2833 auto_spatial_center_strategy(centers, cols.len())
2834 };
2835 let aniso_log_scales = if option_bool(options, "scale_dims").unwrap_or(false) {
2836 Some(vec![0.0; cols.len()])
2837 } else {
2838 None
2839 };
2840 let operator_penalties = DuchonOperatorPenaltySpec::default();
2843 Ok(SmoothBasisSpec::Duchon {
2844 feature_cols: cols.to_vec(),
2845 spec: DuchonBasisSpec {
2846 center_strategy,
2847 periodic: parse_periodic_axes_option(options, cols.len())?,
2848 length_scale,
2849 power,
2850 nullspace_order,
2851 identifiability: parse_spatial_identifiability(options)
2852 .map_err(|e| e.to_string())?,
2853 aniso_log_scales,
2854 operator_penalties,
2855 boundary: if cols.len() == 1 {
2856 let c = cols[0];
2857 let (minv, maxv) = col_minmax(ds.values.column(c))?;
2858 parse_cyclic_boundary(options, minv, maxv)?
2859 } else {
2860 OneDimensionalBoundary::Open
2861 },
2862 radial_reparam: None,
2863 },
2864 input_scales: None,
2865 })
2866 }
2867 "tensor" | "te" | "ti" | "t2" => {
2868 validate_known_options(
2869 "tensor",
2870 options,
2871 &[
2872 "type",
2873 "bs",
2874 "by",
2875 "k",
2876 "basis_dim",
2877 "basis-dim",
2878 "basisdim",
2879 "knot_placement",
2880 "knot-placement",
2881 "knotplacement",
2882 "degree",
2883 "penalty_order",
2884 "double_penalty",
2885 "periodic",
2886 "cyclic",
2887 "period",
2888 "periods",
2889 "period_start",
2890 "period_end",
2891 "origin",
2892 "origins",
2893 "period_origin",
2894 "period-origin",
2895 "domain_origin",
2896 "boundary",
2897 "bc",
2898 "identifiability",
2899 "id",
2900 "__by_col",
2901 ],
2902 )?;
2903 if cols.len() < 2 {
2904 return Err(TermBuilderError::incompatible_config(format!(
2905 "tensor smooth expects at least 2 variables, got {}",
2906 cols.len()
2907 ))
2908 .to_string());
2909 }
2910 let dim = cols.len();
2911
2912 if let Some(raw) = options.get("bs").or_else(|| options.get("type"))
2935 && bs_selector_is_vector(raw)
2936 {
2937 let per_margin = parse_option_list(raw);
2938 if per_margin.len() != dim {
2939 return Err(TermBuilderError::invalid_option(format!(
2940 "tensor smooth per-margin bs vector has {} entries but the smooth has {} margins",
2941 per_margin.len(),
2942 dim
2943 ))
2944 .to_string());
2945 }
2946 for (axis, margin_bs) in per_margin.iter().enumerate() {
2947 if !tensor_margin_bs_is_supported(margin_bs) {
2948 return Err(TermBuilderError::unsupported_feature(format!(
2949 "tensor smooth margin {axis} basis '{margin_bs}' is not a supported penalized-spline margin; \
2950 tensor margins accept tp/tps/ps/bs/cr/cc"
2951 ))
2952 .to_string());
2953 }
2954 }
2955 }
2956 let periodic_axes = parse_tensor_periodic_axes(options, dim)?;
2957 let periods_opt = parse_periods(options, &periodic_axes)?;
2958 let origins_opt = parse_period_origins(options, &periodic_axes)?;
2959 let degree = option_usize(options, "degree").unwrap_or(DEFAULT_BSPLINE_DEGREE);
2960 let penalty_order =
2961 option_usize(options, "penalty_order").unwrap_or(if degree > 1 { 2 } else { 1 });
2962 let (mut k_list, k_inferred) = parse_tensor_k_list(options, cols, ds)?;
2963 if ds.values.nrows() <= 32 && smooth_coordinate_count >= 5 {
2964 for k in &mut k_list {
2965 *k = (*k).min(degree + 2);
2966 }
2967 }
2968 if k_inferred {
2969 inference_notes.push(format!(
2970 "Automatically set per-margin basis sizes {:?} for tensor smooth '{}' \
2971 (dimension-aware tensor budget: total ∏k kept near the mgcv-te default \
2972 and within the data support, distributed geometrically across margins and \
2973 capped per margin by each column's resolution). \
2974 Override with k=<int> or k=[k0,k1,...].",
2975 k_list,
2976 vars.join(",")
2977 ));
2978 }
2979 let per_axis_bs: Vec<Option<String>> =
2992 match options.get("bs").or_else(|| options.get("type")) {
2993 Some(raw) if bs_selector_is_vector(raw) => {
2994 let list = parse_option_list(raw);
2995 (0..dim).map(|a| list.get(a).cloned()).collect()
2996 }
2997 Some(raw) => {
2998 let scalar = raw
2999 .trim()
3000 .trim_matches('"')
3001 .trim_matches('\'')
3002 .to_ascii_lowercase();
3003 vec![Some(scalar); dim]
3004 }
3005 None => vec![None; dim],
3006 };
3007 let margin_wants_cr = |bs: &Option<String>| -> bool {
3013 matches!(
3014 bs.as_deref(),
3015 None | Some("cr") | Some("cs") | Some("tp") | Some("tps")
3016 )
3017 };
3018 let mut margins: Vec<BSplineBasisSpec> = Vec::with_capacity(dim);
3019 let mut emitted_periods: Vec<Option<f64>> = Vec::with_capacity(dim);
3020 for axis in 0..dim {
3021 let c = cols[axis];
3022 let (data_min, data_max) = col_minmax(ds.values.column(c))?;
3023 let k_requested = k_list[axis];
3039 let n_distinct_axis = unique_count_column(ds.values.column(c));
3040 let k_axis = k_requested.min(n_distinct_axis).max(2);
3041 if k_axis < k_requested {
3042 log::info!(
3043 "tensor smooth: margin axis {axis} requested k={k_requested}, but the \
3044 covariate has only {n_distinct_axis} distinct value(s); reducing this \
3045 margin to k={k_axis} (mgcv-style data-support cap on the per-axis basis)."
3046 );
3047 }
3048 if k_axis < 2 {
3061 return Err(TermBuilderError::invalid_option(format!(
3062 "tensor smooth: k[{axis}]={k_axis} too small; tensor margins require k >= 2"
3063 ))
3064 .to_string());
3065 }
3066 if periodic_axes[axis] && k_axis < degree + 1 {
3067 return Err(TermBuilderError::invalid_option(format!(
3068 "tensor smooth: periodic axis {axis} requires k >= {} for degree {degree}, got k={k_axis}",
3069 degree + 1
3070 ))
3071 .to_string());
3072 }
3073 let effective_degree = degree.min(k_axis - 1).max(1);
3074 let effective_penalty_order = penalty_order.min(effective_degree);
3075 let (knotspec, boundary, axis_period) = if periodic_axes[axis] {
3076 let period_value = periods_opt[axis].ok_or_else(|| {
3077 format!(
3078 "tensor smooth axis {axis} is periodic but no period was supplied; \
3079 pass period=<value> (scalar) or period=[..., <value>, ...]"
3080 )
3081 })?;
3082 if !period_value.is_finite() || period_value <= 0.0 {
3083 return Err(format!(
3084 "tensor smooth axis {axis}: period must be a positive finite value, got {period_value}"
3085 ));
3086 }
3087 let domain_start = origins_opt[axis].unwrap_or(data_min);
3088 let domain_end = domain_start + period_value;
3089 (
3090 BSplineKnotSpec::PeriodicUniform {
3091 data_range: (domain_start, domain_end),
3092 num_basis: k_axis,
3093 },
3094 OneDimensionalBoundary::Cyclic {
3095 start: domain_start,
3096 end: domain_end,
3097 },
3098 Some(period_value),
3099 )
3100 } else if margin_wants_cr(&per_axis_bs[axis]) && k_axis >= 3 {
3101 let cr_knots =
3111 crate::basis::select_cr_knots(ds.values.column(c), k_axis)
3112 .map_err(|e| e.to_string())?;
3113 (
3114 BSplineKnotSpec::NaturalCubicRegression { knots: cr_knots },
3115 OneDimensionalBoundary::Open,
3116 None,
3117 )
3118 } else {
3119 let num_internal_knots = if effective_degree < degree {
3126 k_axis.saturating_sub(effective_degree + 1)
3127 } else {
3128 k_axis.saturating_sub(degree + 1).max(1)
3129 };
3130 let knotspec = match parse_knot_placement(options)? {
3131 crate::basis::BSplineKnotPlacement::Uniform => BSplineKnotSpec::Generate {
3132 data_range: (data_min, data_max),
3133 num_internal_knots,
3134 },
3135 crate::basis::BSplineKnotPlacement::Quantile => {
3136 crate::basis::auto_knot_vector_1d_quantile(
3137 ds.values.column(c),
3138 num_internal_knots,
3139 effective_degree,
3140 )
3141 .map_err(|e| e.to_string())?;
3142 BSplineKnotSpec::Automatic {
3143 num_internal_knots: Some(num_internal_knots),
3144 placement: crate::basis::BSplineKnotPlacement::Quantile,
3145 }
3146 }
3147 };
3148 (knotspec, OneDimensionalBoundary::Open, None)
3149 };
3150 let is_cr_margin =
3156 matches!(knotspec, BSplineKnotSpec::NaturalCubicRegression { .. });
3157 let margin_double_penalty =
3158 is_cr_margin && matches!(per_axis_bs[axis].as_deref(), Some("cs"));
3159 margins.push(BSplineBasisSpec {
3160 degree: effective_degree,
3161 penalty_order: effective_penalty_order,
3162 knotspec,
3163 double_penalty: margin_double_penalty,
3164 identifiability: BSplineIdentifiability::None,
3165 boundary,
3166 boundary_conditions: BSplineBoundaryConditions::default(),
3167 });
3168 emitted_periods.push(axis_period);
3169 }
3170 let canon_cols: Vec<usize> = {
3191 let mut perm: Vec<usize> = (0..dim).collect();
3192 perm.sort_by_key(|&a| cols[a]);
3193 if perm.iter().enumerate().any(|(i, &a)| i != a) {
3194 margins = perm.iter().map(|&a| margins[a].clone()).collect();
3195 emitted_periods = perm.iter().map(|&a| emitted_periods[a]).collect();
3196 }
3197 perm.iter().map(|&a| cols[a]).collect()
3198 };
3199 let any_periodic = emitted_periods.iter().any(|p| p.is_some());
3200 let periods_vec = if any_periodic {
3201 emitted_periods
3202 } else {
3203 Vec::new()
3204 };
3205 let tensor_double_penalty = option_bool(options, "double_penalty").unwrap_or(false);
3221 Ok(SmoothBasisSpec::TensorBSpline {
3222 feature_cols: canon_cols,
3223 spec: TensorBSplineSpec {
3224 marginalspecs: margins,
3225 periods: periods_vec,
3226 double_penalty: tensor_double_penalty,
3227 identifiability: parse_tensor_identifiability(options, kind)?,
3228 penalty_decomposition: if matches!(kind, SmoothKind::T2)
3238 || type_opt.as_str() == "t2"
3239 {
3240 TensorBSplinePenaltyDecomposition::Separable
3241 } else {
3242 TensorBSplinePenaltyDecomposition::MarginalKroneckerSum
3243 },
3244 },
3245 })
3246 }
3247 "pca" => {
3248 validate_known_options(
3249 "pca",
3250 options,
3251 &[
3252 "type",
3253 "bs",
3254 "by",
3255 "k",
3256 "basis_dim",
3257 "basis-dim",
3258 "basisdim",
3259 "lazy_path",
3260 "path",
3261 "pca_basis_path",
3262 "chunk_size",
3263 "smooth_penalty",
3264 "centered",
3265 "double_penalty",
3266 "id",
3267 "__by_col",
3268 ],
3269 )?;
3270 let path = options
3271 .get("lazy_path")
3272 .or_else(|| options.get("pca_basis_path"))
3273 .or_else(|| options.get("path"))
3274 .map(|raw| PathBuf::from(strip_quotes(raw)));
3275 let Some(path) = path else {
3276 return Err(TermBuilderError::incompatible_config(
3277 "pca smooth requires lazy_path=... on the formula path",
3278 )
3279 .to_string());
3280 };
3281 let k = option_usize_any(options, &["k", "basis_dim", "basis-dim", "basisdim"])
3282 .unwrap_or(0);
3283 let chunk_size = option_usize(options, "chunk_size").unwrap_or(DEFAULT_PCA_CHUNK_SIZE);
3284 Ok(SmoothBasisSpec::Pca {
3285 feature_cols: cols.to_vec(),
3286 basis_matrix: Array2::<f64>::zeros((cols.len(), k)),
3287 centered: option_bool(options, "centered").unwrap_or(true),
3288 smooth_penalty: option_f64(options, "smooth_penalty").unwrap_or(1.0),
3289 center_mean: None,
3290 pca_basis_path: Some(path),
3291 chunk_size,
3292 })
3293 }
3294 other => Err(TermBuilderError::unsupported_feature(format!(
3295 "unsupported smooth type '{other}'"
3296 ))
3297 .to_string()),
3298 }
3299}
3300
3301pub fn enable_scale_dimensions(spec: &mut TermCollectionSpec) {
3303 for smooth in spec.smooth_terms.iter_mut() {
3304 promote_thin_plate_for_scale_dimensions(&mut smooth.basis);
3311 match &mut smooth.basis {
3312 SmoothBasisSpec::Matern {
3313 feature_cols,
3314 spec: matern,
3315 ..
3316 } => {
3317 if matern.aniso_log_scales.is_none() {
3318 let d = feature_cols.len();
3319 matern.aniso_log_scales = Some(vec![0.0; d]);
3320 }
3321 }
3322 SmoothBasisSpec::Duchon {
3323 feature_cols,
3324 spec: duchon,
3325 ..
3326 } => {
3327 if duchon.aniso_log_scales.is_none() {
3328 let d = feature_cols.len();
3329 duchon.aniso_log_scales = Some(vec![0.0; d]);
3330 }
3331 }
3332 _ => {}
3333 }
3334 }
3335}
3336
3337fn promote_thin_plate_for_scale_dimensions(basis: &mut SmoothBasisSpec) {
3372 let SmoothBasisSpec::ThinPlate {
3373 feature_cols,
3374 spec,
3375 input_scales,
3376 } = &*basis
3377 else {
3378 return;
3379 };
3380 let d = feature_cols.len();
3381 if d <= 1 {
3382 return;
3383 }
3384 let m = thin_plate_penalty_order(d);
3389 let nullspace_order = match m {
3390 0 | 1 => DuchonNullspaceOrder::Zero,
3391 2 => DuchonNullspaceOrder::Linear,
3392 _ => DuchonNullspaceOrder::Degree(m - 1),
3393 };
3394 let duchon_spec = DuchonBasisSpec {
3395 center_strategy: spec.center_strategy.clone(),
3396 periodic: spec.periodic.clone(),
3397 length_scale: None,
3402 power: 0.0,
3404 nullspace_order,
3405 identifiability: spec.identifiability.clone(),
3406 aniso_log_scales: Some(vec![0.0; d]),
3410 operator_penalties: DuchonOperatorPenaltySpec::default(),
3411 boundary: OneDimensionalBoundary::Open,
3412 radial_reparam: None,
3413 };
3414 let feature_cols = feature_cols.clone();
3415 let input_scales = input_scales.clone();
3416 *basis = SmoothBasisSpec::Duchon {
3419 feature_cols,
3420 spec: duchon_spec,
3421 input_scales,
3422 };
3423}
3424
3425pub fn spatial_center_strategy_for_dimension(num_centers: usize, d: usize) -> CenterStrategy {
3430 if d <= 3 {
3431 CenterStrategy::FarthestPoint { num_centers }
3438 } else {
3439 default_spatial_center_strategy(num_centers, d)
3440 }
3441}
3442
3443pub fn col_minmax(col: ArrayView1<'_, f64>) -> Result<(f64, f64), String> {
3444 let min = col.iter().fold(f64::INFINITY, |a, &b| a.min(b));
3445 let max = col.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
3446 if !min.is_finite() || !max.is_finite() {
3447 return Err(TermBuilderError::degenerate_data(
3448 "non-finite data encountered while inferring knot range",
3449 )
3450 .to_string());
3451 }
3452 if (max - min).abs() < 1e-12 {
3453 Ok((min, min + 1e-6))
3454 } else {
3455 Ok((min, max))
3456 }
3457}
3458
3459pub fn unique_count_column(col: ArrayView1<'_, f64>) -> usize {
3460 use std::collections::HashSet;
3461 let mut set = HashSet::<u64>::with_capacity(col.len());
3462 for &v in col {
3463 let norm = if v == 0.0 { 0.0 } else { v };
3464 set.insert(norm.to_bits());
3465 }
3466 set.len().max(1)
3467}
3468
3469pub(crate) const CR_MIN_KNOTS: usize = 3;
3475
3476fn capped_cr_marginal_knotspec(
3503 col: ArrayView1<'_, f64>,
3504 k_cr_requested: usize,
3505 label: &str,
3506 inference_notes: &mut Vec<String>,
3507) -> Result<Option<BSplineKnotSpec>, String> {
3508 let n_distinct = unique_count_column(col);
3509 let k_cr = k_cr_requested.min(n_distinct);
3510 if k_cr < CR_MIN_KNOTS {
3511 inference_notes.push(format!(
3512 "Smooth '{label}': cubic-regression ('cr'/'cs'/'sz') basis requested k={k_cr_requested}, \
3513 but the covariate has only {n_distinct} distinct value(s) — too few to support a cubic \
3514 regression spline (needs >= {CR_MIN_KNOTS} distinct values). Degraded to the linear \
3515 B-spline marginal the default basis builds on the same data."
3516 ));
3517 return Ok(None);
3518 }
3519 if k_cr < k_cr_requested {
3520 inference_notes.push(format!(
3521 "Smooth '{label}': cubic-regression ('cr'/'cs'/'sz') basis reduced from k={k_cr_requested} \
3522 to k={k_cr} to match the covariate's {n_distinct} distinct value(s) (mgcv-style \
3523 data-support cap; a cr basis cannot place more value-knots than the data has)."
3524 ));
3525 }
3526 let cr_knots = crate::basis::select_cr_knots(col, k_cr).map_err(|e| e.to_string())?;
3527 Ok(Some(BSplineKnotSpec::NaturalCubicRegression {
3528 knots: cr_knots,
3529 }))
3530}
3531
3532fn min_per_group_unique_count(
3539 feature_col: ArrayView1<'_, f64>,
3540 group_col: ArrayView1<'_, f64>,
3541) -> usize {
3542 use std::collections::{HashMap, HashSet};
3543 let mut per_group: HashMap<u64, HashSet<u64>> = HashMap::new();
3544 for (xi, gi) in feature_col.iter().zip(group_col.iter()) {
3545 let xnorm = if *xi == 0.0 { 0.0 } else { *xi };
3546 let gnorm = if *gi == 0.0 { 0.0 } else { *gi };
3547 per_group
3548 .entry(gnorm.to_bits())
3549 .or_default()
3550 .insert(xnorm.to_bits());
3551 }
3552 per_group
3553 .values()
3554 .map(|s| s.len())
3555 .min()
3556 .unwrap_or(1)
3557 .max(1)
3558}
3559
3560pub fn heuristic_knots_for_column(col: ArrayView1<'_, f64>) -> usize {
3588 const MAX_DEFAULT_INTERNAL_KNOTS: usize = 8;
3591 let unique = unique_count_column(col);
3592 (unique / 4).clamp(4, MAX_DEFAULT_INTERNAL_KNOTS)
3593}
3594
3595fn heuristic_tensor_margin_knots(cols: &[usize], ds: &Dataset) -> Vec<usize> {
3616 let d = cols.len().max(1);
3617 let degree = DEFAULT_BSPLINE_DEGREE;
3618 let min_k = degree + 2; let n = ds.values.nrows();
3620
3621 let per_margin_cap: Vec<usize> = cols
3625 .iter()
3626 .map(|&c| heuristic_knots_for_column(ds.values.column(c)).max(min_k))
3627 .collect();
3628
3629 let mgcv_like_per_margin = match d {
3636 2 => 7usize,
3637 3 => 5usize,
3638 _ => 4usize,
3639 };
3640 let mgcv_like_total = (mgcv_like_per_margin as f64).powi(d as i32);
3641 let data_budget = (n as f64) * 0.8;
3642 let p_target = mgcv_like_total
3643 .max(min_k.pow(d as u32) as f64)
3644 .min(data_budget);
3645
3646 let geo_per_margin = p_target.powf(1.0 / d as f64).round() as usize;
3649 let unclamped: Vec<usize> = per_margin_cap
3650 .iter()
3651 .map(|&cap| geo_per_margin.clamp(min_k, cap))
3652 .collect();
3653
3654 let mut k_list = unclamped;
3659 loop {
3660 let product: f64 = k_list.iter().map(|&k| k as f64).product();
3661 if product >= p_target {
3662 break;
3663 }
3664 let Some(idx) = k_list
3667 .iter()
3668 .zip(per_margin_cap.iter())
3669 .enumerate()
3670 .filter(|&(_, (k, cap))| k < cap)
3671 .max_by_key(|&(_, (k, cap))| (cap - k, *cap))
3672 .map(|(i, _)| i)
3673 else {
3674 break;
3675 };
3676 k_list[idx] += 1;
3677 }
3678 k_list
3679}
3680
3681pub fn heuristic_centers(n: usize, d: usize) -> usize {
3682 default_num_centers(n, d)
3683}
3684
3685fn parse_endpoint_side(
3690 value: &str,
3691 context: &str,
3692) -> Result<BSplineEndpointBoundaryCondition, String> {
3693 match value.trim().to_ascii_lowercase().as_str() {
3694 "" | "none" | "open" | "unconstrained" | "free" => {
3695 Ok(BSplineEndpointBoundaryCondition::Free)
3696 }
3697 "clamped" | "clamp" | "zero_derivative" | "zero-derivative" => {
3698 Ok(BSplineEndpointBoundaryCondition::Clamped)
3699 }
3700 "anchored" | "anchor" | "zero" | "zero_value" | "zero-value" => {
3701 Ok(BSplineEndpointBoundaryCondition::Anchored { value: 0.0 })
3702 }
3703 other => Err(format!(
3704 "unsupported {context} boundary condition '{other}'; expected free, clamped, or anchored"
3705 )),
3706 }
3707}
3708
3709fn boundary_anchor_value(
3710 options: &BTreeMap<String, String>,
3711 side: &str,
3712 fallback: Option<f64>,
3713) -> Option<f64> {
3714 [
3715 format!("anchor_{side}"),
3716 format!("{side}_anchor"),
3717 format!("anchor-value-{side}"),
3718 ]
3719 .iter()
3720 .find_map(|key| option_f64(options, key))
3721 .or(fallback)
3722}
3723
3724fn apply_anchor_value(
3725 cond: BSplineEndpointBoundaryCondition,
3726 value: Option<f64>,
3727) -> BSplineEndpointBoundaryCondition {
3728 match cond {
3729 BSplineEndpointBoundaryCondition::Anchored { .. } => {
3730 BSplineEndpointBoundaryCondition::Anchored {
3731 value: value.unwrap_or(0.0),
3732 }
3733 }
3734 other => other,
3735 }
3736}
3737
3738fn parse_bspline_boundary_conditions(
3739 options: &BTreeMap<String, String>,
3740) -> Result<BSplineBoundaryConditions, String> {
3741 let fallback_anchor = option_f64(options, "anchor")
3742 .or_else(|| option_f64(options, "anchor_value"))
3743 .or_else(|| option_f64(options, "value"));
3744 let global_boundary_conditions = options
3745 .get("boundary_conditions")
3746 .or_else(|| options.get("bc"));
3747 let mut boundary_conditions = BSplineBoundaryConditions::default();
3748
3749 if let Some(raw_boundary_conditions) = global_boundary_conditions {
3750 let cond = parse_endpoint_side(raw_boundary_conditions, "boundary_conditions")?;
3751 let side = options
3752 .get("side")
3753 .map(|s| s.trim().to_ascii_lowercase())
3754 .unwrap_or_else(|| "both".to_string());
3755 match side.as_str() {
3756 "both" | "all" | "endpoints" => {
3757 boundary_conditions.left = cond;
3758 boundary_conditions.right = cond;
3759 }
3760 "left" | "start" | "lower" => boundary_conditions.left = cond,
3761 "right" | "end" | "upper" => boundary_conditions.right = cond,
3762 other => {
3763 return Err(format!(
3764 "unsupported B-spline boundary side '{other}'; expected left, right, or both"
3765 ));
3766 }
3767 }
3768 }
3769
3770 if let Some(raw) = options
3771 .get("bc_left")
3772 .or_else(|| options.get("left_bc"))
3773 .or_else(|| options.get("bc_start"))
3774 .or_else(|| options.get("start_bc"))
3775 {
3776 boundary_conditions.left = parse_endpoint_side(raw, "left endpoint")?;
3777 }
3778 if let Some(raw) = options
3779 .get("bc_right")
3780 .or_else(|| options.get("right_bc"))
3781 .or_else(|| options.get("bc_end"))
3782 .or_else(|| options.get("end_bc"))
3783 {
3784 boundary_conditions.right = parse_endpoint_side(raw, "right endpoint")?;
3785 }
3786
3787 boundary_conditions.left = apply_anchor_value(
3788 boundary_conditions.left,
3789 boundary_anchor_value(options, "left", fallback_anchor),
3790 );
3791 boundary_conditions.right = apply_anchor_value(
3792 boundary_conditions.right,
3793 boundary_anchor_value(options, "right", fallback_anchor),
3794 );
3795
3796 reject_nonzero_anchor("left", boundary_conditions.left)?;
3804 reject_nonzero_anchor("right", boundary_conditions.right)?;
3805
3806 Ok(boundary_conditions)
3807}
3808
3809fn reject_nonzero_anchor(side: &str, cond: BSplineEndpointBoundaryCondition) -> Result<(), String> {
3810 if let BSplineEndpointBoundaryCondition::Anchored { value } = cond {
3811 if value.abs() > 1e-12 {
3812 return Err(format!(
3813 "non-zero {side} anchor {value} requires an affine offset term that is not yet supported; only anchored value 0 is accepted at parse time"
3814 ));
3815 }
3816 }
3817 Ok(())
3818}
3819
3820fn parse_ps_internal_knots(
3834 options: &BTreeMap<String, String>,
3835 degree: usize,
3836 default_internal_knots: usize,
3837) -> Result<(usize, bool, usize), String> {
3838 const MIN_EXPRESSIVE_INTERNAL_KNOTS: usize = 2;
3839 let knots_internal = if knots_option_is_list(options) {
3849 None
3850 } else {
3851 option_usize_strict(options, "knots")?
3852 };
3853 let basis_dim = option_usize_any_strict(options, &["k", "basis_dim", "basis-dim", "basisdim"])?;
3854 if knots_internal.is_some() && basis_dim.is_some() {
3855 return Err(TermBuilderError::incompatible_config(
3856 "ps/bspline smooth: specify either knots=<internal_knots> or k=<basis_dim> (not both)",
3857 )
3858 .to_string());
3859 }
3860 if let Some(k) = basis_dim {
3861 if k < 2 {
3862 return Err(TermBuilderError::invalid_option(format!(
3863 "ps/bspline smooth: k={} too small; B-spline basis requires k >= 2",
3864 k
3865 ))
3866 .to_string());
3867 }
3868 let effective_degree = degree.min(k - 1).max(1);
3874 let num_internal_knots = if effective_degree < degree {
3875 k.saturating_sub(effective_degree + 1)
3878 } else {
3879 (k - degree - 1).max(MIN_EXPRESSIVE_INTERNAL_KNOTS)
3880 };
3881 Ok((num_internal_knots, false, effective_degree))
3882 } else {
3883 Ok((
3884 knots_internal.unwrap_or(default_internal_knots),
3885 knots_internal.is_none(),
3886 degree,
3887 ))
3888 }
3889}
3890
3891fn knots_option_is_list(options: &BTreeMap<String, String>) -> bool {
3897 options
3898 .get("knots")
3899 .map(|raw| {
3900 let t = raw.trim();
3901 t.starts_with('[') || t.starts_with("c(") || t.starts_with("C(") || t.starts_with('(')
3902 })
3903 .unwrap_or(false)
3904}
3905
3906fn parse_explicit_internal_knots(
3911 options: &BTreeMap<String, String>,
3912) -> Result<Option<Vec<f64>>, String> {
3913 if !knots_option_is_list(options) {
3914 return Ok(None);
3915 }
3916 let raw = options
3917 .get("knots")
3918 .expect("knots_option_is_list implies the key is present");
3919 let tokens = split_list_option(raw);
3920 if tokens.is_empty() {
3921 return Err(TermBuilderError::invalid_option(format!(
3922 "knots={raw} is an empty list; supply at least one internal knot position \
3923 (e.g. knots=[0.2, 0.5, 0.8]) or a scalar count (e.g. knots=8)"
3924 ))
3925 .to_string());
3926 }
3927 let mut positions = Vec::with_capacity(tokens.len());
3928 for tok in &tokens {
3929 let value = parse_numeric_expr(tok).map_err(|err| {
3930 TermBuilderError::invalid_option(format!(
3931 "knots list entry '{tok}' is not a numeric position: {err}"
3932 ))
3933 .to_string()
3934 })?;
3935 positions.push(value);
3936 }
3937 Ok(Some(positions))
3938}
3939
3940fn parse_knot_placement(
3946 options: &BTreeMap<String, String>,
3947) -> Result<crate::basis::BSplineKnotPlacement, String> {
3948 use crate::basis::BSplineKnotPlacement;
3949 match options
3950 .get("knot_placement")
3951 .or_else(|| options.get("knot-placement"))
3952 .or_else(|| options.get("knotplacement"))
3953 {
3954 None => Ok(BSplineKnotPlacement::Uniform),
3955 Some(raw) => match raw
3956 .trim()
3957 .trim_matches('"')
3958 .trim_matches('\'')
3959 .to_ascii_lowercase()
3960 .as_str()
3961 {
3962 "uniform" | "even" | "equal" => Ok(BSplineKnotPlacement::Uniform),
3963 "quantile" | "quantiles" | "data" | "empirical" => Ok(BSplineKnotPlacement::Quantile),
3964 other => Err(TermBuilderError::invalid_option(format!(
3965 "knot_placement={other} is not recognised; expected \"uniform\" or \"quantile\""
3966 ))
3967 .to_string()),
3968 },
3969 }
3970}
3971
3972fn resolve_nonperiodic_bspline_knotspec(
3983 options: &BTreeMap<String, String>,
3984 data: ArrayView1<'_, f64>,
3985 data_range: (f64, f64),
3986 degree: usize,
3987 n_knots: usize,
3988) -> Result<BSplineKnotSpec, String> {
3989 use crate::basis::{BSplineKnotPlacement, clamped_knot_vector_from_internal_positions};
3990 if let Some(positions) = parse_explicit_internal_knots(options)? {
3991 if option_usize_any_strict(options, &["k", "basis_dim", "basis-dim", "basisdim"])?.is_some()
3992 {
3993 return Err(TermBuilderError::incompatible_config(
3994 "ps/bspline smooth: specify either explicit knots=[...] positions or \
3995 k=<basis_dim> (not both); the basis size is fixed by the knot vector",
3996 )
3997 .to_string());
3998 }
3999 let knots = clamped_knot_vector_from_internal_positions(data_range, &positions, degree)
4000 .map_err(|e| e.to_string())?;
4001 return Ok(BSplineKnotSpec::Provided(knots));
4002 }
4003 match parse_knot_placement(options)? {
4004 BSplineKnotPlacement::Uniform => Ok(BSplineKnotSpec::Generate {
4005 data_range,
4006 num_internal_knots: n_knots,
4007 }),
4008 BSplineKnotPlacement::Quantile => {
4009 crate::basis::auto_knot_vector_1d_quantile(data, n_knots, degree)
4013 .map_err(|e| e.to_string())?;
4014 Ok(BSplineKnotSpec::Automatic {
4015 num_internal_knots: Some(n_knots),
4016 placement: BSplineKnotPlacement::Quantile,
4017 })
4018 }
4019 }
4020}
4021
4022pub fn validate_known_options(
4028 term_name: &str,
4029 options: &BTreeMap<String, String>,
4030 known: &[&str],
4031) -> Result<(), String> {
4032 let known_set: std::collections::BTreeSet<&&str> = known.iter().collect();
4033 for key in options.keys() {
4034 if !known_set.contains(&key.as_str()) {
4035 if term_name == "tensor" && is_tensor_k_axis_option_key(key) {
4036 continue;
4037 }
4038 let key_l = key.to_ascii_lowercase();
4040 let mut suggestions: Vec<&str> = known
4041 .iter()
4042 .filter(|k| {
4043 let kl = k.to_ascii_lowercase();
4044 kl.contains(&key_l) || key_l.contains(&kl) || {
4045 let n = kl
4046 .chars()
4047 .zip(key_l.chars())
4048 .take_while(|(a, b)| a == b)
4049 .count();
4050 n >= 3
4051 }
4052 })
4053 .copied()
4054 .collect();
4055 suggestions.sort_unstable();
4056 suggestions.dedup();
4057 let hint = if suggestions.is_empty() {
4058 String::new()
4059 } else {
4060 format!(" — did you mean one of [{}]?", suggestions.join(", "))
4061 };
4062 return Err(TermBuilderError::invalid_option(format!(
4063 "{term_name}() does not accept option `{key}`{hint}. Valid options: [{}]",
4064 {
4065 let mut sorted = known.to_vec();
4066 sorted.sort_unstable();
4067 sorted.join(", ")
4068 }
4069 ))
4070 .to_string());
4071 }
4072 }
4073 Ok(())
4074}
4075
4076pub const SECONDARY_CENTER_CAP_OPTION: &str = "__secondary_center_cap";
4086
4087pub(crate) fn cap_default_spatial_centers(
4092 options: &BTreeMap<String, String>,
4093 default_count: usize,
4094) -> usize {
4095 match option_usize(options, SECONDARY_CENTER_CAP_OPTION) {
4096 Some(cap) => default_count.min(cap),
4097 None => default_count,
4098 }
4099}
4100
4101fn default_matern_center_count(n: usize, d: usize, planned_count: usize) -> usize {
4102 let low_n_floor = (d + 4).min(n);
4109 planned_count.max(low_n_floor).max(1)
4110}
4111
4112pub fn parse_countwith_basis_alias(
4113 options: &BTreeMap<String, String>,
4114 primarykey: &str,
4115 default_count: usize,
4116) -> Result<usize, String> {
4117 let primary = option_usize_strict(options, primarykey)?;
4122 let basis_dim = option_usize_any_strict(
4123 options,
4124 &["k", "basis_dim", "basis-dim", "basisdim", "knots"],
4125 )?;
4126 if primary.is_some() && basis_dim.is_some() {
4127 return Err(TermBuilderError::incompatible_config(format!(
4128 "specify either {}=<count> or k=<basis_dim> (not both)",
4129 primarykey
4130 ))
4131 .to_string());
4132 }
4133 Ok(primary.or(basis_dim).unwrap_or(default_count))
4134}
4135
4136pub fn has_explicit_countwith_basis_alias(
4137 options: &BTreeMap<String, String>,
4138 primarykey: &str,
4139) -> bool {
4140 options.contains_key(primarykey)
4141 || ["k", "basis_dim", "basis-dim", "basisdim", "knots"]
4142 .iter()
4143 .any(|alias| options.contains_key(*alias))
4144}
4145
4146pub fn parse_cyclic_boundary(
4147 options: &BTreeMap<String, String>,
4148 minv: f64,
4149 maxv: f64,
4150) -> Result<OneDimensionalBoundary, String> {
4151 let cyclic = option_bool(options, "cyclic")
4152 .or_else(|| option_bool(options, "periodic"))
4153 .unwrap_or(false);
4154 if !cyclic {
4155 return Ok(OneDimensionalBoundary::Open);
4156 }
4157 let start = match option_numeric_expr(options, "period_start")? {
4158 Some(v) => v,
4159 None => option_numeric_expr(options, "start")?.unwrap_or(minv),
4160 };
4161 let end = match option_numeric_expr(options, "period_end")? {
4162 Some(v) => v,
4163 None => option_numeric_expr(options, "end")?.unwrap_or(maxv),
4164 };
4165 if end <= start {
4166 return Err(format!(
4167 "cyclic smooth requires period_end/end ({end}) > period_start/start ({start})"
4168 ));
4169 }
4170 Ok(OneDimensionalBoundary::Cyclic { start, end })
4171}
4172
4173pub fn parse_periodic_domain_1d(
4180 options: &BTreeMap<String, String>,
4181 minv: f64,
4182 maxv: f64,
4183) -> Result<(f64, f64), String> {
4184 let start = match option_numeric_expr(options, "period_start")? {
4185 Some(v) => v,
4186 None => option_numeric_expr(options, "start")?.unwrap_or(minv),
4187 };
4188 let end = match option_numeric_expr(options, "period_end")? {
4189 Some(v) => v,
4190 None => option_numeric_expr(options, "end")?.unwrap_or(maxv),
4191 };
4192 if !(start.is_finite() && end.is_finite()) {
4193 return Err(format!(
4194 "periodic smooth domain requires finite endpoints, got ({start}, {end})"
4195 ));
4196 }
4197 if end <= start {
4198 return Err(format!(
4199 "periodic smooth requires period_end/end ({end}) > period_start/start ({start})"
4200 ));
4201 }
4202 Ok((start, end - start))
4203}
4204
4205fn parse_matern_nu(raw: &str) -> Result<MaternNu, String> {
4206 let trimmed = raw.trim();
4207 let lowered = trimmed.to_ascii_lowercase();
4208 match lowered.as_str() {
4209 "1/2" | "0.5" | "half" => return Ok(MaternNu::Half),
4210 "3/2" | "1.5" => return Ok(MaternNu::ThreeHalves),
4211 "5/2" | "2.5" => return Ok(MaternNu::FiveHalves),
4212 "7/2" | "3.5" => return Ok(MaternNu::SevenHalves),
4213 "9/2" | "4.5" => return Ok(MaternNu::NineHalves),
4214 _ => {}
4215 }
4216
4217 let value = if let Some((num, den)) = trimmed.split_once('/') {
4218 let num = num
4219 .trim()
4220 .parse::<f64>()
4221 .map_err(|err| format!("{}: {err}", unsupported_matern_nu_message(raw)))?;
4222 let den = den
4223 .trim()
4224 .parse::<f64>()
4225 .map_err(|err| format!("{}: {err}", unsupported_matern_nu_message(raw)))?;
4226 if den == 0.0 || !num.is_finite() || !den.is_finite() {
4227 return Err(unsupported_matern_nu_message(raw));
4228 }
4229 num / den
4230 } else {
4231 trimmed
4232 .parse::<f64>()
4233 .map_err(|err| format!("{}: {err}", unsupported_matern_nu_message(raw)))?
4234 };
4235
4236 const TOL: f64 = 1e-12;
4237 if (value - 0.5).abs() <= TOL {
4238 Ok(MaternNu::Half)
4239 } else if (value - 1.5).abs() <= TOL {
4240 Ok(MaternNu::ThreeHalves)
4241 } else if (value - 2.5).abs() <= TOL {
4242 Ok(MaternNu::FiveHalves)
4243 } else if (value - 3.5).abs() <= TOL {
4244 Ok(MaternNu::SevenHalves)
4245 } else if (value - 4.5).abs() <= TOL {
4246 Ok(MaternNu::NineHalves)
4247 } else {
4248 Err(unsupported_matern_nu_message(raw))
4249 }
4250}
4251
4252fn unsupported_matern_nu_message(raw: &str) -> String {
4253 TermBuilderError::unsupported_feature(format!(
4254 "unsupported Matern nu '{raw}'; supported half-integer values are 1/2, 3/2, 5/2, 7/2, and 9/2"
4255 ))
4256 .to_string()
4257}
4258
4259#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
4260pub enum DuchonPowerPolicy {
4261 Explicit(f64),
4262 CubicStructuralDefault,
4266}
4267
4268pub fn parse_duchon_power_policy(
4269 options: &BTreeMap<String, String>,
4270) -> Result<DuchonPowerPolicy, String> {
4271 if let Some(raw_nu) = options.get("nu") {
4272 return Err(TermBuilderError::incompatible_config(format!(
4273 "Duchon smooths use power=<number>, not nu='{}'. Use power=1.5, power=2, etc.",
4274 raw_nu
4275 ))
4276 .to_string());
4277 }
4278 match options.get("power") {
4279 Some(raw) => {
4280 let value = raw.parse::<f64>().map_err(|err| {
4281 TermBuilderError::invalid_option(format!(
4282 "invalid Duchon power '{}'; expected a non-negative number such as power=1.5 or power=2: {}",
4283 raw, err
4284 ))
4285 .to_string()
4286 })?;
4287 if !value.is_finite() || value < 0.0 {
4288 return Err(TermBuilderError::invalid_option(format!(
4289 "invalid Duchon power '{}'; expected a finite non-negative number such as power=1.5 or power=2",
4290 raw
4291 ))
4292 .to_string());
4293 }
4294 Ok(DuchonPowerPolicy::Explicit(value))
4295 }
4296 None => Ok(DuchonPowerPolicy::CubicStructuralDefault),
4297 }
4298}
4299
4300pub fn parse_duchon_power(options: &BTreeMap<String, String>) -> Result<f64, String> {
4301 match parse_duchon_power_policy(options)? {
4302 DuchonPowerPolicy::Explicit(power) => Ok(power),
4303 DuchonPowerPolicy::CubicStructuralDefault => Ok(1.5),
4309 }
4310}
4311
4312pub fn parse_duchon_order(
4313 options: &BTreeMap<String, String>,
4314) -> Result<DuchonNullspaceOrder, String> {
4315 match options.get("order") {
4316 None => Ok(DuchonNullspaceOrder::Linear),
4320 Some(raw) => match raw.parse::<usize>() {
4321 Ok(0) => Ok(DuchonNullspaceOrder::Zero),
4322 Ok(1) => Ok(DuchonNullspaceOrder::Linear),
4323 Ok(other) => Ok(DuchonNullspaceOrder::Degree(other)),
4324 Err(_) => Err(TermBuilderError::invalid_option(format!(
4325 "invalid Duchon order '{}'; expected a non-negative integer such as order=0, order=1, or order=2",
4326 raw
4327 ))
4328 .to_string()),
4329 },
4330 }
4331}
4332
4333fn parse_matern_identifiability(
4334 options: &BTreeMap<String, String>,
4335) -> Result<MaternIdentifiability, TermBuilderError> {
4336 let Some(raw) = options.get("identifiability").map(String::as_str) else {
4337 return Ok(MaternIdentifiability::default());
4338 };
4339 match raw.trim().to_ascii_lowercase().as_str() {
4340 "none" => Ok(MaternIdentifiability::None),
4341 "sum_tozero" | "sum-to-zero" | "center_sum_tozero" | "center-sum-to-zero" | "centered" => {
4342 Ok(MaternIdentifiability::CenterSumToZero)
4343 }
4344 "linear" | "center_linear_orthogonal" | "center-linear-orthogonal" => {
4345 Ok(MaternIdentifiability::CenterLinearOrthogonal)
4346 }
4347 other => Err(TermBuilderError::unsupported_feature(format!(
4348 "invalid Matérn identifiability '{other}'; expected one of: none, sum_tozero, linear"
4349 ))),
4350 }
4351}
4352
4353fn parse_spatial_identifiability(
4354 options: &BTreeMap<String, String>,
4355) -> Result<SpatialIdentifiability, TermBuilderError> {
4356 let Some(raw) = options.get("identifiability").map(String::as_str) else {
4357 return Ok(SpatialIdentifiability::default());
4358 };
4359 match raw.trim().to_ascii_lowercase().as_str() {
4360 "none" => Ok(SpatialIdentifiability::None),
4361 "orthogonal"
4362 | "orthogonal_to_parametric"
4363 | "orthogonal-to-parametric"
4364 | "parametric_orthogonal" => Ok(SpatialIdentifiability::OrthogonalToParametric),
4365 "frozen" => Err(TermBuilderError::unsupported_feature(
4366 "spatial identifiability 'frozen' is internal-only; use none or orthogonal_to_parametric",
4367 )),
4368 other => Err(TermBuilderError::unsupported_feature(format!(
4369 "invalid spatial identifiability '{other}'; expected one of: none, orthogonal_to_parametric"
4370 ))),
4371 }
4372}
4373
4374#[cfg(test)]
4375mod tests {
4376 use super::*;
4377 use crate::inference::formula_dsl::parse_formula;
4378 use gam_data::{DataSchema, SchemaColumn};
4379 use ndarray::Array2;
4380 use std::collections::BTreeMap;
4381
4382 fn continuous_dataset(headers: &[&str], rows: Vec<Vec<f64>>) -> Dataset {
4383 let nrows = rows.len();
4384 let ncols = headers.len();
4385 let values = Array2::from_shape_vec(
4386 (nrows, ncols),
4387 rows.into_iter().flat_map(|row| row.into_iter()).collect(),
4388 )
4389 .expect("rectangular test data");
4390 Dataset {
4391 headers: headers.iter().map(|name| name.to_string()).collect(),
4392 values,
4393 schema: DataSchema {
4394 columns: headers
4395 .iter()
4396 .map(|name| SchemaColumn {
4397 name: name.to_string(),
4398 kind: ColumnKindTag::Continuous,
4399 levels: vec![],
4400 })
4401 .collect(),
4402 },
4403 column_kinds: vec![ColumnKindTag::Continuous; ncols],
4404 }
4405 }
4406
4407 fn factor_dataset() -> Dataset {
4408 let rows = (0..24)
4409 .map(|i| {
4410 let x = i as f64 / 23.0;
4411 let g = (i % 2) as f64;
4412 vec![x + g, x, g]
4413 })
4414 .collect::<Vec<_>>();
4415 Dataset {
4416 headers: vec!["y".into(), "x".into(), "g".into()],
4417 values: Array2::from_shape_vec(
4418 (rows.len(), 3),
4419 rows.into_iter().flat_map(|row| row.into_iter()).collect(),
4420 )
4421 .expect("rectangular factor test data"),
4422 schema: DataSchema {
4423 columns: vec![
4424 SchemaColumn {
4425 name: "y".into(),
4426 kind: ColumnKindTag::Continuous,
4427 levels: vec![],
4428 },
4429 SchemaColumn {
4430 name: "x".into(),
4431 kind: ColumnKindTag::Continuous,
4432 levels: vec![],
4433 },
4434 SchemaColumn {
4435 name: "g".into(),
4436 kind: ColumnKindTag::Categorical,
4437 levels: vec!["a".into(), "b".into()],
4438 },
4439 ],
4440 },
4441 column_kinds: vec![
4442 ColumnKindTag::Continuous,
4443 ColumnKindTag::Continuous,
4444 ColumnKindTag::Categorical,
4445 ],
4446 }
4447 }
4448
4449 #[test]
4457 fn default_univariate_thinplate_basis_dim_is_modest() {
4458 let n = 300usize;
4461 let rows: Vec<Vec<f64>> = (0..n)
4462 .map(|i| {
4463 let x = -3.0 + 6.0 * (i as f64) / ((n - 1) as f64);
4464 vec![x.sin(), x]
4465 })
4466 .collect();
4467 let ds = continuous_dataset(&["y", "x"], rows);
4468
4469 let mut options = BTreeMap::new();
4470 options.insert("bs".to_string(), "tp".to_string());
4471
4472 let mut notes = Vec::new();
4473 let basis = build_smooth_basis(
4474 SmoothKind::S,
4475 &["x".to_string()],
4476 &[1],
4477 &options,
4478 &ds,
4479 &mut notes,
4480 &ResourcePolicy::default_library(),
4481 1,
4482 )
4483 .expect("build default univariate tp smooth");
4484
4485 let centers = match &basis {
4486 SmoothBasisSpec::ThinPlate { spec, .. } => match &spec.center_strategy {
4487 CenterStrategy::Auto(inner) => match inner.as_ref() {
4488 CenterStrategy::FarthestPoint { num_centers }
4489 | CenterStrategy::EqualMass { num_centers }
4490 | CenterStrategy::EqualMassCovarRepresentative { num_centers }
4491 | CenterStrategy::KMeans { num_centers, .. } => *num_centers,
4492 other => panic!("unexpected auto inner center strategy: {other:?}"),
4493 },
4494 CenterStrategy::FarthestPoint { num_centers }
4495 | CenterStrategy::EqualMass { num_centers }
4496 | CenterStrategy::EqualMassCovarRepresentative { num_centers }
4497 | CenterStrategy::KMeans { num_centers, .. } => *num_centers,
4498 other => panic!("unexpected center strategy: {other:?}"),
4499 },
4500 other => panic!("expected ThinPlate basis, got {other:?}"),
4501 };
4502
4503 assert!(
4507 centers >= 1,
4508 "default univariate tp must still build a usable basis (centers={centers})",
4509 );
4510 }
4511
4512 #[test]
4523 fn default_matern_2d_seeds_resolving_length_scale_not_overscaled_diameter() {
4524 let side = 24usize; let mut rows: Vec<Vec<f64>> = Vec::with_capacity(side * side);
4529 for i in 0..side {
4530 for j in 0..side {
4531 let x1 = i as f64 / (side - 1) as f64; let x2 = j as f64 / (side - 1) as f64; let y = (6.0 * x1).sin() * (6.0 * x2).cos();
4534 rows.push(vec![y, x1, x2]);
4535 }
4536 }
4537 let n = rows.len();
4538 let ds = continuous_dataset(&["y", "x1", "x2"], rows);
4539
4540 let mut options = BTreeMap::new();
4541 options.insert("bs".to_string(), "gp".to_string()); let mut notes = Vec::new();
4543 let mut basis = build_smooth_basis(
4544 SmoothKind::S,
4545 &["x1".to_string(), "x2".to_string()],
4546 &[1, 2],
4547 &options,
4548 &ds,
4549 &mut notes,
4550 &ResourcePolicy::default_library(),
4551 1,
4552 )
4553 .expect("build default 2-D matern smooth");
4554
4555 let (feature_cols, seeded_length_scale) = match &basis {
4557 SmoothBasisSpec::Matern {
4558 feature_cols, spec, ..
4559 } => (feature_cols.clone(), spec.length_scale),
4560 other => panic!("expected Matern basis, got {other:?}"),
4561 };
4562 assert_eq!(
4563 seeded_length_scale, 0.0,
4564 "default matern() must leave length_scale at the 0.0 auto sentinel \
4565 (got {seeded_length_scale}); a non-zero diameter default re-enters the \
4566 over-smoothed basin and disables the planner's wiggly-side auto-init",
4567 );
4568
4569 crate::smooth::auto_init_length_scale_in_basis(ds.values.view(), &mut basis);
4573 let realized = match &basis {
4574 SmoothBasisSpec::Matern { spec, .. } => spec.length_scale,
4575 other => panic!("expected Matern basis after auto-init, got {other:?}"),
4576 };
4577 let expected =
4578 crate::smooth::auto_initial_length_scale(ds.values.view(), &feature_cols);
4579 assert!(
4580 (realized - expected).abs() <= 1e-12,
4581 "auto-init must seed the wiggly-side length scale max_range/sqrt(n) \
4582 (expected {expected}, got {realized})",
4583 );
4584
4585 let max_range = 1.0_f64; assert!(
4590 realized < max_range / 4.0,
4591 "matern seed length_scale {realized} must be in the resolving regime, \
4592 not the over-smoothed diameter corner (n={n}, max_range≈{max_range})",
4593 );
4594 }
4595
4596 fn inferred_tensor_basis_product(ds: &Dataset) -> usize {
4597 let parsed = parse_formula("y ~ te(theta, h)").expect("parse tensor formula");
4598 let col_map = ds.column_map();
4599 let mut notes = Vec::new();
4600 let terms = build_termspec(
4601 &parsed.terms,
4602 ds,
4603 &col_map,
4604 &mut notes,
4605 &ResourcePolicy::default_library(),
4606 )
4607 .expect("build tensor termspec");
4608 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
4609 panic!("expected tensor smooth");
4610 };
4611 spec.marginalspecs
4612 .iter()
4613 .map(|marginal| match marginal.knotspec {
4614 BSplineKnotSpec::Generate {
4615 num_internal_knots, ..
4616 } => num_internal_knots + marginal.degree + 1,
4617 BSplineKnotSpec::PeriodicUniform { num_basis, .. } => num_basis,
4618 BSplineKnotSpec::Automatic {
4619 num_internal_knots: Some(num_internal_knots),
4620 ..
4621 } => num_internal_knots + marginal.degree + 1,
4622 BSplineKnotSpec::Automatic {
4623 num_internal_knots: None,
4624 ..
4625 } => panic!("test helper cannot infer automatic knot count"),
4626 BSplineKnotSpec::Provided(ref knots) => {
4627 knots.len().saturating_sub(marginal.degree + 1)
4628 }
4629 BSplineKnotSpec::NaturalCubicRegression { ref knots } => knots.len(),
4631 })
4632 .product()
4633 }
4634
4635 fn tensor_margin_basis_sizes(ds: &Dataset, formula: &str) -> Vec<usize> {
4636 let parsed = parse_formula(formula).expect("parse tensor formula");
4637 let col_map = ds.column_map();
4638 let mut notes = Vec::new();
4639 let terms = build_termspec(
4640 &parsed.terms,
4641 ds,
4642 &col_map,
4643 &mut notes,
4644 &ResourcePolicy::default_library(),
4645 )
4646 .expect("build tensor termspec");
4647 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
4648 panic!("expected tensor smooth");
4649 };
4650 spec.marginalspecs
4651 .iter()
4652 .map(|marginal| match marginal.knotspec {
4653 BSplineKnotSpec::Generate {
4654 num_internal_knots, ..
4655 } => num_internal_knots + marginal.degree + 1,
4656 BSplineKnotSpec::PeriodicUniform { num_basis, .. } => num_basis,
4657 BSplineKnotSpec::Automatic {
4658 num_internal_knots: Some(num_internal_knots),
4659 ..
4660 } => num_internal_knots + marginal.degree + 1,
4661 BSplineKnotSpec::Automatic {
4662 num_internal_knots: None,
4663 ..
4664 } => panic!("test helper cannot infer automatic knot count"),
4665 BSplineKnotSpec::Provided(ref knots) => {
4666 knots.len().saturating_sub(marginal.degree + 1)
4667 }
4668 BSplineKnotSpec::NaturalCubicRegression { ref knots } => knots.len(),
4670 })
4671 .collect()
4672 }
4673
4674 #[test]
4675 fn validate_known_options_lists_valid_option_names_for_unknown_parameter() {
4676 let mut options = BTreeMap::new();
4677 options.insert("lengt_scale".to_string(), "0.25".to_string());
4678 let err = validate_known_options(
4679 "matern",
4680 &options,
4681 &["type", "bs", "length_scale", "centers", "k", "nu"],
4682 )
4683 .expect_err("unknown smooth option should be rejected");
4684 assert!(
4685 err.contains("matern() does not accept option `lengt_scale`"),
4686 "error should name the invalid option, got: {err}"
4687 );
4688 assert!(
4689 err.contains("did you mean one of [length_scale]"),
4690 "error should suggest the closest valid option, got: {err}"
4691 );
4692 assert!(
4693 err.contains("Valid options: ["),
4694 "error should list valid option names, got: {err}"
4695 );
4696 }
4697
4698 #[test]
4699 fn tensor_k_accepts_square_bracket_per_margin_list() {
4700 let ds = continuous_dataset(
4701 &["y", "x", "z"],
4702 (0..40)
4703 .map(|i| {
4704 let x = i as f64 / 39.0;
4705 let z = ((i * 7) % 40) as f64 / 39.0;
4706 vec![x.sin() + z.cos(), x, z]
4707 })
4708 .collect(),
4709 );
4710
4711 assert_eq!(
4712 tensor_margin_basis_sizes(&ds, "y ~ te(x, z, k=[5, 6])"),
4713 vec![5, 6],
4714 "square-bracket k lists should materialize the requested per-margin values"
4715 );
4716 }
4717
4718 #[test]
4719 fn parse_cylinder_periodic_options_match_requested_forms() {
4720 let mut opts = BTreeMap::new();
4721 opts.insert("periodic".to_string(), "[0]".to_string());
4722 opts.insert("period".to_string(), "[2*pi, None]".to_string());
4723 let axes = parse_periodic_axes(&opts, 2).expect("axes");
4724 let periods = parse_periods(&opts, &axes).expect("periods");
4725 assert_eq!(axes, vec![true, false]);
4726 assert!((periods[0].unwrap() - 2.0 * std::f64::consts::PI).abs() < 1e-12);
4727 assert_eq!(periods[1], None);
4728
4729 let mut boundary_opts = BTreeMap::new();
4730 boundary_opts.insert(
4731 "boundary".to_string(),
4732 "['periodic', 'natural']".to_string(),
4733 );
4734 boundary_opts.insert("period".to_string(), "[2*pi, None]".to_string());
4735 let boundary_axes = parse_periodic_axes(&boundary_opts, 2).expect("boundary axes");
4736 let boundary_periods =
4737 parse_periods(&boundary_opts, &boundary_axes).expect("boundary periods");
4738 assert_eq!(boundary_axes, vec![true, false]);
4739 assert!((boundary_periods[0].unwrap() - 2.0 * std::f64::consts::PI).abs() < 1e-12);
4740 assert_eq!(boundary_periods[1], None);
4741
4742 let mut unicode_opts = BTreeMap::new();
4743 unicode_opts.insert("periodic".to_string(), "[0,1]".to_string());
4744 unicode_opts.insert("period".to_string(), "[2π, τ]".to_string());
4745 let unicode_axes = parse_periodic_axes(&unicode_opts, 2).expect("unicode axes");
4746 let unicode_periods = parse_periods(&unicode_opts, &unicode_axes).expect("unicode periods");
4747 assert_eq!(unicode_axes, vec![true, true]);
4748 assert!((unicode_periods[0].unwrap() - 2.0 * std::f64::consts::PI).abs() < 1e-12);
4749 assert!((unicode_periods[1].unwrap() - std::f64::consts::TAU).abs() < 1e-12);
4750 }
4751
4752 #[test]
4753 fn parse_single_axis_periodic_zero_as_axis_not_false() {
4754 let mut opts = BTreeMap::new();
4755 opts.insert("periodic".to_string(), "[0]".to_string());
4756 opts.insert("period".to_string(), "2*pi".to_string());
4757 opts.insert("origin".to_string(), "0".to_string());
4758 let axes = parse_periodic_axes(&opts, 1).expect("axes");
4759 let periods = parse_periods(&opts, &axes).expect("periods");
4760 let origins = parse_period_origins(&opts, &axes).expect("origins");
4761 assert_eq!(axes, vec![true]);
4762 assert!((periods[0].unwrap() - 2.0 * std::f64::consts::PI).abs() < 1e-12);
4763 assert_eq!(origins[0], Some(0.0));
4764 }
4765
4766 #[test]
4767 fn one_dimensional_bspline_accepts_boundary_periodic() {
4768 let ds = continuous_dataset(
4769 &["y", "theta"],
4770 (0..16)
4771 .map(|i| {
4772 let theta = std::f64::consts::TAU * i as f64 / 16.0;
4773 vec![theta.sin(), theta]
4774 })
4775 .collect(),
4776 );
4777 let parsed = parse_formula("y ~ s(theta, boundary=periodic, period=2*pi, origin=0, k=8)")
4778 .expect("parse");
4779 let col_map = ds.column_map();
4780 let mut notes = Vec::new();
4781 let terms = build_termspec(
4782 &parsed.terms,
4783 &ds,
4784 &col_map,
4785 &mut notes,
4786 &gam_runtime::resource::ResourcePolicy::default_library(),
4787 )
4788 .expect("periodic boundary should build");
4789 let SmoothBasisSpec::BSpline1D { spec, .. } = &terms.smooth_terms[0].basis else {
4790 panic!("expected 1D B-spline");
4791 };
4792 assert!(matches!(
4793 &spec.knotspec,
4794 BSplineKnotSpec::PeriodicUniform {
4795 data_range,
4796 num_basis: 8
4797 } if *data_range == (0.0, std::f64::consts::TAU)
4798 ));
4799 }
4800
4801 #[test]
4802 fn univariate_smooth_accepts_mgcv_cubic_regression_aliases() {
4803 let ds = continuous_dataset(
4804 &["y", "x"],
4805 (0..32)
4806 .map(|i| {
4807 let x = i as f64 / 31.0;
4808 vec![x * x, x]
4809 })
4810 .collect(),
4811 );
4812 let col_map = ds.column_map();
4813
4814 for (selector, expect_double_penalty) in [("cr", false), ("cs", true)] {
4815 let formula = format!("y ~ s(x, bs='{selector}')");
4816 let parsed = parse_formula(&formula).expect("parse cr/cs smooth");
4817 let mut notes = Vec::new();
4818 let terms = build_termspec(
4819 &parsed.terms,
4820 &ds,
4821 &col_map,
4822 &mut notes,
4823 &gam_runtime::resource::ResourcePolicy::default_library(),
4824 )
4825 .unwrap_or_else(|err| panic!("bs='{selector}' must build a 1-D smooth, got: {err:?}"));
4826 let SmoothBasisSpec::BSpline1D { spec, .. } = &terms.smooth_terms[0].basis else {
4827 panic!(
4828 "bs='{selector}' must lower to a BSpline1D; got {:?}",
4829 terms.smooth_terms[0].basis
4830 );
4831 };
4832 assert_eq!(
4833 spec.double_penalty, expect_double_penalty,
4834 "bs='{selector}' must default double_penalty to mgcv's convention \
4835 (cr=no-shrinkage, cs=shrinkage); got double_penalty={}",
4836 spec.double_penalty
4837 );
4838 }
4839 }
4840
4841 #[test]
4842 fn univariate_ps_small_k_degree_reduces_through_build() {
4843 let ds = continuous_dataset(
4852 &["y", "x"],
4853 (0..32)
4854 .map(|i| {
4855 let x = i as f64 / 31.0;
4856 vec![x * x, x]
4857 })
4858 .collect(),
4859 );
4860 let col_map = ds.column_map();
4861
4862 for formula in ["y ~ s(x, bs='ps', k=3)", "y ~ s(x, k=3)"] {
4863 let parsed = parse_formula(formula).expect("parse small-k ps/cr smooth");
4864 let mut notes = Vec::new();
4865 let terms = build_termspec(
4866 &parsed.terms,
4867 &ds,
4868 &col_map,
4869 &mut notes,
4870 &gam_runtime::resource::ResourcePolicy::default_library(),
4871 )
4872 .unwrap_or_else(|err| {
4873 panic!("`{formula}` must degree-reduce, not error; got: {err:?}")
4874 });
4875 let SmoothBasisSpec::BSpline1D { spec, .. } = &terms.smooth_terms[0].basis else {
4876 panic!(
4877 "`{formula}` must lower to a BSpline1D; got {:?}",
4878 terms.smooth_terms[0].basis
4879 );
4880 };
4881 assert_eq!(
4882 spec.degree, 2,
4883 "`{formula}` must drop the cubic default to a quadratic basis"
4884 );
4885 let num_internal = match &spec.knotspec {
4886 BSplineKnotSpec::Generate {
4887 num_internal_knots, ..
4888 } => *num_internal_knots,
4889 BSplineKnotSpec::Automatic {
4890 num_internal_knots: Some(n),
4891 ..
4892 } => *n,
4893 other => panic!("`{formula}` unexpected knotspec: {other:?}"),
4894 };
4895 assert_eq!(
4896 num_internal, 0,
4897 "`{formula}` must have zero internal knots (num_basis = k = 3)"
4898 );
4899 assert!(
4901 spec.penalty_order >= 1 && spec.penalty_order <= spec.degree,
4902 "`{formula}` penalty_order {} must satisfy 1 <= order <= degree={}",
4903 spec.penalty_order,
4904 spec.degree
4905 );
4906 }
4907 }
4908
4909 #[test]
4910 fn formula_shape_constraint_round_trips_and_rejects_bogus() {
4911 let ds = continuous_dataset(
4912 &["y", "x"],
4913 (0..32)
4914 .map(|i| {
4915 let x = i as f64 / 31.0;
4916 vec![x * x, x]
4917 })
4918 .collect(),
4919 );
4920 let col_map = ds.column_map();
4921
4922 let parsed =
4923 parse_formula("y ~ s(x, shape=monotone_increasing)").expect("parse monotone smooth");
4924 let mut notes = Vec::new();
4925 let terms = build_termspec(
4926 &parsed.terms,
4927 &ds,
4928 &col_map,
4929 &mut notes,
4930 &gam_runtime::resource::ResourcePolicy::default_library(),
4931 )
4932 .expect("monotone smooth should build");
4933 assert_eq!(
4934 terms.smooth_terms[0].shape,
4935 ShapeConstraint::MonotoneIncreasing
4936 );
4937
4938 let parsed_bad = parse_formula("y ~ s(x, shape=bogus)").expect("parse bogus shape");
4939 let mut notes_bad = Vec::new();
4940 let err = build_termspec(
4941 &parsed_bad.terms,
4942 &ds,
4943 &col_map,
4944 &mut notes_bad,
4945 &gam_runtime::resource::ResourcePolicy::default_library(),
4946 )
4947 .expect_err("bogus shape must error");
4948 assert!(
4949 format!("{err:?}").contains("unknown shape constraint"),
4950 "got: {err:?}"
4951 );
4952 }
4953
4954 #[test]
4955 fn default_sphere_smooth_uses_spherical_farthest_point_centers() {
4956 let ds = continuous_dataset(
4957 &["y", "lat", "lon"],
4958 (0..24)
4959 .map(|i| {
4960 let t = i as f64 / 24.0;
4961 let lat = -60.0 + 120.0 * t;
4962 let lon = -180.0 + 360.0 * ((7 * i) % 24) as f64 / 24.0;
4963 vec![lat.to_radians().sin(), lat, lon]
4964 })
4965 .collect(),
4966 );
4967 let parsed = parse_formula("y ~ sphere(lat, lon)").expect("parse");
4968 let col_map = ds.column_map();
4969 let mut notes = Vec::new();
4970 let terms = build_termspec(
4971 &parsed.terms,
4972 &ds,
4973 &col_map,
4974 &mut notes,
4975 &gam_runtime::resource::ResourcePolicy::default_library(),
4976 )
4977 .expect("build sphere termspec");
4978 let SmoothBasisSpec::Sphere { spec, .. } = &terms.smooth_terms[0].basis else {
4979 panic!("expected sphere term");
4980 };
4981 assert!(matches!(
4982 spec.center_strategy,
4983 CenterStrategy::FarthestPoint { .. }
4984 ));
4985 }
4986
4987 #[test]
4988 fn one_dimensional_duchon_defaults_to_scale_free_length_scale() {
4989 let ds = continuous_dataset(
4990 &["y", "x"],
4991 (0..32)
4992 .map(|i| {
4993 let x = i as f64 / 31.0;
4994 vec![(std::f64::consts::TAU * x).sin(), x]
4995 })
4996 .collect(),
4997 );
4998 let parsed = parse_formula("y ~ duchon(x)").expect("parse");
4999 let col_map = ds.column_map();
5000 let mut notes = Vec::new();
5001 let terms = build_termspec(
5002 &parsed.terms,
5003 &ds,
5004 &col_map,
5005 &mut notes,
5006 &gam_runtime::resource::ResourcePolicy::default_library(),
5007 )
5008 .expect("build default duchon termspec");
5009 let SmoothBasisSpec::Duchon { spec, .. } = &terms.smooth_terms[0].basis else {
5010 panic!("expected Duchon term");
5011 };
5012 assert_eq!(spec.length_scale, None);
5013 }
5014
5015 #[test]
5016 fn one_dimensional_duchon_length_scale_opts_into_hybrid_mode() {
5017 let ds = continuous_dataset(
5018 &["y", "x"],
5019 (0..32)
5020 .map(|i| {
5021 let x = i as f64 / 31.0;
5022 vec![(std::f64::consts::TAU * x).sin(), x]
5023 })
5024 .collect(),
5025 );
5026 let parsed = parse_formula("y ~ duchon(x, length_scale=0.25)").expect("parse");
5027 let col_map = ds.column_map();
5028 let mut notes = Vec::new();
5029 let terms = build_termspec(
5030 &parsed.terms,
5031 &ds,
5032 &col_map,
5033 &mut notes,
5034 &gam_runtime::resource::ResourcePolicy::default_library(),
5035 )
5036 .expect("build hybrid duchon termspec");
5037 let SmoothBasisSpec::Duchon { spec, .. } = &terms.smooth_terms[0].basis else {
5038 panic!("expected Duchon term");
5039 };
5040 assert_eq!(spec.length_scale, Some(0.25));
5041 }
5042
5043 #[test]
5044 fn parse_matern_nu_accepts_equivalent_half_integer_forms() {
5045 let cases = [
5046 ("1/2", MaternNu::Half),
5047 (" 1 / 2 ", MaternNu::Half),
5048 (".5", MaternNu::Half),
5049 ("0.50", MaternNu::Half),
5050 ("half", MaternNu::Half),
5051 ("3 / 2", MaternNu::ThreeHalves),
5052 ("1.50", MaternNu::ThreeHalves),
5053 ("5 / 2", MaternNu::FiveHalves),
5054 ("2.500000000000", MaternNu::FiveHalves),
5055 ("7 / 2", MaternNu::SevenHalves),
5056 ("3.50", MaternNu::SevenHalves),
5057 ("9 / 2", MaternNu::NineHalves),
5058 ("4.50", MaternNu::NineHalves),
5059 ];
5060 for (raw, expected) in cases {
5061 let parsed = parse_matern_nu(raw).expect(raw);
5062 assert!(
5063 matches!(
5064 (parsed, expected),
5065 (MaternNu::Half, MaternNu::Half)
5066 | (MaternNu::ThreeHalves, MaternNu::ThreeHalves)
5067 | (MaternNu::FiveHalves, MaternNu::FiveHalves)
5068 | (MaternNu::SevenHalves, MaternNu::SevenHalves)
5069 | (MaternNu::NineHalves, MaternNu::NineHalves)
5070 ),
5071 "parsed {raw:?} as {parsed:?}, expected {expected:?}"
5072 );
5073 }
5074 }
5075
5076 #[test]
5077 fn parse_matern_nu_rejects_unsupported_or_invalid_values() {
5078 for raw in ["1", "2", "11/2", "1/0", "nan", "fast"] {
5079 let err = parse_matern_nu(raw).expect_err(raw);
5080 assert!(
5081 err.contains("supported half-integer values"),
5082 "unexpected error for {raw:?}: {err}"
5083 );
5084 }
5085 }
5086
5087 #[test]
5088 fn parse_ps_k_promotes_underexpressive_cubic_basis() {
5089 let mut opts = BTreeMap::new();
5090 opts.insert("k".to_string(), "4".to_string());
5091 let (internal, inferred, eff_degree) = parse_ps_internal_knots(&opts, 3, 20).expect("k=4");
5092 assert_eq!(internal, 2);
5093 assert_eq!(eff_degree, 3);
5094 assert!(!inferred);
5095
5096 opts.insert("k".to_string(), "6".to_string());
5097 let (internal, inferred, eff_degree) = parse_ps_internal_knots(&opts, 3, 20).expect("k=6");
5098 assert_eq!(internal, 2);
5099 assert_eq!(eff_degree, 3);
5100 assert!(!inferred);
5101
5102 opts.insert("k".to_string(), "10".to_string());
5103 let (internal, inferred, eff_degree) = parse_ps_internal_knots(&opts, 3, 20).expect("k=10");
5104 assert_eq!(internal, 6);
5105 assert_eq!(eff_degree, 3);
5106 assert!(!inferred);
5107 }
5108
5109 #[test]
5110 fn parse_ps_internal_knots_drops_degree_for_small_k() {
5111 let mut opts = BTreeMap::new();
5116 opts.insert("k".to_string(), "3".to_string());
5117 let (internal, inferred, eff_degree) = parse_ps_internal_knots(&opts, 3, 20).expect("k=3");
5118 assert_eq!(eff_degree, 2);
5119 assert_eq!(internal, 0);
5120 assert!(!inferred);
5121
5122 opts.insert("k".to_string(), "2".to_string());
5125 let (internal, inferred, eff_degree) = parse_ps_internal_knots(&opts, 3, 20).expect("k=2");
5126 assert_eq!(eff_degree, 1);
5127 assert_eq!(internal, 0);
5128 assert!(!inferred);
5129
5130 opts.insert("k".to_string(), "1".to_string());
5134 let err = parse_ps_internal_knots(&opts, 3, 20)
5135 .expect_err("k=1 is below the irreducible spline floor");
5136 assert!(err.contains("requires k >= 2"), "unexpected error: {err}");
5137
5138 opts.insert("k".to_string(), "4".to_string());
5141 let (internal, inferred, eff_degree) = parse_ps_internal_knots(&opts, 3, 20).expect("k=4");
5142 assert_eq!(eff_degree, 3);
5143 assert_eq!(internal, 2);
5144 assert!(!inferred);
5145 }
5146
5147 #[test]
5148 fn factor_smooth_marginal_degree_reduces_for_small_k() {
5149 let ds = factor_dataset();
5150 let col_map = ds.column_map();
5151
5152 for (k, expected_degree) in [(3usize, 2usize), (2usize, 1usize)] {
5153 let parsed =
5154 parse_formula(&format!("y ~ s(x, g, bs=fs, k={k})")).expect("parse factor smooth");
5155 let mut notes = Vec::new();
5156 let terms = build_termspec(
5157 &parsed.terms,
5158 &ds,
5159 &col_map,
5160 &mut notes,
5161 &gam_runtime::resource::ResourcePolicy::default_library(),
5162 )
5163 .unwrap_or_else(|err| panic!("fs k={k} should degree-reduce, got: {err:?}"));
5164 let SmoothBasisSpec::FactorSmooth { spec } = &terms.smooth_terms[0].basis else {
5165 panic!(
5166 "expected factor smooth, got {:?}",
5167 terms.smooth_terms[0].basis
5168 );
5169 };
5170 assert_eq!(spec.marginal.degree, expected_degree);
5171 assert!(
5172 spec.marginal.penalty_order <= spec.marginal.degree,
5173 "penalty_order {} must be clamped to degree {}",
5174 spec.marginal.penalty_order,
5175 spec.marginal.degree
5176 );
5177 let basis_size = match spec.marginal.knotspec {
5178 BSplineKnotSpec::Generate {
5179 num_internal_knots, ..
5180 } => num_internal_knots + spec.marginal.degree + 1,
5181 BSplineKnotSpec::Automatic {
5182 num_internal_knots: Some(num_internal_knots),
5183 ..
5184 } => num_internal_knots + spec.marginal.degree + 1,
5185 ref other => panic!("unexpected factor-smooth knotspec: {other:?}"),
5186 };
5187 assert_eq!(basis_size, k);
5188 }
5189 }
5190
5191 fn ternary_factor_dataset() -> Dataset {
5194 let rows = (0..120)
5195 .map(|i| {
5196 let x = (i % 3) as f64;
5197 let g = (i % 2) as f64;
5198 vec![x + g, x, g]
5199 })
5200 .collect::<Vec<_>>();
5201 Dataset {
5202 headers: vec!["y".into(), "x".into(), "g".into()],
5203 values: Array2::from_shape_vec(
5204 (rows.len(), 3),
5205 rows.into_iter().flat_map(|row| row.into_iter()).collect(),
5206 )
5207 .expect("rectangular ternary factor test data"),
5208 schema: DataSchema {
5209 columns: vec![
5210 SchemaColumn {
5211 name: "y".into(),
5212 kind: ColumnKindTag::Continuous,
5213 levels: vec![],
5214 },
5215 SchemaColumn {
5216 name: "x".into(),
5217 kind: ColumnKindTag::Continuous,
5218 levels: vec![],
5219 },
5220 SchemaColumn {
5221 name: "g".into(),
5222 kind: ColumnKindTag::Categorical,
5223 levels: vec!["a".into(), "b".into()],
5224 },
5225 ],
5226 },
5227 column_kinds: vec![
5228 ColumnKindTag::Continuous,
5229 ColumnKindTag::Continuous,
5230 ColumnKindTag::Categorical,
5231 ],
5232 }
5233 }
5234
5235 #[test]
5236 fn univariate_cr_smooth_caps_knots_to_data_support() {
5237 let ds = continuous_dataset(
5243 &["y", "x"],
5244 (0..90)
5245 .map(|i| vec![(i % 3) as f64, (i % 3) as f64])
5246 .collect(),
5247 );
5248 let col_map = ds.column_map();
5249 let parsed = parse_formula("y ~ s(x, bs=cr, k=10)").expect("parse cr smooth");
5250 let mut notes = Vec::new();
5251 let terms = build_termspec(
5252 &parsed.terms,
5253 &ds,
5254 &col_map,
5255 &mut notes,
5256 &gam_runtime::resource::ResourcePolicy::default_library(),
5257 )
5258 .expect("cr k=10 must cap to data support instead of erroring");
5259 let SmoothBasisSpec::BSpline1D { spec, .. } = &terms.smooth_terms[0].basis else {
5260 panic!("expected BSpline1D for s(x, bs=cr)");
5261 };
5262 let BSplineKnotSpec::NaturalCubicRegression { knots } = &spec.knotspec else {
5263 panic!("expected cr knotspec, got {:?}", spec.knotspec);
5264 };
5265 assert_eq!(knots.len(), 3, "cr basis not capped to 3 distinct values");
5267 assert_eq!(knots.as_slice().unwrap(), &[0.0, 1.0, 2.0]);
5268 assert!(
5270 notes.iter().any(|n| n.contains("data-support cap")),
5271 "cap not reported in inference notes: {notes:?}"
5272 );
5273 }
5274
5275 #[test]
5276 fn univariate_cr_smooth_binary_covariate_degrades_to_bspline() {
5277 let ds = continuous_dataset(
5281 &["y", "x"],
5282 (0..80)
5283 .map(|i| vec![(i % 2) as f64, (i % 2) as f64])
5284 .collect(),
5285 );
5286 let col_map = ds.column_map();
5287 let parsed = parse_formula("y ~ s(x, bs=cr, k=10)").expect("parse cr smooth");
5288 let mut notes = Vec::new();
5289 let terms = build_termspec(
5290 &parsed.terms,
5291 &ds,
5292 &col_map,
5293 &mut notes,
5294 &gam_runtime::resource::ResourcePolicy::default_library(),
5295 )
5296 .expect("binary cr must degrade to B-spline instead of erroring");
5297 let SmoothBasisSpec::BSpline1D { spec, .. } = &terms.smooth_terms[0].basis else {
5298 panic!("expected BSpline1D for s(x, bs=cr)");
5299 };
5300 assert!(
5301 !matches!(
5302 spec.knotspec,
5303 BSplineKnotSpec::NaturalCubicRegression { .. }
5304 ),
5305 "binary covariate must NOT build a cr basis, got {:?}",
5306 spec.knotspec
5307 );
5308 assert!(
5309 notes
5310 .iter()
5311 .any(|n| n.contains("Degraded to the linear B-spline")),
5312 "degradation not reported in inference notes: {notes:?}"
5313 );
5314 }
5315
5316 #[test]
5317 fn sz_factor_smooth_low_cardinality_uses_bspline_marginal() {
5318 let ds = ternary_factor_dataset();
5327 let col_map = ds.column_map();
5328 let parsed = parse_formula("y ~ s(x, g, bs=sz, k=10)").expect("parse sz factor smooth");
5329 let mut notes = Vec::new();
5330 let terms = build_termspec(
5331 &parsed.terms,
5332 &ds,
5333 &col_map,
5334 &mut notes,
5335 &gam_runtime::resource::ResourcePolicy::default_library(),
5336 )
5337 .expect("sz on a ternary covariate must build (B-spline marginal), not hard-fail");
5338 let SmoothBasisSpec::FactorSmooth { spec } = &terms.smooth_terms[0].basis else {
5339 panic!("expected FactorSmooth for s(x, g, bs=sz)");
5340 };
5341 assert!(
5342 !matches!(
5343 spec.marginal.knotspec,
5344 BSplineKnotSpec::NaturalCubicRegression { .. }
5345 ),
5346 "sz marginal must be a B-spline (curvature-capable), not the \
5347 natural-BC cr basis; got {:?}",
5348 spec.marginal.knotspec
5349 );
5350 }
5351
5352 fn continuous_x_factor_dataset(n: usize, n_groups: usize) -> Dataset {
5357 let rows = (0..n)
5358 .map(|i| {
5359 let x = i as f64 / (n as f64 - 1.0);
5360 let g = (i % n_groups) as f64;
5361 vec![x + g, x, g]
5362 })
5363 .collect::<Vec<_>>();
5364 let levels: Vec<String> = (0..n_groups).map(|k| format!("g{k}")).collect();
5365 Dataset {
5366 headers: vec!["y".into(), "x".into(), "g".into()],
5367 values: Array2::from_shape_vec(
5368 (rows.len(), 3),
5369 rows.into_iter().flat_map(|row| row.into_iter()).collect(),
5370 )
5371 .expect("rectangular continuous-x factor data"),
5372 schema: DataSchema {
5373 columns: vec![
5374 SchemaColumn {
5375 name: "y".into(),
5376 kind: ColumnKindTag::Continuous,
5377 levels: vec![],
5378 },
5379 SchemaColumn {
5380 name: "x".into(),
5381 kind: ColumnKindTag::Continuous,
5382 levels: vec![],
5383 },
5384 SchemaColumn {
5385 name: "g".into(),
5386 kind: ColumnKindTag::Categorical,
5387 levels,
5388 },
5389 ],
5390 },
5391 column_kinds: vec![
5392 ColumnKindTag::Continuous,
5393 ColumnKindTag::Continuous,
5394 ColumnKindTag::Categorical,
5395 ],
5396 }
5397 }
5398
5399 fn factor_smooth_spec_for(formula: &str, ds: &Dataset) -> FactorSmoothSpec {
5400 let col_map = ds.column_map();
5401 let parsed = parse_formula(formula).expect("parse factor smooth formula");
5402 let mut notes = Vec::new();
5403 let terms = build_termspec(
5404 &parsed.terms,
5405 ds,
5406 &col_map,
5407 &mut notes,
5408 &gam_runtime::resource::ResourcePolicy::default_library(),
5409 )
5410 .expect("build factor smooth term");
5411 let SmoothBasisSpec::FactorSmooth { spec } = &terms.smooth_terms[0].basis else {
5412 panic!("expected FactorSmooth basis for `{formula}`");
5413 };
5414 spec.clone()
5415 }
5416
5417 #[test]
5436 fn sz_factor_smooth_carries_null_space_ridge_like_fs() {
5437 let ds = continuous_x_factor_dataset(180, 4);
5438 let mut workspace = crate::basis::BasisWorkspace::new();
5439
5440 let sz_spec = factor_smooth_spec_for("y ~ s(x, g, bs=sz, k=8)", &ds);
5441 let sz_built = crate::smooth::build_factor_smooth(
5442 ds.values.view(),
5443 &sz_spec,
5444 "sz_term",
5445 &mut workspace,
5446 )
5447 .expect("build sz factor smooth");
5448
5449 let fs_spec = factor_smooth_spec_for("y ~ s(x, g, bs=fs, k=8)", &ds);
5450 let fs_built = crate::smooth::build_factor_smooth(
5451 ds.values.view(),
5452 &fs_spec,
5453 "fs_term",
5454 &mut workspace,
5455 )
5456 .expect("build fs factor smooth");
5457
5458 let n_levels = sz_spec
5465 .group_frozen_levels
5466 .as_ref()
5467 .map(|l| l.len())
5468 .unwrap_or(4);
5469 assert!(n_levels >= 3, "test needs >=3 groups, got {n_levels}");
5470
5471 assert_eq!(
5472 sz_built.penalties.len(),
5473 fs_built.penalties.len(),
5474 "sz must carry the same number of penalties as fs (wiggliness + one \
5475 null-space ridge per marginal null direction); sz had {} (only the \
5476 wiggliness penalties => null space unpenalized => over-smoothed), fs \
5477 had {}",
5478 sz_built.penalties.len(),
5479 fs_built.penalties.len(),
5480 );
5481
5482 assert!(
5487 sz_built.penalties.len() >= 2,
5488 "sz deviation block carries no null-space ridge (penalties={}); the \
5489 null space is unpenalized and REML over-smooths the deviations",
5490 sz_built.penalties.len(),
5491 );
5492
5493 assert!(
5498 sz_built.dim < fs_built.dim,
5499 "sz design width {} must be strictly less than fs width {} \
5500 (zero-sum contrast drops one level block)",
5501 sz_built.dim,
5502 fs_built.dim,
5503 );
5504
5505 assert_eq!(sz_built.penalties.len(), sz_built.nullspaces.len());
5508 assert_eq!(sz_built.penalties.len(), sz_built.penaltyinfo.len());
5509 assert_eq!(sz_built.penalties.len(), sz_built.null_eigenvectors.len());
5510 }
5511
5512 fn factor_dataset_l3() -> Dataset {
5523 let rows = (0..30)
5525 .map(|i| {
5526 let x = i as f64 / 29.0;
5527 let g = (i % 3) as f64;
5528 vec![x + g, x, g]
5529 })
5530 .collect::<Vec<_>>();
5531 Dataset {
5532 headers: vec!["y".into(), "x".into(), "g".into()],
5533 values: Array2::from_shape_vec(
5534 (rows.len(), 3),
5535 rows.into_iter().flat_map(|row| row.into_iter()).collect(),
5536 )
5537 .expect("rectangular L=3 factor test data"),
5538 schema: DataSchema {
5539 columns: vec![
5540 SchemaColumn {
5541 name: "y".into(),
5542 kind: ColumnKindTag::Continuous,
5543 levels: vec![],
5544 },
5545 SchemaColumn {
5546 name: "x".into(),
5547 kind: ColumnKindTag::Continuous,
5548 levels: vec![],
5549 },
5550 SchemaColumn {
5551 name: "g".into(),
5552 kind: ColumnKindTag::Categorical,
5553 levels: vec!["a".into(), "b".into(), "c".into()],
5554 },
5555 ],
5556 },
5557 column_kinds: vec![
5558 ColumnKindTag::Continuous,
5559 ColumnKindTag::Continuous,
5560 ColumnKindTag::Categorical,
5561 ],
5562 }
5563 }
5564
5565 #[test]
5566 fn factor_by_smooth_plus_bare_categorical_does_not_duplicate_factor_block() {
5567 let ds = factor_dataset_l3();
5568 let col_map = ds.column_map();
5569
5570 let g_blocks = |formula: &str| -> usize {
5571 let parsed = parse_formula(formula).expect("parse by-smooth formula");
5572 let mut notes = Vec::new();
5573 let terms = build_termspec(
5574 &parsed.terms,
5575 &ds,
5576 &col_map,
5577 &mut notes,
5578 &ResourcePolicy::default_library(),
5579 )
5580 .unwrap_or_else(|err| panic!("`{formula}` must build, got: {err:?}"));
5581 terms
5582 .random_effect_terms
5583 .iter()
5584 .filter(|rt| rt.name == "g")
5585 .count()
5586 };
5587
5588 let by_only = g_blocks("y ~ s(x, by=g, k=10)");
5592 assert_eq!(
5593 by_only, 1,
5594 "`y ~ s(x, by=g)` must produce exactly one `g` design block"
5595 );
5596
5597 let by_plus_bare = g_blocks("y ~ s(x, by=g, k=10) + g");
5601 assert_eq!(
5602 by_plus_bare, 1,
5603 "`y ~ s(x, by=g) + g` must collapse to ONE `g` block (#1457): the bare \
5604 `+ g` already owns the factor's level offsets, so the `by=` branch \
5605 must not add a second, treatment-coded main effect"
5606 );
5607
5608 assert_eq!(
5610 by_plus_bare, by_only,
5611 "the bare `+ g` collision must add zero extra `g` blocks (#1457)"
5612 );
5613 }
5614
5615 #[test]
5616 fn parse_tensor_periods_and_origins_aliases() {
5617 let mut opts = BTreeMap::new();
5618 opts.insert(
5619 "boundary".to_string(),
5620 "['periodic', 'periodic']".to_string(),
5621 );
5622 opts.insert("periods".to_string(), "[7, 24]".to_string());
5623 opts.insert("origins".to_string(), "[0, -12]".to_string());
5624 let axes = parse_periodic_axes(&opts, 2).expect("axes");
5625 let periods = parse_periods(&opts, &axes).expect("periods");
5626 let origins = parse_period_origins(&opts, &axes).expect("origins");
5627 assert_eq!(axes, vec![true, true]);
5628 assert_eq!(periods, vec![Some(7.0), Some(24.0)]);
5629 assert_eq!(origins, vec![Some(0.0), Some(-12.0)]);
5630 }
5631
5632 #[test]
5633 fn tensor_smooth_honors_per_margin_k_list() {
5634 let ds = continuous_dataset(
5635 &["y", "theta", "h"],
5636 (0..20)
5637 .map(|i| {
5638 let theta = std::f64::consts::TAU * i as f64 / 20.0;
5639 let h = -1.0 + 2.0 * (i % 5) as f64 / 4.0;
5640 vec![theta.cos() + h, theta, h]
5641 })
5642 .collect(),
5643 );
5644 let parsed = parse_formula(
5645 "y ~ te(theta, h, periodic=[0], period=[2*pi, None], origin=[0, None], k=[9,5])",
5646 )
5647 .expect("parse tensor formula");
5648 let col_map = ds.column_map();
5649 let mut notes = Vec::new();
5650 let terms = build_termspec(
5651 &parsed.terms,
5652 &ds,
5653 &col_map,
5654 &mut notes,
5655 &gam_runtime::resource::ResourcePolicy::default_library(),
5656 )
5657 .expect("build tensor terms");
5658 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
5659 panic!("expected tensor B-spline");
5660 };
5661 let dims = spec
5662 .marginalspecs
5663 .iter()
5664 .map(|m| match m.knotspec {
5665 BSplineKnotSpec::PeriodicUniform { num_basis, .. } => num_basis,
5666 BSplineKnotSpec::Generate {
5667 num_internal_knots, ..
5668 } => num_internal_knots + m.degree + 1,
5669 BSplineKnotSpec::NaturalCubicRegression { ref knots } => knots.len(),
5672 _ => panic!("unexpected tensor marginal knotspec"),
5673 })
5674 .collect::<Vec<_>>();
5675 assert_eq!(dims, vec![9, 5]);
5676 }
5677
5678 #[test]
5679 fn tensor_smooth_honors_per_margin_k_axis_aliases() {
5680 let ds = continuous_dataset(
5681 &["resp", "x", "y"],
5682 (0..12)
5683 .map(|i| {
5684 let t = i as f64 / 11.0;
5685 vec![t, t, 1.0 - t]
5686 })
5687 .collect(),
5688 );
5689 assert_eq!(
5690 tensor_margin_basis_sizes(&ds, "resp ~ te(x, y, k_x=9, k_y=5)"),
5691 vec![9, 5],
5692 "k_<margin> aliases should materialize requested per-margin values"
5693 );
5694 }
5695
5696 #[test]
5697 fn tensor_smooth_low_cardinality_axis_falls_back_to_lower_degree_basis() {
5698 let ds = continuous_dataset(
5705 &["y", "x", "b"],
5706 (0..40)
5707 .map(|i| {
5708 let x = i as f64 / 39.0;
5709 let b = (i % 2) as f64;
5710 vec![x.sin() + 0.5 * b, x, b]
5711 })
5712 .collect(),
5713 );
5714 let parsed = parse_formula("y ~ te(x, b, k=[5, 2])").expect("parse tensor with k=[5,2]");
5715 let col_map = ds.column_map();
5716 let mut notes = Vec::new();
5717 let terms = build_termspec(
5718 &parsed.terms,
5719 &ds,
5720 &col_map,
5721 &mut notes,
5722 &gam_runtime::resource::ResourcePolicy::default_library(),
5723 )
5724 .expect("build tensor with binary margin");
5725 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
5726 panic!("expected tensor B-spline for te(x, b)");
5727 };
5728 let continuous = &spec.marginalspecs[0];
5732 let binary = &spec.marginalspecs[1];
5733 assert_eq!(continuous.degree, 3);
5734 assert_eq!(binary.degree, 1);
5735 assert!(
5736 binary.penalty_order >= 1 && binary.penalty_order <= binary.degree,
5737 "binary margin penalty_order {} must satisfy 1 <= order <= degree={}",
5738 binary.penalty_order,
5739 binary.degree
5740 );
5741 let basis_size = |m: &BSplineBasisSpec| match m.knotspec {
5742 BSplineKnotSpec::PeriodicUniform { num_basis, .. } => num_basis,
5743 BSplineKnotSpec::Generate {
5744 num_internal_knots, ..
5745 } => num_internal_knots + m.degree + 1,
5746 BSplineKnotSpec::Automatic {
5747 num_internal_knots: Some(n),
5748 ..
5749 } => n + m.degree + 1,
5750 BSplineKnotSpec::NaturalCubicRegression { ref knots } => knots.len(),
5753 _ => panic!("unexpected tensor marginal knotspec"),
5754 };
5755 assert_eq!(basis_size(continuous), 5);
5756 assert_eq!(basis_size(binary), 2);
5757 }
5758
5759 #[test]
5760 fn tensor_smooth_uniform_k_is_capped_to_a_low_cardinality_margins_distinct_values() {
5761 let ds = continuous_dataset(
5769 &["y", "x", "b"],
5770 (0..40)
5771 .map(|i| {
5772 let x = i as f64 / 39.0;
5773 let b = (i % 2) as f64;
5774 vec![x.sin() + 0.5 * b, x, b]
5775 })
5776 .collect(),
5777 );
5778 let parsed = parse_formula("y ~ te(x, b, k=5)").expect("parse tensor with uniform k=5");
5779 let col_map = ds.column_map();
5780 let mut notes = Vec::new();
5781 let terms = build_termspec(
5782 &parsed.terms,
5783 &ds,
5784 &col_map,
5785 &mut notes,
5786 &gam_runtime::resource::ResourcePolicy::default_library(),
5787 )
5788 .expect("uniform k=5 must auto-cap the binary margin instead of erroring");
5789 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
5790 panic!("expected tensor B-spline for te(x, b)");
5791 };
5792 let basis_size = |m: &BSplineBasisSpec| match &m.knotspec {
5793 BSplineKnotSpec::PeriodicUniform { num_basis, .. } => *num_basis,
5794 BSplineKnotSpec::Generate {
5795 num_internal_knots, ..
5796 } => num_internal_knots + m.degree + 1,
5797 BSplineKnotSpec::Automatic {
5798 num_internal_knots: Some(n),
5799 ..
5800 } => n + m.degree + 1,
5801 BSplineKnotSpec::NaturalCubicRegression { knots } => knots.len(),
5802 other => panic!("unexpected tensor marginal knotspec: {other:?}"),
5803 };
5804 let binary = &spec.marginalspecs[1];
5805 assert_eq!(basis_size(binary), 2);
5808 assert_eq!(binary.degree, 1);
5809 assert_eq!(basis_size(&spec.marginalspecs[0]), 5);
5811 }
5812
5813 #[test]
5814 fn tensor_all_tp_margins_with_per_margin_k_routes_to_bspline_tensor() {
5815 let ds = continuous_dataset(
5824 &["y", "x1", "x2"],
5825 (0..32)
5826 .map(|i| {
5827 let t = i as f64 / 31.0;
5828 vec![t.sin(), t, 1.0 - t]
5829 })
5830 .collect(),
5831 );
5832 let parsed =
5833 parse_formula("y ~ te(x1, x2, bs=c('tp','tp'), k=c(5,5))").expect("parse tensor");
5834 let col_map = ds.column_map();
5835 let mut notes = Vec::new();
5836 let terms = build_termspec(
5837 &parsed.terms,
5838 &ds,
5839 &col_map,
5840 &mut notes,
5841 &gam_runtime::resource::ResourcePolicy::default_library(),
5842 )
5843 .expect("build tensor terms with per-margin k");
5844 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
5845 panic!(
5846 "expected B-spline tensor when k=c(5,5) is supplied with bs=c('tp','tp'), got {:?}",
5847 terms.smooth_terms[0].basis
5848 );
5849 };
5850 let dims = spec
5860 .marginalspecs
5861 .iter()
5862 .map(|m| match m.knotspec {
5863 BSplineKnotSpec::Generate {
5864 num_internal_knots, ..
5865 } => num_internal_knots + m.degree + 1,
5866 BSplineKnotSpec::Automatic {
5867 num_internal_knots: Some(num_internal_knots),
5868 ..
5869 } => num_internal_knots + m.degree + 1,
5870 BSplineKnotSpec::PeriodicUniform { num_basis, .. } => num_basis,
5871 BSplineKnotSpec::Provided(ref knots) => {
5872 knots.len().saturating_sub(m.degree + 1)
5873 }
5874 BSplineKnotSpec::NaturalCubicRegression { ref knots } => knots.len(),
5875 BSplineKnotSpec::Automatic {
5876 num_internal_knots: None,
5877 ..
5878 } => panic!("test cannot infer automatic knot count"),
5879 })
5880 .collect::<Vec<_>>();
5881 assert_eq!(dims, vec![5, 5]);
5882 }
5883
5884 #[test]
5885 fn tensor_all_tp_margins_without_per_margin_k_builds_anisotropic_tensor() {
5886 let ds = continuous_dataset(
5894 &["y", "x1", "x2"],
5895 (0..32)
5896 .map(|i| {
5897 let t = i as f64 / 31.0;
5898 vec![t.sin(), t, 1.0 - t]
5899 })
5900 .collect(),
5901 );
5902 let parsed = parse_formula("y ~ te(x1, x2, bs=c('tp','tp'))").expect("parse tensor");
5903 let col_map = ds.column_map();
5904 let mut notes = Vec::new();
5905 let terms = build_termspec(
5906 &parsed.terms,
5907 &ds,
5908 &col_map,
5909 &mut notes,
5910 &gam_runtime::resource::ResourcePolicy::default_library(),
5911 )
5912 .expect("build tensor terms without per-margin k");
5913 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
5914 panic!(
5915 "te(...,bs=c('tp','tp')) must route to an anisotropic tensor product, not a \
5916 silent isotropic thin-plate substitution; got {:?}",
5917 terms.smooth_terms[0].basis
5918 );
5919 };
5920 assert_eq!(
5921 spec.marginalspecs.len(),
5922 2,
5923 "tp tensor must carry one penalized B-spline margin per axis"
5924 );
5925 }
5926
5927 #[test]
5928 fn explicit_basis_sizes_are_not_small_n_clamped() {
5929 let ds = continuous_dataset(
5930 &["y", "x1", "x2", "x3", "x4", "x5"],
5931 (0..12)
5932 .map(|i| {
5933 let x = i as f64 / 11.0;
5934 vec![x.sin(), x, x * x, x + 0.1, 1.0 - x, (2.0 * x).sin()]
5935 })
5936 .collect(),
5937 );
5938 let parsed = parse_formula("y ~ s(x1, k=10) + s(x2) + s(x3) + s(x4) + s(x5)")
5939 .expect("parse multi-smooth formula");
5940 let col_map = ds.column_map();
5941 let mut notes = Vec::new();
5942 let terms = build_termspec(
5943 &parsed.terms,
5944 &ds,
5945 &col_map,
5946 &mut notes,
5947 &gam_runtime::resource::ResourcePolicy::default_library(),
5948 )
5949 .expect("build multi-smooth terms");
5950 let SmoothBasisSpec::BSpline1D { spec, .. } = &terms.smooth_terms[0].basis else {
5951 panic!("expected first smooth to be B-spline");
5952 };
5953 assert!(matches!(
5954 &spec.knotspec,
5955 BSplineKnotSpec::Generate {
5956 num_internal_knots: 6,
5957 ..
5958 }
5959 ));
5960 }
5961
5962 #[test]
5963 fn explicit_duchon_centers_are_not_small_n_bumped() {
5964 let ds = continuous_dataset(
5965 &["y", "x1", "x2", "x3", "x4", "x5"],
5966 (0..12)
5967 .map(|i| {
5968 let x = i as f64 / 11.0;
5969 vec![x.sin(), x, x * x, x + 0.1, 1.0 - x, (2.0 * x).sin()]
5970 })
5971 .collect(),
5972 );
5973 let parsed = parse_formula("y ~ duchon(x1, centers=3) + s(x2) + s(x3) + s(x4) + s(x5)")
5980 .expect("parse multi-smooth formula");
5981 let col_map = ds.column_map();
5982 let mut notes = Vec::new();
5983 let terms = build_termspec(
5984 &parsed.terms,
5985 &ds,
5986 &col_map,
5987 &mut notes,
5988 &gam_runtime::resource::ResourcePolicy::default_library(),
5989 )
5990 .expect("build multi-smooth terms");
5991 let SmoothBasisSpec::Duchon { spec, .. } = &terms.smooth_terms[0].basis else {
5992 panic!("expected first smooth to be Duchon");
5993 };
5994 assert!(matches!(
5995 spec.center_strategy,
5996 CenterStrategy::FarthestPoint { num_centers: 3 }
5997 ));
5998 }
5999
6000 #[test]
6001 fn inferred_tensor_basis_cap_uses_coordinate_support_not_duplicate_rows() {
6002 let mut unique_rows = Vec::new();
6003 for i in 0..50 {
6004 let theta = i as f64 / 50.0;
6005 for j in 0..16 {
6006 let h = -1.0 + 2.0 * (j as f64) / 15.0;
6007 let y = theta.cos() + h;
6008 unique_rows.push(vec![y, theta, h]);
6009 }
6010 }
6011 let mut repeated_rows = Vec::new();
6012 for _ in 0..12 {
6013 repeated_rows.extend(unique_rows.iter().cloned());
6014 }
6015
6016 let unique = continuous_dataset(&["y", "theta", "h"], unique_rows);
6017 let repeated = continuous_dataset(&["y", "theta", "h"], repeated_rows);
6018
6019 let unique_basis = inferred_tensor_basis_product(&unique);
6020 let repeated_basis = inferred_tensor_basis_product(&repeated);
6021
6022 assert_eq!(
6023 unique_basis, repeated_basis,
6024 "duplicating existing tensor coordinates must not inflate inferred basis width"
6025 );
6026 }
6027
6028 #[test]
6029 fn inferred_three_dim_tensor_basis_stays_bounded_for_reml_selection() {
6030 let make = |n: usize| -> usize {
6038 let mut rows = Vec::with_capacity(n);
6039 for i in 0..n {
6040 let f = i as f64 / n as f64;
6041 rows.push(vec![f.sin(), f, (2.0 * f).cos(), (3.0 * f) % 1.0]);
6042 }
6043 let ds = continuous_dataset(&["y", "x1", "x2", "x3"], rows);
6044 let parsed = parse_formula("y ~ te(x1, x2, x3)").expect("parse 3-D tensor");
6045 let col_map = ds.column_map();
6046 let mut notes = Vec::new();
6047 let terms = build_termspec(
6048 &parsed.terms,
6049 &ds,
6050 &col_map,
6051 &mut notes,
6052 &ResourcePolicy::default_library(),
6053 )
6054 .expect("build 3-D tensor termspec");
6055 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
6056 panic!("expected tensor smooth");
6057 };
6058 spec.marginalspecs
6059 .iter()
6060 .map(|m| match m.knotspec {
6061 BSplineKnotSpec::Generate {
6062 num_internal_knots, ..
6063 } => num_internal_knots + m.degree + 1,
6064 BSplineKnotSpec::Automatic {
6065 num_internal_knots: Some(num_internal_knots),
6066 ..
6067 } => num_internal_knots + m.degree + 1,
6068 BSplineKnotSpec::NaturalCubicRegression { ref knots } => knots.len(),
6071 _ => panic!("unexpected tensor margin knotspec"),
6072 })
6073 .product()
6074 };
6075
6076 assert!(
6078 make(60) <= 216,
6079 "3-D te at small n must stay near the mgcv te default, got {}",
6080 make(60)
6081 );
6082 assert!(
6084 make(2000) <= 216,
6085 "3-D te at large n must not blow ∏k toward the data size, got {}",
6086 make(2000)
6087 );
6088 }
6089
6090 #[test]
6091 fn parse_bspline_boundary_conditions_and_side_selector() {
6092 let mut opts = BTreeMap::new();
6097 opts.insert("boundary_conditions".to_string(), "anchored".to_string());
6098 opts.insert("side".to_string(), "left".to_string());
6099 opts.insert("anchor".to_string(), "2.5".to_string());
6100 let err = parse_bspline_boundary_conditions(&opts)
6101 .expect_err("non-zero left anchor must be rejected")
6102 .to_string();
6103 assert!(
6104 err.contains("left") && err.contains("2.5"),
6105 "rejection should name the affected side and value: {err}"
6106 );
6107
6108 let mut opts = BTreeMap::new();
6112 opts.insert("start_bc".to_string(), "clamped".to_string());
6113 opts.insert("end_bc".to_string(), "zero".to_string());
6114 opts.insert("right_anchor".to_string(), "-1.0".to_string());
6115 let err = parse_bspline_boundary_conditions(&opts)
6116 .expect_err("non-zero right anchor must be rejected")
6117 .to_string();
6118 assert!(
6119 err.contains("right") && err.contains("-1"),
6120 "rejection should name the affected side and value: {err}"
6121 );
6122
6123 let mut opts = BTreeMap::new();
6127 opts.insert("start_bc".to_string(), "clamped".to_string());
6128 opts.insert("end_bc".to_string(), "zero".to_string());
6129 let parsed = parse_bspline_boundary_conditions(&opts).expect("boundary conditions");
6130 assert!(matches!(
6131 parsed.left,
6132 BSplineEndpointBoundaryCondition::Clamped
6133 ));
6134 assert!(matches!(
6135 parsed.right,
6136 BSplineEndpointBoundaryCondition::Anchored { value } if value.abs() < 1e-12
6137 ));
6138 }
6139
6140 #[test]
6141 fn categorical_by_numeric_interaction_expands_treatment_coded_cells() {
6142 let ds = factor_dataset();
6153 let parsed = parse_formula("y ~ x:g").expect("parse `y ~ x:g`");
6155 let col_map = ds.column_map();
6156 let mut notes = Vec::new();
6157 let terms = build_termspec(
6158 &parsed.terms,
6159 &ds,
6160 &col_map,
6161 &mut notes,
6162 &ResourcePolicy::default_library(),
6163 )
6164 .expect("factor-aware `x:g` interaction must build, not error");
6165
6166 assert_eq!(
6167 terms.linear_terms.len(),
6168 2,
6169 "interaction-only `x:g` keeps ALL factor levels (full dummy coding): one slope column per group"
6170 );
6171
6172 let x_col = *col_map.get("x").expect("x column");
6173 let g_col = *col_map.get("g").expect("g column");
6174
6175 let mut seen_bits = std::collections::HashSet::new();
6178 for term in &terms.linear_terms {
6179 assert!(
6180 term.is_interaction(),
6181 "the categorical-by-numeric cell is a Wilkinson-Rogers interaction"
6182 );
6183 assert_eq!(term.feature_cols, vec![x_col]);
6184 assert_eq!(term.categorical_levels.len(), 1);
6185 let (gate_col, gate_bits) = term.categorical_levels[0];
6186 assert_eq!(gate_col, g_col);
6187 assert!(seen_bits.insert(gate_bits), "each level appears once");
6188
6189 let column = term
6191 .realized_design_column(ds.values.view())
6192 .expect("realize cell column");
6193 let n = ds.values.nrows();
6194 assert_eq!(column.len(), n);
6195 for row in 0..n {
6196 let x = ds.values[[row, x_col]];
6197 let g = ds.values[[row, g_col]];
6198 let expected = if g.to_bits() == gate_bits { x } else { 0.0 };
6199 assert!(
6200 (column[row] - expected).abs() < 1e-12,
6201 "row {row}: g={g}, x={x}, expected {expected}, got {}",
6202 column[row]
6203 );
6204 }
6205 }
6206 assert!(seen_bits.contains(&0.0_f64.to_bits()));
6209 assert!(seen_bits.contains(&1.0_f64.to_bits()));
6210 }
6211
6212 #[test]
6213 fn categorical_by_numeric_interaction_keeps_treatment_coding_with_parent() {
6214 let ds = factor_dataset();
6222 let parsed = parse_formula("y ~ x + x:g").expect("parse `y ~ x + x:g`");
6223 let col_map = ds.column_map();
6224 let mut notes = Vec::new();
6225 let terms = build_termspec(
6226 &parsed.terms,
6227 &ds,
6228 &col_map,
6229 &mut notes,
6230 &ResourcePolicy::default_library(),
6231 )
6232 .expect("`x + x:g` must build");
6233
6234 let x_col = *col_map.get("x").expect("x column");
6236 let g_col = *col_map.get("g").expect("g column");
6237 let interaction_cells: Vec<_> = terms
6238 .linear_terms
6239 .iter()
6240 .filter(|t| t.is_interaction())
6241 .collect();
6242 assert_eq!(
6243 interaction_cells.len(),
6244 1,
6245 "with `x` present, `x:g` is treatment-coded → one cell (reference dropped)"
6246 );
6247 let term = interaction_cells[0];
6248 assert_eq!(term.feature_cols, vec![x_col]);
6249 assert_eq!(term.categorical_levels.len(), 1);
6250 let (gate_col, gate_bits) = term.categorical_levels[0];
6251 assert_eq!(gate_col, g_col);
6252 assert_eq!(gate_bits, 1.0_f64.to_bits());
6254 }
6255
6256 #[test]
6257 fn categorical_by_categorical_interaction_expands_full_cross_cells() {
6258 let n = 30usize;
6269 let mut rows = Vec::with_capacity(n);
6270 for i in 0..n {
6271 let y = (i as f64).sin();
6272 let f = (i % 3) as f64; let g = (i % 2) as f64; rows.push(vec![y, f, g]);
6275 }
6276 let values = Array2::from_shape_vec(
6277 (n, 3),
6278 rows.into_iter().flat_map(|row| row.into_iter()).collect(),
6279 )
6280 .expect("rectangular cross-factor data");
6281 let ds = Dataset {
6282 headers: vec!["y".into(), "f".into(), "g".into()],
6283 values,
6284 schema: DataSchema {
6285 columns: vec![
6286 SchemaColumn {
6287 name: "y".into(),
6288 kind: ColumnKindTag::Continuous,
6289 levels: vec![],
6290 },
6291 SchemaColumn {
6292 name: "f".into(),
6293 kind: ColumnKindTag::Categorical,
6294 levels: vec!["f0".into(), "f1".into(), "f2".into()],
6295 },
6296 SchemaColumn {
6297 name: "g".into(),
6298 kind: ColumnKindTag::Categorical,
6299 levels: vec!["g0".into(), "g1".into()],
6300 },
6301 ],
6302 },
6303 column_kinds: vec![
6304 ColumnKindTag::Continuous,
6305 ColumnKindTag::Categorical,
6306 ColumnKindTag::Categorical,
6307 ],
6308 };
6309
6310 let parsed = parse_formula("y ~ f:g").expect("parse `y ~ f:g`");
6311 let col_map = ds.column_map();
6312 let mut notes = Vec::new();
6313 let terms = build_termspec(
6314 &parsed.terms,
6315 &ds,
6316 &col_map,
6317 &mut notes,
6318 &ResourcePolicy::default_library(),
6319 )
6320 .expect("factor-by-factor `f:g` interaction must build, not error");
6321
6322 assert_eq!(
6323 terms.linear_terms.len(),
6324 5,
6325 "saturated 3*2 = 6 cross cells minus one reference cell (f0:g0) = 5"
6326 );
6327
6328 let f_col = *col_map.get("f").expect("f column");
6329 let g_col = *col_map.get("g").expect("g column");
6330 let f0 = 0.0_f64.to_bits();
6334 let g0 = 0.0_f64.to_bits();
6335 let mut emitted = std::collections::HashSet::new();
6336 for term in &terms.linear_terms {
6337 assert!(term.feature_cols.is_empty());
6339 assert_eq!(term.categorical_levels.len(), 2);
6340 let mut gates = std::collections::HashMap::new();
6341 for &(col, bits) in &term.categorical_levels {
6342 gates.insert(col, bits);
6343 }
6344 let f_bits = *gates.get(&f_col).expect("f gate present");
6345 let g_bits = *gates.get(&g_col).expect("g gate present");
6346 assert!(
6348 !(f_bits == f0 && g_bits == g0),
6349 "the reference cell f0:g0 must be absorbed by the intercept, not emitted"
6350 );
6351 emitted.insert((f_bits, g_bits));
6352
6353 let column = term
6354 .realized_design_column(ds.values.view())
6355 .expect("realize cross cell");
6356 for row in 0..n {
6357 let f = ds.values[[row, f_col]];
6358 let g = ds.values[[row, g_col]];
6359 let expected = if f.to_bits() == f_bits && g.to_bits() == g_bits {
6360 1.0
6361 } else {
6362 0.0
6363 };
6364 assert!(
6365 (column[row] - expected).abs() < 1e-12,
6366 "row {row}: expected {expected}, got {}",
6367 column[row]
6368 );
6369 }
6370 assert!(
6371 column.iter().any(|&v| v == 1.0),
6372 "each cross cell must be observed in the data"
6373 );
6374 }
6375 let f_levels = [0.0_f64.to_bits(), 1.0_f64.to_bits(), 2.0_f64.to_bits()];
6378 let g_levels = [0.0_f64.to_bits(), 1.0_f64.to_bits()];
6379 for &fb in &f_levels {
6380 for &gb in &g_levels {
6381 if fb == f0 && gb == g0 {
6382 continue;
6383 }
6384 assert!(
6385 emitted.contains(&(fb, gb)),
6386 "saturated cross cell must be present"
6387 );
6388 }
6389 }
6390 }
6391}