Skip to main content

gam_terms/smooth/
structure_analysis.rs

1use crate::basis::{
2    BSplineIdentifiability, CenterStrategy, ConstantCurvatureIdentifiability,
3    MaternIdentifiability, MeasureJetIdentifiability, SpatialIdentifiability,
4    SphericalSplineIdentifiability,
5};
6
7use super::{
8    ByVarKind, FactorSmoothFlavour, SmoothBasisSpec, SmoothTermSpec, TensorBSplineIdentifiability,
9};
10
11use std::collections::BTreeSet;
12
13fn smooth_basis_feature_cols(basis: &SmoothBasisSpec) -> Vec<usize> {
14    match basis {
15        SmoothBasisSpec::ByVariable { inner, by_col, .. }
16        | SmoothBasisSpec::FactorSumToZero { inner, by_col, .. } => {
17            let mut cols = smooth_basis_feature_cols(inner);
18            cols.push(*by_col);
19            cols.sort_unstable();
20            cols.dedup();
21            cols
22        }
23        SmoothBasisSpec::BySmooth { smooth, .. } => smooth_basis_feature_cols(smooth),
24        SmoothBasisSpec::BSpline1D { feature_col, .. } => vec![*feature_col],
25        SmoothBasisSpec::ThinPlate { feature_cols, .. }
26        | SmoothBasisSpec::Sphere { feature_cols, .. }
27        | SmoothBasisSpec::ConstantCurvature { feature_cols, .. }
28        | SmoothBasisSpec::Matern { feature_cols, .. }
29        | SmoothBasisSpec::MeasureJet { feature_cols, .. }
30        | SmoothBasisSpec::Duchon { feature_cols, .. }
31        | SmoothBasisSpec::Pca { feature_cols, .. }
32        | SmoothBasisSpec::TensorBSpline { feature_cols, .. } => feature_cols.clone(),
33        SmoothBasisSpec::FactorSmooth { spec } => {
34            let mut cols = spec.continuous_cols.clone();
35            cols.push(spec.group_col);
36            cols.sort_unstable();
37            cols.dedup();
38            cols
39        }
40    }
41}
42
43pub fn smooth_term_feature_cols(term: &SmoothTermSpec) -> Vec<usize> {
44    smooth_basis_feature_cols(&term.basis)
45}
46
47fn smooth_basis_family_rank(term: &SmoothTermSpec) -> u8 {
48    match &term.basis {
49        SmoothBasisSpec::ByVariable { inner, .. }
50        | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
51            smooth_basis_family_rank(&SmoothTermSpec {
52                name: term.name.clone(),
53                basis: (**inner).clone(),
54                shape: term.shape,
55                joint_null_rotation: None,
56            })
57        }
58        SmoothBasisSpec::BSpline1D { .. } => 0,
59        SmoothBasisSpec::TensorBSpline { .. } => 1,
60        SmoothBasisSpec::ThinPlate { .. } => 2,
61        SmoothBasisSpec::Sphere { .. } => 3,
62        SmoothBasisSpec::Matern { .. } => 4,
63        SmoothBasisSpec::Duchon { .. } => 5,
64        SmoothBasisSpec::Pca { .. } => 6,
65        SmoothBasisSpec::ConstantCurvature { .. } => 8,
66        SmoothBasisSpec::MeasureJet { .. } => 9,
67        SmoothBasisSpec::BySmooth { smooth, .. } => smooth_basis_family_rank(&SmoothTermSpec {
68            name: term.name.clone(),
69            basis: (**smooth).clone(),
70            shape: term.shape,
71            joint_null_rotation: None,
72        }),
73        SmoothBasisSpec::FactorSmooth { .. } => 7,
74    }
75}
76
77pub fn smooth_has_frozen_identifiability(term: &SmoothTermSpec) -> bool {
78    match &term.basis {
79        SmoothBasisSpec::ByVariable { inner, .. }
80        | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
81            smooth_has_frozen_identifiability(&SmoothTermSpec {
82                name: term.name.clone(),
83                basis: (**inner).clone(),
84                shape: term.shape,
85                joint_null_rotation: None,
86            })
87        }
88        SmoothBasisSpec::BSpline1D { spec, .. } => {
89            matches!(
90                spec.identifiability,
91                BSplineIdentifiability::FrozenTransform { .. }
92            )
93        }
94        SmoothBasisSpec::ThinPlate { spec, .. } => matches!(
95            spec.identifiability,
96            SpatialIdentifiability::FrozenTransform { .. }
97        ),
98        SmoothBasisSpec::Sphere { spec, .. } => {
99            matches!(spec.center_strategy, CenterStrategy::UserProvided(_))
100                || matches!(
101                    spec.identifiability,
102                    SphericalSplineIdentifiability::FrozenTransform { .. }
103                )
104        }
105        SmoothBasisSpec::ConstantCurvature { spec, .. } => {
106            matches!(spec.center_strategy, CenterStrategy::UserProvided(_))
107                || matches!(
108                    spec.identifiability,
109                    ConstantCurvatureIdentifiability::FrozenTransform { .. }
110                )
111        }
112        SmoothBasisSpec::MeasureJet { spec, .. } => {
113            matches!(spec.center_strategy, CenterStrategy::UserProvided(_))
114                || matches!(
115                    spec.identifiability,
116                    MeasureJetIdentifiability::FrozenTransform { .. }
117                )
118        }
119        SmoothBasisSpec::Matern { spec, .. } => matches!(
120            spec.identifiability,
121            MaternIdentifiability::FrozenTransform { .. }
122        ),
123        SmoothBasisSpec::BySmooth { by_kind, .. } => match by_kind {
124            ByVarKind::Factor { frozen_levels, .. } => frozen_levels.is_some(),
125            ByVarKind::Numeric { .. } => true,
126        },
127        SmoothBasisSpec::FactorSmooth { spec } => spec.group_frozen_levels.is_some(),
128        SmoothBasisSpec::Duchon { spec, .. } => matches!(
129            spec.identifiability,
130            SpatialIdentifiability::FrozenTransform { .. }
131        ),
132        SmoothBasisSpec::Pca {
133            centered,
134            center_mean,
135            pca_basis_path,
136            ..
137        } => !*centered || center_mean.is_some() || pca_basis_path.is_some(),
138        SmoothBasisSpec::TensorBSpline { spec, .. } => matches!(
139            spec.identifiability,
140            TensorBSplineIdentifiability::FrozenTransform { .. }
141        ),
142    }
143}
144
145fn compare_smooth_ownership_priority(
146    lhs_idx: usize,
147    lhs: &SmoothTermSpec,
148    rhs_idx: usize,
149    rhs: &SmoothTermSpec,
150) -> std::cmp::Ordering {
151    let lhs_cols = smooth_term_feature_cols(lhs);
152    let rhs_cols = smooth_term_feature_cols(rhs);
153    lhs_cols
154        .len()
155        .cmp(&rhs_cols.len())
156        .then_with(|| lhs_cols.cmp(&rhs_cols))
157        .then_with(|| smooth_basis_family_rank(lhs).cmp(&smooth_basis_family_rank(rhs)))
158        .then_with(|| lhs.name.cmp(&rhs.name))
159        .then(lhs_idx.cmp(&rhs_idx))
160}
161
162/// The `(by_col, level_bits)` row-gate of a factor-`by=` level smooth
163/// (`s(x, by=fac)`, treatment-contrast level), or `None` for any other smooth
164/// (including numeric-`by` scaling, which is NOT row-gated).
165///
166/// A level-gated smooth's design is zero on every row outside its level, so its
167/// columns are NOT in the column span of an un-gated (full-support) smooth on
168/// the same covariate. Ownership/orthogonalization must therefore skip it
169/// (otherwise the per-group deviation is residualized away to zero — #1276).
170fn factor_by_level_gate_of(term: &SmoothTermSpec) -> Option<(usize, u64)> {
171    match &term.basis {
172        SmoothBasisSpec::ByVariable {
173            by_col,
174            by: crate::smooth::ByVariableSpec::Level { value_bits, .. },
175            ..
176        } => Some((*by_col, *value_bits)),
177        _ => None,
178    }
179}
180
181/// The grouping column of a sum-to-zero factor *deviation* smooth
182/// (`s(g, x, bs="sz")`, lowered to either `FactorSmooth { Sz }` or the internal
183/// `FactorSumToZero`), or `None` for any other smooth.
184///
185/// An sz smooth's per-level design columns are sum-to-zero ACROSS the grouping
186/// factor at every covariate value (the last level's block is `-Σ` of the
187/// others, so each contrast column's across-group sum vanishes pointwise). They
188/// are therefore orthogonal to — never spanned by — any owner smooth that does
189/// not itself vary with that factor. Ownership/orthogonalization against such an
190/// owner (e.g. the shared `s(x)` in the canonical `s(x) + s(g, x, bs="sz")`)
191/// would residualize the genuine deviation away, collapsing every group's curve
192/// to a constant offset (#1605 — the same failure family as the factor-`by`
193/// level gate #1276). The sz block is self-identified by its own sum-to-zero
194/// contrast + penalty, so it needs no owner block.
195///
196/// This is specific to the `Sz` flavour: `Fs`/`Re` random-effect factor smooths
197/// carry a non-zero group mean that genuinely overlaps a shared `s(x)`, so their
198/// deliberate `s(x) + fs` residualization (#978) is preserved.
199fn factor_sum_to_zero_group_col(term: &SmoothTermSpec) -> Option<usize> {
200    match &term.basis {
201        SmoothBasisSpec::FactorSumToZero { by_col, .. } => Some(*by_col),
202        SmoothBasisSpec::FactorSmooth { spec }
203            if matches!(spec.flavour, FactorSmoothFlavour::Sz) =>
204        {
205            Some(spec.group_col)
206        }
207        _ => None,
208    }
209}
210
211fn smooth_is_owned_by_prior_term(owner: &SmoothTermSpec, target: &SmoothTermSpec) -> bool {
212    // A factor-`by=` level smooth is row-gated (zero off its level), so its
213    // columns lie outside the span of any owner that is not gated to the SAME
214    // (by_col, level): the un-gated population smooth `s(x)` does not span the
215    // group deviation `s(x, by=g==level)`. Residualizing the gated deviation
216    // against the population smooth collapses it to zero (#1276). Identifiability
217    // of the deviation comes from its own factor-level gate + penalty, handled
218    // by `factor_by_level_gate` in design construction — not from ownership.
219    if let Some(target_gate) = factor_by_level_gate_of(target) {
220        if factor_by_level_gate_of(owner) != Some(target_gate) {
221            return false;
222        }
223    }
224    // A sum-to-zero factor deviation smooth is orthogonal to any owner that does
225    // not vary with its grouping factor (its columns sum to zero across that
226    // factor). Such an owner cannot span it, so skip ownership — otherwise the
227    // deviation is residualized down to a per-group constant and the curve
228    // shape is lost (#1605).
229    if let Some(group_col) = factor_sum_to_zero_group_col(target) {
230        let owner_features = smooth_term_feature_cols(owner)
231            .into_iter()
232            .collect::<BTreeSet<_>>();
233        if !owner_features.contains(&group_col) {
234            return false;
235        }
236    }
237    let owner_features = smooth_term_feature_cols(owner)
238        .into_iter()
239        .collect::<BTreeSet<_>>();
240    let target_features = smooth_term_feature_cols(target)
241        .into_iter()
242        .collect::<BTreeSet<_>>();
243    owner_features.is_subset(&target_features)
244}
245
246/// Static (spec-only) description of the hierarchical smooth-ownership decomposition.
247///
248/// This is the single source of truth for the deterministic ownership policy that
249/// `apply_global_smooth_identifiability` uses during the fit: the processing order of
250/// smooth terms, the feature columns each term spans, the candidate lower-order owners of
251/// each term (nested/duplicate feature sets), and the basis-family rank used as a
252/// tie-breaker. The fit engine consumes this structure and additionally applies a numerical
253/// cross-residual overlap test on the realized design columns; the CLI structure-warning
254/// path consumes the same structure for diagnostic messages, so both paths agree on which
255/// smooths own which subspaces.
256pub struct SmoothStructureAnalysis {
257    /// Smooth-term indices sorted into ownership-processing order (lowest priority first):
258    /// lower-order / narrower smooths come first and own their subspaces.
259    pub ownership_order: Vec<usize>,
260    /// `term_feature_cols[idx]` are the sorted, deduplicated feature columns that smooth term
261    /// `idx` spans (indexed by the original smooth-term index, not by `ownership_order`).
262    pub term_feature_cols: Vec<Vec<usize>>,
263    /// `term_owners[idx]` are the indices of prior (in `ownership_order`) smooth terms whose
264    /// feature set is a subset of term `idx`'s feature set, i.e. candidate owners of `idx`.
265    /// The list is given in ownership-processing order.
266    pub term_owners: Vec<Vec<usize>>,
267}
268
269/// Compute the static hierarchical smooth-ownership decomposition from the smooth-term specs.
270///
271/// `smoothspecs` is the same slice that `apply_global_smooth_identifiability` receives.
272pub fn analyze_smooth_ownership(smoothspecs: &[SmoothTermSpec]) -> SmoothStructureAnalysis {
273    let term_feature_cols: Vec<Vec<usize>> =
274        smoothspecs.iter().map(smooth_term_feature_cols).collect();
275
276    let mut ownership_order: Vec<usize> = (0..smoothspecs.len()).collect();
277    ownership_order.sort_by(|&lhs, &rhs| {
278        compare_smooth_ownership_priority(lhs, &smoothspecs[lhs], rhs, &smoothspecs[rhs])
279    });
280
281    let mut term_owners = vec![Vec::<usize>::new(); smoothspecs.len()];
282    for (pos, &target_idx) in ownership_order.iter().enumerate() {
283        let target = &smoothspecs[target_idx];
284        term_owners[target_idx] = ownership_order[..pos]
285            .iter()
286            .copied()
287            .filter(|&owner_idx| smooth_is_owned_by_prior_term(&smoothspecs[owner_idx], target))
288            .collect();
289    }
290
291    SmoothStructureAnalysis {
292        ownership_order,
293        term_feature_cols,
294        term_owners,
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301    use crate::basis::{BSplineBasisSpec, BSplineKnotSpec, OneDimensionalBoundary};
302    use crate::smooth::{BySmoothKind, ByVariableSpec, ShapeConstraint};
303
304    fn bspline(feature_col: usize) -> SmoothBasisSpec {
305        SmoothBasisSpec::BSpline1D {
306            feature_col,
307            spec: BSplineBasisSpec {
308                degree: 3,
309                penalty_order: 2,
310                knotspec: BSplineKnotSpec::Generate {
311                    data_range: (0.0, 1.0),
312                    num_internal_knots: 5,
313                },
314                double_penalty: false,
315                identifiability: BSplineIdentifiability::None,
316                boundary: OneDimensionalBoundary::Open,
317                boundary_conditions: Default::default(),
318            },
319        }
320    }
321
322    fn term(name: &str, basis: SmoothBasisSpec) -> SmoothTermSpec {
323        SmoothTermSpec {
324            name: name.to_string(),
325            basis,
326            shape: ShapeConstraint::None,
327            joint_null_rotation: None,
328        }
329    }
330
331    fn level_by_term(
332        name: &str,
333        feature_col: usize,
334        by_col: usize,
335        level_bits: u64,
336    ) -> SmoothTermSpec {
337        term(
338            name,
339            SmoothBasisSpec::ByVariable {
340                inner: Box::new(bspline(feature_col)),
341                by_col,
342                kind: BySmoothKind::Level { level_bits },
343                by: ByVariableSpec::Level {
344                    value_bits: level_bits,
345                    label: name.to_string(),
346                },
347            },
348        )
349    }
350
351    #[test]
352    fn ungated_smooth_does_not_own_factor_by_level_smooth() {
353        let specs = vec![term("s(x)", bspline(0)), level_by_term("s(x):B", 0, 1, 42)];
354
355        let analysis = analyze_smooth_ownership(&specs);
356
357        assert_eq!(
358            analysis.term_owners[1],
359            Vec::<usize>::new(),
360            "ungated s(x) must not own the row-gated by-factor deviation smooth"
361        );
362    }
363
364    #[test]
365    fn same_factor_by_level_gate_keeps_normal_subset_ownership() {
366        let specs = vec![
367            level_by_term("s(x):B", 0, 2, 42),
368            term(
369                "te(x,z):B",
370                SmoothBasisSpec::ByVariable {
371                    inner: Box::new(SmoothBasisSpec::TensorBSpline {
372                        feature_cols: vec![0, 1],
373                        spec: Default::default(),
374                    }),
375                    by_col: 2,
376                    kind: BySmoothKind::Level { level_bits: 42 },
377                    by: ByVariableSpec::Level {
378                        value_bits: 42,
379                        label: "B".to_string(),
380                    },
381                },
382            ),
383        ];
384
385        let analysis = analyze_smooth_ownership(&specs);
386
387        assert_eq!(
388            analysis.term_owners[1],
389            vec![0],
390            "matching by-level gates may still use the ordinary nested smooth ownership rule"
391        );
392    }
393}