1fn try_build_spatial_term_log_kappa_derivative(
2 data: ArrayView2<'_, f64>,
3 resolvedspec: &TermCollectionSpec,
4 design: &TermCollectionDesign,
5 term_idx: usize,
6) -> Result<
7 Option<(
8 Range<usize>,
9 usize,
10 Array2<f64>,
11 Array2<f64>,
12 Array2<f64>,
13 Array2<f64>,
14 Vec<Array2<f64>>,
15 Vec<Array2<f64>>,
16 Option<std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>>,
17 )>,
18 EstimationError,
19> {
20 let Some(smooth_term) = design.smooth.terms.get(term_idx) else {
21 return Ok(None);
22 };
23 let Some(termspec) = resolvedspec.smooth_terms.get(term_idx) else {
24 return Ok(None);
25 };
26
27 let derivative_bundle = match &termspec.basis {
28 SmoothBasisSpec::ThinPlate {
29 feature_cols,
30 spec,
31 input_scales,
32 } => {
33 let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
34 let mut spec_local = spec.clone();
35 if let Some(s) = input_scales {
36 apply_input_standardization(&mut x, s);
37 spec_local.length_scale =
38 compensate_length_scale_for_standardization(spec.length_scale, s);
39 }
40 build_thin_plate_basis_log_kappa_derivatives(x.view(), &spec_local)
41 .map_err(EstimationError::from)?
42 }
43 SmoothBasisSpec::Sphere { .. } => return Ok(None),
44 SmoothBasisSpec::ConstantCurvature { feature_cols, spec } => {
53 let x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
54 build_constant_curvature_basis_kappa_derivatives(x.view(), spec)
55 .map_err(EstimationError::from)?
56 }
57 SmoothBasisSpec::MeasureJet { .. } => return Ok(None),
63 SmoothBasisSpec::Matern {
64 feature_cols,
65 spec,
66 input_scales,
67 } => {
68 let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
69 let mut spec_local = spec.clone();
70 if let Some(s) = input_scales {
71 apply_input_standardization(&mut x, s);
72 spec_local.length_scale =
73 compensate_length_scale_for_standardization(spec.length_scale, s);
74 }
75 spec_local.double_penalty = false;
90 build_matern_basis_log_kappa_derivatives(x.view(), &spec_local)
91 .map_err(EstimationError::from)?
92 }
93 SmoothBasisSpec::Duchon {
94 feature_cols,
95 spec,
96 input_scales,
97 } => {
98 let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
99 let mut spec_local = spec.clone();
100 if let Some(s) = input_scales {
101 apply_input_standardization(&mut x, s);
102 spec_local.length_scale =
103 compensate_optional_length_scale_for_standardization(spec.length_scale, s);
104 }
105 let BasisMetadata::Duchon {
106 centers,
107 identifiability_transform,
108 operator_collocation_points,
109 radial_reparam,
110 ..
111 } = &smooth_term.metadata
112 else {
113 return Ok(None);
114 };
115 if spec_local.radial_reparam.is_none() {
118 spec_local.radial_reparam = radial_reparam.clone();
119 }
120 gam_terms::basis::build_duchon_basis_log_kappa_derivativeswith_collocationwithworkspace(
121 x.view(),
122 &spec_local,
123 centers.view(),
124 identifiability_transform.as_ref(),
125 operator_collocation_points
126 .as_ref()
127 .map(|points| points.view()),
128 &mut BasisWorkspace::default(),
129 )
130 .map_err(EstimationError::from)?
131 }
132 SmoothBasisSpec::BSpline1D { .. }
133 | SmoothBasisSpec::TensorBSpline { .. }
134 | SmoothBasisSpec::ByVariable { .. }
135 | SmoothBasisSpec::FactorSumToZero { .. }
136 | SmoothBasisSpec::BySmooth { .. }
137 | SmoothBasisSpec::FactorSmooth { .. }
138 | SmoothBasisSpec::Pca { .. } => {
139 return Ok(None);
140 }
141 };
142 let mut implicit_operator = derivative_bundle.implicit_operator;
143 let BasisPsiDerivativeResult {
144 design_derivative: mut local_x_psi,
145 penalties_derivative: mut local_s_psi,
146 implicit_operator: local_implicit_first_unused,
147 } = derivative_bundle.first;
148 let BasisPsiSecondDerivativeResult {
149 designsecond_derivative: mut local_x_psi_psi,
150 penaltiessecond_derivative: mut local_s_psi_psi,
151 implicit_operator: local_implicit_second_unused,
152 } = derivative_bundle.second;
153 assert!(local_implicit_first_unused.is_none());
154 assert!(local_implicit_second_unused.is_none());
155
156 if let Some(rotation) = smooth_term.joint_null_rotation.as_ref() {
157 let q = &rotation.rotation;
158 if let Some(op) = implicit_operator.take() {
159 implicit_operator = Some(op.append_full_transform(q).map_err(EstimationError::from)?);
160 } else {
161 if local_x_psi.ncols() != q.nrows() || local_x_psi_psi.ncols() != q.nrows() {
162 return Ok(None);
163 }
164 local_x_psi = fast_ab(&local_x_psi, q);
165 local_x_psi_psi = fast_ab(&local_x_psi_psi, q);
166 }
167 let rotate_penalty = |s_local: Array2<f64>| -> Option<Array2<f64>> {
168 if s_local.nrows() != q.nrows() || s_local.ncols() != q.nrows() {
169 return None;
170 }
171 let qt_s = gam_linalg::faer_ndarray::fast_atb(q, &s_local);
172 Some(gam_linalg::faer_ndarray::fast_ab(&qt_s, q))
173 };
174 let Some(rotated_s_psi) = local_s_psi
175 .into_iter()
176 .map(|s| rotate_penalty(s))
177 .collect::<Option<Vec<_>>>()
178 else {
179 return Ok(None);
180 };
181 local_s_psi = rotated_s_psi;
182 let Some(rotated_s_psi_psi) = local_s_psi_psi
183 .into_iter()
184 .map(|s| rotate_penalty(s))
185 .collect::<Option<Vec<_>>>()
186 else {
187 return Ok(None);
188 };
189 local_s_psi_psi = rotated_s_psi_psi;
190 }
191 let implicit_operator = implicit_operator.map(std::sync::Arc::new);
192
193 if let Some(ref op) = implicit_operator {
194 if op.p_out() != smooth_term.coeff_range.len() {
195 return Ok(None);
196 }
197 } else {
198 if local_x_psi.ncols() != smooth_term.coeff_range.len() {
199 return Ok(None);
200 }
201 if local_x_psi_psi.ncols() != smooth_term.coeff_range.len() {
202 return Ok(None);
203 }
204 }
205 if local_s_psi.is_empty() || local_s_psi.len() != local_s_psi_psi.len() {
206 return Ok(None);
207 }
208 if local_s_psi.iter().any(|s| {
209 s.nrows() != smooth_term.coeff_range.len() || s.ncols() != smooth_term.coeff_range.len()
210 }) {
211 return Ok(None);
212 }
213 if local_s_psi_psi.iter().any(|s| {
214 s.nrows() != smooth_term.coeff_range.len() || s.ncols() != smooth_term.coeff_range.len()
215 }) {
216 return Ok(None);
217 }
218
219 let p_total = design.design.ncols();
220 let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
221 let global_range = (smooth_start + smooth_term.coeff_range.start)
222 ..(smooth_start + smooth_term.coeff_range.end);
223
224 Ok(Some((
225 global_range,
226 p_total,
227 local_x_psi,
228 local_s_psi.iter().fold(
229 Array2::<f64>::zeros((smooth_term.coeff_range.len(), smooth_term.coeff_range.len())),
230 |acc, m| acc + m,
231 ),
232 local_x_psi_psi,
233 local_s_psi_psi.iter().fold(
234 Array2::<f64>::zeros((smooth_term.coeff_range.len(), smooth_term.coeff_range.len())),
235 |acc, m| acc + m,
236 ),
237 local_s_psi,
238 local_s_psi_psi,
239 implicit_operator,
240 )))
241}
242
243fn try_build_spatial_log_kappa_hyper_dirs(
244 data: ArrayView2<'_, f64>,
245 resolvedspec: &TermCollectionSpec,
246 design: &TermCollectionDesign,
247 spatial_terms: &[usize],
248) -> Result<Option<Vec<DirectionalHyperParam>>, EstimationError> {
249 let Some(info_list) =
256 try_build_spatial_log_kappa_derivativeinfo_list(data, resolvedspec, design, spatial_terms)?
257 else {
258 return Ok(None);
259 };
260 Ok(Some(spatial_log_kappa_hyper_dirs_frominfo_list(info_list)?))
261}
262
263pub(crate) fn try_build_latent_coord_hyper_dirs(
264 latent: std::sync::Arc<gam_terms::latent::LatentCoordValues>,
265 resolvedspec: &TermCollectionSpec,
266 design: &TermCollectionDesign,
267 latent_terms: &[gam_problem::types::SmoothTermIdx],
268 analytic_rho_count: usize,
269) -> Result<Option<Vec<DirectionalHyperParam>>, EstimationError> {
270 if latent_terms.is_empty() || latent.is_empty() {
271 return Ok(None);
272 }
273 if latent_terms.len() != 1 {
274 crate::bail_invalid_estim!(
275 "LatentCoord standard-fit hyper_dirs currently require exactly one latent smooth term"
276 .to_string(),
277 );
278 }
279 let term_idx = latent_terms[0];
280 let smooth_term = design.smooth.terms.get(term_idx.get()).ok_or_else(|| {
281 EstimationError::InvalidInput(format!(
282 "LatentCoord term index {term_idx} out of bounds for realized smooth design"
283 ))
284 })?;
285 let termspec = resolvedspec
286 .smooth_terms
287 .get(term_idx.get())
288 .ok_or_else(|| {
289 EstimationError::InvalidInput(format!(
290 "LatentCoord term index {term_idx} out of bounds for resolved smooth spec"
291 ))
292 })?;
293 let p_total = design.design.ncols();
294 let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
295 let global_range = (smooth_start + smooth_term.coeff_range.start)
296 ..(smooth_start + smooth_term.coeff_range.end);
297
298 let operator = match (&termspec.basis, &smooth_term.metadata) {
303 (
304 SmoothBasisSpec::Matern { .. },
305 BasisMetadata::Matern {
306 centers,
307 length_scale,
308 nu,
309 include_intercept,
310 identifiability_transform,
311 ..
312 },
313 ) => gam_terms::basis::LatentCoordDesignDerivative::new_matern(
314 latent.clone(),
315 std::sync::Arc::new(centers.clone()),
316 *length_scale,
317 *nu,
318 *include_intercept,
319 identifiability_transform.clone(),
320 )
321 .map_err(EstimationError::from)?,
322 (
323 SmoothBasisSpec::Duchon { .. },
324 BasisMetadata::Duchon {
325 centers,
326 length_scale,
327 power,
328 nullspace_order,
329 identifiability_transform,
330 ..
331 },
332 ) => gam_terms::basis::LatentCoordDesignDerivative::new_duchon(
333 latent.clone(),
334 std::sync::Arc::new(centers.clone()),
335 *length_scale,
336 *power,
337 *nullspace_order,
338 identifiability_transform.clone(),
339 )
340 .map_err(EstimationError::from)?,
341 (
342 SmoothBasisSpec::Sphere { .. },
343 BasisMetadata::Sphere {
344 centers,
345 penalty_order,
346 method,
347 constraint_transform,
348 ..
349 },
350 ) if matches!(*method, gam_terms::basis::SphereMethod::Wahba) => {
351 gam_terms::basis::LatentCoordDesignDerivative::new_sphere(
352 latent.clone(),
353 std::sync::Arc::new(centers.clone()),
354 *penalty_order,
355 constraint_transform.clone(),
356 )
357 .map_err(EstimationError::from)?
358 }
359 (
360 SmoothBasisSpec::BSpline1D { spec, .. },
361 BasisMetadata::BSpline1D {
362 knots,
363 identifiability_transform,
364 periodic,
365 degree: meta_degree,
366 ..
367 },
368 ) => {
369 let effective_degree = meta_degree.unwrap_or(spec.degree);
373 if let Some((domain_start, period, num_basis)) = periodic {
374 gam_terms::basis::LatentCoordDesignDerivative::new_periodic_bspline(
375 latent.clone(),
376 (*domain_start, *domain_start + *period),
377 effective_degree,
378 *num_basis,
379 identifiability_transform.clone(),
380 )
381 .map_err(EstimationError::from)?
382 } else {
383 gam_terms::basis::LatentCoordDesignDerivative::new_tensor_bspline(
384 latent.clone(),
385 vec![knots.clone()],
386 vec![effective_degree],
387 identifiability_transform.clone(),
388 )
389 .map_err(EstimationError::from)?
390 }
391 }
392 (
393 SmoothBasisSpec::TensorBSpline { .. },
394 BasisMetadata::TensorBSpline {
395 knots,
396 degrees,
397 identifiability_transform,
398 ..
399 },
400 ) => gam_terms::basis::LatentCoordDesignDerivative::new_tensor_bspline(
401 latent.clone(),
402 knots.clone(),
403 degrees.clone(),
404 identifiability_transform.clone(),
405 )
406 .map_err(EstimationError::from)?,
407 (SmoothBasisSpec::Pca { .. }, BasisMetadata::Pca { basis_matrix, .. }) => {
408 gam_terms::basis::LatentCoordDesignDerivative::new_pca(
409 latent.clone(),
410 std::sync::Arc::new(basis_matrix.clone()),
411 )
412 .map_err(EstimationError::from)?
413 }
414 _ => return Ok(None),
415 };
416 if operator.p_out() != global_range.len() {
417 crate::bail_invalid_estim!(
418 "LatentCoord derivative width mismatch for term '{}': operator p={}, coeff range={}",
419 smooth_term.name,
420 operator.p_out(),
421 global_range.len()
422 );
423 }
424 let operator = std::sync::Arc::new(operator);
425 let mut hyper_dirs = Vec::with_capacity(operator.n_axes());
426 for flat_axis in 0..operator.n_axes() {
427 let dir = DirectionalHyperParam::new_compact(
428 gam_solve::estimate::reml::HyperDesignDerivative::from_latent_coord(
429 operator.clone(),
430 flat_axis,
431 global_range.clone(),
432 p_total,
433 ),
434 Vec::new(),
435 None,
436 None,
437 )?
438 .not_penalty_like();
439 hyper_dirs.push(dir);
440 }
441 let direct_dim = latent_coord_direct_hyper_count(latent.id_mode(), latent.latent_dim());
442 if analytic_rho_count + direct_dim > 0 {
443 let zero_x = gam_solve::estimate::reml::HyperDesignDerivative::from(Array2::<f64>::zeros((
444 design.design.nrows(),
445 p_total,
446 )));
447 for _ in 0..analytic_rho_count {
448 hyper_dirs.push(
449 DirectionalHyperParam::new_compact(zero_x.clone(), Vec::new(), None, None)?
450 .not_penalty_like(),
451 );
452 }
453 for _ in 0..direct_dim {
454 hyper_dirs.push(
455 DirectionalHyperParam::new_compact(zero_x.clone(), Vec::new(), None, None)?
456 .not_penalty_like(),
457 );
458 }
459 }
460 Ok(Some(hyper_dirs))
461}
462
463fn latent_coord_direct_hyper_count(
464 id_mode: &gam_terms::latent::LatentIdMode,
465 latent_dim: usize,
466) -> usize {
467 use gam_terms::latent::{AuxPriorStrength, LatentIdMode};
468 match id_mode {
469 LatentIdMode::AuxPrior { strength, .. } => match strength {
470 AuxPriorStrength::Auto => 1,
471 AuxPriorStrength::Fixed(_) => 0,
472 },
473 LatentIdMode::AuxPriorDimSelection { strength, .. } => {
474 latent_dim
475 + match strength {
476 AuxPriorStrength::Auto => 1,
477 AuxPriorStrength::Fixed(_) => 0,
478 }
479 }
480 LatentIdMode::DimSelection { .. } => latent_dim,
481 LatentIdMode::IsometryToReference { strength, .. } => match strength {
484 AuxPriorStrength::Auto => 1,
485 AuxPriorStrength::Fixed(_) => 0,
486 },
487 LatentIdMode::AuxOutcome { head, .. } => head.n_coeffs(latent_dim) + latent_dim,
490 LatentIdMode::None => 0,
491 }
492}
493
494fn latent_coord_initial_direct_hypers(
495 id_mode: &gam_terms::latent::LatentIdMode,
496 latent_dim: usize,
497) -> Result<Array1<f64>, EstimationError> {
498 use gam_terms::latent::{AuxPriorStrength, LatentIdMode};
499 let mut values = Vec::with_capacity(latent_coord_direct_hyper_count(id_mode, latent_dim));
500 match id_mode {
501 LatentIdMode::AuxPrior { strength, .. } => {
502 if matches!(strength, AuxPriorStrength::Auto) {
503 values.push(0.0);
504 }
505 }
506 LatentIdMode::AuxPriorDimSelection {
507 strength,
508 init_log_precision,
509 ..
510 } => {
511 if matches!(strength, AuxPriorStrength::Auto) {
512 values.push(0.0);
513 }
514 append_latent_ard_seed(&mut values, init_log_precision.as_ref(), latent_dim)?;
515 }
516 LatentIdMode::DimSelection { init_log_precision } => {
517 append_latent_ard_seed(&mut values, init_log_precision.as_ref(), latent_dim)?;
518 }
519 LatentIdMode::IsometryToReference { strength, .. } => {
520 if matches!(strength, AuxPriorStrength::Auto) {
521 values.push(0.0);
522 }
523 }
524 LatentIdMode::AuxOutcome {
525 head,
526 init_log_precision,
527 } => {
528 values.extend(std::iter::repeat_n(0.0, head.n_coeffs(latent_dim)));
532 append_latent_ard_seed(&mut values, init_log_precision.as_ref(), latent_dim)?;
533 }
534 LatentIdMode::None => {}
535 }
536 Ok(Array1::from_vec(values))
537}
538
539fn append_latent_ard_seed(
540 values: &mut Vec<f64>,
541 init: Option<&Array1<f64>>,
542 latent_dim: usize,
543) -> Result<(), EstimationError> {
544 if let Some(init) = init {
545 if init.len() != latent_dim {
546 crate::bail_invalid_estim!(
547 "latent dim_selection init_log_precision length mismatch: got {}, expected {}",
548 init.len(),
549 latent_dim
550 );
551 }
552 values.extend(init.iter().copied());
553 } else {
554 values.extend(std::iter::repeat_n(0.0, latent_dim));
555 }
556 Ok(())
557}
558
559struct LatentIdObjectiveContribution {
560 cost: f64,
561 gradient: Array1<f64>,
562}
563
564fn latent_id_objective_contribution(
565 theta: &Array1<f64>,
566 rho_dim: usize,
567 analytic_rho_count: usize,
568 latent: &gam_terms::latent::LatentCoordValues,
569) -> Result<LatentIdObjectiveContribution, EstimationError> {
570 use gam_terms::latent::{AuxPriorStrength, LatentIdMode, aux_prior_targets};
571 let n_obs = latent.n_obs();
572 let latent_dim = latent.latent_dim();
573 let flat_len = latent.len();
574 let mut gradient = Array1::<f64>::zeros(theta.len());
575 let t_start = rho_dim;
576 let direct_start = t_start + flat_len + analytic_rho_count;
577 if theta.len() < direct_start {
578 crate::bail_invalid_estim!(
579 "latent-coordinate theta too short for id objective: got {}, need at least {}",
580 theta.len(),
581 direct_start
582 );
583 }
584 let t = latent.as_matrix();
585 let mut cost = 0.0;
586 let mut cursor = direct_start;
587
588 match latent.id_mode() {
589 LatentIdMode::AuxPrior {
590 u,
591 family,
592 strength,
593 }
594 | LatentIdMode::AuxPriorDimSelection {
595 u,
596 family,
597 strength,
598 ..
599 } => {
600 let (log_mu, mu) = match strength {
601 AuxPriorStrength::Fixed(mu) => (mu.ln(), *mu),
602 AuxPriorStrength::Auto => {
603 let log_mu = theta[cursor];
604 cursor += 1;
605 (log_mu, log_mu.exp())
606 }
607 };
608 let targets = aux_prior_targets(t.view(), u.view(), *family)
609 .map_err(EstimationError::InvalidInput)?;
610 let residual = &t - &targets;
611 let q = residual.iter().map(|v| v * v).sum::<f64>();
612 let k = (n_obs * latent_dim) as f64;
619 cost += 0.5 * mu * q - 0.5 * k * log_mu;
620
621 let projected_residual = aux_prior_targets(residual.view(), u.view(), *family)
622 .map_err(EstimationError::InvalidInput)?;
623 let grad_base = residual - projected_residual;
624 for n in 0..n_obs {
625 for axis in 0..latent_dim {
626 gradient[t_start + n * latent_dim + axis] += mu * grad_base[[n, axis]];
627 }
628 }
629 if matches!(strength, AuxPriorStrength::Auto) {
630 gradient[direct_start] += 0.5 * mu * q - 0.5 * k;
631 }
632 }
633 LatentIdMode::IsometryToReference { reference, strength } => {
634 if reference.dim() != (n_obs, latent_dim) {
641 crate::bail_invalid_estim!(
642 "IsometryToReference reference shape {:?} must equal (n_obs, latent_dim) = ({}, {})",
643 reference.dim(),
644 n_obs,
645 latent_dim
646 );
647 }
648 let mu_slot = cursor;
649 let (log_mu, mu) = match strength {
650 AuxPriorStrength::Fixed(mu) => (mu.ln(), *mu),
651 AuxPriorStrength::Auto => {
652 let log_mu = theta[cursor];
653 cursor += 1;
654 (log_mu, log_mu.exp())
655 }
656 };
657 let residual = &t - reference;
658 let q = residual.iter().map(|v| v * v).sum::<f64>();
659 let k = (n_obs * latent_dim) as f64;
663 cost += 0.5 * mu * q - 0.5 * k * log_mu;
664 for n in 0..n_obs {
665 for axis in 0..latent_dim {
666 gradient[t_start + n * latent_dim + axis] += mu * residual[[n, axis]];
667 }
668 }
669 if matches!(strength, AuxPriorStrength::Auto) {
670 gradient[mu_slot] += 0.5 * mu * q - 0.5 * k;
671 }
672 }
673 LatentIdMode::AuxOutcome { head, .. } => {
674 let n_coeffs = head.n_coeffs(latent_dim);
682 let coeffs = theta
683 .slice(ndarray::s![cursor..cursor + n_coeffs])
684 .to_owned();
685 let (head_nll, grad_coeffs, grad_t) = head
686 .neg_loglik_and_grad(t.view(), coeffs.view())
687 .map_err(EstimationError::InvalidInput)?;
688 cost += head_nll;
689 for (offset, &g) in grad_coeffs.iter().enumerate() {
690 gradient[cursor + offset] += g;
691 }
692 for n in 0..n_obs {
693 for axis in 0..latent_dim {
694 gradient[t_start + n * latent_dim + axis] += grad_t[[n, axis]];
695 }
696 }
697 cursor += n_coeffs;
698 }
699 LatentIdMode::DimSelection { .. } | LatentIdMode::None => {}
700 }
701
702 match latent.id_mode() {
703 LatentIdMode::AuxPriorDimSelection { .. }
704 | LatentIdMode::DimSelection { .. }
705 | LatentIdMode::AuxOutcome { .. } => {
706 for axis in 0..latent_dim {
707 let log_alpha = theta[cursor + axis];
708 let alpha = log_alpha.exp();
709 let mut q_axis = 0.0;
710 for n in 0..n_obs {
711 let flat_idx = n * latent_dim + axis;
712 let value = latent.as_flat()[flat_idx];
713 q_axis += value * value;
714 gradient[t_start + flat_idx] += alpha * value;
715 }
716 cost += 0.5 * alpha * q_axis - 0.5 * n_obs as f64 * log_alpha;
717 gradient[cursor + axis] += 0.5 * alpha * q_axis - 0.5 * n_obs as f64;
718 }
719 cursor += latent_dim;
720 }
721 LatentIdMode::AuxPrior { .. }
722 | LatentIdMode::IsometryToReference { .. }
723 | LatentIdMode::None => {}
724 }
725
726 if cursor != theta.len() {
727 crate::bail_invalid_estim!(
728 "latent-coordinate direct hyperparameter length mismatch: consumed {}, theta len {}",
729 cursor,
730 theta.len()
731 );
732 }
733 Ok(LatentIdObjectiveContribution { cost, gradient })
734}
735
736fn add_latent_id_objective_to_eval(
737 theta: &Array1<f64>,
738 rho_dim: usize,
739 analytic_rho_count: usize,
740 latent: &gam_terms::latent::LatentCoordValues,
741 eval: &mut (
742 f64,
743 Array1<f64>,
744 gam_problem::HessianResult,
745 ),
746) -> Result<(), EstimationError> {
747 let contribution =
748 latent_id_objective_contribution(theta, rho_dim, analytic_rho_count, latent)?;
749 eval.0 += contribution.cost;
750 if eval.1.len() != contribution.gradient.len() {
751 crate::bail_invalid_estim!(
752 "latent-coordinate REML gradient length mismatch: base={}, id={}",
753 eval.1.len(),
754 contribution.gradient.len()
755 );
756 }
757 eval.1 += &contribution.gradient;
758 if eval.2.is_analytic() {
759 eval.2 = gam_problem::HessianResult::Unavailable;
760 }
761 Ok(())
762}
763
764fn analytic_penalty_objective_contribution(
765 theta: &Array1<f64>,
766 rho_dim: usize,
767 latent: &gam_terms::latent::LatentCoordValues,
768 registry: &gam_terms::AnalyticPenaltyRegistry,
769) -> Result<LatentIdObjectiveContribution, EstimationError> {
770 let flat_len = latent.len();
771 let t_start = rho_dim;
772 let t_end = t_start + flat_len;
773 let rho_start = t_end;
774 let rho_end = rho_start + registry.total_rho_count();
775 if theta.len() < rho_end {
776 crate::bail_invalid_estim!(
777 "latent-coordinate theta too short for analytic penalties: got {}, need at least {}",
778 theta.len(),
779 rho_end
780 );
781 }
782 let target_t = theta.slice(s![t_start..t_end]);
783 let rho = theta.slice(s![rho_start..rho_end]);
784 let mut cost = 0.0_f64;
785 let mut gradient = Array1::<f64>::zeros(theta.len());
786 for (penalty, (rho_slice, tier, name)) in registry.penalties.iter().zip(registry.rho_layout()) {
787 let rho_local = rho.slice(s![rho_slice.clone()]);
788 match tier {
789 gam_terms::PenaltyTier::Psi => {
790 cost += penalty.value(target_t.view(), rho_local);
791 let grad = penalty.grad_target(target_t.view(), rho_local);
792 if grad.len() != flat_len {
793 crate::bail_invalid_estim!(
794 "analytic penalty {name:?} gradient length mismatch: got {}, expected {}",
795 grad.len(),
796 flat_len
797 );
798 }
799 for i in 0..flat_len {
800 gradient[t_start + i] += grad[i];
801 }
802 let grad_rho_local = penalty.grad_rho(target_t.view(), rho_local);
803 if grad_rho_local.len() != rho_slice.len() {
804 crate::bail_invalid_estim!(
805 "analytic penalty {name:?} rho-gradient length mismatch: got {}, expected {}",
806 grad_rho_local.len(),
807 rho_slice.len()
808 );
809 }
810 for local_idx in 0..grad_rho_local.len() {
811 gradient[rho_start + rho_slice.start + local_idx] += grad_rho_local[local_idx];
812 }
813 }
814 gam_terms::PenaltyTier::Beta => {}
815 gam_terms::PenaltyTier::Rho => {}
816 }
817 }
818 Ok(LatentIdObjectiveContribution { cost, gradient })
819}
820
821fn add_analytic_penalty_hessian_to_eval(
822 theta: &Array1<f64>,
823 rho_dim: usize,
824 latent: &gam_terms::latent::LatentCoordValues,
825 registry: &gam_terms::AnalyticPenaltyRegistry,
826 eval: &mut (
827 f64,
828 Array1<f64>,
829 gam_problem::HessianResult,
830 ),
831) -> Result<(), EstimationError> {
832 let flat_len = latent.len();
833 let t_start = rho_dim;
834 let t_end = t_start + flat_len;
835 let rho_start = t_end;
836 let rho_end = rho_start + registry.total_rho_count();
837 if theta.len() < rho_end {
838 crate::bail_invalid_estim!(
839 "latent-coordinate theta too short for analytic penalty Hessian: got {}, need at least {}",
840 theta.len(),
841 rho_end
842 );
843 }
844 let gam_problem::HessianResult::Analytic(hessian) = &mut eval.2 else {
845 if eval.2.is_analytic() {
846 eval.2 = gam_problem::HessianResult::Unavailable;
847 }
848 return Ok(());
849 };
850 if hessian.dim() != (theta.len(), theta.len()) {
851 crate::bail_invalid_estim!(
852 "analytic penalty Hessian target shape mismatch: got {}x{}, expected {}x{}",
853 hessian.nrows(),
854 hessian.ncols(),
855 theta.len(),
856 theta.len()
857 );
858 }
859 let target_t = theta.slice(s![t_start..t_end]);
860 let rho = theta.slice(s![rho_start..rho_end]);
861 for (penalty, (rho_slice, tier, _name)) in registry.penalties.iter().zip(registry.rho_layout())
862 {
863 let rho_local = rho.slice(s![rho_slice]);
864 if !matches!(tier, gam_terms::PenaltyTier::Psi) {
865 continue;
866 }
867 if let Some(diag) = penalty.hessian_diag(target_t.view(), rho_local) {
868 if diag.len() != flat_len {
869 crate::bail_invalid_estim!(
870 "analytic penalty Hessian diagonal length mismatch: got {}, expected {}",
871 diag.len(),
872 flat_len
873 );
874 }
875 for i in 0..flat_len {
876 hessian[[t_start + i, t_start + i]] += diag[i];
877 }
878 continue;
879 }
880 let mut probe = Array1::<f64>::zeros(flat_len);
881 for col in 0..flat_len {
882 probe[col] = 1.0;
883 let hv = penalty.hvp(target_t.view(), rho_local, probe.view());
884 if hv.len() != flat_len {
885 crate::bail_invalid_estim!(
886 "analytic penalty Hessian-vector length mismatch: got {}, expected {}",
887 hv.len(),
888 flat_len
889 );
890 }
891 for row in 0..flat_len {
892 hessian[[t_start + row, t_start + col]] += hv[row];
893 }
894 probe[col] = 0.0;
895 }
896 }
897 Ok(())
898}
899
900fn add_analytic_penalty_objective_to_eval(
901 theta: &Array1<f64>,
902 rho_dim: usize,
903 latent: &gam_terms::latent::LatentCoordValues,
904 registry: &gam_terms::AnalyticPenaltyRegistry,
905 eval: &mut (
906 f64,
907 Array1<f64>,
908 gam_problem::HessianResult,
909 ),
910) -> Result<(), EstimationError> {
911 let contribution = analytic_penalty_objective_contribution(theta, rho_dim, latent, registry)?;
912 eval.0 += contribution.cost;
913 if eval.1.len() != contribution.gradient.len() {
914 crate::bail_invalid_estim!(
915 "latent-coordinate REML gradient length mismatch: base={}, analytic_penalty={}",
916 eval.1.len(),
917 contribution.gradient.len()
918 );
919 }
920 eval.1 += &contribution.gradient;
921 add_analytic_penalty_hessian_to_eval(theta, rho_dim, latent, registry, eval)?;
922 Ok(())
923}
924
925fn spatial_log_kappa_hyper_dirs_frominfo_list(
926 info_list: Vec<SpatialPsiDerivative>,
927) -> Result<Vec<DirectionalHyperParam>, EstimationError> {
928 use gam_solve::estimate::reml::ImplicitDerivLevel;
929 use std::collections::HashMap;
930
931 let log_kappa_dim = info_list.len();
932 let group_ids: Vec<Option<usize>> = info_list.iter().map(|e| e.aniso_group_id).collect();
938 let mut group_indices_map: HashMap<usize, Vec<usize>> = HashMap::new();
939 for (idx, gid) in group_ids.iter().enumerate() {
940 if let Some(g) = gid {
941 group_indices_map.entry(*g).or_default().push(idx);
942 }
943 }
944
945 let mut hyper_dirs = Vec::with_capacity(log_kappa_dim);
946 for (i, info) in info_list.into_iter().enumerate() {
947 let SpatialPsiDerivative {
948 penalty_index: _,
949 penalty_indices,
950 global_range,
951 total_p,
952 x_psi_local,
953 s_psi_components_local,
954 x_psi_psi_local,
955 s_psi_psi_components_local,
956 aniso_group_id,
957 aniso_cross_designs,
958 aniso_cross_penalty_provider,
959 implicit_operator,
960 implicit_axis,
961 } = info;
962
963 let mut xsecond = vec![None; log_kappa_dim];
964 xsecond[i] = Some(if let Some(ref op) = implicit_operator {
966 gam_solve::estimate::reml::HyperDesignDerivative::from_implicit(
967 op.clone(),
968 ImplicitDerivLevel::SecondDiag(implicit_axis),
969 global_range.clone(),
970 total_p,
971 )
972 } else {
973 gam_solve::estimate::reml::HyperDesignDerivative::from_embedded(
974 x_psi_psi_local,
975 global_range.clone(),
976 total_p,
977 )
978 });
979 if let Some(cross_designs) = aniso_cross_designs {
981 if let Some(gid) = aniso_group_id {
985 let base = group_indices_map
986 .get(&gid)
987 .and_then(|v| v.first().copied())
988 .unwrap_or(i);
989 for (b_axis, cross_mat) in cross_designs.into_iter() {
990 let j = base + b_axis;
991 if j < log_kappa_dim {
992 xsecond[j] = Some(if let Some(ref op) = implicit_operator {
993 gam_solve::estimate::reml::HyperDesignDerivative::from_implicit(
994 op.clone(),
995 ImplicitDerivLevel::SecondCross(implicit_axis, b_axis),
996 global_range.clone(),
997 total_p,
998 )
999 } else {
1000 gam_solve::estimate::reml::HyperDesignDerivative::from_embedded(
1001 cross_mat,
1002 global_range.clone(),
1003 total_p,
1004 )
1005 });
1006 }
1007 }
1008 }
1009 }
1010 let s_components = penalty_indices
1011 .iter()
1012 .copied()
1013 .zip(s_psi_components_local.into_iter().map(|local| {
1014 gam_solve::estimate::reml::HyperPenaltyDerivative::from_embedded(
1015 local,
1016 global_range.clone(),
1017 total_p,
1018 )
1019 }))
1020 .collect::<Vec<_>>();
1021 let s2_components = penalty_indices
1022 .iter()
1023 .copied()
1024 .zip(s_psi_psi_components_local.into_iter().map(|local| {
1025 gam_solve::estimate::reml::HyperPenaltyDerivative::from_embedded(
1026 local,
1027 global_range.clone(),
1028 total_p,
1029 )
1030 }))
1031 .collect::<Vec<_>>();
1032 let mut ssecond_components = vec![None; log_kappa_dim];
1033 ssecond_components[i] = Some(s2_components);
1034 let mut penaltysecond_partner_indices: Option<Vec<usize>> = None;
1035 let penaltysecond_component_provider =
1036 if let (Some(provider), Some(gid)) = (aniso_cross_penalty_provider, aniso_group_id) {
1037 let group_indices = group_indices_map.get(&gid).cloned().unwrap_or_default();
1038 let axis_in_group =
1039 group_indices
1040 .iter()
1041 .position(|&idx| idx == i)
1042 .ok_or_else(|| {
1043 EstimationError::InvalidInput(format!(
1044 "missing spatial hyper axis {} in anisotropy group {}",
1045 i, gid
1046 ))
1047 })?;
1048 penaltysecond_partner_indices = Some(
1049 group_indices
1050 .iter()
1051 .copied()
1052 .filter(|&idx| idx != i)
1053 .collect(),
1054 );
1055 let penalty_indices_inner = penalty_indices.clone();
1056 let global_range_inner = global_range.clone();
1057 let total_p_inner = total_p;
1058 let group_indices_inner = group_indices;
1059 Some(std::sync::Arc::new(
1060 move |j: usize| -> Result<
1061 Option<Vec<gam_solve::estimate::reml::PenaltyDerivativeComponent>>,
1062 EstimationError,
1063 > {
1064 let Some(other_axis_in_group) =
1065 group_indices_inner.iter().position(|&idx| idx == j)
1066 else {
1067 return Ok(None);
1068 };
1069 if other_axis_in_group == axis_in_group {
1070 return Ok(None);
1071 }
1072 let cross_pens = provider(other_axis_in_group)?;
1073 if cross_pens.is_empty() {
1074 return Ok(None);
1075 }
1076 Ok(Some(
1077 penalty_indices_inner
1078 .iter()
1079 .copied()
1080 .zip(cross_pens.into_iter().map(|local| {
1081 gam_solve::estimate::reml::HyperPenaltyDerivative::from_embedded(
1082 local,
1083 global_range_inner.clone(),
1084 total_p_inner,
1085 )
1086 }))
1087 .map(|(penalty_index, matrix)| {
1088 gam_solve::estimate::reml::PenaltyDerivativeComponent {
1089 penalty_index,
1090 matrix,
1091 }
1092 })
1093 .collect(),
1094 ))
1095 },
1096 )
1097 as std::sync::Arc<
1098 dyn Fn(
1099 usize,
1100 ) -> Result<
1101 Option<Vec<gam_solve::estimate::reml::PenaltyDerivativeComponent>>,
1102 EstimationError,
1103 > + Send
1104 + Sync
1105 + 'static,
1106 >)
1107 } else {
1108 None
1109 };
1110 let x_first_hyper = if let Some(ref op) = implicit_operator {
1113 gam_solve::estimate::reml::HyperDesignDerivative::from_implicit(
1114 op.clone(),
1115 ImplicitDerivLevel::First(implicit_axis),
1116 global_range.clone(),
1117 total_p,
1118 )
1119 } else {
1120 gam_solve::estimate::reml::HyperDesignDerivative::from_embedded(
1121 x_psi_local,
1122 global_range.clone(),
1123 total_p,
1124 )
1125 };
1126 let mut dir = DirectionalHyperParam::new_compact(
1127 x_first_hyper,
1128 s_components,
1129 Some(xsecond),
1130 Some(ssecond_components),
1131 )?
1132 .not_penalty_like();
1133 if let Some(provider) = penaltysecond_component_provider {
1134 dir = dir.with_penaltysecond_component_provider(provider);
1135 }
1136 if let Some(partner_indices) = penaltysecond_partner_indices {
1137 dir = dir.with_penaltysecond_partner_indices(partner_indices);
1138 }
1139 hyper_dirs.push(dir);
1140 }
1141 Ok(hyper_dirs)
1142}
1143
1144pub(crate) fn spatial_dims_per_term(
1150 resolvedspec: &TermCollectionSpec,
1151 spatial_terms: &[usize],
1152) -> Vec<usize> {
1153 spatial_terms
1154 .iter()
1155 .map(|&term_idx| {
1156 if let Some(mj) = measure_jet_term_spec(resolvedspec, term_idx) {
1157 measure_jet_psi_dim(mj)
1160 } else if spatial_term_uses_per_axis_psi(resolvedspec, term_idx) {
1161 get_spatial_feature_dim(resolvedspec, term_idx).unwrap_or(1)
1162 } else {
1163 1
1164 }
1165 })
1166 .collect()
1167}
1168
1169fn has_aniso_terms(resolvedspec: &TermCollectionSpec, spatial_terms: &[usize]) -> bool {
1173 spatial_terms
1174 .iter()
1175 .any(|&term_idx| spatial_term_uses_per_axis_psi(resolvedspec, term_idx))
1176}
1177
1178macro_rules! impl_exact_joint_theta_memo {
1184 () => {
1185 fn memoized_cost(&self, theta: &Array1<f64>) -> Option<f64> {
1186 if self
1187 .current_theta
1188 .as_ref()
1189 .is_some_and(|cached| theta_values_match(cached, theta))
1190 {
1191 self.last_eval
1192 .as_ref()
1193 .map(|cached| cached.0)
1194 .or(self.last_cost)
1195 } else {
1196 None
1197 }
1198 }
1199
1200 fn memoized_eval(
1201 &self,
1202 theta: &Array1<f64>,
1203 ) -> Option<(
1204 f64,
1205 Array1<f64>,
1206 gam_problem::HessianResult,
1207 )> {
1208 if self
1209 .current_theta
1210 .as_ref()
1211 .is_some_and(|cached| theta_values_match(cached, theta))
1212 {
1213 self.last_eval.clone()
1214 } else {
1215 None
1216 }
1217 }
1218
1219 fn store_eval(
1220 &mut self,
1221 eval: (
1222 f64,
1223 Array1<f64>,
1224 gam_problem::HessianResult,
1225 ),
1226 ) {
1227 self.last_cost = Some(eval.0);
1228 self.last_eval = Some(eval);
1229 }
1230 };
1231}
1232
1233struct SingleBlockExactJointDesignCache<'d> {
1234 realizer: FrozenTermCollectionIncrementalRealizer<'d>,
1235 current_theta: Option<Array1<f64>>,
1236 last_eval_theta: Option<Array1<f64>>,
1243 last_cost: Option<f64>,
1244 last_eval: Option<(
1245 f64,
1246 Array1<f64>,
1247 gam_problem::HessianResult,
1248 )>,
1249 cached_hyper_dirs: Option<(u64, Vec<DirectionalHyperParam>)>,
1261 spatial_terms: Vec<usize>,
1262 rho_dim: usize,
1263 dims_per_term: Vec<usize>,
1264}
1265
1266impl<'d> SingleBlockExactJointDesignCache<'d> {
1267 fn new(
1268 data: ArrayView2<'d, f64>,
1269 spec: TermCollectionSpec,
1270 design: TermCollectionDesign,
1271 spatial_terms: Vec<usize>,
1272 rho_dim: usize,
1273 dims_per_term: Vec<usize>,
1274 ) -> Result<Self, String> {
1275 Ok(Self {
1276 realizer: FrozenTermCollectionIncrementalRealizer::new(data, spec, design)?,
1277 current_theta: None,
1278 last_eval_theta: None,
1279 last_cost: None,
1280 last_eval: None,
1281 cached_hyper_dirs: None,
1282 spatial_terms,
1283 rho_dim,
1284 dims_per_term,
1285 })
1286 }
1287
1288 fn design_revision(&self) -> u64 {
1289 self.realizer.design_revision()
1290 }
1291
1292 fn hyper_dirs_for_current_design(
1302 &mut self,
1303 data: ArrayView2<'_, f64>,
1304 kind: SpatialHyperKind,
1305 ) -> Result<Vec<DirectionalHyperParam>, EstimationError> {
1306 let revision = self.realizer.design_revision();
1307 if let Some((cached_rev, dirs)) = self.cached_hyper_dirs.as_ref()
1308 && *cached_rev == revision
1309 {
1310 return Ok(dirs.clone());
1311 }
1312 let dirs = try_build_spatial_log_kappa_hyper_dirs(
1313 data,
1314 self.realizer.spec(),
1315 self.realizer.design(),
1316 &self.spatial_terms,
1317 )?
1318 .ok_or_else(|| {
1319 EstimationError::InvalidInput(format!(
1320 "failed to build {} hyper_dirs at current {}",
1321 kind.adjective(),
1322 kind.coord_name(),
1323 ))
1324 })?;
1325 self.cached_hyper_dirs = Some((revision, dirs.clone()));
1326 Ok(dirs)
1327 }
1328
1329 fn nfree_tensor_gradient_hyper_dirs(
1330 &mut self,
1331 theta: &Array1<f64>,
1332 ) -> Result<Vec<DirectionalHyperParam>, EstimationError> {
1333 let psi = &theta.as_slice().ok_or_else(|| {
1334 EstimationError::InvalidInput(
1335 "nfree_tensor_gradient_hyper_dirs: theta is not contiguous".to_string(),
1336 )
1337 })?[self.rho_dim..];
1338 let (global_range, p_total, s_psi_components) = self
1339 .realizer
1340 .canonical_penalty_derivatives_at_psi(&self.spatial_terms, psi)
1341 .map_err(EstimationError::InvalidInput)?;
1342 let zero_x = gam_solve::estimate::reml::HyperDesignDerivative::zero(
1343 self.realizer.design().design.nrows(),
1344 p_total,
1345 );
1346 let components = s_psi_components
1347 .into_iter()
1348 .enumerate()
1349 .map(|(penalty_index, local)| {
1350 (
1351 penalty_index,
1352 gam_solve::estimate::reml::HyperPenaltyDerivative::from_embedded(
1353 local,
1354 global_range.clone(),
1355 p_total,
1356 ),
1357 )
1358 })
1359 .collect::<Vec<_>>();
1360 Ok(DirectionalHyperParam::new_compact(zero_x, components, None, None)?.not_penalty_like())
1361 .map(|dir| vec![dir])
1362 }
1363
1364 fn ensure_theta(&mut self, theta: &Array1<f64>) -> Result<(), String> {
1365 if self
1366 .current_theta
1367 .as_ref()
1368 .is_some_and(|cached| theta_values_match(cached, theta))
1369 {
1370 return Ok(());
1371 }
1372 let t_ensure = std::time::Instant::now();
1373 let log_kappa = SpatialLogKappaCoords::from_theta_tail_with_dims(
1374 theta,
1375 self.rho_dim,
1376 self.dims_per_term.clone(),
1377 );
1378 self.realizer
1379 .apply_log_kappa(&log_kappa, &self.spatial_terms)?;
1380 log::info!(
1381 "[STAGE] ensure_theta (apply_log_kappa, {} terms): {:.3}s",
1382 self.spatial_terms.len(),
1383 t_ensure.elapsed().as_secs_f64(),
1384 );
1385 self.current_theta = Some(theta.clone());
1386 self.last_eval_theta = None;
1387 self.last_cost = None;
1388 self.last_eval = None;
1389 Ok(())
1390 }
1391
1392 fn memoized_cost(&self, theta: &Array1<f64>) -> Option<f64> {
1399 if self
1400 .last_eval_theta
1401 .as_ref()
1402 .is_some_and(|cached| theta_values_match(cached, theta))
1403 {
1404 self.last_eval
1405 .as_ref()
1406 .map(|cached| cached.0)
1407 .or(self.last_cost)
1408 } else {
1409 None
1410 }
1411 }
1412
1413 fn memoized_eval(
1414 &self,
1415 theta: &Array1<f64>,
1416 ) -> Option<(
1417 f64,
1418 Array1<f64>,
1419 gam_problem::HessianResult,
1420 )> {
1421 if self
1422 .last_eval_theta
1423 .as_ref()
1424 .is_some_and(|cached| theta_values_match(cached, theta))
1425 {
1426 self.last_eval.clone()
1427 } else {
1428 None
1429 }
1430 }
1431
1432 fn store_eval_at(
1436 &mut self,
1437 theta: &Array1<f64>,
1438 eval: (
1439 f64,
1440 Array1<f64>,
1441 gam_problem::HessianResult,
1442 ),
1443 ) {
1444 self.last_eval_theta = Some(theta.clone());
1445 self.last_cost = Some(eval.0);
1446 self.last_eval = Some(eval);
1447 }
1448
1449 fn store_cost_at(&mut self, theta: &Array1<f64>, cost: f64) {
1452 self.last_eval_theta = Some(theta.clone());
1453 self.last_cost = Some(cost);
1454 self.last_eval = None;
1458 }
1459
1460 fn spec(&self) -> &TermCollectionSpec {
1461 self.realizer.spec()
1462 }
1463
1464 fn design(&self) -> &TermCollectionDesign {
1465 self.realizer.design()
1466 }
1467
1468 fn supports_nfree_penalty_rekey(&self) -> bool {
1474 self.realizer
1475 .supports_nfree_penalty_rekey(&self.spatial_terms)
1476 }
1477
1478 fn supports_nfree_gradient_only_routing(&self) -> bool {
1479 self.realizer
1480 .supports_nfree_gradient_only_routing(&self.spatial_terms)
1481 }
1482
1483 fn canonical_penalties_at(
1493 &mut self,
1494 theta: &Array1<f64>,
1495 ) -> Result<(Vec<gam_terms::construction::CanonicalPenalty>, Vec<usize>), String> {
1496 let psi = &theta
1497 .as_slice()
1498 .ok_or_else(|| "canonical_penalties_at: theta is not contiguous".to_string())?
1499 [self.rho_dim..];
1500 self.realizer
1501 .canonical_penalties_at_psi(&self.spatial_terms, psi)
1502 }
1503}
1504
1505struct SingleBlockLatentCoordDesignCache {
1506 data: Array2<f64>,
1507 spec: TermCollectionSpec,
1508 design: TermCollectionDesign,
1509 current_theta: Option<Array1<f64>>,
1510 current_latent: Option<std::sync::Arc<gam_terms::latent::LatentCoordValues>>,
1511 current_hyper_dirs: Option<Vec<gam_solve::estimate::reml::DirectionalHyperParam>>,
1512 current_design_cache_id: Option<u64>,
1513 latent_design_cache: gam_solve::latent_cache::LatentDesignCache,
1514 last_cost: Option<f64>,
1515 last_eval: Option<(
1516 f64,
1517 Array1<f64>,
1518 gam_problem::HessianResult,
1519 )>,
1520 term_index: gam_problem::types::SmoothTermIdx,
1521 feature_cols: Vec<usize>,
1522 rho_dim: usize,
1523 n_obs: usize,
1524 latent_dim: usize,
1525 id_mode: gam_terms::latent::LatentIdMode,
1526 manifold: gam_terms::latent::LatentManifold,
1527 retraction_registry: gam_solve::latent_cache::LatentRetractionRegistry,
1528 latent_id: u64,
1529 analytic_penalties: Option<std::sync::Arc<gam_terms::AnalyticPenaltyRegistry>>,
1530 analytic_rho_count: usize,
1531 design_revision: u64,
1532 last_outer_iter: Option<u64>,
1536}
1537
1538impl SingleBlockLatentCoordDesignCache {
1539 fn new(
1540 data: Array2<f64>,
1541 spec: TermCollectionSpec,
1542 design: TermCollectionDesign,
1543 latent: &StandardLatentCoordConfig,
1544 rho_dim: usize,
1545 ) -> Result<Self, String> {
1546 if latent.term_index.get() >= spec.smooth_terms.len() {
1547 return Err(SmoothError::dimension_mismatch(format!(
1548 "latent-coordinate term index {} out of bounds for {} smooth terms",
1549 latent.term_index,
1550 spec.smooth_terms.len()
1551 ))
1552 .into());
1553 }
1554 if latent.feature_cols.len() != latent.values.latent_dim() {
1555 return Err(SmoothError::dimension_mismatch(format!(
1556 "latent-coordinate feature width mismatch: feature_cols={}, latent_dim={}",
1557 latent.feature_cols.len(),
1558 latent.values.latent_dim()
1559 ))
1560 .into());
1561 }
1562 if latent.values.n_obs() != data.nrows() {
1563 return Err(SmoothError::dimension_mismatch(format!(
1564 "latent-coordinate row mismatch: latent n={}, data n={}",
1565 latent.values.n_obs(),
1566 data.nrows()
1567 ))
1568 .into());
1569 }
1570 let analytic_rho_count = latent
1571 .analytic_penalties
1572 .as_ref()
1573 .map_or(0, |registry| registry.total_rho_count());
1574 Ok(Self {
1575 data,
1576 spec,
1577 design,
1578 current_theta: None,
1579 current_latent: None,
1580 current_hyper_dirs: None,
1581 current_design_cache_id: None,
1582 latent_design_cache: gam_solve::latent_cache::LatentDesignCache::default(),
1583 last_cost: None,
1584 last_eval: None,
1585 term_index: latent.term_index,
1586 feature_cols: latent.feature_cols.clone(),
1587 rho_dim,
1588 n_obs: latent.values.n_obs(),
1589 latent_dim: latent.values.latent_dim(),
1590 id_mode: latent.values.id_mode().clone(),
1591 manifold: latent.values.manifold().clone(),
1592 retraction_registry: latent.values.retraction_registry().clone(),
1593 latent_id: latent.values.latent_id(),
1594 analytic_penalties: latent.analytic_penalties.clone(),
1595 analytic_rho_count,
1596 design_revision: 0,
1597 last_outer_iter: None,
1598 })
1599 }
1600
1601 fn design_revision(&self) -> u64 {
1602 self.design_revision
1603 }
1604
1605 fn design(&self) -> &TermCollectionDesign {
1606 &self.design
1607 }
1608
1609 fn latent(&self) -> Result<std::sync::Arc<gam_terms::latent::LatentCoordValues>, String> {
1610 self.current_latent
1611 .as_ref()
1612 .cloned()
1613 .ok_or_else(|| "latent-coordinate cache has not been realized".to_string())
1614 }
1615
1616 fn analytic_penalties(&self) -> Option<std::sync::Arc<gam_terms::AnalyticPenaltyRegistry>> {
1617 self.analytic_penalties.clone()
1618 }
1619
1620 fn analytic_penalty_rho_count(&self) -> usize {
1621 self.analytic_rho_count
1622 }
1623
1624 fn hyper_dirs(&self) -> Result<Vec<gam_solve::estimate::reml::DirectionalHyperParam>, String> {
1625 self.current_hyper_dirs
1626 .as_ref()
1627 .cloned()
1628 .ok_or_else(|| "latent-coordinate hyper_dirs cache has not been realized".to_string())
1629 }
1630
1631 fn latent_basis_kind(&self) -> Result<gam_solve::latent_cache::LatentBasisKind, String> {
1632 let smooth_term = self
1633 .design
1634 .smooth
1635 .terms
1636 .get(self.term_index.get())
1637 .ok_or_else(|| {
1638 SmoothError::dimension_mismatch(format!(
1639 "LatentCoord term index {} out of bounds for realized smooth design",
1640 self.term_index
1641 ))
1642 })?;
1643 let termspec = self
1644 .spec
1645 .smooth_terms
1646 .get(self.term_index.get())
1647 .ok_or_else(|| {
1648 SmoothError::dimension_mismatch(format!(
1649 "LatentCoord term index {} out of bounds for resolved smooth spec",
1650 self.term_index
1651 ))
1652 })?;
1653 match (&termspec.basis, &smooth_term.metadata) {
1654 (
1655 SmoothBasisSpec::Matern { .. },
1656 BasisMetadata::Matern {
1657 centers,
1658 length_scale,
1659 nu,
1660 aniso_log_scales,
1661 ..
1662 },
1663 ) => Ok(gam_solve::latent_cache::LatentBasisKind::Matern {
1664 centers: centers.clone(),
1665 length_scale: *length_scale,
1666 nu: *nu,
1667 aniso_log_scales: aniso_log_scales
1668 .clone()
1669 .unwrap_or_else(|| vec![0.0; centers.ncols()]),
1670 chunk_size: gam_terms::basis::auto_streaming_chunk_size_for_dense(
1671 self.n_obs,
1672 centers.nrows(),
1673 ),
1674 }),
1675 (
1676 SmoothBasisSpec::Duchon { .. },
1677 BasisMetadata::Duchon {
1678 centers,
1679 length_scale,
1680 power,
1681 nullspace_order,
1682 aniso_log_scales,
1683 ..
1684 },
1685 ) => Ok(gam_solve::latent_cache::LatentBasisKind::Duchon {
1686 centers: centers.clone(),
1687 length_scale: *length_scale,
1688 power: *power,
1689 nullspace_order: *nullspace_order,
1690 aniso_log_scales: aniso_log_scales
1691 .clone()
1692 .unwrap_or_else(|| vec![0.0; centers.ncols()]),
1693 }),
1694 (
1695 SmoothBasisSpec::Sphere { .. },
1696 BasisMetadata::Sphere {
1697 centers,
1698 penalty_order,
1699 method,
1700 ..
1701 },
1702 ) if matches!(*method, gam_terms::basis::SphereMethod::Wahba) => {
1703 Ok(gam_solve::latent_cache::LatentBasisKind::Sphere {
1704 centers: centers.clone(),
1705 penalty_order: *penalty_order,
1706 chunk_size: gam_terms::basis::auto_streaming_chunk_size_for_dense(
1707 self.n_obs,
1708 centers.nrows(),
1709 ),
1710 })
1711 }
1712 (
1713 SmoothBasisSpec::BSpline1D { spec, .. },
1714 BasisMetadata::BSpline1D {
1715 knots,
1716 periodic,
1717 degree: meta_degree,
1718 ..
1719 },
1720 ) => {
1721 let effective_degree = meta_degree.unwrap_or(spec.degree);
1725 if let Some((domain_start, period, num_basis)) = periodic {
1726 Ok(
1727 gam_solve::latent_cache::LatentBasisKind::PeriodicBspline {
1728 domain_start: *domain_start,
1729 period: *period,
1730 degree: effective_degree,
1731 num_basis: *num_basis,
1732 chunk_size: gam_terms::basis::auto_streaming_chunk_size_for_dense(
1733 self.n_obs, *num_basis,
1734 ),
1735 },
1736 )
1737 } else {
1738 let num_basis_est = knots.len().saturating_sub(effective_degree + 1);
1739 Ok(
1740 gam_solve::latent_cache::LatentBasisKind::TensorBspline {
1741 knots: vec![knots.clone()],
1742 degrees: vec![effective_degree],
1743 chunk_size: gam_terms::basis::auto_streaming_chunk_size_for_dense(
1744 self.n_obs,
1745 num_basis_est,
1746 ),
1747 },
1748 )
1749 }
1750 }
1751 (
1752 SmoothBasisSpec::TensorBSpline { .. },
1753 BasisMetadata::TensorBSpline { knots, degrees, .. },
1754 ) => Ok(
1755 gam_solve::latent_cache::LatentBasisKind::TensorBspline {
1756 knots: knots.clone(),
1757 degrees: degrees.clone(),
1758 chunk_size: None,
1759 },
1760 ),
1761 (
1762 SmoothBasisSpec::Pca { .. },
1763 BasisMetadata::Pca {
1764 basis_matrix,
1765 centered,
1766 smooth_penalty,
1767 center_mean,
1768 pca_basis_path,
1769 chunk_size,
1770 ..
1771 },
1772 ) => {
1773 let center_mean_fingerprint = if *centered && pca_basis_path.is_none() {
1774 let mean = center_mean.as_ref().ok_or_else(|| {
1775 SmoothError::invalid_config(
1776 "latent-coordinate Pca cache key requires center_mean when centered",
1777 )
1778 })?;
1779 Some(gam_solve::latent_cache::pca_center_mean_fingerprint(
1780 mean,
1781 ))
1782 } else {
1783 None
1784 };
1785 Ok(gam_solve::latent_cache::LatentBasisKind::Pca {
1786 basis_matrix: basis_matrix.clone(),
1787 centered: *centered,
1788 center_mean_fingerprint,
1789 smooth_penalty: *smooth_penalty,
1790 pca_basis_path: pca_basis_path.clone(),
1791 chunk_size: *chunk_size,
1792 })
1793 }
1794 _ => Err(SmoothError::invalid_config(
1795 "latent-coordinate design cache could not key the realized latent smooth basis"
1796 .to_string(),
1797 )
1798 .into()),
1799 }
1800 }
1801
1802 fn ensure_theta(&mut self, theta: &Array1<f64>) -> Result<(), String> {
1803 if self
1804 .current_theta
1805 .as_ref()
1806 .is_some_and(|cached| theta_values_match(cached, theta))
1807 {
1808 return Ok(());
1809 }
1810 let latent_flat_len = self.n_obs * self.latent_dim;
1811 let direct_hyper_count = latent_coord_direct_hyper_count(&self.id_mode, self.latent_dim);
1812 let expected =
1813 self.rho_dim + latent_flat_len + self.analytic_rho_count + direct_hyper_count;
1814 if theta.len() != expected {
1815 return Err(SmoothError::dimension_mismatch(format!(
1816 "latent-coordinate theta length mismatch: got {}, expected {} (rho_dim={}, n={}, d={}, analytic_rhos={}, direct_hypers={})",
1817 theta.len(),
1818 expected,
1819 self.rho_dim,
1820 self.n_obs,
1821 self.latent_dim,
1822 self.analytic_rho_count,
1823 direct_hyper_count
1824 ))
1825 .into());
1826 }
1827 let flat = theta
1828 .slice(s![self.rho_dim..self.rho_dim + latent_flat_len])
1829 .to_owned();
1830 let latent = std::sync::Arc::new(
1831 gam_terms::latent::LatentCoordValues::from_flat_with_manifold_and_retraction_and_id(
1832 flat,
1833 self.n_obs,
1834 self.latent_dim,
1835 self.id_mode.clone(),
1836 self.manifold.clone(),
1837 self.retraction_registry.clone(),
1838 self.latent_id,
1839 ),
1840 );
1841 let latent_values_changed = self
1842 .current_latent
1843 .as_ref()
1844 .map(|cached| !latent_values_match(cached.as_flat(), latent.as_flat()))
1845 .unwrap_or(true);
1846 if latent_values_changed {
1847 self.latent_design_cache.invalidate_all();
1848 self.current_design_cache_id = None;
1849 self.design_revision = self.design_revision.wrapping_add(1);
1850 }
1851 for n in 0..self.n_obs {
1852 for axis in 0..self.latent_dim {
1853 let col = self.feature_cols[axis];
1854 self.data[[n, col]] = latent.as_flat()[n * self.latent_dim + axis];
1855 }
1856 }
1857
1858 let basis_kind = self.latent_basis_kind()?;
1859 let rebuilt_width = self.design.design.ncols();
1860 let spec = self.spec.clone();
1861 let term_index = self.term_index;
1862 let analytic_rho_count = self.analytic_rho_count;
1863 let data = self.data.view();
1864 let design_context_digest =
1865 gam_solve::latent_cache::latent_design_context_cache_digest(
1866 data,
1867 &spec,
1868 term_index,
1869 analytic_rho_count,
1870 &self.feature_cols,
1871 )
1872 .map_err(|e| e.to_string())?;
1873 let lookup = self
1874 .latent_design_cache
1875 .lookup_or_compute(latent.clone(), basis_kind, design_context_digest, || {
1876 let rebuilt = build_term_collection_design(data, &spec).map_err(|e| {
1877 EstimationError::InvalidInput(format!(
1878 "failed to rebuild latent-coordinate design: {e}"
1879 ))
1880 })?;
1881 if rebuilt.design.ncols() != rebuilt_width {
1882 crate::bail_invalid_estim!(
1883 "latent-coordinate design topology changed: rebuilt p={}, cached p={}",
1884 rebuilt.design.ncols(),
1885 rebuilt_width
1886 );
1887 }
1888 let hyper_dirs = try_build_latent_coord_hyper_dirs(
1889 latent.clone(),
1890 &spec,
1891 &rebuilt,
1892 &[term_index],
1893 analytic_rho_count,
1894 )?
1895 .ok_or_else(|| {
1896 EstimationError::InvalidInput(
1897 "failed to build latent-coordinate hyper_dirs".to_string(),
1898 )
1899 })?;
1900 Ok(gam_solve::latent_cache::ComputedLatentDesign {
1901 design: rebuilt,
1902 hyper_dirs,
1903 })
1904 })
1905 .map_err(|e| e.to_string())?;
1906 if lookup.cached.design.design.ncols() != self.design.design.ncols() {
1907 return Err(SmoothError::dimension_mismatch(format!(
1908 "latent-coordinate design topology changed: rebuilt p={}, cached p={}",
1909 lookup.cached.design.design.ncols(),
1910 self.design.design.ncols()
1911 ))
1912 .into());
1913 }
1914 self.design = lookup.cached.design.clone();
1915 self.current_hyper_dirs = Some(lookup.cached.hyper_dirs.clone());
1916 self.current_latent = Some(latent);
1917 self.current_theta = Some(theta.clone());
1918 self.last_cost = None;
1919 self.last_eval = None;
1920 self.last_outer_iter = None;
1921 if !latent_values_changed && self.current_design_cache_id != Some(lookup.entry_id) {
1922 self.design_revision = self.design_revision.wrapping_add(1);
1923 }
1924 self.current_design_cache_id = Some(lookup.entry_id);
1925 Ok(())
1926 }
1927
1928 fn memoized_cost(&self, theta: &Array1<f64>) -> Option<f64> {
1929 if self
1930 .current_theta
1931 .as_ref()
1932 .is_some_and(|cached| theta_values_match(cached, theta))
1933 && self.last_outer_iter
1934 == Some(gam_solve::estimate::reml::outer_eval::current_outer_iter())
1935 {
1936 self.last_eval
1937 .as_ref()
1938 .map(|cached| cached.0)
1939 .or(self.last_cost)
1940 } else {
1941 None
1942 }
1943 }
1944
1945 fn memoized_eval(
1946 &self,
1947 theta: &Array1<f64>,
1948 ) -> Option<(
1949 f64,
1950 Array1<f64>,
1951 gam_problem::HessianResult,
1952 )> {
1953 if self
1954 .current_theta
1955 .as_ref()
1956 .is_some_and(|cached| theta_values_match(cached, theta))
1957 && self.last_outer_iter
1958 == Some(gam_solve::estimate::reml::outer_eval::current_outer_iter())
1959 {
1960 self.last_eval.clone()
1961 } else {
1962 None
1963 }
1964 }
1965
1966 fn store_eval(
1967 &mut self,
1968 eval: (
1969 f64,
1970 Array1<f64>,
1971 gam_problem::HessianResult,
1972 ),
1973 ) {
1974 self.last_cost = Some(eval.0);
1975 self.last_eval = Some(eval);
1976 self.last_outer_iter =
1977 Some(gam_solve::estimate::reml::outer_eval::current_outer_iter());
1978 }
1979
1980 fn store_cost(&mut self, cost: f64) {
1981 self.last_cost = Some(cost);
1982 self.last_outer_iter =
1983 Some(gam_solve::estimate::reml::outer_eval::current_outer_iter());
1984 }
1985
1986 fn reset(&mut self) {
1987 self.current_theta = None;
1988 self.current_latent = None;
1989 self.current_hyper_dirs = None;
1990 self.current_design_cache_id = None;
1991 self.latent_design_cache.invalidate();
1992 self.last_cost = None;
1993 self.last_eval = None;
1994 self.last_outer_iter = None;
1995 }
1996}
1997
1998pub fn fixed_kappa_profiled_reml_score(
2014 data: ArrayView2<'_, f64>,
2015 y: ArrayView1<'_, f64>,
2016 weights: ArrayView1<'_, f64>,
2017 offset: ArrayView1<'_, f64>,
2018 resolvedspec: &TermCollectionSpec,
2019 term_idx: usize,
2020 kappa: f64,
2021 family: LikelihoodSpec,
2022 options: &FitOptions,
2023) -> Result<f64, EstimationError> {
2024 if !kappa.is_finite() {
2025 crate::bail_invalid_estim!("fixed-κ profiled score probed a non-finite κ = {kappa}");
2026 }
2027 let (feature_cols, mut probe_basis) = match resolvedspec
2030 .smooth_terms
2031 .get(term_idx)
2032 .map(|t| &t.basis)
2033 {
2034 Some(SmoothBasisSpec::ConstantCurvature {
2035 feature_cols, spec, ..
2036 }) => (feature_cols.clone(), spec.clone()),
2037 _ => {
2038 crate::bail_invalid_estim!(
2039 "fixed-κ profiled score: term {term_idx} is not a constant-curvature smooth"
2040 )
2041 }
2042 };
2043 probe_basis.kappa = kappa;
2044
2045 let is_unweighted = weights.iter().all(|&w| (w - 1.0).abs() <= 1e-12);
2065 let is_zero_offset = offset.iter().all(|&o| o.abs() <= 1e-12);
2066 if family == LikelihoodSpec::gaussian_identity() && is_unweighted && is_zero_offset {
2067 let x_term = select_columns(data, &feature_cols).map_err(EstimationError::from)?;
2068 let score =
2069 gam_terms::basis::constant_curvature_honest_profiled_reml_score(x_term.view(), y, &probe_basis)
2070 .map_err(|e| {
2071 EstimationError::InvalidInput(format!(
2072 "fixed-κ honest profiled-REML score at κ={kappa} failed: {e}"
2073 ))
2074 })?;
2075 if !score.is_finite() {
2076 crate::bail_invalid_estim!(
2077 "fixed-κ honest profiled-REML score at κ={kappa} is non-finite"
2078 );
2079 }
2080 return Ok(score);
2081 }
2082
2083 let mut probe_spec = resolvedspec.clone();
2085 match probe_spec.smooth_terms.get_mut(term_idx).map(|t| &mut t.basis) {
2086 Some(SmoothBasisSpec::ConstantCurvature { spec, .. }) => spec.kappa = kappa,
2087 _ => {
2088 crate::bail_invalid_estim!(
2089 "fixed-κ profiled score: term {term_idx} is not a constant-curvature smooth"
2090 )
2091 }
2092 }
2093 let fixed_kappa_options = SpatialLengthScaleOptimizationOptions {
2094 enabled: false,
2095 ..SpatialLengthScaleOptimizationOptions::default()
2096 };
2097 let fit = fit_term_collectionwith_spatial_length_scale_optimization(
2098 data,
2099 y.to_owned(),
2100 weights.to_owned(),
2101 offset.to_owned(),
2102 &probe_spec,
2103 family,
2104 options,
2105 &fixed_kappa_options,
2106 )?;
2107 let score = fit_score(&fit.fit);
2108 if !score.is_finite() {
2109 crate::bail_invalid_estim!("fixed-κ profiled fit at κ={kappa} returned a non-finite score");
2110 }
2111 Ok(score)
2112}
2113
2114fn constant_curvature_kappa_fair_argmin(
2139 data: ArrayView2<'_, f64>,
2140 y: ArrayView1<'_, f64>,
2141 resolvedspec: &TermCollectionSpec,
2142 term_idx: usize,
2143) -> Option<f64> {
2144 let (kappa_min, kappa_max) = constant_curvature_kappa_bounds(data, resolvedspec, term_idx);
2145 if !(kappa_min.is_finite() && kappa_max.is_finite() && kappa_max > kappa_min) {
2146 return None;
2147 }
2148 let (feature_cols, base_spec) = match resolvedspec.smooth_terms.get(term_idx).map(|t| &t.basis) {
2149 Some(SmoothBasisSpec::ConstantCurvature {
2150 feature_cols, spec, ..
2151 }) => (feature_cols, spec.clone()),
2152 _ => return None,
2153 };
2154 let x_term = match select_columns(data, feature_cols) {
2155 Ok(x) => x,
2156 Err(e) => {
2157 log::info!("[spatial-kappa] #1464 κ-fair argmin column select failed ({e}); skipping");
2158 return None;
2159 }
2160 };
2161 const GRID_STEPS: usize = 24;
2167 let mut best: Option<(f64, f64)> = None; for i in 0..=GRID_STEPS {
2169 let t = i as f64 / GRID_STEPS as f64;
2170 let kappa = kappa_min + (kappa_max - kappa_min) * t;
2171 let mut probe_spec = base_spec.clone();
2172 probe_spec.kappa = kappa;
2173 match gam_terms::basis::constant_curvature_kappa_fair_sign_score(x_term.view(), y, &probe_spec) {
2174 Ok(score) => {
2175 if best.as_ref().is_none_or(|(b, _)| score < *b) {
2176 best = Some((score, kappa));
2177 }
2178 }
2179 Err(e) => {
2180 log::info!(
2181 "[spatial-kappa] #1464 κ-fair argmin probe at κ={kappa:.4} failed ({e}); skipping"
2182 );
2183 }
2184 }
2185 }
2186 best.map(|(score, kappa)| {
2187 log::info!(
2188 "[spatial-kappa] #1464 κ-fair argmin κ̂={kappa:.4} (κ-fair score={score:.6e}) for term {term_idx}"
2189 );
2190 kappa
2191 })
2192}
2193
2194fn select_constant_curvature_kappa_sign_seed(
2202 data: ArrayView2<'_, f64>,
2203 y: ArrayView1<'_, f64>,
2204 resolvedspec: &TermCollectionSpec,
2205 term_idx: usize,
2206) -> Option<f64> {
2207 let (kappa_min, kappa_max) = constant_curvature_kappa_bounds(data, resolvedspec, term_idx);
2208 if !(kappa_min.is_finite() && kappa_max.is_finite() && kappa_max > kappa_min) {
2209 return None;
2210 }
2211 let (feature_cols, base_spec) = match resolvedspec.smooth_terms.get(term_idx).map(|t| &t.basis) {
2223 Some(SmoothBasisSpec::ConstantCurvature {
2224 feature_cols, spec, ..
2225 }) => (feature_cols, spec.clone()),
2226 _ => return None,
2227 };
2228 let x_term = match select_columns(data, feature_cols) {
2229 Ok(x) => x,
2230 Err(e) => {
2231 log::info!("[spatial-kappa] #1464 sign-basin scan column select failed ({e}); skipping");
2232 return None;
2233 }
2234 };
2235 let probes = [
2239 kappa_min,
2240 0.5 * kappa_min,
2241 0.0,
2242 0.5 * kappa_max,
2243 kappa_max,
2244 ];
2245 let mut best: Option<(f64, f64)> = None; for &kappa in &probes {
2247 let mut probe_spec = base_spec.clone();
2248 probe_spec.kappa = kappa;
2249 match gam_terms::basis::constant_curvature_kappa_fair_sign_score(
2250 x_term.view(),
2251 y,
2252 &probe_spec,
2253 ) {
2254 Ok(score) => {
2255 if best.as_ref().is_none_or(|(b, _)| score < *b) {
2256 best = Some((score, kappa));
2257 }
2258 }
2259 Err(e) => {
2260 log::info!(
2261 "[spatial-kappa] #1464 sign-basin probe at κ={kappa:.4} failed ({e}); skipping"
2262 );
2263 }
2264 }
2265 }
2266 best.map(|(score, kappa)| {
2267 log::info!(
2268 "[spatial-kappa] #1464 κ-fair sign-basin scan selected κ_seed={kappa:.4} \
2269 (κ-fair score={score:.6e}) for term {term_idx}"
2270 );
2271 kappa
2272 })
2273}
2274
2275const SPATIAL_RANGE_PRESCAN_GRID: usize = 7;
2278
2279fn prescan_isotropic_spatial_range_seed(
2311 data: ArrayView2<'_, f64>,
2312 y: ArrayView1<'_, f64>,
2313 weights: ArrayView1<'_, f64>,
2314 offset: ArrayView1<'_, f64>,
2315 resolvedspec: &TermCollectionSpec,
2316 baseline_score: f64,
2317 family: &LikelihoodSpec,
2318 options: &FitOptions,
2319 kappa_options: &SpatialLengthScaleOptimizationOptions,
2320 spatial_terms: &[usize],
2321) -> Result<Vec<(usize, f64)>, EstimationError> {
2322 if has_aniso_terms(resolvedspec, spatial_terms)
2324 || !constant_curvature_term_indices(resolvedspec).is_empty()
2325 {
2326 return Ok(Vec::new());
2327 }
2328 let dims = spatial_dims_per_term(resolvedspec, spatial_terms);
2329 let mut working = resolvedspec.clone();
2333 let mut best_score = if baseline_score.is_finite() {
2334 baseline_score
2335 } else {
2336 f64::INFINITY
2337 };
2338 let mut overrides: Vec<(usize, f64)> = Vec::new();
2339 for (slot, &term_idx) in spatial_terms.iter().enumerate() {
2340 if dims.get(slot).copied().unwrap_or(1) != 1 {
2343 continue;
2344 }
2345 if get_spatial_length_scale(&working, term_idx).is_none() {
2348 continue;
2349 }
2350 let (psi_lo, psi_hi) = spatial_term_psi_bounds(data, &working, term_idx, kappa_options);
2351 if !(psi_lo.is_finite() && psi_hi.is_finite()) || psi_hi <= psi_lo {
2352 continue;
2353 }
2354 let mut term_best: Option<f64> = None;
2355 for g in 0..SPATIAL_RANGE_PRESCAN_GRID {
2356 let frac = g as f64 / (SPATIAL_RANGE_PRESCAN_GRID - 1) as f64;
2357 let psi = psi_lo + (psi_hi - psi_lo) * frac;
2358 let ls = (-psi).exp();
2362 if !ls.is_finite() || ls <= 0.0 {
2363 continue;
2364 }
2365 let mut probe = working.clone();
2366 if set_spatial_length_scale(&mut probe, term_idx, ls).is_err() {
2367 continue;
2368 }
2369 let fit = match fit_term_collection_forspec(
2378 data,
2379 y,
2380 weights,
2381 offset,
2382 &probe,
2383 family.clone(),
2384 options,
2385 ) {
2386 Ok(fit) => fit,
2387 Err(_) => continue,
2390 };
2391 let score = fit_score(&fit.fit);
2392 if score.is_finite() && score < best_score - 1e-7 * best_score.abs().max(1.0) {
2395 best_score = score;
2396 term_best = Some(ls);
2397 }
2398 }
2399 if let Some(ls) = term_best {
2400 set_spatial_length_scale(&mut working, term_idx, ls)?;
2401 overrides.push((term_idx, ls));
2402 log::info!(
2403 "[spatial-kappa] #1074 range pre-scan: term {term_idx} re-seeded at \
2404 length_scale={ls:.5} (profiled REML {best_score:.5}, was {baseline_score:.5})"
2405 );
2406 }
2407 }
2408 Ok(overrides)
2409}
2410
2411const JOINT_RESTART_WINDOW_FRACTIONS: [f64; 5] = [0.0, 0.2, 0.45, 0.7, 1.0];
2420
2421fn joint_solve_from_window_fraction(
2437 data: ArrayView2<'_, f64>,
2438 y: ArrayView1<'_, f64>,
2439 weights: ArrayView1<'_, f64>,
2440 offset: ArrayView1<'_, f64>,
2441 base_spec: &TermCollectionSpec,
2442 spatial_terms: &[usize],
2443 fraction: f64,
2444 family: &LikelihoodSpec,
2445 options: &FitOptions,
2446 baseline_options: &FitOptions,
2447 kappa_options: &SpatialLengthScaleOptimizationOptions,
2448) -> Result<Option<(FittedTermCollectionWithSpec, f64)>, EstimationError> {
2449 let mut seed_spec = base_spec.clone();
2450 let mut any_set = false;
2451 for &term_idx in spatial_terms {
2452 if get_spatial_length_scale(&seed_spec, term_idx).is_none() {
2453 continue;
2454 }
2455 let (psi_lo, psi_hi) = spatial_term_psi_bounds(data, &seed_spec, term_idx, kappa_options);
2456 if !(psi_lo.is_finite() && psi_hi.is_finite()) || psi_hi <= psi_lo {
2457 continue;
2458 }
2459 let psi = psi_lo + (psi_hi - psi_lo) * fraction;
2460 let ls = (-psi).exp();
2461 if !ls.is_finite() || ls <= 0.0 {
2462 continue;
2463 }
2464 if set_spatial_length_scale(&mut seed_spec, term_idx, ls).is_ok() {
2465 any_set = true;
2466 }
2467 }
2468 if !any_set {
2469 return Ok(None);
2470 }
2471 let seed_best = match fit_term_collection_forspec(
2475 data,
2476 y,
2477 weights,
2478 offset,
2479 &seed_spec,
2480 family.clone(),
2481 baseline_options,
2482 ) {
2483 Ok(fit) => fit,
2484 Err(_) => return Ok(None),
2485 };
2486 let seed_spec = freeze_term_collection_from_design(&seed_spec, &seed_best.design)?;
2487 let seed_terms = spatial_length_scale_term_indices(&seed_spec);
2490 if seed_terms.is_empty() {
2491 let score = fit_score(&seed_best.fit);
2492 return Ok(Some((
2493 FittedTermCollectionWithSpec {
2494 fit: seed_best.fit,
2495 design: seed_best.design,
2496 resolvedspec: seed_spec,
2497 adaptive_diagnostics: seed_best.adaptive_diagnostics,
2498 kappa_timing: None,
2499 },
2500 score,
2501 )));
2502 }
2503 let joint = try_exact_joint_spatial_length_scale_optimization(
2504 data,
2505 y,
2506 weights,
2507 offset,
2508 &seed_spec,
2509 &seed_best,
2510 family.clone(),
2511 options,
2512 kappa_options,
2513 &seed_terms,
2514 )?;
2515 match joint {
2516 Some(fit) => {
2517 let score = fit_score(&fit.fit);
2518 Ok(Some((fit, score)))
2519 }
2520 None => {
2523 let score = fit_score(&seed_best.fit);
2524 Ok(Some((
2525 FittedTermCollectionWithSpec {
2526 fit: seed_best.fit,
2527 design: seed_best.design,
2528 resolvedspec: seed_spec,
2529 adaptive_diagnostics: seed_best.adaptive_diagnostics,
2530 kappa_timing: None,
2531 },
2532 score,
2533 )))
2534 }
2535 }
2536}
2537
2538fn try_exact_joint_spatial_length_scale_optimization(
2539 data: ArrayView2<'_, f64>,
2540 y: ArrayView1<'_, f64>,
2541 weights: ArrayView1<'_, f64>,
2542 offset: ArrayView1<'_, f64>,
2543 resolvedspec: &TermCollectionSpec,
2544 best: &FittedTermCollection,
2545 family: LikelihoodSpec,
2546 options: &FitOptions,
2547 kappa_options: &SpatialLengthScaleOptimizationOptions,
2548 spatial_terms: &[usize],
2549) -> Result<Option<FittedTermCollectionWithSpec>, EstimationError> {
2550 if spatial_terms.is_empty() {
2551 return Ok(None);
2552 }
2553 kappa_options
2558 .validate()
2559 .map_err(EstimationError::InvalidInput)?;
2560
2561 let cc_term_set = constant_curvature_term_indices(resolvedspec);
2581 let all_spatial_are_cc =
2582 !cc_term_set.is_empty() && spatial_terms.iter().all(|t| cc_term_set.contains(t));
2583 if all_spatial_are_cc {
2584 let mut fixed_kappa_spec = resolvedspec.clone();
2585 let mut any_kappa_chosen = false;
2586 for &term_idx in spatial_terms {
2587 if let Some(kappa_hat) =
2598 constant_curvature_kappa_fair_argmin(data, y, resolvedspec, term_idx)
2599 .filter(|&k| k < 0.0)
2600 {
2601 if let Some(SmoothBasisSpec::ConstantCurvature { spec: cc, .. }) = fixed_kappa_spec
2602 .smooth_terms
2603 .get_mut(term_idx)
2604 .map(|t| &mut t.basis)
2605 {
2606 cc.kappa = kappa_hat;
2607 any_kappa_chosen = true;
2608 log::info!(
2609 "[spatial-kappa] #1464 term {term_idx}: fixed κ̂ = {kappa_hat:.4} from κ-fair argmin (hyperbolic basin; profiling ρ only)"
2610 );
2611 }
2612 }
2613 }
2614 if any_kappa_chosen {
2615 let baseline_score = fit_score(&best.fit);
2619 let fitted = fit_term_collection_forspec(
2620 data,
2621 y,
2622 weights,
2623 offset,
2624 &fixed_kappa_spec,
2625 family.clone(),
2626 options,
2627 )?;
2628 let frozen_spec =
2629 freeze_term_collection_from_design(&fixed_kappa_spec, &fitted.design)?;
2630 let mut fit = fitted.fit;
2631 fit.reml_score = baseline_score;
2643 return Ok(Some(FittedTermCollectionWithSpec {
2644 fit,
2645 design: fitted.design,
2646 resolvedspec: frozen_spec,
2647 adaptive_diagnostics: fitted.adaptive_diagnostics,
2648 kappa_timing: None,
2649 }));
2650 }
2651 }
2652
2653 if try_build_spatial_log_kappa_hyper_dirs(data, resolvedspec, &best.design, spatial_terms)?
2654 .is_none()
2655 {
2656 if !constant_curvature_term_indices(resolvedspec).is_empty() {
2657 log::info!(
2658 "[#1464-trace] try_exact_joint RETURNED None (hyper_dirs unavailable); \
2659 κ̂ comes from a NON-joint path"
2660 );
2661 }
2662 return Ok(None);
2663 }
2664 if !constant_curvature_term_indices(resolvedspec).is_empty() {
2665 log::info!(
2666 "[#1464-trace] try_exact_joint ENTERED for {} spatial term(s); CC present",
2667 spatial_terms.len()
2668 );
2669 }
2670
2671 const JOINT_RHO_BOUND: f64 = 12.0;
2672 let rho_dim = best.fit.lambdas.len();
2673
2674 let has_constant_curvature_term = !constant_curvature_term_indices(resolvedspec).is_empty();
2688 let rho_upper_bound = if has_constant_curvature_term {
2689 gam_solve::estimate::RHO_BOUND
2690 } else {
2691 JOINT_RHO_BOUND
2692 };
2693
2694 let dims_per_term = spatial_dims_per_term(resolvedspec, spatial_terms);
2696 let use_aniso = has_aniso_terms(resolvedspec, spatial_terms);
2697
2698 let log_kappa0 = if use_aniso {
2703 SpatialLogKappaCoords::from_length_scales_aniso(resolvedspec, spatial_terms, kappa_options)
2704 } else {
2705 SpatialLogKappaCoords::from_length_scales(resolvedspec, spatial_terms, kappa_options)
2706 };
2707 let mut log_kappa0 =
2710 log_kappa0.reseed_from_data(data, resolvedspec, spatial_terms, kappa_options);
2711 let mut cc_sign_seeds: Vec<(usize, f64)> = Vec::new();
2727 if has_constant_curvature_term {
2728 for (slot, &term_idx) in spatial_terms.iter().enumerate() {
2729 if constant_curvature_term_spec(resolvedspec, term_idx).is_none() {
2730 continue;
2731 }
2732 let scan = select_constant_curvature_kappa_sign_seed(
2733 data,
2734 y,
2735 resolvedspec,
2736 term_idx,
2737 );
2738 match scan {
2743 Some(kappa_seed) => {
2744 log::info!(
2745 "[#1464-trace] term {term_idx}: κ-fair sign-basin scan picked κ_seed = {kappa_seed}"
2746 );
2747 log_kappa0.set_scalar_slot(slot, kappa_seed);
2748 cc_sign_seeds.push((slot, kappa_seed));
2749 }
2750 None => {
2751 log::info!(
2752 "[#1464-trace] term {term_idx}: fixed-κ sign-basin scan returned NONE (no seed applied)"
2753 );
2754 }
2755 }
2756 }
2757 }
2758 let log_kappa_lower = if use_aniso {
2759 SpatialLogKappaCoords::lower_bounds_aniso_from_data(
2760 data,
2761 resolvedspec,
2762 spatial_terms,
2763 &dims_per_term,
2764 kappa_options,
2765 )
2766 } else {
2767 SpatialLogKappaCoords::lower_bounds_from_data(
2768 data,
2769 resolvedspec,
2770 spatial_terms,
2771 kappa_options,
2772 )
2773 };
2774 let log_kappa_upper = if use_aniso {
2775 SpatialLogKappaCoords::upper_bounds_aniso_from_data(
2776 data,
2777 resolvedspec,
2778 spatial_terms,
2779 &dims_per_term,
2780 kappa_options,
2781 )
2782 } else {
2783 SpatialLogKappaCoords::upper_bounds_from_data(
2784 data,
2785 resolvedspec,
2786 spatial_terms,
2787 kappa_options,
2788 )
2789 };
2790 let mut log_kappa_lower = log_kappa_lower;
2814 let mut log_kappa_upper = log_kappa_upper;
2815 for &(slot, kappa_seed) in &cc_sign_seeds {
2816 if kappa_seed != 0.0 {
2817 log_kappa_lower.set_scalar_slot(slot, kappa_seed);
2818 log_kappa_upper.set_scalar_slot(slot, kappa_seed);
2819 }
2820 log::info!(
2821 "[#1464-trace] slot {slot}: FROZE joint ψ coordinate at κ_seed={kappa_seed} \
2822 (window [{}, {}]); raw fit_score is sign-blind so the κ-fair scan is authoritative",
2823 log_kappa_lower.as_array()[log_kappa_lower.dims_per_term()[..slot].iter().sum::<usize>()],
2824 log_kappa_upper.as_array()[log_kappa_upper.dims_per_term()[..slot].iter().sum::<usize>()],
2825 );
2826 }
2827 let log_kappa0 = log_kappa0.clamp_to_bounds(&log_kappa_lower, &log_kappa_upper);
2830 let setup = ExactJointHyperSetup::new(
2831 best.fit.lambdas.mapv(f64::ln),
2832 Array1::<f64>::from_elem(rho_dim, -JOINT_RHO_BOUND),
2833 Array1::<f64>::from_elem(rho_dim, rho_upper_bound),
2834 log_kappa0,
2835 log_kappa_lower,
2836 log_kappa_upper,
2837 );
2838
2839 let theta0 = setup.theta0();
2840 let lower = setup.lower();
2841 let upper = setup.upper();
2842
2843 let kind = if use_aniso {
2855 SpatialHyperKind::Anisotropic
2856 } else {
2857 SpatialHyperKind::Isotropic
2858 };
2859 let (outcome, kappa_timing) = run_exact_joint_spatial_optimization(
2860 kind,
2861 data,
2862 y,
2863 weights,
2864 offset,
2865 resolvedspec,
2866 &best.design,
2867 family.clone(),
2868 options,
2869 spatial_terms,
2870 &dims_per_term,
2871 &theta0,
2872 &lower,
2873 &upper,
2874 rho_dim,
2875 kappa_options,
2876 )?;
2877
2878 let baseline_score = fit_score(&best.fit);
2879
2880 let (theta_star, joint_final_value) = match outcome {
2890 SpatialJointOutcome::Optimized {
2891 theta_star,
2892 final_value,
2893 } => (theta_star, final_value),
2894 SpatialJointOutcome::NonConverged {
2895 iterations,
2896 final_value,
2897 final_grad_norm,
2898 } => {
2899 if has_constant_curvature_term {
2900 log::info!(
2901 "[#1464-trace] joint solve NONCONVERGED (iters={iterations}, \
2902 final_value={final_value}); returning FROZEN BASELINE geometry \
2903 (κ̂ = spec default, NOT the joint candidate)"
2904 );
2905 }
2906 log::info!(
2907 "[spatial-kappa] joint spatial optimization did not converge \
2908 (iterations={}, final_objective={:.6e}, final_grad_norm={}); \
2909 keeping the frozen baseline geometry",
2910 iterations,
2911 final_value,
2912 final_grad_norm.map_or_else(|| "n/a".to_string(), |g| format!("{g:.3e}")),
2913 );
2914 return Ok(Some(fit_frozen_baseline_geometry(
2915 data,
2916 y,
2917 weights,
2918 offset,
2919 resolvedspec,
2920 best,
2921 family,
2922 options,
2923 baseline_score,
2924 Some(kappa_timing),
2925 )?));
2926 }
2927 };
2928
2929 let accept_tol = options.tol.max(1e-8 * baseline_score.abs()).max(1e-12);
2934 if joint_final_value > baseline_score + accept_tol {
2935 if has_constant_curvature_term {
2936 log::info!(
2937 "[#1464-trace] joint candidate WORSENED score (joint={joint_final_value}, \
2938 baseline={baseline_score}); returning FROZEN BASELINE geometry \
2939 (κ̂ = spec default, NOT the joint candidate)"
2940 );
2941 }
2942 log::info!(
2943 "[spatial-kappa] exact joint spatial candidate worsened the profiled score (joint={:.6e}, baseline={:.6e}, tol={:.2e}); keeping the frozen baseline geometry",
2944 joint_final_value,
2945 baseline_score,
2946 accept_tol,
2947 );
2948 return Ok(Some(fit_frozen_baseline_geometry(
2949 data,
2950 y,
2951 weights,
2952 offset,
2953 resolvedspec,
2954 best,
2955 family,
2956 options,
2957 baseline_score,
2958 Some(kappa_timing),
2959 )?));
2960 }
2961
2962 let rho_star = theta_star.slice(s![..rho_dim]).mapv(f64::exp);
2963 let log_kappa_star =
2964 SpatialLogKappaCoords::from_theta_tail_with_dims(&theta_star, rho_dim, dims_per_term);
2965 if has_constant_curvature_term {
2971 let star = log_kappa_star.as_array();
2972 let dims = log_kappa_star.dims_per_term();
2973 for (slot, &term_idx) in spatial_terms.iter().enumerate() {
2974 if constant_curvature_term_spec(resolvedspec, term_idx).is_some() {
2975 let off: usize = dims[..slot].iter().sum();
2976 log::info!(
2977 "[#1464-trace] term {term_idx}: joint solver CONVERGED ψ-tail κ = {} \
2978 (this is the optimised candidate; joint_final_value={joint_final_value})",
2979 star[off]
2980 );
2981 }
2982 }
2983 }
2984 let baseline_spec = resolvedspec;
2988 let optimized_spec = log_kappa_star.apply_tospec(resolvedspec, spatial_terms)?;
2989 let optimized = fit_term_collection_forspecwith_heuristic_lambdas(
2990 data,
2991 y,
2992 weights,
2993 offset,
2994 &optimized_spec,
2995 rho_star.as_slice(),
2996 family.clone(),
2997 options,
2998 )?;
2999
3000 let optimized_edf = optimized.fit.inference.as_ref().map(|inf| inf.edf_total);
3014 if let Some(opt_edf) = optimized_edf
3015 && opt_edf < SPATIAL_COLLAPSE_EDF_FLOOR
3016 {
3017 let baseline = fit_frozen_baseline_geometry(
3018 data,
3019 y,
3020 weights,
3021 offset,
3022 baseline_spec,
3023 best,
3024 family.clone(),
3025 options,
3026 baseline_score,
3027 Some(kappa_timing),
3028 )?;
3029 let baseline_edf = baseline.fit.inference.as_ref().map(|inf| inf.edf_total);
3030 if let Some(base_edf) = baseline_edf
3031 && base_edf >= opt_edf + SPATIAL_COLLAPSE_EDF_MARGIN
3032 {
3033 log::info!(
3034 "[spatial-kappa] joint candidate collapsed to the null (edf={opt_edf:.3}); \
3035 baseline geometry retains edf={base_edf:.3} — keeping the frozen baseline",
3036 );
3037 return Ok(Some(baseline));
3038 }
3039 }
3042
3043 let mut fit = optimized.fit;
3047 fit.reml_score = joint_final_value;
3048 let optimized_result = FittedTermCollectionWithSpec {
3049 fit,
3050 design: optimized.design,
3051 resolvedspec: optimized_spec,
3052 adaptive_diagnostics: optimized.adaptive_diagnostics,
3053 kappa_timing: Some(kappa_timing),
3054 };
3055
3056 Ok(Some(optimized_result))
3057}
3058
3059const SPATIAL_COLLAPSE_EDF_FLOOR: f64 = 2.5;
3063
3064const SPATIAL_COLLAPSE_EDF_MARGIN: f64 = 1.0;
3069
3070fn fit_frozen_baseline_geometry(
3106 data: ArrayView2<'_, f64>,
3107 y: ArrayView1<'_, f64>,
3108 weights: ArrayView1<'_, f64>,
3109 offset: ArrayView1<'_, f64>,
3110 resolvedspec: &TermCollectionSpec,
3111 best: &FittedTermCollection,
3112 family: LikelihoodSpec,
3113 options: &FitOptions,
3114 baseline_score: f64,
3115 kappa_timing: Option<SpatialLengthScaleOptimizationTiming>,
3116) -> Result<FittedTermCollectionWithSpec, EstimationError> {
3117 let baseline = fit_term_collection_forspecwith_heuristic_lambdas(
3118 data,
3119 y,
3120 weights,
3121 offset,
3122 resolvedspec,
3123 best.fit.lambdas.as_slice(),
3124 family.clone(),
3125 options,
3126 )?;
3127 let best_edf = best.fit.inference.as_ref().map(|inf| inf.edf_total);
3132 let baseline_edf = baseline.fit.inference.as_ref().map(|inf| inf.edf_total);
3133 let baseline = match (best_edf, baseline_edf) {
3134 (Some(best_edf), Some(base_edf))
3135 if base_edf < SPATIAL_COLLAPSE_EDF_FLOOR
3136 && best_edf >= base_edf + SPATIAL_COLLAPSE_EDF_MARGIN =>
3137 {
3138 log::info!(
3139 "[spatial-kappa] warm-started frozen baseline collapsed (edf={base_edf:.3}) \
3140 below the certified baseline (edf={best_edf:.3}); refitting from scratch",
3141 );
3142 fit_term_collection_forspec(data, y, weights, offset, resolvedspec, family, options)?
3143 }
3144 _ => baseline,
3145 };
3146 let mut fit = baseline.fit;
3147 fit.reml_score = baseline_score;
3148 Ok(FittedTermCollectionWithSpec {
3149 fit,
3150 design: baseline.design,
3151 resolvedspec: resolvedspec.clone(),
3152 adaptive_diagnostics: baseline.adaptive_diagnostics,
3153 kappa_timing,
3154 })
3155}
3156
3157#[derive(Clone, Copy, PartialEq, Eq, Debug)]
3169enum SpatialHyperKind {
3170 Anisotropic,
3171 Isotropic,
3172}
3173
3174impl SpatialHyperKind {
3175 fn label(self) -> &'static str {
3178 match self {
3179 SpatialHyperKind::Anisotropic => "spatial-aniso-joint",
3180 SpatialHyperKind::Isotropic => "spatial-iso-joint",
3181 }
3182 }
3183
3184 fn adjective(self) -> &'static str {
3186 match self {
3187 SpatialHyperKind::Anisotropic => "anisotropic",
3188 SpatialHyperKind::Isotropic => "isotropic",
3189 }
3190 }
3191
3192 fn coord_name(self) -> &'static str {
3195 match self {
3196 SpatialHyperKind::Anisotropic => "psi",
3197 SpatialHyperKind::Isotropic => "kappa",
3198 }
3199 }
3200}
3201
3202struct SpatialFrozenGlmInputs {
3208 y: Array1<f64>,
3209 weights: Array1<f64>,
3210 offset: Array1<f64>,
3211 family: LikelihoodSpec,
3212}
3213
3214fn frozen_glm_tensor_eligible_family(family: &LikelihoodSpec) -> bool {
3231 !family.is_gaussian_identity()
3232 && matches!(
3233 &family.response,
3234 ResponseFamily::Binomial
3235 | ResponseFamily::Poisson
3236 | ResponseFamily::Gamma
3237 | ResponseFamily::NegativeBinomial { .. }
3238 )
3239}
3240
3241struct SpatialJointContext<'d> {
3242 data: ArrayView2<'d, f64>,
3243 rho_dim: usize,
3244 kind: SpatialHyperKind,
3245 cache: SingleBlockExactJointDesignCache<'d>,
3246 evaluator: gam_solve::estimate::ExternalJointHyperEvaluator<'d>,
3247 frozen_glm_inputs: Option<SpatialFrozenGlmInputs>,
3248 frozen_glm_psi_bounds: Option<(f64, f64)>,
3249 frozen_glm_tensor: Option<gam_solve::glm_sufficient_lane::FrozenWeightGramTensor>,
3250 frozen_glm_tensor_attempted: bool,
3251 frozen_glm_weight_memo: Option<(Array1<f64>, Array1<f64>)>,
3263}
3264
3265#[derive(Clone, Copy, Debug, Default)]
3266struct NfreeSkipGateStatus {
3267 shape: bool,
3268 value: bool,
3269 gradient: bool,
3270 penalty: bool,
3271 revision: bool,
3272 second_order: bool,
3273}
3274
3275impl NfreeSkipGateStatus {
3276 fn would_skip(self, require_gradient: bool) -> bool {
3277 self.shape
3278 && self.value
3279 && (!require_gradient || self.gradient)
3280 && self.penalty
3281 && self.revision
3282 && !self.second_order
3283 }
3284}
3285
3286impl<'d> SpatialJointContext<'d> {
3287 fn nfree_skip_gate_status(
3288 &self,
3289 theta: &Array1<f64>,
3290 allow_second_order: bool,
3291 require_gradient: bool,
3292 ) -> NfreeSkipGateStatus {
3293 let shape = theta.len() == self.rho_dim + 1;
3294 let (value, gradient) = if shape {
3295 let psi = theta[self.rho_dim];
3296 (
3297 self.evaluator.psi_gram_tensor_covers(psi)
3298 && self.evaluator.psi_gram_tensor_covers_skip(psi),
3299 !require_gradient || self.evaluator.psi_gram_tensor_covers_gradient(psi),
3300 )
3301 } else {
3302 (false, false)
3303 };
3304 NfreeSkipGateStatus {
3305 shape,
3306 value,
3307 gradient,
3308 penalty: self.evaluator.supports_nfree_penalty_rekey(),
3309 revision: self.evaluator.nfree_fast_path_revision().is_some(),
3310 second_order: allow_second_order,
3311 }
3312 }
3313
3314 fn frozen_glm_working_state(
3315 &self,
3316 beta: &Array1<f64>,
3317 ) -> Result<Option<(Array1<f64>, Array1<f64>)>, EstimationError> {
3318 let Some(inputs) = self.frozen_glm_inputs.as_ref() else {
3319 return Ok(None);
3320 };
3321 if beta.len() != self.cache.design().design.ncols() {
3322 return Ok(None);
3323 }
3324 let mut eta = self.cache.design().design.matrixvectormultiply(beta);
3325 if eta.len() != inputs.offset.len() {
3326 crate::bail_invalid_estim!(
3327 "frozen GLM tensor warm-state row mismatch: eta={}, offset={}",
3328 eta.len(),
3329 inputs.offset.len()
3330 );
3331 }
3332 eta += &inputs.offset;
3333 let obs = evaluate_standard_familyobservations(
3334 inputs.family.clone(),
3335 None,
3336 None,
3337 None,
3338 &inputs.y,
3339 &inputs.weights,
3340 &eta,
3341 )?;
3342 let mut working_response = obs.eta.clone();
3343 for i in 0..working_response.len() {
3344 let wi = obs.fisherweight[i].max(1e-12);
3345 working_response[i] += obs.score[i] / wi;
3346 }
3347 Ok(Some((obs.fisherweight, working_response)))
3348 }
3349
3350 fn frozen_glm_trial_weights(
3359 &mut self,
3360 beta: &Array1<f64>,
3361 ) -> Result<Option<Array1<f64>>, EstimationError> {
3362 if let Some((memo_beta, memo_w)) = self.frozen_glm_weight_memo.as_ref()
3363 && memo_beta.len() == beta.len()
3364 && memo_beta
3365 .iter()
3366 .zip(beta.iter())
3367 .all(|(a, b)| a.to_bits() == b.to_bits())
3368 {
3369 return Ok(Some(memo_w.clone()));
3370 }
3371 match self.frozen_glm_working_state(beta)? {
3372 Some((current_w, _)) => {
3373 self.frozen_glm_weight_memo = Some((beta.clone(), current_w.clone()));
3374 Ok(Some(current_w))
3375 }
3376 None => Ok(None),
3377 }
3378 }
3379
3380 fn ensure_frozen_glm_tensor(
3381 &mut self,
3382 theta: &Array1<f64>,
3383 warm_beta: Option<&Array1<f64>>,
3384 ) -> Result<(), EstimationError> {
3385 if self.frozen_glm_tensor.is_some() || self.frozen_glm_tensor_attempted {
3386 return Ok(());
3387 }
3388 let Some((psi_lo, psi_hi)) = self.frozen_glm_psi_bounds else {
3389 return Ok(());
3390 };
3391 if theta.len() != self.rho_dim + 1 {
3392 self.frozen_glm_tensor_attempted = true;
3393 return Ok(());
3394 }
3395 let Some(beta) = warm_beta else {
3396 return Ok(());
3397 };
3398 let Some((frozen_w, working_z)) = self.frozen_glm_working_state(beta)? else {
3399 self.frozen_glm_tensor_attempted = true;
3400 return Ok(());
3401 };
3402 let theta_probe_base = theta.clone();
3403 let rho_dim = self.rho_dim;
3404 let Self {
3411 cache, evaluator, ..
3412 } = self;
3413 let tensor = evaluator.build_frozen_glm_gram_tensor(
3414 |psi| {
3415 let mut theta_probe = theta_probe_base.clone();
3416 theta_probe[rho_dim] = psi;
3417 cache.ensure_theta(&theta_probe)?;
3418 Ok(cache.design().design.clone())
3419 },
3420 frozen_w.view(),
3421 working_z.view(),
3422 psi_lo,
3423 psi_hi,
3424 );
3425 self.cache
3426 .ensure_theta(theta)
3427 .map_err(EstimationError::InvalidInput)?;
3428 self.frozen_glm_tensor_attempted = true;
3429 if let Some(tensor) = tensor {
3430 self.frozen_glm_tensor = Some(tensor);
3431 log::info!(
3432 "[STAGE] {} certified frozen-W GLM ψ tensor over [{psi_lo:.3}, {psi_hi:.3}]",
3433 self.kind.label(),
3434 );
3435 } else {
3436 log::info!(
3437 "[STAGE] {} frozen-W GLM ψ tensor did not certify over [{psi_lo:.3}, {psi_hi:.3}]",
3438 self.kind.label(),
3439 );
3440 }
3441 Ok(())
3442 }
3443
3444 fn stage_frozen_glm_trial_statistics(
3445 &mut self,
3446 theta: &Array1<f64>,
3447 warm_beta: Option<&Array1<f64>>,
3448 allow_gradient: bool,
3449 ) -> Result<(), EstimationError> {
3450 let kind = self.kind;
3451 let mut staged_gram: Option<Array2<f64>> = None;
3452 let mut staged_deriv: Option<(Array2<f64>, Array1<f64>)> = None;
3453 if theta.len() == self.rho_dim + 1 {
3454 let psi = theta[self.rho_dim];
3455 let tensor_covers = self
3462 .frozen_glm_tensor
3463 .as_ref()
3464 .is_some_and(|t| t.contains(psi));
3465 let current_w = if tensor_covers {
3466 match warm_beta {
3467 Some(beta) => self.frozen_glm_trial_weights(beta)?,
3468 None => None,
3469 }
3470 } else {
3471 None
3472 };
3473 if let (Some(tensor), Some(current_w)) =
3474 (self.frozen_glm_tensor.as_ref(), current_w.as_ref())
3475 {
3476 const FROZEN_GLM_WEIGHT_DRIFT_RTOL: f64 = 1e-3;
3477 if tensor.weight_drift_within(current_w.view(), FROZEN_GLM_WEIGHT_DRIFT_RTOL) {
3478 staged_gram = Some(tensor.gram_at(psi));
3479 log::debug!(
3480 "[STAGE] {} trial at psi={psi:.6}: serving frozen-W GLM \
3481 first-Fisher-step XᵀWX n-free (weight drift within tol)",
3482 kind.label(),
3483 );
3484 }
3485 if allow_gradient
3486 && tensor.contains_for_gradient(psi)
3487 && let Some((dgram_dpsi, drhs_dpsi)) =
3488 tensor.gradient_pair_if_sound(psi, current_w.view())
3489 {
3490 staged_deriv = Some((dgram_dpsi, drhs_dpsi));
3491 log::debug!(
3492 "[STAGE] {} trial at psi={psi:.6}: serving frozen-W GLM \
3493 ψ-gradient (∂G/∂ψ, ∂b/∂ψ) n-free (gradient weight drift within \
3494 tight tol); B_j stays exact",
3495 kind.label(),
3496 );
3497 }
3498 }
3499 }
3500 self.evaluator.stage_glm_first_step_gram(staged_gram);
3501 self.evaluator.stage_glm_psi_gram_deriv(staged_deriv);
3502 Ok(())
3503 }
3504
3505 fn eval_full(
3507 &mut self,
3508 theta: &Array1<f64>,
3509 order: gam_solve::rho_optimizer::OuterEvalOrder,
3510 analytic_outer_hessian_available: bool,
3511 ) -> Result<
3512 (
3513 f64,
3514 Array1<f64>,
3515 gam_problem::HessianResult,
3516 ),
3517 EstimationError,
3518 > {
3519 use gam_solve::rho_optimizer::OuterEvalOrder;
3520 let allow_second_order = matches!(order, OuterEvalOrder::ValueGradientHessian)
3521 && analytic_outer_hessian_available;
3522 if let Some(eval) = self.cache.memoized_eval(theta) {
3523 let cached_satisfies_order = !allow_second_order || eval.2.is_analytic();
3524 if cached_satisfies_order {
3525 return Ok(eval);
3526 }
3527 }
3528 let kind = self.kind;
3529 let nfree_fast_path_revision = self.evaluator.nfree_fast_path_revision();
3565 let skip_design_realization = !allow_second_order && theta.len() == self.rho_dim + 1 && {
3566 let psi = theta[self.rho_dim];
3567 self.evaluator.psi_gram_tensor_covers(psi)
3568 && self.evaluator.psi_gram_tensor_covers_gradient(psi)
3575 && self.evaluator.psi_gram_tensor_covers_skip(psi)
3592 && self.evaluator.supports_nfree_penalty_rekey()
3597 && nfree_fast_path_revision.is_some()
3598 };
3599 if skip_design_realization {
3611 log::debug!(
3612 "[STAGE] {} eval_full at psi={:.6}: skipping n×k design re-realization \
3613 + reconditioning — criterion/gradient/inner-solve served n-free from \
3614 the certified ψ-gram tensor (GaussianFixedCache + k-space ψ-derivatives)",
3615 kind.label(),
3616 theta[self.rho_dim],
3617 );
3618 } else {
3619 self.cache
3620 .ensure_theta(theta)
3621 .map_err(EstimationError::InvalidInput)?;
3622 }
3623 let warm_beta = self.evaluator.current_beta();
3624 self.ensure_frozen_glm_tensor(theta, warm_beta.as_ref())?;
3625 self.stage_frozen_glm_trial_statistics(theta, warm_beta.as_ref(), !allow_second_order)?;
3633 let hyper_dirs = if skip_design_realization {
3640 self.cache.nfree_tensor_gradient_hyper_dirs(theta)?
3641 } else {
3642 self.cache.hyper_dirs_for_current_design(self.data, kind)?
3643 };
3644
3645 let design_revision = if skip_design_realization {
3646 nfree_fast_path_revision
3647 } else {
3648 Some(self.cache.design_revision())
3649 };
3650 if self.evaluator.supports_nfree_penalty_rekey() {
3664 match self.cache.canonical_penalties_at(theta) {
3665 Ok(penalty) => self.evaluator.stage_fast_path_penalty(Some(penalty)),
3666 Err(e) => {
3667 log::warn!(
3668 "[STAGE] {} eval_full at psi={:.6}: exact n-free S(ψ) rebuild failed \
3669 ({e}); clearing stage (eval falls to slow path)",
3670 kind.label(),
3671 theta[self.rho_dim],
3672 );
3673 self.evaluator.stage_fast_path_penalty(None);
3674 }
3675 }
3676 }
3677 let eval = evaluate_joint_reml_outer_eval_at_theta(
3684 &mut self.evaluator,
3685 self.cache.design(),
3686 theta,
3687 self.rho_dim,
3688 hyper_dirs,
3689 warm_beta.as_ref().map(|b: &Array1<f64>| b.view()),
3690 if allow_second_order {
3691 order
3692 } else {
3693 OuterEvalOrder::ValueAndGradient
3694 },
3695 design_revision,
3696 );
3697 if let Ok(ref value) = eval {
3698 self.cache.store_eval_at(theta, value.clone());
3699 }
3700 eval
3701 }
3702
3703 fn eval_efs(
3704 &mut self,
3705 theta: &Array1<f64>,
3706 ) -> Result<gam_problem::EfsEval, EstimationError> {
3707 self.cache
3708 .ensure_theta(theta)
3709 .map_err(EstimationError::InvalidInput)?;
3710 let kind = self.kind;
3711 let hyper_dirs = try_build_spatial_log_kappa_hyper_dirs(
3712 self.data,
3713 self.cache.spec(),
3714 self.cache.design(),
3715 &self.cache.spatial_terms,
3716 )?
3717 .ok_or_else(|| {
3718 EstimationError::InvalidInput(format!(
3719 "failed to build {} hyper_dirs for exact-joint EFS",
3720 kind.adjective(),
3721 ))
3722 })?;
3723 let design_revision = Some(self.cache.design_revision());
3724 let warm_beta = self.evaluator.current_beta();
3725 evaluate_joint_reml_efs_at_theta(
3726 &mut self.evaluator,
3727 self.cache.design(),
3728 theta,
3729 self.rho_dim,
3730 hyper_dirs,
3731 warm_beta.as_ref().map(|b: &Array1<f64>| b.view()),
3732 design_revision,
3733 )
3734 }
3735
3736 fn eval_cost(&mut self, theta: &Array1<f64>) -> f64 {
3742 if let Some(cost) = self.cache.memoized_cost(theta) {
3743 return cost;
3744 }
3745 let probe_start = std::time::Instant::now();
3760 let psi_distance = self
3761 .cache
3762 .current_theta
3763 .as_ref()
3764 .filter(|reference| reference.len() == theta.len())
3765 .map(|reference| {
3766 reference
3767 .iter()
3768 .zip(theta.iter())
3769 .map(|(a, b)| (a - b) * (a - b))
3770 .sum::<f64>()
3771 .sqrt()
3772 })
3773 .unwrap_or(f64::NAN);
3774 let nfree_fast_path_revision = self.evaluator.nfree_fast_path_revision();
3788 let skip_value_realization = theta.len() == self.rho_dim + 1 && {
3789 let psi = theta[self.rho_dim];
3790 self.evaluator.psi_gram_tensor_covers(psi)
3791 && self.evaluator.psi_gram_tensor_covers_skip(psi)
3800 && self.evaluator.supports_nfree_penalty_rekey()
3805 && nfree_fast_path_revision.is_some()
3806 };
3807 if theta.len() == self.rho_dim + 1
3808 && self.evaluator.has_psi_gram_tensor()
3809 && !self.evaluator.psi_gram_tensor_covers(theta[self.rho_dim])
3810 {
3811 self.cache.store_cost_at(theta, f64::INFINITY);
3812 return f64::INFINITY;
3813 }
3814 if !skip_value_realization && self.cache.ensure_theta(theta).is_err() {
3815 return f64::INFINITY;
3816 }
3817 if self.evaluator.supports_nfree_penalty_rekey() {
3823 match self.cache.canonical_penalties_at(theta) {
3824 Ok(penalty) => self.evaluator.stage_fast_path_penalty(Some(penalty)),
3825 Err(_) => self.evaluator.stage_fast_path_penalty(None),
3826 }
3827 }
3828 let warm_beta = self.evaluator.current_beta();
3829 if let Err(err) = self.ensure_frozen_glm_tensor(theta, warm_beta.as_ref()) {
3830 log::warn!(
3831 "[STAGE] {} value-probe at psi={:.6}: frozen-W GLM tensor setup failed ({err}); \
3832 falling back to exact streamed Gram",
3833 self.kind.label(),
3834 if theta.len() > self.rho_dim {
3835 theta[self.rho_dim]
3836 } else {
3837 f64::NAN
3838 },
3839 );
3840 self.evaluator.stage_glm_first_step_gram(None);
3841 self.evaluator.stage_glm_psi_gram_deriv(None);
3842 } else if let Err(err) =
3843 self.stage_frozen_glm_trial_statistics(theta, warm_beta.as_ref(), false)
3844 {
3845 log::warn!(
3846 "[STAGE] {} value-probe at psi={:.6}: frozen-W GLM staging failed ({err}); \
3847 falling back to exact streamed Gram",
3848 self.kind.label(),
3849 if theta.len() > self.rho_dim {
3850 theta[self.rho_dim]
3851 } else {
3852 f64::NAN
3853 },
3854 );
3855 self.evaluator.stage_glm_first_step_gram(None);
3856 self.evaluator.stage_glm_psi_gram_deriv(None);
3857 }
3858 let design_revision = if skip_value_realization {
3859 nfree_fast_path_revision
3860 } else {
3861 Some(self.cache.design_revision())
3862 };
3863 let cost_label = self.kind.label();
3864 let result = {
3865 let design = self.cache.design();
3866 self.evaluator.evaluate_cost_only(
3867 &design.design,
3868 &design.penalties,
3869 &design.nullspace_dims,
3870 design.linear_constraints.clone(),
3871 theta,
3872 self.rho_dim,
3873 warm_beta.as_ref().map(|b: &Array1<f64>| b.view()),
3874 cost_label,
3875 design_revision,
3876 )
3877 };
3878 match result {
3879 Ok(cost) => {
3880 log::debug!(
3881 "[STAGE] {cost_label} value-probe (order=Value): elapsed={:.3}s \
3882 cost={cost:.6e} trial_theta_distance={psi_distance:.3e}",
3883 probe_start.elapsed().as_secs_f64(),
3884 );
3885 self.cache.store_cost_at(theta, cost);
3886 cost
3887 }
3888 Err(_) => f64::INFINITY,
3889 }
3890 }
3891
3892 fn reset(&mut self) {
3893 self.cache.current_theta = None;
3894 self.cache.last_eval_theta = None;
3895 self.cache.last_cost = None;
3896 self.cache.last_eval = None;
3897 }
3898}
3899
3900enum SpatialJointOutcome {
3933 Optimized {
3937 theta_star: Array1<f64>,
3938 final_value: f64,
3939 },
3940 NonConverged {
3944 iterations: usize,
3945 final_value: f64,
3946 final_grad_norm: Option<f64>,
3947 },
3948}
3949
3950fn kphase_log_norms(theta: &Array1<f64>, rho_dim: usize) -> (f64, f64) {
3951 let theta_norm = theta.iter().map(|v| v * v).sum::<f64>().sqrt();
3952 let log_kappa_norm = theta
3953 .iter()
3954 .skip(rho_dim)
3955 .map(|v| v * v)
3956 .sum::<f64>()
3957 .sqrt();
3958 (theta_norm, log_kappa_norm)
3959}
3960
3961fn run_exact_joint_spatial_optimization(
3962 kind: SpatialHyperKind,
3963 data: ArrayView2<'_, f64>,
3964 y: ArrayView1<'_, f64>,
3965 weights: ArrayView1<'_, f64>,
3966 offset: ArrayView1<'_, f64>,
3967 resolvedspec: &TermCollectionSpec,
3968 baseline_design: &TermCollectionDesign,
3969 family: LikelihoodSpec,
3970 options: &FitOptions,
3971 spatial_terms: &[usize],
3972 dims_per_term: &[usize],
3973 theta0: &Array1<f64>,
3974 lower: &Array1<f64>,
3975 upper: &Array1<f64>,
3976 rho_dim: usize,
3977 kappa_options: &SpatialLengthScaleOptimizationOptions,
3978) -> Result<(SpatialJointOutcome, SpatialLengthScaleOptimizationTiming), EstimationError> {
3979 let label = kind.label();
3980 assert!(
3982 lower.len() == theta0.len() && upper.len() == theta0.len(),
3983 "spatial hyperparameter bounds must match theta length: lower_len={}, upper_len={}, theta_len={}",
3984 lower.len(),
3985 upper.len(),
3986 theta0.len()
3987 );
3988 assert!(
3989 baseline_design.smooth.terms.len() >= spatial_terms.len(),
3990 "baseline design must have at least one smooth term per spatial term: baseline_terms={}, spatial_terms={}",
3991 baseline_design.smooth.terms.len(),
3992 spatial_terms.len()
3993 );
3994 use gam_solve::rho_optimizer::OuterEvalOrder;
3995 use gam_problem::{DeclaredHessianForm, Derivative, OuterEval};
3996
3997 let theta_dim = theta0.len();
3998 let coord_dim = theta_dim - rho_dim;
4001 let analytic_outer_hessian_available =
4011 exact_joint_spatial_outer_hessian_available(&family, baseline_design);
4012 if !analytic_outer_hessian_available {
4013 log::info!(
4014 "[{label}] analytic outer Hessian unavailable for family/design; routing without second-order geometry (coord_dim={coord_dim})"
4015 );
4016 }
4017 let mut prefer_gradient_only = theta_dim > EXACT_JOINT_SECOND_ORDER_THETA_CAP;
4023 if prefer_gradient_only {
4024 log::info!(
4025 "[{label}] joint θ-dim {theta_dim} exceeds the exact pair-Hessian budget \
4026 ({EXACT_JOINT_SECOND_ORDER_THETA_CAP}); routing gradient-only quasi-Newton"
4027 );
4028 }
4029 let mut suppress_outer_hessian_for_nfree = false;
4039
4040 log::trace!(
4041 "[{}] starting analytic optimization: rho_dim={}, coord_dim={}, dims_per_term={:?}",
4042 label,
4043 rho_dim,
4044 coord_dim,
4045 dims_per_term,
4046 );
4047
4048 let mut ctx = SpatialJointContext {
4049 data,
4050 rho_dim,
4051 kind,
4052 cache: SingleBlockExactJointDesignCache::new(
4053 data,
4054 resolvedspec.clone(),
4055 baseline_design.clone(),
4056 spatial_terms.to_vec(),
4057 rho_dim,
4058 dims_per_term.to_vec(),
4059 )
4060 .map_err(EstimationError::InvalidInput)?,
4061 evaluator: gam_solve::estimate::ExternalJointHyperEvaluator::new(
4062 y,
4063 weights,
4064 &baseline_design.design,
4065 offset,
4066 &baseline_design.penalties,
4067 &external_opts_for_design(&family, baseline_design, options),
4068 label,
4069 )?,
4070 frozen_glm_inputs: if coord_dim == 1 && frozen_glm_tensor_eligible_family(&family) {
4071 Some(SpatialFrozenGlmInputs {
4072 y: y.to_owned(),
4073 weights: weights.to_owned(),
4074 offset: offset.to_owned(),
4075 family: family.clone(),
4076 })
4077 } else {
4078 None
4079 },
4080 frozen_glm_psi_bounds: if coord_dim == 1 && frozen_glm_tensor_eligible_family(&family) {
4081 Some((lower[rho_dim], upper[rho_dim]))
4082 } else {
4083 None
4084 },
4085 frozen_glm_tensor: None,
4086 frozen_glm_tensor_attempted: false,
4087 frozen_glm_weight_memo: None,
4088 };
4089
4090 let mut psi_rank_stable_floor: Option<f64> = None;
4113 let mut psi_rank_stable_ceiling: Option<f64> = None;
4122 let nfree_penalty_capable = coord_dim == 1
4123 && family.is_gaussian_identity()
4124 && ctx.cache.supports_nfree_penalty_rekey();
4125 if nfree_penalty_capable {
4126 let psi_lo = lower[rho_dim];
4127 let psi_hi = upper[rho_dim];
4128 let z = Array1::from_iter(y.iter().zip(offset.iter()).map(|(yi, oi)| yi - oi));
4129 let theta_probe_base = theta0.clone();
4130 let SpatialJointContext {
4133 cache, evaluator, ..
4134 } = &mut ctx;
4135 let attached = evaluator.build_and_set_psi_gram_tensor(
4136 |psi| {
4137 let mut theta_probe = theta_probe_base.clone();
4138 theta_probe[rho_dim] = psi;
4139 cache.ensure_theta(&theta_probe)?;
4140 Ok(cache.design().design.clone())
4141 },
4142 weights,
4143 z.view(),
4144 psi_lo,
4145 psi_hi,
4146 );
4147 if attached {
4148 log::info!(
4149 "[{label}] certified ψ-gram tensor over [{psi_lo:.3}, {psi_hi:.3}]: \
4150 in-window trials assemble Gaussian sufficient statistics n-free"
4151 );
4152 let psi_anchor = theta0[rho_dim];
4157 psi_rank_stable_floor = evaluator
4158 .psi_gram_rank_stable_floor(psi_anchor)
4159 .filter(|&f| f.is_finite() && f > psi_lo && f < psi_anchor);
4160 log::info!(
4161 "[KAPPA-PHASE-FLOOR] n_rows={} psi_lo={psi_lo:.6} psi_anchor={psi_anchor:.6} \
4162 rank_stable_floor={:?} lifted={}",
4163 data.nrows(),
4164 evaluator.psi_gram_rank_stable_floor(psi_anchor),
4165 psi_rank_stable_floor.is_some(),
4166 );
4167 if let Some(floor) = psi_rank_stable_floor {
4168 log::info!(
4169 "[{label}] rank-stable κ-floor ψ_floor={floor:.6} > window floor \
4170 ψ_lo={psi_lo:.6}: lifting the optimizer lower bound to keep every \
4171 in-window trial on the n-free design-realization skip (#1033). The \
4172 conditioned Gram is rank-deficient below ψ_floor (longest-length-scale \
4173 radial mode collapses into the nullspace), where the skip is soundly \
4174 refused; that band drifts with n via the sample-std standardization, \
4175 so this n-free k-space floor is the n-independent fix."
4176 );
4177 }
4178 psi_rank_stable_ceiling = evaluator
4187 .psi_gram_rank_stable_ceiling(psi_anchor)
4188 .filter(|&c| c.is_finite() && c < psi_hi && c > psi_anchor);
4189 log::info!(
4190 "[KAPPA-PHASE-CEIL] n_rows={} psi_hi={psi_hi:.6} psi_anchor={psi_anchor:.6} \
4191 rank_stable_ceiling={:?} clamped={}",
4192 data.nrows(),
4193 evaluator.psi_gram_rank_stable_ceiling(psi_anchor),
4194 psi_rank_stable_ceiling.is_some(),
4195 );
4196 if let Some(ceiling) = psi_rank_stable_ceiling {
4197 log::info!(
4198 "[{label}] rank-stable κ-ceiling ψ_ceil={ceiling:.6} < window ceiling \
4199 ψ_hi={psi_hi:.6}: clamping the optimizer upper bound to keep every \
4200 in-window trial on the n-free design-realization skip (#1033). The \
4201 conditioned Gram is rank-deficient above ψ_ceil (longest-frequency \
4202 radial mode goes collinear), where the skip is soundly refused; a \
4203 line-search overshoot there trips the O(n) reset_surface lane (and the \
4204 deficient pinning ψ it records resets the next in-band trial too)."
4205 );
4206 }
4207 let gradient_covers_full_window = evaluator.psi_gram_tensor_covers_gradient(psi_lo)
4208 && evaluator.psi_gram_tensor_covers_gradient(psi_hi);
4209 if gradient_covers_full_window {
4210 log::info!(
4211 "[{label}] certified ψ-gram tensor gradient lane covers the full \
4212 optimizer window [{psi_lo:.3}, {psi_hi:.3}]"
4213 );
4214 } else {
4215 log::info!(
4216 "[{label}] ψ-gram tensor value lane certified, but the gradient lane \
4217 does not cover the full optimizer window [{psi_lo:.3}, {psi_hi:.3}]; \
4218 keeping exact streamed kappa routing"
4219 );
4220 }
4221 evaluator.set_supports_nfree_penalty_rekey(true);
4241 log::info!(
4242 "[{label}] exact n-free ψ-penalty re-key enabled over [{psi_lo:.3}, \
4243 {psi_hi:.3}]: in-window fast-path trials rebuild S(ψ) n-free from frozen \
4244 geometry (no reset_surface)"
4245 );
4246 } else {
4247 log::info!(
4248 "[{label}] ψ-gram tensor did not certify over [{psi_lo:.3}, {psi_hi:.3}]; \
4249 keeping the exact per-trial path"
4250 );
4251 }
4252 if attached
4273 && evaluator.psi_gram_tensor_covers_gradient(psi_lo)
4274 && evaluator.psi_gram_tensor_covers_gradient(psi_hi)
4275 && evaluator.supports_nfree_penalty_rekey()
4276 && cache.supports_nfree_gradient_only_routing()
4277 {
4278 suppress_outer_hessian_for_nfree = true;
4279 prefer_gradient_only = true;
4280 log::info!(
4281 "[{label}] n-free Gaussian ψ-lane armed; suppressing the analytic outer \
4282 Hessian and routing gradient-only (BFGS) so the κ outer loop never realizes \
4283 the O(n) second-order slab — n-independent outer loop (#1033)"
4284 );
4285 }
4286 } else if coord_dim == 1 && family.is_gaussian_identity() {
4287 log::info!(
4288 "[{label}] exact n-free ψ-penalty re-key unavailable; skipping ψ-gram tensor \
4289 attachment so value, gradient, and Hessian remain on the same exact streamed \
4290 objective"
4291 );
4292 }
4293
4294 const OUTER_FD_AUDIT_MAX_N: usize = 4_000; const OUTER_FD_AUDIT_MAX_THETA_DIM: usize = 32; let n_total = data.nrows();
4322 let outer_fd_audit_eligible = log::log_enabled!(log::Level::Info) && analytic_outer_hessian_available && n_total <= OUTER_FD_AUDIT_MAX_N && theta_dim <= OUTER_FD_AUDIT_MAX_THETA_DIM; log::info!(
4327 "[OUTER-FD-AUDIT/spatial-exact-joint] gate eligible={outer_fd_audit_eligible} \
4328 analytic_grad={analytic_outer_hessian_available} n_total={n_total} \
4329 theta_dim={theta_dim} rho_dim={rho_dim} psi_dim={coord_dim}"
4330 );
4331 if outer_fd_audit_eligible {
4332 let audit = (|| -> Result<gam_solve::rho_optimizer::OuterGradientFdAudit, String> {
4334 let mut eval_at = |theta: &Array1<f64>,
4335 mode: gam_solve::estimate::reml::reml_outer_engine::EvalMode|
4336 -> Result<
4337 (
4338 f64,
4339 Array1<f64>,
4340 gam_problem::HessianResult,
4341 ),
4342 String,
4343 > {
4344 use gam_solve::estimate::reml::reml_outer_engine::EvalMode;
4345 let order = if matches!(mode, EvalMode::ValueGradientHessian) {
4346 OuterEvalOrder::ValueGradientHessian
4347 } else {
4348 OuterEvalOrder::Value
4349 };
4350 ctx.eval_full(theta, order, analytic_outer_hessian_available)
4351 .map_err(|e| format!("fd-audit eval_full: {e}"))
4352 };
4353 let rho_dim_audit = rho_dim;
4354 let label_fn = move |i: usize| -> String {
4355 if i < rho_dim_audit {
4356 format!("rho[{i}]")
4357 } else {
4358 format!("psi_kappa[{}]", i - rho_dim_audit)
4359 }
4360 };
4361 gam_solve::rho_optimizer::outer_gradient_fd_audit(
4362 theta0,
4364 1e-4,
4365 label_fn,
4366 &mut eval_at,
4367 )
4368 })();
4369 match audit {
4371 Ok(audit) => audit.log_verdict("spatial-exact-joint"),
4372 Err(e) => log::warn!("[OUTER-FD-AUDIT/spatial-exact-joint] skipped: {e}"),
4373 }
4374 }
4375
4376 let kphase_prime_order = if analytic_outer_hessian_available && !suppress_outer_hessian_for_nfree {
4377 OuterEvalOrder::ValueGradientHessian
4378 } else {
4379 OuterEvalOrder::ValueAndGradient
4380 };
4381 let kphase_prime_start = std::time::Instant::now();
4382 drop(ctx.eval_full(theta0, kphase_prime_order, analytic_outer_hessian_available)?);
4383 log::info!(
4384 "[KAPPA-PHASE-PRIME] n_rows={} order={:?} elapsed_s={:.4} slow_path_resets_total={} design_revision={}",
4385 data.nrows(),
4386 kphase_prime_order,
4387 kphase_prime_start.elapsed().as_secs_f64(),
4388 ctx.evaluator.slow_path_reset_count(),
4389 ctx.cache.design_revision(),
4390 );
4391
4392 let kphase_cost_calls = std::cell::Cell::new(0usize);
4393 let kphase_eval_calls = std::cell::Cell::new(0usize);
4394 let kphase_efs_calls = std::cell::Cell::new(0usize);
4395 let kphase_cost_total_s = std::cell::Cell::new(0.0);
4396 let kphase_eval_total_s = std::cell::Cell::new(0.0);
4397 let kphase_efs_total_s = std::cell::Cell::new(0.0);
4398 let kphase_nfree_miss_shape = std::cell::Cell::new(0u64);
4399 let kphase_nfree_miss_value = std::cell::Cell::new(0u64);
4400 let kphase_nfree_miss_gradient = std::cell::Cell::new(0u64);
4401 let kphase_nfree_miss_penalty = std::cell::Cell::new(0u64);
4402 let kphase_nfree_miss_revision = std::cell::Cell::new(0u64);
4403 let kphase_nfree_miss_second_order = std::cell::Cell::new(0u64);
4404 let kphase_nfree_miss_other = std::cell::Cell::new(0u64);
4405 let kphase_optim_start = std::time::Instant::now();
4406 let kphase_log_kappa_dim = coord_dim;
4407 let kphase_slow_resets_start = ctx.evaluator.slow_path_reset_count();
4408 let kphase_design_revision_start = ctx.cache.design_revision();
4409
4410 let lower_effective: std::borrow::Cow<'_, Array1<f64>> = match psi_rank_stable_floor {
4417 Some(floor) if coord_dim == 1 && floor > lower[rho_dim] => {
4418 let mut lifted = lower.clone();
4419 lifted[rho_dim] = floor;
4420 std::borrow::Cow::Owned(lifted)
4421 }
4422 _ => std::borrow::Cow::Borrowed(lower),
4423 };
4424 let lower = lower_effective.as_ref();
4425
4426 let upper_effective: std::borrow::Cow<'_, Array1<f64>> = match psi_rank_stable_ceiling {
4434 Some(ceiling) if coord_dim == 1 && ceiling < upper[rho_dim] => {
4435 let mut clamped = upper.clone();
4436 clamped[rho_dim] = ceiling;
4437 std::borrow::Cow::Owned(clamped)
4438 }
4439 _ => std::borrow::Cow::Borrowed(upper),
4440 };
4441 let upper = upper_effective.as_ref();
4442
4443 let problem = exact_joint_multistart_outer_problem(
4444 theta0,
4445 lower,
4446 upper,
4447 rho_dim,
4448 coord_dim,
4449 theta_dim,
4450 Derivative::Analytic,
4451 if analytic_outer_hessian_available && !suppress_outer_hessian_for_nfree {
4452 DeclaredHessianForm::Either
4453 } else {
4454 DeclaredHessianForm::Unavailable
4459 },
4460 prefer_gradient_only,
4461 suppress_outer_hessian_for_nfree,
4472 seed_risk_profile_for_likelihood_family(&family),
4473 kappa_options.rel_tol.max(1e-6),
4474 kappa_options.max_outer_iter.max(1),
4475 Some(5.0),
4479 Some(kappa_options.log_step.clamp(0.25, 1.0)),
4481 None,
4482 Some((data.nrows(), baseline_design.design.ncols())),
4487 !constant_curvature_term_indices(resolvedspec).is_empty(),
4491 );
4492
4493 let eval_outer = |ctx: &mut &mut SpatialJointContext<'_>,
4494 theta: &Array1<f64>,
4495 order: OuterEvalOrder|
4496 -> Result<OuterEval, EstimationError> {
4497 let t0 = std::time::Instant::now();
4498 let allow_second_order_for_call = matches!(order, OuterEvalOrder::ValueGradientHessian)
4499 && analytic_outer_hessian_available;
4500 let gate = ctx.nfree_skip_gate_status(theta, allow_second_order_for_call, true);
4501 let resets_before = ctx.evaluator.slow_path_reset_count();
4502 let raw = ctx.eval_full(theta, order, analytic_outer_hessian_available);
4503 let reset_delta = ctx
4504 .evaluator
4505 .slow_path_reset_count()
4506 .saturating_sub(resets_before);
4507 if reset_delta > 0 {
4508 if !gate.shape {
4509 kphase_nfree_miss_shape.set(kphase_nfree_miss_shape.get() + reset_delta);
4510 }
4511 if gate.shape && !gate.value {
4512 kphase_nfree_miss_value.set(kphase_nfree_miss_value.get() + reset_delta);
4513 }
4514 if gate.shape && gate.value && !gate.gradient {
4515 kphase_nfree_miss_gradient.set(kphase_nfree_miss_gradient.get() + reset_delta);
4516 }
4517 if gate.shape && gate.value && gate.gradient && !gate.penalty {
4518 kphase_nfree_miss_penalty.set(kphase_nfree_miss_penalty.get() + reset_delta);
4519 }
4520 if gate.shape && gate.value && gate.gradient && gate.penalty && !gate.revision {
4521 kphase_nfree_miss_revision.set(kphase_nfree_miss_revision.get() + reset_delta);
4522 }
4523 if gate.shape
4524 && gate.value
4525 && gate.gradient
4526 && gate.penalty
4527 && gate.revision
4528 && gate.second_order
4529 {
4530 kphase_nfree_miss_second_order
4531 .set(kphase_nfree_miss_second_order.get() + reset_delta);
4532 }
4533 if gate.would_skip(true) {
4534 kphase_nfree_miss_other.set(kphase_nfree_miss_other.get() + reset_delta);
4535 }
4536 }
4537 let elapsed_s = t0.elapsed().as_secs_f64();
4538 kphase_eval_calls.set(kphase_eval_calls.get() + 1);
4539 kphase_eval_total_s.set(kphase_eval_total_s.get() + elapsed_s);
4540 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta, rho_dim);
4541 log::info!(
4542 "[KAPPA-PHASE] phase=eval_outer call={} order={:?} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
4543 kphase_eval_calls.get(),
4544 order,
4545 Some(ctx.cache.design_revision()),
4546 theta_norm,
4547 log_kappa_norm,
4548 elapsed_s,
4549 );
4550 match raw {
4551 Ok((cost, grad, hess)) => Ok(OuterEval {
4552 cost,
4553 gradient: grad,
4554 hessian: hess,
4555 inner_beta_hint: None,
4556 }),
4557 Err(err) if is_recoverable_trial_point_error(&err) => {
4565 log::debug!(
4566 "[{label}] trial point infeasible (kernel design \
4567 not constructible at theta={theta:?}): {err}; retreating",
4568 );
4569 Ok(OuterEval::infeasible(theta_dim))
4570 }
4571 Err(err) => Err(err),
4572 }
4573 };
4574
4575 let mut obj = problem.build_objective_with_eval_order(
4576 &mut ctx,
4577 |ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>| {
4578 let t0 = std::time::Instant::now();
4579 let gate = ctx.nfree_skip_gate_status(theta, false, false);
4580 let resets_before = ctx.evaluator.slow_path_reset_count();
4581 let cost = ctx.eval_cost(theta);
4582 let reset_delta = ctx
4583 .evaluator
4584 .slow_path_reset_count()
4585 .saturating_sub(resets_before);
4586 if reset_delta > 0 {
4587 if !gate.shape {
4588 kphase_nfree_miss_shape.set(kphase_nfree_miss_shape.get() + reset_delta);
4589 }
4590 if gate.shape && !gate.value {
4591 kphase_nfree_miss_value.set(kphase_nfree_miss_value.get() + reset_delta);
4592 }
4593 if gate.shape && gate.value && !gate.penalty {
4594 kphase_nfree_miss_penalty.set(kphase_nfree_miss_penalty.get() + reset_delta);
4595 }
4596 if gate.shape && gate.value && gate.penalty && !gate.revision {
4597 kphase_nfree_miss_revision.set(kphase_nfree_miss_revision.get() + reset_delta);
4598 }
4599 if gate.would_skip(false) {
4600 kphase_nfree_miss_other.set(kphase_nfree_miss_other.get() + reset_delta);
4601 }
4602 }
4603 let elapsed_s = t0.elapsed().as_secs_f64();
4604 kphase_cost_calls.set(kphase_cost_calls.get() + 1);
4605 kphase_cost_total_s.set(kphase_cost_total_s.get() + elapsed_s);
4606 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta, rho_dim);
4607 log::info!(
4608 "[KAPPA-PHASE] phase=cost call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
4609 kphase_cost_calls.get(),
4610 Some(ctx.cache.design_revision()),
4611 theta_norm,
4612 log_kappa_norm,
4613 elapsed_s,
4614 );
4615 Ok(cost)
4616 },
4617 |ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>| {
4618 eval_outer(
4619 ctx,
4620 theta,
4621 if analytic_outer_hessian_available && !suppress_outer_hessian_for_nfree {
4631 OuterEvalOrder::ValueGradientHessian
4632 } else {
4633 OuterEvalOrder::ValueAndGradient
4634 },
4635 )
4636 },
4637 |ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>, order: OuterEvalOrder| {
4638 eval_outer(ctx, theta, order)
4639 },
4640 Some(|ctx: &mut &mut SpatialJointContext<'_>| {
4641 ctx.reset();
4642 }),
4643 Some(|ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>| {
4644 let t0 = std::time::Instant::now();
4645 let eval = ctx.eval_efs(theta);
4646 let elapsed_s = t0.elapsed().as_secs_f64();
4647 kphase_efs_calls.set(kphase_efs_calls.get() + 1);
4648 kphase_efs_total_s.set(kphase_efs_total_s.get() + elapsed_s);
4649 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta, rho_dim);
4650 log::info!(
4651 "[KAPPA-PHASE] phase=efs call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
4652 kphase_efs_calls.get(),
4653 Some(ctx.cache.design_revision()),
4654 theta_norm,
4655 log_kappa_norm,
4656 elapsed_s,
4657 );
4658 eval
4659 }),
4660 );
4661
4662 let run_label = match kind {
4663 SpatialHyperKind::Anisotropic => "aniso-psi joint REML",
4664 SpatialHyperKind::Isotropic => "iso-kappa joint REML",
4665 };
4666 let result = problem.run(&mut obj, run_label).map_err(|e| {
4667 EstimationError::InvalidInput(format!(
4668 "{} analytic optimization failed after exhausting strategy fallbacks: {e}",
4669 kind.adjective(),
4670 ))
4671 })?;
4672 drop(obj);
4673 let kphase_total_s = kphase_optim_start.elapsed().as_secs_f64();
4674 let kphase_slow_resets = ctx
4675 .evaluator
4676 .slow_path_reset_count()
4677 .saturating_sub(kphase_slow_resets_start);
4678 let kphase_design_revision_delta = ctx
4679 .cache
4680 .design_revision()
4681 .saturating_sub(kphase_design_revision_start);
4682 log::info!(
4683 "[KAPPA-PHASE-SUMMARY] n_rows={} log_kappa_dim={} n_cost={} cost_total_s={:.4} n_eval={} eval_total_s={:.4} n_efs={} efs_total_s={:.4} slow_path_resets={} design_revision_delta={} nfree_miss_shape={} nfree_miss_value={} nfree_miss_gradient={} nfree_miss_penalty={} nfree_miss_revision={} nfree_miss_second_order={} nfree_miss_other={} optim_total_s={:.4}",
4684 data.nrows(),
4685 kphase_log_kappa_dim,
4686 kphase_cost_calls.get(),
4687 kphase_cost_total_s.get(),
4688 kphase_eval_calls.get(),
4689 kphase_eval_total_s.get(),
4690 kphase_efs_calls.get(),
4691 kphase_efs_total_s.get(),
4692 kphase_slow_resets,
4693 kphase_design_revision_delta,
4694 kphase_nfree_miss_shape.get(),
4695 kphase_nfree_miss_value.get(),
4696 kphase_nfree_miss_gradient.get(),
4697 kphase_nfree_miss_penalty.get(),
4698 kphase_nfree_miss_revision.get(),
4699 kphase_nfree_miss_second_order.get(),
4700 kphase_nfree_miss_other.get(),
4701 kphase_total_s,
4702 );
4703 let timing = SpatialLengthScaleOptimizationTiming {
4704 log_kappa_dim: kphase_log_kappa_dim,
4705 cost_calls: kphase_cost_calls.get(),
4706 cost_total_s: kphase_cost_total_s.get(),
4707 eval_calls: kphase_eval_calls.get(),
4708 eval_total_s: kphase_eval_total_s.get(),
4709 efs_calls: kphase_efs_calls.get(),
4710 efs_total_s: kphase_efs_total_s.get(),
4711 slow_path_resets: kphase_slow_resets,
4712 design_revision_delta: kphase_design_revision_delta,
4713 nfree_miss_shape: kphase_nfree_miss_shape.get(),
4714 nfree_miss_value: kphase_nfree_miss_value.get(),
4715 nfree_miss_gradient: kphase_nfree_miss_gradient.get(),
4716 nfree_miss_penalty: kphase_nfree_miss_penalty.get(),
4717 nfree_miss_revision: kphase_nfree_miss_revision.get(),
4718 nfree_miss_second_order: kphase_nfree_miss_second_order.get(),
4719 nfree_miss_other: kphase_nfree_miss_other.get(),
4720 optim_total_s: kphase_total_s,
4721 };
4722 if !result.converged {
4723 let rel_to_cost_threshold = options.tol * (1.0_f64 + result.final_value.abs());
4734 if let Some(final_grad) = result
4735 .final_grad_norm
4736 .filter(|v| v.is_finite() && *v <= rel_to_cost_threshold)
4737 {
4738 log::info!(
4739 "[{}] outer optimization hit max_iter={} but \
4740 projected gradient norm {:.3e} ≤ τ·(1+|f|) = {:.3e} \
4741 (τ={:.3e}, |f|={:.3e}); accepting iterate under the mgcv-style \
4742 relative-to-cost REML convergence criterion.",
4743 label,
4744 result.iterations,
4745 final_grad,
4746 rel_to_cost_threshold,
4747 options.tol,
4748 result.final_value.abs(),
4749 );
4750 } else if result.final_value.is_finite() {
4751 log::warn!(
4766 "[{}] {} did not converge after {} iterations \
4767 (final_objective={:.6e}, final_grad_norm={}); keeping the \
4768 frozen baseline geometry instead of aborting the fit.",
4769 label,
4770 kind.adjective(),
4771 result.iterations,
4772 result.final_value,
4773 result.final_grad_norm_report(),
4774 );
4775 return Ok((
4776 SpatialJointOutcome::NonConverged {
4777 iterations: result.iterations,
4778 final_value: result.final_value,
4779 final_grad_norm: result.final_grad_norm,
4780 },
4781 timing,
4782 ));
4783 } else {
4784 crate::bail_invalid_estim!(
4789 "{} analytic optimization diverged after {} iterations (final_objective={:.6e}, final_grad_norm={})",
4790 kind.adjective(),
4791 result.iterations,
4792 result.final_value,
4793 result.final_grad_norm_report(),
4794 );
4795 }
4796 }
4797 log::trace!(
4798 "[{}] converged in {} iterations, final_value={:.6e}, grad_norm={}",
4799 label,
4800 result.iterations,
4801 result.final_value,
4802 result.final_grad_norm_report(),
4803 );
4804 let theta_star = result.rho;
4808 Ok((
4809 SpatialJointOutcome::Optimized {
4810 theta_star,
4811 final_value: result.final_value,
4812 },
4813 timing,
4814 ))
4815}
4816
4817fn set_single_term_spatial_length_scale(
4821 term: &mut SmoothTermSpec,
4822 length_scale: f64,
4823) -> Result<(), EstimationError> {
4824 match &mut term.basis {
4825 SmoothBasisSpec::ThinPlate { spec, .. } => {
4826 spec.length_scale = length_scale;
4827 Ok(())
4828 }
4829 SmoothBasisSpec::Matern { spec, .. } => {
4830 spec.length_scale = length_scale;
4831 Ok(())
4832 }
4833 SmoothBasisSpec::Duchon { spec, .. } => {
4834 spec.length_scale = Some(length_scale);
4835 Ok(())
4836 }
4837 _ => Err(EstimationError::InvalidInput(format!(
4838 "term '{}' does not expose a spatial length scale",
4839 term.name
4840 ))),
4841 }
4842}
4843
4844fn set_single_term_spatial_aniso_log_scales(
4848 term: &mut SmoothTermSpec,
4849 eta: Vec<f64>,
4850) -> Result<(), EstimationError> {
4851 let eta = center_aniso_log_scales(&eta);
4852 match &mut term.basis {
4853 SmoothBasisSpec::Matern { spec, .. } => {
4854 spec.aniso_log_scales = Some(eta);
4855 Ok(())
4856 }
4857 SmoothBasisSpec::Duchon { spec, .. } => {
4858 spec.aniso_log_scales = Some(eta);
4859 Ok(())
4860 }
4861 _ => Err(EstimationError::InvalidInput(format!(
4862 "term '{}' does not support aniso_log_scales",
4863 term.name
4864 ))),
4865 }
4866}
4867
4868pub fn get_constant_curvature_kappa(spec: &TermCollectionSpec, term_idx: usize) -> Option<f64> {
4887 constant_curvature_term_spec(spec, term_idx).map(|cc| cc.kappa)
4888}
4889
4890pub fn constant_curvature_term_indices(spec: &TermCollectionSpec) -> Vec<usize> {
4892 (0..spec.smooth_terms.len())
4893 .filter(|&idx| constant_curvature_term_spec(spec, idx).is_some())
4894 .collect()
4895}
4896
4897
4898#[derive(Debug, Clone)]
4899struct SingleSmoothTermRealization {
4900 design_local: DesignMatrix,
4901 term: SmoothTerm,
4902 dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
4903}
4904
4905impl SingleSmoothTermRealization {
4906 fn active_penaltyinfo(&self) -> Vec<PenaltyInfo> {
4907 self.term
4908 .penaltyinfo_local
4909 .iter()
4910 .filter(|info| info.active)
4911 .cloned()
4912 .collect()
4913 }
4914}
4915
4916fn build_single_smooth_term_realization(
4917 data: ArrayView2<'_, f64>,
4918 termspec: &SmoothTermSpec,
4919) -> Result<SingleSmoothTermRealization, BasisError> {
4920 let raw = build_smooth_design(data, std::slice::from_ref(termspec))?;
4921 finish_single_smooth_term_realization(raw)
4922}
4923
4924fn finish_single_smooth_term_realization(
4925 raw: RawSmoothDesign,
4926) -> Result<SingleSmoothTermRealization, BasisError> {
4927 let RawSmoothDesign {
4928 term_designs,
4929 dropped_penaltyinfo,
4930 terms,
4931 ..
4932 } = raw;
4933 let term = terms.into_iter().next().ok_or_else(|| {
4934 BasisError::InvalidInput("single-term smooth build returned no term".to_string())
4935 })?;
4936 let design = term_designs.into_iter().next().ok_or_else(|| {
4937 BasisError::InvalidInput("single-term smooth build returned no term design".to_string())
4938 })?;
4939
4940 Ok(SingleSmoothTermRealization {
4941 design_local: design,
4942 term,
4943 dropped_penaltyinfo,
4944 })
4945}
4946
4947fn wrap_local_build_as_realization(
4954 mut local: LocalSmoothTermBuild,
4955 termspec: &SmoothTermSpec,
4956) -> Result<SingleSmoothTermRealization, String> {
4957 let p_local = local.dim;
4958 let lb_local = if local.box_reparam {
4959 shape_lower_bounds_local(termspec.shape, p_local)
4960 } else {
4961 None
4962 };
4963
4964 let active_count = local.penaltyinfo.iter().filter(|info| info.active).count();
4965 if active_count != local.penalties.len() {
4966 return Err(format!(
4967 "internal penalty info mismatch for term '{}': active_infos={}, penalties={}",
4968 termspec.name,
4969 active_count,
4970 local.penalties.len()
4971 ));
4972 }
4973
4974 let mut dropped_penaltyinfo = Vec::<DroppedPenaltyBlockInfo>::new();
4975 for info in local.penaltyinfo.iter().filter(|info| !info.active) {
4976 dropped_penaltyinfo.push(DroppedPenaltyBlockInfo {
4977 termname: Some(termspec.name.clone()),
4978 penalty: info.clone(),
4979 });
4980 }
4981 for info in &local.pre_dropped_penaltyinfo {
4982 dropped_penaltyinfo.push(DroppedPenaltyBlockInfo {
4983 termname: Some(termspec.name.clone()),
4984 penalty: info.clone(),
4985 });
4986 }
4987
4988 let applied_rotation: Option<gam_terms::basis::JointNullRotation> = match (
4992 local.joint_null_rotation.take(),
4993 lb_local.is_some(),
4994 local.linear_constraints.is_some(),
4995 ) {
4996 (Some(rot), false, false) => {
4997 let q = &rot.rotation;
4998 let dense = local
4999 .design
5000 .try_to_dense_by_chunks("joint-null absorption rotation (single realization)")
5001 .map_err(|e| {
5002 format!(
5003 "joint-null absorption rotation: dense conversion failed for term '{}': {}",
5004 termspec.name, e
5005 )
5006 })?;
5007 let rotated = gam_linalg::faer_ndarray::fast_ab(&dense, q);
5008 local.design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(rotated));
5009 local.penalties = local
5010 .penalties
5011 .into_iter()
5012 .map(|s_local| {
5013 let qt_s = gam_linalg::faer_ndarray::fast_atb(q, &s_local);
5014 gam_linalg::faer_ndarray::fast_ab(&qt_s, q)
5015 })
5016 .collect();
5017 local.ops = vec![None; local.penalties.len()];
5018 local.kronecker_factored = None;
5019 Some(rot)
5020 }
5021 (Some(_), _, _) => None,
5022 (None, _, _) => None,
5023 };
5024
5025 let smooth_term = SmoothTerm {
5026 name: termspec.name.clone(),
5027 coeff_range: 0..p_local,
5028 shape: termspec.shape,
5029 penalties_local: local.penalties.clone(),
5030 nullspace_dims: local.nullspaces.clone(),
5031 penaltyinfo_local: local.penaltyinfo.clone(),
5032 metadata: local.metadata.clone(),
5033 lower_bounds_local: lb_local,
5034 linear_constraints_local: local.linear_constraints.clone(),
5035 kronecker_factored: local.kronecker_factored.take(),
5036 joint_null_rotation: applied_rotation,
5037 unabsorbed_global_orthogonality: None,
5040 };
5041
5042 Ok(SingleSmoothTermRealization {
5043 design_local: local.design,
5044 term: smooth_term,
5045 dropped_penaltyinfo,
5046 })
5047}
5048
5049fn freeze_geometry_from_metadata(
5060 termspec: &SmoothTermSpec,
5061 metadata: &BasisMetadata,
5062) -> Option<SmoothTermSpec> {
5063 let mut frozen = termspec.clone();
5064 match (&mut frozen.basis, metadata) {
5065 (
5066 SmoothBasisSpec::Matern {
5067 spec,
5068 input_scales: spec_scales,
5069 ..
5070 },
5071 BasisMetadata::Matern {
5072 centers,
5073 input_scales: meta_scales,
5074 identifiability_transform,
5075 nullspace_shrinkage_survived,
5076 ..
5077 },
5078 ) => {
5079 spec.center_strategy = CenterStrategy::UserProvided(centers.clone());
5080 if spec_scales.is_none()
5081 && let Some(s) = meta_scales.clone()
5082 {
5083 *spec_scales = Some(s);
5084 }
5085 if let Some(transform) = identifiability_transform.clone() {
5103 spec.identifiability = MaternIdentifiability::FrozenTransform {
5104 transform,
5105 nullspace_shrinkage_survived: Some(*nullspace_shrinkage_survived),
5106 };
5107 }
5108 Some(frozen)
5109 }
5110 (
5111 SmoothBasisSpec::Duchon {
5112 spec,
5113 input_scales: spec_scales,
5114 ..
5115 },
5116 BasisMetadata::Duchon {
5117 centers,
5118 input_scales: meta_scales,
5119 ..
5120 },
5121 ) => {
5122 spec.center_strategy = CenterStrategy::UserProvided(centers.clone());
5123 if spec_scales.is_none()
5124 && let Some(s) = meta_scales.clone()
5125 {
5126 *spec_scales = Some(s);
5127 }
5128 Some(frozen)
5129 }
5130 (
5131 SmoothBasisSpec::ThinPlate {
5132 spec,
5133 input_scales: spec_scales,
5134 ..
5135 },
5136 BasisMetadata::ThinPlate {
5137 centers,
5138 input_scales: meta_scales,
5139 ..
5140 },
5141 ) => {
5142 spec.center_strategy = CenterStrategy::UserProvided(centers.clone());
5143 if spec_scales.is_none()
5144 && let Some(s) = meta_scales.clone()
5145 {
5146 *spec_scales = Some(s);
5147 }
5148 Some(frozen)
5149 }
5150 _ => None,
5153 }
5154}
5155
5156fn rebuild_smooth_auxiliary_state(
5157 smooth: &mut SmoothDesign,
5158 dropped_penaltyinfo_by_term: &[Vec<DroppedPenaltyBlockInfo>],
5159) -> Result<(), String> {
5160 if dropped_penaltyinfo_by_term.len() != smooth.terms.len() {
5161 return Err(SmoothError::dimension_mismatch(format!(
5162 "smooth dropped-penalty cache mismatch: terms={}, dropped_sets={}",
5163 smooth.terms.len(),
5164 dropped_penaltyinfo_by_term.len()
5165 ))
5166 .into());
5167 }
5168
5169 let total_p = smooth.total_smooth_cols();
5170 let mut coefficient_lower_bounds = Array1::<f64>::from_elem(total_p, f64::NEG_INFINITY);
5171 let mut any_bounds = false;
5172 let mut linear_constraintrows: Vec<Array1<f64>> = Vec::new();
5173 let mut linear_constraint_b: Vec<f64> = Vec::new();
5174
5175 for term in &smooth.terms {
5176 let range = term.coeff_range.clone();
5177 if let Some(lb_local) = term.lower_bounds_local.as_ref() {
5178 if lb_local.len() != range.len() {
5179 return Err(SmoothError::dimension_mismatch(format!(
5180 "smooth lower-bound cache mismatch for term '{}': bounds={}, coeffs={}",
5181 term.name,
5182 lb_local.len(),
5183 range.len()
5184 ))
5185 .into());
5186 }
5187 coefficient_lower_bounds
5188 .slice_mut(s![range.clone()])
5189 .assign(lb_local);
5190 any_bounds = true;
5191 }
5192 if let Some(lin_local) = term.linear_constraints_local.as_ref() {
5193 if lin_local.a.ncols() != range.len() {
5194 return Err(SmoothError::dimension_mismatch(format!(
5195 "smooth linear-constraint cache mismatch for term '{}': cols={}, coeffs={}",
5196 term.name,
5197 lin_local.a.ncols(),
5198 range.len()
5199 ))
5200 .into());
5201 }
5202 for r in 0..lin_local.a.nrows() {
5203 let mut row = Array1::<f64>::zeros(total_p);
5204 row.slice_mut(s![range.clone()]).assign(&lin_local.a.row(r));
5205 linear_constraintrows.push(row);
5206 linear_constraint_b.push(lin_local.b[r]);
5207 }
5208 }
5209 }
5210
5211 smooth.coefficient_lower_bounds = if any_bounds {
5212 Some(coefficient_lower_bounds)
5213 } else {
5214 None
5215 };
5216 smooth.linear_constraints = if linear_constraintrows.is_empty() {
5217 None
5218 } else {
5219 let mut a = Array2::<f64>::zeros((linear_constraintrows.len(), total_p));
5220 for (i, row) in linear_constraintrows.iter().enumerate() {
5221 a.row_mut(i).assign(row);
5222 }
5223 Some(LinearInequalityConstraints {
5224 a,
5225 b: Array1::from_vec(linear_constraint_b),
5226 })
5227 };
5228 smooth.dropped_penaltyinfo = dropped_penaltyinfo_by_term
5229 .iter()
5230 .flat_map(|infos| infos.iter().cloned())
5231 .collect();
5232 Ok(())
5233}
5234
5235fn rebuild_term_collection_auxiliary_state(
5236 spec: &TermCollectionSpec,
5237 design: &mut TermCollectionDesign,
5238) -> Result<(), String> {
5239 if spec.linear_terms.len() != design.linear_ranges.len() {
5240 return Err(SmoothError::dimension_mismatch(format!(
5241 "term-collection linear bookkeeping mismatch: spec_terms={}, design_ranges={}",
5242 spec.linear_terms.len(),
5243 design.linear_ranges.len()
5244 ))
5245 .into());
5246 }
5247
5248 let p_total = design.design.ncols();
5249 let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
5250 let mut coefficient_lower_bounds = Array1::<f64>::from_elem(p_total, f64::NEG_INFINITY);
5251 let mut any_bounds = false;
5252 let mut linear_constraintrows: Vec<Array1<f64>> = Vec::new();
5253 let mut linear_constraint_b: Vec<f64> = Vec::new();
5254
5255 for (linear, (_, range)) in spec.linear_terms.iter().zip(design.linear_ranges.iter()) {
5256 if range.len() != 1 {
5257 return Err(SmoothError::dimension_mismatch(format!(
5258 "linear term '{}' expected one coefficient column, found {}",
5259 linear.name,
5260 range.len()
5261 ))
5262 .into());
5263 }
5264 let col = range.start;
5265 if let Some(lb) = linear.coefficient_min {
5266 let mut row = Array1::<f64>::zeros(p_total);
5267 row[col] = 1.0;
5268 linear_constraintrows.push(row);
5269 linear_constraint_b.push(lb);
5270 }
5271 if let Some(ub) = linear.coefficient_max {
5272 let mut row = Array1::<f64>::zeros(p_total);
5273 row[col] = -1.0;
5274 linear_constraintrows.push(row);
5275 linear_constraint_b.push(-ub);
5276 }
5277 }
5278
5279 if let Some(lb_smooth) = design.smooth.coefficient_lower_bounds.as_ref() {
5280 if lb_smooth.len() != design.smooth.total_smooth_cols() {
5281 return Err(SmoothError::dimension_mismatch(format!(
5282 "smooth lower-bound width mismatch: bounds={}, smooth_cols={}",
5283 lb_smooth.len(),
5284 design.smooth.total_smooth_cols()
5285 ))
5286 .into());
5287 }
5288 coefficient_lower_bounds
5289 .slice_mut(s![
5290 smooth_start..(smooth_start + design.smooth.total_smooth_cols())
5291 ])
5292 .assign(lb_smooth);
5293 any_bounds = true;
5294 }
5295 if let Some(lin_smooth) = design.smooth.linear_constraints.as_ref() {
5296 if lin_smooth.a.ncols() != design.smooth.total_smooth_cols() {
5297 return Err(SmoothError::dimension_mismatch(format!(
5298 "smooth linear-constraint width mismatch: cols={}, smooth_cols={}",
5299 lin_smooth.a.ncols(),
5300 design.smooth.total_smooth_cols()
5301 ))
5302 .into());
5303 }
5304 let mut a_global = Array2::<f64>::zeros((lin_smooth.a.nrows(), p_total));
5305 a_global
5306 .slice_mut(s![
5307 ..,
5308 smooth_start..(smooth_start + design.smooth.total_smooth_cols())
5309 ])
5310 .assign(&lin_smooth.a);
5311 for r in 0..a_global.nrows() {
5312 linear_constraintrows.push(a_global.row(r).to_owned());
5313 linear_constraint_b.push(lin_smooth.b[r]);
5314 }
5315 }
5316
5317 let lower_bound_constraints = if any_bounds {
5318 linear_constraints_from_lower_bounds_global(&coefficient_lower_bounds)
5319 } else {
5320 None
5321 };
5322 let explicit_linear_constraints = if linear_constraintrows.is_empty() {
5323 None
5324 } else {
5325 let mut a = Array2::<f64>::zeros((linear_constraintrows.len(), p_total));
5326 for (i, row) in linear_constraintrows.iter().enumerate() {
5327 a.row_mut(i).assign(row);
5328 }
5329 Some(LinearInequalityConstraints {
5330 a,
5331 b: Array1::from_vec(linear_constraint_b),
5332 })
5333 };
5334
5335 design.coefficient_lower_bounds = if any_bounds {
5336 Some(coefficient_lower_bounds)
5337 } else {
5338 None
5339 };
5340 design.linear_constraints =
5341 merge_linear_constraints_global(explicit_linear_constraints, lower_bound_constraints);
5342 design.dropped_penaltyinfo = design.smooth.dropped_penaltyinfo.clone();
5343 Ok(())
5344}
5345
5346fn theta_values_match(left: &Array1<f64>, right: &Array1<f64>) -> bool {
5347 left.len() == right.len()
5348 && left
5349 .iter()
5350 .zip(right.iter())
5351 .all(|(&l, &r)| l.to_bits() == r.to_bits())
5352}
5353
5354fn latent_values_match(left: &Array1<f64>, right: &Array1<f64>) -> bool {
5355 theta_values_match(left, right)
5356}
5357
5358fn spatial_aniso_matches(left: Option<&[f64]>, right: Option<&[f64]>) -> bool {
5359 match (left, right) {
5360 (None, None) => true,
5361 (Some(a), Some(b)) => {
5362 a.len() == b.len()
5363 && a.iter()
5364 .zip(b.iter())
5365 .all(|(&x, &y)| x.to_bits() == y.to_bits())
5366 }
5367 _ => false,
5368 }
5369}
5370
5371fn spatial_length_scale_matches(left: Option<f64>, right: Option<f64>) -> bool {
5372 match (left, right) {
5373 (None, None) => true,
5374 (Some(a), Some(b)) => a.to_bits() == b.to_bits(),
5375 _ => false,
5376 }
5377}
5378
5379struct FrozenTermCollectionIncrementalRealizer<'d> {
5380 data: ArrayView2<'d, f64>,
5381 spec: TermCollectionSpec,
5382 design: TermCollectionDesign,
5383 fixed_blocks: Vec<DesignBlock>,
5384 dropped_penaltyinfo_by_term: Vec<Vec<DroppedPenaltyBlockInfo>>,
5385 smooth_penalty_ranges: Vec<Range<usize>>,
5386 full_penalty_ranges: Vec<Range<usize>>,
5387 basisworkspace: gam_terms::basis::BasisWorkspace,
5391 spatial_realization_geometry: Vec<Option<SmoothTermSpec>>,
5404 design_revision: u64,
5410}
5411
5412impl<'d> std::fmt::Debug for FrozenTermCollectionIncrementalRealizer<'d> {
5413 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
5414 f.debug_struct("FrozenTermCollectionIncrementalRealizer")
5415 .field("data_shape", &(self.data.nrows(), self.data.ncols()))
5416 .field("fixed_blocks", &self.fixed_blocks.len())
5417 .finish_non_exhaustive()
5418 }
5419}
5420
5421impl<'d> FrozenTermCollectionIncrementalRealizer<'d> {
5422 fn new(
5423 data: ArrayView2<'d, f64>,
5424 spec: TermCollectionSpec,
5425 design: TermCollectionDesign,
5426 ) -> Result<Self, String> {
5427 if spec.smooth_terms.len() != design.smooth.terms.len() {
5428 return Err(SmoothError::dimension_mismatch(format!(
5429 "incremental realizer smooth term mismatch: spec_terms={}, design_terms={}",
5430 spec.smooth_terms.len(),
5431 design.smooth.terms.len()
5432 ))
5433 .into());
5434 }
5435
5436 let mut smooth_cursor = 0usize;
5437 let mut smooth_penalty_ranges = Vec::with_capacity(design.smooth.terms.len());
5438 for term in &design.smooth.terms {
5439 let next = smooth_cursor + term.penalties_local.len();
5440 smooth_penalty_ranges.push(smooth_cursor..next);
5441 smooth_cursor = next;
5442 }
5443 if smooth_cursor != design.smooth.penalties.len() {
5444 return Err(SmoothError::dimension_mismatch(format!(
5445 "incremental realizer smooth penalty mismatch: ranged={}, actual={}",
5446 smooth_cursor,
5447 design.smooth.penalties.len()
5448 ))
5449 .into());
5450 }
5451
5452 let fixed_penalty_offset = design
5453 .penalties
5454 .len()
5455 .checked_sub(design.smooth.penalties.len())
5456 .ok_or_else(|| {
5457 "incremental realizer encountered invalid penalty bookkeeping".to_string()
5458 })?;
5459 let full_penalty_ranges = smooth_penalty_ranges
5460 .iter()
5461 .map(|range| (fixed_penalty_offset + range.start)..(fixed_penalty_offset + range.end))
5462 .collect::<Vec<_>>();
5463 let fixed_blocks = build_term_collection_fixed_blocks(data, &spec)
5464 .map_err(|e| format!("failed to cache fixed term-collection blocks: {e}"))?;
5465
5466 let mut dropped_penaltyinfo_by_term = Vec::with_capacity(spec.smooth_terms.len());
5467 for (term_idx, termspec) in spec.smooth_terms.iter().enumerate() {
5468 let realization =
5469 build_single_smooth_term_realization(data, termspec).map_err(|e| {
5470 format!(
5471 "failed to build cached realization for smooth term '{}' (index {}): {e}",
5472 termspec.name, term_idx
5473 )
5474 })?;
5475 let expected_cols = design.smooth.terms[term_idx].coeff_range.len();
5476 if realization.design_local.ncols() != expected_cols {
5477 return Err(SmoothError::dimension_mismatch(format!(
5478 "cached realization width mismatch for term '{}': cached_cols={}, design_cols={}",
5479 termspec.name,
5480 realization.design_local.ncols(),
5481 expected_cols
5482 ))
5483 .into());
5484 }
5485 if realization.active_penaltyinfo().len()
5486 != design.smooth.terms[term_idx].penalties_local.len()
5487 {
5488 return Err(SmoothError::dimension_mismatch(format!(
5489 "cached realization penalty mismatch for term '{}': cached_penalties={}, design_penalties={}",
5490 termspec.name,
5491 realization.active_penaltyinfo().len(),
5492 design.smooth.terms[term_idx].penalties_local.len()
5493 ))
5494 .into());
5495 }
5496 dropped_penaltyinfo_by_term.push(realization.dropped_penaltyinfo);
5497 }
5498
5499 let geometry_slots = spec.smooth_terms.len();
5500 Ok(Self {
5501 data,
5502 spec,
5503 design,
5504 fixed_blocks,
5505 dropped_penaltyinfo_by_term,
5506 smooth_penalty_ranges,
5507 full_penalty_ranges,
5508 basisworkspace: gam_terms::basis::BasisWorkspace::new(),
5509 spatial_realization_geometry: vec![None; geometry_slots],
5510 design_revision: 0,
5511 })
5512 }
5513
5514 fn design_revision(&self) -> u64 {
5515 self.design_revision
5516 }
5517
5518 fn spec(&self) -> &TermCollectionSpec {
5519 &self.spec
5520 }
5521
5522 fn design(&self) -> &TermCollectionDesign {
5523 &self.design
5524 }
5525
5526 fn supports_nfree_penalty_rekey(&self, spatial_terms: &[usize]) -> bool {
5567 if spatial_terms.len() != 1 {
5568 return false;
5569 }
5570 let term_idx = spatial_terms[0];
5571 matches!(
5572 self.design.smooth.terms.get(term_idx).map(|t| &t.metadata),
5573 Some(BasisMetadata::Duchon { .. } | BasisMetadata::ThinPlate { .. })
5574 )
5575 }
5576
5577 fn supports_nfree_gradient_only_routing(&self, spatial_terms: &[usize]) -> bool {
5586 if spatial_terms.len() != 1 {
5587 return false;
5588 }
5589 let term_idx = spatial_terms[0];
5590 matches!(
5591 self.design.smooth.terms.get(term_idx).map(|t| &t.metadata),
5592 Some(BasisMetadata::Duchon { .. } | BasisMetadata::ThinPlate { .. })
5593 )
5594 }
5595
5596 fn canonical_penalties_at_psi(
5609 &mut self,
5610 spatial_terms: &[usize],
5611 psi: &[f64],
5612 ) -> Result<(Vec<gam_terms::construction::CanonicalPenalty>, Vec<usize>), String> {
5613 if spatial_terms.len() != 1 {
5614 return Err(format!(
5615 "n-free penalty re-key requires exactly one spatial term, found {}",
5616 spatial_terms.len()
5617 ));
5618 }
5619 let term_idx = spatial_terms[0];
5620 let (ls_opt, aniso_from_psi) = spatial_term_psi_to_length_scale_and_aniso(psi);
5626 let termspec =
5629 self.spec.smooth_terms.get(term_idx).ok_or_else(|| {
5630 format!("spatial term {term_idx} out of range for n-free penalty")
5631 })?;
5632 let term = self
5633 .design
5634 .smooth
5635 .terms
5636 .get(term_idx)
5637 .ok_or_else(|| format!("realized smooth term {term_idx} out of range"))?;
5638 let p_total = self.design.design.ncols();
5641 let (locals, nullspace_dims): (Vec<Array2<f64>>, Vec<usize>) = match &term.metadata {
5642 BasisMetadata::Duchon {
5643 centers,
5644 identifiability_transform,
5645 operator_collocation_points,
5646 power,
5647 nullspace_order,
5648 aniso_log_scales,
5649 input_scales,
5650 radial_reparam,
5651 ..
5652 } => {
5653 let operator_penalties = match &termspec.basis {
5654 SmoothBasisSpec::Duchon { spec, .. } => spec.operator_penalties.clone(),
5655 _ => gam_terms::basis::DuchonOperatorPenaltySpec::default(),
5656 };
5657 let effective_ls = match input_scales.as_deref() {
5664 Some(scales) => {
5665 compensate_optional_length_scale_for_standardization(ls_opt, scales)
5666 }
5667 None => ls_opt,
5668 };
5669 gam_terms::basis::duchon_penalties_at_length_scale(
5670 centers.view(),
5671 identifiability_transform.as_ref(),
5672 operator_collocation_points.as_ref().map(|p| p.view()),
5673 &operator_penalties,
5674 *power,
5675 *nullspace_order,
5676 aniso_log_scales.as_deref(),
5677 radial_reparam.as_ref(),
5678 effective_ls,
5679 &mut self.basisworkspace,
5680 )
5681 .map_err(|e| e.to_string())?
5682 }
5683 BasisMetadata::Matern {
5684 centers,
5685 periodic,
5686 nu,
5687 include_intercept,
5688 identifiability_transform,
5689 aniso_log_scales,
5690 input_scales,
5691 ..
5692 } => {
5693 let ls = ls_opt.ok_or_else(|| {
5700 "Matérn n-free penalty re-key requires a finite length-scale".to_string()
5701 })?;
5702 let effective_ls = match input_scales.as_deref() {
5703 Some(scales) => compensate_length_scale_for_standardization(ls, scales),
5704 None => ls,
5705 };
5706 let aniso_for_penalty = aniso_from_psi.as_deref().or(aniso_log_scales.as_deref());
5707 let (penalties, nullspace_dims, _info) =
5718 matern_operator_penalty_triplet_at_length_scale(
5719 centers.view(),
5720 periodic.as_deref(),
5721 identifiability_transform.as_ref(),
5722 *nu,
5723 *include_intercept,
5724 aniso_for_penalty,
5725 effective_ls,
5726 )
5727 .map_err(|e| e.to_string())?;
5728 (penalties, nullspace_dims)
5729 }
5730 BasisMetadata::ThinPlate {
5731 centers,
5732 identifiability_transform,
5733 radial_reparam,
5734 ..
5735 } => {
5736 let ls = ls_opt.ok_or_else(|| {
5737 "thin-plate n-free penalty re-key requires a finite length-scale".to_string()
5738 })?;
5739 let double_penalty = match &termspec.basis {
5740 SmoothBasisSpec::ThinPlate { spec, .. } => spec.double_penalty,
5741 _ => false,
5742 };
5743 gam_terms::basis::thin_plate_penalties_at_length_scale(
5744 centers.view(),
5745 identifiability_transform.as_ref(),
5746 radial_reparam.as_ref(),
5747 ls,
5748 double_penalty,
5749 &mut self.basisworkspace,
5750 )
5751 .map_err(|e| e.to_string())?
5752 }
5753 other => {
5754 return Err(format!(
5755 "n-free penalty re-key unsupported for basis metadata {:?}",
5756 std::mem::discriminant(other)
5757 ));
5758 }
5759 };
5760 let templates = &self.design.penalties;
5765 if templates.len() != locals.len() {
5766 return Err(format!(
5767 "n-free penalty re-key produced {} blocks but the frozen design carries {} \
5768 — penalty topology is not ψ-stable",
5769 locals.len(),
5770 templates.len()
5771 ));
5772 }
5773 let specs: Vec<gam_solve::estimate::PenaltySpec> = templates
5774 .iter()
5775 .zip(locals.into_iter())
5776 .map(|(tmpl, local)| gam_solve::estimate::PenaltySpec::Block {
5777 local,
5778 col_range: tmpl.col_range.clone(),
5779 prior_mean: tmpl.prior_mean.clone(),
5780 structure_hint: tmpl.structure_hint.clone(),
5781 op: tmpl.op.clone(),
5782 })
5783 .collect();
5784 gam_terms::construction::canonicalize_penalty_specs(
5785 &specs,
5786 &nullspace_dims,
5787 p_total,
5788 "nfree-psi-penalty",
5789 )
5790 .map_err(|e| e.to_string())
5791 }
5792
5793 fn canonical_penalty_derivatives_at_psi(
5794 &mut self,
5795 spatial_terms: &[usize],
5796 psi: &[f64],
5797 ) -> Result<(Range<usize>, usize, Vec<Array2<f64>>), String> {
5798 if spatial_terms.len() != 1 {
5799 return Err(format!(
5800 "n-free penalty derivative re-key requires exactly one spatial term, found {}",
5801 spatial_terms.len()
5802 ));
5803 }
5804 let term_idx = spatial_terms[0];
5805 let (ls_opt, aniso_from_psi) = spatial_term_psi_to_length_scale_and_aniso(psi);
5806 let termspec = self.spec.smooth_terms.get(term_idx).ok_or_else(|| {
5807 format!("spatial term {term_idx} out of range for n-free penalty derivative")
5808 })?;
5809 let term = self
5810 .design
5811 .smooth
5812 .terms
5813 .get(term_idx)
5814 .ok_or_else(|| format!("realized smooth term {term_idx} out of range"))?;
5815 let p_total = self.design.design.ncols();
5816 let smooth_start = p_total.saturating_sub(self.design.smooth.total_smooth_cols());
5817 let global_range =
5818 (smooth_start + term.coeff_range.start)..(smooth_start + term.coeff_range.end);
5819
5820 let locals = match &term.metadata {
5821 BasisMetadata::Duchon {
5822 centers,
5823 identifiability_transform,
5824 operator_collocation_points,
5825 power,
5826 nullspace_order,
5827 aniso_log_scales,
5828 input_scales,
5829 radial_reparam,
5830 ..
5831 } => {
5832 let mut spec = match &termspec.basis {
5833 SmoothBasisSpec::Duchon { spec, .. } => spec.clone(),
5834 _ => {
5835 return Err(
5836 "Duchon n-free penalty derivative requires a Duchon term spec"
5837 .to_string(),
5838 );
5839 }
5840 };
5841 let effective_ls = match input_scales.as_deref() {
5842 Some(scales) => {
5843 compensate_optional_length_scale_for_standardization(ls_opt, scales)
5844 }
5845 None => ls_opt,
5846 };
5847 spec.length_scale = effective_ls;
5848 spec.power = *power;
5849 spec.nullspace_order = *nullspace_order;
5850 spec.aniso_log_scales = aniso_log_scales.clone();
5851 spec.radial_reparam = radial_reparam.clone();
5854 if spec.length_scale.is_none() {
5855 return Err(
5856 "Duchon n-free penalty derivative requires a hybrid length-scale"
5857 .to_string(),
5858 );
5859 }
5860 let collocation = operator_collocation_points
5861 .as_ref()
5862 .map(|points| points.view())
5863 .unwrap_or_else(|| centers.view());
5864 let (_native_sources, mut first, _native_second) =
5865 gam_terms::basis::build_duchon_native_penalty_psi_derivatives(
5866 centers.view(),
5867 &spec,
5868 identifiability_transform.as_ref(),
5869 &mut self.basisworkspace,
5870 )
5871 .map_err(|e| e.to_string())?;
5872 let (_operator_sources, operator_first, _operator_second) =
5873 gam_terms::basis::build_duchon_operator_penalty_psi_derivatives(
5874 collocation,
5875 centers.view(),
5876 &spec,
5877 identifiability_transform.as_ref(),
5878 &mut self.basisworkspace,
5879 )
5880 .map_err(|e| e.to_string())?;
5881 first.extend(operator_first);
5882 first
5883 }
5884 BasisMetadata::Matern {
5885 centers,
5886 periodic,
5887 nu,
5888 include_intercept,
5889 identifiability_transform,
5890 aniso_log_scales,
5891 input_scales,
5892 ..
5893 } => {
5894 let ls = ls_opt.ok_or_else(|| {
5895 "Matérn n-free penalty derivative requires a finite length-scale".to_string()
5896 })?;
5897 let effective_ls = match input_scales.as_deref() {
5898 Some(scales) => compensate_length_scale_for_standardization(ls, scales),
5899 None => ls,
5900 };
5901 let penalty_centers =
5902 gam_terms::basis::expand_periodic_centers(¢ers.to_owned(), periodic.as_deref())
5903 .map_err(|e| e.to_string())?;
5904 let aniso_for_penalty = aniso_from_psi.as_deref().or(aniso_log_scales.as_deref());
5905 let (first, _second) = gam_terms::basis::build_matern_operator_penalty_psi_derivatives(
5906 penalty_centers.view(),
5907 effective_ls,
5908 *nu,
5909 *include_intercept,
5910 identifiability_transform.as_ref(),
5911 aniso_for_penalty,
5912 )
5913 .map_err(|e| e.to_string())?;
5914 first
5915 }
5916 BasisMetadata::ThinPlate {
5917 centers,
5918 identifiability_transform,
5919 radial_reparam,
5920 ..
5921 } => {
5922 let ls = ls_opt.ok_or_else(|| {
5923 "thin-plate n-free penalty derivative requires a finite length-scale"
5924 .to_string()
5925 })?;
5926 let mut spec = match &termspec.basis {
5927 SmoothBasisSpec::ThinPlate { spec, .. } => spec.clone(),
5928 _ => {
5929 return Err(
5930 "thin-plate n-free penalty derivative requires a ThinPlate term spec"
5931 .to_string(),
5932 );
5933 }
5934 };
5935 spec.length_scale = ls;
5936 if spec.radial_reparam.is_none() {
5937 spec.radial_reparam = radial_reparam.clone();
5938 }
5939 let (primary, _primary_second) =
5940 gam_terms::basis::build_thin_plate_penalty_psi_derivativeswithworkspace(
5941 centers.view(),
5942 &spec,
5943 identifiability_transform.as_ref(),
5944 &mut self.basisworkspace,
5945 )
5946 .map_err(|e| e.to_string())?;
5947 if self.design.penalties.len() > 1 {
5948 vec![primary.clone(), Array2::<f64>::zeros(primary.raw_dim())]
5949 } else {
5950 vec![primary]
5951 }
5952 }
5953 other => {
5954 return Err(format!(
5955 "n-free penalty derivative re-key unsupported for basis metadata {:?}",
5956 std::mem::discriminant(other)
5957 ));
5958 }
5959 };
5960 if locals.len() != self.design.penalties.len() {
5961 return Err(format!(
5962 "n-free penalty derivative re-key produced {} blocks but the frozen design carries {} \
5963 — penalty topology is not ψ-stable",
5964 locals.len(),
5965 self.design.penalties.len()
5966 ));
5967 }
5968 Ok((global_range, p_total, locals))
5969 }
5970
5971 fn apply_log_kappa(
5972 &mut self,
5973 log_kappa: &SpatialLogKappaCoords,
5974 term_indices: &[usize],
5975 ) -> Result<(), String> {
5976 if term_indices.len() != log_kappa.dims_per_term().len() {
5977 return Err(SmoothError::dimension_mismatch(format!(
5978 "incremental realizer log-kappa term mismatch: term_indices={}, dims_per_term={}",
5979 term_indices.len(),
5980 log_kappa.dims_per_term().len()
5981 ))
5982 .into());
5983 }
5984
5985 let mut any_changed = false;
5986 for (slot, &term_idx) in term_indices.iter().enumerate() {
5987 any_changed |= self.apply_log_kappa_to_term(term_idx, log_kappa.term_slice(slot))?;
5988 }
5989
5990 if any_changed {
5991 self.refresh_full_design_operator()?;
5992 rebuild_smooth_auxiliary_state(
5993 &mut self.design.smooth,
5994 &self.dropped_penaltyinfo_by_term,
5995 )?;
5996 rebuild_term_collection_auxiliary_state(&self.spec, &mut self.design)?;
5997 self.design_revision = self.design_revision.wrapping_add(1);
5998 }
5999 Ok(())
6000 }
6001
6002 fn apply_log_kappa_to_term(&mut self, term_idx: usize, psi: &[f64]) -> Result<bool, String> {
6003 if !spatial_term_supports_hyper_optimization(&self.spec, term_idx) {
6004 return Err(SmoothError::invalid_config(format!(
6005 "incremental realizer term {term_idx} does not expose spatial hyperparameters"
6006 ))
6007 .into());
6008 }
6009 let measure_jet_term = measure_jet_term_spec(&self.spec, term_idx).is_some();
6013 let constant_curvature_term = constant_curvature_term_spec(&self.spec, term_idx).is_some();
6017 let mut next_length_scale = None;
6018 let mut next_aniso: Option<Vec<f64>> = None;
6019 if measure_jet_term {
6020 if !set_measure_jet_psi_dials(&mut self.spec, term_idx, psi)
6021 .map_err(|e| e.to_string())?
6022 {
6023 return Ok(false);
6024 }
6025 } else if constant_curvature_term {
6026 if !set_constant_curvature_kappa(&mut self.spec, term_idx, psi)
6027 .map_err(|e| e.to_string())?
6028 {
6029 return Ok(false);
6030 }
6031 } else {
6032 let current_length_scale = get_spatial_length_scale(&self.spec, term_idx);
6033 let current_aniso = get_spatial_aniso_log_scales(&self.spec, term_idx);
6034 let (ls, eta) = spatial_term_psi_to_length_scale_and_aniso(psi);
6035 next_length_scale = ls;
6036 next_aniso = eta;
6037 let same_length = spatial_length_scale_matches(current_length_scale, next_length_scale);
6038 let same_aniso = spatial_aniso_matches(current_aniso.as_deref(), next_aniso.as_deref());
6039 if same_length && same_aniso {
6040 return Ok(false);
6041 }
6042 if let Some(length_scale) = next_length_scale {
6043 set_spatial_length_scale(&mut self.spec, term_idx, length_scale)
6044 .map_err(|e| e.to_string())?;
6045 }
6046 if let Some(eta) = next_aniso.clone() {
6047 set_spatial_aniso_log_scales(&mut self.spec, term_idx, eta)
6048 .map_err(|e| e.to_string())?;
6049 }
6050 }
6051
6052 let geometry_slot = self
6063 .spatial_realization_geometry
6064 .get(term_idx)
6065 .ok_or_else(|| format!("incremental realizer geometry slot {term_idx} out of range"))?;
6066 let mut build_spec = match geometry_slot {
6067 Some(cached) => cached.clone(),
6068 None => self
6069 .spec
6070 .smooth_terms
6071 .get(term_idx)
6072 .ok_or_else(|| format!("incremental realizer smooth term {term_idx} out of range"))?
6073 .clone(),
6074 };
6075 if measure_jet_term {
6076 set_single_term_measure_jet_psi_dials(&mut build_spec, psi)
6080 .map_err(|e| e.to_string())?;
6081 } else if constant_curvature_term {
6082 set_single_term_constant_curvature_kappa(&mut build_spec, psi)
6087 .map_err(|e| e.to_string())?;
6088 } else {
6089 if let Some(length_scale) = next_length_scale {
6090 set_single_term_spatial_length_scale(&mut build_spec, length_scale)
6091 .map_err(|e| e.to_string())?;
6092 }
6093 if let Some(eta) = next_aniso {
6094 set_single_term_spatial_aniso_log_scales(&mut build_spec, eta)
6095 .map_err(|e| e.to_string())?;
6096 }
6097 }
6098
6099 let termname = build_spec.name.clone();
6100 let local = build_single_local_smooth_term(
6101 self.data,
6102 &build_spec,
6103 &mut self.basisworkspace,
6104 )
6105 .map_err(|e| {
6106 format!(
6107 "failed to rebuild smooth term '{termname}' during incremental κ realization: {e}"
6108 )
6109 })?;
6110
6111 if self.spatial_realization_geometry[term_idx].is_none()
6116 && let Some(frozen) = freeze_geometry_from_metadata(&build_spec, &local.metadata)
6117 {
6118 if let (
6130 SmoothBasisSpec::Matern {
6131 spec: frozen_spec, ..
6132 },
6133 Some(SmoothBasisSpec::Matern {
6134 spec: live_spec, ..
6135 }),
6136 ) = (
6137 &frozen.basis,
6138 self.spec
6139 .smooth_terms
6140 .get_mut(term_idx)
6141 .map(|t| &mut t.basis),
6142 ) {
6143 live_spec.identifiability = frozen_spec.identifiability.clone();
6144 live_spec.center_strategy = frozen_spec.center_strategy.clone();
6145 }
6146 self.spatial_realization_geometry[term_idx] = Some(frozen);
6147 }
6148
6149 let realization = wrap_local_build_as_realization(local, &build_spec)?;
6150 self.replace_term_realization(term_idx, realization)?;
6151 Ok(true)
6152 }
6153
6154 fn replace_term_realization(
6155 &mut self,
6156 term_idx: usize,
6157 realization: SingleSmoothTermRealization,
6158 ) -> Result<(), String> {
6159 let t_replace = std::time::Instant::now();
6160 let SingleSmoothTermRealization {
6161 design_local,
6162 term,
6163 dropped_penaltyinfo,
6164 } = realization;
6165 let SmoothTerm {
6166 name,
6167 penalties_local,
6168 nullspace_dims,
6169 penaltyinfo_local,
6170 metadata,
6171 lower_bounds_local,
6172 linear_constraints_local,
6173 joint_null_rotation,
6174 ..
6175 } = term;
6176 let coeff_range = self
6177 .design
6178 .smooth
6179 .terms
6180 .get(term_idx)
6181 .ok_or_else(|| format!("incremental realizer smooth term {term_idx} out of range"))?
6182 .coeff_range
6183 .clone();
6184 if design_local.ncols() != coeff_range.len() {
6185 return Err(SmoothError::dimension_mismatch(format!(
6186 "incremental realizer width mismatch for term {}: rebuilt_cols={}, cached_cols={}",
6187 term_idx,
6188 design_local.ncols(),
6189 coeff_range.len()
6190 ))
6191 .into());
6192 }
6193 if design_local.nrows() != self.design.design.nrows() {
6194 return Err(SmoothError::dimension_mismatch(format!(
6195 "incremental realizer row mismatch for term {}: rebuilt_rows={}, design_rows={}",
6196 term_idx,
6197 design_local.nrows(),
6198 self.design.design.nrows()
6199 ))
6200 .into());
6201 }
6202
6203 let active_penaltyinfo = penaltyinfo_local
6204 .iter()
6205 .filter(|info| info.active)
6206 .cloned()
6207 .collect::<Vec<_>>();
6208 let smooth_penalty_range = self
6209 .smooth_penalty_ranges
6210 .get(term_idx)
6211 .ok_or_else(|| {
6212 format!("incremental realizer missing smooth penalty range for term {term_idx}")
6213 })?
6214 .clone();
6215 let full_penalty_range = self
6216 .full_penalty_ranges
6217 .get(term_idx)
6218 .ok_or_else(|| {
6219 format!("incremental realizer missing full penalty range for term {term_idx}")
6220 })?
6221 .clone();
6222 if active_penaltyinfo.len() != smooth_penalty_range.len()
6223 || penalties_local.len() != smooth_penalty_range.len()
6224 || nullspace_dims.len() != smooth_penalty_range.len()
6225 {
6226 return Err(SmoothError::dimension_mismatch(format!(
6227 "incremental realizer topology changed for term '{}': penalties={}, infos={}, nullspaces={}, cached_penalties={}",
6228 name,
6229 penalties_local.len(),
6230 active_penaltyinfo.len(),
6231 nullspace_dims.len(),
6232 smooth_penalty_range.len()
6233 ))
6234 .into());
6235 }
6236
6237 self.design.smooth.term_designs[term_idx] = design_local;
6238
6239 for (offset, penalty_local) in penalties_local.iter().enumerate() {
6240 let smooth_penalty_idx = smooth_penalty_range.start + offset;
6241 let full_penalty_idx = full_penalty_range.start + offset;
6242 let nullspace_dim = nullspace_dims[offset];
6243 let penalty_info = active_penaltyinfo[offset].clone();
6244
6245 if penalty_local.nrows() != coeff_range.len()
6246 || penalty_local.ncols() != coeff_range.len()
6247 {
6248 return Err(SmoothError::dimension_mismatch(format!(
6249 "incremental realizer penalty shape mismatch for term '{}' penalty {}: \
6250 penalty is {}x{} but coeff_range has {} columns",
6251 name,
6252 offset,
6253 penalty_local.nrows(),
6254 penalty_local.ncols(),
6255 coeff_range.len()
6256 ))
6257 .into());
6258 }
6259
6260 let smooth_penalty = self
6261 .design
6262 .smooth
6263 .penalties
6264 .get_mut(smooth_penalty_idx)
6265 .ok_or_else(|| {
6266 format!(
6267 "incremental realizer smooth penalty {} out of range for term {}",
6268 smooth_penalty_idx, term_idx
6269 )
6270 })?;
6271 smooth_penalty.local.assign(penalty_local);
6274
6275 let full_bp = self
6276 .design
6277 .penalties
6278 .get_mut(full_penalty_idx)
6279 .ok_or_else(|| {
6280 format!(
6281 "incremental realizer full penalty {} out of range for term {}",
6282 full_penalty_idx, term_idx
6283 )
6284 })?;
6285 full_bp.local.assign(penalty_local);
6288
6289 self.design.smooth.nullspace_dims[smooth_penalty_idx] = nullspace_dim;
6290 self.design.nullspace_dims[full_penalty_idx] = nullspace_dim;
6291
6292 self.design.smooth.penaltyinfo[smooth_penalty_idx].global_index = smooth_penalty_idx;
6293 self.design.smooth.penaltyinfo[smooth_penalty_idx].termname = Some(name.clone());
6294 self.design.smooth.penaltyinfo[smooth_penalty_idx].penalty = penalty_info.clone();
6295
6296 self.design.penaltyinfo[full_penalty_idx].global_index = full_penalty_idx;
6297 self.design.penaltyinfo[full_penalty_idx].termname = Some(name.clone());
6298 self.design.penaltyinfo[full_penalty_idx].penalty = penalty_info;
6299 }
6300
6301 let target_term = self.design.smooth.terms.get_mut(term_idx).ok_or_else(|| {
6302 format!("incremental realizer smooth term {term_idx} disappeared during replacement")
6303 })?;
6304 target_term.penalties_local = penalties_local;
6305 target_term.nullspace_dims = nullspace_dims;
6306 target_term.penaltyinfo_local = penaltyinfo_local;
6307 target_term.metadata = metadata;
6308 target_term.lower_bounds_local = lower_bounds_local;
6309 target_term.linear_constraints_local = linear_constraints_local;
6310 target_term.joint_null_rotation = joint_null_rotation;
6311 self.dropped_penaltyinfo_by_term[term_idx] = dropped_penaltyinfo;
6312 log::info!(
6313 "[STAGE] smooth basis rebuild (term {}, '{}', cols={}): {:.3}s",
6314 term_idx,
6315 target_term.name,
6316 coeff_range.len(),
6317 t_replace.elapsed().as_secs_f64(),
6318 );
6319 Ok(())
6320 }
6321
6322 fn refresh_full_design_operator(&mut self) -> Result<(), String> {
6323 let mut blocks = Vec::<DesignBlock>::with_capacity(
6324 self.fixed_blocks.len() + self.design.smooth.term_designs.len(),
6325 );
6326 blocks.extend(self.fixed_blocks.iter().cloned());
6327 for term_design in &self.design.smooth.term_designs {
6328 blocks.push(DesignBlock::from(term_design));
6329 }
6330 self.design.design = assemble_term_collection_design_matrix(blocks)
6331 .map_err(|e| format!("failed to refresh term-collection design: {e}"))?;
6332 Ok(())
6333 }
6334}
6335
6336fn build_term_collection_fixed_blocks(
6337 data: ArrayView2<'_, f64>,
6338 spec: &TermCollectionSpec,
6339) -> Result<Vec<DesignBlock>, BasisError> {
6340 let mut blocks = Vec::<DesignBlock>::new();
6341 if !term_collection_has_one_sided_anchored_bspline(spec) {
6342 blocks.push(DesignBlock::Intercept(data.nrows()));
6343 }
6344
6345 if !spec.linear_terms.is_empty() {
6346 let mut linear_block = Array2::<f64>::zeros((data.nrows(), spec.linear_terms.len()));
6347 for (j, linear) in spec.linear_terms.iter().enumerate() {
6348 let column = linear
6352 .realized_design_column(data)
6353 .map_err(BasisError::InvalidInput)?;
6354 linear_block.column_mut(j).assign(&column);
6355 }
6356 blocks.push(DesignBlock::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
6357 linear_block,
6358 )));
6359 }
6360
6361 for term in &spec.random_effect_terms {
6362 let block = build_random_effect_block(data, term)?;
6363 let re_op = RandomEffectOperator::new(block.group_ids, block.num_groups);
6364 blocks.push(DesignBlock::RandomEffect(Arc::new(re_op)));
6365 }
6366
6367 Ok(blocks)
6368}
6369
6370pub struct SpatialLengthScaleOptimizationResult<FitOut> {
6375 pub resolved_specs: Vec<TermCollectionSpec>,
6376 pub designs: Vec<TermCollectionDesign>,
6377 pub fit: FitOut,
6378 pub timing: Option<SpatialLengthScaleOptimizationTiming>,
6379}
6380
6381#[derive(Debug, Clone)]
6383pub struct ExactJointHyperSetup {
6384 rho0: Array1<f64>,
6385 rho_lower: Array1<f64>,
6386 rho_upper: Array1<f64>,
6387 log_kappa0: SpatialLogKappaCoords,
6388 log_kappa_lower: SpatialLogKappaCoords,
6389 log_kappa_upper: SpatialLogKappaCoords,
6390 auxiliary0: Array1<f64>,
6391 auxiliary_lower: Array1<f64>,
6392 auxiliary_upper: Array1<f64>,
6393}
6394
6395impl ExactJointHyperSetup {
6396 fn sanitize_rho_seed(
6397 rho0: Array1<f64>,
6398 rho_lower: &Array1<f64>,
6399 rho_upper: &Array1<f64>,
6400 ) -> Array1<f64> {
6401 Array1::from_iter(rho0.iter().enumerate().map(|(idx, &value)| {
6402 let lo = rho_lower[idx];
6403 let hi = rho_upper[idx];
6404 let fallback = 0.0_f64.clamp(lo, hi);
6405 if value.is_finite() {
6406 value.clamp(lo, hi)
6407 } else {
6408 fallback
6409 }
6410 }))
6411 }
6412
6413 pub(crate) fn new(
6414 rho0: Array1<f64>,
6415 rho_lower: Array1<f64>,
6416 rho_upper: Array1<f64>,
6417 log_kappa0: SpatialLogKappaCoords,
6418 log_kappa_lower: SpatialLogKappaCoords,
6419 log_kappa_upper: SpatialLogKappaCoords,
6420 ) -> Self {
6421 let rho0 = Self::sanitize_rho_seed(rho0, &rho_lower, &rho_upper);
6422 Self {
6423 rho0,
6424 rho_lower,
6425 rho_upper,
6426 log_kappa0,
6427 log_kappa_lower,
6428 log_kappa_upper,
6429 auxiliary0: Array1::zeros(0),
6430 auxiliary_lower: Array1::zeros(0),
6431 auxiliary_upper: Array1::zeros(0),
6432 }
6433 }
6434
6435 pub(crate) fn with_auxiliary(
6436 mut self,
6437 auxiliary0: Array1<f64>,
6438 auxiliary_lower: Array1<f64>,
6439 auxiliary_upper: Array1<f64>,
6440 ) -> Self {
6441 assert_eq!(
6442 auxiliary0.len(),
6443 auxiliary_lower.len(),
6444 "auxiliary lower bound length mismatch"
6445 );
6446 assert_eq!(
6447 auxiliary0.len(),
6448 auxiliary_upper.len(),
6449 "auxiliary upper bound length mismatch"
6450 );
6451 self.auxiliary0 = Self::sanitize_rho_seed(auxiliary0, &auxiliary_lower, &auxiliary_upper);
6452 self.auxiliary_lower = auxiliary_lower;
6453 self.auxiliary_upper = auxiliary_upper;
6454 self
6455 }
6456
6457 pub(crate) fn rho_dim(&self) -> usize {
6458 self.rho0.len()
6459 }
6460
6461 pub(crate) fn log_kappa_dim(&self) -> usize {
6462 self.log_kappa0.len()
6463 }
6464
6465 pub(crate) fn auxiliary_dim(&self) -> usize {
6466 self.auxiliary0.len()
6467 }
6468
6469 pub(crate) fn theta0(&self) -> Array1<f64> {
6470 let mut out =
6471 Array1::<f64>::zeros(self.rho_dim() + self.log_kappa_dim() + self.auxiliary_dim());
6472 out.slice_mut(s![..self.rho_dim()]).assign(&self.rho0);
6473 out.slice_mut(s![self.rho_dim()..self.rho_dim() + self.log_kappa_dim()])
6474 .assign(self.log_kappa0.as_array());
6475 out.slice_mut(s![self.rho_dim() + self.log_kappa_dim()..])
6476 .assign(&self.auxiliary0);
6477 out
6478 }
6479
6480 pub(crate) fn lower(&self) -> Array1<f64> {
6481 let mut out =
6482 Array1::<f64>::zeros(self.rho_dim() + self.log_kappa_dim() + self.auxiliary_dim());
6483 out.slice_mut(s![..self.rho_dim()]).assign(&self.rho_lower);
6484 out.slice_mut(s![self.rho_dim()..self.rho_dim() + self.log_kappa_dim()])
6485 .assign(self.log_kappa_lower.as_array());
6486 out.slice_mut(s![self.rho_dim() + self.log_kappa_dim()..])
6487 .assign(&self.auxiliary_lower);
6488 out
6489 }
6490
6491 pub(crate) fn upper(&self) -> Array1<f64> {
6492 let mut out =
6493 Array1::<f64>::zeros(self.rho_dim() + self.log_kappa_dim() + self.auxiliary_dim());
6494 out.slice_mut(s![..self.rho_dim()]).assign(&self.rho_upper);
6495 out.slice_mut(s![self.rho_dim()..self.rho_dim() + self.log_kappa_dim()])
6496 .assign(self.log_kappa_upper.as_array());
6497 out.slice_mut(s![self.rho_dim() + self.log_kappa_dim()..])
6498 .assign(&self.auxiliary_upper);
6499 out
6500 }
6501
6502 pub(crate) fn log_kappa_dims_per_term(&self) -> Vec<usize> {
6504 self.log_kappa0.dims_per_term().to_vec()
6505 }
6506}
6507
6508struct ExactJointDesignCache<'d> {
6514 realizers: Vec<FrozenTermCollectionIncrementalRealizer<'d>>,
6515 block_term_indices: Vec<Vec<usize>>,
6516 current_theta: Option<Array1<f64>>,
6517 last_cost: Option<f64>,
6518 last_eval: Option<(
6519 f64,
6520 Array1<f64>,
6521 gam_problem::HessianResult,
6522 )>,
6523 rho_dim: usize,
6524 all_dims: Vec<usize>,
6525 log_kappa_dim: usize,
6526 block_term_counts: Vec<usize>,
6527}
6528
6529impl<'d> ExactJointDesignCache<'d> {
6530 fn new(
6531 data: ArrayView2<'d, f64>,
6532 blocks: Vec<(TermCollectionSpec, TermCollectionDesign, Vec<usize>)>,
6533 rho_dim: usize,
6534 all_dims: Vec<usize>,
6535 ) -> Result<Self, String> {
6536 let n_blocks = blocks.len();
6537 let mut realizers = Vec::with_capacity(n_blocks);
6538 let mut block_term_indices = Vec::with_capacity(n_blocks);
6539 let mut block_term_counts = Vec::with_capacity(n_blocks);
6540
6541 for (spec, design, terms) in blocks {
6542 block_term_counts.push(terms.len());
6543 block_term_indices.push(terms);
6544 realizers.push(FrozenTermCollectionIncrementalRealizer::new(
6545 data, spec, design,
6546 )?);
6547 }
6548
6549 Ok(Self {
6550 realizers,
6551 block_term_indices,
6552 current_theta: None,
6553 last_cost: None,
6554 last_eval: None,
6555 rho_dim,
6556 log_kappa_dim: all_dims.iter().sum(),
6557 all_dims,
6558 block_term_counts,
6559 })
6560 }
6561
6562 fn ensure_theta(&mut self, theta: &Array1<f64>) -> Result<(), String> {
6563 if self
6564 .current_theta
6565 .as_ref()
6566 .is_some_and(|cached| theta_values_match(cached, theta))
6567 {
6568 return Ok(());
6569 }
6570
6571 let t_ensure = std::time::Instant::now();
6572 let kappa_theta_len = self.rho_dim + self.log_kappa_dim;
6573 if theta.len() < kappa_theta_len {
6574 return Err(SmoothError::dimension_mismatch(format!(
6575 "exact-joint theta length mismatch: got {}, expected at least {} (rho_dim={}, log_kappa_dim={})",
6576 theta.len(),
6577 kappa_theta_len,
6578 self.rho_dim,
6579 self.log_kappa_dim
6580 ))
6581 .into());
6582 }
6583 let theta_kappa = theta.slice(s![..kappa_theta_len]).to_owned();
6584 let full_log_kappa = SpatialLogKappaCoords::from_theta_tail_with_dims(
6585 &theta_kappa,
6586 self.rho_dim,
6587 self.all_dims.clone(),
6588 );
6589
6590 let n = self.realizers.len();
6594 let mut remaining = full_log_kappa;
6595 for block_idx in 0..n {
6596 let count = self.block_term_counts[block_idx];
6597 if block_idx < n - 1 {
6598 let (block_lk, rest) = remaining.split_at(count);
6599 self.realizers[block_idx]
6600 .apply_log_kappa(&block_lk, &self.block_term_indices[block_idx])?;
6601 remaining = rest;
6602 } else {
6603 self.realizers[block_idx]
6605 .apply_log_kappa(&remaining, &self.block_term_indices[block_idx])?;
6606 }
6607 }
6608
6609 log::info!(
6610 "[STAGE] ensure_theta (n-block, {} blocks, {} realizers): {:.3}s",
6611 n,
6612 self.realizers.len(),
6613 t_ensure.elapsed().as_secs_f64(),
6614 );
6615 self.current_theta = Some(theta.clone());
6616 self.last_cost = None;
6617 self.last_eval = None;
6618 Ok(())
6619 }
6620
6621 impl_exact_joint_theta_memo!();
6622
6623 fn store_cost_only(&mut self, theta: &Array1<f64>, cost: f64) {
6629 if self
6630 .current_theta
6631 .as_ref()
6632 .is_some_and(|cached| theta_values_match(cached, theta))
6633 {
6634 self.last_cost = Some(cost);
6635 }
6636 }
6637
6638 fn specs(&self) -> Vec<&TermCollectionSpec> {
6639 self.realizers.iter().map(|r| r.spec()).collect()
6640 }
6641
6642 fn designs(&self) -> Vec<&TermCollectionDesign> {
6643 self.realizers.iter().map(|r| r.design()).collect()
6644 }
6645
6646 fn design_revision(&self) -> u64 {
6656 self.realizers
6657 .iter()
6658 .fold(0u64, |acc, r| acc.wrapping_add(r.design_revision()))
6659 }
6660}
6661
6662pub(crate) fn seed_risk_profile_for_likelihood_family(
6663 family: &LikelihoodSpec,
6664) -> gam_problem::SeedRiskProfile {
6665 match &family.response {
6666 ResponseFamily::Gaussian => gam_problem::SeedRiskProfile::Gaussian,
6667 ResponseFamily::RoystonParmar => gam_problem::SeedRiskProfile::Survival,
6668 ResponseFamily::Binomial
6669 | ResponseFamily::Poisson
6670 | ResponseFamily::Tweedie { .. }
6671 | ResponseFamily::NegativeBinomial { .. }
6672 | ResponseFamily::Beta { .. }
6673 | ResponseFamily::Gamma => gam_problem::SeedRiskProfile::GeneralizedLinear,
6674 }
6675}
6676
6677const EXACT_JOINT_SECOND_ORDER_THETA_CAP: usize = 8;
6685
6686fn exact_joint_seed_config(
6687 risk_profile: gam_problem::SeedRiskProfile,
6688 auxiliary_dim: usize,
6689) -> gam_problem::SeedConfig {
6690 let mut config = gam_problem::SeedConfig {
6691 risk_profile,
6692 num_auxiliary_trailing: auxiliary_dim,
6693 ..Default::default()
6694 };
6695 match risk_profile {
6696 gam_problem::SeedRiskProfile::Gaussian
6697 | gam_problem::SeedRiskProfile::GaussianLocationScale => {
6698 config.max_seeds = 4;
6699 config.seed_budget = 2;
6700 }
6701 gam_problem::SeedRiskProfile::GeneralizedLinear => {
6702 config.max_seeds = 1;
6707 config.seed_budget = 1;
6708 config.screen_max_inner_iterations = 8;
6709 }
6710 gam_problem::SeedRiskProfile::Survival => {
6711 config.max_seeds = 8;
6717 config.seed_budget = 4;
6718 config.screen_max_inner_iterations = 8;
6719 }
6720 }
6721 config
6722}
6723
6724#[cfg(test)]
6725mod exact_joint_seed_config_tests {
6726 use super::*;
6727
6728 #[test]
6729 fn exact_joint_marginal_slope_profiles_get_deeper_startup_validation() {
6730 let bms = exact_joint_seed_config(gam_problem::SeedRiskProfile::GeneralizedLinear, 2);
6731 assert_eq!(bms.max_seeds, 1);
6732 assert_eq!(bms.seed_budget, 1);
6733 assert_eq!(bms.screen_max_inner_iterations, 8);
6734 assert_eq!(bms.num_auxiliary_trailing, 2);
6735
6736 let survival = exact_joint_seed_config(gam_problem::SeedRiskProfile::Survival, 3);
6737 assert_eq!(survival.max_seeds, 8);
6738 assert_eq!(survival.seed_budget, 4);
6739 assert_eq!(survival.screen_max_inner_iterations, 8);
6740 assert_eq!(survival.num_auxiliary_trailing, 3);
6741 }
6742
6743 #[test]
6744 fn exact_joint_gaussian_keeps_tight_historical_multistart_budget() {
6745 let gaussian = exact_joint_seed_config(gam_problem::SeedRiskProfile::Gaussian, 1);
6746 assert_eq!(gaussian.max_seeds, 4);
6747 assert_eq!(gaussian.seed_budget, 2);
6748 assert_eq!(
6749 gaussian.screen_max_inner_iterations,
6750 gam_problem::SeedConfig::default().screen_max_inner_iterations
6751 );
6752 assert_eq!(gaussian.num_auxiliary_trailing, 1);
6753 }
6754}
6755
6756#[cfg(test)]
6757mod wood_reference_df_tests {
6758 use super::*;
6759
6760 #[test]
6766 fn edf1_equals_two_trace_minus_trace_of_square() {
6767 let f = ndarray::array![[0.9_f64, 0.0], [0.0, 0.4]];
6771 let got = wood_reference_df(Some(&f), &(0..2)).unwrap();
6772 assert!(
6773 (got - 1.63).abs() < 1e-12,
6774 "edf1 should be 2*tr - tr(F^2) = 1.63, got {got}"
6775 );
6776 let edf = 1.3;
6779 assert!(got >= edf - 1e-12, "edf1 {got} must be >= edf {edf}");
6780 }
6781
6782 #[test]
6783 fn edf1_never_collapses_below_edf_when_offdiagonals_blow_up() {
6784 let f = ndarray::array![[0.5_f64, 40.0], [40.0, 0.5]];
6791 let tr = 1.0_f64;
6792 let got = wood_reference_df(Some(&f), &(0..2)).unwrap();
6793 assert!(
6794 got >= tr - 1e-12,
6795 "edf1 must be floored at edf (=tr={tr}) even when tr(F^2) explodes, got {got}"
6796 );
6797 assert!(got.is_finite() && got > 0.0, "edf1 must stay finite/positive");
6798 }
6799
6800 #[test]
6801 fn returns_none_on_nonpositive_or_missing_trace() {
6802 assert!(wood_reference_df(None, &(0..2)).is_none());
6805 let zero = ndarray::array![[0.0_f64, 0.0], [0.0, 0.0]];
6807 assert!(wood_reference_df(Some(&zero), &(0..2)).is_none());
6808 let f = ndarray::array![[0.5_f64, 0.0], [0.0, 0.5]];
6810 assert!(wood_reference_df(Some(&f), &(0..5)).is_none());
6811 }
6812}
6813
6814pub(crate) fn exact_joint_multistart_outer_problem(
6815 theta0: &Array1<f64>,
6816 lower: &Array1<f64>,
6817 upper: &Array1<f64>,
6818 rho_dim: usize,
6819 auxiliary_dim: usize,
6820 n_params: usize,
6821 gradient: gam_problem::Derivative,
6822 hessian: gam_problem::DeclaredHessianForm,
6823 prefer_gradient_only: bool,
6824 disable_fixed_point: bool,
6825 risk_profile: gam_problem::SeedRiskProfile,
6826 tolerance: f64,
6827 max_iter: usize,
6828 bfgs_step_cap: Option<f64>,
6837 bfgs_step_cap_psi: Option<f64>,
6838 screening_cap: Option<Arc<AtomicUsize>>,
6839 profiled_objective_size: Option<(usize, usize)>,
6860 has_constant_curvature: bool,
6869) -> gam_solve::rho_optimizer::OuterProblem {
6870 let mut seed_heuristic = theta0.to_vec();
6871 for value in &mut seed_heuristic[..rho_dim] {
6872 *value = value.exp();
6873 }
6874 let rho_ceiling = if has_constant_curvature {
6879 gam_solve::estimate::RHO_BOUND
6880 } else {
6881 12.0
6882 };
6883 let mut problem = gam_solve::rho_optimizer::OuterProblem::new(n_params)
6884 .with_gradient(gradient)
6885 .with_hessian(hessian)
6886 .with_prefer_gradient_only(prefer_gradient_only)
6887 .with_disable_fixed_point(disable_fixed_point)
6888 .with_fallback_policy(gam_solve::rho_optimizer::FallbackPolicy::Automatic)
6898 .with_psi_dim(auxiliary_dim)
6899 .with_tolerance(tolerance)
6900 .with_max_iter(max_iter)
6901 .with_bounds(lower.clone(), upper.clone())
6902 .with_initial_rho(theta0.clone())
6903 .with_bfgs_step_cap(bfgs_step_cap)
6904 .with_bfgs_step_cap_psi(bfgs_step_cap_psi)
6905 .with_seed_config({
6906 let mut sc = exact_joint_seed_config(risk_profile, auxiliary_dim);
6907 if has_constant_curvature {
6908 sc.bounds = (sc.bounds.0, rho_ceiling);
6912 }
6931 sc
6932 })
6933 .with_rho_bound(rho_ceiling)
6934 .with_heuristic_lambdas(seed_heuristic);
6935 if let Some((n_obs, p_cols)) = profiled_objective_size {
6936 problem = problem
6944 .with_objective_scale(Some(n_obs as f64))
6945 .with_problem_size(n_obs, p_cols)
6946 .with_arc_initial_regularization(Some(0.25))
6947 .with_operator_initial_trust_radius(Some(4.0));
6948 }
6949 if let Some(screening_cap) = screening_cap {
6950 problem = problem
6951 .with_screening_cap(screening_cap)
6952 .with_screen_initial_rho(true);
6953 }
6954 problem
6955}
6956
6957fn kappa_phase_failure_is_fixed_kappa_recoverable(message: &str) -> bool {
6968 message.contains("no candidate seeds passed outer startup validation")
6969 || message.contains("joint hyper rho dimension mismatch")
6970 || message.contains("objective returned a non-finite cost")
6971}
6972
6973pub fn optimize_spatial_length_scale_exact_joint<FitOut, FitFn, ExactFn, ExactEfsFn, SeedFn>(
6974 data: ArrayView2<'_, f64>,
6975 block_specs: &[TermCollectionSpec],
6976 block_term_indices: &[Vec<usize>],
6977 kappa_options: &SpatialLengthScaleOptimizationOptions,
6978 joint_setup: &ExactJointHyperSetup,
6979 seed_risk_profile: gam_problem::SeedRiskProfile,
6980 analytic_joint_gradient_available: bool,
6981 analytic_joint_hessian_available: bool,
6982 disable_fixed_point: bool,
6983 screening_cap: Option<Arc<AtomicUsize>>,
6984 outer_derivative_policy: gam_model_api::families::custom_family::OuterDerivativePolicy,
6985 mut fit_fn: FitFn,
6986 mut exact_fn: ExactFn,
6987 mut exact_efs_fn: ExactEfsFn,
6988 mut seed_inner_beta_fn: SeedFn,
6989) -> Result<SpatialLengthScaleOptimizationResult<FitOut>, String>
6990where
6991 FitOut: Clone,
6992 FitFn: FnMut(
6993 &Array1<f64>,
6994 &[TermCollectionSpec],
6995 &[TermCollectionDesign],
6996 ) -> Result<FitOut, String>,
6997 ExactFn: FnMut(
6998 &Array1<f64>,
6999 &[TermCollectionSpec],
7000 &[TermCollectionDesign],
7001 gam_solve::estimate::reml::reml_outer_engine::EvalMode,
7002 &gam_problem::outer_subsample::RowSet,
7003 ) -> Result<
7004 (
7005 f64,
7006 Array1<f64>,
7007 gam_problem::HessianResult,
7008 ),
7009 String,
7010 >,
7011 ExactEfsFn: FnMut(
7012 &Array1<f64>,
7013 &[TermCollectionSpec],
7014 &[TermCollectionDesign],
7015 ) -> Result<gam_problem::EfsEval, String>,
7016 SeedFn:
7017 FnMut(&Array1<f64>) -> Result<gam_solve::rho_optimizer::SeedOutcome, EstimationError>,
7018{
7019 let n_blocks = block_specs.len();
7020 if block_term_indices.len() != n_blocks {
7021 return Err(SmoothError::dimension_mismatch(format!(
7022 "block_specs ({}) and block_term_indices ({}) length mismatch",
7023 n_blocks,
7024 block_term_indices.len()
7025 ))
7026 .into());
7027 }
7028
7029 let log_kappa_dim = joint_setup.log_kappa_dim();
7030
7031 log::warn!(
7032 "[OUTER-FD-AUDIT/spatial-exact-joint] driver entry: aux_dim={} log_kappa_dim={} kappa_enabled={} rho_dim={} theta0_len={}",
7033 joint_setup.auxiliary_dim(),
7034 log_kappa_dim,
7035 kappa_options.enabled,
7036 joint_setup.rho_dim(),
7037 joint_setup.theta0().len()
7038 );
7039
7040 if joint_setup.auxiliary_dim() == 0 && (!kappa_options.enabled || log_kappa_dim == 0) {
7044 log::warn!(
7045 "[OUTER-FD-AUDIT/spatial-exact-joint] taking FAST path (no outer theta optimization in this driver)"
7046 );
7047 let (designs, resolved_specs) = build_term_collection_designs_and_freeze_joint(
7048 data, block_specs,
7049 )
7050 .map_err(|e| {
7051 format!("failed to build and freeze joint block designs during exact joint kappa optimization: {e}")
7052 })?;
7053 let theta0 = joint_setup.theta0();
7054
7055 let spec_refs: Vec<TermCollectionSpec> = resolved_specs.clone();
7057 let design_refs: Vec<TermCollectionDesign> = designs.clone();
7058 let fit = fit_fn(&theta0, &spec_refs, &design_refs)?;
7059 return Ok(SpatialLengthScaleOptimizationResult {
7060 resolved_specs,
7061 designs,
7062 fit,
7063 timing: None,
7064 });
7065 }
7066
7067 let theta0 = joint_setup.theta0();
7071 let lower = joint_setup.lower();
7072 let upper = joint_setup.upper();
7073 if theta0.len() < log_kappa_dim || lower.len() != theta0.len() || upper.len() != theta0.len() {
7074 return Err(SmoothError::dimension_mismatch(format!(
7075 "invalid exact joint theta setup: theta0={}, lower={}, upper={}, required_log_kappa_dim={}",
7076 theta0.len(),
7077 lower.len(),
7078 upper.len(),
7079 log_kappa_dim
7080 ))
7081 .into());
7082 }
7083 let rho_dim = joint_setup.rho_dim();
7084 let all_dims = joint_setup.log_kappa_dims_per_term();
7085
7086 let (boot_designs, best_specs) = build_term_collection_designs_and_freeze_joint(
7088 data,
7089 block_specs,
7090 )
7091 .map_err(|e| {
7092 format!(
7093 "failed to build and freeze joint block designs during exact joint kappa bootstrap: {e}"
7094 )
7095 })?;
7096 let policy_hessian_form = outer_derivative_policy.declared_hessian_form();
7106 let analytic_outer_hessian_available = analytic_joint_hessian_available
7107 && matches!(
7108 policy_hessian_form,
7109 gam_problem::DeclaredHessianForm::Either
7110 | gam_problem::DeclaredHessianForm::Dense
7111 | gam_problem::DeclaredHessianForm::Operator { .. }
7112 );
7113 let prefer_gradient_only = !analytic_outer_hessian_available;
7114
7115 let theta_dim = theta0.len();
7116 let psi_dim = theta_dim - rho_dim;
7117
7118 let cache_blocks: Vec<(TermCollectionSpec, TermCollectionDesign, Vec<usize>)> = best_specs
7120 .iter()
7121 .zip(boot_designs.iter())
7122 .zip(block_term_indices.iter())
7123 .map(|((spec, design), terms)| (spec.clone(), design.clone(), terms.clone()))
7124 .collect();
7125
7126 struct NBlockExactJointState<'d> {
7127 cache: ExactJointDesignCache<'d>,
7128 }
7129
7130 let mut state = NBlockExactJointState {
7131 cache: ExactJointDesignCache::new(data, cache_blocks, rho_dim, all_dims.clone())?,
7132 };
7133
7134 const KAPPA_PILOT_K: usize = 5_000;
7159 const KAPPA_POLISH_K: usize = 25_000;
7160 const KAPPA_POLISH_TRIGGER_N: usize = 100_000;
7161
7162 let n_total = data.nrows();
7163 let use_staged_kappa = outer_derivative_policy.should_use_staged_kappa(n_total);
7164 if use_staged_kappa {
7165 log::info!(
7166 "[KAPPA-STAGED] auto-engaging pilot+polish schedule: n={} pilot_k={} polish_k={}",
7167 n_total,
7168 KAPPA_PILOT_K,
7169 KAPPA_POLISH_K,
7170 );
7171 }
7172
7173 fn build_uniform_pilot_subsample(
7190 n_total: usize,
7191 k_target: usize,
7192 seed: u64,
7193 ) -> gam_problem::outer_subsample::OuterScoreSubsample {
7194 use gam_problem::outer_subsample::OuterScoreSubsample;
7195 let k = k_target.min(n_total);
7196 if k == 0 || n_total == 0 {
7197 return OuterScoreSubsample::from_uniform_inclusion_mask(Vec::new(), n_total, seed);
7198 }
7199 let mut mask: Vec<usize> = Vec::with_capacity(k);
7203 let mut state = seed.wrapping_add(0x9E3779B97F4A7C15);
7205 let splitmix = |s: &mut u64| -> u64 { gam_linalg::utils::splitmix64(s) };
7206 let mut taken = std::collections::HashSet::with_capacity(k);
7207 for j in (n_total - k)..n_total {
7208 let r = (splitmix(&mut state) % (j as u64 + 1)) as usize;
7209 if !taken.insert(r) {
7210 taken.insert(j);
7211 mask.push(j);
7212 } else {
7213 mask.push(r);
7214 }
7215 }
7216 mask.sort_unstable();
7217 mask.dedup();
7218 OuterScoreSubsample::from_uniform_inclusion_mask(mask, n_total, seed)
7219 }
7220
7221 let current_row_set: std::cell::RefCell<gam_problem::outer_subsample::RowSet> = if use_staged_kappa {
7222 let pilot = build_uniform_pilot_subsample(n_total, KAPPA_PILOT_K, n_total as u64);
7223 std::cell::RefCell::new(gam_problem::outer_subsample::RowSet::Subsample {
7224 rows: std::sync::Arc::clone(&pilot.rows),
7225 n_full: n_total,
7226 })
7227 } else {
7228 std::cell::RefCell::new(gam_problem::outer_subsample::RowSet::All)
7229 };
7230
7231 let exact_fn_cell = std::cell::RefCell::new(&mut exact_fn);
7232 let exact_efs_fn_cell = std::cell::RefCell::new(&mut exact_efs_fn);
7233
7234 use std::cell::Cell;
7249 let kphase_cost_calls: Cell<usize> = Cell::new(0);
7250 let kphase_cost_total_s: Cell<f64> = Cell::new(0.0);
7251 let kphase_eval_calls: Cell<usize> = Cell::new(0);
7252 let kphase_eval_total_s: Cell<f64> = Cell::new(0.0);
7253 let kphase_efs_calls: Cell<usize> = Cell::new(0);
7254 let kphase_efs_total_s: Cell<f64> = Cell::new(0.0);
7255 let kphase_optim_start = std::time::Instant::now();
7256 let kphase_log_kappa_dim = log_kappa_dim;
7257 let kphase_log_norms = |theta: &Array1<f64>| -> (f64, f64) {
7258 let theta_norm = theta.iter().map(|v| v * v).sum::<f64>().sqrt();
7259 let log_kappa_norm = if kphase_log_kappa_dim > 0 && theta.len() >= kphase_log_kappa_dim {
7260 let start = theta.len() - kphase_log_kappa_dim;
7261 theta.iter().skip(start).map(|v| v * v).sum::<f64>().sqrt()
7262 } else {
7263 0.0
7264 };
7265 (theta_norm, log_kappa_norm)
7266 };
7267
7268 use gam_solve::rho_optimizer::OuterEvalOrder;
7269 use gam_problem::{DeclaredHessianForm, Derivative, OuterEval};
7270
7271 let joint_p_cols: usize = boot_designs
7275 .iter()
7276 .map(|d| d.design.ncols())
7277 .sum::<usize>()
7278 .max(1);
7279
7280 let problem = exact_joint_multistart_outer_problem(
7281 &theta0,
7282 &lower,
7283 &upper,
7284 rho_dim,
7285 psi_dim,
7286 theta_dim,
7287 if analytic_joint_gradient_available {
7288 Derivative::Analytic
7289 } else {
7290 Derivative::Unavailable
7291 },
7292 if analytic_outer_hessian_available {
7293 DeclaredHessianForm::Either
7294 } else {
7295 DeclaredHessianForm::Unavailable
7296 },
7297 prefer_gradient_only,
7298 disable_fixed_point,
7299 seed_risk_profile,
7300 kappa_options.rel_tol.max(1e-6),
7301 kappa_options.max_outer_iter.max(1),
7302 Some(5.0),
7304 Some(kappa_options.log_step.clamp(0.25, 1.0)),
7306 screening_cap.clone(),
7307 Some((n_total, joint_p_cols)),
7310 block_specs
7313 .iter()
7314 .any(|s| !constant_curvature_term_indices(s).is_empty()),
7315 );
7316
7317 fn collect_specs(cache: &ExactJointDesignCache<'_>) -> Vec<TermCollectionSpec> {
7319 cache.specs().into_iter().cloned().collect()
7320 }
7321 fn collect_designs(cache: &ExactJointDesignCache<'_>) -> Vec<TermCollectionDesign> {
7322 cache.designs().into_iter().cloned().collect()
7323 }
7324
7325 let result = {
7326 let eval_outer = |ctx: &mut &mut NBlockExactJointState<'_>,
7327 theta: &Array1<f64>,
7328 order: OuterEvalOrder|
7329 -> Result<OuterEval, EstimationError> {
7330 if let Some((cost, grad, hess)) = ctx.cache.memoized_eval(theta) {
7331 let cached_satisfies_order = match order {
7332 OuterEvalOrder::Value => true,
7333 OuterEvalOrder::ValueAndGradient => true,
7334 OuterEvalOrder::ValueGradientHessian => hess.is_analytic(),
7335 };
7336 if cached_satisfies_order {
7337 if !cost.is_finite() {
7338 return Ok(OuterEval::infeasible(theta.len()));
7339 }
7340 if grad.iter().any(|v| !v.is_finite()) {
7353 return Ok(OuterEval::infeasible(theta.len()));
7354 }
7355 return Ok(OuterEval {
7356 cost,
7357 gradient: grad,
7358 hessian: hess,
7359 inner_beta_hint: None,
7360 });
7361 }
7362 }
7363 if gam_solve::rho_optimizer::outer_wall_clock_deadline_exceeded() {
7380 return Ok(OuterEval::infeasible(theta.len()));
7381 }
7382 if let Err(err) = ctx.cache.ensure_theta(theta) {
7383 log::warn!(
7384 "[OUTER] n-block exact-joint spatial: ensure_theta failed during gradient evaluation: {err}"
7385 );
7386 return Ok(OuterEval::infeasible(theta.len()));
7387 }
7388 let design_revision = Some(ctx.cache.design_revision());
7389 let specs = collect_specs(&ctx.cache);
7390 let designs = collect_designs(&ctx.cache);
7391 let clamped = outer_derivative_policy.order_for_evaluation(order);
7399 let need_hessian = matches!(clamped, OuterEvalOrder::ValueGradientHessian)
7400 && analytic_outer_hessian_available;
7401 let eval_mode = if need_hessian {
7402 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueGradientHessian
7403 } else {
7404 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueAndGradient
7405 };
7406 let t0 = std::time::Instant::now();
7407 let result = {
7408 let row_set_borrow = current_row_set.borrow();
7409 (*exact_fn_cell.borrow_mut())(theta, &specs, &designs, eval_mode, &row_set_borrow)
7410 };
7411 let elapsed_s = t0.elapsed().as_secs_f64();
7412 kphase_eval_calls.set(kphase_eval_calls.get() + 1);
7413 kphase_eval_total_s.set(kphase_eval_total_s.get() + elapsed_s);
7414 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta);
7415 log::info!(
7416 "[KAPPA-PHASE] phase=eval_outer call={} order={:?} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
7417 kphase_eval_calls.get(),
7418 order,
7419 design_revision,
7420 theta_norm,
7421 log_kappa_norm,
7422 elapsed_s,
7423 );
7424 match result {
7425 Ok((cost, grad, hess)) => {
7426 ctx.cache.store_eval((cost, grad.clone(), hess.clone()));
7427 if !cost.is_finite() {
7428 return Ok(OuterEval::infeasible(theta.len()));
7429 }
7430 if grad.iter().any(|v| !v.is_finite()) {
7443 return Ok(OuterEval::infeasible(theta.len()));
7444 }
7445 Ok(OuterEval {
7446 cost,
7447 gradient: grad,
7448 hessian: hess,
7449 inner_beta_hint: None,
7450 })
7451 }
7452 Err(err) => {
7453 log::warn!(
7454 "[OUTER] n-block exact-joint spatial: exact evaluation failed: {err}"
7455 );
7456 Ok(OuterEval::infeasible(theta.len()))
7457 }
7458 }
7459 };
7460
7461 let obj = problem.build_objective_with_eval_order(
7462 &mut state,
7463 |ctx: &mut &mut NBlockExactJointState<'_>, theta: &Array1<f64>| {
7464 if let Some(cost) = ctx.cache.memoized_cost(theta) {
7465 return Ok(cost);
7466 }
7467 if gam_solve::rho_optimizer::outer_wall_clock_deadline_exceeded() {
7475 return Ok(f64::INFINITY);
7476 }
7477 if let Err(err) = ctx.cache.ensure_theta(theta) {
7478 log::warn!(
7479 "[OUTER] n-block exact-joint spatial: ensure_theta failed during cost evaluation: {err}"
7480 );
7481 return Ok(f64::INFINITY);
7482 }
7483 let design_revision = Some(ctx.cache.design_revision());
7484 let specs = collect_specs(&ctx.cache);
7485 let designs = collect_designs(&ctx.cache);
7486 let t0 = std::time::Instant::now();
7493 let result = {
7494 let row_set_borrow = current_row_set.borrow();
7495 (*exact_fn_cell.borrow_mut())(
7496 theta,
7497 &specs,
7498 &designs,
7499 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueOnly,
7500 &row_set_borrow,
7501 )
7502 };
7503 let elapsed_s = t0.elapsed().as_secs_f64();
7504 kphase_cost_calls.set(kphase_cost_calls.get() + 1);
7505 kphase_cost_total_s.set(kphase_cost_total_s.get() + elapsed_s);
7506 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta);
7507 log::info!(
7508 "[KAPPA-PHASE] phase=cost call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
7509 kphase_cost_calls.get(),
7510 design_revision,
7511 theta_norm,
7512 log_kappa_norm,
7513 elapsed_s,
7514 );
7515 match result {
7516 Ok((cost, _grad, _hess)) => {
7517 ctx.cache.store_cost_only(theta, cost);
7523 Ok(cost)
7524 }
7525 Err(err) => {
7526 log::warn!(
7527 "[OUTER] n-block exact-joint spatial: exact cost evaluation failed: {err}"
7528 );
7529 Ok(f64::INFINITY)
7530 }
7531 }
7532 },
7533 |ctx: &mut &mut NBlockExactJointState<'_>, theta: &Array1<f64>| {
7534 eval_outer(
7535 ctx,
7536 theta,
7537 if analytic_outer_hessian_available {
7538 OuterEvalOrder::ValueGradientHessian
7539 } else {
7540 OuterEvalOrder::ValueAndGradient
7541 },
7542 )
7543 },
7544 |ctx: &mut &mut NBlockExactJointState<'_>,
7545 theta: &Array1<f64>,
7546 order: OuterEvalOrder| { eval_outer(ctx, theta, order) },
7547 None::<fn(&mut &mut NBlockExactJointState<'_>)>,
7548 Some(
7549 |ctx: &mut &mut NBlockExactJointState<'_>, theta: &Array1<f64>| {
7550 ctx.cache
7551 .ensure_theta(theta)
7552 .map_err(EstimationError::InvalidInput)?;
7553 let design_revision = Some(ctx.cache.design_revision());
7554 let specs = collect_specs(&ctx.cache);
7555 let designs = collect_designs(&ctx.cache);
7556 let t0 = std::time::Instant::now();
7557 let eval_result = (*exact_efs_fn_cell.borrow_mut())(
7558 theta,
7559 &specs,
7560 &designs,
7561 );
7562 let elapsed_s = t0.elapsed().as_secs_f64();
7563 kphase_efs_calls.set(kphase_efs_calls.get() + 1);
7564 kphase_efs_total_s.set(kphase_efs_total_s.get() + elapsed_s);
7565 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta);
7566 log::info!(
7567 "[KAPPA-PHASE] phase=efs call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
7568 kphase_efs_calls.get(),
7569 design_revision,
7570 theta_norm,
7571 log_kappa_norm,
7572 elapsed_s,
7573 );
7574 let eval = eval_result.map_err(EstimationError::RemlOptimizationFailed)?;
7575 Ok(eval)
7576 },
7577 ),
7578 );
7579 let mut obj = obj.with_seed_inner_state(
7580 move |_ctx: &mut &mut NBlockExactJointState<'_>, beta: &Array1<f64>| {
7581 (seed_inner_beta_fn)(beta)
7582 },
7583 );
7584
7585 match problem.run(&mut obj, "n-block exact-joint spatial") {
7586 Ok(result) => result,
7587 Err(e) => {
7588 let message = e.to_string();
7589 if kappa_phase_failure_is_fixed_kappa_recoverable(&message) {
7609 drop(obj);
7610 log::warn!(
7611 "[KAPPA-PHASE] length-scale optimization could not validate any seed \
7612 ({message}); falling back to a FIXED bootstrap κ (skipping κ \
7613 optimization) and fitting there — a real model at the initial \
7614 length-scale rather than raising (gam#787/#860)."
7615 );
7616 let (designs, resolved_specs) =
7617 build_term_collection_designs_and_freeze_joint(data, block_specs).map_err(
7618 |build_err| {
7619 format!(
7620 "fixed-κ fallback failed to build and freeze joint block \
7621 designs after κ optimization could not validate a seed \
7622 ({message}): {build_err}"
7623 )
7624 },
7625 )?;
7626 let fixed_theta0 = joint_setup.theta0();
7627 let spec_refs: Vec<TermCollectionSpec> = resolved_specs.clone();
7628 let design_refs: Vec<TermCollectionDesign> = designs.clone();
7629 let fit = fit_fn(&fixed_theta0, &spec_refs, &design_refs)?;
7630 return Ok(SpatialLengthScaleOptimizationResult {
7631 resolved_specs,
7632 designs,
7633 fit,
7634 timing: None,
7635 });
7636 }
7637 return Err(message);
7638 }
7639 }
7640 }; let kphase_total_s = kphase_optim_start.elapsed().as_secs_f64();
7650 log::info!(
7651 "[KAPPA-PHASE-SUMMARY] log_kappa_dim={} n_cost={} cost_total_s={:.4} n_eval={} eval_total_s={:.4} n_efs={} efs_total_s={:.4} optim_total_s={:.4}",
7652 kphase_log_kappa_dim,
7653 kphase_cost_calls.get(),
7654 kphase_cost_total_s.get(),
7655 kphase_eval_calls.get(),
7656 kphase_eval_total_s.get(),
7657 kphase_efs_calls.get(),
7658 kphase_efs_total_s.get(),
7659 kphase_total_s,
7660 );
7661 let timing = SpatialLengthScaleOptimizationTiming {
7662 log_kappa_dim: kphase_log_kappa_dim,
7663 cost_calls: kphase_cost_calls.get(),
7664 cost_total_s: kphase_cost_total_s.get(),
7665 eval_calls: kphase_eval_calls.get(),
7666 eval_total_s: kphase_eval_total_s.get(),
7667 efs_calls: kphase_efs_calls.get(),
7668 efs_total_s: kphase_efs_total_s.get(),
7669 slow_path_resets: 0,
7670 design_revision_delta: 0,
7671 nfree_miss_shape: 0,
7672 nfree_miss_value: 0,
7673 nfree_miss_gradient: 0,
7674 nfree_miss_penalty: 0,
7675 nfree_miss_revision: 0,
7676 nfree_miss_second_order: 0,
7677 nfree_miss_other: 0,
7678 optim_total_s: kphase_total_s,
7679 };
7680
7681 let theta_star = result.rho;
7682
7683 if use_staged_kappa && n_total >= KAPPA_POLISH_TRIGGER_N {
7700 let polish = build_uniform_pilot_subsample(
7701 n_total,
7702 KAPPA_POLISH_K,
7703 (n_total as u64).wrapping_add(0xA5A5A5A5),
7704 );
7705 *current_row_set.borrow_mut() = gam_problem::outer_subsample::RowSet::Subsample {
7706 rows: std::sync::Arc::clone(&polish.rows),
7707 n_full: n_total,
7708 };
7709 log::info!(
7710 "[KAPPA-STAGED] rotating to polish subsample: k={} at theta_star",
7711 polish.rows.len(),
7712 );
7713 state.cache.ensure_theta(&theta_star)?;
7717 let (polish_cost, polish_grad, _) = {
7718 let specs = collect_specs(&state.cache);
7719 let designs = collect_designs(&state.cache);
7720 let row_set_borrow = current_row_set.borrow();
7721 exact_fn(
7722 &theta_star,
7723 &specs,
7724 &designs,
7725 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueAndGradient,
7726 &row_set_borrow,
7727 )?
7728 };
7729 if !polish_cost.is_finite() || polish_grad.iter().any(|value| !value.is_finite()) {
7730 return Err(
7731 "polish subsample exact-joint evaluation produced non-finite objective pieces"
7732 .to_string(),
7733 );
7734 }
7735 }
7736 *current_row_set.borrow_mut() = gam_problem::outer_subsample::RowSet::All;
7737 if use_staged_kappa {
7738 log::info!(
7739 "[KAPPA-STAGED] rotating to full data for final coefficient fit (n={})",
7740 n_total,
7741 );
7742 }
7743
7744 state.cache.ensure_theta(&theta_star)?;
7745
7746 let resolved_specs: Vec<TermCollectionSpec> = collect_specs(&state.cache);
7747 let designs: Vec<TermCollectionDesign> = collect_designs(&state.cache);
7748
7749 let fit = fit_fn(&theta_star, &resolved_specs, &designs)?;
7750
7751 for spec in &resolved_specs {
7752 log_spatial_aniso_scales(spec);
7753 }
7754
7755 Ok(SpatialLengthScaleOptimizationResult {
7756 resolved_specs,
7757 designs,
7758 fit,
7759 timing: Some(timing),
7760 })
7761}
7762
7763fn try_exact_joint_latent_coord_optimization(
7764 data: ArrayView2<'_, f64>,
7765 y: ArrayView1<'_, f64>,
7766 weights: ArrayView1<'_, f64>,
7767 offset: ArrayView1<'_, f64>,
7768 resolvedspec: &TermCollectionSpec,
7769 best: &FittedTermCollection,
7770 family: LikelihoodSpec,
7771 options: &FitOptions,
7772 latent: &StandardLatentCoordConfig,
7773) -> Result<FittedTermCollectionWithSpec, EstimationError> {
7774 use gam_solve::rho_optimizer::OuterEvalOrder;
7775 use gam_problem::{DeclaredHessianForm, Derivative, OuterEval};
7776
7777 let rho_dim = best.fit.lambdas.len();
7778 let latent_flat_dim = latent.values.len();
7779 if latent_flat_dim == 0 {
7780 crate::bail_invalid_estim!(
7781 "latent-coordinate optimization requires a non-empty latent block"
7782 );
7783 }
7784 let direct_hypers =
7785 latent_coord_initial_direct_hypers(latent.values.id_mode(), latent.values.latent_dim())?;
7786 let analytic_rho_count = latent
7787 .analytic_penalties
7788 .as_ref()
7789 .map_or(0, |registry| registry.total_rho_count());
7790 let latent_coord_ext_dim = latent_flat_dim + analytic_rho_count + direct_hypers.len();
7791
7792 let mut theta0 = Array1::<f64>::zeros(rho_dim + latent_coord_ext_dim);
7793 theta0
7794 .slice_mut(s![..rho_dim])
7795 .assign(&best.fit.lambdas.mapv(f64::ln));
7796 theta0
7797 .slice_mut(s![rho_dim..rho_dim + latent_flat_dim])
7798 .assign(latent.values.as_flat());
7799 if !direct_hypers.is_empty() {
7800 let direct_start = rho_dim + latent_flat_dim + analytic_rho_count;
7801 theta0
7802 .slice_mut(s![direct_start..direct_start + direct_hypers.len()])
7803 .assign(&direct_hypers);
7804 }
7805
7806 let mut lower = Array1::<f64>::from_elem(theta0.len(), -12.0);
7807 let mut upper = Array1::<f64>::from_elem(theta0.len(), 12.0);
7808 let latent_bound = latent
7809 .values
7810 .as_flat()
7811 .iter()
7812 .fold(1.0_f64, |acc, &v| acc.max(v.abs()))
7813 + 10.0;
7814 for axis in rho_dim..rho_dim + latent_flat_dim {
7815 lower[axis] = -latent_bound;
7816 upper[axis] = latent_bound;
7817 }
7818
7819 struct LatentJointContext<'d> {
7820 rho_dim: usize,
7821 cache: SingleBlockLatentCoordDesignCache,
7822 evaluator: gam_solve::estimate::ExternalJointHyperEvaluator<'d>,
7823 }
7824
7825 impl<'d> LatentJointContext<'d> {
7826 fn eval_full(
7827 &mut self,
7828 theta: &Array1<f64>,
7829 order: OuterEvalOrder,
7830 ) -> Result<
7831 (
7832 f64,
7833 Array1<f64>,
7834 gam_problem::HessianResult,
7835 ),
7836 EstimationError,
7837 > {
7838 if let Some(eval) = self.cache.memoized_eval(theta) {
7839 return Ok(eval);
7840 }
7841 self.cache
7842 .ensure_theta(theta)
7843 .map_err(EstimationError::InvalidInput)?;
7844 let hyper_dirs = self
7845 .cache
7846 .hyper_dirs()
7847 .map_err(EstimationError::InvalidInput)?;
7848 let design_revision = Some(self.cache.design_revision());
7849 let registry_for_key = self.cache.analytic_penalties();
7850 self.evaluator
7851 .set_analytic_penalty_registry(registry_for_key.as_deref());
7852 let mut eval = evaluate_joint_reml_outer_eval_at_theta(
7853 &mut self.evaluator,
7854 self.cache.design(),
7855 theta,
7856 self.rho_dim,
7857 hyper_dirs,
7858 None,
7859 order,
7860 design_revision,
7861 )?;
7862 let latent = self.cache.latent().map_err(EstimationError::InvalidInput)?;
7863 if let Some(registry) = registry_for_key {
7864 let mut registry = registry.as_ref().clone();
7865 registry.apply_weight_schedules(
7866 gam_solve::estimate::reml::outer_eval::current_outer_iter() as usize,
7867 );
7868 add_analytic_penalty_objective_to_eval(
7869 theta,
7870 self.rho_dim,
7871 latent.as_ref(),
7872 ®istry,
7873 &mut eval,
7874 )?;
7875 }
7876 add_latent_id_objective_to_eval(
7877 theta,
7878 self.rho_dim,
7879 self.cache.analytic_penalty_rho_count(),
7880 latent.as_ref(),
7881 &mut eval,
7882 )?;
7883 self.cache.store_eval(eval.clone());
7884 Ok(eval)
7885 }
7886
7887 fn eval_efs(
7888 &mut self,
7889 theta: &Array1<f64>,
7890 ) -> Result<gam_problem::EfsEval, EstimationError> {
7891 self.cache
7892 .ensure_theta(theta)
7893 .map_err(EstimationError::InvalidInput)?;
7894 let hyper_dirs = self
7895 .cache
7896 .hyper_dirs()
7897 .map_err(EstimationError::InvalidInput)?;
7898 let registry_for_key = self.cache.analytic_penalties();
7899 self.evaluator
7900 .set_analytic_penalty_registry(registry_for_key.as_deref());
7901 let mut efs = evaluate_joint_reml_efs_at_theta(
7902 &mut self.evaluator,
7903 self.cache.design(),
7904 theta,
7905 self.rho_dim,
7906 hyper_dirs,
7907 None,
7908 Some(self.cache.design_revision()),
7909 )?;
7910 if let Some(registry) = registry_for_key {
7911 let mut registry = registry.as_ref().clone();
7912 registry.apply_weight_schedules(
7913 gam_solve::estimate::reml::outer_eval::current_outer_iter() as usize,
7914 );
7915 let latent = self.cache.latent().map_err(EstimationError::InvalidInput)?;
7916 let contribution = analytic_penalty_objective_contribution(
7917 theta,
7918 self.rho_dim,
7919 latent.as_ref(),
7920 ®istry,
7921 )?;
7922 efs.cost += contribution.cost;
7923 if let (Some(psi_gradient), Some(psi_indices)) =
7924 (efs.psi_gradient.as_mut(), efs.psi_indices.as_ref())
7925 {
7926 if psi_gradient.len() != psi_indices.len() {
7927 crate::bail_invalid_estim!(
7928 "latent-coordinate analytic penalty EFS psi gradient length mismatch: gradient={}, indices={}",
7929 psi_gradient.len(),
7930 psi_indices.len()
7931 );
7932 }
7933 for (local_idx, &theta_idx) in psi_indices.iter().enumerate() {
7934 psi_gradient[local_idx] += contribution.gradient[theta_idx];
7935 }
7936 }
7937 }
7938 Ok(efs)
7939 }
7940
7941 fn eval_cost(&mut self, theta: &Array1<f64>) -> f64 {
7942 if let Some(cost) = self.cache.memoized_cost(theta) {
7943 return cost;
7944 }
7945 if self.cache.ensure_theta(theta).is_err() {
7946 return f64::INFINITY;
7947 }
7948 let design_revision = Some(self.cache.design_revision());
7949 let registry_for_key = self.cache.analytic_penalties();
7950 self.evaluator
7951 .set_analytic_penalty_registry(registry_for_key.as_deref());
7952 let result = {
7953 let design = self.cache.design();
7954 self.evaluator.evaluate_cost_only(
7955 &design.design,
7956 &design.penalties,
7957 &design.nullspace_dims,
7958 design.linear_constraints.clone(),
7959 theta,
7960 self.rho_dim,
7961 None,
7962 "latent-coordinate-joint cost-only",
7963 design_revision,
7964 )
7965 };
7966 match result {
7967 Ok(cost) => {
7968 let latent = match self.cache.latent() {
7969 Ok(latent) => latent,
7970 Err(_) => return f64::INFINITY,
7971 };
7972 let contribution = match latent_id_objective_contribution(
7973 theta,
7974 self.rho_dim,
7975 self.cache.analytic_penalty_rho_count(),
7976 latent.as_ref(),
7977 ) {
7978 Ok(contribution) => contribution,
7979 Err(_) => return f64::INFINITY,
7980 };
7981 let cost = cost + contribution.cost;
7982 let cost = if let Some(registry) = registry_for_key {
7983 let mut registry = registry.as_ref().clone();
7984 registry.apply_weight_schedules(
7985 gam_solve::estimate::reml::outer_eval::current_outer_iter()
7986 as usize,
7987 );
7988 match analytic_penalty_objective_contribution(
7989 theta,
7990 self.rho_dim,
7991 latent.as_ref(),
7992 ®istry,
7993 ) {
7994 Ok(contribution) => cost + contribution.cost,
7995 Err(_) => return f64::INFINITY,
7996 }
7997 } else {
7998 cost
7999 };
8000 self.cache.store_cost(cost);
8001 cost
8002 }
8003 Err(_) => f64::INFINITY,
8004 }
8005 }
8006 }
8007
8008 let mut ctx = LatentJointContext {
8009 rho_dim,
8010 cache: SingleBlockLatentCoordDesignCache::new(
8011 data.to_owned(),
8012 resolvedspec.clone(),
8013 best.design.clone(),
8014 latent,
8015 rho_dim,
8016 )
8017 .map_err(EstimationError::InvalidInput)?,
8018 evaluator: gam_solve::estimate::ExternalJointHyperEvaluator::new(
8019 y,
8020 weights,
8021 &best.design.design,
8022 offset,
8023 &best.design.penalties,
8024 &external_opts_for_design(&family, &best.design, options),
8025 "latent-coordinate-joint",
8026 )?,
8027 };
8028 let registry_for_key = ctx.cache.analytic_penalties();
8029 ctx.evaluator
8030 .set_analytic_penalty_registry(registry_for_key.as_deref());
8031 ctx.evaluator
8032 .set_persistent_latent_values_fingerprint(latent.values.id_mode());
8033 if let Some(cached_t) = ctx
8034 .evaluator
8035 .load_persistent_latent_values(latent.values.n_obs(), latent.values.latent_dim())
8036 {
8037 let cached_t: Array2<f64> = cached_t;
8038 for (dst, src) in theta0
8039 .slice_mut(s![rho_dim..rho_dim + latent_flat_dim])
8040 .iter_mut()
8041 .zip(cached_t.iter())
8042 {
8043 *dst = *src;
8044 }
8045 }
8046
8047 let problem = exact_joint_multistart_outer_problem(
8048 &theta0,
8049 &lower,
8050 &upper,
8051 rho_dim,
8052 latent_coord_ext_dim,
8053 theta0.len(),
8054 Derivative::Analytic,
8055 DeclaredHessianForm::Unavailable,
8056 false,
8057 false,
8058 seed_risk_profile_for_likelihood_family(&family),
8059 options.tol,
8060 options.max_iter.max(1),
8061 Some(5.0),
8062 Some(0.5),
8063 None,
8064 Some((data.nrows(), best.design.design.ncols().max(1))),
8067 !constant_curvature_term_indices(resolvedspec).is_empty(),
8070 );
8071
8072 let eval_outer = |ctx: &mut &mut LatentJointContext<'_>,
8073 theta: &Array1<f64>,
8074 order: OuterEvalOrder|
8075 -> Result<OuterEval, EstimationError> {
8076 let (cost, gradient, hessian) = ctx.eval_full(theta, order)?;
8077 Ok(OuterEval {
8078 cost,
8079 gradient,
8080 hessian,
8081 inner_beta_hint: None,
8082 })
8083 };
8084
8085 let result = {
8086 let mut obj = problem.build_objective_with_eval_order(
8087 &mut ctx,
8088 |ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>| Ok(ctx.eval_cost(theta)),
8089 |ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>| {
8090 eval_outer(ctx, theta, OuterEvalOrder::ValueAndGradient)
8091 },
8092 |ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>, order: OuterEvalOrder| {
8093 eval_outer(ctx, theta, order)
8094 },
8095 Some(|ctx: &mut &mut LatentJointContext<'_>| {
8096 ctx.cache.reset();
8097 }),
8098 Some(|ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>| ctx.eval_efs(theta)),
8099 );
8100
8101 problem
8102 .run(&mut obj, "latent-coordinate joint REML")
8103 .map_err(|e| {
8104 EstimationError::InvalidInput(format!(
8105 "latent-coordinate joint optimization failed after exhausting strategy fallbacks: {e}"
8106 ))
8107 })?
8108 };
8109 if !result.converged {
8110 crate::bail_invalid_estim!(
8111 "latent-coordinate joint optimization did not converge after {} iterations (final_objective={:.6e}, final_grad_norm={})",
8112 result.iterations,
8113 result.final_value,
8114 result.final_grad_norm_report(),
8115 );
8116 }
8117
8118 let theta_star = result.rho;
8119 let rho_star = theta_star.slice(s![..rho_dim]).mapv(f64::exp);
8120 let mut final_data = data.to_owned();
8121 let flat_t = theta_star
8122 .slice(s![rho_dim..rho_dim + latent_flat_dim])
8123 .to_owned();
8124 let mut fitted_latent_values =
8125 Array2::<f64>::zeros((latent.values.n_obs(), latent.values.latent_dim()));
8126 for n in 0..latent.values.n_obs() {
8127 for axis in 0..latent.values.latent_dim() {
8128 let value = flat_t[n * latent.values.latent_dim() + axis];
8129 fitted_latent_values[[n, axis]] = value;
8130 final_data[[n, latent.feature_cols[axis]]] = value;
8131 }
8132 }
8133 let optimized = fit_term_collection_forspecwith_heuristic_lambdas(
8134 final_data.view(),
8135 y,
8136 weights,
8137 offset,
8138 resolvedspec,
8139 rho_star.as_slice(),
8140 family,
8141 options,
8142 )?;
8143 ctx.evaluator
8144 .store_persistent_latent_values(&fitted_latent_values);
8145 let mut fit = optimized.fit;
8146 fit.reml_score = result.final_value;
8147 fit.penalized_objective = result.final_value;
8148 Ok(FittedTermCollectionWithSpec {
8149 fit,
8150 design: optimized.design,
8151 resolvedspec: resolvedspec.clone(),
8152 adaptive_diagnostics: optimized.adaptive_diagnostics,
8153 kappa_timing: None,
8154 })
8155}
8156
8157pub fn fit_term_collectionwith_latent_coord_optimization(
8158 data: ArrayView2<'_, f64>,
8159 y: Array1<f64>,
8160 weights: Array1<f64>,
8161 offset: Array1<f64>,
8162 spec: &TermCollectionSpec,
8163 latent: &StandardLatentCoordConfig,
8164 family: LikelihoodSpec,
8165 options: &FitOptions,
8166) -> Result<FittedTermCollectionWithSpec, EstimationError> {
8167 let n = data.nrows();
8168 if !(y.len() == n && weights.len() == n && offset.len() == n) {
8169 crate::bail_invalid_estim!(
8170 "fit_term_collectionwith_latent_coord_optimization row mismatch: n={}, y={}, weights={}, offset={}",
8171 n,
8172 y.len(),
8173 weights.len(),
8174 offset.len()
8175 );
8176 }
8177 let best = fit_term_collection_forspec(
8178 data,
8179 y.view(),
8180 weights.view(),
8181 offset.view(),
8182 spec,
8183 family.clone(),
8184 options,
8185 )?;
8186 let resolvedspec = freeze_term_collection_from_design(spec, &best.design)?;
8187 try_exact_joint_latent_coord_optimization(
8188 data,
8189 y.view(),
8190 weights.view(),
8191 offset.view(),
8192 &resolvedspec,
8193 &best,
8194 family,
8195 options,
8196 latent,
8197 )
8198}
8199
8200pub fn fit_term_collectionwith_spatial_length_scale_optimization(
8201 data: ArrayView2<'_, f64>,
8202 y: Array1<f64>,
8203 weights: Array1<f64>,
8204 offset: Array1<f64>,
8205 spec: &TermCollectionSpec,
8206 family: LikelihoodSpec,
8207 options: &FitOptions,
8208 kappa_options: &SpatialLengthScaleOptimizationOptions,
8209) -> Result<FittedTermCollectionWithSpec, EstimationError> {
8210 let mut resolvedspec = spec.clone();
8226 let spatial_terms = spatial_length_scale_term_indices(&resolvedspec);
8227 let n = data.nrows();
8228 if !(y.len() == n && weights.len() == n && offset.len() == n) {
8229 crate::bail_invalid_estim!(
8230 "fit_term_collectionwith_spatial_length_scale_optimization row mismatch: n={}, y={}, weights={}, offset={}",
8231 n,
8232 y.len(),
8233 weights.len(),
8234 offset.len()
8235 );
8236 }
8237 if !kappa_options.enabled || spatial_terms.is_empty() {
8238 let out = fit_term_collection_forspec(
8239 data,
8240 y.view(),
8241 weights.view(),
8242 offset.view(),
8243 &resolvedspec,
8244 family,
8245 options,
8246 )?;
8247 let resolvedspec = freeze_term_collection_from_design(&resolvedspec, &out.design)?;
8248 return Ok(FittedTermCollectionWithSpec {
8249 fit: out.fit,
8250 design: out.design,
8251 resolvedspec,
8252 adaptive_diagnostics: out.adaptive_diagnostics,
8253 kappa_timing: None,
8254 });
8255 }
8256 if kappa_options.max_outer_iter == 0 {
8257 crate::bail_invalid_estim!("spatial kappa optimization requires max_outer_iter >= 1");
8258 }
8259 if !(kappa_options.log_step.is_finite() && kappa_options.log_step > 0.0) {
8260 crate::bail_invalid_estim!("spatial kappa optimization requires log_step > 0");
8261 }
8262 if !(kappa_options.min_length_scale.is_finite()
8263 && kappa_options.max_length_scale.is_finite()
8264 && kappa_options.min_length_scale > 0.0
8265 && kappa_options.max_length_scale >= kappa_options.min_length_scale)
8266 {
8267 crate::bail_invalid_estim!(
8268 "spatial kappa optimization requires valid positive length_scale bounds"
8269 );
8270 }
8271
8272 let pilot_threshold = kappa_options.pilot_subsample_threshold;
8273 if pilot_threshold > 0 && n > pilot_threshold * 2 {
8274 log::info!(
8275 "[spatial-kappa] n={n} exceeds pilot threshold {}; using pilot geometry only for deterministic anisotropy initialization",
8276 pilot_threshold * 2,
8277 );
8278 apply_spatial_anisotropy_pilot_initializer(
8279 data,
8280 &mut resolvedspec,
8281 &spatial_terms,
8282 pilot_threshold,
8283 kappa_options,
8284 );
8285 }
8286
8287 apply_response_aware_anisotropy_seed(data, y.view(), &mut resolvedspec, &spatial_terms);
8296
8297 for term_idx in constant_curvature_term_indices(&resolvedspec) {
8315 if let Some(kappa_seed) =
8316 select_constant_curvature_kappa_sign_seed(data, y.view(), &resolvedspec, term_idx)
8317 && kappa_seed != 0.0
8318 && let Some(SmoothBasisSpec::ConstantCurvature { spec: cc, .. }) =
8319 resolvedspec.smooth_terms.get_mut(term_idx).map(|t| &mut t.basis)
8320 {
8321 log::info!(
8322 "[#1464] pinned CC term {term_idx} baseline κ to κ-fair scan value {kappa_seed} \
8323 (raw profiled REML is sign-blind; scan is authoritative for the sign)"
8324 );
8325 cc.kappa = kappa_seed;
8326 }
8327 }
8328
8329 let baseline_options = superseded_fit_options(options);
8330 let mut best = fit_term_collection_forspec(
8331 data,
8332 y.view(),
8333 weights.view(),
8334 offset.view(),
8335 &resolvedspec,
8336 family.clone(),
8337 &baseline_options,
8338 )?;
8339 resolvedspec = freeze_term_collection_from_design(&resolvedspec, &best.design)?;
8340 let mut spatial_terms = spatial_length_scale_term_indices(&resolvedspec);
8350 sync_aniso_contrasts_from_metadata(&mut resolvedspec, &best.design.smooth);
8354 let mut prescan_improved = false;
8361 if !spatial_terms.is_empty() {
8362 let baseline_score = fit_score(&best.fit);
8363 let range_overrides = prescan_isotropic_spatial_range_seed(
8364 data,
8365 y.view(),
8366 weights.view(),
8367 offset.view(),
8368 &resolvedspec,
8369 baseline_score,
8370 &family,
8371 &baseline_options,
8372 kappa_options,
8373 &spatial_terms,
8374 )?;
8375 if !range_overrides.is_empty() {
8376 prescan_improved = true;
8377 for (term_idx, length_scale) in range_overrides {
8378 set_spatial_length_scale(&mut resolvedspec, term_idx, length_scale)?;
8379 }
8380 best = fit_term_collection_forspec(
8384 data,
8385 y.view(),
8386 weights.view(),
8387 offset.view(),
8388 &resolvedspec,
8389 family.clone(),
8390 &baseline_options,
8391 )?;
8392 resolvedspec = freeze_term_collection_from_design(&resolvedspec, &best.design)?;
8393 spatial_terms = spatial_length_scale_term_indices(&resolvedspec);
8397 }
8398 }
8399 if spatial_terms.is_empty() {
8400 let fitted = fit_term_collection_forspecwith_heuristic_lambdas(
8401 data,
8402 y.view(),
8403 weights.view(),
8404 offset.view(),
8405 &resolvedspec,
8406 best.fit.lambdas.as_slice(),
8407 family,
8408 options,
8409 )?;
8410 return Ok(FittedTermCollectionWithSpec {
8411 fit: fitted.fit,
8412 design: fitted.design,
8413 resolvedspec,
8414 adaptive_diagnostics: fitted.adaptive_diagnostics,
8415 kappa_timing: None,
8416 });
8417 }
8418 let initial_score = fit_score(&best.fit);
8419 if !initial_score.is_finite() {
8420 log::debug!("[spatial-kappa] initial profiled score is non-finite");
8421 }
8422 let seed_length_scales: Vec<(usize, f64)> = spatial_terms
8429 .iter()
8430 .filter_map(|&t| get_spatial_length_scale(&resolvedspec, t).map(|ls| (t, ls)))
8431 .collect();
8432 let joint_result = try_exact_joint_spatial_length_scale_optimization(
8433 data,
8434 y.view(),
8435 weights.view(),
8436 offset.view(),
8437 &resolvedspec,
8438 &best,
8439 family.clone(),
8440 options,
8441 kappa_options,
8442 &spatial_terms,
8443 )
8444 .map(|opt| {
8445 opt.map(|fit| {
8446 let score = fit_score(&fit.fit);
8447 (fit, score)
8448 })
8449 });
8450 let exact_joint = if prescan_improved && !matches!(joint_result, Ok(Some(_))) {
8460 let reason = match &joint_result {
8461 Err(e) => format!("error: {e}"),
8462 _ => "unavailable".to_string(),
8463 };
8464 log::info!(
8465 "[spatial-kappa] #1074 joint polish yielded no usable candidate \
8466 ({reason}); returning the multi-start pre-scan geometry (REML {initial_score:.5})"
8467 );
8468 FittedTermCollectionWithSpec {
8469 fit: best.fit,
8470 design: best.design,
8471 resolvedspec,
8472 adaptive_diagnostics: best.adaptive_diagnostics,
8473 kappa_timing: None,
8474 }
8475 } else {
8476 require_successful_spatial_optimization_result(initial_score, joint_result)?
8477 };
8478
8479 let exact_joint = {
8506 let primary_score = fit_score(&exact_joint.fit);
8507 let improved = primary_score.is_finite()
8508 && initial_score.is_finite()
8509 && primary_score < initial_score - 1e-7 * initial_score.abs().max(1.0);
8510 let base_spec = exact_joint.resolvedspec.clone();
8515 let geometry_unchanged = !seed_length_scales.is_empty()
8518 && seed_length_scales.iter().all(|&(t, seed_ls)| {
8519 get_spatial_length_scale(&base_spec, t)
8520 .is_some_and(|ls| (ls - seed_ls).abs() <= 1e-6 * seed_ls.abs().max(1.0))
8521 });
8522 let eligible = !improved
8523 && geometry_unchanged
8524 && !has_aniso_terms(&base_spec, &spatial_terms)
8525 && constant_curvature_term_indices(&base_spec).is_empty()
8526 && spatial_terms
8527 .iter()
8528 .any(|&t| get_spatial_length_scale(&base_spec, t).is_some());
8529 if eligible {
8530 log::info!(
8531 "[spatial-kappa] #1688 joint solve stalled at REML {primary_score:.5} \
8532 (no improvement over baseline {initial_score:.5}); running ψ-window \
8533 multistart rescue across {} seeds",
8534 JOINT_RESTART_WINDOW_FRACTIONS.len(),
8535 );
8536 let mut best_fit = exact_joint;
8537 let mut best_score = primary_score;
8539 for &fraction in JOINT_RESTART_WINDOW_FRACTIONS.iter() {
8540 match joint_solve_from_window_fraction(
8541 data,
8542 y.view(),
8543 weights.view(),
8544 offset.view(),
8545 &base_spec,
8546 &spatial_terms,
8547 fraction,
8548 &family,
8549 options,
8550 &baseline_options,
8551 kappa_options,
8552 ) {
8553 Ok(Some((candidate, score))) => {
8554 if score.is_finite()
8555 && (!best_score.is_finite()
8556 || score < best_score - 1e-7 * best_score.abs().max(1.0))
8557 {
8558 log::info!(
8559 "[spatial-kappa] #1688 multistart seed (ψ-window {fraction:.2}) \
8560 reached REML {score:.5}, improving on {best_score:.5}",
8561 );
8562 best_score = score;
8563 best_fit = candidate;
8564 }
8565 }
8566 Ok(None) => {}
8568 Err(e) => {
8572 log::info!(
8573 "[spatial-kappa] #1688 multistart seed (ψ-window {fraction:.2}) \
8574 failed ({e}); skipping"
8575 );
8576 }
8577 }
8578 }
8579 best_fit
8580 } else {
8581 exact_joint
8582 }
8583 };
8584
8585 log_spatial_aniso_scales(&exact_joint.resolvedspec);
8586 Ok(exact_joint)
8587}
8588
8589#[derive(Clone, Debug)]
8595pub struct CurvatureInference {
8596 pub term_idx: usize,
8598 pub kappa_hat: f64,
8601 pub ci: gam_geometry::curvature_estimand::KappaProfileCi,
8603 pub flatness: gam_geometry::curvature_estimand::FlatnessTest,
8607}
8608
8609pub fn curvature_inference_forspec(
8627 data: ArrayView2<'_, f64>,
8628 y: ArrayView1<'_, f64>,
8629 weights: ArrayView1<'_, f64>,
8630 offset: ArrayView1<'_, f64>,
8631 resolvedspec: &TermCollectionSpec,
8632 term_idx: usize,
8633 family: LikelihoodSpec,
8634 options: &FitOptions,
8635 level: f64,
8636) -> Result<CurvatureInference, EstimationError> {
8637 let kappa_hat = get_constant_curvature_kappa(resolvedspec, term_idx).ok_or_else(|| {
8638 EstimationError::InvalidInput(format!(
8639 "curvature_inference_forspec: term {term_idx} is not a constant-curvature smooth"
8640 ))
8641 })?;
8642 let (kappa_min, kappa_max) = constant_curvature_kappa_bounds(data, resolvedspec, term_idx);
8643
8644 let cc_fair_inputs: Option<(Array2<f64>, gam_terms::basis::ConstantCurvatureBasisSpec)> =
8669 if kappa_hat < 0.0 {
8670 match resolvedspec.smooth_terms.get(term_idx).map(|t| &t.basis) {
8671 Some(SmoothBasisSpec::ConstantCurvature {
8672 feature_cols, spec, ..
8673 }) => select_columns(data, feature_cols)
8674 .ok()
8675 .map(|x| (x, spec.clone())),
8676 _ => None,
8677 }
8678 } else {
8679 None
8680 };
8681
8682 let v_p_cache: std::cell::RefCell<std::collections::HashMap<u64, f64>> =
8687 std::cell::RefCell::new(std::collections::HashMap::new());
8688 let v_p = |kappa: f64| -> Result<f64, String> {
8689 if !kappa.is_finite() {
8690 return Err(format!("V_p probed a non-finite κ = {kappa}"));
8691 }
8692 let key = kappa.to_bits();
8693 if let Some(&cached) = v_p_cache.borrow().get(&key) {
8694 return Ok(cached);
8695 }
8696 let score = if let Some((x_term, base_spec)) = &cc_fair_inputs {
8697 let mut probe_spec = base_spec.clone();
8698 probe_spec.kappa = kappa;
8699 gam_terms::basis::constant_curvature_kappa_fair_sign_score(x_term.view(), y, &probe_spec)
8700 .map_err(|e| format!("κ-fair criterion at κ={kappa} failed: {e}"))?
8701 } else {
8702 fixed_kappa_profiled_reml_score(
8703 data,
8704 y,
8705 weights,
8706 offset,
8707 resolvedspec,
8708 term_idx,
8709 kappa,
8710 family.clone(),
8711 options,
8712 )
8713 .map_err(|e| format!("V_p fixed-κ fit at κ={kappa} failed: {e}"))?
8714 };
8715 v_p_cache.borrow_mut().insert(key, score);
8716 Ok(score)
8717 };
8718
8719 let h = (1e-3 * (kappa_max - kappa_min)).max(1e-4);
8723 let v_pp = match (v_p(kappa_hat + h), v_p(kappa_hat), v_p(kappa_hat - h)) {
8724 (Ok(vp), Ok(v0), Ok(vm)) => (vp - 2.0 * v0 + vm) / (h * h),
8725 _ => f64::NAN, };
8727
8728 let ci = gam_geometry::curvature_estimand::profile_ci_walk(
8729 &v_p, kappa_hat, v_pp, kappa_min, kappa_max, level, 1e-4,
8730 )
8731 .map_err(EstimationError::InvalidInput)?;
8732 let flatness = gam_geometry::curvature_estimand::flatness_lr_test(&v_p, kappa_hat)
8733 .map_err(EstimationError::InvalidInput)?;
8734
8735 Ok(CurvatureInference {
8736 term_idx,
8737 kappa_hat,
8738 ci,
8739 flatness,
8740 })
8741}
8742
8743#[derive(Clone, Copy, Debug, PartialEq, Eq)]
8746pub enum SmoothLrCorrection {
8747 LawleyLrEstimatedLambda,
8751 LawleyLrFixedLambda,
8756 None,
8760}
8761
8762impl SmoothLrCorrection {
8763 pub fn label(self) -> &'static str {
8765 match self {
8766 SmoothLrCorrection::LawleyLrEstimatedLambda => "lawley_lr_estimated_lambda",
8767 SmoothLrCorrection::LawleyLrFixedLambda => "lawley_lr_fixed_lambda",
8768 SmoothLrCorrection::None => "none",
8769 }
8770 }
8771}
8772
8773#[derive(Clone, Debug)]
8779pub struct SmoothTermLrInference {
8780 pub name: String,
8782 pub term_idx: usize,
8784 pub statistic_lr: f64,
8787 pub ref_df: f64,
8790 pub bartlett_factor: f64,
8793 pub bartlett_factor_conditional: Option<f64>,
8797 pub rho_variation_shift: Option<f64>,
8800 pub statistic_corrected: f64,
8802 pub p_value_uncorrected: f64,
8804 pub p_value_corrected: f64,
8807 pub material: bool,
8815 pub correction: SmoothLrCorrection,
8817}
8818
8819pub const SMOOTH_LR_MATERIAL_THRESHOLD: f64 = 0.10;
8823
8824fn fitted_rho_penalty_components(
8830 penalties: &[BlockwisePenalty],
8831 lambdas: &[f64],
8832 p_total: usize,
8833) -> Result<Vec<gam_terms::inference::lawley::RhoPenaltyComponent>, EstimationError> {
8834 if penalties.len() != lambdas.len() {
8835 return Err(EstimationError::InvalidInput(format!(
8836 "smooth_term_lr_inference: penalty/lambda count mismatch ({} penalties, {} lambdas)",
8837 penalties.len(),
8838 lambdas.len()
8839 )));
8840 }
8841 let mut components = Vec::with_capacity(penalties.len());
8842 for (idx, (penalty, &lambda)) in penalties.iter().zip(lambdas.iter()).enumerate() {
8843 if !(lambda.is_finite() && lambda >= 0.0) {
8844 return Err(EstimationError::InvalidInput(format!(
8845 "smooth_term_lr_inference: lambda[{idx}] is invalid: {lambda}"
8846 )));
8847 }
8848 let r = &penalty.col_range;
8849 if r.end > p_total {
8850 return Err(EstimationError::InvalidInput(format!(
8851 "smooth_term_lr_inference: penalty[{idx}] range {:?} exceeds coefficient dimension {p_total}",
8852 r
8853 )));
8854 }
8855 let mut s_component = Array2::<f64>::zeros((p_total, p_total));
8856 s_component
8857 .slice_mut(s![r.start..r.end, r.start..r.end])
8858 .scaled_add(lambda, &penalty.local);
8859 components.push(gam_terms::inference::lawley::RhoPenaltyComponent { s_component });
8860 }
8861 Ok(components)
8862}
8863
8864pub fn smooth_term_lr_inference_forspec(
8909 data: ArrayView2<'_, f64>,
8910 y: ArrayView1<'_, f64>,
8911 weights: ArrayView1<'_, f64>,
8912 offset: ArrayView1<'_, f64>,
8913 resolvedspec: &TermCollectionSpec,
8914 family: LikelihoodSpec,
8915 options: &FitOptions,
8916) -> Result<Vec<SmoothTermLrInference>, EstimationError> {
8917 use gam_terms::inference::lawley::{
8918 LAWLEY_PAIR_MATRIX_MAX_ROWS, known_scale_expected_jets_with_dispersion,
8919 lawley_lr_bartlett_factor, lawley_lr_mean_shift_with_rho_variation,
8920 };
8921
8922 let n = data.nrows();
8923 let full = fit_term_collection_forspec(
8926 data,
8927 y,
8928 weights,
8929 offset,
8930 resolvedspec,
8931 family.clone(),
8932 options,
8933 )?;
8934 let ll_full = full.fit.log_likelihood;
8935 let p_total = full.design.design.ncols();
8936 let lambdas = full.fit.lambdas.as_slice().ok_or_else(|| {
8937 EstimationError::InvalidInput(
8938 "smooth_term_lr_inference: non-contiguous lambda vector".to_string(),
8939 )
8940 })?;
8941 let s_lambda = weighted_blockwise_penalty_sum(&full.design.penalties, lambdas, p_total);
8942 let rho_penalty_components =
8943 fitted_rho_penalty_components(&full.design.penalties, lambdas, p_total)?;
8944 let rho_covariance = full.fit.artifacts.rho_covariance.as_ref().filter(|cov| {
8945 cov.nrows() == rho_penalty_components.len() && cov.ncols() == rho_penalty_components.len()
8946 });
8947 let full_design_dense = full.design.design.to_dense();
8949 let influence = full.fit.coefficient_influence();
8950 let family_disp = lawley_dispersion_for_family(&family, &full.fit);
8951
8952 let mut penalty_cursor = full.design.leading_penalty_blocks_before_smooth();
8956 let mut out = Vec::<SmoothTermLrInference>::new();
8957 for (term_idx, design_term) in full.design.smooth.terms.iter().enumerate() {
8958 let k = design_term.penalties_local.len();
8959 let block_start = penalty_cursor;
8960 penalty_cursor += k;
8961 if design_term.shape != ShapeConstraint::None {
8964 continue;
8965 }
8966 let coeff_range = design_term.coeff_range.clone();
8967 if coeff_range.start >= coeff_range.end || coeff_range.end > p_total {
8968 continue;
8969 }
8970 let edf = full.fit.per_term_edf(coeff_range.clone(), block_start, k);
8982 let null_dim = design_term.wald_unpenalized_dim();
9002 let edf_floor = (null_dim.max(1)) as f64;
9054 let untrusted_edf_collapse = !full.fit.outer_converged && edf < edf_floor;
9055 let unconverged_dim_floor = if untrusted_edf_collapse {
9056 coeff_range.len() as f64
9057 } else {
9058 0.0
9059 };
9060 let ref_df = wood_reference_df(influence, &coeff_range)
9061 .unwrap_or(0.0)
9062 .max(edf)
9063 .max(null_dim as f64)
9064 .max(unconverged_dim_floor)
9065 .max(1.0);
9066 if !(ref_df.is_finite() && ref_df > 0.0) {
9067 continue;
9068 }
9069
9070 let mut null_spec = resolvedspec.clone();
9073 let Some(spec_pos) = null_spec
9074 .smooth_terms
9075 .iter()
9076 .position(|t| t.name == design_term.name)
9077 else {
9078 continue;
9079 };
9080 null_spec.smooth_terms.remove(spec_pos);
9081 let null_fit = fit_term_collection_forspec(
9082 data,
9083 y,
9084 weights,
9085 offset,
9086 &null_spec,
9087 family.clone(),
9088 options,
9089 );
9090 let (statistic_lr, eta_null) = match null_fit {
9091 Ok(null) if null.fit.log_likelihood.is_finite() => {
9092 let w = (2.0 * (ll_full - null.fit.log_likelihood)).max(0.0);
9093 let mut eta = null.design.design.dot(&null.fit.beta);
9097 eta += &offset;
9098 (w, Some(eta))
9099 }
9100 _ => (f64::NAN, None),
9101 };
9102
9103 let chi2 = statrs::distribution::ChiSquared::new(ref_df).ok();
9104 let p_uncorrected = match (chi2.as_ref(), statistic_lr.is_finite()) {
9105 (Some(dist), true) => {
9106 use statrs::distribution::ContinuousCDF;
9107 (1.0 - dist.cdf(statistic_lr)).clamp(0.0, 1.0)
9108 }
9109 _ => f64::NAN,
9110 };
9111
9112 let mut bartlett_factor = 1.0;
9116 let mut bartlett_factor_conditional = None;
9117 let mut rho_variation_shift = None;
9118 let mut statistic_corrected = statistic_lr;
9119 let mut p_corrected = p_uncorrected;
9120 let mut correction = SmoothLrCorrection::None;
9121 if let (Some(eta), true, true) = (
9122 eta_null.as_ref(),
9123 statistic_lr.is_finite(),
9124 n <= LAWLEY_PAIR_MATRIX_MAX_ROWS,
9125 ) {
9126 let kappas: Option<Vec<_>> = (0..n)
9127 .map(|i| {
9128 known_scale_expected_jets_with_dispersion(&family, eta[i], family_disp)
9129 .and_then(|jets| jets.kappas().ok())
9130 })
9131 .collect();
9132 if let (Some(kappas), Some(dist)) = (kappas, chi2.as_ref()) {
9133 let fixed_factor = lawley_lr_bartlett_factor(
9134 full_design_dense.view(),
9135 &kappas,
9136 Some(s_lambda.view()),
9137 coeff_range.clone(),
9138 ref_df,
9139 );
9140 if let Ok(c_cond) = fixed_factor
9141 && c_cond.is_finite()
9142 && c_cond > 0.0
9143 {
9144 let mut c_applied = c_cond;
9145 correction = SmoothLrCorrection::LawleyLrFixedLambda;
9146 if let Some(cov) = rho_covariance
9147 && let Ok(total_shift) = lawley_lr_mean_shift_with_rho_variation(
9148 full_design_dense.view(),
9149 &kappas,
9150 s_lambda.view(),
9151 coeff_range.clone(),
9152 &rho_penalty_components,
9153 cov.view(),
9154 )
9155 {
9156 let mean_w = ref_df + total_shift;
9157 if let Some(c_est) =
9158 gam_terms::inference::higher_order::bartlett_factor_from_mean(
9159 mean_w, ref_df,
9160 )
9161 && c_est.is_finite()
9162 && c_est > 0.0
9163 {
9164 let conditional_shift = (c_cond - 1.0) * ref_df;
9165 c_applied = c_est;
9166 bartlett_factor_conditional = Some(c_cond);
9167 rho_variation_shift = Some(total_shift - conditional_shift);
9168 correction = SmoothLrCorrection::LawleyLrEstimatedLambda;
9169 }
9170 }
9171 use statrs::distribution::ContinuousCDF;
9172 bartlett_factor = c_applied;
9173 statistic_corrected = statistic_lr / c_applied;
9174 p_corrected = (1.0 - dist.cdf(statistic_corrected)).clamp(0.0, 1.0);
9175 }
9176 }
9177 }
9178
9179 let material = match correction {
9185 SmoothLrCorrection::LawleyLrEstimatedLambda
9186 | SmoothLrCorrection::LawleyLrFixedLambda => {
9187 let factor_move = (bartlett_factor - 1.0).abs();
9188 let p_denom = p_uncorrected.max(p_corrected).max(f64::MIN_POSITIVE);
9189 let p_move = if p_uncorrected.is_finite() && p_corrected.is_finite() {
9190 (p_corrected - p_uncorrected).abs() / p_denom
9191 } else {
9192 0.0
9193 };
9194 factor_move > SMOOTH_LR_MATERIAL_THRESHOLD || p_move > SMOOTH_LR_MATERIAL_THRESHOLD
9195 }
9196 SmoothLrCorrection::None => false,
9197 };
9198
9199 out.push(SmoothTermLrInference {
9200 name: design_term.name.clone(),
9201 term_idx,
9202 statistic_lr,
9203 ref_df,
9204 bartlett_factor,
9205 bartlett_factor_conditional,
9206 rho_variation_shift,
9207 statistic_corrected,
9208 p_value_uncorrected: p_uncorrected,
9209 p_value_corrected: p_corrected,
9210 material,
9211 correction,
9212 });
9213 }
9214 Ok(out)
9215}
9216
9217fn lawley_dispersion_for_family(family: &LikelihoodSpec, fit: &UnifiedFitResult) -> f64 {
9220 match family.response {
9221 gam_spec::ResponseFamily::Gaussian => {
9222 let sd = fit.standard_deviation;
9223 (sd * sd).max(f64::MIN_POSITIVE)
9224 }
9225 gam_spec::ResponseFamily::Gamma => {
9226 let shape = fit.standard_deviation;
9227 if shape.is_finite() && shape > 0.0 {
9228 1.0 / shape
9229 } else {
9230 1.0
9231 }
9232 }
9233 _ => 1.0,
9234 }
9235}
9236
9237fn wood_reference_df(influence: Option<&Array2<f64>>, coeff_range: &Range<usize>) -> Option<f64> {
9261 let f = influence?;
9262 let (start, end) = (coeff_range.start, coeff_range.end);
9263 if start >= end || end > f.nrows() || end > f.ncols() {
9264 return None;
9265 }
9266 let block = f.slice(s![start..end, start..end]);
9267 let tr = (0..block.nrows()).map(|i| block[[i, i]]).sum::<f64>();
9268 let tr2 = block.dot(&block).diag().sum();
9269 (tr.is_finite() && tr2.is_finite() && tr > 0.0)
9270 .then(|| (2.0 * tr - tr2).max(tr).max(1e-12))
9271}