1use crate::basis::{
2 BSplineIdentifiability, CenterStrategy, ConstantCurvatureIdentifiability,
3 MaternIdentifiability, MeasureJetIdentifiability, SpatialIdentifiability,
4 SphericalSplineIdentifiability,
5};
6
7use super::{ByVarKind, SmoothBasisSpec, SmoothTermSpec, TensorBSplineIdentifiability};
8
9use std::collections::BTreeSet;
10
11fn smooth_basis_feature_cols(basis: &SmoothBasisSpec) -> Vec<usize> {
12 match basis {
13 SmoothBasisSpec::ByVariable { inner, by_col, .. }
14 | SmoothBasisSpec::FactorSumToZero { inner, by_col, .. } => {
15 let mut cols = smooth_basis_feature_cols(inner);
16 cols.push(*by_col);
17 cols.sort_unstable();
18 cols.dedup();
19 cols
20 }
21 SmoothBasisSpec::BySmooth { smooth, .. } => smooth_basis_feature_cols(smooth),
22 SmoothBasisSpec::BSpline1D { feature_col, .. } => vec![*feature_col],
23 SmoothBasisSpec::ThinPlate { feature_cols, .. }
24 | SmoothBasisSpec::Sphere { feature_cols, .. }
25 | SmoothBasisSpec::ConstantCurvature { feature_cols, .. }
26 | SmoothBasisSpec::Matern { feature_cols, .. }
27 | SmoothBasisSpec::MeasureJet { feature_cols, .. }
28 | SmoothBasisSpec::Duchon { feature_cols, .. }
29 | SmoothBasisSpec::Pca { feature_cols, .. }
30 | SmoothBasisSpec::TensorBSpline { feature_cols, .. } => feature_cols.clone(),
31 SmoothBasisSpec::FactorSmooth { spec } => {
32 let mut cols = spec.continuous_cols.clone();
33 cols.push(spec.group_col);
34 cols.sort_unstable();
35 cols.dedup();
36 cols
37 }
38 }
39}
40
41pub fn smooth_term_feature_cols(term: &SmoothTermSpec) -> Vec<usize> {
42 smooth_basis_feature_cols(&term.basis)
43}
44
45fn smooth_basis_family_rank(term: &SmoothTermSpec) -> u8 {
46 match &term.basis {
47 SmoothBasisSpec::ByVariable { inner, .. }
48 | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
49 smooth_basis_family_rank(&SmoothTermSpec {
50 name: term.name.clone(),
51 basis: (**inner).clone(),
52 shape: term.shape,
53 joint_null_rotation: None,
54 })
55 }
56 SmoothBasisSpec::BSpline1D { .. } => 0,
57 SmoothBasisSpec::TensorBSpline { .. } => 1,
58 SmoothBasisSpec::ThinPlate { .. } => 2,
59 SmoothBasisSpec::Sphere { .. } => 3,
60 SmoothBasisSpec::Matern { .. } => 4,
61 SmoothBasisSpec::Duchon { .. } => 5,
62 SmoothBasisSpec::Pca { .. } => 6,
63 SmoothBasisSpec::ConstantCurvature { .. } => 8,
64 SmoothBasisSpec::MeasureJet { .. } => 9,
65 SmoothBasisSpec::BySmooth { smooth, .. } => smooth_basis_family_rank(&SmoothTermSpec {
66 name: term.name.clone(),
67 basis: (**smooth).clone(),
68 shape: term.shape,
69 joint_null_rotation: None,
70 }),
71 SmoothBasisSpec::FactorSmooth { .. } => 7,
72 }
73}
74
75pub fn smooth_has_frozen_identifiability(term: &SmoothTermSpec) -> bool {
76 match &term.basis {
77 SmoothBasisSpec::ByVariable { inner, .. }
78 | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
79 smooth_has_frozen_identifiability(&SmoothTermSpec {
80 name: term.name.clone(),
81 basis: (**inner).clone(),
82 shape: term.shape,
83 joint_null_rotation: None,
84 })
85 }
86 SmoothBasisSpec::BSpline1D { spec, .. } => {
87 matches!(
88 spec.identifiability,
89 BSplineIdentifiability::FrozenTransform { .. }
90 )
91 }
92 SmoothBasisSpec::ThinPlate { spec, .. } => matches!(
93 spec.identifiability,
94 SpatialIdentifiability::FrozenTransform { .. }
95 ),
96 SmoothBasisSpec::Sphere { spec, .. } => {
97 matches!(spec.center_strategy, CenterStrategy::UserProvided(_))
98 || matches!(
99 spec.identifiability,
100 SphericalSplineIdentifiability::FrozenTransform { .. }
101 )
102 }
103 SmoothBasisSpec::ConstantCurvature { spec, .. } => {
104 matches!(spec.center_strategy, CenterStrategy::UserProvided(_))
105 || matches!(
106 spec.identifiability,
107 ConstantCurvatureIdentifiability::FrozenTransform { .. }
108 )
109 }
110 SmoothBasisSpec::MeasureJet { spec, .. } => {
111 matches!(spec.center_strategy, CenterStrategy::UserProvided(_))
112 || matches!(
113 spec.identifiability,
114 MeasureJetIdentifiability::FrozenTransform { .. }
115 )
116 }
117 SmoothBasisSpec::Matern { spec, .. } => matches!(
118 spec.identifiability,
119 MaternIdentifiability::FrozenTransform { .. }
120 ),
121 SmoothBasisSpec::BySmooth { by_kind, .. } => match by_kind {
122 ByVarKind::Factor { frozen_levels, .. } => frozen_levels.is_some(),
123 ByVarKind::Numeric { .. } => true,
124 },
125 SmoothBasisSpec::FactorSmooth { spec } => spec.group_frozen_levels.is_some(),
126 SmoothBasisSpec::Duchon { spec, .. } => matches!(
127 spec.identifiability,
128 SpatialIdentifiability::FrozenTransform { .. }
129 ),
130 SmoothBasisSpec::Pca {
131 centered,
132 center_mean,
133 pca_basis_path,
134 ..
135 } => !*centered || center_mean.is_some() || pca_basis_path.is_some(),
136 SmoothBasisSpec::TensorBSpline { spec, .. } => matches!(
137 spec.identifiability,
138 TensorBSplineIdentifiability::FrozenTransform { .. }
139 ),
140 }
141}
142
143fn compare_smooth_ownership_priority(
144 lhs_idx: usize,
145 lhs: &SmoothTermSpec,
146 rhs_idx: usize,
147 rhs: &SmoothTermSpec,
148) -> std::cmp::Ordering {
149 let lhs_cols = smooth_term_feature_cols(lhs);
150 let rhs_cols = smooth_term_feature_cols(rhs);
151 lhs_cols
152 .len()
153 .cmp(&rhs_cols.len())
154 .then_with(|| lhs_cols.cmp(&rhs_cols))
155 .then_with(|| smooth_basis_family_rank(lhs).cmp(&smooth_basis_family_rank(rhs)))
156 .then_with(|| lhs.name.cmp(&rhs.name))
157 .then(lhs_idx.cmp(&rhs_idx))
158}
159
160fn factor_by_level_gate_of(term: &SmoothTermSpec) -> Option<(usize, u64)> {
169 match &term.basis {
170 SmoothBasisSpec::ByVariable {
171 by_col,
172 by: crate::smooth::ByVariableSpec::Level { value_bits, .. },
173 ..
174 } => Some((*by_col, *value_bits)),
175 _ => None,
176 }
177}
178
179fn smooth_is_owned_by_prior_term(owner: &SmoothTermSpec, target: &SmoothTermSpec) -> bool {
180 if let Some(target_gate) = factor_by_level_gate_of(target) {
188 if factor_by_level_gate_of(owner) != Some(target_gate) {
189 return false;
190 }
191 }
192 let owner_features = smooth_term_feature_cols(owner)
193 .into_iter()
194 .collect::<BTreeSet<_>>();
195 let target_features = smooth_term_feature_cols(target)
196 .into_iter()
197 .collect::<BTreeSet<_>>();
198 owner_features.is_subset(&target_features)
199}
200
201pub struct SmoothStructureAnalysis {
212 pub ownership_order: Vec<usize>,
215 pub term_feature_cols: Vec<Vec<usize>>,
218 pub term_owners: Vec<Vec<usize>>,
222}
223
224pub fn analyze_smooth_ownership(smoothspecs: &[SmoothTermSpec]) -> SmoothStructureAnalysis {
228 let term_feature_cols: Vec<Vec<usize>> =
229 smoothspecs.iter().map(smooth_term_feature_cols).collect();
230
231 let mut ownership_order: Vec<usize> = (0..smoothspecs.len()).collect();
232 ownership_order.sort_by(|&lhs, &rhs| {
233 compare_smooth_ownership_priority(lhs, &smoothspecs[lhs], rhs, &smoothspecs[rhs])
234 });
235
236 let mut term_owners = vec![Vec::<usize>::new(); smoothspecs.len()];
237 for (pos, &target_idx) in ownership_order.iter().enumerate() {
238 let target = &smoothspecs[target_idx];
239 term_owners[target_idx] = ownership_order[..pos]
240 .iter()
241 .copied()
242 .filter(|&owner_idx| smooth_is_owned_by_prior_term(&smoothspecs[owner_idx], target))
243 .collect();
244 }
245
246 SmoothStructureAnalysis {
247 ownership_order,
248 term_feature_cols,
249 term_owners,
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use crate::basis::{BSplineBasisSpec, BSplineKnotSpec, OneDimensionalBoundary};
257 use crate::smooth::{BySmoothKind, ByVariableSpec, ShapeConstraint};
258
259 fn bspline(feature_col: usize) -> SmoothBasisSpec {
260 SmoothBasisSpec::BSpline1D {
261 feature_col,
262 spec: BSplineBasisSpec {
263 degree: 3,
264 penalty_order: 2,
265 knotspec: BSplineKnotSpec::Generate {
266 data_range: (0.0, 1.0),
267 num_internal_knots: 5,
268 },
269 double_penalty: false,
270 identifiability: BSplineIdentifiability::None,
271 boundary: OneDimensionalBoundary::Open,
272 boundary_conditions: Default::default(),
273 },
274 }
275 }
276
277 fn term(name: &str, basis: SmoothBasisSpec) -> SmoothTermSpec {
278 SmoothTermSpec {
279 name: name.to_string(),
280 basis,
281 shape: ShapeConstraint::None,
282 joint_null_rotation: None,
283 }
284 }
285
286 fn level_by_term(
287 name: &str,
288 feature_col: usize,
289 by_col: usize,
290 level_bits: u64,
291 ) -> SmoothTermSpec {
292 term(
293 name,
294 SmoothBasisSpec::ByVariable {
295 inner: Box::new(bspline(feature_col)),
296 by_col,
297 kind: BySmoothKind::Level { level_bits },
298 by: ByVariableSpec::Level {
299 value_bits: level_bits,
300 label: name.to_string(),
301 },
302 },
303 )
304 }
305
306 #[test]
307 fn ungated_smooth_does_not_own_factor_by_level_smooth() {
308 let specs = vec![term("s(x)", bspline(0)), level_by_term("s(x):B", 0, 1, 42)];
309
310 let analysis = analyze_smooth_ownership(&specs);
311
312 assert_eq!(
313 analysis.term_owners[1],
314 Vec::<usize>::new(),
315 "ungated s(x) must not own the row-gated by-factor deviation smooth"
316 );
317 }
318
319 #[test]
320 fn same_factor_by_level_gate_keeps_normal_subset_ownership() {
321 let specs = vec![
322 level_by_term("s(x):B", 0, 2, 42),
323 term(
324 "te(x,z):B",
325 SmoothBasisSpec::ByVariable {
326 inner: Box::new(SmoothBasisSpec::TensorBSpline {
327 feature_cols: vec![0, 1],
328 spec: Default::default(),
329 }),
330 by_col: 2,
331 kind: BySmoothKind::Level { level_bits: 42 },
332 by: ByVariableSpec::Level {
333 value_bits: 42,
334 label: "B".to_string(),
335 },
336 },
337 ),
338 ];
339
340 let analysis = analyze_smooth_ownership(&specs);
341
342 assert_eq!(
343 analysis.term_owners[1],
344 vec![0],
345 "matching by-level gates may still use the ordinary nested smooth ownership rule"
346 );
347 }
348}