fn try_build_spatial_term_log_kappa_derivative(
data: ArrayView2<'_, f64>,
resolvedspec: &TermCollectionSpec,
design: &TermCollectionDesign,
term_idx: usize,
) -> Result<
Option<(
Range<usize>,
usize,
Array2<f64>,
Array2<f64>,
Array2<f64>,
Array2<f64>,
Vec<Array2<f64>>,
Vec<Array2<f64>>,
Option<std::sync::Arc<crate::terms::basis::ImplicitDesignPsiDerivative>>,
)>,
EstimationError,
> {
let Some(smooth_term) = design.smooth.terms.get(term_idx) else {
return Ok(None);
};
let Some(termspec) = resolvedspec.smooth_terms.get(term_idx) else {
return Ok(None);
};
let derivative_bundle = match &termspec.basis {
SmoothBasisSpec::ThinPlate {
feature_cols,
spec,
input_scales,
} => {
let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
let mut spec_local = spec.clone();
if let Some(s) = input_scales {
apply_input_standardization(&mut x, s);
spec_local.length_scale =
compensate_length_scale_for_standardization(spec.length_scale, s);
}
build_thin_plate_basis_log_kappa_derivatives(x.view(), &spec_local)
.map_err(EstimationError::from)?
}
SmoothBasisSpec::Sphere { .. } => return Ok(None),
SmoothBasisSpec::ConstantCurvature { feature_cols, spec } => {
let x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
build_constant_curvature_basis_kappa_derivatives(x.view(), spec)
.map_err(EstimationError::from)?
}
SmoothBasisSpec::MeasureJet { .. } => return Ok(None),
SmoothBasisSpec::Matern {
feature_cols,
spec,
input_scales,
} => {
let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
let mut spec_local = spec.clone();
if let Some(s) = input_scales {
apply_input_standardization(&mut x, s);
spec_local.length_scale =
compensate_length_scale_for_standardization(spec.length_scale, s);
}
build_matern_basis_log_kappa_derivatives(x.view(), &spec_local)
.map_err(EstimationError::from)?
}
SmoothBasisSpec::Duchon {
feature_cols,
spec,
input_scales,
} => {
let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
let mut spec_local = spec.clone();
if let Some(s) = input_scales {
apply_input_standardization(&mut x, s);
spec_local.length_scale =
compensate_optional_length_scale_for_standardization(spec.length_scale, s);
}
let BasisMetadata::Duchon {
centers,
identifiability_transform,
operator_collocation_points,
radial_reparam,
..
} = &smooth_term.metadata
else {
return Ok(None);
};
if spec_local.radial_reparam.is_none() {
spec_local.radial_reparam = radial_reparam.clone();
}
crate::basis::build_duchon_basis_log_kappa_derivativeswith_collocationwithworkspace(
x.view(),
&spec_local,
centers.view(),
identifiability_transform.as_ref(),
operator_collocation_points
.as_ref()
.map(|points| points.view()),
&mut BasisWorkspace::default(),
)
.map_err(EstimationError::from)?
}
SmoothBasisSpec::BSpline1D { .. }
| SmoothBasisSpec::TensorBSpline { .. }
| SmoothBasisSpec::ByVariable { .. }
| SmoothBasisSpec::FactorSumToZero { .. }
| SmoothBasisSpec::BySmooth { .. }
| SmoothBasisSpec::FactorSmooth { .. }
| SmoothBasisSpec::Pca { .. } => {
return Ok(None);
}
};
let mut implicit_operator = derivative_bundle.implicit_operator;
let BasisPsiDerivativeResult {
design_derivative: mut local_x_psi,
penalties_derivative: mut local_s_psi,
implicit_operator: local_implicit_first_unused,
} = derivative_bundle.first;
let BasisPsiSecondDerivativeResult {
designsecond_derivative: mut local_x_psi_psi,
penaltiessecond_derivative: mut local_s_psi_psi,
implicit_operator: local_implicit_second_unused,
} = derivative_bundle.second;
assert!(local_implicit_first_unused.is_none());
assert!(local_implicit_second_unused.is_none());
if let Some(rotation) = smooth_term.joint_null_rotation.as_ref() {
let q = &rotation.rotation;
if let Some(op) = implicit_operator.take() {
implicit_operator = Some(op.append_full_transform(q).map_err(EstimationError::from)?);
} else {
if local_x_psi.ncols() != q.nrows() || local_x_psi_psi.ncols() != q.nrows() {
return Ok(None);
}
local_x_psi = fast_ab(&local_x_psi, q);
local_x_psi_psi = fast_ab(&local_x_psi_psi, q);
}
let rotate_penalty = |s_local: Array2<f64>| -> Option<Array2<f64>> {
if s_local.nrows() != q.nrows() || s_local.ncols() != q.nrows() {
return None;
}
let qt_s = crate::linalg::faer_ndarray::fast_atb(q, &s_local);
Some(crate::linalg::faer_ndarray::fast_ab(&qt_s, q))
};
let Some(rotated_s_psi) = local_s_psi
.into_iter()
.map(|s| rotate_penalty(s))
.collect::<Option<Vec<_>>>()
else {
return Ok(None);
};
local_s_psi = rotated_s_psi;
let Some(rotated_s_psi_psi) = local_s_psi_psi
.into_iter()
.map(|s| rotate_penalty(s))
.collect::<Option<Vec<_>>>()
else {
return Ok(None);
};
local_s_psi_psi = rotated_s_psi_psi;
}
let implicit_operator = implicit_operator.map(std::sync::Arc::new);
if let Some(ref op) = implicit_operator {
if op.p_out() != smooth_term.coeff_range.len() {
return Ok(None);
}
} else {
if local_x_psi.ncols() != smooth_term.coeff_range.len() {
return Ok(None);
}
if local_x_psi_psi.ncols() != smooth_term.coeff_range.len() {
return Ok(None);
}
}
if local_s_psi.is_empty() || local_s_psi.len() != local_s_psi_psi.len() {
return Ok(None);
}
if local_s_psi.iter().any(|s| {
s.nrows() != smooth_term.coeff_range.len() || s.ncols() != smooth_term.coeff_range.len()
}) {
return Ok(None);
}
if local_s_psi_psi.iter().any(|s| {
s.nrows() != smooth_term.coeff_range.len() || s.ncols() != smooth_term.coeff_range.len()
}) {
return Ok(None);
}
let p_total = design.design.ncols();
let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
let global_range = (smooth_start + smooth_term.coeff_range.start)
..(smooth_start + smooth_term.coeff_range.end);
Ok(Some((
global_range,
p_total,
local_x_psi,
local_s_psi.iter().fold(
Array2::<f64>::zeros((smooth_term.coeff_range.len(), smooth_term.coeff_range.len())),
|acc, m| acc + m,
),
local_x_psi_psi,
local_s_psi_psi.iter().fold(
Array2::<f64>::zeros((smooth_term.coeff_range.len(), smooth_term.coeff_range.len())),
|acc, m| acc + m,
),
local_s_psi,
local_s_psi_psi,
implicit_operator,
)))
}
fn try_build_spatial_log_kappa_hyper_dirs(
data: ArrayView2<'_, f64>,
resolvedspec: &TermCollectionSpec,
design: &TermCollectionDesign,
spatial_terms: &[usize],
) -> Result<Option<Vec<DirectionalHyperParam>>, EstimationError> {
let Some(info_list) =
try_build_spatial_log_kappa_derivativeinfo_list(data, resolvedspec, design, spatial_terms)?
else {
return Ok(None);
};
Ok(Some(spatial_log_kappa_hyper_dirs_frominfo_list(info_list)?))
}
pub(crate) fn try_build_latent_coord_hyper_dirs(
latent: std::sync::Arc<crate::terms::latent::LatentCoordValues>,
resolvedspec: &TermCollectionSpec,
design: &TermCollectionDesign,
latent_terms: &[crate::types::SmoothTermIdx],
analytic_rho_count: usize,
) -> Result<Option<Vec<DirectionalHyperParam>>, EstimationError> {
if latent_terms.is_empty() || latent.is_empty() {
return Ok(None);
}
if latent_terms.len() != 1 {
crate::bail_invalid_estim!(
"LatentCoord standard-fit hyper_dirs currently require exactly one latent smooth term"
.to_string(),
);
}
let term_idx = latent_terms[0];
let smooth_term = design.smooth.terms.get(term_idx.get()).ok_or_else(|| {
EstimationError::InvalidInput(format!(
"LatentCoord term index {term_idx} out of bounds for realized smooth design"
))
})?;
let termspec = resolvedspec
.smooth_terms
.get(term_idx.get())
.ok_or_else(|| {
EstimationError::InvalidInput(format!(
"LatentCoord term index {term_idx} out of bounds for resolved smooth spec"
))
})?;
let p_total = design.design.ncols();
let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
let global_range = (smooth_start + smooth_term.coeff_range.start)
..(smooth_start + smooth_term.coeff_range.end);
let operator = match (&termspec.basis, &smooth_term.metadata) {
(
SmoothBasisSpec::Matern { .. },
BasisMetadata::Matern {
centers,
length_scale,
nu,
include_intercept,
identifiability_transform,
..
},
) => crate::terms::basis::LatentCoordDesignDerivative::new_matern(
latent.clone(),
std::sync::Arc::new(centers.clone()),
*length_scale,
*nu,
*include_intercept,
identifiability_transform.clone(),
)
.map_err(EstimationError::from)?,
(
SmoothBasisSpec::Duchon { .. },
BasisMetadata::Duchon {
centers,
length_scale,
power,
nullspace_order,
identifiability_transform,
..
},
) => crate::terms::basis::LatentCoordDesignDerivative::new_duchon(
latent.clone(),
std::sync::Arc::new(centers.clone()),
*length_scale,
*power,
*nullspace_order,
identifiability_transform.clone(),
)
.map_err(EstimationError::from)?,
(
SmoothBasisSpec::Sphere { .. },
BasisMetadata::Sphere {
centers,
penalty_order,
method,
constraint_transform,
..
},
) if matches!(*method, crate::basis::SphereMethod::Wahba) => {
crate::terms::basis::LatentCoordDesignDerivative::new_sphere(
latent.clone(),
std::sync::Arc::new(centers.clone()),
*penalty_order,
constraint_transform.clone(),
)
.map_err(EstimationError::from)?
}
(
SmoothBasisSpec::BSpline1D { spec, .. },
BasisMetadata::BSpline1D {
knots,
identifiability_transform,
periodic,
degree: meta_degree,
..
},
) => {
let effective_degree = meta_degree.unwrap_or(spec.degree);
if let Some((domain_start, period, num_basis)) = periodic {
crate::terms::basis::LatentCoordDesignDerivative::new_periodic_bspline(
latent.clone(),
(*domain_start, *domain_start + *period),
effective_degree,
*num_basis,
identifiability_transform.clone(),
)
.map_err(EstimationError::from)?
} else {
crate::terms::basis::LatentCoordDesignDerivative::new_tensor_bspline(
latent.clone(),
vec![knots.clone()],
vec![effective_degree],
identifiability_transform.clone(),
)
.map_err(EstimationError::from)?
}
}
(
SmoothBasisSpec::TensorBSpline { .. },
BasisMetadata::TensorBSpline {
knots,
degrees,
identifiability_transform,
..
},
) => crate::terms::basis::LatentCoordDesignDerivative::new_tensor_bspline(
latent.clone(),
knots.clone(),
degrees.clone(),
identifiability_transform.clone(),
)
.map_err(EstimationError::from)?,
(SmoothBasisSpec::Pca { .. }, BasisMetadata::Pca { basis_matrix, .. }) => {
crate::terms::basis::LatentCoordDesignDerivative::new_pca(
latent.clone(),
std::sync::Arc::new(basis_matrix.clone()),
)
.map_err(EstimationError::from)?
}
_ => return Ok(None),
};
if operator.p_out() != global_range.len() {
crate::bail_invalid_estim!(
"LatentCoord derivative width mismatch for term '{}': operator p={}, coeff range={}",
smooth_term.name,
operator.p_out(),
global_range.len()
);
}
let operator = std::sync::Arc::new(operator);
let mut hyper_dirs = Vec::with_capacity(operator.n_axes());
for flat_axis in 0..operator.n_axes() {
let dir = DirectionalHyperParam::new_compact(
crate::estimate::reml::HyperDesignDerivative::from_latent_coord(
operator.clone(),
flat_axis,
global_range.clone(),
p_total,
),
Vec::new(),
None,
None,
)?
.not_penalty_like();
hyper_dirs.push(dir);
}
let direct_dim = latent_coord_direct_hyper_count(latent.id_mode(), latent.latent_dim());
if analytic_rho_count + direct_dim > 0 {
let zero_x = crate::estimate::reml::HyperDesignDerivative::from(Array2::<f64>::zeros((
design.design.nrows(),
p_total,
)));
for _ in 0..analytic_rho_count {
hyper_dirs.push(
DirectionalHyperParam::new_compact(zero_x.clone(), Vec::new(), None, None)?
.not_penalty_like(),
);
}
for _ in 0..direct_dim {
hyper_dirs.push(
DirectionalHyperParam::new_compact(zero_x.clone(), Vec::new(), None, None)?
.not_penalty_like(),
);
}
}
Ok(Some(hyper_dirs))
}
fn latent_coord_direct_hyper_count(
id_mode: &crate::terms::latent::LatentIdMode,
latent_dim: usize,
) -> usize {
use crate::terms::latent::{AuxPriorStrength, LatentIdMode};
match id_mode {
LatentIdMode::AuxPrior { strength, .. } => match strength {
AuxPriorStrength::Auto => 1,
AuxPriorStrength::Fixed(_) => 0,
},
LatentIdMode::AuxPriorDimSelection { strength, .. } => {
latent_dim
+ match strength {
AuxPriorStrength::Auto => 1,
AuxPriorStrength::Fixed(_) => 0,
}
}
LatentIdMode::DimSelection { .. } => latent_dim,
LatentIdMode::IsometryToReference { strength, .. } => match strength {
AuxPriorStrength::Auto => 1,
AuxPriorStrength::Fixed(_) => 0,
},
LatentIdMode::AuxOutcome { head, .. } => head.n_coeffs(latent_dim) + latent_dim,
LatentIdMode::None => 0,
}
}
fn latent_coord_initial_direct_hypers(
id_mode: &crate::terms::latent::LatentIdMode,
latent_dim: usize,
) -> Result<Array1<f64>, EstimationError> {
use crate::terms::latent::{AuxPriorStrength, LatentIdMode};
let mut values = Vec::with_capacity(latent_coord_direct_hyper_count(id_mode, latent_dim));
match id_mode {
LatentIdMode::AuxPrior { strength, .. } => {
if matches!(strength, AuxPriorStrength::Auto) {
values.push(0.0);
}
}
LatentIdMode::AuxPriorDimSelection {
strength,
init_log_precision,
..
} => {
if matches!(strength, AuxPriorStrength::Auto) {
values.push(0.0);
}
append_latent_ard_seed(&mut values, init_log_precision.as_ref(), latent_dim)?;
}
LatentIdMode::DimSelection { init_log_precision } => {
append_latent_ard_seed(&mut values, init_log_precision.as_ref(), latent_dim)?;
}
LatentIdMode::IsometryToReference { strength, .. } => {
if matches!(strength, AuxPriorStrength::Auto) {
values.push(0.0);
}
}
LatentIdMode::AuxOutcome {
head,
init_log_precision,
} => {
values.extend(std::iter::repeat_n(0.0, head.n_coeffs(latent_dim)));
append_latent_ard_seed(&mut values, init_log_precision.as_ref(), latent_dim)?;
}
LatentIdMode::None => {}
}
Ok(Array1::from_vec(values))
}
fn append_latent_ard_seed(
values: &mut Vec<f64>,
init: Option<&Array1<f64>>,
latent_dim: usize,
) -> Result<(), EstimationError> {
if let Some(init) = init {
if init.len() != latent_dim {
crate::bail_invalid_estim!(
"latent dim_selection init_log_precision length mismatch: got {}, expected {}",
init.len(),
latent_dim
);
}
values.extend(init.iter().copied());
} else {
values.extend(std::iter::repeat_n(0.0, latent_dim));
}
Ok(())
}
struct LatentIdObjectiveContribution {
cost: f64,
gradient: Array1<f64>,
}
fn latent_id_objective_contribution(
theta: &Array1<f64>,
rho_dim: usize,
analytic_rho_count: usize,
latent: &crate::terms::latent::LatentCoordValues,
) -> Result<LatentIdObjectiveContribution, EstimationError> {
use crate::terms::latent::{AuxPriorStrength, LatentIdMode, aux_prior_targets};
let n_obs = latent.n_obs();
let latent_dim = latent.latent_dim();
let flat_len = latent.len();
let mut gradient = Array1::<f64>::zeros(theta.len());
let t_start = rho_dim;
let direct_start = t_start + flat_len + analytic_rho_count;
if theta.len() < direct_start {
crate::bail_invalid_estim!(
"latent-coordinate theta too short for id objective: got {}, need at least {}",
theta.len(),
direct_start
);
}
let t = latent.as_matrix();
let mut cost = 0.0;
let mut cursor = direct_start;
match latent.id_mode() {
LatentIdMode::AuxPrior {
u,
family,
strength,
}
| LatentIdMode::AuxPriorDimSelection {
u,
family,
strength,
..
} => {
let (log_mu, mu) = match strength {
AuxPriorStrength::Fixed(mu) => (mu.ln(), *mu),
AuxPriorStrength::Auto => {
let log_mu = theta[cursor];
cursor += 1;
(log_mu, log_mu.exp())
}
};
let targets = aux_prior_targets(t.view(), u.view(), *family)
.map_err(EstimationError::InvalidInput)?;
let residual = &t - &targets;
let q = residual.iter().map(|v| v * v).sum::<f64>();
let k = (n_obs * latent_dim) as f64;
cost += 0.5 * mu * q - 0.5 * k * log_mu;
let projected_residual = aux_prior_targets(residual.view(), u.view(), *family)
.map_err(EstimationError::InvalidInput)?;
let grad_base = residual - projected_residual;
for n in 0..n_obs {
for axis in 0..latent_dim {
gradient[t_start + n * latent_dim + axis] += mu * grad_base[[n, axis]];
}
}
if matches!(strength, AuxPriorStrength::Auto) {
gradient[direct_start] += 0.5 * mu * q - 0.5 * k;
}
}
LatentIdMode::IsometryToReference { reference, strength } => {
if reference.dim() != (n_obs, latent_dim) {
crate::bail_invalid_estim!(
"IsometryToReference reference shape {:?} must equal (n_obs, latent_dim) = ({}, {})",
reference.dim(),
n_obs,
latent_dim
);
}
let mu_slot = cursor;
let (log_mu, mu) = match strength {
AuxPriorStrength::Fixed(mu) => (mu.ln(), *mu),
AuxPriorStrength::Auto => {
let log_mu = theta[cursor];
cursor += 1;
(log_mu, log_mu.exp())
}
};
let residual = &t - reference;
let q = residual.iter().map(|v| v * v).sum::<f64>();
let k = (n_obs * latent_dim) as f64;
cost += 0.5 * mu * q - 0.5 * k * log_mu;
for n in 0..n_obs {
for axis in 0..latent_dim {
gradient[t_start + n * latent_dim + axis] += mu * residual[[n, axis]];
}
}
if matches!(strength, AuxPriorStrength::Auto) {
gradient[mu_slot] += 0.5 * mu * q - 0.5 * k;
}
}
LatentIdMode::AuxOutcome { head, .. } => {
let n_coeffs = head.n_coeffs(latent_dim);
let coeffs = theta
.slice(ndarray::s![cursor..cursor + n_coeffs])
.to_owned();
let (head_nll, grad_coeffs, grad_t) = head
.neg_loglik_and_grad(t.view(), coeffs.view())
.map_err(EstimationError::InvalidInput)?;
cost += head_nll;
for (offset, &g) in grad_coeffs.iter().enumerate() {
gradient[cursor + offset] += g;
}
for n in 0..n_obs {
for axis in 0..latent_dim {
gradient[t_start + n * latent_dim + axis] += grad_t[[n, axis]];
}
}
cursor += n_coeffs;
}
LatentIdMode::DimSelection { .. } | LatentIdMode::None => {}
}
match latent.id_mode() {
LatentIdMode::AuxPriorDimSelection { .. }
| LatentIdMode::DimSelection { .. }
| LatentIdMode::AuxOutcome { .. } => {
for axis in 0..latent_dim {
let log_alpha = theta[cursor + axis];
let alpha = log_alpha.exp();
let mut q_axis = 0.0;
for n in 0..n_obs {
let flat_idx = n * latent_dim + axis;
let value = latent.as_flat()[flat_idx];
q_axis += value * value;
gradient[t_start + flat_idx] += alpha * value;
}
cost += 0.5 * alpha * q_axis - 0.5 * n_obs as f64 * log_alpha;
gradient[cursor + axis] += 0.5 * alpha * q_axis - 0.5 * n_obs as f64;
}
cursor += latent_dim;
}
LatentIdMode::AuxPrior { .. }
| LatentIdMode::IsometryToReference { .. }
| LatentIdMode::None => {}
}
if cursor != theta.len() {
crate::bail_invalid_estim!(
"latent-coordinate direct hyperparameter length mismatch: consumed {}, theta len {}",
cursor,
theta.len()
);
}
Ok(LatentIdObjectiveContribution { cost, gradient })
}
fn add_latent_id_objective_to_eval(
theta: &Array1<f64>,
rho_dim: usize,
analytic_rho_count: usize,
latent: &crate::terms::latent::LatentCoordValues,
eval: &mut (
f64,
Array1<f64>,
crate::solver::rho_optimizer::HessianResult,
),
) -> Result<(), EstimationError> {
let contribution =
latent_id_objective_contribution(theta, rho_dim, analytic_rho_count, latent)?;
eval.0 += contribution.cost;
if eval.1.len() != contribution.gradient.len() {
crate::bail_invalid_estim!(
"latent-coordinate REML gradient length mismatch: base={}, id={}",
eval.1.len(),
contribution.gradient.len()
);
}
eval.1 += &contribution.gradient;
if eval.2.is_analytic() {
eval.2 = crate::solver::rho_optimizer::HessianResult::Unavailable;
}
Ok(())
}
fn analytic_penalty_objective_contribution(
theta: &Array1<f64>,
rho_dim: usize,
latent: &crate::terms::latent::LatentCoordValues,
registry: &crate::terms::AnalyticPenaltyRegistry,
) -> Result<LatentIdObjectiveContribution, EstimationError> {
let flat_len = latent.len();
let t_start = rho_dim;
let t_end = t_start + flat_len;
let rho_start = t_end;
let rho_end = rho_start + registry.total_rho_count();
if theta.len() < rho_end {
crate::bail_invalid_estim!(
"latent-coordinate theta too short for analytic penalties: got {}, need at least {}",
theta.len(),
rho_end
);
}
let target_t = theta.slice(s![t_start..t_end]);
let rho = theta.slice(s![rho_start..rho_end]);
let mut cost = 0.0_f64;
let mut gradient = Array1::<f64>::zeros(theta.len());
for (penalty, (rho_slice, tier, name)) in registry.penalties.iter().zip(registry.rho_layout()) {
let rho_local = rho.slice(s![rho_slice.clone()]);
match tier {
crate::terms::PenaltyTier::Psi => {
cost += penalty.value(target_t.view(), rho_local);
let grad = penalty.grad_target(target_t.view(), rho_local);
if grad.len() != flat_len {
crate::bail_invalid_estim!(
"analytic penalty {name:?} gradient length mismatch: got {}, expected {}",
grad.len(),
flat_len
);
}
for i in 0..flat_len {
gradient[t_start + i] += grad[i];
}
let grad_rho_local = penalty.grad_rho(target_t.view(), rho_local);
if grad_rho_local.len() != rho_slice.len() {
crate::bail_invalid_estim!(
"analytic penalty {name:?} rho-gradient length mismatch: got {}, expected {}",
grad_rho_local.len(),
rho_slice.len()
);
}
for local_idx in 0..grad_rho_local.len() {
gradient[rho_start + rho_slice.start + local_idx] += grad_rho_local[local_idx];
}
}
crate::terms::PenaltyTier::Beta => {}
crate::terms::PenaltyTier::Rho => {}
}
}
Ok(LatentIdObjectiveContribution { cost, gradient })
}
fn add_analytic_penalty_hessian_to_eval(
theta: &Array1<f64>,
rho_dim: usize,
latent: &crate::terms::latent::LatentCoordValues,
registry: &crate::terms::AnalyticPenaltyRegistry,
eval: &mut (
f64,
Array1<f64>,
crate::solver::rho_optimizer::HessianResult,
),
) -> Result<(), EstimationError> {
let flat_len = latent.len();
let t_start = rho_dim;
let t_end = t_start + flat_len;
let rho_start = t_end;
let rho_end = rho_start + registry.total_rho_count();
if theta.len() < rho_end {
crate::bail_invalid_estim!(
"latent-coordinate theta too short for analytic penalty Hessian: got {}, need at least {}",
theta.len(),
rho_end
);
}
let crate::solver::rho_optimizer::HessianResult::Analytic(hessian) = &mut eval.2 else {
if eval.2.is_analytic() {
eval.2 = crate::solver::rho_optimizer::HessianResult::Unavailable;
}
return Ok(());
};
if hessian.dim() != (theta.len(), theta.len()) {
crate::bail_invalid_estim!(
"analytic penalty Hessian target shape mismatch: got {}x{}, expected {}x{}",
hessian.nrows(),
hessian.ncols(),
theta.len(),
theta.len()
);
}
let target_t = theta.slice(s![t_start..t_end]);
let rho = theta.slice(s![rho_start..rho_end]);
for (penalty, (rho_slice, tier, _name)) in registry.penalties.iter().zip(registry.rho_layout())
{
let rho_local = rho.slice(s![rho_slice]);
if !matches!(tier, crate::terms::PenaltyTier::Psi) {
continue;
}
if let Some(diag) = penalty.hessian_diag(target_t.view(), rho_local) {
if diag.len() != flat_len {
crate::bail_invalid_estim!(
"analytic penalty Hessian diagonal length mismatch: got {}, expected {}",
diag.len(),
flat_len
);
}
for i in 0..flat_len {
hessian[[t_start + i, t_start + i]] += diag[i];
}
continue;
}
let mut probe = Array1::<f64>::zeros(flat_len);
for col in 0..flat_len {
probe[col] = 1.0;
let hv = penalty.hvp(target_t.view(), rho_local, probe.view());
if hv.len() != flat_len {
crate::bail_invalid_estim!(
"analytic penalty Hessian-vector length mismatch: got {}, expected {}",
hv.len(),
flat_len
);
}
for row in 0..flat_len {
hessian[[t_start + row, t_start + col]] += hv[row];
}
probe[col] = 0.0;
}
}
Ok(())
}
fn add_analytic_penalty_objective_to_eval(
theta: &Array1<f64>,
rho_dim: usize,
latent: &crate::terms::latent::LatentCoordValues,
registry: &crate::terms::AnalyticPenaltyRegistry,
eval: &mut (
f64,
Array1<f64>,
crate::solver::rho_optimizer::HessianResult,
),
) -> Result<(), EstimationError> {
let contribution = analytic_penalty_objective_contribution(theta, rho_dim, latent, registry)?;
eval.0 += contribution.cost;
if eval.1.len() != contribution.gradient.len() {
crate::bail_invalid_estim!(
"latent-coordinate REML gradient length mismatch: base={}, analytic_penalty={}",
eval.1.len(),
contribution.gradient.len()
);
}
eval.1 += &contribution.gradient;
add_analytic_penalty_hessian_to_eval(theta, rho_dim, latent, registry, eval)?;
Ok(())
}
fn spatial_log_kappa_hyper_dirs_frominfo_list(
info_list: Vec<SpatialPsiDerivative>,
) -> Result<Vec<DirectionalHyperParam>, EstimationError> {
use crate::estimate::reml::ImplicitDerivLevel;
use std::collections::HashMap;
let log_kappa_dim = info_list.len();
let group_ids: Vec<Option<usize>> = info_list.iter().map(|e| e.aniso_group_id).collect();
let mut group_indices_map: HashMap<usize, Vec<usize>> = HashMap::new();
for (idx, gid) in group_ids.iter().enumerate() {
if let Some(g) = gid {
group_indices_map.entry(*g).or_default().push(idx);
}
}
let mut hyper_dirs = Vec::with_capacity(log_kappa_dim);
for (i, info) in info_list.into_iter().enumerate() {
let SpatialPsiDerivative {
penalty_index: _,
penalty_indices,
global_range,
total_p,
x_psi_local,
s_psi_components_local,
x_psi_psi_local,
s_psi_psi_components_local,
aniso_group_id,
aniso_cross_designs,
aniso_cross_penalty_provider,
implicit_operator,
implicit_axis,
} = info;
let mut xsecond = vec![None; log_kappa_dim];
xsecond[i] = Some(if let Some(ref op) = implicit_operator {
crate::estimate::reml::HyperDesignDerivative::from_implicit(
op.clone(),
ImplicitDerivLevel::SecondDiag(implicit_axis),
global_range.clone(),
total_p,
)
} else {
crate::estimate::reml::HyperDesignDerivative::from_embedded(
x_psi_psi_local,
global_range.clone(),
total_p,
)
});
if let Some(cross_designs) = aniso_cross_designs {
if let Some(gid) = aniso_group_id {
let base = group_indices_map
.get(&gid)
.and_then(|v| v.first().copied())
.unwrap_or(i);
for (b_axis, cross_mat) in cross_designs.into_iter() {
let j = base + b_axis;
if j < log_kappa_dim {
xsecond[j] = Some(if let Some(ref op) = implicit_operator {
crate::estimate::reml::HyperDesignDerivative::from_implicit(
op.clone(),
ImplicitDerivLevel::SecondCross(implicit_axis, b_axis),
global_range.clone(),
total_p,
)
} else {
crate::estimate::reml::HyperDesignDerivative::from_embedded(
cross_mat,
global_range.clone(),
total_p,
)
});
}
}
}
}
let s_components = penalty_indices
.iter()
.copied()
.zip(s_psi_components_local.into_iter().map(|local| {
crate::estimate::reml::HyperPenaltyDerivative::from_embedded(
local,
global_range.clone(),
total_p,
)
}))
.collect::<Vec<_>>();
let s2_components = penalty_indices
.iter()
.copied()
.zip(s_psi_psi_components_local.into_iter().map(|local| {
crate::estimate::reml::HyperPenaltyDerivative::from_embedded(
local,
global_range.clone(),
total_p,
)
}))
.collect::<Vec<_>>();
let mut ssecond_components = vec![None; log_kappa_dim];
ssecond_components[i] = Some(s2_components);
let mut penaltysecond_partner_indices: Option<Vec<usize>> = None;
let penaltysecond_component_provider =
if let (Some(provider), Some(gid)) = (aniso_cross_penalty_provider, aniso_group_id) {
let group_indices = group_indices_map.get(&gid).cloned().unwrap_or_default();
let axis_in_group =
group_indices
.iter()
.position(|&idx| idx == i)
.ok_or_else(|| {
EstimationError::InvalidInput(format!(
"missing spatial hyper axis {} in anisotropy group {}",
i, gid
))
})?;
penaltysecond_partner_indices = Some(
group_indices
.iter()
.copied()
.filter(|&idx| idx != i)
.collect(),
);
let penalty_indices_inner = penalty_indices.clone();
let global_range_inner = global_range.clone();
let total_p_inner = total_p;
let group_indices_inner = group_indices;
Some(std::sync::Arc::new(
move |j: usize| -> Result<
Option<Vec<crate::estimate::reml::PenaltyDerivativeComponent>>,
EstimationError,
> {
let Some(other_axis_in_group) =
group_indices_inner.iter().position(|&idx| idx == j)
else {
return Ok(None);
};
if other_axis_in_group == axis_in_group {
return Ok(None);
}
let cross_pens = provider(other_axis_in_group)?;
if cross_pens.is_empty() {
return Ok(None);
}
Ok(Some(
penalty_indices_inner
.iter()
.copied()
.zip(cross_pens.into_iter().map(|local| {
crate::estimate::reml::HyperPenaltyDerivative::from_embedded(
local,
global_range_inner.clone(),
total_p_inner,
)
}))
.map(|(penalty_index, matrix)| {
crate::estimate::reml::PenaltyDerivativeComponent {
penalty_index,
matrix,
}
})
.collect(),
))
},
)
as std::sync::Arc<
dyn Fn(
usize,
) -> Result<
Option<Vec<crate::estimate::reml::PenaltyDerivativeComponent>>,
EstimationError,
> + Send
+ Sync
+ 'static,
>)
} else {
None
};
let x_first_hyper = if let Some(ref op) = implicit_operator {
crate::estimate::reml::HyperDesignDerivative::from_implicit(
op.clone(),
ImplicitDerivLevel::First(implicit_axis),
global_range.clone(),
total_p,
)
} else {
crate::estimate::reml::HyperDesignDerivative::from_embedded(
x_psi_local,
global_range.clone(),
total_p,
)
};
let mut dir = DirectionalHyperParam::new_compact(
x_first_hyper,
s_components,
Some(xsecond),
Some(ssecond_components),
)?
.not_penalty_like();
if let Some(provider) = penaltysecond_component_provider {
dir = dir.with_penaltysecond_component_provider(provider);
}
if let Some(partner_indices) = penaltysecond_partner_indices {
dir = dir.with_penaltysecond_partner_indices(partner_indices);
}
hyper_dirs.push(dir);
}
Ok(hyper_dirs)
}
fn spatial_term_uses_per_axis_psi(resolvedspec: &TermCollectionSpec, term_idx: usize) -> bool {
if let Some(mj) = measure_jet_term_spec(resolvedspec, term_idx) {
return measure_jet_enrolls_psi(mj);
}
let Some(d) = get_spatial_feature_dim(resolvedspec, term_idx) else {
return false;
};
if d <= 1 {
return false;
}
let Some(eta) = get_spatial_aniso_log_scales(resolvedspec, term_idx) else {
return false;
};
if eta.len() != d {
return false;
}
!matches!(
resolvedspec.smooth_terms.get(term_idx).map(|t| &t.basis),
Some(SmoothBasisSpec::Duchon { .. })
)
}
pub(crate) fn spatial_dims_per_term(
resolvedspec: &TermCollectionSpec,
spatial_terms: &[usize],
) -> Vec<usize> {
spatial_terms
.iter()
.map(|&term_idx| {
if let Some(mj) = measure_jet_term_spec(resolvedspec, term_idx) {
measure_jet_psi_dim(mj)
} else if spatial_term_uses_per_axis_psi(resolvedspec, term_idx) {
get_spatial_feature_dim(resolvedspec, term_idx).unwrap_or(1)
} else {
1
}
})
.collect()
}
fn has_aniso_terms(resolvedspec: &TermCollectionSpec, spatial_terms: &[usize]) -> bool {
spatial_terms
.iter()
.any(|&term_idx| spatial_term_uses_per_axis_psi(resolvedspec, term_idx))
}
macro_rules! impl_exact_joint_theta_memo {
() => {
fn memoized_cost(&self, theta: &Array1<f64>) -> Option<f64> {
if self
.current_theta
.as_ref()
.is_some_and(|cached| theta_values_match(cached, theta))
{
self.last_eval
.as_ref()
.map(|cached| cached.0)
.or(self.last_cost)
} else {
None
}
}
fn memoized_eval(
&self,
theta: &Array1<f64>,
) -> Option<(
f64,
Array1<f64>,
crate::solver::rho_optimizer::HessianResult,
)> {
if self
.current_theta
.as_ref()
.is_some_and(|cached| theta_values_match(cached, theta))
{
self.last_eval.clone()
} else {
None
}
}
fn store_eval(
&mut self,
eval: (
f64,
Array1<f64>,
crate::solver::rho_optimizer::HessianResult,
),
) {
self.last_cost = Some(eval.0);
self.last_eval = Some(eval);
}
};
}
struct SingleBlockExactJointDesignCache<'d> {
realizer: FrozenTermCollectionIncrementalRealizer<'d>,
current_theta: Option<Array1<f64>>,
last_eval_theta: Option<Array1<f64>>,
last_cost: Option<f64>,
last_eval: Option<(
f64,
Array1<f64>,
crate::solver::rho_optimizer::HessianResult,
)>,
cached_hyper_dirs: Option<(u64, Vec<DirectionalHyperParam>)>,
spatial_terms: Vec<usize>,
rho_dim: usize,
dims_per_term: Vec<usize>,
}
impl<'d> SingleBlockExactJointDesignCache<'d> {
fn new(
data: ArrayView2<'d, f64>,
spec: TermCollectionSpec,
design: TermCollectionDesign,
spatial_terms: Vec<usize>,
rho_dim: usize,
dims_per_term: Vec<usize>,
) -> Result<Self, String> {
Ok(Self {
realizer: FrozenTermCollectionIncrementalRealizer::new(data, spec, design)?,
current_theta: None,
last_eval_theta: None,
last_cost: None,
last_eval: None,
cached_hyper_dirs: None,
spatial_terms,
rho_dim,
dims_per_term,
})
}
fn design_revision(&self) -> u64 {
self.realizer.design_revision()
}
fn hyper_dirs_for_current_design(
&mut self,
data: ArrayView2<'_, f64>,
kind: SpatialHyperKind,
) -> Result<Vec<DirectionalHyperParam>, EstimationError> {
let revision = self.realizer.design_revision();
if let Some((cached_rev, dirs)) = self.cached_hyper_dirs.as_ref()
&& *cached_rev == revision
{
return Ok(dirs.clone());
}
let dirs = try_build_spatial_log_kappa_hyper_dirs(
data,
self.realizer.spec(),
self.realizer.design(),
&self.spatial_terms,
)?
.ok_or_else(|| {
EstimationError::InvalidInput(format!(
"failed to build {} hyper_dirs at current {}",
kind.adjective(),
kind.coord_name(),
))
})?;
self.cached_hyper_dirs = Some((revision, dirs.clone()));
Ok(dirs)
}
fn nfree_tensor_gradient_hyper_dirs(
&mut self,
theta: &Array1<f64>,
) -> Result<Vec<DirectionalHyperParam>, EstimationError> {
let psi = &theta.as_slice().ok_or_else(|| {
EstimationError::InvalidInput(
"nfree_tensor_gradient_hyper_dirs: theta is not contiguous".to_string(),
)
})?[self.rho_dim..];
let (global_range, p_total, s_psi_components) = self
.realizer
.canonical_penalty_derivatives_at_psi(&self.spatial_terms, psi)
.map_err(EstimationError::InvalidInput)?;
let zero_x = crate::estimate::reml::HyperDesignDerivative::zero(
self.realizer.design().design.nrows(),
p_total,
);
let components = s_psi_components
.into_iter()
.enumerate()
.map(|(penalty_index, local)| {
(
penalty_index,
crate::estimate::reml::HyperPenaltyDerivative::from_embedded(
local,
global_range.clone(),
p_total,
),
)
})
.collect::<Vec<_>>();
Ok(DirectionalHyperParam::new_compact(zero_x, components, None, None)?.not_penalty_like())
.map(|dir| vec![dir])
}
fn ensure_theta(&mut self, theta: &Array1<f64>) -> Result<(), String> {
if self
.current_theta
.as_ref()
.is_some_and(|cached| theta_values_match(cached, theta))
{
return Ok(());
}
let t_ensure = std::time::Instant::now();
let log_kappa = SpatialLogKappaCoords::from_theta_tail_with_dims(
theta,
self.rho_dim,
self.dims_per_term.clone(),
);
self.realizer
.apply_log_kappa(&log_kappa, &self.spatial_terms)?;
log::info!(
"[STAGE] ensure_theta (apply_log_kappa, {} terms): {:.3}s",
self.spatial_terms.len(),
t_ensure.elapsed().as_secs_f64(),
);
self.current_theta = Some(theta.clone());
self.last_eval_theta = None;
self.last_cost = None;
self.last_eval = None;
Ok(())
}
fn memoized_cost(&self, theta: &Array1<f64>) -> Option<f64> {
if self
.last_eval_theta
.as_ref()
.is_some_and(|cached| theta_values_match(cached, theta))
{
self.last_eval
.as_ref()
.map(|cached| cached.0)
.or(self.last_cost)
} else {
None
}
}
fn memoized_eval(
&self,
theta: &Array1<f64>,
) -> Option<(
f64,
Array1<f64>,
crate::solver::rho_optimizer::HessianResult,
)> {
if self
.last_eval_theta
.as_ref()
.is_some_and(|cached| theta_values_match(cached, theta))
{
self.last_eval.clone()
} else {
None
}
}
fn store_eval_at(
&mut self,
theta: &Array1<f64>,
eval: (
f64,
Array1<f64>,
crate::solver::rho_optimizer::HessianResult,
),
) {
self.last_eval_theta = Some(theta.clone());
self.last_cost = Some(eval.0);
self.last_eval = Some(eval);
}
fn store_cost_at(&mut self, theta: &Array1<f64>, cost: f64) {
self.last_eval_theta = Some(theta.clone());
self.last_cost = Some(cost);
self.last_eval = None;
}
fn spec(&self) -> &TermCollectionSpec {
self.realizer.spec()
}
fn design(&self) -> &TermCollectionDesign {
self.realizer.design()
}
fn supports_nfree_penalty_rekey(&self) -> bool {
self.realizer
.supports_nfree_penalty_rekey(&self.spatial_terms)
}
fn supports_nfree_gradient_only_routing(&self) -> bool {
self.realizer
.supports_nfree_gradient_only_routing(&self.spatial_terms)
}
fn canonical_penalties_at(
&mut self,
theta: &Array1<f64>,
) -> Result<(Vec<crate::construction::CanonicalPenalty>, Vec<usize>), String> {
let psi = &theta
.as_slice()
.ok_or_else(|| "canonical_penalties_at: theta is not contiguous".to_string())?
[self.rho_dim..];
self.realizer
.canonical_penalties_at_psi(&self.spatial_terms, psi)
}
}
struct SingleBlockLatentCoordDesignCache {
data: Array2<f64>,
spec: TermCollectionSpec,
design: TermCollectionDesign,
current_theta: Option<Array1<f64>>,
current_latent: Option<std::sync::Arc<crate::terms::latent::LatentCoordValues>>,
current_hyper_dirs: Option<Vec<crate::estimate::reml::DirectionalHyperParam>>,
current_design_cache_id: Option<u64>,
latent_design_cache: crate::solver::latent_cache::LatentDesignCache,
last_cost: Option<f64>,
last_eval: Option<(
f64,
Array1<f64>,
crate::solver::rho_optimizer::HessianResult,
)>,
term_index: crate::types::SmoothTermIdx,
feature_cols: Vec<usize>,
rho_dim: usize,
n_obs: usize,
latent_dim: usize,
id_mode: crate::terms::latent::LatentIdMode,
manifold: crate::terms::latent::LatentManifold,
retraction_registry: crate::solver::latent_cache::LatentRetractionRegistry,
latent_id: u64,
analytic_penalties: Option<std::sync::Arc<crate::terms::AnalyticPenaltyRegistry>>,
analytic_rho_count: usize,
design_revision: u64,
last_outer_iter: Option<u64>,
}
impl SingleBlockLatentCoordDesignCache {
fn new(
data: Array2<f64>,
spec: TermCollectionSpec,
design: TermCollectionDesign,
latent: &StandardLatentCoordConfig,
rho_dim: usize,
) -> Result<Self, String> {
if latent.term_index.get() >= spec.smooth_terms.len() {
return Err(SmoothError::dimension_mismatch(format!(
"latent-coordinate term index {} out of bounds for {} smooth terms",
latent.term_index,
spec.smooth_terms.len()
))
.into());
}
if latent.feature_cols.len() != latent.values.latent_dim() {
return Err(SmoothError::dimension_mismatch(format!(
"latent-coordinate feature width mismatch: feature_cols={}, latent_dim={}",
latent.feature_cols.len(),
latent.values.latent_dim()
))
.into());
}
if latent.values.n_obs() != data.nrows() {
return Err(SmoothError::dimension_mismatch(format!(
"latent-coordinate row mismatch: latent n={}, data n={}",
latent.values.n_obs(),
data.nrows()
))
.into());
}
let analytic_rho_count = latent
.analytic_penalties
.as_ref()
.map_or(0, |registry| registry.total_rho_count());
Ok(Self {
data,
spec,
design,
current_theta: None,
current_latent: None,
current_hyper_dirs: None,
current_design_cache_id: None,
latent_design_cache: crate::solver::latent_cache::LatentDesignCache::default(),
last_cost: None,
last_eval: None,
term_index: latent.term_index,
feature_cols: latent.feature_cols.clone(),
rho_dim,
n_obs: latent.values.n_obs(),
latent_dim: latent.values.latent_dim(),
id_mode: latent.values.id_mode().clone(),
manifold: latent.values.manifold().clone(),
retraction_registry: latent.values.retraction_registry().clone(),
latent_id: latent.values.latent_id(),
analytic_penalties: latent.analytic_penalties.clone(),
analytic_rho_count,
design_revision: 0,
last_outer_iter: None,
})
}
fn design_revision(&self) -> u64 {
self.design_revision
}
fn design(&self) -> &TermCollectionDesign {
&self.design
}
fn latent(&self) -> Result<std::sync::Arc<crate::terms::latent::LatentCoordValues>, String> {
self.current_latent
.as_ref()
.cloned()
.ok_or_else(|| "latent-coordinate cache has not been realized".to_string())
}
fn analytic_penalties(&self) -> Option<std::sync::Arc<crate::terms::AnalyticPenaltyRegistry>> {
self.analytic_penalties.clone()
}
fn analytic_penalty_rho_count(&self) -> usize {
self.analytic_rho_count
}
fn hyper_dirs(&self) -> Result<Vec<crate::estimate::reml::DirectionalHyperParam>, String> {
self.current_hyper_dirs
.as_ref()
.cloned()
.ok_or_else(|| "latent-coordinate hyper_dirs cache has not been realized".to_string())
}
fn latent_basis_kind(&self) -> Result<crate::solver::latent_cache::LatentBasisKind, String> {
let smooth_term = self
.design
.smooth
.terms
.get(self.term_index.get())
.ok_or_else(|| {
SmoothError::dimension_mismatch(format!(
"LatentCoord term index {} out of bounds for realized smooth design",
self.term_index
))
})?;
let termspec = self
.spec
.smooth_terms
.get(self.term_index.get())
.ok_or_else(|| {
SmoothError::dimension_mismatch(format!(
"LatentCoord term index {} out of bounds for resolved smooth spec",
self.term_index
))
})?;
match (&termspec.basis, &smooth_term.metadata) {
(
SmoothBasisSpec::Matern { .. },
BasisMetadata::Matern {
centers,
length_scale,
nu,
aniso_log_scales,
..
},
) => Ok(crate::solver::latent_cache::LatentBasisKind::Matern {
centers: centers.clone(),
length_scale: *length_scale,
nu: *nu,
aniso_log_scales: aniso_log_scales
.clone()
.unwrap_or_else(|| vec![0.0; centers.ncols()]),
chunk_size: crate::basis::auto_streaming_chunk_size_for_dense(
self.n_obs,
centers.nrows(),
),
}),
(
SmoothBasisSpec::Duchon { .. },
BasisMetadata::Duchon {
centers,
length_scale,
power,
nullspace_order,
aniso_log_scales,
..
},
) => Ok(crate::solver::latent_cache::LatentBasisKind::Duchon {
centers: centers.clone(),
length_scale: *length_scale,
power: *power,
nullspace_order: *nullspace_order,
aniso_log_scales: aniso_log_scales
.clone()
.unwrap_or_else(|| vec![0.0; centers.ncols()]),
}),
(
SmoothBasisSpec::Sphere { .. },
BasisMetadata::Sphere {
centers,
penalty_order,
method,
..
},
) if matches!(*method, crate::basis::SphereMethod::Wahba) => {
Ok(crate::solver::latent_cache::LatentBasisKind::Sphere {
centers: centers.clone(),
penalty_order: *penalty_order,
chunk_size: crate::basis::auto_streaming_chunk_size_for_dense(
self.n_obs,
centers.nrows(),
),
})
}
(
SmoothBasisSpec::BSpline1D { spec, .. },
BasisMetadata::BSpline1D {
knots,
periodic,
degree: meta_degree,
..
},
) => {
let effective_degree = meta_degree.unwrap_or(spec.degree);
if let Some((domain_start, period, num_basis)) = periodic {
Ok(
crate::solver::latent_cache::LatentBasisKind::PeriodicBspline {
domain_start: *domain_start,
period: *period,
degree: effective_degree,
num_basis: *num_basis,
chunk_size: crate::basis::auto_streaming_chunk_size_for_dense(
self.n_obs, *num_basis,
),
},
)
} else {
let num_basis_est = knots.len().saturating_sub(effective_degree + 1);
Ok(
crate::solver::latent_cache::LatentBasisKind::TensorBspline {
knots: vec![knots.clone()],
degrees: vec![effective_degree],
chunk_size: crate::basis::auto_streaming_chunk_size_for_dense(
self.n_obs,
num_basis_est,
),
},
)
}
}
(
SmoothBasisSpec::TensorBSpline { .. },
BasisMetadata::TensorBSpline { knots, degrees, .. },
) => Ok(
crate::solver::latent_cache::LatentBasisKind::TensorBspline {
knots: knots.clone(),
degrees: degrees.clone(),
chunk_size: None,
},
),
(
SmoothBasisSpec::Pca { .. },
BasisMetadata::Pca {
basis_matrix,
centered,
smooth_penalty,
center_mean,
pca_basis_path,
chunk_size,
..
},
) => {
let center_mean_fingerprint = if *centered && pca_basis_path.is_none() {
let mean = center_mean.as_ref().ok_or_else(|| {
SmoothError::invalid_config(
"latent-coordinate Pca cache key requires center_mean when centered",
)
})?;
Some(crate::solver::latent_cache::pca_center_mean_fingerprint(
mean,
))
} else {
None
};
Ok(crate::solver::latent_cache::LatentBasisKind::Pca {
basis_matrix: basis_matrix.clone(),
centered: *centered,
center_mean_fingerprint,
smooth_penalty: *smooth_penalty,
pca_basis_path: pca_basis_path.clone(),
chunk_size: *chunk_size,
})
}
_ => Err(SmoothError::invalid_config(
"latent-coordinate design cache could not key the realized latent smooth basis"
.to_string(),
)
.into()),
}
}
fn ensure_theta(&mut self, theta: &Array1<f64>) -> Result<(), String> {
if self
.current_theta
.as_ref()
.is_some_and(|cached| theta_values_match(cached, theta))
{
return Ok(());
}
let latent_flat_len = self.n_obs * self.latent_dim;
let direct_hyper_count = latent_coord_direct_hyper_count(&self.id_mode, self.latent_dim);
let expected =
self.rho_dim + latent_flat_len + self.analytic_rho_count + direct_hyper_count;
if theta.len() != expected {
return Err(SmoothError::dimension_mismatch(format!(
"latent-coordinate theta length mismatch: got {}, expected {} (rho_dim={}, n={}, d={}, analytic_rhos={}, direct_hypers={})",
theta.len(),
expected,
self.rho_dim,
self.n_obs,
self.latent_dim,
self.analytic_rho_count,
direct_hyper_count
))
.into());
}
let flat = theta
.slice(s![self.rho_dim..self.rho_dim + latent_flat_len])
.to_owned();
let latent = std::sync::Arc::new(
crate::terms::latent::LatentCoordValues::from_flat_with_manifold_and_retraction_and_id(
flat,
self.n_obs,
self.latent_dim,
self.id_mode.clone(),
self.manifold.clone(),
self.retraction_registry.clone(),
self.latent_id,
),
);
let latent_values_changed = self
.current_latent
.as_ref()
.map(|cached| !latent_values_match(cached.as_flat(), latent.as_flat()))
.unwrap_or(true);
if latent_values_changed {
self.latent_design_cache.invalidate_all();
self.current_design_cache_id = None;
self.design_revision = self.design_revision.wrapping_add(1);
}
for n in 0..self.n_obs {
for axis in 0..self.latent_dim {
let col = self.feature_cols[axis];
self.data[[n, col]] = latent.as_flat()[n * self.latent_dim + axis];
}
}
let basis_kind = self.latent_basis_kind()?;
let rebuilt_width = self.design.design.ncols();
let spec = self.spec.clone();
let term_index = self.term_index;
let analytic_rho_count = self.analytic_rho_count;
let data = self.data.view();
let design_context_digest =
crate::solver::latent_cache::latent_design_context_cache_digest(
data,
&spec,
term_index,
analytic_rho_count,
&self.feature_cols,
)
.map_err(|e| e.to_string())?;
let lookup = self
.latent_design_cache
.lookup_or_compute(latent.clone(), basis_kind, design_context_digest, || {
let rebuilt = build_term_collection_design(data, &spec).map_err(|e| {
EstimationError::InvalidInput(format!(
"failed to rebuild latent-coordinate design: {e}"
))
})?;
if rebuilt.design.ncols() != rebuilt_width {
crate::bail_invalid_estim!(
"latent-coordinate design topology changed: rebuilt p={}, cached p={}",
rebuilt.design.ncols(),
rebuilt_width
);
}
let hyper_dirs = try_build_latent_coord_hyper_dirs(
latent.clone(),
&spec,
&rebuilt,
&[term_index],
analytic_rho_count,
)?
.ok_or_else(|| {
EstimationError::InvalidInput(
"failed to build latent-coordinate hyper_dirs".to_string(),
)
})?;
Ok(crate::solver::latent_cache::ComputedLatentDesign {
design: rebuilt,
hyper_dirs,
})
})
.map_err(|e| e.to_string())?;
if lookup.cached.design.design.ncols() != self.design.design.ncols() {
return Err(SmoothError::dimension_mismatch(format!(
"latent-coordinate design topology changed: rebuilt p={}, cached p={}",
lookup.cached.design.design.ncols(),
self.design.design.ncols()
))
.into());
}
self.design = lookup.cached.design.clone();
self.current_hyper_dirs = Some(lookup.cached.hyper_dirs.clone());
self.current_latent = Some(latent);
self.current_theta = Some(theta.clone());
self.last_cost = None;
self.last_eval = None;
self.last_outer_iter = None;
if !latent_values_changed && self.current_design_cache_id != Some(lookup.entry_id) {
self.design_revision = self.design_revision.wrapping_add(1);
}
self.current_design_cache_id = Some(lookup.entry_id);
Ok(())
}
fn memoized_cost(&self, theta: &Array1<f64>) -> Option<f64> {
if self
.current_theta
.as_ref()
.is_some_and(|cached| theta_values_match(cached, theta))
&& self.last_outer_iter
== Some(crate::solver::estimate::reml::outer_eval::current_outer_iter())
{
self.last_eval
.as_ref()
.map(|cached| cached.0)
.or(self.last_cost)
} else {
None
}
}
fn memoized_eval(
&self,
theta: &Array1<f64>,
) -> Option<(
f64,
Array1<f64>,
crate::solver::rho_optimizer::HessianResult,
)> {
if self
.current_theta
.as_ref()
.is_some_and(|cached| theta_values_match(cached, theta))
&& self.last_outer_iter
== Some(crate::solver::estimate::reml::outer_eval::current_outer_iter())
{
self.last_eval.clone()
} else {
None
}
}
fn store_eval(
&mut self,
eval: (
f64,
Array1<f64>,
crate::solver::rho_optimizer::HessianResult,
),
) {
self.last_cost = Some(eval.0);
self.last_eval = Some(eval);
self.last_outer_iter =
Some(crate::solver::estimate::reml::outer_eval::current_outer_iter());
}
fn store_cost(&mut self, cost: f64) {
self.last_cost = Some(cost);
self.last_outer_iter =
Some(crate::solver::estimate::reml::outer_eval::current_outer_iter());
}
fn reset(&mut self) {
self.current_theta = None;
self.current_latent = None;
self.current_hyper_dirs = None;
self.current_design_cache_id = None;
self.latent_design_cache.invalidate();
self.last_cost = None;
self.last_eval = None;
self.last_outer_iter = None;
}
}
pub fn fixed_kappa_profiled_reml_score(
data: ArrayView2<'_, f64>,
y: ArrayView1<'_, f64>,
weights: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
resolvedspec: &TermCollectionSpec,
term_idx: usize,
kappa: f64,
family: LikelihoodSpec,
options: &FitOptions,
) -> Result<f64, EstimationError> {
if !kappa.is_finite() {
crate::bail_invalid_estim!("fixed-κ profiled score probed a non-finite κ = {kappa}");
}
let (feature_cols, mut probe_basis) = match resolvedspec
.smooth_terms
.get(term_idx)
.map(|t| &t.basis)
{
Some(SmoothBasisSpec::ConstantCurvature {
feature_cols, spec, ..
}) => (feature_cols.clone(), spec.clone()),
_ => {
crate::bail_invalid_estim!(
"fixed-κ profiled score: term {term_idx} is not a constant-curvature smooth"
)
}
};
probe_basis.kappa = kappa;
let is_unweighted = weights.iter().all(|&w| (w - 1.0).abs() <= 1e-12);
let is_zero_offset = offset.iter().all(|&o| o.abs() <= 1e-12);
if family == LikelihoodSpec::gaussian_identity() && is_unweighted && is_zero_offset {
let x_term = select_columns(data, &feature_cols).map_err(EstimationError::from)?;
let score =
crate::basis::constant_curvature_honest_profiled_reml_score(x_term.view(), y, &probe_basis)
.map_err(|e| {
EstimationError::InvalidInput(format!(
"fixed-κ honest profiled-REML score at κ={kappa} failed: {e}"
))
})?;
if !score.is_finite() {
crate::bail_invalid_estim!(
"fixed-κ honest profiled-REML score at κ={kappa} is non-finite"
);
}
return Ok(score);
}
let mut probe_spec = resolvedspec.clone();
match probe_spec.smooth_terms.get_mut(term_idx).map(|t| &mut t.basis) {
Some(SmoothBasisSpec::ConstantCurvature { spec, .. }) => spec.kappa = kappa,
_ => {
crate::bail_invalid_estim!(
"fixed-κ profiled score: term {term_idx} is not a constant-curvature smooth"
)
}
}
let fixed_kappa_options = SpatialLengthScaleOptimizationOptions {
enabled: false,
..SpatialLengthScaleOptimizationOptions::default()
};
let fit = fit_term_collectionwith_spatial_length_scale_optimization(
data,
y.to_owned(),
weights.to_owned(),
offset.to_owned(),
&probe_spec,
family,
options,
&fixed_kappa_options,
)?;
let score = fit_score(&fit.fit);
if !score.is_finite() {
crate::bail_invalid_estim!("fixed-κ profiled fit at κ={kappa} returned a non-finite score");
}
Ok(score)
}
fn constant_curvature_kappa_fair_argmin(
data: ArrayView2<'_, f64>,
y: ArrayView1<'_, f64>,
resolvedspec: &TermCollectionSpec,
term_idx: usize,
) -> Option<f64> {
let (kappa_min, kappa_max) = constant_curvature_kappa_bounds(data, resolvedspec, term_idx);
if !(kappa_min.is_finite() && kappa_max.is_finite() && kappa_max > kappa_min) {
return None;
}
let (feature_cols, base_spec) = match resolvedspec.smooth_terms.get(term_idx).map(|t| &t.basis) {
Some(SmoothBasisSpec::ConstantCurvature {
feature_cols, spec, ..
}) => (feature_cols, spec.clone()),
_ => return None,
};
let x_term = match select_columns(data, feature_cols) {
Ok(x) => x,
Err(e) => {
log::info!("[spatial-kappa] #1464 κ-fair argmin column select failed ({e}); skipping");
return None;
}
};
const GRID_STEPS: usize = 24;
let mut best: Option<(f64, f64)> = None; for i in 0..=GRID_STEPS {
let t = i as f64 / GRID_STEPS as f64;
let kappa = kappa_min + (kappa_max - kappa_min) * t;
let mut probe_spec = base_spec.clone();
probe_spec.kappa = kappa;
match crate::basis::constant_curvature_kappa_fair_sign_score(x_term.view(), y, &probe_spec) {
Ok(score) => {
if best.as_ref().is_none_or(|(b, _)| score < *b) {
best = Some((score, kappa));
}
}
Err(e) => {
log::info!(
"[spatial-kappa] #1464 κ-fair argmin probe at κ={kappa:.4} failed ({e}); skipping"
);
}
}
}
best.map(|(score, kappa)| {
log::info!(
"[spatial-kappa] #1464 κ-fair argmin κ̂={kappa:.4} (κ-fair score={score:.6e}) for term {term_idx}"
);
kappa
})
}
fn select_constant_curvature_kappa_sign_seed(
data: ArrayView2<'_, f64>,
y: ArrayView1<'_, f64>,
resolvedspec: &TermCollectionSpec,
term_idx: usize,
) -> Option<f64> {
let (kappa_min, kappa_max) = constant_curvature_kappa_bounds(data, resolvedspec, term_idx);
if !(kappa_min.is_finite() && kappa_max.is_finite() && kappa_max > kappa_min) {
return None;
}
let (feature_cols, base_spec) = match resolvedspec.smooth_terms.get(term_idx).map(|t| &t.basis) {
Some(SmoothBasisSpec::ConstantCurvature {
feature_cols, spec, ..
}) => (feature_cols, spec.clone()),
_ => return None,
};
let x_term = match select_columns(data, feature_cols) {
Ok(x) => x,
Err(e) => {
log::info!("[spatial-kappa] #1464 sign-basin scan column select failed ({e}); skipping");
return None;
}
};
let probes = [
kappa_min,
0.5 * kappa_min,
0.0,
0.5 * kappa_max,
kappa_max,
];
let mut best: Option<(f64, f64)> = None; for &kappa in &probes {
let mut probe_spec = base_spec.clone();
probe_spec.kappa = kappa;
match crate::basis::constant_curvature_kappa_fair_sign_score(
x_term.view(),
y,
&probe_spec,
) {
Ok(score) => {
if best.as_ref().is_none_or(|(b, _)| score < *b) {
best = Some((score, kappa));
}
}
Err(e) => {
log::info!(
"[spatial-kappa] #1464 sign-basin probe at κ={kappa:.4} failed ({e}); skipping"
);
}
}
}
best.map(|(score, kappa)| {
log::info!(
"[spatial-kappa] #1464 κ-fair sign-basin scan selected κ_seed={kappa:.4} \
(κ-fair score={score:.6e}) for term {term_idx}"
);
kappa
})
}
fn try_exact_joint_spatial_length_scale_optimization(
data: ArrayView2<'_, f64>,
y: ArrayView1<'_, f64>,
weights: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
resolvedspec: &TermCollectionSpec,
best: &FittedTermCollection,
family: LikelihoodSpec,
options: &FitOptions,
kappa_options: &SpatialLengthScaleOptimizationOptions,
spatial_terms: &[usize],
) -> Result<Option<FittedTermCollectionWithSpec>, EstimationError> {
if spatial_terms.is_empty() {
return Ok(None);
}
kappa_options
.validate()
.map_err(EstimationError::InvalidInput)?;
let cc_term_set = constant_curvature_term_indices(resolvedspec);
let all_spatial_are_cc =
!cc_term_set.is_empty() && spatial_terms.iter().all(|t| cc_term_set.contains(t));
if all_spatial_are_cc {
let mut fixed_kappa_spec = resolvedspec.clone();
let mut any_kappa_chosen = false;
for &term_idx in spatial_terms {
if let Some(kappa_hat) =
constant_curvature_kappa_fair_argmin(data, y, resolvedspec, term_idx)
.filter(|&k| k < 0.0)
{
if let Some(SmoothBasisSpec::ConstantCurvature { spec: cc, .. }) = fixed_kappa_spec
.smooth_terms
.get_mut(term_idx)
.map(|t| &mut t.basis)
{
cc.kappa = kappa_hat;
any_kappa_chosen = true;
log::info!(
"[spatial-kappa] #1464 term {term_idx}: fixed κ̂ = {kappa_hat:.4} from κ-fair argmin (hyperbolic basin; profiling ρ only)"
);
}
}
}
if any_kappa_chosen {
let baseline_score = fit_score(&best.fit);
let fitted = fit_term_collection_forspec(
data,
y,
weights,
offset,
&fixed_kappa_spec,
family.clone(),
options,
)?;
let frozen_spec =
freeze_term_collection_from_design(&fixed_kappa_spec, &fitted.design)?;
let mut fit = fitted.fit;
fit.reml_score = baseline_score;
return Ok(Some(FittedTermCollectionWithSpec {
fit,
design: fitted.design,
resolvedspec: frozen_spec,
adaptive_diagnostics: fitted.adaptive_diagnostics,
kappa_timing: None,
}));
}
}
if try_build_spatial_log_kappa_hyper_dirs(data, resolvedspec, &best.design, spatial_terms)?
.is_none()
{
if !constant_curvature_term_indices(resolvedspec).is_empty() {
log::info!(
"[#1464-trace] try_exact_joint RETURNED None (hyper_dirs unavailable); \
κ̂ comes from a NON-joint path"
);
}
return Ok(None);
}
if !constant_curvature_term_indices(resolvedspec).is_empty() {
log::info!(
"[#1464-trace] try_exact_joint ENTERED for {} spatial term(s); CC present",
spatial_terms.len()
);
}
const JOINT_RHO_BOUND: f64 = 12.0;
let rho_dim = best.fit.lambdas.len();
let has_constant_curvature_term = !constant_curvature_term_indices(resolvedspec).is_empty();
let rho_upper_bound = if has_constant_curvature_term {
crate::estimate::RHO_BOUND
} else {
JOINT_RHO_BOUND
};
let dims_per_term = spatial_dims_per_term(resolvedspec, spatial_terms);
let use_aniso = has_aniso_terms(resolvedspec, spatial_terms);
let log_kappa0 = if use_aniso {
SpatialLogKappaCoords::from_length_scales_aniso(resolvedspec, spatial_terms, kappa_options)
} else {
SpatialLogKappaCoords::from_length_scales(resolvedspec, spatial_terms, kappa_options)
};
let mut log_kappa0 =
log_kappa0.reseed_from_data(data, resolvedspec, spatial_terms, kappa_options);
let mut cc_sign_seeds: Vec<(usize, f64)> = Vec::new();
if has_constant_curvature_term {
for (slot, &term_idx) in spatial_terms.iter().enumerate() {
if constant_curvature_term_spec(resolvedspec, term_idx).is_none() {
continue;
}
let scan = select_constant_curvature_kappa_sign_seed(
data,
y,
resolvedspec,
term_idx,
);
match scan {
Some(kappa_seed) => {
log::info!(
"[#1464-trace] term {term_idx}: κ-fair sign-basin scan picked κ_seed = {kappa_seed}"
);
log_kappa0.set_scalar_slot(slot, kappa_seed);
cc_sign_seeds.push((slot, kappa_seed));
}
None => {
log::info!(
"[#1464-trace] term {term_idx}: fixed-κ sign-basin scan returned NONE (no seed applied)"
);
}
}
}
}
let log_kappa_lower = if use_aniso {
SpatialLogKappaCoords::lower_bounds_aniso_from_data(
data,
resolvedspec,
spatial_terms,
&dims_per_term,
kappa_options,
)
} else {
SpatialLogKappaCoords::lower_bounds_from_data(
data,
resolvedspec,
spatial_terms,
kappa_options,
)
};
let log_kappa_upper = if use_aniso {
SpatialLogKappaCoords::upper_bounds_aniso_from_data(
data,
resolvedspec,
spatial_terms,
&dims_per_term,
kappa_options,
)
} else {
SpatialLogKappaCoords::upper_bounds_from_data(
data,
resolvedspec,
spatial_terms,
kappa_options,
)
};
let mut log_kappa_lower = log_kappa_lower;
let mut log_kappa_upper = log_kappa_upper;
for &(slot, kappa_seed) in &cc_sign_seeds {
if kappa_seed != 0.0 {
log_kappa_lower.set_scalar_slot(slot, kappa_seed);
log_kappa_upper.set_scalar_slot(slot, kappa_seed);
}
log::info!(
"[#1464-trace] slot {slot}: FROZE joint ψ coordinate at κ_seed={kappa_seed} \
(window [{}, {}]); raw fit_score is sign-blind so the κ-fair scan is authoritative",
log_kappa_lower.as_array()[log_kappa_lower.dims_per_term()[..slot].iter().sum::<usize>()],
log_kappa_upper.as_array()[log_kappa_upper.dims_per_term()[..slot].iter().sum::<usize>()],
);
}
let log_kappa0 = log_kappa0.clamp_to_bounds(&log_kappa_lower, &log_kappa_upper);
let setup = ExactJointHyperSetup::new(
best.fit.lambdas.mapv(f64::ln),
Array1::<f64>::from_elem(rho_dim, -JOINT_RHO_BOUND),
Array1::<f64>::from_elem(rho_dim, rho_upper_bound),
log_kappa0,
log_kappa_lower,
log_kappa_upper,
);
let theta0 = setup.theta0();
let lower = setup.lower();
let upper = setup.upper();
let kind = if use_aniso {
SpatialHyperKind::Anisotropic
} else {
SpatialHyperKind::Isotropic
};
let (outcome, kappa_timing) = run_exact_joint_spatial_optimization(
kind,
data,
y,
weights,
offset,
resolvedspec,
&best.design,
family.clone(),
options,
spatial_terms,
&dims_per_term,
&theta0,
&lower,
&upper,
rho_dim,
kappa_options,
)?;
let baseline_score = fit_score(&best.fit);
let (theta_star, joint_final_value) = match outcome {
SpatialJointOutcome::Optimized {
theta_star,
final_value,
} => (theta_star, final_value),
SpatialJointOutcome::NonConverged {
iterations,
final_value,
final_grad_norm,
} => {
if has_constant_curvature_term {
log::info!(
"[#1464-trace] joint solve NONCONVERGED (iters={iterations}, \
final_value={final_value}); returning FROZEN BASELINE geometry \
(κ̂ = spec default, NOT the joint candidate)"
);
}
log::info!(
"[spatial-kappa] joint spatial optimization did not converge \
(iterations={}, final_objective={:.6e}, final_grad_norm={}); \
keeping the frozen baseline geometry",
iterations,
final_value,
final_grad_norm.map_or_else(|| "n/a".to_string(), |g| format!("{g:.3e}")),
);
return Ok(Some(fit_frozen_baseline_geometry(
data,
y,
weights,
offset,
resolvedspec,
best,
family,
options,
baseline_score,
Some(kappa_timing),
)?));
}
};
let accept_tol = options.tol.max(1e-8 * baseline_score.abs()).max(1e-12);
if joint_final_value > baseline_score + accept_tol {
if has_constant_curvature_term {
log::info!(
"[#1464-trace] joint candidate WORSENED score (joint={joint_final_value}, \
baseline={baseline_score}); returning FROZEN BASELINE geometry \
(κ̂ = spec default, NOT the joint candidate)"
);
}
log::info!(
"[spatial-kappa] exact joint spatial candidate worsened the profiled score (joint={:.6e}, baseline={:.6e}, tol={:.2e}); keeping the frozen baseline geometry",
joint_final_value,
baseline_score,
accept_tol,
);
return Ok(Some(fit_frozen_baseline_geometry(
data,
y,
weights,
offset,
resolvedspec,
best,
family,
options,
baseline_score,
Some(kappa_timing),
)?));
}
let rho_star = theta_star.slice(s![..rho_dim]).mapv(f64::exp);
let log_kappa_star =
SpatialLogKappaCoords::from_theta_tail_with_dims(&theta_star, rho_dim, dims_per_term);
if has_constant_curvature_term {
let star = log_kappa_star.as_array();
let dims = log_kappa_star.dims_per_term();
for (slot, &term_idx) in spatial_terms.iter().enumerate() {
if constant_curvature_term_spec(resolvedspec, term_idx).is_some() {
let off: usize = dims[..slot].iter().sum();
log::info!(
"[#1464-trace] term {term_idx}: joint solver CONVERGED ψ-tail κ = {} \
(this is the optimised candidate; joint_final_value={joint_final_value})",
star[off]
);
}
}
}
let baseline_spec = resolvedspec;
let optimized_spec = log_kappa_star.apply_tospec(resolvedspec, spatial_terms)?;
let optimized = fit_term_collection_forspecwith_heuristic_lambdas(
data,
y,
weights,
offset,
&optimized_spec,
rho_star.as_slice(),
family.clone(),
options,
)?;
let optimized_edf = optimized.fit.inference.as_ref().map(|inf| inf.edf_total);
if let Some(opt_edf) = optimized_edf
&& opt_edf < SPATIAL_COLLAPSE_EDF_FLOOR
{
let baseline = fit_frozen_baseline_geometry(
data,
y,
weights,
offset,
baseline_spec,
best,
family.clone(),
options,
baseline_score,
Some(kappa_timing),
)?;
let baseline_edf = baseline.fit.inference.as_ref().map(|inf| inf.edf_total);
if let Some(base_edf) = baseline_edf
&& base_edf >= opt_edf + SPATIAL_COLLAPSE_EDF_MARGIN
{
log::info!(
"[spatial-kappa] joint candidate collapsed to the null (edf={opt_edf:.3}); \
baseline geometry retains edf={base_edf:.3} — keeping the frozen baseline",
);
return Ok(Some(baseline));
}
}
let mut fit = optimized.fit;
fit.reml_score = joint_final_value;
let optimized_result = FittedTermCollectionWithSpec {
fit,
design: optimized.design,
resolvedspec: optimized_spec,
adaptive_diagnostics: optimized.adaptive_diagnostics,
kappa_timing: Some(kappa_timing),
};
Ok(Some(optimized_result))
}
const SPATIAL_COLLAPSE_EDF_FLOOR: f64 = 2.5;
const SPATIAL_COLLAPSE_EDF_MARGIN: f64 = 1.0;
fn fit_frozen_baseline_geometry(
data: ArrayView2<'_, f64>,
y: ArrayView1<'_, f64>,
weights: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
resolvedspec: &TermCollectionSpec,
best: &FittedTermCollection,
family: LikelihoodSpec,
options: &FitOptions,
baseline_score: f64,
kappa_timing: Option<SpatialLengthScaleOptimizationTiming>,
) -> Result<FittedTermCollectionWithSpec, EstimationError> {
let baseline = fit_term_collection_forspecwith_heuristic_lambdas(
data,
y,
weights,
offset,
resolvedspec,
best.fit.lambdas.as_slice(),
family.clone(),
options,
)?;
let best_edf = best.fit.inference.as_ref().map(|inf| inf.edf_total);
let baseline_edf = baseline.fit.inference.as_ref().map(|inf| inf.edf_total);
let baseline = match (best_edf, baseline_edf) {
(Some(best_edf), Some(base_edf))
if base_edf < SPATIAL_COLLAPSE_EDF_FLOOR
&& best_edf >= base_edf + SPATIAL_COLLAPSE_EDF_MARGIN =>
{
log::info!(
"[spatial-kappa] warm-started frozen baseline collapsed (edf={base_edf:.3}) \
below the certified baseline (edf={best_edf:.3}); refitting from scratch",
);
fit_term_collection_forspec(data, y, weights, offset, resolvedspec, family, options)?
}
_ => baseline,
};
let mut fit = baseline.fit;
fit.reml_score = baseline_score;
Ok(FittedTermCollectionWithSpec {
fit,
design: baseline.design,
resolvedspec: resolvedspec.clone(),
adaptive_diagnostics: baseline.adaptive_diagnostics,
kappa_timing,
})
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
enum SpatialHyperKind {
Anisotropic,
Isotropic,
}
impl SpatialHyperKind {
fn label(self) -> &'static str {
match self {
SpatialHyperKind::Anisotropic => "spatial-aniso-joint",
SpatialHyperKind::Isotropic => "spatial-iso-joint",
}
}
fn adjective(self) -> &'static str {
match self {
SpatialHyperKind::Anisotropic => "anisotropic",
SpatialHyperKind::Isotropic => "isotropic",
}
}
fn coord_name(self) -> &'static str {
match self {
SpatialHyperKind::Anisotropic => "psi",
SpatialHyperKind::Isotropic => "kappa",
}
}
}
struct SpatialFrozenGlmInputs {
y: Array1<f64>,
weights: Array1<f64>,
offset: Array1<f64>,
family: LikelihoodSpec,
}
fn frozen_glm_tensor_eligible_family(family: &LikelihoodSpec) -> bool {
!family.is_gaussian_identity()
&& matches!(
&family.response,
ResponseFamily::Binomial
| ResponseFamily::Poisson
| ResponseFamily::Gamma
| ResponseFamily::NegativeBinomial { .. }
)
}
struct SpatialJointContext<'d> {
data: ArrayView2<'d, f64>,
rho_dim: usize,
kind: SpatialHyperKind,
cache: SingleBlockExactJointDesignCache<'d>,
evaluator: crate::estimate::ExternalJointHyperEvaluator<'d>,
frozen_glm_inputs: Option<SpatialFrozenGlmInputs>,
frozen_glm_psi_bounds: Option<(f64, f64)>,
frozen_glm_tensor: Option<crate::solver::glm_sufficient_lane::FrozenWeightGramTensor>,
frozen_glm_tensor_attempted: bool,
frozen_glm_weight_memo: Option<(Array1<f64>, Array1<f64>)>,
}
#[derive(Clone, Copy, Debug, Default)]
struct NfreeSkipGateStatus {
shape: bool,
value: bool,
gradient: bool,
penalty: bool,
revision: bool,
second_order: bool,
}
impl NfreeSkipGateStatus {
fn would_skip(self, require_gradient: bool) -> bool {
self.shape
&& self.value
&& (!require_gradient || self.gradient)
&& self.penalty
&& self.revision
&& !self.second_order
}
}
impl<'d> SpatialJointContext<'d> {
fn nfree_skip_gate_status(
&self,
theta: &Array1<f64>,
allow_second_order: bool,
require_gradient: bool,
) -> NfreeSkipGateStatus {
let shape = theta.len() == self.rho_dim + 1;
let (value, gradient) = if shape {
let psi = theta[self.rho_dim];
(
self.evaluator.psi_gram_tensor_covers(psi)
&& self.evaluator.psi_gram_tensor_covers_skip(psi),
!require_gradient || self.evaluator.psi_gram_tensor_covers_gradient(psi),
)
} else {
(false, false)
};
NfreeSkipGateStatus {
shape,
value,
gradient,
penalty: self.evaluator.supports_nfree_penalty_rekey(),
revision: self.evaluator.nfree_fast_path_revision().is_some(),
second_order: allow_second_order,
}
}
fn frozen_glm_working_state(
&self,
beta: &Array1<f64>,
) -> Result<Option<(Array1<f64>, Array1<f64>)>, EstimationError> {
let Some(inputs) = self.frozen_glm_inputs.as_ref() else {
return Ok(None);
};
if beta.len() != self.cache.design().design.ncols() {
return Ok(None);
}
let mut eta = self.cache.design().design.matrixvectormultiply(beta);
if eta.len() != inputs.offset.len() {
crate::bail_invalid_estim!(
"frozen GLM tensor warm-state row mismatch: eta={}, offset={}",
eta.len(),
inputs.offset.len()
);
}
eta += &inputs.offset;
let obs = evaluate_standard_familyobservations(
inputs.family.clone(),
None,
None,
None,
&inputs.y,
&inputs.weights,
&eta,
)?;
let mut working_response = obs.eta.clone();
for i in 0..working_response.len() {
let wi = obs.fisherweight[i].max(1e-12);
working_response[i] += obs.score[i] / wi;
}
Ok(Some((obs.fisherweight, working_response)))
}
fn frozen_glm_trial_weights(
&mut self,
beta: &Array1<f64>,
) -> Result<Option<Array1<f64>>, EstimationError> {
if let Some((memo_beta, memo_w)) = self.frozen_glm_weight_memo.as_ref()
&& memo_beta.len() == beta.len()
&& memo_beta
.iter()
.zip(beta.iter())
.all(|(a, b)| a.to_bits() == b.to_bits())
{
return Ok(Some(memo_w.clone()));
}
match self.frozen_glm_working_state(beta)? {
Some((current_w, _)) => {
self.frozen_glm_weight_memo = Some((beta.clone(), current_w.clone()));
Ok(Some(current_w))
}
None => Ok(None),
}
}
fn ensure_frozen_glm_tensor(
&mut self,
theta: &Array1<f64>,
warm_beta: Option<&Array1<f64>>,
) -> Result<(), EstimationError> {
if self.frozen_glm_tensor.is_some() || self.frozen_glm_tensor_attempted {
return Ok(());
}
let Some((psi_lo, psi_hi)) = self.frozen_glm_psi_bounds else {
return Ok(());
};
if theta.len() != self.rho_dim + 1 {
self.frozen_glm_tensor_attempted = true;
return Ok(());
}
let Some(beta) = warm_beta else {
return Ok(());
};
let Some((frozen_w, working_z)) = self.frozen_glm_working_state(beta)? else {
self.frozen_glm_tensor_attempted = true;
return Ok(());
};
let theta_probe_base = theta.clone();
let rho_dim = self.rho_dim;
let Self {
cache, evaluator, ..
} = self;
let tensor = evaluator.build_frozen_glm_gram_tensor(
|psi| {
let mut theta_probe = theta_probe_base.clone();
theta_probe[rho_dim] = psi;
cache.ensure_theta(&theta_probe)?;
Ok(cache.design().design.clone())
},
frozen_w.view(),
working_z.view(),
psi_lo,
psi_hi,
);
self.cache
.ensure_theta(theta)
.map_err(EstimationError::InvalidInput)?;
self.frozen_glm_tensor_attempted = true;
if let Some(tensor) = tensor {
self.frozen_glm_tensor = Some(tensor);
log::info!(
"[STAGE] {} certified frozen-W GLM ψ tensor over [{psi_lo:.3}, {psi_hi:.3}]",
self.kind.label(),
);
} else {
log::info!(
"[STAGE] {} frozen-W GLM ψ tensor did not certify over [{psi_lo:.3}, {psi_hi:.3}]",
self.kind.label(),
);
}
Ok(())
}
fn stage_frozen_glm_trial_statistics(
&mut self,
theta: &Array1<f64>,
warm_beta: Option<&Array1<f64>>,
allow_gradient: bool,
) -> Result<(), EstimationError> {
let kind = self.kind;
let mut staged_gram: Option<Array2<f64>> = None;
let mut staged_deriv: Option<(Array2<f64>, Array1<f64>)> = None;
if theta.len() == self.rho_dim + 1 {
let psi = theta[self.rho_dim];
let tensor_covers = self
.frozen_glm_tensor
.as_ref()
.is_some_and(|t| t.contains(psi));
let current_w = if tensor_covers {
match warm_beta {
Some(beta) => self.frozen_glm_trial_weights(beta)?,
None => None,
}
} else {
None
};
if let (Some(tensor), Some(current_w)) =
(self.frozen_glm_tensor.as_ref(), current_w.as_ref())
{
const FROZEN_GLM_WEIGHT_DRIFT_RTOL: f64 = 1e-3;
if tensor.weight_drift_within(current_w.view(), FROZEN_GLM_WEIGHT_DRIFT_RTOL) {
staged_gram = Some(tensor.gram_at(psi));
log::debug!(
"[STAGE] {} trial at psi={psi:.6}: serving frozen-W GLM \
first-Fisher-step XᵀWX n-free (weight drift within tol)",
kind.label(),
);
}
if allow_gradient
&& tensor.contains_for_gradient(psi)
&& let Some((dgram_dpsi, drhs_dpsi)) =
tensor.gradient_pair_if_sound(psi, current_w.view())
{
staged_deriv = Some((dgram_dpsi, drhs_dpsi));
log::debug!(
"[STAGE] {} trial at psi={psi:.6}: serving frozen-W GLM \
ψ-gradient (∂G/∂ψ, ∂b/∂ψ) n-free (gradient weight drift within \
tight tol); B_j stays exact",
kind.label(),
);
}
}
}
self.evaluator.stage_glm_first_step_gram(staged_gram);
self.evaluator.stage_glm_psi_gram_deriv(staged_deriv);
Ok(())
}
fn eval_full(
&mut self,
theta: &Array1<f64>,
order: crate::solver::rho_optimizer::OuterEvalOrder,
analytic_outer_hessian_available: bool,
) -> Result<
(
f64,
Array1<f64>,
crate::solver::rho_optimizer::HessianResult,
),
EstimationError,
> {
use crate::solver::rho_optimizer::OuterEvalOrder;
let allow_second_order = matches!(order, OuterEvalOrder::ValueGradientHessian)
&& analytic_outer_hessian_available;
if let Some(eval) = self.cache.memoized_eval(theta) {
let cached_satisfies_order = !allow_second_order || eval.2.is_analytic();
if cached_satisfies_order {
return Ok(eval);
}
}
let kind = self.kind;
let nfree_fast_path_revision = self.evaluator.nfree_fast_path_revision();
let skip_design_realization = !allow_second_order && theta.len() == self.rho_dim + 1 && {
let psi = theta[self.rho_dim];
self.evaluator.psi_gram_tensor_covers(psi)
&& self.evaluator.psi_gram_tensor_covers_gradient(psi)
&& self.evaluator.psi_gram_tensor_covers_skip(psi)
&& self.evaluator.supports_nfree_penalty_rekey()
&& nfree_fast_path_revision.is_some()
};
if skip_design_realization {
log::debug!(
"[STAGE] {} eval_full at psi={:.6}: skipping n×k design re-realization \
+ reconditioning — criterion/gradient/inner-solve served n-free from \
the certified ψ-gram tensor (GaussianFixedCache + k-space ψ-derivatives)",
kind.label(),
theta[self.rho_dim],
);
} else {
self.cache
.ensure_theta(theta)
.map_err(EstimationError::InvalidInput)?;
}
let warm_beta = self.evaluator.current_beta();
self.ensure_frozen_glm_tensor(theta, warm_beta.as_ref())?;
self.stage_frozen_glm_trial_statistics(theta, warm_beta.as_ref(), !allow_second_order)?;
let hyper_dirs = if skip_design_realization {
self.cache.nfree_tensor_gradient_hyper_dirs(theta)?
} else {
self.cache.hyper_dirs_for_current_design(self.data, kind)?
};
let design_revision = if skip_design_realization {
nfree_fast_path_revision
} else {
Some(self.cache.design_revision())
};
if self.evaluator.supports_nfree_penalty_rekey() {
match self.cache.canonical_penalties_at(theta) {
Ok(penalty) => self.evaluator.stage_fast_path_penalty(Some(penalty)),
Err(e) => {
log::warn!(
"[STAGE] {} eval_full at psi={:.6}: exact n-free S(ψ) rebuild failed \
({e}); clearing stage (eval falls to slow path)",
kind.label(),
theta[self.rho_dim],
);
self.evaluator.stage_fast_path_penalty(None);
}
}
}
let eval = evaluate_joint_reml_outer_eval_at_theta(
&mut self.evaluator,
self.cache.design(),
theta,
self.rho_dim,
hyper_dirs,
warm_beta.as_ref().map(|b| b.view()),
if allow_second_order {
order
} else {
OuterEvalOrder::ValueAndGradient
},
design_revision,
);
if let Ok(ref value) = eval {
self.cache.store_eval_at(theta, value.clone());
}
eval
}
fn eval_efs(
&mut self,
theta: &Array1<f64>,
) -> Result<crate::solver::rho_optimizer::EfsEval, EstimationError> {
self.cache
.ensure_theta(theta)
.map_err(EstimationError::InvalidInput)?;
let kind = self.kind;
let hyper_dirs = try_build_spatial_log_kappa_hyper_dirs(
self.data,
self.cache.spec(),
self.cache.design(),
&self.cache.spatial_terms,
)?
.ok_or_else(|| {
EstimationError::InvalidInput(format!(
"failed to build {} hyper_dirs for exact-joint EFS",
kind.adjective(),
))
})?;
let design_revision = Some(self.cache.design_revision());
let warm_beta = self.evaluator.current_beta();
evaluate_joint_reml_efs_at_theta(
&mut self.evaluator,
self.cache.design(),
theta,
self.rho_dim,
hyper_dirs,
warm_beta.as_ref().map(|b| b.view()),
design_revision,
)
}
fn eval_cost(&mut self, theta: &Array1<f64>) -> f64 {
if let Some(cost) = self.cache.memoized_cost(theta) {
return cost;
}
let probe_start = std::time::Instant::now();
let psi_distance = self
.cache
.current_theta
.as_ref()
.filter(|reference| reference.len() == theta.len())
.map(|reference| {
reference
.iter()
.zip(theta.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum::<f64>()
.sqrt()
})
.unwrap_or(f64::NAN);
let nfree_fast_path_revision = self.evaluator.nfree_fast_path_revision();
let skip_value_realization = theta.len() == self.rho_dim + 1 && {
let psi = theta[self.rho_dim];
self.evaluator.psi_gram_tensor_covers(psi)
&& self.evaluator.psi_gram_tensor_covers_skip(psi)
&& self.evaluator.supports_nfree_penalty_rekey()
&& nfree_fast_path_revision.is_some()
};
if theta.len() == self.rho_dim + 1
&& self.evaluator.has_psi_gram_tensor()
&& !self.evaluator.psi_gram_tensor_covers(theta[self.rho_dim])
{
self.cache.store_cost_at(theta, f64::INFINITY);
return f64::INFINITY;
}
if !skip_value_realization && self.cache.ensure_theta(theta).is_err() {
return f64::INFINITY;
}
if self.evaluator.supports_nfree_penalty_rekey() {
match self.cache.canonical_penalties_at(theta) {
Ok(penalty) => self.evaluator.stage_fast_path_penalty(Some(penalty)),
Err(_) => self.evaluator.stage_fast_path_penalty(None),
}
}
let warm_beta = self.evaluator.current_beta();
if let Err(err) = self.ensure_frozen_glm_tensor(theta, warm_beta.as_ref()) {
log::warn!(
"[STAGE] {} value-probe at psi={:.6}: frozen-W GLM tensor setup failed ({err}); \
falling back to exact streamed Gram",
self.kind.label(),
if theta.len() > self.rho_dim {
theta[self.rho_dim]
} else {
f64::NAN
},
);
self.evaluator.stage_glm_first_step_gram(None);
self.evaluator.stage_glm_psi_gram_deriv(None);
} else if let Err(err) =
self.stage_frozen_glm_trial_statistics(theta, warm_beta.as_ref(), false)
{
log::warn!(
"[STAGE] {} value-probe at psi={:.6}: frozen-W GLM staging failed ({err}); \
falling back to exact streamed Gram",
self.kind.label(),
if theta.len() > self.rho_dim {
theta[self.rho_dim]
} else {
f64::NAN
},
);
self.evaluator.stage_glm_first_step_gram(None);
self.evaluator.stage_glm_psi_gram_deriv(None);
}
let design_revision = if skip_value_realization {
nfree_fast_path_revision
} else {
Some(self.cache.design_revision())
};
let cost_label = self.kind.label();
let result = {
let design = self.cache.design();
self.evaluator.evaluate_cost_only(
&design.design,
&design.penalties,
&design.nullspace_dims,
design.linear_constraints.clone(),
theta,
self.rho_dim,
warm_beta.as_ref().map(|b| b.view()),
cost_label,
design_revision,
)
};
match result {
Ok(cost) => {
log::debug!(
"[STAGE] {cost_label} value-probe (order=Value): elapsed={:.3}s \
cost={cost:.6e} trial_theta_distance={psi_distance:.3e}",
probe_start.elapsed().as_secs_f64(),
);
self.cache.store_cost_at(theta, cost);
cost
}
Err(_) => f64::INFINITY,
}
}
fn reset(&mut self) {
self.cache.current_theta = None;
self.cache.last_eval_theta = None;
self.cache.last_cost = None;
self.cache.last_eval = None;
}
}
enum SpatialJointOutcome {
Optimized {
theta_star: Array1<f64>,
final_value: f64,
},
NonConverged {
iterations: usize,
final_value: f64,
final_grad_norm: Option<f64>,
},
}
fn kphase_log_norms(theta: &Array1<f64>, rho_dim: usize) -> (f64, f64) {
let theta_norm = theta.iter().map(|v| v * v).sum::<f64>().sqrt();
let log_kappa_norm = theta
.iter()
.skip(rho_dim)
.map(|v| v * v)
.sum::<f64>()
.sqrt();
(theta_norm, log_kappa_norm)
}
fn run_exact_joint_spatial_optimization(
kind: SpatialHyperKind,
data: ArrayView2<'_, f64>,
y: ArrayView1<'_, f64>,
weights: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
resolvedspec: &TermCollectionSpec,
baseline_design: &TermCollectionDesign,
family: LikelihoodSpec,
options: &FitOptions,
spatial_terms: &[usize],
dims_per_term: &[usize],
theta0: &Array1<f64>,
lower: &Array1<f64>,
upper: &Array1<f64>,
rho_dim: usize,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<(SpatialJointOutcome, SpatialLengthScaleOptimizationTiming), EstimationError> {
let label = kind.label();
assert!(
lower.len() == theta0.len() && upper.len() == theta0.len(),
"spatial hyperparameter bounds must match theta length: lower_len={}, upper_len={}, theta_len={}",
lower.len(),
upper.len(),
theta0.len()
);
assert!(
baseline_design.smooth.terms.len() >= spatial_terms.len(),
"baseline design must have at least one smooth term per spatial term: baseline_terms={}, spatial_terms={}",
baseline_design.smooth.terms.len(),
spatial_terms.len()
);
use crate::solver::rho_optimizer::{
DeclaredHessianForm, Derivative, OuterEval, OuterEvalOrder,
};
let theta_dim = theta0.len();
let coord_dim = theta_dim - rho_dim;
let analytic_outer_hessian_available =
exact_joint_spatial_outer_hessian_available(&family, baseline_design);
if !analytic_outer_hessian_available {
log::info!(
"[{label}] analytic outer Hessian unavailable for family/design; routing without second-order geometry (coord_dim={coord_dim})"
);
}
let mut prefer_gradient_only = theta_dim > EXACT_JOINT_SECOND_ORDER_THETA_CAP;
if prefer_gradient_only {
log::info!(
"[{label}] joint θ-dim {theta_dim} exceeds the exact pair-Hessian budget \
({EXACT_JOINT_SECOND_ORDER_THETA_CAP}); routing gradient-only quasi-Newton"
);
}
let mut suppress_outer_hessian_for_nfree = false;
log::trace!(
"[{}] starting analytic optimization: rho_dim={}, coord_dim={}, dims_per_term={:?}",
label,
rho_dim,
coord_dim,
dims_per_term,
);
let mut ctx = SpatialJointContext {
data,
rho_dim,
kind,
cache: SingleBlockExactJointDesignCache::new(
data,
resolvedspec.clone(),
baseline_design.clone(),
spatial_terms.to_vec(),
rho_dim,
dims_per_term.to_vec(),
)
.map_err(EstimationError::InvalidInput)?,
evaluator: crate::estimate::ExternalJointHyperEvaluator::new(
y,
weights,
&baseline_design.design,
offset,
&baseline_design.penalties,
&external_opts_for_design(&family, baseline_design, options),
label,
)?,
frozen_glm_inputs: if coord_dim == 1 && frozen_glm_tensor_eligible_family(&family) {
Some(SpatialFrozenGlmInputs {
y: y.to_owned(),
weights: weights.to_owned(),
offset: offset.to_owned(),
family: family.clone(),
})
} else {
None
},
frozen_glm_psi_bounds: if coord_dim == 1 && frozen_glm_tensor_eligible_family(&family) {
Some((lower[rho_dim], upper[rho_dim]))
} else {
None
},
frozen_glm_tensor: None,
frozen_glm_tensor_attempted: false,
frozen_glm_weight_memo: None,
};
let nfree_penalty_capable = coord_dim == 1
&& family.is_gaussian_identity()
&& ctx.cache.supports_nfree_penalty_rekey();
if nfree_penalty_capable {
let psi_lo = lower[rho_dim];
let psi_hi = upper[rho_dim];
let z = Array1::from_iter(y.iter().zip(offset.iter()).map(|(yi, oi)| yi - oi));
let theta_probe_base = theta0.clone();
let SpatialJointContext {
cache, evaluator, ..
} = &mut ctx;
let attached = evaluator.build_and_set_psi_gram_tensor(
|psi| {
let mut theta_probe = theta_probe_base.clone();
theta_probe[rho_dim] = psi;
cache.ensure_theta(&theta_probe)?;
Ok(cache.design().design.clone())
},
weights,
z.view(),
psi_lo,
psi_hi,
);
if attached {
log::info!(
"[{label}] certified ψ-gram tensor over [{psi_lo:.3}, {psi_hi:.3}]: \
in-window trials assemble Gaussian sufficient statistics n-free"
);
let gradient_covers_full_window = evaluator.psi_gram_tensor_covers_gradient(psi_lo)
&& evaluator.psi_gram_tensor_covers_gradient(psi_hi);
if gradient_covers_full_window {
log::info!(
"[{label}] certified ψ-gram tensor gradient lane covers the full \
optimizer window [{psi_lo:.3}, {psi_hi:.3}]"
);
} else {
log::info!(
"[{label}] ψ-gram tensor value lane certified, but the gradient lane \
does not cover the full optimizer window [{psi_lo:.3}, {psi_hi:.3}]; \
keeping exact streamed kappa routing"
);
}
evaluator.set_supports_nfree_penalty_rekey(true);
log::info!(
"[{label}] exact n-free ψ-penalty re-key enabled over [{psi_lo:.3}, \
{psi_hi:.3}]: in-window fast-path trials rebuild S(ψ) n-free from frozen \
geometry (no reset_surface)"
);
} else {
log::info!(
"[{label}] ψ-gram tensor did not certify over [{psi_lo:.3}, {psi_hi:.3}]; \
keeping the exact per-trial path"
);
}
if attached
&& evaluator.psi_gram_tensor_covers_gradient(psi_lo)
&& evaluator.psi_gram_tensor_covers_gradient(psi_hi)
&& evaluator.supports_nfree_penalty_rekey()
&& cache.supports_nfree_gradient_only_routing()
{
suppress_outer_hessian_for_nfree = true;
prefer_gradient_only = true;
log::info!(
"[{label}] n-free Gaussian ψ-lane armed; suppressing the analytic outer \
Hessian and routing gradient-only (BFGS) so the κ outer loop never realizes \
the O(n) second-order slab — n-independent outer loop (#1033)"
);
}
} else if coord_dim == 1 && family.is_gaussian_identity() {
log::info!(
"[{label}] exact n-free ψ-penalty re-key unavailable; skipping ψ-gram tensor \
attachment so value, gradient, and Hessian remain on the same exact streamed \
objective"
);
}
const OUTER_FD_AUDIT_MAX_N: usize = 4_000; const OUTER_FD_AUDIT_MAX_THETA_DIM: usize = 32; let n_total = data.nrows();
let outer_fd_audit_eligible = analytic_outer_hessian_available && n_total <= OUTER_FD_AUDIT_MAX_N && theta_dim <= OUTER_FD_AUDIT_MAX_THETA_DIM; log::warn!(
"[OUTER-FD-AUDIT/spatial-exact-joint] gate eligible={outer_fd_audit_eligible} \
analytic_grad={analytic_outer_hessian_available} n_total={n_total} \
theta_dim={theta_dim} rho_dim={rho_dim} psi_dim={coord_dim}"
);
if outer_fd_audit_eligible {
let audit = (|| -> Result<crate::solver::rho_optimizer::OuterGradientFdAudit, String> {
let mut eval_at = |theta: &Array1<f64>,
mode: crate::solver::estimate::reml::reml_outer_engine::EvalMode|
-> Result<
(
f64,
Array1<f64>,
crate::solver::rho_optimizer::HessianResult,
),
String,
> {
use crate::solver::estimate::reml::reml_outer_engine::EvalMode;
let order = if matches!(mode, EvalMode::ValueGradientHessian) {
OuterEvalOrder::ValueGradientHessian
} else {
OuterEvalOrder::Value
};
ctx.eval_full(theta, order, analytic_outer_hessian_available)
.map_err(|e| format!("fd-audit eval_full: {e}"))
};
let rho_dim_audit = rho_dim;
let label_fn = move |i: usize| -> String {
if i < rho_dim_audit {
format!("rho[{i}]")
} else {
format!("psi_kappa[{}]", i - rho_dim_audit)
}
};
crate::solver::rho_optimizer::outer_gradient_fd_audit(
theta0,
1e-4,
label_fn,
&mut eval_at,
)
})();
match audit {
Ok(audit) => audit.log_verdict("spatial-exact-joint"),
Err(e) => log::warn!("[OUTER-FD-AUDIT/spatial-exact-joint] skipped: {e}"),
}
}
let kphase_prime_order = if analytic_outer_hessian_available && !suppress_outer_hessian_for_nfree {
OuterEvalOrder::ValueGradientHessian
} else {
OuterEvalOrder::ValueAndGradient
};
let kphase_prime_start = std::time::Instant::now();
drop(ctx.eval_full(theta0, kphase_prime_order, analytic_outer_hessian_available)?);
log::info!(
"[KAPPA-PHASE-PRIME] order={:?} elapsed_s={:.4} slow_path_resets_total={} design_revision={}",
kphase_prime_order,
kphase_prime_start.elapsed().as_secs_f64(),
ctx.evaluator.slow_path_reset_count(),
ctx.cache.design_revision(),
);
let kphase_cost_calls = std::cell::Cell::new(0usize);
let kphase_eval_calls = std::cell::Cell::new(0usize);
let kphase_efs_calls = std::cell::Cell::new(0usize);
let kphase_cost_total_s = std::cell::Cell::new(0.0);
let kphase_eval_total_s = std::cell::Cell::new(0.0);
let kphase_efs_total_s = std::cell::Cell::new(0.0);
let kphase_nfree_miss_shape = std::cell::Cell::new(0u64);
let kphase_nfree_miss_value = std::cell::Cell::new(0u64);
let kphase_nfree_miss_gradient = std::cell::Cell::new(0u64);
let kphase_nfree_miss_penalty = std::cell::Cell::new(0u64);
let kphase_nfree_miss_revision = std::cell::Cell::new(0u64);
let kphase_nfree_miss_second_order = std::cell::Cell::new(0u64);
let kphase_nfree_miss_other = std::cell::Cell::new(0u64);
let kphase_optim_start = std::time::Instant::now();
let kphase_log_kappa_dim = coord_dim;
let kphase_slow_resets_start = ctx.evaluator.slow_path_reset_count();
let kphase_design_revision_start = ctx.cache.design_revision();
let problem = exact_joint_multistart_outer_problem(
theta0,
lower,
upper,
rho_dim,
coord_dim,
theta_dim,
Derivative::Analytic,
if analytic_outer_hessian_available && !suppress_outer_hessian_for_nfree {
DeclaredHessianForm::Either
} else {
DeclaredHessianForm::Unavailable
},
prefer_gradient_only,
suppress_outer_hessian_for_nfree,
seed_risk_profile_for_likelihood_family(&family),
kappa_options.rel_tol.max(1e-6),
kappa_options.max_outer_iter.max(1),
Some(5.0),
Some(kappa_options.log_step.clamp(0.25, 1.0)),
None,
Some((data.nrows(), baseline_design.design.ncols())),
!constant_curvature_term_indices(resolvedspec).is_empty(),
);
let eval_outer = |ctx: &mut &mut SpatialJointContext<'_>,
theta: &Array1<f64>,
order: OuterEvalOrder|
-> Result<OuterEval, EstimationError> {
let t0 = std::time::Instant::now();
let allow_second_order_for_call = matches!(order, OuterEvalOrder::ValueGradientHessian)
&& analytic_outer_hessian_available;
let gate = ctx.nfree_skip_gate_status(theta, allow_second_order_for_call, true);
let resets_before = ctx.evaluator.slow_path_reset_count();
let raw = ctx.eval_full(theta, order, analytic_outer_hessian_available);
let reset_delta = ctx
.evaluator
.slow_path_reset_count()
.saturating_sub(resets_before);
if reset_delta > 0 {
if !gate.shape {
kphase_nfree_miss_shape.set(kphase_nfree_miss_shape.get() + reset_delta);
}
if gate.shape && !gate.value {
kphase_nfree_miss_value.set(kphase_nfree_miss_value.get() + reset_delta);
}
if gate.shape && gate.value && !gate.gradient {
kphase_nfree_miss_gradient.set(kphase_nfree_miss_gradient.get() + reset_delta);
}
if gate.shape && gate.value && gate.gradient && !gate.penalty {
kphase_nfree_miss_penalty.set(kphase_nfree_miss_penalty.get() + reset_delta);
}
if gate.shape && gate.value && gate.gradient && gate.penalty && !gate.revision {
kphase_nfree_miss_revision.set(kphase_nfree_miss_revision.get() + reset_delta);
}
if gate.shape
&& gate.value
&& gate.gradient
&& gate.penalty
&& gate.revision
&& gate.second_order
{
kphase_nfree_miss_second_order
.set(kphase_nfree_miss_second_order.get() + reset_delta);
}
if gate.would_skip(true) {
kphase_nfree_miss_other.set(kphase_nfree_miss_other.get() + reset_delta);
}
}
let elapsed_s = t0.elapsed().as_secs_f64();
kphase_eval_calls.set(kphase_eval_calls.get() + 1);
kphase_eval_total_s.set(kphase_eval_total_s.get() + elapsed_s);
let (theta_norm, log_kappa_norm) = kphase_log_norms(theta, rho_dim);
log::info!(
"[KAPPA-PHASE] phase=eval_outer call={} order={:?} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
kphase_eval_calls.get(),
order,
Some(ctx.cache.design_revision()),
theta_norm,
log_kappa_norm,
elapsed_s,
);
match raw {
Ok((cost, grad, hess)) => Ok(OuterEval {
cost,
gradient: grad,
hessian: hess,
inner_beta_hint: None,
}),
Err(err) if is_recoverable_trial_point_error(&err) => {
log::debug!(
"[{label}] trial point infeasible (kernel design \
not constructible at theta={theta:?}): {err}; retreating",
);
Ok(OuterEval::infeasible(theta_dim))
}
Err(err) => Err(err),
}
};
let mut obj = problem.build_objective_with_eval_order(
&mut ctx,
|ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>| {
let t0 = std::time::Instant::now();
let gate = ctx.nfree_skip_gate_status(theta, false, false);
let resets_before = ctx.evaluator.slow_path_reset_count();
let cost = ctx.eval_cost(theta);
let reset_delta = ctx
.evaluator
.slow_path_reset_count()
.saturating_sub(resets_before);
if reset_delta > 0 {
if !gate.shape {
kphase_nfree_miss_shape.set(kphase_nfree_miss_shape.get() + reset_delta);
}
if gate.shape && !gate.value {
kphase_nfree_miss_value.set(kphase_nfree_miss_value.get() + reset_delta);
}
if gate.shape && gate.value && !gate.penalty {
kphase_nfree_miss_penalty.set(kphase_nfree_miss_penalty.get() + reset_delta);
}
if gate.shape && gate.value && gate.penalty && !gate.revision {
kphase_nfree_miss_revision.set(kphase_nfree_miss_revision.get() + reset_delta);
}
if gate.would_skip(false) {
kphase_nfree_miss_other.set(kphase_nfree_miss_other.get() + reset_delta);
}
}
let elapsed_s = t0.elapsed().as_secs_f64();
kphase_cost_calls.set(kphase_cost_calls.get() + 1);
kphase_cost_total_s.set(kphase_cost_total_s.get() + elapsed_s);
let (theta_norm, log_kappa_norm) = kphase_log_norms(theta, rho_dim);
log::info!(
"[KAPPA-PHASE] phase=cost call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
kphase_cost_calls.get(),
Some(ctx.cache.design_revision()),
theta_norm,
log_kappa_norm,
elapsed_s,
);
Ok(cost)
},
|ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>| {
eval_outer(
ctx,
theta,
if analytic_outer_hessian_available && !suppress_outer_hessian_for_nfree {
OuterEvalOrder::ValueGradientHessian
} else {
OuterEvalOrder::ValueAndGradient
},
)
},
|ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>, order: OuterEvalOrder| {
eval_outer(ctx, theta, order)
},
Some(|ctx: &mut &mut SpatialJointContext<'_>| {
ctx.reset();
}),
Some(|ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>| {
let t0 = std::time::Instant::now();
let eval = ctx.eval_efs(theta);
let elapsed_s = t0.elapsed().as_secs_f64();
kphase_efs_calls.set(kphase_efs_calls.get() + 1);
kphase_efs_total_s.set(kphase_efs_total_s.get() + elapsed_s);
let (theta_norm, log_kappa_norm) = kphase_log_norms(theta, rho_dim);
log::info!(
"[KAPPA-PHASE] phase=efs call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
kphase_efs_calls.get(),
Some(ctx.cache.design_revision()),
theta_norm,
log_kappa_norm,
elapsed_s,
);
eval
}),
);
let run_label = match kind {
SpatialHyperKind::Anisotropic => "aniso-psi joint REML",
SpatialHyperKind::Isotropic => "iso-kappa joint REML",
};
let result = problem.run(&mut obj, run_label).map_err(|e| {
EstimationError::InvalidInput(format!(
"{} analytic optimization failed after exhausting strategy fallbacks: {e}",
kind.adjective(),
))
})?;
drop(obj);
let kphase_total_s = kphase_optim_start.elapsed().as_secs_f64();
let kphase_slow_resets = ctx
.evaluator
.slow_path_reset_count()
.saturating_sub(kphase_slow_resets_start);
let kphase_design_revision_delta = ctx
.cache
.design_revision()
.saturating_sub(kphase_design_revision_start);
log::info!(
"[KAPPA-PHASE-SUMMARY] log_kappa_dim={} n_cost={} cost_total_s={:.4} n_eval={} eval_total_s={:.4} n_efs={} efs_total_s={:.4} slow_path_resets={} design_revision_delta={} nfree_miss_shape={} nfree_miss_value={} nfree_miss_gradient={} nfree_miss_penalty={} nfree_miss_revision={} nfree_miss_second_order={} nfree_miss_other={} optim_total_s={:.4}",
kphase_log_kappa_dim,
kphase_cost_calls.get(),
kphase_cost_total_s.get(),
kphase_eval_calls.get(),
kphase_eval_total_s.get(),
kphase_efs_calls.get(),
kphase_efs_total_s.get(),
kphase_slow_resets,
kphase_design_revision_delta,
kphase_nfree_miss_shape.get(),
kphase_nfree_miss_value.get(),
kphase_nfree_miss_gradient.get(),
kphase_nfree_miss_penalty.get(),
kphase_nfree_miss_revision.get(),
kphase_nfree_miss_second_order.get(),
kphase_nfree_miss_other.get(),
kphase_total_s,
);
let timing = SpatialLengthScaleOptimizationTiming {
log_kappa_dim: kphase_log_kappa_dim,
cost_calls: kphase_cost_calls.get(),
cost_total_s: kphase_cost_total_s.get(),
eval_calls: kphase_eval_calls.get(),
eval_total_s: kphase_eval_total_s.get(),
efs_calls: kphase_efs_calls.get(),
efs_total_s: kphase_efs_total_s.get(),
slow_path_resets: kphase_slow_resets,
design_revision_delta: kphase_design_revision_delta,
nfree_miss_shape: kphase_nfree_miss_shape.get(),
nfree_miss_value: kphase_nfree_miss_value.get(),
nfree_miss_gradient: kphase_nfree_miss_gradient.get(),
nfree_miss_penalty: kphase_nfree_miss_penalty.get(),
nfree_miss_revision: kphase_nfree_miss_revision.get(),
nfree_miss_second_order: kphase_nfree_miss_second_order.get(),
nfree_miss_other: kphase_nfree_miss_other.get(),
optim_total_s: kphase_total_s,
};
if !result.converged {
let rel_to_cost_threshold = options.tol * (1.0_f64 + result.final_value.abs());
if let Some(final_grad) = result
.final_grad_norm
.filter(|v| v.is_finite() && *v <= rel_to_cost_threshold)
{
log::info!(
"[{}] outer optimization hit max_iter={} but \
projected gradient norm {:.3e} ≤ τ·(1+|f|) = {:.3e} \
(τ={:.3e}, |f|={:.3e}); accepting iterate under the mgcv-style \
relative-to-cost REML convergence criterion.",
label,
result.iterations,
final_grad,
rel_to_cost_threshold,
options.tol,
result.final_value.abs(),
);
} else if result.final_value.is_finite() {
log::warn!(
"[{}] {} did not converge after {} iterations \
(final_objective={:.6e}, final_grad_norm={}); keeping the \
frozen baseline geometry instead of aborting the fit.",
label,
kind.adjective(),
result.iterations,
result.final_value,
result.final_grad_norm_report(),
);
return Ok((
SpatialJointOutcome::NonConverged {
iterations: result.iterations,
final_value: result.final_value,
final_grad_norm: result.final_grad_norm,
},
timing,
));
} else {
crate::bail_invalid_estim!(
"{} analytic optimization diverged after {} iterations (final_objective={:.6e}, final_grad_norm={})",
kind.adjective(),
result.iterations,
result.final_value,
result.final_grad_norm_report(),
);
}
}
log::trace!(
"[{}] converged in {} iterations, final_value={:.6e}, grad_norm={}",
label,
result.iterations,
result.final_value,
result.final_grad_norm_report(),
);
let theta_star = result.rho;
Ok((
SpatialJointOutcome::Optimized {
theta_star,
final_value: result.final_value,
},
timing,
))
}
fn set_spatial_length_scale(
spec: &mut TermCollectionSpec,
term_idx: usize,
length_scale: f64,
) -> Result<(), EstimationError> {
let Some(term) = spec.smooth_terms.get_mut(term_idx) else {
crate::bail_invalid_estim!("spatial length-scale term index {term_idx} out of range");
};
match &mut term.basis {
SmoothBasisSpec::ThinPlate { spec, .. } => {
spec.length_scale = length_scale;
Ok(())
}
SmoothBasisSpec::Matern { spec, .. } => {
spec.length_scale = length_scale;
Ok(())
}
SmoothBasisSpec::Duchon { spec, .. } => {
spec.length_scale = Some(length_scale);
Ok(())
}
_ => Err(EstimationError::InvalidInput(format!(
"term '{}' does not expose a spatial length scale",
term.name
))),
}
}
fn set_single_term_spatial_length_scale(
term: &mut SmoothTermSpec,
length_scale: f64,
) -> Result<(), EstimationError> {
match &mut term.basis {
SmoothBasisSpec::ThinPlate { spec, .. } => {
spec.length_scale = length_scale;
Ok(())
}
SmoothBasisSpec::Matern { spec, .. } => {
spec.length_scale = length_scale;
Ok(())
}
SmoothBasisSpec::Duchon { spec, .. } => {
spec.length_scale = Some(length_scale);
Ok(())
}
_ => Err(EstimationError::InvalidInput(format!(
"term '{}' does not expose a spatial length scale",
term.name
))),
}
}
fn set_single_term_spatial_aniso_log_scales(
term: &mut SmoothTermSpec,
eta: Vec<f64>,
) -> Result<(), EstimationError> {
let eta = center_aniso_log_scales(&eta);
match &mut term.basis {
SmoothBasisSpec::Matern { spec, .. } => {
spec.aniso_log_scales = Some(eta);
Ok(())
}
SmoothBasisSpec::Duchon { spec, .. } => {
spec.aniso_log_scales = Some(eta);
Ok(())
}
_ => Err(EstimationError::InvalidInput(format!(
"term '{}' does not support aniso_log_scales",
term.name
))),
}
}
pub fn get_spatial_length_scale(spec: &TermCollectionSpec, term_idx: usize) -> Option<f64> {
spec.smooth_terms
.get(term_idx)
.and_then(|term| match &term.basis {
SmoothBasisSpec::ThinPlate { spec, .. } => Some(spec.length_scale),
SmoothBasisSpec::Matern { spec, .. } => Some(spec.length_scale),
SmoothBasisSpec::Duchon { spec, .. } => spec.length_scale,
_ => None,
})
}
pub fn freeze_measure_jet_length_scale_learning(spec: &mut TermCollectionSpec) -> usize {
let mut frozen = 0;
for term in spec.smooth_terms.iter_mut() {
if let SmoothBasisSpec::MeasureJet { spec: mj, .. } = &mut term.basis
&& mj.learn_length_scale
{
mj.learn_length_scale = false;
frozen += 1;
}
}
frozen
}
pub fn get_constant_curvature_kappa(spec: &TermCollectionSpec, term_idx: usize) -> Option<f64> {
constant_curvature_term_spec(spec, term_idx).map(|cc| cc.kappa)
}
pub fn constant_curvature_term_indices(spec: &TermCollectionSpec) -> Vec<usize> {
(0..spec.smooth_terms.len())
.filter(|&idx| constant_curvature_term_spec(spec, idx).is_some())
.collect()
}
fn freeze_smooth_basis_from_metadata(
basis: &mut SmoothBasisSpec,
metadata: &BasisMetadata,
term_name: &str,
) -> Result<(), EstimationError> {
match (&mut *basis, metadata) {
(SmoothBasisSpec::ByVariable { inner, .. }, meta)
| (SmoothBasisSpec::FactorSumToZero { inner, .. }, meta) => {
freeze_smooth_basis_from_metadata(inner, meta, term_name)?;
}
(
SmoothBasisSpec::BSpline1D { spec: s, .. },
BasisMetadata::BSpline1D {
knots,
identifiability_transform,
periodic,
degree: meta_degree,
..
},
) => {
if let Some(d) = meta_degree {
s.degree = *d;
}
s.knotspec = periodic
.map(
|(domain_start, period, num_basis)| BSplineKnotSpec::PeriodicUniform {
data_range: (domain_start, domain_start + period),
num_basis,
},
)
.unwrap_or_else(|| BSplineKnotSpec::Provided(knots.clone()));
s.identifiability = match identifiability_transform {
Some(z) => BSplineIdentifiability::FrozenTransform {
transform: z.clone(),
},
None => BSplineIdentifiability::None,
};
}
(
SmoothBasisSpec::ThinPlate {
spec: s,
input_scales,
..
},
BasisMetadata::ThinPlate {
centers,
length_scale,
periodic: meta_periodic,
identifiability_transform,
input_scales: meta_scales,
radial_reparam,
},
) => {
s.center_strategy = crate::basis::CenterStrategy::UserProvided(centers.clone());
s.length_scale = *length_scale;
s.identifiability = match identifiability_transform {
Some(z) => SpatialIdentifiability::FrozenTransform {
transform: z.clone(),
},
None => match &s.identifiability {
SpatialIdentifiability::FrozenTransform { .. } => s.identifiability.clone(),
_ => SpatialIdentifiability::None,
},
};
s.radial_reparam = radial_reparam.clone();
s.periodic = meta_periodic.clone();
*input_scales = meta_scales.clone();
}
(
SmoothBasisSpec::ThinPlate { feature_cols, .. },
BasisMetadata::Duchon {
centers,
length_scale,
periodic: meta_periodic,
power,
nullspace_order,
identifiability_transform,
input_scales: meta_scales,
aniso_log_scales: meta_aniso,
radial_reparam: meta_radial_reparam,
..
},
) => {
let identifiability = match identifiability_transform {
Some(z) => SpatialIdentifiability::FrozenTransform {
transform: z.clone(),
},
None => SpatialIdentifiability::None,
};
*basis = SmoothBasisSpec::Duchon {
feature_cols: feature_cols.clone(),
spec: DuchonBasisSpec {
periodic: meta_periodic.clone(),
center_strategy: crate::basis::CenterStrategy::UserProvided(centers.clone()),
length_scale: *length_scale,
power: *power,
nullspace_order: *nullspace_order,
identifiability,
aniso_log_scales: meta_aniso.clone(),
operator_penalties: Default::default(),
boundary: OneDimensionalBoundary::Open,
radial_reparam: meta_radial_reparam.clone(),
},
input_scales: meta_scales.clone(),
};
}
(
SmoothBasisSpec::Sphere { spec: s, .. },
BasisMetadata::Sphere {
centers,
penalty_order,
method,
max_degree,
wahba_kernel,
constraint_transform,
},
) => {
s.center_strategy = crate::basis::CenterStrategy::UserProvided(centers.clone());
s.penalty_order = *penalty_order;
s.method = *method;
s.max_degree = *max_degree;
s.wahba_kernel = *wahba_kernel;
s.identifiability = match constraint_transform {
Some(z) => SphericalSplineIdentifiability::FrozenTransform {
transform: z.clone(),
},
None => SphericalSplineIdentifiability::CenterSumToZero,
};
}
(
SmoothBasisSpec::ConstantCurvature { spec: s, .. },
BasisMetadata::ConstantCurvature {
centers,
kappa,
length_scale,
constraint_transform,
},
) => {
s.center_strategy = crate::basis::CenterStrategy::UserProvided(centers.clone());
s.kappa = *kappa;
s.length_scale = *length_scale;
s.identifiability = match constraint_transform {
Some(z) => ConstantCurvatureIdentifiability::FrozenTransform {
transform: z.clone(),
},
None => ConstantCurvatureIdentifiability::CenterSumToZero,
};
}
(
SmoothBasisSpec::MeasureJet {
spec: s,
input_scales,
..
},
BasisMetadata::MeasureJet {
centers,
input_scales: meta_scales,
length_scale,
eps_band,
order_s,
alpha,
tau0,
masses,
support_means,
penalty_normalization_scales,
raw_penalty_normalization_scales,
fused_penalty_normalization_scale,
constraint_transform,
},
) => {
s.center_strategy = crate::basis::CenterStrategy::UserProvided(centers.clone());
s.length_scale = *length_scale;
s.order_s = *order_s;
s.alpha = *alpha;
s.tau0 = *tau0;
s.num_scales = eps_band.len();
s.frozen_quadrature = Some(MeasureJetFrozenQuadrature {
masses: masses.clone(),
eps_band: eps_band.clone(),
support_means: support_means.clone(),
penalty_normalization_scales: penalty_normalization_scales.clone(),
raw_penalty_normalization_scales: raw_penalty_normalization_scales.clone(),
fused_penalty_normalization_scale: *fused_penalty_normalization_scale,
});
s.identifiability = match constraint_transform {
Some(z) => MeasureJetIdentifiability::FrozenTransform {
transform: z.clone(),
},
None => MeasureJetIdentifiability::CenterSumToZero,
};
*input_scales = meta_scales.clone();
}
(
SmoothBasisSpec::Matern {
spec: s,
input_scales,
..
},
BasisMetadata::Matern {
centers,
length_scale,
periodic: meta_periodic,
nu,
include_intercept,
identifiability_transform,
input_scales: meta_scales,
aniso_log_scales: meta_aniso,
nullspace_shrinkage_survived: meta_nullspace_survived,
},
) => {
s.center_strategy = crate::basis::CenterStrategy::UserProvided(centers.clone());
s.length_scale = *length_scale;
s.nu = *nu;
s.include_intercept = *include_intercept;
s.identifiability = match identifiability_transform {
Some(z) => MaternIdentifiability::FrozenTransform {
transform: z.clone(),
nullspace_shrinkage_survived: Some(*meta_nullspace_survived),
},
None => MaternIdentifiability::None,
};
s.aniso_log_scales = meta_aniso.clone();
s.periodic = meta_periodic.clone();
*input_scales = meta_scales.clone();
}
(
SmoothBasisSpec::Duchon {
spec: s,
input_scales,
..
},
BasisMetadata::Duchon {
centers,
length_scale,
periodic: meta_periodic,
power,
nullspace_order,
identifiability_transform,
input_scales: meta_scales,
aniso_log_scales: meta_aniso,
radial_reparam,
..
},
) => {
s.center_strategy = crate::basis::CenterStrategy::UserProvided(centers.clone());
s.length_scale = *length_scale;
s.power = *power;
s.nullspace_order = *nullspace_order;
s.identifiability = match identifiability_transform {
Some(z) => SpatialIdentifiability::FrozenTransform {
transform: z.clone(),
},
None => match &s.identifiability {
SpatialIdentifiability::FrozenTransform { .. } => s.identifiability.clone(),
_ => SpatialIdentifiability::None,
},
};
s.aniso_log_scales = meta_aniso.clone();
s.periodic = meta_periodic.clone();
*input_scales = meta_scales.clone();
s.radial_reparam = radial_reparam.clone();
}
(
SmoothBasisSpec::Sphere { spec: s, .. },
BasisMetadata::SphereHarmonics {
max_degree,
radians,
},
) => {
s.max_degree = Some(*max_degree);
s.radians = *radians;
}
(
SmoothBasisSpec::TensorBSpline {
feature_cols,
spec: s,
},
BasisMetadata::TensorBSpline {
feature_cols: fitted_cols,
knots,
degrees,
periods,
identifiability_transform,
},
) => {
if s.marginalspecs.len() != knots.len() || s.marginalspecs.len() != degrees.len() {
crate::bail_invalid_estim!(
"tensor freeze mismatch for '{}': marginalspecs={}, knots={}, degrees={}",
term_name,
s.marginalspecs.len(),
knots.len(),
degrees.len()
);
}
*feature_cols = fitted_cols.clone();
for i in 0..s.marginalspecs.len() {
s.marginalspecs[i].degree = degrees[i];
s.marginalspecs[i].knotspec = match (periods[i], knots[i].len()) {
(Some(period), num_basis) if num_basis >= 1 => {
let domain_start = knots[i][0];
BSplineKnotSpec::PeriodicUniform {
data_range: (domain_start, domain_start + period),
num_basis,
}
}
_ => BSplineKnotSpec::Provided(knots[i].clone()),
};
}
s.identifiability = match identifiability_transform {
Some(z) => TensorBSplineIdentifiability::FrozenTransform {
transform: z.clone(),
},
None => TensorBSplineIdentifiability::None,
};
}
(
SmoothBasisSpec::FactorSmooth { spec: s },
BasisMetadata::FactorSmooth {
knots,
degree,
periodic,
group_levels,
..
},
) => {
s.marginal.knotspec = periodic
.map(
|(domain_start, period, num_basis)| BSplineKnotSpec::PeriodicUniform {
data_range: (domain_start, domain_start + period),
num_basis,
},
)
.unwrap_or_else(|| BSplineKnotSpec::Provided(knots.clone()));
s.marginal.degree = *degree;
s.group_frozen_levels = Some(group_levels.clone());
}
(
SmoothBasisSpec::BySmooth { smooth, by_kind },
BasisMetadata::FactorSmooth {
knots,
degree,
periodic,
group_levels,
..
},
) => {
if let ByVarKind::Factor { frozen_levels, .. } = by_kind {
*frozen_levels = Some(group_levels.clone());
}
if let SmoothBasisSpec::BSpline1D { spec: inner, .. } = smooth.as_mut() {
inner.degree = *degree;
inner.knotspec = periodic
.map(
|(domain_start, period, num_basis)| BSplineKnotSpec::PeriodicUniform {
data_range: (domain_start, domain_start + period),
num_basis,
},
)
.unwrap_or_else(|| BSplineKnotSpec::Provided(knots.clone()));
inner.identifiability = BSplineIdentifiability::None;
}
}
(
SmoothBasisSpec::BySmooth { smooth, by_kind },
BasisMetadata::BySmooth { inner, levels, .. },
) => {
if let ByVarKind::Factor { frozen_levels, .. } = by_kind
&& let Some(levels) = levels
{
*frozen_levels = Some(levels.clone());
}
freeze_smooth_basis_from_metadata(smooth, inner, term_name)?;
}
(SmoothBasisSpec::BySmooth { smooth, .. }, metadata) => {
freeze_smooth_basis_from_metadata(smooth, metadata, term_name)?;
}
_ => {
crate::bail_invalid_estim!(
"smooth metadata/spec type mismatch while freezing term '{}'",
term_name
);
}
}
Ok(())
}
pub fn freeze_term_collection_from_design(
spec: &TermCollectionSpec,
design: &TermCollectionDesign,
) -> Result<TermCollectionSpec, EstimationError> {
if spec.smooth_terms.len() != design.smooth.terms.len() {
crate::bail_invalid_estim!(
"freeze mismatch: smooth spec count {} != design smooth term count {}",
spec.smooth_terms.len(),
design.smooth.terms.len()
);
}
if spec.random_effect_terms.len() != design.random_effect_levels.len() {
crate::bail_invalid_estim!(
"freeze mismatch: random-effect spec count {} != design random-effect term count {}",
spec.random_effect_terms.len(),
design.random_effect_levels.len()
);
}
let mut frozen = spec.clone();
for (term, fitted) in frozen
.smooth_terms
.iter_mut()
.zip(design.smooth.terms.iter())
{
term.joint_null_rotation = fitted.joint_null_rotation.clone();
freeze_smooth_basis_from_metadata(&mut term.basis, &fitted.metadata, &term.name)?;
if let Some(z) = fitted.unabsorbed_global_orthogonality.as_ref() {
match &mut term.basis {
SmoothBasisSpec::FactorSumToZero {
frozen_global_orthogonality,
..
} => *frozen_global_orthogonality = Some(z.clone()),
SmoothBasisSpec::FactorSmooth { spec } => {
spec.frozen_global_orthogonality = Some(z.clone());
}
_ => {
crate::bail_invalid_estim!(
"freeze: term '{}' carries an unabsorbed global-orthogonality transform but its basis kind has no frozen carrier for it",
term.name
);
}
}
}
}
for (idx, rt) in frozen.random_effect_terms.iter_mut().enumerate() {
let (_, kept_levels) = &design.random_effect_levels[idx];
rt.frozen_levels = Some(kept_levels.clone());
}
Ok(frozen)
}
#[derive(Debug, Clone)]
struct SingleSmoothTermRealization {
design_local: DesignMatrix,
term: SmoothTerm,
dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
}
impl SingleSmoothTermRealization {
fn active_penaltyinfo(&self) -> Vec<PenaltyInfo> {
self.term
.penaltyinfo_local
.iter()
.filter(|info| info.active)
.cloned()
.collect()
}
}
fn build_single_smooth_term_realization(
data: ArrayView2<'_, f64>,
termspec: &SmoothTermSpec,
) -> Result<SingleSmoothTermRealization, BasisError> {
let raw = build_smooth_design(data, std::slice::from_ref(termspec))?;
finish_single_smooth_term_realization(raw)
}
fn finish_single_smooth_term_realization(
raw: RawSmoothDesign,
) -> Result<SingleSmoothTermRealization, BasisError> {
let RawSmoothDesign {
term_designs,
dropped_penaltyinfo,
terms,
..
} = raw;
let term = terms.into_iter().next().ok_or_else(|| {
BasisError::InvalidInput("single-term smooth build returned no term".to_string())
})?;
let design = term_designs.into_iter().next().ok_or_else(|| {
BasisError::InvalidInput("single-term smooth build returned no term design".to_string())
})?;
Ok(SingleSmoothTermRealization {
design_local: design,
term,
dropped_penaltyinfo,
})
}
fn wrap_local_build_as_realization(
mut local: LocalSmoothTermBuild,
termspec: &SmoothTermSpec,
) -> Result<SingleSmoothTermRealization, String> {
let p_local = local.dim;
let lb_local = if local.box_reparam {
shape_lower_bounds_local(termspec.shape, p_local)
} else {
None
};
let active_count = local.penaltyinfo.iter().filter(|info| info.active).count();
if active_count != local.penalties.len() {
return Err(format!(
"internal penalty info mismatch for term '{}': active_infos={}, penalties={}",
termspec.name,
active_count,
local.penalties.len()
));
}
let mut dropped_penaltyinfo = Vec::<DroppedPenaltyBlockInfo>::new();
for info in local.penaltyinfo.iter().filter(|info| !info.active) {
dropped_penaltyinfo.push(DroppedPenaltyBlockInfo {
termname: Some(termspec.name.clone()),
penalty: info.clone(),
});
}
for info in &local.pre_dropped_penaltyinfo {
dropped_penaltyinfo.push(DroppedPenaltyBlockInfo {
termname: Some(termspec.name.clone()),
penalty: info.clone(),
});
}
let applied_rotation: Option<crate::terms::basis::JointNullRotation> = match (
local.joint_null_rotation.take(),
lb_local.is_some(),
local.linear_constraints.is_some(),
) {
(Some(rot), false, false) => {
let q = &rot.rotation;
let dense = local
.design
.try_to_dense_by_chunks("joint-null absorption rotation (single realization)")
.map_err(|e| {
format!(
"joint-null absorption rotation: dense conversion failed for term '{}': {}",
termspec.name, e
)
})?;
let rotated = crate::linalg::faer_ndarray::fast_ab(&dense, q);
local.design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(rotated));
local.penalties = local
.penalties
.into_iter()
.map(|s_local| {
let qt_s = crate::linalg::faer_ndarray::fast_atb(q, &s_local);
crate::linalg::faer_ndarray::fast_ab(&qt_s, q)
})
.collect();
local.ops = vec![None; local.penalties.len()];
local.kronecker_factored = None;
Some(rot)
}
(Some(_), _, _) => None,
(None, _, _) => None,
};
let smooth_term = SmoothTerm {
name: termspec.name.clone(),
coeff_range: 0..p_local,
shape: termspec.shape,
penalties_local: local.penalties.clone(),
nullspace_dims: local.nullspaces.clone(),
penaltyinfo_local: local.penaltyinfo.clone(),
metadata: local.metadata.clone(),
lower_bounds_local: lb_local,
linear_constraints_local: local.linear_constraints.clone(),
kronecker_factored: local.kronecker_factored.take(),
joint_null_rotation: applied_rotation,
unabsorbed_global_orthogonality: None,
};
Ok(SingleSmoothTermRealization {
design_local: local.design,
term: smooth_term,
dropped_penaltyinfo,
})
}
fn freeze_geometry_from_metadata(
termspec: &SmoothTermSpec,
metadata: &BasisMetadata,
) -> Option<SmoothTermSpec> {
let mut frozen = termspec.clone();
match (&mut frozen.basis, metadata) {
(
SmoothBasisSpec::Matern {
spec,
input_scales: spec_scales,
..
},
BasisMetadata::Matern {
centers,
input_scales: meta_scales,
identifiability_transform,
nullspace_shrinkage_survived,
..
},
) => {
spec.center_strategy = CenterStrategy::UserProvided(centers.clone());
if spec_scales.is_none()
&& let Some(s) = meta_scales.clone()
{
*spec_scales = Some(s);
}
if let Some(transform) = identifiability_transform.clone() {
spec.identifiability = MaternIdentifiability::FrozenTransform {
transform,
nullspace_shrinkage_survived: Some(*nullspace_shrinkage_survived),
};
}
Some(frozen)
}
(
SmoothBasisSpec::Duchon {
spec,
input_scales: spec_scales,
..
},
BasisMetadata::Duchon {
centers,
input_scales: meta_scales,
..
},
) => {
spec.center_strategy = CenterStrategy::UserProvided(centers.clone());
if spec_scales.is_none()
&& let Some(s) = meta_scales.clone()
{
*spec_scales = Some(s);
}
Some(frozen)
}
(
SmoothBasisSpec::ThinPlate {
spec,
input_scales: spec_scales,
..
},
BasisMetadata::ThinPlate {
centers,
input_scales: meta_scales,
..
},
) => {
spec.center_strategy = CenterStrategy::UserProvided(centers.clone());
if spec_scales.is_none()
&& let Some(s) = meta_scales.clone()
{
*spec_scales = Some(s);
}
Some(frozen)
}
_ => None,
}
}
fn rebuild_smooth_auxiliary_state(
smooth: &mut SmoothDesign,
dropped_penaltyinfo_by_term: &[Vec<DroppedPenaltyBlockInfo>],
) -> Result<(), String> {
if dropped_penaltyinfo_by_term.len() != smooth.terms.len() {
return Err(SmoothError::dimension_mismatch(format!(
"smooth dropped-penalty cache mismatch: terms={}, dropped_sets={}",
smooth.terms.len(),
dropped_penaltyinfo_by_term.len()
))
.into());
}
let total_p = smooth.total_smooth_cols();
let mut coefficient_lower_bounds = Array1::<f64>::from_elem(total_p, f64::NEG_INFINITY);
let mut any_bounds = false;
let mut linear_constraintrows: Vec<Array1<f64>> = Vec::new();
let mut linear_constraint_b: Vec<f64> = Vec::new();
for term in &smooth.terms {
let range = term.coeff_range.clone();
if let Some(lb_local) = term.lower_bounds_local.as_ref() {
if lb_local.len() != range.len() {
return Err(SmoothError::dimension_mismatch(format!(
"smooth lower-bound cache mismatch for term '{}': bounds={}, coeffs={}",
term.name,
lb_local.len(),
range.len()
))
.into());
}
coefficient_lower_bounds
.slice_mut(s![range.clone()])
.assign(lb_local);
any_bounds = true;
}
if let Some(lin_local) = term.linear_constraints_local.as_ref() {
if lin_local.a.ncols() != range.len() {
return Err(SmoothError::dimension_mismatch(format!(
"smooth linear-constraint cache mismatch for term '{}': cols={}, coeffs={}",
term.name,
lin_local.a.ncols(),
range.len()
))
.into());
}
for r in 0..lin_local.a.nrows() {
let mut row = Array1::<f64>::zeros(total_p);
row.slice_mut(s![range.clone()]).assign(&lin_local.a.row(r));
linear_constraintrows.push(row);
linear_constraint_b.push(lin_local.b[r]);
}
}
}
smooth.coefficient_lower_bounds = if any_bounds {
Some(coefficient_lower_bounds)
} else {
None
};
smooth.linear_constraints = if linear_constraintrows.is_empty() {
None
} else {
let mut a = Array2::<f64>::zeros((linear_constraintrows.len(), total_p));
for (i, row) in linear_constraintrows.iter().enumerate() {
a.row_mut(i).assign(row);
}
Some(LinearInequalityConstraints {
a,
b: Array1::from_vec(linear_constraint_b),
})
};
smooth.dropped_penaltyinfo = dropped_penaltyinfo_by_term
.iter()
.flat_map(|infos| infos.iter().cloned())
.collect();
Ok(())
}
fn rebuild_term_collection_auxiliary_state(
spec: &TermCollectionSpec,
design: &mut TermCollectionDesign,
) -> Result<(), String> {
if spec.linear_terms.len() != design.linear_ranges.len() {
return Err(SmoothError::dimension_mismatch(format!(
"term-collection linear bookkeeping mismatch: spec_terms={}, design_ranges={}",
spec.linear_terms.len(),
design.linear_ranges.len()
))
.into());
}
let p_total = design.design.ncols();
let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
let mut coefficient_lower_bounds = Array1::<f64>::from_elem(p_total, f64::NEG_INFINITY);
let mut any_bounds = false;
let mut linear_constraintrows: Vec<Array1<f64>> = Vec::new();
let mut linear_constraint_b: Vec<f64> = Vec::new();
for (linear, (_, range)) in spec.linear_terms.iter().zip(design.linear_ranges.iter()) {
if range.len() != 1 {
return Err(SmoothError::dimension_mismatch(format!(
"linear term '{}' expected one coefficient column, found {}",
linear.name,
range.len()
))
.into());
}
let col = range.start;
if let Some(lb) = linear.coefficient_min {
let mut row = Array1::<f64>::zeros(p_total);
row[col] = 1.0;
linear_constraintrows.push(row);
linear_constraint_b.push(lb);
}
if let Some(ub) = linear.coefficient_max {
let mut row = Array1::<f64>::zeros(p_total);
row[col] = -1.0;
linear_constraintrows.push(row);
linear_constraint_b.push(-ub);
}
}
if let Some(lb_smooth) = design.smooth.coefficient_lower_bounds.as_ref() {
if lb_smooth.len() != design.smooth.total_smooth_cols() {
return Err(SmoothError::dimension_mismatch(format!(
"smooth lower-bound width mismatch: bounds={}, smooth_cols={}",
lb_smooth.len(),
design.smooth.total_smooth_cols()
))
.into());
}
coefficient_lower_bounds
.slice_mut(s![
smooth_start..(smooth_start + design.smooth.total_smooth_cols())
])
.assign(lb_smooth);
any_bounds = true;
}
if let Some(lin_smooth) = design.smooth.linear_constraints.as_ref() {
if lin_smooth.a.ncols() != design.smooth.total_smooth_cols() {
return Err(SmoothError::dimension_mismatch(format!(
"smooth linear-constraint width mismatch: cols={}, smooth_cols={}",
lin_smooth.a.ncols(),
design.smooth.total_smooth_cols()
))
.into());
}
let mut a_global = Array2::<f64>::zeros((lin_smooth.a.nrows(), p_total));
a_global
.slice_mut(s![
..,
smooth_start..(smooth_start + design.smooth.total_smooth_cols())
])
.assign(&lin_smooth.a);
for r in 0..a_global.nrows() {
linear_constraintrows.push(a_global.row(r).to_owned());
linear_constraint_b.push(lin_smooth.b[r]);
}
}
let lower_bound_constraints = if any_bounds {
linear_constraints_from_lower_bounds_global(&coefficient_lower_bounds)
} else {
None
};
let explicit_linear_constraints = if linear_constraintrows.is_empty() {
None
} else {
let mut a = Array2::<f64>::zeros((linear_constraintrows.len(), p_total));
for (i, row) in linear_constraintrows.iter().enumerate() {
a.row_mut(i).assign(row);
}
Some(LinearInequalityConstraints {
a,
b: Array1::from_vec(linear_constraint_b),
})
};
design.coefficient_lower_bounds = if any_bounds {
Some(coefficient_lower_bounds)
} else {
None
};
design.linear_constraints =
merge_linear_constraints_global(explicit_linear_constraints, lower_bound_constraints);
design.dropped_penaltyinfo = design.smooth.dropped_penaltyinfo.clone();
Ok(())
}
fn theta_values_match(left: &Array1<f64>, right: &Array1<f64>) -> bool {
left.len() == right.len()
&& left
.iter()
.zip(right.iter())
.all(|(&l, &r)| l.to_bits() == r.to_bits())
}
fn latent_values_match(left: &Array1<f64>, right: &Array1<f64>) -> bool {
theta_values_match(left, right)
}
fn spatial_aniso_matches(left: Option<&[f64]>, right: Option<&[f64]>) -> bool {
match (left, right) {
(None, None) => true,
(Some(a), Some(b)) => {
a.len() == b.len()
&& a.iter()
.zip(b.iter())
.all(|(&x, &y)| x.to_bits() == y.to_bits())
}
_ => false,
}
}
fn spatial_length_scale_matches(left: Option<f64>, right: Option<f64>) -> bool {
match (left, right) {
(None, None) => true,
(Some(a), Some(b)) => a.to_bits() == b.to_bits(),
_ => false,
}
}
struct FrozenTermCollectionIncrementalRealizer<'d> {
data: ArrayView2<'d, f64>,
spec: TermCollectionSpec,
design: TermCollectionDesign,
fixed_blocks: Vec<DesignBlock>,
dropped_penaltyinfo_by_term: Vec<Vec<DroppedPenaltyBlockInfo>>,
smooth_penalty_ranges: Vec<Range<usize>>,
full_penalty_ranges: Vec<Range<usize>>,
basisworkspace: crate::basis::BasisWorkspace,
spatial_realization_geometry: Vec<Option<SmoothTermSpec>>,
design_revision: u64,
}
impl<'d> std::fmt::Debug for FrozenTermCollectionIncrementalRealizer<'d> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FrozenTermCollectionIncrementalRealizer")
.field("data_shape", &(self.data.nrows(), self.data.ncols()))
.field("fixed_blocks", &self.fixed_blocks.len())
.finish_non_exhaustive()
}
}
impl<'d> FrozenTermCollectionIncrementalRealizer<'d> {
fn new(
data: ArrayView2<'d, f64>,
spec: TermCollectionSpec,
design: TermCollectionDesign,
) -> Result<Self, String> {
if spec.smooth_terms.len() != design.smooth.terms.len() {
return Err(SmoothError::dimension_mismatch(format!(
"incremental realizer smooth term mismatch: spec_terms={}, design_terms={}",
spec.smooth_terms.len(),
design.smooth.terms.len()
))
.into());
}
let mut smooth_cursor = 0usize;
let mut smooth_penalty_ranges = Vec::with_capacity(design.smooth.terms.len());
for term in &design.smooth.terms {
let next = smooth_cursor + term.penalties_local.len();
smooth_penalty_ranges.push(smooth_cursor..next);
smooth_cursor = next;
}
if smooth_cursor != design.smooth.penalties.len() {
return Err(SmoothError::dimension_mismatch(format!(
"incremental realizer smooth penalty mismatch: ranged={}, actual={}",
smooth_cursor,
design.smooth.penalties.len()
))
.into());
}
let fixed_penalty_offset = design
.penalties
.len()
.checked_sub(design.smooth.penalties.len())
.ok_or_else(|| {
"incremental realizer encountered invalid penalty bookkeeping".to_string()
})?;
let full_penalty_ranges = smooth_penalty_ranges
.iter()
.map(|range| (fixed_penalty_offset + range.start)..(fixed_penalty_offset + range.end))
.collect::<Vec<_>>();
let fixed_blocks = build_term_collection_fixed_blocks(data, &spec)
.map_err(|e| format!("failed to cache fixed term-collection blocks: {e}"))?;
let mut dropped_penaltyinfo_by_term = Vec::with_capacity(spec.smooth_terms.len());
for (term_idx, termspec) in spec.smooth_terms.iter().enumerate() {
let realization =
build_single_smooth_term_realization(data, termspec).map_err(|e| {
format!(
"failed to build cached realization for smooth term '{}' (index {}): {e}",
termspec.name, term_idx
)
})?;
let expected_cols = design.smooth.terms[term_idx].coeff_range.len();
if realization.design_local.ncols() != expected_cols {
return Err(SmoothError::dimension_mismatch(format!(
"cached realization width mismatch for term '{}': cached_cols={}, design_cols={}",
termspec.name,
realization.design_local.ncols(),
expected_cols
))
.into());
}
if realization.active_penaltyinfo().len()
!= design.smooth.terms[term_idx].penalties_local.len()
{
return Err(SmoothError::dimension_mismatch(format!(
"cached realization penalty mismatch for term '{}': cached_penalties={}, design_penalties={}",
termspec.name,
realization.active_penaltyinfo().len(),
design.smooth.terms[term_idx].penalties_local.len()
))
.into());
}
dropped_penaltyinfo_by_term.push(realization.dropped_penaltyinfo);
}
let geometry_slots = spec.smooth_terms.len();
Ok(Self {
data,
spec,
design,
fixed_blocks,
dropped_penaltyinfo_by_term,
smooth_penalty_ranges,
full_penalty_ranges,
basisworkspace: crate::basis::BasisWorkspace::new(),
spatial_realization_geometry: vec![None; geometry_slots],
design_revision: 0,
})
}
fn design_revision(&self) -> u64 {
self.design_revision
}
fn spec(&self) -> &TermCollectionSpec {
&self.spec
}
fn design(&self) -> &TermCollectionDesign {
&self.design
}
fn supports_nfree_penalty_rekey(&self, spatial_terms: &[usize]) -> bool {
if spatial_terms.len() != 1 {
return false;
}
let term_idx = spatial_terms[0];
matches!(
self.design.smooth.terms.get(term_idx).map(|t| &t.metadata),
Some(BasisMetadata::Duchon { .. } | BasisMetadata::ThinPlate { .. })
)
}
fn supports_nfree_gradient_only_routing(&self, spatial_terms: &[usize]) -> bool {
if spatial_terms.len() != 1 {
return false;
}
let term_idx = spatial_terms[0];
matches!(
self.design.smooth.terms.get(term_idx).map(|t| &t.metadata),
Some(BasisMetadata::Duchon { .. } | BasisMetadata::ThinPlate { .. })
)
}
fn canonical_penalties_at_psi(
&mut self,
spatial_terms: &[usize],
psi: &[f64],
) -> Result<(Vec<crate::construction::CanonicalPenalty>, Vec<usize>), String> {
if spatial_terms.len() != 1 {
return Err(format!(
"n-free penalty re-key requires exactly one spatial term, found {}",
spatial_terms.len()
));
}
let term_idx = spatial_terms[0];
let (ls_opt, aniso_from_psi) = spatial_term_psi_to_length_scale_and_aniso(psi);
let termspec =
self.spec.smooth_terms.get(term_idx).ok_or_else(|| {
format!("spatial term {term_idx} out of range for n-free penalty")
})?;
let term = self
.design
.smooth
.terms
.get(term_idx)
.ok_or_else(|| format!("realized smooth term {term_idx} out of range"))?;
let p_total = self.design.design.ncols();
let (locals, nullspace_dims): (Vec<Array2<f64>>, Vec<usize>) = match &term.metadata {
BasisMetadata::Duchon {
centers,
identifiability_transform,
operator_collocation_points,
power,
nullspace_order,
aniso_log_scales,
input_scales,
radial_reparam,
..
} => {
let operator_penalties = match &termspec.basis {
SmoothBasisSpec::Duchon { spec, .. } => spec.operator_penalties.clone(),
_ => crate::basis::DuchonOperatorPenaltySpec::default(),
};
let effective_ls = match input_scales.as_deref() {
Some(scales) => {
compensate_optional_length_scale_for_standardization(ls_opt, scales)
}
None => ls_opt,
};
crate::basis::duchon_penalties_at_length_scale(
centers.view(),
identifiability_transform.as_ref(),
operator_collocation_points.as_ref().map(|p| p.view()),
&operator_penalties,
*power,
*nullspace_order,
aniso_log_scales.as_deref(),
radial_reparam.as_ref(),
effective_ls,
&mut self.basisworkspace,
)
.map_err(|e| e.to_string())?
}
BasisMetadata::Matern {
centers,
periodic,
nu,
include_intercept,
identifiability_transform,
aniso_log_scales,
input_scales,
..
} => {
let ls = ls_opt.ok_or_else(|| {
"Matérn n-free penalty re-key requires a finite length-scale".to_string()
})?;
let effective_ls = match input_scales.as_deref() {
Some(scales) => compensate_length_scale_for_standardization(ls, scales),
None => ls,
};
let aniso_for_penalty = aniso_from_psi.as_deref().or(aniso_log_scales.as_deref());
let (penalties, nullspace_dims, _info) =
matern_operator_penalty_triplet_at_length_scale(
centers.view(),
periodic.as_deref(),
identifiability_transform.as_ref(),
*nu,
*include_intercept,
aniso_for_penalty,
effective_ls,
)
.map_err(|e| e.to_string())?;
(penalties, nullspace_dims)
}
BasisMetadata::ThinPlate {
centers,
identifiability_transform,
radial_reparam,
..
} => {
let ls = ls_opt.ok_or_else(|| {
"thin-plate n-free penalty re-key requires a finite length-scale".to_string()
})?;
let double_penalty = match &termspec.basis {
SmoothBasisSpec::ThinPlate { spec, .. } => spec.double_penalty,
_ => false,
};
crate::basis::thin_plate_penalties_at_length_scale(
centers.view(),
identifiability_transform.as_ref(),
radial_reparam.as_ref(),
ls,
double_penalty,
&mut self.basisworkspace,
)
.map_err(|e| e.to_string())?
}
other => {
return Err(format!(
"n-free penalty re-key unsupported for basis metadata {:?}",
std::mem::discriminant(other)
));
}
};
let templates = &self.design.penalties;
if templates.len() != locals.len() {
return Err(format!(
"n-free penalty re-key produced {} blocks but the frozen design carries {} \
— penalty topology is not ψ-stable",
locals.len(),
templates.len()
));
}
let specs: Vec<crate::estimate::PenaltySpec> = templates
.iter()
.zip(locals.into_iter())
.map(|(tmpl, local)| crate::estimate::PenaltySpec::Block {
local,
col_range: tmpl.col_range.clone(),
prior_mean: tmpl.prior_mean.clone(),
structure_hint: tmpl.structure_hint.clone(),
op: tmpl.op.clone(),
})
.collect();
crate::construction::canonicalize_penalty_specs(
&specs,
&nullspace_dims,
p_total,
"nfree-psi-penalty",
)
.map_err(|e| e.to_string())
}
fn canonical_penalty_derivatives_at_psi(
&mut self,
spatial_terms: &[usize],
psi: &[f64],
) -> Result<(Range<usize>, usize, Vec<Array2<f64>>), String> {
if spatial_terms.len() != 1 {
return Err(format!(
"n-free penalty derivative re-key requires exactly one spatial term, found {}",
spatial_terms.len()
));
}
let term_idx = spatial_terms[0];
let (ls_opt, aniso_from_psi) = spatial_term_psi_to_length_scale_and_aniso(psi);
let termspec = self.spec.smooth_terms.get(term_idx).ok_or_else(|| {
format!("spatial term {term_idx} out of range for n-free penalty derivative")
})?;
let term = self
.design
.smooth
.terms
.get(term_idx)
.ok_or_else(|| format!("realized smooth term {term_idx} out of range"))?;
let p_total = self.design.design.ncols();
let smooth_start = p_total.saturating_sub(self.design.smooth.total_smooth_cols());
let global_range =
(smooth_start + term.coeff_range.start)..(smooth_start + term.coeff_range.end);
let locals = match &term.metadata {
BasisMetadata::Duchon {
centers,
identifiability_transform,
operator_collocation_points,
power,
nullspace_order,
aniso_log_scales,
input_scales,
radial_reparam,
..
} => {
let mut spec = match &termspec.basis {
SmoothBasisSpec::Duchon { spec, .. } => spec.clone(),
_ => {
return Err(
"Duchon n-free penalty derivative requires a Duchon term spec"
.to_string(),
);
}
};
let effective_ls = match input_scales.as_deref() {
Some(scales) => {
compensate_optional_length_scale_for_standardization(ls_opt, scales)
}
None => ls_opt,
};
spec.length_scale = effective_ls;
spec.power = *power;
spec.nullspace_order = *nullspace_order;
spec.aniso_log_scales = aniso_log_scales.clone();
spec.radial_reparam = radial_reparam.clone();
if spec.length_scale.is_none() {
return Err(
"Duchon n-free penalty derivative requires a hybrid length-scale"
.to_string(),
);
}
let collocation = operator_collocation_points
.as_ref()
.map(|points| points.view())
.unwrap_or_else(|| centers.view());
let (_native_sources, mut first, _native_second) =
crate::basis::build_duchon_native_penalty_psi_derivatives(
centers.view(),
&spec,
identifiability_transform.as_ref(),
&mut self.basisworkspace,
)
.map_err(|e| e.to_string())?;
let (_operator_sources, operator_first, _operator_second) =
crate::basis::build_duchon_operator_penalty_psi_derivatives(
collocation,
centers.view(),
&spec,
identifiability_transform.as_ref(),
&mut self.basisworkspace,
)
.map_err(|e| e.to_string())?;
first.extend(operator_first);
first
}
BasisMetadata::Matern {
centers,
periodic,
nu,
include_intercept,
identifiability_transform,
aniso_log_scales,
input_scales,
..
} => {
let ls = ls_opt.ok_or_else(|| {
"Matérn n-free penalty derivative requires a finite length-scale".to_string()
})?;
let effective_ls = match input_scales.as_deref() {
Some(scales) => compensate_length_scale_for_standardization(ls, scales),
None => ls,
};
let penalty_centers =
crate::basis::expand_periodic_centers(¢ers.to_owned(), periodic.as_deref())
.map_err(|e| e.to_string())?;
let aniso_for_penalty = aniso_from_psi.as_deref().or(aniso_log_scales.as_deref());
let (first, _second) = crate::basis::build_matern_operator_penalty_psi_derivatives(
penalty_centers.view(),
effective_ls,
*nu,
*include_intercept,
identifiability_transform.as_ref(),
aniso_for_penalty,
)
.map_err(|e| e.to_string())?;
first
}
BasisMetadata::ThinPlate {
centers,
identifiability_transform,
radial_reparam,
..
} => {
let ls = ls_opt.ok_or_else(|| {
"thin-plate n-free penalty derivative requires a finite length-scale"
.to_string()
})?;
let mut spec = match &termspec.basis {
SmoothBasisSpec::ThinPlate { spec, .. } => spec.clone(),
_ => {
return Err(
"thin-plate n-free penalty derivative requires a ThinPlate term spec"
.to_string(),
);
}
};
spec.length_scale = ls;
if spec.radial_reparam.is_none() {
spec.radial_reparam = radial_reparam.clone();
}
let (primary, _primary_second) =
crate::basis::build_thin_plate_penalty_psi_derivativeswithworkspace(
centers.view(),
&spec,
identifiability_transform.as_ref(),
&mut self.basisworkspace,
)
.map_err(|e| e.to_string())?;
if self.design.penalties.len() > 1 {
vec![primary.clone(), Array2::<f64>::zeros(primary.raw_dim())]
} else {
vec![primary]
}
}
other => {
return Err(format!(
"n-free penalty derivative re-key unsupported for basis metadata {:?}",
std::mem::discriminant(other)
));
}
};
if locals.len() != self.design.penalties.len() {
return Err(format!(
"n-free penalty derivative re-key produced {} blocks but the frozen design carries {} \
— penalty topology is not ψ-stable",
locals.len(),
self.design.penalties.len()
));
}
Ok((global_range, p_total, locals))
}
fn apply_log_kappa(
&mut self,
log_kappa: &SpatialLogKappaCoords,
term_indices: &[usize],
) -> Result<(), String> {
if term_indices.len() != log_kappa.dims_per_term().len() {
return Err(SmoothError::dimension_mismatch(format!(
"incremental realizer log-kappa term mismatch: term_indices={}, dims_per_term={}",
term_indices.len(),
log_kappa.dims_per_term().len()
))
.into());
}
let mut any_changed = false;
for (slot, &term_idx) in term_indices.iter().enumerate() {
any_changed |= self.apply_log_kappa_to_term(term_idx, log_kappa.term_slice(slot))?;
}
if any_changed {
self.refresh_full_design_operator()?;
rebuild_smooth_auxiliary_state(
&mut self.design.smooth,
&self.dropped_penaltyinfo_by_term,
)?;
rebuild_term_collection_auxiliary_state(&self.spec, &mut self.design)?;
self.design_revision = self.design_revision.wrapping_add(1);
}
Ok(())
}
fn apply_log_kappa_to_term(&mut self, term_idx: usize, psi: &[f64]) -> Result<bool, String> {
if !spatial_term_supports_hyper_optimization(&self.spec, term_idx) {
return Err(SmoothError::invalid_config(format!(
"incremental realizer term {term_idx} does not expose spatial hyperparameters"
))
.into());
}
let measure_jet_term = measure_jet_term_spec(&self.spec, term_idx).is_some();
let constant_curvature_term = constant_curvature_term_spec(&self.spec, term_idx).is_some();
let mut next_length_scale = None;
let mut next_aniso: Option<Vec<f64>> = None;
if measure_jet_term {
if !set_measure_jet_psi_dials(&mut self.spec, term_idx, psi)
.map_err(|e| e.to_string())?
{
return Ok(false);
}
} else if constant_curvature_term {
if !set_constant_curvature_kappa(&mut self.spec, term_idx, psi)
.map_err(|e| e.to_string())?
{
return Ok(false);
}
} else {
let current_length_scale = get_spatial_length_scale(&self.spec, term_idx);
let current_aniso = get_spatial_aniso_log_scales(&self.spec, term_idx);
let (ls, eta) = spatial_term_psi_to_length_scale_and_aniso(psi);
next_length_scale = ls;
next_aniso = eta;
let same_length = spatial_length_scale_matches(current_length_scale, next_length_scale);
let same_aniso = spatial_aniso_matches(current_aniso.as_deref(), next_aniso.as_deref());
if same_length && same_aniso {
return Ok(false);
}
if let Some(length_scale) = next_length_scale {
set_spatial_length_scale(&mut self.spec, term_idx, length_scale)
.map_err(|e| e.to_string())?;
}
if let Some(eta) = next_aniso.clone() {
set_spatial_aniso_log_scales(&mut self.spec, term_idx, eta)
.map_err(|e| e.to_string())?;
}
}
let geometry_slot = self
.spatial_realization_geometry
.get(term_idx)
.ok_or_else(|| format!("incremental realizer geometry slot {term_idx} out of range"))?;
let mut build_spec = match geometry_slot {
Some(cached) => cached.clone(),
None => self
.spec
.smooth_terms
.get(term_idx)
.ok_or_else(|| format!("incremental realizer smooth term {term_idx} out of range"))?
.clone(),
};
if measure_jet_term {
set_single_term_measure_jet_psi_dials(&mut build_spec, psi)
.map_err(|e| e.to_string())?;
} else if constant_curvature_term {
set_single_term_constant_curvature_kappa(&mut build_spec, psi)
.map_err(|e| e.to_string())?;
} else {
if let Some(length_scale) = next_length_scale {
set_single_term_spatial_length_scale(&mut build_spec, length_scale)
.map_err(|e| e.to_string())?;
}
if let Some(eta) = next_aniso {
set_single_term_spatial_aniso_log_scales(&mut build_spec, eta)
.map_err(|e| e.to_string())?;
}
}
let termname = build_spec.name.clone();
let local = build_single_local_smooth_term(
self.data,
&build_spec,
&mut self.basisworkspace,
)
.map_err(|e| {
format!(
"failed to rebuild smooth term '{termname}' during incremental κ realization: {e}"
)
})?;
if self.spatial_realization_geometry[term_idx].is_none()
&& let Some(frozen) = freeze_geometry_from_metadata(&build_spec, &local.metadata)
{
if let (
SmoothBasisSpec::Matern {
spec: frozen_spec, ..
},
Some(SmoothBasisSpec::Matern {
spec: live_spec, ..
}),
) = (
&frozen.basis,
self.spec
.smooth_terms
.get_mut(term_idx)
.map(|t| &mut t.basis),
) {
live_spec.identifiability = frozen_spec.identifiability.clone();
live_spec.center_strategy = frozen_spec.center_strategy.clone();
}
self.spatial_realization_geometry[term_idx] = Some(frozen);
}
let realization = wrap_local_build_as_realization(local, &build_spec)?;
self.replace_term_realization(term_idx, realization)?;
Ok(true)
}
fn replace_term_realization(
&mut self,
term_idx: usize,
realization: SingleSmoothTermRealization,
) -> Result<(), String> {
let t_replace = std::time::Instant::now();
let SingleSmoothTermRealization {
design_local,
term,
dropped_penaltyinfo,
} = realization;
let SmoothTerm {
name,
penalties_local,
nullspace_dims,
penaltyinfo_local,
metadata,
lower_bounds_local,
linear_constraints_local,
joint_null_rotation,
..
} = term;
let coeff_range = self
.design
.smooth
.terms
.get(term_idx)
.ok_or_else(|| format!("incremental realizer smooth term {term_idx} out of range"))?
.coeff_range
.clone();
if design_local.ncols() != coeff_range.len() {
return Err(SmoothError::dimension_mismatch(format!(
"incremental realizer width mismatch for term {}: rebuilt_cols={}, cached_cols={}",
term_idx,
design_local.ncols(),
coeff_range.len()
))
.into());
}
if design_local.nrows() != self.design.design.nrows() {
return Err(SmoothError::dimension_mismatch(format!(
"incremental realizer row mismatch for term {}: rebuilt_rows={}, design_rows={}",
term_idx,
design_local.nrows(),
self.design.design.nrows()
))
.into());
}
let active_penaltyinfo = penaltyinfo_local
.iter()
.filter(|info| info.active)
.cloned()
.collect::<Vec<_>>();
let smooth_penalty_range = self
.smooth_penalty_ranges
.get(term_idx)
.ok_or_else(|| {
format!("incremental realizer missing smooth penalty range for term {term_idx}")
})?
.clone();
let full_penalty_range = self
.full_penalty_ranges
.get(term_idx)
.ok_or_else(|| {
format!("incremental realizer missing full penalty range for term {term_idx}")
})?
.clone();
if active_penaltyinfo.len() != smooth_penalty_range.len()
|| penalties_local.len() != smooth_penalty_range.len()
|| nullspace_dims.len() != smooth_penalty_range.len()
{
return Err(SmoothError::dimension_mismatch(format!(
"incremental realizer topology changed for term '{}': penalties={}, infos={}, nullspaces={}, cached_penalties={}",
name,
penalties_local.len(),
active_penaltyinfo.len(),
nullspace_dims.len(),
smooth_penalty_range.len()
))
.into());
}
self.design.smooth.term_designs[term_idx] = design_local;
for (offset, penalty_local) in penalties_local.iter().enumerate() {
let smooth_penalty_idx = smooth_penalty_range.start + offset;
let full_penalty_idx = full_penalty_range.start + offset;
let nullspace_dim = nullspace_dims[offset];
let penalty_info = active_penaltyinfo[offset].clone();
if penalty_local.nrows() != coeff_range.len()
|| penalty_local.ncols() != coeff_range.len()
{
return Err(SmoothError::dimension_mismatch(format!(
"incremental realizer penalty shape mismatch for term '{}' penalty {}: \
penalty is {}x{} but coeff_range has {} columns",
name,
offset,
penalty_local.nrows(),
penalty_local.ncols(),
coeff_range.len()
))
.into());
}
let smooth_penalty = self
.design
.smooth
.penalties
.get_mut(smooth_penalty_idx)
.ok_or_else(|| {
format!(
"incremental realizer smooth penalty {} out of range for term {}",
smooth_penalty_idx, term_idx
)
})?;
smooth_penalty.local.assign(penalty_local);
let full_bp = self
.design
.penalties
.get_mut(full_penalty_idx)
.ok_or_else(|| {
format!(
"incremental realizer full penalty {} out of range for term {}",
full_penalty_idx, term_idx
)
})?;
full_bp.local.assign(penalty_local);
self.design.smooth.nullspace_dims[smooth_penalty_idx] = nullspace_dim;
self.design.nullspace_dims[full_penalty_idx] = nullspace_dim;
self.design.smooth.penaltyinfo[smooth_penalty_idx].global_index = smooth_penalty_idx;
self.design.smooth.penaltyinfo[smooth_penalty_idx].termname = Some(name.clone());
self.design.smooth.penaltyinfo[smooth_penalty_idx].penalty = penalty_info.clone();
self.design.penaltyinfo[full_penalty_idx].global_index = full_penalty_idx;
self.design.penaltyinfo[full_penalty_idx].termname = Some(name.clone());
self.design.penaltyinfo[full_penalty_idx].penalty = penalty_info;
}
let target_term = self.design.smooth.terms.get_mut(term_idx).ok_or_else(|| {
format!("incremental realizer smooth term {term_idx} disappeared during replacement")
})?;
target_term.penalties_local = penalties_local;
target_term.nullspace_dims = nullspace_dims;
target_term.penaltyinfo_local = penaltyinfo_local;
target_term.metadata = metadata;
target_term.lower_bounds_local = lower_bounds_local;
target_term.linear_constraints_local = linear_constraints_local;
target_term.joint_null_rotation = joint_null_rotation;
self.dropped_penaltyinfo_by_term[term_idx] = dropped_penaltyinfo;
log::info!(
"[STAGE] smooth basis rebuild (term {}, '{}', cols={}): {:.3}s",
term_idx,
target_term.name,
coeff_range.len(),
t_replace.elapsed().as_secs_f64(),
);
Ok(())
}
fn refresh_full_design_operator(&mut self) -> Result<(), String> {
let mut blocks = Vec::<DesignBlock>::with_capacity(
self.fixed_blocks.len() + self.design.smooth.term_designs.len(),
);
blocks.extend(self.fixed_blocks.iter().cloned());
for term_design in &self.design.smooth.term_designs {
blocks.push(DesignBlock::from(term_design));
}
self.design.design = assemble_term_collection_design_matrix(blocks)
.map_err(|e| format!("failed to refresh term-collection design: {e}"))?;
Ok(())
}
}
fn build_term_collection_fixed_blocks(
data: ArrayView2<'_, f64>,
spec: &TermCollectionSpec,
) -> Result<Vec<DesignBlock>, BasisError> {
let mut blocks = Vec::<DesignBlock>::new();
if !term_collection_has_one_sided_anchored_bspline(spec) {
blocks.push(DesignBlock::Intercept(data.nrows()));
}
if !spec.linear_terms.is_empty() {
let mut linear_block = Array2::<f64>::zeros((data.nrows(), spec.linear_terms.len()));
for (j, linear) in spec.linear_terms.iter().enumerate() {
let column = linear
.realized_design_column(data)
.map_err(BasisError::InvalidInput)?;
linear_block.column_mut(j).assign(&column);
}
blocks.push(DesignBlock::Dense(crate::matrix::DenseDesignMatrix::from(
linear_block,
)));
}
for term in &spec.random_effect_terms {
let block = build_random_effect_block(data, term)?;
let re_op = RandomEffectOperator::new(block.group_ids, block.num_groups);
blocks.push(DesignBlock::RandomEffect(Arc::new(re_op)));
}
Ok(blocks)
}
pub struct SpatialLengthScaleOptimizationResult<FitOut> {
pub resolved_specs: Vec<TermCollectionSpec>,
pub designs: Vec<TermCollectionDesign>,
pub fit: FitOut,
pub timing: Option<SpatialLengthScaleOptimizationTiming>,
}
#[derive(Debug, Clone)]
pub struct ExactJointHyperSetup {
rho0: Array1<f64>,
rho_lower: Array1<f64>,
rho_upper: Array1<f64>,
log_kappa0: SpatialLogKappaCoords,
log_kappa_lower: SpatialLogKappaCoords,
log_kappa_upper: SpatialLogKappaCoords,
auxiliary0: Array1<f64>,
auxiliary_lower: Array1<f64>,
auxiliary_upper: Array1<f64>,
}
impl ExactJointHyperSetup {
fn sanitize_rho_seed(
rho0: Array1<f64>,
rho_lower: &Array1<f64>,
rho_upper: &Array1<f64>,
) -> Array1<f64> {
Array1::from_iter(rho0.iter().enumerate().map(|(idx, &value)| {
let lo = rho_lower[idx];
let hi = rho_upper[idx];
let fallback = 0.0_f64.clamp(lo, hi);
if value.is_finite() {
value.clamp(lo, hi)
} else {
fallback
}
}))
}
pub(crate) fn new(
rho0: Array1<f64>,
rho_lower: Array1<f64>,
rho_upper: Array1<f64>,
log_kappa0: SpatialLogKappaCoords,
log_kappa_lower: SpatialLogKappaCoords,
log_kappa_upper: SpatialLogKappaCoords,
) -> Self {
let rho0 = Self::sanitize_rho_seed(rho0, &rho_lower, &rho_upper);
Self {
rho0,
rho_lower,
rho_upper,
log_kappa0,
log_kappa_lower,
log_kappa_upper,
auxiliary0: Array1::zeros(0),
auxiliary_lower: Array1::zeros(0),
auxiliary_upper: Array1::zeros(0),
}
}
pub(crate) fn with_auxiliary(
mut self,
auxiliary0: Array1<f64>,
auxiliary_lower: Array1<f64>,
auxiliary_upper: Array1<f64>,
) -> Self {
assert_eq!(
auxiliary0.len(),
auxiliary_lower.len(),
"auxiliary lower bound length mismatch"
);
assert_eq!(
auxiliary0.len(),
auxiliary_upper.len(),
"auxiliary upper bound length mismatch"
);
self.auxiliary0 = Self::sanitize_rho_seed(auxiliary0, &auxiliary_lower, &auxiliary_upper);
self.auxiliary_lower = auxiliary_lower;
self.auxiliary_upper = auxiliary_upper;
self
}
pub(crate) fn rho_dim(&self) -> usize {
self.rho0.len()
}
pub(crate) fn log_kappa_dim(&self) -> usize {
self.log_kappa0.len()
}
pub(crate) fn auxiliary_dim(&self) -> usize {
self.auxiliary0.len()
}
pub(crate) fn theta0(&self) -> Array1<f64> {
let mut out =
Array1::<f64>::zeros(self.rho_dim() + self.log_kappa_dim() + self.auxiliary_dim());
out.slice_mut(s![..self.rho_dim()]).assign(&self.rho0);
out.slice_mut(s![self.rho_dim()..self.rho_dim() + self.log_kappa_dim()])
.assign(self.log_kappa0.as_array());
out.slice_mut(s![self.rho_dim() + self.log_kappa_dim()..])
.assign(&self.auxiliary0);
out
}
pub(crate) fn lower(&self) -> Array1<f64> {
let mut out =
Array1::<f64>::zeros(self.rho_dim() + self.log_kappa_dim() + self.auxiliary_dim());
out.slice_mut(s![..self.rho_dim()]).assign(&self.rho_lower);
out.slice_mut(s![self.rho_dim()..self.rho_dim() + self.log_kappa_dim()])
.assign(self.log_kappa_lower.as_array());
out.slice_mut(s![self.rho_dim() + self.log_kappa_dim()..])
.assign(&self.auxiliary_lower);
out
}
pub(crate) fn upper(&self) -> Array1<f64> {
let mut out =
Array1::<f64>::zeros(self.rho_dim() + self.log_kappa_dim() + self.auxiliary_dim());
out.slice_mut(s![..self.rho_dim()]).assign(&self.rho_upper);
out.slice_mut(s![self.rho_dim()..self.rho_dim() + self.log_kappa_dim()])
.assign(self.log_kappa_upper.as_array());
out.slice_mut(s![self.rho_dim() + self.log_kappa_dim()..])
.assign(&self.auxiliary_upper);
out
}
pub(crate) fn log_kappa_dims_per_term(&self) -> Vec<usize> {
self.log_kappa0.dims_per_term().to_vec()
}
}
struct ExactJointDesignCache<'d> {
realizers: Vec<FrozenTermCollectionIncrementalRealizer<'d>>,
block_term_indices: Vec<Vec<usize>>,
current_theta: Option<Array1<f64>>,
last_cost: Option<f64>,
last_eval: Option<(
f64,
Array1<f64>,
crate::solver::rho_optimizer::HessianResult,
)>,
rho_dim: usize,
all_dims: Vec<usize>,
log_kappa_dim: usize,
block_term_counts: Vec<usize>,
}
impl<'d> ExactJointDesignCache<'d> {
fn new(
data: ArrayView2<'d, f64>,
blocks: Vec<(TermCollectionSpec, TermCollectionDesign, Vec<usize>)>,
rho_dim: usize,
all_dims: Vec<usize>,
) -> Result<Self, String> {
let n_blocks = blocks.len();
let mut realizers = Vec::with_capacity(n_blocks);
let mut block_term_indices = Vec::with_capacity(n_blocks);
let mut block_term_counts = Vec::with_capacity(n_blocks);
for (spec, design, terms) in blocks {
block_term_counts.push(terms.len());
block_term_indices.push(terms);
realizers.push(FrozenTermCollectionIncrementalRealizer::new(
data, spec, design,
)?);
}
Ok(Self {
realizers,
block_term_indices,
current_theta: None,
last_cost: None,
last_eval: None,
rho_dim,
log_kappa_dim: all_dims.iter().sum(),
all_dims,
block_term_counts,
})
}
fn ensure_theta(&mut self, theta: &Array1<f64>) -> Result<(), String> {
if self
.current_theta
.as_ref()
.is_some_and(|cached| theta_values_match(cached, theta))
{
return Ok(());
}
let t_ensure = std::time::Instant::now();
let kappa_theta_len = self.rho_dim + self.log_kappa_dim;
if theta.len() < kappa_theta_len {
return Err(SmoothError::dimension_mismatch(format!(
"exact-joint theta length mismatch: got {}, expected at least {} (rho_dim={}, log_kappa_dim={})",
theta.len(),
kappa_theta_len,
self.rho_dim,
self.log_kappa_dim
))
.into());
}
let theta_kappa = theta.slice(s![..kappa_theta_len]).to_owned();
let full_log_kappa = SpatialLogKappaCoords::from_theta_tail_with_dims(
&theta_kappa,
self.rho_dim,
self.all_dims.clone(),
);
let n = self.realizers.len();
let mut remaining = full_log_kappa;
for block_idx in 0..n {
let count = self.block_term_counts[block_idx];
if block_idx < n - 1 {
let (block_lk, rest) = remaining.split_at(count);
self.realizers[block_idx]
.apply_log_kappa(&block_lk, &self.block_term_indices[block_idx])?;
remaining = rest;
} else {
self.realizers[block_idx]
.apply_log_kappa(&remaining, &self.block_term_indices[block_idx])?;
}
}
log::info!(
"[STAGE] ensure_theta (n-block, {} blocks, {} realizers): {:.3}s",
n,
self.realizers.len(),
t_ensure.elapsed().as_secs_f64(),
);
self.current_theta = Some(theta.clone());
self.last_cost = None;
self.last_eval = None;
Ok(())
}
impl_exact_joint_theta_memo!();
fn store_cost_only(&mut self, theta: &Array1<f64>, cost: f64) {
if self
.current_theta
.as_ref()
.is_some_and(|cached| theta_values_match(cached, theta))
{
self.last_cost = Some(cost);
}
}
fn specs(&self) -> Vec<&TermCollectionSpec> {
self.realizers.iter().map(|r| r.spec()).collect()
}
fn designs(&self) -> Vec<&TermCollectionDesign> {
self.realizers.iter().map(|r| r.design()).collect()
}
fn design_revision(&self) -> u64 {
self.realizers
.iter()
.fold(0u64, |acc, r| acc.wrapping_add(r.design_revision()))
}
}
pub(crate) fn seed_risk_profile_for_likelihood_family(
family: &LikelihoodSpec,
) -> crate::seeding::SeedRiskProfile {
match &family.response {
ResponseFamily::Gaussian => crate::seeding::SeedRiskProfile::Gaussian,
ResponseFamily::RoystonParmar => crate::seeding::SeedRiskProfile::Survival,
ResponseFamily::Binomial
| ResponseFamily::Poisson
| ResponseFamily::Tweedie { .. }
| ResponseFamily::NegativeBinomial { .. }
| ResponseFamily::Beta { .. }
| ResponseFamily::Gamma => crate::seeding::SeedRiskProfile::GeneralizedLinear,
}
}
const EXACT_JOINT_SECOND_ORDER_THETA_CAP: usize = 8;
fn exact_joint_seed_config(
risk_profile: crate::seeding::SeedRiskProfile,
auxiliary_dim: usize,
) -> crate::seeding::SeedConfig {
let mut config = crate::seeding::SeedConfig {
risk_profile,
num_auxiliary_trailing: auxiliary_dim,
..Default::default()
};
match risk_profile {
crate::seeding::SeedRiskProfile::Gaussian => {
config.max_seeds = 4;
config.seed_budget = 2;
}
crate::seeding::SeedRiskProfile::GeneralizedLinear => {
config.max_seeds = 1;
config.seed_budget = 1;
config.screen_max_inner_iterations = 8;
}
crate::seeding::SeedRiskProfile::Survival => {
config.max_seeds = 8;
config.seed_budget = 4;
config.screen_max_inner_iterations = 8;
}
}
config
}
#[cfg(test)]
mod exact_joint_seed_config_tests {
use super::*;
#[test]
fn exact_joint_marginal_slope_profiles_get_deeper_startup_validation() {
let bms = exact_joint_seed_config(crate::seeding::SeedRiskProfile::GeneralizedLinear, 2);
assert_eq!(bms.max_seeds, 1);
assert_eq!(bms.seed_budget, 1);
assert_eq!(bms.screen_max_inner_iterations, 8);
assert_eq!(bms.num_auxiliary_trailing, 2);
let survival = exact_joint_seed_config(crate::seeding::SeedRiskProfile::Survival, 3);
assert_eq!(survival.max_seeds, 8);
assert_eq!(survival.seed_budget, 4);
assert_eq!(survival.screen_max_inner_iterations, 8);
assert_eq!(survival.num_auxiliary_trailing, 3);
}
#[test]
fn exact_joint_gaussian_keeps_tight_historical_multistart_budget() {
let gaussian = exact_joint_seed_config(crate::seeding::SeedRiskProfile::Gaussian, 1);
assert_eq!(gaussian.max_seeds, 4);
assert_eq!(gaussian.seed_budget, 2);
assert_eq!(
gaussian.screen_max_inner_iterations,
crate::seeding::SeedConfig::default().screen_max_inner_iterations
);
assert_eq!(gaussian.num_auxiliary_trailing, 1);
}
}
pub(crate) fn exact_joint_multistart_outer_problem(
theta0: &Array1<f64>,
lower: &Array1<f64>,
upper: &Array1<f64>,
rho_dim: usize,
auxiliary_dim: usize,
n_params: usize,
gradient: crate::solver::rho_optimizer::Derivative,
hessian: crate::solver::rho_optimizer::DeclaredHessianForm,
prefer_gradient_only: bool,
disable_fixed_point: bool,
risk_profile: crate::seeding::SeedRiskProfile,
tolerance: f64,
max_iter: usize,
bfgs_step_cap: Option<f64>,
bfgs_step_cap_psi: Option<f64>,
screening_cap: Option<Arc<AtomicUsize>>,
profiled_objective_size: Option<(usize, usize)>,
has_constant_curvature: bool,
) -> crate::solver::rho_optimizer::OuterProblem {
let mut seed_heuristic = theta0.to_vec();
for value in &mut seed_heuristic[..rho_dim] {
*value = value.exp();
}
let rho_ceiling = if has_constant_curvature {
crate::estimate::RHO_BOUND
} else {
12.0
};
let mut problem = crate::solver::rho_optimizer::OuterProblem::new(n_params)
.with_gradient(gradient)
.with_hessian(hessian)
.with_prefer_gradient_only(prefer_gradient_only)
.with_disable_fixed_point(disable_fixed_point)
.with_fallback_policy(crate::solver::rho_optimizer::FallbackPolicy::Automatic)
.with_psi_dim(auxiliary_dim)
.with_tolerance(tolerance)
.with_max_iter(max_iter)
.with_bounds(lower.clone(), upper.clone())
.with_initial_rho(theta0.clone())
.with_bfgs_step_cap(bfgs_step_cap)
.with_bfgs_step_cap_psi(bfgs_step_cap_psi)
.with_seed_config({
let mut sc = exact_joint_seed_config(risk_profile, auxiliary_dim);
if has_constant_curvature {
sc.bounds = (sc.bounds.0, rho_ceiling);
}
sc
})
.with_rho_bound(rho_ceiling)
.with_heuristic_lambdas(seed_heuristic);
if let Some((n_obs, p_cols)) = profiled_objective_size {
problem = problem
.with_objective_scale(Some(n_obs as f64))
.with_problem_size(n_obs, p_cols)
.with_arc_initial_regularization(Some(0.25))
.with_operator_initial_trust_radius(Some(4.0));
}
if let Some(screening_cap) = screening_cap {
problem = problem
.with_screening_cap(screening_cap)
.with_screen_initial_rho(true);
}
problem
}
fn kappa_phase_failure_is_fixed_kappa_recoverable(message: &str) -> bool {
message.contains("no candidate seeds passed outer startup validation")
|| message.contains("joint hyper rho dimension mismatch")
|| message.contains("objective returned a non-finite cost")
}
pub fn optimize_spatial_length_scale_exact_joint<FitOut, FitFn, ExactFn, ExactEfsFn, SeedFn>(
data: ArrayView2<'_, f64>,
block_specs: &[TermCollectionSpec],
block_term_indices: &[Vec<usize>],
kappa_options: &SpatialLengthScaleOptimizationOptions,
joint_setup: &ExactJointHyperSetup,
seed_risk_profile: crate::seeding::SeedRiskProfile,
analytic_joint_gradient_available: bool,
analytic_joint_hessian_available: bool,
disable_fixed_point: bool,
screening_cap: Option<Arc<AtomicUsize>>,
outer_derivative_policy: crate::families::custom_family::OuterDerivativePolicy,
mut fit_fn: FitFn,
mut exact_fn: ExactFn,
mut exact_efs_fn: ExactEfsFn,
mut seed_inner_beta_fn: SeedFn,
) -> Result<SpatialLengthScaleOptimizationResult<FitOut>, String>
where
FitOut: Clone,
FitFn: FnMut(
&Array1<f64>,
&[TermCollectionSpec],
&[TermCollectionDesign],
) -> Result<FitOut, String>,
ExactFn: FnMut(
&Array1<f64>,
&[TermCollectionSpec],
&[TermCollectionDesign],
crate::solver::estimate::reml::reml_outer_engine::EvalMode,
&crate::outer_subsample::RowSet,
) -> Result<
(
f64,
Array1<f64>,
crate::solver::rho_optimizer::HessianResult,
),
String,
>,
ExactEfsFn: FnMut(
&Array1<f64>,
&[TermCollectionSpec],
&[TermCollectionDesign],
) -> Result<crate::solver::rho_optimizer::EfsEval, String>,
SeedFn:
FnMut(&Array1<f64>) -> Result<crate::solver::rho_optimizer::SeedOutcome, EstimationError>,
{
let n_blocks = block_specs.len();
if block_term_indices.len() != n_blocks {
return Err(SmoothError::dimension_mismatch(format!(
"block_specs ({}) and block_term_indices ({}) length mismatch",
n_blocks,
block_term_indices.len()
))
.into());
}
let log_kappa_dim = joint_setup.log_kappa_dim();
log::warn!(
"[OUTER-FD-AUDIT/spatial-exact-joint] driver entry: aux_dim={} log_kappa_dim={} kappa_enabled={} rho_dim={} theta0_len={}",
joint_setup.auxiliary_dim(),
log_kappa_dim,
kappa_options.enabled,
joint_setup.rho_dim(),
joint_setup.theta0().len()
);
if joint_setup.auxiliary_dim() == 0 && (!kappa_options.enabled || log_kappa_dim == 0) {
log::warn!(
"[OUTER-FD-AUDIT/spatial-exact-joint] taking FAST path (no outer theta optimization in this driver)"
);
let (designs, resolved_specs) = build_term_collection_designs_and_freeze_joint(
data, block_specs,
)
.map_err(|e| {
format!("failed to build and freeze joint block designs during exact joint kappa optimization: {e}")
})?;
let theta0 = joint_setup.theta0();
let spec_refs: Vec<TermCollectionSpec> = resolved_specs.clone();
let design_refs: Vec<TermCollectionDesign> = designs.clone();
let fit = fit_fn(&theta0, &spec_refs, &design_refs)?;
return Ok(SpatialLengthScaleOptimizationResult {
resolved_specs,
designs,
fit,
timing: None,
});
}
let theta0 = joint_setup.theta0();
let lower = joint_setup.lower();
let upper = joint_setup.upper();
if theta0.len() < log_kappa_dim || lower.len() != theta0.len() || upper.len() != theta0.len() {
return Err(SmoothError::dimension_mismatch(format!(
"invalid exact joint theta setup: theta0={}, lower={}, upper={}, required_log_kappa_dim={}",
theta0.len(),
lower.len(),
upper.len(),
log_kappa_dim
))
.into());
}
let rho_dim = joint_setup.rho_dim();
let all_dims = joint_setup.log_kappa_dims_per_term();
let (boot_designs, best_specs) = build_term_collection_designs_and_freeze_joint(
data,
block_specs,
)
.map_err(|e| {
format!(
"failed to build and freeze joint block designs during exact joint kappa bootstrap: {e}"
)
})?;
let policy_hessian_form = outer_derivative_policy.declared_hessian_form();
let analytic_outer_hessian_available = analytic_joint_hessian_available
&& matches!(
policy_hessian_form,
crate::solver::rho_optimizer::DeclaredHessianForm::Either
| crate::solver::rho_optimizer::DeclaredHessianForm::Dense
| crate::solver::rho_optimizer::DeclaredHessianForm::Operator { .. }
);
let prefer_gradient_only = !analytic_outer_hessian_available;
let theta_dim = theta0.len();
let psi_dim = theta_dim - rho_dim;
let cache_blocks: Vec<(TermCollectionSpec, TermCollectionDesign, Vec<usize>)> = best_specs
.iter()
.zip(boot_designs.iter())
.zip(block_term_indices.iter())
.map(|((spec, design), terms)| (spec.clone(), design.clone(), terms.clone()))
.collect();
struct NBlockExactJointState<'d> {
cache: ExactJointDesignCache<'d>,
}
let mut state = NBlockExactJointState {
cache: ExactJointDesignCache::new(data, cache_blocks, rho_dim, all_dims.clone())?,
};
const KAPPA_PILOT_K: usize = 5_000;
const KAPPA_POLISH_K: usize = 25_000;
const KAPPA_POLISH_TRIGGER_N: usize = 100_000;
let n_total = data.nrows();
let use_staged_kappa = outer_derivative_policy.should_use_staged_kappa(n_total);
if use_staged_kappa {
log::info!(
"[KAPPA-STAGED] auto-engaging pilot+polish schedule: n={} pilot_k={} polish_k={}",
n_total,
KAPPA_PILOT_K,
KAPPA_POLISH_K,
);
}
fn build_uniform_pilot_subsample(
n_total: usize,
k_target: usize,
seed: u64,
) -> crate::outer_subsample::OuterScoreSubsample {
use crate::outer_subsample::OuterScoreSubsample;
let k = k_target.min(n_total);
if k == 0 || n_total == 0 {
return OuterScoreSubsample::from_uniform_inclusion_mask(Vec::new(), n_total, seed);
}
let mut mask: Vec<usize> = Vec::with_capacity(k);
let mut state = seed.wrapping_add(0x9E3779B97F4A7C15);
let splitmix = |s: &mut u64| -> u64 { crate::linalg::utils::splitmix64(s) };
let mut taken = std::collections::HashSet::with_capacity(k);
for j in (n_total - k)..n_total {
let r = (splitmix(&mut state) % (j as u64 + 1)) as usize;
if !taken.insert(r) {
taken.insert(j);
mask.push(j);
} else {
mask.push(r);
}
}
mask.sort_unstable();
mask.dedup();
OuterScoreSubsample::from_uniform_inclusion_mask(mask, n_total, seed)
}
let current_row_set: std::cell::RefCell<crate::outer_subsample::RowSet> = if use_staged_kappa {
let pilot = build_uniform_pilot_subsample(n_total, KAPPA_PILOT_K, n_total as u64);
std::cell::RefCell::new(crate::outer_subsample::RowSet::Subsample {
rows: std::sync::Arc::clone(&pilot.rows),
n_full: n_total,
})
} else {
std::cell::RefCell::new(crate::outer_subsample::RowSet::All)
};
let exact_fn_cell = std::cell::RefCell::new(&mut exact_fn);
let exact_efs_fn_cell = std::cell::RefCell::new(&mut exact_efs_fn);
use std::cell::Cell;
let kphase_cost_calls: Cell<usize> = Cell::new(0);
let kphase_cost_total_s: Cell<f64> = Cell::new(0.0);
let kphase_eval_calls: Cell<usize> = Cell::new(0);
let kphase_eval_total_s: Cell<f64> = Cell::new(0.0);
let kphase_efs_calls: Cell<usize> = Cell::new(0);
let kphase_efs_total_s: Cell<f64> = Cell::new(0.0);
let kphase_optim_start = std::time::Instant::now();
let kphase_log_kappa_dim = log_kappa_dim;
let kphase_log_norms = |theta: &Array1<f64>| -> (f64, f64) {
let theta_norm = theta.iter().map(|v| v * v).sum::<f64>().sqrt();
let log_kappa_norm = if kphase_log_kappa_dim > 0 && theta.len() >= kphase_log_kappa_dim {
let start = theta.len() - kphase_log_kappa_dim;
theta.iter().skip(start).map(|v| v * v).sum::<f64>().sqrt()
} else {
0.0
};
(theta_norm, log_kappa_norm)
};
use crate::solver::rho_optimizer::{
DeclaredHessianForm, Derivative, OuterEval, OuterEvalOrder,
};
let joint_p_cols: usize = boot_designs
.iter()
.map(|d| d.design.ncols())
.sum::<usize>()
.max(1);
let problem = exact_joint_multistart_outer_problem(
&theta0,
&lower,
&upper,
rho_dim,
psi_dim,
theta_dim,
if analytic_joint_gradient_available {
Derivative::Analytic
} else {
Derivative::Unavailable
},
if analytic_outer_hessian_available {
DeclaredHessianForm::Either
} else {
DeclaredHessianForm::Unavailable
},
prefer_gradient_only,
disable_fixed_point,
seed_risk_profile,
kappa_options.rel_tol.max(1e-6),
kappa_options.max_outer_iter.max(1),
Some(5.0),
Some(kappa_options.log_step.clamp(0.25, 1.0)),
screening_cap.clone(),
Some((n_total, joint_p_cols)),
block_specs
.iter()
.any(|s| !constant_curvature_term_indices(s).is_empty()),
);
fn collect_specs(cache: &ExactJointDesignCache<'_>) -> Vec<TermCollectionSpec> {
cache.specs().into_iter().cloned().collect()
}
fn collect_designs(cache: &ExactJointDesignCache<'_>) -> Vec<TermCollectionDesign> {
cache.designs().into_iter().cloned().collect()
}
let result = {
let eval_outer = |ctx: &mut &mut NBlockExactJointState<'_>,
theta: &Array1<f64>,
order: OuterEvalOrder|
-> Result<OuterEval, EstimationError> {
if let Some((cost, grad, hess)) = ctx.cache.memoized_eval(theta) {
let cached_satisfies_order = match order {
OuterEvalOrder::Value => true,
OuterEvalOrder::ValueAndGradient => true,
OuterEvalOrder::ValueGradientHessian => hess.is_analytic(),
};
if cached_satisfies_order {
if !cost.is_finite() {
return Ok(OuterEval::infeasible(theta.len()));
}
if grad.iter().any(|v| !v.is_finite()) {
return Ok(OuterEval::infeasible(theta.len()));
}
return Ok(OuterEval {
cost,
gradient: grad,
hessian: hess,
inner_beta_hint: None,
});
}
}
if let Err(err) = ctx.cache.ensure_theta(theta) {
log::warn!(
"[OUTER] n-block exact-joint spatial: ensure_theta failed during gradient evaluation: {err}"
);
return Ok(OuterEval::infeasible(theta.len()));
}
let design_revision = Some(ctx.cache.design_revision());
let specs = collect_specs(&ctx.cache);
let designs = collect_designs(&ctx.cache);
let clamped = outer_derivative_policy.order_for_evaluation(order);
let need_hessian = matches!(clamped, OuterEvalOrder::ValueGradientHessian)
&& analytic_outer_hessian_available;
let eval_mode = if need_hessian {
crate::solver::estimate::reml::reml_outer_engine::EvalMode::ValueGradientHessian
} else {
crate::solver::estimate::reml::reml_outer_engine::EvalMode::ValueAndGradient
};
let t0 = std::time::Instant::now();
let result = {
let row_set_borrow = current_row_set.borrow();
(*exact_fn_cell.borrow_mut())(theta, &specs, &designs, eval_mode, &row_set_borrow)
};
let elapsed_s = t0.elapsed().as_secs_f64();
kphase_eval_calls.set(kphase_eval_calls.get() + 1);
kphase_eval_total_s.set(kphase_eval_total_s.get() + elapsed_s);
let (theta_norm, log_kappa_norm) = kphase_log_norms(theta);
log::info!(
"[KAPPA-PHASE] phase=eval_outer call={} order={:?} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
kphase_eval_calls.get(),
order,
design_revision,
theta_norm,
log_kappa_norm,
elapsed_s,
);
match result {
Ok((cost, grad, hess)) => {
ctx.cache.store_eval((cost, grad.clone(), hess.clone()));
if !cost.is_finite() {
return Ok(OuterEval::infeasible(theta.len()));
}
if grad.iter().any(|v| !v.is_finite()) {
return Ok(OuterEval::infeasible(theta.len()));
}
Ok(OuterEval {
cost,
gradient: grad,
hessian: hess,
inner_beta_hint: None,
})
}
Err(err) => {
log::warn!(
"[OUTER] n-block exact-joint spatial: exact evaluation failed: {err}"
);
Ok(OuterEval::infeasible(theta.len()))
}
}
};
let obj = problem.build_objective_with_eval_order(
&mut state,
|ctx: &mut &mut NBlockExactJointState<'_>, theta: &Array1<f64>| {
if let Some(cost) = ctx.cache.memoized_cost(theta) {
return Ok(cost);
}
if let Err(err) = ctx.cache.ensure_theta(theta) {
log::warn!(
"[OUTER] n-block exact-joint spatial: ensure_theta failed during cost evaluation: {err}"
);
return Ok(f64::INFINITY);
}
let design_revision = Some(ctx.cache.design_revision());
let specs = collect_specs(&ctx.cache);
let designs = collect_designs(&ctx.cache);
let t0 = std::time::Instant::now();
let result = {
let row_set_borrow = current_row_set.borrow();
(*exact_fn_cell.borrow_mut())(
theta,
&specs,
&designs,
crate::solver::estimate::reml::reml_outer_engine::EvalMode::ValueOnly,
&row_set_borrow,
)
};
let elapsed_s = t0.elapsed().as_secs_f64();
kphase_cost_calls.set(kphase_cost_calls.get() + 1);
kphase_cost_total_s.set(kphase_cost_total_s.get() + elapsed_s);
let (theta_norm, log_kappa_norm) = kphase_log_norms(theta);
log::info!(
"[KAPPA-PHASE] phase=cost call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
kphase_cost_calls.get(),
design_revision,
theta_norm,
log_kappa_norm,
elapsed_s,
);
match result {
Ok((cost, _grad, _hess)) => {
ctx.cache.store_cost_only(theta, cost);
Ok(cost)
}
Err(err) => {
log::warn!(
"[OUTER] n-block exact-joint spatial: exact cost evaluation failed: {err}"
);
Ok(f64::INFINITY)
}
}
},
|ctx: &mut &mut NBlockExactJointState<'_>, theta: &Array1<f64>| {
eval_outer(
ctx,
theta,
if analytic_outer_hessian_available {
OuterEvalOrder::ValueGradientHessian
} else {
OuterEvalOrder::ValueAndGradient
},
)
},
|ctx: &mut &mut NBlockExactJointState<'_>,
theta: &Array1<f64>,
order: OuterEvalOrder| { eval_outer(ctx, theta, order) },
None::<fn(&mut &mut NBlockExactJointState<'_>)>,
Some(
|ctx: &mut &mut NBlockExactJointState<'_>, theta: &Array1<f64>| {
ctx.cache
.ensure_theta(theta)
.map_err(EstimationError::InvalidInput)?;
let design_revision = Some(ctx.cache.design_revision());
let specs = collect_specs(&ctx.cache);
let designs = collect_designs(&ctx.cache);
let t0 = std::time::Instant::now();
let eval_result = (*exact_efs_fn_cell.borrow_mut())(
theta,
&specs,
&designs,
);
let elapsed_s = t0.elapsed().as_secs_f64();
kphase_efs_calls.set(kphase_efs_calls.get() + 1);
kphase_efs_total_s.set(kphase_efs_total_s.get() + elapsed_s);
let (theta_norm, log_kappa_norm) = kphase_log_norms(theta);
log::info!(
"[KAPPA-PHASE] phase=efs call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
kphase_efs_calls.get(),
design_revision,
theta_norm,
log_kappa_norm,
elapsed_s,
);
let eval = eval_result.map_err(EstimationError::RemlOptimizationFailed)?;
Ok(eval)
},
),
);
let mut obj = obj.with_seed_inner_state(
move |_ctx: &mut &mut NBlockExactJointState<'_>, beta: &Array1<f64>| {
(seed_inner_beta_fn)(beta)
},
);
match problem.run(&mut obj, "n-block exact-joint spatial") {
Ok(result) => result,
Err(e) => {
let message = e.to_string();
if kappa_phase_failure_is_fixed_kappa_recoverable(&message) {
drop(obj);
log::warn!(
"[KAPPA-PHASE] length-scale optimization could not validate any seed \
({message}); falling back to a FIXED bootstrap κ (skipping κ \
optimization) and fitting there — a real model at the initial \
length-scale rather than raising (gam#787/#860)."
);
let (designs, resolved_specs) =
build_term_collection_designs_and_freeze_joint(data, block_specs).map_err(
|build_err| {
format!(
"fixed-κ fallback failed to build and freeze joint block \
designs after κ optimization could not validate a seed \
({message}): {build_err}"
)
},
)?;
let fixed_theta0 = joint_setup.theta0();
let spec_refs: Vec<TermCollectionSpec> = resolved_specs.clone();
let design_refs: Vec<TermCollectionDesign> = designs.clone();
let fit = fit_fn(&fixed_theta0, &spec_refs, &design_refs)?;
return Ok(SpatialLengthScaleOptimizationResult {
resolved_specs,
designs,
fit,
timing: None,
});
}
return Err(message);
}
}
};
let kphase_total_s = kphase_optim_start.elapsed().as_secs_f64();
log::info!(
"[KAPPA-PHASE-SUMMARY] log_kappa_dim={} n_cost={} cost_total_s={:.4} n_eval={} eval_total_s={:.4} n_efs={} efs_total_s={:.4} optim_total_s={:.4}",
kphase_log_kappa_dim,
kphase_cost_calls.get(),
kphase_cost_total_s.get(),
kphase_eval_calls.get(),
kphase_eval_total_s.get(),
kphase_efs_calls.get(),
kphase_efs_total_s.get(),
kphase_total_s,
);
let timing = SpatialLengthScaleOptimizationTiming {
log_kappa_dim: kphase_log_kappa_dim,
cost_calls: kphase_cost_calls.get(),
cost_total_s: kphase_cost_total_s.get(),
eval_calls: kphase_eval_calls.get(),
eval_total_s: kphase_eval_total_s.get(),
efs_calls: kphase_efs_calls.get(),
efs_total_s: kphase_efs_total_s.get(),
slow_path_resets: 0,
design_revision_delta: 0,
nfree_miss_shape: 0,
nfree_miss_value: 0,
nfree_miss_gradient: 0,
nfree_miss_penalty: 0,
nfree_miss_revision: 0,
nfree_miss_second_order: 0,
nfree_miss_other: 0,
optim_total_s: kphase_total_s,
};
let theta_star = result.rho;
if use_staged_kappa && n_total >= KAPPA_POLISH_TRIGGER_N {
let polish = build_uniform_pilot_subsample(
n_total,
KAPPA_POLISH_K,
(n_total as u64).wrapping_add(0xA5A5A5A5),
);
*current_row_set.borrow_mut() = crate::outer_subsample::RowSet::Subsample {
rows: std::sync::Arc::clone(&polish.rows),
n_full: n_total,
};
log::info!(
"[KAPPA-STAGED] rotating to polish subsample: k={} at theta_star",
polish.rows.len(),
);
state.cache.ensure_theta(&theta_star)?;
let (polish_cost, polish_grad, _) = {
let specs = collect_specs(&state.cache);
let designs = collect_designs(&state.cache);
let row_set_borrow = current_row_set.borrow();
exact_fn(
&theta_star,
&specs,
&designs,
crate::solver::estimate::reml::reml_outer_engine::EvalMode::ValueAndGradient,
&row_set_borrow,
)?
};
if !polish_cost.is_finite() || polish_grad.iter().any(|value| !value.is_finite()) {
return Err(
"polish subsample exact-joint evaluation produced non-finite objective pieces"
.to_string(),
);
}
}
*current_row_set.borrow_mut() = crate::outer_subsample::RowSet::All;
if use_staged_kappa {
log::info!(
"[KAPPA-STAGED] rotating to full data for final coefficient fit (n={})",
n_total,
);
}
state.cache.ensure_theta(&theta_star)?;
let resolved_specs: Vec<TermCollectionSpec> = collect_specs(&state.cache);
let designs: Vec<TermCollectionDesign> = collect_designs(&state.cache);
let fit = fit_fn(&theta_star, &resolved_specs, &designs)?;
for spec in &resolved_specs {
log_spatial_aniso_scales(spec);
}
Ok(SpatialLengthScaleOptimizationResult {
resolved_specs,
designs,
fit,
timing: Some(timing),
})
}
fn try_exact_joint_latent_coord_optimization(
data: ArrayView2<'_, f64>,
y: ArrayView1<'_, f64>,
weights: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
resolvedspec: &TermCollectionSpec,
best: &FittedTermCollection,
family: LikelihoodSpec,
options: &FitOptions,
latent: &StandardLatentCoordConfig,
) -> Result<FittedTermCollectionWithSpec, EstimationError> {
use crate::solver::rho_optimizer::{
DeclaredHessianForm, Derivative, OuterEval, OuterEvalOrder,
};
let rho_dim = best.fit.lambdas.len();
let latent_flat_dim = latent.values.len();
if latent_flat_dim == 0 {
crate::bail_invalid_estim!(
"latent-coordinate optimization requires a non-empty latent block"
);
}
let direct_hypers =
latent_coord_initial_direct_hypers(latent.values.id_mode(), latent.values.latent_dim())?;
let analytic_rho_count = latent
.analytic_penalties
.as_ref()
.map_or(0, |registry| registry.total_rho_count());
let latent_coord_ext_dim = latent_flat_dim + analytic_rho_count + direct_hypers.len();
let mut theta0 = Array1::<f64>::zeros(rho_dim + latent_coord_ext_dim);
theta0
.slice_mut(s![..rho_dim])
.assign(&best.fit.lambdas.mapv(f64::ln));
theta0
.slice_mut(s![rho_dim..rho_dim + latent_flat_dim])
.assign(latent.values.as_flat());
if !direct_hypers.is_empty() {
let direct_start = rho_dim + latent_flat_dim + analytic_rho_count;
theta0
.slice_mut(s![direct_start..direct_start + direct_hypers.len()])
.assign(&direct_hypers);
}
let mut lower = Array1::<f64>::from_elem(theta0.len(), -12.0);
let mut upper = Array1::<f64>::from_elem(theta0.len(), 12.0);
let latent_bound = latent
.values
.as_flat()
.iter()
.fold(1.0_f64, |acc, &v| acc.max(v.abs()))
+ 10.0;
for axis in rho_dim..rho_dim + latent_flat_dim {
lower[axis] = -latent_bound;
upper[axis] = latent_bound;
}
struct LatentJointContext<'d> {
rho_dim: usize,
cache: SingleBlockLatentCoordDesignCache,
evaluator: crate::estimate::ExternalJointHyperEvaluator<'d>,
}
impl<'d> LatentJointContext<'d> {
fn eval_full(
&mut self,
theta: &Array1<f64>,
order: OuterEvalOrder,
) -> Result<
(
f64,
Array1<f64>,
crate::solver::rho_optimizer::HessianResult,
),
EstimationError,
> {
if let Some(eval) = self.cache.memoized_eval(theta) {
return Ok(eval);
}
self.cache
.ensure_theta(theta)
.map_err(EstimationError::InvalidInput)?;
let hyper_dirs = self
.cache
.hyper_dirs()
.map_err(EstimationError::InvalidInput)?;
let design_revision = Some(self.cache.design_revision());
let registry_for_key = self.cache.analytic_penalties();
self.evaluator
.set_analytic_penalty_registry(registry_for_key.as_deref());
let mut eval = evaluate_joint_reml_outer_eval_at_theta(
&mut self.evaluator,
self.cache.design(),
theta,
self.rho_dim,
hyper_dirs,
None,
order,
design_revision,
)?;
let latent = self.cache.latent().map_err(EstimationError::InvalidInput)?;
if let Some(registry) = registry_for_key {
let mut registry = registry.as_ref().clone();
registry.apply_weight_schedules(
crate::solver::estimate::reml::outer_eval::current_outer_iter() as usize,
);
add_analytic_penalty_objective_to_eval(
theta,
self.rho_dim,
latent.as_ref(),
®istry,
&mut eval,
)?;
}
add_latent_id_objective_to_eval(
theta,
self.rho_dim,
self.cache.analytic_penalty_rho_count(),
latent.as_ref(),
&mut eval,
)?;
self.cache.store_eval(eval.clone());
Ok(eval)
}
fn eval_efs(
&mut self,
theta: &Array1<f64>,
) -> Result<crate::solver::rho_optimizer::EfsEval, EstimationError> {
self.cache
.ensure_theta(theta)
.map_err(EstimationError::InvalidInput)?;
let hyper_dirs = self
.cache
.hyper_dirs()
.map_err(EstimationError::InvalidInput)?;
let registry_for_key = self.cache.analytic_penalties();
self.evaluator
.set_analytic_penalty_registry(registry_for_key.as_deref());
let mut efs = evaluate_joint_reml_efs_at_theta(
&mut self.evaluator,
self.cache.design(),
theta,
self.rho_dim,
hyper_dirs,
None,
Some(self.cache.design_revision()),
)?;
if let Some(registry) = registry_for_key {
let mut registry = registry.as_ref().clone();
registry.apply_weight_schedules(
crate::solver::estimate::reml::outer_eval::current_outer_iter() as usize,
);
let latent = self.cache.latent().map_err(EstimationError::InvalidInput)?;
let contribution = analytic_penalty_objective_contribution(
theta,
self.rho_dim,
latent.as_ref(),
®istry,
)?;
efs.cost += contribution.cost;
if let (Some(psi_gradient), Some(psi_indices)) =
(efs.psi_gradient.as_mut(), efs.psi_indices.as_ref())
{
if psi_gradient.len() != psi_indices.len() {
crate::bail_invalid_estim!(
"latent-coordinate analytic penalty EFS psi gradient length mismatch: gradient={}, indices={}",
psi_gradient.len(),
psi_indices.len()
);
}
for (local_idx, &theta_idx) in psi_indices.iter().enumerate() {
psi_gradient[local_idx] += contribution.gradient[theta_idx];
}
}
}
Ok(efs)
}
fn eval_cost(&mut self, theta: &Array1<f64>) -> f64 {
if let Some(cost) = self.cache.memoized_cost(theta) {
return cost;
}
if self.cache.ensure_theta(theta).is_err() {
return f64::INFINITY;
}
let design_revision = Some(self.cache.design_revision());
let registry_for_key = self.cache.analytic_penalties();
self.evaluator
.set_analytic_penalty_registry(registry_for_key.as_deref());
let result = {
let design = self.cache.design();
self.evaluator.evaluate_cost_only(
&design.design,
&design.penalties,
&design.nullspace_dims,
design.linear_constraints.clone(),
theta,
self.rho_dim,
None,
"latent-coordinate-joint cost-only",
design_revision,
)
};
match result {
Ok(cost) => {
let latent = match self.cache.latent() {
Ok(latent) => latent,
Err(_) => return f64::INFINITY,
};
let contribution = match latent_id_objective_contribution(
theta,
self.rho_dim,
self.cache.analytic_penalty_rho_count(),
latent.as_ref(),
) {
Ok(contribution) => contribution,
Err(_) => return f64::INFINITY,
};
let cost = cost + contribution.cost;
let cost = if let Some(registry) = registry_for_key {
let mut registry = registry.as_ref().clone();
registry.apply_weight_schedules(
crate::solver::estimate::reml::outer_eval::current_outer_iter()
as usize,
);
match analytic_penalty_objective_contribution(
theta,
self.rho_dim,
latent.as_ref(),
®istry,
) {
Ok(contribution) => cost + contribution.cost,
Err(_) => return f64::INFINITY,
}
} else {
cost
};
self.cache.store_cost(cost);
cost
}
Err(_) => f64::INFINITY,
}
}
}
let mut ctx = LatentJointContext {
rho_dim,
cache: SingleBlockLatentCoordDesignCache::new(
data.to_owned(),
resolvedspec.clone(),
best.design.clone(),
latent,
rho_dim,
)
.map_err(EstimationError::InvalidInput)?,
evaluator: crate::estimate::ExternalJointHyperEvaluator::new(
y,
weights,
&best.design.design,
offset,
&best.design.penalties,
&external_opts_for_design(&family, &best.design, options),
"latent-coordinate-joint",
)?,
};
let registry_for_key = ctx.cache.analytic_penalties();
ctx.evaluator
.set_analytic_penalty_registry(registry_for_key.as_deref());
ctx.evaluator
.set_persistent_latent_values_fingerprint(latent.values.id_mode());
if let Some(cached_t) = ctx
.evaluator
.load_persistent_latent_values(latent.values.n_obs(), latent.values.latent_dim())
{
for (dst, src) in theta0
.slice_mut(s![rho_dim..rho_dim + latent_flat_dim])
.iter_mut()
.zip(cached_t.iter())
{
*dst = *src;
}
}
let problem = exact_joint_multistart_outer_problem(
&theta0,
&lower,
&upper,
rho_dim,
latent_coord_ext_dim,
theta0.len(),
Derivative::Analytic,
DeclaredHessianForm::Unavailable,
false,
false,
seed_risk_profile_for_likelihood_family(&family),
options.tol,
options.max_iter.max(1),
Some(5.0),
Some(0.5),
None,
Some((data.nrows(), best.design.design.ncols().max(1))),
!constant_curvature_term_indices(resolvedspec).is_empty(),
);
let eval_outer = |ctx: &mut &mut LatentJointContext<'_>,
theta: &Array1<f64>,
order: OuterEvalOrder|
-> Result<OuterEval, EstimationError> {
let (cost, gradient, hessian) = ctx.eval_full(theta, order)?;
Ok(OuterEval {
cost,
gradient,
hessian,
inner_beta_hint: None,
})
};
let result = {
let mut obj = problem.build_objective_with_eval_order(
&mut ctx,
|ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>| Ok(ctx.eval_cost(theta)),
|ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>| {
eval_outer(ctx, theta, OuterEvalOrder::ValueAndGradient)
},
|ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>, order: OuterEvalOrder| {
eval_outer(ctx, theta, order)
},
Some(|ctx: &mut &mut LatentJointContext<'_>| {
ctx.cache.reset();
}),
Some(|ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>| ctx.eval_efs(theta)),
);
problem
.run(&mut obj, "latent-coordinate joint REML")
.map_err(|e| {
EstimationError::InvalidInput(format!(
"latent-coordinate joint optimization failed after exhausting strategy fallbacks: {e}"
))
})?
};
if !result.converged {
crate::bail_invalid_estim!(
"latent-coordinate joint optimization did not converge after {} iterations (final_objective={:.6e}, final_grad_norm={})",
result.iterations,
result.final_value,
result.final_grad_norm_report(),
);
}
let theta_star = result.rho;
let rho_star = theta_star.slice(s![..rho_dim]).mapv(f64::exp);
let mut final_data = data.to_owned();
let flat_t = theta_star
.slice(s![rho_dim..rho_dim + latent_flat_dim])
.to_owned();
let mut fitted_latent_values =
Array2::<f64>::zeros((latent.values.n_obs(), latent.values.latent_dim()));
for n in 0..latent.values.n_obs() {
for axis in 0..latent.values.latent_dim() {
let value = flat_t[n * latent.values.latent_dim() + axis];
fitted_latent_values[[n, axis]] = value;
final_data[[n, latent.feature_cols[axis]]] = value;
}
}
let optimized = fit_term_collection_forspecwith_heuristic_lambdas(
final_data.view(),
y,
weights,
offset,
resolvedspec,
rho_star.as_slice(),
family,
options,
)?;
ctx.evaluator
.store_persistent_latent_values(&fitted_latent_values);
let mut fit = optimized.fit;
fit.reml_score = result.final_value;
fit.penalized_objective = result.final_value;
Ok(FittedTermCollectionWithSpec {
fit,
design: optimized.design,
resolvedspec: resolvedspec.clone(),
adaptive_diagnostics: optimized.adaptive_diagnostics,
kappa_timing: None,
})
}
pub fn fit_term_collectionwith_latent_coord_optimization(
data: ArrayView2<'_, f64>,
y: Array1<f64>,
weights: Array1<f64>,
offset: Array1<f64>,
spec: &TermCollectionSpec,
latent: &StandardLatentCoordConfig,
family: LikelihoodSpec,
options: &FitOptions,
) -> Result<FittedTermCollectionWithSpec, EstimationError> {
let n = data.nrows();
if !(y.len() == n && weights.len() == n && offset.len() == n) {
crate::bail_invalid_estim!(
"fit_term_collectionwith_latent_coord_optimization row mismatch: n={}, y={}, weights={}, offset={}",
n,
y.len(),
weights.len(),
offset.len()
);
}
let best = fit_term_collection_forspec(
data,
y.view(),
weights.view(),
offset.view(),
spec,
family.clone(),
options,
)?;
let resolvedspec = freeze_term_collection_from_design(spec, &best.design)?;
try_exact_joint_latent_coord_optimization(
data,
y.view(),
weights.view(),
offset.view(),
&resolvedspec,
&best,
family,
options,
latent,
)
}
pub fn fit_term_collectionwith_spatial_length_scale_optimization(
data: ArrayView2<'_, f64>,
y: Array1<f64>,
weights: Array1<f64>,
offset: Array1<f64>,
spec: &TermCollectionSpec,
family: LikelihoodSpec,
options: &FitOptions,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<FittedTermCollectionWithSpec, EstimationError> {
let mut resolvedspec = spec.clone();
let spatial_terms = spatial_length_scale_term_indices(&resolvedspec);
let n = data.nrows();
if !(y.len() == n && weights.len() == n && offset.len() == n) {
crate::bail_invalid_estim!(
"fit_term_collectionwith_spatial_length_scale_optimization row mismatch: n={}, y={}, weights={}, offset={}",
n,
y.len(),
weights.len(),
offset.len()
);
}
if !kappa_options.enabled || spatial_terms.is_empty() {
let out = fit_term_collection_forspec(
data,
y.view(),
weights.view(),
offset.view(),
&resolvedspec,
family,
options,
)?;
let resolvedspec = freeze_term_collection_from_design(&resolvedspec, &out.design)?;
return Ok(FittedTermCollectionWithSpec {
fit: out.fit,
design: out.design,
resolvedspec,
adaptive_diagnostics: out.adaptive_diagnostics,
kappa_timing: None,
});
}
if kappa_options.max_outer_iter == 0 {
crate::bail_invalid_estim!("spatial kappa optimization requires max_outer_iter >= 1");
}
if !(kappa_options.log_step.is_finite() && kappa_options.log_step > 0.0) {
crate::bail_invalid_estim!("spatial kappa optimization requires log_step > 0");
}
if !(kappa_options.min_length_scale.is_finite()
&& kappa_options.max_length_scale.is_finite()
&& kappa_options.min_length_scale > 0.0
&& kappa_options.max_length_scale >= kappa_options.min_length_scale)
{
crate::bail_invalid_estim!(
"spatial kappa optimization requires valid positive length_scale bounds"
);
}
let pilot_threshold = kappa_options.pilot_subsample_threshold;
if pilot_threshold > 0 && n > pilot_threshold * 2 {
log::info!(
"[spatial-kappa] n={n} exceeds pilot threshold {}; using pilot geometry only for deterministic anisotropy initialization",
pilot_threshold * 2,
);
apply_spatial_anisotropy_pilot_initializer(
data,
&mut resolvedspec,
&spatial_terms,
pilot_threshold,
kappa_options,
);
}
apply_response_aware_anisotropy_seed(data, y.view(), &mut resolvedspec, &spatial_terms);
for term_idx in constant_curvature_term_indices(&resolvedspec) {
if let Some(kappa_seed) =
select_constant_curvature_kappa_sign_seed(data, y.view(), &resolvedspec, term_idx)
&& kappa_seed != 0.0
&& let Some(SmoothBasisSpec::ConstantCurvature { spec: cc, .. }) =
resolvedspec.smooth_terms.get_mut(term_idx).map(|t| &mut t.basis)
{
log::info!(
"[#1464] pinned CC term {term_idx} baseline κ to κ-fair scan value {kappa_seed} \
(raw profiled REML is sign-blind; scan is authoritative for the sign)"
);
cc.kappa = kappa_seed;
}
}
let baseline_options = superseded_fit_options(options);
let best = fit_term_collection_forspec(
data,
y.view(),
weights.view(),
offset.view(),
&resolvedspec,
family.clone(),
&baseline_options,
)?;
resolvedspec = freeze_term_collection_from_design(&resolvedspec, &best.design)?;
let spatial_terms = spatial_length_scale_term_indices(&resolvedspec);
sync_aniso_contrasts_from_metadata(&mut resolvedspec, &best.design.smooth);
if spatial_terms.is_empty() {
let fitted = fit_term_collection_forspecwith_heuristic_lambdas(
data,
y.view(),
weights.view(),
offset.view(),
&resolvedspec,
best.fit.lambdas.as_slice(),
family,
options,
)?;
return Ok(FittedTermCollectionWithSpec {
fit: fitted.fit,
design: fitted.design,
resolvedspec,
adaptive_diagnostics: fitted.adaptive_diagnostics,
kappa_timing: None,
});
}
let initial_score = fit_score(&best.fit);
if !initial_score.is_finite() {
log::debug!("[spatial-kappa] initial profiled score is non-finite");
}
let exact_joint = require_successful_spatial_optimization_result(
initial_score,
try_exact_joint_spatial_length_scale_optimization(
data,
y.view(),
weights.view(),
offset.view(),
&resolvedspec,
&best,
family,
options,
kappa_options,
&spatial_terms,
)
.map(|opt| {
opt.map(|fit| {
let score = fit_score(&fit.fit);
(fit, score)
})
}),
)?;
log_spatial_aniso_scales(&exact_joint.resolvedspec);
Ok(exact_joint)
}
#[derive(Clone, Debug)]
pub struct CurvatureInference {
pub term_idx: usize,
pub kappa_hat: f64,
pub ci: crate::geometry::curvature_estimand::KappaProfileCi,
pub flatness: crate::geometry::curvature_estimand::FlatnessTest,
}
pub fn curvature_inference_forspec(
data: ArrayView2<'_, f64>,
y: ArrayView1<'_, f64>,
weights: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
resolvedspec: &TermCollectionSpec,
term_idx: usize,
family: LikelihoodSpec,
options: &FitOptions,
level: f64,
) -> Result<CurvatureInference, EstimationError> {
let kappa_hat = get_constant_curvature_kappa(resolvedspec, term_idx).ok_or_else(|| {
EstimationError::InvalidInput(format!(
"curvature_inference_forspec: term {term_idx} is not a constant-curvature smooth"
))
})?;
let (kappa_min, kappa_max) = constant_curvature_kappa_bounds(data, resolvedspec, term_idx);
let cc_fair_inputs: Option<(Array2<f64>, crate::basis::ConstantCurvatureBasisSpec)> =
if kappa_hat < 0.0 {
match resolvedspec.smooth_terms.get(term_idx).map(|t| &t.basis) {
Some(SmoothBasisSpec::ConstantCurvature {
feature_cols, spec, ..
}) => select_columns(data, feature_cols)
.ok()
.map(|x| (x, spec.clone())),
_ => None,
}
} else {
None
};
let v_p_cache: std::cell::RefCell<std::collections::HashMap<u64, f64>> =
std::cell::RefCell::new(std::collections::HashMap::new());
let v_p = |kappa: f64| -> Result<f64, String> {
if !kappa.is_finite() {
return Err(format!("V_p probed a non-finite κ = {kappa}"));
}
let key = kappa.to_bits();
if let Some(&cached) = v_p_cache.borrow().get(&key) {
return Ok(cached);
}
let score = if let Some((x_term, base_spec)) = &cc_fair_inputs {
let mut probe_spec = base_spec.clone();
probe_spec.kappa = kappa;
crate::basis::constant_curvature_kappa_fair_sign_score(x_term.view(), y, &probe_spec)
.map_err(|e| format!("κ-fair criterion at κ={kappa} failed: {e}"))?
} else {
fixed_kappa_profiled_reml_score(
data,
y,
weights,
offset,
resolvedspec,
term_idx,
kappa,
family.clone(),
options,
)
.map_err(|e| format!("V_p fixed-κ fit at κ={kappa} failed: {e}"))?
};
v_p_cache.borrow_mut().insert(key, score);
Ok(score)
};
let h = (1e-3 * (kappa_max - kappa_min)).max(1e-4);
let v_pp = match (v_p(kappa_hat + h), v_p(kappa_hat), v_p(kappa_hat - h)) {
(Ok(vp), Ok(v0), Ok(vm)) => (vp - 2.0 * v0 + vm) / (h * h),
_ => f64::NAN, };
let ci = crate::geometry::curvature_estimand::profile_ci_walk(
&v_p, kappa_hat, v_pp, kappa_min, kappa_max, level, 1e-4,
)
.map_err(EstimationError::InvalidInput)?;
let flatness = crate::geometry::curvature_estimand::flatness_lr_test(&v_p, kappa_hat)
.map_err(EstimationError::InvalidInput)?;
Ok(CurvatureInference {
term_idx,
kappa_hat,
ci,
flatness,
})
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum SmoothLrCorrection {
LawleyLrEstimatedLambda,
LawleyLrFixedLambda,
None,
}
impl SmoothLrCorrection {
pub fn label(self) -> &'static str {
match self {
SmoothLrCorrection::LawleyLrEstimatedLambda => "lawley_lr_estimated_lambda",
SmoothLrCorrection::LawleyLrFixedLambda => "lawley_lr_fixed_lambda",
SmoothLrCorrection::None => "none",
}
}
}
#[derive(Clone, Debug)]
pub struct SmoothTermLrInference {
pub name: String,
pub term_idx: usize,
pub statistic_lr: f64,
pub ref_df: f64,
pub bartlett_factor: f64,
pub bartlett_factor_conditional: Option<f64>,
pub rho_variation_shift: Option<f64>,
pub statistic_corrected: f64,
pub p_value_uncorrected: f64,
pub p_value_corrected: f64,
pub material: bool,
pub correction: SmoothLrCorrection,
}
pub const SMOOTH_LR_MATERIAL_THRESHOLD: f64 = 0.10;
fn fitted_rho_penalty_components(
penalties: &[BlockwisePenalty],
lambdas: &[f64],
p_total: usize,
) -> Result<Vec<crate::inference::lawley::RhoPenaltyComponent>, EstimationError> {
if penalties.len() != lambdas.len() {
return Err(EstimationError::InvalidInput(format!(
"smooth_term_lr_inference: penalty/lambda count mismatch ({} penalties, {} lambdas)",
penalties.len(),
lambdas.len()
)));
}
let mut components = Vec::with_capacity(penalties.len());
for (idx, (penalty, &lambda)) in penalties.iter().zip(lambdas.iter()).enumerate() {
if !(lambda.is_finite() && lambda >= 0.0) {
return Err(EstimationError::InvalidInput(format!(
"smooth_term_lr_inference: lambda[{idx}] is invalid: {lambda}"
)));
}
let r = &penalty.col_range;
if r.end > p_total {
return Err(EstimationError::InvalidInput(format!(
"smooth_term_lr_inference: penalty[{idx}] range {:?} exceeds coefficient dimension {p_total}",
r
)));
}
let mut s_component = Array2::<f64>::zeros((p_total, p_total));
s_component
.slice_mut(s![r.start..r.end, r.start..r.end])
.scaled_add(lambda, &penalty.local);
components.push(crate::inference::lawley::RhoPenaltyComponent { s_component });
}
Ok(components)
}
pub fn smooth_term_lr_inference_forspec(
data: ArrayView2<'_, f64>,
y: ArrayView1<'_, f64>,
weights: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
resolvedspec: &TermCollectionSpec,
family: LikelihoodSpec,
options: &FitOptions,
) -> Result<Vec<SmoothTermLrInference>, EstimationError> {
use crate::inference::lawley::{
LAWLEY_PAIR_MATRIX_MAX_ROWS, known_scale_expected_jets_with_dispersion,
lawley_lr_bartlett_factor, lawley_lr_mean_shift_with_rho_variation,
};
let n = data.nrows();
let full = fit_term_collection_forspec(
data,
y,
weights,
offset,
resolvedspec,
family.clone(),
options,
)?;
let ll_full = full.fit.log_likelihood;
let p_total = full.design.design.ncols();
let lambdas = full.fit.lambdas.as_slice().ok_or_else(|| {
EstimationError::InvalidInput(
"smooth_term_lr_inference: non-contiguous lambda vector".to_string(),
)
})?;
let s_lambda = weighted_blockwise_penalty_sum(&full.design.penalties, lambdas, p_total);
let rho_penalty_components =
fitted_rho_penalty_components(&full.design.penalties, lambdas, p_total)?;
let rho_covariance = full.fit.artifacts.rho_covariance.as_ref().filter(|cov| {
cov.nrows() == rho_penalty_components.len() && cov.ncols() == rho_penalty_components.len()
});
let full_design_dense = full.design.design.to_dense();
let influence = full.fit.coefficient_influence();
let family_disp = lawley_dispersion_for_family(&family, &full.fit);
let mut penalty_cursor = full.design.random_effect_ranges.len();
let mut out = Vec::<SmoothTermLrInference>::new();
for (term_idx, design_term) in full.design.smooth.terms.iter().enumerate() {
let k = design_term.penalties_local.len();
let block_start = penalty_cursor;
penalty_cursor += k;
if design_term.shape != ShapeConstraint::None {
continue;
}
let coeff_range = design_term.coeff_range.clone();
if coeff_range.start >= coeff_range.end || coeff_range.end > p_total {
continue;
}
let edf = full.fit.per_term_edf(coeff_range.clone(), block_start, k);
let ref_df = wood_reference_df(influence, &coeff_range).unwrap_or(edf.max(1e-12));
if !(ref_df.is_finite() && ref_df > 0.0) {
continue;
}
let mut null_spec = resolvedspec.clone();
let Some(spec_pos) = null_spec
.smooth_terms
.iter()
.position(|t| t.name == design_term.name)
else {
continue;
};
null_spec.smooth_terms.remove(spec_pos);
let null_fit = fit_term_collection_forspec(
data,
y,
weights,
offset,
&null_spec,
family.clone(),
options,
);
let (statistic_lr, eta_null) = match null_fit {
Ok(null) if null.fit.log_likelihood.is_finite() => {
let w = (2.0 * (ll_full - null.fit.log_likelihood)).max(0.0);
let mut eta = null.design.design.dot(&null.fit.beta);
eta += &offset;
(w, Some(eta))
}
_ => (f64::NAN, None),
};
let chi2 = statrs::distribution::ChiSquared::new(ref_df).ok();
let p_uncorrected = match (chi2.as_ref(), statistic_lr.is_finite()) {
(Some(dist), true) => {
use statrs::distribution::ContinuousCDF;
(1.0 - dist.cdf(statistic_lr)).clamp(0.0, 1.0)
}
_ => f64::NAN,
};
let mut bartlett_factor = 1.0;
let mut bartlett_factor_conditional = None;
let mut rho_variation_shift = None;
let mut statistic_corrected = statistic_lr;
let mut p_corrected = p_uncorrected;
let mut correction = SmoothLrCorrection::None;
if let (Some(eta), true, true) = (
eta_null.as_ref(),
statistic_lr.is_finite(),
n <= LAWLEY_PAIR_MATRIX_MAX_ROWS,
) {
let kappas: Option<Vec<_>> = (0..n)
.map(|i| {
known_scale_expected_jets_with_dispersion(&family, eta[i], family_disp)
.and_then(|jets| jets.kappas().ok())
})
.collect();
if let (Some(kappas), Some(dist)) = (kappas, chi2.as_ref()) {
let fixed_factor = lawley_lr_bartlett_factor(
full_design_dense.view(),
&kappas,
Some(s_lambda.view()),
coeff_range.clone(),
ref_df,
);
if let Ok(c_cond) = fixed_factor
&& c_cond.is_finite()
&& c_cond > 0.0
{
let mut c_applied = c_cond;
correction = SmoothLrCorrection::LawleyLrFixedLambda;
if let Some(cov) = rho_covariance
&& let Ok(total_shift) = lawley_lr_mean_shift_with_rho_variation(
full_design_dense.view(),
&kappas,
s_lambda.view(),
coeff_range.clone(),
&rho_penalty_components,
cov.view(),
)
{
let mean_w = ref_df + total_shift;
if let Some(c_est) =
crate::inference::higher_order::bartlett_factor_from_mean(
mean_w, ref_df,
)
&& c_est.is_finite()
&& c_est > 0.0
{
let conditional_shift = (c_cond - 1.0) * ref_df;
c_applied = c_est;
bartlett_factor_conditional = Some(c_cond);
rho_variation_shift = Some(total_shift - conditional_shift);
correction = SmoothLrCorrection::LawleyLrEstimatedLambda;
}
}
use statrs::distribution::ContinuousCDF;
bartlett_factor = c_applied;
statistic_corrected = statistic_lr / c_applied;
p_corrected = (1.0 - dist.cdf(statistic_corrected)).clamp(0.0, 1.0);
}
}
}
let material = match correction {
SmoothLrCorrection::LawleyLrEstimatedLambda
| SmoothLrCorrection::LawleyLrFixedLambda => {
let factor_move = (bartlett_factor - 1.0).abs();
let p_denom = p_uncorrected.max(p_corrected).max(f64::MIN_POSITIVE);
let p_move = if p_uncorrected.is_finite() && p_corrected.is_finite() {
(p_corrected - p_uncorrected).abs() / p_denom
} else {
0.0
};
factor_move > SMOOTH_LR_MATERIAL_THRESHOLD || p_move > SMOOTH_LR_MATERIAL_THRESHOLD
}
SmoothLrCorrection::None => false,
};
out.push(SmoothTermLrInference {
name: design_term.name.clone(),
term_idx,
statistic_lr,
ref_df,
bartlett_factor,
bartlett_factor_conditional,
rho_variation_shift,
statistic_corrected,
p_value_uncorrected: p_uncorrected,
p_value_corrected: p_corrected,
material,
correction,
});
}
Ok(out)
}
fn lawley_dispersion_for_family(family: &LikelihoodSpec, fit: &UnifiedFitResult) -> f64 {
match family.response {
crate::types::ResponseFamily::Gaussian => {
let sd = fit.standard_deviation;
(sd * sd).max(f64::MIN_POSITIVE)
}
crate::types::ResponseFamily::Gamma => {
let shape = fit.standard_deviation;
if shape.is_finite() && shape > 0.0 {
1.0 / shape
} else {
1.0
}
}
_ => 1.0,
}
}
fn wood_reference_df(influence: Option<&Array2<f64>>, coeff_range: &Range<usize>) -> Option<f64> {
let f = influence?;
let (start, end) = (coeff_range.start, coeff_range.end);
if start >= end || end > f.nrows() || end > f.ncols() {
return None;
}
let block = f.slice(s![start..end, start..end]);
let tr = (0..block.nrows()).map(|i| block[[i, i]]).sum::<f64>();
let tr2 = block.dot(&block).diag().sum();
(tr.is_finite() && tr2.is_finite() && tr > 0.0 && tr2 > 0.0).then(|| (tr * tr / tr2).max(1e-12))
}