use coefficient_transforms::{
convex_divided_difference_transform_matrix, cumulative_exp, cumulative_sum_transform_matrix,
second_cumulative_exp,
};
pub use error::SmoothError;
use input_standardization::{
apply_input_standardization, compensate_length_scale_for_standardization,
compensate_optional_length_scale_for_standardization, compute_spatial_input_scales,
};
use shape_constraints::{
build_shape_constraint_design_1d, build_shape_linear_constraints_1d,
linear_constraints_from_lower_bounds_global, merge_linear_constraints_global,
shape_lower_bounds_local, shape_order_and_sign, shape_supports_basis,
shape_uses_box_reparameterization,
};
fn describe_thin_plate_center_request(strategy: &CenterStrategy) -> String {
match strategy {
CenterStrategy::Auto(inner) => describe_thin_plate_center_request(inner),
CenterStrategy::UserProvided(centers) => format!("{} centers", centers.nrows()),
CenterStrategy::EqualMass { num_centers }
| CenterStrategy::EqualMassCovarRepresentative { num_centers }
| CenterStrategy::FarthestPoint { num_centers }
| CenterStrategy::KMeans { num_centers, .. } => format!("{num_centers} centers"),
CenterStrategy::UniformGrid { points_per_dim } => {
format!("uniform grid with {points_per_dim} points per dimension")
}
}
}
fn rewrite_thin_plate_knots_error(
err: BasisError,
termname: &str,
feature_count: usize,
spec: &ThinPlateBasisSpec,
) -> BasisError {
match err {
BasisError::InvalidInput(msg)
if msg.contains("thin-plate spline requires at least")
&& (msg.contains("centers to span") || msg.contains("knots to span")) =>
{
let min_centers = crate::basis::thin_plate_polynomial_basis_dimension(feature_count);
let requested = describe_thin_plate_center_request(&spec.center_strategy);
BasisError::InvalidInput(format!(
"joint TPS term '{termname}' over {feature_count} covariates with {requested} is invalid; minimum centers is {min_centers}"
))
}
BasisError::InvalidInput(msg)
if msg.starts_with("requested ") && msg.contains(" knots but only ") =>
{
let min_centers = crate::basis::thin_plate_polynomial_basis_dimension(feature_count);
let requested = describe_thin_plate_center_request(&spec.center_strategy);
BasisError::InvalidInput(format!(
"joint TPS term '{termname}' over {feature_count} covariates with {requested} is invalid; minimum centers is {min_centers}"
))
}
other => other,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ShapeConstraint {
None,
MonotoneIncreasing,
MonotoneDecreasing,
Convex,
Concave,
}
pub fn parse_shape_constraint(raw: &str) -> Result<ShapeConstraint, String> {
let normalized = raw.trim().to_ascii_lowercase().replace('-', "_");
match normalized.as_str() {
"" | "none" => Ok(ShapeConstraint::None),
"monotone_increasing" | "monotonic_increasing" | "increasing" | "mono_inc" | "mpi" => {
Ok(ShapeConstraint::MonotoneIncreasing)
}
"monotone_decreasing" | "monotonic_decreasing" | "decreasing" | "mono_dec" | "mpd" => {
Ok(ShapeConstraint::MonotoneDecreasing)
}
"convex" | "cvx" => Ok(ShapeConstraint::Convex),
"concave" | "ccv" => Ok(ShapeConstraint::Concave),
other => Err(format!(
"unknown shape constraint {other:?}; expected one of \
\"none\", \"monotone_increasing\", \"monotone_decreasing\", \
\"convex\", \"concave\""
)),
}
}
impl ShapeConstraint {
pub fn dsl_str(&self) -> &'static str {
match self {
ShapeConstraint::None => "none",
ShapeConstraint::MonotoneIncreasing => "monotone_increasing",
ShapeConstraint::MonotoneDecreasing => "monotone_decreasing",
ShapeConstraint::Convex => "convex",
ShapeConstraint::Concave => "concave",
}
}
}
const SMOOTH_HEAD_KEYWORDS: [&str; 11] = [
"s",
"smooth",
"te",
"tensor",
"thinplate",
"tps",
"duchon",
"matern",
"sphere",
"bs",
"bspline",
];
pub fn apply_shape_constraints_to_formula(
formula: &str,
constraints: &[(String, String)],
) -> Result<String, String> {
use std::collections::{BTreeMap, BTreeSet};
if constraints.is_empty() {
return Ok(formula.to_string());
}
let strip_ws = |s: &str| -> String { s.chars().filter(|c| !c.is_whitespace()).collect() };
let mut wanted: BTreeMap<String, &'static str> = BTreeMap::new();
let mut originals: BTreeMap<String, String> = BTreeMap::new();
for (key, kind_raw) in constraints {
let kind = parse_shape_constraint(kind_raw)?;
let nk = strip_ws(key);
originals.entry(nk.clone()).or_insert_with(|| key.clone());
if kind != ShapeConstraint::None {
wanted.insert(nk, kind.dsl_str());
}
}
if wanted.is_empty() {
return Ok(formula.to_string());
}
let chars: Vec<char> = formula.chars().collect();
let n = chars.len();
let is_ident = |c: char| c.is_ascii_alphanumeric() || c == '_';
let mut out = String::with_capacity(formula.len() + 32);
let mut matched: BTreeSet<String> = BTreeSet::new();
let mut i = 0usize;
while i < n {
let mut head: Option<(usize, usize)> = None; let mut p = i;
while p < n {
let boundary = p == 0 || !is_ident(chars[p - 1]);
if boundary {
for kw in SMOOTH_HEAD_KEYWORDS.iter() {
let klen = kw.chars().count();
if p + klen > n || chars[p..p + klen].iter().collect::<String>() != **kw {
continue;
}
let mut q = p + klen;
while q < n && chars[q].is_whitespace() {
q += 1;
}
if q < n && chars[q] == '(' {
head = Some((p, q));
break;
}
}
}
if head.is_some() {
break;
}
p += 1;
}
let (head_start, paren_open) = match head {
Some(h) => h,
None => {
out.extend(chars[i..].iter());
break;
}
};
out.extend(chars[i..head_start].iter());
let body_start = paren_open + 1;
let mut depth = 1i32;
let mut j = body_start;
let mut in_str: Option<char> = None;
let mut closed = false;
while j < n {
let ch = chars[j];
if let Some(quote) = in_str {
if ch == quote {
in_str = None;
}
} else if ch == '\'' || ch == '"' {
in_str = Some(ch);
} else if ch == '(' {
depth += 1;
} else if ch == ')' {
depth -= 1;
if depth == 0 {
closed = true;
break;
}
}
j += 1;
}
if !closed {
out.extend(chars[head_start..].iter());
break;
}
let term_text: String = chars[head_start..=j].iter().collect();
let key_norm = strip_ws(&term_text);
match wanted.get(&key_norm) {
None => out.extend(chars[head_start..=j].iter()),
Some(kind) => {
let head_paren: String = chars[head_start..body_start].iter().collect();
let inside: String = chars[body_start..j].iter().collect();
let inside = inside.trim();
if inside.is_empty() {
out.push_str(&format!("{head_paren}shape={kind})"));
} else {
out.push_str(&format!("{head_paren}{inside}, shape={kind})"));
}
matched.insert(key_norm);
}
}
i = j + 1;
}
let mut missing: Vec<String> = wanted
.keys()
.filter(|k| !matched.contains(*k))
.map(|k| originals.get(k).cloned().unwrap_or_else(|| k.clone()))
.collect();
if !missing.is_empty() {
missing.sort();
return Err(format!(
"shape constraints referenced smooth term(s) not found in formula: {}",
missing.join(", ")
));
}
Ok(out)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum BySmoothKind {
Numeric,
Level { level_bits: u64 },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SmoothBasisSpec {
ByVariable {
inner: Box<SmoothBasisSpec>,
by_col: usize,
kind: BySmoothKind,
by: ByVariableSpec,
},
FactorSumToZero {
inner: Box<SmoothBasisSpec>,
by_col: usize,
levels: Vec<u64>,
#[serde(default)]
frozen_global_orthogonality: Option<Array2<f64>>,
},
BSpline1D {
feature_col: usize,
spec: BSplineBasisSpec,
},
BySmooth {
smooth: Box<SmoothBasisSpec>,
by_kind: ByVarKind,
},
FactorSmooth { spec: FactorSmoothSpec },
ThinPlate {
feature_cols: Vec<usize>,
spec: ThinPlateBasisSpec,
#[serde(default)]
input_scales: Option<Vec<f64>>,
},
Sphere {
feature_cols: Vec<usize>,
spec: SphericalSplineBasisSpec,
},
ConstantCurvature {
feature_cols: Vec<usize>,
spec: ConstantCurvatureBasisSpec,
},
Matern {
feature_cols: Vec<usize>,
spec: MaternBasisSpec,
#[serde(default)]
input_scales: Option<Vec<f64>>,
},
MeasureJet {
feature_cols: Vec<usize>,
spec: MeasureJetBasisSpec,
#[serde(default)]
input_scales: Option<Vec<f64>>,
},
Duchon {
feature_cols: Vec<usize>,
spec: DuchonBasisSpec,
#[serde(default)]
input_scales: Option<Vec<f64>>,
},
Pca {
feature_cols: Vec<usize>,
basis_matrix: Array2<f64>,
centered: bool,
#[serde(default = "default_pca_smooth_penalty")]
smooth_penalty: f64,
#[serde(default)]
center_mean: Option<Array1<f64>>,
#[serde(default)]
pca_basis_path: Option<PathBuf>,
#[serde(default = "default_pca_chunk_size")]
chunk_size: usize,
},
TensorBSpline {
feature_cols: Vec<usize>,
spec: TensorBSplineSpec,
},
}
impl SmoothBasisSpec {
pub fn min_sample_rows(&self) -> usize {
const RADIAL_FLOOR: usize = 5;
match self {
Self::ByVariable { inner, .. } => inner.min_sample_rows(),
Self::FactorSumToZero { inner, levels, .. } => {
let inner_min = inner.min_sample_rows();
let lvls = levels.len().saturating_sub(1).max(1);
inner_min.saturating_mul(lvls)
}
Self::BSpline1D { spec, .. } => bspline_basis_min_rows(spec),
Self::BySmooth { smooth, .. } => smooth.min_sample_rows(),
Self::FactorSmooth { spec } => {
bspline_basis_min_rows(&spec.marginal)
}
Self::ThinPlate { .. }
| Self::Sphere { .. }
| Self::ConstantCurvature { .. }
| Self::Matern { .. }
| Self::MeasureJet { .. }
| Self::Duchon { .. } => RADIAL_FLOOR,
Self::Pca { basis_matrix, .. } => basis_matrix.ncols().max(1),
Self::TensorBSpline { spec, .. } => {
let mut total: usize = 0;
for marginal in &spec.marginalspecs {
let m = bspline_basis_min_rows(marginal);
total = total.saturating_add(m.max(1));
}
total.max(RADIAL_FLOOR)
}
}
}
pub fn structural_kind(&self) -> &'static str {
match self {
Self::ByVariable { .. } => "by_variable",
Self::FactorSumToZero { .. } => "factor_sum_to_zero",
Self::BSpline1D { .. } => "bspline_1d",
Self::BySmooth { .. } => "by_smooth",
Self::FactorSmooth { .. } => "factor_smooth",
Self::ThinPlate { .. } => "thin_plate",
Self::Sphere { .. } => "sphere",
Self::ConstantCurvature { .. } => "constant_curvature",
Self::Matern { .. } => "matern",
Self::MeasureJet { .. } => "measurejet",
Self::Duchon { .. } => "duchon",
Self::Pca { .. } => "pca",
Self::TensorBSpline { .. } => "tensor_bspline",
}
}
pub fn structural_feature_cols(&self) -> Vec<usize> {
match self {
Self::ByVariable { inner, .. } | Self::FactorSumToZero { inner, .. } => {
inner.structural_feature_cols()
}
Self::BySmooth { smooth, .. } => smooth.structural_feature_cols(),
Self::FactorSmooth { .. } => Vec::new(),
Self::BSpline1D { feature_col, .. } => vec![*feature_col],
Self::ThinPlate { feature_cols, .. }
| Self::Sphere { feature_cols, .. }
| Self::ConstantCurvature { feature_cols, .. }
| Self::Matern { feature_cols, .. }
| Self::MeasureJet { feature_cols, .. }
| Self::Duchon { feature_cols, .. }
| Self::Pca { feature_cols, .. }
| Self::TensorBSpline { feature_cols, .. } => feature_cols.clone(),
}
}
}
fn bspline_basis_min_rows(spec: &crate::terms::basis::BSplineBasisSpec) -> usize {
use crate::terms::basis::BSplineKnotSpec;
let columns = match &spec.knotspec {
BSplineKnotSpec::Generate {
num_internal_knots, ..
} => *num_internal_knots + spec.degree + 1,
BSplineKnotSpec::Automatic {
num_internal_knots: Some(k),
..
} => *k + spec.degree + 1,
BSplineKnotSpec::Automatic {
num_internal_knots: None,
..
} => {
spec.degree + 2
}
BSplineKnotSpec::Provided(knots) => knots.len().saturating_sub(spec.degree + 1).max(1),
BSplineKnotSpec::PeriodicUniform { num_basis, .. } => *num_basis,
};
let columns = columns.max(spec.degree + 2);
if spec.double_penalty {
const DOUBLE_PENALTY_FLOOR: usize = 2;
DOUBLE_PENALTY_FLOOR.min(columns).max(1)
} else {
columns
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ByVariableSpec {
Numeric,
Level { value_bits: u64, label: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TensorMarginalSpec {
BSpline(BSplineBasisSpec),
Categorical {
feature_col_offset: usize,
drop_first_level: bool,
center_for_identifiability: bool,
frozen_levels: Option<Vec<u64>>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ByVarKind {
Numeric {
feature_col: usize,
},
Factor {
feature_col: usize,
ordered: bool,
frozen_levels: Option<Vec<u64>>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FactorSmoothSpec {
pub continuous_cols: Vec<usize>,
pub group_col: usize,
pub marginal: BSplineBasisSpec,
pub flavour: FactorSmoothFlavour,
pub group_frozen_levels: Option<Vec<u64>>,
#[serde(default)]
pub frozen_global_orthogonality: Option<Array2<f64>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FactorSmoothFlavour {
Fs { m_null_penalty_orders: Vec<usize> },
Sz,
Re,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorBSplineSpec {
pub marginalspecs: Vec<BSplineBasisSpec>,
#[serde(default)]
pub periods: Vec<Option<f64>>,
pub double_penalty: bool,
#[serde(default)]
pub identifiability: TensorBSplineIdentifiability,
#[serde(default)]
pub penalty_decomposition: TensorBSplinePenaltyDecomposition,
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub enum TensorBSplineIdentifiability {
None,
#[default]
SumToZero,
MarginalSumToZero,
FrozenTransform {
transform: Array2<f64>,
},
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TensorBSplinePenaltyDecomposition {
#[default]
MarginalKroneckerSum,
Separable,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SmoothTermSpec {
pub name: String,
pub basis: SmoothBasisSpec,
pub shape: ShapeConstraint,
#[serde(default)]
pub joint_null_rotation: Option<crate::terms::basis::JointNullRotation>,
}
#[derive(Debug, Clone)]
pub struct SmoothTerm {
pub name: String,
pub coeff_range: Range<usize>,
pub shape: ShapeConstraint,
pub penalties_local: Vec<Array2<f64>>,
pub nullspace_dims: Vec<usize>,
pub penaltyinfo_local: Vec<PenaltyInfo>,
pub metadata: BasisMetadata,
pub lower_bounds_local: Option<Array1<f64>>,
pub linear_constraints_local: Option<LinearInequalityConstraints>,
pub kronecker_factored: Option<KroneckerFactoredBasis>,
pub joint_null_rotation: Option<crate::terms::basis::JointNullRotation>,
pub unabsorbed_global_orthogonality: Option<Array2<f64>>,
}
impl SmoothTerm {
pub fn apply_rotation_to_predict(
&self,
x_new_raw: Array2<f64>,
) -> Result<Array2<f64>, BasisError> {
let Some(rot) = self.joint_null_rotation.as_ref() else {
return Ok(x_new_raw);
};
let p_local = rot.rotation.nrows();
if x_new_raw.ncols() != p_local {
crate::bail_dim_basis!(
"joint-null rotation replay for term '{}': raw design has {} columns, \
rotation expects {} (the raw basis builder must emit the same column \
count as at fit time)",
self.name,
x_new_raw.ncols(),
p_local,
);
}
Ok(crate::linalg::faer_ndarray::fast_ab(
&x_new_raw,
&rot.rotation,
))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PenaltyBlockInfo {
pub global_index: usize,
pub termname: Option<String>,
pub penalty: PenaltyInfo,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DroppedPenaltyBlockInfo {
pub termname: Option<String>,
pub penalty: PenaltyInfo,
}
#[derive(Debug, Clone)]
pub struct SmoothDesign {
pub term_designs: Vec<DesignMatrix>,
pub penalties: Vec<BlockwisePenalty>,
pub nullspace_dims: Vec<usize>,
pub penaltyinfo: Vec<PenaltyBlockInfo>,
pub dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
pub terms: Vec<SmoothTerm>,
pub coefficient_lower_bounds: Option<Array1<f64>>,
pub linear_constraints: Option<LinearInequalityConstraints>,
}
impl SmoothDesign {
pub fn total_smooth_cols(&self) -> usize {
self.term_designs.iter().map(DesignMatrix::ncols).sum()
}
pub fn nrows(&self) -> usize {
self.term_designs.first().map_or(0, DesignMatrix::nrows)
}
}
#[derive(Debug, Clone)]
pub struct RawSmoothDesign {
pub term_designs: Vec<DesignMatrix>,
pub penalties: Vec<BlockwisePenalty>,
pub nullspace_dims: Vec<usize>,
pub penaltyinfo: Vec<PenaltyBlockInfo>,
pub dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
pub terms: Vec<SmoothTerm>,
pub coefficient_lower_bounds: Option<Array1<f64>>,
pub linear_constraints: Option<LinearInequalityConstraints>,
}
impl RawSmoothDesign {
pub fn total_smooth_cols(&self) -> usize {
self.term_designs.iter().map(DesignMatrix::ncols).sum()
}
pub fn nrows(&self) -> usize {
self.term_designs.first().map_or(0, DesignMatrix::nrows)
}
}
impl From<RawSmoothDesign> for SmoothDesign {
fn from(value: RawSmoothDesign) -> Self {
Self {
term_designs: value.term_designs,
penalties: value.penalties,
nullspace_dims: value.nullspace_dims,
penaltyinfo: value.penaltyinfo,
dropped_penaltyinfo: value.dropped_penaltyinfo,
terms: value.terms,
coefficient_lower_bounds: value.coefficient_lower_bounds,
linear_constraints: value.linear_constraints,
}
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub enum BoundedCoefficientPriorSpec {
#[default]
None,
Uniform,
Beta {
a: f64,
b: f64,
},
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub enum LinearCoefficientGeometry {
#[default]
Unconstrained,
Bounded {
min: f64,
max: f64,
#[serde(default)]
prior: BoundedCoefficientPriorSpec,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LinearTermSpec {
pub name: String,
pub feature_col: usize,
#[serde(default)]
pub feature_cols: Vec<usize>,
#[serde(default)]
pub categorical_levels: Vec<(usize, u64)>,
#[serde(default = "default_linear_term_double_penalty")]
pub double_penalty: bool,
#[serde(default)]
pub coefficient_geometry: LinearCoefficientGeometry,
#[serde(default)]
pub coefficient_min: Option<f64>,
#[serde(default)]
pub coefficient_max: Option<f64>,
}
impl LinearTermSpec {
pub fn effective_feature_cols(&self) -> Vec<usize> {
if self.feature_cols.is_empty() {
vec![self.feature_col]
} else {
self.feature_cols.clone()
}
}
pub fn is_interaction(&self) -> bool {
self.feature_cols.len() > 1 || !self.categorical_levels.is_empty()
}
pub fn realized_design_column(&self, data: ArrayView2<'_, f64>) -> Result<Array1<f64>, String> {
let n = data.nrows();
let p = data.ncols();
let bounds = |col: usize| -> Result<(), String> {
if col >= p {
Err(format!(
"linear term '{}' feature column {} out of bounds for {} columns",
self.name, col, p
))
} else {
Ok(())
}
};
let mut column = if self.categorical_levels.is_empty() {
let cols = self.effective_feature_cols();
for &c in &cols {
bounds(c)?;
}
let mut acc = data.column(cols[0]).to_owned();
for &c in cols.iter().skip(1) {
acc *= &data.column(c);
}
acc
} else {
let mut acc = Array1::<f64>::ones(n);
for &c in &self.feature_cols {
bounds(c)?;
acc *= &data.column(c);
}
acc
};
for &(col, level_bits) in &self.categorical_levels {
bounds(col)?;
let gate = data.column(col);
for (out, &v) in column.iter_mut().zip(gate.iter()) {
if v.to_bits() != level_bits {
*out = 0.0;
}
}
}
Ok(column)
}
}
const fn default_linear_term_double_penalty() -> bool {
false
}
const fn default_pca_smooth_penalty() -> f64 {
1.0
}
const fn default_pca_chunk_size() -> usize {
4096
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RandomEffectTermSpec {
pub name: String,
pub feature_col: usize,
pub drop_first_level: bool,
#[serde(default = "default_random_effect_penalized")]
pub penalized: bool,
#[serde(default)]
pub frozen_levels: Option<Vec<u64>>,
}
fn default_random_effect_penalized() -> bool {
true
}
fn validate_measure_jet_positive_vec_len(
label: &str,
term_name: &str,
field: &str,
values: &[f64],
expected: usize,
) -> Result<(), String> {
if values.len() != expected {
return Err(SmoothError::invalid_config(format!(
"{label} term '{term_name}' frozen MeasureJet {field} has length {}, expected {expected}",
values.len()
))
.into());
}
if values
.iter()
.any(|value| !(value.is_finite() && *value > 0.0))
{
return Err(SmoothError::invalid_config(format!(
"{label} term '{term_name}' frozen MeasureJet {field} values must be positive and finite"
))
.into());
}
Ok(())
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TermCollectionSpec {
pub linear_terms: Vec<LinearTermSpec>,
pub random_effect_terms: Vec<RandomEffectTermSpec>,
pub smooth_terms: Vec<SmoothTermSpec>,
}
fn validate_smooth_basis_frozen(
basis: &SmoothBasisSpec,
label: &str,
term_name: &str,
) -> Result<(), String> {
match basis {
SmoothBasisSpec::ByVariable { inner, .. }
| SmoothBasisSpec::FactorSumToZero { inner, .. } => {
validate_smooth_basis_frozen(inner, label, term_name)
}
SmoothBasisSpec::BSpline1D { spec, .. } => {
if !matches!(
spec.knotspec,
BSplineKnotSpec::Provided(_) | BSplineKnotSpec::PeriodicUniform { .. }
) {
return Err(format!(
"{label} term '{term_name}' is not frozen: BSpline knotspec must be Provided or PeriodicUniform"
));
}
Ok(())
}
SmoothBasisSpec::ThinPlate { spec, .. } => {
if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
return Err(format!(
"{label} term '{term_name}' is not frozen: ThinPlate centers must be UserProvided"
));
}
if matches!(
spec.identifiability,
SpatialIdentifiability::OrthogonalToParametric
) {
return Err(format!(
"{label} term '{term_name}' is not frozen: ThinPlate identifiability must be FrozenTransform or None"
));
}
Ok(())
}
_ => Ok(()),
}
}
impl TermCollectionSpec {
pub fn write_structural_shape_hash(&self, h: &mut crate::warm_start::Fingerprinter) {
h.write_str("term-collection");
h.write_usize(self.linear_terms.len());
for linear in &self.linear_terms {
h.write_str(&linear.name);
}
h.write_usize(self.random_effect_terms.len());
h.write_usize(self.smooth_terms.len());
for smooth in &self.smooth_terms {
h.write_str(&smooth.name);
h.write_str(smooth.basis.structural_kind());
for col in smooth.basis.structural_feature_cols() {
h.write_usize(col);
}
}
}
pub fn validate_frozen(&self, label: &str) -> Result<(), String> {
for linear in &self.linear_terms {
if let (Some(min), Some(max)) = (linear.coefficient_min, linear.coefficient_max)
&& (!min.is_finite() || !max.is_finite() || min > max)
{
return Err(SmoothError::invalid_config(format!(
"{label} linear term '{}' has invalid coefficient constraint [{min}, {max}]",
linear.name
))
.into());
}
if let Some(min) = linear.coefficient_min
&& !min.is_finite()
{
return Err(SmoothError::invalid_config(format!(
"{label} linear term '{}' has non-finite coefficient minimum {min}",
linear.name
))
.into());
}
if let Some(max) = linear.coefficient_max
&& !max.is_finite()
{
return Err(SmoothError::invalid_config(format!(
"{label} linear term '{}' has non-finite coefficient maximum {max}",
linear.name
))
.into());
}
if let LinearCoefficientGeometry::Bounded { min, max, prior } =
&linear.coefficient_geometry
{
if !min.is_finite() || !max.is_finite() || min >= max {
return Err(SmoothError::invalid_config(format!(
"{label} bounded term '{}' has invalid bounds [{min}, {max}]",
linear.name
))
.into());
}
match prior {
BoundedCoefficientPriorSpec::None | BoundedCoefficientPriorSpec::Uniform => {}
BoundedCoefficientPriorSpec::Beta { a, b } => {
if !a.is_finite() || !b.is_finite() || *a < 1.0 || *b < 1.0 {
return Err(SmoothError::invalid_config(format!(
"{label} bounded term '{}' has invalid Beta prior ({a}, {b})",
linear.name
))
.into());
}
}
}
}
}
for st in &self.smooth_terms {
match &st.basis {
SmoothBasisSpec::ByVariable { inner, .. } => {
validate_smooth_basis_frozen(inner, label, &st.name)?;
let nested = SmoothTermSpec {
name: st.name.clone(),
basis: (**inner).clone(),
shape: st.shape,
joint_null_rotation: None,
};
TermCollectionSpec {
linear_terms: Vec::new(),
random_effect_terms: Vec::new(),
smooth_terms: vec![nested],
}
.validate_frozen(label)?;
}
SmoothBasisSpec::FactorSumToZero { inner, levels, .. } => {
if levels.len() < 2 {
return Err(format!(
"{label} term '{}' has invalid frozen sz levels",
st.name
));
}
validate_smooth_basis_frozen(inner, label, &st.name)?;
}
SmoothBasisSpec::BSpline1D { spec, .. } => {
if !matches!(
spec.knotspec,
BSplineKnotSpec::Provided(_) | BSplineKnotSpec::PeriodicUniform { .. }
) {
return Err(SmoothError::invalid_config(format!(
"{label} term '{}' is not frozen: BSpline knotspec must be Provided or PeriodicUniform",
st.name
))
.into());
}
}
SmoothBasisSpec::ThinPlate { spec, .. } => {
if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
return Err(SmoothError::invalid_config(format!(
"{label} term '{}' is not frozen: ThinPlate centers must be UserProvided",
st.name
))
.into());
}
if matches!(
spec.identifiability,
SpatialIdentifiability::OrthogonalToParametric
) {
return Err(SmoothError::invalid_config(format!(
"{label} term '{}' is not frozen: ThinPlate identifiability must be FrozenTransform or None",
st.name
))
.into());
}
}
SmoothBasisSpec::Sphere { spec, .. } => {
if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
return Err(SmoothError::invalid_config(format!(
"{label} term '{}' is not frozen: Sphere centers must be UserProvided",
st.name
))
.into());
}
if matches!(spec.method, crate::basis::SphereMethod::Harmonic)
&& spec.max_degree.is_none_or(|d| d == 0)
{
return Err(format!(
"{label} term '{}' is not frozen: sphere max_degree must be positive",
st.name
));
}
}
SmoothBasisSpec::ConstantCurvature { spec, .. } => {
if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
return Err(SmoothError::invalid_config(format!(
"{label} term '{}' is not frozen: ConstantCurvature centers must be UserProvided",
st.name
))
.into());
}
if !(spec.length_scale.is_finite() && spec.length_scale > 0.0) {
return Err(SmoothError::invalid_config(format!(
"{label} term '{}' is not frozen: ConstantCurvature length_scale must be the realized positive value",
st.name
))
.into());
}
}
SmoothBasisSpec::MeasureJet { spec, .. } => {
let centers = match &spec.center_strategy {
CenterStrategy::UserProvided(centers) => centers,
_ => {
return Err(SmoothError::invalid_config(format!(
"{label} term '{}' is not frozen: MeasureJet centers must be UserProvided",
st.name
))
.into());
}
};
if centers.nrows() == 0 {
return Err(SmoothError::invalid_config(format!(
"{label} term '{}' is not frozen: MeasureJet centers are empty",
st.name
))
.into());
}
if !(spec.length_scale.is_finite() && spec.length_scale > 0.0) {
return Err(SmoothError::invalid_config(format!(
"{label} term '{}' is not frozen: MeasureJet length_scale must be the realized positive value",
st.name
))
.into());
}
let frozen = spec.frozen_quadrature.as_ref().ok_or_else(|| {
SmoothError::invalid_config(format!(
"{label} term '{}' is not frozen: MeasureJet frozen_quadrature payload is missing",
st.name
))
})?;
if frozen.masses.len() != centers.nrows() {
return Err(SmoothError::invalid_config(format!(
"{label} term '{}' frozen MeasureJet has {} masses for {} centers",
st.name,
frozen.masses.len(),
centers.nrows()
))
.into());
}
let total_mass = frozen.masses.sum();
if frozen
.masses
.iter()
.any(|mass| !(mass.is_finite() && *mass >= 0.0))
|| !(total_mass.is_finite() && total_mass > 0.0)
{
return Err(SmoothError::invalid_config(format!(
"{label} term '{}' frozen MeasureJet masses must be finite, nonnegative, and have positive total mass",
st.name
))
.into());
}
let n_levels = frozen.eps_band.len();
if n_levels == 0
|| frozen
.eps_band
.iter()
.any(|eps| !(eps.is_finite() && *eps > 0.0))
{
return Err(SmoothError::invalid_config(format!(
"{label} term '{}' frozen MeasureJet eps_band must be nonempty, finite, and positive",
st.name
))
.into());
}
for (idx, pair) in frozen.eps_band.windows(2).enumerate() {
if pair[1] <= pair[0] {
return Err(SmoothError::invalid_config(format!(
"{label} term '{}' frozen MeasureJet eps_band is not strictly ascending at {idx}: {} then {}",
st.name,
pair[0],
pair[1]
))
.into());
}
}
validate_measure_jet_positive_vec_len(
label,
&st.name,
"support_means",
&frozen.support_means,
n_levels,
)?;
let per_level = crate::basis::measure_jet_multiscale_mode(spec);
if per_level {
validate_measure_jet_positive_vec_len(
label,
&st.name,
"penalty_normalization_scales",
&frozen.penalty_normalization_scales,
n_levels,
)?;
validate_measure_jet_positive_vec_len(
label,
&st.name,
"raw_penalty_normalization_scales",
&frozen.raw_penalty_normalization_scales,
n_levels,
)?;
if frozen.fused_penalty_normalization_scale.is_some() {
return Err(SmoothError::invalid_config(format!(
"{label} term '{}' per-level MeasureJet must not carry a fused penalty normalization scale",
st.name
))
.into());
}
} else {
if !frozen.penalty_normalization_scales.is_empty()
|| !frozen.raw_penalty_normalization_scales.is_empty()
{
return Err(SmoothError::invalid_config(format!(
"{label} term '{}' fused MeasureJet must not carry per-level penalty normalization scales",
st.name
))
.into());
}
match frozen.fused_penalty_normalization_scale {
Some(scale) if scale.is_finite() && scale > 0.0 => {}
Some(scale) => {
return Err(SmoothError::invalid_config(format!(
"{label} term '{}' fused MeasureJet penalty normalization scale must be positive and finite, got {scale}",
st.name
))
.into());
}
None => {
return Err(SmoothError::invalid_config(format!(
"{label} term '{}' fused MeasureJet is missing its penalty normalization scale",
st.name
))
.into());
}
}
}
}
SmoothBasisSpec::Matern { spec, .. } => {
if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
return Err(SmoothError::invalid_config(format!(
"{label} term '{}' is not frozen: Matern centers must be UserProvided",
st.name
))
.into());
}
}
SmoothBasisSpec::Duchon { spec, .. } => {
if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
return Err(SmoothError::invalid_config(format!(
"{label} term '{}' is not frozen: Duchon centers must be UserProvided",
st.name
))
.into());
}
if matches!(
spec.identifiability,
SpatialIdentifiability::OrthogonalToParametric
) {
return Err(SmoothError::invalid_config(format!(
"{label} term '{}' is not frozen: Duchon identifiability must be FrozenTransform or None",
st.name
))
.into());
}
}
SmoothBasisSpec::Pca {
centered,
center_mean,
pca_basis_path,
..
} => {
if *centered && center_mean.is_none() && pca_basis_path.is_none() {
return Err(SmoothError::invalid_config(format!(
"{label} term '{}' is not frozen: centered Pca missing center_mean",
st.name
))
.into());
}
}
SmoothBasisSpec::BySmooth { smooth, by_kind } => {
if let SmoothBasisSpec::BySmooth { .. } = smooth.as_ref() {
return Err(format!("{label} term '{}' has nested by-smooths", st.name));
}
match by_kind {
ByVarKind::Numeric { .. } => {}
ByVarKind::Factor { frozen_levels, .. } if frozen_levels.is_none() => {
return Err(format!(
"{label} term '{}' is not frozen: by-factor levels missing",
st.name
));
}
ByVarKind::Factor { .. } => {}
}
let nested = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: st.name.clone(),
basis: (**smooth).clone(),
shape: st.shape,
joint_null_rotation: None,
}],
};
nested.validate_frozen(label)?;
}
SmoothBasisSpec::FactorSmooth { spec } => {
if spec.group_frozen_levels.is_none() {
return Err(format!(
"{label} term '{}' is not frozen: factor-smooth levels missing",
st.name
));
}
if !matches!(
spec.marginal.knotspec,
BSplineKnotSpec::Provided(_) | BSplineKnotSpec::PeriodicUniform { .. }
) {
return Err(format!(
"{label} term '{}' is not frozen: factor-smooth marginal knots missing",
st.name
));
}
}
SmoothBasisSpec::TensorBSpline { spec, .. } => {
for (dim, marginal) in spec.marginalspecs.iter().enumerate() {
if !matches!(
marginal.knotspec,
BSplineKnotSpec::Provided(_) | BSplineKnotSpec::PeriodicUniform { .. }
) {
return Err(SmoothError::invalid_config(format!(
"{label} term '{}' dim {} is not frozen: tensor marginal knotspec must be Provided or PeriodicUniform",
st.name, dim
))
.into());
}
}
if matches!(
spec.identifiability,
TensorBSplineIdentifiability::SumToZero
| TensorBSplineIdentifiability::MarginalSumToZero
) {
return Err(SmoothError::invalid_config(format!(
"{label} term '{}' is not frozen: tensor identifiability must be FrozenTransform or None",
st.name
))
.into());
}
}
}
}
for rt in &self.random_effect_terms {
if rt.frozen_levels.is_none() {
return Err(SmoothError::invalid_config(format!(
"{label} random-effect term '{}' is not frozen: missing frozen_levels",
rt.name
))
.into());
}
}
Ok(())
}
pub fn remap_feature_columns<E, F>(&self, mut remap: F) -> Result<TermCollectionSpec, E>
where
F: FnMut(usize) -> Result<usize, E>,
{
let mut out = self.clone();
for lt in &mut out.linear_terms {
lt.feature_col = remap(lt.feature_col)?;
for fc in lt.feature_cols.iter_mut() {
*fc = remap(*fc)?;
}
for (col, _bits) in lt.categorical_levels.iter_mut() {
*col = remap(*col)?;
}
}
for rt in &mut out.random_effect_terms {
rt.feature_col = remap(rt.feature_col)?;
}
for st in &mut out.smooth_terms {
remap_smooth_basis_feature_columns(&mut st.basis, &mut remap)?;
}
Ok(out)
}
}
fn remap_smooth_basis_feature_columns<E, F>(
basis: &mut SmoothBasisSpec,
remap: &mut F,
) -> Result<(), E>
where
F: FnMut(usize) -> Result<usize, E>,
{
match basis {
SmoothBasisSpec::ByVariable { inner, by_col, .. }
| SmoothBasisSpec::FactorSumToZero { inner, by_col, .. } => {
*by_col = remap(*by_col)?;
remap_smooth_basis_feature_columns(inner, remap)?;
}
SmoothBasisSpec::BSpline1D { feature_col, .. } => {
*feature_col = remap(*feature_col)?;
}
SmoothBasisSpec::BySmooth { smooth, by_kind } => {
let by_feature_col = match by_kind {
ByVarKind::Numeric { feature_col } | ByVarKind::Factor { feature_col, .. } => {
feature_col
}
};
*by_feature_col = remap(*by_feature_col)?;
remap_smooth_basis_feature_columns(smooth, remap)?;
}
SmoothBasisSpec::FactorSmooth { spec } => {
for fc in spec.continuous_cols.iter_mut() {
*fc = remap(*fc)?;
}
spec.group_col = remap(spec.group_col)?;
}
SmoothBasisSpec::ThinPlate { feature_cols, .. }
| SmoothBasisSpec::Sphere { feature_cols, .. }
| SmoothBasisSpec::ConstantCurvature { feature_cols, .. }
| SmoothBasisSpec::Matern { feature_cols, .. }
| SmoothBasisSpec::MeasureJet { feature_cols, .. }
| SmoothBasisSpec::Duchon { feature_cols, .. }
| SmoothBasisSpec::Pca { feature_cols, .. }
| SmoothBasisSpec::TensorBSpline { feature_cols, .. } => {
for fc in feature_cols.iter_mut() {
*fc = remap(*fc)?;
}
}
}
Ok(())
}
#[derive(Debug, Clone)]
pub enum PenaltyStructureHint {
Ridge(f64),
Kronecker(Vec<Array2<f64>>),
}
#[derive(Clone)]
pub struct BlockwisePenalty {
pub col_range: Range<usize>,
pub local: Array2<f64>,
pub prior_mean: crate::solver::estimate::CoefficientPriorMean,
pub structure_hint: Option<PenaltyStructureHint>,
pub op: Option<std::sync::Arc<dyn crate::terms::analytic_penalties::PenaltyOp>>,
}
impl std::fmt::Debug for BlockwisePenalty {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BlockwisePenalty")
.field("col_range", &self.col_range)
.field(
"local",
&format_args!("{}×{}", self.local.nrows(), self.local.ncols()),
)
.field("prior_mean", &self.prior_mean)
.field("structure_hint", &self.structure_hint)
.field("op", &self.op.as_ref().map(|o| o.dim()))
.finish()
}
}
impl BlockwisePenalty {
pub fn new(col_range: Range<usize>, local: Array2<f64>) -> Self {
assert_eq!(col_range.len(), local.nrows());
assert_eq!(col_range.len(), local.ncols());
Self {
col_range,
local,
prior_mean: crate::solver::estimate::CoefficientPriorMean::Zero,
structure_hint: None,
op: None,
}
}
pub fn with_prior_mean(
mut self,
prior_mean: crate::solver::estimate::CoefficientPriorMean,
) -> Self {
self.prior_mean = prior_mean;
self
}
pub fn with_op(
mut self,
op: Option<std::sync::Arc<dyn crate::terms::analytic_penalties::PenaltyOp>>,
) -> Self {
self.op = op;
self
}
pub fn ridge(col_range: Range<usize>, scale: f64) -> Self {
let block_size = col_range.len();
let mut local = Array2::<f64>::zeros((block_size, block_size));
for i in 0..block_size {
local[[i, i]] = scale;
}
Self {
col_range,
local,
prior_mean: crate::solver::estimate::CoefficientPriorMean::Zero,
structure_hint: Some(PenaltyStructureHint::Ridge(scale)),
op: None,
}
}
pub fn kronecker(
col_range: Range<usize>,
local: Array2<f64>,
factors: Vec<Array2<f64>>,
) -> Self {
assert_eq!(col_range.len(), local.nrows());
assert_eq!(col_range.len(), local.ncols());
Self {
col_range,
local,
prior_mean: crate::solver::estimate::CoefficientPriorMean::Zero,
structure_hint: Some(PenaltyStructureHint::Kronecker(factors)),
op: None,
}
}
pub fn to_global(&self, p_total: usize) -> Array2<f64> {
let mut g = Array2::<f64>::zeros((p_total, p_total));
let r = &self.col_range;
assert!(
r.end <= p_total && self.local.nrows() == r.len() && self.local.ncols() == r.len(),
"BlockwisePenalty::to_global shape invariant violated: \
col_range={}..{}, local={}x{}, p_total={}",
r.start,
r.end,
self.local.nrows(),
self.local.ncols(),
p_total,
);
g.slice_mut(s![r.start..r.end, r.start..r.end])
.assign(&self.local);
g
}
pub(crate) fn to_penalty_matrix(
&self,
total_dim: usize,
) -> crate::custom_family::PenaltyMatrix {
crate::custom_family::PenaltyMatrix::Blockwise {
local: self.local.clone(),
col_range: self.col_range.clone(),
total_dim,
}
}
#[inline]
pub fn block_size(&self) -> usize {
self.col_range.len()
}
}
pub fn weighted_blockwise_penalty_sum(
penalties: &[BlockwisePenalty],
lambdas: &[f64],
p_total: usize,
) -> Array2<f64> {
assert_eq!(penalties.len(), lambdas.len());
for (idx, &lam) in lambdas.iter().enumerate() {
assert!(
lam.is_finite() && lam >= 0.0,
"weighted_blockwise_penalty_sum: lambdas[{idx}] = {lam} is invalid (must be finite and non-negative; negative smoothing parameters violate S_λ ⪰ 0)",
);
}
for (idx, bp) in penalties.iter().enumerate() {
let r = &bp.col_range;
assert!(
r.end <= p_total,
"weighted_blockwise_penalty_sum: penalties[{idx}] col_range {:?} exceeds p_total = {p_total}",
r,
);
}
let mut out = Array2::<f64>::zeros((p_total, p_total));
for (bp, &lam) in penalties.iter().zip(lambdas.iter()) {
let r = &bp.col_range;
let mut slice = out.slice_mut(s![r.start..r.end, r.start..r.end]);
slice.scaled_add(lam, &bp.local);
}
out
}
#[derive(Debug, Clone)]
pub struct KroneckerPenaltySystem {
pub marginal_penalties: Vec<Array2<f64>>,
pub marginal_eigensystems: Vec<(Array1<f64>, Array2<f64>)>,
pub marginal_dims: Vec<usize>,
pub has_double_penalty: bool,
}
impl KroneckerPenaltySystem {
pub fn new(
marginal_penalties: Vec<Array2<f64>>,
marginal_dims: Vec<usize>,
has_double_penalty: bool,
) -> Result<Self, BasisError> {
if marginal_penalties.len() != marginal_dims.len() {
crate::bail_dim_basis!(
"KroneckerPenaltySystem: {} penalties vs {} dims",
marginal_penalties.len(),
marginal_dims.len()
);
}
let eigensystems =
kronecker_marginal_eigensystems(&marginal_penalties, "KroneckerPenaltySystem")
.map_err(|e| BasisError::InvalidInput(e.to_string()))?;
Ok(Self {
marginal_penalties,
marginal_eigensystems: eigensystems,
marginal_dims,
has_double_penalty,
})
}
pub fn p_total(&self) -> usize {
self.marginal_dims.iter().copied().product()
}
pub fn ndim(&self) -> usize {
self.marginal_dims.len()
}
pub fn num_penalties(&self) -> usize {
self.marginal_dims.len() + if self.has_double_penalty { 1 } else { 0 }
}
pub fn logdet_and_derivatives(
&self,
lambdas: &[f64],
ridge: f64,
) -> (f64, Array1<f64>, Array2<f64>) {
let n_pen = self.num_penalties();
assert_eq!(lambdas.len(), n_pen, "lambda count mismatch");
let marginal_evals: Vec<_> = self
.marginal_eigensystems
.iter()
.map(|(evals, _)| evals.view())
.collect();
kronecker_logdet_and_derivatives(
&marginal_evals,
&self.marginal_dims,
lambdas,
self.has_double_penalty,
ridge,
)
}
pub fn logdet_rank_and_derivatives(
&self,
lambdas: &[f64],
ridge: f64,
) -> (f64, usize, Array1<f64>, Array2<f64>) {
let n_pen = self.num_penalties();
assert_eq!(lambdas.len(), n_pen, "lambda count mismatch");
let d = self.marginal_dims.len();
let mut logdet = 0.0;
let mut rank = 0usize;
let mut grad = Array1::<f64>::zeros(n_pen);
let mut hess = Array2::<f64>::zeros((n_pen, n_pen));
const EIGENVALUE_POSITIVITY_FLOOR: f64 = 1e-12;
const STRUCTURAL_ZERO_FLOOR: f64 = 1e-12;
let mut multi_idx = vec![0usize; d];
loop {
let mut sigma = 0.0;
let mut structural_sigma = 0.0;
for k in 0..d {
let marginal_eigenvalue = self.marginal_eigensystems[k].0[multi_idx[k]];
structural_sigma += marginal_eigenvalue;
sigma += lambdas[k] * marginal_eigenvalue;
}
let joint_null = structural_sigma <= STRUCTURAL_ZERO_FLOOR;
if self.has_double_penalty && joint_null {
sigma += lambdas[d];
}
if structural_sigma > STRUCTURAL_ZERO_FLOOR {
sigma += ridge;
}
if sigma > EIGENVALUE_POSITIVITY_FLOOR {
rank += 1;
logdet += sigma.ln();
let inv_sigma = 1.0 / sigma;
let inv_sigma2 = inv_sigma * inv_sigma;
for k in 0..n_pen {
let ck = if k < d {
lambdas[k] * self.marginal_eigensystems[k].0[multi_idx[k]]
} else if joint_null {
lambdas[d]
} else {
0.0
};
grad[k] += ck * inv_sigma;
hess[[k, k]] += ck * inv_sigma - ck * ck * inv_sigma2;
for l in (k + 1)..n_pen {
let cl = if l < d {
lambdas[l] * self.marginal_eigensystems[l].0[multi_idx[l]]
} else if joint_null {
lambdas[d]
} else {
0.0
};
let off = -ck * cl * inv_sigma2;
hess[[k, l]] += off;
hess[[l, k]] += off;
}
}
}
let mut carry = true;
for dim in (0..d).rev() {
if carry {
multi_idx[dim] += 1;
if multi_idx[dim] < self.marginal_dims[dim] {
carry = false;
} else {
multi_idx[dim] = 0;
}
}
}
if carry {
break;
}
}
(logdet, rank, grad, hess)
}
}
#[cfg(test)]
mod kronecker_penalty_system_tests {
use super::KroneckerPenaltySystem;
use ndarray::array;
#[test]
fn double_penalty_rank_derivatives_use_only_joint_null_space() {
let penalties = vec![
array![[0.0, 0.0], [0.0, 2.0]],
array![[0.0, 0.0], [0.0, 3.0]],
];
let system = KroneckerPenaltySystem::new(penalties, vec![2usize, 2usize], true).unwrap();
let lambdas = vec![5.0, 7.0, 11.0];
let (logdet, rank, grad, hess) = system.logdet_rank_and_derivatives(&lambdas, 0.0);
let expected_diag = [11.0_f64, 21.0, 10.0, 31.0];
let expected_logdet: f64 = expected_diag.iter().map(|v| v.ln()).sum();
assert_eq!(rank, 4);
assert!((logdet - expected_logdet).abs() <= 1e-12);
assert!(
(grad[2] - 1.0).abs() <= 1e-12,
"double-penalty rank derivative must count only the joint null mode, got {}",
grad[2]
);
assert!(hess[[2, 2]].abs() <= 1e-12);
}
}
#[derive(Clone, Debug)]
pub struct TermCollectionDesign {
pub design: DesignMatrix,
pub penalties: Vec<BlockwisePenalty>,
pub nullspace_dims: Vec<usize>,
pub penaltyinfo: Vec<PenaltyBlockInfo>,
pub dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
pub coefficient_lower_bounds: Option<Array1<f64>>,
pub linear_constraints: Option<LinearInequalityConstraints>,
pub intercept_range: Range<usize>,
pub linear_ranges: Vec<(String, Range<usize>)>,
pub random_effect_ranges: Vec<(String, Range<usize>)>,
pub random_effect_levels: Vec<(String, Vec<u64>)>,
pub smooth: SmoothDesign,
}
impl TermCollectionDesign {
pub fn penalties_as_penalty_matrix(&self) -> Vec<crate::custom_family::PenaltyMatrix> {
let p = self.design.ncols();
self.penalties
.iter()
.map(|bp| bp.to_penalty_matrix(p))
.collect()
}
#[inline]
pub fn num_penalties(&self) -> usize {
self.penalties.len()
}
pub fn realize_coefficient_groups(
&self,
groups: &[CoefficientGroupSpec],
base_prior: &crate::types::RhoPrior,
) -> Result<RealizedCoefficientGroups, BasisError> {
realize_coefficient_groups(self, groups, base_prior)
}
pub fn kronecker_penalty_system(&self) -> Option<KroneckerPenaltySystem> {
let kron_terms: Vec<&KroneckerFactoredBasis> = self
.smooth
.terms
.iter()
.filter_map(|t| t.kronecker_factored.as_ref())
.collect();
if kron_terms.len() != 1 {
return None; }
let kron = kron_terms[0];
let has_non_kron_smooth_terms = self
.smooth
.terms
.iter()
.any(|t| t.kronecker_factored.is_none());
if has_non_kron_smooth_terms {
return None; }
KroneckerPenaltySystem::new(
kron.marginal_penalties.clone(),
kron.marginal_dims.clone(),
kron.has_double_penalty,
)
.ok()
}
}
#[derive(Clone)]
pub struct FittedTermCollection {
pub fit: UnifiedFitResult,
pub design: TermCollectionDesign,
pub adaptive_diagnostics: Option<AdaptiveRegularizationDiagnostics>,
}
#[derive(Clone)]
pub struct FittedTermCollectionWithSpec {
pub fit: UnifiedFitResult,
pub design: TermCollectionDesign,
pub resolvedspec: TermCollectionSpec,
pub adaptive_diagnostics: Option<AdaptiveRegularizationDiagnostics>,
}
#[derive(Clone)]
pub struct StandardLatentCoordConfig {
pub values: std::sync::Arc<crate::terms::latent::LatentCoordValues>,
pub term_index: crate::types::SmoothTermIdx,
pub feature_cols: Vec<usize>,
pub manifold: crate::terms::latent::LatentManifold,
pub manifold_auto: bool,
pub retraction_registry: crate::solver::latent_cache::LatentRetractionRegistry,
pub analytic_penalties: Option<std::sync::Arc<crate::terms::AnalyticPenaltyRegistry>>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AdaptiveSpatialMap {
pub termname: String,
pub feature_cols: Vec<usize>,
pub collocation_points: Array2<f64>,
pub inv_magweight: Array1<f64>,
pub invgradweight: Array1<f64>,
pub inv_lapweight: Array1<f64>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AdaptiveRegularizationDiagnostics {
pub epsilon_0: f64,
pub epsilon_g: f64,
pub epsilon_c: f64,
pub epsilon_outer_iterations: usize,
pub mm_iterations: usize,
pub converged: bool,
pub maps: Vec<AdaptiveSpatialMap>,
}
#[derive(Debug, Clone)]
struct LinearColumnConditioning {
col_idx: usize,
mean: f64,
scale: f64,
}
#[derive(Debug, Clone, Default)]
struct LinearFitConditioning {
intercept_idx: usize,
columns: Vec<LinearColumnConditioning>,
}
#[derive(Clone)]
pub(crate) struct SpatialPsiDerivative {
pub penalty_index: usize,
pub penalty_indices: Vec<usize>,
pub global_range: Range<usize>,
pub total_p: usize,
pub x_psi_local: Array2<f64>,
pub s_psi_components_local: Vec<Array2<f64>>,
pub x_psi_psi_local: Array2<f64>,
pub s_psi_psi_components_local: Vec<Array2<f64>>,
pub aniso_group_id: Option<usize>,
pub aniso_cross_designs: Option<Vec<(usize, Array2<f64>)>>,
pub aniso_cross_penalty_provider: Option<
std::sync::Arc<
dyn Fn(usize) -> Result<Vec<Array2<f64>>, EstimationError> + Send + Sync + 'static,
>,
>,
pub implicit_operator: Option<std::sync::Arc<crate::terms::basis::ImplicitDesignPsiDerivative>>,
pub implicit_axis: usize,
}
#[derive(Debug, Clone)]
pub(crate) struct SpatialLogKappaCoords {
values: Array1<f64>,
dims_per_term: Vec<usize>,
}
#[derive(Clone, Copy)]
enum AnisoBoundEnd {
Lower,
Upper,
}
impl SpatialLogKappaCoords {
pub(crate) fn new_with_dims(values: Array1<f64>, dims_per_term: Vec<usize>) -> Self {
assert_eq!(
values.len(),
dims_per_term.iter().sum::<usize>(),
"SpatialLogKappaCoords: values length {} != sum of dims_per_term {}",
values.len(),
dims_per_term.iter().sum::<usize>(),
);
Self {
values,
dims_per_term,
}
}
pub(crate) fn from_length_scales(
spec: &TermCollectionSpec,
term_indices: &[usize],
options: &SpatialLengthScaleOptimizationOptions,
) -> Self {
let mut out = Array1::<f64>::zeros(term_indices.len());
for (slot, &term_idx) in term_indices.iter().enumerate() {
if let Some(cc) = constant_curvature_term_spec(spec, term_idx) {
out[slot] = cc.kappa;
continue;
}
let length_scale = get_spatial_length_scale(spec, term_idx)
.unwrap_or(options.min_length_scale)
.clamp(options.min_length_scale, options.max_length_scale);
out[slot] = -length_scale.ln();
}
Self {
values: out,
dims_per_term: vec![1; term_indices.len()],
}
}
pub(crate) fn from_length_scales_aniso(
spec: &TermCollectionSpec,
term_indices: &[usize],
options: &SpatialLengthScaleOptimizationOptions,
) -> Self {
let mut vals = Vec::new();
let mut dims = Vec::new();
for &term_idx in term_indices {
if let Some(mj) = measure_jet_term_spec(spec, term_idx) {
let seed = measure_jet_psi_seed(mj);
dims.push(seed.len());
vals.extend(seed);
continue;
}
if let Some(cc) = constant_curvature_term_spec(spec, term_idx) {
vals.push(cc.kappa);
dims.push(1);
continue;
}
let length_scale = get_spatial_length_scale(spec, term_idx)
.unwrap_or(options.min_length_scale)
.clamp(options.min_length_scale, options.max_length_scale);
let psi_bar = -length_scale.ln();
if spatial_term_uses_per_axis_psi(spec, term_idx) {
let d = get_spatial_feature_dim(spec, term_idx).unwrap_or(1);
let eta_raw = get_spatial_aniso_log_scales(spec, term_idx)
.expect("predicate guarantees aniso_log_scales is Some");
let eta = center_aniso_log_scales(&eta_raw);
for &eta_a in &eta {
vals.push(psi_bar + eta_a);
}
dims.push(d);
} else {
vals.push(psi_bar);
dims.push(1);
}
}
Self {
values: Array1::from_vec(vals),
dims_per_term: dims,
}
}
pub(crate) fn lower_bounds_from_data(
data: ArrayView2<'_, f64>,
spec: &TermCollectionSpec,
term_indices: &[usize],
options: &SpatialLengthScaleOptimizationOptions,
) -> Self {
let mut values = Array1::<f64>::zeros(term_indices.len());
for (slot, &term_idx) in term_indices.iter().enumerate() {
values[slot] = spatial_term_psi_bounds(data, spec, term_idx, options).0;
}
Self {
values,
dims_per_term: vec![1; term_indices.len()],
}
}
pub(crate) fn upper_bounds_from_data(
data: ArrayView2<'_, f64>,
spec: &TermCollectionSpec,
term_indices: &[usize],
options: &SpatialLengthScaleOptimizationOptions,
) -> Self {
let mut values = Array1::<f64>::zeros(term_indices.len());
for (slot, &term_idx) in term_indices.iter().enumerate() {
values[slot] = spatial_term_psi_bounds(data, spec, term_idx, options).1;
}
Self {
values,
dims_per_term: vec![1; term_indices.len()],
}
}
pub(crate) fn lower_bounds_aniso_from_data(
data: ArrayView2<'_, f64>,
spec: &TermCollectionSpec,
term_indices: &[usize],
dims_per_term: &[usize],
options: &SpatialLengthScaleOptimizationOptions,
) -> Self {
Self::aniso_bounds_from_data(
data,
spec,
term_indices,
dims_per_term,
options,
AnisoBoundEnd::Lower,
)
}
pub(crate) fn upper_bounds_aniso_from_data(
data: ArrayView2<'_, f64>,
spec: &TermCollectionSpec,
term_indices: &[usize],
dims_per_term: &[usize],
options: &SpatialLengthScaleOptimizationOptions,
) -> Self {
Self::aniso_bounds_from_data(
data,
spec,
term_indices,
dims_per_term,
options,
AnisoBoundEnd::Upper,
)
}
fn aniso_bounds_from_data(
data: ArrayView2<'_, f64>,
spec: &TermCollectionSpec,
term_indices: &[usize],
dims_per_term: &[usize],
options: &SpatialLengthScaleOptimizationOptions,
end: AnisoBoundEnd,
) -> Self {
assert_eq!(term_indices.len(), dims_per_term.len());
let total: usize = dims_per_term.iter().sum();
let mut values = Array1::<f64>::zeros(total);
let mut cursor = 0;
for (slot, &term_idx) in term_indices.iter().enumerate() {
let d = dims_per_term[slot];
if let Some(mj) = measure_jet_term_spec(spec, term_idx) {
let bounds = measure_jet_psi_bound_values(mj, matches!(end, AnisoBoundEnd::Upper));
for (offset, bound) in bounds.into_iter().enumerate() {
if offset < d {
values[cursor + offset] = bound;
}
}
cursor += d;
continue;
}
if constant_curvature_term_spec(spec, term_idx).is_some() {
let (lo, hi) = constant_curvature_kappa_bounds(data, spec, term_idx);
if d >= 1 {
values[cursor] = match end {
AnisoBoundEnd::Lower => lo,
AnisoBoundEnd::Upper => hi,
};
}
cursor += d;
continue;
}
let psi_bound = {
let (lo, hi) = spatial_term_psi_bounds(data, spec, term_idx, options);
match end {
AnisoBoundEnd::Lower => lo,
AnisoBoundEnd::Upper => hi,
}
};
let axis_offsets = if d <= 1 {
vec![0.0; d]
} else {
get_spatial_aniso_log_scales(spec, term_idx)
.filter(|eta| eta.len() == d)
.map(|eta| center_aniso_log_scales(&eta))
.unwrap_or_else(|| vec![0.0; d])
};
for offset in 0..d {
values[cursor + offset] = psi_bound + axis_offsets[offset];
}
cursor += d;
}
Self {
values,
dims_per_term: dims_per_term.to_vec(),
}
}
pub(crate) fn reseed_from_data(
mut self,
data: ArrayView2<'_, f64>,
spec: &TermCollectionSpec,
term_indices: &[usize],
options: &SpatialLengthScaleOptimizationOptions,
) -> Self {
assert_eq!(term_indices.len(), self.dims_per_term.len());
let mut cursor = 0;
for (slot, &term_idx) in term_indices.iter().enumerate() {
let d = self.dims_per_term[slot];
if measure_jet_term_spec(spec, term_idx).is_some() {
cursor += d;
continue;
}
if constant_curvature_term_spec(spec, term_idx).is_some() {
cursor += d;
continue;
}
let Some(psi_bar_new) = spatial_term_psi_seed(data, spec, term_idx, options) else {
cursor += d;
continue;
};
if d == 0 {
continue;
}
let current: Vec<f64> = self.values.slice(s![cursor..cursor + d]).to_vec();
let psi_bar_old = current.iter().sum::<f64>() / d as f64;
for (offset, &old_value) in current.iter().enumerate() {
self.values[cursor + offset] = psi_bar_new + (old_value - psi_bar_old);
}
cursor += d;
}
self
}
pub(crate) fn clamp_to_bounds(
mut self,
lower: &SpatialLogKappaCoords,
upper: &SpatialLogKappaCoords,
) -> Self {
assert_eq!(self.values.len(), lower.values.len());
assert_eq!(self.values.len(), upper.values.len());
let mut n_projected = 0usize;
let mut worst_delta = 0.0_f64;
for idx in 0..self.values.len() {
let lo = lower.values[idx];
let hi = upper.values[idx];
if !(lo.is_finite() && hi.is_finite()) {
continue;
}
let v = self.values[idx];
if v < lo {
worst_delta = worst_delta.max(lo - v);
self.values[idx] = lo;
n_projected += 1;
} else if v > hi {
worst_delta = worst_delta.max(v - hi);
self.values[idx] = hi;
n_projected += 1;
}
}
if n_projected > 0 {
log::info!(
"[spatial-kappa] projected {n_projected}/{} ψ seed coords into data-derived bounds \
(worst excess={worst_delta:.3} log units); user length_scale falls outside \
[{KERNEL_RANGE_MIN_DIAMETER_FRACTION}/r_max, {KERNEL_RANGE_MAX_SPACING_MULTIPLE}/r_min] geometry window",
self.values.len()
);
}
self
}
pub(crate) fn from_theta_tail_with_dims(
theta: &Array1<f64>,
start: usize,
dims_per_term: Vec<usize>,
) -> Self {
let total: usize = dims_per_term.iter().sum();
Self {
values: theta.slice(s![start..start + total]).to_owned(),
dims_per_term,
}
}
pub(crate) fn len(&self) -> usize {
self.values.len()
}
pub(crate) fn dims_per_term(&self) -> &[usize] {
&self.dims_per_term
}
fn term_offset(&self, term_idx: usize) -> usize {
self.dims_per_term[..term_idx].iter().sum()
}
fn term_slice(&self, term_idx: usize) -> &[f64] {
let offset = self.term_offset(term_idx);
let d = self.dims_per_term[term_idx];
&self.values.as_slice().unwrap()[offset..offset + d]
}
pub(crate) fn as_array(&self) -> &Array1<f64> {
&self.values
}
pub(crate) fn split_at(&self, mid: usize) -> (Self, Self) {
let flat_mid: usize = self.dims_per_term[..mid].iter().sum();
(
Self {
values: self.values.slice(s![0..flat_mid]).to_owned(),
dims_per_term: self.dims_per_term[..mid].to_vec(),
},
Self {
values: self.values.slice(s![flat_mid..]).to_owned(),
dims_per_term: self.dims_per_term[mid..].to_vec(),
},
)
}
pub(crate) fn apply_tospec(
&self,
spec: &TermCollectionSpec,
term_indices: &[usize],
) -> Result<TermCollectionSpec, EstimationError> {
if term_indices.len() != self.dims_per_term.len() {
crate::bail_invalid_estim!(
"SpatialLogKappaCoords::apply_tospec: term count mismatch: \
term_indices={} dims_per_term={}",
term_indices.len(),
self.dims_per_term.len()
);
}
let mut updated = spec.clone();
for (slot, &term_idx) in term_indices.iter().enumerate() {
let psi = self.term_slice(slot);
let d = self.dims_per_term[slot];
if measure_jet_term_spec(&updated, term_idx).is_some() {
set_measure_jet_psi_dials(&mut updated, term_idx, psi)?;
continue;
}
if constant_curvature_term_spec(&updated, term_idx).is_some() {
set_constant_curvature_kappa(&mut updated, term_idx, psi)?;
continue;
}
let (next_length_scale, next_aniso) = spatial_term_psi_to_length_scale_and_aniso(psi);
if (d == 1 || next_length_scale.is_some())
&& let Some(length_scale) = next_length_scale
{
set_spatial_length_scale(&mut updated, term_idx, length_scale)?;
}
if let Some(eta) = next_aniso {
set_spatial_aniso_log_scales(&mut updated, term_idx, eta)?;
}
}
Ok(updated)
}
}
pub(crate) fn center_aniso_log_scales(eta: &[f64]) -> Vec<f64> {
if eta.len() <= 1 {
return eta.to_vec();
}
let mean = eta.iter().sum::<f64>() / eta.len() as f64;
eta.iter()
.map(|&v| {
let centered = v - mean;
if centered.abs() <= 1e-15 {
0.0
} else {
centered
}
})
.collect()
}
fn spatial_term_supports_hyper_optimization(spec: &TermCollectionSpec, term_idx: usize) -> bool {
if let Some(term) = spec.smooth_terms.get(term_idx)
&& let SmoothBasisSpec::ThinPlate { .. } = &term.basis
{
return false;
}
if let Some(term) = spec.smooth_terms.get(term_idx)
&& let SmoothBasisSpec::Matern { .. } = &term.basis
{
return true;
}
if let Some(mj) = measure_jet_term_spec(spec, term_idx) {
return measure_jet_enrolls_psi(mj);
}
if constant_curvature_term_spec(spec, term_idx).is_some() {
return true;
}
get_spatial_length_scale(spec, term_idx).is_some()
}
fn measure_jet_term_spec(
spec: &TermCollectionSpec,
term_idx: usize,
) -> Option<&crate::basis::MeasureJetBasisSpec> {
spec.smooth_terms
.get(term_idx)
.and_then(|term| match &term.basis {
SmoothBasisSpec::MeasureJet { spec, .. } => Some(spec),
_ => None,
})
}
fn measure_jet_enrolls_psi(mj: &crate::basis::MeasureJetBasisSpec) -> bool {
measure_jet_learns_length_scale(mj)
|| (mj.tau0 > 0.0 && crate::basis::measure_jet_multiscale_mode(mj))
}
fn measure_jet_learns_length_scale(mj: &crate::basis::MeasureJetBasisSpec) -> bool {
mj.learn_length_scale
}
const MEASURE_JET_PSI_ALPHA_BOUNDS: (f64, f64) = (-1.0, 3.0);
const MEASURE_JET_PSI_LN_TAU_BOUNDS: (f64, f64) = (-18.420680743952367, 4.605170185988092);
const MEASURE_JET_PSI_LN_LENGTH_SCALE_BOUNDS: (f64, f64) = (-6.907755278982137, 4.605170185988092);
fn measure_jet_penalty_psi_dim(mj: &crate::basis::MeasureJetBasisSpec) -> usize {
if crate::basis::measure_jet_multiscale_mode(mj) {
2
} else {
0
}
}
fn measure_jet_psi_dim(mj: &crate::basis::MeasureJetBasisSpec) -> usize {
usize::from(measure_jet_learns_length_scale(mj)) + measure_jet_penalty_psi_dim(mj)
}
fn measure_jet_psi_seed(mj: &crate::basis::MeasureJetBasisSpec) -> Vec<f64> {
let mut seed = Vec::with_capacity(measure_jet_psi_dim(mj));
if measure_jet_learns_length_scale(mj) {
let ell = if mj.length_scale > 0.0 {
mj.length_scale
} else {
1.0
};
seed.push(ell.ln());
}
if measure_jet_penalty_psi_dim(mj) > 0 {
let ln_tau = mj.tau0.max(f64::MIN_POSITIVE).ln();
seed.extend_from_slice(&[mj.alpha, ln_tau]);
}
seed
}
fn measure_jet_psi_bound_values(mj: &crate::basis::MeasureJetBasisSpec, upper: bool) -> Vec<f64> {
let pick = |b: (f64, f64)| if upper { b.1 } else { b.0 };
let mut bounds = Vec::with_capacity(measure_jet_psi_dim(mj));
if measure_jet_learns_length_scale(mj) {
bounds.push(pick(MEASURE_JET_PSI_LN_LENGTH_SCALE_BOUNDS));
}
if measure_jet_penalty_psi_dim(mj) > 0 {
bounds.push(pick(MEASURE_JET_PSI_ALPHA_BOUNDS));
bounds.push(pick(MEASURE_JET_PSI_LN_TAU_BOUNDS));
}
bounds
}
fn apply_measure_jet_psi(
mj: &mut crate::basis::MeasureJetBasisSpec,
psi: &[f64],
) -> Result<bool, EstimationError> {
if psi.len() != measure_jet_psi_dim(mj) {
crate::bail_invalid_estim!(
"measure-jet ψ write-back dimension mismatch: got {} values for a {}-dial term",
psi.len(),
measure_jet_psi_dim(mj)
);
}
let mut changed = false;
let mut cursor = 0usize;
if measure_jet_learns_length_scale(mj) {
let next_ell = psi[cursor].exp();
cursor += 1;
if !(next_ell.is_finite() && next_ell > 0.0) {
crate::bail_invalid_estim!(
"measure-jet ψ write-back produced a non-finite/non-positive length_scale (ℓ={next_ell})"
);
}
if next_ell != mj.length_scale {
mj.length_scale = next_ell;
changed = true;
}
}
if measure_jet_penalty_psi_dim(mj) > 0 {
let next_alpha = psi[cursor];
let next_tau = psi[cursor + 1].exp();
if !(next_alpha.is_finite() && next_tau.is_finite() && next_tau > 0.0) {
crate::bail_invalid_estim!(
"measure-jet ψ write-back produced non-finite dials (alpha={next_alpha}, tau={next_tau})"
);
}
if next_alpha != mj.alpha {
mj.alpha = next_alpha;
changed = true;
}
if next_tau != mj.tau0 {
mj.tau0 = next_tau;
changed = true;
}
}
Ok(changed)
}
fn set_measure_jet_psi_dials(
spec: &mut TermCollectionSpec,
term_idx: usize,
psi: &[f64],
) -> Result<bool, EstimationError> {
let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
crate::bail_invalid_estim!("measure-jet ψ write-back: term index {term_idx} out of range");
};
set_single_term_measure_jet_psi_dials(term, psi)
}
fn set_single_term_measure_jet_psi_dials(
term: &mut SmoothTermSpec,
psi: &[f64],
) -> Result<bool, EstimationError> {
let SmoothBasisSpec::MeasureJet { spec: mj, .. } = &mut term.basis else {
crate::bail_invalid_estim!("measure-jet ψ write-back targeted a non-measure-jet term");
};
apply_measure_jet_psi(mj, psi)
}
fn constant_curvature_term_spec(
spec: &TermCollectionSpec,
term_idx: usize,
) -> Option<&crate::basis::ConstantCurvatureBasisSpec> {
spec.smooth_terms
.get(term_idx)
.and_then(|term| match &term.basis {
SmoothBasisSpec::ConstantCurvature { spec, .. } => Some(spec),
_ => None,
})
}
const CONSTANT_CURVATURE_KAPPA_CHART_FRACTION: f64 = 0.5;
const CONSTANT_CURVATURE_MIN_CHART_RADIUS2: f64 = 1e-8;
fn constant_curvature_kappa_bounds(
data: ArrayView2<'_, f64>,
spec: &TermCollectionSpec,
term_idx: usize,
) -> (f64, f64) {
let feature_cols = match spec.smooth_terms.get(term_idx).map(|t| &t.basis) {
Some(SmoothBasisSpec::ConstantCurvature { feature_cols, .. }) => feature_cols,
_ => return (-1.0, 1.0),
};
let mut max_r2 = CONSTANT_CURVATURE_MIN_CHART_RADIUS2;
for row in data.outer_iter() {
let mut r2 = 0.0_f64;
for &c in feature_cols.iter() {
if let Some(&v) = row.get(c)
&& v.is_finite()
{
r2 += v * v;
}
}
if r2 > max_r2 {
max_r2 = r2;
}
}
let half = CONSTANT_CURVATURE_KAPPA_CHART_FRACTION / max_r2;
(-half, half)
}
fn set_constant_curvature_kappa(
spec: &mut TermCollectionSpec,
term_idx: usize,
psi: &[f64],
) -> Result<bool, EstimationError> {
let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
crate::bail_invalid_estim!(
"constant-curvature κ write-back: term index {term_idx} out of range"
);
};
set_single_term_constant_curvature_kappa(term, psi)
}
fn set_single_term_constant_curvature_kappa(
term: &mut SmoothTermSpec,
psi: &[f64],
) -> Result<bool, EstimationError> {
if psi.len() != 1 {
crate::bail_invalid_estim!(
"constant-curvature κ write-back expects exactly one value, got {}",
psi.len()
);
}
let next_kappa = psi[0];
if !next_kappa.is_finite() {
crate::bail_invalid_estim!(
"constant-curvature κ write-back produced a non-finite κ = {next_kappa}"
);
}
let SmoothBasisSpec::ConstantCurvature { spec: cc, .. } = &mut term.basis else {
crate::bail_invalid_estim!(
"constant-curvature κ write-back targeted a non-constant-curvature term"
);
};
if cc.kappa != next_kappa {
cc.kappa = next_kappa;
Ok(true)
} else {
Ok(false)
}
}
pub fn spatial_term_has_locked_kappa(spec: &TermCollectionSpec, term_idx: usize) -> bool {
get_spatial_length_scale(spec, term_idx).is_some()
&& !spatial_term_uses_per_axis_psi(spec, term_idx)
}
const KERNEL_RANGE_MIN_DIAMETER_FRACTION: f64 = 2.0;
const KERNEL_RANGE_MAX_SPACING_MULTIPLE: f64 = 1e2;
fn spatial_term_psi_bounds(
data: ArrayView2<'_, f64>,
spec: &TermCollectionSpec,
term_idx: usize,
options: &SpatialLengthScaleOptimizationOptions,
) -> (f64, f64) {
let fallback = (
-options.max_length_scale.ln(),
-options.min_length_scale.ln(),
);
if constant_curvature_term_spec(spec, term_idx).is_some() {
return constant_curvature_kappa_bounds(data, spec, term_idx);
}
let Some(term) = spec.smooth_terms.get(term_idx) else {
return fallback;
};
let aniso = get_spatial_aniso_log_scales(spec, term_idx);
let r_bounds = match spatial_term_center_strategy(term) {
Some(CenterStrategy::UserProvided(centers)) if centers.nrows() >= 2 => {
match aniso.as_deref() {
Some(eta) if eta.len() == centers.ncols() => {
let y = points_in_aniso_y_space(centers.view(), eta);
pairwise_distance_bounds(y.view())
}
_ => pairwise_distance_bounds(centers.view()),
}
}
_ => standardized_spatial_term_data(data, term)
.ok()
.and_then(|x| match aniso.as_deref() {
Some(eta) if eta.len() == x.ncols() => {
let y = points_in_aniso_y_space(x.view(), eta);
pairwise_distance_bounds_sampled(y.view())
}
_ => pairwise_distance_bounds_sampled(x.view()),
}),
};
let Some((r_min, r_max)) = r_bounds else {
return fallback;
};
let psi_lo_data = (KERNEL_RANGE_MIN_DIAMETER_FRACTION / r_max).ln();
let psi_hi_data = (KERNEL_RANGE_MAX_SPACING_MULTIPLE / r_min).ln();
let psi_lo = psi_lo_data.max(fallback.0);
let psi_hi = psi_hi_data.min(fallback.1);
if psi_lo >= psi_hi {
return fallback;
}
(psi_lo, psi_hi)
}
fn spatial_term_psi_seed(
data: ArrayView2<'_, f64>,
spec: &TermCollectionSpec,
term_idx: usize,
options: &SpatialLengthScaleOptimizationOptions,
) -> Option<f64> {
if get_spatial_length_scale(spec, term_idx).is_some() {
return None; }
let (psi_lo, psi_hi) = spatial_term_psi_bounds(data, spec, term_idx, options);
Some(0.5 * (psi_lo + psi_hi))
}
fn spatial_term_psi_to_length_scale_and_aniso(psi: &[f64]) -> (Option<f64>, Option<Vec<f64>>) {
if psi.len() <= 1 {
(Some((-psi.first().copied().unwrap_or(0.0)).exp()), None)
} else {
let psi_bar = psi.iter().sum::<f64>() / psi.len() as f64;
(
Some((-psi_bar).exp()),
Some(psi.iter().map(|&value| value - psi_bar).collect()),
)
}
}
pub fn get_spatial_aniso_log_scales(
spec: &TermCollectionSpec,
term_idx: usize,
) -> Option<Vec<f64>> {
spec.smooth_terms
.get(term_idx)
.and_then(|term| match &term.basis {
SmoothBasisSpec::Matern { spec, .. } => spec.aniso_log_scales.clone(),
SmoothBasisSpec::Duchon { spec, .. } => spec.aniso_log_scales.clone(),
_ => None,
})
}
fn get_spatial_feature_dim(spec: &TermCollectionSpec, term_idx: usize) -> Option<usize> {
spec.smooth_terms
.get(term_idx)
.and_then(|term| match &term.basis {
SmoothBasisSpec::ThinPlate { feature_cols, .. } => Some(feature_cols.len()),
SmoothBasisSpec::Matern { feature_cols, .. } => Some(feature_cols.len()),
SmoothBasisSpec::Duchon { feature_cols, .. } => Some(feature_cols.len()),
_ => None,
})
}
pub fn log_spatial_aniso_scales(spec: &TermCollectionSpec) {
for (term_idx, term) in spec.smooth_terms.iter().enumerate() {
let (aniso, length_scale) = match &term.basis {
SmoothBasisSpec::Matern { spec, .. } => {
(spec.aniso_log_scales.as_ref(), Some(spec.length_scale))
}
SmoothBasisSpec::Duchon { spec, .. } => {
(spec.aniso_log_scales.as_ref(), spec.length_scale)
}
_ => (None, None),
};
let Some(eta) = aniso else { continue };
if eta.is_empty() {
continue;
}
let mut lines = match length_scale {
Some(ls) => format!(
"[spatial-kappa] term {} (\"{}\"): anisotropic length scales optimized (global length_scale={:.4})",
term_idx, term.name, ls
),
None => format!(
"[spatial-kappa] term {} (\"{}\"): pure Duchon shape anisotropy optimized",
term_idx, term.name
),
};
for (a, &eta_a) in eta.iter().enumerate() {
if let Some(ls) = length_scale {
let length_a = ls * (-eta_a).exp();
let kappa_a = (1.0 / ls) * eta_a.exp();
lines.push_str(&format!(
"\n axis {}: eta={:+.4}, length={:.4}, kappa={:.4}",
a, eta_a, length_a, kappa_a
));
} else {
lines.push_str(&format!("\n axis {}: eta={:+.4}", a, eta_a));
}
}
log::info!("{}", lines);
}
}
fn set_spatial_aniso_log_scales(
spec: &mut TermCollectionSpec,
term_idx: usize,
eta: Vec<f64>,
) -> Result<(), EstimationError> {
let eta = center_aniso_log_scales(&eta);
let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
crate::bail_invalid_estim!("spatial aniso_log_scales term index {term_idx} out of range");
};
match &mut term.basis {
SmoothBasisSpec::Matern { spec, .. } => {
spec.aniso_log_scales = Some(eta);
Ok(())
}
SmoothBasisSpec::Duchon { spec, .. } => {
spec.aniso_log_scales = Some(eta);
Ok(())
}
_ => Err(EstimationError::InvalidInput(format!(
"term '{}' does not support aniso_log_scales",
term.name
))),
}
}
pub(crate) fn sync_aniso_contrasts_from_metadata(
spec: &mut TermCollectionSpec,
design: &SmoothDesign,
) {
for (term_idx, term) in design.terms.iter().enumerate() {
let meta_aniso = match &term.metadata {
BasisMetadata::Matern {
aniso_log_scales, ..
} => aniso_log_scales.clone(),
BasisMetadata::Duchon {
aniso_log_scales, ..
} => aniso_log_scales.clone(),
_ => None,
};
if let Some(eta) = meta_aniso
&& eta.len() > 1
{
set_spatial_aniso_log_scales(spec, term_idx, eta).ok();
}
}
}
#[derive(Debug, Clone)]
pub struct SpatialLengthScaleOptimizationOptions {
pub enabled: bool,
pub max_outer_iter: usize,
pub rel_tol: f64,
pub log_step: f64,
pub min_length_scale: f64,
pub max_length_scale: f64,
pub pilot_subsample_threshold: usize,
}
impl Default for SpatialLengthScaleOptimizationOptions {
fn default() -> Self {
Self {
enabled: true,
max_outer_iter: 80,
rel_tol: 1e-4,
log_step: std::f64::consts::LN_2,
min_length_scale: 1e-3,
max_length_scale: 1e3,
pilot_subsample_threshold: 10_000,
}
}
}
impl SpatialLengthScaleOptimizationOptions {
pub fn validate(&self) -> Result<(), String> {
if !self.min_length_scale.is_finite() || self.min_length_scale <= 0.0 {
return Err(SmoothError::invalid_config(format!(
"SpatialLengthScaleOptimizationOptions::min_length_scale must be > 0 and finite, got {}",
self.min_length_scale
))
.into());
}
if !self.max_length_scale.is_finite() || self.max_length_scale <= 0.0 {
return Err(SmoothError::invalid_config(format!(
"SpatialLengthScaleOptimizationOptions::max_length_scale must be > 0 and finite, got {}",
self.max_length_scale
))
.into());
}
if self.min_length_scale >= self.max_length_scale {
return Err(SmoothError::invalid_config(format!(
"SpatialLengthScaleOptimizationOptions requires min_length_scale < max_length_scale, got min={} max={}",
self.min_length_scale, self.max_length_scale
))
.into());
}
if !self.rel_tol.is_finite() || self.rel_tol <= 0.0 {
return Err(SmoothError::invalid_config(format!(
"SpatialLengthScaleOptimizationOptions::rel_tol must be > 0 and finite, got {}",
self.rel_tol
))
.into());
}
if !self.log_step.is_finite() || self.log_step <= 0.0 {
return Err(SmoothError::invalid_config(format!(
"SpatialLengthScaleOptimizationOptions::log_step must be > 0 and finite, got {}",
self.log_step
))
.into());
}
Ok(())
}
}
#[derive(Debug, Clone)]
struct RandomEffectBlock {
name: String,
group_ids: Vec<Option<usize>>,
num_groups: usize,
kept_levels: Vec<u64>,
}
const BLOCK_SPARSE_ZERO_EPS: f64 = 1e-12;
const BLOCK_SPARSE_MAX_DENSITY: f64 = 0.20;
fn blocks_have_intrinsic_sparse_structure(blocks: &[DesignBlock]) -> bool {
blocks
.iter()
.any(|block| matches!(block, DesignBlock::Sparse(_) | DesignBlock::RandomEffect(_)))
}
fn sparse_compatible_block_nnz(block: &DesignBlock) -> Option<usize> {
match block {
DesignBlock::Intercept(n) => Some(*n),
DesignBlock::RandomEffect(op) => {
Some(op.group_ids.iter().filter(|gid| gid.is_some()).count())
}
DesignBlock::Sparse(sparse) => Some(sparse.val().len()),
DesignBlock::Dense(dense) => dense.as_dense_ref().map(|matrix| {
matrix
.iter()
.filter(|&&value| value.abs() > BLOCK_SPARSE_ZERO_EPS)
.count()
}),
}
}
fn try_build_sparse_design_from_blocks(
blocks: &[DesignBlock],
) -> Result<Option<DesignMatrix>, BasisError> {
if blocks.is_empty() {
return Ok(None);
}
let nrows = blocks[0].nrows();
let ncols: usize = blocks.iter().map(DesignBlock::ncols).sum();
if nrows == 0 || ncols == 0 || ncols <= 32 {
return Ok(None);
}
let preserve_sparse_storage = blocks_have_intrinsic_sparse_structure(blocks);
let sparse_nnz_limit = if preserve_sparse_storage {
usize::MAX
} else {
let total_cells = nrows.saturating_mul(ncols);
((total_cells as f64) * BLOCK_SPARSE_MAX_DENSITY).floor() as usize
};
let mut nnz = 0usize;
for block in blocks {
let block_nnz = if let Some(block_nnz) = sparse_compatible_block_nnz(block) {
block_nnz
} else {
return Ok(None);
};
nnz = nnz.saturating_add(block_nnz);
if nnz > sparse_nnz_limit {
return Ok(None);
}
}
let mut triplets = Vec::<Triplet<usize, usize, f64>>::with_capacity(nnz);
let mut col_offset = 0usize;
for block in blocks {
match block {
DesignBlock::Intercept(n) => {
for row in 0..*n {
triplets.push(Triplet::new(row, col_offset, 1.0));
}
}
DesignBlock::RandomEffect(op) => {
for (row, group_id) in op.group_ids.iter().enumerate() {
if let Some(group) = group_id {
triplets.push(Triplet::new(row, col_offset + group, 1.0));
}
}
}
DesignBlock::Sparse(sparse) => {
let (symbolic, values) = sparse.parts();
let col_ptr = symbolic.col_ptr();
let row_idx = symbolic.row_idx();
for col in 0..sparse.ncols() {
for idx in col_ptr[col]..col_ptr[col + 1] {
let value = values[idx];
if value.abs() > BLOCK_SPARSE_ZERO_EPS {
triplets.push(Triplet::new(row_idx[idx], col_offset + col, value));
}
}
}
}
DesignBlock::Dense(dense) => {
let matrix = dense.as_dense_ref().ok_or_else(|| {
BasisError::InvalidInput(
"sparse-compatible block assembly requires materialized dense blocks"
.to_string(),
)
})?;
for row in 0..matrix.nrows() {
for col in 0..matrix.ncols() {
let value = matrix[[row, col]];
if value.abs() > BLOCK_SPARSE_ZERO_EPS {
triplets.push(Triplet::new(row, col_offset + col, value));
}
}
}
}
}
col_offset += block.ncols();
}
let sparse = SparseColMat::try_new_from_triplets(nrows, ncols, &triplets).map_err(|_| {
BasisError::SparseCreation("failed to assemble sparse term-collection design".to_string())
})?;
Ok(Some(DesignMatrix::Sparse(
crate::matrix::SparseDesignMatrix::new(sparse),
)))
}
fn assemble_term_collection_design_matrix(
blocks: Vec<DesignBlock>,
) -> Result<DesignMatrix, BasisError> {
if let Some(sparse) = try_build_sparse_design_from_blocks(&blocks)? {
return Ok(sparse);
}
let block_op = BlockDesignOperator::new(blocks).map_err(|e| {
BasisError::InvalidInput(format!("failed to build block design operator: {e}"))
})?;
Ok(DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Arc::new(block_op),
)))
}
fn select_columns(data: ArrayView2<'_, f64>, cols: &[usize]) -> Result<Array2<f64>, BasisError> {
let n = data.nrows();
let p = data.ncols();
for &c in cols {
if c >= p {
crate::bail_dim_basis!("feature column {c} is out of bounds for data with {p} columns");
}
}
let mut out = Array2::<f64>::zeros((n, cols.len()));
for (j, &c) in cols.iter().enumerate() {
out.column_mut(j).assign(&data.column(c));
}
Ok(out)
}
fn nonfinite_value_label(value: f64) -> &'static str {
if value.is_nan() {
"NaN"
} else if value.is_sign_positive() {
"+Inf"
} else {
"-Inf"
}
}
fn validate_term_feature_column_finite(
data: ArrayView2<'_, f64>,
term_kind: &str,
term_name: &str,
feature_col: usize,
) -> Result<(), BasisError> {
let p = data.ncols();
if feature_col >= p {
crate::bail_dim_basis!(
"{term_kind} term '{term_name}' feature column {feature_col} out of bounds for {p} columns"
);
}
for (row, &value) in data.column(feature_col).iter().enumerate() {
if !value.is_finite() {
crate::bail_invalid_basis!(
"{term_kind} term '{term_name}' feature column {feature_col} row {row} contains non-finite value {}",
nonfinite_value_label(value)
);
}
}
Ok(())
}
fn validate_smooth_terms_finite_inputs(
data: ArrayView2<'_, f64>,
terms: &[SmoothTermSpec],
) -> Result<(), BasisError> {
for term in terms {
for feature_col in smooth_term_feature_cols(term) {
validate_term_feature_column_finite(data, "smooth", &term.name, feature_col)?;
}
}
Ok(())
}
fn validate_term_collection_finite_inputs(
data: ArrayView2<'_, f64>,
spec: &TermCollectionSpec,
) -> Result<(), BasisError> {
for term in &spec.linear_terms {
validate_term_feature_column_finite(data, "linear", &term.name, term.feature_col)?;
}
for term in &spec.random_effect_terms {
validate_term_feature_column_finite(data, "random-effect", &term.name, term.feature_col)?;
}
validate_smooth_terms_finite_inputs(data, &spec.smooth_terms)
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
struct JointSpatialCenterGroupKey {
feature_cols: Vec<usize>,
strategy_kind: CenterStrategyKind,
strategy_aux: usize,
requested_num_centers: usize,
input_scale_bits: Option<Vec<u64>>,
}
fn spatial_term_min_center_count(term: &SmoothTermSpec) -> usize {
match &term.basis {
SmoothBasisSpec::ThinPlate { feature_cols, .. } => feature_cols.len() + 1,
SmoothBasisSpec::Duchon {
feature_cols, spec, ..
} => match spec.nullspace_order {
crate::basis::DuchonNullspaceOrder::Zero => 1,
crate::basis::DuchonNullspaceOrder::Linear => feature_cols.len() + 1,
crate::basis::DuchonNullspaceOrder::Degree(degree) => {
crate::basis::duchon_nullspace_dimension(feature_cols.len(), degree)
}
},
SmoothBasisSpec::Matern { .. } => 1,
_ => 1,
}
}
fn spatial_term_group_key(term: &SmoothTermSpec) -> Option<JointSpatialCenterGroupKey> {
let (feature_cols, strategy, input_scales) = match &term.basis {
SmoothBasisSpec::ThinPlate {
feature_cols,
spec,
input_scales,
} => (feature_cols, &spec.center_strategy, input_scales.as_ref()),
SmoothBasisSpec::Matern {
feature_cols,
spec,
input_scales,
} => (feature_cols, &spec.center_strategy, input_scales.as_ref()),
SmoothBasisSpec::Duchon {
feature_cols,
spec,
input_scales,
} => (feature_cols, &spec.center_strategy, input_scales.as_ref()),
_ => return None,
};
let strategy_kind = center_strategy_kind(strategy);
let strategy_aux = match strategy {
CenterStrategy::Auto(inner) => match inner.as_ref() {
CenterStrategy::KMeans { max_iter, .. } => *max_iter,
CenterStrategy::UniformGrid { points_per_dim } => *points_per_dim,
_ => 0,
},
CenterStrategy::KMeans { max_iter, .. } => *max_iter,
CenterStrategy::UniformGrid { points_per_dim } => *points_per_dim,
_ => 0,
};
Some(JointSpatialCenterGroupKey {
feature_cols: feature_cols.clone(),
strategy_kind,
strategy_aux,
requested_num_centers: center_strategy_num_centers(strategy)?,
input_scale_bits: input_scales
.map(|values| values.iter().map(|value| value.to_bits()).collect()),
})
}
fn spatial_term_center_strategy(term: &SmoothTermSpec) -> Option<&CenterStrategy> {
match &term.basis {
SmoothBasisSpec::ThinPlate { spec, .. } => Some(&spec.center_strategy),
SmoothBasisSpec::Matern { spec, .. } => Some(&spec.center_strategy),
SmoothBasisSpec::Duchon { spec, .. } => Some(&spec.center_strategy),
_ => None,
}
}
fn set_spatial_term_centers(
term: &mut SmoothTermSpec,
centers: Array2<f64>,
) -> Result<(), BasisError> {
match &mut term.basis {
SmoothBasisSpec::ThinPlate { spec, .. } => {
spec.center_strategy = CenterStrategy::UserProvided(centers);
Ok(())
}
SmoothBasisSpec::Matern { spec, .. } => {
spec.center_strategy = CenterStrategy::UserProvided(centers);
Ok(())
}
SmoothBasisSpec::Duchon { spec, .. } => {
spec.center_strategy = CenterStrategy::UserProvided(centers);
Ok(())
}
_ => Err(BasisError::InvalidInput(format!(
"term '{}' does not support spatial center planning",
term.name
))),
}
}
fn standardized_spatial_term_data(
data: ArrayView2<'_, f64>,
term: &SmoothTermSpec,
) -> Result<Array2<f64>, BasisError> {
let (feature_cols, input_scales) = match &term.basis {
SmoothBasisSpec::ThinPlate {
feature_cols,
input_scales,
..
}
| SmoothBasisSpec::Matern {
feature_cols,
input_scales,
..
}
| SmoothBasisSpec::Duchon {
feature_cols,
input_scales,
..
} => (feature_cols, input_scales.as_ref()),
_ => {
crate::bail_invalid_basis!("term '{}' is not a spatial smooth", term.name);
}
};
let mut x = select_columns(data, feature_cols)?;
if let Some(scales) = input_scales {
apply_input_standardization(&mut x, scales);
} else if let Some(scales) = compute_spatial_input_scales(x.view()) {
apply_input_standardization(&mut x, &scales);
}
Ok(x)
}
fn plan_joint_spatial_centers_for_term_blocks(
data: ArrayView2<'_, f64>,
term_blocks: &[Vec<SmoothTermSpec>],
) -> Result<Vec<Vec<SmoothTermSpec>>, BasisError> {
let mut planned_blocks = term_blocks.to_vec();
let n = data.nrows();
let mut groups: BTreeMap<JointSpatialCenterGroupKey, Vec<(usize, usize)>> = BTreeMap::new();
for (block_idx, terms) in planned_blocks.iter().enumerate() {
for (term_idx, term) in terms.iter().enumerate() {
let Some(strategy) = spatial_term_center_strategy(term) else {
continue;
};
if !center_strategy_is_auto(strategy) {
continue;
}
let Some(group_key) = spatial_term_group_key(term) else {
continue;
};
if !matches!(
group_key.strategy_kind,
CenterStrategyKind::EqualMass
| CenterStrategyKind::EqualMassCovarRepresentative
| CenterStrategyKind::FarthestPoint
| CenterStrategyKind::KMeans
) {
continue;
}
if center_strategy_num_centers(strategy).is_none() {
continue;
}
groups
.entry(group_key)
.or_default()
.push((block_idx, term_idx));
}
}
for (group_key, members) in groups {
if members.len() < 2 {
continue;
}
let min_required = members
.iter()
.map(|&(block_idx, term_idx)| {
spatial_term_min_center_count(&planned_blocks[block_idx][term_idx])
})
.max()
.unwrap_or(1);
let joint_centers = group_key
.requested_num_centers
.max(min_required)
.min(n.max(1));
let (first_block_idx, first_term_idx) = members[0];
let prototype = &planned_blocks[first_block_idx][first_term_idx];
let standardized = standardized_spatial_term_data(data, prototype)?;
let strategy = spatial_term_center_strategy(prototype).ok_or_else(|| {
BasisError::InvalidInput(format!(
"term '{}' lost its spatial center strategy during joint planning",
prototype.name
))
})?;
let joint_strategy = center_strategy_with_num_centers(strategy, joint_centers)?;
let shared_centers = select_centers_by_strategy(standardized.view(), &joint_strategy)?;
log::info!(
"sharing {} spatial centers across {} smooth terms over columns {:?} (requested {} centers)",
shared_centers.nrows(),
members.len(),
group_key.feature_cols,
group_key.requested_num_centers,
);
for (block_idx, term_idx) in members {
set_spatial_term_centers(
&mut planned_blocks[block_idx][term_idx],
shared_centers.clone(),
)?;
}
}
for block in planned_blocks.iter_mut() {
for term in block.iter_mut() {
auto_init_length_scale_in_place(data, term);
}
}
Ok(planned_blocks)
}
fn auto_initial_length_scale(data: ArrayView2<'_, f64>, feature_cols: &[usize]) -> f64 {
const LENGTH_SCALE_FLOOR: f64 = 1e-6;
let n = data.nrows();
if n == 0 || feature_cols.is_empty() {
return 1.0;
}
let mut max_range = 0.0_f64;
for &c in feature_cols {
if c >= data.ncols() {
continue;
}
let col = data.column(c);
let mut lo = f64::INFINITY;
let mut hi = f64::NEG_INFINITY;
for &v in col.iter() {
if v.is_finite() {
if v < lo {
lo = v;
}
if v > hi {
hi = v;
}
}
}
if hi > lo {
let r = hi - lo;
if r > max_range {
max_range = r;
}
}
}
if !max_range.is_finite() || max_range <= 0.0 {
return 1.0;
}
let init = max_range / (n as f64).sqrt();
init.max(LENGTH_SCALE_FLOOR).min(max_range)
}
fn auto_init_length_scale_in_place(data: ArrayView2<'_, f64>, term: &mut SmoothTermSpec) {
auto_init_length_scale_in_basis(data, &mut term.basis);
}
fn auto_init_length_scale_in_basis(data: ArrayView2<'_, f64>, basis: &mut SmoothBasisSpec) {
match basis {
SmoothBasisSpec::Matern {
feature_cols, spec, ..
} => {
if spec.length_scale == 0.0 {
spec.length_scale = auto_initial_length_scale(data, feature_cols);
}
}
SmoothBasisSpec::ThinPlate {
feature_cols, spec, ..
} => {
if spec.length_scale == 0.0 {
spec.length_scale = auto_initial_length_scale(data, feature_cols);
}
}
SmoothBasisSpec::ByVariable { inner, .. }
| SmoothBasisSpec::FactorSumToZero { inner, .. } => {
auto_init_length_scale_in_basis(data, inner);
}
SmoothBasisSpec::BySmooth { smooth, .. } => {
auto_init_length_scale_in_basis(data, smooth);
}
_ => {}
}
}
impl LinearFitConditioning {
fn from_columns(design: &TermCollectionDesign, selected_cols: &[usize]) -> Self {
const SCALE_EPS: f64 = 1e-12;
let n = design.design.nrows();
let p = design.design.ncols();
let mut columns = Vec::with_capacity(selected_cols.len());
if n == 0 || selected_cols.is_empty() {
return Self {
intercept_idx: design.intercept_range.start,
columns,
};
}
let chunk_rows = crate::linalg::utils::row_chunk_for_byte_budget(n, p);
let mut sums = vec![0.0_f64; selected_cols.len()];
for start in (0..n).step_by(chunk_rows) {
let end = (start + chunk_rows).min(n);
let chunk = design
.design
.try_row_chunk(start..end)
.expect("LinearFitConditioning::from_columns row chunk failed");
for (k, &col_idx) in selected_cols.iter().enumerate() {
let column = chunk.column(col_idx);
for &v in column.iter() {
sums[k] += v;
}
}
}
let inv_n = 1.0_f64 / n as f64;
let means: Vec<f64> = sums.iter().map(|&s| s * inv_n).collect();
let mut sq_devs = vec![0.0_f64; selected_cols.len()];
for start in (0..n).step_by(chunk_rows) {
let end = (start + chunk_rows).min(n);
let chunk = design
.design
.try_row_chunk(start..end)
.expect("LinearFitConditioning::from_columns row chunk failed");
for (k, &col_idx) in selected_cols.iter().enumerate() {
let mean_k = means[k];
let column = chunk.column(col_idx);
for &v in column.iter() {
let d = v - mean_k;
sq_devs[k] += d * d;
}
}
}
for (k, &col_idx) in selected_cols.iter().enumerate() {
let mean = means[k];
let var = sq_devs[k] * inv_n;
let (mean, scale) = if var.is_finite() && var > SCALE_EPS * SCALE_EPS {
(mean, var.sqrt())
} else {
(0.0, 1.0)
};
columns.push(LinearColumnConditioning {
col_idx,
mean,
scale,
});
}
Self {
intercept_idx: design.intercept_range.start,
columns,
}
}
fn apply_to_design(&self, design: &Array2<f64>) -> Array2<f64> {
let mut out = design.clone();
for col in &self.columns {
{
let mut dst = out.column_mut(col.col_idx);
dst -= col.mean;
}
if col.scale != 1.0 {
out.column_mut(col.col_idx).mapv_inplace(|v| v / col.scale);
}
}
out
}
fn transform_matrix_columnswith_a(&self, mat: &Array2<f64>) -> Array2<f64> {
let mut out = mat.clone();
let intercept = self.intercept_idx;
for col in &self.columns {
let intercept_col = out.column(intercept).to_owned();
let mut target = out.column_mut(col.col_idx);
target -= &(intercept_col * col.mean);
if col.scale != 1.0 {
target.mapv_inplace(|v| v / col.scale);
}
}
out
}
fn transform_matrixrowswith_a_transpose(&self, mat: &Array2<f64>) -> Array2<f64> {
let mut out = mat.clone();
let intercept = self.intercept_idx;
for col in &self.columns {
let interceptrow = out.row(intercept).to_owned();
let mut target = out.row_mut(col.col_idx);
target -= &(interceptrow * col.mean);
if col.scale != 1.0 {
target.mapv_inplace(|v| v / col.scale);
}
}
out
}
fn left_multiply_by_m_inv_transpose(&self, mat_internal: &Array2<f64>) -> Array2<f64> {
let mut out = mat_internal.clone();
let intercept = self.intercept_idx;
let interceptrow_snapshot = mat_internal.row(intercept).to_owned();
for col in &self.columns {
if col.scale != 1.0 {
out.row_mut(col.col_idx).mapv_inplace(|v| v * col.scale);
}
if col.mean != 0.0 {
let mut target = out.row_mut(col.col_idx);
target += &(&interceptrow_snapshot * col.mean);
}
}
out
}
fn right_multiply_by_m_inv(&self, mat_internal: &Array2<f64>) -> Array2<f64> {
let mut out = mat_internal.clone();
let intercept = self.intercept_idx;
let intercept_col_snapshot = mat_internal.column(intercept).to_owned();
for col in &self.columns {
if col.scale != 1.0 {
out.column_mut(col.col_idx).mapv_inplace(|v| v * col.scale);
}
if col.mean != 0.0 {
let mut target = out.column_mut(col.col_idx);
target += &(&intercept_col_snapshot * col.mean);
}
}
out
}
fn transform_blockwise_penalties_to_internal(
&self,
penalties: &[BlockwisePenalty],
p: usize,
) -> Vec<PenaltySpec> {
let conditioning_cols: std::collections::HashSet<usize> =
self.columns.iter().map(|c| c.col_idx).collect();
penalties
.iter()
.map(|bp| {
let overlaps =
(bp.col_range.start..bp.col_range.end).any(|j| conditioning_cols.contains(&j));
if overlaps {
let global = bp.to_global(p);
let right = self.transform_matrix_columnswith_a(&global);
let transformed = self.transform_matrixrowswith_a_transpose(&right);
PenaltySpec::Dense(transformed)
} else {
PenaltySpec::from_blockwise(bp.clone())
}
})
.collect()
}
fn backtransform_beta(&self, beta_internal: &Array1<f64>) -> Array1<f64> {
let mut beta = beta_internal.clone();
let intercept = self.intercept_idx;
for col in &self.columns {
beta[intercept] -= beta_internal[col.col_idx] * col.mean / col.scale;
beta[col.col_idx] = beta_internal[col.col_idx] / col.scale;
}
beta
}
fn transform_penalized_hessian_to_original(&self, h_internal: &Array2<f64>) -> Array2<f64> {
let right = self.right_multiply_by_m_inv(h_internal);
self.left_multiply_by_m_inv_transpose(&right)
}
fn internal_bounds_for(&self, col_idx: usize, min: f64, max: f64) -> (f64, f64) {
if let Some(col) = self.columns.iter().find(|c| c.col_idx == col_idx) {
(min * col.scale, max * col.scale)
} else {
(min, max)
}
}
}
fn freeze_raw_spatial_metadata(metadata: BasisMetadata, raw_cols: usize) -> BasisMetadata {
match metadata {
BasisMetadata::ThinPlate {
centers,
length_scale,
periodic,
identifiability_transform: None,
input_scales,
radial_reparam,
} => BasisMetadata::ThinPlate {
centers,
length_scale,
periodic,
identifiability_transform: Some(Array2::eye(raw_cols)),
input_scales,
radial_reparam,
},
BasisMetadata::Duchon {
centers,
length_scale,
periodic,
power,
nullspace_order,
identifiability_transform: None,
input_scales,
aniso_log_scales,
operator_collocation_points,
} => BasisMetadata::Duchon {
centers,
length_scale,
periodic,
power,
nullspace_order,
identifiability_transform: Some(Array2::eye(raw_cols)),
input_scales,
aniso_log_scales,
operator_collocation_points,
},
other => other,
}
}
fn matern_operator_penalty_triplet_from_metadata(
metadata: &BasisMetadata,
) -> Result<(Vec<Array2<f64>>, Vec<usize>, Vec<PenaltyInfo>), BasisError> {
let BasisMetadata::Matern {
centers,
length_scale,
periodic,
nu,
include_intercept,
identifiability_transform,
aniso_log_scales,
input_scales,
..
} = metadata
else {
crate::bail_invalid_basis!("Matérn operator penalties require Matérn metadata");
};
let penalty_length_scale = match input_scales.as_deref() {
Some(scales) => compensate_length_scale_for_standardization(*length_scale, scales),
None => *length_scale,
};
let penalty_centers = crate::basis::expand_periodic_centers(centers, periodic.as_deref())?;
let ops = build_matern_collocation_operator_matrices(
penalty_centers.view(),
None,
penalty_length_scale,
*nu,
*include_intercept,
identifiability_transform.as_ref().map(|z| z.view()),
aniso_log_scales.as_deref(),
)?;
const ORDER_EPS: f64 = 1e-9;
let d = penalty_centers.ncols();
let m = nu.half_integer_value() + 0.5 * d as f64;
let mut candidates = Vec::with_capacity(3);
for (raw, source, min_order) in [
(ops.d0.t().dot(&ops.d0), PenaltySource::OperatorMass, 0.0),
(ops.d1.t().dot(&ops.d1), PenaltySource::OperatorTension, 1.0),
(
ops.d2.t().dot(&ops.d2),
PenaltySource::OperatorStiffness,
2.0,
),
] {
if min_order > 0.0 && m <= min_order + ORDER_EPS {
continue;
}
let sym = (&raw + &raw.t()) * 0.5;
let (matrix, normalization_scale) = normalize_penalty_in_constrained_space(&sym);
candidates.push(PenaltyCandidate {
matrix,
nullspace_dim_hint: 0,
source,
normalization_scale,
kronecker_factors: None,
op: None,
});
}
filter_active_penalty_candidates(candidates)
}
fn normalize_penalty_in_constrained_space(matrix: &Array2<f64>) -> (Array2<f64>, f64) {
let matrix = (matrix + &matrix.t().to_owned()) * 0.5;
let matrix = crate::terms::basis::project_penalty_to_psd_cone(&matrix);
let c = matrix.iter().map(|v| v * v).sum::<f64>().sqrt();
if c.is_finite() && c > 0.0 {
(matrix.mapv(|v| v / c), c)
} else {
(matrix, 1.0)
}
}
fn build_periodic_fourier_margin(
x: ArrayView1<'_, f64>,
period: f64,
requested_cols: usize,
penalty_order: usize,
) -> Result<(Array2<f64>, Array2<f64>, Array1<f64>), BasisError> {
if !period.is_finite() || period <= 0.0 {
crate::bail_invalid_basis!(
"periodic tensor margin requires finite positive period, got {period}"
);
}
let q = requested_cols.max(3);
let harmonics = q / 2;
let has_nyquist_cos = q.is_multiple_of(2);
let mut basis = Array2::<f64>::zeros((x.len(), q));
basis.column_mut(0).fill(1.0);
for (i, &xi) in x.iter().enumerate() {
let angle = 2.0 * std::f64::consts::PI * xi / period;
let mut col = 1usize;
for h in 1..=harmonics {
if col >= q {
break;
}
basis[[i, col]] = (h as f64 * angle).cos();
col += 1;
if col >= q {
break;
}
basis[[i, col]] = (h as f64 * angle).sin();
col += 1;
}
if has_nyquist_cos && q > 1 {
basis[[i, q - 1]] = (harmonics as f64 * angle).cos();
}
}
let mut penalty = Array2::<f64>::zeros((q, q));
let order = penalty_order.max(1) as i32;
let mut col = 1usize;
for h in 1..=harmonics {
let w = (h as f64).powi(2 * order);
if col < q {
penalty[[col, col]] = w;
col += 1;
}
if col < q {
penalty[[col, col]] = w;
col += 1;
}
}
if has_nyquist_cos && q > 1 {
penalty[[q - 1, q - 1]] = (harmonics as f64).powi(2 * order);
}
let knots = Array1::linspace(0.0, period, q);
Ok((basis, penalty, knots))
}
fn tensor_product_design_from_sparse_marginals(
marginal_sparse: &[&SparseColMat<usize, f64>],
) -> Result<SparseColMat<usize, f64>, BasisError> {
if marginal_sparse.is_empty() {
crate::bail_invalid_basis!("TensorBSpline requires at least one marginal basis");
}
let n = marginal_sparse[0].nrows();
for (i, m) in marginal_sparse.iter().enumerate().skip(1) {
if m.nrows() != n {
crate::bail_dim_basis!(
"tensor sparse marginal row mismatch at dim {i}: expected {n}, got {}",
m.nrows()
);
}
}
let dims: Vec<usize> = marginal_sparse.iter().map(|m| m.ncols()).collect();
let total_cols = dims.iter().try_fold(1usize, |acc, &q| {
acc.checked_mul(q)
.ok_or_else(|| BasisError::DimensionMismatch("tensor basis too large".to_string()))
})?;
let mut strides = vec![1usize; dims.len()];
for d in (0..dims.len().saturating_sub(1)).rev() {
strides[d] = strides[d + 1]
.checked_mul(dims[d + 1])
.ok_or_else(|| BasisError::DimensionMismatch("tensor basis too large".to_string()))?;
}
use faer::sparse::SparseRowMat;
let csrs: Vec<SparseRowMat<usize, f64>> = marginal_sparse
.iter()
.enumerate()
.map(|(d, m)| {
m.as_ref().to_row_major().map_err(|e| {
BasisError::SparseCreation(format!(
"tensor sparse marginal {d} CSR conversion failed: {e:?}"
))
})
})
.collect::<Result<Vec<_>, _>>()?;
let row_ptrs: Vec<&[usize]> = csrs.iter().map(|c| c.symbolic().row_ptr()).collect();
let col_idxs: Vec<&[usize]> = csrs.iter().map(|c| c.symbolic().col_idx()).collect();
let vals: Vec<&[f64]> = csrs.iter().map(|c| c.val()).collect();
use rayon::prelude::*;
const CHUNK: usize = 1024;
let num_chunks = n.div_ceil(CHUNK);
let per_chunk: Vec<Vec<Triplet<usize, usize, f64>>> = (0..num_chunks)
.into_par_iter()
.map(|chunk_idx| {
let row_start = chunk_idx * CHUNK;
let row_end = (row_start + CHUNK).min(n);
let mut chunk_triplets = Vec::<Triplet<usize, usize, f64>>::new();
let mut cur_cols = Vec::<usize>::with_capacity(64);
let mut cur_vals = Vec::<f64>::with_capacity(64);
let mut next_cols = Vec::<usize>::with_capacity(64);
let mut next_vals = Vec::<f64>::with_capacity(64);
for i in row_start..row_end {
cur_cols.clear();
cur_vals.clear();
cur_cols.push(0);
cur_vals.push(1.0);
let mut row_is_zero = false;
for d in 0..dims.len() {
let row_start_d = row_ptrs[d][i];
let row_end_d = row_ptrs[d][i + 1];
if row_start_d == row_end_d {
row_is_zero = true;
break;
}
let stride = strides[d];
next_cols.clear();
next_vals.clear();
next_cols.reserve(cur_cols.len() * (row_end_d - row_start_d));
next_vals.reserve(cur_vals.len() * (row_end_d - row_start_d));
for (&prev_col, &prev_val) in cur_cols.iter().zip(cur_vals.iter()) {
for ptr in row_start_d..row_end_d {
let cj = col_idxs[d][ptr];
let vj = vals[d][ptr];
next_cols.push(prev_col + cj * stride);
next_vals.push(prev_val * vj);
}
}
std::mem::swap(&mut cur_cols, &mut next_cols);
std::mem::swap(&mut cur_vals, &mut next_vals);
}
if row_is_zero {
continue;
}
for (&col, &val) in cur_cols.iter().zip(cur_vals.iter()) {
chunk_triplets.push(Triplet::new(i, col, val));
}
}
chunk_triplets
})
.collect();
let total_nnz: usize = per_chunk.iter().map(Vec::len).sum();
let mut triplets = Vec::<Triplet<usize, usize, f64>>::with_capacity(total_nnz);
for chunk in per_chunk {
triplets.extend(chunk);
}
SparseColMat::try_new_from_triplets(n, total_cols, &triplets).map_err(|e| {
BasisError::SparseCreation(format!(
"failed to assemble sparse tensor product design: {e:?}"
))
})
}
struct TensorMarginRangeNullProjectors {
range: Array2<f64>,
null: Array2<f64>,
}
fn projector_from_columns(columns: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
if indices.is_empty() {
return Array2::<f64>::zeros((columns.nrows(), columns.nrows()));
}
let basis = columns.select(Axis(1), indices);
basis.dot(&basis.t())
}
fn tensor_margin_range_null_projectors(
normalized_marginal_penalties: &[(Array2<f64>, f64)],
) -> Result<Vec<TensorMarginRangeNullProjectors>, BasisError> {
normalized_marginal_penalties
.iter()
.enumerate()
.map(|(dim, (penalty, _))| {
let analysis = crate::terms::basis::analyze_penalty_block(penalty)?;
if analysis.rank == 0 {
crate::bail_invalid_basis!(
"t2 separable tensor penalty margin {dim} has rank-zero penalty; \
cannot split penalized and null subspaces"
);
}
let mut range_idx = Vec::<usize>::new();
let mut null_idx = Vec::<usize>::new();
for (idx, &ev) in analysis.eigenvalues.iter().enumerate() {
if ev > analysis.tol {
range_idx.push(idx);
} else {
null_idx.push(idx);
}
}
Ok(TensorMarginRangeNullProjectors {
range: projector_from_columns(&analysis.eigenvectors, &range_idx),
null: projector_from_columns(&analysis.eigenvectors, &null_idx),
})
})
.collect()
}
fn build_tensor_bspline_basis(
data: ArrayView2<'_, f64>,
feature_cols: &[usize],
spec: &TensorBSplineSpec,
) -> Result<BasisBuildResult, BasisError> {
if feature_cols.is_empty() {
crate::bail_invalid_basis!("TensorBSpline requires at least one feature column");
}
if feature_cols.len() != spec.marginalspecs.len() {
crate::bail_dim_basis!(
"TensorBSpline feature/spec mismatch: feature_cols={}, marginalspecs={}",
feature_cols.len(),
spec.marginalspecs.len()
);
}
if !spec.periods.is_empty() && spec.periods.len() != feature_cols.len() {
crate::bail_dim_basis!(
"TensorBSpline periods length {} does not match feature count {}",
spec.periods.len(),
feature_cols.len()
);
}
let p = data.ncols();
for &c in feature_cols {
if c >= p {
crate::bail_dim_basis!(
"tensor feature column {c} is out of bounds for data with {p} columns"
);
}
}
let mut marginal_knots = Vec::<Array1<f64>>::with_capacity(feature_cols.len());
let mut marginal_degrees = Vec::<usize>::with_capacity(feature_cols.len());
let mut marginalnum_basis = Vec::<usize>::with_capacity(feature_cols.len());
let mut marginal_penalties = Vec::<Array2<f64>>::with_capacity(feature_cols.len());
let mut marginal_designs = Vec::<Array2<f64>>::with_capacity(feature_cols.len());
let mut marginal_effective_periods = Vec::<Option<f64>>::with_capacity(feature_cols.len());
let mut marginal_sparse =
Vec::<Option<SparseColMat<usize, f64>>>::with_capacity(feature_cols.len());
for (dim, (&col, marginalspec)) in feature_cols
.iter()
.zip(spec.marginalspecs.iter())
.enumerate()
{
if let Some(period) = spec.periods.get(dim).and_then(|p| *p) {
let requested_cols = match marginalspec.knotspec {
BSplineKnotSpec::Generate {
num_internal_knots, ..
} => num_internal_knots + marginalspec.degree + 1,
BSplineKnotSpec::Provided(ref knots) => {
knots.len().saturating_sub(marginalspec.degree + 1)
}
BSplineKnotSpec::Automatic {
num_internal_knots, ..
} => {
const DEFAULT_AUTOMATIC_INTERNAL_KNOTS: usize = 8;
num_internal_knots.unwrap_or(DEFAULT_AUTOMATIC_INTERNAL_KNOTS)
+ marginalspec.degree
+ 1
}
BSplineKnotSpec::PeriodicUniform { num_basis, .. } => num_basis,
};
let (basis, penalty, knots) = build_periodic_fourier_margin(
data.column(col),
period,
requested_cols,
marginalspec.penalty_order,
)?;
marginal_knots.push(knots);
marginal_degrees.push(marginalspec.degree);
marginalnum_basis.push(basis.ncols());
marginal_designs.push(basis);
marginal_penalties.push(penalty);
marginal_sparse.push(None);
marginal_effective_periods.push(Some(period));
} else {
let mut marginal_unconstrained = marginalspec.clone();
marginal_unconstrained.identifiability = BSplineIdentifiability::None;
let built = build_bspline_basis_1d(data.column(col), &marginal_unconstrained)?;
let knots = match built.metadata {
BasisMetadata::BSpline1D { knots, .. } => knots,
_ => {
crate::bail_invalid_basis!(
"internal TensorBSpline error at dim {dim}: expected BSpline1D metadata"
);
}
};
marginal_knots.push(knots);
marginal_degrees.push(marginalspec.degree);
marginalnum_basis.push(built.design.ncols());
let sparse_view: Option<SparseColMat<usize, f64>> =
built.design.as_sparse().map(|sd| {
let inner: &SparseColMat<usize, f64> = sd;
inner.clone()
});
marginal_sparse.push(sparse_view);
marginal_designs.push(built.design.to_dense());
marginal_penalties.push(
built
.penalties
.first()
.ok_or_else(|| {
BasisError::InvalidInput(format!(
"internal TensorBSpline error at dim {dim}: missing marginal penalty"
))
})?
.clone(),
);
built.nullspace_dims.first().ok_or_else(|| {
BasisError::InvalidInput(format!(
"internal TensorBSpline error at dim {dim}: missing marginal nullspace dim"
))
})?;
let implied_period = match marginalspec.knotspec {
BSplineKnotSpec::PeriodicUniform { data_range, .. } => {
Some(data_range.1 - data_range.0)
}
_ => None,
};
marginal_effective_periods.push(implied_period);
}
}
let total_cols: usize = marginalnum_basis.iter().product();
let mut dense_design = (!matches!(spec.identifiability, TensorBSplineIdentifiability::None))
.then(|| tensor_product_design_from_marginals(&marginal_designs))
.transpose()?;
let mut candidates = Vec::<PenaltyCandidate>::with_capacity(
match spec.penalty_decomposition {
TensorBSplinePenaltyDecomposition::MarginalKroneckerSum => marginal_penalties.len(),
TensorBSplinePenaltyDecomposition::Separable => marginal_penalties.len() * 2,
} + if spec.double_penalty { 1 } else { 0 },
);
let normalized_marginal_penalties: Vec<(Array2<f64>, f64)> = marginal_penalties
.iter()
.map(normalize_penalty_in_constrained_space)
.collect();
let mut kronecker_marginal_penalties =
Vec::<Array2<f64>>::with_capacity(normalized_marginal_penalties.len());
match spec.penalty_decomposition {
TensorBSplinePenaltyDecomposition::MarginalKroneckerSum => {
let mut marginal_kron_sum = Array2::<f64>::zeros((total_cols, total_cols));
for dim in 0..normalized_marginal_penalties.len() {
let mut s_dim = Array2::<f64>::eye(1);
let mut factors = Vec::<Array2<f64>>::with_capacity(marginalnum_basis.len());
for (j, &qj) in marginalnum_basis.iter().enumerate() {
let factor = if j == dim {
normalized_marginal_penalties[j].0.clone()
} else {
Array2::<f64>::eye(qj)
};
factors.push(factor.clone());
s_dim = kronecker_product(&s_dim, &factor);
}
if dim == kronecker_marginal_penalties.len() {
kronecker_marginal_penalties.push(normalized_marginal_penalties[dim].0.clone());
}
marginal_kron_sum += &s_dim;
candidates.push(PenaltyCandidate {
matrix: s_dim,
nullspace_dim_hint: 0,
source: PenaltySource::TensorMarginal { dim },
normalization_scale: normalized_marginal_penalties[dim].1,
kronecker_factors: Some(factors),
op: None,
});
}
if spec.double_penalty
&& let Some(shrink) =
crate::terms::basis::build_nullspace_shrinkage_penalty(&marginal_kron_sum)?
{
let (matrix, normalization_scale) =
normalize_penalty_in_constrained_space(&shrink.sym_penalty);
candidates.push(PenaltyCandidate {
matrix,
nullspace_dim_hint: 0,
source: PenaltySource::TensorGlobalRidge,
normalization_scale,
kronecker_factors: None,
op: None,
});
}
}
TensorBSplinePenaltyDecomposition::Separable => {
let projectors = tensor_margin_range_null_projectors(&normalized_marginal_penalties)?;
let n_masks = 1usize.checked_shl(projectors.len() as u32).ok_or_else(|| {
BasisError::InvalidInput(format!(
"t2 separable tensor penalty supports at most {} margins, got {}",
usize::BITS - 1,
projectors.len()
))
})?;
for mask in 1..n_masks {
let mut matrix = Array2::<f64>::eye(1);
let mut factors = Vec::<Array2<f64>>::with_capacity(projectors.len());
let mut penalized_margins = Vec::<usize>::new();
for (dim, projector) in projectors.iter().enumerate() {
let use_range = ((mask >> dim) & 1) == 1;
let factor = if use_range {
penalized_margins.push(dim);
projector.range.clone()
} else {
projector.null.clone()
};
matrix = kronecker_product(&matrix, &factor);
factors.push(factor);
}
let (matrix, normalization_scale) = normalize_penalty_in_constrained_space(&matrix);
candidates.push(PenaltyCandidate {
matrix,
nullspace_dim_hint: 0,
source: PenaltySource::TensorSeparable { penalized_margins },
normalization_scale,
kronecker_factors: Some(factors),
op: None,
});
}
if spec.double_penalty {
let mut matrix = Array2::<f64>::eye(1);
let mut factors = Vec::<Array2<f64>>::with_capacity(projectors.len());
for projector in &projectors {
matrix = kronecker_product(&matrix, &projector.null);
factors.push(projector.null.clone());
}
let (matrix, normalization_scale) = normalize_penalty_in_constrained_space(&matrix);
candidates.push(PenaltyCandidate {
matrix,
nullspace_dim_hint: 0,
source: PenaltySource::TensorGlobalRidge,
normalization_scale,
kronecker_factors: Some(factors),
op: None,
});
}
}
}
let z_opt = match &spec.identifiability {
TensorBSplineIdentifiability::None => None,
TensorBSplineIdentifiability::SumToZero => {
if total_cols < 2 {
crate::bail_invalid_basis!(
"TensorBSpline requires at least 2 basis coefficients to enforce sum-to-zero identifiability"
);
}
let dense_design_ref = dense_design.as_ref().ok_or_else(|| {
BasisError::InvalidInput(
"tensor sum-to-zero identifiability requires a realized basis".to_string(),
)
})?;
let (_, z) = apply_sum_to_zero_constraint(dense_design_ref.view(), None)?;
let gauge = crate::solver::gauge::Gauge::sum_to_zero(z);
Some(gauge.block_transform(0))
}
TensorBSplineIdentifiability::MarginalSumToZero => {
if marginal_designs.len() < 2 {
crate::bail_invalid_basis!(
"tensor interaction (ti) identifiability requires at least 2 margins"
);
}
let mut z = Array2::<f64>::eye(1);
for (dim, marginal) in marginal_designs.iter().enumerate() {
if marginal.ncols() < 2 {
crate::bail_invalid_basis!(
"tensor interaction (ti) margin {dim} has fewer than 2 basis functions; \
cannot remove its marginal main effect"
);
}
let (_, z_dim) = apply_sum_to_zero_constraint(marginal.view(), None)?;
let gauge_dim = crate::solver::gauge::Gauge::sum_to_zero(z_dim);
let z_dim = gauge_dim.block_transform(0);
z = kronecker_product(&z, &z_dim);
}
Some(z)
}
TensorBSplineIdentifiability::FrozenTransform { transform } => {
if transform.nrows() != total_cols {
crate::bail_dim_basis!(
"frozen tensor identifiability transform mismatch: design has {} columns but transform has {} rows",
total_cols,
transform.nrows()
);
}
Some(transform.clone())
}
};
if let Some(z) = z_opt.as_ref() {
let gauge = crate::solver::gauge::Gauge::from_block_transforms(&[z.clone()]);
let dense = dense_design.as_mut().ok_or_else(|| {
BasisError::InvalidInput(
"tensor identifiability transform requires a realized basis".to_string(),
)
})?;
let restricted_design = gauge.restrict_design(dense);
*dense = restricted_design;
candidates = candidates
.into_iter()
.map(|candidate| -> Result<PenaltyCandidate, BasisError> {
let matrix = gauge.restrict_penalty(&candidate.matrix);
let (matrix, c_new) = normalize_penalty_in_constrained_space(&matrix);
let preserve_margin_scale =
matches!(&candidate.source, PenaltySource::TensorMarginal { .. });
let (matrix, normalization_scale) = if preserve_margin_scale {
(matrix.mapv(|v| v * c_new), candidate.normalization_scale)
} else {
(matrix, candidate.normalization_scale * c_new)
};
Ok(PenaltyCandidate {
nullspace_dim_hint: candidate.nullspace_dim_hint,
matrix,
source: candidate.source,
normalization_scale,
kronecker_factors: None,
op: candidate.op.clone(),
})
})
.collect::<Result<Vec<_>, _>>()?;
}
let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
filter_active_penalty_candidates_with_ops(candidates)?;
let identifiability_is_none =
matches!(spec.identifiability, TensorBSplineIdentifiability::None);
let all_marginals_sparse = marginal_sparse.iter().all(Option::is_some);
let design = if let Some(dense_design) = dense_design {
DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(dense_design))
} else if identifiability_is_none && all_marginals_sparse {
let sparse_marginals: Vec<&SparseColMat<usize, f64>> = marginal_sparse
.iter()
.map(|m| m.as_ref().expect("all_marginals_sparse just verified"))
.collect();
let sparse_design = tensor_product_design_from_sparse_marginals(&sparse_marginals)?;
DesignMatrix::Sparse(crate::matrix::SparseDesignMatrix::new(sparse_design))
} else {
let marginals: Vec<Arc<Array2<f64>>> = marginal_designs
.iter()
.map(|m| Arc::new(m.clone()))
.collect();
let op = TensorProductDesignOperator::new(marginals).map_err(|e| {
BasisError::InvalidInput(format!("TensorProductDesignOperator build failed: {e}"))
})?;
DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(Arc::new(op)))
};
Ok(BasisBuildResult {
design,
penalties,
nullspace_dims,
penaltyinfo,
ops,
null_eigenvectors,
joint_null_rotation: None,
metadata: BasisMetadata::TensorBSpline {
feature_cols: feature_cols.to_vec(),
knots: marginal_knots,
degrees: marginal_degrees,
periods: marginal_effective_periods,
identifiability_transform: z_opt,
},
kronecker_factored: if matches!(spec.identifiability, TensorBSplineIdentifiability::None)
&& matches!(
spec.penalty_decomposition,
TensorBSplinePenaltyDecomposition::MarginalKroneckerSum
) {
Some(KroneckerFactoredBasis {
marginal_designs,
marginal_penalties: kronecker_marginal_penalties,
marginal_dims: marginalnum_basis.clone(),
has_double_penalty: spec.double_penalty,
})
} else {
None
},
})
}
fn tensor_product_design_from_marginals(
marginal_designs: &[Array2<f64>],
) -> Result<Array2<f64>, BasisError> {
if marginal_designs.is_empty() {
crate::bail_invalid_basis!("TensorBSpline requires at least one marginal basis");
}
let n = marginal_designs[0].nrows();
for (i, b) in marginal_designs.iter().enumerate().skip(1) {
if b.nrows() != n {
crate::bail_dim_basis!(
"tensor marginal row mismatch at dim {i}: expected {n}, got {}",
b.nrows()
);
}
}
let total_cols = marginal_designs.iter().try_fold(1usize, |acc, b| {
acc.checked_mul(b.ncols())
.ok_or_else(|| BasisError::DimensionMismatch("tensor basis too large".to_string()))
})?;
use ndarray::parallel::prelude::*;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let mut design = Array2::<f64>::zeros((n, total_cols));
design
.axis_chunks_iter_mut(ndarray::Axis(0), 1024)
.into_par_iter()
.enumerate()
.for_each(|(chunk_idx, mut block)| {
let row_offset = chunk_idx * 1024;
let mut cur = Vec::<f64>::with_capacity(total_cols);
let mut next = Vec::<f64>::with_capacity(total_cols);
for (local_i, mut out_row) in block.outer_iter_mut().enumerate() {
let i = row_offset + local_i;
cur.clear();
cur.push(1.0);
for b in marginal_designs {
let q = b.ncols();
next.clear();
next.resize(cur.len() * q, 0.0);
let b_row = b.row(i);
let b_slice = b_row
.as_slice()
.expect("Array2 row from outer_iter is contiguous");
for (a_idx, &aval) in cur.iter().enumerate() {
let off = a_idx * q;
let dst = &mut next[off..off + q];
for col in 0..q {
dst[col] = aval * b_slice[col];
}
}
std::mem::swap(&mut cur, &mut next);
}
let out_slice = out_row
.as_slice_mut()
.expect("design row is contiguous in C-major Array2");
out_slice.copy_from_slice(&cur);
}
});
Ok(design)
}
fn build_random_effect_block(
data: ArrayView2<'_, f64>,
spec: &RandomEffectTermSpec,
) -> Result<RandomEffectBlock, BasisError> {
let n = data.nrows();
let p = data.ncols();
if spec.feature_col >= p {
crate::bail_dim_basis!(
"random-effect term '{}' feature column {} out of bounds for {} columns",
spec.name,
spec.feature_col,
p
);
}
let col = data.column(spec.feature_col);
if col.iter().any(|v| !v.is_finite()) {
crate::bail_invalid_basis!(
"random-effect term '{}' contains non-finite group values",
spec.name
);
}
let kept_levels: Vec<u64> = if let Some(levels) = spec.frozen_levels.as_ref() {
if levels.is_empty() {
crate::bail_invalid_basis!(
"random-effect term '{}' has empty frozen_levels",
spec.name
);
}
levels.clone()
} else {
let mut levels_set = BTreeSet::<u64>::new();
for &v in col {
levels_set.insert(v.to_bits());
}
if levels_set.is_empty() {
crate::bail_invalid_basis!("random-effect term '{}' has no observed levels", spec.name);
}
let levels: Vec<u64> = levels_set.into_iter().collect();
let start_idx = if spec.drop_first_level && levels.len() > 1 {
1usize
} else {
0usize
};
levels[start_idx..].to_vec()
};
if kept_levels.is_empty() {
crate::bail_invalid_basis!(
"random-effect term '{}' drops all levels; keep at least one level",
spec.name
);
}
let q = kept_levels.len();
let mut level_to_col = BTreeMap::<u64, usize>::new();
for (idx, &bits) in kept_levels.iter().enumerate() {
if level_to_col.insert(bits, idx).is_some() {
crate::bail_invalid_basis!(
"random-effect term '{}' has duplicate frozen level bits {bits}",
spec.name
);
}
}
let mut group_ids = Vec::with_capacity(n);
for &v in col {
let bits = v.to_bits();
group_ids.push(level_to_col.get(&bits).copied());
}
Ok(RandomEffectBlock {
name: spec.name.clone(),
group_ids,
num_groups: q,
kept_levels,
})
}
impl SmoothDesign {
pub fn map_term_coefficients(
unconstrained: &Array1<f64>,
shape: ShapeConstraint,
) -> Result<Array1<f64>, BasisError> {
if unconstrained.is_empty() {
crate::bail_invalid_basis!("unconstrained coefficient vector cannot be empty");
}
let mapped = match shape {
ShapeConstraint::None => unconstrained.clone(),
ShapeConstraint::MonotoneIncreasing => cumulative_exp(unconstrained, 1.0),
ShapeConstraint::MonotoneDecreasing => cumulative_exp(unconstrained, -1.0),
ShapeConstraint::Convex => second_cumulative_exp(unconstrained, 1.0),
ShapeConstraint::Concave => second_cumulative_exp(unconstrained, -1.0),
};
Ok(mapped)
}
}
struct LocalSmoothTermBuild {
dim: usize,
design: DesignMatrix,
penalties: Vec<Array2<f64>>,
ops: Vec<Option<std::sync::Arc<dyn crate::terms::analytic_penalties::PenaltyOp>>>,
nullspaces: Vec<usize>,
null_eigenvectors: Vec<Option<Array2<f64>>>,
joint_null_rotation: Option<crate::terms::basis::JointNullRotation>,
penaltyinfo: Vec<PenaltyInfo>,
pre_dropped_penaltyinfo: Vec<PenaltyInfo>,
metadata: BasisMetadata,
linear_constraints: Option<LinearInequalityConstraints>,
box_reparam: bool,
kronecker_factored: Option<KroneckerFactoredBasis>,
}
#[derive(Clone)]
struct PcaScoresMemmapDesignOperator {
mmap: Arc<memmap2::Mmap>,
data_offset: usize,
nrows: usize,
ncols: usize,
chunk_size: usize,
}
impl PcaScoresMemmapDesignOperator {
fn open(path: PathBuf, chunk_size: usize) -> Result<Self, BasisError> {
let file = File::open(&path).map_err(|err| {
BasisError::InvalidInput(format!(
"failed to open lazy Pca .npy scores '{}': {err}",
path.display()
))
})?;
let mmap = unsafe {
memmap2::Mmap::map(&file).map_err(|err| {
BasisError::InvalidInput(format!(
"failed to memmap lazy Pca .npy scores '{}': {err}",
path.display()
))
})?
};
let (data_offset, nrows, ncols) = parse_f64_2d_npy_header(&mmap, &path)?;
let expected = data_offset
.checked_add(nrows.saturating_mul(ncols).saturating_mul(8))
.ok_or_else(|| {
BasisError::InvalidInput(format!(
"lazy Pca .npy scores '{}' shape is too large",
path.display()
))
})?;
if mmap.len() < expected {
crate::bail_invalid_basis!(
"lazy Pca .npy scores '{}' is truncated: header expects {} bytes, file has {}",
path.display(),
expected,
mmap.len()
);
}
Ok(Self {
mmap: Arc::new(mmap),
data_offset,
nrows,
ncols,
chunk_size: chunk_size.max(1),
})
}
fn value(&self, row: usize, col: usize) -> f64 {
let offset = self.data_offset + (row * self.ncols + col) * 8;
let mut bytes = [0_u8; 8];
bytes.copy_from_slice(&self.mmap[offset..offset + 8]);
f64::from_le_bytes(bytes)
}
fn chunk_rows(&self) -> usize {
self.chunk_size.min(self.nrows.max(1))
}
}
impl LinearOperator for PcaScoresMemmapDesignOperator {
fn nrows(&self) -> usize {
self.nrows
}
fn ncols(&self) -> usize {
self.ncols
}
fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
assert_eq!(
vector.len(),
self.ncols,
"lazy Pca apply vector length mismatch"
);
let mut out = Array1::<f64>::zeros(self.nrows);
for start in (0..self.nrows).step_by(self.chunk_rows()) {
let end = (start + self.chunk_rows()).min(self.nrows);
for row in start..end {
let mut acc = 0.0;
for col in 0..self.ncols {
acc += self.value(row, col) * vector[col];
}
out[row] = acc;
}
}
out
}
fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
assert_eq!(
vector.len(),
self.nrows,
"lazy Pca apply_transpose vector length mismatch"
);
let mut out = Array1::<f64>::zeros(self.ncols);
for start in (0..self.nrows).step_by(self.chunk_rows()) {
let end = (start + self.chunk_rows()).min(self.nrows);
for row in start..end {
let scale = vector[row];
if scale == 0.0 {
continue;
}
for col in 0..self.ncols {
out[col] += scale * self.value(row, col);
}
}
}
out
}
fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
if weights.len() != self.nrows {
return Err(format!(
"lazy Pca diag_xtw_x weight length mismatch: weights={}, nrows={}",
weights.len(),
self.nrows
));
}
let mut gram = Array2::<f64>::zeros((self.ncols, self.ncols));
for start in (0..self.nrows).step_by(self.chunk_rows()) {
let end = (start + self.chunk_rows()).min(self.nrows);
for row in start..end {
let w = weights[row];
if w == 0.0 {
continue;
}
for a in 0..self.ncols {
let xa = self.value(row, a);
if xa == 0.0 {
continue;
}
for b in a..self.ncols {
gram[[a, b]] += w * xa * self.value(row, b);
}
}
}
}
for a in 0..self.ncols {
for b in 0..a {
gram[[a, b]] = gram[[b, a]];
}
}
Ok(gram)
}
fn apply_weighted_normal(
&self,
weights: &Array1<f64>,
vector: &Array1<f64>,
penalty: Option<&Array2<f64>>,
ridge: f64,
) -> Array1<f64> {
assert_eq!(
weights.len(),
self.nrows,
"lazy Pca weighted-normal weight mismatch"
);
assert_eq!(
vector.len(),
self.ncols,
"lazy Pca weighted-normal vector mismatch"
);
let mut out = Array1::<f64>::zeros(self.ncols);
for start in (0..self.nrows).step_by(self.chunk_rows()) {
let end = (start + self.chunk_rows()).min(self.nrows);
for row in start..end {
let w = weights[row].max(0.0);
if w == 0.0 {
continue;
}
let mut row_dot = 0.0;
for col in 0..self.ncols {
row_dot += self.value(row, col) * vector[col];
}
if row_dot == 0.0 {
continue;
}
let scaled = w * row_dot;
for col in 0..self.ncols {
out[col] += scaled * self.value(row, col);
}
}
}
if let Some(pen) = penalty {
out += &pen.dot(vector);
}
if ridge > 0.0 {
out += &vector.mapv(|x| ridge * x);
}
out
}
}
impl DenseDesignOperator for PcaScoresMemmapDesignOperator {
fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
if weights.len() != self.nrows || y.len() != self.nrows {
return Err(format!(
"lazy Pca compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
weights.len(),
y.len(),
self.nrows
));
}
let mut out = Array1::<f64>::zeros(self.ncols);
for start in (0..self.nrows).step_by(self.chunk_rows()) {
let end = (start + self.chunk_rows()).min(self.nrows);
for row in start..end {
let scale = weights[row] * y[row];
if scale == 0.0 {
continue;
}
for col in 0..self.ncols {
out[col] += scale * self.value(row, col);
}
}
}
Ok(out)
}
fn row_chunk_into(
&self,
rows: Range<usize>,
mut out: ArrayViewMut2<'_, f64>,
) -> Result<(), MatrixMaterializationError> {
if rows.end > self.nrows || rows.start > rows.end {
return Err(MatrixMaterializationError::MissingRowChunk {
context: "lazy Pca row range out of bounds",
});
}
if out.nrows() != rows.end - rows.start || out.ncols() != self.ncols {
return Err(MatrixMaterializationError::MissingRowChunk {
context: "lazy Pca row_chunk_into shape mismatch",
});
}
for (local, row) in (rows.start..rows.end).enumerate() {
for col in 0..self.ncols {
out[[local, col]] = self.value(row, col);
}
}
Ok(())
}
fn to_dense(&self) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((self.nrows, self.ncols));
self.row_chunk_into(0..self.nrows, out.view_mut())
.expect("lazy Pca full materialization failed");
out
}
}
fn parse_f64_2d_npy_header(
bytes: &[u8],
path: &PathBuf,
) -> Result<(usize, usize, usize), BasisError> {
if bytes.len() < 10 || &bytes[0..6] != b"\x93NUMPY" {
crate::bail_invalid_basis!("lazy Pca scores '{}' is not a .npy file", path.display());
}
let major = bytes[6];
let header_len = match major {
1 => u16::from_le_bytes([bytes[8], bytes[9]]) as usize,
2 | 3 => {
if bytes.len() < 12 {
crate::bail_invalid_basis!(
"lazy Pca scores '{}' has a truncated .npy header",
path.display()
);
}
u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]) as usize
}
other => {
crate::bail_invalid_basis!(
"lazy Pca scores '{}' uses unsupported .npy version {}",
path.display(),
other
);
}
};
let header_start = if major == 1 { 10 } else { 12 };
let data_offset = header_start + header_len;
if bytes.len() < data_offset {
crate::bail_invalid_basis!(
"lazy Pca scores '{}' has a truncated .npy header",
path.display()
);
}
let header = std::str::from_utf8(&bytes[header_start..data_offset]).map_err(|err| {
BasisError::InvalidInput(format!(
"lazy Pca scores '{}' has a non-UTF8 .npy header: {err}",
path.display()
))
})?;
if !(header.contains("'descr': '<f8'")
|| header.contains("\"descr\": \"<f8\"")
|| header.contains("'descr': '|f8'")
|| header.contains("\"descr\": \"|f8\""))
{
crate::bail_invalid_basis!(
"lazy Pca scores '{}' must be float64 little-endian .npy",
path.display()
);
}
if header.contains("True") {
crate::bail_invalid_basis!(
"lazy Pca scores '{}' must be C-contiguous, not Fortran-ordered",
path.display()
);
}
let shape_pos = header.find("shape").ok_or_else(|| {
BasisError::InvalidInput(format!(
"lazy Pca scores '{}' .npy header is missing shape",
path.display()
))
})?;
let open = header[shape_pos..].find('(').ok_or_else(|| {
BasisError::InvalidInput(format!(
"lazy Pca scores '{}' .npy header has malformed shape",
path.display()
))
})? + shape_pos;
let close = header[open..].find(')').ok_or_else(|| {
BasisError::InvalidInput(format!(
"lazy Pca scores '{}' .npy header has malformed shape",
path.display()
))
})? + open;
let dims = header[open + 1..close]
.split(',')
.map(str::trim)
.filter(|part| !part.is_empty())
.map(|part| part.parse::<usize>())
.collect::<Result<Vec<_>, _>>()
.map_err(|err| {
BasisError::InvalidInput(format!(
"lazy Pca scores '{}' .npy shape is not integral: {err}",
path.display()
))
})?;
if dims.len() != 2 {
crate::bail_invalid_basis!(
"lazy Pca scores '{}' must have shape (N, K), got {:?}",
path.display(),
dims
);
}
Ok((data_offset, dims[0], dims[1]))
}
fn pca_center_mean(x: ArrayView2<'_, f64>) -> Result<Array1<f64>, BasisError> {
if x.nrows() == 0 {
crate::bail_invalid_basis!("Pca basis requires at least one row to compute center mean");
}
let mut mean = Array1::<f64>::zeros(x.ncols());
for row in x.rows() {
mean += &row;
}
mean.mapv_inplace(|v| v / x.nrows() as f64);
Ok(mean)
}
fn build_pca_smooth_basis(
data: ArrayView2<'_, f64>,
feature_cols: &[usize],
basis_matrix: &Array2<f64>,
centered: bool,
smooth_penalty: f64,
center_mean: Option<&Array1<f64>>,
pca_basis_path: Option<&PathBuf>,
chunk_size: usize,
) -> Result<BasisBuildResult, BasisError> {
if let Some(path) = pca_basis_path {
let op = PcaScoresMemmapDesignOperator::open(path.clone(), chunk_size)?;
if op.nrows != data.nrows() {
crate::bail_dim_basis!(
"lazy Pca scores row mismatch: .npy has {}, data has {}",
op.nrows,
data.nrows()
);
}
let k = op.ncols;
let mut penalty = Array2::<f64>::eye(k);
penalty.mapv_inplace(|v| v * smooth_penalty);
let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
filter_active_penalty_candidates_with_ops(vec![PenaltyCandidate {
matrix: penalty,
nullspace_dim_hint: 0,
source: PenaltySource::Other("PcaRidge".to_string()),
normalization_scale: 1.0,
kronecker_factors: None,
op: None,
}])?;
return Ok(BasisBuildResult {
design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(Arc::new(op))),
penalties,
nullspace_dims,
penaltyinfo,
ops,
null_eigenvectors,
joint_null_rotation: None,
metadata: BasisMetadata::Pca {
feature_cols: feature_cols.to_vec(),
basis_matrix: basis_matrix.clone(),
centered,
smooth_penalty,
center_mean: center_mean.cloned(),
pca_basis_path: Some(path.clone()),
chunk_size: chunk_size.max(1),
},
kronecker_factored: None,
});
}
if basis_matrix.nrows() != feature_cols.len() {
crate::bail_dim_basis!(
"Pca basis row mismatch: basis rows={}, feature columns={}",
basis_matrix.nrows(),
feature_cols.len()
);
}
let mut x = select_columns(data, feature_cols)?;
let mean = if centered {
match center_mean {
Some(mean) => mean.clone(),
None => pca_center_mean(x.view())?,
}
} else {
Array1::<f64>::zeros(feature_cols.len())
};
if centered {
for mut row in x.rows_mut() {
row -= &mean;
}
}
let design = fast_ab(&x, basis_matrix);
let k = basis_matrix.ncols();
let mut penalty = Array2::<f64>::eye(k);
penalty.mapv_inplace(|v| v * smooth_penalty);
let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
filter_active_penalty_candidates_with_ops(vec![PenaltyCandidate {
matrix: penalty,
nullspace_dim_hint: 0,
source: PenaltySource::Other("PcaRidge".to_string()),
normalization_scale: 1.0,
kronecker_factors: None,
op: None,
}])?;
Ok(BasisBuildResult {
design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(design)),
penalties,
nullspace_dims,
penaltyinfo,
ops,
null_eigenvectors,
joint_null_rotation: None,
metadata: BasisMetadata::Pca {
feature_cols: feature_cols.to_vec(),
basis_matrix: basis_matrix.clone(),
centered,
smooth_penalty,
center_mean: centered.then_some(mean),
pca_basis_path: None,
chunk_size: chunk_size.max(1),
},
kronecker_factored: None,
})
}
fn apply_by_variable_to_local_build(
mut built: LocalSmoothTermBuild,
data: ArrayView2<'_, f64>,
by_col: usize,
by: &ByVariableSpec,
term_name: &str,
) -> Result<LocalSmoothTermBuild, BasisError> {
if by_col >= data.ncols() {
crate::bail_dim_basis!(
"by-variable smooth term '{term_name}' references column {by_col}, but data has {} columns",
data.ncols()
);
}
let weights = match by {
ByVariableSpec::Numeric => data.column(by_col).to_owned(),
ByVariableSpec::Level { value_bits, .. } => data.column(by_col).mapv(|value| {
if value.to_bits() == *value_bits {
1.0
} else {
0.0
}
}),
};
if weights.iter().any(|value| !value.is_finite()) {
crate::bail_invalid_basis!(
"by-variable smooth term '{term_name}' has non-finite by-column values"
);
}
let mut dense = built
.design
.try_to_dense_by_chunks("by-variable smooth row gating")
.map_err(BasisError::InvalidInput)?;
for (mut row, &weight) in dense.rows_mut().into_iter().zip(weights.iter()) {
row.mapv_inplace(|value| value * weight);
}
built.design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(dense));
built.kronecker_factored = None;
Ok(built)
}
fn ensure_by_variable_specs_match(
kind: &BySmoothKind,
by: &ByVariableSpec,
term_name: &str,
) -> Result<(), BasisError> {
match (kind, by) {
(BySmoothKind::Numeric, ByVariableSpec::Numeric) => Ok(()),
(BySmoothKind::Level { level_bits }, ByVariableSpec::Level { value_bits, .. })
if level_bits == value_bits =>
{
Ok(())
}
_ => Err(BasisError::InvalidInput(format!(
"by-variable smooth term '{term_name}' has inconsistent by-variable specifications"
))),
}
}
fn build_factor_smooth(
data: ArrayView2<'_, f64>,
spec: &FactorSmoothSpec,
term_name: &str,
workspace: &mut crate::basis::BasisWorkspace,
) -> Result<LocalSmoothTermBuild, BasisError> {
if spec.continuous_cols.len() != 1 {
crate::bail_invalid_basis!(
"factor smooth term '{}' currently supports exactly one continuous covariate; found {}",
term_name,
spec.continuous_cols.len()
);
}
let feature_col = spec.continuous_cols[0];
let group_col = spec.group_col;
if feature_col >= data.ncols() || group_col >= data.ncols() {
crate::bail_dim_basis!(
"factor smooth term '{}' references columns ({}, {}) out of bounds for {} columns",
term_name,
feature_col,
group_col,
data.ncols()
);
}
if matches!(spec.flavour, FactorSmoothFlavour::Sz) {
let levels = resolve_factor_smooth_levels(data, group_col, spec, term_name)?;
let inner = SmoothBasisSpec::BSpline1D {
feature_col,
spec: factor_smooth_marginal_for_replay(&spec.marginal),
};
let sz_term = SmoothTermSpec {
name: term_name.to_string(),
basis: SmoothBasisSpec::FactorSumToZero {
inner: Box::new(inner),
by_col: group_col,
levels,
frozen_global_orthogonality: None,
},
shape: ShapeConstraint::None,
joint_null_rotation: None,
};
return build_single_local_smooth_term(data, &sz_term, workspace);
}
let levels = resolve_factor_smooth_levels(data, group_col, spec, term_name)?;
let n_levels = levels.len();
if n_levels < 2 {
crate::bail_invalid_basis!(
"factor smooth term '{}' requires at least two grouping levels; found {}",
term_name,
n_levels
);
}
let use_per_dim_null = matches!(
&spec.flavour,
FactorSmoothFlavour::Fs { m_null_penalty_orders }
if m_null_penalty_orders.iter().copied().max().unwrap_or(0) >= 1
);
let mut marginal_spec = factor_smooth_marginal_for_replay(&spec.marginal);
if use_per_dim_null {
marginal_spec.double_penalty = false;
}
let inner_term = SmoothTermSpec {
name: format!("{term_name}::marginal"),
basis: SmoothBasisSpec::BSpline1D {
feature_col,
spec: marginal_spec,
},
shape: ShapeConstraint::None,
joint_null_rotation: None,
};
let inner = build_single_local_smooth_term(data, &inner_term, workspace)?;
let base = inner
.design
.try_to_dense_by_chunks("factor smooth marginal")
.map_err(BasisError::InvalidInput)?;
let n = base.nrows();
let p = base.ncols();
let q = p * n_levels;
let mut dense = Array2::<f64>::zeros((n, q));
for i in 0..n {
let bits = data[[i, group_col]].to_bits();
let level_idx = levels.iter().position(|b| *b == bits).ok_or_else(|| {
BasisError::InvalidInput(format!(
"factor smooth term '{term_name}' saw an unseen grouping level at row {}",
i + 1
))
})?;
let start = level_idx * p;
dense
.slice_mut(s![i, start..start + p])
.assign(&base.row(i));
}
let marginal_penalties: Vec<Array2<f64>> = if matches!(spec.flavour, FactorSmoothFlavour::Re) {
vec![Array2::<f64>::eye(p)]
} else {
inner.penalties.clone()
};
let marginal_penaltyinfo: Vec<PenaltyInfo> = if matches!(spec.flavour, FactorSmoothFlavour::Re)
{
vec![PenaltyInfo {
source: PenaltySource::Primary,
original_index: 0,
active: true,
effective_rank: p,
dropped_reason: None,
nullspace_dim_hint: 0,
normalization_scale: 1.0,
kronecker_factors: None,
}]
} else {
inner.penaltyinfo.clone()
};
if marginal_penalties.len() != marginal_penaltyinfo.len() {
crate::bail_invalid_basis!(
"internal factor-smooth penalty metadata mismatch for term '{}': penalties={}, infos={}",
term_name,
marginal_penalties.len(),
marginal_penaltyinfo.len()
);
}
let mut penalties = Vec::<Array2<f64>>::with_capacity(marginal_penalties.len());
let mut penaltyinfo = Vec::<PenaltyInfo>::with_capacity(marginal_penalties.len());
for (penalty_pos, s_inner) in marginal_penalties.iter().enumerate() {
let mut s_big = Array2::<f64>::zeros((q, q));
for level in 0..n_levels {
let start = level * p;
s_big
.slice_mut(s![start..start + p, start..start + p])
.assign(s_inner);
}
let (s_big, factor_smooth_scale) = normalize_penalty_in_constrained_space(&s_big);
let mut info = marginal_penaltyinfo[penalty_pos].clone();
info.original_index = penalty_pos;
info.normalization_scale *= factor_smooth_scale;
info.nullspace_dim_hint = info.nullspace_dim_hint.saturating_mul(n_levels);
info.kronecker_factors = None;
penalties.push(s_big);
penaltyinfo.push(info);
}
let mut nullspaces: Vec<usize> = if matches!(spec.flavour, FactorSmoothFlavour::Re) {
vec![0]
} else {
inner
.nullspaces
.iter()
.map(|ns| ns.saturating_mul(n_levels))
.collect()
};
if use_per_dim_null
&& let Some(Some(z)) = inner.null_eigenvectors.first()
&& z.nrows() == p
{
for k in 0..z.ncols() {
let zk = z.column(k);
let mut p_k = Array2::<f64>::zeros((p, p));
for a in 0..p {
for b in 0..p {
p_k[[a, b]] = zk[a] * zk[b];
}
}
let mut s_null = Array2::<f64>::zeros((q, q));
for level in 0..n_levels {
let start = level * p;
s_null
.slice_mut(s![start..start + p, start..start + p])
.assign(&p_k);
}
let (s_null, null_scale) = normalize_penalty_in_constrained_space(&s_null);
let null_block = crate::terms::basis::analyze_penalty_block_with_op(&s_null, None)?;
if null_block.rank > 0 {
let original_index = penalties.len();
penalties.push(null_block.sym_penalty);
nullspaces.push(null_block.nullity);
penaltyinfo.push(PenaltyInfo {
source: PenaltySource::Primary,
original_index,
active: true,
effective_rank: null_block.rank,
dropped_reason: None,
nullspace_dim_hint: null_block.nullity,
normalization_scale: null_scale,
kronecker_factors: None,
});
}
}
}
let null_eigenvectors = crate::terms::basis::recompute_null_eigenvectors(&penalties)?;
let joint_null_rotation = crate::terms::basis::compute_joint_null_rotation(&penalties)?;
let (knots, degree, periodic) = match &inner.metadata {
BasisMetadata::BSpline1D {
knots,
periodic,
degree,
..
} => (
knots.clone(),
degree.unwrap_or(spec.marginal.degree),
*periodic,
),
other => {
crate::bail_invalid_basis!(
"factor smooth term '{}' produced an unexpected marginal metadata variant {:?}",
term_name,
other
);
}
};
let flavour_tag = match &spec.flavour {
FactorSmoothFlavour::Fs { .. } => "fs",
FactorSmoothFlavour::Sz => "sz",
FactorSmoothFlavour::Re => "re",
}
.to_string();
let metadata = BasisMetadata::FactorSmooth {
continuous_cols: spec.continuous_cols.clone(),
group_col,
knots,
degree,
periodic,
group_levels: levels,
flavour: flavour_tag,
};
let ops = vec![None; penalties.len()];
Ok(LocalSmoothTermBuild {
dim: q,
design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(dense)),
penalties,
ops,
nullspaces,
null_eigenvectors,
joint_null_rotation,
penaltyinfo,
pre_dropped_penaltyinfo: Vec::new(),
metadata,
linear_constraints: None,
box_reparam: false,
kronecker_factored: None,
})
}
fn resolve_factor_smooth_levels(
data: ArrayView2<'_, f64>,
group_col: usize,
spec: &FactorSmoothSpec,
term_name: &str,
) -> Result<Vec<u64>, BasisError> {
if let Some(frozen) = &spec.group_frozen_levels {
if frozen.is_empty() {
crate::bail_invalid_basis!(
"factor smooth term '{}' has an empty frozen level list",
term_name
);
}
return Ok(frozen.clone());
}
let mut bits: Vec<u64> = data.column(group_col).iter().map(|v| v.to_bits()).collect();
bits.sort_by(|a, b| {
f64::from_bits(*a)
.partial_cmp(&f64::from_bits(*b))
.unwrap_or(std::cmp::Ordering::Equal)
});
bits.dedup();
Ok(bits)
}
fn factor_smooth_marginal_for_replay(marginal: &BSplineBasisSpec) -> BSplineBasisSpec {
let mut m = marginal.clone();
m.identifiability = BSplineIdentifiability::None;
m
}
fn build_single_local_smooth_term(
data: ArrayView2<'_, f64>,
term: &SmoothTermSpec,
workspace: &mut crate::basis::BasisWorkspace,
) -> Result<LocalSmoothTermBuild, BasisError> {
if term.shape != ShapeConstraint::None && !shape_supports_basis(term) {
crate::bail_invalid_basis!(
"ShapeConstraint::{:?} is unsupported for term '{}'",
term.shape,
term.name
);
}
if let SmoothBasisSpec::ByVariable {
inner,
by_col,
kind,
by,
} = &term.basis
{
ensure_by_variable_specs_match(kind, by, &term.name)?;
let inner_term = SmoothTermSpec {
name: term.name.clone(),
basis: (**inner).clone(),
shape: term.shape,
joint_null_rotation: None,
};
let built = build_single_local_smooth_term(data, &inner_term, workspace)?;
return apply_by_variable_to_local_build(built, data, *by_col, by, &term.name);
}
let mut shape_axis_col: Option<usize> = None;
let mut built: BasisBuildResult = match &term.basis {
SmoothBasisSpec::FactorSumToZero {
inner,
by_col,
levels,
..
} => {
if *by_col >= data.ncols() {
crate::bail_dim_basis!(
"term '{}' by column {} out of bounds for {} columns",
term.name,
by_col,
data.ncols()
);
}
if levels.len() < 2 {
crate::bail_invalid_basis!(
"sum-to-zero factor smooth term '{}' requires at least two levels",
term.name
);
}
if term.shape != ShapeConstraint::None {
crate::bail_invalid_basis!(
"ShapeConstraint::{:?} is unsupported for sum-to-zero factor smooth term '{}'",
term.shape,
term.name
);
}
let inner_term = SmoothTermSpec {
name: format!("{}::inner", term.name),
basis: (**inner).clone(),
shape: ShapeConstraint::None,
joint_null_rotation: None,
};
let mut inner_built = build_single_local_smooth_term(data, &inner_term, workspace)?;
let base = inner_built
.design
.try_to_dense_by_chunks("sum-to-zero factor smooth")
.map_err(BasisError::InvalidInput)?;
let n = base.nrows();
let p = base.ncols();
let l_minus_one = levels.len() - 1;
let mut dense = Array2::<f64>::zeros((n, p * l_minus_one));
for i in 0..n {
let bits = data[[i, *by_col]].to_bits();
let level_idx = levels.iter().position(|b| *b == bits).ok_or_else(|| {
BasisError::InvalidInput(format!(
"sum-to-zero factor smooth term '{}' saw an unseen level at row {}",
term.name,
i + 1
))
})?;
if level_idx < l_minus_one {
let start = level_idx * p;
dense
.slice_mut(s![i, start..start + p])
.assign(&base.row(i));
} else {
for level in 0..l_minus_one {
let start = level * p;
dense
.slice_mut(s![i, start..start + p])
.assign(&base.row(i).mapv(|v| -v));
}
}
}
let mut penalties = Vec::<Array2<f64>>::with_capacity(inner_built.penalties.len());
let active_penalty_indices = inner_built
.penaltyinfo
.iter()
.enumerate()
.filter_map(|(idx, info)| info.active.then_some(idx))
.collect::<Vec<_>>();
if active_penalty_indices.len() != inner_built.penalties.len() {
crate::bail_invalid_basis!(
"internal sz penalty metadata mismatch: activeinfos={}, penalties={}",
active_penalty_indices.len(),
inner_built.penalties.len()
);
}
for (penalty_pos, s_inner) in inner_built.penalties.iter().enumerate() {
let mut s_big = Array2::<f64>::zeros((p * l_minus_one, p * l_minus_one));
for a in 0..l_minus_one {
for b in 0..l_minus_one {
let factor = if a == b { 2.0 } else { 1.0 };
let mut block = s_big.slice_mut(s![a * p..(a + 1) * p, b * p..(b + 1) * p]);
block.assign(&s_inner.mapv(|v| v * factor));
}
}
let (s_big, factor_smooth_scale) = normalize_penalty_in_constrained_space(&s_big);
let info_idx = active_penalty_indices[penalty_pos];
inner_built.penaltyinfo[info_idx].normalization_scale *= factor_smooth_scale;
penalties.push(s_big);
}
inner_built.dim = p * l_minus_one;
inner_built.design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(dense));
inner_built.penalties = penalties;
inner_built.ops = vec![None; inner_built.penalties.len()];
inner_built.nullspaces = inner_built
.nullspaces
.iter()
.map(|ns| ns.saturating_mul(l_minus_one))
.collect();
inner_built.null_eigenvectors =
crate::terms::basis::recompute_null_eigenvectors(&inner_built.penalties)?;
inner_built.joint_null_rotation =
crate::terms::basis::compute_joint_null_rotation(&inner_built.penalties)?;
inner_built.kronecker_factored = None;
return Ok(inner_built);
}
SmoothBasisSpec::BSpline1D { feature_col, spec } => {
if *feature_col >= data.ncols() {
crate::bail_dim_basis!(
"term '{}' feature column {} out of bounds for {} columns",
term.name,
feature_col,
data.ncols()
);
}
let mut spec_local = spec.clone();
if term.shape != ShapeConstraint::None {
spec_local.identifiability = BSplineIdentifiability::None;
}
build_bspline_basis_1d(data.column(*feature_col), &spec_local)?
}
SmoothBasisSpec::ThinPlate {
feature_cols,
spec,
input_scales,
} => {
if term.shape != ShapeConstraint::None {
if feature_cols.len() != 1 {
crate::bail_invalid_basis!(
"ShapeConstraint::{:?} for term '{}' on ThinPlate basis requires exactly 1 feature axis; found {}",
term.shape,
term.name,
feature_cols.len()
);
}
shape_axis_col = Some(feature_cols[0]);
}
let mut x = select_columns(data, feature_cols)?;
let (scales, length_scale_eff) = if let Some(s) = input_scales {
apply_input_standardization(&mut x, s);
(
Some(s.clone()),
compensate_length_scale_for_standardization(spec.length_scale, s),
)
} else if let Some(s) = compute_spatial_input_scales(x.view()) {
apply_input_standardization(&mut x, &s);
let l_eff = compensate_length_scale_for_standardization(spec.length_scale, &s);
(Some(s), l_eff)
} else {
(None, spec.length_scale)
};
let mut spec_local = spec.clone();
spec_local.length_scale = length_scale_eff;
if matches!(
spec_local.identifiability,
SpatialIdentifiability::OrthogonalToParametric
) {
spec_local.identifiability = SpatialIdentifiability::None;
}
let mut result = build_thin_plate_basis(x.view(), &spec_local).map_err(|err| {
rewrite_thin_plate_knots_error(err, &term.name, feature_cols.len(), spec)
})?;
match &mut result.metadata {
BasisMetadata::ThinPlate {
input_scales: ms,
length_scale,
..
} => {
*ms = scales;
*length_scale = spec.length_scale;
}
BasisMetadata::Duchon {
input_scales: ms,
length_scale,
..
} => {
*ms = scales;
*length_scale = Some(spec.length_scale);
}
_ => {}
}
result
}
SmoothBasisSpec::Sphere { feature_cols, spec } => {
if term.shape != ShapeConstraint::None {
crate::bail_invalid_basis!(
"ShapeConstraint::{:?} for term '{}' is not supported on spherical splines",
term.shape,
term.name
);
}
let x = select_columns(data, feature_cols)?;
build_spherical_spline_basis(x.view(), spec)?
}
SmoothBasisSpec::ConstantCurvature { feature_cols, spec } => {
if term.shape != ShapeConstraint::None {
crate::bail_invalid_basis!(
"ShapeConstraint::{:?} for term '{}' is not supported on constant-curvature smooths",
term.shape,
term.name
);
}
let x = select_columns(data, feature_cols)?;
build_constant_curvature_basis(x.view(), spec)?
}
SmoothBasisSpec::MeasureJet {
feature_cols,
spec,
input_scales,
} => {
if term.shape != ShapeConstraint::None {
crate::bail_invalid_basis!(
"ShapeConstraint::{:?} for term '{}' is not supported on measure-jet smooths",
term.shape,
term.name
);
}
let mut x = select_columns(data, feature_cols)?;
let (scales, length_scale_eff) = if let Some(s) = input_scales {
apply_input_standardization(&mut x, s);
(Some(s.clone()), spec.length_scale)
} else if let Some(s) = compute_spatial_input_scales(x.view()) {
apply_input_standardization(&mut x, &s);
let l_eff = if spec.length_scale > 0.0 {
compensate_length_scale_for_standardization(spec.length_scale, &s)
} else {
spec.length_scale
};
(Some(s), l_eff)
} else {
(None, spec.length_scale)
};
let mut spec_local = spec.clone();
spec_local.length_scale = length_scale_eff;
let mut result = build_measure_jet_basis(x.view(), &spec_local)?;
if let BasisMetadata::MeasureJet {
input_scales: ms, ..
} = &mut result.metadata
{
*ms = scales;
}
result
}
SmoothBasisSpec::Matern {
feature_cols,
spec,
input_scales,
} => {
if term.shape != ShapeConstraint::None {
if feature_cols.len() != 1 {
crate::bail_invalid_basis!(
"ShapeConstraint::{:?} for term '{}' on Matern basis requires exactly 1 feature axis; found {}",
term.shape,
term.name,
feature_cols.len()
);
}
shape_axis_col = Some(feature_cols[0]);
}
let mut x = select_columns(data, feature_cols)?;
let (scales, length_scale_eff) = if let Some(s) = input_scales {
apply_input_standardization(&mut x, s);
(
Some(s.clone()),
compensate_length_scale_for_standardization(spec.length_scale, s),
)
} else if let Some(s) = compute_spatial_input_scales(x.view()) {
apply_input_standardization(&mut x, &s);
let l_eff = compensate_length_scale_for_standardization(spec.length_scale, &s);
(Some(s), l_eff)
} else {
(None, spec.length_scale)
};
let mut spec_local = spec.clone();
spec_local.length_scale = length_scale_eff;
let mut result = build_matern_basiswithworkspace(x.view(), &spec_local, workspace)?;
if let BasisMetadata::Matern {
input_scales,
length_scale,
..
} = &mut result.metadata
{
*input_scales = scales;
*length_scale = spec.length_scale;
}
result
}
SmoothBasisSpec::Duchon {
feature_cols,
spec,
input_scales,
} => {
if term.shape != ShapeConstraint::None {
if feature_cols.len() != 1 {
crate::bail_invalid_basis!(
"ShapeConstraint::{:?} for term '{}' on Duchon basis requires exactly 1 feature axis; found {}",
term.shape,
term.name,
feature_cols.len()
);
}
shape_axis_col = Some(feature_cols[0]);
}
let mut x = select_columns(data, feature_cols)?;
let (scales, length_scale_eff) = if let Some(s) = input_scales {
apply_input_standardization(&mut x, s);
(
Some(s.clone()),
compensate_optional_length_scale_for_standardization(spec.length_scale, s),
)
} else if let Some(s) = compute_spatial_input_scales(x.view()) {
apply_input_standardization(&mut x, &s);
let l_eff =
compensate_optional_length_scale_for_standardization(spec.length_scale, &s);
(Some(s), l_eff)
} else {
(None, spec.length_scale)
};
let mut spec_local = spec.clone();
spec_local.length_scale = length_scale_eff;
if matches!(
spec_local.identifiability,
SpatialIdentifiability::OrthogonalToParametric
) {
spec_local.identifiability = SpatialIdentifiability::None;
}
let mut result = build_duchon_basiswithworkspace(x.view(), &spec_local, workspace)?;
if let BasisMetadata::Duchon {
input_scales,
length_scale,
..
} = &mut result.metadata
{
*input_scales = scales;
*length_scale = spec.length_scale;
}
result
}
SmoothBasisSpec::Pca {
feature_cols,
basis_matrix,
centered,
smooth_penalty,
center_mean,
pca_basis_path,
chunk_size,
} => {
if term.shape != ShapeConstraint::None {
crate::bail_invalid_basis!(
"ShapeConstraint::{:?} for term '{}' is not supported on Pca basis",
term.shape,
term.name
);
}
build_pca_smooth_basis(
data,
feature_cols,
basis_matrix,
*centered,
*smooth_penalty,
center_mean.as_ref(),
pca_basis_path.as_ref(),
*chunk_size,
)?
}
SmoothBasisSpec::TensorBSpline { feature_cols, spec } => {
build_tensor_bspline_basis(data, feature_cols, spec)?
}
SmoothBasisSpec::ByVariable { .. } => {
crate::bail_invalid_basis!(
"internal: ByVariable smooths must return before inner basis dispatch"
);
}
SmoothBasisSpec::BySmooth { .. } => {
crate::bail_invalid_basis!("internal: BySmooth smooths must be lowered to ByVariable before inner basis dispatch"
.to_string(),);
}
SmoothBasisSpec::FactorSmooth { spec } => {
if term.shape != ShapeConstraint::None {
crate::bail_invalid_basis!(
"ShapeConstraint::{:?} is unsupported for factor smooth term '{}'",
term.shape,
term.name
);
}
return build_factor_smooth(data, spec, &term.name, workspace);
}
};
if let SmoothBasisSpec::Matern { .. } = &term.basis {
let (penalties, nullspace_dims, penaltyinfo) =
matern_operator_penalty_triplet_from_metadata(&built.metadata)?;
built.penalties = penalties;
built.nullspace_dims = nullspace_dims;
built.penaltyinfo = penaltyinfo;
}
let p_local = built.design.ncols();
let mut metadata = built.metadata.clone();
let kron_factored = if term.shape == ShapeConstraint::None {
built.kronecker_factored
} else {
None
};
let mut design_t = built.design;
let mut penalties_t: Vec<Array2<f64>> = built.penalties;
let mut ops_t: Vec<Option<std::sync::Arc<dyn crate::terms::analytic_penalties::PenaltyOp>>> =
built.ops;
if matches!(
spatial_identifiability_policy(term),
Some(SpatialIdentifiability::OrthogonalToParametric)
) {
metadata = freeze_raw_spatial_metadata(metadata, design_t.ncols());
}
let active_penaltyinfo_t = built
.penaltyinfo
.iter()
.filter(|info| info.active)
.cloned()
.collect::<Vec<_>>();
let pre_dropped_penaltyinfo_t = built
.penaltyinfo
.iter()
.filter(|info| !info.active)
.cloned()
.collect::<Vec<_>>();
let use_box_reparam =
term.shape != ShapeConstraint::None && shape_uses_box_reparameterization(&term.basis);
let mut coefficient_transform_for_constraints: Option<Array2<f64>> = None;
if let Some((order, sign)) = shape_order_and_sign(term.shape)
&& use_box_reparam
{
let t = if order == 2 {
let bspline_meta = match &metadata {
BasisMetadata::BSpline1D {
knots,
degree,
periodic,
..
} if periodic.is_none() => Some((knots.clone(), degree.unwrap_or(0))),
_ => None,
};
match bspline_meta {
Some((knots, degree)) if degree >= 1 => {
let greville = crate::basis::compute_greville_abscissae(&knots, degree)?;
if greville.len() != p_local {
crate::bail_invalid_basis!(
"shape-constraint Greville abscissae count {} does not match basis dim {} for term '{}'",
greville.len(),
p_local,
term.name
);
}
convex_divided_difference_transform_matrix(&greville, sign)?
}
_ => cumulative_sum_transform_matrix(p_local, order, sign),
}
} else {
cumulative_sum_transform_matrix(p_local, order, sign)
};
coefficient_transform_for_constraints = Some(t.clone());
let inner_dense = match design_t {
DesignMatrix::Dense(d) => d,
DesignMatrix::Sparse(sp) => crate::matrix::DenseDesignMatrix::from(
sp.try_to_dense_arc("shape-constrained coefficient transform")
.map_err(BasisError::InvalidInput)?,
),
};
let coeff_op = crate::matrix::CoefficientTransformOperator::new(inner_dense, t.clone())
.map_err(|e| BasisError::InvalidInput(format!("CoefficientTransformOperator: {e}")))?;
design_t = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(Arc::new(coeff_op)));
if penalties_t.len() != active_penaltyinfo_t.len() {
crate::bail_invalid_basis!(
"internal box-reparam penalty/info mismatch for term '{}': penalties={}, infos={}",
term.name,
penalties_t.len(),
active_penaltyinfo_t.len()
);
}
let transformed_wiggliness = penalties_t
.iter()
.zip(active_penaltyinfo_t.iter())
.find(|(_, info)| !matches!(info.source, PenaltySource::DoublePenaltyNullspace))
.map(|(s_local, _)| {
let tt_s = fast_atb(&t, s_local);
fast_ab(&tt_s, &t)
});
let mut rebuilt = Vec::with_capacity(penalties_t.len());
for (s_local, info) in penalties_t.iter().zip(active_penaltyinfo_t.iter()) {
if matches!(info.source, PenaltySource::DoublePenaltyNullspace) {
let s_wiggle_t = transformed_wiggliness.as_ref().ok_or_else(|| {
BasisError::InvalidInput(format!(
"box-reparam term '{}' has a double-penalty ridge but no primary wiggliness penalty to derive its nullspace from",
term.name
))
})?;
let ridge = crate::terms::basis::build_nullspace_shrinkage_penalty(s_wiggle_t)?
.map(|shrink| shrink.sym_penalty)
.unwrap_or_else(|| Array2::<f64>::zeros((p_local, p_local)));
rebuilt.push(ridge);
} else {
let tt_s = fast_atb(&t, s_local);
rebuilt.push(fast_ab(&tt_s, &t));
}
}
penalties_t = rebuilt;
ops_t = vec![None; penalties_t.len()];
}
if penalties_t.len() != active_penaltyinfo_t.len() {
crate::bail_invalid_basis!(
"internal penalty metadata mismatch for term '{}': active penalties={}, active infos={}",
term.name,
penalties_t.len(),
active_penaltyinfo_t.len()
);
}
if ops_t.len() != penalties_t.len() {
ops_t = vec![None; penalties_t.len()];
}
let penalty_candidates = penalties_t
.into_iter()
.zip(active_penaltyinfo_t.into_iter())
.zip(ops_t.into_iter())
.map(
|((matrix, info), op_in)| -> Result<PenaltyCandidate, BasisError> {
let (matrix, c_new) = normalize_penalty_in_constrained_space(&matrix);
let preserve_margin_scale =
matches!(&info.source, PenaltySource::TensorMarginal { .. });
let (matrix, normalization_scale, op_scale, kronecker_scale) =
if preserve_margin_scale {
(
matrix.mapv(|v| v * c_new),
info.normalization_scale,
1.0,
1.0,
)
} else {
(
matrix,
info.normalization_scale * c_new,
1.0 / c_new,
1.0 / c_new,
)
};
let scaled_op = if op_scale > 0.0 && op_scale.is_finite() {
op_in.map(|op| {
std::sync::Arc::new(crate::terms::analytic_penalties::ScaledPenaltyOp::new(
op, op_scale,
))
as std::sync::Arc<dyn crate::terms::analytic_penalties::PenaltyOp>
})
} else {
None
};
let kronecker_factors = info.kronecker_factors.map(|mut factors| {
if let Some(first) = factors.first_mut() {
first.mapv_inplace(|v| v * kronecker_scale);
}
factors
});
Ok(PenaltyCandidate {
nullspace_dim_hint: info.nullspace_dim_hint,
matrix,
source: info.source,
normalization_scale,
kronecker_factors,
op: scaled_op,
})
},
)
.collect::<Result<Vec<_>, _>>()?;
let (penalties_t, nullspaces_t, penaltyinfo_t, null_eigenvectors_t, ops_t) =
crate::terms::basis::filter_active_penalty_candidates_with_ops(penalty_candidates)?;
let shape_linear_constraints = if term.shape != ShapeConstraint::None && !use_box_reparam {
let axis = shape_axis_col.ok_or_else(|| {
BasisError::InvalidInput(format!(
"internal shape-constraint axis missing for term '{}'",
term.name
))
})?;
let (x_shape_eval, design_shape_eval) =
build_shape_constraint_design_1d(data, term, &metadata, axis)?;
build_shape_linear_constraints_1d(
x_shape_eval.view(),
design_shape_eval.view(),
term.shape,
)?
} else {
None
};
let linear_constraints_local = merge_linear_constraints_global(shape_linear_constraints, None);
let joint_null_rotation = match term.joint_null_rotation.clone() {
Some(persisted) => Some(persisted),
None if smooth_has_frozen_identifiability(term) => None,
None if kron_factored.is_some() => None,
None => crate::terms::basis::compute_joint_null_rotation(&penalties_t)?,
};
Ok(LocalSmoothTermBuild {
dim: p_local,
design: design_t,
penalties: penalties_t,
ops: ops_t,
nullspaces: nullspaces_t,
null_eigenvectors: null_eigenvectors_t,
joint_null_rotation,
penaltyinfo: penaltyinfo_t,
pre_dropped_penaltyinfo: pre_dropped_penaltyinfo_t,
metadata,
linear_constraints: linear_constraints_local,
box_reparam: use_box_reparam,
kronecker_factored: kron_factored,
})
}
pub fn build_smooth_design(
data: ArrayView2<'_, f64>,
terms: &[SmoothTermSpec],
) -> Result<RawSmoothDesign, BasisError> {
let mut ws = crate::basis::BasisWorkspace::new();
build_smooth_design_withworkspace(data, terms, &mut ws)
}
pub fn build_smooth_design_withworkspace(
data: ArrayView2<'_, f64>,
terms: &[SmoothTermSpec],
workspace: &mut crate::basis::BasisWorkspace,
) -> Result<RawSmoothDesign, BasisError> {
validate_smooth_terms_finite_inputs(data, terms)?;
build_smooth_design_withworkspace_unvalidated(data, terms, workspace)
}
fn build_smooth_design_withworkspace_unvalidated(
data: ArrayView2<'_, f64>,
terms: &[SmoothTermSpec],
workspace: &mut crate::basis::BasisWorkspace,
) -> Result<RawSmoothDesign, BasisError> {
let mut planned_blocks = plan_joint_spatial_centers_for_term_blocks(data, &[terms.to_vec()])?;
let planned_terms = planned_blocks.pop().ok_or_else(|| {
BasisError::InvalidInput(
"joint spatial center planner returned no smooth blocks".to_string(),
)
})?;
let policy = workspace.policy().clone();
let local_builds: Vec<LocalSmoothTermBuild> = {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
planned_terms
.into_par_iter()
.map(|term| {
let mut term_workspace = crate::basis::BasisWorkspace::with_policy(policy.clone());
build_single_local_smooth_term(data, &term, &mut term_workspace)
})
.collect::<Result<Vec<_>, _>>()?
};
let total_p: usize = local_builds.iter().map(|built| built.dim).sum();
let mut local_designs: Vec<DesignMatrix> = Vec::with_capacity(local_builds.len());
let mut terms_out = Vec::<SmoothTerm>::with_capacity(terms.len());
let mut penalties_global = Vec::<BlockwisePenalty>::new();
let mut nullspace_dims_global = Vec::<usize>::new();
let mut penaltyinfo_global = Vec::<PenaltyBlockInfo>::new();
let mut dropped_penaltyinfo_global = Vec::<DroppedPenaltyBlockInfo>::new();
let mut coefficient_lower_bounds = Array1::<f64>::from_elem(total_p, f64::NEG_INFINITY);
let mut any_bounds = false;
let mut linear_constraintsrows: Vec<(usize, usize, Array1<f64>)> = Vec::new();
let mut linear_constraints_b: Vec<f64> = Vec::new();
let mut col_start = 0usize;
for (term, mut built) in terms.iter().zip(local_builds.into_iter()) {
let p_local = built.dim;
let col_end = col_start + p_local;
let lb_local = if built.box_reparam {
shape_lower_bounds_local(term.shape, p_local)
} else {
None
};
let applied_rotation: Option<crate::terms::basis::JointNullRotation> = match (
built.joint_null_rotation.take(),
lb_local.is_some(),
built.linear_constraints.is_some(),
) {
(Some(rot), false, false) => {
let q = &rot.rotation;
let dense = built
.design
.try_to_dense_by_chunks("joint-null absorption rotation")
.map_err(BasisError::InvalidInput)?;
let rotated = crate::linalg::faer_ndarray::fast_ab(&dense, q);
built.design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(rotated));
built.penalties = built
.penalties
.into_iter()
.map(|s_local| {
let qt_s = crate::linalg::faer_ndarray::fast_atb(q, &s_local);
crate::linalg::faer_ndarray::fast_ab(&qt_s, q)
})
.collect();
built.ops = vec![None; built.penalties.len()];
built.kronecker_factored = None;
Some(rot)
}
(Some(_), _, _) => None,
(None, _, _) => None,
};
let activeinfos = built
.penaltyinfo
.iter()
.filter(|info| info.active)
.collect::<Vec<_>>();
if activeinfos.len() != built.penalties.len() {
crate::bail_invalid_basis!(
"internal penalty info mismatch for term '{}': activeinfos={}, penalties={}",
term.name,
activeinfos.len(),
built.penalties.len()
);
}
for (((s_local, &ns), info), op_local) in built
.penalties
.iter()
.zip(built.nullspaces.iter())
.zip(activeinfos.into_iter())
.zip(built.ops.iter())
{
let global_index = penalties_global.len();
penalties_global.push(
BlockwisePenalty::new(col_start..col_end, s_local.clone())
.with_op(op_local.clone()),
);
nullspace_dims_global.push(ns);
let mut penalty = info.clone();
penalty.nullspace_dim_hint = ns;
penaltyinfo_global.push(PenaltyBlockInfo {
global_index,
termname: Some(term.name.clone()),
penalty,
});
}
for info in built.penaltyinfo.iter().filter(|info| !info.active) {
dropped_penaltyinfo_global.push(DroppedPenaltyBlockInfo {
termname: Some(term.name.clone()),
penalty: info.clone(),
});
}
for info in &built.pre_dropped_penaltyinfo {
dropped_penaltyinfo_global.push(DroppedPenaltyBlockInfo {
termname: Some(term.name.clone()),
penalty: info.clone(),
});
}
if let Some(lin_local) = &built.linear_constraints {
for r in 0..lin_local.a.nrows() {
linear_constraintsrows.push((col_start, col_end, lin_local.a.row(r).to_owned()));
linear_constraints_b.push(lin_local.b[r]);
}
}
if let Some(lb_local) = &lb_local {
coefficient_lower_bounds
.slice_mut(s![col_start..col_end])
.assign(lb_local);
any_bounds = true;
}
local_designs.push(built.design);
terms_out.push(SmoothTerm {
name: term.name.clone(),
coeff_range: col_start..col_end,
shape: term.shape,
penalties_local: built.penalties,
nullspace_dims: built.nullspaces,
penaltyinfo_local: built.penaltyinfo,
metadata: built.metadata,
lower_bounds_local: lb_local,
linear_constraints_local: built.linear_constraints,
kronecker_factored: built.kronecker_factored.take(),
joint_null_rotation: applied_rotation,
unabsorbed_global_orthogonality: None,
});
col_start = col_end;
}
assert_eq!(
penalties_global.len(),
nullspace_dims_global.len(),
"global smooth penalty/nullspace bookkeeping diverged"
);
assert_eq!(
penalties_global.len(),
penaltyinfo_global.len(),
"global smooth penalty metadata bookkeeping diverged"
);
Ok(RawSmoothDesign {
term_designs: local_designs,
penalties: penalties_global,
nullspace_dims: nullspace_dims_global,
penaltyinfo: penaltyinfo_global,
dropped_penaltyinfo: dropped_penaltyinfo_global,
terms: terms_out,
coefficient_lower_bounds: if any_bounds {
Some(coefficient_lower_bounds)
} else {
None
},
linear_constraints: if linear_constraintsrows.is_empty() {
None
} else {
let mut a = Array2::<f64>::zeros((linear_constraintsrows.len(), total_p));
for (i, (cs, ce, values)) in linear_constraintsrows.iter().enumerate() {
a.row_mut(i).slice_mut(s![*cs..*ce]).assign(values);
}
Some(LinearInequalityConstraints {
a,
b: Array1::from_vec(linear_constraints_b),
})
},
})
}