use crate::basis::{
BSplineBasisSpec, BSplineIdentifiability, BSplineKnotSpec, BasisBuildResult, BasisError,
BasisMetadata, BasisPsiDerivativeResult, BasisPsiSecondDerivativeResult, CenterStrategy,
CenterStrategyKind, DuchonBasisSpec, DuchonNullspaceOrder, DuchonOperatorPenaltySpec,
KroneckerFactoredBasis, MaternBasisSpec, MaternIdentifiability, PenaltyCandidate, PenaltyInfo,
PenaltySource, SpatialIdentifiability, ThinPlateBasisSpec, apply_sum_to_zero_constraint,
build_bspline_basis_1d, build_duchon_basis, build_duchon_basis_log_kappa_aniso_derivatives,
build_duchon_basis_log_kappa_derivatives, build_duchon_basiswithworkspace, build_matern_basis,
build_matern_basis_log_kappa_aniso_derivatives, build_matern_basis_log_kappa_derivatives,
build_matern_basiswithworkspace, build_matern_collocation_operator_matrices,
build_thin_plate_basis, build_thin_plate_basis_log_kappa_derivatives, center_strategy_is_auto,
center_strategy_kind, center_strategy_num_centers, center_strategy_with_num_centers,
estimate_penalty_nullity, filter_active_penalty_candidates,
filter_active_penalty_candidates_with_ops, initial_aniso_contrasts,
orthogonality_transform_for_design, pairwise_distance_bounds, pairwise_distance_bounds_sampled,
points_in_aniso_y_space, select_centers_by_strategy,
};
use crate::construction::{
kronecker_logdet_and_derivatives, kronecker_marginal_eigensystems, kronecker_product,
};
use crate::custom_family::{
BlockGeometryDirectionalDerivative, BlockWorkingSet, BlockwiseFitOptions, CustomFamily,
CustomFamilyBlockPsiDerivative, CustomFamilyWarmStart, ExactNewtonJointPsiTerms,
ExactNewtonOuterObjective, FamilyEvaluation, ParameterBlockSpec, ParameterBlockState,
PenaltyMatrix, evaluate_custom_family_joint_hyper, evaluate_custom_family_joint_hyper_efs,
fit_custom_family,
};
use crate::estimate::{
EstimationError, ExternalOptimOptions, FitInference, FitOptions, FittedLinkState, PenaltySpec,
UnifiedFitResult, UnifiedFitResultParts, fit_gamwith_heuristic_lambdas,
reml::DirectionalHyperParam,
};
use crate::faer_ndarray::{fast_atb, fast_atv};
use crate::families::strategy::{FamilyStrategy, strategy_for_family};
use crate::matrix::{
BlockDesignOperator, CoefficientTransformOperator, DesignBlock, DesignMatrix,
RandomEffectOperator, SymmetricMatrix, TensorProductDesignOperator,
};
use crate::mixture_link::{
logit_inverse_link_jet5, state_from_beta_logisticspec, state_from_sasspec, state_fromspec,
};
use crate::pirls::LinearInequalityConstraints;
use crate::types::{
InverseLink, LatentCLogLogState, LikelihoodFamily, MixtureLinkState, SasLinkState,
};
use faer::sparse::{SparseColMat, Triplet};
use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s};
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, BTreeSet};
use std::f64;
use std::ops::Range;
use std::sync::atomic::AtomicUsize;
use std::sync::{Arc, Mutex};
fn describe_thin_plate_center_request(strategy: &CenterStrategy) -> String {
match strategy {
CenterStrategy::Auto(inner) => describe_thin_plate_center_request(inner),
CenterStrategy::UserProvided(centers) => format!("{} centers", centers.nrows()),
CenterStrategy::EqualMass { num_centers }
| CenterStrategy::EqualMassCovarRepresentative { num_centers }
| CenterStrategy::FarthestPoint { num_centers }
| CenterStrategy::KMeans { num_centers, .. } => format!("{num_centers} centers"),
CenterStrategy::UniformGrid { points_per_dim } => {
format!("uniform grid with {points_per_dim} points per dimension")
}
}
}
fn rewrite_thin_plate_knots_error(
err: BasisError,
termname: &str,
feature_count: usize,
spec: &ThinPlateBasisSpec,
) -> BasisError {
match err {
BasisError::InvalidInput(msg)
if msg.contains("thin-plate spline requires at least")
&& (msg.contains("centers to span") || msg.contains("knots to span")) =>
{
let min_centers = crate::basis::thin_plate_polynomial_basis_dimension(feature_count);
let requested = describe_thin_plate_center_request(&spec.center_strategy);
BasisError::InvalidInput(format!(
"joint TPS term '{termname}' over {feature_count} covariates with {requested} is invalid; minimum centers is {min_centers}"
))
}
BasisError::InvalidInput(msg)
if msg.starts_with("requested ") && msg.contains(" knots but only ") =>
{
let min_centers = crate::basis::thin_plate_polynomial_basis_dimension(feature_count);
let requested = describe_thin_plate_center_request(&spec.center_strategy);
BasisError::InvalidInput(format!(
"joint TPS term '{termname}' over {feature_count} covariates with {requested} is invalid; minimum centers is {min_centers}"
))
}
other => other,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ShapeConstraint {
None,
MonotoneIncreasing,
MonotoneDecreasing,
Convex,
Concave,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SmoothBasisSpec {
BSpline1D {
feature_col: usize,
spec: BSplineBasisSpec,
},
ThinPlate {
feature_cols: Vec<usize>,
spec: ThinPlateBasisSpec,
#[serde(default)]
input_scales: Option<Vec<f64>>,
},
Matern {
feature_cols: Vec<usize>,
spec: MaternBasisSpec,
#[serde(default)]
input_scales: Option<Vec<f64>>,
},
Duchon {
feature_cols: Vec<usize>,
spec: DuchonBasisSpec,
#[serde(default)]
input_scales: Option<Vec<f64>>,
},
TensorBSpline {
feature_cols: Vec<usize>,
spec: TensorBSplineSpec,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorBSplineSpec {
pub marginalspecs: Vec<BSplineBasisSpec>,
pub double_penalty: bool,
#[serde(default)]
pub identifiability: TensorBSplineIdentifiability,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TensorBSplineIdentifiability {
None,
SumToZero,
FrozenTransform { transform: Array2<f64> },
}
impl Default for TensorBSplineIdentifiability {
fn default() -> Self {
Self::SumToZero
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SmoothTermSpec {
pub name: String,
pub basis: SmoothBasisSpec,
pub shape: ShapeConstraint,
}
#[derive(Debug, Clone)]
pub struct SmoothTerm {
pub name: String,
pub coeff_range: Range<usize>,
pub shape: ShapeConstraint,
pub penalties_local: Vec<Array2<f64>>,
pub nullspace_dims: Vec<usize>,
pub penaltyinfo_local: Vec<PenaltyInfo>,
pub metadata: BasisMetadata,
pub lower_bounds_local: Option<Array1<f64>>,
pub linear_constraints_local: Option<LinearInequalityConstraints>,
pub kronecker_factored: Option<KroneckerFactoredBasis>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PenaltyBlockInfo {
pub global_index: usize,
pub termname: Option<String>,
pub penalty: PenaltyInfo,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DroppedPenaltyBlockInfo {
pub termname: Option<String>,
pub penalty: PenaltyInfo,
}
#[derive(Debug, Clone)]
pub struct SmoothDesign {
pub term_designs: Vec<DesignMatrix>,
pub penalties: Vec<BlockwisePenalty>,
pub nullspace_dims: Vec<usize>,
pub penaltyinfo: Vec<PenaltyBlockInfo>,
pub dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
pub terms: Vec<SmoothTerm>,
pub coefficient_lower_bounds: Option<Array1<f64>>,
pub linear_constraints: Option<LinearInequalityConstraints>,
}
impl SmoothDesign {
pub fn total_smooth_cols(&self) -> usize {
self.term_designs.iter().map(DesignMatrix::ncols).sum()
}
pub fn nrows(&self) -> usize {
self.term_designs.first().map_or(0, DesignMatrix::nrows)
}
}
#[derive(Debug, Clone)]
pub struct RawSmoothDesign {
pub term_designs: Vec<DesignMatrix>,
pub penalties: Vec<BlockwisePenalty>,
pub nullspace_dims: Vec<usize>,
pub penaltyinfo: Vec<PenaltyBlockInfo>,
pub dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
pub terms: Vec<SmoothTerm>,
pub coefficient_lower_bounds: Option<Array1<f64>>,
pub linear_constraints: Option<LinearInequalityConstraints>,
}
impl RawSmoothDesign {
pub fn total_smooth_cols(&self) -> usize {
self.term_designs.iter().map(DesignMatrix::ncols).sum()
}
pub fn nrows(&self) -> usize {
self.term_designs.first().map_or(0, DesignMatrix::nrows)
}
}
impl From<RawSmoothDesign> for SmoothDesign {
fn from(value: RawSmoothDesign) -> Self {
Self {
term_designs: value.term_designs,
penalties: value.penalties,
nullspace_dims: value.nullspace_dims,
penaltyinfo: value.penaltyinfo,
dropped_penaltyinfo: value.dropped_penaltyinfo,
terms: value.terms,
coefficient_lower_bounds: value.coefficient_lower_bounds,
linear_constraints: value.linear_constraints,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum BoundedCoefficientPriorSpec {
None,
Uniform,
Beta { a: f64, b: f64 },
}
impl Default for BoundedCoefficientPriorSpec {
fn default() -> Self {
Self::None
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub enum LinearCoefficientGeometry {
#[default]
Unconstrained,
Bounded {
min: f64,
max: f64,
#[serde(default)]
prior: BoundedCoefficientPriorSpec,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LinearTermSpec {
pub name: String,
pub feature_col: usize,
#[serde(default = "default_linear_term_double_penalty")]
pub double_penalty: bool,
#[serde(default)]
pub coefficient_geometry: LinearCoefficientGeometry,
#[serde(default)]
pub coefficient_min: Option<f64>,
#[serde(default)]
pub coefficient_max: Option<f64>,
}
const fn default_linear_term_double_penalty() -> bool {
true
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RandomEffectTermSpec {
pub name: String,
pub feature_col: usize,
pub drop_first_level: bool,
#[serde(default)]
pub frozen_levels: Option<Vec<u64>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TermCollectionSpec {
pub linear_terms: Vec<LinearTermSpec>,
pub random_effect_terms: Vec<RandomEffectTermSpec>,
pub smooth_terms: Vec<SmoothTermSpec>,
}
impl TermCollectionSpec {
pub fn validate_frozen(&self, label: &str) -> Result<(), String> {
for linear in &self.linear_terms {
if let (Some(min), Some(max)) = (linear.coefficient_min, linear.coefficient_max)
&& (!min.is_finite() || !max.is_finite() || min > max)
{
return Err(format!(
"{label} linear term '{}' has invalid coefficient constraint [{min}, {max}]",
linear.name
));
}
if let Some(min) = linear.coefficient_min
&& !min.is_finite()
{
return Err(format!(
"{label} linear term '{}' has non-finite coefficient minimum {min}",
linear.name
));
}
if let Some(max) = linear.coefficient_max
&& !max.is_finite()
{
return Err(format!(
"{label} linear term '{}' has non-finite coefficient maximum {max}",
linear.name
));
}
if let LinearCoefficientGeometry::Bounded { min, max, prior } =
&linear.coefficient_geometry
{
if !min.is_finite() || !max.is_finite() || min >= max {
return Err(format!(
"{label} bounded term '{}' has invalid bounds [{min}, {max}]",
linear.name
));
}
match prior {
BoundedCoefficientPriorSpec::None | BoundedCoefficientPriorSpec::Uniform => {}
BoundedCoefficientPriorSpec::Beta { a, b } => {
if !a.is_finite() || !b.is_finite() || *a < 1.0 || *b < 1.0 {
return Err(format!(
"{label} bounded term '{}' has invalid Beta prior ({a}, {b})",
linear.name
));
}
}
}
}
}
for st in &self.smooth_terms {
match &st.basis {
SmoothBasisSpec::BSpline1D { spec, .. } => {
if !matches!(spec.knotspec, BSplineKnotSpec::Provided(_)) {
return Err(format!(
"{label} term '{}' is not frozen: BSpline knotspec must be Provided",
st.name
));
}
}
SmoothBasisSpec::ThinPlate { spec, .. } => {
if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
return Err(format!(
"{label} term '{}' is not frozen: ThinPlate centers must be UserProvided",
st.name
));
}
if matches!(
spec.identifiability,
SpatialIdentifiability::OrthogonalToParametric
) {
return Err(format!(
"{label} term '{}' is not frozen: ThinPlate identifiability must be FrozenTransform or None",
st.name
));
}
}
SmoothBasisSpec::Matern { spec, .. } => {
if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
return Err(format!(
"{label} term '{}' is not frozen: Matern centers must be UserProvided",
st.name
));
}
}
SmoothBasisSpec::Duchon { spec, .. } => {
if !matches!(spec.center_strategy, CenterStrategy::UserProvided(_)) {
return Err(format!(
"{label} term '{}' is not frozen: Duchon centers must be UserProvided",
st.name
));
}
if matches!(
spec.identifiability,
SpatialIdentifiability::OrthogonalToParametric
) {
return Err(format!(
"{label} term '{}' is not frozen: Duchon identifiability must be FrozenTransform or None",
st.name
));
}
}
SmoothBasisSpec::TensorBSpline { spec, .. } => {
for (dim, marginal) in spec.marginalspecs.iter().enumerate() {
if !matches!(marginal.knotspec, BSplineKnotSpec::Provided(_)) {
return Err(format!(
"{label} term '{}' dim {} is not frozen: tensor marginal knotspec must be Provided",
st.name, dim
));
}
}
if matches!(
spec.identifiability,
TensorBSplineIdentifiability::SumToZero
) {
return Err(format!(
"{label} term '{}' is not frozen: tensor identifiability must be FrozenTransform or None",
st.name
));
}
}
}
}
for rt in &self.random_effect_terms {
if rt.frozen_levels.is_none() {
return Err(format!(
"{label} random-effect term '{}' is not frozen: missing frozen_levels",
rt.name
));
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub enum PenaltyStructureHint {
Ridge(f64),
Kronecker(Vec<Array2<f64>>),
}
#[derive(Clone)]
pub struct BlockwisePenalty {
pub col_range: Range<usize>,
pub local: Array2<f64>,
pub structure_hint: Option<PenaltyStructureHint>,
pub op: Option<std::sync::Arc<dyn crate::terms::penalty_op::PenaltyOp>>,
}
impl std::fmt::Debug for BlockwisePenalty {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BlockwisePenalty")
.field("col_range", &self.col_range)
.field(
"local",
&format_args!("{}×{}", self.local.nrows(), self.local.ncols()),
)
.field("structure_hint", &self.structure_hint)
.field("op", &self.op.as_ref().map(|o| o.dim()))
.finish()
}
}
impl BlockwisePenalty {
pub fn new(col_range: Range<usize>, local: Array2<f64>) -> Self {
debug_assert_eq!(col_range.len(), local.nrows());
debug_assert_eq!(col_range.len(), local.ncols());
Self {
col_range,
local,
structure_hint: None,
op: None,
}
}
pub fn with_op(
mut self,
op: Option<std::sync::Arc<dyn crate::terms::penalty_op::PenaltyOp>>,
) -> Self {
self.op = op;
self
}
pub fn ridge(col_range: Range<usize>, scale: f64) -> Self {
let block_size = col_range.len();
let mut local = Array2::<f64>::zeros((block_size, block_size));
for i in 0..block_size {
local[[i, i]] = scale;
}
Self {
col_range,
local,
structure_hint: Some(PenaltyStructureHint::Ridge(scale)),
op: None,
}
}
pub fn kronecker(
col_range: Range<usize>,
local: Array2<f64>,
factors: Vec<Array2<f64>>,
) -> Self {
debug_assert_eq!(col_range.len(), local.nrows());
debug_assert_eq!(col_range.len(), local.ncols());
Self {
col_range,
local,
structure_hint: Some(PenaltyStructureHint::Kronecker(factors)),
op: None,
}
}
pub fn to_global(&self, p_total: usize) -> Array2<f64> {
let mut g = Array2::<f64>::zeros((p_total, p_total));
let r = &self.col_range;
assert!(
r.end <= p_total && self.local.nrows() == r.len() && self.local.ncols() == r.len(),
"BlockwisePenalty::to_global shape invariant violated: \
col_range={}..{}, local={}x{}, p_total={}",
r.start,
r.end,
self.local.nrows(),
self.local.ncols(),
p_total,
);
g.slice_mut(s![r.start..r.end, r.start..r.end])
.assign(&self.local);
g
}
#[inline]
pub fn block_size(&self) -> usize {
self.col_range.len()
}
}
pub fn weighted_blockwise_penalty_sum(
penalties: &[BlockwisePenalty],
lambdas: &[f64],
p_total: usize,
) -> Array2<f64> {
debug_assert_eq!(penalties.len(), lambdas.len());
let mut out = Array2::<f64>::zeros((p_total, p_total));
for (bp, &lam) in penalties.iter().zip(lambdas.iter()) {
let r = &bp.col_range;
let mut slice = out.slice_mut(s![r.start..r.end, r.start..r.end]);
slice.scaled_add(lam, &bp.local);
}
out
}
#[derive(Debug, Clone)]
pub struct KroneckerPenaltySystem {
pub marginal_penalties: Vec<Array2<f64>>,
pub marginal_eigensystems: Vec<(Array1<f64>, Array2<f64>)>,
pub marginal_dims: Vec<usize>,
pub has_double_penalty: bool,
}
impl KroneckerPenaltySystem {
pub fn new(
marginal_penalties: Vec<Array2<f64>>,
marginal_dims: Vec<usize>,
has_double_penalty: bool,
) -> Result<Self, BasisError> {
if marginal_penalties.len() != marginal_dims.len() {
return Err(BasisError::DimensionMismatch(format!(
"KroneckerPenaltySystem: {} penalties vs {} dims",
marginal_penalties.len(),
marginal_dims.len()
)));
}
let eigensystems =
kronecker_marginal_eigensystems(&marginal_penalties, "KroneckerPenaltySystem")
.map_err(|e| BasisError::InvalidInput(e.to_string()))?;
Ok(Self {
marginal_penalties,
marginal_eigensystems: eigensystems,
marginal_dims,
has_double_penalty,
})
}
pub fn p_total(&self) -> usize {
self.marginal_dims.iter().copied().product()
}
pub fn ndim(&self) -> usize {
self.marginal_dims.len()
}
pub fn num_penalties(&self) -> usize {
self.marginal_dims.len() + if self.has_double_penalty { 1 } else { 0 }
}
pub fn logdet_and_derivatives(
&self,
lambdas: &[f64],
ridge: f64,
) -> (f64, Array1<f64>, Array2<f64>) {
let n_pen = self.num_penalties();
assert_eq!(lambdas.len(), n_pen, "lambda count mismatch");
let marginal_evals: Vec<_> = self
.marginal_eigensystems
.iter()
.map(|(evals, _)| evals.view())
.collect();
kronecker_logdet_and_derivatives(
&marginal_evals,
&self.marginal_dims,
lambdas,
self.has_double_penalty,
ridge,
)
}
}
#[derive(Clone, Debug)]
pub struct TermCollectionDesign {
pub design: DesignMatrix,
pub penalties: Vec<BlockwisePenalty>,
pub nullspace_dims: Vec<usize>,
pub penaltyinfo: Vec<PenaltyBlockInfo>,
pub dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
pub coefficient_lower_bounds: Option<Array1<f64>>,
pub linear_constraints: Option<LinearInequalityConstraints>,
pub intercept_range: Range<usize>,
pub linear_ranges: Vec<(String, Range<usize>)>,
pub random_effect_ranges: Vec<(String, Range<usize>)>,
pub random_effect_levels: Vec<(String, Vec<u64>)>,
pub smooth: SmoothDesign,
}
impl TermCollectionDesign {
pub fn penalties_as_penalty_matrix(&self) -> Vec<crate::custom_family::PenaltyMatrix> {
let p = self.design.ncols();
self.penalties
.iter()
.map(|bp| crate::custom_family::PenaltyMatrix::from_blockwise(bp.clone(), p))
.collect()
}
#[inline]
pub fn num_penalties(&self) -> usize {
self.penalties.len()
}
pub fn kronecker_penalty_system(&self) -> Option<KroneckerPenaltySystem> {
let kron_terms: Vec<&KroneckerFactoredBasis> = self
.smooth
.terms
.iter()
.filter_map(|t| t.kronecker_factored.as_ref())
.collect();
if kron_terms.len() != 1 {
return None; }
let kron = kron_terms[0];
let non_kron_smooth_terms = self
.smooth
.terms
.iter()
.filter(|t| t.kronecker_factored.is_none())
.count();
if non_kron_smooth_terms > 0 {
return None; }
KroneckerPenaltySystem::new(
kron.marginal_penalties.clone(),
kron.marginal_dims.clone(),
kron.has_double_penalty,
)
.ok()
}
}
#[derive(Clone)]
pub struct FittedTermCollection {
pub fit: UnifiedFitResult,
pub design: TermCollectionDesign,
pub adaptive_diagnostics: Option<AdaptiveRegularizationDiagnostics>,
}
#[derive(Clone)]
pub struct FittedTermCollectionWithSpec {
pub fit: UnifiedFitResult,
pub design: TermCollectionDesign,
pub resolvedspec: TermCollectionSpec,
pub adaptive_diagnostics: Option<AdaptiveRegularizationDiagnostics>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AdaptiveSpatialMap {
pub termname: String,
pub feature_cols: Vec<usize>,
pub collocation_points: Array2<f64>,
pub inv_magweight: Array1<f64>,
pub invgradweight: Array1<f64>,
pub inv_lapweight: Array1<f64>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AdaptiveRegularizationDiagnostics {
pub epsilon_0: f64,
pub epsilon_g: f64,
pub epsilon_c: f64,
pub epsilon_outer_iterations: usize,
pub mm_iterations: usize,
pub converged: bool,
pub maps: Vec<AdaptiveSpatialMap>,
}
#[derive(Debug, Clone)]
struct LinearColumnConditioning {
col_idx: usize,
mean: f64,
scale: f64,
}
#[derive(Debug, Clone, Default)]
struct LinearFitConditioning {
intercept_idx: usize,
columns: Vec<LinearColumnConditioning>,
}
#[derive(Clone)]
pub(crate) struct SpatialPsiDerivative {
pub penalty_index: usize,
pub penalty_indices: Vec<usize>,
pub global_range: Range<usize>,
pub total_p: usize,
pub x_psi_local: Array2<f64>,
pub s_psi_components_local: Vec<Array2<f64>>,
pub x_psi_psi_local: Array2<f64>,
pub s_psi_psi_components_local: Vec<Array2<f64>>,
pub aniso_group_id: Option<usize>,
pub aniso_cross_designs: Option<Vec<(usize, Array2<f64>)>>,
pub aniso_cross_penalty_provider: Option<
std::sync::Arc<
dyn Fn(usize) -> Result<Vec<Array2<f64>>, EstimationError> + Send + Sync + 'static,
>,
>,
pub implicit_operator: Option<std::sync::Arc<crate::terms::basis::ImplicitDesignPsiDerivative>>,
pub implicit_axis: usize,
}
#[derive(Debug, Clone)]
pub(crate) struct SpatialLogKappaCoords {
values: Array1<f64>,
dims_per_term: Vec<usize>,
}
impl SpatialLogKappaCoords {
pub(crate) fn new_with_dims(values: Array1<f64>, dims_per_term: Vec<usize>) -> Self {
debug_assert_eq!(
values.len(),
dims_per_term.iter().sum::<usize>(),
"SpatialLogKappaCoords: values length {} != sum of dims_per_term {}",
values.len(),
dims_per_term.iter().sum::<usize>(),
);
Self {
values,
dims_per_term,
}
}
pub(crate) fn from_length_scales(
spec: &TermCollectionSpec,
term_indices: &[usize],
options: &SpatialLengthScaleOptimizationOptions,
) -> Self {
let mut out = Array1::<f64>::zeros(term_indices.len());
for (slot, &term_idx) in term_indices.iter().enumerate() {
let length_scale = get_spatial_length_scale(spec, term_idx)
.unwrap_or(options.min_length_scale)
.clamp(options.min_length_scale, options.max_length_scale);
out[slot] = -length_scale.ln();
}
Self {
values: out,
dims_per_term: vec![1; term_indices.len()],
}
}
pub(crate) fn from_length_scales_aniso(
spec: &TermCollectionSpec,
term_indices: &[usize],
options: &SpatialLengthScaleOptimizationOptions,
) -> Self {
let mut vals = Vec::new();
let mut dims = Vec::new();
for &term_idx in term_indices {
if is_pure_duchon_aniso_term(spec, term_idx) {
let d = get_spatial_feature_dim(spec, term_idx).unwrap_or(1);
let eta = center_aniso_log_scales(
&get_spatial_aniso_log_scales(spec, term_idx).unwrap_or_else(|| vec![0.0; d]),
);
for eta_a in eta.into_iter().take(d.saturating_sub(1)) {
vals.push(eta_a);
}
dims.push(d.saturating_sub(1).max(1));
continue;
}
let length_scale = get_spatial_length_scale(spec, term_idx)
.unwrap_or(options.min_length_scale)
.clamp(options.min_length_scale, options.max_length_scale);
let psi_bar = -length_scale.ln();
let aniso = get_spatial_aniso_log_scales(spec, term_idx);
let d = get_spatial_feature_dim(spec, term_idx).unwrap_or(1);
match aniso {
Some(ref eta) if eta.len() == d && d > 1 => {
let eta = center_aniso_log_scales(eta);
for &eta_a in &eta {
vals.push(psi_bar + eta_a);
}
dims.push(d);
}
_ => {
vals.push(psi_bar);
dims.push(1);
}
}
}
Self {
values: Array1::from_vec(vals),
dims_per_term: dims,
}
}
#[cfg(test)]
pub(crate) fn lower_bounds_aniso(
dims_per_term: &[usize],
options: &SpatialLengthScaleOptimizationOptions,
) -> Self {
let total: usize = dims_per_term.iter().sum();
Self {
values: Array1::<f64>::from_elem(total, -options.max_length_scale.ln()),
dims_per_term: dims_per_term.to_vec(),
}
}
#[cfg(test)]
pub(crate) fn upper_bounds_aniso(
dims_per_term: &[usize],
options: &SpatialLengthScaleOptimizationOptions,
) -> Self {
let total: usize = dims_per_term.iter().sum();
Self {
values: Array1::<f64>::from_elem(total, -options.min_length_scale.ln()),
dims_per_term: dims_per_term.to_vec(),
}
}
pub(crate) fn lower_bounds_from_data(
data: ArrayView2<'_, f64>,
spec: &TermCollectionSpec,
term_indices: &[usize],
options: &SpatialLengthScaleOptimizationOptions,
) -> Self {
let mut values = Array1::<f64>::zeros(term_indices.len());
for (slot, &term_idx) in term_indices.iter().enumerate() {
values[slot] = spatial_term_psi_bounds(data, spec, term_idx, options).0;
}
Self {
values,
dims_per_term: vec![1; term_indices.len()],
}
}
pub(crate) fn upper_bounds_from_data(
data: ArrayView2<'_, f64>,
spec: &TermCollectionSpec,
term_indices: &[usize],
options: &SpatialLengthScaleOptimizationOptions,
) -> Self {
let mut values = Array1::<f64>::zeros(term_indices.len());
for (slot, &term_idx) in term_indices.iter().enumerate() {
values[slot] = spatial_term_psi_bounds(data, spec, term_idx, options).1;
}
Self {
values,
dims_per_term: vec![1; term_indices.len()],
}
}
pub(crate) fn lower_bounds_aniso_from_data(
data: ArrayView2<'_, f64>,
spec: &TermCollectionSpec,
term_indices: &[usize],
dims_per_term: &[usize],
options: &SpatialLengthScaleOptimizationOptions,
) -> Self {
debug_assert_eq!(term_indices.len(), dims_per_term.len());
let total: usize = dims_per_term.iter().sum();
let mut values = Array1::<f64>::zeros(total);
let options_lo = -options.max_length_scale.ln();
let mut cursor = 0;
for (slot, &term_idx) in term_indices.iter().enumerate() {
let d = dims_per_term[slot];
let psi_lo = if is_pure_duchon_aniso_term(spec, term_idx) {
options_lo
} else {
spatial_term_psi_bounds(data, spec, term_idx, options).0
};
let axis_offsets = if is_pure_duchon_aniso_term(spec, term_idx) || d <= 1 {
vec![0.0; d]
} else {
get_spatial_aniso_log_scales(spec, term_idx)
.filter(|eta| eta.len() == d)
.map(|eta| center_aniso_log_scales(&eta))
.unwrap_or_else(|| vec![0.0; d])
};
for offset in 0..d {
values[cursor + offset] = psi_lo + axis_offsets[offset];
}
cursor += d;
}
Self {
values,
dims_per_term: dims_per_term.to_vec(),
}
}
pub(crate) fn upper_bounds_aniso_from_data(
data: ArrayView2<'_, f64>,
spec: &TermCollectionSpec,
term_indices: &[usize],
dims_per_term: &[usize],
options: &SpatialLengthScaleOptimizationOptions,
) -> Self {
debug_assert_eq!(term_indices.len(), dims_per_term.len());
let total: usize = dims_per_term.iter().sum();
let mut values = Array1::<f64>::zeros(total);
let options_hi = -options.min_length_scale.ln();
let mut cursor = 0;
for (slot, &term_idx) in term_indices.iter().enumerate() {
let d = dims_per_term[slot];
let psi_hi = if is_pure_duchon_aniso_term(spec, term_idx) {
options_hi
} else {
spatial_term_psi_bounds(data, spec, term_idx, options).1
};
let axis_offsets = if is_pure_duchon_aniso_term(spec, term_idx) || d <= 1 {
vec![0.0; d]
} else {
get_spatial_aniso_log_scales(spec, term_idx)
.filter(|eta| eta.len() == d)
.map(|eta| center_aniso_log_scales(&eta))
.unwrap_or_else(|| vec![0.0; d])
};
for offset in 0..d {
values[cursor + offset] = psi_hi + axis_offsets[offset];
}
cursor += d;
}
Self {
values,
dims_per_term: dims_per_term.to_vec(),
}
}
pub(crate) fn reseed_from_data(
mut self,
data: ArrayView2<'_, f64>,
spec: &TermCollectionSpec,
term_indices: &[usize],
options: &SpatialLengthScaleOptimizationOptions,
) -> Self {
debug_assert_eq!(term_indices.len(), self.dims_per_term.len());
let mut cursor = 0;
for (slot, &term_idx) in term_indices.iter().enumerate() {
let d = self.dims_per_term[slot];
let Some(psi_bar_new) = spatial_term_psi_seed(data, spec, term_idx, options) else {
cursor += d;
continue;
};
if is_pure_duchon_aniso_term(spec, term_idx) {
cursor += d;
continue;
}
if d == 0 {
continue;
}
let current: Vec<f64> = self.values.slice(s![cursor..cursor + d]).to_vec();
let psi_bar_old = current.iter().sum::<f64>() / d as f64;
for (offset, &old_value) in current.iter().enumerate() {
self.values[cursor + offset] = psi_bar_new + (old_value - psi_bar_old);
}
cursor += d;
}
self
}
pub(crate) fn clamp_to_bounds(
mut self,
lower: &SpatialLogKappaCoords,
upper: &SpatialLogKappaCoords,
) -> Self {
debug_assert_eq!(self.values.len(), lower.values.len());
debug_assert_eq!(self.values.len(), upper.values.len());
let mut n_projected = 0usize;
let mut worst_delta = 0.0_f64;
for idx in 0..self.values.len() {
let lo = lower.values[idx];
let hi = upper.values[idx];
if !(lo.is_finite() && hi.is_finite()) {
continue;
}
let v = self.values[idx];
if v < lo {
worst_delta = worst_delta.max(lo - v);
self.values[idx] = lo;
n_projected += 1;
} else if v > hi {
worst_delta = worst_delta.max(v - hi);
self.values[idx] = hi;
n_projected += 1;
}
}
if n_projected > 0 {
log::info!(
"[spatial-kappa] projected {n_projected}/{} ψ seed coords into data-derived bounds \
(worst excess={worst_delta:.3} log units); user length_scale falls outside \
[2/r_max, 1e2/r_min] geometry window",
self.values.len()
);
}
self
}
pub(crate) fn from_theta_tail_with_dims(
theta: &Array1<f64>,
start: usize,
dims_per_term: Vec<usize>,
) -> Self {
let total: usize = dims_per_term.iter().sum();
Self {
values: theta.slice(s![start..start + total]).to_owned(),
dims_per_term,
}
}
pub(crate) fn len(&self) -> usize {
self.values.len()
}
pub(crate) fn dims_per_term(&self) -> &[usize] {
&self.dims_per_term
}
fn term_offset(&self, term_idx: usize) -> usize {
self.dims_per_term[..term_idx].iter().sum()
}
fn term_slice(&self, term_idx: usize) -> &[f64] {
let offset = self.term_offset(term_idx);
let d = self.dims_per_term[term_idx];
&self.values.as_slice().unwrap()[offset..offset + d]
}
pub(crate) fn as_array(&self) -> &Array1<f64> {
&self.values
}
pub(crate) fn split_at(&self, mid: usize) -> (Self, Self) {
let flat_mid: usize = self.dims_per_term[..mid].iter().sum();
(
Self {
values: self.values.slice(s![0..flat_mid]).to_owned(),
dims_per_term: self.dims_per_term[..mid].to_vec(),
},
Self {
values: self.values.slice(s![flat_mid..]).to_owned(),
dims_per_term: self.dims_per_term[mid..].to_vec(),
},
)
}
pub(crate) fn apply_tospec(
&self,
spec: &TermCollectionSpec,
term_indices: &[usize],
) -> Result<TermCollectionSpec, EstimationError> {
if term_indices.len() != self.dims_per_term.len() {
return Err(EstimationError::InvalidInput(format!(
"SpatialLogKappaCoords::apply_tospec: term count mismatch: \
term_indices={} dims_per_term={}",
term_indices.len(),
self.dims_per_term.len()
)));
}
let mut updated = spec.clone();
for (slot, &term_idx) in term_indices.iter().enumerate() {
let psi = self.term_slice(slot);
let d = self.dims_per_term[slot];
let (next_length_scale, next_aniso) =
spatial_term_psi_to_length_scale_and_aniso(spec, term_idx, psi);
if d == 1 || next_length_scale.is_some() {
if let Some(length_scale) = next_length_scale {
set_spatial_length_scale(&mut updated, term_idx, length_scale)?;
}
}
if let Some(eta) = next_aniso {
set_spatial_aniso_log_scales(&mut updated, term_idx, eta)?;
}
}
Ok(updated)
}
}
fn center_aniso_log_scales(eta: &[f64]) -> Vec<f64> {
if eta.len() <= 1 {
return eta.to_vec();
}
let mean = eta.iter().sum::<f64>() / eta.len() as f64;
eta.iter()
.map(|&v| {
let centered = v - mean;
if centered.abs() <= 1e-15 {
0.0
} else {
centered
}
})
.collect()
}
fn is_pure_duchon_aniso_term(spec: &TermCollectionSpec, term_idx: usize) -> bool {
spec.smooth_terms
.get(term_idx)
.is_some_and(|term| match &term.basis {
SmoothBasisSpec::Duchon {
feature_cols, spec, ..
} => {
spec.length_scale.is_none()
&& feature_cols.len() > 1
&& spec
.aniso_log_scales
.as_ref()
.is_some_and(|eta| eta.len() == feature_cols.len())
}
_ => false,
})
}
fn spatial_term_supports_hyper_optimization(spec: &TermCollectionSpec, term_idx: usize) -> bool {
get_spatial_length_scale(spec, term_idx).is_some() || is_pure_duchon_aniso_term(spec, term_idx)
}
pub fn spatial_term_has_locked_kappa(spec: &TermCollectionSpec, term_idx: usize) -> bool {
let aniso = get_spatial_aniso_log_scales(spec, term_idx);
let aniso_active = matches!(
(aniso.as_ref(), get_spatial_feature_dim(spec, term_idx)),
(Some(eta), Some(d)) if eta.len() == d && d > 0
);
get_spatial_length_scale(spec, term_idx).is_some() && !aniso_active
}
fn spatial_term_psi_bounds(
data: ArrayView2<'_, f64>,
spec: &TermCollectionSpec,
term_idx: usize,
options: &SpatialLengthScaleOptimizationOptions,
) -> (f64, f64) {
let fallback = (
-options.max_length_scale.ln(),
-options.min_length_scale.ln(),
);
let Some(term) = spec.smooth_terms.get(term_idx) else {
return fallback;
};
let aniso = get_spatial_aniso_log_scales(spec, term_idx);
let r_bounds = match spatial_term_center_strategy(term) {
Some(CenterStrategy::UserProvided(centers)) if centers.nrows() >= 2 => {
match aniso.as_deref() {
Some(eta) if eta.len() == centers.ncols() => {
let y = points_in_aniso_y_space(centers.view(), eta);
pairwise_distance_bounds(y.view())
}
_ => pairwise_distance_bounds(centers.view()),
}
}
_ => standardized_spatial_term_data(data, term)
.ok()
.and_then(|x| match aniso.as_deref() {
Some(eta) if eta.len() == x.ncols() => {
let y = points_in_aniso_y_space(x.view(), eta);
pairwise_distance_bounds_sampled(y.view())
}
_ => pairwise_distance_bounds_sampled(x.view()),
}),
};
let Some((r_min, r_max)) = r_bounds else {
return fallback;
};
let psi_lo_data = (2.0 / r_max).ln();
let psi_hi_data = (1e2 / r_min).ln();
let psi_lo = psi_lo_data.max(fallback.0);
let psi_hi = psi_hi_data.min(fallback.1);
if psi_lo >= psi_hi {
return fallback;
}
(psi_lo, psi_hi)
}
fn spatial_term_psi_seed(
data: ArrayView2<'_, f64>,
spec: &TermCollectionSpec,
term_idx: usize,
options: &SpatialLengthScaleOptimizationOptions,
) -> Option<f64> {
if get_spatial_length_scale(spec, term_idx).is_some() {
return None; }
let (psi_lo, psi_hi) = spatial_term_psi_bounds(data, spec, term_idx, options);
Some(0.5 * (psi_lo + psi_hi))
}
fn spatial_term_psi_to_length_scale_and_aniso(
spec: &TermCollectionSpec,
term_idx: usize,
psi: &[f64],
) -> (Option<f64>, Option<Vec<f64>>) {
if is_pure_duchon_aniso_term(spec, term_idx) {
let d = get_spatial_feature_dim(spec, term_idx).unwrap_or(psi.len());
let free = d.saturating_sub(1);
let mut eta = vec![0.0; d];
for (axis, &value) in psi.iter().take(free).enumerate() {
eta[axis] = value;
}
if d > 1 {
eta[d - 1] = -eta[..d - 1].iter().sum::<f64>();
}
return (None, Some(eta));
}
if psi.len() <= 1 {
(Some((-psi.first().copied().unwrap_or(0.0)).exp()), None)
} else {
let psi_bar = psi.iter().sum::<f64>() / psi.len() as f64;
(
Some((-psi_bar).exp()),
Some(psi.iter().map(|&value| value - psi_bar).collect()),
)
}
}
pub fn get_spatial_aniso_log_scales(
spec: &TermCollectionSpec,
term_idx: usize,
) -> Option<Vec<f64>> {
spec.smooth_terms
.get(term_idx)
.and_then(|term| match &term.basis {
SmoothBasisSpec::Matern { spec, .. } => spec.aniso_log_scales.clone(),
SmoothBasisSpec::Duchon { spec, .. } => spec.aniso_log_scales.clone(),
_ => None,
})
}
fn get_spatial_feature_dim(spec: &TermCollectionSpec, term_idx: usize) -> Option<usize> {
spec.smooth_terms
.get(term_idx)
.and_then(|term| match &term.basis {
SmoothBasisSpec::ThinPlate { feature_cols, .. } => Some(feature_cols.len()),
SmoothBasisSpec::Matern { feature_cols, .. } => Some(feature_cols.len()),
SmoothBasisSpec::Duchon { feature_cols, .. } => Some(feature_cols.len()),
_ => None,
})
}
pub fn log_spatial_aniso_scales(spec: &TermCollectionSpec) {
for (term_idx, term) in spec.smooth_terms.iter().enumerate() {
let (aniso, length_scale) = match &term.basis {
SmoothBasisSpec::Matern { spec, .. } => {
(spec.aniso_log_scales.as_ref(), Some(spec.length_scale))
}
SmoothBasisSpec::Duchon { spec, .. } => {
(spec.aniso_log_scales.as_ref(), spec.length_scale)
}
_ => (None, None),
};
let Some(eta) = aniso else { continue };
if eta.is_empty() {
continue;
}
let mut lines = match length_scale {
Some(ls) => format!(
"[spatial-kappa] term {} (\"{}\"): anisotropic length scales optimized (global length_scale={:.4})",
term_idx, term.name, ls
),
None => format!(
"[spatial-kappa] term {} (\"{}\"): pure Duchon shape anisotropy optimized",
term_idx, term.name
),
};
for (a, &eta_a) in eta.iter().enumerate() {
if let Some(ls) = length_scale {
let length_a = ls * (-eta_a).exp();
let kappa_a = (1.0 / ls) * eta_a.exp();
lines.push_str(&format!(
"\n axis {}: eta={:+.4}, length={:.4}, kappa={:.4}",
a, eta_a, length_a, kappa_a
));
} else {
lines.push_str(&format!("\n axis {}: eta={:+.4}", a, eta_a));
}
}
log::info!("{}", lines);
}
}
fn set_spatial_aniso_log_scales(
spec: &mut TermCollectionSpec,
term_idx: usize,
eta: Vec<f64>,
) -> Result<(), EstimationError> {
let eta = center_aniso_log_scales(&eta);
let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
return Err(EstimationError::InvalidInput(format!(
"spatial aniso_log_scales term index {term_idx} out of range"
)));
};
match &mut term.basis {
SmoothBasisSpec::Matern { spec, .. } => {
spec.aniso_log_scales = Some(eta);
Ok(())
}
SmoothBasisSpec::Duchon { spec, .. } => {
spec.aniso_log_scales = Some(eta);
Ok(())
}
_ => Err(EstimationError::InvalidInput(format!(
"term '{}' does not support aniso_log_scales",
term.name
))),
}
}
pub(crate) fn sync_aniso_contrasts_from_metadata(
spec: &mut TermCollectionSpec,
design: &SmoothDesign,
) {
for (term_idx, term) in design.terms.iter().enumerate() {
let meta_aniso = match &term.metadata {
BasisMetadata::Matern {
aniso_log_scales, ..
} => aniso_log_scales.clone(),
BasisMetadata::Duchon {
aniso_log_scales, ..
} => aniso_log_scales.clone(),
_ => None,
};
if let Some(eta) = meta_aniso {
if eta.len() > 1 {
set_spatial_aniso_log_scales(spec, term_idx, eta).ok();
}
}
}
}
#[derive(Debug, Clone)]
pub struct SpatialLengthScaleOptimizationOptions {
pub enabled: bool,
pub max_outer_iter: usize,
pub rel_tol: f64,
pub log_step: f64,
pub min_length_scale: f64,
pub max_length_scale: f64,
pub pilot_subsample_threshold: usize,
}
impl Default for SpatialLengthScaleOptimizationOptions {
fn default() -> Self {
Self {
enabled: true,
max_outer_iter: 80,
rel_tol: 1e-4,
log_step: std::f64::consts::LN_2,
min_length_scale: 1e-3,
max_length_scale: 1e3,
pilot_subsample_threshold: 10_000,
}
}
}
impl SpatialLengthScaleOptimizationOptions {
pub fn validate(&self) -> Result<(), String> {
if !self.min_length_scale.is_finite() || self.min_length_scale <= 0.0 {
return Err(format!(
"SpatialLengthScaleOptimizationOptions::min_length_scale must be > 0 and finite, got {}",
self.min_length_scale
));
}
if !self.max_length_scale.is_finite() || self.max_length_scale <= 0.0 {
return Err(format!(
"SpatialLengthScaleOptimizationOptions::max_length_scale must be > 0 and finite, got {}",
self.max_length_scale
));
}
if self.min_length_scale >= self.max_length_scale {
return Err(format!(
"SpatialLengthScaleOptimizationOptions requires min_length_scale < max_length_scale, got min={} max={}",
self.min_length_scale, self.max_length_scale
));
}
if !self.rel_tol.is_finite() || self.rel_tol <= 0.0 {
return Err(format!(
"SpatialLengthScaleOptimizationOptions::rel_tol must be > 0 and finite, got {}",
self.rel_tol
));
}
if !self.log_step.is_finite() || self.log_step <= 0.0 {
return Err(format!(
"SpatialLengthScaleOptimizationOptions::log_step must be > 0 and finite, got {}",
self.log_step
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
struct RandomEffectBlock {
name: String,
group_ids: Vec<Option<usize>>,
num_groups: usize,
kept_levels: Vec<u64>,
}
const BLOCK_SPARSE_ZERO_EPS: f64 = 1e-12;
const BLOCK_SPARSE_MAX_DENSITY: f64 = 0.20;
fn blocks_have_intrinsic_sparse_structure(blocks: &[DesignBlock]) -> bool {
blocks
.iter()
.any(|block| matches!(block, DesignBlock::Sparse(_) | DesignBlock::RandomEffect(_)))
}
fn sparse_compatible_block_nnz(block: &DesignBlock) -> Option<usize> {
match block {
DesignBlock::Intercept(n) => Some(*n),
DesignBlock::RandomEffect(op) => {
Some(op.group_ids.iter().filter(|gid| gid.is_some()).count())
}
DesignBlock::Sparse(sparse) => Some(sparse.val().len()),
DesignBlock::Dense(dense) => dense.as_dense_ref().map(|matrix| {
matrix
.iter()
.filter(|&&value| value.abs() > BLOCK_SPARSE_ZERO_EPS)
.count()
}),
}
}
fn try_build_sparse_design_from_blocks(
blocks: &[DesignBlock],
) -> Result<Option<DesignMatrix>, BasisError> {
if blocks.is_empty() {
return Ok(None);
}
let nrows = blocks[0].nrows();
let ncols: usize = blocks.iter().map(DesignBlock::ncols).sum();
if nrows == 0 || ncols == 0 || ncols <= 32 {
return Ok(None);
}
let preserve_sparse_storage = blocks_have_intrinsic_sparse_structure(blocks);
let sparse_nnz_limit = if preserve_sparse_storage {
usize::MAX
} else {
let total_cells = nrows.saturating_mul(ncols);
((total_cells as f64) * BLOCK_SPARSE_MAX_DENSITY).floor() as usize
};
let mut nnz = 0usize;
for block in blocks {
let block_nnz = if let Some(block_nnz) = sparse_compatible_block_nnz(block) {
block_nnz
} else {
return Ok(None);
};
nnz = nnz.saturating_add(block_nnz);
if nnz > sparse_nnz_limit {
return Ok(None);
}
}
let mut triplets = Vec::<Triplet<usize, usize, f64>>::with_capacity(nnz);
let mut col_offset = 0usize;
for block in blocks {
match block {
DesignBlock::Intercept(n) => {
for row in 0..*n {
triplets.push(Triplet::new(row, col_offset, 1.0));
}
}
DesignBlock::RandomEffect(op) => {
for (row, group_id) in op.group_ids.iter().enumerate() {
if let Some(group) = group_id {
triplets.push(Triplet::new(row, col_offset + group, 1.0));
}
}
}
DesignBlock::Sparse(sparse) => {
let (symbolic, values) = sparse.parts();
let col_ptr = symbolic.col_ptr();
let row_idx = symbolic.row_idx();
for col in 0..sparse.ncols() {
for idx in col_ptr[col]..col_ptr[col + 1] {
let value = values[idx];
if value.abs() > BLOCK_SPARSE_ZERO_EPS {
triplets.push(Triplet::new(row_idx[idx], col_offset + col, value));
}
}
}
}
DesignBlock::Dense(dense) => {
let matrix = dense.as_dense_ref().ok_or_else(|| {
BasisError::InvalidInput(
"sparse-compatible block assembly requires materialized dense blocks"
.to_string(),
)
})?;
for row in 0..matrix.nrows() {
for col in 0..matrix.ncols() {
let value = matrix[[row, col]];
if value.abs() > BLOCK_SPARSE_ZERO_EPS {
triplets.push(Triplet::new(row, col_offset + col, value));
}
}
}
}
}
col_offset += block.ncols();
}
let sparse = SparseColMat::try_new_from_triplets(nrows, ncols, &triplets).map_err(|_| {
BasisError::SparseCreation("failed to assemble sparse term-collection design".to_string())
})?;
Ok(Some(DesignMatrix::Sparse(
crate::matrix::SparseDesignMatrix::new(sparse),
)))
}
fn assemble_term_collection_design_matrix(
blocks: Vec<DesignBlock>,
) -> Result<DesignMatrix, BasisError> {
if let Some(sparse) = try_build_sparse_design_from_blocks(&blocks)? {
return Ok(sparse);
}
let block_op = BlockDesignOperator::new(blocks).map_err(|e| {
BasisError::InvalidInput(format!("failed to build block design operator: {e}"))
})?;
Ok(DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Arc::new(block_op),
)))
}
fn compute_spatial_input_scales(x: ArrayView2<'_, f64>) -> Option<Vec<f64>> {
let d = x.ncols();
if d <= 1 {
return None;
}
let n = x.nrows() as f64;
if n < 2.0 {
return None;
}
let mut scales = Vec::with_capacity(d);
for j in 0..d {
let col = x.column(j);
let mean = col.sum() / n;
let var = col.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / (n - 1.0);
scales.push(var.sqrt().max(1e-12));
}
Some(scales)
}
fn apply_input_standardization(x: &mut Array2<f64>, scales: &[f64]) {
for j in 0..x.ncols() {
let inv = 1.0 / scales[j];
x.column_mut(j).mapv_inplace(|v| v * inv);
}
}
fn geometric_mean_scale(scales: &[f64]) -> f64 {
if scales.is_empty() {
return 1.0;
}
let log_mean: f64 = scales.iter().map(|&s| s.ln()).sum::<f64>() / scales.len() as f64;
log_mean.exp()
}
fn compensate_length_scale_for_standardization(length_scale: f64, scales: &[f64]) -> f64 {
let sigma_geom = geometric_mean_scale(scales);
if sigma_geom > 0.0 && sigma_geom.is_finite() {
length_scale / sigma_geom
} else {
length_scale
}
}
fn compensate_optional_length_scale_for_standardization(
length_scale: Option<f64>,
scales: &[f64],
) -> Option<f64> {
length_scale.map(|l| compensate_length_scale_for_standardization(l, scales))
}
fn select_columns(data: ArrayView2<'_, f64>, cols: &[usize]) -> Result<Array2<f64>, BasisError> {
let n = data.nrows();
let p = data.ncols();
for &c in cols {
if c >= p {
return Err(BasisError::DimensionMismatch(format!(
"feature column {c} is out of bounds for data with {p} columns"
)));
}
}
let mut out = Array2::<f64>::zeros((n, cols.len()));
for (j, &c) in cols.iter().enumerate() {
out.column_mut(j).assign(&data.column(c));
}
Ok(out)
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
struct JointSpatialCenterGroupKey {
feature_cols: Vec<usize>,
strategy_kind: CenterStrategyKind,
strategy_aux: usize,
requested_num_centers: usize,
input_scale_bits: Option<Vec<u64>>,
}
fn spatial_term_min_center_count(term: &SmoothTermSpec) -> usize {
match &term.basis {
SmoothBasisSpec::ThinPlate { feature_cols, .. } => feature_cols.len() + 1,
SmoothBasisSpec::Duchon {
feature_cols, spec, ..
} => match spec.nullspace_order {
crate::basis::DuchonNullspaceOrder::Zero => 1,
crate::basis::DuchonNullspaceOrder::Linear => feature_cols.len() + 1,
crate::basis::DuchonNullspaceOrder::Degree(degree) => {
crate::basis::duchon_nullspace_dimension(feature_cols.len(), degree)
}
},
SmoothBasisSpec::Matern { .. } => 1,
_ => 1,
}
}
fn spatial_term_group_key(term: &SmoothTermSpec) -> Option<JointSpatialCenterGroupKey> {
let (feature_cols, strategy, input_scales) = match &term.basis {
SmoothBasisSpec::ThinPlate {
feature_cols,
spec,
input_scales,
} => (feature_cols, &spec.center_strategy, input_scales.as_ref()),
SmoothBasisSpec::Matern {
feature_cols,
spec,
input_scales,
} => (feature_cols, &spec.center_strategy, input_scales.as_ref()),
SmoothBasisSpec::Duchon {
feature_cols,
spec,
input_scales,
} => (feature_cols, &spec.center_strategy, input_scales.as_ref()),
_ => return None,
};
let strategy_kind = center_strategy_kind(strategy);
let strategy_aux = match strategy {
CenterStrategy::Auto(inner) => match inner.as_ref() {
CenterStrategy::KMeans { max_iter, .. } => *max_iter,
CenterStrategy::UniformGrid { points_per_dim } => *points_per_dim,
_ => 0,
},
CenterStrategy::KMeans { max_iter, .. } => *max_iter,
CenterStrategy::UniformGrid { points_per_dim } => *points_per_dim,
_ => 0,
};
Some(JointSpatialCenterGroupKey {
feature_cols: feature_cols.clone(),
strategy_kind,
strategy_aux,
requested_num_centers: center_strategy_num_centers(strategy)?,
input_scale_bits: input_scales
.map(|values| values.iter().map(|value| value.to_bits()).collect()),
})
}
fn spatial_term_center_strategy(term: &SmoothTermSpec) -> Option<&CenterStrategy> {
match &term.basis {
SmoothBasisSpec::ThinPlate { spec, .. } => Some(&spec.center_strategy),
SmoothBasisSpec::Matern { spec, .. } => Some(&spec.center_strategy),
SmoothBasisSpec::Duchon { spec, .. } => Some(&spec.center_strategy),
_ => None,
}
}
fn set_spatial_term_centers(
term: &mut SmoothTermSpec,
centers: Array2<f64>,
) -> Result<(), BasisError> {
match &mut term.basis {
SmoothBasisSpec::ThinPlate { spec, .. } => {
spec.center_strategy = CenterStrategy::UserProvided(centers);
Ok(())
}
SmoothBasisSpec::Matern { spec, .. } => {
spec.center_strategy = CenterStrategy::UserProvided(centers);
Ok(())
}
SmoothBasisSpec::Duchon { spec, .. } => {
spec.center_strategy = CenterStrategy::UserProvided(centers);
Ok(())
}
_ => Err(BasisError::InvalidInput(format!(
"term '{}' does not support spatial center planning",
term.name
))),
}
}
fn standardized_spatial_term_data(
data: ArrayView2<'_, f64>,
term: &SmoothTermSpec,
) -> Result<Array2<f64>, BasisError> {
let (feature_cols, input_scales) = match &term.basis {
SmoothBasisSpec::ThinPlate {
feature_cols,
input_scales,
..
}
| SmoothBasisSpec::Matern {
feature_cols,
input_scales,
..
}
| SmoothBasisSpec::Duchon {
feature_cols,
input_scales,
..
} => (feature_cols, input_scales.as_ref()),
_ => {
return Err(BasisError::InvalidInput(format!(
"term '{}' is not a spatial smooth",
term.name
)));
}
};
let mut x = select_columns(data, feature_cols)?;
if let Some(scales) = input_scales {
apply_input_standardization(&mut x, scales);
} else if let Some(scales) = compute_spatial_input_scales(x.view()) {
apply_input_standardization(&mut x, &scales);
}
Ok(x)
}
fn plan_joint_spatial_centers_for_term_blocks(
data: ArrayView2<'_, f64>,
term_blocks: &[Vec<SmoothTermSpec>],
) -> Result<Vec<Vec<SmoothTermSpec>>, BasisError> {
let mut planned_blocks = term_blocks.to_vec();
let n = data.nrows();
let mut groups: BTreeMap<JointSpatialCenterGroupKey, Vec<(usize, usize)>> = BTreeMap::new();
for (block_idx, terms) in planned_blocks.iter().enumerate() {
for (term_idx, term) in terms.iter().enumerate() {
let Some(strategy) = spatial_term_center_strategy(term) else {
continue;
};
if !center_strategy_is_auto(strategy) {
continue;
}
let Some(group_key) = spatial_term_group_key(term) else {
continue;
};
if !matches!(
group_key.strategy_kind,
CenterStrategyKind::EqualMass
| CenterStrategyKind::EqualMassCovarRepresentative
| CenterStrategyKind::FarthestPoint
| CenterStrategyKind::KMeans
) {
continue;
}
if center_strategy_num_centers(strategy).is_none() {
continue;
}
groups
.entry(group_key)
.or_default()
.push((block_idx, term_idx));
}
}
for (group_key, members) in groups {
if members.len() < 2 {
continue;
}
let min_required = members
.iter()
.map(|&(block_idx, term_idx)| {
spatial_term_min_center_count(&planned_blocks[block_idx][term_idx])
})
.max()
.unwrap_or(1);
let joint_centers = group_key
.requested_num_centers
.max(min_required)
.min(n.max(1));
let (first_block_idx, first_term_idx) = members[0];
let prototype = &planned_blocks[first_block_idx][first_term_idx];
let standardized = standardized_spatial_term_data(data, prototype)?;
let strategy = spatial_term_center_strategy(prototype).ok_or_else(|| {
BasisError::InvalidInput(format!(
"term '{}' lost its spatial center strategy during joint planning",
prototype.name
))
})?;
let joint_strategy = center_strategy_with_num_centers(strategy, joint_centers)?;
let shared_centers = select_centers_by_strategy(standardized.view(), &joint_strategy)?;
log::info!(
"sharing {} spatial centers across {} smooth terms over columns {:?} (requested {} centers)",
shared_centers.nrows(),
members.len(),
group_key.feature_cols,
group_key.requested_num_centers,
);
for (block_idx, term_idx) in members {
set_spatial_term_centers(
&mut planned_blocks[block_idx][term_idx],
shared_centers.clone(),
)?;
}
}
Ok(planned_blocks)
}
fn linear_conditioning_chunk_rows(n: usize, p: usize) -> usize {
const TARGET_BYTES: usize = 8 * 1024 * 1024;
const MIN_ROWS: usize = 256;
const MAX_ROWS: usize = 65_536;
if p == 0 {
return n.max(1);
}
(TARGET_BYTES / (p * 8))
.clamp(MIN_ROWS, MAX_ROWS)
.min(n.max(1))
}
impl LinearFitConditioning {
fn from_columns(design: &TermCollectionDesign, selected_cols: &[usize]) -> Self {
const SCALE_EPS: f64 = 1e-12;
let n = design.design.nrows();
let p = design.design.ncols();
let mut columns = Vec::with_capacity(selected_cols.len());
if n == 0 || selected_cols.is_empty() {
return Self {
intercept_idx: design.intercept_range.start,
columns,
};
}
let chunk_rows = linear_conditioning_chunk_rows(n, p);
let mut sums = vec![0.0_f64; selected_cols.len()];
for start in (0..n).step_by(chunk_rows) {
let end = (start + chunk_rows).min(n);
let chunk = design
.design
.try_row_chunk(start..end)
.expect("LinearFitConditioning::from_columns row chunk failed");
for (k, &col_idx) in selected_cols.iter().enumerate() {
let column = chunk.column(col_idx);
for &v in column.iter() {
sums[k] += v;
}
}
}
let inv_n = 1.0_f64 / n as f64;
let means: Vec<f64> = sums.iter().map(|&s| s * inv_n).collect();
let mut sq_devs = vec![0.0_f64; selected_cols.len()];
for start in (0..n).step_by(chunk_rows) {
let end = (start + chunk_rows).min(n);
let chunk = design
.design
.try_row_chunk(start..end)
.expect("LinearFitConditioning::from_columns row chunk failed");
for (k, &col_idx) in selected_cols.iter().enumerate() {
let mean_k = means[k];
let column = chunk.column(col_idx);
for &v in column.iter() {
let d = v - mean_k;
sq_devs[k] += d * d;
}
}
}
for (k, &col_idx) in selected_cols.iter().enumerate() {
let mean = means[k];
let var = sq_devs[k] * inv_n;
let (mean, scale) = if var.is_finite() && var > SCALE_EPS * SCALE_EPS {
(mean, var.sqrt())
} else {
(0.0, 1.0)
};
columns.push(LinearColumnConditioning {
col_idx,
mean,
scale,
});
}
Self {
intercept_idx: design.intercept_range.start,
columns,
}
}
fn apply_to_design(&self, design: &Array2<f64>) -> Array2<f64> {
let mut out = design.clone();
for col in &self.columns {
{
let mut dst = out.column_mut(col.col_idx);
dst -= col.mean;
}
if col.scale != 1.0 {
out.column_mut(col.col_idx).mapv_inplace(|v| v / col.scale);
}
}
out
}
fn transform_matrix_columnswith_a(&self, mat: &Array2<f64>) -> Array2<f64> {
let mut out = mat.clone();
let intercept = self.intercept_idx;
for col in &self.columns {
let intercept_col = out.column(intercept).to_owned();
let mut target = out.column_mut(col.col_idx);
target -= &(intercept_col * col.mean);
if col.scale != 1.0 {
target.mapv_inplace(|v| v / col.scale);
}
}
out
}
fn transform_matrixrowswith_a_transpose(&self, mat: &Array2<f64>) -> Array2<f64> {
let mut out = mat.clone();
let intercept = self.intercept_idx;
for col in &self.columns {
let interceptrow = out.row(intercept).to_owned();
let mut target = out.row_mut(col.col_idx);
target -= &(interceptrow * col.mean);
if col.scale != 1.0 {
target.mapv_inplace(|v| v / col.scale);
}
}
out
}
fn transform_matrix_columnswith_b(&self, mat: &Array2<f64>) -> Array2<f64> {
let mut out = mat.clone();
let intercept = self.intercept_idx;
for col in &self.columns {
let intercept_col = out.column(intercept).to_owned();
let mut target = out.column_mut(col.col_idx);
if col.mean != 0.0 {
target += &(intercept_col * col.mean);
}
if col.scale != 1.0 {
target.mapv_inplace(|v| v * col.scale);
}
}
out
}
fn transform_matrixrowswith_b_transpose(&self, mat: &Array2<f64>) -> Array2<f64> {
let mut out = mat.clone();
let intercept = self.intercept_idx;
for col in &self.columns {
let interceptrow = out.row(intercept).to_owned();
let mut target = out.row_mut(col.col_idx);
if col.mean != 0.0 {
target += &(interceptrow * col.mean);
}
if col.scale != 1.0 {
target.mapv_inplace(|v| v * col.scale);
}
}
out
}
fn transform_blockwise_penalties_to_internal(
&self,
penalties: &[BlockwisePenalty],
p: usize,
) -> Vec<PenaltySpec> {
let conditioning_cols: std::collections::HashSet<usize> =
self.columns.iter().map(|c| c.col_idx).collect();
penalties
.iter()
.map(|bp| {
let overlaps =
(bp.col_range.start..bp.col_range.end).any(|j| conditioning_cols.contains(&j));
if overlaps {
let global = bp.to_global(p);
let right = self.transform_matrix_columnswith_a(&global);
let transformed = self.transform_matrixrowswith_a_transpose(&right);
PenaltySpec::Dense(transformed)
} else {
PenaltySpec::from_blockwise(bp.clone())
}
})
.collect()
}
fn backtransform_beta(&self, beta_internal: &Array1<f64>) -> Array1<f64> {
let mut beta = beta_internal.clone();
let intercept = self.intercept_idx;
for col in &self.columns {
beta[intercept] -= beta_internal[col.col_idx] * col.mean / col.scale;
beta[col.col_idx] = beta_internal[col.col_idx] / col.scale;
}
beta
}
fn transform_penalized_hessian_to_original(&self, h_internal: &Array2<f64>) -> Array2<f64> {
let right = self.transform_matrix_columnswith_b(h_internal);
self.transform_matrixrowswith_b_transpose(&right)
}
fn backtransform_covariance(&self, cov_internal: &Array2<f64>) -> Array2<f64> {
let right = self.transform_matrix_columnswith_a(cov_internal);
self.transform_matrixrowswith_a_transpose(&right)
}
fn internal_bounds_for(&self, col_idx: usize, min: f64, max: f64) -> (f64, f64) {
if let Some(col) = self.columns.iter().find(|c| c.col_idx == col_idx) {
(min * col.scale, max * col.scale)
} else {
(min, max)
}
}
}
fn cumulative_exp(values: &Array1<f64>, sign: f64) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(values.len());
let mut run = 0.0;
for i in 0..values.len() {
run += values[i].exp();
out[i] = sign * run;
}
out
}
fn second_cumulative_exp(values: &Array1<f64>, sign: f64) -> Array1<f64> {
let first = cumulative_exp(values, sign);
let mut out = Array1::<f64>::zeros(values.len());
let mut run = 0.0;
for i in 0..values.len() {
run += first[i];
out[i] = run;
}
out
}
fn cumulative_sum_transform_matrix(dim: usize, order: usize, sign: f64) -> Array2<f64> {
let mut t = Array2::<f64>::eye(dim);
for _ in 0..order {
let mut next = Array2::<f64>::zeros((dim, dim));
for i in 0..dim {
for j in 0..=i {
next[[i, j]] = 1.0;
}
}
t = t.dot(&next);
}
if sign < 0.0 {
t.mapv_inplace(|v| -v);
}
t
}
fn shape_order_and_sign(shape: ShapeConstraint) -> Option<(usize, f64)> {
match shape {
ShapeConstraint::None => None,
ShapeConstraint::MonotoneIncreasing => Some((1, 1.0)),
ShapeConstraint::MonotoneDecreasing => Some((1, -1.0)),
ShapeConstraint::Convex => Some((2, 1.0)),
ShapeConstraint::Concave => Some((2, -1.0)),
}
}
fn shape_lower_bounds_local(shape: ShapeConstraint, dim: usize) -> Option<Array1<f64>> {
let (order, _) = shape_order_and_sign(shape)?;
let mut lb = Array1::<f64>::from_elem(dim, f64::NEG_INFINITY);
for j in order..dim {
lb[j] = 0.0;
}
Some(lb)
}
fn shape_supports_basis(term: &SmoothTermSpec) -> bool {
!matches!(term.basis, SmoothBasisSpec::TensorBSpline { .. })
}
fn freeze_raw_spatial_metadata(metadata: BasisMetadata, raw_cols: usize) -> BasisMetadata {
match metadata {
BasisMetadata::ThinPlate {
centers,
length_scale,
identifiability_transform: None,
input_scales,
radial_reparam,
} => BasisMetadata::ThinPlate {
centers,
length_scale,
identifiability_transform: Some(Array2::eye(raw_cols)),
input_scales,
radial_reparam,
},
BasisMetadata::Duchon {
centers,
length_scale,
power,
nullspace_order,
identifiability_transform: None,
input_scales,
aniso_log_scales,
} => BasisMetadata::Duchon {
centers,
length_scale,
power,
nullspace_order,
identifiability_transform: Some(Array2::eye(raw_cols)),
input_scales,
aniso_log_scales,
},
other => other,
}
}
fn matern_operator_penalty_triplet_from_metadata(
metadata: &BasisMetadata,
) -> Result<(Vec<Array2<f64>>, Vec<usize>, Vec<PenaltyInfo>), BasisError> {
let BasisMetadata::Matern {
centers,
length_scale,
nu,
include_intercept,
identifiability_transform,
aniso_log_scales,
..
} = metadata
else {
return Err(BasisError::InvalidInput(
"Matérn operator penalties require Matérn metadata".to_string(),
));
};
let ops = build_matern_collocation_operator_matrices(
centers.view(),
None,
*length_scale,
*nu,
*include_intercept,
identifiability_transform.as_ref().map(|z| z.view()),
aniso_log_scales.as_deref(),
)?;
let mut candidates = Vec::with_capacity(3);
for (raw, source) in [
(ops.d0.t().dot(&ops.d0), PenaltySource::OperatorMass),
(ops.d1.t().dot(&ops.d1), PenaltySource::OperatorTension),
(ops.d2.t().dot(&ops.d2), PenaltySource::OperatorStiffness),
] {
let sym = (&raw + &raw.t()) * 0.5;
let (matrix, normalization_scale) = normalize_penalty_in_constrained_space(&sym);
candidates.push(PenaltyCandidate {
matrix,
nullspace_dim_hint: 0,
source,
normalization_scale,
kronecker_factors: None,
op: None,
});
}
filter_active_penalty_candidates(candidates)
}
fn shape_uses_box_reparameterization(basis: &SmoothBasisSpec) -> bool {
matches!(basis, SmoothBasisSpec::BSpline1D { .. })
}
fn build_shape_constraint_grid_1d(x: ArrayView1<'_, f64>) -> Result<Array1<f64>, BasisError> {
if x.is_empty() {
return Err(BasisError::InvalidInput(
"shape-constrained smooth requires non-empty covariate values".to_string(),
));
}
if x.iter().any(|v| !v.is_finite()) {
return Err(BasisError::InvalidInput(
"shape-constrained smooth requires finite covariate values".to_string(),
));
}
let mut x_sorted: Vec<f64> = x.iter().copied().collect();
x_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let mut x_unique: Vec<f64> = Vec::with_capacity(x_sorted.len());
let mut last: Option<f64> = None;
for v in x_sorted {
let take = match last {
None => true,
Some(prev) => (v - prev).abs() > 1e-12 * prev.abs().max(v.abs()).max(1.0),
};
if take {
x_unique.push(v);
last = Some(v);
}
}
if x_unique.len() < 2 {
return Err(BasisError::InvalidInput(
"shape-constrained smooth requires at least two unique covariate values".to_string(),
));
}
let min_x = x_unique[0];
let max_x = *x_unique
.last()
.expect("x_unique has at least two elements by construction");
if (max_x - min_x).abs() <= 1e-12 {
return Err(BasisError::InvalidInput(
"shape-constrained smooth requires non-degenerate covariate range".to_string(),
));
}
let target_points = x_unique.len().clamp(96, 320);
let mut grid = Array1::<f64>::zeros(target_points);
let denom = (target_points - 1) as f64;
for i in 0..target_points {
let t = i as f64 / denom;
grid[i] = min_x + t * (max_x - min_x);
}
Ok(grid)
}
fn build_shape_constraint_design_1d(
data: ArrayView2<'_, f64>,
term: &SmoothTermSpec,
metadata: &BasisMetadata,
axis_col: usize,
) -> Result<(Array1<f64>, Array2<f64>), BasisError> {
let x_grid = build_shape_constraint_grid_1d(data.column(axis_col))?;
let grid_2d = x_grid
.clone()
.into_shape_with_order((x_grid.len(), 1))
.map_err(|e| {
BasisError::InvalidInput(format!(
"failed to construct 1D shape grid matrix for term '{}': {e}",
term.name
))
})?;
let design = match (&term.basis, metadata) {
(
SmoothBasisSpec::BSpline1D { spec, .. },
BasisMetadata::BSpline1D {
knots,
identifiability_transform,
},
) => {
let evalspec = BSplineBasisSpec {
degree: spec.degree,
penalty_order: spec.penalty_order,
knotspec: BSplineKnotSpec::Provided(knots.clone()),
double_penalty: false,
identifiability: identifiability_transform
.as_ref()
.map(|z| BSplineIdentifiability::FrozenTransform {
transform: z.clone(),
})
.unwrap_or(BSplineIdentifiability::None),
};
build_bspline_basis_1d(x_grid.view(), &evalspec)?
.design
.to_dense()
}
(
SmoothBasisSpec::ThinPlate { .. },
BasisMetadata::ThinPlate {
centers,
length_scale,
identifiability_transform,
radial_reparam,
..
},
) => {
let evalspec = ThinPlateBasisSpec {
center_strategy: crate::basis::CenterStrategy::UserProvided(centers.clone()),
length_scale: *length_scale,
double_penalty: false,
identifiability: identifiability_transform
.as_ref()
.map(|z| SpatialIdentifiability::FrozenTransform {
transform: z.clone(),
})
.unwrap_or(SpatialIdentifiability::None),
radial_reparam: radial_reparam.clone(),
};
build_thin_plate_basis(grid_2d.view(), &evalspec)?
.design
.to_dense()
}
(
SmoothBasisSpec::Matern { .. },
BasisMetadata::Matern {
centers,
length_scale,
nu,
include_intercept,
identifiability_transform,
aniso_log_scales,
..
},
) => {
let ident = identifiability_transform
.as_ref()
.map(|z| MaternIdentifiability::FrozenTransform {
transform: z.clone(),
})
.unwrap_or(MaternIdentifiability::None);
let evalspec = MaternBasisSpec {
center_strategy: crate::basis::CenterStrategy::UserProvided(centers.clone()),
length_scale: *length_scale,
nu: *nu,
include_intercept: *include_intercept,
double_penalty: false,
identifiability: ident,
aniso_log_scales: aniso_log_scales.clone(),
};
build_matern_basis(grid_2d.view(), &evalspec)?
.design
.to_dense()
}
(
SmoothBasisSpec::Duchon { spec, .. },
BasisMetadata::Duchon {
centers,
length_scale,
power,
nullspace_order,
identifiability_transform,
aniso_log_scales,
..
},
) => {
let evalspec = DuchonBasisSpec {
center_strategy: crate::basis::CenterStrategy::UserProvided(centers.clone()),
length_scale: *length_scale,
power: *power,
nullspace_order: *nullspace_order,
identifiability: identifiability_transform
.as_ref()
.map(|z| SpatialIdentifiability::FrozenTransform {
transform: z.clone(),
})
.unwrap_or_else(|| spec.identifiability.clone()),
aniso_log_scales: aniso_log_scales.clone(),
operator_penalties: spec.operator_penalties.clone(),
};
build_duchon_basis(grid_2d.view(), &evalspec)?
.design
.to_dense()
}
_ => {
return Err(BasisError::InvalidInput(format!(
"shape-constraint grid reconstruction metadata mismatch for term '{}'",
term.name
)));
}
};
Ok((x_grid, design))
}
fn build_shape_linear_constraints_1d(
x: ArrayView1<'_, f64>,
design_local: ArrayView2<'_, f64>,
shape: ShapeConstraint,
) -> Result<Option<LinearInequalityConstraints>, BasisError> {
let (order, sign) = match shape_order_and_sign(shape) {
Some(v) => v,
None => return Ok(None),
};
let n = x.len();
let p = design_local.ncols();
if n == 0 || p == 0 {
return Ok(None);
}
if x.iter().any(|v| !v.is_finite()) {
return Err(BasisError::InvalidInput(
"shape-constrained smooth requires finite covariate values".to_string(),
));
}
let mut idx: Vec<usize> = (0..n).collect();
idx.sort_by(|&i, &j| x[i].partial_cmp(&x[j]).unwrap_or(std::cmp::Ordering::Equal));
let x_scale = x.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs())).max(1.0);
let x_tol = 1e-12 * x_scale;
let mut collapsedrows: Vec<Array1<f64>> = Vec::new();
let mut group_sum = Array1::<f64>::zeros(p);
let mut group_count = 0usize;
let mut last_x: Option<f64> = None;
for &r in &idx {
let xr = x[r];
let start_new = match last_x {
None => false,
Some(prev) => (xr - prev).abs() > x_tol,
};
if start_new {
if group_count > 0 {
collapsedrows.push(group_sum.mapv(|v| v / group_count as f64));
}
group_sum.fill(0.0);
group_count = 0;
}
group_sum += &design_local.row(r).to_owned();
group_count += 1;
last_x = Some(xr);
}
if group_count > 0 {
collapsedrows.push(group_sum.mapv(|v| v / group_count as f64));
}
let m = collapsedrows.len();
if m <= order {
return Err(BasisError::InvalidInput(format!(
"shape-constrained smooth requires at least {} unique covariate locations; found {}",
order + 1,
m
)));
}
let q_raw = m - order;
let mut arows: Vec<Array1<f64>> = Vec::with_capacity(q_raw);
for i in 0..q_raw {
let row = if order == 1 {
&collapsedrows[i + 1] - &collapsedrows[i]
} else {
&collapsedrows[i + 2] - &collapsedrows[i + 1].mapv(|v| 2.0 * v) + &collapsedrows[i]
};
let mut row_signed = row;
if sign < 0.0 {
row_signed.mapv_inplace(|v| -v);
}
let norm = row_signed.dot(&row_signed).sqrt();
if norm > 1e-12 {
arows.push(row_signed);
}
}
if arows.is_empty() {
return Ok(None);
}
let mut a = Array2::<f64>::zeros((arows.len(), p));
for (i, row) in arows.iter().enumerate() {
a.row_mut(i).assign(row);
}
let b = Array1::<f64>::zeros(a.nrows());
Ok(Some(LinearInequalityConstraints { a, b }))
}
fn linear_constraints_from_lower_bounds_global(
lower_bounds: &Array1<f64>,
) -> Option<LinearInequalityConstraints> {
let rows: Vec<usize> = (0..lower_bounds.len())
.filter(|&i| lower_bounds[i].is_finite())
.collect();
if rows.is_empty() {
return None;
}
let p = lower_bounds.len();
let mut a = Array2::<f64>::zeros((rows.len(), p));
let mut b = Array1::<f64>::zeros(rows.len());
for (r, &idx) in rows.iter().enumerate() {
a[[r, idx]] = 1.0;
b[r] = lower_bounds[idx];
}
Some(LinearInequalityConstraints { a, b })
}
fn merge_linear_constraints_global(
first: Option<LinearInequalityConstraints>,
second: Option<LinearInequalityConstraints>,
) -> Option<LinearInequalityConstraints> {
match (first, second) {
(None, None) => None,
(Some(c), None) | (None, Some(c)) => Some(c),
(Some(a), Some(b)) => {
if a.a.ncols() != b.a.ncols() {
return None;
}
let m1 = a.a.nrows();
let m2 = b.a.nrows();
let p = a.a.ncols();
let mut mat = Array2::<f64>::zeros((m1 + m2, p));
mat.slice_mut(s![0..m1, ..]).assign(&a.a);
mat.slice_mut(s![m1..(m1 + m2), ..]).assign(&b.a);
let mut rhs = Array1::<f64>::zeros(m1 + m2);
rhs.slice_mut(s![0..m1]).assign(&a.b);
rhs.slice_mut(s![m1..(m1 + m2)]).assign(&b.b);
Some(LinearInequalityConstraints { a: mat, b: rhs })
}
}
}
fn normalize_penalty_in_constrained_space(matrix: &Array2<f64>) -> (Array2<f64>, f64) {
let matrix = (matrix + &matrix.t().to_owned()) * 0.5;
let matrix = crate::terms::basis::project_penalty_to_psd_cone(&matrix);
let c = matrix.iter().map(|v| v * v).sum::<f64>().sqrt();
if c.is_finite() && c > 0.0 {
(matrix.mapv(|v| v / c), c)
} else {
(matrix, 1.0)
}
}
fn build_tensor_bspline_basis(
data: ArrayView2<'_, f64>,
feature_cols: &[usize],
spec: &TensorBSplineSpec,
) -> Result<BasisBuildResult, BasisError> {
if feature_cols.is_empty() {
return Err(BasisError::InvalidInput(
"TensorBSpline requires at least one feature column".to_string(),
));
}
if feature_cols.len() != spec.marginalspecs.len() {
return Err(BasisError::DimensionMismatch(format!(
"TensorBSpline feature/spec mismatch: feature_cols={}, marginalspecs={}",
feature_cols.len(),
spec.marginalspecs.len()
)));
}
let p = data.ncols();
for &c in feature_cols {
if c >= p {
return Err(BasisError::DimensionMismatch(format!(
"tensor feature column {c} is out of bounds for data with {p} columns"
)));
}
}
let mut marginal_knots = Vec::<Array1<f64>>::with_capacity(feature_cols.len());
let mut marginal_degrees = Vec::<usize>::with_capacity(feature_cols.len());
let mut marginalnum_basis = Vec::<usize>::with_capacity(feature_cols.len());
let mut marginal_penalties = Vec::<Array2<f64>>::with_capacity(feature_cols.len());
let mut marginal_designs = Vec::<Array2<f64>>::with_capacity(feature_cols.len());
for (dim, (&col, marginalspec)) in feature_cols
.iter()
.zip(spec.marginalspecs.iter())
.enumerate()
{
let mut marginal_unconstrained = marginalspec.clone();
marginal_unconstrained.identifiability = BSplineIdentifiability::None;
let built = build_bspline_basis_1d(data.column(col), &marginal_unconstrained)?;
let knots = match built.metadata {
BasisMetadata::BSpline1D { knots, .. } => knots,
_ => {
return Err(BasisError::InvalidInput(format!(
"internal TensorBSpline error at dim {dim}: expected BSpline1D metadata"
)));
}
};
marginal_knots.push(knots);
marginal_degrees.push(marginalspec.degree);
marginalnum_basis.push(built.design.ncols());
marginal_designs.push(built.design.to_dense());
marginal_penalties.push(
built
.penalties
.first()
.ok_or_else(|| {
BasisError::InvalidInput(format!(
"internal TensorBSpline error at dim {dim}: missing marginal penalty"
))
})?
.clone(),
);
built.nullspace_dims.first().ok_or_else(|| {
BasisError::InvalidInput(format!(
"internal TensorBSpline error at dim {dim}: missing marginal nullspace dim"
))
})?;
}
let total_cols: usize = marginalnum_basis.iter().product();
let mut dense_design = (!matches!(spec.identifiability, TensorBSplineIdentifiability::None))
.then(|| tensor_product_design_from_marginals(&marginal_designs))
.transpose()?;
let mut candidates = Vec::<PenaltyCandidate>::with_capacity(
marginal_penalties.len() + if spec.double_penalty { 1 } else { 0 },
);
for dim in 0..marginal_penalties.len() {
let mut s_dim = Array2::<f64>::eye(1);
let mut factors = Vec::<Array2<f64>>::with_capacity(marginalnum_basis.len());
for (j, &qj) in marginalnum_basis.iter().enumerate() {
let factor = if j == dim {
marginal_penalties[j].clone()
} else {
Array2::<f64>::eye(qj)
};
factors.push(factor.clone());
s_dim = kronecker_product(&s_dim, &factor);
}
candidates.push(PenaltyCandidate {
matrix: s_dim,
nullspace_dim_hint: 0,
source: PenaltySource::TensorMarginal { dim },
normalization_scale: 1.0,
kronecker_factors: Some(factors),
op: None,
});
}
if spec.double_penalty {
candidates.push(PenaltyCandidate {
matrix: Array2::<f64>::eye(total_cols),
nullspace_dim_hint: 0,
source: PenaltySource::TensorGlobalRidge,
normalization_scale: 1.0,
kronecker_factors: None,
op: None,
});
}
let z_opt = match &spec.identifiability {
TensorBSplineIdentifiability::None => None,
TensorBSplineIdentifiability::SumToZero => {
if total_cols < 2 {
return Err(BasisError::InvalidInput(
"TensorBSpline requires at least 2 basis coefficients to enforce sum-to-zero identifiability".to_string(),
));
}
let dense_design_ref = dense_design.as_ref().ok_or_else(|| {
BasisError::InvalidInput(
"tensor sum-to-zero identifiability requires a realized basis".to_string(),
)
})?;
let (_, z) = apply_sum_to_zero_constraint(dense_design_ref.view(), None)?;
Some(z)
}
TensorBSplineIdentifiability::FrozenTransform { transform } => {
if transform.nrows() != total_cols {
return Err(BasisError::DimensionMismatch(format!(
"frozen tensor identifiability transform mismatch: design has {} columns but transform has {} rows",
total_cols,
transform.nrows()
)));
}
Some(transform.clone())
}
};
if let Some(z) = z_opt.as_ref() {
let dense = dense_design.as_mut().ok_or_else(|| {
BasisError::InvalidInput(
"tensor identifiability transform requires a realized basis".to_string(),
)
})?;
*dense = dense.dot(z);
candidates = candidates
.into_iter()
.map(|candidate| -> Result<PenaltyCandidate, BasisError> {
let zt_s = z.t().dot(&candidate.matrix);
let matrix = zt_s.dot(z);
let (matrix, c_new) = normalize_penalty_in_constrained_space(&matrix);
Ok(PenaltyCandidate {
nullspace_dim_hint: candidate.nullspace_dim_hint,
matrix,
source: candidate.source,
normalization_scale: candidate.normalization_scale * c_new,
kronecker_factors: candidate.kronecker_factors.clone(),
op: candidate.op.clone(),
})
})
.collect::<Result<Vec<_>, _>>()?;
}
let (penalties, nullspace_dims, penaltyinfo, ops) =
filter_active_penalty_candidates_with_ops(candidates)?;
let design = if let Some(dense_design) = dense_design {
DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(dense_design))
} else {
let marginals: Vec<Arc<Array2<f64>>> = marginal_designs
.iter()
.map(|m| Arc::new(m.clone()))
.collect();
let op = TensorProductDesignOperator::new(marginals).map_err(|e| {
BasisError::InvalidInput(format!("TensorProductDesignOperator build failed: {e}"))
})?;
DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(Arc::new(op)))
};
Ok(BasisBuildResult {
design,
penalties,
nullspace_dims,
penaltyinfo,
ops,
metadata: BasisMetadata::TensorBSpline {
feature_cols: feature_cols.to_vec(),
knots: marginal_knots,
degrees: marginal_degrees,
identifiability_transform: z_opt,
},
kronecker_factored: if matches!(spec.identifiability, TensorBSplineIdentifiability::None) {
Some(KroneckerFactoredBasis {
marginal_designs,
marginal_penalties,
marginal_dims: marginalnum_basis.clone(),
has_double_penalty: spec.double_penalty,
})
} else {
None
},
})
}
fn tensor_product_design_from_marginals(
marginal_designs: &[Array2<f64>],
) -> Result<Array2<f64>, BasisError> {
if marginal_designs.is_empty() {
return Err(BasisError::InvalidInput(
"TensorBSpline requires at least one marginal basis".to_string(),
));
}
let n = marginal_designs[0].nrows();
for (i, b) in marginal_designs.iter().enumerate().skip(1) {
if b.nrows() != n {
return Err(BasisError::DimensionMismatch(format!(
"tensor marginal row mismatch at dim {i}: expected {n}, got {}",
b.nrows()
)));
}
}
let total_cols = marginal_designs.iter().try_fold(1usize, |acc, b| {
acc.checked_mul(b.ncols())
.ok_or_else(|| BasisError::DimensionMismatch("tensor basis too large".to_string()))
})?;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let row_data: Vec<f64> = (0..n)
.into_par_iter()
.flat_map_iter(|i| {
let mut rowvals = vec![1.0_f64];
for b in marginal_designs {
let q = b.ncols();
let mut next = vec![0.0_f64; rowvals.len() * q];
for (a_idx, &aval) in rowvals.iter().enumerate() {
for col in 0..q {
next[a_idx * q + col] = aval * b[[i, col]];
}
}
rowvals = next;
}
rowvals.into_iter()
})
.collect();
let design = Array2::<f64>::from_shape_vec((n, total_cols), row_data)
.map_err(|e| BasisError::DimensionMismatch(format!("tensor design assembly: {e}")))?;
Ok(design)
}
fn build_random_effect_block(
data: ArrayView2<'_, f64>,
spec: &RandomEffectTermSpec,
) -> Result<RandomEffectBlock, BasisError> {
let n = data.nrows();
let p = data.ncols();
if spec.feature_col >= p {
return Err(BasisError::DimensionMismatch(format!(
"random-effect term '{}' feature column {} out of bounds for {} columns",
spec.name, spec.feature_col, p
)));
}
let col = data.column(spec.feature_col);
if col.iter().any(|v| !v.is_finite()) {
return Err(BasisError::InvalidInput(format!(
"random-effect term '{}' contains non-finite group values",
spec.name
)));
}
let mut kept_levels: Vec<u64> = if let Some(levels) = spec.frozen_levels.as_ref() {
if levels.is_empty() {
return Err(BasisError::InvalidInput(format!(
"random-effect term '{}' has empty frozen_levels",
spec.name
)));
}
levels.clone()
} else {
let mut levels_set = BTreeSet::<u64>::new();
for &v in col {
levels_set.insert(v.to_bits());
}
if levels_set.is_empty() {
return Err(BasisError::InvalidInput(format!(
"random-effect term '{}' has no observed levels",
spec.name
)));
}
let levels: Vec<u64> = levels_set.into_iter().collect();
let start_idx = if spec.drop_first_level && levels.len() > 1 {
1usize
} else {
0usize
};
levels[start_idx..].to_vec()
};
kept_levels.sort_unstable();
kept_levels.dedup();
if kept_levels.is_empty() {
return Err(BasisError::InvalidInput(format!(
"random-effect term '{}' drops all levels; keep at least one level",
spec.name
)));
}
let q = kept_levels.len();
let mut group_ids = Vec::with_capacity(n);
for &v in col {
let bits = v.to_bits();
group_ids.push(kept_levels.binary_search(&bits).ok());
}
Ok(RandomEffectBlock {
name: spec.name.clone(),
group_ids,
num_groups: q,
kept_levels,
})
}
impl SmoothDesign {
pub fn map_term_coefficients(
unconstrained: &Array1<f64>,
shape: ShapeConstraint,
) -> Result<Array1<f64>, BasisError> {
if unconstrained.is_empty() {
return Err(BasisError::InvalidInput(
"unconstrained coefficient vector cannot be empty".to_string(),
));
}
let mapped = match shape {
ShapeConstraint::None => unconstrained.clone(),
ShapeConstraint::MonotoneIncreasing => cumulative_exp(unconstrained, 1.0),
ShapeConstraint::MonotoneDecreasing => cumulative_exp(unconstrained, -1.0),
ShapeConstraint::Convex => second_cumulative_exp(unconstrained, 1.0),
ShapeConstraint::Concave => second_cumulative_exp(unconstrained, -1.0),
};
Ok(mapped)
}
}
struct LocalSmoothTermBuild {
dim: usize,
design: DesignMatrix,
penalties: Vec<Array2<f64>>,
ops: Vec<Option<std::sync::Arc<dyn crate::terms::penalty_op::PenaltyOp>>>,
nullspaces: Vec<usize>,
penaltyinfo: Vec<PenaltyInfo>,
pre_dropped_penaltyinfo: Vec<PenaltyInfo>,
metadata: BasisMetadata,
linear_constraints: Option<LinearInequalityConstraints>,
box_reparam: bool,
kronecker_factored: Option<KroneckerFactoredBasis>,
}
fn build_single_local_smooth_term(
data: ArrayView2<'_, f64>,
term: &SmoothTermSpec,
workspace: &mut crate::basis::BasisWorkspace,
) -> Result<LocalSmoothTermBuild, BasisError> {
if term.shape != ShapeConstraint::None && !shape_supports_basis(term) {
return Err(BasisError::InvalidInput(format!(
"ShapeConstraint::{:?} is unsupported for term '{}'",
term.shape, term.name
)));
}
let mut shape_axis_col: Option<usize> = None;
let mut built: BasisBuildResult = match &term.basis {
SmoothBasisSpec::BSpline1D { feature_col, spec } => {
if *feature_col >= data.ncols() {
return Err(BasisError::DimensionMismatch(format!(
"term '{}' feature column {} out of bounds for {} columns",
term.name,
feature_col,
data.ncols()
)));
}
let mut spec_local = spec.clone();
if term.shape != ShapeConstraint::None {
spec_local.identifiability = BSplineIdentifiability::None;
}
build_bspline_basis_1d(data.column(*feature_col), &spec_local)?
}
SmoothBasisSpec::ThinPlate {
feature_cols,
spec,
input_scales,
} => {
if term.shape != ShapeConstraint::None {
if feature_cols.len() != 1 {
return Err(BasisError::InvalidInput(format!(
"ShapeConstraint::{:?} for term '{}' on ThinPlate basis requires exactly 1 feature axis; found {}",
term.shape,
term.name,
feature_cols.len()
)));
}
shape_axis_col = Some(feature_cols[0]);
}
let mut x = select_columns(data, feature_cols)?;
let (scales, length_scale_eff) = if let Some(s) = input_scales {
apply_input_standardization(&mut x, s);
(
Some(s.clone()),
compensate_length_scale_for_standardization(spec.length_scale, s),
)
} else if let Some(s) = compute_spatial_input_scales(x.view()) {
apply_input_standardization(&mut x, &s);
let l_eff = compensate_length_scale_for_standardization(spec.length_scale, &s);
(Some(s), l_eff)
} else {
(None, spec.length_scale)
};
let mut spec_local = spec.clone();
spec_local.length_scale = length_scale_eff;
if matches!(
spec_local.identifiability,
SpatialIdentifiability::OrthogonalToParametric
) {
spec_local.identifiability = SpatialIdentifiability::None;
}
let mut result = build_thin_plate_basis(x.view(), &spec_local).map_err(|err| {
rewrite_thin_plate_knots_error(err, &term.name, feature_cols.len(), spec)
})?;
match &mut result.metadata {
BasisMetadata::ThinPlate {
input_scales: ms,
length_scale,
..
} => {
*ms = scales;
*length_scale = spec.length_scale;
}
BasisMetadata::Duchon {
input_scales: ms,
length_scale,
..
} => {
*ms = scales;
*length_scale = Some(spec.length_scale);
}
_ => {}
}
result
}
SmoothBasisSpec::Matern {
feature_cols,
spec,
input_scales,
} => {
if term.shape != ShapeConstraint::None {
if feature_cols.len() != 1 {
return Err(BasisError::InvalidInput(format!(
"ShapeConstraint::{:?} for term '{}' on Matern basis requires exactly 1 feature axis; found {}",
term.shape,
term.name,
feature_cols.len()
)));
}
shape_axis_col = Some(feature_cols[0]);
}
let mut x = select_columns(data, feature_cols)?;
let (scales, length_scale_eff) = if let Some(s) = input_scales {
apply_input_standardization(&mut x, s);
(
Some(s.clone()),
compensate_length_scale_for_standardization(spec.length_scale, s),
)
} else if let Some(s) = compute_spatial_input_scales(x.view()) {
apply_input_standardization(&mut x, &s);
let l_eff = compensate_length_scale_for_standardization(spec.length_scale, &s);
(Some(s), l_eff)
} else {
(None, spec.length_scale)
};
let mut spec_local = spec.clone();
spec_local.length_scale = length_scale_eff;
let mut result = build_matern_basiswithworkspace(x.view(), &spec_local, workspace)?;
if let BasisMetadata::Matern {
input_scales,
length_scale,
..
} = &mut result.metadata
{
*input_scales = scales;
*length_scale = spec.length_scale;
}
result
}
SmoothBasisSpec::Duchon {
feature_cols,
spec,
input_scales,
} => {
if term.shape != ShapeConstraint::None {
if feature_cols.len() != 1 {
return Err(BasisError::InvalidInput(format!(
"ShapeConstraint::{:?} for term '{}' on Duchon basis requires exactly 1 feature axis; found {}",
term.shape,
term.name,
feature_cols.len()
)));
}
shape_axis_col = Some(feature_cols[0]);
}
let mut x = select_columns(data, feature_cols)?;
let (scales, length_scale_eff) = if let Some(s) = input_scales {
apply_input_standardization(&mut x, s);
(
Some(s.clone()),
compensate_optional_length_scale_for_standardization(spec.length_scale, s),
)
} else if let Some(s) = compute_spatial_input_scales(x.view()) {
apply_input_standardization(&mut x, &s);
let l_eff =
compensate_optional_length_scale_for_standardization(spec.length_scale, &s);
(Some(s), l_eff)
} else {
(None, spec.length_scale)
};
let mut spec_local = spec.clone();
spec_local.length_scale = length_scale_eff;
if matches!(
spec_local.identifiability,
SpatialIdentifiability::OrthogonalToParametric
) {
spec_local.identifiability = SpatialIdentifiability::None;
}
let mut result = build_duchon_basiswithworkspace(x.view(), &spec_local, workspace)?;
if let BasisMetadata::Duchon {
input_scales,
length_scale,
..
} = &mut result.metadata
{
*input_scales = scales;
*length_scale = spec.length_scale;
}
result
}
SmoothBasisSpec::TensorBSpline { feature_cols, spec } => {
build_tensor_bspline_basis(data, feature_cols, spec)?
}
};
match &term.basis {
SmoothBasisSpec::Matern { .. } => {
let (penalties, nullspace_dims, penaltyinfo) =
matern_operator_penalty_triplet_from_metadata(&built.metadata)?;
built.penalties = penalties;
built.nullspace_dims = nullspace_dims;
built.penaltyinfo = penaltyinfo;
}
_ => {}
}
let p_local = built.design.ncols();
let mut metadata = built.metadata.clone();
let kron_factored = if term.shape == ShapeConstraint::None {
built.kronecker_factored
} else {
None
};
let mut design_t = built.design;
let mut penalties_t: Vec<Array2<f64>> = built.penalties;
let mut ops_t: Vec<Option<std::sync::Arc<dyn crate::terms::penalty_op::PenaltyOp>>> = built.ops;
if matches!(
spatial_identifiability_policy(term),
Some(SpatialIdentifiability::OrthogonalToParametric)
) {
metadata = freeze_raw_spatial_metadata(metadata, design_t.ncols());
}
let active_penaltyinfo_t = built
.penaltyinfo
.iter()
.filter(|info| info.active)
.cloned()
.collect::<Vec<_>>();
let pre_dropped_penaltyinfo_t = built
.penaltyinfo
.iter()
.filter(|info| !info.active)
.cloned()
.collect::<Vec<_>>();
let use_box_reparam =
term.shape != ShapeConstraint::None && shape_uses_box_reparameterization(&term.basis);
if let Some((order, sign)) = shape_order_and_sign(term.shape)
&& use_box_reparam
{
let t = cumulative_sum_transform_matrix(p_local, order, sign);
let inner_dense = match design_t {
DesignMatrix::Dense(d) => d,
DesignMatrix::Sparse(sp) => crate::matrix::DenseDesignMatrix::from(
sp.try_to_dense_arc("shape-constrained coefficient transform")
.map_err(BasisError::InvalidInput)?,
),
};
let coeff_op = crate::matrix::CoefficientTransformOperator::new(inner_dense, t.clone())
.map_err(|e| BasisError::InvalidInput(format!("CoefficientTransformOperator: {e}")))?;
design_t = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(Arc::new(coeff_op)));
penalties_t = penalties_t
.into_iter()
.map(|s_local| {
let tt_s = t.t().dot(&s_local);
tt_s.dot(&t)
})
.collect();
ops_t = vec![None; penalties_t.len()];
}
if penalties_t.len() != active_penaltyinfo_t.len() {
return Err(BasisError::InvalidInput(format!(
"internal penalty metadata mismatch for term '{}': active penalties={}, active infos={}",
term.name,
penalties_t.len(),
active_penaltyinfo_t.len()
)));
}
if ops_t.len() != penalties_t.len() {
ops_t = vec![None; penalties_t.len()];
}
let penalty_candidates = penalties_t
.into_iter()
.zip(active_penaltyinfo_t.into_iter())
.zip(ops_t.into_iter())
.map(
|((matrix, info), op_in)| -> Result<PenaltyCandidate, BasisError> {
let (matrix, c_new) = normalize_penalty_in_constrained_space(&matrix);
let scaled_op = if c_new > 0.0 && c_new.is_finite() {
op_in.map(|op| {
std::sync::Arc::new(crate::terms::penalty_op::ScaledPenaltyOp::new(
op,
1.0 / c_new,
))
as std::sync::Arc<dyn crate::terms::penalty_op::PenaltyOp>
})
} else {
None
};
Ok(PenaltyCandidate {
nullspace_dim_hint: info.nullspace_dim_hint,
matrix,
source: info.source,
normalization_scale: info.normalization_scale * c_new,
kronecker_factors: None,
op: scaled_op,
})
},
)
.collect::<Result<Vec<_>, _>>()?;
let (penalties_t, nullspaces_t, penaltyinfo_t, ops_t) =
crate::terms::basis::filter_active_penalty_candidates_with_ops(penalty_candidates)?;
let linear_constraints_local = if term.shape != ShapeConstraint::None && !use_box_reparam {
let axis = shape_axis_col.ok_or_else(|| {
BasisError::InvalidInput(format!(
"internal shape-constraint axis missing for term '{}'",
term.name
))
})?;
let (x_shape_eval, design_shape_eval) =
build_shape_constraint_design_1d(data, term, &metadata, axis)?;
build_shape_linear_constraints_1d(
x_shape_eval.view(),
design_shape_eval.view(),
term.shape,
)?
} else {
None
};
Ok(LocalSmoothTermBuild {
dim: p_local,
design: design_t,
penalties: penalties_t,
ops: ops_t,
nullspaces: nullspaces_t,
penaltyinfo: penaltyinfo_t,
pre_dropped_penaltyinfo: pre_dropped_penaltyinfo_t,
metadata,
linear_constraints: linear_constraints_local,
box_reparam: use_box_reparam,
kronecker_factored: kron_factored,
})
}
pub fn build_smooth_design(
data: ArrayView2<'_, f64>,
terms: &[SmoothTermSpec],
) -> Result<RawSmoothDesign, BasisError> {
let mut ws = crate::basis::BasisWorkspace::new();
build_smooth_design_withworkspace(data, terms, &mut ws)
}
pub fn build_smooth_design_withworkspace(
data: ArrayView2<'_, f64>,
terms: &[SmoothTermSpec],
workspace: &mut crate::basis::BasisWorkspace,
) -> Result<RawSmoothDesign, BasisError> {
let mut planned_blocks = plan_joint_spatial_centers_for_term_blocks(data, &[terms.to_vec()])?;
let planned_terms = planned_blocks.pop().ok_or_else(|| {
BasisError::InvalidInput(
"joint spatial center planner returned no smooth blocks".to_string(),
)
})?;
let policy = workspace.policy().clone();
let mut local_builds: Vec<LocalSmoothTermBuild> = {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
planned_terms
.into_par_iter()
.map(|term| {
let mut term_workspace = crate::basis::BasisWorkspace::with_policy(policy.clone());
build_single_local_smooth_term(data, &term, &mut term_workspace)
})
.collect::<Result<Vec<_>, _>>()?
};
let local_dims: Vec<usize> = local_builds.iter().map(|built| built.dim).collect();
let local_designs: Vec<DesignMatrix> = local_builds
.iter()
.map(|built| built.design.clone())
.collect();
let total_p: usize = local_dims.iter().sum();
let mut terms_out = Vec::<SmoothTerm>::with_capacity(terms.len());
let mut penalties_global = Vec::<BlockwisePenalty>::new();
let mut nullspace_dims_global = Vec::<usize>::new();
let mut penaltyinfo_global = Vec::<PenaltyBlockInfo>::new();
let mut dropped_penaltyinfo_global = Vec::<DroppedPenaltyBlockInfo>::new();
let mut coefficient_lower_bounds = Array1::<f64>::from_elem(total_p, f64::NEG_INFINITY);
let mut any_bounds = false;
let mut linear_constraintsrows: Vec<Array1<f64>> = Vec::new();
let mut linear_constraints_b: Vec<f64> = Vec::new();
let mut col_start = 0usize;
for (idx, term) in terms.iter().enumerate() {
let p_local = local_dims[idx];
let col_end = col_start + p_local;
let lb_local = if local_builds[idx].box_reparam {
shape_lower_bounds_local(term.shape, p_local)
} else {
None
};
let activeinfos = local_builds[idx]
.penaltyinfo
.iter()
.filter(|info| info.active)
.collect::<Vec<_>>();
if activeinfos.len() != local_builds[idx].penalties.len() {
return Err(BasisError::InvalidInput(format!(
"internal penalty info mismatch for term '{}': activeinfos={}, penalties={}",
term.name,
activeinfos.len(),
local_builds[idx].penalties.len()
)));
}
for (((s_local, &ns), info), op_local) in local_builds[idx]
.penalties
.iter()
.zip(local_builds[idx].nullspaces.iter())
.zip(activeinfos.into_iter())
.zip(local_builds[idx].ops.iter())
{
let global_index = penalties_global.len();
penalties_global.push(
BlockwisePenalty::new(col_start..col_end, s_local.clone())
.with_op(op_local.clone()),
);
nullspace_dims_global.push(ns);
let mut penalty = info.clone();
penalty.nullspace_dim_hint = ns;
penaltyinfo_global.push(PenaltyBlockInfo {
global_index,
termname: Some(term.name.clone()),
penalty,
});
}
for info in local_builds[idx]
.penaltyinfo
.iter()
.filter(|info| !info.active)
{
dropped_penaltyinfo_global.push(DroppedPenaltyBlockInfo {
termname: Some(term.name.clone()),
penalty: info.clone(),
});
}
for info in &local_builds[idx].pre_dropped_penaltyinfo {
dropped_penaltyinfo_global.push(DroppedPenaltyBlockInfo {
termname: Some(term.name.clone()),
penalty: info.clone(),
});
}
terms_out.push(SmoothTerm {
name: term.name.clone(),
coeff_range: col_start..col_end,
shape: term.shape,
penalties_local: local_builds[idx].penalties.clone(),
nullspace_dims: local_builds[idx].nullspaces.clone(),
penaltyinfo_local: local_builds[idx].penaltyinfo.clone(),
metadata: local_builds[idx].metadata.clone(),
lower_bounds_local: lb_local.clone(),
linear_constraints_local: local_builds[idx].linear_constraints.clone(),
kronecker_factored: local_builds[idx].kronecker_factored.take(),
});
if let Some(lin_local) = &local_builds[idx].linear_constraints {
for r in 0..lin_local.a.nrows() {
let mut row = Array1::<f64>::zeros(total_p);
row.slice_mut(s![col_start..col_end])
.assign(&lin_local.a.row(r));
linear_constraintsrows.push(row);
linear_constraints_b.push(lin_local.b[r]);
}
}
if let Some(lb_local) = lb_local {
coefficient_lower_bounds
.slice_mut(s![col_start..col_end])
.assign(&lb_local);
any_bounds = true;
}
col_start = col_end;
}
debug_assert_eq!(
penalties_global.len(),
nullspace_dims_global.len(),
"global smooth penalty/nullspace bookkeeping diverged"
);
debug_assert_eq!(
penalties_global.len(),
penaltyinfo_global.len(),
"global smooth penalty metadata bookkeeping diverged"
);
Ok(RawSmoothDesign {
term_designs: local_designs,
penalties: penalties_global,
nullspace_dims: nullspace_dims_global,
penaltyinfo: penaltyinfo_global,
dropped_penaltyinfo: dropped_penaltyinfo_global,
terms: terms_out,
coefficient_lower_bounds: if any_bounds {
Some(coefficient_lower_bounds)
} else {
None
},
linear_constraints: if linear_constraintsrows.is_empty() {
None
} else {
let mut a = Array2::<f64>::zeros((linear_constraintsrows.len(), total_p));
for (i, row) in linear_constraintsrows.iter().enumerate() {
a.row_mut(i).assign(row);
}
Some(LinearInequalityConstraints {
a,
b: Array1::from_vec(linear_constraints_b),
})
},
})
}
fn build_term_collection_design_inner(
data: ArrayView2<'_, f64>,
spec: &TermCollectionSpec,
) -> Result<TermCollectionDesign, BasisError> {
use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
let n = data.nrows();
let p_data = data.ncols();
let p_intercept = 1usize;
let p_lin = spec.linear_terms.len();
let (smooth_raw_result, (random_blocks_result, linear_block_result)) = rayon::join(
|| build_smooth_design(data, &spec.smooth_terms),
|| {
rayon::join(
|| {
spec.random_effect_terms
.par_iter()
.map(|term| build_random_effect_block(data, term))
.collect::<Vec<_>>()
.into_iter()
.collect::<Result<Vec<_>, _>>()
},
|| -> Result<Option<Array2<f64>>, BasisError> {
if p_lin == 0 {
return Ok(None);
}
let linear_columns = (0..p_lin)
.into_par_iter()
.map(|j| {
let linear = &spec.linear_terms[j];
if linear.feature_col >= p_data {
return Err(BasisError::DimensionMismatch(format!(
"linear term '{}' feature column {} out of bounds for {} columns",
linear.name, linear.feature_col, p_data
)));
}
Ok(data.column(linear.feature_col).to_owned())
})
.collect::<Vec<_>>()
.into_iter()
.collect::<Result<Vec<_>, _>>()?;
let mut out = Array2::<f64>::zeros((n, p_lin));
for (j, column) in linear_columns.iter().enumerate() {
out.column_mut(j).assign(column);
}
Ok(Some(out))
},
)
},
);
let smooth_raw = smooth_raw_result?;
let random_blocks = random_blocks_result?;
let linear_block = linear_block_result?;
let smooth = apply_global_smooth_identifiability(
smooth_raw,
data,
&spec.linear_terms,
&spec.smooth_terms,
)?;
let p_rand: usize = random_blocks.iter().map(|b| b.num_groups).sum();
let p_smooth = smooth.total_smooth_cols();
let p_total = p_intercept + p_lin + p_rand + p_smooth;
let mut linear_ranges = Vec::<(String, Range<usize>)>::with_capacity(p_lin);
for (j, linear) in spec.linear_terms.iter().enumerate() {
let col = p_intercept + j;
linear_ranges.push((linear.name.clone(), col..(col + 1)));
}
let mut random_effect_ranges =
Vec::<(String, Range<usize>)>::with_capacity(random_blocks.len());
let mut random_effect_levels = Vec::<(String, Vec<u64>)>::with_capacity(random_blocks.len());
let mut col_cursor = p_intercept + p_lin;
for block in &random_blocks {
let q = block.num_groups;
let end = col_cursor + q;
random_effect_ranges.push((block.name.clone(), col_cursor..end));
random_effect_levels.push((block.name.clone(), block.kept_levels.clone()));
col_cursor = end;
}
let mut blocks = Vec::<DesignBlock>::new();
blocks.push(DesignBlock::Intercept(n));
if let Some(lin_block) = linear_block {
blocks.push(DesignBlock::Dense(crate::matrix::DenseDesignMatrix::from(
lin_block,
)));
}
for block in &random_blocks {
let re_op = RandomEffectOperator::new(block.group_ids.clone(), block.num_groups);
blocks.push(DesignBlock::RandomEffect(Arc::new(re_op)));
}
if p_smooth > 0 {
for term_design in &smooth.term_designs {
match term_design {
DesignMatrix::Dense(dense) => blocks.push(DesignBlock::Dense(dense.clone())),
DesignMatrix::Sparse(sparse) => blocks.push(DesignBlock::Sparse(sparse.clone())),
}
}
}
let design = assemble_term_collection_design_matrix(blocks)?;
let mut penalties = Vec::<BlockwisePenalty>::new();
let mut nullspace_dims = Vec::<usize>::new();
let mut penaltyinfo = Vec::<PenaltyBlockInfo>::new();
let mut dropped_penaltyinfo = Vec::<DroppedPenaltyBlockInfo>::new();
let mut coefficient_lower_bounds = Array1::<f64>::from_elem(p_total, f64::NEG_INFINITY);
let mut any_bounds = false;
let mut linear_constraintrows = Vec::<Array1<f64>>::new();
let mut linear_constraint_b = Vec::<f64>::new();
let mut penalized_linear_cols = Vec::<usize>::new();
for (j, linear) in spec.linear_terms.iter().enumerate() {
let col = p_intercept + j;
if let Some(lb) = linear.coefficient_min {
let mut row = Array1::<f64>::zeros(p_total);
row[col] = 1.0;
linear_constraintrows.push(row);
linear_constraint_b.push(lb);
}
if let Some(ub) = linear.coefficient_max {
let mut row = Array1::<f64>::zeros(p_total);
row[col] = -1.0;
linear_constraintrows.push(row);
linear_constraint_b.push(-ub);
}
if linear.double_penalty {
penalized_linear_cols.push(col);
}
}
if !penalized_linear_cols.is_empty() {
let min_col = *penalized_linear_cols.iter().min().unwrap();
let max_col = *penalized_linear_cols.iter().max().unwrap();
let block_range = min_col..(max_col + 1);
let block_size = block_range.len();
let mut s_local = Array2::<f64>::zeros((block_size, block_size));
for &col in &penalized_linear_cols {
let local_idx = col - min_col;
s_local[[local_idx, local_idx]] = 1.0;
}
let global_index = penalties.len();
penalties.push(BlockwisePenalty::new(block_range, s_local));
nullspace_dims.push(0);
penaltyinfo.push(PenaltyBlockInfo {
global_index,
termname: Some("linear".to_string()),
penalty: PenaltyInfo {
source: PenaltySource::Other("LinearDoublePenaltyGroup".to_string()),
original_index: 0,
active: true,
effective_rank: penalized_linear_cols.len(),
dropped_reason: None,
nullspace_dim_hint: 0,
normalization_scale: 1.0,
kronecker_factors: None,
},
});
}
for (re_idx, (name, range)) in random_effect_ranges.iter().enumerate() {
if range.is_empty() {
continue;
}
let block_size = range.len();
let global_index = penalties.len();
penalties.push(BlockwisePenalty::ridge(range.clone(), 1.0));
nullspace_dims.push(0);
penaltyinfo.push(PenaltyBlockInfo {
global_index,
termname: Some(name.clone()),
penalty: PenaltyInfo {
source: PenaltySource::Other(format!("RandomEffectRidge({name})")),
original_index: re_idx,
active: true,
effective_rank: block_size,
dropped_reason: None,
nullspace_dim_hint: 0,
normalization_scale: 1.0,
kronecker_factors: None,
},
});
}
if smooth.penaltyinfo.len() != smooth.penalties.len() {
return Err(BasisError::InvalidInput(format!(
"smooth penalty metadata mismatch: penalties={}, metadata={}",
smooth.penalties.len(),
smooth.penaltyinfo.len()
)));
}
let smooth_start = p_intercept + p_lin + p_rand;
for ((bp_smooth, &ns), localinfo) in smooth
.penalties
.iter()
.zip(smooth.nullspace_dims.iter())
.zip(smooth.penaltyinfo.iter())
{
let global_index = penalties.len();
let offset_range =
(bp_smooth.col_range.start + smooth_start)..(bp_smooth.col_range.end + smooth_start);
let bp = if let Some(factors) = localinfo.penalty.kronecker_factors.as_ref() {
BlockwisePenalty::kronecker(offset_range, bp_smooth.local.clone(), factors.clone())
.with_op(bp_smooth.op.clone())
} else if matches!(localinfo.penalty.source, PenaltySource::TensorGlobalRidge)
|| matches!(
localinfo.penalty.source,
PenaltySource::Other(ref s) if s.starts_with("RandomEffectRidge")
)
{
BlockwisePenalty::ridge(offset_range, 1.0)
} else {
BlockwisePenalty::new(offset_range, bp_smooth.local.clone())
.with_op(bp_smooth.op.clone())
};
penalties.push(bp);
nullspace_dims.push(ns);
let mut penalty = localinfo.penalty.clone();
penalty.nullspace_dim_hint = ns;
penaltyinfo.push(PenaltyBlockInfo {
global_index,
termname: localinfo.termname.clone(),
penalty,
});
}
dropped_penaltyinfo.extend(smooth.dropped_penaltyinfo.iter().cloned());
debug_assert_eq!(
penalties.len(),
nullspace_dims.len(),
"term-collection penalty/nullspace bookkeeping diverged"
);
debug_assert_eq!(
penalties.len(),
penaltyinfo.len(),
"term-collection penalty metadata bookkeeping diverged"
);
if let Some(lb_smooth) = smooth.coefficient_lower_bounds.as_ref() {
let start = p_intercept + p_lin + p_rand;
coefficient_lower_bounds
.slice_mut(s![start..(start + p_smooth)])
.assign(lb_smooth);
any_bounds = true;
}
if let Some(lin_smooth) = smooth.linear_constraints.as_ref() {
let mut a_global = Array2::<f64>::zeros((lin_smooth.a.nrows(), p_total));
let start = p_intercept + p_lin + p_rand;
a_global
.slice_mut(s![.., start..(start + p_smooth)])
.assign(&lin_smooth.a);
for r in 0..a_global.nrows() {
linear_constraintrows.push(a_global.row(r).to_owned());
linear_constraint_b.push(lin_smooth.b[r]);
}
}
let lower_bound_constraints = if any_bounds {
linear_constraints_from_lower_bounds_global(&coefficient_lower_bounds)
} else {
None
};
let explicit_linear_constraints = if linear_constraintrows.is_empty() {
None
} else {
let mut a = Array2::<f64>::zeros((linear_constraintrows.len(), p_total));
for (i, row) in linear_constraintrows.iter().enumerate() {
a.row_mut(i).assign(row);
}
Some(LinearInequalityConstraints {
a,
b: Array1::from_vec(linear_constraint_b),
})
};
let linear_constraints =
merge_linear_constraints_global(explicit_linear_constraints, lower_bound_constraints);
Ok(TermCollectionDesign {
design,
penalties,
nullspace_dims,
penaltyinfo,
dropped_penaltyinfo,
coefficient_lower_bounds: if any_bounds {
Some(coefficient_lower_bounds)
} else {
None
},
linear_constraints,
intercept_range: 0..1,
linear_ranges,
random_effect_ranges,
random_effect_levels,
smooth,
})
}
pub fn build_term_collection_design(
data: ArrayView2<'_, f64>,
spec: &TermCollectionSpec,
) -> Result<TermCollectionDesign, BasisError> {
let mut planned_specs =
plan_joint_spatial_centers_for_term_blocks(data, &[spec.smooth_terms.clone()])?;
let planned_smooth_terms = planned_specs.pop().ok_or_else(|| {
BasisError::InvalidInput(
"joint spatial center planner returned no smooth terms for single-spec build"
.to_string(),
)
})?;
let mut planned_spec = spec.clone();
planned_spec.smooth_terms = planned_smooth_terms;
build_term_collection_design_inner(data, &planned_spec)
}
pub fn build_term_collection_designs_joint(
data: ArrayView2<'_, f64>,
specs: &[TermCollectionSpec],
) -> Result<Vec<TermCollectionDesign>, BasisError> {
let smooth_blocks = specs
.iter()
.map(|spec| spec.smooth_terms.clone())
.collect::<Vec<_>>();
let planned_blocks = plan_joint_spatial_centers_for_term_blocks(data, &smooth_blocks)?;
let mut out = Vec::with_capacity(specs.len());
for (spec, planned_terms) in specs.iter().zip(planned_blocks.into_iter()) {
let mut planned_spec = spec.clone();
planned_spec.smooth_terms = planned_terms;
out.push(build_term_collection_design_inner(data, &planned_spec)?);
}
Ok(out)
}
pub fn build_term_collection_designs_and_freeze_joint(
data: ArrayView2<'_, f64>,
specs: &[TermCollectionSpec],
) -> Result<(Vec<TermCollectionDesign>, Vec<TermCollectionSpec>), EstimationError> {
let designs = build_term_collection_designs_joint(data, specs)?;
let mut resolved_specs = Vec::with_capacity(specs.len());
for (spec, design) in specs.iter().zip(designs.iter()) {
resolved_specs.push(freeze_term_collection_from_design(spec, design)?);
}
Ok((designs, resolved_specs))
}
fn smooth_term_feature_cols(term: &SmoothTermSpec) -> Vec<usize> {
match &term.basis {
SmoothBasisSpec::BSpline1D { feature_col, .. } => vec![*feature_col],
SmoothBasisSpec::ThinPlate { feature_cols, .. }
| SmoothBasisSpec::Matern { feature_cols, .. }
| SmoothBasisSpec::Duchon { feature_cols, .. }
| SmoothBasisSpec::TensorBSpline { feature_cols, .. } => feature_cols.clone(),
}
}
fn smooth_basis_family_rank(term: &SmoothTermSpec) -> u8 {
match &term.basis {
SmoothBasisSpec::BSpline1D { .. } => 0,
SmoothBasisSpec::TensorBSpline { .. } => 1,
SmoothBasisSpec::ThinPlate { .. } => 2,
SmoothBasisSpec::Matern { .. } => 3,
SmoothBasisSpec::Duchon { .. } => 4,
}
}
fn smooth_has_frozen_identifiability(term: &SmoothTermSpec) -> bool {
match &term.basis {
SmoothBasisSpec::BSpline1D { spec, .. } => {
matches!(
spec.identifiability,
BSplineIdentifiability::FrozenTransform { .. }
)
}
SmoothBasisSpec::ThinPlate { spec, .. } => matches!(
spec.identifiability,
SpatialIdentifiability::FrozenTransform { .. }
),
SmoothBasisSpec::Matern { spec, .. } => matches!(
spec.identifiability,
MaternIdentifiability::FrozenTransform { .. }
),
SmoothBasisSpec::Duchon { spec, .. } => matches!(
spec.identifiability,
SpatialIdentifiability::FrozenTransform { .. }
),
SmoothBasisSpec::TensorBSpline { spec, .. } => matches!(
spec.identifiability,
TensorBSplineIdentifiability::FrozenTransform { .. }
),
}
}
fn compare_smooth_ownership_priority(
lhs_idx: usize,
lhs: &SmoothTermSpec,
rhs_idx: usize,
rhs: &SmoothTermSpec,
) -> std::cmp::Ordering {
let lhs_cols = smooth_term_feature_cols(lhs);
let rhs_cols = smooth_term_feature_cols(rhs);
lhs_cols
.len()
.cmp(&rhs_cols.len())
.then_with(|| lhs_cols.cmp(&rhs_cols))
.then_with(|| smooth_basis_family_rank(lhs).cmp(&smooth_basis_family_rank(rhs)))
.then_with(|| lhs.name.cmp(&rhs.name))
.then(lhs_idx.cmp(&rhs_idx))
}
fn smooth_is_owned_by_prior_term(owner: &SmoothTermSpec, target: &SmoothTermSpec) -> bool {
let owner_features = smooth_term_feature_cols(owner)
.into_iter()
.collect::<BTreeSet<_>>();
let target_features = smooth_term_feature_cols(target)
.into_iter()
.collect::<BTreeSet<_>>();
owner_features.is_subset(&target_features)
}
fn build_constraint_block(
n: usize,
parametric_block: Option<&Array2<f64>>,
owner_blocks: &[&DesignMatrix],
) -> Result<Array2<f64>, BasisError> {
let param_cols = parametric_block.map_or(0, |mat| mat.ncols());
let owner_cols: usize = owner_blocks.iter().map(|design| design.ncols()).sum();
let mut block = Array2::<f64>::zeros((n, param_cols + owner_cols));
let mut col_start = 0usize;
if let Some(parametric) = parametric_block {
let col_end = col_start + parametric.ncols();
block
.slice_mut(s![.., col_start..col_end])
.assign(parametric);
col_start = col_end;
}
const CHUNK: usize = 1024;
for owner in owner_blocks {
let col_end = col_start + owner.ncols();
for row_start in (0..n).step_by(CHUNK) {
let row_end = (row_start + CHUNK).min(n);
let chunk = (*owner)
.try_row_chunk(row_start..row_end)
.map_err(|e| BasisError::InvalidInput(e.to_string()))?;
block
.slice_mut(s![row_start..row_end, col_start..col_end])
.assign(&chunk);
}
col_start = col_end;
}
Ok(block)
}
fn design_cross_relative_residual(
lhs: &DesignMatrix,
rhs: &DesignMatrix,
) -> Result<f64, BasisError> {
let n = lhs.nrows();
if rhs.nrows() != n {
return Err(BasisError::ConstraintMatrixRowMismatch {
basisrows: n,
constraintrows: rhs.nrows(),
});
}
const CHUNK: usize = 1024;
let mut cross = Array2::<f64>::zeros((lhs.ncols(), rhs.ncols()));
let mut lhs_sumsq = 0.0;
let mut rhs_sumsq = 0.0;
for start in (0..n).step_by(CHUNK) {
let end = (start + CHUNK).min(n);
let lhs_chunk = lhs
.try_row_chunk(start..end)
.map_err(|e| BasisError::InvalidInput(e.to_string()))?;
let rhs_chunk = rhs
.try_row_chunk(start..end)
.map_err(|e| BasisError::InvalidInput(e.to_string()))?;
cross += &crate::faer_ndarray::fast_atb(&lhs_chunk, &rhs_chunk);
lhs_sumsq += lhs_chunk.iter().map(|v| v * v).sum::<f64>();
rhs_sumsq += rhs_chunk.iter().map(|v| v * v).sum::<f64>();
}
let num = cross.iter().map(|v| v * v).sum::<f64>().sqrt();
let denom = (lhs_sumsq.sqrt() * rhs_sumsq.sqrt()).max(1e-300);
Ok(num / denom)
}
fn smooth_has_overlapping_linear_terms(
linear_terms: &[LinearTermSpec],
termspec: &SmoothTermSpec,
) -> bool {
let feature_cols = smooth_term_feature_cols(termspec);
linear_terms
.iter()
.any(|linear| feature_cols.contains(&linear.feature_col))
}
fn smooth_intrinsic_parametric_feature_cols(
linear_terms: &[LinearTermSpec],
term: &SmoothTermSpec,
) -> Vec<usize> {
let feature_cols = smooth_term_feature_cols(term);
let mut owned = Vec::new();
for linear in linear_terms {
if feature_cols.contains(&linear.feature_col) && !owned.contains(&linear.feature_col) {
owned.push(linear.feature_col);
}
}
owned
}
fn apply_global_smooth_identifiability(
smooth: RawSmoothDesign,
data: ArrayView2<'_, f64>,
linear_terms: &[LinearTermSpec],
smoothspecs: &[SmoothTermSpec],
) -> Result<SmoothDesign, BasisError> {
if smoothspecs.len() != smooth.terms.len() {
return Err(BasisError::DimensionMismatch(format!(
"smooth spec count ({}) does not match built term count ({})",
smoothspecs.len(),
smooth.terms.len()
)));
}
if smooth.terms.is_empty() {
return Ok(smooth.into());
}
let mut local_designs = vec![None; smooth.terms.len()];
let mut local_penalties = vec![Vec::<Array2<f64>>::new(); smooth.terms.len()];
let mut local_nullspaces = vec![Vec::<usize>::new(); smooth.terms.len()];
let mut local_penaltyinfo = vec![Vec::<PenaltyInfo>::new(); smooth.terms.len()];
let mut local_metadata = vec![None; smooth.terms.len()];
let mut local_dims = vec![0usize; smooth.terms.len()];
let mut local_linear_constraints = vec![None; smooth.terms.len()];
let mut ownership_order: Vec<usize> = (0..smooth.terms.len()).collect();
ownership_order.sort_by(|&lhs, &rhs| {
compare_smooth_ownership_priority(lhs, &smoothspecs[lhs], rhs, &smoothspecs[rhs])
});
let mut processed_owner_indices = Vec::<usize>::with_capacity(smooth.terms.len());
use rayon::iter::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator,
};
for &idx in &ownership_order {
let term = &smooth.terms[idx];
let termspec = &smoothspecs[idx];
let design_local = smooth.term_designs[idx].clone();
let skip_global_transform =
smooth_has_frozen_identifiability(termspec) || term.lower_bounds_local.is_some();
let owner_indices = if skip_global_transform {
Vec::new()
} else {
let overlap_tol = 1e-10;
let owner_cross_checks = processed_owner_indices
.iter()
.copied()
.filter(|&owner_idx| {
smooth_is_owned_by_prior_term(&smoothspecs[owner_idx], termspec)
})
.collect::<Vec<_>>()
.into_par_iter()
.map(|owner_idx| {
let owner_design = local_designs[owner_idx]
.as_ref()
.expect("owner design must be available before dependent smooth");
design_cross_relative_residual(&design_local, owner_design)
.map(|rel| (owner_idx, rel))
})
.collect::<Vec<_>>();
let mut out = Vec::new();
for check in owner_cross_checks {
let (owner_idx, rel) = check?;
if rel > overlap_tol {
out.push(owner_idx);
}
}
out
};
let owner_blocks = owner_indices
.iter()
.map(|owner_idx| {
local_designs[*owner_idx]
.as_ref()
.expect("owner design must be available before dependent smooth")
})
.collect::<Vec<_>>();
let needs_parametric_block = !skip_global_transform
&& (smooth_has_overlapping_linear_terms(linear_terms, termspec)
|| !smooth_intrinsic_parametric_feature_cols(linear_terms, termspec).is_empty()
|| matches!(
spatial_identifiability_policy(termspec),
Some(SpatialIdentifiability::OrthogonalToParametric)
));
let parametric_block = if !needs_parametric_block {
None
} else {
Some(build_parametric_constraint_block_for_term(
data,
linear_terms,
termspec,
)?)
};
let c_local =
if skip_global_transform || (parametric_block.is_none() && owner_blocks.is_empty()) {
None
} else {
Some(build_constraint_block(
data.nrows(),
parametric_block.as_ref(),
&owner_blocks,
)?)
};
let z_opt = if skip_global_transform {
None
} else {
match maybe_smooth_identifiability_transform(
termspec,
&design_local,
c_local.as_ref().map(|mat| mat.view()),
) {
Ok(z_opt) => z_opt,
Err(BasisError::ConstraintNullspaceCollapsed { .. }) if !owner_blocks.is_empty() => {
Some(Array2::zeros((design_local.ncols(), 0)))
}
Err(err) => return Err(err),
}
};
let design_constrained = if let Some(z) = z_opt.as_ref() {
apply_smooth_transform_to_design(design_local, z, &term.name)?
} else {
design_local
};
if let Some(c_ref) = c_local.as_ref() {
let rel =
orthogonality_relative_residual_for_design(&design_constrained, c_ref.view())?;
let tol = 1e-8;
if rel > tol {
return Err(BasisError::InvalidInput(format!(
"smooth orthogonality residual too large for term '{}': {:.3e} > {:.1e}",
term.name, rel, tol
)));
}
}
let active_penaltyinfo = term
.penaltyinfo_local
.iter()
.filter(|info| info.active)
.cloned()
.collect::<Vec<_>>();
if active_penaltyinfo.len() != term.penalties_local.len() {
return Err(BasisError::InvalidInput(format!(
"internal penalty metadata mismatch for term '{}': activeinfos={}, penalties={}",
term.name,
active_penaltyinfo.len(),
term.penalties_local.len()
)));
}
let penalties_constrained: Vec<Array2<f64>> = term
.penalties_local
.par_iter()
.map(|s_local| {
if let Some(z) = z_opt.as_ref() {
let zt_s = z.t().dot(s_local);
zt_s.dot(z)
} else {
s_local.clone()
}
})
.collect();
let penalty_candidates = penalties_constrained
.into_par_iter()
.zip(active_penaltyinfo.into_par_iter())
.map(|(matrix, info)| {
let (matrix, c_new) = normalize_penalty_in_constrained_space(&matrix);
PenaltyCandidate {
nullspace_dim_hint: info.nullspace_dim_hint,
matrix,
source: info.source,
normalization_scale: info.normalization_scale * c_new,
kronecker_factors: None,
op: None,
}
})
.collect::<Vec<_>>();
let (penalties_constrained, nullspace_constrained, penaltyinfo_constrained) =
filter_active_penalty_candidates(penalty_candidates)?;
let linear_constraints_constrained =
if let Some(lin_local) = term.linear_constraints_local.as_ref() {
if let Some(z) = z_opt.as_ref() {
Some(LinearInequalityConstraints {
a: lin_local.a.dot(z),
b: lin_local.b.clone(),
})
} else {
Some(lin_local.clone())
}
} else {
None
};
local_dims[idx] = design_constrained.ncols();
local_designs[idx] = Some(design_constrained);
local_penalties[idx] = penalties_constrained;
local_nullspaces[idx] = nullspace_constrained;
local_penaltyinfo[idx] = penaltyinfo_constrained;
local_linear_constraints[idx] = linear_constraints_constrained;
local_metadata[idx] = Some(with_identifiability_transform(
&term.metadata,
z_opt.as_ref(),
)?);
processed_owner_indices.push(idx);
}
let total_p: usize = local_dims.iter().sum();
let mut terms_out = Vec::<SmoothTerm>::with_capacity(smooth.terms.len());
let mut penalties_global = Vec::<BlockwisePenalty>::new();
let mut nullspace_dims_global = Vec::<usize>::new();
let mut penaltyinfo_global = Vec::<PenaltyBlockInfo>::new();
let mut dropped_penaltyinfo_global = smooth.dropped_penaltyinfo.clone();
let mut coefficient_lower_bounds = Array1::<f64>::from_elem(total_p, f64::NEG_INFINITY);
let mut any_bounds = false;
let mut linear_constraintsrows: Vec<Array1<f64>> = Vec::new();
let mut linear_constraints_b: Vec<f64> = Vec::new();
let mut col_start = 0usize;
for idx in 0..smooth.terms.len() {
let p_local = local_dims[idx];
let col_end = col_start + p_local;
let activeinfos = local_penaltyinfo[idx]
.iter()
.filter(|info| info.active)
.collect::<Vec<_>>();
if activeinfos.len() != local_penalties[idx].len() {
return Err(BasisError::InvalidInput(format!(
"internal penalty info mismatch for term '{}': activeinfos={}, penalties={}",
smooth.terms[idx].name,
activeinfos.len(),
local_penalties[idx].len()
)));
}
for ((s_local, &ns), info) in local_penalties[idx]
.iter()
.zip(local_nullspaces[idx].iter())
.zip(activeinfos.into_iter())
{
let global_index = penalties_global.len();
penalties_global.push(BlockwisePenalty::new(col_start..col_end, s_local.clone()));
nullspace_dims_global.push(ns);
let mut penalty = info.clone();
penalty.nullspace_dim_hint = ns;
penaltyinfo_global.push(PenaltyBlockInfo {
global_index,
termname: Some(smooth.terms[idx].name.clone()),
penalty,
});
}
for info in local_penaltyinfo[idx].iter().filter(|info| !info.active) {
dropped_penaltyinfo_global.push(DroppedPenaltyBlockInfo {
termname: Some(smooth.terms[idx].name.clone()),
penalty: info.clone(),
});
}
terms_out.push(SmoothTerm {
name: smooth.terms[idx].name.clone(),
coeff_range: col_start..col_end,
shape: smooth.terms[idx].shape,
penalties_local: local_penalties[idx].clone(),
nullspace_dims: local_nullspaces[idx].clone(),
penaltyinfo_local: local_penaltyinfo[idx].clone(),
metadata: local_metadata[idx]
.clone()
.expect("local metadata must exist for every smooth term"),
lower_bounds_local: smooth.terms[idx].lower_bounds_local.clone(),
linear_constraints_local: local_linear_constraints[idx].clone(),
kronecker_factored: None,
});
if let Some(lin_local) = &local_linear_constraints[idx] {
for r in 0..lin_local.a.nrows() {
let mut row = Array1::<f64>::zeros(total_p);
row.slice_mut(s![col_start..col_end])
.assign(&lin_local.a.row(r));
linear_constraintsrows.push(row);
linear_constraints_b.push(lin_local.b[r]);
}
}
if let Some(lb_local) = smooth.terms[idx].lower_bounds_local.as_ref()
&& lb_local.len() == p_local
{
coefficient_lower_bounds
.slice_mut(s![col_start..col_end])
.assign(lb_local);
any_bounds = true;
}
col_start = col_end;
}
debug_assert_eq!(
penalties_global.len(),
nullspace_dims_global.len(),
"globally reparameterized smooth penalty/nullspace bookkeeping diverged"
);
debug_assert_eq!(
penalties_global.len(),
penaltyinfo_global.len(),
"globally reparameterized smooth penalty metadata bookkeeping diverged"
);
Ok(SmoothDesign {
term_designs: local_designs
.into_iter()
.map(|design| design.expect("local design must exist for every smooth term"))
.collect(),
penalties: penalties_global,
nullspace_dims: nullspace_dims_global,
penaltyinfo: penaltyinfo_global,
dropped_penaltyinfo: dropped_penaltyinfo_global,
terms: terms_out,
coefficient_lower_bounds: if any_bounds {
Some(coefficient_lower_bounds)
} else {
None
},
linear_constraints: if linear_constraintsrows.is_empty() {
None
} else {
let mut a = Array2::<f64>::zeros((linear_constraintsrows.len(), total_p));
for (i, row) in linear_constraintsrows.iter().enumerate() {
a.row_mut(i).assign(row);
}
Some(LinearInequalityConstraints {
a,
b: Array1::from_vec(linear_constraints_b),
})
},
})
}
fn build_parametric_constraint_block_for_term(
data: ArrayView2<'_, f64>,
linear_terms: &[LinearTermSpec],
termspec: &SmoothTermSpec,
) -> Result<Array2<f64>, BasisError> {
let n = data.nrows();
let p_data = data.ncols();
let feature_cols = smooth_term_feature_cols(termspec);
let mut parametric_cols = smooth_intrinsic_parametric_feature_cols(linear_terms, termspec);
for &feature_col in ¶metric_cols {
if feature_col >= p_data {
return Err(BasisError::DimensionMismatch(format!(
"smooth term feature column {feature_col} out of bounds for {p_data} columns"
)));
}
}
for linear in linear_terms
.iter()
.filter(|linear| feature_cols.contains(&linear.feature_col))
{
if linear.feature_col >= p_data {
return Err(BasisError::DimensionMismatch(format!(
"linear term '{}' feature column {} out of bounds for {} columns",
linear.name, linear.feature_col, p_data
)));
}
if !parametric_cols.contains(&linear.feature_col) {
parametric_cols.push(linear.feature_col);
}
}
let mut c = Array2::<f64>::zeros((n, 1 + parametric_cols.len()));
c.column_mut(0).fill(1.0);
for (j, &feature_col) in parametric_cols.iter().enumerate() {
c.column_mut(j + 1).assign(&data.column(feature_col));
}
Ok(c)
}
fn apply_smooth_transform_to_design(
design_local: DesignMatrix,
transform: &Array2<f64>,
termname: &str,
) -> Result<DesignMatrix, BasisError> {
match design_local {
DesignMatrix::Dense(inner) => {
let op = CoefficientTransformOperator::new(inner, transform.clone()).map_err(|e| {
BasisError::InvalidInput(format!(
"smooth identifiability transform failed for term '{termname}': {e}"
))
})?;
Ok(DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Arc::new(op),
)))
}
DesignMatrix::Sparse(inner) => {
let dense = inner
.try_to_dense_arc("smooth identifiability sparse transform")
.map_err(BasisError::InvalidInput)?
.as_ref()
.dot(transform);
Ok(DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
dense,
)))
}
}
}
fn design_constraint_cross(
design: &DesignMatrix,
constraint_matrix: ArrayView2<'_, f64>,
) -> Result<Array2<f64>, BasisError> {
let n = design.nrows();
if constraint_matrix.nrows() != n {
return Err(BasisError::ConstraintMatrixRowMismatch {
basisrows: n,
constraintrows: constraint_matrix.nrows(),
});
}
let mut cross = Array2::<f64>::zeros((design.ncols(), constraint_matrix.ncols()));
const CHUNK: usize = 1024;
for start in (0..n).step_by(CHUNK) {
let end = (start + CHUNK).min(n);
let design_chunk = design
.try_row_chunk(start..end)
.map_err(|e| BasisError::InvalidInput(e.to_string()))?;
let constraint_chunk = constraint_matrix.slice(s![start..end, ..]).to_owned();
cross += &crate::faer_ndarray::fast_atb(&design_chunk, &constraint_chunk);
}
Ok(cross)
}
fn design_frobenius_norm(design: &DesignMatrix) -> Result<f64, BasisError> {
let n = design.nrows();
const CHUNK: usize = 1024;
let mut sumsq = 0.0;
for start in (0..n).step_by(CHUNK) {
let end = (start + CHUNK).min(n);
let chunk = design
.try_row_chunk(start..end)
.map_err(|e| BasisError::InvalidInput(e.to_string()))?;
sumsq += chunk.iter().map(|v| v * v).sum::<f64>();
}
Ok(sumsq.sqrt())
}
fn maybe_smooth_identifiability_transform(
termspec: &SmoothTermSpec,
design_local: &DesignMatrix,
constraint_block: Option<ArrayView2<'_, f64>>,
) -> Result<Option<Array2<f64>>, BasisError> {
if let Some(SpatialIdentifiability::FrozenTransform { transform }) =
spatial_identifiability_policy(termspec)
{
if design_local.ncols() != transform.nrows() {
return Err(BasisError::DimensionMismatch(format!(
"frozen spatial identifiability transform mismatch: design has {} columns but transform has {} rows",
design_local.ncols(),
transform.nrows()
)));
}
return Ok(Some(transform.clone()));
}
let Some(c) = constraint_block else {
return Ok(None);
};
if c.ncols() == 0 {
return Ok(None);
}
let z = orthogonality_transform_for_design(
design_local,
c,
None, )?;
Ok(Some(z))
}
fn spatial_identifiability_policy(termspec: &SmoothTermSpec) -> Option<&SpatialIdentifiability> {
match &termspec.basis {
SmoothBasisSpec::ThinPlate { spec, .. } => Some(&spec.identifiability),
SmoothBasisSpec::Duchon { spec, .. } => Some(&spec.identifiability),
_ => None,
}
}
fn compose_identifiability_transforms(
existing: Option<&Array2<f64>>,
extra: Option<&Array2<f64>>,
) -> Result<Option<Array2<f64>>, BasisError> {
match (existing, extra) {
(Some(lhs), Some(rhs)) => {
if lhs.ncols() == rhs.nrows() {
Ok(Some(lhs.dot(rhs)))
} else if lhs.nrows() == rhs.nrows() && lhs.ncols() == rhs.ncols() {
Ok(Some(rhs.clone()))
} else {
Err(BasisError::DimensionMismatch(format!(
"identifiability transform mismatch: existing is {}x{}, extra is {}x{}",
lhs.nrows(),
lhs.ncols(),
rhs.nrows(),
rhs.ncols(),
)))
}
}
(Some(lhs), None) => Ok(Some(lhs.clone())),
(None, Some(rhs)) => Ok(Some(rhs.clone())),
(None, None) => Ok(None),
}
}
fn with_identifiability_transform(
metadata: &BasisMetadata,
transform: Option<&Array2<f64>>,
) -> Result<BasisMetadata, BasisError> {
match metadata {
BasisMetadata::BSpline1D {
knots,
identifiability_transform,
} => Ok(BasisMetadata::BSpline1D {
knots: knots.clone(),
identifiability_transform: compose_identifiability_transforms(
identifiability_transform.as_ref(),
transform,
)?,
}),
BasisMetadata::ThinPlate {
centers,
length_scale,
identifiability_transform,
input_scales,
radial_reparam,
} => Ok(BasisMetadata::ThinPlate {
centers: centers.clone(),
length_scale: *length_scale,
identifiability_transform: compose_identifiability_transforms(
identifiability_transform.as_ref(),
transform,
)?,
input_scales: input_scales.clone(),
radial_reparam: radial_reparam.clone(),
}),
BasisMetadata::Matern {
centers,
length_scale,
nu,
include_intercept,
identifiability_transform,
input_scales,
aniso_log_scales,
} => Ok(BasisMetadata::Matern {
centers: centers.clone(),
length_scale: *length_scale,
nu: *nu,
include_intercept: *include_intercept,
identifiability_transform: compose_identifiability_transforms(
identifiability_transform.as_ref(),
transform,
)?,
input_scales: input_scales.clone(),
aniso_log_scales: aniso_log_scales.clone(),
}),
BasisMetadata::Duchon {
centers,
length_scale,
power,
nullspace_order,
identifiability_transform,
input_scales,
aniso_log_scales,
} => Ok(BasisMetadata::Duchon {
centers: centers.clone(),
length_scale: *length_scale,
power: *power,
nullspace_order: *nullspace_order,
input_scales: input_scales.clone(),
aniso_log_scales: aniso_log_scales.clone(),
identifiability_transform: compose_identifiability_transforms(
identifiability_transform.as_ref(),
transform,
)?,
}),
BasisMetadata::TensorBSpline {
feature_cols,
knots,
degrees,
identifiability_transform,
} => Ok(BasisMetadata::TensorBSpline {
feature_cols: feature_cols.clone(),
knots: knots.clone(),
degrees: degrees.clone(),
identifiability_transform: compose_identifiability_transforms(
identifiability_transform.as_ref(),
transform,
)?,
}),
}
}
fn orthogonality_relative_residual_for_design(
design: &DesignMatrix,
constraint_matrix: ArrayView2<'_, f64>,
) -> Result<f64, BasisError> {
let cross = design_constraint_cross(design, constraint_matrix)?;
let num = cross.iter().map(|v| v * v).sum::<f64>().sqrt();
let b_norm = design_frobenius_norm(design)?;
let c_norm = constraint_matrix.iter().map(|v| v * v).sum::<f64>().sqrt();
let denom = (b_norm * c_norm).max(1e-300);
Ok(num / denom)
}
pub fn fit_term_collection_forspec(
data: ArrayView2<'_, f64>,
y: ArrayView1<'_, f64>,
weights: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
spec: &TermCollectionSpec,
family: LikelihoodFamily,
options: &FitOptions,
) -> Result<FittedTermCollection, EstimationError> {
fit_term_collection_forspecwith_heuristic_lambdas(
data, y, weights, offset, spec, None, family, options,
)
}
fn fit_term_collection_forspecwith_heuristic_lambdas(
data: ArrayView2<'_, f64>,
y: ArrayView1<'_, f64>,
weights: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
spec: &TermCollectionSpec,
heuristic_lambdas: Option<&[f64]>,
family: LikelihoodFamily,
options: &FitOptions,
) -> Result<FittedTermCollection, EstimationError> {
let base_design = build_term_collection_design(data, spec)?;
fit_term_collection_on_realized_design(
y,
weights,
offset,
spec,
&base_design,
heuristic_lambdas,
family,
options,
)
}
fn has_bounded_linear_terms(spec: &TermCollectionSpec) -> bool {
spec.linear_terms.iter().any(|term| {
matches!(
term.coefficient_geometry,
LinearCoefficientGeometry::Bounded { .. }
)
})
}
fn fit_term_collection_on_realized_design(
y: ArrayView1<'_, f64>,
weights: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
spec: &TermCollectionSpec,
design: &TermCollectionDesign,
heuristic_lambdas: Option<&[f64]>,
family: LikelihoodFamily,
options: &FitOptions,
) -> Result<FittedTermCollection, EstimationError> {
if has_bounded_linear_terms(spec) {
return fit_bounded_term_collection_with_design(
y,
weights,
offset,
spec,
design,
heuristic_lambdas,
family,
options,
);
}
let base_fit_opts = adaptive_fit_options_base(options, design);
let fitted = FittedTermCollection {
fit: fit_gamwith_heuristic_lambdas(
design.design.clone(),
y,
weights,
offset,
&design.penalties,
heuristic_lambdas,
family,
&base_fit_opts,
)?,
design: design.clone(),
adaptive_diagnostics: None,
};
enforce_term_constraint_feasibility(&fitted.design, &fitted.fit)?;
let adaptive_opts = options.adaptive_regularization.clone().unwrap_or_default();
if !adaptive_opts.enabled {
return Ok(fitted);
}
let runtime_caches = extract_spatial_operator_runtime_caches(spec, &fitted.design)?;
if runtime_caches.is_empty() {
return Ok(fitted);
}
fit_term_collectionwith_exact_spatial_adaptive_regularization(
fitted,
y,
weights,
offset,
family,
options,
&runtime_caches,
)
}
#[derive(Clone)]
struct SpatialOperatorRuntimeCache {
termname: String,
feature_cols: Vec<usize>,
coeff_global_range: Range<usize>,
mass_penalty_global_idx: usize,
tension_penalty_global_idx: usize,
stiffness_penalty_global_idx: usize,
d0: Array2<f64>,
d1: Array2<f64>,
d2: Array2<f64>,
collocation_points: Array2<f64>,
dimension: usize,
}
#[derive(Clone)]
struct SpatialAdaptiveWeights {
#[cfg(test)]
magweight: Array1<f64>,
#[cfg(test)]
gradweight: Array1<f64>,
#[cfg(test)]
lapweight: Array1<f64>,
inv_magweight: Array1<f64>,
invgradweight: Array1<f64>,
inv_lapweight: Array1<f64>,
}
#[derive(Clone)]
struct CharbonnierScalarBlockState {
signal: Array1<f64>,
radius: Array1<f64>,
epsilon: f64,
}
impl CharbonnierScalarBlockState {
fn from_signal(signal: Array1<f64>, epsilon: f64) -> Self {
let eps = epsilon.max(1e-12);
let radius = signal.mapv(|t| (t * t + eps * eps).sqrt());
Self {
signal,
radius,
epsilon: eps,
}
}
fn absolute_signal(&self) -> Array1<f64> {
self.signal.mapv(f64::abs)
}
fn penalty_value(&self) -> f64 {
self.radius.iter().map(|r| r - self.epsilon).sum::<f64>()
}
fn betagradient_coeff(&self) -> Array1<f64> {
Array1::from_iter(
self.signal
.iter()
.zip(self.radius.iter())
.map(|(t, r)| t / r),
)
}
fn betahessian_diag(&self) -> Array1<f64> {
let eps2 = self.epsilon * self.epsilon;
self.radius.mapv(|r| eps2 / r.powi(3))
}
fn log_epsilon_gradient_terms(&self) -> Array1<f64> {
let epsilon = self.epsilon;
let eps2 = epsilon * epsilon;
self.radius.mapv(|r| eps2 / r - epsilon)
}
fn log_epsilon_betagradient_coeff(&self) -> Array1<f64> {
let eps2 = self.epsilon * self.epsilon;
Array1::from_iter(
self.signal
.iter()
.zip(self.radius.iter())
.map(|(t, r)| -eps2 * t / r.powi(3)),
)
}
fn log_epsilon_hessian_terms(&self) -> Array1<f64> {
let epsilon = self.epsilon;
let eps2 = epsilon * epsilon;
let eps4 = eps2 * eps2;
self.radius
.mapv(|r| 2.0 * eps2 / r - eps4 / r.powi(3) - epsilon)
}
fn surrogateweights(
&self,
weight_floor: f64,
weight_ceiling: f64,
) -> (Array1<f64>, Array1<f64>) {
let weight = self
.radius
.mapv(|r| (1.0 / r).clamp(weight_floor, weight_ceiling));
let invweight = weight.mapv(|u| 1.0 / u);
(weight, invweight)
}
fn directionalhessian_diag(&self, direction_signal: &Array1<f64>) -> Array1<f64> {
let eps2 = self.epsilon * self.epsilon;
Array1::from_iter(
self.signal
.iter()
.zip(direction_signal.iter())
.zip(self.radius.iter())
.map(|((t, q), r)| -3.0 * eps2 * t * q / r.powi(5)),
)
}
fn log_epsilon_betahessian_diag(&self) -> Array1<f64> {
let eps2 = self.epsilon * self.epsilon;
let eps4 = eps2 * eps2;
Array1::from_iter(
self.signal
.iter()
.zip(self.radius.iter())
.map(|(_, r)| 2.0 * eps2 / r.powi(3) - 3.0 * eps4 / r.powi(5)),
)
}
fn log_epsilon_beta_mixed_second_coeff(&self) -> Array1<f64> {
let eps2 = self.epsilon * self.epsilon;
Array1::from_iter(
self.signal
.iter()
.zip(self.radius.iter())
.map(|(t, r)| eps2 * t * (eps2 - 2.0 * t * t) / r.powi(5)),
)
}
fn log_epsilon_betahessian_second_diag(&self) -> Array1<f64> {
let eps2 = self.epsilon * self.epsilon;
let eps4 = eps2 * eps2;
let eps6 = eps4 * eps2;
Array1::from_iter(
self.radius.iter().map(|r| {
4.0 * eps2 / r.powi(3) - 18.0 * eps4 / r.powi(5) + 15.0 * eps6 / r.powi(7)
}),
)
}
fn log_epsilon_betahessian_directional_diag(
&self,
direction_signal: &Array1<f64>,
) -> Array1<f64> {
let eps2 = self.epsilon * self.epsilon;
let eps4 = eps2 * eps2;
Array1::from_iter(
self.signal
.iter()
.zip(direction_signal.iter())
.zip(self.radius.iter())
.map(|((t, q), r)| (-6.0 * eps2 * t / r.powi(5) + 15.0 * eps4 * t / r.powi(7)) * q),
)
}
}
#[derive(Clone)]
struct CharbonnierGroupedBlockState {
norm: Array1<f64>,
radius: Array1<f64>,
signal_blocks: Array2<f64>,
epsilon: f64,
}
impl CharbonnierGroupedBlockState {
fn from_signal_blocks(signal_blocks: Array2<f64>, epsilon: f64) -> Self {
let eps = epsilon.max(1e-12);
let norm = Array1::from_iter(
signal_blocks
.rows()
.into_iter()
.map(|row| row.iter().map(|v| v * v).sum::<f64>().sqrt()),
);
let radius = norm.mapv(|g| (g * g + eps * eps).sqrt());
Self {
norm,
radius,
signal_blocks,
epsilon: eps,
}
}
fn penalty_value(&self) -> f64 {
self.radius.iter().map(|r| r - self.epsilon).sum::<f64>()
}
fn norm_signal(&self) -> Array1<f64> {
self.norm.clone()
}
fn betagradient_blocks(&self) -> Array2<f64> {
let mut out = self.signal_blocks.clone();
for (k, mut row) in out.rows_mut().into_iter().enumerate() {
let scale = 1.0 / self.radius[k];
row.mapv_inplace(|v| v * scale);
}
out
}
fn betahessian_blocks(&self) -> Vec<Array2<f64>> {
let mut out = Vec::with_capacity(self.signal_blocks.nrows());
for (k, row) in self.signal_blocks.rows().into_iter().enumerate() {
let dim = row.len();
let mut block = Array2::<f64>::eye(dim);
block.mapv_inplace(|v| v / self.radius[k]);
for i in 0..dim {
for j in 0..dim {
block[[i, j]] -= row[i] * row[j] / self.radius[k].powi(3);
}
}
out.push(block);
}
out
}
fn log_epsilon_gradient_terms(&self) -> Array1<f64> {
let epsilon = self.epsilon;
let eps2 = epsilon * epsilon;
self.radius.mapv(|r| eps2 / r - epsilon)
}
fn log_epsilon_betagradient_blocks(&self) -> Array2<f64> {
let mut out = self.signal_blocks.clone();
let eps2 = self.epsilon * self.epsilon;
for (k, mut row) in out.rows_mut().into_iter().enumerate() {
let scale = -eps2 / self.radius[k].powi(3);
row.mapv_inplace(|v| v * scale);
}
out
}
fn log_epsilon_hessian_terms(&self) -> Array1<f64> {
let epsilon = self.epsilon;
let eps2 = epsilon * epsilon;
let eps4 = eps2 * eps2;
self.radius
.mapv(|r| 2.0 * eps2 / r - eps4 / r.powi(3) - epsilon)
}
fn surrogateweights(
&self,
weight_floor: f64,
weight_ceiling: f64,
) -> (Array1<f64>, Array1<f64>) {
let weight = self
.radius
.mapv(|r| (1.0 / r).clamp(weight_floor, weight_ceiling));
let invweight = weight.mapv(|u| 1.0 / u);
(weight, invweight)
}
fn directionalhessian_blocks(&self, direction_blocks: &Array2<f64>) -> Vec<Array2<f64>> {
let mut out = Vec::with_capacity(self.signal_blocks.nrows());
for (k, (v, q)) in self
.signal_blocks
.rows()
.into_iter()
.zip(direction_blocks.rows().into_iter())
.enumerate()
{
let dim = v.len();
let dot = v.iter().zip(q.iter()).map(|(a, b)| a * b).sum::<f64>();
let r3 = self.radius[k].powi(3);
let r5 = self.radius[k].powi(5);
let mut block = Array2::<f64>::eye(dim);
block.mapv_inplace(|x| -dot * x / r3);
for i in 0..dim {
for j in 0..dim {
block[[i, j]] -= (q[i] * v[j] + v[i] * q[j]) / r3;
block[[i, j]] += 3.0 * dot * v[i] * v[j] / r5;
}
}
out.push(block);
}
out
}
fn log_epsilon_betahessian_blocks(&self) -> Vec<Array2<f64>> {
let mut out = Vec::with_capacity(self.signal_blocks.nrows());
for (k, row) in self.signal_blocks.rows().into_iter().enumerate() {
let dim = row.len();
let r3 = self.radius[k].powi(3);
let r5 = self.radius[k].powi(5);
let mut block = Array2::<f64>::eye(dim);
let eps2 = self.epsilon * self.epsilon;
block.mapv_inplace(|v| -eps2 * v / r3);
for i in 0..dim {
for j in 0..dim {
block[[i, j]] += 3.0 * eps2 * row[i] * row[j] / r5;
}
}
out.push(block);
}
out
}
fn log_epsilon_beta_mixed_second_blocks(&self) -> Array2<f64> {
let mut out = self.signal_blocks.clone();
let eps2 = self.epsilon * self.epsilon;
for (k, mut row) in out.rows_mut().into_iter().enumerate() {
let norm2 = self.norm[k] * self.norm[k];
let scale = eps2 * (eps2 - 2.0 * norm2) / self.radius[k].powi(5);
row.mapv_inplace(|v| v * scale);
}
out
}
fn log_epsilon_betahessian_second_blocks(&self) -> Vec<Array2<f64>> {
let mut out = Vec::with_capacity(self.signal_blocks.nrows());
let eps2 = self.epsilon * self.epsilon;
for (k, row) in self.signal_blocks.rows().into_iter().enumerate() {
let dim = row.len();
let norm2 = self.norm[k] * self.norm[k];
let r5 = self.radius[k].powi(5);
let r7 = self.radius[k].powi(7);
let mut block = Array2::<f64>::eye(dim);
block.mapv_inplace(|v| eps2 * (eps2 - 2.0 * norm2) * v / r5);
for i in 0..dim {
for j in 0..dim {
block[[i, j]] += 3.0 * eps2 * (2.0 * norm2 - 3.0 * eps2) * row[i] * row[j] / r7;
}
}
out.push(block);
}
out
}
fn log_epsilon_betahessian_directional_blocks(
&self,
direction_blocks: &Array2<f64>,
) -> Vec<Array2<f64>> {
let mut out = Vec::with_capacity(self.signal_blocks.nrows());
let eps2 = self.epsilon * self.epsilon;
for (k, (v, q)) in self
.signal_blocks
.rows()
.into_iter()
.zip(direction_blocks.rows().into_iter())
.enumerate()
{
let dim = v.len();
let dot = v.iter().zip(q.iter()).map(|(a, b)| a * b).sum::<f64>();
let r5 = self.radius[k].powi(5);
let r7 = self.radius[k].powi(7);
let mut block = Array2::<f64>::eye(dim);
block.mapv_inplace(|x| 3.0 * eps2 * dot * x / r5);
for i in 0..dim {
for j in 0..dim {
block[[i, j]] += 3.0 * eps2 * (q[i] * v[j] + v[i] * q[j]) / r5;
block[[i, j]] -= 15.0 * eps2 * dot * v[i] * v[j] / r7;
}
}
out.push(block);
}
out
}
}
fn scalar_operatorgradient(operator: &Array2<f64>, coeff: &Array1<f64>) -> Array1<f64> {
operator.t().dot(coeff)
}
fn scalar_operatorhessian(operator: &Array2<f64>, diag: &Array1<f64>) -> Array2<f64> {
let mut weighted = operator.clone();
for (k, &w) in diag.iter().enumerate() {
weighted.row_mut(k).mapv_inplace(|v| v * w);
}
let gram = operator.t().dot(&weighted);
(&gram + &gram.t().to_owned()) * 0.5
}
fn grouped_operatorgradient(
d1: &Array2<f64>,
dimension: usize,
blocks: &Array2<f64>,
) -> Result<Array1<f64>, EstimationError> {
if blocks.ncols() != dimension {
return Err(EstimationError::InvalidInput(format!(
"grouped gradient block dimension mismatch: got {}, expected {dimension}",
blocks.ncols()
)));
}
if d1.nrows() != blocks.nrows() * dimension {
return Err(EstimationError::InvalidInput(format!(
"grouped gradient row mismatch: D1 has {} rows, blocks imply {}",
d1.nrows(),
blocks.nrows() * dimension
)));
}
let mut out = Array1::<f64>::zeros(d1.ncols());
for k in 0..blocks.nrows() {
let gk = d1
.slice(s![k * dimension..(k + 1) * dimension, ..])
.to_owned();
out += &gk.t().dot(&blocks.row(k).to_owned());
}
Ok(out)
}
fn grouped_operatorhessian(
d1: &Array2<f64>,
dimension: usize,
blocks: &[Array2<f64>],
) -> Result<Array2<f64>, EstimationError> {
if d1.nrows() != blocks.len() * dimension {
return Err(EstimationError::InvalidInput(format!(
"grouped Hessian row mismatch: D1 has {} rows, blocks imply {}",
d1.nrows(),
blocks.len() * dimension
)));
}
let p = d1.ncols();
let mut out = Array2::<f64>::zeros((p, p));
for (k, block) in blocks.iter().enumerate() {
if block.nrows() != dimension || block.ncols() != dimension {
return Err(EstimationError::InvalidInput(format!(
"grouped Hessian block {k} has shape {}x{}, expected {}x{}",
block.nrows(),
block.ncols(),
dimension,
dimension
)));
}
let gk = d1
.slice(s![k * dimension..(k + 1) * dimension, ..])
.to_owned();
out += &gk.t().dot(&block.dot(&gk));
}
Ok((&out + &out.t().to_owned()) * 0.5)
}
#[derive(Clone)]
struct SpatialPenaltyExactState {
magnitude: CharbonnierScalarBlockState,
gradient: CharbonnierGroupedBlockState,
curvature: CharbonnierGroupedBlockState,
}
fn collocationgradient_blocks(
gradrows: &Array1<f64>,
dimension: usize,
) -> Result<Array2<f64>, EstimationError> {
if dimension == 0 || gradrows.len() % dimension != 0 {
return Err(EstimationError::InvalidInput(format!(
"invalid collocation gradient layout: rows={}, dimension={dimension}",
gradrows.len()
)));
}
let p = gradrows.len() / dimension;
let mut out = Array2::<f64>::zeros((p, dimension));
for k in 0..p {
for axis in 0..dimension {
out[[k, axis]] = gradrows[k * dimension + axis];
}
}
Ok(out)
}
fn collocationhessian_blocks(
hessianrows: &Array1<f64>,
dimension: usize,
) -> Result<Array2<f64>, EstimationError> {
let block_dim = dimension.checked_mul(dimension).ok_or_else(|| {
EstimationError::InvalidInput("invalid collocation Hessian dimension overflow".to_string())
})?;
if block_dim == 0 || hessianrows.len() % block_dim != 0 {
return Err(EstimationError::InvalidInput(format!(
"invalid collocation Hessian layout: rows={}, dimension={dimension}",
hessianrows.len()
)));
}
let p = hessianrows.len() / block_dim;
let mut out = Array2::<f64>::zeros((p, block_dim));
for k in 0..p {
for idx in 0..block_dim {
out[[k, idx]] = hessianrows[k * block_dim + idx];
}
}
Ok(out)
}
impl SpatialPenaltyExactState {
fn from_beta_local(
beta_local: ArrayView1<'_, f64>,
cache: &SpatialOperatorRuntimeCache,
epsilons: [f64; 3],
) -> Result<Self, EstimationError> {
let gradientrows = cache.d1.dot(&beta_local);
let hessianrows = cache.d2.dot(&beta_local);
Ok(Self {
magnitude: CharbonnierScalarBlockState::from_signal(
cache.d0.dot(&beta_local),
epsilons[0],
),
gradient: CharbonnierGroupedBlockState::from_signal_blocks(
collocationgradient_blocks(&gradientrows, cache.dimension)?,
epsilons[1],
),
curvature: CharbonnierGroupedBlockState::from_signal_blocks(
collocationhessian_blocks(&hessianrows, cache.dimension)?,
epsilons[2],
),
})
}
fn absolute_collocation_magnitudes(&self) -> (Array1<f64>, Array1<f64>, Array1<f64>) {
(
self.magnitude.absolute_signal(),
self.gradient.norm_signal(),
self.curvature.norm_signal(),
)
}
}
fn quantile_from_sorted(sorted: &[f64], q: f64) -> f64 {
if sorted.is_empty() {
return 0.0;
}
let qq = q.clamp(0.0, 1.0);
let pos = qq * (sorted.len().saturating_sub(1) as f64);
let lo = pos.floor() as usize;
let hi = pos.ceil() as usize;
if lo == hi {
sorted[lo]
} else {
let t = pos - lo as f64;
sorted[lo] * (1.0 - t) + sorted[hi] * t
}
}
fn robust_epsilon_from_samples(values: &[f64], min_epsilon_cfg: f64) -> f64 {
if values.is_empty() {
return min_epsilon_cfg.max(1e-12);
}
let mut clean = values
.iter()
.copied()
.filter(|v| v.is_finite() && *v >= 0.0)
.collect::<Vec<_>>();
if clean.is_empty() {
return min_epsilon_cfg.max(1e-12);
}
clean.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let n = clean.len();
let median = quantile_from_sorted(&clean, 0.5);
let q75 = quantile_from_sorted(&clean, 0.75);
let q95 = quantile_from_sorted(&clean, 0.95);
let mut abs_dev = clean
.iter()
.map(|v| (v - median).abs())
.filter(|v| v.is_finite())
.collect::<Vec<_>>();
abs_dev.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let mad = 1.4826 * quantile_from_sorted(&abs_dev, 0.5);
let mut scale = median.max(mad).max(q75);
let delta = (f64::EPSILON.sqrt() * q95.max(1.0))
.max(min_epsilon_cfg)
.max(1e-12);
let s_min = min_epsilon_cfg.max(1e-12);
if scale <= delta {
let rms = (clean.iter().map(|v| v * v).sum::<f64>() / n as f64).sqrt();
scale = q95.max(rms);
}
if scale <= delta {
scale = s_min;
}
let kappa = 1.0_f64;
(kappa * scale).max(s_min)
}
fn extract_spatial_operator_runtime_caches(
spec: &TermCollectionSpec,
design: &TermCollectionDesign,
) -> Result<Vec<SpatialOperatorRuntimeCache>, EstimationError> {
let smooth_start = design
.design
.ncols()
.saturating_sub(design.smooth.total_smooth_cols());
let mut out = Vec::<SpatialOperatorRuntimeCache>::new();
for (term_idx, (termspec, term_fit)) in spec
.smooth_terms
.iter()
.zip(design.smooth.terms.iter())
.enumerate()
{
let Some(global_base_idx) = smooth_term_penalty_index(spec, design, term_idx) else {
continue;
};
let mut active_local_idx = 0usize;
let mut mass_local_idx = None;
let mut tension_local_idx = None;
let mut stiffness_local_idx = None;
let mut mass_norm = None;
let mut tension_norm = None;
let mut stiffness_norm = None;
for info in &term_fit.penaltyinfo_local {
if !info.active {
continue;
}
match info.source {
PenaltySource::OperatorMass => {
mass_local_idx = Some(active_local_idx);
mass_norm = Some(info.normalization_scale);
}
PenaltySource::OperatorTension => {
tension_local_idx = Some(active_local_idx);
tension_norm = Some(info.normalization_scale);
}
PenaltySource::OperatorStiffness => {
stiffness_local_idx = Some(active_local_idx);
stiffness_norm = Some(info.normalization_scale);
}
_ => {}
}
active_local_idx += 1;
}
let (
Some(mass_local),
Some(tension_local),
Some(stiffness_local),
Some(mass_scale),
Some(tension_scale),
Some(stiffness_scale),
) = (
mass_local_idx,
tension_local_idx,
stiffness_local_idx,
mass_norm,
tension_norm,
stiffness_norm,
)
else {
continue;
};
let mass_global_idx = global_base_idx + mass_local;
let tension_global_idx = global_base_idx + tension_local;
let stiffness_global_idx = global_base_idx + stiffness_local;
let (feature_cols, mut d0, mut d1, mut d2, collocation_points, dim) =
match (&termspec.basis, &term_fit.metadata) {
(
SmoothBasisSpec::Matern { feature_cols, .. },
BasisMetadata::Matern {
centers,
length_scale,
nu,
include_intercept,
identifiability_transform,
aniso_log_scales,
..
},
) => {
let ops = build_matern_collocation_operator_matrices(
centers.view(),
None,
*length_scale,
*nu,
*include_intercept,
identifiability_transform.as_ref().map(|z| z.view()),
aniso_log_scales.as_deref(),
)?;
(
feature_cols.clone(),
ops.d0,
ops.d1,
ops.d2,
ops.collocation_points,
centers.ncols(),
)
}
(
SmoothBasisSpec::Duchon { feature_cols, .. },
BasisMetadata::Duchon {
centers,
length_scale,
power,
nullspace_order,
identifiability_transform,
aniso_log_scales,
..
},
) => {
let mut ws = crate::basis::BasisWorkspace::default();
let ops =
crate::basis::build_duchon_collocation_operator_matriceswithworkspace(
centers.view(),
None,
*length_scale,
*power,
*nullspace_order,
aniso_log_scales.as_deref(),
identifiability_transform.as_ref().map(|z| z.view()),
2,
&mut ws,
)?;
(
feature_cols.clone(),
ops.d0,
ops.d1,
ops.d2,
ops.collocation_points,
centers.ncols(),
)
}
_ => continue,
};
let mass_scale = mass_scale.max(1e-12).sqrt();
let tension_scale = tension_scale.max(1e-12).sqrt();
let stiffness_scale = stiffness_scale.max(1e-12).sqrt();
d0.mapv_inplace(|v| v / mass_scale);
d1.mapv_inplace(|v| v / tension_scale);
d2.mapv_inplace(|v| v / stiffness_scale);
let coeff_global_range =
(smooth_start + term_fit.coeff_range.start)..(smooth_start + term_fit.coeff_range.end);
if d0.ncols() != coeff_global_range.len()
|| d1.ncols() != coeff_global_range.len()
|| d2.ncols() != coeff_global_range.len()
{
return Err(EstimationError::InvalidInput(format!(
"spatial operator dimension mismatch for term '{}': D0 cols={}, D1 cols={}, D2 cols={}, coeffs={}",
term_fit.name,
d0.ncols(),
d1.ncols(),
d2.ncols(),
coeff_global_range.len()
)));
}
out.push(SpatialOperatorRuntimeCache {
termname: term_fit.name.clone(),
feature_cols,
coeff_global_range,
mass_penalty_global_idx: mass_global_idx,
tension_penalty_global_idx: tension_global_idx,
stiffness_penalty_global_idx: stiffness_global_idx,
d0,
d1,
d2,
collocation_points,
dimension: dim,
});
}
Ok(out)
}
fn compute_spatial_adaptiveweights_for_beta(
beta: &Array1<f64>,
caches: &[SpatialOperatorRuntimeCache],
epsilon_0: f64,
epsilon_g: f64,
epsilon_c: f64,
weight_floor: f64,
weight_ceiling: f64,
) -> Result<Vec<SpatialAdaptiveWeights>, EstimationError> {
caches
.iter()
.map(|cache| {
let beta_local = beta.slice(s![cache.coeff_global_range.clone()]);
let exact = SpatialPenaltyExactState::from_beta_local(
beta_local,
cache,
[epsilon_0, epsilon_g, epsilon_c],
)?;
#[cfg(test)]
let (u_0, inv_0) = exact
.magnitude
.surrogateweights(weight_floor, weight_ceiling);
#[cfg(not(test))]
let (_, inv_0) = exact
.magnitude
.surrogateweights(weight_floor, weight_ceiling);
#[cfg(test)]
let (u_g, inv_g) = exact
.gradient
.surrogateweights(weight_floor, weight_ceiling);
#[cfg(not(test))]
let (_, inv_g) = exact
.gradient
.surrogateweights(weight_floor, weight_ceiling);
#[cfg(test)]
let (u_c, inv_c) = exact
.curvature
.surrogateweights(weight_floor, weight_ceiling);
#[cfg(not(test))]
let (_, inv_c) = exact
.curvature
.surrogateweights(weight_floor, weight_ceiling);
Ok(SpatialAdaptiveWeights {
#[cfg(test)]
magweight: u_0,
#[cfg(test)]
gradweight: u_g,
#[cfg(test)]
lapweight: u_c,
inv_magweight: inv_0,
invgradweight: inv_g,
inv_lapweight: inv_c,
})
})
.collect()
}
fn compute_initial_epsilons(
beta: &Array1<f64>,
caches: &[SpatialOperatorRuntimeCache],
min_epsilon: f64,
) -> Result<(f64, f64, f64), EstimationError> {
let mut fvals = Vec::<f64>::new();
let mut gvals = Vec::<f64>::new();
let mut cvals = Vec::<f64>::new();
for cache in caches {
let beta_local = beta.slice(s![cache.coeff_global_range.clone()]);
let exact = SpatialPenaltyExactState::from_beta_local(
beta_local,
cache,
[min_epsilon, min_epsilon, min_epsilon],
)?;
let (f, g, c) = exact.absolute_collocation_magnitudes();
fvals.extend(f.iter().copied());
gvals.extend(g.iter().copied());
cvals.extend(c.iter().copied());
}
let eps_0 = robust_epsilon_from_samples(&fvals, min_epsilon);
let eps_g = robust_epsilon_from_samples(&gvals, min_epsilon);
let eps_c = robust_epsilon_from_samples(&cvals, min_epsilon);
Ok((eps_0, eps_g, eps_c))
}
fn exact_spatial_adaptive_penalty_index_set(
caches: &[SpatialOperatorRuntimeCache],
) -> BTreeSet<usize> {
let mut out = BTreeSet::new();
for cache in caches {
out.insert(cache.mass_penalty_global_idx);
out.insert(cache.tension_penalty_global_idx);
out.insert(cache.stiffness_penalty_global_idx);
}
out
}
fn build_spatial_adaptive_hyperspecs(cache_count: usize) -> Vec<SpatialAdaptiveHyperSpec> {
let mut out = Vec::with_capacity(cache_count * 3 + 3);
for cache_index in 0..cache_count {
out.push(SpatialAdaptiveHyperSpec {
cache_index,
kind: SpatialAdaptiveHyperKind::LogLambdaMagnitude,
});
out.push(SpatialAdaptiveHyperSpec {
cache_index,
kind: SpatialAdaptiveHyperKind::LogLambdaGradient,
});
out.push(SpatialAdaptiveHyperSpec {
cache_index,
kind: SpatialAdaptiveHyperKind::LogLambdaCurvature,
});
}
out.push(SpatialAdaptiveHyperSpec {
cache_index: 0,
kind: SpatialAdaptiveHyperKind::LogEpsilonMagnitude,
});
out.push(SpatialAdaptiveHyperSpec {
cache_index: 0,
kind: SpatialAdaptiveHyperKind::LogEpsilonGradient,
});
out.push(SpatialAdaptiveHyperSpec {
cache_index: 0,
kind: SpatialAdaptiveHyperKind::LogEpsilonCurvature,
});
out
}
fn penalty_matrixwith_local_block(
total_dim: usize,
coeff_range: Range<usize>,
local: &Array2<f64>,
) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((total_dim, total_dim));
out.slice_mut(s![coeff_range.clone(), coeff_range])
.assign(local);
out
}
fn fit_term_collectionwith_exact_spatial_adaptive_regularization(
baseline: FittedTermCollection,
y: ArrayView1<'_, f64>,
weights: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
family: LikelihoodFamily,
options: &FitOptions,
runtime_caches: &[SpatialOperatorRuntimeCache],
) -> Result<FittedTermCollection, EstimationError> {
let adaptive_opts = options.adaptive_regularization.clone().unwrap_or_default();
let adaptive_penalty_indices = exact_spatial_adaptive_penalty_index_set(runtime_caches);
let p_total = baseline.design.design.ncols();
struct RetainedPenaltySetup {
global_idx: usize,
global_penalty: Array2<f64>,
nullspace_dim: usize,
log_lambda: f64,
col_range: Range<usize>,
hessian_piece: Array2<f64>,
}
use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
let retained_setups = baseline
.design
.penalties
.par_iter()
.enumerate()
.map(|(idx, bp)| {
if adaptive_penalty_indices.contains(&idx) {
return None;
}
let lambda = baseline.fit.lambdas[idx];
Some(RetainedPenaltySetup {
global_idx: idx,
global_penalty: bp.to_global(p_total),
nullspace_dim: baseline
.design
.nullspace_dims
.get(idx)
.copied()
.unwrap_or(0),
log_lambda: lambda.max(1e-12).ln(),
col_range: bp.col_range.clone(),
hessian_piece: bp.local.mapv(|v| lambda * v),
})
})
.collect::<Vec<_>>();
let retained_count = retained_setups
.iter()
.filter(|setup| setup.is_some())
.count();
let mut retained_penalties = Vec::<Array2<f64>>::with_capacity(retained_count);
let mut retained_nullspace_dims = Vec::<usize>::with_capacity(retained_count);
let mut retained_log_lambdas = Vec::<f64>::with_capacity(retained_count);
let mut retained_global_indices = Vec::<usize>::with_capacity(retained_count);
let mut fixed_quadratichessian = Array2::<f64>::zeros((p_total, p_total));
for setup in retained_setups.into_iter().flatten() {
retained_penalties.push(setup.global_penalty);
retained_nullspace_dims.push(setup.nullspace_dim);
retained_log_lambdas.push(setup.log_lambda);
retained_global_indices.push(setup.global_idx);
fixed_quadratichessian
.slice_mut(s![setup.col_range.clone(), setup.col_range])
.scaled_add(1.0, &setup.hessian_piece);
}
let (eps_0_init, eps_g_init, eps_c_init) = compute_initial_epsilons(
&baseline.fit.beta,
runtime_caches,
adaptive_opts.min_epsilon,
)?;
let mut initial_theta =
Array1::<f64>::zeros(retained_penalties.len() + runtime_caches.len() * 3 + 3);
for (idx, value) in retained_log_lambdas.iter().enumerate() {
initial_theta[idx] = *value;
}
let adaptive_log_lambda_components = runtime_caches
.par_iter()
.map(|cache| {
[
baseline.fit.lambdas[cache.mass_penalty_global_idx]
.max(1e-12)
.ln(),
baseline.fit.lambdas[cache.tension_penalty_global_idx]
.max(1e-12)
.ln(),
baseline.fit.lambdas[cache.stiffness_penalty_global_idx]
.max(1e-12)
.ln(),
]
})
.collect::<Vec<_>>();
let mut at = retained_penalties.len();
for logs in &adaptive_log_lambda_components {
initial_theta[at] = logs[0];
initial_theta[at + 1] = logs[1];
initial_theta[at + 2] = logs[2];
at += 3;
}
initial_theta[at] = eps_0_init.max(adaptive_opts.min_epsilon).ln();
initial_theta[at + 1] = eps_g_init.max(adaptive_opts.min_epsilon).ln();
initial_theta[at + 2] = eps_c_init.max(adaptive_opts.min_epsilon).ln();
let hyperspecs = build_spatial_adaptive_hyperspecs(runtime_caches.len());
let zero_psi_op: std::sync::Arc<dyn crate::custom_family::CustomFamilyPsiDerivativeOperator> =
std::sync::Arc::new(crate::custom_family::ZeroPsiDerivativeOperator::new(
baseline.design.design.nrows(),
baseline.design.design.ncols(),
));
let derivative_blocks = vec![
hyperspecs
.par_iter()
.map(|_| CustomFamilyBlockPsiDerivative {
penalty_index: None,
x_psi: Array2::<f64>::zeros((0, 0)),
s_psi: Array2::<f64>::zeros((0, 0)),
s_psi_components: None,
s_psi_penalty_components: None,
x_psi_psi: None,
s_psi_psi: None,
s_psi_psi_components: None,
s_psi_psi_penalty_components: None,
implicit_operator: Some(std::sync::Arc::clone(&zero_psi_op)),
implicit_axis: 0,
implicit_group_id: None,
})
.collect::<Vec<_>>(),
];
let mixture_link_state = options
.mixture_link
.clone()
.as_ref()
.map(state_fromspec)
.transpose()
.map_err(EstimationError::InvalidInput)?;
let sas_link_state = options
.sas_link
.map(|spec| {
if matches!(family, LikelihoodFamily::BinomialBetaLogistic) {
state_from_beta_logisticspec(spec)
} else {
state_from_sasspec(spec)
}
})
.transpose()
.map_err(EstimationError::InvalidInput)?;
let latent_cloglog_state = options.latent_cloglog;
let shared_y = Arc::new(y.to_owned());
let sharedweights = Arc::new(weights.to_owned());
let shared_design = baseline
.design
.design
.try_to_dense_arc("spatial adaptive exact hyperfit design")
.map_err(EstimationError::InvalidInput)?;
let shared_offset = Arc::new(offset.to_owned());
let shared_runtime_caches = Arc::new(runtime_caches.to_vec());
let shared_hyperspecs = Arc::new(hyperspecs.clone());
let zero_quadratic = Arc::new(Array2::<f64>::zeros((
baseline.design.design.ncols(),
baseline.design.design.ncols(),
)));
let base_family = SpatialAdaptiveExactFamily {
family,
latent_cloglog_state,
mixture_link_state: mixture_link_state.clone(),
sas_link_state,
y: shared_y.clone(),
weights: sharedweights.clone(),
design: shared_design.clone(),
offset: shared_offset.clone(),
linear_constraints: baseline.design.linear_constraints.clone(),
runtime_caches: shared_runtime_caches.clone(),
adaptive_params: Vec::new(),
fixed_quadratichessian: zero_quadratic.clone(),
hyperspecs: shared_hyperspecs.clone(),
exact_eval_cache: Arc::new(Mutex::new(None)),
};
let rho_dim = retained_penalties.len();
let operator_slots_end = rho_dim + runtime_caches.len() * 3;
const UNIFIED_LOG_WINDOW: f64 = 6.0;
const RETAINED_LAMBDA_LOG_LOWER_FLOOR: f64 = -30.0;
const RETAINED_LAMBDA_LOG_UPPER_CAP: f64 = 30.0;
const OPERATOR_LAMBDA_LOG_LOWER_FLOOR: f64 = -10.0;
const OPERATOR_LAMBDA_LOG_UPPER_CAP: f64 = 30.0;
let epsilon_floor_log = adaptive_opts.min_epsilon.max(1e-12).ln();
let anchored_bound = |idx: usize, sign: f64| -> f64 {
let raw = initial_theta[idx] + sign * UNIFIED_LOG_WINDOW;
if idx < rho_dim {
raw.clamp(
RETAINED_LAMBDA_LOG_LOWER_FLOOR,
RETAINED_LAMBDA_LOG_UPPER_CAP,
)
} else if idx < operator_slots_end {
raw.clamp(
OPERATOR_LAMBDA_LOG_LOWER_FLOOR,
OPERATOR_LAMBDA_LOG_UPPER_CAP,
)
} else {
raw.max(epsilon_floor_log)
}
};
let eps_lower =
Array1::from_iter((0..initial_theta.len()).map(|idx| anchored_bound(idx, -1.0)));
let eps_upper = Array1::from_iter((0..initial_theta.len()).map(|idx| anchored_bound(idx, 1.0)));
let blockspec = ParameterBlockSpec {
name: "eta".to_string(),
design: baseline.design.design.clone(),
offset: offset.to_owned(),
penalties: retained_penalties
.iter()
.cloned()
.map(PenaltyMatrix::Dense)
.collect(),
nullspace_dims: retained_nullspace_dims.clone(),
initial_log_lambdas: Array1::from_vec(retained_log_lambdas.clone()),
initial_beta: Some(baseline.fit.beta.clone()),
};
let screening_cap = Arc::new(AtomicUsize::new(0));
let outer_opts = BlockwiseFitOptions {
inner_max_cycles: options.max_iter,
inner_tol: options.tol,
outer_max_iter: options.max_iter,
outer_tol: options.tol,
compute_covariance: false,
screening_max_inner_iterations: Some(Arc::clone(&screening_cap)),
..BlockwiseFitOptions::default()
};
use crate::solver::outer_strategy::{
DeclaredHessianForm, Derivative, HessianResult, OuterEval, OuterProblem,
};
struct SpatialAdaptiveOuterState {
warm_cache: Option<CustomFamilyWarmStart>,
last_eval: Option<(
Array1<f64>,
f64,
Array1<f64>,
HessianResult,
CustomFamilyWarmStart,
)>,
}
let n_theta = initial_theta.len();
let theta_bounds = Some((eps_lower.clone(), eps_upper.clone()));
let clamp_theta = {
let lo = eps_lower;
let hi = eps_upper;
move |theta: &Array1<f64>| -> Array1<f64> {
let mut clamped = theta.clone();
for i in 0..clamped.len() {
clamped[i] = clamped[i].clamp(lo[i], hi[i]);
}
clamped
}
};
let decode_theta = |theta: &Array1<f64>| -> (Array1<f64>, Vec<SpatialAdaptiveTermHyperParams>) {
let rho = theta.slice(s![..rho_dim]).to_owned();
let adaptive_lambda_start = rho_dim;
let adaptive_lambda_end = adaptive_lambda_start + runtime_caches.len() * 3;
let eps = [
theta[adaptive_lambda_end].exp(),
theta[adaptive_lambda_end + 1].exp(),
theta[adaptive_lambda_end + 2].exp(),
];
let adaptive_params = runtime_caches
.iter()
.enumerate()
.map(|(cache_idx, _)| SpatialAdaptiveTermHyperParams {
lambda: [
theta[adaptive_lambda_start + cache_idx * 3].exp(),
theta[adaptive_lambda_start + cache_idx * 3 + 1].exp(),
theta[adaptive_lambda_start + cache_idx * 3 + 2].exp(),
],
epsilon: eps,
})
.collect::<Vec<_>>();
(rho, adaptive_params)
};
let analytic_outer_hessian_available =
crate::custom_family::joint_exact_analytic_outer_hessian_available()
&& base_family
.exact_outer_derivative_order(std::slice::from_ref(&blockspec), &outer_opts)
.has_hessian()
&& crate::custom_family::exact_newton_outer_geometry_supports_second_order_solver(
&base_family,
);
let outer_max_iter = crate::custom_family::cost_gated_first_order_max_iter(
options.max_iter,
base_family.coefficient_gradient_cost(std::slice::from_ref(&blockspec)),
analytic_outer_hessian_available,
);
if outer_max_iter < options.max_iter {
log::info!(
"[OUTER] exact spatial adaptive regularization: first-order work gate reduced outer_max_iter {} -> {}",
options.max_iter,
outer_max_iter,
);
}
let problem = OuterProblem::new(n_theta)
.with_gradient(Derivative::Analytic)
.with_hessian(if analytic_outer_hessian_available {
DeclaredHessianForm::Either
} else {
DeclaredHessianForm::Unavailable
})
.with_fallback_policy(crate::solver::outer_strategy::FallbackPolicy::Disabled)
.with_psi_dim(n_theta.saturating_sub(rho_dim))
.with_tolerance(options.tol)
.with_max_iter(outer_max_iter)
.with_seed_config(crate::seeding::SeedConfig::default())
.with_screening_cap(Arc::clone(&screening_cap))
.with_initial_rho(initial_theta.clone());
let problem = if let Some((lo, hi)) = theta_bounds {
problem.with_bounds(lo, hi)
} else {
problem
};
let eval_outer = |st: &mut SpatialAdaptiveOuterState,
theta: &Array1<f64>,
order: crate::solver::outer_strategy::OuterEvalOrder|
-> Result<OuterEval, EstimationError> {
let theta = clamp_theta(theta);
if let Some((cached_theta, cached_cost, cached_grad, cached_hess, cached_warm)) =
&st.last_eval
&& cached_theta.len() == theta.len()
&& cached_theta
.iter()
.zip(theta.iter())
.all(|(&a, &b)| (a - b).abs() <= 1e-12)
&& (!matches!(
order,
crate::solver::outer_strategy::OuterEvalOrder::ValueGradientHessian
) || analytic_outer_hessian_available)
{
st.warm_cache = Some(cached_warm.clone());
return Ok(OuterEval {
cost: *cached_cost,
gradient: cached_grad.clone(),
hessian: if matches!(
order,
crate::solver::outer_strategy::OuterEvalOrder::ValueGradientHessian
) && analytic_outer_hessian_available
{
cached_hess.clone()
} else {
HessianResult::Unavailable
},
});
}
let (rho, adaptive_params) = decode_theta(&theta);
let family_eval = base_family.with_adaptive_params(adaptive_params, zero_quadratic.clone());
let need_hessian = matches!(
order,
crate::solver::outer_strategy::OuterEvalOrder::ValueGradientHessian
) && analytic_outer_hessian_available;
let result = evaluate_custom_family_joint_hyper(
&family_eval,
std::slice::from_ref(&blockspec),
&outer_opts,
&rho,
&derivative_blocks,
st.warm_cache.as_ref(),
if need_hessian {
crate::solver::estimate::reml::unified::EvalMode::ValueGradientHessian
} else {
crate::solver::estimate::reml::unified::EvalMode::ValueAndGradient
},
)
.map_err(|e| {
EstimationError::RemlOptimizationFailed(format!("spatial adaptive eval failed: {e}"))
})?;
if !result.inner_converged {
st.warm_cache = Some(result.warm_start.clone());
return Err(EstimationError::RemlOptimizationFailed(
"exact spatial adaptive inner solve did not converge".to_string(),
));
}
if !result.objective.is_finite() || result.gradient.iter().any(|v| !v.is_finite()) {
return Err(EstimationError::RemlOptimizationFailed(
"exact spatial adaptive objective returned non-finite values".to_string(),
));
}
let hessian_result = if need_hessian {
if !result.outer_hessian.is_analytic() {
return Err(EstimationError::RemlOptimizationFailed(
"exact spatial adaptive objective did not return an exact outer Hessian"
.to_string(),
));
}
match result.outer_hessian.dim() {
Some(dim) if dim == theta.len() => {}
Some(dim) => {
return Err(EstimationError::RemlOptimizationFailed(format!(
"exact spatial adaptive outer Hessian dimension mismatch: got {dim}, expected {}",
theta.len(),
)));
}
None => {
return Err(EstimationError::RemlOptimizationFailed(
"exact spatial adaptive objective did not report an outer Hessian dimension"
.to_string(),
));
}
}
st.last_eval = Some((
theta.clone(),
result.objective,
result.gradient.clone(),
result.outer_hessian.clone(),
result.warm_start.clone(),
));
result.outer_hessian
} else {
HessianResult::Unavailable
};
st.warm_cache = Some(result.warm_start);
Ok(OuterEval {
cost: result.objective,
gradient: result.gradient,
hessian: hessian_result,
})
};
let mut obj = problem.build_objective_with_eval_order(
SpatialAdaptiveOuterState {
warm_cache: None,
last_eval: None,
},
|st: &mut SpatialAdaptiveOuterState, theta: &Array1<f64>| {
let theta = clamp_theta(theta);
let (rho, adaptive_params) = decode_theta(&theta);
let family_eval =
base_family.with_adaptive_params(adaptive_params, zero_quadratic.clone());
let result = evaluate_custom_family_joint_hyper(
&family_eval,
std::slice::from_ref(&blockspec),
&outer_opts,
&rho,
&derivative_blocks,
st.warm_cache.as_ref(),
crate::solver::estimate::reml::unified::EvalMode::ValueOnly,
)
.map_err(|e| {
EstimationError::RemlOptimizationFailed(format!(
"spatial adaptive cost eval failed: {e}"
))
})?;
if !result.inner_converged {
st.warm_cache = Some(result.warm_start);
return Err(EstimationError::RemlOptimizationFailed(
"exact spatial adaptive cost inner solve did not converge".to_string(),
));
}
st.warm_cache = Some(result.warm_start);
Ok(result.objective)
},
|st: &mut SpatialAdaptiveOuterState, theta: &Array1<f64>| {
eval_outer(
st,
theta,
if analytic_outer_hessian_available {
crate::solver::outer_strategy::OuterEvalOrder::ValueGradientHessian
} else {
crate::solver::outer_strategy::OuterEvalOrder::ValueAndGradient
},
)
},
|st: &mut SpatialAdaptiveOuterState,
theta: &Array1<f64>,
order: crate::solver::outer_strategy::OuterEvalOrder| {
eval_outer(st, theta, order)
},
Some(|st: &mut SpatialAdaptiveOuterState| {
st.warm_cache = None;
st.last_eval = None;
}),
Some(|st: &mut SpatialAdaptiveOuterState, theta: &Array1<f64>| {
let theta = clamp_theta(theta);
let (rho, adaptive_params) = decode_theta(&theta);
let family_eval =
base_family.with_adaptive_params(adaptive_params, zero_quadratic.clone());
let result = evaluate_custom_family_joint_hyper_efs(
&family_eval,
std::slice::from_ref(&blockspec),
&outer_opts,
&rho,
&derivative_blocks,
st.warm_cache.as_ref(),
)
.map_err(|e| {
EstimationError::RemlOptimizationFailed(format!(
"spatial adaptive EFS eval failed: {e}"
))
})?;
if !result.inner_converged {
st.warm_cache = Some(result.warm_start);
return Err(EstimationError::RemlOptimizationFailed(
"exact spatial adaptive EFS inner solve did not converge".to_string(),
));
}
st.warm_cache = Some(result.warm_start);
Ok(result.efs_eval)
}),
);
let outer_result = problem
.run(&mut obj, "exact spatial adaptive regularization")
.map_err(|e| {
EstimationError::InvalidInput(format!(
"exact spatial adaptive outer optimization failed: {e}"
))
})?;
let outer_iterations = outer_result.iterations;
let outer_grad_norm = outer_result.final_grad_norm;
let theta_star = outer_result.rho;
let rho_star = theta_star.slice(s![..rho_dim]).to_owned();
let adaptive_lambda_start = rho_dim;
let adaptive_lambda_end = adaptive_lambda_start + runtime_caches.len() * 3;
let eps_star = [
theta_star[adaptive_lambda_end].exp(),
theta_star[adaptive_lambda_end + 1].exp(),
theta_star[adaptive_lambda_end + 2].exp(),
];
let adaptive_params = runtime_caches
.iter()
.enumerate()
.map(|(cache_idx, _)| SpatialAdaptiveTermHyperParams {
lambda: [
theta_star[adaptive_lambda_start + cache_idx * 3].exp(),
theta_star[adaptive_lambda_start + cache_idx * 3 + 1].exp(),
theta_star[adaptive_lambda_start + cache_idx * 3 + 2].exp(),
],
epsilon: eps_star,
})
.collect::<Vec<_>>();
let mut fixed_total = Array2::<f64>::zeros((
baseline.design.design.ncols(),
baseline.design.design.ncols(),
));
for (idx, penalty) in retained_penalties.iter().enumerate() {
fixed_total.scaled_add(rho_star[idx].exp(), penalty);
}
let final_family =
base_family.with_adaptive_params(adaptive_params.clone(), Arc::new(fixed_total.clone()));
let final_blockspec = ParameterBlockSpec {
name: "eta".to_string(),
design: baseline.design.design.clone(),
offset: offset.to_owned(),
penalties: vec![],
nullspace_dims: vec![],
initial_log_lambdas: Array1::zeros(0),
initial_beta: Some(baseline.fit.beta.clone()),
};
let final_fit = fit_custom_family(
&final_family,
&[final_blockspec],
&BlockwiseFitOptions {
inner_max_cycles: options.max_iter,
inner_tol: options.tol,
outer_max_iter: 1,
outer_tol: options.tol,
compute_covariance: true,
..BlockwiseFitOptions::default()
},
)
.map_err(|e| EstimationError::InvalidInput(e.to_string()))?;
let beta = final_fit.block_states[0].beta.clone();
let final_eval = final_family
.exact_evaluation(&beta)
.map_err(EstimationError::InvalidInput)?;
let penalized_hessian = final_eval
.totalobjectivehessian(&final_family.design)
.map_err(EstimationError::InvalidInput)?;
let beta_covariance = final_fit.covariance_conditional.clone();
let beta_standard_errors = beta_covariance
.as_ref()
.map(|cov| Array1::from_iter((0..cov.nrows()).map(|i| cov[[i, i]].max(0.0).sqrt())));
let mut full_lambdas = baseline.fit.lambdas.clone();
for (idx, &global_idx) in retained_global_indices.iter().enumerate() {
full_lambdas[global_idx] = rho_star[idx].exp();
}
for (cache_idx, cache) in runtime_caches.iter().enumerate() {
full_lambdas[cache.mass_penalty_global_idx] = adaptive_params[cache_idx].lambda[0];
full_lambdas[cache.tension_penalty_global_idx] = adaptive_params[cache_idx].lambda[1];
full_lambdas[cache.stiffness_penalty_global_idx] = adaptive_params[cache_idx].lambda[2];
}
let deviance = match family {
LikelihoodFamily::GaussianIdentity => y
.iter()
.zip(final_eval.obs.mu.iter())
.zip(weights.iter())
.map(|((&yy, &mu), &w)| w.max(0.0) * (yy - mu) * (yy - mu))
.sum(),
_ => -2.0 * final_eval.obs.log_likelihood,
};
let mut local_penalty_blocks =
Vec::<PenaltySpec>::with_capacity(baseline.design.penalties.len());
for (global_idx, bp) in baseline.design.penalties.iter().enumerate() {
if adaptive_penalty_indices.contains(&global_idx) {
let cache = runtime_caches
.iter()
.find(|cache| {
cache.mass_penalty_global_idx == global_idx
|| cache.tension_penalty_global_idx == global_idx
|| cache.stiffness_penalty_global_idx == global_idx
})
.ok_or_else(|| {
EstimationError::InvalidInput(format!(
"missing runtime cache for adaptive penalty index {global_idx}"
))
})?;
let cache_idx = runtime_caches
.iter()
.position(|c| {
c.mass_penalty_global_idx == global_idx
|| c.tension_penalty_global_idx == global_idx
|| c.stiffness_penalty_global_idx == global_idx
})
.ok_or_else(|| {
EstimationError::InvalidInput(format!(
"missing adaptive cache position for penalty index {global_idx}"
))
})?;
let state = &final_eval.adaptive_states[cache_idx];
let local = if cache.mass_penalty_global_idx == global_idx {
scalar_operatorhessian(&cache.d0, &state.magnitude.betahessian_diag())
.mapv(|v| adaptive_params[cache_idx].lambda[0] * v)
} else if cache.tension_penalty_global_idx == global_idx {
grouped_operatorhessian(
&cache.d1,
cache.dimension,
&state.gradient.betahessian_blocks(),
)?
.mapv(|v| adaptive_params[cache_idx].lambda[1] * v)
} else {
grouped_operatorhessian(
&cache.d2,
cache.dimension * cache.dimension,
&state.curvature.betahessian_blocks(),
)?
.mapv(|v| adaptive_params[cache_idx].lambda[2] * v)
};
local_penalty_blocks.push(PenaltySpec::Dense(penalty_matrixwith_local_block(
baseline.design.design.ncols(),
cache.coeff_global_range.clone(),
&local,
)));
} else {
local_penalty_blocks.push(PenaltySpec::Dense(
bp.to_global(p_total).mapv(|v| v * full_lambdas[global_idx]),
));
}
}
let (edf_by_block, edf_total) = if let Some(cov) = beta_covariance.as_ref() {
exact_bounded_edf(
&local_penalty_blocks,
&Array1::from_elem(local_penalty_blocks.len(), 1.0),
cov,
)?
} else {
(vec![0.0; local_penalty_blocks.len()], 0.0)
};
let stable_penalty_term =
2.0 * final_eval.adaptive_penalty_value + beta.dot(&fixed_total.dot(&beta));
let standard_deviation = match family {
LikelihoodFamily::GaussianIdentity => {
let denom = (y.len() as f64 - edf_total).max(1.0);
(deviance / denom).sqrt()
}
_ => 1.0,
};
let maps = compute_spatial_adaptiveweights_for_beta(
&beta,
runtime_caches,
eps_star[0],
eps_star[1],
eps_star[2],
adaptive_opts.weight_floor,
adaptive_opts.weight_ceiling,
)?
.into_iter()
.zip(runtime_caches.iter())
.map(|(w, cache)| AdaptiveSpatialMap {
termname: cache.termname.clone(),
feature_cols: cache.feature_cols.clone(),
collocation_points: cache.collocation_points.clone(),
inv_magweight: w.inv_magweight,
invgradweight: w.invgradweight,
inv_lapweight: w.inv_lapweight,
})
.collect::<Vec<_>>();
let fitted_link = match family {
LikelihoodFamily::BinomialLatentCLogLog => FittedLinkState::LatentCLogLog {
state: latent_cloglog_state
.expect("BinomialLatentCLogLog requires an explicit latent-cloglog state"),
},
LikelihoodFamily::BinomialMixture => mixture_link_state
.clone()
.map(|state| FittedLinkState::Mixture {
state,
covariance: None,
})
.unwrap_or(FittedLinkState::Standard(None)),
LikelihoodFamily::BinomialSas => sas_link_state
.map(|state| FittedLinkState::Sas {
state,
covariance: None,
})
.unwrap_or(FittedLinkState::Standard(None)),
LikelihoodFamily::BinomialBetaLogistic => sas_link_state
.map(|state| FittedLinkState::BetaLogistic {
state,
covariance: None,
})
.unwrap_or(FittedLinkState::Standard(None)),
_ => FittedLinkState::Standard(None),
};
let max_abs_eta = final_eval
.obs
.eta
.iter()
.fold(0.0_f64, |acc, &v| acc.max(v.abs()));
let fitted = FittedTermCollection {
fit: {
let log_lambdas = full_lambdas.mapv(|v| v.max(1e-300).ln());
let inf = FitInference {
edf_by_block,
edf_total,
smoothing_correction: None,
penalized_hessian: penalized_hessian.clone(),
working_weights: final_eval.obs.fisherweight.clone(),
working_response: {
let mut out = final_eval.obs.eta.clone();
for i in 0..out.len() {
let wi = final_eval.obs.fisherweight[i].max(1e-12);
out[i] += final_eval.obs.score[i] / wi;
}
out
},
reparam_qs: None,
beta_covariance,
beta_standard_errors,
beta_covariance_corrected: None,
beta_standard_errors_corrected: None,
bias_correction_beta: None,
};
let geometry = Some(crate::estimate::FitGeometry {
penalized_hessian,
working_weights: inf.working_weights.clone(),
working_response: inf.working_response.clone(),
});
let covariance_conditional = inf.beta_covariance.clone();
let pirls_status_val = if final_fit.outer_converged {
crate::pirls::PirlsStatus::Converged
} else {
crate::pirls::PirlsStatus::StalledAtValidMinimum
};
UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
blocks: vec![crate::estimate::FittedBlock {
beta: beta.clone(),
role: crate::estimate::BlockRole::Mean,
edf: edf_total,
lambdas: full_lambdas.clone(),
}],
log_lambdas,
lambdas: full_lambdas,
likelihood_family: Some(family),
likelihood_scale: family.default_scale_metadata(),
log_likelihood_normalization:
crate::types::LogLikelihoodNormalization::UserProvided,
log_likelihood: final_eval.obs.log_likelihood,
deviance,
reml_score: final_fit.penalized_objective,
stable_penalty_term,
penalized_objective: final_fit.penalized_objective,
outer_iterations,
outer_converged: final_fit.outer_converged,
outer_gradient_norm: outer_grad_norm,
standard_deviation,
covariance_conditional,
covariance_corrected: None,
inference: Some(inf),
fitted_link,
geometry,
block_states: Vec::new(),
pirls_status: pirls_status_val,
max_abs_eta,
constraint_kkt: None,
artifacts: crate::estimate::FitArtifacts {
pirls: None,
..Default::default()
},
inner_cycles: 0,
})?
},
design: baseline.design,
adaptive_diagnostics: Some(AdaptiveRegularizationDiagnostics {
epsilon_0: eps_star[0],
epsilon_g: eps_star[1],
epsilon_c: eps_star[2],
epsilon_outer_iterations: outer_iterations,
mm_iterations: 0,
converged: final_fit.outer_converged,
maps,
}),
};
enforce_term_constraint_feasibility(&fitted.design, &fitted.fit)?;
Ok(fitted)
}
#[cfg(test)]
fn weighted_operator_gram_from_d1(
d1: &Array2<f64>,
weight: &Array1<f64>,
dimension: usize,
) -> Array2<f64> {
let mut weighted = d1.clone();
for k in 0..weight.len() {
let w = weight[k].sqrt();
for axis in 0..dimension {
let row = k * dimension + axis;
weighted.row_mut(row).mapv_inplace(|v| v * w);
}
}
let gram = weighted.t().dot(&weighted);
(&gram + &gram.t().to_owned()) * 0.5
}
#[cfg(test)]
fn weighted_operator_gram_from_d2(d2: &Array2<f64>, weight: &Array1<f64>) -> Array2<f64> {
let mut weighted = d2.clone();
let block_dim = d2
.nrows()
.checked_div(weight.len().max(1))
.filter(|block| *block > 0 && *block * weight.len() == d2.nrows())
.expect("D2 row count must be an integer number of curvature rows per collocation point");
for k in 0..weight.len() {
let w = weight[k].sqrt();
for local in 0..block_dim {
weighted
.row_mut(k * block_dim + local)
.mapv_inplace(|v| v * w);
}
}
let gram = weighted.t().dot(&weighted);
(&gram + &gram.t().to_owned()) * 0.5
}
fn adaptive_fit_options_base(options: &FitOptions, design: &TermCollectionDesign) -> FitOptions {
FitOptions {
latent_cloglog: options.latent_cloglog,
mixture_link: options.mixture_link.clone(),
optimize_mixture: options.optimize_mixture,
sas_link: options.sas_link,
optimize_sas: options.optimize_sas,
compute_inference: options.compute_inference,
max_iter: options.max_iter,
tol: options.tol,
nullspace_dims: design.nullspace_dims.clone(),
linear_constraints: design.linear_constraints.clone(),
firth_bias_reduction: options.firth_bias_reduction,
adaptive_regularization: None,
penalty_shrinkage_floor: options.penalty_shrinkage_floor,
rho_prior: options.rho_prior.clone(),
kronecker_penalty_system: design.kronecker_penalty_system(),
kronecker_factored: design
.smooth
.terms
.iter()
.find_map(|t| t.kronecker_factored.clone()),
}
}
#[derive(Clone)]
struct BoundedLinearTermMeta {
col_idx: usize,
min: f64,
max: f64,
prior: BoundedCoefficientPriorSpec,
}
#[derive(Clone)]
struct BoundedLinearFamily {
family: LikelihoodFamily,
latent_cloglog_state: Option<LatentCLogLogState>,
mixture_link_state: Option<MixtureLinkState>,
sas_link_state: Option<SasLinkState>,
y: Array1<f64>,
weights: Array1<f64>,
design: Array2<f64>,
designzeroed: Array2<f64>,
offset: Array1<f64>,
bounded_terms: Vec<BoundedLinearTermMeta>,
}
#[derive(Clone)]
struct StandardFamilyObservationState {
eta: Array1<f64>,
mu: Array1<f64>,
score: Array1<f64>,
fisherweight: Array1<f64>,
neghessian_eta: Array1<f64>,
neghessian_eta_derivative: Array1<f64>,
log_likelihood: f64,
}
fn bounded_logit(z: f64) -> f64 {
let zc = z.clamp(1e-12, 1.0 - 1e-12);
(zc / (1.0 - zc)).ln()
}
fn stable_sigmoid(theta: f64) -> f64 {
if theta >= 0.0 {
let exp_neg = (-theta).exp();
1.0 / (1.0 + exp_neg)
} else {
let exp_pos = theta.exp();
exp_pos / (1.0 + exp_pos)
}
}
fn stable_softplus(x: f64) -> f64 {
if x > 0.0 {
x + (-x).exp().ln_1p()
} else {
x.exp().ln_1p()
}
}
fn bounded_latent_to_user(theta: f64, min: f64, max: f64) -> (f64, f64, f64) {
let z = stable_sigmoid(theta);
let width = max - min;
let beta = min + width * z;
let db_dtheta = width * z * (1.0 - z);
(beta, z, db_dtheta)
}
fn bounded_latent_derivatives(theta: f64, min: f64, max: f64) -> (f64, f64, f64, f64, f64) {
let z = stable_sigmoid(theta);
let width = max - min;
let s = z * (1.0 - z);
let beta = min + width * z;
let db_dtheta = width * s;
let d2b_dtheta2 = width * s * (1.0 - 2.0 * z);
let d3b_dtheta3 = width * s * (1.0 - 6.0 * z + 6.0 * z * z);
(beta, z, db_dtheta, d2b_dtheta2, d3b_dtheta3)
}
fn bounded_prior_terms(theta: f64, prior: &BoundedCoefficientPriorSpec) -> (f64, f64, f64, f64) {
let (a, b) = match prior {
BoundedCoefficientPriorSpec::None => return (0.0, 0.0, 0.0, 0.0),
BoundedCoefficientPriorSpec::Uniform => (1.0, 1.0),
BoundedCoefficientPriorSpec::Beta { a, b } => (*a, *b),
};
let z = stable_sigmoid(theta).clamp(1e-12, 1.0 - 1e-12);
let logp = a * z.ln() + b * (1.0 - z).ln();
let grad = a - (a + b) * z;
let neghess = (a + b) * z * (1.0 - z);
let neghess_derivative = (a + b) * z * (1.0 - z) * (1.0 - 2.0 * z);
(logp, grad, neghess, neghess_derivative)
}
fn evaluate_standard_familyobservations(
family: LikelihoodFamily,
latent_cloglog_state: Option<&LatentCLogLogState>,
mixture_link_state: Option<&MixtureLinkState>,
sas_link_state: Option<&SasLinkState>,
y: &Array1<f64>,
weights: &Array1<f64>,
eta: &Array1<f64>,
) -> Result<StandardFamilyObservationState, EstimationError> {
const PROB_EPS: f64 = 1e-10;
const MU_DERIV_EPS: f64 = 1e-12;
let n = y.len();
if weights.len() != n || eta.len() != n {
return Err(EstimationError::InvalidInput(
"bounded family observation size mismatch".to_string(),
));
}
let mut mu = Array1::<f64>::zeros(n);
let mut score = Array1::<f64>::zeros(n);
let mut fisherweight = Array1::<f64>::zeros(n);
let mut neghessian_eta = Array1::<f64>::zeros(n);
let mut neghessian_eta_derivative = Array1::<f64>::zeros(n);
let mut log_likelihood = 0.0;
for i in 0..n {
let w = weights[i].max(0.0);
let yi = y[i];
let eta_i = eta[i];
match family {
LikelihoodFamily::GaussianIdentity => {
let resid = yi - eta_i;
mu[i] = eta_i;
score[i] = w * resid;
fisherweight[i] = w.max(MU_DERIV_EPS);
neghessian_eta[i] = w;
neghessian_eta_derivative[i] = 0.0;
log_likelihood += -0.5 * w * resid * resid;
}
LikelihoodFamily::BinomialLogit => {
let jet = logit_inverse_link_jet5(eta_i);
mu[i] = jet.mu;
score[i] = w * (yi - jet.mu);
fisherweight[i] = jet.d1.max(MU_DERIV_EPS);
neghessian_eta[i] = jet.d1;
neghessian_eta_derivative[i] = jet.d2;
let logmu = -stable_softplus(-eta_i);
let log_one_minusmu = -stable_softplus(eta_i);
log_likelihood += w * (yi * logmu + (1.0 - yi) * log_one_minusmu);
}
LikelihoodFamily::BinomialProbit
| LikelihoodFamily::BinomialCLogLog
| LikelihoodFamily::BinomialLatentCLogLog
| LikelihoodFamily::BinomialSas
| LikelihoodFamily::BinomialBetaLogistic
| LikelihoodFamily::BinomialMixture => {
let inverse_link = if let Some(state) = latent_cloglog_state {
Some(InverseLink::LatentCLogLog(*state))
} else if let Some(state) = mixture_link_state {
Some(InverseLink::Mixture(state.clone()))
} else if let Some(state) = sas_link_state {
Some(
if matches!(family, LikelihoodFamily::BinomialBetaLogistic) {
InverseLink::BetaLogistic(*state)
} else {
InverseLink::Sas(*state)
},
)
} else {
None
};
let jet =
strategy_for_family(family, inverse_link.as_ref()).inverse_link_jet(eta_i)?;
let mu_i_raw = jet.mu;
let dmu_deta_raw = jet.d1;
let mu_i: f64 = mu_i_raw.clamp(PROB_EPS, 1.0 - PROB_EPS);
let dmu_deta = dmu_deta_raw.max(MU_DERIV_EPS);
let d2mu_deta2 = jet.d2;
let d3mu_deta3 = jet.d3;
let var = (mu_i * (1.0 - mu_i)).max(PROB_EPS);
let lmu = (yi - mu_i) / var;
let lmumu = -(yi / (mu_i * mu_i)) - ((1.0 - yi) / ((1.0 - mu_i) * (1.0 - mu_i)));
let lmumumu = 2.0 * yi / (mu_i * mu_i * mu_i)
- 2.0 * (1.0 - yi) / ((1.0 - mu_i) * (1.0 - mu_i) * (1.0 - mu_i));
mu[i] = mu_i;
score[i] = w * lmu * dmu_deta;
fisherweight[i] = (w * dmu_deta * dmu_deta / var).max(MU_DERIV_EPS);
neghessian_eta[i] = -w * (lmumu * dmu_deta * dmu_deta + lmu * d2mu_deta2);
neghessian_eta_derivative[i] = -w
* (lmumumu * dmu_deta * dmu_deta * dmu_deta
+ 3.0 * lmumu * dmu_deta * d2mu_deta2
+ lmu * d3mu_deta3);
log_likelihood += w * (yi * mu_i.ln() + (1.0 - yi) * (1.0 - mu_i).ln());
}
LikelihoodFamily::PoissonLog => {
return Err(EstimationError::InvalidInput(
"bounded linear terms are not supported for PoissonLog fits".to_string(),
));
}
LikelihoodFamily::GammaLog => {
return Err(EstimationError::InvalidInput(
"bounded linear terms are not supported for GammaLog fits".to_string(),
));
}
LikelihoodFamily::RoystonParmar => {
return Err(EstimationError::InvalidInput(
"bounded linear terms are not supported for survival model fits".to_string(),
));
}
}
}
Ok(StandardFamilyObservationState {
eta: eta.clone(),
mu,
score,
fisherweight,
neghessian_eta,
neghessian_eta_derivative,
log_likelihood,
})
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum SpatialAdaptiveHyperKind {
LogLambdaMagnitude,
LogLambdaGradient,
LogLambdaCurvature,
LogEpsilonMagnitude,
LogEpsilonGradient,
LogEpsilonCurvature,
}
impl SpatialAdaptiveHyperKind {
fn component_index(self) -> usize {
match self {
SpatialAdaptiveHyperKind::LogLambdaMagnitude
| SpatialAdaptiveHyperKind::LogEpsilonMagnitude => 0,
SpatialAdaptiveHyperKind::LogLambdaGradient
| SpatialAdaptiveHyperKind::LogEpsilonGradient => 1,
SpatialAdaptiveHyperKind::LogLambdaCurvature
| SpatialAdaptiveHyperKind::LogEpsilonCurvature => 2,
}
}
fn is_log_lambda(self) -> bool {
matches!(
self,
SpatialAdaptiveHyperKind::LogLambdaMagnitude
| SpatialAdaptiveHyperKind::LogLambdaGradient
| SpatialAdaptiveHyperKind::LogLambdaCurvature
)
}
fn is_log_epsilon(self) -> bool {
matches!(
self,
SpatialAdaptiveHyperKind::LogEpsilonMagnitude
| SpatialAdaptiveHyperKind::LogEpsilonGradient
| SpatialAdaptiveHyperKind::LogEpsilonCurvature
)
}
}
#[derive(Clone, Copy, Debug)]
struct SpatialAdaptiveHyperSpec {
cache_index: usize,
kind: SpatialAdaptiveHyperKind,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum SpatialAdaptiveExplicitSecondOrderKind {
StructuralZero,
LocalAlphaAlpha,
LocalAlphaEta,
SharedEtaEta,
}
impl SpatialAdaptiveHyperSpec {
fn component_index(self) -> usize {
self.kind.component_index()
}
fn explicit_second_order_kind(self, other: Self) -> SpatialAdaptiveExplicitSecondOrderKind {
if self.component_index() != other.component_index() {
return SpatialAdaptiveExplicitSecondOrderKind::StructuralZero;
}
match (
self.kind.is_log_lambda(),
other.kind.is_log_lambda(),
self.kind.is_log_epsilon(),
other.kind.is_log_epsilon(),
) {
(true, true, false, false) if self.cache_index == other.cache_index => {
SpatialAdaptiveExplicitSecondOrderKind::LocalAlphaAlpha
}
(true, false, false, true) | (false, true, true, false) => {
SpatialAdaptiveExplicitSecondOrderKind::LocalAlphaEta
}
(false, false, true, true) => SpatialAdaptiveExplicitSecondOrderKind::SharedEtaEta,
_ => SpatialAdaptiveExplicitSecondOrderKind::StructuralZero,
}
}
}
#[derive(Clone, Debug)]
struct SpatialAdaptiveTermHyperParams {
lambda: [f64; 3],
epsilon: [f64; 3],
}
#[derive(Clone)]
struct SpatialAdaptiveExactEvaluation {
obs: StandardFamilyObservationState,
adaptive_states: Vec<SpatialPenaltyExactState>,
adaptive_penalty_value: f64,
adaptive_penaltygradient: Array1<f64>,
adaptive_penaltyhessian: Array2<f64>,
fixed_quadraticvalue: f64,
fixed_quadraticgradient: Array1<f64>,
fixed_quadratichessian: Array2<f64>,
}
#[derive(Clone)]
struct CachedSpatialAdaptiveExactEvaluation {
beta: Array1<f64>,
eval: Arc<SpatialAdaptiveExactEvaluation>,
}
impl SpatialAdaptiveExactEvaluation {
fn total_penalty_value(&self) -> f64 {
self.adaptive_penalty_value + self.fixed_quadraticvalue
}
fn total_penaltygradient(&self) -> Array1<f64> {
&self.adaptive_penaltygradient + &self.fixed_quadraticgradient
}
fn total_penaltyhessian(&self) -> Array2<f64> {
&self.adaptive_penaltyhessian + &self.fixed_quadratichessian
}
fn totalobjectivehessian(&self, design: &Array2<f64>) -> Result<Array2<f64>, String> {
let mut out = xt_diag_x_dense(design.view(), self.obs.neghessian_eta.view())?;
out += &self.total_penaltyhessian();
Ok(out)
}
}
#[derive(Clone)]
struct SpatialAdaptiveExactFamily {
family: LikelihoodFamily,
latent_cloglog_state: Option<LatentCLogLogState>,
mixture_link_state: Option<MixtureLinkState>,
sas_link_state: Option<SasLinkState>,
y: Arc<Array1<f64>>,
weights: Arc<Array1<f64>>,
design: Arc<Array2<f64>>,
offset: Arc<Array1<f64>>,
linear_constraints: Option<LinearInequalityConstraints>,
runtime_caches: Arc<Vec<SpatialOperatorRuntimeCache>>,
adaptive_params: Vec<SpatialAdaptiveTermHyperParams>,
fixed_quadratichessian: Arc<Array2<f64>>,
hyperspecs: Arc<Vec<SpatialAdaptiveHyperSpec>>,
exact_eval_cache: Arc<Mutex<Option<CachedSpatialAdaptiveExactEvaluation>>>,
}
impl SpatialAdaptiveExactFamily {
fn with_adaptive_params(
&self,
adaptive_params: Vec<SpatialAdaptiveTermHyperParams>,
fixed_quadratichessian: Arc<Array2<f64>>,
) -> Self {
Self {
family: self.family,
latent_cloglog_state: self.latent_cloglog_state,
mixture_link_state: self.mixture_link_state.clone(),
sas_link_state: self.sas_link_state,
y: self.y.clone(),
weights: self.weights.clone(),
design: self.design.clone(),
offset: self.offset.clone(),
linear_constraints: self.linear_constraints.clone(),
runtime_caches: self.runtime_caches.clone(),
adaptive_params,
fixed_quadratichessian,
hyperspecs: self.hyperspecs.clone(),
exact_eval_cache: Arc::new(Mutex::new(None)),
}
}
fn total_eta(&self, beta: &Array1<f64>) -> Array1<f64> {
crate::faer_ndarray::fast_av(self.design.as_ref(), beta) + self.offset.as_ref()
}
fn fixed_quadratic_terms(&self, beta: &Array1<f64>) -> (f64, Array1<f64>) {
let grad = self.fixed_quadratichessian.dot(beta);
let value = 0.5 * beta.dot(&grad);
(value, grad)
}
fn adaptive_penalty_value_only(&self, beta: &Array1<f64>) -> Result<f64, String> {
let mut penalty_value = 0.0;
for (cache_idx, cache) in self.runtime_caches.iter().enumerate() {
let params = self.adaptive_params.get(cache_idx).ok_or_else(|| {
format!(
"missing adaptive parameter block for cache {}",
cache.termname
)
})?;
let beta_local = beta.slice(s![cache.coeff_global_range.clone()]);
let state =
SpatialPenaltyExactState::from_beta_local(beta_local, cache, params.epsilon)
.map_err(|e| e.to_string())?;
penalty_value += params.lambda[0] * state.magnitude.penalty_value();
penalty_value += params.lambda[1] * state.gradient.penalty_value();
penalty_value += params.lambda[2] * state.curvature.penalty_value();
}
Ok(penalty_value)
}
fn zero_hyper_parts(&self) -> (Array1<f64>, Array2<f64>) {
let total_dim = self.design.ncols();
(
Array1::<f64>::zeros(total_dim),
Array2::<f64>::zeros((total_dim, total_dim)),
)
}
fn embed_local_hyper_parts(
&self,
coeff_range: &Range<usize>,
local_grad: &Array1<f64>,
local_hess: &Array2<f64>,
) -> (Array1<f64>, Array2<f64>) {
let (mut beta_mixed, mut betahessian) = self.zero_hyper_parts();
beta_mixed
.slice_mut(s![coeff_range.clone()])
.assign(local_grad);
betahessian
.slice_mut(s![coeff_range.clone(), coeff_range.clone()])
.assign(local_hess);
(beta_mixed, betahessian)
}
fn embed_local_hyper_hessian(
&self,
coeff_range: &Range<usize>,
local_hess: &Array2<f64>,
) -> Array2<f64> {
let total_dim = self.design.ncols();
let mut out = Array2::<f64>::zeros((total_dim, total_dim));
out.slice_mut(s![coeff_range.clone(), coeff_range.clone()])
.assign(local_hess);
out
}
fn adaptive_block_parts(
&self,
eval: &SpatialAdaptiveExactEvaluation,
cache_idx: usize,
component: usize,
) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
let cache = self
.runtime_caches
.get(cache_idx)
.ok_or_else(|| format!("adaptive cache index {} out of bounds", cache_idx))?;
let params = self
.adaptive_params
.get(cache_idx)
.ok_or_else(|| format!("adaptive hyperparameter block {} out of bounds", cache_idx))?;
let state = eval
.adaptive_states
.get(cache_idx)
.ok_or_else(|| format!("adaptive exact state index {} out of bounds", cache_idx))?;
match component {
0 => {
let lambda = params.lambda[0];
let beta_mixed_local = lambda
* scalar_operatorgradient(&cache.d0, &state.magnitude.betagradient_coeff());
let betahessian_local =
lambda * scalar_operatorhessian(&cache.d0, &state.magnitude.betahessian_diag());
let (beta_mixed, betahessian) = self.embed_local_hyper_parts(
&cache.coeff_global_range,
&beta_mixed_local,
&betahessian_local,
);
Ok((
lambda * state.magnitude.penalty_value(),
beta_mixed,
betahessian,
))
}
1 => {
let lambda = params.lambda[1];
let beta_mixed_local = lambda
* grouped_operatorgradient(
&cache.d1,
cache.dimension,
&state.gradient.betagradient_blocks(),
)
.map_err(|e| e.to_string())?;
let betahessian_local = lambda
* grouped_operatorhessian(
&cache.d1,
cache.dimension,
&state.gradient.betahessian_blocks(),
)
.map_err(|e| e.to_string())?;
let (beta_mixed, betahessian) = self.embed_local_hyper_parts(
&cache.coeff_global_range,
&beta_mixed_local,
&betahessian_local,
);
Ok((
lambda * state.gradient.penalty_value(),
beta_mixed,
betahessian,
))
}
2 => {
let lambda = params.lambda[2];
let beta_mixed_local = lambda
* grouped_operatorgradient(
&cache.d2,
cache.dimension * cache.dimension,
&state.curvature.betagradient_blocks(),
)
.map_err(|e| e.to_string())?;
let betahessian_local = lambda
* grouped_operatorhessian(
&cache.d2,
cache.dimension * cache.dimension,
&state.curvature.betahessian_blocks(),
)
.map_err(|e| e.to_string())?;
let (beta_mixed, betahessian) = self.embed_local_hyper_parts(
&cache.coeff_global_range,
&beta_mixed_local,
&betahessian_local,
);
Ok((
lambda * state.curvature.penalty_value(),
beta_mixed,
betahessian,
))
}
_ => Err(format!("invalid adaptive component index {}", component)),
}
}
fn adaptive_block_log_epsilon_parts(
&self,
eval: &SpatialAdaptiveExactEvaluation,
cache_idx: usize,
component: usize,
) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
let cache = self
.runtime_caches
.get(cache_idx)
.ok_or_else(|| format!("adaptive cache index {} out of bounds", cache_idx))?;
let params = self
.adaptive_params
.get(cache_idx)
.ok_or_else(|| format!("adaptive hyperparameter block {} out of bounds", cache_idx))?;
let state = eval
.adaptive_states
.get(cache_idx)
.ok_or_else(|| format!("adaptive exact state index {} out of bounds", cache_idx))?;
match component {
0 => {
let lambda = params.lambda[0];
let beta_mixed_local = lambda
* scalar_operatorgradient(
&cache.d0,
&state.magnitude.log_epsilon_betagradient_coeff(),
);
let betahessian_local = lambda
* scalar_operatorhessian(
&cache.d0,
&state.magnitude.log_epsilon_betahessian_diag(),
);
let (beta_mixed, betahessian) = self.embed_local_hyper_parts(
&cache.coeff_global_range,
&beta_mixed_local,
&betahessian_local,
);
Ok((
lambda * state.magnitude.log_epsilon_gradient_terms().sum(),
beta_mixed,
betahessian,
))
}
1 => {
let lambda = params.lambda[1];
let beta_mixed_local = lambda
* grouped_operatorgradient(
&cache.d1,
cache.dimension,
&state.gradient.log_epsilon_betagradient_blocks(),
)
.map_err(|e| e.to_string())?;
let betahessian_local = lambda
* grouped_operatorhessian(
&cache.d1,
cache.dimension,
&state.gradient.log_epsilon_betahessian_blocks(),
)
.map_err(|e| e.to_string())?;
let (beta_mixed, betahessian) = self.embed_local_hyper_parts(
&cache.coeff_global_range,
&beta_mixed_local,
&betahessian_local,
);
Ok((
lambda * state.gradient.log_epsilon_gradient_terms().sum(),
beta_mixed,
betahessian,
))
}
2 => {
let lambda = params.lambda[2];
let beta_mixed_local = lambda
* grouped_operatorgradient(
&cache.d2,
cache.dimension * cache.dimension,
&state.curvature.log_epsilon_betagradient_blocks(),
)
.map_err(|e| e.to_string())?;
let betahessian_local = lambda
* grouped_operatorhessian(
&cache.d2,
cache.dimension * cache.dimension,
&state.curvature.log_epsilon_betahessian_blocks(),
)
.map_err(|e| e.to_string())?;
let (beta_mixed, betahessian) = self.embed_local_hyper_parts(
&cache.coeff_global_range,
&beta_mixed_local,
&betahessian_local,
);
Ok((
lambda * state.curvature.log_epsilon_gradient_terms().sum(),
beta_mixed,
betahessian,
))
}
_ => Err(format!("invalid adaptive component index {}", component)),
}
}
fn adaptive_block_log_epsilon_second_parts(
&self,
eval: &SpatialAdaptiveExactEvaluation,
cache_idx: usize,
component: usize,
) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
let cache = self
.runtime_caches
.get(cache_idx)
.ok_or_else(|| format!("adaptive cache index {} out of bounds", cache_idx))?;
let params = self
.adaptive_params
.get(cache_idx)
.ok_or_else(|| format!("adaptive hyperparameter block {} out of bounds", cache_idx))?;
let state = eval
.adaptive_states
.get(cache_idx)
.ok_or_else(|| format!("adaptive exact state index {} out of bounds", cache_idx))?;
match component {
0 => {
let lambda = params.lambda[0];
let beta_mixed_local = lambda
* scalar_operatorgradient(
&cache.d0,
&state.magnitude.log_epsilon_beta_mixed_second_coeff(),
);
let betahessian_local = lambda
* scalar_operatorhessian(
&cache.d0,
&state.magnitude.log_epsilon_betahessian_second_diag(),
);
let (beta_mixed, betahessian) = self.embed_local_hyper_parts(
&cache.coeff_global_range,
&beta_mixed_local,
&betahessian_local,
);
Ok((
lambda * state.magnitude.log_epsilon_hessian_terms().sum(),
beta_mixed,
betahessian,
))
}
1 => {
let lambda = params.lambda[1];
let beta_mixed_local = lambda
* grouped_operatorgradient(
&cache.d1,
cache.dimension,
&state.gradient.log_epsilon_beta_mixed_second_blocks(),
)
.map_err(|e| e.to_string())?;
let betahessian_local = lambda
* grouped_operatorhessian(
&cache.d1,
cache.dimension,
&state.gradient.log_epsilon_betahessian_second_blocks(),
)
.map_err(|e| e.to_string())?;
let (beta_mixed, betahessian) = self.embed_local_hyper_parts(
&cache.coeff_global_range,
&beta_mixed_local,
&betahessian_local,
);
Ok((
lambda * state.gradient.log_epsilon_hessian_terms().sum(),
beta_mixed,
betahessian,
))
}
2 => {
let lambda = params.lambda[2];
let beta_mixed_local = lambda
* grouped_operatorgradient(
&cache.d2,
cache.dimension * cache.dimension,
&state.curvature.log_epsilon_beta_mixed_second_blocks(),
)
.map_err(|e| e.to_string())?;
let betahessian_local = lambda
* grouped_operatorhessian(
&cache.d2,
cache.dimension * cache.dimension,
&state.curvature.log_epsilon_betahessian_second_blocks(),
)
.map_err(|e| e.to_string())?;
let (beta_mixed, betahessian) = self.embed_local_hyper_parts(
&cache.coeff_global_range,
&beta_mixed_local,
&betahessian_local,
);
Ok((
lambda * state.curvature.log_epsilon_hessian_terms().sum(),
beta_mixed,
betahessian,
))
}
_ => Err(format!("invalid adaptive component index {}", component)),
}
}
fn adaptive_shared_log_epsilon_parts(
&self,
eval: &SpatialAdaptiveExactEvaluation,
component: usize,
) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
let (mut score, mut hessian) = self.zero_hyper_parts();
let mut objective = 0.0;
for cache_idx in 0..self.runtime_caches.len() {
let (local_objective, local_score, local_hessian) =
self.adaptive_block_log_epsilon_parts(eval, cache_idx, component)?;
objective += local_objective;
score += &local_score;
hessian += &local_hessian;
}
Ok((objective, score, hessian))
}
fn adaptive_shared_log_epsilon_second_parts(
&self,
eval: &SpatialAdaptiveExactEvaluation,
component: usize,
) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
let (mut score, mut hessian) = self.zero_hyper_parts();
let mut objective = 0.0;
for cache_idx in 0..self.runtime_caches.len() {
let (local_objective, local_score, local_hessian) =
self.adaptive_block_log_epsilon_second_parts(eval, cache_idx, component)?;
objective += local_objective;
score += &local_score;
hessian += &local_hessian;
}
Ok((objective, score, hessian))
}
fn adaptive_shared_log_epsilon_drift(
&self,
eval: &SpatialAdaptiveExactEvaluation,
component: usize,
direction: &Array1<f64>,
) -> Result<Array2<f64>, String> {
let total_dim = self.design.ncols();
let mut total = Array2::<f64>::zeros((total_dim, total_dim));
for cache_idx in 0..self.runtime_caches.len() {
total +=
&self.adaptive_block_log_epsilon_drift(eval, cache_idx, component, direction)?;
}
Ok(total)
}
fn adaptive_explicit_second_order_parts(
&self,
eval: &SpatialAdaptiveExactEvaluation,
left: SpatialAdaptiveHyperSpec,
right: SpatialAdaptiveHyperSpec,
) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
match left.explicit_second_order_kind(right) {
SpatialAdaptiveExplicitSecondOrderKind::StructuralZero => {
let (score, hessian) = self.zero_hyper_parts();
Ok((0.0, score, hessian))
}
SpatialAdaptiveExplicitSecondOrderKind::LocalAlphaAlpha => {
self.adaptive_block_parts(eval, left.cache_index, left.component_index())
}
SpatialAdaptiveExplicitSecondOrderKind::LocalAlphaEta => {
let local_alpha = if left.kind.is_log_lambda() {
left
} else {
right
};
self.adaptive_block_log_epsilon_parts(
eval,
local_alpha.cache_index,
local_alpha.component_index(),
)
}
SpatialAdaptiveExplicitSecondOrderKind::SharedEtaEta => {
self.adaptive_shared_log_epsilon_second_parts(eval, left.component_index())
}
}
}
fn adaptive_block_drift(
&self,
eval: &SpatialAdaptiveExactEvaluation,
cache_idx: usize,
component: usize,
direction: &Array1<f64>,
) -> Result<Array2<f64>, String> {
let cache = self
.runtime_caches
.get(cache_idx)
.ok_or_else(|| format!("adaptive cache index {} out of bounds", cache_idx))?;
let params = self
.adaptive_params
.get(cache_idx)
.ok_or_else(|| format!("adaptive hyperparameter block {} out of bounds", cache_idx))?;
let state = eval
.adaptive_states
.get(cache_idx)
.ok_or_else(|| format!("adaptive exact state index {} out of bounds", cache_idx))?;
let direction_local = direction.slice(s![cache.coeff_global_range.clone()]);
let local_hessian = match component {
0 => {
let d0_u = cache.d0.dot(&direction_local);
params.lambda[0]
* scalar_operatorhessian(
&cache.d0,
&state.magnitude.directionalhessian_diag(&d0_u),
)
}
1 => {
let d1_u = cache.d1.dot(&direction_local);
params.lambda[1]
* grouped_operatorhessian(
&cache.d1,
cache.dimension,
&state.gradient.directionalhessian_blocks(
&collocationgradient_blocks(&d1_u, cache.dimension)
.map_err(|e| e.to_string())?,
),
)
.map_err(|e| e.to_string())?
}
2 => {
let d2_u = cache.d2.dot(&direction_local);
params.lambda[2]
* grouped_operatorhessian(
&cache.d2,
cache.dimension * cache.dimension,
&state.curvature.directionalhessian_blocks(
&collocationhessian_blocks(&d2_u, cache.dimension)
.map_err(|e| e.to_string())?,
),
)
.map_err(|e| e.to_string())?
}
_ => return Err(format!("invalid adaptive component index {}", component)),
};
Ok(self.embed_local_hyper_hessian(&cache.coeff_global_range, &local_hessian))
}
fn adaptive_block_log_epsilon_drift(
&self,
eval: &SpatialAdaptiveExactEvaluation,
cache_idx: usize,
component: usize,
direction: &Array1<f64>,
) -> Result<Array2<f64>, String> {
let cache = self
.runtime_caches
.get(cache_idx)
.ok_or_else(|| format!("adaptive cache index {} out of bounds", cache_idx))?;
let params = self
.adaptive_params
.get(cache_idx)
.ok_or_else(|| format!("adaptive hyperparameter block {} out of bounds", cache_idx))?;
let state = eval
.adaptive_states
.get(cache_idx)
.ok_or_else(|| format!("adaptive exact state index {} out of bounds", cache_idx))?;
let direction_local = direction.slice(s![cache.coeff_global_range.clone()]);
let local_hessian = match component {
0 => {
let d0_u = cache.d0.dot(&direction_local);
params.lambda[0]
* scalar_operatorhessian(
&cache.d0,
&state
.magnitude
.log_epsilon_betahessian_directional_diag(&d0_u),
)
}
1 => {
let d1_u = cache.d1.dot(&direction_local);
params.lambda[1]
* grouped_operatorhessian(
&cache.d1,
cache.dimension,
&state.gradient.log_epsilon_betahessian_directional_blocks(
&collocationgradient_blocks(&d1_u, cache.dimension)
.map_err(|e| e.to_string())?,
),
)
.map_err(|e| e.to_string())?
}
2 => {
let d2_u = cache.d2.dot(&direction_local);
params.lambda[2]
* grouped_operatorhessian(
&cache.d2,
cache.dimension * cache.dimension,
&state.curvature.log_epsilon_betahessian_directional_blocks(
&collocationhessian_blocks(&d2_u, cache.dimension)
.map_err(|e| e.to_string())?,
),
)
.map_err(|e| e.to_string())?
}
_ => return Err(format!("invalid adaptive component index {}", component)),
};
Ok(self.embed_local_hyper_hessian(&cache.coeff_global_range, &local_hessian))
}
fn adaptive_hyper_parts(
&self,
eval: &SpatialAdaptiveExactEvaluation,
hyper: SpatialAdaptiveHyperSpec,
) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
match hyper.kind {
SpatialAdaptiveHyperKind::LogLambdaMagnitude => {
let (mut beta_mixed, mut betahessian) = self.zero_hyper_parts();
let cache = self.runtime_caches.get(hyper.cache_index).ok_or_else(|| {
format!("adaptive cache index {} out of bounds", hyper.cache_index)
})?;
let params = self.adaptive_params.get(hyper.cache_index).ok_or_else(|| {
format!(
"adaptive hyperparameter block {} out of bounds",
hyper.cache_index
)
})?;
let state = eval.adaptive_states.get(hyper.cache_index).ok_or_else(|| {
format!(
"adaptive exact state index {} out of bounds",
hyper.cache_index
)
})?;
let lambda = params.lambda[0];
let beta_mixed_local = lambda
* scalar_operatorgradient(&cache.d0, &state.magnitude.betagradient_coeff());
let betahessian_local =
lambda * scalar_operatorhessian(&cache.d0, &state.magnitude.betahessian_diag());
beta_mixed
.slice_mut(s![cache.coeff_global_range.clone()])
.assign(&beta_mixed_local);
betahessian
.slice_mut(s![
cache.coeff_global_range.clone(),
cache.coeff_global_range.clone()
])
.assign(&betahessian_local);
Ok((
lambda * state.magnitude.penalty_value(),
beta_mixed,
betahessian,
))
}
SpatialAdaptiveHyperKind::LogLambdaGradient => {
let (mut beta_mixed, mut betahessian) = self.zero_hyper_parts();
let cache = self.runtime_caches.get(hyper.cache_index).ok_or_else(|| {
format!("adaptive cache index {} out of bounds", hyper.cache_index)
})?;
let params = self.adaptive_params.get(hyper.cache_index).ok_or_else(|| {
format!(
"adaptive hyperparameter block {} out of bounds",
hyper.cache_index
)
})?;
let state = eval.adaptive_states.get(hyper.cache_index).ok_or_else(|| {
format!(
"adaptive exact state index {} out of bounds",
hyper.cache_index
)
})?;
let lambda = params.lambda[1];
let beta_mixed_local = lambda
* grouped_operatorgradient(
&cache.d1,
cache.dimension,
&state.gradient.betagradient_blocks(),
)
.map_err(|e| e.to_string())?;
let betahessian_local = lambda
* grouped_operatorhessian(
&cache.d1,
cache.dimension,
&state.gradient.betahessian_blocks(),
)
.map_err(|e| e.to_string())?;
beta_mixed
.slice_mut(s![cache.coeff_global_range.clone()])
.assign(&beta_mixed_local);
betahessian
.slice_mut(s![
cache.coeff_global_range.clone(),
cache.coeff_global_range.clone()
])
.assign(&betahessian_local);
Ok((
lambda * state.gradient.penalty_value(),
beta_mixed,
betahessian,
))
}
SpatialAdaptiveHyperKind::LogLambdaCurvature => {
let (mut beta_mixed, mut betahessian) = self.zero_hyper_parts();
let cache = self.runtime_caches.get(hyper.cache_index).ok_or_else(|| {
format!("adaptive cache index {} out of bounds", hyper.cache_index)
})?;
let params = self.adaptive_params.get(hyper.cache_index).ok_or_else(|| {
format!(
"adaptive hyperparameter block {} out of bounds",
hyper.cache_index
)
})?;
let state = eval.adaptive_states.get(hyper.cache_index).ok_or_else(|| {
format!(
"adaptive exact state index {} out of bounds",
hyper.cache_index
)
})?;
let lambda = params.lambda[2];
let beta_mixed_local = lambda
* grouped_operatorgradient(
&cache.d2,
cache.dimension * cache.dimension,
&state.curvature.betagradient_blocks(),
)
.map_err(|e| e.to_string())?;
let betahessian_local = lambda
* grouped_operatorhessian(
&cache.d2,
cache.dimension * cache.dimension,
&state.curvature.betahessian_blocks(),
)
.map_err(|e| e.to_string())?;
beta_mixed
.slice_mut(s![cache.coeff_global_range.clone()])
.assign(&beta_mixed_local);
betahessian
.slice_mut(s![
cache.coeff_global_range.clone(),
cache.coeff_global_range.clone()
])
.assign(&betahessian_local);
Ok((
lambda * state.curvature.penalty_value(),
beta_mixed,
betahessian,
))
}
SpatialAdaptiveHyperKind::LogEpsilonMagnitude
| SpatialAdaptiveHyperKind::LogEpsilonGradient
| SpatialAdaptiveHyperKind::LogEpsilonCurvature => {
self.adaptive_shared_log_epsilon_parts(eval, hyper.component_index())
}
}
}
fn exact_evaluation_uncached(
&self,
beta: &Array1<f64>,
) -> Result<SpatialAdaptiveExactEvaluation, String> {
let eta = self.total_eta(beta);
let obs = evaluate_standard_familyobservations(
self.family,
self.latent_cloglog_state.as_ref(),
self.mixture_link_state.as_ref(),
self.sas_link_state.as_ref(),
&self.y,
&self.weights,
&eta,
)
.map_err(|e| e.to_string())?;
let p = beta.len();
let mut penalty_value = 0.0;
let mut penaltygradient = Array1::<f64>::zeros(p);
let mut penaltyhessian = Array2::<f64>::zeros((p, p));
let mut adaptive_states = Vec::with_capacity(self.runtime_caches.len());
for (cache_idx, cache) in self.runtime_caches.iter().enumerate() {
let params = self.adaptive_params.get(cache_idx).ok_or_else(|| {
format!(
"missing adaptive parameter block for cache {}",
cache.termname
)
})?;
let beta_local = beta.slice(s![cache.coeff_global_range.clone()]);
let state =
SpatialPenaltyExactState::from_beta_local(beta_local, cache, params.epsilon)
.map_err(|e| e.to_string())?;
let g0 = scalar_operatorgradient(&cache.d0, &state.magnitude.betagradient_coeff());
let gg = grouped_operatorgradient(
&cache.d1,
cache.dimension,
&state.gradient.betagradient_blocks(),
)
.map_err(|e| e.to_string())?;
let gc = grouped_operatorgradient(
&cache.d2,
cache.dimension * cache.dimension,
&state.curvature.betagradient_blocks(),
)
.map_err(|e| e.to_string())?;
let h0 = scalar_operatorhessian(&cache.d0, &state.magnitude.betahessian_diag());
let hg = grouped_operatorhessian(
&cache.d1,
cache.dimension,
&state.gradient.betahessian_blocks(),
)
.map_err(|e| e.to_string())?;
let hc = grouped_operatorhessian(
&cache.d2,
cache.dimension * cache.dimension,
&state.curvature.betahessian_blocks(),
)
.map_err(|e| e.to_string())?;
let lambda0 = params.lambda[0];
let lambdag = params.lambda[1];
let lambdac = params.lambda[2];
penalty_value += lambda0 * state.magnitude.penalty_value();
penalty_value += lambdag * state.gradient.penalty_value();
penalty_value += lambdac * state.curvature.penalty_value();
let range = cache.coeff_global_range.clone();
{
let mut grad_local = penaltygradient.slice_mut(s![range.clone()]);
grad_local += &(g0.mapv(|v| lambda0 * v));
grad_local += &(gg.mapv(|v| lambdag * v));
grad_local += &(gc.mapv(|v| lambdac * v));
}
{
let mut h_local = penaltyhessian.slice_mut(s![range.clone(), range]);
h_local += &h0.mapv(|v| lambda0 * v);
h_local += &hg.mapv(|v| lambdag * v);
h_local += &hc.mapv(|v| lambdac * v);
}
adaptive_states.push(state);
}
let (fixed_quadraticvalue, fixed_quadraticgradient) = self.fixed_quadratic_terms(beta);
Ok(SpatialAdaptiveExactEvaluation {
obs,
adaptive_states,
adaptive_penalty_value: penalty_value,
adaptive_penaltygradient: penaltygradient,
adaptive_penaltyhessian: penaltyhessian,
fixed_quadraticvalue,
fixed_quadraticgradient,
fixed_quadratichessian: self.fixed_quadratichessian.as_ref().clone(),
})
}
fn exact_evaluation(
&self,
beta: &Array1<f64>,
) -> Result<Arc<SpatialAdaptiveExactEvaluation>, String> {
{
let cache = self
.exact_eval_cache
.lock()
.map_err(|_| "spatial adaptive exact-evaluation cache lock poisoned".to_string())?;
if let Some(cached) = cache.as_ref()
&& cached.beta.len() == beta.len()
&& cached
.beta
.iter()
.zip(beta.iter())
.all(|(&left, &right)| left == right)
{
return Ok(Arc::clone(&cached.eval));
}
}
let eval = Arc::new(self.exact_evaluation_uncached(beta)?);
let mut cache = self
.exact_eval_cache
.lock()
.map_err(|_| "spatial adaptive exact-evaluation cache lock poisoned".to_string())?;
*cache = Some(CachedSpatialAdaptiveExactEvaluation {
beta: beta.clone(),
eval: Arc::clone(&eval),
});
Ok(eval)
}
fn exacthessian_directional_derivative_from_evaluation(
&self,
_: &Array1<f64>,
eval: &SpatialAdaptiveExactEvaluation,
direction: &Array1<f64>,
) -> Result<Array2<f64>, String> {
let d_eta = crate::faer_ndarray::fast_av(self.design.as_ref(), direction);
let mut total = xt_diag_x_dense(
self.design.view(),
(&eval.obs.neghessian_eta_derivative * &d_eta).view(),
)?;
for (cache_idx, cache) in self.runtime_caches.iter().enumerate() {
let params = self.adaptive_params.get(cache_idx).ok_or_else(|| {
format!(
"missing adaptive parameter block for cache {}",
cache.termname
)
})?;
let state = eval
.adaptive_states
.get(cache_idx)
.ok_or_else(|| format!("missing adaptive state for cache {}", cache.termname))?;
let direction_local = direction.slice(s![cache.coeff_global_range.clone()]);
let d0_u = cache.d0.dot(&direction_local);
let d1_u = cache.d1.dot(&direction_local);
let d2_u = cache.d2.dot(&direction_local);
let h0 =
scalar_operatorhessian(&cache.d0, &state.magnitude.directionalhessian_diag(&d0_u))
.mapv(|v| params.lambda[0] * v);
let hg = grouped_operatorhessian(
&cache.d1,
cache.dimension,
&state.gradient.directionalhessian_blocks(
&collocationgradient_blocks(&d1_u, cache.dimension)
.map_err(|e| e.to_string())?,
),
)
.map_err(|e| e.to_string())?
.mapv(|v| params.lambda[1] * v);
let hc = grouped_operatorhessian(
&cache.d2,
cache.dimension * cache.dimension,
&state.curvature.directionalhessian_blocks(
&collocationhessian_blocks(&d2_u, cache.dimension)
.map_err(|e| e.to_string())?,
),
)
.map_err(|e| e.to_string())?
.mapv(|v| params.lambda[2] * v);
let range = cache.coeff_global_range.clone();
let mut local = total.slice_mut(s![range.clone(), range]);
local += &h0;
local += &hg;
local += &hc;
}
Ok(total)
}
}
impl CustomFamily for SpatialAdaptiveExactFamily {
fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
let eval = self.exact_evaluation(beta)?;
let mut gradient = fast_atv(&self.design, &eval.obs.score);
gradient -= &eval.total_penaltygradient();
let mut hessian = xt_diag_x_dense(self.design.view(), eval.obs.neghessian_eta.view())?;
hessian += &eval.total_penaltyhessian();
Ok(FamilyEvaluation {
log_likelihood: eval.obs.log_likelihood - eval.total_penalty_value(),
blockworking_sets: vec![BlockWorkingSet::ExactNewton {
gradient,
hessian: SymmetricMatrix::Dense(hessian),
}],
})
}
fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
let state = expect_single_block_state(block_states, "spatial adaptive exact family")?;
let beta = &state.beta;
let obs = evaluate_standard_familyobservations(
self.family,
self.latent_cloglog_state.as_ref(),
self.mixture_link_state.as_ref(),
self.sas_link_state.as_ref(),
&self.y,
&self.weights,
&state.eta,
)
.map_err(|e| e.to_string())?;
let adaptive_penalty = self.adaptive_penalty_value_only(beta)?;
let (fixed_quadratic, _) = self.fixed_quadratic_terms(beta);
Ok(obs.log_likelihood - adaptive_penalty - fixed_quadratic)
}
fn exact_newton_outerobjective(&self) -> ExactNewtonOuterObjective {
ExactNewtonOuterObjective::StrictPseudoLaplace
}
fn exact_newton_allows_semidefinitehessian(&self) -> bool {
true
}
fn exact_newton_joint_hessian(
&self,
block_states: &[ParameterBlockState],
) -> Result<Option<Array2<f64>>, String> {
let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
let eval = self.exact_evaluation(beta)?;
Ok(Some(eval.totalobjectivehessian(&self.design)?))
}
fn exact_newton_hessian_directional_derivative(
&self,
block_states: &[ParameterBlockState],
block_idx: usize,
d_beta: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
expect_block_idx_zero(block_idx, "spatial adaptive exact family", "")?;
self.exact_newton_joint_hessian_directional_derivative(block_states, d_beta)
}
fn exact_newton_joint_hessian_directional_derivative(
&self,
block_states: &[ParameterBlockState],
d_beta_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
if d_beta_flat.len() != beta.len() {
return Err(format!(
"spatial adaptive exact family direction length mismatch: got {}, expected {}",
d_beta_flat.len(),
beta.len()
));
}
let eval = self.exact_evaluation(beta)?;
Ok(Some(
self.exacthessian_directional_derivative_from_evaluation(beta, &eval, d_beta_flat)?,
))
}
fn block_linear_constraints(
&self,
_: &[ParameterBlockState],
block_idx: usize,
_: &ParameterBlockSpec,
) -> Result<Option<LinearInequalityConstraints>, String> {
expect_block_idx_zero(block_idx, "spatial adaptive exact family", "")?;
Ok(self.linear_constraints.clone())
}
fn exact_newton_joint_psi_terms(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
psi_index: usize,
) -> Result<Option<ExactNewtonJointPsiTerms>, String> {
if block_states.len() != 1 || specs.len() != 1 || derivative_blocks.len() != 1 {
return Err(format!(
"spatial adaptive exact family expects one block/state/spec/psi payload, got states={} specs={} deriv_blocks={}",
block_states.len(),
specs.len(),
derivative_blocks.len()
));
}
derivative_blocks[0]
.get(psi_index)
.ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_index))?;
let hyper = self
.hyperspecs
.get(psi_index)
.ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_index))?;
let beta = &block_states[0].beta;
let eval = self.exact_evaluation(beta)?;
let (direct, beta_mixed, betahessian_explicit) =
self.adaptive_hyper_parts(&eval, *hyper)?;
Ok(Some(ExactNewtonJointPsiTerms {
objective_psi: direct,
score_psi: beta_mixed,
hessian_psi: betahessian_explicit,
hessian_psi_operator: None,
}))
}
fn exact_newton_joint_psisecond_order_terms(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
psi_i: usize,
psi_j: usize,
) -> Result<Option<crate::custom_family::ExactNewtonJointPsiSecondOrderTerms>, String> {
if block_states.len() != 1 || specs.len() != 1 || derivative_blocks.len() != 1 {
return Err(format!(
"spatial adaptive exact family expects one block/state/spec/psi payload, got states={} specs={} deriv_blocks={}",
block_states.len(),
specs.len(),
derivative_blocks.len()
));
}
derivative_blocks[0]
.get(psi_i)
.ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_i))?;
derivative_blocks[0]
.get(psi_j)
.ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_j))?;
let hyper_i = self
.hyperspecs
.get(psi_i)
.ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_i))?;
let hyper_j = self
.hyperspecs
.get(psi_j)
.ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_j))?;
let beta = &block_states[0].beta;
let eval = self.exact_evaluation(beta)?;
let (objective_psi_psi, score_psi_psi, hessian_psi_psi) =
self.adaptive_explicit_second_order_parts(&eval, *hyper_i, *hyper_j)?;
Ok(Some(
crate::custom_family::ExactNewtonJointPsiSecondOrderTerms {
objective_psi_psi,
score_psi_psi,
hessian_psi_psi,
hessian_psi_psi_operator: None,
},
))
}
fn exact_newton_joint_psihessian_directional_derivative(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
psi_index: usize,
direction: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
if block_states.len() != 1 || specs.len() != 1 || derivative_blocks.len() != 1 {
return Err(format!(
"spatial adaptive exact family expects one block/state/spec/psi payload, got states={} specs={} deriv_blocks={}",
block_states.len(),
specs.len(),
derivative_blocks.len()
));
}
let beta = &block_states[0].beta;
if direction.len() != beta.len() {
return Err(format!(
"spatial adaptive exact family direction length mismatch: got {}, expected {}",
direction.len(),
beta.len()
));
}
derivative_blocks[0]
.get(psi_index)
.ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_index))?;
let hyper = self
.hyperspecs
.get(psi_index)
.ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_index))?;
let eval = self.exact_evaluation(beta)?;
let drift = match hyper.kind {
SpatialAdaptiveHyperKind::LogLambdaMagnitude
| SpatialAdaptiveHyperKind::LogLambdaGradient
| SpatialAdaptiveHyperKind::LogLambdaCurvature => self.adaptive_block_drift(
&eval,
hyper.cache_index,
hyper.kind.component_index(),
direction,
)?,
SpatialAdaptiveHyperKind::LogEpsilonMagnitude
| SpatialAdaptiveHyperKind::LogEpsilonGradient
| SpatialAdaptiveHyperKind::LogEpsilonCurvature => self
.adaptive_shared_log_epsilon_drift(
&eval,
hyper.kind.component_index(),
direction,
)?,
};
Ok(Some(drift))
}
}
fn expect_single_block_state<'a>(
block_states: &'a [ParameterBlockState],
family_name: &str,
) -> Result<&'a ParameterBlockState, String> {
if block_states.len() != 1 {
return Err(format!(
"{family_name} expects 1 block, got {}",
block_states.len()
));
}
Ok(&block_states[0])
}
fn expect_block_idx_zero(block_idx: usize, family_name: &str, context: &str) -> Result<(), String> {
if block_idx != 0 {
return Err(format!(
"{family_name} expects block_idx 0{context}, got {block_idx}"
));
}
Ok(())
}
impl BoundedLinearFamily {
fn bounded_term_derivative_data(
&self,
latent_beta: &Array1<f64>,
) -> (
Array1<f64>,
Array1<f64>,
Array1<f64>,
Array1<f64>,
Array1<f64>,
) {
let p = latent_beta.len();
let mut beta_user = latent_beta.clone();
let mut jac_diag = Array1::<f64>::ones(p);
let mut second_diag = Array1::<f64>::zeros(p);
let mut third_diag = Array1::<f64>::zeros(p);
let mut priorthird = Array1::<f64>::zeros(p);
for term in &self.bounded_terms {
let (beta, _, db_dtheta, d2b_dtheta2, d3b_dtheta3) =
bounded_latent_derivatives(latent_beta[term.col_idx], term.min, term.max);
beta_user[term.col_idx] = beta;
jac_diag[term.col_idx] = db_dtheta;
second_diag[term.col_idx] = d2b_dtheta2;
third_diag[term.col_idx] = d3b_dtheta3;
let (_, _, _, prior_neghess_derivative) =
bounded_prior_terms(latent_beta[term.col_idx], &term.prior);
priorthird[term.col_idx] = prior_neghess_derivative;
}
(beta_user, jac_diag, second_diag, third_diag, priorthird)
}
fn user_beta_and_jacobian(&self, latent_beta: &Array1<f64>) -> (Array1<f64>, Array1<f64>) {
let (beta_user, jac_diag, _, _, _) = self.bounded_term_derivative_data(latent_beta);
(beta_user, jac_diag)
}
fn nonlinear_offset_from_latent(&self, latent_beta: &Array1<f64>) -> Array1<f64> {
let mut offset = self.offset.clone();
for term in &self.bounded_terms {
let (beta, _, _) =
bounded_latent_to_user(latent_beta[term.col_idx], term.min, term.max);
offset += &(self.design.column(term.col_idx).to_owned() * beta);
}
offset
}
fn effective_design_for_latent(&self, jac_diag: &Array1<f64>) -> Array2<f64> {
let mut x_eff = self.design.clone();
for term in &self.bounded_terms {
let scaled = self.design.column(term.col_idx).to_owned() * jac_diag[term.col_idx];
x_eff.column_mut(term.col_idx).assign(&scaled);
}
x_eff
}
fn exacthessian_andgradient(
&self,
latent_beta: &Array1<f64>,
) -> Result<
(
StandardFamilyObservationState,
Array2<f64>,
Array1<f64>,
f64,
Array1<f64>,
Array1<f64>,
Array1<f64>,
),
String,
> {
let (_, jac_diag, second_diag, third_diag, priorthird) =
self.bounded_term_derivative_data(latent_beta);
let x_eff = self.effective_design_for_latent(&jac_diag);
let eta =
self.designzeroed.dot(latent_beta) + self.nonlinear_offset_from_latent(latent_beta);
let obs = evaluate_standard_familyobservations(
self.family,
self.latent_cloglog_state.as_ref(),
self.mixture_link_state.as_ref(),
self.sas_link_state.as_ref(),
&self.y,
&self.weights,
&eta,
)
.map_err(|e| e.to_string())?;
let mut priorgrad = Array1::<f64>::zeros(latent_beta.len());
let mut prior_neghess = Array2::<f64>::zeros((latent_beta.len(), latent_beta.len()));
let mut prior_loglik = 0.0;
for term in &self.bounded_terms {
let (logp, grad, neghess, _) =
bounded_prior_terms(latent_beta[term.col_idx], &term.prior);
prior_loglik += logp;
priorgrad[term.col_idx] += grad;
prior_neghess[[term.col_idx, term.col_idx]] += neghess;
}
let mut hessian = xt_diag_x_dense(x_eff.view(), obs.neghessian_eta.view())?;
let mut gradient = fast_atv(&x_eff, &obs.score);
for term in &self.bounded_terms {
let score_beta = self.design.column(term.col_idx).dot(&obs.score);
hessian[[term.col_idx, term.col_idx]] -= score_beta * second_diag[term.col_idx];
}
hessian += &prior_neghess;
gradient += &priorgrad;
Ok((
obs,
hessian,
gradient,
prior_loglik,
second_diag,
third_diag,
priorthird,
))
}
fn evaluation_from_latent(
&self,
latent_beta: &Array1<f64>,
) -> Result<
(
StandardFamilyObservationState,
Array2<f64>,
Array1<f64>,
f64,
),
String,
> {
let (obs, hessian, gradient, prior_loglik, _, _, _) =
self.exacthessian_andgradient(latent_beta)?;
Ok((obs, hessian, gradient, prior_loglik))
}
}
impl CustomFamily for BoundedLinearFamily {
fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
let latent_beta = &expect_single_block_state(block_states, "bounded linear family")?.beta;
let (obs, hessian, gradient, prior_loglik) = self.evaluation_from_latent(latent_beta)?;
Ok(FamilyEvaluation {
log_likelihood: obs.log_likelihood + prior_loglik,
blockworking_sets: vec![BlockWorkingSet::ExactNewton {
gradient,
hessian: SymmetricMatrix::Dense(hessian),
}],
})
}
fn exact_newton_joint_hessian(
&self,
block_states: &[ParameterBlockState],
) -> Result<Option<Array2<f64>>, String> {
let latent_beta = &expect_single_block_state(block_states, "bounded linear family")?.beta;
let (_, hessian, _, _) = self.evaluation_from_latent(latent_beta)?;
Ok(Some(hessian))
}
fn exact_newton_hessian_directional_derivative(
&self,
block_states: &[ParameterBlockState],
block_idx: usize,
d_beta: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
expect_block_idx_zero(block_idx, "bounded linear family", "")?;
self.exact_newton_joint_hessian_directional_derivative(block_states, d_beta)
}
fn exact_newton_joint_hessian_directional_derivative(
&self,
block_states: &[ParameterBlockState],
d_beta_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let latent_beta = &expect_single_block_state(block_states, "bounded linear family")?.beta;
if d_beta_flat.len() != latent_beta.len() {
return Err(format!(
"bounded linear family directional derivative length mismatch: got {}, expected {}",
d_beta_flat.len(),
latent_beta.len()
));
}
let (obs, _, _, _, second_diag, third_diag, priorthird) =
self.exacthessian_andgradient(latent_beta)?;
let (_, jac_diag, _, _, _) = self.bounded_term_derivative_data(latent_beta);
let x_eff = self.effective_design_for_latent(&jac_diag);
let deta = x_eff.dot(d_beta_flat);
let d_neghess_eta = &obs.neghessian_eta_derivative * &deta;
let mut dx_eff = Array2::<f64>::zeros(x_eff.raw_dim());
for term in &self.bounded_terms {
let scale = second_diag[term.col_idx] * d_beta_flat[term.col_idx];
if scale != 0.0 {
let col = self.design.column(term.col_idx).to_owned() * scale;
dx_eff.column_mut(term.col_idx).assign(&col);
}
}
let mut dhessian = xt_diag_x_dense(x_eff.view(), d_neghess_eta.view())?;
let mut wxdx = Array2::<f64>::zeros((x_eff.ncols(), x_eff.ncols()));
for i in 0..x_eff.nrows() {
let wi = obs.neghessian_eta[i];
if wi == 0.0 {
continue;
}
for a in 0..x_eff.ncols() {
let xa = x_eff[[i, a]];
for b in 0..x_eff.ncols() {
wxdx[[a, b]] += wi * (dx_eff[[i, a]] * x_eff[[i, b]] + xa * dx_eff[[i, b]]);
}
}
}
dhessian += &wxdx;
let d_score = -&obs.neghessian_eta * &deta;
for term in &self.bounded_terms {
let score_beta = self.design.column(term.col_idx).dot(&obs.score);
let d_score_beta = self.design.column(term.col_idx).dot(&d_score);
dhessian[[term.col_idx, term.col_idx]] -= d_score_beta * second_diag[term.col_idx]
+ score_beta * third_diag[term.col_idx] * d_beta_flat[term.col_idx];
dhessian[[term.col_idx, term.col_idx]] +=
priorthird[term.col_idx] * d_beta_flat[term.col_idx];
}
Ok(Some(dhessian))
}
fn block_geometry(
&self,
block_states: &[ParameterBlockState],
spec: &ParameterBlockSpec,
) -> Result<(DesignMatrix, Array1<f64>), String> {
if block_states.is_empty() {
return Ok((
DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
self.designzeroed.clone(),
)),
self.offset.clone(),
));
}
let offset = self.nonlinear_offset_from_latent(
&expect_single_block_state(block_states, "bounded linear family")?.beta,
);
let x = if spec.design.ncols() == self.designzeroed.ncols() {
self.designzeroed.clone()
} else {
return Err("bounded linear family design column mismatch".to_string());
};
Ok((
DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(x)),
offset,
))
}
fn block_geometry_is_dynamic(&self) -> bool {
true
}
fn block_geometry_directional_derivative(
&self,
block_states: &[ParameterBlockState],
block_idx: usize,
spec: &ParameterBlockSpec,
d_beta: &Array1<f64>,
) -> Result<Option<BlockGeometryDirectionalDerivative>, String> {
expect_block_idx_zero(
block_idx,
"bounded linear family",
" for geometry derivative",
)?;
expect_single_block_state(block_states, "bounded linear family")?;
if d_beta.len() != spec.design.ncols() {
return Err(format!(
"bounded linear family geometry derivative direction mismatch: got {}, expected {}",
d_beta.len(),
spec.design.ncols()
));
}
let (_, jac_diag, _, _, _) = self.bounded_term_derivative_data(&block_states[0].beta);
let mut d_offset = Array1::<f64>::zeros(self.offset.len());
let has_drift = self
.bounded_terms
.iter()
.any(|term| jac_diag[term.col_idx] != 0.0 && d_beta[term.col_idx] != 0.0);
if !has_drift {
return Ok(Some(BlockGeometryDirectionalDerivative {
d_design: None,
d_offset,
}));
}
for term in &self.bounded_terms {
let col = term.col_idx;
let drift = jac_diag[col] * d_beta[col];
if drift != 0.0 {
d_offset += &(self.design.column(col).to_owned() * drift);
}
}
Ok(Some(BlockGeometryDirectionalDerivative {
d_design: None,
d_offset,
}))
}
}
#[inline]
fn dense_diag_gram_chunkrows(p: usize) -> usize {
const MIN_ROWS: usize = 512;
const MAX_ROWS: usize = 2048;
const TARGET_BYTES: usize = 2 * 1024 * 1024;
let bytes_per_row = p.max(1) * std::mem::size_of::<f64>();
(TARGET_BYTES / bytes_per_row).clamp(MIN_ROWS, MAX_ROWS)
}
fn xt_diag_x_dense(x: ArrayView2<'_, f64>, w: ArrayView1<'_, f64>) -> Result<Array2<f64>, String> {
if x.nrows() != w.len() {
return Err("xt_diag_x_dense row mismatch".to_string());
}
let (n, p) = x.dim();
if n == 0 || p == 0 {
return Ok(Array2::<f64>::zeros((p, p)));
}
const STREAMING_BYTES_THRESHOLD: usize = 8 * 1024 * 1024;
let dense_work_bytes = n
.checked_mul(p)
.and_then(|cells| cells.checked_mul(std::mem::size_of::<f64>()))
.unwrap_or(usize::MAX);
if dense_work_bytes <= STREAMING_BYTES_THRESHOLD {
let mut weighted = x.to_owned();
ndarray::Zip::from(weighted.rows_mut())
.and(w)
.par_for_each(|mut row, wi| row *= *wi);
return Ok(fast_atb(&x, &weighted));
}
let chunkrows = dense_diag_gram_chunkrows(p).min(n);
let mut weighted_chunk = Array2::<f64>::zeros((chunkrows, p));
let mut out = Array2::<f64>::zeros((p, p));
for row_start in (0..n).step_by(chunkrows) {
let rows = (n - row_start).min(chunkrows);
let x_chunk = x.slice(s![row_start..row_start + rows, ..]);
{
let mut chunk = weighted_chunk.slice_mut(s![0..rows, ..]);
for local_row in 0..rows {
let scale = w[row_start + local_row];
if scale == 0.0 {
chunk.row_mut(local_row).fill(0.0);
continue;
}
for col in 0..p {
chunk[[local_row, col]] = x_chunk[[local_row, col]] * scale;
}
}
}
out += &fast_atb(&x_chunk, &weighted_chunk.slice(s![0..rows, ..]));
}
Ok(out)
}
fn trace_of_dense_product(a: &Array2<f64>, b: &Array2<f64>) -> Result<f64, String> {
if a.nrows() != a.ncols() || b.nrows() != b.ncols() || a.nrows() != b.nrows() {
return Err("trace_of_dense_product dimension mismatch".to_string());
}
let mut trace = 0.0;
for i in 0..a.nrows() {
for j in 0..a.ncols() {
trace += a[[i, j]] * b[[j, i]];
}
}
Ok(trace)
}
fn exact_bounded_edf(
penalties: &[PenaltySpec],
lambdas: &Array1<f64>,
latent_cov: &Array2<f64>,
) -> Result<(Vec<f64>, f64), EstimationError> {
if penalties.len() != lambdas.len() {
return Err(EstimationError::InvalidInput(format!(
"bounded EDF penalty/lambda mismatch: {} penalties vs {} lambdas",
penalties.len(),
lambdas.len()
)));
}
if latent_cov.nrows() != latent_cov.ncols() {
return Err(EstimationError::InvalidInput(
"bounded EDF covariance must be square".to_string(),
));
}
let p = latent_cov.nrows();
let mut s_lambda = Array2::<f64>::zeros((p, p));
let mut edf_by_block = Vec::with_capacity(penalties.len());
let mut trace_sum = 0.0;
for (k, ps) in penalties.iter().enumerate() {
let lambda_k = lambdas[k];
match ps {
PenaltySpec::Block {
local, col_range, ..
} => {
s_lambda
.slice_mut(ndarray::s![col_range.clone(), col_range.clone()])
.scaled_add(lambda_k, local);
let penalty_rank =
local
.nrows()
.saturating_sub(estimate_penalty_nullity(local).map_err(|e| {
EstimationError::InvalidInput(format!("bounded EDF rank failed: {e}"))
})?);
let cov_block = latent_cov.slice(ndarray::s![col_range.clone(), col_range.clone()]);
let trace_k = lambda_k
* trace_of_dense_product(&cov_block.to_owned(), local)
.map_err(EstimationError::InvalidInput)?;
trace_sum += trace_k;
let p_k = penalty_rank as f64;
edf_by_block.push((p_k - trace_k).clamp(0.0, p_k));
}
PenaltySpec::Dense(m) => {
s_lambda.scaled_add(lambda_k, m);
let penalty_rank = p.saturating_sub(estimate_penalty_nullity(m).map_err(|e| {
EstimationError::InvalidInput(format!("bounded EDF rank failed: {e}"))
})?);
let trace_k = lambda_k
* trace_of_dense_product(latent_cov, m)
.map_err(EstimationError::InvalidInput)?;
trace_sum += trace_k;
let p_k = penalty_rank as f64;
edf_by_block.push((p_k - trace_k).clamp(0.0, p_k));
}
}
}
let nullity_total = estimate_penalty_nullity(&s_lambda)
.map_err(|e| EstimationError::InvalidInput(format!("bounded EDF nullity failed: {e}")))?
as f64;
let edf_total = (p as f64 - trace_sum).clamp(nullity_total, p as f64);
Ok((edf_by_block, edf_total))
}
fn fit_bounded_term_collection_with_design(
y: ArrayView1<'_, f64>,
weights: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
spec: &TermCollectionSpec,
design: &TermCollectionDesign,
heuristic_lambdas: Option<&[f64]>,
family: LikelihoodFamily,
options: &FitOptions,
) -> Result<FittedTermCollection, EstimationError> {
let conditioning_cols: Vec<usize> = spec
.linear_terms
.iter()
.enumerate()
.filter_map(|(j, linear)| {
(!linear.double_penalty).then_some(design.intercept_range.end + j)
})
.collect();
let conditioning = LinearFitConditioning::from_columns(&design, &conditioning_cols);
let dense_design = design.design.to_dense_cow();
let fit_design = conditioning.apply_to_design(&dense_design);
let fit_penalties = conditioning
.transform_blockwise_penalties_to_internal(&design.penalties, design.design.ncols());
if design.linear_constraints.is_some() {
return Err(EstimationError::InvalidInput(
"bounded() terms are not yet compatible with explicit linear constraints".to_string(),
));
}
let mut bounded_terms = Vec::<BoundedLinearTermMeta>::new();
for (j, term) in spec.linear_terms.iter().enumerate() {
if term.double_penalty
&& matches!(
term.coefficient_geometry,
LinearCoefficientGeometry::Bounded { .. }
)
{
return Err(EstimationError::InvalidInput(format!(
"bounded linear term '{}' cannot also use double_penalty",
term.name
)));
}
if let LinearCoefficientGeometry::Bounded { min, max, prior } =
term.coefficient_geometry.clone()
{
let col_idx = design.intercept_range.end + j;
let (min_internal, max_internal) = conditioning.internal_bounds_for(col_idx, min, max);
bounded_terms.push(BoundedLinearTermMeta {
col_idx,
min: min_internal,
max: max_internal,
prior,
});
}
}
if bounded_terms.is_empty() {
return Err(EstimationError::InvalidInput(
"internal bounded fit path called with no bounded terms".to_string(),
));
}
let mut designzeroed = fit_design.clone();
let mut initial_beta = Array1::<f64>::zeros(fit_design.ncols());
for term in &bounded_terms {
designzeroed.column_mut(term.col_idx).fill(0.0);
initial_beta[term.col_idx] = bounded_logit(0.5);
}
let initial_log_lambdas = heuristic_lambdas
.map(|vals| Array1::from_vec(vals.to_vec()))
.unwrap_or_else(|| Array1::zeros(fit_penalties.len()));
if initial_log_lambdas.len() != fit_penalties.len() {
return Err(EstimationError::InvalidInput(format!(
"heuristic lambda length mismatch for bounded model: got {}, expected {}",
initial_log_lambdas.len(),
fit_penalties.len()
)));
}
let family_adapter = BoundedLinearFamily {
family,
latent_cloglog_state: options.latent_cloglog,
mixture_link_state: options
.mixture_link
.clone()
.as_ref()
.map(state_fromspec)
.transpose()
.map_err(EstimationError::InvalidInput)?,
sas_link_state: options
.sas_link
.map(|spec| {
if matches!(family, LikelihoodFamily::BinomialBetaLogistic) {
state_from_beta_logisticspec(spec)
} else {
state_from_sasspec(spec)
}
})
.transpose()
.map_err(EstimationError::InvalidInput)?,
y: y.to_owned(),
weights: weights.to_owned(),
design: fit_design.clone(),
designzeroed: designzeroed.clone(),
offset: offset.to_owned(),
bounded_terms: bounded_terms.clone(),
};
let blockspec = ParameterBlockSpec {
name: "eta".to_string(),
design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(designzeroed)),
offset: offset.to_owned(),
penalties: fit_penalties
.iter()
.map(|ps| match ps {
PenaltySpec::Block {
local, col_range, ..
} => PenaltyMatrix::Blockwise {
local: local.clone(),
col_range: col_range.clone(),
total_dim: design.design.ncols(),
},
PenaltySpec::Dense(m) => PenaltyMatrix::Dense(m.clone()),
})
.collect(),
nullspace_dims: design.nullspace_dims.clone(),
initial_log_lambdas,
initial_beta: Some(initial_beta),
};
let fit = fit_custom_family(
&family_adapter,
&[blockspec],
&BlockwiseFitOptions {
inner_max_cycles: options.max_iter,
inner_tol: options.tol,
outer_max_iter: options.max_iter,
outer_tol: options.tol,
..BlockwiseFitOptions::default()
},
)
.map_err(|e| EstimationError::InvalidInput(e.to_string()))?;
let latent_beta = fit.block_states[0].beta.clone();
let (beta_user_internal, jac_diag) = family_adapter.user_beta_and_jacobian(&latent_beta);
let beta_user = conditioning.backtransform_beta(&beta_user_internal);
let latent_cov = fit.covariance_conditional.clone();
let beta_covariance = latent_cov.as_ref().map(|cov| {
let mut out = cov.clone();
let jac_col = jac_diag.view().insert_axis(ndarray::Axis(1));
let jacrow = jac_diag.view().insert_axis(ndarray::Axis(0));
out *= &(jac_col.to_owned() * jacrow);
conditioning.backtransform_covariance(&out)
});
let beta_standard_errors = beta_covariance
.as_ref()
.map(|cov| Array1::from_iter((0..cov.nrows()).map(|i| cov[[i, i]].max(0.0).sqrt())));
let (eta_state, h_data, _, _) = family_adapter
.evaluation_from_latent(&latent_beta)
.map_err(EstimationError::InvalidInput)?;
let p_fit = fit_design.ncols();
let mut s_lambda_internal = Array2::<f64>::zeros((p_fit, p_fit));
for (k, penalty) in fit_penalties.iter().enumerate() {
match penalty {
PenaltySpec::Block {
local, col_range, ..
} => {
s_lambda_internal
.slice_mut(ndarray::s![col_range.clone(), col_range.clone()])
.scaled_add(fit.lambdas[k], local);
}
PenaltySpec::Dense(m) => {
s_lambda_internal.scaled_add(fit.lambdas[k], m);
}
}
}
let mut penalized_hessian = h_data.clone();
penalized_hessian += &s_lambda_internal;
let penalized_hessian =
conditioning.transform_penalized_hessian_to_original(&penalized_hessian);
let s_lambda_original = weighted_blockwise_penalty_sum(
&design.penalties,
fit.lambdas.as_slice().unwrap(),
design.design.ncols(),
);
let penalty_term = beta_user.dot(&s_lambda_original.dot(&beta_user));
let deviance = match family {
LikelihoodFamily::GaussianIdentity => y
.iter()
.zip(eta_state.mu.iter())
.zip(weights.iter())
.map(|((&yy, &mu), &w)| w.max(0.0) * (yy - mu) * (yy - mu))
.sum(),
_ => -2.0 * eta_state.log_likelihood,
};
let (edf_by_block, edf_total) = if let Some(cov) = latent_cov.as_ref() {
exact_bounded_edf(&fit_penalties, &fit.lambdas, cov)?
} else {
(vec![0.0; fit_penalties.len()], 0.0)
};
let geometry = Some(crate::estimate::FitGeometry {
penalized_hessian: penalized_hessian.clone(),
working_weights: eta_state.fisherweight.clone(),
working_response: {
let mut working_response = eta_state.eta.clone();
for i in 0..working_response.len() {
let wi = eta_state.fisherweight[i].max(1e-12);
working_response[i] += eta_state.score[i] / wi;
}
working_response
},
});
let max_abs_eta = eta_state
.eta
.iter()
.fold(0.0_f64, |acc, &v| acc.max(v.abs()));
Ok(FittedTermCollection {
fit: {
let log_lambdas = fit.lambdas.mapv(|v| v.max(1e-300).ln());
let inf = FitInference {
edf_by_block,
edf_total,
smoothing_correction: None,
penalized_hessian: penalized_hessian.clone(),
working_weights: eta_state.fisherweight.clone(),
working_response: {
let mut working_response = eta_state.eta.clone();
for i in 0..working_response.len() {
let wi = eta_state.fisherweight[i].max(1e-12);
working_response[i] += eta_state.score[i] / wi;
}
working_response
},
reparam_qs: None,
beta_covariance,
beta_standard_errors,
beta_covariance_corrected: None,
beta_standard_errors_corrected: None,
bias_correction_beta: None,
};
let covariance_conditional = inf.beta_covariance.clone();
let pirls_status_val = if fit.outer_converged {
crate::pirls::PirlsStatus::Converged
} else {
crate::pirls::PirlsStatus::StalledAtValidMinimum
};
UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
blocks: vec![crate::estimate::FittedBlock {
beta: beta_user.clone(),
role: crate::estimate::BlockRole::Mean,
edf: edf_total,
lambdas: fit.lambdas.clone(),
}],
log_lambdas,
lambdas: fit.lambdas,
likelihood_family: Some(family),
likelihood_scale: family.default_scale_metadata(),
log_likelihood_normalization:
crate::types::LogLikelihoodNormalization::UserProvided,
log_likelihood: eta_state.log_likelihood,
deviance,
reml_score: fit.penalized_objective,
stable_penalty_term: penalty_term,
penalized_objective: fit.penalized_objective,
outer_iterations: fit.outer_iterations,
outer_converged: fit.outer_converged,
outer_gradient_norm: fit.outer_gradient_norm,
standard_deviation: 1.0,
covariance_conditional,
covariance_corrected: None,
inference: Some(inf),
fitted_link: crate::estimate::FittedLinkState::Standard(None),
geometry,
block_states: Vec::new(),
pirls_status: pirls_status_val,
max_abs_eta,
constraint_kkt: None,
artifacts: crate::estimate::FitArtifacts {
pirls: None,
..Default::default()
},
inner_cycles: 0,
})?
},
design: design.clone(),
adaptive_diagnostics: None,
})
}
fn enforce_term_constraint_feasibility(
design: &TermCollectionDesign,
fit: &UnifiedFitResult,
) -> Result<(), EstimationError> {
let tol = 1e-7;
let smooth_start = design
.design
.ncols()
.saturating_sub(design.smooth.total_smooth_cols());
let mut violations: Vec<String> = Vec::new();
for term in &design.smooth.terms {
let gr = (smooth_start + term.coeff_range.start)..(smooth_start + term.coeff_range.end);
let beta_local = fit.beta.slice(s![gr.clone()]).to_owned();
if let Some(lb) = term.lower_bounds_local.as_ref() {
let mut worst = 0.0_f64;
let mut worst_idx = 0usize;
for i in 0..lb.len().min(beta_local.len()) {
if lb[i].is_finite() {
let viol = (lb[i] - beta_local[i]).max(0.0);
if viol > worst {
worst = viol;
worst_idx = i;
}
}
}
if worst > tol {
violations.push(format!(
"term='{}' kind=lower-bound maxviolation={:.3e} coeff_index={}",
term.name, worst, worst_idx
));
}
}
if let Some(lin) = term.linear_constraints_local.as_ref() {
let slack = lin.a.dot(&beta_local) - &lin.b;
let mut worst = 0.0_f64;
let mut worstrow = 0usize;
for (i, &v) in slack.iter().enumerate() {
let viol = (-v).max(0.0);
if viol > worst {
worst = viol;
worstrow = i;
}
}
if worst > tol {
violations.push(format!(
"term='{}' kind=linear-inequality maxviolation={:.3e} row={}",
term.name, worst, worstrow
));
}
}
}
if !violations.is_empty() {
let mut msg = format!(
"constraint violation after fit ({} violating term constraints): {}",
violations.len(),
violations.join(" | ")
);
if let Some(kkt) = fit.constraint_kkt.as_ref() {
msg.push_str(&format!(
"; KKT[primal={:.3e}, dual={:.3e}, comp={:.3e}, stat={:.3e}]",
kkt.primal_feasibility, kkt.dual_feasibility, kkt.complementarity, kkt.stationarity
));
}
return Err(EstimationError::ParameterConstraintViolation(msg));
}
Ok(())
}
fn stratified_spatial_subsample(
data: ArrayView2<'_, f64>,
spec: &TermCollectionSpec,
target_size: usize,
) -> Vec<usize> {
use rand::SeedableRng;
use rand::rngs::StdRng;
use rand::seq::SliceRandom;
let n = data.nrows();
if n <= target_size {
return (0..n).collect();
}
let spatial_cols: Option<Vec<usize>> =
spec.smooth_terms.iter().find_map(|term| match &term.basis {
SmoothBasisSpec::ThinPlate { feature_cols, .. }
| SmoothBasisSpec::Matern { feature_cols, .. }
| SmoothBasisSpec::Duchon { feature_cols, .. } => {
if feature_cols.len() >= 1 {
Some(feature_cols.clone())
} else {
None
}
}
_ => None,
});
let mut rng = StdRng::seed_from_u64(42);
let cols = match spatial_cols {
Some(c) if !c.is_empty() => c,
_ => {
let mut indices: Vec<usize> = (0..n).collect();
indices.shuffle(&mut rng);
indices.truncate(target_size);
indices.sort_unstable();
return indices;
}
};
let d = cols.len();
let mut mins = vec![f64::INFINITY; d];
let mut maxs = vec![f64::NEG_INFINITY; d];
for i in 0..n {
for (ax, &col) in cols.iter().enumerate() {
let v = data[[i, col]];
if v < mins[ax] {
mins[ax] = v;
}
if v > maxs[ax] {
maxs[ax] = v;
}
}
}
let total_cells_target = (target_size / 5).max(1);
let cells_per_axis = ((total_cells_target as f64).powf(1.0 / d as f64)).ceil() as usize;
let cells_per_axis = cells_per_axis.max(1);
let mut cell_members: std::collections::HashMap<Vec<usize>, Vec<usize>> =
std::collections::HashMap::new();
for i in 0..n {
let mut cell_key = Vec::with_capacity(d);
for (ax, &col) in cols.iter().enumerate() {
let range = maxs[ax] - mins[ax];
let cell = if range <= 0.0 {
0
} else {
let frac = (data[[i, col]] - mins[ax]) / range;
(frac * cells_per_axis as f64).floor() as usize
};
cell_key.push(cell.min(cells_per_axis - 1));
}
cell_members.entry(cell_key).or_default().push(i);
}
let mut selected: Vec<usize> = Vec::with_capacity(target_size);
let mut remaining_budget = target_size;
let mut remaining_population = n;
let mut cells: Vec<(Vec<usize>, Vec<usize>)> = cell_members.into_iter().collect();
cells.sort_by(|a, b| a.0.cmp(&b.0));
for (_, members) in &mut cells {
if remaining_budget == 0 {
break;
}
let alloc = ((members.len() as f64 / remaining_population as f64) * remaining_budget as f64)
.round() as usize;
let alloc = alloc.max(1).min(members.len()).min(remaining_budget);
members.shuffle(&mut rng);
selected.extend_from_slice(&members[..alloc]);
remaining_budget = remaining_budget.saturating_sub(alloc);
remaining_population = remaining_population.saturating_sub(members.len());
}
if selected.len() > target_size {
selected.shuffle(&mut rng);
selected.truncate(target_size);
}
selected.sort_unstable();
selected
}
fn sampled_rows(data: ArrayView2<'_, f64>, indices: &[usize]) -> Array2<f64> {
let mut sampled = Array2::<f64>::zeros((indices.len(), data.ncols()));
for (new_row, &orig_row) in indices.iter().enumerate() {
sampled.row_mut(new_row).assign(&data.row(orig_row));
}
sampled
}
fn spatial_term_user_centers(term: &SmoothTermSpec) -> Option<ArrayView2<'_, f64>> {
match spatial_term_center_strategy(term) {
Some(CenterStrategy::UserProvided(centers)) => Some(centers.view()),
_ => None,
}
}
fn finite_centered_axis_contrasts(values: &[f64], expected_dim: usize) -> Option<Vec<f64>> {
if values.len() != expected_dim || expected_dim <= 1 {
return None;
}
if values.iter().any(|value| !value.is_finite()) {
return None;
}
Some(center_aniso_log_scales(values))
}
fn blended_pilot_axis_contrasts(
pilot_data: ArrayView2<'_, f64>,
term: &SmoothTermSpec,
centers: ArrayView2<'_, f64>,
) -> Option<Vec<f64>> {
let d = centers.ncols();
if d <= 1 {
return None;
}
let center_eta = initial_aniso_contrasts(centers);
let data_eta = standardized_spatial_term_data(pilot_data, term)
.ok()
.and_then(|x| finite_centered_axis_contrasts(&initial_aniso_contrasts(x.view()), d));
let center_eta = finite_centered_axis_contrasts(¢er_eta, d)?;
let blended = match data_eta {
Some(data_eta) => center_eta
.iter()
.zip(data_eta.iter())
.map(|(&from_centers, &from_data)| 0.5 * (from_centers + from_data))
.collect::<Vec<_>>(),
None => center_eta,
};
finite_centered_axis_contrasts(&blended, d)
}
fn apply_pilot_spatial_psi_reseed(
pilot_data: ArrayView2<'_, f64>,
spec: &TermCollectionSpec,
spatial_terms: &[usize],
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<TermCollectionSpec, EstimationError> {
let dims_per_term = spatial_dims_per_term(spec, spatial_terms);
let use_aniso = has_aniso_terms(spec, spatial_terms);
let log_kappa0 = if use_aniso {
SpatialLogKappaCoords::from_length_scales_aniso(spec, spatial_terms, kappa_options)
} else {
SpatialLogKappaCoords::from_length_scales(spec, spatial_terms, kappa_options)
};
let log_kappa0 = log_kappa0.reseed_from_data(pilot_data, spec, spatial_terms, kappa_options);
let log_kappa_lower = if use_aniso {
SpatialLogKappaCoords::lower_bounds_aniso_from_data(
pilot_data,
spec,
spatial_terms,
&dims_per_term,
kappa_options,
)
} else {
SpatialLogKappaCoords::lower_bounds_from_data(
pilot_data,
spec,
spatial_terms,
kappa_options,
)
};
let log_kappa_upper = if use_aniso {
SpatialLogKappaCoords::upper_bounds_aniso_from_data(
pilot_data,
spec,
spatial_terms,
&dims_per_term,
kappa_options,
)
} else {
SpatialLogKappaCoords::upper_bounds_from_data(
pilot_data,
spec,
spatial_terms,
kappa_options,
)
};
log_kappa0
.clamp_to_bounds(&log_kappa_lower, &log_kappa_upper)
.apply_tospec(spec, spatial_terms)
}
pub(crate) fn apply_spatial_anisotropy_pilot_initializer(
data: ArrayView2<'_, f64>,
spec: &mut TermCollectionSpec,
spatial_terms: &[usize],
target_size: usize,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> usize {
if target_size == 0 || data.nrows() <= target_size.saturating_mul(2) || spatial_terms.is_empty()
{
return 0;
}
if !has_aniso_terms(spec, spatial_terms) {
return 0;
}
let indices = stratified_spatial_subsample(data, spec, target_size);
let pilot_data = sampled_rows(data, &indices);
let mut working = spec.clone();
let mut updated_terms = 0usize;
const GEOMETRY_UPDATES: usize = 2;
for pass in 0..GEOMETRY_UPDATES {
let planned_terms = match plan_joint_spatial_centers_for_term_blocks(
pilot_data.view(),
&[working.smooth_terms.clone()],
)
.and_then(|mut blocks| {
blocks.pop().ok_or_else(|| {
BasisError::InvalidInput(
"pilot geometry initializer produced no smooth-term block".to_string(),
)
})
}) {
Ok(terms) => terms,
Err(err) => {
log::warn!(
"[spatial-kappa] pilot geometry initializer skipped after center planning failed: {err}"
);
return updated_terms;
}
};
for &term_idx in spatial_terms {
let Some(current_eta) = get_spatial_aniso_log_scales(&working, term_idx) else {
continue;
};
let Some(d) = get_spatial_feature_dim(&working, term_idx) else {
continue;
};
if d <= 1 || current_eta.len() != d {
continue;
}
let Some(planned_term) = planned_terms.get(term_idx) else {
continue;
};
let Some(centers) = spatial_term_user_centers(planned_term) else {
continue;
};
let Some(eta) = blended_pilot_axis_contrasts(pilot_data.view(), planned_term, centers)
else {
continue;
};
if set_spatial_aniso_log_scales(&mut working, term_idx, eta).is_ok() {
updated_terms += usize::from(pass == 0);
}
}
match apply_pilot_spatial_psi_reseed(
pilot_data.view(),
&working,
spatial_terms,
kappa_options,
) {
Ok(updated) => {
working = updated;
}
Err(err) => {
log::warn!(
"[spatial-kappa] pilot geometry ψ reseed skipped after deterministic initializer error: {err}"
);
break;
}
}
}
if updated_terms > 0 {
log::info!(
"[spatial-kappa] initialized anisotropy from {}-row pilot geometry for {} spatial term(s); proceeding to full-data optimization",
indices.len(),
updated_terms
);
*spec = working;
}
updated_terms
}
pub(crate) fn spatial_length_scale_term_indices(spec: &TermCollectionSpec) -> Vec<usize> {
spec.smooth_terms
.iter()
.enumerate()
.filter_map(|(idx, _)| spatial_term_supports_hyper_optimization(spec, idx).then_some(idx))
.collect()
}
pub fn all_spatial_terms_kappa_fixed(spec: &TermCollectionSpec) -> bool {
spec.smooth_terms
.iter()
.enumerate()
.all(|(idx, _)| {
!spatial_term_supports_hyper_optimization(spec, idx)
|| spatial_term_has_locked_kappa(spec, idx)
})
}
fn fit_score(fit: &UnifiedFitResult) -> f64 {
if fit.reml_score.is_finite() {
return fit.reml_score;
}
let score = 0.5 * fit.deviance + 0.5 * fit.stable_penalty_term;
if score.is_finite() {
score
} else {
f64::INFINITY
}
}
fn require_successful_spatial_optimization_result<T>(
initial_score: f64,
result: Result<Option<(T, f64)>, EstimationError>,
) -> Result<T, EstimationError> {
match result {
Ok(Some((value, exact_score))) => {
if exact_score <= initial_score + 1e-10 {
Ok(value)
} else {
Err(EstimationError::RemlOptimizationFailed(format!(
"spatial kappa optimization made REML score worse ({initial_score:.6e} -> {exact_score:.6e})"
)))
}
}
Ok(None) => Err(EstimationError::RemlOptimizationFailed(
"spatial kappa optimization is unavailable for one or more eligible spatial terms"
.to_string(),
)),
Err(err) => Err(EstimationError::RemlOptimizationFailed(format!(
"spatial kappa optimization failed: {err}"
))),
}
}
fn external_opts_for_design(
family: LikelihoodFamily,
design: &TermCollectionDesign,
options: &FitOptions,
) -> ExternalOptimOptions {
ExternalOptimOptions {
family,
latent_cloglog: options.latent_cloglog,
mixture_link: options.mixture_link.clone(),
optimize_mixture: options.optimize_mixture,
sas_link: options.sas_link,
optimize_sas: options.optimize_sas,
compute_inference: options.compute_inference,
max_iter: options.max_iter,
tol: options.tol,
nullspace_dims: design.nullspace_dims.clone(),
linear_constraints: design.linear_constraints.clone(),
firth_bias_reduction: Some(options.firth_bias_reduction),
penalty_shrinkage_floor: options.penalty_shrinkage_floor,
rho_prior: options.rho_prior.clone(),
kronecker_penalty_system: design.kronecker_penalty_system(),
kronecker_factored: design
.smooth
.terms
.iter()
.find_map(|t| t.kronecker_factored.clone()),
}
}
fn evaluate_joint_reml_outer_eval_at_theta(
evaluator: &mut crate::estimate::ExternalJointHyperEvaluator<'_>,
design: &TermCollectionDesign,
theta: &Array1<f64>,
rho_dim: usize,
hyper_dirs: Vec<crate::estimate::reml::DirectionalHyperParam>,
warm_start_beta: Option<ArrayView1<'_, f64>>,
order: crate::solver::outer_strategy::OuterEvalOrder,
) -> Result<
(
f64,
Array1<f64>,
crate::solver::outer_strategy::HessianResult,
),
EstimationError,
> {
evaluator.evaluate_with_order(
&design.design,
&design.penalties,
&design.nullspace_dims,
design.linear_constraints.clone(),
theta,
rho_dim,
hyper_dirs,
warm_start_beta,
"evaluate_joint_reml_outer_eval_at_theta",
order,
)
}
fn evaluate_joint_reml_efs_at_theta(
evaluator: &mut crate::estimate::ExternalJointHyperEvaluator<'_>,
design: &TermCollectionDesign,
theta: &Array1<f64>,
rho_dim: usize,
hyper_dirs: Vec<crate::estimate::reml::DirectionalHyperParam>,
warm_start_beta: Option<ArrayView1<'_, f64>>,
) -> Result<crate::solver::outer_strategy::EfsEval, EstimationError> {
evaluator.evaluate_efs(
&design.design,
&design.penalties,
&design.nullspace_dims,
design.linear_constraints.clone(),
theta,
rho_dim,
hyper_dirs,
warm_start_beta,
"evaluate_joint_reml_efs_at_theta",
)
}
fn exact_joint_spatial_outer_hessian_available(
_family: LikelihoodFamily,
_design: &TermCollectionDesign,
) -> bool {
true
}
fn smooth_term_penalty_index(
spec: &TermCollectionSpec,
design: &TermCollectionDesign,
term_idx: usize,
) -> Option<usize> {
if term_idx >= design.smooth.terms.len() || term_idx >= spec.smooth_terms.len() {
return None;
}
if design.smooth.terms[term_idx].penalties_local.is_empty() {
return None;
}
let linear_penalties = usize::from(spec.linear_terms.iter().any(|t| t.double_penalty));
let random_penalties = design
.random_effect_ranges
.iter()
.filter(|(_, range)| !range.is_empty())
.count();
let smooth_offset = linear_penalties + random_penalties;
let local_offset = design
.smooth
.terms
.iter()
.take(term_idx)
.map(|term| term.penalties_local.len())
.sum::<usize>();
Some(smooth_offset + local_offset)
}
fn try_build_spatial_term_log_kappa_derivativeinfo(
data: ArrayView2<'_, f64>,
resolvedspec: &TermCollectionSpec,
design: &TermCollectionDesign,
term_idx: usize,
) -> Result<Option<SpatialPsiDerivative>, EstimationError> {
let Some((
global_range,
total_p,
x_psi_local,
s_psi_local_check,
x_psi_psi_local,
s_psi_psi_local,
s_psi_components_local,
s_psi_psi_components_local,
implicit_operator,
)) = try_build_spatial_term_log_kappa_derivative(data, resolvedspec, design, term_idx)?
else {
return Ok(None);
};
let Some(penalty_start) = smooth_term_penalty_index(resolvedspec, design, term_idx) else {
return Ok(None);
};
if s_psi_components_local.is_empty() || s_psi_psi_components_local.is_empty() {
return Ok(None);
}
if s_psi_components_local.len() != s_psi_psi_components_local.len() {
return Ok(None);
}
let penalty_indices = (0..s_psi_components_local.len())
.map(|j| penalty_start + j)
.collect::<Vec<_>>();
let penalty_index = penalty_indices[0];
if s_psi_local_check.nrows() == 0 || s_psi_psi_local.nrows() == 0 {
return Ok(None);
}
Ok(Some(SpatialPsiDerivative {
penalty_index,
penalty_indices,
global_range,
total_p,
x_psi_local,
s_psi_components_local,
x_psi_psi_local,
s_psi_psi_components_local,
aniso_group_id: None,
aniso_cross_designs: None,
aniso_cross_penalty_provider: None,
implicit_operator,
implicit_axis: 0,
}))
}
pub(crate) fn try_build_spatial_log_kappa_derivativeinfo_list(
data: ArrayView2<'_, f64>,
resolvedspec: &TermCollectionSpec,
design: &TermCollectionDesign,
spatial_terms: &[usize],
) -> Result<Option<Vec<SpatialPsiDerivative>>, EstimationError> {
let mut out = Vec::new();
let mut aniso_gid = 0usize;
for &term_idx in spatial_terms {
let aniso = get_spatial_aniso_log_scales(resolvedspec, term_idx);
let dim = get_spatial_feature_dim(resolvedspec, term_idx);
if let (Some(eta), Some(d)) = (&aniso, dim) {
if eta.len() == d && d > 1 {
if let Some(entries) = try_build_spatial_term_log_kappa_aniso_derivativeinfos(
data,
resolvedspec,
design,
term_idx,
aniso_gid,
)? {
aniso_gid += 1;
out.extend(entries);
continue;
} else {
return Ok(None);
}
}
}
let Some(info) =
try_build_spatial_term_log_kappa_derivativeinfo(data, resolvedspec, design, term_idx)?
else {
return Ok(None);
};
out.push(info);
}
Ok(Some(out))
}
fn try_build_spatial_term_log_kappa_aniso_derivativeinfos(
data: ArrayView2<'_, f64>,
resolvedspec: &TermCollectionSpec,
design: &TermCollectionDesign,
term_idx: usize,
aniso_group_id: usize,
) -> Result<Option<Vec<SpatialPsiDerivative>>, EstimationError> {
let smooth_term = match design.smooth.terms.get(term_idx) {
Some(t) => t,
None => return Ok(None),
};
let termspec = match resolvedspec.smooth_terms.get(term_idx) {
Some(t) => t,
None => return Ok(None),
};
let mut aniso_result = match &termspec.basis {
SmoothBasisSpec::Matern {
feature_cols,
spec,
input_scales,
} => {
let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
if let Some(s) = input_scales {
apply_input_standardization(&mut x, s);
}
build_matern_basis_log_kappa_aniso_derivatives(x.view(), spec)
.map_err(EstimationError::from)?
}
SmoothBasisSpec::Duchon {
feature_cols,
spec,
input_scales,
} => {
let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
if let Some(s) = input_scales {
apply_input_standardization(&mut x, s);
}
build_duchon_basis_log_kappa_aniso_derivatives(x.view(), spec)
.map_err(EstimationError::from)?
}
_ => return Ok(None),
};
let d = if let Some(ref op) = aniso_result.implicit_operator {
op.n_axes()
} else if !aniso_result.design_first.is_empty() {
aniso_result.design_first.len()
} else {
0
};
if d == 0 {
return Ok(None);
}
let Some(penalty_start) = smooth_term_penalty_index(resolvedspec, design, term_idx) else {
return Ok(None);
};
let p_total = design.design.ncols();
let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
let global_range = (smooth_start + smooth_term.coeff_range.start)
..(smooth_start + smooth_term.coeff_range.end);
let num_penalties = aniso_result.penalties_first[0].len();
let penalty_indices: Vec<usize> = (0..num_penalties).map(|j| penalty_start + j).collect();
let penalties_cross_provider = aniso_result.penalties_cross_provider.clone();
let use_implicit_design = aniso_result.design_first.is_empty();
let implicit_op_arc = aniso_result
.implicit_operator
.as_ref()
.map(|op| std::sync::Arc::new(op.clone()));
let mut entries = Vec::with_capacity(d);
for a in 0..d {
let (x_psi_local, x_psi_psi_local) = if use_implicit_design {
(Array2::<f64>::zeros((0, 0)), Array2::<f64>::zeros((0, 0)))
} else {
let x_first = std::mem::take(&mut aniso_result.design_first[a]);
let x_second = std::mem::take(&mut aniso_result.design_second_diag[a]);
if x_first.ncols() != smooth_term.coeff_range.len() {
return Ok(None);
}
(x_first, x_second)
};
let s_psi_components = std::mem::take(&mut aniso_result.penalties_first[a]);
let s_psi_psi_components = std::mem::take(&mut aniso_result.penalties_second_diag[a]);
let cross_designs = if implicit_op_arc.is_some() {
let mut cd = Vec::with_capacity(d - 1);
for b in 0..d {
if b == a {
continue;
}
cd.push((b, Array2::<f64>::zeros((0, 0))));
}
cd
} else if !aniso_result.design_second_cross.is_empty() {
let mut cd = Vec::new();
for (cross_idx, &(pa, pb)) in aniso_result.design_second_cross_pairs.iter().enumerate()
{
if pa == a {
cd.push((pb, aniso_result.design_second_cross[cross_idx].clone()));
} else if pb == a {
cd.push((pa, aniso_result.design_second_cross[cross_idx].clone()));
}
}
cd
} else {
Vec::new()
};
let cross_penalty_provider = if d > 1 {
let penalties_cross_provider = penalties_cross_provider.clone();
Some(std::sync::Arc::new(
move |b_axis: usize| -> Result<Vec<Array2<f64>>, EstimationError> {
if b_axis == a {
return Ok(Vec::new());
}
let (axis_lo, axis_hi) = if a < b_axis { (a, b_axis) } else { (b_axis, a) };
if let Some(provider) = penalties_cross_provider.as_ref() {
provider
.evaluate(axis_lo, axis_hi)
.map_err(EstimationError::from)
} else {
Ok(Vec::new())
}
},
)
as std::sync::Arc<
dyn Fn(usize) -> Result<Vec<Array2<f64>>, EstimationError>
+ Send
+ Sync
+ 'static,
>)
} else {
None
};
entries.push(SpatialPsiDerivative {
penalty_index: penalty_indices[0],
penalty_indices: penalty_indices.clone(),
global_range: global_range.clone(),
total_p: p_total,
x_psi_local,
s_psi_components_local: s_psi_components,
x_psi_psi_local,
s_psi_psi_components_local: s_psi_psi_components,
aniso_group_id: Some(aniso_group_id),
aniso_cross_designs: if cross_designs.is_empty() {
None
} else {
Some(cross_designs)
},
aniso_cross_penalty_provider: cross_penalty_provider,
implicit_operator: implicit_op_arc.clone(),
implicit_axis: a,
});
}
Ok(Some(entries))
}
fn try_build_spatial_term_log_kappa_derivative(
data: ArrayView2<'_, f64>,
resolvedspec: &TermCollectionSpec,
design: &TermCollectionDesign,
term_idx: usize,
) -> Result<
Option<(
Range<usize>,
usize,
Array2<f64>,
Array2<f64>,
Array2<f64>,
Array2<f64>,
Vec<Array2<f64>>,
Vec<Array2<f64>>,
Option<std::sync::Arc<crate::terms::basis::ImplicitDesignPsiDerivative>>,
)>,
EstimationError,
> {
let smooth_term = match design.smooth.terms.get(term_idx) {
Some(term) => term,
None => return Ok(None),
};
let termspec = match resolvedspec.smooth_terms.get(term_idx) {
Some(term) => term,
None => return Ok(None),
};
let derivative_bundle = match &termspec.basis {
SmoothBasisSpec::ThinPlate {
feature_cols,
spec,
input_scales,
} => {
let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
if let Some(s) = input_scales {
apply_input_standardization(&mut x, s);
}
build_thin_plate_basis_log_kappa_derivatives(x.view(), spec)
.map_err(EstimationError::from)?
}
SmoothBasisSpec::Matern {
feature_cols,
spec,
input_scales,
} => {
let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
if let Some(s) = input_scales {
apply_input_standardization(&mut x, s);
}
build_matern_basis_log_kappa_derivatives(x.view(), spec)
.map_err(EstimationError::from)?
}
SmoothBasisSpec::Duchon {
feature_cols,
spec,
input_scales,
} => {
let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
if let Some(s) = input_scales {
apply_input_standardization(&mut x, s);
}
build_duchon_basis_log_kappa_derivatives(x.view(), spec)
.map_err(EstimationError::from)?
}
SmoothBasisSpec::BSpline1D { .. } | SmoothBasisSpec::TensorBSpline { .. } => {
return Ok(None);
}
};
let implicit_operator = derivative_bundle.implicit_operator.map(std::sync::Arc::new);
let BasisPsiDerivativeResult {
design_derivative: local_x_psi,
penalties_derivative: local_s_psi,
implicit_operator: local_implicit_first_unused,
} = derivative_bundle.first;
let BasisPsiSecondDerivativeResult {
designsecond_derivative: local_x_psi_psi,
penaltiessecond_derivative: local_s_psi_psi,
implicit_operator: local_implicit_second_unused,
} = derivative_bundle.second;
debug_assert!(local_implicit_first_unused.is_none());
debug_assert!(local_implicit_second_unused.is_none());
if let Some(ref op) = implicit_operator {
if op.p_out() != smooth_term.coeff_range.len() {
return Ok(None);
}
} else {
if local_x_psi.ncols() != smooth_term.coeff_range.len() {
return Ok(None);
}
if local_x_psi_psi.ncols() != smooth_term.coeff_range.len() {
return Ok(None);
}
}
if local_s_psi.is_empty() || local_s_psi.len() != local_s_psi_psi.len() {
return Ok(None);
}
if local_s_psi.iter().any(|s| {
s.nrows() != smooth_term.coeff_range.len() || s.ncols() != smooth_term.coeff_range.len()
}) {
return Ok(None);
}
if local_s_psi_psi.iter().any(|s| {
s.nrows() != smooth_term.coeff_range.len() || s.ncols() != smooth_term.coeff_range.len()
}) {
return Ok(None);
}
let p_total = design.design.ncols();
let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
let global_range = (smooth_start + smooth_term.coeff_range.start)
..(smooth_start + smooth_term.coeff_range.end);
Ok(Some((
global_range,
p_total,
local_x_psi,
local_s_psi.iter().fold(
Array2::<f64>::zeros((smooth_term.coeff_range.len(), smooth_term.coeff_range.len())),
|acc, m| acc + m,
),
local_x_psi_psi,
local_s_psi_psi.iter().fold(
Array2::<f64>::zeros((smooth_term.coeff_range.len(), smooth_term.coeff_range.len())),
|acc, m| acc + m,
),
local_s_psi,
local_s_psi_psi,
implicit_operator,
)))
}
fn try_build_spatial_log_kappa_hyper_dirs(
data: ArrayView2<'_, f64>,
resolvedspec: &TermCollectionSpec,
design: &TermCollectionDesign,
spatial_terms: &[usize],
) -> Result<Option<Vec<DirectionalHyperParam>>, EstimationError> {
let Some(info_list) =
try_build_spatial_log_kappa_derivativeinfo_list(data, resolvedspec, design, spatial_terms)?
else {
return Ok(None);
};
Ok(Some(spatial_log_kappa_hyper_dirs_frominfo_list(info_list)?))
}
fn spatial_log_kappa_hyper_dirs_frominfo_list(
info_list: Vec<SpatialPsiDerivative>,
) -> Result<Vec<DirectionalHyperParam>, EstimationError> {
use crate::estimate::reml::ImplicitDerivLevel;
use std::collections::HashMap;
let log_kappa_dim = info_list.len();
let group_ids: Vec<Option<usize>> = info_list.iter().map(|e| e.aniso_group_id).collect();
let mut group_indices_map: HashMap<usize, Vec<usize>> = HashMap::new();
for (idx, gid) in group_ids.iter().enumerate() {
if let Some(g) = gid {
group_indices_map.entry(*g).or_default().push(idx);
}
}
let mut hyper_dirs = Vec::with_capacity(log_kappa_dim);
for (i, info) in info_list.into_iter().enumerate() {
let SpatialPsiDerivative {
penalty_index: _,
penalty_indices,
global_range,
total_p,
x_psi_local,
s_psi_components_local,
x_psi_psi_local,
s_psi_psi_components_local,
aniso_group_id,
aniso_cross_designs,
aniso_cross_penalty_provider,
implicit_operator,
implicit_axis,
} = info;
let mut xsecond = vec![None; log_kappa_dim];
xsecond[i] = Some(if let Some(ref op) = implicit_operator {
crate::estimate::reml::HyperDesignDerivative::from_implicit(
op.clone(),
ImplicitDerivLevel::SecondDiag(implicit_axis),
global_range.clone(),
total_p,
)
} else {
crate::estimate::reml::HyperDesignDerivative::from_embedded(
x_psi_psi_local,
global_range.clone(),
total_p,
)
});
if let Some(cross_designs) = aniso_cross_designs {
if let Some(gid) = aniso_group_id {
let base = group_indices_map
.get(&gid)
.and_then(|v| v.first().copied())
.unwrap_or(i);
for (b_axis, cross_mat) in cross_designs.into_iter() {
let j = base + b_axis;
if j < log_kappa_dim {
xsecond[j] = Some(if let Some(ref op) = implicit_operator {
crate::estimate::reml::HyperDesignDerivative::from_implicit(
op.clone(),
ImplicitDerivLevel::SecondCross(implicit_axis, b_axis),
global_range.clone(),
total_p,
)
} else {
crate::estimate::reml::HyperDesignDerivative::from_embedded(
cross_mat,
global_range.clone(),
total_p,
)
});
}
}
}
}
let s_components = penalty_indices
.iter()
.copied()
.zip(s_psi_components_local.into_iter().map(|local| {
crate::estimate::reml::HyperPenaltyDerivative::from_embedded(
local,
global_range.clone(),
total_p,
)
}))
.collect::<Vec<_>>();
let s2_components = penalty_indices
.iter()
.copied()
.zip(s_psi_psi_components_local.into_iter().map(|local| {
crate::estimate::reml::HyperPenaltyDerivative::from_embedded(
local,
global_range.clone(),
total_p,
)
}))
.collect::<Vec<_>>();
let mut ssecond_components = vec![None; log_kappa_dim];
ssecond_components[i] = Some(s2_components);
let mut penaltysecond_partner_indices: Option<Vec<usize>> = None;
let penaltysecond_component_provider =
if let (Some(provider), Some(gid)) = (aniso_cross_penalty_provider, aniso_group_id) {
let group_indices = group_indices_map.get(&gid).cloned().unwrap_or_default();
let axis_in_group =
group_indices
.iter()
.position(|&idx| idx == i)
.ok_or_else(|| {
EstimationError::InvalidInput(format!(
"missing spatial hyper axis {} in anisotropy group {}",
i, gid
))
})?;
penaltysecond_partner_indices = Some(
group_indices
.iter()
.copied()
.filter(|&idx| idx != i)
.collect(),
);
let penalty_indices_inner = penalty_indices.clone();
let global_range_inner = global_range.clone();
let total_p_inner = total_p;
let group_indices_inner = group_indices;
Some(std::sync::Arc::new(
move |j: usize| -> Result<
Option<Vec<crate::estimate::reml::PenaltyDerivativeComponent>>,
EstimationError,
> {
let Some(other_axis_in_group) =
group_indices_inner.iter().position(|&idx| idx == j)
else {
return Ok(None);
};
if other_axis_in_group == axis_in_group {
return Ok(None);
}
let cross_pens = provider(other_axis_in_group)?;
if cross_pens.is_empty() {
return Ok(None);
}
Ok(Some(
penalty_indices_inner
.iter()
.copied()
.zip(cross_pens.into_iter().map(|local| {
crate::estimate::reml::HyperPenaltyDerivative::from_embedded(
local,
global_range_inner.clone(),
total_p_inner,
)
}))
.map(|(penalty_index, matrix)| {
crate::estimate::reml::PenaltyDerivativeComponent {
penalty_index,
matrix,
}
})
.collect(),
))
},
)
as std::sync::Arc<
dyn Fn(
usize,
) -> Result<
Option<Vec<crate::estimate::reml::PenaltyDerivativeComponent>>,
EstimationError,
> + Send
+ Sync
+ 'static,
>)
} else {
None
};
let x_first_hyper = if let Some(ref op) = implicit_operator {
crate::estimate::reml::HyperDesignDerivative::from_implicit(
op.clone(),
ImplicitDerivLevel::First(implicit_axis),
global_range.clone(),
total_p,
)
} else {
crate::estimate::reml::HyperDesignDerivative::from_embedded(
x_psi_local,
global_range.clone(),
total_p,
)
};
let mut dir = DirectionalHyperParam::new_compact(
x_first_hyper,
s_components,
Some(xsecond),
Some(ssecond_components),
)?
.not_penalty_like();
if let Some(provider) = penaltysecond_component_provider {
dir = dir.with_penaltysecond_component_provider(provider);
}
if let Some(partner_indices) = penaltysecond_partner_indices {
dir = dir.with_penaltysecond_partner_indices(partner_indices);
}
hyper_dirs.push(dir);
}
Ok(hyper_dirs)
}
pub(crate) fn spatial_dims_per_term(
resolvedspec: &TermCollectionSpec,
spatial_terms: &[usize],
) -> Vec<usize> {
spatial_terms
.iter()
.map(|&term_idx| {
let d = get_spatial_feature_dim(resolvedspec, term_idx).unwrap_or(1);
if is_pure_duchon_aniso_term(resolvedspec, term_idx) {
return d.saturating_sub(1).max(1);
}
let has_aniso = get_spatial_aniso_log_scales(resolvedspec, term_idx).is_some();
if has_aniso && d > 1 { d } else { 1 }
})
.collect()
}
fn has_aniso_terms(resolvedspec: &TermCollectionSpec, spatial_terms: &[usize]) -> bool {
spatial_terms.iter().any(|&term_idx| {
get_spatial_aniso_log_scales(resolvedspec, term_idx).is_some_and(|eta| eta.len() > 1)
})
}
#[derive(Debug)]
struct SingleBlockExactJointDesignCache<'d> {
realizer: FrozenTermCollectionIncrementalRealizer<'d>,
current_theta: Option<Array1<f64>>,
last_cost: Option<f64>,
last_eval: Option<(
f64,
Array1<f64>,
crate::solver::outer_strategy::HessianResult,
)>,
spatial_terms: Vec<usize>,
rho_dim: usize,
dims_per_term: Vec<usize>,
}
impl<'d> SingleBlockExactJointDesignCache<'d> {
fn new(
data: ArrayView2<'d, f64>,
spec: TermCollectionSpec,
design: TermCollectionDesign,
spatial_terms: Vec<usize>,
rho_dim: usize,
dims_per_term: Vec<usize>,
) -> Result<Self, String> {
Ok(Self {
realizer: FrozenTermCollectionIncrementalRealizer::new(data, spec, design)?,
current_theta: None,
last_cost: None,
last_eval: None,
spatial_terms,
rho_dim,
dims_per_term,
})
}
fn ensure_theta(&mut self, theta: &Array1<f64>) -> Result<(), String> {
if self
.current_theta
.as_ref()
.is_some_and(|cached| theta_values_match(cached, theta))
{
return Ok(());
}
let t_ensure = std::time::Instant::now();
let log_kappa = SpatialLogKappaCoords::from_theta_tail_with_dims(
theta,
self.rho_dim,
self.dims_per_term.clone(),
);
self.realizer
.apply_log_kappa(&log_kappa, &self.spatial_terms)?;
log::info!(
"[STAGE] ensure_theta (apply_log_kappa, {} terms): {:.3}s",
self.spatial_terms.len(),
t_ensure.elapsed().as_secs_f64(),
);
self.current_theta = Some(theta.clone());
self.last_cost = None;
self.last_eval = None;
Ok(())
}
fn memoized_cost(&self, theta: &Array1<f64>) -> Option<f64> {
if self
.current_theta
.as_ref()
.is_some_and(|cached| theta_values_match(cached, theta))
{
self.last_eval
.as_ref()
.map(|cached| cached.0)
.or(self.last_cost)
} else {
None
}
}
fn memoized_eval(
&self,
theta: &Array1<f64>,
) -> Option<(
f64,
Array1<f64>,
crate::solver::outer_strategy::HessianResult,
)> {
if self
.current_theta
.as_ref()
.is_some_and(|cached| theta_values_match(cached, theta))
{
self.last_eval.clone()
} else {
None
}
}
fn store_eval(
&mut self,
eval: (
f64,
Array1<f64>,
crate::solver::outer_strategy::HessianResult,
),
) {
self.last_cost = Some(eval.0);
self.last_eval = Some(eval);
}
fn store_cost(&mut self, cost: f64) {
self.last_cost = Some(cost);
}
fn spec(&self) -> &TermCollectionSpec {
self.realizer.spec()
}
fn design(&self) -> &TermCollectionDesign {
self.realizer.design()
}
}
fn try_exact_joint_spatial_length_scale_optimization(
data: ArrayView2<'_, f64>,
y: ArrayView1<'_, f64>,
weights: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
resolvedspec: &TermCollectionSpec,
best: &FittedTermCollection,
family: LikelihoodFamily,
options: &FitOptions,
kappa_options: &SpatialLengthScaleOptimizationOptions,
spatial_terms: &[usize],
) -> Result<Option<FittedTermCollectionWithSpec>, EstimationError> {
if spatial_terms.is_empty() {
return Ok(None);
}
kappa_options
.validate()
.map_err(EstimationError::InvalidInput)?;
if try_build_spatial_log_kappa_hyper_dirs(data, resolvedspec, &best.design, spatial_terms)?
.is_none()
{
return Ok(None);
}
const JOINT_RHO_BOUND: f64 = 12.0;
let rho_dim = best.fit.lambdas.len();
let dims_per_term = spatial_dims_per_term(resolvedspec, spatial_terms);
let use_aniso = has_aniso_terms(resolvedspec, spatial_terms);
let log_kappa0 = if use_aniso {
SpatialLogKappaCoords::from_length_scales_aniso(resolvedspec, spatial_terms, kappa_options)
} else {
SpatialLogKappaCoords::from_length_scales(resolvedspec, spatial_terms, kappa_options)
};
let log_kappa0 = log_kappa0.reseed_from_data(data, resolvedspec, spatial_terms, kappa_options);
let log_kappa_lower = if use_aniso {
SpatialLogKappaCoords::lower_bounds_aniso_from_data(
data,
resolvedspec,
spatial_terms,
&dims_per_term,
kappa_options,
)
} else {
SpatialLogKappaCoords::lower_bounds_from_data(
data,
resolvedspec,
spatial_terms,
kappa_options,
)
};
let log_kappa_upper = if use_aniso {
SpatialLogKappaCoords::upper_bounds_aniso_from_data(
data,
resolvedspec,
spatial_terms,
&dims_per_term,
kappa_options,
)
} else {
SpatialLogKappaCoords::upper_bounds_from_data(
data,
resolvedspec,
spatial_terms,
kappa_options,
)
};
let log_kappa0 = log_kappa0.clamp_to_bounds(&log_kappa_lower, &log_kappa_upper);
let setup = ExactJointHyperSetup::new(
best.fit.lambdas.mapv(f64::ln),
Array1::<f64>::from_elem(rho_dim, -JOINT_RHO_BOUND),
Array1::<f64>::from_elem(rho_dim, JOINT_RHO_BOUND),
log_kappa0,
log_kappa_lower,
log_kappa_upper,
);
let theta0 = setup.theta0();
let lower = setup.lower();
let upper = setup.upper();
let (theta_star, joint_final_value) = if use_aniso {
try_exact_joint_spatial_aniso_optimization(
data,
y,
weights,
offset,
resolvedspec,
&best.design,
family,
options,
spatial_terms,
&dims_per_term,
&theta0,
&lower,
&upper,
rho_dim,
kappa_options,
)?
} else {
try_exact_joint_spatial_isotropic_optimization(
data,
y,
weights,
offset,
resolvedspec,
&best.design,
family,
options,
spatial_terms,
&dims_per_term,
&theta0,
&lower,
&upper,
rho_dim,
kappa_options,
)?
};
let baseline_score = fit_score(&best.fit);
let baseline_result = FittedTermCollectionWithSpec {
fit: best.fit.clone(),
design: best.design.clone(),
resolvedspec: resolvedspec.clone(),
adaptive_diagnostics: best.adaptive_diagnostics.clone(),
};
let accept_tol = options
.tol
.max(1e-8 * baseline_score.abs())
.max(1e-12);
if joint_final_value > baseline_score + accept_tol {
log::info!(
"[spatial-kappa] exact joint spatial candidate worsened the profiled score (joint={:.6e}, baseline={:.6e}, tol={:.2e}); keeping the frozen baseline geometry",
joint_final_value,
baseline_score,
accept_tol,
);
return Ok(Some(baseline_result));
}
let rho_star = theta_star.slice(s![..rho_dim]).mapv(f64::exp);
let log_kappa_star =
SpatialLogKappaCoords::from_theta_tail_with_dims(&theta_star, rho_dim, dims_per_term);
let resolvedspec = log_kappa_star.apply_tospec(resolvedspec, spatial_terms)?;
let optimized = fit_term_collection_forspecwith_heuristic_lambdas(
data,
y,
weights,
offset,
&resolvedspec,
rho_star.as_slice(),
family,
options,
)?;
let mut fit = optimized.fit;
fit.reml_score = joint_final_value;
let optimized_result = FittedTermCollectionWithSpec {
fit,
design: optimized.design,
resolvedspec,
adaptive_diagnostics: optimized.adaptive_diagnostics,
};
Ok(Some(optimized_result))
}
fn try_exact_joint_spatial_aniso_optimization(
data: ArrayView2<'_, f64>,
y: ArrayView1<'_, f64>,
weights: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
resolvedspec: &TermCollectionSpec,
baseline_design: &TermCollectionDesign,
family: LikelihoodFamily,
options: &FitOptions,
spatial_terms: &[usize],
dims_per_term: &[usize],
theta0: &Array1<f64>,
lower: &Array1<f64>,
upper: &Array1<f64>,
rho_dim: usize,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<(Array1<f64>, f64), EstimationError> {
assert!(lower.len() == theta0.len() && upper.len() == theta0.len());
assert!(baseline_design.smooth.terms.len() >= spatial_terms.len());
use crate::solver::outer_strategy::{
DeclaredHessianForm, Derivative, OuterEval, OuterEvalOrder,
};
let theta_dim = theta0.len();
let psi_dim = theta_dim - rho_dim;
let analytic_outer_hessian_available =
exact_joint_spatial_outer_hessian_available(family, baseline_design);
if !analytic_outer_hessian_available {
log::info!(
"[spatial-aniso-joint] analytic outer Hessian unavailable for family/design; routing without second-order geometry (psi_dim={psi_dim})"
);
}
let prefer_gradient_only = false;
log::trace!(
"[spatial-aniso-joint] starting analytic optimization: rho_dim={}, psi_dim={}, dims_per_term={:?}",
rho_dim,
psi_dim,
dims_per_term,
);
struct AnisoJointContext<'d> {
data: ArrayView2<'d, f64>,
rho_dim: usize,
cache: SingleBlockExactJointDesignCache<'d>,
evaluator: crate::estimate::ExternalJointHyperEvaluator<'d>,
}
impl<'d> AnisoJointContext<'d> {
fn eval_full(
&mut self,
theta: &Array1<f64>,
order: OuterEvalOrder,
analytic_outer_hessian_available: bool,
) -> Result<
(
f64,
Array1<f64>,
crate::solver::outer_strategy::HessianResult,
),
EstimationError,
> {
let allow_second_order = matches!(order, OuterEvalOrder::ValueGradientHessian)
&& analytic_outer_hessian_available;
if let Some(eval) = self.cache.memoized_eval(theta) {
let cached_satisfies_order = !allow_second_order || eval.2.is_analytic();
if cached_satisfies_order {
return Ok(eval);
}
}
self.cache
.ensure_theta(theta)
.map_err(EstimationError::InvalidInput)?;
let hyper_dirs = try_build_spatial_log_kappa_hyper_dirs(
self.data,
self.cache.spec(),
self.cache.design(),
&self.cache.spatial_terms,
)?
.ok_or_else(|| {
EstimationError::InvalidInput(
"failed to build aniso hyper_dirs at current psi".to_string(),
)
})?;
let eval = evaluate_joint_reml_outer_eval_at_theta(
&mut self.evaluator,
self.cache.design(),
theta,
self.rho_dim,
hyper_dirs,
None,
if allow_second_order {
order
} else {
OuterEvalOrder::ValueAndGradient
},
);
if let Ok(ref value) = eval {
self.cache.store_eval(value.clone());
}
eval
}
fn eval_efs(
&mut self,
theta: &Array1<f64>,
) -> Result<crate::solver::outer_strategy::EfsEval, EstimationError> {
self.cache
.ensure_theta(theta)
.map_err(EstimationError::InvalidInput)?;
let hyper_dirs = try_build_spatial_log_kappa_hyper_dirs(
self.data,
self.cache.spec(),
self.cache.design(),
&self.cache.spatial_terms,
)?
.ok_or_else(|| {
EstimationError::InvalidInput(
"failed to build aniso hyper_dirs for exact-joint EFS".to_string(),
)
})?;
evaluate_joint_reml_efs_at_theta(
&mut self.evaluator,
self.cache.design(),
theta,
self.rho_dim,
hyper_dirs,
None,
)
}
fn eval_cost(&mut self, theta: &Array1<f64>) -> f64 {
if let Some(cost) = self.cache.memoized_cost(theta) {
return cost;
}
if self.cache.ensure_theta(theta).is_err() {
return f64::INFINITY;
}
let result = {
let design = self.cache.design();
self.evaluator.evaluate_cost_only(
&design.design,
&design.penalties,
&design.nullspace_dims,
design.linear_constraints.clone(),
theta,
self.rho_dim,
None,
"spatial-aniso-joint cost-only",
)
};
match result {
Ok(cost) => {
self.cache.store_cost(cost);
cost
}
Err(_) => f64::INFINITY,
}
}
fn reset(&mut self) {
self.cache.current_theta = None;
self.cache.last_cost = None;
self.cache.last_eval = None;
}
}
let mut ctx = AnisoJointContext {
data,
rho_dim,
cache: SingleBlockExactJointDesignCache::new(
data,
resolvedspec.clone(),
baseline_design.clone(),
spatial_terms.to_vec(),
rho_dim,
dims_per_term.to_vec(),
)
.map_err(EstimationError::InvalidInput)?,
evaluator: crate::estimate::ExternalJointHyperEvaluator::new(
y,
weights,
&baseline_design.design,
offset,
&baseline_design.penalties,
&external_opts_for_design(family, baseline_design, options),
"spatial-aniso-joint",
)?,
};
let problem = exact_joint_multistart_outer_problem(
theta0,
lower,
upper,
rho_dim,
psi_dim,
theta_dim,
Derivative::Analytic,
if analytic_outer_hessian_available {
DeclaredHessianForm::Either
} else {
DeclaredHessianForm::Unavailable
},
prefer_gradient_only,
false,
seed_risk_profile_for_likelihood_family(family),
kappa_options.rel_tol.max(1e-6),
kappa_options.max_outer_iter.max(1),
Some(kappa_options.log_step.clamp(0.25, 1.0)),
None,
);
let eval_outer = |ctx: &mut &mut AnisoJointContext<'_>,
theta: &Array1<f64>,
order: OuterEvalOrder|
-> Result<OuterEval, EstimationError> {
match ctx.eval_full(theta, order, analytic_outer_hessian_available) {
Ok((cost, grad, hess)) => Ok(OuterEval {
cost,
gradient: grad,
hessian: hess,
}),
Err(err) => Err(err),
}
};
let mut obj = problem.build_objective_with_eval_order(
&mut ctx,
|ctx: &mut &mut AnisoJointContext<'_>, theta: &Array1<f64>| Ok(ctx.eval_cost(theta)),
|ctx: &mut &mut AnisoJointContext<'_>, theta: &Array1<f64>| {
eval_outer(
ctx,
theta,
if analytic_outer_hessian_available {
OuterEvalOrder::ValueGradientHessian
} else {
OuterEvalOrder::ValueAndGradient
},
)
},
|ctx: &mut &mut AnisoJointContext<'_>, theta: &Array1<f64>, order: OuterEvalOrder| {
eval_outer(ctx, theta, order)
},
Some(|ctx: &mut &mut AnisoJointContext<'_>| {
ctx.reset();
}),
Some(|ctx: &mut &mut AnisoJointContext<'_>, theta: &Array1<f64>| ctx.eval_efs(theta)),
);
let result = problem.run(&mut obj, "aniso-psi joint REML").map_err(|e| {
EstimationError::InvalidInput(format!(
"anisotropic analytic optimization failed after exhausting strategy fallbacks: {e}"
))
})?;
log::trace!(
"[spatial-aniso-joint] converged in {} iterations, final_value={:.6e}, grad_norm={:.6e}",
result.iterations,
result.final_value,
result.final_grad_norm,
);
let theta_star = result.rho;
Ok((theta_star, result.final_value))
}
fn try_exact_joint_spatial_isotropic_optimization(
data: ArrayView2<'_, f64>,
y: ArrayView1<'_, f64>,
weights: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
resolvedspec: &TermCollectionSpec,
baseline_design: &TermCollectionDesign,
family: LikelihoodFamily,
options: &FitOptions,
spatial_terms: &[usize],
dims_per_term: &[usize],
theta0: &Array1<f64>,
lower: &Array1<f64>,
upper: &Array1<f64>,
rho_dim: usize,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<(Array1<f64>, f64), EstimationError> {
assert!(lower.len() == theta0.len() && upper.len() == theta0.len());
assert!(baseline_design.smooth.terms.len() >= spatial_terms.len());
use crate::solver::outer_strategy::{
DeclaredHessianForm, Derivative, OuterEval, OuterEvalOrder,
};
let theta_dim = theta0.len();
let kappa_dim = theta_dim - rho_dim;
let analytic_outer_hessian_available =
exact_joint_spatial_outer_hessian_available(family, baseline_design);
let prefer_gradient_only = false;
log::trace!(
"[spatial-iso-joint] starting analytic optimization: rho_dim={}, kappa_dim={}, dims_per_term={:?}",
rho_dim,
kappa_dim,
dims_per_term,
);
struct IsoJointContext<'d> {
data: ArrayView2<'d, f64>,
rho_dim: usize,
cache: SingleBlockExactJointDesignCache<'d>,
evaluator: crate::estimate::ExternalJointHyperEvaluator<'d>,
}
impl<'d> IsoJointContext<'d> {
fn eval_full(
&mut self,
theta: &Array1<f64>,
order: OuterEvalOrder,
analytic_outer_hessian_available: bool,
) -> Result<
(
f64,
Array1<f64>,
crate::solver::outer_strategy::HessianResult,
),
EstimationError,
> {
let allow_second_order = matches!(order, OuterEvalOrder::ValueGradientHessian)
&& analytic_outer_hessian_available;
if let Some(eval) = self.cache.memoized_eval(theta) {
let cached_satisfies_order = !allow_second_order || eval.2.is_analytic();
if cached_satisfies_order {
return Ok(eval);
}
}
self.cache
.ensure_theta(theta)
.map_err(EstimationError::InvalidInput)?;
let hyper_dirs = try_build_spatial_log_kappa_hyper_dirs(
self.data,
self.cache.spec(),
self.cache.design(),
&self.cache.spatial_terms,
)?
.ok_or_else(|| {
EstimationError::InvalidInput(
"failed to build isotropic hyper_dirs at current kappa".to_string(),
)
})?;
let eval = evaluate_joint_reml_outer_eval_at_theta(
&mut self.evaluator,
self.cache.design(),
theta,
self.rho_dim,
hyper_dirs,
None,
if allow_second_order {
order
} else {
OuterEvalOrder::ValueAndGradient
},
);
if let Ok(ref value) = eval {
self.cache.store_eval(value.clone());
}
eval
}
fn eval_efs(
&mut self,
theta: &Array1<f64>,
) -> Result<crate::solver::outer_strategy::EfsEval, EstimationError> {
self.cache
.ensure_theta(theta)
.map_err(EstimationError::InvalidInput)?;
let hyper_dirs = try_build_spatial_log_kappa_hyper_dirs(
self.data,
self.cache.spec(),
self.cache.design(),
&self.cache.spatial_terms,
)?
.ok_or_else(|| {
EstimationError::InvalidInput(
"failed to build isotropic hyper_dirs for exact-joint EFS".to_string(),
)
})?;
evaluate_joint_reml_efs_at_theta(
&mut self.evaluator,
self.cache.design(),
theta,
self.rho_dim,
hyper_dirs,
None,
)
}
fn eval_cost(&mut self, theta: &Array1<f64>) -> f64 {
if let Some(cost) = self.cache.memoized_cost(theta) {
return cost;
}
if self.cache.ensure_theta(theta).is_err() {
return f64::INFINITY;
}
let result = {
let design = self.cache.design();
self.evaluator.evaluate_cost_only(
&design.design,
&design.penalties,
&design.nullspace_dims,
design.linear_constraints.clone(),
theta,
self.rho_dim,
None,
"spatial-iso-joint cost-only",
)
};
match result {
Ok(cost) => {
self.cache.store_cost(cost);
cost
}
Err(_) => f64::INFINITY,
}
}
fn reset(&mut self) {
self.cache.current_theta = None;
self.cache.last_cost = None;
self.cache.last_eval = None;
}
}
let mut ctx = IsoJointContext {
data,
rho_dim,
cache: SingleBlockExactJointDesignCache::new(
data,
resolvedspec.clone(),
baseline_design.clone(),
spatial_terms.to_vec(),
rho_dim,
dims_per_term.to_vec(),
)
.map_err(EstimationError::InvalidInput)?,
evaluator: crate::estimate::ExternalJointHyperEvaluator::new(
y,
weights,
&baseline_design.design,
offset,
&baseline_design.penalties,
&external_opts_for_design(family, baseline_design, options),
"spatial-iso-joint",
)?,
};
let problem = exact_joint_multistart_outer_problem(
theta0,
lower,
upper,
rho_dim,
kappa_dim,
theta_dim,
Derivative::Analytic,
if analytic_outer_hessian_available {
DeclaredHessianForm::Either
} else {
DeclaredHessianForm::Unavailable
},
prefer_gradient_only,
false,
seed_risk_profile_for_likelihood_family(family),
kappa_options.rel_tol.max(1e-6),
kappa_options.max_outer_iter.max(1),
Some(kappa_options.log_step.clamp(0.25, 1.0)),
None,
);
let eval_outer = |ctx: &mut &mut IsoJointContext<'_>,
theta: &Array1<f64>,
order: OuterEvalOrder|
-> Result<OuterEval, EstimationError> {
match ctx.eval_full(theta, order, analytic_outer_hessian_available) {
Ok((cost, grad, hess)) => Ok(OuterEval {
cost,
gradient: grad,
hessian: hess,
}),
Err(err) => Err(err),
}
};
let mut obj = problem.build_objective_with_eval_order(
&mut ctx,
|ctx: &mut &mut IsoJointContext<'_>, theta: &Array1<f64>| Ok(ctx.eval_cost(theta)),
|ctx: &mut &mut IsoJointContext<'_>, theta: &Array1<f64>| {
eval_outer(
ctx,
theta,
if analytic_outer_hessian_available {
OuterEvalOrder::ValueGradientHessian
} else {
OuterEvalOrder::ValueAndGradient
},
)
},
|ctx: &mut &mut IsoJointContext<'_>, theta: &Array1<f64>, order: OuterEvalOrder| {
eval_outer(ctx, theta, order)
},
Some(|ctx: &mut &mut IsoJointContext<'_>| {
ctx.reset();
}),
Some(|ctx: &mut &mut IsoJointContext<'_>, theta: &Array1<f64>| ctx.eval_efs(theta)),
);
let result = problem.run(&mut obj, "iso-kappa joint REML").map_err(|e| {
EstimationError::InvalidInput(format!(
"isotropic analytic optimization failed after exhausting strategy fallbacks: {e}"
))
})?;
log::trace!(
"[spatial-iso-joint] converged in {} iterations, final_value={:.6e}, grad_norm={:.6e}",
result.iterations,
result.final_value,
result.final_grad_norm,
);
Ok((result.rho, result.final_value))
}
fn set_spatial_length_scale(
spec: &mut TermCollectionSpec,
term_idx: usize,
length_scale: f64,
) -> Result<(), EstimationError> {
let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
return Err(EstimationError::InvalidInput(format!(
"spatial length-scale term index {term_idx} out of range"
)));
};
match &mut term.basis {
SmoothBasisSpec::ThinPlate { spec, .. } => {
spec.length_scale = length_scale;
Ok(())
}
SmoothBasisSpec::Matern { spec, .. } => {
spec.length_scale = length_scale;
Ok(())
}
SmoothBasisSpec::Duchon { spec, .. } => {
spec.length_scale = Some(length_scale);
Ok(())
}
_ => Err(EstimationError::InvalidInput(format!(
"term '{}' does not expose a spatial length scale",
term.name
))),
}
}
pub fn get_spatial_length_scale(spec: &TermCollectionSpec, term_idx: usize) -> Option<f64> {
spec.smooth_terms
.get(term_idx)
.and_then(|term| match &term.basis {
SmoothBasisSpec::ThinPlate { spec, .. } => Some(spec.length_scale),
SmoothBasisSpec::Matern { spec, .. } => Some(spec.length_scale),
SmoothBasisSpec::Duchon { spec, .. } => spec.length_scale,
_ => None,
})
}
pub fn freeze_term_collection_from_design(
spec: &TermCollectionSpec,
design: &TermCollectionDesign,
) -> Result<TermCollectionSpec, EstimationError> {
if spec.smooth_terms.len() != design.smooth.terms.len() {
return Err(EstimationError::InvalidInput(format!(
"freeze mismatch: smooth spec count {} != design smooth term count {}",
spec.smooth_terms.len(),
design.smooth.terms.len()
)));
}
if spec.random_effect_terms.len() != design.random_effect_levels.len() {
return Err(EstimationError::InvalidInput(format!(
"freeze mismatch: random-effect spec count {} != design random-effect term count {}",
spec.random_effect_terms.len(),
design.random_effect_levels.len()
)));
}
let mut frozen = spec.clone();
for (term, fitted) in frozen
.smooth_terms
.iter_mut()
.zip(design.smooth.terms.iter())
{
if matches!(&term.basis, SmoothBasisSpec::ThinPlate { .. })
&& matches!(&fitted.metadata, BasisMetadata::Duchon { .. })
{
let (feature_cols, original_identifiability) = match &term.basis {
SmoothBasisSpec::ThinPlate {
feature_cols, spec, ..
} => (feature_cols.clone(), spec.identifiability.clone()),
_ => unreachable!("guarded by the matches! above"),
};
term.basis = SmoothBasisSpec::Duchon {
feature_cols,
spec: DuchonBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 0 },
length_scale: None,
power: 0,
nullspace_order: DuchonNullspaceOrder::Zero,
identifiability: original_identifiability,
aniso_log_scales: None,
operator_penalties: DuchonOperatorPenaltySpec::default(),
},
input_scales: None,
};
}
match (&mut term.basis, &fitted.metadata) {
(
SmoothBasisSpec::BSpline1D { spec: s, .. },
BasisMetadata::BSpline1D {
knots,
identifiability_transform,
},
) => {
s.knotspec = BSplineKnotSpec::Provided(knots.clone());
s.identifiability = match identifiability_transform {
Some(z) => BSplineIdentifiability::FrozenTransform {
transform: z.clone(),
},
None => BSplineIdentifiability::None,
};
}
(
SmoothBasisSpec::ThinPlate {
spec: s,
input_scales,
..
},
BasisMetadata::ThinPlate {
centers,
length_scale,
identifiability_transform,
input_scales: meta_scales,
radial_reparam,
..
},
) => {
s.center_strategy = crate::basis::CenterStrategy::UserProvided(centers.clone());
s.length_scale = *length_scale;
s.identifiability = match identifiability_transform {
Some(z) => SpatialIdentifiability::FrozenTransform {
transform: z.clone(),
},
None => match &s.identifiability {
SpatialIdentifiability::FrozenTransform { .. } => s.identifiability.clone(),
_ => SpatialIdentifiability::None,
},
};
s.radial_reparam = radial_reparam.clone();
*input_scales = meta_scales.clone();
}
(
SmoothBasisSpec::ThinPlate { feature_cols, .. },
BasisMetadata::Duchon {
centers,
length_scale,
power,
nullspace_order,
identifiability_transform,
input_scales: meta_scales,
aniso_log_scales: meta_aniso,
},
) => {
let identifiability = match identifiability_transform {
Some(z) => SpatialIdentifiability::FrozenTransform {
transform: z.clone(),
},
None => SpatialIdentifiability::None,
};
term.basis = SmoothBasisSpec::Duchon {
feature_cols: feature_cols.clone(),
spec: DuchonBasisSpec {
center_strategy: crate::basis::CenterStrategy::UserProvided(
centers.clone(),
),
length_scale: *length_scale,
power: *power,
nullspace_order: *nullspace_order,
identifiability,
aniso_log_scales: meta_aniso.clone(),
operator_penalties: Default::default(),
},
input_scales: meta_scales.clone(),
};
}
(
SmoothBasisSpec::Matern {
spec: s,
input_scales,
..
},
BasisMetadata::Matern {
centers,
length_scale,
nu,
include_intercept,
identifiability_transform,
input_scales: meta_scales,
aniso_log_scales: meta_aniso,
..
},
) => {
s.center_strategy = crate::basis::CenterStrategy::UserProvided(centers.clone());
s.length_scale = *length_scale;
s.nu = *nu;
s.include_intercept = *include_intercept;
s.identifiability = match identifiability_transform {
Some(z) => MaternIdentifiability::FrozenTransform {
transform: z.clone(),
},
None => MaternIdentifiability::None,
};
s.aniso_log_scales = meta_aniso.clone();
*input_scales = meta_scales.clone();
}
(
SmoothBasisSpec::Duchon {
spec: s,
input_scales,
..
},
BasisMetadata::Duchon {
centers,
length_scale,
power,
nullspace_order,
identifiability_transform,
input_scales: meta_scales,
aniso_log_scales: meta_aniso,
},
) => {
s.center_strategy = crate::basis::CenterStrategy::UserProvided(centers.clone());
s.length_scale = *length_scale;
s.power = *power;
s.nullspace_order = *nullspace_order;
s.identifiability = match identifiability_transform {
Some(z) => SpatialIdentifiability::FrozenTransform {
transform: z.clone(),
},
None => match &s.identifiability {
SpatialIdentifiability::FrozenTransform { .. } => s.identifiability.clone(),
_ => SpatialIdentifiability::None,
},
};
s.aniso_log_scales = meta_aniso.clone();
*input_scales = meta_scales.clone();
}
(
SmoothBasisSpec::TensorBSpline {
feature_cols,
spec: s,
},
BasisMetadata::TensorBSpline {
feature_cols: fitted_cols,
knots,
degrees,
identifiability_transform,
},
) => {
if s.marginalspecs.len() != knots.len() || s.marginalspecs.len() != degrees.len() {
return Err(EstimationError::InvalidInput(format!(
"tensor freeze mismatch for '{}': marginalspecs={}, knots={}, degrees={}",
term.name,
s.marginalspecs.len(),
knots.len(),
degrees.len()
)));
}
*feature_cols = fitted_cols.clone();
for i in 0..s.marginalspecs.len() {
s.marginalspecs[i].degree = degrees[i];
s.marginalspecs[i].knotspec = BSplineKnotSpec::Provided(knots[i].clone());
}
s.identifiability = match identifiability_transform {
Some(z) => TensorBSplineIdentifiability::FrozenTransform {
transform: z.clone(),
},
None => TensorBSplineIdentifiability::None,
};
}
_ => {
return Err(EstimationError::InvalidInput(format!(
"smooth metadata/spec type mismatch while freezing term '{}'",
term.name
)));
}
}
}
for (idx, rt) in frozen.random_effect_terms.iter_mut().enumerate() {
let (_, kept_levels) = &design.random_effect_levels[idx];
rt.frozen_levels = Some(kept_levels.clone());
}
Ok(frozen)
}
#[derive(Debug, Clone)]
struct SingleSmoothTermRealization {
design_local: DesignMatrix,
term: SmoothTerm,
dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
}
impl SingleSmoothTermRealization {
fn active_penaltyinfo(&self) -> Vec<PenaltyInfo> {
self.term
.penaltyinfo_local
.iter()
.filter(|info| info.active)
.cloned()
.collect()
}
}
fn build_single_smooth_term_realization(
data: ArrayView2<'_, f64>,
termspec: &SmoothTermSpec,
) -> Result<SingleSmoothTermRealization, BasisError> {
let raw = build_smooth_design(data, std::slice::from_ref(termspec))?;
finish_single_smooth_term_realization(raw)
}
fn build_single_smooth_term_realization_withworkspace(
data: ArrayView2<'_, f64>,
termspec: &SmoothTermSpec,
workspace: &mut crate::basis::BasisWorkspace,
) -> Result<SingleSmoothTermRealization, BasisError> {
let raw = build_smooth_design_withworkspace(data, std::slice::from_ref(termspec), workspace)?;
finish_single_smooth_term_realization(raw)
}
fn finish_single_smooth_term_realization(
raw: RawSmoothDesign,
) -> Result<SingleSmoothTermRealization, BasisError> {
let RawSmoothDesign {
term_designs,
dropped_penaltyinfo,
terms,
..
} = raw;
let term = terms.into_iter().next().ok_or_else(|| {
BasisError::InvalidInput("single-term smooth build returned no term".to_string())
})?;
let design = term_designs.into_iter().next().ok_or_else(|| {
BasisError::InvalidInput("single-term smooth build returned no term design".to_string())
})?;
Ok(SingleSmoothTermRealization {
design_local: design,
term,
dropped_penaltyinfo,
})
}
fn rebuild_smooth_auxiliary_state(
smooth: &mut SmoothDesign,
dropped_penaltyinfo_by_term: &[Vec<DroppedPenaltyBlockInfo>],
) -> Result<(), String> {
if dropped_penaltyinfo_by_term.len() != smooth.terms.len() {
return Err(format!(
"smooth dropped-penalty cache mismatch: terms={}, dropped_sets={}",
smooth.terms.len(),
dropped_penaltyinfo_by_term.len()
));
}
let total_p = smooth.total_smooth_cols();
let mut coefficient_lower_bounds = Array1::<f64>::from_elem(total_p, f64::NEG_INFINITY);
let mut any_bounds = false;
let mut linear_constraintrows: Vec<Array1<f64>> = Vec::new();
let mut linear_constraint_b: Vec<f64> = Vec::new();
for term in &smooth.terms {
let range = term.coeff_range.clone();
if let Some(lb_local) = term.lower_bounds_local.as_ref() {
if lb_local.len() != range.len() {
return Err(format!(
"smooth lower-bound cache mismatch for term '{}': bounds={}, coeffs={}",
term.name,
lb_local.len(),
range.len()
));
}
coefficient_lower_bounds
.slice_mut(s![range.clone()])
.assign(lb_local);
any_bounds = true;
}
if let Some(lin_local) = term.linear_constraints_local.as_ref() {
if lin_local.a.ncols() != range.len() {
return Err(format!(
"smooth linear-constraint cache mismatch for term '{}': cols={}, coeffs={}",
term.name,
lin_local.a.ncols(),
range.len()
));
}
for r in 0..lin_local.a.nrows() {
let mut row = Array1::<f64>::zeros(total_p);
row.slice_mut(s![range.clone()]).assign(&lin_local.a.row(r));
linear_constraintrows.push(row);
linear_constraint_b.push(lin_local.b[r]);
}
}
}
smooth.coefficient_lower_bounds = if any_bounds {
Some(coefficient_lower_bounds)
} else {
None
};
smooth.linear_constraints = if linear_constraintrows.is_empty() {
None
} else {
let mut a = Array2::<f64>::zeros((linear_constraintrows.len(), total_p));
for (i, row) in linear_constraintrows.iter().enumerate() {
a.row_mut(i).assign(row);
}
Some(LinearInequalityConstraints {
a,
b: Array1::from_vec(linear_constraint_b),
})
};
smooth.dropped_penaltyinfo = dropped_penaltyinfo_by_term
.iter()
.flat_map(|infos| infos.iter().cloned())
.collect();
Ok(())
}
fn rebuild_term_collection_auxiliary_state(
spec: &TermCollectionSpec,
design: &mut TermCollectionDesign,
) -> Result<(), String> {
if spec.linear_terms.len() != design.linear_ranges.len() {
return Err(format!(
"term-collection linear bookkeeping mismatch: spec_terms={}, design_ranges={}",
spec.linear_terms.len(),
design.linear_ranges.len()
));
}
let p_total = design.design.ncols();
let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
let mut coefficient_lower_bounds = Array1::<f64>::from_elem(p_total, f64::NEG_INFINITY);
let mut any_bounds = false;
let mut linear_constraintrows: Vec<Array1<f64>> = Vec::new();
let mut linear_constraint_b: Vec<f64> = Vec::new();
for (linear, (_, range)) in spec.linear_terms.iter().zip(design.linear_ranges.iter()) {
if range.len() != 1 {
return Err(format!(
"linear term '{}' expected one coefficient column, found {}",
linear.name,
range.len()
));
}
let col = range.start;
if let Some(lb) = linear.coefficient_min {
let mut row = Array1::<f64>::zeros(p_total);
row[col] = 1.0;
linear_constraintrows.push(row);
linear_constraint_b.push(lb);
}
if let Some(ub) = linear.coefficient_max {
let mut row = Array1::<f64>::zeros(p_total);
row[col] = -1.0;
linear_constraintrows.push(row);
linear_constraint_b.push(-ub);
}
}
if let Some(lb_smooth) = design.smooth.coefficient_lower_bounds.as_ref() {
if lb_smooth.len() != design.smooth.total_smooth_cols() {
return Err(format!(
"smooth lower-bound width mismatch: bounds={}, smooth_cols={}",
lb_smooth.len(),
design.smooth.total_smooth_cols()
));
}
coefficient_lower_bounds
.slice_mut(s![
smooth_start..(smooth_start + design.smooth.total_smooth_cols())
])
.assign(lb_smooth);
any_bounds = true;
}
if let Some(lin_smooth) = design.smooth.linear_constraints.as_ref() {
if lin_smooth.a.ncols() != design.smooth.total_smooth_cols() {
return Err(format!(
"smooth linear-constraint width mismatch: cols={}, smooth_cols={}",
lin_smooth.a.ncols(),
design.smooth.total_smooth_cols()
));
}
let mut a_global = Array2::<f64>::zeros((lin_smooth.a.nrows(), p_total));
a_global
.slice_mut(s![
..,
smooth_start..(smooth_start + design.smooth.total_smooth_cols())
])
.assign(&lin_smooth.a);
for r in 0..a_global.nrows() {
linear_constraintrows.push(a_global.row(r).to_owned());
linear_constraint_b.push(lin_smooth.b[r]);
}
}
let lower_bound_constraints = if any_bounds {
linear_constraints_from_lower_bounds_global(&coefficient_lower_bounds)
} else {
None
};
let explicit_linear_constraints = if linear_constraintrows.is_empty() {
None
} else {
let mut a = Array2::<f64>::zeros((linear_constraintrows.len(), p_total));
for (i, row) in linear_constraintrows.iter().enumerate() {
a.row_mut(i).assign(row);
}
Some(LinearInequalityConstraints {
a,
b: Array1::from_vec(linear_constraint_b),
})
};
design.coefficient_lower_bounds = if any_bounds {
Some(coefficient_lower_bounds)
} else {
None
};
design.linear_constraints =
merge_linear_constraints_global(explicit_linear_constraints, lower_bound_constraints);
design.dropped_penaltyinfo = design.smooth.dropped_penaltyinfo.clone();
Ok(())
}
fn theta_values_match(left: &Array1<f64>, right: &Array1<f64>) -> bool {
left.len() == right.len()
&& left
.iter()
.zip(right.iter())
.all(|(&l, &r)| l.to_bits() == r.to_bits())
}
fn spatial_aniso_matches(left: Option<&[f64]>, right: Option<&[f64]>) -> bool {
match (left, right) {
(None, None) => true,
(Some(a), Some(b)) => {
a.len() == b.len()
&& a.iter()
.zip(b.iter())
.all(|(&x, &y)| x.to_bits() == y.to_bits())
}
_ => false,
}
}
fn spatial_length_scale_matches(left: Option<f64>, right: Option<f64>) -> bool {
match (left, right) {
(None, None) => true,
(Some(a), Some(b)) => a.to_bits() == b.to_bits(),
_ => false,
}
}
struct FrozenTermCollectionIncrementalRealizer<'d> {
data: ArrayView2<'d, f64>,
spec: TermCollectionSpec,
design: TermCollectionDesign,
fixed_blocks: Vec<DesignBlock>,
dropped_penaltyinfo_by_term: Vec<Vec<DroppedPenaltyBlockInfo>>,
smooth_penalty_ranges: Vec<Range<usize>>,
full_penalty_ranges: Vec<Range<usize>>,
basisworkspace: crate::basis::BasisWorkspace,
}
impl<'d> std::fmt::Debug for FrozenTermCollectionIncrementalRealizer<'d> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FrozenTermCollectionIncrementalRealizer")
.field("data_shape", &(self.data.nrows(), self.data.ncols()))
.field("fixed_blocks", &self.fixed_blocks.len())
.finish_non_exhaustive()
}
}
impl<'d> FrozenTermCollectionIncrementalRealizer<'d> {
fn new(
data: ArrayView2<'d, f64>,
spec: TermCollectionSpec,
design: TermCollectionDesign,
) -> Result<Self, String> {
if spec.smooth_terms.len() != design.smooth.terms.len() {
return Err(format!(
"incremental realizer smooth term mismatch: spec_terms={}, design_terms={}",
spec.smooth_terms.len(),
design.smooth.terms.len()
));
}
let mut smooth_cursor = 0usize;
let mut smooth_penalty_ranges = Vec::with_capacity(design.smooth.terms.len());
for term in &design.smooth.terms {
let next = smooth_cursor + term.penalties_local.len();
smooth_penalty_ranges.push(smooth_cursor..next);
smooth_cursor = next;
}
if smooth_cursor != design.smooth.penalties.len() {
return Err(format!(
"incremental realizer smooth penalty mismatch: ranged={}, actual={}",
smooth_cursor,
design.smooth.penalties.len()
));
}
let fixed_penalty_offset = design
.penalties
.len()
.checked_sub(design.smooth.penalties.len())
.ok_or_else(|| {
"incremental realizer encountered invalid penalty bookkeeping".to_string()
})?;
let full_penalty_ranges = smooth_penalty_ranges
.iter()
.map(|range| (fixed_penalty_offset + range.start)..(fixed_penalty_offset + range.end))
.collect::<Vec<_>>();
let fixed_blocks = build_term_collection_fixed_blocks(data, &spec)
.map_err(|e| format!("failed to cache fixed term-collection blocks: {e}"))?;
let mut dropped_penaltyinfo_by_term = Vec::with_capacity(spec.smooth_terms.len());
for (term_idx, termspec) in spec.smooth_terms.iter().enumerate() {
let realization =
build_single_smooth_term_realization(data, termspec).map_err(|e| {
format!(
"failed to build cached realization for smooth term '{}' (index {}): {e}",
termspec.name, term_idx
)
})?;
let expected_cols = design.smooth.terms[term_idx].coeff_range.len();
if realization.design_local.ncols() != expected_cols {
return Err(format!(
"cached realization width mismatch for term '{}': cached_cols={}, design_cols={}",
termspec.name,
realization.design_local.ncols(),
expected_cols
));
}
if realization.active_penaltyinfo().len()
!= design.smooth.terms[term_idx].penalties_local.len()
{
return Err(format!(
"cached realization penalty mismatch for term '{}': cached_penalties={}, design_penalties={}",
termspec.name,
realization.active_penaltyinfo().len(),
design.smooth.terms[term_idx].penalties_local.len()
));
}
dropped_penaltyinfo_by_term.push(realization.dropped_penaltyinfo);
}
Ok(Self {
data,
spec,
design,
fixed_blocks,
dropped_penaltyinfo_by_term,
smooth_penalty_ranges,
full_penalty_ranges,
basisworkspace: crate::basis::BasisWorkspace::new(),
})
}
fn spec(&self) -> &TermCollectionSpec {
&self.spec
}
fn design(&self) -> &TermCollectionDesign {
&self.design
}
fn apply_log_kappa(
&mut self,
log_kappa: &SpatialLogKappaCoords,
term_indices: &[usize],
) -> Result<(), String> {
if term_indices.len() != log_kappa.dims_per_term().len() {
return Err(format!(
"incremental realizer log-kappa term mismatch: term_indices={}, dims_per_term={}",
term_indices.len(),
log_kappa.dims_per_term().len()
));
}
let mut any_changed = false;
for (slot, &term_idx) in term_indices.iter().enumerate() {
any_changed |= self.apply_log_kappa_to_term(term_idx, log_kappa.term_slice(slot))?;
}
if any_changed {
self.refresh_full_design_operator()?;
rebuild_smooth_auxiliary_state(
&mut self.design.smooth,
&self.dropped_penaltyinfo_by_term,
)?;
rebuild_term_collection_auxiliary_state(&self.spec, &mut self.design)?;
}
Ok(())
}
fn apply_log_kappa_to_term(&mut self, term_idx: usize, psi: &[f64]) -> Result<bool, String> {
if !spatial_term_supports_hyper_optimization(&self.spec, term_idx) {
return Err(format!(
"incremental realizer term {term_idx} does not expose spatial hyperparameters"
));
}
let current_length_scale = get_spatial_length_scale(&self.spec, term_idx);
let current_aniso = get_spatial_aniso_log_scales(&self.spec, term_idx);
let (next_length_scale, next_aniso) =
spatial_term_psi_to_length_scale_and_aniso(&self.spec, term_idx, psi);
let same_length = spatial_length_scale_matches(current_length_scale, next_length_scale);
let same_aniso = spatial_aniso_matches(current_aniso.as_deref(), next_aniso.as_deref());
if same_length && same_aniso {
return Ok(false);
}
if let Some(length_scale) = next_length_scale {
set_spatial_length_scale(&mut self.spec, term_idx, length_scale)
.map_err(|e| e.to_string())?;
}
if let Some(eta) = next_aniso {
set_spatial_aniso_log_scales(&mut self.spec, term_idx, eta)
.map_err(|e| e.to_string())?;
}
let termspec = self
.spec
.smooth_terms
.get(term_idx)
.ok_or_else(|| format!("incremental realizer smooth term {term_idx} out of range"))?
.clone();
let realization = build_single_smooth_term_realization_withworkspace(
self.data,
&termspec,
&mut self.basisworkspace,
)
.map_err(|e| {
format!(
"failed to rebuild smooth term '{}' during incremental κ realization: {e}",
termspec.name
)
})?;
self.replace_term_realization(term_idx, realization)?;
Ok(true)
}
fn replace_term_realization(
&mut self,
term_idx: usize,
realization: SingleSmoothTermRealization,
) -> Result<(), String> {
let t_replace = std::time::Instant::now();
let SingleSmoothTermRealization {
design_local,
term,
dropped_penaltyinfo,
} = realization;
let SmoothTerm {
name,
penalties_local,
nullspace_dims,
penaltyinfo_local,
metadata,
lower_bounds_local,
linear_constraints_local,
..
} = term;
let coeff_range = self
.design
.smooth
.terms
.get(term_idx)
.ok_or_else(|| format!("incremental realizer smooth term {term_idx} out of range"))?
.coeff_range
.clone();
if design_local.ncols() != coeff_range.len() {
return Err(format!(
"incremental realizer width mismatch for term {}: rebuilt_cols={}, cached_cols={}",
term_idx,
design_local.ncols(),
coeff_range.len()
));
}
if design_local.nrows() != self.design.design.nrows() {
return Err(format!(
"incremental realizer row mismatch for term {}: rebuilt_rows={}, design_rows={}",
term_idx,
design_local.nrows(),
self.design.design.nrows()
));
}
let active_penaltyinfo = penaltyinfo_local
.iter()
.filter(|info| info.active)
.cloned()
.collect::<Vec<_>>();
let smooth_penalty_range = self
.smooth_penalty_ranges
.get(term_idx)
.ok_or_else(|| {
format!("incremental realizer missing smooth penalty range for term {term_idx}")
})?
.clone();
let full_penalty_range = self
.full_penalty_ranges
.get(term_idx)
.ok_or_else(|| {
format!("incremental realizer missing full penalty range for term {term_idx}")
})?
.clone();
if active_penaltyinfo.len() != smooth_penalty_range.len()
|| penalties_local.len() != smooth_penalty_range.len()
|| nullspace_dims.len() != smooth_penalty_range.len()
{
return Err(format!(
"incremental realizer topology changed for term '{}': penalties={}, infos={}, nullspaces={}, cached_penalties={}",
name,
penalties_local.len(),
active_penaltyinfo.len(),
nullspace_dims.len(),
smooth_penalty_range.len()
));
}
self.design.smooth.term_designs[term_idx] = design_local;
for (offset, penalty_local) in penalties_local.iter().enumerate() {
let smooth_penalty_idx = smooth_penalty_range.start + offset;
let full_penalty_idx = full_penalty_range.start + offset;
let nullspace_dim = nullspace_dims[offset];
let penalty_info = active_penaltyinfo[offset].clone();
if penalty_local.nrows() != coeff_range.len()
|| penalty_local.ncols() != coeff_range.len()
{
return Err(format!(
"incremental realizer penalty shape mismatch for term '{}' penalty {}: \
penalty is {}x{} but coeff_range has {} columns",
name,
offset,
penalty_local.nrows(),
penalty_local.ncols(),
coeff_range.len()
));
}
let smooth_penalty = self
.design
.smooth
.penalties
.get_mut(smooth_penalty_idx)
.ok_or_else(|| {
format!(
"incremental realizer smooth penalty {} out of range for term {}",
smooth_penalty_idx, term_idx
)
})?;
smooth_penalty.local.assign(penalty_local);
let full_bp = self
.design
.penalties
.get_mut(full_penalty_idx)
.ok_or_else(|| {
format!(
"incremental realizer full penalty {} out of range for term {}",
full_penalty_idx, term_idx
)
})?;
full_bp.local.assign(penalty_local);
self.design.smooth.nullspace_dims[smooth_penalty_idx] = nullspace_dim;
self.design.nullspace_dims[full_penalty_idx] = nullspace_dim;
self.design.smooth.penaltyinfo[smooth_penalty_idx].global_index = smooth_penalty_idx;
self.design.smooth.penaltyinfo[smooth_penalty_idx].termname = Some(name.clone());
self.design.smooth.penaltyinfo[smooth_penalty_idx].penalty = penalty_info.clone();
self.design.penaltyinfo[full_penalty_idx].global_index = full_penalty_idx;
self.design.penaltyinfo[full_penalty_idx].termname = Some(name.clone());
self.design.penaltyinfo[full_penalty_idx].penalty = penalty_info;
}
let target_term = self.design.smooth.terms.get_mut(term_idx).ok_or_else(|| {
format!("incremental realizer smooth term {term_idx} disappeared during replacement")
})?;
target_term.penalties_local = penalties_local;
target_term.nullspace_dims = nullspace_dims;
target_term.penaltyinfo_local = penaltyinfo_local;
target_term.metadata = metadata;
target_term.lower_bounds_local = lower_bounds_local;
target_term.linear_constraints_local = linear_constraints_local;
self.dropped_penaltyinfo_by_term[term_idx] = dropped_penaltyinfo;
log::info!(
"[STAGE] smooth basis rebuild (term {}, '{}', cols={}): {:.3}s",
term_idx,
target_term.name,
coeff_range.len(),
t_replace.elapsed().as_secs_f64(),
);
Ok(())
}
fn refresh_full_design_operator(&mut self) -> Result<(), String> {
let mut blocks = Vec::<DesignBlock>::with_capacity(
self.fixed_blocks.len() + self.design.smooth.term_designs.len(),
);
blocks.extend(self.fixed_blocks.iter().cloned());
for term_design in &self.design.smooth.term_designs {
blocks.push(DesignBlock::from(term_design));
}
self.design.design = assemble_term_collection_design_matrix(blocks)
.map_err(|e| format!("failed to refresh term-collection design: {e}"))?;
Ok(())
}
}
fn build_term_collection_fixed_blocks(
data: ArrayView2<'_, f64>,
spec: &TermCollectionSpec,
) -> Result<Vec<DesignBlock>, BasisError> {
let mut blocks = Vec::<DesignBlock>::new();
blocks.push(DesignBlock::Intercept(data.nrows()));
if !spec.linear_terms.is_empty() {
for linear in &spec.linear_terms {
if linear.feature_col >= data.ncols() {
return Err(BasisError::DimensionMismatch(format!(
"linear term '{}' feature column {} out of bounds for {} columns",
linear.name,
linear.feature_col,
data.ncols()
)));
}
}
let mut linear_block = Array2::<f64>::zeros((data.nrows(), spec.linear_terms.len()));
for (j, linear) in spec.linear_terms.iter().enumerate() {
linear_block
.column_mut(j)
.assign(&data.column(linear.feature_col));
}
blocks.push(DesignBlock::Dense(crate::matrix::DenseDesignMatrix::from(
linear_block,
)));
}
for term in &spec.random_effect_terms {
let block = build_random_effect_block(data, term)?;
let re_op = RandomEffectOperator::new(block.group_ids, block.num_groups);
blocks.push(DesignBlock::RandomEffect(Arc::new(re_op)));
}
Ok(blocks)
}
pub struct SpatialLengthScaleOptimizationResult<FitOut> {
pub resolved_specs: Vec<TermCollectionSpec>,
pub designs: Vec<TermCollectionDesign>,
pub fit: FitOut,
}
#[derive(Debug, Clone)]
pub struct ExactJointHyperSetup {
rho0: Array1<f64>,
rho_lower: Array1<f64>,
rho_upper: Array1<f64>,
log_kappa0: SpatialLogKappaCoords,
log_kappa_lower: SpatialLogKappaCoords,
log_kappa_upper: SpatialLogKappaCoords,
auxiliary0: Array1<f64>,
auxiliary_lower: Array1<f64>,
auxiliary_upper: Array1<f64>,
}
impl ExactJointHyperSetup {
fn sanitize_rho_seed(
rho0: Array1<f64>,
rho_lower: &Array1<f64>,
rho_upper: &Array1<f64>,
) -> Array1<f64> {
Array1::from_iter(rho0.iter().enumerate().map(|(idx, &value)| {
let lo = rho_lower[idx];
let hi = rho_upper[idx];
let fallback = 0.0_f64.clamp(lo, hi);
if value.is_finite() {
value.clamp(lo, hi)
} else {
fallback
}
}))
}
pub(crate) fn new(
rho0: Array1<f64>,
rho_lower: Array1<f64>,
rho_upper: Array1<f64>,
log_kappa0: SpatialLogKappaCoords,
log_kappa_lower: SpatialLogKappaCoords,
log_kappa_upper: SpatialLogKappaCoords,
) -> Self {
let rho0 = Self::sanitize_rho_seed(rho0, &rho_lower, &rho_upper);
Self {
rho0,
rho_lower,
rho_upper,
log_kappa0,
log_kappa_lower,
log_kappa_upper,
auxiliary0: Array1::zeros(0),
auxiliary_lower: Array1::zeros(0),
auxiliary_upper: Array1::zeros(0),
}
}
pub(crate) fn with_auxiliary(
mut self,
auxiliary0: Array1<f64>,
auxiliary_lower: Array1<f64>,
auxiliary_upper: Array1<f64>,
) -> Self {
assert_eq!(
auxiliary0.len(),
auxiliary_lower.len(),
"auxiliary lower bound length mismatch"
);
assert_eq!(
auxiliary0.len(),
auxiliary_upper.len(),
"auxiliary upper bound length mismatch"
);
self.auxiliary0 = Self::sanitize_rho_seed(auxiliary0, &auxiliary_lower, &auxiliary_upper);
self.auxiliary_lower = auxiliary_lower;
self.auxiliary_upper = auxiliary_upper;
self
}
pub(crate) fn rho_dim(&self) -> usize {
self.rho0.len()
}
pub(crate) fn log_kappa_dim(&self) -> usize {
self.log_kappa0.len()
}
pub(crate) fn auxiliary_dim(&self) -> usize {
self.auxiliary0.len()
}
pub(crate) fn theta0(&self) -> Array1<f64> {
let mut out =
Array1::<f64>::zeros(self.rho_dim() + self.log_kappa_dim() + self.auxiliary_dim());
out.slice_mut(s![..self.rho_dim()]).assign(&self.rho0);
out.slice_mut(s![self.rho_dim()..self.rho_dim() + self.log_kappa_dim()])
.assign(self.log_kappa0.as_array());
out.slice_mut(s![self.rho_dim() + self.log_kappa_dim()..])
.assign(&self.auxiliary0);
out
}
pub(crate) fn lower(&self) -> Array1<f64> {
let mut out =
Array1::<f64>::zeros(self.rho_dim() + self.log_kappa_dim() + self.auxiliary_dim());
out.slice_mut(s![..self.rho_dim()]).assign(&self.rho_lower);
out.slice_mut(s![self.rho_dim()..self.rho_dim() + self.log_kappa_dim()])
.assign(self.log_kappa_lower.as_array());
out.slice_mut(s![self.rho_dim() + self.log_kappa_dim()..])
.assign(&self.auxiliary_lower);
out
}
pub(crate) fn upper(&self) -> Array1<f64> {
let mut out =
Array1::<f64>::zeros(self.rho_dim() + self.log_kappa_dim() + self.auxiliary_dim());
out.slice_mut(s![..self.rho_dim()]).assign(&self.rho_upper);
out.slice_mut(s![self.rho_dim()..self.rho_dim() + self.log_kappa_dim()])
.assign(self.log_kappa_upper.as_array());
out.slice_mut(s![self.rho_dim() + self.log_kappa_dim()..])
.assign(&self.auxiliary_upper);
out
}
pub(crate) fn log_kappa_dims_per_term(&self) -> Vec<usize> {
self.log_kappa0.dims_per_term().to_vec()
}
}
struct ExactJointDesignCache<'d> {
realizers: Vec<FrozenTermCollectionIncrementalRealizer<'d>>,
block_term_indices: Vec<Vec<usize>>,
current_theta: Option<Array1<f64>>,
last_cost: Option<f64>,
last_eval: Option<(
f64,
Array1<f64>,
crate::solver::outer_strategy::HessianResult,
)>,
rho_dim: usize,
all_dims: Vec<usize>,
log_kappa_dim: usize,
block_term_counts: Vec<usize>,
}
impl<'d> ExactJointDesignCache<'d> {
fn new(
data: ArrayView2<'d, f64>,
blocks: Vec<(TermCollectionSpec, TermCollectionDesign, Vec<usize>)>,
rho_dim: usize,
all_dims: Vec<usize>,
) -> Result<Self, String> {
let n_blocks = blocks.len();
let mut realizers = Vec::with_capacity(n_blocks);
let mut block_term_indices = Vec::with_capacity(n_blocks);
let mut block_term_counts = Vec::with_capacity(n_blocks);
for (spec, design, terms) in blocks {
block_term_counts.push(terms.len());
block_term_indices.push(terms);
realizers.push(FrozenTermCollectionIncrementalRealizer::new(
data, spec, design,
)?);
}
Ok(Self {
realizers,
block_term_indices,
current_theta: None,
last_cost: None,
last_eval: None,
rho_dim,
log_kappa_dim: all_dims.iter().sum(),
all_dims,
block_term_counts,
})
}
fn ensure_theta(&mut self, theta: &Array1<f64>) -> Result<(), String> {
if self
.current_theta
.as_ref()
.is_some_and(|cached| theta_values_match(cached, theta))
{
return Ok(());
}
let t_ensure = std::time::Instant::now();
let kappa_theta_len = self.rho_dim + self.log_kappa_dim;
if theta.len() < kappa_theta_len {
return Err(format!(
"exact-joint theta length mismatch: got {}, expected at least {} (rho_dim={}, log_kappa_dim={})",
theta.len(),
kappa_theta_len,
self.rho_dim,
self.log_kappa_dim
));
}
let theta_kappa = theta.slice(s![..kappa_theta_len]).to_owned();
let full_log_kappa = SpatialLogKappaCoords::from_theta_tail_with_dims(
&theta_kappa,
self.rho_dim,
self.all_dims.clone(),
);
let n = self.realizers.len();
let mut remaining = full_log_kappa;
for block_idx in 0..n {
let count = self.block_term_counts[block_idx];
if block_idx < n - 1 {
let (block_lk, rest) = remaining.split_at(count);
self.realizers[block_idx]
.apply_log_kappa(&block_lk, &self.block_term_indices[block_idx])?;
remaining = rest;
} else {
self.realizers[block_idx]
.apply_log_kappa(&remaining, &self.block_term_indices[block_idx])?;
}
}
log::info!(
"[STAGE] ensure_theta (n-block, {} blocks, {} realizers): {:.3}s",
n,
self.realizers.len(),
t_ensure.elapsed().as_secs_f64(),
);
self.current_theta = Some(theta.clone());
self.last_cost = None;
self.last_eval = None;
Ok(())
}
fn memoized_cost(&self, theta: &Array1<f64>) -> Option<f64> {
if self
.current_theta
.as_ref()
.is_some_and(|cached| theta_values_match(cached, theta))
{
self.last_eval
.as_ref()
.map(|cached| cached.0)
.or(self.last_cost)
} else {
None
}
}
fn memoized_eval(
&self,
theta: &Array1<f64>,
) -> Option<(
f64,
Array1<f64>,
crate::solver::outer_strategy::HessianResult,
)> {
if self
.current_theta
.as_ref()
.is_some_and(|cached| theta_values_match(cached, theta))
{
self.last_eval.clone()
} else {
None
}
}
fn store_eval(
&mut self,
eval: (
f64,
Array1<f64>,
crate::solver::outer_strategy::HessianResult,
),
) {
self.last_cost = Some(eval.0);
self.last_eval = Some(eval);
}
fn store_cost_only(&mut self, theta: &Array1<f64>, cost: f64) {
if self
.current_theta
.as_ref()
.is_some_and(|cached| theta_values_match(cached, theta))
{
self.last_cost = Some(cost);
}
}
fn specs(&self) -> Vec<&TermCollectionSpec> {
self.realizers.iter().map(|r| r.spec()).collect()
}
fn designs(&self) -> Vec<&TermCollectionDesign> {
self.realizers.iter().map(|r| r.design()).collect()
}
}
pub(crate) fn seed_risk_profile_for_likelihood_family(
family: LikelihoodFamily,
) -> crate::seeding::SeedRiskProfile {
match family {
LikelihoodFamily::GaussianIdentity => crate::seeding::SeedRiskProfile::Gaussian,
LikelihoodFamily::RoystonParmar => crate::seeding::SeedRiskProfile::Survival,
LikelihoodFamily::BinomialLogit
| LikelihoodFamily::BinomialProbit
| LikelihoodFamily::BinomialCLogLog
| LikelihoodFamily::BinomialLatentCLogLog
| LikelihoodFamily::BinomialSas
| LikelihoodFamily::BinomialBetaLogistic
| LikelihoodFamily::BinomialMixture
| LikelihoodFamily::PoissonLog
| LikelihoodFamily::GammaLog => crate::seeding::SeedRiskProfile::GeneralizedLinear,
}
}
pub(crate) fn exact_joint_multistart_outer_problem(
theta0: &Array1<f64>,
lower: &Array1<f64>,
upper: &Array1<f64>,
rho_dim: usize,
auxiliary_dim: usize,
n_params: usize,
gradient: crate::solver::outer_strategy::Derivative,
hessian: crate::solver::outer_strategy::DeclaredHessianForm,
prefer_gradient_only: bool,
disable_fixed_point: bool,
risk_profile: crate::seeding::SeedRiskProfile,
tolerance: f64,
max_iter: usize,
bfgs_step_cap: Option<f64>,
screening_cap: Option<Arc<AtomicUsize>>,
) -> crate::solver::outer_strategy::OuterProblem {
let mut seed_heuristic = theta0.to_vec();
for value in &mut seed_heuristic[..rho_dim] {
*value = value.exp();
}
let mut problem = crate::solver::outer_strategy::OuterProblem::new(n_params)
.with_gradient(gradient)
.with_hessian(hessian)
.with_prefer_gradient_only(prefer_gradient_only)
.with_disable_fixed_point(disable_fixed_point)
.with_fallback_policy(crate::solver::outer_strategy::FallbackPolicy::Automatic)
.with_psi_dim(auxiliary_dim)
.with_tolerance(tolerance)
.with_max_iter(max_iter)
.with_bounds(lower.clone(), upper.clone())
.with_initial_rho(theta0.clone())
.with_bfgs_step_cap(bfgs_step_cap)
.with_seed_config(crate::seeding::SeedConfig {
max_seeds: 4,
seed_budget: 2,
risk_profile,
num_auxiliary_trailing: auxiliary_dim,
..Default::default()
})
.with_rho_bound(12.0)
.with_heuristic_lambdas(seed_heuristic);
if let Some(screening_cap) = screening_cap {
problem = problem
.with_screening_cap(screening_cap)
.with_screen_initial_rho(true);
}
problem
}
pub fn optimize_spatial_length_scale_exact_joint<FitOut, FitFn, ExactFn, ExactEfsFn>(
data: ArrayView2<'_, f64>,
block_specs: &[TermCollectionSpec],
block_term_indices: &[Vec<usize>],
kappa_options: &SpatialLengthScaleOptimizationOptions,
joint_setup: &ExactJointHyperSetup,
seed_risk_profile: crate::seeding::SeedRiskProfile,
analytic_joint_gradient_available: bool,
analytic_joint_hessian_available: bool,
disable_fixed_point: bool,
screening_cap: Option<Arc<AtomicUsize>>,
outer_derivative_policy: crate::families::custom_family::OuterDerivativePolicy,
mut fit_fn: FitFn,
mut exact_fn: ExactFn,
mut exact_efs_fn: ExactEfsFn,
) -> Result<SpatialLengthScaleOptimizationResult<FitOut>, String>
where
FitOut: Clone,
FitFn: FnMut(
&Array1<f64>,
&[TermCollectionSpec],
&[TermCollectionDesign],
) -> Result<FitOut, String>,
ExactFn: FnMut(
&Array1<f64>,
&[TermCollectionSpec],
&[TermCollectionDesign],
crate::solver::estimate::reml::unified::EvalMode,
&crate::families::row_kernel::RowSet,
) -> Result<
(
f64,
Array1<f64>,
crate::solver::outer_strategy::HessianResult,
),
String,
>,
ExactEfsFn: FnMut(
&Array1<f64>,
&[TermCollectionSpec],
&[TermCollectionDesign],
) -> Result<crate::solver::outer_strategy::EfsEval, String>,
{
let n_blocks = block_specs.len();
if block_term_indices.len() != n_blocks {
return Err(format!(
"block_specs ({}) and block_term_indices ({}) length mismatch",
n_blocks,
block_term_indices.len()
));
}
let log_kappa_dim = joint_setup.log_kappa_dim();
if joint_setup.auxiliary_dim() == 0 && (!kappa_options.enabled || log_kappa_dim == 0) {
let (designs, resolved_specs) = build_term_collection_designs_and_freeze_joint(
data, block_specs,
)
.map_err(|e| {
format!("failed to build and freeze joint block designs during exact joint kappa optimization: {e}")
})?;
let theta0 = joint_setup.theta0();
let spec_refs: Vec<TermCollectionSpec> = resolved_specs.clone();
let design_refs: Vec<TermCollectionDesign> = designs.clone();
let fit = fit_fn(&theta0, &spec_refs, &design_refs)?;
return Ok(SpatialLengthScaleOptimizationResult {
resolved_specs,
designs,
fit,
});
}
let theta0 = joint_setup.theta0();
let lower = joint_setup.lower();
let upper = joint_setup.upper();
if theta0.len() < log_kappa_dim || lower.len() != theta0.len() || upper.len() != theta0.len() {
return Err(format!(
"invalid exact joint theta setup: theta0={}, lower={}, upper={}, required_log_kappa_dim={}",
theta0.len(),
lower.len(),
upper.len(),
log_kappa_dim
));
}
let rho_dim = joint_setup.rho_dim();
let all_dims = joint_setup.log_kappa_dims_per_term();
let (boot_designs, best_specs) = build_term_collection_designs_and_freeze_joint(
data,
block_specs,
)
.map_err(|e| {
format!(
"failed to build and freeze joint block designs during exact joint kappa bootstrap: {e}"
)
})?;
let policy_hessian_form = outer_derivative_policy.declared_hessian_form();
let analytic_outer_hessian_available = analytic_joint_hessian_available
&& matches!(
policy_hessian_form,
crate::solver::outer_strategy::DeclaredHessianForm::Either
| crate::solver::outer_strategy::DeclaredHessianForm::Dense
| crate::solver::outer_strategy::DeclaredHessianForm::Operator { .. }
);
let prefer_gradient_only = !analytic_outer_hessian_available;
let theta_dim = theta0.len();
let psi_dim = theta_dim - rho_dim;
let cache_blocks: Vec<(TermCollectionSpec, TermCollectionDesign, Vec<usize>)> = best_specs
.iter()
.zip(boot_designs.iter())
.zip(block_term_indices.iter())
.map(|((spec, design), terms)| (spec.clone(), design.clone(), terms.clone()))
.collect();
struct NBlockExactJointState<'d> {
cache: ExactJointDesignCache<'d>,
}
let mut state = NBlockExactJointState {
cache: ExactJointDesignCache::new(data, cache_blocks, rho_dim, all_dims.clone())?,
};
const KAPPA_SUBSAMPLE_PILOT_TRIGGER_N: usize = 30_000;
const KAPPA_PILOT_K: usize = 5_000;
const KAPPA_POLISH_K: usize = 25_000;
const KAPPA_POLISH_TRIGGER_N: usize = 100_000;
let _ = KAPPA_SUBSAMPLE_PILOT_TRIGGER_N; let _ = KAPPA_POLISH_TRIGGER_N;
let _ = KAPPA_POLISH_K;
let _ = KAPPA_PILOT_K;
let n_total = data.nrows();
let use_staged_kappa = outer_derivative_policy.should_use_staged_kappa(n_total);
if use_staged_kappa {
log::info!(
"[KAPPA-STAGED] auto-engaging pilot+polish schedule: n={} pilot_k={} polish_k={}",
n_total,
KAPPA_PILOT_K,
KAPPA_POLISH_K,
);
}
fn build_uniform_pilot_subsample(
n_total: usize,
k_target: usize,
seed: u64,
) -> crate::families::marginal_slope_shared::OuterScoreSubsample {
use crate::families::marginal_slope_shared::OuterScoreSubsample;
let k = k_target.min(n_total);
if k == 0 || n_total == 0 {
return OuterScoreSubsample::new(Vec::new(), n_total, seed);
}
let mut mask: Vec<usize> = Vec::with_capacity(k);
let mut state = seed.wrapping_add(0x9E3779B97F4A7C15);
let splitmix = |s: &mut u64| -> u64 {
*s = s.wrapping_add(0x9E3779B97F4A7C15);
let mut z = *s;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D049BB133111EB);
z ^ (z >> 31)
};
let mut taken = std::collections::HashSet::with_capacity(k);
for j in (n_total - k)..n_total {
let r = (splitmix(&mut state) % (j as u64 + 1)) as usize;
if !taken.insert(r) {
taken.insert(j);
mask.push(j);
} else {
mask.push(r);
}
}
mask.sort_unstable();
mask.dedup();
OuterScoreSubsample::new(mask, n_total, seed)
}
let current_row_set: std::cell::RefCell<crate::families::row_kernel::RowSet> =
if use_staged_kappa {
let pilot = build_uniform_pilot_subsample(n_total, KAPPA_PILOT_K, n_total as u64);
std::cell::RefCell::new(crate::families::row_kernel::RowSet::Subsample {
rows: std::sync::Arc::clone(&pilot.rows),
n_full: n_total,
})
} else {
std::cell::RefCell::new(crate::families::row_kernel::RowSet::All)
};
let exact_fn_cell = std::cell::RefCell::new(&mut exact_fn);
let exact_efs_fn_cell = std::cell::RefCell::new(&mut exact_efs_fn);
use std::cell::Cell;
let kphase_cost_calls: Cell<usize> = Cell::new(0);
let kphase_cost_total_s: Cell<f64> = Cell::new(0.0);
let kphase_eval_calls: Cell<usize> = Cell::new(0);
let kphase_eval_total_s: Cell<f64> = Cell::new(0.0);
let kphase_efs_calls: Cell<usize> = Cell::new(0);
let kphase_efs_total_s: Cell<f64> = Cell::new(0.0);
let kphase_optim_start = std::time::Instant::now();
let kphase_log_kappa_dim = log_kappa_dim;
let kphase_log_norms = |theta: &Array1<f64>| -> (f64, f64) {
let theta_norm = theta.iter().map(|v| v * v).sum::<f64>().sqrt();
let log_kappa_norm = if kphase_log_kappa_dim > 0 && theta.len() >= kphase_log_kappa_dim {
let start = theta.len() - kphase_log_kappa_dim;
theta.iter().skip(start).map(|v| v * v).sum::<f64>().sqrt()
} else {
0.0
};
(theta_norm, log_kappa_norm)
};
use crate::solver::outer_strategy::{
DeclaredHessianForm, Derivative, OuterEval, OuterEvalOrder,
};
let problem = exact_joint_multistart_outer_problem(
&theta0,
&lower,
&upper,
rho_dim,
psi_dim,
theta_dim,
if analytic_joint_gradient_available {
Derivative::Analytic
} else {
Derivative::Unavailable
},
if analytic_outer_hessian_available {
DeclaredHessianForm::Either
} else {
DeclaredHessianForm::Unavailable
},
prefer_gradient_only,
disable_fixed_point,
seed_risk_profile,
kappa_options.rel_tol.max(1e-6),
kappa_options.max_outer_iter.max(1),
Some(kappa_options.log_step.clamp(0.25, 1.0)),
screening_cap.clone(),
);
fn collect_specs(cache: &ExactJointDesignCache<'_>) -> Vec<TermCollectionSpec> {
cache.specs().into_iter().cloned().collect()
}
fn collect_designs(cache: &ExactJointDesignCache<'_>) -> Vec<TermCollectionDesign> {
cache.designs().into_iter().cloned().collect()
}
let result = {
let eval_outer = |ctx: &mut &mut NBlockExactJointState<'_>,
theta: &Array1<f64>,
order: OuterEvalOrder|
-> Result<OuterEval, EstimationError> {
if let Some((cost, grad, hess)) = ctx.cache.memoized_eval(theta) {
let cached_satisfies_order = match order {
OuterEvalOrder::ValueAndGradient => true,
OuterEvalOrder::ValueGradientHessian => hess.is_analytic(),
};
if cached_satisfies_order {
if !cost.is_finite() {
return Ok(OuterEval::infeasible(theta.len()));
}
if grad.iter().any(|v| !v.is_finite()) {
return Err(EstimationError::RemlOptimizationFailed(
"n-block exact-joint gradient contained non-finite values".to_string(),
));
}
return Ok(OuterEval {
cost,
gradient: grad,
hessian: hess,
});
}
}
if let Err(err) = ctx.cache.ensure_theta(theta) {
log::warn!(
"[OUTER] n-block exact-joint spatial: ensure_theta failed during gradient evaluation: {err}"
);
return Ok(OuterEval::infeasible(theta.len()));
}
let specs = collect_specs(&ctx.cache);
let designs = collect_designs(&ctx.cache);
let clamped = outer_derivative_policy.order_for_evaluation(order);
let need_hessian = matches!(clamped, OuterEvalOrder::ValueGradientHessian)
&& analytic_outer_hessian_available;
let eval_mode = if need_hessian {
crate::solver::estimate::reml::unified::EvalMode::ValueGradientHessian
} else {
crate::solver::estimate::reml::unified::EvalMode::ValueAndGradient
};
let _t0 = std::time::Instant::now();
let row_set_borrow = current_row_set.borrow();
let result = (&mut *exact_fn_cell.borrow_mut())(
theta,
&specs,
&designs,
eval_mode,
&*row_set_borrow,
);
drop(row_set_borrow);
let elapsed_s = _t0.elapsed().as_secs_f64();
kphase_eval_calls.set(kphase_eval_calls.get() + 1);
kphase_eval_total_s.set(kphase_eval_total_s.get() + elapsed_s);
let (theta_norm, log_kappa_norm) = kphase_log_norms(theta);
log::info!(
"[KAPPA-PHASE] phase=eval_outer call={} order={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
kphase_eval_calls.get(),
order,
theta_norm,
log_kappa_norm,
elapsed_s,
);
match result {
Ok((cost, grad, hess)) => {
ctx.cache.store_eval((cost, grad.clone(), hess.clone()));
if !cost.is_finite() {
return Ok(OuterEval::infeasible(theta.len()));
}
if grad.iter().any(|v| !v.is_finite()) {
return Err(EstimationError::RemlOptimizationFailed(
"n-block exact-joint gradient contained non-finite values".to_string(),
));
}
Ok(OuterEval {
cost,
gradient: grad,
hessian: hess,
})
}
Err(err) => {
log::warn!(
"[OUTER] n-block exact-joint spatial: exact evaluation failed: {err}"
);
Ok(OuterEval::infeasible(theta.len()))
}
}
};
let mut obj = problem.build_objective_with_eval_order(
&mut state,
|ctx: &mut &mut NBlockExactJointState<'_>, theta: &Array1<f64>| {
if let Some(cost) = ctx.cache.memoized_cost(theta) {
return Ok(cost);
}
if let Err(err) = ctx.cache.ensure_theta(theta) {
log::warn!(
"[OUTER] n-block exact-joint spatial: ensure_theta failed during cost evaluation: {err}"
);
return Ok(f64::INFINITY);
}
let specs = collect_specs(&ctx.cache);
let designs = collect_designs(&ctx.cache);
let _t0 = std::time::Instant::now();
let row_set_borrow = current_row_set.borrow();
let result = (&mut *exact_fn_cell.borrow_mut())(
theta,
&specs,
&designs,
crate::solver::estimate::reml::unified::EvalMode::ValueOnly,
&*row_set_borrow,
);
drop(row_set_borrow);
let elapsed_s = _t0.elapsed().as_secs_f64();
kphase_cost_calls.set(kphase_cost_calls.get() + 1);
kphase_cost_total_s.set(kphase_cost_total_s.get() + elapsed_s);
let (theta_norm, log_kappa_norm) = kphase_log_norms(theta);
log::info!(
"[KAPPA-PHASE] phase=cost call={} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
kphase_cost_calls.get(),
theta_norm,
log_kappa_norm,
elapsed_s,
);
match result {
Ok((cost, _grad, _hess)) => {
ctx.cache.store_cost_only(theta, cost);
Ok(cost)
}
Err(err) => {
log::warn!(
"[OUTER] n-block exact-joint spatial: exact cost evaluation failed: {err}"
);
Ok(f64::INFINITY)
}
}
},
|ctx: &mut &mut NBlockExactJointState<'_>, theta: &Array1<f64>| {
eval_outer(
ctx,
theta,
if analytic_outer_hessian_available {
OuterEvalOrder::ValueGradientHessian
} else {
OuterEvalOrder::ValueAndGradient
},
)
},
|ctx: &mut &mut NBlockExactJointState<'_>,
theta: &Array1<f64>,
order: OuterEvalOrder| { eval_outer(ctx, theta, order) },
None::<fn(&mut &mut NBlockExactJointState<'_>)>,
Some(
|ctx: &mut &mut NBlockExactJointState<'_>, theta: &Array1<f64>| {
ctx.cache
.ensure_theta(theta)
.map_err(EstimationError::InvalidInput)?;
let specs = collect_specs(&ctx.cache);
let designs = collect_designs(&ctx.cache);
let _t0 = std::time::Instant::now();
let eval_result = (&mut *exact_efs_fn_cell.borrow_mut())(
theta,
&specs,
&designs,
);
let elapsed_s = _t0.elapsed().as_secs_f64();
kphase_efs_calls.set(kphase_efs_calls.get() + 1);
kphase_efs_total_s.set(kphase_efs_total_s.get() + elapsed_s);
let (theta_norm, log_kappa_norm) = kphase_log_norms(theta);
log::info!(
"[KAPPA-PHASE] phase=efs call={} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
kphase_efs_calls.get(),
theta_norm,
log_kappa_norm,
elapsed_s,
);
let eval = eval_result.map_err(EstimationError::RemlOptimizationFailed)?;
Ok(eval)
},
),
);
problem
.run(&mut obj, "n-block exact-joint spatial")
.map_err(|e| e.to_string())?
};
let kphase_total_s = kphase_optim_start.elapsed().as_secs_f64();
log::info!(
"[KAPPA-PHASE-SUMMARY] log_kappa_dim={} n_cost={} cost_total_s={:.4} n_eval={} eval_total_s={:.4} n_efs={} efs_total_s={:.4} optim_total_s={:.4}",
kphase_log_kappa_dim,
kphase_cost_calls.get(),
kphase_cost_total_s.get(),
kphase_eval_calls.get(),
kphase_eval_total_s.get(),
kphase_efs_calls.get(),
kphase_efs_total_s.get(),
kphase_total_s,
);
let theta_star = result.rho;
if use_staged_kappa && n_total >= KAPPA_POLISH_TRIGGER_N {
let polish = build_uniform_pilot_subsample(
n_total,
KAPPA_POLISH_K,
(n_total as u64).wrapping_add(0xA5A5A5A5),
);
*current_row_set.borrow_mut() = crate::families::row_kernel::RowSet::Subsample {
rows: std::sync::Arc::clone(&polish.rows),
n_full: n_total,
};
log::info!(
"[KAPPA-STAGED] rotating to polish subsample: k={} at theta_star",
polish.rows.len(),
);
state.cache.ensure_theta(&theta_star)?;
{
let specs = collect_specs(&state.cache);
let designs = collect_designs(&state.cache);
let row_set_borrow = current_row_set.borrow();
let _polish_eval = exact_fn(
&theta_star,
&specs,
&designs,
crate::solver::estimate::reml::unified::EvalMode::ValueAndGradient,
&*row_set_borrow,
);
}
}
*current_row_set.borrow_mut() = crate::families::row_kernel::RowSet::All;
if use_staged_kappa {
log::info!(
"[KAPPA-STAGED] rotating to full data for final coefficient fit (n={})",
n_total,
);
}
state.cache.ensure_theta(&theta_star)?;
let resolved_specs: Vec<TermCollectionSpec> = collect_specs(&state.cache);
let designs: Vec<TermCollectionDesign> = collect_designs(&state.cache);
let fit = fit_fn(&theta_star, &resolved_specs, &designs)?;
for spec in &resolved_specs {
log_spatial_aniso_scales(spec);
}
Ok(SpatialLengthScaleOptimizationResult {
resolved_specs,
designs,
fit,
})
}
pub fn fit_term_collectionwith_spatial_length_scale_optimization(
data: ArrayView2<'_, f64>,
y: Array1<f64>,
weights: Array1<f64>,
offset: Array1<f64>,
spec: &TermCollectionSpec,
family: LikelihoodFamily,
options: &FitOptions,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<FittedTermCollectionWithSpec, EstimationError> {
let mut resolvedspec = spec.clone();
let spatial_terms = spatial_length_scale_term_indices(&resolvedspec);
let n = data.nrows();
if !(y.len() == n && weights.len() == n && offset.len() == n) {
return Err(EstimationError::InvalidInput(format!(
"fit_term_collectionwith_spatial_length_scale_optimization row mismatch: n={}, y={}, weights={}, offset={}",
n,
y.len(),
weights.len(),
offset.len()
)));
}
if !kappa_options.enabled || spatial_terms.is_empty() {
let out = fit_term_collection_forspec(
data,
y.view(),
weights.view(),
offset.view(),
&resolvedspec,
family,
options,
)?;
let resolvedspec = freeze_term_collection_from_design(&resolvedspec, &out.design)?;
return Ok(FittedTermCollectionWithSpec {
fit: out.fit,
design: out.design,
resolvedspec,
adaptive_diagnostics: out.adaptive_diagnostics,
});
}
if kappa_options.max_outer_iter == 0 {
return Err(EstimationError::InvalidInput(
"spatial kappa optimization requires max_outer_iter >= 1".to_string(),
));
}
if !(kappa_options.log_step.is_finite() && kappa_options.log_step > 0.0) {
return Err(EstimationError::InvalidInput(
"spatial kappa optimization requires log_step > 0".to_string(),
));
}
if !(kappa_options.min_length_scale.is_finite()
&& kappa_options.max_length_scale.is_finite()
&& kappa_options.min_length_scale > 0.0
&& kappa_options.max_length_scale >= kappa_options.min_length_scale)
{
return Err(EstimationError::InvalidInput(
"spatial kappa optimization requires valid positive length_scale bounds".to_string(),
));
}
let pilot_threshold = kappa_options.pilot_subsample_threshold;
if pilot_threshold > 0 && n > pilot_threshold * 2 {
log::info!(
"[spatial-kappa] n={n} exceeds pilot threshold {}; using pilot geometry only for deterministic anisotropy initialization",
pilot_threshold * 2,
);
apply_spatial_anisotropy_pilot_initializer(
data,
&mut resolvedspec,
&spatial_terms,
pilot_threshold,
kappa_options,
);
}
let best = fit_term_collection_forspec(
data,
y.view(),
weights.view(),
offset.view(),
&resolvedspec,
family,
options,
)?;
resolvedspec = freeze_term_collection_from_design(&resolvedspec, &best.design)?;
let spatial_terms = spatial_length_scale_term_indices(&resolvedspec);
sync_aniso_contrasts_from_metadata(&mut resolvedspec, &best.design.smooth);
let initial_score = fit_score(&best.fit);
if !initial_score.is_finite() {
log::debug!("[spatial-kappa] initial profiled score is non-finite");
}
if !spatial_terms.is_empty() {
let exact_joint = require_successful_spatial_optimization_result(
initial_score,
try_exact_joint_spatial_length_scale_optimization(
data,
y.view(),
weights.view(),
offset.view(),
&resolvedspec,
&best,
family,
options,
kappa_options,
&spatial_terms,
)
.map(|opt| {
opt.map(|fit| {
let score = fit_score(&fit.fit);
(fit, score)
})
}),
)?;
log_spatial_aniso_scales(&exact_joint.resolvedspec);
return Ok(exact_joint);
}
Ok(FittedTermCollectionWithSpec {
fit: best.fit,
design: best.design,
resolvedspec,
adaptive_diagnostics: best.adaptive_diagnostics,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::basis::{
BSplineBasisSpec, BSplineIdentifiability, BSplineKnotSpec, CenterStrategy, DuchonBasisSpec,
DuchonNullspaceOrder, DuchonOperatorPenaltySpec, MaternBasisSpec, MaternIdentifiability,
MaternNu, SpatialIdentifiability, ThinPlateBasisSpec,
};
use crate::estimate::AdaptiveRegularizationOptions;
use crate::faer_ndarray::{FaerEigh, FaerSvd};
use ndarray::array;
use rand::RngExt;
use rand::SeedableRng;
use rand::rngs::StdRng;
fn assert_spatial_derivative_width(
label: &str,
dense: &Array2<f64>,
implicit: Option<&crate::terms::basis::ImplicitDesignPsiDerivative>,
expected: usize,
) {
if let Some(op) = implicit {
assert_eq!(
op.p_out(),
expected,
"{label} implicit derivative width should match term coefficient width"
);
} else {
assert_eq!(
dense.ncols(),
expected,
"{label} dense derivative width should match term coefficient width"
);
}
}
fn numerical_rank(x: &Array2<f64>) -> usize {
let (_, s, _) = x
.svd(false, false)
.expect("SVD should succeed in rank test");
let sigma_max = s.iter().copied().fold(0.0_f64, f64::max);
let tol = (x.nrows().max(x.ncols()).max(1) as f64) * f64::EPSILON * sigma_max.max(1.0);
s.iter().filter(|&&sv| sv > tol).count()
}
fn residual_norm_to_column_space(x: &Array2<f64>, y: &Array1<f64>) -> f64 {
let (u_opt, _, _) = x
.svd(true, false)
.expect("SVD should succeed in projection residual test");
let u = u_opt.expect("left singular vectors should be present");
let rank = numerical_rank(x);
let mut proj = Array1::<f64>::zeros(y.len());
for j in 0..rank.min(u.ncols()) {
let uj = u.column(j);
let coeff = uj.dot(y);
proj += &(&uj.to_owned() * coeff);
}
let resid = y - &proj;
resid.dot(&resid).sqrt()
}
fn two_block_exact_joint_hyper_setup(
meanspec: &TermCollectionSpec,
noisespec: &TermCollectionSpec,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> ExactJointHyperSetup {
let mean_terms = spatial_length_scale_term_indices(meanspec);
let noise_terms = spatial_length_scale_term_indices(noisespec);
let mean_dims_per_term = spatial_dims_per_term(meanspec, &mean_terms);
let noise_dims_per_term = spatial_dims_per_term(noisespec, &noise_terms);
let mean_use_aniso = has_aniso_terms(meanspec, &mean_terms);
let noise_use_aniso = has_aniso_terms(noisespec, &noise_terms);
let mean_log_kappa = if mean_use_aniso {
SpatialLogKappaCoords::from_length_scales_aniso(meanspec, &mean_terms, kappa_options)
} else {
SpatialLogKappaCoords::from_length_scales(meanspec, &mean_terms, kappa_options)
};
let noise_log_kappa = if noise_use_aniso {
SpatialLogKappaCoords::from_length_scales_aniso(noisespec, &noise_terms, kappa_options)
} else {
SpatialLogKappaCoords::from_length_scales(noisespec, &noise_terms, kappa_options)
};
let dims_per_term = mean_log_kappa
.dims_per_term()
.iter()
.copied()
.chain(noise_log_kappa.dims_per_term().iter().copied())
.collect::<Vec<_>>();
debug_assert_eq!(
dims_per_term,
mean_dims_per_term
.iter()
.copied()
.chain(noise_dims_per_term.iter().copied())
.collect::<Vec<_>>()
);
let log_kappa0 = SpatialLogKappaCoords::new_with_dims(
Array1::from_iter(
mean_log_kappa
.as_array()
.iter()
.chain(noise_log_kappa.as_array().iter())
.copied(),
),
dims_per_term.clone(),
);
ExactJointHyperSetup::new(
Array1::zeros(0),
Array1::zeros(0),
Array1::zeros(0),
log_kappa0,
SpatialLogKappaCoords::lower_bounds_aniso(&dims_per_term, kappa_options),
SpatialLogKappaCoords::upper_bounds_aniso(&dims_per_term, kappa_options),
)
}
fn max_abs_diff_matrix(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
assert_eq!(a.dim(), b.dim());
a.iter()
.zip(b.iter())
.map(|(&x, &y)| (x - y).abs())
.fold(0.0_f64, f64::max)
}
fn assert_frozen_replay_matches_fit(
data: ArrayView2<'_, f64>,
spec: &TermCollectionSpec,
label: &str,
) {
let fit_design = build_term_collection_design(data, spec).expect("fit-time design");
let frozen =
freeze_term_collection_from_design(spec, &fit_design).expect("freeze term collection");
let replay_design = build_term_collection_design(data, &frozen).expect("replay design");
let max_abs = max_abs_diff_matrix(
&fit_design.design.to_dense(),
&replay_design.design.to_dense(),
);
assert!(
max_abs <= 1e-10,
"{label} frozen replay changed realized design: max_abs={max_abs}"
);
}
fn dense_kronecker_pseudo_logdet_reference(
marginal_penalties: &[Array2<f64>],
lambdas: &[f64],
ridge: f64,
) -> (f64, Array1<f64>, Array2<f64>) {
let p_total: usize = marginal_penalties
.iter()
.map(|penalty| penalty.nrows())
.product();
let mut s_dense = Array2::<f64>::zeros((p_total, p_total));
for (axis, penalty) in marginal_penalties.iter().enumerate() {
let mut kron_term = Array2::<f64>::eye(1);
for (other_axis, other_penalty) in marginal_penalties.iter().enumerate() {
let factor = if axis == other_axis {
penalty.clone()
} else {
Array2::<f64>::eye(other_penalty.nrows())
};
kron_term = crate::construction::kronecker_product(&kron_term, &factor);
}
s_dense.scaled_add(lambdas[axis], &kron_term);
}
if ridge > 0.0 {
for idx in 0..p_total {
s_dense[[idx, idx]] += ridge;
}
}
let (evals_dense, evecs_dense): (Array1<f64>, Array2<f64>) = s_dense
.eigh(faer::Side::Lower)
.expect("dense Kronecker eigh");
let tol = 1e-12;
let positive_indices: Vec<usize> = evals_dense
.iter()
.enumerate()
.filter_map(|(idx, &value)| (value > tol).then_some(idx))
.collect();
let logdet = positive_indices
.iter()
.map(|&idx| evals_dense[idx].ln())
.sum();
let mut grad = Array1::<f64>::zeros(lambdas.len());
let mut hess = Array2::<f64>::zeros((lambdas.len(), lambdas.len()));
for (axis, penalty) in marginal_penalties.iter().enumerate() {
let mut kron_term = Array2::<f64>::eye(1);
for (other_axis, other_penalty) in marginal_penalties.iter().enumerate() {
let factor = if axis == other_axis {
penalty.clone()
} else {
Array2::<f64>::eye(other_penalty.nrows())
};
kron_term = crate::construction::kronecker_product(&kron_term, &factor);
}
for &eig_idx in &positive_indices {
let eigval = evals_dense[eig_idx];
let eigvec = evecs_dense.column(eig_idx).to_owned();
let projected = kron_term.dot(&eigvec);
let ck = lambdas[axis] * eigvec.dot(&projected);
grad[axis] += ck / eigval;
hess[[axis, axis]] += ck / eigval - (ck * ck) / (eigval * eigval);
for other_axis in (axis + 1)..lambdas.len() {
let mut other_kron = Array2::<f64>::eye(1);
for (inner_axis, inner_penalty) in marginal_penalties.iter().enumerate() {
let factor = if other_axis == inner_axis {
inner_penalty.clone()
} else {
Array2::<f64>::eye(inner_penalty.nrows())
};
other_kron = crate::construction::kronecker_product(&other_kron, &factor);
}
let other_projected = other_kron.dot(&eigvec);
let cl = lambdas[other_axis] * eigvec.dot(&other_projected);
let off = -(ck * cl) / (eigval * eigval);
hess[[axis, other_axis]] += off;
hess[[other_axis, axis]] += off;
}
}
}
(logdet, grad, hess)
}
fn max_abs_diff_vector(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
assert_eq!(a.len(), b.len());
a.iter()
.zip(b.iter())
.map(|(&x, &y)| (x - y).abs())
.fold(0.0_f64, f64::max)
}
#[test]
fn kronecker_penalty_system_logdet_matches_dense_reference() {
let q1 = 3usize;
let q2 = 4usize;
let s1 = array![[1.0, -1.0, 0.0], [-1.0, 2.0, -1.0], [0.0, -1.0, 1.0]];
let s2 = array![
[1.0, -1.0, 0.0, 0.0],
[-1.0, 2.0, -1.0, 0.0],
[0.0, -1.0, 2.0, -1.0],
[0.0, 0.0, -1.0, 1.0]
];
let marginal_penalties = vec![s1, s2];
let lambdas = vec![2.5, 1.3];
let ridge = 0.0;
let system = KroneckerPenaltySystem::new(marginal_penalties.clone(), vec![q1, q2], false)
.expect("KroneckerPenaltySystem");
let (logdet, grad, hess) = system.logdet_and_derivatives(&lambdas, ridge);
let (dense_logdet, dense_grad, dense_hess) =
dense_kronecker_pseudo_logdet_reference(&marginal_penalties, &lambdas, ridge);
assert!(
(logdet - dense_logdet).abs() < 1e-8,
"KroneckerPenaltySystem logdet mismatch: factored={} dense={}",
logdet,
dense_logdet
);
let grad_diff = max_abs_diff_vector(&grad, &dense_grad);
assert!(
grad_diff < 1e-8,
"KroneckerPenaltySystem gradient mismatch: max diff={grad_diff}"
);
let hess_diff = max_abs_diff_matrix(&hess, &dense_hess);
assert!(
hess_diff < 1e-8,
"KroneckerPenaltySystem Hessian mismatch: max diff={hess_diff}"
);
}
fn assert_term_collection_designs_match(
left: &TermCollectionDesign,
right: &TermCollectionDesign,
label: &str,
) {
let left_design = left.design.to_dense();
let right_design = right.design.to_dense();
let design_diff = max_abs_diff_matrix(&left_design, &right_design);
assert!(
design_diff <= 1e-10,
"{label} design mismatch max_abs={design_diff}"
);
assert_eq!(
left.penalties.len(),
right.penalties.len(),
"{label} penalty count mismatch"
);
for (idx, (lp, rp)) in left
.penalties
.iter()
.zip(right.penalties.iter())
.enumerate()
{
assert_eq!(
lp.col_range, rp.col_range,
"{label} penalty {idx} col_range mismatch"
);
let penalty_diff = max_abs_diff_matrix(&lp.local, &rp.local);
assert!(
penalty_diff <= 1e-10,
"{label} penalty {idx} mismatch max_abs={penalty_diff}"
);
}
assert_eq!(
left.nullspace_dims, right.nullspace_dims,
"{label} nullspace dims mismatch"
);
assert_eq!(
left.penaltyinfo.len(),
right.penaltyinfo.len(),
"{label} penaltyinfo length mismatch"
);
for (idx, (linfo, rinfo)) in left
.penaltyinfo
.iter()
.zip(right.penaltyinfo.iter())
.enumerate()
{
assert_eq!(
linfo.termname, rinfo.termname,
"{label} penaltyinfo termname mismatch at {idx}"
);
assert_eq!(
linfo.penalty.source, rinfo.penalty.source,
"{label} penalty source mismatch at {idx}"
);
assert_eq!(
linfo.penalty.active, rinfo.penalty.active,
"{label} penalty active mismatch at {idx}"
);
assert_eq!(
linfo.penalty.effective_rank, rinfo.penalty.effective_rank,
"{label} penalty rank mismatch at {idx}"
);
assert_eq!(
linfo.penalty.nullspace_dim_hint, rinfo.penalty.nullspace_dim_hint,
"{label} penalty nullspace hint mismatch at {idx}"
);
assert!(
(linfo.penalty.normalization_scale - rinfo.penalty.normalization_scale).abs()
<= 1e-10,
"{label} penalty normalization mismatch at {idx}"
);
}
match (
left.coefficient_lower_bounds.as_ref(),
right.coefficient_lower_bounds.as_ref(),
) {
(Some(lb_left), Some(lb_right)) => {
let diff = max_abs_diff_vector(lb_left, lb_right);
assert!(diff <= 1e-10, "{label} lower-bound mismatch max_abs={diff}");
}
(None, None) => {}
_ => panic!("{label} lower-bound presence mismatch"),
}
match (
left.linear_constraints.as_ref(),
right.linear_constraints.as_ref(),
) {
(Some(c_left), Some(c_right)) => {
let a_diff = max_abs_diff_matrix(&c_left.a, &c_right.a);
let b_diff = max_abs_diff_vector(&c_left.b, &c_right.b);
assert!(
a_diff <= 1e-10,
"{label} linear-constraint A mismatch max_abs={a_diff}"
);
assert!(
b_diff <= 1e-10,
"{label} linear-constraint b mismatch max_abs={b_diff}"
);
}
(None, None) => {}
_ => panic!("{label} linear-constraint presence mismatch"),
}
}
#[test]
fn smooth_design_assembles_terms_and_penalties() {
let data = array![
[0.0, 0.0, 0.2],
[0.2, 0.1, 0.4],
[0.4, 0.2, 0.6],
[0.6, 0.4, 0.7],
[0.8, 0.7, 0.9],
[1.0, 1.0, 1.1]
];
let terms = vec![
SmoothTermSpec {
name: "s_x0".to_string(),
basis: SmoothBasisSpec::BSpline1D {
feature_col: 0,
spec: BSplineBasisSpec {
degree: 3,
penalty_order: 2,
knotspec: BSplineKnotSpec::Generate {
data_range: (0.0, 1.0),
num_internal_knots: 4,
},
double_penalty: true,
identifiability: BSplineIdentifiability::default(),
},
},
shape: ShapeConstraint::None,
},
SmoothTermSpec {
name: "tps_x1x2".to_string(),
basis: SmoothBasisSpec::ThinPlate {
feature_cols: vec![1, 2],
spec: ThinPlateBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 4 },
length_scale: 1.0,
double_penalty: true,
identifiability: SpatialIdentifiability::default(),
radial_reparam: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
},
];
let sd = build_smooth_design(data.view(), &terms).unwrap();
assert_eq!(sd.nrows(), data.nrows());
assert_eq!(sd.terms.len(), 2);
assert_eq!(sd.penalties.len(), 4);
assert_eq!(sd.nullspace_dims.len(), 4);
for bp in &sd.penalties {
assert_eq!(bp.local.nrows(), bp.block_size());
assert_eq!(bp.local.ncols(), bp.block_size());
assert!(bp.col_range.end <= sd.total_smooth_cols());
}
}
#[test]
fn shape_mapping_monotone_increasing_is_non_decreasing() {
let theta = array![-1.0, 0.5, -0.2, 0.3];
let beta = SmoothDesign::map_term_coefficients(&theta, ShapeConstraint::MonotoneIncreasing)
.unwrap();
for i in 1..beta.len() {
assert!(beta[i] >= beta[i - 1]);
}
}
#[test]
fn build_smooth_design_rejectsmultiaxis_spatial_shape_constraints() {
let data = array![[0.0, 0.0], [0.5, 0.2], [1.0, 0.4], [1.5, 0.6],];
let terms = vec![SmoothTermSpec {
name: "tps_shape".to_string(),
basis: SmoothBasisSpec::ThinPlate {
feature_cols: vec![0, 1],
spec: ThinPlateBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 3 },
length_scale: 1.0,
double_penalty: false,
identifiability: SpatialIdentifiability::default(),
radial_reparam: None,
},
input_scales: None,
},
shape: ShapeConstraint::MonotoneIncreasing,
}];
let err = build_smooth_design(data.view(), &terms).expect_err("shape should be rejected");
match err {
BasisError::InvalidInput(msg) => {
assert!(msg.contains("requires exactly 1 feature axis"));
}
other => panic!("unexpected error: {other:?}"),
}
}
#[test]
fn build_smooth_design_accepts_monotone_thin_plate_1dwith_linear_constraints() {
let data = array![[0.0], [0.15], [0.35], [0.5], [0.65], [0.85], [1.0]];
let terms = vec![SmoothTermSpec {
name: "mono_tps".to_string(),
basis: SmoothBasisSpec::ThinPlate {
feature_cols: vec![0],
spec: ThinPlateBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 4 },
length_scale: 1.0,
double_penalty: false,
identifiability: SpatialIdentifiability::default(),
radial_reparam: None,
},
input_scales: None,
},
shape: ShapeConstraint::MonotoneIncreasing,
}];
let sd = build_smooth_design(data.view(), &terms).expect("shape-constrained thin-plate");
assert!(sd.coefficient_lower_bounds.is_none());
let lin = sd
.linear_constraints
.as_ref()
.expect("linear constraints should be generated");
assert!(lin.a.nrows() > 0);
assert_eq!(lin.a.ncols(), sd.total_smooth_cols());
assert_eq!(lin.b.len(), lin.a.nrows());
}
#[test]
fn build_smooth_design_auto_promotes_thin_plate_below_canonical_polynomial_dimension() {
let data = array![
[0.0, 0.0, 0.0],
[0.2, 0.1, 0.3],
[0.4, 0.3, 0.5],
[0.7, 0.6, 0.8],
];
let terms = vec![SmoothTermSpec {
name: "thinplate(pc1, pc2, pc3)".to_string(),
basis: SmoothBasisSpec::ThinPlate {
feature_cols: vec![0, 1, 2],
spec: ThinPlateBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 3 },
length_scale: 1.0,
double_penalty: false,
identifiability: SpatialIdentifiability::default(),
radial_reparam: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
}];
let sd = build_smooth_design(data.view(), &terms)
.expect("auto-promotion to Duchon should succeed at infeasible canonical (d, k)");
let metadata = &sd.terms.first().expect("at least one smooth term").metadata;
assert!(
matches!(metadata, BasisMetadata::Duchon { .. }),
"expected Duchon metadata after auto-promotion, got {metadata:?}"
);
}
#[test]
fn freeze_term_collection_handles_thin_plate_auto_promotion_to_duchon() {
let mut rng = StdRng::seed_from_u64(20260504);
let n = 200usize;
let mut data = Array2::<f64>::zeros((n, 5));
for i in 0..n {
for j in 0..5 {
data[[i, j]] = rng.random_range(-1.0..1.0);
}
}
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "thinplate(pc1, pc2, pc3, pc4, pc5)".to_string(),
basis: SmoothBasisSpec::ThinPlate {
feature_cols: vec![0, 1, 2, 3, 4],
spec: ThinPlateBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 10 },
length_scale: 1.0,
double_penalty: false,
identifiability: SpatialIdentifiability::default(),
radial_reparam: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let fit_design = build_term_collection_design(data.view(), &spec).expect("fit-time design");
let metadata = &fit_design
.smooth
.terms
.first()
.expect("at least one smooth term")
.metadata;
assert!(
matches!(metadata, BasisMetadata::Duchon { .. }),
"expected auto-promotion to Duchon, got {metadata:?}"
);
let frozen = freeze_term_collection_from_design(&spec, &fit_design).expect(
"freeze must succeed across the auto-promoted (ThinPlate spec, Duchon metadata) pair",
);
assert!(
matches!(frozen.smooth_terms[0].basis, SmoothBasisSpec::Duchon { .. }),
"frozen spec should reflect the auto-promotion as a Duchon variant"
);
let replay_design =
build_term_collection_design(data.view(), &frozen).expect("replay design");
let max_abs = fit_design
.design
.to_dense()
.iter()
.zip(replay_design.design.to_dense().iter())
.map(|(&a, &b)| (a - b).abs())
.fold(0.0_f64, f64::max);
assert!(
max_abs <= 1e-10,
"auto-promoted frozen replay changed realized design: max_abs={max_abs}"
);
}
#[test]
fn build_smooth_design_accepts_monotone_matern_1dwith_linear_constraints() {
let data = array![[0.0], [0.2], [0.4], [0.6], [0.8], [1.0]];
let terms = vec![SmoothTermSpec {
name: "mono_matern".to_string(),
basis: SmoothBasisSpec::Matern {
feature_cols: vec![0],
spec: MaternBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 4 },
length_scale: 0.7,
nu: MaternNu::FiveHalves,
include_intercept: false,
double_penalty: false,
identifiability: MaternIdentifiability::CenterSumToZero,
aniso_log_scales: None,
},
input_scales: None,
},
shape: ShapeConstraint::MonotoneIncreasing,
}];
let sd = build_smooth_design(data.view(), &terms).expect("shape-constrained Matérn");
assert!(sd.coefficient_lower_bounds.is_none());
let lin = sd
.linear_constraints
.as_ref()
.expect("linear constraints should be generated");
assert!(lin.a.nrows() > 0);
assert_eq!(lin.a.ncols(), sd.total_smooth_cols());
assert_eq!(lin.b.len(), lin.a.nrows());
}
#[test]
fn build_smooth_design_accepts_monotone_duchon_1dwith_linear_constraints() {
let data = array![[0.0], [0.2], [0.4], [0.6], [0.8], [1.0]];
let terms = vec![SmoothTermSpec {
name: "mono_duchon".to_string(),
basis: SmoothBasisSpec::Duchon {
feature_cols: vec![0],
spec: DuchonBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 4 },
length_scale: Some(0.9),
power: 5,
nullspace_order: DuchonNullspaceOrder::Linear,
identifiability: SpatialIdentifiability::OrthogonalToParametric,
aniso_log_scales: None,
operator_penalties: DuchonOperatorPenaltySpec::default(),
},
input_scales: None,
},
shape: ShapeConstraint::MonotoneIncreasing,
}];
let sd = build_smooth_design(data.view(), &terms).expect("shape-constrained Duchon");
assert!(sd.coefficient_lower_bounds.is_none());
let lin = sd
.linear_constraints
.as_ref()
.expect("linear constraints should be generated");
assert!(lin.a.nrows() > 0);
assert_eq!(lin.a.ncols(), sd.total_smooth_cols());
assert_eq!(lin.b.len(), lin.a.nrows());
}
#[test]
fn build_smooth_design_accepts_monotone_bsplinewith_bounds() {
let data = array![[0.0], [0.25], [0.5], [0.75], [1.0]];
let terms = vec![SmoothTermSpec {
name: "mono_bs".to_string(),
basis: SmoothBasisSpec::BSpline1D {
feature_col: 0,
spec: BSplineBasisSpec {
degree: 3,
penalty_order: 2,
knotspec: BSplineKnotSpec::Generate {
data_range: (0.0, 1.0),
num_internal_knots: 3,
},
double_penalty: false,
identifiability: BSplineIdentifiability::default(),
},
},
shape: ShapeConstraint::MonotoneIncreasing,
}];
let sd = build_smooth_design(data.view(), &terms).expect("shape-constrained bspline");
let lb = sd
.coefficient_lower_bounds
.as_ref()
.expect("lower bounds should be generated");
assert_eq!(lb.len(), sd.total_smooth_cols());
assert!(lb[0].is_infinite() && lb[0].is_sign_negative());
for j in 1..lb.len() {
assert_eq!(lb[j], 0.0);
}
}
#[test]
fn term_collection_design_combines_linear_and_smooth() {
let data = array![
[0.0, 0.0, 0.2],
[0.2, 0.1, 0.4],
[0.4, 0.2, 0.6],
[0.6, 0.4, 0.7],
[0.8, 0.7, 0.9],
[1.0, 1.0, 1.1]
];
let spec = TermCollectionSpec {
linear_terms: vec![LinearTermSpec {
name: "lin_x0".to_string(),
feature_col: 0,
double_penalty: true,
coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
coefficient_min: None,
coefficient_max: None,
}],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "tps_x1x2".to_string(),
basis: SmoothBasisSpec::ThinPlate {
feature_cols: vec![1, 2],
spec: ThinPlateBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 4 },
length_scale: 1.0,
double_penalty: true,
identifiability: SpatialIdentifiability::default(),
radial_reparam: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let design = build_term_collection_design(data.view(), &spec).unwrap();
let design_dense = design.design.to_dense();
assert_eq!(design.design.nrows(), data.nrows());
assert_eq!(design.intercept_range, 0..1);
assert!(
design_dense
.column(design.intercept_range.start)
.iter()
.all(|&v: &f64| (v - 1.0).abs() < 1e-12)
);
assert!(design.design.ncols() >= 2);
assert_eq!(design.linear_ranges.len(), 1);
assert_eq!(design.random_effect_ranges.len(), 0);
assert_eq!(design.penalties.len(), 3); assert_eq!(design.nullspace_dims.len(), 3);
}
#[test]
fn term_collection_design_keeps_intercept_plus_bspline_sparse() {
let n = 96usize;
let x = Array1::linspace(0.0, 1.0, n);
let mut data = Array2::<f64>::zeros((n, 1));
data.column_mut(0).assign(&x);
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "s_x".to_string(),
basis: SmoothBasisSpec::BSpline1D {
feature_col: 0,
spec: BSplineBasisSpec {
degree: 3,
penalty_order: 2,
knotspec: BSplineKnotSpec::Generate {
data_range: (0.0, 1.0),
num_internal_knots: 32,
},
double_penalty: false,
identifiability: BSplineIdentifiability::default(),
},
},
shape: ShapeConstraint::None,
}],
};
let design =
build_term_collection_design(data.view(), &spec).expect("term collection design");
assert!(
matches!(design.design, DesignMatrix::Sparse(_)),
"expected sparse full design, got {:?}",
design.design
);
}
#[test]
fn spatial_smooth_columns_do_not_duplicate_global_intercept() {
let data = array![
[0.0, 0.0],
[0.2, 0.1],
[0.4, 0.3],
[0.6, 0.6],
[0.8, 0.7],
[1.0, 1.0],
];
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "tps_xy".to_string(),
basis: SmoothBasisSpec::ThinPlate {
feature_cols: vec![0, 1],
spec: ThinPlateBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 4 },
length_scale: 1.0,
double_penalty: false,
identifiability: SpatialIdentifiability::default(),
radial_reparam: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let design = build_term_collection_design(data.view(), &spec).unwrap();
let design_dense = design.design.to_dense();
let smooth_start = 1usize;
let smooth_end = smooth_start + design.smooth.total_smooth_cols();
for col in smooth_start..smooth_end {
let is_all_ones = design_dense
.column(col)
.iter()
.all(|&v: &f64| (v - 1.0).abs() < 1e-12);
assert!(
!is_all_ones,
"smooth column {col} unexpectedly duplicated intercept"
);
}
}
#[test]
fn spatial_smooth_drops_matching_linear_trend_columns() {
let data = array![
[0.0, 0.1],
[0.2, 0.0],
[0.3, 0.4],
[0.5, 0.2],
[0.7, 0.9],
[1.0, 0.8],
[1.2, 1.1],
[1.4, 1.3],
];
let spec = TermCollectionSpec {
linear_terms: vec![LinearTermSpec {
name: "lin_x0".to_string(),
feature_col: 0,
double_penalty: false,
coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
coefficient_min: None,
coefficient_max: None,
}],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "tps_xy".to_string(),
basis: SmoothBasisSpec::ThinPlate {
feature_cols: vec![0, 1],
spec: ThinPlateBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 4 },
length_scale: 1.0,
double_penalty: false,
identifiability: SpatialIdentifiability::default(),
radial_reparam: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let design = build_term_collection_design(data.view(), &spec).unwrap();
assert_eq!(design.smooth.total_smooth_cols(), 2);
let dense = design.design.to_dense_cow();
let lin_col = design.linear_ranges[0].1.start;
let linvalues = dense.column(lin_col).to_owned();
let smooth_start = 1 + spec.linear_terms.len();
let smooth_end = smooth_start + design.smooth.total_smooth_cols();
for col in smooth_start..smooth_end {
let same_as_linear = dense
.column(col)
.iter()
.zip(linvalues.iter())
.all(|(&a, &b)| (a - b).abs() < 1e-12);
assert!(
!same_as_linear,
"smooth column {col} unexpectedly duplicated linear term column"
);
}
}
#[test]
fn spatial_option5_is_orthogonal_to_parametric_block() {
let data = array![
[0.0, 0.1],
[0.2, 0.0],
[0.3, 0.4],
[0.5, 0.2],
[0.7, 0.9],
[1.0, 0.8],
];
let spec = TermCollectionSpec {
linear_terms: vec![LinearTermSpec {
name: "lin_x0".to_string(),
feature_col: 0,
double_penalty: false,
coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
coefficient_min: None,
coefficient_max: None,
}],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "tps_xy".to_string(),
basis: SmoothBasisSpec::ThinPlate {
feature_cols: vec![0, 1],
spec: ThinPlateBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 4 },
length_scale: 1.0,
double_penalty: false,
identifiability: SpatialIdentifiability::OrthogonalToParametric,
radial_reparam: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let design = build_term_collection_design(data.view(), &spec).unwrap();
let design_dense = design.design.to_dense();
let n = data.nrows();
let mut c = Array2::<f64>::zeros((n, 2));
c.column_mut(0).fill(1.0);
c.column_mut(1).assign(&data.column(0));
let smooth_start = 1 + spec.linear_terms.len();
let b = design_dense
.slice(s![
..,
smooth_start..(smooth_start + design.smooth.total_smooth_cols())
])
.to_owned();
let cross = b.t().dot(&c);
let num = cross.iter().map(|v| v * v).sum::<f64>().sqrt();
let b_norm = b.iter().map(|v| v * v).sum::<f64>().sqrt();
let c_norm = c.iter().map(|v| v * v).sum::<f64>().sqrt();
let rel = num / (b_norm * c_norm).max(1e-300);
assert!(
rel <= 1e-10,
"smooth residual against model-owned parametric block too large: {rel}"
);
}
#[test]
fn thin_plate_default_identifiability_centers_against_intercept_only_without_linear_terms() {
let data = array![
[-1.9, -1.2],
[-1.3, -0.7],
[-0.8, -0.4],
[-0.2, 0.1],
[0.0, 0.3],
[0.4, 0.5],
[0.9, 0.8],
[1.4, 1.1],
[1.9, 1.5],
[2.3, 1.8],
];
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: (0..2)
.map(|feature| SmoothTermSpec {
name: format!("tps_x{feature}"),
basis: SmoothBasisSpec::ThinPlate {
feature_cols: vec![feature],
spec: ThinPlateBasisSpec {
center_strategy: CenterStrategy::EqualMass { num_centers: 4 },
length_scale: 1.0,
double_penalty: false,
identifiability: SpatialIdentifiability::OrthogonalToParametric,
radial_reparam: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
})
.collect(),
};
let design = build_term_collection_design(data.view(), &spec).unwrap();
let design_dense = design.design.to_dense();
let smooth_start = 1 + spec.linear_terms.len();
let intercept = Array2::<f64>::ones((data.nrows(), 1));
for (term_idx, term) in design.smooth.terms.iter().enumerate() {
let block = design_dense
.slice(s![
..,
(smooth_start + term.coeff_range.start)..(smooth_start + term.coeff_range.end)
])
.to_owned();
let cross = block.t().dot(&intercept);
let num = cross.iter().map(|v| v * v).sum::<f64>().sqrt();
let block_norm = block.iter().map(|v| v * v).sum::<f64>().sqrt();
let intercept_norm = intercept.iter().map(|v| v * v).sum::<f64>().sqrt();
let rel = num / (block_norm * intercept_norm).max(1e-300);
assert!(
rel <= 1e-10,
"ThinPlate term {term_idx} should be centered against the intercept (no linear terms in formula); got rel={rel:.3e}"
);
}
}
#[test]
fn spatial_option5_does_not_overconstrain_on_nonoverlapping_linear_terms() {
let n = 40usize;
let p = 16usize;
let mut data = Array2::<f64>::zeros((n, p));
for i in 0..n {
for j in 0..p {
data[[i, j]] =
(i as f64) * 0.03 + (j as f64) * 0.11 + ((i * (j + 1)) as f64) * 1e-3;
}
}
let spec = TermCollectionSpec {
linear_terms: (5..16)
.map(|j| LinearTermSpec {
name: format!("pc{j}"),
feature_col: j,
double_penalty: false,
coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
coefficient_min: None,
coefficient_max: None,
})
.collect(),
random_effect_terms: vec![],
smooth_terms: vec![
SmoothTermSpec {
name: "tps_pc1".to_string(),
basis: SmoothBasisSpec::ThinPlate {
feature_cols: vec![1],
spec: ThinPlateBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 12 },
length_scale: 1.0,
double_penalty: true,
identifiability: SpatialIdentifiability::OrthogonalToParametric,
radial_reparam: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
},
SmoothTermSpec {
name: "tps_pc2".to_string(),
basis: SmoothBasisSpec::ThinPlate {
feature_cols: vec![2],
spec: ThinPlateBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 12 },
length_scale: 1.0,
double_penalty: true,
identifiability: SpatialIdentifiability::OrthogonalToParametric,
radial_reparam: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
},
],
};
let out = build_term_collection_design(data.view(), &spec);
assert!(
out.is_ok(),
"term-local Option 5 should not over-constrain non-overlapping smooth/linear terms: {:?}",
out.err()
);
}
#[test]
fn overlapping_linear_term_residualizes_bspline_smooth() {
let data = array![
[0.0],
[0.1],
[0.2],
[0.3],
[0.4],
[0.5],
[0.6],
[0.7],
[0.8],
[0.9],
[1.0],
];
let spec = TermCollectionSpec {
linear_terms: vec![LinearTermSpec {
name: "x".to_string(),
feature_col: 0,
double_penalty: false,
coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
coefficient_min: None,
coefficient_max: None,
}],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "s_x".to_string(),
basis: SmoothBasisSpec::BSpline1D {
feature_col: 0,
spec: BSplineBasisSpec {
degree: 3,
penalty_order: 2,
knotspec: BSplineKnotSpec::Generate {
data_range: (0.0, 1.0),
num_internal_knots: 4,
},
double_penalty: false,
identifiability: BSplineIdentifiability::default(),
},
},
shape: ShapeConstraint::None,
}],
};
let design = build_term_collection_design(data.view(), &spec).expect("bspline design");
let mut c = Array2::<f64>::zeros((data.nrows(), 2));
c.column_mut(0).fill(1.0);
c.column_mut(1).assign(&data.column(0));
let rel =
orthogonality_relative_residual_for_design(&design.smooth.term_designs[0], c.view())
.expect("orthogonality residual");
assert!(
rel <= 1e-10,
"B-spline smooth should be orthogonal to [1, x] when linear(x) is present; rel={rel}"
);
}
#[test]
fn standalone_tps_keeps_centered_linear_nullspace() {
let data = array![[-1.5], [-0.7], [0.2], [0.8], [1.6]];
let centers = array![[-1.5], [0.2], [1.6]];
let smooth = SmoothTermSpec {
name: "s_x".to_string(),
basis: SmoothBasisSpec::ThinPlate {
feature_cols: vec![0],
spec: ThinPlateBasisSpec {
center_strategy: CenterStrategy::UserProvided(centers),
length_scale: 1.0,
double_penalty: false,
identifiability: SpatialIdentifiability::OrthogonalToParametric,
radial_reparam: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
};
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![smooth],
};
let design = build_term_collection_design(data.view(), &spec).expect("tps design");
assert_eq!(design.smooth.term_designs[0].ncols(), 2);
assert_eq!(design.smooth.nullspace_dims, vec![1]);
let intercept = Array2::<f64>::ones((data.nrows(), 1));
let rel = orthogonality_relative_residual_for_design(
&design.smooth.term_designs[0],
intercept.view(),
)
.expect("intercept residual");
assert!(
rel <= 1e-10,
"standalone TPS should be centered against the intercept while retaining its linear nullspace; rel={rel}"
);
}
#[test]
fn spatial_parametric_ownership_projects_only_explicit_linear_axes() {
let term = SmoothTermSpec {
name: "s_xy".to_string(),
basis: SmoothBasisSpec::ThinPlate {
feature_cols: vec![0, 1],
spec: ThinPlateBasisSpec {
center_strategy: CenterStrategy::EqualMass { num_centers: 4 },
length_scale: 1.0,
double_penalty: false,
identifiability: SpatialIdentifiability::OrthogonalToParametric,
radial_reparam: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
};
let linear_terms = vec![LinearTermSpec {
name: "x0".to_string(),
feature_col: 0,
double_penalty: false,
coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
coefficient_min: None,
coefficient_max: None,
}];
assert_eq!(
smooth_intrinsic_parametric_feature_cols(&linear_terms, &term),
vec![0],
"a linear term on x0 should not claim the smooth's x1 nullspace"
);
}
#[test]
fn hierarchical_smooth_ownership_is_order_independent_for_bspline_and_duchon() {
let data = array![
[0.00, 0.00],
[0.10, 0.15],
[0.18, 0.30],
[0.27, 0.10],
[0.35, 0.55],
[0.46, 0.25],
[0.54, 0.70],
[0.63, 0.40],
[0.72, 0.85],
[0.81, 0.60],
[0.90, 0.95],
[1.00, 0.75],
];
let bspline_term = SmoothTermSpec {
name: "s_x".to_string(),
basis: SmoothBasisSpec::BSpline1D {
feature_col: 0,
spec: BSplineBasisSpec {
degree: 3,
penalty_order: 2,
knotspec: BSplineKnotSpec::Generate {
data_range: (0.0, 1.0),
num_internal_knots: 5,
},
double_penalty: false,
identifiability: BSplineIdentifiability::default(),
},
},
shape: ShapeConstraint::None,
};
let duchon_term = SmoothTermSpec {
name: "duchon_xy".to_string(),
basis: SmoothBasisSpec::Duchon {
feature_cols: vec![0, 1],
spec: DuchonBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 6 },
length_scale: Some(1.0),
power: 5,
nullspace_order: DuchonNullspaceOrder::Linear,
identifiability: SpatialIdentifiability::default(),
aniso_log_scales: None,
operator_penalties: DuchonOperatorPenaltySpec::default(),
},
input_scales: None,
},
shape: ShapeConstraint::None,
};
let spec_a = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![duchon_term.clone(), bspline_term.clone()],
};
let spec_b = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![bspline_term, duchon_term],
};
let design_a = build_term_collection_design(data.view(), &spec_a).expect("design a");
let design_b = build_term_collection_design(data.view(), &spec_b).expect("design b");
for design in [&design_a, &design_b] {
let owner_idx = design
.smooth
.terms
.iter()
.position(|term| term.name == "s_x")
.expect("owner term");
let target_idx = design
.smooth
.terms
.iter()
.position(|term| term.name == "duchon_xy")
.expect("target term");
let owner_dense = design.smooth.term_designs[owner_idx].to_dense();
let rel = orthogonality_relative_residual_for_design(
&design.smooth.term_designs[target_idx],
owner_dense.view(),
)
.expect("orthogonality residual");
assert!(
rel <= 1e-10,
"multivariate Duchon term should be residualized against owned 1D spline space; rel={rel}"
);
}
let duchon_a_idx = design_a
.smooth
.terms
.iter()
.position(|term| term.name == "duchon_xy")
.expect("duchon in design a");
let duchon_b_idx = design_b
.smooth
.terms
.iter()
.position(|term| term.name == "duchon_xy")
.expect("duchon in design b");
let duchon_a = design_a.smooth.term_designs[duchon_a_idx].to_dense();
let duchon_b = design_b.smooth.term_designs[duchon_b_idx].to_dense();
assert_eq!(duchon_a.dim(), duchon_b.dim());
let max_abs = duchon_a
.iter()
.zip(duchon_b.iter())
.map(|(lhs, rhs)| (lhs - rhs).abs())
.fold(0.0_f64, f64::max);
assert!(
max_abs <= 1e-10,
"hierarchical ownership should not depend on user term order; max_abs={max_abs}"
);
}
#[test]
fn freeze_roundtrip_preserves_hierarchical_smooth_transforms() {
let data = array![
[0.00, 0.00],
[0.10, 0.15],
[0.18, 0.30],
[0.27, 0.10],
[0.35, 0.55],
[0.46, 0.25],
[0.54, 0.70],
[0.63, 0.40],
[0.72, 0.85],
[0.81, 0.60],
[0.90, 0.95],
[1.00, 0.75],
];
let spec = TermCollectionSpec {
linear_terms: vec![LinearTermSpec {
name: "x".to_string(),
feature_col: 0,
double_penalty: false,
coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
coefficient_min: None,
coefficient_max: None,
}],
random_effect_terms: vec![],
smooth_terms: vec![
SmoothTermSpec {
name: "duchon_xy".to_string(),
basis: SmoothBasisSpec::Duchon {
feature_cols: vec![0, 1],
spec: DuchonBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 6 },
length_scale: Some(1.0),
power: 1,
nullspace_order: DuchonNullspaceOrder::Linear,
identifiability: SpatialIdentifiability::default(),
aniso_log_scales: None,
operator_penalties: DuchonOperatorPenaltySpec::default(),
},
input_scales: None,
},
shape: ShapeConstraint::None,
},
SmoothTermSpec {
name: "s_x".to_string(),
basis: SmoothBasisSpec::BSpline1D {
feature_col: 0,
spec: BSplineBasisSpec {
degree: 3,
penalty_order: 2,
knotspec: BSplineKnotSpec::Generate {
data_range: (0.0, 1.0),
num_internal_knots: 5,
},
double_penalty: false,
identifiability: BSplineIdentifiability::default(),
},
},
shape: ShapeConstraint::None,
},
],
};
let design = build_term_collection_design(data.view(), &spec).expect("fit-time design");
let frozen =
freeze_term_collection_from_design(&spec, &design).expect("freeze hierarchical design");
let replay = build_term_collection_design(data.view(), &frozen).expect("replay design");
let dense_fit = design.design.to_dense();
let dense_replay = replay.design.to_dense();
assert_eq!(dense_fit.dim(), dense_replay.dim());
let max_abs = dense_fit
.iter()
.zip(dense_replay.iter())
.map(|(lhs, rhs)| (lhs - rhs).abs())
.fold(0.0_f64, f64::max);
assert!(
max_abs <= 1e-10,
"frozen hierarchical transforms should replay exactly on the training data; max_abs={max_abs}"
);
}
#[test]
fn spatial_option5_preserves_lazy_thin_plate_terms_at_large_scale() {
let n = 17_000usize;
let k = 2_000usize;
let mut data = Array2::<f64>::zeros((n, 1));
let mut centers = Array2::<f64>::zeros((k, 1));
for i in 0..n {
data[[i, 0]] = i as f64 / (n - 1) as f64;
}
for j in 0..k {
centers[[j, 0]] = j as f64 / (k - 1) as f64;
}
let spec = TermCollectionSpec {
linear_terms: vec![LinearTermSpec {
name: "x".to_string(),
feature_col: 0,
double_penalty: false,
coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
coefficient_min: None,
coefficient_max: None,
}],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "tps_x".to_string(),
basis: SmoothBasisSpec::ThinPlate {
feature_cols: vec![0],
spec: ThinPlateBasisSpec {
center_strategy: CenterStrategy::UserProvided(centers),
length_scale: 1.0,
double_penalty: false,
identifiability: SpatialIdentifiability::OrthogonalToParametric,
radial_reparam: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let design =
build_term_collection_design(data.view(), &spec).expect("large option-5 design");
assert!(matches!(
&design.smooth.term_designs[0],
DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::Lazy(_))
));
let mut c = Array2::<f64>::zeros((n, 2));
c.column_mut(0).fill(1.0);
c.column_mut(1).assign(&data.column(0));
let rel =
orthogonality_relative_residual_for_design(&design.smooth.term_designs[0], c.view())
.expect("orthogonality residual");
assert!(rel <= 1e-8, "lazy option-5 residual too large: {rel}");
}
#[test]
fn spatial_frozen_transform_rebuild_is_exact_on_trainingrows() {
let data = array![
[0.0, 0.1],
[0.2, 0.0],
[0.3, 0.4],
[0.5, 0.2],
[0.7, 0.9],
[1.0, 0.8],
];
let fitspec = TermCollectionSpec {
linear_terms: vec![LinearTermSpec {
name: "lin_x0".to_string(),
feature_col: 0,
double_penalty: false,
coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
coefficient_min: None,
coefficient_max: None,
}],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "tps_xy".to_string(),
basis: SmoothBasisSpec::ThinPlate {
feature_cols: vec![0, 1],
spec: ThinPlateBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 4 },
length_scale: 1.0,
double_penalty: false,
identifiability: SpatialIdentifiability::OrthogonalToParametric,
radial_reparam: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let fit_design = build_term_collection_design(data.view(), &fitspec).unwrap();
let term_meta = &fit_design.smooth.terms[0].metadata;
let (centers, length_scale, z) = match term_meta {
BasisMetadata::ThinPlate {
centers,
length_scale,
identifiability_transform,
..
} => (
centers.clone(),
*length_scale,
identifiability_transform
.clone()
.expect("fit-time Option 5 should store transform"),
),
other => panic!("unexpected metadata variant: {other:?}"),
};
let frozenspec = TermCollectionSpec {
linear_terms: fitspec.linear_terms.clone(),
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "tps_xy".to_string(),
basis: SmoothBasisSpec::ThinPlate {
feature_cols: vec![0, 1],
spec: ThinPlateBasisSpec {
center_strategy: CenterStrategy::UserProvided(centers),
length_scale,
double_penalty: false,
identifiability: SpatialIdentifiability::FrozenTransform { transform: z },
radial_reparam: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let frozen_design = build_term_collection_design(data.view(), &frozenspec).unwrap();
assert_eq!(
fit_design.smooth.term_designs.len(),
frozen_design.smooth.term_designs.len(),
"frozen transform rebuild term count mismatch"
);
let max_abs = fit_design
.smooth
.term_designs
.iter()
.zip(frozen_design.smooth.term_designs.iter())
.flat_map(|(a, b)| {
let a_dense = a.to_dense();
let b_dense = b.to_dense();
assert_eq!(a_dense.dim(), b_dense.dim());
a_dense
.iter()
.zip(b_dense.iter())
.map(|(&x, &y)| (x - y).abs())
.collect::<Vec<_>>()
})
.fold(0.0_f64, f64::max);
assert!(
max_abs <= 1e-12,
"frozen transform rebuild mismatch max_abs={max_abs}"
);
}
#[test]
fn frozen_spatial_replay_preserves_standardized_length_scale_compensation() {
let n = 16usize;
let mut data = Array2::<f64>::zeros((n, 2));
for i in 0..n {
let t = i as f64 / (n as f64 - 1.0);
data[[i, 0]] = 0.07 * i as f64 + 0.02 * (3.0 * t).sin();
data[[i, 1]] = 4.0 * t + 0.35 * (5.0 * t).cos();
}
let tps_spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "tps_xy".to_string(),
basis: SmoothBasisSpec::ThinPlate {
feature_cols: vec![0, 1],
spec: ThinPlateBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 6 },
length_scale: 1.3,
double_penalty: true,
identifiability: SpatialIdentifiability::OrthogonalToParametric,
radial_reparam: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
assert_frozen_replay_matches_fit(data.view(), &tps_spec, "thin-plate");
let matern_spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "matern_xy".to_string(),
basis: SmoothBasisSpec::Matern {
feature_cols: vec![0, 1],
spec: MaternBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 6 },
length_scale: 1.1,
nu: MaternNu::FiveHalves,
include_intercept: false,
double_penalty: true,
identifiability: MaternIdentifiability::CenterSumToZero,
aniso_log_scales: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
assert_frozen_replay_matches_fit(data.view(), &matern_spec, "matern");
let duchon_spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "duchon_xy".to_string(),
basis: SmoothBasisSpec::Duchon {
feature_cols: vec![0, 1],
spec: DuchonBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 6 },
length_scale: Some(1.4),
power: 5,
nullspace_order: DuchonNullspaceOrder::Linear,
identifiability: SpatialIdentifiability::OrthogonalToParametric,
aniso_log_scales: None,
operator_penalties: DuchonOperatorPenaltySpec::default(),
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
assert_frozen_replay_matches_fit(data.view(), &duchon_spec, "duchon");
}
#[test]
fn term_collection_design_adds_random_effect_dummy_blockwithridge() {
let data = array![
[0.1, 0.0],
[0.2, 1.0],
[0.3, 0.0],
[0.4, 2.0],
[0.5, 1.0],
[0.6, 2.0],
];
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![RandomEffectTermSpec {
name: "id".to_string(),
feature_col: 1,
drop_first_level: false,
frozen_levels: None,
}],
smooth_terms: vec![],
};
let design = build_term_collection_design(data.view(), &spec).unwrap();
assert_eq!(design.intercept_range, 0..1);
assert_eq!(design.design.ncols(), 4);
assert_eq!(design.random_effect_ranges.len(), 1);
assert_eq!(design.penalties.len(), 1);
assert_eq!(design.nullspace_dims, vec![0]);
let (_, range) = &design.random_effect_ranges[0];
let dense = design.design.to_dense_cow();
for i in 0..dense.nrows() {
let row_sum: f64 = dense.slice(s![i, range.clone()]).sum();
assert!((row_sum - 1.0).abs() < 1e-12);
}
}
#[test]
fn matern_smooth_buildswith_double_penalty_in_high_dim() {
let n = 12usize;
let d = 10usize;
let mut data = Array2::<f64>::zeros((n, d));
for i in 0..n {
for j in 0..d {
data[[i, j]] = (i as f64) * 0.1 + (j as f64) * 0.03;
}
}
let terms = vec![SmoothTermSpec {
name: "matern_x".to_string(),
basis: SmoothBasisSpec::Matern {
feature_cols: (0..d).collect(),
spec: MaternBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 5 },
length_scale: 0.75,
nu: MaternNu::FiveHalves,
include_intercept: false,
double_penalty: true,
identifiability: MaternIdentifiability::CenterSumToZero,
aniso_log_scales: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
}];
let sd = build_smooth_design(data.view(), &terms).unwrap();
assert_eq!(sd.nrows(), n);
assert_eq!(sd.terms.len(), 1);
assert_eq!(sd.penalties.len(), 3);
assert_eq!(sd.nullspace_dims.len(), 3);
}
#[test]
fn duchon_linear_nullspace_builds_and_reports_nullspace_dim() {
let n = 20usize;
let d = 10usize;
let mut data = Array2::<f64>::zeros((n, d));
for i in 0..n {
for j in 0..d {
let mut key = (i as u64).wrapping_mul(0x9E3779B97F4A7C15);
key ^= (j as u64).wrapping_mul(0xBF58476D1CE4E5B9);
key = (key ^ (key >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
key = (key ^ (key >> 27)).wrapping_mul(0x94D049BB133111EB);
let v = ((key ^ (key >> 31)) as f64) / (u64::MAX as f64);
data[[i, j]] = v;
}
}
let terms = vec![SmoothTermSpec {
name: "duchon_x".to_string(),
basis: SmoothBasisSpec::Duchon {
feature_cols: (0..d).collect(),
spec: DuchonBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 12 },
length_scale: Some(0.9),
power: 5,
nullspace_order: DuchonNullspaceOrder::Linear,
identifiability: SpatialIdentifiability::default(),
aniso_log_scales: None,
operator_penalties: DuchonOperatorPenaltySpec::default(),
},
input_scales: None,
},
shape: ShapeConstraint::None,
}];
let sd = build_smooth_design(data.view(), &terms).unwrap();
assert_eq!(sd.nrows(), n);
assert_eq!(sd.terms.len(), 1);
assert_eq!(sd.penalties.len(), 3);
assert_eq!(sd.nullspace_dims.len(), 3);
}
#[test]
fn joint_duchon_orderzero_raw_smooth_build_preserves_unconstrained_basis() {
let n = 12usize;
let d = 4usize;
let mut data = Array2::<f64>::zeros((n, d));
for i in 0..n {
for j in 0..d {
data[[i, j]] = (i as f64) * 0.13 + (j as f64) * 0.17;
}
}
let terms = vec![SmoothTermSpec {
name: "duchon_joint".to_string(),
basis: SmoothBasisSpec::Duchon {
feature_cols: (0..d).collect(),
spec: DuchonBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 4 },
length_scale: Some(1.0),
power: 3,
nullspace_order: DuchonNullspaceOrder::Zero,
identifiability: SpatialIdentifiability::default(),
aniso_log_scales: None,
operator_penalties: DuchonOperatorPenaltySpec::default(),
},
input_scales: None,
},
shape: ShapeConstraint::None,
}];
let sd = build_smooth_design(data.view(), &terms).expect("joint duchon build");
assert_eq!(sd.total_smooth_cols(), 4);
match &sd.terms[0].metadata {
BasisMetadata::Duchon {
identifiability_transform,
..
} => {
assert!(
identifiability_transform.is_some(),
"raw smooth build should freeze Duchon orthogonality once the basis is built"
);
}
other => panic!("expected Duchon metadata, got {other:?}"),
}
}
#[test]
fn term_collection_joint_duchon_carries_frozen_transform_into_metadata() {
let n = 12usize;
let d = 4usize;
let mut data = Array2::<f64>::zeros((n, d));
for i in 0..n {
for j in 0..d {
data[[i, j]] = (i as f64) * 0.13 + (j as f64) * 0.17;
}
}
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "duchon_joint".to_string(),
basis: SmoothBasisSpec::Duchon {
feature_cols: (0..d).collect(),
spec: DuchonBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 4 },
length_scale: Some(1.0),
power: 3,
nullspace_order: DuchonNullspaceOrder::Zero,
identifiability: SpatialIdentifiability::default(),
aniso_log_scales: None,
operator_penalties: DuchonOperatorPenaltySpec::default(),
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let design =
build_term_collection_design(data.view(), &spec).expect("term collection design");
let term = &design.smooth.terms[0];
assert_eq!(term.coeff_range.len(), 3);
match &term.metadata {
BasisMetadata::Duchon {
identifiability_transform,
..
} => {
let z = identifiability_transform
.as_ref()
.expect("term collection should store frozen Duchon transform");
assert_eq!(z.nrows(), 4);
assert_eq!(z.ncols(), 3);
}
other => panic!("expected Duchon metadata, got {other:?}"),
}
}
#[test]
fn adaptive_cache_respects_frozen_joint_duchon_transform() {
let n = 12usize;
let d = 4usize;
let mut data = Array2::<f64>::zeros((n, d));
for i in 0..n {
for j in 0..d {
data[[i, j]] = (i as f64) * 0.13 + (j as f64) * 0.17;
}
}
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "duchon_joint".to_string(),
basis: SmoothBasisSpec::Duchon {
feature_cols: (0..d).collect(),
spec: DuchonBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 4 },
length_scale: Some(1.0),
power: 3,
nullspace_order: DuchonNullspaceOrder::Zero,
identifiability: SpatialIdentifiability::default(),
aniso_log_scales: None,
operator_penalties: DuchonOperatorPenaltySpec::default(),
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let design =
build_term_collection_design(data.view(), &spec).expect("term collection design");
let caches =
extract_spatial_operator_runtime_caches(&spec, &design).expect("adaptive caches");
assert_eq!(caches.len(), 1);
assert_eq!(
caches[0].coeff_global_range.len(),
design.smooth.terms[0].coeff_range.len()
);
}
#[test]
fn frozen_joint_duchonspec_rebuild_keeps_adaptive_cache_in_sync() {
let n = 12usize;
let d = 4usize;
let mut data = Array2::<f64>::zeros((n, d));
for i in 0..n {
for j in 0..d {
data[[i, j]] = (i as f64) * 0.13 + (j as f64) * 0.17;
}
}
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "duchon_joint".to_string(),
basis: SmoothBasisSpec::Duchon {
feature_cols: (0..d).collect(),
spec: DuchonBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 4 },
length_scale: Some(1.0),
power: 3,
nullspace_order: DuchonNullspaceOrder::Zero,
identifiability: SpatialIdentifiability::default(),
aniso_log_scales: None,
operator_penalties: DuchonOperatorPenaltySpec::default(),
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let design = build_term_collection_design(data.view(), &spec).expect("base design");
let frozen = freeze_term_collection_from_design(&spec, &design).expect("freeze spec");
let rebuilt = build_term_collection_design(data.view(), &frozen).expect("rebuilt design");
let caches =
extract_spatial_operator_runtime_caches(&frozen, &rebuilt).expect("adaptive caches");
assert_eq!(caches.len(), 1);
assert_eq!(caches[0].termname, "duchon_joint");
assert_eq!(rebuilt.smooth.terms[0].coeff_range.len(), 3);
}
#[test]
fn frozen_joint_maternspec_rebuild_keeps_adaptive_cache_in_sync() {
let n = 12usize;
let d = 2usize;
let mut data = Array2::<f64>::zeros((n, d));
for i in 0..n {
data[[i, 0]] = i as f64 * 0.13;
data[[i, 1]] = (i as f64 * 0.17).sin();
}
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "matern_joint".to_string(),
basis: SmoothBasisSpec::Matern {
feature_cols: (0..d).collect(),
spec: MaternBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 6 },
length_scale: 1.0,
nu: MaternNu::FiveHalves,
include_intercept: false,
double_penalty: true,
identifiability: MaternIdentifiability::CenterSumToZero,
aniso_log_scales: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let design = build_term_collection_design(data.view(), &spec).expect("base design");
let frozen = freeze_term_collection_from_design(&spec, &design).expect("freeze spec");
let rebuilt = build_term_collection_design(data.view(), &frozen).expect("rebuilt design");
let caches =
extract_spatial_operator_runtime_caches(&frozen, &rebuilt).expect("adaptive caches");
assert_eq!(caches.len(), 1);
assert_eq!(caches[0].termname, "matern_joint");
assert_eq!(rebuilt.smooth.terms.len(), 1);
assert!(!rebuilt.smooth.terms[0].coeff_range.is_empty());
}
#[test]
fn tensor_bspline_term_builds_te_style_design_and_penalties() {
let n = 10usize;
let mut data = Array2::<f64>::zeros((n, 2));
for i in 0..n {
data[[i, 0]] = i as f64 / (n as f64 - 1.0);
data[[i, 1]] = (i as f64 / (n as f64 - 1.0)).powi(2);
}
let spec_x = BSplineBasisSpec {
degree: 3,
penalty_order: 2,
knotspec: BSplineKnotSpec::Generate {
data_range: (0.0, 1.0),
num_internal_knots: 3,
},
double_penalty: false,
identifiability: BSplineIdentifiability::default(),
};
let spec_y = BSplineBasisSpec {
degree: 3,
penalty_order: 2,
knotspec: BSplineKnotSpec::Generate {
data_range: (0.0, 1.0),
num_internal_knots: 2,
},
double_penalty: false,
identifiability: BSplineIdentifiability::default(),
};
let terms = vec![SmoothTermSpec {
name: "te_xy".to_string(),
basis: SmoothBasisSpec::TensorBSpline {
feature_cols: vec![0, 1],
spec: TensorBSplineSpec {
marginalspecs: vec![spec_x, spec_y],
double_penalty: true,
identifiability: TensorBSplineIdentifiability::default(),
},
},
shape: ShapeConstraint::None,
}];
let sd = build_smooth_design(data.view(), &terms).unwrap();
assert_eq!(sd.nrows(), n);
assert_eq!(sd.terms.len(), 1);
assert_eq!(sd.penalties.len(), 3);
assert_eq!(sd.nullspace_dims.len(), 3);
assert!(
sd.penalties
.iter()
.all(|bp| bp.local.nrows() == bp.block_size())
);
assert!(
sd.penalties
.iter()
.all(|bp| bp.col_range.end <= sd.total_smooth_cols())
);
}
#[test]
fn tensor_bspline_design_matches_extended_marginal_kronecker_product() {
let data = array![[-0.2, 0.1], [0.2, 0.4], [0.7, 0.8], [1.2, 1.1],];
let spec_x = BSplineBasisSpec {
degree: 3,
penalty_order: 2,
knotspec: BSplineKnotSpec::Generate {
data_range: (0.0, 1.0),
num_internal_knots: 3,
},
double_penalty: false,
identifiability: BSplineIdentifiability::None,
};
let spec_y = BSplineBasisSpec {
degree: 3,
penalty_order: 2,
knotspec: BSplineKnotSpec::Generate {
data_range: (0.0, 1.0),
num_internal_knots: 2,
},
double_penalty: false,
identifiability: BSplineIdentifiability::None,
};
let mx = build_bspline_basis_1d(data.column(0), &spec_x)
.unwrap()
.design
.to_dense();
let my = build_bspline_basis_1d(data.column(1), &spec_y)
.unwrap()
.design
.to_dense();
let expected = tensor_product_design_from_marginals(&[mx.clone(), my.clone()]).unwrap();
let term = SmoothTermSpec {
name: "te_xy".to_string(),
basis: SmoothBasisSpec::TensorBSpline {
feature_cols: vec![0, 1],
spec: TensorBSplineSpec {
marginalspecs: vec![spec_x, spec_y],
double_penalty: false,
identifiability: TensorBSplineIdentifiability::None,
},
},
shape: ShapeConstraint::None,
};
let got = build_smooth_design(data.view(), &[term])
.unwrap()
.term_designs
.into_iter()
.next()
.unwrap()
.to_dense();
assert_eq!(got.dim(), expected.dim());
for i in 0..got.nrows() {
for j in 0..got.ncols() {
assert!((got[[i, j]] - expected[[i, j]]).abs() < 1e-10);
}
}
}
#[test]
fn tensor_bspline_design_is_identifiable_against_global_intercept() {
let n = 120usize;
let mut data = Array2::<f64>::zeros((n, 2));
for i in 0..n {
let t = i as f64 / (n as f64 - 1.0);
data[[i, 0]] = t;
data[[i, 1]] = (3.0 * t).sin();
}
let tensor_term = SmoothTermSpec {
name: "te_xy".to_string(),
basis: SmoothBasisSpec::TensorBSpline {
feature_cols: vec![0, 1],
spec: TensorBSplineSpec {
marginalspecs: vec![
BSplineBasisSpec {
degree: 3,
penalty_order: 2,
knotspec: BSplineKnotSpec::Generate {
data_range: (0.0, 1.0),
num_internal_knots: 6,
},
double_penalty: false,
identifiability: BSplineIdentifiability::default(),
},
BSplineBasisSpec {
degree: 3,
penalty_order: 2,
knotspec: BSplineKnotSpec::Generate {
data_range: (-1.0, 1.0),
num_internal_knots: 6,
},
double_penalty: false,
identifiability: BSplineIdentifiability::default(),
},
],
double_penalty: false,
identifiability: TensorBSplineIdentifiability::default(),
},
},
shape: ShapeConstraint::None,
};
let sd = build_smooth_design(data.view(), &[tensor_term.clone()]).unwrap();
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![tensor_term],
};
let full = build_term_collection_design(data.view(), &spec).unwrap();
let ones = Array1::<f64>::ones(n);
let sd_dense_terms = sd
.term_designs
.iter()
.map(|d| d.to_dense())
.collect::<Vec<_>>();
let sd_assembled = ndarray::concatenate(
ndarray::Axis(1),
&sd_dense_terms.iter().map(|d| d.view()).collect::<Vec<_>>(),
)
.unwrap();
let residualvs_tensor = residual_norm_to_column_space(&sd_assembled, &ones);
let full_design_dense = full.design.to_dense();
let residualvs_full = residual_norm_to_column_space(&full_design_dense, &ones);
assert!(residualvs_tensor > 1e-6);
assert!(residualvs_full < 1e-8);
}
#[test]
fn spatial_length_scale_optimization_monotone_improves_or_keeps_score_for_matern_two_feature() {
let n = 60usize;
let d = 3usize;
let mut data = Array2::<f64>::zeros((n, d));
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let x0 = i as f64 / (n as f64 - 1.0);
let x1 = (i as f64 * 0.13).sin();
let x2 = (i as f64 * 0.07).cos();
data[[i, 0]] = x0;
data[[i, 1]] = x1;
data[[i, 2]] = x2;
y[i] = (2.5 * x0).sin() + 0.4 * x1 - 0.2 * x2;
}
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "matern".to_string(),
basis: SmoothBasisSpec::Matern {
feature_cols: vec![0, 1, 2],
spec: MaternBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 12 },
length_scale: 20.0,
nu: MaternNu::FiveHalves,
include_intercept: false,
double_penalty: true,
identifiability: MaternIdentifiability::CenterSumToZero,
aniso_log_scales: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let fit_opts = FitOptions {
latent_cloglog: None,
mixture_link: None,
optimize_mixture: false,
sas_link: None,
optimize_sas: false,
compute_inference: true,
max_iter: 40,
tol: 1e-6,
nullspace_dims: vec![],
linear_constraints: None,
firth_bias_reduction: false,
adaptive_regularization: None,
penalty_shrinkage_floor: Some(1e-6),
rho_prior: Default::default(),
kronecker_penalty_system: None,
kronecker_factored: None,
};
let weights = Array1::ones(n);
let offset = Array1::zeros(n);
let baseline = fit_term_collection_forspec(
data.view(),
y.view(),
weights.view(),
offset.view(),
&spec,
LikelihoodFamily::GaussianIdentity,
&fit_opts,
)
.expect("baseline fit should succeed");
let baseline_score = fit_score(&baseline.fit);
let optimized = fit_term_collectionwith_spatial_length_scale_optimization(
data.view(),
y.clone(),
weights.clone(),
offset.clone(),
&spec,
LikelihoodFamily::GaussianIdentity,
&fit_opts,
&SpatialLengthScaleOptimizationOptions {
enabled: true,
max_outer_iter: 2,
rel_tol: 1e-5,
log_step: std::f64::consts::LN_2,
min_length_scale: 1e-3,
max_length_scale: 1e3,
pilot_subsample_threshold: 0,
},
)
.expect("optimized fit should succeed");
let optimized_score = fit_score(&optimized.fit);
assert!(optimized_score <= baseline_score + 1e-10);
let ls = match &optimized.resolvedspec.smooth_terms[0].basis {
SmoothBasisSpec::Matern { spec, .. } => spec.length_scale,
_ => panic!("expected Matérn term"),
};
assert!(ls.is_finite() && (1e-3..=1e3).contains(&ls));
match &optimized.resolvedspec.smooth_terms[0].basis {
SmoothBasisSpec::Matern { spec, .. } => {
assert!(matches!(
spec.center_strategy,
CenterStrategy::UserProvided(_)
));
assert!(matches!(
spec.identifiability,
MaternIdentifiability::FrozenTransform { .. }
));
}
_ => panic!("expected Matérn term"),
}
}
#[test]
fn exact_joint_two_block_spatial_length_scale_freezes_matern_centers() {
let n = 40usize;
let mut data = Array2::<f64>::zeros((n, 2));
for i in 0..n {
let x0 = i as f64 / (n as f64 - 1.0);
let x1 = (i as f64 * 0.21).sin();
data[[i, 0]] = x0;
data[[i, 1]] = x1;
}
let matern_term = |name: &str, length_scale: f64| SmoothTermSpec {
name: name.to_string(),
basis: SmoothBasisSpec::Matern {
feature_cols: vec![0, 1],
spec: MaternBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 8 },
length_scale,
nu: MaternNu::FiveHalves,
include_intercept: false,
double_penalty: true,
identifiability: MaternIdentifiability::CenterSumToZero,
aniso_log_scales: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
};
let meanspec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![matern_term("mean_matern", 0.8)],
};
let noisespec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![matern_term("noise_matern", 1.1)],
};
let kappa_options = SpatialLengthScaleOptimizationOptions {
enabled: true,
max_outer_iter: 1,
rel_tol: 1e-6,
log_step: std::f64::consts::LN_2,
min_length_scale: 1e-3,
max_length_scale: 1e3,
pilot_subsample_threshold: 0,
};
let joint_setup = two_block_exact_joint_hyper_setup(&meanspec, &noisespec, &kappa_options);
let theta_dim = joint_setup.theta0().len();
let mean_terms = spatial_length_scale_term_indices(&meanspec);
let noise_terms = spatial_length_scale_term_indices(&noisespec);
let policy = crate::families::custom_family::OuterDerivativePolicy {
capability: crate::families::custom_family::ExactOuterDerivativeOrder::Second,
predicted_hessian_work: 0,
predicted_gradient_work: 0,
subsample_capable: false,
};
let solved = optimize_spatial_length_scale_exact_joint(
data.view(),
&[meanspec.clone(), noisespec.clone()],
&[mean_terms, noise_terms],
&kappa_options,
&joint_setup,
crate::seeding::SeedRiskProfile::Gaussian,
true,
true,
false,
None,
policy,
|theta, specs, designs| {
assert_eq!(theta.len(), theta_dim);
assert_eq!(specs.len(), 2);
Ok(designs[0].design.ncols() as f64
+ designs[1].design.ncols() as f64
+ designs[0].penalties.len() as f64
+ designs[1].penalties.len() as f64)
},
|theta, specs, designs, eval_mode, _row_set| {
assert_eq!(theta.len(), theta_dim);
assert_eq!(specs.len(), 2);
assert!(!designs.is_empty());
Ok((
0.0,
Array1::zeros(theta_dim),
if matches!(
eval_mode,
crate::solver::estimate::reml::unified::EvalMode::ValueGradientHessian
) {
crate::solver::outer_strategy::HessianResult::Analytic(Array2::zeros((
theta_dim, theta_dim,
)))
} else {
crate::solver::outer_strategy::HessianResult::Unavailable
},
))
},
|theta, specs, designs| {
assert_eq!(theta.len(), theta_dim);
assert_eq!(specs.len(), 2);
assert!(!designs.is_empty());
Ok(crate::solver::outer_strategy::EfsEval {
cost: 0.0,
steps: vec![0.0; theta_dim],
beta: None,
psi_gradient: None,
psi_indices: None,
})
},
)
.expect("exact joint two-block κ optimization should succeed");
for resolved in [&solved.resolved_specs[0], &solved.resolved_specs[1]] {
match &resolved.smooth_terms[0].basis {
SmoothBasisSpec::Matern { spec, .. } => {
assert!(matches!(
spec.center_strategy,
CenterStrategy::UserProvided(_)
));
assert!(matches!(
spec.identifiability,
MaternIdentifiability::FrozenTransform { .. }
));
}
_ => panic!("expected Matérn term"),
}
}
}
#[test]
fn exact_joint_spatial_outer_hessian_available_for_dense_non_gaussian_designs() {
let n = 5_001usize;
let p = 100usize;
let data = Array2::from_shape_fn((n, p), |(i, j)| ((i + j + 1) as f64).sin());
let spec = TermCollectionSpec {
linear_terms: (0..p)
.map(|j| LinearTermSpec {
name: format!("x{j}"),
feature_col: j,
double_penalty: false,
coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
coefficient_min: None,
coefficient_max: None,
})
.collect(),
random_effect_terms: vec![],
smooth_terms: vec![],
};
let design = build_term_collection_design(data.view(), &spec).expect("design");
assert!(exact_joint_spatial_outer_hessian_available(
LikelihoodFamily::BinomialLogit,
&design,
));
assert!(exact_joint_spatial_outer_hessian_available(
LikelihoodFamily::GaussianIdentity,
&design,
));
}
#[test]
fn spatial_aniso_joint_exact_hessian_materializes_small_case() {
let n = 18usize;
let mut data = Array2::<f64>::zeros((n, 2));
for i in 0..n {
let t = i as f64 / (n as f64 - 1.0);
data[[i, 0]] = t;
data[[i, 1]] = (0.41 * i as f64).sin();
}
let y = Array1::from_iter((0..n).map(|i| {
let t = i as f64 / (n as f64 - 1.0);
0.4 + (2.0 * std::f64::consts::PI * t).sin()
}));
let weights = Array1::ones(n);
let offset = Array1::zeros(n);
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "matern_aniso".to_string(),
basis: SmoothBasisSpec::Matern {
feature_cols: vec![0, 1],
spec: MaternBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 5 },
length_scale: 0.85,
nu: MaternNu::FiveHalves,
include_intercept: false,
double_penalty: true,
identifiability: MaternIdentifiability::CenterSumToZero,
aniso_log_scales: Some(vec![0.2, -0.2]),
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let fit_opts = FitOptions {
latent_cloglog: None,
mixture_link: None,
optimize_mixture: false,
sas_link: None,
optimize_sas: false,
compute_inference: false,
max_iter: 120,
tol: 1e-10,
nullspace_dims: vec![],
linear_constraints: None,
firth_bias_reduction: false,
adaptive_regularization: None,
penalty_shrinkage_floor: None,
rho_prior: Default::default(),
kronecker_penalty_system: None,
kronecker_factored: None,
};
let design = build_term_collection_design(data.view(), &spec).expect("design");
let frozen = freeze_term_collection_from_design(&spec, &design).expect("freeze");
let spatial_terms = spatial_length_scale_term_indices(&frozen);
let dims_per_term = spatial_dims_per_term(&frozen, &spatial_terms);
assert_eq!(dims_per_term, vec![2]);
let rho_dim = design.penalties.len();
let log_kappa0 = SpatialLogKappaCoords::from_length_scales_aniso(
&frozen,
&spatial_terms,
&SpatialLengthScaleOptimizationOptions::default(),
);
let mut theta = Array1::<f64>::zeros(rho_dim + log_kappa0.as_array().len());
for j in 0..rho_dim {
theta[j] = -0.15 + 0.07 * j as f64;
}
theta.slice_mut(s![rho_dim..]).assign(log_kappa0.as_array());
let external_opts =
external_opts_for_design(LikelihoodFamily::GaussianIdentity, &design, &fit_opts);
let mut cache = SingleBlockExactJointDesignCache::new(
data.view(),
frozen,
design.clone(),
spatial_terms,
rho_dim,
dims_per_term,
)
.expect("single-block cache");
let mut evaluator = crate::estimate::ExternalJointHyperEvaluator::new(
y.view(),
weights.view(),
&design.design,
offset.view(),
&design.penalties,
&external_opts,
"small aniso Hessian finite-difference evaluator",
)
.expect("evaluator");
let eval_at = |theta: &Array1<f64>,
cache: &mut SingleBlockExactJointDesignCache<'_>,
evaluator: &mut crate::estimate::ExternalJointHyperEvaluator<'_>,
order: crate::solver::outer_strategy::OuterEvalOrder| {
cache.ensure_theta(theta).expect("theta applied");
let hyper_dirs = try_build_spatial_log_kappa_hyper_dirs(
data.view(),
cache.spec(),
cache.design(),
&cache.spatial_terms,
)
.expect("hyper dirs build")
.expect("hyper dirs present");
evaluate_joint_reml_outer_eval_at_theta(
evaluator,
cache.design(),
theta,
rho_dim,
hyper_dirs,
None,
order,
)
.expect("outer eval")
};
let (_, gradient, hessian_result) = eval_at(
&theta,
&mut cache,
&mut evaluator,
crate::solver::outer_strategy::OuterEvalOrder::ValueGradientHessian,
);
let hessian = hessian_result
.materialize_dense()
.expect("hessian materializes")
.expect("hessian present");
assert_eq!(hessian.nrows(), theta.len());
assert_eq!(hessian.ncols(), theta.len());
assert!(hessian.iter().all(|value| value.is_finite()));
assert!(gradient.iter().all(|value| value.is_finite()));
let symmetry_diff = max_abs_diff_matrix(&hessian, &hessian.t().to_owned());
assert!(
symmetry_diff <= 1e-10,
"small aniso exact Hessian should be symmetric, max diff={symmetry_diff}"
);
let psi_block = hessian.slice(s![rho_dim.., rho_dim..]).to_owned();
assert!(
psi_block.iter().any(|value| value.abs() > 1e-10),
"small aniso exact Hessian should carry non-zero ψ curvature"
);
}
#[test]
fn iso_kappa_duchon_binomial_probit_joint_gradient_matches_finite_difference() {
let n = 80usize;
let mut data = Array2::<f64>::zeros((n, 1));
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let t = i as f64 / (n as f64 - 1.0);
data[[i, 0]] = t;
let eta = 1.4 * (2.0 * std::f64::consts::PI * t).sin() + 0.5 * (t - 0.5);
y[i] = if eta + 0.7 * (3.7 * (i as f64) + 1.0).sin() > 0.0 {
1.0
} else {
0.0
};
}
let weights = Array1::ones(n);
let offset = Array1::zeros(n);
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "duchon_1d".to_string(),
basis: SmoothBasisSpec::Duchon {
feature_cols: vec![0],
spec: DuchonBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 8 },
length_scale: Some(1.0),
power: 1,
nullspace_order: DuchonNullspaceOrder::Linear,
identifiability: SpatialIdentifiability::default(),
aniso_log_scales: None,
operator_penalties: DuchonOperatorPenaltySpec::default(),
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let fit_opts = FitOptions {
latent_cloglog: None,
mixture_link: None,
optimize_mixture: false,
sas_link: None,
optimize_sas: false,
compute_inference: false,
max_iter: 200,
tol: 1e-12,
nullspace_dims: vec![],
linear_constraints: None,
firth_bias_reduction: false,
adaptive_regularization: None,
penalty_shrinkage_floor: None,
rho_prior: Default::default(),
kronecker_penalty_system: None,
kronecker_factored: None,
};
let design = build_term_collection_design(data.view(), &spec).expect("design");
let frozen = freeze_term_collection_from_design(&spec, &design).expect("freeze");
let spatial_terms = spatial_length_scale_term_indices(&frozen);
let dims_per_term = spatial_dims_per_term(&frozen, &spatial_terms);
assert_eq!(dims_per_term, vec![1], "1D Duchon should expose one log-κ");
let rho_dim = design.penalties.len();
let psi_dim: usize = dims_per_term.iter().sum();
assert!(psi_dim >= 1, "test requires at least one log-κ axis");
let external_opts =
external_opts_for_design(LikelihoodFamily::BinomialProbit, &design, &fit_opts);
let mut cache = SingleBlockExactJointDesignCache::new(
data.view(),
frozen.clone(),
design.clone(),
spatial_terms.clone(),
rho_dim,
dims_per_term.clone(),
)
.expect("single-block cache");
let mut evaluator = crate::estimate::ExternalJointHyperEvaluator::new(
y.view(),
weights.view(),
&design.design,
offset.view(),
&design.penalties,
&external_opts,
"iso-kappa Duchon BinomialProbit FD evaluator",
)
.expect("evaluator");
let cost_at = |theta: &Array1<f64>,
cache: &mut SingleBlockExactJointDesignCache<'_>,
evaluator: &mut crate::estimate::ExternalJointHyperEvaluator<'_>|
-> f64 {
cache.ensure_theta(theta).expect("ensure_theta");
let design = cache.design();
evaluator
.evaluate_cost_only(
&design.design,
&design.penalties,
&design.nullspace_dims,
design.linear_constraints.clone(),
theta,
rho_dim,
None,
"iso-kappa Duchon FD cost-only",
)
.expect("cost-only eval")
};
let analytic_at = |theta: &Array1<f64>,
cache: &mut SingleBlockExactJointDesignCache<'_>,
evaluator: &mut crate::estimate::ExternalJointHyperEvaluator<'_>|
-> (f64, Array1<f64>) {
cache.ensure_theta(theta).expect("ensure_theta");
let hyper_dirs = try_build_spatial_log_kappa_hyper_dirs(
data.view(),
cache.spec(),
cache.design(),
&cache.spatial_terms,
)
.expect("hyper dirs build")
.expect("hyper dirs present");
let (cost, grad, _hess) = evaluate_joint_reml_outer_eval_at_theta(
evaluator,
cache.design(),
theta,
rho_dim,
hyper_dirs,
None,
crate::solver::outer_strategy::OuterEvalOrder::ValueAndGradient,
)
.expect("outer eval");
(cost, grad)
};
let theta_dim = rho_dim + psi_dim;
let mut theta_zero = Array1::<f64>::zeros(theta_dim);
for j in 0..theta_dim {
theta_zero[j] = 0.0;
}
let mut theta_base = Array1::<f64>::zeros(theta_dim);
for j in 0..rho_dim {
theta_base[j] = 0.2 - 0.1 * j as f64;
}
for k in 0..psi_dim {
theta_base[rho_dim + k] = 0.0;
}
let mut theta_psi_only = Array1::<f64>::zeros(theta_dim);
for k in 0..psi_dim {
theta_psi_only[rho_dim + k] = 0.4;
}
let mut theta_alt = theta_base.clone();
for j in 0..rho_dim {
theta_alt[j] = 1.0 + 0.05 * j as f64;
}
for k in 0..psi_dim {
theta_alt[rho_dim + k] = 0.4;
}
let h = 1e-5_f64;
let rel_tol = 5e-3_f64;
let mut violations: Vec<String> = Vec::new();
for (label, theta) in [
("zero", &theta_zero),
("psi_only", &theta_psi_only),
("base", &theta_base),
("alt", &theta_alt),
] {
let (cost_an, grad_an) = analytic_at(theta, &mut cache, &mut evaluator);
assert!(
cost_an.is_finite(),
"analytic cost not finite at {label}: cost={cost_an:?}"
);
eprintln!(
"[fd_iso_kappa_duchon {label}] cost_analytic={cost_an:+.6e} \
grad_norm_analytic={:.6e}",
grad_an.iter().map(|g| g * g).sum::<f64>().sqrt()
);
for j in 0..theta_dim {
let mut plus = theta.clone();
plus[j] += h;
let mut minus = theta.clone();
minus[j] -= h;
let cp = cost_at(&plus, &mut cache, &mut evaluator);
let cm = cost_at(&minus, &mut cache, &mut evaluator);
let fd = (cp - cm) / (2.0 * h);
let denom = fd.abs().max(grad_an[j].abs()).max(1e-3);
let rel = (grad_an[j] - fd).abs() / denom;
let kind = if j < rho_dim { "rho" } else { "psi" };
eprintln!(
"[fd_iso_kappa_duchon {label}] {kind} j={j} analytic={:+.6e} fd={:+.6e} \
cp={:+.12e} cm={:+.12e} dcost={:+.6e} abs_err={:.3e} rel_err={:.3e}",
grad_an[j],
fd,
cp,
cm,
cp - cm,
(grad_an[j] - fd).abs(),
rel
);
assert!(
cp.is_finite() && cm.is_finite(),
"non-finite cost in FD at {label} j={j}: cp={cp}, cm={cm}"
);
if rel >= rel_tol {
violations.push(format!(
"{label} {kind} j={j}: analytic={:+.6e}, fd={:+.6e}, rel_err={:.3e}",
grad_an[j], fd, rel
));
}
}
}
assert!(
violations.is_empty(),
"FD/analytic gradient disagreement at rel_tol={rel_tol:.1e}:\n {}",
violations.join("\n ")
);
}
#[test]
fn spatial_aniso_joint_large_psi_dim_keeps_second_order_route() {
let cap = crate::solver::outer_strategy::OuterCapability {
gradient: crate::solver::outer_strategy::Derivative::Analytic,
hessian: crate::solver::outer_strategy::DeclaredHessianForm::Either,
n_params: 40,
psi_dim: 31,
fixed_point_available: true,
barrier_config: None,
prefer_gradient_only: false,
disable_fixed_point: false,
};
let route = crate::solver::outer_strategy::plan(&cap);
assert_eq!(route.solver, crate::solver::outer_strategy::Solver::Arc);
assert_eq!(
route.hessian_source,
crate::solver::outer_strategy::HessianSource::Analytic
);
assert!(route.routing_log_line().contains("matrix-free=true"));
}
#[test]
fn exact_joint_spatial_outer_hessian_available_for_sparse_designs() {
let n = 96usize;
let x = Array1::linspace(0.0, 1.0, n);
let mut data = Array2::<f64>::zeros((n, 1));
data.column_mut(0).assign(&x);
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "s_x".to_string(),
basis: SmoothBasisSpec::BSpline1D {
feature_col: 0,
spec: BSplineBasisSpec {
degree: 3,
penalty_order: 2,
knotspec: BSplineKnotSpec::Generate {
data_range: (0.0, 1.0),
num_internal_knots: 32,
},
double_penalty: false,
identifiability: BSplineIdentifiability::default(),
},
},
shape: ShapeConstraint::None,
}],
};
let design = build_term_collection_design(data.view(), &spec).expect("design");
assert!(matches!(design.design, DesignMatrix::Sparse(_)));
assert!(exact_joint_spatial_outer_hessian_available(
LikelihoodFamily::BinomialLogit,
&design,
));
}
#[test]
fn iso_kappa_duchon_dx_dpsi_matches_fd() {
let n = 80usize;
let mut data = Array2::<f64>::zeros((n, 1));
for i in 0..n {
data[[i, 0]] = i as f64 / (n as f64 - 1.0);
}
let spec_orig = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "duchon_1d".to_string(),
basis: SmoothBasisSpec::Duchon {
feature_cols: vec![0],
spec: DuchonBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 8 },
length_scale: Some(1.0),
power: 1,
nullspace_order: DuchonNullspaceOrder::Linear,
identifiability: SpatialIdentifiability::default(),
aniso_log_scales: None,
operator_penalties: DuchonOperatorPenaltySpec::default(),
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let design = build_term_collection_design(data.view(), &spec_orig).expect("design");
let frozen = freeze_term_collection_from_design(&spec_orig, &design).expect("freeze");
let build_design_at = |psi: f64| -> Array2<f64> {
let mut s = frozen.clone();
if let SmoothBasisSpec::Duchon {
spec: ref mut duchon, ..
} = s.smooth_terms[0].basis
{
duchon.length_scale = Some((-psi).exp());
}
let d = build_term_collection_design(data.view(), &s).expect("rebuild");
d.design.to_dense()
};
let psi_eval = 0.0_f64;
let duchon_spec = if let SmoothBasisSpec::Duchon { spec: ref s, .. } =
frozen.smooth_terms[0].basis
{
s.clone()
} else {
panic!("expected Duchon");
};
let mut duchon_spec_at = duchon_spec.clone();
duchon_spec_at.length_scale = Some((-psi_eval).exp());
let bundle = crate::basis::build_duchon_basis_log_kappa_derivatives(
data.view(),
&duchon_spec_at,
)
.expect("derivatives");
let op = bundle.implicit_operator.expect("implicit operator");
let p = op.p_out();
let h = 1e-4_f64;
let x_plus = build_design_at(psi_eval + h);
let x_minus = build_design_at(psi_eval - h);
eprintln!(
"[DXDPSI_FD] X(+h)[0,0..3]={:?} X(-h)[0,0..3]={:?}",
x_plus.row(0).iter().take(3).copied().collect::<Vec<_>>(),
x_minus.row(0).iter().take(3).copied().collect::<Vec<_>>(),
);
eprintln!(
"[DXDPSI_FD] X(+h) shape={:?} X(-h) shape={:?} p_out={}",
x_plus.shape(), x_minus.shape(), p,
);
let x_at = build_design_at(psi_eval);
let orig_design = build_term_collection_design(data.view(), &spec_orig).expect("rebuild orig");
eprintln!(
"[DXDPSI_FD] X(psi_eval) shape={:?} orig_design.ncols={}",
x_at.shape(),
orig_design.design.ncols(),
);
let mut analytic = Array2::<f64>::zeros((n, p));
let mut basisv = Array1::<f64>::zeros(p);
for j in 0..p {
basisv[j] = 1.0;
let col = op.forward_mul(0, &basisv.view()).expect("forward_mul");
analytic.column_mut(j).assign(&col);
basisv[j] = 0.0;
}
let smooth_start = 1usize;
let v_test = Array1::<f64>::from_shape_fn(n, |i| (i as f64 * 0.07).sin());
let analytic_tv = op.transpose_mul(0, &v_test.view()).expect("transpose_mul");
let fd_tv_full = (&x_plus.t() - &x_minus.t()) / (2.0 * h);
let fd_tv = fd_tv_full.dot(&v_test);
let fd_tv_smooth = fd_tv.slice(s![smooth_start..(smooth_start + p)]).to_owned();
let _ = &fd_tv_smooth;
let max_tv_diff = analytic_tv
.iter()
.zip(fd_tv_smooth.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f64, f64::max);
let max_tv_abs = analytic_tv
.iter()
.map(|v| v.abs())
.fold(0.0f64, f64::max);
eprintln!(
"[DXDPSI_TV] max|analytic_tv - fd_tv|={:.3e} max|analytic_tv|={:.3e}",
max_tv_diff, max_tv_abs
);
eprintln!(
"[DXDPSI_TV] analytic_tv={:?}",
analytic_tv.iter().take(p).copied().collect::<Vec<_>>()
);
eprintln!(
"[DXDPSI_TV] fd_tv_smooth={:?}",
fd_tv_smooth.iter().take(p).copied().collect::<Vec<_>>()
);
let fd_full = (&x_plus - &x_minus) / (2.0 * h);
let fd = fd_full.slice(s![.., smooth_start..(smooth_start + p)]).to_owned();
let mut max_diff = 0.0_f64;
let mut max_abs = 0.0_f64;
for i in 0..n {
for j in 0..p {
let d = (analytic[[i, j]] - fd[[i, j]]).abs();
if d > max_diff {
max_diff = d;
}
if analytic[[i, j]].abs() > max_abs {
max_abs = analytic[[i, j]].abs();
}
}
}
eprintln!(
"[DXDPSI_FD] max|analytic - fd|={:.3e} max|analytic|={:.3e}",
max_diff, max_abs
);
eprintln!(
"[DXDPSI_FD] analytic[0,..]={:?}",
analytic.row(0).iter().take(p).copied().collect::<Vec<_>>(),
);
eprintln!(
"[DXDPSI_FD] fd[0,..]={:?}",
fd.row(0).iter().take(p).copied().collect::<Vec<_>>(),
);
assert!(max_diff < 5e-3 * max_abs.max(1e-3), "dX/dψ mismatch");
}
#[test]
fn joint_build_and_freeze_shares_auto_spatial_centers_across_blocks() {
let n = 400usize;
let mut data = Array2::<f64>::zeros((n, 2));
for i in 0..n {
data[[i, 0]] = i as f64 / (n as f64 - 1.0);
data[[i, 1]] = (i as f64 * 0.19).sin();
}
let matern_term = |name: &str| SmoothTermSpec {
name: name.to_string(),
basis: SmoothBasisSpec::Matern {
feature_cols: vec![0, 1],
spec: MaternBasisSpec {
center_strategy: CenterStrategy::Auto(Box::new(
CenterStrategy::FarthestPoint { num_centers: 8 },
)),
length_scale: 0.8,
nu: MaternNu::FiveHalves,
include_intercept: false,
double_penalty: true,
identifiability: MaternIdentifiability::CenterSumToZero,
aniso_log_scales: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
};
let marginalspec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![matern_term("marginal")],
};
let logslopespec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![matern_term("logslope")],
};
let (designs, resolved_specs) = build_term_collection_designs_and_freeze_joint(
data.view(),
&[marginalspec.clone(), logslopespec.clone()],
)
.expect("joint build and freeze should succeed");
assert_eq!(designs.len(), 2);
assert_eq!(resolved_specs.len(), 2);
let extract_centers = |spec: &TermCollectionSpec| match &spec.smooth_terms[0].basis {
SmoothBasisSpec::Matern { spec, .. } => match &spec.center_strategy {
CenterStrategy::UserProvided(centers) => centers.clone(),
other => panic!("expected frozen user-provided centers, got {other:?}"),
},
other => panic!("expected Matérn term, got {other:?}"),
};
let marginal_centers = extract_centers(&resolved_specs[0]);
let logslope_centers = extract_centers(&resolved_specs[1]);
let separate_marginal_design =
build_term_collection_design(data.view(), &marginalspec).expect("separate marginal");
let separate_marginal =
freeze_term_collection_from_design(&marginalspec, &separate_marginal_design)
.expect("freeze separate marginal");
let separate_marginal_centers = extract_centers(&separate_marginal);
assert_eq!(marginal_centers, logslope_centers);
assert_eq!(marginal_centers.ncols(), 2);
assert_eq!(marginal_centers.nrows(), separate_marginal_centers.nrows());
}
#[test]
fn exact_joint_two_block_no_spatial_fast_path_returns_fully_frozen_specs() {
let n = 24usize;
let mut data = Array2::<f64>::zeros((n, 2));
for i in 0..n {
let t = i as f64 / (n as f64 - 1.0);
data[[i, 0]] = t;
data[[i, 1]] = (i % 3) as f64;
}
let pspline_term = |name: &str| SmoothTermSpec {
name: name.to_string(),
basis: SmoothBasisSpec::BSpline1D {
feature_col: 0,
spec: BSplineBasisSpec {
degree: 3,
penalty_order: 2,
knotspec: BSplineKnotSpec::Generate {
data_range: (0.0, 1.0),
num_internal_knots: 3,
},
double_penalty: true,
identifiability: BSplineIdentifiability::None,
},
},
shape: ShapeConstraint::None,
};
let random_effect = RandomEffectTermSpec {
name: "grp".to_string(),
feature_col: 1,
drop_first_level: false,
frozen_levels: None,
};
let meanspec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![random_effect.clone()],
smooth_terms: vec![pspline_term("mean_ps")],
};
let noisespec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![random_effect],
smooth_terms: vec![pspline_term("noise_ps")],
};
let kappa_options = SpatialLengthScaleOptimizationOptions {
enabled: true,
max_outer_iter: 1,
rel_tol: 1e-6,
log_step: std::f64::consts::LN_2,
min_length_scale: 1e-3,
max_length_scale: 1e3,
pilot_subsample_threshold: 0,
};
let joint_setup = two_block_exact_joint_hyper_setup(&meanspec, &noisespec, &kappa_options);
let theta_dim = joint_setup.theta0().len();
let mean_terms = spatial_length_scale_term_indices(&meanspec);
let noise_terms = spatial_length_scale_term_indices(&noisespec);
assert!(mean_terms.is_empty());
assert!(noise_terms.is_empty());
let policy = crate::families::custom_family::OuterDerivativePolicy {
capability: crate::families::custom_family::ExactOuterDerivativeOrder::Second,
predicted_hessian_work: 0,
predicted_gradient_work: 0,
subsample_capable: false,
};
let solved = optimize_spatial_length_scale_exact_joint(
data.view(),
&[meanspec.clone(), noisespec.clone()],
&[mean_terms, noise_terms],
&kappa_options,
&joint_setup,
crate::seeding::SeedRiskProfile::Gaussian,
true,
true,
false,
None,
policy,
|theta, specs, designs| {
assert_eq!(theta.len(), theta_dim);
assert_eq!(specs.len(), 2);
assert_eq!(designs.len(), 2);
Ok(designs[0].design.ncols() as f64 + designs[1].design.ncols() as f64)
},
|theta, specs, designs, eval_mode, _row_set| {
assert_eq!(theta.len(), theta_dim);
assert_eq!(specs.len(), 2);
assert_eq!(designs.len(), 2);
Ok((
0.0,
Array1::zeros(theta_dim),
if matches!(
eval_mode,
crate::solver::estimate::reml::unified::EvalMode::ValueGradientHessian
) {
crate::solver::outer_strategy::HessianResult::Analytic(Array2::zeros((
theta_dim, theta_dim,
)))
} else {
crate::solver::outer_strategy::HessianResult::Unavailable
},
))
},
|theta, specs, designs| {
assert_eq!(theta.len(), theta_dim);
assert_eq!(specs.len(), 2);
assert_eq!(designs.len(), 2);
Ok(crate::solver::outer_strategy::EfsEval {
cost: 0.0,
steps: vec![0.0; theta_dim],
beta: None,
psi_gradient: None,
psi_indices: None,
})
},
)
.expect("exact joint no-spatial fast path should succeed");
for resolved in [&solved.resolved_specs[0], &solved.resolved_specs[1]] {
resolved
.validate_frozen("resolvedspec")
.expect("exact joint no-spatial fast path should fully freeze specs");
match &resolved.smooth_terms[0].basis {
SmoothBasisSpec::BSpline1D { spec, .. } => {
assert!(matches!(spec.knotspec, BSplineKnotSpec::Provided(_)));
}
_ => panic!("expected P-spline term"),
}
assert!(
resolved.random_effect_terms[0].frozen_levels.is_some(),
"random-effect levels should be frozen in exact joint no-spatial fast path"
);
}
}
#[test]
fn incremental_frozen_realizer_matches_unified_full_rebuild() {
let n = 24usize;
let mut data = Array2::<f64>::zeros((n, 4));
for i in 0..n {
let t = i as f64 / (n as f64 - 1.0);
data[[i, 0]] = t;
data[[i, 1]] = (0.35 * i as f64).sin();
data[[i, 2]] = (i % 3) as f64;
data[[i, 3]] = t * t;
}
let spec = TermCollectionSpec {
linear_terms: vec![LinearTermSpec {
name: "lin".to_string(),
feature_col: 1,
double_penalty: false,
coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
coefficient_min: Some(-0.5),
coefficient_max: None,
}],
random_effect_terms: vec![RandomEffectTermSpec {
name: "grp".to_string(),
feature_col: 2,
drop_first_level: false,
frozen_levels: None,
}],
smooth_terms: vec![
SmoothTermSpec {
name: "spatial".to_string(),
basis: SmoothBasisSpec::Matern {
feature_cols: vec![0, 1],
spec: MaternBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 6 },
length_scale: 0.8,
nu: MaternNu::FiveHalves,
include_intercept: false,
double_penalty: true,
identifiability: MaternIdentifiability::CenterSumToZero,
aniso_log_scales: Some(vec![0.15, -0.15]),
},
input_scales: None,
},
shape: ShapeConstraint::None,
},
SmoothTermSpec {
name: "mono".to_string(),
basis: SmoothBasisSpec::BSpline1D {
feature_col: 3,
spec: BSplineBasisSpec {
degree: 3,
penalty_order: 2,
knotspec: BSplineKnotSpec::Generate {
data_range: (0.0, 1.0),
num_internal_knots: 3,
},
double_penalty: false,
identifiability: BSplineIdentifiability::None,
},
},
shape: ShapeConstraint::MonotoneIncreasing,
},
],
};
let base_design = build_term_collection_design(data.view(), &spec).expect("base design");
let frozen = freeze_term_collection_from_design(&spec, &base_design).expect("freeze");
let frozen_design =
build_term_collection_design(data.view(), &frozen).expect("frozen design");
let spatial_terms = spatial_length_scale_term_indices(&frozen);
assert_eq!(spatial_terms, vec![0]);
let smooth_start = frozen_design.design.ncols() - frozen_design.smooth.total_smooth_cols();
let fixed_before = frozen_design.design.clone();
let nonspatial_range = frozen_design.smooth.terms[1].coeff_range.clone();
let full_nonspatial_range =
(smooth_start + nonspatial_range.start)..(smooth_start + nonspatial_range.end);
let mut realizer = FrozenTermCollectionIncrementalRealizer::new(
data.view(),
frozen.clone(),
frozen_design.clone(),
)
.expect("incremental realizer");
let updated_log_kappa = SpatialLogKappaCoords::new_with_dims(array![0.30, -0.20], vec![2]);
let updated_spec = updated_log_kappa
.apply_tospec(&frozen, &spatial_terms)
.expect("updated spec");
realizer
.apply_log_kappa(&updated_log_kappa, &spatial_terms)
.expect("incremental update");
let rebuilt =
build_term_collection_design(data.view(), &updated_spec).expect("rebuilt design");
assert_term_collection_designs_match(realizer.design(), &rebuilt, "incremental realizer");
let linear_range = frozen_design.linear_ranges[0].1.clone();
let random_range = frozen_design.random_effect_ranges[0].1.clone();
let fixed_before_dense = fixed_before.to_dense();
let updated_full_dense = realizer.design().design.to_dense();
let linear_diff = max_abs_diff_matrix(
&fixed_before_dense
.slice(s![.., linear_range.clone()])
.to_owned(),
&updated_full_dense.slice(s![.., linear_range]).to_owned(),
);
let random_diff = max_abs_diff_matrix(
&fixed_before_dense
.slice(s![.., random_range.clone()])
.to_owned(),
&updated_full_dense.slice(s![.., random_range]).to_owned(),
);
let nonspatial_diff = max_abs_diff_matrix(
&fixed_before_dense
.slice(s![.., full_nonspatial_range.clone()])
.to_owned(),
&updated_full_dense
.slice(s![.., full_nonspatial_range.clone()])
.to_owned(),
);
let spatial_range = frozen_design.smooth.terms[0].coeff_range.clone();
let full_spatial_range =
(smooth_start + spatial_range.start)..(smooth_start + spatial_range.end);
let spatial_change = max_abs_diff_matrix(
&fixed_before_dense
.slice(s![.., full_spatial_range.clone()])
.to_owned(),
&updated_full_dense
.slice(s![.., full_spatial_range])
.to_owned(),
);
assert!(
linear_diff <= 1e-12,
"linear block changed max_abs={linear_diff}"
);
assert!(
random_diff <= 1e-12,
"random-effect block changed max_abs={random_diff}"
);
assert!(
nonspatial_diff <= 1e-12,
"unchanged smooth block changed max_abs={nonspatial_diff}"
);
assert!(
spatial_change > 1e-8,
"spatial block did not update max_abs={spatial_change}"
);
}
#[test]
fn two_block_exact_joint_design_cache_clears_memo_on_theta_change() {
let n = 20usize;
let mut data = Array2::<f64>::zeros((n, 2));
for i in 0..n {
let x0 = i as f64 / (n as f64 - 1.0);
let x1 = (0.19 * i as f64).sin();
data[[i, 0]] = x0;
data[[i, 1]] = x1;
}
let matern_term = |name: &str, length_scale: f64| SmoothTermSpec {
name: name.to_string(),
basis: SmoothBasisSpec::Matern {
feature_cols: vec![0, 1],
spec: MaternBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 5 },
length_scale,
nu: MaternNu::FiveHalves,
include_intercept: false,
double_penalty: true,
identifiability: MaternIdentifiability::CenterSumToZero,
aniso_log_scales: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
};
let meanspec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![matern_term("mean", 0.7)],
};
let noisespec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![matern_term("noise", 1.1)],
};
let kappa_options = SpatialLengthScaleOptimizationOptions {
enabled: true,
max_outer_iter: 1,
rel_tol: 1e-6,
log_step: std::f64::consts::LN_2,
min_length_scale: 1e-3,
max_length_scale: 1e3,
pilot_subsample_threshold: 0,
};
let joint_setup = two_block_exact_joint_hyper_setup(&meanspec, &noisespec, &kappa_options);
let theta0 = joint_setup.theta0();
let mean_design = build_term_collection_design(data.view(), &meanspec).expect("mean");
let noise_design = build_term_collection_design(data.view(), &noisespec).expect("noise");
let mean_frozen =
freeze_term_collection_from_design(&meanspec, &mean_design).expect("freeze mean");
let noise_frozen =
freeze_term_collection_from_design(&noisespec, &noise_design).expect("freeze noise");
let mean_term_indices = spatial_length_scale_term_indices(&mean_frozen);
let noise_term_indices = spatial_length_scale_term_indices(&noise_frozen);
let mut cache = ExactJointDesignCache::new(
data.view(),
vec![
(
mean_frozen.clone(),
mean_design.clone(),
mean_term_indices.clone(),
),
(
noise_frozen.clone(),
noise_design.clone(),
noise_term_indices.clone(),
),
],
joint_setup.rho_dim(),
joint_setup.log_kappa_dims_per_term(),
)
.expect("n-block cache");
cache.ensure_theta(&theta0).expect("initial theta");
assert!(cache.memoized_cost(&theta0).is_none());
assert!(cache.memoized_eval(&theta0).is_none());
let eval = (
2.25,
Array1::<f64>::ones(theta0.len()),
crate::solver::outer_strategy::HessianResult::Analytic(Array2::<f64>::eye(
theta0.len(),
)),
);
cache.store_eval(eval.clone());
let cached_eval = cache.memoized_eval(&theta0).expect("cached eval");
assert!((cached_eval.0 - eval.0).abs() <= 1e-12);
assert_eq!(cached_eval.1, eval.1);
assert_eq!(
cached_eval
.2
.materialize_dense()
.expect("materialize cached hessian"),
eval.2
.materialize_dense()
.expect("materialize eval hessian"),
);
let mut theta1 = theta0.clone();
theta1[joint_setup.rho_dim()] += 0.25;
cache.ensure_theta(&theta1).expect("updated theta");
assert!(cache.memoized_cost(&theta1).is_none());
assert!(cache.memoized_eval(&theta1).is_none());
let log_kappa = SpatialLogKappaCoords::from_theta_tail_with_dims(
&theta1,
joint_setup.rho_dim(),
joint_setup.log_kappa_dims_per_term(),
);
let mean_terms = spatial_length_scale_term_indices(&mean_frozen);
let noise_terms = spatial_length_scale_term_indices(&noise_frozen);
let (mean_lk, noise_lk) = log_kappa.split_at(mean_terms.len());
let mean_updated = mean_lk
.apply_tospec(&mean_frozen, &mean_terms)
.expect("mean updated spec");
let noise_updated = noise_lk
.apply_tospec(&noise_frozen, &noise_terms)
.expect("noise updated spec");
let mean_rebuilt =
build_term_collection_design(data.view(), &mean_updated).expect("mean rebuilt");
let noise_rebuilt =
build_term_collection_design(data.view(), &noise_updated).expect("noise rebuilt");
let cache_designs = cache.designs();
assert_term_collection_designs_match(cache_designs[0], &mean_rebuilt, "mean cache");
assert_term_collection_designs_match(cache_designs[1], &noise_rebuilt, "noise cache");
}
#[test]
fn single_block_exact_joint_design_cache_clears_memo_on_theta_change() {
let n = 22usize;
let mut data = Array2::<f64>::zeros((n, 2));
for i in 0..n {
let x0 = i as f64 / (n as f64 - 1.0);
let x1 = (0.23 * i as f64).cos();
data[[i, 0]] = x0;
data[[i, 1]] = x1;
}
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "matern".to_string(),
basis: SmoothBasisSpec::Matern {
feature_cols: vec![0, 1],
spec: MaternBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 6 },
length_scale: 0.9,
nu: MaternNu::FiveHalves,
include_intercept: false,
double_penalty: true,
identifiability: MaternIdentifiability::CenterSumToZero,
aniso_log_scales: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let design = build_term_collection_design(data.view(), &spec).expect("design");
let frozen = freeze_term_collection_from_design(&spec, &design).expect("freeze spec");
let spatial_terms = spatial_length_scale_term_indices(&frozen);
let rho_dim = design.penalties.len();
let dims_per_term = vec![1];
let mut theta0 = Array1::<f64>::zeros(rho_dim + 1);
theta0[rho_dim] = -get_spatial_length_scale(&frozen, spatial_terms[0])
.expect("length scale")
.ln();
let mut cache = SingleBlockExactJointDesignCache::new(
data.view(),
frozen.clone(),
design.clone(),
spatial_terms.clone(),
rho_dim,
dims_per_term.clone(),
)
.expect("single-block cache");
cache.ensure_theta(&theta0).expect("initial theta");
assert!(cache.memoized_cost(&theta0).is_none());
assert!(cache.memoized_eval(&theta0).is_none());
let eval = (
0.5,
Array1::<f64>::ones(theta0.len()),
crate::solver::outer_strategy::HessianResult::Analytic(Array2::<f64>::eye(
theta0.len(),
)),
);
cache.store_eval(eval.clone());
let cached_eval = cache.memoized_eval(&theta0).expect("cached eval");
assert!((cached_eval.0 - eval.0).abs() <= 1e-12);
assert_eq!(cached_eval.1, eval.1);
assert_eq!(
cached_eval
.2
.materialize_dense()
.expect("materialize cached hessian"),
eval.2
.materialize_dense()
.expect("materialize eval hessian"),
);
let mut theta1 = theta0.clone();
theta1[rho_dim] += 0.35;
cache.ensure_theta(&theta1).expect("updated theta");
assert!(cache.memoized_cost(&theta1).is_none());
assert!(cache.memoized_eval(&theta1).is_none());
let updated_log_kappa =
SpatialLogKappaCoords::from_theta_tail_with_dims(&theta1, rho_dim, dims_per_term);
let updated_spec = updated_log_kappa
.apply_tospec(&frozen, &spatial_terms)
.expect("updated spec");
let rebuilt =
build_term_collection_design(data.view(), &updated_spec).expect("rebuilt design");
assert_term_collection_designs_match(cache.design(), &rebuilt, "single-block cache");
}
#[test]
fn external_joint_evaluator_reuse_matches_fresh_state_after_theta_update() {
let n = 26usize;
let mut data = Array2::<f64>::zeros((n, 2));
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let x0 = i as f64 / (n as f64 - 1.0);
let x1 = (0.21 * i as f64).sin();
data[[i, 0]] = x0;
data[[i, 1]] = x1;
y[i] = (2.0 * std::f64::consts::PI * x0).sin() + 0.35 * x1;
}
let spec = TermCollectionSpec {
linear_terms: vec![LinearTermSpec {
name: "x0".to_string(),
feature_col: 0,
double_penalty: false,
coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
coefficient_min: None,
coefficient_max: None,
}],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "matern".to_string(),
basis: SmoothBasisSpec::Matern {
feature_cols: vec![0, 1],
spec: MaternBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 6 },
length_scale: 0.85,
nu: MaternNu::FiveHalves,
include_intercept: false,
double_penalty: true,
identifiability: MaternIdentifiability::CenterSumToZero,
aniso_log_scales: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let weights = Array1::ones(n);
let offset = Array1::zeros(n);
let fit_opts = FitOptions {
latent_cloglog: None,
mixture_link: None,
optimize_mixture: false,
sas_link: None,
optimize_sas: false,
compute_inference: false,
max_iter: 40,
tol: 1e-7,
nullspace_dims: vec![],
linear_constraints: None,
firth_bias_reduction: false,
adaptive_regularization: None,
penalty_shrinkage_floor: None,
rho_prior: Default::default(),
kronecker_penalty_system: None,
kronecker_factored: None,
};
let design = build_term_collection_design(data.view(), &spec).expect("design");
let frozen = freeze_term_collection_from_design(&spec, &design).expect("freeze");
let spatial_terms = spatial_length_scale_term_indices(&frozen);
let dims_per_term = spatial_dims_per_term(&frozen, &spatial_terms);
let rho_dim = design.penalties.len();
let mut theta0 = Array1::<f64>::zeros(rho_dim + dims_per_term.iter().sum::<usize>());
for j in 0..rho_dim {
theta0[j] = 0.2 - 0.1 * j as f64;
}
theta0[rho_dim] = -get_spatial_length_scale(&frozen, spatial_terms[0])
.expect("length scale")
.ln();
let mut theta1 = theta0.clone();
theta1[rho_dim] += 0.3;
let external_opts =
external_opts_for_design(LikelihoodFamily::GaussianIdentity, &design, &fit_opts);
let mut cache = SingleBlockExactJointDesignCache::new(
data.view(),
frozen,
design.clone(),
spatial_terms,
rho_dim,
dims_per_term,
)
.expect("single-block cache");
let mut reused = crate::estimate::ExternalJointHyperEvaluator::new(
y.view(),
weights.view(),
&design.design,
offset.view(),
&design.penalties,
&external_opts,
"reused evaluator",
)
.expect("reused evaluator");
let compare_eval =
|theta: &Array1<f64>,
cache: &mut SingleBlockExactJointDesignCache<'_>,
reused: &mut crate::estimate::ExternalJointHyperEvaluator<'_>| {
cache.ensure_theta(theta).expect("theta applied");
let build_hyper_dirs = || {
try_build_spatial_log_kappa_hyper_dirs(
data.view(),
cache.spec(),
cache.design(),
&cache.spatial_terms,
)
.expect("hyper dirs build")
.expect("hyper dirs present")
};
let reused_eval = evaluate_joint_reml_outer_eval_at_theta(
reused,
cache.design(),
theta,
rho_dim,
build_hyper_dirs(),
None,
crate::solver::outer_strategy::OuterEvalOrder::ValueGradientHessian,
)
.expect("reused eval");
let fresh_opts = external_opts_for_design(
LikelihoodFamily::GaussianIdentity,
cache.design(),
&fit_opts,
);
let mut fresh = crate::estimate::ExternalJointHyperEvaluator::new(
y.view(),
weights.view(),
&cache.design().design,
offset.view(),
&cache.design().penalties,
&fresh_opts,
"fresh evaluator",
)
.expect("fresh evaluator");
let fresh_eval = evaluate_joint_reml_outer_eval_at_theta(
&mut fresh,
cache.design(),
theta,
rho_dim,
build_hyper_dirs(),
None,
crate::solver::outer_strategy::OuterEvalOrder::ValueGradientHessian,
)
.expect("fresh eval");
let cost_diff = (reused_eval.0 - fresh_eval.0).abs();
assert!(cost_diff <= 1e-10, "cost mismatch: {cost_diff}");
let grad_diff = reused_eval
.1
.iter()
.zip(fresh_eval.1.iter())
.map(|(left, right)| (left - right).abs())
.fold(0.0_f64, f64::max);
assert!(grad_diff <= 1e-9, "gradient mismatch: {grad_diff}");
let reused_hess = reused_eval
.2
.materialize_dense()
.expect("reused hessian materializes")
.expect("reused hessian present");
let fresh_hess = fresh_eval
.2
.materialize_dense()
.expect("fresh hessian materializes")
.expect("fresh hessian present");
let hess_diff = max_abs_diff_matrix(&reused_hess, &fresh_hess);
assert!(hess_diff <= 1e-9, "hessian mismatch: {hess_diff}");
let reused_efs = evaluate_joint_reml_efs_at_theta(
reused,
cache.design(),
theta,
rho_dim,
build_hyper_dirs(),
None,
)
.expect("reused EFS eval");
let mut fresh_efs_eval = crate::estimate::ExternalJointHyperEvaluator::new(
y.view(),
weights.view(),
&cache.design().design,
offset.view(),
&cache.design().penalties,
&fresh_opts,
"fresh EFS evaluator",
)
.expect("fresh EFS evaluator");
let fresh_efs = evaluate_joint_reml_efs_at_theta(
&mut fresh_efs_eval,
cache.design(),
theta,
rho_dim,
build_hyper_dirs(),
None,
)
.expect("fresh EFS eval");
let efs_cost_diff = (reused_efs.cost - fresh_efs.cost).abs();
assert!(efs_cost_diff <= 1e-10, "EFS cost mismatch: {efs_cost_diff}");
assert_eq!(reused_efs.steps.len(), fresh_efs.steps.len());
let efs_step_diff = reused_efs
.steps
.iter()
.zip(fresh_efs.steps.iter())
.map(|(left, right)| (left - right).abs())
.fold(0.0_f64, f64::max);
assert!(efs_step_diff <= 1e-9, "EFS step mismatch: {efs_step_diff}");
};
compare_eval(&theta0, &mut cache, &mut reused);
compare_eval(&theta1, &mut cache, &mut reused);
}
#[test]
fn exact_matern_log_kappa_derivative_uses_feature_columns_only() {
let n = 24usize;
let p = 17usize;
let mut data = Array2::<f64>::zeros((n, p));
for i in 0..n {
let x = i as f64 / (n as f64 - 1.0);
data[[i, 0]] = x;
for j in 1..p {
data[[i, j]] = ((i + j) as f64 * 0.13).sin();
}
}
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "matern".to_string(),
basis: SmoothBasisSpec::Matern {
feature_cols: vec![0],
spec: MaternBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 6 },
length_scale: 0.4,
nu: MaternNu::FiveHalves,
include_intercept: false,
double_penalty: true,
identifiability: MaternIdentifiability::CenterSumToZero,
aniso_log_scales: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let design = build_term_collection_design(data.view(), &spec)
.expect("baseline Matérn design should build");
let frozenspec = freeze_term_collection_from_design(&spec, &design)
.expect("freezing Matérn centers from design should succeed");
match &frozenspec.smooth_terms[0].basis {
SmoothBasisSpec::Matern { spec, .. } => match &spec.center_strategy {
CenterStrategy::UserProvided(centers) => {
assert_eq!(centers.ncols(), 1, "frozen centers should stay term-local");
}
_ => panic!("expected frozen user-provided centers"),
},
_ => panic!("expected Matérn term"),
}
let derivative =
try_build_spatial_term_log_kappa_derivative(data.view(), &frozenspec, &design, 0);
assert!(
derivative.is_ok(),
"exact Matérn log-kappa derivative should use only feature_cols; got {derivative:?}"
);
assert!(
derivative
.expect("derivative call should succeed")
.is_some(),
"Matérn term should expose an exact derivative"
);
}
#[test]
fn exact_thin_plate_log_kappa_derivative_uses_feature_columns_only() {
let n = 28usize;
let p = 15usize;
let mut data = Array2::<f64>::zeros((n, p));
for i in 0..n {
let x0 = i as f64 / (n as f64 - 1.0);
let x1 = (0.17 * i as f64).sin();
data[[i, 0]] = x0;
data[[i, 1]] = x1;
for j in 2..p {
data[[i, j]] = ((i + 3 * j) as f64 * 0.07).cos();
}
}
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "thinplate".to_string(),
basis: SmoothBasisSpec::ThinPlate {
feature_cols: vec![0, 1],
spec: ThinPlateBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 7 },
length_scale: 0.7,
double_penalty: true,
identifiability: SpatialIdentifiability::default(),
radial_reparam: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let design = build_term_collection_design(data.view(), &spec)
.expect("baseline ThinPlate design should build");
let frozenspec = freeze_term_collection_from_design(&spec, &design)
.expect("freezing ThinPlate centers from design should succeed");
match &frozenspec.smooth_terms[0].basis {
SmoothBasisSpec::ThinPlate { spec, .. } => match &spec.center_strategy {
CenterStrategy::UserProvided(centers) => {
assert_eq!(centers.ncols(), 2, "frozen centers should stay term-local");
}
_ => panic!("expected frozen user-provided centers"),
},
_ => panic!("expected ThinPlate term"),
}
let smooth_term = &design.smooth.terms[0];
let termspec = &frozenspec.smooth_terms[0];
let BasisPsiDerivativeResult {
design_derivative: local_x_psi,
penalties_derivative: local_s_psi,
..
} = match &termspec.basis {
SmoothBasisSpec::ThinPlate {
feature_cols, spec, ..
} => {
let x = select_columns(data.view(), feature_cols)
.expect("select ThinPlate feature cols");
crate::basis::build_thin_plate_basis_log_kappa_derivative(x.view(), spec)
.expect("direct ThinPlate derivative should build")
}
_ => panic!("expected ThinPlate term"),
};
let BasisPsiSecondDerivativeResult {
designsecond_derivative: local_x_psi_psi,
penaltiessecond_derivative: local_s_psi_psi,
..
} = match &termspec.basis {
SmoothBasisSpec::ThinPlate {
feature_cols, spec, ..
} => {
let x = select_columns(data.view(), feature_cols)
.expect("select ThinPlate feature cols");
crate::basis::build_thin_plate_basis_log_kappasecond_derivative(x.view(), spec)
.expect("direct ThinPlate second derivative should build")
}
_ => panic!("expected ThinPlate term"),
};
assert_eq!(local_x_psi.ncols(), smooth_term.coeff_range.len());
assert_eq!(local_x_psi_psi.ncols(), smooth_term.coeff_range.len());
assert!(!local_s_psi.is_empty());
assert_eq!(local_s_psi.len(), local_s_psi_psi.len());
assert!(local_s_psi.iter().all(|s| {
s.nrows() == smooth_term.coeff_range.len() && s.ncols() == smooth_term.coeff_range.len()
}));
assert!(local_s_psi_psi.iter().all(|s| {
s.nrows() == smooth_term.coeff_range.len() && s.ncols() == smooth_term.coeff_range.len()
}));
let derivative =
try_build_spatial_term_log_kappa_derivative(data.view(), &frozenspec, &design, 0);
assert!(
derivative.is_ok(),
"exact ThinPlate log-kappa derivative should use only feature_cols; got {derivative:?}"
);
let derivative = derivative.expect("derivative call should succeed");
assert!(
derivative.is_some(),
"ThinPlate term should expose an exact derivative"
);
}
#[test]
fn exact_duchon_log_kappa_derivative_uses_feature_columns_only() {
let n = 28usize;
let p = 15usize;
let mut data = Array2::<f64>::zeros((n, p));
for i in 0..n {
let x0 = i as f64 / (n as f64 - 1.0);
let x1 = (0.21 * i as f64).cos();
data[[i, 0]] = x0;
data[[i, 1]] = x1;
for j in 2..p {
data[[i, j]] = ((i + 2 * j) as f64 * 0.09).sin();
}
}
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "duchon".to_string(),
basis: SmoothBasisSpec::Duchon {
feature_cols: vec![0, 1],
spec: DuchonBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 7 },
length_scale: Some(0.7),
power: 1,
nullspace_order: DuchonNullspaceOrder::Linear,
identifiability: SpatialIdentifiability::default(),
aniso_log_scales: None,
operator_penalties: DuchonOperatorPenaltySpec::default(),
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let design = build_term_collection_design(data.view(), &spec)
.expect("baseline Duchon design should build");
let frozenspec = freeze_term_collection_from_design(&spec, &design)
.expect("freezing Duchon centers from design should succeed");
match &frozenspec.smooth_terms[0].basis {
SmoothBasisSpec::Duchon { spec, .. } => match &spec.center_strategy {
CenterStrategy::UserProvided(centers) => {
assert_eq!(centers.ncols(), 2, "frozen centers should stay term-local");
}
_ => panic!("expected frozen user-provided centers"),
},
_ => panic!("expected Duchon term"),
}
let smooth_term = &design.smooth.terms[0];
let termspec = &frozenspec.smooth_terms[0];
let derivative_bundle = match &termspec.basis {
SmoothBasisSpec::Duchon {
feature_cols, spec, ..
} => {
let x =
select_columns(data.view(), feature_cols).expect("select Duchon feature cols");
build_duchon_basis_log_kappa_derivatives(x.view(), spec)
.expect("direct Duchon derivative bundle should build")
}
_ => panic!("expected Duchon term"),
};
let local_implicit = derivative_bundle.implicit_operator;
let BasisPsiDerivativeResult {
design_derivative: local_x_psi,
penalties_derivative: local_s_psi,
implicit_operator: local_implicit_psi_unused,
} = derivative_bundle.first;
let BasisPsiSecondDerivativeResult {
designsecond_derivative: local_x_psi_psi,
penaltiessecond_derivative: local_s_psi_psi,
implicit_operator: local_implicit_psi_psi_unused,
} = derivative_bundle.second;
assert!(local_implicit_psi_unused.is_none());
assert!(local_implicit_psi_psi_unused.is_none());
assert_spatial_derivative_width(
"Duchon first log-kappa",
&local_x_psi,
local_implicit.as_ref(),
smooth_term.coeff_range.len(),
);
assert_spatial_derivative_width(
"Duchon second log-kappa",
&local_x_psi_psi,
local_implicit.as_ref(),
smooth_term.coeff_range.len(),
);
assert!(!local_s_psi.is_empty());
assert_eq!(local_s_psi.len(), local_s_psi_psi.len());
assert!(local_s_psi.iter().all(|s| {
s.nrows() == smooth_term.coeff_range.len() && s.ncols() == smooth_term.coeff_range.len()
}));
assert!(local_s_psi_psi.iter().all(|s| {
s.nrows() == smooth_term.coeff_range.len() && s.ncols() == smooth_term.coeff_range.len()
}));
let derivative =
try_build_spatial_term_log_kappa_derivative(data.view(), &frozenspec, &design, 0);
assert!(
derivative.is_ok(),
"exact Duchon log-kappa derivative should use only feature_cols; got {derivative:?}"
);
let derivative = derivative.expect("derivative call should succeed");
assert!(
derivative.is_some(),
"Duchon term should expose an exact derivative"
);
}
#[test]
fn spatial_length_scale_optimization_monotone_improves_or_keeps_score_for_matern() {
let n = 60usize;
let d = 2usize;
let mut data = Array2::<f64>::zeros((n, d));
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let x0 = i as f64 / (n as f64 - 1.0);
let x1 = (i as f64 * 0.17).sin();
data[[i, 0]] = x0;
data[[i, 1]] = x1;
y[i] = (3.0 * x0).cos() + 0.35 * x1;
}
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "matern".to_string(),
basis: SmoothBasisSpec::Matern {
feature_cols: vec![0, 1],
spec: MaternBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 12 },
length_scale: 12.0,
nu: MaternNu::FiveHalves,
include_intercept: false,
double_penalty: true,
identifiability: MaternIdentifiability::CenterSumToZero,
aniso_log_scales: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let fit_opts = FitOptions {
latent_cloglog: None,
mixture_link: None,
optimize_mixture: false,
sas_link: None,
optimize_sas: false,
compute_inference: true,
max_iter: 40,
tol: 1e-6,
nullspace_dims: vec![],
linear_constraints: None,
firth_bias_reduction: false,
adaptive_regularization: None,
penalty_shrinkage_floor: None,
rho_prior: Default::default(),
kronecker_penalty_system: None,
kronecker_factored: None,
};
let weights = Array1::ones(n);
let offset = Array1::zeros(n);
let baseline = fit_term_collection_forspec(
data.view(),
y.view(),
weights.view(),
offset.view(),
&spec,
LikelihoodFamily::GaussianIdentity,
&fit_opts,
)
.expect("baseline fit should succeed");
let baseline_score = fit_score(&baseline.fit);
let optimized = fit_term_collectionwith_spatial_length_scale_optimization(
data.view(),
y.clone(),
weights.clone(),
offset.clone(),
&spec,
LikelihoodFamily::GaussianIdentity,
&fit_opts,
&SpatialLengthScaleOptimizationOptions {
enabled: true,
max_outer_iter: 2,
rel_tol: 1e-5,
log_step: std::f64::consts::LN_2,
min_length_scale: 1e-3,
max_length_scale: 1e3,
pilot_subsample_threshold: 0,
},
)
.expect("optimized fit should succeed");
let optimized_score = fit_score(&optimized.fit);
assert!(optimized_score <= baseline_score + 1e-10);
let ls = match &optimized.resolvedspec.smooth_terms[0].basis {
SmoothBasisSpec::Matern { spec, .. } => spec.length_scale,
_ => panic!("expected Matérn term"),
};
assert!(ls.is_finite() && (1e-3..=1e3).contains(&ls));
match &optimized.resolvedspec.smooth_terms[0].basis {
SmoothBasisSpec::Matern { spec, .. } => {
assert!(matches!(
spec.center_strategy,
CenterStrategy::UserProvided(_)
));
assert!(matches!(
spec.identifiability,
MaternIdentifiability::FrozenTransform { .. }
));
}
_ => panic!("expected Matérn term"),
}
}
#[test]
fn spatial_length_scale_optimization_runs_binomial_logit_matern_with_exact_laml_derivatives() {
let n = 80usize;
let d = 2usize;
let mut data = Array2::<f64>::zeros((n, d));
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let x0 = i as f64 / (n as f64 - 1.0);
let x1 = (i as f64 * 0.19).cos();
data[[i, 0]] = x0;
data[[i, 1]] = x1;
let eta = -0.15 + 0.45 * x0 - 0.25 * x1 + 0.10 * (6.0 * x0).sin();
let mu = 1.0 / (1.0 + (-eta).exp());
let u = (((i * 37 + 17) % 101) as f64 + 0.5) / 101.0;
y[i] = if u < mu { 1.0 } else { 0.0 };
}
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "matern".to_string(),
basis: SmoothBasisSpec::Matern {
feature_cols: vec![0, 1],
spec: MaternBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 10 },
length_scale: 1.8,
nu: MaternNu::FiveHalves,
include_intercept: false,
double_penalty: true,
identifiability: MaternIdentifiability::CenterSumToZero,
aniso_log_scales: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let fit_opts = FitOptions {
latent_cloglog: None,
mixture_link: None,
optimize_mixture: false,
sas_link: None,
optimize_sas: false,
compute_inference: true,
max_iter: 60,
tol: 1e-6,
nullspace_dims: vec![],
linear_constraints: None,
firth_bias_reduction: false,
adaptive_regularization: None,
penalty_shrinkage_floor: Some(1e-6),
rho_prior: Default::default(),
kronecker_penalty_system: None,
kronecker_factored: None,
};
let weights = Array1::ones(n);
let offset = Array1::zeros(n);
fit_term_collectionwith_spatial_length_scale_optimization(
data.view(),
y,
weights,
offset,
&spec,
LikelihoodFamily::BinomialLogit,
&fit_opts,
&SpatialLengthScaleOptimizationOptions {
enabled: true,
max_outer_iter: 2,
rel_tol: 1e-5,
log_step: std::f64::consts::LN_2,
min_length_scale: 1e-3,
max_length_scale: 1e3,
pilot_subsample_threshold: 0,
},
)
.expect("standard binomial-logit spatial kappa optimization should use exact non-TK LAML derivatives");
}
#[test]
fn spatial_kappa_result_requires_exact_availability() {
let err = require_successful_spatial_optimization_result::<()>(0.0, Ok(None))
.expect_err("missing exact spatial result must be surfaced");
let msg = err.to_string();
assert!(msg.contains("unavailable"), "unexpected error: {msg}");
}
#[test]
fn spatial_kappa_result_rejects_worse_exact_score() {
let err = require_successful_spatial_optimization_result(1.0, Ok(Some(((), 1.5))))
.expect_err("worse exact spatial result must be rejected");
let msg = err.to_string();
assert!(
msg.contains("made REML score worse"),
"unexpected error: {msg}"
);
assert!(msg.contains("1.000000e0"), "unexpected error: {msg}");
assert!(msg.contains("1.500000e0"), "unexpected error: {msg}");
}
#[test]
fn spatial_kappa_result_surfaces_optimizer_failure() {
let err = require_successful_spatial_optimization_result::<()>(
0.0,
Err(EstimationError::InvalidInput("boom".to_string())),
)
.expect_err("exact spatial optimizer failure must be surfaced");
let msg = err.to_string();
assert!(
msg.contains("spatial kappa optimization failed"),
"unexpected error: {msg}"
);
assert!(msg.contains("boom"), "unexpected error: {msg}");
}
#[test]
fn duchon_terms_participate_in_kappa_optimization() {
let data = array![
[0.0, 0.1, 0.2],
[0.2, 0.0, 0.4],
[0.4, 0.3, 0.1],
[0.6, 0.5, 0.7],
[0.8, 0.7, 0.3],
[1.0, 0.9, 0.8],
];
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "duchon".to_string(),
basis: SmoothBasisSpec::Duchon {
feature_cols: vec![0, 1],
spec: DuchonBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 4 },
length_scale: Some(0.9),
power: 1,
nullspace_order: DuchonNullspaceOrder::Linear,
identifiability: SpatialIdentifiability::default(),
aniso_log_scales: None,
operator_penalties: DuchonOperatorPenaltySpec::default(),
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
assert_eq!(spatial_length_scale_term_indices(&spec), vec![0]);
let fit_opts = FitOptions {
latent_cloglog: None,
mixture_link: None,
optimize_mixture: false,
sas_link: None,
optimize_sas: false,
compute_inference: true,
max_iter: 40,
tol: 1e-6,
nullspace_dims: vec![],
linear_constraints: None,
firth_bias_reduction: false,
adaptive_regularization: None,
penalty_shrinkage_floor: None,
rho_prior: Default::default(),
kronecker_penalty_system: None,
kronecker_factored: None,
};
let y = Array1::linspace(0.0, 1.0, data.nrows());
let weights = Array1::ones(data.nrows());
let offset = Array1::zeros(data.nrows());
let design = build_term_collection_design(data.view(), &spec)
.expect("baseline Duchon design should build");
let frozenspec = freeze_term_collection_from_design(&spec, &design)
.expect("freezing Duchon centers from design should succeed");
let derivative =
try_build_spatial_term_log_kappa_derivative(data.view(), &frozenspec, &design, 0);
assert!(
derivative
.expect("Duchon exact derivative call should succeed")
.is_some(),
"Duchon term should expose an exact derivative"
);
let optimized = fit_term_collectionwith_spatial_length_scale_optimization(
data.view(),
y,
weights,
offset,
&spec,
LikelihoodFamily::GaussianIdentity,
&fit_opts,
&SpatialLengthScaleOptimizationOptions::default(),
)
.expect("Duchon fit should use exact κ optimization");
let optimized_ls = match &optimized.resolvedspec.smooth_terms[0].basis {
SmoothBasisSpec::Duchon { spec, .. } => spec.length_scale,
_ => panic!("expected Duchon term"),
};
assert!(optimized_ls.is_some());
match &optimized.resolvedspec.smooth_terms[0].basis {
SmoothBasisSpec::Duchon { spec, .. } => {
assert!(matches!(
spec.center_strategy,
CenterStrategy::UserProvided(_)
));
assert!(matches!(
spec.identifiability,
SpatialIdentifiability::FrozenTransform { .. }
));
}
_ => panic!("expected Duchon term"),
}
}
#[test]
fn pure_duchon_scale_dimensions_participate_without_length_scale() {
let mut spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "pure_duchon".to_string(),
basis: SmoothBasisSpec::Duchon {
feature_cols: vec![0, 1],
spec: DuchonBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 4 },
length_scale: None,
power: 1,
nullspace_order: DuchonNullspaceOrder::Linear,
identifiability: SpatialIdentifiability::default(),
aniso_log_scales: None,
operator_penalties: DuchonOperatorPenaltySpec::default(),
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
crate::term_builder::enable_scale_dimensions(&mut spec);
assert_eq!(spatial_length_scale_term_indices(&spec), vec![0]);
match &spec.smooth_terms[0].basis {
SmoothBasisSpec::Duchon { spec, .. } => {
assert_eq!(spec.length_scale, None);
assert_eq!(spec.aniso_log_scales.as_deref(), Some(&[0.0, 0.0][..]));
}
_ => panic!("expected Duchon term"),
}
}
#[test]
fn pure_duchon_apply_tospec_preserves_length_scale_none() {
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "pure_duchon".to_string(),
basis: SmoothBasisSpec::Duchon {
feature_cols: vec![0, 1],
spec: DuchonBasisSpec {
center_strategy: CenterStrategy::UserProvided(array![
[0.0, 0.0],
[1.0, 0.0],
[0.0, 1.0],
[1.0, 1.0]
]),
length_scale: None,
power: 1,
nullspace_order: DuchonNullspaceOrder::Linear,
identifiability: SpatialIdentifiability::default(),
aniso_log_scales: Some(vec![0.0, 0.0]),
operator_penalties: DuchonOperatorPenaltySpec::default(),
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let updated = SpatialLogKappaCoords::new_with_dims(array![0.4], vec![1])
.apply_tospec(&spec, &[0])
.expect("pure Duchon update should succeed");
match &updated.smooth_terms[0].basis {
SmoothBasisSpec::Duchon { spec, .. } => {
assert_eq!(spec.length_scale, None);
let eta = spec
.aniso_log_scales
.as_ref()
.expect("pure Duchon should keep aniso");
assert!((eta.iter().sum::<f64>()).abs() <= 1e-12);
assert!((eta[0] - 0.4).abs() <= 1e-12);
assert!((eta[1] + 0.4).abs() <= 1e-12);
}
_ => panic!("expected Duchon term"),
}
}
#[test]
fn pure_duchon_from_length_scales_aniso_centers_existing_eta() {
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "pure_duchon".to_string(),
basis: SmoothBasisSpec::Duchon {
feature_cols: vec![0, 1, 2],
spec: DuchonBasisSpec {
center_strategy: CenterStrategy::UserProvided(array![
[0.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
]),
length_scale: None,
power: 1,
nullspace_order: DuchonNullspaceOrder::Linear,
identifiability: SpatialIdentifiability::None,
aniso_log_scales: Some(vec![0.7, 0.2, 0.1]),
operator_penalties: DuchonOperatorPenaltySpec::default(),
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let coords = SpatialLogKappaCoords::from_length_scales_aniso(
&spec,
&[0],
&SpatialLengthScaleOptimizationOptions::default(),
);
assert_eq!(coords.dims_per_term(), &[2]);
let expected = [0.36666666666666664, -0.13333333333333336];
for (got, want) in coords.as_array().iter().zip(expected.iter()) {
assert!((got - want).abs() <= 1e-12);
}
}
#[test]
fn from_length_scales_aniso_keeps_nonaniso_spatial_terms_scalar() {
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![
SmoothTermSpec {
name: "matern_aniso".to_string(),
basis: SmoothBasisSpec::Matern {
feature_cols: vec![0, 1],
spec: MaternBasisSpec {
center_strategy: CenterStrategy::UserProvided(array![
[0.0, 0.0],
[1.0, 0.0],
[0.0, 1.0],
]),
length_scale: 0.5,
nu: MaternNu::FiveHalves,
include_intercept: false,
double_penalty: false,
identifiability: MaternIdentifiability::None,
aniso_log_scales: Some(vec![0.3, -0.3]),
},
input_scales: None,
},
shape: ShapeConstraint::None,
},
SmoothTermSpec {
name: "matern_iso".to_string(),
basis: SmoothBasisSpec::Matern {
feature_cols: vec![0, 1],
spec: MaternBasisSpec {
center_strategy: CenterStrategy::UserProvided(array![
[0.0, 0.0],
[1.0, 0.0],
[0.0, 1.0],
]),
length_scale: 0.25,
nu: MaternNu::ThreeHalves,
include_intercept: false,
double_penalty: false,
identifiability: MaternIdentifiability::None,
aniso_log_scales: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
},
],
};
let term_indices = [0usize, 1usize];
let coords = SpatialLogKappaCoords::from_length_scales_aniso(
&spec,
&term_indices,
&SpatialLengthScaleOptimizationOptions::default(),
);
assert_eq!(spatial_dims_per_term(&spec, &term_indices), vec![2, 1]);
assert_eq!(coords.dims_per_term(), &[2, 1]);
let expected = [-0.5_f64.ln() + 0.3, -0.5_f64.ln() - 0.3, -0.25_f64.ln()];
for (got, want) in coords.as_array().iter().zip(expected.iter()) {
assert!((got - want).abs() <= 1e-12);
}
}
#[test]
fn aniso_bounds_clamp_preserves_in_range_global_length_scale_and_eta() {
let data = array![[0.0, 0.0], [1.0, 0.2], [0.1, 1.0], [1.1, 1.2]];
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "matern_aniso".to_string(),
basis: SmoothBasisSpec::Matern {
feature_cols: vec![0, 1],
spec: MaternBasisSpec {
center_strategy: CenterStrategy::UserProvided(array![
[0.0, 0.0],
[1.0, 0.0],
[0.0, 1.0],
[1.0, 1.0],
]),
length_scale: 1.0,
nu: MaternNu::FiveHalves,
include_intercept: false,
double_penalty: true,
identifiability: MaternIdentifiability::None,
aniso_log_scales: Some(vec![3.0, -3.0]),
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let options = SpatialLengthScaleOptimizationOptions {
enabled: true,
max_outer_iter: 1,
rel_tol: 1e-6,
log_step: std::f64::consts::LN_2,
min_length_scale: (-2.0_f64).exp(),
max_length_scale: 1.0_f64.exp(),
pilot_subsample_threshold: 0,
};
let spatial_terms = vec![0];
let dims_per_term = spatial_dims_per_term(&spec, &spatial_terms);
let seed = SpatialLogKappaCoords::from_length_scales_aniso(&spec, &spatial_terms, &options);
let lower = SpatialLogKappaCoords::lower_bounds_aniso_from_data(
data.view(),
&spec,
&spatial_terms,
&dims_per_term,
&options,
);
let upper = SpatialLogKappaCoords::upper_bounds_aniso_from_data(
data.view(),
&spec,
&spatial_terms,
&dims_per_term,
&options,
);
let projected = seed.clone().clamp_to_bounds(&lower, &upper);
assert_eq!(projected.as_array(), seed.as_array());
let updated = projected
.apply_tospec(&spec, &spatial_terms)
.expect("aniso projection should decode");
match &updated.smooth_terms[0].basis {
SmoothBasisSpec::Matern { spec, .. } => {
assert!((spec.length_scale - 1.0).abs() <= 1e-12);
let eta = spec
.aniso_log_scales
.as_ref()
.expect("anisotropy should be preserved");
assert!((eta[0] - 3.0).abs() <= 1e-12);
assert!((eta[1] + 3.0).abs() <= 1e-12);
}
_ => panic!("expected Matérn term"),
}
}
#[test]
fn pure_duchon_aniso_fit_optimizes_without_introducing_hybrid_scale() {
let data = array![
[0.0, 0.1, 0.2],
[0.2, 0.0, 0.4],
[0.4, 0.3, 0.1],
[0.6, 0.5, 0.7],
[0.8, 0.7, 0.3],
[1.0, 0.9, 0.8],
];
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "pure_duchon".to_string(),
basis: SmoothBasisSpec::Duchon {
feature_cols: vec![0, 1, 2],
spec: DuchonBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 5 },
length_scale: None,
power: 1,
nullspace_order: DuchonNullspaceOrder::Linear,
identifiability: SpatialIdentifiability::default(),
aniso_log_scales: Some(vec![0.0, 0.0, 0.0]),
operator_penalties: DuchonOperatorPenaltySpec::default(),
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let fit_opts = FitOptions {
latent_cloglog: None,
mixture_link: None,
optimize_mixture: false,
sas_link: None,
optimize_sas: false,
compute_inference: true,
max_iter: 40,
tol: 1e-6,
nullspace_dims: vec![],
linear_constraints: None,
firth_bias_reduction: false,
adaptive_regularization: None,
penalty_shrinkage_floor: None,
rho_prior: Default::default(),
kronecker_penalty_system: None,
kronecker_factored: None,
};
let optimized = fit_term_collectionwith_spatial_length_scale_optimization(
data.view(),
Array1::linspace(0.0, 1.0, data.nrows()),
Array1::ones(data.nrows()),
Array1::zeros(data.nrows()),
&spec,
LikelihoodFamily::GaussianIdentity,
&fit_opts,
&SpatialLengthScaleOptimizationOptions::default(),
)
.expect("pure Duchon anisotropic fit should optimize");
match &optimized.resolvedspec.smooth_terms[0].basis {
SmoothBasisSpec::Duchon { spec, .. } => {
assert_eq!(spec.length_scale, None);
assert!(
spec.aniso_log_scales.is_some(),
"pure Duchon anisotropy should remain enabled"
);
}
_ => panic!("expected Duchon term"),
}
}
#[test]
fn spatial_anisotropy_pilot_initializer_seeds_geometry_without_fit() {
let data = Array2::from_shape_fn((32, 2), |(i, j)| {
if j == 0 {
i as f64 / 31.0
} else {
((i % 8) as f64) * 0.03
}
});
let mut spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "pc_matern".to_string(),
basis: SmoothBasisSpec::Matern {
feature_cols: vec![0, 1],
spec: MaternBasisSpec {
center_strategy: CenterStrategy::UserProvided(array![
[0.0, 0.0],
[1.0, 0.0],
[0.0, 0.05],
[1.0, 0.05],
]),
length_scale: 1.0,
nu: MaternNu::FiveHalves,
include_intercept: false,
double_penalty: true,
identifiability: MaternIdentifiability::None,
aniso_log_scales: Some(vec![0.0, 0.0]),
},
input_scales: Some(vec![1.0, 1.0]),
},
shape: ShapeConstraint::None,
}],
};
let spatial_terms = spatial_length_scale_term_indices(&spec);
let updated = apply_spatial_anisotropy_pilot_initializer(
data.view(),
&mut spec,
&spatial_terms,
8,
&SpatialLengthScaleOptimizationOptions::default(),
);
assert_eq!(updated, 1);
match &spec.smooth_terms[0].basis {
SmoothBasisSpec::Matern { spec, .. } => {
let eta = spec
.aniso_log_scales
.as_ref()
.expect("pilot initializer should preserve anisotropy");
assert_eq!(eta.len(), 2);
assert!((eta[0] + eta[1]).abs() <= 1e-12);
assert!(
eta.iter().any(|value| value.abs() > 1e-6),
"pilot geometry should seed nonzero axis contrast"
);
assert!(spec.length_scale.is_finite() && spec.length_scale > 0.0);
}
_ => panic!("expected Matern term"),
}
}
#[test]
fn pure_duchon_aniso_penalties_stay_symmetric_through_freeze_and_cache() {
fn max_asymmetry(matrix: &Array2<f64>) -> f64 {
let n = matrix.nrows().min(matrix.ncols());
let mut max_asym = 0.0_f64;
for i in 0..n {
for j in 0..i {
max_asym = max_asym.max((matrix[[i, j]] - matrix[[j, i]]).abs());
}
}
max_asym
}
fn assert_design_penalties_symmetric(label: &str, design: &TermCollectionDesign) {
for (penalty_idx, penalty) in design.penalties.iter().enumerate() {
let max_asym = max_asymmetry(&penalty.local);
assert!(
max_asym <= 1e-10,
"{label} penalty {penalty_idx} asymmetry too large: {max_asym:.3e}"
);
}
}
fn assert_reparam_penalty_symmetric(label: &str, design: &TermCollectionDesign) {
let p_total = design.design.ncols();
let penalty_specs = design
.penalties
.iter()
.map(|penalty| crate::estimate::PenaltySpec::Dense(penalty.to_global(p_total)))
.collect::<Vec<_>>();
let (canonical_penalties, _) = crate::construction::canonicalize_penalty_specs(
&penalty_specs,
&design.nullspace_dims,
p_total,
label,
)
.expect("canonicalize penalties");
let invariant = crate::construction::precompute_reparam_invariant_from_canonical(
&canonical_penalties,
p_total,
)
.expect("reparam invariant");
let lambdas = vec![1.0; canonical_penalties.len()];
let reparam = crate::construction::stable_reparameterizationwith_invariant(
&canonical_penalties,
&lambdas,
p_total,
&invariant,
None,
)
.expect("stable reparameterization");
let max_asym = max_asymmetry(&reparam.s_transformed);
assert!(
max_asym <= 1e-10,
"{label} transformed penalty asymmetry too large: {max_asym:.3e}"
);
}
let data = array![
[0.0, 0.1, 0.2],
[0.2, 0.0, 0.4],
[0.4, 0.3, 0.1],
[0.6, 0.5, 0.7],
[0.8, 0.7, 0.3],
[1.0, 0.9, 0.8],
];
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "pure_duchon".to_string(),
basis: SmoothBasisSpec::Duchon {
feature_cols: vec![0, 1, 2],
spec: DuchonBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 5 },
length_scale: None,
power: 1,
nullspace_order: DuchonNullspaceOrder::Linear,
identifiability: SpatialIdentifiability::default(),
aniso_log_scales: Some(vec![0.0, 0.0, 0.0]),
operator_penalties: DuchonOperatorPenaltySpec::default(),
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let base_design = build_term_collection_design(data.view(), &spec).expect("base design");
assert_design_penalties_symmetric("base", &base_design);
assert_reparam_penalty_symmetric("base", &base_design);
let frozen = freeze_term_collection_from_design(&spec, &base_design).expect("freeze spec");
let frozen_design =
build_term_collection_design(data.view(), &frozen).expect("frozen rebuild");
assert_design_penalties_symmetric("frozen", &frozen_design);
assert_reparam_penalty_symmetric("frozen", &frozen_design);
let spatial_terms = spatial_length_scale_term_indices(&frozen);
let rho_dim = frozen_design.penalties.len();
let dims_per_term = vec![2];
let mut theta = Array1::<f64>::zeros(rho_dim + 2);
theta[rho_dim] = 0.2;
theta[rho_dim + 1] = -0.2;
let mut cache = SingleBlockExactJointDesignCache::new(
data.view(),
frozen.clone(),
frozen_design.clone(),
spatial_terms,
rho_dim,
dims_per_term,
)
.expect("single-block cache");
cache.ensure_theta(&theta).expect("updated theta");
assert_design_penalties_symmetric("cache", cache.design());
assert_reparam_penalty_symmetric("cache", cache.design());
}
#[test]
fn single_block_no_spatial_fast_path_returns_fully_frozen_spec() {
let n = 48usize;
let mut data = Array2::<f64>::zeros((n, 2));
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let t = i as f64 / (n as f64 - 1.0);
data[[i, 0]] = t;
data[[i, 1]] = (i % 4) as f64;
y[i] = 0.5 + 1.5 * t;
}
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![RandomEffectTermSpec {
name: "grp".to_string(),
feature_col: 1,
drop_first_level: false,
frozen_levels: None,
}],
smooth_terms: vec![SmoothTermSpec {
name: "ps".to_string(),
basis: SmoothBasisSpec::BSpline1D {
feature_col: 0,
spec: BSplineBasisSpec {
degree: 3,
penalty_order: 2,
knotspec: BSplineKnotSpec::Generate {
data_range: (0.0, 1.0),
num_internal_knots: 4,
},
double_penalty: true,
identifiability: BSplineIdentifiability::None,
},
},
shape: ShapeConstraint::None,
}],
};
let fit_opts = FitOptions {
latent_cloglog: None,
mixture_link: None,
optimize_mixture: false,
sas_link: None,
optimize_sas: false,
compute_inference: true,
max_iter: 40,
tol: 1e-6,
nullspace_dims: vec![],
linear_constraints: None,
firth_bias_reduction: false,
adaptive_regularization: None,
penalty_shrinkage_floor: None,
rho_prior: Default::default(),
kronecker_penalty_system: None,
kronecker_factored: None,
};
let fitted = fit_term_collectionwith_spatial_length_scale_optimization(
data.view(),
y,
Array1::ones(n),
Array1::zeros(n),
&spec,
LikelihoodFamily::GaussianIdentity,
&fit_opts,
&SpatialLengthScaleOptimizationOptions::default(),
)
.expect("single-block no-spatial fit should succeed");
fitted
.resolvedspec
.validate_frozen("resolvedspec")
.expect("single-block no-spatial fast path should fully freeze specs");
match &fitted.resolvedspec.smooth_terms[0].basis {
SmoothBasisSpec::BSpline1D { spec, .. } => {
assert!(matches!(spec.knotspec, BSplineKnotSpec::Provided(_)));
}
_ => panic!("expected P-spline term"),
}
assert!(
fitted.resolvedspec.random_effect_terms[0]
.frozen_levels
.is_some(),
"random-effect levels should be frozen in single-block no-spatial fast path"
);
}
#[test]
fn exact_joint_two_block_spatial_length_scale_freezes_duchon_centers() {
let n = 40usize;
let mut data = Array2::<f64>::zeros((n, 2));
for i in 0..n {
let x0 = i as f64 / (n as f64 - 1.0);
let x1 = (i as f64 * 0.19).cos();
data[[i, 0]] = x0;
data[[i, 1]] = x1;
}
let duchon_term = |name: &str, length_scale: f64| SmoothTermSpec {
name: name.to_string(),
basis: SmoothBasisSpec::Duchon {
feature_cols: vec![0, 1],
spec: DuchonBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 8 },
length_scale: Some(length_scale),
power: 3,
nullspace_order: DuchonNullspaceOrder::Linear,
identifiability: SpatialIdentifiability::default(),
aniso_log_scales: None,
operator_penalties: DuchonOperatorPenaltySpec::default(),
},
input_scales: None,
},
shape: ShapeConstraint::None,
};
let meanspec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![duchon_term("mean_duchon", 0.8)],
};
let noisespec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![duchon_term("noise_duchon", 1.1)],
};
let kappa_options = SpatialLengthScaleOptimizationOptions {
enabled: true,
max_outer_iter: 1,
rel_tol: 1e-6,
log_step: std::f64::consts::LN_2,
min_length_scale: 1e-3,
max_length_scale: 1e3,
pilot_subsample_threshold: 0,
};
let joint_setup = two_block_exact_joint_hyper_setup(&meanspec, &noisespec, &kappa_options);
let theta_dim = joint_setup.theta0().len();
let mean_terms = spatial_length_scale_term_indices(&meanspec);
let noise_terms = spatial_length_scale_term_indices(&noisespec);
let policy = crate::families::custom_family::OuterDerivativePolicy {
capability: crate::families::custom_family::ExactOuterDerivativeOrder::Second,
predicted_hessian_work: 0,
predicted_gradient_work: 0,
subsample_capable: false,
};
let solved = optimize_spatial_length_scale_exact_joint(
data.view(),
&[meanspec.clone(), noisespec.clone()],
&[mean_terms, noise_terms],
&kappa_options,
&joint_setup,
crate::seeding::SeedRiskProfile::Gaussian,
true,
true,
false,
None,
policy,
|theta, specs, designs| {
assert_eq!(theta.len(), theta_dim);
assert_eq!(specs.len(), 2);
Ok(designs[0].design.ncols() as f64
+ designs[1].design.ncols() as f64
+ designs[0].penalties.len() as f64
+ designs[1].penalties.len() as f64)
},
|theta, specs, designs, eval_mode, _row_set| {
assert_eq!(theta.len(), theta_dim);
assert_eq!(specs.len(), 2);
assert!(!designs.is_empty());
Ok((
0.0,
Array1::zeros(theta_dim),
if matches!(
eval_mode,
crate::solver::estimate::reml::unified::EvalMode::ValueGradientHessian
) {
crate::solver::outer_strategy::HessianResult::Analytic(Array2::zeros((
theta_dim, theta_dim,
)))
} else {
crate::solver::outer_strategy::HessianResult::Unavailable
},
))
},
|theta, specs, designs| {
assert_eq!(theta.len(), theta_dim);
assert_eq!(specs.len(), 2);
assert!(!designs.is_empty());
Ok(crate::solver::outer_strategy::EfsEval {
cost: 0.0,
steps: vec![0.0; theta_dim],
beta: None,
psi_gradient: None,
psi_indices: None,
})
},
)
.expect("exact joint two-block spatial length-scale optimization should succeed");
for resolved in [&solved.resolved_specs[0], &solved.resolved_specs[1]] {
match &resolved.smooth_terms[0].basis {
SmoothBasisSpec::Duchon { spec, .. } => {
assert!(matches!(
spec.center_strategy,
CenterStrategy::UserProvided(_)
));
assert!(matches!(
spec.identifiability,
SpatialIdentifiability::FrozenTransform { .. }
));
}
_ => panic!("expected Duchon term"),
}
}
}
#[test]
fn joint_build_and_cache_rebuild_frozen_pure_duchon_blocks() {
let n = 72usize;
let d = 5usize;
let mut data = Array2::<f64>::zeros((n, d));
for i in 0..n {
let t = i as f64 / (n as f64 - 1.0);
data[[i, 0]] = t;
data[[i, 1]] = (0.17 * i as f64).sin();
data[[i, 2]] = (0.11 * i as f64).cos();
data[[i, 3]] = ((i % 7) as f64) / 6.0;
data[[i, 4]] = t * (1.0 - t);
}
let pure_duchon_term = |name: &str| SmoothTermSpec {
name: name.to_string(),
basis: SmoothBasisSpec::Duchon {
feature_cols: (0..d).collect(),
spec: DuchonBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 24 },
length_scale: None,
power: 2,
nullspace_order: DuchonNullspaceOrder::Linear,
identifiability: SpatialIdentifiability::default(),
aniso_log_scales: Some(vec![0.0; d]),
operator_penalties: DuchonOperatorPenaltySpec::default(),
},
input_scales: None,
},
shape: ShapeConstraint::None,
};
let meanspec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![pure_duchon_term("mean_pure_duchon")],
};
let noisespec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![pure_duchon_term("noise_pure_duchon")],
};
let (boot_designs, frozen_specs) = build_term_collection_designs_and_freeze_joint(
data.view(),
&[meanspec.clone(), noisespec.clone()],
)
.expect("initial joint pure Duchon build");
assert_eq!(boot_designs.len(), 2);
assert_eq!(frozen_specs.len(), 2);
assert_eq!(boot_designs[0].smooth.terms[0].coeff_range.len(), 23);
assert_eq!(boot_designs[1].smooth.terms[0].coeff_range.len(), 23);
let (rebuilt_designs, refrozen_specs) =
build_term_collection_designs_and_freeze_joint(data.view(), &frozen_specs)
.expect("rebuilding frozen joint pure Duchon specs should succeed");
assert_eq!(rebuilt_designs.len(), 2);
assert_eq!(refrozen_specs.len(), 2);
for idx in 0..2 {
let direct = build_term_collection_design(data.view(), &frozen_specs[idx])
.expect("direct frozen pure Duchon rebuild");
assert_term_collection_designs_match(
&rebuilt_designs[idx],
&direct,
if idx == 0 {
"mean pure Duchon frozen rebuild"
} else {
"noise pure Duchon frozen rebuild"
},
);
assert_eq!(rebuilt_designs[idx].smooth.terms[0].coeff_range.len(), 23);
match &refrozen_specs[idx].smooth_terms[0].basis {
SmoothBasisSpec::Duchon { spec, .. } => {
assert!(matches!(
spec.identifiability,
SpatialIdentifiability::FrozenTransform { .. }
));
}
_ => panic!("expected Duchon term"),
}
}
let kappa_options = SpatialLengthScaleOptimizationOptions {
enabled: true,
max_outer_iter: 1,
rel_tol: 1e-6,
log_step: std::f64::consts::LN_2,
min_length_scale: 1e-3,
max_length_scale: 1e3,
pilot_subsample_threshold: 0,
};
let joint_setup =
two_block_exact_joint_hyper_setup(&frozen_specs[0], &frozen_specs[1], &kappa_options);
assert!(joint_setup.log_kappa_dim() > 0);
let mean_term_indices = spatial_length_scale_term_indices(&frozen_specs[0]);
let noise_term_indices = spatial_length_scale_term_indices(&frozen_specs[1]);
let mut cache = ExactJointDesignCache::new(
data.view(),
vec![
(
frozen_specs[0].clone(),
rebuilt_designs[0].clone(),
mean_term_indices.clone(),
),
(
frozen_specs[1].clone(),
rebuilt_designs[1].clone(),
noise_term_indices.clone(),
),
],
joint_setup.rho_dim(),
joint_setup.log_kappa_dims_per_term(),
)
.expect("pure Duchon exact-joint cache");
let mut theta1 = joint_setup.theta0();
let psi_start = joint_setup.rho_dim();
theta1[psi_start] += 0.25;
theta1[psi_start + 1] -= 0.15;
cache
.ensure_theta(&theta1)
.expect("pure Duchon cache theta update");
let log_kappa = SpatialLogKappaCoords::from_theta_tail_with_dims(
&theta1,
joint_setup.rho_dim(),
joint_setup.log_kappa_dims_per_term(),
);
let (mean_lk, noise_lk) = log_kappa.split_at(mean_term_indices.len());
let mean_updated = mean_lk
.apply_tospec(&frozen_specs[0], &mean_term_indices)
.expect("updated mean pure Duchon spec");
let noise_updated = noise_lk
.apply_tospec(&frozen_specs[1], &noise_term_indices)
.expect("updated noise pure Duchon spec");
let mean_rebuilt =
build_term_collection_design(data.view(), &mean_updated).expect("mean rebuilt");
let noise_rebuilt =
build_term_collection_design(data.view(), &noise_updated).expect("noise rebuilt");
let cache_designs = cache.designs();
assert_term_collection_designs_match(
cache_designs[0],
&mean_rebuilt,
"mean pure Duchon cache",
);
assert_term_collection_designs_match(
cache_designs[1],
&noise_rebuilt,
"noise pure Duchon cache",
);
}
#[test]
fn bounded_linear_gaussian_fit_respects_interval() {
let n = 64usize;
let mut data = Array2::<f64>::zeros((n, 2));
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let x = -1.0 + 2.0 * (i as f64) / ((n - 1) as f64);
let z = (i as f64) / ((n - 1) as f64);
data[[i, 0]] = x;
data[[i, 1]] = z;
y[i] = 0.25 + 0.8 * x + 0.05 * z;
}
let spec = TermCollectionSpec {
linear_terms: vec![
LinearTermSpec {
name: "x".to_string(),
feature_col: 0,
double_penalty: false,
coefficient_geometry: LinearCoefficientGeometry::Bounded {
min: 0.0,
max: 0.5,
prior: BoundedCoefficientPriorSpec::Beta { a: 2.0, b: 2.0 },
},
coefficient_min: None,
coefficient_max: None,
},
LinearTermSpec {
name: "z".to_string(),
feature_col: 1,
double_penalty: false,
coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
coefficient_min: None,
coefficient_max: None,
},
],
random_effect_terms: vec![],
smooth_terms: vec![],
};
let fitted = fit_term_collectionwith_spatial_length_scale_optimization(
data.view(),
y,
Array1::ones(n),
Array1::zeros(n),
&spec,
LikelihoodFamily::GaussianIdentity,
&FitOptions {
latent_cloglog: None,
mixture_link: None,
optimize_mixture: false,
sas_link: None,
optimize_sas: false,
compute_inference: true,
max_iter: 40,
tol: 1e-6,
nullspace_dims: vec![],
linear_constraints: None,
firth_bias_reduction: false,
adaptive_regularization: None,
penalty_shrinkage_floor: None,
rho_prior: Default::default(),
kronecker_penalty_system: None,
kronecker_factored: None,
},
&SpatialLengthScaleOptimizationOptions {
enabled: false,
..SpatialLengthScaleOptimizationOptions::default()
},
)
.expect("bounded gaussian fit");
let bounded_idx = fitted.design.linear_ranges[0].1.start;
let estimate = fitted.fit.beta[bounded_idx];
assert!(
(0.0..=0.5).contains(&estimate),
"bounded coefficient escaped interval: {estimate}"
);
assert!(
estimate > 0.1,
"bounded coefficient should move into the positive interior, got {estimate}"
);
}
#[test]
fn term_collection_design_emits_linear_coefficient_constraints() {
let data = array![[0.0], [1.0], [2.0], [3.0]];
let spec = TermCollectionSpec {
linear_terms: vec![LinearTermSpec {
name: "x".to_string(),
feature_col: 0,
double_penalty: false,
coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
coefficient_min: Some(0.0),
coefficient_max: Some(1.0),
}],
random_effect_terms: vec![],
smooth_terms: vec![],
};
let design = build_term_collection_design(data.view(), &spec).expect("design");
let constraints = design.linear_constraints.expect("constraints");
assert_eq!(constraints.a.ncols(), design.design.ncols());
assert_eq!(constraints.a.nrows(), 2);
let linear_idx = design.linear_ranges[0].1.start;
assert_eq!(constraints.a[[0, linear_idx]], 1.0);
assert_eq!(constraints.b[0], 0.0);
assert_eq!(constraints.a[[1, linear_idx]], -1.0);
assert_eq!(constraints.b[1], -1.0);
}
#[test]
fn linear_termspec_defaults_to_penalizedwhen_field_is_omitted() {
let json = r#"{"name":"x","feature_col":0}"#;
let term: LinearTermSpec = serde_json::from_str(json).expect("deserialize linear term");
assert!(term.double_penalty);
assert!(matches!(
term.coefficient_geometry,
LinearCoefficientGeometry::Unconstrained
));
}
#[test]
fn linear_double_penalties_share_one_globalridge_block() {
let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
let spec = TermCollectionSpec {
linear_terms: vec![
LinearTermSpec {
name: "x1".to_string(),
feature_col: 0,
double_penalty: true,
coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
coefficient_min: None,
coefficient_max: None,
},
LinearTermSpec {
name: "x2".to_string(),
feature_col: 1,
double_penalty: true,
coefficient_geometry: LinearCoefficientGeometry::Unconstrained,
coefficient_min: None,
coefficient_max: None,
},
],
random_effect_terms: vec![],
smooth_terms: vec![],
};
let design = build_term_collection_design(data.view(), &spec).expect("design");
assert_eq!(design.penalties.len(), 1);
assert_eq!(design.penaltyinfo.len(), 1);
assert_eq!(design.penaltyinfo[0].termname.as_deref(), Some("linear"));
assert_eq!(design.penaltyinfo[0].penalty.effective_rank, 2);
let x1 = design.linear_ranges[0].1.start;
let x2 = design.linear_ranges[1].1.start;
let bp = &design.penalties[0];
let x1_local = x1 - bp.col_range.start;
let x2_local = x2 - bp.col_range.start;
assert_eq!(bp.local[[x1_local, x1_local]], 1.0);
assert_eq!(bp.local[[x2_local, x2_local]], 1.0);
}
#[test]
fn bounded_uniform_prior_matches_beta_one_one_terms() {
let theta = 0.7;
let uniform = bounded_prior_terms(theta, &BoundedCoefficientPriorSpec::Uniform);
let beta11 =
bounded_prior_terms(theta, &BoundedCoefficientPriorSpec::Beta { a: 1.0, b: 1.0 });
assert!((uniform.0 - beta11.0).abs() < 1e-12);
assert!((uniform.1 - beta11.1).abs() < 1e-12);
assert!((uniform.2 - beta11.2).abs() < 1e-12);
assert!((uniform.3 - beta11.3).abs() < 1e-12);
}
#[test]
fn boundednone_prior_has_no_extra_latentobjective_terms() {
let theta = 0.7;
let none = bounded_prior_terms(theta, &BoundedCoefficientPriorSpec::None);
assert_eq!(none, (0.0, 0.0, 0.0, 0.0));
let uniform = bounded_prior_terms(theta, &BoundedCoefficientPriorSpec::Uniform);
assert!(uniform.0.is_finite());
assert!(uniform.0 < 0.0);
assert!(uniform.1.abs() > 1e-6);
assert!(uniform.2 > 0.0);
assert!(uniform.3.is_finite());
}
#[test]
fn exact_bounded_edf_matches_trace_formula_for_simple_penalty() {
let penalties = vec![PenaltySpec::Dense(Array2::eye(1))];
let lambdas = array![0.25];
let cov = array![[2.0]];
let (edf_by_block, edf_total) =
exact_bounded_edf(&penalties, &lambdas, &cov).expect("exact bounded edf");
assert_eq!(edf_by_block.len(), 1);
assert!((edf_by_block[0] - 0.5).abs() < 1e-12);
assert!((edf_total - 0.5).abs() < 1e-12);
}
#[test]
fn bounded_joint_hessian_directional_derivative_matches_finite_difference() {
let x = array![[0.2, -1.0], [0.8, 0.5], [1.1, 1.2], [1.7, -0.3]];
let y = array![0.4, 1.0, 1.7, 2.2];
let weights = Array1::ones(y.len());
let family = BoundedLinearFamily {
family: LikelihoodFamily::GaussianIdentity,
latent_cloglog_state: None,
mixture_link_state: None,
sas_link_state: None,
y: y.clone(),
weights: weights.clone(),
design: x.clone(),
designzeroed: {
let mut dz = x.clone();
dz.column_mut(0).fill(0.0);
dz
},
offset: Array1::zeros(y.len()),
bounded_terms: vec![BoundedLinearTermMeta {
col_idx: 0,
min: 0.0,
max: 1.0,
prior: BoundedCoefficientPriorSpec::Uniform,
}],
};
let state = vec![ParameterBlockState {
beta: array![0.4, -0.2],
eta: Array1::zeros(y.len()),
}];
let direction = array![0.3, -0.4];
let analytic = family
.exact_newton_joint_hessian_directional_derivative(&state, &direction)
.expect("analytic derivative")
.expect("joint derivative");
let h = 1e-6;
let plus_state = vec![ParameterBlockState {
beta: &state[0].beta + &(direction.clone() * h),
eta: Array1::zeros(y.len()),
}];
let minus_state = vec![ParameterBlockState {
beta: &state[0].beta - &(direction.clone() * h),
eta: Array1::zeros(y.len()),
}];
let plus = family
.exact_newton_joint_hessian(&plus_state)
.expect("plus hessian")
.expect("plus exact hessian");
let minus = family
.exact_newton_joint_hessian(&minus_state)
.expect("minus hessian")
.expect("minus exact hessian");
let fd = (plus - minus) / (2.0 * h);
for i in 0..analytic.nrows() {
for j in 0..analytic.ncols() {
assert_eq!(
analytic[[i, j]].signum(),
fd[[i, j]].signum(),
"directional derivative sign mismatch at ({i},{j}): analytic={}, fd={}",
analytic[[i, j]],
fd[[i, j]]
);
assert!(
(analytic[[i, j]] - fd[[i, j]]).abs() < 1e-5,
"directional derivative mismatch at ({i},{j}): analytic={}, fd={}",
analytic[[i, j]],
fd[[i, j]]
);
}
}
}
#[test]
fn adaptive_initial_epsilons_use_mean_fallbackwhen_median_is_tiny() {
let cache = SpatialOperatorRuntimeCache {
termname: "matern".to_string(),
feature_cols: vec![0, 1],
coeff_global_range: 0..2,
mass_penalty_global_idx: 0,
tension_penalty_global_idx: 1,
stiffness_penalty_global_idx: 2,
d0: array![[5e-10, 0.0], [6e-10, 0.0]],
d1: array![[1e-10, 0.0], [0.0, 1e-10], [2e-10, 0.0], [0.0, 2e-10]],
d2: array![
[3e-10, 0.0],
[0.0, 0.0],
[0.0, 0.0],
[3e-10, 0.0],
[4e-10, 0.0],
[0.0, 0.0],
[0.0, 0.0],
[4e-10, 0.0],
],
collocation_points: array![[0.0, 0.0], [1.0, 1.0]],
dimension: 2,
};
let beta = array![1.0, 1.0];
let (eps_0, eps_g, eps_c) =
compute_initial_epsilons(&beta, &[cache], 1e-8).expect("initial epsilons");
assert!(eps_0 >= 1e-8);
assert!(eps_g >= 1e-8);
assert!(eps_c >= 1e-8);
}
#[test]
fn adaptive_exact_psigradient_symmetrizes_nearly_symmetrichessian() {
let family = SpatialAdaptiveExactFamily {
family: LikelihoodFamily::GaussianIdentity,
latent_cloglog_state: None,
mixture_link_state: None,
sas_link_state: None,
y: Arc::new(array![0.0, 0.0]),
weights: Arc::new(array![1.0, 1.0]),
design: Arc::new(array![[1.0, 0.0], [0.0, 1.0]]),
offset: Arc::new(array![0.0, 0.0]),
linear_constraints: None,
runtime_caches: Arc::new(vec![SpatialOperatorRuntimeCache {
termname: "toy".to_string(),
feature_cols: vec![0],
coeff_global_range: 0..2,
mass_penalty_global_idx: 0,
tension_penalty_global_idx: 1,
stiffness_penalty_global_idx: 2,
d0: array![[1.0, 0.0], [0.0, 1.0]],
d1: array![[1.0, 0.0], [0.0, 1.0]],
d2: array![[1.0, 0.0], [0.0, 1.0]],
collocation_points: array![[0.0], [1.0]],
dimension: 1,
}]),
adaptive_params: vec![SpatialAdaptiveTermHyperParams {
lambda: [1.0, 1.0, 1.0],
epsilon: [1.0, 1.0, 1.0],
}],
fixed_quadratichessian: Arc::new(array![[0.0, 0.1], [3.0, 0.0]]),
hyperspecs: Arc::new(build_spatial_adaptive_hyperspecs(1)),
exact_eval_cache: Arc::new(Mutex::new(None)),
};
let spec = ParameterBlockSpec {
name: "toy".to_string(),
design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(array![
[1.0, 0.0],
[0.0, 1.0]
])),
offset: array![0.0, 0.0],
penalties: vec![],
nullspace_dims: vec![],
initial_log_lambdas: Array1::zeros(0),
initial_beta: Some(array![0.0, 0.0]),
};
let deriv = CustomFamilyBlockPsiDerivative {
penalty_index: None,
x_psi: Array2::zeros((2, 2)),
s_psi: Array2::zeros((2, 2)),
s_psi_components: None,
s_psi_penalty_components: None,
x_psi_psi: None,
s_psi_psi: None,
s_psi_psi_components: None,
s_psi_psi_penalty_components: None,
implicit_operator: None,
implicit_axis: 0,
implicit_group_id: None,
};
let state = vec![ParameterBlockState {
beta: array![0.0, 0.0],
eta: array![0.0, 0.0],
}];
let gradient = family
.exact_newton_joint_psi_terms(&state, std::slice::from_ref(&spec), &[vec![deriv]], 0)
.expect("adaptive joint psi terms should tolerate nearly symmetric Hessian")
.expect("adaptive joint psi terms should be present")
.objective_psi;
assert!(
gradient.is_finite(),
"expected finite adaptive joint psi objective term after symmetrization, got {gradient}"
);
}
#[test]
fn adaptiveweighted_operator_grams_are_symmetric_and_psd() {
let d1 = array![
[1.0, 0.0, 2.0],
[0.5, 1.0, 0.0],
[0.0, 1.0, 1.0],
[1.5, 0.0, 0.5],
];
let d2 = array![
[1.0, 0.0, 1.0],
[0.0, 1.0, 2.0],
[2.0, 0.0, 0.5],
[0.0, 0.5, 1.0],
[0.5, 0.0, 1.5],
[1.0, 1.0, 0.0],
[0.0, 2.0, 1.0],
[1.5, 0.0, 0.0],
];
let weight = array![2.0, 3.0];
let s1 = weighted_operator_gram_from_d1(&d1, &weight, 2);
let s2 = weighted_operator_gram_from_d2(&d2, &weight);
for i in 0..s1.nrows() {
for j in 0..s1.ncols() {
assert!((s1[[i, j]] - s1[[j, i]]).abs() < 1e-10);
assert!((s2[[i, j]] - s2[[j, i]]).abs() < 1e-10);
}
}
let v = array![0.2, -0.3, 0.5];
let q1 = v.dot(&s1.dot(&v));
let q2 = v.dot(&s2.dot(&v));
assert!(q1 >= -1e-10);
assert!(q2 >= -1e-10);
}
#[test]
fn adaptiveweight_clamp_is_applied_in_u_space() {
let cache = SpatialOperatorRuntimeCache {
termname: "matern".to_string(),
feature_cols: vec![0, 1],
coeff_global_range: 0..2,
mass_penalty_global_idx: 0,
tension_penalty_global_idx: 1,
stiffness_penalty_global_idx: 2,
d0: array![[0.0, 0.0]],
d1: array![[0.0, 0.0], [0.0, 0.0]],
d2: array![[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]],
collocation_points: array![[0.0, 0.0]],
dimension: 2,
};
let beta = array![0.0, 0.0];
let out =
compute_spatial_adaptiveweights_for_beta(&beta, &[cache], 1e-8, 1e-8, 1e-8, 1e-8, 1e2)
.expect("adaptive weights");
assert_eq!(out.len(), 1);
assert!((out[0].magweight[0] - 1e2).abs() < 1e-12);
assert!((out[0].gradweight[0] - 1e2).abs() < 1e-12);
assert!((out[0].lapweight[0] - 1e2).abs() < 1e-12);
assert!((out[0].inv_magweight[0] - 1e-2).abs() < 1e-12);
assert!((out[0].invgradweight[0] - 1e-2).abs() < 1e-12);
assert!((out[0].inv_lapweight[0] - 1e-2).abs() < 1e-12);
}
#[test]
fn adaptiveweight_inverse_consistencywithout_clamp() {
let cache = SpatialOperatorRuntimeCache {
termname: "matern".to_string(),
feature_cols: vec![0, 1],
coeff_global_range: 0..2,
mass_penalty_global_idx: 0,
tension_penalty_global_idx: 1,
stiffness_penalty_global_idx: 2,
d0: array![[1.0, 0.0], [2.0, 0.0]],
d1: array![[1.0, 0.0], [0.0, 1.0], [2.0, 0.0], [0.0, 2.0]],
d2: array![
[1.0, 0.0],
[0.0, 0.0],
[0.0, 0.0],
[0.0, 0.0],
[2.0, 0.0],
[0.0, 0.0],
[0.0, 0.0],
[0.0, 0.0],
],
collocation_points: array![[0.0, 0.0], [1.0, 1.0]],
dimension: 2,
};
let beta = array![1.0, 1.0];
let out = compute_spatial_adaptiveweights_for_beta(
&beta,
&[cache],
1e-6,
1e-6,
1e-6,
1e-12,
1e12,
)
.expect("adaptive weights");
assert_eq!(out.len(), 1);
for k in 0..out[0].gradweight.len() {
let p0 = out[0].magweight[k] * out[0].inv_magweight[k];
let pg = out[0].gradweight[k] * out[0].invgradweight[k];
let pc = out[0].lapweight[k] * out[0].inv_lapweight[k];
assert!((p0 - 1.0).abs() < 1e-10, "mag pair mismatch at {k}: {p0}");
assert!((pg - 1.0).abs() < 1e-10, "grad pair mismatch at {k}: {pg}");
assert!((pc - 1.0).abs() < 1e-10, "lap pair mismatch at {k}: {pc}");
}
}
#[test]
fn adaptiveweight_is_monotone_in_signal_magnitude() {
let cache = SpatialOperatorRuntimeCache {
termname: "matern".to_string(),
feature_cols: vec![0, 1],
coeff_global_range: 0..2,
mass_penalty_global_idx: 0,
tension_penalty_global_idx: 1,
stiffness_penalty_global_idx: 2,
d0: array![[1.0, 0.0]],
d1: array![[1.0, 0.0], [0.0, 1.0]],
d2: array![[1.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]],
collocation_points: array![[0.0, 0.0]],
dimension: 2,
};
let beta_small = array![0.25, 0.25];
let beta_large = array![2.0, 2.0];
let small = compute_spatial_adaptiveweights_for_beta(
&beta_small,
std::slice::from_ref(&cache),
1e-8,
1e-8,
1e-8,
1e-12,
1e12,
)
.expect("small adaptive weights");
let large = compute_spatial_adaptiveweights_for_beta(
&beta_large,
&[cache],
1e-8,
1e-8,
1e-8,
1e-12,
1e12,
)
.expect("large adaptive weights");
assert!(small[0].magweight[0] > large[0].magweight[0]);
assert!(small[0].gradweight[0] > large[0].gradweight[0]);
assert!(small[0].lapweight[0] > large[0].lapweight[0]);
assert!(small[0].inv_magweight[0] < large[0].inv_magweight[0]);
assert!(small[0].invgradweight[0] < large[0].invgradweight[0]);
assert!(small[0].inv_lapweight[0] < large[0].inv_lapweight[0]);
}
#[test]
fn exact_spatial_adaptive_regularization_fit_runswithout_mm() {
let n = 48usize;
let mut data = Array2::<f64>::zeros((n, 2));
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let x0 = i as f64 / (n as f64 - 1.0);
let x1 = (0.19 * i as f64).sin();
data[[i, 0]] = x0;
data[[i, 1]] = x1;
y[i] = (3.0 * x0).sin() + 0.25 * x1;
}
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "matern".to_string(),
basis: SmoothBasisSpec::Matern {
feature_cols: vec![0, 1],
spec: MaternBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 10 },
length_scale: 0.7,
nu: MaternNu::FiveHalves,
include_intercept: false,
double_penalty: true,
identifiability: MaternIdentifiability::CenterSumToZero,
aniso_log_scales: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let fit = fit_term_collection_forspec(
data.view(),
y.view(),
Array1::ones(n).view(),
Array1::zeros(n).view(),
&spec,
LikelihoodFamily::GaussianIdentity,
&FitOptions {
latent_cloglog: None,
mixture_link: None,
optimize_mixture: false,
sas_link: None,
optimize_sas: false,
compute_inference: true,
max_iter: 25,
tol: 1e-5,
nullspace_dims: vec![],
linear_constraints: None,
firth_bias_reduction: false,
adaptive_regularization: Some(AdaptiveRegularizationOptions {
enabled: true,
max_mm_iter: 4,
beta_rel_tol: 1e-4,
max_epsilon_outer_iter: 2,
epsilon_log_step: std::f64::consts::LN_2,
min_epsilon: 1e-6,
weight_floor: 1e-8,
weight_ceiling: 1e8,
}),
penalty_shrinkage_floor: None,
rho_prior: Default::default(),
kronecker_penalty_system: None,
kronecker_factored: None,
},
)
.expect("exact adaptive spatial fit should succeed");
let diag = fit
.adaptive_diagnostics
.as_ref()
.expect("adaptive diagnostics should be present");
assert_eq!(diag.mm_iterations, 0);
assert!(diag.epsilon_0.is_finite() && diag.epsilon_0 > 0.0);
assert!(diag.epsilon_g.is_finite() && diag.epsilon_g > 0.0);
assert!(diag.epsilon_c.is_finite() && diag.epsilon_c > 0.0);
assert_eq!(diag.maps.len(), 1);
assert!(fit.fit.beta.iter().all(|v| v.is_finite()));
assert!(fit.fit.reml_score.is_finite());
}
#[test]
fn exact_spatial_adaptive_binomial_sas_fit_preserves_link_state() {
let n = 36usize;
let mut data = Array2::<f64>::zeros((n, 2));
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let x0 = -1.0 + 2.0 * (i as f64 / (n as f64 - 1.0));
let x1 = (0.23 * i as f64).sin();
data[[i, 0]] = x0;
data[[i, 1]] = x1;
let eta = 0.55 * x0 - 0.2 * x1 + 0.1 * x0 * x1;
let p = 1.0 / (1.0 + (-eta).exp());
let u = ((i * 37 + 13) % 100) as f64 / 100.0;
y[i] = if u < p { 1.0 } else { 0.0 };
}
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "matern".to_string(),
basis: SmoothBasisSpec::Matern {
feature_cols: vec![0, 1],
spec: MaternBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 10 },
length_scale: 0.7,
nu: MaternNu::FiveHalves,
include_intercept: false,
double_penalty: true,
identifiability: MaternIdentifiability::CenterSumToZero,
aniso_log_scales: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let fit = fit_term_collection_forspec(
data.view(),
y.view(),
Array1::ones(n).view(),
Array1::zeros(n).view(),
&spec,
LikelihoodFamily::BinomialSas,
&FitOptions {
latent_cloglog: None,
mixture_link: None,
optimize_mixture: false,
sas_link: Some(crate::types::SasLinkSpec {
initial_epsilon: 0.1,
initial_log_delta: -0.2,
}),
optimize_sas: false,
compute_inference: true,
max_iter: 15,
tol: 1e-5,
nullspace_dims: vec![],
linear_constraints: None,
firth_bias_reduction: false,
adaptive_regularization: Some(AdaptiveRegularizationOptions {
enabled: true,
max_mm_iter: 4,
beta_rel_tol: 1e-4,
max_epsilon_outer_iter: 2,
epsilon_log_step: std::f64::consts::LN_2,
min_epsilon: 1e-6,
weight_floor: 1e-8,
weight_ceiling: 1e8,
}),
penalty_shrinkage_floor: None,
rho_prior: Default::default(),
kronecker_penalty_system: None,
kronecker_factored: None,
},
)
.expect("exact adaptive SAS fit should succeed");
match fit.fit.fitted_link {
FittedLinkState::Sas { state, covariance } => {
assert!(state.epsilon.is_finite());
assert!(state.log_delta.is_finite());
assert!(state.delta.is_finite() && state.delta > 0.0);
assert!(covariance.is_none());
}
other => panic!("expected SAS link parameters, got {other:?}"),
}
}
#[test]
fn exact_spatial_adaptive_joint_hypergradient_matches_finite_difference() {
let n = 36usize;
let mut data = Array2::<f64>::zeros((n, 2));
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let x0 = i as f64 / (n as f64 - 1.0);
let x1 = (0.31 * i as f64).sin();
data[[i, 0]] = x0;
data[[i, 1]] = x1;
y[i] = (4.0 * x0).sin() + 0.35 * x1 + 0.2 * ((x0 - 0.55) * 18.0).tanh();
}
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "matern".to_string(),
basis: SmoothBasisSpec::Matern {
feature_cols: vec![0, 1],
spec: MaternBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 8 },
length_scale: 0.6,
nu: MaternNu::FiveHalves,
include_intercept: false,
double_penalty: true,
identifiability: MaternIdentifiability::CenterSumToZero,
aniso_log_scales: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let baseline = fit_term_collection_forspec(
data.view(),
y.view(),
Array1::ones(n).view(),
Array1::zeros(n).view(),
&spec,
LikelihoodFamily::GaussianIdentity,
&FitOptions {
latent_cloglog: None,
mixture_link: None,
optimize_mixture: false,
sas_link: None,
optimize_sas: false,
compute_inference: true,
max_iter: 30,
tol: 1e-6,
nullspace_dims: vec![],
linear_constraints: None,
firth_bias_reduction: false,
adaptive_regularization: None,
penalty_shrinkage_floor: None,
rho_prior: Default::default(),
kronecker_penalty_system: None,
kronecker_factored: None,
},
)
.expect("baseline fit");
let runtime_caches = extract_spatial_operator_runtime_caches(&spec, &baseline.design)
.expect("runtime caches");
assert_eq!(runtime_caches.len(), 1);
let adaptive_opts = AdaptiveRegularizationOptions::default();
let (eps_0_init, eps_g_init, eps_c_init) = compute_initial_epsilons(
&baseline.fit.beta,
&runtime_caches,
adaptive_opts.min_epsilon,
)
.expect("initial epsilons");
let hyperspecs = build_spatial_adaptive_hyperspecs(runtime_caches.len());
let zero_psi_op: std::sync::Arc<
dyn crate::custom_family::CustomFamilyPsiDerivativeOperator,
> = std::sync::Arc::new(crate::custom_family::ZeroPsiDerivativeOperator::new(
baseline.design.design.nrows(),
baseline.design.design.ncols(),
));
let derivative_blocks = vec![
hyperspecs
.iter()
.enumerate()
.map(|(_, _)| CustomFamilyBlockPsiDerivative {
penalty_index: None,
x_psi: Array2::<f64>::zeros((0, 0)),
s_psi: Array2::<f64>::zeros((0, 0)),
s_psi_components: None,
s_psi_penalty_components: None,
x_psi_psi: None,
s_psi_psi: None,
s_psi_psi_components: None,
s_psi_psi_penalty_components: None,
implicit_operator: Some(std::sync::Arc::clone(&zero_psi_op)),
implicit_axis: 0,
implicit_group_id: None,
})
.collect::<Vec<_>>(),
];
let base_family = SpatialAdaptiveExactFamily {
family: LikelihoodFamily::GaussianIdentity,
latent_cloglog_state: None,
mixture_link_state: None,
sas_link_state: None,
y: Arc::new(y.clone()),
weights: Arc::new(Array1::ones(n)),
design: baseline.design.design.to_dense_arc(),
offset: Arc::new(Array1::zeros(n)),
linear_constraints: baseline.design.linear_constraints.clone(),
runtime_caches: Arc::new(runtime_caches.clone()),
adaptive_params: Vec::new(),
fixed_quadratichessian: Arc::new(Array2::<f64>::zeros((
baseline.design.design.ncols(),
baseline.design.design.ncols(),
))),
hyperspecs: Arc::new(hyperspecs),
exact_eval_cache: Arc::new(Mutex::new(None)),
};
let blockspec = ParameterBlockSpec {
name: "eta".to_string(),
design: baseline.design.design.clone(),
offset: Array1::zeros(n),
penalties: vec![],
nullspace_dims: vec![],
initial_log_lambdas: Array1::zeros(0),
initial_beta: Some(baseline.fit.beta.clone()),
};
let outer_opts = BlockwiseFitOptions {
inner_max_cycles: 30,
inner_tol: 1e-6,
outer_max_iter: 30,
outer_tol: 1e-6,
compute_covariance: false,
..BlockwiseFitOptions::default()
};
let evaluate_theta = |theta: &Array1<f64>, need_hessian: bool| {
let family = base_family.with_adaptive_params(
vec![SpatialAdaptiveTermHyperParams {
lambda: [theta[0].exp(), theta[1].exp(), theta[2].exp()],
epsilon: [theta[3].exp(), theta[4].exp(), theta[5].exp()],
}],
Arc::new(Array2::<f64>::zeros((
baseline.design.design.ncols(),
baseline.design.design.ncols(),
))),
);
evaluate_custom_family_joint_hyper(
&family,
std::slice::from_ref(&blockspec),
&outer_opts,
&Array1::zeros(0),
&derivative_blocks,
None,
if need_hessian {
crate::solver::estimate::reml::unified::EvalMode::ValueGradientHessian
} else {
crate::solver::estimate::reml::unified::EvalMode::ValueAndGradient
},
)
.expect("joint hyper eval")
};
let theta = array![
baseline.fit.lambdas[runtime_caches[0].mass_penalty_global_idx]
.max(1e-6)
.ln(),
baseline.fit.lambdas[runtime_caches[0].tension_penalty_global_idx]
.max(1e-6)
.ln(),
baseline.fit.lambdas[runtime_caches[0].stiffness_penalty_global_idx]
.max(1e-6)
.ln(),
eps_0_init.max(1e-6).ln(),
eps_g_init.max(1e-6).ln(),
eps_c_init.max(1e-6).ln(),
];
let analytic = evaluate_theta(&theta, true);
assert_eq!(analytic.gradient.len(), theta.len());
assert!(
analytic.outer_hessian.is_analytic(),
"adaptive joint hyper evaluation must expose exact Hessian curvature"
);
assert_eq!(
analytic.outer_hessian.dim(),
Some(theta.len()),
"adaptive joint hyper Hessian must span all lambda/epsilon coordinates"
);
let analytic_hessian = analytic
.outer_hessian
.clone()
.materialize_dense()
.expect("adaptive joint hyper Hessian should materialize")
.expect("adaptive joint hyper Hessian should be present");
let h = 1e-5;
for j in 0..theta.len() {
let mut plus = theta.clone();
plus[j] += h;
let mut minus = theta.clone();
minus[j] -= h;
let fd = (evaluate_theta(&plus, false).objective
- evaluate_theta(&minus, false).objective)
/ (2.0 * h);
assert!(
(analytic.gradient[j] - fd).abs() < 5e-3 * (1.0 + fd.abs()),
"adaptive joint hypergradient mismatch at {j}: analytic={}, fd={fd}",
analytic.gradient[j]
);
let grad_fd = (evaluate_theta(&plus, false).gradient
- evaluate_theta(&minus, false).gradient)
/ (2.0 * h);
for i in 0..theta.len() {
assert!(
(analytic_hessian[[i, j]] - grad_fd[i]).abs() < 5e-2 * (1.0 + grad_fd[i].abs()),
"adaptive joint hyper-Hessian mismatch at ({i},{j}): analytic={}, fd={}",
analytic_hessian[[i, j]],
grad_fd[i]
);
}
}
}
#[test]
fn exact_spatial_adaptive_1dobjective_profile_has_finite_gradient_lambda_surface() {
let n = 96usize;
let mut data = Array2::<f64>::zeros((n, 1));
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let x = i as f64 / (n as f64 - 1.0);
data[[i, 0]] = x;
y[i] = 0.12 * (2.0 * std::f64::consts::PI * x).sin()
+ 0.05 * (5.0 * std::f64::consts::PI * x).cos()
+ 1.4 / (1.0 + (-(x - 0.5) / 0.012).exp());
}
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "duchon".to_string(),
basis: SmoothBasisSpec::Duchon {
feature_cols: vec![0],
spec: DuchonBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 31 },
length_scale: Some(1.0),
power: 2,
nullspace_order: DuchonNullspaceOrder::Linear,
identifiability: SpatialIdentifiability::default(),
aniso_log_scales: None,
operator_penalties: DuchonOperatorPenaltySpec::default(),
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let baseline = fit_term_collection_forspec(
data.view(),
y.view(),
Array1::ones(n).view(),
Array1::zeros(n).view(),
&spec,
LikelihoodFamily::GaussianIdentity,
&FitOptions {
latent_cloglog: None,
mixture_link: None,
optimize_mixture: false,
sas_link: None,
optimize_sas: false,
compute_inference: true,
max_iter: 20,
tol: 1e-6,
nullspace_dims: vec![],
linear_constraints: None,
firth_bias_reduction: false,
adaptive_regularization: None,
penalty_shrinkage_floor: None,
rho_prior: Default::default(),
kronecker_penalty_system: None,
kronecker_factored: None,
},
)
.expect("baseline fit");
let runtime_caches = extract_spatial_operator_runtime_caches(&spec, &baseline.design)
.expect("runtime caches");
assert_eq!(runtime_caches.len(), 1);
let (eps_0, eps_g, eps_c) =
compute_initial_epsilons(&baseline.fit.beta, &runtime_caches, 1e-8)
.expect("initial epsilons");
let hyperspecs = build_spatial_adaptive_hyperspecs(runtime_caches.len());
let zero_psi_op: std::sync::Arc<
dyn crate::custom_family::CustomFamilyPsiDerivativeOperator,
> = std::sync::Arc::new(crate::custom_family::ZeroPsiDerivativeOperator::new(
baseline.design.design.nrows(),
baseline.design.design.ncols(),
));
let derivative_blocks = vec![
hyperspecs
.iter()
.enumerate()
.map(|(_, _)| CustomFamilyBlockPsiDerivative {
penalty_index: None,
x_psi: Array2::<f64>::zeros((0, 0)),
s_psi: Array2::<f64>::zeros((0, 0)),
s_psi_components: None,
s_psi_penalty_components: None,
x_psi_psi: None,
s_psi_psi: None,
s_psi_psi_components: None,
s_psi_psi_penalty_components: None,
implicit_operator: Some(std::sync::Arc::clone(&zero_psi_op)),
implicit_axis: 0,
implicit_group_id: None,
})
.collect::<Vec<_>>(),
];
let base_family = SpatialAdaptiveExactFamily {
family: LikelihoodFamily::GaussianIdentity,
latent_cloglog_state: None,
mixture_link_state: None,
sas_link_state: None,
y: Arc::new(y.clone()),
weights: Arc::new(Array1::ones(n)),
design: baseline.design.design.to_dense_arc(),
offset: Arc::new(Array1::zeros(n)),
linear_constraints: baseline.design.linear_constraints.clone(),
runtime_caches: Arc::new(runtime_caches.clone()),
adaptive_params: Vec::new(),
fixed_quadratichessian: Arc::new(Array2::<f64>::zeros((
baseline.design.design.ncols(),
baseline.design.design.ncols(),
))),
hyperspecs: Arc::new(hyperspecs),
exact_eval_cache: Arc::new(Mutex::new(None)),
};
let blockspec = ParameterBlockSpec {
name: "eta".to_string(),
design: baseline.design.design.clone(),
offset: Array1::zeros(n),
penalties: vec![],
nullspace_dims: vec![],
initial_log_lambdas: Array1::zeros(0),
initial_beta: Some(baseline.fit.beta.clone()),
};
let outer_opts = BlockwiseFitOptions {
inner_max_cycles: 20,
inner_tol: 1e-6,
outer_max_iter: 20,
outer_tol: 1e-6,
compute_covariance: false,
..BlockwiseFitOptions::default()
};
let evaluate_theta = |log_lambda_g: f64| {
let family = base_family.with_adaptive_params(
vec![SpatialAdaptiveTermHyperParams {
lambda: [1e-12, log_lambda_g.exp(), 1e-12],
epsilon: [eps_0, eps_g, eps_c],
}],
Arc::new(Array2::<f64>::zeros((
baseline.design.design.ncols(),
baseline.design.design.ncols(),
))),
);
evaluate_custom_family_joint_hyper(
&family,
std::slice::from_ref(&blockspec),
&outer_opts,
&Array1::zeros(0),
&derivative_blocks,
None,
crate::solver::estimate::reml::unified::EvalMode::ValueAndGradient,
)
.expect("joint hyper eval")
};
let low = evaluate_theta((1e-8_f64).ln());
let mid = evaluate_theta((1e-4_f64).ln());
let high = evaluate_theta((1e-2_f64).ln());
for (label, eval) in [("low", &low), ("mid", &mid), ("high", &high)] {
assert!(
eval.objective.is_finite(),
"{label} gradient-lambda profile objective is not finite: {}",
eval.objective
);
assert!(
eval.gradient.iter().all(|v| v.is_finite()),
"{label} gradient-lambda profile gradient contains non-finite entries: {:?}",
eval.gradient
);
}
assert!(
(low.objective - high.objective).abs() > 1e-8,
"gradient-lambda profile should remain identifiable: low={}, high={}",
low.objective,
high.objective
);
}
#[test]
fn exact_spatial_adaptive_high_center_duchon_fit_no_longer_fails_in_outer_solver() {
let n = 320usize;
let mut data = Array2::<f64>::zeros((n, 1));
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let x = i as f64 / (n as f64 - 1.0);
data[[i, 0]] = x;
y[i] = 0.12 * (2.0 * std::f64::consts::PI * x).sin()
+ 0.05 * (5.0 * std::f64::consts::PI * x).cos()
+ 1.4 / (1.0 + (-(x - 0.5) / 0.012).exp());
}
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "duchon".to_string(),
basis: SmoothBasisSpec::Duchon {
feature_cols: vec![0],
spec: DuchonBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 120 },
length_scale: Some(1.0),
power: 2,
nullspace_order: DuchonNullspaceOrder::Linear,
identifiability: SpatialIdentifiability::default(),
aniso_log_scales: None,
operator_penalties: DuchonOperatorPenaltySpec::default(),
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let fit = fit_term_collection_forspec(
data.view(),
y.view(),
Array1::ones(n).view(),
Array1::zeros(n).view(),
&spec,
LikelihoodFamily::GaussianIdentity,
&FitOptions {
latent_cloglog: None,
mixture_link: None,
optimize_mixture: false,
sas_link: None,
optimize_sas: false,
compute_inference: true,
max_iter: 40,
tol: 1e-6,
nullspace_dims: vec![],
linear_constraints: None,
firth_bias_reduction: false,
adaptive_regularization: Some(AdaptiveRegularizationOptions {
enabled: true,
max_mm_iter: 10,
beta_rel_tol: 1e-3,
max_epsilon_outer_iter: 4,
epsilon_log_step: std::f64::consts::LN_2,
min_epsilon: 1e-8,
weight_floor: 1e-8,
weight_ceiling: 1e8,
}),
penalty_shrinkage_floor: None,
rho_prior: Default::default(),
kronecker_penalty_system: None,
kronecker_factored: None,
},
)
.expect("high-center adaptive Duchon fit should not fail");
assert!(fit.fit.beta.iter().all(|v| v.is_finite()));
assert!(fit.fit.deviance.is_finite());
assert!(fit.fit.edf_total().is_some_and(f64::is_finite));
let diag = fit
.adaptive_diagnostics
.as_ref()
.expect("adaptive diagnostics should be present");
assert!(diag.epsilon_0.is_finite() && diag.epsilon_0 > 0.0);
assert!(diag.epsilon_g.is_finite() && diag.epsilon_g > 0.0);
assert!(diag.epsilon_c.is_finite() && diag.epsilon_c > 0.0);
}
#[test]
fn binomial_logit_tail_curvature_uses_stable_exact_formula() {
let eta = array![30.0, 30.0, -30.0, -30.0, 40.0, -40.0];
let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
let weights = Array1::ones(eta.len());
let obs = evaluate_standard_familyobservations(
LikelihoodFamily::BinomialLogit,
None,
None,
None,
&y,
&weights,
&eta,
)
.expect("stable logit observations");
for i in 0..eta.len() {
let jet = logit_inverse_link_jet5(eta[i]);
assert!(
(obs.neghessian_eta[i] - jet.d1).abs() <= 1e-12 * (1.0 + jet.d1.abs()),
"eta={} y={} curvature={} target={}",
eta[i],
y[i],
obs.neghessian_eta[i],
jet.d1
);
assert!(
(obs.neghessian_eta_derivative[i] - jet.d2).abs() <= 1e-12 * (1.0 + jet.d2.abs()),
"eta={} y={} dcurvature={} target={}",
eta[i],
y[i],
obs.neghessian_eta_derivative[i],
jet.d2
);
assert!(
obs.neghessian_eta[i].is_finite()
&& obs.neghessian_eta_derivative[i].is_finite()
&& obs.log_likelihood.is_finite(),
"expected finite logit tail observation state at eta={} y={}",
eta[i],
y[i]
);
}
}
#[test]
fn non_logit_binomial_tailobservations_stay_finite() {
let eta = array![12.0, -12.0, 18.0, -18.0];
let y = array![0.0, 1.0, 1.0, 0.0];
let weights = Array1::ones(eta.len());
for family in [
LikelihoodFamily::BinomialProbit,
LikelihoodFamily::BinomialCLogLog,
] {
let obs =
evaluate_standard_familyobservations(family, None, None, None, &y, &weights, &eta)
.expect("tail observations");
assert!(obs.log_likelihood.is_finite(), "family={family:?}");
assert!(
obs.score.iter().all(|v| v.is_finite())
&& obs.neghessian_eta.iter().all(|v| v.is_finite())
&& obs.neghessian_eta_derivative.iter().all(|v| v.is_finite()),
"family={family:?}"
);
}
}
#[test]
fn two_block_exact_joint_setup_sanitizes_non_finite_rho_seed() {
let setup = ExactJointHyperSetup::new(
array![f64::NEG_INFINITY, 0.25, f64::INFINITY],
array![-12.0, -12.0, -12.0],
array![12.0, 12.0, 12.0],
SpatialLogKappaCoords::new_with_dims(array![0.5], vec![1]),
SpatialLogKappaCoords::new_with_dims(array![-2.0], vec![1]),
SpatialLogKappaCoords::new_with_dims(array![2.0], vec![1]),
);
let theta0 = setup.theta0();
assert!(theta0.iter().all(|v| v.is_finite()));
assert_eq!(theta0[0], 0.0);
assert_eq!(theta0[1], 0.25);
assert_eq!(theta0[2], 0.0);
assert_eq!(theta0[3], 0.5);
}
#[test]
fn extracted_spatial_runtime_cache_matches_normalized_design_penalties() {
let n = 24usize;
let mut data = Array2::<f64>::zeros((n, 2));
for i in 0..n {
data[[i, 0]] = i as f64 / (n as f64 - 1.0);
data[[i, 1]] = (0.23 * i as f64).cos();
}
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "matern".to_string(),
basis: SmoothBasisSpec::Matern {
feature_cols: vec![0, 1],
spec: MaternBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 7 },
length_scale: 0.8,
nu: MaternNu::FiveHalves,
include_intercept: false,
double_penalty: true,
identifiability: MaternIdentifiability::CenterSumToZero,
aniso_log_scales: None,
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let design = build_term_collection_design(data.view(), &spec).expect("design");
let caches =
extract_spatial_operator_runtime_caches(&spec, &design).expect("runtime caches");
assert_eq!(caches.len(), 1);
let cache = &caches[0];
let s0 = {
let raw = cache.d0.t().dot(&cache.d0);
(&raw + &raw.t()) * 0.5
};
let s1 = {
let raw = cache.d1.t().dot(&cache.d1);
(&raw + &raw.t()) * 0.5
};
let s2 = {
let raw = cache.d2.t().dot(&cache.d2);
(&raw + &raw.t()) * 0.5
};
let s0_global = penalty_matrixwith_local_block(
design.design.ncols(),
cache.coeff_global_range.clone(),
&s0,
);
let s1_global = penalty_matrixwith_local_block(
design.design.ncols(),
cache.coeff_global_range.clone(),
&s1,
);
let s2_global = penalty_matrixwith_local_block(
design.design.ncols(),
cache.coeff_global_range.clone(),
&s2,
);
let p_total = design.design.ncols();
let err0 = (&s0_global
- &design.penalties[cache.mass_penalty_global_idx].to_global(p_total))
.iter()
.map(|v| v.abs())
.fold(0.0_f64, f64::max);
let err1 = (&s1_global
- &design.penalties[cache.tension_penalty_global_idx].to_global(p_total))
.iter()
.map(|v| v.abs())
.fold(0.0_f64, f64::max);
let err2 = (&s2_global
- &design.penalties[cache.stiffness_penalty_global_idx].to_global(p_total))
.iter()
.map(|v| v.abs())
.fold(0.0_f64, f64::max);
assert!(err0 < 1e-8, "mass penalty mismatch too large: {err0}");
assert!(err1 < 1e-8, "tension penalty mismatch too large: {err1}");
assert!(err2 < 1e-8, "stiffness penalty mismatch too large: {err2}");
}
#[test]
fn extracted_duchon_spatial_runtime_cache_matches_normalized_design_penalties() {
let n = 32usize;
let mut data = Array2::<f64>::zeros((n, 1));
for i in 0..n {
data[[i, 0]] = i as f64 / (n as f64 - 1.0);
}
let spec = TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![SmoothTermSpec {
name: "duchon".to_string(),
basis: SmoothBasisSpec::Duchon {
feature_cols: vec![0],
spec: DuchonBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 11 },
length_scale: Some(0.8),
power: 2,
nullspace_order: DuchonNullspaceOrder::Zero,
identifiability: SpatialIdentifiability::default(),
aniso_log_scales: None,
operator_penalties: DuchonOperatorPenaltySpec::default(),
},
input_scales: None,
},
shape: ShapeConstraint::None,
}],
};
let design = build_term_collection_design(data.view(), &spec).expect("design");
assert_eq!(design.penalties.len(), 3);
let caches =
extract_spatial_operator_runtime_caches(&spec, &design).expect("runtime caches");
assert_eq!(caches.len(), 1);
let cache = &caches[0];
let s0 = {
let raw = cache.d0.t().dot(&cache.d0);
(&raw + &raw.t()) * 0.5
};
let s1 = {
let raw = cache.d1.t().dot(&cache.d1);
(&raw + &raw.t()) * 0.5
};
let s2 = {
let raw = cache.d2.t().dot(&cache.d2);
(&raw + &raw.t()) * 0.5
};
let s0_global = penalty_matrixwith_local_block(
design.design.ncols(),
cache.coeff_global_range.clone(),
&s0,
);
let s1_global = penalty_matrixwith_local_block(
design.design.ncols(),
cache.coeff_global_range.clone(),
&s1,
);
let s2_global = penalty_matrixwith_local_block(
design.design.ncols(),
cache.coeff_global_range.clone(),
&s2,
);
let p_total = design.design.ncols();
let err0 = (&s0_global
- &design.penalties[cache.mass_penalty_global_idx].to_global(p_total))
.iter()
.map(|v| v.abs())
.fold(0.0_f64, f64::max);
let err1 = (&s1_global
- &design.penalties[cache.tension_penalty_global_idx].to_global(p_total))
.iter()
.map(|v| v.abs())
.fold(0.0_f64, f64::max);
let err2 = (&s2_global
- &design.penalties[cache.stiffness_penalty_global_idx].to_global(p_total))
.iter()
.map(|v| v.abs())
.fold(0.0_f64, f64::max);
assert!(
err0 < 1e-8,
"Duchon mass penalty mismatch too large: {err0}"
);
assert!(
err1 < 1e-8,
"Duchon tension penalty mismatch too large: {err1}"
);
assert!(
err2 < 1e-8,
"Duchon stiffness penalty mismatch too large: {err2}"
);
}
#[test]
fn spatial_adaptive_explicit_second_order_kind_matches_block_sparsity() {
let alpha_mass_0 = SpatialAdaptiveHyperSpec {
cache_index: 0,
kind: SpatialAdaptiveHyperKind::LogLambdaMagnitude,
};
let alpha_mass_1 = SpatialAdaptiveHyperSpec {
cache_index: 1,
kind: SpatialAdaptiveHyperKind::LogLambdaMagnitude,
};
let alpha_grad_0 = SpatialAdaptiveHyperSpec {
cache_index: 0,
kind: SpatialAdaptiveHyperKind::LogLambdaGradient,
};
let eta_mass = SpatialAdaptiveHyperSpec {
cache_index: 0,
kind: SpatialAdaptiveHyperKind::LogEpsilonMagnitude,
};
let eta_grad = SpatialAdaptiveHyperSpec {
cache_index: 0,
kind: SpatialAdaptiveHyperKind::LogEpsilonGradient,
};
assert_eq!(
alpha_mass_0.explicit_second_order_kind(alpha_mass_0),
SpatialAdaptiveExplicitSecondOrderKind::LocalAlphaAlpha
);
assert_eq!(
alpha_mass_0.explicit_second_order_kind(alpha_mass_1),
SpatialAdaptiveExplicitSecondOrderKind::StructuralZero
);
assert_eq!(
alpha_mass_0.explicit_second_order_kind(alpha_grad_0),
SpatialAdaptiveExplicitSecondOrderKind::StructuralZero
);
assert_eq!(
alpha_mass_1.explicit_second_order_kind(eta_mass),
SpatialAdaptiveExplicitSecondOrderKind::LocalAlphaEta
);
assert_eq!(
eta_mass.explicit_second_order_kind(alpha_mass_1),
SpatialAdaptiveExplicitSecondOrderKind::LocalAlphaEta
);
assert_eq!(
eta_mass.explicit_second_order_kind(eta_mass),
SpatialAdaptiveExplicitSecondOrderKind::SharedEtaEta
);
assert_eq!(
eta_mass.explicit_second_order_kind(eta_grad),
SpatialAdaptiveExplicitSecondOrderKind::StructuralZero
);
}
#[test]
fn scalar_charbonnier_exact_derivatives_match_finite_difference() {
let signal = array![0.7, -1.1];
let epsilon = 0.3;
let state = CharbonnierScalarBlockState::from_signal(signal.clone(), epsilon);
let h = 1e-5;
let value = |x: &Array1<f64>| {
CharbonnierScalarBlockState::from_signal(x.clone(), epsilon).penalty_value()
};
for i in 0..signal.len() {
let mut plus = signal.clone();
plus[i] += h;
let mut minus = signal.clone();
minus[i] -= h;
let gradfd = (value(&plus) - value(&minus)) / (2.0 * h);
let hessfd = (value(&plus) - 2.0 * value(&signal) + value(&minus)) / (h * h);
assert!((state.betagradient_coeff()[i] - gradfd).abs() < 1e-6);
assert!((state.betahessian_diag()[i] - hessfd).abs() < 1e-4);
}
}
#[test]
fn grouped_charbonnier_exactgradient_matches_finite_difference() {
let blocks = array![[0.8, -0.4], [0.3, 0.9]];
let epsilon = 0.25;
let state = CharbonnierGroupedBlockState::from_signal_blocks(blocks.clone(), epsilon);
let h = 1e-6;
let value = |x: &Array2<f64>| {
CharbonnierGroupedBlockState::from_signal_blocks(x.clone(), epsilon).penalty_value()
};
let analytic = state.betagradient_blocks();
for k in 0..blocks.nrows() {
for axis in 0..blocks.ncols() {
let mut plus = blocks.clone();
plus[[k, axis]] += h;
let mut minus = blocks.clone();
minus[[k, axis]] -= h;
let gradfd = (value(&plus) - value(&minus)) / (2.0 * h);
assert!((analytic[[k, axis]] - gradfd).abs() < 1e-6);
}
}
}
#[test]
fn scalar_charbonnier_log_epsilon_derivatives_match_finite_difference() {
let signal = array![0.4, -0.9];
let epsilon = 0.35_f64;
let eta = epsilon.ln();
let state = CharbonnierScalarBlockState::from_signal(signal.clone(), epsilon);
let h = 1e-5;
let value = |eta_value: f64| {
CharbonnierScalarBlockState::from_signal(signal.clone(), eta_value.exp())
.penalty_value()
};
let gradfd = (value(eta + h) - value(eta - h)) / (2.0 * h);
let hessfd = (value(eta + h) - 2.0 * value(eta) + value(eta - h)) / (h * h);
assert!((state.log_epsilon_gradient_terms().sum() - gradfd).abs() < 1e-6);
assert!((state.log_epsilon_hessian_terms().sum() - hessfd).abs() < 1e-4);
let eval_grad = |eta_value: f64| {
CharbonnierScalarBlockState::from_signal(signal.clone(), eta_value.exp())
.betagradient_coeff()
};
let mixedfd = (&eval_grad(eta + h) - &eval_grad(eta - h)) / (2.0 * h);
for i in 0..signal.len() {
assert!((state.log_epsilon_betagradient_coeff()[i] - mixedfd[i]).abs() < 1e-6);
}
let eval_hess = |eta_value: f64| {
CharbonnierScalarBlockState::from_signal(signal.clone(), eta_value.exp())
.betahessian_diag()
};
let betahessfd = (&eval_hess(eta + h) - &eval_hess(eta - h)) / (2.0 * h);
for i in 0..signal.len() {
assert!((state.log_epsilon_betahessian_diag()[i] - betahessfd[i]).abs() < 1e-5);
}
let eval_log_grad = |eta_value: f64| {
CharbonnierScalarBlockState::from_signal(signal.clone(), eta_value.exp())
.log_epsilon_betagradient_coeff()
};
let second_mixedfd = (&eval_log_grad(eta + h) - &eval_log_grad(eta - h)) / (2.0 * h);
for i in 0..signal.len() {
assert!(
(state.log_epsilon_beta_mixed_second_coeff()[i] - second_mixedfd[i]).abs() < 1e-5
);
}
let eval_log_hess = |eta_value: f64| {
CharbonnierScalarBlockState::from_signal(signal.clone(), eta_value.exp())
.log_epsilon_betahessian_diag()
};
let second_hessfd = (&eval_log_hess(eta + h) - &eval_log_hess(eta - h)) / (2.0 * h);
for i in 0..signal.len() {
assert!(
(state.log_epsilon_betahessian_second_diag()[i] - second_hessfd[i]).abs() < 1e-4
);
}
}
#[test]
fn scalar_charbonnier_directionalhessian_matches_finite_difference() {
let signal = array![0.5, -0.6];
let epsilon = 0.2;
let direction = array![0.3, -0.1];
let state = CharbonnierScalarBlockState::from_signal(signal.clone(), epsilon);
let h = 1e-6;
let analytic = state.directionalhessian_diag(&direction);
let evalhess = |step: f64| {
let shifted = &signal + &(direction.mapv(|v| step * v));
CharbonnierScalarBlockState::from_signal(shifted, epsilon).betahessian_diag()
};
let fd = (&evalhess(h) - &evalhess(-h)) / (2.0 * h);
for i in 0..signal.len() {
assert!((analytic[i] - fd[i]).abs() < 1e-5);
}
}
#[test]
fn scalar_charbonnier_log_epsilon_directionalhessian_matches_finite_difference() {
let signal = array![0.5, -0.6];
let epsilon = 0.2;
let direction = array![0.3, -0.1];
let state = CharbonnierScalarBlockState::from_signal(signal.clone(), epsilon);
let h = 1e-6;
let analytic = state.log_epsilon_betahessian_directional_diag(&direction);
let evalhess = |step: f64| {
let shifted = &signal + &(direction.mapv(|v| step * v));
CharbonnierScalarBlockState::from_signal(shifted, epsilon)
.log_epsilon_betahessian_diag()
};
let fd = (&evalhess(h) - &evalhess(-h)) / (2.0 * h);
for i in 0..signal.len() {
assert!((analytic[i] - fd[i]).abs() < 1e-4);
}
}
#[test]
fn grouped_charbonnier_log_epsilon_derivatives_match_finite_difference() {
let blocks = array![[0.7, -0.2], [0.1, 0.8]];
let epsilon = 0.3_f64;
let eta = epsilon.ln();
let state = CharbonnierGroupedBlockState::from_signal_blocks(blocks.clone(), epsilon);
let h = 1e-5;
let value = |eta_value: f64| {
CharbonnierGroupedBlockState::from_signal_blocks(blocks.clone(), eta_value.exp())
.penalty_value()
};
let gradfd = (value(eta + h) - value(eta - h)) / (2.0 * h);
let hessfd = (value(eta + h) - 2.0 * value(eta) + value(eta - h)) / (h * h);
assert!((state.log_epsilon_gradient_terms().sum() - gradfd).abs() < 1e-6);
assert!((state.log_epsilon_hessian_terms().sum() - hessfd).abs() < 1e-4);
let eval_grad = |eta_value: f64| {
CharbonnierGroupedBlockState::from_signal_blocks(blocks.clone(), eta_value.exp())
.betagradient_blocks()
};
let mixedfd = (&eval_grad(eta + h) - &eval_grad(eta - h)) / (2.0 * h);
let analytic_mixed = state.log_epsilon_betagradient_blocks();
for k in 0..blocks.nrows() {
for axis in 0..blocks.ncols() {
assert!((analytic_mixed[[k, axis]] - mixedfd[[k, axis]]).abs() < 1e-6);
}
}
let eval_hess = |eta_value: f64| {
CharbonnierGroupedBlockState::from_signal_blocks(blocks.clone(), eta_value.exp())
.betahessian_blocks()
};
let plus_hess = eval_hess(eta + h);
let minus_hess = eval_hess(eta - h);
let analytic_hess = state.log_epsilon_betahessian_blocks();
for k in 0..analytic_hess.len() {
let fd = (&plus_hess[k] - &minus_hess[k]) / (2.0 * h);
for i in 0..fd.nrows() {
for j in 0..fd.ncols() {
assert!((analytic_hess[k][[i, j]] - fd[[i, j]]).abs() < 1e-5);
}
}
}
let eval_log_grad = |eta_value: f64| {
CharbonnierGroupedBlockState::from_signal_blocks(blocks.clone(), eta_value.exp())
.log_epsilon_betagradient_blocks()
};
let second_mixedfd = (&eval_log_grad(eta + h) - &eval_log_grad(eta - h)) / (2.0 * h);
let analytic_second_mixed = state.log_epsilon_beta_mixed_second_blocks();
for k in 0..blocks.nrows() {
for axis in 0..blocks.ncols() {
assert!(
(analytic_second_mixed[[k, axis]] - second_mixedfd[[k, axis]]).abs() < 1e-5
);
}
}
let eval_log_hess = |eta_value: f64| {
CharbonnierGroupedBlockState::from_signal_blocks(blocks.clone(), eta_value.exp())
.log_epsilon_betahessian_blocks()
};
let plus_log_hess = eval_log_hess(eta + h);
let minus_log_hess = eval_log_hess(eta - h);
let analytic_second_hess = state.log_epsilon_betahessian_second_blocks();
for k in 0..analytic_second_hess.len() {
let fd = (&plus_log_hess[k] - &minus_log_hess[k]) / (2.0 * h);
for i in 0..fd.nrows() {
for j in 0..fd.ncols() {
assert!((analytic_second_hess[k][[i, j]] - fd[[i, j]]).abs() < 1e-4);
}
}
}
}
#[test]
fn grouped_charbonnier_directionalhessian_matches_finite_difference() {
let blocks = array![[0.6, -0.2], [0.4, 0.5]];
let direction = array![[0.1, -0.3], [0.2, 0.15]];
let epsilon = 0.4;
let state = CharbonnierGroupedBlockState::from_signal_blocks(blocks.clone(), epsilon);
let analytic = state.directionalhessian_blocks(&direction);
let h = 1e-6;
let evalhess = |step: f64| {
let shifted = &blocks + &(direction.mapv(|v| step * v));
CharbonnierGroupedBlockState::from_signal_blocks(shifted, epsilon).betahessian_blocks()
};
let plus = evalhess(h);
let minus = evalhess(-h);
for k in 0..analytic.len() {
let fd = (&plus[k] - &minus[k]) / (2.0 * h);
for i in 0..fd.nrows() {
for j in 0..fd.ncols() {
assert!((analytic[k][[i, j]] - fd[[i, j]]).abs() < 1e-5);
}
}
}
}
#[test]
fn grouped_charbonnier_log_epsilon_directionalhessian_matches_finite_difference() {
let blocks = array![[0.6, -0.2], [0.4, 0.5]];
let direction = array![[0.1, -0.3], [0.2, 0.15]];
let epsilon = 0.4;
let state = CharbonnierGroupedBlockState::from_signal_blocks(blocks.clone(), epsilon);
let analytic = state.log_epsilon_betahessian_directional_blocks(&direction);
let h = 1e-6;
let evalhess = |step: f64| {
let shifted = &blocks + &(direction.mapv(|v| step * v));
CharbonnierGroupedBlockState::from_signal_blocks(shifted, epsilon)
.log_epsilon_betahessian_blocks()
};
let plus = evalhess(h);
let minus = evalhess(-h);
for k in 0..analytic.len() {
let fd = (&plus[k] - &minus[k]) / (2.0 * h);
for i in 0..fd.nrows() {
for j in 0..fd.ncols() {
assert!((analytic[k][[i, j]] - fd[[i, j]]).abs() < 1e-4);
}
}
}
}
#[test]
fn grouped_charbonnier_directionalhessian_blocks_are_symmetric() {
let blocks = array![[0.6, -0.2], [0.4, 0.5]];
let direction = array![[0.1, -0.3], [0.2, 0.15]];
let epsilon = 0.4;
let analytic = CharbonnierGroupedBlockState::from_signal_blocks(blocks, epsilon)
.directionalhessian_blocks(&direction);
for (k, block) in analytic.iter().enumerate() {
for i in 0..block.nrows() {
for j in 0..block.ncols() {
assert!(
(block[[i, j]] - block[[j, i]]).abs() < 1e-12,
"directional Hessian block {k} is not symmetric at ({i},{j})"
);
}
}
}
}
#[test]
fn scalar_charbonnier_local_quadratic_curvature_matches_transition_scale() {
let signal = array![0.0, 0.0, 0.0];
let small = CharbonnierScalarBlockState::from_signal(signal.clone(), 1e-3);
let large = CharbonnierScalarBlockState::from_signal(signal, 1e3);
for (&a, &b) in small
.betahessian_diag()
.iter()
.zip(large.betahessian_diag().iter())
{
assert!(
(a - 1e3).abs() < 1e-7,
"small-epsilon curvature should be 1/eps, got {a}"
);
assert!(
(b - 1e-3).abs() < 1e-13,
"large-epsilon curvature should be 1/eps, got {b}"
);
}
}
#[test]
fn grouped_charbonnier_local_quadratic_curvature_matches_transition_scale() {
let blocks = array![[0.0, 0.0], [0.0, 0.0]];
let small = CharbonnierGroupedBlockState::from_signal_blocks(blocks.clone(), 1e-3);
let large = CharbonnierGroupedBlockState::from_signal_blocks(blocks, 1e3);
for (small_block, large_block) in small
.betahessian_blocks()
.into_iter()
.zip(large.betahessian_blocks().into_iter())
{
let eye = Array2::<f64>::eye(small_block.nrows());
assert!(
(&small_block - &eye.mapv(|v| 1e3 * v)).mapv(f64::abs).sum() < 1e-7,
"small-epsilon grouped curvature should equal I/eps"
);
assert!(
(&large_block - &eye.mapv(|v| 1e-3 * v))
.mapv(f64::abs)
.sum()
< 1e-13,
"large-epsilon grouped curvature should equal I/eps"
);
}
}
#[test]
fn scalar_charbonnier_small_signal_matches_local_half_quadratic() {
let signal = array![1e-5, -2e-5, 3e-5];
for &epsilon in &[1e-3, 1e-1, 1.0, 1e2] {
let state = CharbonnierScalarBlockState::from_signal(signal.clone(), epsilon);
let value = state.penalty_value();
let target = 0.5 * signal.iter().map(|v| v * v).sum::<f64>() / epsilon;
let rel = (value - target).abs() / target.max(1e-20);
assert!(
rel < 5e-3,
"scalar Charbonnier should match 0.5*t^2/eps locally: eps={epsilon}, value={value}, target={target}, rel={rel}"
);
}
}
#[test]
fn grouped_charbonnier_small_signal_matches_local_half_quadratic() {
let blocks = array![[1e-5, -2e-5], [3e-5, 4e-5]];
for &epsilon in &[1e-3, 1e-1, 1.0, 1e2] {
let state = CharbonnierGroupedBlockState::from_signal_blocks(blocks.clone(), epsilon);
let value = state.penalty_value();
let target = 0.5 * blocks.iter().map(|v| v * v).sum::<f64>() / epsilon;
let rel = (value - target).abs() / target.max(1e-20);
assert!(
rel < 5e-3,
"grouped Charbonnier should match 0.5*||v||^2/eps locally: eps={epsilon}, value={value}, target={target}, rel={rel}"
);
}
}
#[test]
fn adaptive_diagnostics_json_roundtrip_preserves_shapes() {
let diag = AdaptiveRegularizationDiagnostics {
epsilon_0: 0.01,
epsilon_g: 0.02,
epsilon_c: 0.03,
epsilon_outer_iterations: 2,
mm_iterations: 3,
converged: true,
maps: vec![AdaptiveSpatialMap {
termname: "matern".to_string(),
feature_cols: vec![0, 1],
collocation_points: array![[1.0, 2.0], [3.0, 4.0]],
inv_magweight: array![0.05, 0.15],
invgradweight: array![0.1, 0.2],
inv_lapweight: array![0.3, 0.4],
}],
};
let payload = serde_json::to_value(&diag).expect("serialize diagnostics");
assert_eq!(payload["mm_iterations"].as_u64(), Some(3));
assert_eq!(
payload["maps"][0]["collocation_points"]["dim"]
.as_array()
.map(|v| v.len()),
Some(2)
);
let decoded: AdaptiveRegularizationDiagnostics =
serde_json::from_value(payload).expect("deserialize diagnostics");
assert_eq!(decoded.epsilon_outer_iterations, 2);
assert_eq!(decoded.mm_iterations, 3);
assert!(decoded.converged);
assert_eq!(decoded.maps.len(), 1);
assert_eq!(decoded.maps[0].collocation_points.nrows(), 2);
assert_eq!(decoded.maps[0].collocation_points.ncols(), 2);
assert_eq!(decoded.maps[0].invgradweight.len(), 2);
assert_eq!(decoded.maps[0].inv_lapweight.len(), 2);
}
}