1fn try_build_spatial_term_log_kappa_derivative(
2 data: ArrayView2<'_, f64>,
3 resolvedspec: &TermCollectionSpec,
4 design: &TermCollectionDesign,
5 term_idx: usize,
6) -> Result<
7 Option<(
8 Range<usize>,
9 usize,
10 Array2<f64>,
11 Array2<f64>,
12 Array2<f64>,
13 Array2<f64>,
14 Vec<Array2<f64>>,
15 Vec<Array2<f64>>,
16 Option<std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>>,
17 )>,
18 EstimationError,
19> {
20 let Some(smooth_term) = design.smooth.terms.get(term_idx) else {
21 return Ok(None);
22 };
23 let Some(termspec) = resolvedspec.smooth_terms.get(term_idx) else {
24 return Ok(None);
25 };
26
27 let derivative_bundle = match &termspec.basis {
28 SmoothBasisSpec::ThinPlate {
29 feature_cols,
30 spec,
31 input_scales,
32 } => {
33 let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
34 let mut spec_local = spec.clone();
35 if let Some(s) = input_scales {
36 apply_input_standardization(&mut x, s);
37 spec_local.length_scale =
38 compensate_length_scale_for_standardization(spec.length_scale, s);
39 }
40 build_thin_plate_basis_log_kappa_derivatives(x.view(), &spec_local)
41 .map_err(EstimationError::from)?
42 }
43 SmoothBasisSpec::Sphere { .. } => return Ok(None),
44 SmoothBasisSpec::ConstantCurvature { feature_cols, spec } => {
53 let x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
54 build_constant_curvature_basis_kappa_derivatives(x.view(), spec)
55 .map_err(EstimationError::from)?
56 }
57 SmoothBasisSpec::MeasureJet { .. } => return Ok(None),
63 SmoothBasisSpec::Matern {
64 feature_cols,
65 spec,
66 input_scales,
67 } => {
68 let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
69 let mut spec_local = spec.clone();
70 if let Some(s) = input_scales {
71 apply_input_standardization(&mut x, s);
72 spec_local.length_scale =
73 compensate_length_scale_for_standardization(spec.length_scale, s);
74 }
75 spec_local.double_penalty = false;
90 build_matern_basis_log_kappa_derivatives(x.view(), &spec_local)
91 .map_err(EstimationError::from)?
92 }
93 SmoothBasisSpec::Duchon {
94 feature_cols,
95 spec,
96 input_scales,
97 } => {
98 let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
99 let mut spec_local = spec.clone();
100 if let Some(s) = input_scales {
101 apply_input_standardization(&mut x, s);
102 spec_local.length_scale =
103 compensate_optional_length_scale_for_standardization(spec.length_scale, s);
104 }
105 let BasisMetadata::Duchon {
106 centers,
107 identifiability_transform,
108 operator_collocation_points,
109 radial_reparam,
110 ..
111 } = &smooth_term.metadata
112 else {
113 return Ok(None);
114 };
115 if spec_local.radial_reparam.is_none() {
118 spec_local.radial_reparam = radial_reparam.clone();
119 }
120 gam_terms::basis::build_duchon_basis_log_kappa_derivativeswith_collocationwithworkspace(
121 x.view(),
122 &spec_local,
123 centers.view(),
124 identifiability_transform.as_ref(),
125 operator_collocation_points
126 .as_ref()
127 .map(|points| points.view()),
128 &mut BasisWorkspace::default(),
129 )
130 .map_err(EstimationError::from)?
131 }
132 SmoothBasisSpec::BSpline1D { .. }
133 | SmoothBasisSpec::TensorBSpline { .. }
134 | SmoothBasisSpec::ByVariable { .. }
135 | SmoothBasisSpec::FactorSumToZero { .. }
136 | SmoothBasisSpec::BySmooth { .. }
137 | SmoothBasisSpec::FactorSmooth { .. }
138 | SmoothBasisSpec::Pca { .. } => {
139 return Ok(None);
140 }
141 };
142 let mut implicit_operator = derivative_bundle.implicit_operator;
143 let BasisPsiDerivativeResult {
144 design_derivative: mut local_x_psi,
145 penalties_derivative: mut local_s_psi,
146 implicit_operator: local_implicit_first_unused,
147 } = derivative_bundle.first;
148 let BasisPsiSecondDerivativeResult {
149 designsecond_derivative: mut local_x_psi_psi,
150 penaltiessecond_derivative: mut local_s_psi_psi,
151 implicit_operator: local_implicit_second_unused,
152 } = derivative_bundle.second;
153 assert!(local_implicit_first_unused.is_none());
154 assert!(local_implicit_second_unused.is_none());
155
156 if let SmoothBasisSpec::Matern {
160 feature_cols,
161 spec,
162 input_scales,
163 } = &termspec.basis
164 {
165 let mut xf = select_columns(data, feature_cols).map_err(EstimationError::from)?;
166 let mut sp = spec.clone();
167 if let Some(s) = input_scales {
168 apply_input_standardization(&mut xf, s.as_slice());
169 sp.length_scale =
170 compensate_length_scale_for_standardization(spec.length_scale, s.as_slice());
171 }
172 sp.double_penalty = false;
173 let ls0 = sp.length_scale;
175 let h = 1e-6_f64;
176 let value_design = |delta: f64| -> Option<Array2<f64>> {
177 let mut s2 = sp.clone();
178 s2.length_scale = ls0 * (-delta).exp();
179 gam_terms::basis::build_matern_basis(xf.view(), &s2)
180 .ok()
181 .map(|b| b.design.to_dense())
182 };
183 if let (Some(dp), Some(dm)) = (value_design(h), value_design(-h)) {
184 if dp.shape() == local_x_psi.shape() {
185 let num = (&dp - &dm) / (2.0 * h);
186 let err = (&local_x_psi - &num)
187 .mapv(f64::abs)
188 .iter()
189 .fold(0.0_f64, |a, &b| a.max(b));
190 let anorm = local_x_psi.iter().map(|v| v * v).sum::<f64>().sqrt();
191 log::warn!(
192 "[OUTER-FD-AUDIT XPSIFD-1122] x_psi_max_abs_err={err:.3e} |x_psi|={anorm:.3e} \
193 shape={:?} centers_frozen={}",
194 local_x_psi.shape(),
195 match &sp.center_strategy {
196 gam_terms::basis::CenterStrategy::UserProvided(c) => c.nrows(),
197 _ => 0,
198 },
199 );
200 if let Some(d0) = value_design(0.0) {
204 let p_total = design.design.ncols();
205 let smooth_start =
206 p_total.saturating_sub(design.smooth.total_smooth_cols());
207 let g0 = smooth_start + smooth_term.coeff_range.start;
208 let g1 = smooth_start + smooth_term.coeff_range.end;
209 let realized = design.design.to_dense();
210 if g1 <= realized.ncols() && d0.ncols() == (g1 - g0) {
211 let block = realized.slice(ndarray::s![.., g0..g1]).to_owned();
212 let dmax = (&block - &d0)
213 .mapv(f64::abs)
214 .iter()
215 .fold(0.0_f64, |a, &b| a.max(b));
216 let bnorm = block.iter().map(|v| v * v).sum::<f64>().sqrt();
218 let d0norm = d0.iter().map(|v| v * v).sum::<f64>().sqrt();
219 log::warn!(
220 "[OUTER-FD-AUDIT XDESIGN-1122] realized_vs_valuebuild_max_abs={dmax:.3e} \
221 |realized_block|={bnorm:.4e} |value_build|={d0norm:.4e} \
222 block_shape={:?} g0={g0} g1={g1} p_total={p_total}",
223 block.shape(),
224 );
225 } else {
226 log::warn!(
227 "[OUTER-FD-AUDIT XDESIGN-1122] shape/range mismatch realized_cols={} g0={g0} g1={g1} d0_cols={}",
228 realized.ncols(),
229 d0.ncols()
230 );
231 }
232 }
233 } else {
234 log::warn!(
235 "[OUTER-FD-AUDIT XPSIFD-1122] shape mismatch analytic={:?} value={:?}",
236 local_x_psi.shape(),
237 dp.shape()
238 );
239 }
240 }
241 }
242
243 {
253 use gam_terms::smooth::matern_operator_penalty_triplet_at_length_scale;
254 if let BasisMetadata::Matern {
255 centers,
256 length_scale,
257 periodic,
258 nu,
259 include_intercept,
260 identifiability_transform,
261 aniso_log_scales,
262 input_scales,
263 ..
264 } = &smooth_term.metadata
265 {
266 let ls_eff = match input_scales.as_deref() {
271 Some(s) => compensate_length_scale_for_standardization(*length_scale, s),
272 None => *length_scale,
273 };
274 let h = 1e-6_f64;
276 let summed_at = |delta: f64| -> Option<Array2<f64>> {
277 let ls = ls_eff * (-delta).exp();
278 let (mats, _ranks, _info) = matern_operator_penalty_triplet_at_length_scale(
279 centers.view(),
280 periodic.as_deref(),
281 identifiability_transform.as_ref(),
282 *nu,
283 *include_intercept,
284 aniso_log_scales.as_deref(),
285 ls,
286 )
287 .ok()?;
288 let p = mats.first().map(|m: &Array2<f64>| m.nrows()).unwrap_or(0);
289 let mut acc = Array2::<f64>::zeros((p, p));
290 for m in &mats {
291 if m.shape() == acc.shape() {
292 acc += m;
293 }
294 }
295 Some(acc)
296 };
297 let analytic_sum = local_s_psi.iter().fold(
298 Array2::<f64>::zeros((
299 smooth_term.coeff_range.len(),
300 smooth_term.coeff_range.len(),
301 )),
302 |acc, m| {
303 if m.shape() == acc.shape() {
304 acc + m
305 } else {
306 acc
307 }
308 },
309 );
310 if let (Some(sp), Some(sm), Some(s0)) =
311 (summed_at(h), summed_at(-h), summed_at(0.0))
312 {
313 if sp.shape() == analytic_sum.shape() {
314 let num = (&sp - &sm) / (2.0 * h);
315 let err = (&analytic_sum - &num)
316 .mapv(f64::abs)
317 .iter()
318 .fold(0.0_f64, |a, &b| a.max(b));
319 let anorm = analytic_sum.iter().map(|v| v * v).sum::<f64>().sqrt();
320 let nnorm = num.iter().map(|v| v * v).sum::<f64>().sqrt();
321 log::warn!(
325 "[OUTER-FD-AUDIT SPSIFD-1122] s_psi_max_abs_err={err:.3e} \
326 |analytic|={anorm:.3e} |fd|={nnorm:.3e} blocks={} \
327 shape={:?} |S0|={:.3e}",
328 local_s_psi.len(),
329 analytic_sum.shape(),
330 s0.iter().map(|v| v * v).sum::<f64>().sqrt(),
331 );
332 let per_block = |delta: f64| -> Option<Vec<Array2<f64>>> {
335 let ls = ls_eff * (-delta).exp();
336 matern_operator_penalty_triplet_at_length_scale(
337 centers.view(),
338 periodic.as_deref(),
339 identifiability_transform.as_ref(),
340 *nu,
341 *include_intercept,
342 aniso_log_scales.as_deref(),
343 ls,
344 )
345 .ok()
346 .map(|(m, _, _)| m)
347 };
348 if let (Some(bp), Some(bm)) = (per_block(h), per_block(-h)) {
349 for (bi, (ap, am)) in bp.iter().zip(bm.iter()).enumerate() {
350 if bi < local_s_psi.len()
351 && ap.shape() == local_s_psi[bi].shape()
352 {
353 let bnum = (ap - am) / (2.0 * h);
354 let berr = (&local_s_psi[bi] - &bnum)
355 .mapv(f64::abs)
356 .iter()
357 .fold(0.0_f64, |a, &b| a.max(b));
358 let banorm =
359 local_s_psi[bi].iter().map(|v| v * v).sum::<f64>().sqrt();
360 log::warn!(
361 "[OUTER-FD-AUDIT SPSIFD-1122 BLOCK] block={bi} max_abs_err={berr:.3e} |analytic|={banorm:.3e}"
362 );
363 }
364 }
365 }
366 } else {
367 log::warn!(
368 "[OUTER-FD-AUDIT SPSIFD-1122] shape mismatch analytic={:?} fd={:?}",
369 analytic_sum.shape(),
370 sp.shape()
371 );
372 }
373 }
374 }
375 }
376
377 log::warn!(
380 "[OUTER-FD-AUDIT TEMP-ROT-1122] analytic joint_null_rotation={} nullity={} x_psi={}x{}",
381 smooth_term.joint_null_rotation.is_some(),
382 smooth_term
383 .joint_null_rotation
384 .as_ref()
385 .map(|r| r.joint_nullity)
386 .unwrap_or(0),
387 local_x_psi.nrows(),
388 local_x_psi.ncols(),
389 );
390 if let Some(rotation) = smooth_term.joint_null_rotation.as_ref() {
391 let q = &rotation.rotation;
392 if let Some(op) = implicit_operator.take() {
393 implicit_operator = Some(op.append_full_transform(q).map_err(EstimationError::from)?);
394 } else {
395 if local_x_psi.ncols() != q.nrows() || local_x_psi_psi.ncols() != q.nrows() {
396 return Ok(None);
397 }
398 local_x_psi = fast_ab(&local_x_psi, q);
399 local_x_psi_psi = fast_ab(&local_x_psi_psi, q);
400 }
401 let rotate_penalty = |s_local: Array2<f64>| -> Option<Array2<f64>> {
402 if s_local.nrows() != q.nrows() || s_local.ncols() != q.nrows() {
403 return None;
404 }
405 let qt_s = gam_linalg::faer_ndarray::fast_atb(q, &s_local);
406 Some(gam_linalg::faer_ndarray::fast_ab(&qt_s, q))
407 };
408 let Some(rotated_s_psi) = local_s_psi
409 .into_iter()
410 .map(|s| rotate_penalty(s))
411 .collect::<Option<Vec<_>>>()
412 else {
413 return Ok(None);
414 };
415 local_s_psi = rotated_s_psi;
416 let Some(rotated_s_psi_psi) = local_s_psi_psi
417 .into_iter()
418 .map(|s| rotate_penalty(s))
419 .collect::<Option<Vec<_>>>()
420 else {
421 return Ok(None);
422 };
423 local_s_psi_psi = rotated_s_psi_psi;
424 }
425 let implicit_operator = implicit_operator.map(std::sync::Arc::new);
426
427 if let Some(ref op) = implicit_operator {
428 if op.p_out() != smooth_term.coeff_range.len() {
429 return Ok(None);
430 }
431 } else {
432 if local_x_psi.ncols() != smooth_term.coeff_range.len() {
433 return Ok(None);
434 }
435 if local_x_psi_psi.ncols() != smooth_term.coeff_range.len() {
436 return Ok(None);
437 }
438 }
439 if local_s_psi.is_empty() || local_s_psi.len() != local_s_psi_psi.len() {
440 return Ok(None);
441 }
442 if local_s_psi.iter().any(|s| {
443 s.nrows() != smooth_term.coeff_range.len() || s.ncols() != smooth_term.coeff_range.len()
444 }) {
445 return Ok(None);
446 }
447 if local_s_psi_psi.iter().any(|s| {
448 s.nrows() != smooth_term.coeff_range.len() || s.ncols() != smooth_term.coeff_range.len()
449 }) {
450 return Ok(None);
451 }
452
453 let p_total = design.design.ncols();
454 let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
455 let global_range = (smooth_start + smooth_term.coeff_range.start)
456 ..(smooth_start + smooth_term.coeff_range.end);
457
458 Ok(Some((
459 global_range,
460 p_total,
461 local_x_psi,
462 local_s_psi.iter().fold(
463 Array2::<f64>::zeros((smooth_term.coeff_range.len(), smooth_term.coeff_range.len())),
464 |acc, m| acc + m,
465 ),
466 local_x_psi_psi,
467 local_s_psi_psi.iter().fold(
468 Array2::<f64>::zeros((smooth_term.coeff_range.len(), smooth_term.coeff_range.len())),
469 |acc, m| acc + m,
470 ),
471 local_s_psi,
472 local_s_psi_psi,
473 implicit_operator,
474 )))
475}
476
477fn try_build_spatial_log_kappa_hyper_dirs(
478 data: ArrayView2<'_, f64>,
479 resolvedspec: &TermCollectionSpec,
480 design: &TermCollectionDesign,
481 spatial_terms: &[usize],
482) -> Result<Option<Vec<DirectionalHyperParam>>, EstimationError> {
483 let Some(info_list) =
490 try_build_spatial_log_kappa_derivativeinfo_list(data, resolvedspec, design, spatial_terms)?
491 else {
492 return Ok(None);
493 };
494 Ok(Some(spatial_log_kappa_hyper_dirs_frominfo_list(info_list)?))
495}
496
497pub(crate) fn try_build_latent_coord_hyper_dirs(
498 latent: std::sync::Arc<gam_terms::latent::LatentCoordValues>,
499 resolvedspec: &TermCollectionSpec,
500 design: &TermCollectionDesign,
501 latent_terms: &[gam_problem::types::SmoothTermIdx],
502 analytic_rho_count: usize,
503) -> Result<Option<Vec<DirectionalHyperParam>>, EstimationError> {
504 if latent_terms.is_empty() || latent.is_empty() {
505 return Ok(None);
506 }
507 if latent_terms.len() != 1 {
508 crate::bail_invalid_estim!(
509 "LatentCoord standard-fit hyper_dirs currently require exactly one latent smooth term"
510 .to_string(),
511 );
512 }
513 let term_idx = latent_terms[0];
514 let smooth_term = design.smooth.terms.get(term_idx.get()).ok_or_else(|| {
515 EstimationError::InvalidInput(format!(
516 "LatentCoord term index {term_idx} out of bounds for realized smooth design"
517 ))
518 })?;
519 let termspec = resolvedspec
520 .smooth_terms
521 .get(term_idx.get())
522 .ok_or_else(|| {
523 EstimationError::InvalidInput(format!(
524 "LatentCoord term index {term_idx} out of bounds for resolved smooth spec"
525 ))
526 })?;
527 let p_total = design.design.ncols();
528 let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
529 let global_range = (smooth_start + smooth_term.coeff_range.start)
530 ..(smooth_start + smooth_term.coeff_range.end);
531
532 let operator = match (&termspec.basis, &smooth_term.metadata) {
537 (
538 SmoothBasisSpec::Matern { .. },
539 BasisMetadata::Matern {
540 centers,
541 length_scale,
542 nu,
543 include_intercept,
544 identifiability_transform,
545 ..
546 },
547 ) => gam_terms::basis::LatentCoordDesignDerivative::new_matern(
548 latent.clone(),
549 std::sync::Arc::new(centers.clone()),
550 *length_scale,
551 *nu,
552 *include_intercept,
553 identifiability_transform.clone(),
554 )
555 .map_err(EstimationError::from)?,
556 (
557 SmoothBasisSpec::Duchon { .. },
558 BasisMetadata::Duchon {
559 centers,
560 length_scale,
561 power,
562 nullspace_order,
563 identifiability_transform,
564 ..
565 },
566 ) => gam_terms::basis::LatentCoordDesignDerivative::new_duchon(
567 latent.clone(),
568 std::sync::Arc::new(centers.clone()),
569 *length_scale,
570 *power,
571 *nullspace_order,
572 identifiability_transform.clone(),
573 )
574 .map_err(EstimationError::from)?,
575 (
576 SmoothBasisSpec::Sphere { .. },
577 BasisMetadata::Sphere {
578 centers,
579 penalty_order,
580 method,
581 constraint_transform,
582 ..
583 },
584 ) if matches!(*method, gam_terms::basis::SphereMethod::Wahba) => {
585 gam_terms::basis::LatentCoordDesignDerivative::new_sphere(
586 latent.clone(),
587 std::sync::Arc::new(centers.clone()),
588 *penalty_order,
589 constraint_transform.clone(),
590 )
591 .map_err(EstimationError::from)?
592 }
593 (
594 SmoothBasisSpec::BSpline1D { spec, .. },
595 BasisMetadata::BSpline1D {
596 knots,
597 identifiability_transform,
598 periodic,
599 degree: meta_degree,
600 ..
601 },
602 ) => {
603 let effective_degree = meta_degree.unwrap_or(spec.degree);
607 if let Some((domain_start, period, num_basis)) = periodic {
608 gam_terms::basis::LatentCoordDesignDerivative::new_periodic_bspline(
609 latent.clone(),
610 (*domain_start, *domain_start + *period),
611 effective_degree,
612 *num_basis,
613 identifiability_transform.clone(),
614 )
615 .map_err(EstimationError::from)?
616 } else {
617 gam_terms::basis::LatentCoordDesignDerivative::new_tensor_bspline(
618 latent.clone(),
619 vec![knots.clone()],
620 vec![effective_degree],
621 identifiability_transform.clone(),
622 )
623 .map_err(EstimationError::from)?
624 }
625 }
626 (
627 SmoothBasisSpec::TensorBSpline { .. },
628 BasisMetadata::TensorBSpline {
629 knots,
630 degrees,
631 identifiability_transform,
632 ..
633 },
634 ) => gam_terms::basis::LatentCoordDesignDerivative::new_tensor_bspline(
635 latent.clone(),
636 knots.clone(),
637 degrees.clone(),
638 identifiability_transform.clone(),
639 )
640 .map_err(EstimationError::from)?,
641 (SmoothBasisSpec::Pca { .. }, BasisMetadata::Pca { basis_matrix, .. }) => {
642 gam_terms::basis::LatentCoordDesignDerivative::new_pca(
643 latent.clone(),
644 std::sync::Arc::new(basis_matrix.clone()),
645 )
646 .map_err(EstimationError::from)?
647 }
648 _ => return Ok(None),
649 };
650 if operator.p_out() != global_range.len() {
651 crate::bail_invalid_estim!(
652 "LatentCoord derivative width mismatch for term '{}': operator p={}, coeff range={}",
653 smooth_term.name,
654 operator.p_out(),
655 global_range.len()
656 );
657 }
658 let operator = std::sync::Arc::new(operator);
659 let mut hyper_dirs = Vec::with_capacity(operator.n_axes());
660 for flat_axis in 0..operator.n_axes() {
661 let dir = DirectionalHyperParam::new_compact(
662 gam_solve::estimate::reml::HyperDesignDerivative::from_latent_coord(
663 operator.clone(),
664 flat_axis,
665 global_range.clone(),
666 p_total,
667 ),
668 Vec::new(),
669 None,
670 None,
671 )?
672 .not_penalty_like();
673 hyper_dirs.push(dir);
674 }
675 let direct_dim = latent_coord_direct_hyper_count(latent.id_mode(), latent.latent_dim());
676 if analytic_rho_count + direct_dim > 0 {
677 let zero_x = gam_solve::estimate::reml::HyperDesignDerivative::from(Array2::<f64>::zeros((
678 design.design.nrows(),
679 p_total,
680 )));
681 for _ in 0..analytic_rho_count {
682 hyper_dirs.push(
683 DirectionalHyperParam::new_compact(zero_x.clone(), Vec::new(), None, None)?
684 .not_penalty_like(),
685 );
686 }
687 for _ in 0..direct_dim {
688 hyper_dirs.push(
689 DirectionalHyperParam::new_compact(zero_x.clone(), Vec::new(), None, None)?
690 .not_penalty_like(),
691 );
692 }
693 }
694 Ok(Some(hyper_dirs))
695}
696
697fn latent_coord_direct_hyper_count(
698 id_mode: &gam_terms::latent::LatentIdMode,
699 latent_dim: usize,
700) -> usize {
701 use gam_terms::latent::{AuxPriorStrength, LatentIdMode};
702 match id_mode {
703 LatentIdMode::AuxPrior { strength, .. } => match strength {
704 AuxPriorStrength::Auto => 1,
705 AuxPriorStrength::Fixed(_) => 0,
706 },
707 LatentIdMode::AuxPriorDimSelection { strength, .. } => {
708 latent_dim
709 + match strength {
710 AuxPriorStrength::Auto => 1,
711 AuxPriorStrength::Fixed(_) => 0,
712 }
713 }
714 LatentIdMode::DimSelection { .. } => latent_dim,
715 LatentIdMode::IsometryToReference { strength, .. } => match strength {
718 AuxPriorStrength::Auto => 1,
719 AuxPriorStrength::Fixed(_) => 0,
720 },
721 LatentIdMode::AuxOutcome { head, .. } => head.n_coeffs(latent_dim) + latent_dim,
724 LatentIdMode::None => 0,
725 }
726}
727
728fn latent_coord_initial_direct_hypers(
729 id_mode: &gam_terms::latent::LatentIdMode,
730 latent_dim: usize,
731) -> Result<Array1<f64>, EstimationError> {
732 use gam_terms::latent::{AuxPriorStrength, LatentIdMode};
733 let mut values = Vec::with_capacity(latent_coord_direct_hyper_count(id_mode, latent_dim));
734 match id_mode {
735 LatentIdMode::AuxPrior { strength, .. } => {
736 if matches!(strength, AuxPriorStrength::Auto) {
737 values.push(0.0);
738 }
739 }
740 LatentIdMode::AuxPriorDimSelection {
741 strength,
742 init_log_precision,
743 ..
744 } => {
745 if matches!(strength, AuxPriorStrength::Auto) {
746 values.push(0.0);
747 }
748 append_latent_ard_seed(&mut values, init_log_precision.as_ref(), latent_dim)?;
749 }
750 LatentIdMode::DimSelection { init_log_precision } => {
751 append_latent_ard_seed(&mut values, init_log_precision.as_ref(), latent_dim)?;
752 }
753 LatentIdMode::IsometryToReference { strength, .. } => {
754 if matches!(strength, AuxPriorStrength::Auto) {
755 values.push(0.0);
756 }
757 }
758 LatentIdMode::AuxOutcome {
759 head,
760 init_log_precision,
761 } => {
762 values.extend(std::iter::repeat_n(0.0, head.n_coeffs(latent_dim)));
766 append_latent_ard_seed(&mut values, init_log_precision.as_ref(), latent_dim)?;
767 }
768 LatentIdMode::None => {}
769 }
770 Ok(Array1::from_vec(values))
771}
772
773fn append_latent_ard_seed(
774 values: &mut Vec<f64>,
775 init: Option<&Array1<f64>>,
776 latent_dim: usize,
777) -> Result<(), EstimationError> {
778 if let Some(init) = init {
779 if init.len() != latent_dim {
780 crate::bail_invalid_estim!(
781 "latent dim_selection init_log_precision length mismatch: got {}, expected {}",
782 init.len(),
783 latent_dim
784 );
785 }
786 values.extend(init.iter().copied());
787 } else {
788 values.extend(std::iter::repeat_n(0.0, latent_dim));
789 }
790 Ok(())
791}
792
793struct LatentIdObjectiveContribution {
794 cost: f64,
795 gradient: Array1<f64>,
796}
797
798fn latent_id_objective_contribution(
799 theta: &Array1<f64>,
800 rho_dim: usize,
801 analytic_rho_count: usize,
802 latent: &gam_terms::latent::LatentCoordValues,
803) -> Result<LatentIdObjectiveContribution, EstimationError> {
804 use gam_terms::latent::{AuxPriorStrength, LatentIdMode, aux_prior_targets};
805 let n_obs = latent.n_obs();
806 let latent_dim = latent.latent_dim();
807 let flat_len = latent.len();
808 let mut gradient = Array1::<f64>::zeros(theta.len());
809 let t_start = rho_dim;
810 let direct_start = t_start + flat_len + analytic_rho_count;
811 if theta.len() < direct_start {
812 crate::bail_invalid_estim!(
813 "latent-coordinate theta too short for id objective: got {}, need at least {}",
814 theta.len(),
815 direct_start
816 );
817 }
818 let t = latent.as_matrix();
819 let mut cost = 0.0;
820 let mut cursor = direct_start;
821
822 match latent.id_mode() {
823 LatentIdMode::AuxPrior {
824 u,
825 family,
826 strength,
827 }
828 | LatentIdMode::AuxPriorDimSelection {
829 u,
830 family,
831 strength,
832 ..
833 } => {
834 let (log_mu, mu) = match strength {
835 AuxPriorStrength::Fixed(mu) => (mu.ln(), *mu),
836 AuxPriorStrength::Auto => {
837 let log_mu = theta[cursor];
838 cursor += 1;
839 (log_mu, log_mu.exp())
840 }
841 };
842 let targets = aux_prior_targets(t.view(), u.view(), *family)
843 .map_err(EstimationError::InvalidInput)?;
844 let residual = &t - &targets;
845 let q = residual.iter().map(|v| v * v).sum::<f64>();
846 let k = (n_obs * latent_dim) as f64;
853 cost += 0.5 * mu * q - 0.5 * k * log_mu;
854
855 let projected_residual = aux_prior_targets(residual.view(), u.view(), *family)
856 .map_err(EstimationError::InvalidInput)?;
857 let grad_base = residual - projected_residual;
858 for n in 0..n_obs {
859 for axis in 0..latent_dim {
860 gradient[t_start + n * latent_dim + axis] += mu * grad_base[[n, axis]];
861 }
862 }
863 if matches!(strength, AuxPriorStrength::Auto) {
864 gradient[direct_start] += 0.5 * mu * q - 0.5 * k;
865 }
866 }
867 LatentIdMode::IsometryToReference { reference, strength } => {
868 if reference.dim() != (n_obs, latent_dim) {
875 crate::bail_invalid_estim!(
876 "IsometryToReference reference shape {:?} must equal (n_obs, latent_dim) = ({}, {})",
877 reference.dim(),
878 n_obs,
879 latent_dim
880 );
881 }
882 let mu_slot = cursor;
883 let (log_mu, mu) = match strength {
884 AuxPriorStrength::Fixed(mu) => (mu.ln(), *mu),
885 AuxPriorStrength::Auto => {
886 let log_mu = theta[cursor];
887 cursor += 1;
888 (log_mu, log_mu.exp())
889 }
890 };
891 let residual = &t - reference;
892 let q = residual.iter().map(|v| v * v).sum::<f64>();
893 let k = (n_obs * latent_dim) as f64;
897 cost += 0.5 * mu * q - 0.5 * k * log_mu;
898 for n in 0..n_obs {
899 for axis in 0..latent_dim {
900 gradient[t_start + n * latent_dim + axis] += mu * residual[[n, axis]];
901 }
902 }
903 if matches!(strength, AuxPriorStrength::Auto) {
904 gradient[mu_slot] += 0.5 * mu * q - 0.5 * k;
905 }
906 }
907 LatentIdMode::AuxOutcome { head, .. } => {
908 let n_coeffs = head.n_coeffs(latent_dim);
916 let coeffs = theta
917 .slice(ndarray::s![cursor..cursor + n_coeffs])
918 .to_owned();
919 let (head_nll, grad_coeffs, grad_t) = head
920 .neg_loglik_and_grad(t.view(), coeffs.view())
921 .map_err(EstimationError::InvalidInput)?;
922 cost += head_nll;
923 for (offset, &g) in grad_coeffs.iter().enumerate() {
924 gradient[cursor + offset] += g;
925 }
926 for n in 0..n_obs {
927 for axis in 0..latent_dim {
928 gradient[t_start + n * latent_dim + axis] += grad_t[[n, axis]];
929 }
930 }
931 cursor += n_coeffs;
932 }
933 LatentIdMode::DimSelection { .. } | LatentIdMode::None => {}
934 }
935
936 match latent.id_mode() {
937 LatentIdMode::AuxPriorDimSelection { .. }
938 | LatentIdMode::DimSelection { .. }
939 | LatentIdMode::AuxOutcome { .. } => {
940 for axis in 0..latent_dim {
941 let log_alpha = theta[cursor + axis];
942 let alpha = log_alpha.exp();
943 let mut q_axis = 0.0;
944 for n in 0..n_obs {
945 let flat_idx = n * latent_dim + axis;
946 let value = latent.as_flat()[flat_idx];
947 q_axis += value * value;
948 gradient[t_start + flat_idx] += alpha * value;
949 }
950 cost += 0.5 * alpha * q_axis - 0.5 * n_obs as f64 * log_alpha;
951 gradient[cursor + axis] += 0.5 * alpha * q_axis - 0.5 * n_obs as f64;
952 }
953 cursor += latent_dim;
954 }
955 LatentIdMode::AuxPrior { .. }
956 | LatentIdMode::IsometryToReference { .. }
957 | LatentIdMode::None => {}
958 }
959
960 if cursor != theta.len() {
961 crate::bail_invalid_estim!(
962 "latent-coordinate direct hyperparameter length mismatch: consumed {}, theta len {}",
963 cursor,
964 theta.len()
965 );
966 }
967 Ok(LatentIdObjectiveContribution { cost, gradient })
968}
969
970fn add_latent_id_objective_to_eval(
971 theta: &Array1<f64>,
972 rho_dim: usize,
973 analytic_rho_count: usize,
974 latent: &gam_terms::latent::LatentCoordValues,
975 eval: &mut (
976 f64,
977 Array1<f64>,
978 gam_problem::HessianResult,
979 ),
980) -> Result<(), EstimationError> {
981 let contribution =
982 latent_id_objective_contribution(theta, rho_dim, analytic_rho_count, latent)?;
983 eval.0 += contribution.cost;
984 if eval.1.len() != contribution.gradient.len() {
985 crate::bail_invalid_estim!(
986 "latent-coordinate REML gradient length mismatch: base={}, id={}",
987 eval.1.len(),
988 contribution.gradient.len()
989 );
990 }
991 eval.1 += &contribution.gradient;
992 if eval.2.is_analytic() {
993 eval.2 = gam_problem::HessianResult::Unavailable;
994 }
995 Ok(())
996}
997
998fn analytic_penalty_objective_contribution(
999 theta: &Array1<f64>,
1000 rho_dim: usize,
1001 latent: &gam_terms::latent::LatentCoordValues,
1002 registry: &gam_terms::AnalyticPenaltyRegistry,
1003) -> Result<LatentIdObjectiveContribution, EstimationError> {
1004 let flat_len = latent.len();
1005 let t_start = rho_dim;
1006 let t_end = t_start + flat_len;
1007 let rho_start = t_end;
1008 let rho_end = rho_start + registry.total_rho_count();
1009 if theta.len() < rho_end {
1010 crate::bail_invalid_estim!(
1011 "latent-coordinate theta too short for analytic penalties: got {}, need at least {}",
1012 theta.len(),
1013 rho_end
1014 );
1015 }
1016 let target_t = theta.slice(s![t_start..t_end]);
1017 let rho = theta.slice(s![rho_start..rho_end]);
1018 let mut cost = 0.0_f64;
1019 let mut gradient = Array1::<f64>::zeros(theta.len());
1020 for (penalty, (rho_slice, tier, name)) in registry.penalties.iter().zip(registry.rho_layout()) {
1021 let rho_local = rho.slice(s![rho_slice.clone()]);
1022 match tier {
1023 gam_terms::PenaltyTier::Psi => {
1024 cost += penalty.value(target_t.view(), rho_local);
1025 let grad = penalty.grad_target(target_t.view(), rho_local);
1026 if grad.len() != flat_len {
1027 crate::bail_invalid_estim!(
1028 "analytic penalty {name:?} gradient length mismatch: got {}, expected {}",
1029 grad.len(),
1030 flat_len
1031 );
1032 }
1033 for i in 0..flat_len {
1034 gradient[t_start + i] += grad[i];
1035 }
1036 let grad_rho_local = penalty.grad_rho(target_t.view(), rho_local);
1037 if grad_rho_local.len() != rho_slice.len() {
1038 crate::bail_invalid_estim!(
1039 "analytic penalty {name:?} rho-gradient length mismatch: got {}, expected {}",
1040 grad_rho_local.len(),
1041 rho_slice.len()
1042 );
1043 }
1044 for local_idx in 0..grad_rho_local.len() {
1045 gradient[rho_start + rho_slice.start + local_idx] += grad_rho_local[local_idx];
1046 }
1047 }
1048 gam_terms::PenaltyTier::Beta => {}
1049 gam_terms::PenaltyTier::Rho => {}
1050 }
1051 }
1052 Ok(LatentIdObjectiveContribution { cost, gradient })
1053}
1054
1055fn add_analytic_penalty_hessian_to_eval(
1056 theta: &Array1<f64>,
1057 rho_dim: usize,
1058 latent: &gam_terms::latent::LatentCoordValues,
1059 registry: &gam_terms::AnalyticPenaltyRegistry,
1060 eval: &mut (
1061 f64,
1062 Array1<f64>,
1063 gam_problem::HessianResult,
1064 ),
1065) -> Result<(), EstimationError> {
1066 let flat_len = latent.len();
1067 let t_start = rho_dim;
1068 let t_end = t_start + flat_len;
1069 let rho_start = t_end;
1070 let rho_end = rho_start + registry.total_rho_count();
1071 if theta.len() < rho_end {
1072 crate::bail_invalid_estim!(
1073 "latent-coordinate theta too short for analytic penalty Hessian: got {}, need at least {}",
1074 theta.len(),
1075 rho_end
1076 );
1077 }
1078 let gam_problem::HessianResult::Analytic(hessian) = &mut eval.2 else {
1079 if eval.2.is_analytic() {
1080 eval.2 = gam_problem::HessianResult::Unavailable;
1081 }
1082 return Ok(());
1083 };
1084 if hessian.dim() != (theta.len(), theta.len()) {
1085 crate::bail_invalid_estim!(
1086 "analytic penalty Hessian target shape mismatch: got {}x{}, expected {}x{}",
1087 hessian.nrows(),
1088 hessian.ncols(),
1089 theta.len(),
1090 theta.len()
1091 );
1092 }
1093 let target_t = theta.slice(s![t_start..t_end]);
1094 let rho = theta.slice(s![rho_start..rho_end]);
1095 for (penalty, (rho_slice, tier, _name)) in registry.penalties.iter().zip(registry.rho_layout())
1096 {
1097 let rho_local = rho.slice(s![rho_slice]);
1098 if !matches!(tier, gam_terms::PenaltyTier::Psi) {
1099 continue;
1100 }
1101 if let Some(diag) = penalty.hessian_diag(target_t.view(), rho_local) {
1102 if diag.len() != flat_len {
1103 crate::bail_invalid_estim!(
1104 "analytic penalty Hessian diagonal length mismatch: got {}, expected {}",
1105 diag.len(),
1106 flat_len
1107 );
1108 }
1109 for i in 0..flat_len {
1110 hessian[[t_start + i, t_start + i]] += diag[i];
1111 }
1112 continue;
1113 }
1114 let mut probe = Array1::<f64>::zeros(flat_len);
1115 for col in 0..flat_len {
1116 probe[col] = 1.0;
1117 let hv = penalty.hvp(target_t.view(), rho_local, probe.view());
1118 if hv.len() != flat_len {
1119 crate::bail_invalid_estim!(
1120 "analytic penalty Hessian-vector length mismatch: got {}, expected {}",
1121 hv.len(),
1122 flat_len
1123 );
1124 }
1125 for row in 0..flat_len {
1126 hessian[[t_start + row, t_start + col]] += hv[row];
1127 }
1128 probe[col] = 0.0;
1129 }
1130 }
1131 Ok(())
1132}
1133
1134fn add_analytic_penalty_objective_to_eval(
1135 theta: &Array1<f64>,
1136 rho_dim: usize,
1137 latent: &gam_terms::latent::LatentCoordValues,
1138 registry: &gam_terms::AnalyticPenaltyRegistry,
1139 eval: &mut (
1140 f64,
1141 Array1<f64>,
1142 gam_problem::HessianResult,
1143 ),
1144) -> Result<(), EstimationError> {
1145 let contribution = analytic_penalty_objective_contribution(theta, rho_dim, latent, registry)?;
1146 eval.0 += contribution.cost;
1147 if eval.1.len() != contribution.gradient.len() {
1148 crate::bail_invalid_estim!(
1149 "latent-coordinate REML gradient length mismatch: base={}, analytic_penalty={}",
1150 eval.1.len(),
1151 contribution.gradient.len()
1152 );
1153 }
1154 eval.1 += &contribution.gradient;
1155 add_analytic_penalty_hessian_to_eval(theta, rho_dim, latent, registry, eval)?;
1156 Ok(())
1157}
1158
1159fn spatial_log_kappa_hyper_dirs_frominfo_list(
1160 info_list: Vec<SpatialPsiDerivative>,
1161) -> Result<Vec<DirectionalHyperParam>, EstimationError> {
1162 use gam_solve::estimate::reml::ImplicitDerivLevel;
1163 use std::collections::HashMap;
1164
1165 let log_kappa_dim = info_list.len();
1166 let group_ids: Vec<Option<usize>> = info_list.iter().map(|e| e.aniso_group_id).collect();
1172 let mut group_indices_map: HashMap<usize, Vec<usize>> = HashMap::new();
1173 for (idx, gid) in group_ids.iter().enumerate() {
1174 if let Some(g) = gid {
1175 group_indices_map.entry(*g).or_default().push(idx);
1176 }
1177 }
1178
1179 let mut hyper_dirs = Vec::with_capacity(log_kappa_dim);
1180 for (i, info) in info_list.into_iter().enumerate() {
1181 let SpatialPsiDerivative {
1182 penalty_index: _,
1183 penalty_indices,
1184 global_range,
1185 total_p,
1186 x_psi_local,
1187 s_psi_components_local,
1188 x_psi_psi_local,
1189 s_psi_psi_components_local,
1190 aniso_group_id,
1191 aniso_cross_designs,
1192 aniso_cross_penalty_provider,
1193 implicit_operator,
1194 implicit_axis,
1195 } = info;
1196
1197 let mut xsecond = vec![None; log_kappa_dim];
1198 xsecond[i] = Some(if let Some(ref op) = implicit_operator {
1200 gam_solve::estimate::reml::HyperDesignDerivative::from_implicit(
1201 op.clone(),
1202 ImplicitDerivLevel::SecondDiag(implicit_axis),
1203 global_range.clone(),
1204 total_p,
1205 )
1206 } else {
1207 gam_solve::estimate::reml::HyperDesignDerivative::from_embedded(
1208 x_psi_psi_local,
1209 global_range.clone(),
1210 total_p,
1211 )
1212 });
1213 if let Some(cross_designs) = aniso_cross_designs {
1215 if let Some(gid) = aniso_group_id {
1219 let base = group_indices_map
1220 .get(&gid)
1221 .and_then(|v| v.first().copied())
1222 .unwrap_or(i);
1223 for (b_axis, cross_mat) in cross_designs.into_iter() {
1224 let j = base + b_axis;
1225 if j < log_kappa_dim {
1226 xsecond[j] = Some(if let Some(ref op) = implicit_operator {
1227 gam_solve::estimate::reml::HyperDesignDerivative::from_implicit(
1228 op.clone(),
1229 ImplicitDerivLevel::SecondCross(implicit_axis, b_axis),
1230 global_range.clone(),
1231 total_p,
1232 )
1233 } else {
1234 gam_solve::estimate::reml::HyperDesignDerivative::from_embedded(
1235 cross_mat,
1236 global_range.clone(),
1237 total_p,
1238 )
1239 });
1240 }
1241 }
1242 }
1243 }
1244 let s_components = penalty_indices
1245 .iter()
1246 .copied()
1247 .zip(s_psi_components_local.into_iter().map(|local| {
1248 gam_solve::estimate::reml::HyperPenaltyDerivative::from_embedded(
1249 local,
1250 global_range.clone(),
1251 total_p,
1252 )
1253 }))
1254 .collect::<Vec<_>>();
1255 let s2_components = penalty_indices
1256 .iter()
1257 .copied()
1258 .zip(s_psi_psi_components_local.into_iter().map(|local| {
1259 gam_solve::estimate::reml::HyperPenaltyDerivative::from_embedded(
1260 local,
1261 global_range.clone(),
1262 total_p,
1263 )
1264 }))
1265 .collect::<Vec<_>>();
1266 let mut ssecond_components = vec![None; log_kappa_dim];
1267 ssecond_components[i] = Some(s2_components);
1268 let mut penaltysecond_partner_indices: Option<Vec<usize>> = None;
1269 let penaltysecond_component_provider =
1270 if let (Some(provider), Some(gid)) = (aniso_cross_penalty_provider, aniso_group_id) {
1271 let group_indices = group_indices_map.get(&gid).cloned().unwrap_or_default();
1272 let axis_in_group =
1273 group_indices
1274 .iter()
1275 .position(|&idx| idx == i)
1276 .ok_or_else(|| {
1277 EstimationError::InvalidInput(format!(
1278 "missing spatial hyper axis {} in anisotropy group {}",
1279 i, gid
1280 ))
1281 })?;
1282 penaltysecond_partner_indices = Some(
1283 group_indices
1284 .iter()
1285 .copied()
1286 .filter(|&idx| idx != i)
1287 .collect(),
1288 );
1289 let penalty_indices_inner = penalty_indices.clone();
1290 let global_range_inner = global_range.clone();
1291 let total_p_inner = total_p;
1292 let group_indices_inner = group_indices;
1293 Some(std::sync::Arc::new(
1294 move |j: usize| -> Result<
1295 Option<Vec<gam_solve::estimate::reml::PenaltyDerivativeComponent>>,
1296 EstimationError,
1297 > {
1298 let Some(other_axis_in_group) =
1299 group_indices_inner.iter().position(|&idx| idx == j)
1300 else {
1301 return Ok(None);
1302 };
1303 if other_axis_in_group == axis_in_group {
1304 return Ok(None);
1305 }
1306 let cross_pens = provider(other_axis_in_group)?;
1307 if cross_pens.is_empty() {
1308 return Ok(None);
1309 }
1310 Ok(Some(
1311 penalty_indices_inner
1312 .iter()
1313 .copied()
1314 .zip(cross_pens.into_iter().map(|local| {
1315 gam_solve::estimate::reml::HyperPenaltyDerivative::from_embedded(
1316 local,
1317 global_range_inner.clone(),
1318 total_p_inner,
1319 )
1320 }))
1321 .map(|(penalty_index, matrix)| {
1322 gam_solve::estimate::reml::PenaltyDerivativeComponent {
1323 penalty_index,
1324 matrix,
1325 }
1326 })
1327 .collect(),
1328 ))
1329 },
1330 )
1331 as std::sync::Arc<
1332 dyn Fn(
1333 usize,
1334 ) -> Result<
1335 Option<Vec<gam_solve::estimate::reml::PenaltyDerivativeComponent>>,
1336 EstimationError,
1337 > + Send
1338 + Sync
1339 + 'static,
1340 >)
1341 } else {
1342 None
1343 };
1344 let x_first_hyper = if let Some(ref op) = implicit_operator {
1347 gam_solve::estimate::reml::HyperDesignDerivative::from_implicit(
1348 op.clone(),
1349 ImplicitDerivLevel::First(implicit_axis),
1350 global_range.clone(),
1351 total_p,
1352 )
1353 } else {
1354 gam_solve::estimate::reml::HyperDesignDerivative::from_embedded(
1355 x_psi_local,
1356 global_range.clone(),
1357 total_p,
1358 )
1359 };
1360 let mut dir = DirectionalHyperParam::new_compact(
1361 x_first_hyper,
1362 s_components,
1363 Some(xsecond),
1364 Some(ssecond_components),
1365 )?
1366 .not_penalty_like();
1367 if let Some(provider) = penaltysecond_component_provider {
1368 dir = dir.with_penaltysecond_component_provider(provider);
1369 }
1370 if let Some(partner_indices) = penaltysecond_partner_indices {
1371 dir = dir.with_penaltysecond_partner_indices(partner_indices);
1372 }
1373 hyper_dirs.push(dir);
1374 }
1375 Ok(hyper_dirs)
1376}
1377
1378pub(crate) fn spatial_dims_per_term(
1384 resolvedspec: &TermCollectionSpec,
1385 spatial_terms: &[usize],
1386) -> Vec<usize> {
1387 spatial_terms
1388 .iter()
1389 .map(|&term_idx| {
1390 if let Some(mj) = measure_jet_term_spec(resolvedspec, term_idx) {
1391 measure_jet_psi_dim(mj)
1394 } else if spatial_term_uses_per_axis_psi(resolvedspec, term_idx) {
1395 get_spatial_feature_dim(resolvedspec, term_idx).unwrap_or(1)
1396 } else {
1397 1
1398 }
1399 })
1400 .collect()
1401}
1402
1403fn has_aniso_terms(resolvedspec: &TermCollectionSpec, spatial_terms: &[usize]) -> bool {
1407 spatial_terms
1408 .iter()
1409 .any(|&term_idx| spatial_term_uses_per_axis_psi(resolvedspec, term_idx))
1410}
1411
1412macro_rules! impl_exact_joint_theta_memo {
1418 () => {
1419 fn memoized_cost(&self, theta: &Array1<f64>) -> Option<f64> {
1420 if self
1421 .current_theta
1422 .as_ref()
1423 .is_some_and(|cached| theta_values_match(cached, theta))
1424 {
1425 self.last_eval
1426 .as_ref()
1427 .map(|cached| cached.0)
1428 .or(self.last_cost)
1429 } else {
1430 None
1431 }
1432 }
1433
1434 fn memoized_eval(
1435 &self,
1436 theta: &Array1<f64>,
1437 ) -> Option<(
1438 f64,
1439 Array1<f64>,
1440 gam_problem::HessianResult,
1441 )> {
1442 if self
1443 .current_theta
1444 .as_ref()
1445 .is_some_and(|cached| theta_values_match(cached, theta))
1446 {
1447 self.last_eval.clone()
1448 } else {
1449 None
1450 }
1451 }
1452
1453 fn store_eval(
1454 &mut self,
1455 eval: (
1456 f64,
1457 Array1<f64>,
1458 gam_problem::HessianResult,
1459 ),
1460 ) {
1461 self.last_cost = Some(eval.0);
1462 self.last_eval = Some(eval);
1463 }
1464 };
1465}
1466
1467struct SingleBlockExactJointDesignCache<'d> {
1468 realizer: FrozenTermCollectionIncrementalRealizer<'d>,
1469 current_theta: Option<Array1<f64>>,
1470 last_eval_theta: Option<Array1<f64>>,
1477 last_cost: Option<f64>,
1478 last_eval: Option<(
1479 f64,
1480 Array1<f64>,
1481 gam_problem::HessianResult,
1482 )>,
1483 cached_hyper_dirs: Option<(u64, Vec<DirectionalHyperParam>)>,
1495 spatial_terms: Vec<usize>,
1496 rho_dim: usize,
1497 dims_per_term: Vec<usize>,
1498}
1499
1500impl<'d> SingleBlockExactJointDesignCache<'d> {
1501 fn new(
1502 data: ArrayView2<'d, f64>,
1503 spec: TermCollectionSpec,
1504 design: TermCollectionDesign,
1505 spatial_terms: Vec<usize>,
1506 rho_dim: usize,
1507 dims_per_term: Vec<usize>,
1508 ) -> Result<Self, String> {
1509 Ok(Self {
1510 realizer: FrozenTermCollectionIncrementalRealizer::new(data, spec, design)?,
1511 current_theta: None,
1512 last_eval_theta: None,
1513 last_cost: None,
1514 last_eval: None,
1515 cached_hyper_dirs: None,
1516 spatial_terms,
1517 rho_dim,
1518 dims_per_term,
1519 })
1520 }
1521
1522 fn design_revision(&self) -> u64 {
1523 self.realizer.design_revision()
1524 }
1525
1526 fn hyper_dirs_for_current_design(
1536 &mut self,
1537 data: ArrayView2<'_, f64>,
1538 kind: SpatialHyperKind,
1539 ) -> Result<Vec<DirectionalHyperParam>, EstimationError> {
1540 let revision = self.realizer.design_revision();
1541 if let Some((cached_rev, dirs)) = self.cached_hyper_dirs.as_ref()
1542 && *cached_rev == revision
1543 {
1544 return Ok(dirs.clone());
1545 }
1546 let dirs = try_build_spatial_log_kappa_hyper_dirs(
1547 data,
1548 self.realizer.spec(),
1549 self.realizer.design(),
1550 &self.spatial_terms,
1551 )?
1552 .ok_or_else(|| {
1553 EstimationError::InvalidInput(format!(
1554 "failed to build {} hyper_dirs at current {}",
1555 kind.adjective(),
1556 kind.coord_name(),
1557 ))
1558 })?;
1559 self.cached_hyper_dirs = Some((revision, dirs.clone()));
1560 Ok(dirs)
1561 }
1562
1563 fn nfree_tensor_gradient_hyper_dirs(
1564 &mut self,
1565 theta: &Array1<f64>,
1566 ) -> Result<Vec<DirectionalHyperParam>, EstimationError> {
1567 let psi = &theta.as_slice().ok_or_else(|| {
1568 EstimationError::InvalidInput(
1569 "nfree_tensor_gradient_hyper_dirs: theta is not contiguous".to_string(),
1570 )
1571 })?[self.rho_dim..];
1572 let (global_range, p_total, s_psi_components) = self
1573 .realizer
1574 .canonical_penalty_derivatives_at_psi(&self.spatial_terms, psi)
1575 .map_err(EstimationError::InvalidInput)?;
1576 let zero_x = gam_solve::estimate::reml::HyperDesignDerivative::zero(
1577 self.realizer.design().design.nrows(),
1578 p_total,
1579 );
1580 let components = s_psi_components
1581 .into_iter()
1582 .enumerate()
1583 .map(|(penalty_index, local)| {
1584 (
1585 penalty_index,
1586 gam_solve::estimate::reml::HyperPenaltyDerivative::from_embedded(
1587 local,
1588 global_range.clone(),
1589 p_total,
1590 ),
1591 )
1592 })
1593 .collect::<Vec<_>>();
1594 Ok(DirectionalHyperParam::new_compact(zero_x, components, None, None)?.not_penalty_like())
1595 .map(|dir| vec![dir])
1596 }
1597
1598 fn ensure_theta(&mut self, theta: &Array1<f64>) -> Result<(), String> {
1599 if self
1600 .current_theta
1601 .as_ref()
1602 .is_some_and(|cached| theta_values_match(cached, theta))
1603 {
1604 return Ok(());
1605 }
1606 let t_ensure = std::time::Instant::now();
1607 let log_kappa = SpatialLogKappaCoords::from_theta_tail_with_dims(
1608 theta,
1609 self.rho_dim,
1610 self.dims_per_term.clone(),
1611 );
1612 self.realizer
1613 .apply_log_kappa(&log_kappa, &self.spatial_terms)?;
1614 log::info!(
1615 "[STAGE] ensure_theta (apply_log_kappa, {} terms): {:.3}s",
1616 self.spatial_terms.len(),
1617 t_ensure.elapsed().as_secs_f64(),
1618 );
1619 self.current_theta = Some(theta.clone());
1620 self.last_eval_theta = None;
1621 self.last_cost = None;
1622 self.last_eval = None;
1623 Ok(())
1624 }
1625
1626 fn memoized_cost(&self, theta: &Array1<f64>) -> Option<f64> {
1633 if self
1634 .last_eval_theta
1635 .as_ref()
1636 .is_some_and(|cached| theta_values_match(cached, theta))
1637 {
1638 self.last_eval
1639 .as_ref()
1640 .map(|cached| cached.0)
1641 .or(self.last_cost)
1642 } else {
1643 None
1644 }
1645 }
1646
1647 fn memoized_eval(
1648 &self,
1649 theta: &Array1<f64>,
1650 ) -> Option<(
1651 f64,
1652 Array1<f64>,
1653 gam_problem::HessianResult,
1654 )> {
1655 if self
1656 .last_eval_theta
1657 .as_ref()
1658 .is_some_and(|cached| theta_values_match(cached, theta))
1659 {
1660 self.last_eval.clone()
1661 } else {
1662 None
1663 }
1664 }
1665
1666 fn store_eval_at(
1670 &mut self,
1671 theta: &Array1<f64>,
1672 eval: (
1673 f64,
1674 Array1<f64>,
1675 gam_problem::HessianResult,
1676 ),
1677 ) {
1678 self.last_eval_theta = Some(theta.clone());
1679 self.last_cost = Some(eval.0);
1680 self.last_eval = Some(eval);
1681 }
1682
1683 fn store_cost_at(&mut self, theta: &Array1<f64>, cost: f64) {
1686 self.last_eval_theta = Some(theta.clone());
1687 self.last_cost = Some(cost);
1688 self.last_eval = None;
1692 }
1693
1694 fn spec(&self) -> &TermCollectionSpec {
1695 self.realizer.spec()
1696 }
1697
1698 fn design(&self) -> &TermCollectionDesign {
1699 self.realizer.design()
1700 }
1701
1702 fn supports_nfree_penalty_rekey(&self) -> bool {
1708 self.realizer
1709 .supports_nfree_penalty_rekey(&self.spatial_terms)
1710 }
1711
1712 fn supports_nfree_gradient_only_routing(&self) -> bool {
1713 self.realizer
1714 .supports_nfree_gradient_only_routing(&self.spatial_terms)
1715 }
1716
1717 fn canonical_penalties_at(
1727 &mut self,
1728 theta: &Array1<f64>,
1729 ) -> Result<(Vec<gam_terms::construction::CanonicalPenalty>, Vec<usize>), String> {
1730 let psi = &theta
1731 .as_slice()
1732 .ok_or_else(|| "canonical_penalties_at: theta is not contiguous".to_string())?
1733 [self.rho_dim..];
1734 self.realizer
1735 .canonical_penalties_at_psi(&self.spatial_terms, psi)
1736 }
1737}
1738
1739struct SingleBlockLatentCoordDesignCache {
1740 data: Array2<f64>,
1741 spec: TermCollectionSpec,
1742 design: TermCollectionDesign,
1743 current_theta: Option<Array1<f64>>,
1744 current_latent: Option<std::sync::Arc<gam_terms::latent::LatentCoordValues>>,
1745 current_hyper_dirs: Option<Vec<gam_solve::estimate::reml::DirectionalHyperParam>>,
1746 current_design_cache_id: Option<u64>,
1747 latent_design_cache: gam_solve::latent_cache::LatentDesignCache,
1748 last_cost: Option<f64>,
1749 last_eval: Option<(
1750 f64,
1751 Array1<f64>,
1752 gam_problem::HessianResult,
1753 )>,
1754 term_index: gam_problem::types::SmoothTermIdx,
1755 feature_cols: Vec<usize>,
1756 rho_dim: usize,
1757 n_obs: usize,
1758 latent_dim: usize,
1759 id_mode: gam_terms::latent::LatentIdMode,
1760 manifold: gam_terms::latent::LatentManifold,
1761 retraction_registry: gam_solve::latent_cache::LatentRetractionRegistry,
1762 latent_id: u64,
1763 analytic_penalties: Option<std::sync::Arc<gam_terms::AnalyticPenaltyRegistry>>,
1764 analytic_rho_count: usize,
1765 design_revision: u64,
1766 last_outer_iter: Option<u64>,
1770}
1771
1772impl SingleBlockLatentCoordDesignCache {
1773 fn new(
1774 data: Array2<f64>,
1775 spec: TermCollectionSpec,
1776 design: TermCollectionDesign,
1777 latent: &StandardLatentCoordConfig,
1778 rho_dim: usize,
1779 ) -> Result<Self, String> {
1780 if latent.term_index.get() >= spec.smooth_terms.len() {
1781 return Err(SmoothError::dimension_mismatch(format!(
1782 "latent-coordinate term index {} out of bounds for {} smooth terms",
1783 latent.term_index,
1784 spec.smooth_terms.len()
1785 ))
1786 .into());
1787 }
1788 if latent.feature_cols.len() != latent.values.latent_dim() {
1789 return Err(SmoothError::dimension_mismatch(format!(
1790 "latent-coordinate feature width mismatch: feature_cols={}, latent_dim={}",
1791 latent.feature_cols.len(),
1792 latent.values.latent_dim()
1793 ))
1794 .into());
1795 }
1796 if latent.values.n_obs() != data.nrows() {
1797 return Err(SmoothError::dimension_mismatch(format!(
1798 "latent-coordinate row mismatch: latent n={}, data n={}",
1799 latent.values.n_obs(),
1800 data.nrows()
1801 ))
1802 .into());
1803 }
1804 let analytic_rho_count = latent
1805 .analytic_penalties
1806 .as_ref()
1807 .map_or(0, |registry| registry.total_rho_count());
1808 Ok(Self {
1809 data,
1810 spec,
1811 design,
1812 current_theta: None,
1813 current_latent: None,
1814 current_hyper_dirs: None,
1815 current_design_cache_id: None,
1816 latent_design_cache: gam_solve::latent_cache::LatentDesignCache::default(),
1817 last_cost: None,
1818 last_eval: None,
1819 term_index: latent.term_index,
1820 feature_cols: latent.feature_cols.clone(),
1821 rho_dim,
1822 n_obs: latent.values.n_obs(),
1823 latent_dim: latent.values.latent_dim(),
1824 id_mode: latent.values.id_mode().clone(),
1825 manifold: latent.values.manifold().clone(),
1826 retraction_registry: latent.values.retraction_registry().clone(),
1827 latent_id: latent.values.latent_id(),
1828 analytic_penalties: latent.analytic_penalties.clone(),
1829 analytic_rho_count,
1830 design_revision: 0,
1831 last_outer_iter: None,
1832 })
1833 }
1834
1835 fn design_revision(&self) -> u64 {
1836 self.design_revision
1837 }
1838
1839 fn design(&self) -> &TermCollectionDesign {
1840 &self.design
1841 }
1842
1843 fn latent(&self) -> Result<std::sync::Arc<gam_terms::latent::LatentCoordValues>, String> {
1844 self.current_latent
1845 .as_ref()
1846 .cloned()
1847 .ok_or_else(|| "latent-coordinate cache has not been realized".to_string())
1848 }
1849
1850 fn analytic_penalties(&self) -> Option<std::sync::Arc<gam_terms::AnalyticPenaltyRegistry>> {
1851 self.analytic_penalties.clone()
1852 }
1853
1854 fn analytic_penalty_rho_count(&self) -> usize {
1855 self.analytic_rho_count
1856 }
1857
1858 fn hyper_dirs(&self) -> Result<Vec<gam_solve::estimate::reml::DirectionalHyperParam>, String> {
1859 self.current_hyper_dirs
1860 .as_ref()
1861 .cloned()
1862 .ok_or_else(|| "latent-coordinate hyper_dirs cache has not been realized".to_string())
1863 }
1864
1865 fn latent_basis_kind(&self) -> Result<gam_solve::latent_cache::LatentBasisKind, String> {
1866 let smooth_term = self
1867 .design
1868 .smooth
1869 .terms
1870 .get(self.term_index.get())
1871 .ok_or_else(|| {
1872 SmoothError::dimension_mismatch(format!(
1873 "LatentCoord term index {} out of bounds for realized smooth design",
1874 self.term_index
1875 ))
1876 })?;
1877 let termspec = self
1878 .spec
1879 .smooth_terms
1880 .get(self.term_index.get())
1881 .ok_or_else(|| {
1882 SmoothError::dimension_mismatch(format!(
1883 "LatentCoord term index {} out of bounds for resolved smooth spec",
1884 self.term_index
1885 ))
1886 })?;
1887 match (&termspec.basis, &smooth_term.metadata) {
1888 (
1889 SmoothBasisSpec::Matern { .. },
1890 BasisMetadata::Matern {
1891 centers,
1892 length_scale,
1893 nu,
1894 aniso_log_scales,
1895 ..
1896 },
1897 ) => Ok(gam_solve::latent_cache::LatentBasisKind::Matern {
1898 centers: centers.clone(),
1899 length_scale: *length_scale,
1900 nu: *nu,
1901 aniso_log_scales: aniso_log_scales
1902 .clone()
1903 .unwrap_or_else(|| vec![0.0; centers.ncols()]),
1904 chunk_size: gam_terms::basis::auto_streaming_chunk_size_for_dense(
1905 self.n_obs,
1906 centers.nrows(),
1907 ),
1908 }),
1909 (
1910 SmoothBasisSpec::Duchon { .. },
1911 BasisMetadata::Duchon {
1912 centers,
1913 length_scale,
1914 power,
1915 nullspace_order,
1916 aniso_log_scales,
1917 ..
1918 },
1919 ) => Ok(gam_solve::latent_cache::LatentBasisKind::Duchon {
1920 centers: centers.clone(),
1921 length_scale: *length_scale,
1922 power: *power,
1923 nullspace_order: *nullspace_order,
1924 aniso_log_scales: aniso_log_scales
1925 .clone()
1926 .unwrap_or_else(|| vec![0.0; centers.ncols()]),
1927 }),
1928 (
1929 SmoothBasisSpec::Sphere { .. },
1930 BasisMetadata::Sphere {
1931 centers,
1932 penalty_order,
1933 method,
1934 ..
1935 },
1936 ) if matches!(*method, gam_terms::basis::SphereMethod::Wahba) => {
1937 Ok(gam_solve::latent_cache::LatentBasisKind::Sphere {
1938 centers: centers.clone(),
1939 penalty_order: *penalty_order,
1940 chunk_size: gam_terms::basis::auto_streaming_chunk_size_for_dense(
1941 self.n_obs,
1942 centers.nrows(),
1943 ),
1944 })
1945 }
1946 (
1947 SmoothBasisSpec::BSpline1D { spec, .. },
1948 BasisMetadata::BSpline1D {
1949 knots,
1950 periodic,
1951 degree: meta_degree,
1952 ..
1953 },
1954 ) => {
1955 let effective_degree = meta_degree.unwrap_or(spec.degree);
1959 if let Some((domain_start, period, num_basis)) = periodic {
1960 Ok(
1961 gam_solve::latent_cache::LatentBasisKind::PeriodicBspline {
1962 domain_start: *domain_start,
1963 period: *period,
1964 degree: effective_degree,
1965 num_basis: *num_basis,
1966 chunk_size: gam_terms::basis::auto_streaming_chunk_size_for_dense(
1967 self.n_obs, *num_basis,
1968 ),
1969 },
1970 )
1971 } else {
1972 let num_basis_est = knots.len().saturating_sub(effective_degree + 1);
1973 Ok(
1974 gam_solve::latent_cache::LatentBasisKind::TensorBspline {
1975 knots: vec![knots.clone()],
1976 degrees: vec![effective_degree],
1977 chunk_size: gam_terms::basis::auto_streaming_chunk_size_for_dense(
1978 self.n_obs,
1979 num_basis_est,
1980 ),
1981 },
1982 )
1983 }
1984 }
1985 (
1986 SmoothBasisSpec::TensorBSpline { .. },
1987 BasisMetadata::TensorBSpline { knots, degrees, .. },
1988 ) => Ok(
1989 gam_solve::latent_cache::LatentBasisKind::TensorBspline {
1990 knots: knots.clone(),
1991 degrees: degrees.clone(),
1992 chunk_size: None,
1993 },
1994 ),
1995 (
1996 SmoothBasisSpec::Pca { .. },
1997 BasisMetadata::Pca {
1998 basis_matrix,
1999 centered,
2000 smooth_penalty,
2001 center_mean,
2002 pca_basis_path,
2003 chunk_size,
2004 ..
2005 },
2006 ) => {
2007 let center_mean_fingerprint = if *centered && pca_basis_path.is_none() {
2008 let mean = center_mean.as_ref().ok_or_else(|| {
2009 SmoothError::invalid_config(
2010 "latent-coordinate Pca cache key requires center_mean when centered",
2011 )
2012 })?;
2013 Some(gam_solve::latent_cache::pca_center_mean_fingerprint(
2014 mean,
2015 ))
2016 } else {
2017 None
2018 };
2019 Ok(gam_solve::latent_cache::LatentBasisKind::Pca {
2020 basis_matrix: basis_matrix.clone(),
2021 centered: *centered,
2022 center_mean_fingerprint,
2023 smooth_penalty: *smooth_penalty,
2024 pca_basis_path: pca_basis_path.clone(),
2025 chunk_size: *chunk_size,
2026 })
2027 }
2028 _ => Err(SmoothError::invalid_config(
2029 "latent-coordinate design cache could not key the realized latent smooth basis"
2030 .to_string(),
2031 )
2032 .into()),
2033 }
2034 }
2035
2036 fn ensure_theta(&mut self, theta: &Array1<f64>) -> Result<(), String> {
2037 if self
2038 .current_theta
2039 .as_ref()
2040 .is_some_and(|cached| theta_values_match(cached, theta))
2041 {
2042 return Ok(());
2043 }
2044 let latent_flat_len = self.n_obs * self.latent_dim;
2045 let direct_hyper_count = latent_coord_direct_hyper_count(&self.id_mode, self.latent_dim);
2046 let expected =
2047 self.rho_dim + latent_flat_len + self.analytic_rho_count + direct_hyper_count;
2048 if theta.len() != expected {
2049 return Err(SmoothError::dimension_mismatch(format!(
2050 "latent-coordinate theta length mismatch: got {}, expected {} (rho_dim={}, n={}, d={}, analytic_rhos={}, direct_hypers={})",
2051 theta.len(),
2052 expected,
2053 self.rho_dim,
2054 self.n_obs,
2055 self.latent_dim,
2056 self.analytic_rho_count,
2057 direct_hyper_count
2058 ))
2059 .into());
2060 }
2061 let flat = theta
2062 .slice(s![self.rho_dim..self.rho_dim + latent_flat_len])
2063 .to_owned();
2064 let latent = std::sync::Arc::new(
2065 gam_terms::latent::LatentCoordValues::from_flat_with_manifold_and_retraction_and_id(
2066 flat,
2067 self.n_obs,
2068 self.latent_dim,
2069 self.id_mode.clone(),
2070 self.manifold.clone(),
2071 self.retraction_registry.clone(),
2072 self.latent_id,
2073 ),
2074 );
2075 let latent_values_changed = self
2076 .current_latent
2077 .as_ref()
2078 .map(|cached| !latent_values_match(cached.as_flat(), latent.as_flat()))
2079 .unwrap_or(true);
2080 if latent_values_changed {
2081 self.latent_design_cache.invalidate_all();
2082 self.current_design_cache_id = None;
2083 self.design_revision = self.design_revision.wrapping_add(1);
2084 }
2085 for n in 0..self.n_obs {
2086 for axis in 0..self.latent_dim {
2087 let col = self.feature_cols[axis];
2088 self.data[[n, col]] = latent.as_flat()[n * self.latent_dim + axis];
2089 }
2090 }
2091
2092 let basis_kind = self.latent_basis_kind()?;
2093 let rebuilt_width = self.design.design.ncols();
2094 let spec = self.spec.clone();
2095 let term_index = self.term_index;
2096 let analytic_rho_count = self.analytic_rho_count;
2097 let data = self.data.view();
2098 let design_context_digest =
2099 gam_solve::latent_cache::latent_design_context_cache_digest(
2100 data,
2101 &spec,
2102 term_index,
2103 analytic_rho_count,
2104 &self.feature_cols,
2105 )
2106 .map_err(|e| e.to_string())?;
2107 let lookup = self
2108 .latent_design_cache
2109 .lookup_or_compute(latent.clone(), basis_kind, design_context_digest, || {
2110 let rebuilt = build_term_collection_design(data, &spec).map_err(|e| {
2111 EstimationError::InvalidInput(format!(
2112 "failed to rebuild latent-coordinate design: {e}"
2113 ))
2114 })?;
2115 if rebuilt.design.ncols() != rebuilt_width {
2116 crate::bail_invalid_estim!(
2117 "latent-coordinate design topology changed: rebuilt p={}, cached p={}",
2118 rebuilt.design.ncols(),
2119 rebuilt_width
2120 );
2121 }
2122 let hyper_dirs = try_build_latent_coord_hyper_dirs(
2123 latent.clone(),
2124 &spec,
2125 &rebuilt,
2126 &[term_index],
2127 analytic_rho_count,
2128 )?
2129 .ok_or_else(|| {
2130 EstimationError::InvalidInput(
2131 "failed to build latent-coordinate hyper_dirs".to_string(),
2132 )
2133 })?;
2134 Ok(gam_solve::latent_cache::ComputedLatentDesign {
2135 design: rebuilt,
2136 hyper_dirs,
2137 })
2138 })
2139 .map_err(|e| e.to_string())?;
2140 if lookup.cached.design.design.ncols() != self.design.design.ncols() {
2141 return Err(SmoothError::dimension_mismatch(format!(
2142 "latent-coordinate design topology changed: rebuilt p={}, cached p={}",
2143 lookup.cached.design.design.ncols(),
2144 self.design.design.ncols()
2145 ))
2146 .into());
2147 }
2148 self.design = lookup.cached.design.clone();
2149 self.current_hyper_dirs = Some(lookup.cached.hyper_dirs.clone());
2150 self.current_latent = Some(latent);
2151 self.current_theta = Some(theta.clone());
2152 self.last_cost = None;
2153 self.last_eval = None;
2154 self.last_outer_iter = None;
2155 if !latent_values_changed && self.current_design_cache_id != Some(lookup.entry_id) {
2156 self.design_revision = self.design_revision.wrapping_add(1);
2157 }
2158 self.current_design_cache_id = Some(lookup.entry_id);
2159 Ok(())
2160 }
2161
2162 fn memoized_cost(&self, theta: &Array1<f64>) -> Option<f64> {
2163 if self
2164 .current_theta
2165 .as_ref()
2166 .is_some_and(|cached| theta_values_match(cached, theta))
2167 && self.last_outer_iter
2168 == Some(gam_solve::estimate::reml::outer_eval::current_outer_iter())
2169 {
2170 self.last_eval
2171 .as_ref()
2172 .map(|cached| cached.0)
2173 .or(self.last_cost)
2174 } else {
2175 None
2176 }
2177 }
2178
2179 fn memoized_eval(
2180 &self,
2181 theta: &Array1<f64>,
2182 ) -> Option<(
2183 f64,
2184 Array1<f64>,
2185 gam_problem::HessianResult,
2186 )> {
2187 if self
2188 .current_theta
2189 .as_ref()
2190 .is_some_and(|cached| theta_values_match(cached, theta))
2191 && self.last_outer_iter
2192 == Some(gam_solve::estimate::reml::outer_eval::current_outer_iter())
2193 {
2194 self.last_eval.clone()
2195 } else {
2196 None
2197 }
2198 }
2199
2200 fn store_eval(
2201 &mut self,
2202 eval: (
2203 f64,
2204 Array1<f64>,
2205 gam_problem::HessianResult,
2206 ),
2207 ) {
2208 self.last_cost = Some(eval.0);
2209 self.last_eval = Some(eval);
2210 self.last_outer_iter =
2211 Some(gam_solve::estimate::reml::outer_eval::current_outer_iter());
2212 }
2213
2214 fn store_cost(&mut self, cost: f64) {
2215 self.last_cost = Some(cost);
2216 self.last_outer_iter =
2217 Some(gam_solve::estimate::reml::outer_eval::current_outer_iter());
2218 }
2219
2220 fn reset(&mut self) {
2221 self.current_theta = None;
2222 self.current_latent = None;
2223 self.current_hyper_dirs = None;
2224 self.current_design_cache_id = None;
2225 self.latent_design_cache.invalidate();
2226 self.last_cost = None;
2227 self.last_eval = None;
2228 self.last_outer_iter = None;
2229 }
2230}
2231
2232pub fn fixed_kappa_profiled_reml_score(
2248 data: ArrayView2<'_, f64>,
2249 y: ArrayView1<'_, f64>,
2250 weights: ArrayView1<'_, f64>,
2251 offset: ArrayView1<'_, f64>,
2252 resolvedspec: &TermCollectionSpec,
2253 term_idx: usize,
2254 kappa: f64,
2255 family: LikelihoodSpec,
2256 options: &FitOptions,
2257) -> Result<f64, EstimationError> {
2258 if !kappa.is_finite() {
2259 crate::bail_invalid_estim!("fixed-κ profiled score probed a non-finite κ = {kappa}");
2260 }
2261 let (feature_cols, mut probe_basis) = match resolvedspec
2264 .smooth_terms
2265 .get(term_idx)
2266 .map(|t| &t.basis)
2267 {
2268 Some(SmoothBasisSpec::ConstantCurvature {
2269 feature_cols, spec, ..
2270 }) => (feature_cols.clone(), spec.clone()),
2271 _ => {
2272 crate::bail_invalid_estim!(
2273 "fixed-κ profiled score: term {term_idx} is not a constant-curvature smooth"
2274 )
2275 }
2276 };
2277 probe_basis.kappa = kappa;
2278
2279 let is_unweighted = weights.iter().all(|&w| (w - 1.0).abs() <= 1e-12);
2299 let is_zero_offset = offset.iter().all(|&o| o.abs() <= 1e-12);
2300 if family == LikelihoodSpec::gaussian_identity() && is_unweighted && is_zero_offset {
2301 let x_term = select_columns(data, &feature_cols).map_err(EstimationError::from)?;
2302 let score =
2303 gam_terms::basis::constant_curvature_honest_profiled_reml_score(x_term.view(), y, &probe_basis)
2304 .map_err(|e| {
2305 EstimationError::InvalidInput(format!(
2306 "fixed-κ honest profiled-REML score at κ={kappa} failed: {e}"
2307 ))
2308 })?;
2309 if !score.is_finite() {
2310 crate::bail_invalid_estim!(
2311 "fixed-κ honest profiled-REML score at κ={kappa} is non-finite"
2312 );
2313 }
2314 return Ok(score);
2315 }
2316
2317 let mut probe_spec = resolvedspec.clone();
2319 match probe_spec.smooth_terms.get_mut(term_idx).map(|t| &mut t.basis) {
2320 Some(SmoothBasisSpec::ConstantCurvature { spec, .. }) => spec.kappa = kappa,
2321 _ => {
2322 crate::bail_invalid_estim!(
2323 "fixed-κ profiled score: term {term_idx} is not a constant-curvature smooth"
2324 )
2325 }
2326 }
2327 let fixed_kappa_options = SpatialLengthScaleOptimizationOptions {
2328 enabled: false,
2329 ..SpatialLengthScaleOptimizationOptions::default()
2330 };
2331 let fit = fit_term_collectionwith_spatial_length_scale_optimization(
2332 data,
2333 y.to_owned(),
2334 weights.to_owned(),
2335 offset.to_owned(),
2336 &probe_spec,
2337 family,
2338 options,
2339 &fixed_kappa_options,
2340 )?;
2341 let score = fit_score(&fit.fit);
2342 if !score.is_finite() {
2343 crate::bail_invalid_estim!("fixed-κ profiled fit at κ={kappa} returned a non-finite score");
2344 }
2345 Ok(score)
2346}
2347
2348fn constant_curvature_kappa_fair_argmin(
2373 data: ArrayView2<'_, f64>,
2374 y: ArrayView1<'_, f64>,
2375 resolvedspec: &TermCollectionSpec,
2376 term_idx: usize,
2377) -> Option<f64> {
2378 let (kappa_min, kappa_max) = constant_curvature_kappa_bounds(data, resolvedspec, term_idx);
2379 if !(kappa_min.is_finite() && kappa_max.is_finite() && kappa_max > kappa_min) {
2380 return None;
2381 }
2382 let (feature_cols, base_spec) = match resolvedspec.smooth_terms.get(term_idx).map(|t| &t.basis) {
2383 Some(SmoothBasisSpec::ConstantCurvature {
2384 feature_cols, spec, ..
2385 }) => (feature_cols, spec.clone()),
2386 _ => return None,
2387 };
2388 let x_term = match select_columns(data, feature_cols) {
2389 Ok(x) => x,
2390 Err(e) => {
2391 log::info!("[spatial-kappa] #1464 κ-fair argmin column select failed ({e}); skipping");
2392 return None;
2393 }
2394 };
2395 const GRID_STEPS: usize = 24;
2401 let mut best: Option<(f64, f64)> = None; for i in 0..=GRID_STEPS {
2403 let t = i as f64 / GRID_STEPS as f64;
2404 let kappa = kappa_min + (kappa_max - kappa_min) * t;
2405 let mut probe_spec = base_spec.clone();
2406 probe_spec.kappa = kappa;
2407 match gam_terms::basis::constant_curvature_kappa_fair_sign_score(x_term.view(), y, &probe_spec) {
2408 Ok(score) => {
2409 if best.as_ref().is_none_or(|(b, _)| score < *b) {
2410 best = Some((score, kappa));
2411 }
2412 }
2413 Err(e) => {
2414 log::info!(
2415 "[spatial-kappa] #1464 κ-fair argmin probe at κ={kappa:.4} failed ({e}); skipping"
2416 );
2417 }
2418 }
2419 }
2420 best.map(|(score, kappa)| {
2421 log::info!(
2422 "[spatial-kappa] #1464 κ-fair argmin κ̂={kappa:.4} (κ-fair score={score:.6e}) for term {term_idx}"
2423 );
2424 kappa
2425 })
2426}
2427
2428fn select_constant_curvature_kappa_sign_seed(
2436 data: ArrayView2<'_, f64>,
2437 y: ArrayView1<'_, f64>,
2438 resolvedspec: &TermCollectionSpec,
2439 term_idx: usize,
2440) -> Option<f64> {
2441 let (kappa_min, kappa_max) = constant_curvature_kappa_bounds(data, resolvedspec, term_idx);
2442 if !(kappa_min.is_finite() && kappa_max.is_finite() && kappa_max > kappa_min) {
2443 return None;
2444 }
2445 let (feature_cols, base_spec) = match resolvedspec.smooth_terms.get(term_idx).map(|t| &t.basis) {
2457 Some(SmoothBasisSpec::ConstantCurvature {
2458 feature_cols, spec, ..
2459 }) => (feature_cols, spec.clone()),
2460 _ => return None,
2461 };
2462 let x_term = match select_columns(data, feature_cols) {
2463 Ok(x) => x,
2464 Err(e) => {
2465 log::info!("[spatial-kappa] #1464 sign-basin scan column select failed ({e}); skipping");
2466 return None;
2467 }
2468 };
2469 let probes = [
2473 kappa_min,
2474 0.5 * kappa_min,
2475 0.0,
2476 0.5 * kappa_max,
2477 kappa_max,
2478 ];
2479 let mut best: Option<(f64, f64)> = None; for &kappa in &probes {
2481 let mut probe_spec = base_spec.clone();
2482 probe_spec.kappa = kappa;
2483 match gam_terms::basis::constant_curvature_kappa_fair_sign_score(
2484 x_term.view(),
2485 y,
2486 &probe_spec,
2487 ) {
2488 Ok(score) => {
2489 if best.as_ref().is_none_or(|(b, _)| score < *b) {
2490 best = Some((score, kappa));
2491 }
2492 }
2493 Err(e) => {
2494 log::info!(
2495 "[spatial-kappa] #1464 sign-basin probe at κ={kappa:.4} failed ({e}); skipping"
2496 );
2497 }
2498 }
2499 }
2500 best.map(|(score, kappa)| {
2501 log::info!(
2502 "[spatial-kappa] #1464 κ-fair sign-basin scan selected κ_seed={kappa:.4} \
2503 (κ-fair score={score:.6e}) for term {term_idx}"
2504 );
2505 kappa
2506 })
2507}
2508
2509const SPATIAL_RANGE_PRESCAN_GRID: usize = 7;
2512
2513fn prescan_isotropic_spatial_range_seed(
2545 data: ArrayView2<'_, f64>,
2546 y: ArrayView1<'_, f64>,
2547 weights: ArrayView1<'_, f64>,
2548 offset: ArrayView1<'_, f64>,
2549 resolvedspec: &TermCollectionSpec,
2550 baseline_score: f64,
2551 family: &LikelihoodSpec,
2552 options: &FitOptions,
2553 kappa_options: &SpatialLengthScaleOptimizationOptions,
2554 spatial_terms: &[usize],
2555) -> Result<Vec<(usize, f64)>, EstimationError> {
2556 if has_aniso_terms(resolvedspec, spatial_terms)
2558 || !constant_curvature_term_indices(resolvedspec).is_empty()
2559 {
2560 return Ok(Vec::new());
2561 }
2562 let dims = spatial_dims_per_term(resolvedspec, spatial_terms);
2563 let mut working = resolvedspec.clone();
2567 let mut best_score = if baseline_score.is_finite() {
2568 baseline_score
2569 } else {
2570 f64::INFINITY
2571 };
2572 let mut overrides: Vec<(usize, f64)> = Vec::new();
2573 for (slot, &term_idx) in spatial_terms.iter().enumerate() {
2574 if dims.get(slot).copied().unwrap_or(1) != 1 {
2577 continue;
2578 }
2579 if get_spatial_length_scale(&working, term_idx).is_none() {
2582 continue;
2583 }
2584 let (psi_lo, psi_hi) = spatial_term_psi_bounds(data, &working, term_idx, kappa_options);
2585 if !(psi_lo.is_finite() && psi_hi.is_finite()) || psi_hi <= psi_lo {
2586 continue;
2587 }
2588 let mut term_best: Option<f64> = None;
2589 for g in 0..SPATIAL_RANGE_PRESCAN_GRID {
2590 let frac = g as f64 / (SPATIAL_RANGE_PRESCAN_GRID - 1) as f64;
2591 let psi = psi_lo + (psi_hi - psi_lo) * frac;
2592 let ls = (-psi).exp();
2596 if !ls.is_finite() || ls <= 0.0 {
2597 continue;
2598 }
2599 let mut probe = working.clone();
2600 if set_spatial_length_scale(&mut probe, term_idx, ls).is_err() {
2601 continue;
2602 }
2603 let fit = match fit_term_collection_forspec(
2612 data,
2613 y,
2614 weights,
2615 offset,
2616 &probe,
2617 family.clone(),
2618 options,
2619 ) {
2620 Ok(fit) => fit,
2621 Err(_) => continue,
2624 };
2625 let score = fit_score(&fit.fit);
2626 if score.is_finite() && score < best_score - 1e-7 * best_score.abs().max(1.0) {
2629 best_score = score;
2630 term_best = Some(ls);
2631 }
2632 }
2633 if let Some(ls) = term_best {
2634 set_spatial_length_scale(&mut working, term_idx, ls)?;
2635 overrides.push((term_idx, ls));
2636 log::info!(
2637 "[spatial-kappa] #1074 range pre-scan: term {term_idx} re-seeded at \
2638 length_scale={ls:.5} (profiled REML {best_score:.5}, was {baseline_score:.5})"
2639 );
2640 }
2641 }
2642 Ok(overrides)
2643}
2644
2645const JOINT_RESTART_WINDOW_FRACTIONS: [f64; 5] = [0.0, 0.2, 0.45, 0.7, 1.0];
2654
2655fn joint_solve_from_window_fraction(
2671 data: ArrayView2<'_, f64>,
2672 y: ArrayView1<'_, f64>,
2673 weights: ArrayView1<'_, f64>,
2674 offset: ArrayView1<'_, f64>,
2675 base_spec: &TermCollectionSpec,
2676 spatial_terms: &[usize],
2677 fraction: f64,
2678 family: &LikelihoodSpec,
2679 options: &FitOptions,
2680 baseline_options: &FitOptions,
2681 kappa_options: &SpatialLengthScaleOptimizationOptions,
2682) -> Result<Option<(FittedTermCollectionWithSpec, f64)>, EstimationError> {
2683 let mut seed_spec = base_spec.clone();
2684 let mut any_set = false;
2685 for &term_idx in spatial_terms {
2686 if get_spatial_length_scale(&seed_spec, term_idx).is_none() {
2687 continue;
2688 }
2689 let (psi_lo, psi_hi) = spatial_term_psi_bounds(data, &seed_spec, term_idx, kappa_options);
2690 if !(psi_lo.is_finite() && psi_hi.is_finite()) || psi_hi <= psi_lo {
2691 continue;
2692 }
2693 let psi = psi_lo + (psi_hi - psi_lo) * fraction;
2694 let ls = (-psi).exp();
2695 if !ls.is_finite() || ls <= 0.0 {
2696 continue;
2697 }
2698 if set_spatial_length_scale(&mut seed_spec, term_idx, ls).is_ok() {
2699 any_set = true;
2700 }
2701 }
2702 if !any_set {
2703 return Ok(None);
2704 }
2705 let seed_best = match fit_term_collection_forspec(
2709 data,
2710 y,
2711 weights,
2712 offset,
2713 &seed_spec,
2714 family.clone(),
2715 baseline_options,
2716 ) {
2717 Ok(fit) => fit,
2718 Err(_) => return Ok(None),
2719 };
2720 let seed_spec = freeze_term_collection_from_design(&seed_spec, &seed_best.design)?;
2721 let seed_terms = spatial_length_scale_term_indices(&seed_spec);
2724 if seed_terms.is_empty() {
2725 let score = fit_score(&seed_best.fit);
2726 return Ok(Some((
2727 FittedTermCollectionWithSpec {
2728 fit: seed_best.fit,
2729 design: seed_best.design,
2730 resolvedspec: seed_spec,
2731 adaptive_diagnostics: seed_best.adaptive_diagnostics,
2732 kappa_timing: None,
2733 },
2734 score,
2735 )));
2736 }
2737 let joint = try_exact_joint_spatial_length_scale_optimization(
2738 data,
2739 y,
2740 weights,
2741 offset,
2742 &seed_spec,
2743 &seed_best,
2744 family.clone(),
2745 options,
2746 kappa_options,
2747 &seed_terms,
2748 )?;
2749 match joint {
2750 Some(fit) => {
2751 let score = fit_score(&fit.fit);
2752 Ok(Some((fit, score)))
2753 }
2754 None => {
2757 let score = fit_score(&seed_best.fit);
2758 Ok(Some((
2759 FittedTermCollectionWithSpec {
2760 fit: seed_best.fit,
2761 design: seed_best.design,
2762 resolvedspec: seed_spec,
2763 adaptive_diagnostics: seed_best.adaptive_diagnostics,
2764 kappa_timing: None,
2765 },
2766 score,
2767 )))
2768 }
2769 }
2770}
2771
2772fn try_exact_joint_spatial_length_scale_optimization(
2773 data: ArrayView2<'_, f64>,
2774 y: ArrayView1<'_, f64>,
2775 weights: ArrayView1<'_, f64>,
2776 offset: ArrayView1<'_, f64>,
2777 resolvedspec: &TermCollectionSpec,
2778 best: &FittedTermCollection,
2779 family: LikelihoodSpec,
2780 options: &FitOptions,
2781 kappa_options: &SpatialLengthScaleOptimizationOptions,
2782 spatial_terms: &[usize],
2783) -> Result<Option<FittedTermCollectionWithSpec>, EstimationError> {
2784 if spatial_terms.is_empty() {
2785 return Ok(None);
2786 }
2787 kappa_options
2792 .validate()
2793 .map_err(EstimationError::InvalidInput)?;
2794
2795 let cc_term_set = constant_curvature_term_indices(resolvedspec);
2815 let all_spatial_are_cc =
2816 !cc_term_set.is_empty() && spatial_terms.iter().all(|t| cc_term_set.contains(t));
2817 if all_spatial_are_cc {
2818 let mut fixed_kappa_spec = resolvedspec.clone();
2819 let mut any_kappa_chosen = false;
2820 for &term_idx in spatial_terms {
2821 if let Some(kappa_hat) =
2832 constant_curvature_kappa_fair_argmin(data, y, resolvedspec, term_idx)
2833 .filter(|&k| k < 0.0)
2834 {
2835 if let Some(SmoothBasisSpec::ConstantCurvature { spec: cc, .. }) = fixed_kappa_spec
2836 .smooth_terms
2837 .get_mut(term_idx)
2838 .map(|t| &mut t.basis)
2839 {
2840 cc.kappa = kappa_hat;
2841 any_kappa_chosen = true;
2842 log::info!(
2843 "[spatial-kappa] #1464 term {term_idx}: fixed κ̂ = {kappa_hat:.4} from κ-fair argmin (hyperbolic basin; profiling ρ only)"
2844 );
2845 }
2846 }
2847 }
2848 if any_kappa_chosen {
2849 let baseline_score = fit_score(&best.fit);
2853 let fitted = fit_term_collection_forspec(
2854 data,
2855 y,
2856 weights,
2857 offset,
2858 &fixed_kappa_spec,
2859 family.clone(),
2860 options,
2861 )?;
2862 let frozen_spec =
2863 freeze_term_collection_from_design(&fixed_kappa_spec, &fitted.design)?;
2864 let mut fit = fitted.fit;
2865 fit.reml_score = baseline_score;
2877 return Ok(Some(FittedTermCollectionWithSpec {
2878 fit,
2879 design: fitted.design,
2880 resolvedspec: frozen_spec,
2881 adaptive_diagnostics: fitted.adaptive_diagnostics,
2882 kappa_timing: None,
2883 }));
2884 }
2885 }
2886
2887 if try_build_spatial_log_kappa_hyper_dirs(data, resolvedspec, &best.design, spatial_terms)?
2888 .is_none()
2889 {
2890 if !constant_curvature_term_indices(resolvedspec).is_empty() {
2891 log::info!(
2892 "[#1464-trace] try_exact_joint RETURNED None (hyper_dirs unavailable); \
2893 κ̂ comes from a NON-joint path"
2894 );
2895 }
2896 return Ok(None);
2897 }
2898 if !constant_curvature_term_indices(resolvedspec).is_empty() {
2899 log::info!(
2900 "[#1464-trace] try_exact_joint ENTERED for {} spatial term(s); CC present",
2901 spatial_terms.len()
2902 );
2903 }
2904
2905 const JOINT_RHO_BOUND: f64 = 12.0;
2906 let rho_dim = best.fit.lambdas.len();
2907
2908 let has_constant_curvature_term = !constant_curvature_term_indices(resolvedspec).is_empty();
2922 let rho_upper_bound = if has_constant_curvature_term {
2923 gam_solve::estimate::RHO_BOUND
2924 } else {
2925 JOINT_RHO_BOUND
2926 };
2927
2928 let dims_per_term = spatial_dims_per_term(resolvedspec, spatial_terms);
2930 let use_aniso = has_aniso_terms(resolvedspec, spatial_terms);
2931
2932 let log_kappa0 = if use_aniso {
2937 SpatialLogKappaCoords::from_length_scales_aniso(resolvedspec, spatial_terms, kappa_options)
2938 } else {
2939 SpatialLogKappaCoords::from_length_scales(resolvedspec, spatial_terms, kappa_options)
2940 };
2941 let mut log_kappa0 =
2944 log_kappa0.reseed_from_data(data, resolvedspec, spatial_terms, kappa_options);
2945 let mut cc_sign_seeds: Vec<(usize, f64)> = Vec::new();
2961 if has_constant_curvature_term {
2962 for (slot, &term_idx) in spatial_terms.iter().enumerate() {
2963 if constant_curvature_term_spec(resolvedspec, term_idx).is_none() {
2964 continue;
2965 }
2966 let scan = select_constant_curvature_kappa_sign_seed(
2967 data,
2968 y,
2969 resolvedspec,
2970 term_idx,
2971 );
2972 match scan {
2977 Some(kappa_seed) => {
2978 log::info!(
2979 "[#1464-trace] term {term_idx}: κ-fair sign-basin scan picked κ_seed = {kappa_seed}"
2980 );
2981 log_kappa0.set_scalar_slot(slot, kappa_seed);
2982 cc_sign_seeds.push((slot, kappa_seed));
2983 }
2984 None => {
2985 log::info!(
2986 "[#1464-trace] term {term_idx}: fixed-κ sign-basin scan returned NONE (no seed applied)"
2987 );
2988 }
2989 }
2990 }
2991 }
2992 let log_kappa_lower = if use_aniso {
2993 SpatialLogKappaCoords::lower_bounds_aniso_from_data(
2994 data,
2995 resolvedspec,
2996 spatial_terms,
2997 &dims_per_term,
2998 kappa_options,
2999 )
3000 } else {
3001 SpatialLogKappaCoords::lower_bounds_from_data(
3002 data,
3003 resolvedspec,
3004 spatial_terms,
3005 kappa_options,
3006 )
3007 };
3008 let log_kappa_upper = if use_aniso {
3009 SpatialLogKappaCoords::upper_bounds_aniso_from_data(
3010 data,
3011 resolvedspec,
3012 spatial_terms,
3013 &dims_per_term,
3014 kappa_options,
3015 )
3016 } else {
3017 SpatialLogKappaCoords::upper_bounds_from_data(
3018 data,
3019 resolvedspec,
3020 spatial_terms,
3021 kappa_options,
3022 )
3023 };
3024 let mut log_kappa_lower = log_kappa_lower;
3048 let mut log_kappa_upper = log_kappa_upper;
3049 for &(slot, kappa_seed) in &cc_sign_seeds {
3050 if kappa_seed != 0.0 {
3051 log_kappa_lower.set_scalar_slot(slot, kappa_seed);
3052 log_kappa_upper.set_scalar_slot(slot, kappa_seed);
3053 }
3054 log::info!(
3055 "[#1464-trace] slot {slot}: FROZE joint ψ coordinate at κ_seed={kappa_seed} \
3056 (window [{}, {}]); raw fit_score is sign-blind so the κ-fair scan is authoritative",
3057 log_kappa_lower.as_array()[log_kappa_lower.dims_per_term()[..slot].iter().sum::<usize>()],
3058 log_kappa_upper.as_array()[log_kappa_upper.dims_per_term()[..slot].iter().sum::<usize>()],
3059 );
3060 }
3061 let log_kappa0 = log_kappa0.clamp_to_bounds(&log_kappa_lower, &log_kappa_upper);
3064 let setup = ExactJointHyperSetup::new(
3065 best.fit.lambdas.mapv(f64::ln),
3066 Array1::<f64>::from_elem(rho_dim, -JOINT_RHO_BOUND),
3067 Array1::<f64>::from_elem(rho_dim, rho_upper_bound),
3068 log_kappa0,
3069 log_kappa_lower,
3070 log_kappa_upper,
3071 );
3072
3073 let theta0 = setup.theta0();
3074 let lower = setup.lower();
3075 let upper = setup.upper();
3076
3077 let kind = if use_aniso {
3089 SpatialHyperKind::Anisotropic
3090 } else {
3091 SpatialHyperKind::Isotropic
3092 };
3093 let (outcome, kappa_timing) = run_exact_joint_spatial_optimization(
3094 kind,
3095 data,
3096 y,
3097 weights,
3098 offset,
3099 resolvedspec,
3100 &best.design,
3101 family.clone(),
3102 options,
3103 spatial_terms,
3104 &dims_per_term,
3105 &theta0,
3106 &lower,
3107 &upper,
3108 rho_dim,
3109 kappa_options,
3110 )?;
3111
3112 let baseline_score = fit_score(&best.fit);
3113
3114 let (theta_star, joint_final_value) = match outcome {
3124 SpatialJointOutcome::Optimized {
3125 theta_star,
3126 final_value,
3127 } => (theta_star, final_value),
3128 SpatialJointOutcome::NonConverged {
3129 iterations,
3130 final_value,
3131 final_grad_norm,
3132 } => {
3133 if has_constant_curvature_term {
3134 log::info!(
3135 "[#1464-trace] joint solve NONCONVERGED (iters={iterations}, \
3136 final_value={final_value}); returning FROZEN BASELINE geometry \
3137 (κ̂ = spec default, NOT the joint candidate)"
3138 );
3139 }
3140 log::info!(
3141 "[spatial-kappa] joint spatial optimization did not converge \
3142 (iterations={}, final_objective={:.6e}, final_grad_norm={}); \
3143 keeping the frozen baseline geometry",
3144 iterations,
3145 final_value,
3146 final_grad_norm.map_or_else(|| "n/a".to_string(), |g| format!("{g:.3e}")),
3147 );
3148 return Ok(Some(fit_frozen_baseline_geometry(
3149 data,
3150 y,
3151 weights,
3152 offset,
3153 resolvedspec,
3154 best,
3155 family,
3156 options,
3157 baseline_score,
3158 Some(kappa_timing),
3159 )?));
3160 }
3161 };
3162
3163 let accept_tol = options.tol.max(1e-8 * baseline_score.abs()).max(1e-12);
3168 if joint_final_value > baseline_score + accept_tol {
3169 if has_constant_curvature_term {
3170 log::info!(
3171 "[#1464-trace] joint candidate WORSENED score (joint={joint_final_value}, \
3172 baseline={baseline_score}); returning FROZEN BASELINE geometry \
3173 (κ̂ = spec default, NOT the joint candidate)"
3174 );
3175 }
3176 log::info!(
3177 "[spatial-kappa] exact joint spatial candidate worsened the profiled score (joint={:.6e}, baseline={:.6e}, tol={:.2e}); keeping the frozen baseline geometry",
3178 joint_final_value,
3179 baseline_score,
3180 accept_tol,
3181 );
3182 return Ok(Some(fit_frozen_baseline_geometry(
3183 data,
3184 y,
3185 weights,
3186 offset,
3187 resolvedspec,
3188 best,
3189 family,
3190 options,
3191 baseline_score,
3192 Some(kappa_timing),
3193 )?));
3194 }
3195
3196 let rho_star = theta_star.slice(s![..rho_dim]).mapv(f64::exp);
3197 let log_kappa_star =
3198 SpatialLogKappaCoords::from_theta_tail_with_dims(&theta_star, rho_dim, dims_per_term);
3199 if has_constant_curvature_term {
3205 let star = log_kappa_star.as_array();
3206 let dims = log_kappa_star.dims_per_term();
3207 for (slot, &term_idx) in spatial_terms.iter().enumerate() {
3208 if constant_curvature_term_spec(resolvedspec, term_idx).is_some() {
3209 let off: usize = dims[..slot].iter().sum();
3210 log::info!(
3211 "[#1464-trace] term {term_idx}: joint solver CONVERGED ψ-tail κ = {} \
3212 (this is the optimised candidate; joint_final_value={joint_final_value})",
3213 star[off]
3214 );
3215 }
3216 }
3217 }
3218 let baseline_spec = resolvedspec;
3222 let optimized_spec = log_kappa_star.apply_tospec(resolvedspec, spatial_terms)?;
3223 let optimized = fit_term_collection_forspecwith_heuristic_lambdas(
3224 data,
3225 y,
3226 weights,
3227 offset,
3228 &optimized_spec,
3229 rho_star.as_slice(),
3230 family.clone(),
3231 options,
3232 )?;
3233
3234 let optimized_edf = optimized.fit.inference.as_ref().map(|inf| inf.edf_total);
3248 if let Some(opt_edf) = optimized_edf
3249 && opt_edf < SPATIAL_COLLAPSE_EDF_FLOOR
3250 {
3251 let baseline = fit_frozen_baseline_geometry(
3252 data,
3253 y,
3254 weights,
3255 offset,
3256 baseline_spec,
3257 best,
3258 family.clone(),
3259 options,
3260 baseline_score,
3261 Some(kappa_timing),
3262 )?;
3263 let baseline_edf = baseline.fit.inference.as_ref().map(|inf| inf.edf_total);
3264 if let Some(base_edf) = baseline_edf
3265 && base_edf >= opt_edf + SPATIAL_COLLAPSE_EDF_MARGIN
3266 {
3267 log::info!(
3268 "[spatial-kappa] joint candidate collapsed to the null (edf={opt_edf:.3}); \
3269 baseline geometry retains edf={base_edf:.3} — keeping the frozen baseline",
3270 );
3271 return Ok(Some(baseline));
3272 }
3273 }
3276
3277 let mut fit = optimized.fit;
3281 fit.reml_score = joint_final_value;
3282 let optimized_result = FittedTermCollectionWithSpec {
3283 fit,
3284 design: optimized.design,
3285 resolvedspec: optimized_spec,
3286 adaptive_diagnostics: optimized.adaptive_diagnostics,
3287 kappa_timing: Some(kappa_timing),
3288 };
3289
3290 Ok(Some(optimized_result))
3291}
3292
3293const SPATIAL_COLLAPSE_EDF_FLOOR: f64 = 2.5;
3297
3298const SPATIAL_COLLAPSE_EDF_MARGIN: f64 = 1.0;
3303
3304fn fit_frozen_baseline_geometry(
3340 data: ArrayView2<'_, f64>,
3341 y: ArrayView1<'_, f64>,
3342 weights: ArrayView1<'_, f64>,
3343 offset: ArrayView1<'_, f64>,
3344 resolvedspec: &TermCollectionSpec,
3345 best: &FittedTermCollection,
3346 family: LikelihoodSpec,
3347 options: &FitOptions,
3348 baseline_score: f64,
3349 kappa_timing: Option<SpatialLengthScaleOptimizationTiming>,
3350) -> Result<FittedTermCollectionWithSpec, EstimationError> {
3351 let baseline = fit_term_collection_forspecwith_heuristic_lambdas(
3352 data,
3353 y,
3354 weights,
3355 offset,
3356 resolvedspec,
3357 best.fit.lambdas.as_slice(),
3358 family.clone(),
3359 options,
3360 )?;
3361 let best_edf = best.fit.inference.as_ref().map(|inf| inf.edf_total);
3366 let baseline_edf = baseline.fit.inference.as_ref().map(|inf| inf.edf_total);
3367 let baseline = match (best_edf, baseline_edf) {
3368 (Some(best_edf), Some(base_edf))
3369 if base_edf < SPATIAL_COLLAPSE_EDF_FLOOR
3370 && best_edf >= base_edf + SPATIAL_COLLAPSE_EDF_MARGIN =>
3371 {
3372 log::info!(
3373 "[spatial-kappa] warm-started frozen baseline collapsed (edf={base_edf:.3}) \
3374 below the certified baseline (edf={best_edf:.3}); refitting from scratch",
3375 );
3376 fit_term_collection_forspec(data, y, weights, offset, resolvedspec, family, options)?
3377 }
3378 _ => baseline,
3379 };
3380 let mut fit = baseline.fit;
3381 fit.reml_score = baseline_score;
3382 Ok(FittedTermCollectionWithSpec {
3383 fit,
3384 design: baseline.design,
3385 resolvedspec: resolvedspec.clone(),
3386 adaptive_diagnostics: baseline.adaptive_diagnostics,
3387 kappa_timing,
3388 })
3389}
3390
3391#[derive(Clone, Copy, PartialEq, Eq, Debug)]
3403enum SpatialHyperKind {
3404 Anisotropic,
3405 Isotropic,
3406}
3407
3408impl SpatialHyperKind {
3409 fn label(self) -> &'static str {
3412 match self {
3413 SpatialHyperKind::Anisotropic => "spatial-aniso-joint",
3414 SpatialHyperKind::Isotropic => "spatial-iso-joint",
3415 }
3416 }
3417
3418 fn adjective(self) -> &'static str {
3420 match self {
3421 SpatialHyperKind::Anisotropic => "anisotropic",
3422 SpatialHyperKind::Isotropic => "isotropic",
3423 }
3424 }
3425
3426 fn coord_name(self) -> &'static str {
3429 match self {
3430 SpatialHyperKind::Anisotropic => "psi",
3431 SpatialHyperKind::Isotropic => "kappa",
3432 }
3433 }
3434}
3435
3436struct SpatialFrozenGlmInputs {
3442 y: Array1<f64>,
3443 weights: Array1<f64>,
3444 offset: Array1<f64>,
3445 family: LikelihoodSpec,
3446}
3447
3448fn frozen_glm_tensor_eligible_family(family: &LikelihoodSpec) -> bool {
3465 !family.is_gaussian_identity()
3466 && matches!(
3467 &family.response,
3468 ResponseFamily::Binomial
3469 | ResponseFamily::Poisson
3470 | ResponseFamily::Gamma
3471 | ResponseFamily::NegativeBinomial { .. }
3472 )
3473}
3474
3475struct SpatialJointContext<'d> {
3476 data: ArrayView2<'d, f64>,
3477 rho_dim: usize,
3478 kind: SpatialHyperKind,
3479 cache: SingleBlockExactJointDesignCache<'d>,
3480 evaluator: gam_solve::estimate::ExternalJointHyperEvaluator<'d>,
3481 frozen_glm_inputs: Option<SpatialFrozenGlmInputs>,
3482 frozen_glm_psi_bounds: Option<(f64, f64)>,
3483 frozen_glm_tensor: Option<gam_solve::glm_sufficient_lane::FrozenWeightGramTensor>,
3484 frozen_glm_tensor_attempted: bool,
3485 frozen_glm_weight_memo: Option<(Array1<f64>, Array1<f64>)>,
3497}
3498
3499#[derive(Clone, Copy, Debug, Default)]
3500struct NfreeSkipGateStatus {
3501 shape: bool,
3502 value: bool,
3503 gradient: bool,
3504 penalty: bool,
3505 revision: bool,
3506 second_order: bool,
3507}
3508
3509impl NfreeSkipGateStatus {
3510 fn would_skip(self, require_gradient: bool) -> bool {
3511 self.shape
3512 && self.value
3513 && (!require_gradient || self.gradient)
3514 && self.penalty
3515 && self.revision
3516 && !self.second_order
3517 }
3518}
3519
3520impl<'d> SpatialJointContext<'d> {
3521 fn nfree_skip_gate_status(
3522 &self,
3523 theta: &Array1<f64>,
3524 allow_second_order: bool,
3525 require_gradient: bool,
3526 ) -> NfreeSkipGateStatus {
3527 let shape = theta.len() == self.rho_dim + 1;
3528 let (value, gradient) = if shape {
3529 let psi = theta[self.rho_dim];
3530 (
3531 self.evaluator.psi_gram_tensor_covers(psi)
3532 && self.evaluator.psi_gram_tensor_covers_skip(psi),
3533 !require_gradient || self.evaluator.psi_gram_tensor_covers_gradient(psi),
3534 )
3535 } else {
3536 (false, false)
3537 };
3538 NfreeSkipGateStatus {
3539 shape,
3540 value,
3541 gradient,
3542 penalty: self.evaluator.supports_nfree_penalty_rekey(),
3543 revision: self.evaluator.nfree_fast_path_revision().is_some(),
3544 second_order: allow_second_order,
3545 }
3546 }
3547
3548 fn frozen_glm_working_state(
3549 &self,
3550 beta: &Array1<f64>,
3551 ) -> Result<Option<(Array1<f64>, Array1<f64>)>, EstimationError> {
3552 let Some(inputs) = self.frozen_glm_inputs.as_ref() else {
3553 return Ok(None);
3554 };
3555 if beta.len() != self.cache.design().design.ncols() {
3556 return Ok(None);
3557 }
3558 let mut eta = self.cache.design().design.matrixvectormultiply(beta);
3559 if eta.len() != inputs.offset.len() {
3560 crate::bail_invalid_estim!(
3561 "frozen GLM tensor warm-state row mismatch: eta={}, offset={}",
3562 eta.len(),
3563 inputs.offset.len()
3564 );
3565 }
3566 eta += &inputs.offset;
3567 let obs = evaluate_standard_familyobservations(
3568 inputs.family.clone(),
3569 None,
3570 None,
3571 None,
3572 &inputs.y,
3573 &inputs.weights,
3574 &eta,
3575 )?;
3576 let mut working_response = obs.eta.clone();
3577 for i in 0..working_response.len() {
3578 let wi = obs.fisherweight[i].max(1e-12);
3579 working_response[i] += obs.score[i] / wi;
3580 }
3581 Ok(Some((obs.fisherweight, working_response)))
3582 }
3583
3584 fn frozen_glm_trial_weights(
3593 &mut self,
3594 beta: &Array1<f64>,
3595 ) -> Result<Option<Array1<f64>>, EstimationError> {
3596 if let Some((memo_beta, memo_w)) = self.frozen_glm_weight_memo.as_ref()
3597 && memo_beta.len() == beta.len()
3598 && memo_beta
3599 .iter()
3600 .zip(beta.iter())
3601 .all(|(a, b)| a.to_bits() == b.to_bits())
3602 {
3603 return Ok(Some(memo_w.clone()));
3604 }
3605 match self.frozen_glm_working_state(beta)? {
3606 Some((current_w, _)) => {
3607 self.frozen_glm_weight_memo = Some((beta.clone(), current_w.clone()));
3608 Ok(Some(current_w))
3609 }
3610 None => Ok(None),
3611 }
3612 }
3613
3614 fn ensure_frozen_glm_tensor(
3615 &mut self,
3616 theta: &Array1<f64>,
3617 warm_beta: Option<&Array1<f64>>,
3618 ) -> Result<(), EstimationError> {
3619 if self.frozen_glm_tensor.is_some() || self.frozen_glm_tensor_attempted {
3620 return Ok(());
3621 }
3622 let Some((psi_lo, psi_hi)) = self.frozen_glm_psi_bounds else {
3623 return Ok(());
3624 };
3625 if theta.len() != self.rho_dim + 1 {
3626 self.frozen_glm_tensor_attempted = true;
3627 return Ok(());
3628 }
3629 let Some(beta) = warm_beta else {
3630 return Ok(());
3631 };
3632 let Some((frozen_w, working_z)) = self.frozen_glm_working_state(beta)? else {
3633 self.frozen_glm_tensor_attempted = true;
3634 return Ok(());
3635 };
3636 let theta_probe_base = theta.clone();
3637 let rho_dim = self.rho_dim;
3638 let Self {
3645 cache, evaluator, ..
3646 } = self;
3647 let tensor = evaluator.build_frozen_glm_gram_tensor(
3648 |psi| {
3649 let mut theta_probe = theta_probe_base.clone();
3650 theta_probe[rho_dim] = psi;
3651 cache.ensure_theta(&theta_probe)?;
3652 Ok(cache.design().design.clone())
3653 },
3654 frozen_w.view(),
3655 working_z.view(),
3656 psi_lo,
3657 psi_hi,
3658 );
3659 self.cache
3660 .ensure_theta(theta)
3661 .map_err(EstimationError::InvalidInput)?;
3662 self.frozen_glm_tensor_attempted = true;
3663 if let Some(tensor) = tensor {
3664 self.frozen_glm_tensor = Some(tensor);
3665 log::info!(
3666 "[STAGE] {} certified frozen-W GLM ψ tensor over [{psi_lo:.3}, {psi_hi:.3}]",
3667 self.kind.label(),
3668 );
3669 } else {
3670 log::info!(
3671 "[STAGE] {} frozen-W GLM ψ tensor did not certify over [{psi_lo:.3}, {psi_hi:.3}]",
3672 self.kind.label(),
3673 );
3674 }
3675 Ok(())
3676 }
3677
3678 fn stage_frozen_glm_trial_statistics(
3679 &mut self,
3680 theta: &Array1<f64>,
3681 warm_beta: Option<&Array1<f64>>,
3682 allow_gradient: bool,
3683 ) -> Result<(), EstimationError> {
3684 let kind = self.kind;
3685 let mut staged_gram: Option<Array2<f64>> = None;
3686 let mut staged_deriv: Option<(Array2<f64>, Array1<f64>)> = None;
3687 if theta.len() == self.rho_dim + 1 {
3688 let psi = theta[self.rho_dim];
3689 let tensor_covers = self
3696 .frozen_glm_tensor
3697 .as_ref()
3698 .is_some_and(|t| t.contains(psi));
3699 let current_w = if tensor_covers {
3700 match warm_beta {
3701 Some(beta) => self.frozen_glm_trial_weights(beta)?,
3702 None => None,
3703 }
3704 } else {
3705 None
3706 };
3707 if let (Some(tensor), Some(current_w)) =
3708 (self.frozen_glm_tensor.as_ref(), current_w.as_ref())
3709 {
3710 const FROZEN_GLM_WEIGHT_DRIFT_RTOL: f64 = 1e-3;
3711 if tensor.weight_drift_within(current_w.view(), FROZEN_GLM_WEIGHT_DRIFT_RTOL) {
3712 staged_gram = Some(tensor.gram_at(psi));
3713 log::debug!(
3714 "[STAGE] {} trial at psi={psi:.6}: serving frozen-W GLM \
3715 first-Fisher-step XᵀWX n-free (weight drift within tol)",
3716 kind.label(),
3717 );
3718 }
3719 if allow_gradient
3720 && tensor.contains_for_gradient(psi)
3721 && let Some((dgram_dpsi, drhs_dpsi)) =
3722 tensor.gradient_pair_if_sound(psi, current_w.view())
3723 {
3724 staged_deriv = Some((dgram_dpsi, drhs_dpsi));
3725 log::debug!(
3726 "[STAGE] {} trial at psi={psi:.6}: serving frozen-W GLM \
3727 ψ-gradient (∂G/∂ψ, ∂b/∂ψ) n-free (gradient weight drift within \
3728 tight tol); B_j stays exact",
3729 kind.label(),
3730 );
3731 }
3732 }
3733 }
3734 self.evaluator.stage_glm_first_step_gram(staged_gram);
3735 self.evaluator.stage_glm_psi_gram_deriv(staged_deriv);
3736 Ok(())
3737 }
3738
3739 fn eval_full(
3741 &mut self,
3742 theta: &Array1<f64>,
3743 order: gam_solve::rho_optimizer::OuterEvalOrder,
3744 analytic_outer_hessian_available: bool,
3745 ) -> Result<
3746 (
3747 f64,
3748 Array1<f64>,
3749 gam_problem::HessianResult,
3750 ),
3751 EstimationError,
3752 > {
3753 use gam_solve::rho_optimizer::OuterEvalOrder;
3754 let allow_second_order = matches!(order, OuterEvalOrder::ValueGradientHessian)
3755 && analytic_outer_hessian_available;
3756 if let Some(eval) = self.cache.memoized_eval(theta) {
3757 let cached_satisfies_order = !allow_second_order || eval.2.is_analytic();
3758 if cached_satisfies_order {
3759 return Ok(eval);
3760 }
3761 }
3762 let kind = self.kind;
3763 let nfree_fast_path_revision = self.evaluator.nfree_fast_path_revision();
3799 let skip_design_realization = !allow_second_order && theta.len() == self.rho_dim + 1 && {
3800 let psi = theta[self.rho_dim];
3801 self.evaluator.psi_gram_tensor_covers(psi)
3802 && self.evaluator.psi_gram_tensor_covers_gradient(psi)
3809 && self.evaluator.psi_gram_tensor_covers_skip(psi)
3826 && self.evaluator.supports_nfree_penalty_rekey()
3831 && nfree_fast_path_revision.is_some()
3832 };
3833 let skip_design_realization = false && skip_design_realization;
3837 log::warn!(
3838 "[OUTER-FD-AUDIT TEMP-SKIPOFF-1122] skip_design_realization={skip_design_realization}"
3839 );
3840 if skip_design_realization {
3841 log::debug!(
3842 "[STAGE] {} eval_full at psi={:.6}: skipping n×k design re-realization \
3843 + reconditioning — criterion/gradient/inner-solve served n-free from \
3844 the certified ψ-gram tensor (GaussianFixedCache + k-space ψ-derivatives)",
3845 kind.label(),
3846 theta[self.rho_dim],
3847 );
3848 } else {
3849 self.cache
3850 .ensure_theta(theta)
3851 .map_err(EstimationError::InvalidInput)?;
3852 }
3853 let warm_beta = self.evaluator.current_beta();
3854 self.ensure_frozen_glm_tensor(theta, warm_beta.as_ref())?;
3855 self.stage_frozen_glm_trial_statistics(theta, warm_beta.as_ref(), !allow_second_order)?;
3863 let hyper_dirs = if skip_design_realization {
3870 self.cache.nfree_tensor_gradient_hyper_dirs(theta)?
3871 } else {
3872 self.cache.hyper_dirs_for_current_design(self.data, kind)?
3873 };
3874
3875 let design_revision = if skip_design_realization {
3876 nfree_fast_path_revision
3877 } else {
3878 Some(self.cache.design_revision())
3879 };
3880 if self.evaluator.supports_nfree_penalty_rekey() {
3894 match self.cache.canonical_penalties_at(theta) {
3895 Ok(penalty) => self.evaluator.stage_fast_path_penalty(Some(penalty)),
3896 Err(e) => {
3897 log::warn!(
3898 "[STAGE] {} eval_full at psi={:.6}: exact n-free S(ψ) rebuild failed \
3899 ({e}); clearing stage (eval falls to slow path)",
3900 kind.label(),
3901 theta[self.rho_dim],
3902 );
3903 self.evaluator.stage_fast_path_penalty(None);
3904 }
3905 }
3906 }
3907 let eval = evaluate_joint_reml_outer_eval_at_theta(
3914 &mut self.evaluator,
3915 self.cache.design(),
3916 theta,
3917 self.rho_dim,
3918 hyper_dirs,
3919 warm_beta.as_ref().map(|b: &Array1<f64>| b.view()),
3920 if allow_second_order {
3921 order
3922 } else {
3923 OuterEvalOrder::ValueAndGradient
3924 },
3925 design_revision,
3926 );
3927 if let Ok(ref value) = eval {
3928 self.cache.store_eval_at(theta, value.clone());
3929 }
3930 eval
3931 }
3932
3933 fn eval_efs(
3934 &mut self,
3935 theta: &Array1<f64>,
3936 ) -> Result<gam_problem::EfsEval, EstimationError> {
3937 self.cache
3938 .ensure_theta(theta)
3939 .map_err(EstimationError::InvalidInput)?;
3940 let kind = self.kind;
3941 let hyper_dirs = try_build_spatial_log_kappa_hyper_dirs(
3942 self.data,
3943 self.cache.spec(),
3944 self.cache.design(),
3945 &self.cache.spatial_terms,
3946 )?
3947 .ok_or_else(|| {
3948 EstimationError::InvalidInput(format!(
3949 "failed to build {} hyper_dirs for exact-joint EFS",
3950 kind.adjective(),
3951 ))
3952 })?;
3953 let design_revision = Some(self.cache.design_revision());
3954 let warm_beta = self.evaluator.current_beta();
3955 evaluate_joint_reml_efs_at_theta(
3956 &mut self.evaluator,
3957 self.cache.design(),
3958 theta,
3959 self.rho_dim,
3960 hyper_dirs,
3961 warm_beta.as_ref().map(|b: &Array1<f64>| b.view()),
3962 design_revision,
3963 )
3964 }
3965
3966 fn eval_cost(&mut self, theta: &Array1<f64>) -> f64 {
3972 if let Some(cost) = self.cache.memoized_cost(theta) {
3973 return cost;
3974 }
3975 let probe_start = std::time::Instant::now();
3990 let psi_distance = self
3991 .cache
3992 .current_theta
3993 .as_ref()
3994 .filter(|reference| reference.len() == theta.len())
3995 .map(|reference| {
3996 reference
3997 .iter()
3998 .zip(theta.iter())
3999 .map(|(a, b)| (a - b) * (a - b))
4000 .sum::<f64>()
4001 .sqrt()
4002 })
4003 .unwrap_or(f64::NAN);
4004 let nfree_fast_path_revision = self.evaluator.nfree_fast_path_revision();
4018 let skip_value_realization = theta.len() == self.rho_dim + 1 && {
4019 let psi = theta[self.rho_dim];
4020 self.evaluator.psi_gram_tensor_covers(psi)
4021 && self.evaluator.psi_gram_tensor_covers_skip(psi)
4030 && self.evaluator.supports_nfree_penalty_rekey()
4035 && nfree_fast_path_revision.is_some()
4036 };
4037 if theta.len() == self.rho_dim + 1
4038 && self.evaluator.has_psi_gram_tensor()
4039 && !self.evaluator.psi_gram_tensor_covers(theta[self.rho_dim])
4040 {
4041 self.cache.store_cost_at(theta, f64::INFINITY);
4042 return f64::INFINITY;
4043 }
4044 if !skip_value_realization && self.cache.ensure_theta(theta).is_err() {
4045 return f64::INFINITY;
4046 }
4047 if self.evaluator.supports_nfree_penalty_rekey() {
4053 match self.cache.canonical_penalties_at(theta) {
4054 Ok(penalty) => self.evaluator.stage_fast_path_penalty(Some(penalty)),
4055 Err(_) => self.evaluator.stage_fast_path_penalty(None),
4056 }
4057 }
4058 let warm_beta = self.evaluator.current_beta();
4059 if let Err(err) = self.ensure_frozen_glm_tensor(theta, warm_beta.as_ref()) {
4060 log::warn!(
4061 "[STAGE] {} value-probe at psi={:.6}: frozen-W GLM tensor setup failed ({err}); \
4062 falling back to exact streamed Gram",
4063 self.kind.label(),
4064 if theta.len() > self.rho_dim {
4065 theta[self.rho_dim]
4066 } else {
4067 f64::NAN
4068 },
4069 );
4070 self.evaluator.stage_glm_first_step_gram(None);
4071 self.evaluator.stage_glm_psi_gram_deriv(None);
4072 } else if let Err(err) =
4073 self.stage_frozen_glm_trial_statistics(theta, warm_beta.as_ref(), false)
4074 {
4075 log::warn!(
4076 "[STAGE] {} value-probe at psi={:.6}: frozen-W GLM staging failed ({err}); \
4077 falling back to exact streamed Gram",
4078 self.kind.label(),
4079 if theta.len() > self.rho_dim {
4080 theta[self.rho_dim]
4081 } else {
4082 f64::NAN
4083 },
4084 );
4085 self.evaluator.stage_glm_first_step_gram(None);
4086 self.evaluator.stage_glm_psi_gram_deriv(None);
4087 }
4088 let design_revision = if skip_value_realization {
4089 nfree_fast_path_revision
4090 } else {
4091 Some(self.cache.design_revision())
4092 };
4093 let cost_label = self.kind.label();
4094 let result = {
4095 let design = self.cache.design();
4096 self.evaluator.evaluate_cost_only(
4097 &design.design,
4098 &design.penalties,
4099 &design.nullspace_dims,
4100 design.linear_constraints.clone(),
4101 theta,
4102 self.rho_dim,
4103 warm_beta.as_ref().map(|b: &Array1<f64>| b.view()),
4104 cost_label,
4105 design_revision,
4106 )
4107 };
4108 match result {
4109 Ok(cost) => {
4110 log::debug!(
4111 "[STAGE] {cost_label} value-probe (order=Value): elapsed={:.3}s \
4112 cost={cost:.6e} trial_theta_distance={psi_distance:.3e}",
4113 probe_start.elapsed().as_secs_f64(),
4114 );
4115 self.cache.store_cost_at(theta, cost);
4116 cost
4117 }
4118 Err(_) => f64::INFINITY,
4119 }
4120 }
4121
4122 fn reset(&mut self) {
4123 self.cache.current_theta = None;
4124 self.cache.last_eval_theta = None;
4125 self.cache.last_cost = None;
4126 self.cache.last_eval = None;
4127 }
4128}
4129
4130enum SpatialJointOutcome {
4163 Optimized {
4167 theta_star: Array1<f64>,
4168 final_value: f64,
4169 },
4170 NonConverged {
4174 iterations: usize,
4175 final_value: f64,
4176 final_grad_norm: Option<f64>,
4177 },
4178}
4179
4180fn kphase_log_norms(theta: &Array1<f64>, rho_dim: usize) -> (f64, f64) {
4181 let theta_norm = theta.iter().map(|v| v * v).sum::<f64>().sqrt();
4182 let log_kappa_norm = theta
4183 .iter()
4184 .skip(rho_dim)
4185 .map(|v| v * v)
4186 .sum::<f64>()
4187 .sqrt();
4188 (theta_norm, log_kappa_norm)
4189}
4190
4191fn run_exact_joint_spatial_optimization(
4192 kind: SpatialHyperKind,
4193 data: ArrayView2<'_, f64>,
4194 y: ArrayView1<'_, f64>,
4195 weights: ArrayView1<'_, f64>,
4196 offset: ArrayView1<'_, f64>,
4197 resolvedspec: &TermCollectionSpec,
4198 baseline_design: &TermCollectionDesign,
4199 family: LikelihoodSpec,
4200 options: &FitOptions,
4201 spatial_terms: &[usize],
4202 dims_per_term: &[usize],
4203 theta0: &Array1<f64>,
4204 lower: &Array1<f64>,
4205 upper: &Array1<f64>,
4206 rho_dim: usize,
4207 kappa_options: &SpatialLengthScaleOptimizationOptions,
4208) -> Result<(SpatialJointOutcome, SpatialLengthScaleOptimizationTiming), EstimationError> {
4209 let label = kind.label();
4210 assert!(
4212 lower.len() == theta0.len() && upper.len() == theta0.len(),
4213 "spatial hyperparameter bounds must match theta length: lower_len={}, upper_len={}, theta_len={}",
4214 lower.len(),
4215 upper.len(),
4216 theta0.len()
4217 );
4218 assert!(
4219 baseline_design.smooth.terms.len() >= spatial_terms.len(),
4220 "baseline design must have at least one smooth term per spatial term: baseline_terms={}, spatial_terms={}",
4221 baseline_design.smooth.terms.len(),
4222 spatial_terms.len()
4223 );
4224 use gam_solve::rho_optimizer::OuterEvalOrder;
4225 use gam_problem::{DeclaredHessianForm, Derivative, OuterEval};
4226
4227 let theta_dim = theta0.len();
4228 let coord_dim = theta_dim - rho_dim;
4231 let analytic_outer_hessian_available =
4241 exact_joint_spatial_outer_hessian_available(&family, baseline_design);
4242 if !analytic_outer_hessian_available {
4243 log::info!(
4244 "[{label}] analytic outer Hessian unavailable for family/design; routing without second-order geometry (coord_dim={coord_dim})"
4245 );
4246 }
4247 let mut prefer_gradient_only = theta_dim > EXACT_JOINT_SECOND_ORDER_THETA_CAP;
4253 if prefer_gradient_only {
4254 log::info!(
4255 "[{label}] joint θ-dim {theta_dim} exceeds the exact pair-Hessian budget \
4256 ({EXACT_JOINT_SECOND_ORDER_THETA_CAP}); routing gradient-only quasi-Newton"
4257 );
4258 }
4259 let mut suppress_outer_hessian_for_nfree = false;
4269
4270 log::trace!(
4271 "[{}] starting analytic optimization: rho_dim={}, coord_dim={}, dims_per_term={:?}",
4272 label,
4273 rho_dim,
4274 coord_dim,
4275 dims_per_term,
4276 );
4277
4278 let mut ctx = SpatialJointContext {
4279 data,
4280 rho_dim,
4281 kind,
4282 cache: SingleBlockExactJointDesignCache::new(
4283 data,
4284 resolvedspec.clone(),
4285 baseline_design.clone(),
4286 spatial_terms.to_vec(),
4287 rho_dim,
4288 dims_per_term.to_vec(),
4289 )
4290 .map_err(EstimationError::InvalidInput)?,
4291 evaluator: gam_solve::estimate::ExternalJointHyperEvaluator::new(
4292 y,
4293 weights,
4294 &baseline_design.design,
4295 offset,
4296 &baseline_design.penalties,
4297 &external_opts_for_design(&family, baseline_design, options),
4298 label,
4299 )?,
4300 frozen_glm_inputs: if coord_dim == 1 && frozen_glm_tensor_eligible_family(&family) {
4301 Some(SpatialFrozenGlmInputs {
4302 y: y.to_owned(),
4303 weights: weights.to_owned(),
4304 offset: offset.to_owned(),
4305 family: family.clone(),
4306 })
4307 } else {
4308 None
4309 },
4310 frozen_glm_psi_bounds: if coord_dim == 1 && frozen_glm_tensor_eligible_family(&family) {
4311 Some((lower[rho_dim], upper[rho_dim]))
4312 } else {
4313 None
4314 },
4315 frozen_glm_tensor: None,
4316 frozen_glm_tensor_attempted: false,
4317 frozen_glm_weight_memo: None,
4318 };
4319
4320 let mut psi_rank_stable_floor: Option<f64> = None;
4343 let mut psi_rank_stable_ceiling: Option<f64> = None;
4352 let nfree_penalty_capable = coord_dim == 1
4353 && family.is_gaussian_identity()
4354 && ctx.cache.supports_nfree_penalty_rekey();
4355 if nfree_penalty_capable {
4356 let psi_lo = lower[rho_dim];
4357 let psi_hi = upper[rho_dim];
4358 let z = Array1::from_iter(y.iter().zip(offset.iter()).map(|(yi, oi)| yi - oi));
4359 let theta_probe_base = theta0.clone();
4360 let SpatialJointContext {
4363 cache, evaluator, ..
4364 } = &mut ctx;
4365 let attached = evaluator.build_and_set_psi_gram_tensor(
4366 |psi| {
4367 let mut theta_probe = theta_probe_base.clone();
4368 theta_probe[rho_dim] = psi;
4369 cache.ensure_theta(&theta_probe)?;
4370 Ok(cache.design().design.clone())
4371 },
4372 weights,
4373 z.view(),
4374 psi_lo,
4375 psi_hi,
4376 );
4377 if attached {
4378 log::info!(
4379 "[{label}] certified ψ-gram tensor over [{psi_lo:.3}, {psi_hi:.3}]: \
4380 in-window trials assemble Gaussian sufficient statistics n-free"
4381 );
4382 let psi_anchor = theta0[rho_dim];
4387 psi_rank_stable_floor = evaluator
4388 .psi_gram_rank_stable_floor(psi_anchor)
4389 .filter(|&f| f.is_finite() && f > psi_lo && f < psi_anchor);
4390 log::info!(
4391 "[KAPPA-PHASE-FLOOR] n_rows={} psi_lo={psi_lo:.6} psi_anchor={psi_anchor:.6} \
4392 rank_stable_floor={:?} lifted={}",
4393 data.nrows(),
4394 evaluator.psi_gram_rank_stable_floor(psi_anchor),
4395 psi_rank_stable_floor.is_some(),
4396 );
4397 if let Some(floor) = psi_rank_stable_floor {
4398 log::info!(
4399 "[{label}] rank-stable κ-floor ψ_floor={floor:.6} > window floor \
4400 ψ_lo={psi_lo:.6}: lifting the optimizer lower bound to keep every \
4401 in-window trial on the n-free design-realization skip (#1033). The \
4402 conditioned Gram is rank-deficient below ψ_floor (longest-length-scale \
4403 radial mode collapses into the nullspace), where the skip is soundly \
4404 refused; that band drifts with n via the sample-std standardization, \
4405 so this n-free k-space floor is the n-independent fix."
4406 );
4407 }
4408 psi_rank_stable_ceiling = evaluator
4417 .psi_gram_rank_stable_ceiling(psi_anchor)
4418 .filter(|&c| c.is_finite() && c < psi_hi && c > psi_anchor);
4419 log::info!(
4420 "[KAPPA-PHASE-CEIL] n_rows={} psi_hi={psi_hi:.6} psi_anchor={psi_anchor:.6} \
4421 rank_stable_ceiling={:?} clamped={}",
4422 data.nrows(),
4423 evaluator.psi_gram_rank_stable_ceiling(psi_anchor),
4424 psi_rank_stable_ceiling.is_some(),
4425 );
4426 if let Some(ceiling) = psi_rank_stable_ceiling {
4427 log::info!(
4428 "[{label}] rank-stable κ-ceiling ψ_ceil={ceiling:.6} < window ceiling \
4429 ψ_hi={psi_hi:.6}: clamping the optimizer upper bound to keep every \
4430 in-window trial on the n-free design-realization skip (#1033). The \
4431 conditioned Gram is rank-deficient above ψ_ceil (longest-frequency \
4432 radial mode goes collinear), where the skip is soundly refused; a \
4433 line-search overshoot there trips the O(n) reset_surface lane (and the \
4434 deficient pinning ψ it records resets the next in-band trial too)."
4435 );
4436 }
4437 let gradient_covers_full_window = evaluator.psi_gram_tensor_covers_gradient(psi_lo)
4438 && evaluator.psi_gram_tensor_covers_gradient(psi_hi);
4439 if gradient_covers_full_window {
4440 log::info!(
4441 "[{label}] certified ψ-gram tensor gradient lane covers the full \
4442 optimizer window [{psi_lo:.3}, {psi_hi:.3}]"
4443 );
4444 } else {
4445 log::info!(
4446 "[{label}] ψ-gram tensor value lane certified, but the gradient lane \
4447 does not cover the full optimizer window [{psi_lo:.3}, {psi_hi:.3}]; \
4448 keeping exact streamed kappa routing"
4449 );
4450 }
4451 evaluator.set_supports_nfree_penalty_rekey(true);
4471 log::info!(
4472 "[{label}] exact n-free ψ-penalty re-key enabled over [{psi_lo:.3}, \
4473 {psi_hi:.3}]: in-window fast-path trials rebuild S(ψ) n-free from frozen \
4474 geometry (no reset_surface)"
4475 );
4476 } else {
4477 log::info!(
4478 "[{label}] ψ-gram tensor did not certify over [{psi_lo:.3}, {psi_hi:.3}]; \
4479 keeping the exact per-trial path"
4480 );
4481 }
4482 if attached
4503 && evaluator.psi_gram_tensor_covers_gradient(psi_lo)
4504 && evaluator.psi_gram_tensor_covers_gradient(psi_hi)
4505 && evaluator.supports_nfree_penalty_rekey()
4506 && cache.supports_nfree_gradient_only_routing()
4507 {
4508 suppress_outer_hessian_for_nfree = true;
4509 prefer_gradient_only = true;
4510 log::info!(
4511 "[{label}] n-free Gaussian ψ-lane armed; suppressing the analytic outer \
4512 Hessian and routing gradient-only (BFGS) so the κ outer loop never realizes \
4513 the O(n) second-order slab — n-independent outer loop (#1033)"
4514 );
4515 }
4516 } else if coord_dim == 1 && family.is_gaussian_identity() {
4517 log::info!(
4518 "[{label}] exact n-free ψ-penalty re-key unavailable; skipping ψ-gram tensor \
4519 attachment so value, gradient, and Hessian remain on the same exact streamed \
4520 objective"
4521 );
4522 }
4523
4524 const OUTER_FD_AUDIT_MAX_N: usize = 4_000; const OUTER_FD_AUDIT_MAX_THETA_DIM: usize = 32; let n_total = data.nrows();
4552 let outer_fd_audit_eligible = log::log_enabled!(log::Level::Info) && analytic_outer_hessian_available && n_total <= OUTER_FD_AUDIT_MAX_N && theta_dim <= OUTER_FD_AUDIT_MAX_THETA_DIM; log::info!(
4557 "[OUTER-FD-AUDIT/spatial-exact-joint] gate eligible={outer_fd_audit_eligible} \
4558 analytic_grad={analytic_outer_hessian_available} n_total={n_total} \
4559 theta_dim={theta_dim} rho_dim={rho_dim} psi_dim={coord_dim}"
4560 );
4561 if outer_fd_audit_eligible {
4562 let audit = (|| -> Result<gam_solve::rho_optimizer::OuterGradientFdAudit, String> {
4564 let mut eval_at = |theta: &Array1<f64>,
4565 mode: gam_solve::estimate::reml::reml_outer_engine::EvalMode|
4566 -> Result<
4567 (
4568 f64,
4569 Array1<f64>,
4570 gam_problem::HessianResult,
4571 ),
4572 String,
4573 > {
4574 use gam_solve::estimate::reml::reml_outer_engine::EvalMode;
4575 let order = if matches!(mode, EvalMode::ValueGradientHessian) {
4576 OuterEvalOrder::ValueGradientHessian
4577 } else {
4578 OuterEvalOrder::Value
4579 };
4580 ctx.eval_full(theta, order, analytic_outer_hessian_available)
4581 .map_err(|e| format!("fd-audit eval_full: {e}"))
4582 };
4583 let rho_dim_audit = rho_dim;
4584 let label_fn = move |i: usize| -> String {
4585 if i < rho_dim_audit {
4586 format!("rho[{i}]")
4587 } else {
4588 format!("psi_kappa[{}]", i - rho_dim_audit)
4589 }
4590 };
4591 gam_solve::rho_optimizer::outer_gradient_fd_audit(
4592 theta0,
4594 1e-4,
4595 label_fn,
4596 &mut eval_at,
4597 )
4598 })();
4599 match audit {
4601 Ok(audit) => audit.log_verdict("spatial-exact-joint"),
4602 Err(e) => log::warn!("[OUTER-FD-AUDIT/spatial-exact-joint] skipped: {e}"),
4603 }
4604 }
4605
4606 let kphase_prime_order = if analytic_outer_hessian_available && !suppress_outer_hessian_for_nfree {
4607 OuterEvalOrder::ValueGradientHessian
4608 } else {
4609 OuterEvalOrder::ValueAndGradient
4610 };
4611 let kphase_prime_start = std::time::Instant::now();
4612 drop(ctx.eval_full(theta0, kphase_prime_order, analytic_outer_hessian_available)?);
4613 log::info!(
4614 "[KAPPA-PHASE-PRIME] n_rows={} order={:?} elapsed_s={:.4} slow_path_resets_total={} design_revision={}",
4615 data.nrows(),
4616 kphase_prime_order,
4617 kphase_prime_start.elapsed().as_secs_f64(),
4618 ctx.evaluator.slow_path_reset_count(),
4619 ctx.cache.design_revision(),
4620 );
4621
4622 let kphase_cost_calls = std::cell::Cell::new(0usize);
4623 let kphase_eval_calls = std::cell::Cell::new(0usize);
4624 let kphase_efs_calls = std::cell::Cell::new(0usize);
4625 let kphase_cost_total_s = std::cell::Cell::new(0.0);
4626 let kphase_eval_total_s = std::cell::Cell::new(0.0);
4627 let kphase_efs_total_s = std::cell::Cell::new(0.0);
4628 let kphase_nfree_miss_shape = std::cell::Cell::new(0u64);
4629 let kphase_nfree_miss_value = std::cell::Cell::new(0u64);
4630 let kphase_nfree_miss_gradient = std::cell::Cell::new(0u64);
4631 let kphase_nfree_miss_penalty = std::cell::Cell::new(0u64);
4632 let kphase_nfree_miss_revision = std::cell::Cell::new(0u64);
4633 let kphase_nfree_miss_second_order = std::cell::Cell::new(0u64);
4634 let kphase_nfree_miss_other = std::cell::Cell::new(0u64);
4635 let kphase_optim_start = std::time::Instant::now();
4636 let kphase_log_kappa_dim = coord_dim;
4637 let kphase_slow_resets_start = ctx.evaluator.slow_path_reset_count();
4638 let kphase_design_revision_start = ctx.cache.design_revision();
4639
4640 let lower_effective: std::borrow::Cow<'_, Array1<f64>> = match psi_rank_stable_floor {
4647 Some(floor) if coord_dim == 1 && floor > lower[rho_dim] => {
4648 let mut lifted = lower.clone();
4649 lifted[rho_dim] = floor;
4650 std::borrow::Cow::Owned(lifted)
4651 }
4652 _ => std::borrow::Cow::Borrowed(lower),
4653 };
4654 let lower = lower_effective.as_ref();
4655
4656 let upper_effective: std::borrow::Cow<'_, Array1<f64>> = match psi_rank_stable_ceiling {
4664 Some(ceiling) if coord_dim == 1 && ceiling < upper[rho_dim] => {
4665 let mut clamped = upper.clone();
4666 clamped[rho_dim] = ceiling;
4667 std::borrow::Cow::Owned(clamped)
4668 }
4669 _ => std::borrow::Cow::Borrowed(upper),
4670 };
4671 let upper = upper_effective.as_ref();
4672
4673 let problem = exact_joint_multistart_outer_problem(
4674 theta0,
4675 lower,
4676 upper,
4677 rho_dim,
4678 coord_dim,
4679 theta_dim,
4680 Derivative::Analytic,
4681 if analytic_outer_hessian_available && !suppress_outer_hessian_for_nfree {
4682 DeclaredHessianForm::Either
4683 } else {
4684 DeclaredHessianForm::Unavailable
4689 },
4690 prefer_gradient_only,
4691 suppress_outer_hessian_for_nfree,
4702 seed_risk_profile_for_likelihood_family(&family),
4703 kappa_options.rel_tol.max(1e-6),
4704 kappa_options.max_outer_iter.max(1),
4705 Some(5.0),
4709 Some(kappa_options.log_step.clamp(0.25, 1.0)),
4711 None,
4712 Some((data.nrows(), baseline_design.design.ncols())),
4717 !constant_curvature_term_indices(resolvedspec).is_empty(),
4721 );
4722
4723 let eval_outer = |ctx: &mut &mut SpatialJointContext<'_>,
4724 theta: &Array1<f64>,
4725 order: OuterEvalOrder|
4726 -> Result<OuterEval, EstimationError> {
4727 let t0 = std::time::Instant::now();
4728 let allow_second_order_for_call = matches!(order, OuterEvalOrder::ValueGradientHessian)
4729 && analytic_outer_hessian_available;
4730 let gate = ctx.nfree_skip_gate_status(theta, allow_second_order_for_call, true);
4731 let resets_before = ctx.evaluator.slow_path_reset_count();
4732 let raw = ctx.eval_full(theta, order, analytic_outer_hessian_available);
4733 let reset_delta = ctx
4734 .evaluator
4735 .slow_path_reset_count()
4736 .saturating_sub(resets_before);
4737 if reset_delta > 0 {
4738 if !gate.shape {
4739 kphase_nfree_miss_shape.set(kphase_nfree_miss_shape.get() + reset_delta);
4740 }
4741 if gate.shape && !gate.value {
4742 kphase_nfree_miss_value.set(kphase_nfree_miss_value.get() + reset_delta);
4743 }
4744 if gate.shape && gate.value && !gate.gradient {
4745 kphase_nfree_miss_gradient.set(kphase_nfree_miss_gradient.get() + reset_delta);
4746 }
4747 if gate.shape && gate.value && gate.gradient && !gate.penalty {
4748 kphase_nfree_miss_penalty.set(kphase_nfree_miss_penalty.get() + reset_delta);
4749 }
4750 if gate.shape && gate.value && gate.gradient && gate.penalty && !gate.revision {
4751 kphase_nfree_miss_revision.set(kphase_nfree_miss_revision.get() + reset_delta);
4752 }
4753 if gate.shape
4754 && gate.value
4755 && gate.gradient
4756 && gate.penalty
4757 && gate.revision
4758 && gate.second_order
4759 {
4760 kphase_nfree_miss_second_order
4761 .set(kphase_nfree_miss_second_order.get() + reset_delta);
4762 }
4763 if gate.would_skip(true) {
4764 kphase_nfree_miss_other.set(kphase_nfree_miss_other.get() + reset_delta);
4765 }
4766 }
4767 let elapsed_s = t0.elapsed().as_secs_f64();
4768 kphase_eval_calls.set(kphase_eval_calls.get() + 1);
4769 kphase_eval_total_s.set(kphase_eval_total_s.get() + elapsed_s);
4770 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta, rho_dim);
4771 log::info!(
4772 "[KAPPA-PHASE] phase=eval_outer call={} order={:?} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
4773 kphase_eval_calls.get(),
4774 order,
4775 Some(ctx.cache.design_revision()),
4776 theta_norm,
4777 log_kappa_norm,
4778 elapsed_s,
4779 );
4780 match raw {
4781 Ok((cost, grad, hess)) => Ok(OuterEval {
4782 cost,
4783 gradient: grad,
4784 hessian: hess,
4785 inner_beta_hint: None,
4786 }),
4787 Err(err) if is_recoverable_trial_point_error(&err) => {
4795 log::debug!(
4796 "[{label}] trial point infeasible (kernel design \
4797 not constructible at theta={theta:?}): {err}; retreating",
4798 );
4799 Ok(OuterEval::infeasible(theta_dim))
4800 }
4801 Err(err) => Err(err),
4802 }
4803 };
4804
4805 let mut obj = problem.build_objective_with_eval_order(
4806 &mut ctx,
4807 |ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>| {
4808 let t0 = std::time::Instant::now();
4809 let gate = ctx.nfree_skip_gate_status(theta, false, false);
4810 let resets_before = ctx.evaluator.slow_path_reset_count();
4811 let cost = ctx.eval_cost(theta);
4812 let reset_delta = ctx
4813 .evaluator
4814 .slow_path_reset_count()
4815 .saturating_sub(resets_before);
4816 if reset_delta > 0 {
4817 if !gate.shape {
4818 kphase_nfree_miss_shape.set(kphase_nfree_miss_shape.get() + reset_delta);
4819 }
4820 if gate.shape && !gate.value {
4821 kphase_nfree_miss_value.set(kphase_nfree_miss_value.get() + reset_delta);
4822 }
4823 if gate.shape && gate.value && !gate.penalty {
4824 kphase_nfree_miss_penalty.set(kphase_nfree_miss_penalty.get() + reset_delta);
4825 }
4826 if gate.shape && gate.value && gate.penalty && !gate.revision {
4827 kphase_nfree_miss_revision.set(kphase_nfree_miss_revision.get() + reset_delta);
4828 }
4829 if gate.would_skip(false) {
4830 kphase_nfree_miss_other.set(kphase_nfree_miss_other.get() + reset_delta);
4831 }
4832 }
4833 let elapsed_s = t0.elapsed().as_secs_f64();
4834 kphase_cost_calls.set(kphase_cost_calls.get() + 1);
4835 kphase_cost_total_s.set(kphase_cost_total_s.get() + elapsed_s);
4836 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta, rho_dim);
4837 log::info!(
4838 "[KAPPA-PHASE] phase=cost call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
4839 kphase_cost_calls.get(),
4840 Some(ctx.cache.design_revision()),
4841 theta_norm,
4842 log_kappa_norm,
4843 elapsed_s,
4844 );
4845 Ok(cost)
4846 },
4847 |ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>| {
4848 eval_outer(
4849 ctx,
4850 theta,
4851 if analytic_outer_hessian_available && !suppress_outer_hessian_for_nfree {
4861 OuterEvalOrder::ValueGradientHessian
4862 } else {
4863 OuterEvalOrder::ValueAndGradient
4864 },
4865 )
4866 },
4867 |ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>, order: OuterEvalOrder| {
4868 eval_outer(ctx, theta, order)
4869 },
4870 Some(|ctx: &mut &mut SpatialJointContext<'_>| {
4871 ctx.reset();
4872 }),
4873 Some(|ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>| {
4874 let t0 = std::time::Instant::now();
4875 let eval = ctx.eval_efs(theta);
4876 let elapsed_s = t0.elapsed().as_secs_f64();
4877 kphase_efs_calls.set(kphase_efs_calls.get() + 1);
4878 kphase_efs_total_s.set(kphase_efs_total_s.get() + elapsed_s);
4879 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta, rho_dim);
4880 log::info!(
4881 "[KAPPA-PHASE] phase=efs call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
4882 kphase_efs_calls.get(),
4883 Some(ctx.cache.design_revision()),
4884 theta_norm,
4885 log_kappa_norm,
4886 elapsed_s,
4887 );
4888 eval
4889 }),
4890 );
4891
4892 let run_label = match kind {
4893 SpatialHyperKind::Anisotropic => "aniso-psi joint REML",
4894 SpatialHyperKind::Isotropic => "iso-kappa joint REML",
4895 };
4896 let result = problem.run(&mut obj, run_label).map_err(|e| {
4897 EstimationError::InvalidInput(format!(
4898 "{} analytic optimization failed after exhausting strategy fallbacks: {e}",
4899 kind.adjective(),
4900 ))
4901 })?;
4902 drop(obj);
4903 let kphase_total_s = kphase_optim_start.elapsed().as_secs_f64();
4904 let kphase_slow_resets = ctx
4905 .evaluator
4906 .slow_path_reset_count()
4907 .saturating_sub(kphase_slow_resets_start);
4908 let kphase_design_revision_delta = ctx
4909 .cache
4910 .design_revision()
4911 .saturating_sub(kphase_design_revision_start);
4912 log::info!(
4913 "[KAPPA-PHASE-SUMMARY] n_rows={} log_kappa_dim={} n_cost={} cost_total_s={:.4} n_eval={} eval_total_s={:.4} n_efs={} efs_total_s={:.4} slow_path_resets={} design_revision_delta={} nfree_miss_shape={} nfree_miss_value={} nfree_miss_gradient={} nfree_miss_penalty={} nfree_miss_revision={} nfree_miss_second_order={} nfree_miss_other={} optim_total_s={:.4}",
4914 data.nrows(),
4915 kphase_log_kappa_dim,
4916 kphase_cost_calls.get(),
4917 kphase_cost_total_s.get(),
4918 kphase_eval_calls.get(),
4919 kphase_eval_total_s.get(),
4920 kphase_efs_calls.get(),
4921 kphase_efs_total_s.get(),
4922 kphase_slow_resets,
4923 kphase_design_revision_delta,
4924 kphase_nfree_miss_shape.get(),
4925 kphase_nfree_miss_value.get(),
4926 kphase_nfree_miss_gradient.get(),
4927 kphase_nfree_miss_penalty.get(),
4928 kphase_nfree_miss_revision.get(),
4929 kphase_nfree_miss_second_order.get(),
4930 kphase_nfree_miss_other.get(),
4931 kphase_total_s,
4932 );
4933 let timing = SpatialLengthScaleOptimizationTiming {
4934 log_kappa_dim: kphase_log_kappa_dim,
4935 cost_calls: kphase_cost_calls.get(),
4936 cost_total_s: kphase_cost_total_s.get(),
4937 eval_calls: kphase_eval_calls.get(),
4938 eval_total_s: kphase_eval_total_s.get(),
4939 efs_calls: kphase_efs_calls.get(),
4940 efs_total_s: kphase_efs_total_s.get(),
4941 slow_path_resets: kphase_slow_resets,
4942 design_revision_delta: kphase_design_revision_delta,
4943 nfree_miss_shape: kphase_nfree_miss_shape.get(),
4944 nfree_miss_value: kphase_nfree_miss_value.get(),
4945 nfree_miss_gradient: kphase_nfree_miss_gradient.get(),
4946 nfree_miss_penalty: kphase_nfree_miss_penalty.get(),
4947 nfree_miss_revision: kphase_nfree_miss_revision.get(),
4948 nfree_miss_second_order: kphase_nfree_miss_second_order.get(),
4949 nfree_miss_other: kphase_nfree_miss_other.get(),
4950 optim_total_s: kphase_total_s,
4951 };
4952 if !result.converged {
4953 let rel_to_cost_threshold = options.tol * (1.0_f64 + result.final_value.abs());
4964 if let Some(final_grad) = result
4965 .final_grad_norm
4966 .filter(|v| v.is_finite() && *v <= rel_to_cost_threshold)
4967 {
4968 log::info!(
4969 "[{}] outer optimization hit max_iter={} but \
4970 projected gradient norm {:.3e} ≤ τ·(1+|f|) = {:.3e} \
4971 (τ={:.3e}, |f|={:.3e}); accepting iterate under the mgcv-style \
4972 relative-to-cost REML convergence criterion.",
4973 label,
4974 result.iterations,
4975 final_grad,
4976 rel_to_cost_threshold,
4977 options.tol,
4978 result.final_value.abs(),
4979 );
4980 } else if result.final_value.is_finite() {
4981 log::warn!(
4996 "[{}] {} did not converge after {} iterations \
4997 (final_objective={:.6e}, final_grad_norm={}); keeping the \
4998 frozen baseline geometry instead of aborting the fit.",
4999 label,
5000 kind.adjective(),
5001 result.iterations,
5002 result.final_value,
5003 result.final_grad_norm_report(),
5004 );
5005 return Ok((
5006 SpatialJointOutcome::NonConverged {
5007 iterations: result.iterations,
5008 final_value: result.final_value,
5009 final_grad_norm: result.final_grad_norm,
5010 },
5011 timing,
5012 ));
5013 } else {
5014 crate::bail_invalid_estim!(
5019 "{} analytic optimization diverged after {} iterations (final_objective={:.6e}, final_grad_norm={})",
5020 kind.adjective(),
5021 result.iterations,
5022 result.final_value,
5023 result.final_grad_norm_report(),
5024 );
5025 }
5026 }
5027 log::trace!(
5028 "[{}] converged in {} iterations, final_value={:.6e}, grad_norm={}",
5029 label,
5030 result.iterations,
5031 result.final_value,
5032 result.final_grad_norm_report(),
5033 );
5034 let theta_star = result.rho;
5038 Ok((
5039 SpatialJointOutcome::Optimized {
5040 theta_star,
5041 final_value: result.final_value,
5042 },
5043 timing,
5044 ))
5045}
5046
5047fn set_single_term_spatial_length_scale(
5051 term: &mut SmoothTermSpec,
5052 length_scale: f64,
5053) -> Result<(), EstimationError> {
5054 match &mut term.basis {
5055 SmoothBasisSpec::ThinPlate { spec, .. } => {
5056 spec.length_scale = length_scale;
5057 Ok(())
5058 }
5059 SmoothBasisSpec::Matern { spec, .. } => {
5060 spec.length_scale = length_scale;
5061 Ok(())
5062 }
5063 SmoothBasisSpec::Duchon { spec, .. } => {
5064 spec.length_scale = Some(length_scale);
5065 Ok(())
5066 }
5067 _ => Err(EstimationError::InvalidInput(format!(
5068 "term '{}' does not expose a spatial length scale",
5069 term.name
5070 ))),
5071 }
5072}
5073
5074fn set_single_term_spatial_aniso_log_scales(
5078 term: &mut SmoothTermSpec,
5079 eta: Vec<f64>,
5080) -> Result<(), EstimationError> {
5081 let eta = center_aniso_log_scales(&eta);
5082 match &mut term.basis {
5083 SmoothBasisSpec::Matern { spec, .. } => {
5084 spec.aniso_log_scales = Some(eta);
5085 Ok(())
5086 }
5087 SmoothBasisSpec::Duchon { spec, .. } => {
5088 spec.aniso_log_scales = Some(eta);
5089 Ok(())
5090 }
5091 _ => Err(EstimationError::InvalidInput(format!(
5092 "term '{}' does not support aniso_log_scales",
5093 term.name
5094 ))),
5095 }
5096}
5097
5098pub fn get_constant_curvature_kappa(spec: &TermCollectionSpec, term_idx: usize) -> Option<f64> {
5117 constant_curvature_term_spec(spec, term_idx).map(|cc| cc.kappa)
5118}
5119
5120pub fn constant_curvature_term_indices(spec: &TermCollectionSpec) -> Vec<usize> {
5122 (0..spec.smooth_terms.len())
5123 .filter(|&idx| constant_curvature_term_spec(spec, idx).is_some())
5124 .collect()
5125}
5126
5127
5128#[derive(Debug, Clone)]
5129struct SingleSmoothTermRealization {
5130 design_local: DesignMatrix,
5131 term: SmoothTerm,
5132 dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
5133}
5134
5135impl SingleSmoothTermRealization {
5136 fn active_penaltyinfo(&self) -> Vec<PenaltyInfo> {
5137 self.term
5138 .penaltyinfo_local
5139 .iter()
5140 .filter(|info| info.active)
5141 .cloned()
5142 .collect()
5143 }
5144}
5145
5146fn build_single_smooth_term_realization(
5147 data: ArrayView2<'_, f64>,
5148 termspec: &SmoothTermSpec,
5149) -> Result<SingleSmoothTermRealization, BasisError> {
5150 let raw = build_smooth_design(data, std::slice::from_ref(termspec))?;
5151 finish_single_smooth_term_realization(raw)
5152}
5153
5154fn finish_single_smooth_term_realization(
5155 raw: RawSmoothDesign,
5156) -> Result<SingleSmoothTermRealization, BasisError> {
5157 let RawSmoothDesign {
5158 term_designs,
5159 dropped_penaltyinfo,
5160 terms,
5161 ..
5162 } = raw;
5163 let term = terms.into_iter().next().ok_or_else(|| {
5164 BasisError::InvalidInput("single-term smooth build returned no term".to_string())
5165 })?;
5166 let design = term_designs.into_iter().next().ok_or_else(|| {
5167 BasisError::InvalidInput("single-term smooth build returned no term design".to_string())
5168 })?;
5169
5170 Ok(SingleSmoothTermRealization {
5171 design_local: design,
5172 term,
5173 dropped_penaltyinfo,
5174 })
5175}
5176
5177fn wrap_local_build_as_realization(
5184 mut local: LocalSmoothTermBuild,
5185 termspec: &SmoothTermSpec,
5186) -> Result<SingleSmoothTermRealization, String> {
5187 let p_local = local.dim;
5188 let lb_local = if local.box_reparam {
5189 shape_lower_bounds_local(termspec.shape, p_local)
5190 } else {
5191 None
5192 };
5193
5194 let active_count = local.penaltyinfo.iter().filter(|info| info.active).count();
5195 if active_count != local.penalties.len() {
5196 return Err(format!(
5197 "internal penalty info mismatch for term '{}': active_infos={}, penalties={}",
5198 termspec.name,
5199 active_count,
5200 local.penalties.len()
5201 ));
5202 }
5203
5204 let mut dropped_penaltyinfo = Vec::<DroppedPenaltyBlockInfo>::new();
5205 for info in local.penaltyinfo.iter().filter(|info| !info.active) {
5206 dropped_penaltyinfo.push(DroppedPenaltyBlockInfo {
5207 termname: Some(termspec.name.clone()),
5208 penalty: info.clone(),
5209 });
5210 }
5211 for info in &local.pre_dropped_penaltyinfo {
5212 dropped_penaltyinfo.push(DroppedPenaltyBlockInfo {
5213 termname: Some(termspec.name.clone()),
5214 penalty: info.clone(),
5215 });
5216 }
5217
5218 let applied_rotation: Option<gam_terms::basis::JointNullRotation> = match (
5222 local.joint_null_rotation.take(),
5223 lb_local.is_some(),
5224 local.linear_constraints.is_some(),
5225 ) {
5226 (Some(rot), false, false) => {
5227 let q = &rot.rotation;
5228 let dense = local
5229 .design
5230 .try_to_dense_by_chunks("joint-null absorption rotation (single realization)")
5231 .map_err(|e| {
5232 format!(
5233 "joint-null absorption rotation: dense conversion failed for term '{}': {}",
5234 termspec.name, e
5235 )
5236 })?;
5237 let rotated = gam_linalg::faer_ndarray::fast_ab(&dense, q);
5238 local.design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(rotated));
5239 local.penalties = local
5240 .penalties
5241 .into_iter()
5242 .map(|s_local| {
5243 let qt_s = gam_linalg::faer_ndarray::fast_atb(q, &s_local);
5244 gam_linalg::faer_ndarray::fast_ab(&qt_s, q)
5245 })
5246 .collect();
5247 local.ops = vec![None; local.penalties.len()];
5248 local.kronecker_factored = None;
5249 Some(rot)
5250 }
5251 (Some(_), _, _) => None,
5252 (None, _, _) => None,
5253 };
5254
5255 let smooth_term = SmoothTerm {
5256 name: termspec.name.clone(),
5257 coeff_range: 0..p_local,
5258 shape: termspec.shape,
5259 penalties_local: local.penalties.clone(),
5260 nullspace_dims: local.nullspaces.clone(),
5261 penaltyinfo_local: local.penaltyinfo.clone(),
5262 metadata: local.metadata.clone(),
5263 lower_bounds_local: lb_local,
5264 linear_constraints_local: local.linear_constraints.clone(),
5265 kronecker_factored: local.kronecker_factored.take(),
5266 joint_null_rotation: applied_rotation,
5267 unabsorbed_global_orthogonality: None,
5270 };
5271
5272 Ok(SingleSmoothTermRealization {
5273 design_local: local.design,
5274 term: smooth_term,
5275 dropped_penaltyinfo,
5276 })
5277}
5278
5279fn freeze_geometry_from_metadata(
5290 termspec: &SmoothTermSpec,
5291 metadata: &BasisMetadata,
5292) -> Option<SmoothTermSpec> {
5293 let mut frozen = termspec.clone();
5294 match (&mut frozen.basis, metadata) {
5295 (
5296 SmoothBasisSpec::Matern {
5297 spec,
5298 input_scales: spec_scales,
5299 ..
5300 },
5301 BasisMetadata::Matern {
5302 centers,
5303 input_scales: meta_scales,
5304 identifiability_transform,
5305 nullspace_shrinkage_survived,
5306 ..
5307 },
5308 ) => {
5309 spec.center_strategy = CenterStrategy::UserProvided(centers.clone());
5310 if spec_scales.is_none()
5311 && let Some(s) = meta_scales.clone()
5312 {
5313 *spec_scales = Some(s);
5314 }
5315 if let Some(transform) = identifiability_transform.clone() {
5333 spec.identifiability = MaternIdentifiability::FrozenTransform {
5334 transform,
5335 nullspace_shrinkage_survived: Some(*nullspace_shrinkage_survived),
5336 };
5337 }
5338 Some(frozen)
5339 }
5340 (
5341 SmoothBasisSpec::Duchon {
5342 spec,
5343 input_scales: spec_scales,
5344 ..
5345 },
5346 BasisMetadata::Duchon {
5347 centers,
5348 input_scales: meta_scales,
5349 ..
5350 },
5351 ) => {
5352 spec.center_strategy = CenterStrategy::UserProvided(centers.clone());
5353 if spec_scales.is_none()
5354 && let Some(s) = meta_scales.clone()
5355 {
5356 *spec_scales = Some(s);
5357 }
5358 Some(frozen)
5359 }
5360 (
5361 SmoothBasisSpec::ThinPlate {
5362 spec,
5363 input_scales: spec_scales,
5364 ..
5365 },
5366 BasisMetadata::ThinPlate {
5367 centers,
5368 input_scales: meta_scales,
5369 ..
5370 },
5371 ) => {
5372 spec.center_strategy = CenterStrategy::UserProvided(centers.clone());
5373 if spec_scales.is_none()
5374 && let Some(s) = meta_scales.clone()
5375 {
5376 *spec_scales = Some(s);
5377 }
5378 Some(frozen)
5379 }
5380 _ => None,
5383 }
5384}
5385
5386fn rebuild_smooth_auxiliary_state(
5387 smooth: &mut SmoothDesign,
5388 dropped_penaltyinfo_by_term: &[Vec<DroppedPenaltyBlockInfo>],
5389) -> Result<(), String> {
5390 if dropped_penaltyinfo_by_term.len() != smooth.terms.len() {
5391 return Err(SmoothError::dimension_mismatch(format!(
5392 "smooth dropped-penalty cache mismatch: terms={}, dropped_sets={}",
5393 smooth.terms.len(),
5394 dropped_penaltyinfo_by_term.len()
5395 ))
5396 .into());
5397 }
5398
5399 let total_p = smooth.total_smooth_cols();
5400 let mut coefficient_lower_bounds = Array1::<f64>::from_elem(total_p, f64::NEG_INFINITY);
5401 let mut any_bounds = false;
5402 let mut linear_constraintrows: Vec<Array1<f64>> = Vec::new();
5403 let mut linear_constraint_b: Vec<f64> = Vec::new();
5404
5405 for term in &smooth.terms {
5406 let range = term.coeff_range.clone();
5407 if let Some(lb_local) = term.lower_bounds_local.as_ref() {
5408 if lb_local.len() != range.len() {
5409 return Err(SmoothError::dimension_mismatch(format!(
5410 "smooth lower-bound cache mismatch for term '{}': bounds={}, coeffs={}",
5411 term.name,
5412 lb_local.len(),
5413 range.len()
5414 ))
5415 .into());
5416 }
5417 coefficient_lower_bounds
5418 .slice_mut(s![range.clone()])
5419 .assign(lb_local);
5420 any_bounds = true;
5421 }
5422 if let Some(lin_local) = term.linear_constraints_local.as_ref() {
5423 if lin_local.a.ncols() != range.len() {
5424 return Err(SmoothError::dimension_mismatch(format!(
5425 "smooth linear-constraint cache mismatch for term '{}': cols={}, coeffs={}",
5426 term.name,
5427 lin_local.a.ncols(),
5428 range.len()
5429 ))
5430 .into());
5431 }
5432 for r in 0..lin_local.a.nrows() {
5433 let mut row = Array1::<f64>::zeros(total_p);
5434 row.slice_mut(s![range.clone()]).assign(&lin_local.a.row(r));
5435 linear_constraintrows.push(row);
5436 linear_constraint_b.push(lin_local.b[r]);
5437 }
5438 }
5439 }
5440
5441 smooth.coefficient_lower_bounds = if any_bounds {
5442 Some(coefficient_lower_bounds)
5443 } else {
5444 None
5445 };
5446 smooth.linear_constraints = if linear_constraintrows.is_empty() {
5447 None
5448 } else {
5449 let mut a = Array2::<f64>::zeros((linear_constraintrows.len(), total_p));
5450 for (i, row) in linear_constraintrows.iter().enumerate() {
5451 a.row_mut(i).assign(row);
5452 }
5453 Some(LinearInequalityConstraints {
5454 a,
5455 b: Array1::from_vec(linear_constraint_b),
5456 })
5457 };
5458 smooth.dropped_penaltyinfo = dropped_penaltyinfo_by_term
5459 .iter()
5460 .flat_map(|infos| infos.iter().cloned())
5461 .collect();
5462 Ok(())
5463}
5464
5465fn rebuild_term_collection_auxiliary_state(
5466 spec: &TermCollectionSpec,
5467 design: &mut TermCollectionDesign,
5468) -> Result<(), String> {
5469 if spec.linear_terms.len() != design.linear_ranges.len() {
5470 return Err(SmoothError::dimension_mismatch(format!(
5471 "term-collection linear bookkeeping mismatch: spec_terms={}, design_ranges={}",
5472 spec.linear_terms.len(),
5473 design.linear_ranges.len()
5474 ))
5475 .into());
5476 }
5477
5478 let p_total = design.design.ncols();
5479 let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
5480 let mut coefficient_lower_bounds = Array1::<f64>::from_elem(p_total, f64::NEG_INFINITY);
5481 let mut any_bounds = false;
5482 let mut linear_constraintrows: Vec<Array1<f64>> = Vec::new();
5483 let mut linear_constraint_b: Vec<f64> = Vec::new();
5484
5485 for (linear, (_, range)) in spec.linear_terms.iter().zip(design.linear_ranges.iter()) {
5486 if range.len() != 1 {
5487 return Err(SmoothError::dimension_mismatch(format!(
5488 "linear term '{}' expected one coefficient column, found {}",
5489 linear.name,
5490 range.len()
5491 ))
5492 .into());
5493 }
5494 let col = range.start;
5495 if let Some(lb) = linear.coefficient_min {
5496 let mut row = Array1::<f64>::zeros(p_total);
5497 row[col] = 1.0;
5498 linear_constraintrows.push(row);
5499 linear_constraint_b.push(lb);
5500 }
5501 if let Some(ub) = linear.coefficient_max {
5502 let mut row = Array1::<f64>::zeros(p_total);
5503 row[col] = -1.0;
5504 linear_constraintrows.push(row);
5505 linear_constraint_b.push(-ub);
5506 }
5507 }
5508
5509 if let Some(lb_smooth) = design.smooth.coefficient_lower_bounds.as_ref() {
5510 if lb_smooth.len() != design.smooth.total_smooth_cols() {
5511 return Err(SmoothError::dimension_mismatch(format!(
5512 "smooth lower-bound width mismatch: bounds={}, smooth_cols={}",
5513 lb_smooth.len(),
5514 design.smooth.total_smooth_cols()
5515 ))
5516 .into());
5517 }
5518 coefficient_lower_bounds
5519 .slice_mut(s![
5520 smooth_start..(smooth_start + design.smooth.total_smooth_cols())
5521 ])
5522 .assign(lb_smooth);
5523 any_bounds = true;
5524 }
5525 if let Some(lin_smooth) = design.smooth.linear_constraints.as_ref() {
5526 if lin_smooth.a.ncols() != design.smooth.total_smooth_cols() {
5527 return Err(SmoothError::dimension_mismatch(format!(
5528 "smooth linear-constraint width mismatch: cols={}, smooth_cols={}",
5529 lin_smooth.a.ncols(),
5530 design.smooth.total_smooth_cols()
5531 ))
5532 .into());
5533 }
5534 let mut a_global = Array2::<f64>::zeros((lin_smooth.a.nrows(), p_total));
5535 a_global
5536 .slice_mut(s![
5537 ..,
5538 smooth_start..(smooth_start + design.smooth.total_smooth_cols())
5539 ])
5540 .assign(&lin_smooth.a);
5541 for r in 0..a_global.nrows() {
5542 linear_constraintrows.push(a_global.row(r).to_owned());
5543 linear_constraint_b.push(lin_smooth.b[r]);
5544 }
5545 }
5546
5547 let lower_bound_constraints = if any_bounds {
5548 linear_constraints_from_lower_bounds_global(&coefficient_lower_bounds)
5549 } else {
5550 None
5551 };
5552 let explicit_linear_constraints = if linear_constraintrows.is_empty() {
5553 None
5554 } else {
5555 let mut a = Array2::<f64>::zeros((linear_constraintrows.len(), p_total));
5556 for (i, row) in linear_constraintrows.iter().enumerate() {
5557 a.row_mut(i).assign(row);
5558 }
5559 Some(LinearInequalityConstraints {
5560 a,
5561 b: Array1::from_vec(linear_constraint_b),
5562 })
5563 };
5564
5565 design.coefficient_lower_bounds = if any_bounds {
5566 Some(coefficient_lower_bounds)
5567 } else {
5568 None
5569 };
5570 design.linear_constraints =
5571 merge_linear_constraints_global(explicit_linear_constraints, lower_bound_constraints);
5572 design.dropped_penaltyinfo = design.smooth.dropped_penaltyinfo.clone();
5573 Ok(())
5574}
5575
5576fn theta_values_match(left: &Array1<f64>, right: &Array1<f64>) -> bool {
5577 left.len() == right.len()
5578 && left
5579 .iter()
5580 .zip(right.iter())
5581 .all(|(&l, &r)| l.to_bits() == r.to_bits())
5582}
5583
5584fn latent_values_match(left: &Array1<f64>, right: &Array1<f64>) -> bool {
5585 theta_values_match(left, right)
5586}
5587
5588fn spatial_aniso_matches(left: Option<&[f64]>, right: Option<&[f64]>) -> bool {
5589 match (left, right) {
5590 (None, None) => true,
5591 (Some(a), Some(b)) => {
5592 a.len() == b.len()
5593 && a.iter()
5594 .zip(b.iter())
5595 .all(|(&x, &y)| x.to_bits() == y.to_bits())
5596 }
5597 _ => false,
5598 }
5599}
5600
5601fn spatial_length_scale_matches(left: Option<f64>, right: Option<f64>) -> bool {
5602 match (left, right) {
5603 (None, None) => true,
5604 (Some(a), Some(b)) => a.to_bits() == b.to_bits(),
5605 _ => false,
5606 }
5607}
5608
5609struct FrozenTermCollectionIncrementalRealizer<'d> {
5610 data: ArrayView2<'d, f64>,
5611 spec: TermCollectionSpec,
5612 design: TermCollectionDesign,
5613 fixed_blocks: Vec<DesignBlock>,
5614 dropped_penaltyinfo_by_term: Vec<Vec<DroppedPenaltyBlockInfo>>,
5615 smooth_penalty_ranges: Vec<Range<usize>>,
5616 full_penalty_ranges: Vec<Range<usize>>,
5617 basisworkspace: gam_terms::basis::BasisWorkspace,
5621 spatial_realization_geometry: Vec<Option<SmoothTermSpec>>,
5634 design_revision: u64,
5640}
5641
5642impl<'d> std::fmt::Debug for FrozenTermCollectionIncrementalRealizer<'d> {
5643 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
5644 f.debug_struct("FrozenTermCollectionIncrementalRealizer")
5645 .field("data_shape", &(self.data.nrows(), self.data.ncols()))
5646 .field("fixed_blocks", &self.fixed_blocks.len())
5647 .finish_non_exhaustive()
5648 }
5649}
5650
5651impl<'d> FrozenTermCollectionIncrementalRealizer<'d> {
5652 fn new(
5653 data: ArrayView2<'d, f64>,
5654 spec: TermCollectionSpec,
5655 design: TermCollectionDesign,
5656 ) -> Result<Self, String> {
5657 if spec.smooth_terms.len() != design.smooth.terms.len() {
5658 return Err(SmoothError::dimension_mismatch(format!(
5659 "incremental realizer smooth term mismatch: spec_terms={}, design_terms={}",
5660 spec.smooth_terms.len(),
5661 design.smooth.terms.len()
5662 ))
5663 .into());
5664 }
5665
5666 let mut smooth_cursor = 0usize;
5667 let mut smooth_penalty_ranges = Vec::with_capacity(design.smooth.terms.len());
5668 for term in &design.smooth.terms {
5669 let next = smooth_cursor + term.penalties_local.len();
5670 smooth_penalty_ranges.push(smooth_cursor..next);
5671 smooth_cursor = next;
5672 }
5673 if smooth_cursor != design.smooth.penalties.len() {
5674 return Err(SmoothError::dimension_mismatch(format!(
5675 "incremental realizer smooth penalty mismatch: ranged={}, actual={}",
5676 smooth_cursor,
5677 design.smooth.penalties.len()
5678 ))
5679 .into());
5680 }
5681
5682 let fixed_penalty_offset = design
5683 .penalties
5684 .len()
5685 .checked_sub(design.smooth.penalties.len())
5686 .ok_or_else(|| {
5687 "incremental realizer encountered invalid penalty bookkeeping".to_string()
5688 })?;
5689 let full_penalty_ranges = smooth_penalty_ranges
5690 .iter()
5691 .map(|range| (fixed_penalty_offset + range.start)..(fixed_penalty_offset + range.end))
5692 .collect::<Vec<_>>();
5693 let fixed_blocks = build_term_collection_fixed_blocks(data, &spec)
5694 .map_err(|e| format!("failed to cache fixed term-collection blocks: {e}"))?;
5695
5696 let mut dropped_penaltyinfo_by_term = Vec::with_capacity(spec.smooth_terms.len());
5697 for (term_idx, termspec) in spec.smooth_terms.iter().enumerate() {
5698 let realization =
5699 build_single_smooth_term_realization(data, termspec).map_err(|e| {
5700 format!(
5701 "failed to build cached realization for smooth term '{}' (index {}): {e}",
5702 termspec.name, term_idx
5703 )
5704 })?;
5705 let expected_cols = design.smooth.terms[term_idx].coeff_range.len();
5706 if realization.design_local.ncols() != expected_cols {
5707 return Err(SmoothError::dimension_mismatch(format!(
5708 "cached realization width mismatch for term '{}': cached_cols={}, design_cols={}",
5709 termspec.name,
5710 realization.design_local.ncols(),
5711 expected_cols
5712 ))
5713 .into());
5714 }
5715 if realization.active_penaltyinfo().len()
5716 != design.smooth.terms[term_idx].penalties_local.len()
5717 {
5718 return Err(SmoothError::dimension_mismatch(format!(
5719 "cached realization penalty mismatch for term '{}': cached_penalties={}, design_penalties={}",
5720 termspec.name,
5721 realization.active_penaltyinfo().len(),
5722 design.smooth.terms[term_idx].penalties_local.len()
5723 ))
5724 .into());
5725 }
5726 dropped_penaltyinfo_by_term.push(realization.dropped_penaltyinfo);
5727 }
5728
5729 let geometry_slots = spec.smooth_terms.len();
5730 Ok(Self {
5731 data,
5732 spec,
5733 design,
5734 fixed_blocks,
5735 dropped_penaltyinfo_by_term,
5736 smooth_penalty_ranges,
5737 full_penalty_ranges,
5738 basisworkspace: gam_terms::basis::BasisWorkspace::new(),
5739 spatial_realization_geometry: vec![None; geometry_slots],
5740 design_revision: 0,
5741 })
5742 }
5743
5744 fn design_revision(&self) -> u64 {
5745 self.design_revision
5746 }
5747
5748 fn spec(&self) -> &TermCollectionSpec {
5749 &self.spec
5750 }
5751
5752 fn design(&self) -> &TermCollectionDesign {
5753 &self.design
5754 }
5755
5756 fn supports_nfree_penalty_rekey(&self, spatial_terms: &[usize]) -> bool {
5797 if spatial_terms.len() != 1 {
5798 return false;
5799 }
5800 let term_idx = spatial_terms[0];
5801 matches!(
5802 self.design.smooth.terms.get(term_idx).map(|t| &t.metadata),
5803 Some(BasisMetadata::Duchon { .. } | BasisMetadata::ThinPlate { .. })
5804 )
5805 }
5806
5807 fn supports_nfree_gradient_only_routing(&self, spatial_terms: &[usize]) -> bool {
5816 if spatial_terms.len() != 1 {
5817 return false;
5818 }
5819 let term_idx = spatial_terms[0];
5820 matches!(
5821 self.design.smooth.terms.get(term_idx).map(|t| &t.metadata),
5822 Some(BasisMetadata::Duchon { .. } | BasisMetadata::ThinPlate { .. })
5823 )
5824 }
5825
5826 fn canonical_penalties_at_psi(
5839 &mut self,
5840 spatial_terms: &[usize],
5841 psi: &[f64],
5842 ) -> Result<(Vec<gam_terms::construction::CanonicalPenalty>, Vec<usize>), String> {
5843 if spatial_terms.len() != 1 {
5844 return Err(format!(
5845 "n-free penalty re-key requires exactly one spatial term, found {}",
5846 spatial_terms.len()
5847 ));
5848 }
5849 let term_idx = spatial_terms[0];
5850 let (ls_opt, aniso_from_psi) = spatial_term_psi_to_length_scale_and_aniso(psi);
5856 let termspec =
5859 self.spec.smooth_terms.get(term_idx).ok_or_else(|| {
5860 format!("spatial term {term_idx} out of range for n-free penalty")
5861 })?;
5862 let term = self
5863 .design
5864 .smooth
5865 .terms
5866 .get(term_idx)
5867 .ok_or_else(|| format!("realized smooth term {term_idx} out of range"))?;
5868 let p_total = self.design.design.ncols();
5871 let (locals, nullspace_dims): (Vec<Array2<f64>>, Vec<usize>) = match &term.metadata {
5872 BasisMetadata::Duchon {
5873 centers,
5874 identifiability_transform,
5875 operator_collocation_points,
5876 power,
5877 nullspace_order,
5878 aniso_log_scales,
5879 input_scales,
5880 radial_reparam,
5881 ..
5882 } => {
5883 let operator_penalties = match &termspec.basis {
5884 SmoothBasisSpec::Duchon { spec, .. } => spec.operator_penalties.clone(),
5885 _ => gam_terms::basis::DuchonOperatorPenaltySpec::default(),
5886 };
5887 let effective_ls = match input_scales.as_deref() {
5894 Some(scales) => {
5895 compensate_optional_length_scale_for_standardization(ls_opt, scales)
5896 }
5897 None => ls_opt,
5898 };
5899 gam_terms::basis::duchon_penalties_at_length_scale(
5900 centers.view(),
5901 identifiability_transform.as_ref(),
5902 operator_collocation_points.as_ref().map(|p| p.view()),
5903 &operator_penalties,
5904 *power,
5905 *nullspace_order,
5906 aniso_log_scales.as_deref(),
5907 radial_reparam.as_ref(),
5908 effective_ls,
5909 &mut self.basisworkspace,
5910 )
5911 .map_err(|e| e.to_string())?
5912 }
5913 BasisMetadata::Matern {
5914 centers,
5915 periodic,
5916 nu,
5917 include_intercept,
5918 identifiability_transform,
5919 aniso_log_scales,
5920 input_scales,
5921 ..
5922 } => {
5923 let ls = ls_opt.ok_or_else(|| {
5930 "Matérn n-free penalty re-key requires a finite length-scale".to_string()
5931 })?;
5932 let effective_ls = match input_scales.as_deref() {
5933 Some(scales) => compensate_length_scale_for_standardization(ls, scales),
5934 None => ls,
5935 };
5936 let aniso_for_penalty = aniso_from_psi.as_deref().or(aniso_log_scales.as_deref());
5937 let (penalties, nullspace_dims, _info) =
5948 matern_operator_penalty_triplet_at_length_scale(
5949 centers.view(),
5950 periodic.as_deref(),
5951 identifiability_transform.as_ref(),
5952 *nu,
5953 *include_intercept,
5954 aniso_for_penalty,
5955 effective_ls,
5956 )
5957 .map_err(|e| e.to_string())?;
5958 (penalties, nullspace_dims)
5959 }
5960 BasisMetadata::ThinPlate {
5961 centers,
5962 identifiability_transform,
5963 radial_reparam,
5964 ..
5965 } => {
5966 let ls = ls_opt.ok_or_else(|| {
5967 "thin-plate n-free penalty re-key requires a finite length-scale".to_string()
5968 })?;
5969 let double_penalty = match &termspec.basis {
5970 SmoothBasisSpec::ThinPlate { spec, .. } => spec.double_penalty,
5971 _ => false,
5972 };
5973 gam_terms::basis::thin_plate_penalties_at_length_scale(
5974 centers.view(),
5975 identifiability_transform.as_ref(),
5976 radial_reparam.as_ref(),
5977 ls,
5978 double_penalty,
5979 &mut self.basisworkspace,
5980 )
5981 .map_err(|e| e.to_string())?
5982 }
5983 other => {
5984 return Err(format!(
5985 "n-free penalty re-key unsupported for basis metadata {:?}",
5986 std::mem::discriminant(other)
5987 ));
5988 }
5989 };
5990 let templates = &self.design.penalties;
5995 if templates.len() != locals.len() {
5996 return Err(format!(
5997 "n-free penalty re-key produced {} blocks but the frozen design carries {} \
5998 — penalty topology is not ψ-stable",
5999 locals.len(),
6000 templates.len()
6001 ));
6002 }
6003 let specs: Vec<gam_solve::estimate::PenaltySpec> = templates
6004 .iter()
6005 .zip(locals.into_iter())
6006 .map(|(tmpl, local)| gam_solve::estimate::PenaltySpec::Block {
6007 local,
6008 col_range: tmpl.col_range.clone(),
6009 prior_mean: tmpl.prior_mean.clone(),
6010 structure_hint: tmpl.structure_hint.clone(),
6011 op: tmpl.op.clone(),
6012 })
6013 .collect();
6014 gam_terms::construction::canonicalize_penalty_specs(
6015 &specs,
6016 &nullspace_dims,
6017 p_total,
6018 "nfree-psi-penalty",
6019 )
6020 .map_err(|e| e.to_string())
6021 }
6022
6023 fn canonical_penalty_derivatives_at_psi(
6024 &mut self,
6025 spatial_terms: &[usize],
6026 psi: &[f64],
6027 ) -> Result<(Range<usize>, usize, Vec<Array2<f64>>), String> {
6028 if spatial_terms.len() != 1 {
6029 return Err(format!(
6030 "n-free penalty derivative re-key requires exactly one spatial term, found {}",
6031 spatial_terms.len()
6032 ));
6033 }
6034 let term_idx = spatial_terms[0];
6035 let (ls_opt, aniso_from_psi) = spatial_term_psi_to_length_scale_and_aniso(psi);
6036 let termspec = self.spec.smooth_terms.get(term_idx).ok_or_else(|| {
6037 format!("spatial term {term_idx} out of range for n-free penalty derivative")
6038 })?;
6039 let term = self
6040 .design
6041 .smooth
6042 .terms
6043 .get(term_idx)
6044 .ok_or_else(|| format!("realized smooth term {term_idx} out of range"))?;
6045 let p_total = self.design.design.ncols();
6046 let smooth_start = p_total.saturating_sub(self.design.smooth.total_smooth_cols());
6047 let global_range =
6048 (smooth_start + term.coeff_range.start)..(smooth_start + term.coeff_range.end);
6049
6050 let locals = match &term.metadata {
6051 BasisMetadata::Duchon {
6052 centers,
6053 identifiability_transform,
6054 operator_collocation_points,
6055 power,
6056 nullspace_order,
6057 aniso_log_scales,
6058 input_scales,
6059 radial_reparam,
6060 ..
6061 } => {
6062 let mut spec = match &termspec.basis {
6063 SmoothBasisSpec::Duchon { spec, .. } => spec.clone(),
6064 _ => {
6065 return Err(
6066 "Duchon n-free penalty derivative requires a Duchon term spec"
6067 .to_string(),
6068 );
6069 }
6070 };
6071 let effective_ls = match input_scales.as_deref() {
6072 Some(scales) => {
6073 compensate_optional_length_scale_for_standardization(ls_opt, scales)
6074 }
6075 None => ls_opt,
6076 };
6077 spec.length_scale = effective_ls;
6078 spec.power = *power;
6079 spec.nullspace_order = *nullspace_order;
6080 spec.aniso_log_scales = aniso_log_scales.clone();
6081 spec.radial_reparam = radial_reparam.clone();
6084 if spec.length_scale.is_none() {
6085 return Err(
6086 "Duchon n-free penalty derivative requires a hybrid length-scale"
6087 .to_string(),
6088 );
6089 }
6090 let collocation = operator_collocation_points
6091 .as_ref()
6092 .map(|points| points.view())
6093 .unwrap_or_else(|| centers.view());
6094 let (_native_sources, mut first, _native_second) =
6095 gam_terms::basis::build_duchon_native_penalty_psi_derivatives(
6096 centers.view(),
6097 &spec,
6098 identifiability_transform.as_ref(),
6099 &mut self.basisworkspace,
6100 )
6101 .map_err(|e| e.to_string())?;
6102 let (_operator_sources, operator_first, _operator_second) =
6103 gam_terms::basis::build_duchon_operator_penalty_psi_derivatives(
6104 collocation,
6105 centers.view(),
6106 &spec,
6107 identifiability_transform.as_ref(),
6108 &mut self.basisworkspace,
6109 )
6110 .map_err(|e| e.to_string())?;
6111 first.extend(operator_first);
6112 first
6113 }
6114 BasisMetadata::Matern {
6115 centers,
6116 periodic,
6117 nu,
6118 include_intercept,
6119 identifiability_transform,
6120 aniso_log_scales,
6121 input_scales,
6122 ..
6123 } => {
6124 let ls = ls_opt.ok_or_else(|| {
6125 "Matérn n-free penalty derivative requires a finite length-scale".to_string()
6126 })?;
6127 let effective_ls = match input_scales.as_deref() {
6128 Some(scales) => compensate_length_scale_for_standardization(ls, scales),
6129 None => ls,
6130 };
6131 let penalty_centers =
6132 gam_terms::basis::expand_periodic_centers(¢ers.to_owned(), periodic.as_deref())
6133 .map_err(|e| e.to_string())?;
6134 let aniso_for_penalty = aniso_from_psi.as_deref().or(aniso_log_scales.as_deref());
6135 let (first, _second) = gam_terms::basis::build_matern_operator_penalty_psi_derivatives(
6136 penalty_centers.view(),
6137 effective_ls,
6138 *nu,
6139 *include_intercept,
6140 identifiability_transform.as_ref(),
6141 aniso_for_penalty,
6142 )
6143 .map_err(|e| e.to_string())?;
6144 first
6145 }
6146 BasisMetadata::ThinPlate {
6147 centers,
6148 identifiability_transform,
6149 radial_reparam,
6150 ..
6151 } => {
6152 let ls = ls_opt.ok_or_else(|| {
6153 "thin-plate n-free penalty derivative requires a finite length-scale"
6154 .to_string()
6155 })?;
6156 let mut spec = match &termspec.basis {
6157 SmoothBasisSpec::ThinPlate { spec, .. } => spec.clone(),
6158 _ => {
6159 return Err(
6160 "thin-plate n-free penalty derivative requires a ThinPlate term spec"
6161 .to_string(),
6162 );
6163 }
6164 };
6165 spec.length_scale = ls;
6166 if spec.radial_reparam.is_none() {
6167 spec.radial_reparam = radial_reparam.clone();
6168 }
6169 let (primary, _primary_second) =
6170 gam_terms::basis::build_thin_plate_penalty_psi_derivativeswithworkspace(
6171 centers.view(),
6172 &spec,
6173 identifiability_transform.as_ref(),
6174 &mut self.basisworkspace,
6175 )
6176 .map_err(|e| e.to_string())?;
6177 if self.design.penalties.len() > 1 {
6178 vec![primary.clone(), Array2::<f64>::zeros(primary.raw_dim())]
6179 } else {
6180 vec![primary]
6181 }
6182 }
6183 other => {
6184 return Err(format!(
6185 "n-free penalty derivative re-key unsupported for basis metadata {:?}",
6186 std::mem::discriminant(other)
6187 ));
6188 }
6189 };
6190 if locals.len() != self.design.penalties.len() {
6191 return Err(format!(
6192 "n-free penalty derivative re-key produced {} blocks but the frozen design carries {} \
6193 — penalty topology is not ψ-stable",
6194 locals.len(),
6195 self.design.penalties.len()
6196 ));
6197 }
6198 Ok((global_range, p_total, locals))
6199 }
6200
6201 fn apply_log_kappa(
6202 &mut self,
6203 log_kappa: &SpatialLogKappaCoords,
6204 term_indices: &[usize],
6205 ) -> Result<(), String> {
6206 if term_indices.len() != log_kappa.dims_per_term().len() {
6207 return Err(SmoothError::dimension_mismatch(format!(
6208 "incremental realizer log-kappa term mismatch: term_indices={}, dims_per_term={}",
6209 term_indices.len(),
6210 log_kappa.dims_per_term().len()
6211 ))
6212 .into());
6213 }
6214
6215 let mut any_changed = false;
6216 for (slot, &term_idx) in term_indices.iter().enumerate() {
6217 any_changed |= self.apply_log_kappa_to_term(term_idx, log_kappa.term_slice(slot))?;
6218 }
6219
6220 if any_changed {
6221 self.refresh_full_design_operator()?;
6222 rebuild_smooth_auxiliary_state(
6223 &mut self.design.smooth,
6224 &self.dropped_penaltyinfo_by_term,
6225 )?;
6226 rebuild_term_collection_auxiliary_state(&self.spec, &mut self.design)?;
6227 self.design_revision = self.design_revision.wrapping_add(1);
6228 }
6229 Ok(())
6230 }
6231
6232 fn apply_log_kappa_to_term(&mut self, term_idx: usize, psi: &[f64]) -> Result<bool, String> {
6233 if !spatial_term_supports_hyper_optimization(&self.spec, term_idx) {
6234 return Err(SmoothError::invalid_config(format!(
6235 "incremental realizer term {term_idx} does not expose spatial hyperparameters"
6236 ))
6237 .into());
6238 }
6239 let measure_jet_term = measure_jet_term_spec(&self.spec, term_idx).is_some();
6243 let constant_curvature_term = constant_curvature_term_spec(&self.spec, term_idx).is_some();
6247 let mut next_length_scale = None;
6248 let mut next_aniso: Option<Vec<f64>> = None;
6249 if measure_jet_term {
6250 if !set_measure_jet_psi_dials(&mut self.spec, term_idx, psi)
6251 .map_err(|e| e.to_string())?
6252 {
6253 return Ok(false);
6254 }
6255 } else if constant_curvature_term {
6256 if !set_constant_curvature_kappa(&mut self.spec, term_idx, psi)
6257 .map_err(|e| e.to_string())?
6258 {
6259 return Ok(false);
6260 }
6261 } else {
6262 let current_length_scale = get_spatial_length_scale(&self.spec, term_idx);
6263 let current_aniso = get_spatial_aniso_log_scales(&self.spec, term_idx);
6264 let (ls, eta) = spatial_term_psi_to_length_scale_and_aniso(psi);
6265 next_length_scale = ls;
6266 next_aniso = eta;
6267 let same_length = spatial_length_scale_matches(current_length_scale, next_length_scale);
6268 let same_aniso = spatial_aniso_matches(current_aniso.as_deref(), next_aniso.as_deref());
6269 if same_length && same_aniso {
6270 return Ok(false);
6271 }
6272 if let Some(length_scale) = next_length_scale {
6273 set_spatial_length_scale(&mut self.spec, term_idx, length_scale)
6274 .map_err(|e| e.to_string())?;
6275 }
6276 if let Some(eta) = next_aniso.clone() {
6277 set_spatial_aniso_log_scales(&mut self.spec, term_idx, eta)
6278 .map_err(|e| e.to_string())?;
6279 }
6280 }
6281
6282 let geometry_slot = self
6293 .spatial_realization_geometry
6294 .get(term_idx)
6295 .ok_or_else(|| format!("incremental realizer geometry slot {term_idx} out of range"))?;
6296 let mut build_spec = match geometry_slot {
6297 Some(cached) => cached.clone(),
6298 None => self
6299 .spec
6300 .smooth_terms
6301 .get(term_idx)
6302 .ok_or_else(|| format!("incremental realizer smooth term {term_idx} out of range"))?
6303 .clone(),
6304 };
6305 if measure_jet_term {
6306 set_single_term_measure_jet_psi_dials(&mut build_spec, psi)
6310 .map_err(|e| e.to_string())?;
6311 } else if constant_curvature_term {
6312 set_single_term_constant_curvature_kappa(&mut build_spec, psi)
6317 .map_err(|e| e.to_string())?;
6318 } else {
6319 if let Some(length_scale) = next_length_scale {
6320 set_single_term_spatial_length_scale(&mut build_spec, length_scale)
6321 .map_err(|e| e.to_string())?;
6322 }
6323 if let Some(eta) = next_aniso {
6324 set_single_term_spatial_aniso_log_scales(&mut build_spec, eta)
6325 .map_err(|e| e.to_string())?;
6326 }
6327 }
6328
6329 let termname = build_spec.name.clone();
6330 let local = build_single_local_smooth_term(
6331 self.data,
6332 &build_spec,
6333 &mut self.basisworkspace,
6334 )
6335 .map_err(|e| {
6336 format!(
6337 "failed to rebuild smooth term '{termname}' during incremental κ realization: {e}"
6338 )
6339 })?;
6340
6341 if self.spatial_realization_geometry[term_idx].is_none()
6346 && let Some(frozen) = freeze_geometry_from_metadata(&build_spec, &local.metadata)
6347 {
6348 if let (
6360 SmoothBasisSpec::Matern {
6361 spec: frozen_spec, ..
6362 },
6363 Some(SmoothBasisSpec::Matern {
6364 spec: live_spec, ..
6365 }),
6366 ) = (
6367 &frozen.basis,
6368 self.spec
6369 .smooth_terms
6370 .get_mut(term_idx)
6371 .map(|t| &mut t.basis),
6372 ) {
6373 live_spec.identifiability = frozen_spec.identifiability.clone();
6374 live_spec.center_strategy = frozen_spec.center_strategy.clone();
6375 }
6376 self.spatial_realization_geometry[term_idx] = Some(frozen);
6377 }
6378
6379 let realization = wrap_local_build_as_realization(local, &build_spec)?;
6380 self.replace_term_realization(term_idx, realization)?;
6381 Ok(true)
6382 }
6383
6384 fn replace_term_realization(
6385 &mut self,
6386 term_idx: usize,
6387 realization: SingleSmoothTermRealization,
6388 ) -> Result<(), String> {
6389 let t_replace = std::time::Instant::now();
6390 let SingleSmoothTermRealization {
6391 design_local,
6392 term,
6393 dropped_penaltyinfo,
6394 } = realization;
6395 let SmoothTerm {
6396 name,
6397 penalties_local,
6398 nullspace_dims,
6399 penaltyinfo_local,
6400 metadata,
6401 lower_bounds_local,
6402 linear_constraints_local,
6403 joint_null_rotation,
6404 ..
6405 } = term;
6406 let coeff_range = self
6407 .design
6408 .smooth
6409 .terms
6410 .get(term_idx)
6411 .ok_or_else(|| format!("incremental realizer smooth term {term_idx} out of range"))?
6412 .coeff_range
6413 .clone();
6414 if design_local.ncols() != coeff_range.len() {
6415 return Err(SmoothError::dimension_mismatch(format!(
6416 "incremental realizer width mismatch for term {}: rebuilt_cols={}, cached_cols={}",
6417 term_idx,
6418 design_local.ncols(),
6419 coeff_range.len()
6420 ))
6421 .into());
6422 }
6423 if design_local.nrows() != self.design.design.nrows() {
6424 return Err(SmoothError::dimension_mismatch(format!(
6425 "incremental realizer row mismatch for term {}: rebuilt_rows={}, design_rows={}",
6426 term_idx,
6427 design_local.nrows(),
6428 self.design.design.nrows()
6429 ))
6430 .into());
6431 }
6432
6433 let active_penaltyinfo = penaltyinfo_local
6434 .iter()
6435 .filter(|info| info.active)
6436 .cloned()
6437 .collect::<Vec<_>>();
6438 let smooth_penalty_range = self
6439 .smooth_penalty_ranges
6440 .get(term_idx)
6441 .ok_or_else(|| {
6442 format!("incremental realizer missing smooth penalty range for term {term_idx}")
6443 })?
6444 .clone();
6445 let full_penalty_range = self
6446 .full_penalty_ranges
6447 .get(term_idx)
6448 .ok_or_else(|| {
6449 format!("incremental realizer missing full penalty range for term {term_idx}")
6450 })?
6451 .clone();
6452 if active_penaltyinfo.len() != smooth_penalty_range.len()
6453 || penalties_local.len() != smooth_penalty_range.len()
6454 || nullspace_dims.len() != smooth_penalty_range.len()
6455 {
6456 return Err(SmoothError::dimension_mismatch(format!(
6457 "incremental realizer topology changed for term '{}': penalties={}, infos={}, nullspaces={}, cached_penalties={}",
6458 name,
6459 penalties_local.len(),
6460 active_penaltyinfo.len(),
6461 nullspace_dims.len(),
6462 smooth_penalty_range.len()
6463 ))
6464 .into());
6465 }
6466
6467 self.design.smooth.term_designs[term_idx] = design_local;
6468
6469 for (offset, penalty_local) in penalties_local.iter().enumerate() {
6470 let smooth_penalty_idx = smooth_penalty_range.start + offset;
6471 let full_penalty_idx = full_penalty_range.start + offset;
6472 let nullspace_dim = nullspace_dims[offset];
6473 let penalty_info = active_penaltyinfo[offset].clone();
6474
6475 if penalty_local.nrows() != coeff_range.len()
6476 || penalty_local.ncols() != coeff_range.len()
6477 {
6478 return Err(SmoothError::dimension_mismatch(format!(
6479 "incremental realizer penalty shape mismatch for term '{}' penalty {}: \
6480 penalty is {}x{} but coeff_range has {} columns",
6481 name,
6482 offset,
6483 penalty_local.nrows(),
6484 penalty_local.ncols(),
6485 coeff_range.len()
6486 ))
6487 .into());
6488 }
6489
6490 let smooth_penalty = self
6491 .design
6492 .smooth
6493 .penalties
6494 .get_mut(smooth_penalty_idx)
6495 .ok_or_else(|| {
6496 format!(
6497 "incremental realizer smooth penalty {} out of range for term {}",
6498 smooth_penalty_idx, term_idx
6499 )
6500 })?;
6501 smooth_penalty.local.assign(penalty_local);
6504
6505 let full_bp = self
6506 .design
6507 .penalties
6508 .get_mut(full_penalty_idx)
6509 .ok_or_else(|| {
6510 format!(
6511 "incremental realizer full penalty {} out of range for term {}",
6512 full_penalty_idx, term_idx
6513 )
6514 })?;
6515 full_bp.local.assign(penalty_local);
6518
6519 self.design.smooth.nullspace_dims[smooth_penalty_idx] = nullspace_dim;
6520 self.design.nullspace_dims[full_penalty_idx] = nullspace_dim;
6521
6522 self.design.smooth.penaltyinfo[smooth_penalty_idx].global_index = smooth_penalty_idx;
6523 self.design.smooth.penaltyinfo[smooth_penalty_idx].termname = Some(name.clone());
6524 self.design.smooth.penaltyinfo[smooth_penalty_idx].penalty = penalty_info.clone();
6525
6526 self.design.penaltyinfo[full_penalty_idx].global_index = full_penalty_idx;
6527 self.design.penaltyinfo[full_penalty_idx].termname = Some(name.clone());
6528 self.design.penaltyinfo[full_penalty_idx].penalty = penalty_info;
6529 }
6530
6531 let target_term = self.design.smooth.terms.get_mut(term_idx).ok_or_else(|| {
6532 format!("incremental realizer smooth term {term_idx} disappeared during replacement")
6533 })?;
6534 target_term.penalties_local = penalties_local;
6535 target_term.nullspace_dims = nullspace_dims;
6536 target_term.penaltyinfo_local = penaltyinfo_local;
6537 target_term.metadata = metadata;
6538 target_term.lower_bounds_local = lower_bounds_local;
6539 target_term.linear_constraints_local = linear_constraints_local;
6540 target_term.joint_null_rotation = joint_null_rotation;
6541 self.dropped_penaltyinfo_by_term[term_idx] = dropped_penaltyinfo;
6542 log::info!(
6543 "[STAGE] smooth basis rebuild (term {}, '{}', cols={}): {:.3}s",
6544 term_idx,
6545 target_term.name,
6546 coeff_range.len(),
6547 t_replace.elapsed().as_secs_f64(),
6548 );
6549 Ok(())
6550 }
6551
6552 fn refresh_full_design_operator(&mut self) -> Result<(), String> {
6553 let mut blocks = Vec::<DesignBlock>::with_capacity(
6554 self.fixed_blocks.len() + self.design.smooth.term_designs.len(),
6555 );
6556 blocks.extend(self.fixed_blocks.iter().cloned());
6557 for term_design in &self.design.smooth.term_designs {
6558 blocks.push(DesignBlock::from(term_design));
6559 }
6560 self.design.design = assemble_term_collection_design_matrix(blocks)
6561 .map_err(|e| format!("failed to refresh term-collection design: {e}"))?;
6562 Ok(())
6563 }
6564}
6565
6566fn build_term_collection_fixed_blocks(
6567 data: ArrayView2<'_, f64>,
6568 spec: &TermCollectionSpec,
6569) -> Result<Vec<DesignBlock>, BasisError> {
6570 let mut blocks = Vec::<DesignBlock>::new();
6571 if !term_collection_has_one_sided_anchored_bspline(spec) {
6572 blocks.push(DesignBlock::Intercept(data.nrows()));
6573 }
6574
6575 if !spec.linear_terms.is_empty() {
6576 let mut linear_block = Array2::<f64>::zeros((data.nrows(), spec.linear_terms.len()));
6577 for (j, linear) in spec.linear_terms.iter().enumerate() {
6578 let column = linear
6582 .realized_design_column(data)
6583 .map_err(BasisError::InvalidInput)?;
6584 linear_block.column_mut(j).assign(&column);
6585 }
6586 blocks.push(DesignBlock::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
6587 linear_block,
6588 )));
6589 }
6590
6591 for term in &spec.random_effect_terms {
6592 let block = build_random_effect_block(data, term)?;
6593 let re_op = RandomEffectOperator::new(block.group_ids, block.num_groups);
6594 blocks.push(DesignBlock::RandomEffect(Arc::new(re_op)));
6595 }
6596
6597 Ok(blocks)
6598}
6599
6600pub struct SpatialLengthScaleOptimizationResult<FitOut> {
6605 pub resolved_specs: Vec<TermCollectionSpec>,
6606 pub designs: Vec<TermCollectionDesign>,
6607 pub fit: FitOut,
6608 pub timing: Option<SpatialLengthScaleOptimizationTiming>,
6609}
6610
6611#[derive(Debug, Clone)]
6613pub struct ExactJointHyperSetup {
6614 rho0: Array1<f64>,
6615 rho_lower: Array1<f64>,
6616 rho_upper: Array1<f64>,
6617 log_kappa0: SpatialLogKappaCoords,
6618 log_kappa_lower: SpatialLogKappaCoords,
6619 log_kappa_upper: SpatialLogKappaCoords,
6620 auxiliary0: Array1<f64>,
6621 auxiliary_lower: Array1<f64>,
6622 auxiliary_upper: Array1<f64>,
6623}
6624
6625impl ExactJointHyperSetup {
6626 fn sanitize_rho_seed(
6627 rho0: Array1<f64>,
6628 rho_lower: &Array1<f64>,
6629 rho_upper: &Array1<f64>,
6630 ) -> Array1<f64> {
6631 Array1::from_iter(rho0.iter().enumerate().map(|(idx, &value)| {
6632 let lo = rho_lower[idx];
6633 let hi = rho_upper[idx];
6634 let fallback = 0.0_f64.clamp(lo, hi);
6635 if value.is_finite() {
6636 value.clamp(lo, hi)
6637 } else {
6638 fallback
6639 }
6640 }))
6641 }
6642
6643 pub(crate) fn new(
6644 rho0: Array1<f64>,
6645 rho_lower: Array1<f64>,
6646 rho_upper: Array1<f64>,
6647 log_kappa0: SpatialLogKappaCoords,
6648 log_kappa_lower: SpatialLogKappaCoords,
6649 log_kappa_upper: SpatialLogKappaCoords,
6650 ) -> Self {
6651 let rho0 = Self::sanitize_rho_seed(rho0, &rho_lower, &rho_upper);
6652 Self {
6653 rho0,
6654 rho_lower,
6655 rho_upper,
6656 log_kappa0,
6657 log_kappa_lower,
6658 log_kappa_upper,
6659 auxiliary0: Array1::zeros(0),
6660 auxiliary_lower: Array1::zeros(0),
6661 auxiliary_upper: Array1::zeros(0),
6662 }
6663 }
6664
6665 pub(crate) fn with_auxiliary(
6666 mut self,
6667 auxiliary0: Array1<f64>,
6668 auxiliary_lower: Array1<f64>,
6669 auxiliary_upper: Array1<f64>,
6670 ) -> Self {
6671 assert_eq!(
6672 auxiliary0.len(),
6673 auxiliary_lower.len(),
6674 "auxiliary lower bound length mismatch"
6675 );
6676 assert_eq!(
6677 auxiliary0.len(),
6678 auxiliary_upper.len(),
6679 "auxiliary upper bound length mismatch"
6680 );
6681 self.auxiliary0 = Self::sanitize_rho_seed(auxiliary0, &auxiliary_lower, &auxiliary_upper);
6682 self.auxiliary_lower = auxiliary_lower;
6683 self.auxiliary_upper = auxiliary_upper;
6684 self
6685 }
6686
6687 pub(crate) fn rho_dim(&self) -> usize {
6688 self.rho0.len()
6689 }
6690
6691 pub(crate) fn log_kappa_dim(&self) -> usize {
6692 self.log_kappa0.len()
6693 }
6694
6695 pub(crate) fn auxiliary_dim(&self) -> usize {
6696 self.auxiliary0.len()
6697 }
6698
6699 pub(crate) fn theta0(&self) -> Array1<f64> {
6700 let mut out =
6701 Array1::<f64>::zeros(self.rho_dim() + self.log_kappa_dim() + self.auxiliary_dim());
6702 out.slice_mut(s![..self.rho_dim()]).assign(&self.rho0);
6703 out.slice_mut(s![self.rho_dim()..self.rho_dim() + self.log_kappa_dim()])
6704 .assign(self.log_kappa0.as_array());
6705 out.slice_mut(s![self.rho_dim() + self.log_kappa_dim()..])
6706 .assign(&self.auxiliary0);
6707 out
6708 }
6709
6710 pub(crate) fn lower(&self) -> Array1<f64> {
6711 let mut out =
6712 Array1::<f64>::zeros(self.rho_dim() + self.log_kappa_dim() + self.auxiliary_dim());
6713 out.slice_mut(s![..self.rho_dim()]).assign(&self.rho_lower);
6714 out.slice_mut(s![self.rho_dim()..self.rho_dim() + self.log_kappa_dim()])
6715 .assign(self.log_kappa_lower.as_array());
6716 out.slice_mut(s![self.rho_dim() + self.log_kappa_dim()..])
6717 .assign(&self.auxiliary_lower);
6718 out
6719 }
6720
6721 pub(crate) fn upper(&self) -> Array1<f64> {
6722 let mut out =
6723 Array1::<f64>::zeros(self.rho_dim() + self.log_kappa_dim() + self.auxiliary_dim());
6724 out.slice_mut(s![..self.rho_dim()]).assign(&self.rho_upper);
6725 out.slice_mut(s![self.rho_dim()..self.rho_dim() + self.log_kappa_dim()])
6726 .assign(self.log_kappa_upper.as_array());
6727 out.slice_mut(s![self.rho_dim() + self.log_kappa_dim()..])
6728 .assign(&self.auxiliary_upper);
6729 out
6730 }
6731
6732 pub(crate) fn log_kappa_dims_per_term(&self) -> Vec<usize> {
6734 self.log_kappa0.dims_per_term().to_vec()
6735 }
6736}
6737
6738struct ExactJointDesignCache<'d> {
6744 realizers: Vec<FrozenTermCollectionIncrementalRealizer<'d>>,
6745 block_term_indices: Vec<Vec<usize>>,
6746 current_theta: Option<Array1<f64>>,
6747 last_cost: Option<f64>,
6748 last_eval: Option<(
6749 f64,
6750 Array1<f64>,
6751 gam_problem::HessianResult,
6752 )>,
6753 rho_dim: usize,
6754 all_dims: Vec<usize>,
6755 log_kappa_dim: usize,
6756 block_term_counts: Vec<usize>,
6757}
6758
6759impl<'d> ExactJointDesignCache<'d> {
6760 fn new(
6761 data: ArrayView2<'d, f64>,
6762 blocks: Vec<(TermCollectionSpec, TermCollectionDesign, Vec<usize>)>,
6763 rho_dim: usize,
6764 all_dims: Vec<usize>,
6765 ) -> Result<Self, String> {
6766 let n_blocks = blocks.len();
6767 let mut realizers = Vec::with_capacity(n_blocks);
6768 let mut block_term_indices = Vec::with_capacity(n_blocks);
6769 let mut block_term_counts = Vec::with_capacity(n_blocks);
6770
6771 for (spec, design, terms) in blocks {
6772 block_term_counts.push(terms.len());
6773 block_term_indices.push(terms);
6774 realizers.push(FrozenTermCollectionIncrementalRealizer::new(
6775 data, spec, design,
6776 )?);
6777 }
6778
6779 Ok(Self {
6780 realizers,
6781 block_term_indices,
6782 current_theta: None,
6783 last_cost: None,
6784 last_eval: None,
6785 rho_dim,
6786 log_kappa_dim: all_dims.iter().sum(),
6787 all_dims,
6788 block_term_counts,
6789 })
6790 }
6791
6792 fn ensure_theta(&mut self, theta: &Array1<f64>) -> Result<(), String> {
6793 if self
6794 .current_theta
6795 .as_ref()
6796 .is_some_and(|cached| theta_values_match(cached, theta))
6797 {
6798 return Ok(());
6799 }
6800
6801 let t_ensure = std::time::Instant::now();
6802 let kappa_theta_len = self.rho_dim + self.log_kappa_dim;
6803 if theta.len() < kappa_theta_len {
6804 return Err(SmoothError::dimension_mismatch(format!(
6805 "exact-joint theta length mismatch: got {}, expected at least {} (rho_dim={}, log_kappa_dim={})",
6806 theta.len(),
6807 kappa_theta_len,
6808 self.rho_dim,
6809 self.log_kappa_dim
6810 ))
6811 .into());
6812 }
6813 let theta_kappa = theta.slice(s![..kappa_theta_len]).to_owned();
6814 let full_log_kappa = SpatialLogKappaCoords::from_theta_tail_with_dims(
6815 &theta_kappa,
6816 self.rho_dim,
6817 self.all_dims.clone(),
6818 );
6819
6820 let n = self.realizers.len();
6824 let mut remaining = full_log_kappa;
6825 for block_idx in 0..n {
6826 let count = self.block_term_counts[block_idx];
6827 if block_idx < n - 1 {
6828 let (block_lk, rest) = remaining.split_at(count);
6829 self.realizers[block_idx]
6830 .apply_log_kappa(&block_lk, &self.block_term_indices[block_idx])?;
6831 remaining = rest;
6832 } else {
6833 self.realizers[block_idx]
6835 .apply_log_kappa(&remaining, &self.block_term_indices[block_idx])?;
6836 }
6837 }
6838
6839 log::info!(
6840 "[STAGE] ensure_theta (n-block, {} blocks, {} realizers): {:.3}s",
6841 n,
6842 self.realizers.len(),
6843 t_ensure.elapsed().as_secs_f64(),
6844 );
6845 self.current_theta = Some(theta.clone());
6846 self.last_cost = None;
6847 self.last_eval = None;
6848 Ok(())
6849 }
6850
6851 impl_exact_joint_theta_memo!();
6852
6853 fn store_cost_only(&mut self, theta: &Array1<f64>, cost: f64) {
6859 if self
6860 .current_theta
6861 .as_ref()
6862 .is_some_and(|cached| theta_values_match(cached, theta))
6863 {
6864 self.last_cost = Some(cost);
6865 }
6866 }
6867
6868 fn specs(&self) -> Vec<&TermCollectionSpec> {
6869 self.realizers.iter().map(|r| r.spec()).collect()
6870 }
6871
6872 fn designs(&self) -> Vec<&TermCollectionDesign> {
6873 self.realizers.iter().map(|r| r.design()).collect()
6874 }
6875
6876 fn design_revision(&self) -> u64 {
6886 self.realizers
6887 .iter()
6888 .fold(0u64, |acc, r| acc.wrapping_add(r.design_revision()))
6889 }
6890}
6891
6892pub(crate) fn seed_risk_profile_for_likelihood_family(
6893 family: &LikelihoodSpec,
6894) -> gam_problem::SeedRiskProfile {
6895 match &family.response {
6896 ResponseFamily::Gaussian => gam_problem::SeedRiskProfile::Gaussian,
6897 ResponseFamily::RoystonParmar => gam_problem::SeedRiskProfile::Survival,
6898 ResponseFamily::Binomial
6899 | ResponseFamily::Poisson
6900 | ResponseFamily::Tweedie { .. }
6901 | ResponseFamily::NegativeBinomial { .. }
6902 | ResponseFamily::Beta { .. }
6903 | ResponseFamily::Gamma => gam_problem::SeedRiskProfile::GeneralizedLinear,
6904 }
6905}
6906
6907const EXACT_JOINT_SECOND_ORDER_THETA_CAP: usize = 8;
6915
6916fn exact_joint_seed_config(
6917 risk_profile: gam_problem::SeedRiskProfile,
6918 auxiliary_dim: usize,
6919) -> gam_problem::SeedConfig {
6920 let mut config = gam_problem::SeedConfig {
6921 risk_profile,
6922 num_auxiliary_trailing: auxiliary_dim,
6923 ..Default::default()
6924 };
6925 match risk_profile {
6926 gam_problem::SeedRiskProfile::Gaussian
6927 | gam_problem::SeedRiskProfile::GaussianLocationScale => {
6928 config.max_seeds = 4;
6929 config.seed_budget = 2;
6930 }
6931 gam_problem::SeedRiskProfile::GeneralizedLinear => {
6932 config.max_seeds = 1;
6937 config.seed_budget = 1;
6938 config.screen_max_inner_iterations = 8;
6939 }
6940 gam_problem::SeedRiskProfile::Survival => {
6941 config.max_seeds = 8;
6947 config.seed_budget = 4;
6948 config.screen_max_inner_iterations = 8;
6949 }
6950 }
6951 config
6952}
6953
6954#[cfg(test)]
6955mod exact_joint_seed_config_tests {
6956 use super::*;
6957
6958 #[test]
6959 fn exact_joint_marginal_slope_profiles_get_deeper_startup_validation() {
6960 let bms = exact_joint_seed_config(gam_problem::SeedRiskProfile::GeneralizedLinear, 2);
6961 assert_eq!(bms.max_seeds, 1);
6962 assert_eq!(bms.seed_budget, 1);
6963 assert_eq!(bms.screen_max_inner_iterations, 8);
6964 assert_eq!(bms.num_auxiliary_trailing, 2);
6965
6966 let survival = exact_joint_seed_config(gam_problem::SeedRiskProfile::Survival, 3);
6967 assert_eq!(survival.max_seeds, 8);
6968 assert_eq!(survival.seed_budget, 4);
6969 assert_eq!(survival.screen_max_inner_iterations, 8);
6970 assert_eq!(survival.num_auxiliary_trailing, 3);
6971 }
6972
6973 #[test]
6974 fn exact_joint_gaussian_keeps_tight_historical_multistart_budget() {
6975 let gaussian = exact_joint_seed_config(gam_problem::SeedRiskProfile::Gaussian, 1);
6976 assert_eq!(gaussian.max_seeds, 4);
6977 assert_eq!(gaussian.seed_budget, 2);
6978 assert_eq!(
6979 gaussian.screen_max_inner_iterations,
6980 gam_problem::SeedConfig::default().screen_max_inner_iterations
6981 );
6982 assert_eq!(gaussian.num_auxiliary_trailing, 1);
6983 }
6984}
6985
6986#[cfg(test)]
6987mod wood_reference_df_tests {
6988 use super::*;
6989
6990 #[test]
6996 fn edf1_equals_two_trace_minus_trace_of_square() {
6997 let f = ndarray::array![[0.9_f64, 0.0], [0.0, 0.4]];
7001 let got = wood_reference_df(Some(&f), &(0..2)).unwrap();
7002 assert!(
7003 (got - 1.63).abs() < 1e-12,
7004 "edf1 should be 2*tr - tr(F^2) = 1.63, got {got}"
7005 );
7006 let edf = 1.3;
7009 assert!(got >= edf - 1e-12, "edf1 {got} must be >= edf {edf}");
7010 }
7011
7012 #[test]
7013 fn edf1_never_collapses_below_edf_when_offdiagonals_blow_up() {
7014 let f = ndarray::array![[0.5_f64, 40.0], [40.0, 0.5]];
7021 let tr = 1.0_f64;
7022 let got = wood_reference_df(Some(&f), &(0..2)).unwrap();
7023 assert!(
7024 got >= tr - 1e-12,
7025 "edf1 must be floored at edf (=tr={tr}) even when tr(F^2) explodes, got {got}"
7026 );
7027 assert!(got.is_finite() && got > 0.0, "edf1 must stay finite/positive");
7028 }
7029
7030 #[test]
7031 fn returns_none_on_nonpositive_or_missing_trace() {
7032 assert!(wood_reference_df(None, &(0..2)).is_none());
7035 let zero = ndarray::array![[0.0_f64, 0.0], [0.0, 0.0]];
7037 assert!(wood_reference_df(Some(&zero), &(0..2)).is_none());
7038 let f = ndarray::array![[0.5_f64, 0.0], [0.0, 0.5]];
7040 assert!(wood_reference_df(Some(&f), &(0..5)).is_none());
7041 }
7042}
7043
7044pub(crate) fn exact_joint_multistart_outer_problem(
7045 theta0: &Array1<f64>,
7046 lower: &Array1<f64>,
7047 upper: &Array1<f64>,
7048 rho_dim: usize,
7049 auxiliary_dim: usize,
7050 n_params: usize,
7051 gradient: gam_problem::Derivative,
7052 hessian: gam_problem::DeclaredHessianForm,
7053 prefer_gradient_only: bool,
7054 disable_fixed_point: bool,
7055 risk_profile: gam_problem::SeedRiskProfile,
7056 tolerance: f64,
7057 max_iter: usize,
7058 bfgs_step_cap: Option<f64>,
7067 bfgs_step_cap_psi: Option<f64>,
7068 screening_cap: Option<Arc<AtomicUsize>>,
7069 profiled_objective_size: Option<(usize, usize)>,
7090 has_constant_curvature: bool,
7099) -> gam_solve::rho_optimizer::OuterProblem {
7100 let mut seed_heuristic = theta0.to_vec();
7101 for value in &mut seed_heuristic[..rho_dim] {
7102 *value = value.exp();
7103 }
7104 let rho_ceiling = if has_constant_curvature {
7109 gam_solve::estimate::RHO_BOUND
7110 } else {
7111 12.0
7112 };
7113 let mut problem = gam_solve::rho_optimizer::OuterProblem::new(n_params)
7114 .with_gradient(gradient)
7115 .with_hessian(hessian)
7116 .with_prefer_gradient_only(prefer_gradient_only)
7117 .with_disable_fixed_point(disable_fixed_point)
7118 .with_fallback_policy(gam_solve::rho_optimizer::FallbackPolicy::Automatic)
7128 .with_psi_dim(auxiliary_dim)
7129 .with_tolerance(tolerance)
7130 .with_max_iter(max_iter)
7131 .with_bounds(lower.clone(), upper.clone())
7132 .with_initial_rho(theta0.clone())
7133 .with_bfgs_step_cap(bfgs_step_cap)
7134 .with_bfgs_step_cap_psi(bfgs_step_cap_psi)
7135 .with_seed_config({
7136 let mut sc = exact_joint_seed_config(risk_profile, auxiliary_dim);
7137 if has_constant_curvature {
7138 sc.bounds = (sc.bounds.0, rho_ceiling);
7142 }
7161 sc
7162 })
7163 .with_rho_bound(rho_ceiling)
7164 .with_heuristic_lambdas(seed_heuristic);
7165 if let Some((n_obs, p_cols)) = profiled_objective_size {
7166 problem = problem
7174 .with_objective_scale(Some(n_obs as f64))
7175 .with_problem_size(n_obs, p_cols)
7176 .with_arc_initial_regularization(Some(0.25))
7177 .with_operator_initial_trust_radius(Some(4.0));
7178 }
7179 if let Some(screening_cap) = screening_cap {
7180 problem = problem
7181 .with_screening_cap(screening_cap)
7182 .with_screen_initial_rho(true);
7183 }
7184 problem
7185}
7186
7187fn kappa_phase_failure_is_fixed_kappa_recoverable(message: &str) -> bool {
7198 message.contains("no candidate seeds passed outer startup validation")
7199 || message.contains("joint hyper rho dimension mismatch")
7200 || message.contains("objective returned a non-finite cost")
7201}
7202
7203pub fn optimize_spatial_length_scale_exact_joint<FitOut, FitFn, ExactFn, ExactEfsFn, SeedFn>(
7204 data: ArrayView2<'_, f64>,
7205 block_specs: &[TermCollectionSpec],
7206 block_term_indices: &[Vec<usize>],
7207 kappa_options: &SpatialLengthScaleOptimizationOptions,
7208 joint_setup: &ExactJointHyperSetup,
7209 seed_risk_profile: gam_problem::SeedRiskProfile,
7210 analytic_joint_gradient_available: bool,
7211 analytic_joint_hessian_available: bool,
7212 disable_fixed_point: bool,
7213 screening_cap: Option<Arc<AtomicUsize>>,
7214 outer_derivative_policy: gam_model_api::families::custom_family::OuterDerivativePolicy,
7215 mut fit_fn: FitFn,
7216 mut exact_fn: ExactFn,
7217 mut exact_efs_fn: ExactEfsFn,
7218 mut seed_inner_beta_fn: SeedFn,
7219) -> Result<SpatialLengthScaleOptimizationResult<FitOut>, String>
7220where
7221 FitOut: Clone,
7222 FitFn: FnMut(
7223 &Array1<f64>,
7224 &[TermCollectionSpec],
7225 &[TermCollectionDesign],
7226 ) -> Result<FitOut, String>,
7227 ExactFn: FnMut(
7228 &Array1<f64>,
7229 &[TermCollectionSpec],
7230 &[TermCollectionDesign],
7231 gam_solve::estimate::reml::reml_outer_engine::EvalMode,
7232 &gam_problem::outer_subsample::RowSet,
7233 ) -> Result<
7234 (
7235 f64,
7236 Array1<f64>,
7237 gam_problem::HessianResult,
7238 ),
7239 String,
7240 >,
7241 ExactEfsFn: FnMut(
7242 &Array1<f64>,
7243 &[TermCollectionSpec],
7244 &[TermCollectionDesign],
7245 ) -> Result<gam_problem::EfsEval, String>,
7246 SeedFn:
7247 FnMut(&Array1<f64>) -> Result<gam_solve::rho_optimizer::SeedOutcome, EstimationError>,
7248{
7249 let n_blocks = block_specs.len();
7250 if block_term_indices.len() != n_blocks {
7251 return Err(SmoothError::dimension_mismatch(format!(
7252 "block_specs ({}) and block_term_indices ({}) length mismatch",
7253 n_blocks,
7254 block_term_indices.len()
7255 ))
7256 .into());
7257 }
7258
7259 let log_kappa_dim = joint_setup.log_kappa_dim();
7260
7261 log::warn!(
7262 "[OUTER-FD-AUDIT/spatial-exact-joint] driver entry: aux_dim={} log_kappa_dim={} kappa_enabled={} rho_dim={} theta0_len={}",
7263 joint_setup.auxiliary_dim(),
7264 log_kappa_dim,
7265 kappa_options.enabled,
7266 joint_setup.rho_dim(),
7267 joint_setup.theta0().len()
7268 );
7269
7270 if joint_setup.auxiliary_dim() == 0 && (!kappa_options.enabled || log_kappa_dim == 0) {
7274 log::warn!(
7275 "[OUTER-FD-AUDIT/spatial-exact-joint] taking FAST path (no outer theta optimization in this driver)"
7276 );
7277 let (designs, resolved_specs) = build_term_collection_designs_and_freeze_joint(
7278 data, block_specs,
7279 )
7280 .map_err(|e| {
7281 format!("failed to build and freeze joint block designs during exact joint kappa optimization: {e}")
7282 })?;
7283 let theta0 = joint_setup.theta0();
7284
7285 let spec_refs: Vec<TermCollectionSpec> = resolved_specs.clone();
7287 let design_refs: Vec<TermCollectionDesign> = designs.clone();
7288 let fit = fit_fn(&theta0, &spec_refs, &design_refs)?;
7289 return Ok(SpatialLengthScaleOptimizationResult {
7290 resolved_specs,
7291 designs,
7292 fit,
7293 timing: None,
7294 });
7295 }
7296
7297 let theta0 = joint_setup.theta0();
7301 let lower = joint_setup.lower();
7302 let upper = joint_setup.upper();
7303 if theta0.len() < log_kappa_dim || lower.len() != theta0.len() || upper.len() != theta0.len() {
7304 return Err(SmoothError::dimension_mismatch(format!(
7305 "invalid exact joint theta setup: theta0={}, lower={}, upper={}, required_log_kappa_dim={}",
7306 theta0.len(),
7307 lower.len(),
7308 upper.len(),
7309 log_kappa_dim
7310 ))
7311 .into());
7312 }
7313 let rho_dim = joint_setup.rho_dim();
7314 let all_dims = joint_setup.log_kappa_dims_per_term();
7315
7316 let (boot_designs, best_specs) = build_term_collection_designs_and_freeze_joint(
7318 data,
7319 block_specs,
7320 )
7321 .map_err(|e| {
7322 format!(
7323 "failed to build and freeze joint block designs during exact joint kappa bootstrap: {e}"
7324 )
7325 })?;
7326 let policy_hessian_form = outer_derivative_policy.declared_hessian_form();
7336 let analytic_outer_hessian_available = analytic_joint_hessian_available
7337 && matches!(
7338 policy_hessian_form,
7339 gam_problem::DeclaredHessianForm::Either
7340 | gam_problem::DeclaredHessianForm::Dense
7341 | gam_problem::DeclaredHessianForm::Operator { .. }
7342 );
7343 let prefer_gradient_only = !analytic_outer_hessian_available;
7344
7345 let theta_dim = theta0.len();
7346 let psi_dim = theta_dim - rho_dim;
7347
7348 let cache_blocks: Vec<(TermCollectionSpec, TermCollectionDesign, Vec<usize>)> = best_specs
7350 .iter()
7351 .zip(boot_designs.iter())
7352 .zip(block_term_indices.iter())
7353 .map(|((spec, design), terms)| (spec.clone(), design.clone(), terms.clone()))
7354 .collect();
7355
7356 struct NBlockExactJointState<'d> {
7357 cache: ExactJointDesignCache<'d>,
7358 }
7359
7360 let mut state = NBlockExactJointState {
7361 cache: ExactJointDesignCache::new(data, cache_blocks, rho_dim, all_dims.clone())?,
7362 };
7363
7364 const KAPPA_PILOT_K: usize = 5_000;
7389 const KAPPA_POLISH_K: usize = 25_000;
7390 const KAPPA_POLISH_TRIGGER_N: usize = 100_000;
7391
7392 let n_total = data.nrows();
7393 let use_staged_kappa = outer_derivative_policy.should_use_staged_kappa(n_total);
7394 if use_staged_kappa {
7395 log::info!(
7396 "[KAPPA-STAGED] auto-engaging pilot+polish schedule: n={} pilot_k={} polish_k={}",
7397 n_total,
7398 KAPPA_PILOT_K,
7399 KAPPA_POLISH_K,
7400 );
7401 }
7402
7403 fn build_uniform_pilot_subsample(
7420 n_total: usize,
7421 k_target: usize,
7422 seed: u64,
7423 ) -> gam_problem::outer_subsample::OuterScoreSubsample {
7424 use gam_problem::outer_subsample::OuterScoreSubsample;
7425 let k = k_target.min(n_total);
7426 if k == 0 || n_total == 0 {
7427 return OuterScoreSubsample::from_uniform_inclusion_mask(Vec::new(), n_total, seed);
7428 }
7429 let mut mask: Vec<usize> = Vec::with_capacity(k);
7433 let mut state = seed.wrapping_add(0x9E3779B97F4A7C15);
7435 let splitmix = |s: &mut u64| -> u64 { gam_linalg::utils::splitmix64(s) };
7436 let mut taken = std::collections::HashSet::with_capacity(k);
7437 for j in (n_total - k)..n_total {
7438 let r = (splitmix(&mut state) % (j as u64 + 1)) as usize;
7439 if !taken.insert(r) {
7440 taken.insert(j);
7441 mask.push(j);
7442 } else {
7443 mask.push(r);
7444 }
7445 }
7446 mask.sort_unstable();
7447 mask.dedup();
7448 OuterScoreSubsample::from_uniform_inclusion_mask(mask, n_total, seed)
7449 }
7450
7451 let current_row_set: std::cell::RefCell<gam_problem::outer_subsample::RowSet> = if use_staged_kappa {
7452 let pilot = build_uniform_pilot_subsample(n_total, KAPPA_PILOT_K, n_total as u64);
7453 std::cell::RefCell::new(gam_problem::outer_subsample::RowSet::Subsample {
7454 rows: std::sync::Arc::clone(&pilot.rows),
7455 n_full: n_total,
7456 })
7457 } else {
7458 std::cell::RefCell::new(gam_problem::outer_subsample::RowSet::All)
7459 };
7460
7461 let exact_fn_cell = std::cell::RefCell::new(&mut exact_fn);
7462 let exact_efs_fn_cell = std::cell::RefCell::new(&mut exact_efs_fn);
7463
7464 use std::cell::Cell;
7479 let kphase_cost_calls: Cell<usize> = Cell::new(0);
7480 let kphase_cost_total_s: Cell<f64> = Cell::new(0.0);
7481 let kphase_eval_calls: Cell<usize> = Cell::new(0);
7482 let kphase_eval_total_s: Cell<f64> = Cell::new(0.0);
7483 let kphase_efs_calls: Cell<usize> = Cell::new(0);
7484 let kphase_efs_total_s: Cell<f64> = Cell::new(0.0);
7485 let kphase_optim_start = std::time::Instant::now();
7486 let kphase_log_kappa_dim = log_kappa_dim;
7487 let kphase_log_norms = |theta: &Array1<f64>| -> (f64, f64) {
7488 let theta_norm = theta.iter().map(|v| v * v).sum::<f64>().sqrt();
7489 let log_kappa_norm = if kphase_log_kappa_dim > 0 && theta.len() >= kphase_log_kappa_dim {
7490 let start = theta.len() - kphase_log_kappa_dim;
7491 theta.iter().skip(start).map(|v| v * v).sum::<f64>().sqrt()
7492 } else {
7493 0.0
7494 };
7495 (theta_norm, log_kappa_norm)
7496 };
7497
7498 use gam_solve::rho_optimizer::OuterEvalOrder;
7499 use gam_problem::{DeclaredHessianForm, Derivative, OuterEval};
7500
7501 let joint_p_cols: usize = boot_designs
7505 .iter()
7506 .map(|d| d.design.ncols())
7507 .sum::<usize>()
7508 .max(1);
7509
7510 let problem = exact_joint_multistart_outer_problem(
7511 &theta0,
7512 &lower,
7513 &upper,
7514 rho_dim,
7515 psi_dim,
7516 theta_dim,
7517 if analytic_joint_gradient_available {
7518 Derivative::Analytic
7519 } else {
7520 Derivative::Unavailable
7521 },
7522 if analytic_outer_hessian_available {
7523 DeclaredHessianForm::Either
7524 } else {
7525 DeclaredHessianForm::Unavailable
7526 },
7527 prefer_gradient_only,
7528 disable_fixed_point,
7529 seed_risk_profile,
7530 kappa_options.rel_tol.max(1e-6),
7531 kappa_options.max_outer_iter.max(1),
7532 Some(5.0),
7534 Some(kappa_options.log_step.clamp(0.25, 1.0)),
7536 screening_cap.clone(),
7537 Some((n_total, joint_p_cols)),
7540 block_specs
7543 .iter()
7544 .any(|s| !constant_curvature_term_indices(s).is_empty()),
7545 );
7546
7547 fn collect_specs(cache: &ExactJointDesignCache<'_>) -> Vec<TermCollectionSpec> {
7549 cache.specs().into_iter().cloned().collect()
7550 }
7551 fn collect_designs(cache: &ExactJointDesignCache<'_>) -> Vec<TermCollectionDesign> {
7552 cache.designs().into_iter().cloned().collect()
7553 }
7554
7555 let result = {
7556 let eval_outer = |ctx: &mut &mut NBlockExactJointState<'_>,
7557 theta: &Array1<f64>,
7558 order: OuterEvalOrder|
7559 -> Result<OuterEval, EstimationError> {
7560 if let Some((cost, grad, hess)) = ctx.cache.memoized_eval(theta) {
7561 let cached_satisfies_order = match order {
7562 OuterEvalOrder::Value => true,
7563 OuterEvalOrder::ValueAndGradient => true,
7564 OuterEvalOrder::ValueGradientHessian => hess.is_analytic(),
7565 };
7566 if cached_satisfies_order {
7567 if !cost.is_finite() {
7568 return Ok(OuterEval::infeasible(theta.len()));
7569 }
7570 if grad.iter().any(|v| !v.is_finite()) {
7583 return Ok(OuterEval::infeasible(theta.len()));
7584 }
7585 return Ok(OuterEval {
7586 cost,
7587 gradient: grad,
7588 hessian: hess,
7589 inner_beta_hint: None,
7590 });
7591 }
7592 }
7593 if gam_solve::rho_optimizer::outer_wall_clock_deadline_exceeded() {
7610 return Ok(OuterEval::infeasible(theta.len()));
7611 }
7612 if let Err(err) = ctx.cache.ensure_theta(theta) {
7613 log::warn!(
7614 "[OUTER] n-block exact-joint spatial: ensure_theta failed during gradient evaluation: {err}"
7615 );
7616 return Ok(OuterEval::infeasible(theta.len()));
7617 }
7618 let design_revision = Some(ctx.cache.design_revision());
7619 let specs = collect_specs(&ctx.cache);
7620 let designs = collect_designs(&ctx.cache);
7621 let clamped = outer_derivative_policy.order_for_evaluation(order);
7629 let need_hessian = matches!(clamped, OuterEvalOrder::ValueGradientHessian)
7630 && analytic_outer_hessian_available;
7631 let eval_mode = if need_hessian {
7632 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueGradientHessian
7633 } else {
7634 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueAndGradient
7635 };
7636 let t0 = std::time::Instant::now();
7637 let result = {
7638 let row_set_borrow = current_row_set.borrow();
7639 (*exact_fn_cell.borrow_mut())(theta, &specs, &designs, eval_mode, &row_set_borrow)
7640 };
7641 let elapsed_s = t0.elapsed().as_secs_f64();
7642 kphase_eval_calls.set(kphase_eval_calls.get() + 1);
7643 kphase_eval_total_s.set(kphase_eval_total_s.get() + elapsed_s);
7644 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta);
7645 log::info!(
7646 "[KAPPA-PHASE] phase=eval_outer call={} order={:?} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
7647 kphase_eval_calls.get(),
7648 order,
7649 design_revision,
7650 theta_norm,
7651 log_kappa_norm,
7652 elapsed_s,
7653 );
7654 match result {
7655 Ok((cost, grad, hess)) => {
7656 ctx.cache.store_eval((cost, grad.clone(), hess.clone()));
7657 if !cost.is_finite() {
7658 return Ok(OuterEval::infeasible(theta.len()));
7659 }
7660 if grad.iter().any(|v| !v.is_finite()) {
7673 return Ok(OuterEval::infeasible(theta.len()));
7674 }
7675 Ok(OuterEval {
7676 cost,
7677 gradient: grad,
7678 hessian: hess,
7679 inner_beta_hint: None,
7680 })
7681 }
7682 Err(err) => {
7683 log::warn!(
7684 "[OUTER] n-block exact-joint spatial: exact evaluation failed: {err}"
7685 );
7686 Ok(OuterEval::infeasible(theta.len()))
7687 }
7688 }
7689 };
7690
7691 let obj = problem.build_objective_with_eval_order(
7692 &mut state,
7693 |ctx: &mut &mut NBlockExactJointState<'_>, theta: &Array1<f64>| {
7694 if let Some(cost) = ctx.cache.memoized_cost(theta) {
7695 return Ok(cost);
7696 }
7697 if gam_solve::rho_optimizer::outer_wall_clock_deadline_exceeded() {
7705 return Ok(f64::INFINITY);
7706 }
7707 if let Err(err) = ctx.cache.ensure_theta(theta) {
7708 log::warn!(
7709 "[OUTER] n-block exact-joint spatial: ensure_theta failed during cost evaluation: {err}"
7710 );
7711 return Ok(f64::INFINITY);
7712 }
7713 let design_revision = Some(ctx.cache.design_revision());
7714 let specs = collect_specs(&ctx.cache);
7715 let designs = collect_designs(&ctx.cache);
7716 let t0 = std::time::Instant::now();
7723 let result = {
7724 let row_set_borrow = current_row_set.borrow();
7725 (*exact_fn_cell.borrow_mut())(
7726 theta,
7727 &specs,
7728 &designs,
7729 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueOnly,
7730 &row_set_borrow,
7731 )
7732 };
7733 let elapsed_s = t0.elapsed().as_secs_f64();
7734 kphase_cost_calls.set(kphase_cost_calls.get() + 1);
7735 kphase_cost_total_s.set(kphase_cost_total_s.get() + elapsed_s);
7736 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta);
7737 log::info!(
7738 "[KAPPA-PHASE] phase=cost call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
7739 kphase_cost_calls.get(),
7740 design_revision,
7741 theta_norm,
7742 log_kappa_norm,
7743 elapsed_s,
7744 );
7745 match result {
7746 Ok((cost, _grad, _hess)) => {
7747 ctx.cache.store_cost_only(theta, cost);
7753 Ok(cost)
7754 }
7755 Err(err) => {
7756 log::warn!(
7757 "[OUTER] n-block exact-joint spatial: exact cost evaluation failed: {err}"
7758 );
7759 Ok(f64::INFINITY)
7760 }
7761 }
7762 },
7763 |ctx: &mut &mut NBlockExactJointState<'_>, theta: &Array1<f64>| {
7764 eval_outer(
7765 ctx,
7766 theta,
7767 if analytic_outer_hessian_available {
7768 OuterEvalOrder::ValueGradientHessian
7769 } else {
7770 OuterEvalOrder::ValueAndGradient
7771 },
7772 )
7773 },
7774 |ctx: &mut &mut NBlockExactJointState<'_>,
7775 theta: &Array1<f64>,
7776 order: OuterEvalOrder| { eval_outer(ctx, theta, order) },
7777 None::<fn(&mut &mut NBlockExactJointState<'_>)>,
7778 Some(
7779 |ctx: &mut &mut NBlockExactJointState<'_>, theta: &Array1<f64>| {
7780 ctx.cache
7781 .ensure_theta(theta)
7782 .map_err(EstimationError::InvalidInput)?;
7783 let design_revision = Some(ctx.cache.design_revision());
7784 let specs = collect_specs(&ctx.cache);
7785 let designs = collect_designs(&ctx.cache);
7786 let t0 = std::time::Instant::now();
7787 let eval_result = (*exact_efs_fn_cell.borrow_mut())(
7788 theta,
7789 &specs,
7790 &designs,
7791 );
7792 let elapsed_s = t0.elapsed().as_secs_f64();
7793 kphase_efs_calls.set(kphase_efs_calls.get() + 1);
7794 kphase_efs_total_s.set(kphase_efs_total_s.get() + elapsed_s);
7795 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta);
7796 log::info!(
7797 "[KAPPA-PHASE] phase=efs call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
7798 kphase_efs_calls.get(),
7799 design_revision,
7800 theta_norm,
7801 log_kappa_norm,
7802 elapsed_s,
7803 );
7804 let eval = eval_result.map_err(EstimationError::RemlOptimizationFailed)?;
7805 Ok(eval)
7806 },
7807 ),
7808 );
7809 let mut obj = obj.with_seed_inner_state(
7810 move |_ctx: &mut &mut NBlockExactJointState<'_>, beta: &Array1<f64>| {
7811 (seed_inner_beta_fn)(beta)
7812 },
7813 );
7814
7815 match problem.run(&mut obj, "n-block exact-joint spatial") {
7816 Ok(result) => result,
7817 Err(e) => {
7818 let message = e.to_string();
7819 if kappa_phase_failure_is_fixed_kappa_recoverable(&message) {
7839 drop(obj);
7840 log::warn!(
7841 "[KAPPA-PHASE] length-scale optimization could not validate any seed \
7842 ({message}); falling back to a FIXED bootstrap κ (skipping κ \
7843 optimization) and fitting there — a real model at the initial \
7844 length-scale rather than raising (gam#787/#860)."
7845 );
7846 let (designs, resolved_specs) =
7847 build_term_collection_designs_and_freeze_joint(data, block_specs).map_err(
7848 |build_err| {
7849 format!(
7850 "fixed-κ fallback failed to build and freeze joint block \
7851 designs after κ optimization could not validate a seed \
7852 ({message}): {build_err}"
7853 )
7854 },
7855 )?;
7856 let fixed_theta0 = joint_setup.theta0();
7857 let spec_refs: Vec<TermCollectionSpec> = resolved_specs.clone();
7858 let design_refs: Vec<TermCollectionDesign> = designs.clone();
7859 let fit = fit_fn(&fixed_theta0, &spec_refs, &design_refs)?;
7860 return Ok(SpatialLengthScaleOptimizationResult {
7861 resolved_specs,
7862 designs,
7863 fit,
7864 timing: None,
7865 });
7866 }
7867 return Err(message);
7868 }
7869 }
7870 }; let kphase_total_s = kphase_optim_start.elapsed().as_secs_f64();
7880 log::info!(
7881 "[KAPPA-PHASE-SUMMARY] log_kappa_dim={} n_cost={} cost_total_s={:.4} n_eval={} eval_total_s={:.4} n_efs={} efs_total_s={:.4} optim_total_s={:.4}",
7882 kphase_log_kappa_dim,
7883 kphase_cost_calls.get(),
7884 kphase_cost_total_s.get(),
7885 kphase_eval_calls.get(),
7886 kphase_eval_total_s.get(),
7887 kphase_efs_calls.get(),
7888 kphase_efs_total_s.get(),
7889 kphase_total_s,
7890 );
7891 let timing = SpatialLengthScaleOptimizationTiming {
7892 log_kappa_dim: kphase_log_kappa_dim,
7893 cost_calls: kphase_cost_calls.get(),
7894 cost_total_s: kphase_cost_total_s.get(),
7895 eval_calls: kphase_eval_calls.get(),
7896 eval_total_s: kphase_eval_total_s.get(),
7897 efs_calls: kphase_efs_calls.get(),
7898 efs_total_s: kphase_efs_total_s.get(),
7899 slow_path_resets: 0,
7900 design_revision_delta: 0,
7901 nfree_miss_shape: 0,
7902 nfree_miss_value: 0,
7903 nfree_miss_gradient: 0,
7904 nfree_miss_penalty: 0,
7905 nfree_miss_revision: 0,
7906 nfree_miss_second_order: 0,
7907 nfree_miss_other: 0,
7908 optim_total_s: kphase_total_s,
7909 };
7910
7911 let theta_star = result.rho;
7912
7913 if use_staged_kappa && n_total >= KAPPA_POLISH_TRIGGER_N {
7930 let polish = build_uniform_pilot_subsample(
7931 n_total,
7932 KAPPA_POLISH_K,
7933 (n_total as u64).wrapping_add(0xA5A5A5A5),
7934 );
7935 *current_row_set.borrow_mut() = gam_problem::outer_subsample::RowSet::Subsample {
7936 rows: std::sync::Arc::clone(&polish.rows),
7937 n_full: n_total,
7938 };
7939 log::info!(
7940 "[KAPPA-STAGED] rotating to polish subsample: k={} at theta_star",
7941 polish.rows.len(),
7942 );
7943 state.cache.ensure_theta(&theta_star)?;
7947 let (polish_cost, polish_grad, _) = {
7948 let specs = collect_specs(&state.cache);
7949 let designs = collect_designs(&state.cache);
7950 let row_set_borrow = current_row_set.borrow();
7951 exact_fn(
7952 &theta_star,
7953 &specs,
7954 &designs,
7955 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueAndGradient,
7956 &row_set_borrow,
7957 )?
7958 };
7959 if !polish_cost.is_finite() || polish_grad.iter().any(|value| !value.is_finite()) {
7960 return Err(
7961 "polish subsample exact-joint evaluation produced non-finite objective pieces"
7962 .to_string(),
7963 );
7964 }
7965 }
7966 *current_row_set.borrow_mut() = gam_problem::outer_subsample::RowSet::All;
7967 if use_staged_kappa {
7968 log::info!(
7969 "[KAPPA-STAGED] rotating to full data for final coefficient fit (n={})",
7970 n_total,
7971 );
7972 }
7973
7974 state.cache.ensure_theta(&theta_star)?;
7975
7976 let resolved_specs: Vec<TermCollectionSpec> = collect_specs(&state.cache);
7977 let designs: Vec<TermCollectionDesign> = collect_designs(&state.cache);
7978
7979 let fit = fit_fn(&theta_star, &resolved_specs, &designs)?;
7980
7981 for spec in &resolved_specs {
7982 log_spatial_aniso_scales(spec);
7983 }
7984
7985 Ok(SpatialLengthScaleOptimizationResult {
7986 resolved_specs,
7987 designs,
7988 fit,
7989 timing: Some(timing),
7990 })
7991}
7992
7993fn try_exact_joint_latent_coord_optimization(
7994 data: ArrayView2<'_, f64>,
7995 y: ArrayView1<'_, f64>,
7996 weights: ArrayView1<'_, f64>,
7997 offset: ArrayView1<'_, f64>,
7998 resolvedspec: &TermCollectionSpec,
7999 best: &FittedTermCollection,
8000 family: LikelihoodSpec,
8001 options: &FitOptions,
8002 latent: &StandardLatentCoordConfig,
8003) -> Result<FittedTermCollectionWithSpec, EstimationError> {
8004 use gam_solve::rho_optimizer::OuterEvalOrder;
8005 use gam_problem::{DeclaredHessianForm, Derivative, OuterEval};
8006
8007 let rho_dim = best.fit.lambdas.len();
8008 let latent_flat_dim = latent.values.len();
8009 if latent_flat_dim == 0 {
8010 crate::bail_invalid_estim!(
8011 "latent-coordinate optimization requires a non-empty latent block"
8012 );
8013 }
8014 let direct_hypers =
8015 latent_coord_initial_direct_hypers(latent.values.id_mode(), latent.values.latent_dim())?;
8016 let analytic_rho_count = latent
8017 .analytic_penalties
8018 .as_ref()
8019 .map_or(0, |registry| registry.total_rho_count());
8020 let latent_coord_ext_dim = latent_flat_dim + analytic_rho_count + direct_hypers.len();
8021
8022 let mut theta0 = Array1::<f64>::zeros(rho_dim + latent_coord_ext_dim);
8023 theta0
8024 .slice_mut(s![..rho_dim])
8025 .assign(&best.fit.lambdas.mapv(f64::ln));
8026 theta0
8027 .slice_mut(s![rho_dim..rho_dim + latent_flat_dim])
8028 .assign(latent.values.as_flat());
8029 if !direct_hypers.is_empty() {
8030 let direct_start = rho_dim + latent_flat_dim + analytic_rho_count;
8031 theta0
8032 .slice_mut(s![direct_start..direct_start + direct_hypers.len()])
8033 .assign(&direct_hypers);
8034 }
8035
8036 let mut lower = Array1::<f64>::from_elem(theta0.len(), -12.0);
8037 let mut upper = Array1::<f64>::from_elem(theta0.len(), 12.0);
8038 let latent_bound = latent
8039 .values
8040 .as_flat()
8041 .iter()
8042 .fold(1.0_f64, |acc, &v| acc.max(v.abs()))
8043 + 10.0;
8044 for axis in rho_dim..rho_dim + latent_flat_dim {
8045 lower[axis] = -latent_bound;
8046 upper[axis] = latent_bound;
8047 }
8048
8049 struct LatentJointContext<'d> {
8050 rho_dim: usize,
8051 cache: SingleBlockLatentCoordDesignCache,
8052 evaluator: gam_solve::estimate::ExternalJointHyperEvaluator<'d>,
8053 }
8054
8055 impl<'d> LatentJointContext<'d> {
8056 fn eval_full(
8057 &mut self,
8058 theta: &Array1<f64>,
8059 order: OuterEvalOrder,
8060 ) -> Result<
8061 (
8062 f64,
8063 Array1<f64>,
8064 gam_problem::HessianResult,
8065 ),
8066 EstimationError,
8067 > {
8068 if let Some(eval) = self.cache.memoized_eval(theta) {
8069 return Ok(eval);
8070 }
8071 self.cache
8072 .ensure_theta(theta)
8073 .map_err(EstimationError::InvalidInput)?;
8074 let hyper_dirs = self
8075 .cache
8076 .hyper_dirs()
8077 .map_err(EstimationError::InvalidInput)?;
8078 let design_revision = Some(self.cache.design_revision());
8079 let registry_for_key = self.cache.analytic_penalties();
8080 self.evaluator
8081 .set_analytic_penalty_registry(registry_for_key.as_deref());
8082 let mut eval = evaluate_joint_reml_outer_eval_at_theta(
8083 &mut self.evaluator,
8084 self.cache.design(),
8085 theta,
8086 self.rho_dim,
8087 hyper_dirs,
8088 None,
8089 order,
8090 design_revision,
8091 )?;
8092 let latent = self.cache.latent().map_err(EstimationError::InvalidInput)?;
8093 if let Some(registry) = registry_for_key {
8094 let mut registry = registry.as_ref().clone();
8095 registry.apply_weight_schedules(
8096 gam_solve::estimate::reml::outer_eval::current_outer_iter() as usize,
8097 );
8098 add_analytic_penalty_objective_to_eval(
8099 theta,
8100 self.rho_dim,
8101 latent.as_ref(),
8102 ®istry,
8103 &mut eval,
8104 )?;
8105 }
8106 add_latent_id_objective_to_eval(
8107 theta,
8108 self.rho_dim,
8109 self.cache.analytic_penalty_rho_count(),
8110 latent.as_ref(),
8111 &mut eval,
8112 )?;
8113 self.cache.store_eval(eval.clone());
8114 Ok(eval)
8115 }
8116
8117 fn eval_efs(
8118 &mut self,
8119 theta: &Array1<f64>,
8120 ) -> Result<gam_problem::EfsEval, EstimationError> {
8121 self.cache
8122 .ensure_theta(theta)
8123 .map_err(EstimationError::InvalidInput)?;
8124 let hyper_dirs = self
8125 .cache
8126 .hyper_dirs()
8127 .map_err(EstimationError::InvalidInput)?;
8128 let registry_for_key = self.cache.analytic_penalties();
8129 self.evaluator
8130 .set_analytic_penalty_registry(registry_for_key.as_deref());
8131 let mut efs = evaluate_joint_reml_efs_at_theta(
8132 &mut self.evaluator,
8133 self.cache.design(),
8134 theta,
8135 self.rho_dim,
8136 hyper_dirs,
8137 None,
8138 Some(self.cache.design_revision()),
8139 )?;
8140 if let Some(registry) = registry_for_key {
8141 let mut registry = registry.as_ref().clone();
8142 registry.apply_weight_schedules(
8143 gam_solve::estimate::reml::outer_eval::current_outer_iter() as usize,
8144 );
8145 let latent = self.cache.latent().map_err(EstimationError::InvalidInput)?;
8146 let contribution = analytic_penalty_objective_contribution(
8147 theta,
8148 self.rho_dim,
8149 latent.as_ref(),
8150 ®istry,
8151 )?;
8152 efs.cost += contribution.cost;
8153 if let (Some(psi_gradient), Some(psi_indices)) =
8154 (efs.psi_gradient.as_mut(), efs.psi_indices.as_ref())
8155 {
8156 if psi_gradient.len() != psi_indices.len() {
8157 crate::bail_invalid_estim!(
8158 "latent-coordinate analytic penalty EFS psi gradient length mismatch: gradient={}, indices={}",
8159 psi_gradient.len(),
8160 psi_indices.len()
8161 );
8162 }
8163 for (local_idx, &theta_idx) in psi_indices.iter().enumerate() {
8164 psi_gradient[local_idx] += contribution.gradient[theta_idx];
8165 }
8166 }
8167 }
8168 Ok(efs)
8169 }
8170
8171 fn eval_cost(&mut self, theta: &Array1<f64>) -> f64 {
8172 if let Some(cost) = self.cache.memoized_cost(theta) {
8173 return cost;
8174 }
8175 if self.cache.ensure_theta(theta).is_err() {
8176 return f64::INFINITY;
8177 }
8178 let design_revision = Some(self.cache.design_revision());
8179 let registry_for_key = self.cache.analytic_penalties();
8180 self.evaluator
8181 .set_analytic_penalty_registry(registry_for_key.as_deref());
8182 let result = {
8183 let design = self.cache.design();
8184 self.evaluator.evaluate_cost_only(
8185 &design.design,
8186 &design.penalties,
8187 &design.nullspace_dims,
8188 design.linear_constraints.clone(),
8189 theta,
8190 self.rho_dim,
8191 None,
8192 "latent-coordinate-joint cost-only",
8193 design_revision,
8194 )
8195 };
8196 match result {
8197 Ok(cost) => {
8198 let latent = match self.cache.latent() {
8199 Ok(latent) => latent,
8200 Err(_) => return f64::INFINITY,
8201 };
8202 let contribution = match latent_id_objective_contribution(
8203 theta,
8204 self.rho_dim,
8205 self.cache.analytic_penalty_rho_count(),
8206 latent.as_ref(),
8207 ) {
8208 Ok(contribution) => contribution,
8209 Err(_) => return f64::INFINITY,
8210 };
8211 let cost = cost + contribution.cost;
8212 let cost = if let Some(registry) = registry_for_key {
8213 let mut registry = registry.as_ref().clone();
8214 registry.apply_weight_schedules(
8215 gam_solve::estimate::reml::outer_eval::current_outer_iter()
8216 as usize,
8217 );
8218 match analytic_penalty_objective_contribution(
8219 theta,
8220 self.rho_dim,
8221 latent.as_ref(),
8222 ®istry,
8223 ) {
8224 Ok(contribution) => cost + contribution.cost,
8225 Err(_) => return f64::INFINITY,
8226 }
8227 } else {
8228 cost
8229 };
8230 self.cache.store_cost(cost);
8231 cost
8232 }
8233 Err(_) => f64::INFINITY,
8234 }
8235 }
8236 }
8237
8238 let mut ctx = LatentJointContext {
8239 rho_dim,
8240 cache: SingleBlockLatentCoordDesignCache::new(
8241 data.to_owned(),
8242 resolvedspec.clone(),
8243 best.design.clone(),
8244 latent,
8245 rho_dim,
8246 )
8247 .map_err(EstimationError::InvalidInput)?,
8248 evaluator: gam_solve::estimate::ExternalJointHyperEvaluator::new(
8249 y,
8250 weights,
8251 &best.design.design,
8252 offset,
8253 &best.design.penalties,
8254 &external_opts_for_design(&family, &best.design, options),
8255 "latent-coordinate-joint",
8256 )?,
8257 };
8258 let registry_for_key = ctx.cache.analytic_penalties();
8259 ctx.evaluator
8260 .set_analytic_penalty_registry(registry_for_key.as_deref());
8261 ctx.evaluator
8262 .set_persistent_latent_values_fingerprint(latent.values.id_mode());
8263 if let Some(cached_t) = ctx
8264 .evaluator
8265 .load_persistent_latent_values(latent.values.n_obs(), latent.values.latent_dim())
8266 {
8267 let cached_t: Array2<f64> = cached_t;
8268 for (dst, src) in theta0
8269 .slice_mut(s![rho_dim..rho_dim + latent_flat_dim])
8270 .iter_mut()
8271 .zip(cached_t.iter())
8272 {
8273 *dst = *src;
8274 }
8275 }
8276
8277 let problem = exact_joint_multistart_outer_problem(
8278 &theta0,
8279 &lower,
8280 &upper,
8281 rho_dim,
8282 latent_coord_ext_dim,
8283 theta0.len(),
8284 Derivative::Analytic,
8285 DeclaredHessianForm::Unavailable,
8286 false,
8287 false,
8288 seed_risk_profile_for_likelihood_family(&family),
8289 options.tol,
8290 options.max_iter.max(1),
8291 Some(5.0),
8292 Some(0.5),
8293 None,
8294 Some((data.nrows(), best.design.design.ncols().max(1))),
8297 !constant_curvature_term_indices(resolvedspec).is_empty(),
8300 );
8301
8302 let eval_outer = |ctx: &mut &mut LatentJointContext<'_>,
8303 theta: &Array1<f64>,
8304 order: OuterEvalOrder|
8305 -> Result<OuterEval, EstimationError> {
8306 let (cost, gradient, hessian) = ctx.eval_full(theta, order)?;
8307 Ok(OuterEval {
8308 cost,
8309 gradient,
8310 hessian,
8311 inner_beta_hint: None,
8312 })
8313 };
8314
8315 let result = {
8316 let mut obj = problem.build_objective_with_eval_order(
8317 &mut ctx,
8318 |ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>| Ok(ctx.eval_cost(theta)),
8319 |ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>| {
8320 eval_outer(ctx, theta, OuterEvalOrder::ValueAndGradient)
8321 },
8322 |ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>, order: OuterEvalOrder| {
8323 eval_outer(ctx, theta, order)
8324 },
8325 Some(|ctx: &mut &mut LatentJointContext<'_>| {
8326 ctx.cache.reset();
8327 }),
8328 Some(|ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>| ctx.eval_efs(theta)),
8329 );
8330
8331 problem
8332 .run(&mut obj, "latent-coordinate joint REML")
8333 .map_err(|e| {
8334 EstimationError::InvalidInput(format!(
8335 "latent-coordinate joint optimization failed after exhausting strategy fallbacks: {e}"
8336 ))
8337 })?
8338 };
8339 if !result.converged {
8340 crate::bail_invalid_estim!(
8341 "latent-coordinate joint optimization did not converge after {} iterations (final_objective={:.6e}, final_grad_norm={})",
8342 result.iterations,
8343 result.final_value,
8344 result.final_grad_norm_report(),
8345 );
8346 }
8347
8348 let theta_star = result.rho;
8349 let rho_star = theta_star.slice(s![..rho_dim]).mapv(f64::exp);
8350 let mut final_data = data.to_owned();
8351 let flat_t = theta_star
8352 .slice(s![rho_dim..rho_dim + latent_flat_dim])
8353 .to_owned();
8354 let mut fitted_latent_values =
8355 Array2::<f64>::zeros((latent.values.n_obs(), latent.values.latent_dim()));
8356 for n in 0..latent.values.n_obs() {
8357 for axis in 0..latent.values.latent_dim() {
8358 let value = flat_t[n * latent.values.latent_dim() + axis];
8359 fitted_latent_values[[n, axis]] = value;
8360 final_data[[n, latent.feature_cols[axis]]] = value;
8361 }
8362 }
8363 let optimized = fit_term_collection_forspecwith_heuristic_lambdas(
8364 final_data.view(),
8365 y,
8366 weights,
8367 offset,
8368 resolvedspec,
8369 rho_star.as_slice(),
8370 family,
8371 options,
8372 )?;
8373 ctx.evaluator
8374 .store_persistent_latent_values(&fitted_latent_values);
8375 let mut fit = optimized.fit;
8376 fit.reml_score = result.final_value;
8377 fit.penalized_objective = result.final_value;
8378 Ok(FittedTermCollectionWithSpec {
8379 fit,
8380 design: optimized.design,
8381 resolvedspec: resolvedspec.clone(),
8382 adaptive_diagnostics: optimized.adaptive_diagnostics,
8383 kappa_timing: None,
8384 })
8385}
8386
8387pub fn fit_term_collectionwith_latent_coord_optimization(
8388 data: ArrayView2<'_, f64>,
8389 y: Array1<f64>,
8390 weights: Array1<f64>,
8391 offset: Array1<f64>,
8392 spec: &TermCollectionSpec,
8393 latent: &StandardLatentCoordConfig,
8394 family: LikelihoodSpec,
8395 options: &FitOptions,
8396) -> Result<FittedTermCollectionWithSpec, EstimationError> {
8397 let n = data.nrows();
8398 if !(y.len() == n && weights.len() == n && offset.len() == n) {
8399 crate::bail_invalid_estim!(
8400 "fit_term_collectionwith_latent_coord_optimization row mismatch: n={}, y={}, weights={}, offset={}",
8401 n,
8402 y.len(),
8403 weights.len(),
8404 offset.len()
8405 );
8406 }
8407 let best = fit_term_collection_forspec(
8408 data,
8409 y.view(),
8410 weights.view(),
8411 offset.view(),
8412 spec,
8413 family.clone(),
8414 options,
8415 )?;
8416 let resolvedspec = freeze_term_collection_from_design(spec, &best.design)?;
8417 try_exact_joint_latent_coord_optimization(
8418 data,
8419 y.view(),
8420 weights.view(),
8421 offset.view(),
8422 &resolvedspec,
8423 &best,
8424 family,
8425 options,
8426 latent,
8427 )
8428}
8429
8430pub fn fit_term_collectionwith_spatial_length_scale_optimization(
8431 data: ArrayView2<'_, f64>,
8432 y: Array1<f64>,
8433 weights: Array1<f64>,
8434 offset: Array1<f64>,
8435 spec: &TermCollectionSpec,
8436 family: LikelihoodSpec,
8437 options: &FitOptions,
8438 kappa_options: &SpatialLengthScaleOptimizationOptions,
8439) -> Result<FittedTermCollectionWithSpec, EstimationError> {
8440 let mut resolvedspec = spec.clone();
8456 let spatial_terms = spatial_length_scale_term_indices(&resolvedspec);
8457 let n = data.nrows();
8458 if !(y.len() == n && weights.len() == n && offset.len() == n) {
8459 crate::bail_invalid_estim!(
8460 "fit_term_collectionwith_spatial_length_scale_optimization row mismatch: n={}, y={}, weights={}, offset={}",
8461 n,
8462 y.len(),
8463 weights.len(),
8464 offset.len()
8465 );
8466 }
8467 if !kappa_options.enabled || spatial_terms.is_empty() {
8468 let out = fit_term_collection_forspec(
8469 data,
8470 y.view(),
8471 weights.view(),
8472 offset.view(),
8473 &resolvedspec,
8474 family,
8475 options,
8476 )?;
8477 let resolvedspec = freeze_term_collection_from_design(&resolvedspec, &out.design)?;
8478 return Ok(FittedTermCollectionWithSpec {
8479 fit: out.fit,
8480 design: out.design,
8481 resolvedspec,
8482 adaptive_diagnostics: out.adaptive_diagnostics,
8483 kappa_timing: None,
8484 });
8485 }
8486 if kappa_options.max_outer_iter == 0 {
8487 crate::bail_invalid_estim!("spatial kappa optimization requires max_outer_iter >= 1");
8488 }
8489 if !(kappa_options.log_step.is_finite() && kappa_options.log_step > 0.0) {
8490 crate::bail_invalid_estim!("spatial kappa optimization requires log_step > 0");
8491 }
8492 if !(kappa_options.min_length_scale.is_finite()
8493 && kappa_options.max_length_scale.is_finite()
8494 && kappa_options.min_length_scale > 0.0
8495 && kappa_options.max_length_scale >= kappa_options.min_length_scale)
8496 {
8497 crate::bail_invalid_estim!(
8498 "spatial kappa optimization requires valid positive length_scale bounds"
8499 );
8500 }
8501
8502 let pilot_threshold = kappa_options.pilot_subsample_threshold;
8503 if pilot_threshold > 0 && n > pilot_threshold * 2 {
8504 log::info!(
8505 "[spatial-kappa] n={n} exceeds pilot threshold {}; using pilot geometry only for deterministic anisotropy initialization",
8506 pilot_threshold * 2,
8507 );
8508 apply_spatial_anisotropy_pilot_initializer(
8509 data,
8510 &mut resolvedspec,
8511 &spatial_terms,
8512 pilot_threshold,
8513 kappa_options,
8514 );
8515 }
8516
8517 apply_response_aware_anisotropy_seed(data, y.view(), &mut resolvedspec, &spatial_terms);
8526
8527 for term_idx in constant_curvature_term_indices(&resolvedspec) {
8545 if let Some(kappa_seed) =
8546 select_constant_curvature_kappa_sign_seed(data, y.view(), &resolvedspec, term_idx)
8547 && kappa_seed != 0.0
8548 && let Some(SmoothBasisSpec::ConstantCurvature { spec: cc, .. }) =
8549 resolvedspec.smooth_terms.get_mut(term_idx).map(|t| &mut t.basis)
8550 {
8551 log::info!(
8552 "[#1464] pinned CC term {term_idx} baseline κ to κ-fair scan value {kappa_seed} \
8553 (raw profiled REML is sign-blind; scan is authoritative for the sign)"
8554 );
8555 cc.kappa = kappa_seed;
8556 }
8557 }
8558
8559 let baseline_options = superseded_fit_options(options);
8560 let mut best = fit_term_collection_forspec(
8561 data,
8562 y.view(),
8563 weights.view(),
8564 offset.view(),
8565 &resolvedspec,
8566 family.clone(),
8567 &baseline_options,
8568 )?;
8569 resolvedspec = freeze_term_collection_from_design(&resolvedspec, &best.design)?;
8570 let mut spatial_terms = spatial_length_scale_term_indices(&resolvedspec);
8580 sync_aniso_contrasts_from_metadata(&mut resolvedspec, &best.design.smooth);
8584 let mut prescan_improved = false;
8591 if !spatial_terms.is_empty() {
8592 let baseline_score = fit_score(&best.fit);
8593 let range_overrides = prescan_isotropic_spatial_range_seed(
8594 data,
8595 y.view(),
8596 weights.view(),
8597 offset.view(),
8598 &resolvedspec,
8599 baseline_score,
8600 &family,
8601 &baseline_options,
8602 kappa_options,
8603 &spatial_terms,
8604 )?;
8605 if !range_overrides.is_empty() {
8606 prescan_improved = true;
8607 for (term_idx, length_scale) in range_overrides {
8608 set_spatial_length_scale(&mut resolvedspec, term_idx, length_scale)?;
8609 }
8610 best = fit_term_collection_forspec(
8614 data,
8615 y.view(),
8616 weights.view(),
8617 offset.view(),
8618 &resolvedspec,
8619 family.clone(),
8620 &baseline_options,
8621 )?;
8622 resolvedspec = freeze_term_collection_from_design(&resolvedspec, &best.design)?;
8623 spatial_terms = spatial_length_scale_term_indices(&resolvedspec);
8627 }
8628 }
8629 if spatial_terms.is_empty() {
8630 let fitted = fit_term_collection_forspecwith_heuristic_lambdas(
8631 data,
8632 y.view(),
8633 weights.view(),
8634 offset.view(),
8635 &resolvedspec,
8636 best.fit.lambdas.as_slice(),
8637 family,
8638 options,
8639 )?;
8640 return Ok(FittedTermCollectionWithSpec {
8641 fit: fitted.fit,
8642 design: fitted.design,
8643 resolvedspec,
8644 adaptive_diagnostics: fitted.adaptive_diagnostics,
8645 kappa_timing: None,
8646 });
8647 }
8648 let initial_score = fit_score(&best.fit);
8649 if !initial_score.is_finite() {
8650 log::debug!("[spatial-kappa] initial profiled score is non-finite");
8651 }
8652 let seed_length_scales: Vec<(usize, f64)> = spatial_terms
8659 .iter()
8660 .filter_map(|&t| get_spatial_length_scale(&resolvedspec, t).map(|ls| (t, ls)))
8661 .collect();
8662 let joint_result = try_exact_joint_spatial_length_scale_optimization(
8663 data,
8664 y.view(),
8665 weights.view(),
8666 offset.view(),
8667 &resolvedspec,
8668 &best,
8669 family.clone(),
8670 options,
8671 kappa_options,
8672 &spatial_terms,
8673 )
8674 .map(|opt| {
8675 opt.map(|fit| {
8676 let score = fit_score(&fit.fit);
8677 (fit, score)
8678 })
8679 });
8680 let exact_joint = if prescan_improved && !matches!(joint_result, Ok(Some(_))) {
8690 let reason = match &joint_result {
8691 Err(e) => format!("error: {e}"),
8692 _ => "unavailable".to_string(),
8693 };
8694 log::info!(
8695 "[spatial-kappa] #1074 joint polish yielded no usable candidate \
8696 ({reason}); returning the multi-start pre-scan geometry (REML {initial_score:.5})"
8697 );
8698 FittedTermCollectionWithSpec {
8699 fit: best.fit,
8700 design: best.design,
8701 resolvedspec,
8702 adaptive_diagnostics: best.adaptive_diagnostics,
8703 kappa_timing: None,
8704 }
8705 } else {
8706 require_successful_spatial_optimization_result(initial_score, joint_result)?
8707 };
8708
8709 let exact_joint = {
8736 let primary_score = fit_score(&exact_joint.fit);
8737 let improved = primary_score.is_finite()
8738 && initial_score.is_finite()
8739 && primary_score < initial_score - 1e-7 * initial_score.abs().max(1.0);
8740 let base_spec = exact_joint.resolvedspec.clone();
8745 let geometry_unchanged = !seed_length_scales.is_empty()
8748 && seed_length_scales.iter().all(|&(t, seed_ls)| {
8749 get_spatial_length_scale(&base_spec, t)
8750 .is_some_and(|ls| (ls - seed_ls).abs() <= 1e-6 * seed_ls.abs().max(1.0))
8751 });
8752 let eligible = !improved
8753 && geometry_unchanged
8754 && !has_aniso_terms(&base_spec, &spatial_terms)
8755 && constant_curvature_term_indices(&base_spec).is_empty()
8756 && spatial_terms
8757 .iter()
8758 .any(|&t| get_spatial_length_scale(&base_spec, t).is_some());
8759 if eligible {
8760 log::info!(
8761 "[spatial-kappa] #1688 joint solve stalled at REML {primary_score:.5} \
8762 (no improvement over baseline {initial_score:.5}); running ψ-window \
8763 multistart rescue across {} seeds",
8764 JOINT_RESTART_WINDOW_FRACTIONS.len(),
8765 );
8766 let mut best_fit = exact_joint;
8767 let mut best_score = primary_score;
8769 for &fraction in JOINT_RESTART_WINDOW_FRACTIONS.iter() {
8770 match joint_solve_from_window_fraction(
8771 data,
8772 y.view(),
8773 weights.view(),
8774 offset.view(),
8775 &base_spec,
8776 &spatial_terms,
8777 fraction,
8778 &family,
8779 options,
8780 &baseline_options,
8781 kappa_options,
8782 ) {
8783 Ok(Some((candidate, score))) => {
8784 if score.is_finite()
8785 && (!best_score.is_finite()
8786 || score < best_score - 1e-7 * best_score.abs().max(1.0))
8787 {
8788 log::info!(
8789 "[spatial-kappa] #1688 multistart seed (ψ-window {fraction:.2}) \
8790 reached REML {score:.5}, improving on {best_score:.5}",
8791 );
8792 best_score = score;
8793 best_fit = candidate;
8794 }
8795 }
8796 Ok(None) => {}
8798 Err(e) => {
8802 log::info!(
8803 "[spatial-kappa] #1688 multistart seed (ψ-window {fraction:.2}) \
8804 failed ({e}); skipping"
8805 );
8806 }
8807 }
8808 }
8809 best_fit
8810 } else {
8811 exact_joint
8812 }
8813 };
8814
8815 log_spatial_aniso_scales(&exact_joint.resolvedspec);
8816 Ok(exact_joint)
8817}
8818
8819#[derive(Clone, Debug)]
8825pub struct CurvatureInference {
8826 pub term_idx: usize,
8828 pub kappa_hat: f64,
8831 pub ci: gam_geometry::curvature_estimand::KappaProfileCi,
8833 pub flatness: gam_geometry::curvature_estimand::FlatnessTest,
8837}
8838
8839pub fn curvature_inference_forspec(
8857 data: ArrayView2<'_, f64>,
8858 y: ArrayView1<'_, f64>,
8859 weights: ArrayView1<'_, f64>,
8860 offset: ArrayView1<'_, f64>,
8861 resolvedspec: &TermCollectionSpec,
8862 term_idx: usize,
8863 family: LikelihoodSpec,
8864 options: &FitOptions,
8865 level: f64,
8866) -> Result<CurvatureInference, EstimationError> {
8867 let kappa_hat = get_constant_curvature_kappa(resolvedspec, term_idx).ok_or_else(|| {
8868 EstimationError::InvalidInput(format!(
8869 "curvature_inference_forspec: term {term_idx} is not a constant-curvature smooth"
8870 ))
8871 })?;
8872 let (kappa_min, kappa_max) = constant_curvature_kappa_bounds(data, resolvedspec, term_idx);
8873
8874 let cc_fair_inputs: Option<(Array2<f64>, gam_terms::basis::ConstantCurvatureBasisSpec)> =
8899 if kappa_hat < 0.0 {
8900 match resolvedspec.smooth_terms.get(term_idx).map(|t| &t.basis) {
8901 Some(SmoothBasisSpec::ConstantCurvature {
8902 feature_cols, spec, ..
8903 }) => select_columns(data, feature_cols)
8904 .ok()
8905 .map(|x| (x, spec.clone())),
8906 _ => None,
8907 }
8908 } else {
8909 None
8910 };
8911
8912 let v_p_cache: std::cell::RefCell<std::collections::HashMap<u64, f64>> =
8917 std::cell::RefCell::new(std::collections::HashMap::new());
8918 let v_p = |kappa: f64| -> Result<f64, String> {
8919 if !kappa.is_finite() {
8920 return Err(format!("V_p probed a non-finite κ = {kappa}"));
8921 }
8922 let key = kappa.to_bits();
8923 if let Some(&cached) = v_p_cache.borrow().get(&key) {
8924 return Ok(cached);
8925 }
8926 let score = if let Some((x_term, base_spec)) = &cc_fair_inputs {
8927 let mut probe_spec = base_spec.clone();
8928 probe_spec.kappa = kappa;
8929 gam_terms::basis::constant_curvature_kappa_fair_sign_score(x_term.view(), y, &probe_spec)
8930 .map_err(|e| format!("κ-fair criterion at κ={kappa} failed: {e}"))?
8931 } else {
8932 fixed_kappa_profiled_reml_score(
8933 data,
8934 y,
8935 weights,
8936 offset,
8937 resolvedspec,
8938 term_idx,
8939 kappa,
8940 family.clone(),
8941 options,
8942 )
8943 .map_err(|e| format!("V_p fixed-κ fit at κ={kappa} failed: {e}"))?
8944 };
8945 v_p_cache.borrow_mut().insert(key, score);
8946 Ok(score)
8947 };
8948
8949 let h = (1e-3 * (kappa_max - kappa_min)).max(1e-4);
8953 let v_pp = match (v_p(kappa_hat + h), v_p(kappa_hat), v_p(kappa_hat - h)) {
8954 (Ok(vp), Ok(v0), Ok(vm)) => (vp - 2.0 * v0 + vm) / (h * h),
8955 _ => f64::NAN, };
8957
8958 let ci = gam_geometry::curvature_estimand::profile_ci_walk(
8959 &v_p, kappa_hat, v_pp, kappa_min, kappa_max, level, 1e-4,
8960 )
8961 .map_err(EstimationError::InvalidInput)?;
8962 let flatness = gam_geometry::curvature_estimand::flatness_lr_test(&v_p, kappa_hat)
8963 .map_err(EstimationError::InvalidInput)?;
8964
8965 Ok(CurvatureInference {
8966 term_idx,
8967 kappa_hat,
8968 ci,
8969 flatness,
8970 })
8971}
8972
8973#[derive(Clone, Copy, Debug, PartialEq, Eq)]
8976pub enum SmoothLrCorrection {
8977 LawleyLrEstimatedLambda,
8981 LawleyLrFixedLambda,
8986 None,
8990}
8991
8992impl SmoothLrCorrection {
8993 pub fn label(self) -> &'static str {
8995 match self {
8996 SmoothLrCorrection::LawleyLrEstimatedLambda => "lawley_lr_estimated_lambda",
8997 SmoothLrCorrection::LawleyLrFixedLambda => "lawley_lr_fixed_lambda",
8998 SmoothLrCorrection::None => "none",
8999 }
9000 }
9001}
9002
9003#[derive(Clone, Debug)]
9009pub struct SmoothTermLrInference {
9010 pub name: String,
9012 pub term_idx: usize,
9014 pub statistic_lr: f64,
9017 pub ref_df: f64,
9020 pub bartlett_factor: f64,
9023 pub bartlett_factor_conditional: Option<f64>,
9027 pub rho_variation_shift: Option<f64>,
9030 pub statistic_corrected: f64,
9032 pub p_value_uncorrected: f64,
9034 pub p_value_corrected: f64,
9037 pub material: bool,
9045 pub correction: SmoothLrCorrection,
9047}
9048
9049pub const SMOOTH_LR_MATERIAL_THRESHOLD: f64 = 0.10;
9053
9054fn fitted_rho_penalty_components(
9060 penalties: &[BlockwisePenalty],
9061 lambdas: &[f64],
9062 p_total: usize,
9063) -> Result<Vec<gam_terms::inference::lawley::RhoPenaltyComponent>, EstimationError> {
9064 if penalties.len() != lambdas.len() {
9065 return Err(EstimationError::InvalidInput(format!(
9066 "smooth_term_lr_inference: penalty/lambda count mismatch ({} penalties, {} lambdas)",
9067 penalties.len(),
9068 lambdas.len()
9069 )));
9070 }
9071 let mut components = Vec::with_capacity(penalties.len());
9072 for (idx, (penalty, &lambda)) in penalties.iter().zip(lambdas.iter()).enumerate() {
9073 if !(lambda.is_finite() && lambda >= 0.0) {
9074 return Err(EstimationError::InvalidInput(format!(
9075 "smooth_term_lr_inference: lambda[{idx}] is invalid: {lambda}"
9076 )));
9077 }
9078 let r = &penalty.col_range;
9079 if r.end > p_total {
9080 return Err(EstimationError::InvalidInput(format!(
9081 "smooth_term_lr_inference: penalty[{idx}] range {:?} exceeds coefficient dimension {p_total}",
9082 r
9083 )));
9084 }
9085 let mut s_component = Array2::<f64>::zeros((p_total, p_total));
9086 s_component
9087 .slice_mut(s![r.start..r.end, r.start..r.end])
9088 .scaled_add(lambda, &penalty.local);
9089 components.push(gam_terms::inference::lawley::RhoPenaltyComponent { s_component });
9090 }
9091 Ok(components)
9092}
9093
9094pub fn smooth_term_lr_inference_forspec(
9139 data: ArrayView2<'_, f64>,
9140 y: ArrayView1<'_, f64>,
9141 weights: ArrayView1<'_, f64>,
9142 offset: ArrayView1<'_, f64>,
9143 resolvedspec: &TermCollectionSpec,
9144 family: LikelihoodSpec,
9145 options: &FitOptions,
9146) -> Result<Vec<SmoothTermLrInference>, EstimationError> {
9147 use gam_terms::inference::lawley::{
9148 LAWLEY_PAIR_MATRIX_MAX_ROWS, known_scale_expected_jets_with_dispersion,
9149 lawley_lr_bartlett_factor, lawley_lr_mean_shift_with_rho_variation,
9150 };
9151
9152 let n = data.nrows();
9153 let full = fit_term_collection_forspec(
9156 data,
9157 y,
9158 weights,
9159 offset,
9160 resolvedspec,
9161 family.clone(),
9162 options,
9163 )?;
9164 let ll_full = full.fit.log_likelihood;
9165 let p_total = full.design.design.ncols();
9166 let lambdas = full.fit.lambdas.as_slice().ok_or_else(|| {
9167 EstimationError::InvalidInput(
9168 "smooth_term_lr_inference: non-contiguous lambda vector".to_string(),
9169 )
9170 })?;
9171 let s_lambda = weighted_blockwise_penalty_sum(&full.design.penalties, lambdas, p_total);
9172 let rho_penalty_components =
9173 fitted_rho_penalty_components(&full.design.penalties, lambdas, p_total)?;
9174 let rho_covariance = full.fit.artifacts.rho_covariance.as_ref().filter(|cov| {
9175 cov.nrows() == rho_penalty_components.len() && cov.ncols() == rho_penalty_components.len()
9176 });
9177 let full_design_dense = full.design.design.to_dense();
9179 let influence = full.fit.coefficient_influence();
9180 let family_disp = lawley_dispersion_for_family(&family, &full.fit);
9181
9182 let mut penalty_cursor = full.design.random_effect_ranges.len();
9185 let mut out = Vec::<SmoothTermLrInference>::new();
9186 for (term_idx, design_term) in full.design.smooth.terms.iter().enumerate() {
9187 let k = design_term.penalties_local.len();
9188 let block_start = penalty_cursor;
9189 penalty_cursor += k;
9190 if design_term.shape != ShapeConstraint::None {
9193 continue;
9194 }
9195 let coeff_range = design_term.coeff_range.clone();
9196 if coeff_range.start >= coeff_range.end || coeff_range.end > p_total {
9197 continue;
9198 }
9199 let edf = full.fit.per_term_edf(coeff_range.clone(), block_start, k);
9211 let null_dim = design_term.wald_unpenalized_dim();
9231 let edf_floor = (null_dim.max(1)) as f64;
9283 let untrusted_edf_collapse = !full.fit.outer_converged && edf < edf_floor;
9284 let unconverged_dim_floor = if untrusted_edf_collapse {
9285 coeff_range.len() as f64
9286 } else {
9287 0.0
9288 };
9289 let ref_df = wood_reference_df(influence, &coeff_range)
9290 .unwrap_or(0.0)
9291 .max(edf)
9292 .max(null_dim as f64)
9293 .max(unconverged_dim_floor)
9294 .max(1.0);
9295 if !(ref_df.is_finite() && ref_df > 0.0) {
9296 continue;
9297 }
9298
9299 let mut null_spec = resolvedspec.clone();
9302 let Some(spec_pos) = null_spec
9303 .smooth_terms
9304 .iter()
9305 .position(|t| t.name == design_term.name)
9306 else {
9307 continue;
9308 };
9309 null_spec.smooth_terms.remove(spec_pos);
9310 let null_fit = fit_term_collection_forspec(
9311 data,
9312 y,
9313 weights,
9314 offset,
9315 &null_spec,
9316 family.clone(),
9317 options,
9318 );
9319 let (statistic_lr, eta_null) = match null_fit {
9320 Ok(null) if null.fit.log_likelihood.is_finite() => {
9321 let w = (2.0 * (ll_full - null.fit.log_likelihood)).max(0.0);
9322 let mut eta = null.design.design.dot(&null.fit.beta);
9326 eta += &offset;
9327 (w, Some(eta))
9328 }
9329 _ => (f64::NAN, None),
9330 };
9331
9332 let chi2 = statrs::distribution::ChiSquared::new(ref_df).ok();
9333 let p_uncorrected = match (chi2.as_ref(), statistic_lr.is_finite()) {
9334 (Some(dist), true) => {
9335 use statrs::distribution::ContinuousCDF;
9336 (1.0 - dist.cdf(statistic_lr)).clamp(0.0, 1.0)
9337 }
9338 _ => f64::NAN,
9339 };
9340
9341 let mut bartlett_factor = 1.0;
9345 let mut bartlett_factor_conditional = None;
9346 let mut rho_variation_shift = None;
9347 let mut statistic_corrected = statistic_lr;
9348 let mut p_corrected = p_uncorrected;
9349 let mut correction = SmoothLrCorrection::None;
9350 if let (Some(eta), true, true) = (
9351 eta_null.as_ref(),
9352 statistic_lr.is_finite(),
9353 n <= LAWLEY_PAIR_MATRIX_MAX_ROWS,
9354 ) {
9355 let kappas: Option<Vec<_>> = (0..n)
9356 .map(|i| {
9357 known_scale_expected_jets_with_dispersion(&family, eta[i], family_disp)
9358 .and_then(|jets| jets.kappas().ok())
9359 })
9360 .collect();
9361 if let (Some(kappas), Some(dist)) = (kappas, chi2.as_ref()) {
9362 let fixed_factor = lawley_lr_bartlett_factor(
9363 full_design_dense.view(),
9364 &kappas,
9365 Some(s_lambda.view()),
9366 coeff_range.clone(),
9367 ref_df,
9368 );
9369 if let Ok(c_cond) = fixed_factor
9370 && c_cond.is_finite()
9371 && c_cond > 0.0
9372 {
9373 let mut c_applied = c_cond;
9374 correction = SmoothLrCorrection::LawleyLrFixedLambda;
9375 if let Some(cov) = rho_covariance
9376 && let Ok(total_shift) = lawley_lr_mean_shift_with_rho_variation(
9377 full_design_dense.view(),
9378 &kappas,
9379 s_lambda.view(),
9380 coeff_range.clone(),
9381 &rho_penalty_components,
9382 cov.view(),
9383 )
9384 {
9385 let mean_w = ref_df + total_shift;
9386 if let Some(c_est) =
9387 gam_terms::inference::higher_order::bartlett_factor_from_mean(
9388 mean_w, ref_df,
9389 )
9390 && c_est.is_finite()
9391 && c_est > 0.0
9392 {
9393 let conditional_shift = (c_cond - 1.0) * ref_df;
9394 c_applied = c_est;
9395 bartlett_factor_conditional = Some(c_cond);
9396 rho_variation_shift = Some(total_shift - conditional_shift);
9397 correction = SmoothLrCorrection::LawleyLrEstimatedLambda;
9398 }
9399 }
9400 use statrs::distribution::ContinuousCDF;
9401 bartlett_factor = c_applied;
9402 statistic_corrected = statistic_lr / c_applied;
9403 p_corrected = (1.0 - dist.cdf(statistic_corrected)).clamp(0.0, 1.0);
9404 }
9405 }
9406 }
9407
9408 let material = match correction {
9414 SmoothLrCorrection::LawleyLrEstimatedLambda
9415 | SmoothLrCorrection::LawleyLrFixedLambda => {
9416 let factor_move = (bartlett_factor - 1.0).abs();
9417 let p_denom = p_uncorrected.max(p_corrected).max(f64::MIN_POSITIVE);
9418 let p_move = if p_uncorrected.is_finite() && p_corrected.is_finite() {
9419 (p_corrected - p_uncorrected).abs() / p_denom
9420 } else {
9421 0.0
9422 };
9423 factor_move > SMOOTH_LR_MATERIAL_THRESHOLD || p_move > SMOOTH_LR_MATERIAL_THRESHOLD
9424 }
9425 SmoothLrCorrection::None => false,
9426 };
9427
9428 out.push(SmoothTermLrInference {
9429 name: design_term.name.clone(),
9430 term_idx,
9431 statistic_lr,
9432 ref_df,
9433 bartlett_factor,
9434 bartlett_factor_conditional,
9435 rho_variation_shift,
9436 statistic_corrected,
9437 p_value_uncorrected: p_uncorrected,
9438 p_value_corrected: p_corrected,
9439 material,
9440 correction,
9441 });
9442 }
9443 Ok(out)
9444}
9445
9446fn lawley_dispersion_for_family(family: &LikelihoodSpec, fit: &UnifiedFitResult) -> f64 {
9449 match family.response {
9450 gam_spec::ResponseFamily::Gaussian => {
9451 let sd = fit.standard_deviation;
9452 (sd * sd).max(f64::MIN_POSITIVE)
9453 }
9454 gam_spec::ResponseFamily::Gamma => {
9455 let shape = fit.standard_deviation;
9456 if shape.is_finite() && shape > 0.0 {
9457 1.0 / shape
9458 } else {
9459 1.0
9460 }
9461 }
9462 _ => 1.0,
9463 }
9464}
9465
9466fn wood_reference_df(influence: Option<&Array2<f64>>, coeff_range: &Range<usize>) -> Option<f64> {
9490 let f = influence?;
9491 let (start, end) = (coeff_range.start, coeff_range.end);
9492 if start >= end || end > f.nrows() || end > f.ncols() {
9493 return None;
9494 }
9495 let block = f.slice(s![start..end, start..end]);
9496 let tr = (0..block.nrows()).map(|i| block[[i, i]]).sum::<f64>();
9497 let tr2 = block.dot(&block).diag().sum();
9498 (tr.is_finite() && tr2.is_finite() && tr > 0.0)
9499 .then(|| (2.0 * tr - tr2).max(tr).max(1e-12))
9500}