1fn try_build_spatial_term_log_kappa_derivative(
2 data: ArrayView2<'_, f64>,
3 resolvedspec: &TermCollectionSpec,
4 design: &TermCollectionDesign,
5 term_idx: usize,
6) -> Result<
7 Option<(
8 Range<usize>,
9 usize,
10 Array2<f64>,
11 Array2<f64>,
12 Array2<f64>,
13 Array2<f64>,
14 Vec<Array2<f64>>,
15 Vec<Array2<f64>>,
16 Option<std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>>,
17 )>,
18 EstimationError,
19> {
20 let Some(smooth_term) = design.smooth.terms.get(term_idx) else {
21 return Ok(None);
22 };
23 let Some(termspec) = resolvedspec.smooth_terms.get(term_idx) else {
24 return Ok(None);
25 };
26
27 let derivative_bundle = match &termspec.basis {
28 SmoothBasisSpec::ThinPlate {
29 feature_cols,
30 spec,
31 input_scales,
32 } => {
33 let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
34 let mut spec_local = spec.clone();
35 if let Some(s) = input_scales {
36 apply_input_standardization(&mut x, s);
37 spec_local.length_scale =
38 compensate_length_scale_for_standardization(spec.length_scale, s);
39 }
40 build_thin_plate_basis_log_kappa_derivatives(x.view(), &spec_local)
41 .map_err(EstimationError::from)?
42 }
43 SmoothBasisSpec::Sphere { .. } => return Ok(None),
44 SmoothBasisSpec::ConstantCurvature { feature_cols, spec } => {
53 let x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
54 build_constant_curvature_basis_kappa_derivatives(x.view(), spec)
55 .map_err(EstimationError::from)?
56 }
57 SmoothBasisSpec::MeasureJet { .. } => return Ok(None),
63 SmoothBasisSpec::Matern {
64 feature_cols,
65 spec,
66 input_scales,
67 } => {
68 let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
69 let mut spec_local = spec.clone();
70 if let Some(s) = input_scales {
71 apply_input_standardization(&mut x, s);
72 spec_local.length_scale =
73 compensate_length_scale_for_standardization(spec.length_scale, s);
74 }
75 build_matern_basis_log_kappa_derivatives(x.view(), &spec_local)
76 .map_err(EstimationError::from)?
77 }
78 SmoothBasisSpec::Duchon {
79 feature_cols,
80 spec,
81 input_scales,
82 } => {
83 let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
84 let mut spec_local = spec.clone();
85 if let Some(s) = input_scales {
86 apply_input_standardization(&mut x, s);
87 spec_local.length_scale =
88 compensate_optional_length_scale_for_standardization(spec.length_scale, s);
89 }
90 let BasisMetadata::Duchon {
91 centers,
92 identifiability_transform,
93 operator_collocation_points,
94 radial_reparam,
95 ..
96 } = &smooth_term.metadata
97 else {
98 return Ok(None);
99 };
100 if spec_local.radial_reparam.is_none() {
103 spec_local.radial_reparam = radial_reparam.clone();
104 }
105 gam_terms::basis::build_duchon_basis_log_kappa_derivativeswith_collocationwithworkspace(
106 x.view(),
107 &spec_local,
108 centers.view(),
109 identifiability_transform.as_ref(),
110 operator_collocation_points
111 .as_ref()
112 .map(|points| points.view()),
113 &mut BasisWorkspace::default(),
114 )
115 .map_err(EstimationError::from)?
116 }
117 SmoothBasisSpec::BSpline1D { .. }
118 | SmoothBasisSpec::TensorBSpline { .. }
119 | SmoothBasisSpec::ByVariable { .. }
120 | SmoothBasisSpec::FactorSumToZero { .. }
121 | SmoothBasisSpec::BySmooth { .. }
122 | SmoothBasisSpec::FactorSmooth { .. }
123 | SmoothBasisSpec::Pca { .. } => {
124 return Ok(None);
125 }
126 };
127 let mut implicit_operator = derivative_bundle.implicit_operator;
128 let BasisPsiDerivativeResult {
129 design_derivative: mut local_x_psi,
130 penalties_derivative: mut local_s_psi,
131 implicit_operator: local_implicit_first_unused,
132 } = derivative_bundle.first;
133 let BasisPsiSecondDerivativeResult {
134 designsecond_derivative: mut local_x_psi_psi,
135 penaltiessecond_derivative: mut local_s_psi_psi,
136 implicit_operator: local_implicit_second_unused,
137 } = derivative_bundle.second;
138 assert!(local_implicit_first_unused.is_none());
139 assert!(local_implicit_second_unused.is_none());
140
141 if let Some(rotation) = smooth_term.joint_null_rotation.as_ref() {
142 let q = &rotation.rotation;
143 if let Some(op) = implicit_operator.take() {
144 implicit_operator = Some(op.append_full_transform(q).map_err(EstimationError::from)?);
145 } else {
146 if local_x_psi.ncols() != q.nrows() || local_x_psi_psi.ncols() != q.nrows() {
147 return Ok(None);
148 }
149 local_x_psi = fast_ab(&local_x_psi, q);
150 local_x_psi_psi = fast_ab(&local_x_psi_psi, q);
151 }
152 let rotate_penalty = |s_local: Array2<f64>| -> Option<Array2<f64>> {
153 if s_local.nrows() != q.nrows() || s_local.ncols() != q.nrows() {
154 return None;
155 }
156 let qt_s = gam_linalg::faer_ndarray::fast_atb(q, &s_local);
157 Some(gam_linalg::faer_ndarray::fast_ab(&qt_s, q))
158 };
159 let Some(rotated_s_psi) = local_s_psi
160 .into_iter()
161 .map(|s| rotate_penalty(s))
162 .collect::<Option<Vec<_>>>()
163 else {
164 return Ok(None);
165 };
166 local_s_psi = rotated_s_psi;
167 let Some(rotated_s_psi_psi) = local_s_psi_psi
168 .into_iter()
169 .map(|s| rotate_penalty(s))
170 .collect::<Option<Vec<_>>>()
171 else {
172 return Ok(None);
173 };
174 local_s_psi_psi = rotated_s_psi_psi;
175 }
176 let implicit_operator = implicit_operator.map(std::sync::Arc::new);
177
178 if let Some(ref op) = implicit_operator {
179 if op.p_out() != smooth_term.coeff_range.len() {
180 return Ok(None);
181 }
182 } else {
183 if local_x_psi.ncols() != smooth_term.coeff_range.len() {
184 return Ok(None);
185 }
186 if local_x_psi_psi.ncols() != smooth_term.coeff_range.len() {
187 return Ok(None);
188 }
189 }
190 if local_s_psi.is_empty() || local_s_psi.len() != local_s_psi_psi.len() {
191 return Ok(None);
192 }
193 if local_s_psi.iter().any(|s| {
194 s.nrows() != smooth_term.coeff_range.len() || s.ncols() != smooth_term.coeff_range.len()
195 }) {
196 return Ok(None);
197 }
198 if local_s_psi_psi.iter().any(|s| {
199 s.nrows() != smooth_term.coeff_range.len() || s.ncols() != smooth_term.coeff_range.len()
200 }) {
201 return Ok(None);
202 }
203
204 let p_total = design.design.ncols();
205 let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
206 let global_range = (smooth_start + smooth_term.coeff_range.start)
207 ..(smooth_start + smooth_term.coeff_range.end);
208
209 Ok(Some((
210 global_range,
211 p_total,
212 local_x_psi,
213 local_s_psi.iter().fold(
214 Array2::<f64>::zeros((smooth_term.coeff_range.len(), smooth_term.coeff_range.len())),
215 |acc, m| acc + m,
216 ),
217 local_x_psi_psi,
218 local_s_psi_psi.iter().fold(
219 Array2::<f64>::zeros((smooth_term.coeff_range.len(), smooth_term.coeff_range.len())),
220 |acc, m| acc + m,
221 ),
222 local_s_psi,
223 local_s_psi_psi,
224 implicit_operator,
225 )))
226}
227
228fn try_build_spatial_log_kappa_hyper_dirs(
229 data: ArrayView2<'_, f64>,
230 resolvedspec: &TermCollectionSpec,
231 design: &TermCollectionDesign,
232 spatial_terms: &[usize],
233) -> Result<Option<Vec<DirectionalHyperParam>>, EstimationError> {
234 let Some(info_list) =
241 try_build_spatial_log_kappa_derivativeinfo_list(data, resolvedspec, design, spatial_terms)?
242 else {
243 return Ok(None);
244 };
245 Ok(Some(spatial_log_kappa_hyper_dirs_frominfo_list(info_list)?))
246}
247
248pub(crate) fn try_build_latent_coord_hyper_dirs(
249 latent: std::sync::Arc<gam_terms::latent::LatentCoordValues>,
250 resolvedspec: &TermCollectionSpec,
251 design: &TermCollectionDesign,
252 latent_terms: &[gam_problem::types::SmoothTermIdx],
253 analytic_rho_count: usize,
254) -> Result<Option<Vec<DirectionalHyperParam>>, EstimationError> {
255 if latent_terms.is_empty() || latent.is_empty() {
256 return Ok(None);
257 }
258 if latent_terms.len() != 1 {
259 crate::bail_invalid_estim!(
260 "LatentCoord standard-fit hyper_dirs currently require exactly one latent smooth term"
261 .to_string(),
262 );
263 }
264 let term_idx = latent_terms[0];
265 let smooth_term = design.smooth.terms.get(term_idx.get()).ok_or_else(|| {
266 EstimationError::InvalidInput(format!(
267 "LatentCoord term index {term_idx} out of bounds for realized smooth design"
268 ))
269 })?;
270 let termspec = resolvedspec
271 .smooth_terms
272 .get(term_idx.get())
273 .ok_or_else(|| {
274 EstimationError::InvalidInput(format!(
275 "LatentCoord term index {term_idx} out of bounds for resolved smooth spec"
276 ))
277 })?;
278 let p_total = design.design.ncols();
279 let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
280 let global_range = (smooth_start + smooth_term.coeff_range.start)
281 ..(smooth_start + smooth_term.coeff_range.end);
282
283 let operator = match (&termspec.basis, &smooth_term.metadata) {
288 (
289 SmoothBasisSpec::Matern { .. },
290 BasisMetadata::Matern {
291 centers,
292 length_scale,
293 nu,
294 include_intercept,
295 identifiability_transform,
296 ..
297 },
298 ) => gam_terms::basis::LatentCoordDesignDerivative::new_matern(
299 latent.clone(),
300 std::sync::Arc::new(centers.clone()),
301 *length_scale,
302 *nu,
303 *include_intercept,
304 identifiability_transform.clone(),
305 )
306 .map_err(EstimationError::from)?,
307 (
308 SmoothBasisSpec::Duchon { .. },
309 BasisMetadata::Duchon {
310 centers,
311 length_scale,
312 power,
313 nullspace_order,
314 identifiability_transform,
315 ..
316 },
317 ) => gam_terms::basis::LatentCoordDesignDerivative::new_duchon(
318 latent.clone(),
319 std::sync::Arc::new(centers.clone()),
320 *length_scale,
321 *power,
322 *nullspace_order,
323 identifiability_transform.clone(),
324 )
325 .map_err(EstimationError::from)?,
326 (
327 SmoothBasisSpec::Sphere { .. },
328 BasisMetadata::Sphere {
329 centers,
330 penalty_order,
331 method,
332 constraint_transform,
333 ..
334 },
335 ) if matches!(*method, gam_terms::basis::SphereMethod::Wahba) => {
336 gam_terms::basis::LatentCoordDesignDerivative::new_sphere(
337 latent.clone(),
338 std::sync::Arc::new(centers.clone()),
339 *penalty_order,
340 constraint_transform.clone(),
341 )
342 .map_err(EstimationError::from)?
343 }
344 (
345 SmoothBasisSpec::BSpline1D { spec, .. },
346 BasisMetadata::BSpline1D {
347 knots,
348 identifiability_transform,
349 periodic,
350 degree: meta_degree,
351 ..
352 },
353 ) => {
354 let effective_degree = meta_degree.unwrap_or(spec.degree);
358 if let Some((domain_start, period, num_basis)) = periodic {
359 gam_terms::basis::LatentCoordDesignDerivative::new_periodic_bspline(
360 latent.clone(),
361 (*domain_start, *domain_start + *period),
362 effective_degree,
363 *num_basis,
364 identifiability_transform.clone(),
365 )
366 .map_err(EstimationError::from)?
367 } else {
368 gam_terms::basis::LatentCoordDesignDerivative::new_tensor_bspline(
369 latent.clone(),
370 vec![knots.clone()],
371 vec![effective_degree],
372 identifiability_transform.clone(),
373 )
374 .map_err(EstimationError::from)?
375 }
376 }
377 (
378 SmoothBasisSpec::TensorBSpline { .. },
379 BasisMetadata::TensorBSpline {
380 knots,
381 degrees,
382 identifiability_transform,
383 ..
384 },
385 ) => gam_terms::basis::LatentCoordDesignDerivative::new_tensor_bspline(
386 latent.clone(),
387 knots.clone(),
388 degrees.clone(),
389 identifiability_transform.clone(),
390 )
391 .map_err(EstimationError::from)?,
392 (SmoothBasisSpec::Pca { .. }, BasisMetadata::Pca { basis_matrix, .. }) => {
393 gam_terms::basis::LatentCoordDesignDerivative::new_pca(
394 latent.clone(),
395 std::sync::Arc::new(basis_matrix.clone()),
396 )
397 .map_err(EstimationError::from)?
398 }
399 _ => return Ok(None),
400 };
401 if operator.p_out() != global_range.len() {
402 crate::bail_invalid_estim!(
403 "LatentCoord derivative width mismatch for term '{}': operator p={}, coeff range={}",
404 smooth_term.name,
405 operator.p_out(),
406 global_range.len()
407 );
408 }
409 let operator = std::sync::Arc::new(operator);
410 let mut hyper_dirs = Vec::with_capacity(operator.n_axes());
411 for flat_axis in 0..operator.n_axes() {
412 let dir = DirectionalHyperParam::new_compact(
413 gam_solve::estimate::reml::HyperDesignDerivative::from_latent_coord(
414 operator.clone(),
415 flat_axis,
416 global_range.clone(),
417 p_total,
418 ),
419 Vec::new(),
420 None,
421 None,
422 )?
423 .not_penalty_like();
424 hyper_dirs.push(dir);
425 }
426 let direct_dim = latent_coord_direct_hyper_count(latent.id_mode(), latent.latent_dim());
427 if analytic_rho_count + direct_dim > 0 {
428 let zero_x = gam_solve::estimate::reml::HyperDesignDerivative::from(Array2::<f64>::zeros((
429 design.design.nrows(),
430 p_total,
431 )));
432 for _ in 0..analytic_rho_count {
433 hyper_dirs.push(
434 DirectionalHyperParam::new_compact(zero_x.clone(), Vec::new(), None, None)?
435 .not_penalty_like(),
436 );
437 }
438 for _ in 0..direct_dim {
439 hyper_dirs.push(
440 DirectionalHyperParam::new_compact(zero_x.clone(), Vec::new(), None, None)?
441 .not_penalty_like(),
442 );
443 }
444 }
445 Ok(Some(hyper_dirs))
446}
447
448fn latent_coord_direct_hyper_count(
449 id_mode: &gam_terms::latent::LatentIdMode,
450 latent_dim: usize,
451) -> usize {
452 use gam_terms::latent::{AuxPriorStrength, LatentIdMode};
453 match id_mode {
454 LatentIdMode::AuxPrior { strength, .. } => match strength {
455 AuxPriorStrength::Auto => 1,
456 AuxPriorStrength::Fixed(_) => 0,
457 },
458 LatentIdMode::AuxPriorDimSelection { strength, .. } => {
459 latent_dim
460 + match strength {
461 AuxPriorStrength::Auto => 1,
462 AuxPriorStrength::Fixed(_) => 0,
463 }
464 }
465 LatentIdMode::DimSelection { .. } => latent_dim,
466 LatentIdMode::IsometryToReference { strength, .. } => match strength {
469 AuxPriorStrength::Auto => 1,
470 AuxPriorStrength::Fixed(_) => 0,
471 },
472 LatentIdMode::AuxOutcome { head, .. } => head.n_coeffs(latent_dim) + latent_dim,
475 LatentIdMode::None => 0,
476 }
477}
478
479fn latent_coord_initial_direct_hypers(
480 id_mode: &gam_terms::latent::LatentIdMode,
481 latent_dim: usize,
482) -> Result<Array1<f64>, EstimationError> {
483 use gam_terms::latent::{AuxPriorStrength, LatentIdMode};
484 let mut values = Vec::with_capacity(latent_coord_direct_hyper_count(id_mode, latent_dim));
485 match id_mode {
486 LatentIdMode::AuxPrior { strength, .. } => {
487 if matches!(strength, AuxPriorStrength::Auto) {
488 values.push(0.0);
489 }
490 }
491 LatentIdMode::AuxPriorDimSelection {
492 strength,
493 init_log_precision,
494 ..
495 } => {
496 if matches!(strength, AuxPriorStrength::Auto) {
497 values.push(0.0);
498 }
499 append_latent_ard_seed(&mut values, init_log_precision.as_ref(), latent_dim)?;
500 }
501 LatentIdMode::DimSelection { init_log_precision } => {
502 append_latent_ard_seed(&mut values, init_log_precision.as_ref(), latent_dim)?;
503 }
504 LatentIdMode::IsometryToReference { strength, .. } => {
505 if matches!(strength, AuxPriorStrength::Auto) {
506 values.push(0.0);
507 }
508 }
509 LatentIdMode::AuxOutcome {
510 head,
511 init_log_precision,
512 } => {
513 values.extend(std::iter::repeat_n(0.0, head.n_coeffs(latent_dim)));
517 append_latent_ard_seed(&mut values, init_log_precision.as_ref(), latent_dim)?;
518 }
519 LatentIdMode::None => {}
520 }
521 Ok(Array1::from_vec(values))
522}
523
524fn append_latent_ard_seed(
525 values: &mut Vec<f64>,
526 init: Option<&Array1<f64>>,
527 latent_dim: usize,
528) -> Result<(), EstimationError> {
529 if let Some(init) = init {
530 if init.len() != latent_dim {
531 crate::bail_invalid_estim!(
532 "latent dim_selection init_log_precision length mismatch: got {}, expected {}",
533 init.len(),
534 latent_dim
535 );
536 }
537 values.extend(init.iter().copied());
538 } else {
539 values.extend(std::iter::repeat_n(0.0, latent_dim));
540 }
541 Ok(())
542}
543
544struct LatentIdObjectiveContribution {
545 cost: f64,
546 gradient: Array1<f64>,
547}
548
549fn latent_id_objective_contribution(
550 theta: &Array1<f64>,
551 rho_dim: usize,
552 analytic_rho_count: usize,
553 latent: &gam_terms::latent::LatentCoordValues,
554) -> Result<LatentIdObjectiveContribution, EstimationError> {
555 use gam_terms::latent::{AuxPriorStrength, LatentIdMode, aux_prior_targets};
556 let n_obs = latent.n_obs();
557 let latent_dim = latent.latent_dim();
558 let flat_len = latent.len();
559 let mut gradient = Array1::<f64>::zeros(theta.len());
560 let t_start = rho_dim;
561 let direct_start = t_start + flat_len + analytic_rho_count;
562 if theta.len() < direct_start {
563 crate::bail_invalid_estim!(
564 "latent-coordinate theta too short for id objective: got {}, need at least {}",
565 theta.len(),
566 direct_start
567 );
568 }
569 let t = latent.as_matrix();
570 let mut cost = 0.0;
571 let mut cursor = direct_start;
572
573 match latent.id_mode() {
574 LatentIdMode::AuxPrior {
575 u,
576 family,
577 strength,
578 }
579 | LatentIdMode::AuxPriorDimSelection {
580 u,
581 family,
582 strength,
583 ..
584 } => {
585 let (log_mu, mu) = match strength {
586 AuxPriorStrength::Fixed(mu) => (mu.ln(), *mu),
587 AuxPriorStrength::Auto => {
588 let log_mu = theta[cursor];
589 cursor += 1;
590 (log_mu, log_mu.exp())
591 }
592 };
593 let targets = aux_prior_targets(t.view(), u.view(), *family)
594 .map_err(EstimationError::InvalidInput)?;
595 let residual = &t - &targets;
596 let q = residual.iter().map(|v| v * v).sum::<f64>();
597 let k = (n_obs * latent_dim) as f64;
604 cost += 0.5 * mu * q - 0.5 * k * log_mu;
605
606 let projected_residual = aux_prior_targets(residual.view(), u.view(), *family)
607 .map_err(EstimationError::InvalidInput)?;
608 let grad_base = residual - projected_residual;
609 for n in 0..n_obs {
610 for axis in 0..latent_dim {
611 gradient[t_start + n * latent_dim + axis] += mu * grad_base[[n, axis]];
612 }
613 }
614 if matches!(strength, AuxPriorStrength::Auto) {
615 gradient[direct_start] += 0.5 * mu * q - 0.5 * k;
616 }
617 }
618 LatentIdMode::IsometryToReference { reference, strength } => {
619 if reference.dim() != (n_obs, latent_dim) {
626 crate::bail_invalid_estim!(
627 "IsometryToReference reference shape {:?} must equal (n_obs, latent_dim) = ({}, {})",
628 reference.dim(),
629 n_obs,
630 latent_dim
631 );
632 }
633 let mu_slot = cursor;
634 let (log_mu, mu) = match strength {
635 AuxPriorStrength::Fixed(mu) => (mu.ln(), *mu),
636 AuxPriorStrength::Auto => {
637 let log_mu = theta[cursor];
638 cursor += 1;
639 (log_mu, log_mu.exp())
640 }
641 };
642 let residual = &t - reference;
643 let q = residual.iter().map(|v| v * v).sum::<f64>();
644 let k = (n_obs * latent_dim) as f64;
648 cost += 0.5 * mu * q - 0.5 * k * log_mu;
649 for n in 0..n_obs {
650 for axis in 0..latent_dim {
651 gradient[t_start + n * latent_dim + axis] += mu * residual[[n, axis]];
652 }
653 }
654 if matches!(strength, AuxPriorStrength::Auto) {
655 gradient[mu_slot] += 0.5 * mu * q - 0.5 * k;
656 }
657 }
658 LatentIdMode::AuxOutcome { head, .. } => {
659 let n_coeffs = head.n_coeffs(latent_dim);
667 let coeffs = theta
668 .slice(ndarray::s![cursor..cursor + n_coeffs])
669 .to_owned();
670 let (head_nll, grad_coeffs, grad_t) = head
671 .neg_loglik_and_grad(t.view(), coeffs.view())
672 .map_err(EstimationError::InvalidInput)?;
673 cost += head_nll;
674 for (offset, &g) in grad_coeffs.iter().enumerate() {
675 gradient[cursor + offset] += g;
676 }
677 for n in 0..n_obs {
678 for axis in 0..latent_dim {
679 gradient[t_start + n * latent_dim + axis] += grad_t[[n, axis]];
680 }
681 }
682 cursor += n_coeffs;
683 }
684 LatentIdMode::DimSelection { .. } | LatentIdMode::None => {}
685 }
686
687 match latent.id_mode() {
688 LatentIdMode::AuxPriorDimSelection { .. }
689 | LatentIdMode::DimSelection { .. }
690 | LatentIdMode::AuxOutcome { .. } => {
691 for axis in 0..latent_dim {
692 let log_alpha = theta[cursor + axis];
693 let alpha = log_alpha.exp();
694 let mut q_axis = 0.0;
695 for n in 0..n_obs {
696 let flat_idx = n * latent_dim + axis;
697 let value = latent.as_flat()[flat_idx];
698 q_axis += value * value;
699 gradient[t_start + flat_idx] += alpha * value;
700 }
701 cost += 0.5 * alpha * q_axis - 0.5 * n_obs as f64 * log_alpha;
702 gradient[cursor + axis] += 0.5 * alpha * q_axis - 0.5 * n_obs as f64;
703 }
704 cursor += latent_dim;
705 }
706 LatentIdMode::AuxPrior { .. }
707 | LatentIdMode::IsometryToReference { .. }
708 | LatentIdMode::None => {}
709 }
710
711 if cursor != theta.len() {
712 crate::bail_invalid_estim!(
713 "latent-coordinate direct hyperparameter length mismatch: consumed {}, theta len {}",
714 cursor,
715 theta.len()
716 );
717 }
718 Ok(LatentIdObjectiveContribution { cost, gradient })
719}
720
721fn add_latent_id_objective_to_eval(
722 theta: &Array1<f64>,
723 rho_dim: usize,
724 analytic_rho_count: usize,
725 latent: &gam_terms::latent::LatentCoordValues,
726 eval: &mut (
727 f64,
728 Array1<f64>,
729 gam_problem::HessianResult,
730 ),
731) -> Result<(), EstimationError> {
732 let contribution =
733 latent_id_objective_contribution(theta, rho_dim, analytic_rho_count, latent)?;
734 eval.0 += contribution.cost;
735 if eval.1.len() != contribution.gradient.len() {
736 crate::bail_invalid_estim!(
737 "latent-coordinate REML gradient length mismatch: base={}, id={}",
738 eval.1.len(),
739 contribution.gradient.len()
740 );
741 }
742 eval.1 += &contribution.gradient;
743 if eval.2.is_analytic() {
744 eval.2 = gam_problem::HessianResult::Unavailable;
745 }
746 Ok(())
747}
748
749fn analytic_penalty_objective_contribution(
750 theta: &Array1<f64>,
751 rho_dim: usize,
752 latent: &gam_terms::latent::LatentCoordValues,
753 registry: &gam_terms::AnalyticPenaltyRegistry,
754) -> Result<LatentIdObjectiveContribution, EstimationError> {
755 let flat_len = latent.len();
756 let t_start = rho_dim;
757 let t_end = t_start + flat_len;
758 let rho_start = t_end;
759 let rho_end = rho_start + registry.total_rho_count();
760 if theta.len() < rho_end {
761 crate::bail_invalid_estim!(
762 "latent-coordinate theta too short for analytic penalties: got {}, need at least {}",
763 theta.len(),
764 rho_end
765 );
766 }
767 let target_t = theta.slice(s![t_start..t_end]);
768 let rho = theta.slice(s![rho_start..rho_end]);
769 let mut cost = 0.0_f64;
770 let mut gradient = Array1::<f64>::zeros(theta.len());
771 for (penalty, (rho_slice, tier, name)) in registry.penalties.iter().zip(registry.rho_layout()) {
772 let rho_local = rho.slice(s![rho_slice.clone()]);
773 match tier {
774 gam_terms::PenaltyTier::Psi => {
775 cost += penalty.value(target_t.view(), rho_local);
776 let grad = penalty.grad_target(target_t.view(), rho_local);
777 if grad.len() != flat_len {
778 crate::bail_invalid_estim!(
779 "analytic penalty {name:?} gradient length mismatch: got {}, expected {}",
780 grad.len(),
781 flat_len
782 );
783 }
784 for i in 0..flat_len {
785 gradient[t_start + i] += grad[i];
786 }
787 let grad_rho_local = penalty.grad_rho(target_t.view(), rho_local);
788 if grad_rho_local.len() != rho_slice.len() {
789 crate::bail_invalid_estim!(
790 "analytic penalty {name:?} rho-gradient length mismatch: got {}, expected {}",
791 grad_rho_local.len(),
792 rho_slice.len()
793 );
794 }
795 for local_idx in 0..grad_rho_local.len() {
796 gradient[rho_start + rho_slice.start + local_idx] += grad_rho_local[local_idx];
797 }
798 }
799 gam_terms::PenaltyTier::Beta => {}
800 gam_terms::PenaltyTier::Rho => {}
801 }
802 }
803 Ok(LatentIdObjectiveContribution { cost, gradient })
804}
805
806fn add_analytic_penalty_hessian_to_eval(
807 theta: &Array1<f64>,
808 rho_dim: usize,
809 latent: &gam_terms::latent::LatentCoordValues,
810 registry: &gam_terms::AnalyticPenaltyRegistry,
811 eval: &mut (
812 f64,
813 Array1<f64>,
814 gam_problem::HessianResult,
815 ),
816) -> Result<(), EstimationError> {
817 let flat_len = latent.len();
818 let t_start = rho_dim;
819 let t_end = t_start + flat_len;
820 let rho_start = t_end;
821 let rho_end = rho_start + registry.total_rho_count();
822 if theta.len() < rho_end {
823 crate::bail_invalid_estim!(
824 "latent-coordinate theta too short for analytic penalty Hessian: got {}, need at least {}",
825 theta.len(),
826 rho_end
827 );
828 }
829 let gam_problem::HessianResult::Analytic(hessian) = &mut eval.2 else {
830 if eval.2.is_analytic() {
831 eval.2 = gam_problem::HessianResult::Unavailable;
832 }
833 return Ok(());
834 };
835 if hessian.dim() != (theta.len(), theta.len()) {
836 crate::bail_invalid_estim!(
837 "analytic penalty Hessian target shape mismatch: got {}x{}, expected {}x{}",
838 hessian.nrows(),
839 hessian.ncols(),
840 theta.len(),
841 theta.len()
842 );
843 }
844 let target_t = theta.slice(s![t_start..t_end]);
845 let rho = theta.slice(s![rho_start..rho_end]);
846 for (penalty, (rho_slice, tier, _name)) in registry.penalties.iter().zip(registry.rho_layout())
847 {
848 let rho_local = rho.slice(s![rho_slice]);
849 if !matches!(tier, gam_terms::PenaltyTier::Psi) {
850 continue;
851 }
852 if let Some(diag) = penalty.hessian_diag(target_t.view(), rho_local) {
853 if diag.len() != flat_len {
854 crate::bail_invalid_estim!(
855 "analytic penalty Hessian diagonal length mismatch: got {}, expected {}",
856 diag.len(),
857 flat_len
858 );
859 }
860 for i in 0..flat_len {
861 hessian[[t_start + i, t_start + i]] += diag[i];
862 }
863 continue;
864 }
865 let mut probe = Array1::<f64>::zeros(flat_len);
866 for col in 0..flat_len {
867 probe[col] = 1.0;
868 let hv = penalty.hvp(target_t.view(), rho_local, probe.view());
869 if hv.len() != flat_len {
870 crate::bail_invalid_estim!(
871 "analytic penalty Hessian-vector length mismatch: got {}, expected {}",
872 hv.len(),
873 flat_len
874 );
875 }
876 for row in 0..flat_len {
877 hessian[[t_start + row, t_start + col]] += hv[row];
878 }
879 probe[col] = 0.0;
880 }
881 }
882 Ok(())
883}
884
885fn add_analytic_penalty_objective_to_eval(
886 theta: &Array1<f64>,
887 rho_dim: usize,
888 latent: &gam_terms::latent::LatentCoordValues,
889 registry: &gam_terms::AnalyticPenaltyRegistry,
890 eval: &mut (
891 f64,
892 Array1<f64>,
893 gam_problem::HessianResult,
894 ),
895) -> Result<(), EstimationError> {
896 let contribution = analytic_penalty_objective_contribution(theta, rho_dim, latent, registry)?;
897 eval.0 += contribution.cost;
898 if eval.1.len() != contribution.gradient.len() {
899 crate::bail_invalid_estim!(
900 "latent-coordinate REML gradient length mismatch: base={}, analytic_penalty={}",
901 eval.1.len(),
902 contribution.gradient.len()
903 );
904 }
905 eval.1 += &contribution.gradient;
906 add_analytic_penalty_hessian_to_eval(theta, rho_dim, latent, registry, eval)?;
907 Ok(())
908}
909
910fn spatial_log_kappa_hyper_dirs_frominfo_list(
911 info_list: Vec<SpatialPsiDerivative>,
912) -> Result<Vec<DirectionalHyperParam>, EstimationError> {
913 use gam_solve::estimate::reml::ImplicitDerivLevel;
914 use std::collections::HashMap;
915
916 let log_kappa_dim = info_list.len();
917 let group_ids: Vec<Option<usize>> = info_list.iter().map(|e| e.aniso_group_id).collect();
923 let mut group_indices_map: HashMap<usize, Vec<usize>> = HashMap::new();
924 for (idx, gid) in group_ids.iter().enumerate() {
925 if let Some(g) = gid {
926 group_indices_map.entry(*g).or_default().push(idx);
927 }
928 }
929
930 let mut hyper_dirs = Vec::with_capacity(log_kappa_dim);
931 for (i, info) in info_list.into_iter().enumerate() {
932 let SpatialPsiDerivative {
933 penalty_index: _,
934 penalty_indices,
935 global_range,
936 total_p,
937 x_psi_local,
938 s_psi_components_local,
939 x_psi_psi_local,
940 s_psi_psi_components_local,
941 aniso_group_id,
942 aniso_cross_designs,
943 aniso_cross_penalty_provider,
944 implicit_operator,
945 implicit_axis,
946 } = info;
947
948 let mut xsecond = vec![None; log_kappa_dim];
949 xsecond[i] = Some(if let Some(ref op) = implicit_operator {
951 gam_solve::estimate::reml::HyperDesignDerivative::from_implicit(
952 op.clone(),
953 ImplicitDerivLevel::SecondDiag(implicit_axis),
954 global_range.clone(),
955 total_p,
956 )
957 } else {
958 gam_solve::estimate::reml::HyperDesignDerivative::from_embedded(
959 x_psi_psi_local,
960 global_range.clone(),
961 total_p,
962 )
963 });
964 if let Some(cross_designs) = aniso_cross_designs {
966 if let Some(gid) = aniso_group_id {
970 let base = group_indices_map
971 .get(&gid)
972 .and_then(|v| v.first().copied())
973 .unwrap_or(i);
974 for (b_axis, cross_mat) in cross_designs.into_iter() {
975 let j = base + b_axis;
976 if j < log_kappa_dim {
977 xsecond[j] = Some(if let Some(ref op) = implicit_operator {
978 gam_solve::estimate::reml::HyperDesignDerivative::from_implicit(
979 op.clone(),
980 ImplicitDerivLevel::SecondCross(implicit_axis, b_axis),
981 global_range.clone(),
982 total_p,
983 )
984 } else {
985 gam_solve::estimate::reml::HyperDesignDerivative::from_embedded(
986 cross_mat,
987 global_range.clone(),
988 total_p,
989 )
990 });
991 }
992 }
993 }
994 }
995 let s_components = penalty_indices
996 .iter()
997 .copied()
998 .zip(s_psi_components_local.into_iter().map(|local| {
999 gam_solve::estimate::reml::HyperPenaltyDerivative::from_embedded(
1000 local,
1001 global_range.clone(),
1002 total_p,
1003 )
1004 }))
1005 .collect::<Vec<_>>();
1006 let s2_components = penalty_indices
1007 .iter()
1008 .copied()
1009 .zip(s_psi_psi_components_local.into_iter().map(|local| {
1010 gam_solve::estimate::reml::HyperPenaltyDerivative::from_embedded(
1011 local,
1012 global_range.clone(),
1013 total_p,
1014 )
1015 }))
1016 .collect::<Vec<_>>();
1017 let mut ssecond_components = vec![None; log_kappa_dim];
1018 ssecond_components[i] = Some(s2_components);
1019 let mut penaltysecond_partner_indices: Option<Vec<usize>> = None;
1020 let penaltysecond_component_provider =
1021 if let (Some(provider), Some(gid)) = (aniso_cross_penalty_provider, aniso_group_id) {
1022 let group_indices = group_indices_map.get(&gid).cloned().unwrap_or_default();
1023 let axis_in_group =
1024 group_indices
1025 .iter()
1026 .position(|&idx| idx == i)
1027 .ok_or_else(|| {
1028 EstimationError::InvalidInput(format!(
1029 "missing spatial hyper axis {} in anisotropy group {}",
1030 i, gid
1031 ))
1032 })?;
1033 penaltysecond_partner_indices = Some(
1034 group_indices
1035 .iter()
1036 .copied()
1037 .filter(|&idx| idx != i)
1038 .collect(),
1039 );
1040 let penalty_indices_inner = penalty_indices.clone();
1041 let global_range_inner = global_range.clone();
1042 let total_p_inner = total_p;
1043 let group_indices_inner = group_indices;
1044 Some(std::sync::Arc::new(
1045 move |j: usize| -> Result<
1046 Option<Vec<gam_solve::estimate::reml::PenaltyDerivativeComponent>>,
1047 EstimationError,
1048 > {
1049 let Some(other_axis_in_group) =
1050 group_indices_inner.iter().position(|&idx| idx == j)
1051 else {
1052 return Ok(None);
1053 };
1054 if other_axis_in_group == axis_in_group {
1055 return Ok(None);
1056 }
1057 let cross_pens = provider(other_axis_in_group)?;
1058 if cross_pens.is_empty() {
1059 return Ok(None);
1060 }
1061 Ok(Some(
1062 penalty_indices_inner
1063 .iter()
1064 .copied()
1065 .zip(cross_pens.into_iter().map(|local| {
1066 gam_solve::estimate::reml::HyperPenaltyDerivative::from_embedded(
1067 local,
1068 global_range_inner.clone(),
1069 total_p_inner,
1070 )
1071 }))
1072 .map(|(penalty_index, matrix)| {
1073 gam_solve::estimate::reml::PenaltyDerivativeComponent {
1074 penalty_index,
1075 matrix,
1076 }
1077 })
1078 .collect(),
1079 ))
1080 },
1081 )
1082 as std::sync::Arc<
1083 dyn Fn(
1084 usize,
1085 ) -> Result<
1086 Option<Vec<gam_solve::estimate::reml::PenaltyDerivativeComponent>>,
1087 EstimationError,
1088 > + Send
1089 + Sync
1090 + 'static,
1091 >)
1092 } else {
1093 None
1094 };
1095 let x_first_hyper = if let Some(ref op) = implicit_operator {
1098 gam_solve::estimate::reml::HyperDesignDerivative::from_implicit(
1099 op.clone(),
1100 ImplicitDerivLevel::First(implicit_axis),
1101 global_range.clone(),
1102 total_p,
1103 )
1104 } else {
1105 gam_solve::estimate::reml::HyperDesignDerivative::from_embedded(
1106 x_psi_local,
1107 global_range.clone(),
1108 total_p,
1109 )
1110 };
1111 let mut dir = DirectionalHyperParam::new_compact(
1112 x_first_hyper,
1113 s_components,
1114 Some(xsecond),
1115 Some(ssecond_components),
1116 )?
1117 .not_penalty_like();
1118 if let Some(provider) = penaltysecond_component_provider {
1119 dir = dir.with_penaltysecond_component_provider(provider);
1120 }
1121 if let Some(partner_indices) = penaltysecond_partner_indices {
1122 dir = dir.with_penaltysecond_partner_indices(partner_indices);
1123 }
1124 hyper_dirs.push(dir);
1125 }
1126 Ok(hyper_dirs)
1127}
1128
1129pub(crate) fn spatial_dims_per_term(
1135 resolvedspec: &TermCollectionSpec,
1136 spatial_terms: &[usize],
1137) -> Vec<usize> {
1138 spatial_terms
1139 .iter()
1140 .map(|&term_idx| {
1141 if let Some(mj) = measure_jet_term_spec(resolvedspec, term_idx) {
1142 measure_jet_psi_dim(mj)
1145 } else if spatial_term_uses_per_axis_psi(resolvedspec, term_idx) {
1146 get_spatial_feature_dim(resolvedspec, term_idx).unwrap_or(1)
1147 } else {
1148 1
1149 }
1150 })
1151 .collect()
1152}
1153
1154fn has_aniso_terms(resolvedspec: &TermCollectionSpec, spatial_terms: &[usize]) -> bool {
1158 spatial_terms
1159 .iter()
1160 .any(|&term_idx| spatial_term_uses_per_axis_psi(resolvedspec, term_idx))
1161}
1162
1163macro_rules! impl_exact_joint_theta_memo {
1169 () => {
1170 fn memoized_cost(&self, theta: &Array1<f64>) -> Option<f64> {
1171 if self
1172 .current_theta
1173 .as_ref()
1174 .is_some_and(|cached| theta_values_match(cached, theta))
1175 {
1176 self.last_eval
1177 .as_ref()
1178 .map(|cached| cached.0)
1179 .or(self.last_cost)
1180 } else {
1181 None
1182 }
1183 }
1184
1185 fn memoized_eval(
1186 &self,
1187 theta: &Array1<f64>,
1188 ) -> Option<(
1189 f64,
1190 Array1<f64>,
1191 gam_problem::HessianResult,
1192 )> {
1193 if self
1194 .current_theta
1195 .as_ref()
1196 .is_some_and(|cached| theta_values_match(cached, theta))
1197 {
1198 self.last_eval.clone()
1199 } else {
1200 None
1201 }
1202 }
1203
1204 fn store_eval(
1205 &mut self,
1206 eval: (
1207 f64,
1208 Array1<f64>,
1209 gam_problem::HessianResult,
1210 ),
1211 ) {
1212 self.last_cost = Some(eval.0);
1213 self.last_eval = Some(eval);
1214 }
1215 };
1216}
1217
1218struct SingleBlockExactJointDesignCache<'d> {
1219 realizer: FrozenTermCollectionIncrementalRealizer<'d>,
1220 current_theta: Option<Array1<f64>>,
1221 last_eval_theta: Option<Array1<f64>>,
1228 last_cost: Option<f64>,
1229 last_eval: Option<(
1230 f64,
1231 Array1<f64>,
1232 gam_problem::HessianResult,
1233 )>,
1234 cached_hyper_dirs: Option<(u64, Vec<DirectionalHyperParam>)>,
1246 spatial_terms: Vec<usize>,
1247 rho_dim: usize,
1248 dims_per_term: Vec<usize>,
1249}
1250
1251impl<'d> SingleBlockExactJointDesignCache<'d> {
1252 fn new(
1253 data: ArrayView2<'d, f64>,
1254 spec: TermCollectionSpec,
1255 design: TermCollectionDesign,
1256 spatial_terms: Vec<usize>,
1257 rho_dim: usize,
1258 dims_per_term: Vec<usize>,
1259 ) -> Result<Self, String> {
1260 Ok(Self {
1261 realizer: FrozenTermCollectionIncrementalRealizer::new(data, spec, design)?,
1262 current_theta: None,
1263 last_eval_theta: None,
1264 last_cost: None,
1265 last_eval: None,
1266 cached_hyper_dirs: None,
1267 spatial_terms,
1268 rho_dim,
1269 dims_per_term,
1270 })
1271 }
1272
1273 fn design_revision(&self) -> u64 {
1274 self.realizer.design_revision()
1275 }
1276
1277 fn hyper_dirs_for_current_design(
1287 &mut self,
1288 data: ArrayView2<'_, f64>,
1289 kind: SpatialHyperKind,
1290 ) -> Result<Vec<DirectionalHyperParam>, EstimationError> {
1291 let revision = self.realizer.design_revision();
1292 if let Some((cached_rev, dirs)) = self.cached_hyper_dirs.as_ref()
1293 && *cached_rev == revision
1294 {
1295 return Ok(dirs.clone());
1296 }
1297 let dirs = try_build_spatial_log_kappa_hyper_dirs(
1298 data,
1299 self.realizer.spec(),
1300 self.realizer.design(),
1301 &self.spatial_terms,
1302 )?
1303 .ok_or_else(|| {
1304 EstimationError::InvalidInput(format!(
1305 "failed to build {} hyper_dirs at current {}",
1306 kind.adjective(),
1307 kind.coord_name(),
1308 ))
1309 })?;
1310 self.cached_hyper_dirs = Some((revision, dirs.clone()));
1311 Ok(dirs)
1312 }
1313
1314 fn nfree_tensor_gradient_hyper_dirs(
1315 &mut self,
1316 theta: &Array1<f64>,
1317 ) -> Result<Vec<DirectionalHyperParam>, EstimationError> {
1318 let psi = &theta.as_slice().ok_or_else(|| {
1319 EstimationError::InvalidInput(
1320 "nfree_tensor_gradient_hyper_dirs: theta is not contiguous".to_string(),
1321 )
1322 })?[self.rho_dim..];
1323 let (global_range, p_total, s_psi_components) = self
1324 .realizer
1325 .canonical_penalty_derivatives_at_psi(&self.spatial_terms, psi)
1326 .map_err(EstimationError::InvalidInput)?;
1327 let zero_x = gam_solve::estimate::reml::HyperDesignDerivative::zero(
1328 self.realizer.design().design.nrows(),
1329 p_total,
1330 );
1331 let components = s_psi_components
1332 .into_iter()
1333 .enumerate()
1334 .map(|(penalty_index, local)| {
1335 (
1336 penalty_index,
1337 gam_solve::estimate::reml::HyperPenaltyDerivative::from_embedded(
1338 local,
1339 global_range.clone(),
1340 p_total,
1341 ),
1342 )
1343 })
1344 .collect::<Vec<_>>();
1345 Ok(DirectionalHyperParam::new_compact(zero_x, components, None, None)?.not_penalty_like())
1346 .map(|dir| vec![dir])
1347 }
1348
1349 fn ensure_theta(&mut self, theta: &Array1<f64>) -> Result<(), String> {
1350 if self
1351 .current_theta
1352 .as_ref()
1353 .is_some_and(|cached| theta_values_match(cached, theta))
1354 {
1355 return Ok(());
1356 }
1357 let t_ensure = std::time::Instant::now();
1358 let log_kappa = SpatialLogKappaCoords::from_theta_tail_with_dims(
1359 theta,
1360 self.rho_dim,
1361 self.dims_per_term.clone(),
1362 );
1363 self.realizer
1364 .apply_log_kappa(&log_kappa, &self.spatial_terms)?;
1365 log::info!(
1366 "[STAGE] ensure_theta (apply_log_kappa, {} terms): {:.3}s",
1367 self.spatial_terms.len(),
1368 t_ensure.elapsed().as_secs_f64(),
1369 );
1370 self.current_theta = Some(theta.clone());
1371 self.last_eval_theta = None;
1372 self.last_cost = None;
1373 self.last_eval = None;
1374 Ok(())
1375 }
1376
1377 fn memoized_cost(&self, theta: &Array1<f64>) -> Option<f64> {
1384 if self
1385 .last_eval_theta
1386 .as_ref()
1387 .is_some_and(|cached| theta_values_match(cached, theta))
1388 {
1389 self.last_eval
1390 .as_ref()
1391 .map(|cached| cached.0)
1392 .or(self.last_cost)
1393 } else {
1394 None
1395 }
1396 }
1397
1398 fn memoized_eval(
1399 &self,
1400 theta: &Array1<f64>,
1401 ) -> Option<(
1402 f64,
1403 Array1<f64>,
1404 gam_problem::HessianResult,
1405 )> {
1406 if self
1407 .last_eval_theta
1408 .as_ref()
1409 .is_some_and(|cached| theta_values_match(cached, theta))
1410 {
1411 self.last_eval.clone()
1412 } else {
1413 None
1414 }
1415 }
1416
1417 fn store_eval_at(
1421 &mut self,
1422 theta: &Array1<f64>,
1423 eval: (
1424 f64,
1425 Array1<f64>,
1426 gam_problem::HessianResult,
1427 ),
1428 ) {
1429 self.last_eval_theta = Some(theta.clone());
1430 self.last_cost = Some(eval.0);
1431 self.last_eval = Some(eval);
1432 }
1433
1434 fn store_cost_at(&mut self, theta: &Array1<f64>, cost: f64) {
1437 self.last_eval_theta = Some(theta.clone());
1438 self.last_cost = Some(cost);
1439 self.last_eval = None;
1443 }
1444
1445 fn spec(&self) -> &TermCollectionSpec {
1446 self.realizer.spec()
1447 }
1448
1449 fn design(&self) -> &TermCollectionDesign {
1450 self.realizer.design()
1451 }
1452
1453 fn supports_nfree_penalty_rekey(&self) -> bool {
1459 self.realizer
1460 .supports_nfree_penalty_rekey(&self.spatial_terms)
1461 }
1462
1463 fn supports_nfree_gradient_only_routing(&self) -> bool {
1464 self.realizer
1465 .supports_nfree_gradient_only_routing(&self.spatial_terms)
1466 }
1467
1468 fn canonical_penalties_at(
1478 &mut self,
1479 theta: &Array1<f64>,
1480 ) -> Result<(Vec<gam_terms::construction::CanonicalPenalty>, Vec<usize>), String> {
1481 let psi = &theta
1482 .as_slice()
1483 .ok_or_else(|| "canonical_penalties_at: theta is not contiguous".to_string())?
1484 [self.rho_dim..];
1485 self.realizer
1486 .canonical_penalties_at_psi(&self.spatial_terms, psi)
1487 }
1488}
1489
1490struct SingleBlockLatentCoordDesignCache {
1491 data: Array2<f64>,
1492 spec: TermCollectionSpec,
1493 design: TermCollectionDesign,
1494 current_theta: Option<Array1<f64>>,
1495 current_latent: Option<std::sync::Arc<gam_terms::latent::LatentCoordValues>>,
1496 current_hyper_dirs: Option<Vec<gam_solve::estimate::reml::DirectionalHyperParam>>,
1497 current_design_cache_id: Option<u64>,
1498 latent_design_cache: gam_solve::latent_cache::LatentDesignCache,
1499 last_cost: Option<f64>,
1500 last_eval: Option<(
1501 f64,
1502 Array1<f64>,
1503 gam_problem::HessianResult,
1504 )>,
1505 term_index: gam_problem::types::SmoothTermIdx,
1506 feature_cols: Vec<usize>,
1507 rho_dim: usize,
1508 n_obs: usize,
1509 latent_dim: usize,
1510 id_mode: gam_terms::latent::LatentIdMode,
1511 manifold: gam_terms::latent::LatentManifold,
1512 retraction_registry: gam_solve::latent_cache::LatentRetractionRegistry,
1513 latent_id: u64,
1514 analytic_penalties: Option<std::sync::Arc<gam_terms::AnalyticPenaltyRegistry>>,
1515 analytic_rho_count: usize,
1516 design_revision: u64,
1517 last_outer_iter: Option<u64>,
1521}
1522
1523impl SingleBlockLatentCoordDesignCache {
1524 fn new(
1525 data: Array2<f64>,
1526 spec: TermCollectionSpec,
1527 design: TermCollectionDesign,
1528 latent: &StandardLatentCoordConfig,
1529 rho_dim: usize,
1530 ) -> Result<Self, String> {
1531 if latent.term_index.get() >= spec.smooth_terms.len() {
1532 return Err(SmoothError::dimension_mismatch(format!(
1533 "latent-coordinate term index {} out of bounds for {} smooth terms",
1534 latent.term_index,
1535 spec.smooth_terms.len()
1536 ))
1537 .into());
1538 }
1539 if latent.feature_cols.len() != latent.values.latent_dim() {
1540 return Err(SmoothError::dimension_mismatch(format!(
1541 "latent-coordinate feature width mismatch: feature_cols={}, latent_dim={}",
1542 latent.feature_cols.len(),
1543 latent.values.latent_dim()
1544 ))
1545 .into());
1546 }
1547 if latent.values.n_obs() != data.nrows() {
1548 return Err(SmoothError::dimension_mismatch(format!(
1549 "latent-coordinate row mismatch: latent n={}, data n={}",
1550 latent.values.n_obs(),
1551 data.nrows()
1552 ))
1553 .into());
1554 }
1555 let analytic_rho_count = latent
1556 .analytic_penalties
1557 .as_ref()
1558 .map_or(0, |registry| registry.total_rho_count());
1559 Ok(Self {
1560 data,
1561 spec,
1562 design,
1563 current_theta: None,
1564 current_latent: None,
1565 current_hyper_dirs: None,
1566 current_design_cache_id: None,
1567 latent_design_cache: gam_solve::latent_cache::LatentDesignCache::default(),
1568 last_cost: None,
1569 last_eval: None,
1570 term_index: latent.term_index,
1571 feature_cols: latent.feature_cols.clone(),
1572 rho_dim,
1573 n_obs: latent.values.n_obs(),
1574 latent_dim: latent.values.latent_dim(),
1575 id_mode: latent.values.id_mode().clone(),
1576 manifold: latent.values.manifold().clone(),
1577 retraction_registry: latent.values.retraction_registry().clone(),
1578 latent_id: latent.values.latent_id(),
1579 analytic_penalties: latent.analytic_penalties.clone(),
1580 analytic_rho_count,
1581 design_revision: 0,
1582 last_outer_iter: None,
1583 })
1584 }
1585
1586 fn design_revision(&self) -> u64 {
1587 self.design_revision
1588 }
1589
1590 fn design(&self) -> &TermCollectionDesign {
1591 &self.design
1592 }
1593
1594 fn latent(&self) -> Result<std::sync::Arc<gam_terms::latent::LatentCoordValues>, String> {
1595 self.current_latent
1596 .as_ref()
1597 .cloned()
1598 .ok_or_else(|| "latent-coordinate cache has not been realized".to_string())
1599 }
1600
1601 fn analytic_penalties(&self) -> Option<std::sync::Arc<gam_terms::AnalyticPenaltyRegistry>> {
1602 self.analytic_penalties.clone()
1603 }
1604
1605 fn analytic_penalty_rho_count(&self) -> usize {
1606 self.analytic_rho_count
1607 }
1608
1609 fn hyper_dirs(&self) -> Result<Vec<gam_solve::estimate::reml::DirectionalHyperParam>, String> {
1610 self.current_hyper_dirs
1611 .as_ref()
1612 .cloned()
1613 .ok_or_else(|| "latent-coordinate hyper_dirs cache has not been realized".to_string())
1614 }
1615
1616 fn latent_basis_kind(&self) -> Result<gam_solve::latent_cache::LatentBasisKind, String> {
1617 let smooth_term = self
1618 .design
1619 .smooth
1620 .terms
1621 .get(self.term_index.get())
1622 .ok_or_else(|| {
1623 SmoothError::dimension_mismatch(format!(
1624 "LatentCoord term index {} out of bounds for realized smooth design",
1625 self.term_index
1626 ))
1627 })?;
1628 let termspec = self
1629 .spec
1630 .smooth_terms
1631 .get(self.term_index.get())
1632 .ok_or_else(|| {
1633 SmoothError::dimension_mismatch(format!(
1634 "LatentCoord term index {} out of bounds for resolved smooth spec",
1635 self.term_index
1636 ))
1637 })?;
1638 match (&termspec.basis, &smooth_term.metadata) {
1639 (
1640 SmoothBasisSpec::Matern { .. },
1641 BasisMetadata::Matern {
1642 centers,
1643 length_scale,
1644 nu,
1645 aniso_log_scales,
1646 ..
1647 },
1648 ) => Ok(gam_solve::latent_cache::LatentBasisKind::Matern {
1649 centers: centers.clone(),
1650 length_scale: *length_scale,
1651 nu: *nu,
1652 aniso_log_scales: aniso_log_scales
1653 .clone()
1654 .unwrap_or_else(|| vec![0.0; centers.ncols()]),
1655 chunk_size: gam_terms::basis::auto_streaming_chunk_size_for_dense(
1656 self.n_obs,
1657 centers.nrows(),
1658 ),
1659 }),
1660 (
1661 SmoothBasisSpec::Duchon { .. },
1662 BasisMetadata::Duchon {
1663 centers,
1664 length_scale,
1665 power,
1666 nullspace_order,
1667 aniso_log_scales,
1668 ..
1669 },
1670 ) => Ok(gam_solve::latent_cache::LatentBasisKind::Duchon {
1671 centers: centers.clone(),
1672 length_scale: *length_scale,
1673 power: *power,
1674 nullspace_order: *nullspace_order,
1675 aniso_log_scales: aniso_log_scales
1676 .clone()
1677 .unwrap_or_else(|| vec![0.0; centers.ncols()]),
1678 }),
1679 (
1680 SmoothBasisSpec::Sphere { .. },
1681 BasisMetadata::Sphere {
1682 centers,
1683 penalty_order,
1684 method,
1685 ..
1686 },
1687 ) if matches!(*method, gam_terms::basis::SphereMethod::Wahba) => {
1688 Ok(gam_solve::latent_cache::LatentBasisKind::Sphere {
1689 centers: centers.clone(),
1690 penalty_order: *penalty_order,
1691 chunk_size: gam_terms::basis::auto_streaming_chunk_size_for_dense(
1692 self.n_obs,
1693 centers.nrows(),
1694 ),
1695 })
1696 }
1697 (
1698 SmoothBasisSpec::BSpline1D { spec, .. },
1699 BasisMetadata::BSpline1D {
1700 knots,
1701 periodic,
1702 degree: meta_degree,
1703 ..
1704 },
1705 ) => {
1706 let effective_degree = meta_degree.unwrap_or(spec.degree);
1710 if let Some((domain_start, period, num_basis)) = periodic {
1711 Ok(
1712 gam_solve::latent_cache::LatentBasisKind::PeriodicBspline {
1713 domain_start: *domain_start,
1714 period: *period,
1715 degree: effective_degree,
1716 num_basis: *num_basis,
1717 chunk_size: gam_terms::basis::auto_streaming_chunk_size_for_dense(
1718 self.n_obs, *num_basis,
1719 ),
1720 },
1721 )
1722 } else {
1723 let num_basis_est = knots.len().saturating_sub(effective_degree + 1);
1724 Ok(
1725 gam_solve::latent_cache::LatentBasisKind::TensorBspline {
1726 knots: vec![knots.clone()],
1727 degrees: vec![effective_degree],
1728 chunk_size: gam_terms::basis::auto_streaming_chunk_size_for_dense(
1729 self.n_obs,
1730 num_basis_est,
1731 ),
1732 },
1733 )
1734 }
1735 }
1736 (
1737 SmoothBasisSpec::TensorBSpline { .. },
1738 BasisMetadata::TensorBSpline { knots, degrees, .. },
1739 ) => Ok(
1740 gam_solve::latent_cache::LatentBasisKind::TensorBspline {
1741 knots: knots.clone(),
1742 degrees: degrees.clone(),
1743 chunk_size: None,
1744 },
1745 ),
1746 (
1747 SmoothBasisSpec::Pca { .. },
1748 BasisMetadata::Pca {
1749 basis_matrix,
1750 centered,
1751 smooth_penalty,
1752 center_mean,
1753 pca_basis_path,
1754 chunk_size,
1755 ..
1756 },
1757 ) => {
1758 let center_mean_fingerprint = if *centered && pca_basis_path.is_none() {
1759 let mean = center_mean.as_ref().ok_or_else(|| {
1760 SmoothError::invalid_config(
1761 "latent-coordinate Pca cache key requires center_mean when centered",
1762 )
1763 })?;
1764 Some(gam_solve::latent_cache::pca_center_mean_fingerprint(
1765 mean,
1766 ))
1767 } else {
1768 None
1769 };
1770 Ok(gam_solve::latent_cache::LatentBasisKind::Pca {
1771 basis_matrix: basis_matrix.clone(),
1772 centered: *centered,
1773 center_mean_fingerprint,
1774 smooth_penalty: *smooth_penalty,
1775 pca_basis_path: pca_basis_path.clone(),
1776 chunk_size: *chunk_size,
1777 })
1778 }
1779 _ => Err(SmoothError::invalid_config(
1780 "latent-coordinate design cache could not key the realized latent smooth basis"
1781 .to_string(),
1782 )
1783 .into()),
1784 }
1785 }
1786
1787 fn ensure_theta(&mut self, theta: &Array1<f64>) -> Result<(), String> {
1788 if self
1789 .current_theta
1790 .as_ref()
1791 .is_some_and(|cached| theta_values_match(cached, theta))
1792 {
1793 return Ok(());
1794 }
1795 let latent_flat_len = self.n_obs * self.latent_dim;
1796 let direct_hyper_count = latent_coord_direct_hyper_count(&self.id_mode, self.latent_dim);
1797 let expected =
1798 self.rho_dim + latent_flat_len + self.analytic_rho_count + direct_hyper_count;
1799 if theta.len() != expected {
1800 return Err(SmoothError::dimension_mismatch(format!(
1801 "latent-coordinate theta length mismatch: got {}, expected {} (rho_dim={}, n={}, d={}, analytic_rhos={}, direct_hypers={})",
1802 theta.len(),
1803 expected,
1804 self.rho_dim,
1805 self.n_obs,
1806 self.latent_dim,
1807 self.analytic_rho_count,
1808 direct_hyper_count
1809 ))
1810 .into());
1811 }
1812 let flat = theta
1813 .slice(s![self.rho_dim..self.rho_dim + latent_flat_len])
1814 .to_owned();
1815 let latent = std::sync::Arc::new(
1816 gam_terms::latent::LatentCoordValues::from_flat_with_manifold_and_retraction_and_id(
1817 flat,
1818 self.n_obs,
1819 self.latent_dim,
1820 self.id_mode.clone(),
1821 self.manifold.clone(),
1822 self.retraction_registry.clone(),
1823 self.latent_id,
1824 ),
1825 );
1826 let latent_values_changed = self
1827 .current_latent
1828 .as_ref()
1829 .map(|cached| !latent_values_match(cached.as_flat(), latent.as_flat()))
1830 .unwrap_or(true);
1831 if latent_values_changed {
1832 self.latent_design_cache.invalidate_all();
1833 self.current_design_cache_id = None;
1834 self.design_revision = self.design_revision.wrapping_add(1);
1835 }
1836 for n in 0..self.n_obs {
1837 for axis in 0..self.latent_dim {
1838 let col = self.feature_cols[axis];
1839 self.data[[n, col]] = latent.as_flat()[n * self.latent_dim + axis];
1840 }
1841 }
1842
1843 let basis_kind = self.latent_basis_kind()?;
1844 let rebuilt_width = self.design.design.ncols();
1845 let spec = self.spec.clone();
1846 let term_index = self.term_index;
1847 let analytic_rho_count = self.analytic_rho_count;
1848 let data = self.data.view();
1849 let design_context_digest =
1850 gam_solve::latent_cache::latent_design_context_cache_digest(
1851 data,
1852 &spec,
1853 term_index,
1854 analytic_rho_count,
1855 &self.feature_cols,
1856 )
1857 .map_err(|e| e.to_string())?;
1858 let lookup = self
1859 .latent_design_cache
1860 .lookup_or_compute(latent.clone(), basis_kind, design_context_digest, || {
1861 let rebuilt = build_term_collection_design(data, &spec).map_err(|e| {
1862 EstimationError::InvalidInput(format!(
1863 "failed to rebuild latent-coordinate design: {e}"
1864 ))
1865 })?;
1866 if rebuilt.design.ncols() != rebuilt_width {
1867 crate::bail_invalid_estim!(
1868 "latent-coordinate design topology changed: rebuilt p={}, cached p={}",
1869 rebuilt.design.ncols(),
1870 rebuilt_width
1871 );
1872 }
1873 let hyper_dirs = try_build_latent_coord_hyper_dirs(
1874 latent.clone(),
1875 &spec,
1876 &rebuilt,
1877 &[term_index],
1878 analytic_rho_count,
1879 )?
1880 .ok_or_else(|| {
1881 EstimationError::InvalidInput(
1882 "failed to build latent-coordinate hyper_dirs".to_string(),
1883 )
1884 })?;
1885 Ok(gam_solve::latent_cache::ComputedLatentDesign {
1886 design: rebuilt,
1887 hyper_dirs,
1888 })
1889 })
1890 .map_err(|e| e.to_string())?;
1891 if lookup.cached.design.design.ncols() != self.design.design.ncols() {
1892 return Err(SmoothError::dimension_mismatch(format!(
1893 "latent-coordinate design topology changed: rebuilt p={}, cached p={}",
1894 lookup.cached.design.design.ncols(),
1895 self.design.design.ncols()
1896 ))
1897 .into());
1898 }
1899 self.design = lookup.cached.design.clone();
1900 self.current_hyper_dirs = Some(lookup.cached.hyper_dirs.clone());
1901 self.current_latent = Some(latent);
1902 self.current_theta = Some(theta.clone());
1903 self.last_cost = None;
1904 self.last_eval = None;
1905 self.last_outer_iter = None;
1906 if !latent_values_changed && self.current_design_cache_id != Some(lookup.entry_id) {
1907 self.design_revision = self.design_revision.wrapping_add(1);
1908 }
1909 self.current_design_cache_id = Some(lookup.entry_id);
1910 Ok(())
1911 }
1912
1913 fn memoized_cost(&self, theta: &Array1<f64>) -> Option<f64> {
1914 if self
1915 .current_theta
1916 .as_ref()
1917 .is_some_and(|cached| theta_values_match(cached, theta))
1918 && self.last_outer_iter
1919 == Some(gam_solve::estimate::reml::outer_eval::current_outer_iter())
1920 {
1921 self.last_eval
1922 .as_ref()
1923 .map(|cached| cached.0)
1924 .or(self.last_cost)
1925 } else {
1926 None
1927 }
1928 }
1929
1930 fn memoized_eval(
1931 &self,
1932 theta: &Array1<f64>,
1933 ) -> Option<(
1934 f64,
1935 Array1<f64>,
1936 gam_problem::HessianResult,
1937 )> {
1938 if self
1939 .current_theta
1940 .as_ref()
1941 .is_some_and(|cached| theta_values_match(cached, theta))
1942 && self.last_outer_iter
1943 == Some(gam_solve::estimate::reml::outer_eval::current_outer_iter())
1944 {
1945 self.last_eval.clone()
1946 } else {
1947 None
1948 }
1949 }
1950
1951 fn store_eval(
1952 &mut self,
1953 eval: (
1954 f64,
1955 Array1<f64>,
1956 gam_problem::HessianResult,
1957 ),
1958 ) {
1959 self.last_cost = Some(eval.0);
1960 self.last_eval = Some(eval);
1961 self.last_outer_iter =
1962 Some(gam_solve::estimate::reml::outer_eval::current_outer_iter());
1963 }
1964
1965 fn store_cost(&mut self, cost: f64) {
1966 self.last_cost = Some(cost);
1967 self.last_outer_iter =
1968 Some(gam_solve::estimate::reml::outer_eval::current_outer_iter());
1969 }
1970
1971 fn reset(&mut self) {
1972 self.current_theta = None;
1973 self.current_latent = None;
1974 self.current_hyper_dirs = None;
1975 self.current_design_cache_id = None;
1976 self.latent_design_cache.invalidate();
1977 self.last_cost = None;
1978 self.last_eval = None;
1979 self.last_outer_iter = None;
1980 }
1981}
1982
1983pub fn fixed_kappa_profiled_reml_score(
1999 data: ArrayView2<'_, f64>,
2000 y: ArrayView1<'_, f64>,
2001 weights: ArrayView1<'_, f64>,
2002 offset: ArrayView1<'_, f64>,
2003 resolvedspec: &TermCollectionSpec,
2004 term_idx: usize,
2005 kappa: f64,
2006 family: LikelihoodSpec,
2007 options: &FitOptions,
2008) -> Result<f64, EstimationError> {
2009 if !kappa.is_finite() {
2010 crate::bail_invalid_estim!("fixed-κ profiled score probed a non-finite κ = {kappa}");
2011 }
2012 let (feature_cols, mut probe_basis) = match resolvedspec
2015 .smooth_terms
2016 .get(term_idx)
2017 .map(|t| &t.basis)
2018 {
2019 Some(SmoothBasisSpec::ConstantCurvature {
2020 feature_cols, spec, ..
2021 }) => (feature_cols.clone(), spec.clone()),
2022 _ => {
2023 crate::bail_invalid_estim!(
2024 "fixed-κ profiled score: term {term_idx} is not a constant-curvature smooth"
2025 )
2026 }
2027 };
2028 probe_basis.kappa = kappa;
2029
2030 let is_unweighted = weights.iter().all(|&w| (w - 1.0).abs() <= 1e-12);
2050 let is_zero_offset = offset.iter().all(|&o| o.abs() <= 1e-12);
2051 if family == LikelihoodSpec::gaussian_identity() && is_unweighted && is_zero_offset {
2052 let x_term = select_columns(data, &feature_cols).map_err(EstimationError::from)?;
2053 let score =
2054 gam_terms::basis::constant_curvature_honest_profiled_reml_score(x_term.view(), y, &probe_basis)
2055 .map_err(|e| {
2056 EstimationError::InvalidInput(format!(
2057 "fixed-κ honest profiled-REML score at κ={kappa} failed: {e}"
2058 ))
2059 })?;
2060 if !score.is_finite() {
2061 crate::bail_invalid_estim!(
2062 "fixed-κ honest profiled-REML score at κ={kappa} is non-finite"
2063 );
2064 }
2065 return Ok(score);
2066 }
2067
2068 let mut probe_spec = resolvedspec.clone();
2070 match probe_spec.smooth_terms.get_mut(term_idx).map(|t| &mut t.basis) {
2071 Some(SmoothBasisSpec::ConstantCurvature { spec, .. }) => spec.kappa = kappa,
2072 _ => {
2073 crate::bail_invalid_estim!(
2074 "fixed-κ profiled score: term {term_idx} is not a constant-curvature smooth"
2075 )
2076 }
2077 }
2078 let fixed_kappa_options = SpatialLengthScaleOptimizationOptions {
2079 enabled: false,
2080 ..SpatialLengthScaleOptimizationOptions::default()
2081 };
2082 let fit = fit_term_collectionwith_spatial_length_scale_optimization(
2083 data,
2084 y.to_owned(),
2085 weights.to_owned(),
2086 offset.to_owned(),
2087 &probe_spec,
2088 family,
2089 options,
2090 &fixed_kappa_options,
2091 )?;
2092 let score = fit_score(&fit.fit);
2093 if !score.is_finite() {
2094 crate::bail_invalid_estim!("fixed-κ profiled fit at κ={kappa} returned a non-finite score");
2095 }
2096 Ok(score)
2097}
2098
2099fn constant_curvature_kappa_fair_argmin(
2124 data: ArrayView2<'_, f64>,
2125 y: ArrayView1<'_, f64>,
2126 resolvedspec: &TermCollectionSpec,
2127 term_idx: usize,
2128) -> Option<f64> {
2129 let (kappa_min, kappa_max) = constant_curvature_kappa_bounds(data, resolvedspec, term_idx);
2130 if !(kappa_min.is_finite() && kappa_max.is_finite() && kappa_max > kappa_min) {
2131 return None;
2132 }
2133 let (feature_cols, base_spec) = match resolvedspec.smooth_terms.get(term_idx).map(|t| &t.basis) {
2134 Some(SmoothBasisSpec::ConstantCurvature {
2135 feature_cols, spec, ..
2136 }) => (feature_cols, spec.clone()),
2137 _ => return None,
2138 };
2139 let x_term = match select_columns(data, feature_cols) {
2140 Ok(x) => x,
2141 Err(e) => {
2142 log::info!("[spatial-kappa] #1464 κ-fair argmin column select failed ({e}); skipping");
2143 return None;
2144 }
2145 };
2146 const GRID_STEPS: usize = 24;
2152 let mut best: Option<(f64, f64)> = None; for i in 0..=GRID_STEPS {
2154 let t = i as f64 / GRID_STEPS as f64;
2155 let kappa = kappa_min + (kappa_max - kappa_min) * t;
2156 let mut probe_spec = base_spec.clone();
2157 probe_spec.kappa = kappa;
2158 match gam_terms::basis::constant_curvature_kappa_fair_sign_score(x_term.view(), y, &probe_spec) {
2159 Ok(score) => {
2160 if best.as_ref().is_none_or(|(b, _)| score < *b) {
2161 best = Some((score, kappa));
2162 }
2163 }
2164 Err(e) => {
2165 log::info!(
2166 "[spatial-kappa] #1464 κ-fair argmin probe at κ={kappa:.4} failed ({e}); skipping"
2167 );
2168 }
2169 }
2170 }
2171 best.map(|(score, kappa)| {
2172 log::info!(
2173 "[spatial-kappa] #1464 κ-fair argmin κ̂={kappa:.4} (κ-fair score={score:.6e}) for term {term_idx}"
2174 );
2175 kappa
2176 })
2177}
2178
2179fn select_constant_curvature_kappa_sign_seed(
2187 data: ArrayView2<'_, f64>,
2188 y: ArrayView1<'_, f64>,
2189 resolvedspec: &TermCollectionSpec,
2190 term_idx: usize,
2191) -> Option<f64> {
2192 let (kappa_min, kappa_max) = constant_curvature_kappa_bounds(data, resolvedspec, term_idx);
2193 if !(kappa_min.is_finite() && kappa_max.is_finite() && kappa_max > kappa_min) {
2194 return None;
2195 }
2196 let (feature_cols, base_spec) = match resolvedspec.smooth_terms.get(term_idx).map(|t| &t.basis) {
2208 Some(SmoothBasisSpec::ConstantCurvature {
2209 feature_cols, spec, ..
2210 }) => (feature_cols, spec.clone()),
2211 _ => return None,
2212 };
2213 let x_term = match select_columns(data, feature_cols) {
2214 Ok(x) => x,
2215 Err(e) => {
2216 log::info!("[spatial-kappa] #1464 sign-basin scan column select failed ({e}); skipping");
2217 return None;
2218 }
2219 };
2220 let probes = [
2224 kappa_min,
2225 0.5 * kappa_min,
2226 0.0,
2227 0.5 * kappa_max,
2228 kappa_max,
2229 ];
2230 let mut best: Option<(f64, f64)> = None; for &kappa in &probes {
2232 let mut probe_spec = base_spec.clone();
2233 probe_spec.kappa = kappa;
2234 match gam_terms::basis::constant_curvature_kappa_fair_sign_score(
2235 x_term.view(),
2236 y,
2237 &probe_spec,
2238 ) {
2239 Ok(score) => {
2240 if best.as_ref().is_none_or(|(b, _)| score < *b) {
2241 best = Some((score, kappa));
2242 }
2243 }
2244 Err(e) => {
2245 log::info!(
2246 "[spatial-kappa] #1464 sign-basin probe at κ={kappa:.4} failed ({e}); skipping"
2247 );
2248 }
2249 }
2250 }
2251 best.map(|(score, kappa)| {
2252 log::info!(
2253 "[spatial-kappa] #1464 κ-fair sign-basin scan selected κ_seed={kappa:.4} \
2254 (κ-fair score={score:.6e}) for term {term_idx}"
2255 );
2256 kappa
2257 })
2258}
2259
2260const SPATIAL_RANGE_PRESCAN_GRID: usize = 7;
2263
2264fn prescan_isotropic_spatial_range_seed(
2296 data: ArrayView2<'_, f64>,
2297 y: ArrayView1<'_, f64>,
2298 weights: ArrayView1<'_, f64>,
2299 offset: ArrayView1<'_, f64>,
2300 resolvedspec: &TermCollectionSpec,
2301 baseline_score: f64,
2302 family: &LikelihoodSpec,
2303 options: &FitOptions,
2304 kappa_options: &SpatialLengthScaleOptimizationOptions,
2305 spatial_terms: &[usize],
2306) -> Result<Vec<(usize, f64)>, EstimationError> {
2307 if has_aniso_terms(resolvedspec, spatial_terms)
2309 || !constant_curvature_term_indices(resolvedspec).is_empty()
2310 {
2311 return Ok(Vec::new());
2312 }
2313 let dims = spatial_dims_per_term(resolvedspec, spatial_terms);
2314 let mut working = resolvedspec.clone();
2318 let mut best_score = if baseline_score.is_finite() {
2319 baseline_score
2320 } else {
2321 f64::INFINITY
2322 };
2323 let mut overrides: Vec<(usize, f64)> = Vec::new();
2324 for (slot, &term_idx) in spatial_terms.iter().enumerate() {
2325 if dims.get(slot).copied().unwrap_or(1) != 1 {
2328 continue;
2329 }
2330 if get_spatial_length_scale(&working, term_idx).is_none() {
2333 continue;
2334 }
2335 let (psi_lo, psi_hi) = spatial_term_psi_bounds(data, &working, term_idx, kappa_options);
2336 if !(psi_lo.is_finite() && psi_hi.is_finite()) || psi_hi <= psi_lo {
2337 continue;
2338 }
2339 let mut term_best: Option<f64> = None;
2340 for g in 0..SPATIAL_RANGE_PRESCAN_GRID {
2341 let frac = g as f64 / (SPATIAL_RANGE_PRESCAN_GRID - 1) as f64;
2342 let psi = psi_lo + (psi_hi - psi_lo) * frac;
2343 let ls = (-psi).exp();
2347 if !ls.is_finite() || ls <= 0.0 {
2348 continue;
2349 }
2350 let mut probe = working.clone();
2351 if set_spatial_length_scale(&mut probe, term_idx, ls).is_err() {
2352 continue;
2353 }
2354 let fit = match fit_term_collection_forspec(
2355 data,
2356 y,
2357 weights,
2358 offset,
2359 &probe,
2360 family.clone(),
2361 options,
2362 ) {
2363 Ok(fit) => fit,
2364 Err(_) => continue,
2367 };
2368 let score = fit_score(&fit.fit);
2369 if score.is_finite() && score < best_score - 1e-7 * best_score.abs().max(1.0) {
2372 best_score = score;
2373 term_best = Some(ls);
2374 }
2375 }
2376 if let Some(ls) = term_best {
2377 set_spatial_length_scale(&mut working, term_idx, ls)?;
2378 overrides.push((term_idx, ls));
2379 log::info!(
2380 "[spatial-kappa] #1074 range pre-scan: term {term_idx} re-seeded at \
2381 length_scale={ls:.5} (profiled REML {best_score:.5}, was {baseline_score:.5})"
2382 );
2383 }
2384 }
2385 Ok(overrides)
2386}
2387
2388fn try_exact_joint_spatial_length_scale_optimization(
2389 data: ArrayView2<'_, f64>,
2390 y: ArrayView1<'_, f64>,
2391 weights: ArrayView1<'_, f64>,
2392 offset: ArrayView1<'_, f64>,
2393 resolvedspec: &TermCollectionSpec,
2394 best: &FittedTermCollection,
2395 family: LikelihoodSpec,
2396 options: &FitOptions,
2397 kappa_options: &SpatialLengthScaleOptimizationOptions,
2398 spatial_terms: &[usize],
2399) -> Result<Option<FittedTermCollectionWithSpec>, EstimationError> {
2400 if spatial_terms.is_empty() {
2401 return Ok(None);
2402 }
2403 kappa_options
2408 .validate()
2409 .map_err(EstimationError::InvalidInput)?;
2410
2411 let cc_term_set = constant_curvature_term_indices(resolvedspec);
2431 let all_spatial_are_cc =
2432 !cc_term_set.is_empty() && spatial_terms.iter().all(|t| cc_term_set.contains(t));
2433 if all_spatial_are_cc {
2434 let mut fixed_kappa_spec = resolvedspec.clone();
2435 let mut any_kappa_chosen = false;
2436 for &term_idx in spatial_terms {
2437 if let Some(kappa_hat) =
2448 constant_curvature_kappa_fair_argmin(data, y, resolvedspec, term_idx)
2449 .filter(|&k| k < 0.0)
2450 {
2451 if let Some(SmoothBasisSpec::ConstantCurvature { spec: cc, .. }) = fixed_kappa_spec
2452 .smooth_terms
2453 .get_mut(term_idx)
2454 .map(|t| &mut t.basis)
2455 {
2456 cc.kappa = kappa_hat;
2457 any_kappa_chosen = true;
2458 log::info!(
2459 "[spatial-kappa] #1464 term {term_idx}: fixed κ̂ = {kappa_hat:.4} from κ-fair argmin (hyperbolic basin; profiling ρ only)"
2460 );
2461 }
2462 }
2463 }
2464 if any_kappa_chosen {
2465 let baseline_score = fit_score(&best.fit);
2469 let fitted = fit_term_collection_forspec(
2470 data,
2471 y,
2472 weights,
2473 offset,
2474 &fixed_kappa_spec,
2475 family.clone(),
2476 options,
2477 )?;
2478 let frozen_spec =
2479 freeze_term_collection_from_design(&fixed_kappa_spec, &fitted.design)?;
2480 let mut fit = fitted.fit;
2481 fit.reml_score = baseline_score;
2493 return Ok(Some(FittedTermCollectionWithSpec {
2494 fit,
2495 design: fitted.design,
2496 resolvedspec: frozen_spec,
2497 adaptive_diagnostics: fitted.adaptive_diagnostics,
2498 kappa_timing: None,
2499 }));
2500 }
2501 }
2502
2503 if try_build_spatial_log_kappa_hyper_dirs(data, resolvedspec, &best.design, spatial_terms)?
2504 .is_none()
2505 {
2506 if !constant_curvature_term_indices(resolvedspec).is_empty() {
2507 log::info!(
2508 "[#1464-trace] try_exact_joint RETURNED None (hyper_dirs unavailable); \
2509 κ̂ comes from a NON-joint path"
2510 );
2511 }
2512 return Ok(None);
2513 }
2514 if !constant_curvature_term_indices(resolvedspec).is_empty() {
2515 log::info!(
2516 "[#1464-trace] try_exact_joint ENTERED for {} spatial term(s); CC present",
2517 spatial_terms.len()
2518 );
2519 }
2520
2521 const JOINT_RHO_BOUND: f64 = 12.0;
2522 let rho_dim = best.fit.lambdas.len();
2523
2524 let has_constant_curvature_term = !constant_curvature_term_indices(resolvedspec).is_empty();
2538 let rho_upper_bound = if has_constant_curvature_term {
2539 gam_solve::estimate::RHO_BOUND
2540 } else {
2541 JOINT_RHO_BOUND
2542 };
2543
2544 let dims_per_term = spatial_dims_per_term(resolvedspec, spatial_terms);
2546 let use_aniso = has_aniso_terms(resolvedspec, spatial_terms);
2547
2548 let log_kappa0 = if use_aniso {
2553 SpatialLogKappaCoords::from_length_scales_aniso(resolvedspec, spatial_terms, kappa_options)
2554 } else {
2555 SpatialLogKappaCoords::from_length_scales(resolvedspec, spatial_terms, kappa_options)
2556 };
2557 let mut log_kappa0 =
2560 log_kappa0.reseed_from_data(data, resolvedspec, spatial_terms, kappa_options);
2561 let mut cc_sign_seeds: Vec<(usize, f64)> = Vec::new();
2577 if has_constant_curvature_term {
2578 for (slot, &term_idx) in spatial_terms.iter().enumerate() {
2579 if constant_curvature_term_spec(resolvedspec, term_idx).is_none() {
2580 continue;
2581 }
2582 let scan = select_constant_curvature_kappa_sign_seed(
2583 data,
2584 y,
2585 resolvedspec,
2586 term_idx,
2587 );
2588 match scan {
2593 Some(kappa_seed) => {
2594 log::info!(
2595 "[#1464-trace] term {term_idx}: κ-fair sign-basin scan picked κ_seed = {kappa_seed}"
2596 );
2597 log_kappa0.set_scalar_slot(slot, kappa_seed);
2598 cc_sign_seeds.push((slot, kappa_seed));
2599 }
2600 None => {
2601 log::info!(
2602 "[#1464-trace] term {term_idx}: fixed-κ sign-basin scan returned NONE (no seed applied)"
2603 );
2604 }
2605 }
2606 }
2607 }
2608 let log_kappa_lower = if use_aniso {
2609 SpatialLogKappaCoords::lower_bounds_aniso_from_data(
2610 data,
2611 resolvedspec,
2612 spatial_terms,
2613 &dims_per_term,
2614 kappa_options,
2615 )
2616 } else {
2617 SpatialLogKappaCoords::lower_bounds_from_data(
2618 data,
2619 resolvedspec,
2620 spatial_terms,
2621 kappa_options,
2622 )
2623 };
2624 let log_kappa_upper = if use_aniso {
2625 SpatialLogKappaCoords::upper_bounds_aniso_from_data(
2626 data,
2627 resolvedspec,
2628 spatial_terms,
2629 &dims_per_term,
2630 kappa_options,
2631 )
2632 } else {
2633 SpatialLogKappaCoords::upper_bounds_from_data(
2634 data,
2635 resolvedspec,
2636 spatial_terms,
2637 kappa_options,
2638 )
2639 };
2640 let mut log_kappa_lower = log_kappa_lower;
2664 let mut log_kappa_upper = log_kappa_upper;
2665 for &(slot, kappa_seed) in &cc_sign_seeds {
2666 if kappa_seed != 0.0 {
2667 log_kappa_lower.set_scalar_slot(slot, kappa_seed);
2668 log_kappa_upper.set_scalar_slot(slot, kappa_seed);
2669 }
2670 log::info!(
2671 "[#1464-trace] slot {slot}: FROZE joint ψ coordinate at κ_seed={kappa_seed} \
2672 (window [{}, {}]); raw fit_score is sign-blind so the κ-fair scan is authoritative",
2673 log_kappa_lower.as_array()[log_kappa_lower.dims_per_term()[..slot].iter().sum::<usize>()],
2674 log_kappa_upper.as_array()[log_kappa_upper.dims_per_term()[..slot].iter().sum::<usize>()],
2675 );
2676 }
2677 let log_kappa0 = log_kappa0.clamp_to_bounds(&log_kappa_lower, &log_kappa_upper);
2680 let setup = ExactJointHyperSetup::new(
2681 best.fit.lambdas.mapv(f64::ln),
2682 Array1::<f64>::from_elem(rho_dim, -JOINT_RHO_BOUND),
2683 Array1::<f64>::from_elem(rho_dim, rho_upper_bound),
2684 log_kappa0,
2685 log_kappa_lower,
2686 log_kappa_upper,
2687 );
2688
2689 let theta0 = setup.theta0();
2690 let lower = setup.lower();
2691 let upper = setup.upper();
2692
2693 let kind = if use_aniso {
2705 SpatialHyperKind::Anisotropic
2706 } else {
2707 SpatialHyperKind::Isotropic
2708 };
2709 let (outcome, kappa_timing) = run_exact_joint_spatial_optimization(
2710 kind,
2711 data,
2712 y,
2713 weights,
2714 offset,
2715 resolvedspec,
2716 &best.design,
2717 family.clone(),
2718 options,
2719 spatial_terms,
2720 &dims_per_term,
2721 &theta0,
2722 &lower,
2723 &upper,
2724 rho_dim,
2725 kappa_options,
2726 )?;
2727
2728 let baseline_score = fit_score(&best.fit);
2729
2730 let (theta_star, joint_final_value) = match outcome {
2740 SpatialJointOutcome::Optimized {
2741 theta_star,
2742 final_value,
2743 } => (theta_star, final_value),
2744 SpatialJointOutcome::NonConverged {
2745 iterations,
2746 final_value,
2747 final_grad_norm,
2748 } => {
2749 if has_constant_curvature_term {
2750 log::info!(
2751 "[#1464-trace] joint solve NONCONVERGED (iters={iterations}, \
2752 final_value={final_value}); returning FROZEN BASELINE geometry \
2753 (κ̂ = spec default, NOT the joint candidate)"
2754 );
2755 }
2756 log::info!(
2757 "[spatial-kappa] joint spatial optimization did not converge \
2758 (iterations={}, final_objective={:.6e}, final_grad_norm={}); \
2759 keeping the frozen baseline geometry",
2760 iterations,
2761 final_value,
2762 final_grad_norm.map_or_else(|| "n/a".to_string(), |g| format!("{g:.3e}")),
2763 );
2764 return Ok(Some(fit_frozen_baseline_geometry(
2765 data,
2766 y,
2767 weights,
2768 offset,
2769 resolvedspec,
2770 best,
2771 family,
2772 options,
2773 baseline_score,
2774 Some(kappa_timing),
2775 )?));
2776 }
2777 };
2778
2779 let accept_tol = options.tol.max(1e-8 * baseline_score.abs()).max(1e-12);
2784 if joint_final_value > baseline_score + accept_tol {
2785 if has_constant_curvature_term {
2786 log::info!(
2787 "[#1464-trace] joint candidate WORSENED score (joint={joint_final_value}, \
2788 baseline={baseline_score}); returning FROZEN BASELINE geometry \
2789 (κ̂ = spec default, NOT the joint candidate)"
2790 );
2791 }
2792 log::info!(
2793 "[spatial-kappa] exact joint spatial candidate worsened the profiled score (joint={:.6e}, baseline={:.6e}, tol={:.2e}); keeping the frozen baseline geometry",
2794 joint_final_value,
2795 baseline_score,
2796 accept_tol,
2797 );
2798 return Ok(Some(fit_frozen_baseline_geometry(
2799 data,
2800 y,
2801 weights,
2802 offset,
2803 resolvedspec,
2804 best,
2805 family,
2806 options,
2807 baseline_score,
2808 Some(kappa_timing),
2809 )?));
2810 }
2811
2812 let rho_star = theta_star.slice(s![..rho_dim]).mapv(f64::exp);
2813 let log_kappa_star =
2814 SpatialLogKappaCoords::from_theta_tail_with_dims(&theta_star, rho_dim, dims_per_term);
2815 if has_constant_curvature_term {
2821 let star = log_kappa_star.as_array();
2822 let dims = log_kappa_star.dims_per_term();
2823 for (slot, &term_idx) in spatial_terms.iter().enumerate() {
2824 if constant_curvature_term_spec(resolvedspec, term_idx).is_some() {
2825 let off: usize = dims[..slot].iter().sum();
2826 log::info!(
2827 "[#1464-trace] term {term_idx}: joint solver CONVERGED ψ-tail κ = {} \
2828 (this is the optimised candidate; joint_final_value={joint_final_value})",
2829 star[off]
2830 );
2831 }
2832 }
2833 }
2834 let baseline_spec = resolvedspec;
2838 let optimized_spec = log_kappa_star.apply_tospec(resolvedspec, spatial_terms)?;
2839 let optimized = fit_term_collection_forspecwith_heuristic_lambdas(
2840 data,
2841 y,
2842 weights,
2843 offset,
2844 &optimized_spec,
2845 rho_star.as_slice(),
2846 family.clone(),
2847 options,
2848 )?;
2849
2850 let optimized_edf = optimized.fit.inference.as_ref().map(|inf| inf.edf_total);
2864 if let Some(opt_edf) = optimized_edf
2865 && opt_edf < SPATIAL_COLLAPSE_EDF_FLOOR
2866 {
2867 let baseline = fit_frozen_baseline_geometry(
2868 data,
2869 y,
2870 weights,
2871 offset,
2872 baseline_spec,
2873 best,
2874 family.clone(),
2875 options,
2876 baseline_score,
2877 Some(kappa_timing),
2878 )?;
2879 let baseline_edf = baseline.fit.inference.as_ref().map(|inf| inf.edf_total);
2880 if let Some(base_edf) = baseline_edf
2881 && base_edf >= opt_edf + SPATIAL_COLLAPSE_EDF_MARGIN
2882 {
2883 log::info!(
2884 "[spatial-kappa] joint candidate collapsed to the null (edf={opt_edf:.3}); \
2885 baseline geometry retains edf={base_edf:.3} — keeping the frozen baseline",
2886 );
2887 return Ok(Some(baseline));
2888 }
2889 }
2892
2893 let mut fit = optimized.fit;
2897 fit.reml_score = joint_final_value;
2898 let optimized_result = FittedTermCollectionWithSpec {
2899 fit,
2900 design: optimized.design,
2901 resolvedspec: optimized_spec,
2902 adaptive_diagnostics: optimized.adaptive_diagnostics,
2903 kappa_timing: Some(kappa_timing),
2904 };
2905
2906 Ok(Some(optimized_result))
2907}
2908
2909const SPATIAL_COLLAPSE_EDF_FLOOR: f64 = 2.5;
2913
2914const SPATIAL_COLLAPSE_EDF_MARGIN: f64 = 1.0;
2919
2920fn fit_frozen_baseline_geometry(
2956 data: ArrayView2<'_, f64>,
2957 y: ArrayView1<'_, f64>,
2958 weights: ArrayView1<'_, f64>,
2959 offset: ArrayView1<'_, f64>,
2960 resolvedspec: &TermCollectionSpec,
2961 best: &FittedTermCollection,
2962 family: LikelihoodSpec,
2963 options: &FitOptions,
2964 baseline_score: f64,
2965 kappa_timing: Option<SpatialLengthScaleOptimizationTiming>,
2966) -> Result<FittedTermCollectionWithSpec, EstimationError> {
2967 let baseline = fit_term_collection_forspecwith_heuristic_lambdas(
2968 data,
2969 y,
2970 weights,
2971 offset,
2972 resolvedspec,
2973 best.fit.lambdas.as_slice(),
2974 family.clone(),
2975 options,
2976 )?;
2977 let best_edf = best.fit.inference.as_ref().map(|inf| inf.edf_total);
2982 let baseline_edf = baseline.fit.inference.as_ref().map(|inf| inf.edf_total);
2983 let baseline = match (best_edf, baseline_edf) {
2984 (Some(best_edf), Some(base_edf))
2985 if base_edf < SPATIAL_COLLAPSE_EDF_FLOOR
2986 && best_edf >= base_edf + SPATIAL_COLLAPSE_EDF_MARGIN =>
2987 {
2988 log::info!(
2989 "[spatial-kappa] warm-started frozen baseline collapsed (edf={base_edf:.3}) \
2990 below the certified baseline (edf={best_edf:.3}); refitting from scratch",
2991 );
2992 fit_term_collection_forspec(data, y, weights, offset, resolvedspec, family, options)?
2993 }
2994 _ => baseline,
2995 };
2996 let mut fit = baseline.fit;
2997 fit.reml_score = baseline_score;
2998 Ok(FittedTermCollectionWithSpec {
2999 fit,
3000 design: baseline.design,
3001 resolvedspec: resolvedspec.clone(),
3002 adaptive_diagnostics: baseline.adaptive_diagnostics,
3003 kappa_timing,
3004 })
3005}
3006
3007#[derive(Clone, Copy, PartialEq, Eq, Debug)]
3019enum SpatialHyperKind {
3020 Anisotropic,
3021 Isotropic,
3022}
3023
3024impl SpatialHyperKind {
3025 fn label(self) -> &'static str {
3028 match self {
3029 SpatialHyperKind::Anisotropic => "spatial-aniso-joint",
3030 SpatialHyperKind::Isotropic => "spatial-iso-joint",
3031 }
3032 }
3033
3034 fn adjective(self) -> &'static str {
3036 match self {
3037 SpatialHyperKind::Anisotropic => "anisotropic",
3038 SpatialHyperKind::Isotropic => "isotropic",
3039 }
3040 }
3041
3042 fn coord_name(self) -> &'static str {
3045 match self {
3046 SpatialHyperKind::Anisotropic => "psi",
3047 SpatialHyperKind::Isotropic => "kappa",
3048 }
3049 }
3050}
3051
3052struct SpatialFrozenGlmInputs {
3058 y: Array1<f64>,
3059 weights: Array1<f64>,
3060 offset: Array1<f64>,
3061 family: LikelihoodSpec,
3062}
3063
3064fn frozen_glm_tensor_eligible_family(family: &LikelihoodSpec) -> bool {
3081 !family.is_gaussian_identity()
3082 && matches!(
3083 &family.response,
3084 ResponseFamily::Binomial
3085 | ResponseFamily::Poisson
3086 | ResponseFamily::Gamma
3087 | ResponseFamily::NegativeBinomial { .. }
3088 )
3089}
3090
3091struct SpatialJointContext<'d> {
3092 data: ArrayView2<'d, f64>,
3093 rho_dim: usize,
3094 kind: SpatialHyperKind,
3095 cache: SingleBlockExactJointDesignCache<'d>,
3096 evaluator: gam_solve::estimate::ExternalJointHyperEvaluator<'d>,
3097 frozen_glm_inputs: Option<SpatialFrozenGlmInputs>,
3098 frozen_glm_psi_bounds: Option<(f64, f64)>,
3099 frozen_glm_tensor: Option<gam_solve::glm_sufficient_lane::FrozenWeightGramTensor>,
3100 frozen_glm_tensor_attempted: bool,
3101 frozen_glm_weight_memo: Option<(Array1<f64>, Array1<f64>)>,
3113}
3114
3115#[derive(Clone, Copy, Debug, Default)]
3116struct NfreeSkipGateStatus {
3117 shape: bool,
3118 value: bool,
3119 gradient: bool,
3120 penalty: bool,
3121 revision: bool,
3122 second_order: bool,
3123}
3124
3125impl NfreeSkipGateStatus {
3126 fn would_skip(self, require_gradient: bool) -> bool {
3127 self.shape
3128 && self.value
3129 && (!require_gradient || self.gradient)
3130 && self.penalty
3131 && self.revision
3132 && !self.second_order
3133 }
3134}
3135
3136impl<'d> SpatialJointContext<'d> {
3137 fn nfree_skip_gate_status(
3138 &self,
3139 theta: &Array1<f64>,
3140 allow_second_order: bool,
3141 require_gradient: bool,
3142 ) -> NfreeSkipGateStatus {
3143 let shape = theta.len() == self.rho_dim + 1;
3144 let (value, gradient) = if shape {
3145 let psi = theta[self.rho_dim];
3146 (
3147 self.evaluator.psi_gram_tensor_covers(psi)
3148 && self.evaluator.psi_gram_tensor_covers_skip(psi),
3149 !require_gradient || self.evaluator.psi_gram_tensor_covers_gradient(psi),
3150 )
3151 } else {
3152 (false, false)
3153 };
3154 NfreeSkipGateStatus {
3155 shape,
3156 value,
3157 gradient,
3158 penalty: self.evaluator.supports_nfree_penalty_rekey(),
3159 revision: self.evaluator.nfree_fast_path_revision().is_some(),
3160 second_order: allow_second_order,
3161 }
3162 }
3163
3164 fn frozen_glm_working_state(
3165 &self,
3166 beta: &Array1<f64>,
3167 ) -> Result<Option<(Array1<f64>, Array1<f64>)>, EstimationError> {
3168 let Some(inputs) = self.frozen_glm_inputs.as_ref() else {
3169 return Ok(None);
3170 };
3171 if beta.len() != self.cache.design().design.ncols() {
3172 return Ok(None);
3173 }
3174 let mut eta = self.cache.design().design.matrixvectormultiply(beta);
3175 if eta.len() != inputs.offset.len() {
3176 crate::bail_invalid_estim!(
3177 "frozen GLM tensor warm-state row mismatch: eta={}, offset={}",
3178 eta.len(),
3179 inputs.offset.len()
3180 );
3181 }
3182 eta += &inputs.offset;
3183 let obs = evaluate_standard_familyobservations(
3184 inputs.family.clone(),
3185 None,
3186 None,
3187 None,
3188 &inputs.y,
3189 &inputs.weights,
3190 &eta,
3191 )?;
3192 let mut working_response = obs.eta.clone();
3193 for i in 0..working_response.len() {
3194 let wi = obs.fisherweight[i].max(1e-12);
3195 working_response[i] += obs.score[i] / wi;
3196 }
3197 Ok(Some((obs.fisherweight, working_response)))
3198 }
3199
3200 fn frozen_glm_trial_weights(
3209 &mut self,
3210 beta: &Array1<f64>,
3211 ) -> Result<Option<Array1<f64>>, EstimationError> {
3212 if let Some((memo_beta, memo_w)) = self.frozen_glm_weight_memo.as_ref()
3213 && memo_beta.len() == beta.len()
3214 && memo_beta
3215 .iter()
3216 .zip(beta.iter())
3217 .all(|(a, b)| a.to_bits() == b.to_bits())
3218 {
3219 return Ok(Some(memo_w.clone()));
3220 }
3221 match self.frozen_glm_working_state(beta)? {
3222 Some((current_w, _)) => {
3223 self.frozen_glm_weight_memo = Some((beta.clone(), current_w.clone()));
3224 Ok(Some(current_w))
3225 }
3226 None => Ok(None),
3227 }
3228 }
3229
3230 fn ensure_frozen_glm_tensor(
3231 &mut self,
3232 theta: &Array1<f64>,
3233 warm_beta: Option<&Array1<f64>>,
3234 ) -> Result<(), EstimationError> {
3235 if self.frozen_glm_tensor.is_some() || self.frozen_glm_tensor_attempted {
3236 return Ok(());
3237 }
3238 let Some((psi_lo, psi_hi)) = self.frozen_glm_psi_bounds else {
3239 return Ok(());
3240 };
3241 if theta.len() != self.rho_dim + 1 {
3242 self.frozen_glm_tensor_attempted = true;
3243 return Ok(());
3244 }
3245 let Some(beta) = warm_beta else {
3246 return Ok(());
3247 };
3248 let Some((frozen_w, working_z)) = self.frozen_glm_working_state(beta)? else {
3249 self.frozen_glm_tensor_attempted = true;
3250 return Ok(());
3251 };
3252 let theta_probe_base = theta.clone();
3253 let rho_dim = self.rho_dim;
3254 let Self {
3261 cache, evaluator, ..
3262 } = self;
3263 let tensor = evaluator.build_frozen_glm_gram_tensor(
3264 |psi| {
3265 let mut theta_probe = theta_probe_base.clone();
3266 theta_probe[rho_dim] = psi;
3267 cache.ensure_theta(&theta_probe)?;
3268 Ok(cache.design().design.clone())
3269 },
3270 frozen_w.view(),
3271 working_z.view(),
3272 psi_lo,
3273 psi_hi,
3274 );
3275 self.cache
3276 .ensure_theta(theta)
3277 .map_err(EstimationError::InvalidInput)?;
3278 self.frozen_glm_tensor_attempted = true;
3279 if let Some(tensor) = tensor {
3280 self.frozen_glm_tensor = Some(tensor);
3281 log::info!(
3282 "[STAGE] {} certified frozen-W GLM ψ tensor over [{psi_lo:.3}, {psi_hi:.3}]",
3283 self.kind.label(),
3284 );
3285 } else {
3286 log::info!(
3287 "[STAGE] {} frozen-W GLM ψ tensor did not certify over [{psi_lo:.3}, {psi_hi:.3}]",
3288 self.kind.label(),
3289 );
3290 }
3291 Ok(())
3292 }
3293
3294 fn stage_frozen_glm_trial_statistics(
3295 &mut self,
3296 theta: &Array1<f64>,
3297 warm_beta: Option<&Array1<f64>>,
3298 allow_gradient: bool,
3299 ) -> Result<(), EstimationError> {
3300 let kind = self.kind;
3301 let mut staged_gram: Option<Array2<f64>> = None;
3302 let mut staged_deriv: Option<(Array2<f64>, Array1<f64>)> = None;
3303 if theta.len() == self.rho_dim + 1 {
3304 let psi = theta[self.rho_dim];
3305 let tensor_covers = self
3312 .frozen_glm_tensor
3313 .as_ref()
3314 .is_some_and(|t| t.contains(psi));
3315 let current_w = if tensor_covers {
3316 match warm_beta {
3317 Some(beta) => self.frozen_glm_trial_weights(beta)?,
3318 None => None,
3319 }
3320 } else {
3321 None
3322 };
3323 if let (Some(tensor), Some(current_w)) =
3324 (self.frozen_glm_tensor.as_ref(), current_w.as_ref())
3325 {
3326 const FROZEN_GLM_WEIGHT_DRIFT_RTOL: f64 = 1e-3;
3327 if tensor.weight_drift_within(current_w.view(), FROZEN_GLM_WEIGHT_DRIFT_RTOL) {
3328 staged_gram = Some(tensor.gram_at(psi));
3329 log::debug!(
3330 "[STAGE] {} trial at psi={psi:.6}: serving frozen-W GLM \
3331 first-Fisher-step XᵀWX n-free (weight drift within tol)",
3332 kind.label(),
3333 );
3334 }
3335 if allow_gradient
3336 && tensor.contains_for_gradient(psi)
3337 && let Some((dgram_dpsi, drhs_dpsi)) =
3338 tensor.gradient_pair_if_sound(psi, current_w.view())
3339 {
3340 staged_deriv = Some((dgram_dpsi, drhs_dpsi));
3341 log::debug!(
3342 "[STAGE] {} trial at psi={psi:.6}: serving frozen-W GLM \
3343 ψ-gradient (∂G/∂ψ, ∂b/∂ψ) n-free (gradient weight drift within \
3344 tight tol); B_j stays exact",
3345 kind.label(),
3346 );
3347 }
3348 }
3349 }
3350 self.evaluator.stage_glm_first_step_gram(staged_gram);
3351 self.evaluator.stage_glm_psi_gram_deriv(staged_deriv);
3352 Ok(())
3353 }
3354
3355 fn eval_full(
3357 &mut self,
3358 theta: &Array1<f64>,
3359 order: gam_solve::rho_optimizer::OuterEvalOrder,
3360 analytic_outer_hessian_available: bool,
3361 ) -> Result<
3362 (
3363 f64,
3364 Array1<f64>,
3365 gam_problem::HessianResult,
3366 ),
3367 EstimationError,
3368 > {
3369 use gam_solve::rho_optimizer::OuterEvalOrder;
3370 let allow_second_order = matches!(order, OuterEvalOrder::ValueGradientHessian)
3371 && analytic_outer_hessian_available;
3372 if let Some(eval) = self.cache.memoized_eval(theta) {
3373 let cached_satisfies_order = !allow_second_order || eval.2.is_analytic();
3374 if cached_satisfies_order {
3375 return Ok(eval);
3376 }
3377 }
3378 let kind = self.kind;
3379 let nfree_fast_path_revision = self.evaluator.nfree_fast_path_revision();
3415 let skip_design_realization = !allow_second_order && theta.len() == self.rho_dim + 1 && {
3416 let psi = theta[self.rho_dim];
3417 self.evaluator.psi_gram_tensor_covers(psi)
3418 && self.evaluator.psi_gram_tensor_covers_gradient(psi)
3425 && self.evaluator.psi_gram_tensor_covers_skip(psi)
3442 && self.evaluator.supports_nfree_penalty_rekey()
3447 && nfree_fast_path_revision.is_some()
3448 };
3449 if skip_design_realization {
3450 log::debug!(
3451 "[STAGE] {} eval_full at psi={:.6}: skipping n×k design re-realization \
3452 + reconditioning — criterion/gradient/inner-solve served n-free from \
3453 the certified ψ-gram tensor (GaussianFixedCache + k-space ψ-derivatives)",
3454 kind.label(),
3455 theta[self.rho_dim],
3456 );
3457 } else {
3458 self.cache
3459 .ensure_theta(theta)
3460 .map_err(EstimationError::InvalidInput)?;
3461 }
3462 let warm_beta = self.evaluator.current_beta();
3463 self.ensure_frozen_glm_tensor(theta, warm_beta.as_ref())?;
3464 self.stage_frozen_glm_trial_statistics(theta, warm_beta.as_ref(), !allow_second_order)?;
3472 let hyper_dirs = if skip_design_realization {
3479 self.cache.nfree_tensor_gradient_hyper_dirs(theta)?
3480 } else {
3481 self.cache.hyper_dirs_for_current_design(self.data, kind)?
3482 };
3483
3484 let design_revision = if skip_design_realization {
3485 nfree_fast_path_revision
3486 } else {
3487 Some(self.cache.design_revision())
3488 };
3489 if self.evaluator.supports_nfree_penalty_rekey() {
3503 match self.cache.canonical_penalties_at(theta) {
3504 Ok(penalty) => self.evaluator.stage_fast_path_penalty(Some(penalty)),
3505 Err(e) => {
3506 log::warn!(
3507 "[STAGE] {} eval_full at psi={:.6}: exact n-free S(ψ) rebuild failed \
3508 ({e}); clearing stage (eval falls to slow path)",
3509 kind.label(),
3510 theta[self.rho_dim],
3511 );
3512 self.evaluator.stage_fast_path_penalty(None);
3513 }
3514 }
3515 }
3516 let eval = evaluate_joint_reml_outer_eval_at_theta(
3523 &mut self.evaluator,
3524 self.cache.design(),
3525 theta,
3526 self.rho_dim,
3527 hyper_dirs,
3528 warm_beta.as_ref().map(|b: &Array1<f64>| b.view()),
3529 if allow_second_order {
3530 order
3531 } else {
3532 OuterEvalOrder::ValueAndGradient
3533 },
3534 design_revision,
3535 );
3536 if let Ok(ref value) = eval {
3537 self.cache.store_eval_at(theta, value.clone());
3538 }
3539 eval
3540 }
3541
3542 fn eval_efs(
3543 &mut self,
3544 theta: &Array1<f64>,
3545 ) -> Result<gam_problem::EfsEval, EstimationError> {
3546 self.cache
3547 .ensure_theta(theta)
3548 .map_err(EstimationError::InvalidInput)?;
3549 let kind = self.kind;
3550 let hyper_dirs = try_build_spatial_log_kappa_hyper_dirs(
3551 self.data,
3552 self.cache.spec(),
3553 self.cache.design(),
3554 &self.cache.spatial_terms,
3555 )?
3556 .ok_or_else(|| {
3557 EstimationError::InvalidInput(format!(
3558 "failed to build {} hyper_dirs for exact-joint EFS",
3559 kind.adjective(),
3560 ))
3561 })?;
3562 let design_revision = Some(self.cache.design_revision());
3563 let warm_beta = self.evaluator.current_beta();
3564 evaluate_joint_reml_efs_at_theta(
3565 &mut self.evaluator,
3566 self.cache.design(),
3567 theta,
3568 self.rho_dim,
3569 hyper_dirs,
3570 warm_beta.as_ref().map(|b: &Array1<f64>| b.view()),
3571 design_revision,
3572 )
3573 }
3574
3575 fn eval_cost(&mut self, theta: &Array1<f64>) -> f64 {
3581 if let Some(cost) = self.cache.memoized_cost(theta) {
3582 return cost;
3583 }
3584 let probe_start = std::time::Instant::now();
3599 let psi_distance = self
3600 .cache
3601 .current_theta
3602 .as_ref()
3603 .filter(|reference| reference.len() == theta.len())
3604 .map(|reference| {
3605 reference
3606 .iter()
3607 .zip(theta.iter())
3608 .map(|(a, b)| (a - b) * (a - b))
3609 .sum::<f64>()
3610 .sqrt()
3611 })
3612 .unwrap_or(f64::NAN);
3613 let nfree_fast_path_revision = self.evaluator.nfree_fast_path_revision();
3627 let skip_value_realization = theta.len() == self.rho_dim + 1 && {
3628 let psi = theta[self.rho_dim];
3629 self.evaluator.psi_gram_tensor_covers(psi)
3630 && self.evaluator.psi_gram_tensor_covers_skip(psi)
3639 && self.evaluator.supports_nfree_penalty_rekey()
3644 && nfree_fast_path_revision.is_some()
3645 };
3646 if theta.len() == self.rho_dim + 1
3647 && self.evaluator.has_psi_gram_tensor()
3648 && !self.evaluator.psi_gram_tensor_covers(theta[self.rho_dim])
3649 {
3650 self.cache.store_cost_at(theta, f64::INFINITY);
3651 return f64::INFINITY;
3652 }
3653 if !skip_value_realization && self.cache.ensure_theta(theta).is_err() {
3654 return f64::INFINITY;
3655 }
3656 if self.evaluator.supports_nfree_penalty_rekey() {
3662 match self.cache.canonical_penalties_at(theta) {
3663 Ok(penalty) => self.evaluator.stage_fast_path_penalty(Some(penalty)),
3664 Err(_) => self.evaluator.stage_fast_path_penalty(None),
3665 }
3666 }
3667 let warm_beta = self.evaluator.current_beta();
3668 if let Err(err) = self.ensure_frozen_glm_tensor(theta, warm_beta.as_ref()) {
3669 log::warn!(
3670 "[STAGE] {} value-probe at psi={:.6}: frozen-W GLM tensor setup failed ({err}); \
3671 falling back to exact streamed Gram",
3672 self.kind.label(),
3673 if theta.len() > self.rho_dim {
3674 theta[self.rho_dim]
3675 } else {
3676 f64::NAN
3677 },
3678 );
3679 self.evaluator.stage_glm_first_step_gram(None);
3680 self.evaluator.stage_glm_psi_gram_deriv(None);
3681 } else if let Err(err) =
3682 self.stage_frozen_glm_trial_statistics(theta, warm_beta.as_ref(), false)
3683 {
3684 log::warn!(
3685 "[STAGE] {} value-probe at psi={:.6}: frozen-W GLM staging failed ({err}); \
3686 falling back to exact streamed Gram",
3687 self.kind.label(),
3688 if theta.len() > self.rho_dim {
3689 theta[self.rho_dim]
3690 } else {
3691 f64::NAN
3692 },
3693 );
3694 self.evaluator.stage_glm_first_step_gram(None);
3695 self.evaluator.stage_glm_psi_gram_deriv(None);
3696 }
3697 let design_revision = if skip_value_realization {
3698 nfree_fast_path_revision
3699 } else {
3700 Some(self.cache.design_revision())
3701 };
3702 let cost_label = self.kind.label();
3703 let result = {
3704 let design = self.cache.design();
3705 self.evaluator.evaluate_cost_only(
3706 &design.design,
3707 &design.penalties,
3708 &design.nullspace_dims,
3709 design.linear_constraints.clone(),
3710 theta,
3711 self.rho_dim,
3712 warm_beta.as_ref().map(|b: &Array1<f64>| b.view()),
3713 cost_label,
3714 design_revision,
3715 )
3716 };
3717 match result {
3718 Ok(cost) => {
3719 log::debug!(
3720 "[STAGE] {cost_label} value-probe (order=Value): elapsed={:.3}s \
3721 cost={cost:.6e} trial_theta_distance={psi_distance:.3e}",
3722 probe_start.elapsed().as_secs_f64(),
3723 );
3724 self.cache.store_cost_at(theta, cost);
3725 cost
3726 }
3727 Err(_) => f64::INFINITY,
3728 }
3729 }
3730
3731 fn reset(&mut self) {
3732 self.cache.current_theta = None;
3733 self.cache.last_eval_theta = None;
3734 self.cache.last_cost = None;
3735 self.cache.last_eval = None;
3736 }
3737}
3738
3739enum SpatialJointOutcome {
3772 Optimized {
3776 theta_star: Array1<f64>,
3777 final_value: f64,
3778 },
3779 NonConverged {
3783 iterations: usize,
3784 final_value: f64,
3785 final_grad_norm: Option<f64>,
3786 },
3787}
3788
3789fn kphase_log_norms(theta: &Array1<f64>, rho_dim: usize) -> (f64, f64) {
3790 let theta_norm = theta.iter().map(|v| v * v).sum::<f64>().sqrt();
3791 let log_kappa_norm = theta
3792 .iter()
3793 .skip(rho_dim)
3794 .map(|v| v * v)
3795 .sum::<f64>()
3796 .sqrt();
3797 (theta_norm, log_kappa_norm)
3798}
3799
3800fn run_exact_joint_spatial_optimization(
3801 kind: SpatialHyperKind,
3802 data: ArrayView2<'_, f64>,
3803 y: ArrayView1<'_, f64>,
3804 weights: ArrayView1<'_, f64>,
3805 offset: ArrayView1<'_, f64>,
3806 resolvedspec: &TermCollectionSpec,
3807 baseline_design: &TermCollectionDesign,
3808 family: LikelihoodSpec,
3809 options: &FitOptions,
3810 spatial_terms: &[usize],
3811 dims_per_term: &[usize],
3812 theta0: &Array1<f64>,
3813 lower: &Array1<f64>,
3814 upper: &Array1<f64>,
3815 rho_dim: usize,
3816 kappa_options: &SpatialLengthScaleOptimizationOptions,
3817) -> Result<(SpatialJointOutcome, SpatialLengthScaleOptimizationTiming), EstimationError> {
3818 let label = kind.label();
3819 assert!(
3821 lower.len() == theta0.len() && upper.len() == theta0.len(),
3822 "spatial hyperparameter bounds must match theta length: lower_len={}, upper_len={}, theta_len={}",
3823 lower.len(),
3824 upper.len(),
3825 theta0.len()
3826 );
3827 assert!(
3828 baseline_design.smooth.terms.len() >= spatial_terms.len(),
3829 "baseline design must have at least one smooth term per spatial term: baseline_terms={}, spatial_terms={}",
3830 baseline_design.smooth.terms.len(),
3831 spatial_terms.len()
3832 );
3833 use gam_solve::rho_optimizer::OuterEvalOrder;
3834 use gam_problem::{DeclaredHessianForm, Derivative, OuterEval};
3835
3836 let theta_dim = theta0.len();
3837 let coord_dim = theta_dim - rho_dim;
3840 let analytic_outer_hessian_available =
3850 exact_joint_spatial_outer_hessian_available(&family, baseline_design);
3851 if !analytic_outer_hessian_available {
3852 log::info!(
3853 "[{label}] analytic outer Hessian unavailable for family/design; routing without second-order geometry (coord_dim={coord_dim})"
3854 );
3855 }
3856 let mut prefer_gradient_only = theta_dim > EXACT_JOINT_SECOND_ORDER_THETA_CAP;
3862 if prefer_gradient_only {
3863 log::info!(
3864 "[{label}] joint θ-dim {theta_dim} exceeds the exact pair-Hessian budget \
3865 ({EXACT_JOINT_SECOND_ORDER_THETA_CAP}); routing gradient-only quasi-Newton"
3866 );
3867 }
3868 let mut suppress_outer_hessian_for_nfree = false;
3878
3879 log::trace!(
3880 "[{}] starting analytic optimization: rho_dim={}, coord_dim={}, dims_per_term={:?}",
3881 label,
3882 rho_dim,
3883 coord_dim,
3884 dims_per_term,
3885 );
3886
3887 let mut ctx = SpatialJointContext {
3888 data,
3889 rho_dim,
3890 kind,
3891 cache: SingleBlockExactJointDesignCache::new(
3892 data,
3893 resolvedspec.clone(),
3894 baseline_design.clone(),
3895 spatial_terms.to_vec(),
3896 rho_dim,
3897 dims_per_term.to_vec(),
3898 )
3899 .map_err(EstimationError::InvalidInput)?,
3900 evaluator: gam_solve::estimate::ExternalJointHyperEvaluator::new(
3901 y,
3902 weights,
3903 &baseline_design.design,
3904 offset,
3905 &baseline_design.penalties,
3906 &external_opts_for_design(&family, baseline_design, options),
3907 label,
3908 )?,
3909 frozen_glm_inputs: if coord_dim == 1 && frozen_glm_tensor_eligible_family(&family) {
3910 Some(SpatialFrozenGlmInputs {
3911 y: y.to_owned(),
3912 weights: weights.to_owned(),
3913 offset: offset.to_owned(),
3914 family: family.clone(),
3915 })
3916 } else {
3917 None
3918 },
3919 frozen_glm_psi_bounds: if coord_dim == 1 && frozen_glm_tensor_eligible_family(&family) {
3920 Some((lower[rho_dim], upper[rho_dim]))
3921 } else {
3922 None
3923 },
3924 frozen_glm_tensor: None,
3925 frozen_glm_tensor_attempted: false,
3926 frozen_glm_weight_memo: None,
3927 };
3928
3929 let mut psi_rank_stable_floor: Option<f64> = None;
3952 let mut psi_rank_stable_ceiling: Option<f64> = None;
3961 let nfree_penalty_capable = coord_dim == 1
3962 && family.is_gaussian_identity()
3963 && ctx.cache.supports_nfree_penalty_rekey();
3964 if nfree_penalty_capable {
3965 let psi_lo = lower[rho_dim];
3966 let psi_hi = upper[rho_dim];
3967 let z = Array1::from_iter(y.iter().zip(offset.iter()).map(|(yi, oi)| yi - oi));
3968 let theta_probe_base = theta0.clone();
3969 let SpatialJointContext {
3972 cache, evaluator, ..
3973 } = &mut ctx;
3974 let attached = evaluator.build_and_set_psi_gram_tensor(
3975 |psi| {
3976 let mut theta_probe = theta_probe_base.clone();
3977 theta_probe[rho_dim] = psi;
3978 cache.ensure_theta(&theta_probe)?;
3979 Ok(cache.design().design.clone())
3980 },
3981 weights,
3982 z.view(),
3983 psi_lo,
3984 psi_hi,
3985 );
3986 if attached {
3987 log::info!(
3988 "[{label}] certified ψ-gram tensor over [{psi_lo:.3}, {psi_hi:.3}]: \
3989 in-window trials assemble Gaussian sufficient statistics n-free"
3990 );
3991 let psi_anchor = theta0[rho_dim];
3996 psi_rank_stable_floor = evaluator
3997 .psi_gram_rank_stable_floor(psi_anchor)
3998 .filter(|&f| f.is_finite() && f > psi_lo && f < psi_anchor);
3999 log::info!(
4000 "[KAPPA-PHASE-FLOOR] n_rows={} psi_lo={psi_lo:.6} psi_anchor={psi_anchor:.6} \
4001 rank_stable_floor={:?} lifted={}",
4002 data.nrows(),
4003 evaluator.psi_gram_rank_stable_floor(psi_anchor),
4004 psi_rank_stable_floor.is_some(),
4005 );
4006 if let Some(floor) = psi_rank_stable_floor {
4007 log::info!(
4008 "[{label}] rank-stable κ-floor ψ_floor={floor:.6} > window floor \
4009 ψ_lo={psi_lo:.6}: lifting the optimizer lower bound to keep every \
4010 in-window trial on the n-free design-realization skip (#1033). The \
4011 conditioned Gram is rank-deficient below ψ_floor (longest-length-scale \
4012 radial mode collapses into the nullspace), where the skip is soundly \
4013 refused; that band drifts with n via the sample-std standardization, \
4014 so this n-free k-space floor is the n-independent fix."
4015 );
4016 }
4017 psi_rank_stable_ceiling = evaluator
4026 .psi_gram_rank_stable_ceiling(psi_anchor)
4027 .filter(|&c| c.is_finite() && c < psi_hi && c > psi_anchor);
4028 log::info!(
4029 "[KAPPA-PHASE-CEIL] n_rows={} psi_hi={psi_hi:.6} psi_anchor={psi_anchor:.6} \
4030 rank_stable_ceiling={:?} clamped={}",
4031 data.nrows(),
4032 evaluator.psi_gram_rank_stable_ceiling(psi_anchor),
4033 psi_rank_stable_ceiling.is_some(),
4034 );
4035 if let Some(ceiling) = psi_rank_stable_ceiling {
4036 log::info!(
4037 "[{label}] rank-stable κ-ceiling ψ_ceil={ceiling:.6} < window ceiling \
4038 ψ_hi={psi_hi:.6}: clamping the optimizer upper bound to keep every \
4039 in-window trial on the n-free design-realization skip (#1033). The \
4040 conditioned Gram is rank-deficient above ψ_ceil (longest-frequency \
4041 radial mode goes collinear), where the skip is soundly refused; a \
4042 line-search overshoot there trips the O(n) reset_surface lane (and the \
4043 deficient pinning ψ it records resets the next in-band trial too)."
4044 );
4045 }
4046 let gradient_covers_full_window = evaluator.psi_gram_tensor_covers_gradient(psi_lo)
4047 && evaluator.psi_gram_tensor_covers_gradient(psi_hi);
4048 if gradient_covers_full_window {
4049 log::info!(
4050 "[{label}] certified ψ-gram tensor gradient lane covers the full \
4051 optimizer window [{psi_lo:.3}, {psi_hi:.3}]"
4052 );
4053 } else {
4054 log::info!(
4055 "[{label}] ψ-gram tensor value lane certified, but the gradient lane \
4056 does not cover the full optimizer window [{psi_lo:.3}, {psi_hi:.3}]; \
4057 keeping exact streamed kappa routing"
4058 );
4059 }
4060 evaluator.set_supports_nfree_penalty_rekey(true);
4080 log::info!(
4081 "[{label}] exact n-free ψ-penalty re-key enabled over [{psi_lo:.3}, \
4082 {psi_hi:.3}]: in-window fast-path trials rebuild S(ψ) n-free from frozen \
4083 geometry (no reset_surface)"
4084 );
4085 } else {
4086 log::info!(
4087 "[{label}] ψ-gram tensor did not certify over [{psi_lo:.3}, {psi_hi:.3}]; \
4088 keeping the exact per-trial path"
4089 );
4090 }
4091 if attached
4112 && evaluator.psi_gram_tensor_covers_gradient(psi_lo)
4113 && evaluator.psi_gram_tensor_covers_gradient(psi_hi)
4114 && evaluator.supports_nfree_penalty_rekey()
4115 && cache.supports_nfree_gradient_only_routing()
4116 {
4117 suppress_outer_hessian_for_nfree = true;
4118 prefer_gradient_only = true;
4119 log::info!(
4120 "[{label}] n-free Gaussian ψ-lane armed; suppressing the analytic outer \
4121 Hessian and routing gradient-only (BFGS) so the κ outer loop never realizes \
4122 the O(n) second-order slab — n-independent outer loop (#1033)"
4123 );
4124 }
4125 } else if coord_dim == 1 && family.is_gaussian_identity() {
4126 log::info!(
4127 "[{label}] exact n-free ψ-penalty re-key unavailable; skipping ψ-gram tensor \
4128 attachment so value, gradient, and Hessian remain on the same exact streamed \
4129 objective"
4130 );
4131 }
4132
4133 const OUTER_FD_AUDIT_MAX_N: usize = 4_000; const OUTER_FD_AUDIT_MAX_THETA_DIM: usize = 32; let n_total = data.nrows();
4150 let outer_fd_audit_eligible = analytic_outer_hessian_available && n_total <= OUTER_FD_AUDIT_MAX_N && theta_dim <= OUTER_FD_AUDIT_MAX_THETA_DIM; log::warn!(
4154 "[OUTER-FD-AUDIT/spatial-exact-joint] gate eligible={outer_fd_audit_eligible} \
4155 analytic_grad={analytic_outer_hessian_available} n_total={n_total} \
4156 theta_dim={theta_dim} rho_dim={rho_dim} psi_dim={coord_dim}"
4157 );
4158 if outer_fd_audit_eligible {
4159 let audit = (|| -> Result<gam_solve::rho_optimizer::OuterGradientFdAudit, String> {
4161 let mut eval_at = |theta: &Array1<f64>,
4162 mode: gam_solve::estimate::reml::reml_outer_engine::EvalMode|
4163 -> Result<
4164 (
4165 f64,
4166 Array1<f64>,
4167 gam_problem::HessianResult,
4168 ),
4169 String,
4170 > {
4171 use gam_solve::estimate::reml::reml_outer_engine::EvalMode;
4172 let order = if matches!(mode, EvalMode::ValueGradientHessian) {
4173 OuterEvalOrder::ValueGradientHessian
4174 } else {
4175 OuterEvalOrder::Value
4176 };
4177 ctx.eval_full(theta, order, analytic_outer_hessian_available)
4178 .map_err(|e| format!("fd-audit eval_full: {e}"))
4179 };
4180 let rho_dim_audit = rho_dim;
4181 let label_fn = move |i: usize| -> String {
4182 if i < rho_dim_audit {
4183 format!("rho[{i}]")
4184 } else {
4185 format!("psi_kappa[{}]", i - rho_dim_audit)
4186 }
4187 };
4188 gam_solve::rho_optimizer::outer_gradient_fd_audit(
4189 theta0,
4191 1e-4,
4192 label_fn,
4193 &mut eval_at,
4194 )
4195 })();
4196 match audit {
4198 Ok(audit) => audit.log_verdict("spatial-exact-joint"),
4199 Err(e) => log::warn!("[OUTER-FD-AUDIT/spatial-exact-joint] skipped: {e}"),
4200 }
4201 }
4202
4203 let kphase_prime_order = if analytic_outer_hessian_available && !suppress_outer_hessian_for_nfree {
4204 OuterEvalOrder::ValueGradientHessian
4205 } else {
4206 OuterEvalOrder::ValueAndGradient
4207 };
4208 let kphase_prime_start = std::time::Instant::now();
4209 drop(ctx.eval_full(theta0, kphase_prime_order, analytic_outer_hessian_available)?);
4210 log::info!(
4211 "[KAPPA-PHASE-PRIME] n_rows={} order={:?} elapsed_s={:.4} slow_path_resets_total={} design_revision={}",
4212 data.nrows(),
4213 kphase_prime_order,
4214 kphase_prime_start.elapsed().as_secs_f64(),
4215 ctx.evaluator.slow_path_reset_count(),
4216 ctx.cache.design_revision(),
4217 );
4218
4219 let kphase_cost_calls = std::cell::Cell::new(0usize);
4220 let kphase_eval_calls = std::cell::Cell::new(0usize);
4221 let kphase_efs_calls = std::cell::Cell::new(0usize);
4222 let kphase_cost_total_s = std::cell::Cell::new(0.0);
4223 let kphase_eval_total_s = std::cell::Cell::new(0.0);
4224 let kphase_efs_total_s = std::cell::Cell::new(0.0);
4225 let kphase_nfree_miss_shape = std::cell::Cell::new(0u64);
4226 let kphase_nfree_miss_value = std::cell::Cell::new(0u64);
4227 let kphase_nfree_miss_gradient = std::cell::Cell::new(0u64);
4228 let kphase_nfree_miss_penalty = std::cell::Cell::new(0u64);
4229 let kphase_nfree_miss_revision = std::cell::Cell::new(0u64);
4230 let kphase_nfree_miss_second_order = std::cell::Cell::new(0u64);
4231 let kphase_nfree_miss_other = std::cell::Cell::new(0u64);
4232 let kphase_optim_start = std::time::Instant::now();
4233 let kphase_log_kappa_dim = coord_dim;
4234 let kphase_slow_resets_start = ctx.evaluator.slow_path_reset_count();
4235 let kphase_design_revision_start = ctx.cache.design_revision();
4236
4237 let lower_effective: std::borrow::Cow<'_, Array1<f64>> = match psi_rank_stable_floor {
4244 Some(floor) if coord_dim == 1 && floor > lower[rho_dim] => {
4245 let mut lifted = lower.clone();
4246 lifted[rho_dim] = floor;
4247 std::borrow::Cow::Owned(lifted)
4248 }
4249 _ => std::borrow::Cow::Borrowed(lower),
4250 };
4251 let lower = lower_effective.as_ref();
4252
4253 let upper_effective: std::borrow::Cow<'_, Array1<f64>> = match psi_rank_stable_ceiling {
4261 Some(ceiling) if coord_dim == 1 && ceiling < upper[rho_dim] => {
4262 let mut clamped = upper.clone();
4263 clamped[rho_dim] = ceiling;
4264 std::borrow::Cow::Owned(clamped)
4265 }
4266 _ => std::borrow::Cow::Borrowed(upper),
4267 };
4268 let upper = upper_effective.as_ref();
4269
4270 let problem = exact_joint_multistart_outer_problem(
4271 theta0,
4272 lower,
4273 upper,
4274 rho_dim,
4275 coord_dim,
4276 theta_dim,
4277 Derivative::Analytic,
4278 if analytic_outer_hessian_available && !suppress_outer_hessian_for_nfree {
4279 DeclaredHessianForm::Either
4280 } else {
4281 DeclaredHessianForm::Unavailable
4286 },
4287 prefer_gradient_only,
4288 suppress_outer_hessian_for_nfree,
4299 seed_risk_profile_for_likelihood_family(&family),
4300 kappa_options.rel_tol.max(1e-6),
4301 kappa_options.max_outer_iter.max(1),
4302 Some(5.0),
4306 Some(kappa_options.log_step.clamp(0.25, 1.0)),
4308 None,
4309 Some((data.nrows(), baseline_design.design.ncols())),
4314 !constant_curvature_term_indices(resolvedspec).is_empty(),
4318 );
4319
4320 let eval_outer = |ctx: &mut &mut SpatialJointContext<'_>,
4321 theta: &Array1<f64>,
4322 order: OuterEvalOrder|
4323 -> Result<OuterEval, EstimationError> {
4324 let t0 = std::time::Instant::now();
4325 let allow_second_order_for_call = matches!(order, OuterEvalOrder::ValueGradientHessian)
4326 && analytic_outer_hessian_available;
4327 let gate = ctx.nfree_skip_gate_status(theta, allow_second_order_for_call, true);
4328 let resets_before = ctx.evaluator.slow_path_reset_count();
4329 let raw = ctx.eval_full(theta, order, analytic_outer_hessian_available);
4330 let reset_delta = ctx
4331 .evaluator
4332 .slow_path_reset_count()
4333 .saturating_sub(resets_before);
4334 if reset_delta > 0 {
4335 if !gate.shape {
4336 kphase_nfree_miss_shape.set(kphase_nfree_miss_shape.get() + reset_delta);
4337 }
4338 if gate.shape && !gate.value {
4339 kphase_nfree_miss_value.set(kphase_nfree_miss_value.get() + reset_delta);
4340 }
4341 if gate.shape && gate.value && !gate.gradient {
4342 kphase_nfree_miss_gradient.set(kphase_nfree_miss_gradient.get() + reset_delta);
4343 }
4344 if gate.shape && gate.value && gate.gradient && !gate.penalty {
4345 kphase_nfree_miss_penalty.set(kphase_nfree_miss_penalty.get() + reset_delta);
4346 }
4347 if gate.shape && gate.value && gate.gradient && gate.penalty && !gate.revision {
4348 kphase_nfree_miss_revision.set(kphase_nfree_miss_revision.get() + reset_delta);
4349 }
4350 if gate.shape
4351 && gate.value
4352 && gate.gradient
4353 && gate.penalty
4354 && gate.revision
4355 && gate.second_order
4356 {
4357 kphase_nfree_miss_second_order
4358 .set(kphase_nfree_miss_second_order.get() + reset_delta);
4359 }
4360 if gate.would_skip(true) {
4361 kphase_nfree_miss_other.set(kphase_nfree_miss_other.get() + reset_delta);
4362 }
4363 }
4364 let elapsed_s = t0.elapsed().as_secs_f64();
4365 kphase_eval_calls.set(kphase_eval_calls.get() + 1);
4366 kphase_eval_total_s.set(kphase_eval_total_s.get() + elapsed_s);
4367 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta, rho_dim);
4368 log::info!(
4369 "[KAPPA-PHASE] phase=eval_outer call={} order={:?} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
4370 kphase_eval_calls.get(),
4371 order,
4372 Some(ctx.cache.design_revision()),
4373 theta_norm,
4374 log_kappa_norm,
4375 elapsed_s,
4376 );
4377 match raw {
4378 Ok((cost, grad, hess)) => Ok(OuterEval {
4379 cost,
4380 gradient: grad,
4381 hessian: hess,
4382 inner_beta_hint: None,
4383 }),
4384 Err(err) if is_recoverable_trial_point_error(&err) => {
4392 log::debug!(
4393 "[{label}] trial point infeasible (kernel design \
4394 not constructible at theta={theta:?}): {err}; retreating",
4395 );
4396 Ok(OuterEval::infeasible(theta_dim))
4397 }
4398 Err(err) => Err(err),
4399 }
4400 };
4401
4402 let mut obj = problem.build_objective_with_eval_order(
4403 &mut ctx,
4404 |ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>| {
4405 let t0 = std::time::Instant::now();
4406 let gate = ctx.nfree_skip_gate_status(theta, false, false);
4407 let resets_before = ctx.evaluator.slow_path_reset_count();
4408 let cost = ctx.eval_cost(theta);
4409 let reset_delta = ctx
4410 .evaluator
4411 .slow_path_reset_count()
4412 .saturating_sub(resets_before);
4413 if reset_delta > 0 {
4414 if !gate.shape {
4415 kphase_nfree_miss_shape.set(kphase_nfree_miss_shape.get() + reset_delta);
4416 }
4417 if gate.shape && !gate.value {
4418 kphase_nfree_miss_value.set(kphase_nfree_miss_value.get() + reset_delta);
4419 }
4420 if gate.shape && gate.value && !gate.penalty {
4421 kphase_nfree_miss_penalty.set(kphase_nfree_miss_penalty.get() + reset_delta);
4422 }
4423 if gate.shape && gate.value && gate.penalty && !gate.revision {
4424 kphase_nfree_miss_revision.set(kphase_nfree_miss_revision.get() + reset_delta);
4425 }
4426 if gate.would_skip(false) {
4427 kphase_nfree_miss_other.set(kphase_nfree_miss_other.get() + reset_delta);
4428 }
4429 }
4430 let elapsed_s = t0.elapsed().as_secs_f64();
4431 kphase_cost_calls.set(kphase_cost_calls.get() + 1);
4432 kphase_cost_total_s.set(kphase_cost_total_s.get() + elapsed_s);
4433 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta, rho_dim);
4434 log::info!(
4435 "[KAPPA-PHASE] phase=cost call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
4436 kphase_cost_calls.get(),
4437 Some(ctx.cache.design_revision()),
4438 theta_norm,
4439 log_kappa_norm,
4440 elapsed_s,
4441 );
4442 Ok(cost)
4443 },
4444 |ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>| {
4445 eval_outer(
4446 ctx,
4447 theta,
4448 if analytic_outer_hessian_available && !suppress_outer_hessian_for_nfree {
4458 OuterEvalOrder::ValueGradientHessian
4459 } else {
4460 OuterEvalOrder::ValueAndGradient
4461 },
4462 )
4463 },
4464 |ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>, order: OuterEvalOrder| {
4465 eval_outer(ctx, theta, order)
4466 },
4467 Some(|ctx: &mut &mut SpatialJointContext<'_>| {
4468 ctx.reset();
4469 }),
4470 Some(|ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>| {
4471 let t0 = std::time::Instant::now();
4472 let eval = ctx.eval_efs(theta);
4473 let elapsed_s = t0.elapsed().as_secs_f64();
4474 kphase_efs_calls.set(kphase_efs_calls.get() + 1);
4475 kphase_efs_total_s.set(kphase_efs_total_s.get() + elapsed_s);
4476 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta, rho_dim);
4477 log::info!(
4478 "[KAPPA-PHASE] phase=efs call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
4479 kphase_efs_calls.get(),
4480 Some(ctx.cache.design_revision()),
4481 theta_norm,
4482 log_kappa_norm,
4483 elapsed_s,
4484 );
4485 eval
4486 }),
4487 );
4488
4489 let run_label = match kind {
4490 SpatialHyperKind::Anisotropic => "aniso-psi joint REML",
4491 SpatialHyperKind::Isotropic => "iso-kappa joint REML",
4492 };
4493 let result = problem.run(&mut obj, run_label).map_err(|e| {
4494 EstimationError::InvalidInput(format!(
4495 "{} analytic optimization failed after exhausting strategy fallbacks: {e}",
4496 kind.adjective(),
4497 ))
4498 })?;
4499 drop(obj);
4500 let kphase_total_s = kphase_optim_start.elapsed().as_secs_f64();
4501 let kphase_slow_resets = ctx
4502 .evaluator
4503 .slow_path_reset_count()
4504 .saturating_sub(kphase_slow_resets_start);
4505 let kphase_design_revision_delta = ctx
4506 .cache
4507 .design_revision()
4508 .saturating_sub(kphase_design_revision_start);
4509 log::info!(
4510 "[KAPPA-PHASE-SUMMARY] n_rows={} log_kappa_dim={} n_cost={} cost_total_s={:.4} n_eval={} eval_total_s={:.4} n_efs={} efs_total_s={:.4} slow_path_resets={} design_revision_delta={} nfree_miss_shape={} nfree_miss_value={} nfree_miss_gradient={} nfree_miss_penalty={} nfree_miss_revision={} nfree_miss_second_order={} nfree_miss_other={} optim_total_s={:.4}",
4511 data.nrows(),
4512 kphase_log_kappa_dim,
4513 kphase_cost_calls.get(),
4514 kphase_cost_total_s.get(),
4515 kphase_eval_calls.get(),
4516 kphase_eval_total_s.get(),
4517 kphase_efs_calls.get(),
4518 kphase_efs_total_s.get(),
4519 kphase_slow_resets,
4520 kphase_design_revision_delta,
4521 kphase_nfree_miss_shape.get(),
4522 kphase_nfree_miss_value.get(),
4523 kphase_nfree_miss_gradient.get(),
4524 kphase_nfree_miss_penalty.get(),
4525 kphase_nfree_miss_revision.get(),
4526 kphase_nfree_miss_second_order.get(),
4527 kphase_nfree_miss_other.get(),
4528 kphase_total_s,
4529 );
4530 let timing = SpatialLengthScaleOptimizationTiming {
4531 log_kappa_dim: kphase_log_kappa_dim,
4532 cost_calls: kphase_cost_calls.get(),
4533 cost_total_s: kphase_cost_total_s.get(),
4534 eval_calls: kphase_eval_calls.get(),
4535 eval_total_s: kphase_eval_total_s.get(),
4536 efs_calls: kphase_efs_calls.get(),
4537 efs_total_s: kphase_efs_total_s.get(),
4538 slow_path_resets: kphase_slow_resets,
4539 design_revision_delta: kphase_design_revision_delta,
4540 nfree_miss_shape: kphase_nfree_miss_shape.get(),
4541 nfree_miss_value: kphase_nfree_miss_value.get(),
4542 nfree_miss_gradient: kphase_nfree_miss_gradient.get(),
4543 nfree_miss_penalty: kphase_nfree_miss_penalty.get(),
4544 nfree_miss_revision: kphase_nfree_miss_revision.get(),
4545 nfree_miss_second_order: kphase_nfree_miss_second_order.get(),
4546 nfree_miss_other: kphase_nfree_miss_other.get(),
4547 optim_total_s: kphase_total_s,
4548 };
4549 if !result.converged {
4550 let rel_to_cost_threshold = options.tol * (1.0_f64 + result.final_value.abs());
4561 if let Some(final_grad) = result
4562 .final_grad_norm
4563 .filter(|v| v.is_finite() && *v <= rel_to_cost_threshold)
4564 {
4565 log::info!(
4566 "[{}] outer optimization hit max_iter={} but \
4567 projected gradient norm {:.3e} ≤ τ·(1+|f|) = {:.3e} \
4568 (τ={:.3e}, |f|={:.3e}); accepting iterate under the mgcv-style \
4569 relative-to-cost REML convergence criterion.",
4570 label,
4571 result.iterations,
4572 final_grad,
4573 rel_to_cost_threshold,
4574 options.tol,
4575 result.final_value.abs(),
4576 );
4577 } else if result.final_value.is_finite() {
4578 log::warn!(
4593 "[{}] {} did not converge after {} iterations \
4594 (final_objective={:.6e}, final_grad_norm={}); keeping the \
4595 frozen baseline geometry instead of aborting the fit.",
4596 label,
4597 kind.adjective(),
4598 result.iterations,
4599 result.final_value,
4600 result.final_grad_norm_report(),
4601 );
4602 return Ok((
4603 SpatialJointOutcome::NonConverged {
4604 iterations: result.iterations,
4605 final_value: result.final_value,
4606 final_grad_norm: result.final_grad_norm,
4607 },
4608 timing,
4609 ));
4610 } else {
4611 crate::bail_invalid_estim!(
4616 "{} analytic optimization diverged after {} iterations (final_objective={:.6e}, final_grad_norm={})",
4617 kind.adjective(),
4618 result.iterations,
4619 result.final_value,
4620 result.final_grad_norm_report(),
4621 );
4622 }
4623 }
4624 log::trace!(
4625 "[{}] converged in {} iterations, final_value={:.6e}, grad_norm={}",
4626 label,
4627 result.iterations,
4628 result.final_value,
4629 result.final_grad_norm_report(),
4630 );
4631 let theta_star = result.rho;
4635 Ok((
4636 SpatialJointOutcome::Optimized {
4637 theta_star,
4638 final_value: result.final_value,
4639 },
4640 timing,
4641 ))
4642}
4643
4644fn set_single_term_spatial_length_scale(
4648 term: &mut SmoothTermSpec,
4649 length_scale: f64,
4650) -> Result<(), EstimationError> {
4651 match &mut term.basis {
4652 SmoothBasisSpec::ThinPlate { spec, .. } => {
4653 spec.length_scale = length_scale;
4654 Ok(())
4655 }
4656 SmoothBasisSpec::Matern { spec, .. } => {
4657 spec.length_scale = length_scale;
4658 Ok(())
4659 }
4660 SmoothBasisSpec::Duchon { spec, .. } => {
4661 spec.length_scale = Some(length_scale);
4662 Ok(())
4663 }
4664 _ => Err(EstimationError::InvalidInput(format!(
4665 "term '{}' does not expose a spatial length scale",
4666 term.name
4667 ))),
4668 }
4669}
4670
4671fn set_single_term_spatial_aniso_log_scales(
4675 term: &mut SmoothTermSpec,
4676 eta: Vec<f64>,
4677) -> Result<(), EstimationError> {
4678 let eta = center_aniso_log_scales(&eta);
4679 match &mut term.basis {
4680 SmoothBasisSpec::Matern { spec, .. } => {
4681 spec.aniso_log_scales = Some(eta);
4682 Ok(())
4683 }
4684 SmoothBasisSpec::Duchon { spec, .. } => {
4685 spec.aniso_log_scales = Some(eta);
4686 Ok(())
4687 }
4688 _ => Err(EstimationError::InvalidInput(format!(
4689 "term '{}' does not support aniso_log_scales",
4690 term.name
4691 ))),
4692 }
4693}
4694
4695pub fn get_constant_curvature_kappa(spec: &TermCollectionSpec, term_idx: usize) -> Option<f64> {
4714 constant_curvature_term_spec(spec, term_idx).map(|cc| cc.kappa)
4715}
4716
4717pub fn constant_curvature_term_indices(spec: &TermCollectionSpec) -> Vec<usize> {
4719 (0..spec.smooth_terms.len())
4720 .filter(|&idx| constant_curvature_term_spec(spec, idx).is_some())
4721 .collect()
4722}
4723
4724
4725#[derive(Debug, Clone)]
4726struct SingleSmoothTermRealization {
4727 design_local: DesignMatrix,
4728 term: SmoothTerm,
4729 dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
4730}
4731
4732impl SingleSmoothTermRealization {
4733 fn active_penaltyinfo(&self) -> Vec<PenaltyInfo> {
4734 self.term
4735 .penaltyinfo_local
4736 .iter()
4737 .filter(|info| info.active)
4738 .cloned()
4739 .collect()
4740 }
4741}
4742
4743fn build_single_smooth_term_realization(
4744 data: ArrayView2<'_, f64>,
4745 termspec: &SmoothTermSpec,
4746) -> Result<SingleSmoothTermRealization, BasisError> {
4747 let raw = build_smooth_design(data, std::slice::from_ref(termspec))?;
4748 finish_single_smooth_term_realization(raw)
4749}
4750
4751fn finish_single_smooth_term_realization(
4752 raw: RawSmoothDesign,
4753) -> Result<SingleSmoothTermRealization, BasisError> {
4754 let RawSmoothDesign {
4755 term_designs,
4756 dropped_penaltyinfo,
4757 terms,
4758 ..
4759 } = raw;
4760 let term = terms.into_iter().next().ok_or_else(|| {
4761 BasisError::InvalidInput("single-term smooth build returned no term".to_string())
4762 })?;
4763 let design = term_designs.into_iter().next().ok_or_else(|| {
4764 BasisError::InvalidInput("single-term smooth build returned no term design".to_string())
4765 })?;
4766
4767 Ok(SingleSmoothTermRealization {
4768 design_local: design,
4769 term,
4770 dropped_penaltyinfo,
4771 })
4772}
4773
4774fn wrap_local_build_as_realization(
4781 mut local: LocalSmoothTermBuild,
4782 termspec: &SmoothTermSpec,
4783) -> Result<SingleSmoothTermRealization, String> {
4784 let p_local = local.dim;
4785 let lb_local = if local.box_reparam {
4786 shape_lower_bounds_local(termspec.shape, p_local)
4787 } else {
4788 None
4789 };
4790
4791 let active_count = local.penaltyinfo.iter().filter(|info| info.active).count();
4792 if active_count != local.penalties.len() {
4793 return Err(format!(
4794 "internal penalty info mismatch for term '{}': active_infos={}, penalties={}",
4795 termspec.name,
4796 active_count,
4797 local.penalties.len()
4798 ));
4799 }
4800
4801 let mut dropped_penaltyinfo = Vec::<DroppedPenaltyBlockInfo>::new();
4802 for info in local.penaltyinfo.iter().filter(|info| !info.active) {
4803 dropped_penaltyinfo.push(DroppedPenaltyBlockInfo {
4804 termname: Some(termspec.name.clone()),
4805 penalty: info.clone(),
4806 });
4807 }
4808 for info in &local.pre_dropped_penaltyinfo {
4809 dropped_penaltyinfo.push(DroppedPenaltyBlockInfo {
4810 termname: Some(termspec.name.clone()),
4811 penalty: info.clone(),
4812 });
4813 }
4814
4815 let applied_rotation: Option<gam_terms::basis::JointNullRotation> = match (
4819 local.joint_null_rotation.take(),
4820 lb_local.is_some(),
4821 local.linear_constraints.is_some(),
4822 ) {
4823 (Some(rot), false, false) => {
4824 let q = &rot.rotation;
4825 let dense = local
4826 .design
4827 .try_to_dense_by_chunks("joint-null absorption rotation (single realization)")
4828 .map_err(|e| {
4829 format!(
4830 "joint-null absorption rotation: dense conversion failed for term '{}': {}",
4831 termspec.name, e
4832 )
4833 })?;
4834 let rotated = gam_linalg::faer_ndarray::fast_ab(&dense, q);
4835 local.design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(rotated));
4836 local.penalties = local
4837 .penalties
4838 .into_iter()
4839 .map(|s_local| {
4840 let qt_s = gam_linalg::faer_ndarray::fast_atb(q, &s_local);
4841 gam_linalg::faer_ndarray::fast_ab(&qt_s, q)
4842 })
4843 .collect();
4844 local.ops = vec![None; local.penalties.len()];
4845 local.kronecker_factored = None;
4846 Some(rot)
4847 }
4848 (Some(_), _, _) => None,
4849 (None, _, _) => None,
4850 };
4851
4852 let smooth_term = SmoothTerm {
4853 name: termspec.name.clone(),
4854 coeff_range: 0..p_local,
4855 shape: termspec.shape,
4856 penalties_local: local.penalties.clone(),
4857 nullspace_dims: local.nullspaces.clone(),
4858 penaltyinfo_local: local.penaltyinfo.clone(),
4859 metadata: local.metadata.clone(),
4860 lower_bounds_local: lb_local,
4861 linear_constraints_local: local.linear_constraints.clone(),
4862 kronecker_factored: local.kronecker_factored.take(),
4863 joint_null_rotation: applied_rotation,
4864 unabsorbed_global_orthogonality: None,
4867 };
4868
4869 Ok(SingleSmoothTermRealization {
4870 design_local: local.design,
4871 term: smooth_term,
4872 dropped_penaltyinfo,
4873 })
4874}
4875
4876fn freeze_geometry_from_metadata(
4887 termspec: &SmoothTermSpec,
4888 metadata: &BasisMetadata,
4889) -> Option<SmoothTermSpec> {
4890 let mut frozen = termspec.clone();
4891 match (&mut frozen.basis, metadata) {
4892 (
4893 SmoothBasisSpec::Matern {
4894 spec,
4895 input_scales: spec_scales,
4896 ..
4897 },
4898 BasisMetadata::Matern {
4899 centers,
4900 input_scales: meta_scales,
4901 identifiability_transform,
4902 nullspace_shrinkage_survived,
4903 ..
4904 },
4905 ) => {
4906 spec.center_strategy = CenterStrategy::UserProvided(centers.clone());
4907 if spec_scales.is_none()
4908 && let Some(s) = meta_scales.clone()
4909 {
4910 *spec_scales = Some(s);
4911 }
4912 if let Some(transform) = identifiability_transform.clone() {
4930 spec.identifiability = MaternIdentifiability::FrozenTransform {
4931 transform,
4932 nullspace_shrinkage_survived: Some(*nullspace_shrinkage_survived),
4933 };
4934 }
4935 Some(frozen)
4936 }
4937 (
4938 SmoothBasisSpec::Duchon {
4939 spec,
4940 input_scales: spec_scales,
4941 ..
4942 },
4943 BasisMetadata::Duchon {
4944 centers,
4945 input_scales: meta_scales,
4946 ..
4947 },
4948 ) => {
4949 spec.center_strategy = CenterStrategy::UserProvided(centers.clone());
4950 if spec_scales.is_none()
4951 && let Some(s) = meta_scales.clone()
4952 {
4953 *spec_scales = Some(s);
4954 }
4955 Some(frozen)
4956 }
4957 (
4958 SmoothBasisSpec::ThinPlate {
4959 spec,
4960 input_scales: spec_scales,
4961 ..
4962 },
4963 BasisMetadata::ThinPlate {
4964 centers,
4965 input_scales: meta_scales,
4966 ..
4967 },
4968 ) => {
4969 spec.center_strategy = CenterStrategy::UserProvided(centers.clone());
4970 if spec_scales.is_none()
4971 && let Some(s) = meta_scales.clone()
4972 {
4973 *spec_scales = Some(s);
4974 }
4975 Some(frozen)
4976 }
4977 _ => None,
4980 }
4981}
4982
4983fn rebuild_smooth_auxiliary_state(
4984 smooth: &mut SmoothDesign,
4985 dropped_penaltyinfo_by_term: &[Vec<DroppedPenaltyBlockInfo>],
4986) -> Result<(), String> {
4987 if dropped_penaltyinfo_by_term.len() != smooth.terms.len() {
4988 return Err(SmoothError::dimension_mismatch(format!(
4989 "smooth dropped-penalty cache mismatch: terms={}, dropped_sets={}",
4990 smooth.terms.len(),
4991 dropped_penaltyinfo_by_term.len()
4992 ))
4993 .into());
4994 }
4995
4996 let total_p = smooth.total_smooth_cols();
4997 let mut coefficient_lower_bounds = Array1::<f64>::from_elem(total_p, f64::NEG_INFINITY);
4998 let mut any_bounds = false;
4999 let mut linear_constraintrows: Vec<Array1<f64>> = Vec::new();
5000 let mut linear_constraint_b: Vec<f64> = Vec::new();
5001
5002 for term in &smooth.terms {
5003 let range = term.coeff_range.clone();
5004 if let Some(lb_local) = term.lower_bounds_local.as_ref() {
5005 if lb_local.len() != range.len() {
5006 return Err(SmoothError::dimension_mismatch(format!(
5007 "smooth lower-bound cache mismatch for term '{}': bounds={}, coeffs={}",
5008 term.name,
5009 lb_local.len(),
5010 range.len()
5011 ))
5012 .into());
5013 }
5014 coefficient_lower_bounds
5015 .slice_mut(s![range.clone()])
5016 .assign(lb_local);
5017 any_bounds = true;
5018 }
5019 if let Some(lin_local) = term.linear_constraints_local.as_ref() {
5020 if lin_local.a.ncols() != range.len() {
5021 return Err(SmoothError::dimension_mismatch(format!(
5022 "smooth linear-constraint cache mismatch for term '{}': cols={}, coeffs={}",
5023 term.name,
5024 lin_local.a.ncols(),
5025 range.len()
5026 ))
5027 .into());
5028 }
5029 for r in 0..lin_local.a.nrows() {
5030 let mut row = Array1::<f64>::zeros(total_p);
5031 row.slice_mut(s![range.clone()]).assign(&lin_local.a.row(r));
5032 linear_constraintrows.push(row);
5033 linear_constraint_b.push(lin_local.b[r]);
5034 }
5035 }
5036 }
5037
5038 smooth.coefficient_lower_bounds = if any_bounds {
5039 Some(coefficient_lower_bounds)
5040 } else {
5041 None
5042 };
5043 smooth.linear_constraints = if linear_constraintrows.is_empty() {
5044 None
5045 } else {
5046 let mut a = Array2::<f64>::zeros((linear_constraintrows.len(), total_p));
5047 for (i, row) in linear_constraintrows.iter().enumerate() {
5048 a.row_mut(i).assign(row);
5049 }
5050 Some(LinearInequalityConstraints {
5051 a,
5052 b: Array1::from_vec(linear_constraint_b),
5053 })
5054 };
5055 smooth.dropped_penaltyinfo = dropped_penaltyinfo_by_term
5056 .iter()
5057 .flat_map(|infos| infos.iter().cloned())
5058 .collect();
5059 Ok(())
5060}
5061
5062fn rebuild_term_collection_auxiliary_state(
5063 spec: &TermCollectionSpec,
5064 design: &mut TermCollectionDesign,
5065) -> Result<(), String> {
5066 if spec.linear_terms.len() != design.linear_ranges.len() {
5067 return Err(SmoothError::dimension_mismatch(format!(
5068 "term-collection linear bookkeeping mismatch: spec_terms={}, design_ranges={}",
5069 spec.linear_terms.len(),
5070 design.linear_ranges.len()
5071 ))
5072 .into());
5073 }
5074
5075 let p_total = design.design.ncols();
5076 let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
5077 let mut coefficient_lower_bounds = Array1::<f64>::from_elem(p_total, f64::NEG_INFINITY);
5078 let mut any_bounds = false;
5079 let mut linear_constraintrows: Vec<Array1<f64>> = Vec::new();
5080 let mut linear_constraint_b: Vec<f64> = Vec::new();
5081
5082 for (linear, (_, range)) in spec.linear_terms.iter().zip(design.linear_ranges.iter()) {
5083 if range.len() != 1 {
5084 return Err(SmoothError::dimension_mismatch(format!(
5085 "linear term '{}' expected one coefficient column, found {}",
5086 linear.name,
5087 range.len()
5088 ))
5089 .into());
5090 }
5091 let col = range.start;
5092 if let Some(lb) = linear.coefficient_min {
5093 let mut row = Array1::<f64>::zeros(p_total);
5094 row[col] = 1.0;
5095 linear_constraintrows.push(row);
5096 linear_constraint_b.push(lb);
5097 }
5098 if let Some(ub) = linear.coefficient_max {
5099 let mut row = Array1::<f64>::zeros(p_total);
5100 row[col] = -1.0;
5101 linear_constraintrows.push(row);
5102 linear_constraint_b.push(-ub);
5103 }
5104 }
5105
5106 if let Some(lb_smooth) = design.smooth.coefficient_lower_bounds.as_ref() {
5107 if lb_smooth.len() != design.smooth.total_smooth_cols() {
5108 return Err(SmoothError::dimension_mismatch(format!(
5109 "smooth lower-bound width mismatch: bounds={}, smooth_cols={}",
5110 lb_smooth.len(),
5111 design.smooth.total_smooth_cols()
5112 ))
5113 .into());
5114 }
5115 coefficient_lower_bounds
5116 .slice_mut(s![
5117 smooth_start..(smooth_start + design.smooth.total_smooth_cols())
5118 ])
5119 .assign(lb_smooth);
5120 any_bounds = true;
5121 }
5122 if let Some(lin_smooth) = design.smooth.linear_constraints.as_ref() {
5123 if lin_smooth.a.ncols() != design.smooth.total_smooth_cols() {
5124 return Err(SmoothError::dimension_mismatch(format!(
5125 "smooth linear-constraint width mismatch: cols={}, smooth_cols={}",
5126 lin_smooth.a.ncols(),
5127 design.smooth.total_smooth_cols()
5128 ))
5129 .into());
5130 }
5131 let mut a_global = Array2::<f64>::zeros((lin_smooth.a.nrows(), p_total));
5132 a_global
5133 .slice_mut(s![
5134 ..,
5135 smooth_start..(smooth_start + design.smooth.total_smooth_cols())
5136 ])
5137 .assign(&lin_smooth.a);
5138 for r in 0..a_global.nrows() {
5139 linear_constraintrows.push(a_global.row(r).to_owned());
5140 linear_constraint_b.push(lin_smooth.b[r]);
5141 }
5142 }
5143
5144 let lower_bound_constraints = if any_bounds {
5145 linear_constraints_from_lower_bounds_global(&coefficient_lower_bounds)
5146 } else {
5147 None
5148 };
5149 let explicit_linear_constraints = if linear_constraintrows.is_empty() {
5150 None
5151 } else {
5152 let mut a = Array2::<f64>::zeros((linear_constraintrows.len(), p_total));
5153 for (i, row) in linear_constraintrows.iter().enumerate() {
5154 a.row_mut(i).assign(row);
5155 }
5156 Some(LinearInequalityConstraints {
5157 a,
5158 b: Array1::from_vec(linear_constraint_b),
5159 })
5160 };
5161
5162 design.coefficient_lower_bounds = if any_bounds {
5163 Some(coefficient_lower_bounds)
5164 } else {
5165 None
5166 };
5167 design.linear_constraints =
5168 merge_linear_constraints_global(explicit_linear_constraints, lower_bound_constraints);
5169 design.dropped_penaltyinfo = design.smooth.dropped_penaltyinfo.clone();
5170 Ok(())
5171}
5172
5173fn theta_values_match(left: &Array1<f64>, right: &Array1<f64>) -> bool {
5174 left.len() == right.len()
5175 && left
5176 .iter()
5177 .zip(right.iter())
5178 .all(|(&l, &r)| l.to_bits() == r.to_bits())
5179}
5180
5181fn latent_values_match(left: &Array1<f64>, right: &Array1<f64>) -> bool {
5182 theta_values_match(left, right)
5183}
5184
5185fn spatial_aniso_matches(left: Option<&[f64]>, right: Option<&[f64]>) -> bool {
5186 match (left, right) {
5187 (None, None) => true,
5188 (Some(a), Some(b)) => {
5189 a.len() == b.len()
5190 && a.iter()
5191 .zip(b.iter())
5192 .all(|(&x, &y)| x.to_bits() == y.to_bits())
5193 }
5194 _ => false,
5195 }
5196}
5197
5198fn spatial_length_scale_matches(left: Option<f64>, right: Option<f64>) -> bool {
5199 match (left, right) {
5200 (None, None) => true,
5201 (Some(a), Some(b)) => a.to_bits() == b.to_bits(),
5202 _ => false,
5203 }
5204}
5205
5206struct FrozenTermCollectionIncrementalRealizer<'d> {
5207 data: ArrayView2<'d, f64>,
5208 spec: TermCollectionSpec,
5209 design: TermCollectionDesign,
5210 fixed_blocks: Vec<DesignBlock>,
5211 dropped_penaltyinfo_by_term: Vec<Vec<DroppedPenaltyBlockInfo>>,
5212 smooth_penalty_ranges: Vec<Range<usize>>,
5213 full_penalty_ranges: Vec<Range<usize>>,
5214 basisworkspace: gam_terms::basis::BasisWorkspace,
5218 spatial_realization_geometry: Vec<Option<SmoothTermSpec>>,
5231 design_revision: u64,
5237}
5238
5239impl<'d> std::fmt::Debug for FrozenTermCollectionIncrementalRealizer<'d> {
5240 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
5241 f.debug_struct("FrozenTermCollectionIncrementalRealizer")
5242 .field("data_shape", &(self.data.nrows(), self.data.ncols()))
5243 .field("fixed_blocks", &self.fixed_blocks.len())
5244 .finish_non_exhaustive()
5245 }
5246}
5247
5248impl<'d> FrozenTermCollectionIncrementalRealizer<'d> {
5249 fn new(
5250 data: ArrayView2<'d, f64>,
5251 spec: TermCollectionSpec,
5252 design: TermCollectionDesign,
5253 ) -> Result<Self, String> {
5254 if spec.smooth_terms.len() != design.smooth.terms.len() {
5255 return Err(SmoothError::dimension_mismatch(format!(
5256 "incremental realizer smooth term mismatch: spec_terms={}, design_terms={}",
5257 spec.smooth_terms.len(),
5258 design.smooth.terms.len()
5259 ))
5260 .into());
5261 }
5262
5263 let mut smooth_cursor = 0usize;
5264 let mut smooth_penalty_ranges = Vec::with_capacity(design.smooth.terms.len());
5265 for term in &design.smooth.terms {
5266 let next = smooth_cursor + term.penalties_local.len();
5267 smooth_penalty_ranges.push(smooth_cursor..next);
5268 smooth_cursor = next;
5269 }
5270 if smooth_cursor != design.smooth.penalties.len() {
5271 return Err(SmoothError::dimension_mismatch(format!(
5272 "incremental realizer smooth penalty mismatch: ranged={}, actual={}",
5273 smooth_cursor,
5274 design.smooth.penalties.len()
5275 ))
5276 .into());
5277 }
5278
5279 let fixed_penalty_offset = design
5280 .penalties
5281 .len()
5282 .checked_sub(design.smooth.penalties.len())
5283 .ok_or_else(|| {
5284 "incremental realizer encountered invalid penalty bookkeeping".to_string()
5285 })?;
5286 let full_penalty_ranges = smooth_penalty_ranges
5287 .iter()
5288 .map(|range| (fixed_penalty_offset + range.start)..(fixed_penalty_offset + range.end))
5289 .collect::<Vec<_>>();
5290 let fixed_blocks = build_term_collection_fixed_blocks(data, &spec)
5291 .map_err(|e| format!("failed to cache fixed term-collection blocks: {e}"))?;
5292
5293 let mut dropped_penaltyinfo_by_term = Vec::with_capacity(spec.smooth_terms.len());
5294 for (term_idx, termspec) in spec.smooth_terms.iter().enumerate() {
5295 let realization =
5296 build_single_smooth_term_realization(data, termspec).map_err(|e| {
5297 format!(
5298 "failed to build cached realization for smooth term '{}' (index {}): {e}",
5299 termspec.name, term_idx
5300 )
5301 })?;
5302 let expected_cols = design.smooth.terms[term_idx].coeff_range.len();
5303 if realization.design_local.ncols() != expected_cols {
5304 return Err(SmoothError::dimension_mismatch(format!(
5305 "cached realization width mismatch for term '{}': cached_cols={}, design_cols={}",
5306 termspec.name,
5307 realization.design_local.ncols(),
5308 expected_cols
5309 ))
5310 .into());
5311 }
5312 if realization.active_penaltyinfo().len()
5313 != design.smooth.terms[term_idx].penalties_local.len()
5314 {
5315 return Err(SmoothError::dimension_mismatch(format!(
5316 "cached realization penalty mismatch for term '{}': cached_penalties={}, design_penalties={}",
5317 termspec.name,
5318 realization.active_penaltyinfo().len(),
5319 design.smooth.terms[term_idx].penalties_local.len()
5320 ))
5321 .into());
5322 }
5323 dropped_penaltyinfo_by_term.push(realization.dropped_penaltyinfo);
5324 }
5325
5326 let geometry_slots = spec.smooth_terms.len();
5327 Ok(Self {
5328 data,
5329 spec,
5330 design,
5331 fixed_blocks,
5332 dropped_penaltyinfo_by_term,
5333 smooth_penalty_ranges,
5334 full_penalty_ranges,
5335 basisworkspace: gam_terms::basis::BasisWorkspace::new(),
5336 spatial_realization_geometry: vec![None; geometry_slots],
5337 design_revision: 0,
5338 })
5339 }
5340
5341 fn design_revision(&self) -> u64 {
5342 self.design_revision
5343 }
5344
5345 fn spec(&self) -> &TermCollectionSpec {
5346 &self.spec
5347 }
5348
5349 fn design(&self) -> &TermCollectionDesign {
5350 &self.design
5351 }
5352
5353 fn supports_nfree_penalty_rekey(&self, spatial_terms: &[usize]) -> bool {
5368 if spatial_terms.len() != 1 {
5369 return false;
5370 }
5371 let term_idx = spatial_terms[0];
5372 matches!(
5373 self.design.smooth.terms.get(term_idx).map(|t| &t.metadata),
5374 Some(BasisMetadata::Duchon { .. } | BasisMetadata::ThinPlate { .. })
5375 )
5376 }
5377
5378 fn supports_nfree_gradient_only_routing(&self, spatial_terms: &[usize]) -> bool {
5387 if spatial_terms.len() != 1 {
5388 return false;
5389 }
5390 let term_idx = spatial_terms[0];
5391 matches!(
5392 self.design.smooth.terms.get(term_idx).map(|t| &t.metadata),
5393 Some(BasisMetadata::Duchon { .. } | BasisMetadata::ThinPlate { .. })
5394 )
5395 }
5396
5397 fn canonical_penalties_at_psi(
5410 &mut self,
5411 spatial_terms: &[usize],
5412 psi: &[f64],
5413 ) -> Result<(Vec<gam_terms::construction::CanonicalPenalty>, Vec<usize>), String> {
5414 if spatial_terms.len() != 1 {
5415 return Err(format!(
5416 "n-free penalty re-key requires exactly one spatial term, found {}",
5417 spatial_terms.len()
5418 ));
5419 }
5420 let term_idx = spatial_terms[0];
5421 let (ls_opt, aniso_from_psi) = spatial_term_psi_to_length_scale_and_aniso(psi);
5427 let termspec =
5430 self.spec.smooth_terms.get(term_idx).ok_or_else(|| {
5431 format!("spatial term {term_idx} out of range for n-free penalty")
5432 })?;
5433 let term = self
5434 .design
5435 .smooth
5436 .terms
5437 .get(term_idx)
5438 .ok_or_else(|| format!("realized smooth term {term_idx} out of range"))?;
5439 let p_total = self.design.design.ncols();
5442 let (locals, nullspace_dims): (Vec<Array2<f64>>, Vec<usize>) = match &term.metadata {
5443 BasisMetadata::Duchon {
5444 centers,
5445 identifiability_transform,
5446 operator_collocation_points,
5447 power,
5448 nullspace_order,
5449 aniso_log_scales,
5450 input_scales,
5451 radial_reparam,
5452 ..
5453 } => {
5454 let operator_penalties = match &termspec.basis {
5455 SmoothBasisSpec::Duchon { spec, .. } => spec.operator_penalties.clone(),
5456 _ => gam_terms::basis::DuchonOperatorPenaltySpec::default(),
5457 };
5458 let effective_ls = match input_scales.as_deref() {
5465 Some(scales) => {
5466 compensate_optional_length_scale_for_standardization(ls_opt, scales)
5467 }
5468 None => ls_opt,
5469 };
5470 gam_terms::basis::duchon_penalties_at_length_scale(
5471 centers.view(),
5472 identifiability_transform.as_ref(),
5473 operator_collocation_points.as_ref().map(|p| p.view()),
5474 &operator_penalties,
5475 *power,
5476 *nullspace_order,
5477 aniso_log_scales.as_deref(),
5478 radial_reparam.as_ref(),
5479 effective_ls,
5480 &mut self.basisworkspace,
5481 )
5482 .map_err(|e| e.to_string())?
5483 }
5484 BasisMetadata::Matern {
5485 centers,
5486 periodic,
5487 nu,
5488 include_intercept,
5489 identifiability_transform,
5490 aniso_log_scales,
5491 input_scales,
5492 ..
5493 } => {
5494 let ls = ls_opt.ok_or_else(|| {
5501 "Matérn n-free penalty re-key requires a finite length-scale".to_string()
5502 })?;
5503 let effective_ls = match input_scales.as_deref() {
5504 Some(scales) => compensate_length_scale_for_standardization(ls, scales),
5505 None => ls,
5506 };
5507 let aniso_for_penalty = aniso_from_psi.as_deref().or(aniso_log_scales.as_deref());
5508 let (penalties, nullspace_dims, _info) =
5519 matern_operator_penalty_triplet_at_length_scale(
5520 centers.view(),
5521 periodic.as_deref(),
5522 identifiability_transform.as_ref(),
5523 *nu,
5524 *include_intercept,
5525 aniso_for_penalty,
5526 effective_ls,
5527 )
5528 .map_err(|e| e.to_string())?;
5529 (penalties, nullspace_dims)
5530 }
5531 BasisMetadata::ThinPlate {
5532 centers,
5533 identifiability_transform,
5534 radial_reparam,
5535 ..
5536 } => {
5537 let ls = ls_opt.ok_or_else(|| {
5538 "thin-plate n-free penalty re-key requires a finite length-scale".to_string()
5539 })?;
5540 let double_penalty = match &termspec.basis {
5541 SmoothBasisSpec::ThinPlate { spec, .. } => spec.double_penalty,
5542 _ => false,
5543 };
5544 gam_terms::basis::thin_plate_penalties_at_length_scale(
5545 centers.view(),
5546 identifiability_transform.as_ref(),
5547 radial_reparam.as_ref(),
5548 ls,
5549 double_penalty,
5550 &mut self.basisworkspace,
5551 )
5552 .map_err(|e| e.to_string())?
5553 }
5554 other => {
5555 return Err(format!(
5556 "n-free penalty re-key unsupported for basis metadata {:?}",
5557 std::mem::discriminant(other)
5558 ));
5559 }
5560 };
5561 let templates = &self.design.penalties;
5566 if templates.len() != locals.len() {
5567 return Err(format!(
5568 "n-free penalty re-key produced {} blocks but the frozen design carries {} \
5569 — penalty topology is not ψ-stable",
5570 locals.len(),
5571 templates.len()
5572 ));
5573 }
5574 let specs: Vec<gam_solve::estimate::PenaltySpec> = templates
5575 .iter()
5576 .zip(locals.into_iter())
5577 .map(|(tmpl, local)| gam_solve::estimate::PenaltySpec::Block {
5578 local,
5579 col_range: tmpl.col_range.clone(),
5580 prior_mean: tmpl.prior_mean.clone(),
5581 structure_hint: tmpl.structure_hint.clone(),
5582 op: tmpl.op.clone(),
5583 })
5584 .collect();
5585 gam_terms::construction::canonicalize_penalty_specs(
5586 &specs,
5587 &nullspace_dims,
5588 p_total,
5589 "nfree-psi-penalty",
5590 )
5591 .map_err(|e| e.to_string())
5592 }
5593
5594 fn canonical_penalty_derivatives_at_psi(
5595 &mut self,
5596 spatial_terms: &[usize],
5597 psi: &[f64],
5598 ) -> Result<(Range<usize>, usize, Vec<Array2<f64>>), String> {
5599 if spatial_terms.len() != 1 {
5600 return Err(format!(
5601 "n-free penalty derivative re-key requires exactly one spatial term, found {}",
5602 spatial_terms.len()
5603 ));
5604 }
5605 let term_idx = spatial_terms[0];
5606 let (ls_opt, aniso_from_psi) = spatial_term_psi_to_length_scale_and_aniso(psi);
5607 let termspec = self.spec.smooth_terms.get(term_idx).ok_or_else(|| {
5608 format!("spatial term {term_idx} out of range for n-free penalty derivative")
5609 })?;
5610 let term = self
5611 .design
5612 .smooth
5613 .terms
5614 .get(term_idx)
5615 .ok_or_else(|| format!("realized smooth term {term_idx} out of range"))?;
5616 let p_total = self.design.design.ncols();
5617 let smooth_start = p_total.saturating_sub(self.design.smooth.total_smooth_cols());
5618 let global_range =
5619 (smooth_start + term.coeff_range.start)..(smooth_start + term.coeff_range.end);
5620
5621 let locals = match &term.metadata {
5622 BasisMetadata::Duchon {
5623 centers,
5624 identifiability_transform,
5625 operator_collocation_points,
5626 power,
5627 nullspace_order,
5628 aniso_log_scales,
5629 input_scales,
5630 radial_reparam,
5631 ..
5632 } => {
5633 let mut spec = match &termspec.basis {
5634 SmoothBasisSpec::Duchon { spec, .. } => spec.clone(),
5635 _ => {
5636 return Err(
5637 "Duchon n-free penalty derivative requires a Duchon term spec"
5638 .to_string(),
5639 );
5640 }
5641 };
5642 let effective_ls = match input_scales.as_deref() {
5643 Some(scales) => {
5644 compensate_optional_length_scale_for_standardization(ls_opt, scales)
5645 }
5646 None => ls_opt,
5647 };
5648 spec.length_scale = effective_ls;
5649 spec.power = *power;
5650 spec.nullspace_order = *nullspace_order;
5651 spec.aniso_log_scales = aniso_log_scales.clone();
5652 spec.radial_reparam = radial_reparam.clone();
5655 if spec.length_scale.is_none() {
5656 return Err(
5657 "Duchon n-free penalty derivative requires a hybrid length-scale"
5658 .to_string(),
5659 );
5660 }
5661 let collocation = operator_collocation_points
5662 .as_ref()
5663 .map(|points| points.view())
5664 .unwrap_or_else(|| centers.view());
5665 let (_native_sources, mut first, _native_second) =
5666 gam_terms::basis::build_duchon_native_penalty_psi_derivatives(
5667 centers.view(),
5668 &spec,
5669 identifiability_transform.as_ref(),
5670 &mut self.basisworkspace,
5671 )
5672 .map_err(|e| e.to_string())?;
5673 let (_operator_sources, operator_first, _operator_second) =
5674 gam_terms::basis::build_duchon_operator_penalty_psi_derivatives(
5675 collocation,
5676 centers.view(),
5677 &spec,
5678 identifiability_transform.as_ref(),
5679 &mut self.basisworkspace,
5680 )
5681 .map_err(|e| e.to_string())?;
5682 first.extend(operator_first);
5683 first
5684 }
5685 BasisMetadata::Matern {
5686 centers,
5687 periodic,
5688 nu,
5689 include_intercept,
5690 identifiability_transform,
5691 aniso_log_scales,
5692 input_scales,
5693 ..
5694 } => {
5695 let ls = ls_opt.ok_or_else(|| {
5696 "Matérn n-free penalty derivative requires a finite length-scale".to_string()
5697 })?;
5698 let effective_ls = match input_scales.as_deref() {
5699 Some(scales) => compensate_length_scale_for_standardization(ls, scales),
5700 None => ls,
5701 };
5702 let penalty_centers =
5703 gam_terms::basis::expand_periodic_centers(¢ers.to_owned(), periodic.as_deref())
5704 .map_err(|e| e.to_string())?;
5705 let aniso_for_penalty = aniso_from_psi.as_deref().or(aniso_log_scales.as_deref());
5706 let (first, _second) = gam_terms::basis::build_matern_operator_penalty_psi_derivatives(
5707 penalty_centers.view(),
5708 effective_ls,
5709 *nu,
5710 *include_intercept,
5711 identifiability_transform.as_ref(),
5712 aniso_for_penalty,
5713 )
5714 .map_err(|e| e.to_string())?;
5715 first
5716 }
5717 BasisMetadata::ThinPlate {
5718 centers,
5719 identifiability_transform,
5720 radial_reparam,
5721 ..
5722 } => {
5723 let ls = ls_opt.ok_or_else(|| {
5724 "thin-plate n-free penalty derivative requires a finite length-scale"
5725 .to_string()
5726 })?;
5727 let mut spec = match &termspec.basis {
5728 SmoothBasisSpec::ThinPlate { spec, .. } => spec.clone(),
5729 _ => {
5730 return Err(
5731 "thin-plate n-free penalty derivative requires a ThinPlate term spec"
5732 .to_string(),
5733 );
5734 }
5735 };
5736 spec.length_scale = ls;
5737 if spec.radial_reparam.is_none() {
5738 spec.radial_reparam = radial_reparam.clone();
5739 }
5740 let (primary, _primary_second) =
5741 gam_terms::basis::build_thin_plate_penalty_psi_derivativeswithworkspace(
5742 centers.view(),
5743 &spec,
5744 identifiability_transform.as_ref(),
5745 &mut self.basisworkspace,
5746 )
5747 .map_err(|e| e.to_string())?;
5748 if self.design.penalties.len() > 1 {
5749 vec![primary.clone(), Array2::<f64>::zeros(primary.raw_dim())]
5750 } else {
5751 vec![primary]
5752 }
5753 }
5754 other => {
5755 return Err(format!(
5756 "n-free penalty derivative re-key unsupported for basis metadata {:?}",
5757 std::mem::discriminant(other)
5758 ));
5759 }
5760 };
5761 if locals.len() != self.design.penalties.len() {
5762 return Err(format!(
5763 "n-free penalty derivative re-key produced {} blocks but the frozen design carries {} \
5764 — penalty topology is not ψ-stable",
5765 locals.len(),
5766 self.design.penalties.len()
5767 ));
5768 }
5769 Ok((global_range, p_total, locals))
5770 }
5771
5772 fn apply_log_kappa(
5773 &mut self,
5774 log_kappa: &SpatialLogKappaCoords,
5775 term_indices: &[usize],
5776 ) -> Result<(), String> {
5777 if term_indices.len() != log_kappa.dims_per_term().len() {
5778 return Err(SmoothError::dimension_mismatch(format!(
5779 "incremental realizer log-kappa term mismatch: term_indices={}, dims_per_term={}",
5780 term_indices.len(),
5781 log_kappa.dims_per_term().len()
5782 ))
5783 .into());
5784 }
5785
5786 let mut any_changed = false;
5787 for (slot, &term_idx) in term_indices.iter().enumerate() {
5788 any_changed |= self.apply_log_kappa_to_term(term_idx, log_kappa.term_slice(slot))?;
5789 }
5790
5791 if any_changed {
5792 self.refresh_full_design_operator()?;
5793 rebuild_smooth_auxiliary_state(
5794 &mut self.design.smooth,
5795 &self.dropped_penaltyinfo_by_term,
5796 )?;
5797 rebuild_term_collection_auxiliary_state(&self.spec, &mut self.design)?;
5798 self.design_revision = self.design_revision.wrapping_add(1);
5799 }
5800 Ok(())
5801 }
5802
5803 fn apply_log_kappa_to_term(&mut self, term_idx: usize, psi: &[f64]) -> Result<bool, String> {
5804 if !spatial_term_supports_hyper_optimization(&self.spec, term_idx) {
5805 return Err(SmoothError::invalid_config(format!(
5806 "incremental realizer term {term_idx} does not expose spatial hyperparameters"
5807 ))
5808 .into());
5809 }
5810 let measure_jet_term = measure_jet_term_spec(&self.spec, term_idx).is_some();
5814 let constant_curvature_term = constant_curvature_term_spec(&self.spec, term_idx).is_some();
5818 let mut next_length_scale = None;
5819 let mut next_aniso: Option<Vec<f64>> = None;
5820 if measure_jet_term {
5821 if !set_measure_jet_psi_dials(&mut self.spec, term_idx, psi)
5822 .map_err(|e| e.to_string())?
5823 {
5824 return Ok(false);
5825 }
5826 } else if constant_curvature_term {
5827 if !set_constant_curvature_kappa(&mut self.spec, term_idx, psi)
5828 .map_err(|e| e.to_string())?
5829 {
5830 return Ok(false);
5831 }
5832 } else {
5833 let current_length_scale = get_spatial_length_scale(&self.spec, term_idx);
5834 let current_aniso = get_spatial_aniso_log_scales(&self.spec, term_idx);
5835 let (ls, eta) = spatial_term_psi_to_length_scale_and_aniso(psi);
5836 next_length_scale = ls;
5837 next_aniso = eta;
5838 let same_length = spatial_length_scale_matches(current_length_scale, next_length_scale);
5839 let same_aniso = spatial_aniso_matches(current_aniso.as_deref(), next_aniso.as_deref());
5840 if same_length && same_aniso {
5841 return Ok(false);
5842 }
5843 if let Some(length_scale) = next_length_scale {
5844 set_spatial_length_scale(&mut self.spec, term_idx, length_scale)
5845 .map_err(|e| e.to_string())?;
5846 }
5847 if let Some(eta) = next_aniso.clone() {
5848 set_spatial_aniso_log_scales(&mut self.spec, term_idx, eta)
5849 .map_err(|e| e.to_string())?;
5850 }
5851 }
5852
5853 let geometry_slot = self
5864 .spatial_realization_geometry
5865 .get(term_idx)
5866 .ok_or_else(|| format!("incremental realizer geometry slot {term_idx} out of range"))?;
5867 let mut build_spec = match geometry_slot {
5868 Some(cached) => cached.clone(),
5869 None => self
5870 .spec
5871 .smooth_terms
5872 .get(term_idx)
5873 .ok_or_else(|| format!("incremental realizer smooth term {term_idx} out of range"))?
5874 .clone(),
5875 };
5876 if measure_jet_term {
5877 set_single_term_measure_jet_psi_dials(&mut build_spec, psi)
5881 .map_err(|e| e.to_string())?;
5882 } else if constant_curvature_term {
5883 set_single_term_constant_curvature_kappa(&mut build_spec, psi)
5888 .map_err(|e| e.to_string())?;
5889 } else {
5890 if let Some(length_scale) = next_length_scale {
5891 set_single_term_spatial_length_scale(&mut build_spec, length_scale)
5892 .map_err(|e| e.to_string())?;
5893 }
5894 if let Some(eta) = next_aniso {
5895 set_single_term_spatial_aniso_log_scales(&mut build_spec, eta)
5896 .map_err(|e| e.to_string())?;
5897 }
5898 }
5899
5900 let termname = build_spec.name.clone();
5901 let local = build_single_local_smooth_term(
5902 self.data,
5903 &build_spec,
5904 &mut self.basisworkspace,
5905 )
5906 .map_err(|e| {
5907 format!(
5908 "failed to rebuild smooth term '{termname}' during incremental κ realization: {e}"
5909 )
5910 })?;
5911
5912 if self.spatial_realization_geometry[term_idx].is_none()
5917 && let Some(frozen) = freeze_geometry_from_metadata(&build_spec, &local.metadata)
5918 {
5919 if let (
5931 SmoothBasisSpec::Matern {
5932 spec: frozen_spec, ..
5933 },
5934 Some(SmoothBasisSpec::Matern {
5935 spec: live_spec, ..
5936 }),
5937 ) = (
5938 &frozen.basis,
5939 self.spec
5940 .smooth_terms
5941 .get_mut(term_idx)
5942 .map(|t| &mut t.basis),
5943 ) {
5944 live_spec.identifiability = frozen_spec.identifiability.clone();
5945 live_spec.center_strategy = frozen_spec.center_strategy.clone();
5946 }
5947 self.spatial_realization_geometry[term_idx] = Some(frozen);
5948 }
5949
5950 let realization = wrap_local_build_as_realization(local, &build_spec)?;
5951 self.replace_term_realization(term_idx, realization)?;
5952 Ok(true)
5953 }
5954
5955 fn replace_term_realization(
5956 &mut self,
5957 term_idx: usize,
5958 realization: SingleSmoothTermRealization,
5959 ) -> Result<(), String> {
5960 let t_replace = std::time::Instant::now();
5961 let SingleSmoothTermRealization {
5962 design_local,
5963 term,
5964 dropped_penaltyinfo,
5965 } = realization;
5966 let SmoothTerm {
5967 name,
5968 penalties_local,
5969 nullspace_dims,
5970 penaltyinfo_local,
5971 metadata,
5972 lower_bounds_local,
5973 linear_constraints_local,
5974 joint_null_rotation,
5975 ..
5976 } = term;
5977 let coeff_range = self
5978 .design
5979 .smooth
5980 .terms
5981 .get(term_idx)
5982 .ok_or_else(|| format!("incremental realizer smooth term {term_idx} out of range"))?
5983 .coeff_range
5984 .clone();
5985 if design_local.ncols() != coeff_range.len() {
5986 return Err(SmoothError::dimension_mismatch(format!(
5987 "incremental realizer width mismatch for term {}: rebuilt_cols={}, cached_cols={}",
5988 term_idx,
5989 design_local.ncols(),
5990 coeff_range.len()
5991 ))
5992 .into());
5993 }
5994 if design_local.nrows() != self.design.design.nrows() {
5995 return Err(SmoothError::dimension_mismatch(format!(
5996 "incremental realizer row mismatch for term {}: rebuilt_rows={}, design_rows={}",
5997 term_idx,
5998 design_local.nrows(),
5999 self.design.design.nrows()
6000 ))
6001 .into());
6002 }
6003
6004 let active_penaltyinfo = penaltyinfo_local
6005 .iter()
6006 .filter(|info| info.active)
6007 .cloned()
6008 .collect::<Vec<_>>();
6009 let smooth_penalty_range = self
6010 .smooth_penalty_ranges
6011 .get(term_idx)
6012 .ok_or_else(|| {
6013 format!("incremental realizer missing smooth penalty range for term {term_idx}")
6014 })?
6015 .clone();
6016 let full_penalty_range = self
6017 .full_penalty_ranges
6018 .get(term_idx)
6019 .ok_or_else(|| {
6020 format!("incremental realizer missing full penalty range for term {term_idx}")
6021 })?
6022 .clone();
6023 if active_penaltyinfo.len() != smooth_penalty_range.len()
6024 || penalties_local.len() != smooth_penalty_range.len()
6025 || nullspace_dims.len() != smooth_penalty_range.len()
6026 {
6027 return Err(SmoothError::dimension_mismatch(format!(
6028 "incremental realizer topology changed for term '{}': penalties={}, infos={}, nullspaces={}, cached_penalties={}",
6029 name,
6030 penalties_local.len(),
6031 active_penaltyinfo.len(),
6032 nullspace_dims.len(),
6033 smooth_penalty_range.len()
6034 ))
6035 .into());
6036 }
6037
6038 self.design.smooth.term_designs[term_idx] = design_local;
6039
6040 for (offset, penalty_local) in penalties_local.iter().enumerate() {
6041 let smooth_penalty_idx = smooth_penalty_range.start + offset;
6042 let full_penalty_idx = full_penalty_range.start + offset;
6043 let nullspace_dim = nullspace_dims[offset];
6044 let penalty_info = active_penaltyinfo[offset].clone();
6045
6046 if penalty_local.nrows() != coeff_range.len()
6047 || penalty_local.ncols() != coeff_range.len()
6048 {
6049 return Err(SmoothError::dimension_mismatch(format!(
6050 "incremental realizer penalty shape mismatch for term '{}' penalty {}: \
6051 penalty is {}x{} but coeff_range has {} columns",
6052 name,
6053 offset,
6054 penalty_local.nrows(),
6055 penalty_local.ncols(),
6056 coeff_range.len()
6057 ))
6058 .into());
6059 }
6060
6061 let smooth_penalty = self
6062 .design
6063 .smooth
6064 .penalties
6065 .get_mut(smooth_penalty_idx)
6066 .ok_or_else(|| {
6067 format!(
6068 "incremental realizer smooth penalty {} out of range for term {}",
6069 smooth_penalty_idx, term_idx
6070 )
6071 })?;
6072 smooth_penalty.local.assign(penalty_local);
6075
6076 let full_bp = self
6077 .design
6078 .penalties
6079 .get_mut(full_penalty_idx)
6080 .ok_or_else(|| {
6081 format!(
6082 "incremental realizer full penalty {} out of range for term {}",
6083 full_penalty_idx, term_idx
6084 )
6085 })?;
6086 full_bp.local.assign(penalty_local);
6089
6090 self.design.smooth.nullspace_dims[smooth_penalty_idx] = nullspace_dim;
6091 self.design.nullspace_dims[full_penalty_idx] = nullspace_dim;
6092
6093 self.design.smooth.penaltyinfo[smooth_penalty_idx].global_index = smooth_penalty_idx;
6094 self.design.smooth.penaltyinfo[smooth_penalty_idx].termname = Some(name.clone());
6095 self.design.smooth.penaltyinfo[smooth_penalty_idx].penalty = penalty_info.clone();
6096
6097 self.design.penaltyinfo[full_penalty_idx].global_index = full_penalty_idx;
6098 self.design.penaltyinfo[full_penalty_idx].termname = Some(name.clone());
6099 self.design.penaltyinfo[full_penalty_idx].penalty = penalty_info;
6100 }
6101
6102 let target_term = self.design.smooth.terms.get_mut(term_idx).ok_or_else(|| {
6103 format!("incremental realizer smooth term {term_idx} disappeared during replacement")
6104 })?;
6105 target_term.penalties_local = penalties_local;
6106 target_term.nullspace_dims = nullspace_dims;
6107 target_term.penaltyinfo_local = penaltyinfo_local;
6108 target_term.metadata = metadata;
6109 target_term.lower_bounds_local = lower_bounds_local;
6110 target_term.linear_constraints_local = linear_constraints_local;
6111 target_term.joint_null_rotation = joint_null_rotation;
6112 self.dropped_penaltyinfo_by_term[term_idx] = dropped_penaltyinfo;
6113 log::info!(
6114 "[STAGE] smooth basis rebuild (term {}, '{}', cols={}): {:.3}s",
6115 term_idx,
6116 target_term.name,
6117 coeff_range.len(),
6118 t_replace.elapsed().as_secs_f64(),
6119 );
6120 Ok(())
6121 }
6122
6123 fn refresh_full_design_operator(&mut self) -> Result<(), String> {
6124 let mut blocks = Vec::<DesignBlock>::with_capacity(
6125 self.fixed_blocks.len() + self.design.smooth.term_designs.len(),
6126 );
6127 blocks.extend(self.fixed_blocks.iter().cloned());
6128 for term_design in &self.design.smooth.term_designs {
6129 blocks.push(DesignBlock::from(term_design));
6130 }
6131 self.design.design = assemble_term_collection_design_matrix(blocks)
6132 .map_err(|e| format!("failed to refresh term-collection design: {e}"))?;
6133 Ok(())
6134 }
6135}
6136
6137fn build_term_collection_fixed_blocks(
6138 data: ArrayView2<'_, f64>,
6139 spec: &TermCollectionSpec,
6140) -> Result<Vec<DesignBlock>, BasisError> {
6141 let mut blocks = Vec::<DesignBlock>::new();
6142 if !term_collection_has_one_sided_anchored_bspline(spec) {
6143 blocks.push(DesignBlock::Intercept(data.nrows()));
6144 }
6145
6146 if !spec.linear_terms.is_empty() {
6147 let mut linear_block = Array2::<f64>::zeros((data.nrows(), spec.linear_terms.len()));
6148 for (j, linear) in spec.linear_terms.iter().enumerate() {
6149 let column = linear
6153 .realized_design_column(data)
6154 .map_err(BasisError::InvalidInput)?;
6155 linear_block.column_mut(j).assign(&column);
6156 }
6157 blocks.push(DesignBlock::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
6158 linear_block,
6159 )));
6160 }
6161
6162 for term in &spec.random_effect_terms {
6163 let block = build_random_effect_block(data, term)?;
6164 let re_op = RandomEffectOperator::new(block.group_ids, block.num_groups);
6165 blocks.push(DesignBlock::RandomEffect(Arc::new(re_op)));
6166 }
6167
6168 Ok(blocks)
6169}
6170
6171pub struct SpatialLengthScaleOptimizationResult<FitOut> {
6176 pub resolved_specs: Vec<TermCollectionSpec>,
6177 pub designs: Vec<TermCollectionDesign>,
6178 pub fit: FitOut,
6179 pub timing: Option<SpatialLengthScaleOptimizationTiming>,
6180}
6181
6182#[derive(Debug, Clone)]
6184pub struct ExactJointHyperSetup {
6185 rho0: Array1<f64>,
6186 rho_lower: Array1<f64>,
6187 rho_upper: Array1<f64>,
6188 log_kappa0: SpatialLogKappaCoords,
6189 log_kappa_lower: SpatialLogKappaCoords,
6190 log_kappa_upper: SpatialLogKappaCoords,
6191 auxiliary0: Array1<f64>,
6192 auxiliary_lower: Array1<f64>,
6193 auxiliary_upper: Array1<f64>,
6194}
6195
6196impl ExactJointHyperSetup {
6197 fn sanitize_rho_seed(
6198 rho0: Array1<f64>,
6199 rho_lower: &Array1<f64>,
6200 rho_upper: &Array1<f64>,
6201 ) -> Array1<f64> {
6202 Array1::from_iter(rho0.iter().enumerate().map(|(idx, &value)| {
6203 let lo = rho_lower[idx];
6204 let hi = rho_upper[idx];
6205 let fallback = 0.0_f64.clamp(lo, hi);
6206 if value.is_finite() {
6207 value.clamp(lo, hi)
6208 } else {
6209 fallback
6210 }
6211 }))
6212 }
6213
6214 pub(crate) fn new(
6215 rho0: Array1<f64>,
6216 rho_lower: Array1<f64>,
6217 rho_upper: Array1<f64>,
6218 log_kappa0: SpatialLogKappaCoords,
6219 log_kappa_lower: SpatialLogKappaCoords,
6220 log_kappa_upper: SpatialLogKappaCoords,
6221 ) -> Self {
6222 let rho0 = Self::sanitize_rho_seed(rho0, &rho_lower, &rho_upper);
6223 Self {
6224 rho0,
6225 rho_lower,
6226 rho_upper,
6227 log_kappa0,
6228 log_kappa_lower,
6229 log_kappa_upper,
6230 auxiliary0: Array1::zeros(0),
6231 auxiliary_lower: Array1::zeros(0),
6232 auxiliary_upper: Array1::zeros(0),
6233 }
6234 }
6235
6236 pub(crate) fn with_auxiliary(
6237 mut self,
6238 auxiliary0: Array1<f64>,
6239 auxiliary_lower: Array1<f64>,
6240 auxiliary_upper: Array1<f64>,
6241 ) -> Self {
6242 assert_eq!(
6243 auxiliary0.len(),
6244 auxiliary_lower.len(),
6245 "auxiliary lower bound length mismatch"
6246 );
6247 assert_eq!(
6248 auxiliary0.len(),
6249 auxiliary_upper.len(),
6250 "auxiliary upper bound length mismatch"
6251 );
6252 self.auxiliary0 = Self::sanitize_rho_seed(auxiliary0, &auxiliary_lower, &auxiliary_upper);
6253 self.auxiliary_lower = auxiliary_lower;
6254 self.auxiliary_upper = auxiliary_upper;
6255 self
6256 }
6257
6258 pub(crate) fn rho_dim(&self) -> usize {
6259 self.rho0.len()
6260 }
6261
6262 pub(crate) fn log_kappa_dim(&self) -> usize {
6263 self.log_kappa0.len()
6264 }
6265
6266 pub(crate) fn auxiliary_dim(&self) -> usize {
6267 self.auxiliary0.len()
6268 }
6269
6270 pub(crate) fn theta0(&self) -> Array1<f64> {
6271 let mut out =
6272 Array1::<f64>::zeros(self.rho_dim() + self.log_kappa_dim() + self.auxiliary_dim());
6273 out.slice_mut(s![..self.rho_dim()]).assign(&self.rho0);
6274 out.slice_mut(s![self.rho_dim()..self.rho_dim() + self.log_kappa_dim()])
6275 .assign(self.log_kappa0.as_array());
6276 out.slice_mut(s![self.rho_dim() + self.log_kappa_dim()..])
6277 .assign(&self.auxiliary0);
6278 out
6279 }
6280
6281 pub(crate) fn lower(&self) -> Array1<f64> {
6282 let mut out =
6283 Array1::<f64>::zeros(self.rho_dim() + self.log_kappa_dim() + self.auxiliary_dim());
6284 out.slice_mut(s![..self.rho_dim()]).assign(&self.rho_lower);
6285 out.slice_mut(s![self.rho_dim()..self.rho_dim() + self.log_kappa_dim()])
6286 .assign(self.log_kappa_lower.as_array());
6287 out.slice_mut(s![self.rho_dim() + self.log_kappa_dim()..])
6288 .assign(&self.auxiliary_lower);
6289 out
6290 }
6291
6292 pub(crate) fn upper(&self) -> Array1<f64> {
6293 let mut out =
6294 Array1::<f64>::zeros(self.rho_dim() + self.log_kappa_dim() + self.auxiliary_dim());
6295 out.slice_mut(s![..self.rho_dim()]).assign(&self.rho_upper);
6296 out.slice_mut(s![self.rho_dim()..self.rho_dim() + self.log_kappa_dim()])
6297 .assign(self.log_kappa_upper.as_array());
6298 out.slice_mut(s![self.rho_dim() + self.log_kappa_dim()..])
6299 .assign(&self.auxiliary_upper);
6300 out
6301 }
6302
6303 pub(crate) fn log_kappa_dims_per_term(&self) -> Vec<usize> {
6305 self.log_kappa0.dims_per_term().to_vec()
6306 }
6307}
6308
6309struct ExactJointDesignCache<'d> {
6315 realizers: Vec<FrozenTermCollectionIncrementalRealizer<'d>>,
6316 block_term_indices: Vec<Vec<usize>>,
6317 current_theta: Option<Array1<f64>>,
6318 last_cost: Option<f64>,
6319 last_eval: Option<(
6320 f64,
6321 Array1<f64>,
6322 gam_problem::HessianResult,
6323 )>,
6324 rho_dim: usize,
6325 all_dims: Vec<usize>,
6326 log_kappa_dim: usize,
6327 block_term_counts: Vec<usize>,
6328}
6329
6330impl<'d> ExactJointDesignCache<'d> {
6331 fn new(
6332 data: ArrayView2<'d, f64>,
6333 blocks: Vec<(TermCollectionSpec, TermCollectionDesign, Vec<usize>)>,
6334 rho_dim: usize,
6335 all_dims: Vec<usize>,
6336 ) -> Result<Self, String> {
6337 let n_blocks = blocks.len();
6338 let mut realizers = Vec::with_capacity(n_blocks);
6339 let mut block_term_indices = Vec::with_capacity(n_blocks);
6340 let mut block_term_counts = Vec::with_capacity(n_blocks);
6341
6342 for (spec, design, terms) in blocks {
6343 block_term_counts.push(terms.len());
6344 block_term_indices.push(terms);
6345 realizers.push(FrozenTermCollectionIncrementalRealizer::new(
6346 data, spec, design,
6347 )?);
6348 }
6349
6350 Ok(Self {
6351 realizers,
6352 block_term_indices,
6353 current_theta: None,
6354 last_cost: None,
6355 last_eval: None,
6356 rho_dim,
6357 log_kappa_dim: all_dims.iter().sum(),
6358 all_dims,
6359 block_term_counts,
6360 })
6361 }
6362
6363 fn ensure_theta(&mut self, theta: &Array1<f64>) -> Result<(), String> {
6364 if self
6365 .current_theta
6366 .as_ref()
6367 .is_some_and(|cached| theta_values_match(cached, theta))
6368 {
6369 return Ok(());
6370 }
6371
6372 let t_ensure = std::time::Instant::now();
6373 let kappa_theta_len = self.rho_dim + self.log_kappa_dim;
6374 if theta.len() < kappa_theta_len {
6375 return Err(SmoothError::dimension_mismatch(format!(
6376 "exact-joint theta length mismatch: got {}, expected at least {} (rho_dim={}, log_kappa_dim={})",
6377 theta.len(),
6378 kappa_theta_len,
6379 self.rho_dim,
6380 self.log_kappa_dim
6381 ))
6382 .into());
6383 }
6384 let theta_kappa = theta.slice(s![..kappa_theta_len]).to_owned();
6385 let full_log_kappa = SpatialLogKappaCoords::from_theta_tail_with_dims(
6386 &theta_kappa,
6387 self.rho_dim,
6388 self.all_dims.clone(),
6389 );
6390
6391 let n = self.realizers.len();
6395 let mut remaining = full_log_kappa;
6396 for block_idx in 0..n {
6397 let count = self.block_term_counts[block_idx];
6398 if block_idx < n - 1 {
6399 let (block_lk, rest) = remaining.split_at(count);
6400 self.realizers[block_idx]
6401 .apply_log_kappa(&block_lk, &self.block_term_indices[block_idx])?;
6402 remaining = rest;
6403 } else {
6404 self.realizers[block_idx]
6406 .apply_log_kappa(&remaining, &self.block_term_indices[block_idx])?;
6407 }
6408 }
6409
6410 log::info!(
6411 "[STAGE] ensure_theta (n-block, {} blocks, {} realizers): {:.3}s",
6412 n,
6413 self.realizers.len(),
6414 t_ensure.elapsed().as_secs_f64(),
6415 );
6416 self.current_theta = Some(theta.clone());
6417 self.last_cost = None;
6418 self.last_eval = None;
6419 Ok(())
6420 }
6421
6422 impl_exact_joint_theta_memo!();
6423
6424 fn store_cost_only(&mut self, theta: &Array1<f64>, cost: f64) {
6430 if self
6431 .current_theta
6432 .as_ref()
6433 .is_some_and(|cached| theta_values_match(cached, theta))
6434 {
6435 self.last_cost = Some(cost);
6436 }
6437 }
6438
6439 fn specs(&self) -> Vec<&TermCollectionSpec> {
6440 self.realizers.iter().map(|r| r.spec()).collect()
6441 }
6442
6443 fn designs(&self) -> Vec<&TermCollectionDesign> {
6444 self.realizers.iter().map(|r| r.design()).collect()
6445 }
6446
6447 fn design_revision(&self) -> u64 {
6457 self.realizers
6458 .iter()
6459 .fold(0u64, |acc, r| acc.wrapping_add(r.design_revision()))
6460 }
6461}
6462
6463pub(crate) fn seed_risk_profile_for_likelihood_family(
6464 family: &LikelihoodSpec,
6465) -> gam_problem::SeedRiskProfile {
6466 match &family.response {
6467 ResponseFamily::Gaussian => gam_problem::SeedRiskProfile::Gaussian,
6468 ResponseFamily::RoystonParmar => gam_problem::SeedRiskProfile::Survival,
6469 ResponseFamily::Binomial
6470 | ResponseFamily::Poisson
6471 | ResponseFamily::Tweedie { .. }
6472 | ResponseFamily::NegativeBinomial { .. }
6473 | ResponseFamily::Beta { .. }
6474 | ResponseFamily::Gamma => gam_problem::SeedRiskProfile::GeneralizedLinear,
6475 }
6476}
6477
6478const EXACT_JOINT_SECOND_ORDER_THETA_CAP: usize = 8;
6486
6487fn exact_joint_seed_config(
6488 risk_profile: gam_problem::SeedRiskProfile,
6489 auxiliary_dim: usize,
6490) -> gam_problem::SeedConfig {
6491 let mut config = gam_problem::SeedConfig {
6492 risk_profile,
6493 num_auxiliary_trailing: auxiliary_dim,
6494 ..Default::default()
6495 };
6496 match risk_profile {
6497 gam_problem::SeedRiskProfile::Gaussian
6498 | gam_problem::SeedRiskProfile::GaussianLocationScale => {
6499 config.max_seeds = 4;
6500 config.seed_budget = 2;
6501 }
6502 gam_problem::SeedRiskProfile::GeneralizedLinear => {
6503 config.max_seeds = 1;
6508 config.seed_budget = 1;
6509 config.screen_max_inner_iterations = 8;
6510 }
6511 gam_problem::SeedRiskProfile::Survival => {
6512 config.max_seeds = 8;
6518 config.seed_budget = 4;
6519 config.screen_max_inner_iterations = 8;
6520 }
6521 }
6522 config
6523}
6524
6525#[cfg(test)]
6526mod exact_joint_seed_config_tests {
6527 use super::*;
6528
6529 #[test]
6530 fn exact_joint_marginal_slope_profiles_get_deeper_startup_validation() {
6531 let bms = exact_joint_seed_config(gam_problem::SeedRiskProfile::GeneralizedLinear, 2);
6532 assert_eq!(bms.max_seeds, 1);
6533 assert_eq!(bms.seed_budget, 1);
6534 assert_eq!(bms.screen_max_inner_iterations, 8);
6535 assert_eq!(bms.num_auxiliary_trailing, 2);
6536
6537 let survival = exact_joint_seed_config(gam_problem::SeedRiskProfile::Survival, 3);
6538 assert_eq!(survival.max_seeds, 8);
6539 assert_eq!(survival.seed_budget, 4);
6540 assert_eq!(survival.screen_max_inner_iterations, 8);
6541 assert_eq!(survival.num_auxiliary_trailing, 3);
6542 }
6543
6544 #[test]
6545 fn exact_joint_gaussian_keeps_tight_historical_multistart_budget() {
6546 let gaussian = exact_joint_seed_config(gam_problem::SeedRiskProfile::Gaussian, 1);
6547 assert_eq!(gaussian.max_seeds, 4);
6548 assert_eq!(gaussian.seed_budget, 2);
6549 assert_eq!(
6550 gaussian.screen_max_inner_iterations,
6551 gam_problem::SeedConfig::default().screen_max_inner_iterations
6552 );
6553 assert_eq!(gaussian.num_auxiliary_trailing, 1);
6554 }
6555}
6556
6557pub(crate) fn exact_joint_multistart_outer_problem(
6558 theta0: &Array1<f64>,
6559 lower: &Array1<f64>,
6560 upper: &Array1<f64>,
6561 rho_dim: usize,
6562 auxiliary_dim: usize,
6563 n_params: usize,
6564 gradient: gam_problem::Derivative,
6565 hessian: gam_problem::DeclaredHessianForm,
6566 prefer_gradient_only: bool,
6567 disable_fixed_point: bool,
6568 risk_profile: gam_problem::SeedRiskProfile,
6569 tolerance: f64,
6570 max_iter: usize,
6571 bfgs_step_cap: Option<f64>,
6580 bfgs_step_cap_psi: Option<f64>,
6581 screening_cap: Option<Arc<AtomicUsize>>,
6582 profiled_objective_size: Option<(usize, usize)>,
6603 has_constant_curvature: bool,
6612) -> gam_solve::rho_optimizer::OuterProblem {
6613 let mut seed_heuristic = theta0.to_vec();
6614 for value in &mut seed_heuristic[..rho_dim] {
6615 *value = value.exp();
6616 }
6617 let rho_ceiling = if has_constant_curvature {
6622 gam_solve::estimate::RHO_BOUND
6623 } else {
6624 12.0
6625 };
6626 let mut problem = gam_solve::rho_optimizer::OuterProblem::new(n_params)
6627 .with_gradient(gradient)
6628 .with_hessian(hessian)
6629 .with_prefer_gradient_only(prefer_gradient_only)
6630 .with_disable_fixed_point(disable_fixed_point)
6631 .with_fallback_policy(gam_solve::rho_optimizer::FallbackPolicy::Automatic)
6641 .with_psi_dim(auxiliary_dim)
6642 .with_tolerance(tolerance)
6643 .with_max_iter(max_iter)
6644 .with_bounds(lower.clone(), upper.clone())
6645 .with_initial_rho(theta0.clone())
6646 .with_bfgs_step_cap(bfgs_step_cap)
6647 .with_bfgs_step_cap_psi(bfgs_step_cap_psi)
6648 .with_seed_config({
6649 let mut sc = exact_joint_seed_config(risk_profile, auxiliary_dim);
6650 if has_constant_curvature {
6651 sc.bounds = (sc.bounds.0, rho_ceiling);
6655 }
6674 sc
6675 })
6676 .with_rho_bound(rho_ceiling)
6677 .with_heuristic_lambdas(seed_heuristic);
6678 if let Some((n_obs, p_cols)) = profiled_objective_size {
6679 problem = problem
6687 .with_objective_scale(Some(n_obs as f64))
6688 .with_problem_size(n_obs, p_cols)
6689 .with_arc_initial_regularization(Some(0.25))
6690 .with_operator_initial_trust_radius(Some(4.0));
6691 }
6692 if let Some(screening_cap) = screening_cap {
6693 problem = problem
6694 .with_screening_cap(screening_cap)
6695 .with_screen_initial_rho(true);
6696 }
6697 problem
6698}
6699
6700fn kappa_phase_failure_is_fixed_kappa_recoverable(message: &str) -> bool {
6711 message.contains("no candidate seeds passed outer startup validation")
6712 || message.contains("joint hyper rho dimension mismatch")
6713 || message.contains("objective returned a non-finite cost")
6714}
6715
6716pub fn optimize_spatial_length_scale_exact_joint<FitOut, FitFn, ExactFn, ExactEfsFn, SeedFn>(
6717 data: ArrayView2<'_, f64>,
6718 block_specs: &[TermCollectionSpec],
6719 block_term_indices: &[Vec<usize>],
6720 kappa_options: &SpatialLengthScaleOptimizationOptions,
6721 joint_setup: &ExactJointHyperSetup,
6722 seed_risk_profile: gam_problem::SeedRiskProfile,
6723 analytic_joint_gradient_available: bool,
6724 analytic_joint_hessian_available: bool,
6725 disable_fixed_point: bool,
6726 screening_cap: Option<Arc<AtomicUsize>>,
6727 outer_derivative_policy: gam_model_api::families::custom_family::OuterDerivativePolicy,
6728 mut fit_fn: FitFn,
6729 mut exact_fn: ExactFn,
6730 mut exact_efs_fn: ExactEfsFn,
6731 mut seed_inner_beta_fn: SeedFn,
6732) -> Result<SpatialLengthScaleOptimizationResult<FitOut>, String>
6733where
6734 FitOut: Clone,
6735 FitFn: FnMut(
6736 &Array1<f64>,
6737 &[TermCollectionSpec],
6738 &[TermCollectionDesign],
6739 ) -> Result<FitOut, String>,
6740 ExactFn: FnMut(
6741 &Array1<f64>,
6742 &[TermCollectionSpec],
6743 &[TermCollectionDesign],
6744 gam_solve::estimate::reml::reml_outer_engine::EvalMode,
6745 &gam_problem::outer_subsample::RowSet,
6746 ) -> Result<
6747 (
6748 f64,
6749 Array1<f64>,
6750 gam_problem::HessianResult,
6751 ),
6752 String,
6753 >,
6754 ExactEfsFn: FnMut(
6755 &Array1<f64>,
6756 &[TermCollectionSpec],
6757 &[TermCollectionDesign],
6758 ) -> Result<gam_problem::EfsEval, String>,
6759 SeedFn:
6760 FnMut(&Array1<f64>) -> Result<gam_solve::rho_optimizer::SeedOutcome, EstimationError>,
6761{
6762 let n_blocks = block_specs.len();
6763 if block_term_indices.len() != n_blocks {
6764 return Err(SmoothError::dimension_mismatch(format!(
6765 "block_specs ({}) and block_term_indices ({}) length mismatch",
6766 n_blocks,
6767 block_term_indices.len()
6768 ))
6769 .into());
6770 }
6771
6772 let log_kappa_dim = joint_setup.log_kappa_dim();
6773
6774 log::warn!(
6775 "[OUTER-FD-AUDIT/spatial-exact-joint] driver entry: aux_dim={} log_kappa_dim={} kappa_enabled={} rho_dim={} theta0_len={}",
6776 joint_setup.auxiliary_dim(),
6777 log_kappa_dim,
6778 kappa_options.enabled,
6779 joint_setup.rho_dim(),
6780 joint_setup.theta0().len()
6781 );
6782
6783 if joint_setup.auxiliary_dim() == 0 && (!kappa_options.enabled || log_kappa_dim == 0) {
6787 log::warn!(
6788 "[OUTER-FD-AUDIT/spatial-exact-joint] taking FAST path (no outer theta optimization in this driver)"
6789 );
6790 let (designs, resolved_specs) = build_term_collection_designs_and_freeze_joint(
6791 data, block_specs,
6792 )
6793 .map_err(|e| {
6794 format!("failed to build and freeze joint block designs during exact joint kappa optimization: {e}")
6795 })?;
6796 let theta0 = joint_setup.theta0();
6797
6798 let spec_refs: Vec<TermCollectionSpec> = resolved_specs.clone();
6800 let design_refs: Vec<TermCollectionDesign> = designs.clone();
6801 let fit = fit_fn(&theta0, &spec_refs, &design_refs)?;
6802 return Ok(SpatialLengthScaleOptimizationResult {
6803 resolved_specs,
6804 designs,
6805 fit,
6806 timing: None,
6807 });
6808 }
6809
6810 let theta0 = joint_setup.theta0();
6814 let lower = joint_setup.lower();
6815 let upper = joint_setup.upper();
6816 if theta0.len() < log_kappa_dim || lower.len() != theta0.len() || upper.len() != theta0.len() {
6817 return Err(SmoothError::dimension_mismatch(format!(
6818 "invalid exact joint theta setup: theta0={}, lower={}, upper={}, required_log_kappa_dim={}",
6819 theta0.len(),
6820 lower.len(),
6821 upper.len(),
6822 log_kappa_dim
6823 ))
6824 .into());
6825 }
6826 let rho_dim = joint_setup.rho_dim();
6827 let all_dims = joint_setup.log_kappa_dims_per_term();
6828
6829 let (boot_designs, best_specs) = build_term_collection_designs_and_freeze_joint(
6831 data,
6832 block_specs,
6833 )
6834 .map_err(|e| {
6835 format!(
6836 "failed to build and freeze joint block designs during exact joint kappa bootstrap: {e}"
6837 )
6838 })?;
6839 let policy_hessian_form = outer_derivative_policy.declared_hessian_form();
6849 let analytic_outer_hessian_available = analytic_joint_hessian_available
6850 && matches!(
6851 policy_hessian_form,
6852 gam_problem::DeclaredHessianForm::Either
6853 | gam_problem::DeclaredHessianForm::Dense
6854 | gam_problem::DeclaredHessianForm::Operator { .. }
6855 );
6856 let prefer_gradient_only = !analytic_outer_hessian_available;
6857
6858 let theta_dim = theta0.len();
6859 let psi_dim = theta_dim - rho_dim;
6860
6861 let cache_blocks: Vec<(TermCollectionSpec, TermCollectionDesign, Vec<usize>)> = best_specs
6863 .iter()
6864 .zip(boot_designs.iter())
6865 .zip(block_term_indices.iter())
6866 .map(|((spec, design), terms)| (spec.clone(), design.clone(), terms.clone()))
6867 .collect();
6868
6869 struct NBlockExactJointState<'d> {
6870 cache: ExactJointDesignCache<'d>,
6871 }
6872
6873 let mut state = NBlockExactJointState {
6874 cache: ExactJointDesignCache::new(data, cache_blocks, rho_dim, all_dims.clone())?,
6875 };
6876
6877 const KAPPA_PILOT_K: usize = 5_000;
6902 const KAPPA_POLISH_K: usize = 25_000;
6903 const KAPPA_POLISH_TRIGGER_N: usize = 100_000;
6904
6905 let n_total = data.nrows();
6906 let use_staged_kappa = outer_derivative_policy.should_use_staged_kappa(n_total);
6907 if use_staged_kappa {
6908 log::info!(
6909 "[KAPPA-STAGED] auto-engaging pilot+polish schedule: n={} pilot_k={} polish_k={}",
6910 n_total,
6911 KAPPA_PILOT_K,
6912 KAPPA_POLISH_K,
6913 );
6914 }
6915
6916 fn build_uniform_pilot_subsample(
6933 n_total: usize,
6934 k_target: usize,
6935 seed: u64,
6936 ) -> gam_problem::outer_subsample::OuterScoreSubsample {
6937 use gam_problem::outer_subsample::OuterScoreSubsample;
6938 let k = k_target.min(n_total);
6939 if k == 0 || n_total == 0 {
6940 return OuterScoreSubsample::from_uniform_inclusion_mask(Vec::new(), n_total, seed);
6941 }
6942 let mut mask: Vec<usize> = Vec::with_capacity(k);
6946 let mut state = seed.wrapping_add(0x9E3779B97F4A7C15);
6948 let splitmix = |s: &mut u64| -> u64 { gam_linalg::utils::splitmix64(s) };
6949 let mut taken = std::collections::HashSet::with_capacity(k);
6950 for j in (n_total - k)..n_total {
6951 let r = (splitmix(&mut state) % (j as u64 + 1)) as usize;
6952 if !taken.insert(r) {
6953 taken.insert(j);
6954 mask.push(j);
6955 } else {
6956 mask.push(r);
6957 }
6958 }
6959 mask.sort_unstable();
6960 mask.dedup();
6961 OuterScoreSubsample::from_uniform_inclusion_mask(mask, n_total, seed)
6962 }
6963
6964 let current_row_set: std::cell::RefCell<gam_problem::outer_subsample::RowSet> = if use_staged_kappa {
6965 let pilot = build_uniform_pilot_subsample(n_total, KAPPA_PILOT_K, n_total as u64);
6966 std::cell::RefCell::new(gam_problem::outer_subsample::RowSet::Subsample {
6967 rows: std::sync::Arc::clone(&pilot.rows),
6968 n_full: n_total,
6969 })
6970 } else {
6971 std::cell::RefCell::new(gam_problem::outer_subsample::RowSet::All)
6972 };
6973
6974 let exact_fn_cell = std::cell::RefCell::new(&mut exact_fn);
6975 let exact_efs_fn_cell = std::cell::RefCell::new(&mut exact_efs_fn);
6976
6977 use std::cell::Cell;
6992 let kphase_cost_calls: Cell<usize> = Cell::new(0);
6993 let kphase_cost_total_s: Cell<f64> = Cell::new(0.0);
6994 let kphase_eval_calls: Cell<usize> = Cell::new(0);
6995 let kphase_eval_total_s: Cell<f64> = Cell::new(0.0);
6996 let kphase_efs_calls: Cell<usize> = Cell::new(0);
6997 let kphase_efs_total_s: Cell<f64> = Cell::new(0.0);
6998 let kphase_optim_start = std::time::Instant::now();
6999 let kphase_log_kappa_dim = log_kappa_dim;
7000 let kphase_log_norms = |theta: &Array1<f64>| -> (f64, f64) {
7001 let theta_norm = theta.iter().map(|v| v * v).sum::<f64>().sqrt();
7002 let log_kappa_norm = if kphase_log_kappa_dim > 0 && theta.len() >= kphase_log_kappa_dim {
7003 let start = theta.len() - kphase_log_kappa_dim;
7004 theta.iter().skip(start).map(|v| v * v).sum::<f64>().sqrt()
7005 } else {
7006 0.0
7007 };
7008 (theta_norm, log_kappa_norm)
7009 };
7010
7011 use gam_solve::rho_optimizer::OuterEvalOrder;
7012 use gam_problem::{DeclaredHessianForm, Derivative, OuterEval};
7013
7014 let joint_p_cols: usize = boot_designs
7018 .iter()
7019 .map(|d| d.design.ncols())
7020 .sum::<usize>()
7021 .max(1);
7022
7023 let problem = exact_joint_multistart_outer_problem(
7024 &theta0,
7025 &lower,
7026 &upper,
7027 rho_dim,
7028 psi_dim,
7029 theta_dim,
7030 if analytic_joint_gradient_available {
7031 Derivative::Analytic
7032 } else {
7033 Derivative::Unavailable
7034 },
7035 if analytic_outer_hessian_available {
7036 DeclaredHessianForm::Either
7037 } else {
7038 DeclaredHessianForm::Unavailable
7039 },
7040 prefer_gradient_only,
7041 disable_fixed_point,
7042 seed_risk_profile,
7043 kappa_options.rel_tol.max(1e-6),
7044 kappa_options.max_outer_iter.max(1),
7045 Some(5.0),
7047 Some(kappa_options.log_step.clamp(0.25, 1.0)),
7049 screening_cap.clone(),
7050 Some((n_total, joint_p_cols)),
7053 block_specs
7056 .iter()
7057 .any(|s| !constant_curvature_term_indices(s).is_empty()),
7058 );
7059
7060 fn collect_specs(cache: &ExactJointDesignCache<'_>) -> Vec<TermCollectionSpec> {
7062 cache.specs().into_iter().cloned().collect()
7063 }
7064 fn collect_designs(cache: &ExactJointDesignCache<'_>) -> Vec<TermCollectionDesign> {
7065 cache.designs().into_iter().cloned().collect()
7066 }
7067
7068 let result = {
7069 let eval_outer = |ctx: &mut &mut NBlockExactJointState<'_>,
7070 theta: &Array1<f64>,
7071 order: OuterEvalOrder|
7072 -> Result<OuterEval, EstimationError> {
7073 if let Some((cost, grad, hess)) = ctx.cache.memoized_eval(theta) {
7074 let cached_satisfies_order = match order {
7075 OuterEvalOrder::Value => true,
7076 OuterEvalOrder::ValueAndGradient => true,
7077 OuterEvalOrder::ValueGradientHessian => hess.is_analytic(),
7078 };
7079 if cached_satisfies_order {
7080 if !cost.is_finite() {
7081 return Ok(OuterEval::infeasible(theta.len()));
7082 }
7083 if grad.iter().any(|v| !v.is_finite()) {
7096 return Ok(OuterEval::infeasible(theta.len()));
7097 }
7098 return Ok(OuterEval {
7099 cost,
7100 gradient: grad,
7101 hessian: hess,
7102 inner_beta_hint: None,
7103 });
7104 }
7105 }
7106 if gam_solve::rho_optimizer::outer_wall_clock_deadline_exceeded() {
7123 return Ok(OuterEval::infeasible(theta.len()));
7124 }
7125 if let Err(err) = ctx.cache.ensure_theta(theta) {
7126 log::warn!(
7127 "[OUTER] n-block exact-joint spatial: ensure_theta failed during gradient evaluation: {err}"
7128 );
7129 return Ok(OuterEval::infeasible(theta.len()));
7130 }
7131 let design_revision = Some(ctx.cache.design_revision());
7132 let specs = collect_specs(&ctx.cache);
7133 let designs = collect_designs(&ctx.cache);
7134 let clamped = outer_derivative_policy.order_for_evaluation(order);
7142 let need_hessian = matches!(clamped, OuterEvalOrder::ValueGradientHessian)
7143 && analytic_outer_hessian_available;
7144 let eval_mode = if need_hessian {
7145 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueGradientHessian
7146 } else {
7147 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueAndGradient
7148 };
7149 let t0 = std::time::Instant::now();
7150 let result = {
7151 let row_set_borrow = current_row_set.borrow();
7152 (*exact_fn_cell.borrow_mut())(theta, &specs, &designs, eval_mode, &row_set_borrow)
7153 };
7154 let elapsed_s = t0.elapsed().as_secs_f64();
7155 kphase_eval_calls.set(kphase_eval_calls.get() + 1);
7156 kphase_eval_total_s.set(kphase_eval_total_s.get() + elapsed_s);
7157 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta);
7158 log::info!(
7159 "[KAPPA-PHASE] phase=eval_outer call={} order={:?} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
7160 kphase_eval_calls.get(),
7161 order,
7162 design_revision,
7163 theta_norm,
7164 log_kappa_norm,
7165 elapsed_s,
7166 );
7167 match result {
7168 Ok((cost, grad, hess)) => {
7169 ctx.cache.store_eval((cost, grad.clone(), hess.clone()));
7170 if !cost.is_finite() {
7171 return Ok(OuterEval::infeasible(theta.len()));
7172 }
7173 if grad.iter().any(|v| !v.is_finite()) {
7186 return Ok(OuterEval::infeasible(theta.len()));
7187 }
7188 Ok(OuterEval {
7189 cost,
7190 gradient: grad,
7191 hessian: hess,
7192 inner_beta_hint: None,
7193 })
7194 }
7195 Err(err) => {
7196 log::warn!(
7197 "[OUTER] n-block exact-joint spatial: exact evaluation failed: {err}"
7198 );
7199 Ok(OuterEval::infeasible(theta.len()))
7200 }
7201 }
7202 };
7203
7204 let obj = problem.build_objective_with_eval_order(
7205 &mut state,
7206 |ctx: &mut &mut NBlockExactJointState<'_>, theta: &Array1<f64>| {
7207 if let Some(cost) = ctx.cache.memoized_cost(theta) {
7208 return Ok(cost);
7209 }
7210 if gam_solve::rho_optimizer::outer_wall_clock_deadline_exceeded() {
7218 return Ok(f64::INFINITY);
7219 }
7220 if let Err(err) = ctx.cache.ensure_theta(theta) {
7221 log::warn!(
7222 "[OUTER] n-block exact-joint spatial: ensure_theta failed during cost evaluation: {err}"
7223 );
7224 return Ok(f64::INFINITY);
7225 }
7226 let design_revision = Some(ctx.cache.design_revision());
7227 let specs = collect_specs(&ctx.cache);
7228 let designs = collect_designs(&ctx.cache);
7229 let t0 = std::time::Instant::now();
7236 let result = {
7237 let row_set_borrow = current_row_set.borrow();
7238 (*exact_fn_cell.borrow_mut())(
7239 theta,
7240 &specs,
7241 &designs,
7242 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueOnly,
7243 &row_set_borrow,
7244 )
7245 };
7246 let elapsed_s = t0.elapsed().as_secs_f64();
7247 kphase_cost_calls.set(kphase_cost_calls.get() + 1);
7248 kphase_cost_total_s.set(kphase_cost_total_s.get() + elapsed_s);
7249 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta);
7250 log::info!(
7251 "[KAPPA-PHASE] phase=cost call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
7252 kphase_cost_calls.get(),
7253 design_revision,
7254 theta_norm,
7255 log_kappa_norm,
7256 elapsed_s,
7257 );
7258 match result {
7259 Ok((cost, _grad, _hess)) => {
7260 ctx.cache.store_cost_only(theta, cost);
7266 Ok(cost)
7267 }
7268 Err(err) => {
7269 log::warn!(
7270 "[OUTER] n-block exact-joint spatial: exact cost evaluation failed: {err}"
7271 );
7272 Ok(f64::INFINITY)
7273 }
7274 }
7275 },
7276 |ctx: &mut &mut NBlockExactJointState<'_>, theta: &Array1<f64>| {
7277 eval_outer(
7278 ctx,
7279 theta,
7280 if analytic_outer_hessian_available {
7281 OuterEvalOrder::ValueGradientHessian
7282 } else {
7283 OuterEvalOrder::ValueAndGradient
7284 },
7285 )
7286 },
7287 |ctx: &mut &mut NBlockExactJointState<'_>,
7288 theta: &Array1<f64>,
7289 order: OuterEvalOrder| { eval_outer(ctx, theta, order) },
7290 None::<fn(&mut &mut NBlockExactJointState<'_>)>,
7291 Some(
7292 |ctx: &mut &mut NBlockExactJointState<'_>, theta: &Array1<f64>| {
7293 ctx.cache
7294 .ensure_theta(theta)
7295 .map_err(EstimationError::InvalidInput)?;
7296 let design_revision = Some(ctx.cache.design_revision());
7297 let specs = collect_specs(&ctx.cache);
7298 let designs = collect_designs(&ctx.cache);
7299 let t0 = std::time::Instant::now();
7300 let eval_result = (*exact_efs_fn_cell.borrow_mut())(
7301 theta,
7302 &specs,
7303 &designs,
7304 );
7305 let elapsed_s = t0.elapsed().as_secs_f64();
7306 kphase_efs_calls.set(kphase_efs_calls.get() + 1);
7307 kphase_efs_total_s.set(kphase_efs_total_s.get() + elapsed_s);
7308 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta);
7309 log::info!(
7310 "[KAPPA-PHASE] phase=efs call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
7311 kphase_efs_calls.get(),
7312 design_revision,
7313 theta_norm,
7314 log_kappa_norm,
7315 elapsed_s,
7316 );
7317 let eval = eval_result.map_err(EstimationError::RemlOptimizationFailed)?;
7318 Ok(eval)
7319 },
7320 ),
7321 );
7322 let mut obj = obj.with_seed_inner_state(
7323 move |_ctx: &mut &mut NBlockExactJointState<'_>, beta: &Array1<f64>| {
7324 (seed_inner_beta_fn)(beta)
7325 },
7326 );
7327
7328 match problem.run(&mut obj, "n-block exact-joint spatial") {
7329 Ok(result) => result,
7330 Err(e) => {
7331 let message = e.to_string();
7332 if kappa_phase_failure_is_fixed_kappa_recoverable(&message) {
7352 drop(obj);
7353 log::warn!(
7354 "[KAPPA-PHASE] length-scale optimization could not validate any seed \
7355 ({message}); falling back to a FIXED bootstrap κ (skipping κ \
7356 optimization) and fitting there — a real model at the initial \
7357 length-scale rather than raising (gam#787/#860)."
7358 );
7359 let (designs, resolved_specs) =
7360 build_term_collection_designs_and_freeze_joint(data, block_specs).map_err(
7361 |build_err| {
7362 format!(
7363 "fixed-κ fallback failed to build and freeze joint block \
7364 designs after κ optimization could not validate a seed \
7365 ({message}): {build_err}"
7366 )
7367 },
7368 )?;
7369 let fixed_theta0 = joint_setup.theta0();
7370 let spec_refs: Vec<TermCollectionSpec> = resolved_specs.clone();
7371 let design_refs: Vec<TermCollectionDesign> = designs.clone();
7372 let fit = fit_fn(&fixed_theta0, &spec_refs, &design_refs)?;
7373 return Ok(SpatialLengthScaleOptimizationResult {
7374 resolved_specs,
7375 designs,
7376 fit,
7377 timing: None,
7378 });
7379 }
7380 return Err(message);
7381 }
7382 }
7383 }; let kphase_total_s = kphase_optim_start.elapsed().as_secs_f64();
7393 log::info!(
7394 "[KAPPA-PHASE-SUMMARY] log_kappa_dim={} n_cost={} cost_total_s={:.4} n_eval={} eval_total_s={:.4} n_efs={} efs_total_s={:.4} optim_total_s={:.4}",
7395 kphase_log_kappa_dim,
7396 kphase_cost_calls.get(),
7397 kphase_cost_total_s.get(),
7398 kphase_eval_calls.get(),
7399 kphase_eval_total_s.get(),
7400 kphase_efs_calls.get(),
7401 kphase_efs_total_s.get(),
7402 kphase_total_s,
7403 );
7404 let timing = SpatialLengthScaleOptimizationTiming {
7405 log_kappa_dim: kphase_log_kappa_dim,
7406 cost_calls: kphase_cost_calls.get(),
7407 cost_total_s: kphase_cost_total_s.get(),
7408 eval_calls: kphase_eval_calls.get(),
7409 eval_total_s: kphase_eval_total_s.get(),
7410 efs_calls: kphase_efs_calls.get(),
7411 efs_total_s: kphase_efs_total_s.get(),
7412 slow_path_resets: 0,
7413 design_revision_delta: 0,
7414 nfree_miss_shape: 0,
7415 nfree_miss_value: 0,
7416 nfree_miss_gradient: 0,
7417 nfree_miss_penalty: 0,
7418 nfree_miss_revision: 0,
7419 nfree_miss_second_order: 0,
7420 nfree_miss_other: 0,
7421 optim_total_s: kphase_total_s,
7422 };
7423
7424 let theta_star = result.rho;
7425
7426 if use_staged_kappa && n_total >= KAPPA_POLISH_TRIGGER_N {
7443 let polish = build_uniform_pilot_subsample(
7444 n_total,
7445 KAPPA_POLISH_K,
7446 (n_total as u64).wrapping_add(0xA5A5A5A5),
7447 );
7448 *current_row_set.borrow_mut() = gam_problem::outer_subsample::RowSet::Subsample {
7449 rows: std::sync::Arc::clone(&polish.rows),
7450 n_full: n_total,
7451 };
7452 log::info!(
7453 "[KAPPA-STAGED] rotating to polish subsample: k={} at theta_star",
7454 polish.rows.len(),
7455 );
7456 state.cache.ensure_theta(&theta_star)?;
7460 let (polish_cost, polish_grad, _) = {
7461 let specs = collect_specs(&state.cache);
7462 let designs = collect_designs(&state.cache);
7463 let row_set_borrow = current_row_set.borrow();
7464 exact_fn(
7465 &theta_star,
7466 &specs,
7467 &designs,
7468 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueAndGradient,
7469 &row_set_borrow,
7470 )?
7471 };
7472 if !polish_cost.is_finite() || polish_grad.iter().any(|value| !value.is_finite()) {
7473 return Err(
7474 "polish subsample exact-joint evaluation produced non-finite objective pieces"
7475 .to_string(),
7476 );
7477 }
7478 }
7479 *current_row_set.borrow_mut() = gam_problem::outer_subsample::RowSet::All;
7480 if use_staged_kappa {
7481 log::info!(
7482 "[KAPPA-STAGED] rotating to full data for final coefficient fit (n={})",
7483 n_total,
7484 );
7485 }
7486
7487 state.cache.ensure_theta(&theta_star)?;
7488
7489 let resolved_specs: Vec<TermCollectionSpec> = collect_specs(&state.cache);
7490 let designs: Vec<TermCollectionDesign> = collect_designs(&state.cache);
7491
7492 let fit = fit_fn(&theta_star, &resolved_specs, &designs)?;
7493
7494 for spec in &resolved_specs {
7495 log_spatial_aniso_scales(spec);
7496 }
7497
7498 Ok(SpatialLengthScaleOptimizationResult {
7499 resolved_specs,
7500 designs,
7501 fit,
7502 timing: Some(timing),
7503 })
7504}
7505
7506fn try_exact_joint_latent_coord_optimization(
7507 data: ArrayView2<'_, f64>,
7508 y: ArrayView1<'_, f64>,
7509 weights: ArrayView1<'_, f64>,
7510 offset: ArrayView1<'_, f64>,
7511 resolvedspec: &TermCollectionSpec,
7512 best: &FittedTermCollection,
7513 family: LikelihoodSpec,
7514 options: &FitOptions,
7515 latent: &StandardLatentCoordConfig,
7516) -> Result<FittedTermCollectionWithSpec, EstimationError> {
7517 use gam_solve::rho_optimizer::OuterEvalOrder;
7518 use gam_problem::{DeclaredHessianForm, Derivative, OuterEval};
7519
7520 let rho_dim = best.fit.lambdas.len();
7521 let latent_flat_dim = latent.values.len();
7522 if latent_flat_dim == 0 {
7523 crate::bail_invalid_estim!(
7524 "latent-coordinate optimization requires a non-empty latent block"
7525 );
7526 }
7527 let direct_hypers =
7528 latent_coord_initial_direct_hypers(latent.values.id_mode(), latent.values.latent_dim())?;
7529 let analytic_rho_count = latent
7530 .analytic_penalties
7531 .as_ref()
7532 .map_or(0, |registry| registry.total_rho_count());
7533 let latent_coord_ext_dim = latent_flat_dim + analytic_rho_count + direct_hypers.len();
7534
7535 let mut theta0 = Array1::<f64>::zeros(rho_dim + latent_coord_ext_dim);
7536 theta0
7537 .slice_mut(s![..rho_dim])
7538 .assign(&best.fit.lambdas.mapv(f64::ln));
7539 theta0
7540 .slice_mut(s![rho_dim..rho_dim + latent_flat_dim])
7541 .assign(latent.values.as_flat());
7542 if !direct_hypers.is_empty() {
7543 let direct_start = rho_dim + latent_flat_dim + analytic_rho_count;
7544 theta0
7545 .slice_mut(s![direct_start..direct_start + direct_hypers.len()])
7546 .assign(&direct_hypers);
7547 }
7548
7549 let mut lower = Array1::<f64>::from_elem(theta0.len(), -12.0);
7550 let mut upper = Array1::<f64>::from_elem(theta0.len(), 12.0);
7551 let latent_bound = latent
7552 .values
7553 .as_flat()
7554 .iter()
7555 .fold(1.0_f64, |acc, &v| acc.max(v.abs()))
7556 + 10.0;
7557 for axis in rho_dim..rho_dim + latent_flat_dim {
7558 lower[axis] = -latent_bound;
7559 upper[axis] = latent_bound;
7560 }
7561
7562 struct LatentJointContext<'d> {
7563 rho_dim: usize,
7564 cache: SingleBlockLatentCoordDesignCache,
7565 evaluator: gam_solve::estimate::ExternalJointHyperEvaluator<'d>,
7566 }
7567
7568 impl<'d> LatentJointContext<'d> {
7569 fn eval_full(
7570 &mut self,
7571 theta: &Array1<f64>,
7572 order: OuterEvalOrder,
7573 ) -> Result<
7574 (
7575 f64,
7576 Array1<f64>,
7577 gam_problem::HessianResult,
7578 ),
7579 EstimationError,
7580 > {
7581 if let Some(eval) = self.cache.memoized_eval(theta) {
7582 return Ok(eval);
7583 }
7584 self.cache
7585 .ensure_theta(theta)
7586 .map_err(EstimationError::InvalidInput)?;
7587 let hyper_dirs = self
7588 .cache
7589 .hyper_dirs()
7590 .map_err(EstimationError::InvalidInput)?;
7591 let design_revision = Some(self.cache.design_revision());
7592 let registry_for_key = self.cache.analytic_penalties();
7593 self.evaluator
7594 .set_analytic_penalty_registry(registry_for_key.as_deref());
7595 let mut eval = evaluate_joint_reml_outer_eval_at_theta(
7596 &mut self.evaluator,
7597 self.cache.design(),
7598 theta,
7599 self.rho_dim,
7600 hyper_dirs,
7601 None,
7602 order,
7603 design_revision,
7604 )?;
7605 let latent = self.cache.latent().map_err(EstimationError::InvalidInput)?;
7606 if let Some(registry) = registry_for_key {
7607 let mut registry = registry.as_ref().clone();
7608 registry.apply_weight_schedules(
7609 gam_solve::estimate::reml::outer_eval::current_outer_iter() as usize,
7610 );
7611 add_analytic_penalty_objective_to_eval(
7612 theta,
7613 self.rho_dim,
7614 latent.as_ref(),
7615 ®istry,
7616 &mut eval,
7617 )?;
7618 }
7619 add_latent_id_objective_to_eval(
7620 theta,
7621 self.rho_dim,
7622 self.cache.analytic_penalty_rho_count(),
7623 latent.as_ref(),
7624 &mut eval,
7625 )?;
7626 self.cache.store_eval(eval.clone());
7627 Ok(eval)
7628 }
7629
7630 fn eval_efs(
7631 &mut self,
7632 theta: &Array1<f64>,
7633 ) -> Result<gam_problem::EfsEval, EstimationError> {
7634 self.cache
7635 .ensure_theta(theta)
7636 .map_err(EstimationError::InvalidInput)?;
7637 let hyper_dirs = self
7638 .cache
7639 .hyper_dirs()
7640 .map_err(EstimationError::InvalidInput)?;
7641 let registry_for_key = self.cache.analytic_penalties();
7642 self.evaluator
7643 .set_analytic_penalty_registry(registry_for_key.as_deref());
7644 let mut efs = evaluate_joint_reml_efs_at_theta(
7645 &mut self.evaluator,
7646 self.cache.design(),
7647 theta,
7648 self.rho_dim,
7649 hyper_dirs,
7650 None,
7651 Some(self.cache.design_revision()),
7652 )?;
7653 if let Some(registry) = registry_for_key {
7654 let mut registry = registry.as_ref().clone();
7655 registry.apply_weight_schedules(
7656 gam_solve::estimate::reml::outer_eval::current_outer_iter() as usize,
7657 );
7658 let latent = self.cache.latent().map_err(EstimationError::InvalidInput)?;
7659 let contribution = analytic_penalty_objective_contribution(
7660 theta,
7661 self.rho_dim,
7662 latent.as_ref(),
7663 ®istry,
7664 )?;
7665 efs.cost += contribution.cost;
7666 if let (Some(psi_gradient), Some(psi_indices)) =
7667 (efs.psi_gradient.as_mut(), efs.psi_indices.as_ref())
7668 {
7669 if psi_gradient.len() != psi_indices.len() {
7670 crate::bail_invalid_estim!(
7671 "latent-coordinate analytic penalty EFS psi gradient length mismatch: gradient={}, indices={}",
7672 psi_gradient.len(),
7673 psi_indices.len()
7674 );
7675 }
7676 for (local_idx, &theta_idx) in psi_indices.iter().enumerate() {
7677 psi_gradient[local_idx] += contribution.gradient[theta_idx];
7678 }
7679 }
7680 }
7681 Ok(efs)
7682 }
7683
7684 fn eval_cost(&mut self, theta: &Array1<f64>) -> f64 {
7685 if let Some(cost) = self.cache.memoized_cost(theta) {
7686 return cost;
7687 }
7688 if self.cache.ensure_theta(theta).is_err() {
7689 return f64::INFINITY;
7690 }
7691 let design_revision = Some(self.cache.design_revision());
7692 let registry_for_key = self.cache.analytic_penalties();
7693 self.evaluator
7694 .set_analytic_penalty_registry(registry_for_key.as_deref());
7695 let result = {
7696 let design = self.cache.design();
7697 self.evaluator.evaluate_cost_only(
7698 &design.design,
7699 &design.penalties,
7700 &design.nullspace_dims,
7701 design.linear_constraints.clone(),
7702 theta,
7703 self.rho_dim,
7704 None,
7705 "latent-coordinate-joint cost-only",
7706 design_revision,
7707 )
7708 };
7709 match result {
7710 Ok(cost) => {
7711 let latent = match self.cache.latent() {
7712 Ok(latent) => latent,
7713 Err(_) => return f64::INFINITY,
7714 };
7715 let contribution = match latent_id_objective_contribution(
7716 theta,
7717 self.rho_dim,
7718 self.cache.analytic_penalty_rho_count(),
7719 latent.as_ref(),
7720 ) {
7721 Ok(contribution) => contribution,
7722 Err(_) => return f64::INFINITY,
7723 };
7724 let cost = cost + contribution.cost;
7725 let cost = if let Some(registry) = registry_for_key {
7726 let mut registry = registry.as_ref().clone();
7727 registry.apply_weight_schedules(
7728 gam_solve::estimate::reml::outer_eval::current_outer_iter()
7729 as usize,
7730 );
7731 match analytic_penalty_objective_contribution(
7732 theta,
7733 self.rho_dim,
7734 latent.as_ref(),
7735 ®istry,
7736 ) {
7737 Ok(contribution) => cost + contribution.cost,
7738 Err(_) => return f64::INFINITY,
7739 }
7740 } else {
7741 cost
7742 };
7743 self.cache.store_cost(cost);
7744 cost
7745 }
7746 Err(_) => f64::INFINITY,
7747 }
7748 }
7749 }
7750
7751 let mut ctx = LatentJointContext {
7752 rho_dim,
7753 cache: SingleBlockLatentCoordDesignCache::new(
7754 data.to_owned(),
7755 resolvedspec.clone(),
7756 best.design.clone(),
7757 latent,
7758 rho_dim,
7759 )
7760 .map_err(EstimationError::InvalidInput)?,
7761 evaluator: gam_solve::estimate::ExternalJointHyperEvaluator::new(
7762 y,
7763 weights,
7764 &best.design.design,
7765 offset,
7766 &best.design.penalties,
7767 &external_opts_for_design(&family, &best.design, options),
7768 "latent-coordinate-joint",
7769 )?,
7770 };
7771 let registry_for_key = ctx.cache.analytic_penalties();
7772 ctx.evaluator
7773 .set_analytic_penalty_registry(registry_for_key.as_deref());
7774 ctx.evaluator
7775 .set_persistent_latent_values_fingerprint(latent.values.id_mode());
7776 if let Some(cached_t) = ctx
7777 .evaluator
7778 .load_persistent_latent_values(latent.values.n_obs(), latent.values.latent_dim())
7779 {
7780 let cached_t: Array2<f64> = cached_t;
7781 for (dst, src) in theta0
7782 .slice_mut(s![rho_dim..rho_dim + latent_flat_dim])
7783 .iter_mut()
7784 .zip(cached_t.iter())
7785 {
7786 *dst = *src;
7787 }
7788 }
7789
7790 let problem = exact_joint_multistart_outer_problem(
7791 &theta0,
7792 &lower,
7793 &upper,
7794 rho_dim,
7795 latent_coord_ext_dim,
7796 theta0.len(),
7797 Derivative::Analytic,
7798 DeclaredHessianForm::Unavailable,
7799 false,
7800 false,
7801 seed_risk_profile_for_likelihood_family(&family),
7802 options.tol,
7803 options.max_iter.max(1),
7804 Some(5.0),
7805 Some(0.5),
7806 None,
7807 Some((data.nrows(), best.design.design.ncols().max(1))),
7810 !constant_curvature_term_indices(resolvedspec).is_empty(),
7813 );
7814
7815 let eval_outer = |ctx: &mut &mut LatentJointContext<'_>,
7816 theta: &Array1<f64>,
7817 order: OuterEvalOrder|
7818 -> Result<OuterEval, EstimationError> {
7819 let (cost, gradient, hessian) = ctx.eval_full(theta, order)?;
7820 Ok(OuterEval {
7821 cost,
7822 gradient,
7823 hessian,
7824 inner_beta_hint: None,
7825 })
7826 };
7827
7828 let result = {
7829 let mut obj = problem.build_objective_with_eval_order(
7830 &mut ctx,
7831 |ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>| Ok(ctx.eval_cost(theta)),
7832 |ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>| {
7833 eval_outer(ctx, theta, OuterEvalOrder::ValueAndGradient)
7834 },
7835 |ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>, order: OuterEvalOrder| {
7836 eval_outer(ctx, theta, order)
7837 },
7838 Some(|ctx: &mut &mut LatentJointContext<'_>| {
7839 ctx.cache.reset();
7840 }),
7841 Some(|ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>| ctx.eval_efs(theta)),
7842 );
7843
7844 problem
7845 .run(&mut obj, "latent-coordinate joint REML")
7846 .map_err(|e| {
7847 EstimationError::InvalidInput(format!(
7848 "latent-coordinate joint optimization failed after exhausting strategy fallbacks: {e}"
7849 ))
7850 })?
7851 };
7852 if !result.converged {
7853 crate::bail_invalid_estim!(
7854 "latent-coordinate joint optimization did not converge after {} iterations (final_objective={:.6e}, final_grad_norm={})",
7855 result.iterations,
7856 result.final_value,
7857 result.final_grad_norm_report(),
7858 );
7859 }
7860
7861 let theta_star = result.rho;
7862 let rho_star = theta_star.slice(s![..rho_dim]).mapv(f64::exp);
7863 let mut final_data = data.to_owned();
7864 let flat_t = theta_star
7865 .slice(s![rho_dim..rho_dim + latent_flat_dim])
7866 .to_owned();
7867 let mut fitted_latent_values =
7868 Array2::<f64>::zeros((latent.values.n_obs(), latent.values.latent_dim()));
7869 for n in 0..latent.values.n_obs() {
7870 for axis in 0..latent.values.latent_dim() {
7871 let value = flat_t[n * latent.values.latent_dim() + axis];
7872 fitted_latent_values[[n, axis]] = value;
7873 final_data[[n, latent.feature_cols[axis]]] = value;
7874 }
7875 }
7876 let optimized = fit_term_collection_forspecwith_heuristic_lambdas(
7877 final_data.view(),
7878 y,
7879 weights,
7880 offset,
7881 resolvedspec,
7882 rho_star.as_slice(),
7883 family,
7884 options,
7885 )?;
7886 ctx.evaluator
7887 .store_persistent_latent_values(&fitted_latent_values);
7888 let mut fit = optimized.fit;
7889 fit.reml_score = result.final_value;
7890 fit.penalized_objective = result.final_value;
7891 Ok(FittedTermCollectionWithSpec {
7892 fit,
7893 design: optimized.design,
7894 resolvedspec: resolvedspec.clone(),
7895 adaptive_diagnostics: optimized.adaptive_diagnostics,
7896 kappa_timing: None,
7897 })
7898}
7899
7900pub fn fit_term_collectionwith_latent_coord_optimization(
7901 data: ArrayView2<'_, f64>,
7902 y: Array1<f64>,
7903 weights: Array1<f64>,
7904 offset: Array1<f64>,
7905 spec: &TermCollectionSpec,
7906 latent: &StandardLatentCoordConfig,
7907 family: LikelihoodSpec,
7908 options: &FitOptions,
7909) -> Result<FittedTermCollectionWithSpec, EstimationError> {
7910 let n = data.nrows();
7911 if !(y.len() == n && weights.len() == n && offset.len() == n) {
7912 crate::bail_invalid_estim!(
7913 "fit_term_collectionwith_latent_coord_optimization row mismatch: n={}, y={}, weights={}, offset={}",
7914 n,
7915 y.len(),
7916 weights.len(),
7917 offset.len()
7918 );
7919 }
7920 let best = fit_term_collection_forspec(
7921 data,
7922 y.view(),
7923 weights.view(),
7924 offset.view(),
7925 spec,
7926 family.clone(),
7927 options,
7928 )?;
7929 let resolvedspec = freeze_term_collection_from_design(spec, &best.design)?;
7930 try_exact_joint_latent_coord_optimization(
7931 data,
7932 y.view(),
7933 weights.view(),
7934 offset.view(),
7935 &resolvedspec,
7936 &best,
7937 family,
7938 options,
7939 latent,
7940 )
7941}
7942
7943pub fn fit_term_collectionwith_spatial_length_scale_optimization(
7944 data: ArrayView2<'_, f64>,
7945 y: Array1<f64>,
7946 weights: Array1<f64>,
7947 offset: Array1<f64>,
7948 spec: &TermCollectionSpec,
7949 family: LikelihoodSpec,
7950 options: &FitOptions,
7951 kappa_options: &SpatialLengthScaleOptimizationOptions,
7952) -> Result<FittedTermCollectionWithSpec, EstimationError> {
7953 let mut resolvedspec = spec.clone();
7969 let spatial_terms = spatial_length_scale_term_indices(&resolvedspec);
7970 let n = data.nrows();
7971 if !(y.len() == n && weights.len() == n && offset.len() == n) {
7972 crate::bail_invalid_estim!(
7973 "fit_term_collectionwith_spatial_length_scale_optimization row mismatch: n={}, y={}, weights={}, offset={}",
7974 n,
7975 y.len(),
7976 weights.len(),
7977 offset.len()
7978 );
7979 }
7980 if !kappa_options.enabled || spatial_terms.is_empty() {
7981 let out = fit_term_collection_forspec(
7982 data,
7983 y.view(),
7984 weights.view(),
7985 offset.view(),
7986 &resolvedspec,
7987 family,
7988 options,
7989 )?;
7990 let resolvedspec = freeze_term_collection_from_design(&resolvedspec, &out.design)?;
7991 return Ok(FittedTermCollectionWithSpec {
7992 fit: out.fit,
7993 design: out.design,
7994 resolvedspec,
7995 adaptive_diagnostics: out.adaptive_diagnostics,
7996 kappa_timing: None,
7997 });
7998 }
7999 if kappa_options.max_outer_iter == 0 {
8000 crate::bail_invalid_estim!("spatial kappa optimization requires max_outer_iter >= 1");
8001 }
8002 if !(kappa_options.log_step.is_finite() && kappa_options.log_step > 0.0) {
8003 crate::bail_invalid_estim!("spatial kappa optimization requires log_step > 0");
8004 }
8005 if !(kappa_options.min_length_scale.is_finite()
8006 && kappa_options.max_length_scale.is_finite()
8007 && kappa_options.min_length_scale > 0.0
8008 && kappa_options.max_length_scale >= kappa_options.min_length_scale)
8009 {
8010 crate::bail_invalid_estim!(
8011 "spatial kappa optimization requires valid positive length_scale bounds"
8012 );
8013 }
8014
8015 let pilot_threshold = kappa_options.pilot_subsample_threshold;
8016 if pilot_threshold > 0 && n > pilot_threshold * 2 {
8017 log::info!(
8018 "[spatial-kappa] n={n} exceeds pilot threshold {}; using pilot geometry only for deterministic anisotropy initialization",
8019 pilot_threshold * 2,
8020 );
8021 apply_spatial_anisotropy_pilot_initializer(
8022 data,
8023 &mut resolvedspec,
8024 &spatial_terms,
8025 pilot_threshold,
8026 kappa_options,
8027 );
8028 }
8029
8030 apply_response_aware_anisotropy_seed(data, y.view(), &mut resolvedspec, &spatial_terms);
8039
8040 for term_idx in constant_curvature_term_indices(&resolvedspec) {
8058 if let Some(kappa_seed) =
8059 select_constant_curvature_kappa_sign_seed(data, y.view(), &resolvedspec, term_idx)
8060 && kappa_seed != 0.0
8061 && let Some(SmoothBasisSpec::ConstantCurvature { spec: cc, .. }) =
8062 resolvedspec.smooth_terms.get_mut(term_idx).map(|t| &mut t.basis)
8063 {
8064 log::info!(
8065 "[#1464] pinned CC term {term_idx} baseline κ to κ-fair scan value {kappa_seed} \
8066 (raw profiled REML is sign-blind; scan is authoritative for the sign)"
8067 );
8068 cc.kappa = kappa_seed;
8069 }
8070 }
8071
8072 let baseline_options = superseded_fit_options(options);
8073 let mut best = fit_term_collection_forspec(
8074 data,
8075 y.view(),
8076 weights.view(),
8077 offset.view(),
8078 &resolvedspec,
8079 family.clone(),
8080 &baseline_options,
8081 )?;
8082 resolvedspec = freeze_term_collection_from_design(&resolvedspec, &best.design)?;
8083 let mut spatial_terms = spatial_length_scale_term_indices(&resolvedspec);
8093 sync_aniso_contrasts_from_metadata(&mut resolvedspec, &best.design.smooth);
8097 let mut prescan_improved = false;
8104 if !spatial_terms.is_empty() {
8105 let baseline_score = fit_score(&best.fit);
8106 let range_overrides = prescan_isotropic_spatial_range_seed(
8107 data,
8108 y.view(),
8109 weights.view(),
8110 offset.view(),
8111 &resolvedspec,
8112 baseline_score,
8113 &family,
8114 &baseline_options,
8115 kappa_options,
8116 &spatial_terms,
8117 )?;
8118 if !range_overrides.is_empty() {
8119 prescan_improved = true;
8120 for (term_idx, length_scale) in range_overrides {
8121 set_spatial_length_scale(&mut resolvedspec, term_idx, length_scale)?;
8122 }
8123 best = fit_term_collection_forspec(
8127 data,
8128 y.view(),
8129 weights.view(),
8130 offset.view(),
8131 &resolvedspec,
8132 family.clone(),
8133 &baseline_options,
8134 )?;
8135 resolvedspec = freeze_term_collection_from_design(&resolvedspec, &best.design)?;
8136 spatial_terms = spatial_length_scale_term_indices(&resolvedspec);
8140 }
8141 }
8142 if spatial_terms.is_empty() {
8143 let fitted = fit_term_collection_forspecwith_heuristic_lambdas(
8144 data,
8145 y.view(),
8146 weights.view(),
8147 offset.view(),
8148 &resolvedspec,
8149 best.fit.lambdas.as_slice(),
8150 family,
8151 options,
8152 )?;
8153 return Ok(FittedTermCollectionWithSpec {
8154 fit: fitted.fit,
8155 design: fitted.design,
8156 resolvedspec,
8157 adaptive_diagnostics: fitted.adaptive_diagnostics,
8158 kappa_timing: None,
8159 });
8160 }
8161 let initial_score = fit_score(&best.fit);
8162 if !initial_score.is_finite() {
8163 log::debug!("[spatial-kappa] initial profiled score is non-finite");
8164 }
8165 let joint_result = try_exact_joint_spatial_length_scale_optimization(
8166 data,
8167 y.view(),
8168 weights.view(),
8169 offset.view(),
8170 &resolvedspec,
8171 &best,
8172 family.clone(),
8173 options,
8174 kappa_options,
8175 &spatial_terms,
8176 )
8177 .map(|opt| {
8178 opt.map(|fit| {
8179 let score = fit_score(&fit.fit);
8180 (fit, score)
8181 })
8182 });
8183 let exact_joint = if prescan_improved && !matches!(joint_result, Ok(Some(_))) {
8193 let reason = match &joint_result {
8194 Err(e) => format!("error: {e}"),
8195 _ => "unavailable".to_string(),
8196 };
8197 log::info!(
8198 "[spatial-kappa] #1074 joint polish yielded no usable candidate \
8199 ({reason}); returning the multi-start pre-scan geometry (REML {initial_score:.5})"
8200 );
8201 FittedTermCollectionWithSpec {
8202 fit: best.fit,
8203 design: best.design,
8204 resolvedspec,
8205 adaptive_diagnostics: best.adaptive_diagnostics,
8206 kappa_timing: None,
8207 }
8208 } else {
8209 require_successful_spatial_optimization_result(initial_score, joint_result)?
8210 };
8211 log_spatial_aniso_scales(&exact_joint.resolvedspec);
8212 Ok(exact_joint)
8213}
8214
8215#[derive(Clone, Debug)]
8221pub struct CurvatureInference {
8222 pub term_idx: usize,
8224 pub kappa_hat: f64,
8227 pub ci: gam_geometry::curvature_estimand::KappaProfileCi,
8229 pub flatness: gam_geometry::curvature_estimand::FlatnessTest,
8233}
8234
8235pub fn curvature_inference_forspec(
8253 data: ArrayView2<'_, f64>,
8254 y: ArrayView1<'_, f64>,
8255 weights: ArrayView1<'_, f64>,
8256 offset: ArrayView1<'_, f64>,
8257 resolvedspec: &TermCollectionSpec,
8258 term_idx: usize,
8259 family: LikelihoodSpec,
8260 options: &FitOptions,
8261 level: f64,
8262) -> Result<CurvatureInference, EstimationError> {
8263 let kappa_hat = get_constant_curvature_kappa(resolvedspec, term_idx).ok_or_else(|| {
8264 EstimationError::InvalidInput(format!(
8265 "curvature_inference_forspec: term {term_idx} is not a constant-curvature smooth"
8266 ))
8267 })?;
8268 let (kappa_min, kappa_max) = constant_curvature_kappa_bounds(data, resolvedspec, term_idx);
8269
8270 let cc_fair_inputs: Option<(Array2<f64>, gam_terms::basis::ConstantCurvatureBasisSpec)> =
8295 if kappa_hat < 0.0 {
8296 match resolvedspec.smooth_terms.get(term_idx).map(|t| &t.basis) {
8297 Some(SmoothBasisSpec::ConstantCurvature {
8298 feature_cols, spec, ..
8299 }) => select_columns(data, feature_cols)
8300 .ok()
8301 .map(|x| (x, spec.clone())),
8302 _ => None,
8303 }
8304 } else {
8305 None
8306 };
8307
8308 let v_p_cache: std::cell::RefCell<std::collections::HashMap<u64, f64>> =
8313 std::cell::RefCell::new(std::collections::HashMap::new());
8314 let v_p = |kappa: f64| -> Result<f64, String> {
8315 if !kappa.is_finite() {
8316 return Err(format!("V_p probed a non-finite κ = {kappa}"));
8317 }
8318 let key = kappa.to_bits();
8319 if let Some(&cached) = v_p_cache.borrow().get(&key) {
8320 return Ok(cached);
8321 }
8322 let score = if let Some((x_term, base_spec)) = &cc_fair_inputs {
8323 let mut probe_spec = base_spec.clone();
8324 probe_spec.kappa = kappa;
8325 gam_terms::basis::constant_curvature_kappa_fair_sign_score(x_term.view(), y, &probe_spec)
8326 .map_err(|e| format!("κ-fair criterion at κ={kappa} failed: {e}"))?
8327 } else {
8328 fixed_kappa_profiled_reml_score(
8329 data,
8330 y,
8331 weights,
8332 offset,
8333 resolvedspec,
8334 term_idx,
8335 kappa,
8336 family.clone(),
8337 options,
8338 )
8339 .map_err(|e| format!("V_p fixed-κ fit at κ={kappa} failed: {e}"))?
8340 };
8341 v_p_cache.borrow_mut().insert(key, score);
8342 Ok(score)
8343 };
8344
8345 let h = (1e-3 * (kappa_max - kappa_min)).max(1e-4);
8349 let v_pp = match (v_p(kappa_hat + h), v_p(kappa_hat), v_p(kappa_hat - h)) {
8350 (Ok(vp), Ok(v0), Ok(vm)) => (vp - 2.0 * v0 + vm) / (h * h),
8351 _ => f64::NAN, };
8353
8354 let ci = gam_geometry::curvature_estimand::profile_ci_walk(
8355 &v_p, kappa_hat, v_pp, kappa_min, kappa_max, level, 1e-4,
8356 )
8357 .map_err(EstimationError::InvalidInput)?;
8358 let flatness = gam_geometry::curvature_estimand::flatness_lr_test(&v_p, kappa_hat)
8359 .map_err(EstimationError::InvalidInput)?;
8360
8361 Ok(CurvatureInference {
8362 term_idx,
8363 kappa_hat,
8364 ci,
8365 flatness,
8366 })
8367}
8368
8369#[derive(Clone, Copy, Debug, PartialEq, Eq)]
8372pub enum SmoothLrCorrection {
8373 LawleyLrEstimatedLambda,
8377 LawleyLrFixedLambda,
8382 None,
8386}
8387
8388impl SmoothLrCorrection {
8389 pub fn label(self) -> &'static str {
8391 match self {
8392 SmoothLrCorrection::LawleyLrEstimatedLambda => "lawley_lr_estimated_lambda",
8393 SmoothLrCorrection::LawleyLrFixedLambda => "lawley_lr_fixed_lambda",
8394 SmoothLrCorrection::None => "none",
8395 }
8396 }
8397}
8398
8399#[derive(Clone, Debug)]
8405pub struct SmoothTermLrInference {
8406 pub name: String,
8408 pub term_idx: usize,
8410 pub statistic_lr: f64,
8413 pub ref_df: f64,
8416 pub bartlett_factor: f64,
8419 pub bartlett_factor_conditional: Option<f64>,
8423 pub rho_variation_shift: Option<f64>,
8426 pub statistic_corrected: f64,
8428 pub p_value_uncorrected: f64,
8430 pub p_value_corrected: f64,
8433 pub material: bool,
8441 pub correction: SmoothLrCorrection,
8443}
8444
8445pub const SMOOTH_LR_MATERIAL_THRESHOLD: f64 = 0.10;
8449
8450fn fitted_rho_penalty_components(
8456 penalties: &[BlockwisePenalty],
8457 lambdas: &[f64],
8458 p_total: usize,
8459) -> Result<Vec<gam_terms::inference::lawley::RhoPenaltyComponent>, EstimationError> {
8460 if penalties.len() != lambdas.len() {
8461 return Err(EstimationError::InvalidInput(format!(
8462 "smooth_term_lr_inference: penalty/lambda count mismatch ({} penalties, {} lambdas)",
8463 penalties.len(),
8464 lambdas.len()
8465 )));
8466 }
8467 let mut components = Vec::with_capacity(penalties.len());
8468 for (idx, (penalty, &lambda)) in penalties.iter().zip(lambdas.iter()).enumerate() {
8469 if !(lambda.is_finite() && lambda >= 0.0) {
8470 return Err(EstimationError::InvalidInput(format!(
8471 "smooth_term_lr_inference: lambda[{idx}] is invalid: {lambda}"
8472 )));
8473 }
8474 let r = &penalty.col_range;
8475 if r.end > p_total {
8476 return Err(EstimationError::InvalidInput(format!(
8477 "smooth_term_lr_inference: penalty[{idx}] range {:?} exceeds coefficient dimension {p_total}",
8478 r
8479 )));
8480 }
8481 let mut s_component = Array2::<f64>::zeros((p_total, p_total));
8482 s_component
8483 .slice_mut(s![r.start..r.end, r.start..r.end])
8484 .scaled_add(lambda, &penalty.local);
8485 components.push(gam_terms::inference::lawley::RhoPenaltyComponent { s_component });
8486 }
8487 Ok(components)
8488}
8489
8490pub fn smooth_term_lr_inference_forspec(
8531 data: ArrayView2<'_, f64>,
8532 y: ArrayView1<'_, f64>,
8533 weights: ArrayView1<'_, f64>,
8534 offset: ArrayView1<'_, f64>,
8535 resolvedspec: &TermCollectionSpec,
8536 family: LikelihoodSpec,
8537 options: &FitOptions,
8538) -> Result<Vec<SmoothTermLrInference>, EstimationError> {
8539 use gam_terms::inference::lawley::{
8540 LAWLEY_PAIR_MATRIX_MAX_ROWS, known_scale_expected_jets_with_dispersion,
8541 lawley_lr_bartlett_factor, lawley_lr_mean_shift_with_rho_variation,
8542 };
8543
8544 let n = data.nrows();
8545 let full = fit_term_collection_forspec(
8548 data,
8549 y,
8550 weights,
8551 offset,
8552 resolvedspec,
8553 family.clone(),
8554 options,
8555 )?;
8556 let ll_full = full.fit.log_likelihood;
8557 let p_total = full.design.design.ncols();
8558 let lambdas = full.fit.lambdas.as_slice().ok_or_else(|| {
8559 EstimationError::InvalidInput(
8560 "smooth_term_lr_inference: non-contiguous lambda vector".to_string(),
8561 )
8562 })?;
8563 let s_lambda = weighted_blockwise_penalty_sum(&full.design.penalties, lambdas, p_total);
8564 let rho_penalty_components =
8565 fitted_rho_penalty_components(&full.design.penalties, lambdas, p_total)?;
8566 let rho_covariance = full.fit.artifacts.rho_covariance.as_ref().filter(|cov| {
8567 cov.nrows() == rho_penalty_components.len() && cov.ncols() == rho_penalty_components.len()
8568 });
8569 let full_design_dense = full.design.design.to_dense();
8571 let influence = full.fit.coefficient_influence();
8572 let family_disp = lawley_dispersion_for_family(&family, &full.fit);
8573
8574 let mut penalty_cursor = full.design.random_effect_ranges.len();
8577 let mut out = Vec::<SmoothTermLrInference>::new();
8578 for (term_idx, design_term) in full.design.smooth.terms.iter().enumerate() {
8579 let k = design_term.penalties_local.len();
8580 let block_start = penalty_cursor;
8581 penalty_cursor += k;
8582 if design_term.shape != ShapeConstraint::None {
8585 continue;
8586 }
8587 let coeff_range = design_term.coeff_range.clone();
8588 if coeff_range.start >= coeff_range.end || coeff_range.end > p_total {
8589 continue;
8590 }
8591 let edf = full.fit.per_term_edf(coeff_range.clone(), block_start, k);
8603 let ref_df = wood_reference_df(influence, &coeff_range).unwrap_or(edf.max(1e-12));
8604 if !(ref_df.is_finite() && ref_df > 0.0) {
8605 continue;
8606 }
8607
8608 let mut null_spec = resolvedspec.clone();
8611 let Some(spec_pos) = null_spec
8612 .smooth_terms
8613 .iter()
8614 .position(|t| t.name == design_term.name)
8615 else {
8616 continue;
8617 };
8618 null_spec.smooth_terms.remove(spec_pos);
8619 let null_fit = fit_term_collection_forspec(
8620 data,
8621 y,
8622 weights,
8623 offset,
8624 &null_spec,
8625 family.clone(),
8626 options,
8627 );
8628 let (statistic_lr, eta_null) = match null_fit {
8629 Ok(null) if null.fit.log_likelihood.is_finite() => {
8630 let w = (2.0 * (ll_full - null.fit.log_likelihood)).max(0.0);
8631 let mut eta = null.design.design.dot(&null.fit.beta);
8635 eta += &offset;
8636 (w, Some(eta))
8637 }
8638 _ => (f64::NAN, None),
8639 };
8640
8641 let chi2 = statrs::distribution::ChiSquared::new(ref_df).ok();
8642 let p_uncorrected = match (chi2.as_ref(), statistic_lr.is_finite()) {
8643 (Some(dist), true) => {
8644 use statrs::distribution::ContinuousCDF;
8645 (1.0 - dist.cdf(statistic_lr)).clamp(0.0, 1.0)
8646 }
8647 _ => f64::NAN,
8648 };
8649
8650 let mut bartlett_factor = 1.0;
8654 let mut bartlett_factor_conditional = None;
8655 let mut rho_variation_shift = None;
8656 let mut statistic_corrected = statistic_lr;
8657 let mut p_corrected = p_uncorrected;
8658 let mut correction = SmoothLrCorrection::None;
8659 if let (Some(eta), true, true) = (
8660 eta_null.as_ref(),
8661 statistic_lr.is_finite(),
8662 n <= LAWLEY_PAIR_MATRIX_MAX_ROWS,
8663 ) {
8664 let kappas: Option<Vec<_>> = (0..n)
8665 .map(|i| {
8666 known_scale_expected_jets_with_dispersion(&family, eta[i], family_disp)
8667 .and_then(|jets| jets.kappas().ok())
8668 })
8669 .collect();
8670 if let (Some(kappas), Some(dist)) = (kappas, chi2.as_ref()) {
8671 let fixed_factor = lawley_lr_bartlett_factor(
8672 full_design_dense.view(),
8673 &kappas,
8674 Some(s_lambda.view()),
8675 coeff_range.clone(),
8676 ref_df,
8677 );
8678 if let Ok(c_cond) = fixed_factor
8679 && c_cond.is_finite()
8680 && c_cond > 0.0
8681 {
8682 let mut c_applied = c_cond;
8683 correction = SmoothLrCorrection::LawleyLrFixedLambda;
8684 if let Some(cov) = rho_covariance
8685 && let Ok(total_shift) = lawley_lr_mean_shift_with_rho_variation(
8686 full_design_dense.view(),
8687 &kappas,
8688 s_lambda.view(),
8689 coeff_range.clone(),
8690 &rho_penalty_components,
8691 cov.view(),
8692 )
8693 {
8694 let mean_w = ref_df + total_shift;
8695 if let Some(c_est) =
8696 gam_terms::inference::higher_order::bartlett_factor_from_mean(
8697 mean_w, ref_df,
8698 )
8699 && c_est.is_finite()
8700 && c_est > 0.0
8701 {
8702 let conditional_shift = (c_cond - 1.0) * ref_df;
8703 c_applied = c_est;
8704 bartlett_factor_conditional = Some(c_cond);
8705 rho_variation_shift = Some(total_shift - conditional_shift);
8706 correction = SmoothLrCorrection::LawleyLrEstimatedLambda;
8707 }
8708 }
8709 use statrs::distribution::ContinuousCDF;
8710 bartlett_factor = c_applied;
8711 statistic_corrected = statistic_lr / c_applied;
8712 p_corrected = (1.0 - dist.cdf(statistic_corrected)).clamp(0.0, 1.0);
8713 }
8714 }
8715 }
8716
8717 let material = match correction {
8723 SmoothLrCorrection::LawleyLrEstimatedLambda
8724 | SmoothLrCorrection::LawleyLrFixedLambda => {
8725 let factor_move = (bartlett_factor - 1.0).abs();
8726 let p_denom = p_uncorrected.max(p_corrected).max(f64::MIN_POSITIVE);
8727 let p_move = if p_uncorrected.is_finite() && p_corrected.is_finite() {
8728 (p_corrected - p_uncorrected).abs() / p_denom
8729 } else {
8730 0.0
8731 };
8732 factor_move > SMOOTH_LR_MATERIAL_THRESHOLD || p_move > SMOOTH_LR_MATERIAL_THRESHOLD
8733 }
8734 SmoothLrCorrection::None => false,
8735 };
8736
8737 out.push(SmoothTermLrInference {
8738 name: design_term.name.clone(),
8739 term_idx,
8740 statistic_lr,
8741 ref_df,
8742 bartlett_factor,
8743 bartlett_factor_conditional,
8744 rho_variation_shift,
8745 statistic_corrected,
8746 p_value_uncorrected: p_uncorrected,
8747 p_value_corrected: p_corrected,
8748 material,
8749 correction,
8750 });
8751 }
8752 Ok(out)
8753}
8754
8755fn lawley_dispersion_for_family(family: &LikelihoodSpec, fit: &UnifiedFitResult) -> f64 {
8758 match family.response {
8759 gam_spec::ResponseFamily::Gaussian => {
8760 let sd = fit.standard_deviation;
8761 (sd * sd).max(f64::MIN_POSITIVE)
8762 }
8763 gam_spec::ResponseFamily::Gamma => {
8764 let shape = fit.standard_deviation;
8765 if shape.is_finite() && shape > 0.0 {
8766 1.0 / shape
8767 } else {
8768 1.0
8769 }
8770 }
8771 _ => 1.0,
8772 }
8773}
8774
8775fn wood_reference_df(influence: Option<&Array2<f64>>, coeff_range: &Range<usize>) -> Option<f64> {
8781 let f = influence?;
8782 let (start, end) = (coeff_range.start, coeff_range.end);
8783 if start >= end || end > f.nrows() || end > f.ncols() {
8784 return None;
8785 }
8786 let block = f.slice(s![start..end, start..end]);
8787 let tr = (0..block.nrows()).map(|i| block[[i, i]]).sum::<f64>();
8788 let tr2 = block.dot(&block).diag().sum();
8789 (tr.is_finite() && tr2.is_finite() && tr > 0.0 && tr2 > 0.0).then(|| (tr * tr / tr2).max(1e-12))
8790}