1fn try_build_spatial_term_log_kappa_derivative(
2 data: ArrayView2<'_, f64>,
3 resolvedspec: &TermCollectionSpec,
4 design: &TermCollectionDesign,
5 term_idx: usize,
6) -> Result<
7 Option<(
8 Range<usize>,
9 usize,
10 Array2<f64>,
11 Array2<f64>,
12 Array2<f64>,
13 Array2<f64>,
14 Vec<Array2<f64>>,
15 Vec<Array2<f64>>,
16 Option<std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>>,
17 )>,
18 EstimationError,
19> {
20 let Some(smooth_term) = design.smooth.terms.get(term_idx) else {
21 return Ok(None);
22 };
23 let Some(termspec) = resolvedspec.smooth_terms.get(term_idx) else {
24 return Ok(None);
25 };
26
27 let derivative_bundle = match &termspec.basis {
28 SmoothBasisSpec::ThinPlate {
29 feature_cols,
30 spec,
31 input_scales,
32 } => {
33 let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
34 let mut spec_local = spec.clone();
35 if let Some(s) = input_scales {
36 apply_input_standardization(&mut x, s);
37 spec_local.length_scale =
38 compensate_length_scale_for_standardization(spec.length_scale, s);
39 }
40 build_thin_plate_basis_log_kappa_derivatives(x.view(), &spec_local)
41 .map_err(EstimationError::from)?
42 }
43 SmoothBasisSpec::Sphere { .. } => return Ok(None),
44 SmoothBasisSpec::ConstantCurvature { feature_cols, spec } => {
53 let x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
54 build_constant_curvature_basis_kappa_derivatives(x.view(), spec)
55 .map_err(EstimationError::from)?
56 }
57 SmoothBasisSpec::MeasureJet { .. } => return Ok(None),
63 SmoothBasisSpec::Matern {
64 feature_cols,
65 spec,
66 input_scales,
67 } => {
68 let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
69 let mut spec_local = spec.clone();
70 if let Some(s) = input_scales {
71 apply_input_standardization(&mut x, s);
72 spec_local.length_scale =
73 compensate_length_scale_for_standardization(spec.length_scale, s);
74 }
75 spec_local.double_penalty = false;
90 build_matern_basis_log_kappa_derivatives(x.view(), &spec_local)
91 .map_err(EstimationError::from)?
92 }
93 SmoothBasisSpec::Duchon {
94 feature_cols,
95 spec,
96 input_scales,
97 } => {
98 let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
99 let mut spec_local = spec.clone();
100 if let Some(s) = input_scales {
101 apply_input_standardization(&mut x, s);
102 spec_local.length_scale =
103 compensate_optional_length_scale_for_standardization(spec.length_scale, s);
104 }
105 let BasisMetadata::Duchon {
106 centers,
107 identifiability_transform,
108 operator_collocation_points,
109 radial_reparam,
110 ..
111 } = &smooth_term.metadata
112 else {
113 return Ok(None);
114 };
115 if spec_local.radial_reparam.is_none() {
118 spec_local.radial_reparam = radial_reparam.clone();
119 }
120 gam_terms::basis::build_duchon_basis_log_kappa_derivativeswith_collocationwithworkspace(
121 x.view(),
122 &spec_local,
123 centers.view(),
124 identifiability_transform.as_ref(),
125 operator_collocation_points
126 .as_ref()
127 .map(|points| points.view()),
128 &mut BasisWorkspace::default(),
129 )
130 .map_err(EstimationError::from)?
131 }
132 SmoothBasisSpec::BSpline1D { .. }
133 | SmoothBasisSpec::TensorBSpline { .. }
134 | SmoothBasisSpec::ByVariable { .. }
135 | SmoothBasisSpec::FactorSumToZero { .. }
136 | SmoothBasisSpec::BySmooth { .. }
137 | SmoothBasisSpec::FactorSmooth { .. }
138 | SmoothBasisSpec::Pca { .. } => {
139 return Ok(None);
140 }
141 };
142 let mut implicit_operator = derivative_bundle.implicit_operator;
143 let BasisPsiDerivativeResult {
144 design_derivative: mut local_x_psi,
145 penalties_derivative: mut local_s_psi,
146 implicit_operator: local_implicit_first_unused,
147 } = derivative_bundle.first;
148 let BasisPsiSecondDerivativeResult {
149 designsecond_derivative: mut local_x_psi_psi,
150 penaltiessecond_derivative: mut local_s_psi_psi,
151 implicit_operator: local_implicit_second_unused,
152 } = derivative_bundle.second;
153 assert!(local_implicit_first_unused.is_none());
154 assert!(local_implicit_second_unused.is_none());
155
156 if let SmoothBasisSpec::Matern {
160 feature_cols,
161 spec,
162 input_scales,
163 } = &termspec.basis
164 {
165 let mut xf = select_columns(data, feature_cols).map_err(EstimationError::from)?;
166 let mut sp = spec.clone();
167 if let Some(s) = input_scales {
168 apply_input_standardization(&mut xf, s.as_slice());
169 sp.length_scale =
170 compensate_length_scale_for_standardization(spec.length_scale, s.as_slice());
171 }
172 sp.double_penalty = false;
173 let ls0 = sp.length_scale;
175 let h = 1e-6_f64;
176 let value_design = |delta: f64| -> Option<Array2<f64>> {
177 let mut s2 = sp.clone();
178 s2.length_scale = ls0 * (-delta).exp();
179 gam_terms::basis::build_matern_basis(xf.view(), &s2)
180 .ok()
181 .map(|b| b.design.to_dense())
182 };
183 if let (Some(dp), Some(dm)) = (value_design(h), value_design(-h)) {
184 if dp.shape() == local_x_psi.shape() {
185 let num = (&dp - &dm) / (2.0 * h);
186 let err = (&local_x_psi - &num)
187 .mapv(f64::abs)
188 .iter()
189 .fold(0.0_f64, |a, &b| a.max(b));
190 let anorm = local_x_psi.iter().map(|v| v * v).sum::<f64>().sqrt();
191 log::warn!(
192 "[OUTER-FD-AUDIT XPSIFD-1122] x_psi_max_abs_err={err:.3e} |x_psi|={anorm:.3e} \
193 shape={:?} centers_frozen={}",
194 local_x_psi.shape(),
195 match &sp.center_strategy {
196 gam_terms::basis::CenterStrategy::UserProvided(c) => c.nrows(),
197 _ => 0,
198 },
199 );
200 if let Some(d0) = value_design(0.0) {
204 let p_total = design.design.ncols();
205 let smooth_start =
206 p_total.saturating_sub(design.smooth.total_smooth_cols());
207 let g0 = smooth_start + smooth_term.coeff_range.start;
208 let g1 = smooth_start + smooth_term.coeff_range.end;
209 let realized = design.design.to_dense();
210 if g1 <= realized.ncols() && d0.ncols() == (g1 - g0) {
211 let block = realized.slice(ndarray::s![.., g0..g1]).to_owned();
212 let dmax = (&block - &d0)
213 .mapv(f64::abs)
214 .iter()
215 .fold(0.0_f64, |a, &b| a.max(b));
216 let bnorm = block.iter().map(|v| v * v).sum::<f64>().sqrt();
218 let d0norm = d0.iter().map(|v| v * v).sum::<f64>().sqrt();
219 log::warn!(
220 "[OUTER-FD-AUDIT XDESIGN-1122] realized_vs_valuebuild_max_abs={dmax:.3e} \
221 |realized_block|={bnorm:.4e} |value_build|={d0norm:.4e} \
222 block_shape={:?} g0={g0} g1={g1} p_total={p_total}",
223 block.shape(),
224 );
225 } else {
226 log::warn!(
227 "[OUTER-FD-AUDIT XDESIGN-1122] shape/range mismatch realized_cols={} g0={g0} g1={g1} d0_cols={}",
228 realized.ncols(),
229 d0.ncols()
230 );
231 }
232 }
233 } else {
234 log::warn!(
235 "[OUTER-FD-AUDIT XPSIFD-1122] shape mismatch analytic={:?} value={:?}",
236 local_x_psi.shape(),
237 dp.shape()
238 );
239 }
240 }
241 }
242
243 {
253 use gam_terms::smooth::matern_operator_penalty_triplet_at_length_scale;
254 if let BasisMetadata::Matern {
255 centers,
256 length_scale,
257 periodic,
258 nu,
259 include_intercept,
260 identifiability_transform,
261 aniso_log_scales,
262 input_scales,
263 ..
264 } = &smooth_term.metadata
265 {
266 let ls_eff = match input_scales.as_deref() {
271 Some(s) => compensate_length_scale_for_standardization(*length_scale, s),
272 None => *length_scale,
273 };
274 let h = 1e-6_f64;
276 let summed_at = |delta: f64| -> Option<Array2<f64>> {
277 let ls = ls_eff * (-delta).exp();
278 let (mats, _ranks, _info) = matern_operator_penalty_triplet_at_length_scale(
279 centers.view(),
280 periodic.as_deref(),
281 identifiability_transform.as_ref(),
282 *nu,
283 *include_intercept,
284 aniso_log_scales.as_deref(),
285 ls,
286 )
287 .ok()?;
288 let p = mats.first().map(|m: &Array2<f64>| m.nrows()).unwrap_or(0);
289 let mut acc = Array2::<f64>::zeros((p, p));
290 for m in &mats {
291 if m.shape() == acc.shape() {
292 acc += m;
293 }
294 }
295 Some(acc)
296 };
297 let analytic_sum = local_s_psi.iter().fold(
298 Array2::<f64>::zeros((
299 smooth_term.coeff_range.len(),
300 smooth_term.coeff_range.len(),
301 )),
302 |acc, m| {
303 if m.shape() == acc.shape() {
304 acc + m
305 } else {
306 acc
307 }
308 },
309 );
310 if let (Some(sp), Some(sm), Some(s0)) =
311 (summed_at(h), summed_at(-h), summed_at(0.0))
312 {
313 if sp.shape() == analytic_sum.shape() {
314 let num = (&sp - &sm) / (2.0 * h);
315 let err = (&analytic_sum - &num)
316 .mapv(f64::abs)
317 .iter()
318 .fold(0.0_f64, |a, &b| a.max(b));
319 let anorm = analytic_sum.iter().map(|v| v * v).sum::<f64>().sqrt();
320 let nnorm = num.iter().map(|v| v * v).sum::<f64>().sqrt();
321 log::warn!(
325 "[OUTER-FD-AUDIT SPSIFD-1122] s_psi_max_abs_err={err:.3e} \
326 |analytic|={anorm:.3e} |fd|={nnorm:.3e} blocks={} \
327 shape={:?} |S0|={:.3e}",
328 local_s_psi.len(),
329 analytic_sum.shape(),
330 s0.iter().map(|v| v * v).sum::<f64>().sqrt(),
331 );
332 let per_block = |delta: f64| -> Option<Vec<Array2<f64>>> {
335 let ls = ls_eff * (-delta).exp();
336 matern_operator_penalty_triplet_at_length_scale(
337 centers.view(),
338 periodic.as_deref(),
339 identifiability_transform.as_ref(),
340 *nu,
341 *include_intercept,
342 aniso_log_scales.as_deref(),
343 ls,
344 )
345 .ok()
346 .map(|(m, _, _)| m)
347 };
348 if let (Some(bp), Some(bm)) = (per_block(h), per_block(-h)) {
349 for (bi, (ap, am)) in bp.iter().zip(bm.iter()).enumerate() {
350 if bi < local_s_psi.len()
351 && ap.shape() == local_s_psi[bi].shape()
352 {
353 let bnum = (ap - am) / (2.0 * h);
354 let berr = (&local_s_psi[bi] - &bnum)
355 .mapv(f64::abs)
356 .iter()
357 .fold(0.0_f64, |a, &b| a.max(b));
358 let banorm =
359 local_s_psi[bi].iter().map(|v| v * v).sum::<f64>().sqrt();
360 log::warn!(
361 "[OUTER-FD-AUDIT SPSIFD-1122 BLOCK] block={bi} max_abs_err={berr:.3e} |analytic|={banorm:.3e}"
362 );
363 }
364 }
365 }
366 } else {
367 log::warn!(
368 "[OUTER-FD-AUDIT SPSIFD-1122] shape mismatch analytic={:?} fd={:?}",
369 analytic_sum.shape(),
370 sp.shape()
371 );
372 }
373 }
374 }
375 }
376
377 log::warn!(
380 "[OUTER-FD-AUDIT TEMP-ROT-1122] analytic joint_null_rotation={} nullity={} x_psi={}x{}",
381 smooth_term.joint_null_rotation.is_some(),
382 smooth_term
383 .joint_null_rotation
384 .as_ref()
385 .map(|r| r.joint_nullity)
386 .unwrap_or(0),
387 local_x_psi.nrows(),
388 local_x_psi.ncols(),
389 );
390 if let Some(rotation) = smooth_term.joint_null_rotation.as_ref() {
391 let q = &rotation.rotation;
392 if let Some(op) = implicit_operator.take() {
393 implicit_operator = Some(op.append_full_transform(q).map_err(EstimationError::from)?);
394 } else {
395 if local_x_psi.ncols() != q.nrows() || local_x_psi_psi.ncols() != q.nrows() {
396 return Ok(None);
397 }
398 local_x_psi = fast_ab(&local_x_psi, q);
399 local_x_psi_psi = fast_ab(&local_x_psi_psi, q);
400 }
401 let rotate_penalty = |s_local: Array2<f64>| -> Option<Array2<f64>> {
402 if s_local.nrows() != q.nrows() || s_local.ncols() != q.nrows() {
403 return None;
404 }
405 let qt_s = gam_linalg::faer_ndarray::fast_atb(q, &s_local);
406 Some(gam_linalg::faer_ndarray::fast_ab(&qt_s, q))
407 };
408 let Some(rotated_s_psi) = local_s_psi
409 .into_iter()
410 .map(|s| rotate_penalty(s))
411 .collect::<Option<Vec<_>>>()
412 else {
413 return Ok(None);
414 };
415 local_s_psi = rotated_s_psi;
416 let Some(rotated_s_psi_psi) = local_s_psi_psi
417 .into_iter()
418 .map(|s| rotate_penalty(s))
419 .collect::<Option<Vec<_>>>()
420 else {
421 return Ok(None);
422 };
423 local_s_psi_psi = rotated_s_psi_psi;
424 }
425 let implicit_operator = implicit_operator.map(std::sync::Arc::new);
426
427 if let Some(ref op) = implicit_operator {
428 if op.p_out() != smooth_term.coeff_range.len() {
429 return Ok(None);
430 }
431 } else {
432 if local_x_psi.ncols() != smooth_term.coeff_range.len() {
433 return Ok(None);
434 }
435 if local_x_psi_psi.ncols() != smooth_term.coeff_range.len() {
436 return Ok(None);
437 }
438 }
439 if local_s_psi.is_empty() || local_s_psi.len() != local_s_psi_psi.len() {
440 return Ok(None);
441 }
442 if local_s_psi.iter().any(|s| {
443 s.nrows() != smooth_term.coeff_range.len() || s.ncols() != smooth_term.coeff_range.len()
444 }) {
445 return Ok(None);
446 }
447 if local_s_psi_psi.iter().any(|s| {
448 s.nrows() != smooth_term.coeff_range.len() || s.ncols() != smooth_term.coeff_range.len()
449 }) {
450 return Ok(None);
451 }
452
453 let p_total = design.design.ncols();
454 let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
455 let global_range = (smooth_start + smooth_term.coeff_range.start)
456 ..(smooth_start + smooth_term.coeff_range.end);
457
458 Ok(Some((
459 global_range,
460 p_total,
461 local_x_psi,
462 local_s_psi.iter().fold(
463 Array2::<f64>::zeros((smooth_term.coeff_range.len(), smooth_term.coeff_range.len())),
464 |acc, m| acc + m,
465 ),
466 local_x_psi_psi,
467 local_s_psi_psi.iter().fold(
468 Array2::<f64>::zeros((smooth_term.coeff_range.len(), smooth_term.coeff_range.len())),
469 |acc, m| acc + m,
470 ),
471 local_s_psi,
472 local_s_psi_psi,
473 implicit_operator,
474 )))
475}
476
477fn try_build_spatial_log_kappa_hyper_dirs(
478 data: ArrayView2<'_, f64>,
479 resolvedspec: &TermCollectionSpec,
480 design: &TermCollectionDesign,
481 spatial_terms: &[usize],
482) -> Result<Option<Vec<DirectionalHyperParam>>, EstimationError> {
483 let Some(info_list) =
490 try_build_spatial_log_kappa_derivativeinfo_list(data, resolvedspec, design, spatial_terms)?
491 else {
492 return Ok(None);
493 };
494 Ok(Some(spatial_log_kappa_hyper_dirs_frominfo_list(info_list)?))
495}
496
497pub(crate) fn try_build_latent_coord_hyper_dirs(
498 latent: std::sync::Arc<gam_terms::latent::LatentCoordValues>,
499 resolvedspec: &TermCollectionSpec,
500 design: &TermCollectionDesign,
501 latent_terms: &[gam_problem::types::SmoothTermIdx],
502 analytic_rho_count: usize,
503) -> Result<Option<Vec<DirectionalHyperParam>>, EstimationError> {
504 if latent_terms.is_empty() || latent.is_empty() {
505 return Ok(None);
506 }
507 if latent_terms.len() != 1 {
508 crate::bail_invalid_estim!(
509 "LatentCoord standard-fit hyper_dirs currently require exactly one latent smooth term"
510 .to_string(),
511 );
512 }
513 let term_idx = latent_terms[0];
514 let smooth_term = design.smooth.terms.get(term_idx.get()).ok_or_else(|| {
515 EstimationError::InvalidInput(format!(
516 "LatentCoord term index {term_idx} out of bounds for realized smooth design"
517 ))
518 })?;
519 let termspec = resolvedspec
520 .smooth_terms
521 .get(term_idx.get())
522 .ok_or_else(|| {
523 EstimationError::InvalidInput(format!(
524 "LatentCoord term index {term_idx} out of bounds for resolved smooth spec"
525 ))
526 })?;
527 let p_total = design.design.ncols();
528 let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
529 let global_range = (smooth_start + smooth_term.coeff_range.start)
530 ..(smooth_start + smooth_term.coeff_range.end);
531
532 let operator = match (&termspec.basis, &smooth_term.metadata) {
537 (
538 SmoothBasisSpec::Matern { .. },
539 BasisMetadata::Matern {
540 centers,
541 length_scale,
542 nu,
543 include_intercept,
544 identifiability_transform,
545 ..
546 },
547 ) => gam_terms::basis::LatentCoordDesignDerivative::new_matern(
548 latent.clone(),
549 std::sync::Arc::new(centers.clone()),
550 *length_scale,
551 *nu,
552 *include_intercept,
553 identifiability_transform.clone(),
554 )
555 .map_err(EstimationError::from)?,
556 (
557 SmoothBasisSpec::Duchon { .. },
558 BasisMetadata::Duchon {
559 centers,
560 length_scale,
561 power,
562 nullspace_order,
563 identifiability_transform,
564 ..
565 },
566 ) => gam_terms::basis::LatentCoordDesignDerivative::new_duchon(
567 latent.clone(),
568 std::sync::Arc::new(centers.clone()),
569 *length_scale,
570 *power,
571 *nullspace_order,
572 identifiability_transform.clone(),
573 )
574 .map_err(EstimationError::from)?,
575 (
576 SmoothBasisSpec::Sphere { .. },
577 BasisMetadata::Sphere {
578 centers,
579 penalty_order,
580 method,
581 constraint_transform,
582 ..
583 },
584 ) if matches!(*method, gam_terms::basis::SphereMethod::Wahba) => {
585 gam_terms::basis::LatentCoordDesignDerivative::new_sphere(
586 latent.clone(),
587 std::sync::Arc::new(centers.clone()),
588 *penalty_order,
589 constraint_transform.clone(),
590 )
591 .map_err(EstimationError::from)?
592 }
593 (
594 SmoothBasisSpec::BSpline1D { spec, .. },
595 BasisMetadata::BSpline1D {
596 knots,
597 identifiability_transform,
598 periodic,
599 degree: meta_degree,
600 ..
601 },
602 ) => {
603 let effective_degree = meta_degree.unwrap_or(spec.degree);
607 if let Some((domain_start, period, num_basis)) = periodic {
608 gam_terms::basis::LatentCoordDesignDerivative::new_periodic_bspline(
609 latent.clone(),
610 (*domain_start, *domain_start + *period),
611 effective_degree,
612 *num_basis,
613 identifiability_transform.clone(),
614 )
615 .map_err(EstimationError::from)?
616 } else {
617 gam_terms::basis::LatentCoordDesignDerivative::new_tensor_bspline(
618 latent.clone(),
619 vec![knots.clone()],
620 vec![effective_degree],
621 identifiability_transform.clone(),
622 )
623 .map_err(EstimationError::from)?
624 }
625 }
626 (
627 SmoothBasisSpec::TensorBSpline { .. },
628 BasisMetadata::TensorBSpline {
629 knots,
630 degrees,
631 identifiability_transform,
632 ..
633 },
634 ) => gam_terms::basis::LatentCoordDesignDerivative::new_tensor_bspline(
635 latent.clone(),
636 knots.clone(),
637 degrees.clone(),
638 identifiability_transform.clone(),
639 )
640 .map_err(EstimationError::from)?,
641 (SmoothBasisSpec::Pca { .. }, BasisMetadata::Pca { basis_matrix, .. }) => {
642 gam_terms::basis::LatentCoordDesignDerivative::new_pca(
643 latent.clone(),
644 std::sync::Arc::new(basis_matrix.clone()),
645 )
646 .map_err(EstimationError::from)?
647 }
648 _ => return Ok(None),
649 };
650 if operator.p_out() != global_range.len() {
651 crate::bail_invalid_estim!(
652 "LatentCoord derivative width mismatch for term '{}': operator p={}, coeff range={}",
653 smooth_term.name,
654 operator.p_out(),
655 global_range.len()
656 );
657 }
658 let operator = std::sync::Arc::new(operator);
659 let mut hyper_dirs = Vec::with_capacity(operator.n_axes());
660 for flat_axis in 0..operator.n_axes() {
661 let dir = DirectionalHyperParam::new_compact(
662 gam_solve::estimate::reml::HyperDesignDerivative::from_latent_coord(
663 operator.clone(),
664 flat_axis,
665 global_range.clone(),
666 p_total,
667 ),
668 Vec::new(),
669 None,
670 None,
671 )?
672 .not_penalty_like();
673 hyper_dirs.push(dir);
674 }
675 let direct_dim = latent_coord_direct_hyper_count(latent.id_mode(), latent.latent_dim());
676 if analytic_rho_count + direct_dim > 0 {
677 let zero_x = gam_solve::estimate::reml::HyperDesignDerivative::from(Array2::<f64>::zeros((
678 design.design.nrows(),
679 p_total,
680 )));
681 for _ in 0..analytic_rho_count {
682 hyper_dirs.push(
683 DirectionalHyperParam::new_compact(zero_x.clone(), Vec::new(), None, None)?
684 .not_penalty_like(),
685 );
686 }
687 for _ in 0..direct_dim {
688 hyper_dirs.push(
689 DirectionalHyperParam::new_compact(zero_x.clone(), Vec::new(), None, None)?
690 .not_penalty_like(),
691 );
692 }
693 }
694 Ok(Some(hyper_dirs))
695}
696
697fn latent_coord_direct_hyper_count(
698 id_mode: &gam_terms::latent::LatentIdMode,
699 latent_dim: usize,
700) -> usize {
701 use gam_terms::latent::{AuxPriorStrength, LatentIdMode};
702 match id_mode {
703 LatentIdMode::AuxPrior { strength, .. } => match strength {
704 AuxPriorStrength::Auto => 1,
705 AuxPriorStrength::Fixed(_) => 0,
706 },
707 LatentIdMode::AuxPriorDimSelection { strength, .. } => {
708 latent_dim
709 + match strength {
710 AuxPriorStrength::Auto => 1,
711 AuxPriorStrength::Fixed(_) => 0,
712 }
713 }
714 LatentIdMode::DimSelection { .. } => latent_dim,
715 LatentIdMode::IsometryToReference { strength, .. } => match strength {
718 AuxPriorStrength::Auto => 1,
719 AuxPriorStrength::Fixed(_) => 0,
720 },
721 LatentIdMode::AuxOutcome { head, .. } => head.n_coeffs(latent_dim) + latent_dim,
724 LatentIdMode::None => 0,
725 }
726}
727
728fn latent_coord_initial_direct_hypers(
729 id_mode: &gam_terms::latent::LatentIdMode,
730 latent_dim: usize,
731) -> Result<Array1<f64>, EstimationError> {
732 use gam_terms::latent::{AuxPriorStrength, LatentIdMode};
733 let mut values = Vec::with_capacity(latent_coord_direct_hyper_count(id_mode, latent_dim));
734 match id_mode {
735 LatentIdMode::AuxPrior { strength, .. } => {
736 if matches!(strength, AuxPriorStrength::Auto) {
737 values.push(0.0);
738 }
739 }
740 LatentIdMode::AuxPriorDimSelection {
741 strength,
742 init_log_precision,
743 ..
744 } => {
745 if matches!(strength, AuxPriorStrength::Auto) {
746 values.push(0.0);
747 }
748 append_latent_ard_seed(&mut values, init_log_precision.as_ref(), latent_dim)?;
749 }
750 LatentIdMode::DimSelection { init_log_precision } => {
751 append_latent_ard_seed(&mut values, init_log_precision.as_ref(), latent_dim)?;
752 }
753 LatentIdMode::IsometryToReference { strength, .. } => {
754 if matches!(strength, AuxPriorStrength::Auto) {
755 values.push(0.0);
756 }
757 }
758 LatentIdMode::AuxOutcome {
759 head,
760 init_log_precision,
761 } => {
762 values.extend(std::iter::repeat_n(0.0, head.n_coeffs(latent_dim)));
766 append_latent_ard_seed(&mut values, init_log_precision.as_ref(), latent_dim)?;
767 }
768 LatentIdMode::None => {}
769 }
770 Ok(Array1::from_vec(values))
771}
772
773fn append_latent_ard_seed(
774 values: &mut Vec<f64>,
775 init: Option<&Array1<f64>>,
776 latent_dim: usize,
777) -> Result<(), EstimationError> {
778 if let Some(init) = init {
779 if init.len() != latent_dim {
780 crate::bail_invalid_estim!(
781 "latent dim_selection init_log_precision length mismatch: got {}, expected {}",
782 init.len(),
783 latent_dim
784 );
785 }
786 values.extend(init.iter().copied());
787 } else {
788 values.extend(std::iter::repeat_n(0.0, latent_dim));
789 }
790 Ok(())
791}
792
793struct LatentIdObjectiveContribution {
794 cost: f64,
795 gradient: Array1<f64>,
796}
797
798fn latent_id_objective_contribution(
799 theta: &Array1<f64>,
800 rho_dim: usize,
801 analytic_rho_count: usize,
802 latent: &gam_terms::latent::LatentCoordValues,
803) -> Result<LatentIdObjectiveContribution, EstimationError> {
804 use gam_terms::latent::{AuxPriorStrength, LatentIdMode, aux_prior_targets};
805 let n_obs = latent.n_obs();
806 let latent_dim = latent.latent_dim();
807 let flat_len = latent.len();
808 let mut gradient = Array1::<f64>::zeros(theta.len());
809 let t_start = rho_dim;
810 let direct_start = t_start + flat_len + analytic_rho_count;
811 if theta.len() < direct_start {
812 crate::bail_invalid_estim!(
813 "latent-coordinate theta too short for id objective: got {}, need at least {}",
814 theta.len(),
815 direct_start
816 );
817 }
818 let t = latent.as_matrix();
819 let mut cost = 0.0;
820 let mut cursor = direct_start;
821
822 match latent.id_mode() {
823 LatentIdMode::AuxPrior {
824 u,
825 family,
826 strength,
827 }
828 | LatentIdMode::AuxPriorDimSelection {
829 u,
830 family,
831 strength,
832 ..
833 } => {
834 let (log_mu, mu) = match strength {
835 AuxPriorStrength::Fixed(mu) => (mu.ln(), *mu),
836 AuxPriorStrength::Auto => {
837 let log_mu = theta[cursor];
838 cursor += 1;
839 (log_mu, log_mu.exp())
840 }
841 };
842 let targets = aux_prior_targets(t.view(), u.view(), *family)
843 .map_err(EstimationError::InvalidInput)?;
844 let residual = &t - &targets;
845 let q = residual.iter().map(|v| v * v).sum::<f64>();
846 let k = (n_obs * latent_dim) as f64;
853 cost += 0.5 * mu * q - 0.5 * k * log_mu;
854
855 let projected_residual = aux_prior_targets(residual.view(), u.view(), *family)
856 .map_err(EstimationError::InvalidInput)?;
857 let grad_base = residual - projected_residual;
858 for n in 0..n_obs {
859 for axis in 0..latent_dim {
860 gradient[t_start + n * latent_dim + axis] += mu * grad_base[[n, axis]];
861 }
862 }
863 if matches!(strength, AuxPriorStrength::Auto) {
864 gradient[direct_start] += 0.5 * mu * q - 0.5 * k;
865 }
866 }
867 LatentIdMode::IsometryToReference { reference, strength } => {
868 if reference.dim() != (n_obs, latent_dim) {
875 crate::bail_invalid_estim!(
876 "IsometryToReference reference shape {:?} must equal (n_obs, latent_dim) = ({}, {})",
877 reference.dim(),
878 n_obs,
879 latent_dim
880 );
881 }
882 let mu_slot = cursor;
883 let (log_mu, mu) = match strength {
884 AuxPriorStrength::Fixed(mu) => (mu.ln(), *mu),
885 AuxPriorStrength::Auto => {
886 let log_mu = theta[cursor];
887 cursor += 1;
888 (log_mu, log_mu.exp())
889 }
890 };
891 let residual = &t - reference;
892 let q = residual.iter().map(|v| v * v).sum::<f64>();
893 let k = (n_obs * latent_dim) as f64;
897 cost += 0.5 * mu * q - 0.5 * k * log_mu;
898 for n in 0..n_obs {
899 for axis in 0..latent_dim {
900 gradient[t_start + n * latent_dim + axis] += mu * residual[[n, axis]];
901 }
902 }
903 if matches!(strength, AuxPriorStrength::Auto) {
904 gradient[mu_slot] += 0.5 * mu * q - 0.5 * k;
905 }
906 }
907 LatentIdMode::AuxOutcome { head, .. } => {
908 let n_coeffs = head.n_coeffs(latent_dim);
916 let coeffs = theta
917 .slice(ndarray::s![cursor..cursor + n_coeffs])
918 .to_owned();
919 let (head_nll, grad_coeffs, grad_t) = head
920 .neg_loglik_and_grad(t.view(), coeffs.view())
921 .map_err(EstimationError::InvalidInput)?;
922 cost += head_nll;
923 for (offset, &g) in grad_coeffs.iter().enumerate() {
924 gradient[cursor + offset] += g;
925 }
926 for n in 0..n_obs {
927 for axis in 0..latent_dim {
928 gradient[t_start + n * latent_dim + axis] += grad_t[[n, axis]];
929 }
930 }
931 cursor += n_coeffs;
932 }
933 LatentIdMode::DimSelection { .. } | LatentIdMode::None => {}
934 }
935
936 match latent.id_mode() {
937 LatentIdMode::AuxPriorDimSelection { .. }
938 | LatentIdMode::DimSelection { .. }
939 | LatentIdMode::AuxOutcome { .. } => {
940 for axis in 0..latent_dim {
941 let log_alpha = theta[cursor + axis];
942 let alpha = log_alpha.exp();
943 let mut q_axis = 0.0;
944 for n in 0..n_obs {
945 let flat_idx = n * latent_dim + axis;
946 let value = latent.as_flat()[flat_idx];
947 q_axis += value * value;
948 gradient[t_start + flat_idx] += alpha * value;
949 }
950 cost += 0.5 * alpha * q_axis - 0.5 * n_obs as f64 * log_alpha;
951 gradient[cursor + axis] += 0.5 * alpha * q_axis - 0.5 * n_obs as f64;
952 }
953 cursor += latent_dim;
954 }
955 LatentIdMode::AuxPrior { .. }
956 | LatentIdMode::IsometryToReference { .. }
957 | LatentIdMode::None => {}
958 }
959
960 if cursor != theta.len() {
961 crate::bail_invalid_estim!(
962 "latent-coordinate direct hyperparameter length mismatch: consumed {}, theta len {}",
963 cursor,
964 theta.len()
965 );
966 }
967 Ok(LatentIdObjectiveContribution { cost, gradient })
968}
969
970fn add_latent_id_objective_to_eval(
971 theta: &Array1<f64>,
972 rho_dim: usize,
973 analytic_rho_count: usize,
974 latent: &gam_terms::latent::LatentCoordValues,
975 eval: &mut (
976 f64,
977 Array1<f64>,
978 gam_problem::HessianResult,
979 ),
980) -> Result<(), EstimationError> {
981 let contribution =
982 latent_id_objective_contribution(theta, rho_dim, analytic_rho_count, latent)?;
983 eval.0 += contribution.cost;
984 if eval.1.len() != contribution.gradient.len() {
985 crate::bail_invalid_estim!(
986 "latent-coordinate REML gradient length mismatch: base={}, id={}",
987 eval.1.len(),
988 contribution.gradient.len()
989 );
990 }
991 eval.1 += &contribution.gradient;
992 if eval.2.is_analytic() {
993 eval.2 = gam_problem::HessianResult::Unavailable;
994 }
995 Ok(())
996}
997
998fn analytic_penalty_objective_contribution(
999 theta: &Array1<f64>,
1000 rho_dim: usize,
1001 latent: &gam_terms::latent::LatentCoordValues,
1002 registry: &gam_terms::AnalyticPenaltyRegistry,
1003) -> Result<LatentIdObjectiveContribution, EstimationError> {
1004 let flat_len = latent.len();
1005 let t_start = rho_dim;
1006 let t_end = t_start + flat_len;
1007 let rho_start = t_end;
1008 let rho_end = rho_start + registry.total_rho_count();
1009 if theta.len() < rho_end {
1010 crate::bail_invalid_estim!(
1011 "latent-coordinate theta too short for analytic penalties: got {}, need at least {}",
1012 theta.len(),
1013 rho_end
1014 );
1015 }
1016 let target_t = theta.slice(s![t_start..t_end]);
1017 let rho = theta.slice(s![rho_start..rho_end]);
1018 let mut cost = 0.0_f64;
1019 let mut gradient = Array1::<f64>::zeros(theta.len());
1020 for (penalty, (rho_slice, tier, name)) in registry.penalties.iter().zip(registry.rho_layout()) {
1021 let rho_local = rho.slice(s![rho_slice.clone()]);
1022 match tier {
1023 gam_terms::PenaltyTier::Psi => {
1024 cost += penalty.value(target_t.view(), rho_local);
1025 let grad = penalty.grad_target(target_t.view(), rho_local);
1026 if grad.len() != flat_len {
1027 crate::bail_invalid_estim!(
1028 "analytic penalty {name:?} gradient length mismatch: got {}, expected {}",
1029 grad.len(),
1030 flat_len
1031 );
1032 }
1033 for i in 0..flat_len {
1034 gradient[t_start + i] += grad[i];
1035 }
1036 let grad_rho_local = penalty.grad_rho(target_t.view(), rho_local);
1037 if grad_rho_local.len() != rho_slice.len() {
1038 crate::bail_invalid_estim!(
1039 "analytic penalty {name:?} rho-gradient length mismatch: got {}, expected {}",
1040 grad_rho_local.len(),
1041 rho_slice.len()
1042 );
1043 }
1044 for local_idx in 0..grad_rho_local.len() {
1045 gradient[rho_start + rho_slice.start + local_idx] += grad_rho_local[local_idx];
1046 }
1047 }
1048 gam_terms::PenaltyTier::Beta => {}
1049 gam_terms::PenaltyTier::Rho => {}
1050 }
1051 }
1052 Ok(LatentIdObjectiveContribution { cost, gradient })
1053}
1054
1055fn add_analytic_penalty_hessian_to_eval(
1056 theta: &Array1<f64>,
1057 rho_dim: usize,
1058 latent: &gam_terms::latent::LatentCoordValues,
1059 registry: &gam_terms::AnalyticPenaltyRegistry,
1060 eval: &mut (
1061 f64,
1062 Array1<f64>,
1063 gam_problem::HessianResult,
1064 ),
1065) -> Result<(), EstimationError> {
1066 let flat_len = latent.len();
1067 let t_start = rho_dim;
1068 let t_end = t_start + flat_len;
1069 let rho_start = t_end;
1070 let rho_end = rho_start + registry.total_rho_count();
1071 if theta.len() < rho_end {
1072 crate::bail_invalid_estim!(
1073 "latent-coordinate theta too short for analytic penalty Hessian: got {}, need at least {}",
1074 theta.len(),
1075 rho_end
1076 );
1077 }
1078 let gam_problem::HessianResult::Analytic(hessian) = &mut eval.2 else {
1079 if eval.2.is_analytic() {
1080 eval.2 = gam_problem::HessianResult::Unavailable;
1081 }
1082 return Ok(());
1083 };
1084 if hessian.dim() != (theta.len(), theta.len()) {
1085 crate::bail_invalid_estim!(
1086 "analytic penalty Hessian target shape mismatch: got {}x{}, expected {}x{}",
1087 hessian.nrows(),
1088 hessian.ncols(),
1089 theta.len(),
1090 theta.len()
1091 );
1092 }
1093 let target_t = theta.slice(s![t_start..t_end]);
1094 let rho = theta.slice(s![rho_start..rho_end]);
1095 for (penalty, (rho_slice, tier, _name)) in registry.penalties.iter().zip(registry.rho_layout())
1096 {
1097 let rho_local = rho.slice(s![rho_slice]);
1098 if !matches!(tier, gam_terms::PenaltyTier::Psi) {
1099 continue;
1100 }
1101 if let Some(diag) = penalty.hessian_diag(target_t.view(), rho_local) {
1102 if diag.len() != flat_len {
1103 crate::bail_invalid_estim!(
1104 "analytic penalty Hessian diagonal length mismatch: got {}, expected {}",
1105 diag.len(),
1106 flat_len
1107 );
1108 }
1109 for i in 0..flat_len {
1110 hessian[[t_start + i, t_start + i]] += diag[i];
1111 }
1112 continue;
1113 }
1114 let mut probe = Array1::<f64>::zeros(flat_len);
1115 for col in 0..flat_len {
1116 probe[col] = 1.0;
1117 let hv = penalty.hvp(target_t.view(), rho_local, probe.view());
1118 if hv.len() != flat_len {
1119 crate::bail_invalid_estim!(
1120 "analytic penalty Hessian-vector length mismatch: got {}, expected {}",
1121 hv.len(),
1122 flat_len
1123 );
1124 }
1125 for row in 0..flat_len {
1126 hessian[[t_start + row, t_start + col]] += hv[row];
1127 }
1128 probe[col] = 0.0;
1129 }
1130 }
1131 Ok(())
1132}
1133
1134fn add_analytic_penalty_objective_to_eval(
1135 theta: &Array1<f64>,
1136 rho_dim: usize,
1137 latent: &gam_terms::latent::LatentCoordValues,
1138 registry: &gam_terms::AnalyticPenaltyRegistry,
1139 eval: &mut (
1140 f64,
1141 Array1<f64>,
1142 gam_problem::HessianResult,
1143 ),
1144) -> Result<(), EstimationError> {
1145 let contribution = analytic_penalty_objective_contribution(theta, rho_dim, latent, registry)?;
1146 eval.0 += contribution.cost;
1147 if eval.1.len() != contribution.gradient.len() {
1148 crate::bail_invalid_estim!(
1149 "latent-coordinate REML gradient length mismatch: base={}, analytic_penalty={}",
1150 eval.1.len(),
1151 contribution.gradient.len()
1152 );
1153 }
1154 eval.1 += &contribution.gradient;
1155 add_analytic_penalty_hessian_to_eval(theta, rho_dim, latent, registry, eval)?;
1156 Ok(())
1157}
1158
1159fn spatial_log_kappa_hyper_dirs_frominfo_list(
1160 info_list: Vec<SpatialPsiDerivative>,
1161) -> Result<Vec<DirectionalHyperParam>, EstimationError> {
1162 use gam_solve::estimate::reml::ImplicitDerivLevel;
1163 use std::collections::HashMap;
1164
1165 let log_kappa_dim = info_list.len();
1166 let group_ids: Vec<Option<usize>> = info_list.iter().map(|e| e.aniso_group_id).collect();
1172 let mut group_indices_map: HashMap<usize, Vec<usize>> = HashMap::new();
1173 for (idx, gid) in group_ids.iter().enumerate() {
1174 if let Some(g) = gid {
1175 group_indices_map.entry(*g).or_default().push(idx);
1176 }
1177 }
1178
1179 let mut hyper_dirs = Vec::with_capacity(log_kappa_dim);
1180 for (i, info) in info_list.into_iter().enumerate() {
1181 let SpatialPsiDerivative {
1182 penalty_index: _,
1183 penalty_indices,
1184 global_range,
1185 total_p,
1186 x_psi_local,
1187 s_psi_components_local,
1188 x_psi_psi_local,
1189 s_psi_psi_components_local,
1190 aniso_group_id,
1191 aniso_cross_designs,
1192 aniso_cross_penalty_provider,
1193 implicit_operator,
1194 implicit_axis,
1195 } = info;
1196
1197 let mut xsecond = vec![None; log_kappa_dim];
1198 xsecond[i] = Some(if let Some(ref op) = implicit_operator {
1200 gam_solve::estimate::reml::HyperDesignDerivative::from_implicit(
1201 op.clone(),
1202 ImplicitDerivLevel::SecondDiag(implicit_axis),
1203 global_range.clone(),
1204 total_p,
1205 )
1206 } else {
1207 gam_solve::estimate::reml::HyperDesignDerivative::from_embedded(
1208 x_psi_psi_local,
1209 global_range.clone(),
1210 total_p,
1211 )
1212 });
1213 if let Some(cross_designs) = aniso_cross_designs {
1215 if let Some(gid) = aniso_group_id {
1219 let base = group_indices_map
1220 .get(&gid)
1221 .and_then(|v| v.first().copied())
1222 .unwrap_or(i);
1223 for (b_axis, cross_mat) in cross_designs.into_iter() {
1224 let j = base + b_axis;
1225 if j < log_kappa_dim {
1226 xsecond[j] = Some(if let Some(ref op) = implicit_operator {
1227 gam_solve::estimate::reml::HyperDesignDerivative::from_implicit(
1228 op.clone(),
1229 ImplicitDerivLevel::SecondCross(implicit_axis, b_axis),
1230 global_range.clone(),
1231 total_p,
1232 )
1233 } else {
1234 gam_solve::estimate::reml::HyperDesignDerivative::from_embedded(
1235 cross_mat,
1236 global_range.clone(),
1237 total_p,
1238 )
1239 });
1240 }
1241 }
1242 }
1243 }
1244 let s_components = penalty_indices
1245 .iter()
1246 .copied()
1247 .zip(s_psi_components_local.into_iter().map(|local| {
1248 gam_solve::estimate::reml::HyperPenaltyDerivative::from_embedded(
1249 local,
1250 global_range.clone(),
1251 total_p,
1252 )
1253 }))
1254 .collect::<Vec<_>>();
1255 let s2_components = penalty_indices
1256 .iter()
1257 .copied()
1258 .zip(s_psi_psi_components_local.into_iter().map(|local| {
1259 gam_solve::estimate::reml::HyperPenaltyDerivative::from_embedded(
1260 local,
1261 global_range.clone(),
1262 total_p,
1263 )
1264 }))
1265 .collect::<Vec<_>>();
1266 let mut ssecond_components = vec![None; log_kappa_dim];
1267 ssecond_components[i] = Some(s2_components);
1268 let mut penaltysecond_partner_indices: Option<Vec<usize>> = None;
1269 let penaltysecond_component_provider =
1270 if let (Some(provider), Some(gid)) = (aniso_cross_penalty_provider, aniso_group_id) {
1271 let group_indices = group_indices_map.get(&gid).cloned().unwrap_or_default();
1272 let axis_in_group =
1273 group_indices
1274 .iter()
1275 .position(|&idx| idx == i)
1276 .ok_or_else(|| {
1277 EstimationError::InvalidInput(format!(
1278 "missing spatial hyper axis {} in anisotropy group {}",
1279 i, gid
1280 ))
1281 })?;
1282 penaltysecond_partner_indices = Some(
1283 group_indices
1284 .iter()
1285 .copied()
1286 .filter(|&idx| idx != i)
1287 .collect(),
1288 );
1289 let penalty_indices_inner = penalty_indices.clone();
1290 let global_range_inner = global_range.clone();
1291 let total_p_inner = total_p;
1292 let group_indices_inner = group_indices;
1293 Some(std::sync::Arc::new(
1294 move |j: usize| -> Result<
1295 Option<Vec<gam_solve::estimate::reml::PenaltyDerivativeComponent>>,
1296 EstimationError,
1297 > {
1298 let Some(other_axis_in_group) =
1299 group_indices_inner.iter().position(|&idx| idx == j)
1300 else {
1301 return Ok(None);
1302 };
1303 if other_axis_in_group == axis_in_group {
1304 return Ok(None);
1305 }
1306 let cross_pens = provider(other_axis_in_group)?;
1307 if cross_pens.is_empty() {
1308 return Ok(None);
1309 }
1310 Ok(Some(
1311 penalty_indices_inner
1312 .iter()
1313 .copied()
1314 .zip(cross_pens.into_iter().map(|local| {
1315 gam_solve::estimate::reml::HyperPenaltyDerivative::from_embedded(
1316 local,
1317 global_range_inner.clone(),
1318 total_p_inner,
1319 )
1320 }))
1321 .map(|(penalty_index, matrix)| {
1322 gam_solve::estimate::reml::PenaltyDerivativeComponent {
1323 penalty_index,
1324 matrix,
1325 }
1326 })
1327 .collect(),
1328 ))
1329 },
1330 )
1331 as std::sync::Arc<
1332 dyn Fn(
1333 usize,
1334 ) -> Result<
1335 Option<Vec<gam_solve::estimate::reml::PenaltyDerivativeComponent>>,
1336 EstimationError,
1337 > + Send
1338 + Sync
1339 + 'static,
1340 >)
1341 } else {
1342 None
1343 };
1344 let x_first_hyper = if let Some(ref op) = implicit_operator {
1347 gam_solve::estimate::reml::HyperDesignDerivative::from_implicit(
1348 op.clone(),
1349 ImplicitDerivLevel::First(implicit_axis),
1350 global_range.clone(),
1351 total_p,
1352 )
1353 } else {
1354 gam_solve::estimate::reml::HyperDesignDerivative::from_embedded(
1355 x_psi_local,
1356 global_range.clone(),
1357 total_p,
1358 )
1359 };
1360 let mut dir = DirectionalHyperParam::new_compact(
1361 x_first_hyper,
1362 s_components,
1363 Some(xsecond),
1364 Some(ssecond_components),
1365 )?
1366 .not_penalty_like();
1367 if let Some(provider) = penaltysecond_component_provider {
1368 dir = dir.with_penaltysecond_component_provider(provider);
1369 }
1370 if let Some(partner_indices) = penaltysecond_partner_indices {
1371 dir = dir.with_penaltysecond_partner_indices(partner_indices);
1372 }
1373 hyper_dirs.push(dir);
1374 }
1375 Ok(hyper_dirs)
1376}
1377
1378pub(crate) fn spatial_dims_per_term(
1384 resolvedspec: &TermCollectionSpec,
1385 spatial_terms: &[usize],
1386) -> Vec<usize> {
1387 spatial_terms
1388 .iter()
1389 .map(|&term_idx| {
1390 if let Some(mj) = measure_jet_term_spec(resolvedspec, term_idx) {
1391 measure_jet_psi_dim(mj)
1394 } else if spatial_term_uses_per_axis_psi(resolvedspec, term_idx) {
1395 get_spatial_feature_dim(resolvedspec, term_idx).unwrap_or(1)
1396 } else {
1397 1
1398 }
1399 })
1400 .collect()
1401}
1402
1403fn has_aniso_terms(resolvedspec: &TermCollectionSpec, spatial_terms: &[usize]) -> bool {
1407 spatial_terms
1408 .iter()
1409 .any(|&term_idx| spatial_term_uses_per_axis_psi(resolvedspec, term_idx))
1410}
1411
1412macro_rules! impl_exact_joint_theta_memo {
1418 () => {
1419 fn memoized_cost(&self, theta: &Array1<f64>) -> Option<f64> {
1420 if self
1421 .current_theta
1422 .as_ref()
1423 .is_some_and(|cached| theta_values_match(cached, theta))
1424 {
1425 self.last_eval
1426 .as_ref()
1427 .map(|cached| cached.0)
1428 .or(self.last_cost)
1429 } else {
1430 None
1431 }
1432 }
1433
1434 fn memoized_eval(
1435 &self,
1436 theta: &Array1<f64>,
1437 ) -> Option<(
1438 f64,
1439 Array1<f64>,
1440 gam_problem::HessianResult,
1441 )> {
1442 if self
1443 .current_theta
1444 .as_ref()
1445 .is_some_and(|cached| theta_values_match(cached, theta))
1446 {
1447 self.last_eval.clone()
1448 } else {
1449 None
1450 }
1451 }
1452
1453 fn store_eval(
1454 &mut self,
1455 eval: (
1456 f64,
1457 Array1<f64>,
1458 gam_problem::HessianResult,
1459 ),
1460 ) {
1461 self.last_cost = Some(eval.0);
1462 self.last_eval = Some(eval);
1463 }
1464 };
1465}
1466
1467struct SingleBlockExactJointDesignCache<'d> {
1468 realizer: FrozenTermCollectionIncrementalRealizer<'d>,
1469 current_theta: Option<Array1<f64>>,
1470 last_eval_theta: Option<Array1<f64>>,
1477 last_cost: Option<f64>,
1478 last_eval: Option<(
1479 f64,
1480 Array1<f64>,
1481 gam_problem::HessianResult,
1482 )>,
1483 cached_hyper_dirs: Option<(u64, Vec<DirectionalHyperParam>)>,
1495 spatial_terms: Vec<usize>,
1496 rho_dim: usize,
1497 dims_per_term: Vec<usize>,
1498}
1499
1500impl<'d> SingleBlockExactJointDesignCache<'d> {
1501 fn new(
1502 data: ArrayView2<'d, f64>,
1503 spec: TermCollectionSpec,
1504 design: TermCollectionDesign,
1505 spatial_terms: Vec<usize>,
1506 rho_dim: usize,
1507 dims_per_term: Vec<usize>,
1508 ) -> Result<Self, String> {
1509 Ok(Self {
1510 realizer: FrozenTermCollectionIncrementalRealizer::new(data, spec, design)?,
1511 current_theta: None,
1512 last_eval_theta: None,
1513 last_cost: None,
1514 last_eval: None,
1515 cached_hyper_dirs: None,
1516 spatial_terms,
1517 rho_dim,
1518 dims_per_term,
1519 })
1520 }
1521
1522 fn design_revision(&self) -> u64 {
1523 self.realizer.design_revision()
1524 }
1525
1526 fn hyper_dirs_for_current_design(
1536 &mut self,
1537 data: ArrayView2<'_, f64>,
1538 kind: SpatialHyperKind,
1539 ) -> Result<Vec<DirectionalHyperParam>, EstimationError> {
1540 let revision = self.realizer.design_revision();
1541 if let Some((cached_rev, dirs)) = self.cached_hyper_dirs.as_ref()
1542 && *cached_rev == revision
1543 {
1544 return Ok(dirs.clone());
1545 }
1546 let dirs = try_build_spatial_log_kappa_hyper_dirs(
1547 data,
1548 self.realizer.spec(),
1549 self.realizer.design(),
1550 &self.spatial_terms,
1551 )?
1552 .ok_or_else(|| {
1553 EstimationError::InvalidInput(format!(
1554 "failed to build {} hyper_dirs at current {}",
1555 kind.adjective(),
1556 kind.coord_name(),
1557 ))
1558 })?;
1559 self.cached_hyper_dirs = Some((revision, dirs.clone()));
1560 Ok(dirs)
1561 }
1562
1563 fn nfree_tensor_gradient_hyper_dirs(
1564 &mut self,
1565 theta: &Array1<f64>,
1566 ) -> Result<Vec<DirectionalHyperParam>, EstimationError> {
1567 let psi = &theta.as_slice().ok_or_else(|| {
1568 EstimationError::InvalidInput(
1569 "nfree_tensor_gradient_hyper_dirs: theta is not contiguous".to_string(),
1570 )
1571 })?[self.rho_dim..];
1572 let (global_range, p_total, s_psi_components) = self
1573 .realizer
1574 .canonical_penalty_derivatives_at_psi(&self.spatial_terms, psi)
1575 .map_err(EstimationError::InvalidInput)?;
1576 let zero_x = gam_solve::estimate::reml::HyperDesignDerivative::zero(
1577 self.realizer.design().design.nrows(),
1578 p_total,
1579 );
1580 let components = s_psi_components
1581 .into_iter()
1582 .enumerate()
1583 .map(|(penalty_index, local)| {
1584 (
1585 penalty_index,
1586 gam_solve::estimate::reml::HyperPenaltyDerivative::from_embedded(
1587 local,
1588 global_range.clone(),
1589 p_total,
1590 ),
1591 )
1592 })
1593 .collect::<Vec<_>>();
1594 Ok(DirectionalHyperParam::new_compact(zero_x, components, None, None)?.not_penalty_like())
1595 .map(|dir| vec![dir])
1596 }
1597
1598 fn ensure_theta(&mut self, theta: &Array1<f64>) -> Result<(), String> {
1599 if self
1600 .current_theta
1601 .as_ref()
1602 .is_some_and(|cached| theta_values_match(cached, theta))
1603 {
1604 return Ok(());
1605 }
1606 let t_ensure = std::time::Instant::now();
1607 let log_kappa = SpatialLogKappaCoords::from_theta_tail_with_dims(
1608 theta,
1609 self.rho_dim,
1610 self.dims_per_term.clone(),
1611 );
1612 self.realizer
1613 .apply_log_kappa(&log_kappa, &self.spatial_terms)?;
1614 log::info!(
1615 "[STAGE] ensure_theta (apply_log_kappa, {} terms): {:.3}s",
1616 self.spatial_terms.len(),
1617 t_ensure.elapsed().as_secs_f64(),
1618 );
1619 self.current_theta = Some(theta.clone());
1620 self.last_eval_theta = None;
1621 self.last_cost = None;
1622 self.last_eval = None;
1623 Ok(())
1624 }
1625
1626 fn memoized_cost(&self, theta: &Array1<f64>) -> Option<f64> {
1633 if self
1634 .last_eval_theta
1635 .as_ref()
1636 .is_some_and(|cached| theta_values_match(cached, theta))
1637 {
1638 self.last_eval
1639 .as_ref()
1640 .map(|cached| cached.0)
1641 .or(self.last_cost)
1642 } else {
1643 None
1644 }
1645 }
1646
1647 fn memoized_eval(
1648 &self,
1649 theta: &Array1<f64>,
1650 ) -> Option<(
1651 f64,
1652 Array1<f64>,
1653 gam_problem::HessianResult,
1654 )> {
1655 if self
1656 .last_eval_theta
1657 .as_ref()
1658 .is_some_and(|cached| theta_values_match(cached, theta))
1659 {
1660 self.last_eval.clone()
1661 } else {
1662 None
1663 }
1664 }
1665
1666 fn store_eval_at(
1670 &mut self,
1671 theta: &Array1<f64>,
1672 eval: (
1673 f64,
1674 Array1<f64>,
1675 gam_problem::HessianResult,
1676 ),
1677 ) {
1678 self.last_eval_theta = Some(theta.clone());
1679 self.last_cost = Some(eval.0);
1680 self.last_eval = Some(eval);
1681 }
1682
1683 fn store_cost_at(&mut self, theta: &Array1<f64>, cost: f64) {
1686 self.last_eval_theta = Some(theta.clone());
1687 self.last_cost = Some(cost);
1688 self.last_eval = None;
1692 }
1693
1694 fn spec(&self) -> &TermCollectionSpec {
1695 self.realizer.spec()
1696 }
1697
1698 fn design(&self) -> &TermCollectionDesign {
1699 self.realizer.design()
1700 }
1701
1702 fn supports_nfree_penalty_rekey(&self) -> bool {
1708 self.realizer
1709 .supports_nfree_penalty_rekey(&self.spatial_terms)
1710 }
1711
1712 fn supports_nfree_gradient_only_routing(&self) -> bool {
1713 self.realizer
1714 .supports_nfree_gradient_only_routing(&self.spatial_terms)
1715 }
1716
1717 fn canonical_penalties_at(
1727 &mut self,
1728 theta: &Array1<f64>,
1729 ) -> Result<(Vec<gam_terms::construction::CanonicalPenalty>, Vec<usize>), String> {
1730 let psi = &theta
1731 .as_slice()
1732 .ok_or_else(|| "canonical_penalties_at: theta is not contiguous".to_string())?
1733 [self.rho_dim..];
1734 self.realizer
1735 .canonical_penalties_at_psi(&self.spatial_terms, psi)
1736 }
1737}
1738
1739struct SingleBlockLatentCoordDesignCache {
1740 data: Array2<f64>,
1741 spec: TermCollectionSpec,
1742 design: TermCollectionDesign,
1743 current_theta: Option<Array1<f64>>,
1744 current_latent: Option<std::sync::Arc<gam_terms::latent::LatentCoordValues>>,
1745 current_hyper_dirs: Option<Vec<gam_solve::estimate::reml::DirectionalHyperParam>>,
1746 current_design_cache_id: Option<u64>,
1747 latent_design_cache: gam_solve::latent_cache::LatentDesignCache,
1748 last_cost: Option<f64>,
1749 last_eval: Option<(
1750 f64,
1751 Array1<f64>,
1752 gam_problem::HessianResult,
1753 )>,
1754 term_index: gam_problem::types::SmoothTermIdx,
1755 feature_cols: Vec<usize>,
1756 rho_dim: usize,
1757 n_obs: usize,
1758 latent_dim: usize,
1759 id_mode: gam_terms::latent::LatentIdMode,
1760 manifold: gam_terms::latent::LatentManifold,
1761 retraction_registry: gam_solve::latent_cache::LatentRetractionRegistry,
1762 latent_id: u64,
1763 analytic_penalties: Option<std::sync::Arc<gam_terms::AnalyticPenaltyRegistry>>,
1764 analytic_rho_count: usize,
1765 design_revision: u64,
1766 last_outer_iter: Option<u64>,
1770}
1771
1772impl SingleBlockLatentCoordDesignCache {
1773 fn new(
1774 data: Array2<f64>,
1775 spec: TermCollectionSpec,
1776 design: TermCollectionDesign,
1777 latent: &StandardLatentCoordConfig,
1778 rho_dim: usize,
1779 ) -> Result<Self, String> {
1780 if latent.term_index.get() >= spec.smooth_terms.len() {
1781 return Err(SmoothError::dimension_mismatch(format!(
1782 "latent-coordinate term index {} out of bounds for {} smooth terms",
1783 latent.term_index,
1784 spec.smooth_terms.len()
1785 ))
1786 .into());
1787 }
1788 if latent.feature_cols.len() != latent.values.latent_dim() {
1789 return Err(SmoothError::dimension_mismatch(format!(
1790 "latent-coordinate feature width mismatch: feature_cols={}, latent_dim={}",
1791 latent.feature_cols.len(),
1792 latent.values.latent_dim()
1793 ))
1794 .into());
1795 }
1796 if latent.values.n_obs() != data.nrows() {
1797 return Err(SmoothError::dimension_mismatch(format!(
1798 "latent-coordinate row mismatch: latent n={}, data n={}",
1799 latent.values.n_obs(),
1800 data.nrows()
1801 ))
1802 .into());
1803 }
1804 let analytic_rho_count = latent
1805 .analytic_penalties
1806 .as_ref()
1807 .map_or(0, |registry| registry.total_rho_count());
1808 Ok(Self {
1809 data,
1810 spec,
1811 design,
1812 current_theta: None,
1813 current_latent: None,
1814 current_hyper_dirs: None,
1815 current_design_cache_id: None,
1816 latent_design_cache: gam_solve::latent_cache::LatentDesignCache::default(),
1817 last_cost: None,
1818 last_eval: None,
1819 term_index: latent.term_index,
1820 feature_cols: latent.feature_cols.clone(),
1821 rho_dim,
1822 n_obs: latent.values.n_obs(),
1823 latent_dim: latent.values.latent_dim(),
1824 id_mode: latent.values.id_mode().clone(),
1825 manifold: latent.values.manifold().clone(),
1826 retraction_registry: latent.values.retraction_registry().clone(),
1827 latent_id: latent.values.latent_id(),
1828 analytic_penalties: latent.analytic_penalties.clone(),
1829 analytic_rho_count,
1830 design_revision: 0,
1831 last_outer_iter: None,
1832 })
1833 }
1834
1835 fn design_revision(&self) -> u64 {
1836 self.design_revision
1837 }
1838
1839 fn design(&self) -> &TermCollectionDesign {
1840 &self.design
1841 }
1842
1843 fn latent(&self) -> Result<std::sync::Arc<gam_terms::latent::LatentCoordValues>, String> {
1844 self.current_latent
1845 .as_ref()
1846 .cloned()
1847 .ok_or_else(|| "latent-coordinate cache has not been realized".to_string())
1848 }
1849
1850 fn analytic_penalties(&self) -> Option<std::sync::Arc<gam_terms::AnalyticPenaltyRegistry>> {
1851 self.analytic_penalties.clone()
1852 }
1853
1854 fn analytic_penalty_rho_count(&self) -> usize {
1855 self.analytic_rho_count
1856 }
1857
1858 fn hyper_dirs(&self) -> Result<Vec<gam_solve::estimate::reml::DirectionalHyperParam>, String> {
1859 self.current_hyper_dirs
1860 .as_ref()
1861 .cloned()
1862 .ok_or_else(|| "latent-coordinate hyper_dirs cache has not been realized".to_string())
1863 }
1864
1865 fn latent_basis_kind(&self) -> Result<gam_solve::latent_cache::LatentBasisKind, String> {
1866 let smooth_term = self
1867 .design
1868 .smooth
1869 .terms
1870 .get(self.term_index.get())
1871 .ok_or_else(|| {
1872 SmoothError::dimension_mismatch(format!(
1873 "LatentCoord term index {} out of bounds for realized smooth design",
1874 self.term_index
1875 ))
1876 })?;
1877 let termspec = self
1878 .spec
1879 .smooth_terms
1880 .get(self.term_index.get())
1881 .ok_or_else(|| {
1882 SmoothError::dimension_mismatch(format!(
1883 "LatentCoord term index {} out of bounds for resolved smooth spec",
1884 self.term_index
1885 ))
1886 })?;
1887 match (&termspec.basis, &smooth_term.metadata) {
1888 (
1889 SmoothBasisSpec::Matern { .. },
1890 BasisMetadata::Matern {
1891 centers,
1892 length_scale,
1893 nu,
1894 aniso_log_scales,
1895 ..
1896 },
1897 ) => Ok(gam_solve::latent_cache::LatentBasisKind::Matern {
1898 centers: centers.clone(),
1899 length_scale: *length_scale,
1900 nu: *nu,
1901 aniso_log_scales: aniso_log_scales
1902 .clone()
1903 .unwrap_or_else(|| vec![0.0; centers.ncols()]),
1904 chunk_size: gam_terms::basis::auto_streaming_chunk_size_for_dense(
1905 self.n_obs,
1906 centers.nrows(),
1907 ),
1908 }),
1909 (
1910 SmoothBasisSpec::Duchon { .. },
1911 BasisMetadata::Duchon {
1912 centers,
1913 length_scale,
1914 power,
1915 nullspace_order,
1916 aniso_log_scales,
1917 ..
1918 },
1919 ) => Ok(gam_solve::latent_cache::LatentBasisKind::Duchon {
1920 centers: centers.clone(),
1921 length_scale: *length_scale,
1922 power: *power,
1923 nullspace_order: *nullspace_order,
1924 aniso_log_scales: aniso_log_scales
1925 .clone()
1926 .unwrap_or_else(|| vec![0.0; centers.ncols()]),
1927 }),
1928 (
1929 SmoothBasisSpec::Sphere { .. },
1930 BasisMetadata::Sphere {
1931 centers,
1932 penalty_order,
1933 method,
1934 ..
1935 },
1936 ) if matches!(*method, gam_terms::basis::SphereMethod::Wahba) => {
1937 Ok(gam_solve::latent_cache::LatentBasisKind::Sphere {
1938 centers: centers.clone(),
1939 penalty_order: *penalty_order,
1940 chunk_size: gam_terms::basis::auto_streaming_chunk_size_for_dense(
1941 self.n_obs,
1942 centers.nrows(),
1943 ),
1944 })
1945 }
1946 (
1947 SmoothBasisSpec::BSpline1D { spec, .. },
1948 BasisMetadata::BSpline1D {
1949 knots,
1950 periodic,
1951 degree: meta_degree,
1952 ..
1953 },
1954 ) => {
1955 let effective_degree = meta_degree.unwrap_or(spec.degree);
1959 if let Some((domain_start, period, num_basis)) = periodic {
1960 Ok(
1961 gam_solve::latent_cache::LatentBasisKind::PeriodicBspline {
1962 domain_start: *domain_start,
1963 period: *period,
1964 degree: effective_degree,
1965 num_basis: *num_basis,
1966 chunk_size: gam_terms::basis::auto_streaming_chunk_size_for_dense(
1967 self.n_obs, *num_basis,
1968 ),
1969 },
1970 )
1971 } else {
1972 let num_basis_est = knots.len().saturating_sub(effective_degree + 1);
1973 Ok(
1974 gam_solve::latent_cache::LatentBasisKind::TensorBspline {
1975 knots: vec![knots.clone()],
1976 degrees: vec![effective_degree],
1977 chunk_size: gam_terms::basis::auto_streaming_chunk_size_for_dense(
1978 self.n_obs,
1979 num_basis_est,
1980 ),
1981 },
1982 )
1983 }
1984 }
1985 (
1986 SmoothBasisSpec::TensorBSpline { .. },
1987 BasisMetadata::TensorBSpline { knots, degrees, .. },
1988 ) => Ok(
1989 gam_solve::latent_cache::LatentBasisKind::TensorBspline {
1990 knots: knots.clone(),
1991 degrees: degrees.clone(),
1992 chunk_size: None,
1993 },
1994 ),
1995 (
1996 SmoothBasisSpec::Pca { .. },
1997 BasisMetadata::Pca {
1998 basis_matrix,
1999 centered,
2000 smooth_penalty,
2001 center_mean,
2002 pca_basis_path,
2003 chunk_size,
2004 ..
2005 },
2006 ) => {
2007 let center_mean_fingerprint = if *centered && pca_basis_path.is_none() {
2008 let mean = center_mean.as_ref().ok_or_else(|| {
2009 SmoothError::invalid_config(
2010 "latent-coordinate Pca cache key requires center_mean when centered",
2011 )
2012 })?;
2013 Some(gam_solve::latent_cache::pca_center_mean_fingerprint(
2014 mean,
2015 ))
2016 } else {
2017 None
2018 };
2019 Ok(gam_solve::latent_cache::LatentBasisKind::Pca {
2020 basis_matrix: basis_matrix.clone(),
2021 centered: *centered,
2022 center_mean_fingerprint,
2023 smooth_penalty: *smooth_penalty,
2024 pca_basis_path: pca_basis_path.clone(),
2025 chunk_size: *chunk_size,
2026 })
2027 }
2028 _ => Err(SmoothError::invalid_config(
2029 "latent-coordinate design cache could not key the realized latent smooth basis"
2030 .to_string(),
2031 )
2032 .into()),
2033 }
2034 }
2035
2036 fn ensure_theta(&mut self, theta: &Array1<f64>) -> Result<(), String> {
2037 if self
2038 .current_theta
2039 .as_ref()
2040 .is_some_and(|cached| theta_values_match(cached, theta))
2041 {
2042 return Ok(());
2043 }
2044 let latent_flat_len = self.n_obs * self.latent_dim;
2045 let direct_hyper_count = latent_coord_direct_hyper_count(&self.id_mode, self.latent_dim);
2046 let expected =
2047 self.rho_dim + latent_flat_len + self.analytic_rho_count + direct_hyper_count;
2048 if theta.len() != expected {
2049 return Err(SmoothError::dimension_mismatch(format!(
2050 "latent-coordinate theta length mismatch: got {}, expected {} (rho_dim={}, n={}, d={}, analytic_rhos={}, direct_hypers={})",
2051 theta.len(),
2052 expected,
2053 self.rho_dim,
2054 self.n_obs,
2055 self.latent_dim,
2056 self.analytic_rho_count,
2057 direct_hyper_count
2058 ))
2059 .into());
2060 }
2061 let flat = theta
2062 .slice(s![self.rho_dim..self.rho_dim + latent_flat_len])
2063 .to_owned();
2064 let latent = std::sync::Arc::new(
2065 gam_terms::latent::LatentCoordValues::from_flat_with_manifold_and_retraction_and_id(
2066 flat,
2067 self.n_obs,
2068 self.latent_dim,
2069 self.id_mode.clone(),
2070 self.manifold.clone(),
2071 self.retraction_registry.clone(),
2072 self.latent_id,
2073 ),
2074 );
2075 let latent_values_changed = self
2076 .current_latent
2077 .as_ref()
2078 .map(|cached| !latent_values_match(cached.as_flat(), latent.as_flat()))
2079 .unwrap_or(true);
2080 if latent_values_changed {
2081 self.latent_design_cache.invalidate_all();
2082 self.current_design_cache_id = None;
2083 self.design_revision = self.design_revision.wrapping_add(1);
2084 }
2085 for n in 0..self.n_obs {
2086 for axis in 0..self.latent_dim {
2087 let col = self.feature_cols[axis];
2088 self.data[[n, col]] = latent.as_flat()[n * self.latent_dim + axis];
2089 }
2090 }
2091
2092 let basis_kind = self.latent_basis_kind()?;
2093 let rebuilt_width = self.design.design.ncols();
2094 let spec = self.spec.clone();
2095 let term_index = self.term_index;
2096 let analytic_rho_count = self.analytic_rho_count;
2097 let data = self.data.view();
2098 let design_context_digest =
2099 gam_solve::latent_cache::latent_design_context_cache_digest(
2100 data,
2101 &spec,
2102 term_index,
2103 analytic_rho_count,
2104 &self.feature_cols,
2105 )
2106 .map_err(|e| e.to_string())?;
2107 let lookup = self
2108 .latent_design_cache
2109 .lookup_or_compute(latent.clone(), basis_kind, design_context_digest, || {
2110 let rebuilt = build_term_collection_design(data, &spec).map_err(|e| {
2111 EstimationError::InvalidInput(format!(
2112 "failed to rebuild latent-coordinate design: {e}"
2113 ))
2114 })?;
2115 if rebuilt.design.ncols() != rebuilt_width {
2116 crate::bail_invalid_estim!(
2117 "latent-coordinate design topology changed: rebuilt p={}, cached p={}",
2118 rebuilt.design.ncols(),
2119 rebuilt_width
2120 );
2121 }
2122 let hyper_dirs = try_build_latent_coord_hyper_dirs(
2123 latent.clone(),
2124 &spec,
2125 &rebuilt,
2126 &[term_index],
2127 analytic_rho_count,
2128 )?
2129 .ok_or_else(|| {
2130 EstimationError::InvalidInput(
2131 "failed to build latent-coordinate hyper_dirs".to_string(),
2132 )
2133 })?;
2134 Ok(gam_solve::latent_cache::ComputedLatentDesign {
2135 design: rebuilt,
2136 hyper_dirs,
2137 })
2138 })
2139 .map_err(|e| e.to_string())?;
2140 if lookup.cached.design.design.ncols() != self.design.design.ncols() {
2141 return Err(SmoothError::dimension_mismatch(format!(
2142 "latent-coordinate design topology changed: rebuilt p={}, cached p={}",
2143 lookup.cached.design.design.ncols(),
2144 self.design.design.ncols()
2145 ))
2146 .into());
2147 }
2148 self.design = lookup.cached.design.clone();
2149 self.current_hyper_dirs = Some(lookup.cached.hyper_dirs.clone());
2150 self.current_latent = Some(latent);
2151 self.current_theta = Some(theta.clone());
2152 self.last_cost = None;
2153 self.last_eval = None;
2154 self.last_outer_iter = None;
2155 if !latent_values_changed && self.current_design_cache_id != Some(lookup.entry_id) {
2156 self.design_revision = self.design_revision.wrapping_add(1);
2157 }
2158 self.current_design_cache_id = Some(lookup.entry_id);
2159 Ok(())
2160 }
2161
2162 fn memoized_cost(&self, theta: &Array1<f64>) -> Option<f64> {
2163 if self
2164 .current_theta
2165 .as_ref()
2166 .is_some_and(|cached| theta_values_match(cached, theta))
2167 && self.last_outer_iter
2168 == Some(gam_solve::estimate::reml::outer_eval::current_outer_iter())
2169 {
2170 self.last_eval
2171 .as_ref()
2172 .map(|cached| cached.0)
2173 .or(self.last_cost)
2174 } else {
2175 None
2176 }
2177 }
2178
2179 fn memoized_eval(
2180 &self,
2181 theta: &Array1<f64>,
2182 ) -> Option<(
2183 f64,
2184 Array1<f64>,
2185 gam_problem::HessianResult,
2186 )> {
2187 if self
2188 .current_theta
2189 .as_ref()
2190 .is_some_and(|cached| theta_values_match(cached, theta))
2191 && self.last_outer_iter
2192 == Some(gam_solve::estimate::reml::outer_eval::current_outer_iter())
2193 {
2194 self.last_eval.clone()
2195 } else {
2196 None
2197 }
2198 }
2199
2200 fn store_eval(
2201 &mut self,
2202 eval: (
2203 f64,
2204 Array1<f64>,
2205 gam_problem::HessianResult,
2206 ),
2207 ) {
2208 self.last_cost = Some(eval.0);
2209 self.last_eval = Some(eval);
2210 self.last_outer_iter =
2211 Some(gam_solve::estimate::reml::outer_eval::current_outer_iter());
2212 }
2213
2214 fn store_cost(&mut self, cost: f64) {
2215 self.last_cost = Some(cost);
2216 self.last_outer_iter =
2217 Some(gam_solve::estimate::reml::outer_eval::current_outer_iter());
2218 }
2219
2220 fn reset(&mut self) {
2221 self.current_theta = None;
2222 self.current_latent = None;
2223 self.current_hyper_dirs = None;
2224 self.current_design_cache_id = None;
2225 self.latent_design_cache.invalidate();
2226 self.last_cost = None;
2227 self.last_eval = None;
2228 self.last_outer_iter = None;
2229 }
2230}
2231
2232pub fn fixed_kappa_profiled_reml_score(
2248 data: ArrayView2<'_, f64>,
2249 y: ArrayView1<'_, f64>,
2250 weights: ArrayView1<'_, f64>,
2251 offset: ArrayView1<'_, f64>,
2252 resolvedspec: &TermCollectionSpec,
2253 term_idx: usize,
2254 kappa: f64,
2255 family: LikelihoodSpec,
2256 options: &FitOptions,
2257) -> Result<f64, EstimationError> {
2258 if !kappa.is_finite() {
2259 crate::bail_invalid_estim!("fixed-κ profiled score probed a non-finite κ = {kappa}");
2260 }
2261 let (feature_cols, mut probe_basis) = match resolvedspec
2264 .smooth_terms
2265 .get(term_idx)
2266 .map(|t| &t.basis)
2267 {
2268 Some(SmoothBasisSpec::ConstantCurvature {
2269 feature_cols, spec, ..
2270 }) => (feature_cols.clone(), spec.clone()),
2271 _ => {
2272 crate::bail_invalid_estim!(
2273 "fixed-κ profiled score: term {term_idx} is not a constant-curvature smooth"
2274 )
2275 }
2276 };
2277 probe_basis.kappa = kappa;
2278
2279 let is_unweighted = weights.iter().all(|&w| (w - 1.0).abs() <= 1e-12);
2299 let is_zero_offset = offset.iter().all(|&o| o.abs() <= 1e-12);
2300 if family == LikelihoodSpec::gaussian_identity() && is_unweighted && is_zero_offset {
2301 let x_term = select_columns(data, &feature_cols).map_err(EstimationError::from)?;
2302 let score =
2303 gam_terms::basis::constant_curvature_honest_profiled_reml_score(x_term.view(), y, &probe_basis)
2304 .map_err(|e| {
2305 EstimationError::InvalidInput(format!(
2306 "fixed-κ honest profiled-REML score at κ={kappa} failed: {e}"
2307 ))
2308 })?;
2309 if !score.is_finite() {
2310 crate::bail_invalid_estim!(
2311 "fixed-κ honest profiled-REML score at κ={kappa} is non-finite"
2312 );
2313 }
2314 return Ok(score);
2315 }
2316
2317 let mut probe_spec = resolvedspec.clone();
2319 match probe_spec.smooth_terms.get_mut(term_idx).map(|t| &mut t.basis) {
2320 Some(SmoothBasisSpec::ConstantCurvature { spec, .. }) => spec.kappa = kappa,
2321 _ => {
2322 crate::bail_invalid_estim!(
2323 "fixed-κ profiled score: term {term_idx} is not a constant-curvature smooth"
2324 )
2325 }
2326 }
2327 let fixed_kappa_options = SpatialLengthScaleOptimizationOptions {
2328 enabled: false,
2329 ..SpatialLengthScaleOptimizationOptions::default()
2330 };
2331 let fit = fit_term_collectionwith_spatial_length_scale_optimization(
2332 data,
2333 y.to_owned(),
2334 weights.to_owned(),
2335 offset.to_owned(),
2336 &probe_spec,
2337 family,
2338 options,
2339 &fixed_kappa_options,
2340 )?;
2341 let score = fit_score(&fit.fit);
2342 if !score.is_finite() {
2343 crate::bail_invalid_estim!("fixed-κ profiled fit at κ={kappa} returned a non-finite score");
2344 }
2345 Ok(score)
2346}
2347
2348fn constant_curvature_kappa_fair_argmin(
2373 data: ArrayView2<'_, f64>,
2374 y: ArrayView1<'_, f64>,
2375 resolvedspec: &TermCollectionSpec,
2376 term_idx: usize,
2377) -> Option<f64> {
2378 let (kappa_min, kappa_max) = constant_curvature_kappa_bounds(data, resolvedspec, term_idx);
2379 if !(kappa_min.is_finite() && kappa_max.is_finite() && kappa_max > kappa_min) {
2380 return None;
2381 }
2382 let (feature_cols, base_spec) = match resolvedspec.smooth_terms.get(term_idx).map(|t| &t.basis) {
2383 Some(SmoothBasisSpec::ConstantCurvature {
2384 feature_cols, spec, ..
2385 }) => (feature_cols, spec.clone()),
2386 _ => return None,
2387 };
2388 let x_term = match select_columns(data, feature_cols) {
2389 Ok(x) => x,
2390 Err(e) => {
2391 log::info!("[spatial-kappa] #1464 κ-fair argmin column select failed ({e}); skipping");
2392 return None;
2393 }
2394 };
2395 const GRID_STEPS: usize = 24;
2401 let mut best: Option<(f64, f64)> = None; for i in 0..=GRID_STEPS {
2403 let t = i as f64 / GRID_STEPS as f64;
2404 let kappa = kappa_min + (kappa_max - kappa_min) * t;
2405 let mut probe_spec = base_spec.clone();
2406 probe_spec.kappa = kappa;
2407 match gam_terms::basis::constant_curvature_kappa_fair_sign_score(x_term.view(), y, &probe_spec) {
2408 Ok(score) => {
2409 if best.as_ref().is_none_or(|(b, _)| score < *b) {
2410 best = Some((score, kappa));
2411 }
2412 }
2413 Err(e) => {
2414 log::info!(
2415 "[spatial-kappa] #1464 κ-fair argmin probe at κ={kappa:.4} failed ({e}); skipping"
2416 );
2417 }
2418 }
2419 }
2420 best.map(|(score, kappa)| {
2421 log::info!(
2422 "[spatial-kappa] #1464 κ-fair argmin κ̂={kappa:.4} (κ-fair score={score:.6e}) for term {term_idx}"
2423 );
2424 kappa
2425 })
2426}
2427
2428fn select_constant_curvature_kappa_sign_seed(
2436 data: ArrayView2<'_, f64>,
2437 y: ArrayView1<'_, f64>,
2438 resolvedspec: &TermCollectionSpec,
2439 term_idx: usize,
2440) -> Option<f64> {
2441 let (kappa_min, kappa_max) = constant_curvature_kappa_bounds(data, resolvedspec, term_idx);
2442 if !(kappa_min.is_finite() && kappa_max.is_finite() && kappa_max > kappa_min) {
2443 return None;
2444 }
2445 let (feature_cols, base_spec) = match resolvedspec.smooth_terms.get(term_idx).map(|t| &t.basis) {
2457 Some(SmoothBasisSpec::ConstantCurvature {
2458 feature_cols, spec, ..
2459 }) => (feature_cols, spec.clone()),
2460 _ => return None,
2461 };
2462 let x_term = match select_columns(data, feature_cols) {
2463 Ok(x) => x,
2464 Err(e) => {
2465 log::info!("[spatial-kappa] #1464 sign-basin scan column select failed ({e}); skipping");
2466 return None;
2467 }
2468 };
2469 let probes = [
2473 kappa_min,
2474 0.5 * kappa_min,
2475 0.0,
2476 0.5 * kappa_max,
2477 kappa_max,
2478 ];
2479 let mut best: Option<(f64, f64)> = None; for &kappa in &probes {
2481 let mut probe_spec = base_spec.clone();
2482 probe_spec.kappa = kappa;
2483 match gam_terms::basis::constant_curvature_kappa_fair_sign_score(
2484 x_term.view(),
2485 y,
2486 &probe_spec,
2487 ) {
2488 Ok(score) => {
2489 if best.as_ref().is_none_or(|(b, _)| score < *b) {
2490 best = Some((score, kappa));
2491 }
2492 }
2493 Err(e) => {
2494 log::info!(
2495 "[spatial-kappa] #1464 sign-basin probe at κ={kappa:.4} failed ({e}); skipping"
2496 );
2497 }
2498 }
2499 }
2500 best.map(|(score, kappa)| {
2501 log::info!(
2502 "[spatial-kappa] #1464 κ-fair sign-basin scan selected κ_seed={kappa:.4} \
2503 (κ-fair score={score:.6e}) for term {term_idx}"
2504 );
2505 kappa
2506 })
2507}
2508
2509const SPATIAL_RANGE_PRESCAN_GRID: usize = 7;
2512
2513fn prescan_isotropic_spatial_range_seed(
2545 data: ArrayView2<'_, f64>,
2546 y: ArrayView1<'_, f64>,
2547 weights: ArrayView1<'_, f64>,
2548 offset: ArrayView1<'_, f64>,
2549 resolvedspec: &TermCollectionSpec,
2550 baseline_score: f64,
2551 family: &LikelihoodSpec,
2552 options: &FitOptions,
2553 kappa_options: &SpatialLengthScaleOptimizationOptions,
2554 spatial_terms: &[usize],
2555) -> Result<Vec<(usize, f64)>, EstimationError> {
2556 if has_aniso_terms(resolvedspec, spatial_terms)
2558 || !constant_curvature_term_indices(resolvedspec).is_empty()
2559 {
2560 return Ok(Vec::new());
2561 }
2562 let dims = spatial_dims_per_term(resolvedspec, spatial_terms);
2563 let mut working = resolvedspec.clone();
2567 let mut best_score = if baseline_score.is_finite() {
2568 baseline_score
2569 } else {
2570 f64::INFINITY
2571 };
2572 let mut overrides: Vec<(usize, f64)> = Vec::new();
2573 for (slot, &term_idx) in spatial_terms.iter().enumerate() {
2574 if dims.get(slot).copied().unwrap_or(1) != 1 {
2577 continue;
2578 }
2579 if get_spatial_length_scale(&working, term_idx).is_none() {
2582 continue;
2583 }
2584 let (psi_lo, psi_hi) = spatial_term_psi_bounds(data, &working, term_idx, kappa_options);
2585 if !(psi_lo.is_finite() && psi_hi.is_finite()) || psi_hi <= psi_lo {
2586 continue;
2587 }
2588 let mut term_best: Option<f64> = None;
2589 for g in 0..SPATIAL_RANGE_PRESCAN_GRID {
2590 let frac = g as f64 / (SPATIAL_RANGE_PRESCAN_GRID - 1) as f64;
2591 let psi = psi_lo + (psi_hi - psi_lo) * frac;
2592 let ls = (-psi).exp();
2596 if !ls.is_finite() || ls <= 0.0 {
2597 continue;
2598 }
2599 let mut probe = working.clone();
2600 if set_spatial_length_scale(&mut probe, term_idx, ls).is_err() {
2601 continue;
2602 }
2603 let fit = match fit_term_collection_forspec(
2612 data,
2613 y,
2614 weights,
2615 offset,
2616 &probe,
2617 family.clone(),
2618 options,
2619 ) {
2620 Ok(fit) => fit,
2621 Err(_) => continue,
2624 };
2625 let score = fit_score(&fit.fit);
2626 if score.is_finite() && score < best_score - 1e-7 * best_score.abs().max(1.0) {
2629 best_score = score;
2630 term_best = Some(ls);
2631 }
2632 }
2633 if let Some(ls) = term_best {
2634 set_spatial_length_scale(&mut working, term_idx, ls)?;
2635 overrides.push((term_idx, ls));
2636 log::info!(
2637 "[spatial-kappa] #1074 range pre-scan: term {term_idx} re-seeded at \
2638 length_scale={ls:.5} (profiled REML {best_score:.5}, was {baseline_score:.5})"
2639 );
2640 }
2641 }
2642 Ok(overrides)
2643}
2644
2645const JOINT_RESTART_WINDOW_FRACTIONS: [f64; 5] = [0.0, 0.2, 0.45, 0.7, 1.0];
2654
2655fn joint_solve_from_window_fraction(
2671 data: ArrayView2<'_, f64>,
2672 y: ArrayView1<'_, f64>,
2673 weights: ArrayView1<'_, f64>,
2674 offset: ArrayView1<'_, f64>,
2675 base_spec: &TermCollectionSpec,
2676 spatial_terms: &[usize],
2677 fraction: f64,
2678 family: &LikelihoodSpec,
2679 options: &FitOptions,
2680 baseline_options: &FitOptions,
2681 kappa_options: &SpatialLengthScaleOptimizationOptions,
2682) -> Result<Option<(FittedTermCollectionWithSpec, f64)>, EstimationError> {
2683 let mut seed_spec = base_spec.clone();
2684 let mut any_set = false;
2685 for &term_idx in spatial_terms {
2686 if get_spatial_length_scale(&seed_spec, term_idx).is_none() {
2687 continue;
2688 }
2689 let (psi_lo, psi_hi) = spatial_term_psi_bounds(data, &seed_spec, term_idx, kappa_options);
2690 if !(psi_lo.is_finite() && psi_hi.is_finite()) || psi_hi <= psi_lo {
2691 continue;
2692 }
2693 let psi = psi_lo + (psi_hi - psi_lo) * fraction;
2694 let ls = (-psi).exp();
2695 if !ls.is_finite() || ls <= 0.0 {
2696 continue;
2697 }
2698 if set_spatial_length_scale(&mut seed_spec, term_idx, ls).is_ok() {
2699 any_set = true;
2700 }
2701 }
2702 if !any_set {
2703 return Ok(None);
2704 }
2705 let seed_best = match fit_term_collection_forspec(
2709 data,
2710 y,
2711 weights,
2712 offset,
2713 &seed_spec,
2714 family.clone(),
2715 baseline_options,
2716 ) {
2717 Ok(fit) => fit,
2718 Err(_) => return Ok(None),
2719 };
2720 let seed_spec = freeze_term_collection_from_design(&seed_spec, &seed_best.design)?;
2721 let seed_terms = spatial_length_scale_term_indices(&seed_spec);
2724 if seed_terms.is_empty() {
2725 let score = fit_score(&seed_best.fit);
2726 return Ok(Some((
2727 FittedTermCollectionWithSpec {
2728 fit: seed_best.fit,
2729 design: seed_best.design,
2730 resolvedspec: seed_spec,
2731 adaptive_diagnostics: seed_best.adaptive_diagnostics,
2732 kappa_timing: None,
2733 },
2734 score,
2735 )));
2736 }
2737 let joint = try_exact_joint_spatial_length_scale_optimization(
2738 data,
2739 y,
2740 weights,
2741 offset,
2742 &seed_spec,
2743 &seed_best,
2744 family.clone(),
2745 options,
2746 kappa_options,
2747 &seed_terms,
2748 )?;
2749 match joint {
2750 Some(fit) => {
2751 let score = fit_score(&fit.fit);
2752 Ok(Some((fit, score)))
2753 }
2754 None => {
2757 let score = fit_score(&seed_best.fit);
2758 Ok(Some((
2759 FittedTermCollectionWithSpec {
2760 fit: seed_best.fit,
2761 design: seed_best.design,
2762 resolvedspec: seed_spec,
2763 adaptive_diagnostics: seed_best.adaptive_diagnostics,
2764 kappa_timing: None,
2765 },
2766 score,
2767 )))
2768 }
2769 }
2770}
2771
2772fn try_exact_joint_spatial_length_scale_optimization(
2773 data: ArrayView2<'_, f64>,
2774 y: ArrayView1<'_, f64>,
2775 weights: ArrayView1<'_, f64>,
2776 offset: ArrayView1<'_, f64>,
2777 resolvedspec: &TermCollectionSpec,
2778 best: &FittedTermCollection,
2779 family: LikelihoodSpec,
2780 options: &FitOptions,
2781 kappa_options: &SpatialLengthScaleOptimizationOptions,
2782 spatial_terms: &[usize],
2783) -> Result<Option<FittedTermCollectionWithSpec>, EstimationError> {
2784 if spatial_terms.is_empty() {
2785 return Ok(None);
2786 }
2787 kappa_options
2792 .validate()
2793 .map_err(EstimationError::InvalidInput)?;
2794
2795 let cc_term_set = constant_curvature_term_indices(resolvedspec);
2815 let all_spatial_are_cc =
2816 !cc_term_set.is_empty() && spatial_terms.iter().all(|t| cc_term_set.contains(t));
2817 if all_spatial_are_cc {
2818 let mut fixed_kappa_spec = resolvedspec.clone();
2819 let mut any_kappa_chosen = false;
2820 for &term_idx in spatial_terms {
2821 if let Some(kappa_hat) =
2832 constant_curvature_kappa_fair_argmin(data, y, resolvedspec, term_idx)
2833 .filter(|&k| k < 0.0)
2834 {
2835 if let Some(SmoothBasisSpec::ConstantCurvature { spec: cc, .. }) = fixed_kappa_spec
2836 .smooth_terms
2837 .get_mut(term_idx)
2838 .map(|t| &mut t.basis)
2839 {
2840 cc.kappa = kappa_hat;
2841 any_kappa_chosen = true;
2842 log::info!(
2843 "[spatial-kappa] #1464 term {term_idx}: fixed κ̂ = {kappa_hat:.4} from κ-fair argmin (hyperbolic basin; profiling ρ only)"
2844 );
2845 }
2846 }
2847 }
2848 if any_kappa_chosen {
2849 let baseline_score = fit_score(&best.fit);
2853 let fitted = fit_term_collection_forspec(
2854 data,
2855 y,
2856 weights,
2857 offset,
2858 &fixed_kappa_spec,
2859 family.clone(),
2860 options,
2861 )?;
2862 let frozen_spec =
2863 freeze_term_collection_from_design(&fixed_kappa_spec, &fitted.design)?;
2864 let mut fit = fitted.fit;
2865 fit.reml_score = baseline_score;
2877 return Ok(Some(FittedTermCollectionWithSpec {
2878 fit,
2879 design: fitted.design,
2880 resolvedspec: frozen_spec,
2881 adaptive_diagnostics: fitted.adaptive_diagnostics,
2882 kappa_timing: None,
2883 }));
2884 }
2885 }
2886
2887 if try_build_spatial_log_kappa_hyper_dirs(data, resolvedspec, &best.design, spatial_terms)?
2888 .is_none()
2889 {
2890 if !constant_curvature_term_indices(resolvedspec).is_empty() {
2891 log::info!(
2892 "[#1464-trace] try_exact_joint RETURNED None (hyper_dirs unavailable); \
2893 κ̂ comes from a NON-joint path"
2894 );
2895 }
2896 return Ok(None);
2897 }
2898 if !constant_curvature_term_indices(resolvedspec).is_empty() {
2899 log::info!(
2900 "[#1464-trace] try_exact_joint ENTERED for {} spatial term(s); CC present",
2901 spatial_terms.len()
2902 );
2903 }
2904
2905 const JOINT_RHO_BOUND: f64 = 12.0;
2906 let rho_dim = best.fit.lambdas.len();
2907
2908 let has_constant_curvature_term = !constant_curvature_term_indices(resolvedspec).is_empty();
2922 let rho_upper_bound = if has_constant_curvature_term {
2923 gam_solve::estimate::RHO_BOUND
2924 } else {
2925 JOINT_RHO_BOUND
2926 };
2927
2928 let dims_per_term = spatial_dims_per_term(resolvedspec, spatial_terms);
2930 let use_aniso = has_aniso_terms(resolvedspec, spatial_terms);
2931
2932 let log_kappa0 = if use_aniso {
2937 SpatialLogKappaCoords::from_length_scales_aniso(resolvedspec, spatial_terms, kappa_options)
2938 } else {
2939 SpatialLogKappaCoords::from_length_scales(resolvedspec, spatial_terms, kappa_options)
2940 };
2941 let mut log_kappa0 =
2944 log_kappa0.reseed_from_data(data, resolvedspec, spatial_terms, kappa_options);
2945 let mut cc_sign_seeds: Vec<(usize, f64)> = Vec::new();
2961 if has_constant_curvature_term {
2962 for (slot, &term_idx) in spatial_terms.iter().enumerate() {
2963 if constant_curvature_term_spec(resolvedspec, term_idx).is_none() {
2964 continue;
2965 }
2966 let scan = select_constant_curvature_kappa_sign_seed(
2967 data,
2968 y,
2969 resolvedspec,
2970 term_idx,
2971 );
2972 match scan {
2977 Some(kappa_seed) => {
2978 log::info!(
2979 "[#1464-trace] term {term_idx}: κ-fair sign-basin scan picked κ_seed = {kappa_seed}"
2980 );
2981 log_kappa0.set_scalar_slot(slot, kappa_seed);
2982 cc_sign_seeds.push((slot, kappa_seed));
2983 }
2984 None => {
2985 log::info!(
2986 "[#1464-trace] term {term_idx}: fixed-κ sign-basin scan returned NONE (no seed applied)"
2987 );
2988 }
2989 }
2990 }
2991 }
2992 let log_kappa_lower = if use_aniso {
2993 SpatialLogKappaCoords::lower_bounds_aniso_from_data(
2994 data,
2995 resolvedspec,
2996 spatial_terms,
2997 &dims_per_term,
2998 kappa_options,
2999 )
3000 } else {
3001 SpatialLogKappaCoords::lower_bounds_from_data(
3002 data,
3003 resolvedspec,
3004 spatial_terms,
3005 kappa_options,
3006 )
3007 };
3008 let log_kappa_upper = if use_aniso {
3009 SpatialLogKappaCoords::upper_bounds_aniso_from_data(
3010 data,
3011 resolvedspec,
3012 spatial_terms,
3013 &dims_per_term,
3014 kappa_options,
3015 )
3016 } else {
3017 SpatialLogKappaCoords::upper_bounds_from_data(
3018 data,
3019 resolvedspec,
3020 spatial_terms,
3021 kappa_options,
3022 )
3023 };
3024 let mut log_kappa_lower = log_kappa_lower;
3048 let mut log_kappa_upper = log_kappa_upper;
3049 for &(slot, kappa_seed) in &cc_sign_seeds {
3050 if kappa_seed != 0.0 {
3051 log_kappa_lower.set_scalar_slot(slot, kappa_seed);
3052 log_kappa_upper.set_scalar_slot(slot, kappa_seed);
3053 }
3054 log::info!(
3055 "[#1464-trace] slot {slot}: FROZE joint ψ coordinate at κ_seed={kappa_seed} \
3056 (window [{}, {}]); raw fit_score is sign-blind so the κ-fair scan is authoritative",
3057 log_kappa_lower.as_array()[log_kappa_lower.dims_per_term()[..slot].iter().sum::<usize>()],
3058 log_kappa_upper.as_array()[log_kappa_upper.dims_per_term()[..slot].iter().sum::<usize>()],
3059 );
3060 }
3061 let log_kappa0 = log_kappa0.clamp_to_bounds(&log_kappa_lower, &log_kappa_upper);
3064 let setup = ExactJointHyperSetup::new(
3065 best.fit.lambdas.mapv(f64::ln),
3066 Array1::<f64>::from_elem(rho_dim, -JOINT_RHO_BOUND),
3067 Array1::<f64>::from_elem(rho_dim, rho_upper_bound),
3068 log_kappa0,
3069 log_kappa_lower,
3070 log_kappa_upper,
3071 );
3072
3073 let theta0 = setup.theta0();
3074 let lower = setup.lower();
3075 let upper = setup.upper();
3076
3077 let kind = if use_aniso {
3089 SpatialHyperKind::Anisotropic
3090 } else {
3091 SpatialHyperKind::Isotropic
3092 };
3093 let (outcome, kappa_timing) = run_exact_joint_spatial_optimization(
3094 kind,
3095 data,
3096 y,
3097 weights,
3098 offset,
3099 resolvedspec,
3100 &best.design,
3101 family.clone(),
3102 options,
3103 spatial_terms,
3104 &dims_per_term,
3105 &theta0,
3106 &lower,
3107 &upper,
3108 rho_dim,
3109 kappa_options,
3110 )?;
3111
3112 let baseline_score = fit_score(&best.fit);
3113
3114 let (theta_star, joint_final_value) = match outcome {
3124 SpatialJointOutcome::Optimized {
3125 theta_star,
3126 final_value,
3127 } => (theta_star, final_value),
3128 SpatialJointOutcome::NonConverged {
3129 iterations,
3130 final_value,
3131 final_grad_norm,
3132 } => {
3133 if has_constant_curvature_term {
3134 log::info!(
3135 "[#1464-trace] joint solve NONCONVERGED (iters={iterations}, \
3136 final_value={final_value}); returning FROZEN BASELINE geometry \
3137 (κ̂ = spec default, NOT the joint candidate)"
3138 );
3139 }
3140 log::info!(
3141 "[spatial-kappa] joint spatial optimization did not converge \
3142 (iterations={}, final_objective={:.6e}, final_grad_norm={}); \
3143 keeping the frozen baseline geometry",
3144 iterations,
3145 final_value,
3146 final_grad_norm.map_or_else(|| "n/a".to_string(), |g| format!("{g:.3e}")),
3147 );
3148 return Ok(Some(fit_frozen_baseline_geometry(
3149 data,
3150 y,
3151 weights,
3152 offset,
3153 resolvedspec,
3154 best,
3155 family,
3156 options,
3157 baseline_score,
3158 Some(kappa_timing),
3159 )?));
3160 }
3161 };
3162
3163 let accept_tol = options.tol.max(1e-8 * baseline_score.abs()).max(1e-12);
3168 if joint_final_value > baseline_score + accept_tol {
3169 if has_constant_curvature_term {
3170 log::info!(
3171 "[#1464-trace] joint candidate WORSENED score (joint={joint_final_value}, \
3172 baseline={baseline_score}); returning FROZEN BASELINE geometry \
3173 (κ̂ = spec default, NOT the joint candidate)"
3174 );
3175 }
3176 log::info!(
3177 "[spatial-kappa] exact joint spatial candidate worsened the profiled score (joint={:.6e}, baseline={:.6e}, tol={:.2e}); keeping the frozen baseline geometry",
3178 joint_final_value,
3179 baseline_score,
3180 accept_tol,
3181 );
3182 return Ok(Some(fit_frozen_baseline_geometry(
3183 data,
3184 y,
3185 weights,
3186 offset,
3187 resolvedspec,
3188 best,
3189 family,
3190 options,
3191 baseline_score,
3192 Some(kappa_timing),
3193 )?));
3194 }
3195
3196 let rho_star = theta_star.slice(s![..rho_dim]).mapv(f64::exp);
3197 let log_kappa_star =
3198 SpatialLogKappaCoords::from_theta_tail_with_dims(&theta_star, rho_dim, dims_per_term);
3199 if has_constant_curvature_term {
3205 let star = log_kappa_star.as_array();
3206 let dims = log_kappa_star.dims_per_term();
3207 for (slot, &term_idx) in spatial_terms.iter().enumerate() {
3208 if constant_curvature_term_spec(resolvedspec, term_idx).is_some() {
3209 let off: usize = dims[..slot].iter().sum();
3210 log::info!(
3211 "[#1464-trace] term {term_idx}: joint solver CONVERGED ψ-tail κ = {} \
3212 (this is the optimised candidate; joint_final_value={joint_final_value})",
3213 star[off]
3214 );
3215 }
3216 }
3217 }
3218 let baseline_spec = resolvedspec;
3222 let optimized_spec = log_kappa_star.apply_tospec(resolvedspec, spatial_terms)?;
3223 let optimized = fit_term_collection_forspecwith_heuristic_lambdas(
3224 data,
3225 y,
3226 weights,
3227 offset,
3228 &optimized_spec,
3229 rho_star.as_slice(),
3230 family.clone(),
3231 options,
3232 )?;
3233
3234 let optimized_edf = optimized.fit.inference.as_ref().map(|inf| inf.edf_total);
3248 if let Some(opt_edf) = optimized_edf
3249 && opt_edf < SPATIAL_COLLAPSE_EDF_FLOOR
3250 {
3251 let baseline = fit_frozen_baseline_geometry(
3252 data,
3253 y,
3254 weights,
3255 offset,
3256 baseline_spec,
3257 best,
3258 family.clone(),
3259 options,
3260 baseline_score,
3261 Some(kappa_timing),
3262 )?;
3263 let baseline_edf = baseline.fit.inference.as_ref().map(|inf| inf.edf_total);
3264 if let Some(base_edf) = baseline_edf
3265 && base_edf >= opt_edf + SPATIAL_COLLAPSE_EDF_MARGIN
3266 {
3267 log::info!(
3268 "[spatial-kappa] joint candidate collapsed to the null (edf={opt_edf:.3}); \
3269 baseline geometry retains edf={base_edf:.3} — keeping the frozen baseline",
3270 );
3271 return Ok(Some(baseline));
3272 }
3273 }
3276
3277 let mut fit = optimized.fit;
3281 fit.reml_score = joint_final_value;
3282 let optimized_result = FittedTermCollectionWithSpec {
3283 fit,
3284 design: optimized.design,
3285 resolvedspec: optimized_spec,
3286 adaptive_diagnostics: optimized.adaptive_diagnostics,
3287 kappa_timing: Some(kappa_timing),
3288 };
3289
3290 Ok(Some(optimized_result))
3291}
3292
3293const SPATIAL_COLLAPSE_EDF_FLOOR: f64 = 2.5;
3297
3298const SPATIAL_COLLAPSE_EDF_MARGIN: f64 = 1.0;
3303
3304fn fit_frozen_baseline_geometry(
3340 data: ArrayView2<'_, f64>,
3341 y: ArrayView1<'_, f64>,
3342 weights: ArrayView1<'_, f64>,
3343 offset: ArrayView1<'_, f64>,
3344 resolvedspec: &TermCollectionSpec,
3345 best: &FittedTermCollection,
3346 family: LikelihoodSpec,
3347 options: &FitOptions,
3348 baseline_score: f64,
3349 kappa_timing: Option<SpatialLengthScaleOptimizationTiming>,
3350) -> Result<FittedTermCollectionWithSpec, EstimationError> {
3351 let baseline = fit_term_collection_forspecwith_heuristic_lambdas(
3352 data,
3353 y,
3354 weights,
3355 offset,
3356 resolvedspec,
3357 best.fit.lambdas.as_slice(),
3358 family.clone(),
3359 options,
3360 )?;
3361 let best_edf = best.fit.inference.as_ref().map(|inf| inf.edf_total);
3366 let baseline_edf = baseline.fit.inference.as_ref().map(|inf| inf.edf_total);
3367 let baseline = match (best_edf, baseline_edf) {
3368 (Some(best_edf), Some(base_edf))
3369 if base_edf < SPATIAL_COLLAPSE_EDF_FLOOR
3370 && best_edf >= base_edf + SPATIAL_COLLAPSE_EDF_MARGIN =>
3371 {
3372 log::info!(
3373 "[spatial-kappa] warm-started frozen baseline collapsed (edf={base_edf:.3}) \
3374 below the certified baseline (edf={best_edf:.3}); refitting from scratch",
3375 );
3376 fit_term_collection_forspec(data, y, weights, offset, resolvedspec, family, options)?
3377 }
3378 _ => baseline,
3379 };
3380 let mut fit = baseline.fit;
3381 fit.reml_score = baseline_score;
3382 Ok(FittedTermCollectionWithSpec {
3383 fit,
3384 design: baseline.design,
3385 resolvedspec: resolvedspec.clone(),
3386 adaptive_diagnostics: baseline.adaptive_diagnostics,
3387 kappa_timing,
3388 })
3389}
3390
3391#[derive(Clone, Copy, PartialEq, Eq, Debug)]
3403enum SpatialHyperKind {
3404 Anisotropic,
3405 Isotropic,
3406}
3407
3408impl SpatialHyperKind {
3409 fn label(self) -> &'static str {
3412 match self {
3413 SpatialHyperKind::Anisotropic => "spatial-aniso-joint",
3414 SpatialHyperKind::Isotropic => "spatial-iso-joint",
3415 }
3416 }
3417
3418 fn adjective(self) -> &'static str {
3420 match self {
3421 SpatialHyperKind::Anisotropic => "anisotropic",
3422 SpatialHyperKind::Isotropic => "isotropic",
3423 }
3424 }
3425
3426 fn coord_name(self) -> &'static str {
3429 match self {
3430 SpatialHyperKind::Anisotropic => "psi",
3431 SpatialHyperKind::Isotropic => "kappa",
3432 }
3433 }
3434}
3435
3436struct SpatialFrozenGlmInputs {
3442 y: Array1<f64>,
3443 weights: Array1<f64>,
3444 offset: Array1<f64>,
3445 family: LikelihoodSpec,
3446}
3447
3448fn frozen_glm_tensor_eligible_family(family: &LikelihoodSpec) -> bool {
3465 !family.is_gaussian_identity()
3466 && matches!(
3467 &family.response,
3468 ResponseFamily::Binomial
3469 | ResponseFamily::Poisson
3470 | ResponseFamily::Gamma
3471 | ResponseFamily::NegativeBinomial { .. }
3472 )
3473}
3474
3475struct SpatialJointContext<'d> {
3476 data: ArrayView2<'d, f64>,
3477 rho_dim: usize,
3478 kind: SpatialHyperKind,
3479 cache: SingleBlockExactJointDesignCache<'d>,
3480 evaluator: gam_solve::estimate::ExternalJointHyperEvaluator<'d>,
3481 frozen_glm_inputs: Option<SpatialFrozenGlmInputs>,
3482 frozen_glm_psi_bounds: Option<(f64, f64)>,
3483 frozen_glm_tensor: Option<gam_solve::glm_sufficient_lane::FrozenWeightGramTensor>,
3484 frozen_glm_tensor_attempted: bool,
3485 frozen_glm_weight_memo: Option<(Array1<f64>, Array1<f64>)>,
3497}
3498
3499#[derive(Clone, Copy, Debug, Default)]
3500struct NfreeSkipGateStatus {
3501 shape: bool,
3502 value: bool,
3503 gradient: bool,
3504 penalty: bool,
3505 revision: bool,
3506 second_order: bool,
3507}
3508
3509impl NfreeSkipGateStatus {
3510 fn would_skip(self, require_gradient: bool) -> bool {
3511 self.shape
3512 && self.value
3513 && (!require_gradient || self.gradient)
3514 && self.penalty
3515 && self.revision
3516 && !self.second_order
3517 }
3518}
3519
3520impl<'d> SpatialJointContext<'d> {
3521 fn nfree_skip_gate_status(
3522 &self,
3523 theta: &Array1<f64>,
3524 allow_second_order: bool,
3525 require_gradient: bool,
3526 ) -> NfreeSkipGateStatus {
3527 let shape = theta.len() == self.rho_dim + 1;
3528 let (value, gradient) = if shape {
3529 let psi = theta[self.rho_dim];
3530 (
3531 self.evaluator.psi_gram_tensor_covers(psi)
3532 && self.evaluator.psi_gram_tensor_covers_skip(psi),
3533 !require_gradient || self.evaluator.psi_gram_tensor_covers_gradient(psi),
3534 )
3535 } else {
3536 (false, false)
3537 };
3538 NfreeSkipGateStatus {
3539 shape,
3540 value,
3541 gradient,
3542 penalty: self.evaluator.supports_nfree_penalty_rekey(),
3543 revision: self.evaluator.nfree_fast_path_revision().is_some(),
3544 second_order: allow_second_order,
3545 }
3546 }
3547
3548 fn frozen_glm_working_state(
3549 &self,
3550 beta: &Array1<f64>,
3551 ) -> Result<Option<(Array1<f64>, Array1<f64>)>, EstimationError> {
3552 let Some(inputs) = self.frozen_glm_inputs.as_ref() else {
3553 return Ok(None);
3554 };
3555 if beta.len() != self.cache.design().design.ncols() {
3556 return Ok(None);
3557 }
3558 let mut eta = self.cache.design().design.matrixvectormultiply(beta);
3559 if eta.len() != inputs.offset.len() {
3560 crate::bail_invalid_estim!(
3561 "frozen GLM tensor warm-state row mismatch: eta={}, offset={}",
3562 eta.len(),
3563 inputs.offset.len()
3564 );
3565 }
3566 eta += &inputs.offset;
3567 let obs = evaluate_standard_familyobservations(
3568 inputs.family.clone(),
3569 None,
3570 None,
3571 None,
3572 &inputs.y,
3573 &inputs.weights,
3574 &eta,
3575 )?;
3576 let mut working_response = obs.eta.clone();
3577 for i in 0..working_response.len() {
3578 let wi = obs.fisherweight[i].max(1e-12);
3579 working_response[i] += obs.score[i] / wi;
3580 }
3581 Ok(Some((obs.fisherweight, working_response)))
3582 }
3583
3584 fn frozen_glm_trial_weights(
3593 &mut self,
3594 beta: &Array1<f64>,
3595 ) -> Result<Option<Array1<f64>>, EstimationError> {
3596 if let Some((memo_beta, memo_w)) = self.frozen_glm_weight_memo.as_ref()
3597 && memo_beta.len() == beta.len()
3598 && memo_beta
3599 .iter()
3600 .zip(beta.iter())
3601 .all(|(a, b)| a.to_bits() == b.to_bits())
3602 {
3603 return Ok(Some(memo_w.clone()));
3604 }
3605 match self.frozen_glm_working_state(beta)? {
3606 Some((current_w, _)) => {
3607 self.frozen_glm_weight_memo = Some((beta.clone(), current_w.clone()));
3608 Ok(Some(current_w))
3609 }
3610 None => Ok(None),
3611 }
3612 }
3613
3614 fn ensure_frozen_glm_tensor(
3615 &mut self,
3616 theta: &Array1<f64>,
3617 warm_beta: Option<&Array1<f64>>,
3618 ) -> Result<(), EstimationError> {
3619 if self.frozen_glm_tensor.is_some() || self.frozen_glm_tensor_attempted {
3620 return Ok(());
3621 }
3622 let Some((psi_lo, psi_hi)) = self.frozen_glm_psi_bounds else {
3623 return Ok(());
3624 };
3625 if theta.len() != self.rho_dim + 1 {
3626 self.frozen_glm_tensor_attempted = true;
3627 return Ok(());
3628 }
3629 let Some(beta) = warm_beta else {
3630 return Ok(());
3631 };
3632 let Some((frozen_w, working_z)) = self.frozen_glm_working_state(beta)? else {
3633 self.frozen_glm_tensor_attempted = true;
3634 return Ok(());
3635 };
3636 let theta_probe_base = theta.clone();
3637 let rho_dim = self.rho_dim;
3638 let Self {
3645 cache, evaluator, ..
3646 } = self;
3647 let tensor = evaluator.build_frozen_glm_gram_tensor(
3648 |psi| {
3649 let mut theta_probe = theta_probe_base.clone();
3650 theta_probe[rho_dim] = psi;
3651 cache.ensure_theta(&theta_probe)?;
3652 Ok(cache.design().design.clone())
3653 },
3654 frozen_w.view(),
3655 working_z.view(),
3656 psi_lo,
3657 psi_hi,
3658 );
3659 self.cache
3660 .ensure_theta(theta)
3661 .map_err(EstimationError::InvalidInput)?;
3662 self.frozen_glm_tensor_attempted = true;
3663 if let Some(tensor) = tensor {
3664 self.frozen_glm_tensor = Some(tensor);
3665 log::info!(
3666 "[STAGE] {} certified frozen-W GLM ψ tensor over [{psi_lo:.3}, {psi_hi:.3}]",
3667 self.kind.label(),
3668 );
3669 } else {
3670 log::info!(
3671 "[STAGE] {} frozen-W GLM ψ tensor did not certify over [{psi_lo:.3}, {psi_hi:.3}]",
3672 self.kind.label(),
3673 );
3674 }
3675 Ok(())
3676 }
3677
3678 fn stage_frozen_glm_trial_statistics(
3679 &mut self,
3680 theta: &Array1<f64>,
3681 warm_beta: Option<&Array1<f64>>,
3682 allow_gradient: bool,
3683 ) -> Result<(), EstimationError> {
3684 let kind = self.kind;
3685 let mut staged_gram: Option<Array2<f64>> = None;
3686 let mut staged_deriv: Option<(Array2<f64>, Array1<f64>)> = None;
3687 if theta.len() == self.rho_dim + 1 {
3688 let psi = theta[self.rho_dim];
3689 let tensor_covers = self
3696 .frozen_glm_tensor
3697 .as_ref()
3698 .is_some_and(|t| t.contains(psi));
3699 let current_w = if tensor_covers {
3700 match warm_beta {
3701 Some(beta) => self.frozen_glm_trial_weights(beta)?,
3702 None => None,
3703 }
3704 } else {
3705 None
3706 };
3707 if let (Some(tensor), Some(current_w)) =
3708 (self.frozen_glm_tensor.as_ref(), current_w.as_ref())
3709 {
3710 const FROZEN_GLM_WEIGHT_DRIFT_RTOL: f64 = 1e-3;
3711 if tensor.weight_drift_within(current_w.view(), FROZEN_GLM_WEIGHT_DRIFT_RTOL) {
3712 staged_gram = Some(tensor.gram_at(psi));
3713 log::debug!(
3714 "[STAGE] {} trial at psi={psi:.6}: serving frozen-W GLM \
3715 first-Fisher-step XᵀWX n-free (weight drift within tol)",
3716 kind.label(),
3717 );
3718 }
3719 if allow_gradient
3720 && tensor.contains_for_gradient(psi)
3721 && let Some((dgram_dpsi, drhs_dpsi)) =
3722 tensor.gradient_pair_if_sound(psi, current_w.view())
3723 {
3724 staged_deriv = Some((dgram_dpsi, drhs_dpsi));
3725 log::debug!(
3726 "[STAGE] {} trial at psi={psi:.6}: serving frozen-W GLM \
3727 ψ-gradient (∂G/∂ψ, ∂b/∂ψ) n-free (gradient weight drift within \
3728 tight tol); B_j stays exact",
3729 kind.label(),
3730 );
3731 }
3732 }
3733 }
3734 self.evaluator.stage_glm_first_step_gram(staged_gram);
3735 self.evaluator.stage_glm_psi_gram_deriv(staged_deriv);
3736 Ok(())
3737 }
3738
3739 fn eval_full(
3741 &mut self,
3742 theta: &Array1<f64>,
3743 order: gam_solve::rho_optimizer::OuterEvalOrder,
3744 analytic_outer_hessian_available: bool,
3745 ) -> Result<
3746 (
3747 f64,
3748 Array1<f64>,
3749 gam_problem::HessianResult,
3750 ),
3751 EstimationError,
3752 > {
3753 use gam_solve::rho_optimizer::OuterEvalOrder;
3754 let allow_second_order = matches!(order, OuterEvalOrder::ValueGradientHessian)
3755 && analytic_outer_hessian_available;
3756 if let Some(eval) = self.cache.memoized_eval(theta) {
3757 let cached_satisfies_order = !allow_second_order || eval.2.is_analytic();
3758 if cached_satisfies_order {
3759 return Ok(eval);
3760 }
3761 }
3762 let kind = self.kind;
3763 let nfree_fast_path_revision = self.evaluator.nfree_fast_path_revision();
3799 let skip_design_realization = !allow_second_order && theta.len() == self.rho_dim + 1 && {
3800 let psi = theta[self.rho_dim];
3801 self.evaluator.psi_gram_tensor_covers(psi)
3802 && self.evaluator.psi_gram_tensor_covers_gradient(psi)
3809 && self.evaluator.psi_gram_tensor_covers_skip(psi)
3826 && self.evaluator.supports_nfree_penalty_rekey()
3831 && nfree_fast_path_revision.is_some()
3832 };
3833 let skip_design_realization = false && skip_design_realization;
3837 log::warn!(
3838 "[OUTER-FD-AUDIT TEMP-SKIPOFF-1122] skip_design_realization={skip_design_realization}"
3839 );
3840 if skip_design_realization {
3841 log::debug!(
3842 "[STAGE] {} eval_full at psi={:.6}: skipping n×k design re-realization \
3843 + reconditioning — criterion/gradient/inner-solve served n-free from \
3844 the certified ψ-gram tensor (GaussianFixedCache + k-space ψ-derivatives)",
3845 kind.label(),
3846 theta[self.rho_dim],
3847 );
3848 } else {
3849 self.cache
3850 .ensure_theta(theta)
3851 .map_err(EstimationError::InvalidInput)?;
3852 }
3853 let warm_beta = self.evaluator.current_beta();
3854 self.ensure_frozen_glm_tensor(theta, warm_beta.as_ref())?;
3855 self.stage_frozen_glm_trial_statistics(theta, warm_beta.as_ref(), !allow_second_order)?;
3863 let hyper_dirs = if skip_design_realization {
3870 self.cache.nfree_tensor_gradient_hyper_dirs(theta)?
3871 } else {
3872 self.cache.hyper_dirs_for_current_design(self.data, kind)?
3873 };
3874
3875 let design_revision = if skip_design_realization {
3876 nfree_fast_path_revision
3877 } else {
3878 Some(self.cache.design_revision())
3879 };
3880 if self.evaluator.supports_nfree_penalty_rekey() {
3894 match self.cache.canonical_penalties_at(theta) {
3895 Ok(penalty) => self.evaluator.stage_fast_path_penalty(Some(penalty)),
3896 Err(e) => {
3897 log::warn!(
3898 "[STAGE] {} eval_full at psi={:.6}: exact n-free S(ψ) rebuild failed \
3899 ({e}); clearing stage (eval falls to slow path)",
3900 kind.label(),
3901 theta[self.rho_dim],
3902 );
3903 self.evaluator.stage_fast_path_penalty(None);
3904 }
3905 }
3906 }
3907 let eval = evaluate_joint_reml_outer_eval_at_theta(
3914 &mut self.evaluator,
3915 self.cache.design(),
3916 theta,
3917 self.rho_dim,
3918 hyper_dirs,
3919 warm_beta.as_ref().map(|b: &Array1<f64>| b.view()),
3920 if allow_second_order {
3921 order
3922 } else {
3923 OuterEvalOrder::ValueAndGradient
3924 },
3925 design_revision,
3926 );
3927 if let Ok(ref value) = eval {
3928 self.cache.store_eval_at(theta, value.clone());
3929 }
3930 eval
3931 }
3932
3933 fn eval_efs(
3934 &mut self,
3935 theta: &Array1<f64>,
3936 ) -> Result<gam_problem::EfsEval, EstimationError> {
3937 self.cache
3938 .ensure_theta(theta)
3939 .map_err(EstimationError::InvalidInput)?;
3940 let kind = self.kind;
3941 let hyper_dirs = try_build_spatial_log_kappa_hyper_dirs(
3942 self.data,
3943 self.cache.spec(),
3944 self.cache.design(),
3945 &self.cache.spatial_terms,
3946 )?
3947 .ok_or_else(|| {
3948 EstimationError::InvalidInput(format!(
3949 "failed to build {} hyper_dirs for exact-joint EFS",
3950 kind.adjective(),
3951 ))
3952 })?;
3953 let design_revision = Some(self.cache.design_revision());
3954 let warm_beta = self.evaluator.current_beta();
3955 evaluate_joint_reml_efs_at_theta(
3956 &mut self.evaluator,
3957 self.cache.design(),
3958 theta,
3959 self.rho_dim,
3960 hyper_dirs,
3961 warm_beta.as_ref().map(|b: &Array1<f64>| b.view()),
3962 design_revision,
3963 )
3964 }
3965
3966 fn eval_cost(&mut self, theta: &Array1<f64>) -> f64 {
3972 if let Some(cost) = self.cache.memoized_cost(theta) {
3973 return cost;
3974 }
3975 let probe_start = std::time::Instant::now();
3990 let psi_distance = self
3991 .cache
3992 .current_theta
3993 .as_ref()
3994 .filter(|reference| reference.len() == theta.len())
3995 .map(|reference| {
3996 reference
3997 .iter()
3998 .zip(theta.iter())
3999 .map(|(a, b)| (a - b) * (a - b))
4000 .sum::<f64>()
4001 .sqrt()
4002 })
4003 .unwrap_or(f64::NAN);
4004 let nfree_fast_path_revision = self.evaluator.nfree_fast_path_revision();
4018 let skip_value_realization = theta.len() == self.rho_dim + 1 && {
4019 let psi = theta[self.rho_dim];
4020 self.evaluator.psi_gram_tensor_covers(psi)
4021 && self.evaluator.psi_gram_tensor_covers_skip(psi)
4030 && self.evaluator.supports_nfree_penalty_rekey()
4035 && nfree_fast_path_revision.is_some()
4036 };
4037 if theta.len() == self.rho_dim + 1
4038 && self.evaluator.has_psi_gram_tensor()
4039 && !self.evaluator.psi_gram_tensor_covers(theta[self.rho_dim])
4040 {
4041 self.cache.store_cost_at(theta, f64::INFINITY);
4042 return f64::INFINITY;
4043 }
4044 if !skip_value_realization && self.cache.ensure_theta(theta).is_err() {
4045 return f64::INFINITY;
4046 }
4047 if self.evaluator.supports_nfree_penalty_rekey() {
4053 match self.cache.canonical_penalties_at(theta) {
4054 Ok(penalty) => self.evaluator.stage_fast_path_penalty(Some(penalty)),
4055 Err(_) => self.evaluator.stage_fast_path_penalty(None),
4056 }
4057 }
4058 let warm_beta = self.evaluator.current_beta();
4059 if let Err(err) = self.ensure_frozen_glm_tensor(theta, warm_beta.as_ref()) {
4060 log::warn!(
4061 "[STAGE] {} value-probe at psi={:.6}: frozen-W GLM tensor setup failed ({err}); \
4062 falling back to exact streamed Gram",
4063 self.kind.label(),
4064 if theta.len() > self.rho_dim {
4065 theta[self.rho_dim]
4066 } else {
4067 f64::NAN
4068 },
4069 );
4070 self.evaluator.stage_glm_first_step_gram(None);
4071 self.evaluator.stage_glm_psi_gram_deriv(None);
4072 } else if let Err(err) =
4073 self.stage_frozen_glm_trial_statistics(theta, warm_beta.as_ref(), false)
4074 {
4075 log::warn!(
4076 "[STAGE] {} value-probe at psi={:.6}: frozen-W GLM staging failed ({err}); \
4077 falling back to exact streamed Gram",
4078 self.kind.label(),
4079 if theta.len() > self.rho_dim {
4080 theta[self.rho_dim]
4081 } else {
4082 f64::NAN
4083 },
4084 );
4085 self.evaluator.stage_glm_first_step_gram(None);
4086 self.evaluator.stage_glm_psi_gram_deriv(None);
4087 }
4088 let design_revision = if skip_value_realization {
4089 nfree_fast_path_revision
4090 } else {
4091 Some(self.cache.design_revision())
4092 };
4093 let cost_label = self.kind.label();
4094 let result = {
4095 let design = self.cache.design();
4096 self.evaluator.evaluate_cost_only(
4097 &design.design,
4098 &design.penalties,
4099 &design.nullspace_dims,
4100 design.linear_constraints.clone(),
4101 theta,
4102 self.rho_dim,
4103 warm_beta.as_ref().map(|b: &Array1<f64>| b.view()),
4104 cost_label,
4105 design_revision,
4106 )
4107 };
4108 match result {
4109 Ok(cost) => {
4110 log::debug!(
4111 "[STAGE] {cost_label} value-probe (order=Value): elapsed={:.3}s \
4112 cost={cost:.6e} trial_theta_distance={psi_distance:.3e}",
4113 probe_start.elapsed().as_secs_f64(),
4114 );
4115 self.cache.store_cost_at(theta, cost);
4116 cost
4117 }
4118 Err(_) => f64::INFINITY,
4119 }
4120 }
4121
4122 fn reset(&mut self) {
4123 self.cache.current_theta = None;
4124 self.cache.last_eval_theta = None;
4125 self.cache.last_cost = None;
4126 self.cache.last_eval = None;
4127 }
4128}
4129
4130enum SpatialJointOutcome {
4163 Optimized {
4167 theta_star: Array1<f64>,
4168 final_value: f64,
4169 },
4170 NonConverged {
4174 iterations: usize,
4175 final_value: f64,
4176 final_grad_norm: Option<f64>,
4177 },
4178}
4179
4180fn kphase_log_norms(theta: &Array1<f64>, rho_dim: usize) -> (f64, f64) {
4181 let theta_norm = theta.iter().map(|v| v * v).sum::<f64>().sqrt();
4182 let log_kappa_norm = theta
4183 .iter()
4184 .skip(rho_dim)
4185 .map(|v| v * v)
4186 .sum::<f64>()
4187 .sqrt();
4188 (theta_norm, log_kappa_norm)
4189}
4190
4191fn run_exact_joint_spatial_optimization(
4192 kind: SpatialHyperKind,
4193 data: ArrayView2<'_, f64>,
4194 y: ArrayView1<'_, f64>,
4195 weights: ArrayView1<'_, f64>,
4196 offset: ArrayView1<'_, f64>,
4197 resolvedspec: &TermCollectionSpec,
4198 baseline_design: &TermCollectionDesign,
4199 family: LikelihoodSpec,
4200 options: &FitOptions,
4201 spatial_terms: &[usize],
4202 dims_per_term: &[usize],
4203 theta0: &Array1<f64>,
4204 lower: &Array1<f64>,
4205 upper: &Array1<f64>,
4206 rho_dim: usize,
4207 kappa_options: &SpatialLengthScaleOptimizationOptions,
4208) -> Result<(SpatialJointOutcome, SpatialLengthScaleOptimizationTiming), EstimationError> {
4209 let label = kind.label();
4210 assert!(
4212 lower.len() == theta0.len() && upper.len() == theta0.len(),
4213 "spatial hyperparameter bounds must match theta length: lower_len={}, upper_len={}, theta_len={}",
4214 lower.len(),
4215 upper.len(),
4216 theta0.len()
4217 );
4218 assert!(
4219 baseline_design.smooth.terms.len() >= spatial_terms.len(),
4220 "baseline design must have at least one smooth term per spatial term: baseline_terms={}, spatial_terms={}",
4221 baseline_design.smooth.terms.len(),
4222 spatial_terms.len()
4223 );
4224 use gam_solve::rho_optimizer::OuterEvalOrder;
4225 use gam_problem::{DeclaredHessianForm, Derivative, OuterEval};
4226
4227 let theta_dim = theta0.len();
4228 let coord_dim = theta_dim - rho_dim;
4231 let analytic_outer_hessian_available =
4241 exact_joint_spatial_outer_hessian_available(&family, baseline_design);
4242 if !analytic_outer_hessian_available {
4243 log::info!(
4244 "[{label}] analytic outer Hessian unavailable for family/design; routing without second-order geometry (coord_dim={coord_dim})"
4245 );
4246 }
4247 let mut prefer_gradient_only = theta_dim > EXACT_JOINT_SECOND_ORDER_THETA_CAP;
4253 if prefer_gradient_only {
4254 log::info!(
4255 "[{label}] joint θ-dim {theta_dim} exceeds the exact pair-Hessian budget \
4256 ({EXACT_JOINT_SECOND_ORDER_THETA_CAP}); routing gradient-only quasi-Newton"
4257 );
4258 }
4259 let mut suppress_outer_hessian_for_nfree = false;
4269
4270 log::trace!(
4271 "[{}] starting analytic optimization: rho_dim={}, coord_dim={}, dims_per_term={:?}",
4272 label,
4273 rho_dim,
4274 coord_dim,
4275 dims_per_term,
4276 );
4277
4278 let mut ctx = SpatialJointContext {
4279 data,
4280 rho_dim,
4281 kind,
4282 cache: SingleBlockExactJointDesignCache::new(
4283 data,
4284 resolvedspec.clone(),
4285 baseline_design.clone(),
4286 spatial_terms.to_vec(),
4287 rho_dim,
4288 dims_per_term.to_vec(),
4289 )
4290 .map_err(EstimationError::InvalidInput)?,
4291 evaluator: gam_solve::estimate::ExternalJointHyperEvaluator::new(
4292 y,
4293 weights,
4294 &baseline_design.design,
4295 offset,
4296 &baseline_design.penalties,
4297 &external_opts_for_design(&family, baseline_design, options),
4298 label,
4299 )?,
4300 frozen_glm_inputs: if coord_dim == 1 && frozen_glm_tensor_eligible_family(&family) {
4301 Some(SpatialFrozenGlmInputs {
4302 y: y.to_owned(),
4303 weights: weights.to_owned(),
4304 offset: offset.to_owned(),
4305 family: family.clone(),
4306 })
4307 } else {
4308 None
4309 },
4310 frozen_glm_psi_bounds: if coord_dim == 1 && frozen_glm_tensor_eligible_family(&family) {
4311 Some((lower[rho_dim], upper[rho_dim]))
4312 } else {
4313 None
4314 },
4315 frozen_glm_tensor: None,
4316 frozen_glm_tensor_attempted: false,
4317 frozen_glm_weight_memo: None,
4318 };
4319
4320 let mut psi_rank_stable_floor: Option<f64> = None;
4343 let mut psi_rank_stable_ceiling: Option<f64> = None;
4352 let nfree_penalty_capable = coord_dim == 1
4353 && family.is_gaussian_identity()
4354 && ctx.cache.supports_nfree_penalty_rekey();
4355 if nfree_penalty_capable {
4356 let psi_lo = lower[rho_dim];
4357 let psi_hi = upper[rho_dim];
4358 let z = Array1::from_iter(y.iter().zip(offset.iter()).map(|(yi, oi)| yi - oi));
4359 let theta_probe_base = theta0.clone();
4360 let SpatialJointContext {
4363 cache, evaluator, ..
4364 } = &mut ctx;
4365 let attached = evaluator.build_and_set_psi_gram_tensor(
4366 |psi| {
4367 let mut theta_probe = theta_probe_base.clone();
4368 theta_probe[rho_dim] = psi;
4369 cache.ensure_theta(&theta_probe)?;
4370 Ok(cache.design().design.clone())
4371 },
4372 weights,
4373 z.view(),
4374 psi_lo,
4375 psi_hi,
4376 );
4377 if attached {
4378 log::info!(
4379 "[{label}] certified ψ-gram tensor over [{psi_lo:.3}, {psi_hi:.3}]: \
4380 in-window trials assemble Gaussian sufficient statistics n-free"
4381 );
4382 let psi_anchor = theta0[rho_dim];
4387 psi_rank_stable_floor = evaluator
4388 .psi_gram_rank_stable_floor(psi_anchor)
4389 .filter(|&f| f.is_finite() && f > psi_lo && f < psi_anchor);
4390 log::info!(
4391 "[KAPPA-PHASE-FLOOR] n_rows={} psi_lo={psi_lo:.6} psi_anchor={psi_anchor:.6} \
4392 rank_stable_floor={:?} lifted={}",
4393 data.nrows(),
4394 evaluator.psi_gram_rank_stable_floor(psi_anchor),
4395 psi_rank_stable_floor.is_some(),
4396 );
4397 if let Some(floor) = psi_rank_stable_floor {
4398 log::info!(
4399 "[{label}] rank-stable κ-floor ψ_floor={floor:.6} > window floor \
4400 ψ_lo={psi_lo:.6}: lifting the optimizer lower bound to keep every \
4401 in-window trial on the n-free design-realization skip (#1033). The \
4402 conditioned Gram is rank-deficient below ψ_floor (longest-length-scale \
4403 radial mode collapses into the nullspace), where the skip is soundly \
4404 refused; that band drifts with n via the sample-std standardization, \
4405 so this n-free k-space floor is the n-independent fix."
4406 );
4407 }
4408 psi_rank_stable_ceiling = evaluator
4417 .psi_gram_rank_stable_ceiling(psi_anchor)
4418 .filter(|&c| c.is_finite() && c < psi_hi && c > psi_anchor);
4419 log::info!(
4420 "[KAPPA-PHASE-CEIL] n_rows={} psi_hi={psi_hi:.6} psi_anchor={psi_anchor:.6} \
4421 rank_stable_ceiling={:?} clamped={}",
4422 data.nrows(),
4423 evaluator.psi_gram_rank_stable_ceiling(psi_anchor),
4424 psi_rank_stable_ceiling.is_some(),
4425 );
4426 if let Some(ceiling) = psi_rank_stable_ceiling {
4427 log::info!(
4428 "[{label}] rank-stable κ-ceiling ψ_ceil={ceiling:.6} < window ceiling \
4429 ψ_hi={psi_hi:.6}: clamping the optimizer upper bound to keep every \
4430 in-window trial on the n-free design-realization skip (#1033). The \
4431 conditioned Gram is rank-deficient above ψ_ceil (longest-frequency \
4432 radial mode goes collinear), where the skip is soundly refused; a \
4433 line-search overshoot there trips the O(n) reset_surface lane (and the \
4434 deficient pinning ψ it records resets the next in-band trial too)."
4435 );
4436 }
4437 let gradient_covers_full_window = evaluator.psi_gram_tensor_covers_gradient(psi_lo)
4438 && evaluator.psi_gram_tensor_covers_gradient(psi_hi);
4439 if gradient_covers_full_window {
4440 log::info!(
4441 "[{label}] certified ψ-gram tensor gradient lane covers the full \
4442 optimizer window [{psi_lo:.3}, {psi_hi:.3}]"
4443 );
4444 } else {
4445 log::info!(
4446 "[{label}] ψ-gram tensor value lane certified, but the gradient lane \
4447 does not cover the full optimizer window [{psi_lo:.3}, {psi_hi:.3}]; \
4448 keeping exact streamed kappa routing"
4449 );
4450 }
4451 evaluator.set_supports_nfree_penalty_rekey(true);
4471 log::info!(
4472 "[{label}] exact n-free ψ-penalty re-key enabled over [{psi_lo:.3}, \
4473 {psi_hi:.3}]: in-window fast-path trials rebuild S(ψ) n-free from frozen \
4474 geometry (no reset_surface)"
4475 );
4476 } else {
4477 log::info!(
4478 "[{label}] ψ-gram tensor did not certify over [{psi_lo:.3}, {psi_hi:.3}]; \
4479 keeping the exact per-trial path"
4480 );
4481 }
4482 if attached
4503 && evaluator.psi_gram_tensor_covers_gradient(psi_lo)
4504 && evaluator.psi_gram_tensor_covers_gradient(psi_hi)
4505 && evaluator.supports_nfree_penalty_rekey()
4506 && cache.supports_nfree_gradient_only_routing()
4507 {
4508 suppress_outer_hessian_for_nfree = true;
4509 prefer_gradient_only = true;
4510 log::info!(
4511 "[{label}] n-free Gaussian ψ-lane armed; suppressing the analytic outer \
4512 Hessian and routing gradient-only (BFGS) so the κ outer loop never realizes \
4513 the O(n) second-order slab — n-independent outer loop (#1033)"
4514 );
4515 }
4516 } else if coord_dim == 1 && family.is_gaussian_identity() {
4517 log::info!(
4518 "[{label}] exact n-free ψ-penalty re-key unavailable; skipping ψ-gram tensor \
4519 attachment so value, gradient, and Hessian remain on the same exact streamed \
4520 objective"
4521 );
4522 }
4523
4524 const OUTER_FD_AUDIT_MAX_N: usize = 4_000; const OUTER_FD_AUDIT_MAX_THETA_DIM: usize = 32; let n_total = data.nrows();
4552 let outer_fd_audit_eligible = log::log_enabled!(log::Level::Info) && analytic_outer_hessian_available && n_total <= OUTER_FD_AUDIT_MAX_N && theta_dim <= OUTER_FD_AUDIT_MAX_THETA_DIM; log::info!(
4557 "[OUTER-FD-AUDIT/spatial-exact-joint] gate eligible={outer_fd_audit_eligible} \
4558 analytic_grad={analytic_outer_hessian_available} n_total={n_total} \
4559 theta_dim={theta_dim} rho_dim={rho_dim} psi_dim={coord_dim}"
4560 );
4561 if outer_fd_audit_eligible {
4562 let audit = (|| -> Result<gam_solve::rho_optimizer::OuterGradientFdAudit, String> {
4564 let mut eval_at = |theta: &Array1<f64>,
4565 mode: gam_solve::estimate::reml::reml_outer_engine::EvalMode|
4566 -> Result<
4567 (
4568 f64,
4569 Array1<f64>,
4570 gam_problem::HessianResult,
4571 ),
4572 String,
4573 > {
4574 use gam_solve::estimate::reml::reml_outer_engine::EvalMode;
4575 let order = if matches!(mode, EvalMode::ValueGradientHessian) {
4576 OuterEvalOrder::ValueGradientHessian
4577 } else {
4578 OuterEvalOrder::Value
4579 };
4580 ctx.eval_full(theta, order, analytic_outer_hessian_available)
4581 .map_err(|e| format!("fd-audit eval_full: {e}"))
4582 };
4583 let rho_dim_audit = rho_dim;
4584 let label_fn = move |i: usize| -> String {
4585 if i < rho_dim_audit {
4586 format!("rho[{i}]")
4587 } else {
4588 format!("psi_kappa[{}]", i - rho_dim_audit)
4589 }
4590 };
4591 gam_solve::rho_optimizer::outer_gradient_fd_audit(
4592 theta0,
4594 1e-4,
4595 label_fn,
4596 &mut eval_at,
4597 )
4598 })();
4599 match audit {
4601 Ok(audit) => audit.log_verdict("spatial-exact-joint"),
4602 Err(e) => log::warn!("[OUTER-FD-AUDIT/spatial-exact-joint] skipped: {e}"),
4603 }
4604 }
4605
4606 let kphase_prime_order = if analytic_outer_hessian_available && !suppress_outer_hessian_for_nfree {
4607 OuterEvalOrder::ValueGradientHessian
4608 } else {
4609 OuterEvalOrder::ValueAndGradient
4610 };
4611 let kphase_prime_start = std::time::Instant::now();
4612 drop(ctx.eval_full(theta0, kphase_prime_order, analytic_outer_hessian_available)?);
4613 log::info!(
4614 "[KAPPA-PHASE-PRIME] n_rows={} order={:?} elapsed_s={:.4} slow_path_resets_total={} design_revision={}",
4615 data.nrows(),
4616 kphase_prime_order,
4617 kphase_prime_start.elapsed().as_secs_f64(),
4618 ctx.evaluator.slow_path_reset_count(),
4619 ctx.cache.design_revision(),
4620 );
4621
4622 let kphase_cost_calls = std::cell::Cell::new(0usize);
4623 let kphase_eval_calls = std::cell::Cell::new(0usize);
4624 let kphase_efs_calls = std::cell::Cell::new(0usize);
4625 let kphase_cost_total_s = std::cell::Cell::new(0.0);
4626 let kphase_eval_total_s = std::cell::Cell::new(0.0);
4627 let kphase_efs_total_s = std::cell::Cell::new(0.0);
4628 let kphase_nfree_miss_shape = std::cell::Cell::new(0u64);
4629 let kphase_nfree_miss_value = std::cell::Cell::new(0u64);
4630 let kphase_nfree_miss_gradient = std::cell::Cell::new(0u64);
4631 let kphase_nfree_miss_penalty = std::cell::Cell::new(0u64);
4632 let kphase_nfree_miss_revision = std::cell::Cell::new(0u64);
4633 let kphase_nfree_miss_second_order = std::cell::Cell::new(0u64);
4634 let kphase_nfree_miss_other = std::cell::Cell::new(0u64);
4635 let kphase_optim_start = std::time::Instant::now();
4636 let kphase_log_kappa_dim = coord_dim;
4637 let kphase_slow_resets_start = ctx.evaluator.slow_path_reset_count();
4638 let kphase_design_revision_start = ctx.cache.design_revision();
4639
4640 let lower_effective: std::borrow::Cow<'_, Array1<f64>> = match psi_rank_stable_floor {
4647 Some(floor) if coord_dim == 1 && floor > lower[rho_dim] => {
4648 let mut lifted = lower.clone();
4649 lifted[rho_dim] = floor;
4650 std::borrow::Cow::Owned(lifted)
4651 }
4652 _ => std::borrow::Cow::Borrowed(lower),
4653 };
4654 let lower = lower_effective.as_ref();
4655
4656 let upper_effective: std::borrow::Cow<'_, Array1<f64>> = match psi_rank_stable_ceiling {
4664 Some(ceiling) if coord_dim == 1 && ceiling < upper[rho_dim] => {
4665 let mut clamped = upper.clone();
4666 clamped[rho_dim] = ceiling;
4667 std::borrow::Cow::Owned(clamped)
4668 }
4669 _ => std::borrow::Cow::Borrowed(upper),
4670 };
4671 let upper = upper_effective.as_ref();
4672
4673 let problem = exact_joint_multistart_outer_problem(
4674 theta0,
4675 lower,
4676 upper,
4677 rho_dim,
4678 coord_dim,
4679 theta_dim,
4680 Derivative::Analytic,
4681 if analytic_outer_hessian_available && !suppress_outer_hessian_for_nfree {
4682 DeclaredHessianForm::Either
4683 } else {
4684 DeclaredHessianForm::Unavailable
4689 },
4690 prefer_gradient_only,
4691 suppress_outer_hessian_for_nfree,
4702 seed_risk_profile_for_likelihood_family(&family),
4703 kappa_options.rel_tol.max(1e-6),
4704 kappa_options.max_outer_iter.max(1),
4705 Some(5.0),
4709 Some(kappa_options.log_step.clamp(0.25, 1.0)),
4711 None,
4712 Some((data.nrows(), baseline_design.design.ncols())),
4717 !constant_curvature_term_indices(resolvedspec).is_empty(),
4721 );
4722
4723 let eval_outer = |ctx: &mut &mut SpatialJointContext<'_>,
4724 theta: &Array1<f64>,
4725 order: OuterEvalOrder|
4726 -> Result<OuterEval, EstimationError> {
4727 let t0 = std::time::Instant::now();
4728 let allow_second_order_for_call = matches!(order, OuterEvalOrder::ValueGradientHessian)
4729 && analytic_outer_hessian_available;
4730 let gate = ctx.nfree_skip_gate_status(theta, allow_second_order_for_call, true);
4731 let resets_before = ctx.evaluator.slow_path_reset_count();
4732 let raw = ctx.eval_full(theta, order, analytic_outer_hessian_available);
4733 let reset_delta = ctx
4734 .evaluator
4735 .slow_path_reset_count()
4736 .saturating_sub(resets_before);
4737 if reset_delta > 0 {
4738 if !gate.shape {
4739 kphase_nfree_miss_shape.set(kphase_nfree_miss_shape.get() + reset_delta);
4740 }
4741 if gate.shape && !gate.value {
4742 kphase_nfree_miss_value.set(kphase_nfree_miss_value.get() + reset_delta);
4743 }
4744 if gate.shape && gate.value && !gate.gradient {
4745 kphase_nfree_miss_gradient.set(kphase_nfree_miss_gradient.get() + reset_delta);
4746 }
4747 if gate.shape && gate.value && gate.gradient && !gate.penalty {
4748 kphase_nfree_miss_penalty.set(kphase_nfree_miss_penalty.get() + reset_delta);
4749 }
4750 if gate.shape && gate.value && gate.gradient && gate.penalty && !gate.revision {
4751 kphase_nfree_miss_revision.set(kphase_nfree_miss_revision.get() + reset_delta);
4752 }
4753 if gate.shape
4754 && gate.value
4755 && gate.gradient
4756 && gate.penalty
4757 && gate.revision
4758 && gate.second_order
4759 {
4760 kphase_nfree_miss_second_order
4761 .set(kphase_nfree_miss_second_order.get() + reset_delta);
4762 }
4763 if gate.would_skip(true) {
4764 kphase_nfree_miss_other.set(kphase_nfree_miss_other.get() + reset_delta);
4765 }
4766 }
4767 let elapsed_s = t0.elapsed().as_secs_f64();
4768 kphase_eval_calls.set(kphase_eval_calls.get() + 1);
4769 kphase_eval_total_s.set(kphase_eval_total_s.get() + elapsed_s);
4770 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta, rho_dim);
4771 log::info!(
4772 "[KAPPA-PHASE] phase=eval_outer call={} order={:?} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
4773 kphase_eval_calls.get(),
4774 order,
4775 Some(ctx.cache.design_revision()),
4776 theta_norm,
4777 log_kappa_norm,
4778 elapsed_s,
4779 );
4780 match raw {
4781 Ok((cost, grad, hess)) => Ok(OuterEval {
4782 cost,
4783 gradient: grad,
4784 hessian: hess,
4785 inner_beta_hint: None,
4786 }),
4787 Err(err) if is_recoverable_trial_point_error(&err) => {
4795 log::debug!(
4796 "[{label}] trial point infeasible (kernel design \
4797 not constructible at theta={theta:?}): {err}; retreating",
4798 );
4799 Ok(OuterEval::infeasible(theta_dim))
4800 }
4801 Err(err) => Err(err),
4802 }
4803 };
4804
4805 let mut obj = problem.build_objective_with_eval_order(
4806 &mut ctx,
4807 |ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>| {
4808 let t0 = std::time::Instant::now();
4809 let gate = ctx.nfree_skip_gate_status(theta, false, false);
4810 let resets_before = ctx.evaluator.slow_path_reset_count();
4811 let cost = ctx.eval_cost(theta);
4812 let reset_delta = ctx
4813 .evaluator
4814 .slow_path_reset_count()
4815 .saturating_sub(resets_before);
4816 if reset_delta > 0 {
4817 if !gate.shape {
4818 kphase_nfree_miss_shape.set(kphase_nfree_miss_shape.get() + reset_delta);
4819 }
4820 if gate.shape && !gate.value {
4821 kphase_nfree_miss_value.set(kphase_nfree_miss_value.get() + reset_delta);
4822 }
4823 if gate.shape && gate.value && !gate.penalty {
4824 kphase_nfree_miss_penalty.set(kphase_nfree_miss_penalty.get() + reset_delta);
4825 }
4826 if gate.shape && gate.value && gate.penalty && !gate.revision {
4827 kphase_nfree_miss_revision.set(kphase_nfree_miss_revision.get() + reset_delta);
4828 }
4829 if gate.would_skip(false) {
4830 kphase_nfree_miss_other.set(kphase_nfree_miss_other.get() + reset_delta);
4831 }
4832 }
4833 let elapsed_s = t0.elapsed().as_secs_f64();
4834 kphase_cost_calls.set(kphase_cost_calls.get() + 1);
4835 kphase_cost_total_s.set(kphase_cost_total_s.get() + elapsed_s);
4836 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta, rho_dim);
4837 log::info!(
4838 "[KAPPA-PHASE] phase=cost call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
4839 kphase_cost_calls.get(),
4840 Some(ctx.cache.design_revision()),
4841 theta_norm,
4842 log_kappa_norm,
4843 elapsed_s,
4844 );
4845 Ok(cost)
4846 },
4847 |ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>| {
4848 eval_outer(
4849 ctx,
4850 theta,
4851 if analytic_outer_hessian_available && !suppress_outer_hessian_for_nfree {
4861 OuterEvalOrder::ValueGradientHessian
4862 } else {
4863 OuterEvalOrder::ValueAndGradient
4864 },
4865 )
4866 },
4867 |ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>, order: OuterEvalOrder| {
4868 eval_outer(ctx, theta, order)
4869 },
4870 Some(|ctx: &mut &mut SpatialJointContext<'_>| {
4871 ctx.reset();
4872 }),
4873 Some(|ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>| {
4874 let t0 = std::time::Instant::now();
4875 let eval = ctx.eval_efs(theta);
4876 let elapsed_s = t0.elapsed().as_secs_f64();
4877 kphase_efs_calls.set(kphase_efs_calls.get() + 1);
4878 kphase_efs_total_s.set(kphase_efs_total_s.get() + elapsed_s);
4879 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta, rho_dim);
4880 log::info!(
4881 "[KAPPA-PHASE] phase=efs call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
4882 kphase_efs_calls.get(),
4883 Some(ctx.cache.design_revision()),
4884 theta_norm,
4885 log_kappa_norm,
4886 elapsed_s,
4887 );
4888 eval
4889 }),
4890 );
4891
4892 let run_label = match kind {
4893 SpatialHyperKind::Anisotropic => "aniso-psi joint REML",
4894 SpatialHyperKind::Isotropic => "iso-kappa joint REML",
4895 };
4896 let result = problem.run(&mut obj, run_label).map_err(|e| {
4897 EstimationError::InvalidInput(format!(
4898 "{} analytic optimization failed after exhausting strategy fallbacks: {e}",
4899 kind.adjective(),
4900 ))
4901 })?;
4902 drop(obj);
4903 let kphase_total_s = kphase_optim_start.elapsed().as_secs_f64();
4904 let kphase_slow_resets = ctx
4905 .evaluator
4906 .slow_path_reset_count()
4907 .saturating_sub(kphase_slow_resets_start);
4908 let kphase_design_revision_delta = ctx
4909 .cache
4910 .design_revision()
4911 .saturating_sub(kphase_design_revision_start);
4912 log::info!(
4913 "[KAPPA-PHASE-SUMMARY] n_rows={} log_kappa_dim={} n_cost={} cost_total_s={:.4} n_eval={} eval_total_s={:.4} n_efs={} efs_total_s={:.4} slow_path_resets={} design_revision_delta={} nfree_miss_shape={} nfree_miss_value={} nfree_miss_gradient={} nfree_miss_penalty={} nfree_miss_revision={} nfree_miss_second_order={} nfree_miss_other={} optim_total_s={:.4}",
4914 data.nrows(),
4915 kphase_log_kappa_dim,
4916 kphase_cost_calls.get(),
4917 kphase_cost_total_s.get(),
4918 kphase_eval_calls.get(),
4919 kphase_eval_total_s.get(),
4920 kphase_efs_calls.get(),
4921 kphase_efs_total_s.get(),
4922 kphase_slow_resets,
4923 kphase_design_revision_delta,
4924 kphase_nfree_miss_shape.get(),
4925 kphase_nfree_miss_value.get(),
4926 kphase_nfree_miss_gradient.get(),
4927 kphase_nfree_miss_penalty.get(),
4928 kphase_nfree_miss_revision.get(),
4929 kphase_nfree_miss_second_order.get(),
4930 kphase_nfree_miss_other.get(),
4931 kphase_total_s,
4932 );
4933 let timing = SpatialLengthScaleOptimizationTiming {
4934 log_kappa_dim: kphase_log_kappa_dim,
4935 cost_calls: kphase_cost_calls.get(),
4936 cost_total_s: kphase_cost_total_s.get(),
4937 eval_calls: kphase_eval_calls.get(),
4938 eval_total_s: kphase_eval_total_s.get(),
4939 efs_calls: kphase_efs_calls.get(),
4940 efs_total_s: kphase_efs_total_s.get(),
4941 slow_path_resets: kphase_slow_resets,
4942 design_revision_delta: kphase_design_revision_delta,
4943 nfree_miss_shape: kphase_nfree_miss_shape.get(),
4944 nfree_miss_value: kphase_nfree_miss_value.get(),
4945 nfree_miss_gradient: kphase_nfree_miss_gradient.get(),
4946 nfree_miss_penalty: kphase_nfree_miss_penalty.get(),
4947 nfree_miss_revision: kphase_nfree_miss_revision.get(),
4948 nfree_miss_second_order: kphase_nfree_miss_second_order.get(),
4949 nfree_miss_other: kphase_nfree_miss_other.get(),
4950 optim_total_s: kphase_total_s,
4951 };
4952 if !result.converged {
4953 let rel_to_cost_threshold = options.tol * (1.0_f64 + result.final_value.abs());
4964 if let Some(final_grad) = result
4965 .final_grad_norm
4966 .filter(|v| v.is_finite() && *v <= rel_to_cost_threshold)
4967 {
4968 log::info!(
4969 "[{}] outer optimization hit max_iter={} but \
4970 projected gradient norm {:.3e} ≤ τ·(1+|f|) = {:.3e} \
4971 (τ={:.3e}, |f|={:.3e}); accepting iterate under the mgcv-style \
4972 relative-to-cost REML convergence criterion.",
4973 label,
4974 result.iterations,
4975 final_grad,
4976 rel_to_cost_threshold,
4977 options.tol,
4978 result.final_value.abs(),
4979 );
4980 } else if result.final_value.is_finite() {
4981 log::warn!(
4996 "[{}] {} did not converge after {} iterations \
4997 (final_objective={:.6e}, final_grad_norm={}); keeping the \
4998 frozen baseline geometry instead of aborting the fit.",
4999 label,
5000 kind.adjective(),
5001 result.iterations,
5002 result.final_value,
5003 result.final_grad_norm_report(),
5004 );
5005 return Ok((
5006 SpatialJointOutcome::NonConverged {
5007 iterations: result.iterations,
5008 final_value: result.final_value,
5009 final_grad_norm: result.final_grad_norm,
5010 },
5011 timing,
5012 ));
5013 } else {
5014 crate::bail_invalid_estim!(
5019 "{} analytic optimization diverged after {} iterations (final_objective={:.6e}, final_grad_norm={})",
5020 kind.adjective(),
5021 result.iterations,
5022 result.final_value,
5023 result.final_grad_norm_report(),
5024 );
5025 }
5026 }
5027 log::trace!(
5028 "[{}] converged in {} iterations, final_value={:.6e}, grad_norm={}",
5029 label,
5030 result.iterations,
5031 result.final_value,
5032 result.final_grad_norm_report(),
5033 );
5034 let theta_star = result.rho;
5038 Ok((
5039 SpatialJointOutcome::Optimized {
5040 theta_star,
5041 final_value: result.final_value,
5042 },
5043 timing,
5044 ))
5045}
5046
5047fn set_single_term_spatial_length_scale(
5051 term: &mut SmoothTermSpec,
5052 length_scale: f64,
5053) -> Result<(), EstimationError> {
5054 match &mut term.basis {
5055 SmoothBasisSpec::ThinPlate { spec, .. } => {
5056 spec.length_scale = length_scale;
5057 Ok(())
5058 }
5059 SmoothBasisSpec::Matern { spec, .. } => {
5060 spec.length_scale = length_scale;
5061 Ok(())
5062 }
5063 SmoothBasisSpec::Duchon { spec, .. } => {
5064 spec.length_scale = Some(length_scale);
5065 Ok(())
5066 }
5067 _ => Err(EstimationError::InvalidInput(format!(
5068 "term '{}' does not expose a spatial length scale",
5069 term.name
5070 ))),
5071 }
5072}
5073
5074fn set_single_term_spatial_aniso_log_scales(
5078 term: &mut SmoothTermSpec,
5079 eta: Vec<f64>,
5080) -> Result<(), EstimationError> {
5081 let eta = center_aniso_log_scales(&eta);
5082 match &mut term.basis {
5083 SmoothBasisSpec::Matern { spec, .. } => {
5084 spec.aniso_log_scales = Some(eta);
5085 Ok(())
5086 }
5087 SmoothBasisSpec::Duchon { spec, .. } => {
5088 spec.aniso_log_scales = Some(eta);
5089 Ok(())
5090 }
5091 _ => Err(EstimationError::InvalidInput(format!(
5092 "term '{}' does not support aniso_log_scales",
5093 term.name
5094 ))),
5095 }
5096}
5097
5098pub fn get_constant_curvature_kappa(spec: &TermCollectionSpec, term_idx: usize) -> Option<f64> {
5117 constant_curvature_term_spec(spec, term_idx).map(|cc| cc.kappa)
5118}
5119
5120pub fn constant_curvature_term_indices(spec: &TermCollectionSpec) -> Vec<usize> {
5122 (0..spec.smooth_terms.len())
5123 .filter(|&idx| constant_curvature_term_spec(spec, idx).is_some())
5124 .collect()
5125}
5126
5127
5128#[derive(Debug, Clone)]
5129struct SingleSmoothTermRealization {
5130 design_local: DesignMatrix,
5131 term: SmoothTerm,
5132 dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
5133}
5134
5135impl SingleSmoothTermRealization {
5136 fn active_penaltyinfo(&self) -> Vec<PenaltyInfo> {
5137 self.term
5138 .penaltyinfo_local
5139 .iter()
5140 .filter(|info| info.active)
5141 .cloned()
5142 .collect()
5143 }
5144}
5145
5146fn build_single_smooth_term_realization(
5147 data: ArrayView2<'_, f64>,
5148 termspec: &SmoothTermSpec,
5149) -> Result<SingleSmoothTermRealization, BasisError> {
5150 let raw = build_smooth_design(data, std::slice::from_ref(termspec))?;
5151 finish_single_smooth_term_realization(raw)
5152}
5153
5154fn finish_single_smooth_term_realization(
5155 raw: RawSmoothDesign,
5156) -> Result<SingleSmoothTermRealization, BasisError> {
5157 let RawSmoothDesign {
5158 term_designs,
5159 dropped_penaltyinfo,
5160 terms,
5161 ..
5162 } = raw;
5163 let term = terms.into_iter().next().ok_or_else(|| {
5164 BasisError::InvalidInput("single-term smooth build returned no term".to_string())
5165 })?;
5166 let design = term_designs.into_iter().next().ok_or_else(|| {
5167 BasisError::InvalidInput("single-term smooth build returned no term design".to_string())
5168 })?;
5169
5170 Ok(SingleSmoothTermRealization {
5171 design_local: design,
5172 term,
5173 dropped_penaltyinfo,
5174 })
5175}
5176
5177fn wrap_local_build_as_realization(
5184 mut local: LocalSmoothTermBuild,
5185 termspec: &SmoothTermSpec,
5186) -> Result<SingleSmoothTermRealization, String> {
5187 let p_local = local.dim;
5188 let lb_local = if local.box_reparam {
5189 shape_lower_bounds_local(termspec.shape, p_local)
5190 } else {
5191 None
5192 };
5193
5194 let active_count = local.penaltyinfo.iter().filter(|info| info.active).count();
5195 if active_count != local.penalties.len() {
5196 return Err(format!(
5197 "internal penalty info mismatch for term '{}': active_infos={}, penalties={}",
5198 termspec.name,
5199 active_count,
5200 local.penalties.len()
5201 ));
5202 }
5203
5204 let mut dropped_penaltyinfo = Vec::<DroppedPenaltyBlockInfo>::new();
5205 for info in local.penaltyinfo.iter().filter(|info| !info.active) {
5206 dropped_penaltyinfo.push(DroppedPenaltyBlockInfo {
5207 termname: Some(termspec.name.clone()),
5208 penalty: info.clone(),
5209 });
5210 }
5211 for info in &local.pre_dropped_penaltyinfo {
5212 dropped_penaltyinfo.push(DroppedPenaltyBlockInfo {
5213 termname: Some(termspec.name.clone()),
5214 penalty: info.clone(),
5215 });
5216 }
5217
5218 let applied_rotation: Option<gam_terms::basis::JointNullRotation> = match (
5222 local.joint_null_rotation.take(),
5223 lb_local.is_some(),
5224 local.linear_constraints.is_some(),
5225 ) {
5226 (Some(rot), false, false) => {
5227 let q = &rot.rotation;
5228 let dense = local
5229 .design
5230 .try_to_dense_by_chunks("joint-null absorption rotation (single realization)")
5231 .map_err(|e| {
5232 format!(
5233 "joint-null absorption rotation: dense conversion failed for term '{}': {}",
5234 termspec.name, e
5235 )
5236 })?;
5237 let rotated = gam_linalg::faer_ndarray::fast_ab(&dense, q);
5238 local.design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(rotated));
5239 local.penalties = local
5240 .penalties
5241 .into_iter()
5242 .map(|s_local| {
5243 let qt_s = gam_linalg::faer_ndarray::fast_atb(q, &s_local);
5244 gam_linalg::faer_ndarray::fast_ab(&qt_s, q)
5245 })
5246 .collect();
5247 local.ops = vec![None; local.penalties.len()];
5248 local.kronecker_factored = None;
5249 Some(rot)
5250 }
5251 (Some(_), _, _) => None,
5252 (None, _, _) => None,
5253 };
5254
5255 let smooth_term = SmoothTerm {
5256 name: termspec.name.clone(),
5257 coeff_range: 0..p_local,
5258 shape: termspec.shape,
5259 penalties_local: local.penalties.clone(),
5260 nullspace_dims: local.nullspaces.clone(),
5261 penaltyinfo_local: local.penaltyinfo.clone(),
5262 metadata: local.metadata.clone(),
5263 lower_bounds_local: lb_local,
5264 linear_constraints_local: local.linear_constraints.clone(),
5265 kronecker_factored: local.kronecker_factored.take(),
5266 joint_null_rotation: applied_rotation,
5267 unabsorbed_global_orthogonality: None,
5270 };
5271
5272 Ok(SingleSmoothTermRealization {
5273 design_local: local.design,
5274 term: smooth_term,
5275 dropped_penaltyinfo,
5276 })
5277}
5278
5279fn freeze_geometry_from_metadata(
5290 termspec: &SmoothTermSpec,
5291 metadata: &BasisMetadata,
5292) -> Option<SmoothTermSpec> {
5293 let mut frozen = termspec.clone();
5294 match (&mut frozen.basis, metadata) {
5295 (
5296 SmoothBasisSpec::Matern {
5297 spec,
5298 input_scales: spec_scales,
5299 ..
5300 },
5301 BasisMetadata::Matern {
5302 centers,
5303 input_scales: meta_scales,
5304 identifiability_transform,
5305 nullspace_shrinkage_survived,
5306 ..
5307 },
5308 ) => {
5309 spec.center_strategy = CenterStrategy::UserProvided(centers.clone());
5310 if spec_scales.is_none()
5311 && let Some(s) = meta_scales.clone()
5312 {
5313 *spec_scales = Some(s);
5314 }
5315 if let Some(transform) = identifiability_transform.clone() {
5333 spec.identifiability = MaternIdentifiability::FrozenTransform {
5334 transform,
5335 nullspace_shrinkage_survived: Some(*nullspace_shrinkage_survived),
5336 };
5337 }
5338 Some(frozen)
5339 }
5340 (
5341 SmoothBasisSpec::Duchon {
5342 spec,
5343 input_scales: spec_scales,
5344 ..
5345 },
5346 BasisMetadata::Duchon {
5347 centers,
5348 input_scales: meta_scales,
5349 ..
5350 },
5351 ) => {
5352 spec.center_strategy = CenterStrategy::UserProvided(centers.clone());
5353 if spec_scales.is_none()
5354 && let Some(s) = meta_scales.clone()
5355 {
5356 *spec_scales = Some(s);
5357 }
5358 Some(frozen)
5359 }
5360 (
5361 SmoothBasisSpec::ThinPlate {
5362 spec,
5363 input_scales: spec_scales,
5364 ..
5365 },
5366 BasisMetadata::ThinPlate {
5367 centers,
5368 input_scales: meta_scales,
5369 ..
5370 },
5371 ) => {
5372 spec.center_strategy = CenterStrategy::UserProvided(centers.clone());
5373 if spec_scales.is_none()
5374 && let Some(s) = meta_scales.clone()
5375 {
5376 *spec_scales = Some(s);
5377 }
5378 Some(frozen)
5379 }
5380 _ => None,
5383 }
5384}
5385
5386fn rebuild_smooth_auxiliary_state(
5387 smooth: &mut SmoothDesign,
5388 dropped_penaltyinfo_by_term: &[Vec<DroppedPenaltyBlockInfo>],
5389) -> Result<(), String> {
5390 if dropped_penaltyinfo_by_term.len() != smooth.terms.len() {
5391 return Err(SmoothError::dimension_mismatch(format!(
5392 "smooth dropped-penalty cache mismatch: terms={}, dropped_sets={}",
5393 smooth.terms.len(),
5394 dropped_penaltyinfo_by_term.len()
5395 ))
5396 .into());
5397 }
5398
5399 let total_p = smooth.total_smooth_cols();
5400 let mut coefficient_lower_bounds = Array1::<f64>::from_elem(total_p, f64::NEG_INFINITY);
5401 let mut any_bounds = false;
5402 let mut linear_constraintrows: Vec<Array1<f64>> = Vec::new();
5403 let mut linear_constraint_b: Vec<f64> = Vec::new();
5404
5405 for term in &smooth.terms {
5406 let range = term.coeff_range.clone();
5407 if let Some(lb_local) = term.lower_bounds_local.as_ref() {
5408 if lb_local.len() != range.len() {
5409 return Err(SmoothError::dimension_mismatch(format!(
5410 "smooth lower-bound cache mismatch for term '{}': bounds={}, coeffs={}",
5411 term.name,
5412 lb_local.len(),
5413 range.len()
5414 ))
5415 .into());
5416 }
5417 coefficient_lower_bounds
5418 .slice_mut(s![range.clone()])
5419 .assign(lb_local);
5420 any_bounds = true;
5421 }
5422 if let Some(lin_local) = term.linear_constraints_local.as_ref() {
5423 if lin_local.a.ncols() != range.len() {
5424 return Err(SmoothError::dimension_mismatch(format!(
5425 "smooth linear-constraint cache mismatch for term '{}': cols={}, coeffs={}",
5426 term.name,
5427 lin_local.a.ncols(),
5428 range.len()
5429 ))
5430 .into());
5431 }
5432 for r in 0..lin_local.a.nrows() {
5433 let mut row = Array1::<f64>::zeros(total_p);
5434 row.slice_mut(s![range.clone()]).assign(&lin_local.a.row(r));
5435 linear_constraintrows.push(row);
5436 linear_constraint_b.push(lin_local.b[r]);
5437 }
5438 }
5439 }
5440
5441 smooth.coefficient_lower_bounds = if any_bounds {
5442 Some(coefficient_lower_bounds)
5443 } else {
5444 None
5445 };
5446 smooth.linear_constraints = if linear_constraintrows.is_empty() {
5447 None
5448 } else {
5449 let mut a = Array2::<f64>::zeros((linear_constraintrows.len(), total_p));
5450 for (i, row) in linear_constraintrows.iter().enumerate() {
5451 a.row_mut(i).assign(row);
5452 }
5453 Some(LinearInequalityConstraints {
5454 a,
5455 b: Array1::from_vec(linear_constraint_b),
5456 })
5457 };
5458 smooth.dropped_penaltyinfo = dropped_penaltyinfo_by_term
5459 .iter()
5460 .flat_map(|infos| infos.iter().cloned())
5461 .collect();
5462 Ok(())
5463}
5464
5465fn rebuild_term_collection_auxiliary_state(
5466 spec: &TermCollectionSpec,
5467 design: &mut TermCollectionDesign,
5468) -> Result<(), String> {
5469 if spec.linear_terms.len() != design.linear_ranges.len() {
5470 return Err(SmoothError::dimension_mismatch(format!(
5471 "term-collection linear bookkeeping mismatch: spec_terms={}, design_ranges={}",
5472 spec.linear_terms.len(),
5473 design.linear_ranges.len()
5474 ))
5475 .into());
5476 }
5477
5478 let p_total = design.design.ncols();
5479 let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
5480 let mut coefficient_lower_bounds = Array1::<f64>::from_elem(p_total, f64::NEG_INFINITY);
5481 let mut any_bounds = false;
5482 let mut linear_constraintrows: Vec<Array1<f64>> = Vec::new();
5483 let mut linear_constraint_b: Vec<f64> = Vec::new();
5484
5485 for (linear, (_, range)) in spec.linear_terms.iter().zip(design.linear_ranges.iter()) {
5486 if range.len() != 1 {
5487 return Err(SmoothError::dimension_mismatch(format!(
5488 "linear term '{}' expected one coefficient column, found {}",
5489 linear.name,
5490 range.len()
5491 ))
5492 .into());
5493 }
5494 let col = range.start;
5495 if let Some(lb) = linear.coefficient_min {
5496 let mut row = Array1::<f64>::zeros(p_total);
5497 row[col] = 1.0;
5498 linear_constraintrows.push(row);
5499 linear_constraint_b.push(lb);
5500 }
5501 if let Some(ub) = linear.coefficient_max {
5502 let mut row = Array1::<f64>::zeros(p_total);
5503 row[col] = -1.0;
5504 linear_constraintrows.push(row);
5505 linear_constraint_b.push(-ub);
5506 }
5507 }
5508
5509 if let Some(lb_smooth) = design.smooth.coefficient_lower_bounds.as_ref() {
5510 if lb_smooth.len() != design.smooth.total_smooth_cols() {
5511 return Err(SmoothError::dimension_mismatch(format!(
5512 "smooth lower-bound width mismatch: bounds={}, smooth_cols={}",
5513 lb_smooth.len(),
5514 design.smooth.total_smooth_cols()
5515 ))
5516 .into());
5517 }
5518 coefficient_lower_bounds
5519 .slice_mut(s![
5520 smooth_start..(smooth_start + design.smooth.total_smooth_cols())
5521 ])
5522 .assign(lb_smooth);
5523 any_bounds = true;
5524 }
5525 if let Some(lin_smooth) = design.smooth.linear_constraints.as_ref() {
5526 if lin_smooth.a.ncols() != design.smooth.total_smooth_cols() {
5527 return Err(SmoothError::dimension_mismatch(format!(
5528 "smooth linear-constraint width mismatch: cols={}, smooth_cols={}",
5529 lin_smooth.a.ncols(),
5530 design.smooth.total_smooth_cols()
5531 ))
5532 .into());
5533 }
5534 let mut a_global = Array2::<f64>::zeros((lin_smooth.a.nrows(), p_total));
5535 a_global
5536 .slice_mut(s![
5537 ..,
5538 smooth_start..(smooth_start + design.smooth.total_smooth_cols())
5539 ])
5540 .assign(&lin_smooth.a);
5541 for r in 0..a_global.nrows() {
5542 linear_constraintrows.push(a_global.row(r).to_owned());
5543 linear_constraint_b.push(lin_smooth.b[r]);
5544 }
5545 }
5546
5547 let lower_bound_constraints = if any_bounds {
5548 linear_constraints_from_lower_bounds_global(&coefficient_lower_bounds)
5549 } else {
5550 None
5551 };
5552 let explicit_linear_constraints = if linear_constraintrows.is_empty() {
5553 None
5554 } else {
5555 let mut a = Array2::<f64>::zeros((linear_constraintrows.len(), p_total));
5556 for (i, row) in linear_constraintrows.iter().enumerate() {
5557 a.row_mut(i).assign(row);
5558 }
5559 Some(LinearInequalityConstraints {
5560 a,
5561 b: Array1::from_vec(linear_constraint_b),
5562 })
5563 };
5564
5565 design.coefficient_lower_bounds = if any_bounds {
5566 Some(coefficient_lower_bounds)
5567 } else {
5568 None
5569 };
5570 design.linear_constraints =
5571 merge_linear_constraints_global(explicit_linear_constraints, lower_bound_constraints);
5572 design.dropped_penaltyinfo = design.smooth.dropped_penaltyinfo.clone();
5573 Ok(())
5574}
5575
5576fn theta_values_match(left: &Array1<f64>, right: &Array1<f64>) -> bool {
5577 left.len() == right.len()
5578 && left
5579 .iter()
5580 .zip(right.iter())
5581 .all(|(&l, &r)| l.to_bits() == r.to_bits())
5582}
5583
5584fn latent_values_match(left: &Array1<f64>, right: &Array1<f64>) -> bool {
5585 theta_values_match(left, right)
5586}
5587
5588fn spatial_aniso_matches(left: Option<&[f64]>, right: Option<&[f64]>) -> bool {
5589 match (left, right) {
5590 (None, None) => true,
5591 (Some(a), Some(b)) => {
5592 a.len() == b.len()
5593 && a.iter()
5594 .zip(b.iter())
5595 .all(|(&x, &y)| x.to_bits() == y.to_bits())
5596 }
5597 _ => false,
5598 }
5599}
5600
5601fn spatial_length_scale_matches(left: Option<f64>, right: Option<f64>) -> bool {
5602 match (left, right) {
5603 (None, None) => true,
5604 (Some(a), Some(b)) => a.to_bits() == b.to_bits(),
5605 _ => false,
5606 }
5607}
5608
5609struct FrozenTermCollectionIncrementalRealizer<'d> {
5610 data: ArrayView2<'d, f64>,
5611 spec: TermCollectionSpec,
5612 design: TermCollectionDesign,
5613 fixed_blocks: Vec<DesignBlock>,
5614 dropped_penaltyinfo_by_term: Vec<Vec<DroppedPenaltyBlockInfo>>,
5615 smooth_penalty_ranges: Vec<Range<usize>>,
5616 full_penalty_ranges: Vec<Range<usize>>,
5617 basisworkspace: gam_terms::basis::BasisWorkspace,
5621 spatial_realization_geometry: Vec<Option<SmoothTermSpec>>,
5634 design_revision: u64,
5640}
5641
5642impl<'d> std::fmt::Debug for FrozenTermCollectionIncrementalRealizer<'d> {
5643 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
5644 f.debug_struct("FrozenTermCollectionIncrementalRealizer")
5645 .field("data_shape", &(self.data.nrows(), self.data.ncols()))
5646 .field("fixed_blocks", &self.fixed_blocks.len())
5647 .finish_non_exhaustive()
5648 }
5649}
5650
5651impl<'d> FrozenTermCollectionIncrementalRealizer<'d> {
5652 fn new(
5653 data: ArrayView2<'d, f64>,
5654 spec: TermCollectionSpec,
5655 design: TermCollectionDesign,
5656 ) -> Result<Self, String> {
5657 if spec.smooth_terms.len() != design.smooth.terms.len() {
5658 return Err(SmoothError::dimension_mismatch(format!(
5659 "incremental realizer smooth term mismatch: spec_terms={}, design_terms={}",
5660 spec.smooth_terms.len(),
5661 design.smooth.terms.len()
5662 ))
5663 .into());
5664 }
5665
5666 let mut smooth_cursor = 0usize;
5667 let mut smooth_penalty_ranges = Vec::with_capacity(design.smooth.terms.len());
5668 for term in &design.smooth.terms {
5669 let next = smooth_cursor + term.penalties_local.len();
5670 smooth_penalty_ranges.push(smooth_cursor..next);
5671 smooth_cursor = next;
5672 }
5673 if smooth_cursor != design.smooth.penalties.len() {
5674 return Err(SmoothError::dimension_mismatch(format!(
5675 "incremental realizer smooth penalty mismatch: ranged={}, actual={}",
5676 smooth_cursor,
5677 design.smooth.penalties.len()
5678 ))
5679 .into());
5680 }
5681
5682 let fixed_penalty_offset = design
5683 .penalties
5684 .len()
5685 .checked_sub(design.smooth.penalties.len())
5686 .ok_or_else(|| {
5687 "incremental realizer encountered invalid penalty bookkeeping".to_string()
5688 })?;
5689 let full_penalty_ranges = smooth_penalty_ranges
5690 .iter()
5691 .map(|range| (fixed_penalty_offset + range.start)..(fixed_penalty_offset + range.end))
5692 .collect::<Vec<_>>();
5693 let fixed_blocks = build_term_collection_fixed_blocks(data, &spec)
5694 .map_err(|e| format!("failed to cache fixed term-collection blocks: {e}"))?;
5695
5696 let mut dropped_penaltyinfo_by_term = Vec::with_capacity(spec.smooth_terms.len());
5697 for (term_idx, termspec) in spec.smooth_terms.iter().enumerate() {
5698 let realization =
5699 build_single_smooth_term_realization(data, termspec).map_err(|e| {
5700 format!(
5701 "failed to build cached realization for smooth term '{}' (index {}): {e}",
5702 termspec.name, term_idx
5703 )
5704 })?;
5705 let expected_cols = design.smooth.terms[term_idx].coeff_range.len();
5706 if realization.design_local.ncols() != expected_cols {
5707 return Err(SmoothError::dimension_mismatch(format!(
5708 "cached realization width mismatch for term '{}': cached_cols={}, design_cols={}",
5709 termspec.name,
5710 realization.design_local.ncols(),
5711 expected_cols
5712 ))
5713 .into());
5714 }
5715 if realization.active_penaltyinfo().len()
5716 != design.smooth.terms[term_idx].penalties_local.len()
5717 {
5718 return Err(SmoothError::dimension_mismatch(format!(
5719 "cached realization penalty mismatch for term '{}': cached_penalties={}, design_penalties={}",
5720 termspec.name,
5721 realization.active_penaltyinfo().len(),
5722 design.smooth.terms[term_idx].penalties_local.len()
5723 ))
5724 .into());
5725 }
5726 dropped_penaltyinfo_by_term.push(realization.dropped_penaltyinfo);
5727 }
5728
5729 let geometry_slots = spec.smooth_terms.len();
5730 Ok(Self {
5731 data,
5732 spec,
5733 design,
5734 fixed_blocks,
5735 dropped_penaltyinfo_by_term,
5736 smooth_penalty_ranges,
5737 full_penalty_ranges,
5738 basisworkspace: gam_terms::basis::BasisWorkspace::new(),
5739 spatial_realization_geometry: vec![None; geometry_slots],
5740 design_revision: 0,
5741 })
5742 }
5743
5744 fn design_revision(&self) -> u64 {
5745 self.design_revision
5746 }
5747
5748 fn spec(&self) -> &TermCollectionSpec {
5749 &self.spec
5750 }
5751
5752 fn design(&self) -> &TermCollectionDesign {
5753 &self.design
5754 }
5755
5756 fn supports_nfree_penalty_rekey(&self, spatial_terms: &[usize]) -> bool {
5771 if spatial_terms.len() != 1 {
5772 return false;
5773 }
5774 let term_idx = spatial_terms[0];
5775 matches!(
5776 self.design.smooth.terms.get(term_idx).map(|t| &t.metadata),
5777 Some(BasisMetadata::Duchon { .. } | BasisMetadata::ThinPlate { .. })
5778 )
5779 }
5780
5781 fn supports_nfree_gradient_only_routing(&self, spatial_terms: &[usize]) -> bool {
5790 if spatial_terms.len() != 1 {
5791 return false;
5792 }
5793 let term_idx = spatial_terms[0];
5794 matches!(
5795 self.design.smooth.terms.get(term_idx).map(|t| &t.metadata),
5796 Some(BasisMetadata::Duchon { .. } | BasisMetadata::ThinPlate { .. })
5797 )
5798 }
5799
5800 fn canonical_penalties_at_psi(
5813 &mut self,
5814 spatial_terms: &[usize],
5815 psi: &[f64],
5816 ) -> Result<(Vec<gam_terms::construction::CanonicalPenalty>, Vec<usize>), String> {
5817 if spatial_terms.len() != 1 {
5818 return Err(format!(
5819 "n-free penalty re-key requires exactly one spatial term, found {}",
5820 spatial_terms.len()
5821 ));
5822 }
5823 let term_idx = spatial_terms[0];
5824 let (ls_opt, aniso_from_psi) = spatial_term_psi_to_length_scale_and_aniso(psi);
5830 let termspec =
5833 self.spec.smooth_terms.get(term_idx).ok_or_else(|| {
5834 format!("spatial term {term_idx} out of range for n-free penalty")
5835 })?;
5836 let term = self
5837 .design
5838 .smooth
5839 .terms
5840 .get(term_idx)
5841 .ok_or_else(|| format!("realized smooth term {term_idx} out of range"))?;
5842 let p_total = self.design.design.ncols();
5845 let (locals, nullspace_dims): (Vec<Array2<f64>>, Vec<usize>) = match &term.metadata {
5846 BasisMetadata::Duchon {
5847 centers,
5848 identifiability_transform,
5849 operator_collocation_points,
5850 power,
5851 nullspace_order,
5852 aniso_log_scales,
5853 input_scales,
5854 radial_reparam,
5855 ..
5856 } => {
5857 let operator_penalties = match &termspec.basis {
5858 SmoothBasisSpec::Duchon { spec, .. } => spec.operator_penalties.clone(),
5859 _ => gam_terms::basis::DuchonOperatorPenaltySpec::default(),
5860 };
5861 let effective_ls = match input_scales.as_deref() {
5868 Some(scales) => {
5869 compensate_optional_length_scale_for_standardization(ls_opt, scales)
5870 }
5871 None => ls_opt,
5872 };
5873 gam_terms::basis::duchon_penalties_at_length_scale(
5874 centers.view(),
5875 identifiability_transform.as_ref(),
5876 operator_collocation_points.as_ref().map(|p| p.view()),
5877 &operator_penalties,
5878 *power,
5879 *nullspace_order,
5880 aniso_log_scales.as_deref(),
5881 radial_reparam.as_ref(),
5882 effective_ls,
5883 &mut self.basisworkspace,
5884 )
5885 .map_err(|e| e.to_string())?
5886 }
5887 BasisMetadata::Matern {
5888 centers,
5889 periodic,
5890 nu,
5891 include_intercept,
5892 identifiability_transform,
5893 aniso_log_scales,
5894 input_scales,
5895 ..
5896 } => {
5897 let ls = ls_opt.ok_or_else(|| {
5904 "Matérn n-free penalty re-key requires a finite length-scale".to_string()
5905 })?;
5906 let effective_ls = match input_scales.as_deref() {
5907 Some(scales) => compensate_length_scale_for_standardization(ls, scales),
5908 None => ls,
5909 };
5910 let aniso_for_penalty = aniso_from_psi.as_deref().or(aniso_log_scales.as_deref());
5911 let (penalties, nullspace_dims, _info) =
5922 matern_operator_penalty_triplet_at_length_scale(
5923 centers.view(),
5924 periodic.as_deref(),
5925 identifiability_transform.as_ref(),
5926 *nu,
5927 *include_intercept,
5928 aniso_for_penalty,
5929 effective_ls,
5930 )
5931 .map_err(|e| e.to_string())?;
5932 (penalties, nullspace_dims)
5933 }
5934 BasisMetadata::ThinPlate {
5935 centers,
5936 identifiability_transform,
5937 radial_reparam,
5938 ..
5939 } => {
5940 let ls = ls_opt.ok_or_else(|| {
5941 "thin-plate n-free penalty re-key requires a finite length-scale".to_string()
5942 })?;
5943 let double_penalty = match &termspec.basis {
5944 SmoothBasisSpec::ThinPlate { spec, .. } => spec.double_penalty,
5945 _ => false,
5946 };
5947 gam_terms::basis::thin_plate_penalties_at_length_scale(
5948 centers.view(),
5949 identifiability_transform.as_ref(),
5950 radial_reparam.as_ref(),
5951 ls,
5952 double_penalty,
5953 &mut self.basisworkspace,
5954 )
5955 .map_err(|e| e.to_string())?
5956 }
5957 other => {
5958 return Err(format!(
5959 "n-free penalty re-key unsupported for basis metadata {:?}",
5960 std::mem::discriminant(other)
5961 ));
5962 }
5963 };
5964 let templates = &self.design.penalties;
5969 if templates.len() != locals.len() {
5970 return Err(format!(
5971 "n-free penalty re-key produced {} blocks but the frozen design carries {} \
5972 — penalty topology is not ψ-stable",
5973 locals.len(),
5974 templates.len()
5975 ));
5976 }
5977 let specs: Vec<gam_solve::estimate::PenaltySpec> = templates
5978 .iter()
5979 .zip(locals.into_iter())
5980 .map(|(tmpl, local)| gam_solve::estimate::PenaltySpec::Block {
5981 local,
5982 col_range: tmpl.col_range.clone(),
5983 prior_mean: tmpl.prior_mean.clone(),
5984 structure_hint: tmpl.structure_hint.clone(),
5985 op: tmpl.op.clone(),
5986 })
5987 .collect();
5988 gam_terms::construction::canonicalize_penalty_specs(
5989 &specs,
5990 &nullspace_dims,
5991 p_total,
5992 "nfree-psi-penalty",
5993 )
5994 .map_err(|e| e.to_string())
5995 }
5996
5997 fn canonical_penalty_derivatives_at_psi(
5998 &mut self,
5999 spatial_terms: &[usize],
6000 psi: &[f64],
6001 ) -> Result<(Range<usize>, usize, Vec<Array2<f64>>), String> {
6002 if spatial_terms.len() != 1 {
6003 return Err(format!(
6004 "n-free penalty derivative re-key requires exactly one spatial term, found {}",
6005 spatial_terms.len()
6006 ));
6007 }
6008 let term_idx = spatial_terms[0];
6009 let (ls_opt, aniso_from_psi) = spatial_term_psi_to_length_scale_and_aniso(psi);
6010 let termspec = self.spec.smooth_terms.get(term_idx).ok_or_else(|| {
6011 format!("spatial term {term_idx} out of range for n-free penalty derivative")
6012 })?;
6013 let term = self
6014 .design
6015 .smooth
6016 .terms
6017 .get(term_idx)
6018 .ok_or_else(|| format!("realized smooth term {term_idx} out of range"))?;
6019 let p_total = self.design.design.ncols();
6020 let smooth_start = p_total.saturating_sub(self.design.smooth.total_smooth_cols());
6021 let global_range =
6022 (smooth_start + term.coeff_range.start)..(smooth_start + term.coeff_range.end);
6023
6024 let locals = match &term.metadata {
6025 BasisMetadata::Duchon {
6026 centers,
6027 identifiability_transform,
6028 operator_collocation_points,
6029 power,
6030 nullspace_order,
6031 aniso_log_scales,
6032 input_scales,
6033 radial_reparam,
6034 ..
6035 } => {
6036 let mut spec = match &termspec.basis {
6037 SmoothBasisSpec::Duchon { spec, .. } => spec.clone(),
6038 _ => {
6039 return Err(
6040 "Duchon n-free penalty derivative requires a Duchon term spec"
6041 .to_string(),
6042 );
6043 }
6044 };
6045 let effective_ls = match input_scales.as_deref() {
6046 Some(scales) => {
6047 compensate_optional_length_scale_for_standardization(ls_opt, scales)
6048 }
6049 None => ls_opt,
6050 };
6051 spec.length_scale = effective_ls;
6052 spec.power = *power;
6053 spec.nullspace_order = *nullspace_order;
6054 spec.aniso_log_scales = aniso_log_scales.clone();
6055 spec.radial_reparam = radial_reparam.clone();
6058 if spec.length_scale.is_none() {
6059 return Err(
6060 "Duchon n-free penalty derivative requires a hybrid length-scale"
6061 .to_string(),
6062 );
6063 }
6064 let collocation = operator_collocation_points
6065 .as_ref()
6066 .map(|points| points.view())
6067 .unwrap_or_else(|| centers.view());
6068 let (_native_sources, mut first, _native_second) =
6069 gam_terms::basis::build_duchon_native_penalty_psi_derivatives(
6070 centers.view(),
6071 &spec,
6072 identifiability_transform.as_ref(),
6073 &mut self.basisworkspace,
6074 )
6075 .map_err(|e| e.to_string())?;
6076 let (_operator_sources, operator_first, _operator_second) =
6077 gam_terms::basis::build_duchon_operator_penalty_psi_derivatives(
6078 collocation,
6079 centers.view(),
6080 &spec,
6081 identifiability_transform.as_ref(),
6082 &mut self.basisworkspace,
6083 )
6084 .map_err(|e| e.to_string())?;
6085 first.extend(operator_first);
6086 first
6087 }
6088 BasisMetadata::Matern {
6089 centers,
6090 periodic,
6091 nu,
6092 include_intercept,
6093 identifiability_transform,
6094 aniso_log_scales,
6095 input_scales,
6096 ..
6097 } => {
6098 let ls = ls_opt.ok_or_else(|| {
6099 "Matérn n-free penalty derivative requires a finite length-scale".to_string()
6100 })?;
6101 let effective_ls = match input_scales.as_deref() {
6102 Some(scales) => compensate_length_scale_for_standardization(ls, scales),
6103 None => ls,
6104 };
6105 let penalty_centers =
6106 gam_terms::basis::expand_periodic_centers(¢ers.to_owned(), periodic.as_deref())
6107 .map_err(|e| e.to_string())?;
6108 let aniso_for_penalty = aniso_from_psi.as_deref().or(aniso_log_scales.as_deref());
6109 let (first, _second) = gam_terms::basis::build_matern_operator_penalty_psi_derivatives(
6110 penalty_centers.view(),
6111 effective_ls,
6112 *nu,
6113 *include_intercept,
6114 identifiability_transform.as_ref(),
6115 aniso_for_penalty,
6116 )
6117 .map_err(|e| e.to_string())?;
6118 first
6119 }
6120 BasisMetadata::ThinPlate {
6121 centers,
6122 identifiability_transform,
6123 radial_reparam,
6124 ..
6125 } => {
6126 let ls = ls_opt.ok_or_else(|| {
6127 "thin-plate n-free penalty derivative requires a finite length-scale"
6128 .to_string()
6129 })?;
6130 let mut spec = match &termspec.basis {
6131 SmoothBasisSpec::ThinPlate { spec, .. } => spec.clone(),
6132 _ => {
6133 return Err(
6134 "thin-plate n-free penalty derivative requires a ThinPlate term spec"
6135 .to_string(),
6136 );
6137 }
6138 };
6139 spec.length_scale = ls;
6140 if spec.radial_reparam.is_none() {
6141 spec.radial_reparam = radial_reparam.clone();
6142 }
6143 let (primary, _primary_second) =
6144 gam_terms::basis::build_thin_plate_penalty_psi_derivativeswithworkspace(
6145 centers.view(),
6146 &spec,
6147 identifiability_transform.as_ref(),
6148 &mut self.basisworkspace,
6149 )
6150 .map_err(|e| e.to_string())?;
6151 if self.design.penalties.len() > 1 {
6152 vec![primary.clone(), Array2::<f64>::zeros(primary.raw_dim())]
6153 } else {
6154 vec![primary]
6155 }
6156 }
6157 other => {
6158 return Err(format!(
6159 "n-free penalty derivative re-key unsupported for basis metadata {:?}",
6160 std::mem::discriminant(other)
6161 ));
6162 }
6163 };
6164 if locals.len() != self.design.penalties.len() {
6165 return Err(format!(
6166 "n-free penalty derivative re-key produced {} blocks but the frozen design carries {} \
6167 — penalty topology is not ψ-stable",
6168 locals.len(),
6169 self.design.penalties.len()
6170 ));
6171 }
6172 Ok((global_range, p_total, locals))
6173 }
6174
6175 fn apply_log_kappa(
6176 &mut self,
6177 log_kappa: &SpatialLogKappaCoords,
6178 term_indices: &[usize],
6179 ) -> Result<(), String> {
6180 if term_indices.len() != log_kappa.dims_per_term().len() {
6181 return Err(SmoothError::dimension_mismatch(format!(
6182 "incremental realizer log-kappa term mismatch: term_indices={}, dims_per_term={}",
6183 term_indices.len(),
6184 log_kappa.dims_per_term().len()
6185 ))
6186 .into());
6187 }
6188
6189 let mut any_changed = false;
6190 for (slot, &term_idx) in term_indices.iter().enumerate() {
6191 any_changed |= self.apply_log_kappa_to_term(term_idx, log_kappa.term_slice(slot))?;
6192 }
6193
6194 if any_changed {
6195 self.refresh_full_design_operator()?;
6196 rebuild_smooth_auxiliary_state(
6197 &mut self.design.smooth,
6198 &self.dropped_penaltyinfo_by_term,
6199 )?;
6200 rebuild_term_collection_auxiliary_state(&self.spec, &mut self.design)?;
6201 self.design_revision = self.design_revision.wrapping_add(1);
6202 }
6203 Ok(())
6204 }
6205
6206 fn apply_log_kappa_to_term(&mut self, term_idx: usize, psi: &[f64]) -> Result<bool, String> {
6207 if !spatial_term_supports_hyper_optimization(&self.spec, term_idx) {
6208 return Err(SmoothError::invalid_config(format!(
6209 "incremental realizer term {term_idx} does not expose spatial hyperparameters"
6210 ))
6211 .into());
6212 }
6213 let measure_jet_term = measure_jet_term_spec(&self.spec, term_idx).is_some();
6217 let constant_curvature_term = constant_curvature_term_spec(&self.spec, term_idx).is_some();
6221 let mut next_length_scale = None;
6222 let mut next_aniso: Option<Vec<f64>> = None;
6223 if measure_jet_term {
6224 if !set_measure_jet_psi_dials(&mut self.spec, term_idx, psi)
6225 .map_err(|e| e.to_string())?
6226 {
6227 return Ok(false);
6228 }
6229 } else if constant_curvature_term {
6230 if !set_constant_curvature_kappa(&mut self.spec, term_idx, psi)
6231 .map_err(|e| e.to_string())?
6232 {
6233 return Ok(false);
6234 }
6235 } else {
6236 let current_length_scale = get_spatial_length_scale(&self.spec, term_idx);
6237 let current_aniso = get_spatial_aniso_log_scales(&self.spec, term_idx);
6238 let (ls, eta) = spatial_term_psi_to_length_scale_and_aniso(psi);
6239 next_length_scale = ls;
6240 next_aniso = eta;
6241 let same_length = spatial_length_scale_matches(current_length_scale, next_length_scale);
6242 let same_aniso = spatial_aniso_matches(current_aniso.as_deref(), next_aniso.as_deref());
6243 if same_length && same_aniso {
6244 return Ok(false);
6245 }
6246 if let Some(length_scale) = next_length_scale {
6247 set_spatial_length_scale(&mut self.spec, term_idx, length_scale)
6248 .map_err(|e| e.to_string())?;
6249 }
6250 if let Some(eta) = next_aniso.clone() {
6251 set_spatial_aniso_log_scales(&mut self.spec, term_idx, eta)
6252 .map_err(|e| e.to_string())?;
6253 }
6254 }
6255
6256 let geometry_slot = self
6267 .spatial_realization_geometry
6268 .get(term_idx)
6269 .ok_or_else(|| format!("incremental realizer geometry slot {term_idx} out of range"))?;
6270 let mut build_spec = match geometry_slot {
6271 Some(cached) => cached.clone(),
6272 None => self
6273 .spec
6274 .smooth_terms
6275 .get(term_idx)
6276 .ok_or_else(|| format!("incremental realizer smooth term {term_idx} out of range"))?
6277 .clone(),
6278 };
6279 if measure_jet_term {
6280 set_single_term_measure_jet_psi_dials(&mut build_spec, psi)
6284 .map_err(|e| e.to_string())?;
6285 } else if constant_curvature_term {
6286 set_single_term_constant_curvature_kappa(&mut build_spec, psi)
6291 .map_err(|e| e.to_string())?;
6292 } else {
6293 if let Some(length_scale) = next_length_scale {
6294 set_single_term_spatial_length_scale(&mut build_spec, length_scale)
6295 .map_err(|e| e.to_string())?;
6296 }
6297 if let Some(eta) = next_aniso {
6298 set_single_term_spatial_aniso_log_scales(&mut build_spec, eta)
6299 .map_err(|e| e.to_string())?;
6300 }
6301 }
6302
6303 let termname = build_spec.name.clone();
6304 let local = build_single_local_smooth_term(
6305 self.data,
6306 &build_spec,
6307 &mut self.basisworkspace,
6308 )
6309 .map_err(|e| {
6310 format!(
6311 "failed to rebuild smooth term '{termname}' during incremental κ realization: {e}"
6312 )
6313 })?;
6314
6315 if self.spatial_realization_geometry[term_idx].is_none()
6320 && let Some(frozen) = freeze_geometry_from_metadata(&build_spec, &local.metadata)
6321 {
6322 if let (
6334 SmoothBasisSpec::Matern {
6335 spec: frozen_spec, ..
6336 },
6337 Some(SmoothBasisSpec::Matern {
6338 spec: live_spec, ..
6339 }),
6340 ) = (
6341 &frozen.basis,
6342 self.spec
6343 .smooth_terms
6344 .get_mut(term_idx)
6345 .map(|t| &mut t.basis),
6346 ) {
6347 live_spec.identifiability = frozen_spec.identifiability.clone();
6348 live_spec.center_strategy = frozen_spec.center_strategy.clone();
6349 }
6350 self.spatial_realization_geometry[term_idx] = Some(frozen);
6351 }
6352
6353 let realization = wrap_local_build_as_realization(local, &build_spec)?;
6354 self.replace_term_realization(term_idx, realization)?;
6355 Ok(true)
6356 }
6357
6358 fn replace_term_realization(
6359 &mut self,
6360 term_idx: usize,
6361 realization: SingleSmoothTermRealization,
6362 ) -> Result<(), String> {
6363 let t_replace = std::time::Instant::now();
6364 let SingleSmoothTermRealization {
6365 design_local,
6366 term,
6367 dropped_penaltyinfo,
6368 } = realization;
6369 let SmoothTerm {
6370 name,
6371 penalties_local,
6372 nullspace_dims,
6373 penaltyinfo_local,
6374 metadata,
6375 lower_bounds_local,
6376 linear_constraints_local,
6377 joint_null_rotation,
6378 ..
6379 } = term;
6380 let coeff_range = self
6381 .design
6382 .smooth
6383 .terms
6384 .get(term_idx)
6385 .ok_or_else(|| format!("incremental realizer smooth term {term_idx} out of range"))?
6386 .coeff_range
6387 .clone();
6388 if design_local.ncols() != coeff_range.len() {
6389 return Err(SmoothError::dimension_mismatch(format!(
6390 "incremental realizer width mismatch for term {}: rebuilt_cols={}, cached_cols={}",
6391 term_idx,
6392 design_local.ncols(),
6393 coeff_range.len()
6394 ))
6395 .into());
6396 }
6397 if design_local.nrows() != self.design.design.nrows() {
6398 return Err(SmoothError::dimension_mismatch(format!(
6399 "incremental realizer row mismatch for term {}: rebuilt_rows={}, design_rows={}",
6400 term_idx,
6401 design_local.nrows(),
6402 self.design.design.nrows()
6403 ))
6404 .into());
6405 }
6406
6407 let active_penaltyinfo = penaltyinfo_local
6408 .iter()
6409 .filter(|info| info.active)
6410 .cloned()
6411 .collect::<Vec<_>>();
6412 let smooth_penalty_range = self
6413 .smooth_penalty_ranges
6414 .get(term_idx)
6415 .ok_or_else(|| {
6416 format!("incremental realizer missing smooth penalty range for term {term_idx}")
6417 })?
6418 .clone();
6419 let full_penalty_range = self
6420 .full_penalty_ranges
6421 .get(term_idx)
6422 .ok_or_else(|| {
6423 format!("incremental realizer missing full penalty range for term {term_idx}")
6424 })?
6425 .clone();
6426 if active_penaltyinfo.len() != smooth_penalty_range.len()
6427 || penalties_local.len() != smooth_penalty_range.len()
6428 || nullspace_dims.len() != smooth_penalty_range.len()
6429 {
6430 return Err(SmoothError::dimension_mismatch(format!(
6431 "incremental realizer topology changed for term '{}': penalties={}, infos={}, nullspaces={}, cached_penalties={}",
6432 name,
6433 penalties_local.len(),
6434 active_penaltyinfo.len(),
6435 nullspace_dims.len(),
6436 smooth_penalty_range.len()
6437 ))
6438 .into());
6439 }
6440
6441 self.design.smooth.term_designs[term_idx] = design_local;
6442
6443 for (offset, penalty_local) in penalties_local.iter().enumerate() {
6444 let smooth_penalty_idx = smooth_penalty_range.start + offset;
6445 let full_penalty_idx = full_penalty_range.start + offset;
6446 let nullspace_dim = nullspace_dims[offset];
6447 let penalty_info = active_penaltyinfo[offset].clone();
6448
6449 if penalty_local.nrows() != coeff_range.len()
6450 || penalty_local.ncols() != coeff_range.len()
6451 {
6452 return Err(SmoothError::dimension_mismatch(format!(
6453 "incremental realizer penalty shape mismatch for term '{}' penalty {}: \
6454 penalty is {}x{} but coeff_range has {} columns",
6455 name,
6456 offset,
6457 penalty_local.nrows(),
6458 penalty_local.ncols(),
6459 coeff_range.len()
6460 ))
6461 .into());
6462 }
6463
6464 let smooth_penalty = self
6465 .design
6466 .smooth
6467 .penalties
6468 .get_mut(smooth_penalty_idx)
6469 .ok_or_else(|| {
6470 format!(
6471 "incremental realizer smooth penalty {} out of range for term {}",
6472 smooth_penalty_idx, term_idx
6473 )
6474 })?;
6475 smooth_penalty.local.assign(penalty_local);
6478
6479 let full_bp = self
6480 .design
6481 .penalties
6482 .get_mut(full_penalty_idx)
6483 .ok_or_else(|| {
6484 format!(
6485 "incremental realizer full penalty {} out of range for term {}",
6486 full_penalty_idx, term_idx
6487 )
6488 })?;
6489 full_bp.local.assign(penalty_local);
6492
6493 self.design.smooth.nullspace_dims[smooth_penalty_idx] = nullspace_dim;
6494 self.design.nullspace_dims[full_penalty_idx] = nullspace_dim;
6495
6496 self.design.smooth.penaltyinfo[smooth_penalty_idx].global_index = smooth_penalty_idx;
6497 self.design.smooth.penaltyinfo[smooth_penalty_idx].termname = Some(name.clone());
6498 self.design.smooth.penaltyinfo[smooth_penalty_idx].penalty = penalty_info.clone();
6499
6500 self.design.penaltyinfo[full_penalty_idx].global_index = full_penalty_idx;
6501 self.design.penaltyinfo[full_penalty_idx].termname = Some(name.clone());
6502 self.design.penaltyinfo[full_penalty_idx].penalty = penalty_info;
6503 }
6504
6505 let target_term = self.design.smooth.terms.get_mut(term_idx).ok_or_else(|| {
6506 format!("incremental realizer smooth term {term_idx} disappeared during replacement")
6507 })?;
6508 target_term.penalties_local = penalties_local;
6509 target_term.nullspace_dims = nullspace_dims;
6510 target_term.penaltyinfo_local = penaltyinfo_local;
6511 target_term.metadata = metadata;
6512 target_term.lower_bounds_local = lower_bounds_local;
6513 target_term.linear_constraints_local = linear_constraints_local;
6514 target_term.joint_null_rotation = joint_null_rotation;
6515 self.dropped_penaltyinfo_by_term[term_idx] = dropped_penaltyinfo;
6516 log::info!(
6517 "[STAGE] smooth basis rebuild (term {}, '{}', cols={}): {:.3}s",
6518 term_idx,
6519 target_term.name,
6520 coeff_range.len(),
6521 t_replace.elapsed().as_secs_f64(),
6522 );
6523 Ok(())
6524 }
6525
6526 fn refresh_full_design_operator(&mut self) -> Result<(), String> {
6527 let mut blocks = Vec::<DesignBlock>::with_capacity(
6528 self.fixed_blocks.len() + self.design.smooth.term_designs.len(),
6529 );
6530 blocks.extend(self.fixed_blocks.iter().cloned());
6531 for term_design in &self.design.smooth.term_designs {
6532 blocks.push(DesignBlock::from(term_design));
6533 }
6534 self.design.design = assemble_term_collection_design_matrix(blocks)
6535 .map_err(|e| format!("failed to refresh term-collection design: {e}"))?;
6536 Ok(())
6537 }
6538}
6539
6540fn build_term_collection_fixed_blocks(
6541 data: ArrayView2<'_, f64>,
6542 spec: &TermCollectionSpec,
6543) -> Result<Vec<DesignBlock>, BasisError> {
6544 let mut blocks = Vec::<DesignBlock>::new();
6545 if !term_collection_has_one_sided_anchored_bspline(spec) {
6546 blocks.push(DesignBlock::Intercept(data.nrows()));
6547 }
6548
6549 if !spec.linear_terms.is_empty() {
6550 let mut linear_block = Array2::<f64>::zeros((data.nrows(), spec.linear_terms.len()));
6551 for (j, linear) in spec.linear_terms.iter().enumerate() {
6552 let column = linear
6556 .realized_design_column(data)
6557 .map_err(BasisError::InvalidInput)?;
6558 linear_block.column_mut(j).assign(&column);
6559 }
6560 blocks.push(DesignBlock::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
6561 linear_block,
6562 )));
6563 }
6564
6565 for term in &spec.random_effect_terms {
6566 let block = build_random_effect_block(data, term)?;
6567 let re_op = RandomEffectOperator::new(block.group_ids, block.num_groups);
6568 blocks.push(DesignBlock::RandomEffect(Arc::new(re_op)));
6569 }
6570
6571 Ok(blocks)
6572}
6573
6574pub struct SpatialLengthScaleOptimizationResult<FitOut> {
6579 pub resolved_specs: Vec<TermCollectionSpec>,
6580 pub designs: Vec<TermCollectionDesign>,
6581 pub fit: FitOut,
6582 pub timing: Option<SpatialLengthScaleOptimizationTiming>,
6583}
6584
6585#[derive(Debug, Clone)]
6587pub struct ExactJointHyperSetup {
6588 rho0: Array1<f64>,
6589 rho_lower: Array1<f64>,
6590 rho_upper: Array1<f64>,
6591 log_kappa0: SpatialLogKappaCoords,
6592 log_kappa_lower: SpatialLogKappaCoords,
6593 log_kappa_upper: SpatialLogKappaCoords,
6594 auxiliary0: Array1<f64>,
6595 auxiliary_lower: Array1<f64>,
6596 auxiliary_upper: Array1<f64>,
6597}
6598
6599impl ExactJointHyperSetup {
6600 fn sanitize_rho_seed(
6601 rho0: Array1<f64>,
6602 rho_lower: &Array1<f64>,
6603 rho_upper: &Array1<f64>,
6604 ) -> Array1<f64> {
6605 Array1::from_iter(rho0.iter().enumerate().map(|(idx, &value)| {
6606 let lo = rho_lower[idx];
6607 let hi = rho_upper[idx];
6608 let fallback = 0.0_f64.clamp(lo, hi);
6609 if value.is_finite() {
6610 value.clamp(lo, hi)
6611 } else {
6612 fallback
6613 }
6614 }))
6615 }
6616
6617 pub(crate) fn new(
6618 rho0: Array1<f64>,
6619 rho_lower: Array1<f64>,
6620 rho_upper: Array1<f64>,
6621 log_kappa0: SpatialLogKappaCoords,
6622 log_kappa_lower: SpatialLogKappaCoords,
6623 log_kappa_upper: SpatialLogKappaCoords,
6624 ) -> Self {
6625 let rho0 = Self::sanitize_rho_seed(rho0, &rho_lower, &rho_upper);
6626 Self {
6627 rho0,
6628 rho_lower,
6629 rho_upper,
6630 log_kappa0,
6631 log_kappa_lower,
6632 log_kappa_upper,
6633 auxiliary0: Array1::zeros(0),
6634 auxiliary_lower: Array1::zeros(0),
6635 auxiliary_upper: Array1::zeros(0),
6636 }
6637 }
6638
6639 pub(crate) fn with_auxiliary(
6640 mut self,
6641 auxiliary0: Array1<f64>,
6642 auxiliary_lower: Array1<f64>,
6643 auxiliary_upper: Array1<f64>,
6644 ) -> Self {
6645 assert_eq!(
6646 auxiliary0.len(),
6647 auxiliary_lower.len(),
6648 "auxiliary lower bound length mismatch"
6649 );
6650 assert_eq!(
6651 auxiliary0.len(),
6652 auxiliary_upper.len(),
6653 "auxiliary upper bound length mismatch"
6654 );
6655 self.auxiliary0 = Self::sanitize_rho_seed(auxiliary0, &auxiliary_lower, &auxiliary_upper);
6656 self.auxiliary_lower = auxiliary_lower;
6657 self.auxiliary_upper = auxiliary_upper;
6658 self
6659 }
6660
6661 pub(crate) fn rho_dim(&self) -> usize {
6662 self.rho0.len()
6663 }
6664
6665 pub(crate) fn log_kappa_dim(&self) -> usize {
6666 self.log_kappa0.len()
6667 }
6668
6669 pub(crate) fn auxiliary_dim(&self) -> usize {
6670 self.auxiliary0.len()
6671 }
6672
6673 pub(crate) fn theta0(&self) -> Array1<f64> {
6674 let mut out =
6675 Array1::<f64>::zeros(self.rho_dim() + self.log_kappa_dim() + self.auxiliary_dim());
6676 out.slice_mut(s![..self.rho_dim()]).assign(&self.rho0);
6677 out.slice_mut(s![self.rho_dim()..self.rho_dim() + self.log_kappa_dim()])
6678 .assign(self.log_kappa0.as_array());
6679 out.slice_mut(s![self.rho_dim() + self.log_kappa_dim()..])
6680 .assign(&self.auxiliary0);
6681 out
6682 }
6683
6684 pub(crate) fn lower(&self) -> Array1<f64> {
6685 let mut out =
6686 Array1::<f64>::zeros(self.rho_dim() + self.log_kappa_dim() + self.auxiliary_dim());
6687 out.slice_mut(s![..self.rho_dim()]).assign(&self.rho_lower);
6688 out.slice_mut(s![self.rho_dim()..self.rho_dim() + self.log_kappa_dim()])
6689 .assign(self.log_kappa_lower.as_array());
6690 out.slice_mut(s![self.rho_dim() + self.log_kappa_dim()..])
6691 .assign(&self.auxiliary_lower);
6692 out
6693 }
6694
6695 pub(crate) fn upper(&self) -> Array1<f64> {
6696 let mut out =
6697 Array1::<f64>::zeros(self.rho_dim() + self.log_kappa_dim() + self.auxiliary_dim());
6698 out.slice_mut(s![..self.rho_dim()]).assign(&self.rho_upper);
6699 out.slice_mut(s![self.rho_dim()..self.rho_dim() + self.log_kappa_dim()])
6700 .assign(self.log_kappa_upper.as_array());
6701 out.slice_mut(s![self.rho_dim() + self.log_kappa_dim()..])
6702 .assign(&self.auxiliary_upper);
6703 out
6704 }
6705
6706 pub(crate) fn log_kappa_dims_per_term(&self) -> Vec<usize> {
6708 self.log_kappa0.dims_per_term().to_vec()
6709 }
6710}
6711
6712struct ExactJointDesignCache<'d> {
6718 realizers: Vec<FrozenTermCollectionIncrementalRealizer<'d>>,
6719 block_term_indices: Vec<Vec<usize>>,
6720 current_theta: Option<Array1<f64>>,
6721 last_cost: Option<f64>,
6722 last_eval: Option<(
6723 f64,
6724 Array1<f64>,
6725 gam_problem::HessianResult,
6726 )>,
6727 rho_dim: usize,
6728 all_dims: Vec<usize>,
6729 log_kappa_dim: usize,
6730 block_term_counts: Vec<usize>,
6731}
6732
6733impl<'d> ExactJointDesignCache<'d> {
6734 fn new(
6735 data: ArrayView2<'d, f64>,
6736 blocks: Vec<(TermCollectionSpec, TermCollectionDesign, Vec<usize>)>,
6737 rho_dim: usize,
6738 all_dims: Vec<usize>,
6739 ) -> Result<Self, String> {
6740 let n_blocks = blocks.len();
6741 let mut realizers = Vec::with_capacity(n_blocks);
6742 let mut block_term_indices = Vec::with_capacity(n_blocks);
6743 let mut block_term_counts = Vec::with_capacity(n_blocks);
6744
6745 for (spec, design, terms) in blocks {
6746 block_term_counts.push(terms.len());
6747 block_term_indices.push(terms);
6748 realizers.push(FrozenTermCollectionIncrementalRealizer::new(
6749 data, spec, design,
6750 )?);
6751 }
6752
6753 Ok(Self {
6754 realizers,
6755 block_term_indices,
6756 current_theta: None,
6757 last_cost: None,
6758 last_eval: None,
6759 rho_dim,
6760 log_kappa_dim: all_dims.iter().sum(),
6761 all_dims,
6762 block_term_counts,
6763 })
6764 }
6765
6766 fn ensure_theta(&mut self, theta: &Array1<f64>) -> Result<(), String> {
6767 if self
6768 .current_theta
6769 .as_ref()
6770 .is_some_and(|cached| theta_values_match(cached, theta))
6771 {
6772 return Ok(());
6773 }
6774
6775 let t_ensure = std::time::Instant::now();
6776 let kappa_theta_len = self.rho_dim + self.log_kappa_dim;
6777 if theta.len() < kappa_theta_len {
6778 return Err(SmoothError::dimension_mismatch(format!(
6779 "exact-joint theta length mismatch: got {}, expected at least {} (rho_dim={}, log_kappa_dim={})",
6780 theta.len(),
6781 kappa_theta_len,
6782 self.rho_dim,
6783 self.log_kappa_dim
6784 ))
6785 .into());
6786 }
6787 let theta_kappa = theta.slice(s![..kappa_theta_len]).to_owned();
6788 let full_log_kappa = SpatialLogKappaCoords::from_theta_tail_with_dims(
6789 &theta_kappa,
6790 self.rho_dim,
6791 self.all_dims.clone(),
6792 );
6793
6794 let n = self.realizers.len();
6798 let mut remaining = full_log_kappa;
6799 for block_idx in 0..n {
6800 let count = self.block_term_counts[block_idx];
6801 if block_idx < n - 1 {
6802 let (block_lk, rest) = remaining.split_at(count);
6803 self.realizers[block_idx]
6804 .apply_log_kappa(&block_lk, &self.block_term_indices[block_idx])?;
6805 remaining = rest;
6806 } else {
6807 self.realizers[block_idx]
6809 .apply_log_kappa(&remaining, &self.block_term_indices[block_idx])?;
6810 }
6811 }
6812
6813 log::info!(
6814 "[STAGE] ensure_theta (n-block, {} blocks, {} realizers): {:.3}s",
6815 n,
6816 self.realizers.len(),
6817 t_ensure.elapsed().as_secs_f64(),
6818 );
6819 self.current_theta = Some(theta.clone());
6820 self.last_cost = None;
6821 self.last_eval = None;
6822 Ok(())
6823 }
6824
6825 impl_exact_joint_theta_memo!();
6826
6827 fn store_cost_only(&mut self, theta: &Array1<f64>, cost: f64) {
6833 if self
6834 .current_theta
6835 .as_ref()
6836 .is_some_and(|cached| theta_values_match(cached, theta))
6837 {
6838 self.last_cost = Some(cost);
6839 }
6840 }
6841
6842 fn specs(&self) -> Vec<&TermCollectionSpec> {
6843 self.realizers.iter().map(|r| r.spec()).collect()
6844 }
6845
6846 fn designs(&self) -> Vec<&TermCollectionDesign> {
6847 self.realizers.iter().map(|r| r.design()).collect()
6848 }
6849
6850 fn design_revision(&self) -> u64 {
6860 self.realizers
6861 .iter()
6862 .fold(0u64, |acc, r| acc.wrapping_add(r.design_revision()))
6863 }
6864}
6865
6866pub(crate) fn seed_risk_profile_for_likelihood_family(
6867 family: &LikelihoodSpec,
6868) -> gam_problem::SeedRiskProfile {
6869 match &family.response {
6870 ResponseFamily::Gaussian => gam_problem::SeedRiskProfile::Gaussian,
6871 ResponseFamily::RoystonParmar => gam_problem::SeedRiskProfile::Survival,
6872 ResponseFamily::Binomial
6873 | ResponseFamily::Poisson
6874 | ResponseFamily::Tweedie { .. }
6875 | ResponseFamily::NegativeBinomial { .. }
6876 | ResponseFamily::Beta { .. }
6877 | ResponseFamily::Gamma => gam_problem::SeedRiskProfile::GeneralizedLinear,
6878 }
6879}
6880
6881const EXACT_JOINT_SECOND_ORDER_THETA_CAP: usize = 8;
6889
6890fn exact_joint_seed_config(
6891 risk_profile: gam_problem::SeedRiskProfile,
6892 auxiliary_dim: usize,
6893) -> gam_problem::SeedConfig {
6894 let mut config = gam_problem::SeedConfig {
6895 risk_profile,
6896 num_auxiliary_trailing: auxiliary_dim,
6897 ..Default::default()
6898 };
6899 match risk_profile {
6900 gam_problem::SeedRiskProfile::Gaussian
6901 | gam_problem::SeedRiskProfile::GaussianLocationScale => {
6902 config.max_seeds = 4;
6903 config.seed_budget = 2;
6904 }
6905 gam_problem::SeedRiskProfile::GeneralizedLinear => {
6906 config.max_seeds = 1;
6911 config.seed_budget = 1;
6912 config.screen_max_inner_iterations = 8;
6913 }
6914 gam_problem::SeedRiskProfile::Survival => {
6915 config.max_seeds = 8;
6921 config.seed_budget = 4;
6922 config.screen_max_inner_iterations = 8;
6923 }
6924 }
6925 config
6926}
6927
6928#[cfg(test)]
6929mod exact_joint_seed_config_tests {
6930 use super::*;
6931
6932 #[test]
6933 fn exact_joint_marginal_slope_profiles_get_deeper_startup_validation() {
6934 let bms = exact_joint_seed_config(gam_problem::SeedRiskProfile::GeneralizedLinear, 2);
6935 assert_eq!(bms.max_seeds, 1);
6936 assert_eq!(bms.seed_budget, 1);
6937 assert_eq!(bms.screen_max_inner_iterations, 8);
6938 assert_eq!(bms.num_auxiliary_trailing, 2);
6939
6940 let survival = exact_joint_seed_config(gam_problem::SeedRiskProfile::Survival, 3);
6941 assert_eq!(survival.max_seeds, 8);
6942 assert_eq!(survival.seed_budget, 4);
6943 assert_eq!(survival.screen_max_inner_iterations, 8);
6944 assert_eq!(survival.num_auxiliary_trailing, 3);
6945 }
6946
6947 #[test]
6948 fn exact_joint_gaussian_keeps_tight_historical_multistart_budget() {
6949 let gaussian = exact_joint_seed_config(gam_problem::SeedRiskProfile::Gaussian, 1);
6950 assert_eq!(gaussian.max_seeds, 4);
6951 assert_eq!(gaussian.seed_budget, 2);
6952 assert_eq!(
6953 gaussian.screen_max_inner_iterations,
6954 gam_problem::SeedConfig::default().screen_max_inner_iterations
6955 );
6956 assert_eq!(gaussian.num_auxiliary_trailing, 1);
6957 }
6958}
6959
6960#[cfg(test)]
6961mod wood_reference_df_tests {
6962 use super::*;
6963
6964 #[test]
6970 fn edf1_equals_two_trace_minus_trace_of_square() {
6971 let f = ndarray::array![[0.9_f64, 0.0], [0.0, 0.4]];
6975 let got = wood_reference_df(Some(&f), &(0..2)).unwrap();
6976 assert!(
6977 (got - 1.63).abs() < 1e-12,
6978 "edf1 should be 2*tr - tr(F^2) = 1.63, got {got}"
6979 );
6980 let edf = 1.3;
6983 assert!(got >= edf - 1e-12, "edf1 {got} must be >= edf {edf}");
6984 }
6985
6986 #[test]
6987 fn edf1_never_collapses_below_edf_when_offdiagonals_blow_up() {
6988 let f = ndarray::array![[0.5_f64, 40.0], [40.0, 0.5]];
6995 let tr = 1.0_f64;
6996 let got = wood_reference_df(Some(&f), &(0..2)).unwrap();
6997 assert!(
6998 got >= tr - 1e-12,
6999 "edf1 must be floored at edf (=tr={tr}) even when tr(F^2) explodes, got {got}"
7000 );
7001 assert!(got.is_finite() && got > 0.0, "edf1 must stay finite/positive");
7002 }
7003
7004 #[test]
7005 fn returns_none_on_nonpositive_or_missing_trace() {
7006 assert!(wood_reference_df(None, &(0..2)).is_none());
7009 let zero = ndarray::array![[0.0_f64, 0.0], [0.0, 0.0]];
7011 assert!(wood_reference_df(Some(&zero), &(0..2)).is_none());
7012 let f = ndarray::array![[0.5_f64, 0.0], [0.0, 0.5]];
7014 assert!(wood_reference_df(Some(&f), &(0..5)).is_none());
7015 }
7016}
7017
7018pub(crate) fn exact_joint_multistart_outer_problem(
7019 theta0: &Array1<f64>,
7020 lower: &Array1<f64>,
7021 upper: &Array1<f64>,
7022 rho_dim: usize,
7023 auxiliary_dim: usize,
7024 n_params: usize,
7025 gradient: gam_problem::Derivative,
7026 hessian: gam_problem::DeclaredHessianForm,
7027 prefer_gradient_only: bool,
7028 disable_fixed_point: bool,
7029 risk_profile: gam_problem::SeedRiskProfile,
7030 tolerance: f64,
7031 max_iter: usize,
7032 bfgs_step_cap: Option<f64>,
7041 bfgs_step_cap_psi: Option<f64>,
7042 screening_cap: Option<Arc<AtomicUsize>>,
7043 profiled_objective_size: Option<(usize, usize)>,
7064 has_constant_curvature: bool,
7073) -> gam_solve::rho_optimizer::OuterProblem {
7074 let mut seed_heuristic = theta0.to_vec();
7075 for value in &mut seed_heuristic[..rho_dim] {
7076 *value = value.exp();
7077 }
7078 let rho_ceiling = if has_constant_curvature {
7083 gam_solve::estimate::RHO_BOUND
7084 } else {
7085 12.0
7086 };
7087 let mut problem = gam_solve::rho_optimizer::OuterProblem::new(n_params)
7088 .with_gradient(gradient)
7089 .with_hessian(hessian)
7090 .with_prefer_gradient_only(prefer_gradient_only)
7091 .with_disable_fixed_point(disable_fixed_point)
7092 .with_fallback_policy(gam_solve::rho_optimizer::FallbackPolicy::Automatic)
7102 .with_psi_dim(auxiliary_dim)
7103 .with_tolerance(tolerance)
7104 .with_max_iter(max_iter)
7105 .with_bounds(lower.clone(), upper.clone())
7106 .with_initial_rho(theta0.clone())
7107 .with_bfgs_step_cap(bfgs_step_cap)
7108 .with_bfgs_step_cap_psi(bfgs_step_cap_psi)
7109 .with_seed_config({
7110 let mut sc = exact_joint_seed_config(risk_profile, auxiliary_dim);
7111 if has_constant_curvature {
7112 sc.bounds = (sc.bounds.0, rho_ceiling);
7116 }
7135 sc
7136 })
7137 .with_rho_bound(rho_ceiling)
7138 .with_heuristic_lambdas(seed_heuristic);
7139 if let Some((n_obs, p_cols)) = profiled_objective_size {
7140 problem = problem
7148 .with_objective_scale(Some(n_obs as f64))
7149 .with_problem_size(n_obs, p_cols)
7150 .with_arc_initial_regularization(Some(0.25))
7151 .with_operator_initial_trust_radius(Some(4.0));
7152 }
7153 if let Some(screening_cap) = screening_cap {
7154 problem = problem
7155 .with_screening_cap(screening_cap)
7156 .with_screen_initial_rho(true);
7157 }
7158 problem
7159}
7160
7161fn kappa_phase_failure_is_fixed_kappa_recoverable(message: &str) -> bool {
7172 message.contains("no candidate seeds passed outer startup validation")
7173 || message.contains("joint hyper rho dimension mismatch")
7174 || message.contains("objective returned a non-finite cost")
7175}
7176
7177pub fn optimize_spatial_length_scale_exact_joint<FitOut, FitFn, ExactFn, ExactEfsFn, SeedFn>(
7178 data: ArrayView2<'_, f64>,
7179 block_specs: &[TermCollectionSpec],
7180 block_term_indices: &[Vec<usize>],
7181 kappa_options: &SpatialLengthScaleOptimizationOptions,
7182 joint_setup: &ExactJointHyperSetup,
7183 seed_risk_profile: gam_problem::SeedRiskProfile,
7184 analytic_joint_gradient_available: bool,
7185 analytic_joint_hessian_available: bool,
7186 disable_fixed_point: bool,
7187 screening_cap: Option<Arc<AtomicUsize>>,
7188 outer_derivative_policy: gam_model_api::families::custom_family::OuterDerivativePolicy,
7189 mut fit_fn: FitFn,
7190 mut exact_fn: ExactFn,
7191 mut exact_efs_fn: ExactEfsFn,
7192 mut seed_inner_beta_fn: SeedFn,
7193) -> Result<SpatialLengthScaleOptimizationResult<FitOut>, String>
7194where
7195 FitOut: Clone,
7196 FitFn: FnMut(
7197 &Array1<f64>,
7198 &[TermCollectionSpec],
7199 &[TermCollectionDesign],
7200 ) -> Result<FitOut, String>,
7201 ExactFn: FnMut(
7202 &Array1<f64>,
7203 &[TermCollectionSpec],
7204 &[TermCollectionDesign],
7205 gam_solve::estimate::reml::reml_outer_engine::EvalMode,
7206 &gam_problem::outer_subsample::RowSet,
7207 ) -> Result<
7208 (
7209 f64,
7210 Array1<f64>,
7211 gam_problem::HessianResult,
7212 ),
7213 String,
7214 >,
7215 ExactEfsFn: FnMut(
7216 &Array1<f64>,
7217 &[TermCollectionSpec],
7218 &[TermCollectionDesign],
7219 ) -> Result<gam_problem::EfsEval, String>,
7220 SeedFn:
7221 FnMut(&Array1<f64>) -> Result<gam_solve::rho_optimizer::SeedOutcome, EstimationError>,
7222{
7223 let n_blocks = block_specs.len();
7224 if block_term_indices.len() != n_blocks {
7225 return Err(SmoothError::dimension_mismatch(format!(
7226 "block_specs ({}) and block_term_indices ({}) length mismatch",
7227 n_blocks,
7228 block_term_indices.len()
7229 ))
7230 .into());
7231 }
7232
7233 let log_kappa_dim = joint_setup.log_kappa_dim();
7234
7235 log::warn!(
7236 "[OUTER-FD-AUDIT/spatial-exact-joint] driver entry: aux_dim={} log_kappa_dim={} kappa_enabled={} rho_dim={} theta0_len={}",
7237 joint_setup.auxiliary_dim(),
7238 log_kappa_dim,
7239 kappa_options.enabled,
7240 joint_setup.rho_dim(),
7241 joint_setup.theta0().len()
7242 );
7243
7244 if joint_setup.auxiliary_dim() == 0 && (!kappa_options.enabled || log_kappa_dim == 0) {
7248 log::warn!(
7249 "[OUTER-FD-AUDIT/spatial-exact-joint] taking FAST path (no outer theta optimization in this driver)"
7250 );
7251 let (designs, resolved_specs) = build_term_collection_designs_and_freeze_joint(
7252 data, block_specs,
7253 )
7254 .map_err(|e| {
7255 format!("failed to build and freeze joint block designs during exact joint kappa optimization: {e}")
7256 })?;
7257 let theta0 = joint_setup.theta0();
7258
7259 let spec_refs: Vec<TermCollectionSpec> = resolved_specs.clone();
7261 let design_refs: Vec<TermCollectionDesign> = designs.clone();
7262 let fit = fit_fn(&theta0, &spec_refs, &design_refs)?;
7263 return Ok(SpatialLengthScaleOptimizationResult {
7264 resolved_specs,
7265 designs,
7266 fit,
7267 timing: None,
7268 });
7269 }
7270
7271 let theta0 = joint_setup.theta0();
7275 let lower = joint_setup.lower();
7276 let upper = joint_setup.upper();
7277 if theta0.len() < log_kappa_dim || lower.len() != theta0.len() || upper.len() != theta0.len() {
7278 return Err(SmoothError::dimension_mismatch(format!(
7279 "invalid exact joint theta setup: theta0={}, lower={}, upper={}, required_log_kappa_dim={}",
7280 theta0.len(),
7281 lower.len(),
7282 upper.len(),
7283 log_kappa_dim
7284 ))
7285 .into());
7286 }
7287 let rho_dim = joint_setup.rho_dim();
7288 let all_dims = joint_setup.log_kappa_dims_per_term();
7289
7290 let (boot_designs, best_specs) = build_term_collection_designs_and_freeze_joint(
7292 data,
7293 block_specs,
7294 )
7295 .map_err(|e| {
7296 format!(
7297 "failed to build and freeze joint block designs during exact joint kappa bootstrap: {e}"
7298 )
7299 })?;
7300 let policy_hessian_form = outer_derivative_policy.declared_hessian_form();
7310 let analytic_outer_hessian_available = analytic_joint_hessian_available
7311 && matches!(
7312 policy_hessian_form,
7313 gam_problem::DeclaredHessianForm::Either
7314 | gam_problem::DeclaredHessianForm::Dense
7315 | gam_problem::DeclaredHessianForm::Operator { .. }
7316 );
7317 let prefer_gradient_only = !analytic_outer_hessian_available;
7318
7319 let theta_dim = theta0.len();
7320 let psi_dim = theta_dim - rho_dim;
7321
7322 let cache_blocks: Vec<(TermCollectionSpec, TermCollectionDesign, Vec<usize>)> = best_specs
7324 .iter()
7325 .zip(boot_designs.iter())
7326 .zip(block_term_indices.iter())
7327 .map(|((spec, design), terms)| (spec.clone(), design.clone(), terms.clone()))
7328 .collect();
7329
7330 struct NBlockExactJointState<'d> {
7331 cache: ExactJointDesignCache<'d>,
7332 }
7333
7334 let mut state = NBlockExactJointState {
7335 cache: ExactJointDesignCache::new(data, cache_blocks, rho_dim, all_dims.clone())?,
7336 };
7337
7338 const KAPPA_PILOT_K: usize = 5_000;
7363 const KAPPA_POLISH_K: usize = 25_000;
7364 const KAPPA_POLISH_TRIGGER_N: usize = 100_000;
7365
7366 let n_total = data.nrows();
7367 let use_staged_kappa = outer_derivative_policy.should_use_staged_kappa(n_total);
7368 if use_staged_kappa {
7369 log::info!(
7370 "[KAPPA-STAGED] auto-engaging pilot+polish schedule: n={} pilot_k={} polish_k={}",
7371 n_total,
7372 KAPPA_PILOT_K,
7373 KAPPA_POLISH_K,
7374 );
7375 }
7376
7377 fn build_uniform_pilot_subsample(
7394 n_total: usize,
7395 k_target: usize,
7396 seed: u64,
7397 ) -> gam_problem::outer_subsample::OuterScoreSubsample {
7398 use gam_problem::outer_subsample::OuterScoreSubsample;
7399 let k = k_target.min(n_total);
7400 if k == 0 || n_total == 0 {
7401 return OuterScoreSubsample::from_uniform_inclusion_mask(Vec::new(), n_total, seed);
7402 }
7403 let mut mask: Vec<usize> = Vec::with_capacity(k);
7407 let mut state = seed.wrapping_add(0x9E3779B97F4A7C15);
7409 let splitmix = |s: &mut u64| -> u64 { gam_linalg::utils::splitmix64(s) };
7410 let mut taken = std::collections::HashSet::with_capacity(k);
7411 for j in (n_total - k)..n_total {
7412 let r = (splitmix(&mut state) % (j as u64 + 1)) as usize;
7413 if !taken.insert(r) {
7414 taken.insert(j);
7415 mask.push(j);
7416 } else {
7417 mask.push(r);
7418 }
7419 }
7420 mask.sort_unstable();
7421 mask.dedup();
7422 OuterScoreSubsample::from_uniform_inclusion_mask(mask, n_total, seed)
7423 }
7424
7425 let current_row_set: std::cell::RefCell<gam_problem::outer_subsample::RowSet> = if use_staged_kappa {
7426 let pilot = build_uniform_pilot_subsample(n_total, KAPPA_PILOT_K, n_total as u64);
7427 std::cell::RefCell::new(gam_problem::outer_subsample::RowSet::Subsample {
7428 rows: std::sync::Arc::clone(&pilot.rows),
7429 n_full: n_total,
7430 })
7431 } else {
7432 std::cell::RefCell::new(gam_problem::outer_subsample::RowSet::All)
7433 };
7434
7435 let exact_fn_cell = std::cell::RefCell::new(&mut exact_fn);
7436 let exact_efs_fn_cell = std::cell::RefCell::new(&mut exact_efs_fn);
7437
7438 use std::cell::Cell;
7453 let kphase_cost_calls: Cell<usize> = Cell::new(0);
7454 let kphase_cost_total_s: Cell<f64> = Cell::new(0.0);
7455 let kphase_eval_calls: Cell<usize> = Cell::new(0);
7456 let kphase_eval_total_s: Cell<f64> = Cell::new(0.0);
7457 let kphase_efs_calls: Cell<usize> = Cell::new(0);
7458 let kphase_efs_total_s: Cell<f64> = Cell::new(0.0);
7459 let kphase_optim_start = std::time::Instant::now();
7460 let kphase_log_kappa_dim = log_kappa_dim;
7461 let kphase_log_norms = |theta: &Array1<f64>| -> (f64, f64) {
7462 let theta_norm = theta.iter().map(|v| v * v).sum::<f64>().sqrt();
7463 let log_kappa_norm = if kphase_log_kappa_dim > 0 && theta.len() >= kphase_log_kappa_dim {
7464 let start = theta.len() - kphase_log_kappa_dim;
7465 theta.iter().skip(start).map(|v| v * v).sum::<f64>().sqrt()
7466 } else {
7467 0.0
7468 };
7469 (theta_norm, log_kappa_norm)
7470 };
7471
7472 use gam_solve::rho_optimizer::OuterEvalOrder;
7473 use gam_problem::{DeclaredHessianForm, Derivative, OuterEval};
7474
7475 let joint_p_cols: usize = boot_designs
7479 .iter()
7480 .map(|d| d.design.ncols())
7481 .sum::<usize>()
7482 .max(1);
7483
7484 let problem = exact_joint_multistart_outer_problem(
7485 &theta0,
7486 &lower,
7487 &upper,
7488 rho_dim,
7489 psi_dim,
7490 theta_dim,
7491 if analytic_joint_gradient_available {
7492 Derivative::Analytic
7493 } else {
7494 Derivative::Unavailable
7495 },
7496 if analytic_outer_hessian_available {
7497 DeclaredHessianForm::Either
7498 } else {
7499 DeclaredHessianForm::Unavailable
7500 },
7501 prefer_gradient_only,
7502 disable_fixed_point,
7503 seed_risk_profile,
7504 kappa_options.rel_tol.max(1e-6),
7505 kappa_options.max_outer_iter.max(1),
7506 Some(5.0),
7508 Some(kappa_options.log_step.clamp(0.25, 1.0)),
7510 screening_cap.clone(),
7511 Some((n_total, joint_p_cols)),
7514 block_specs
7517 .iter()
7518 .any(|s| !constant_curvature_term_indices(s).is_empty()),
7519 );
7520
7521 fn collect_specs(cache: &ExactJointDesignCache<'_>) -> Vec<TermCollectionSpec> {
7523 cache.specs().into_iter().cloned().collect()
7524 }
7525 fn collect_designs(cache: &ExactJointDesignCache<'_>) -> Vec<TermCollectionDesign> {
7526 cache.designs().into_iter().cloned().collect()
7527 }
7528
7529 let result = {
7530 let eval_outer = |ctx: &mut &mut NBlockExactJointState<'_>,
7531 theta: &Array1<f64>,
7532 order: OuterEvalOrder|
7533 -> Result<OuterEval, EstimationError> {
7534 if let Some((cost, grad, hess)) = ctx.cache.memoized_eval(theta) {
7535 let cached_satisfies_order = match order {
7536 OuterEvalOrder::Value => true,
7537 OuterEvalOrder::ValueAndGradient => true,
7538 OuterEvalOrder::ValueGradientHessian => hess.is_analytic(),
7539 };
7540 if cached_satisfies_order {
7541 if !cost.is_finite() {
7542 return Ok(OuterEval::infeasible(theta.len()));
7543 }
7544 if grad.iter().any(|v| !v.is_finite()) {
7557 return Ok(OuterEval::infeasible(theta.len()));
7558 }
7559 return Ok(OuterEval {
7560 cost,
7561 gradient: grad,
7562 hessian: hess,
7563 inner_beta_hint: None,
7564 });
7565 }
7566 }
7567 if gam_solve::rho_optimizer::outer_wall_clock_deadline_exceeded() {
7584 return Ok(OuterEval::infeasible(theta.len()));
7585 }
7586 if let Err(err) = ctx.cache.ensure_theta(theta) {
7587 log::warn!(
7588 "[OUTER] n-block exact-joint spatial: ensure_theta failed during gradient evaluation: {err}"
7589 );
7590 return Ok(OuterEval::infeasible(theta.len()));
7591 }
7592 let design_revision = Some(ctx.cache.design_revision());
7593 let specs = collect_specs(&ctx.cache);
7594 let designs = collect_designs(&ctx.cache);
7595 let clamped = outer_derivative_policy.order_for_evaluation(order);
7603 let need_hessian = matches!(clamped, OuterEvalOrder::ValueGradientHessian)
7604 && analytic_outer_hessian_available;
7605 let eval_mode = if need_hessian {
7606 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueGradientHessian
7607 } else {
7608 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueAndGradient
7609 };
7610 let t0 = std::time::Instant::now();
7611 let result = {
7612 let row_set_borrow = current_row_set.borrow();
7613 (*exact_fn_cell.borrow_mut())(theta, &specs, &designs, eval_mode, &row_set_borrow)
7614 };
7615 let elapsed_s = t0.elapsed().as_secs_f64();
7616 kphase_eval_calls.set(kphase_eval_calls.get() + 1);
7617 kphase_eval_total_s.set(kphase_eval_total_s.get() + elapsed_s);
7618 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta);
7619 log::info!(
7620 "[KAPPA-PHASE] phase=eval_outer call={} order={:?} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
7621 kphase_eval_calls.get(),
7622 order,
7623 design_revision,
7624 theta_norm,
7625 log_kappa_norm,
7626 elapsed_s,
7627 );
7628 match result {
7629 Ok((cost, grad, hess)) => {
7630 ctx.cache.store_eval((cost, grad.clone(), hess.clone()));
7631 if !cost.is_finite() {
7632 return Ok(OuterEval::infeasible(theta.len()));
7633 }
7634 if grad.iter().any(|v| !v.is_finite()) {
7647 return Ok(OuterEval::infeasible(theta.len()));
7648 }
7649 Ok(OuterEval {
7650 cost,
7651 gradient: grad,
7652 hessian: hess,
7653 inner_beta_hint: None,
7654 })
7655 }
7656 Err(err) => {
7657 log::warn!(
7658 "[OUTER] n-block exact-joint spatial: exact evaluation failed: {err}"
7659 );
7660 Ok(OuterEval::infeasible(theta.len()))
7661 }
7662 }
7663 };
7664
7665 let obj = problem.build_objective_with_eval_order(
7666 &mut state,
7667 |ctx: &mut &mut NBlockExactJointState<'_>, theta: &Array1<f64>| {
7668 if let Some(cost) = ctx.cache.memoized_cost(theta) {
7669 return Ok(cost);
7670 }
7671 if gam_solve::rho_optimizer::outer_wall_clock_deadline_exceeded() {
7679 return Ok(f64::INFINITY);
7680 }
7681 if let Err(err) = ctx.cache.ensure_theta(theta) {
7682 log::warn!(
7683 "[OUTER] n-block exact-joint spatial: ensure_theta failed during cost evaluation: {err}"
7684 );
7685 return Ok(f64::INFINITY);
7686 }
7687 let design_revision = Some(ctx.cache.design_revision());
7688 let specs = collect_specs(&ctx.cache);
7689 let designs = collect_designs(&ctx.cache);
7690 let t0 = std::time::Instant::now();
7697 let result = {
7698 let row_set_borrow = current_row_set.borrow();
7699 (*exact_fn_cell.borrow_mut())(
7700 theta,
7701 &specs,
7702 &designs,
7703 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueOnly,
7704 &row_set_borrow,
7705 )
7706 };
7707 let elapsed_s = t0.elapsed().as_secs_f64();
7708 kphase_cost_calls.set(kphase_cost_calls.get() + 1);
7709 kphase_cost_total_s.set(kphase_cost_total_s.get() + elapsed_s);
7710 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta);
7711 log::info!(
7712 "[KAPPA-PHASE] phase=cost call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
7713 kphase_cost_calls.get(),
7714 design_revision,
7715 theta_norm,
7716 log_kappa_norm,
7717 elapsed_s,
7718 );
7719 match result {
7720 Ok((cost, _grad, _hess)) => {
7721 ctx.cache.store_cost_only(theta, cost);
7727 Ok(cost)
7728 }
7729 Err(err) => {
7730 log::warn!(
7731 "[OUTER] n-block exact-joint spatial: exact cost evaluation failed: {err}"
7732 );
7733 Ok(f64::INFINITY)
7734 }
7735 }
7736 },
7737 |ctx: &mut &mut NBlockExactJointState<'_>, theta: &Array1<f64>| {
7738 eval_outer(
7739 ctx,
7740 theta,
7741 if analytic_outer_hessian_available {
7742 OuterEvalOrder::ValueGradientHessian
7743 } else {
7744 OuterEvalOrder::ValueAndGradient
7745 },
7746 )
7747 },
7748 |ctx: &mut &mut NBlockExactJointState<'_>,
7749 theta: &Array1<f64>,
7750 order: OuterEvalOrder| { eval_outer(ctx, theta, order) },
7751 None::<fn(&mut &mut NBlockExactJointState<'_>)>,
7752 Some(
7753 |ctx: &mut &mut NBlockExactJointState<'_>, theta: &Array1<f64>| {
7754 ctx.cache
7755 .ensure_theta(theta)
7756 .map_err(EstimationError::InvalidInput)?;
7757 let design_revision = Some(ctx.cache.design_revision());
7758 let specs = collect_specs(&ctx.cache);
7759 let designs = collect_designs(&ctx.cache);
7760 let t0 = std::time::Instant::now();
7761 let eval_result = (*exact_efs_fn_cell.borrow_mut())(
7762 theta,
7763 &specs,
7764 &designs,
7765 );
7766 let elapsed_s = t0.elapsed().as_secs_f64();
7767 kphase_efs_calls.set(kphase_efs_calls.get() + 1);
7768 kphase_efs_total_s.set(kphase_efs_total_s.get() + elapsed_s);
7769 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta);
7770 log::info!(
7771 "[KAPPA-PHASE] phase=efs call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
7772 kphase_efs_calls.get(),
7773 design_revision,
7774 theta_norm,
7775 log_kappa_norm,
7776 elapsed_s,
7777 );
7778 let eval = eval_result.map_err(EstimationError::RemlOptimizationFailed)?;
7779 Ok(eval)
7780 },
7781 ),
7782 );
7783 let mut obj = obj.with_seed_inner_state(
7784 move |_ctx: &mut &mut NBlockExactJointState<'_>, beta: &Array1<f64>| {
7785 (seed_inner_beta_fn)(beta)
7786 },
7787 );
7788
7789 match problem.run(&mut obj, "n-block exact-joint spatial") {
7790 Ok(result) => result,
7791 Err(e) => {
7792 let message = e.to_string();
7793 if kappa_phase_failure_is_fixed_kappa_recoverable(&message) {
7813 drop(obj);
7814 log::warn!(
7815 "[KAPPA-PHASE] length-scale optimization could not validate any seed \
7816 ({message}); falling back to a FIXED bootstrap κ (skipping κ \
7817 optimization) and fitting there — a real model at the initial \
7818 length-scale rather than raising (gam#787/#860)."
7819 );
7820 let (designs, resolved_specs) =
7821 build_term_collection_designs_and_freeze_joint(data, block_specs).map_err(
7822 |build_err| {
7823 format!(
7824 "fixed-κ fallback failed to build and freeze joint block \
7825 designs after κ optimization could not validate a seed \
7826 ({message}): {build_err}"
7827 )
7828 },
7829 )?;
7830 let fixed_theta0 = joint_setup.theta0();
7831 let spec_refs: Vec<TermCollectionSpec> = resolved_specs.clone();
7832 let design_refs: Vec<TermCollectionDesign> = designs.clone();
7833 let fit = fit_fn(&fixed_theta0, &spec_refs, &design_refs)?;
7834 return Ok(SpatialLengthScaleOptimizationResult {
7835 resolved_specs,
7836 designs,
7837 fit,
7838 timing: None,
7839 });
7840 }
7841 return Err(message);
7842 }
7843 }
7844 }; let kphase_total_s = kphase_optim_start.elapsed().as_secs_f64();
7854 log::info!(
7855 "[KAPPA-PHASE-SUMMARY] log_kappa_dim={} n_cost={} cost_total_s={:.4} n_eval={} eval_total_s={:.4} n_efs={} efs_total_s={:.4} optim_total_s={:.4}",
7856 kphase_log_kappa_dim,
7857 kphase_cost_calls.get(),
7858 kphase_cost_total_s.get(),
7859 kphase_eval_calls.get(),
7860 kphase_eval_total_s.get(),
7861 kphase_efs_calls.get(),
7862 kphase_efs_total_s.get(),
7863 kphase_total_s,
7864 );
7865 let timing = SpatialLengthScaleOptimizationTiming {
7866 log_kappa_dim: kphase_log_kappa_dim,
7867 cost_calls: kphase_cost_calls.get(),
7868 cost_total_s: kphase_cost_total_s.get(),
7869 eval_calls: kphase_eval_calls.get(),
7870 eval_total_s: kphase_eval_total_s.get(),
7871 efs_calls: kphase_efs_calls.get(),
7872 efs_total_s: kphase_efs_total_s.get(),
7873 slow_path_resets: 0,
7874 design_revision_delta: 0,
7875 nfree_miss_shape: 0,
7876 nfree_miss_value: 0,
7877 nfree_miss_gradient: 0,
7878 nfree_miss_penalty: 0,
7879 nfree_miss_revision: 0,
7880 nfree_miss_second_order: 0,
7881 nfree_miss_other: 0,
7882 optim_total_s: kphase_total_s,
7883 };
7884
7885 let theta_star = result.rho;
7886
7887 if use_staged_kappa && n_total >= KAPPA_POLISH_TRIGGER_N {
7904 let polish = build_uniform_pilot_subsample(
7905 n_total,
7906 KAPPA_POLISH_K,
7907 (n_total as u64).wrapping_add(0xA5A5A5A5),
7908 );
7909 *current_row_set.borrow_mut() = gam_problem::outer_subsample::RowSet::Subsample {
7910 rows: std::sync::Arc::clone(&polish.rows),
7911 n_full: n_total,
7912 };
7913 log::info!(
7914 "[KAPPA-STAGED] rotating to polish subsample: k={} at theta_star",
7915 polish.rows.len(),
7916 );
7917 state.cache.ensure_theta(&theta_star)?;
7921 let (polish_cost, polish_grad, _) = {
7922 let specs = collect_specs(&state.cache);
7923 let designs = collect_designs(&state.cache);
7924 let row_set_borrow = current_row_set.borrow();
7925 exact_fn(
7926 &theta_star,
7927 &specs,
7928 &designs,
7929 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueAndGradient,
7930 &row_set_borrow,
7931 )?
7932 };
7933 if !polish_cost.is_finite() || polish_grad.iter().any(|value| !value.is_finite()) {
7934 return Err(
7935 "polish subsample exact-joint evaluation produced non-finite objective pieces"
7936 .to_string(),
7937 );
7938 }
7939 }
7940 *current_row_set.borrow_mut() = gam_problem::outer_subsample::RowSet::All;
7941 if use_staged_kappa {
7942 log::info!(
7943 "[KAPPA-STAGED] rotating to full data for final coefficient fit (n={})",
7944 n_total,
7945 );
7946 }
7947
7948 state.cache.ensure_theta(&theta_star)?;
7949
7950 let resolved_specs: Vec<TermCollectionSpec> = collect_specs(&state.cache);
7951 let designs: Vec<TermCollectionDesign> = collect_designs(&state.cache);
7952
7953 let fit = fit_fn(&theta_star, &resolved_specs, &designs)?;
7954
7955 for spec in &resolved_specs {
7956 log_spatial_aniso_scales(spec);
7957 }
7958
7959 Ok(SpatialLengthScaleOptimizationResult {
7960 resolved_specs,
7961 designs,
7962 fit,
7963 timing: Some(timing),
7964 })
7965}
7966
7967fn try_exact_joint_latent_coord_optimization(
7968 data: ArrayView2<'_, f64>,
7969 y: ArrayView1<'_, f64>,
7970 weights: ArrayView1<'_, f64>,
7971 offset: ArrayView1<'_, f64>,
7972 resolvedspec: &TermCollectionSpec,
7973 best: &FittedTermCollection,
7974 family: LikelihoodSpec,
7975 options: &FitOptions,
7976 latent: &StandardLatentCoordConfig,
7977) -> Result<FittedTermCollectionWithSpec, EstimationError> {
7978 use gam_solve::rho_optimizer::OuterEvalOrder;
7979 use gam_problem::{DeclaredHessianForm, Derivative, OuterEval};
7980
7981 let rho_dim = best.fit.lambdas.len();
7982 let latent_flat_dim = latent.values.len();
7983 if latent_flat_dim == 0 {
7984 crate::bail_invalid_estim!(
7985 "latent-coordinate optimization requires a non-empty latent block"
7986 );
7987 }
7988 let direct_hypers =
7989 latent_coord_initial_direct_hypers(latent.values.id_mode(), latent.values.latent_dim())?;
7990 let analytic_rho_count = latent
7991 .analytic_penalties
7992 .as_ref()
7993 .map_or(0, |registry| registry.total_rho_count());
7994 let latent_coord_ext_dim = latent_flat_dim + analytic_rho_count + direct_hypers.len();
7995
7996 let mut theta0 = Array1::<f64>::zeros(rho_dim + latent_coord_ext_dim);
7997 theta0
7998 .slice_mut(s![..rho_dim])
7999 .assign(&best.fit.lambdas.mapv(f64::ln));
8000 theta0
8001 .slice_mut(s![rho_dim..rho_dim + latent_flat_dim])
8002 .assign(latent.values.as_flat());
8003 if !direct_hypers.is_empty() {
8004 let direct_start = rho_dim + latent_flat_dim + analytic_rho_count;
8005 theta0
8006 .slice_mut(s![direct_start..direct_start + direct_hypers.len()])
8007 .assign(&direct_hypers);
8008 }
8009
8010 let mut lower = Array1::<f64>::from_elem(theta0.len(), -12.0);
8011 let mut upper = Array1::<f64>::from_elem(theta0.len(), 12.0);
8012 let latent_bound = latent
8013 .values
8014 .as_flat()
8015 .iter()
8016 .fold(1.0_f64, |acc, &v| acc.max(v.abs()))
8017 + 10.0;
8018 for axis in rho_dim..rho_dim + latent_flat_dim {
8019 lower[axis] = -latent_bound;
8020 upper[axis] = latent_bound;
8021 }
8022
8023 struct LatentJointContext<'d> {
8024 rho_dim: usize,
8025 cache: SingleBlockLatentCoordDesignCache,
8026 evaluator: gam_solve::estimate::ExternalJointHyperEvaluator<'d>,
8027 }
8028
8029 impl<'d> LatentJointContext<'d> {
8030 fn eval_full(
8031 &mut self,
8032 theta: &Array1<f64>,
8033 order: OuterEvalOrder,
8034 ) -> Result<
8035 (
8036 f64,
8037 Array1<f64>,
8038 gam_problem::HessianResult,
8039 ),
8040 EstimationError,
8041 > {
8042 if let Some(eval) = self.cache.memoized_eval(theta) {
8043 return Ok(eval);
8044 }
8045 self.cache
8046 .ensure_theta(theta)
8047 .map_err(EstimationError::InvalidInput)?;
8048 let hyper_dirs = self
8049 .cache
8050 .hyper_dirs()
8051 .map_err(EstimationError::InvalidInput)?;
8052 let design_revision = Some(self.cache.design_revision());
8053 let registry_for_key = self.cache.analytic_penalties();
8054 self.evaluator
8055 .set_analytic_penalty_registry(registry_for_key.as_deref());
8056 let mut eval = evaluate_joint_reml_outer_eval_at_theta(
8057 &mut self.evaluator,
8058 self.cache.design(),
8059 theta,
8060 self.rho_dim,
8061 hyper_dirs,
8062 None,
8063 order,
8064 design_revision,
8065 )?;
8066 let latent = self.cache.latent().map_err(EstimationError::InvalidInput)?;
8067 if let Some(registry) = registry_for_key {
8068 let mut registry = registry.as_ref().clone();
8069 registry.apply_weight_schedules(
8070 gam_solve::estimate::reml::outer_eval::current_outer_iter() as usize,
8071 );
8072 add_analytic_penalty_objective_to_eval(
8073 theta,
8074 self.rho_dim,
8075 latent.as_ref(),
8076 ®istry,
8077 &mut eval,
8078 )?;
8079 }
8080 add_latent_id_objective_to_eval(
8081 theta,
8082 self.rho_dim,
8083 self.cache.analytic_penalty_rho_count(),
8084 latent.as_ref(),
8085 &mut eval,
8086 )?;
8087 self.cache.store_eval(eval.clone());
8088 Ok(eval)
8089 }
8090
8091 fn eval_efs(
8092 &mut self,
8093 theta: &Array1<f64>,
8094 ) -> Result<gam_problem::EfsEval, EstimationError> {
8095 self.cache
8096 .ensure_theta(theta)
8097 .map_err(EstimationError::InvalidInput)?;
8098 let hyper_dirs = self
8099 .cache
8100 .hyper_dirs()
8101 .map_err(EstimationError::InvalidInput)?;
8102 let registry_for_key = self.cache.analytic_penalties();
8103 self.evaluator
8104 .set_analytic_penalty_registry(registry_for_key.as_deref());
8105 let mut efs = evaluate_joint_reml_efs_at_theta(
8106 &mut self.evaluator,
8107 self.cache.design(),
8108 theta,
8109 self.rho_dim,
8110 hyper_dirs,
8111 None,
8112 Some(self.cache.design_revision()),
8113 )?;
8114 if let Some(registry) = registry_for_key {
8115 let mut registry = registry.as_ref().clone();
8116 registry.apply_weight_schedules(
8117 gam_solve::estimate::reml::outer_eval::current_outer_iter() as usize,
8118 );
8119 let latent = self.cache.latent().map_err(EstimationError::InvalidInput)?;
8120 let contribution = analytic_penalty_objective_contribution(
8121 theta,
8122 self.rho_dim,
8123 latent.as_ref(),
8124 ®istry,
8125 )?;
8126 efs.cost += contribution.cost;
8127 if let (Some(psi_gradient), Some(psi_indices)) =
8128 (efs.psi_gradient.as_mut(), efs.psi_indices.as_ref())
8129 {
8130 if psi_gradient.len() != psi_indices.len() {
8131 crate::bail_invalid_estim!(
8132 "latent-coordinate analytic penalty EFS psi gradient length mismatch: gradient={}, indices={}",
8133 psi_gradient.len(),
8134 psi_indices.len()
8135 );
8136 }
8137 for (local_idx, &theta_idx) in psi_indices.iter().enumerate() {
8138 psi_gradient[local_idx] += contribution.gradient[theta_idx];
8139 }
8140 }
8141 }
8142 Ok(efs)
8143 }
8144
8145 fn eval_cost(&mut self, theta: &Array1<f64>) -> f64 {
8146 if let Some(cost) = self.cache.memoized_cost(theta) {
8147 return cost;
8148 }
8149 if self.cache.ensure_theta(theta).is_err() {
8150 return f64::INFINITY;
8151 }
8152 let design_revision = Some(self.cache.design_revision());
8153 let registry_for_key = self.cache.analytic_penalties();
8154 self.evaluator
8155 .set_analytic_penalty_registry(registry_for_key.as_deref());
8156 let result = {
8157 let design = self.cache.design();
8158 self.evaluator.evaluate_cost_only(
8159 &design.design,
8160 &design.penalties,
8161 &design.nullspace_dims,
8162 design.linear_constraints.clone(),
8163 theta,
8164 self.rho_dim,
8165 None,
8166 "latent-coordinate-joint cost-only",
8167 design_revision,
8168 )
8169 };
8170 match result {
8171 Ok(cost) => {
8172 let latent = match self.cache.latent() {
8173 Ok(latent) => latent,
8174 Err(_) => return f64::INFINITY,
8175 };
8176 let contribution = match latent_id_objective_contribution(
8177 theta,
8178 self.rho_dim,
8179 self.cache.analytic_penalty_rho_count(),
8180 latent.as_ref(),
8181 ) {
8182 Ok(contribution) => contribution,
8183 Err(_) => return f64::INFINITY,
8184 };
8185 let cost = cost + contribution.cost;
8186 let cost = if let Some(registry) = registry_for_key {
8187 let mut registry = registry.as_ref().clone();
8188 registry.apply_weight_schedules(
8189 gam_solve::estimate::reml::outer_eval::current_outer_iter()
8190 as usize,
8191 );
8192 match analytic_penalty_objective_contribution(
8193 theta,
8194 self.rho_dim,
8195 latent.as_ref(),
8196 ®istry,
8197 ) {
8198 Ok(contribution) => cost + contribution.cost,
8199 Err(_) => return f64::INFINITY,
8200 }
8201 } else {
8202 cost
8203 };
8204 self.cache.store_cost(cost);
8205 cost
8206 }
8207 Err(_) => f64::INFINITY,
8208 }
8209 }
8210 }
8211
8212 let mut ctx = LatentJointContext {
8213 rho_dim,
8214 cache: SingleBlockLatentCoordDesignCache::new(
8215 data.to_owned(),
8216 resolvedspec.clone(),
8217 best.design.clone(),
8218 latent,
8219 rho_dim,
8220 )
8221 .map_err(EstimationError::InvalidInput)?,
8222 evaluator: gam_solve::estimate::ExternalJointHyperEvaluator::new(
8223 y,
8224 weights,
8225 &best.design.design,
8226 offset,
8227 &best.design.penalties,
8228 &external_opts_for_design(&family, &best.design, options),
8229 "latent-coordinate-joint",
8230 )?,
8231 };
8232 let registry_for_key = ctx.cache.analytic_penalties();
8233 ctx.evaluator
8234 .set_analytic_penalty_registry(registry_for_key.as_deref());
8235 ctx.evaluator
8236 .set_persistent_latent_values_fingerprint(latent.values.id_mode());
8237 if let Some(cached_t) = ctx
8238 .evaluator
8239 .load_persistent_latent_values(latent.values.n_obs(), latent.values.latent_dim())
8240 {
8241 let cached_t: Array2<f64> = cached_t;
8242 for (dst, src) in theta0
8243 .slice_mut(s![rho_dim..rho_dim + latent_flat_dim])
8244 .iter_mut()
8245 .zip(cached_t.iter())
8246 {
8247 *dst = *src;
8248 }
8249 }
8250
8251 let problem = exact_joint_multistart_outer_problem(
8252 &theta0,
8253 &lower,
8254 &upper,
8255 rho_dim,
8256 latent_coord_ext_dim,
8257 theta0.len(),
8258 Derivative::Analytic,
8259 DeclaredHessianForm::Unavailable,
8260 false,
8261 false,
8262 seed_risk_profile_for_likelihood_family(&family),
8263 options.tol,
8264 options.max_iter.max(1),
8265 Some(5.0),
8266 Some(0.5),
8267 None,
8268 Some((data.nrows(), best.design.design.ncols().max(1))),
8271 !constant_curvature_term_indices(resolvedspec).is_empty(),
8274 );
8275
8276 let eval_outer = |ctx: &mut &mut LatentJointContext<'_>,
8277 theta: &Array1<f64>,
8278 order: OuterEvalOrder|
8279 -> Result<OuterEval, EstimationError> {
8280 let (cost, gradient, hessian) = ctx.eval_full(theta, order)?;
8281 Ok(OuterEval {
8282 cost,
8283 gradient,
8284 hessian,
8285 inner_beta_hint: None,
8286 })
8287 };
8288
8289 let result = {
8290 let mut obj = problem.build_objective_with_eval_order(
8291 &mut ctx,
8292 |ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>| Ok(ctx.eval_cost(theta)),
8293 |ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>| {
8294 eval_outer(ctx, theta, OuterEvalOrder::ValueAndGradient)
8295 },
8296 |ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>, order: OuterEvalOrder| {
8297 eval_outer(ctx, theta, order)
8298 },
8299 Some(|ctx: &mut &mut LatentJointContext<'_>| {
8300 ctx.cache.reset();
8301 }),
8302 Some(|ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>| ctx.eval_efs(theta)),
8303 );
8304
8305 problem
8306 .run(&mut obj, "latent-coordinate joint REML")
8307 .map_err(|e| {
8308 EstimationError::InvalidInput(format!(
8309 "latent-coordinate joint optimization failed after exhausting strategy fallbacks: {e}"
8310 ))
8311 })?
8312 };
8313 if !result.converged {
8314 crate::bail_invalid_estim!(
8315 "latent-coordinate joint optimization did not converge after {} iterations (final_objective={:.6e}, final_grad_norm={})",
8316 result.iterations,
8317 result.final_value,
8318 result.final_grad_norm_report(),
8319 );
8320 }
8321
8322 let theta_star = result.rho;
8323 let rho_star = theta_star.slice(s![..rho_dim]).mapv(f64::exp);
8324 let mut final_data = data.to_owned();
8325 let flat_t = theta_star
8326 .slice(s![rho_dim..rho_dim + latent_flat_dim])
8327 .to_owned();
8328 let mut fitted_latent_values =
8329 Array2::<f64>::zeros((latent.values.n_obs(), latent.values.latent_dim()));
8330 for n in 0..latent.values.n_obs() {
8331 for axis in 0..latent.values.latent_dim() {
8332 let value = flat_t[n * latent.values.latent_dim() + axis];
8333 fitted_latent_values[[n, axis]] = value;
8334 final_data[[n, latent.feature_cols[axis]]] = value;
8335 }
8336 }
8337 let optimized = fit_term_collection_forspecwith_heuristic_lambdas(
8338 final_data.view(),
8339 y,
8340 weights,
8341 offset,
8342 resolvedspec,
8343 rho_star.as_slice(),
8344 family,
8345 options,
8346 )?;
8347 ctx.evaluator
8348 .store_persistent_latent_values(&fitted_latent_values);
8349 let mut fit = optimized.fit;
8350 fit.reml_score = result.final_value;
8351 fit.penalized_objective = result.final_value;
8352 Ok(FittedTermCollectionWithSpec {
8353 fit,
8354 design: optimized.design,
8355 resolvedspec: resolvedspec.clone(),
8356 adaptive_diagnostics: optimized.adaptive_diagnostics,
8357 kappa_timing: None,
8358 })
8359}
8360
8361pub fn fit_term_collectionwith_latent_coord_optimization(
8362 data: ArrayView2<'_, f64>,
8363 y: Array1<f64>,
8364 weights: Array1<f64>,
8365 offset: Array1<f64>,
8366 spec: &TermCollectionSpec,
8367 latent: &StandardLatentCoordConfig,
8368 family: LikelihoodSpec,
8369 options: &FitOptions,
8370) -> Result<FittedTermCollectionWithSpec, EstimationError> {
8371 let n = data.nrows();
8372 if !(y.len() == n && weights.len() == n && offset.len() == n) {
8373 crate::bail_invalid_estim!(
8374 "fit_term_collectionwith_latent_coord_optimization row mismatch: n={}, y={}, weights={}, offset={}",
8375 n,
8376 y.len(),
8377 weights.len(),
8378 offset.len()
8379 );
8380 }
8381 let best = fit_term_collection_forspec(
8382 data,
8383 y.view(),
8384 weights.view(),
8385 offset.view(),
8386 spec,
8387 family.clone(),
8388 options,
8389 )?;
8390 let resolvedspec = freeze_term_collection_from_design(spec, &best.design)?;
8391 try_exact_joint_latent_coord_optimization(
8392 data,
8393 y.view(),
8394 weights.view(),
8395 offset.view(),
8396 &resolvedspec,
8397 &best,
8398 family,
8399 options,
8400 latent,
8401 )
8402}
8403
8404pub fn fit_term_collectionwith_spatial_length_scale_optimization(
8405 data: ArrayView2<'_, f64>,
8406 y: Array1<f64>,
8407 weights: Array1<f64>,
8408 offset: Array1<f64>,
8409 spec: &TermCollectionSpec,
8410 family: LikelihoodSpec,
8411 options: &FitOptions,
8412 kappa_options: &SpatialLengthScaleOptimizationOptions,
8413) -> Result<FittedTermCollectionWithSpec, EstimationError> {
8414 let mut resolvedspec = spec.clone();
8430 let spatial_terms = spatial_length_scale_term_indices(&resolvedspec);
8431 let n = data.nrows();
8432 if !(y.len() == n && weights.len() == n && offset.len() == n) {
8433 crate::bail_invalid_estim!(
8434 "fit_term_collectionwith_spatial_length_scale_optimization row mismatch: n={}, y={}, weights={}, offset={}",
8435 n,
8436 y.len(),
8437 weights.len(),
8438 offset.len()
8439 );
8440 }
8441 if !kappa_options.enabled || spatial_terms.is_empty() {
8442 let out = fit_term_collection_forspec(
8443 data,
8444 y.view(),
8445 weights.view(),
8446 offset.view(),
8447 &resolvedspec,
8448 family,
8449 options,
8450 )?;
8451 let resolvedspec = freeze_term_collection_from_design(&resolvedspec, &out.design)?;
8452 return Ok(FittedTermCollectionWithSpec {
8453 fit: out.fit,
8454 design: out.design,
8455 resolvedspec,
8456 adaptive_diagnostics: out.adaptive_diagnostics,
8457 kappa_timing: None,
8458 });
8459 }
8460 if kappa_options.max_outer_iter == 0 {
8461 crate::bail_invalid_estim!("spatial kappa optimization requires max_outer_iter >= 1");
8462 }
8463 if !(kappa_options.log_step.is_finite() && kappa_options.log_step > 0.0) {
8464 crate::bail_invalid_estim!("spatial kappa optimization requires log_step > 0");
8465 }
8466 if !(kappa_options.min_length_scale.is_finite()
8467 && kappa_options.max_length_scale.is_finite()
8468 && kappa_options.min_length_scale > 0.0
8469 && kappa_options.max_length_scale >= kappa_options.min_length_scale)
8470 {
8471 crate::bail_invalid_estim!(
8472 "spatial kappa optimization requires valid positive length_scale bounds"
8473 );
8474 }
8475
8476 let pilot_threshold = kappa_options.pilot_subsample_threshold;
8477 if pilot_threshold > 0 && n > pilot_threshold * 2 {
8478 log::info!(
8479 "[spatial-kappa] n={n} exceeds pilot threshold {}; using pilot geometry only for deterministic anisotropy initialization",
8480 pilot_threshold * 2,
8481 );
8482 apply_spatial_anisotropy_pilot_initializer(
8483 data,
8484 &mut resolvedspec,
8485 &spatial_terms,
8486 pilot_threshold,
8487 kappa_options,
8488 );
8489 }
8490
8491 apply_response_aware_anisotropy_seed(data, y.view(), &mut resolvedspec, &spatial_terms);
8500
8501 for term_idx in constant_curvature_term_indices(&resolvedspec) {
8519 if let Some(kappa_seed) =
8520 select_constant_curvature_kappa_sign_seed(data, y.view(), &resolvedspec, term_idx)
8521 && kappa_seed != 0.0
8522 && let Some(SmoothBasisSpec::ConstantCurvature { spec: cc, .. }) =
8523 resolvedspec.smooth_terms.get_mut(term_idx).map(|t| &mut t.basis)
8524 {
8525 log::info!(
8526 "[#1464] pinned CC term {term_idx} baseline κ to κ-fair scan value {kappa_seed} \
8527 (raw profiled REML is sign-blind; scan is authoritative for the sign)"
8528 );
8529 cc.kappa = kappa_seed;
8530 }
8531 }
8532
8533 let baseline_options = superseded_fit_options(options);
8534 let mut best = fit_term_collection_forspec(
8535 data,
8536 y.view(),
8537 weights.view(),
8538 offset.view(),
8539 &resolvedspec,
8540 family.clone(),
8541 &baseline_options,
8542 )?;
8543 resolvedspec = freeze_term_collection_from_design(&resolvedspec, &best.design)?;
8544 let mut spatial_terms = spatial_length_scale_term_indices(&resolvedspec);
8554 sync_aniso_contrasts_from_metadata(&mut resolvedspec, &best.design.smooth);
8558 let mut prescan_improved = false;
8565 if !spatial_terms.is_empty() {
8566 let baseline_score = fit_score(&best.fit);
8567 let range_overrides = prescan_isotropic_spatial_range_seed(
8568 data,
8569 y.view(),
8570 weights.view(),
8571 offset.view(),
8572 &resolvedspec,
8573 baseline_score,
8574 &family,
8575 &baseline_options,
8576 kappa_options,
8577 &spatial_terms,
8578 )?;
8579 if !range_overrides.is_empty() {
8580 prescan_improved = true;
8581 for (term_idx, length_scale) in range_overrides {
8582 set_spatial_length_scale(&mut resolvedspec, term_idx, length_scale)?;
8583 }
8584 best = fit_term_collection_forspec(
8588 data,
8589 y.view(),
8590 weights.view(),
8591 offset.view(),
8592 &resolvedspec,
8593 family.clone(),
8594 &baseline_options,
8595 )?;
8596 resolvedspec = freeze_term_collection_from_design(&resolvedspec, &best.design)?;
8597 spatial_terms = spatial_length_scale_term_indices(&resolvedspec);
8601 }
8602 }
8603 if spatial_terms.is_empty() {
8604 let fitted = fit_term_collection_forspecwith_heuristic_lambdas(
8605 data,
8606 y.view(),
8607 weights.view(),
8608 offset.view(),
8609 &resolvedspec,
8610 best.fit.lambdas.as_slice(),
8611 family,
8612 options,
8613 )?;
8614 return Ok(FittedTermCollectionWithSpec {
8615 fit: fitted.fit,
8616 design: fitted.design,
8617 resolvedspec,
8618 adaptive_diagnostics: fitted.adaptive_diagnostics,
8619 kappa_timing: None,
8620 });
8621 }
8622 let initial_score = fit_score(&best.fit);
8623 if !initial_score.is_finite() {
8624 log::debug!("[spatial-kappa] initial profiled score is non-finite");
8625 }
8626 let seed_length_scales: Vec<(usize, f64)> = spatial_terms
8633 .iter()
8634 .filter_map(|&t| get_spatial_length_scale(&resolvedspec, t).map(|ls| (t, ls)))
8635 .collect();
8636 let joint_result = try_exact_joint_spatial_length_scale_optimization(
8637 data,
8638 y.view(),
8639 weights.view(),
8640 offset.view(),
8641 &resolvedspec,
8642 &best,
8643 family.clone(),
8644 options,
8645 kappa_options,
8646 &spatial_terms,
8647 )
8648 .map(|opt| {
8649 opt.map(|fit| {
8650 let score = fit_score(&fit.fit);
8651 (fit, score)
8652 })
8653 });
8654 let exact_joint = if prescan_improved && !matches!(joint_result, Ok(Some(_))) {
8664 let reason = match &joint_result {
8665 Err(e) => format!("error: {e}"),
8666 _ => "unavailable".to_string(),
8667 };
8668 log::info!(
8669 "[spatial-kappa] #1074 joint polish yielded no usable candidate \
8670 ({reason}); returning the multi-start pre-scan geometry (REML {initial_score:.5})"
8671 );
8672 FittedTermCollectionWithSpec {
8673 fit: best.fit,
8674 design: best.design,
8675 resolvedspec,
8676 adaptive_diagnostics: best.adaptive_diagnostics,
8677 kappa_timing: None,
8678 }
8679 } else {
8680 require_successful_spatial_optimization_result(initial_score, joint_result)?
8681 };
8682
8683 let exact_joint = {
8710 let primary_score = fit_score(&exact_joint.fit);
8711 let improved = primary_score.is_finite()
8712 && initial_score.is_finite()
8713 && primary_score < initial_score - 1e-7 * initial_score.abs().max(1.0);
8714 let base_spec = exact_joint.resolvedspec.clone();
8719 let geometry_unchanged = !seed_length_scales.is_empty()
8722 && seed_length_scales.iter().all(|&(t, seed_ls)| {
8723 get_spatial_length_scale(&base_spec, t)
8724 .is_some_and(|ls| (ls - seed_ls).abs() <= 1e-6 * seed_ls.abs().max(1.0))
8725 });
8726 let eligible = !improved
8727 && geometry_unchanged
8728 && !has_aniso_terms(&base_spec, &spatial_terms)
8729 && constant_curvature_term_indices(&base_spec).is_empty()
8730 && spatial_terms
8731 .iter()
8732 .any(|&t| get_spatial_length_scale(&base_spec, t).is_some());
8733 if eligible {
8734 log::info!(
8735 "[spatial-kappa] #1688 joint solve stalled at REML {primary_score:.5} \
8736 (no improvement over baseline {initial_score:.5}); running ψ-window \
8737 multistart rescue across {} seeds",
8738 JOINT_RESTART_WINDOW_FRACTIONS.len(),
8739 );
8740 let mut best_fit = exact_joint;
8741 let mut best_score = primary_score;
8743 for &fraction in JOINT_RESTART_WINDOW_FRACTIONS.iter() {
8744 match joint_solve_from_window_fraction(
8745 data,
8746 y.view(),
8747 weights.view(),
8748 offset.view(),
8749 &base_spec,
8750 &spatial_terms,
8751 fraction,
8752 &family,
8753 options,
8754 &baseline_options,
8755 kappa_options,
8756 ) {
8757 Ok(Some((candidate, score))) => {
8758 if score.is_finite()
8759 && (!best_score.is_finite()
8760 || score < best_score - 1e-7 * best_score.abs().max(1.0))
8761 {
8762 log::info!(
8763 "[spatial-kappa] #1688 multistart seed (ψ-window {fraction:.2}) \
8764 reached REML {score:.5}, improving on {best_score:.5}",
8765 );
8766 best_score = score;
8767 best_fit = candidate;
8768 }
8769 }
8770 Ok(None) => {}
8772 Err(e) => {
8776 log::info!(
8777 "[spatial-kappa] #1688 multistart seed (ψ-window {fraction:.2}) \
8778 failed ({e}); skipping"
8779 );
8780 }
8781 }
8782 }
8783 best_fit
8784 } else {
8785 exact_joint
8786 }
8787 };
8788
8789 log_spatial_aniso_scales(&exact_joint.resolvedspec);
8790 Ok(exact_joint)
8791}
8792
8793#[derive(Clone, Debug)]
8799pub struct CurvatureInference {
8800 pub term_idx: usize,
8802 pub kappa_hat: f64,
8805 pub ci: gam_geometry::curvature_estimand::KappaProfileCi,
8807 pub flatness: gam_geometry::curvature_estimand::FlatnessTest,
8811}
8812
8813pub fn curvature_inference_forspec(
8831 data: ArrayView2<'_, f64>,
8832 y: ArrayView1<'_, f64>,
8833 weights: ArrayView1<'_, f64>,
8834 offset: ArrayView1<'_, f64>,
8835 resolvedspec: &TermCollectionSpec,
8836 term_idx: usize,
8837 family: LikelihoodSpec,
8838 options: &FitOptions,
8839 level: f64,
8840) -> Result<CurvatureInference, EstimationError> {
8841 let kappa_hat = get_constant_curvature_kappa(resolvedspec, term_idx).ok_or_else(|| {
8842 EstimationError::InvalidInput(format!(
8843 "curvature_inference_forspec: term {term_idx} is not a constant-curvature smooth"
8844 ))
8845 })?;
8846 let (kappa_min, kappa_max) = constant_curvature_kappa_bounds(data, resolvedspec, term_idx);
8847
8848 let cc_fair_inputs: Option<(Array2<f64>, gam_terms::basis::ConstantCurvatureBasisSpec)> =
8873 if kappa_hat < 0.0 {
8874 match resolvedspec.smooth_terms.get(term_idx).map(|t| &t.basis) {
8875 Some(SmoothBasisSpec::ConstantCurvature {
8876 feature_cols, spec, ..
8877 }) => select_columns(data, feature_cols)
8878 .ok()
8879 .map(|x| (x, spec.clone())),
8880 _ => None,
8881 }
8882 } else {
8883 None
8884 };
8885
8886 let v_p_cache: std::cell::RefCell<std::collections::HashMap<u64, f64>> =
8891 std::cell::RefCell::new(std::collections::HashMap::new());
8892 let v_p = |kappa: f64| -> Result<f64, String> {
8893 if !kappa.is_finite() {
8894 return Err(format!("V_p probed a non-finite κ = {kappa}"));
8895 }
8896 let key = kappa.to_bits();
8897 if let Some(&cached) = v_p_cache.borrow().get(&key) {
8898 return Ok(cached);
8899 }
8900 let score = if let Some((x_term, base_spec)) = &cc_fair_inputs {
8901 let mut probe_spec = base_spec.clone();
8902 probe_spec.kappa = kappa;
8903 gam_terms::basis::constant_curvature_kappa_fair_sign_score(x_term.view(), y, &probe_spec)
8904 .map_err(|e| format!("κ-fair criterion at κ={kappa} failed: {e}"))?
8905 } else {
8906 fixed_kappa_profiled_reml_score(
8907 data,
8908 y,
8909 weights,
8910 offset,
8911 resolvedspec,
8912 term_idx,
8913 kappa,
8914 family.clone(),
8915 options,
8916 )
8917 .map_err(|e| format!("V_p fixed-κ fit at κ={kappa} failed: {e}"))?
8918 };
8919 v_p_cache.borrow_mut().insert(key, score);
8920 Ok(score)
8921 };
8922
8923 let h = (1e-3 * (kappa_max - kappa_min)).max(1e-4);
8927 let v_pp = match (v_p(kappa_hat + h), v_p(kappa_hat), v_p(kappa_hat - h)) {
8928 (Ok(vp), Ok(v0), Ok(vm)) => (vp - 2.0 * v0 + vm) / (h * h),
8929 _ => f64::NAN, };
8931
8932 let ci = gam_geometry::curvature_estimand::profile_ci_walk(
8933 &v_p, kappa_hat, v_pp, kappa_min, kappa_max, level, 1e-4,
8934 )
8935 .map_err(EstimationError::InvalidInput)?;
8936 let flatness = gam_geometry::curvature_estimand::flatness_lr_test(&v_p, kappa_hat)
8937 .map_err(EstimationError::InvalidInput)?;
8938
8939 Ok(CurvatureInference {
8940 term_idx,
8941 kappa_hat,
8942 ci,
8943 flatness,
8944 })
8945}
8946
8947#[derive(Clone, Copy, Debug, PartialEq, Eq)]
8950pub enum SmoothLrCorrection {
8951 LawleyLrEstimatedLambda,
8955 LawleyLrFixedLambda,
8960 None,
8964}
8965
8966impl SmoothLrCorrection {
8967 pub fn label(self) -> &'static str {
8969 match self {
8970 SmoothLrCorrection::LawleyLrEstimatedLambda => "lawley_lr_estimated_lambda",
8971 SmoothLrCorrection::LawleyLrFixedLambda => "lawley_lr_fixed_lambda",
8972 SmoothLrCorrection::None => "none",
8973 }
8974 }
8975}
8976
8977#[derive(Clone, Debug)]
8983pub struct SmoothTermLrInference {
8984 pub name: String,
8986 pub term_idx: usize,
8988 pub statistic_lr: f64,
8991 pub ref_df: f64,
8994 pub bartlett_factor: f64,
8997 pub bartlett_factor_conditional: Option<f64>,
9001 pub rho_variation_shift: Option<f64>,
9004 pub statistic_corrected: f64,
9006 pub p_value_uncorrected: f64,
9008 pub p_value_corrected: f64,
9011 pub material: bool,
9019 pub correction: SmoothLrCorrection,
9021}
9022
9023pub const SMOOTH_LR_MATERIAL_THRESHOLD: f64 = 0.10;
9027
9028fn fitted_rho_penalty_components(
9034 penalties: &[BlockwisePenalty],
9035 lambdas: &[f64],
9036 p_total: usize,
9037) -> Result<Vec<gam_terms::inference::lawley::RhoPenaltyComponent>, EstimationError> {
9038 if penalties.len() != lambdas.len() {
9039 return Err(EstimationError::InvalidInput(format!(
9040 "smooth_term_lr_inference: penalty/lambda count mismatch ({} penalties, {} lambdas)",
9041 penalties.len(),
9042 lambdas.len()
9043 )));
9044 }
9045 let mut components = Vec::with_capacity(penalties.len());
9046 for (idx, (penalty, &lambda)) in penalties.iter().zip(lambdas.iter()).enumerate() {
9047 if !(lambda.is_finite() && lambda >= 0.0) {
9048 return Err(EstimationError::InvalidInput(format!(
9049 "smooth_term_lr_inference: lambda[{idx}] is invalid: {lambda}"
9050 )));
9051 }
9052 let r = &penalty.col_range;
9053 if r.end > p_total {
9054 return Err(EstimationError::InvalidInput(format!(
9055 "smooth_term_lr_inference: penalty[{idx}] range {:?} exceeds coefficient dimension {p_total}",
9056 r
9057 )));
9058 }
9059 let mut s_component = Array2::<f64>::zeros((p_total, p_total));
9060 s_component
9061 .slice_mut(s![r.start..r.end, r.start..r.end])
9062 .scaled_add(lambda, &penalty.local);
9063 components.push(gam_terms::inference::lawley::RhoPenaltyComponent { s_component });
9064 }
9065 Ok(components)
9066}
9067
9068pub fn smooth_term_lr_inference_forspec(
9113 data: ArrayView2<'_, f64>,
9114 y: ArrayView1<'_, f64>,
9115 weights: ArrayView1<'_, f64>,
9116 offset: ArrayView1<'_, f64>,
9117 resolvedspec: &TermCollectionSpec,
9118 family: LikelihoodSpec,
9119 options: &FitOptions,
9120) -> Result<Vec<SmoothTermLrInference>, EstimationError> {
9121 use gam_terms::inference::lawley::{
9122 LAWLEY_PAIR_MATRIX_MAX_ROWS, known_scale_expected_jets_with_dispersion,
9123 lawley_lr_bartlett_factor, lawley_lr_mean_shift_with_rho_variation,
9124 };
9125
9126 let n = data.nrows();
9127 let full = fit_term_collection_forspec(
9130 data,
9131 y,
9132 weights,
9133 offset,
9134 resolvedspec,
9135 family.clone(),
9136 options,
9137 )?;
9138 let ll_full = full.fit.log_likelihood;
9139 let p_total = full.design.design.ncols();
9140 let lambdas = full.fit.lambdas.as_slice().ok_or_else(|| {
9141 EstimationError::InvalidInput(
9142 "smooth_term_lr_inference: non-contiguous lambda vector".to_string(),
9143 )
9144 })?;
9145 let s_lambda = weighted_blockwise_penalty_sum(&full.design.penalties, lambdas, p_total);
9146 let rho_penalty_components =
9147 fitted_rho_penalty_components(&full.design.penalties, lambdas, p_total)?;
9148 let rho_covariance = full.fit.artifacts.rho_covariance.as_ref().filter(|cov| {
9149 cov.nrows() == rho_penalty_components.len() && cov.ncols() == rho_penalty_components.len()
9150 });
9151 let full_design_dense = full.design.design.to_dense();
9153 let influence = full.fit.coefficient_influence();
9154 let family_disp = lawley_dispersion_for_family(&family, &full.fit);
9155
9156 let mut penalty_cursor = full.design.random_effect_ranges.len();
9159 let mut out = Vec::<SmoothTermLrInference>::new();
9160 for (term_idx, design_term) in full.design.smooth.terms.iter().enumerate() {
9161 let k = design_term.penalties_local.len();
9162 let block_start = penalty_cursor;
9163 penalty_cursor += k;
9164 if design_term.shape != ShapeConstraint::None {
9167 continue;
9168 }
9169 let coeff_range = design_term.coeff_range.clone();
9170 if coeff_range.start >= coeff_range.end || coeff_range.end > p_total {
9171 continue;
9172 }
9173 let edf = full.fit.per_term_edf(coeff_range.clone(), block_start, k);
9185 let null_dim = design_term.wald_unpenalized_dim();
9205 let edf_floor = (null_dim.max(1)) as f64;
9257 let untrusted_edf_collapse = !full.fit.outer_converged && edf < edf_floor;
9258 let unconverged_dim_floor = if untrusted_edf_collapse {
9259 coeff_range.len() as f64
9260 } else {
9261 0.0
9262 };
9263 let ref_df = wood_reference_df(influence, &coeff_range)
9264 .unwrap_or(0.0)
9265 .max(edf)
9266 .max(null_dim as f64)
9267 .max(unconverged_dim_floor)
9268 .max(1.0);
9269 if !(ref_df.is_finite() && ref_df > 0.0) {
9270 continue;
9271 }
9272
9273 let mut null_spec = resolvedspec.clone();
9276 let Some(spec_pos) = null_spec
9277 .smooth_terms
9278 .iter()
9279 .position(|t| t.name == design_term.name)
9280 else {
9281 continue;
9282 };
9283 null_spec.smooth_terms.remove(spec_pos);
9284 let null_fit = fit_term_collection_forspec(
9285 data,
9286 y,
9287 weights,
9288 offset,
9289 &null_spec,
9290 family.clone(),
9291 options,
9292 );
9293 let (statistic_lr, eta_null) = match null_fit {
9294 Ok(null) if null.fit.log_likelihood.is_finite() => {
9295 let w = (2.0 * (ll_full - null.fit.log_likelihood)).max(0.0);
9296 let mut eta = null.design.design.dot(&null.fit.beta);
9300 eta += &offset;
9301 (w, Some(eta))
9302 }
9303 _ => (f64::NAN, None),
9304 };
9305
9306 let chi2 = statrs::distribution::ChiSquared::new(ref_df).ok();
9307 let p_uncorrected = match (chi2.as_ref(), statistic_lr.is_finite()) {
9308 (Some(dist), true) => {
9309 use statrs::distribution::ContinuousCDF;
9310 (1.0 - dist.cdf(statistic_lr)).clamp(0.0, 1.0)
9311 }
9312 _ => f64::NAN,
9313 };
9314
9315 let mut bartlett_factor = 1.0;
9319 let mut bartlett_factor_conditional = None;
9320 let mut rho_variation_shift = None;
9321 let mut statistic_corrected = statistic_lr;
9322 let mut p_corrected = p_uncorrected;
9323 let mut correction = SmoothLrCorrection::None;
9324 if let (Some(eta), true, true) = (
9325 eta_null.as_ref(),
9326 statistic_lr.is_finite(),
9327 n <= LAWLEY_PAIR_MATRIX_MAX_ROWS,
9328 ) {
9329 let kappas: Option<Vec<_>> = (0..n)
9330 .map(|i| {
9331 known_scale_expected_jets_with_dispersion(&family, eta[i], family_disp)
9332 .and_then(|jets| jets.kappas().ok())
9333 })
9334 .collect();
9335 if let (Some(kappas), Some(dist)) = (kappas, chi2.as_ref()) {
9336 let fixed_factor = lawley_lr_bartlett_factor(
9337 full_design_dense.view(),
9338 &kappas,
9339 Some(s_lambda.view()),
9340 coeff_range.clone(),
9341 ref_df,
9342 );
9343 if let Ok(c_cond) = fixed_factor
9344 && c_cond.is_finite()
9345 && c_cond > 0.0
9346 {
9347 let mut c_applied = c_cond;
9348 correction = SmoothLrCorrection::LawleyLrFixedLambda;
9349 if let Some(cov) = rho_covariance
9350 && let Ok(total_shift) = lawley_lr_mean_shift_with_rho_variation(
9351 full_design_dense.view(),
9352 &kappas,
9353 s_lambda.view(),
9354 coeff_range.clone(),
9355 &rho_penalty_components,
9356 cov.view(),
9357 )
9358 {
9359 let mean_w = ref_df + total_shift;
9360 if let Some(c_est) =
9361 gam_terms::inference::higher_order::bartlett_factor_from_mean(
9362 mean_w, ref_df,
9363 )
9364 && c_est.is_finite()
9365 && c_est > 0.0
9366 {
9367 let conditional_shift = (c_cond - 1.0) * ref_df;
9368 c_applied = c_est;
9369 bartlett_factor_conditional = Some(c_cond);
9370 rho_variation_shift = Some(total_shift - conditional_shift);
9371 correction = SmoothLrCorrection::LawleyLrEstimatedLambda;
9372 }
9373 }
9374 use statrs::distribution::ContinuousCDF;
9375 bartlett_factor = c_applied;
9376 statistic_corrected = statistic_lr / c_applied;
9377 p_corrected = (1.0 - dist.cdf(statistic_corrected)).clamp(0.0, 1.0);
9378 }
9379 }
9380 }
9381
9382 let material = match correction {
9388 SmoothLrCorrection::LawleyLrEstimatedLambda
9389 | SmoothLrCorrection::LawleyLrFixedLambda => {
9390 let factor_move = (bartlett_factor - 1.0).abs();
9391 let p_denom = p_uncorrected.max(p_corrected).max(f64::MIN_POSITIVE);
9392 let p_move = if p_uncorrected.is_finite() && p_corrected.is_finite() {
9393 (p_corrected - p_uncorrected).abs() / p_denom
9394 } else {
9395 0.0
9396 };
9397 factor_move > SMOOTH_LR_MATERIAL_THRESHOLD || p_move > SMOOTH_LR_MATERIAL_THRESHOLD
9398 }
9399 SmoothLrCorrection::None => false,
9400 };
9401
9402 out.push(SmoothTermLrInference {
9403 name: design_term.name.clone(),
9404 term_idx,
9405 statistic_lr,
9406 ref_df,
9407 bartlett_factor,
9408 bartlett_factor_conditional,
9409 rho_variation_shift,
9410 statistic_corrected,
9411 p_value_uncorrected: p_uncorrected,
9412 p_value_corrected: p_corrected,
9413 material,
9414 correction,
9415 });
9416 }
9417 Ok(out)
9418}
9419
9420fn lawley_dispersion_for_family(family: &LikelihoodSpec, fit: &UnifiedFitResult) -> f64 {
9423 match family.response {
9424 gam_spec::ResponseFamily::Gaussian => {
9425 let sd = fit.standard_deviation;
9426 (sd * sd).max(f64::MIN_POSITIVE)
9427 }
9428 gam_spec::ResponseFamily::Gamma => {
9429 let shape = fit.standard_deviation;
9430 if shape.is_finite() && shape > 0.0 {
9431 1.0 / shape
9432 } else {
9433 1.0
9434 }
9435 }
9436 _ => 1.0,
9437 }
9438}
9439
9440fn wood_reference_df(influence: Option<&Array2<f64>>, coeff_range: &Range<usize>) -> Option<f64> {
9464 let f = influence?;
9465 let (start, end) = (coeff_range.start, coeff_range.end);
9466 if start >= end || end > f.nrows() || end > f.ncols() {
9467 return None;
9468 }
9469 let block = f.slice(s![start..end, start..end]);
9470 let tr = (0..block.nrows()).map(|i| block[[i, i]]).sum::<f64>();
9471 let tr2 = block.dot(&block).diag().sum();
9472 (tr.is_finite() && tr2.is_finite() && tr > 0.0)
9473 .then(|| (2.0 * tr - tr2).max(tr).max(1e-12))
9474}