1use std::collections::{BTreeMap, BTreeSet, HashMap};
8use std::path::PathBuf;
9
10use ndarray::{Array2, ArrayView1};
11
12use crate::basis::{
13 BSplineBasisSpec, BSplineBoundaryConditions, BSplineEndpointBoundaryCondition,
14 BSplineIdentifiability, BSplineKnotSpec, CenterCountRequest, CenterStrategy,
15 ConstantCurvatureBasisSpec, ConstantCurvatureIdentifiability, DuchonBasisSpec,
16 DuchonNullspaceOrder, DuchonOperatorPenaltySpec, MaternBasisSpec, MaternIdentifiability,
17 MaternNu, MeasureJetBasisSpec, MeasureJetIdentifiability, OneDimensionalBoundary,
18 SpatialIdentifiability, SphereMethod, SphereWahbaKernel, SphericalSplineBasisSpec,
19 SphericalSplineIdentifiability, ThinPlateBasisSpec, auto_spatial_center_strategy,
20 default_num_centers, default_spatial_center_strategy, default_spherical_harmonic_degree,
21 plan_spatial_basis, thin_plate_penalty_order,
22};
23use crate::inference::formula_dsl::{
24 ParsedTerm, SmoothKind, option_bool, option_f64, option_f64_strict, option_usize,
25 option_usize_any, option_usize_any_strict, option_usize_strict, strip_quotes,
26};
27use crate::smooth::{
28 ByVarKind, FactorSmoothFlavour, FactorSmoothSpec, LinearCoefficientGeometry, LinearTermSpec,
29 RandomEffectTermSpec, ShapeConstraint, SmoothBasisSpec, SmoothTermSpec,
30 TensorBSplineIdentifiability, TensorBSplinePenaltyDecomposition, TensorBSplineSpec,
31 TermCollectionSpec,
32};
33use gam_problem::types::ColIdx;
34use gam_data::{ColumnKindTag, DataError, EncodedDataset as Dataset};
35use gam_runtime::resource::ResourcePolicy;
36
37const DEFAULT_BSPLINE_DEGREE: usize = 3;
41
42const DEFAULT_PENALTY_ORDER: usize = 2;
46
47const CYCLIC_DEFAULT_BASIS_DIM: usize = 12;
53
54const FACTOR_SMOOTH_DEFAULT_BASIS_DIM: usize = 10;
60
61const DEFAULT_PCA_CHUNK_SIZE: usize = 4096;
65
66#[derive(Clone, Debug)]
76pub enum TermBuilderError {
77 MissingColumn { reason: String },
84 ColumnNotFound {
90 name: String,
91 role: Option<String>,
92 available: Vec<String>,
93 similar: Vec<String>,
94 tsv_hint: bool,
95 },
96 IncompatibleConfig { reason: String },
100 InvalidOption { reason: String },
103 UnsupportedFeature { reason: String },
107 DegenerateData { reason: String },
110 MalformedFormula { reason: String },
113}
114
115impl std::fmt::Display for TermBuilderError {
116 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117 match self {
118 TermBuilderError::MissingColumn { reason }
119 | TermBuilderError::IncompatibleConfig { reason }
120 | TermBuilderError::InvalidOption { reason }
121 | TermBuilderError::UnsupportedFeature { reason }
122 | TermBuilderError::DegenerateData { reason }
123 | TermBuilderError::MalformedFormula { reason } => f.write_str(reason),
124 TermBuilderError::ColumnNotFound {
130 name,
131 role,
132 available,
133 similar,
134 tsv_hint,
135 } => {
136 let canonical = DataError::ColumnNotFound {
137 name: name.clone(),
138 role: role.clone(),
139 available: available.clone(),
140 similar: similar.clone(),
141 tsv_hint: *tsv_hint,
142 };
143 std::fmt::Display::fmt(&canonical, f)
144 }
145 }
146 }
147}
148
149impl From<TermBuilderError> for String {
150 fn from(err: TermBuilderError) -> String {
151 err.to_string()
152 }
153}
154
155impl From<String> for TermBuilderError {
162 fn from(reason: String) -> Self {
163 Self::IncompatibleConfig { reason }
164 }
165}
166
167impl From<DataError> for TermBuilderError {
174 fn from(err: DataError) -> Self {
175 match err {
176 DataError::ColumnNotFound {
177 name,
178 role,
179 available,
180 similar,
181 tsv_hint,
182 } => Self::ColumnNotFound {
183 name,
184 role,
185 available,
186 similar,
187 tsv_hint,
188 },
189 DataError::SchemaMismatch { reason }
190 | DataError::ParseError { reason }
191 | DataError::EncodingFailure { reason }
192 | DataError::EmptyInput { reason }
193 | DataError::InvalidValue { reason } => Self::MissingColumn { reason },
194 }
195 }
196}
197
198impl TermBuilderError {
200 #[inline]
201 fn missing_column(reason: impl Into<String>) -> Self {
202 TermBuilderError::MissingColumn {
203 reason: reason.into(),
204 }
205 }
206 #[inline]
207 fn incompatible_config(reason: impl Into<String>) -> Self {
208 TermBuilderError::IncompatibleConfig {
209 reason: reason.into(),
210 }
211 }
212 #[inline]
213 fn invalid_option(reason: impl Into<String>) -> Self {
214 TermBuilderError::InvalidOption {
215 reason: reason.into(),
216 }
217 }
218 #[inline]
219 fn unsupported_feature(reason: impl Into<String>) -> Self {
220 TermBuilderError::UnsupportedFeature {
221 reason: reason.into(),
222 }
223 }
224 #[inline]
225 fn degenerate_data(reason: impl Into<String>) -> Self {
226 TermBuilderError::DegenerateData {
227 reason: reason.into(),
228 }
229 }
230 #[inline]
231 fn malformed_formula(reason: impl Into<String>) -> Self {
232 TermBuilderError::MalformedFormula {
233 reason: reason.into(),
234 }
235 }
236}
237
238pub fn resolve_col(col_map: &HashMap<String, usize>, name: &str) -> Result<usize, DataError> {
249 col_map
250 .get(name)
251 .copied()
252 .ok_or_else(|| DataError::column_not_found(col_map, name, None))
253}
254
255pub fn resolve_role_col(
260 col_map: &HashMap<String, usize>,
261 name: &str,
262 role: &str,
263) -> Result<usize, DataError> {
264 col_map
265 .get(name)
266 .copied()
267 .ok_or_else(|| DataError::column_not_found(col_map, name, Some(role)))
268}
269
270fn encoded_levels_for_column(ds: &Dataset, col: ColIdx) -> Vec<(u64, String)> {
271 let mut seen = BTreeSet::<u64>::new();
272 for value in ds.values.column(col.get()) {
273 if value.is_finite() {
274 seen.insert(value.to_bits());
275 }
276 }
277 let schema_levels = ds
278 .schema
279 .columns
280 .get(col.get())
281 .map(|column| column.levels.as_slice())
282 .unwrap_or(&[]);
283 seen.into_iter()
284 .enumerate()
285 .map(|(idx, bits)| {
286 let fallback = format!("level{}", idx + 1);
287 let label = schema_levels.get(idx).cloned().unwrap_or(fallback);
288 (bits, label)
289 })
290 .collect()
291}
292
293pub fn column_map_with_alias(
294 col_map: &HashMap<String, usize>,
295 alias: &str,
296 target_column: &str,
297) -> HashMap<String, usize> {
298 let mut aliased = col_map.clone();
299 if let Some(idx) = col_map.get(target_column).copied() {
300 aliased.entry(alias.to_string()).or_insert(idx);
301 }
302 aliased
303}
304
305pub fn build_termspec(
310 terms: &[ParsedTerm],
311 ds: &Dataset,
312 col_map: &HashMap<String, usize>,
313 inference_notes: &mut Vec<String>,
314 policy: &ResourcePolicy,
315) -> Result<TermCollectionSpec, TermBuilderError> {
316 let mut linear_terms = Vec::<LinearTermSpec>::new();
317 let mut random_terms = Vec::<RandomEffectTermSpec>::new();
318 let mut smooth_terms = Vec::<SmoothTermSpec>::new();
319 let smooth_coordinate_count = terms
320 .iter()
321 .map(|term| match term {
322 ParsedTerm::Smooth { vars, .. } => vars.len(),
323 _ => 0,
324 })
325 .sum::<usize>();
326
327 for t in terms {
328 match t {
329 ParsedTerm::Linear {
330 name,
331 explicit,
332 coefficient_min,
333 coefficient_max,
334 } => {
335 let col = resolve_col(col_map, name)?;
336 let auto_kind = ds.column_kinds.get(col).copied().ok_or_else(|| {
337 TermBuilderError::missing_column(format!(
338 "internal column-kind lookup failed for '{name}'"
339 ))
340 .to_string()
341 })?;
342 if *explicit {
343 linear_terms.push(LinearTermSpec {
344 name: name.clone(),
345 feature_col: col,
346 feature_cols: vec![col],
347 categorical_levels: vec![],
348 double_penalty: false,
351 coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
352 coefficient_min: *coefficient_min,
353 coefficient_max: *coefficient_max,
354 });
355 } else {
356 match auto_kind {
357 ColumnKindTag::Continuous | ColumnKindTag::Binary => {
358 linear_terms.push(LinearTermSpec {
359 name: name.clone(),
360 feature_col: col,
361 feature_cols: vec![col],
362 categorical_levels: vec![],
363 double_penalty: false,
365 coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
366 coefficient_min: *coefficient_min,
367 coefficient_max: *coefficient_max,
368 });
369 }
370 ColumnKindTag::Categorical => {
371 if coefficient_min.is_some() || coefficient_max.is_some() {
372 return Err(TermBuilderError::incompatible_config(format!(
373 "coefficient constraints are not supported for categorical auto-random-effect term '{name}'; use group({name}) or an unconstrained numeric term"
374 )));
375 }
376 random_terms.push(RandomEffectTermSpec {
377 name: name.clone(),
378 feature_col: col,
379 drop_first_level: false,
380 penalized: true,
381 frozen_levels: None,
382 });
383 }
384 }
385 }
386 }
387 ParsedTerm::BoundedLinear {
388 name,
389 min,
390 max,
391 prior,
392 } => {
393 let col = resolve_col(col_map, name)?;
394 let auto_kind = ds.column_kinds.get(col).copied().ok_or_else(|| {
395 TermBuilderError::missing_column(format!(
396 "internal column-kind lookup failed for '{name}'"
397 ))
398 .to_string()
399 })?;
400 if !matches!(auto_kind, ColumnKindTag::Continuous | ColumnKindTag::Binary) {
401 return Err(TermBuilderError::incompatible_config(format!(
402 "bounded() currently supports only numeric columns, got categorical '{name}'"
403 )));
404 }
405 linear_terms.push(LinearTermSpec {
406 name: name.clone(),
407 feature_col: col,
408 feature_cols: vec![col],
409 categorical_levels: vec![],
410 double_penalty: false,
411 coefficient_geometry: LinearCoefficientGeometry::Bounded {
412 min: *min,
413 max: *max,
414 prior: prior.clone(),
415 },
416 coefficient_min: None,
417 coefficient_max: None,
418 });
419 }
420 ParsedTerm::RandomEffect { name } => {
421 let col = resolve_col(col_map, name)?;
422 random_terms.push(RandomEffectTermSpec {
423 name: name.clone(),
424 feature_col: col,
425 drop_first_level: false,
426 penalized: true,
427 frozen_levels: None,
428 });
429 }
430 ParsedTerm::Smooth {
431 label,
432 vars,
433 kind,
434 options,
435 } => {
436 let smooth_vars = vars.clone();
437 let by_name = options.get("by").cloned();
438 let cols = smooth_vars
448 .iter()
449 .map(|v| resolve_col(col_map, v))
450 .collect::<Result<Vec<_>, _>>()?;
451 let mut inner_options = options.clone();
452 inner_options.remove("by");
453 inner_options.remove("ordered");
457 let shape = match inner_options.remove("shape") {
463 None => ShapeConstraint::None,
464 Some(raw) => crate::smooth::parse_shape_constraint(&raw)
465 .map_err(TermBuilderError::invalid_option)?,
466 };
467 let inner_basis = build_smooth_basis(
468 *kind,
469 &smooth_vars,
470 &cols,
471 &inner_options,
472 ds,
473 inference_notes,
474 policy,
475 smooth_coordinate_count,
476 )?;
477 if let Some(by_name) = by_name {
478 let by_col = resolve_col(col_map, &by_name)?;
479 match ds.column_kinds.get(by_col).copied().ok_or_else(|| {
480 format!("internal column-kind lookup failed for by variable '{by_name}'")
481 })? {
482 ColumnKindTag::Categorical => {
483 let levels = encoded_levels_for_column(ds, ColIdx::new(by_col));
484 let penalized_group_owner_present =
497 terms.iter().any(|other| match other {
498 ParsedTerm::RandomEffect { name } => name == &by_name,
499 ParsedTerm::Linear {
500 name,
501 explicit: false,
502 ..
503 } if name == &by_name => col_map
504 .get(name)
505 .and_then(|c| ds.column_kinds.get(*c).copied())
506 .map(|kind| matches!(kind, ColumnKindTag::Categorical))
507 .unwrap_or(false),
508 _ => false,
509 });
510 if !random_terms.iter().any(|rt| rt.name == by_name)
521 && !penalized_group_owner_present
522 {
523 random_terms.push(RandomEffectTermSpec {
524 name: by_name.clone(),
525 feature_col: by_col,
526 drop_first_level: true,
527 penalized: false,
528 frozen_levels: None,
529 });
530 }
531 let frozen_levels: Vec<u64> =
536 levels.iter().map(|(bits, _)| *bits).collect();
537 smooth_terms.push(SmoothTermSpec {
538 name: label.clone(),
539 basis: SmoothBasisSpec::BySmooth {
540 smooth: Box::new(inner_basis),
541 by_kind: ByVarKind::Factor {
542 feature_col: by_col,
543 ordered: option_bool(options, "ordered").unwrap_or(false),
544 frozen_levels: Some(frozen_levels),
545 },
546 },
547 shape,
548 joint_null_rotation: None,
549 });
550 }
551 ColumnKindTag::Binary | ColumnKindTag::Continuous => {
552 smooth_terms.push(SmoothTermSpec {
553 name: label.clone(),
554 basis: SmoothBasisSpec::BySmooth {
555 smooth: Box::new(inner_basis),
556 by_kind: ByVarKind::Numeric {
557 feature_col: by_col,
558 },
559 },
560 shape,
561 joint_null_rotation: None,
562 });
563 }
564 }
565 } else {
566 smooth_terms.push(SmoothTermSpec {
567 name: label.clone(),
568 basis: inner_basis,
569 shape,
570 joint_null_rotation: None,
571 });
572 }
573 }
574 ParsedTerm::LinkWiggle { .. }
575 | ParsedTerm::TimeWiggle { .. }
576 | ParsedTerm::LinkConfig { .. }
577 | ParsedTerm::SurvivalConfig { .. } => {
578 }
580 ParsedTerm::LogSlopeSurface { .. } => {
581 return Err(TermBuilderError::malformed_formula(
582 "logslope(...) declarations must be resolved by the marginal-slope formula path before building a term spec",
583 ));
584 }
585 ParsedTerm::Interaction { vars } => {
586 let main_effect_present = |target: &str| -> bool {
619 terms.iter().any(|other| match other {
620 ParsedTerm::Linear { name, .. }
621 | ParsedTerm::BoundedLinear { name, .. }
622 | ParsedTerm::RandomEffect { name } => name == target,
623 _ => false,
624 })
625 };
626 let parent_present = |drop_var: &str| -> bool {
632 vars.iter()
633 .filter(|v| v.as_str() != drop_var)
634 .all(|v| main_effect_present(v))
635 };
636
637 let mut numeric_cols = Vec::<usize>::new();
638 let mut categorical_factors =
641 Vec::<(String, usize, Vec<(u64, String)>, bool)>::new();
642 for var in vars {
643 let col = resolve_col(col_map, var)?;
644 let kind = ds.column_kinds.get(col).copied().ok_or_else(|| {
645 TermBuilderError::missing_column(format!(
646 "internal column-kind lookup failed for '{var}'"
647 ))
648 .to_string()
649 })?;
650 match kind {
651 ColumnKindTag::Continuous | ColumnKindTag::Binary => numeric_cols.push(col),
652 ColumnKindTag::Categorical => {
653 let mut levels = encoded_levels_for_column(ds, ColIdx::new(col));
654 let treatment_coded = parent_present(var);
658 if treatment_coded && levels.len() > 1 {
659 levels.remove(0);
660 }
661 if levels.is_empty() {
662 return Err(TermBuilderError::incompatible_config(format!(
663 "interaction `{}` references categorical column `{var}` with no usable levels",
664 vars.join(":")
665 )));
666 }
667 categorical_factors.push((var.clone(), col, levels, treatment_coded));
668 }
669 }
670 }
671
672 let label = vars.join(":");
673
674 if categorical_factors.is_empty() {
675 linear_terms.push(LinearTermSpec {
678 name: label,
679 feature_col: numeric_cols[0],
680 feature_cols: numeric_cols,
681 categorical_levels: vec![],
682 double_penalty: false,
685 coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
686 coefficient_min: None,
687 coefficient_max: None,
688 });
689 inference_notes.push(format!(
690 "wired linear interaction `{}` as product of numeric columns",
691 vars.join(":")
692 ));
693 } else {
694 let mut cells: Vec<Vec<(usize, u64, String)>> = vec![Vec::new()];
699 for (_var, col, levels, _treatment_coded) in &categorical_factors {
700 let mut next = Vec::with_capacity(cells.len() * levels.len());
701 for cell in &cells {
702 for (bits, level_label) in levels {
703 let mut extended = cell.clone();
704 extended.push((*col, *bits, level_label.clone()));
705 next.push(extended);
706 }
707 }
708 cells = next;
709 }
710
711 let any_dummy_coded = categorical_factors
723 .iter()
724 .any(|(_, _, _, treatment_coded)| !*treatment_coded);
725 if numeric_cols.is_empty() && any_dummy_coded {
726 let reference_cell: Vec<(usize, u64)> = categorical_factors
729 .iter()
730 .map(|(_, col, _, _)| {
731 let levels = encoded_levels_for_column(ds, ColIdx::new(*col));
732 (*col, levels[0].0)
733 })
734 .collect();
735 cells.retain(|cell| {
736 !reference_cell.iter().all(|(rcol, rbits)| {
737 cell.iter()
738 .any(|(col, bits, _)| col == rcol && bits == rbits)
739 })
740 });
741 }
742
743 let n_cells = cells.len();
744 for cell in cells {
745 let cell_suffix = cell
746 .iter()
747 .map(|(_, _, level_label)| level_label.as_str())
748 .collect::<Vec<_>>()
749 .join(":");
750 let categorical_levels =
751 cell.iter().map(|(col, bits, _)| (*col, *bits)).collect();
752 let feature_col = numeric_cols
758 .first()
759 .copied()
760 .unwrap_or(categorical_factors[0].1);
761 linear_terms.push(LinearTermSpec {
762 name: format!("{label}:{cell_suffix}"),
763 feature_col,
764 feature_cols: numeric_cols.clone(),
765 categorical_levels,
766 double_penalty: false,
767 coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
768 coefficient_min: None,
769 coefficient_max: None,
770 });
771 }
772 let all_treatment_coded = !any_dummy_coded;
773 let coding = if all_treatment_coded {
774 "treatment-coded"
775 } else {
776 "marginality-aware (full dummy / saturated)"
777 };
778 inference_notes.push(format!(
779 "wired factor-aware linear interaction `{}` as {} {} cell column(s)",
780 vars.join(":"),
781 n_cells,
782 coding
783 ));
784 }
785 }
786 }
787 }
788
789 Ok(TermCollectionSpec {
790 linear_terms,
791 random_effect_terms: random_terms,
792 smooth_terms,
793 })
794}
795
796fn split_list_option(raw: &str) -> Vec<String> {
797 let t = raw.trim();
798 let inner = t
805 .strip_prefix('[')
806 .and_then(|u| u.strip_suffix(']'))
807 .or_else(|| {
808 t.strip_prefix("c(")
809 .or_else(|| t.strip_prefix("C("))
810 .or_else(|| t.strip_prefix('('))
811 .and_then(|u| u.strip_suffix(')'))
812 })
813 .unwrap_or(t);
814 inner
815 .split(',')
816 .map(|v| v.trim().to_string())
817 .filter(|v| !v.is_empty())
818 .collect()
819}
820
821fn parse_numeric_expr(raw: &str) -> Result<f64, String> {
822 let mut acc = 1.0f64;
823 let normalized = raw.replace(' ', "");
824 if normalized.eq_ignore_ascii_case("none") {
825 return Err("None is not numeric".to_string());
826 }
827 for factor in normalized.split('*') {
828 if factor.is_empty() {
829 return Err(format!("invalid numeric expression '{raw}'"));
830 }
831 let value = if factor.eq_ignore_ascii_case("pi") || factor == "π" {
832 std::f64::consts::PI
833 } else if factor.eq_ignore_ascii_case("tau") || factor == "τ" {
834 std::f64::consts::TAU
835 } else if let Some(prefix) = factor
836 .strip_suffix("pi")
837 .or_else(|| factor.strip_suffix("π"))
838 {
839 let coefficient = if prefix.is_empty() {
840 1.0
841 } else {
842 prefix
843 .parse::<f64>()
844 .map_err(|err| format!("invalid numeric expression '{raw}': {err}"))?
845 };
846 coefficient * std::f64::consts::PI
847 } else if let Some(prefix) = factor
848 .strip_suffix("tau")
849 .or_else(|| factor.strip_suffix("τ"))
850 {
851 let coefficient = if prefix.is_empty() {
852 1.0
853 } else {
854 prefix
855 .parse::<f64>()
856 .map_err(|err| format!("invalid numeric expression '{raw}': {err}"))?
857 };
858 coefficient * std::f64::consts::TAU
859 } else {
860 factor
861 .parse::<f64>()
862 .map_err(|err| format!("invalid numeric expression '{raw}': {err}"))?
863 };
864 acc *= value;
865 }
866 Ok(acc)
867}
868
869fn option_numeric_expr(
879 options: &BTreeMap<String, String>,
880 key: &str,
881) -> Result<Option<f64>, String> {
882 match options.get(key) {
883 None => Ok(None),
884 Some(raw) => parse_numeric_expr(raw)
885 .map(Some)
886 .map_err(|err| format!("option `{key}={raw}` is not a valid numeric value: {err}")),
887 }
888}
889
890fn parse_periods_option(
891 options: &BTreeMap<String, String>,
892 dim: usize,
893) -> Result<Option<Vec<Option<f64>>>, String> {
894 let Some(raw) = options.get("period") else {
895 return Ok(None);
896 };
897 let values = split_list_option(raw);
898 let mut periods = vec![None; dim];
899 if values.len() == 1 && dim == 1 {
900 periods[0] = Some(parse_numeric_expr(&values[0])?);
901 } else {
902 if values.len() != dim {
903 return Err(format!(
904 "period list length {} must match smooth dimension {}",
905 values.len(),
906 dim
907 ));
908 }
909 for (i, v) in values.iter().enumerate() {
910 if v.eq_ignore_ascii_case("none") {
911 continue;
912 }
913 periods[i] = Some(parse_numeric_expr(v)?);
914 }
915 }
916 Ok(Some(periods))
917}
918
919fn parse_periodic_axes_option(
920 options: &BTreeMap<String, String>,
921 dim: usize,
922) -> Result<Option<Vec<Option<f64>>>, String> {
923 let Some(raw_axes) = options.get("periodic") else {
924 return Ok(None);
925 };
926 let mut periods = parse_periods_option(options, dim)?.unwrap_or_else(|| vec![None; dim]);
927 let axes = split_list_option(raw_axes);
928 if axes.is_empty() {
929 return Ok(Some(periods));
930 }
931 for a in axes {
932 let axis = a
933 .parse::<usize>()
934 .map_err(|err| format!("invalid periodic axis '{a}': {err}"))?;
935 if axis >= dim {
936 return Err(format!(
937 "periodic axis {axis} out of range for {dim}D smooth"
938 ));
939 }
940 if periods[axis].is_none() {
941 return Err(format!(
942 "periodic axis {axis} requires period[{axis}] to be finite"
943 ));
944 }
945 }
946 let listed: std::collections::BTreeSet<usize> = split_list_option(raw_axes)
948 .into_iter()
949 .filter_map(|a| a.parse::<usize>().ok())
950 .collect();
951 for i in 0..dim {
952 if !listed.contains(&i) {
953 periods[i] = None;
954 }
955 }
956 Ok(Some(periods))
957}
958
959fn parse_option_list(raw: &str) -> Vec<String> {
964 let trimmed = raw.trim();
965 let inner = trimmed
971 .strip_prefix('[')
972 .and_then(|v| v.strip_suffix(']'))
973 .or_else(|| {
974 trimmed
975 .strip_prefix("c(")
976 .or_else(|| trimmed.strip_prefix("C("))
977 .or_else(|| trimmed.strip_prefix('('))
978 .and_then(|v| v.strip_suffix(')'))
979 })
980 .unwrap_or(trimmed);
981 inner
982 .split(',')
983 .map(|v| {
984 v.trim()
985 .trim_matches('"')
986 .trim_matches('\'')
987 .to_ascii_lowercase()
988 })
989 .filter(|v| !v.is_empty())
990 .collect()
991}
992
993fn parse_periodic_axes(
994 options: &BTreeMap<String, String>,
995 dim: usize,
996) -> Result<Vec<bool>, String> {
997 let mut axes = vec![false; dim];
998 if let Some(raw) = options.get("periodic").or_else(|| options.get("cyclic")) {
999 let lowered = raw.trim().to_ascii_lowercase();
1000 match lowered.as_str() {
1001 "true" | "yes" | "y" => {
1002 axes.fill(true);
1003 return Ok(axes);
1004 }
1005 "false" | "no" | "n" => return Ok(axes),
1006 _ => {}
1007 }
1008 for axis_raw in parse_option_list(raw) {
1009 let axis = axis_raw
1010 .parse::<usize>()
1011 .map_err(|err| format!("invalid periodic axis '{axis_raw}': {err}"))?;
1012 if axis >= dim {
1013 return Err(format!(
1014 "periodic axis {axis} out of range for {dim}D smooth"
1015 ));
1016 }
1017 axes[axis] = true;
1018 }
1019 }
1020 if let Some(raw) = options.get("boundary").or_else(|| options.get("bc")) {
1021 let boundary = parse_option_list(raw);
1022 if boundary.len() == dim {
1023 for (axis, value) in boundary.iter().enumerate() {
1024 if matches!(value.as_str(), "periodic" | "cyclic" | "cc") {
1025 axes[axis] = true;
1026 }
1027 }
1028 } else if dim == 1
1029 && matches!(
1030 boundary.first().map(String::as_str),
1031 Some("periodic" | "cyclic" | "cc")
1032 )
1033 {
1034 axes[0] = true;
1035 }
1036 }
1037 Ok(axes)
1038}
1039
1040fn parse_optional_numeric_list(
1041 options: &BTreeMap<String, String>,
1042 keys: &[&str],
1043 dim: usize,
1044) -> Result<Vec<Option<f64>>, String> {
1045 let Some(raw) = keys.iter().find_map(|key| options.get(*key)) else {
1046 return Ok(vec![None; dim]);
1047 };
1048 let values = split_list_option(raw);
1049 let mut out = vec![None; dim];
1050 if values.len() == 1 && dim == 1 {
1051 if !values[0].eq_ignore_ascii_case("none") {
1052 out[0] = Some(parse_numeric_expr(&values[0])?);
1053 }
1054 return Ok(out);
1055 }
1056 if values.len() != dim {
1057 return Err(format!(
1058 "numeric option list length {} must match smooth dimension {}",
1059 values.len(),
1060 dim
1061 ));
1062 }
1063 for (i, value) in values.iter().enumerate() {
1064 if !value.eq_ignore_ascii_case("none") {
1065 out[i] = Some(parse_numeric_expr(value)?);
1066 }
1067 }
1068 Ok(out)
1069}
1070
1071fn parse_periods(
1072 options: &BTreeMap<String, String>,
1073 periodic_axes: &[bool],
1074) -> Result<Vec<Option<f64>>, String> {
1075 let dim = periodic_axes.len();
1076 let lone_periodic_broadcast = options
1081 .get("period")
1082 .or_else(|| options.get("periods"))
1083 .and_then(|raw| {
1084 let values = split_list_option(raw);
1085 if values.len() != 1 || dim <= 1 {
1086 return None;
1087 }
1088 let mut iter = periodic_axes.iter().enumerate().filter(|(_, p)| **p);
1089 let first = iter.next()?;
1090 if iter.next().is_some() {
1091 return None;
1092 }
1093 Some((first.0, values.into_iter().next().unwrap()))
1094 });
1095 let periods = if let Some((axis, value)) = lone_periodic_broadcast {
1096 let mut out = vec![None; dim];
1097 if !value.eq_ignore_ascii_case("none") {
1098 out[axis] = Some(parse_numeric_expr(&value)?);
1099 }
1100 out
1101 } else {
1102 parse_optional_numeric_list(options, &["period", "periods"], dim)?
1103 };
1104 for (axis, (periodic, period)) in periodic_axes.iter().zip(periods.iter()).enumerate() {
1105 if *periodic
1106 && let Some(value) = period
1107 && (!value.is_finite() || *value <= 0.0)
1108 {
1109 return Err(format!(
1110 "period for periodic axis {axis} must be finite and positive, got {value}"
1111 ));
1112 }
1113 }
1114 Ok(periods)
1115}
1116
1117fn parse_period_origins(
1118 options: &BTreeMap<String, String>,
1119 periodic_axes: &[bool],
1120) -> Result<Vec<Option<f64>>, String> {
1121 parse_optional_numeric_list(
1122 options,
1123 &[
1124 "origin",
1125 "origins",
1126 "period_origin",
1127 "period-origin",
1128 "domain_origin",
1129 ],
1130 periodic_axes.len(),
1131 )
1132}
1133
1134fn parse_tensor_periodic_axes(
1142 options: &BTreeMap<String, String>,
1143 dim: usize,
1144) -> Result<Vec<bool>, String> {
1145 let mut axes = vec![false; dim];
1146 if let Some(raw) = options.get("periodic").or_else(|| options.get("cyclic")) {
1147 let lowered = raw.trim().to_ascii_lowercase();
1148 match lowered.as_str() {
1149 "true" | "yes" | "y" => {
1150 axes.fill(true);
1151 }
1152 "false" | "no" | "n" => {
1153 }
1155 _ => {
1156 let entries = parse_option_list(raw);
1157 let all_bool = !entries.is_empty()
1158 && entries.iter().all(|v| {
1159 matches!(
1160 v.as_str(),
1161 "true" | "yes" | "y" | "false" | "no" | "n" | "none"
1162 )
1163 });
1164 if all_bool {
1165 if entries.len() != dim {
1166 return Err(format!(
1167 "periodic list length {} must match smooth dimension {}",
1168 entries.len(),
1169 dim
1170 ));
1171 }
1172 for (i, v) in entries.iter().enumerate() {
1173 axes[i] = matches!(v.as_str(), "true" | "yes" | "y");
1174 }
1175 } else {
1176 for axis_raw in entries {
1177 let axis = axis_raw
1178 .parse::<usize>()
1179 .map_err(|err| format!("invalid periodic axis '{axis_raw}': {err}"))?;
1180 if axis >= dim {
1181 return Err(format!(
1182 "periodic axis {axis} out of range for {dim}D smooth"
1183 ));
1184 }
1185 axes[axis] = true;
1186 }
1187 }
1188 }
1189 }
1190 }
1191 if let Some(raw) = options.get("boundary").or_else(|| options.get("bc")) {
1192 let boundary = parse_option_list(raw);
1193 if boundary.len() == dim {
1194 for (axis, value) in boundary.iter().enumerate() {
1195 if matches!(value.as_str(), "periodic" | "cyclic" | "cc") {
1196 axes[axis] = true;
1197 }
1198 }
1199 }
1200 }
1201 Ok(axes)
1202}
1203
1204fn tensor_k_axis_option_axis(
1205 key: &str,
1206 cols: &[usize],
1207 ds: &Dataset,
1208) -> Result<Option<usize>, String> {
1209 let Some(suffix) = key.strip_prefix("k_") else {
1210 return Ok(None);
1211 };
1212 if suffix.is_empty() {
1213 return Err("tensor k axis option must be named k_<axis> or k_<variable>".to_string());
1214 }
1215 if let Ok(axis) = suffix.parse::<usize>() {
1216 return if axis < cols.len() {
1217 Ok(Some(axis))
1218 } else {
1219 Err(format!(
1220 "tensor k axis option `{key}` references axis {axis}, but the smooth has {} margins",
1221 cols.len()
1222 ))
1223 };
1224 }
1225
1226 let mut matches = cols
1227 .iter()
1228 .enumerate()
1229 .filter(|(_, col)| ds.headers.get(**col).is_some_and(|name| name == suffix))
1230 .map(|(axis, _)| axis);
1231 let first = matches.next();
1232 if matches.next().is_some() {
1233 return Err(format!(
1234 "tensor k axis option `{key}` matches more than one margin named `{suffix}`"
1235 ));
1236 }
1237 first.map(Some).ok_or_else(|| {
1238 let margin_names = cols
1239 .iter()
1240 .enumerate()
1241 .map(|(axis, col)| {
1242 let name = ds
1243 .headers
1244 .get(*col)
1245 .map(String::as_str)
1246 .unwrap_or("<unnamed>");
1247 format!("{axis}:{name}")
1248 })
1249 .collect::<Vec<_>>()
1250 .join(", ");
1251 format!(
1252 "tensor k axis option `{key}` does not match a margin index or name; tensor margins are [{margin_names}]"
1253 )
1254 })
1255}
1256
1257fn is_tensor_k_axis_option_key(key: &str) -> bool {
1258 key.strip_prefix("k_")
1259 .is_some_and(|suffix| !suffix.is_empty())
1260}
1261
1262fn parse_tensor_k_list(
1266 options: &BTreeMap<String, String>,
1267 cols: &[usize],
1268 ds: &Dataset,
1269) -> Result<(Vec<usize>, bool), String> {
1270 let mut axis_values = vec![None; cols.len()];
1271 let mut saw_axis_alias = false;
1272 for (key, value) in options {
1273 let Some(axis) = tensor_k_axis_option_axis(key, cols, ds)? else {
1274 continue;
1275 };
1276 saw_axis_alias = true;
1277 if axis_values[axis].is_some() {
1278 return Err(format!("tensor k axis {axis} is specified more than once"));
1279 }
1280 let k: usize = value
1281 .parse()
1282 .map_err(|err| format!("invalid tensor k option `{key}={value}`: {err}"))?;
1283 axis_values[axis] = Some(k);
1284 }
1285
1286 let raw = options
1287 .get("k")
1288 .or_else(|| options.get("basis_dim"))
1289 .or_else(|| options.get("basis-dim"))
1290 .or_else(|| options.get("basisdim"));
1291 if saw_axis_alias {
1292 if raw.is_some() {
1293 return Err(
1294 "tensor k axis aliases cannot be combined with k= or basis_dim=".to_string(),
1295 );
1296 }
1297 if let Some(missing_axis) = axis_values.iter().position(Option::is_none) {
1298 let margin_name = cols
1299 .get(missing_axis)
1300 .and_then(|col| ds.headers.get(*col))
1301 .map(String::as_str)
1302 .unwrap_or("<unnamed>");
1303 return Err(format!(
1304 "tensor k axis aliases must specify every margin; missing axis {missing_axis} ({margin_name})"
1305 ));
1306 }
1307 return Ok((
1308 axis_values
1309 .into_iter()
1310 .map(|k| k.expect("missing axis values rejected above"))
1311 .collect(),
1312 false,
1313 ));
1314 }
1315 let Some(raw) = raw else {
1316 let inferred = heuristic_tensor_margin_knots(cols, ds);
1317 return Ok((inferred, true));
1318 };
1319 let entries = split_list_option(raw);
1320 if entries.len() == 1 {
1321 let k: usize = entries[0]
1322 .parse()
1323 .map_err(|err| format!("invalid tensor k '{}': {err}", entries[0]))?;
1324 return Ok((vec![k; cols.len()], false));
1325 }
1326 if entries.len() != cols.len() {
1327 return Err(format!(
1328 "tensor k list length {} must match smooth dimension {}",
1329 entries.len(),
1330 cols.len()
1331 ));
1332 }
1333 let mut out = Vec::with_capacity(entries.len());
1334 for entry in entries {
1335 let k: usize = entry
1336 .parse()
1337 .map_err(|err| format!("invalid tensor k '{entry}': {err}"))?;
1338 out.push(k);
1339 }
1340 Ok((out, false))
1341}
1342
1343fn parse_tensor_identifiability(
1352 options: &BTreeMap<String, String>,
1353 kind: SmoothKind,
1354) -> Result<TensorBSplineIdentifiability, String> {
1355 let Some(raw) = options.get("identifiability").map(String::as_str) else {
1356 return Ok(match kind {
1357 SmoothKind::Ti => TensorBSplineIdentifiability::MarginalSumToZero,
1358 _ => TensorBSplineIdentifiability::default(),
1359 });
1360 };
1361 match raw.trim().to_ascii_lowercase().as_str() {
1362 "none" => Ok(TensorBSplineIdentifiability::None),
1363 "sum_tozero" | "sum-to-zero" | "center_sum_tozero" | "center-sum-to-zero" | "centered"
1364 | "sumtozero" => Ok(TensorBSplineIdentifiability::SumToZero),
1365 "marginal_sum_tozero" | "marginal-sum-to-zero" | "marginal_sumtozero"
1366 | "marginalsumtozero" | "interaction" => {
1367 Ok(TensorBSplineIdentifiability::MarginalSumToZero)
1368 }
1369 other => Err(TermBuilderError::unsupported_feature(format!(
1370 "invalid tensor identifiability '{other}'; expected one of: none, sum_tozero, marginal_sum_tozero"
1371 ))
1372 .to_string()),
1373 }
1374}
1375
1376fn bspline_boundary_declares_periodic_axis(options: &BTreeMap<String, String>) -> bool {
1377 options
1378 .get("boundary")
1379 .or_else(|| options.get("bc"))
1380 .map(|raw| {
1381 parse_option_list(raw)
1382 .into_iter()
1383 .any(|value| matches!(value.as_str(), "periodic" | "cyclic" | "cc"))
1384 })
1385 .unwrap_or(false)
1386}
1387
1388pub(crate) fn canonicalize_smooth_type(raw: &str) -> &str {
1410 match raw {
1411 "tp" => "tps",
1414 "gp" => "matern",
1419 "curv" | "constant_curvature" | "mkappa" => "curvature",
1423 "mjs" | "measure_jet" | "web" => "measurejet",
1427 other => other,
1428 }
1429}
1430
1431pub(crate) fn tensor_margin_bs_is_supported(margin_bs: &str) -> bool {
1442 matches!(
1443 canonicalize_smooth_type(margin_bs),
1444 "tps" | "ps" | "bs" | "bspline" | "cr" | "cs" | "cc" | "cp" | "cyclic"
1445 )
1446}
1447
1448pub(crate) fn smooth_options_declare_periodic(options: &BTreeMap<String, String>) -> bool {
1454 options.contains_key("periodic")
1455 || options.contains_key("cyclic")
1456 || options
1457 .get("boundary")
1458 .or_else(|| options.get("bc"))
1459 .map(|boundary| {
1460 boundary.to_ascii_lowercase().contains("periodic")
1461 || boundary.to_ascii_lowercase().contains("cyclic")
1462 })
1463 .unwrap_or(false)
1464}
1465
1466pub(crate) fn bs_selector_is_vector(raw: &str) -> bool {
1483 let trimmed = raw.trim();
1484 let bracketed = (trimmed.starts_with('[') && trimmed.ends_with(']'))
1485 || (trimmed.starts_with("c(") || trimmed.starts_with("C(")) && trimmed.ends_with(')')
1486 || (trimmed.starts_with('(') && trimmed.ends_with(')'));
1487 bracketed && !parse_option_list(trimmed).is_empty()
1488}
1489
1490pub fn resolve_smooth_type_name(
1491 kind: SmoothKind,
1492 n_cols: usize,
1493 options: &BTreeMap<String, String>,
1494) -> String {
1495 let selector = options.get("type").or_else(|| options.get("bs"));
1496 if let Some(raw) = selector
1501 && bs_selector_is_vector(raw)
1502 && matches!(kind, SmoothKind::Te | SmoothKind::Ti | SmoothKind::T2)
1503 {
1504 return "tensor".to_string();
1505 }
1506 selector
1507 .map(|s| canonicalize_smooth_type(&s.to_ascii_lowercase()).to_string())
1508 .unwrap_or_else(|| match kind {
1509 SmoothKind::Te | SmoothKind::Ti | SmoothKind::T2 => "tensor".to_string(),
1510 SmoothKind::S if n_cols == 1 => "bspline".to_string(),
1511 SmoothKind::S if smooth_options_declare_periodic(options) => "tensor".to_string(),
1515 SmoothKind::S => "tps".to_string(),
1516 })
1517}
1518
1519pub fn smooth_type_uses_spatial_center_heuristic(canonical_type: &str) -> bool {
1528 matches!(canonical_type, "tps" | "matern" | "duchon")
1529}
1530
1531pub fn build_smooth_basis(
1532 kind: SmoothKind,
1533 vars: &[String],
1534 cols: &[usize],
1535 options: &BTreeMap<String, String>,
1536 ds: &Dataset,
1537 inference_notes: &mut Vec<String>,
1538 policy: &ResourcePolicy,
1539 smooth_coordinate_count: usize,
1540) -> Result<SmoothBasisSpec, String> {
1541 let coord_cols: Vec<(&String, usize)> = vars
1557 .iter()
1558 .zip(cols.iter().copied())
1559 .filter(|(_, col)| !matches!(ds.column_kinds.get(*col), Some(ColumnKindTag::Categorical)))
1560 .collect();
1561 if !coord_cols.is_empty() {
1562 let views: Vec<ArrayView1<'_, f64>> = coord_cols
1563 .iter()
1564 .map(|(_, col)| ds.values.column(*col))
1565 .collect();
1566 let n_rows = views[0].len();
1567 let mut distinct_points = std::collections::HashSet::<Vec<u64>>::new();
1568 for r in 0..n_rows {
1569 let key: Vec<u64> = views
1570 .iter()
1571 .map(|v| {
1572 let x = v[r];
1573 let norm = if x == 0.0 { 0.0 } else { x };
1574 norm.to_bits()
1575 })
1576 .collect();
1577 distinct_points.insert(key);
1578 if distinct_points.len() > 1 {
1579 break;
1580 }
1581 }
1582 if distinct_points.len() <= 1 {
1583 return Err(TermBuilderError::degenerate_data(if coord_cols.len() == 1 {
1584 let var = coord_cols[0].0;
1585 format!(
1586 "smooth term over '{var}' has only one unique value in the training data \
1587 — a smooth on a constant column is degenerate and would only fit the response mean. \
1588 Remove `{var}` from the smooth, drop the term, or check the data."
1589 )
1590 } else {
1591 let names = coord_cols
1592 .iter()
1593 .map(|(v, _)| v.as_str())
1594 .collect::<Vec<_>>()
1595 .join(", ");
1596 format!(
1597 "smooth term over ({names}) has only one unique joint coordinate in the training \
1598 data — every coordinate is constant, so the smooth is degenerate and would only \
1599 fit the response mean. Drop the term or check the data."
1600 )
1601 })
1602 .to_string());
1603 }
1604 }
1605 if let Some(by_name) = options.get("by").cloned() {
1606 let by_col = options
1607 .get("__by_col")
1608 .and_then(|raw| raw.parse::<usize>().ok())
1609 .or_else(|| vars.iter().position(|v| v == &by_name).map(|idx| cols[idx]))
1610 .ok_or_else(|| format!("unknown by= column '{by_name}'"))?;
1611 let mut inner_options = options.clone();
1612 inner_options.remove("by");
1613 inner_options.remove("__by_col");
1614 inner_options.remove("id");
1615 let inner = build_smooth_basis(
1616 kind,
1617 vars,
1618 cols,
1619 &inner_options,
1620 ds,
1621 inference_notes,
1622 policy,
1623 smooth_coordinate_count,
1624 )?;
1625 let by_kind = match ds.column_kinds.get(by_col).copied() {
1626 Some(ColumnKindTag::Categorical) => ByVarKind::Factor {
1627 feature_col: by_col,
1628 ordered: option_bool(options, "ordered").unwrap_or(false),
1629 frozen_levels: None,
1630 },
1631 Some(ColumnKindTag::Continuous | ColumnKindTag::Binary) => ByVarKind::Numeric {
1632 feature_col: by_col,
1633 },
1634 None => {
1635 return Err(format!(
1636 "internal column-kind lookup failed for by='{by_name}'"
1637 ));
1638 }
1639 };
1640 return Ok(SmoothBasisSpec::BySmooth {
1641 smooth: Box::new(inner),
1642 by_kind,
1643 });
1644 }
1645
1646 let smooth_double_penalty = option_bool(options, "double_penalty").unwrap_or(true);
1647 let type_opt = resolve_smooth_type_name(kind, cols.len(), options);
1648
1649 if matches!(type_opt.as_str(), "fs" | "sz" | "re") {
1650 validate_known_options(
1651 type_opt.as_str(),
1652 options,
1653 &[
1654 "type",
1655 "bs",
1656 "k",
1657 "basis_dim",
1658 "basis-dim",
1659 "basisdim",
1660 "knots",
1661 "knot_placement",
1662 "knot-placement",
1663 "knotplacement",
1664 "degree",
1665 "penalty_order",
1666 "m",
1667 "double_penalty",
1668 "ordered",
1669 ],
1670 )?;
1671 if cols.len() != 2 {
1672 return Err(format!(
1673 "{} factor-smooth currently expects exactly two variables (one numeric, one categorical)",
1674 type_opt
1675 ));
1676 }
1677 let kinds = cols
1678 .iter()
1679 .map(|&c| ds.column_kinds.get(c).copied())
1680 .collect::<Vec<_>>();
1681 let (cont_idx, group_idx) = if type_opt == "re" {
1682 match (kinds[0], kinds[1]) {
1684 (Some(ColumnKindTag::Categorical), _) => (1usize, 0usize),
1685 (_, Some(ColumnKindTag::Categorical)) => (0usize, 1usize),
1686 _ => (1usize, 0usize),
1687 }
1688 } else {
1689 match (kinds[0], kinds[1]) {
1690 (_, Some(ColumnKindTag::Categorical)) => (0usize, 1usize),
1691 (Some(ColumnKindTag::Categorical), _) => (1usize, 0usize),
1692 _ => {
1693 return Err(format!(
1694 "{} factor-smooth requires one categorical factor variable",
1695 type_opt
1696 ));
1697 }
1698 }
1699 };
1700 let c = cols[cont_idx];
1701 let (minv, maxv) = col_minmax(ds.values.column(c))?;
1702 let degree = if type_opt == "re" {
1703 1
1704 } else {
1705 option_usize(options, "degree").unwrap_or(DEFAULT_BSPLINE_DEGREE)
1706 };
1707 let pooled_internal = heuristic_knots_for_column(ds.values.column(c));
1727 let default_internal = if type_opt == "re" {
1728 0
1741 } else {
1742 let min_group_resolution =
1743 min_per_group_unique_count(ds.values.column(c), ds.values.column(cols[group_idx]));
1744 let basis_cap = min_group_resolution.saturating_sub(2).max(degree + 2);
1752 let internal_cap = basis_cap.saturating_sub(degree + 1);
1753 let capped = pooled_internal.min(internal_cap.max(1));
1754 let fs_default_internal = FACTOR_SMOOTH_DEFAULT_BASIS_DIM
1770 .saturating_sub(degree + 1)
1771 .max(1);
1772 capped.min(fs_default_internal)
1773 };
1774 let (n_knots, _, effective_degree) =
1775 parse_ps_internal_knots(options, degree, default_internal)?;
1776 let penalty_order = option_usize(options, "penalty_order")
1777 .unwrap_or(if effective_degree > 1 { 2 } else { 1 })
1778 .min(effective_degree);
1779 let marginal_knotspec = resolve_nonperiodic_bspline_knotspec(
1813 options,
1814 ds.values.column(c),
1815 (minv, maxv),
1816 effective_degree,
1817 n_knots,
1818 )?;
1819 let marginal = BSplineBasisSpec {
1820 degree: effective_degree,
1821 penalty_order,
1822 knotspec: marginal_knotspec,
1823 double_penalty: option_bool(options, "double_penalty")
1834 .unwrap_or(type_opt.as_str() != "sz"),
1835 identifiability: BSplineIdentifiability::None,
1836 boundary_conditions: Default::default(),
1837 boundary: OneDimensionalBoundary::Open,
1838 };
1839 let flavour = match type_opt.as_str() {
1840 "fs" => FactorSmoothFlavour::Fs {
1841 m_null_penalty_orders: vec![
1842 option_usize(options, "m").unwrap_or(DEFAULT_PENALTY_ORDER),
1843 ],
1844 },
1845 "sz" => FactorSmoothFlavour::Sz,
1846 "re" => FactorSmoothFlavour::Re,
1847 other => {
1849 return Err(format!(
1850 "internal: factor-smooth flavour dispatch reached unexpected type `{}`",
1851 other
1852 ));
1853 }
1854 };
1855 return Ok(SmoothBasisSpec::FactorSmooth {
1856 spec: FactorSmoothSpec {
1857 continuous_cols: vec![c],
1858 group_col: cols[group_idx],
1859 marginal,
1860 flavour,
1861 group_frozen_levels: None,
1862 frozen_global_orthogonality: None,
1863 },
1864 });
1865 }
1866
1867 match type_opt.as_str() {
1868 "cyclic" | "cc" | "cp" | "cyclic-ps" => {
1869 validate_known_options(
1870 "cyclic",
1871 options,
1872 &[
1873 "type",
1874 "bs",
1875 "by",
1876 "k",
1877 "basis_dim",
1878 "basis-dim",
1879 "basisdim",
1880 "degree",
1881 "penalty_order",
1882 "period",
1883 "periods",
1884 "period_start",
1885 "period_end",
1886 "start",
1887 "end",
1888 "origin",
1889 "origins",
1890 "period_origin",
1891 "period-origin",
1892 "domain_origin",
1893 "double_penalty",
1894 "id",
1895 "__by_col",
1896 "identifiability",
1897 ],
1898 )?;
1899 if cols.len() != 1 {
1900 return Err(format!(
1901 "periodic smooth expects one variable, got {}",
1902 cols.len()
1903 ));
1904 }
1905 let c = cols[0];
1906 let (minv, maxv) = col_minmax(ds.values.column(c))?;
1907 let degree = option_usize(options, "degree").unwrap_or(DEFAULT_BSPLINE_DEGREE);
1908 let mut default_internal = heuristic_knots_for_column(ds.values.column(c));
1909 if ds.values.nrows() <= 32 && smooth_coordinate_count >= 5 {
1910 default_internal = default_internal.min(1);
1911 }
1912 let cyclic_default_basis_cap = CYCLIC_DEFAULT_BASIS_DIM.max(degree + 1);
1926 let default_basis = (default_internal + degree + 1).min(cyclic_default_basis_cap);
1927 let num_basis = option_usize_any(options, &["k", "basis_dim", "basis-dim", "basisdim"])
1928 .unwrap_or(default_basis);
1929 if num_basis < degree + 1 {
1930 return Err(format!(
1931 "periodic smooth: k={} too small for degree {}; expected k >= {}",
1932 num_basis,
1933 degree,
1934 degree + 1
1935 ));
1936 }
1937 let periodic_axes = [true];
1948 let periods = parse_periods(options, &periodic_axes)?;
1949 let origins = parse_period_origins(options, &periodic_axes)?;
1950 let (domain_start, period) = if let Some(p) = periods[0] {
1951 (origins[0].unwrap_or(minv), p)
1952 } else {
1953 parse_periodic_domain_1d(options, minv, maxv)?
1954 };
1955 Ok(SmoothBasisSpec::BSpline1D {
1956 feature_col: c,
1957 spec: BSplineBasisSpec {
1958 degree,
1959 penalty_order: option_usize(options, "penalty_order")
1960 .unwrap_or(DEFAULT_PENALTY_ORDER),
1961 knotspec: BSplineKnotSpec::PeriodicUniform {
1962 data_range: (domain_start, domain_start + period),
1963 num_basis,
1964 },
1965 double_penalty: smooth_double_penalty,
1966 identifiability: BSplineIdentifiability::default(),
1967 boundary_conditions: Default::default(),
1968 boundary: OneDimensionalBoundary::Cyclic {
1969 start: domain_start,
1970 end: domain_start + period,
1971 },
1972 },
1973 })
1974 }
1975 "bspline" | "ps" | "p-spline" | "cr" | "cs" => {
1976 let validation_name = match type_opt.as_str() {
1990 "cr" => "cr",
1991 "cs" => "cs",
1992 _ => "bspline",
1993 };
1994 validate_known_options(
1995 validation_name,
1996 options,
1997 &[
1998 "type",
1999 "bs",
2000 "by",
2001 "k",
2002 "basis_dim",
2003 "basis-dim",
2004 "basisdim",
2005 "knots",
2006 "knot_placement",
2007 "knot-placement",
2008 "knotplacement",
2009 "degree",
2010 "penalty_order",
2011 "boundary",
2012 "bc",
2013 "boundary_conditions",
2014 "bc_left",
2015 "bc_right",
2016 "left_bc",
2017 "right_bc",
2018 "start_bc",
2019 "end_bc",
2020 "side",
2021 "anchor",
2022 "anchor_value",
2023 "value",
2024 "anchor_left",
2025 "left_anchor",
2026 "anchor_right",
2027 "right_anchor",
2028 "periodic",
2029 "period",
2030 "periods",
2031 "period_start",
2032 "period_end",
2033 "origin",
2034 "double_penalty",
2035 "by",
2036 "id",
2037 "__by_col",
2038 "identifiability",
2039 "by",
2040 ],
2041 )?;
2042 if cols.len() != 1 {
2043 return Err(TermBuilderError::incompatible_config(format!(
2044 "bspline smooth expects one variable, got {}",
2045 cols.len()
2046 ))
2047 .to_string());
2048 }
2049 let c = cols[0];
2050 let (minv, maxv) = col_minmax(ds.values.column(c))?;
2051 let degree = option_usize(options, "degree").unwrap_or(DEFAULT_BSPLINE_DEGREE);
2052 let default_internal = heuristic_knots_for_column(ds.values.column(c));
2053 let (mut n_knots, inferred, effective_degree) =
2054 parse_ps_internal_knots(options, degree, default_internal)?;
2055 let periodic_axes = parse_periodic_axes(options, 1).map_err(|e| e.to_string())?;
2056 if periodic_axes[0] && effective_degree != degree {
2061 return Err(TermBuilderError::invalid_option(format!(
2062 "periodic smooth: k={} too small for degree {}; expected k >= {}",
2063 effective_degree + 1,
2064 degree,
2065 degree + 1
2066 ))
2067 .to_string());
2068 }
2069 if inferred && ds.values.nrows() <= 32 && smooth_coordinate_count >= 5 {
2070 n_knots = n_knots.min(1);
2071 }
2072 if inferred {
2073 let unique = unique_count_column(ds.values.column(c));
2074 let ceiling = ((unique as f64).cbrt() as usize).max(20);
2075 inference_notes.push(format!(
2076 "Automatically set {} internal knots for smooth '{}' from {} unique values (rule: clamp(unique/4, 4..max(20, cbrt(unique))) = clamp(unique/4, 4..{})). Override with knots=... or k=....",
2077 n_knots,
2078 vars.join(","),
2079 unique,
2080 ceiling,
2081 ));
2082 }
2083 let boundary_conditions =
2084 if periodic_axes[0] && bspline_boundary_declares_periodic_axis(options) {
2085 BSplineBoundaryConditions::default()
2086 } else {
2087 parse_bspline_boundary_conditions(options).map_err(|e| e.to_string())?
2088 };
2089 let periods = parse_periods(options, &periodic_axes).map_err(|e| e.to_string())?;
2090 let origins =
2091 parse_period_origins(options, &periodic_axes).map_err(|e| e.to_string())?;
2092 let (knotspec, boundary) = if periodic_axes[0] {
2093 if !boundary_conditions.is_free() {
2094 return Err(TermBuilderError::incompatible_config(
2095 "periodic B-splines cannot also declare endpoint boundary conditions",
2096 )
2097 .to_string());
2098 }
2099 {
2100 let (domain_start, p_value) = if periods[0].is_some() {
2101 (origins[0].unwrap_or(minv), periods[0].unwrap())
2102 } else {
2103 parse_periodic_domain_1d(options, minv, maxv).map_err(|e| e.to_string())?
2104 };
2105 let domain_end = domain_start + p_value;
2106 (
2107 BSplineKnotSpec::PeriodicUniform {
2108 data_range: (domain_start, domain_end),
2109 num_basis: n_knots + effective_degree + 1,
2110 },
2111 OneDimensionalBoundary::Cyclic {
2112 start: domain_start,
2113 end: domain_end,
2114 },
2115 )
2116 }
2117 } else if type_opt == "cr" || type_opt == "cs" {
2118 let k_cr = (n_knots + effective_degree + 1).max(CR_MIN_KNOTS);
2135 let knotspec = match capped_cr_marginal_knotspec(
2136 ds.values.column(c),
2137 k_cr,
2138 &vars.join(","),
2139 inference_notes,
2140 )? {
2141 Some(cr_knotspec) => cr_knotspec,
2142 None => resolve_nonperiodic_bspline_knotspec(
2143 options,
2144 ds.values.column(c),
2145 (minv, maxv),
2146 effective_degree,
2147 n_knots,
2148 )?,
2149 };
2150 (knotspec, parse_cyclic_boundary(options, minv, maxv)?)
2151 } else {
2152 (
2153 resolve_nonperiodic_bspline_knotspec(
2154 options,
2155 ds.values.column(c),
2156 (minv, maxv),
2157 effective_degree,
2158 n_knots,
2159 )?,
2160 parse_cyclic_boundary(options, minv, maxv)?,
2161 )
2162 };
2163 let double_penalty = if type_opt == "cr" {
2167 option_bool(options, "double_penalty").unwrap_or(false)
2168 } else {
2169 smooth_double_penalty
2170 };
2171 let penalty_order = option_usize(options, "penalty_order")
2176 .unwrap_or(DEFAULT_PENALTY_ORDER)
2177 .min(effective_degree);
2178 Ok(SmoothBasisSpec::BSpline1D {
2179 feature_col: c,
2180 spec: BSplineBasisSpec {
2181 degree: effective_degree,
2182 penalty_order,
2183 knotspec,
2184 double_penalty,
2185 identifiability: BSplineIdentifiability::default(),
2186 boundary,
2187 boundary_conditions,
2188 },
2189 })
2190 }
2191 "tps" | "thinplate" | "thin-plate" => {
2192 validate_known_options(
2193 "thinplate",
2194 options,
2195 &[
2196 SECONDARY_CENTER_CAP_OPTION,
2197 "type",
2198 "bs",
2199 "by",
2200 "length_scale",
2201 "centers",
2202 "k",
2203 "basis_dim",
2204 "basis-dim",
2205 "basisdim",
2206 "knots",
2207 "include_intercept",
2208 "double_penalty",
2209 "by",
2210 "id",
2211 "__by_col",
2212 "identifiability",
2213 "by",
2214 "scale_dims",
2215 ],
2216 )?;
2217 let plan = plan_spatial_basis(
2218 ds.values.nrows(),
2219 cols.len(),
2220 CenterCountRequest::Default,
2221 DuchonNullspaceOrder::Linear,
2222 option_bool(options, "scale_dims").unwrap_or(false),
2223 policy,
2224 )
2225 .map_err(|e| e.to_string())?;
2226 let default_centers = plan.centers;
2236 let centers = parse_countwith_basis_alias(
2237 options,
2238 "centers",
2239 cap_default_spatial_centers(options, default_centers),
2240 )?;
2241 let center_strategy = if has_explicit_countwith_basis_alias(options, "centers") {
2242 spatial_center_strategy_for_dimension(centers, cols.len())
2243 } else {
2244 auto_spatial_center_strategy(centers, cols.len())
2245 };
2246 Ok(SmoothBasisSpec::ThinPlate {
2247 feature_cols: cols.to_vec(),
2248 spec: ThinPlateBasisSpec {
2249 center_strategy,
2250 periodic: parse_periodic_axes_option(options, cols.len())?,
2251 length_scale: option_f64(options, "length_scale").unwrap_or(0.0),
2259 double_penalty: smooth_double_penalty,
2260 identifiability: parse_spatial_identifiability(options)
2261 .map_err(|e| e.to_string())?,
2262 radial_reparam: None,
2263 },
2264 input_scales: None,
2265 })
2266 }
2267 "sphere" | "s2" | "sos" => {
2268 validate_known_options(
2269 "sphere",
2270 options,
2271 &[
2272 "type",
2273 "bs",
2274 "by",
2275 "centers",
2276 "k",
2277 "basis_dim",
2278 "basis-dim",
2279 "basisdim",
2280 "knots",
2281 "penalty_order",
2282 "m",
2283 "double_penalty",
2284 "id",
2285 "__by_col",
2286 "kernel",
2287 "method",
2288 "radians",
2289 "units",
2290 "degree",
2291 "l",
2292 "max_degree",
2293 "max-degree",
2294 ],
2295 )?;
2296 if cols.len() != 2 {
2297 return Err(format!(
2298 "sphere smooth expects exactly two variables (lat, lon), got {}",
2299 cols.len()
2300 ));
2301 }
2302 let radians = option_bool(options, "radians").unwrap_or_else(|| {
2303 options
2304 .get("units")
2305 .map(|u| u.eq_ignore_ascii_case("radian") || u.eq_ignore_ascii_case("radians"))
2306 .unwrap_or(false)
2307 });
2308 let degree_requested = options.contains_key("degree")
2314 || options.contains_key("l")
2315 || options.contains_key("max_degree")
2316 || options.contains_key("max-degree");
2317 let kernel = options
2318 .get("kernel")
2319 .or_else(|| options.get("method"))
2320 .map(|raw| strip_quotes(raw).trim().to_ascii_lowercase())
2321 .unwrap_or_else(|| {
2322 if degree_requested {
2323 "harmonic".to_string()
2324 } else {
2325 "sobolev".to_string()
2326 }
2327 });
2328 let (method, wahba_kernel) = match kernel.as_str() {
2329 "sobolev" | "wahba" | "wahba_sobolev" | "wahba-sobolev" => {
2330 (SphereMethod::Wahba, SphereWahbaKernel::Sobolev)
2331 }
2332 "pseudo" | "mgcv" | "sos" | "wahba_pseudo" | "wahba-pseudo" => {
2333 (SphereMethod::Wahba, SphereWahbaKernel::Pseudo)
2334 }
2335 "harmonic" | "spherical_harmonic" | "spherical-harmonic" => {
2336 (SphereMethod::Harmonic, SphereWahbaKernel::Sobolev)
2337 }
2338 other => {
2339 return Err(format!(
2340 "unsupported sphere kernel '{other}'; expected sobolev, pseudo, or harmonic"
2341 ));
2342 }
2343 };
2344 let max_degree = if matches!(method, SphereMethod::Harmonic) {
2345 let degree =
2346 option_usize_any(options, &["degree", "l", "max_degree", "max-degree"])
2347 .or_else(|| option_usize(options, "centers"))
2348 .or_else(|| {
2349 option_usize_any(options, &["k", "basis_dim", "basis-dim", "basisdim"])
2350 .and_then(|k| (1..=128).find(|&l| l * (l + 2) >= k))
2351 })
2352 .unwrap_or_else(|| default_spherical_harmonic_degree(ds.values.nrows()));
2353 if degree == 0 {
2354 return Err("sphere smooth requires degree/max_degree >= 1".to_string());
2355 }
2356 if degree > 32 {
2357 return Err(format!(
2358 "sphere smooth max_degree={} is too large for the dense harmonic engine (limit 32)",
2359 degree
2360 ));
2361 }
2362 Some(degree)
2363 } else {
2364 None
2365 };
2366 let penalty_order = option_usize(options, "penalty_order")
2367 .or_else(|| option_usize(options, "m"))
2368 .unwrap_or(DEFAULT_PENALTY_ORDER);
2369 let center_strategy = if matches!(method, SphereMethod::Wahba) {
2370 let mut centers = parse_countwith_basis_alias(
2371 options,
2372 "centers",
2373 default_num_centers(ds.values.nrows(), cols.len()),
2374 )?;
2375 if penalty_order >= 4 {
2376 centers = centers.max(30);
2377 }
2378 CenterStrategy::FarthestPoint {
2379 num_centers: centers,
2380 }
2381 } else {
2382 CenterStrategy::FarthestPoint { num_centers: 0 }
2383 };
2384 Ok(SmoothBasisSpec::Sphere {
2385 feature_cols: cols.to_vec(),
2386 spec: SphericalSplineBasisSpec {
2387 center_strategy,
2388 penalty_order,
2389 double_penalty: smooth_double_penalty,
2390 radians,
2391 method,
2392 max_degree,
2393 wahba_kernel,
2394 identifiability: SphericalSplineIdentifiability::CenterSumToZero,
2395 },
2396 })
2397 }
2398 "curvature" => {
2399 validate_known_options(
2405 "curvature",
2406 options,
2407 &[
2408 "type",
2409 "bs",
2410 "by",
2411 "centers",
2412 "k",
2413 "basis_dim",
2414 "basis-dim",
2415 "basisdim",
2416 "knots",
2417 "kappa",
2418 "length_scale",
2419 "double_penalty",
2420 "id",
2421 "__by_col",
2422 ],
2423 )?;
2424 let kappa = option_f64(options, "kappa").unwrap_or(0.0);
2425 if !kappa.is_finite() {
2426 return Err("curvature smooth requires a finite kappa".to_string());
2427 }
2428 let length_scale = option_f64(options, "length_scale").unwrap_or(0.0);
2429 if !length_scale.is_finite() || length_scale < 0.0 {
2430 return Err(format!(
2431 "curvature smooth length_scale must be positive (or omitted for auto); got {length_scale}"
2432 ));
2433 }
2434 let centers = parse_countwith_basis_alias(
2435 options,
2436 "centers",
2437 default_num_centers(ds.values.nrows(), cols.len()),
2438 )?;
2439 if centers < 2 {
2440 return Err("curvature smooth requires at least 2 centers".to_string());
2441 }
2442 Ok(SmoothBasisSpec::ConstantCurvature {
2443 feature_cols: cols.to_vec(),
2444 spec: ConstantCurvatureBasisSpec {
2445 center_strategy: CenterStrategy::FarthestPoint {
2446 num_centers: centers,
2447 },
2448 kappa,
2449 length_scale,
2452 double_penalty: option_bool(options, "double_penalty").unwrap_or(false),
2459 identifiability: ConstantCurvatureIdentifiability::CenterSumToZero,
2460 },
2461 })
2462 }
2463 "measurejet" => {
2464 validate_known_options(
2470 "measurejet",
2471 options,
2472 &[
2473 "type",
2474 "bs",
2475 "by",
2476 "centers",
2477 "k",
2478 "basis_dim",
2479 "basis-dim",
2480 "basisdim",
2481 "knots",
2482 "s",
2483 "alpha",
2484 "tau",
2485 "scales",
2486 "length_scale",
2487 "double_penalty",
2488 "multiscale",
2489 "learn_length_scale",
2490 "id",
2491 "__by_col",
2492 ],
2493 )?;
2494 let order_s = option_f64(options, "s").unwrap_or(0.0);
2495 if !(order_s.is_finite() && (order_s == 0.0 || (order_s > 0.0 && order_s < 2.0))) {
2498 return Err(format!(
2499 "measurejet smooth s must lie in (0, 2) (or be omitted for auto); got {order_s}"
2500 ));
2501 }
2502 let alpha =
2510 option_f64(options, "alpha").unwrap_or(MeasureJetBasisSpec::default().alpha);
2511 if !alpha.is_finite() {
2512 return Err("measurejet smooth requires a finite alpha".to_string());
2513 }
2514 let tau0 = option_f64(options, "tau").unwrap_or(1e-3);
2515 if !(tau0.is_finite() && tau0 >= 0.0) {
2516 return Err(format!(
2517 "measurejet smooth tau must be finite and nonnegative; got {tau0}"
2518 ));
2519 }
2520 let num_scales = option_usize(options, "scales").unwrap_or(0);
2521 let length_scale = option_f64(options, "length_scale").unwrap_or(0.0);
2522 if !length_scale.is_finite() || length_scale < 0.0 {
2523 return Err(format!(
2524 "measurejet smooth length_scale must be positive (or omitted for auto); got {length_scale}"
2525 ));
2526 }
2527 let centers = parse_countwith_basis_alias(
2528 options,
2529 "centers",
2530 default_num_centers(ds.values.nrows(), cols.len()),
2531 )?;
2532 if centers < 3 {
2533 return Err("measurejet smooth requires at least 3 centers".to_string());
2534 }
2535 let multiscale = option_bool(options, "multiscale").unwrap_or(false);
2539 let learn_length_scale = option_bool(options, "learn_length_scale").unwrap_or(false);
2544 Ok(SmoothBasisSpec::MeasureJet {
2545 feature_cols: cols.to_vec(),
2546 spec: MeasureJetBasisSpec {
2547 center_strategy: CenterStrategy::FarthestPoint {
2548 num_centers: centers,
2549 },
2550 order_s,
2551 alpha,
2552 tau0,
2553 num_scales,
2554 length_scale,
2557 double_penalty: smooth_double_penalty,
2558 learn_length_scale,
2559 multiscale,
2560 identifiability: MeasureJetIdentifiability::CenterSumToZero,
2561 frozen_quadrature: None,
2562 },
2563 input_scales: None,
2564 })
2565 }
2566 "matern" => {
2567 validate_known_options(
2572 "matern",
2573 options,
2574 &[
2575 SECONDARY_CENTER_CAP_OPTION,
2576 "type",
2577 "bs",
2578 "by",
2579 "nu",
2580 "length_scale",
2581 "centers",
2582 "k",
2583 "basis_dim",
2584 "basis-dim",
2585 "basisdim",
2586 "knots",
2587 "include_intercept",
2588 "double_penalty",
2589 "by",
2590 "id",
2591 "__by_col",
2592 "identifiability",
2593 "by",
2594 "scale_dims",
2595 ],
2596 )?;
2597 let plan = plan_spatial_basis(
2598 ds.values.nrows(),
2599 cols.len(),
2600 CenterCountRequest::Default,
2601 DuchonNullspaceOrder::Zero,
2602 option_bool(options, "scale_dims").unwrap_or(false),
2603 policy,
2604 )
2605 .map_err(|e| e.to_string())?;
2606 let centers = parse_countwith_basis_alias(
2607 options,
2608 "centers",
2609 cap_default_spatial_centers(
2610 options,
2611 default_matern_center_count(ds.values.nrows(), cols.len(), plan.centers),
2612 ),
2613 )?;
2614 let center_strategy = if has_explicit_countwith_basis_alias(options, "centers") {
2615 spatial_center_strategy_for_dimension(centers, cols.len())
2616 } else {
2617 auto_spatial_center_strategy(centers, cols.len())
2618 };
2619 let nu = parse_matern_nu(options.get("nu").map(String::as_str).unwrap_or("5/2"))?;
2620 if matches!(nu, MaternNu::Half) && cols.len() >= 2 {
2626 return Err(TermBuilderError::unsupported_feature(format!(
2627 "matern() with nu=1/2 is not supported for d>=2 (got {} covariates): \
2628 the exponential kernel's Laplacian is singular at center collisions, \
2629 which makes the operator-collocation penalty non-invertible. \
2630 Choose nu>=3/2 (e.g. nu=3/2 or the default nu=5/2) for multi-dimensional smooths.",
2631 cols.len()
2632 ))
2633 .to_string());
2634 }
2635 let aniso_log_scales = if option_bool(options, "scale_dims").unwrap_or(false) {
2636 Some(vec![0.0; cols.len()])
2637 } else {
2638 None
2639 };
2640 Ok(SmoothBasisSpec::Matern {
2641 feature_cols: cols.to_vec(),
2642 spec: MaternBasisSpec {
2643 center_strategy,
2644 periodic: parse_periodic_axes_option(options, cols.len())?,
2645 length_scale: option_f64(options, "length_scale").unwrap_or(0.0),
2663 nu,
2664 include_intercept: option_bool(options, "include_intercept").unwrap_or(false),
2665 double_penalty: smooth_double_penalty,
2666 identifiability: parse_matern_identifiability(options)
2667 .map_err(|e| e.to_string())?,
2668 aniso_log_scales,
2669 nullspace_shrinkage_survived: None,
2674 },
2675 input_scales: None,
2676 })
2677 }
2678 "duchon" => {
2679 validate_known_options(
2680 "duchon",
2681 options,
2682 &[
2683 SECONDARY_CENTER_CAP_OPTION,
2684 "type",
2685 "bs",
2686 "by",
2687 "length_scale",
2688 "centers",
2689 "k",
2690 "basis_dim",
2691 "basis-dim",
2692 "basisdim",
2693 "knots",
2694 "power",
2695 "p",
2696 "nullspace_order",
2697 "order",
2698 "identifiability",
2699 "by",
2700 "periodic",
2701 "cyclic",
2702 "period",
2703 "period_start",
2704 "period_end",
2705 "scale_dims",
2706 "double_penalty",
2707 "by",
2708 "id",
2709 "__by_col",
2710 ],
2711 )?;
2712 if options.contains_key("double_penalty") {
2713 return Err(TermBuilderError::incompatible_config(format!(
2714 "Duchon smooth '{}' does not support double_penalty; the Duchon smoother already ships its native reproducing-norm penalty plus a null-space shrinkage ridge.",
2715 vars.join(", ")
2716 ))
2717 .to_string());
2718 }
2719 let requested_nullspace_order = parse_duchon_order(options)?;
2720 let length_scale = option_f64_strict(options, "length_scale")?;
2721 let (nullspace_order, power) = match parse_duchon_power_policy(options)? {
2734 DuchonPowerPolicy::Explicit(req_power) => {
2735 if length_scale.is_some() && req_power.fract() != 0.0 {
2736 return Err(TermBuilderError::incompatible_config(format!(
2737 "hybrid Duchon-Matern smooth '{}' (length_scale=...) requires an integer power, got power={}; \
2738 drop length_scale to use the scale-free structural kernel with a fractional power.",
2739 vars.join(", "),
2740 req_power,
2741 ))
2742 .to_string());
2743 }
2744 (requested_nullspace_order, req_power)
2745 }
2746 DuchonPowerPolicy::CubicStructuralDefault => {
2747 match length_scale {
2754 None => crate::basis::duchon_cubic_default(cols.len()),
2755 Some(_) => {
2756 let max_op = crate::basis::duchon_max_active_operator_derivative_order(
2777 &DuchonOperatorPenaltySpec::default(),
2778 );
2779 let (ns, s) = crate::basis::resolve_duchon_orders(
2780 cols.len(),
2781 requested_nullspace_order,
2782 max_op,
2783 length_scale,
2784 );
2785 (ns, s as f64)
2786 }
2787 }
2788 }
2789 };
2790 let plan = plan_spatial_basis(
2791 ds.values.nrows(),
2792 cols.len(),
2793 CenterCountRequest::Default,
2794 nullspace_order,
2795 option_bool(options, "scale_dims").unwrap_or(false),
2796 policy,
2797 )
2798 .map_err(|e| e.to_string())?;
2799 let centers_explicit = has_explicit_countwith_basis_alias(options, "centers");
2800 let requested_centers = parse_countwith_basis_alias(
2801 options,
2802 "centers",
2803 cap_default_spatial_centers(options, plan.centers),
2804 )?;
2805 let polynomial_cols = match nullspace_order {
2806 DuchonNullspaceOrder::Zero => 1,
2807 DuchonNullspaceOrder::Linear => cols.len() + 1,
2808 DuchonNullspaceOrder::Degree(degree) => {
2809 crate::basis::duchon_nullspace_dimension(cols.len(), degree)
2810 }
2811 };
2812 if requested_centers <= polynomial_cols {
2813 return Err(TermBuilderError::incompatible_config(format!(
2814 "Duchon smooth '{}' requested basis dimension {} but order={:?} in {}D needs {} polynomial null-space columns; choose centers/k > {}",
2815 vars.join(", "),
2816 requested_centers,
2817 nullspace_order,
2818 cols.len(),
2819 polynomial_cols,
2820 polynomial_cols,
2821 ))
2822 .to_string());
2823 }
2824 let mut centers = requested_centers;
2825 if !centers_explicit && ds.values.nrows() <= 32 && smooth_coordinate_count >= 5 {
2826 centers = centers.max(polynomial_cols + 4);
2827 }
2828 let center_strategy = if centers_explicit {
2829 spatial_center_strategy_for_dimension(centers, cols.len())
2830 } else {
2831 auto_spatial_center_strategy(centers, cols.len())
2832 };
2833 let aniso_log_scales = if option_bool(options, "scale_dims").unwrap_or(false) {
2834 Some(vec![0.0; cols.len()])
2835 } else {
2836 None
2837 };
2838 let operator_penalties = DuchonOperatorPenaltySpec::default();
2841 Ok(SmoothBasisSpec::Duchon {
2842 feature_cols: cols.to_vec(),
2843 spec: DuchonBasisSpec {
2844 center_strategy,
2845 periodic: parse_periodic_axes_option(options, cols.len())?,
2846 length_scale,
2847 power,
2848 nullspace_order,
2849 identifiability: parse_spatial_identifiability(options)
2850 .map_err(|e| e.to_string())?,
2851 aniso_log_scales,
2852 operator_penalties,
2853 boundary: if cols.len() == 1 {
2854 let c = cols[0];
2855 let (minv, maxv) = col_minmax(ds.values.column(c))?;
2856 parse_cyclic_boundary(options, minv, maxv)?
2857 } else {
2858 OneDimensionalBoundary::Open
2859 },
2860 radial_reparam: None,
2861 },
2862 input_scales: None,
2863 })
2864 }
2865 "tensor" | "te" | "ti" | "t2" => {
2866 validate_known_options(
2867 "tensor",
2868 options,
2869 &[
2870 "type",
2871 "bs",
2872 "by",
2873 "k",
2874 "basis_dim",
2875 "basis-dim",
2876 "basisdim",
2877 "knot_placement",
2878 "knot-placement",
2879 "knotplacement",
2880 "degree",
2881 "penalty_order",
2882 "double_penalty",
2883 "periodic",
2884 "cyclic",
2885 "period",
2886 "periods",
2887 "period_start",
2888 "period_end",
2889 "origin",
2890 "origins",
2891 "period_origin",
2892 "period-origin",
2893 "domain_origin",
2894 "boundary",
2895 "bc",
2896 "identifiability",
2897 "id",
2898 "__by_col",
2899 ],
2900 )?;
2901 if cols.len() < 2 {
2902 return Err(TermBuilderError::incompatible_config(format!(
2903 "tensor smooth expects at least 2 variables, got {}",
2904 cols.len()
2905 ))
2906 .to_string());
2907 }
2908 let dim = cols.len();
2909
2910 if let Some(raw) = options.get("bs").or_else(|| options.get("type"))
2933 && bs_selector_is_vector(raw)
2934 {
2935 let per_margin = parse_option_list(raw);
2936 if per_margin.len() != dim {
2937 return Err(TermBuilderError::invalid_option(format!(
2938 "tensor smooth per-margin bs vector has {} entries but the smooth has {} margins",
2939 per_margin.len(),
2940 dim
2941 ))
2942 .to_string());
2943 }
2944 for (axis, margin_bs) in per_margin.iter().enumerate() {
2945 if !tensor_margin_bs_is_supported(margin_bs) {
2946 return Err(TermBuilderError::unsupported_feature(format!(
2947 "tensor smooth margin {axis} basis '{margin_bs}' is not a supported penalized-spline margin; \
2948 tensor margins accept tp/tps/ps/bs/cr/cc"
2949 ))
2950 .to_string());
2951 }
2952 }
2953 }
2954 let periodic_axes = parse_tensor_periodic_axes(options, dim)?;
2955 let periods_opt = parse_periods(options, &periodic_axes)?;
2956 let origins_opt = parse_period_origins(options, &periodic_axes)?;
2957 let degree = option_usize(options, "degree").unwrap_or(DEFAULT_BSPLINE_DEGREE);
2958 let penalty_order =
2959 option_usize(options, "penalty_order").unwrap_or(if degree > 1 { 2 } else { 1 });
2960 let (mut k_list, k_inferred) = parse_tensor_k_list(options, cols, ds)?;
2961 if ds.values.nrows() <= 32 && smooth_coordinate_count >= 5 {
2962 for k in &mut k_list {
2963 *k = (*k).min(degree + 2);
2964 }
2965 }
2966 if k_inferred {
2967 inference_notes.push(format!(
2968 "Automatically set per-margin basis sizes {:?} for tensor smooth '{}' \
2969 (dimension-aware tensor budget: total ∏k kept near the mgcv-te default \
2970 and within the data support, distributed geometrically across margins and \
2971 capped per margin by each column's resolution). \
2972 Override with k=<int> or k=[k0,k1,...].",
2973 k_list,
2974 vars.join(",")
2975 ));
2976 }
2977 let per_axis_bs: Vec<Option<String>> =
2990 match options.get("bs").or_else(|| options.get("type")) {
2991 Some(raw) if bs_selector_is_vector(raw) => {
2992 let list = parse_option_list(raw);
2993 (0..dim).map(|a| list.get(a).cloned()).collect()
2994 }
2995 Some(raw) => {
2996 let scalar = raw
2997 .trim()
2998 .trim_matches('"')
2999 .trim_matches('\'')
3000 .to_ascii_lowercase();
3001 vec![Some(scalar); dim]
3002 }
3003 None => vec![None; dim],
3004 };
3005 let margin_wants_cr = |bs: &Option<String>| -> bool {
3011 matches!(
3012 bs.as_deref(),
3013 None | Some("cr") | Some("cs") | Some("tp") | Some("tps")
3014 )
3015 };
3016 let mut margins: Vec<BSplineBasisSpec> = Vec::with_capacity(dim);
3017 let mut emitted_periods: Vec<Option<f64>> = Vec::with_capacity(dim);
3018 for axis in 0..dim {
3019 let c = cols[axis];
3020 let (data_min, data_max) = col_minmax(ds.values.column(c))?;
3021 let k_requested = k_list[axis];
3037 let n_distinct_axis = unique_count_column(ds.values.column(c));
3038 let k_axis = k_requested.min(n_distinct_axis).max(2);
3039 if k_axis < k_requested {
3040 log::info!(
3041 "tensor smooth: margin axis {axis} requested k={k_requested}, but the \
3042 covariate has only {n_distinct_axis} distinct value(s); reducing this \
3043 margin to k={k_axis} (mgcv-style data-support cap on the per-axis basis)."
3044 );
3045 }
3046 if k_axis < 2 {
3059 return Err(TermBuilderError::invalid_option(format!(
3060 "tensor smooth: k[{axis}]={k_axis} too small; tensor margins require k >= 2"
3061 ))
3062 .to_string());
3063 }
3064 if periodic_axes[axis] && k_axis < degree + 1 {
3065 return Err(TermBuilderError::invalid_option(format!(
3066 "tensor smooth: periodic axis {axis} requires k >= {} for degree {degree}, got k={k_axis}",
3067 degree + 1
3068 ))
3069 .to_string());
3070 }
3071 let effective_degree = degree.min(k_axis - 1).max(1);
3072 let effective_penalty_order = penalty_order.min(effective_degree);
3073 let (knotspec, boundary, axis_period) = if periodic_axes[axis] {
3074 let period_value = periods_opt[axis].ok_or_else(|| {
3075 format!(
3076 "tensor smooth axis {axis} is periodic but no period was supplied; \
3077 pass period=<value> (scalar) or period=[..., <value>, ...]"
3078 )
3079 })?;
3080 if !period_value.is_finite() || period_value <= 0.0 {
3081 return Err(format!(
3082 "tensor smooth axis {axis}: period must be a positive finite value, got {period_value}"
3083 ));
3084 }
3085 let domain_start = origins_opt[axis].unwrap_or(data_min);
3086 let domain_end = domain_start + period_value;
3087 (
3088 BSplineKnotSpec::PeriodicUniform {
3089 data_range: (domain_start, domain_end),
3090 num_basis: k_axis,
3091 },
3092 OneDimensionalBoundary::Cyclic {
3093 start: domain_start,
3094 end: domain_end,
3095 },
3096 Some(period_value),
3097 )
3098 } else if margin_wants_cr(&per_axis_bs[axis]) && k_axis >= 3 {
3099 let cr_knots =
3109 crate::basis::select_cr_knots(ds.values.column(c), k_axis)
3110 .map_err(|e| e.to_string())?;
3111 (
3112 BSplineKnotSpec::NaturalCubicRegression { knots: cr_knots },
3113 OneDimensionalBoundary::Open,
3114 None,
3115 )
3116 } else {
3117 let num_internal_knots = if effective_degree < degree {
3124 k_axis.saturating_sub(effective_degree + 1)
3125 } else {
3126 k_axis.saturating_sub(degree + 1).max(1)
3127 };
3128 let knotspec = match parse_knot_placement(options)? {
3129 crate::basis::BSplineKnotPlacement::Uniform => BSplineKnotSpec::Generate {
3130 data_range: (data_min, data_max),
3131 num_internal_knots,
3132 },
3133 crate::basis::BSplineKnotPlacement::Quantile => {
3134 crate::basis::auto_knot_vector_1d_quantile(
3135 ds.values.column(c),
3136 num_internal_knots,
3137 effective_degree,
3138 )
3139 .map_err(|e| e.to_string())?;
3140 BSplineKnotSpec::Automatic {
3141 num_internal_knots: Some(num_internal_knots),
3142 placement: crate::basis::BSplineKnotPlacement::Quantile,
3143 }
3144 }
3145 };
3146 (knotspec, OneDimensionalBoundary::Open, None)
3147 };
3148 let is_cr_margin =
3154 matches!(knotspec, BSplineKnotSpec::NaturalCubicRegression { .. });
3155 let margin_double_penalty =
3156 is_cr_margin && matches!(per_axis_bs[axis].as_deref(), Some("cs"));
3157 margins.push(BSplineBasisSpec {
3158 degree: effective_degree,
3159 penalty_order: effective_penalty_order,
3160 knotspec,
3161 double_penalty: margin_double_penalty,
3162 identifiability: BSplineIdentifiability::None,
3163 boundary,
3164 boundary_conditions: BSplineBoundaryConditions::default(),
3165 });
3166 emitted_periods.push(axis_period);
3167 }
3168 let canon_cols: Vec<usize> = {
3189 let mut perm: Vec<usize> = (0..dim).collect();
3190 perm.sort_by_key(|&a| cols[a]);
3191 if perm.iter().enumerate().any(|(i, &a)| i != a) {
3192 margins = perm.iter().map(|&a| margins[a].clone()).collect();
3193 emitted_periods = perm.iter().map(|&a| emitted_periods[a]).collect();
3194 }
3195 perm.iter().map(|&a| cols[a]).collect()
3196 };
3197 let any_periodic = emitted_periods.iter().any(|p| p.is_some());
3198 let periods_vec = if any_periodic {
3199 emitted_periods
3200 } else {
3201 Vec::new()
3202 };
3203 let tensor_double_penalty = option_bool(options, "double_penalty").unwrap_or(false);
3219 Ok(SmoothBasisSpec::TensorBSpline {
3220 feature_cols: canon_cols,
3221 spec: TensorBSplineSpec {
3222 marginalspecs: margins,
3223 periods: periods_vec,
3224 double_penalty: tensor_double_penalty,
3225 identifiability: parse_tensor_identifiability(options, kind)?,
3226 penalty_decomposition: if matches!(kind, SmoothKind::T2)
3236 || type_opt.as_str() == "t2"
3237 {
3238 TensorBSplinePenaltyDecomposition::Separable
3239 } else {
3240 TensorBSplinePenaltyDecomposition::MarginalKroneckerSum
3241 },
3242 },
3243 })
3244 }
3245 "pca" => {
3246 validate_known_options(
3247 "pca",
3248 options,
3249 &[
3250 "type",
3251 "bs",
3252 "by",
3253 "k",
3254 "basis_dim",
3255 "basis-dim",
3256 "basisdim",
3257 "lazy_path",
3258 "path",
3259 "pca_basis_path",
3260 "chunk_size",
3261 "smooth_penalty",
3262 "centered",
3263 "double_penalty",
3264 "id",
3265 "__by_col",
3266 ],
3267 )?;
3268 let path = options
3269 .get("lazy_path")
3270 .or_else(|| options.get("pca_basis_path"))
3271 .or_else(|| options.get("path"))
3272 .map(|raw| PathBuf::from(strip_quotes(raw)));
3273 let Some(path) = path else {
3274 return Err(TermBuilderError::incompatible_config(
3275 "pca smooth requires lazy_path=... on the formula path",
3276 )
3277 .to_string());
3278 };
3279 let k = option_usize_any(options, &["k", "basis_dim", "basis-dim", "basisdim"])
3280 .unwrap_or(0);
3281 let chunk_size = option_usize(options, "chunk_size").unwrap_or(DEFAULT_PCA_CHUNK_SIZE);
3282 Ok(SmoothBasisSpec::Pca {
3283 feature_cols: cols.to_vec(),
3284 basis_matrix: Array2::<f64>::zeros((cols.len(), k)),
3285 centered: option_bool(options, "centered").unwrap_or(true),
3286 smooth_penalty: option_f64(options, "smooth_penalty").unwrap_or(1.0),
3287 center_mean: None,
3288 pca_basis_path: Some(path),
3289 chunk_size,
3290 })
3291 }
3292 other => Err(TermBuilderError::unsupported_feature(format!(
3293 "unsupported smooth type '{other}'"
3294 ))
3295 .to_string()),
3296 }
3297}
3298
3299pub fn enable_scale_dimensions(spec: &mut TermCollectionSpec) {
3301 for smooth in spec.smooth_terms.iter_mut() {
3302 promote_thin_plate_for_scale_dimensions(&mut smooth.basis);
3309 match &mut smooth.basis {
3310 SmoothBasisSpec::Matern {
3311 feature_cols,
3312 spec: matern,
3313 ..
3314 } => {
3315 if matern.aniso_log_scales.is_none() {
3316 let d = feature_cols.len();
3317 matern.aniso_log_scales = Some(vec![0.0; d]);
3318 }
3319 }
3320 SmoothBasisSpec::Duchon {
3321 feature_cols,
3322 spec: duchon,
3323 ..
3324 } => {
3325 if duchon.aniso_log_scales.is_none() {
3326 let d = feature_cols.len();
3327 duchon.aniso_log_scales = Some(vec![0.0; d]);
3328 }
3329 }
3330 _ => {}
3331 }
3332 }
3333}
3334
3335fn promote_thin_plate_for_scale_dimensions(basis: &mut SmoothBasisSpec) {
3370 let SmoothBasisSpec::ThinPlate {
3371 feature_cols,
3372 spec,
3373 input_scales,
3374 } = &*basis
3375 else {
3376 return;
3377 };
3378 let d = feature_cols.len();
3379 if d <= 1 {
3380 return;
3381 }
3382 let m = thin_plate_penalty_order(d);
3387 let nullspace_order = match m {
3388 0 | 1 => DuchonNullspaceOrder::Zero,
3389 2 => DuchonNullspaceOrder::Linear,
3390 _ => DuchonNullspaceOrder::Degree(m - 1),
3391 };
3392 let duchon_spec = DuchonBasisSpec {
3393 center_strategy: spec.center_strategy.clone(),
3394 periodic: spec.periodic.clone(),
3395 length_scale: None,
3400 power: 0.0,
3402 nullspace_order,
3403 identifiability: spec.identifiability.clone(),
3404 aniso_log_scales: Some(vec![0.0; d]),
3408 operator_penalties: DuchonOperatorPenaltySpec::default(),
3409 boundary: OneDimensionalBoundary::Open,
3410 radial_reparam: None,
3411 };
3412 let feature_cols = feature_cols.clone();
3413 let input_scales = input_scales.clone();
3414 *basis = SmoothBasisSpec::Duchon {
3417 feature_cols,
3418 spec: duchon_spec,
3419 input_scales,
3420 };
3421}
3422
3423pub fn spatial_center_strategy_for_dimension(num_centers: usize, d: usize) -> CenterStrategy {
3428 if d <= 3 {
3429 CenterStrategy::FarthestPoint { num_centers }
3436 } else {
3437 default_spatial_center_strategy(num_centers, d)
3438 }
3439}
3440
3441pub fn col_minmax(col: ArrayView1<'_, f64>) -> Result<(f64, f64), String> {
3442 let min = col.iter().fold(f64::INFINITY, |a, &b| a.min(b));
3443 let max = col.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
3444 if !min.is_finite() || !max.is_finite() {
3445 return Err(TermBuilderError::degenerate_data(
3446 "non-finite data encountered while inferring knot range",
3447 )
3448 .to_string());
3449 }
3450 if (max - min).abs() < 1e-12 {
3451 Ok((min, min + 1e-6))
3452 } else {
3453 Ok((min, max))
3454 }
3455}
3456
3457pub fn unique_count_column(col: ArrayView1<'_, f64>) -> usize {
3458 use std::collections::HashSet;
3459 let mut set = HashSet::<u64>::with_capacity(col.len());
3460 for &v in col {
3461 let norm = if v == 0.0 { 0.0 } else { v };
3462 set.insert(norm.to_bits());
3463 }
3464 set.len().max(1)
3465}
3466
3467pub(crate) const CR_MIN_KNOTS: usize = 3;
3473
3474fn capped_cr_marginal_knotspec(
3501 col: ArrayView1<'_, f64>,
3502 k_cr_requested: usize,
3503 label: &str,
3504 inference_notes: &mut Vec<String>,
3505) -> Result<Option<BSplineKnotSpec>, String> {
3506 let n_distinct = unique_count_column(col);
3507 let k_cr = k_cr_requested.min(n_distinct);
3508 if k_cr < CR_MIN_KNOTS {
3509 inference_notes.push(format!(
3510 "Smooth '{label}': cubic-regression ('cr'/'cs'/'sz') basis requested k={k_cr_requested}, \
3511 but the covariate has only {n_distinct} distinct value(s) — too few to support a cubic \
3512 regression spline (needs >= {CR_MIN_KNOTS} distinct values). Degraded to the linear \
3513 B-spline marginal the default basis builds on the same data."
3514 ));
3515 return Ok(None);
3516 }
3517 if k_cr < k_cr_requested {
3518 inference_notes.push(format!(
3519 "Smooth '{label}': cubic-regression ('cr'/'cs'/'sz') basis reduced from k={k_cr_requested} \
3520 to k={k_cr} to match the covariate's {n_distinct} distinct value(s) (mgcv-style \
3521 data-support cap; a cr basis cannot place more value-knots than the data has)."
3522 ));
3523 }
3524 let cr_knots = crate::basis::select_cr_knots(col, k_cr).map_err(|e| e.to_string())?;
3525 Ok(Some(BSplineKnotSpec::NaturalCubicRegression {
3526 knots: cr_knots,
3527 }))
3528}
3529
3530fn min_per_group_unique_count(
3537 feature_col: ArrayView1<'_, f64>,
3538 group_col: ArrayView1<'_, f64>,
3539) -> usize {
3540 use std::collections::{HashMap, HashSet};
3541 let mut per_group: HashMap<u64, HashSet<u64>> = HashMap::new();
3542 for (xi, gi) in feature_col.iter().zip(group_col.iter()) {
3543 let xnorm = if *xi == 0.0 { 0.0 } else { *xi };
3544 let gnorm = if *gi == 0.0 { 0.0 } else { *gi };
3545 per_group
3546 .entry(gnorm.to_bits())
3547 .or_default()
3548 .insert(xnorm.to_bits());
3549 }
3550 per_group
3551 .values()
3552 .map(|s| s.len())
3553 .min()
3554 .unwrap_or(1)
3555 .max(1)
3556}
3557
3558pub fn heuristic_knots_for_column(col: ArrayView1<'_, f64>) -> usize {
3563 let unique = unique_count_column(col);
3564 let ceiling = ((unique as f64).cbrt() as usize).max(20);
3565 (unique / 4).clamp(4, ceiling)
3566}
3567
3568fn heuristic_tensor_margin_knots(cols: &[usize], ds: &Dataset) -> Vec<usize> {
3588 let d = cols.len().max(1);
3589 let degree = DEFAULT_BSPLINE_DEGREE;
3590 let min_k = degree + 2; let n = ds.values.nrows();
3592
3593 let per_margin_cap: Vec<usize> = cols
3597 .iter()
3598 .map(|&c| heuristic_knots_for_column(ds.values.column(c)).max(min_k))
3599 .collect();
3600
3601 let mgcv_like_per_margin = match d {
3608 2 => 7usize,
3609 3 => 5usize,
3610 _ => 4usize,
3611 };
3612 let mgcv_like_total = (mgcv_like_per_margin as f64).powi(d as i32);
3613 let data_budget = (n as f64) * 0.8;
3614 let p_target = mgcv_like_total
3615 .max(min_k.pow(d as u32) as f64)
3616 .min(data_budget);
3617
3618 let geo_per_margin = p_target.powf(1.0 / d as f64).round() as usize;
3621 let unclamped: Vec<usize> = per_margin_cap
3622 .iter()
3623 .map(|&cap| geo_per_margin.clamp(min_k, cap))
3624 .collect();
3625
3626 let mut k_list = unclamped;
3631 loop {
3632 let product: f64 = k_list.iter().map(|&k| k as f64).product();
3633 if product >= p_target {
3634 break;
3635 }
3636 let Some(idx) = k_list
3639 .iter()
3640 .zip(per_margin_cap.iter())
3641 .enumerate()
3642 .filter(|&(_, (k, cap))| k < cap)
3643 .max_by_key(|&(_, (k, cap))| (cap - k, *cap))
3644 .map(|(i, _)| i)
3645 else {
3646 break;
3647 };
3648 k_list[idx] += 1;
3649 }
3650 k_list
3651}
3652
3653pub fn heuristic_centers(n: usize, d: usize) -> usize {
3654 default_num_centers(n, d)
3655}
3656
3657fn parse_endpoint_side(
3662 value: &str,
3663 context: &str,
3664) -> Result<BSplineEndpointBoundaryCondition, String> {
3665 match value.trim().to_ascii_lowercase().as_str() {
3666 "" | "none" | "open" | "unconstrained" | "free" => {
3667 Ok(BSplineEndpointBoundaryCondition::Free)
3668 }
3669 "clamped" | "clamp" | "zero_derivative" | "zero-derivative" => {
3670 Ok(BSplineEndpointBoundaryCondition::Clamped)
3671 }
3672 "anchored" | "anchor" | "zero" | "zero_value" | "zero-value" => {
3673 Ok(BSplineEndpointBoundaryCondition::Anchored { value: 0.0 })
3674 }
3675 other => Err(format!(
3676 "unsupported {context} boundary condition '{other}'; expected free, clamped, or anchored"
3677 )),
3678 }
3679}
3680
3681fn boundary_anchor_value(
3682 options: &BTreeMap<String, String>,
3683 side: &str,
3684 fallback: Option<f64>,
3685) -> Option<f64> {
3686 [
3687 format!("anchor_{side}"),
3688 format!("{side}_anchor"),
3689 format!("anchor-value-{side}"),
3690 ]
3691 .iter()
3692 .find_map(|key| option_f64(options, key))
3693 .or(fallback)
3694}
3695
3696fn apply_anchor_value(
3697 cond: BSplineEndpointBoundaryCondition,
3698 value: Option<f64>,
3699) -> BSplineEndpointBoundaryCondition {
3700 match cond {
3701 BSplineEndpointBoundaryCondition::Anchored { .. } => {
3702 BSplineEndpointBoundaryCondition::Anchored {
3703 value: value.unwrap_or(0.0),
3704 }
3705 }
3706 other => other,
3707 }
3708}
3709
3710fn parse_bspline_boundary_conditions(
3711 options: &BTreeMap<String, String>,
3712) -> Result<BSplineBoundaryConditions, String> {
3713 let fallback_anchor = option_f64(options, "anchor")
3714 .or_else(|| option_f64(options, "anchor_value"))
3715 .or_else(|| option_f64(options, "value"));
3716 let global_boundary_conditions = options
3717 .get("boundary_conditions")
3718 .or_else(|| options.get("bc"));
3719 let mut boundary_conditions = BSplineBoundaryConditions::default();
3720
3721 if let Some(raw_boundary_conditions) = global_boundary_conditions {
3722 let cond = parse_endpoint_side(raw_boundary_conditions, "boundary_conditions")?;
3723 let side = options
3724 .get("side")
3725 .map(|s| s.trim().to_ascii_lowercase())
3726 .unwrap_or_else(|| "both".to_string());
3727 match side.as_str() {
3728 "both" | "all" | "endpoints" => {
3729 boundary_conditions.left = cond;
3730 boundary_conditions.right = cond;
3731 }
3732 "left" | "start" | "lower" => boundary_conditions.left = cond,
3733 "right" | "end" | "upper" => boundary_conditions.right = cond,
3734 other => {
3735 return Err(format!(
3736 "unsupported B-spline boundary side '{other}'; expected left, right, or both"
3737 ));
3738 }
3739 }
3740 }
3741
3742 if let Some(raw) = options
3743 .get("bc_left")
3744 .or_else(|| options.get("left_bc"))
3745 .or_else(|| options.get("bc_start"))
3746 .or_else(|| options.get("start_bc"))
3747 {
3748 boundary_conditions.left = parse_endpoint_side(raw, "left endpoint")?;
3749 }
3750 if let Some(raw) = options
3751 .get("bc_right")
3752 .or_else(|| options.get("right_bc"))
3753 .or_else(|| options.get("bc_end"))
3754 .or_else(|| options.get("end_bc"))
3755 {
3756 boundary_conditions.right = parse_endpoint_side(raw, "right endpoint")?;
3757 }
3758
3759 boundary_conditions.left = apply_anchor_value(
3760 boundary_conditions.left,
3761 boundary_anchor_value(options, "left", fallback_anchor),
3762 );
3763 boundary_conditions.right = apply_anchor_value(
3764 boundary_conditions.right,
3765 boundary_anchor_value(options, "right", fallback_anchor),
3766 );
3767
3768 reject_nonzero_anchor("left", boundary_conditions.left)?;
3776 reject_nonzero_anchor("right", boundary_conditions.right)?;
3777
3778 Ok(boundary_conditions)
3779}
3780
3781fn reject_nonzero_anchor(side: &str, cond: BSplineEndpointBoundaryCondition) -> Result<(), String> {
3782 if let BSplineEndpointBoundaryCondition::Anchored { value } = cond {
3783 if value.abs() > 1e-12 {
3784 return Err(format!(
3785 "non-zero {side} anchor {value} requires an affine offset term that is not yet supported; only anchored value 0 is accepted at parse time"
3786 ));
3787 }
3788 }
3789 Ok(())
3790}
3791
3792fn parse_ps_internal_knots(
3806 options: &BTreeMap<String, String>,
3807 degree: usize,
3808 default_internal_knots: usize,
3809) -> Result<(usize, bool, usize), String> {
3810 const MIN_EXPRESSIVE_INTERNAL_KNOTS: usize = 2;
3811 let knots_internal = if knots_option_is_list(options) {
3821 None
3822 } else {
3823 option_usize_strict(options, "knots")?
3824 };
3825 let basis_dim = option_usize_any_strict(options, &["k", "basis_dim", "basis-dim", "basisdim"])?;
3826 if knots_internal.is_some() && basis_dim.is_some() {
3827 return Err(TermBuilderError::incompatible_config(
3828 "ps/bspline smooth: specify either knots=<internal_knots> or k=<basis_dim> (not both)",
3829 )
3830 .to_string());
3831 }
3832 if let Some(k) = basis_dim {
3833 if k < 2 {
3834 return Err(TermBuilderError::invalid_option(format!(
3835 "ps/bspline smooth: k={} too small; B-spline basis requires k >= 2",
3836 k
3837 ))
3838 .to_string());
3839 }
3840 let effective_degree = degree.min(k - 1).max(1);
3846 let num_internal_knots = if effective_degree < degree {
3847 k.saturating_sub(effective_degree + 1)
3850 } else {
3851 (k - degree - 1).max(MIN_EXPRESSIVE_INTERNAL_KNOTS)
3852 };
3853 Ok((num_internal_knots, false, effective_degree))
3854 } else {
3855 Ok((
3856 knots_internal.unwrap_or(default_internal_knots),
3857 knots_internal.is_none(),
3858 degree,
3859 ))
3860 }
3861}
3862
3863fn knots_option_is_list(options: &BTreeMap<String, String>) -> bool {
3869 options
3870 .get("knots")
3871 .map(|raw| {
3872 let t = raw.trim();
3873 t.starts_with('[') || t.starts_with("c(") || t.starts_with("C(") || t.starts_with('(')
3874 })
3875 .unwrap_or(false)
3876}
3877
3878fn parse_explicit_internal_knots(
3883 options: &BTreeMap<String, String>,
3884) -> Result<Option<Vec<f64>>, String> {
3885 if !knots_option_is_list(options) {
3886 return Ok(None);
3887 }
3888 let raw = options
3889 .get("knots")
3890 .expect("knots_option_is_list implies the key is present");
3891 let tokens = split_list_option(raw);
3892 if tokens.is_empty() {
3893 return Err(TermBuilderError::invalid_option(format!(
3894 "knots={raw} is an empty list; supply at least one internal knot position \
3895 (e.g. knots=[0.2, 0.5, 0.8]) or a scalar count (e.g. knots=8)"
3896 ))
3897 .to_string());
3898 }
3899 let mut positions = Vec::with_capacity(tokens.len());
3900 for tok in &tokens {
3901 let value = parse_numeric_expr(tok).map_err(|err| {
3902 TermBuilderError::invalid_option(format!(
3903 "knots list entry '{tok}' is not a numeric position: {err}"
3904 ))
3905 .to_string()
3906 })?;
3907 positions.push(value);
3908 }
3909 Ok(Some(positions))
3910}
3911
3912fn parse_knot_placement(
3918 options: &BTreeMap<String, String>,
3919) -> Result<crate::basis::BSplineKnotPlacement, String> {
3920 use crate::basis::BSplineKnotPlacement;
3921 match options
3922 .get("knot_placement")
3923 .or_else(|| options.get("knot-placement"))
3924 .or_else(|| options.get("knotplacement"))
3925 {
3926 None => Ok(BSplineKnotPlacement::Uniform),
3927 Some(raw) => match raw
3928 .trim()
3929 .trim_matches('"')
3930 .trim_matches('\'')
3931 .to_ascii_lowercase()
3932 .as_str()
3933 {
3934 "uniform" | "even" | "equal" => Ok(BSplineKnotPlacement::Uniform),
3935 "quantile" | "quantiles" | "data" | "empirical" => Ok(BSplineKnotPlacement::Quantile),
3936 other => Err(TermBuilderError::invalid_option(format!(
3937 "knot_placement={other} is not recognised; expected \"uniform\" or \"quantile\""
3938 ))
3939 .to_string()),
3940 },
3941 }
3942}
3943
3944fn resolve_nonperiodic_bspline_knotspec(
3955 options: &BTreeMap<String, String>,
3956 data: ArrayView1<'_, f64>,
3957 data_range: (f64, f64),
3958 degree: usize,
3959 n_knots: usize,
3960) -> Result<BSplineKnotSpec, String> {
3961 use crate::basis::{BSplineKnotPlacement, clamped_knot_vector_from_internal_positions};
3962 if let Some(positions) = parse_explicit_internal_knots(options)? {
3963 if option_usize_any_strict(options, &["k", "basis_dim", "basis-dim", "basisdim"])?.is_some()
3964 {
3965 return Err(TermBuilderError::incompatible_config(
3966 "ps/bspline smooth: specify either explicit knots=[...] positions or \
3967 k=<basis_dim> (not both); the basis size is fixed by the knot vector",
3968 )
3969 .to_string());
3970 }
3971 let knots = clamped_knot_vector_from_internal_positions(data_range, &positions, degree)
3972 .map_err(|e| e.to_string())?;
3973 return Ok(BSplineKnotSpec::Provided(knots));
3974 }
3975 match parse_knot_placement(options)? {
3976 BSplineKnotPlacement::Uniform => Ok(BSplineKnotSpec::Generate {
3977 data_range,
3978 num_internal_knots: n_knots,
3979 }),
3980 BSplineKnotPlacement::Quantile => {
3981 crate::basis::auto_knot_vector_1d_quantile(data, n_knots, degree)
3985 .map_err(|e| e.to_string())?;
3986 Ok(BSplineKnotSpec::Automatic {
3987 num_internal_knots: Some(n_knots),
3988 placement: BSplineKnotPlacement::Quantile,
3989 })
3990 }
3991 }
3992}
3993
3994pub fn validate_known_options(
4000 term_name: &str,
4001 options: &BTreeMap<String, String>,
4002 known: &[&str],
4003) -> Result<(), String> {
4004 let known_set: std::collections::BTreeSet<&&str> = known.iter().collect();
4005 for key in options.keys() {
4006 if !known_set.contains(&key.as_str()) {
4007 if term_name == "tensor" && is_tensor_k_axis_option_key(key) {
4008 continue;
4009 }
4010 let key_l = key.to_ascii_lowercase();
4012 let mut suggestions: Vec<&str> = known
4013 .iter()
4014 .filter(|k| {
4015 let kl = k.to_ascii_lowercase();
4016 kl.contains(&key_l) || key_l.contains(&kl) || {
4017 let n = kl
4018 .chars()
4019 .zip(key_l.chars())
4020 .take_while(|(a, b)| a == b)
4021 .count();
4022 n >= 3
4023 }
4024 })
4025 .copied()
4026 .collect();
4027 suggestions.sort_unstable();
4028 suggestions.dedup();
4029 let hint = if suggestions.is_empty() {
4030 String::new()
4031 } else {
4032 format!(" — did you mean one of [{}]?", suggestions.join(", "))
4033 };
4034 return Err(TermBuilderError::invalid_option(format!(
4035 "{term_name}() does not accept option `{key}`{hint}. Valid options: [{}]",
4036 {
4037 let mut sorted = known.to_vec();
4038 sorted.sort_unstable();
4039 sorted.join(", ")
4040 }
4041 ))
4042 .to_string());
4043 }
4044 }
4045 Ok(())
4046}
4047
4048pub const SECONDARY_CENTER_CAP_OPTION: &str = "__secondary_center_cap";
4058
4059pub(crate) fn cap_default_spatial_centers(
4064 options: &BTreeMap<String, String>,
4065 default_count: usize,
4066) -> usize {
4067 match option_usize(options, SECONDARY_CENTER_CAP_OPTION) {
4068 Some(cap) => default_count.min(cap),
4069 None => default_count,
4070 }
4071}
4072
4073fn default_matern_center_count(n: usize, d: usize, planned_count: usize) -> usize {
4074 let low_n_floor = (d + 4).min(n);
4081 planned_count.max(low_n_floor).max(1)
4082}
4083
4084pub fn parse_countwith_basis_alias(
4085 options: &BTreeMap<String, String>,
4086 primarykey: &str,
4087 default_count: usize,
4088) -> Result<usize, String> {
4089 let primary = option_usize_strict(options, primarykey)?;
4094 let basis_dim = option_usize_any_strict(
4095 options,
4096 &["k", "basis_dim", "basis-dim", "basisdim", "knots"],
4097 )?;
4098 if primary.is_some() && basis_dim.is_some() {
4099 return Err(TermBuilderError::incompatible_config(format!(
4100 "specify either {}=<count> or k=<basis_dim> (not both)",
4101 primarykey
4102 ))
4103 .to_string());
4104 }
4105 Ok(primary.or(basis_dim).unwrap_or(default_count))
4106}
4107
4108pub fn has_explicit_countwith_basis_alias(
4109 options: &BTreeMap<String, String>,
4110 primarykey: &str,
4111) -> bool {
4112 options.contains_key(primarykey)
4113 || ["k", "basis_dim", "basis-dim", "basisdim", "knots"]
4114 .iter()
4115 .any(|alias| options.contains_key(*alias))
4116}
4117
4118pub fn parse_cyclic_boundary(
4119 options: &BTreeMap<String, String>,
4120 minv: f64,
4121 maxv: f64,
4122) -> Result<OneDimensionalBoundary, String> {
4123 let cyclic = option_bool(options, "cyclic")
4124 .or_else(|| option_bool(options, "periodic"))
4125 .unwrap_or(false);
4126 if !cyclic {
4127 return Ok(OneDimensionalBoundary::Open);
4128 }
4129 let start = match option_numeric_expr(options, "period_start")? {
4130 Some(v) => v,
4131 None => option_numeric_expr(options, "start")?.unwrap_or(minv),
4132 };
4133 let end = match option_numeric_expr(options, "period_end")? {
4134 Some(v) => v,
4135 None => option_numeric_expr(options, "end")?.unwrap_or(maxv),
4136 };
4137 if end <= start {
4138 return Err(format!(
4139 "cyclic smooth requires period_end/end ({end}) > period_start/start ({start})"
4140 ));
4141 }
4142 Ok(OneDimensionalBoundary::Cyclic { start, end })
4143}
4144
4145pub fn parse_periodic_domain_1d(
4152 options: &BTreeMap<String, String>,
4153 minv: f64,
4154 maxv: f64,
4155) -> Result<(f64, f64), String> {
4156 let start = match option_numeric_expr(options, "period_start")? {
4157 Some(v) => v,
4158 None => option_numeric_expr(options, "start")?.unwrap_or(minv),
4159 };
4160 let end = match option_numeric_expr(options, "period_end")? {
4161 Some(v) => v,
4162 None => option_numeric_expr(options, "end")?.unwrap_or(maxv),
4163 };
4164 if !(start.is_finite() && end.is_finite()) {
4165 return Err(format!(
4166 "periodic smooth domain requires finite endpoints, got ({start}, {end})"
4167 ));
4168 }
4169 if end <= start {
4170 return Err(format!(
4171 "periodic smooth requires period_end/end ({end}) > period_start/start ({start})"
4172 ));
4173 }
4174 Ok((start, end - start))
4175}
4176
4177fn parse_matern_nu(raw: &str) -> Result<MaternNu, String> {
4178 let trimmed = raw.trim();
4179 let lowered = trimmed.to_ascii_lowercase();
4180 match lowered.as_str() {
4181 "1/2" | "0.5" | "half" => return Ok(MaternNu::Half),
4182 "3/2" | "1.5" => return Ok(MaternNu::ThreeHalves),
4183 "5/2" | "2.5" => return Ok(MaternNu::FiveHalves),
4184 "7/2" | "3.5" => return Ok(MaternNu::SevenHalves),
4185 "9/2" | "4.5" => return Ok(MaternNu::NineHalves),
4186 _ => {}
4187 }
4188
4189 let value = if let Some((num, den)) = trimmed.split_once('/') {
4190 let num = num
4191 .trim()
4192 .parse::<f64>()
4193 .map_err(|err| format!("{}: {err}", unsupported_matern_nu_message(raw)))?;
4194 let den = den
4195 .trim()
4196 .parse::<f64>()
4197 .map_err(|err| format!("{}: {err}", unsupported_matern_nu_message(raw)))?;
4198 if den == 0.0 || !num.is_finite() || !den.is_finite() {
4199 return Err(unsupported_matern_nu_message(raw));
4200 }
4201 num / den
4202 } else {
4203 trimmed
4204 .parse::<f64>()
4205 .map_err(|err| format!("{}: {err}", unsupported_matern_nu_message(raw)))?
4206 };
4207
4208 const TOL: f64 = 1e-12;
4209 if (value - 0.5).abs() <= TOL {
4210 Ok(MaternNu::Half)
4211 } else if (value - 1.5).abs() <= TOL {
4212 Ok(MaternNu::ThreeHalves)
4213 } else if (value - 2.5).abs() <= TOL {
4214 Ok(MaternNu::FiveHalves)
4215 } else if (value - 3.5).abs() <= TOL {
4216 Ok(MaternNu::SevenHalves)
4217 } else if (value - 4.5).abs() <= TOL {
4218 Ok(MaternNu::NineHalves)
4219 } else {
4220 Err(unsupported_matern_nu_message(raw))
4221 }
4222}
4223
4224fn unsupported_matern_nu_message(raw: &str) -> String {
4225 TermBuilderError::unsupported_feature(format!(
4226 "unsupported Matern nu '{raw}'; supported half-integer values are 1/2, 3/2, 5/2, 7/2, and 9/2"
4227 ))
4228 .to_string()
4229}
4230
4231#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
4232pub enum DuchonPowerPolicy {
4233 Explicit(f64),
4234 CubicStructuralDefault,
4238}
4239
4240pub fn parse_duchon_power_policy(
4241 options: &BTreeMap<String, String>,
4242) -> Result<DuchonPowerPolicy, String> {
4243 if let Some(raw_nu) = options.get("nu") {
4244 return Err(TermBuilderError::incompatible_config(format!(
4245 "Duchon smooths use power=<number>, not nu='{}'. Use power=1.5, power=2, etc.",
4246 raw_nu
4247 ))
4248 .to_string());
4249 }
4250 match options.get("power") {
4251 Some(raw) => {
4252 let value = raw.parse::<f64>().map_err(|err| {
4253 TermBuilderError::invalid_option(format!(
4254 "invalid Duchon power '{}'; expected a non-negative number such as power=1.5 or power=2: {}",
4255 raw, err
4256 ))
4257 .to_string()
4258 })?;
4259 if !value.is_finite() || value < 0.0 {
4260 return Err(TermBuilderError::invalid_option(format!(
4261 "invalid Duchon power '{}'; expected a finite non-negative number such as power=1.5 or power=2",
4262 raw
4263 ))
4264 .to_string());
4265 }
4266 Ok(DuchonPowerPolicy::Explicit(value))
4267 }
4268 None => Ok(DuchonPowerPolicy::CubicStructuralDefault),
4269 }
4270}
4271
4272pub fn parse_duchon_power(options: &BTreeMap<String, String>) -> Result<f64, String> {
4273 match parse_duchon_power_policy(options)? {
4274 DuchonPowerPolicy::Explicit(power) => Ok(power),
4275 DuchonPowerPolicy::CubicStructuralDefault => Ok(1.5),
4281 }
4282}
4283
4284pub fn parse_duchon_order(
4285 options: &BTreeMap<String, String>,
4286) -> Result<DuchonNullspaceOrder, String> {
4287 match options.get("order") {
4288 None => Ok(DuchonNullspaceOrder::Linear),
4292 Some(raw) => match raw.parse::<usize>() {
4293 Ok(0) => Ok(DuchonNullspaceOrder::Zero),
4294 Ok(1) => Ok(DuchonNullspaceOrder::Linear),
4295 Ok(other) => Ok(DuchonNullspaceOrder::Degree(other)),
4296 Err(_) => Err(TermBuilderError::invalid_option(format!(
4297 "invalid Duchon order '{}'; expected a non-negative integer such as order=0, order=1, or order=2",
4298 raw
4299 ))
4300 .to_string()),
4301 },
4302 }
4303}
4304
4305fn parse_matern_identifiability(
4306 options: &BTreeMap<String, String>,
4307) -> Result<MaternIdentifiability, TermBuilderError> {
4308 let Some(raw) = options.get("identifiability").map(String::as_str) else {
4309 return Ok(MaternIdentifiability::default());
4310 };
4311 match raw.trim().to_ascii_lowercase().as_str() {
4312 "none" => Ok(MaternIdentifiability::None),
4313 "sum_tozero" | "sum-to-zero" | "center_sum_tozero" | "center-sum-to-zero" | "centered" => {
4314 Ok(MaternIdentifiability::CenterSumToZero)
4315 }
4316 "linear" | "center_linear_orthogonal" | "center-linear-orthogonal" => {
4317 Ok(MaternIdentifiability::CenterLinearOrthogonal)
4318 }
4319 other => Err(TermBuilderError::unsupported_feature(format!(
4320 "invalid Matérn identifiability '{other}'; expected one of: none, sum_tozero, linear"
4321 ))),
4322 }
4323}
4324
4325fn parse_spatial_identifiability(
4326 options: &BTreeMap<String, String>,
4327) -> Result<SpatialIdentifiability, TermBuilderError> {
4328 let Some(raw) = options.get("identifiability").map(String::as_str) else {
4329 return Ok(SpatialIdentifiability::default());
4330 };
4331 match raw.trim().to_ascii_lowercase().as_str() {
4332 "none" => Ok(SpatialIdentifiability::None),
4333 "orthogonal"
4334 | "orthogonal_to_parametric"
4335 | "orthogonal-to-parametric"
4336 | "parametric_orthogonal" => Ok(SpatialIdentifiability::OrthogonalToParametric),
4337 "frozen" => Err(TermBuilderError::unsupported_feature(
4338 "spatial identifiability 'frozen' is internal-only; use none or orthogonal_to_parametric",
4339 )),
4340 other => Err(TermBuilderError::unsupported_feature(format!(
4341 "invalid spatial identifiability '{other}'; expected one of: none, orthogonal_to_parametric"
4342 ))),
4343 }
4344}
4345
4346#[cfg(test)]
4347mod tests {
4348 use super::*;
4349 use crate::inference::formula_dsl::parse_formula;
4350 use gam_data::{DataSchema, SchemaColumn};
4351 use ndarray::Array2;
4352 use std::collections::BTreeMap;
4353
4354 fn continuous_dataset(headers: &[&str], rows: Vec<Vec<f64>>) -> Dataset {
4355 let nrows = rows.len();
4356 let ncols = headers.len();
4357 let values = Array2::from_shape_vec(
4358 (nrows, ncols),
4359 rows.into_iter().flat_map(|row| row.into_iter()).collect(),
4360 )
4361 .expect("rectangular test data");
4362 Dataset {
4363 headers: headers.iter().map(|name| name.to_string()).collect(),
4364 values,
4365 schema: DataSchema {
4366 columns: headers
4367 .iter()
4368 .map(|name| SchemaColumn {
4369 name: name.to_string(),
4370 kind: ColumnKindTag::Continuous,
4371 levels: vec![],
4372 })
4373 .collect(),
4374 },
4375 column_kinds: vec![ColumnKindTag::Continuous; ncols],
4376 }
4377 }
4378
4379 fn factor_dataset() -> Dataset {
4380 let rows = (0..24)
4381 .map(|i| {
4382 let x = i as f64 / 23.0;
4383 let g = (i % 2) as f64;
4384 vec![x + g, x, g]
4385 })
4386 .collect::<Vec<_>>();
4387 Dataset {
4388 headers: vec!["y".into(), "x".into(), "g".into()],
4389 values: Array2::from_shape_vec(
4390 (rows.len(), 3),
4391 rows.into_iter().flat_map(|row| row.into_iter()).collect(),
4392 )
4393 .expect("rectangular factor test data"),
4394 schema: DataSchema {
4395 columns: vec![
4396 SchemaColumn {
4397 name: "y".into(),
4398 kind: ColumnKindTag::Continuous,
4399 levels: vec![],
4400 },
4401 SchemaColumn {
4402 name: "x".into(),
4403 kind: ColumnKindTag::Continuous,
4404 levels: vec![],
4405 },
4406 SchemaColumn {
4407 name: "g".into(),
4408 kind: ColumnKindTag::Categorical,
4409 levels: vec!["a".into(), "b".into()],
4410 },
4411 ],
4412 },
4413 column_kinds: vec![
4414 ColumnKindTag::Continuous,
4415 ColumnKindTag::Continuous,
4416 ColumnKindTag::Categorical,
4417 ],
4418 }
4419 }
4420
4421 #[test]
4429 fn default_univariate_thinplate_basis_dim_is_modest() {
4430 let n = 300usize;
4433 let rows: Vec<Vec<f64>> = (0..n)
4434 .map(|i| {
4435 let x = -3.0 + 6.0 * (i as f64) / ((n - 1) as f64);
4436 vec![x.sin(), x]
4437 })
4438 .collect();
4439 let ds = continuous_dataset(&["y", "x"], rows);
4440
4441 let mut options = BTreeMap::new();
4442 options.insert("bs".to_string(), "tp".to_string());
4443
4444 let mut notes = Vec::new();
4445 let basis = build_smooth_basis(
4446 SmoothKind::S,
4447 &["x".to_string()],
4448 &[1],
4449 &options,
4450 &ds,
4451 &mut notes,
4452 &ResourcePolicy::default_library(),
4453 1,
4454 )
4455 .expect("build default univariate tp smooth");
4456
4457 let centers = match &basis {
4458 SmoothBasisSpec::ThinPlate { spec, .. } => match &spec.center_strategy {
4459 CenterStrategy::Auto(inner) => match inner.as_ref() {
4460 CenterStrategy::FarthestPoint { num_centers }
4461 | CenterStrategy::EqualMass { num_centers }
4462 | CenterStrategy::EqualMassCovarRepresentative { num_centers }
4463 | CenterStrategy::KMeans { num_centers, .. } => *num_centers,
4464 other => panic!("unexpected auto inner center strategy: {other:?}"),
4465 },
4466 CenterStrategy::FarthestPoint { num_centers }
4467 | CenterStrategy::EqualMass { num_centers }
4468 | CenterStrategy::EqualMassCovarRepresentative { num_centers }
4469 | CenterStrategy::KMeans { num_centers, .. } => *num_centers,
4470 other => panic!("unexpected center strategy: {other:?}"),
4471 },
4472 other => panic!("expected ThinPlate basis, got {other:?}"),
4473 };
4474
4475 assert!(
4479 centers >= 1,
4480 "default univariate tp must still build a usable basis (centers={centers})",
4481 );
4482 }
4483
4484 #[test]
4495 fn default_matern_2d_seeds_resolving_length_scale_not_overscaled_diameter() {
4496 let side = 24usize; let mut rows: Vec<Vec<f64>> = Vec::with_capacity(side * side);
4501 for i in 0..side {
4502 for j in 0..side {
4503 let x1 = i as f64 / (side - 1) as f64; let x2 = j as f64 / (side - 1) as f64; let y = (6.0 * x1).sin() * (6.0 * x2).cos();
4506 rows.push(vec![y, x1, x2]);
4507 }
4508 }
4509 let n = rows.len();
4510 let ds = continuous_dataset(&["y", "x1", "x2"], rows);
4511
4512 let mut options = BTreeMap::new();
4513 options.insert("bs".to_string(), "gp".to_string()); let mut notes = Vec::new();
4515 let mut basis = build_smooth_basis(
4516 SmoothKind::S,
4517 &["x1".to_string(), "x2".to_string()],
4518 &[1, 2],
4519 &options,
4520 &ds,
4521 &mut notes,
4522 &ResourcePolicy::default_library(),
4523 1,
4524 )
4525 .expect("build default 2-D matern smooth");
4526
4527 let (feature_cols, seeded_length_scale) = match &basis {
4529 SmoothBasisSpec::Matern {
4530 feature_cols, spec, ..
4531 } => (feature_cols.clone(), spec.length_scale),
4532 other => panic!("expected Matern basis, got {other:?}"),
4533 };
4534 assert_eq!(
4535 seeded_length_scale, 0.0,
4536 "default matern() must leave length_scale at the 0.0 auto sentinel \
4537 (got {seeded_length_scale}); a non-zero diameter default re-enters the \
4538 over-smoothed basin and disables the planner's wiggly-side auto-init",
4539 );
4540
4541 crate::smooth::auto_init_length_scale_in_basis(ds.values.view(), &mut basis);
4545 let realized = match &basis {
4546 SmoothBasisSpec::Matern { spec, .. } => spec.length_scale,
4547 other => panic!("expected Matern basis after auto-init, got {other:?}"),
4548 };
4549 let expected =
4550 crate::smooth::auto_initial_length_scale(ds.values.view(), &feature_cols);
4551 assert!(
4552 (realized - expected).abs() <= 1e-12,
4553 "auto-init must seed the wiggly-side length scale max_range/sqrt(n) \
4554 (expected {expected}, got {realized})",
4555 );
4556
4557 let max_range = 1.0_f64; assert!(
4562 realized < max_range / 4.0,
4563 "matern seed length_scale {realized} must be in the resolving regime, \
4564 not the over-smoothed diameter corner (n={n}, max_range≈{max_range})",
4565 );
4566 }
4567
4568 fn inferred_tensor_basis_product(ds: &Dataset) -> usize {
4569 let parsed = parse_formula("y ~ te(theta, h)").expect("parse tensor formula");
4570 let col_map = ds.column_map();
4571 let mut notes = Vec::new();
4572 let terms = build_termspec(
4573 &parsed.terms,
4574 ds,
4575 &col_map,
4576 &mut notes,
4577 &ResourcePolicy::default_library(),
4578 )
4579 .expect("build tensor termspec");
4580 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
4581 panic!("expected tensor smooth");
4582 };
4583 spec.marginalspecs
4584 .iter()
4585 .map(|marginal| match marginal.knotspec {
4586 BSplineKnotSpec::Generate {
4587 num_internal_knots, ..
4588 } => num_internal_knots + marginal.degree + 1,
4589 BSplineKnotSpec::PeriodicUniform { num_basis, .. } => num_basis,
4590 BSplineKnotSpec::Automatic {
4591 num_internal_knots: Some(num_internal_knots),
4592 ..
4593 } => num_internal_knots + marginal.degree + 1,
4594 BSplineKnotSpec::Automatic {
4595 num_internal_knots: None,
4596 ..
4597 } => panic!("test helper cannot infer automatic knot count"),
4598 BSplineKnotSpec::Provided(ref knots) => {
4599 knots.len().saturating_sub(marginal.degree + 1)
4600 }
4601 BSplineKnotSpec::NaturalCubicRegression { ref knots } => knots.len(),
4603 })
4604 .product()
4605 }
4606
4607 fn tensor_margin_basis_sizes(ds: &Dataset, formula: &str) -> Vec<usize> {
4608 let parsed = parse_formula(formula).expect("parse tensor formula");
4609 let col_map = ds.column_map();
4610 let mut notes = Vec::new();
4611 let terms = build_termspec(
4612 &parsed.terms,
4613 ds,
4614 &col_map,
4615 &mut notes,
4616 &ResourcePolicy::default_library(),
4617 )
4618 .expect("build tensor termspec");
4619 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
4620 panic!("expected tensor smooth");
4621 };
4622 spec.marginalspecs
4623 .iter()
4624 .map(|marginal| match marginal.knotspec {
4625 BSplineKnotSpec::Generate {
4626 num_internal_knots, ..
4627 } => num_internal_knots + marginal.degree + 1,
4628 BSplineKnotSpec::PeriodicUniform { num_basis, .. } => num_basis,
4629 BSplineKnotSpec::Automatic {
4630 num_internal_knots: Some(num_internal_knots),
4631 ..
4632 } => num_internal_knots + marginal.degree + 1,
4633 BSplineKnotSpec::Automatic {
4634 num_internal_knots: None,
4635 ..
4636 } => panic!("test helper cannot infer automatic knot count"),
4637 BSplineKnotSpec::Provided(ref knots) => {
4638 knots.len().saturating_sub(marginal.degree + 1)
4639 }
4640 BSplineKnotSpec::NaturalCubicRegression { ref knots } => knots.len(),
4642 })
4643 .collect()
4644 }
4645
4646 #[test]
4647 fn validate_known_options_lists_valid_option_names_for_unknown_parameter() {
4648 let mut options = BTreeMap::new();
4649 options.insert("lengt_scale".to_string(), "0.25".to_string());
4650 let err = validate_known_options(
4651 "matern",
4652 &options,
4653 &["type", "bs", "length_scale", "centers", "k", "nu"],
4654 )
4655 .expect_err("unknown smooth option should be rejected");
4656 assert!(
4657 err.contains("matern() does not accept option `lengt_scale`"),
4658 "error should name the invalid option, got: {err}"
4659 );
4660 assert!(
4661 err.contains("did you mean one of [length_scale]"),
4662 "error should suggest the closest valid option, got: {err}"
4663 );
4664 assert!(
4665 err.contains("Valid options: ["),
4666 "error should list valid option names, got: {err}"
4667 );
4668 }
4669
4670 #[test]
4671 fn tensor_k_accepts_square_bracket_per_margin_list() {
4672 let ds = continuous_dataset(
4673 &["y", "x", "z"],
4674 (0..40)
4675 .map(|i| {
4676 let x = i as f64 / 39.0;
4677 let z = ((i * 7) % 40) as f64 / 39.0;
4678 vec![x.sin() + z.cos(), x, z]
4679 })
4680 .collect(),
4681 );
4682
4683 assert_eq!(
4684 tensor_margin_basis_sizes(&ds, "y ~ te(x, z, k=[5, 6])"),
4685 vec![5, 6],
4686 "square-bracket k lists should materialize the requested per-margin values"
4687 );
4688 }
4689
4690 #[test]
4691 fn parse_cylinder_periodic_options_match_requested_forms() {
4692 let mut opts = BTreeMap::new();
4693 opts.insert("periodic".to_string(), "[0]".to_string());
4694 opts.insert("period".to_string(), "[2*pi, None]".to_string());
4695 let axes = parse_periodic_axes(&opts, 2).expect("axes");
4696 let periods = parse_periods(&opts, &axes).expect("periods");
4697 assert_eq!(axes, vec![true, false]);
4698 assert!((periods[0].unwrap() - 2.0 * std::f64::consts::PI).abs() < 1e-12);
4699 assert_eq!(periods[1], None);
4700
4701 let mut boundary_opts = BTreeMap::new();
4702 boundary_opts.insert(
4703 "boundary".to_string(),
4704 "['periodic', 'natural']".to_string(),
4705 );
4706 boundary_opts.insert("period".to_string(), "[2*pi, None]".to_string());
4707 let boundary_axes = parse_periodic_axes(&boundary_opts, 2).expect("boundary axes");
4708 let boundary_periods =
4709 parse_periods(&boundary_opts, &boundary_axes).expect("boundary periods");
4710 assert_eq!(boundary_axes, vec![true, false]);
4711 assert!((boundary_periods[0].unwrap() - 2.0 * std::f64::consts::PI).abs() < 1e-12);
4712 assert_eq!(boundary_periods[1], None);
4713
4714 let mut unicode_opts = BTreeMap::new();
4715 unicode_opts.insert("periodic".to_string(), "[0,1]".to_string());
4716 unicode_opts.insert("period".to_string(), "[2π, τ]".to_string());
4717 let unicode_axes = parse_periodic_axes(&unicode_opts, 2).expect("unicode axes");
4718 let unicode_periods = parse_periods(&unicode_opts, &unicode_axes).expect("unicode periods");
4719 assert_eq!(unicode_axes, vec![true, true]);
4720 assert!((unicode_periods[0].unwrap() - 2.0 * std::f64::consts::PI).abs() < 1e-12);
4721 assert!((unicode_periods[1].unwrap() - std::f64::consts::TAU).abs() < 1e-12);
4722 }
4723
4724 #[test]
4725 fn parse_single_axis_periodic_zero_as_axis_not_false() {
4726 let mut opts = BTreeMap::new();
4727 opts.insert("periodic".to_string(), "[0]".to_string());
4728 opts.insert("period".to_string(), "2*pi".to_string());
4729 opts.insert("origin".to_string(), "0".to_string());
4730 let axes = parse_periodic_axes(&opts, 1).expect("axes");
4731 let periods = parse_periods(&opts, &axes).expect("periods");
4732 let origins = parse_period_origins(&opts, &axes).expect("origins");
4733 assert_eq!(axes, vec![true]);
4734 assert!((periods[0].unwrap() - 2.0 * std::f64::consts::PI).abs() < 1e-12);
4735 assert_eq!(origins[0], Some(0.0));
4736 }
4737
4738 #[test]
4739 fn one_dimensional_bspline_accepts_boundary_periodic() {
4740 let ds = continuous_dataset(
4741 &["y", "theta"],
4742 (0..16)
4743 .map(|i| {
4744 let theta = std::f64::consts::TAU * i as f64 / 16.0;
4745 vec![theta.sin(), theta]
4746 })
4747 .collect(),
4748 );
4749 let parsed = parse_formula("y ~ s(theta, boundary=periodic, period=2*pi, origin=0, k=8)")
4750 .expect("parse");
4751 let col_map = ds.column_map();
4752 let mut notes = Vec::new();
4753 let terms = build_termspec(
4754 &parsed.terms,
4755 &ds,
4756 &col_map,
4757 &mut notes,
4758 &gam_runtime::resource::ResourcePolicy::default_library(),
4759 )
4760 .expect("periodic boundary should build");
4761 let SmoothBasisSpec::BSpline1D { spec, .. } = &terms.smooth_terms[0].basis else {
4762 panic!("expected 1D B-spline");
4763 };
4764 assert!(matches!(
4765 &spec.knotspec,
4766 BSplineKnotSpec::PeriodicUniform {
4767 data_range,
4768 num_basis: 8
4769 } if *data_range == (0.0, std::f64::consts::TAU)
4770 ));
4771 }
4772
4773 #[test]
4774 fn univariate_smooth_accepts_mgcv_cubic_regression_aliases() {
4775 let ds = continuous_dataset(
4776 &["y", "x"],
4777 (0..32)
4778 .map(|i| {
4779 let x = i as f64 / 31.0;
4780 vec![x * x, x]
4781 })
4782 .collect(),
4783 );
4784 let col_map = ds.column_map();
4785
4786 for (selector, expect_double_penalty) in [("cr", false), ("cs", true)] {
4787 let formula = format!("y ~ s(x, bs='{selector}')");
4788 let parsed = parse_formula(&formula).expect("parse cr/cs smooth");
4789 let mut notes = Vec::new();
4790 let terms = build_termspec(
4791 &parsed.terms,
4792 &ds,
4793 &col_map,
4794 &mut notes,
4795 &gam_runtime::resource::ResourcePolicy::default_library(),
4796 )
4797 .unwrap_or_else(|err| panic!("bs='{selector}' must build a 1-D smooth, got: {err:?}"));
4798 let SmoothBasisSpec::BSpline1D { spec, .. } = &terms.smooth_terms[0].basis else {
4799 panic!(
4800 "bs='{selector}' must lower to a BSpline1D; got {:?}",
4801 terms.smooth_terms[0].basis
4802 );
4803 };
4804 assert_eq!(
4805 spec.double_penalty, expect_double_penalty,
4806 "bs='{selector}' must default double_penalty to mgcv's convention \
4807 (cr=no-shrinkage, cs=shrinkage); got double_penalty={}",
4808 spec.double_penalty
4809 );
4810 }
4811 }
4812
4813 #[test]
4814 fn univariate_ps_small_k_degree_reduces_through_build() {
4815 let ds = continuous_dataset(
4824 &["y", "x"],
4825 (0..32)
4826 .map(|i| {
4827 let x = i as f64 / 31.0;
4828 vec![x * x, x]
4829 })
4830 .collect(),
4831 );
4832 let col_map = ds.column_map();
4833
4834 for formula in ["y ~ s(x, bs='ps', k=3)", "y ~ s(x, k=3)"] {
4835 let parsed = parse_formula(formula).expect("parse small-k ps/cr smooth");
4836 let mut notes = Vec::new();
4837 let terms = build_termspec(
4838 &parsed.terms,
4839 &ds,
4840 &col_map,
4841 &mut notes,
4842 &gam_runtime::resource::ResourcePolicy::default_library(),
4843 )
4844 .unwrap_or_else(|err| {
4845 panic!("`{formula}` must degree-reduce, not error; got: {err:?}")
4846 });
4847 let SmoothBasisSpec::BSpline1D { spec, .. } = &terms.smooth_terms[0].basis else {
4848 panic!(
4849 "`{formula}` must lower to a BSpline1D; got {:?}",
4850 terms.smooth_terms[0].basis
4851 );
4852 };
4853 assert_eq!(
4854 spec.degree, 2,
4855 "`{formula}` must drop the cubic default to a quadratic basis"
4856 );
4857 let num_internal = match &spec.knotspec {
4858 BSplineKnotSpec::Generate {
4859 num_internal_knots, ..
4860 } => *num_internal_knots,
4861 BSplineKnotSpec::Automatic {
4862 num_internal_knots: Some(n),
4863 ..
4864 } => *n,
4865 other => panic!("`{formula}` unexpected knotspec: {other:?}"),
4866 };
4867 assert_eq!(
4868 num_internal, 0,
4869 "`{formula}` must have zero internal knots (num_basis = k = 3)"
4870 );
4871 assert!(
4873 spec.penalty_order >= 1 && spec.penalty_order <= spec.degree,
4874 "`{formula}` penalty_order {} must satisfy 1 <= order <= degree={}",
4875 spec.penalty_order,
4876 spec.degree
4877 );
4878 }
4879 }
4880
4881 #[test]
4882 fn formula_shape_constraint_round_trips_and_rejects_bogus() {
4883 let ds = continuous_dataset(
4884 &["y", "x"],
4885 (0..32)
4886 .map(|i| {
4887 let x = i as f64 / 31.0;
4888 vec![x * x, x]
4889 })
4890 .collect(),
4891 );
4892 let col_map = ds.column_map();
4893
4894 let parsed =
4895 parse_formula("y ~ s(x, shape=monotone_increasing)").expect("parse monotone smooth");
4896 let mut notes = Vec::new();
4897 let terms = build_termspec(
4898 &parsed.terms,
4899 &ds,
4900 &col_map,
4901 &mut notes,
4902 &gam_runtime::resource::ResourcePolicy::default_library(),
4903 )
4904 .expect("monotone smooth should build");
4905 assert_eq!(
4906 terms.smooth_terms[0].shape,
4907 ShapeConstraint::MonotoneIncreasing
4908 );
4909
4910 let parsed_bad = parse_formula("y ~ s(x, shape=bogus)").expect("parse bogus shape");
4911 let mut notes_bad = Vec::new();
4912 let err = build_termspec(
4913 &parsed_bad.terms,
4914 &ds,
4915 &col_map,
4916 &mut notes_bad,
4917 &gam_runtime::resource::ResourcePolicy::default_library(),
4918 )
4919 .expect_err("bogus shape must error");
4920 assert!(
4921 format!("{err:?}").contains("unknown shape constraint"),
4922 "got: {err:?}"
4923 );
4924 }
4925
4926 #[test]
4927 fn default_sphere_smooth_uses_spherical_farthest_point_centers() {
4928 let ds = continuous_dataset(
4929 &["y", "lat", "lon"],
4930 (0..24)
4931 .map(|i| {
4932 let t = i as f64 / 24.0;
4933 let lat = -60.0 + 120.0 * t;
4934 let lon = -180.0 + 360.0 * ((7 * i) % 24) as f64 / 24.0;
4935 vec![lat.to_radians().sin(), lat, lon]
4936 })
4937 .collect(),
4938 );
4939 let parsed = parse_formula("y ~ sphere(lat, lon)").expect("parse");
4940 let col_map = ds.column_map();
4941 let mut notes = Vec::new();
4942 let terms = build_termspec(
4943 &parsed.terms,
4944 &ds,
4945 &col_map,
4946 &mut notes,
4947 &gam_runtime::resource::ResourcePolicy::default_library(),
4948 )
4949 .expect("build sphere termspec");
4950 let SmoothBasisSpec::Sphere { spec, .. } = &terms.smooth_terms[0].basis else {
4951 panic!("expected sphere term");
4952 };
4953 assert!(matches!(
4954 spec.center_strategy,
4955 CenterStrategy::FarthestPoint { .. }
4956 ));
4957 }
4958
4959 #[test]
4960 fn one_dimensional_duchon_defaults_to_scale_free_length_scale() {
4961 let ds = continuous_dataset(
4962 &["y", "x"],
4963 (0..32)
4964 .map(|i| {
4965 let x = i as f64 / 31.0;
4966 vec![(std::f64::consts::TAU * x).sin(), x]
4967 })
4968 .collect(),
4969 );
4970 let parsed = parse_formula("y ~ duchon(x)").expect("parse");
4971 let col_map = ds.column_map();
4972 let mut notes = Vec::new();
4973 let terms = build_termspec(
4974 &parsed.terms,
4975 &ds,
4976 &col_map,
4977 &mut notes,
4978 &gam_runtime::resource::ResourcePolicy::default_library(),
4979 )
4980 .expect("build default duchon termspec");
4981 let SmoothBasisSpec::Duchon { spec, .. } = &terms.smooth_terms[0].basis else {
4982 panic!("expected Duchon term");
4983 };
4984 assert_eq!(spec.length_scale, None);
4985 }
4986
4987 #[test]
4988 fn one_dimensional_duchon_length_scale_opts_into_hybrid_mode() {
4989 let ds = continuous_dataset(
4990 &["y", "x"],
4991 (0..32)
4992 .map(|i| {
4993 let x = i as f64 / 31.0;
4994 vec![(std::f64::consts::TAU * x).sin(), x]
4995 })
4996 .collect(),
4997 );
4998 let parsed = parse_formula("y ~ duchon(x, length_scale=0.25)").expect("parse");
4999 let col_map = ds.column_map();
5000 let mut notes = Vec::new();
5001 let terms = build_termspec(
5002 &parsed.terms,
5003 &ds,
5004 &col_map,
5005 &mut notes,
5006 &gam_runtime::resource::ResourcePolicy::default_library(),
5007 )
5008 .expect("build hybrid duchon termspec");
5009 let SmoothBasisSpec::Duchon { spec, .. } = &terms.smooth_terms[0].basis else {
5010 panic!("expected Duchon term");
5011 };
5012 assert_eq!(spec.length_scale, Some(0.25));
5013 }
5014
5015 #[test]
5016 fn parse_matern_nu_accepts_equivalent_half_integer_forms() {
5017 let cases = [
5018 ("1/2", MaternNu::Half),
5019 (" 1 / 2 ", MaternNu::Half),
5020 (".5", MaternNu::Half),
5021 ("0.50", MaternNu::Half),
5022 ("half", MaternNu::Half),
5023 ("3 / 2", MaternNu::ThreeHalves),
5024 ("1.50", MaternNu::ThreeHalves),
5025 ("5 / 2", MaternNu::FiveHalves),
5026 ("2.500000000000", MaternNu::FiveHalves),
5027 ("7 / 2", MaternNu::SevenHalves),
5028 ("3.50", MaternNu::SevenHalves),
5029 ("9 / 2", MaternNu::NineHalves),
5030 ("4.50", MaternNu::NineHalves),
5031 ];
5032 for (raw, expected) in cases {
5033 let parsed = parse_matern_nu(raw).expect(raw);
5034 assert!(
5035 matches!(
5036 (parsed, expected),
5037 (MaternNu::Half, MaternNu::Half)
5038 | (MaternNu::ThreeHalves, MaternNu::ThreeHalves)
5039 | (MaternNu::FiveHalves, MaternNu::FiveHalves)
5040 | (MaternNu::SevenHalves, MaternNu::SevenHalves)
5041 | (MaternNu::NineHalves, MaternNu::NineHalves)
5042 ),
5043 "parsed {raw:?} as {parsed:?}, expected {expected:?}"
5044 );
5045 }
5046 }
5047
5048 #[test]
5049 fn parse_matern_nu_rejects_unsupported_or_invalid_values() {
5050 for raw in ["1", "2", "11/2", "1/0", "nan", "fast"] {
5051 let err = parse_matern_nu(raw).expect_err(raw);
5052 assert!(
5053 err.contains("supported half-integer values"),
5054 "unexpected error for {raw:?}: {err}"
5055 );
5056 }
5057 }
5058
5059 #[test]
5060 fn parse_ps_k_promotes_underexpressive_cubic_basis() {
5061 let mut opts = BTreeMap::new();
5062 opts.insert("k".to_string(), "4".to_string());
5063 let (internal, inferred, eff_degree) = parse_ps_internal_knots(&opts, 3, 20).expect("k=4");
5064 assert_eq!(internal, 2);
5065 assert_eq!(eff_degree, 3);
5066 assert!(!inferred);
5067
5068 opts.insert("k".to_string(), "6".to_string());
5069 let (internal, inferred, eff_degree) = parse_ps_internal_knots(&opts, 3, 20).expect("k=6");
5070 assert_eq!(internal, 2);
5071 assert_eq!(eff_degree, 3);
5072 assert!(!inferred);
5073
5074 opts.insert("k".to_string(), "10".to_string());
5075 let (internal, inferred, eff_degree) = parse_ps_internal_knots(&opts, 3, 20).expect("k=10");
5076 assert_eq!(internal, 6);
5077 assert_eq!(eff_degree, 3);
5078 assert!(!inferred);
5079 }
5080
5081 #[test]
5082 fn parse_ps_internal_knots_drops_degree_for_small_k() {
5083 let mut opts = BTreeMap::new();
5088 opts.insert("k".to_string(), "3".to_string());
5089 let (internal, inferred, eff_degree) = parse_ps_internal_knots(&opts, 3, 20).expect("k=3");
5090 assert_eq!(eff_degree, 2);
5091 assert_eq!(internal, 0);
5092 assert!(!inferred);
5093
5094 opts.insert("k".to_string(), "2".to_string());
5097 let (internal, inferred, eff_degree) = parse_ps_internal_knots(&opts, 3, 20).expect("k=2");
5098 assert_eq!(eff_degree, 1);
5099 assert_eq!(internal, 0);
5100 assert!(!inferred);
5101
5102 opts.insert("k".to_string(), "1".to_string());
5106 let err = parse_ps_internal_knots(&opts, 3, 20)
5107 .expect_err("k=1 is below the irreducible spline floor");
5108 assert!(err.contains("requires k >= 2"), "unexpected error: {err}");
5109
5110 opts.insert("k".to_string(), "4".to_string());
5113 let (internal, inferred, eff_degree) = parse_ps_internal_knots(&opts, 3, 20).expect("k=4");
5114 assert_eq!(eff_degree, 3);
5115 assert_eq!(internal, 2);
5116 assert!(!inferred);
5117 }
5118
5119 #[test]
5120 fn factor_smooth_marginal_degree_reduces_for_small_k() {
5121 let ds = factor_dataset();
5122 let col_map = ds.column_map();
5123
5124 for (k, expected_degree) in [(3usize, 2usize), (2usize, 1usize)] {
5125 let parsed =
5126 parse_formula(&format!("y ~ s(x, g, bs=fs, k={k})")).expect("parse factor smooth");
5127 let mut notes = Vec::new();
5128 let terms = build_termspec(
5129 &parsed.terms,
5130 &ds,
5131 &col_map,
5132 &mut notes,
5133 &gam_runtime::resource::ResourcePolicy::default_library(),
5134 )
5135 .unwrap_or_else(|err| panic!("fs k={k} should degree-reduce, got: {err:?}"));
5136 let SmoothBasisSpec::FactorSmooth { spec } = &terms.smooth_terms[0].basis else {
5137 panic!(
5138 "expected factor smooth, got {:?}",
5139 terms.smooth_terms[0].basis
5140 );
5141 };
5142 assert_eq!(spec.marginal.degree, expected_degree);
5143 assert!(
5144 spec.marginal.penalty_order <= spec.marginal.degree,
5145 "penalty_order {} must be clamped to degree {}",
5146 spec.marginal.penalty_order,
5147 spec.marginal.degree
5148 );
5149 let basis_size = match spec.marginal.knotspec {
5150 BSplineKnotSpec::Generate {
5151 num_internal_knots, ..
5152 } => num_internal_knots + spec.marginal.degree + 1,
5153 BSplineKnotSpec::Automatic {
5154 num_internal_knots: Some(num_internal_knots),
5155 ..
5156 } => num_internal_knots + spec.marginal.degree + 1,
5157 ref other => panic!("unexpected factor-smooth knotspec: {other:?}"),
5158 };
5159 assert_eq!(basis_size, k);
5160 }
5161 }
5162
5163 fn ternary_factor_dataset() -> Dataset {
5166 let rows = (0..120)
5167 .map(|i| {
5168 let x = (i % 3) as f64;
5169 let g = (i % 2) as f64;
5170 vec![x + g, x, g]
5171 })
5172 .collect::<Vec<_>>();
5173 Dataset {
5174 headers: vec!["y".into(), "x".into(), "g".into()],
5175 values: Array2::from_shape_vec(
5176 (rows.len(), 3),
5177 rows.into_iter().flat_map(|row| row.into_iter()).collect(),
5178 )
5179 .expect("rectangular ternary factor test data"),
5180 schema: DataSchema {
5181 columns: vec![
5182 SchemaColumn {
5183 name: "y".into(),
5184 kind: ColumnKindTag::Continuous,
5185 levels: vec![],
5186 },
5187 SchemaColumn {
5188 name: "x".into(),
5189 kind: ColumnKindTag::Continuous,
5190 levels: vec![],
5191 },
5192 SchemaColumn {
5193 name: "g".into(),
5194 kind: ColumnKindTag::Categorical,
5195 levels: vec!["a".into(), "b".into()],
5196 },
5197 ],
5198 },
5199 column_kinds: vec![
5200 ColumnKindTag::Continuous,
5201 ColumnKindTag::Continuous,
5202 ColumnKindTag::Categorical,
5203 ],
5204 }
5205 }
5206
5207 #[test]
5208 fn univariate_cr_smooth_caps_knots_to_data_support() {
5209 let ds = continuous_dataset(
5215 &["y", "x"],
5216 (0..90)
5217 .map(|i| vec![(i % 3) as f64, (i % 3) as f64])
5218 .collect(),
5219 );
5220 let col_map = ds.column_map();
5221 let parsed = parse_formula("y ~ s(x, bs=cr, k=10)").expect("parse cr smooth");
5222 let mut notes = Vec::new();
5223 let terms = build_termspec(
5224 &parsed.terms,
5225 &ds,
5226 &col_map,
5227 &mut notes,
5228 &gam_runtime::resource::ResourcePolicy::default_library(),
5229 )
5230 .expect("cr k=10 must cap to data support instead of erroring");
5231 let SmoothBasisSpec::BSpline1D { spec, .. } = &terms.smooth_terms[0].basis else {
5232 panic!("expected BSpline1D for s(x, bs=cr)");
5233 };
5234 let BSplineKnotSpec::NaturalCubicRegression { knots } = &spec.knotspec else {
5235 panic!("expected cr knotspec, got {:?}", spec.knotspec);
5236 };
5237 assert_eq!(knots.len(), 3, "cr basis not capped to 3 distinct values");
5239 assert_eq!(knots.as_slice().unwrap(), &[0.0, 1.0, 2.0]);
5240 assert!(
5242 notes.iter().any(|n| n.contains("data-support cap")),
5243 "cap not reported in inference notes: {notes:?}"
5244 );
5245 }
5246
5247 #[test]
5248 fn univariate_cr_smooth_binary_covariate_degrades_to_bspline() {
5249 let ds = continuous_dataset(
5253 &["y", "x"],
5254 (0..80)
5255 .map(|i| vec![(i % 2) as f64, (i % 2) as f64])
5256 .collect(),
5257 );
5258 let col_map = ds.column_map();
5259 let parsed = parse_formula("y ~ s(x, bs=cr, k=10)").expect("parse cr smooth");
5260 let mut notes = Vec::new();
5261 let terms = build_termspec(
5262 &parsed.terms,
5263 &ds,
5264 &col_map,
5265 &mut notes,
5266 &gam_runtime::resource::ResourcePolicy::default_library(),
5267 )
5268 .expect("binary cr must degrade to B-spline instead of erroring");
5269 let SmoothBasisSpec::BSpline1D { spec, .. } = &terms.smooth_terms[0].basis else {
5270 panic!("expected BSpline1D for s(x, bs=cr)");
5271 };
5272 assert!(
5273 !matches!(
5274 spec.knotspec,
5275 BSplineKnotSpec::NaturalCubicRegression { .. }
5276 ),
5277 "binary covariate must NOT build a cr basis, got {:?}",
5278 spec.knotspec
5279 );
5280 assert!(
5281 notes
5282 .iter()
5283 .any(|n| n.contains("Degraded to the linear B-spline")),
5284 "degradation not reported in inference notes: {notes:?}"
5285 );
5286 }
5287
5288 #[test]
5289 fn sz_factor_smooth_low_cardinality_uses_bspline_marginal() {
5290 let ds = ternary_factor_dataset();
5299 let col_map = ds.column_map();
5300 let parsed = parse_formula("y ~ s(x, g, bs=sz, k=10)").expect("parse sz factor smooth");
5301 let mut notes = Vec::new();
5302 let terms = build_termspec(
5303 &parsed.terms,
5304 &ds,
5305 &col_map,
5306 &mut notes,
5307 &gam_runtime::resource::ResourcePolicy::default_library(),
5308 )
5309 .expect("sz on a ternary covariate must build (B-spline marginal), not hard-fail");
5310 let SmoothBasisSpec::FactorSmooth { spec } = &terms.smooth_terms[0].basis else {
5311 panic!("expected FactorSmooth for s(x, g, bs=sz)");
5312 };
5313 assert!(
5314 !matches!(
5315 spec.marginal.knotspec,
5316 BSplineKnotSpec::NaturalCubicRegression { .. }
5317 ),
5318 "sz marginal must be a B-spline (curvature-capable), not the \
5319 natural-BC cr basis; got {:?}",
5320 spec.marginal.knotspec
5321 );
5322 }
5323
5324 fn continuous_x_factor_dataset(n: usize, n_groups: usize) -> Dataset {
5329 let rows = (0..n)
5330 .map(|i| {
5331 let x = i as f64 / (n as f64 - 1.0);
5332 let g = (i % n_groups) as f64;
5333 vec![x + g, x, g]
5334 })
5335 .collect::<Vec<_>>();
5336 let levels: Vec<String> = (0..n_groups).map(|k| format!("g{k}")).collect();
5337 Dataset {
5338 headers: vec!["y".into(), "x".into(), "g".into()],
5339 values: Array2::from_shape_vec(
5340 (rows.len(), 3),
5341 rows.into_iter().flat_map(|row| row.into_iter()).collect(),
5342 )
5343 .expect("rectangular continuous-x factor data"),
5344 schema: DataSchema {
5345 columns: vec![
5346 SchemaColumn {
5347 name: "y".into(),
5348 kind: ColumnKindTag::Continuous,
5349 levels: vec![],
5350 },
5351 SchemaColumn {
5352 name: "x".into(),
5353 kind: ColumnKindTag::Continuous,
5354 levels: vec![],
5355 },
5356 SchemaColumn {
5357 name: "g".into(),
5358 kind: ColumnKindTag::Categorical,
5359 levels,
5360 },
5361 ],
5362 },
5363 column_kinds: vec![
5364 ColumnKindTag::Continuous,
5365 ColumnKindTag::Continuous,
5366 ColumnKindTag::Categorical,
5367 ],
5368 }
5369 }
5370
5371 fn factor_smooth_spec_for(formula: &str, ds: &Dataset) -> FactorSmoothSpec {
5372 let col_map = ds.column_map();
5373 let parsed = parse_formula(formula).expect("parse factor smooth formula");
5374 let mut notes = Vec::new();
5375 let terms = build_termspec(
5376 &parsed.terms,
5377 ds,
5378 &col_map,
5379 &mut notes,
5380 &gam_runtime::resource::ResourcePolicy::default_library(),
5381 )
5382 .expect("build factor smooth term");
5383 let SmoothBasisSpec::FactorSmooth { spec } = &terms.smooth_terms[0].basis else {
5384 panic!("expected FactorSmooth basis for `{formula}`");
5385 };
5386 spec.clone()
5387 }
5388
5389 #[test]
5408 fn sz_factor_smooth_carries_null_space_ridge_like_fs() {
5409 let ds = continuous_x_factor_dataset(180, 4);
5410 let mut workspace = crate::basis::BasisWorkspace::new();
5411
5412 let sz_spec = factor_smooth_spec_for("y ~ s(x, g, bs=sz, k=8)", &ds);
5413 let sz_built = crate::smooth::build_factor_smooth(
5414 ds.values.view(),
5415 &sz_spec,
5416 "sz_term",
5417 &mut workspace,
5418 )
5419 .expect("build sz factor smooth");
5420
5421 let fs_spec = factor_smooth_spec_for("y ~ s(x, g, bs=fs, k=8)", &ds);
5422 let fs_built = crate::smooth::build_factor_smooth(
5423 ds.values.view(),
5424 &fs_spec,
5425 "fs_term",
5426 &mut workspace,
5427 )
5428 .expect("build fs factor smooth");
5429
5430 let n_levels = sz_spec
5437 .group_frozen_levels
5438 .as_ref()
5439 .map(|l| l.len())
5440 .unwrap_or(4);
5441 assert!(n_levels >= 3, "test needs >=3 groups, got {n_levels}");
5442
5443 assert_eq!(
5444 sz_built.penalties.len(),
5445 fs_built.penalties.len(),
5446 "sz must carry the same number of penalties as fs (wiggliness + one \
5447 null-space ridge per marginal null direction); sz had {} (only the \
5448 wiggliness penalties => null space unpenalized => over-smoothed), fs \
5449 had {}",
5450 sz_built.penalties.len(),
5451 fs_built.penalties.len(),
5452 );
5453
5454 assert!(
5459 sz_built.penalties.len() >= 2,
5460 "sz deviation block carries no null-space ridge (penalties={}); the \
5461 null space is unpenalized and REML over-smooths the deviations",
5462 sz_built.penalties.len(),
5463 );
5464
5465 assert!(
5470 sz_built.dim < fs_built.dim,
5471 "sz design width {} must be strictly less than fs width {} \
5472 (zero-sum contrast drops one level block)",
5473 sz_built.dim,
5474 fs_built.dim,
5475 );
5476
5477 assert_eq!(sz_built.penalties.len(), sz_built.nullspaces.len());
5480 assert_eq!(sz_built.penalties.len(), sz_built.penaltyinfo.len());
5481 assert_eq!(sz_built.penalties.len(), sz_built.null_eigenvectors.len());
5482 }
5483
5484 fn factor_dataset_l3() -> Dataset {
5495 let rows = (0..30)
5497 .map(|i| {
5498 let x = i as f64 / 29.0;
5499 let g = (i % 3) as f64;
5500 vec![x + g, x, g]
5501 })
5502 .collect::<Vec<_>>();
5503 Dataset {
5504 headers: vec!["y".into(), "x".into(), "g".into()],
5505 values: Array2::from_shape_vec(
5506 (rows.len(), 3),
5507 rows.into_iter().flat_map(|row| row.into_iter()).collect(),
5508 )
5509 .expect("rectangular L=3 factor test data"),
5510 schema: DataSchema {
5511 columns: vec![
5512 SchemaColumn {
5513 name: "y".into(),
5514 kind: ColumnKindTag::Continuous,
5515 levels: vec![],
5516 },
5517 SchemaColumn {
5518 name: "x".into(),
5519 kind: ColumnKindTag::Continuous,
5520 levels: vec![],
5521 },
5522 SchemaColumn {
5523 name: "g".into(),
5524 kind: ColumnKindTag::Categorical,
5525 levels: vec!["a".into(), "b".into(), "c".into()],
5526 },
5527 ],
5528 },
5529 column_kinds: vec![
5530 ColumnKindTag::Continuous,
5531 ColumnKindTag::Continuous,
5532 ColumnKindTag::Categorical,
5533 ],
5534 }
5535 }
5536
5537 #[test]
5538 fn factor_by_smooth_plus_bare_categorical_does_not_duplicate_factor_block() {
5539 let ds = factor_dataset_l3();
5540 let col_map = ds.column_map();
5541
5542 let g_blocks = |formula: &str| -> usize {
5543 let parsed = parse_formula(formula).expect("parse by-smooth formula");
5544 let mut notes = Vec::new();
5545 let terms = build_termspec(
5546 &parsed.terms,
5547 &ds,
5548 &col_map,
5549 &mut notes,
5550 &ResourcePolicy::default_library(),
5551 )
5552 .unwrap_or_else(|err| panic!("`{formula}` must build, got: {err:?}"));
5553 terms
5554 .random_effect_terms
5555 .iter()
5556 .filter(|rt| rt.name == "g")
5557 .count()
5558 };
5559
5560 let by_only = g_blocks("y ~ s(x, by=g, k=10)");
5564 assert_eq!(
5565 by_only, 1,
5566 "`y ~ s(x, by=g)` must produce exactly one `g` design block"
5567 );
5568
5569 let by_plus_bare = g_blocks("y ~ s(x, by=g, k=10) + g");
5573 assert_eq!(
5574 by_plus_bare, 1,
5575 "`y ~ s(x, by=g) + g` must collapse to ONE `g` block (#1457): the bare \
5576 `+ g` already owns the factor's level offsets, so the `by=` branch \
5577 must not add a second, treatment-coded main effect"
5578 );
5579
5580 assert_eq!(
5582 by_plus_bare, by_only,
5583 "the bare `+ g` collision must add zero extra `g` blocks (#1457)"
5584 );
5585 }
5586
5587 #[test]
5588 fn parse_tensor_periods_and_origins_aliases() {
5589 let mut opts = BTreeMap::new();
5590 opts.insert(
5591 "boundary".to_string(),
5592 "['periodic', 'periodic']".to_string(),
5593 );
5594 opts.insert("periods".to_string(), "[7, 24]".to_string());
5595 opts.insert("origins".to_string(), "[0, -12]".to_string());
5596 let axes = parse_periodic_axes(&opts, 2).expect("axes");
5597 let periods = parse_periods(&opts, &axes).expect("periods");
5598 let origins = parse_period_origins(&opts, &axes).expect("origins");
5599 assert_eq!(axes, vec![true, true]);
5600 assert_eq!(periods, vec![Some(7.0), Some(24.0)]);
5601 assert_eq!(origins, vec![Some(0.0), Some(-12.0)]);
5602 }
5603
5604 #[test]
5605 fn tensor_smooth_honors_per_margin_k_list() {
5606 let ds = continuous_dataset(
5607 &["y", "theta", "h"],
5608 (0..20)
5609 .map(|i| {
5610 let theta = std::f64::consts::TAU * i as f64 / 20.0;
5611 let h = -1.0 + 2.0 * (i % 5) as f64 / 4.0;
5612 vec![theta.cos() + h, theta, h]
5613 })
5614 .collect(),
5615 );
5616 let parsed = parse_formula(
5617 "y ~ te(theta, h, periodic=[0], period=[2*pi, None], origin=[0, None], k=[9,5])",
5618 )
5619 .expect("parse tensor formula");
5620 let col_map = ds.column_map();
5621 let mut notes = Vec::new();
5622 let terms = build_termspec(
5623 &parsed.terms,
5624 &ds,
5625 &col_map,
5626 &mut notes,
5627 &gam_runtime::resource::ResourcePolicy::default_library(),
5628 )
5629 .expect("build tensor terms");
5630 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
5631 panic!("expected tensor B-spline");
5632 };
5633 let dims = spec
5634 .marginalspecs
5635 .iter()
5636 .map(|m| match m.knotspec {
5637 BSplineKnotSpec::PeriodicUniform { num_basis, .. } => num_basis,
5638 BSplineKnotSpec::Generate {
5639 num_internal_knots, ..
5640 } => num_internal_knots + m.degree + 1,
5641 BSplineKnotSpec::NaturalCubicRegression { ref knots } => knots.len(),
5644 _ => panic!("unexpected tensor marginal knotspec"),
5645 })
5646 .collect::<Vec<_>>();
5647 assert_eq!(dims, vec![9, 5]);
5648 }
5649
5650 #[test]
5651 fn tensor_smooth_honors_per_margin_k_axis_aliases() {
5652 let ds = continuous_dataset(
5653 &["resp", "x", "y"],
5654 (0..12)
5655 .map(|i| {
5656 let t = i as f64 / 11.0;
5657 vec![t, t, 1.0 - t]
5658 })
5659 .collect(),
5660 );
5661 assert_eq!(
5662 tensor_margin_basis_sizes(&ds, "resp ~ te(x, y, k_x=9, k_y=5)"),
5663 vec![9, 5],
5664 "k_<margin> aliases should materialize requested per-margin values"
5665 );
5666 }
5667
5668 #[test]
5669 fn tensor_smooth_low_cardinality_axis_falls_back_to_lower_degree_basis() {
5670 let ds = continuous_dataset(
5677 &["y", "x", "b"],
5678 (0..40)
5679 .map(|i| {
5680 let x = i as f64 / 39.0;
5681 let b = (i % 2) as f64;
5682 vec![x.sin() + 0.5 * b, x, b]
5683 })
5684 .collect(),
5685 );
5686 let parsed = parse_formula("y ~ te(x, b, k=[5, 2])").expect("parse tensor with k=[5,2]");
5687 let col_map = ds.column_map();
5688 let mut notes = Vec::new();
5689 let terms = build_termspec(
5690 &parsed.terms,
5691 &ds,
5692 &col_map,
5693 &mut notes,
5694 &gam_runtime::resource::ResourcePolicy::default_library(),
5695 )
5696 .expect("build tensor with binary margin");
5697 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
5698 panic!("expected tensor B-spline for te(x, b)");
5699 };
5700 let continuous = &spec.marginalspecs[0];
5704 let binary = &spec.marginalspecs[1];
5705 assert_eq!(continuous.degree, 3);
5706 assert_eq!(binary.degree, 1);
5707 assert!(
5708 binary.penalty_order >= 1 && binary.penalty_order <= binary.degree,
5709 "binary margin penalty_order {} must satisfy 1 <= order <= degree={}",
5710 binary.penalty_order,
5711 binary.degree
5712 );
5713 let basis_size = |m: &BSplineBasisSpec| match m.knotspec {
5714 BSplineKnotSpec::PeriodicUniform { num_basis, .. } => num_basis,
5715 BSplineKnotSpec::Generate {
5716 num_internal_knots, ..
5717 } => num_internal_knots + m.degree + 1,
5718 BSplineKnotSpec::Automatic {
5719 num_internal_knots: Some(n),
5720 ..
5721 } => n + m.degree + 1,
5722 BSplineKnotSpec::NaturalCubicRegression { ref knots } => knots.len(),
5725 _ => panic!("unexpected tensor marginal knotspec"),
5726 };
5727 assert_eq!(basis_size(continuous), 5);
5728 assert_eq!(basis_size(binary), 2);
5729 }
5730
5731 #[test]
5732 fn tensor_smooth_uniform_k_is_capped_to_a_low_cardinality_margins_distinct_values() {
5733 let ds = continuous_dataset(
5741 &["y", "x", "b"],
5742 (0..40)
5743 .map(|i| {
5744 let x = i as f64 / 39.0;
5745 let b = (i % 2) as f64;
5746 vec![x.sin() + 0.5 * b, x, b]
5747 })
5748 .collect(),
5749 );
5750 let parsed = parse_formula("y ~ te(x, b, k=5)").expect("parse tensor with uniform k=5");
5751 let col_map = ds.column_map();
5752 let mut notes = Vec::new();
5753 let terms = build_termspec(
5754 &parsed.terms,
5755 &ds,
5756 &col_map,
5757 &mut notes,
5758 &gam_runtime::resource::ResourcePolicy::default_library(),
5759 )
5760 .expect("uniform k=5 must auto-cap the binary margin instead of erroring");
5761 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
5762 panic!("expected tensor B-spline for te(x, b)");
5763 };
5764 let basis_size = |m: &BSplineBasisSpec| match &m.knotspec {
5765 BSplineKnotSpec::PeriodicUniform { num_basis, .. } => *num_basis,
5766 BSplineKnotSpec::Generate {
5767 num_internal_knots, ..
5768 } => num_internal_knots + m.degree + 1,
5769 BSplineKnotSpec::Automatic {
5770 num_internal_knots: Some(n),
5771 ..
5772 } => n + m.degree + 1,
5773 BSplineKnotSpec::NaturalCubicRegression { knots } => knots.len(),
5774 other => panic!("unexpected tensor marginal knotspec: {other:?}"),
5775 };
5776 let binary = &spec.marginalspecs[1];
5777 assert_eq!(basis_size(binary), 2);
5780 assert_eq!(binary.degree, 1);
5781 assert_eq!(basis_size(&spec.marginalspecs[0]), 5);
5783 }
5784
5785 #[test]
5786 fn tensor_all_tp_margins_with_per_margin_k_routes_to_bspline_tensor() {
5787 let ds = continuous_dataset(
5796 &["y", "x1", "x2"],
5797 (0..32)
5798 .map(|i| {
5799 let t = i as f64 / 31.0;
5800 vec![t.sin(), t, 1.0 - t]
5801 })
5802 .collect(),
5803 );
5804 let parsed =
5805 parse_formula("y ~ te(x1, x2, bs=c('tp','tp'), k=c(5,5))").expect("parse tensor");
5806 let col_map = ds.column_map();
5807 let mut notes = Vec::new();
5808 let terms = build_termspec(
5809 &parsed.terms,
5810 &ds,
5811 &col_map,
5812 &mut notes,
5813 &gam_runtime::resource::ResourcePolicy::default_library(),
5814 )
5815 .expect("build tensor terms with per-margin k");
5816 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
5817 panic!(
5818 "expected B-spline tensor when k=c(5,5) is supplied with bs=c('tp','tp'), got {:?}",
5819 terms.smooth_terms[0].basis
5820 );
5821 };
5822 let dims = spec
5832 .marginalspecs
5833 .iter()
5834 .map(|m| match m.knotspec {
5835 BSplineKnotSpec::Generate {
5836 num_internal_knots, ..
5837 } => num_internal_knots + m.degree + 1,
5838 BSplineKnotSpec::Automatic {
5839 num_internal_knots: Some(num_internal_knots),
5840 ..
5841 } => num_internal_knots + m.degree + 1,
5842 BSplineKnotSpec::PeriodicUniform { num_basis, .. } => num_basis,
5843 BSplineKnotSpec::Provided(ref knots) => {
5844 knots.len().saturating_sub(m.degree + 1)
5845 }
5846 BSplineKnotSpec::NaturalCubicRegression { ref knots } => knots.len(),
5847 BSplineKnotSpec::Automatic {
5848 num_internal_knots: None,
5849 ..
5850 } => panic!("test cannot infer automatic knot count"),
5851 })
5852 .collect::<Vec<_>>();
5853 assert_eq!(dims, vec![5, 5]);
5854 }
5855
5856 #[test]
5857 fn tensor_all_tp_margins_without_per_margin_k_builds_anisotropic_tensor() {
5858 let ds = continuous_dataset(
5866 &["y", "x1", "x2"],
5867 (0..32)
5868 .map(|i| {
5869 let t = i as f64 / 31.0;
5870 vec![t.sin(), t, 1.0 - t]
5871 })
5872 .collect(),
5873 );
5874 let parsed = parse_formula("y ~ te(x1, x2, bs=c('tp','tp'))").expect("parse tensor");
5875 let col_map = ds.column_map();
5876 let mut notes = Vec::new();
5877 let terms = build_termspec(
5878 &parsed.terms,
5879 &ds,
5880 &col_map,
5881 &mut notes,
5882 &gam_runtime::resource::ResourcePolicy::default_library(),
5883 )
5884 .expect("build tensor terms without per-margin k");
5885 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
5886 panic!(
5887 "te(...,bs=c('tp','tp')) must route to an anisotropic tensor product, not a \
5888 silent isotropic thin-plate substitution; got {:?}",
5889 terms.smooth_terms[0].basis
5890 );
5891 };
5892 assert_eq!(
5893 spec.marginalspecs.len(),
5894 2,
5895 "tp tensor must carry one penalized B-spline margin per axis"
5896 );
5897 }
5898
5899 #[test]
5900 fn explicit_basis_sizes_are_not_small_n_clamped() {
5901 let ds = continuous_dataset(
5902 &["y", "x1", "x2", "x3", "x4", "x5"],
5903 (0..12)
5904 .map(|i| {
5905 let x = i as f64 / 11.0;
5906 vec![x.sin(), x, x * x, x + 0.1, 1.0 - x, (2.0 * x).sin()]
5907 })
5908 .collect(),
5909 );
5910 let parsed = parse_formula("y ~ s(x1, k=10) + s(x2) + s(x3) + s(x4) + s(x5)")
5911 .expect("parse multi-smooth formula");
5912 let col_map = ds.column_map();
5913 let mut notes = Vec::new();
5914 let terms = build_termspec(
5915 &parsed.terms,
5916 &ds,
5917 &col_map,
5918 &mut notes,
5919 &gam_runtime::resource::ResourcePolicy::default_library(),
5920 )
5921 .expect("build multi-smooth terms");
5922 let SmoothBasisSpec::BSpline1D { spec, .. } = &terms.smooth_terms[0].basis else {
5923 panic!("expected first smooth to be B-spline");
5924 };
5925 assert!(matches!(
5926 &spec.knotspec,
5927 BSplineKnotSpec::Generate {
5928 num_internal_knots: 6,
5929 ..
5930 }
5931 ));
5932 }
5933
5934 #[test]
5935 fn explicit_duchon_centers_are_not_small_n_bumped() {
5936 let ds = continuous_dataset(
5937 &["y", "x1", "x2", "x3", "x4", "x5"],
5938 (0..12)
5939 .map(|i| {
5940 let x = i as f64 / 11.0;
5941 vec![x.sin(), x, x * x, x + 0.1, 1.0 - x, (2.0 * x).sin()]
5942 })
5943 .collect(),
5944 );
5945 let parsed = parse_formula("y ~ duchon(x1, centers=3) + s(x2) + s(x3) + s(x4) + s(x5)")
5952 .expect("parse multi-smooth formula");
5953 let col_map = ds.column_map();
5954 let mut notes = Vec::new();
5955 let terms = build_termspec(
5956 &parsed.terms,
5957 &ds,
5958 &col_map,
5959 &mut notes,
5960 &gam_runtime::resource::ResourcePolicy::default_library(),
5961 )
5962 .expect("build multi-smooth terms");
5963 let SmoothBasisSpec::Duchon { spec, .. } = &terms.smooth_terms[0].basis else {
5964 panic!("expected first smooth to be Duchon");
5965 };
5966 assert!(matches!(
5967 spec.center_strategy,
5968 CenterStrategy::FarthestPoint { num_centers: 3 }
5969 ));
5970 }
5971
5972 #[test]
5973 fn inferred_tensor_basis_cap_uses_coordinate_support_not_duplicate_rows() {
5974 let mut unique_rows = Vec::new();
5975 for i in 0..50 {
5976 let theta = i as f64 / 50.0;
5977 for j in 0..16 {
5978 let h = -1.0 + 2.0 * (j as f64) / 15.0;
5979 let y = theta.cos() + h;
5980 unique_rows.push(vec![y, theta, h]);
5981 }
5982 }
5983 let mut repeated_rows = Vec::new();
5984 for _ in 0..12 {
5985 repeated_rows.extend(unique_rows.iter().cloned());
5986 }
5987
5988 let unique = continuous_dataset(&["y", "theta", "h"], unique_rows);
5989 let repeated = continuous_dataset(&["y", "theta", "h"], repeated_rows);
5990
5991 let unique_basis = inferred_tensor_basis_product(&unique);
5992 let repeated_basis = inferred_tensor_basis_product(&repeated);
5993
5994 assert_eq!(
5995 unique_basis, repeated_basis,
5996 "duplicating existing tensor coordinates must not inflate inferred basis width"
5997 );
5998 }
5999
6000 #[test]
6001 fn inferred_three_dim_tensor_basis_stays_bounded_for_reml_selection() {
6002 let make = |n: usize| -> usize {
6010 let mut rows = Vec::with_capacity(n);
6011 for i in 0..n {
6012 let f = i as f64 / n as f64;
6013 rows.push(vec![f.sin(), f, (2.0 * f).cos(), (3.0 * f) % 1.0]);
6014 }
6015 let ds = continuous_dataset(&["y", "x1", "x2", "x3"], rows);
6016 let parsed = parse_formula("y ~ te(x1, x2, x3)").expect("parse 3-D tensor");
6017 let col_map = ds.column_map();
6018 let mut notes = Vec::new();
6019 let terms = build_termspec(
6020 &parsed.terms,
6021 &ds,
6022 &col_map,
6023 &mut notes,
6024 &ResourcePolicy::default_library(),
6025 )
6026 .expect("build 3-D tensor termspec");
6027 let SmoothBasisSpec::TensorBSpline { spec, .. } = &terms.smooth_terms[0].basis else {
6028 panic!("expected tensor smooth");
6029 };
6030 spec.marginalspecs
6031 .iter()
6032 .map(|m| match m.knotspec {
6033 BSplineKnotSpec::Generate {
6034 num_internal_knots, ..
6035 } => num_internal_knots + m.degree + 1,
6036 BSplineKnotSpec::Automatic {
6037 num_internal_knots: Some(num_internal_knots),
6038 ..
6039 } => num_internal_knots + m.degree + 1,
6040 BSplineKnotSpec::NaturalCubicRegression { ref knots } => knots.len(),
6043 _ => panic!("unexpected tensor margin knotspec"),
6044 })
6045 .product()
6046 };
6047
6048 assert!(
6050 make(60) <= 216,
6051 "3-D te at small n must stay near the mgcv te default, got {}",
6052 make(60)
6053 );
6054 assert!(
6056 make(2000) <= 216,
6057 "3-D te at large n must not blow ∏k toward the data size, got {}",
6058 make(2000)
6059 );
6060 }
6061
6062 #[test]
6063 fn parse_bspline_boundary_conditions_and_side_selector() {
6064 let mut opts = BTreeMap::new();
6069 opts.insert("boundary_conditions".to_string(), "anchored".to_string());
6070 opts.insert("side".to_string(), "left".to_string());
6071 opts.insert("anchor".to_string(), "2.5".to_string());
6072 let err = parse_bspline_boundary_conditions(&opts)
6073 .expect_err("non-zero left anchor must be rejected")
6074 .to_string();
6075 assert!(
6076 err.contains("left") && err.contains("2.5"),
6077 "rejection should name the affected side and value: {err}"
6078 );
6079
6080 let mut opts = BTreeMap::new();
6084 opts.insert("start_bc".to_string(), "clamped".to_string());
6085 opts.insert("end_bc".to_string(), "zero".to_string());
6086 opts.insert("right_anchor".to_string(), "-1.0".to_string());
6087 let err = parse_bspline_boundary_conditions(&opts)
6088 .expect_err("non-zero right anchor must be rejected")
6089 .to_string();
6090 assert!(
6091 err.contains("right") && err.contains("-1"),
6092 "rejection should name the affected side and value: {err}"
6093 );
6094
6095 let mut opts = BTreeMap::new();
6099 opts.insert("start_bc".to_string(), "clamped".to_string());
6100 opts.insert("end_bc".to_string(), "zero".to_string());
6101 let parsed = parse_bspline_boundary_conditions(&opts).expect("boundary conditions");
6102 assert!(matches!(
6103 parsed.left,
6104 BSplineEndpointBoundaryCondition::Clamped
6105 ));
6106 assert!(matches!(
6107 parsed.right,
6108 BSplineEndpointBoundaryCondition::Anchored { value } if value.abs() < 1e-12
6109 ));
6110 }
6111
6112 #[test]
6113 fn categorical_by_numeric_interaction_expands_treatment_coded_cells() {
6114 let ds = factor_dataset();
6125 let parsed = parse_formula("y ~ x:g").expect("parse `y ~ x:g`");
6127 let col_map = ds.column_map();
6128 let mut notes = Vec::new();
6129 let terms = build_termspec(
6130 &parsed.terms,
6131 &ds,
6132 &col_map,
6133 &mut notes,
6134 &ResourcePolicy::default_library(),
6135 )
6136 .expect("factor-aware `x:g` interaction must build, not error");
6137
6138 assert_eq!(
6139 terms.linear_terms.len(),
6140 2,
6141 "interaction-only `x:g` keeps ALL factor levels (full dummy coding): one slope column per group"
6142 );
6143
6144 let x_col = *col_map.get("x").expect("x column");
6145 let g_col = *col_map.get("g").expect("g column");
6146
6147 let mut seen_bits = std::collections::HashSet::new();
6150 for term in &terms.linear_terms {
6151 assert!(
6152 term.is_interaction(),
6153 "the categorical-by-numeric cell is a Wilkinson-Rogers interaction"
6154 );
6155 assert_eq!(term.feature_cols, vec![x_col]);
6156 assert_eq!(term.categorical_levels.len(), 1);
6157 let (gate_col, gate_bits) = term.categorical_levels[0];
6158 assert_eq!(gate_col, g_col);
6159 assert!(seen_bits.insert(gate_bits), "each level appears once");
6160
6161 let column = term
6163 .realized_design_column(ds.values.view())
6164 .expect("realize cell column");
6165 let n = ds.values.nrows();
6166 assert_eq!(column.len(), n);
6167 for row in 0..n {
6168 let x = ds.values[[row, x_col]];
6169 let g = ds.values[[row, g_col]];
6170 let expected = if g.to_bits() == gate_bits { x } else { 0.0 };
6171 assert!(
6172 (column[row] - expected).abs() < 1e-12,
6173 "row {row}: g={g}, x={x}, expected {expected}, got {}",
6174 column[row]
6175 );
6176 }
6177 }
6178 assert!(seen_bits.contains(&0.0_f64.to_bits()));
6181 assert!(seen_bits.contains(&1.0_f64.to_bits()));
6182 }
6183
6184 #[test]
6185 fn categorical_by_numeric_interaction_keeps_treatment_coding_with_parent() {
6186 let ds = factor_dataset();
6194 let parsed = parse_formula("y ~ x + x:g").expect("parse `y ~ x + x:g`");
6195 let col_map = ds.column_map();
6196 let mut notes = Vec::new();
6197 let terms = build_termspec(
6198 &parsed.terms,
6199 &ds,
6200 &col_map,
6201 &mut notes,
6202 &ResourcePolicy::default_library(),
6203 )
6204 .expect("`x + x:g` must build");
6205
6206 let x_col = *col_map.get("x").expect("x column");
6208 let g_col = *col_map.get("g").expect("g column");
6209 let interaction_cells: Vec<_> = terms
6210 .linear_terms
6211 .iter()
6212 .filter(|t| t.is_interaction())
6213 .collect();
6214 assert_eq!(
6215 interaction_cells.len(),
6216 1,
6217 "with `x` present, `x:g` is treatment-coded → one cell (reference dropped)"
6218 );
6219 let term = interaction_cells[0];
6220 assert_eq!(term.feature_cols, vec![x_col]);
6221 assert_eq!(term.categorical_levels.len(), 1);
6222 let (gate_col, gate_bits) = term.categorical_levels[0];
6223 assert_eq!(gate_col, g_col);
6224 assert_eq!(gate_bits, 1.0_f64.to_bits());
6226 }
6227
6228 #[test]
6229 fn categorical_by_categorical_interaction_expands_full_cross_cells() {
6230 let n = 30usize;
6241 let mut rows = Vec::with_capacity(n);
6242 for i in 0..n {
6243 let y = (i as f64).sin();
6244 let f = (i % 3) as f64; let g = (i % 2) as f64; rows.push(vec![y, f, g]);
6247 }
6248 let values = Array2::from_shape_vec(
6249 (n, 3),
6250 rows.into_iter().flat_map(|row| row.into_iter()).collect(),
6251 )
6252 .expect("rectangular cross-factor data");
6253 let ds = Dataset {
6254 headers: vec!["y".into(), "f".into(), "g".into()],
6255 values,
6256 schema: DataSchema {
6257 columns: vec![
6258 SchemaColumn {
6259 name: "y".into(),
6260 kind: ColumnKindTag::Continuous,
6261 levels: vec![],
6262 },
6263 SchemaColumn {
6264 name: "f".into(),
6265 kind: ColumnKindTag::Categorical,
6266 levels: vec!["f0".into(), "f1".into(), "f2".into()],
6267 },
6268 SchemaColumn {
6269 name: "g".into(),
6270 kind: ColumnKindTag::Categorical,
6271 levels: vec!["g0".into(), "g1".into()],
6272 },
6273 ],
6274 },
6275 column_kinds: vec![
6276 ColumnKindTag::Continuous,
6277 ColumnKindTag::Categorical,
6278 ColumnKindTag::Categorical,
6279 ],
6280 };
6281
6282 let parsed = parse_formula("y ~ f:g").expect("parse `y ~ f:g`");
6283 let col_map = ds.column_map();
6284 let mut notes = Vec::new();
6285 let terms = build_termspec(
6286 &parsed.terms,
6287 &ds,
6288 &col_map,
6289 &mut notes,
6290 &ResourcePolicy::default_library(),
6291 )
6292 .expect("factor-by-factor `f:g` interaction must build, not error");
6293
6294 assert_eq!(
6295 terms.linear_terms.len(),
6296 5,
6297 "saturated 3*2 = 6 cross cells minus one reference cell (f0:g0) = 5"
6298 );
6299
6300 let f_col = *col_map.get("f").expect("f column");
6301 let g_col = *col_map.get("g").expect("g column");
6302 let f0 = 0.0_f64.to_bits();
6306 let g0 = 0.0_f64.to_bits();
6307 let mut emitted = std::collections::HashSet::new();
6308 for term in &terms.linear_terms {
6309 assert!(term.feature_cols.is_empty());
6311 assert_eq!(term.categorical_levels.len(), 2);
6312 let mut gates = std::collections::HashMap::new();
6313 for &(col, bits) in &term.categorical_levels {
6314 gates.insert(col, bits);
6315 }
6316 let f_bits = *gates.get(&f_col).expect("f gate present");
6317 let g_bits = *gates.get(&g_col).expect("g gate present");
6318 assert!(
6320 !(f_bits == f0 && g_bits == g0),
6321 "the reference cell f0:g0 must be absorbed by the intercept, not emitted"
6322 );
6323 emitted.insert((f_bits, g_bits));
6324
6325 let column = term
6326 .realized_design_column(ds.values.view())
6327 .expect("realize cross cell");
6328 for row in 0..n {
6329 let f = ds.values[[row, f_col]];
6330 let g = ds.values[[row, g_col]];
6331 let expected = if f.to_bits() == f_bits && g.to_bits() == g_bits {
6332 1.0
6333 } else {
6334 0.0
6335 };
6336 assert!(
6337 (column[row] - expected).abs() < 1e-12,
6338 "row {row}: expected {expected}, got {}",
6339 column[row]
6340 );
6341 }
6342 assert!(
6343 column.iter().any(|&v| v == 1.0),
6344 "each cross cell must be observed in the data"
6345 );
6346 }
6347 let f_levels = [0.0_f64.to_bits(), 1.0_f64.to_bits(), 2.0_f64.to_bits()];
6350 let g_levels = [0.0_f64.to_bits(), 1.0_f64.to_bits()];
6351 for &fb in &f_levels {
6352 for &gb in &g_levels {
6353 if fb == f0 && gb == g0 {
6354 continue;
6355 }
6356 assert!(
6357 emitted.contains(&(fb, gb)),
6358 "saturated cross cell must be present"
6359 );
6360 }
6361 }
6362 }
6363}