Skip to main content

gam_terms/smooth/
structure_analysis.rs

1use crate::basis::{
2    BSplineIdentifiability, CenterStrategy, ConstantCurvatureIdentifiability,
3    MaternIdentifiability, MeasureJetIdentifiability, SpatialIdentifiability,
4    SphericalSplineIdentifiability,
5};
6
7use super::{ByVarKind, SmoothBasisSpec, SmoothTermSpec, TensorBSplineIdentifiability};
8
9use std::collections::BTreeSet;
10
11fn smooth_basis_feature_cols(basis: &SmoothBasisSpec) -> Vec<usize> {
12    match basis {
13        SmoothBasisSpec::ByVariable { inner, by_col, .. }
14        | SmoothBasisSpec::FactorSumToZero { inner, by_col, .. } => {
15            let mut cols = smooth_basis_feature_cols(inner);
16            cols.push(*by_col);
17            cols.sort_unstable();
18            cols.dedup();
19            cols
20        }
21        SmoothBasisSpec::BySmooth { smooth, .. } => smooth_basis_feature_cols(smooth),
22        SmoothBasisSpec::BSpline1D { feature_col, .. } => vec![*feature_col],
23        SmoothBasisSpec::ThinPlate { feature_cols, .. }
24        | SmoothBasisSpec::Sphere { feature_cols, .. }
25        | SmoothBasisSpec::ConstantCurvature { feature_cols, .. }
26        | SmoothBasisSpec::Matern { feature_cols, .. }
27        | SmoothBasisSpec::MeasureJet { feature_cols, .. }
28        | SmoothBasisSpec::Duchon { feature_cols, .. }
29        | SmoothBasisSpec::Pca { feature_cols, .. }
30        | SmoothBasisSpec::TensorBSpline { feature_cols, .. } => feature_cols.clone(),
31        SmoothBasisSpec::FactorSmooth { spec } => {
32            let mut cols = spec.continuous_cols.clone();
33            cols.push(spec.group_col);
34            cols.sort_unstable();
35            cols.dedup();
36            cols
37        }
38    }
39}
40
41pub fn smooth_term_feature_cols(term: &SmoothTermSpec) -> Vec<usize> {
42    smooth_basis_feature_cols(&term.basis)
43}
44
45fn smooth_basis_family_rank(term: &SmoothTermSpec) -> u8 {
46    match &term.basis {
47        SmoothBasisSpec::ByVariable { inner, .. }
48        | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
49            smooth_basis_family_rank(&SmoothTermSpec {
50                name: term.name.clone(),
51                basis: (**inner).clone(),
52                shape: term.shape,
53                joint_null_rotation: None,
54            })
55        }
56        SmoothBasisSpec::BSpline1D { .. } => 0,
57        SmoothBasisSpec::TensorBSpline { .. } => 1,
58        SmoothBasisSpec::ThinPlate { .. } => 2,
59        SmoothBasisSpec::Sphere { .. } => 3,
60        SmoothBasisSpec::Matern { .. } => 4,
61        SmoothBasisSpec::Duchon { .. } => 5,
62        SmoothBasisSpec::Pca { .. } => 6,
63        SmoothBasisSpec::ConstantCurvature { .. } => 8,
64        SmoothBasisSpec::MeasureJet { .. } => 9,
65        SmoothBasisSpec::BySmooth { smooth, .. } => smooth_basis_family_rank(&SmoothTermSpec {
66            name: term.name.clone(),
67            basis: (**smooth).clone(),
68            shape: term.shape,
69            joint_null_rotation: None,
70        }),
71        SmoothBasisSpec::FactorSmooth { .. } => 7,
72    }
73}
74
75pub fn smooth_has_frozen_identifiability(term: &SmoothTermSpec) -> bool {
76    match &term.basis {
77        SmoothBasisSpec::ByVariable { inner, .. }
78        | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
79            smooth_has_frozen_identifiability(&SmoothTermSpec {
80                name: term.name.clone(),
81                basis: (**inner).clone(),
82                shape: term.shape,
83                joint_null_rotation: None,
84            })
85        }
86        SmoothBasisSpec::BSpline1D { spec, .. } => {
87            matches!(
88                spec.identifiability,
89                BSplineIdentifiability::FrozenTransform { .. }
90            )
91        }
92        SmoothBasisSpec::ThinPlate { spec, .. } => matches!(
93            spec.identifiability,
94            SpatialIdentifiability::FrozenTransform { .. }
95        ),
96        SmoothBasisSpec::Sphere { spec, .. } => {
97            matches!(spec.center_strategy, CenterStrategy::UserProvided(_))
98                || matches!(
99                    spec.identifiability,
100                    SphericalSplineIdentifiability::FrozenTransform { .. }
101                )
102        }
103        SmoothBasisSpec::ConstantCurvature { spec, .. } => {
104            matches!(spec.center_strategy, CenterStrategy::UserProvided(_))
105                || matches!(
106                    spec.identifiability,
107                    ConstantCurvatureIdentifiability::FrozenTransform { .. }
108                )
109        }
110        SmoothBasisSpec::MeasureJet { spec, .. } => {
111            matches!(spec.center_strategy, CenterStrategy::UserProvided(_))
112                || matches!(
113                    spec.identifiability,
114                    MeasureJetIdentifiability::FrozenTransform { .. }
115                )
116        }
117        SmoothBasisSpec::Matern { spec, .. } => matches!(
118            spec.identifiability,
119            MaternIdentifiability::FrozenTransform { .. }
120        ),
121        SmoothBasisSpec::BySmooth { by_kind, .. } => match by_kind {
122            ByVarKind::Factor { frozen_levels, .. } => frozen_levels.is_some(),
123            ByVarKind::Numeric { .. } => true,
124        },
125        SmoothBasisSpec::FactorSmooth { spec } => spec.group_frozen_levels.is_some(),
126        SmoothBasisSpec::Duchon { spec, .. } => matches!(
127            spec.identifiability,
128            SpatialIdentifiability::FrozenTransform { .. }
129        ),
130        SmoothBasisSpec::Pca {
131            centered,
132            center_mean,
133            pca_basis_path,
134            ..
135        } => !*centered || center_mean.is_some() || pca_basis_path.is_some(),
136        SmoothBasisSpec::TensorBSpline { spec, .. } => matches!(
137            spec.identifiability,
138            TensorBSplineIdentifiability::FrozenTransform { .. }
139        ),
140    }
141}
142
143fn compare_smooth_ownership_priority(
144    lhs_idx: usize,
145    lhs: &SmoothTermSpec,
146    rhs_idx: usize,
147    rhs: &SmoothTermSpec,
148) -> std::cmp::Ordering {
149    let lhs_cols = smooth_term_feature_cols(lhs);
150    let rhs_cols = smooth_term_feature_cols(rhs);
151    lhs_cols
152        .len()
153        .cmp(&rhs_cols.len())
154        .then_with(|| lhs_cols.cmp(&rhs_cols))
155        .then_with(|| smooth_basis_family_rank(lhs).cmp(&smooth_basis_family_rank(rhs)))
156        .then_with(|| lhs.name.cmp(&rhs.name))
157        .then(lhs_idx.cmp(&rhs_idx))
158}
159
160/// The `(by_col, level_bits)` row-gate of a factor-`by=` level smooth
161/// (`s(x, by=fac)`, treatment-contrast level), or `None` for any other smooth
162/// (including numeric-`by` scaling, which is NOT row-gated).
163///
164/// A level-gated smooth's design is zero on every row outside its level, so its
165/// columns are NOT in the column span of an un-gated (full-support) smooth on
166/// the same covariate. Ownership/orthogonalization must therefore skip it
167/// (otherwise the per-group deviation is residualized away to zero — #1276).
168fn factor_by_level_gate_of(term: &SmoothTermSpec) -> Option<(usize, u64)> {
169    match &term.basis {
170        SmoothBasisSpec::ByVariable {
171            by_col,
172            by: crate::smooth::ByVariableSpec::Level { value_bits, .. },
173            ..
174        } => Some((*by_col, *value_bits)),
175        _ => None,
176    }
177}
178
179fn smooth_is_owned_by_prior_term(owner: &SmoothTermSpec, target: &SmoothTermSpec) -> bool {
180    // A factor-`by=` level smooth is row-gated (zero off its level), so its
181    // columns lie outside the span of any owner that is not gated to the SAME
182    // (by_col, level): the un-gated population smooth `s(x)` does not span the
183    // group deviation `s(x, by=g==level)`. Residualizing the gated deviation
184    // against the population smooth collapses it to zero (#1276). Identifiability
185    // of the deviation comes from its own factor-level gate + penalty, handled
186    // by `factor_by_level_gate` in design construction — not from ownership.
187    if let Some(target_gate) = factor_by_level_gate_of(target) {
188        if factor_by_level_gate_of(owner) != Some(target_gate) {
189            return false;
190        }
191    }
192    let owner_features = smooth_term_feature_cols(owner)
193        .into_iter()
194        .collect::<BTreeSet<_>>();
195    let target_features = smooth_term_feature_cols(target)
196        .into_iter()
197        .collect::<BTreeSet<_>>();
198    owner_features.is_subset(&target_features)
199}
200
201/// Static (spec-only) description of the hierarchical smooth-ownership decomposition.
202///
203/// This is the single source of truth for the deterministic ownership policy that
204/// `apply_global_smooth_identifiability` uses during the fit: the processing order of
205/// smooth terms, the feature columns each term spans, the candidate lower-order owners of
206/// each term (nested/duplicate feature sets), and the basis-family rank used as a
207/// tie-breaker. The fit engine consumes this structure and additionally applies a numerical
208/// cross-residual overlap test on the realized design columns; the CLI structure-warning
209/// path consumes the same structure for diagnostic messages, so both paths agree on which
210/// smooths own which subspaces.
211pub struct SmoothStructureAnalysis {
212    /// Smooth-term indices sorted into ownership-processing order (lowest priority first):
213    /// lower-order / narrower smooths come first and own their subspaces.
214    pub ownership_order: Vec<usize>,
215    /// `term_feature_cols[idx]` are the sorted, deduplicated feature columns that smooth term
216    /// `idx` spans (indexed by the original smooth-term index, not by `ownership_order`).
217    pub term_feature_cols: Vec<Vec<usize>>,
218    /// `term_owners[idx]` are the indices of prior (in `ownership_order`) smooth terms whose
219    /// feature set is a subset of term `idx`'s feature set, i.e. candidate owners of `idx`.
220    /// The list is given in ownership-processing order.
221    pub term_owners: Vec<Vec<usize>>,
222}
223
224/// Compute the static hierarchical smooth-ownership decomposition from the smooth-term specs.
225///
226/// `smoothspecs` is the same slice that `apply_global_smooth_identifiability` receives.
227pub fn analyze_smooth_ownership(smoothspecs: &[SmoothTermSpec]) -> SmoothStructureAnalysis {
228    let term_feature_cols: Vec<Vec<usize>> =
229        smoothspecs.iter().map(smooth_term_feature_cols).collect();
230
231    let mut ownership_order: Vec<usize> = (0..smoothspecs.len()).collect();
232    ownership_order.sort_by(|&lhs, &rhs| {
233        compare_smooth_ownership_priority(lhs, &smoothspecs[lhs], rhs, &smoothspecs[rhs])
234    });
235
236    let mut term_owners = vec![Vec::<usize>::new(); smoothspecs.len()];
237    for (pos, &target_idx) in ownership_order.iter().enumerate() {
238        let target = &smoothspecs[target_idx];
239        term_owners[target_idx] = ownership_order[..pos]
240            .iter()
241            .copied()
242            .filter(|&owner_idx| smooth_is_owned_by_prior_term(&smoothspecs[owner_idx], target))
243            .collect();
244    }
245
246    SmoothStructureAnalysis {
247        ownership_order,
248        term_feature_cols,
249        term_owners,
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use crate::basis::{BSplineBasisSpec, BSplineKnotSpec, OneDimensionalBoundary};
257    use crate::smooth::{BySmoothKind, ByVariableSpec, ShapeConstraint};
258
259    fn bspline(feature_col: usize) -> SmoothBasisSpec {
260        SmoothBasisSpec::BSpline1D {
261            feature_col,
262            spec: BSplineBasisSpec {
263                degree: 3,
264                penalty_order: 2,
265                knotspec: BSplineKnotSpec::Generate {
266                    data_range: (0.0, 1.0),
267                    num_internal_knots: 5,
268                },
269                double_penalty: false,
270                identifiability: BSplineIdentifiability::None,
271                boundary: OneDimensionalBoundary::Open,
272                boundary_conditions: Default::default(),
273            },
274        }
275    }
276
277    fn term(name: &str, basis: SmoothBasisSpec) -> SmoothTermSpec {
278        SmoothTermSpec {
279            name: name.to_string(),
280            basis,
281            shape: ShapeConstraint::None,
282            joint_null_rotation: None,
283        }
284    }
285
286    fn level_by_term(
287        name: &str,
288        feature_col: usize,
289        by_col: usize,
290        level_bits: u64,
291    ) -> SmoothTermSpec {
292        term(
293            name,
294            SmoothBasisSpec::ByVariable {
295                inner: Box::new(bspline(feature_col)),
296                by_col,
297                kind: BySmoothKind::Level { level_bits },
298                by: ByVariableSpec::Level {
299                    value_bits: level_bits,
300                    label: name.to_string(),
301                },
302            },
303        )
304    }
305
306    #[test]
307    fn ungated_smooth_does_not_own_factor_by_level_smooth() {
308        let specs = vec![term("s(x)", bspline(0)), level_by_term("s(x):B", 0, 1, 42)];
309
310        let analysis = analyze_smooth_ownership(&specs);
311
312        assert_eq!(
313            analysis.term_owners[1],
314            Vec::<usize>::new(),
315            "ungated s(x) must not own the row-gated by-factor deviation smooth"
316        );
317    }
318
319    #[test]
320    fn same_factor_by_level_gate_keeps_normal_subset_ownership() {
321        let specs = vec![
322            level_by_term("s(x):B", 0, 2, 42),
323            term(
324                "te(x,z):B",
325                SmoothBasisSpec::ByVariable {
326                    inner: Box::new(SmoothBasisSpec::TensorBSpline {
327                        feature_cols: vec![0, 1],
328                        spec: Default::default(),
329                    }),
330                    by_col: 2,
331                    kind: BySmoothKind::Level { level_bits: 42 },
332                    by: ByVariableSpec::Level {
333                        value_bits: 42,
334                        label: "B".to_string(),
335                    },
336                },
337            ),
338        ];
339
340        let analysis = analyze_smooth_ownership(&specs);
341
342        assert_eq!(
343            analysis.term_owners[1],
344            vec![0],
345            "matching by-level gates may still use the ordinary nested smooth ownership rule"
346        );
347    }
348}