1use crate::basis::{
2 BSplineIdentifiability, CenterStrategy, ConstantCurvatureIdentifiability,
3 MaternIdentifiability, MeasureJetIdentifiability, SpatialIdentifiability,
4 SphericalSplineIdentifiability,
5};
6
7use super::{
8 ByVarKind, FactorSmoothFlavour, SmoothBasisSpec, SmoothTermSpec, TensorBSplineIdentifiability,
9};
10
11use std::collections::BTreeSet;
12
13fn smooth_basis_feature_cols(basis: &SmoothBasisSpec) -> Vec<usize> {
14 match basis {
15 SmoothBasisSpec::ByVariable { inner, by_col, .. }
16 | SmoothBasisSpec::FactorSumToZero { inner, by_col, .. } => {
17 let mut cols = smooth_basis_feature_cols(inner);
18 cols.push(*by_col);
19 cols.sort_unstable();
20 cols.dedup();
21 cols
22 }
23 SmoothBasisSpec::BySmooth { smooth, .. } => smooth_basis_feature_cols(smooth),
24 SmoothBasisSpec::BSpline1D { feature_col, .. } => vec![*feature_col],
25 SmoothBasisSpec::ThinPlate { feature_cols, .. }
26 | SmoothBasisSpec::Sphere { feature_cols, .. }
27 | SmoothBasisSpec::ConstantCurvature { feature_cols, .. }
28 | SmoothBasisSpec::Matern { feature_cols, .. }
29 | SmoothBasisSpec::MeasureJet { feature_cols, .. }
30 | SmoothBasisSpec::Duchon { feature_cols, .. }
31 | SmoothBasisSpec::Pca { feature_cols, .. }
32 | SmoothBasisSpec::TensorBSpline { feature_cols, .. } => feature_cols.clone(),
33 SmoothBasisSpec::FactorSmooth { spec } => {
34 let mut cols = spec.continuous_cols.clone();
35 cols.push(spec.group_col);
36 cols.sort_unstable();
37 cols.dedup();
38 cols
39 }
40 }
41}
42
43pub fn smooth_term_feature_cols(term: &SmoothTermSpec) -> Vec<usize> {
44 smooth_basis_feature_cols(&term.basis)
45}
46
47fn smooth_basis_family_rank(term: &SmoothTermSpec) -> u8 {
48 match &term.basis {
49 SmoothBasisSpec::ByVariable { inner, .. }
50 | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
51 smooth_basis_family_rank(&SmoothTermSpec {
52 name: term.name.clone(),
53 basis: (**inner).clone(),
54 shape: term.shape,
55 joint_null_rotation: None,
56 })
57 }
58 SmoothBasisSpec::BSpline1D { .. } => 0,
59 SmoothBasisSpec::TensorBSpline { .. } => 1,
60 SmoothBasisSpec::ThinPlate { .. } => 2,
61 SmoothBasisSpec::Sphere { .. } => 3,
62 SmoothBasisSpec::Matern { .. } => 4,
63 SmoothBasisSpec::Duchon { .. } => 5,
64 SmoothBasisSpec::Pca { .. } => 6,
65 SmoothBasisSpec::ConstantCurvature { .. } => 8,
66 SmoothBasisSpec::MeasureJet { .. } => 9,
67 SmoothBasisSpec::BySmooth { smooth, .. } => smooth_basis_family_rank(&SmoothTermSpec {
68 name: term.name.clone(),
69 basis: (**smooth).clone(),
70 shape: term.shape,
71 joint_null_rotation: None,
72 }),
73 SmoothBasisSpec::FactorSmooth { .. } => 7,
74 }
75}
76
77pub fn smooth_has_frozen_identifiability(term: &SmoothTermSpec) -> bool {
78 match &term.basis {
79 SmoothBasisSpec::ByVariable { inner, .. }
80 | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
81 smooth_has_frozen_identifiability(&SmoothTermSpec {
82 name: term.name.clone(),
83 basis: (**inner).clone(),
84 shape: term.shape,
85 joint_null_rotation: None,
86 })
87 }
88 SmoothBasisSpec::BSpline1D { spec, .. } => {
89 matches!(
90 spec.identifiability,
91 BSplineIdentifiability::FrozenTransform { .. }
92 )
93 }
94 SmoothBasisSpec::ThinPlate { spec, .. } => matches!(
95 spec.identifiability,
96 SpatialIdentifiability::FrozenTransform { .. }
97 ),
98 SmoothBasisSpec::Sphere { spec, .. } => {
99 matches!(spec.center_strategy, CenterStrategy::UserProvided(_))
100 || matches!(
101 spec.identifiability,
102 SphericalSplineIdentifiability::FrozenTransform { .. }
103 )
104 }
105 SmoothBasisSpec::ConstantCurvature { spec, .. } => {
106 matches!(spec.center_strategy, CenterStrategy::UserProvided(_))
107 || matches!(
108 spec.identifiability,
109 ConstantCurvatureIdentifiability::FrozenTransform { .. }
110 )
111 }
112 SmoothBasisSpec::MeasureJet { spec, .. } => {
113 matches!(spec.center_strategy, CenterStrategy::UserProvided(_))
114 || matches!(
115 spec.identifiability,
116 MeasureJetIdentifiability::FrozenTransform { .. }
117 )
118 }
119 SmoothBasisSpec::Matern { spec, .. } => matches!(
120 spec.identifiability,
121 MaternIdentifiability::FrozenTransform { .. }
122 ),
123 SmoothBasisSpec::BySmooth { by_kind, .. } => match by_kind {
124 ByVarKind::Factor { frozen_levels, .. } => frozen_levels.is_some(),
125 ByVarKind::Numeric { .. } => true,
126 },
127 SmoothBasisSpec::FactorSmooth { spec } => spec.group_frozen_levels.is_some(),
128 SmoothBasisSpec::Duchon { spec, .. } => matches!(
129 spec.identifiability,
130 SpatialIdentifiability::FrozenTransform { .. }
131 ),
132 SmoothBasisSpec::Pca {
133 centered,
134 center_mean,
135 pca_basis_path,
136 ..
137 } => !*centered || center_mean.is_some() || pca_basis_path.is_some(),
138 SmoothBasisSpec::TensorBSpline { spec, .. } => matches!(
139 spec.identifiability,
140 TensorBSplineIdentifiability::FrozenTransform { .. }
141 ),
142 }
143}
144
145fn compare_smooth_ownership_priority(
146 lhs_idx: usize,
147 lhs: &SmoothTermSpec,
148 rhs_idx: usize,
149 rhs: &SmoothTermSpec,
150) -> std::cmp::Ordering {
151 let lhs_cols = smooth_term_feature_cols(lhs);
152 let rhs_cols = smooth_term_feature_cols(rhs);
153 lhs_cols
154 .len()
155 .cmp(&rhs_cols.len())
156 .then_with(|| lhs_cols.cmp(&rhs_cols))
157 .then_with(|| smooth_basis_family_rank(lhs).cmp(&smooth_basis_family_rank(rhs)))
158 .then_with(|| lhs.name.cmp(&rhs.name))
159 .then(lhs_idx.cmp(&rhs_idx))
160}
161
162fn factor_by_level_gate_of(term: &SmoothTermSpec) -> Option<(usize, u64)> {
171 match &term.basis {
172 SmoothBasisSpec::ByVariable {
173 by_col,
174 by: crate::smooth::ByVariableSpec::Level { value_bits, .. },
175 ..
176 } => Some((*by_col, *value_bits)),
177 _ => None,
178 }
179}
180
181fn factor_sum_to_zero_group_col(term: &SmoothTermSpec) -> Option<usize> {
200 match &term.basis {
201 SmoothBasisSpec::FactorSumToZero { by_col, .. } => Some(*by_col),
202 SmoothBasisSpec::FactorSmooth { spec }
203 if matches!(spec.flavour, FactorSmoothFlavour::Sz) =>
204 {
205 Some(spec.group_col)
206 }
207 _ => None,
208 }
209}
210
211fn smooth_is_owned_by_prior_term(owner: &SmoothTermSpec, target: &SmoothTermSpec) -> bool {
212 if let Some(target_gate) = factor_by_level_gate_of(target) {
220 if factor_by_level_gate_of(owner) != Some(target_gate) {
221 return false;
222 }
223 }
224 if let Some(group_col) = factor_sum_to_zero_group_col(target) {
230 let owner_features = smooth_term_feature_cols(owner)
231 .into_iter()
232 .collect::<BTreeSet<_>>();
233 if !owner_features.contains(&group_col) {
234 return false;
235 }
236 }
237 let owner_features = smooth_term_feature_cols(owner)
238 .into_iter()
239 .collect::<BTreeSet<_>>();
240 let target_features = smooth_term_feature_cols(target)
241 .into_iter()
242 .collect::<BTreeSet<_>>();
243 owner_features.is_subset(&target_features)
244}
245
246pub struct SmoothStructureAnalysis {
257 pub ownership_order: Vec<usize>,
260 pub term_feature_cols: Vec<Vec<usize>>,
263 pub term_owners: Vec<Vec<usize>>,
267}
268
269pub fn analyze_smooth_ownership(smoothspecs: &[SmoothTermSpec]) -> SmoothStructureAnalysis {
273 let term_feature_cols: Vec<Vec<usize>> =
274 smoothspecs.iter().map(smooth_term_feature_cols).collect();
275
276 let mut ownership_order: Vec<usize> = (0..smoothspecs.len()).collect();
277 ownership_order.sort_by(|&lhs, &rhs| {
278 compare_smooth_ownership_priority(lhs, &smoothspecs[lhs], rhs, &smoothspecs[rhs])
279 });
280
281 let mut term_owners = vec![Vec::<usize>::new(); smoothspecs.len()];
282 for (pos, &target_idx) in ownership_order.iter().enumerate() {
283 let target = &smoothspecs[target_idx];
284 term_owners[target_idx] = ownership_order[..pos]
285 .iter()
286 .copied()
287 .filter(|&owner_idx| smooth_is_owned_by_prior_term(&smoothspecs[owner_idx], target))
288 .collect();
289 }
290
291 SmoothStructureAnalysis {
292 ownership_order,
293 term_feature_cols,
294 term_owners,
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301 use crate::basis::{BSplineBasisSpec, BSplineKnotSpec, OneDimensionalBoundary};
302 use crate::smooth::{BySmoothKind, ByVariableSpec, ShapeConstraint};
303
304 fn bspline(feature_col: usize) -> SmoothBasisSpec {
305 SmoothBasisSpec::BSpline1D {
306 feature_col,
307 spec: BSplineBasisSpec {
308 degree: 3,
309 penalty_order: 2,
310 knotspec: BSplineKnotSpec::Generate {
311 data_range: (0.0, 1.0),
312 num_internal_knots: 5,
313 },
314 double_penalty: false,
315 identifiability: BSplineIdentifiability::None,
316 boundary: OneDimensionalBoundary::Open,
317 boundary_conditions: Default::default(),
318 },
319 }
320 }
321
322 fn term(name: &str, basis: SmoothBasisSpec) -> SmoothTermSpec {
323 SmoothTermSpec {
324 name: name.to_string(),
325 basis,
326 shape: ShapeConstraint::None,
327 joint_null_rotation: None,
328 }
329 }
330
331 fn level_by_term(
332 name: &str,
333 feature_col: usize,
334 by_col: usize,
335 level_bits: u64,
336 ) -> SmoothTermSpec {
337 term(
338 name,
339 SmoothBasisSpec::ByVariable {
340 inner: Box::new(bspline(feature_col)),
341 by_col,
342 kind: BySmoothKind::Level { level_bits },
343 by: ByVariableSpec::Level {
344 value_bits: level_bits,
345 label: name.to_string(),
346 },
347 },
348 )
349 }
350
351 #[test]
352 fn ungated_smooth_does_not_own_factor_by_level_smooth() {
353 let specs = vec![term("s(x)", bspline(0)), level_by_term("s(x):B", 0, 1, 42)];
354
355 let analysis = analyze_smooth_ownership(&specs);
356
357 assert_eq!(
358 analysis.term_owners[1],
359 Vec::<usize>::new(),
360 "ungated s(x) must not own the row-gated by-factor deviation smooth"
361 );
362 }
363
364 #[test]
365 fn same_factor_by_level_gate_keeps_normal_subset_ownership() {
366 let specs = vec![
367 level_by_term("s(x):B", 0, 2, 42),
368 term(
369 "te(x,z):B",
370 SmoothBasisSpec::ByVariable {
371 inner: Box::new(SmoothBasisSpec::TensorBSpline {
372 feature_cols: vec![0, 1],
373 spec: Default::default(),
374 }),
375 by_col: 2,
376 kind: BySmoothKind::Level { level_bits: 42 },
377 by: ByVariableSpec::Level {
378 value_bits: 42,
379 label: "B".to_string(),
380 },
381 },
382 ),
383 ];
384
385 let analysis = analyze_smooth_ownership(&specs);
386
387 assert_eq!(
388 analysis.term_owners[1],
389 vec![0],
390 "matching by-level gates may still use the ordinary nested smooth ownership rule"
391 );
392 }
393}