use crate::basis::{
BSplineIdentifiability, CenterStrategy, ConstantCurvatureIdentifiability,
MaternIdentifiability, MeasureJetIdentifiability, SpatialIdentifiability,
SphericalSplineIdentifiability,
};
use super::{ByVarKind, SmoothBasisSpec, SmoothTermSpec, TensorBSplineIdentifiability};
use std::collections::BTreeSet;
fn smooth_basis_feature_cols(basis: &SmoothBasisSpec) -> Vec<usize> {
match basis {
SmoothBasisSpec::ByVariable { inner, by_col, .. }
| SmoothBasisSpec::FactorSumToZero { inner, by_col, .. } => {
let mut cols = smooth_basis_feature_cols(inner);
cols.push(*by_col);
cols.sort_unstable();
cols.dedup();
cols
}
SmoothBasisSpec::BySmooth { smooth, .. } => smooth_basis_feature_cols(smooth),
SmoothBasisSpec::BSpline1D { feature_col, .. } => vec![*feature_col],
SmoothBasisSpec::ThinPlate { feature_cols, .. }
| SmoothBasisSpec::Sphere { feature_cols, .. }
| SmoothBasisSpec::ConstantCurvature { feature_cols, .. }
| SmoothBasisSpec::Matern { feature_cols, .. }
| SmoothBasisSpec::MeasureJet { feature_cols, .. }
| SmoothBasisSpec::Duchon { feature_cols, .. }
| SmoothBasisSpec::Pca { feature_cols, .. }
| SmoothBasisSpec::TensorBSpline { feature_cols, .. } => feature_cols.clone(),
SmoothBasisSpec::FactorSmooth { spec } => {
let mut cols = spec.continuous_cols.clone();
cols.push(spec.group_col);
cols.sort_unstable();
cols.dedup();
cols
}
}
}
pub fn smooth_term_feature_cols(term: &SmoothTermSpec) -> Vec<usize> {
smooth_basis_feature_cols(&term.basis)
}
fn smooth_basis_family_rank(term: &SmoothTermSpec) -> u8 {
match &term.basis {
SmoothBasisSpec::ByVariable { inner, .. }
| SmoothBasisSpec::FactorSumToZero { inner, .. } => {
smooth_basis_family_rank(&SmoothTermSpec {
name: term.name.clone(),
basis: (**inner).clone(),
shape: term.shape,
joint_null_rotation: None,
})
}
SmoothBasisSpec::BSpline1D { .. } => 0,
SmoothBasisSpec::TensorBSpline { .. } => 1,
SmoothBasisSpec::ThinPlate { .. } => 2,
SmoothBasisSpec::Sphere { .. } => 3,
SmoothBasisSpec::Matern { .. } => 4,
SmoothBasisSpec::Duchon { .. } => 5,
SmoothBasisSpec::Pca { .. } => 6,
SmoothBasisSpec::ConstantCurvature { .. } => 8,
SmoothBasisSpec::MeasureJet { .. } => 9,
SmoothBasisSpec::BySmooth { smooth, .. } => smooth_basis_family_rank(&SmoothTermSpec {
name: term.name.clone(),
basis: (**smooth).clone(),
shape: term.shape,
joint_null_rotation: None,
}),
SmoothBasisSpec::FactorSmooth { .. } => 7,
}
}
pub fn smooth_has_frozen_identifiability(term: &SmoothTermSpec) -> bool {
match &term.basis {
SmoothBasisSpec::ByVariable { inner, .. }
| SmoothBasisSpec::FactorSumToZero { inner, .. } => {
smooth_has_frozen_identifiability(&SmoothTermSpec {
name: term.name.clone(),
basis: (**inner).clone(),
shape: term.shape,
joint_null_rotation: None,
})
}
SmoothBasisSpec::BSpline1D { spec, .. } => {
matches!(
spec.identifiability,
BSplineIdentifiability::FrozenTransform { .. }
)
}
SmoothBasisSpec::ThinPlate { spec, .. } => matches!(
spec.identifiability,
SpatialIdentifiability::FrozenTransform { .. }
),
SmoothBasisSpec::Sphere { spec, .. } => {
matches!(spec.center_strategy, CenterStrategy::UserProvided(_))
|| matches!(
spec.identifiability,
SphericalSplineIdentifiability::FrozenTransform { .. }
)
}
SmoothBasisSpec::ConstantCurvature { spec, .. } => {
matches!(spec.center_strategy, CenterStrategy::UserProvided(_))
|| matches!(
spec.identifiability,
ConstantCurvatureIdentifiability::FrozenTransform { .. }
)
}
SmoothBasisSpec::MeasureJet { spec, .. } => {
matches!(spec.center_strategy, CenterStrategy::UserProvided(_))
|| matches!(
spec.identifiability,
MeasureJetIdentifiability::FrozenTransform { .. }
)
}
SmoothBasisSpec::Matern { spec, .. } => matches!(
spec.identifiability,
MaternIdentifiability::FrozenTransform { .. }
),
SmoothBasisSpec::BySmooth { by_kind, .. } => match by_kind {
ByVarKind::Factor { frozen_levels, .. } => frozen_levels.is_some(),
ByVarKind::Numeric { .. } => true,
},
SmoothBasisSpec::FactorSmooth { spec } => spec.group_frozen_levels.is_some(),
SmoothBasisSpec::Duchon { spec, .. } => matches!(
spec.identifiability,
SpatialIdentifiability::FrozenTransform { .. }
),
SmoothBasisSpec::Pca {
centered,
center_mean,
pca_basis_path,
..
} => !*centered || center_mean.is_some() || pca_basis_path.is_some(),
SmoothBasisSpec::TensorBSpline { spec, .. } => matches!(
spec.identifiability,
TensorBSplineIdentifiability::FrozenTransform { .. }
),
}
}
fn compare_smooth_ownership_priority(
lhs_idx: usize,
lhs: &SmoothTermSpec,
rhs_idx: usize,
rhs: &SmoothTermSpec,
) -> std::cmp::Ordering {
let lhs_cols = smooth_term_feature_cols(lhs);
let rhs_cols = smooth_term_feature_cols(rhs);
lhs_cols
.len()
.cmp(&rhs_cols.len())
.then_with(|| lhs_cols.cmp(&rhs_cols))
.then_with(|| smooth_basis_family_rank(lhs).cmp(&smooth_basis_family_rank(rhs)))
.then_with(|| lhs.name.cmp(&rhs.name))
.then(lhs_idx.cmp(&rhs_idx))
}
fn factor_by_level_gate_of(term: &SmoothTermSpec) -> Option<(usize, u64)> {
match &term.basis {
SmoothBasisSpec::ByVariable {
by_col,
by: crate::smooth::ByVariableSpec::Level { value_bits, .. },
..
} => Some((*by_col, *value_bits)),
_ => None,
}
}
fn smooth_is_owned_by_prior_term(owner: &SmoothTermSpec, target: &SmoothTermSpec) -> bool {
if let Some(target_gate) = factor_by_level_gate_of(target) {
if factor_by_level_gate_of(owner) != Some(target_gate) {
return false;
}
}
let owner_features = smooth_term_feature_cols(owner)
.into_iter()
.collect::<BTreeSet<_>>();
let target_features = smooth_term_feature_cols(target)
.into_iter()
.collect::<BTreeSet<_>>();
owner_features.is_subset(&target_features)
}
pub struct SmoothStructureAnalysis {
pub ownership_order: Vec<usize>,
pub term_feature_cols: Vec<Vec<usize>>,
pub term_owners: Vec<Vec<usize>>,
}
pub fn analyze_smooth_ownership(smoothspecs: &[SmoothTermSpec]) -> SmoothStructureAnalysis {
let term_feature_cols: Vec<Vec<usize>> =
smoothspecs.iter().map(smooth_term_feature_cols).collect();
let mut ownership_order: Vec<usize> = (0..smoothspecs.len()).collect();
ownership_order.sort_by(|&lhs, &rhs| {
compare_smooth_ownership_priority(lhs, &smoothspecs[lhs], rhs, &smoothspecs[rhs])
});
let mut term_owners = vec![Vec::<usize>::new(); smoothspecs.len()];
for (pos, &target_idx) in ownership_order.iter().enumerate() {
let target = &smoothspecs[target_idx];
term_owners[target_idx] = ownership_order[..pos]
.iter()
.copied()
.filter(|&owner_idx| smooth_is_owned_by_prior_term(&smoothspecs[owner_idx], target))
.collect();
}
SmoothStructureAnalysis {
ownership_order,
term_feature_cols,
term_owners,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::basis::{BSplineBasisSpec, BSplineKnotSpec, OneDimensionalBoundary};
use crate::smooth::{BySmoothKind, ByVariableSpec, ShapeConstraint};
fn bspline(feature_col: usize) -> SmoothBasisSpec {
SmoothBasisSpec::BSpline1D {
feature_col,
spec: BSplineBasisSpec {
degree: 3,
penalty_order: 2,
knotspec: BSplineKnotSpec::Generate {
data_range: (0.0, 1.0),
num_internal_knots: 5,
},
double_penalty: false,
identifiability: BSplineIdentifiability::None,
boundary: OneDimensionalBoundary::Open,
boundary_conditions: Default::default(),
},
}
}
fn term(name: &str, basis: SmoothBasisSpec) -> SmoothTermSpec {
SmoothTermSpec {
name: name.to_string(),
basis,
shape: ShapeConstraint::None,
joint_null_rotation: None,
}
}
fn level_by_term(
name: &str,
feature_col: usize,
by_col: usize,
level_bits: u64,
) -> SmoothTermSpec {
term(
name,
SmoothBasisSpec::ByVariable {
inner: Box::new(bspline(feature_col)),
by_col,
kind: BySmoothKind::Level { level_bits },
by: ByVariableSpec::Level {
value_bits: level_bits,
label: name.to_string(),
},
},
)
}
#[test]
fn ungated_smooth_does_not_own_factor_by_level_smooth() {
let specs = vec![term("s(x)", bspline(0)), level_by_term("s(x):B", 0, 1, 42)];
let analysis = analyze_smooth_ownership(&specs);
assert_eq!(
analysis.term_owners[1],
Vec::<usize>::new(),
"ungated s(x) must not own the row-gated by-factor deviation smooth"
);
}
#[test]
fn same_factor_by_level_gate_keeps_normal_subset_ownership() {
let specs = vec![
level_by_term("s(x):B", 0, 2, 42),
term(
"te(x,z):B",
SmoothBasisSpec::ByVariable {
inner: Box::new(SmoothBasisSpec::TensorBSpline {
feature_cols: vec![0, 1],
spec: Default::default(),
}),
by_col: 2,
kind: BySmoothKind::Level { level_bits: 42 },
by: ByVariableSpec::Level {
value_bits: 42,
label: "B".to_string(),
},
},
),
];
let analysis = analyze_smooth_ownership(&specs);
assert_eq!(
analysis.term_owners[1],
vec![0],
"matching by-level gates may still use the ordinary nested smooth ownership rule"
);
}
}