1fn try_build_spatial_term_log_kappa_derivative(
2 data: ArrayView2<'_, f64>,
3 resolvedspec: &TermCollectionSpec,
4 design: &TermCollectionDesign,
5 term_idx: usize,
6) -> Result<
7 Option<(
8 Range<usize>,
9 usize,
10 Array2<f64>,
11 Array2<f64>,
12 Array2<f64>,
13 Array2<f64>,
14 Vec<Array2<f64>>,
15 Vec<Array2<f64>>,
16 Option<std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>>,
17 )>,
18 EstimationError,
19> {
20 let Some(smooth_term) = design.smooth.terms.get(term_idx) else {
21 return Ok(None);
22 };
23 let Some(termspec) = resolvedspec.smooth_terms.get(term_idx) else {
24 return Ok(None);
25 };
26
27 let derivative_bundle = match &termspec.basis {
28 SmoothBasisSpec::ThinPlate {
29 feature_cols,
30 spec,
31 input_scales,
32 } => {
33 let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
34 let mut spec_local = spec.clone();
35 if let Some(s) = input_scales {
36 apply_input_standardization(&mut x, s);
37 spec_local.length_scale =
38 compensate_length_scale_for_standardization(spec.length_scale, s);
39 }
40 build_thin_plate_basis_log_kappa_derivatives(x.view(), &spec_local)
41 .map_err(EstimationError::from)?
42 }
43 SmoothBasisSpec::Sphere { .. } => return Ok(None),
44 SmoothBasisSpec::ConstantCurvature { feature_cols, spec } => {
53 let x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
54 build_constant_curvature_basis_kappa_derivatives(x.view(), spec)
55 .map_err(EstimationError::from)?
56 }
57 SmoothBasisSpec::MeasureJet { .. } => return Ok(None),
63 SmoothBasisSpec::Matern {
64 feature_cols,
65 spec,
66 input_scales,
67 } => {
68 let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
69 let mut spec_local = spec.clone();
70 if let Some(s) = input_scales {
71 apply_input_standardization(&mut x, s);
72 spec_local.length_scale =
73 compensate_length_scale_for_standardization(spec.length_scale, s);
74 }
75 spec_local.double_penalty = false;
90 build_matern_basis_log_kappa_derivatives(x.view(), &spec_local)
91 .map_err(EstimationError::from)?
92 }
93 SmoothBasisSpec::Duchon {
94 feature_cols,
95 spec,
96 input_scales,
97 } => {
98 let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
99 let mut spec_local = spec.clone();
100 if let Some(s) = input_scales {
101 apply_input_standardization(&mut x, s);
102 spec_local.length_scale =
103 compensate_optional_length_scale_for_standardization(spec.length_scale, s);
104 }
105 let BasisMetadata::Duchon {
106 centers,
107 identifiability_transform,
108 operator_collocation_points,
109 radial_reparam,
110 ..
111 } = &smooth_term.metadata
112 else {
113 return Ok(None);
114 };
115 if spec_local.radial_reparam.is_none() {
118 spec_local.radial_reparam = radial_reparam.clone();
119 }
120 gam_terms::basis::build_duchon_basis_log_kappa_derivativeswith_collocationwithworkspace(
121 x.view(),
122 &spec_local,
123 centers.view(),
124 identifiability_transform.as_ref(),
125 operator_collocation_points
126 .as_ref()
127 .map(|points| points.view()),
128 &mut BasisWorkspace::default(),
129 )
130 .map_err(EstimationError::from)?
131 }
132 SmoothBasisSpec::BSpline1D { .. }
133 | SmoothBasisSpec::TensorBSpline { .. }
134 | SmoothBasisSpec::ByVariable { .. }
135 | SmoothBasisSpec::FactorSumToZero { .. }
136 | SmoothBasisSpec::BySmooth { .. }
137 | SmoothBasisSpec::FactorSmooth { .. }
138 | SmoothBasisSpec::Pca { .. } => {
139 return Ok(None);
140 }
141 };
142 let mut implicit_operator = derivative_bundle.implicit_operator;
143 let BasisPsiDerivativeResult {
144 design_derivative: mut local_x_psi,
145 penalties_derivative: mut local_s_psi,
146 implicit_operator: local_implicit_first_unused,
147 } = derivative_bundle.first;
148 let BasisPsiSecondDerivativeResult {
149 designsecond_derivative: mut local_x_psi_psi,
150 penaltiessecond_derivative: mut local_s_psi_psi,
151 implicit_operator: local_implicit_second_unused,
152 } = derivative_bundle.second;
153 assert!(local_implicit_first_unused.is_none());
154 assert!(local_implicit_second_unused.is_none());
155
156 if let Some(rotation) = smooth_term.joint_null_rotation.as_ref() {
157 let q = &rotation.rotation;
158 if let Some(op) = implicit_operator.take() {
159 implicit_operator = Some(op.append_full_transform(q).map_err(EstimationError::from)?);
160 } else {
161 if local_x_psi.ncols() != q.nrows() || local_x_psi_psi.ncols() != q.nrows() {
162 return Ok(None);
163 }
164 local_x_psi = fast_ab(&local_x_psi, q);
165 local_x_psi_psi = fast_ab(&local_x_psi_psi, q);
166 }
167 let rotate_penalty = |s_local: Array2<f64>| -> Option<Array2<f64>> {
168 if s_local.nrows() != q.nrows() || s_local.ncols() != q.nrows() {
169 return None;
170 }
171 let qt_s = gam_linalg::faer_ndarray::fast_atb(q, &s_local);
172 Some(gam_linalg::faer_ndarray::fast_ab(&qt_s, q))
173 };
174 let Some(rotated_s_psi) = local_s_psi
175 .into_iter()
176 .map(|s| rotate_penalty(s))
177 .collect::<Option<Vec<_>>>()
178 else {
179 return Ok(None);
180 };
181 local_s_psi = rotated_s_psi;
182 let Some(rotated_s_psi_psi) = local_s_psi_psi
183 .into_iter()
184 .map(|s| rotate_penalty(s))
185 .collect::<Option<Vec<_>>>()
186 else {
187 return Ok(None);
188 };
189 local_s_psi_psi = rotated_s_psi_psi;
190 }
191 let implicit_operator = implicit_operator.map(std::sync::Arc::new);
192
193 if let Some(ref op) = implicit_operator {
194 if op.p_out() != smooth_term.coeff_range.len() {
195 return Ok(None);
196 }
197 } else {
198 if local_x_psi.ncols() != smooth_term.coeff_range.len() {
199 return Ok(None);
200 }
201 if local_x_psi_psi.ncols() != smooth_term.coeff_range.len() {
202 return Ok(None);
203 }
204 }
205 if local_s_psi.is_empty() || local_s_psi.len() != local_s_psi_psi.len() {
206 return Ok(None);
207 }
208 if local_s_psi.iter().any(|s| {
209 s.nrows() != smooth_term.coeff_range.len() || s.ncols() != smooth_term.coeff_range.len()
210 }) {
211 return Ok(None);
212 }
213 if local_s_psi_psi.iter().any(|s| {
214 s.nrows() != smooth_term.coeff_range.len() || s.ncols() != smooth_term.coeff_range.len()
215 }) {
216 return Ok(None);
217 }
218
219 let p_total = design.design.ncols();
220 let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
221 let global_range = (smooth_start + smooth_term.coeff_range.start)
222 ..(smooth_start + smooth_term.coeff_range.end);
223
224 Ok(Some((
225 global_range,
226 p_total,
227 local_x_psi,
228 local_s_psi.iter().fold(
229 Array2::<f64>::zeros((smooth_term.coeff_range.len(), smooth_term.coeff_range.len())),
230 |acc, m| acc + m,
231 ),
232 local_x_psi_psi,
233 local_s_psi_psi.iter().fold(
234 Array2::<f64>::zeros((smooth_term.coeff_range.len(), smooth_term.coeff_range.len())),
235 |acc, m| acc + m,
236 ),
237 local_s_psi,
238 local_s_psi_psi,
239 implicit_operator,
240 )))
241}
242
243fn try_build_spatial_log_kappa_hyper_dirs(
244 data: ArrayView2<'_, f64>,
245 resolvedspec: &TermCollectionSpec,
246 design: &TermCollectionDesign,
247 spatial_terms: &[usize],
248) -> Result<Option<Vec<DirectionalHyperParam>>, EstimationError> {
249 let Some(info_list) =
256 try_build_spatial_log_kappa_derivativeinfo_list(data, resolvedspec, design, spatial_terms)?
257 else {
258 return Ok(None);
259 };
260 Ok(Some(spatial_log_kappa_hyper_dirs_frominfo_list(info_list)?))
261}
262
263pub(crate) fn try_build_latent_coord_hyper_dirs(
264 latent: std::sync::Arc<gam_terms::latent::LatentCoordValues>,
265 resolvedspec: &TermCollectionSpec,
266 design: &TermCollectionDesign,
267 latent_terms: &[gam_problem::types::SmoothTermIdx],
268 analytic_rho_count: usize,
269) -> Result<Option<Vec<DirectionalHyperParam>>, EstimationError> {
270 if latent_terms.is_empty() || latent.is_empty() {
271 return Ok(None);
272 }
273 if latent_terms.len() != 1 {
274 crate::bail_invalid_estim!(
275 "LatentCoord standard-fit hyper_dirs currently require exactly one latent smooth term"
276 .to_string(),
277 );
278 }
279 let term_idx = latent_terms[0];
280 let smooth_term = design.smooth.terms.get(term_idx.get()).ok_or_else(|| {
281 EstimationError::InvalidInput(format!(
282 "LatentCoord term index {term_idx} out of bounds for realized smooth design"
283 ))
284 })?;
285 let termspec = resolvedspec
286 .smooth_terms
287 .get(term_idx.get())
288 .ok_or_else(|| {
289 EstimationError::InvalidInput(format!(
290 "LatentCoord term index {term_idx} out of bounds for resolved smooth spec"
291 ))
292 })?;
293 let p_total = design.design.ncols();
294 let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
295 let global_range = (smooth_start + smooth_term.coeff_range.start)
296 ..(smooth_start + smooth_term.coeff_range.end);
297
298 let operator = match (&termspec.basis, &smooth_term.metadata) {
303 (
304 SmoothBasisSpec::Matern { .. },
305 BasisMetadata::Matern {
306 centers,
307 length_scale,
308 nu,
309 include_intercept,
310 identifiability_transform,
311 ..
312 },
313 ) => gam_terms::basis::LatentCoordDesignDerivative::new_matern(
314 latent.clone(),
315 std::sync::Arc::new(centers.clone()),
316 *length_scale,
317 *nu,
318 *include_intercept,
319 identifiability_transform.clone(),
320 )
321 .map_err(EstimationError::from)?,
322 (
323 SmoothBasisSpec::Duchon { .. },
324 BasisMetadata::Duchon {
325 centers,
326 length_scale,
327 power,
328 nullspace_order,
329 identifiability_transform,
330 ..
331 },
332 ) => gam_terms::basis::LatentCoordDesignDerivative::new_duchon(
333 latent.clone(),
334 std::sync::Arc::new(centers.clone()),
335 *length_scale,
336 *power,
337 *nullspace_order,
338 identifiability_transform.clone(),
339 )
340 .map_err(EstimationError::from)?,
341 (
342 SmoothBasisSpec::Sphere { .. },
343 BasisMetadata::Sphere {
344 centers,
345 penalty_order,
346 method,
347 constraint_transform,
348 ..
349 },
350 ) if matches!(*method, gam_terms::basis::SphereMethod::Wahba) => {
351 gam_terms::basis::LatentCoordDesignDerivative::new_sphere(
352 latent.clone(),
353 std::sync::Arc::new(centers.clone()),
354 *penalty_order,
355 constraint_transform.clone(),
356 )
357 .map_err(EstimationError::from)?
358 }
359 (
360 SmoothBasisSpec::BSpline1D { spec, .. },
361 BasisMetadata::BSpline1D {
362 knots,
363 identifiability_transform,
364 periodic,
365 degree: meta_degree,
366 ..
367 },
368 ) => {
369 let effective_degree = meta_degree.unwrap_or(spec.degree);
373 if let Some((domain_start, period, num_basis)) = periodic {
374 gam_terms::basis::LatentCoordDesignDerivative::new_periodic_bspline(
375 latent.clone(),
376 (*domain_start, *domain_start + *period),
377 effective_degree,
378 *num_basis,
379 identifiability_transform.clone(),
380 )
381 .map_err(EstimationError::from)?
382 } else {
383 gam_terms::basis::LatentCoordDesignDerivative::new_tensor_bspline(
384 latent.clone(),
385 vec![knots.clone()],
386 vec![effective_degree],
387 identifiability_transform.clone(),
388 )
389 .map_err(EstimationError::from)?
390 }
391 }
392 (
393 SmoothBasisSpec::TensorBSpline { .. },
394 BasisMetadata::TensorBSpline {
395 knots,
396 degrees,
397 identifiability_transform,
398 ..
399 },
400 ) => gam_terms::basis::LatentCoordDesignDerivative::new_tensor_bspline(
401 latent.clone(),
402 knots.clone(),
403 degrees.clone(),
404 identifiability_transform.clone(),
405 )
406 .map_err(EstimationError::from)?,
407 (SmoothBasisSpec::Pca { .. }, BasisMetadata::Pca { basis_matrix, .. }) => {
408 gam_terms::basis::LatentCoordDesignDerivative::new_pca(
409 latent.clone(),
410 std::sync::Arc::new(basis_matrix.clone()),
411 )
412 .map_err(EstimationError::from)?
413 }
414 _ => return Ok(None),
415 };
416 if operator.p_out() != global_range.len() {
417 crate::bail_invalid_estim!(
418 "LatentCoord derivative width mismatch for term '{}': operator p={}, coeff range={}",
419 smooth_term.name,
420 operator.p_out(),
421 global_range.len()
422 );
423 }
424 let operator = std::sync::Arc::new(operator);
425 let mut hyper_dirs = Vec::with_capacity(operator.n_axes());
426 for flat_axis in 0..operator.n_axes() {
427 let dir = DirectionalHyperParam::new_compact(
428 gam_solve::estimate::reml::HyperDesignDerivative::from_latent_coord(
429 operator.clone(),
430 flat_axis,
431 global_range.clone(),
432 p_total,
433 ),
434 Vec::new(),
435 None,
436 None,
437 )?
438 .not_penalty_like();
439 hyper_dirs.push(dir);
440 }
441 let direct_dim = latent_coord_direct_hyper_count(latent.id_mode(), latent.latent_dim());
442 if analytic_rho_count + direct_dim > 0 {
443 let zero_x = gam_solve::estimate::reml::HyperDesignDerivative::from(Array2::<f64>::zeros((
444 design.design.nrows(),
445 p_total,
446 )));
447 for _ in 0..analytic_rho_count {
448 hyper_dirs.push(
449 DirectionalHyperParam::new_compact(zero_x.clone(), Vec::new(), None, None)?
450 .not_penalty_like(),
451 );
452 }
453 for _ in 0..direct_dim {
454 hyper_dirs.push(
455 DirectionalHyperParam::new_compact(zero_x.clone(), Vec::new(), None, None)?
456 .not_penalty_like(),
457 );
458 }
459 }
460 Ok(Some(hyper_dirs))
461}
462
463fn latent_coord_direct_hyper_count(
464 id_mode: &gam_terms::latent::LatentIdMode,
465 latent_dim: usize,
466) -> usize {
467 use gam_terms::latent::{AuxPriorStrength, LatentIdMode};
468 match id_mode {
469 LatentIdMode::AuxPrior { strength, .. } => match strength {
470 AuxPriorStrength::Auto => 1,
471 AuxPriorStrength::Fixed(_) => 0,
472 },
473 LatentIdMode::AuxPriorDimSelection { strength, .. } => {
474 latent_dim
475 + match strength {
476 AuxPriorStrength::Auto => 1,
477 AuxPriorStrength::Fixed(_) => 0,
478 }
479 }
480 LatentIdMode::DimSelection { .. } => latent_dim,
481 LatentIdMode::IsometryToReference { strength, .. } => match strength {
484 AuxPriorStrength::Auto => 1,
485 AuxPriorStrength::Fixed(_) => 0,
486 },
487 LatentIdMode::AuxOutcome { head, .. } => head.n_coeffs(latent_dim) + latent_dim,
490 LatentIdMode::None => 0,
491 }
492}
493
494fn latent_coord_initial_direct_hypers(
495 id_mode: &gam_terms::latent::LatentIdMode,
496 latent_dim: usize,
497) -> Result<Array1<f64>, EstimationError> {
498 use gam_terms::latent::{AuxPriorStrength, LatentIdMode};
499 let mut values = Vec::with_capacity(latent_coord_direct_hyper_count(id_mode, latent_dim));
500 match id_mode {
501 LatentIdMode::AuxPrior { strength, .. } => {
502 if matches!(strength, AuxPriorStrength::Auto) {
503 values.push(0.0);
504 }
505 }
506 LatentIdMode::AuxPriorDimSelection {
507 strength,
508 init_log_precision,
509 ..
510 } => {
511 if matches!(strength, AuxPriorStrength::Auto) {
512 values.push(0.0);
513 }
514 append_latent_ard_seed(&mut values, init_log_precision.as_ref(), latent_dim)?;
515 }
516 LatentIdMode::DimSelection { init_log_precision } => {
517 append_latent_ard_seed(&mut values, init_log_precision.as_ref(), latent_dim)?;
518 }
519 LatentIdMode::IsometryToReference { strength, .. } => {
520 if matches!(strength, AuxPriorStrength::Auto) {
521 values.push(0.0);
522 }
523 }
524 LatentIdMode::AuxOutcome {
525 head,
526 init_log_precision,
527 } => {
528 values.extend(std::iter::repeat_n(0.0, head.n_coeffs(latent_dim)));
532 append_latent_ard_seed(&mut values, init_log_precision.as_ref(), latent_dim)?;
533 }
534 LatentIdMode::None => {}
535 }
536 Ok(Array1::from_vec(values))
537}
538
539fn append_latent_ard_seed(
540 values: &mut Vec<f64>,
541 init: Option<&Array1<f64>>,
542 latent_dim: usize,
543) -> Result<(), EstimationError> {
544 if let Some(init) = init {
545 if init.len() != latent_dim {
546 crate::bail_invalid_estim!(
547 "latent dim_selection init_log_precision length mismatch: got {}, expected {}",
548 init.len(),
549 latent_dim
550 );
551 }
552 values.extend(init.iter().copied());
553 } else {
554 values.extend(std::iter::repeat_n(0.0, latent_dim));
555 }
556 Ok(())
557}
558
559struct LatentIdObjectiveContribution {
560 cost: f64,
561 gradient: Array1<f64>,
562}
563
564fn latent_id_objective_contribution(
565 theta: &Array1<f64>,
566 rho_dim: usize,
567 analytic_rho_count: usize,
568 latent: &gam_terms::latent::LatentCoordValues,
569) -> Result<LatentIdObjectiveContribution, EstimationError> {
570 use gam_terms::latent::{AuxPriorStrength, LatentIdMode, aux_prior_targets};
571 let n_obs = latent.n_obs();
572 let latent_dim = latent.latent_dim();
573 let flat_len = latent.len();
574 let mut gradient = Array1::<f64>::zeros(theta.len());
575 let t_start = rho_dim;
576 let direct_start = t_start + flat_len + analytic_rho_count;
577 if theta.len() < direct_start {
578 crate::bail_invalid_estim!(
579 "latent-coordinate theta too short for id objective: got {}, need at least {}",
580 theta.len(),
581 direct_start
582 );
583 }
584 let t = latent.as_matrix();
585 let mut cost = 0.0;
586 let mut cursor = direct_start;
587
588 match latent.id_mode() {
589 LatentIdMode::AuxPrior {
590 u,
591 family,
592 strength,
593 }
594 | LatentIdMode::AuxPriorDimSelection {
595 u,
596 family,
597 strength,
598 ..
599 } => {
600 let (log_mu, mu) = match strength {
601 AuxPriorStrength::Fixed(mu) => (mu.ln(), *mu),
602 AuxPriorStrength::Auto => {
603 let log_mu = theta[cursor];
604 cursor += 1;
605 (log_mu, log_mu.exp())
606 }
607 };
608 let targets = aux_prior_targets(t.view(), u.view(), *family)
609 .map_err(EstimationError::InvalidInput)?;
610 let residual = &t - &targets;
611 let q = residual.iter().map(|v| v * v).sum::<f64>();
612 let k = (n_obs * latent_dim) as f64;
619 cost += 0.5 * mu * q - 0.5 * k * log_mu;
620
621 let projected_residual = aux_prior_targets(residual.view(), u.view(), *family)
622 .map_err(EstimationError::InvalidInput)?;
623 let grad_base = residual - projected_residual;
624 for n in 0..n_obs {
625 for axis in 0..latent_dim {
626 gradient[t_start + n * latent_dim + axis] += mu * grad_base[[n, axis]];
627 }
628 }
629 if matches!(strength, AuxPriorStrength::Auto) {
630 gradient[direct_start] += 0.5 * mu * q - 0.5 * k;
631 }
632 }
633 LatentIdMode::IsometryToReference { reference, strength } => {
634 if reference.dim() != (n_obs, latent_dim) {
641 crate::bail_invalid_estim!(
642 "IsometryToReference reference shape {:?} must equal (n_obs, latent_dim) = ({}, {})",
643 reference.dim(),
644 n_obs,
645 latent_dim
646 );
647 }
648 let mu_slot = cursor;
649 let (log_mu, mu) = match strength {
650 AuxPriorStrength::Fixed(mu) => (mu.ln(), *mu),
651 AuxPriorStrength::Auto => {
652 let log_mu = theta[cursor];
653 cursor += 1;
654 (log_mu, log_mu.exp())
655 }
656 };
657 let residual = &t - reference;
658 let q = residual.iter().map(|v| v * v).sum::<f64>();
659 let k = (n_obs * latent_dim) as f64;
663 cost += 0.5 * mu * q - 0.5 * k * log_mu;
664 for n in 0..n_obs {
665 for axis in 0..latent_dim {
666 gradient[t_start + n * latent_dim + axis] += mu * residual[[n, axis]];
667 }
668 }
669 if matches!(strength, AuxPriorStrength::Auto) {
670 gradient[mu_slot] += 0.5 * mu * q - 0.5 * k;
671 }
672 }
673 LatentIdMode::AuxOutcome { head, .. } => {
674 let n_coeffs = head.n_coeffs(latent_dim);
682 let coeffs = theta
683 .slice(ndarray::s![cursor..cursor + n_coeffs])
684 .to_owned();
685 let (head_nll, grad_coeffs, grad_t) = head
686 .neg_loglik_and_grad(t.view(), coeffs.view())
687 .map_err(EstimationError::InvalidInput)?;
688 cost += head_nll;
689 for (offset, &g) in grad_coeffs.iter().enumerate() {
690 gradient[cursor + offset] += g;
691 }
692 for n in 0..n_obs {
693 for axis in 0..latent_dim {
694 gradient[t_start + n * latent_dim + axis] += grad_t[[n, axis]];
695 }
696 }
697 cursor += n_coeffs;
698 }
699 LatentIdMode::DimSelection { .. } | LatentIdMode::None => {}
700 }
701
702 match latent.id_mode() {
703 LatentIdMode::AuxPriorDimSelection { .. }
704 | LatentIdMode::DimSelection { .. }
705 | LatentIdMode::AuxOutcome { .. } => {
706 for axis in 0..latent_dim {
707 let log_alpha = theta[cursor + axis];
708 let alpha = log_alpha.exp();
709 let mut q_axis = 0.0;
710 for n in 0..n_obs {
711 let flat_idx = n * latent_dim + axis;
712 let value = latent.as_flat()[flat_idx];
713 q_axis += value * value;
714 gradient[t_start + flat_idx] += alpha * value;
715 }
716 cost += 0.5 * alpha * q_axis - 0.5 * n_obs as f64 * log_alpha;
717 gradient[cursor + axis] += 0.5 * alpha * q_axis - 0.5 * n_obs as f64;
718 }
719 cursor += latent_dim;
720 }
721 LatentIdMode::AuxPrior { .. }
722 | LatentIdMode::IsometryToReference { .. }
723 | LatentIdMode::None => {}
724 }
725
726 if cursor != theta.len() {
727 crate::bail_invalid_estim!(
728 "latent-coordinate direct hyperparameter length mismatch: consumed {}, theta len {}",
729 cursor,
730 theta.len()
731 );
732 }
733 Ok(LatentIdObjectiveContribution { cost, gradient })
734}
735
736fn add_latent_id_objective_to_eval(
737 theta: &Array1<f64>,
738 rho_dim: usize,
739 analytic_rho_count: usize,
740 latent: &gam_terms::latent::LatentCoordValues,
741 eval: &mut (
742 f64,
743 Array1<f64>,
744 gam_problem::HessianResult,
745 ),
746) -> Result<(), EstimationError> {
747 let contribution =
748 latent_id_objective_contribution(theta, rho_dim, analytic_rho_count, latent)?;
749 eval.0 += contribution.cost;
750 if eval.1.len() != contribution.gradient.len() {
751 crate::bail_invalid_estim!(
752 "latent-coordinate REML gradient length mismatch: base={}, id={}",
753 eval.1.len(),
754 contribution.gradient.len()
755 );
756 }
757 eval.1 += &contribution.gradient;
758 if eval.2.is_analytic() {
759 eval.2 = gam_problem::HessianResult::Unavailable;
760 }
761 Ok(())
762}
763
764fn analytic_penalty_objective_contribution(
765 theta: &Array1<f64>,
766 rho_dim: usize,
767 latent: &gam_terms::latent::LatentCoordValues,
768 registry: &gam_terms::AnalyticPenaltyRegistry,
769) -> Result<LatentIdObjectiveContribution, EstimationError> {
770 let flat_len = latent.len();
771 let t_start = rho_dim;
772 let t_end = t_start + flat_len;
773 let rho_start = t_end;
774 let rho_end = rho_start + registry.total_rho_count();
775 if theta.len() < rho_end {
776 crate::bail_invalid_estim!(
777 "latent-coordinate theta too short for analytic penalties: got {}, need at least {}",
778 theta.len(),
779 rho_end
780 );
781 }
782 let target_t = theta.slice(s![t_start..t_end]);
783 let rho = theta.slice(s![rho_start..rho_end]);
784 let mut cost = 0.0_f64;
785 let mut gradient = Array1::<f64>::zeros(theta.len());
786 for (penalty, (rho_slice, tier, name)) in registry.penalties.iter().zip(registry.rho_layout()) {
787 let rho_local = rho.slice(s![rho_slice.clone()]);
788 match tier {
789 gam_terms::PenaltyTier::Psi => {
790 cost += penalty.value(target_t.view(), rho_local);
791 let grad = penalty.grad_target(target_t.view(), rho_local);
792 if grad.len() != flat_len {
793 crate::bail_invalid_estim!(
794 "analytic penalty {name:?} gradient length mismatch: got {}, expected {}",
795 grad.len(),
796 flat_len
797 );
798 }
799 for i in 0..flat_len {
800 gradient[t_start + i] += grad[i];
801 }
802 let grad_rho_local = penalty.grad_rho(target_t.view(), rho_local);
803 if grad_rho_local.len() != rho_slice.len() {
804 crate::bail_invalid_estim!(
805 "analytic penalty {name:?} rho-gradient length mismatch: got {}, expected {}",
806 grad_rho_local.len(),
807 rho_slice.len()
808 );
809 }
810 for local_idx in 0..grad_rho_local.len() {
811 gradient[rho_start + rho_slice.start + local_idx] += grad_rho_local[local_idx];
812 }
813 }
814 gam_terms::PenaltyTier::Beta => {}
815 gam_terms::PenaltyTier::Rho => {}
816 }
817 }
818 Ok(LatentIdObjectiveContribution { cost, gradient })
819}
820
821fn add_analytic_penalty_hessian_to_eval(
822 theta: &Array1<f64>,
823 rho_dim: usize,
824 latent: &gam_terms::latent::LatentCoordValues,
825 registry: &gam_terms::AnalyticPenaltyRegistry,
826 eval: &mut (
827 f64,
828 Array1<f64>,
829 gam_problem::HessianResult,
830 ),
831) -> Result<(), EstimationError> {
832 let flat_len = latent.len();
833 let t_start = rho_dim;
834 let t_end = t_start + flat_len;
835 let rho_start = t_end;
836 let rho_end = rho_start + registry.total_rho_count();
837 if theta.len() < rho_end {
838 crate::bail_invalid_estim!(
839 "latent-coordinate theta too short for analytic penalty Hessian: got {}, need at least {}",
840 theta.len(),
841 rho_end
842 );
843 }
844 let gam_problem::HessianResult::Analytic(hessian) = &mut eval.2 else {
845 if eval.2.is_analytic() {
846 eval.2 = gam_problem::HessianResult::Unavailable;
847 }
848 return Ok(());
849 };
850 if hessian.dim() != (theta.len(), theta.len()) {
851 crate::bail_invalid_estim!(
852 "analytic penalty Hessian target shape mismatch: got {}x{}, expected {}x{}",
853 hessian.nrows(),
854 hessian.ncols(),
855 theta.len(),
856 theta.len()
857 );
858 }
859 let target_t = theta.slice(s![t_start..t_end]);
860 let rho = theta.slice(s![rho_start..rho_end]);
861 for (penalty, (rho_slice, tier, _name)) in registry.penalties.iter().zip(registry.rho_layout())
862 {
863 let rho_local = rho.slice(s![rho_slice]);
864 if !matches!(tier, gam_terms::PenaltyTier::Psi) {
865 continue;
866 }
867 if let Some(diag) = penalty.hessian_diag(target_t.view(), rho_local) {
868 if diag.len() != flat_len {
869 crate::bail_invalid_estim!(
870 "analytic penalty Hessian diagonal length mismatch: got {}, expected {}",
871 diag.len(),
872 flat_len
873 );
874 }
875 for i in 0..flat_len {
876 hessian[[t_start + i, t_start + i]] += diag[i];
877 }
878 continue;
879 }
880 let mut probe = Array1::<f64>::zeros(flat_len);
881 for col in 0..flat_len {
882 probe[col] = 1.0;
883 let hv = penalty.hvp(target_t.view(), rho_local, probe.view());
884 if hv.len() != flat_len {
885 crate::bail_invalid_estim!(
886 "analytic penalty Hessian-vector length mismatch: got {}, expected {}",
887 hv.len(),
888 flat_len
889 );
890 }
891 for row in 0..flat_len {
892 hessian[[t_start + row, t_start + col]] += hv[row];
893 }
894 probe[col] = 0.0;
895 }
896 }
897 Ok(())
898}
899
900fn add_analytic_penalty_objective_to_eval(
901 theta: &Array1<f64>,
902 rho_dim: usize,
903 latent: &gam_terms::latent::LatentCoordValues,
904 registry: &gam_terms::AnalyticPenaltyRegistry,
905 eval: &mut (
906 f64,
907 Array1<f64>,
908 gam_problem::HessianResult,
909 ),
910) -> Result<(), EstimationError> {
911 let contribution = analytic_penalty_objective_contribution(theta, rho_dim, latent, registry)?;
912 eval.0 += contribution.cost;
913 if eval.1.len() != contribution.gradient.len() {
914 crate::bail_invalid_estim!(
915 "latent-coordinate REML gradient length mismatch: base={}, analytic_penalty={}",
916 eval.1.len(),
917 contribution.gradient.len()
918 );
919 }
920 eval.1 += &contribution.gradient;
921 add_analytic_penalty_hessian_to_eval(theta, rho_dim, latent, registry, eval)?;
922 Ok(())
923}
924
925fn spatial_log_kappa_hyper_dirs_frominfo_list(
926 info_list: Vec<SpatialPsiDerivative>,
927) -> Result<Vec<DirectionalHyperParam>, EstimationError> {
928 use gam_solve::estimate::reml::ImplicitDerivLevel;
929 use std::collections::HashMap;
930
931 let log_kappa_dim = info_list.len();
932 let group_ids: Vec<Option<usize>> = info_list.iter().map(|e| e.aniso_group_id).collect();
938 let mut group_indices_map: HashMap<usize, Vec<usize>> = HashMap::new();
939 for (idx, gid) in group_ids.iter().enumerate() {
940 if let Some(g) = gid {
941 group_indices_map.entry(*g).or_default().push(idx);
942 }
943 }
944
945 let mut hyper_dirs = Vec::with_capacity(log_kappa_dim);
946 for (i, info) in info_list.into_iter().enumerate() {
947 let SpatialPsiDerivative {
948 penalty_index: _,
949 penalty_indices,
950 global_range,
951 total_p,
952 x_psi_local,
953 s_psi_components_local,
954 x_psi_psi_local,
955 s_psi_psi_components_local,
956 aniso_group_id,
957 aniso_cross_designs,
958 aniso_cross_penalty_provider,
959 implicit_operator,
960 implicit_axis,
961 } = info;
962
963 let mut xsecond = vec![None; log_kappa_dim];
964 xsecond[i] = Some(if let Some(ref op) = implicit_operator {
966 gam_solve::estimate::reml::HyperDesignDerivative::from_implicit(
967 op.clone(),
968 ImplicitDerivLevel::SecondDiag(implicit_axis),
969 global_range.clone(),
970 total_p,
971 )
972 } else {
973 gam_solve::estimate::reml::HyperDesignDerivative::from_embedded(
974 x_psi_psi_local,
975 global_range.clone(),
976 total_p,
977 )
978 });
979 if let Some(cross_designs) = aniso_cross_designs {
981 if let Some(gid) = aniso_group_id {
985 let base = group_indices_map
986 .get(&gid)
987 .and_then(|v| v.first().copied())
988 .unwrap_or(i);
989 for (b_axis, cross_mat) in cross_designs.into_iter() {
990 let j = base + b_axis;
991 if j < log_kappa_dim {
992 xsecond[j] = Some(if let Some(ref op) = implicit_operator {
993 gam_solve::estimate::reml::HyperDesignDerivative::from_implicit(
994 op.clone(),
995 ImplicitDerivLevel::SecondCross(implicit_axis, b_axis),
996 global_range.clone(),
997 total_p,
998 )
999 } else {
1000 gam_solve::estimate::reml::HyperDesignDerivative::from_embedded(
1001 cross_mat,
1002 global_range.clone(),
1003 total_p,
1004 )
1005 });
1006 }
1007 }
1008 }
1009 }
1010 let s_components = penalty_indices
1011 .iter()
1012 .copied()
1013 .zip(s_psi_components_local.into_iter().map(|local| {
1014 gam_solve::estimate::reml::HyperPenaltyDerivative::from_embedded(
1015 local,
1016 global_range.clone(),
1017 total_p,
1018 )
1019 }))
1020 .collect::<Vec<_>>();
1021 let s2_components = penalty_indices
1022 .iter()
1023 .copied()
1024 .zip(s_psi_psi_components_local.into_iter().map(|local| {
1025 gam_solve::estimate::reml::HyperPenaltyDerivative::from_embedded(
1026 local,
1027 global_range.clone(),
1028 total_p,
1029 )
1030 }))
1031 .collect::<Vec<_>>();
1032 let mut ssecond_components = vec![None; log_kappa_dim];
1033 ssecond_components[i] = Some(s2_components);
1034 let mut penaltysecond_partner_indices: Option<Vec<usize>> = None;
1035 let penaltysecond_component_provider =
1036 if let (Some(provider), Some(gid)) = (aniso_cross_penalty_provider, aniso_group_id) {
1037 let group_indices = group_indices_map.get(&gid).cloned().unwrap_or_default();
1038 let axis_in_group =
1039 group_indices
1040 .iter()
1041 .position(|&idx| idx == i)
1042 .ok_or_else(|| {
1043 EstimationError::InvalidInput(format!(
1044 "missing spatial hyper axis {} in anisotropy group {}",
1045 i, gid
1046 ))
1047 })?;
1048 penaltysecond_partner_indices = Some(
1049 group_indices
1050 .iter()
1051 .copied()
1052 .filter(|&idx| idx != i)
1053 .collect(),
1054 );
1055 let penalty_indices_inner = penalty_indices.clone();
1056 let global_range_inner = global_range.clone();
1057 let total_p_inner = total_p;
1058 let group_indices_inner = group_indices;
1059 Some(std::sync::Arc::new(
1060 move |j: usize| -> Result<
1061 Option<Vec<gam_solve::estimate::reml::PenaltyDerivativeComponent>>,
1062 EstimationError,
1063 > {
1064 let Some(other_axis_in_group) =
1065 group_indices_inner.iter().position(|&idx| idx == j)
1066 else {
1067 return Ok(None);
1068 };
1069 if other_axis_in_group == axis_in_group {
1070 return Ok(None);
1071 }
1072 let cross_pens = provider(other_axis_in_group)?;
1073 if cross_pens.is_empty() {
1074 return Ok(None);
1075 }
1076 Ok(Some(
1077 penalty_indices_inner
1078 .iter()
1079 .copied()
1080 .zip(cross_pens.into_iter().map(|local| {
1081 gam_solve::estimate::reml::HyperPenaltyDerivative::from_embedded(
1082 local,
1083 global_range_inner.clone(),
1084 total_p_inner,
1085 )
1086 }))
1087 .map(|(penalty_index, matrix)| {
1088 gam_solve::estimate::reml::PenaltyDerivativeComponent {
1089 penalty_index,
1090 matrix,
1091 }
1092 })
1093 .collect(),
1094 ))
1095 },
1096 )
1097 as std::sync::Arc<
1098 dyn Fn(
1099 usize,
1100 ) -> Result<
1101 Option<Vec<gam_solve::estimate::reml::PenaltyDerivativeComponent>>,
1102 EstimationError,
1103 > + Send
1104 + Sync
1105 + 'static,
1106 >)
1107 } else {
1108 None
1109 };
1110 let x_first_hyper = if let Some(ref op) = implicit_operator {
1113 gam_solve::estimate::reml::HyperDesignDerivative::from_implicit(
1114 op.clone(),
1115 ImplicitDerivLevel::First(implicit_axis),
1116 global_range.clone(),
1117 total_p,
1118 )
1119 } else {
1120 gam_solve::estimate::reml::HyperDesignDerivative::from_embedded(
1121 x_psi_local,
1122 global_range.clone(),
1123 total_p,
1124 )
1125 };
1126 let mut dir = DirectionalHyperParam::new_compact(
1127 x_first_hyper,
1128 s_components,
1129 Some(xsecond),
1130 Some(ssecond_components),
1131 )?
1132 .not_penalty_like();
1133 if let Some(provider) = penaltysecond_component_provider {
1134 dir = dir.with_penaltysecond_component_provider(provider);
1135 }
1136 if let Some(partner_indices) = penaltysecond_partner_indices {
1137 dir = dir.with_penaltysecond_partner_indices(partner_indices);
1138 }
1139 hyper_dirs.push(dir);
1140 }
1141 Ok(hyper_dirs)
1142}
1143
1144pub(crate) fn spatial_dims_per_term(
1150 resolvedspec: &TermCollectionSpec,
1151 spatial_terms: &[usize],
1152) -> Vec<usize> {
1153 spatial_terms
1154 .iter()
1155 .map(|&term_idx| {
1156 if let Some(mj) = measure_jet_term_spec(resolvedspec, term_idx) {
1157 measure_jet_psi_dim(mj)
1160 } else if spatial_term_uses_per_axis_psi(resolvedspec, term_idx) {
1161 get_spatial_feature_dim(resolvedspec, term_idx).unwrap_or(1)
1162 } else {
1163 1
1164 }
1165 })
1166 .collect()
1167}
1168
1169fn has_aniso_terms(resolvedspec: &TermCollectionSpec, spatial_terms: &[usize]) -> bool {
1173 spatial_terms
1174 .iter()
1175 .any(|&term_idx| spatial_term_uses_per_axis_psi(resolvedspec, term_idx))
1176}
1177
1178macro_rules! impl_exact_joint_theta_memo {
1184 () => {
1185 fn memoized_cost(&self, theta: &Array1<f64>) -> Option<f64> {
1186 if self
1187 .current_theta
1188 .as_ref()
1189 .is_some_and(|cached| theta_values_match(cached, theta))
1190 {
1191 self.last_eval
1192 .as_ref()
1193 .map(|cached| cached.0)
1194 .or(self.last_cost)
1195 } else {
1196 None
1197 }
1198 }
1199
1200 fn memoized_eval(
1201 &self,
1202 theta: &Array1<f64>,
1203 ) -> Option<(
1204 f64,
1205 Array1<f64>,
1206 gam_problem::HessianResult,
1207 )> {
1208 if self
1209 .current_theta
1210 .as_ref()
1211 .is_some_and(|cached| theta_values_match(cached, theta))
1212 {
1213 self.last_eval.clone()
1214 } else {
1215 None
1216 }
1217 }
1218
1219 fn store_eval(
1220 &mut self,
1221 eval: (
1222 f64,
1223 Array1<f64>,
1224 gam_problem::HessianResult,
1225 ),
1226 ) {
1227 self.last_cost = Some(eval.0);
1228 self.last_eval = Some(eval);
1229 }
1230 };
1231}
1232
1233struct SingleBlockExactJointDesignCache<'d> {
1234 realizer: FrozenTermCollectionIncrementalRealizer<'d>,
1235 current_theta: Option<Array1<f64>>,
1236 last_eval_theta: Option<Array1<f64>>,
1243 last_cost: Option<f64>,
1244 last_eval: Option<(
1245 f64,
1246 Array1<f64>,
1247 gam_problem::HessianResult,
1248 )>,
1249 cached_hyper_dirs: Option<(u64, Vec<DirectionalHyperParam>)>,
1261 spatial_terms: Vec<usize>,
1262 rho_dim: usize,
1263 dims_per_term: Vec<usize>,
1264}
1265
1266impl<'d> SingleBlockExactJointDesignCache<'d> {
1267 fn new(
1268 data: ArrayView2<'d, f64>,
1269 spec: TermCollectionSpec,
1270 design: TermCollectionDesign,
1271 spatial_terms: Vec<usize>,
1272 rho_dim: usize,
1273 dims_per_term: Vec<usize>,
1274 ) -> Result<Self, String> {
1275 Ok(Self {
1276 realizer: FrozenTermCollectionIncrementalRealizer::new(data, spec, design)?,
1277 current_theta: None,
1278 last_eval_theta: None,
1279 last_cost: None,
1280 last_eval: None,
1281 cached_hyper_dirs: None,
1282 spatial_terms,
1283 rho_dim,
1284 dims_per_term,
1285 })
1286 }
1287
1288 fn design_revision(&self) -> u64 {
1289 self.realizer.design_revision()
1290 }
1291
1292 fn hyper_dirs_for_current_design(
1302 &mut self,
1303 data: ArrayView2<'_, f64>,
1304 kind: SpatialHyperKind,
1305 ) -> Result<Vec<DirectionalHyperParam>, EstimationError> {
1306 let revision = self.realizer.design_revision();
1307 if let Some((cached_rev, dirs)) = self.cached_hyper_dirs.as_ref()
1308 && *cached_rev == revision
1309 {
1310 return Ok(dirs.clone());
1311 }
1312 let dirs = try_build_spatial_log_kappa_hyper_dirs(
1313 data,
1314 self.realizer.spec(),
1315 self.realizer.design(),
1316 &self.spatial_terms,
1317 )?
1318 .ok_or_else(|| {
1319 EstimationError::InvalidInput(format!(
1320 "failed to build {} hyper_dirs at current {}",
1321 kind.adjective(),
1322 kind.coord_name(),
1323 ))
1324 })?;
1325 self.cached_hyper_dirs = Some((revision, dirs.clone()));
1326 Ok(dirs)
1327 }
1328
1329 fn nfree_tensor_gradient_hyper_dirs(
1330 &mut self,
1331 theta: &Array1<f64>,
1332 ) -> Result<Vec<DirectionalHyperParam>, EstimationError> {
1333 let psi = &theta.as_slice().ok_or_else(|| {
1334 EstimationError::InvalidInput(
1335 "nfree_tensor_gradient_hyper_dirs: theta is not contiguous".to_string(),
1336 )
1337 })?[self.rho_dim..];
1338 let (global_range, p_total, s_psi_components) = self
1339 .realizer
1340 .canonical_penalty_derivatives_at_psi(&self.spatial_terms, psi)
1341 .map_err(EstimationError::InvalidInput)?;
1342 let zero_x = gam_solve::estimate::reml::HyperDesignDerivative::zero(
1343 self.realizer.design().design.nrows(),
1344 p_total,
1345 );
1346 let components = s_psi_components
1347 .into_iter()
1348 .enumerate()
1349 .map(|(penalty_index, local)| {
1350 (
1351 penalty_index,
1352 gam_solve::estimate::reml::HyperPenaltyDerivative::from_embedded(
1353 local,
1354 global_range.clone(),
1355 p_total,
1356 ),
1357 )
1358 })
1359 .collect::<Vec<_>>();
1360 Ok(DirectionalHyperParam::new_compact(zero_x, components, None, None)?.not_penalty_like())
1361 .map(|dir| vec![dir])
1362 }
1363
1364 fn ensure_theta(&mut self, theta: &Array1<f64>) -> Result<(), String> {
1365 if self
1366 .current_theta
1367 .as_ref()
1368 .is_some_and(|cached| theta_values_match(cached, theta))
1369 {
1370 return Ok(());
1371 }
1372 let t_ensure = std::time::Instant::now();
1373 let log_kappa = SpatialLogKappaCoords::from_theta_tail_with_dims(
1374 theta,
1375 self.rho_dim,
1376 self.dims_per_term.clone(),
1377 );
1378 self.realizer
1379 .apply_log_kappa(&log_kappa, &self.spatial_terms)?;
1380 log::info!(
1381 "[STAGE] ensure_theta (apply_log_kappa, {} terms): {:.3}s",
1382 self.spatial_terms.len(),
1383 t_ensure.elapsed().as_secs_f64(),
1384 );
1385 self.current_theta = Some(theta.clone());
1386 self.last_eval_theta = None;
1387 self.last_cost = None;
1388 self.last_eval = None;
1389 Ok(())
1390 }
1391
1392 fn memoized_cost(&self, theta: &Array1<f64>) -> Option<f64> {
1399 if self
1400 .last_eval_theta
1401 .as_ref()
1402 .is_some_and(|cached| theta_values_match(cached, theta))
1403 {
1404 self.last_eval
1405 .as_ref()
1406 .map(|cached| cached.0)
1407 .or(self.last_cost)
1408 } else {
1409 None
1410 }
1411 }
1412
1413 fn memoized_eval(
1414 &self,
1415 theta: &Array1<f64>,
1416 ) -> Option<(
1417 f64,
1418 Array1<f64>,
1419 gam_problem::HessianResult,
1420 )> {
1421 if self
1422 .last_eval_theta
1423 .as_ref()
1424 .is_some_and(|cached| theta_values_match(cached, theta))
1425 {
1426 self.last_eval.clone()
1427 } else {
1428 None
1429 }
1430 }
1431
1432 fn store_eval_at(
1436 &mut self,
1437 theta: &Array1<f64>,
1438 eval: (
1439 f64,
1440 Array1<f64>,
1441 gam_problem::HessianResult,
1442 ),
1443 ) {
1444 self.last_eval_theta = Some(theta.clone());
1445 self.last_cost = Some(eval.0);
1446 self.last_eval = Some(eval);
1447 }
1448
1449 fn store_cost_at(&mut self, theta: &Array1<f64>, cost: f64) {
1452 self.last_eval_theta = Some(theta.clone());
1453 self.last_cost = Some(cost);
1454 self.last_eval = None;
1458 }
1459
1460 fn spec(&self) -> &TermCollectionSpec {
1461 self.realizer.spec()
1462 }
1463
1464 fn design(&self) -> &TermCollectionDesign {
1465 self.realizer.design()
1466 }
1467
1468 fn supports_nfree_penalty_rekey(&self) -> bool {
1474 self.realizer
1475 .supports_nfree_penalty_rekey(&self.spatial_terms)
1476 }
1477
1478 fn supports_nfree_gradient_only_routing(&self) -> bool {
1479 self.realizer
1480 .supports_nfree_gradient_only_routing(&self.spatial_terms)
1481 }
1482
1483 fn canonical_penalties_at(
1493 &mut self,
1494 theta: &Array1<f64>,
1495 ) -> Result<(Vec<gam_terms::construction::CanonicalPenalty>, Vec<usize>), String> {
1496 let psi = &theta
1497 .as_slice()
1498 .ok_or_else(|| "canonical_penalties_at: theta is not contiguous".to_string())?
1499 [self.rho_dim..];
1500 self.realizer
1501 .canonical_penalties_at_psi(&self.spatial_terms, psi)
1502 }
1503}
1504
1505struct SingleBlockLatentCoordDesignCache {
1506 data: Array2<f64>,
1507 spec: TermCollectionSpec,
1508 design: TermCollectionDesign,
1509 current_theta: Option<Array1<f64>>,
1510 current_latent: Option<std::sync::Arc<gam_terms::latent::LatentCoordValues>>,
1511 current_hyper_dirs: Option<Vec<gam_solve::estimate::reml::DirectionalHyperParam>>,
1512 current_design_cache_id: Option<u64>,
1513 latent_design_cache: gam_solve::latent_cache::LatentDesignCache,
1514 last_cost: Option<f64>,
1515 last_eval: Option<(
1516 f64,
1517 Array1<f64>,
1518 gam_problem::HessianResult,
1519 )>,
1520 term_index: gam_problem::types::SmoothTermIdx,
1521 feature_cols: Vec<usize>,
1522 rho_dim: usize,
1523 n_obs: usize,
1524 latent_dim: usize,
1525 id_mode: gam_terms::latent::LatentIdMode,
1526 manifold: gam_terms::latent::LatentManifold,
1527 retraction_registry: gam_solve::latent_cache::LatentRetractionRegistry,
1528 latent_id: u64,
1529 analytic_penalties: Option<std::sync::Arc<gam_terms::AnalyticPenaltyRegistry>>,
1530 analytic_rho_count: usize,
1531 design_revision: u64,
1532 last_outer_iter: Option<u64>,
1536}
1537
1538impl SingleBlockLatentCoordDesignCache {
1539 fn new(
1540 data: Array2<f64>,
1541 spec: TermCollectionSpec,
1542 design: TermCollectionDesign,
1543 latent: &StandardLatentCoordConfig,
1544 rho_dim: usize,
1545 ) -> Result<Self, String> {
1546 if latent.term_index.get() >= spec.smooth_terms.len() {
1547 return Err(SmoothError::dimension_mismatch(format!(
1548 "latent-coordinate term index {} out of bounds for {} smooth terms",
1549 latent.term_index,
1550 spec.smooth_terms.len()
1551 ))
1552 .into());
1553 }
1554 if latent.feature_cols.len() != latent.values.latent_dim() {
1555 return Err(SmoothError::dimension_mismatch(format!(
1556 "latent-coordinate feature width mismatch: feature_cols={}, latent_dim={}",
1557 latent.feature_cols.len(),
1558 latent.values.latent_dim()
1559 ))
1560 .into());
1561 }
1562 if latent.values.n_obs() != data.nrows() {
1563 return Err(SmoothError::dimension_mismatch(format!(
1564 "latent-coordinate row mismatch: latent n={}, data n={}",
1565 latent.values.n_obs(),
1566 data.nrows()
1567 ))
1568 .into());
1569 }
1570 let analytic_rho_count = latent
1571 .analytic_penalties
1572 .as_ref()
1573 .map_or(0, |registry| registry.total_rho_count());
1574 Ok(Self {
1575 data,
1576 spec,
1577 design,
1578 current_theta: None,
1579 current_latent: None,
1580 current_hyper_dirs: None,
1581 current_design_cache_id: None,
1582 latent_design_cache: gam_solve::latent_cache::LatentDesignCache::default(),
1583 last_cost: None,
1584 last_eval: None,
1585 term_index: latent.term_index,
1586 feature_cols: latent.feature_cols.clone(),
1587 rho_dim,
1588 n_obs: latent.values.n_obs(),
1589 latent_dim: latent.values.latent_dim(),
1590 id_mode: latent.values.id_mode().clone(),
1591 manifold: latent.values.manifold().clone(),
1592 retraction_registry: latent.values.retraction_registry().clone(),
1593 latent_id: latent.values.latent_id(),
1594 analytic_penalties: latent.analytic_penalties.clone(),
1595 analytic_rho_count,
1596 design_revision: 0,
1597 last_outer_iter: None,
1598 })
1599 }
1600
1601 fn design_revision(&self) -> u64 {
1602 self.design_revision
1603 }
1604
1605 fn design(&self) -> &TermCollectionDesign {
1606 &self.design
1607 }
1608
1609 fn latent(&self) -> Result<std::sync::Arc<gam_terms::latent::LatentCoordValues>, String> {
1610 self.current_latent
1611 .as_ref()
1612 .cloned()
1613 .ok_or_else(|| "latent-coordinate cache has not been realized".to_string())
1614 }
1615
1616 fn analytic_penalties(&self) -> Option<std::sync::Arc<gam_terms::AnalyticPenaltyRegistry>> {
1617 self.analytic_penalties.clone()
1618 }
1619
1620 fn analytic_penalty_rho_count(&self) -> usize {
1621 self.analytic_rho_count
1622 }
1623
1624 fn hyper_dirs(&self) -> Result<Vec<gam_solve::estimate::reml::DirectionalHyperParam>, String> {
1625 self.current_hyper_dirs
1626 .as_ref()
1627 .cloned()
1628 .ok_or_else(|| "latent-coordinate hyper_dirs cache has not been realized".to_string())
1629 }
1630
1631 fn latent_basis_kind(&self) -> Result<gam_solve::latent_cache::LatentBasisKind, String> {
1632 let smooth_term = self
1633 .design
1634 .smooth
1635 .terms
1636 .get(self.term_index.get())
1637 .ok_or_else(|| {
1638 SmoothError::dimension_mismatch(format!(
1639 "LatentCoord term index {} out of bounds for realized smooth design",
1640 self.term_index
1641 ))
1642 })?;
1643 let termspec = self
1644 .spec
1645 .smooth_terms
1646 .get(self.term_index.get())
1647 .ok_or_else(|| {
1648 SmoothError::dimension_mismatch(format!(
1649 "LatentCoord term index {} out of bounds for resolved smooth spec",
1650 self.term_index
1651 ))
1652 })?;
1653 match (&termspec.basis, &smooth_term.metadata) {
1654 (
1655 SmoothBasisSpec::Matern { .. },
1656 BasisMetadata::Matern {
1657 centers,
1658 length_scale,
1659 nu,
1660 aniso_log_scales,
1661 ..
1662 },
1663 ) => Ok(gam_solve::latent_cache::LatentBasisKind::Matern {
1664 centers: centers.clone(),
1665 length_scale: *length_scale,
1666 nu: *nu,
1667 aniso_log_scales: aniso_log_scales
1668 .clone()
1669 .unwrap_or_else(|| vec![0.0; centers.ncols()]),
1670 chunk_size: gam_terms::basis::auto_streaming_chunk_size_for_dense(
1671 self.n_obs,
1672 centers.nrows(),
1673 ),
1674 }),
1675 (
1676 SmoothBasisSpec::Duchon { .. },
1677 BasisMetadata::Duchon {
1678 centers,
1679 length_scale,
1680 power,
1681 nullspace_order,
1682 aniso_log_scales,
1683 ..
1684 },
1685 ) => Ok(gam_solve::latent_cache::LatentBasisKind::Duchon {
1686 centers: centers.clone(),
1687 length_scale: *length_scale,
1688 power: *power,
1689 nullspace_order: *nullspace_order,
1690 aniso_log_scales: aniso_log_scales
1691 .clone()
1692 .unwrap_or_else(|| vec![0.0; centers.ncols()]),
1693 }),
1694 (
1695 SmoothBasisSpec::Sphere { .. },
1696 BasisMetadata::Sphere {
1697 centers,
1698 penalty_order,
1699 method,
1700 ..
1701 },
1702 ) if matches!(*method, gam_terms::basis::SphereMethod::Wahba) => {
1703 Ok(gam_solve::latent_cache::LatentBasisKind::Sphere {
1704 centers: centers.clone(),
1705 penalty_order: *penalty_order,
1706 chunk_size: gam_terms::basis::auto_streaming_chunk_size_for_dense(
1707 self.n_obs,
1708 centers.nrows(),
1709 ),
1710 })
1711 }
1712 (
1713 SmoothBasisSpec::BSpline1D { spec, .. },
1714 BasisMetadata::BSpline1D {
1715 knots,
1716 periodic,
1717 degree: meta_degree,
1718 ..
1719 },
1720 ) => {
1721 let effective_degree = meta_degree.unwrap_or(spec.degree);
1725 if let Some((domain_start, period, num_basis)) = periodic {
1726 Ok(
1727 gam_solve::latent_cache::LatentBasisKind::PeriodicBspline {
1728 domain_start: *domain_start,
1729 period: *period,
1730 degree: effective_degree,
1731 num_basis: *num_basis,
1732 chunk_size: gam_terms::basis::auto_streaming_chunk_size_for_dense(
1733 self.n_obs, *num_basis,
1734 ),
1735 },
1736 )
1737 } else {
1738 let num_basis_est = knots.len().saturating_sub(effective_degree + 1);
1739 Ok(
1740 gam_solve::latent_cache::LatentBasisKind::TensorBspline {
1741 knots: vec![knots.clone()],
1742 degrees: vec![effective_degree],
1743 chunk_size: gam_terms::basis::auto_streaming_chunk_size_for_dense(
1744 self.n_obs,
1745 num_basis_est,
1746 ),
1747 },
1748 )
1749 }
1750 }
1751 (
1752 SmoothBasisSpec::TensorBSpline { .. },
1753 BasisMetadata::TensorBSpline { knots, degrees, .. },
1754 ) => Ok(
1755 gam_solve::latent_cache::LatentBasisKind::TensorBspline {
1756 knots: knots.clone(),
1757 degrees: degrees.clone(),
1758 chunk_size: None,
1759 },
1760 ),
1761 (
1762 SmoothBasisSpec::Pca { .. },
1763 BasisMetadata::Pca {
1764 basis_matrix,
1765 centered,
1766 smooth_penalty,
1767 center_mean,
1768 pca_basis_path,
1769 chunk_size,
1770 ..
1771 },
1772 ) => {
1773 let center_mean_fingerprint = if *centered && pca_basis_path.is_none() {
1774 let mean = center_mean.as_ref().ok_or_else(|| {
1775 SmoothError::invalid_config(
1776 "latent-coordinate Pca cache key requires center_mean when centered",
1777 )
1778 })?;
1779 Some(gam_solve::latent_cache::pca_center_mean_fingerprint(
1780 mean,
1781 ))
1782 } else {
1783 None
1784 };
1785 Ok(gam_solve::latent_cache::LatentBasisKind::Pca {
1786 basis_matrix: basis_matrix.clone(),
1787 centered: *centered,
1788 center_mean_fingerprint,
1789 smooth_penalty: *smooth_penalty,
1790 pca_basis_path: pca_basis_path.clone(),
1791 chunk_size: *chunk_size,
1792 })
1793 }
1794 _ => Err(SmoothError::invalid_config(
1795 "latent-coordinate design cache could not key the realized latent smooth basis"
1796 .to_string(),
1797 )
1798 .into()),
1799 }
1800 }
1801
1802 fn ensure_theta(&mut self, theta: &Array1<f64>) -> Result<(), String> {
1803 if self
1804 .current_theta
1805 .as_ref()
1806 .is_some_and(|cached| theta_values_match(cached, theta))
1807 {
1808 return Ok(());
1809 }
1810 let latent_flat_len = self.n_obs * self.latent_dim;
1811 let direct_hyper_count = latent_coord_direct_hyper_count(&self.id_mode, self.latent_dim);
1812 let expected =
1813 self.rho_dim + latent_flat_len + self.analytic_rho_count + direct_hyper_count;
1814 if theta.len() != expected {
1815 return Err(SmoothError::dimension_mismatch(format!(
1816 "latent-coordinate theta length mismatch: got {}, expected {} (rho_dim={}, n={}, d={}, analytic_rhos={}, direct_hypers={})",
1817 theta.len(),
1818 expected,
1819 self.rho_dim,
1820 self.n_obs,
1821 self.latent_dim,
1822 self.analytic_rho_count,
1823 direct_hyper_count
1824 ))
1825 .into());
1826 }
1827 let flat = theta
1828 .slice(s![self.rho_dim..self.rho_dim + latent_flat_len])
1829 .to_owned();
1830 let latent = std::sync::Arc::new(
1831 gam_terms::latent::LatentCoordValues::from_flat_with_manifold_and_retraction_and_id(
1832 flat,
1833 self.n_obs,
1834 self.latent_dim,
1835 self.id_mode.clone(),
1836 self.manifold.clone(),
1837 self.retraction_registry.clone(),
1838 self.latent_id,
1839 ),
1840 );
1841 let latent_values_changed = self
1842 .current_latent
1843 .as_ref()
1844 .map(|cached| !latent_values_match(cached.as_flat(), latent.as_flat()))
1845 .unwrap_or(true);
1846 if latent_values_changed {
1847 self.latent_design_cache.invalidate_all();
1848 self.current_design_cache_id = None;
1849 self.design_revision = self.design_revision.wrapping_add(1);
1850 }
1851 for n in 0..self.n_obs {
1852 for axis in 0..self.latent_dim {
1853 let col = self.feature_cols[axis];
1854 self.data[[n, col]] = latent.as_flat()[n * self.latent_dim + axis];
1855 }
1856 }
1857
1858 let basis_kind = self.latent_basis_kind()?;
1859 let rebuilt_width = self.design.design.ncols();
1860 let spec = self.spec.clone();
1861 let term_index = self.term_index;
1862 let analytic_rho_count = self.analytic_rho_count;
1863 let data = self.data.view();
1864 let design_context_digest =
1865 gam_solve::latent_cache::latent_design_context_cache_digest(
1866 data,
1867 &spec,
1868 term_index,
1869 analytic_rho_count,
1870 &self.feature_cols,
1871 )
1872 .map_err(|e| e.to_string())?;
1873 let lookup = self
1874 .latent_design_cache
1875 .lookup_or_compute(latent.clone(), basis_kind, design_context_digest, || {
1876 let rebuilt = build_term_collection_design(data, &spec).map_err(|e| {
1877 EstimationError::InvalidInput(format!(
1878 "failed to rebuild latent-coordinate design: {e}"
1879 ))
1880 })?;
1881 if rebuilt.design.ncols() != rebuilt_width {
1882 crate::bail_invalid_estim!(
1883 "latent-coordinate design topology changed: rebuilt p={}, cached p={}",
1884 rebuilt.design.ncols(),
1885 rebuilt_width
1886 );
1887 }
1888 let hyper_dirs = try_build_latent_coord_hyper_dirs(
1889 latent.clone(),
1890 &spec,
1891 &rebuilt,
1892 &[term_index],
1893 analytic_rho_count,
1894 )?
1895 .ok_or_else(|| {
1896 EstimationError::InvalidInput(
1897 "failed to build latent-coordinate hyper_dirs".to_string(),
1898 )
1899 })?;
1900 Ok(gam_solve::latent_cache::ComputedLatentDesign {
1901 design: rebuilt,
1902 hyper_dirs,
1903 })
1904 })
1905 .map_err(|e| e.to_string())?;
1906 if lookup.cached.design.design.ncols() != self.design.design.ncols() {
1907 return Err(SmoothError::dimension_mismatch(format!(
1908 "latent-coordinate design topology changed: rebuilt p={}, cached p={}",
1909 lookup.cached.design.design.ncols(),
1910 self.design.design.ncols()
1911 ))
1912 .into());
1913 }
1914 self.design = lookup.cached.design.clone();
1915 self.current_hyper_dirs = Some(lookup.cached.hyper_dirs.clone());
1916 self.current_latent = Some(latent);
1917 self.current_theta = Some(theta.clone());
1918 self.last_cost = None;
1919 self.last_eval = None;
1920 self.last_outer_iter = None;
1921 if !latent_values_changed && self.current_design_cache_id != Some(lookup.entry_id) {
1922 self.design_revision = self.design_revision.wrapping_add(1);
1923 }
1924 self.current_design_cache_id = Some(lookup.entry_id);
1925 Ok(())
1926 }
1927
1928 fn memoized_cost(&self, theta: &Array1<f64>) -> Option<f64> {
1929 if self
1930 .current_theta
1931 .as_ref()
1932 .is_some_and(|cached| theta_values_match(cached, theta))
1933 && self.last_outer_iter
1934 == Some(gam_solve::estimate::reml::outer_eval::current_outer_iter())
1935 {
1936 self.last_eval
1937 .as_ref()
1938 .map(|cached| cached.0)
1939 .or(self.last_cost)
1940 } else {
1941 None
1942 }
1943 }
1944
1945 fn memoized_eval(
1946 &self,
1947 theta: &Array1<f64>,
1948 ) -> Option<(
1949 f64,
1950 Array1<f64>,
1951 gam_problem::HessianResult,
1952 )> {
1953 if self
1954 .current_theta
1955 .as_ref()
1956 .is_some_and(|cached| theta_values_match(cached, theta))
1957 && self.last_outer_iter
1958 == Some(gam_solve::estimate::reml::outer_eval::current_outer_iter())
1959 {
1960 self.last_eval.clone()
1961 } else {
1962 None
1963 }
1964 }
1965
1966 fn store_eval(
1967 &mut self,
1968 eval: (
1969 f64,
1970 Array1<f64>,
1971 gam_problem::HessianResult,
1972 ),
1973 ) {
1974 self.last_cost = Some(eval.0);
1975 self.last_eval = Some(eval);
1976 self.last_outer_iter =
1977 Some(gam_solve::estimate::reml::outer_eval::current_outer_iter());
1978 }
1979
1980 fn store_cost(&mut self, cost: f64) {
1981 self.last_cost = Some(cost);
1982 self.last_outer_iter =
1983 Some(gam_solve::estimate::reml::outer_eval::current_outer_iter());
1984 }
1985
1986 fn reset(&mut self) {
1987 self.current_theta = None;
1988 self.current_latent = None;
1989 self.current_hyper_dirs = None;
1990 self.current_design_cache_id = None;
1991 self.latent_design_cache.invalidate();
1992 self.last_cost = None;
1993 self.last_eval = None;
1994 self.last_outer_iter = None;
1995 }
1996}
1997
1998pub fn fixed_kappa_profiled_reml_score(
2014 data: ArrayView2<'_, f64>,
2015 y: ArrayView1<'_, f64>,
2016 weights: ArrayView1<'_, f64>,
2017 offset: ArrayView1<'_, f64>,
2018 resolvedspec: &TermCollectionSpec,
2019 term_idx: usize,
2020 kappa: f64,
2021 family: LikelihoodSpec,
2022 options: &FitOptions,
2023) -> Result<f64, EstimationError> {
2024 if !kappa.is_finite() {
2025 crate::bail_invalid_estim!("fixed-κ profiled score probed a non-finite κ = {kappa}");
2026 }
2027 let (feature_cols, mut probe_basis) = match resolvedspec
2030 .smooth_terms
2031 .get(term_idx)
2032 .map(|t| &t.basis)
2033 {
2034 Some(SmoothBasisSpec::ConstantCurvature {
2035 feature_cols, spec, ..
2036 }) => (feature_cols.clone(), spec.clone()),
2037 _ => {
2038 crate::bail_invalid_estim!(
2039 "fixed-κ profiled score: term {term_idx} is not a constant-curvature smooth"
2040 )
2041 }
2042 };
2043 probe_basis.kappa = kappa;
2044
2045 let is_unweighted = weights.iter().all(|&w| (w - 1.0).abs() <= 1e-12);
2065 let is_zero_offset = offset.iter().all(|&o| o.abs() <= 1e-12);
2066 if family == LikelihoodSpec::gaussian_identity() && is_unweighted && is_zero_offset {
2067 let x_term = select_columns(data, &feature_cols).map_err(EstimationError::from)?;
2068 let score =
2069 gam_terms::basis::constant_curvature_honest_profiled_reml_score(x_term.view(), y, &probe_basis)
2070 .map_err(|e| {
2071 EstimationError::InvalidInput(format!(
2072 "fixed-κ honest profiled-REML score at κ={kappa} failed: {e}"
2073 ))
2074 })?;
2075 if !score.is_finite() {
2076 crate::bail_invalid_estim!(
2077 "fixed-κ honest profiled-REML score at κ={kappa} is non-finite"
2078 );
2079 }
2080 return Ok(score);
2081 }
2082
2083 let mut probe_spec = resolvedspec.clone();
2085 match probe_spec.smooth_terms.get_mut(term_idx).map(|t| &mut t.basis) {
2086 Some(SmoothBasisSpec::ConstantCurvature { spec, .. }) => spec.kappa = kappa,
2087 _ => {
2088 crate::bail_invalid_estim!(
2089 "fixed-κ profiled score: term {term_idx} is not a constant-curvature smooth"
2090 )
2091 }
2092 }
2093 let fixed_kappa_options = SpatialLengthScaleOptimizationOptions {
2094 enabled: false,
2095 ..SpatialLengthScaleOptimizationOptions::default()
2096 };
2097 let fit = fit_term_collectionwith_spatial_length_scale_optimization(
2098 data,
2099 y.to_owned(),
2100 weights.to_owned(),
2101 offset.to_owned(),
2102 &probe_spec,
2103 family,
2104 options,
2105 &fixed_kappa_options,
2106 )?;
2107 let score = fit_score(&fit.fit);
2108 if !score.is_finite() {
2109 crate::bail_invalid_estim!("fixed-κ profiled fit at κ={kappa} returned a non-finite score");
2110 }
2111 Ok(score)
2112}
2113
2114fn constant_curvature_kappa_fair_argmin(
2139 data: ArrayView2<'_, f64>,
2140 y: ArrayView1<'_, f64>,
2141 resolvedspec: &TermCollectionSpec,
2142 term_idx: usize,
2143) -> Option<f64> {
2144 let (kappa_min, kappa_max) = constant_curvature_kappa_bounds(data, resolvedspec, term_idx);
2145 if !(kappa_min.is_finite() && kappa_max.is_finite() && kappa_max > kappa_min) {
2146 return None;
2147 }
2148 let (feature_cols, base_spec) = match resolvedspec.smooth_terms.get(term_idx).map(|t| &t.basis) {
2149 Some(SmoothBasisSpec::ConstantCurvature {
2150 feature_cols, spec, ..
2151 }) => (feature_cols, spec.clone()),
2152 _ => return None,
2153 };
2154 let x_term = match select_columns(data, feature_cols) {
2155 Ok(x) => x,
2156 Err(e) => {
2157 log::info!("[spatial-kappa] #1464 κ-fair argmin column select failed ({e}); skipping");
2158 return None;
2159 }
2160 };
2161 const GRID_STEPS: usize = 24;
2167 let mut best: Option<(f64, f64)> = None; for i in 0..=GRID_STEPS {
2169 let t = i as f64 / GRID_STEPS as f64;
2170 let kappa = kappa_min + (kappa_max - kappa_min) * t;
2171 let mut probe_spec = base_spec.clone();
2172 probe_spec.kappa = kappa;
2173 match gam_terms::basis::constant_curvature_kappa_fair_sign_score(x_term.view(), y, &probe_spec) {
2174 Ok(score) => {
2175 if best.as_ref().is_none_or(|(b, _)| score < *b) {
2176 best = Some((score, kappa));
2177 }
2178 }
2179 Err(e) => {
2180 log::info!(
2181 "[spatial-kappa] #1464 κ-fair argmin probe at κ={kappa:.4} failed ({e}); skipping"
2182 );
2183 }
2184 }
2185 }
2186 best.map(|(score, kappa)| {
2187 log::info!(
2188 "[spatial-kappa] #1464 κ-fair argmin κ̂={kappa:.4} (κ-fair score={score:.6e}) for term {term_idx}"
2189 );
2190 kappa
2191 })
2192}
2193
2194fn select_constant_curvature_kappa_sign_seed(
2202 data: ArrayView2<'_, f64>,
2203 y: ArrayView1<'_, f64>,
2204 resolvedspec: &TermCollectionSpec,
2205 term_idx: usize,
2206) -> Option<f64> {
2207 let (kappa_min, kappa_max) = constant_curvature_kappa_bounds(data, resolvedspec, term_idx);
2208 if !(kappa_min.is_finite() && kappa_max.is_finite() && kappa_max > kappa_min) {
2209 return None;
2210 }
2211 let (feature_cols, base_spec) = match resolvedspec.smooth_terms.get(term_idx).map(|t| &t.basis) {
2223 Some(SmoothBasisSpec::ConstantCurvature {
2224 feature_cols, spec, ..
2225 }) => (feature_cols, spec.clone()),
2226 _ => return None,
2227 };
2228 let x_term = match select_columns(data, feature_cols) {
2229 Ok(x) => x,
2230 Err(e) => {
2231 log::info!("[spatial-kappa] #1464 sign-basin scan column select failed ({e}); skipping");
2232 return None;
2233 }
2234 };
2235 let probes = [
2239 kappa_min,
2240 0.5 * kappa_min,
2241 0.0,
2242 0.5 * kappa_max,
2243 kappa_max,
2244 ];
2245 let mut best: Option<(f64, f64)> = None; for &kappa in &probes {
2247 let mut probe_spec = base_spec.clone();
2248 probe_spec.kappa = kappa;
2249 match gam_terms::basis::constant_curvature_kappa_fair_sign_score(
2250 x_term.view(),
2251 y,
2252 &probe_spec,
2253 ) {
2254 Ok(score) => {
2255 if best.as_ref().is_none_or(|(b, _)| score < *b) {
2256 best = Some((score, kappa));
2257 }
2258 }
2259 Err(e) => {
2260 log::info!(
2261 "[spatial-kappa] #1464 sign-basin probe at κ={kappa:.4} failed ({e}); skipping"
2262 );
2263 }
2264 }
2265 }
2266 best.map(|(score, kappa)| {
2267 log::info!(
2268 "[spatial-kappa] #1464 κ-fair sign-basin scan selected κ_seed={kappa:.4} \
2269 (κ-fair score={score:.6e}) for term {term_idx}"
2270 );
2271 kappa
2272 })
2273}
2274
2275const SPATIAL_RANGE_PRESCAN_GRID: usize = 7;
2278
2279fn prescan_isotropic_spatial_range_seed(
2311 data: ArrayView2<'_, f64>,
2312 y: ArrayView1<'_, f64>,
2313 weights: ArrayView1<'_, f64>,
2314 offset: ArrayView1<'_, f64>,
2315 resolvedspec: &TermCollectionSpec,
2316 baseline_score: f64,
2317 family: &LikelihoodSpec,
2318 options: &FitOptions,
2319 kappa_options: &SpatialLengthScaleOptimizationOptions,
2320 spatial_terms: &[usize],
2321) -> Result<Vec<(usize, f64)>, EstimationError> {
2322 if has_aniso_terms(resolvedspec, spatial_terms)
2324 || !constant_curvature_term_indices(resolvedspec).is_empty()
2325 {
2326 return Ok(Vec::new());
2327 }
2328 let dims = spatial_dims_per_term(resolvedspec, spatial_terms);
2329 let mut working = resolvedspec.clone();
2333 let mut best_score = if baseline_score.is_finite() {
2334 baseline_score
2335 } else {
2336 f64::INFINITY
2337 };
2338 let mut overrides: Vec<(usize, f64)> = Vec::new();
2339 for (slot, &term_idx) in spatial_terms.iter().enumerate() {
2340 if dims.get(slot).copied().unwrap_or(1) != 1 {
2343 continue;
2344 }
2345 if get_spatial_length_scale(&working, term_idx).is_none() {
2348 continue;
2349 }
2350 let (psi_lo, psi_hi) = spatial_term_psi_bounds(data, &working, term_idx, kappa_options);
2351 if !(psi_lo.is_finite() && psi_hi.is_finite()) || psi_hi <= psi_lo {
2352 continue;
2353 }
2354 let mut term_best: Option<f64> = None;
2355 for g in 0..SPATIAL_RANGE_PRESCAN_GRID {
2356 let frac = g as f64 / (SPATIAL_RANGE_PRESCAN_GRID - 1) as f64;
2357 let psi = psi_lo + (psi_hi - psi_lo) * frac;
2358 let ls = (-psi).exp();
2362 if !ls.is_finite() || ls <= 0.0 {
2363 continue;
2364 }
2365 let mut probe = working.clone();
2366 if set_spatial_length_scale(&mut probe, term_idx, ls).is_err() {
2367 continue;
2368 }
2369 let fit = match fit_term_collection_forspec(
2378 data,
2379 y,
2380 weights,
2381 offset,
2382 &probe,
2383 family.clone(),
2384 options,
2385 ) {
2386 Ok(fit) => fit,
2387 Err(_) => continue,
2390 };
2391 let score = fit_score(&fit.fit);
2392 if score.is_finite() && score < best_score - 1e-7 * best_score.abs().max(1.0) {
2395 best_score = score;
2396 term_best = Some(ls);
2397 }
2398 }
2399 if let Some(ls) = term_best {
2400 set_spatial_length_scale(&mut working, term_idx, ls)?;
2401 overrides.push((term_idx, ls));
2402 log::info!(
2403 "[spatial-kappa] #1074 range pre-scan: term {term_idx} re-seeded at \
2404 length_scale={ls:.5} (profiled REML {best_score:.5}, was {baseline_score:.5})"
2405 );
2406 }
2407 }
2408 Ok(overrides)
2409}
2410
2411const JOINT_RESTART_WINDOW_FRACTIONS: [f64; 5] = [0.0, 0.2, 0.45, 0.7, 1.0];
2420
2421fn joint_solve_from_window_fraction(
2437 data: ArrayView2<'_, f64>,
2438 y: ArrayView1<'_, f64>,
2439 weights: ArrayView1<'_, f64>,
2440 offset: ArrayView1<'_, f64>,
2441 base_spec: &TermCollectionSpec,
2442 spatial_terms: &[usize],
2443 fraction: f64,
2444 family: &LikelihoodSpec,
2445 options: &FitOptions,
2446 baseline_options: &FitOptions,
2447 kappa_options: &SpatialLengthScaleOptimizationOptions,
2448) -> Result<Option<(FittedTermCollectionWithSpec, f64)>, EstimationError> {
2449 let mut seed_spec = base_spec.clone();
2450 let mut any_set = false;
2451 for &term_idx in spatial_terms {
2452 if get_spatial_length_scale(&seed_spec, term_idx).is_none() {
2453 continue;
2454 }
2455 let (psi_lo, psi_hi) = spatial_term_psi_bounds(data, &seed_spec, term_idx, kappa_options);
2456 if !(psi_lo.is_finite() && psi_hi.is_finite()) || psi_hi <= psi_lo {
2457 continue;
2458 }
2459 let psi = psi_lo + (psi_hi - psi_lo) * fraction;
2460 let ls = (-psi).exp();
2461 if !ls.is_finite() || ls <= 0.0 {
2462 continue;
2463 }
2464 if set_spatial_length_scale(&mut seed_spec, term_idx, ls).is_ok() {
2465 any_set = true;
2466 }
2467 }
2468 if !any_set {
2469 return Ok(None);
2470 }
2471 let seed_best = match fit_term_collection_forspec(
2475 data,
2476 y,
2477 weights,
2478 offset,
2479 &seed_spec,
2480 family.clone(),
2481 baseline_options,
2482 ) {
2483 Ok(fit) => fit,
2484 Err(_) => return Ok(None),
2485 };
2486 let seed_spec = freeze_term_collection_from_design(&seed_spec, &seed_best.design)?;
2487 let seed_terms = spatial_length_scale_term_indices(&seed_spec);
2490 if seed_terms.is_empty() {
2491 let score = fit_score(&seed_best.fit);
2492 return Ok(Some((
2493 FittedTermCollectionWithSpec {
2494 fit: seed_best.fit,
2495 design: seed_best.design,
2496 resolvedspec: seed_spec,
2497 adaptive_diagnostics: seed_best.adaptive_diagnostics,
2498 kappa_timing: None,
2499 },
2500 score,
2501 )));
2502 }
2503 let joint = try_exact_joint_spatial_length_scale_optimization(
2504 data,
2505 y,
2506 weights,
2507 offset,
2508 &seed_spec,
2509 &seed_best,
2510 family.clone(),
2511 options,
2512 kappa_options,
2513 &seed_terms,
2514 )?;
2515 match joint {
2516 Some(fit) => {
2517 let score = fit_score(&fit.fit);
2518 Ok(Some((fit, score)))
2519 }
2520 None => {
2523 let score = fit_score(&seed_best.fit);
2524 Ok(Some((
2525 FittedTermCollectionWithSpec {
2526 fit: seed_best.fit,
2527 design: seed_best.design,
2528 resolvedspec: seed_spec,
2529 adaptive_diagnostics: seed_best.adaptive_diagnostics,
2530 kappa_timing: None,
2531 },
2532 score,
2533 )))
2534 }
2535 }
2536}
2537
2538fn try_exact_joint_spatial_length_scale_optimization(
2539 data: ArrayView2<'_, f64>,
2540 y: ArrayView1<'_, f64>,
2541 weights: ArrayView1<'_, f64>,
2542 offset: ArrayView1<'_, f64>,
2543 resolvedspec: &TermCollectionSpec,
2544 best: &FittedTermCollection,
2545 family: LikelihoodSpec,
2546 options: &FitOptions,
2547 kappa_options: &SpatialLengthScaleOptimizationOptions,
2548 spatial_terms: &[usize],
2549) -> Result<Option<FittedTermCollectionWithSpec>, EstimationError> {
2550 if spatial_terms.is_empty() {
2551 return Ok(None);
2552 }
2553 kappa_options
2558 .validate()
2559 .map_err(EstimationError::InvalidInput)?;
2560
2561 let cc_term_set = constant_curvature_term_indices(resolvedspec);
2581 let all_spatial_are_cc =
2582 !cc_term_set.is_empty() && spatial_terms.iter().all(|t| cc_term_set.contains(t));
2583 if all_spatial_are_cc {
2584 let mut fixed_kappa_spec = resolvedspec.clone();
2585 let mut any_kappa_chosen = false;
2586 for &term_idx in spatial_terms {
2587 if let Some(kappa_hat) =
2598 constant_curvature_kappa_fair_argmin(data, y, resolvedspec, term_idx)
2599 .filter(|&k| k < 0.0)
2600 {
2601 if let Some(SmoothBasisSpec::ConstantCurvature { spec: cc, .. }) = fixed_kappa_spec
2602 .smooth_terms
2603 .get_mut(term_idx)
2604 .map(|t| &mut t.basis)
2605 {
2606 cc.kappa = kappa_hat;
2607 any_kappa_chosen = true;
2608 log::info!(
2609 "[spatial-kappa] #1464 term {term_idx}: fixed κ̂ = {kappa_hat:.4} from κ-fair argmin (hyperbolic basin; profiling ρ only)"
2610 );
2611 }
2612 }
2613 }
2614 if any_kappa_chosen {
2615 let baseline_score = fit_score(&best.fit);
2619 let fitted = fit_term_collection_forspec(
2620 data,
2621 y,
2622 weights,
2623 offset,
2624 &fixed_kappa_spec,
2625 family.clone(),
2626 options,
2627 )?;
2628 let frozen_spec =
2629 freeze_term_collection_from_design(&fixed_kappa_spec, &fitted.design)?;
2630 let mut fit = fitted.fit;
2631 fit.reml_score = baseline_score;
2643 return Ok(Some(FittedTermCollectionWithSpec {
2644 fit,
2645 design: fitted.design,
2646 resolvedspec: frozen_spec,
2647 adaptive_diagnostics: fitted.adaptive_diagnostics,
2648 kappa_timing: None,
2649 }));
2650 }
2651 }
2652
2653 if try_build_spatial_log_kappa_hyper_dirs(data, resolvedspec, &best.design, spatial_terms)?
2654 .is_none()
2655 {
2656 if !constant_curvature_term_indices(resolvedspec).is_empty() {
2657 log::info!(
2658 "[#1464-trace] try_exact_joint RETURNED None (hyper_dirs unavailable); \
2659 κ̂ comes from a NON-joint path"
2660 );
2661 }
2662 return Ok(None);
2663 }
2664 if !constant_curvature_term_indices(resolvedspec).is_empty() {
2665 log::info!(
2666 "[#1464-trace] try_exact_joint ENTERED for {} spatial term(s); CC present",
2667 spatial_terms.len()
2668 );
2669 }
2670
2671 const JOINT_RHO_BOUND: f64 = 12.0;
2672 let rho_dim = best.fit.lambdas.len();
2673
2674 let has_constant_curvature_term = !constant_curvature_term_indices(resolvedspec).is_empty();
2688 let rho_upper_bound = if has_constant_curvature_term {
2689 gam_solve::estimate::RHO_BOUND
2690 } else {
2691 JOINT_RHO_BOUND
2692 };
2693
2694 let dims_per_term = spatial_dims_per_term(resolvedspec, spatial_terms);
2696 let use_aniso = has_aniso_terms(resolvedspec, spatial_terms);
2697
2698 let log_kappa0 = if use_aniso {
2703 SpatialLogKappaCoords::from_length_scales_aniso(resolvedspec, spatial_terms, kappa_options)
2704 } else {
2705 SpatialLogKappaCoords::from_length_scales(resolvedspec, spatial_terms, kappa_options)
2706 };
2707 let mut log_kappa0 =
2710 log_kappa0.reseed_from_data(data, resolvedspec, spatial_terms, kappa_options);
2711 let mut cc_sign_seeds: Vec<(usize, f64)> = Vec::new();
2727 if has_constant_curvature_term {
2728 for (slot, &term_idx) in spatial_terms.iter().enumerate() {
2729 if constant_curvature_term_spec(resolvedspec, term_idx).is_none() {
2730 continue;
2731 }
2732 let scan = select_constant_curvature_kappa_sign_seed(
2733 data,
2734 y,
2735 resolvedspec,
2736 term_idx,
2737 );
2738 match scan {
2743 Some(kappa_seed) => {
2744 log::info!(
2745 "[#1464-trace] term {term_idx}: κ-fair sign-basin scan picked κ_seed = {kappa_seed}"
2746 );
2747 log_kappa0.set_scalar_slot(slot, kappa_seed);
2748 cc_sign_seeds.push((slot, kappa_seed));
2749 }
2750 None => {
2751 log::info!(
2752 "[#1464-trace] term {term_idx}: fixed-κ sign-basin scan returned NONE (no seed applied)"
2753 );
2754 }
2755 }
2756 }
2757 }
2758 let log_kappa_lower = if use_aniso {
2759 SpatialLogKappaCoords::lower_bounds_aniso_from_data(
2760 data,
2761 resolvedspec,
2762 spatial_terms,
2763 &dims_per_term,
2764 kappa_options,
2765 )
2766 } else {
2767 SpatialLogKappaCoords::lower_bounds_from_data(
2768 data,
2769 resolvedspec,
2770 spatial_terms,
2771 kappa_options,
2772 )
2773 };
2774 let log_kappa_upper = if use_aniso {
2775 SpatialLogKappaCoords::upper_bounds_aniso_from_data(
2776 data,
2777 resolvedspec,
2778 spatial_terms,
2779 &dims_per_term,
2780 kappa_options,
2781 )
2782 } else {
2783 SpatialLogKappaCoords::upper_bounds_from_data(
2784 data,
2785 resolvedspec,
2786 spatial_terms,
2787 kappa_options,
2788 )
2789 };
2790 let mut log_kappa_lower = log_kappa_lower;
2814 let mut log_kappa_upper = log_kappa_upper;
2815 for &(slot, kappa_seed) in &cc_sign_seeds {
2816 if kappa_seed != 0.0 {
2817 log_kappa_lower.set_scalar_slot(slot, kappa_seed);
2818 log_kappa_upper.set_scalar_slot(slot, kappa_seed);
2819 }
2820 log::info!(
2821 "[#1464-trace] slot {slot}: FROZE joint ψ coordinate at κ_seed={kappa_seed} \
2822 (window [{}, {}]); raw fit_score is sign-blind so the κ-fair scan is authoritative",
2823 log_kappa_lower.as_array()[log_kappa_lower.dims_per_term()[..slot].iter().sum::<usize>()],
2824 log_kappa_upper.as_array()[log_kappa_upper.dims_per_term()[..slot].iter().sum::<usize>()],
2825 );
2826 }
2827 let log_kappa0 = log_kappa0.clamp_to_bounds(&log_kappa_lower, &log_kappa_upper);
2830 let setup = ExactJointHyperSetup::new(
2831 best.fit.lambdas.mapv(f64::ln),
2832 Array1::<f64>::from_elem(rho_dim, -JOINT_RHO_BOUND),
2833 Array1::<f64>::from_elem(rho_dim, rho_upper_bound),
2834 log_kappa0,
2835 log_kappa_lower,
2836 log_kappa_upper,
2837 );
2838
2839 let theta0 = setup.theta0();
2840 let lower = setup.lower();
2841 let upper = setup.upper();
2842
2843 let kind = if use_aniso {
2855 SpatialHyperKind::Anisotropic
2856 } else {
2857 SpatialHyperKind::Isotropic
2858 };
2859 let (outcome, kappa_timing) = run_exact_joint_spatial_optimization(
2860 kind,
2861 data,
2862 y,
2863 weights,
2864 offset,
2865 resolvedspec,
2866 &best.design,
2867 family.clone(),
2868 options,
2869 spatial_terms,
2870 &dims_per_term,
2871 &theta0,
2872 &lower,
2873 &upper,
2874 rho_dim,
2875 kappa_options,
2876 )?;
2877
2878 let baseline_score = fit_score(&best.fit);
2879
2880 let (theta_star, joint_final_value) = match outcome {
2890 SpatialJointOutcome::Optimized {
2891 theta_star,
2892 final_value,
2893 } => (theta_star, final_value),
2894 SpatialJointOutcome::NonConverged {
2895 iterations,
2896 final_value,
2897 final_grad_norm,
2898 } => {
2899 if has_constant_curvature_term {
2900 log::info!(
2901 "[#1464-trace] joint solve NONCONVERGED (iters={iterations}, \
2902 final_value={final_value}); returning FROZEN BASELINE geometry \
2903 (κ̂ = spec default, NOT the joint candidate)"
2904 );
2905 }
2906 log::info!(
2907 "[spatial-kappa] joint spatial optimization did not converge \
2908 (iterations={}, final_objective={:.6e}, final_grad_norm={}); \
2909 keeping the frozen baseline geometry",
2910 iterations,
2911 final_value,
2912 final_grad_norm.map_or_else(|| "n/a".to_string(), |g| format!("{g:.3e}")),
2913 );
2914 return Ok(Some(fit_frozen_baseline_geometry(
2915 data,
2916 y,
2917 weights,
2918 offset,
2919 resolvedspec,
2920 best,
2921 family,
2922 options,
2923 baseline_score,
2924 Some(kappa_timing),
2925 )?));
2926 }
2927 };
2928
2929 let accept_tol = options.tol.max(1e-8 * baseline_score.abs()).max(1e-12);
2934 if joint_final_value > baseline_score + accept_tol {
2935 if has_constant_curvature_term {
2936 log::info!(
2937 "[#1464-trace] joint candidate WORSENED score (joint={joint_final_value}, \
2938 baseline={baseline_score}); returning FROZEN BASELINE geometry \
2939 (κ̂ = spec default, NOT the joint candidate)"
2940 );
2941 }
2942 log::info!(
2943 "[spatial-kappa] exact joint spatial candidate worsened the profiled score (joint={:.6e}, baseline={:.6e}, tol={:.2e}); keeping the frozen baseline geometry",
2944 joint_final_value,
2945 baseline_score,
2946 accept_tol,
2947 );
2948 return Ok(Some(fit_frozen_baseline_geometry(
2949 data,
2950 y,
2951 weights,
2952 offset,
2953 resolvedspec,
2954 best,
2955 family,
2956 options,
2957 baseline_score,
2958 Some(kappa_timing),
2959 )?));
2960 }
2961
2962 let rho_star = theta_star.slice(s![..rho_dim]).mapv(f64::exp);
2963 let log_kappa_star =
2964 SpatialLogKappaCoords::from_theta_tail_with_dims(&theta_star, rho_dim, dims_per_term);
2965 if has_constant_curvature_term {
2971 let star = log_kappa_star.as_array();
2972 let dims = log_kappa_star.dims_per_term();
2973 for (slot, &term_idx) in spatial_terms.iter().enumerate() {
2974 if constant_curvature_term_spec(resolvedspec, term_idx).is_some() {
2975 let off: usize = dims[..slot].iter().sum();
2976 log::info!(
2977 "[#1464-trace] term {term_idx}: joint solver CONVERGED ψ-tail κ = {} \
2978 (this is the optimised candidate; joint_final_value={joint_final_value})",
2979 star[off]
2980 );
2981 }
2982 }
2983 }
2984 let baseline_spec = resolvedspec;
2988 let optimized_spec = log_kappa_star.apply_tospec(resolvedspec, spatial_terms)?;
2989 let optimized = fit_term_collection_forspecwith_heuristic_lambdas(
2990 data,
2991 y,
2992 weights,
2993 offset,
2994 &optimized_spec,
2995 rho_star.as_slice(),
2996 family.clone(),
2997 options,
2998 )?;
2999
3000 let optimized_edf = optimized.fit.inference.as_ref().map(|inf| inf.edf_total);
3014 if let Some(opt_edf) = optimized_edf
3015 && opt_edf < SPATIAL_COLLAPSE_EDF_FLOOR
3016 {
3017 let baseline = fit_frozen_baseline_geometry(
3018 data,
3019 y,
3020 weights,
3021 offset,
3022 baseline_spec,
3023 best,
3024 family.clone(),
3025 options,
3026 baseline_score,
3027 Some(kappa_timing),
3028 )?;
3029 let baseline_edf = baseline.fit.inference.as_ref().map(|inf| inf.edf_total);
3030 if let Some(base_edf) = baseline_edf
3031 && base_edf >= opt_edf + SPATIAL_COLLAPSE_EDF_MARGIN
3032 {
3033 log::info!(
3034 "[spatial-kappa] joint candidate collapsed to the null (edf={opt_edf:.3}); \
3035 baseline geometry retains edf={base_edf:.3} — keeping the frozen baseline",
3036 );
3037 return Ok(Some(baseline));
3038 }
3039 }
3042
3043 let mut fit = optimized.fit;
3047 fit.reml_score = joint_final_value;
3048 let optimized_result = FittedTermCollectionWithSpec {
3049 fit,
3050 design: optimized.design,
3051 resolvedspec: optimized_spec,
3052 adaptive_diagnostics: optimized.adaptive_diagnostics,
3053 kappa_timing: Some(kappa_timing),
3054 };
3055
3056 Ok(Some(optimized_result))
3057}
3058
3059const SPATIAL_COLLAPSE_EDF_FLOOR: f64 = 2.5;
3063
3064const SPATIAL_COLLAPSE_EDF_MARGIN: f64 = 1.0;
3069
3070fn fit_frozen_baseline_geometry(
3106 data: ArrayView2<'_, f64>,
3107 y: ArrayView1<'_, f64>,
3108 weights: ArrayView1<'_, f64>,
3109 offset: ArrayView1<'_, f64>,
3110 resolvedspec: &TermCollectionSpec,
3111 best: &FittedTermCollection,
3112 family: LikelihoodSpec,
3113 options: &FitOptions,
3114 baseline_score: f64,
3115 kappa_timing: Option<SpatialLengthScaleOptimizationTiming>,
3116) -> Result<FittedTermCollectionWithSpec, EstimationError> {
3117 let baseline = fit_term_collection_forspecwith_heuristic_lambdas(
3118 data,
3119 y,
3120 weights,
3121 offset,
3122 resolvedspec,
3123 best.fit.lambdas.as_slice(),
3124 family.clone(),
3125 options,
3126 )?;
3127 let best_edf = best.fit.inference.as_ref().map(|inf| inf.edf_total);
3132 let baseline_edf = baseline.fit.inference.as_ref().map(|inf| inf.edf_total);
3133 let baseline = match (best_edf, baseline_edf) {
3134 (Some(best_edf), Some(base_edf))
3135 if base_edf < SPATIAL_COLLAPSE_EDF_FLOOR
3136 && best_edf >= base_edf + SPATIAL_COLLAPSE_EDF_MARGIN =>
3137 {
3138 log::info!(
3139 "[spatial-kappa] warm-started frozen baseline collapsed (edf={base_edf:.3}) \
3140 below the certified baseline (edf={best_edf:.3}); refitting from scratch",
3141 );
3142 fit_term_collection_forspec(data, y, weights, offset, resolvedspec, family, options)?
3143 }
3144 _ => baseline,
3145 };
3146 let mut fit = baseline.fit;
3147 fit.reml_score = baseline_score;
3148 Ok(FittedTermCollectionWithSpec {
3149 fit,
3150 design: baseline.design,
3151 resolvedspec: resolvedspec.clone(),
3152 adaptive_diagnostics: baseline.adaptive_diagnostics,
3153 kappa_timing,
3154 })
3155}
3156
3157#[derive(Clone, Copy, PartialEq, Eq, Debug)]
3169enum SpatialHyperKind {
3170 Anisotropic,
3171 Isotropic,
3172}
3173
3174impl SpatialHyperKind {
3175 fn label(self) -> &'static str {
3178 match self {
3179 SpatialHyperKind::Anisotropic => "spatial-aniso-joint",
3180 SpatialHyperKind::Isotropic => "spatial-iso-joint",
3181 }
3182 }
3183
3184 fn adjective(self) -> &'static str {
3186 match self {
3187 SpatialHyperKind::Anisotropic => "anisotropic",
3188 SpatialHyperKind::Isotropic => "isotropic",
3189 }
3190 }
3191
3192 fn coord_name(self) -> &'static str {
3195 match self {
3196 SpatialHyperKind::Anisotropic => "psi",
3197 SpatialHyperKind::Isotropic => "kappa",
3198 }
3199 }
3200}
3201
3202struct SpatialFrozenGlmInputs {
3208 y: Array1<f64>,
3209 weights: Array1<f64>,
3210 offset: Array1<f64>,
3211 family: LikelihoodSpec,
3212}
3213
3214fn frozen_glm_tensor_eligible_family(family: &LikelihoodSpec) -> bool {
3231 !family.is_gaussian_identity()
3232 && matches!(
3233 &family.response,
3234 ResponseFamily::Binomial
3235 | ResponseFamily::Poisson
3236 | ResponseFamily::Gamma
3237 | ResponseFamily::NegativeBinomial { .. }
3238 )
3239}
3240
3241struct SpatialJointContext<'d> {
3242 data: ArrayView2<'d, f64>,
3243 rho_dim: usize,
3244 kind: SpatialHyperKind,
3245 cache: SingleBlockExactJointDesignCache<'d>,
3246 evaluator: gam_solve::estimate::ExternalJointHyperEvaluator<'d>,
3247 frozen_glm_inputs: Option<SpatialFrozenGlmInputs>,
3248 frozen_glm_psi_bounds: Option<(f64, f64)>,
3249 frozen_glm_tensor: Option<gam_solve::glm_sufficient_lane::FrozenWeightGramTensor>,
3250 frozen_glm_tensor_attempted: bool,
3251 frozen_glm_weight_memo: Option<(Array1<f64>, Array1<f64>)>,
3263}
3264
3265#[derive(Clone, Copy, Debug, Default)]
3266struct NfreeSkipGateStatus {
3267 shape: bool,
3268 value: bool,
3269 gradient: bool,
3270 penalty: bool,
3271 revision: bool,
3272 second_order: bool,
3273}
3274
3275impl NfreeSkipGateStatus {
3276 fn would_skip(self, require_gradient: bool) -> bool {
3277 self.shape
3278 && self.value
3279 && (!require_gradient || self.gradient)
3280 && self.penalty
3281 && self.revision
3282 && !self.second_order
3283 }
3284}
3285
3286impl<'d> SpatialJointContext<'d> {
3287 fn nfree_skip_gate_status(
3288 &self,
3289 theta: &Array1<f64>,
3290 allow_second_order: bool,
3291 require_gradient: bool,
3292 ) -> NfreeSkipGateStatus {
3293 let shape = theta.len() == self.rho_dim + 1;
3294 let (value, gradient) = if shape {
3295 let psi = theta[self.rho_dim];
3296 (
3297 self.evaluator.psi_gram_tensor_covers(psi)
3298 && self.evaluator.psi_gram_tensor_covers_skip(psi),
3299 !require_gradient || self.evaluator.psi_gram_tensor_covers_gradient(psi),
3300 )
3301 } else {
3302 (false, false)
3303 };
3304 NfreeSkipGateStatus {
3305 shape,
3306 value,
3307 gradient,
3308 penalty: self.evaluator.supports_nfree_penalty_rekey(),
3309 revision: self.evaluator.nfree_fast_path_revision().is_some(),
3310 second_order: allow_second_order,
3311 }
3312 }
3313
3314 fn frozen_glm_working_state(
3315 &self,
3316 beta: &Array1<f64>,
3317 ) -> Result<Option<(Array1<f64>, Array1<f64>)>, EstimationError> {
3318 let Some(inputs) = self.frozen_glm_inputs.as_ref() else {
3319 return Ok(None);
3320 };
3321 if beta.len() != self.cache.design().design.ncols() {
3322 return Ok(None);
3323 }
3324 let mut eta = self.cache.design().design.matrixvectormultiply(beta);
3325 if eta.len() != inputs.offset.len() {
3326 crate::bail_invalid_estim!(
3327 "frozen GLM tensor warm-state row mismatch: eta={}, offset={}",
3328 eta.len(),
3329 inputs.offset.len()
3330 );
3331 }
3332 eta += &inputs.offset;
3333 let obs = evaluate_standard_familyobservations(
3334 inputs.family.clone(),
3335 None,
3336 None,
3337 None,
3338 &inputs.y,
3339 &inputs.weights,
3340 &eta,
3341 )?;
3342 let mut working_response = obs.eta.clone();
3343 for i in 0..working_response.len() {
3344 let wi = obs.fisherweight[i].max(1e-12);
3345 working_response[i] += obs.score[i] / wi;
3346 }
3347 Ok(Some((obs.fisherweight, working_response)))
3348 }
3349
3350 fn frozen_glm_trial_weights(
3359 &mut self,
3360 beta: &Array1<f64>,
3361 ) -> Result<Option<Array1<f64>>, EstimationError> {
3362 if let Some((memo_beta, memo_w)) = self.frozen_glm_weight_memo.as_ref()
3363 && memo_beta.len() == beta.len()
3364 && memo_beta
3365 .iter()
3366 .zip(beta.iter())
3367 .all(|(a, b)| a.to_bits() == b.to_bits())
3368 {
3369 return Ok(Some(memo_w.clone()));
3370 }
3371 match self.frozen_glm_working_state(beta)? {
3372 Some((current_w, _)) => {
3373 self.frozen_glm_weight_memo = Some((beta.clone(), current_w.clone()));
3374 Ok(Some(current_w))
3375 }
3376 None => Ok(None),
3377 }
3378 }
3379
3380 fn ensure_frozen_glm_tensor(
3381 &mut self,
3382 theta: &Array1<f64>,
3383 warm_beta: Option<&Array1<f64>>,
3384 ) -> Result<(), EstimationError> {
3385 if self.frozen_glm_tensor.is_some() || self.frozen_glm_tensor_attempted {
3386 return Ok(());
3387 }
3388 let Some((psi_lo, psi_hi)) = self.frozen_glm_psi_bounds else {
3389 return Ok(());
3390 };
3391 if theta.len() != self.rho_dim + 1 {
3392 self.frozen_glm_tensor_attempted = true;
3393 return Ok(());
3394 }
3395 let Some(beta) = warm_beta else {
3396 return Ok(());
3397 };
3398 let Some((frozen_w, working_z)) = self.frozen_glm_working_state(beta)? else {
3399 self.frozen_glm_tensor_attempted = true;
3400 return Ok(());
3401 };
3402 let theta_probe_base = theta.clone();
3403 let rho_dim = self.rho_dim;
3404 let Self {
3411 cache, evaluator, ..
3412 } = self;
3413 let tensor = evaluator.build_frozen_glm_gram_tensor(
3414 |psi| {
3415 let mut theta_probe = theta_probe_base.clone();
3416 theta_probe[rho_dim] = psi;
3417 cache.ensure_theta(&theta_probe)?;
3418 Ok(cache.design().design.clone())
3419 },
3420 frozen_w.view(),
3421 working_z.view(),
3422 psi_lo,
3423 psi_hi,
3424 );
3425 self.cache
3426 .ensure_theta(theta)
3427 .map_err(EstimationError::InvalidInput)?;
3428 self.frozen_glm_tensor_attempted = true;
3429 if let Some(tensor) = tensor {
3430 self.frozen_glm_tensor = Some(tensor);
3431 log::info!(
3432 "[STAGE] {} certified frozen-W GLM ψ tensor over [{psi_lo:.3}, {psi_hi:.3}]",
3433 self.kind.label(),
3434 );
3435 } else {
3436 log::info!(
3437 "[STAGE] {} frozen-W GLM ψ tensor did not certify over [{psi_lo:.3}, {psi_hi:.3}]",
3438 self.kind.label(),
3439 );
3440 }
3441 Ok(())
3442 }
3443
3444 fn stage_frozen_glm_trial_statistics(
3445 &mut self,
3446 theta: &Array1<f64>,
3447 warm_beta: Option<&Array1<f64>>,
3448 allow_gradient: bool,
3449 ) -> Result<(), EstimationError> {
3450 let kind = self.kind;
3451 let mut staged_gram: Option<Array2<f64>> = None;
3452 let mut staged_deriv: Option<(Array2<f64>, Array1<f64>)> = None;
3453 if theta.len() == self.rho_dim + 1 {
3454 let psi = theta[self.rho_dim];
3455 let tensor_covers = self
3462 .frozen_glm_tensor
3463 .as_ref()
3464 .is_some_and(|t| t.contains(psi));
3465 let current_w = if tensor_covers {
3466 match warm_beta {
3467 Some(beta) => self.frozen_glm_trial_weights(beta)?,
3468 None => None,
3469 }
3470 } else {
3471 None
3472 };
3473 if let (Some(tensor), Some(current_w)) =
3474 (self.frozen_glm_tensor.as_ref(), current_w.as_ref())
3475 {
3476 const FROZEN_GLM_WEIGHT_DRIFT_RTOL: f64 = 1e-3;
3477 if tensor.weight_drift_within(current_w.view(), FROZEN_GLM_WEIGHT_DRIFT_RTOL) {
3478 staged_gram = Some(tensor.gram_at(psi));
3479 log::debug!(
3480 "[STAGE] {} trial at psi={psi:.6}: serving frozen-W GLM \
3481 first-Fisher-step XᵀWX n-free (weight drift within tol)",
3482 kind.label(),
3483 );
3484 }
3485 if allow_gradient
3486 && tensor.contains_for_gradient(psi)
3487 && let Some((dgram_dpsi, drhs_dpsi)) =
3488 tensor.gradient_pair_if_sound(psi, current_w.view())
3489 {
3490 staged_deriv = Some((dgram_dpsi, drhs_dpsi));
3491 log::debug!(
3492 "[STAGE] {} trial at psi={psi:.6}: serving frozen-W GLM \
3493 ψ-gradient (∂G/∂ψ, ∂b/∂ψ) n-free (gradient weight drift within \
3494 tight tol); B_j stays exact",
3495 kind.label(),
3496 );
3497 }
3498 }
3499 }
3500 self.evaluator.stage_glm_first_step_gram(staged_gram);
3501 self.evaluator.stage_glm_psi_gram_deriv(staged_deriv);
3502 Ok(())
3503 }
3504
3505 fn eval_full(
3507 &mut self,
3508 theta: &Array1<f64>,
3509 order: gam_solve::rho_optimizer::OuterEvalOrder,
3510 analytic_outer_hessian_available: bool,
3511 ) -> Result<
3512 (
3513 f64,
3514 Array1<f64>,
3515 gam_problem::HessianResult,
3516 ),
3517 EstimationError,
3518 > {
3519 use gam_solve::rho_optimizer::OuterEvalOrder;
3520 let allow_second_order = matches!(order, OuterEvalOrder::ValueGradientHessian)
3521 && analytic_outer_hessian_available;
3522 if let Some(eval) = self.cache.memoized_eval(theta) {
3523 let cached_satisfies_order = !allow_second_order || eval.2.is_analytic();
3524 if cached_satisfies_order {
3525 return Ok(eval);
3526 }
3527 }
3528 let kind = self.kind;
3529 let nfree_fast_path_revision = self.evaluator.nfree_fast_path_revision();
3565 let skip_design_realization = !allow_second_order && theta.len() == self.rho_dim + 1 && {
3566 let psi = theta[self.rho_dim];
3567 self.evaluator.psi_gram_tensor_covers(psi)
3568 && self.evaluator.psi_gram_tensor_covers_gradient(psi)
3575 && self.evaluator.psi_gram_tensor_covers_skip(psi)
3592 && self.evaluator.supports_nfree_penalty_rekey()
3597 && nfree_fast_path_revision.is_some()
3598 };
3599 if skip_design_realization {
3600 log::debug!(
3601 "[STAGE] {} eval_full at psi={:.6}: skipping n×k design re-realization \
3602 + reconditioning — criterion/gradient/inner-solve served n-free from \
3603 the certified ψ-gram tensor (GaussianFixedCache + k-space ψ-derivatives)",
3604 kind.label(),
3605 theta[self.rho_dim],
3606 );
3607 } else {
3608 self.cache
3609 .ensure_theta(theta)
3610 .map_err(EstimationError::InvalidInput)?;
3611 }
3612 let warm_beta = self.evaluator.current_beta();
3613 self.ensure_frozen_glm_tensor(theta, warm_beta.as_ref())?;
3614 self.stage_frozen_glm_trial_statistics(theta, warm_beta.as_ref(), !allow_second_order)?;
3622 let hyper_dirs = if skip_design_realization {
3629 self.cache.nfree_tensor_gradient_hyper_dirs(theta)?
3630 } else {
3631 self.cache.hyper_dirs_for_current_design(self.data, kind)?
3632 };
3633
3634 let design_revision = if skip_design_realization {
3635 nfree_fast_path_revision
3636 } else {
3637 Some(self.cache.design_revision())
3638 };
3639 if self.evaluator.supports_nfree_penalty_rekey() {
3653 match self.cache.canonical_penalties_at(theta) {
3654 Ok(penalty) => self.evaluator.stage_fast_path_penalty(Some(penalty)),
3655 Err(e) => {
3656 log::warn!(
3657 "[STAGE] {} eval_full at psi={:.6}: exact n-free S(ψ) rebuild failed \
3658 ({e}); clearing stage (eval falls to slow path)",
3659 kind.label(),
3660 theta[self.rho_dim],
3661 );
3662 self.evaluator.stage_fast_path_penalty(None);
3663 }
3664 }
3665 }
3666 let eval = evaluate_joint_reml_outer_eval_at_theta(
3673 &mut self.evaluator,
3674 self.cache.design(),
3675 theta,
3676 self.rho_dim,
3677 hyper_dirs,
3678 warm_beta.as_ref().map(|b: &Array1<f64>| b.view()),
3679 if allow_second_order {
3680 order
3681 } else {
3682 OuterEvalOrder::ValueAndGradient
3683 },
3684 design_revision,
3685 );
3686 if let Ok(ref value) = eval {
3687 self.cache.store_eval_at(theta, value.clone());
3688 }
3689 eval
3690 }
3691
3692 fn eval_efs(
3693 &mut self,
3694 theta: &Array1<f64>,
3695 ) -> Result<gam_problem::EfsEval, EstimationError> {
3696 self.cache
3697 .ensure_theta(theta)
3698 .map_err(EstimationError::InvalidInput)?;
3699 let kind = self.kind;
3700 let hyper_dirs = try_build_spatial_log_kappa_hyper_dirs(
3701 self.data,
3702 self.cache.spec(),
3703 self.cache.design(),
3704 &self.cache.spatial_terms,
3705 )?
3706 .ok_or_else(|| {
3707 EstimationError::InvalidInput(format!(
3708 "failed to build {} hyper_dirs for exact-joint EFS",
3709 kind.adjective(),
3710 ))
3711 })?;
3712 let design_revision = Some(self.cache.design_revision());
3713 let warm_beta = self.evaluator.current_beta();
3714 evaluate_joint_reml_efs_at_theta(
3715 &mut self.evaluator,
3716 self.cache.design(),
3717 theta,
3718 self.rho_dim,
3719 hyper_dirs,
3720 warm_beta.as_ref().map(|b: &Array1<f64>| b.view()),
3721 design_revision,
3722 )
3723 }
3724
3725 fn eval_cost(&mut self, theta: &Array1<f64>) -> f64 {
3731 if let Some(cost) = self.cache.memoized_cost(theta) {
3732 return cost;
3733 }
3734 let probe_start = std::time::Instant::now();
3749 let psi_distance = self
3750 .cache
3751 .current_theta
3752 .as_ref()
3753 .filter(|reference| reference.len() == theta.len())
3754 .map(|reference| {
3755 reference
3756 .iter()
3757 .zip(theta.iter())
3758 .map(|(a, b)| (a - b) * (a - b))
3759 .sum::<f64>()
3760 .sqrt()
3761 })
3762 .unwrap_or(f64::NAN);
3763 let nfree_fast_path_revision = self.evaluator.nfree_fast_path_revision();
3777 let skip_value_realization = theta.len() == self.rho_dim + 1 && {
3778 let psi = theta[self.rho_dim];
3779 self.evaluator.psi_gram_tensor_covers(psi)
3780 && self.evaluator.psi_gram_tensor_covers_skip(psi)
3789 && self.evaluator.supports_nfree_penalty_rekey()
3794 && nfree_fast_path_revision.is_some()
3795 };
3796 if theta.len() == self.rho_dim + 1
3797 && self.evaluator.has_psi_gram_tensor()
3798 && !self.evaluator.psi_gram_tensor_covers(theta[self.rho_dim])
3799 {
3800 self.cache.store_cost_at(theta, f64::INFINITY);
3801 return f64::INFINITY;
3802 }
3803 if !skip_value_realization && self.cache.ensure_theta(theta).is_err() {
3804 return f64::INFINITY;
3805 }
3806 if self.evaluator.supports_nfree_penalty_rekey() {
3812 match self.cache.canonical_penalties_at(theta) {
3813 Ok(penalty) => self.evaluator.stage_fast_path_penalty(Some(penalty)),
3814 Err(_) => self.evaluator.stage_fast_path_penalty(None),
3815 }
3816 }
3817 let warm_beta = self.evaluator.current_beta();
3818 if let Err(err) = self.ensure_frozen_glm_tensor(theta, warm_beta.as_ref()) {
3819 log::warn!(
3820 "[STAGE] {} value-probe at psi={:.6}: frozen-W GLM tensor setup failed ({err}); \
3821 falling back to exact streamed Gram",
3822 self.kind.label(),
3823 if theta.len() > self.rho_dim {
3824 theta[self.rho_dim]
3825 } else {
3826 f64::NAN
3827 },
3828 );
3829 self.evaluator.stage_glm_first_step_gram(None);
3830 self.evaluator.stage_glm_psi_gram_deriv(None);
3831 } else if let Err(err) =
3832 self.stage_frozen_glm_trial_statistics(theta, warm_beta.as_ref(), false)
3833 {
3834 log::warn!(
3835 "[STAGE] {} value-probe at psi={:.6}: frozen-W GLM staging failed ({err}); \
3836 falling back to exact streamed Gram",
3837 self.kind.label(),
3838 if theta.len() > self.rho_dim {
3839 theta[self.rho_dim]
3840 } else {
3841 f64::NAN
3842 },
3843 );
3844 self.evaluator.stage_glm_first_step_gram(None);
3845 self.evaluator.stage_glm_psi_gram_deriv(None);
3846 }
3847 let design_revision = if skip_value_realization {
3848 nfree_fast_path_revision
3849 } else {
3850 Some(self.cache.design_revision())
3851 };
3852 let cost_label = self.kind.label();
3853 let result = {
3854 let design = self.cache.design();
3855 self.evaluator.evaluate_cost_only(
3856 &design.design,
3857 &design.penalties,
3858 &design.nullspace_dims,
3859 design.linear_constraints.clone(),
3860 theta,
3861 self.rho_dim,
3862 warm_beta.as_ref().map(|b: &Array1<f64>| b.view()),
3863 cost_label,
3864 design_revision,
3865 )
3866 };
3867 match result {
3868 Ok(cost) => {
3869 log::debug!(
3870 "[STAGE] {cost_label} value-probe (order=Value): elapsed={:.3}s \
3871 cost={cost:.6e} trial_theta_distance={psi_distance:.3e}",
3872 probe_start.elapsed().as_secs_f64(),
3873 );
3874 self.cache.store_cost_at(theta, cost);
3875 cost
3876 }
3877 Err(_) => f64::INFINITY,
3878 }
3879 }
3880
3881 fn reset(&mut self) {
3882 self.cache.current_theta = None;
3883 self.cache.last_eval_theta = None;
3884 self.cache.last_cost = None;
3885 self.cache.last_eval = None;
3886 }
3887}
3888
3889enum SpatialJointOutcome {
3922 Optimized {
3926 theta_star: Array1<f64>,
3927 final_value: f64,
3928 },
3929 NonConverged {
3933 iterations: usize,
3934 final_value: f64,
3935 final_grad_norm: Option<f64>,
3936 },
3937}
3938
3939fn kphase_log_norms(theta: &Array1<f64>, rho_dim: usize) -> (f64, f64) {
3940 let theta_norm = theta.iter().map(|v| v * v).sum::<f64>().sqrt();
3941 let log_kappa_norm = theta
3942 .iter()
3943 .skip(rho_dim)
3944 .map(|v| v * v)
3945 .sum::<f64>()
3946 .sqrt();
3947 (theta_norm, log_kappa_norm)
3948}
3949
3950fn run_exact_joint_spatial_optimization(
3951 kind: SpatialHyperKind,
3952 data: ArrayView2<'_, f64>,
3953 y: ArrayView1<'_, f64>,
3954 weights: ArrayView1<'_, f64>,
3955 offset: ArrayView1<'_, f64>,
3956 resolvedspec: &TermCollectionSpec,
3957 baseline_design: &TermCollectionDesign,
3958 family: LikelihoodSpec,
3959 options: &FitOptions,
3960 spatial_terms: &[usize],
3961 dims_per_term: &[usize],
3962 theta0: &Array1<f64>,
3963 lower: &Array1<f64>,
3964 upper: &Array1<f64>,
3965 rho_dim: usize,
3966 kappa_options: &SpatialLengthScaleOptimizationOptions,
3967) -> Result<(SpatialJointOutcome, SpatialLengthScaleOptimizationTiming), EstimationError> {
3968 let label = kind.label();
3969 assert!(
3971 lower.len() == theta0.len() && upper.len() == theta0.len(),
3972 "spatial hyperparameter bounds must match theta length: lower_len={}, upper_len={}, theta_len={}",
3973 lower.len(),
3974 upper.len(),
3975 theta0.len()
3976 );
3977 assert!(
3978 baseline_design.smooth.terms.len() >= spatial_terms.len(),
3979 "baseline design must have at least one smooth term per spatial term: baseline_terms={}, spatial_terms={}",
3980 baseline_design.smooth.terms.len(),
3981 spatial_terms.len()
3982 );
3983 use gam_solve::rho_optimizer::OuterEvalOrder;
3984 use gam_problem::{DeclaredHessianForm, Derivative, OuterEval};
3985
3986 let theta_dim = theta0.len();
3987 let coord_dim = theta_dim - rho_dim;
3990 let analytic_outer_hessian_available =
4000 exact_joint_spatial_outer_hessian_available(&family, baseline_design);
4001 if !analytic_outer_hessian_available {
4002 log::info!(
4003 "[{label}] analytic outer Hessian unavailable for family/design; routing without second-order geometry (coord_dim={coord_dim})"
4004 );
4005 }
4006 let mut prefer_gradient_only = theta_dim > EXACT_JOINT_SECOND_ORDER_THETA_CAP;
4012 if prefer_gradient_only {
4013 log::info!(
4014 "[{label}] joint θ-dim {theta_dim} exceeds the exact pair-Hessian budget \
4015 ({EXACT_JOINT_SECOND_ORDER_THETA_CAP}); routing gradient-only quasi-Newton"
4016 );
4017 }
4018 let mut suppress_outer_hessian_for_nfree = false;
4028
4029 log::trace!(
4030 "[{}] starting analytic optimization: rho_dim={}, coord_dim={}, dims_per_term={:?}",
4031 label,
4032 rho_dim,
4033 coord_dim,
4034 dims_per_term,
4035 );
4036
4037 let mut ctx = SpatialJointContext {
4038 data,
4039 rho_dim,
4040 kind,
4041 cache: SingleBlockExactJointDesignCache::new(
4042 data,
4043 resolvedspec.clone(),
4044 baseline_design.clone(),
4045 spatial_terms.to_vec(),
4046 rho_dim,
4047 dims_per_term.to_vec(),
4048 )
4049 .map_err(EstimationError::InvalidInput)?,
4050 evaluator: gam_solve::estimate::ExternalJointHyperEvaluator::new(
4051 y,
4052 weights,
4053 &baseline_design.design,
4054 offset,
4055 &baseline_design.penalties,
4056 &external_opts_for_design(&family, baseline_design, options),
4057 label,
4058 )?,
4059 frozen_glm_inputs: if coord_dim == 1 && frozen_glm_tensor_eligible_family(&family) {
4060 Some(SpatialFrozenGlmInputs {
4061 y: y.to_owned(),
4062 weights: weights.to_owned(),
4063 offset: offset.to_owned(),
4064 family: family.clone(),
4065 })
4066 } else {
4067 None
4068 },
4069 frozen_glm_psi_bounds: if coord_dim == 1 && frozen_glm_tensor_eligible_family(&family) {
4070 Some((lower[rho_dim], upper[rho_dim]))
4071 } else {
4072 None
4073 },
4074 frozen_glm_tensor: None,
4075 frozen_glm_tensor_attempted: false,
4076 frozen_glm_weight_memo: None,
4077 };
4078
4079 let mut psi_rank_stable_floor: Option<f64> = None;
4102 let mut psi_rank_stable_ceiling: Option<f64> = None;
4111 let nfree_penalty_capable = coord_dim == 1
4112 && family.is_gaussian_identity()
4113 && ctx.cache.supports_nfree_penalty_rekey();
4114 if nfree_penalty_capable {
4115 let psi_lo = lower[rho_dim];
4116 let psi_hi = upper[rho_dim];
4117 let z = Array1::from_iter(y.iter().zip(offset.iter()).map(|(yi, oi)| yi - oi));
4118 let theta_probe_base = theta0.clone();
4119 let SpatialJointContext {
4122 cache, evaluator, ..
4123 } = &mut ctx;
4124 let attached = evaluator.build_and_set_psi_gram_tensor(
4125 |psi| {
4126 let mut theta_probe = theta_probe_base.clone();
4127 theta_probe[rho_dim] = psi;
4128 cache.ensure_theta(&theta_probe)?;
4129 Ok(cache.design().design.clone())
4130 },
4131 weights,
4132 z.view(),
4133 psi_lo,
4134 psi_hi,
4135 );
4136 if attached {
4137 log::info!(
4138 "[{label}] certified ψ-gram tensor over [{psi_lo:.3}, {psi_hi:.3}]: \
4139 in-window trials assemble Gaussian sufficient statistics n-free"
4140 );
4141 let psi_anchor = theta0[rho_dim];
4146 psi_rank_stable_floor = evaluator
4147 .psi_gram_rank_stable_floor(psi_anchor)
4148 .filter(|&f| f.is_finite() && f > psi_lo && f < psi_anchor);
4149 log::info!(
4150 "[KAPPA-PHASE-FLOOR] n_rows={} psi_lo={psi_lo:.6} psi_anchor={psi_anchor:.6} \
4151 rank_stable_floor={:?} lifted={}",
4152 data.nrows(),
4153 evaluator.psi_gram_rank_stable_floor(psi_anchor),
4154 psi_rank_stable_floor.is_some(),
4155 );
4156 if let Some(floor) = psi_rank_stable_floor {
4157 log::info!(
4158 "[{label}] rank-stable κ-floor ψ_floor={floor:.6} > window floor \
4159 ψ_lo={psi_lo:.6}: lifting the optimizer lower bound to keep every \
4160 in-window trial on the n-free design-realization skip (#1033). The \
4161 conditioned Gram is rank-deficient below ψ_floor (longest-length-scale \
4162 radial mode collapses into the nullspace), where the skip is soundly \
4163 refused; that band drifts with n via the sample-std standardization, \
4164 so this n-free k-space floor is the n-independent fix."
4165 );
4166 }
4167 psi_rank_stable_ceiling = evaluator
4176 .psi_gram_rank_stable_ceiling(psi_anchor)
4177 .filter(|&c| c.is_finite() && c < psi_hi && c > psi_anchor);
4178 log::info!(
4179 "[KAPPA-PHASE-CEIL] n_rows={} psi_hi={psi_hi:.6} psi_anchor={psi_anchor:.6} \
4180 rank_stable_ceiling={:?} clamped={}",
4181 data.nrows(),
4182 evaluator.psi_gram_rank_stable_ceiling(psi_anchor),
4183 psi_rank_stable_ceiling.is_some(),
4184 );
4185 if let Some(ceiling) = psi_rank_stable_ceiling {
4186 log::info!(
4187 "[{label}] rank-stable κ-ceiling ψ_ceil={ceiling:.6} < window ceiling \
4188 ψ_hi={psi_hi:.6}: clamping the optimizer upper bound to keep every \
4189 in-window trial on the n-free design-realization skip (#1033). The \
4190 conditioned Gram is rank-deficient above ψ_ceil (longest-frequency \
4191 radial mode goes collinear), where the skip is soundly refused; a \
4192 line-search overshoot there trips the O(n) reset_surface lane (and the \
4193 deficient pinning ψ it records resets the next in-band trial too)."
4194 );
4195 }
4196 let gradient_covers_full_window = evaluator.psi_gram_tensor_covers_gradient(psi_lo)
4197 && evaluator.psi_gram_tensor_covers_gradient(psi_hi);
4198 if gradient_covers_full_window {
4199 log::info!(
4200 "[{label}] certified ψ-gram tensor gradient lane covers the full \
4201 optimizer window [{psi_lo:.3}, {psi_hi:.3}]"
4202 );
4203 } else {
4204 log::info!(
4205 "[{label}] ψ-gram tensor value lane certified, but the gradient lane \
4206 does not cover the full optimizer window [{psi_lo:.3}, {psi_hi:.3}]; \
4207 keeping exact streamed kappa routing"
4208 );
4209 }
4210 evaluator.set_supports_nfree_penalty_rekey(true);
4230 log::info!(
4231 "[{label}] exact n-free ψ-penalty re-key enabled over [{psi_lo:.3}, \
4232 {psi_hi:.3}]: in-window fast-path trials rebuild S(ψ) n-free from frozen \
4233 geometry (no reset_surface)"
4234 );
4235 } else {
4236 log::info!(
4237 "[{label}] ψ-gram tensor did not certify over [{psi_lo:.3}, {psi_hi:.3}]; \
4238 keeping the exact per-trial path"
4239 );
4240 }
4241 if attached
4262 && evaluator.psi_gram_tensor_covers_gradient(psi_lo)
4263 && evaluator.psi_gram_tensor_covers_gradient(psi_hi)
4264 && evaluator.supports_nfree_penalty_rekey()
4265 && cache.supports_nfree_gradient_only_routing()
4266 {
4267 suppress_outer_hessian_for_nfree = true;
4268 prefer_gradient_only = true;
4269 log::info!(
4270 "[{label}] n-free Gaussian ψ-lane armed; suppressing the analytic outer \
4271 Hessian and routing gradient-only (BFGS) so the κ outer loop never realizes \
4272 the O(n) second-order slab — n-independent outer loop (#1033)"
4273 );
4274 }
4275 } else if coord_dim == 1 && family.is_gaussian_identity() {
4276 log::info!(
4277 "[{label}] exact n-free ψ-penalty re-key unavailable; skipping ψ-gram tensor \
4278 attachment so value, gradient, and Hessian remain on the same exact streamed \
4279 objective"
4280 );
4281 }
4282
4283 const OUTER_FD_AUDIT_MAX_N: usize = 4_000; const OUTER_FD_AUDIT_MAX_THETA_DIM: usize = 32; let n_total = data.nrows();
4311 let outer_fd_audit_eligible = log::log_enabled!(log::Level::Info) && analytic_outer_hessian_available && n_total <= OUTER_FD_AUDIT_MAX_N && theta_dim <= OUTER_FD_AUDIT_MAX_THETA_DIM; log::info!(
4316 "[OUTER-FD-AUDIT/spatial-exact-joint] gate eligible={outer_fd_audit_eligible} \
4317 analytic_grad={analytic_outer_hessian_available} n_total={n_total} \
4318 theta_dim={theta_dim} rho_dim={rho_dim} psi_dim={coord_dim}"
4319 );
4320 if outer_fd_audit_eligible {
4321 let audit = (|| -> Result<gam_solve::rho_optimizer::OuterGradientFdAudit, String> {
4323 let mut eval_at = |theta: &Array1<f64>,
4324 mode: gam_solve::estimate::reml::reml_outer_engine::EvalMode|
4325 -> Result<
4326 (
4327 f64,
4328 Array1<f64>,
4329 gam_problem::HessianResult,
4330 ),
4331 String,
4332 > {
4333 use gam_solve::estimate::reml::reml_outer_engine::EvalMode;
4334 let order = if matches!(mode, EvalMode::ValueGradientHessian) {
4335 OuterEvalOrder::ValueGradientHessian
4336 } else {
4337 OuterEvalOrder::Value
4338 };
4339 ctx.eval_full(theta, order, analytic_outer_hessian_available)
4340 .map_err(|e| format!("fd-audit eval_full: {e}"))
4341 };
4342 let rho_dim_audit = rho_dim;
4343 let label_fn = move |i: usize| -> String {
4344 if i < rho_dim_audit {
4345 format!("rho[{i}]")
4346 } else {
4347 format!("psi_kappa[{}]", i - rho_dim_audit)
4348 }
4349 };
4350 gam_solve::rho_optimizer::outer_gradient_fd_audit(
4351 theta0,
4353 1e-4,
4354 label_fn,
4355 &mut eval_at,
4356 )
4357 })();
4358 match audit {
4360 Ok(audit) => audit.log_verdict("spatial-exact-joint"),
4361 Err(e) => log::warn!("[OUTER-FD-AUDIT/spatial-exact-joint] skipped: {e}"),
4362 }
4363 }
4364
4365 let kphase_prime_order = if analytic_outer_hessian_available && !suppress_outer_hessian_for_nfree {
4366 OuterEvalOrder::ValueGradientHessian
4367 } else {
4368 OuterEvalOrder::ValueAndGradient
4369 };
4370 let kphase_prime_start = std::time::Instant::now();
4371 drop(ctx.eval_full(theta0, kphase_prime_order, analytic_outer_hessian_available)?);
4372 log::info!(
4373 "[KAPPA-PHASE-PRIME] n_rows={} order={:?} elapsed_s={:.4} slow_path_resets_total={} design_revision={}",
4374 data.nrows(),
4375 kphase_prime_order,
4376 kphase_prime_start.elapsed().as_secs_f64(),
4377 ctx.evaluator.slow_path_reset_count(),
4378 ctx.cache.design_revision(),
4379 );
4380
4381 let kphase_cost_calls = std::cell::Cell::new(0usize);
4382 let kphase_eval_calls = std::cell::Cell::new(0usize);
4383 let kphase_efs_calls = std::cell::Cell::new(0usize);
4384 let kphase_cost_total_s = std::cell::Cell::new(0.0);
4385 let kphase_eval_total_s = std::cell::Cell::new(0.0);
4386 let kphase_efs_total_s = std::cell::Cell::new(0.0);
4387 let kphase_nfree_miss_shape = std::cell::Cell::new(0u64);
4388 let kphase_nfree_miss_value = std::cell::Cell::new(0u64);
4389 let kphase_nfree_miss_gradient = std::cell::Cell::new(0u64);
4390 let kphase_nfree_miss_penalty = std::cell::Cell::new(0u64);
4391 let kphase_nfree_miss_revision = std::cell::Cell::new(0u64);
4392 let kphase_nfree_miss_second_order = std::cell::Cell::new(0u64);
4393 let kphase_nfree_miss_other = std::cell::Cell::new(0u64);
4394 let kphase_optim_start = std::time::Instant::now();
4395 let kphase_log_kappa_dim = coord_dim;
4396 let kphase_slow_resets_start = ctx.evaluator.slow_path_reset_count();
4397 let kphase_design_revision_start = ctx.cache.design_revision();
4398
4399 let lower_effective: std::borrow::Cow<'_, Array1<f64>> = match psi_rank_stable_floor {
4406 Some(floor) if coord_dim == 1 && floor > lower[rho_dim] => {
4407 let mut lifted = lower.clone();
4408 lifted[rho_dim] = floor;
4409 std::borrow::Cow::Owned(lifted)
4410 }
4411 _ => std::borrow::Cow::Borrowed(lower),
4412 };
4413 let lower = lower_effective.as_ref();
4414
4415 let upper_effective: std::borrow::Cow<'_, Array1<f64>> = match psi_rank_stable_ceiling {
4423 Some(ceiling) if coord_dim == 1 && ceiling < upper[rho_dim] => {
4424 let mut clamped = upper.clone();
4425 clamped[rho_dim] = ceiling;
4426 std::borrow::Cow::Owned(clamped)
4427 }
4428 _ => std::borrow::Cow::Borrowed(upper),
4429 };
4430 let upper = upper_effective.as_ref();
4431
4432 let problem = exact_joint_multistart_outer_problem(
4433 theta0,
4434 lower,
4435 upper,
4436 rho_dim,
4437 coord_dim,
4438 theta_dim,
4439 Derivative::Analytic,
4440 if analytic_outer_hessian_available && !suppress_outer_hessian_for_nfree {
4441 DeclaredHessianForm::Either
4442 } else {
4443 DeclaredHessianForm::Unavailable
4448 },
4449 prefer_gradient_only,
4450 suppress_outer_hessian_for_nfree,
4461 seed_risk_profile_for_likelihood_family(&family),
4462 kappa_options.rel_tol.max(1e-6),
4463 kappa_options.max_outer_iter.max(1),
4464 Some(5.0),
4468 Some(kappa_options.log_step.clamp(0.25, 1.0)),
4470 None,
4471 Some((data.nrows(), baseline_design.design.ncols())),
4476 !constant_curvature_term_indices(resolvedspec).is_empty(),
4480 );
4481
4482 let eval_outer = |ctx: &mut &mut SpatialJointContext<'_>,
4483 theta: &Array1<f64>,
4484 order: OuterEvalOrder|
4485 -> Result<OuterEval, EstimationError> {
4486 let t0 = std::time::Instant::now();
4487 let allow_second_order_for_call = matches!(order, OuterEvalOrder::ValueGradientHessian)
4488 && analytic_outer_hessian_available;
4489 let gate = ctx.nfree_skip_gate_status(theta, allow_second_order_for_call, true);
4490 let resets_before = ctx.evaluator.slow_path_reset_count();
4491 let raw = ctx.eval_full(theta, order, analytic_outer_hessian_available);
4492 let reset_delta = ctx
4493 .evaluator
4494 .slow_path_reset_count()
4495 .saturating_sub(resets_before);
4496 if reset_delta > 0 {
4497 if !gate.shape {
4498 kphase_nfree_miss_shape.set(kphase_nfree_miss_shape.get() + reset_delta);
4499 }
4500 if gate.shape && !gate.value {
4501 kphase_nfree_miss_value.set(kphase_nfree_miss_value.get() + reset_delta);
4502 }
4503 if gate.shape && gate.value && !gate.gradient {
4504 kphase_nfree_miss_gradient.set(kphase_nfree_miss_gradient.get() + reset_delta);
4505 }
4506 if gate.shape && gate.value && gate.gradient && !gate.penalty {
4507 kphase_nfree_miss_penalty.set(kphase_nfree_miss_penalty.get() + reset_delta);
4508 }
4509 if gate.shape && gate.value && gate.gradient && gate.penalty && !gate.revision {
4510 kphase_nfree_miss_revision.set(kphase_nfree_miss_revision.get() + reset_delta);
4511 }
4512 if gate.shape
4513 && gate.value
4514 && gate.gradient
4515 && gate.penalty
4516 && gate.revision
4517 && gate.second_order
4518 {
4519 kphase_nfree_miss_second_order
4520 .set(kphase_nfree_miss_second_order.get() + reset_delta);
4521 }
4522 if gate.would_skip(true) {
4523 kphase_nfree_miss_other.set(kphase_nfree_miss_other.get() + reset_delta);
4524 }
4525 }
4526 let elapsed_s = t0.elapsed().as_secs_f64();
4527 kphase_eval_calls.set(kphase_eval_calls.get() + 1);
4528 kphase_eval_total_s.set(kphase_eval_total_s.get() + elapsed_s);
4529 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta, rho_dim);
4530 log::info!(
4531 "[KAPPA-PHASE] phase=eval_outer call={} order={:?} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
4532 kphase_eval_calls.get(),
4533 order,
4534 Some(ctx.cache.design_revision()),
4535 theta_norm,
4536 log_kappa_norm,
4537 elapsed_s,
4538 );
4539 match raw {
4540 Ok((cost, grad, hess)) => Ok(OuterEval {
4541 cost,
4542 gradient: grad,
4543 hessian: hess,
4544 inner_beta_hint: None,
4545 }),
4546 Err(err) if is_recoverable_trial_point_error(&err) => {
4554 log::debug!(
4555 "[{label}] trial point infeasible (kernel design \
4556 not constructible at theta={theta:?}): {err}; retreating",
4557 );
4558 Ok(OuterEval::infeasible(theta_dim))
4559 }
4560 Err(err) => Err(err),
4561 }
4562 };
4563
4564 let mut obj = problem.build_objective_with_eval_order(
4565 &mut ctx,
4566 |ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>| {
4567 let t0 = std::time::Instant::now();
4568 let gate = ctx.nfree_skip_gate_status(theta, false, false);
4569 let resets_before = ctx.evaluator.slow_path_reset_count();
4570 let cost = ctx.eval_cost(theta);
4571 let reset_delta = ctx
4572 .evaluator
4573 .slow_path_reset_count()
4574 .saturating_sub(resets_before);
4575 if reset_delta > 0 {
4576 if !gate.shape {
4577 kphase_nfree_miss_shape.set(kphase_nfree_miss_shape.get() + reset_delta);
4578 }
4579 if gate.shape && !gate.value {
4580 kphase_nfree_miss_value.set(kphase_nfree_miss_value.get() + reset_delta);
4581 }
4582 if gate.shape && gate.value && !gate.penalty {
4583 kphase_nfree_miss_penalty.set(kphase_nfree_miss_penalty.get() + reset_delta);
4584 }
4585 if gate.shape && gate.value && gate.penalty && !gate.revision {
4586 kphase_nfree_miss_revision.set(kphase_nfree_miss_revision.get() + reset_delta);
4587 }
4588 if gate.would_skip(false) {
4589 kphase_nfree_miss_other.set(kphase_nfree_miss_other.get() + reset_delta);
4590 }
4591 }
4592 let elapsed_s = t0.elapsed().as_secs_f64();
4593 kphase_cost_calls.set(kphase_cost_calls.get() + 1);
4594 kphase_cost_total_s.set(kphase_cost_total_s.get() + elapsed_s);
4595 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta, rho_dim);
4596 log::info!(
4597 "[KAPPA-PHASE] phase=cost call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
4598 kphase_cost_calls.get(),
4599 Some(ctx.cache.design_revision()),
4600 theta_norm,
4601 log_kappa_norm,
4602 elapsed_s,
4603 );
4604 Ok(cost)
4605 },
4606 |ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>| {
4607 eval_outer(
4608 ctx,
4609 theta,
4610 if analytic_outer_hessian_available && !suppress_outer_hessian_for_nfree {
4620 OuterEvalOrder::ValueGradientHessian
4621 } else {
4622 OuterEvalOrder::ValueAndGradient
4623 },
4624 )
4625 },
4626 |ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>, order: OuterEvalOrder| {
4627 eval_outer(ctx, theta, order)
4628 },
4629 Some(|ctx: &mut &mut SpatialJointContext<'_>| {
4630 ctx.reset();
4631 }),
4632 Some(|ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>| {
4633 let t0 = std::time::Instant::now();
4634 let eval = ctx.eval_efs(theta);
4635 let elapsed_s = t0.elapsed().as_secs_f64();
4636 kphase_efs_calls.set(kphase_efs_calls.get() + 1);
4637 kphase_efs_total_s.set(kphase_efs_total_s.get() + elapsed_s);
4638 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta, rho_dim);
4639 log::info!(
4640 "[KAPPA-PHASE] phase=efs call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
4641 kphase_efs_calls.get(),
4642 Some(ctx.cache.design_revision()),
4643 theta_norm,
4644 log_kappa_norm,
4645 elapsed_s,
4646 );
4647 eval
4648 }),
4649 );
4650
4651 let run_label = match kind {
4652 SpatialHyperKind::Anisotropic => "aniso-psi joint REML",
4653 SpatialHyperKind::Isotropic => "iso-kappa joint REML",
4654 };
4655 let result = problem.run(&mut obj, run_label).map_err(|e| {
4656 EstimationError::InvalidInput(format!(
4657 "{} analytic optimization failed after exhausting strategy fallbacks: {e}",
4658 kind.adjective(),
4659 ))
4660 })?;
4661 drop(obj);
4662 let kphase_total_s = kphase_optim_start.elapsed().as_secs_f64();
4663 let kphase_slow_resets = ctx
4664 .evaluator
4665 .slow_path_reset_count()
4666 .saturating_sub(kphase_slow_resets_start);
4667 let kphase_design_revision_delta = ctx
4668 .cache
4669 .design_revision()
4670 .saturating_sub(kphase_design_revision_start);
4671 log::info!(
4672 "[KAPPA-PHASE-SUMMARY] n_rows={} log_kappa_dim={} n_cost={} cost_total_s={:.4} n_eval={} eval_total_s={:.4} n_efs={} efs_total_s={:.4} slow_path_resets={} design_revision_delta={} nfree_miss_shape={} nfree_miss_value={} nfree_miss_gradient={} nfree_miss_penalty={} nfree_miss_revision={} nfree_miss_second_order={} nfree_miss_other={} optim_total_s={:.4}",
4673 data.nrows(),
4674 kphase_log_kappa_dim,
4675 kphase_cost_calls.get(),
4676 kphase_cost_total_s.get(),
4677 kphase_eval_calls.get(),
4678 kphase_eval_total_s.get(),
4679 kphase_efs_calls.get(),
4680 kphase_efs_total_s.get(),
4681 kphase_slow_resets,
4682 kphase_design_revision_delta,
4683 kphase_nfree_miss_shape.get(),
4684 kphase_nfree_miss_value.get(),
4685 kphase_nfree_miss_gradient.get(),
4686 kphase_nfree_miss_penalty.get(),
4687 kphase_nfree_miss_revision.get(),
4688 kphase_nfree_miss_second_order.get(),
4689 kphase_nfree_miss_other.get(),
4690 kphase_total_s,
4691 );
4692 let timing = SpatialLengthScaleOptimizationTiming {
4693 log_kappa_dim: kphase_log_kappa_dim,
4694 cost_calls: kphase_cost_calls.get(),
4695 cost_total_s: kphase_cost_total_s.get(),
4696 eval_calls: kphase_eval_calls.get(),
4697 eval_total_s: kphase_eval_total_s.get(),
4698 efs_calls: kphase_efs_calls.get(),
4699 efs_total_s: kphase_efs_total_s.get(),
4700 slow_path_resets: kphase_slow_resets,
4701 design_revision_delta: kphase_design_revision_delta,
4702 nfree_miss_shape: kphase_nfree_miss_shape.get(),
4703 nfree_miss_value: kphase_nfree_miss_value.get(),
4704 nfree_miss_gradient: kphase_nfree_miss_gradient.get(),
4705 nfree_miss_penalty: kphase_nfree_miss_penalty.get(),
4706 nfree_miss_revision: kphase_nfree_miss_revision.get(),
4707 nfree_miss_second_order: kphase_nfree_miss_second_order.get(),
4708 nfree_miss_other: kphase_nfree_miss_other.get(),
4709 optim_total_s: kphase_total_s,
4710 };
4711 if !result.converged {
4712 let rel_to_cost_threshold = options.tol * (1.0_f64 + result.final_value.abs());
4723 if let Some(final_grad) = result
4724 .final_grad_norm
4725 .filter(|v| v.is_finite() && *v <= rel_to_cost_threshold)
4726 {
4727 log::info!(
4728 "[{}] outer optimization hit max_iter={} but \
4729 projected gradient norm {:.3e} ≤ τ·(1+|f|) = {:.3e} \
4730 (τ={:.3e}, |f|={:.3e}); accepting iterate under the mgcv-style \
4731 relative-to-cost REML convergence criterion.",
4732 label,
4733 result.iterations,
4734 final_grad,
4735 rel_to_cost_threshold,
4736 options.tol,
4737 result.final_value.abs(),
4738 );
4739 } else if result.final_value.is_finite() {
4740 log::warn!(
4755 "[{}] {} did not converge after {} iterations \
4756 (final_objective={:.6e}, final_grad_norm={}); keeping the \
4757 frozen baseline geometry instead of aborting the fit.",
4758 label,
4759 kind.adjective(),
4760 result.iterations,
4761 result.final_value,
4762 result.final_grad_norm_report(),
4763 );
4764 return Ok((
4765 SpatialJointOutcome::NonConverged {
4766 iterations: result.iterations,
4767 final_value: result.final_value,
4768 final_grad_norm: result.final_grad_norm,
4769 },
4770 timing,
4771 ));
4772 } else {
4773 crate::bail_invalid_estim!(
4778 "{} analytic optimization diverged after {} iterations (final_objective={:.6e}, final_grad_norm={})",
4779 kind.adjective(),
4780 result.iterations,
4781 result.final_value,
4782 result.final_grad_norm_report(),
4783 );
4784 }
4785 }
4786 log::trace!(
4787 "[{}] converged in {} iterations, final_value={:.6e}, grad_norm={}",
4788 label,
4789 result.iterations,
4790 result.final_value,
4791 result.final_grad_norm_report(),
4792 );
4793 let theta_star = result.rho;
4797 Ok((
4798 SpatialJointOutcome::Optimized {
4799 theta_star,
4800 final_value: result.final_value,
4801 },
4802 timing,
4803 ))
4804}
4805
4806fn set_single_term_spatial_length_scale(
4810 term: &mut SmoothTermSpec,
4811 length_scale: f64,
4812) -> Result<(), EstimationError> {
4813 match &mut term.basis {
4814 SmoothBasisSpec::ThinPlate { spec, .. } => {
4815 spec.length_scale = length_scale;
4816 Ok(())
4817 }
4818 SmoothBasisSpec::Matern { spec, .. } => {
4819 spec.length_scale = length_scale;
4820 Ok(())
4821 }
4822 SmoothBasisSpec::Duchon { spec, .. } => {
4823 spec.length_scale = Some(length_scale);
4824 Ok(())
4825 }
4826 _ => Err(EstimationError::InvalidInput(format!(
4827 "term '{}' does not expose a spatial length scale",
4828 term.name
4829 ))),
4830 }
4831}
4832
4833fn set_single_term_spatial_aniso_log_scales(
4837 term: &mut SmoothTermSpec,
4838 eta: Vec<f64>,
4839) -> Result<(), EstimationError> {
4840 let eta = center_aniso_log_scales(&eta);
4841 match &mut term.basis {
4842 SmoothBasisSpec::Matern { spec, .. } => {
4843 spec.aniso_log_scales = Some(eta);
4844 Ok(())
4845 }
4846 SmoothBasisSpec::Duchon { spec, .. } => {
4847 spec.aniso_log_scales = Some(eta);
4848 Ok(())
4849 }
4850 _ => Err(EstimationError::InvalidInput(format!(
4851 "term '{}' does not support aniso_log_scales",
4852 term.name
4853 ))),
4854 }
4855}
4856
4857pub fn get_constant_curvature_kappa(spec: &TermCollectionSpec, term_idx: usize) -> Option<f64> {
4876 constant_curvature_term_spec(spec, term_idx).map(|cc| cc.kappa)
4877}
4878
4879pub fn constant_curvature_term_indices(spec: &TermCollectionSpec) -> Vec<usize> {
4881 (0..spec.smooth_terms.len())
4882 .filter(|&idx| constant_curvature_term_spec(spec, idx).is_some())
4883 .collect()
4884}
4885
4886
4887#[derive(Debug, Clone)]
4888struct SingleSmoothTermRealization {
4889 design_local: DesignMatrix,
4890 term: SmoothTerm,
4891 dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
4892}
4893
4894impl SingleSmoothTermRealization {
4895 fn active_penaltyinfo(&self) -> Vec<PenaltyInfo> {
4896 self.term
4897 .penaltyinfo_local
4898 .iter()
4899 .filter(|info| info.active)
4900 .cloned()
4901 .collect()
4902 }
4903}
4904
4905fn build_single_smooth_term_realization(
4906 data: ArrayView2<'_, f64>,
4907 termspec: &SmoothTermSpec,
4908) -> Result<SingleSmoothTermRealization, BasisError> {
4909 let raw = build_smooth_design(data, std::slice::from_ref(termspec))?;
4910 finish_single_smooth_term_realization(raw)
4911}
4912
4913fn finish_single_smooth_term_realization(
4914 raw: RawSmoothDesign,
4915) -> Result<SingleSmoothTermRealization, BasisError> {
4916 let RawSmoothDesign {
4917 term_designs,
4918 dropped_penaltyinfo,
4919 terms,
4920 ..
4921 } = raw;
4922 let term = terms.into_iter().next().ok_or_else(|| {
4923 BasisError::InvalidInput("single-term smooth build returned no term".to_string())
4924 })?;
4925 let design = term_designs.into_iter().next().ok_or_else(|| {
4926 BasisError::InvalidInput("single-term smooth build returned no term design".to_string())
4927 })?;
4928
4929 Ok(SingleSmoothTermRealization {
4930 design_local: design,
4931 term,
4932 dropped_penaltyinfo,
4933 })
4934}
4935
4936fn wrap_local_build_as_realization(
4943 mut local: LocalSmoothTermBuild,
4944 termspec: &SmoothTermSpec,
4945) -> Result<SingleSmoothTermRealization, String> {
4946 let p_local = local.dim;
4947 let lb_local = if local.box_reparam {
4948 shape_lower_bounds_local(termspec.shape, p_local)
4949 } else {
4950 None
4951 };
4952
4953 let active_count = local.penaltyinfo.iter().filter(|info| info.active).count();
4954 if active_count != local.penalties.len() {
4955 return Err(format!(
4956 "internal penalty info mismatch for term '{}': active_infos={}, penalties={}",
4957 termspec.name,
4958 active_count,
4959 local.penalties.len()
4960 ));
4961 }
4962
4963 let mut dropped_penaltyinfo = Vec::<DroppedPenaltyBlockInfo>::new();
4964 for info in local.penaltyinfo.iter().filter(|info| !info.active) {
4965 dropped_penaltyinfo.push(DroppedPenaltyBlockInfo {
4966 termname: Some(termspec.name.clone()),
4967 penalty: info.clone(),
4968 });
4969 }
4970 for info in &local.pre_dropped_penaltyinfo {
4971 dropped_penaltyinfo.push(DroppedPenaltyBlockInfo {
4972 termname: Some(termspec.name.clone()),
4973 penalty: info.clone(),
4974 });
4975 }
4976
4977 let applied_rotation: Option<gam_terms::basis::JointNullRotation> = match (
4981 local.joint_null_rotation.take(),
4982 lb_local.is_some(),
4983 local.linear_constraints.is_some(),
4984 ) {
4985 (Some(rot), false, false) => {
4986 let q = &rot.rotation;
4987 let dense = local
4988 .design
4989 .try_to_dense_by_chunks("joint-null absorption rotation (single realization)")
4990 .map_err(|e| {
4991 format!(
4992 "joint-null absorption rotation: dense conversion failed for term '{}': {}",
4993 termspec.name, e
4994 )
4995 })?;
4996 let rotated = gam_linalg::faer_ndarray::fast_ab(&dense, q);
4997 local.design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(rotated));
4998 local.penalties = local
4999 .penalties
5000 .into_iter()
5001 .map(|s_local| {
5002 let qt_s = gam_linalg::faer_ndarray::fast_atb(q, &s_local);
5003 gam_linalg::faer_ndarray::fast_ab(&qt_s, q)
5004 })
5005 .collect();
5006 local.ops = vec![None; local.penalties.len()];
5007 local.kronecker_factored = None;
5008 Some(rot)
5009 }
5010 (Some(_), _, _) => None,
5011 (None, _, _) => None,
5012 };
5013
5014 let smooth_term = SmoothTerm {
5015 name: termspec.name.clone(),
5016 coeff_range: 0..p_local,
5017 shape: termspec.shape,
5018 penalties_local: local.penalties.clone(),
5019 nullspace_dims: local.nullspaces.clone(),
5020 penaltyinfo_local: local.penaltyinfo.clone(),
5021 metadata: local.metadata.clone(),
5022 lower_bounds_local: lb_local,
5023 linear_constraints_local: local.linear_constraints.clone(),
5024 kronecker_factored: local.kronecker_factored.take(),
5025 joint_null_rotation: applied_rotation,
5026 unabsorbed_global_orthogonality: None,
5029 };
5030
5031 Ok(SingleSmoothTermRealization {
5032 design_local: local.design,
5033 term: smooth_term,
5034 dropped_penaltyinfo,
5035 })
5036}
5037
5038fn freeze_geometry_from_metadata(
5049 termspec: &SmoothTermSpec,
5050 metadata: &BasisMetadata,
5051) -> Option<SmoothTermSpec> {
5052 let mut frozen = termspec.clone();
5053 match (&mut frozen.basis, metadata) {
5054 (
5055 SmoothBasisSpec::Matern {
5056 spec,
5057 input_scales: spec_scales,
5058 ..
5059 },
5060 BasisMetadata::Matern {
5061 centers,
5062 input_scales: meta_scales,
5063 identifiability_transform,
5064 nullspace_shrinkage_survived,
5065 ..
5066 },
5067 ) => {
5068 spec.center_strategy = CenterStrategy::UserProvided(centers.clone());
5069 if spec_scales.is_none()
5070 && let Some(s) = meta_scales.clone()
5071 {
5072 *spec_scales = Some(s);
5073 }
5074 if let Some(transform) = identifiability_transform.clone() {
5092 spec.identifiability = MaternIdentifiability::FrozenTransform {
5093 transform,
5094 nullspace_shrinkage_survived: Some(*nullspace_shrinkage_survived),
5095 };
5096 }
5097 Some(frozen)
5098 }
5099 (
5100 SmoothBasisSpec::Duchon {
5101 spec,
5102 input_scales: spec_scales,
5103 ..
5104 },
5105 BasisMetadata::Duchon {
5106 centers,
5107 input_scales: meta_scales,
5108 ..
5109 },
5110 ) => {
5111 spec.center_strategy = CenterStrategy::UserProvided(centers.clone());
5112 if spec_scales.is_none()
5113 && let Some(s) = meta_scales.clone()
5114 {
5115 *spec_scales = Some(s);
5116 }
5117 Some(frozen)
5118 }
5119 (
5120 SmoothBasisSpec::ThinPlate {
5121 spec,
5122 input_scales: spec_scales,
5123 ..
5124 },
5125 BasisMetadata::ThinPlate {
5126 centers,
5127 input_scales: meta_scales,
5128 ..
5129 },
5130 ) => {
5131 spec.center_strategy = CenterStrategy::UserProvided(centers.clone());
5132 if spec_scales.is_none()
5133 && let Some(s) = meta_scales.clone()
5134 {
5135 *spec_scales = Some(s);
5136 }
5137 Some(frozen)
5138 }
5139 _ => None,
5142 }
5143}
5144
5145fn rebuild_smooth_auxiliary_state(
5146 smooth: &mut SmoothDesign,
5147 dropped_penaltyinfo_by_term: &[Vec<DroppedPenaltyBlockInfo>],
5148) -> Result<(), String> {
5149 if dropped_penaltyinfo_by_term.len() != smooth.terms.len() {
5150 return Err(SmoothError::dimension_mismatch(format!(
5151 "smooth dropped-penalty cache mismatch: terms={}, dropped_sets={}",
5152 smooth.terms.len(),
5153 dropped_penaltyinfo_by_term.len()
5154 ))
5155 .into());
5156 }
5157
5158 let total_p = smooth.total_smooth_cols();
5159 let mut coefficient_lower_bounds = Array1::<f64>::from_elem(total_p, f64::NEG_INFINITY);
5160 let mut any_bounds = false;
5161 let mut linear_constraintrows: Vec<Array1<f64>> = Vec::new();
5162 let mut linear_constraint_b: Vec<f64> = Vec::new();
5163
5164 for term in &smooth.terms {
5165 let range = term.coeff_range.clone();
5166 if let Some(lb_local) = term.lower_bounds_local.as_ref() {
5167 if lb_local.len() != range.len() {
5168 return Err(SmoothError::dimension_mismatch(format!(
5169 "smooth lower-bound cache mismatch for term '{}': bounds={}, coeffs={}",
5170 term.name,
5171 lb_local.len(),
5172 range.len()
5173 ))
5174 .into());
5175 }
5176 coefficient_lower_bounds
5177 .slice_mut(s![range.clone()])
5178 .assign(lb_local);
5179 any_bounds = true;
5180 }
5181 if let Some(lin_local) = term.linear_constraints_local.as_ref() {
5182 if lin_local.a.ncols() != range.len() {
5183 return Err(SmoothError::dimension_mismatch(format!(
5184 "smooth linear-constraint cache mismatch for term '{}': cols={}, coeffs={}",
5185 term.name,
5186 lin_local.a.ncols(),
5187 range.len()
5188 ))
5189 .into());
5190 }
5191 for r in 0..lin_local.a.nrows() {
5192 let mut row = Array1::<f64>::zeros(total_p);
5193 row.slice_mut(s![range.clone()]).assign(&lin_local.a.row(r));
5194 linear_constraintrows.push(row);
5195 linear_constraint_b.push(lin_local.b[r]);
5196 }
5197 }
5198 }
5199
5200 smooth.coefficient_lower_bounds = if any_bounds {
5201 Some(coefficient_lower_bounds)
5202 } else {
5203 None
5204 };
5205 smooth.linear_constraints = if linear_constraintrows.is_empty() {
5206 None
5207 } else {
5208 let mut a = Array2::<f64>::zeros((linear_constraintrows.len(), total_p));
5209 for (i, row) in linear_constraintrows.iter().enumerate() {
5210 a.row_mut(i).assign(row);
5211 }
5212 Some(LinearInequalityConstraints {
5213 a,
5214 b: Array1::from_vec(linear_constraint_b),
5215 })
5216 };
5217 smooth.dropped_penaltyinfo = dropped_penaltyinfo_by_term
5218 .iter()
5219 .flat_map(|infos| infos.iter().cloned())
5220 .collect();
5221 Ok(())
5222}
5223
5224fn rebuild_term_collection_auxiliary_state(
5225 spec: &TermCollectionSpec,
5226 design: &mut TermCollectionDesign,
5227) -> Result<(), String> {
5228 if spec.linear_terms.len() != design.linear_ranges.len() {
5229 return Err(SmoothError::dimension_mismatch(format!(
5230 "term-collection linear bookkeeping mismatch: spec_terms={}, design_ranges={}",
5231 spec.linear_terms.len(),
5232 design.linear_ranges.len()
5233 ))
5234 .into());
5235 }
5236
5237 let p_total = design.design.ncols();
5238 let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
5239 let mut coefficient_lower_bounds = Array1::<f64>::from_elem(p_total, f64::NEG_INFINITY);
5240 let mut any_bounds = false;
5241 let mut linear_constraintrows: Vec<Array1<f64>> = Vec::new();
5242 let mut linear_constraint_b: Vec<f64> = Vec::new();
5243
5244 for (linear, (_, range)) in spec.linear_terms.iter().zip(design.linear_ranges.iter()) {
5245 if range.len() != 1 {
5246 return Err(SmoothError::dimension_mismatch(format!(
5247 "linear term '{}' expected one coefficient column, found {}",
5248 linear.name,
5249 range.len()
5250 ))
5251 .into());
5252 }
5253 let col = range.start;
5254 if let Some(lb) = linear.coefficient_min {
5255 let mut row = Array1::<f64>::zeros(p_total);
5256 row[col] = 1.0;
5257 linear_constraintrows.push(row);
5258 linear_constraint_b.push(lb);
5259 }
5260 if let Some(ub) = linear.coefficient_max {
5261 let mut row = Array1::<f64>::zeros(p_total);
5262 row[col] = -1.0;
5263 linear_constraintrows.push(row);
5264 linear_constraint_b.push(-ub);
5265 }
5266 }
5267
5268 if let Some(lb_smooth) = design.smooth.coefficient_lower_bounds.as_ref() {
5269 if lb_smooth.len() != design.smooth.total_smooth_cols() {
5270 return Err(SmoothError::dimension_mismatch(format!(
5271 "smooth lower-bound width mismatch: bounds={}, smooth_cols={}",
5272 lb_smooth.len(),
5273 design.smooth.total_smooth_cols()
5274 ))
5275 .into());
5276 }
5277 coefficient_lower_bounds
5278 .slice_mut(s![
5279 smooth_start..(smooth_start + design.smooth.total_smooth_cols())
5280 ])
5281 .assign(lb_smooth);
5282 any_bounds = true;
5283 }
5284 if let Some(lin_smooth) = design.smooth.linear_constraints.as_ref() {
5285 if lin_smooth.a.ncols() != design.smooth.total_smooth_cols() {
5286 return Err(SmoothError::dimension_mismatch(format!(
5287 "smooth linear-constraint width mismatch: cols={}, smooth_cols={}",
5288 lin_smooth.a.ncols(),
5289 design.smooth.total_smooth_cols()
5290 ))
5291 .into());
5292 }
5293 let mut a_global = Array2::<f64>::zeros((lin_smooth.a.nrows(), p_total));
5294 a_global
5295 .slice_mut(s![
5296 ..,
5297 smooth_start..(smooth_start + design.smooth.total_smooth_cols())
5298 ])
5299 .assign(&lin_smooth.a);
5300 for r in 0..a_global.nrows() {
5301 linear_constraintrows.push(a_global.row(r).to_owned());
5302 linear_constraint_b.push(lin_smooth.b[r]);
5303 }
5304 }
5305
5306 let lower_bound_constraints = if any_bounds {
5307 linear_constraints_from_lower_bounds_global(&coefficient_lower_bounds)
5308 } else {
5309 None
5310 };
5311 let explicit_linear_constraints = if linear_constraintrows.is_empty() {
5312 None
5313 } else {
5314 let mut a = Array2::<f64>::zeros((linear_constraintrows.len(), p_total));
5315 for (i, row) in linear_constraintrows.iter().enumerate() {
5316 a.row_mut(i).assign(row);
5317 }
5318 Some(LinearInequalityConstraints {
5319 a,
5320 b: Array1::from_vec(linear_constraint_b),
5321 })
5322 };
5323
5324 design.coefficient_lower_bounds = if any_bounds {
5325 Some(coefficient_lower_bounds)
5326 } else {
5327 None
5328 };
5329 design.linear_constraints =
5330 merge_linear_constraints_global(explicit_linear_constraints, lower_bound_constraints);
5331 design.dropped_penaltyinfo = design.smooth.dropped_penaltyinfo.clone();
5332 Ok(())
5333}
5334
5335fn theta_values_match(left: &Array1<f64>, right: &Array1<f64>) -> bool {
5336 left.len() == right.len()
5337 && left
5338 .iter()
5339 .zip(right.iter())
5340 .all(|(&l, &r)| l.to_bits() == r.to_bits())
5341}
5342
5343fn latent_values_match(left: &Array1<f64>, right: &Array1<f64>) -> bool {
5344 theta_values_match(left, right)
5345}
5346
5347fn spatial_aniso_matches(left: Option<&[f64]>, right: Option<&[f64]>) -> bool {
5348 match (left, right) {
5349 (None, None) => true,
5350 (Some(a), Some(b)) => {
5351 a.len() == b.len()
5352 && a.iter()
5353 .zip(b.iter())
5354 .all(|(&x, &y)| x.to_bits() == y.to_bits())
5355 }
5356 _ => false,
5357 }
5358}
5359
5360fn spatial_length_scale_matches(left: Option<f64>, right: Option<f64>) -> bool {
5361 match (left, right) {
5362 (None, None) => true,
5363 (Some(a), Some(b)) => a.to_bits() == b.to_bits(),
5364 _ => false,
5365 }
5366}
5367
5368struct FrozenTermCollectionIncrementalRealizer<'d> {
5369 data: ArrayView2<'d, f64>,
5370 spec: TermCollectionSpec,
5371 design: TermCollectionDesign,
5372 fixed_blocks: Vec<DesignBlock>,
5373 dropped_penaltyinfo_by_term: Vec<Vec<DroppedPenaltyBlockInfo>>,
5374 smooth_penalty_ranges: Vec<Range<usize>>,
5375 full_penalty_ranges: Vec<Range<usize>>,
5376 basisworkspace: gam_terms::basis::BasisWorkspace,
5380 spatial_realization_geometry: Vec<Option<SmoothTermSpec>>,
5393 design_revision: u64,
5399}
5400
5401impl<'d> std::fmt::Debug for FrozenTermCollectionIncrementalRealizer<'d> {
5402 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
5403 f.debug_struct("FrozenTermCollectionIncrementalRealizer")
5404 .field("data_shape", &(self.data.nrows(), self.data.ncols()))
5405 .field("fixed_blocks", &self.fixed_blocks.len())
5406 .finish_non_exhaustive()
5407 }
5408}
5409
5410impl<'d> FrozenTermCollectionIncrementalRealizer<'d> {
5411 fn new(
5412 data: ArrayView2<'d, f64>,
5413 spec: TermCollectionSpec,
5414 design: TermCollectionDesign,
5415 ) -> Result<Self, String> {
5416 if spec.smooth_terms.len() != design.smooth.terms.len() {
5417 return Err(SmoothError::dimension_mismatch(format!(
5418 "incremental realizer smooth term mismatch: spec_terms={}, design_terms={}",
5419 spec.smooth_terms.len(),
5420 design.smooth.terms.len()
5421 ))
5422 .into());
5423 }
5424
5425 let mut smooth_cursor = 0usize;
5426 let mut smooth_penalty_ranges = Vec::with_capacity(design.smooth.terms.len());
5427 for term in &design.smooth.terms {
5428 let next = smooth_cursor + term.penalties_local.len();
5429 smooth_penalty_ranges.push(smooth_cursor..next);
5430 smooth_cursor = next;
5431 }
5432 if smooth_cursor != design.smooth.penalties.len() {
5433 return Err(SmoothError::dimension_mismatch(format!(
5434 "incremental realizer smooth penalty mismatch: ranged={}, actual={}",
5435 smooth_cursor,
5436 design.smooth.penalties.len()
5437 ))
5438 .into());
5439 }
5440
5441 let fixed_penalty_offset = design
5442 .penalties
5443 .len()
5444 .checked_sub(design.smooth.penalties.len())
5445 .ok_or_else(|| {
5446 "incremental realizer encountered invalid penalty bookkeeping".to_string()
5447 })?;
5448 let full_penalty_ranges = smooth_penalty_ranges
5449 .iter()
5450 .map(|range| (fixed_penalty_offset + range.start)..(fixed_penalty_offset + range.end))
5451 .collect::<Vec<_>>();
5452 let fixed_blocks = build_term_collection_fixed_blocks(data, &spec)
5453 .map_err(|e| format!("failed to cache fixed term-collection blocks: {e}"))?;
5454
5455 let mut dropped_penaltyinfo_by_term = Vec::with_capacity(spec.smooth_terms.len());
5456 for (term_idx, termspec) in spec.smooth_terms.iter().enumerate() {
5457 let realization =
5458 build_single_smooth_term_realization(data, termspec).map_err(|e| {
5459 format!(
5460 "failed to build cached realization for smooth term '{}' (index {}): {e}",
5461 termspec.name, term_idx
5462 )
5463 })?;
5464 let expected_cols = design.smooth.terms[term_idx].coeff_range.len();
5465 if realization.design_local.ncols() != expected_cols {
5466 return Err(SmoothError::dimension_mismatch(format!(
5467 "cached realization width mismatch for term '{}': cached_cols={}, design_cols={}",
5468 termspec.name,
5469 realization.design_local.ncols(),
5470 expected_cols
5471 ))
5472 .into());
5473 }
5474 if realization.active_penaltyinfo().len()
5475 != design.smooth.terms[term_idx].penalties_local.len()
5476 {
5477 return Err(SmoothError::dimension_mismatch(format!(
5478 "cached realization penalty mismatch for term '{}': cached_penalties={}, design_penalties={}",
5479 termspec.name,
5480 realization.active_penaltyinfo().len(),
5481 design.smooth.terms[term_idx].penalties_local.len()
5482 ))
5483 .into());
5484 }
5485 dropped_penaltyinfo_by_term.push(realization.dropped_penaltyinfo);
5486 }
5487
5488 let geometry_slots = spec.smooth_terms.len();
5489 Ok(Self {
5490 data,
5491 spec,
5492 design,
5493 fixed_blocks,
5494 dropped_penaltyinfo_by_term,
5495 smooth_penalty_ranges,
5496 full_penalty_ranges,
5497 basisworkspace: gam_terms::basis::BasisWorkspace::new(),
5498 spatial_realization_geometry: vec![None; geometry_slots],
5499 design_revision: 0,
5500 })
5501 }
5502
5503 fn design_revision(&self) -> u64 {
5504 self.design_revision
5505 }
5506
5507 fn spec(&self) -> &TermCollectionSpec {
5508 &self.spec
5509 }
5510
5511 fn design(&self) -> &TermCollectionDesign {
5512 &self.design
5513 }
5514
5515 fn supports_nfree_penalty_rekey(&self, spatial_terms: &[usize]) -> bool {
5530 if spatial_terms.len() != 1 {
5531 return false;
5532 }
5533 let term_idx = spatial_terms[0];
5534 matches!(
5535 self.design.smooth.terms.get(term_idx).map(|t| &t.metadata),
5536 Some(BasisMetadata::Duchon { .. } | BasisMetadata::ThinPlate { .. })
5537 )
5538 }
5539
5540 fn supports_nfree_gradient_only_routing(&self, spatial_terms: &[usize]) -> bool {
5549 if spatial_terms.len() != 1 {
5550 return false;
5551 }
5552 let term_idx = spatial_terms[0];
5553 matches!(
5554 self.design.smooth.terms.get(term_idx).map(|t| &t.metadata),
5555 Some(BasisMetadata::Duchon { .. } | BasisMetadata::ThinPlate { .. })
5556 )
5557 }
5558
5559 fn canonical_penalties_at_psi(
5572 &mut self,
5573 spatial_terms: &[usize],
5574 psi: &[f64],
5575 ) -> Result<(Vec<gam_terms::construction::CanonicalPenalty>, Vec<usize>), String> {
5576 if spatial_terms.len() != 1 {
5577 return Err(format!(
5578 "n-free penalty re-key requires exactly one spatial term, found {}",
5579 spatial_terms.len()
5580 ));
5581 }
5582 let term_idx = spatial_terms[0];
5583 let (ls_opt, aniso_from_psi) = spatial_term_psi_to_length_scale_and_aniso(psi);
5589 let termspec =
5592 self.spec.smooth_terms.get(term_idx).ok_or_else(|| {
5593 format!("spatial term {term_idx} out of range for n-free penalty")
5594 })?;
5595 let term = self
5596 .design
5597 .smooth
5598 .terms
5599 .get(term_idx)
5600 .ok_or_else(|| format!("realized smooth term {term_idx} out of range"))?;
5601 let p_total = self.design.design.ncols();
5604 let (locals, nullspace_dims): (Vec<Array2<f64>>, Vec<usize>) = match &term.metadata {
5605 BasisMetadata::Duchon {
5606 centers,
5607 identifiability_transform,
5608 operator_collocation_points,
5609 power,
5610 nullspace_order,
5611 aniso_log_scales,
5612 input_scales,
5613 radial_reparam,
5614 ..
5615 } => {
5616 let operator_penalties = match &termspec.basis {
5617 SmoothBasisSpec::Duchon { spec, .. } => spec.operator_penalties.clone(),
5618 _ => gam_terms::basis::DuchonOperatorPenaltySpec::default(),
5619 };
5620 let effective_ls = match input_scales.as_deref() {
5627 Some(scales) => {
5628 compensate_optional_length_scale_for_standardization(ls_opt, scales)
5629 }
5630 None => ls_opt,
5631 };
5632 gam_terms::basis::duchon_penalties_at_length_scale(
5633 centers.view(),
5634 identifiability_transform.as_ref(),
5635 operator_collocation_points.as_ref().map(|p| p.view()),
5636 &operator_penalties,
5637 *power,
5638 *nullspace_order,
5639 aniso_log_scales.as_deref(),
5640 radial_reparam.as_ref(),
5641 effective_ls,
5642 &mut self.basisworkspace,
5643 )
5644 .map_err(|e| e.to_string())?
5645 }
5646 BasisMetadata::Matern {
5647 centers,
5648 periodic,
5649 nu,
5650 include_intercept,
5651 identifiability_transform,
5652 aniso_log_scales,
5653 input_scales,
5654 ..
5655 } => {
5656 let ls = ls_opt.ok_or_else(|| {
5663 "Matérn n-free penalty re-key requires a finite length-scale".to_string()
5664 })?;
5665 let effective_ls = match input_scales.as_deref() {
5666 Some(scales) => compensate_length_scale_for_standardization(ls, scales),
5667 None => ls,
5668 };
5669 let aniso_for_penalty = aniso_from_psi.as_deref().or(aniso_log_scales.as_deref());
5670 let (penalties, nullspace_dims, _info) =
5681 matern_operator_penalty_triplet_at_length_scale(
5682 centers.view(),
5683 periodic.as_deref(),
5684 identifiability_transform.as_ref(),
5685 *nu,
5686 *include_intercept,
5687 aniso_for_penalty,
5688 effective_ls,
5689 )
5690 .map_err(|e| e.to_string())?;
5691 (penalties, nullspace_dims)
5692 }
5693 BasisMetadata::ThinPlate {
5694 centers,
5695 identifiability_transform,
5696 radial_reparam,
5697 ..
5698 } => {
5699 let ls = ls_opt.ok_or_else(|| {
5700 "thin-plate n-free penalty re-key requires a finite length-scale".to_string()
5701 })?;
5702 let double_penalty = match &termspec.basis {
5703 SmoothBasisSpec::ThinPlate { spec, .. } => spec.double_penalty,
5704 _ => false,
5705 };
5706 gam_terms::basis::thin_plate_penalties_at_length_scale(
5707 centers.view(),
5708 identifiability_transform.as_ref(),
5709 radial_reparam.as_ref(),
5710 ls,
5711 double_penalty,
5712 &mut self.basisworkspace,
5713 )
5714 .map_err(|e| e.to_string())?
5715 }
5716 other => {
5717 return Err(format!(
5718 "n-free penalty re-key unsupported for basis metadata {:?}",
5719 std::mem::discriminant(other)
5720 ));
5721 }
5722 };
5723 let templates = &self.design.penalties;
5728 if templates.len() != locals.len() {
5729 return Err(format!(
5730 "n-free penalty re-key produced {} blocks but the frozen design carries {} \
5731 — penalty topology is not ψ-stable",
5732 locals.len(),
5733 templates.len()
5734 ));
5735 }
5736 let specs: Vec<gam_solve::estimate::PenaltySpec> = templates
5737 .iter()
5738 .zip(locals.into_iter())
5739 .map(|(tmpl, local)| gam_solve::estimate::PenaltySpec::Block {
5740 local,
5741 col_range: tmpl.col_range.clone(),
5742 prior_mean: tmpl.prior_mean.clone(),
5743 structure_hint: tmpl.structure_hint.clone(),
5744 op: tmpl.op.clone(),
5745 })
5746 .collect();
5747 gam_terms::construction::canonicalize_penalty_specs(
5748 &specs,
5749 &nullspace_dims,
5750 p_total,
5751 "nfree-psi-penalty",
5752 )
5753 .map_err(|e| e.to_string())
5754 }
5755
5756 fn canonical_penalty_derivatives_at_psi(
5757 &mut self,
5758 spatial_terms: &[usize],
5759 psi: &[f64],
5760 ) -> Result<(Range<usize>, usize, Vec<Array2<f64>>), String> {
5761 if spatial_terms.len() != 1 {
5762 return Err(format!(
5763 "n-free penalty derivative re-key requires exactly one spatial term, found {}",
5764 spatial_terms.len()
5765 ));
5766 }
5767 let term_idx = spatial_terms[0];
5768 let (ls_opt, aniso_from_psi) = spatial_term_psi_to_length_scale_and_aniso(psi);
5769 let termspec = self.spec.smooth_terms.get(term_idx).ok_or_else(|| {
5770 format!("spatial term {term_idx} out of range for n-free penalty derivative")
5771 })?;
5772 let term = self
5773 .design
5774 .smooth
5775 .terms
5776 .get(term_idx)
5777 .ok_or_else(|| format!("realized smooth term {term_idx} out of range"))?;
5778 let p_total = self.design.design.ncols();
5779 let smooth_start = p_total.saturating_sub(self.design.smooth.total_smooth_cols());
5780 let global_range =
5781 (smooth_start + term.coeff_range.start)..(smooth_start + term.coeff_range.end);
5782
5783 let locals = match &term.metadata {
5784 BasisMetadata::Duchon {
5785 centers,
5786 identifiability_transform,
5787 operator_collocation_points,
5788 power,
5789 nullspace_order,
5790 aniso_log_scales,
5791 input_scales,
5792 radial_reparam,
5793 ..
5794 } => {
5795 let mut spec = match &termspec.basis {
5796 SmoothBasisSpec::Duchon { spec, .. } => spec.clone(),
5797 _ => {
5798 return Err(
5799 "Duchon n-free penalty derivative requires a Duchon term spec"
5800 .to_string(),
5801 );
5802 }
5803 };
5804 let effective_ls = match input_scales.as_deref() {
5805 Some(scales) => {
5806 compensate_optional_length_scale_for_standardization(ls_opt, scales)
5807 }
5808 None => ls_opt,
5809 };
5810 spec.length_scale = effective_ls;
5811 spec.power = *power;
5812 spec.nullspace_order = *nullspace_order;
5813 spec.aniso_log_scales = aniso_log_scales.clone();
5814 spec.radial_reparam = radial_reparam.clone();
5817 if spec.length_scale.is_none() {
5818 return Err(
5819 "Duchon n-free penalty derivative requires a hybrid length-scale"
5820 .to_string(),
5821 );
5822 }
5823 let collocation = operator_collocation_points
5824 .as_ref()
5825 .map(|points| points.view())
5826 .unwrap_or_else(|| centers.view());
5827 let (_native_sources, mut first, _native_second) =
5828 gam_terms::basis::build_duchon_native_penalty_psi_derivatives(
5829 centers.view(),
5830 &spec,
5831 identifiability_transform.as_ref(),
5832 &mut self.basisworkspace,
5833 )
5834 .map_err(|e| e.to_string())?;
5835 let (_operator_sources, operator_first, _operator_second) =
5836 gam_terms::basis::build_duchon_operator_penalty_psi_derivatives(
5837 collocation,
5838 centers.view(),
5839 &spec,
5840 identifiability_transform.as_ref(),
5841 &mut self.basisworkspace,
5842 )
5843 .map_err(|e| e.to_string())?;
5844 first.extend(operator_first);
5845 first
5846 }
5847 BasisMetadata::Matern {
5848 centers,
5849 periodic,
5850 nu,
5851 include_intercept,
5852 identifiability_transform,
5853 aniso_log_scales,
5854 input_scales,
5855 ..
5856 } => {
5857 let ls = ls_opt.ok_or_else(|| {
5858 "Matérn n-free penalty derivative requires a finite length-scale".to_string()
5859 })?;
5860 let effective_ls = match input_scales.as_deref() {
5861 Some(scales) => compensate_length_scale_for_standardization(ls, scales),
5862 None => ls,
5863 };
5864 let penalty_centers =
5865 gam_terms::basis::expand_periodic_centers(¢ers.to_owned(), periodic.as_deref())
5866 .map_err(|e| e.to_string())?;
5867 let aniso_for_penalty = aniso_from_psi.as_deref().or(aniso_log_scales.as_deref());
5868 let (first, _second) = gam_terms::basis::build_matern_operator_penalty_psi_derivatives(
5869 penalty_centers.view(),
5870 effective_ls,
5871 *nu,
5872 *include_intercept,
5873 identifiability_transform.as_ref(),
5874 aniso_for_penalty,
5875 )
5876 .map_err(|e| e.to_string())?;
5877 first
5878 }
5879 BasisMetadata::ThinPlate {
5880 centers,
5881 identifiability_transform,
5882 radial_reparam,
5883 ..
5884 } => {
5885 let ls = ls_opt.ok_or_else(|| {
5886 "thin-plate n-free penalty derivative requires a finite length-scale"
5887 .to_string()
5888 })?;
5889 let mut spec = match &termspec.basis {
5890 SmoothBasisSpec::ThinPlate { spec, .. } => spec.clone(),
5891 _ => {
5892 return Err(
5893 "thin-plate n-free penalty derivative requires a ThinPlate term spec"
5894 .to_string(),
5895 );
5896 }
5897 };
5898 spec.length_scale = ls;
5899 if spec.radial_reparam.is_none() {
5900 spec.radial_reparam = radial_reparam.clone();
5901 }
5902 let (primary, _primary_second) =
5903 gam_terms::basis::build_thin_plate_penalty_psi_derivativeswithworkspace(
5904 centers.view(),
5905 &spec,
5906 identifiability_transform.as_ref(),
5907 &mut self.basisworkspace,
5908 )
5909 .map_err(|e| e.to_string())?;
5910 if self.design.penalties.len() > 1 {
5911 vec![primary.clone(), Array2::<f64>::zeros(primary.raw_dim())]
5912 } else {
5913 vec![primary]
5914 }
5915 }
5916 other => {
5917 return Err(format!(
5918 "n-free penalty derivative re-key unsupported for basis metadata {:?}",
5919 std::mem::discriminant(other)
5920 ));
5921 }
5922 };
5923 if locals.len() != self.design.penalties.len() {
5924 return Err(format!(
5925 "n-free penalty derivative re-key produced {} blocks but the frozen design carries {} \
5926 — penalty topology is not ψ-stable",
5927 locals.len(),
5928 self.design.penalties.len()
5929 ));
5930 }
5931 Ok((global_range, p_total, locals))
5932 }
5933
5934 fn apply_log_kappa(
5935 &mut self,
5936 log_kappa: &SpatialLogKappaCoords,
5937 term_indices: &[usize],
5938 ) -> Result<(), String> {
5939 if term_indices.len() != log_kappa.dims_per_term().len() {
5940 return Err(SmoothError::dimension_mismatch(format!(
5941 "incremental realizer log-kappa term mismatch: term_indices={}, dims_per_term={}",
5942 term_indices.len(),
5943 log_kappa.dims_per_term().len()
5944 ))
5945 .into());
5946 }
5947
5948 let mut any_changed = false;
5949 for (slot, &term_idx) in term_indices.iter().enumerate() {
5950 any_changed |= self.apply_log_kappa_to_term(term_idx, log_kappa.term_slice(slot))?;
5951 }
5952
5953 if any_changed {
5954 self.refresh_full_design_operator()?;
5955 rebuild_smooth_auxiliary_state(
5956 &mut self.design.smooth,
5957 &self.dropped_penaltyinfo_by_term,
5958 )?;
5959 rebuild_term_collection_auxiliary_state(&self.spec, &mut self.design)?;
5960 self.design_revision = self.design_revision.wrapping_add(1);
5961 }
5962 Ok(())
5963 }
5964
5965 fn apply_log_kappa_to_term(&mut self, term_idx: usize, psi: &[f64]) -> Result<bool, String> {
5966 if !spatial_term_supports_hyper_optimization(&self.spec, term_idx) {
5967 return Err(SmoothError::invalid_config(format!(
5968 "incremental realizer term {term_idx} does not expose spatial hyperparameters"
5969 ))
5970 .into());
5971 }
5972 let measure_jet_term = measure_jet_term_spec(&self.spec, term_idx).is_some();
5976 let constant_curvature_term = constant_curvature_term_spec(&self.spec, term_idx).is_some();
5980 let mut next_length_scale = None;
5981 let mut next_aniso: Option<Vec<f64>> = None;
5982 if measure_jet_term {
5983 if !set_measure_jet_psi_dials(&mut self.spec, term_idx, psi)
5984 .map_err(|e| e.to_string())?
5985 {
5986 return Ok(false);
5987 }
5988 } else if constant_curvature_term {
5989 if !set_constant_curvature_kappa(&mut self.spec, term_idx, psi)
5990 .map_err(|e| e.to_string())?
5991 {
5992 return Ok(false);
5993 }
5994 } else {
5995 let current_length_scale = get_spatial_length_scale(&self.spec, term_idx);
5996 let current_aniso = get_spatial_aniso_log_scales(&self.spec, term_idx);
5997 let (ls, eta) = spatial_term_psi_to_length_scale_and_aniso(psi);
5998 next_length_scale = ls;
5999 next_aniso = eta;
6000 let same_length = spatial_length_scale_matches(current_length_scale, next_length_scale);
6001 let same_aniso = spatial_aniso_matches(current_aniso.as_deref(), next_aniso.as_deref());
6002 if same_length && same_aniso {
6003 return Ok(false);
6004 }
6005 if let Some(length_scale) = next_length_scale {
6006 set_spatial_length_scale(&mut self.spec, term_idx, length_scale)
6007 .map_err(|e| e.to_string())?;
6008 }
6009 if let Some(eta) = next_aniso.clone() {
6010 set_spatial_aniso_log_scales(&mut self.spec, term_idx, eta)
6011 .map_err(|e| e.to_string())?;
6012 }
6013 }
6014
6015 let geometry_slot = self
6026 .spatial_realization_geometry
6027 .get(term_idx)
6028 .ok_or_else(|| format!("incremental realizer geometry slot {term_idx} out of range"))?;
6029 let mut build_spec = match geometry_slot {
6030 Some(cached) => cached.clone(),
6031 None => self
6032 .spec
6033 .smooth_terms
6034 .get(term_idx)
6035 .ok_or_else(|| format!("incremental realizer smooth term {term_idx} out of range"))?
6036 .clone(),
6037 };
6038 if measure_jet_term {
6039 set_single_term_measure_jet_psi_dials(&mut build_spec, psi)
6043 .map_err(|e| e.to_string())?;
6044 } else if constant_curvature_term {
6045 set_single_term_constant_curvature_kappa(&mut build_spec, psi)
6050 .map_err(|e| e.to_string())?;
6051 } else {
6052 if let Some(length_scale) = next_length_scale {
6053 set_single_term_spatial_length_scale(&mut build_spec, length_scale)
6054 .map_err(|e| e.to_string())?;
6055 }
6056 if let Some(eta) = next_aniso {
6057 set_single_term_spatial_aniso_log_scales(&mut build_spec, eta)
6058 .map_err(|e| e.to_string())?;
6059 }
6060 }
6061
6062 let termname = build_spec.name.clone();
6063 let local = build_single_local_smooth_term(
6064 self.data,
6065 &build_spec,
6066 &mut self.basisworkspace,
6067 )
6068 .map_err(|e| {
6069 format!(
6070 "failed to rebuild smooth term '{termname}' during incremental κ realization: {e}"
6071 )
6072 })?;
6073
6074 if self.spatial_realization_geometry[term_idx].is_none()
6079 && let Some(frozen) = freeze_geometry_from_metadata(&build_spec, &local.metadata)
6080 {
6081 if let (
6093 SmoothBasisSpec::Matern {
6094 spec: frozen_spec, ..
6095 },
6096 Some(SmoothBasisSpec::Matern {
6097 spec: live_spec, ..
6098 }),
6099 ) = (
6100 &frozen.basis,
6101 self.spec
6102 .smooth_terms
6103 .get_mut(term_idx)
6104 .map(|t| &mut t.basis),
6105 ) {
6106 live_spec.identifiability = frozen_spec.identifiability.clone();
6107 live_spec.center_strategy = frozen_spec.center_strategy.clone();
6108 }
6109 self.spatial_realization_geometry[term_idx] = Some(frozen);
6110 }
6111
6112 let realization = wrap_local_build_as_realization(local, &build_spec)?;
6113 self.replace_term_realization(term_idx, realization)?;
6114 Ok(true)
6115 }
6116
6117 fn replace_term_realization(
6118 &mut self,
6119 term_idx: usize,
6120 realization: SingleSmoothTermRealization,
6121 ) -> Result<(), String> {
6122 let t_replace = std::time::Instant::now();
6123 let SingleSmoothTermRealization {
6124 design_local,
6125 term,
6126 dropped_penaltyinfo,
6127 } = realization;
6128 let SmoothTerm {
6129 name,
6130 penalties_local,
6131 nullspace_dims,
6132 penaltyinfo_local,
6133 metadata,
6134 lower_bounds_local,
6135 linear_constraints_local,
6136 joint_null_rotation,
6137 ..
6138 } = term;
6139 let coeff_range = self
6140 .design
6141 .smooth
6142 .terms
6143 .get(term_idx)
6144 .ok_or_else(|| format!("incremental realizer smooth term {term_idx} out of range"))?
6145 .coeff_range
6146 .clone();
6147 if design_local.ncols() != coeff_range.len() {
6148 return Err(SmoothError::dimension_mismatch(format!(
6149 "incremental realizer width mismatch for term {}: rebuilt_cols={}, cached_cols={}",
6150 term_idx,
6151 design_local.ncols(),
6152 coeff_range.len()
6153 ))
6154 .into());
6155 }
6156 if design_local.nrows() != self.design.design.nrows() {
6157 return Err(SmoothError::dimension_mismatch(format!(
6158 "incremental realizer row mismatch for term {}: rebuilt_rows={}, design_rows={}",
6159 term_idx,
6160 design_local.nrows(),
6161 self.design.design.nrows()
6162 ))
6163 .into());
6164 }
6165
6166 let active_penaltyinfo = penaltyinfo_local
6167 .iter()
6168 .filter(|info| info.active)
6169 .cloned()
6170 .collect::<Vec<_>>();
6171 let smooth_penalty_range = self
6172 .smooth_penalty_ranges
6173 .get(term_idx)
6174 .ok_or_else(|| {
6175 format!("incremental realizer missing smooth penalty range for term {term_idx}")
6176 })?
6177 .clone();
6178 let full_penalty_range = self
6179 .full_penalty_ranges
6180 .get(term_idx)
6181 .ok_or_else(|| {
6182 format!("incremental realizer missing full penalty range for term {term_idx}")
6183 })?
6184 .clone();
6185 if active_penaltyinfo.len() != smooth_penalty_range.len()
6186 || penalties_local.len() != smooth_penalty_range.len()
6187 || nullspace_dims.len() != smooth_penalty_range.len()
6188 {
6189 return Err(SmoothError::dimension_mismatch(format!(
6190 "incremental realizer topology changed for term '{}': penalties={}, infos={}, nullspaces={}, cached_penalties={}",
6191 name,
6192 penalties_local.len(),
6193 active_penaltyinfo.len(),
6194 nullspace_dims.len(),
6195 smooth_penalty_range.len()
6196 ))
6197 .into());
6198 }
6199
6200 self.design.smooth.term_designs[term_idx] = design_local;
6201
6202 for (offset, penalty_local) in penalties_local.iter().enumerate() {
6203 let smooth_penalty_idx = smooth_penalty_range.start + offset;
6204 let full_penalty_idx = full_penalty_range.start + offset;
6205 let nullspace_dim = nullspace_dims[offset];
6206 let penalty_info = active_penaltyinfo[offset].clone();
6207
6208 if penalty_local.nrows() != coeff_range.len()
6209 || penalty_local.ncols() != coeff_range.len()
6210 {
6211 return Err(SmoothError::dimension_mismatch(format!(
6212 "incremental realizer penalty shape mismatch for term '{}' penalty {}: \
6213 penalty is {}x{} but coeff_range has {} columns",
6214 name,
6215 offset,
6216 penalty_local.nrows(),
6217 penalty_local.ncols(),
6218 coeff_range.len()
6219 ))
6220 .into());
6221 }
6222
6223 let smooth_penalty = self
6224 .design
6225 .smooth
6226 .penalties
6227 .get_mut(smooth_penalty_idx)
6228 .ok_or_else(|| {
6229 format!(
6230 "incremental realizer smooth penalty {} out of range for term {}",
6231 smooth_penalty_idx, term_idx
6232 )
6233 })?;
6234 smooth_penalty.local.assign(penalty_local);
6237
6238 let full_bp = self
6239 .design
6240 .penalties
6241 .get_mut(full_penalty_idx)
6242 .ok_or_else(|| {
6243 format!(
6244 "incremental realizer full penalty {} out of range for term {}",
6245 full_penalty_idx, term_idx
6246 )
6247 })?;
6248 full_bp.local.assign(penalty_local);
6251
6252 self.design.smooth.nullspace_dims[smooth_penalty_idx] = nullspace_dim;
6253 self.design.nullspace_dims[full_penalty_idx] = nullspace_dim;
6254
6255 self.design.smooth.penaltyinfo[smooth_penalty_idx].global_index = smooth_penalty_idx;
6256 self.design.smooth.penaltyinfo[smooth_penalty_idx].termname = Some(name.clone());
6257 self.design.smooth.penaltyinfo[smooth_penalty_idx].penalty = penalty_info.clone();
6258
6259 self.design.penaltyinfo[full_penalty_idx].global_index = full_penalty_idx;
6260 self.design.penaltyinfo[full_penalty_idx].termname = Some(name.clone());
6261 self.design.penaltyinfo[full_penalty_idx].penalty = penalty_info;
6262 }
6263
6264 let target_term = self.design.smooth.terms.get_mut(term_idx).ok_or_else(|| {
6265 format!("incremental realizer smooth term {term_idx} disappeared during replacement")
6266 })?;
6267 target_term.penalties_local = penalties_local;
6268 target_term.nullspace_dims = nullspace_dims;
6269 target_term.penaltyinfo_local = penaltyinfo_local;
6270 target_term.metadata = metadata;
6271 target_term.lower_bounds_local = lower_bounds_local;
6272 target_term.linear_constraints_local = linear_constraints_local;
6273 target_term.joint_null_rotation = joint_null_rotation;
6274 self.dropped_penaltyinfo_by_term[term_idx] = dropped_penaltyinfo;
6275 log::info!(
6276 "[STAGE] smooth basis rebuild (term {}, '{}', cols={}): {:.3}s",
6277 term_idx,
6278 target_term.name,
6279 coeff_range.len(),
6280 t_replace.elapsed().as_secs_f64(),
6281 );
6282 Ok(())
6283 }
6284
6285 fn refresh_full_design_operator(&mut self) -> Result<(), String> {
6286 let mut blocks = Vec::<DesignBlock>::with_capacity(
6287 self.fixed_blocks.len() + self.design.smooth.term_designs.len(),
6288 );
6289 blocks.extend(self.fixed_blocks.iter().cloned());
6290 for term_design in &self.design.smooth.term_designs {
6291 blocks.push(DesignBlock::from(term_design));
6292 }
6293 self.design.design = assemble_term_collection_design_matrix(blocks)
6294 .map_err(|e| format!("failed to refresh term-collection design: {e}"))?;
6295 Ok(())
6296 }
6297}
6298
6299fn build_term_collection_fixed_blocks(
6300 data: ArrayView2<'_, f64>,
6301 spec: &TermCollectionSpec,
6302) -> Result<Vec<DesignBlock>, BasisError> {
6303 let mut blocks = Vec::<DesignBlock>::new();
6304 if !term_collection_has_one_sided_anchored_bspline(spec) {
6305 blocks.push(DesignBlock::Intercept(data.nrows()));
6306 }
6307
6308 if !spec.linear_terms.is_empty() {
6309 let mut linear_block = Array2::<f64>::zeros((data.nrows(), spec.linear_terms.len()));
6310 for (j, linear) in spec.linear_terms.iter().enumerate() {
6311 let column = linear
6315 .realized_design_column(data)
6316 .map_err(BasisError::InvalidInput)?;
6317 linear_block.column_mut(j).assign(&column);
6318 }
6319 blocks.push(DesignBlock::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
6320 linear_block,
6321 )));
6322 }
6323
6324 for term in &spec.random_effect_terms {
6325 let block = build_random_effect_block(data, term)?;
6326 let re_op = RandomEffectOperator::new(block.group_ids, block.num_groups);
6327 blocks.push(DesignBlock::RandomEffect(Arc::new(re_op)));
6328 }
6329
6330 Ok(blocks)
6331}
6332
6333pub struct SpatialLengthScaleOptimizationResult<FitOut> {
6338 pub resolved_specs: Vec<TermCollectionSpec>,
6339 pub designs: Vec<TermCollectionDesign>,
6340 pub fit: FitOut,
6341 pub timing: Option<SpatialLengthScaleOptimizationTiming>,
6342}
6343
6344#[derive(Debug, Clone)]
6346pub struct ExactJointHyperSetup {
6347 rho0: Array1<f64>,
6348 rho_lower: Array1<f64>,
6349 rho_upper: Array1<f64>,
6350 log_kappa0: SpatialLogKappaCoords,
6351 log_kappa_lower: SpatialLogKappaCoords,
6352 log_kappa_upper: SpatialLogKappaCoords,
6353 auxiliary0: Array1<f64>,
6354 auxiliary_lower: Array1<f64>,
6355 auxiliary_upper: Array1<f64>,
6356}
6357
6358impl ExactJointHyperSetup {
6359 fn sanitize_rho_seed(
6360 rho0: Array1<f64>,
6361 rho_lower: &Array1<f64>,
6362 rho_upper: &Array1<f64>,
6363 ) -> Array1<f64> {
6364 Array1::from_iter(rho0.iter().enumerate().map(|(idx, &value)| {
6365 let lo = rho_lower[idx];
6366 let hi = rho_upper[idx];
6367 let fallback = 0.0_f64.clamp(lo, hi);
6368 if value.is_finite() {
6369 value.clamp(lo, hi)
6370 } else {
6371 fallback
6372 }
6373 }))
6374 }
6375
6376 pub(crate) fn new(
6377 rho0: Array1<f64>,
6378 rho_lower: Array1<f64>,
6379 rho_upper: Array1<f64>,
6380 log_kappa0: SpatialLogKappaCoords,
6381 log_kappa_lower: SpatialLogKappaCoords,
6382 log_kappa_upper: SpatialLogKappaCoords,
6383 ) -> Self {
6384 let rho0 = Self::sanitize_rho_seed(rho0, &rho_lower, &rho_upper);
6385 Self {
6386 rho0,
6387 rho_lower,
6388 rho_upper,
6389 log_kappa0,
6390 log_kappa_lower,
6391 log_kappa_upper,
6392 auxiliary0: Array1::zeros(0),
6393 auxiliary_lower: Array1::zeros(0),
6394 auxiliary_upper: Array1::zeros(0),
6395 }
6396 }
6397
6398 pub(crate) fn with_auxiliary(
6399 mut self,
6400 auxiliary0: Array1<f64>,
6401 auxiliary_lower: Array1<f64>,
6402 auxiliary_upper: Array1<f64>,
6403 ) -> Self {
6404 assert_eq!(
6405 auxiliary0.len(),
6406 auxiliary_lower.len(),
6407 "auxiliary lower bound length mismatch"
6408 );
6409 assert_eq!(
6410 auxiliary0.len(),
6411 auxiliary_upper.len(),
6412 "auxiliary upper bound length mismatch"
6413 );
6414 self.auxiliary0 = Self::sanitize_rho_seed(auxiliary0, &auxiliary_lower, &auxiliary_upper);
6415 self.auxiliary_lower = auxiliary_lower;
6416 self.auxiliary_upper = auxiliary_upper;
6417 self
6418 }
6419
6420 pub(crate) fn rho_dim(&self) -> usize {
6421 self.rho0.len()
6422 }
6423
6424 pub(crate) fn log_kappa_dim(&self) -> usize {
6425 self.log_kappa0.len()
6426 }
6427
6428 pub(crate) fn auxiliary_dim(&self) -> usize {
6429 self.auxiliary0.len()
6430 }
6431
6432 pub(crate) fn theta0(&self) -> Array1<f64> {
6433 let mut out =
6434 Array1::<f64>::zeros(self.rho_dim() + self.log_kappa_dim() + self.auxiliary_dim());
6435 out.slice_mut(s![..self.rho_dim()]).assign(&self.rho0);
6436 out.slice_mut(s![self.rho_dim()..self.rho_dim() + self.log_kappa_dim()])
6437 .assign(self.log_kappa0.as_array());
6438 out.slice_mut(s![self.rho_dim() + self.log_kappa_dim()..])
6439 .assign(&self.auxiliary0);
6440 out
6441 }
6442
6443 pub(crate) fn lower(&self) -> Array1<f64> {
6444 let mut out =
6445 Array1::<f64>::zeros(self.rho_dim() + self.log_kappa_dim() + self.auxiliary_dim());
6446 out.slice_mut(s![..self.rho_dim()]).assign(&self.rho_lower);
6447 out.slice_mut(s![self.rho_dim()..self.rho_dim() + self.log_kappa_dim()])
6448 .assign(self.log_kappa_lower.as_array());
6449 out.slice_mut(s![self.rho_dim() + self.log_kappa_dim()..])
6450 .assign(&self.auxiliary_lower);
6451 out
6452 }
6453
6454 pub(crate) fn upper(&self) -> Array1<f64> {
6455 let mut out =
6456 Array1::<f64>::zeros(self.rho_dim() + self.log_kappa_dim() + self.auxiliary_dim());
6457 out.slice_mut(s![..self.rho_dim()]).assign(&self.rho_upper);
6458 out.slice_mut(s![self.rho_dim()..self.rho_dim() + self.log_kappa_dim()])
6459 .assign(self.log_kappa_upper.as_array());
6460 out.slice_mut(s![self.rho_dim() + self.log_kappa_dim()..])
6461 .assign(&self.auxiliary_upper);
6462 out
6463 }
6464
6465 pub(crate) fn log_kappa_dims_per_term(&self) -> Vec<usize> {
6467 self.log_kappa0.dims_per_term().to_vec()
6468 }
6469}
6470
6471struct ExactJointDesignCache<'d> {
6477 realizers: Vec<FrozenTermCollectionIncrementalRealizer<'d>>,
6478 block_term_indices: Vec<Vec<usize>>,
6479 current_theta: Option<Array1<f64>>,
6480 last_cost: Option<f64>,
6481 last_eval: Option<(
6482 f64,
6483 Array1<f64>,
6484 gam_problem::HessianResult,
6485 )>,
6486 rho_dim: usize,
6487 all_dims: Vec<usize>,
6488 log_kappa_dim: usize,
6489 block_term_counts: Vec<usize>,
6490}
6491
6492impl<'d> ExactJointDesignCache<'d> {
6493 fn new(
6494 data: ArrayView2<'d, f64>,
6495 blocks: Vec<(TermCollectionSpec, TermCollectionDesign, Vec<usize>)>,
6496 rho_dim: usize,
6497 all_dims: Vec<usize>,
6498 ) -> Result<Self, String> {
6499 let n_blocks = blocks.len();
6500 let mut realizers = Vec::with_capacity(n_blocks);
6501 let mut block_term_indices = Vec::with_capacity(n_blocks);
6502 let mut block_term_counts = Vec::with_capacity(n_blocks);
6503
6504 for (spec, design, terms) in blocks {
6505 block_term_counts.push(terms.len());
6506 block_term_indices.push(terms);
6507 realizers.push(FrozenTermCollectionIncrementalRealizer::new(
6508 data, spec, design,
6509 )?);
6510 }
6511
6512 Ok(Self {
6513 realizers,
6514 block_term_indices,
6515 current_theta: None,
6516 last_cost: None,
6517 last_eval: None,
6518 rho_dim,
6519 log_kappa_dim: all_dims.iter().sum(),
6520 all_dims,
6521 block_term_counts,
6522 })
6523 }
6524
6525 fn ensure_theta(&mut self, theta: &Array1<f64>) -> Result<(), String> {
6526 if self
6527 .current_theta
6528 .as_ref()
6529 .is_some_and(|cached| theta_values_match(cached, theta))
6530 {
6531 return Ok(());
6532 }
6533
6534 let t_ensure = std::time::Instant::now();
6535 let kappa_theta_len = self.rho_dim + self.log_kappa_dim;
6536 if theta.len() < kappa_theta_len {
6537 return Err(SmoothError::dimension_mismatch(format!(
6538 "exact-joint theta length mismatch: got {}, expected at least {} (rho_dim={}, log_kappa_dim={})",
6539 theta.len(),
6540 kappa_theta_len,
6541 self.rho_dim,
6542 self.log_kappa_dim
6543 ))
6544 .into());
6545 }
6546 let theta_kappa = theta.slice(s![..kappa_theta_len]).to_owned();
6547 let full_log_kappa = SpatialLogKappaCoords::from_theta_tail_with_dims(
6548 &theta_kappa,
6549 self.rho_dim,
6550 self.all_dims.clone(),
6551 );
6552
6553 let n = self.realizers.len();
6557 let mut remaining = full_log_kappa;
6558 for block_idx in 0..n {
6559 let count = self.block_term_counts[block_idx];
6560 if block_idx < n - 1 {
6561 let (block_lk, rest) = remaining.split_at(count);
6562 self.realizers[block_idx]
6563 .apply_log_kappa(&block_lk, &self.block_term_indices[block_idx])?;
6564 remaining = rest;
6565 } else {
6566 self.realizers[block_idx]
6568 .apply_log_kappa(&remaining, &self.block_term_indices[block_idx])?;
6569 }
6570 }
6571
6572 log::info!(
6573 "[STAGE] ensure_theta (n-block, {} blocks, {} realizers): {:.3}s",
6574 n,
6575 self.realizers.len(),
6576 t_ensure.elapsed().as_secs_f64(),
6577 );
6578 self.current_theta = Some(theta.clone());
6579 self.last_cost = None;
6580 self.last_eval = None;
6581 Ok(())
6582 }
6583
6584 impl_exact_joint_theta_memo!();
6585
6586 fn store_cost_only(&mut self, theta: &Array1<f64>, cost: f64) {
6592 if self
6593 .current_theta
6594 .as_ref()
6595 .is_some_and(|cached| theta_values_match(cached, theta))
6596 {
6597 self.last_cost = Some(cost);
6598 }
6599 }
6600
6601 fn specs(&self) -> Vec<&TermCollectionSpec> {
6602 self.realizers.iter().map(|r| r.spec()).collect()
6603 }
6604
6605 fn designs(&self) -> Vec<&TermCollectionDesign> {
6606 self.realizers.iter().map(|r| r.design()).collect()
6607 }
6608
6609 fn design_revision(&self) -> u64 {
6619 self.realizers
6620 .iter()
6621 .fold(0u64, |acc, r| acc.wrapping_add(r.design_revision()))
6622 }
6623}
6624
6625pub(crate) fn seed_risk_profile_for_likelihood_family(
6626 family: &LikelihoodSpec,
6627) -> gam_problem::SeedRiskProfile {
6628 match &family.response {
6629 ResponseFamily::Gaussian => gam_problem::SeedRiskProfile::Gaussian,
6630 ResponseFamily::RoystonParmar => gam_problem::SeedRiskProfile::Survival,
6631 ResponseFamily::Binomial
6632 | ResponseFamily::Poisson
6633 | ResponseFamily::Tweedie { .. }
6634 | ResponseFamily::NegativeBinomial { .. }
6635 | ResponseFamily::Beta { .. }
6636 | ResponseFamily::Gamma => gam_problem::SeedRiskProfile::GeneralizedLinear,
6637 }
6638}
6639
6640const EXACT_JOINT_SECOND_ORDER_THETA_CAP: usize = 8;
6648
6649fn exact_joint_seed_config(
6650 risk_profile: gam_problem::SeedRiskProfile,
6651 auxiliary_dim: usize,
6652) -> gam_problem::SeedConfig {
6653 let mut config = gam_problem::SeedConfig {
6654 risk_profile,
6655 num_auxiliary_trailing: auxiliary_dim,
6656 ..Default::default()
6657 };
6658 match risk_profile {
6659 gam_problem::SeedRiskProfile::Gaussian
6660 | gam_problem::SeedRiskProfile::GaussianLocationScale => {
6661 config.max_seeds = 4;
6662 config.seed_budget = 2;
6663 }
6664 gam_problem::SeedRiskProfile::GeneralizedLinear => {
6665 config.max_seeds = 1;
6670 config.seed_budget = 1;
6671 config.screen_max_inner_iterations = 8;
6672 }
6673 gam_problem::SeedRiskProfile::Survival => {
6674 config.max_seeds = 8;
6680 config.seed_budget = 4;
6681 config.screen_max_inner_iterations = 8;
6682 }
6683 }
6684 config
6685}
6686
6687#[cfg(test)]
6688mod exact_joint_seed_config_tests {
6689 use super::*;
6690
6691 #[test]
6692 fn exact_joint_marginal_slope_profiles_get_deeper_startup_validation() {
6693 let bms = exact_joint_seed_config(gam_problem::SeedRiskProfile::GeneralizedLinear, 2);
6694 assert_eq!(bms.max_seeds, 1);
6695 assert_eq!(bms.seed_budget, 1);
6696 assert_eq!(bms.screen_max_inner_iterations, 8);
6697 assert_eq!(bms.num_auxiliary_trailing, 2);
6698
6699 let survival = exact_joint_seed_config(gam_problem::SeedRiskProfile::Survival, 3);
6700 assert_eq!(survival.max_seeds, 8);
6701 assert_eq!(survival.seed_budget, 4);
6702 assert_eq!(survival.screen_max_inner_iterations, 8);
6703 assert_eq!(survival.num_auxiliary_trailing, 3);
6704 }
6705
6706 #[test]
6707 fn exact_joint_gaussian_keeps_tight_historical_multistart_budget() {
6708 let gaussian = exact_joint_seed_config(gam_problem::SeedRiskProfile::Gaussian, 1);
6709 assert_eq!(gaussian.max_seeds, 4);
6710 assert_eq!(gaussian.seed_budget, 2);
6711 assert_eq!(
6712 gaussian.screen_max_inner_iterations,
6713 gam_problem::SeedConfig::default().screen_max_inner_iterations
6714 );
6715 assert_eq!(gaussian.num_auxiliary_trailing, 1);
6716 }
6717}
6718
6719pub(crate) fn exact_joint_multistart_outer_problem(
6720 theta0: &Array1<f64>,
6721 lower: &Array1<f64>,
6722 upper: &Array1<f64>,
6723 rho_dim: usize,
6724 auxiliary_dim: usize,
6725 n_params: usize,
6726 gradient: gam_problem::Derivative,
6727 hessian: gam_problem::DeclaredHessianForm,
6728 prefer_gradient_only: bool,
6729 disable_fixed_point: bool,
6730 risk_profile: gam_problem::SeedRiskProfile,
6731 tolerance: f64,
6732 max_iter: usize,
6733 bfgs_step_cap: Option<f64>,
6742 bfgs_step_cap_psi: Option<f64>,
6743 screening_cap: Option<Arc<AtomicUsize>>,
6744 profiled_objective_size: Option<(usize, usize)>,
6765 has_constant_curvature: bool,
6774) -> gam_solve::rho_optimizer::OuterProblem {
6775 let mut seed_heuristic = theta0.to_vec();
6776 for value in &mut seed_heuristic[..rho_dim] {
6777 *value = value.exp();
6778 }
6779 let rho_ceiling = if has_constant_curvature {
6784 gam_solve::estimate::RHO_BOUND
6785 } else {
6786 12.0
6787 };
6788 let mut problem = gam_solve::rho_optimizer::OuterProblem::new(n_params)
6789 .with_gradient(gradient)
6790 .with_hessian(hessian)
6791 .with_prefer_gradient_only(prefer_gradient_only)
6792 .with_disable_fixed_point(disable_fixed_point)
6793 .with_fallback_policy(gam_solve::rho_optimizer::FallbackPolicy::Automatic)
6803 .with_psi_dim(auxiliary_dim)
6804 .with_tolerance(tolerance)
6805 .with_max_iter(max_iter)
6806 .with_bounds(lower.clone(), upper.clone())
6807 .with_initial_rho(theta0.clone())
6808 .with_bfgs_step_cap(bfgs_step_cap)
6809 .with_bfgs_step_cap_psi(bfgs_step_cap_psi)
6810 .with_seed_config({
6811 let mut sc = exact_joint_seed_config(risk_profile, auxiliary_dim);
6812 if has_constant_curvature {
6813 sc.bounds = (sc.bounds.0, rho_ceiling);
6817 }
6836 sc
6837 })
6838 .with_rho_bound(rho_ceiling)
6839 .with_heuristic_lambdas(seed_heuristic);
6840 if let Some((n_obs, p_cols)) = profiled_objective_size {
6841 problem = problem
6849 .with_objective_scale(Some(n_obs as f64))
6850 .with_problem_size(n_obs, p_cols)
6851 .with_arc_initial_regularization(Some(0.25))
6852 .with_operator_initial_trust_radius(Some(4.0));
6853 }
6854 if let Some(screening_cap) = screening_cap {
6855 problem = problem
6856 .with_screening_cap(screening_cap)
6857 .with_screen_initial_rho(true);
6858 }
6859 problem
6860}
6861
6862fn kappa_phase_failure_is_fixed_kappa_recoverable(message: &str) -> bool {
6873 message.contains("no candidate seeds passed outer startup validation")
6874 || message.contains("joint hyper rho dimension mismatch")
6875 || message.contains("objective returned a non-finite cost")
6876}
6877
6878pub fn optimize_spatial_length_scale_exact_joint<FitOut, FitFn, ExactFn, ExactEfsFn, SeedFn>(
6879 data: ArrayView2<'_, f64>,
6880 block_specs: &[TermCollectionSpec],
6881 block_term_indices: &[Vec<usize>],
6882 kappa_options: &SpatialLengthScaleOptimizationOptions,
6883 joint_setup: &ExactJointHyperSetup,
6884 seed_risk_profile: gam_problem::SeedRiskProfile,
6885 analytic_joint_gradient_available: bool,
6886 analytic_joint_hessian_available: bool,
6887 disable_fixed_point: bool,
6888 screening_cap: Option<Arc<AtomicUsize>>,
6889 outer_derivative_policy: gam_model_api::families::custom_family::OuterDerivativePolicy,
6890 mut fit_fn: FitFn,
6891 mut exact_fn: ExactFn,
6892 mut exact_efs_fn: ExactEfsFn,
6893 mut seed_inner_beta_fn: SeedFn,
6894) -> Result<SpatialLengthScaleOptimizationResult<FitOut>, String>
6895where
6896 FitOut: Clone,
6897 FitFn: FnMut(
6898 &Array1<f64>,
6899 &[TermCollectionSpec],
6900 &[TermCollectionDesign],
6901 ) -> Result<FitOut, String>,
6902 ExactFn: FnMut(
6903 &Array1<f64>,
6904 &[TermCollectionSpec],
6905 &[TermCollectionDesign],
6906 gam_solve::estimate::reml::reml_outer_engine::EvalMode,
6907 &gam_problem::outer_subsample::RowSet,
6908 ) -> Result<
6909 (
6910 f64,
6911 Array1<f64>,
6912 gam_problem::HessianResult,
6913 ),
6914 String,
6915 >,
6916 ExactEfsFn: FnMut(
6917 &Array1<f64>,
6918 &[TermCollectionSpec],
6919 &[TermCollectionDesign],
6920 ) -> Result<gam_problem::EfsEval, String>,
6921 SeedFn:
6922 FnMut(&Array1<f64>) -> Result<gam_solve::rho_optimizer::SeedOutcome, EstimationError>,
6923{
6924 let n_blocks = block_specs.len();
6925 if block_term_indices.len() != n_blocks {
6926 return Err(SmoothError::dimension_mismatch(format!(
6927 "block_specs ({}) and block_term_indices ({}) length mismatch",
6928 n_blocks,
6929 block_term_indices.len()
6930 ))
6931 .into());
6932 }
6933
6934 let log_kappa_dim = joint_setup.log_kappa_dim();
6935
6936 log::warn!(
6937 "[OUTER-FD-AUDIT/spatial-exact-joint] driver entry: aux_dim={} log_kappa_dim={} kappa_enabled={} rho_dim={} theta0_len={}",
6938 joint_setup.auxiliary_dim(),
6939 log_kappa_dim,
6940 kappa_options.enabled,
6941 joint_setup.rho_dim(),
6942 joint_setup.theta0().len()
6943 );
6944
6945 if joint_setup.auxiliary_dim() == 0 && (!kappa_options.enabled || log_kappa_dim == 0) {
6949 log::warn!(
6950 "[OUTER-FD-AUDIT/spatial-exact-joint] taking FAST path (no outer theta optimization in this driver)"
6951 );
6952 let (designs, resolved_specs) = build_term_collection_designs_and_freeze_joint(
6953 data, block_specs,
6954 )
6955 .map_err(|e| {
6956 format!("failed to build and freeze joint block designs during exact joint kappa optimization: {e}")
6957 })?;
6958 let theta0 = joint_setup.theta0();
6959
6960 let spec_refs: Vec<TermCollectionSpec> = resolved_specs.clone();
6962 let design_refs: Vec<TermCollectionDesign> = designs.clone();
6963 let fit = fit_fn(&theta0, &spec_refs, &design_refs)?;
6964 return Ok(SpatialLengthScaleOptimizationResult {
6965 resolved_specs,
6966 designs,
6967 fit,
6968 timing: None,
6969 });
6970 }
6971
6972 let theta0 = joint_setup.theta0();
6976 let lower = joint_setup.lower();
6977 let upper = joint_setup.upper();
6978 if theta0.len() < log_kappa_dim || lower.len() != theta0.len() || upper.len() != theta0.len() {
6979 return Err(SmoothError::dimension_mismatch(format!(
6980 "invalid exact joint theta setup: theta0={}, lower={}, upper={}, required_log_kappa_dim={}",
6981 theta0.len(),
6982 lower.len(),
6983 upper.len(),
6984 log_kappa_dim
6985 ))
6986 .into());
6987 }
6988 let rho_dim = joint_setup.rho_dim();
6989 let all_dims = joint_setup.log_kappa_dims_per_term();
6990
6991 let (boot_designs, best_specs) = build_term_collection_designs_and_freeze_joint(
6993 data,
6994 block_specs,
6995 )
6996 .map_err(|e| {
6997 format!(
6998 "failed to build and freeze joint block designs during exact joint kappa bootstrap: {e}"
6999 )
7000 })?;
7001 let policy_hessian_form = outer_derivative_policy.declared_hessian_form();
7011 let analytic_outer_hessian_available = analytic_joint_hessian_available
7012 && matches!(
7013 policy_hessian_form,
7014 gam_problem::DeclaredHessianForm::Either
7015 | gam_problem::DeclaredHessianForm::Dense
7016 | gam_problem::DeclaredHessianForm::Operator { .. }
7017 );
7018 let prefer_gradient_only = !analytic_outer_hessian_available;
7019
7020 let theta_dim = theta0.len();
7021 let psi_dim = theta_dim - rho_dim;
7022
7023 let cache_blocks: Vec<(TermCollectionSpec, TermCollectionDesign, Vec<usize>)> = best_specs
7025 .iter()
7026 .zip(boot_designs.iter())
7027 .zip(block_term_indices.iter())
7028 .map(|((spec, design), terms)| (spec.clone(), design.clone(), terms.clone()))
7029 .collect();
7030
7031 struct NBlockExactJointState<'d> {
7032 cache: ExactJointDesignCache<'d>,
7033 }
7034
7035 let mut state = NBlockExactJointState {
7036 cache: ExactJointDesignCache::new(data, cache_blocks, rho_dim, all_dims.clone())?,
7037 };
7038
7039 const KAPPA_PILOT_K: usize = 5_000;
7064 const KAPPA_POLISH_K: usize = 25_000;
7065 const KAPPA_POLISH_TRIGGER_N: usize = 100_000;
7066
7067 let n_total = data.nrows();
7068 let use_staged_kappa = outer_derivative_policy.should_use_staged_kappa(n_total);
7069 if use_staged_kappa {
7070 log::info!(
7071 "[KAPPA-STAGED] auto-engaging pilot+polish schedule: n={} pilot_k={} polish_k={}",
7072 n_total,
7073 KAPPA_PILOT_K,
7074 KAPPA_POLISH_K,
7075 );
7076 }
7077
7078 fn build_uniform_pilot_subsample(
7095 n_total: usize,
7096 k_target: usize,
7097 seed: u64,
7098 ) -> gam_problem::outer_subsample::OuterScoreSubsample {
7099 use gam_problem::outer_subsample::OuterScoreSubsample;
7100 let k = k_target.min(n_total);
7101 if k == 0 || n_total == 0 {
7102 return OuterScoreSubsample::from_uniform_inclusion_mask(Vec::new(), n_total, seed);
7103 }
7104 let mut mask: Vec<usize> = Vec::with_capacity(k);
7108 let mut state = seed.wrapping_add(0x9E3779B97F4A7C15);
7110 let splitmix = |s: &mut u64| -> u64 { gam_linalg::utils::splitmix64(s) };
7111 let mut taken = std::collections::HashSet::with_capacity(k);
7112 for j in (n_total - k)..n_total {
7113 let r = (splitmix(&mut state) % (j as u64 + 1)) as usize;
7114 if !taken.insert(r) {
7115 taken.insert(j);
7116 mask.push(j);
7117 } else {
7118 mask.push(r);
7119 }
7120 }
7121 mask.sort_unstable();
7122 mask.dedup();
7123 OuterScoreSubsample::from_uniform_inclusion_mask(mask, n_total, seed)
7124 }
7125
7126 let current_row_set: std::cell::RefCell<gam_problem::outer_subsample::RowSet> = if use_staged_kappa {
7127 let pilot = build_uniform_pilot_subsample(n_total, KAPPA_PILOT_K, n_total as u64);
7128 std::cell::RefCell::new(gam_problem::outer_subsample::RowSet::Subsample {
7129 rows: std::sync::Arc::clone(&pilot.rows),
7130 n_full: n_total,
7131 })
7132 } else {
7133 std::cell::RefCell::new(gam_problem::outer_subsample::RowSet::All)
7134 };
7135
7136 let exact_fn_cell = std::cell::RefCell::new(&mut exact_fn);
7137 let exact_efs_fn_cell = std::cell::RefCell::new(&mut exact_efs_fn);
7138
7139 use std::cell::Cell;
7154 let kphase_cost_calls: Cell<usize> = Cell::new(0);
7155 let kphase_cost_total_s: Cell<f64> = Cell::new(0.0);
7156 let kphase_eval_calls: Cell<usize> = Cell::new(0);
7157 let kphase_eval_total_s: Cell<f64> = Cell::new(0.0);
7158 let kphase_efs_calls: Cell<usize> = Cell::new(0);
7159 let kphase_efs_total_s: Cell<f64> = Cell::new(0.0);
7160 let kphase_optim_start = std::time::Instant::now();
7161 let kphase_log_kappa_dim = log_kappa_dim;
7162 let kphase_log_norms = |theta: &Array1<f64>| -> (f64, f64) {
7163 let theta_norm = theta.iter().map(|v| v * v).sum::<f64>().sqrt();
7164 let log_kappa_norm = if kphase_log_kappa_dim > 0 && theta.len() >= kphase_log_kappa_dim {
7165 let start = theta.len() - kphase_log_kappa_dim;
7166 theta.iter().skip(start).map(|v| v * v).sum::<f64>().sqrt()
7167 } else {
7168 0.0
7169 };
7170 (theta_norm, log_kappa_norm)
7171 };
7172
7173 use gam_solve::rho_optimizer::OuterEvalOrder;
7174 use gam_problem::{DeclaredHessianForm, Derivative, OuterEval};
7175
7176 let joint_p_cols: usize = boot_designs
7180 .iter()
7181 .map(|d| d.design.ncols())
7182 .sum::<usize>()
7183 .max(1);
7184
7185 let problem = exact_joint_multistart_outer_problem(
7186 &theta0,
7187 &lower,
7188 &upper,
7189 rho_dim,
7190 psi_dim,
7191 theta_dim,
7192 if analytic_joint_gradient_available {
7193 Derivative::Analytic
7194 } else {
7195 Derivative::Unavailable
7196 },
7197 if analytic_outer_hessian_available {
7198 DeclaredHessianForm::Either
7199 } else {
7200 DeclaredHessianForm::Unavailable
7201 },
7202 prefer_gradient_only,
7203 disable_fixed_point,
7204 seed_risk_profile,
7205 kappa_options.rel_tol.max(1e-6),
7206 kappa_options.max_outer_iter.max(1),
7207 Some(5.0),
7209 Some(kappa_options.log_step.clamp(0.25, 1.0)),
7211 screening_cap.clone(),
7212 Some((n_total, joint_p_cols)),
7215 block_specs
7218 .iter()
7219 .any(|s| !constant_curvature_term_indices(s).is_empty()),
7220 );
7221
7222 fn collect_specs(cache: &ExactJointDesignCache<'_>) -> Vec<TermCollectionSpec> {
7224 cache.specs().into_iter().cloned().collect()
7225 }
7226 fn collect_designs(cache: &ExactJointDesignCache<'_>) -> Vec<TermCollectionDesign> {
7227 cache.designs().into_iter().cloned().collect()
7228 }
7229
7230 let result = {
7231 let eval_outer = |ctx: &mut &mut NBlockExactJointState<'_>,
7232 theta: &Array1<f64>,
7233 order: OuterEvalOrder|
7234 -> Result<OuterEval, EstimationError> {
7235 if let Some((cost, grad, hess)) = ctx.cache.memoized_eval(theta) {
7236 let cached_satisfies_order = match order {
7237 OuterEvalOrder::Value => true,
7238 OuterEvalOrder::ValueAndGradient => true,
7239 OuterEvalOrder::ValueGradientHessian => hess.is_analytic(),
7240 };
7241 if cached_satisfies_order {
7242 if !cost.is_finite() {
7243 return Ok(OuterEval::infeasible(theta.len()));
7244 }
7245 if grad.iter().any(|v| !v.is_finite()) {
7258 return Ok(OuterEval::infeasible(theta.len()));
7259 }
7260 return Ok(OuterEval {
7261 cost,
7262 gradient: grad,
7263 hessian: hess,
7264 inner_beta_hint: None,
7265 });
7266 }
7267 }
7268 if gam_solve::rho_optimizer::outer_wall_clock_deadline_exceeded() {
7285 return Ok(OuterEval::infeasible(theta.len()));
7286 }
7287 if let Err(err) = ctx.cache.ensure_theta(theta) {
7288 log::warn!(
7289 "[OUTER] n-block exact-joint spatial: ensure_theta failed during gradient evaluation: {err}"
7290 );
7291 return Ok(OuterEval::infeasible(theta.len()));
7292 }
7293 let design_revision = Some(ctx.cache.design_revision());
7294 let specs = collect_specs(&ctx.cache);
7295 let designs = collect_designs(&ctx.cache);
7296 let clamped = outer_derivative_policy.order_for_evaluation(order);
7304 let need_hessian = matches!(clamped, OuterEvalOrder::ValueGradientHessian)
7305 && analytic_outer_hessian_available;
7306 let eval_mode = if need_hessian {
7307 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueGradientHessian
7308 } else {
7309 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueAndGradient
7310 };
7311 let t0 = std::time::Instant::now();
7312 let result = {
7313 let row_set_borrow = current_row_set.borrow();
7314 (*exact_fn_cell.borrow_mut())(theta, &specs, &designs, eval_mode, &row_set_borrow)
7315 };
7316 let elapsed_s = t0.elapsed().as_secs_f64();
7317 kphase_eval_calls.set(kphase_eval_calls.get() + 1);
7318 kphase_eval_total_s.set(kphase_eval_total_s.get() + elapsed_s);
7319 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta);
7320 log::info!(
7321 "[KAPPA-PHASE] phase=eval_outer call={} order={:?} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
7322 kphase_eval_calls.get(),
7323 order,
7324 design_revision,
7325 theta_norm,
7326 log_kappa_norm,
7327 elapsed_s,
7328 );
7329 match result {
7330 Ok((cost, grad, hess)) => {
7331 ctx.cache.store_eval((cost, grad.clone(), hess.clone()));
7332 if !cost.is_finite() {
7333 return Ok(OuterEval::infeasible(theta.len()));
7334 }
7335 if grad.iter().any(|v| !v.is_finite()) {
7348 return Ok(OuterEval::infeasible(theta.len()));
7349 }
7350 Ok(OuterEval {
7351 cost,
7352 gradient: grad,
7353 hessian: hess,
7354 inner_beta_hint: None,
7355 })
7356 }
7357 Err(err) => {
7358 log::warn!(
7359 "[OUTER] n-block exact-joint spatial: exact evaluation failed: {err}"
7360 );
7361 Ok(OuterEval::infeasible(theta.len()))
7362 }
7363 }
7364 };
7365
7366 let obj = problem.build_objective_with_eval_order(
7367 &mut state,
7368 |ctx: &mut &mut NBlockExactJointState<'_>, theta: &Array1<f64>| {
7369 if let Some(cost) = ctx.cache.memoized_cost(theta) {
7370 return Ok(cost);
7371 }
7372 if gam_solve::rho_optimizer::outer_wall_clock_deadline_exceeded() {
7380 return Ok(f64::INFINITY);
7381 }
7382 if let Err(err) = ctx.cache.ensure_theta(theta) {
7383 log::warn!(
7384 "[OUTER] n-block exact-joint spatial: ensure_theta failed during cost evaluation: {err}"
7385 );
7386 return Ok(f64::INFINITY);
7387 }
7388 let design_revision = Some(ctx.cache.design_revision());
7389 let specs = collect_specs(&ctx.cache);
7390 let designs = collect_designs(&ctx.cache);
7391 let t0 = std::time::Instant::now();
7398 let result = {
7399 let row_set_borrow = current_row_set.borrow();
7400 (*exact_fn_cell.borrow_mut())(
7401 theta,
7402 &specs,
7403 &designs,
7404 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueOnly,
7405 &row_set_borrow,
7406 )
7407 };
7408 let elapsed_s = t0.elapsed().as_secs_f64();
7409 kphase_cost_calls.set(kphase_cost_calls.get() + 1);
7410 kphase_cost_total_s.set(kphase_cost_total_s.get() + elapsed_s);
7411 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta);
7412 log::info!(
7413 "[KAPPA-PHASE] phase=cost call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
7414 kphase_cost_calls.get(),
7415 design_revision,
7416 theta_norm,
7417 log_kappa_norm,
7418 elapsed_s,
7419 );
7420 match result {
7421 Ok((cost, _grad, _hess)) => {
7422 ctx.cache.store_cost_only(theta, cost);
7428 Ok(cost)
7429 }
7430 Err(err) => {
7431 log::warn!(
7432 "[OUTER] n-block exact-joint spatial: exact cost evaluation failed: {err}"
7433 );
7434 Ok(f64::INFINITY)
7435 }
7436 }
7437 },
7438 |ctx: &mut &mut NBlockExactJointState<'_>, theta: &Array1<f64>| {
7439 eval_outer(
7440 ctx,
7441 theta,
7442 if analytic_outer_hessian_available {
7443 OuterEvalOrder::ValueGradientHessian
7444 } else {
7445 OuterEvalOrder::ValueAndGradient
7446 },
7447 )
7448 },
7449 |ctx: &mut &mut NBlockExactJointState<'_>,
7450 theta: &Array1<f64>,
7451 order: OuterEvalOrder| { eval_outer(ctx, theta, order) },
7452 None::<fn(&mut &mut NBlockExactJointState<'_>)>,
7453 Some(
7454 |ctx: &mut &mut NBlockExactJointState<'_>, theta: &Array1<f64>| {
7455 ctx.cache
7456 .ensure_theta(theta)
7457 .map_err(EstimationError::InvalidInput)?;
7458 let design_revision = Some(ctx.cache.design_revision());
7459 let specs = collect_specs(&ctx.cache);
7460 let designs = collect_designs(&ctx.cache);
7461 let t0 = std::time::Instant::now();
7462 let eval_result = (*exact_efs_fn_cell.borrow_mut())(
7463 theta,
7464 &specs,
7465 &designs,
7466 );
7467 let elapsed_s = t0.elapsed().as_secs_f64();
7468 kphase_efs_calls.set(kphase_efs_calls.get() + 1);
7469 kphase_efs_total_s.set(kphase_efs_total_s.get() + elapsed_s);
7470 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta);
7471 log::info!(
7472 "[KAPPA-PHASE] phase=efs call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
7473 kphase_efs_calls.get(),
7474 design_revision,
7475 theta_norm,
7476 log_kappa_norm,
7477 elapsed_s,
7478 );
7479 let eval = eval_result.map_err(EstimationError::RemlOptimizationFailed)?;
7480 Ok(eval)
7481 },
7482 ),
7483 );
7484 let mut obj = obj.with_seed_inner_state(
7485 move |_ctx: &mut &mut NBlockExactJointState<'_>, beta: &Array1<f64>| {
7486 (seed_inner_beta_fn)(beta)
7487 },
7488 );
7489
7490 match problem.run(&mut obj, "n-block exact-joint spatial") {
7491 Ok(result) => result,
7492 Err(e) => {
7493 let message = e.to_string();
7494 if kappa_phase_failure_is_fixed_kappa_recoverable(&message) {
7514 drop(obj);
7515 log::warn!(
7516 "[KAPPA-PHASE] length-scale optimization could not validate any seed \
7517 ({message}); falling back to a FIXED bootstrap κ (skipping κ \
7518 optimization) and fitting there — a real model at the initial \
7519 length-scale rather than raising (gam#787/#860)."
7520 );
7521 let (designs, resolved_specs) =
7522 build_term_collection_designs_and_freeze_joint(data, block_specs).map_err(
7523 |build_err| {
7524 format!(
7525 "fixed-κ fallback failed to build and freeze joint block \
7526 designs after κ optimization could not validate a seed \
7527 ({message}): {build_err}"
7528 )
7529 },
7530 )?;
7531 let fixed_theta0 = joint_setup.theta0();
7532 let spec_refs: Vec<TermCollectionSpec> = resolved_specs.clone();
7533 let design_refs: Vec<TermCollectionDesign> = designs.clone();
7534 let fit = fit_fn(&fixed_theta0, &spec_refs, &design_refs)?;
7535 return Ok(SpatialLengthScaleOptimizationResult {
7536 resolved_specs,
7537 designs,
7538 fit,
7539 timing: None,
7540 });
7541 }
7542 return Err(message);
7543 }
7544 }
7545 }; let kphase_total_s = kphase_optim_start.elapsed().as_secs_f64();
7555 log::info!(
7556 "[KAPPA-PHASE-SUMMARY] log_kappa_dim={} n_cost={} cost_total_s={:.4} n_eval={} eval_total_s={:.4} n_efs={} efs_total_s={:.4} optim_total_s={:.4}",
7557 kphase_log_kappa_dim,
7558 kphase_cost_calls.get(),
7559 kphase_cost_total_s.get(),
7560 kphase_eval_calls.get(),
7561 kphase_eval_total_s.get(),
7562 kphase_efs_calls.get(),
7563 kphase_efs_total_s.get(),
7564 kphase_total_s,
7565 );
7566 let timing = SpatialLengthScaleOptimizationTiming {
7567 log_kappa_dim: kphase_log_kappa_dim,
7568 cost_calls: kphase_cost_calls.get(),
7569 cost_total_s: kphase_cost_total_s.get(),
7570 eval_calls: kphase_eval_calls.get(),
7571 eval_total_s: kphase_eval_total_s.get(),
7572 efs_calls: kphase_efs_calls.get(),
7573 efs_total_s: kphase_efs_total_s.get(),
7574 slow_path_resets: 0,
7575 design_revision_delta: 0,
7576 nfree_miss_shape: 0,
7577 nfree_miss_value: 0,
7578 nfree_miss_gradient: 0,
7579 nfree_miss_penalty: 0,
7580 nfree_miss_revision: 0,
7581 nfree_miss_second_order: 0,
7582 nfree_miss_other: 0,
7583 optim_total_s: kphase_total_s,
7584 };
7585
7586 let theta_star = result.rho;
7587
7588 if use_staged_kappa && n_total >= KAPPA_POLISH_TRIGGER_N {
7605 let polish = build_uniform_pilot_subsample(
7606 n_total,
7607 KAPPA_POLISH_K,
7608 (n_total as u64).wrapping_add(0xA5A5A5A5),
7609 );
7610 *current_row_set.borrow_mut() = gam_problem::outer_subsample::RowSet::Subsample {
7611 rows: std::sync::Arc::clone(&polish.rows),
7612 n_full: n_total,
7613 };
7614 log::info!(
7615 "[KAPPA-STAGED] rotating to polish subsample: k={} at theta_star",
7616 polish.rows.len(),
7617 );
7618 state.cache.ensure_theta(&theta_star)?;
7622 let (polish_cost, polish_grad, _) = {
7623 let specs = collect_specs(&state.cache);
7624 let designs = collect_designs(&state.cache);
7625 let row_set_borrow = current_row_set.borrow();
7626 exact_fn(
7627 &theta_star,
7628 &specs,
7629 &designs,
7630 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueAndGradient,
7631 &row_set_borrow,
7632 )?
7633 };
7634 if !polish_cost.is_finite() || polish_grad.iter().any(|value| !value.is_finite()) {
7635 return Err(
7636 "polish subsample exact-joint evaluation produced non-finite objective pieces"
7637 .to_string(),
7638 );
7639 }
7640 }
7641 *current_row_set.borrow_mut() = gam_problem::outer_subsample::RowSet::All;
7642 if use_staged_kappa {
7643 log::info!(
7644 "[KAPPA-STAGED] rotating to full data for final coefficient fit (n={})",
7645 n_total,
7646 );
7647 }
7648
7649 state.cache.ensure_theta(&theta_star)?;
7650
7651 let resolved_specs: Vec<TermCollectionSpec> = collect_specs(&state.cache);
7652 let designs: Vec<TermCollectionDesign> = collect_designs(&state.cache);
7653
7654 let fit = fit_fn(&theta_star, &resolved_specs, &designs)?;
7655
7656 for spec in &resolved_specs {
7657 log_spatial_aniso_scales(spec);
7658 }
7659
7660 Ok(SpatialLengthScaleOptimizationResult {
7661 resolved_specs,
7662 designs,
7663 fit,
7664 timing: Some(timing),
7665 })
7666}
7667
7668fn try_exact_joint_latent_coord_optimization(
7669 data: ArrayView2<'_, f64>,
7670 y: ArrayView1<'_, f64>,
7671 weights: ArrayView1<'_, f64>,
7672 offset: ArrayView1<'_, f64>,
7673 resolvedspec: &TermCollectionSpec,
7674 best: &FittedTermCollection,
7675 family: LikelihoodSpec,
7676 options: &FitOptions,
7677 latent: &StandardLatentCoordConfig,
7678) -> Result<FittedTermCollectionWithSpec, EstimationError> {
7679 use gam_solve::rho_optimizer::OuterEvalOrder;
7680 use gam_problem::{DeclaredHessianForm, Derivative, OuterEval};
7681
7682 let rho_dim = best.fit.lambdas.len();
7683 let latent_flat_dim = latent.values.len();
7684 if latent_flat_dim == 0 {
7685 crate::bail_invalid_estim!(
7686 "latent-coordinate optimization requires a non-empty latent block"
7687 );
7688 }
7689 let direct_hypers =
7690 latent_coord_initial_direct_hypers(latent.values.id_mode(), latent.values.latent_dim())?;
7691 let analytic_rho_count = latent
7692 .analytic_penalties
7693 .as_ref()
7694 .map_or(0, |registry| registry.total_rho_count());
7695 let latent_coord_ext_dim = latent_flat_dim + analytic_rho_count + direct_hypers.len();
7696
7697 let mut theta0 = Array1::<f64>::zeros(rho_dim + latent_coord_ext_dim);
7698 theta0
7699 .slice_mut(s![..rho_dim])
7700 .assign(&best.fit.lambdas.mapv(f64::ln));
7701 theta0
7702 .slice_mut(s![rho_dim..rho_dim + latent_flat_dim])
7703 .assign(latent.values.as_flat());
7704 if !direct_hypers.is_empty() {
7705 let direct_start = rho_dim + latent_flat_dim + analytic_rho_count;
7706 theta0
7707 .slice_mut(s![direct_start..direct_start + direct_hypers.len()])
7708 .assign(&direct_hypers);
7709 }
7710
7711 let mut lower = Array1::<f64>::from_elem(theta0.len(), -12.0);
7712 let mut upper = Array1::<f64>::from_elem(theta0.len(), 12.0);
7713 let latent_bound = latent
7714 .values
7715 .as_flat()
7716 .iter()
7717 .fold(1.0_f64, |acc, &v| acc.max(v.abs()))
7718 + 10.0;
7719 for axis in rho_dim..rho_dim + latent_flat_dim {
7720 lower[axis] = -latent_bound;
7721 upper[axis] = latent_bound;
7722 }
7723
7724 struct LatentJointContext<'d> {
7725 rho_dim: usize,
7726 cache: SingleBlockLatentCoordDesignCache,
7727 evaluator: gam_solve::estimate::ExternalJointHyperEvaluator<'d>,
7728 }
7729
7730 impl<'d> LatentJointContext<'d> {
7731 fn eval_full(
7732 &mut self,
7733 theta: &Array1<f64>,
7734 order: OuterEvalOrder,
7735 ) -> Result<
7736 (
7737 f64,
7738 Array1<f64>,
7739 gam_problem::HessianResult,
7740 ),
7741 EstimationError,
7742 > {
7743 if let Some(eval) = self.cache.memoized_eval(theta) {
7744 return Ok(eval);
7745 }
7746 self.cache
7747 .ensure_theta(theta)
7748 .map_err(EstimationError::InvalidInput)?;
7749 let hyper_dirs = self
7750 .cache
7751 .hyper_dirs()
7752 .map_err(EstimationError::InvalidInput)?;
7753 let design_revision = Some(self.cache.design_revision());
7754 let registry_for_key = self.cache.analytic_penalties();
7755 self.evaluator
7756 .set_analytic_penalty_registry(registry_for_key.as_deref());
7757 let mut eval = evaluate_joint_reml_outer_eval_at_theta(
7758 &mut self.evaluator,
7759 self.cache.design(),
7760 theta,
7761 self.rho_dim,
7762 hyper_dirs,
7763 None,
7764 order,
7765 design_revision,
7766 )?;
7767 let latent = self.cache.latent().map_err(EstimationError::InvalidInput)?;
7768 if let Some(registry) = registry_for_key {
7769 let mut registry = registry.as_ref().clone();
7770 registry.apply_weight_schedules(
7771 gam_solve::estimate::reml::outer_eval::current_outer_iter() as usize,
7772 );
7773 add_analytic_penalty_objective_to_eval(
7774 theta,
7775 self.rho_dim,
7776 latent.as_ref(),
7777 ®istry,
7778 &mut eval,
7779 )?;
7780 }
7781 add_latent_id_objective_to_eval(
7782 theta,
7783 self.rho_dim,
7784 self.cache.analytic_penalty_rho_count(),
7785 latent.as_ref(),
7786 &mut eval,
7787 )?;
7788 self.cache.store_eval(eval.clone());
7789 Ok(eval)
7790 }
7791
7792 fn eval_efs(
7793 &mut self,
7794 theta: &Array1<f64>,
7795 ) -> Result<gam_problem::EfsEval, EstimationError> {
7796 self.cache
7797 .ensure_theta(theta)
7798 .map_err(EstimationError::InvalidInput)?;
7799 let hyper_dirs = self
7800 .cache
7801 .hyper_dirs()
7802 .map_err(EstimationError::InvalidInput)?;
7803 let registry_for_key = self.cache.analytic_penalties();
7804 self.evaluator
7805 .set_analytic_penalty_registry(registry_for_key.as_deref());
7806 let mut efs = evaluate_joint_reml_efs_at_theta(
7807 &mut self.evaluator,
7808 self.cache.design(),
7809 theta,
7810 self.rho_dim,
7811 hyper_dirs,
7812 None,
7813 Some(self.cache.design_revision()),
7814 )?;
7815 if let Some(registry) = registry_for_key {
7816 let mut registry = registry.as_ref().clone();
7817 registry.apply_weight_schedules(
7818 gam_solve::estimate::reml::outer_eval::current_outer_iter() as usize,
7819 );
7820 let latent = self.cache.latent().map_err(EstimationError::InvalidInput)?;
7821 let contribution = analytic_penalty_objective_contribution(
7822 theta,
7823 self.rho_dim,
7824 latent.as_ref(),
7825 ®istry,
7826 )?;
7827 efs.cost += contribution.cost;
7828 if let (Some(psi_gradient), Some(psi_indices)) =
7829 (efs.psi_gradient.as_mut(), efs.psi_indices.as_ref())
7830 {
7831 if psi_gradient.len() != psi_indices.len() {
7832 crate::bail_invalid_estim!(
7833 "latent-coordinate analytic penalty EFS psi gradient length mismatch: gradient={}, indices={}",
7834 psi_gradient.len(),
7835 psi_indices.len()
7836 );
7837 }
7838 for (local_idx, &theta_idx) in psi_indices.iter().enumerate() {
7839 psi_gradient[local_idx] += contribution.gradient[theta_idx];
7840 }
7841 }
7842 }
7843 Ok(efs)
7844 }
7845
7846 fn eval_cost(&mut self, theta: &Array1<f64>) -> f64 {
7847 if let Some(cost) = self.cache.memoized_cost(theta) {
7848 return cost;
7849 }
7850 if self.cache.ensure_theta(theta).is_err() {
7851 return f64::INFINITY;
7852 }
7853 let design_revision = Some(self.cache.design_revision());
7854 let registry_for_key = self.cache.analytic_penalties();
7855 self.evaluator
7856 .set_analytic_penalty_registry(registry_for_key.as_deref());
7857 let result = {
7858 let design = self.cache.design();
7859 self.evaluator.evaluate_cost_only(
7860 &design.design,
7861 &design.penalties,
7862 &design.nullspace_dims,
7863 design.linear_constraints.clone(),
7864 theta,
7865 self.rho_dim,
7866 None,
7867 "latent-coordinate-joint cost-only",
7868 design_revision,
7869 )
7870 };
7871 match result {
7872 Ok(cost) => {
7873 let latent = match self.cache.latent() {
7874 Ok(latent) => latent,
7875 Err(_) => return f64::INFINITY,
7876 };
7877 let contribution = match latent_id_objective_contribution(
7878 theta,
7879 self.rho_dim,
7880 self.cache.analytic_penalty_rho_count(),
7881 latent.as_ref(),
7882 ) {
7883 Ok(contribution) => contribution,
7884 Err(_) => return f64::INFINITY,
7885 };
7886 let cost = cost + contribution.cost;
7887 let cost = if let Some(registry) = registry_for_key {
7888 let mut registry = registry.as_ref().clone();
7889 registry.apply_weight_schedules(
7890 gam_solve::estimate::reml::outer_eval::current_outer_iter()
7891 as usize,
7892 );
7893 match analytic_penalty_objective_contribution(
7894 theta,
7895 self.rho_dim,
7896 latent.as_ref(),
7897 ®istry,
7898 ) {
7899 Ok(contribution) => cost + contribution.cost,
7900 Err(_) => return f64::INFINITY,
7901 }
7902 } else {
7903 cost
7904 };
7905 self.cache.store_cost(cost);
7906 cost
7907 }
7908 Err(_) => f64::INFINITY,
7909 }
7910 }
7911 }
7912
7913 let mut ctx = LatentJointContext {
7914 rho_dim,
7915 cache: SingleBlockLatentCoordDesignCache::new(
7916 data.to_owned(),
7917 resolvedspec.clone(),
7918 best.design.clone(),
7919 latent,
7920 rho_dim,
7921 )
7922 .map_err(EstimationError::InvalidInput)?,
7923 evaluator: gam_solve::estimate::ExternalJointHyperEvaluator::new(
7924 y,
7925 weights,
7926 &best.design.design,
7927 offset,
7928 &best.design.penalties,
7929 &external_opts_for_design(&family, &best.design, options),
7930 "latent-coordinate-joint",
7931 )?,
7932 };
7933 let registry_for_key = ctx.cache.analytic_penalties();
7934 ctx.evaluator
7935 .set_analytic_penalty_registry(registry_for_key.as_deref());
7936 ctx.evaluator
7937 .set_persistent_latent_values_fingerprint(latent.values.id_mode());
7938 if let Some(cached_t) = ctx
7939 .evaluator
7940 .load_persistent_latent_values(latent.values.n_obs(), latent.values.latent_dim())
7941 {
7942 let cached_t: Array2<f64> = cached_t;
7943 for (dst, src) in theta0
7944 .slice_mut(s![rho_dim..rho_dim + latent_flat_dim])
7945 .iter_mut()
7946 .zip(cached_t.iter())
7947 {
7948 *dst = *src;
7949 }
7950 }
7951
7952 let problem = exact_joint_multistart_outer_problem(
7953 &theta0,
7954 &lower,
7955 &upper,
7956 rho_dim,
7957 latent_coord_ext_dim,
7958 theta0.len(),
7959 Derivative::Analytic,
7960 DeclaredHessianForm::Unavailable,
7961 false,
7962 false,
7963 seed_risk_profile_for_likelihood_family(&family),
7964 options.tol,
7965 options.max_iter.max(1),
7966 Some(5.0),
7967 Some(0.5),
7968 None,
7969 Some((data.nrows(), best.design.design.ncols().max(1))),
7972 !constant_curvature_term_indices(resolvedspec).is_empty(),
7975 );
7976
7977 let eval_outer = |ctx: &mut &mut LatentJointContext<'_>,
7978 theta: &Array1<f64>,
7979 order: OuterEvalOrder|
7980 -> Result<OuterEval, EstimationError> {
7981 let (cost, gradient, hessian) = ctx.eval_full(theta, order)?;
7982 Ok(OuterEval {
7983 cost,
7984 gradient,
7985 hessian,
7986 inner_beta_hint: None,
7987 })
7988 };
7989
7990 let result = {
7991 let mut obj = problem.build_objective_with_eval_order(
7992 &mut ctx,
7993 |ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>| Ok(ctx.eval_cost(theta)),
7994 |ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>| {
7995 eval_outer(ctx, theta, OuterEvalOrder::ValueAndGradient)
7996 },
7997 |ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>, order: OuterEvalOrder| {
7998 eval_outer(ctx, theta, order)
7999 },
8000 Some(|ctx: &mut &mut LatentJointContext<'_>| {
8001 ctx.cache.reset();
8002 }),
8003 Some(|ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>| ctx.eval_efs(theta)),
8004 );
8005
8006 problem
8007 .run(&mut obj, "latent-coordinate joint REML")
8008 .map_err(|e| {
8009 EstimationError::InvalidInput(format!(
8010 "latent-coordinate joint optimization failed after exhausting strategy fallbacks: {e}"
8011 ))
8012 })?
8013 };
8014 if !result.converged {
8015 crate::bail_invalid_estim!(
8016 "latent-coordinate joint optimization did not converge after {} iterations (final_objective={:.6e}, final_grad_norm={})",
8017 result.iterations,
8018 result.final_value,
8019 result.final_grad_norm_report(),
8020 );
8021 }
8022
8023 let theta_star = result.rho;
8024 let rho_star = theta_star.slice(s![..rho_dim]).mapv(f64::exp);
8025 let mut final_data = data.to_owned();
8026 let flat_t = theta_star
8027 .slice(s![rho_dim..rho_dim + latent_flat_dim])
8028 .to_owned();
8029 let mut fitted_latent_values =
8030 Array2::<f64>::zeros((latent.values.n_obs(), latent.values.latent_dim()));
8031 for n in 0..latent.values.n_obs() {
8032 for axis in 0..latent.values.latent_dim() {
8033 let value = flat_t[n * latent.values.latent_dim() + axis];
8034 fitted_latent_values[[n, axis]] = value;
8035 final_data[[n, latent.feature_cols[axis]]] = value;
8036 }
8037 }
8038 let optimized = fit_term_collection_forspecwith_heuristic_lambdas(
8039 final_data.view(),
8040 y,
8041 weights,
8042 offset,
8043 resolvedspec,
8044 rho_star.as_slice(),
8045 family,
8046 options,
8047 )?;
8048 ctx.evaluator
8049 .store_persistent_latent_values(&fitted_latent_values);
8050 let mut fit = optimized.fit;
8051 fit.reml_score = result.final_value;
8052 fit.penalized_objective = result.final_value;
8053 Ok(FittedTermCollectionWithSpec {
8054 fit,
8055 design: optimized.design,
8056 resolvedspec: resolvedspec.clone(),
8057 adaptive_diagnostics: optimized.adaptive_diagnostics,
8058 kappa_timing: None,
8059 })
8060}
8061
8062pub fn fit_term_collectionwith_latent_coord_optimization(
8063 data: ArrayView2<'_, f64>,
8064 y: Array1<f64>,
8065 weights: Array1<f64>,
8066 offset: Array1<f64>,
8067 spec: &TermCollectionSpec,
8068 latent: &StandardLatentCoordConfig,
8069 family: LikelihoodSpec,
8070 options: &FitOptions,
8071) -> Result<FittedTermCollectionWithSpec, EstimationError> {
8072 let n = data.nrows();
8073 if !(y.len() == n && weights.len() == n && offset.len() == n) {
8074 crate::bail_invalid_estim!(
8075 "fit_term_collectionwith_latent_coord_optimization row mismatch: n={}, y={}, weights={}, offset={}",
8076 n,
8077 y.len(),
8078 weights.len(),
8079 offset.len()
8080 );
8081 }
8082 let best = fit_term_collection_forspec(
8083 data,
8084 y.view(),
8085 weights.view(),
8086 offset.view(),
8087 spec,
8088 family.clone(),
8089 options,
8090 )?;
8091 let resolvedspec = freeze_term_collection_from_design(spec, &best.design)?;
8092 try_exact_joint_latent_coord_optimization(
8093 data,
8094 y.view(),
8095 weights.view(),
8096 offset.view(),
8097 &resolvedspec,
8098 &best,
8099 family,
8100 options,
8101 latent,
8102 )
8103}
8104
8105pub fn fit_term_collectionwith_spatial_length_scale_optimization(
8106 data: ArrayView2<'_, f64>,
8107 y: Array1<f64>,
8108 weights: Array1<f64>,
8109 offset: Array1<f64>,
8110 spec: &TermCollectionSpec,
8111 family: LikelihoodSpec,
8112 options: &FitOptions,
8113 kappa_options: &SpatialLengthScaleOptimizationOptions,
8114) -> Result<FittedTermCollectionWithSpec, EstimationError> {
8115 let mut resolvedspec = spec.clone();
8131 let spatial_terms = spatial_length_scale_term_indices(&resolvedspec);
8132 let n = data.nrows();
8133 if !(y.len() == n && weights.len() == n && offset.len() == n) {
8134 crate::bail_invalid_estim!(
8135 "fit_term_collectionwith_spatial_length_scale_optimization row mismatch: n={}, y={}, weights={}, offset={}",
8136 n,
8137 y.len(),
8138 weights.len(),
8139 offset.len()
8140 );
8141 }
8142 if !kappa_options.enabled || spatial_terms.is_empty() {
8143 let out = fit_term_collection_forspec(
8144 data,
8145 y.view(),
8146 weights.view(),
8147 offset.view(),
8148 &resolvedspec,
8149 family,
8150 options,
8151 )?;
8152 let resolvedspec = freeze_term_collection_from_design(&resolvedspec, &out.design)?;
8153 return Ok(FittedTermCollectionWithSpec {
8154 fit: out.fit,
8155 design: out.design,
8156 resolvedspec,
8157 adaptive_diagnostics: out.adaptive_diagnostics,
8158 kappa_timing: None,
8159 });
8160 }
8161 if kappa_options.max_outer_iter == 0 {
8162 crate::bail_invalid_estim!("spatial kappa optimization requires max_outer_iter >= 1");
8163 }
8164 if !(kappa_options.log_step.is_finite() && kappa_options.log_step > 0.0) {
8165 crate::bail_invalid_estim!("spatial kappa optimization requires log_step > 0");
8166 }
8167 if !(kappa_options.min_length_scale.is_finite()
8168 && kappa_options.max_length_scale.is_finite()
8169 && kappa_options.min_length_scale > 0.0
8170 && kappa_options.max_length_scale >= kappa_options.min_length_scale)
8171 {
8172 crate::bail_invalid_estim!(
8173 "spatial kappa optimization requires valid positive length_scale bounds"
8174 );
8175 }
8176
8177 let pilot_threshold = kappa_options.pilot_subsample_threshold;
8178 if pilot_threshold > 0 && n > pilot_threshold * 2 {
8179 log::info!(
8180 "[spatial-kappa] n={n} exceeds pilot threshold {}; using pilot geometry only for deterministic anisotropy initialization",
8181 pilot_threshold * 2,
8182 );
8183 apply_spatial_anisotropy_pilot_initializer(
8184 data,
8185 &mut resolvedspec,
8186 &spatial_terms,
8187 pilot_threshold,
8188 kappa_options,
8189 );
8190 }
8191
8192 apply_response_aware_anisotropy_seed(data, y.view(), &mut resolvedspec, &spatial_terms);
8201
8202 for term_idx in constant_curvature_term_indices(&resolvedspec) {
8220 if let Some(kappa_seed) =
8221 select_constant_curvature_kappa_sign_seed(data, y.view(), &resolvedspec, term_idx)
8222 && kappa_seed != 0.0
8223 && let Some(SmoothBasisSpec::ConstantCurvature { spec: cc, .. }) =
8224 resolvedspec.smooth_terms.get_mut(term_idx).map(|t| &mut t.basis)
8225 {
8226 log::info!(
8227 "[#1464] pinned CC term {term_idx} baseline κ to κ-fair scan value {kappa_seed} \
8228 (raw profiled REML is sign-blind; scan is authoritative for the sign)"
8229 );
8230 cc.kappa = kappa_seed;
8231 }
8232 }
8233
8234 let baseline_options = superseded_fit_options(options);
8235 let mut best = fit_term_collection_forspec(
8236 data,
8237 y.view(),
8238 weights.view(),
8239 offset.view(),
8240 &resolvedspec,
8241 family.clone(),
8242 &baseline_options,
8243 )?;
8244 resolvedspec = freeze_term_collection_from_design(&resolvedspec, &best.design)?;
8245 let mut spatial_terms = spatial_length_scale_term_indices(&resolvedspec);
8255 sync_aniso_contrasts_from_metadata(&mut resolvedspec, &best.design.smooth);
8259 let mut prescan_improved = false;
8266 if !spatial_terms.is_empty() {
8267 let baseline_score = fit_score(&best.fit);
8268 let range_overrides = prescan_isotropic_spatial_range_seed(
8269 data,
8270 y.view(),
8271 weights.view(),
8272 offset.view(),
8273 &resolvedspec,
8274 baseline_score,
8275 &family,
8276 &baseline_options,
8277 kappa_options,
8278 &spatial_terms,
8279 )?;
8280 if !range_overrides.is_empty() {
8281 prescan_improved = true;
8282 for (term_idx, length_scale) in range_overrides {
8283 set_spatial_length_scale(&mut resolvedspec, term_idx, length_scale)?;
8284 }
8285 best = fit_term_collection_forspec(
8289 data,
8290 y.view(),
8291 weights.view(),
8292 offset.view(),
8293 &resolvedspec,
8294 family.clone(),
8295 &baseline_options,
8296 )?;
8297 resolvedspec = freeze_term_collection_from_design(&resolvedspec, &best.design)?;
8298 spatial_terms = spatial_length_scale_term_indices(&resolvedspec);
8302 }
8303 }
8304 if spatial_terms.is_empty() {
8305 let fitted = fit_term_collection_forspecwith_heuristic_lambdas(
8306 data,
8307 y.view(),
8308 weights.view(),
8309 offset.view(),
8310 &resolvedspec,
8311 best.fit.lambdas.as_slice(),
8312 family,
8313 options,
8314 )?;
8315 return Ok(FittedTermCollectionWithSpec {
8316 fit: fitted.fit,
8317 design: fitted.design,
8318 resolvedspec,
8319 adaptive_diagnostics: fitted.adaptive_diagnostics,
8320 kappa_timing: None,
8321 });
8322 }
8323 let initial_score = fit_score(&best.fit);
8324 if !initial_score.is_finite() {
8325 log::debug!("[spatial-kappa] initial profiled score is non-finite");
8326 }
8327 let seed_length_scales: Vec<(usize, f64)> = spatial_terms
8334 .iter()
8335 .filter_map(|&t| get_spatial_length_scale(&resolvedspec, t).map(|ls| (t, ls)))
8336 .collect();
8337 let joint_result = try_exact_joint_spatial_length_scale_optimization(
8338 data,
8339 y.view(),
8340 weights.view(),
8341 offset.view(),
8342 &resolvedspec,
8343 &best,
8344 family.clone(),
8345 options,
8346 kappa_options,
8347 &spatial_terms,
8348 )
8349 .map(|opt| {
8350 opt.map(|fit| {
8351 let score = fit_score(&fit.fit);
8352 (fit, score)
8353 })
8354 });
8355 let exact_joint = if prescan_improved && !matches!(joint_result, Ok(Some(_))) {
8365 let reason = match &joint_result {
8366 Err(e) => format!("error: {e}"),
8367 _ => "unavailable".to_string(),
8368 };
8369 log::info!(
8370 "[spatial-kappa] #1074 joint polish yielded no usable candidate \
8371 ({reason}); returning the multi-start pre-scan geometry (REML {initial_score:.5})"
8372 );
8373 FittedTermCollectionWithSpec {
8374 fit: best.fit,
8375 design: best.design,
8376 resolvedspec,
8377 adaptive_diagnostics: best.adaptive_diagnostics,
8378 kappa_timing: None,
8379 }
8380 } else {
8381 require_successful_spatial_optimization_result(initial_score, joint_result)?
8382 };
8383
8384 let exact_joint = {
8411 let primary_score = fit_score(&exact_joint.fit);
8412 let improved = primary_score.is_finite()
8413 && initial_score.is_finite()
8414 && primary_score < initial_score - 1e-7 * initial_score.abs().max(1.0);
8415 let base_spec = exact_joint.resolvedspec.clone();
8420 let geometry_unchanged = !seed_length_scales.is_empty()
8423 && seed_length_scales.iter().all(|&(t, seed_ls)| {
8424 get_spatial_length_scale(&base_spec, t)
8425 .is_some_and(|ls| (ls - seed_ls).abs() <= 1e-6 * seed_ls.abs().max(1.0))
8426 });
8427 let eligible = !improved
8428 && geometry_unchanged
8429 && !has_aniso_terms(&base_spec, &spatial_terms)
8430 && constant_curvature_term_indices(&base_spec).is_empty()
8431 && spatial_terms
8432 .iter()
8433 .any(|&t| get_spatial_length_scale(&base_spec, t).is_some());
8434 if eligible {
8435 log::info!(
8436 "[spatial-kappa] #1688 joint solve stalled at REML {primary_score:.5} \
8437 (no improvement over baseline {initial_score:.5}); running ψ-window \
8438 multistart rescue across {} seeds",
8439 JOINT_RESTART_WINDOW_FRACTIONS.len(),
8440 );
8441 let mut best_fit = exact_joint;
8442 let mut best_score = primary_score;
8444 for &fraction in JOINT_RESTART_WINDOW_FRACTIONS.iter() {
8445 match joint_solve_from_window_fraction(
8446 data,
8447 y.view(),
8448 weights.view(),
8449 offset.view(),
8450 &base_spec,
8451 &spatial_terms,
8452 fraction,
8453 &family,
8454 options,
8455 &baseline_options,
8456 kappa_options,
8457 ) {
8458 Ok(Some((candidate, score))) => {
8459 if score.is_finite()
8460 && (!best_score.is_finite()
8461 || score < best_score - 1e-7 * best_score.abs().max(1.0))
8462 {
8463 log::info!(
8464 "[spatial-kappa] #1688 multistart seed (ψ-window {fraction:.2}) \
8465 reached REML {score:.5}, improving on {best_score:.5}",
8466 );
8467 best_score = score;
8468 best_fit = candidate;
8469 }
8470 }
8471 Ok(None) => {}
8473 Err(e) => {
8477 log::info!(
8478 "[spatial-kappa] #1688 multistart seed (ψ-window {fraction:.2}) \
8479 failed ({e}); skipping"
8480 );
8481 }
8482 }
8483 }
8484 best_fit
8485 } else {
8486 exact_joint
8487 }
8488 };
8489
8490 log_spatial_aniso_scales(&exact_joint.resolvedspec);
8491 Ok(exact_joint)
8492}
8493
8494#[derive(Clone, Debug)]
8500pub struct CurvatureInference {
8501 pub term_idx: usize,
8503 pub kappa_hat: f64,
8506 pub ci: gam_geometry::curvature_estimand::KappaProfileCi,
8508 pub flatness: gam_geometry::curvature_estimand::FlatnessTest,
8512}
8513
8514pub fn curvature_inference_forspec(
8532 data: ArrayView2<'_, f64>,
8533 y: ArrayView1<'_, f64>,
8534 weights: ArrayView1<'_, f64>,
8535 offset: ArrayView1<'_, f64>,
8536 resolvedspec: &TermCollectionSpec,
8537 term_idx: usize,
8538 family: LikelihoodSpec,
8539 options: &FitOptions,
8540 level: f64,
8541) -> Result<CurvatureInference, EstimationError> {
8542 let kappa_hat = get_constant_curvature_kappa(resolvedspec, term_idx).ok_or_else(|| {
8543 EstimationError::InvalidInput(format!(
8544 "curvature_inference_forspec: term {term_idx} is not a constant-curvature smooth"
8545 ))
8546 })?;
8547 let (kappa_min, kappa_max) = constant_curvature_kappa_bounds(data, resolvedspec, term_idx);
8548
8549 let cc_fair_inputs: Option<(Array2<f64>, gam_terms::basis::ConstantCurvatureBasisSpec)> =
8574 if kappa_hat < 0.0 {
8575 match resolvedspec.smooth_terms.get(term_idx).map(|t| &t.basis) {
8576 Some(SmoothBasisSpec::ConstantCurvature {
8577 feature_cols, spec, ..
8578 }) => select_columns(data, feature_cols)
8579 .ok()
8580 .map(|x| (x, spec.clone())),
8581 _ => None,
8582 }
8583 } else {
8584 None
8585 };
8586
8587 let v_p_cache: std::cell::RefCell<std::collections::HashMap<u64, f64>> =
8592 std::cell::RefCell::new(std::collections::HashMap::new());
8593 let v_p = |kappa: f64| -> Result<f64, String> {
8594 if !kappa.is_finite() {
8595 return Err(format!("V_p probed a non-finite κ = {kappa}"));
8596 }
8597 let key = kappa.to_bits();
8598 if let Some(&cached) = v_p_cache.borrow().get(&key) {
8599 return Ok(cached);
8600 }
8601 let score = if let Some((x_term, base_spec)) = &cc_fair_inputs {
8602 let mut probe_spec = base_spec.clone();
8603 probe_spec.kappa = kappa;
8604 gam_terms::basis::constant_curvature_kappa_fair_sign_score(x_term.view(), y, &probe_spec)
8605 .map_err(|e| format!("κ-fair criterion at κ={kappa} failed: {e}"))?
8606 } else {
8607 fixed_kappa_profiled_reml_score(
8608 data,
8609 y,
8610 weights,
8611 offset,
8612 resolvedspec,
8613 term_idx,
8614 kappa,
8615 family.clone(),
8616 options,
8617 )
8618 .map_err(|e| format!("V_p fixed-κ fit at κ={kappa} failed: {e}"))?
8619 };
8620 v_p_cache.borrow_mut().insert(key, score);
8621 Ok(score)
8622 };
8623
8624 let h = (1e-3 * (kappa_max - kappa_min)).max(1e-4);
8628 let v_pp = match (v_p(kappa_hat + h), v_p(kappa_hat), v_p(kappa_hat - h)) {
8629 (Ok(vp), Ok(v0), Ok(vm)) => (vp - 2.0 * v0 + vm) / (h * h),
8630 _ => f64::NAN, };
8632
8633 let ci = gam_geometry::curvature_estimand::profile_ci_walk(
8634 &v_p, kappa_hat, v_pp, kappa_min, kappa_max, level, 1e-4,
8635 )
8636 .map_err(EstimationError::InvalidInput)?;
8637 let flatness = gam_geometry::curvature_estimand::flatness_lr_test(&v_p, kappa_hat)
8638 .map_err(EstimationError::InvalidInput)?;
8639
8640 Ok(CurvatureInference {
8641 term_idx,
8642 kappa_hat,
8643 ci,
8644 flatness,
8645 })
8646}
8647
8648#[derive(Clone, Copy, Debug, PartialEq, Eq)]
8651pub enum SmoothLrCorrection {
8652 LawleyLrEstimatedLambda,
8656 LawleyLrFixedLambda,
8661 None,
8665}
8666
8667impl SmoothLrCorrection {
8668 pub fn label(self) -> &'static str {
8670 match self {
8671 SmoothLrCorrection::LawleyLrEstimatedLambda => "lawley_lr_estimated_lambda",
8672 SmoothLrCorrection::LawleyLrFixedLambda => "lawley_lr_fixed_lambda",
8673 SmoothLrCorrection::None => "none",
8674 }
8675 }
8676}
8677
8678#[derive(Clone, Debug)]
8684pub struct SmoothTermLrInference {
8685 pub name: String,
8687 pub term_idx: usize,
8689 pub statistic_lr: f64,
8692 pub ref_df: f64,
8695 pub bartlett_factor: f64,
8698 pub bartlett_factor_conditional: Option<f64>,
8702 pub rho_variation_shift: Option<f64>,
8705 pub statistic_corrected: f64,
8707 pub p_value_uncorrected: f64,
8709 pub p_value_corrected: f64,
8712 pub material: bool,
8720 pub correction: SmoothLrCorrection,
8722}
8723
8724pub const SMOOTH_LR_MATERIAL_THRESHOLD: f64 = 0.10;
8728
8729fn fitted_rho_penalty_components(
8735 penalties: &[BlockwisePenalty],
8736 lambdas: &[f64],
8737 p_total: usize,
8738) -> Result<Vec<gam_terms::inference::lawley::RhoPenaltyComponent>, EstimationError> {
8739 if penalties.len() != lambdas.len() {
8740 return Err(EstimationError::InvalidInput(format!(
8741 "smooth_term_lr_inference: penalty/lambda count mismatch ({} penalties, {} lambdas)",
8742 penalties.len(),
8743 lambdas.len()
8744 )));
8745 }
8746 let mut components = Vec::with_capacity(penalties.len());
8747 for (idx, (penalty, &lambda)) in penalties.iter().zip(lambdas.iter()).enumerate() {
8748 if !(lambda.is_finite() && lambda >= 0.0) {
8749 return Err(EstimationError::InvalidInput(format!(
8750 "smooth_term_lr_inference: lambda[{idx}] is invalid: {lambda}"
8751 )));
8752 }
8753 let r = &penalty.col_range;
8754 if r.end > p_total {
8755 return Err(EstimationError::InvalidInput(format!(
8756 "smooth_term_lr_inference: penalty[{idx}] range {:?} exceeds coefficient dimension {p_total}",
8757 r
8758 )));
8759 }
8760 let mut s_component = Array2::<f64>::zeros((p_total, p_total));
8761 s_component
8762 .slice_mut(s![r.start..r.end, r.start..r.end])
8763 .scaled_add(lambda, &penalty.local);
8764 components.push(gam_terms::inference::lawley::RhoPenaltyComponent { s_component });
8765 }
8766 Ok(components)
8767}
8768
8769pub fn smooth_term_lr_inference_forspec(
8814 data: ArrayView2<'_, f64>,
8815 y: ArrayView1<'_, f64>,
8816 weights: ArrayView1<'_, f64>,
8817 offset: ArrayView1<'_, f64>,
8818 resolvedspec: &TermCollectionSpec,
8819 family: LikelihoodSpec,
8820 options: &FitOptions,
8821) -> Result<Vec<SmoothTermLrInference>, EstimationError> {
8822 use gam_terms::inference::lawley::{
8823 LAWLEY_PAIR_MATRIX_MAX_ROWS, known_scale_expected_jets_with_dispersion,
8824 lawley_lr_bartlett_factor, lawley_lr_mean_shift_with_rho_variation,
8825 };
8826
8827 let n = data.nrows();
8828 let full = fit_term_collection_forspec(
8831 data,
8832 y,
8833 weights,
8834 offset,
8835 resolvedspec,
8836 family.clone(),
8837 options,
8838 )?;
8839 let ll_full = full.fit.log_likelihood;
8840 let p_total = full.design.design.ncols();
8841 let lambdas = full.fit.lambdas.as_slice().ok_or_else(|| {
8842 EstimationError::InvalidInput(
8843 "smooth_term_lr_inference: non-contiguous lambda vector".to_string(),
8844 )
8845 })?;
8846 let s_lambda = weighted_blockwise_penalty_sum(&full.design.penalties, lambdas, p_total);
8847 let rho_penalty_components =
8848 fitted_rho_penalty_components(&full.design.penalties, lambdas, p_total)?;
8849 let rho_covariance = full.fit.artifacts.rho_covariance.as_ref().filter(|cov| {
8850 cov.nrows() == rho_penalty_components.len() && cov.ncols() == rho_penalty_components.len()
8851 });
8852 let full_design_dense = full.design.design.to_dense();
8854 let influence = full.fit.coefficient_influence();
8855 let family_disp = lawley_dispersion_for_family(&family, &full.fit);
8856
8857 let mut penalty_cursor = full.design.random_effect_ranges.len();
8860 let mut out = Vec::<SmoothTermLrInference>::new();
8861 for (term_idx, design_term) in full.design.smooth.terms.iter().enumerate() {
8862 let k = design_term.penalties_local.len();
8863 let block_start = penalty_cursor;
8864 penalty_cursor += k;
8865 if design_term.shape != ShapeConstraint::None {
8868 continue;
8869 }
8870 let coeff_range = design_term.coeff_range.clone();
8871 if coeff_range.start >= coeff_range.end || coeff_range.end > p_total {
8872 continue;
8873 }
8874 let edf = full.fit.per_term_edf(coeff_range.clone(), block_start, k);
8886 let null_dim: usize = design_term.nullspace_dims.iter().sum();
8892 let ref_df = wood_reference_df(influence, &coeff_range)
8910 .unwrap_or(0.0)
8911 .max(edf)
8912 .max(null_dim.max(1) as f64)
8913 .max(1e-12);
8914 if !(ref_df.is_finite() && ref_df > 0.0) {
8915 continue;
8916 }
8917
8918 let mut null_spec = resolvedspec.clone();
8921 let Some(spec_pos) = null_spec
8922 .smooth_terms
8923 .iter()
8924 .position(|t| t.name == design_term.name)
8925 else {
8926 continue;
8927 };
8928 null_spec.smooth_terms.remove(spec_pos);
8929 let null_fit = fit_term_collection_forspec(
8930 data,
8931 y,
8932 weights,
8933 offset,
8934 &null_spec,
8935 family.clone(),
8936 options,
8937 );
8938 let (statistic_lr, eta_null) = match null_fit {
8939 Ok(null) if null.fit.log_likelihood.is_finite() => {
8940 let w = (2.0 * (ll_full - null.fit.log_likelihood)).max(0.0);
8941 let mut eta = null.design.design.dot(&null.fit.beta);
8945 eta += &offset;
8946 (w, Some(eta))
8947 }
8948 _ => (f64::NAN, None),
8949 };
8950
8951 let chi2 = statrs::distribution::ChiSquared::new(ref_df).ok();
8952 let p_uncorrected = match (chi2.as_ref(), statistic_lr.is_finite()) {
8953 (Some(dist), true) => {
8954 use statrs::distribution::ContinuousCDF;
8955 (1.0 - dist.cdf(statistic_lr)).clamp(0.0, 1.0)
8956 }
8957 _ => f64::NAN,
8958 };
8959
8960 let mut bartlett_factor = 1.0;
8964 let mut bartlett_factor_conditional = None;
8965 let mut rho_variation_shift = None;
8966 let mut statistic_corrected = statistic_lr;
8967 let mut p_corrected = p_uncorrected;
8968 let mut correction = SmoothLrCorrection::None;
8969 if let (Some(eta), true, true) = (
8970 eta_null.as_ref(),
8971 statistic_lr.is_finite(),
8972 n <= LAWLEY_PAIR_MATRIX_MAX_ROWS,
8973 ) {
8974 let kappas: Option<Vec<_>> = (0..n)
8975 .map(|i| {
8976 known_scale_expected_jets_with_dispersion(&family, eta[i], family_disp)
8977 .and_then(|jets| jets.kappas().ok())
8978 })
8979 .collect();
8980 if let (Some(kappas), Some(dist)) = (kappas, chi2.as_ref()) {
8981 let fixed_factor = lawley_lr_bartlett_factor(
8982 full_design_dense.view(),
8983 &kappas,
8984 Some(s_lambda.view()),
8985 coeff_range.clone(),
8986 ref_df,
8987 );
8988 if let Ok(c_cond) = fixed_factor
8989 && c_cond.is_finite()
8990 && c_cond > 0.0
8991 {
8992 let mut c_applied = c_cond;
8993 correction = SmoothLrCorrection::LawleyLrFixedLambda;
8994 if let Some(cov) = rho_covariance
8995 && let Ok(total_shift) = lawley_lr_mean_shift_with_rho_variation(
8996 full_design_dense.view(),
8997 &kappas,
8998 s_lambda.view(),
8999 coeff_range.clone(),
9000 &rho_penalty_components,
9001 cov.view(),
9002 )
9003 {
9004 let mean_w = ref_df + total_shift;
9005 if let Some(c_est) =
9006 gam_terms::inference::higher_order::bartlett_factor_from_mean(
9007 mean_w, ref_df,
9008 )
9009 && c_est.is_finite()
9010 && c_est > 0.0
9011 {
9012 let conditional_shift = (c_cond - 1.0) * ref_df;
9013 c_applied = c_est;
9014 bartlett_factor_conditional = Some(c_cond);
9015 rho_variation_shift = Some(total_shift - conditional_shift);
9016 correction = SmoothLrCorrection::LawleyLrEstimatedLambda;
9017 }
9018 }
9019 use statrs::distribution::ContinuousCDF;
9020 bartlett_factor = c_applied;
9021 statistic_corrected = statistic_lr / c_applied;
9022 p_corrected = (1.0 - dist.cdf(statistic_corrected)).clamp(0.0, 1.0);
9023 }
9024 }
9025 }
9026
9027 let material = match correction {
9033 SmoothLrCorrection::LawleyLrEstimatedLambda
9034 | SmoothLrCorrection::LawleyLrFixedLambda => {
9035 let factor_move = (bartlett_factor - 1.0).abs();
9036 let p_denom = p_uncorrected.max(p_corrected).max(f64::MIN_POSITIVE);
9037 let p_move = if p_uncorrected.is_finite() && p_corrected.is_finite() {
9038 (p_corrected - p_uncorrected).abs() / p_denom
9039 } else {
9040 0.0
9041 };
9042 factor_move > SMOOTH_LR_MATERIAL_THRESHOLD || p_move > SMOOTH_LR_MATERIAL_THRESHOLD
9043 }
9044 SmoothLrCorrection::None => false,
9045 };
9046
9047 out.push(SmoothTermLrInference {
9048 name: design_term.name.clone(),
9049 term_idx,
9050 statistic_lr,
9051 ref_df,
9052 bartlett_factor,
9053 bartlett_factor_conditional,
9054 rho_variation_shift,
9055 statistic_corrected,
9056 p_value_uncorrected: p_uncorrected,
9057 p_value_corrected: p_corrected,
9058 material,
9059 correction,
9060 });
9061 }
9062 Ok(out)
9063}
9064
9065fn lawley_dispersion_for_family(family: &LikelihoodSpec, fit: &UnifiedFitResult) -> f64 {
9068 match family.response {
9069 gam_spec::ResponseFamily::Gaussian => {
9070 let sd = fit.standard_deviation;
9071 (sd * sd).max(f64::MIN_POSITIVE)
9072 }
9073 gam_spec::ResponseFamily::Gamma => {
9074 let shape = fit.standard_deviation;
9075 if shape.is_finite() && shape > 0.0 {
9076 1.0 / shape
9077 } else {
9078 1.0
9079 }
9080 }
9081 _ => 1.0,
9082 }
9083}
9084
9085fn wood_reference_df(influence: Option<&Array2<f64>>, coeff_range: &Range<usize>) -> Option<f64> {
9091 let f = influence?;
9092 let (start, end) = (coeff_range.start, coeff_range.end);
9093 if start >= end || end > f.nrows() || end > f.ncols() {
9094 return None;
9095 }
9096 let block = f.slice(s![start..end, start..end]);
9097 let tr = (0..block.nrows()).map(|i| block[[i, i]]).sum::<f64>();
9098 let tr2 = block.dot(&block).diag().sum();
9099 (tr.is_finite() && tr2.is_finite() && tr > 0.0 && tr2 > 0.0).then(|| (tr * tr / tr2).max(1e-12))
9100}