1fn try_build_spatial_term_log_kappa_derivative(
2 data: ArrayView2<'_, f64>,
3 resolvedspec: &TermCollectionSpec,
4 design: &TermCollectionDesign,
5 term_idx: usize,
6) -> Result<
7 Option<(
8 Range<usize>,
9 usize,
10 Array2<f64>,
11 Array2<f64>,
12 Array2<f64>,
13 Array2<f64>,
14 Vec<Array2<f64>>,
15 Vec<Array2<f64>>,
16 Option<std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>>,
17 )>,
18 EstimationError,
19> {
20 let Some(smooth_term) = design.smooth.terms.get(term_idx) else {
21 return Ok(None);
22 };
23 let Some(termspec) = resolvedspec.smooth_terms.get(term_idx) else {
24 return Ok(None);
25 };
26
27 let derivative_bundle = match &termspec.basis {
28 SmoothBasisSpec::ThinPlate {
29 feature_cols,
30 spec,
31 input_scales,
32 } => {
33 let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
34 let mut spec_local = spec.clone();
35 if let Some(s) = input_scales {
36 apply_input_standardization(&mut x, s);
37 spec_local.length_scale =
38 compensate_length_scale_for_standardization(spec.length_scale, s);
39 }
40 build_thin_plate_basis_log_kappa_derivatives(x.view(), &spec_local)
41 .map_err(EstimationError::from)?
42 }
43 SmoothBasisSpec::Sphere { .. } => return Ok(None),
44 SmoothBasisSpec::ConstantCurvature { feature_cols, spec } => {
53 let x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
54 build_constant_curvature_basis_kappa_derivatives(x.view(), spec)
55 .map_err(EstimationError::from)?
56 }
57 SmoothBasisSpec::MeasureJet { .. } => return Ok(None),
63 SmoothBasisSpec::Matern {
64 feature_cols,
65 spec,
66 input_scales,
67 } => {
68 let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
69 let mut spec_local = spec.clone();
70 if let Some(s) = input_scales {
71 apply_input_standardization(&mut x, s);
72 spec_local.length_scale =
73 compensate_length_scale_for_standardization(spec.length_scale, s);
74 }
75 build_matern_basis_log_kappa_derivatives(x.view(), &spec_local)
76 .map_err(EstimationError::from)?
77 }
78 SmoothBasisSpec::Duchon {
79 feature_cols,
80 spec,
81 input_scales,
82 } => {
83 let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
84 let mut spec_local = spec.clone();
85 if let Some(s) = input_scales {
86 apply_input_standardization(&mut x, s);
87 spec_local.length_scale =
88 compensate_optional_length_scale_for_standardization(spec.length_scale, s);
89 }
90 let BasisMetadata::Duchon {
91 centers,
92 identifiability_transform,
93 operator_collocation_points,
94 radial_reparam,
95 ..
96 } = &smooth_term.metadata
97 else {
98 return Ok(None);
99 };
100 if spec_local.radial_reparam.is_none() {
103 spec_local.radial_reparam = radial_reparam.clone();
104 }
105 gam_terms::basis::build_duchon_basis_log_kappa_derivativeswith_collocationwithworkspace(
106 x.view(),
107 &spec_local,
108 centers.view(),
109 identifiability_transform.as_ref(),
110 operator_collocation_points
111 .as_ref()
112 .map(|points| points.view()),
113 &mut BasisWorkspace::default(),
114 )
115 .map_err(EstimationError::from)?
116 }
117 SmoothBasisSpec::BSpline1D { .. }
118 | SmoothBasisSpec::TensorBSpline { .. }
119 | SmoothBasisSpec::ByVariable { .. }
120 | SmoothBasisSpec::FactorSumToZero { .. }
121 | SmoothBasisSpec::BySmooth { .. }
122 | SmoothBasisSpec::FactorSmooth { .. }
123 | SmoothBasisSpec::Pca { .. } => {
124 return Ok(None);
125 }
126 };
127 let mut implicit_operator = derivative_bundle.implicit_operator;
128 let BasisPsiDerivativeResult {
129 design_derivative: mut local_x_psi,
130 penalties_derivative: mut local_s_psi,
131 implicit_operator: local_implicit_first_unused,
132 } = derivative_bundle.first;
133 let BasisPsiSecondDerivativeResult {
134 designsecond_derivative: mut local_x_psi_psi,
135 penaltiessecond_derivative: mut local_s_psi_psi,
136 implicit_operator: local_implicit_second_unused,
137 } = derivative_bundle.second;
138 assert!(local_implicit_first_unused.is_none());
139 assert!(local_implicit_second_unused.is_none());
140
141 if let Some(rotation) = smooth_term.joint_null_rotation.as_ref() {
142 let q = &rotation.rotation;
143 if let Some(op) = implicit_operator.take() {
144 implicit_operator = Some(op.append_full_transform(q).map_err(EstimationError::from)?);
145 } else {
146 if local_x_psi.ncols() != q.nrows() || local_x_psi_psi.ncols() != q.nrows() {
147 return Ok(None);
148 }
149 local_x_psi = fast_ab(&local_x_psi, q);
150 local_x_psi_psi = fast_ab(&local_x_psi_psi, q);
151 }
152 let rotate_penalty = |s_local: Array2<f64>| -> Option<Array2<f64>> {
153 if s_local.nrows() != q.nrows() || s_local.ncols() != q.nrows() {
154 return None;
155 }
156 let qt_s = gam_linalg::faer_ndarray::fast_atb(q, &s_local);
157 Some(gam_linalg::faer_ndarray::fast_ab(&qt_s, q))
158 };
159 let Some(rotated_s_psi) = local_s_psi
160 .into_iter()
161 .map(|s| rotate_penalty(s))
162 .collect::<Option<Vec<_>>>()
163 else {
164 return Ok(None);
165 };
166 local_s_psi = rotated_s_psi;
167 let Some(rotated_s_psi_psi) = local_s_psi_psi
168 .into_iter()
169 .map(|s| rotate_penalty(s))
170 .collect::<Option<Vec<_>>>()
171 else {
172 return Ok(None);
173 };
174 local_s_psi_psi = rotated_s_psi_psi;
175 }
176 let implicit_operator = implicit_operator.map(std::sync::Arc::new);
177
178 if let Some(ref op) = implicit_operator {
179 if op.p_out() != smooth_term.coeff_range.len() {
180 return Ok(None);
181 }
182 } else {
183 if local_x_psi.ncols() != smooth_term.coeff_range.len() {
184 return Ok(None);
185 }
186 if local_x_psi_psi.ncols() != smooth_term.coeff_range.len() {
187 return Ok(None);
188 }
189 }
190 if local_s_psi.is_empty() || local_s_psi.len() != local_s_psi_psi.len() {
191 return Ok(None);
192 }
193 if local_s_psi.iter().any(|s| {
194 s.nrows() != smooth_term.coeff_range.len() || s.ncols() != smooth_term.coeff_range.len()
195 }) {
196 return Ok(None);
197 }
198 if local_s_psi_psi.iter().any(|s| {
199 s.nrows() != smooth_term.coeff_range.len() || s.ncols() != smooth_term.coeff_range.len()
200 }) {
201 return Ok(None);
202 }
203
204 let p_total = design.design.ncols();
205 let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
206 let global_range = (smooth_start + smooth_term.coeff_range.start)
207 ..(smooth_start + smooth_term.coeff_range.end);
208
209 Ok(Some((
210 global_range,
211 p_total,
212 local_x_psi,
213 local_s_psi.iter().fold(
214 Array2::<f64>::zeros((smooth_term.coeff_range.len(), smooth_term.coeff_range.len())),
215 |acc, m| acc + m,
216 ),
217 local_x_psi_psi,
218 local_s_psi_psi.iter().fold(
219 Array2::<f64>::zeros((smooth_term.coeff_range.len(), smooth_term.coeff_range.len())),
220 |acc, m| acc + m,
221 ),
222 local_s_psi,
223 local_s_psi_psi,
224 implicit_operator,
225 )))
226}
227
228fn try_build_spatial_log_kappa_hyper_dirs(
229 data: ArrayView2<'_, f64>,
230 resolvedspec: &TermCollectionSpec,
231 design: &TermCollectionDesign,
232 spatial_terms: &[usize],
233) -> Result<Option<Vec<DirectionalHyperParam>>, EstimationError> {
234 let Some(info_list) =
241 try_build_spatial_log_kappa_derivativeinfo_list(data, resolvedspec, design, spatial_terms)?
242 else {
243 return Ok(None);
244 };
245 Ok(Some(spatial_log_kappa_hyper_dirs_frominfo_list(info_list)?))
246}
247
248pub(crate) fn try_build_latent_coord_hyper_dirs(
249 latent: std::sync::Arc<gam_terms::latent::LatentCoordValues>,
250 resolvedspec: &TermCollectionSpec,
251 design: &TermCollectionDesign,
252 latent_terms: &[gam_problem::types::SmoothTermIdx],
253 analytic_rho_count: usize,
254) -> Result<Option<Vec<DirectionalHyperParam>>, EstimationError> {
255 if latent_terms.is_empty() || latent.is_empty() {
256 return Ok(None);
257 }
258 if latent_terms.len() != 1 {
259 crate::bail_invalid_estim!(
260 "LatentCoord standard-fit hyper_dirs currently require exactly one latent smooth term"
261 .to_string(),
262 );
263 }
264 let term_idx = latent_terms[0];
265 let smooth_term = design.smooth.terms.get(term_idx.get()).ok_or_else(|| {
266 EstimationError::InvalidInput(format!(
267 "LatentCoord term index {term_idx} out of bounds for realized smooth design"
268 ))
269 })?;
270 let termspec = resolvedspec
271 .smooth_terms
272 .get(term_idx.get())
273 .ok_or_else(|| {
274 EstimationError::InvalidInput(format!(
275 "LatentCoord term index {term_idx} out of bounds for resolved smooth spec"
276 ))
277 })?;
278 let p_total = design.design.ncols();
279 let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
280 let global_range = (smooth_start + smooth_term.coeff_range.start)
281 ..(smooth_start + smooth_term.coeff_range.end);
282
283 let operator = match (&termspec.basis, &smooth_term.metadata) {
288 (
289 SmoothBasisSpec::Matern { .. },
290 BasisMetadata::Matern {
291 centers,
292 length_scale,
293 nu,
294 include_intercept,
295 identifiability_transform,
296 ..
297 },
298 ) => gam_terms::basis::LatentCoordDesignDerivative::new_matern(
299 latent.clone(),
300 std::sync::Arc::new(centers.clone()),
301 *length_scale,
302 *nu,
303 *include_intercept,
304 identifiability_transform.clone(),
305 )
306 .map_err(EstimationError::from)?,
307 (
308 SmoothBasisSpec::Duchon { .. },
309 BasisMetadata::Duchon {
310 centers,
311 length_scale,
312 power,
313 nullspace_order,
314 identifiability_transform,
315 ..
316 },
317 ) => gam_terms::basis::LatentCoordDesignDerivative::new_duchon(
318 latent.clone(),
319 std::sync::Arc::new(centers.clone()),
320 *length_scale,
321 *power,
322 *nullspace_order,
323 identifiability_transform.clone(),
324 )
325 .map_err(EstimationError::from)?,
326 (
327 SmoothBasisSpec::Sphere { .. },
328 BasisMetadata::Sphere {
329 centers,
330 penalty_order,
331 method,
332 constraint_transform,
333 ..
334 },
335 ) if matches!(*method, gam_terms::basis::SphereMethod::Wahba) => {
336 gam_terms::basis::LatentCoordDesignDerivative::new_sphere(
337 latent.clone(),
338 std::sync::Arc::new(centers.clone()),
339 *penalty_order,
340 constraint_transform.clone(),
341 )
342 .map_err(EstimationError::from)?
343 }
344 (
345 SmoothBasisSpec::BSpline1D { spec, .. },
346 BasisMetadata::BSpline1D {
347 knots,
348 identifiability_transform,
349 periodic,
350 degree: meta_degree,
351 ..
352 },
353 ) => {
354 let effective_degree = meta_degree.unwrap_or(spec.degree);
358 if let Some((domain_start, period, num_basis)) = periodic {
359 gam_terms::basis::LatentCoordDesignDerivative::new_periodic_bspline(
360 latent.clone(),
361 (*domain_start, *domain_start + *period),
362 effective_degree,
363 *num_basis,
364 identifiability_transform.clone(),
365 )
366 .map_err(EstimationError::from)?
367 } else {
368 gam_terms::basis::LatentCoordDesignDerivative::new_tensor_bspline(
369 latent.clone(),
370 vec![knots.clone()],
371 vec![effective_degree],
372 identifiability_transform.clone(),
373 )
374 .map_err(EstimationError::from)?
375 }
376 }
377 (
378 SmoothBasisSpec::TensorBSpline { .. },
379 BasisMetadata::TensorBSpline {
380 knots,
381 degrees,
382 identifiability_transform,
383 ..
384 },
385 ) => gam_terms::basis::LatentCoordDesignDerivative::new_tensor_bspline(
386 latent.clone(),
387 knots.clone(),
388 degrees.clone(),
389 identifiability_transform.clone(),
390 )
391 .map_err(EstimationError::from)?,
392 (SmoothBasisSpec::Pca { .. }, BasisMetadata::Pca { basis_matrix, .. }) => {
393 gam_terms::basis::LatentCoordDesignDerivative::new_pca(
394 latent.clone(),
395 std::sync::Arc::new(basis_matrix.clone()),
396 )
397 .map_err(EstimationError::from)?
398 }
399 _ => return Ok(None),
400 };
401 if operator.p_out() != global_range.len() {
402 crate::bail_invalid_estim!(
403 "LatentCoord derivative width mismatch for term '{}': operator p={}, coeff range={}",
404 smooth_term.name,
405 operator.p_out(),
406 global_range.len()
407 );
408 }
409 let operator = std::sync::Arc::new(operator);
410 let mut hyper_dirs = Vec::with_capacity(operator.n_axes());
411 for flat_axis in 0..operator.n_axes() {
412 let dir = DirectionalHyperParam::new_compact(
413 gam_solve::estimate::reml::HyperDesignDerivative::from_latent_coord(
414 operator.clone(),
415 flat_axis,
416 global_range.clone(),
417 p_total,
418 ),
419 Vec::new(),
420 None,
421 None,
422 )?
423 .not_penalty_like();
424 hyper_dirs.push(dir);
425 }
426 let direct_dim = latent_coord_direct_hyper_count(latent.id_mode(), latent.latent_dim());
427 if analytic_rho_count + direct_dim > 0 {
428 let zero_x = gam_solve::estimate::reml::HyperDesignDerivative::from(Array2::<f64>::zeros((
429 design.design.nrows(),
430 p_total,
431 )));
432 for _ in 0..analytic_rho_count {
433 hyper_dirs.push(
434 DirectionalHyperParam::new_compact(zero_x.clone(), Vec::new(), None, None)?
435 .not_penalty_like(),
436 );
437 }
438 for _ in 0..direct_dim {
439 hyper_dirs.push(
440 DirectionalHyperParam::new_compact(zero_x.clone(), Vec::new(), None, None)?
441 .not_penalty_like(),
442 );
443 }
444 }
445 Ok(Some(hyper_dirs))
446}
447
448fn latent_coord_direct_hyper_count(
449 id_mode: &gam_terms::latent::LatentIdMode,
450 latent_dim: usize,
451) -> usize {
452 use gam_terms::latent::{AuxPriorStrength, LatentIdMode};
453 match id_mode {
454 LatentIdMode::AuxPrior { strength, .. } => match strength {
455 AuxPriorStrength::Auto => 1,
456 AuxPriorStrength::Fixed(_) => 0,
457 },
458 LatentIdMode::AuxPriorDimSelection { strength, .. } => {
459 latent_dim
460 + match strength {
461 AuxPriorStrength::Auto => 1,
462 AuxPriorStrength::Fixed(_) => 0,
463 }
464 }
465 LatentIdMode::DimSelection { .. } => latent_dim,
466 LatentIdMode::IsometryToReference { strength, .. } => match strength {
469 AuxPriorStrength::Auto => 1,
470 AuxPriorStrength::Fixed(_) => 0,
471 },
472 LatentIdMode::AuxOutcome { head, .. } => head.n_coeffs(latent_dim) + latent_dim,
475 LatentIdMode::None => 0,
476 }
477}
478
479fn latent_coord_initial_direct_hypers(
480 id_mode: &gam_terms::latent::LatentIdMode,
481 latent_dim: usize,
482) -> Result<Array1<f64>, EstimationError> {
483 use gam_terms::latent::{AuxPriorStrength, LatentIdMode};
484 let mut values = Vec::with_capacity(latent_coord_direct_hyper_count(id_mode, latent_dim));
485 match id_mode {
486 LatentIdMode::AuxPrior { strength, .. } => {
487 if matches!(strength, AuxPriorStrength::Auto) {
488 values.push(0.0);
489 }
490 }
491 LatentIdMode::AuxPriorDimSelection {
492 strength,
493 init_log_precision,
494 ..
495 } => {
496 if matches!(strength, AuxPriorStrength::Auto) {
497 values.push(0.0);
498 }
499 append_latent_ard_seed(&mut values, init_log_precision.as_ref(), latent_dim)?;
500 }
501 LatentIdMode::DimSelection { init_log_precision } => {
502 append_latent_ard_seed(&mut values, init_log_precision.as_ref(), latent_dim)?;
503 }
504 LatentIdMode::IsometryToReference { strength, .. } => {
505 if matches!(strength, AuxPriorStrength::Auto) {
506 values.push(0.0);
507 }
508 }
509 LatentIdMode::AuxOutcome {
510 head,
511 init_log_precision,
512 } => {
513 values.extend(std::iter::repeat_n(0.0, head.n_coeffs(latent_dim)));
517 append_latent_ard_seed(&mut values, init_log_precision.as_ref(), latent_dim)?;
518 }
519 LatentIdMode::None => {}
520 }
521 Ok(Array1::from_vec(values))
522}
523
524fn append_latent_ard_seed(
525 values: &mut Vec<f64>,
526 init: Option<&Array1<f64>>,
527 latent_dim: usize,
528) -> Result<(), EstimationError> {
529 if let Some(init) = init {
530 if init.len() != latent_dim {
531 crate::bail_invalid_estim!(
532 "latent dim_selection init_log_precision length mismatch: got {}, expected {}",
533 init.len(),
534 latent_dim
535 );
536 }
537 values.extend(init.iter().copied());
538 } else {
539 values.extend(std::iter::repeat_n(0.0, latent_dim));
540 }
541 Ok(())
542}
543
544struct LatentIdObjectiveContribution {
545 cost: f64,
546 gradient: Array1<f64>,
547}
548
549fn latent_id_objective_contribution(
550 theta: &Array1<f64>,
551 rho_dim: usize,
552 analytic_rho_count: usize,
553 latent: &gam_terms::latent::LatentCoordValues,
554) -> Result<LatentIdObjectiveContribution, EstimationError> {
555 use gam_terms::latent::{AuxPriorStrength, LatentIdMode, aux_prior_targets};
556 let n_obs = latent.n_obs();
557 let latent_dim = latent.latent_dim();
558 let flat_len = latent.len();
559 let mut gradient = Array1::<f64>::zeros(theta.len());
560 let t_start = rho_dim;
561 let direct_start = t_start + flat_len + analytic_rho_count;
562 if theta.len() < direct_start {
563 crate::bail_invalid_estim!(
564 "latent-coordinate theta too short for id objective: got {}, need at least {}",
565 theta.len(),
566 direct_start
567 );
568 }
569 let t = latent.as_matrix();
570 let mut cost = 0.0;
571 let mut cursor = direct_start;
572
573 match latent.id_mode() {
574 LatentIdMode::AuxPrior {
575 u,
576 family,
577 strength,
578 }
579 | LatentIdMode::AuxPriorDimSelection {
580 u,
581 family,
582 strength,
583 ..
584 } => {
585 let (log_mu, mu) = match strength {
586 AuxPriorStrength::Fixed(mu) => (mu.ln(), *mu),
587 AuxPriorStrength::Auto => {
588 let log_mu = theta[cursor];
589 cursor += 1;
590 (log_mu, log_mu.exp())
591 }
592 };
593 let targets = aux_prior_targets(t.view(), u.view(), *family)
594 .map_err(EstimationError::InvalidInput)?;
595 let residual = &t - &targets;
596 let q = residual.iter().map(|v| v * v).sum::<f64>();
597 let k = (n_obs * latent_dim) as f64;
604 cost += 0.5 * mu * q - 0.5 * k * log_mu;
605
606 let projected_residual = aux_prior_targets(residual.view(), u.view(), *family)
607 .map_err(EstimationError::InvalidInput)?;
608 let grad_base = residual - projected_residual;
609 for n in 0..n_obs {
610 for axis in 0..latent_dim {
611 gradient[t_start + n * latent_dim + axis] += mu * grad_base[[n, axis]];
612 }
613 }
614 if matches!(strength, AuxPriorStrength::Auto) {
615 gradient[direct_start] += 0.5 * mu * q - 0.5 * k;
616 }
617 }
618 LatentIdMode::IsometryToReference { reference, strength } => {
619 if reference.dim() != (n_obs, latent_dim) {
626 crate::bail_invalid_estim!(
627 "IsometryToReference reference shape {:?} must equal (n_obs, latent_dim) = ({}, {})",
628 reference.dim(),
629 n_obs,
630 latent_dim
631 );
632 }
633 let mu_slot = cursor;
634 let (log_mu, mu) = match strength {
635 AuxPriorStrength::Fixed(mu) => (mu.ln(), *mu),
636 AuxPriorStrength::Auto => {
637 let log_mu = theta[cursor];
638 cursor += 1;
639 (log_mu, log_mu.exp())
640 }
641 };
642 let residual = &t - reference;
643 let q = residual.iter().map(|v| v * v).sum::<f64>();
644 let k = (n_obs * latent_dim) as f64;
648 cost += 0.5 * mu * q - 0.5 * k * log_mu;
649 for n in 0..n_obs {
650 for axis in 0..latent_dim {
651 gradient[t_start + n * latent_dim + axis] += mu * residual[[n, axis]];
652 }
653 }
654 if matches!(strength, AuxPriorStrength::Auto) {
655 gradient[mu_slot] += 0.5 * mu * q - 0.5 * k;
656 }
657 }
658 LatentIdMode::AuxOutcome { head, .. } => {
659 let n_coeffs = head.n_coeffs(latent_dim);
667 let coeffs = theta
668 .slice(ndarray::s![cursor..cursor + n_coeffs])
669 .to_owned();
670 let (head_nll, grad_coeffs, grad_t) = head
671 .neg_loglik_and_grad(t.view(), coeffs.view())
672 .map_err(EstimationError::InvalidInput)?;
673 cost += head_nll;
674 for (offset, &g) in grad_coeffs.iter().enumerate() {
675 gradient[cursor + offset] += g;
676 }
677 for n in 0..n_obs {
678 for axis in 0..latent_dim {
679 gradient[t_start + n * latent_dim + axis] += grad_t[[n, axis]];
680 }
681 }
682 cursor += n_coeffs;
683 }
684 LatentIdMode::DimSelection { .. } | LatentIdMode::None => {}
685 }
686
687 match latent.id_mode() {
688 LatentIdMode::AuxPriorDimSelection { .. }
689 | LatentIdMode::DimSelection { .. }
690 | LatentIdMode::AuxOutcome { .. } => {
691 for axis in 0..latent_dim {
692 let log_alpha = theta[cursor + axis];
693 let alpha = log_alpha.exp();
694 let mut q_axis = 0.0;
695 for n in 0..n_obs {
696 let flat_idx = n * latent_dim + axis;
697 let value = latent.as_flat()[flat_idx];
698 q_axis += value * value;
699 gradient[t_start + flat_idx] += alpha * value;
700 }
701 cost += 0.5 * alpha * q_axis - 0.5 * n_obs as f64 * log_alpha;
702 gradient[cursor + axis] += 0.5 * alpha * q_axis - 0.5 * n_obs as f64;
703 }
704 cursor += latent_dim;
705 }
706 LatentIdMode::AuxPrior { .. }
707 | LatentIdMode::IsometryToReference { .. }
708 | LatentIdMode::None => {}
709 }
710
711 if cursor != theta.len() {
712 crate::bail_invalid_estim!(
713 "latent-coordinate direct hyperparameter length mismatch: consumed {}, theta len {}",
714 cursor,
715 theta.len()
716 );
717 }
718 Ok(LatentIdObjectiveContribution { cost, gradient })
719}
720
721fn add_latent_id_objective_to_eval(
722 theta: &Array1<f64>,
723 rho_dim: usize,
724 analytic_rho_count: usize,
725 latent: &gam_terms::latent::LatentCoordValues,
726 eval: &mut (
727 f64,
728 Array1<f64>,
729 gam_problem::HessianResult,
730 ),
731) -> Result<(), EstimationError> {
732 let contribution =
733 latent_id_objective_contribution(theta, rho_dim, analytic_rho_count, latent)?;
734 eval.0 += contribution.cost;
735 if eval.1.len() != contribution.gradient.len() {
736 crate::bail_invalid_estim!(
737 "latent-coordinate REML gradient length mismatch: base={}, id={}",
738 eval.1.len(),
739 contribution.gradient.len()
740 );
741 }
742 eval.1 += &contribution.gradient;
743 if eval.2.is_analytic() {
744 eval.2 = gam_problem::HessianResult::Unavailable;
745 }
746 Ok(())
747}
748
749fn analytic_penalty_objective_contribution(
750 theta: &Array1<f64>,
751 rho_dim: usize,
752 latent: &gam_terms::latent::LatentCoordValues,
753 registry: &gam_terms::AnalyticPenaltyRegistry,
754) -> Result<LatentIdObjectiveContribution, EstimationError> {
755 let flat_len = latent.len();
756 let t_start = rho_dim;
757 let t_end = t_start + flat_len;
758 let rho_start = t_end;
759 let rho_end = rho_start + registry.total_rho_count();
760 if theta.len() < rho_end {
761 crate::bail_invalid_estim!(
762 "latent-coordinate theta too short for analytic penalties: got {}, need at least {}",
763 theta.len(),
764 rho_end
765 );
766 }
767 let target_t = theta.slice(s![t_start..t_end]);
768 let rho = theta.slice(s![rho_start..rho_end]);
769 let mut cost = 0.0_f64;
770 let mut gradient = Array1::<f64>::zeros(theta.len());
771 for (penalty, (rho_slice, tier, name)) in registry.penalties.iter().zip(registry.rho_layout()) {
772 let rho_local = rho.slice(s![rho_slice.clone()]);
773 match tier {
774 gam_terms::PenaltyTier::Psi => {
775 cost += penalty.value(target_t.view(), rho_local);
776 let grad = penalty.grad_target(target_t.view(), rho_local);
777 if grad.len() != flat_len {
778 crate::bail_invalid_estim!(
779 "analytic penalty {name:?} gradient length mismatch: got {}, expected {}",
780 grad.len(),
781 flat_len
782 );
783 }
784 for i in 0..flat_len {
785 gradient[t_start + i] += grad[i];
786 }
787 let grad_rho_local = penalty.grad_rho(target_t.view(), rho_local);
788 if grad_rho_local.len() != rho_slice.len() {
789 crate::bail_invalid_estim!(
790 "analytic penalty {name:?} rho-gradient length mismatch: got {}, expected {}",
791 grad_rho_local.len(),
792 rho_slice.len()
793 );
794 }
795 for local_idx in 0..grad_rho_local.len() {
796 gradient[rho_start + rho_slice.start + local_idx] += grad_rho_local[local_idx];
797 }
798 }
799 gam_terms::PenaltyTier::Beta => {}
800 gam_terms::PenaltyTier::Rho => {}
801 }
802 }
803 Ok(LatentIdObjectiveContribution { cost, gradient })
804}
805
806fn add_analytic_penalty_hessian_to_eval(
807 theta: &Array1<f64>,
808 rho_dim: usize,
809 latent: &gam_terms::latent::LatentCoordValues,
810 registry: &gam_terms::AnalyticPenaltyRegistry,
811 eval: &mut (
812 f64,
813 Array1<f64>,
814 gam_problem::HessianResult,
815 ),
816) -> Result<(), EstimationError> {
817 let flat_len = latent.len();
818 let t_start = rho_dim;
819 let t_end = t_start + flat_len;
820 let rho_start = t_end;
821 let rho_end = rho_start + registry.total_rho_count();
822 if theta.len() < rho_end {
823 crate::bail_invalid_estim!(
824 "latent-coordinate theta too short for analytic penalty Hessian: got {}, need at least {}",
825 theta.len(),
826 rho_end
827 );
828 }
829 let gam_problem::HessianResult::Analytic(hessian) = &mut eval.2 else {
830 if eval.2.is_analytic() {
831 eval.2 = gam_problem::HessianResult::Unavailable;
832 }
833 return Ok(());
834 };
835 if hessian.dim() != (theta.len(), theta.len()) {
836 crate::bail_invalid_estim!(
837 "analytic penalty Hessian target shape mismatch: got {}x{}, expected {}x{}",
838 hessian.nrows(),
839 hessian.ncols(),
840 theta.len(),
841 theta.len()
842 );
843 }
844 let target_t = theta.slice(s![t_start..t_end]);
845 let rho = theta.slice(s![rho_start..rho_end]);
846 for (penalty, (rho_slice, tier, _name)) in registry.penalties.iter().zip(registry.rho_layout())
847 {
848 let rho_local = rho.slice(s![rho_slice]);
849 if !matches!(tier, gam_terms::PenaltyTier::Psi) {
850 continue;
851 }
852 if let Some(diag) = penalty.hessian_diag(target_t.view(), rho_local) {
853 if diag.len() != flat_len {
854 crate::bail_invalid_estim!(
855 "analytic penalty Hessian diagonal length mismatch: got {}, expected {}",
856 diag.len(),
857 flat_len
858 );
859 }
860 for i in 0..flat_len {
861 hessian[[t_start + i, t_start + i]] += diag[i];
862 }
863 continue;
864 }
865 let mut probe = Array1::<f64>::zeros(flat_len);
866 for col in 0..flat_len {
867 probe[col] = 1.0;
868 let hv = penalty.hvp(target_t.view(), rho_local, probe.view());
869 if hv.len() != flat_len {
870 crate::bail_invalid_estim!(
871 "analytic penalty Hessian-vector length mismatch: got {}, expected {}",
872 hv.len(),
873 flat_len
874 );
875 }
876 for row in 0..flat_len {
877 hessian[[t_start + row, t_start + col]] += hv[row];
878 }
879 probe[col] = 0.0;
880 }
881 }
882 Ok(())
883}
884
885fn add_analytic_penalty_objective_to_eval(
886 theta: &Array1<f64>,
887 rho_dim: usize,
888 latent: &gam_terms::latent::LatentCoordValues,
889 registry: &gam_terms::AnalyticPenaltyRegistry,
890 eval: &mut (
891 f64,
892 Array1<f64>,
893 gam_problem::HessianResult,
894 ),
895) -> Result<(), EstimationError> {
896 let contribution = analytic_penalty_objective_contribution(theta, rho_dim, latent, registry)?;
897 eval.0 += contribution.cost;
898 if eval.1.len() != contribution.gradient.len() {
899 crate::bail_invalid_estim!(
900 "latent-coordinate REML gradient length mismatch: base={}, analytic_penalty={}",
901 eval.1.len(),
902 contribution.gradient.len()
903 );
904 }
905 eval.1 += &contribution.gradient;
906 add_analytic_penalty_hessian_to_eval(theta, rho_dim, latent, registry, eval)?;
907 Ok(())
908}
909
910fn spatial_log_kappa_hyper_dirs_frominfo_list(
911 info_list: Vec<SpatialPsiDerivative>,
912) -> Result<Vec<DirectionalHyperParam>, EstimationError> {
913 use gam_solve::estimate::reml::ImplicitDerivLevel;
914 use std::collections::HashMap;
915
916 let log_kappa_dim = info_list.len();
917 let group_ids: Vec<Option<usize>> = info_list.iter().map(|e| e.aniso_group_id).collect();
923 let mut group_indices_map: HashMap<usize, Vec<usize>> = HashMap::new();
924 for (idx, gid) in group_ids.iter().enumerate() {
925 if let Some(g) = gid {
926 group_indices_map.entry(*g).or_default().push(idx);
927 }
928 }
929
930 let mut hyper_dirs = Vec::with_capacity(log_kappa_dim);
931 for (i, info) in info_list.into_iter().enumerate() {
932 let SpatialPsiDerivative {
933 penalty_index: _,
934 penalty_indices,
935 global_range,
936 total_p,
937 x_psi_local,
938 s_psi_components_local,
939 x_psi_psi_local,
940 s_psi_psi_components_local,
941 aniso_group_id,
942 aniso_cross_designs,
943 aniso_cross_penalty_provider,
944 implicit_operator,
945 implicit_axis,
946 } = info;
947
948 let mut xsecond = vec![None; log_kappa_dim];
949 xsecond[i] = Some(if let Some(ref op) = implicit_operator {
951 gam_solve::estimate::reml::HyperDesignDerivative::from_implicit(
952 op.clone(),
953 ImplicitDerivLevel::SecondDiag(implicit_axis),
954 global_range.clone(),
955 total_p,
956 )
957 } else {
958 gam_solve::estimate::reml::HyperDesignDerivative::from_embedded(
959 x_psi_psi_local,
960 global_range.clone(),
961 total_p,
962 )
963 });
964 if let Some(cross_designs) = aniso_cross_designs {
966 if let Some(gid) = aniso_group_id {
970 let base = group_indices_map
971 .get(&gid)
972 .and_then(|v| v.first().copied())
973 .unwrap_or(i);
974 for (b_axis, cross_mat) in cross_designs.into_iter() {
975 let j = base + b_axis;
976 if j < log_kappa_dim {
977 xsecond[j] = Some(if let Some(ref op) = implicit_operator {
978 gam_solve::estimate::reml::HyperDesignDerivative::from_implicit(
979 op.clone(),
980 ImplicitDerivLevel::SecondCross(implicit_axis, b_axis),
981 global_range.clone(),
982 total_p,
983 )
984 } else {
985 gam_solve::estimate::reml::HyperDesignDerivative::from_embedded(
986 cross_mat,
987 global_range.clone(),
988 total_p,
989 )
990 });
991 }
992 }
993 }
994 }
995 let s_components = penalty_indices
996 .iter()
997 .copied()
998 .zip(s_psi_components_local.into_iter().map(|local| {
999 gam_solve::estimate::reml::HyperPenaltyDerivative::from_embedded(
1000 local,
1001 global_range.clone(),
1002 total_p,
1003 )
1004 }))
1005 .collect::<Vec<_>>();
1006 let s2_components = penalty_indices
1007 .iter()
1008 .copied()
1009 .zip(s_psi_psi_components_local.into_iter().map(|local| {
1010 gam_solve::estimate::reml::HyperPenaltyDerivative::from_embedded(
1011 local,
1012 global_range.clone(),
1013 total_p,
1014 )
1015 }))
1016 .collect::<Vec<_>>();
1017 let mut ssecond_components = vec![None; log_kappa_dim];
1018 ssecond_components[i] = Some(s2_components);
1019 let mut penaltysecond_partner_indices: Option<Vec<usize>> = None;
1020 let penaltysecond_component_provider =
1021 if let (Some(provider), Some(gid)) = (aniso_cross_penalty_provider, aniso_group_id) {
1022 let group_indices = group_indices_map.get(&gid).cloned().unwrap_or_default();
1023 let axis_in_group =
1024 group_indices
1025 .iter()
1026 .position(|&idx| idx == i)
1027 .ok_or_else(|| {
1028 EstimationError::InvalidInput(format!(
1029 "missing spatial hyper axis {} in anisotropy group {}",
1030 i, gid
1031 ))
1032 })?;
1033 penaltysecond_partner_indices = Some(
1034 group_indices
1035 .iter()
1036 .copied()
1037 .filter(|&idx| idx != i)
1038 .collect(),
1039 );
1040 let penalty_indices_inner = penalty_indices.clone();
1041 let global_range_inner = global_range.clone();
1042 let total_p_inner = total_p;
1043 let group_indices_inner = group_indices;
1044 Some(std::sync::Arc::new(
1045 move |j: usize| -> Result<
1046 Option<Vec<gam_solve::estimate::reml::PenaltyDerivativeComponent>>,
1047 EstimationError,
1048 > {
1049 let Some(other_axis_in_group) =
1050 group_indices_inner.iter().position(|&idx| idx == j)
1051 else {
1052 return Ok(None);
1053 };
1054 if other_axis_in_group == axis_in_group {
1055 return Ok(None);
1056 }
1057 let cross_pens = provider(other_axis_in_group)?;
1058 if cross_pens.is_empty() {
1059 return Ok(None);
1060 }
1061 Ok(Some(
1062 penalty_indices_inner
1063 .iter()
1064 .copied()
1065 .zip(cross_pens.into_iter().map(|local| {
1066 gam_solve::estimate::reml::HyperPenaltyDerivative::from_embedded(
1067 local,
1068 global_range_inner.clone(),
1069 total_p_inner,
1070 )
1071 }))
1072 .map(|(penalty_index, matrix)| {
1073 gam_solve::estimate::reml::PenaltyDerivativeComponent {
1074 penalty_index,
1075 matrix,
1076 }
1077 })
1078 .collect(),
1079 ))
1080 },
1081 )
1082 as std::sync::Arc<
1083 dyn Fn(
1084 usize,
1085 ) -> Result<
1086 Option<Vec<gam_solve::estimate::reml::PenaltyDerivativeComponent>>,
1087 EstimationError,
1088 > + Send
1089 + Sync
1090 + 'static,
1091 >)
1092 } else {
1093 None
1094 };
1095 let x_first_hyper = if let Some(ref op) = implicit_operator {
1098 gam_solve::estimate::reml::HyperDesignDerivative::from_implicit(
1099 op.clone(),
1100 ImplicitDerivLevel::First(implicit_axis),
1101 global_range.clone(),
1102 total_p,
1103 )
1104 } else {
1105 gam_solve::estimate::reml::HyperDesignDerivative::from_embedded(
1106 x_psi_local,
1107 global_range.clone(),
1108 total_p,
1109 )
1110 };
1111 let mut dir = DirectionalHyperParam::new_compact(
1112 x_first_hyper,
1113 s_components,
1114 Some(xsecond),
1115 Some(ssecond_components),
1116 )?
1117 .not_penalty_like();
1118 if let Some(provider) = penaltysecond_component_provider {
1119 dir = dir.with_penaltysecond_component_provider(provider);
1120 }
1121 if let Some(partner_indices) = penaltysecond_partner_indices {
1122 dir = dir.with_penaltysecond_partner_indices(partner_indices);
1123 }
1124 hyper_dirs.push(dir);
1125 }
1126 Ok(hyper_dirs)
1127}
1128
1129pub(crate) fn spatial_dims_per_term(
1135 resolvedspec: &TermCollectionSpec,
1136 spatial_terms: &[usize],
1137) -> Vec<usize> {
1138 spatial_terms
1139 .iter()
1140 .map(|&term_idx| {
1141 if let Some(mj) = measure_jet_term_spec(resolvedspec, term_idx) {
1142 measure_jet_psi_dim(mj)
1145 } else if spatial_term_uses_per_axis_psi(resolvedspec, term_idx) {
1146 get_spatial_feature_dim(resolvedspec, term_idx).unwrap_or(1)
1147 } else {
1148 1
1149 }
1150 })
1151 .collect()
1152}
1153
1154fn has_aniso_terms(resolvedspec: &TermCollectionSpec, spatial_terms: &[usize]) -> bool {
1158 spatial_terms
1159 .iter()
1160 .any(|&term_idx| spatial_term_uses_per_axis_psi(resolvedspec, term_idx))
1161}
1162
1163macro_rules! impl_exact_joint_theta_memo {
1169 () => {
1170 fn memoized_cost(&self, theta: &Array1<f64>) -> Option<f64> {
1171 if self
1172 .current_theta
1173 .as_ref()
1174 .is_some_and(|cached| theta_values_match(cached, theta))
1175 {
1176 self.last_eval
1177 .as_ref()
1178 .map(|cached| cached.0)
1179 .or(self.last_cost)
1180 } else {
1181 None
1182 }
1183 }
1184
1185 fn memoized_eval(
1186 &self,
1187 theta: &Array1<f64>,
1188 ) -> Option<(
1189 f64,
1190 Array1<f64>,
1191 gam_problem::HessianResult,
1192 )> {
1193 if self
1194 .current_theta
1195 .as_ref()
1196 .is_some_and(|cached| theta_values_match(cached, theta))
1197 {
1198 self.last_eval.clone()
1199 } else {
1200 None
1201 }
1202 }
1203
1204 fn store_eval(
1205 &mut self,
1206 eval: (
1207 f64,
1208 Array1<f64>,
1209 gam_problem::HessianResult,
1210 ),
1211 ) {
1212 self.last_cost = Some(eval.0);
1213 self.last_eval = Some(eval);
1214 }
1215 };
1216}
1217
1218struct SingleBlockExactJointDesignCache<'d> {
1219 realizer: FrozenTermCollectionIncrementalRealizer<'d>,
1220 current_theta: Option<Array1<f64>>,
1221 last_eval_theta: Option<Array1<f64>>,
1228 last_cost: Option<f64>,
1229 last_eval: Option<(
1230 f64,
1231 Array1<f64>,
1232 gam_problem::HessianResult,
1233 )>,
1234 cached_hyper_dirs: Option<(u64, Vec<DirectionalHyperParam>)>,
1246 spatial_terms: Vec<usize>,
1247 rho_dim: usize,
1248 dims_per_term: Vec<usize>,
1249}
1250
1251impl<'d> SingleBlockExactJointDesignCache<'d> {
1252 fn new(
1253 data: ArrayView2<'d, f64>,
1254 spec: TermCollectionSpec,
1255 design: TermCollectionDesign,
1256 spatial_terms: Vec<usize>,
1257 rho_dim: usize,
1258 dims_per_term: Vec<usize>,
1259 ) -> Result<Self, String> {
1260 Ok(Self {
1261 realizer: FrozenTermCollectionIncrementalRealizer::new(data, spec, design)?,
1262 current_theta: None,
1263 last_eval_theta: None,
1264 last_cost: None,
1265 last_eval: None,
1266 cached_hyper_dirs: None,
1267 spatial_terms,
1268 rho_dim,
1269 dims_per_term,
1270 })
1271 }
1272
1273 fn design_revision(&self) -> u64 {
1274 self.realizer.design_revision()
1275 }
1276
1277 fn hyper_dirs_for_current_design(
1287 &mut self,
1288 data: ArrayView2<'_, f64>,
1289 kind: SpatialHyperKind,
1290 ) -> Result<Vec<DirectionalHyperParam>, EstimationError> {
1291 let revision = self.realizer.design_revision();
1292 if let Some((cached_rev, dirs)) = self.cached_hyper_dirs.as_ref()
1293 && *cached_rev == revision
1294 {
1295 return Ok(dirs.clone());
1296 }
1297 let dirs = try_build_spatial_log_kappa_hyper_dirs(
1298 data,
1299 self.realizer.spec(),
1300 self.realizer.design(),
1301 &self.spatial_terms,
1302 )?
1303 .ok_or_else(|| {
1304 EstimationError::InvalidInput(format!(
1305 "failed to build {} hyper_dirs at current {}",
1306 kind.adjective(),
1307 kind.coord_name(),
1308 ))
1309 })?;
1310 self.cached_hyper_dirs = Some((revision, dirs.clone()));
1311 Ok(dirs)
1312 }
1313
1314 fn nfree_tensor_gradient_hyper_dirs(
1315 &mut self,
1316 theta: &Array1<f64>,
1317 ) -> Result<Vec<DirectionalHyperParam>, EstimationError> {
1318 let psi = &theta.as_slice().ok_or_else(|| {
1319 EstimationError::InvalidInput(
1320 "nfree_tensor_gradient_hyper_dirs: theta is not contiguous".to_string(),
1321 )
1322 })?[self.rho_dim..];
1323 let (global_range, p_total, s_psi_components) = self
1324 .realizer
1325 .canonical_penalty_derivatives_at_psi(&self.spatial_terms, psi)
1326 .map_err(EstimationError::InvalidInput)?;
1327 let zero_x = gam_solve::estimate::reml::HyperDesignDerivative::zero(
1328 self.realizer.design().design.nrows(),
1329 p_total,
1330 );
1331 let components = s_psi_components
1332 .into_iter()
1333 .enumerate()
1334 .map(|(penalty_index, local)| {
1335 (
1336 penalty_index,
1337 gam_solve::estimate::reml::HyperPenaltyDerivative::from_embedded(
1338 local,
1339 global_range.clone(),
1340 p_total,
1341 ),
1342 )
1343 })
1344 .collect::<Vec<_>>();
1345 Ok(DirectionalHyperParam::new_compact(zero_x, components, None, None)?.not_penalty_like())
1346 .map(|dir| vec![dir])
1347 }
1348
1349 fn ensure_theta(&mut self, theta: &Array1<f64>) -> Result<(), String> {
1350 if self
1351 .current_theta
1352 .as_ref()
1353 .is_some_and(|cached| theta_values_match(cached, theta))
1354 {
1355 return Ok(());
1356 }
1357 let t_ensure = std::time::Instant::now();
1358 let log_kappa = SpatialLogKappaCoords::from_theta_tail_with_dims(
1359 theta,
1360 self.rho_dim,
1361 self.dims_per_term.clone(),
1362 );
1363 self.realizer
1364 .apply_log_kappa(&log_kappa, &self.spatial_terms)?;
1365 log::info!(
1366 "[STAGE] ensure_theta (apply_log_kappa, {} terms): {:.3}s",
1367 self.spatial_terms.len(),
1368 t_ensure.elapsed().as_secs_f64(),
1369 );
1370 self.current_theta = Some(theta.clone());
1371 self.last_eval_theta = None;
1372 self.last_cost = None;
1373 self.last_eval = None;
1374 Ok(())
1375 }
1376
1377 fn memoized_cost(&self, theta: &Array1<f64>) -> Option<f64> {
1384 if self
1385 .last_eval_theta
1386 .as_ref()
1387 .is_some_and(|cached| theta_values_match(cached, theta))
1388 {
1389 self.last_eval
1390 .as_ref()
1391 .map(|cached| cached.0)
1392 .or(self.last_cost)
1393 } else {
1394 None
1395 }
1396 }
1397
1398 fn memoized_eval(
1399 &self,
1400 theta: &Array1<f64>,
1401 ) -> Option<(
1402 f64,
1403 Array1<f64>,
1404 gam_problem::HessianResult,
1405 )> {
1406 if self
1407 .last_eval_theta
1408 .as_ref()
1409 .is_some_and(|cached| theta_values_match(cached, theta))
1410 {
1411 self.last_eval.clone()
1412 } else {
1413 None
1414 }
1415 }
1416
1417 fn store_eval_at(
1421 &mut self,
1422 theta: &Array1<f64>,
1423 eval: (
1424 f64,
1425 Array1<f64>,
1426 gam_problem::HessianResult,
1427 ),
1428 ) {
1429 self.last_eval_theta = Some(theta.clone());
1430 self.last_cost = Some(eval.0);
1431 self.last_eval = Some(eval);
1432 }
1433
1434 fn store_cost_at(&mut self, theta: &Array1<f64>, cost: f64) {
1437 self.last_eval_theta = Some(theta.clone());
1438 self.last_cost = Some(cost);
1439 self.last_eval = None;
1443 }
1444
1445 fn spec(&self) -> &TermCollectionSpec {
1446 self.realizer.spec()
1447 }
1448
1449 fn design(&self) -> &TermCollectionDesign {
1450 self.realizer.design()
1451 }
1452
1453 fn supports_nfree_penalty_rekey(&self) -> bool {
1459 self.realizer
1460 .supports_nfree_penalty_rekey(&self.spatial_terms)
1461 }
1462
1463 fn supports_nfree_gradient_only_routing(&self) -> bool {
1464 self.realizer
1465 .supports_nfree_gradient_only_routing(&self.spatial_terms)
1466 }
1467
1468 fn canonical_penalties_at(
1478 &mut self,
1479 theta: &Array1<f64>,
1480 ) -> Result<(Vec<gam_terms::construction::CanonicalPenalty>, Vec<usize>), String> {
1481 let psi = &theta
1482 .as_slice()
1483 .ok_or_else(|| "canonical_penalties_at: theta is not contiguous".to_string())?
1484 [self.rho_dim..];
1485 self.realizer
1486 .canonical_penalties_at_psi(&self.spatial_terms, psi)
1487 }
1488}
1489
1490struct SingleBlockLatentCoordDesignCache {
1491 data: Array2<f64>,
1492 spec: TermCollectionSpec,
1493 design: TermCollectionDesign,
1494 current_theta: Option<Array1<f64>>,
1495 current_latent: Option<std::sync::Arc<gam_terms::latent::LatentCoordValues>>,
1496 current_hyper_dirs: Option<Vec<gam_solve::estimate::reml::DirectionalHyperParam>>,
1497 current_design_cache_id: Option<u64>,
1498 latent_design_cache: gam_solve::latent_cache::LatentDesignCache,
1499 last_cost: Option<f64>,
1500 last_eval: Option<(
1501 f64,
1502 Array1<f64>,
1503 gam_problem::HessianResult,
1504 )>,
1505 term_index: gam_problem::types::SmoothTermIdx,
1506 feature_cols: Vec<usize>,
1507 rho_dim: usize,
1508 n_obs: usize,
1509 latent_dim: usize,
1510 id_mode: gam_terms::latent::LatentIdMode,
1511 manifold: gam_terms::latent::LatentManifold,
1512 retraction_registry: gam_solve::latent_cache::LatentRetractionRegistry,
1513 latent_id: u64,
1514 analytic_penalties: Option<std::sync::Arc<gam_terms::AnalyticPenaltyRegistry>>,
1515 analytic_rho_count: usize,
1516 design_revision: u64,
1517 last_outer_iter: Option<u64>,
1521}
1522
1523impl SingleBlockLatentCoordDesignCache {
1524 fn new(
1525 data: Array2<f64>,
1526 spec: TermCollectionSpec,
1527 design: TermCollectionDesign,
1528 latent: &StandardLatentCoordConfig,
1529 rho_dim: usize,
1530 ) -> Result<Self, String> {
1531 if latent.term_index.get() >= spec.smooth_terms.len() {
1532 return Err(SmoothError::dimension_mismatch(format!(
1533 "latent-coordinate term index {} out of bounds for {} smooth terms",
1534 latent.term_index,
1535 spec.smooth_terms.len()
1536 ))
1537 .into());
1538 }
1539 if latent.feature_cols.len() != latent.values.latent_dim() {
1540 return Err(SmoothError::dimension_mismatch(format!(
1541 "latent-coordinate feature width mismatch: feature_cols={}, latent_dim={}",
1542 latent.feature_cols.len(),
1543 latent.values.latent_dim()
1544 ))
1545 .into());
1546 }
1547 if latent.values.n_obs() != data.nrows() {
1548 return Err(SmoothError::dimension_mismatch(format!(
1549 "latent-coordinate row mismatch: latent n={}, data n={}",
1550 latent.values.n_obs(),
1551 data.nrows()
1552 ))
1553 .into());
1554 }
1555 let analytic_rho_count = latent
1556 .analytic_penalties
1557 .as_ref()
1558 .map_or(0, |registry| registry.total_rho_count());
1559 Ok(Self {
1560 data,
1561 spec,
1562 design,
1563 current_theta: None,
1564 current_latent: None,
1565 current_hyper_dirs: None,
1566 current_design_cache_id: None,
1567 latent_design_cache: gam_solve::latent_cache::LatentDesignCache::default(),
1568 last_cost: None,
1569 last_eval: None,
1570 term_index: latent.term_index,
1571 feature_cols: latent.feature_cols.clone(),
1572 rho_dim,
1573 n_obs: latent.values.n_obs(),
1574 latent_dim: latent.values.latent_dim(),
1575 id_mode: latent.values.id_mode().clone(),
1576 manifold: latent.values.manifold().clone(),
1577 retraction_registry: latent.values.retraction_registry().clone(),
1578 latent_id: latent.values.latent_id(),
1579 analytic_penalties: latent.analytic_penalties.clone(),
1580 analytic_rho_count,
1581 design_revision: 0,
1582 last_outer_iter: None,
1583 })
1584 }
1585
1586 fn design_revision(&self) -> u64 {
1587 self.design_revision
1588 }
1589
1590 fn design(&self) -> &TermCollectionDesign {
1591 &self.design
1592 }
1593
1594 fn latent(&self) -> Result<std::sync::Arc<gam_terms::latent::LatentCoordValues>, String> {
1595 self.current_latent
1596 .as_ref()
1597 .cloned()
1598 .ok_or_else(|| "latent-coordinate cache has not been realized".to_string())
1599 }
1600
1601 fn analytic_penalties(&self) -> Option<std::sync::Arc<gam_terms::AnalyticPenaltyRegistry>> {
1602 self.analytic_penalties.clone()
1603 }
1604
1605 fn analytic_penalty_rho_count(&self) -> usize {
1606 self.analytic_rho_count
1607 }
1608
1609 fn hyper_dirs(&self) -> Result<Vec<gam_solve::estimate::reml::DirectionalHyperParam>, String> {
1610 self.current_hyper_dirs
1611 .as_ref()
1612 .cloned()
1613 .ok_or_else(|| "latent-coordinate hyper_dirs cache has not been realized".to_string())
1614 }
1615
1616 fn latent_basis_kind(&self) -> Result<gam_solve::latent_cache::LatentBasisKind, String> {
1617 let smooth_term = self
1618 .design
1619 .smooth
1620 .terms
1621 .get(self.term_index.get())
1622 .ok_or_else(|| {
1623 SmoothError::dimension_mismatch(format!(
1624 "LatentCoord term index {} out of bounds for realized smooth design",
1625 self.term_index
1626 ))
1627 })?;
1628 let termspec = self
1629 .spec
1630 .smooth_terms
1631 .get(self.term_index.get())
1632 .ok_or_else(|| {
1633 SmoothError::dimension_mismatch(format!(
1634 "LatentCoord term index {} out of bounds for resolved smooth spec",
1635 self.term_index
1636 ))
1637 })?;
1638 match (&termspec.basis, &smooth_term.metadata) {
1639 (
1640 SmoothBasisSpec::Matern { .. },
1641 BasisMetadata::Matern {
1642 centers,
1643 length_scale,
1644 nu,
1645 aniso_log_scales,
1646 ..
1647 },
1648 ) => Ok(gam_solve::latent_cache::LatentBasisKind::Matern {
1649 centers: centers.clone(),
1650 length_scale: *length_scale,
1651 nu: *nu,
1652 aniso_log_scales: aniso_log_scales
1653 .clone()
1654 .unwrap_or_else(|| vec![0.0; centers.ncols()]),
1655 chunk_size: gam_terms::basis::auto_streaming_chunk_size_for_dense(
1656 self.n_obs,
1657 centers.nrows(),
1658 ),
1659 }),
1660 (
1661 SmoothBasisSpec::Duchon { .. },
1662 BasisMetadata::Duchon {
1663 centers,
1664 length_scale,
1665 power,
1666 nullspace_order,
1667 aniso_log_scales,
1668 ..
1669 },
1670 ) => Ok(gam_solve::latent_cache::LatentBasisKind::Duchon {
1671 centers: centers.clone(),
1672 length_scale: *length_scale,
1673 power: *power,
1674 nullspace_order: *nullspace_order,
1675 aniso_log_scales: aniso_log_scales
1676 .clone()
1677 .unwrap_or_else(|| vec![0.0; centers.ncols()]),
1678 }),
1679 (
1680 SmoothBasisSpec::Sphere { .. },
1681 BasisMetadata::Sphere {
1682 centers,
1683 penalty_order,
1684 method,
1685 ..
1686 },
1687 ) if matches!(*method, gam_terms::basis::SphereMethod::Wahba) => {
1688 Ok(gam_solve::latent_cache::LatentBasisKind::Sphere {
1689 centers: centers.clone(),
1690 penalty_order: *penalty_order,
1691 chunk_size: gam_terms::basis::auto_streaming_chunk_size_for_dense(
1692 self.n_obs,
1693 centers.nrows(),
1694 ),
1695 })
1696 }
1697 (
1698 SmoothBasisSpec::BSpline1D { spec, .. },
1699 BasisMetadata::BSpline1D {
1700 knots,
1701 periodic,
1702 degree: meta_degree,
1703 ..
1704 },
1705 ) => {
1706 let effective_degree = meta_degree.unwrap_or(spec.degree);
1710 if let Some((domain_start, period, num_basis)) = periodic {
1711 Ok(
1712 gam_solve::latent_cache::LatentBasisKind::PeriodicBspline {
1713 domain_start: *domain_start,
1714 period: *period,
1715 degree: effective_degree,
1716 num_basis: *num_basis,
1717 chunk_size: gam_terms::basis::auto_streaming_chunk_size_for_dense(
1718 self.n_obs, *num_basis,
1719 ),
1720 },
1721 )
1722 } else {
1723 let num_basis_est = knots.len().saturating_sub(effective_degree + 1);
1724 Ok(
1725 gam_solve::latent_cache::LatentBasisKind::TensorBspline {
1726 knots: vec![knots.clone()],
1727 degrees: vec![effective_degree],
1728 chunk_size: gam_terms::basis::auto_streaming_chunk_size_for_dense(
1729 self.n_obs,
1730 num_basis_est,
1731 ),
1732 },
1733 )
1734 }
1735 }
1736 (
1737 SmoothBasisSpec::TensorBSpline { .. },
1738 BasisMetadata::TensorBSpline { knots, degrees, .. },
1739 ) => Ok(
1740 gam_solve::latent_cache::LatentBasisKind::TensorBspline {
1741 knots: knots.clone(),
1742 degrees: degrees.clone(),
1743 chunk_size: None,
1744 },
1745 ),
1746 (
1747 SmoothBasisSpec::Pca { .. },
1748 BasisMetadata::Pca {
1749 basis_matrix,
1750 centered,
1751 smooth_penalty,
1752 center_mean,
1753 pca_basis_path,
1754 chunk_size,
1755 ..
1756 },
1757 ) => {
1758 let center_mean_fingerprint = if *centered && pca_basis_path.is_none() {
1759 let mean = center_mean.as_ref().ok_or_else(|| {
1760 SmoothError::invalid_config(
1761 "latent-coordinate Pca cache key requires center_mean when centered",
1762 )
1763 })?;
1764 Some(gam_solve::latent_cache::pca_center_mean_fingerprint(
1765 mean,
1766 ))
1767 } else {
1768 None
1769 };
1770 Ok(gam_solve::latent_cache::LatentBasisKind::Pca {
1771 basis_matrix: basis_matrix.clone(),
1772 centered: *centered,
1773 center_mean_fingerprint,
1774 smooth_penalty: *smooth_penalty,
1775 pca_basis_path: pca_basis_path.clone(),
1776 chunk_size: *chunk_size,
1777 })
1778 }
1779 _ => Err(SmoothError::invalid_config(
1780 "latent-coordinate design cache could not key the realized latent smooth basis"
1781 .to_string(),
1782 )
1783 .into()),
1784 }
1785 }
1786
1787 fn ensure_theta(&mut self, theta: &Array1<f64>) -> Result<(), String> {
1788 if self
1789 .current_theta
1790 .as_ref()
1791 .is_some_and(|cached| theta_values_match(cached, theta))
1792 {
1793 return Ok(());
1794 }
1795 let latent_flat_len = self.n_obs * self.latent_dim;
1796 let direct_hyper_count = latent_coord_direct_hyper_count(&self.id_mode, self.latent_dim);
1797 let expected =
1798 self.rho_dim + latent_flat_len + self.analytic_rho_count + direct_hyper_count;
1799 if theta.len() != expected {
1800 return Err(SmoothError::dimension_mismatch(format!(
1801 "latent-coordinate theta length mismatch: got {}, expected {} (rho_dim={}, n={}, d={}, analytic_rhos={}, direct_hypers={})",
1802 theta.len(),
1803 expected,
1804 self.rho_dim,
1805 self.n_obs,
1806 self.latent_dim,
1807 self.analytic_rho_count,
1808 direct_hyper_count
1809 ))
1810 .into());
1811 }
1812 let flat = theta
1813 .slice(s![self.rho_dim..self.rho_dim + latent_flat_len])
1814 .to_owned();
1815 let latent = std::sync::Arc::new(
1816 gam_terms::latent::LatentCoordValues::from_flat_with_manifold_and_retraction_and_id(
1817 flat,
1818 self.n_obs,
1819 self.latent_dim,
1820 self.id_mode.clone(),
1821 self.manifold.clone(),
1822 self.retraction_registry.clone(),
1823 self.latent_id,
1824 ),
1825 );
1826 let latent_values_changed = self
1827 .current_latent
1828 .as_ref()
1829 .map(|cached| !latent_values_match(cached.as_flat(), latent.as_flat()))
1830 .unwrap_or(true);
1831 if latent_values_changed {
1832 self.latent_design_cache.invalidate_all();
1833 self.current_design_cache_id = None;
1834 self.design_revision = self.design_revision.wrapping_add(1);
1835 }
1836 for n in 0..self.n_obs {
1837 for axis in 0..self.latent_dim {
1838 let col = self.feature_cols[axis];
1839 self.data[[n, col]] = latent.as_flat()[n * self.latent_dim + axis];
1840 }
1841 }
1842
1843 let basis_kind = self.latent_basis_kind()?;
1844 let rebuilt_width = self.design.design.ncols();
1845 let spec = self.spec.clone();
1846 let term_index = self.term_index;
1847 let analytic_rho_count = self.analytic_rho_count;
1848 let data = self.data.view();
1849 let design_context_digest =
1850 gam_solve::latent_cache::latent_design_context_cache_digest(
1851 data,
1852 &spec,
1853 term_index,
1854 analytic_rho_count,
1855 &self.feature_cols,
1856 )
1857 .map_err(|e| e.to_string())?;
1858 let lookup = self
1859 .latent_design_cache
1860 .lookup_or_compute(latent.clone(), basis_kind, design_context_digest, || {
1861 let rebuilt = build_term_collection_design(data, &spec).map_err(|e| {
1862 EstimationError::InvalidInput(format!(
1863 "failed to rebuild latent-coordinate design: {e}"
1864 ))
1865 })?;
1866 if rebuilt.design.ncols() != rebuilt_width {
1867 crate::bail_invalid_estim!(
1868 "latent-coordinate design topology changed: rebuilt p={}, cached p={}",
1869 rebuilt.design.ncols(),
1870 rebuilt_width
1871 );
1872 }
1873 let hyper_dirs = try_build_latent_coord_hyper_dirs(
1874 latent.clone(),
1875 &spec,
1876 &rebuilt,
1877 &[term_index],
1878 analytic_rho_count,
1879 )?
1880 .ok_or_else(|| {
1881 EstimationError::InvalidInput(
1882 "failed to build latent-coordinate hyper_dirs".to_string(),
1883 )
1884 })?;
1885 Ok(gam_solve::latent_cache::ComputedLatentDesign {
1886 design: rebuilt,
1887 hyper_dirs,
1888 })
1889 })
1890 .map_err(|e| e.to_string())?;
1891 if lookup.cached.design.design.ncols() != self.design.design.ncols() {
1892 return Err(SmoothError::dimension_mismatch(format!(
1893 "latent-coordinate design topology changed: rebuilt p={}, cached p={}",
1894 lookup.cached.design.design.ncols(),
1895 self.design.design.ncols()
1896 ))
1897 .into());
1898 }
1899 self.design = lookup.cached.design.clone();
1900 self.current_hyper_dirs = Some(lookup.cached.hyper_dirs.clone());
1901 self.current_latent = Some(latent);
1902 self.current_theta = Some(theta.clone());
1903 self.last_cost = None;
1904 self.last_eval = None;
1905 self.last_outer_iter = None;
1906 if !latent_values_changed && self.current_design_cache_id != Some(lookup.entry_id) {
1907 self.design_revision = self.design_revision.wrapping_add(1);
1908 }
1909 self.current_design_cache_id = Some(lookup.entry_id);
1910 Ok(())
1911 }
1912
1913 fn memoized_cost(&self, theta: &Array1<f64>) -> Option<f64> {
1914 if self
1915 .current_theta
1916 .as_ref()
1917 .is_some_and(|cached| theta_values_match(cached, theta))
1918 && self.last_outer_iter
1919 == Some(gam_solve::estimate::reml::outer_eval::current_outer_iter())
1920 {
1921 self.last_eval
1922 .as_ref()
1923 .map(|cached| cached.0)
1924 .or(self.last_cost)
1925 } else {
1926 None
1927 }
1928 }
1929
1930 fn memoized_eval(
1931 &self,
1932 theta: &Array1<f64>,
1933 ) -> Option<(
1934 f64,
1935 Array1<f64>,
1936 gam_problem::HessianResult,
1937 )> {
1938 if self
1939 .current_theta
1940 .as_ref()
1941 .is_some_and(|cached| theta_values_match(cached, theta))
1942 && self.last_outer_iter
1943 == Some(gam_solve::estimate::reml::outer_eval::current_outer_iter())
1944 {
1945 self.last_eval.clone()
1946 } else {
1947 None
1948 }
1949 }
1950
1951 fn store_eval(
1952 &mut self,
1953 eval: (
1954 f64,
1955 Array1<f64>,
1956 gam_problem::HessianResult,
1957 ),
1958 ) {
1959 self.last_cost = Some(eval.0);
1960 self.last_eval = Some(eval);
1961 self.last_outer_iter =
1962 Some(gam_solve::estimate::reml::outer_eval::current_outer_iter());
1963 }
1964
1965 fn store_cost(&mut self, cost: f64) {
1966 self.last_cost = Some(cost);
1967 self.last_outer_iter =
1968 Some(gam_solve::estimate::reml::outer_eval::current_outer_iter());
1969 }
1970
1971 fn reset(&mut self) {
1972 self.current_theta = None;
1973 self.current_latent = None;
1974 self.current_hyper_dirs = None;
1975 self.current_design_cache_id = None;
1976 self.latent_design_cache.invalidate();
1977 self.last_cost = None;
1978 self.last_eval = None;
1979 self.last_outer_iter = None;
1980 }
1981}
1982
1983pub fn fixed_kappa_profiled_reml_score(
1999 data: ArrayView2<'_, f64>,
2000 y: ArrayView1<'_, f64>,
2001 weights: ArrayView1<'_, f64>,
2002 offset: ArrayView1<'_, f64>,
2003 resolvedspec: &TermCollectionSpec,
2004 term_idx: usize,
2005 kappa: f64,
2006 family: LikelihoodSpec,
2007 options: &FitOptions,
2008) -> Result<f64, EstimationError> {
2009 if !kappa.is_finite() {
2010 crate::bail_invalid_estim!("fixed-κ profiled score probed a non-finite κ = {kappa}");
2011 }
2012 let (feature_cols, mut probe_basis) = match resolvedspec
2015 .smooth_terms
2016 .get(term_idx)
2017 .map(|t| &t.basis)
2018 {
2019 Some(SmoothBasisSpec::ConstantCurvature {
2020 feature_cols, spec, ..
2021 }) => (feature_cols.clone(), spec.clone()),
2022 _ => {
2023 crate::bail_invalid_estim!(
2024 "fixed-κ profiled score: term {term_idx} is not a constant-curvature smooth"
2025 )
2026 }
2027 };
2028 probe_basis.kappa = kappa;
2029
2030 let is_unweighted = weights.iter().all(|&w| (w - 1.0).abs() <= 1e-12);
2050 let is_zero_offset = offset.iter().all(|&o| o.abs() <= 1e-12);
2051 if family == LikelihoodSpec::gaussian_identity() && is_unweighted && is_zero_offset {
2052 let x_term = select_columns(data, &feature_cols).map_err(EstimationError::from)?;
2053 let score =
2054 gam_terms::basis::constant_curvature_honest_profiled_reml_score(x_term.view(), y, &probe_basis)
2055 .map_err(|e| {
2056 EstimationError::InvalidInput(format!(
2057 "fixed-κ honest profiled-REML score at κ={kappa} failed: {e}"
2058 ))
2059 })?;
2060 if !score.is_finite() {
2061 crate::bail_invalid_estim!(
2062 "fixed-κ honest profiled-REML score at κ={kappa} is non-finite"
2063 );
2064 }
2065 return Ok(score);
2066 }
2067
2068 let mut probe_spec = resolvedspec.clone();
2070 match probe_spec.smooth_terms.get_mut(term_idx).map(|t| &mut t.basis) {
2071 Some(SmoothBasisSpec::ConstantCurvature { spec, .. }) => spec.kappa = kappa,
2072 _ => {
2073 crate::bail_invalid_estim!(
2074 "fixed-κ profiled score: term {term_idx} is not a constant-curvature smooth"
2075 )
2076 }
2077 }
2078 let fixed_kappa_options = SpatialLengthScaleOptimizationOptions {
2079 enabled: false,
2080 ..SpatialLengthScaleOptimizationOptions::default()
2081 };
2082 let fit = fit_term_collectionwith_spatial_length_scale_optimization(
2083 data,
2084 y.to_owned(),
2085 weights.to_owned(),
2086 offset.to_owned(),
2087 &probe_spec,
2088 family,
2089 options,
2090 &fixed_kappa_options,
2091 )?;
2092 let score = fit_score(&fit.fit);
2093 if !score.is_finite() {
2094 crate::bail_invalid_estim!("fixed-κ profiled fit at κ={kappa} returned a non-finite score");
2095 }
2096 Ok(score)
2097}
2098
2099fn constant_curvature_kappa_fair_argmin(
2124 data: ArrayView2<'_, f64>,
2125 y: ArrayView1<'_, f64>,
2126 resolvedspec: &TermCollectionSpec,
2127 term_idx: usize,
2128) -> Option<f64> {
2129 let (kappa_min, kappa_max) = constant_curvature_kappa_bounds(data, resolvedspec, term_idx);
2130 if !(kappa_min.is_finite() && kappa_max.is_finite() && kappa_max > kappa_min) {
2131 return None;
2132 }
2133 let (feature_cols, base_spec) = match resolvedspec.smooth_terms.get(term_idx).map(|t| &t.basis) {
2134 Some(SmoothBasisSpec::ConstantCurvature {
2135 feature_cols, spec, ..
2136 }) => (feature_cols, spec.clone()),
2137 _ => return None,
2138 };
2139 let x_term = match select_columns(data, feature_cols) {
2140 Ok(x) => x,
2141 Err(e) => {
2142 log::info!("[spatial-kappa] #1464 κ-fair argmin column select failed ({e}); skipping");
2143 return None;
2144 }
2145 };
2146 const GRID_STEPS: usize = 24;
2152 let mut best: Option<(f64, f64)> = None; for i in 0..=GRID_STEPS {
2154 let t = i as f64 / GRID_STEPS as f64;
2155 let kappa = kappa_min + (kappa_max - kappa_min) * t;
2156 let mut probe_spec = base_spec.clone();
2157 probe_spec.kappa = kappa;
2158 match gam_terms::basis::constant_curvature_kappa_fair_sign_score(x_term.view(), y, &probe_spec) {
2159 Ok(score) => {
2160 if best.as_ref().is_none_or(|(b, _)| score < *b) {
2161 best = Some((score, kappa));
2162 }
2163 }
2164 Err(e) => {
2165 log::info!(
2166 "[spatial-kappa] #1464 κ-fair argmin probe at κ={kappa:.4} failed ({e}); skipping"
2167 );
2168 }
2169 }
2170 }
2171 best.map(|(score, kappa)| {
2172 log::info!(
2173 "[spatial-kappa] #1464 κ-fair argmin κ̂={kappa:.4} (κ-fair score={score:.6e}) for term {term_idx}"
2174 );
2175 kappa
2176 })
2177}
2178
2179fn select_constant_curvature_kappa_sign_seed(
2187 data: ArrayView2<'_, f64>,
2188 y: ArrayView1<'_, f64>,
2189 resolvedspec: &TermCollectionSpec,
2190 term_idx: usize,
2191) -> Option<f64> {
2192 let (kappa_min, kappa_max) = constant_curvature_kappa_bounds(data, resolvedspec, term_idx);
2193 if !(kappa_min.is_finite() && kappa_max.is_finite() && kappa_max > kappa_min) {
2194 return None;
2195 }
2196 let (feature_cols, base_spec) = match resolvedspec.smooth_terms.get(term_idx).map(|t| &t.basis) {
2208 Some(SmoothBasisSpec::ConstantCurvature {
2209 feature_cols, spec, ..
2210 }) => (feature_cols, spec.clone()),
2211 _ => return None,
2212 };
2213 let x_term = match select_columns(data, feature_cols) {
2214 Ok(x) => x,
2215 Err(e) => {
2216 log::info!("[spatial-kappa] #1464 sign-basin scan column select failed ({e}); skipping");
2217 return None;
2218 }
2219 };
2220 let probes = [
2224 kappa_min,
2225 0.5 * kappa_min,
2226 0.0,
2227 0.5 * kappa_max,
2228 kappa_max,
2229 ];
2230 let mut best: Option<(f64, f64)> = None; for &kappa in &probes {
2232 let mut probe_spec = base_spec.clone();
2233 probe_spec.kappa = kappa;
2234 match gam_terms::basis::constant_curvature_kappa_fair_sign_score(
2235 x_term.view(),
2236 y,
2237 &probe_spec,
2238 ) {
2239 Ok(score) => {
2240 if best.as_ref().is_none_or(|(b, _)| score < *b) {
2241 best = Some((score, kappa));
2242 }
2243 }
2244 Err(e) => {
2245 log::info!(
2246 "[spatial-kappa] #1464 sign-basin probe at κ={kappa:.4} failed ({e}); skipping"
2247 );
2248 }
2249 }
2250 }
2251 best.map(|(score, kappa)| {
2252 log::info!(
2253 "[spatial-kappa] #1464 κ-fair sign-basin scan selected κ_seed={kappa:.4} \
2254 (κ-fair score={score:.6e}) for term {term_idx}"
2255 );
2256 kappa
2257 })
2258}
2259
2260const SPATIAL_RANGE_PRESCAN_GRID: usize = 7;
2263
2264fn prescan_isotropic_spatial_range_seed(
2296 data: ArrayView2<'_, f64>,
2297 y: ArrayView1<'_, f64>,
2298 weights: ArrayView1<'_, f64>,
2299 offset: ArrayView1<'_, f64>,
2300 resolvedspec: &TermCollectionSpec,
2301 baseline_score: f64,
2302 family: &LikelihoodSpec,
2303 options: &FitOptions,
2304 kappa_options: &SpatialLengthScaleOptimizationOptions,
2305 spatial_terms: &[usize],
2306) -> Result<Vec<(usize, f64)>, EstimationError> {
2307 if has_aniso_terms(resolvedspec, spatial_terms)
2309 || !constant_curvature_term_indices(resolvedspec).is_empty()
2310 {
2311 return Ok(Vec::new());
2312 }
2313 let dims = spatial_dims_per_term(resolvedspec, spatial_terms);
2314 let mut working = resolvedspec.clone();
2318 let mut best_score = if baseline_score.is_finite() {
2319 baseline_score
2320 } else {
2321 f64::INFINITY
2322 };
2323 let mut overrides: Vec<(usize, f64)> = Vec::new();
2324 for (slot, &term_idx) in spatial_terms.iter().enumerate() {
2325 if dims.get(slot).copied().unwrap_or(1) != 1 {
2328 continue;
2329 }
2330 if get_spatial_length_scale(&working, term_idx).is_none() {
2333 continue;
2334 }
2335 let (psi_lo, psi_hi) = spatial_term_psi_bounds(data, &working, term_idx, kappa_options);
2336 if !(psi_lo.is_finite() && psi_hi.is_finite()) || psi_hi <= psi_lo {
2337 continue;
2338 }
2339 let mut term_best: Option<f64> = None;
2340 for g in 0..SPATIAL_RANGE_PRESCAN_GRID {
2341 let frac = g as f64 / (SPATIAL_RANGE_PRESCAN_GRID - 1) as f64;
2342 let psi = psi_lo + (psi_hi - psi_lo) * frac;
2343 let ls = (-psi).exp();
2347 if !ls.is_finite() || ls <= 0.0 {
2348 continue;
2349 }
2350 let mut probe = working.clone();
2351 if set_spatial_length_scale(&mut probe, term_idx, ls).is_err() {
2352 continue;
2353 }
2354 let fit = match fit_term_collection_forspec(
2355 data,
2356 y,
2357 weights,
2358 offset,
2359 &probe,
2360 family.clone(),
2361 options,
2362 ) {
2363 Ok(fit) => fit,
2364 Err(_) => continue,
2367 };
2368 let score = fit_score(&fit.fit);
2369 if score.is_finite() && score < best_score - 1e-7 * best_score.abs().max(1.0) {
2372 best_score = score;
2373 term_best = Some(ls);
2374 }
2375 }
2376 if let Some(ls) = term_best {
2377 set_spatial_length_scale(&mut working, term_idx, ls)?;
2378 overrides.push((term_idx, ls));
2379 log::info!(
2380 "[spatial-kappa] #1074 range pre-scan: term {term_idx} re-seeded at \
2381 length_scale={ls:.5} (profiled REML {best_score:.5}, was {baseline_score:.5})"
2382 );
2383 }
2384 }
2385 Ok(overrides)
2386}
2387
2388fn try_exact_joint_spatial_length_scale_optimization(
2389 data: ArrayView2<'_, f64>,
2390 y: ArrayView1<'_, f64>,
2391 weights: ArrayView1<'_, f64>,
2392 offset: ArrayView1<'_, f64>,
2393 resolvedspec: &TermCollectionSpec,
2394 best: &FittedTermCollection,
2395 family: LikelihoodSpec,
2396 options: &FitOptions,
2397 kappa_options: &SpatialLengthScaleOptimizationOptions,
2398 spatial_terms: &[usize],
2399) -> Result<Option<FittedTermCollectionWithSpec>, EstimationError> {
2400 if spatial_terms.is_empty() {
2401 return Ok(None);
2402 }
2403 kappa_options
2408 .validate()
2409 .map_err(EstimationError::InvalidInput)?;
2410
2411 let cc_term_set = constant_curvature_term_indices(resolvedspec);
2431 let all_spatial_are_cc =
2432 !cc_term_set.is_empty() && spatial_terms.iter().all(|t| cc_term_set.contains(t));
2433 if all_spatial_are_cc {
2434 let mut fixed_kappa_spec = resolvedspec.clone();
2435 let mut any_kappa_chosen = false;
2436 for &term_idx in spatial_terms {
2437 if let Some(kappa_hat) =
2448 constant_curvature_kappa_fair_argmin(data, y, resolvedspec, term_idx)
2449 .filter(|&k| k < 0.0)
2450 {
2451 if let Some(SmoothBasisSpec::ConstantCurvature { spec: cc, .. }) = fixed_kappa_spec
2452 .smooth_terms
2453 .get_mut(term_idx)
2454 .map(|t| &mut t.basis)
2455 {
2456 cc.kappa = kappa_hat;
2457 any_kappa_chosen = true;
2458 log::info!(
2459 "[spatial-kappa] #1464 term {term_idx}: fixed κ̂ = {kappa_hat:.4} from κ-fair argmin (hyperbolic basin; profiling ρ only)"
2460 );
2461 }
2462 }
2463 }
2464 if any_kappa_chosen {
2465 let baseline_score = fit_score(&best.fit);
2469 let fitted = fit_term_collection_forspec(
2470 data,
2471 y,
2472 weights,
2473 offset,
2474 &fixed_kappa_spec,
2475 family.clone(),
2476 options,
2477 )?;
2478 let frozen_spec =
2479 freeze_term_collection_from_design(&fixed_kappa_spec, &fitted.design)?;
2480 let mut fit = fitted.fit;
2481 fit.reml_score = baseline_score;
2493 return Ok(Some(FittedTermCollectionWithSpec {
2494 fit,
2495 design: fitted.design,
2496 resolvedspec: frozen_spec,
2497 adaptive_diagnostics: fitted.adaptive_diagnostics,
2498 kappa_timing: None,
2499 }));
2500 }
2501 }
2502
2503 if try_build_spatial_log_kappa_hyper_dirs(data, resolvedspec, &best.design, spatial_terms)?
2504 .is_none()
2505 {
2506 if !constant_curvature_term_indices(resolvedspec).is_empty() {
2507 log::info!(
2508 "[#1464-trace] try_exact_joint RETURNED None (hyper_dirs unavailable); \
2509 κ̂ comes from a NON-joint path"
2510 );
2511 }
2512 return Ok(None);
2513 }
2514 if !constant_curvature_term_indices(resolvedspec).is_empty() {
2515 log::info!(
2516 "[#1464-trace] try_exact_joint ENTERED for {} spatial term(s); CC present",
2517 spatial_terms.len()
2518 );
2519 }
2520
2521 const JOINT_RHO_BOUND: f64 = 12.0;
2522 let rho_dim = best.fit.lambdas.len();
2523
2524 let has_constant_curvature_term = !constant_curvature_term_indices(resolvedspec).is_empty();
2538 let rho_upper_bound = if has_constant_curvature_term {
2539 gam_solve::estimate::RHO_BOUND
2540 } else {
2541 JOINT_RHO_BOUND
2542 };
2543
2544 let dims_per_term = spatial_dims_per_term(resolvedspec, spatial_terms);
2546 let use_aniso = has_aniso_terms(resolvedspec, spatial_terms);
2547
2548 let log_kappa0 = if use_aniso {
2553 SpatialLogKappaCoords::from_length_scales_aniso(resolvedspec, spatial_terms, kappa_options)
2554 } else {
2555 SpatialLogKappaCoords::from_length_scales(resolvedspec, spatial_terms, kappa_options)
2556 };
2557 let mut log_kappa0 =
2560 log_kappa0.reseed_from_data(data, resolvedspec, spatial_terms, kappa_options);
2561 let mut cc_sign_seeds: Vec<(usize, f64)> = Vec::new();
2577 if has_constant_curvature_term {
2578 for (slot, &term_idx) in spatial_terms.iter().enumerate() {
2579 if constant_curvature_term_spec(resolvedspec, term_idx).is_none() {
2580 continue;
2581 }
2582 let scan = select_constant_curvature_kappa_sign_seed(
2583 data,
2584 y,
2585 resolvedspec,
2586 term_idx,
2587 );
2588 match scan {
2593 Some(kappa_seed) => {
2594 log::info!(
2595 "[#1464-trace] term {term_idx}: κ-fair sign-basin scan picked κ_seed = {kappa_seed}"
2596 );
2597 log_kappa0.set_scalar_slot(slot, kappa_seed);
2598 cc_sign_seeds.push((slot, kappa_seed));
2599 }
2600 None => {
2601 log::info!(
2602 "[#1464-trace] term {term_idx}: fixed-κ sign-basin scan returned NONE (no seed applied)"
2603 );
2604 }
2605 }
2606 }
2607 }
2608 let log_kappa_lower = if use_aniso {
2609 SpatialLogKappaCoords::lower_bounds_aniso_from_data(
2610 data,
2611 resolvedspec,
2612 spatial_terms,
2613 &dims_per_term,
2614 kappa_options,
2615 )
2616 } else {
2617 SpatialLogKappaCoords::lower_bounds_from_data(
2618 data,
2619 resolvedspec,
2620 spatial_terms,
2621 kappa_options,
2622 )
2623 };
2624 let log_kappa_upper = if use_aniso {
2625 SpatialLogKappaCoords::upper_bounds_aniso_from_data(
2626 data,
2627 resolvedspec,
2628 spatial_terms,
2629 &dims_per_term,
2630 kappa_options,
2631 )
2632 } else {
2633 SpatialLogKappaCoords::upper_bounds_from_data(
2634 data,
2635 resolvedspec,
2636 spatial_terms,
2637 kappa_options,
2638 )
2639 };
2640 let mut log_kappa_lower = log_kappa_lower;
2664 let mut log_kappa_upper = log_kappa_upper;
2665 for &(slot, kappa_seed) in &cc_sign_seeds {
2666 if kappa_seed != 0.0 {
2667 log_kappa_lower.set_scalar_slot(slot, kappa_seed);
2668 log_kappa_upper.set_scalar_slot(slot, kappa_seed);
2669 }
2670 log::info!(
2671 "[#1464-trace] slot {slot}: FROZE joint ψ coordinate at κ_seed={kappa_seed} \
2672 (window [{}, {}]); raw fit_score is sign-blind so the κ-fair scan is authoritative",
2673 log_kappa_lower.as_array()[log_kappa_lower.dims_per_term()[..slot].iter().sum::<usize>()],
2674 log_kappa_upper.as_array()[log_kappa_upper.dims_per_term()[..slot].iter().sum::<usize>()],
2675 );
2676 }
2677 let log_kappa0 = log_kappa0.clamp_to_bounds(&log_kappa_lower, &log_kappa_upper);
2680 let setup = ExactJointHyperSetup::new(
2681 best.fit.lambdas.mapv(f64::ln),
2682 Array1::<f64>::from_elem(rho_dim, -JOINT_RHO_BOUND),
2683 Array1::<f64>::from_elem(rho_dim, rho_upper_bound),
2684 log_kappa0,
2685 log_kappa_lower,
2686 log_kappa_upper,
2687 );
2688
2689 let theta0 = setup.theta0();
2690 let lower = setup.lower();
2691 let upper = setup.upper();
2692
2693 let kind = if use_aniso {
2705 SpatialHyperKind::Anisotropic
2706 } else {
2707 SpatialHyperKind::Isotropic
2708 };
2709 let (outcome, kappa_timing) = run_exact_joint_spatial_optimization(
2710 kind,
2711 data,
2712 y,
2713 weights,
2714 offset,
2715 resolvedspec,
2716 &best.design,
2717 family.clone(),
2718 options,
2719 spatial_terms,
2720 &dims_per_term,
2721 &theta0,
2722 &lower,
2723 &upper,
2724 rho_dim,
2725 kappa_options,
2726 )?;
2727
2728 let baseline_score = fit_score(&best.fit);
2729
2730 let (theta_star, joint_final_value) = match outcome {
2740 SpatialJointOutcome::Optimized {
2741 theta_star,
2742 final_value,
2743 } => (theta_star, final_value),
2744 SpatialJointOutcome::NonConverged {
2745 iterations,
2746 final_value,
2747 final_grad_norm,
2748 } => {
2749 if has_constant_curvature_term {
2750 log::info!(
2751 "[#1464-trace] joint solve NONCONVERGED (iters={iterations}, \
2752 final_value={final_value}); returning FROZEN BASELINE geometry \
2753 (κ̂ = spec default, NOT the joint candidate)"
2754 );
2755 }
2756 log::info!(
2757 "[spatial-kappa] joint spatial optimization did not converge \
2758 (iterations={}, final_objective={:.6e}, final_grad_norm={}); \
2759 keeping the frozen baseline geometry",
2760 iterations,
2761 final_value,
2762 final_grad_norm.map_or_else(|| "n/a".to_string(), |g| format!("{g:.3e}")),
2763 );
2764 return Ok(Some(fit_frozen_baseline_geometry(
2765 data,
2766 y,
2767 weights,
2768 offset,
2769 resolvedspec,
2770 best,
2771 family,
2772 options,
2773 baseline_score,
2774 Some(kappa_timing),
2775 )?));
2776 }
2777 };
2778
2779 let accept_tol = options.tol.max(1e-8 * baseline_score.abs()).max(1e-12);
2784 if joint_final_value > baseline_score + accept_tol {
2785 if has_constant_curvature_term {
2786 log::info!(
2787 "[#1464-trace] joint candidate WORSENED score (joint={joint_final_value}, \
2788 baseline={baseline_score}); returning FROZEN BASELINE geometry \
2789 (κ̂ = spec default, NOT the joint candidate)"
2790 );
2791 }
2792 log::info!(
2793 "[spatial-kappa] exact joint spatial candidate worsened the profiled score (joint={:.6e}, baseline={:.6e}, tol={:.2e}); keeping the frozen baseline geometry",
2794 joint_final_value,
2795 baseline_score,
2796 accept_tol,
2797 );
2798 return Ok(Some(fit_frozen_baseline_geometry(
2799 data,
2800 y,
2801 weights,
2802 offset,
2803 resolvedspec,
2804 best,
2805 family,
2806 options,
2807 baseline_score,
2808 Some(kappa_timing),
2809 )?));
2810 }
2811
2812 let rho_star = theta_star.slice(s![..rho_dim]).mapv(f64::exp);
2813 let log_kappa_star =
2814 SpatialLogKappaCoords::from_theta_tail_with_dims(&theta_star, rho_dim, dims_per_term);
2815 if has_constant_curvature_term {
2821 let star = log_kappa_star.as_array();
2822 let dims = log_kappa_star.dims_per_term();
2823 for (slot, &term_idx) in spatial_terms.iter().enumerate() {
2824 if constant_curvature_term_spec(resolvedspec, term_idx).is_some() {
2825 let off: usize = dims[..slot].iter().sum();
2826 log::info!(
2827 "[#1464-trace] term {term_idx}: joint solver CONVERGED ψ-tail κ = {} \
2828 (this is the optimised candidate; joint_final_value={joint_final_value})",
2829 star[off]
2830 );
2831 }
2832 }
2833 }
2834 let baseline_spec = resolvedspec;
2838 let optimized_spec = log_kappa_star.apply_tospec(resolvedspec, spatial_terms)?;
2839 let optimized = fit_term_collection_forspecwith_heuristic_lambdas(
2840 data,
2841 y,
2842 weights,
2843 offset,
2844 &optimized_spec,
2845 rho_star.as_slice(),
2846 family.clone(),
2847 options,
2848 )?;
2849
2850 let optimized_edf = optimized.fit.inference.as_ref().map(|inf| inf.edf_total);
2864 if let Some(opt_edf) = optimized_edf
2865 && opt_edf < SPATIAL_COLLAPSE_EDF_FLOOR
2866 {
2867 let baseline = fit_frozen_baseline_geometry(
2868 data,
2869 y,
2870 weights,
2871 offset,
2872 baseline_spec,
2873 best,
2874 family.clone(),
2875 options,
2876 baseline_score,
2877 Some(kappa_timing),
2878 )?;
2879 let baseline_edf = baseline.fit.inference.as_ref().map(|inf| inf.edf_total);
2880 if let Some(base_edf) = baseline_edf
2881 && base_edf >= opt_edf + SPATIAL_COLLAPSE_EDF_MARGIN
2882 {
2883 log::info!(
2884 "[spatial-kappa] joint candidate collapsed to the null (edf={opt_edf:.3}); \
2885 baseline geometry retains edf={base_edf:.3} — keeping the frozen baseline",
2886 );
2887 return Ok(Some(baseline));
2888 }
2889 }
2892
2893 let mut fit = optimized.fit;
2897 fit.reml_score = joint_final_value;
2898 let optimized_result = FittedTermCollectionWithSpec {
2899 fit,
2900 design: optimized.design,
2901 resolvedspec: optimized_spec,
2902 adaptive_diagnostics: optimized.adaptive_diagnostics,
2903 kappa_timing: Some(kappa_timing),
2904 };
2905
2906 Ok(Some(optimized_result))
2907}
2908
2909const SPATIAL_COLLAPSE_EDF_FLOOR: f64 = 2.5;
2913
2914const SPATIAL_COLLAPSE_EDF_MARGIN: f64 = 1.0;
2919
2920fn fit_frozen_baseline_geometry(
2956 data: ArrayView2<'_, f64>,
2957 y: ArrayView1<'_, f64>,
2958 weights: ArrayView1<'_, f64>,
2959 offset: ArrayView1<'_, f64>,
2960 resolvedspec: &TermCollectionSpec,
2961 best: &FittedTermCollection,
2962 family: LikelihoodSpec,
2963 options: &FitOptions,
2964 baseline_score: f64,
2965 kappa_timing: Option<SpatialLengthScaleOptimizationTiming>,
2966) -> Result<FittedTermCollectionWithSpec, EstimationError> {
2967 let baseline = fit_term_collection_forspecwith_heuristic_lambdas(
2968 data,
2969 y,
2970 weights,
2971 offset,
2972 resolvedspec,
2973 best.fit.lambdas.as_slice(),
2974 family.clone(),
2975 options,
2976 )?;
2977 let best_edf = best.fit.inference.as_ref().map(|inf| inf.edf_total);
2982 let baseline_edf = baseline.fit.inference.as_ref().map(|inf| inf.edf_total);
2983 let baseline = match (best_edf, baseline_edf) {
2984 (Some(best_edf), Some(base_edf))
2985 if base_edf < SPATIAL_COLLAPSE_EDF_FLOOR
2986 && best_edf >= base_edf + SPATIAL_COLLAPSE_EDF_MARGIN =>
2987 {
2988 log::info!(
2989 "[spatial-kappa] warm-started frozen baseline collapsed (edf={base_edf:.3}) \
2990 below the certified baseline (edf={best_edf:.3}); refitting from scratch",
2991 );
2992 fit_term_collection_forspec(data, y, weights, offset, resolvedspec, family, options)?
2993 }
2994 _ => baseline,
2995 };
2996 let mut fit = baseline.fit;
2997 fit.reml_score = baseline_score;
2998 Ok(FittedTermCollectionWithSpec {
2999 fit,
3000 design: baseline.design,
3001 resolvedspec: resolvedspec.clone(),
3002 adaptive_diagnostics: baseline.adaptive_diagnostics,
3003 kappa_timing,
3004 })
3005}
3006
3007#[derive(Clone, Copy, PartialEq, Eq, Debug)]
3019enum SpatialHyperKind {
3020 Anisotropic,
3021 Isotropic,
3022}
3023
3024impl SpatialHyperKind {
3025 fn label(self) -> &'static str {
3028 match self {
3029 SpatialHyperKind::Anisotropic => "spatial-aniso-joint",
3030 SpatialHyperKind::Isotropic => "spatial-iso-joint",
3031 }
3032 }
3033
3034 fn adjective(self) -> &'static str {
3036 match self {
3037 SpatialHyperKind::Anisotropic => "anisotropic",
3038 SpatialHyperKind::Isotropic => "isotropic",
3039 }
3040 }
3041
3042 fn coord_name(self) -> &'static str {
3045 match self {
3046 SpatialHyperKind::Anisotropic => "psi",
3047 SpatialHyperKind::Isotropic => "kappa",
3048 }
3049 }
3050}
3051
3052struct SpatialFrozenGlmInputs {
3058 y: Array1<f64>,
3059 weights: Array1<f64>,
3060 offset: Array1<f64>,
3061 family: LikelihoodSpec,
3062}
3063
3064fn frozen_glm_tensor_eligible_family(family: &LikelihoodSpec) -> bool {
3081 !family.is_gaussian_identity()
3082 && matches!(
3083 &family.response,
3084 ResponseFamily::Binomial
3085 | ResponseFamily::Poisson
3086 | ResponseFamily::Gamma
3087 | ResponseFamily::NegativeBinomial { .. }
3088 )
3089}
3090
3091struct SpatialJointContext<'d> {
3092 data: ArrayView2<'d, f64>,
3093 rho_dim: usize,
3094 kind: SpatialHyperKind,
3095 cache: SingleBlockExactJointDesignCache<'d>,
3096 evaluator: gam_solve::estimate::ExternalJointHyperEvaluator<'d>,
3097 frozen_glm_inputs: Option<SpatialFrozenGlmInputs>,
3098 frozen_glm_psi_bounds: Option<(f64, f64)>,
3099 frozen_glm_tensor: Option<gam_solve::glm_sufficient_lane::FrozenWeightGramTensor>,
3100 frozen_glm_tensor_attempted: bool,
3101 frozen_glm_weight_memo: Option<(Array1<f64>, Array1<f64>)>,
3113}
3114
3115#[derive(Clone, Copy, Debug, Default)]
3116struct NfreeSkipGateStatus {
3117 shape: bool,
3118 value: bool,
3119 gradient: bool,
3120 penalty: bool,
3121 revision: bool,
3122 second_order: bool,
3123}
3124
3125impl NfreeSkipGateStatus {
3126 fn would_skip(self, require_gradient: bool) -> bool {
3127 self.shape
3128 && self.value
3129 && (!require_gradient || self.gradient)
3130 && self.penalty
3131 && self.revision
3132 && !self.second_order
3133 }
3134}
3135
3136impl<'d> SpatialJointContext<'d> {
3137 fn nfree_skip_gate_status(
3138 &self,
3139 theta: &Array1<f64>,
3140 allow_second_order: bool,
3141 require_gradient: bool,
3142 ) -> NfreeSkipGateStatus {
3143 let shape = theta.len() == self.rho_dim + 1;
3144 let (value, gradient) = if shape {
3145 let psi = theta[self.rho_dim];
3146 (
3147 self.evaluator.psi_gram_tensor_covers(psi)
3148 && self.evaluator.psi_gram_tensor_covers_skip(psi),
3149 !require_gradient || self.evaluator.psi_gram_tensor_covers_gradient(psi),
3150 )
3151 } else {
3152 (false, false)
3153 };
3154 NfreeSkipGateStatus {
3155 shape,
3156 value,
3157 gradient,
3158 penalty: self.evaluator.supports_nfree_penalty_rekey(),
3159 revision: self.evaluator.nfree_fast_path_revision().is_some(),
3160 second_order: allow_second_order,
3161 }
3162 }
3163
3164 fn frozen_glm_working_state(
3165 &self,
3166 beta: &Array1<f64>,
3167 ) -> Result<Option<(Array1<f64>, Array1<f64>)>, EstimationError> {
3168 let Some(inputs) = self.frozen_glm_inputs.as_ref() else {
3169 return Ok(None);
3170 };
3171 if beta.len() != self.cache.design().design.ncols() {
3172 return Ok(None);
3173 }
3174 let mut eta = self.cache.design().design.matrixvectormultiply(beta);
3175 if eta.len() != inputs.offset.len() {
3176 crate::bail_invalid_estim!(
3177 "frozen GLM tensor warm-state row mismatch: eta={}, offset={}",
3178 eta.len(),
3179 inputs.offset.len()
3180 );
3181 }
3182 eta += &inputs.offset;
3183 let obs = evaluate_standard_familyobservations(
3184 inputs.family.clone(),
3185 None,
3186 None,
3187 None,
3188 &inputs.y,
3189 &inputs.weights,
3190 &eta,
3191 )?;
3192 let mut working_response = obs.eta.clone();
3193 for i in 0..working_response.len() {
3194 let wi = obs.fisherweight[i].max(1e-12);
3195 working_response[i] += obs.score[i] / wi;
3196 }
3197 Ok(Some((obs.fisherweight, working_response)))
3198 }
3199
3200 fn frozen_glm_trial_weights(
3209 &mut self,
3210 beta: &Array1<f64>,
3211 ) -> Result<Option<Array1<f64>>, EstimationError> {
3212 if let Some((memo_beta, memo_w)) = self.frozen_glm_weight_memo.as_ref()
3213 && memo_beta.len() == beta.len()
3214 && memo_beta
3215 .iter()
3216 .zip(beta.iter())
3217 .all(|(a, b)| a.to_bits() == b.to_bits())
3218 {
3219 return Ok(Some(memo_w.clone()));
3220 }
3221 match self.frozen_glm_working_state(beta)? {
3222 Some((current_w, _)) => {
3223 self.frozen_glm_weight_memo = Some((beta.clone(), current_w.clone()));
3224 Ok(Some(current_w))
3225 }
3226 None => Ok(None),
3227 }
3228 }
3229
3230 fn ensure_frozen_glm_tensor(
3231 &mut self,
3232 theta: &Array1<f64>,
3233 warm_beta: Option<&Array1<f64>>,
3234 ) -> Result<(), EstimationError> {
3235 if self.frozen_glm_tensor.is_some() || self.frozen_glm_tensor_attempted {
3236 return Ok(());
3237 }
3238 let Some((psi_lo, psi_hi)) = self.frozen_glm_psi_bounds else {
3239 return Ok(());
3240 };
3241 if theta.len() != self.rho_dim + 1 {
3242 self.frozen_glm_tensor_attempted = true;
3243 return Ok(());
3244 }
3245 let Some(beta) = warm_beta else {
3246 return Ok(());
3247 };
3248 let Some((frozen_w, working_z)) = self.frozen_glm_working_state(beta)? else {
3249 self.frozen_glm_tensor_attempted = true;
3250 return Ok(());
3251 };
3252 let theta_probe_base = theta.clone();
3253 let rho_dim = self.rho_dim;
3254 let Self {
3261 cache, evaluator, ..
3262 } = self;
3263 let tensor = evaluator.build_frozen_glm_gram_tensor(
3264 |psi| {
3265 let mut theta_probe = theta_probe_base.clone();
3266 theta_probe[rho_dim] = psi;
3267 cache.ensure_theta(&theta_probe)?;
3268 Ok(cache.design().design.clone())
3269 },
3270 frozen_w.view(),
3271 working_z.view(),
3272 psi_lo,
3273 psi_hi,
3274 );
3275 self.cache
3276 .ensure_theta(theta)
3277 .map_err(EstimationError::InvalidInput)?;
3278 self.frozen_glm_tensor_attempted = true;
3279 if let Some(tensor) = tensor {
3280 self.frozen_glm_tensor = Some(tensor);
3281 log::info!(
3282 "[STAGE] {} certified frozen-W GLM ψ tensor over [{psi_lo:.3}, {psi_hi:.3}]",
3283 self.kind.label(),
3284 );
3285 } else {
3286 log::info!(
3287 "[STAGE] {} frozen-W GLM ψ tensor did not certify over [{psi_lo:.3}, {psi_hi:.3}]",
3288 self.kind.label(),
3289 );
3290 }
3291 Ok(())
3292 }
3293
3294 fn stage_frozen_glm_trial_statistics(
3295 &mut self,
3296 theta: &Array1<f64>,
3297 warm_beta: Option<&Array1<f64>>,
3298 allow_gradient: bool,
3299 ) -> Result<(), EstimationError> {
3300 let kind = self.kind;
3301 let mut staged_gram: Option<Array2<f64>> = None;
3302 let mut staged_deriv: Option<(Array2<f64>, Array1<f64>)> = None;
3303 if theta.len() == self.rho_dim + 1 {
3304 let psi = theta[self.rho_dim];
3305 let tensor_covers = self
3312 .frozen_glm_tensor
3313 .as_ref()
3314 .is_some_and(|t| t.contains(psi));
3315 let current_w = if tensor_covers {
3316 match warm_beta {
3317 Some(beta) => self.frozen_glm_trial_weights(beta)?,
3318 None => None,
3319 }
3320 } else {
3321 None
3322 };
3323 if let (Some(tensor), Some(current_w)) =
3324 (self.frozen_glm_tensor.as_ref(), current_w.as_ref())
3325 {
3326 const FROZEN_GLM_WEIGHT_DRIFT_RTOL: f64 = 1e-3;
3327 if tensor.weight_drift_within(current_w.view(), FROZEN_GLM_WEIGHT_DRIFT_RTOL) {
3328 staged_gram = Some(tensor.gram_at(psi));
3329 log::debug!(
3330 "[STAGE] {} trial at psi={psi:.6}: serving frozen-W GLM \
3331 first-Fisher-step XᵀWX n-free (weight drift within tol)",
3332 kind.label(),
3333 );
3334 }
3335 if allow_gradient
3336 && tensor.contains_for_gradient(psi)
3337 && let Some((dgram_dpsi, drhs_dpsi)) =
3338 tensor.gradient_pair_if_sound(psi, current_w.view())
3339 {
3340 staged_deriv = Some((dgram_dpsi, drhs_dpsi));
3341 log::debug!(
3342 "[STAGE] {} trial at psi={psi:.6}: serving frozen-W GLM \
3343 ψ-gradient (∂G/∂ψ, ∂b/∂ψ) n-free (gradient weight drift within \
3344 tight tol); B_j stays exact",
3345 kind.label(),
3346 );
3347 }
3348 }
3349 }
3350 self.evaluator.stage_glm_first_step_gram(staged_gram);
3351 self.evaluator.stage_glm_psi_gram_deriv(staged_deriv);
3352 Ok(())
3353 }
3354
3355 fn eval_full(
3357 &mut self,
3358 theta: &Array1<f64>,
3359 order: gam_solve::rho_optimizer::OuterEvalOrder,
3360 analytic_outer_hessian_available: bool,
3361 ) -> Result<
3362 (
3363 f64,
3364 Array1<f64>,
3365 gam_problem::HessianResult,
3366 ),
3367 EstimationError,
3368 > {
3369 use gam_solve::rho_optimizer::OuterEvalOrder;
3370 let allow_second_order = matches!(order, OuterEvalOrder::ValueGradientHessian)
3371 && analytic_outer_hessian_available;
3372 if let Some(eval) = self.cache.memoized_eval(theta) {
3373 let cached_satisfies_order = !allow_second_order || eval.2.is_analytic();
3374 if cached_satisfies_order {
3375 return Ok(eval);
3376 }
3377 }
3378 let kind = self.kind;
3379 let nfree_fast_path_revision = self.evaluator.nfree_fast_path_revision();
3415 let skip_design_realization = !allow_second_order && theta.len() == self.rho_dim + 1 && {
3416 let psi = theta[self.rho_dim];
3417 self.evaluator.psi_gram_tensor_covers(psi)
3418 && self.evaluator.psi_gram_tensor_covers_gradient(psi)
3425 && self.evaluator.psi_gram_tensor_covers_skip(psi)
3442 && self.evaluator.supports_nfree_penalty_rekey()
3447 && nfree_fast_path_revision.is_some()
3448 };
3449 if skip_design_realization {
3450 log::debug!(
3451 "[STAGE] {} eval_full at psi={:.6}: skipping n×k design re-realization \
3452 + reconditioning — criterion/gradient/inner-solve served n-free from \
3453 the certified ψ-gram tensor (GaussianFixedCache + k-space ψ-derivatives)",
3454 kind.label(),
3455 theta[self.rho_dim],
3456 );
3457 } else {
3458 self.cache
3459 .ensure_theta(theta)
3460 .map_err(EstimationError::InvalidInput)?;
3461 }
3462 let warm_beta = self.evaluator.current_beta();
3463 self.ensure_frozen_glm_tensor(theta, warm_beta.as_ref())?;
3464 self.stage_frozen_glm_trial_statistics(theta, warm_beta.as_ref(), !allow_second_order)?;
3472 let hyper_dirs = if skip_design_realization {
3479 self.cache.nfree_tensor_gradient_hyper_dirs(theta)?
3480 } else {
3481 self.cache.hyper_dirs_for_current_design(self.data, kind)?
3482 };
3483
3484 let design_revision = if skip_design_realization {
3485 nfree_fast_path_revision
3486 } else {
3487 Some(self.cache.design_revision())
3488 };
3489 if self.evaluator.supports_nfree_penalty_rekey() {
3503 match self.cache.canonical_penalties_at(theta) {
3504 Ok(penalty) => self.evaluator.stage_fast_path_penalty(Some(penalty)),
3505 Err(e) => {
3506 log::warn!(
3507 "[STAGE] {} eval_full at psi={:.6}: exact n-free S(ψ) rebuild failed \
3508 ({e}); clearing stage (eval falls to slow path)",
3509 kind.label(),
3510 theta[self.rho_dim],
3511 );
3512 self.evaluator.stage_fast_path_penalty(None);
3513 }
3514 }
3515 }
3516 let eval = evaluate_joint_reml_outer_eval_at_theta(
3523 &mut self.evaluator,
3524 self.cache.design(),
3525 theta,
3526 self.rho_dim,
3527 hyper_dirs,
3528 warm_beta.as_ref().map(|b: &Array1<f64>| b.view()),
3529 if allow_second_order {
3530 order
3531 } else {
3532 OuterEvalOrder::ValueAndGradient
3533 },
3534 design_revision,
3535 );
3536 if let Ok(ref value) = eval {
3537 self.cache.store_eval_at(theta, value.clone());
3538 }
3539 eval
3540 }
3541
3542 fn eval_efs(
3543 &mut self,
3544 theta: &Array1<f64>,
3545 ) -> Result<gam_problem::EfsEval, EstimationError> {
3546 self.cache
3547 .ensure_theta(theta)
3548 .map_err(EstimationError::InvalidInput)?;
3549 let kind = self.kind;
3550 let hyper_dirs = try_build_spatial_log_kappa_hyper_dirs(
3551 self.data,
3552 self.cache.spec(),
3553 self.cache.design(),
3554 &self.cache.spatial_terms,
3555 )?
3556 .ok_or_else(|| {
3557 EstimationError::InvalidInput(format!(
3558 "failed to build {} hyper_dirs for exact-joint EFS",
3559 kind.adjective(),
3560 ))
3561 })?;
3562 let design_revision = Some(self.cache.design_revision());
3563 let warm_beta = self.evaluator.current_beta();
3564 evaluate_joint_reml_efs_at_theta(
3565 &mut self.evaluator,
3566 self.cache.design(),
3567 theta,
3568 self.rho_dim,
3569 hyper_dirs,
3570 warm_beta.as_ref().map(|b: &Array1<f64>| b.view()),
3571 design_revision,
3572 )
3573 }
3574
3575 fn eval_cost(&mut self, theta: &Array1<f64>) -> f64 {
3581 if let Some(cost) = self.cache.memoized_cost(theta) {
3582 return cost;
3583 }
3584 let probe_start = std::time::Instant::now();
3599 let psi_distance = self
3600 .cache
3601 .current_theta
3602 .as_ref()
3603 .filter(|reference| reference.len() == theta.len())
3604 .map(|reference| {
3605 reference
3606 .iter()
3607 .zip(theta.iter())
3608 .map(|(a, b)| (a - b) * (a - b))
3609 .sum::<f64>()
3610 .sqrt()
3611 })
3612 .unwrap_or(f64::NAN);
3613 let nfree_fast_path_revision = self.evaluator.nfree_fast_path_revision();
3627 let skip_value_realization = theta.len() == self.rho_dim + 1 && {
3628 let psi = theta[self.rho_dim];
3629 self.evaluator.psi_gram_tensor_covers(psi)
3630 && self.evaluator.psi_gram_tensor_covers_skip(psi)
3639 && self.evaluator.supports_nfree_penalty_rekey()
3644 && nfree_fast_path_revision.is_some()
3645 };
3646 if theta.len() == self.rho_dim + 1
3647 && self.evaluator.has_psi_gram_tensor()
3648 && !self.evaluator.psi_gram_tensor_covers(theta[self.rho_dim])
3649 {
3650 self.cache.store_cost_at(theta, f64::INFINITY);
3651 return f64::INFINITY;
3652 }
3653 if !skip_value_realization && self.cache.ensure_theta(theta).is_err() {
3654 return f64::INFINITY;
3655 }
3656 if self.evaluator.supports_nfree_penalty_rekey() {
3662 match self.cache.canonical_penalties_at(theta) {
3663 Ok(penalty) => self.evaluator.stage_fast_path_penalty(Some(penalty)),
3664 Err(_) => self.evaluator.stage_fast_path_penalty(None),
3665 }
3666 }
3667 let warm_beta = self.evaluator.current_beta();
3668 if let Err(err) = self.ensure_frozen_glm_tensor(theta, warm_beta.as_ref()) {
3669 log::warn!(
3670 "[STAGE] {} value-probe at psi={:.6}: frozen-W GLM tensor setup failed ({err}); \
3671 falling back to exact streamed Gram",
3672 self.kind.label(),
3673 if theta.len() > self.rho_dim {
3674 theta[self.rho_dim]
3675 } else {
3676 f64::NAN
3677 },
3678 );
3679 self.evaluator.stage_glm_first_step_gram(None);
3680 self.evaluator.stage_glm_psi_gram_deriv(None);
3681 } else if let Err(err) =
3682 self.stage_frozen_glm_trial_statistics(theta, warm_beta.as_ref(), false)
3683 {
3684 log::warn!(
3685 "[STAGE] {} value-probe at psi={:.6}: frozen-W GLM staging failed ({err}); \
3686 falling back to exact streamed Gram",
3687 self.kind.label(),
3688 if theta.len() > self.rho_dim {
3689 theta[self.rho_dim]
3690 } else {
3691 f64::NAN
3692 },
3693 );
3694 self.evaluator.stage_glm_first_step_gram(None);
3695 self.evaluator.stage_glm_psi_gram_deriv(None);
3696 }
3697 let design_revision = if skip_value_realization {
3698 nfree_fast_path_revision
3699 } else {
3700 Some(self.cache.design_revision())
3701 };
3702 let cost_label = self.kind.label();
3703 let result = {
3704 let design = self.cache.design();
3705 self.evaluator.evaluate_cost_only(
3706 &design.design,
3707 &design.penalties,
3708 &design.nullspace_dims,
3709 design.linear_constraints.clone(),
3710 theta,
3711 self.rho_dim,
3712 warm_beta.as_ref().map(|b: &Array1<f64>| b.view()),
3713 cost_label,
3714 design_revision,
3715 )
3716 };
3717 match result {
3718 Ok(cost) => {
3719 log::debug!(
3720 "[STAGE] {cost_label} value-probe (order=Value): elapsed={:.3}s \
3721 cost={cost:.6e} trial_theta_distance={psi_distance:.3e}",
3722 probe_start.elapsed().as_secs_f64(),
3723 );
3724 self.cache.store_cost_at(theta, cost);
3725 cost
3726 }
3727 Err(_) => f64::INFINITY,
3728 }
3729 }
3730
3731 fn reset(&mut self) {
3732 self.cache.current_theta = None;
3733 self.cache.last_eval_theta = None;
3734 self.cache.last_cost = None;
3735 self.cache.last_eval = None;
3736 }
3737}
3738
3739enum SpatialJointOutcome {
3772 Optimized {
3776 theta_star: Array1<f64>,
3777 final_value: f64,
3778 },
3779 NonConverged {
3783 iterations: usize,
3784 final_value: f64,
3785 final_grad_norm: Option<f64>,
3786 },
3787}
3788
3789fn kphase_log_norms(theta: &Array1<f64>, rho_dim: usize) -> (f64, f64) {
3790 let theta_norm = theta.iter().map(|v| v * v).sum::<f64>().sqrt();
3791 let log_kappa_norm = theta
3792 .iter()
3793 .skip(rho_dim)
3794 .map(|v| v * v)
3795 .sum::<f64>()
3796 .sqrt();
3797 (theta_norm, log_kappa_norm)
3798}
3799
3800fn run_exact_joint_spatial_optimization(
3801 kind: SpatialHyperKind,
3802 data: ArrayView2<'_, f64>,
3803 y: ArrayView1<'_, f64>,
3804 weights: ArrayView1<'_, f64>,
3805 offset: ArrayView1<'_, f64>,
3806 resolvedspec: &TermCollectionSpec,
3807 baseline_design: &TermCollectionDesign,
3808 family: LikelihoodSpec,
3809 options: &FitOptions,
3810 spatial_terms: &[usize],
3811 dims_per_term: &[usize],
3812 theta0: &Array1<f64>,
3813 lower: &Array1<f64>,
3814 upper: &Array1<f64>,
3815 rho_dim: usize,
3816 kappa_options: &SpatialLengthScaleOptimizationOptions,
3817) -> Result<(SpatialJointOutcome, SpatialLengthScaleOptimizationTiming), EstimationError> {
3818 let label = kind.label();
3819 assert!(
3821 lower.len() == theta0.len() && upper.len() == theta0.len(),
3822 "spatial hyperparameter bounds must match theta length: lower_len={}, upper_len={}, theta_len={}",
3823 lower.len(),
3824 upper.len(),
3825 theta0.len()
3826 );
3827 assert!(
3828 baseline_design.smooth.terms.len() >= spatial_terms.len(),
3829 "baseline design must have at least one smooth term per spatial term: baseline_terms={}, spatial_terms={}",
3830 baseline_design.smooth.terms.len(),
3831 spatial_terms.len()
3832 );
3833 use gam_solve::rho_optimizer::OuterEvalOrder;
3834 use gam_problem::{DeclaredHessianForm, Derivative, OuterEval};
3835
3836 let theta_dim = theta0.len();
3837 let coord_dim = theta_dim - rho_dim;
3840 let analytic_outer_hessian_available =
3850 exact_joint_spatial_outer_hessian_available(&family, baseline_design);
3851 if !analytic_outer_hessian_available {
3852 log::info!(
3853 "[{label}] analytic outer Hessian unavailable for family/design; routing without second-order geometry (coord_dim={coord_dim})"
3854 );
3855 }
3856 let mut prefer_gradient_only = theta_dim > EXACT_JOINT_SECOND_ORDER_THETA_CAP;
3862 if prefer_gradient_only {
3863 log::info!(
3864 "[{label}] joint θ-dim {theta_dim} exceeds the exact pair-Hessian budget \
3865 ({EXACT_JOINT_SECOND_ORDER_THETA_CAP}); routing gradient-only quasi-Newton"
3866 );
3867 }
3868 let mut suppress_outer_hessian_for_nfree = false;
3878
3879 log::trace!(
3880 "[{}] starting analytic optimization: rho_dim={}, coord_dim={}, dims_per_term={:?}",
3881 label,
3882 rho_dim,
3883 coord_dim,
3884 dims_per_term,
3885 );
3886
3887 let mut ctx = SpatialJointContext {
3888 data,
3889 rho_dim,
3890 kind,
3891 cache: SingleBlockExactJointDesignCache::new(
3892 data,
3893 resolvedspec.clone(),
3894 baseline_design.clone(),
3895 spatial_terms.to_vec(),
3896 rho_dim,
3897 dims_per_term.to_vec(),
3898 )
3899 .map_err(EstimationError::InvalidInput)?,
3900 evaluator: gam_solve::estimate::ExternalJointHyperEvaluator::new(
3901 y,
3902 weights,
3903 &baseline_design.design,
3904 offset,
3905 &baseline_design.penalties,
3906 &external_opts_for_design(&family, baseline_design, options),
3907 label,
3908 )?,
3909 frozen_glm_inputs: if coord_dim == 1 && frozen_glm_tensor_eligible_family(&family) {
3910 Some(SpatialFrozenGlmInputs {
3911 y: y.to_owned(),
3912 weights: weights.to_owned(),
3913 offset: offset.to_owned(),
3914 family: family.clone(),
3915 })
3916 } else {
3917 None
3918 },
3919 frozen_glm_psi_bounds: if coord_dim == 1 && frozen_glm_tensor_eligible_family(&family) {
3920 Some((lower[rho_dim], upper[rho_dim]))
3921 } else {
3922 None
3923 },
3924 frozen_glm_tensor: None,
3925 frozen_glm_tensor_attempted: false,
3926 frozen_glm_weight_memo: None,
3927 };
3928
3929 let nfree_penalty_capable = coord_dim == 1
3942 && family.is_gaussian_identity()
3943 && ctx.cache.supports_nfree_penalty_rekey();
3944 if nfree_penalty_capable {
3945 let psi_lo = lower[rho_dim];
3946 let psi_hi = upper[rho_dim];
3947 let z = Array1::from_iter(y.iter().zip(offset.iter()).map(|(yi, oi)| yi - oi));
3948 let theta_probe_base = theta0.clone();
3949 let SpatialJointContext {
3952 cache, evaluator, ..
3953 } = &mut ctx;
3954 let attached = evaluator.build_and_set_psi_gram_tensor(
3955 |psi| {
3956 let mut theta_probe = theta_probe_base.clone();
3957 theta_probe[rho_dim] = psi;
3958 cache.ensure_theta(&theta_probe)?;
3959 Ok(cache.design().design.clone())
3960 },
3961 weights,
3962 z.view(),
3963 psi_lo,
3964 psi_hi,
3965 );
3966 if attached {
3967 log::info!(
3968 "[{label}] certified ψ-gram tensor over [{psi_lo:.3}, {psi_hi:.3}]: \
3969 in-window trials assemble Gaussian sufficient statistics n-free"
3970 );
3971 let gradient_covers_full_window = evaluator.psi_gram_tensor_covers_gradient(psi_lo)
3972 && evaluator.psi_gram_tensor_covers_gradient(psi_hi);
3973 if gradient_covers_full_window {
3974 log::info!(
3975 "[{label}] certified ψ-gram tensor gradient lane covers the full \
3976 optimizer window [{psi_lo:.3}, {psi_hi:.3}]"
3977 );
3978 } else {
3979 log::info!(
3980 "[{label}] ψ-gram tensor value lane certified, but the gradient lane \
3981 does not cover the full optimizer window [{psi_lo:.3}, {psi_hi:.3}]; \
3982 keeping exact streamed kappa routing"
3983 );
3984 }
3985 evaluator.set_supports_nfree_penalty_rekey(true);
4005 log::info!(
4006 "[{label}] exact n-free ψ-penalty re-key enabled over [{psi_lo:.3}, \
4007 {psi_hi:.3}]: in-window fast-path trials rebuild S(ψ) n-free from frozen \
4008 geometry (no reset_surface)"
4009 );
4010 } else {
4011 log::info!(
4012 "[{label}] ψ-gram tensor did not certify over [{psi_lo:.3}, {psi_hi:.3}]; \
4013 keeping the exact per-trial path"
4014 );
4015 }
4016 if attached
4037 && evaluator.psi_gram_tensor_covers_gradient(psi_lo)
4038 && evaluator.psi_gram_tensor_covers_gradient(psi_hi)
4039 && evaluator.supports_nfree_penalty_rekey()
4040 && cache.supports_nfree_gradient_only_routing()
4041 {
4042 suppress_outer_hessian_for_nfree = true;
4043 prefer_gradient_only = true;
4044 log::info!(
4045 "[{label}] n-free Gaussian ψ-lane armed; suppressing the analytic outer \
4046 Hessian and routing gradient-only (BFGS) so the κ outer loop never realizes \
4047 the O(n) second-order slab — n-independent outer loop (#1033)"
4048 );
4049 }
4050 } else if coord_dim == 1 && family.is_gaussian_identity() {
4051 log::info!(
4052 "[{label}] exact n-free ψ-penalty re-key unavailable; skipping ψ-gram tensor \
4053 attachment so value, gradient, and Hessian remain on the same exact streamed \
4054 objective"
4055 );
4056 }
4057
4058 const OUTER_FD_AUDIT_MAX_N: usize = 4_000; const OUTER_FD_AUDIT_MAX_THETA_DIM: usize = 32; let n_total = data.nrows();
4075 let outer_fd_audit_eligible = analytic_outer_hessian_available && n_total <= OUTER_FD_AUDIT_MAX_N && theta_dim <= OUTER_FD_AUDIT_MAX_THETA_DIM; log::warn!(
4079 "[OUTER-FD-AUDIT/spatial-exact-joint] gate eligible={outer_fd_audit_eligible} \
4080 analytic_grad={analytic_outer_hessian_available} n_total={n_total} \
4081 theta_dim={theta_dim} rho_dim={rho_dim} psi_dim={coord_dim}"
4082 );
4083 if outer_fd_audit_eligible {
4084 let audit = (|| -> Result<gam_solve::rho_optimizer::OuterGradientFdAudit, String> {
4086 let mut eval_at = |theta: &Array1<f64>,
4087 mode: gam_solve::estimate::reml::reml_outer_engine::EvalMode|
4088 -> Result<
4089 (
4090 f64,
4091 Array1<f64>,
4092 gam_problem::HessianResult,
4093 ),
4094 String,
4095 > {
4096 use gam_solve::estimate::reml::reml_outer_engine::EvalMode;
4097 let order = if matches!(mode, EvalMode::ValueGradientHessian) {
4098 OuterEvalOrder::ValueGradientHessian
4099 } else {
4100 OuterEvalOrder::Value
4101 };
4102 ctx.eval_full(theta, order, analytic_outer_hessian_available)
4103 .map_err(|e| format!("fd-audit eval_full: {e}"))
4104 };
4105 let rho_dim_audit = rho_dim;
4106 let label_fn = move |i: usize| -> String {
4107 if i < rho_dim_audit {
4108 format!("rho[{i}]")
4109 } else {
4110 format!("psi_kappa[{}]", i - rho_dim_audit)
4111 }
4112 };
4113 gam_solve::rho_optimizer::outer_gradient_fd_audit(
4114 theta0,
4116 1e-4,
4117 label_fn,
4118 &mut eval_at,
4119 )
4120 })();
4121 match audit {
4123 Ok(audit) => audit.log_verdict("spatial-exact-joint"),
4124 Err(e) => log::warn!("[OUTER-FD-AUDIT/spatial-exact-joint] skipped: {e}"),
4125 }
4126 }
4127
4128 let kphase_prime_order = if analytic_outer_hessian_available && !suppress_outer_hessian_for_nfree {
4129 OuterEvalOrder::ValueGradientHessian
4130 } else {
4131 OuterEvalOrder::ValueAndGradient
4132 };
4133 let kphase_prime_start = std::time::Instant::now();
4134 drop(ctx.eval_full(theta0, kphase_prime_order, analytic_outer_hessian_available)?);
4135 log::info!(
4136 "[KAPPA-PHASE-PRIME] order={:?} elapsed_s={:.4} slow_path_resets_total={} design_revision={}",
4137 kphase_prime_order,
4138 kphase_prime_start.elapsed().as_secs_f64(),
4139 ctx.evaluator.slow_path_reset_count(),
4140 ctx.cache.design_revision(),
4141 );
4142
4143 let kphase_cost_calls = std::cell::Cell::new(0usize);
4144 let kphase_eval_calls = std::cell::Cell::new(0usize);
4145 let kphase_efs_calls = std::cell::Cell::new(0usize);
4146 let kphase_cost_total_s = std::cell::Cell::new(0.0);
4147 let kphase_eval_total_s = std::cell::Cell::new(0.0);
4148 let kphase_efs_total_s = std::cell::Cell::new(0.0);
4149 let kphase_nfree_miss_shape = std::cell::Cell::new(0u64);
4150 let kphase_nfree_miss_value = std::cell::Cell::new(0u64);
4151 let kphase_nfree_miss_gradient = std::cell::Cell::new(0u64);
4152 let kphase_nfree_miss_penalty = std::cell::Cell::new(0u64);
4153 let kphase_nfree_miss_revision = std::cell::Cell::new(0u64);
4154 let kphase_nfree_miss_second_order = std::cell::Cell::new(0u64);
4155 let kphase_nfree_miss_other = std::cell::Cell::new(0u64);
4156 let kphase_optim_start = std::time::Instant::now();
4157 let kphase_log_kappa_dim = coord_dim;
4158 let kphase_slow_resets_start = ctx.evaluator.slow_path_reset_count();
4159 let kphase_design_revision_start = ctx.cache.design_revision();
4160
4161 let problem = exact_joint_multistart_outer_problem(
4162 theta0,
4163 lower,
4164 upper,
4165 rho_dim,
4166 coord_dim,
4167 theta_dim,
4168 Derivative::Analytic,
4169 if analytic_outer_hessian_available && !suppress_outer_hessian_for_nfree {
4170 DeclaredHessianForm::Either
4171 } else {
4172 DeclaredHessianForm::Unavailable
4177 },
4178 prefer_gradient_only,
4179 suppress_outer_hessian_for_nfree,
4190 seed_risk_profile_for_likelihood_family(&family),
4191 kappa_options.rel_tol.max(1e-6),
4192 kappa_options.max_outer_iter.max(1),
4193 Some(5.0),
4197 Some(kappa_options.log_step.clamp(0.25, 1.0)),
4199 None,
4200 Some((data.nrows(), baseline_design.design.ncols())),
4205 !constant_curvature_term_indices(resolvedspec).is_empty(),
4209 );
4210
4211 let eval_outer = |ctx: &mut &mut SpatialJointContext<'_>,
4212 theta: &Array1<f64>,
4213 order: OuterEvalOrder|
4214 -> Result<OuterEval, EstimationError> {
4215 let t0 = std::time::Instant::now();
4216 let allow_second_order_for_call = matches!(order, OuterEvalOrder::ValueGradientHessian)
4217 && analytic_outer_hessian_available;
4218 let gate = ctx.nfree_skip_gate_status(theta, allow_second_order_for_call, true);
4219 let resets_before = ctx.evaluator.slow_path_reset_count();
4220 let raw = ctx.eval_full(theta, order, analytic_outer_hessian_available);
4221 let reset_delta = ctx
4222 .evaluator
4223 .slow_path_reset_count()
4224 .saturating_sub(resets_before);
4225 if reset_delta > 0 {
4226 if !gate.shape {
4227 kphase_nfree_miss_shape.set(kphase_nfree_miss_shape.get() + reset_delta);
4228 }
4229 if gate.shape && !gate.value {
4230 kphase_nfree_miss_value.set(kphase_nfree_miss_value.get() + reset_delta);
4231 }
4232 if gate.shape && gate.value && !gate.gradient {
4233 kphase_nfree_miss_gradient.set(kphase_nfree_miss_gradient.get() + reset_delta);
4234 }
4235 if gate.shape && gate.value && gate.gradient && !gate.penalty {
4236 kphase_nfree_miss_penalty.set(kphase_nfree_miss_penalty.get() + reset_delta);
4237 }
4238 if gate.shape && gate.value && gate.gradient && gate.penalty && !gate.revision {
4239 kphase_nfree_miss_revision.set(kphase_nfree_miss_revision.get() + reset_delta);
4240 }
4241 if gate.shape
4242 && gate.value
4243 && gate.gradient
4244 && gate.penalty
4245 && gate.revision
4246 && gate.second_order
4247 {
4248 kphase_nfree_miss_second_order
4249 .set(kphase_nfree_miss_second_order.get() + reset_delta);
4250 }
4251 if gate.would_skip(true) {
4252 kphase_nfree_miss_other.set(kphase_nfree_miss_other.get() + reset_delta);
4253 }
4254 }
4255 let elapsed_s = t0.elapsed().as_secs_f64();
4256 kphase_eval_calls.set(kphase_eval_calls.get() + 1);
4257 kphase_eval_total_s.set(kphase_eval_total_s.get() + elapsed_s);
4258 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta, rho_dim);
4259 log::info!(
4260 "[KAPPA-PHASE] phase=eval_outer call={} order={:?} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
4261 kphase_eval_calls.get(),
4262 order,
4263 Some(ctx.cache.design_revision()),
4264 theta_norm,
4265 log_kappa_norm,
4266 elapsed_s,
4267 );
4268 match raw {
4269 Ok((cost, grad, hess)) => Ok(OuterEval {
4270 cost,
4271 gradient: grad,
4272 hessian: hess,
4273 inner_beta_hint: None,
4274 }),
4275 Err(err) if is_recoverable_trial_point_error(&err) => {
4283 log::debug!(
4284 "[{label}] trial point infeasible (kernel design \
4285 not constructible at theta={theta:?}): {err}; retreating",
4286 );
4287 Ok(OuterEval::infeasible(theta_dim))
4288 }
4289 Err(err) => Err(err),
4290 }
4291 };
4292
4293 let mut obj = problem.build_objective_with_eval_order(
4294 &mut ctx,
4295 |ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>| {
4296 let t0 = std::time::Instant::now();
4297 let gate = ctx.nfree_skip_gate_status(theta, false, false);
4298 let resets_before = ctx.evaluator.slow_path_reset_count();
4299 let cost = ctx.eval_cost(theta);
4300 let reset_delta = ctx
4301 .evaluator
4302 .slow_path_reset_count()
4303 .saturating_sub(resets_before);
4304 if reset_delta > 0 {
4305 if !gate.shape {
4306 kphase_nfree_miss_shape.set(kphase_nfree_miss_shape.get() + reset_delta);
4307 }
4308 if gate.shape && !gate.value {
4309 kphase_nfree_miss_value.set(kphase_nfree_miss_value.get() + reset_delta);
4310 }
4311 if gate.shape && gate.value && !gate.penalty {
4312 kphase_nfree_miss_penalty.set(kphase_nfree_miss_penalty.get() + reset_delta);
4313 }
4314 if gate.shape && gate.value && gate.penalty && !gate.revision {
4315 kphase_nfree_miss_revision.set(kphase_nfree_miss_revision.get() + reset_delta);
4316 }
4317 if gate.would_skip(false) {
4318 kphase_nfree_miss_other.set(kphase_nfree_miss_other.get() + reset_delta);
4319 }
4320 }
4321 let elapsed_s = t0.elapsed().as_secs_f64();
4322 kphase_cost_calls.set(kphase_cost_calls.get() + 1);
4323 kphase_cost_total_s.set(kphase_cost_total_s.get() + elapsed_s);
4324 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta, rho_dim);
4325 log::info!(
4326 "[KAPPA-PHASE] phase=cost call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
4327 kphase_cost_calls.get(),
4328 Some(ctx.cache.design_revision()),
4329 theta_norm,
4330 log_kappa_norm,
4331 elapsed_s,
4332 );
4333 Ok(cost)
4334 },
4335 |ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>| {
4336 eval_outer(
4337 ctx,
4338 theta,
4339 if analytic_outer_hessian_available && !suppress_outer_hessian_for_nfree {
4349 OuterEvalOrder::ValueGradientHessian
4350 } else {
4351 OuterEvalOrder::ValueAndGradient
4352 },
4353 )
4354 },
4355 |ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>, order: OuterEvalOrder| {
4356 eval_outer(ctx, theta, order)
4357 },
4358 Some(|ctx: &mut &mut SpatialJointContext<'_>| {
4359 ctx.reset();
4360 }),
4361 Some(|ctx: &mut &mut SpatialJointContext<'_>, theta: &Array1<f64>| {
4362 let t0 = std::time::Instant::now();
4363 let eval = ctx.eval_efs(theta);
4364 let elapsed_s = t0.elapsed().as_secs_f64();
4365 kphase_efs_calls.set(kphase_efs_calls.get() + 1);
4366 kphase_efs_total_s.set(kphase_efs_total_s.get() + elapsed_s);
4367 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta, rho_dim);
4368 log::info!(
4369 "[KAPPA-PHASE] phase=efs call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
4370 kphase_efs_calls.get(),
4371 Some(ctx.cache.design_revision()),
4372 theta_norm,
4373 log_kappa_norm,
4374 elapsed_s,
4375 );
4376 eval
4377 }),
4378 );
4379
4380 let run_label = match kind {
4381 SpatialHyperKind::Anisotropic => "aniso-psi joint REML",
4382 SpatialHyperKind::Isotropic => "iso-kappa joint REML",
4383 };
4384 let result = problem.run(&mut obj, run_label).map_err(|e| {
4385 EstimationError::InvalidInput(format!(
4386 "{} analytic optimization failed after exhausting strategy fallbacks: {e}",
4387 kind.adjective(),
4388 ))
4389 })?;
4390 drop(obj);
4391 let kphase_total_s = kphase_optim_start.elapsed().as_secs_f64();
4392 let kphase_slow_resets = ctx
4393 .evaluator
4394 .slow_path_reset_count()
4395 .saturating_sub(kphase_slow_resets_start);
4396 let kphase_design_revision_delta = ctx
4397 .cache
4398 .design_revision()
4399 .saturating_sub(kphase_design_revision_start);
4400 log::info!(
4401 "[KAPPA-PHASE-SUMMARY] log_kappa_dim={} n_cost={} cost_total_s={:.4} n_eval={} eval_total_s={:.4} n_efs={} efs_total_s={:.4} slow_path_resets={} design_revision_delta={} nfree_miss_shape={} nfree_miss_value={} nfree_miss_gradient={} nfree_miss_penalty={} nfree_miss_revision={} nfree_miss_second_order={} nfree_miss_other={} optim_total_s={:.4}",
4402 kphase_log_kappa_dim,
4403 kphase_cost_calls.get(),
4404 kphase_cost_total_s.get(),
4405 kphase_eval_calls.get(),
4406 kphase_eval_total_s.get(),
4407 kphase_efs_calls.get(),
4408 kphase_efs_total_s.get(),
4409 kphase_slow_resets,
4410 kphase_design_revision_delta,
4411 kphase_nfree_miss_shape.get(),
4412 kphase_nfree_miss_value.get(),
4413 kphase_nfree_miss_gradient.get(),
4414 kphase_nfree_miss_penalty.get(),
4415 kphase_nfree_miss_revision.get(),
4416 kphase_nfree_miss_second_order.get(),
4417 kphase_nfree_miss_other.get(),
4418 kphase_total_s,
4419 );
4420 let timing = SpatialLengthScaleOptimizationTiming {
4421 log_kappa_dim: kphase_log_kappa_dim,
4422 cost_calls: kphase_cost_calls.get(),
4423 cost_total_s: kphase_cost_total_s.get(),
4424 eval_calls: kphase_eval_calls.get(),
4425 eval_total_s: kphase_eval_total_s.get(),
4426 efs_calls: kphase_efs_calls.get(),
4427 efs_total_s: kphase_efs_total_s.get(),
4428 slow_path_resets: kphase_slow_resets,
4429 design_revision_delta: kphase_design_revision_delta,
4430 nfree_miss_shape: kphase_nfree_miss_shape.get(),
4431 nfree_miss_value: kphase_nfree_miss_value.get(),
4432 nfree_miss_gradient: kphase_nfree_miss_gradient.get(),
4433 nfree_miss_penalty: kphase_nfree_miss_penalty.get(),
4434 nfree_miss_revision: kphase_nfree_miss_revision.get(),
4435 nfree_miss_second_order: kphase_nfree_miss_second_order.get(),
4436 nfree_miss_other: kphase_nfree_miss_other.get(),
4437 optim_total_s: kphase_total_s,
4438 };
4439 if !result.converged {
4440 let rel_to_cost_threshold = options.tol * (1.0_f64 + result.final_value.abs());
4451 if let Some(final_grad) = result
4452 .final_grad_norm
4453 .filter(|v| v.is_finite() && *v <= rel_to_cost_threshold)
4454 {
4455 log::info!(
4456 "[{}] outer optimization hit max_iter={} but \
4457 projected gradient norm {:.3e} ≤ τ·(1+|f|) = {:.3e} \
4458 (τ={:.3e}, |f|={:.3e}); accepting iterate under the mgcv-style \
4459 relative-to-cost REML convergence criterion.",
4460 label,
4461 result.iterations,
4462 final_grad,
4463 rel_to_cost_threshold,
4464 options.tol,
4465 result.final_value.abs(),
4466 );
4467 } else if result.final_value.is_finite() {
4468 log::warn!(
4483 "[{}] {} did not converge after {} iterations \
4484 (final_objective={:.6e}, final_grad_norm={}); keeping the \
4485 frozen baseline geometry instead of aborting the fit.",
4486 label,
4487 kind.adjective(),
4488 result.iterations,
4489 result.final_value,
4490 result.final_grad_norm_report(),
4491 );
4492 return Ok((
4493 SpatialJointOutcome::NonConverged {
4494 iterations: result.iterations,
4495 final_value: result.final_value,
4496 final_grad_norm: result.final_grad_norm,
4497 },
4498 timing,
4499 ));
4500 } else {
4501 crate::bail_invalid_estim!(
4506 "{} analytic optimization diverged after {} iterations (final_objective={:.6e}, final_grad_norm={})",
4507 kind.adjective(),
4508 result.iterations,
4509 result.final_value,
4510 result.final_grad_norm_report(),
4511 );
4512 }
4513 }
4514 log::trace!(
4515 "[{}] converged in {} iterations, final_value={:.6e}, grad_norm={}",
4516 label,
4517 result.iterations,
4518 result.final_value,
4519 result.final_grad_norm_report(),
4520 );
4521 let theta_star = result.rho;
4525 Ok((
4526 SpatialJointOutcome::Optimized {
4527 theta_star,
4528 final_value: result.final_value,
4529 },
4530 timing,
4531 ))
4532}
4533
4534fn set_single_term_spatial_length_scale(
4538 term: &mut SmoothTermSpec,
4539 length_scale: f64,
4540) -> Result<(), EstimationError> {
4541 match &mut term.basis {
4542 SmoothBasisSpec::ThinPlate { spec, .. } => {
4543 spec.length_scale = length_scale;
4544 Ok(())
4545 }
4546 SmoothBasisSpec::Matern { spec, .. } => {
4547 spec.length_scale = length_scale;
4548 Ok(())
4549 }
4550 SmoothBasisSpec::Duchon { spec, .. } => {
4551 spec.length_scale = Some(length_scale);
4552 Ok(())
4553 }
4554 _ => Err(EstimationError::InvalidInput(format!(
4555 "term '{}' does not expose a spatial length scale",
4556 term.name
4557 ))),
4558 }
4559}
4560
4561fn set_single_term_spatial_aniso_log_scales(
4565 term: &mut SmoothTermSpec,
4566 eta: Vec<f64>,
4567) -> Result<(), EstimationError> {
4568 let eta = center_aniso_log_scales(&eta);
4569 match &mut term.basis {
4570 SmoothBasisSpec::Matern { spec, .. } => {
4571 spec.aniso_log_scales = Some(eta);
4572 Ok(())
4573 }
4574 SmoothBasisSpec::Duchon { spec, .. } => {
4575 spec.aniso_log_scales = Some(eta);
4576 Ok(())
4577 }
4578 _ => Err(EstimationError::InvalidInput(format!(
4579 "term '{}' does not support aniso_log_scales",
4580 term.name
4581 ))),
4582 }
4583}
4584
4585pub fn get_constant_curvature_kappa(spec: &TermCollectionSpec, term_idx: usize) -> Option<f64> {
4604 constant_curvature_term_spec(spec, term_idx).map(|cc| cc.kappa)
4605}
4606
4607pub fn constant_curvature_term_indices(spec: &TermCollectionSpec) -> Vec<usize> {
4609 (0..spec.smooth_terms.len())
4610 .filter(|&idx| constant_curvature_term_spec(spec, idx).is_some())
4611 .collect()
4612}
4613
4614
4615#[derive(Debug, Clone)]
4616struct SingleSmoothTermRealization {
4617 design_local: DesignMatrix,
4618 term: SmoothTerm,
4619 dropped_penaltyinfo: Vec<DroppedPenaltyBlockInfo>,
4620}
4621
4622impl SingleSmoothTermRealization {
4623 fn active_penaltyinfo(&self) -> Vec<PenaltyInfo> {
4624 self.term
4625 .penaltyinfo_local
4626 .iter()
4627 .filter(|info| info.active)
4628 .cloned()
4629 .collect()
4630 }
4631}
4632
4633fn build_single_smooth_term_realization(
4634 data: ArrayView2<'_, f64>,
4635 termspec: &SmoothTermSpec,
4636) -> Result<SingleSmoothTermRealization, BasisError> {
4637 let raw = build_smooth_design(data, std::slice::from_ref(termspec))?;
4638 finish_single_smooth_term_realization(raw)
4639}
4640
4641fn finish_single_smooth_term_realization(
4642 raw: RawSmoothDesign,
4643) -> Result<SingleSmoothTermRealization, BasisError> {
4644 let RawSmoothDesign {
4645 term_designs,
4646 dropped_penaltyinfo,
4647 terms,
4648 ..
4649 } = raw;
4650 let term = terms.into_iter().next().ok_or_else(|| {
4651 BasisError::InvalidInput("single-term smooth build returned no term".to_string())
4652 })?;
4653 let design = term_designs.into_iter().next().ok_or_else(|| {
4654 BasisError::InvalidInput("single-term smooth build returned no term design".to_string())
4655 })?;
4656
4657 Ok(SingleSmoothTermRealization {
4658 design_local: design,
4659 term,
4660 dropped_penaltyinfo,
4661 })
4662}
4663
4664fn wrap_local_build_as_realization(
4671 mut local: LocalSmoothTermBuild,
4672 termspec: &SmoothTermSpec,
4673) -> Result<SingleSmoothTermRealization, String> {
4674 let p_local = local.dim;
4675 let lb_local = if local.box_reparam {
4676 shape_lower_bounds_local(termspec.shape, p_local)
4677 } else {
4678 None
4679 };
4680
4681 let active_count = local.penaltyinfo.iter().filter(|info| info.active).count();
4682 if active_count != local.penalties.len() {
4683 return Err(format!(
4684 "internal penalty info mismatch for term '{}': active_infos={}, penalties={}",
4685 termspec.name,
4686 active_count,
4687 local.penalties.len()
4688 ));
4689 }
4690
4691 let mut dropped_penaltyinfo = Vec::<DroppedPenaltyBlockInfo>::new();
4692 for info in local.penaltyinfo.iter().filter(|info| !info.active) {
4693 dropped_penaltyinfo.push(DroppedPenaltyBlockInfo {
4694 termname: Some(termspec.name.clone()),
4695 penalty: info.clone(),
4696 });
4697 }
4698 for info in &local.pre_dropped_penaltyinfo {
4699 dropped_penaltyinfo.push(DroppedPenaltyBlockInfo {
4700 termname: Some(termspec.name.clone()),
4701 penalty: info.clone(),
4702 });
4703 }
4704
4705 let applied_rotation: Option<gam_terms::basis::JointNullRotation> = match (
4709 local.joint_null_rotation.take(),
4710 lb_local.is_some(),
4711 local.linear_constraints.is_some(),
4712 ) {
4713 (Some(rot), false, false) => {
4714 let q = &rot.rotation;
4715 let dense = local
4716 .design
4717 .try_to_dense_by_chunks("joint-null absorption rotation (single realization)")
4718 .map_err(|e| {
4719 format!(
4720 "joint-null absorption rotation: dense conversion failed for term '{}': {}",
4721 termspec.name, e
4722 )
4723 })?;
4724 let rotated = gam_linalg::faer_ndarray::fast_ab(&dense, q);
4725 local.design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(rotated));
4726 local.penalties = local
4727 .penalties
4728 .into_iter()
4729 .map(|s_local| {
4730 let qt_s = gam_linalg::faer_ndarray::fast_atb(q, &s_local);
4731 gam_linalg::faer_ndarray::fast_ab(&qt_s, q)
4732 })
4733 .collect();
4734 local.ops = vec![None; local.penalties.len()];
4735 local.kronecker_factored = None;
4736 Some(rot)
4737 }
4738 (Some(_), _, _) => None,
4739 (None, _, _) => None,
4740 };
4741
4742 let smooth_term = SmoothTerm {
4743 name: termspec.name.clone(),
4744 coeff_range: 0..p_local,
4745 shape: termspec.shape,
4746 penalties_local: local.penalties.clone(),
4747 nullspace_dims: local.nullspaces.clone(),
4748 penaltyinfo_local: local.penaltyinfo.clone(),
4749 metadata: local.metadata.clone(),
4750 lower_bounds_local: lb_local,
4751 linear_constraints_local: local.linear_constraints.clone(),
4752 kronecker_factored: local.kronecker_factored.take(),
4753 joint_null_rotation: applied_rotation,
4754 unabsorbed_global_orthogonality: None,
4757 };
4758
4759 Ok(SingleSmoothTermRealization {
4760 design_local: local.design,
4761 term: smooth_term,
4762 dropped_penaltyinfo,
4763 })
4764}
4765
4766fn freeze_geometry_from_metadata(
4777 termspec: &SmoothTermSpec,
4778 metadata: &BasisMetadata,
4779) -> Option<SmoothTermSpec> {
4780 let mut frozen = termspec.clone();
4781 match (&mut frozen.basis, metadata) {
4782 (
4783 SmoothBasisSpec::Matern {
4784 spec,
4785 input_scales: spec_scales,
4786 ..
4787 },
4788 BasisMetadata::Matern {
4789 centers,
4790 input_scales: meta_scales,
4791 identifiability_transform,
4792 nullspace_shrinkage_survived,
4793 ..
4794 },
4795 ) => {
4796 spec.center_strategy = CenterStrategy::UserProvided(centers.clone());
4797 if spec_scales.is_none()
4798 && let Some(s) = meta_scales.clone()
4799 {
4800 *spec_scales = Some(s);
4801 }
4802 if let Some(transform) = identifiability_transform.clone() {
4820 spec.identifiability = MaternIdentifiability::FrozenTransform {
4821 transform,
4822 nullspace_shrinkage_survived: Some(*nullspace_shrinkage_survived),
4823 };
4824 }
4825 Some(frozen)
4826 }
4827 (
4828 SmoothBasisSpec::Duchon {
4829 spec,
4830 input_scales: spec_scales,
4831 ..
4832 },
4833 BasisMetadata::Duchon {
4834 centers,
4835 input_scales: meta_scales,
4836 ..
4837 },
4838 ) => {
4839 spec.center_strategy = CenterStrategy::UserProvided(centers.clone());
4840 if spec_scales.is_none()
4841 && let Some(s) = meta_scales.clone()
4842 {
4843 *spec_scales = Some(s);
4844 }
4845 Some(frozen)
4846 }
4847 (
4848 SmoothBasisSpec::ThinPlate {
4849 spec,
4850 input_scales: spec_scales,
4851 ..
4852 },
4853 BasisMetadata::ThinPlate {
4854 centers,
4855 input_scales: meta_scales,
4856 ..
4857 },
4858 ) => {
4859 spec.center_strategy = CenterStrategy::UserProvided(centers.clone());
4860 if spec_scales.is_none()
4861 && let Some(s) = meta_scales.clone()
4862 {
4863 *spec_scales = Some(s);
4864 }
4865 Some(frozen)
4866 }
4867 _ => None,
4870 }
4871}
4872
4873fn rebuild_smooth_auxiliary_state(
4874 smooth: &mut SmoothDesign,
4875 dropped_penaltyinfo_by_term: &[Vec<DroppedPenaltyBlockInfo>],
4876) -> Result<(), String> {
4877 if dropped_penaltyinfo_by_term.len() != smooth.terms.len() {
4878 return Err(SmoothError::dimension_mismatch(format!(
4879 "smooth dropped-penalty cache mismatch: terms={}, dropped_sets={}",
4880 smooth.terms.len(),
4881 dropped_penaltyinfo_by_term.len()
4882 ))
4883 .into());
4884 }
4885
4886 let total_p = smooth.total_smooth_cols();
4887 let mut coefficient_lower_bounds = Array1::<f64>::from_elem(total_p, f64::NEG_INFINITY);
4888 let mut any_bounds = false;
4889 let mut linear_constraintrows: Vec<Array1<f64>> = Vec::new();
4890 let mut linear_constraint_b: Vec<f64> = Vec::new();
4891
4892 for term in &smooth.terms {
4893 let range = term.coeff_range.clone();
4894 if let Some(lb_local) = term.lower_bounds_local.as_ref() {
4895 if lb_local.len() != range.len() {
4896 return Err(SmoothError::dimension_mismatch(format!(
4897 "smooth lower-bound cache mismatch for term '{}': bounds={}, coeffs={}",
4898 term.name,
4899 lb_local.len(),
4900 range.len()
4901 ))
4902 .into());
4903 }
4904 coefficient_lower_bounds
4905 .slice_mut(s![range.clone()])
4906 .assign(lb_local);
4907 any_bounds = true;
4908 }
4909 if let Some(lin_local) = term.linear_constraints_local.as_ref() {
4910 if lin_local.a.ncols() != range.len() {
4911 return Err(SmoothError::dimension_mismatch(format!(
4912 "smooth linear-constraint cache mismatch for term '{}': cols={}, coeffs={}",
4913 term.name,
4914 lin_local.a.ncols(),
4915 range.len()
4916 ))
4917 .into());
4918 }
4919 for r in 0..lin_local.a.nrows() {
4920 let mut row = Array1::<f64>::zeros(total_p);
4921 row.slice_mut(s![range.clone()]).assign(&lin_local.a.row(r));
4922 linear_constraintrows.push(row);
4923 linear_constraint_b.push(lin_local.b[r]);
4924 }
4925 }
4926 }
4927
4928 smooth.coefficient_lower_bounds = if any_bounds {
4929 Some(coefficient_lower_bounds)
4930 } else {
4931 None
4932 };
4933 smooth.linear_constraints = if linear_constraintrows.is_empty() {
4934 None
4935 } else {
4936 let mut a = Array2::<f64>::zeros((linear_constraintrows.len(), total_p));
4937 for (i, row) in linear_constraintrows.iter().enumerate() {
4938 a.row_mut(i).assign(row);
4939 }
4940 Some(LinearInequalityConstraints {
4941 a,
4942 b: Array1::from_vec(linear_constraint_b),
4943 })
4944 };
4945 smooth.dropped_penaltyinfo = dropped_penaltyinfo_by_term
4946 .iter()
4947 .flat_map(|infos| infos.iter().cloned())
4948 .collect();
4949 Ok(())
4950}
4951
4952fn rebuild_term_collection_auxiliary_state(
4953 spec: &TermCollectionSpec,
4954 design: &mut TermCollectionDesign,
4955) -> Result<(), String> {
4956 if spec.linear_terms.len() != design.linear_ranges.len() {
4957 return Err(SmoothError::dimension_mismatch(format!(
4958 "term-collection linear bookkeeping mismatch: spec_terms={}, design_ranges={}",
4959 spec.linear_terms.len(),
4960 design.linear_ranges.len()
4961 ))
4962 .into());
4963 }
4964
4965 let p_total = design.design.ncols();
4966 let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
4967 let mut coefficient_lower_bounds = Array1::<f64>::from_elem(p_total, f64::NEG_INFINITY);
4968 let mut any_bounds = false;
4969 let mut linear_constraintrows: Vec<Array1<f64>> = Vec::new();
4970 let mut linear_constraint_b: Vec<f64> = Vec::new();
4971
4972 for (linear, (_, range)) in spec.linear_terms.iter().zip(design.linear_ranges.iter()) {
4973 if range.len() != 1 {
4974 return Err(SmoothError::dimension_mismatch(format!(
4975 "linear term '{}' expected one coefficient column, found {}",
4976 linear.name,
4977 range.len()
4978 ))
4979 .into());
4980 }
4981 let col = range.start;
4982 if let Some(lb) = linear.coefficient_min {
4983 let mut row = Array1::<f64>::zeros(p_total);
4984 row[col] = 1.0;
4985 linear_constraintrows.push(row);
4986 linear_constraint_b.push(lb);
4987 }
4988 if let Some(ub) = linear.coefficient_max {
4989 let mut row = Array1::<f64>::zeros(p_total);
4990 row[col] = -1.0;
4991 linear_constraintrows.push(row);
4992 linear_constraint_b.push(-ub);
4993 }
4994 }
4995
4996 if let Some(lb_smooth) = design.smooth.coefficient_lower_bounds.as_ref() {
4997 if lb_smooth.len() != design.smooth.total_smooth_cols() {
4998 return Err(SmoothError::dimension_mismatch(format!(
4999 "smooth lower-bound width mismatch: bounds={}, smooth_cols={}",
5000 lb_smooth.len(),
5001 design.smooth.total_smooth_cols()
5002 ))
5003 .into());
5004 }
5005 coefficient_lower_bounds
5006 .slice_mut(s![
5007 smooth_start..(smooth_start + design.smooth.total_smooth_cols())
5008 ])
5009 .assign(lb_smooth);
5010 any_bounds = true;
5011 }
5012 if let Some(lin_smooth) = design.smooth.linear_constraints.as_ref() {
5013 if lin_smooth.a.ncols() != design.smooth.total_smooth_cols() {
5014 return Err(SmoothError::dimension_mismatch(format!(
5015 "smooth linear-constraint width mismatch: cols={}, smooth_cols={}",
5016 lin_smooth.a.ncols(),
5017 design.smooth.total_smooth_cols()
5018 ))
5019 .into());
5020 }
5021 let mut a_global = Array2::<f64>::zeros((lin_smooth.a.nrows(), p_total));
5022 a_global
5023 .slice_mut(s![
5024 ..,
5025 smooth_start..(smooth_start + design.smooth.total_smooth_cols())
5026 ])
5027 .assign(&lin_smooth.a);
5028 for r in 0..a_global.nrows() {
5029 linear_constraintrows.push(a_global.row(r).to_owned());
5030 linear_constraint_b.push(lin_smooth.b[r]);
5031 }
5032 }
5033
5034 let lower_bound_constraints = if any_bounds {
5035 linear_constraints_from_lower_bounds_global(&coefficient_lower_bounds)
5036 } else {
5037 None
5038 };
5039 let explicit_linear_constraints = if linear_constraintrows.is_empty() {
5040 None
5041 } else {
5042 let mut a = Array2::<f64>::zeros((linear_constraintrows.len(), p_total));
5043 for (i, row) in linear_constraintrows.iter().enumerate() {
5044 a.row_mut(i).assign(row);
5045 }
5046 Some(LinearInequalityConstraints {
5047 a,
5048 b: Array1::from_vec(linear_constraint_b),
5049 })
5050 };
5051
5052 design.coefficient_lower_bounds = if any_bounds {
5053 Some(coefficient_lower_bounds)
5054 } else {
5055 None
5056 };
5057 design.linear_constraints =
5058 merge_linear_constraints_global(explicit_linear_constraints, lower_bound_constraints);
5059 design.dropped_penaltyinfo = design.smooth.dropped_penaltyinfo.clone();
5060 Ok(())
5061}
5062
5063fn theta_values_match(left: &Array1<f64>, right: &Array1<f64>) -> bool {
5064 left.len() == right.len()
5065 && left
5066 .iter()
5067 .zip(right.iter())
5068 .all(|(&l, &r)| l.to_bits() == r.to_bits())
5069}
5070
5071fn latent_values_match(left: &Array1<f64>, right: &Array1<f64>) -> bool {
5072 theta_values_match(left, right)
5073}
5074
5075fn spatial_aniso_matches(left: Option<&[f64]>, right: Option<&[f64]>) -> bool {
5076 match (left, right) {
5077 (None, None) => true,
5078 (Some(a), Some(b)) => {
5079 a.len() == b.len()
5080 && a.iter()
5081 .zip(b.iter())
5082 .all(|(&x, &y)| x.to_bits() == y.to_bits())
5083 }
5084 _ => false,
5085 }
5086}
5087
5088fn spatial_length_scale_matches(left: Option<f64>, right: Option<f64>) -> bool {
5089 match (left, right) {
5090 (None, None) => true,
5091 (Some(a), Some(b)) => a.to_bits() == b.to_bits(),
5092 _ => false,
5093 }
5094}
5095
5096struct FrozenTermCollectionIncrementalRealizer<'d> {
5097 data: ArrayView2<'d, f64>,
5098 spec: TermCollectionSpec,
5099 design: TermCollectionDesign,
5100 fixed_blocks: Vec<DesignBlock>,
5101 dropped_penaltyinfo_by_term: Vec<Vec<DroppedPenaltyBlockInfo>>,
5102 smooth_penalty_ranges: Vec<Range<usize>>,
5103 full_penalty_ranges: Vec<Range<usize>>,
5104 basisworkspace: gam_terms::basis::BasisWorkspace,
5108 spatial_realization_geometry: Vec<Option<SmoothTermSpec>>,
5121 design_revision: u64,
5127}
5128
5129impl<'d> std::fmt::Debug for FrozenTermCollectionIncrementalRealizer<'d> {
5130 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
5131 f.debug_struct("FrozenTermCollectionIncrementalRealizer")
5132 .field("data_shape", &(self.data.nrows(), self.data.ncols()))
5133 .field("fixed_blocks", &self.fixed_blocks.len())
5134 .finish_non_exhaustive()
5135 }
5136}
5137
5138impl<'d> FrozenTermCollectionIncrementalRealizer<'d> {
5139 fn new(
5140 data: ArrayView2<'d, f64>,
5141 spec: TermCollectionSpec,
5142 design: TermCollectionDesign,
5143 ) -> Result<Self, String> {
5144 if spec.smooth_terms.len() != design.smooth.terms.len() {
5145 return Err(SmoothError::dimension_mismatch(format!(
5146 "incremental realizer smooth term mismatch: spec_terms={}, design_terms={}",
5147 spec.smooth_terms.len(),
5148 design.smooth.terms.len()
5149 ))
5150 .into());
5151 }
5152
5153 let mut smooth_cursor = 0usize;
5154 let mut smooth_penalty_ranges = Vec::with_capacity(design.smooth.terms.len());
5155 for term in &design.smooth.terms {
5156 let next = smooth_cursor + term.penalties_local.len();
5157 smooth_penalty_ranges.push(smooth_cursor..next);
5158 smooth_cursor = next;
5159 }
5160 if smooth_cursor != design.smooth.penalties.len() {
5161 return Err(SmoothError::dimension_mismatch(format!(
5162 "incremental realizer smooth penalty mismatch: ranged={}, actual={}",
5163 smooth_cursor,
5164 design.smooth.penalties.len()
5165 ))
5166 .into());
5167 }
5168
5169 let fixed_penalty_offset = design
5170 .penalties
5171 .len()
5172 .checked_sub(design.smooth.penalties.len())
5173 .ok_or_else(|| {
5174 "incremental realizer encountered invalid penalty bookkeeping".to_string()
5175 })?;
5176 let full_penalty_ranges = smooth_penalty_ranges
5177 .iter()
5178 .map(|range| (fixed_penalty_offset + range.start)..(fixed_penalty_offset + range.end))
5179 .collect::<Vec<_>>();
5180 let fixed_blocks = build_term_collection_fixed_blocks(data, &spec)
5181 .map_err(|e| format!("failed to cache fixed term-collection blocks: {e}"))?;
5182
5183 let mut dropped_penaltyinfo_by_term = Vec::with_capacity(spec.smooth_terms.len());
5184 for (term_idx, termspec) in spec.smooth_terms.iter().enumerate() {
5185 let realization =
5186 build_single_smooth_term_realization(data, termspec).map_err(|e| {
5187 format!(
5188 "failed to build cached realization for smooth term '{}' (index {}): {e}",
5189 termspec.name, term_idx
5190 )
5191 })?;
5192 let expected_cols = design.smooth.terms[term_idx].coeff_range.len();
5193 if realization.design_local.ncols() != expected_cols {
5194 return Err(SmoothError::dimension_mismatch(format!(
5195 "cached realization width mismatch for term '{}': cached_cols={}, design_cols={}",
5196 termspec.name,
5197 realization.design_local.ncols(),
5198 expected_cols
5199 ))
5200 .into());
5201 }
5202 if realization.active_penaltyinfo().len()
5203 != design.smooth.terms[term_idx].penalties_local.len()
5204 {
5205 return Err(SmoothError::dimension_mismatch(format!(
5206 "cached realization penalty mismatch for term '{}': cached_penalties={}, design_penalties={}",
5207 termspec.name,
5208 realization.active_penaltyinfo().len(),
5209 design.smooth.terms[term_idx].penalties_local.len()
5210 ))
5211 .into());
5212 }
5213 dropped_penaltyinfo_by_term.push(realization.dropped_penaltyinfo);
5214 }
5215
5216 let geometry_slots = spec.smooth_terms.len();
5217 Ok(Self {
5218 data,
5219 spec,
5220 design,
5221 fixed_blocks,
5222 dropped_penaltyinfo_by_term,
5223 smooth_penalty_ranges,
5224 full_penalty_ranges,
5225 basisworkspace: gam_terms::basis::BasisWorkspace::new(),
5226 spatial_realization_geometry: vec![None; geometry_slots],
5227 design_revision: 0,
5228 })
5229 }
5230
5231 fn design_revision(&self) -> u64 {
5232 self.design_revision
5233 }
5234
5235 fn spec(&self) -> &TermCollectionSpec {
5236 &self.spec
5237 }
5238
5239 fn design(&self) -> &TermCollectionDesign {
5240 &self.design
5241 }
5242
5243 fn supports_nfree_penalty_rekey(&self, spatial_terms: &[usize]) -> bool {
5258 if spatial_terms.len() != 1 {
5259 return false;
5260 }
5261 let term_idx = spatial_terms[0];
5262 matches!(
5263 self.design.smooth.terms.get(term_idx).map(|t| &t.metadata),
5264 Some(BasisMetadata::Duchon { .. } | BasisMetadata::ThinPlate { .. })
5265 )
5266 }
5267
5268 fn supports_nfree_gradient_only_routing(&self, spatial_terms: &[usize]) -> bool {
5277 if spatial_terms.len() != 1 {
5278 return false;
5279 }
5280 let term_idx = spatial_terms[0];
5281 matches!(
5282 self.design.smooth.terms.get(term_idx).map(|t| &t.metadata),
5283 Some(BasisMetadata::Duchon { .. } | BasisMetadata::ThinPlate { .. })
5284 )
5285 }
5286
5287 fn canonical_penalties_at_psi(
5300 &mut self,
5301 spatial_terms: &[usize],
5302 psi: &[f64],
5303 ) -> Result<(Vec<gam_terms::construction::CanonicalPenalty>, Vec<usize>), String> {
5304 if spatial_terms.len() != 1 {
5305 return Err(format!(
5306 "n-free penalty re-key requires exactly one spatial term, found {}",
5307 spatial_terms.len()
5308 ));
5309 }
5310 let term_idx = spatial_terms[0];
5311 let (ls_opt, aniso_from_psi) = spatial_term_psi_to_length_scale_and_aniso(psi);
5317 let termspec =
5320 self.spec.smooth_terms.get(term_idx).ok_or_else(|| {
5321 format!("spatial term {term_idx} out of range for n-free penalty")
5322 })?;
5323 let term = self
5324 .design
5325 .smooth
5326 .terms
5327 .get(term_idx)
5328 .ok_or_else(|| format!("realized smooth term {term_idx} out of range"))?;
5329 let p_total = self.design.design.ncols();
5332 let (locals, nullspace_dims): (Vec<Array2<f64>>, Vec<usize>) = match &term.metadata {
5333 BasisMetadata::Duchon {
5334 centers,
5335 identifiability_transform,
5336 operator_collocation_points,
5337 power,
5338 nullspace_order,
5339 aniso_log_scales,
5340 input_scales,
5341 radial_reparam,
5342 ..
5343 } => {
5344 let operator_penalties = match &termspec.basis {
5345 SmoothBasisSpec::Duchon { spec, .. } => spec.operator_penalties.clone(),
5346 _ => gam_terms::basis::DuchonOperatorPenaltySpec::default(),
5347 };
5348 let effective_ls = match input_scales.as_deref() {
5355 Some(scales) => {
5356 compensate_optional_length_scale_for_standardization(ls_opt, scales)
5357 }
5358 None => ls_opt,
5359 };
5360 gam_terms::basis::duchon_penalties_at_length_scale(
5361 centers.view(),
5362 identifiability_transform.as_ref(),
5363 operator_collocation_points.as_ref().map(|p| p.view()),
5364 &operator_penalties,
5365 *power,
5366 *nullspace_order,
5367 aniso_log_scales.as_deref(),
5368 radial_reparam.as_ref(),
5369 effective_ls,
5370 &mut self.basisworkspace,
5371 )
5372 .map_err(|e| e.to_string())?
5373 }
5374 BasisMetadata::Matern {
5375 centers,
5376 periodic,
5377 nu,
5378 include_intercept,
5379 identifiability_transform,
5380 aniso_log_scales,
5381 input_scales,
5382 ..
5383 } => {
5384 let ls = ls_opt.ok_or_else(|| {
5391 "Matérn n-free penalty re-key requires a finite length-scale".to_string()
5392 })?;
5393 let effective_ls = match input_scales.as_deref() {
5394 Some(scales) => compensate_length_scale_for_standardization(ls, scales),
5395 None => ls,
5396 };
5397 let aniso_for_penalty = aniso_from_psi.as_deref().or(aniso_log_scales.as_deref());
5398 let (penalties, nullspace_dims, _info) =
5409 matern_operator_penalty_triplet_at_length_scale(
5410 centers.view(),
5411 periodic.as_deref(),
5412 identifiability_transform.as_ref(),
5413 *nu,
5414 *include_intercept,
5415 aniso_for_penalty,
5416 effective_ls,
5417 )
5418 .map_err(|e| e.to_string())?;
5419 (penalties, nullspace_dims)
5420 }
5421 BasisMetadata::ThinPlate {
5422 centers,
5423 identifiability_transform,
5424 radial_reparam,
5425 ..
5426 } => {
5427 let ls = ls_opt.ok_or_else(|| {
5428 "thin-plate n-free penalty re-key requires a finite length-scale".to_string()
5429 })?;
5430 let double_penalty = match &termspec.basis {
5431 SmoothBasisSpec::ThinPlate { spec, .. } => spec.double_penalty,
5432 _ => false,
5433 };
5434 gam_terms::basis::thin_plate_penalties_at_length_scale(
5435 centers.view(),
5436 identifiability_transform.as_ref(),
5437 radial_reparam.as_ref(),
5438 ls,
5439 double_penalty,
5440 &mut self.basisworkspace,
5441 )
5442 .map_err(|e| e.to_string())?
5443 }
5444 other => {
5445 return Err(format!(
5446 "n-free penalty re-key unsupported for basis metadata {:?}",
5447 std::mem::discriminant(other)
5448 ));
5449 }
5450 };
5451 let templates = &self.design.penalties;
5456 if templates.len() != locals.len() {
5457 return Err(format!(
5458 "n-free penalty re-key produced {} blocks but the frozen design carries {} \
5459 — penalty topology is not ψ-stable",
5460 locals.len(),
5461 templates.len()
5462 ));
5463 }
5464 let specs: Vec<gam_solve::estimate::PenaltySpec> = templates
5465 .iter()
5466 .zip(locals.into_iter())
5467 .map(|(tmpl, local)| gam_solve::estimate::PenaltySpec::Block {
5468 local,
5469 col_range: tmpl.col_range.clone(),
5470 prior_mean: tmpl.prior_mean.clone(),
5471 structure_hint: tmpl.structure_hint.clone(),
5472 op: tmpl.op.clone(),
5473 })
5474 .collect();
5475 gam_terms::construction::canonicalize_penalty_specs(
5476 &specs,
5477 &nullspace_dims,
5478 p_total,
5479 "nfree-psi-penalty",
5480 )
5481 .map_err(|e| e.to_string())
5482 }
5483
5484 fn canonical_penalty_derivatives_at_psi(
5485 &mut self,
5486 spatial_terms: &[usize],
5487 psi: &[f64],
5488 ) -> Result<(Range<usize>, usize, Vec<Array2<f64>>), String> {
5489 if spatial_terms.len() != 1 {
5490 return Err(format!(
5491 "n-free penalty derivative re-key requires exactly one spatial term, found {}",
5492 spatial_terms.len()
5493 ));
5494 }
5495 let term_idx = spatial_terms[0];
5496 let (ls_opt, aniso_from_psi) = spatial_term_psi_to_length_scale_and_aniso(psi);
5497 let termspec = self.spec.smooth_terms.get(term_idx).ok_or_else(|| {
5498 format!("spatial term {term_idx} out of range for n-free penalty derivative")
5499 })?;
5500 let term = self
5501 .design
5502 .smooth
5503 .terms
5504 .get(term_idx)
5505 .ok_or_else(|| format!("realized smooth term {term_idx} out of range"))?;
5506 let p_total = self.design.design.ncols();
5507 let smooth_start = p_total.saturating_sub(self.design.smooth.total_smooth_cols());
5508 let global_range =
5509 (smooth_start + term.coeff_range.start)..(smooth_start + term.coeff_range.end);
5510
5511 let locals = match &term.metadata {
5512 BasisMetadata::Duchon {
5513 centers,
5514 identifiability_transform,
5515 operator_collocation_points,
5516 power,
5517 nullspace_order,
5518 aniso_log_scales,
5519 input_scales,
5520 radial_reparam,
5521 ..
5522 } => {
5523 let mut spec = match &termspec.basis {
5524 SmoothBasisSpec::Duchon { spec, .. } => spec.clone(),
5525 _ => {
5526 return Err(
5527 "Duchon n-free penalty derivative requires a Duchon term spec"
5528 .to_string(),
5529 );
5530 }
5531 };
5532 let effective_ls = match input_scales.as_deref() {
5533 Some(scales) => {
5534 compensate_optional_length_scale_for_standardization(ls_opt, scales)
5535 }
5536 None => ls_opt,
5537 };
5538 spec.length_scale = effective_ls;
5539 spec.power = *power;
5540 spec.nullspace_order = *nullspace_order;
5541 spec.aniso_log_scales = aniso_log_scales.clone();
5542 spec.radial_reparam = radial_reparam.clone();
5545 if spec.length_scale.is_none() {
5546 return Err(
5547 "Duchon n-free penalty derivative requires a hybrid length-scale"
5548 .to_string(),
5549 );
5550 }
5551 let collocation = operator_collocation_points
5552 .as_ref()
5553 .map(|points| points.view())
5554 .unwrap_or_else(|| centers.view());
5555 let (_native_sources, mut first, _native_second) =
5556 gam_terms::basis::build_duchon_native_penalty_psi_derivatives(
5557 centers.view(),
5558 &spec,
5559 identifiability_transform.as_ref(),
5560 &mut self.basisworkspace,
5561 )
5562 .map_err(|e| e.to_string())?;
5563 let (_operator_sources, operator_first, _operator_second) =
5564 gam_terms::basis::build_duchon_operator_penalty_psi_derivatives(
5565 collocation,
5566 centers.view(),
5567 &spec,
5568 identifiability_transform.as_ref(),
5569 &mut self.basisworkspace,
5570 )
5571 .map_err(|e| e.to_string())?;
5572 first.extend(operator_first);
5573 first
5574 }
5575 BasisMetadata::Matern {
5576 centers,
5577 periodic,
5578 nu,
5579 include_intercept,
5580 identifiability_transform,
5581 aniso_log_scales,
5582 input_scales,
5583 ..
5584 } => {
5585 let ls = ls_opt.ok_or_else(|| {
5586 "Matérn n-free penalty derivative requires a finite length-scale".to_string()
5587 })?;
5588 let effective_ls = match input_scales.as_deref() {
5589 Some(scales) => compensate_length_scale_for_standardization(ls, scales),
5590 None => ls,
5591 };
5592 let penalty_centers =
5593 gam_terms::basis::expand_periodic_centers(¢ers.to_owned(), periodic.as_deref())
5594 .map_err(|e| e.to_string())?;
5595 let aniso_for_penalty = aniso_from_psi.as_deref().or(aniso_log_scales.as_deref());
5596 let (first, _second) = gam_terms::basis::build_matern_operator_penalty_psi_derivatives(
5597 penalty_centers.view(),
5598 effective_ls,
5599 *nu,
5600 *include_intercept,
5601 identifiability_transform.as_ref(),
5602 aniso_for_penalty,
5603 )
5604 .map_err(|e| e.to_string())?;
5605 first
5606 }
5607 BasisMetadata::ThinPlate {
5608 centers,
5609 identifiability_transform,
5610 radial_reparam,
5611 ..
5612 } => {
5613 let ls = ls_opt.ok_or_else(|| {
5614 "thin-plate n-free penalty derivative requires a finite length-scale"
5615 .to_string()
5616 })?;
5617 let mut spec = match &termspec.basis {
5618 SmoothBasisSpec::ThinPlate { spec, .. } => spec.clone(),
5619 _ => {
5620 return Err(
5621 "thin-plate n-free penalty derivative requires a ThinPlate term spec"
5622 .to_string(),
5623 );
5624 }
5625 };
5626 spec.length_scale = ls;
5627 if spec.radial_reparam.is_none() {
5628 spec.radial_reparam = radial_reparam.clone();
5629 }
5630 let (primary, _primary_second) =
5631 gam_terms::basis::build_thin_plate_penalty_psi_derivativeswithworkspace(
5632 centers.view(),
5633 &spec,
5634 identifiability_transform.as_ref(),
5635 &mut self.basisworkspace,
5636 )
5637 .map_err(|e| e.to_string())?;
5638 if self.design.penalties.len() > 1 {
5639 vec![primary.clone(), Array2::<f64>::zeros(primary.raw_dim())]
5640 } else {
5641 vec![primary]
5642 }
5643 }
5644 other => {
5645 return Err(format!(
5646 "n-free penalty derivative re-key unsupported for basis metadata {:?}",
5647 std::mem::discriminant(other)
5648 ));
5649 }
5650 };
5651 if locals.len() != self.design.penalties.len() {
5652 return Err(format!(
5653 "n-free penalty derivative re-key produced {} blocks but the frozen design carries {} \
5654 — penalty topology is not ψ-stable",
5655 locals.len(),
5656 self.design.penalties.len()
5657 ));
5658 }
5659 Ok((global_range, p_total, locals))
5660 }
5661
5662 fn apply_log_kappa(
5663 &mut self,
5664 log_kappa: &SpatialLogKappaCoords,
5665 term_indices: &[usize],
5666 ) -> Result<(), String> {
5667 if term_indices.len() != log_kappa.dims_per_term().len() {
5668 return Err(SmoothError::dimension_mismatch(format!(
5669 "incremental realizer log-kappa term mismatch: term_indices={}, dims_per_term={}",
5670 term_indices.len(),
5671 log_kappa.dims_per_term().len()
5672 ))
5673 .into());
5674 }
5675
5676 let mut any_changed = false;
5677 for (slot, &term_idx) in term_indices.iter().enumerate() {
5678 any_changed |= self.apply_log_kappa_to_term(term_idx, log_kappa.term_slice(slot))?;
5679 }
5680
5681 if any_changed {
5682 self.refresh_full_design_operator()?;
5683 rebuild_smooth_auxiliary_state(
5684 &mut self.design.smooth,
5685 &self.dropped_penaltyinfo_by_term,
5686 )?;
5687 rebuild_term_collection_auxiliary_state(&self.spec, &mut self.design)?;
5688 self.design_revision = self.design_revision.wrapping_add(1);
5689 }
5690 Ok(())
5691 }
5692
5693 fn apply_log_kappa_to_term(&mut self, term_idx: usize, psi: &[f64]) -> Result<bool, String> {
5694 if !spatial_term_supports_hyper_optimization(&self.spec, term_idx) {
5695 return Err(SmoothError::invalid_config(format!(
5696 "incremental realizer term {term_idx} does not expose spatial hyperparameters"
5697 ))
5698 .into());
5699 }
5700 let measure_jet_term = measure_jet_term_spec(&self.spec, term_idx).is_some();
5704 let constant_curvature_term = constant_curvature_term_spec(&self.spec, term_idx).is_some();
5708 let mut next_length_scale = None;
5709 let mut next_aniso: Option<Vec<f64>> = None;
5710 if measure_jet_term {
5711 if !set_measure_jet_psi_dials(&mut self.spec, term_idx, psi)
5712 .map_err(|e| e.to_string())?
5713 {
5714 return Ok(false);
5715 }
5716 } else if constant_curvature_term {
5717 if !set_constant_curvature_kappa(&mut self.spec, term_idx, psi)
5718 .map_err(|e| e.to_string())?
5719 {
5720 return Ok(false);
5721 }
5722 } else {
5723 let current_length_scale = get_spatial_length_scale(&self.spec, term_idx);
5724 let current_aniso = get_spatial_aniso_log_scales(&self.spec, term_idx);
5725 let (ls, eta) = spatial_term_psi_to_length_scale_and_aniso(psi);
5726 next_length_scale = ls;
5727 next_aniso = eta;
5728 let same_length = spatial_length_scale_matches(current_length_scale, next_length_scale);
5729 let same_aniso = spatial_aniso_matches(current_aniso.as_deref(), next_aniso.as_deref());
5730 if same_length && same_aniso {
5731 return Ok(false);
5732 }
5733 if let Some(length_scale) = next_length_scale {
5734 set_spatial_length_scale(&mut self.spec, term_idx, length_scale)
5735 .map_err(|e| e.to_string())?;
5736 }
5737 if let Some(eta) = next_aniso.clone() {
5738 set_spatial_aniso_log_scales(&mut self.spec, term_idx, eta)
5739 .map_err(|e| e.to_string())?;
5740 }
5741 }
5742
5743 let geometry_slot = self
5754 .spatial_realization_geometry
5755 .get(term_idx)
5756 .ok_or_else(|| format!("incremental realizer geometry slot {term_idx} out of range"))?;
5757 let mut build_spec = match geometry_slot {
5758 Some(cached) => cached.clone(),
5759 None => self
5760 .spec
5761 .smooth_terms
5762 .get(term_idx)
5763 .ok_or_else(|| format!("incremental realizer smooth term {term_idx} out of range"))?
5764 .clone(),
5765 };
5766 if measure_jet_term {
5767 set_single_term_measure_jet_psi_dials(&mut build_spec, psi)
5771 .map_err(|e| e.to_string())?;
5772 } else if constant_curvature_term {
5773 set_single_term_constant_curvature_kappa(&mut build_spec, psi)
5778 .map_err(|e| e.to_string())?;
5779 } else {
5780 if let Some(length_scale) = next_length_scale {
5781 set_single_term_spatial_length_scale(&mut build_spec, length_scale)
5782 .map_err(|e| e.to_string())?;
5783 }
5784 if let Some(eta) = next_aniso {
5785 set_single_term_spatial_aniso_log_scales(&mut build_spec, eta)
5786 .map_err(|e| e.to_string())?;
5787 }
5788 }
5789
5790 let termname = build_spec.name.clone();
5791 let local = build_single_local_smooth_term(
5792 self.data,
5793 &build_spec,
5794 &mut self.basisworkspace,
5795 )
5796 .map_err(|e| {
5797 format!(
5798 "failed to rebuild smooth term '{termname}' during incremental κ realization: {e}"
5799 )
5800 })?;
5801
5802 if self.spatial_realization_geometry[term_idx].is_none()
5807 && let Some(frozen) = freeze_geometry_from_metadata(&build_spec, &local.metadata)
5808 {
5809 if let (
5821 SmoothBasisSpec::Matern {
5822 spec: frozen_spec, ..
5823 },
5824 Some(SmoothBasisSpec::Matern {
5825 spec: live_spec, ..
5826 }),
5827 ) = (
5828 &frozen.basis,
5829 self.spec
5830 .smooth_terms
5831 .get_mut(term_idx)
5832 .map(|t| &mut t.basis),
5833 ) {
5834 live_spec.identifiability = frozen_spec.identifiability.clone();
5835 live_spec.center_strategy = frozen_spec.center_strategy.clone();
5836 }
5837 self.spatial_realization_geometry[term_idx] = Some(frozen);
5838 }
5839
5840 let realization = wrap_local_build_as_realization(local, &build_spec)?;
5841 self.replace_term_realization(term_idx, realization)?;
5842 Ok(true)
5843 }
5844
5845 fn replace_term_realization(
5846 &mut self,
5847 term_idx: usize,
5848 realization: SingleSmoothTermRealization,
5849 ) -> Result<(), String> {
5850 let t_replace = std::time::Instant::now();
5851 let SingleSmoothTermRealization {
5852 design_local,
5853 term,
5854 dropped_penaltyinfo,
5855 } = realization;
5856 let SmoothTerm {
5857 name,
5858 penalties_local,
5859 nullspace_dims,
5860 penaltyinfo_local,
5861 metadata,
5862 lower_bounds_local,
5863 linear_constraints_local,
5864 joint_null_rotation,
5865 ..
5866 } = term;
5867 let coeff_range = self
5868 .design
5869 .smooth
5870 .terms
5871 .get(term_idx)
5872 .ok_or_else(|| format!("incremental realizer smooth term {term_idx} out of range"))?
5873 .coeff_range
5874 .clone();
5875 if design_local.ncols() != coeff_range.len() {
5876 return Err(SmoothError::dimension_mismatch(format!(
5877 "incremental realizer width mismatch for term {}: rebuilt_cols={}, cached_cols={}",
5878 term_idx,
5879 design_local.ncols(),
5880 coeff_range.len()
5881 ))
5882 .into());
5883 }
5884 if design_local.nrows() != self.design.design.nrows() {
5885 return Err(SmoothError::dimension_mismatch(format!(
5886 "incremental realizer row mismatch for term {}: rebuilt_rows={}, design_rows={}",
5887 term_idx,
5888 design_local.nrows(),
5889 self.design.design.nrows()
5890 ))
5891 .into());
5892 }
5893
5894 let active_penaltyinfo = penaltyinfo_local
5895 .iter()
5896 .filter(|info| info.active)
5897 .cloned()
5898 .collect::<Vec<_>>();
5899 let smooth_penalty_range = self
5900 .smooth_penalty_ranges
5901 .get(term_idx)
5902 .ok_or_else(|| {
5903 format!("incremental realizer missing smooth penalty range for term {term_idx}")
5904 })?
5905 .clone();
5906 let full_penalty_range = self
5907 .full_penalty_ranges
5908 .get(term_idx)
5909 .ok_or_else(|| {
5910 format!("incremental realizer missing full penalty range for term {term_idx}")
5911 })?
5912 .clone();
5913 if active_penaltyinfo.len() != smooth_penalty_range.len()
5914 || penalties_local.len() != smooth_penalty_range.len()
5915 || nullspace_dims.len() != smooth_penalty_range.len()
5916 {
5917 return Err(SmoothError::dimension_mismatch(format!(
5918 "incremental realizer topology changed for term '{}': penalties={}, infos={}, nullspaces={}, cached_penalties={}",
5919 name,
5920 penalties_local.len(),
5921 active_penaltyinfo.len(),
5922 nullspace_dims.len(),
5923 smooth_penalty_range.len()
5924 ))
5925 .into());
5926 }
5927
5928 self.design.smooth.term_designs[term_idx] = design_local;
5929
5930 for (offset, penalty_local) in penalties_local.iter().enumerate() {
5931 let smooth_penalty_idx = smooth_penalty_range.start + offset;
5932 let full_penalty_idx = full_penalty_range.start + offset;
5933 let nullspace_dim = nullspace_dims[offset];
5934 let penalty_info = active_penaltyinfo[offset].clone();
5935
5936 if penalty_local.nrows() != coeff_range.len()
5937 || penalty_local.ncols() != coeff_range.len()
5938 {
5939 return Err(SmoothError::dimension_mismatch(format!(
5940 "incremental realizer penalty shape mismatch for term '{}' penalty {}: \
5941 penalty is {}x{} but coeff_range has {} columns",
5942 name,
5943 offset,
5944 penalty_local.nrows(),
5945 penalty_local.ncols(),
5946 coeff_range.len()
5947 ))
5948 .into());
5949 }
5950
5951 let smooth_penalty = self
5952 .design
5953 .smooth
5954 .penalties
5955 .get_mut(smooth_penalty_idx)
5956 .ok_or_else(|| {
5957 format!(
5958 "incremental realizer smooth penalty {} out of range for term {}",
5959 smooth_penalty_idx, term_idx
5960 )
5961 })?;
5962 smooth_penalty.local.assign(penalty_local);
5965
5966 let full_bp = self
5967 .design
5968 .penalties
5969 .get_mut(full_penalty_idx)
5970 .ok_or_else(|| {
5971 format!(
5972 "incremental realizer full penalty {} out of range for term {}",
5973 full_penalty_idx, term_idx
5974 )
5975 })?;
5976 full_bp.local.assign(penalty_local);
5979
5980 self.design.smooth.nullspace_dims[smooth_penalty_idx] = nullspace_dim;
5981 self.design.nullspace_dims[full_penalty_idx] = nullspace_dim;
5982
5983 self.design.smooth.penaltyinfo[smooth_penalty_idx].global_index = smooth_penalty_idx;
5984 self.design.smooth.penaltyinfo[smooth_penalty_idx].termname = Some(name.clone());
5985 self.design.smooth.penaltyinfo[smooth_penalty_idx].penalty = penalty_info.clone();
5986
5987 self.design.penaltyinfo[full_penalty_idx].global_index = full_penalty_idx;
5988 self.design.penaltyinfo[full_penalty_idx].termname = Some(name.clone());
5989 self.design.penaltyinfo[full_penalty_idx].penalty = penalty_info;
5990 }
5991
5992 let target_term = self.design.smooth.terms.get_mut(term_idx).ok_or_else(|| {
5993 format!("incremental realizer smooth term {term_idx} disappeared during replacement")
5994 })?;
5995 target_term.penalties_local = penalties_local;
5996 target_term.nullspace_dims = nullspace_dims;
5997 target_term.penaltyinfo_local = penaltyinfo_local;
5998 target_term.metadata = metadata;
5999 target_term.lower_bounds_local = lower_bounds_local;
6000 target_term.linear_constraints_local = linear_constraints_local;
6001 target_term.joint_null_rotation = joint_null_rotation;
6002 self.dropped_penaltyinfo_by_term[term_idx] = dropped_penaltyinfo;
6003 log::info!(
6004 "[STAGE] smooth basis rebuild (term {}, '{}', cols={}): {:.3}s",
6005 term_idx,
6006 target_term.name,
6007 coeff_range.len(),
6008 t_replace.elapsed().as_secs_f64(),
6009 );
6010 Ok(())
6011 }
6012
6013 fn refresh_full_design_operator(&mut self) -> Result<(), String> {
6014 let mut blocks = Vec::<DesignBlock>::with_capacity(
6015 self.fixed_blocks.len() + self.design.smooth.term_designs.len(),
6016 );
6017 blocks.extend(self.fixed_blocks.iter().cloned());
6018 for term_design in &self.design.smooth.term_designs {
6019 blocks.push(DesignBlock::from(term_design));
6020 }
6021 self.design.design = assemble_term_collection_design_matrix(blocks)
6022 .map_err(|e| format!("failed to refresh term-collection design: {e}"))?;
6023 Ok(())
6024 }
6025}
6026
6027fn build_term_collection_fixed_blocks(
6028 data: ArrayView2<'_, f64>,
6029 spec: &TermCollectionSpec,
6030) -> Result<Vec<DesignBlock>, BasisError> {
6031 let mut blocks = Vec::<DesignBlock>::new();
6032 if !term_collection_has_one_sided_anchored_bspline(spec) {
6033 blocks.push(DesignBlock::Intercept(data.nrows()));
6034 }
6035
6036 if !spec.linear_terms.is_empty() {
6037 let mut linear_block = Array2::<f64>::zeros((data.nrows(), spec.linear_terms.len()));
6038 for (j, linear) in spec.linear_terms.iter().enumerate() {
6039 let column = linear
6043 .realized_design_column(data)
6044 .map_err(BasisError::InvalidInput)?;
6045 linear_block.column_mut(j).assign(&column);
6046 }
6047 blocks.push(DesignBlock::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
6048 linear_block,
6049 )));
6050 }
6051
6052 for term in &spec.random_effect_terms {
6053 let block = build_random_effect_block(data, term)?;
6054 let re_op = RandomEffectOperator::new(block.group_ids, block.num_groups);
6055 blocks.push(DesignBlock::RandomEffect(Arc::new(re_op)));
6056 }
6057
6058 Ok(blocks)
6059}
6060
6061pub struct SpatialLengthScaleOptimizationResult<FitOut> {
6066 pub resolved_specs: Vec<TermCollectionSpec>,
6067 pub designs: Vec<TermCollectionDesign>,
6068 pub fit: FitOut,
6069 pub timing: Option<SpatialLengthScaleOptimizationTiming>,
6070}
6071
6072#[derive(Debug, Clone)]
6074pub struct ExactJointHyperSetup {
6075 rho0: Array1<f64>,
6076 rho_lower: Array1<f64>,
6077 rho_upper: Array1<f64>,
6078 log_kappa0: SpatialLogKappaCoords,
6079 log_kappa_lower: SpatialLogKappaCoords,
6080 log_kappa_upper: SpatialLogKappaCoords,
6081 auxiliary0: Array1<f64>,
6082 auxiliary_lower: Array1<f64>,
6083 auxiliary_upper: Array1<f64>,
6084}
6085
6086impl ExactJointHyperSetup {
6087 fn sanitize_rho_seed(
6088 rho0: Array1<f64>,
6089 rho_lower: &Array1<f64>,
6090 rho_upper: &Array1<f64>,
6091 ) -> Array1<f64> {
6092 Array1::from_iter(rho0.iter().enumerate().map(|(idx, &value)| {
6093 let lo = rho_lower[idx];
6094 let hi = rho_upper[idx];
6095 let fallback = 0.0_f64.clamp(lo, hi);
6096 if value.is_finite() {
6097 value.clamp(lo, hi)
6098 } else {
6099 fallback
6100 }
6101 }))
6102 }
6103
6104 pub(crate) fn new(
6105 rho0: Array1<f64>,
6106 rho_lower: Array1<f64>,
6107 rho_upper: Array1<f64>,
6108 log_kappa0: SpatialLogKappaCoords,
6109 log_kappa_lower: SpatialLogKappaCoords,
6110 log_kappa_upper: SpatialLogKappaCoords,
6111 ) -> Self {
6112 let rho0 = Self::sanitize_rho_seed(rho0, &rho_lower, &rho_upper);
6113 Self {
6114 rho0,
6115 rho_lower,
6116 rho_upper,
6117 log_kappa0,
6118 log_kappa_lower,
6119 log_kappa_upper,
6120 auxiliary0: Array1::zeros(0),
6121 auxiliary_lower: Array1::zeros(0),
6122 auxiliary_upper: Array1::zeros(0),
6123 }
6124 }
6125
6126 pub(crate) fn with_auxiliary(
6127 mut self,
6128 auxiliary0: Array1<f64>,
6129 auxiliary_lower: Array1<f64>,
6130 auxiliary_upper: Array1<f64>,
6131 ) -> Self {
6132 assert_eq!(
6133 auxiliary0.len(),
6134 auxiliary_lower.len(),
6135 "auxiliary lower bound length mismatch"
6136 );
6137 assert_eq!(
6138 auxiliary0.len(),
6139 auxiliary_upper.len(),
6140 "auxiliary upper bound length mismatch"
6141 );
6142 self.auxiliary0 = Self::sanitize_rho_seed(auxiliary0, &auxiliary_lower, &auxiliary_upper);
6143 self.auxiliary_lower = auxiliary_lower;
6144 self.auxiliary_upper = auxiliary_upper;
6145 self
6146 }
6147
6148 pub(crate) fn rho_dim(&self) -> usize {
6149 self.rho0.len()
6150 }
6151
6152 pub(crate) fn log_kappa_dim(&self) -> usize {
6153 self.log_kappa0.len()
6154 }
6155
6156 pub(crate) fn auxiliary_dim(&self) -> usize {
6157 self.auxiliary0.len()
6158 }
6159
6160 pub(crate) fn theta0(&self) -> Array1<f64> {
6161 let mut out =
6162 Array1::<f64>::zeros(self.rho_dim() + self.log_kappa_dim() + self.auxiliary_dim());
6163 out.slice_mut(s![..self.rho_dim()]).assign(&self.rho0);
6164 out.slice_mut(s![self.rho_dim()..self.rho_dim() + self.log_kappa_dim()])
6165 .assign(self.log_kappa0.as_array());
6166 out.slice_mut(s![self.rho_dim() + self.log_kappa_dim()..])
6167 .assign(&self.auxiliary0);
6168 out
6169 }
6170
6171 pub(crate) fn lower(&self) -> Array1<f64> {
6172 let mut out =
6173 Array1::<f64>::zeros(self.rho_dim() + self.log_kappa_dim() + self.auxiliary_dim());
6174 out.slice_mut(s![..self.rho_dim()]).assign(&self.rho_lower);
6175 out.slice_mut(s![self.rho_dim()..self.rho_dim() + self.log_kappa_dim()])
6176 .assign(self.log_kappa_lower.as_array());
6177 out.slice_mut(s![self.rho_dim() + self.log_kappa_dim()..])
6178 .assign(&self.auxiliary_lower);
6179 out
6180 }
6181
6182 pub(crate) fn upper(&self) -> Array1<f64> {
6183 let mut out =
6184 Array1::<f64>::zeros(self.rho_dim() + self.log_kappa_dim() + self.auxiliary_dim());
6185 out.slice_mut(s![..self.rho_dim()]).assign(&self.rho_upper);
6186 out.slice_mut(s![self.rho_dim()..self.rho_dim() + self.log_kappa_dim()])
6187 .assign(self.log_kappa_upper.as_array());
6188 out.slice_mut(s![self.rho_dim() + self.log_kappa_dim()..])
6189 .assign(&self.auxiliary_upper);
6190 out
6191 }
6192
6193 pub(crate) fn log_kappa_dims_per_term(&self) -> Vec<usize> {
6195 self.log_kappa0.dims_per_term().to_vec()
6196 }
6197}
6198
6199struct ExactJointDesignCache<'d> {
6205 realizers: Vec<FrozenTermCollectionIncrementalRealizer<'d>>,
6206 block_term_indices: Vec<Vec<usize>>,
6207 current_theta: Option<Array1<f64>>,
6208 last_cost: Option<f64>,
6209 last_eval: Option<(
6210 f64,
6211 Array1<f64>,
6212 gam_problem::HessianResult,
6213 )>,
6214 rho_dim: usize,
6215 all_dims: Vec<usize>,
6216 log_kappa_dim: usize,
6217 block_term_counts: Vec<usize>,
6218}
6219
6220impl<'d> ExactJointDesignCache<'d> {
6221 fn new(
6222 data: ArrayView2<'d, f64>,
6223 blocks: Vec<(TermCollectionSpec, TermCollectionDesign, Vec<usize>)>,
6224 rho_dim: usize,
6225 all_dims: Vec<usize>,
6226 ) -> Result<Self, String> {
6227 let n_blocks = blocks.len();
6228 let mut realizers = Vec::with_capacity(n_blocks);
6229 let mut block_term_indices = Vec::with_capacity(n_blocks);
6230 let mut block_term_counts = Vec::with_capacity(n_blocks);
6231
6232 for (spec, design, terms) in blocks {
6233 block_term_counts.push(terms.len());
6234 block_term_indices.push(terms);
6235 realizers.push(FrozenTermCollectionIncrementalRealizer::new(
6236 data, spec, design,
6237 )?);
6238 }
6239
6240 Ok(Self {
6241 realizers,
6242 block_term_indices,
6243 current_theta: None,
6244 last_cost: None,
6245 last_eval: None,
6246 rho_dim,
6247 log_kappa_dim: all_dims.iter().sum(),
6248 all_dims,
6249 block_term_counts,
6250 })
6251 }
6252
6253 fn ensure_theta(&mut self, theta: &Array1<f64>) -> Result<(), String> {
6254 if self
6255 .current_theta
6256 .as_ref()
6257 .is_some_and(|cached| theta_values_match(cached, theta))
6258 {
6259 return Ok(());
6260 }
6261
6262 let t_ensure = std::time::Instant::now();
6263 let kappa_theta_len = self.rho_dim + self.log_kappa_dim;
6264 if theta.len() < kappa_theta_len {
6265 return Err(SmoothError::dimension_mismatch(format!(
6266 "exact-joint theta length mismatch: got {}, expected at least {} (rho_dim={}, log_kappa_dim={})",
6267 theta.len(),
6268 kappa_theta_len,
6269 self.rho_dim,
6270 self.log_kappa_dim
6271 ))
6272 .into());
6273 }
6274 let theta_kappa = theta.slice(s![..kappa_theta_len]).to_owned();
6275 let full_log_kappa = SpatialLogKappaCoords::from_theta_tail_with_dims(
6276 &theta_kappa,
6277 self.rho_dim,
6278 self.all_dims.clone(),
6279 );
6280
6281 let n = self.realizers.len();
6285 let mut remaining = full_log_kappa;
6286 for block_idx in 0..n {
6287 let count = self.block_term_counts[block_idx];
6288 if block_idx < n - 1 {
6289 let (block_lk, rest) = remaining.split_at(count);
6290 self.realizers[block_idx]
6291 .apply_log_kappa(&block_lk, &self.block_term_indices[block_idx])?;
6292 remaining = rest;
6293 } else {
6294 self.realizers[block_idx]
6296 .apply_log_kappa(&remaining, &self.block_term_indices[block_idx])?;
6297 }
6298 }
6299
6300 log::info!(
6301 "[STAGE] ensure_theta (n-block, {} blocks, {} realizers): {:.3}s",
6302 n,
6303 self.realizers.len(),
6304 t_ensure.elapsed().as_secs_f64(),
6305 );
6306 self.current_theta = Some(theta.clone());
6307 self.last_cost = None;
6308 self.last_eval = None;
6309 Ok(())
6310 }
6311
6312 impl_exact_joint_theta_memo!();
6313
6314 fn store_cost_only(&mut self, theta: &Array1<f64>, cost: f64) {
6320 if self
6321 .current_theta
6322 .as_ref()
6323 .is_some_and(|cached| theta_values_match(cached, theta))
6324 {
6325 self.last_cost = Some(cost);
6326 }
6327 }
6328
6329 fn specs(&self) -> Vec<&TermCollectionSpec> {
6330 self.realizers.iter().map(|r| r.spec()).collect()
6331 }
6332
6333 fn designs(&self) -> Vec<&TermCollectionDesign> {
6334 self.realizers.iter().map(|r| r.design()).collect()
6335 }
6336
6337 fn design_revision(&self) -> u64 {
6347 self.realizers
6348 .iter()
6349 .fold(0u64, |acc, r| acc.wrapping_add(r.design_revision()))
6350 }
6351}
6352
6353pub(crate) fn seed_risk_profile_for_likelihood_family(
6354 family: &LikelihoodSpec,
6355) -> gam_problem::SeedRiskProfile {
6356 match &family.response {
6357 ResponseFamily::Gaussian => gam_problem::SeedRiskProfile::Gaussian,
6358 ResponseFamily::RoystonParmar => gam_problem::SeedRiskProfile::Survival,
6359 ResponseFamily::Binomial
6360 | ResponseFamily::Poisson
6361 | ResponseFamily::Tweedie { .. }
6362 | ResponseFamily::NegativeBinomial { .. }
6363 | ResponseFamily::Beta { .. }
6364 | ResponseFamily::Gamma => gam_problem::SeedRiskProfile::GeneralizedLinear,
6365 }
6366}
6367
6368const EXACT_JOINT_SECOND_ORDER_THETA_CAP: usize = 8;
6376
6377fn exact_joint_seed_config(
6378 risk_profile: gam_problem::SeedRiskProfile,
6379 auxiliary_dim: usize,
6380) -> gam_problem::SeedConfig {
6381 let mut config = gam_problem::SeedConfig {
6382 risk_profile,
6383 num_auxiliary_trailing: auxiliary_dim,
6384 ..Default::default()
6385 };
6386 match risk_profile {
6387 gam_problem::SeedRiskProfile::Gaussian
6388 | gam_problem::SeedRiskProfile::GaussianLocationScale => {
6389 config.max_seeds = 4;
6390 config.seed_budget = 2;
6391 }
6392 gam_problem::SeedRiskProfile::GeneralizedLinear => {
6393 config.max_seeds = 1;
6398 config.seed_budget = 1;
6399 config.screen_max_inner_iterations = 8;
6400 }
6401 gam_problem::SeedRiskProfile::Survival => {
6402 config.max_seeds = 8;
6408 config.seed_budget = 4;
6409 config.screen_max_inner_iterations = 8;
6410 }
6411 }
6412 config
6413}
6414
6415#[cfg(test)]
6416mod exact_joint_seed_config_tests {
6417 use super::*;
6418
6419 #[test]
6420 fn exact_joint_marginal_slope_profiles_get_deeper_startup_validation() {
6421 let bms = exact_joint_seed_config(gam_problem::SeedRiskProfile::GeneralizedLinear, 2);
6422 assert_eq!(bms.max_seeds, 1);
6423 assert_eq!(bms.seed_budget, 1);
6424 assert_eq!(bms.screen_max_inner_iterations, 8);
6425 assert_eq!(bms.num_auxiliary_trailing, 2);
6426
6427 let survival = exact_joint_seed_config(gam_problem::SeedRiskProfile::Survival, 3);
6428 assert_eq!(survival.max_seeds, 8);
6429 assert_eq!(survival.seed_budget, 4);
6430 assert_eq!(survival.screen_max_inner_iterations, 8);
6431 assert_eq!(survival.num_auxiliary_trailing, 3);
6432 }
6433
6434 #[test]
6435 fn exact_joint_gaussian_keeps_tight_historical_multistart_budget() {
6436 let gaussian = exact_joint_seed_config(gam_problem::SeedRiskProfile::Gaussian, 1);
6437 assert_eq!(gaussian.max_seeds, 4);
6438 assert_eq!(gaussian.seed_budget, 2);
6439 assert_eq!(
6440 gaussian.screen_max_inner_iterations,
6441 gam_problem::SeedConfig::default().screen_max_inner_iterations
6442 );
6443 assert_eq!(gaussian.num_auxiliary_trailing, 1);
6444 }
6445}
6446
6447pub(crate) fn exact_joint_multistart_outer_problem(
6448 theta0: &Array1<f64>,
6449 lower: &Array1<f64>,
6450 upper: &Array1<f64>,
6451 rho_dim: usize,
6452 auxiliary_dim: usize,
6453 n_params: usize,
6454 gradient: gam_problem::Derivative,
6455 hessian: gam_problem::DeclaredHessianForm,
6456 prefer_gradient_only: bool,
6457 disable_fixed_point: bool,
6458 risk_profile: gam_problem::SeedRiskProfile,
6459 tolerance: f64,
6460 max_iter: usize,
6461 bfgs_step_cap: Option<f64>,
6470 bfgs_step_cap_psi: Option<f64>,
6471 screening_cap: Option<Arc<AtomicUsize>>,
6472 profiled_objective_size: Option<(usize, usize)>,
6493 has_constant_curvature: bool,
6502) -> gam_solve::rho_optimizer::OuterProblem {
6503 let mut seed_heuristic = theta0.to_vec();
6504 for value in &mut seed_heuristic[..rho_dim] {
6505 *value = value.exp();
6506 }
6507 let rho_ceiling = if has_constant_curvature {
6512 gam_solve::estimate::RHO_BOUND
6513 } else {
6514 12.0
6515 };
6516 let mut problem = gam_solve::rho_optimizer::OuterProblem::new(n_params)
6517 .with_gradient(gradient)
6518 .with_hessian(hessian)
6519 .with_prefer_gradient_only(prefer_gradient_only)
6520 .with_disable_fixed_point(disable_fixed_point)
6521 .with_fallback_policy(gam_solve::rho_optimizer::FallbackPolicy::Automatic)
6531 .with_psi_dim(auxiliary_dim)
6532 .with_tolerance(tolerance)
6533 .with_max_iter(max_iter)
6534 .with_bounds(lower.clone(), upper.clone())
6535 .with_initial_rho(theta0.clone())
6536 .with_bfgs_step_cap(bfgs_step_cap)
6537 .with_bfgs_step_cap_psi(bfgs_step_cap_psi)
6538 .with_seed_config({
6539 let mut sc = exact_joint_seed_config(risk_profile, auxiliary_dim);
6540 if has_constant_curvature {
6541 sc.bounds = (sc.bounds.0, rho_ceiling);
6545 }
6564 sc
6565 })
6566 .with_rho_bound(rho_ceiling)
6567 .with_heuristic_lambdas(seed_heuristic);
6568 if let Some((n_obs, p_cols)) = profiled_objective_size {
6569 problem = problem
6577 .with_objective_scale(Some(n_obs as f64))
6578 .with_problem_size(n_obs, p_cols)
6579 .with_arc_initial_regularization(Some(0.25))
6580 .with_operator_initial_trust_radius(Some(4.0));
6581 }
6582 if let Some(screening_cap) = screening_cap {
6583 problem = problem
6584 .with_screening_cap(screening_cap)
6585 .with_screen_initial_rho(true);
6586 }
6587 problem
6588}
6589
6590fn kappa_phase_failure_is_fixed_kappa_recoverable(message: &str) -> bool {
6601 message.contains("no candidate seeds passed outer startup validation")
6602 || message.contains("joint hyper rho dimension mismatch")
6603 || message.contains("objective returned a non-finite cost")
6604}
6605
6606pub fn optimize_spatial_length_scale_exact_joint<FitOut, FitFn, ExactFn, ExactEfsFn, SeedFn>(
6607 data: ArrayView2<'_, f64>,
6608 block_specs: &[TermCollectionSpec],
6609 block_term_indices: &[Vec<usize>],
6610 kappa_options: &SpatialLengthScaleOptimizationOptions,
6611 joint_setup: &ExactJointHyperSetup,
6612 seed_risk_profile: gam_problem::SeedRiskProfile,
6613 analytic_joint_gradient_available: bool,
6614 analytic_joint_hessian_available: bool,
6615 disable_fixed_point: bool,
6616 screening_cap: Option<Arc<AtomicUsize>>,
6617 outer_derivative_policy: gam_model_api::families::custom_family::OuterDerivativePolicy,
6618 mut fit_fn: FitFn,
6619 mut exact_fn: ExactFn,
6620 mut exact_efs_fn: ExactEfsFn,
6621 mut seed_inner_beta_fn: SeedFn,
6622) -> Result<SpatialLengthScaleOptimizationResult<FitOut>, String>
6623where
6624 FitOut: Clone,
6625 FitFn: FnMut(
6626 &Array1<f64>,
6627 &[TermCollectionSpec],
6628 &[TermCollectionDesign],
6629 ) -> Result<FitOut, String>,
6630 ExactFn: FnMut(
6631 &Array1<f64>,
6632 &[TermCollectionSpec],
6633 &[TermCollectionDesign],
6634 gam_solve::estimate::reml::reml_outer_engine::EvalMode,
6635 &gam_problem::outer_subsample::RowSet,
6636 ) -> Result<
6637 (
6638 f64,
6639 Array1<f64>,
6640 gam_problem::HessianResult,
6641 ),
6642 String,
6643 >,
6644 ExactEfsFn: FnMut(
6645 &Array1<f64>,
6646 &[TermCollectionSpec],
6647 &[TermCollectionDesign],
6648 ) -> Result<gam_problem::EfsEval, String>,
6649 SeedFn:
6650 FnMut(&Array1<f64>) -> Result<gam_solve::rho_optimizer::SeedOutcome, EstimationError>,
6651{
6652 let n_blocks = block_specs.len();
6653 if block_term_indices.len() != n_blocks {
6654 return Err(SmoothError::dimension_mismatch(format!(
6655 "block_specs ({}) and block_term_indices ({}) length mismatch",
6656 n_blocks,
6657 block_term_indices.len()
6658 ))
6659 .into());
6660 }
6661
6662 let log_kappa_dim = joint_setup.log_kappa_dim();
6663
6664 log::warn!(
6665 "[OUTER-FD-AUDIT/spatial-exact-joint] driver entry: aux_dim={} log_kappa_dim={} kappa_enabled={} rho_dim={} theta0_len={}",
6666 joint_setup.auxiliary_dim(),
6667 log_kappa_dim,
6668 kappa_options.enabled,
6669 joint_setup.rho_dim(),
6670 joint_setup.theta0().len()
6671 );
6672
6673 if joint_setup.auxiliary_dim() == 0 && (!kappa_options.enabled || log_kappa_dim == 0) {
6677 log::warn!(
6678 "[OUTER-FD-AUDIT/spatial-exact-joint] taking FAST path (no outer theta optimization in this driver)"
6679 );
6680 let (designs, resolved_specs) = build_term_collection_designs_and_freeze_joint(
6681 data, block_specs,
6682 )
6683 .map_err(|e| {
6684 format!("failed to build and freeze joint block designs during exact joint kappa optimization: {e}")
6685 })?;
6686 let theta0 = joint_setup.theta0();
6687
6688 let spec_refs: Vec<TermCollectionSpec> = resolved_specs.clone();
6690 let design_refs: Vec<TermCollectionDesign> = designs.clone();
6691 let fit = fit_fn(&theta0, &spec_refs, &design_refs)?;
6692 return Ok(SpatialLengthScaleOptimizationResult {
6693 resolved_specs,
6694 designs,
6695 fit,
6696 timing: None,
6697 });
6698 }
6699
6700 let theta0 = joint_setup.theta0();
6704 let lower = joint_setup.lower();
6705 let upper = joint_setup.upper();
6706 if theta0.len() < log_kappa_dim || lower.len() != theta0.len() || upper.len() != theta0.len() {
6707 return Err(SmoothError::dimension_mismatch(format!(
6708 "invalid exact joint theta setup: theta0={}, lower={}, upper={}, required_log_kappa_dim={}",
6709 theta0.len(),
6710 lower.len(),
6711 upper.len(),
6712 log_kappa_dim
6713 ))
6714 .into());
6715 }
6716 let rho_dim = joint_setup.rho_dim();
6717 let all_dims = joint_setup.log_kappa_dims_per_term();
6718
6719 let (boot_designs, best_specs) = build_term_collection_designs_and_freeze_joint(
6721 data,
6722 block_specs,
6723 )
6724 .map_err(|e| {
6725 format!(
6726 "failed to build and freeze joint block designs during exact joint kappa bootstrap: {e}"
6727 )
6728 })?;
6729 let policy_hessian_form = outer_derivative_policy.declared_hessian_form();
6739 let analytic_outer_hessian_available = analytic_joint_hessian_available
6740 && matches!(
6741 policy_hessian_form,
6742 gam_problem::DeclaredHessianForm::Either
6743 | gam_problem::DeclaredHessianForm::Dense
6744 | gam_problem::DeclaredHessianForm::Operator { .. }
6745 );
6746 let prefer_gradient_only = !analytic_outer_hessian_available;
6747
6748 let theta_dim = theta0.len();
6749 let psi_dim = theta_dim - rho_dim;
6750
6751 let cache_blocks: Vec<(TermCollectionSpec, TermCollectionDesign, Vec<usize>)> = best_specs
6753 .iter()
6754 .zip(boot_designs.iter())
6755 .zip(block_term_indices.iter())
6756 .map(|((spec, design), terms)| (spec.clone(), design.clone(), terms.clone()))
6757 .collect();
6758
6759 struct NBlockExactJointState<'d> {
6760 cache: ExactJointDesignCache<'d>,
6761 }
6762
6763 let mut state = NBlockExactJointState {
6764 cache: ExactJointDesignCache::new(data, cache_blocks, rho_dim, all_dims.clone())?,
6765 };
6766
6767 const KAPPA_PILOT_K: usize = 5_000;
6792 const KAPPA_POLISH_K: usize = 25_000;
6793 const KAPPA_POLISH_TRIGGER_N: usize = 100_000;
6794
6795 let n_total = data.nrows();
6796 let use_staged_kappa = outer_derivative_policy.should_use_staged_kappa(n_total);
6797 if use_staged_kappa {
6798 log::info!(
6799 "[KAPPA-STAGED] auto-engaging pilot+polish schedule: n={} pilot_k={} polish_k={}",
6800 n_total,
6801 KAPPA_PILOT_K,
6802 KAPPA_POLISH_K,
6803 );
6804 }
6805
6806 fn build_uniform_pilot_subsample(
6823 n_total: usize,
6824 k_target: usize,
6825 seed: u64,
6826 ) -> gam_problem::outer_subsample::OuterScoreSubsample {
6827 use gam_problem::outer_subsample::OuterScoreSubsample;
6828 let k = k_target.min(n_total);
6829 if k == 0 || n_total == 0 {
6830 return OuterScoreSubsample::from_uniform_inclusion_mask(Vec::new(), n_total, seed);
6831 }
6832 let mut mask: Vec<usize> = Vec::with_capacity(k);
6836 let mut state = seed.wrapping_add(0x9E3779B97F4A7C15);
6838 let splitmix = |s: &mut u64| -> u64 { gam_linalg::utils::splitmix64(s) };
6839 let mut taken = std::collections::HashSet::with_capacity(k);
6840 for j in (n_total - k)..n_total {
6841 let r = (splitmix(&mut state) % (j as u64 + 1)) as usize;
6842 if !taken.insert(r) {
6843 taken.insert(j);
6844 mask.push(j);
6845 } else {
6846 mask.push(r);
6847 }
6848 }
6849 mask.sort_unstable();
6850 mask.dedup();
6851 OuterScoreSubsample::from_uniform_inclusion_mask(mask, n_total, seed)
6852 }
6853
6854 let current_row_set: std::cell::RefCell<gam_problem::outer_subsample::RowSet> = if use_staged_kappa {
6855 let pilot = build_uniform_pilot_subsample(n_total, KAPPA_PILOT_K, n_total as u64);
6856 std::cell::RefCell::new(gam_problem::outer_subsample::RowSet::Subsample {
6857 rows: std::sync::Arc::clone(&pilot.rows),
6858 n_full: n_total,
6859 })
6860 } else {
6861 std::cell::RefCell::new(gam_problem::outer_subsample::RowSet::All)
6862 };
6863
6864 let exact_fn_cell = std::cell::RefCell::new(&mut exact_fn);
6865 let exact_efs_fn_cell = std::cell::RefCell::new(&mut exact_efs_fn);
6866
6867 use std::cell::Cell;
6882 let kphase_cost_calls: Cell<usize> = Cell::new(0);
6883 let kphase_cost_total_s: Cell<f64> = Cell::new(0.0);
6884 let kphase_eval_calls: Cell<usize> = Cell::new(0);
6885 let kphase_eval_total_s: Cell<f64> = Cell::new(0.0);
6886 let kphase_efs_calls: Cell<usize> = Cell::new(0);
6887 let kphase_efs_total_s: Cell<f64> = Cell::new(0.0);
6888 let kphase_optim_start = std::time::Instant::now();
6889 let kphase_log_kappa_dim = log_kappa_dim;
6890 let kphase_log_norms = |theta: &Array1<f64>| -> (f64, f64) {
6891 let theta_norm = theta.iter().map(|v| v * v).sum::<f64>().sqrt();
6892 let log_kappa_norm = if kphase_log_kappa_dim > 0 && theta.len() >= kphase_log_kappa_dim {
6893 let start = theta.len() - kphase_log_kappa_dim;
6894 theta.iter().skip(start).map(|v| v * v).sum::<f64>().sqrt()
6895 } else {
6896 0.0
6897 };
6898 (theta_norm, log_kappa_norm)
6899 };
6900
6901 use gam_solve::rho_optimizer::OuterEvalOrder;
6902 use gam_problem::{DeclaredHessianForm, Derivative, OuterEval};
6903
6904 let joint_p_cols: usize = boot_designs
6908 .iter()
6909 .map(|d| d.design.ncols())
6910 .sum::<usize>()
6911 .max(1);
6912
6913 let problem = exact_joint_multistart_outer_problem(
6914 &theta0,
6915 &lower,
6916 &upper,
6917 rho_dim,
6918 psi_dim,
6919 theta_dim,
6920 if analytic_joint_gradient_available {
6921 Derivative::Analytic
6922 } else {
6923 Derivative::Unavailable
6924 },
6925 if analytic_outer_hessian_available {
6926 DeclaredHessianForm::Either
6927 } else {
6928 DeclaredHessianForm::Unavailable
6929 },
6930 prefer_gradient_only,
6931 disable_fixed_point,
6932 seed_risk_profile,
6933 kappa_options.rel_tol.max(1e-6),
6934 kappa_options.max_outer_iter.max(1),
6935 Some(5.0),
6937 Some(kappa_options.log_step.clamp(0.25, 1.0)),
6939 screening_cap.clone(),
6940 Some((n_total, joint_p_cols)),
6943 block_specs
6946 .iter()
6947 .any(|s| !constant_curvature_term_indices(s).is_empty()),
6948 );
6949
6950 fn collect_specs(cache: &ExactJointDesignCache<'_>) -> Vec<TermCollectionSpec> {
6952 cache.specs().into_iter().cloned().collect()
6953 }
6954 fn collect_designs(cache: &ExactJointDesignCache<'_>) -> Vec<TermCollectionDesign> {
6955 cache.designs().into_iter().cloned().collect()
6956 }
6957
6958 let result = {
6959 let eval_outer = |ctx: &mut &mut NBlockExactJointState<'_>,
6960 theta: &Array1<f64>,
6961 order: OuterEvalOrder|
6962 -> Result<OuterEval, EstimationError> {
6963 if let Some((cost, grad, hess)) = ctx.cache.memoized_eval(theta) {
6964 let cached_satisfies_order = match order {
6965 OuterEvalOrder::Value => true,
6966 OuterEvalOrder::ValueAndGradient => true,
6967 OuterEvalOrder::ValueGradientHessian => hess.is_analytic(),
6968 };
6969 if cached_satisfies_order {
6970 if !cost.is_finite() {
6971 return Ok(OuterEval::infeasible(theta.len()));
6972 }
6973 if grad.iter().any(|v| !v.is_finite()) {
6986 return Ok(OuterEval::infeasible(theta.len()));
6987 }
6988 return Ok(OuterEval {
6989 cost,
6990 gradient: grad,
6991 hessian: hess,
6992 inner_beta_hint: None,
6993 });
6994 }
6995 }
6996 if gam_solve::rho_optimizer::outer_wall_clock_deadline_exceeded() {
7013 return Ok(OuterEval::infeasible(theta.len()));
7014 }
7015 if let Err(err) = ctx.cache.ensure_theta(theta) {
7016 log::warn!(
7017 "[OUTER] n-block exact-joint spatial: ensure_theta failed during gradient evaluation: {err}"
7018 );
7019 return Ok(OuterEval::infeasible(theta.len()));
7020 }
7021 let design_revision = Some(ctx.cache.design_revision());
7022 let specs = collect_specs(&ctx.cache);
7023 let designs = collect_designs(&ctx.cache);
7024 let clamped = outer_derivative_policy.order_for_evaluation(order);
7032 let need_hessian = matches!(clamped, OuterEvalOrder::ValueGradientHessian)
7033 && analytic_outer_hessian_available;
7034 let eval_mode = if need_hessian {
7035 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueGradientHessian
7036 } else {
7037 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueAndGradient
7038 };
7039 let t0 = std::time::Instant::now();
7040 let result = {
7041 let row_set_borrow = current_row_set.borrow();
7042 (*exact_fn_cell.borrow_mut())(theta, &specs, &designs, eval_mode, &row_set_borrow)
7043 };
7044 let elapsed_s = t0.elapsed().as_secs_f64();
7045 kphase_eval_calls.set(kphase_eval_calls.get() + 1);
7046 kphase_eval_total_s.set(kphase_eval_total_s.get() + elapsed_s);
7047 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta);
7048 log::info!(
7049 "[KAPPA-PHASE] phase=eval_outer call={} order={:?} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
7050 kphase_eval_calls.get(),
7051 order,
7052 design_revision,
7053 theta_norm,
7054 log_kappa_norm,
7055 elapsed_s,
7056 );
7057 match result {
7058 Ok((cost, grad, hess)) => {
7059 ctx.cache.store_eval((cost, grad.clone(), hess.clone()));
7060 if !cost.is_finite() {
7061 return Ok(OuterEval::infeasible(theta.len()));
7062 }
7063 if grad.iter().any(|v| !v.is_finite()) {
7076 return Ok(OuterEval::infeasible(theta.len()));
7077 }
7078 Ok(OuterEval {
7079 cost,
7080 gradient: grad,
7081 hessian: hess,
7082 inner_beta_hint: None,
7083 })
7084 }
7085 Err(err) => {
7086 log::warn!(
7087 "[OUTER] n-block exact-joint spatial: exact evaluation failed: {err}"
7088 );
7089 Ok(OuterEval::infeasible(theta.len()))
7090 }
7091 }
7092 };
7093
7094 let obj = problem.build_objective_with_eval_order(
7095 &mut state,
7096 |ctx: &mut &mut NBlockExactJointState<'_>, theta: &Array1<f64>| {
7097 if let Some(cost) = ctx.cache.memoized_cost(theta) {
7098 return Ok(cost);
7099 }
7100 if gam_solve::rho_optimizer::outer_wall_clock_deadline_exceeded() {
7108 return Ok(f64::INFINITY);
7109 }
7110 if let Err(err) = ctx.cache.ensure_theta(theta) {
7111 log::warn!(
7112 "[OUTER] n-block exact-joint spatial: ensure_theta failed during cost evaluation: {err}"
7113 );
7114 return Ok(f64::INFINITY);
7115 }
7116 let design_revision = Some(ctx.cache.design_revision());
7117 let specs = collect_specs(&ctx.cache);
7118 let designs = collect_designs(&ctx.cache);
7119 let t0 = std::time::Instant::now();
7126 let result = {
7127 let row_set_borrow = current_row_set.borrow();
7128 (*exact_fn_cell.borrow_mut())(
7129 theta,
7130 &specs,
7131 &designs,
7132 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueOnly,
7133 &row_set_borrow,
7134 )
7135 };
7136 let elapsed_s = t0.elapsed().as_secs_f64();
7137 kphase_cost_calls.set(kphase_cost_calls.get() + 1);
7138 kphase_cost_total_s.set(kphase_cost_total_s.get() + elapsed_s);
7139 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta);
7140 log::info!(
7141 "[KAPPA-PHASE] phase=cost call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
7142 kphase_cost_calls.get(),
7143 design_revision,
7144 theta_norm,
7145 log_kappa_norm,
7146 elapsed_s,
7147 );
7148 match result {
7149 Ok((cost, _grad, _hess)) => {
7150 ctx.cache.store_cost_only(theta, cost);
7156 Ok(cost)
7157 }
7158 Err(err) => {
7159 log::warn!(
7160 "[OUTER] n-block exact-joint spatial: exact cost evaluation failed: {err}"
7161 );
7162 Ok(f64::INFINITY)
7163 }
7164 }
7165 },
7166 |ctx: &mut &mut NBlockExactJointState<'_>, theta: &Array1<f64>| {
7167 eval_outer(
7168 ctx,
7169 theta,
7170 if analytic_outer_hessian_available {
7171 OuterEvalOrder::ValueGradientHessian
7172 } else {
7173 OuterEvalOrder::ValueAndGradient
7174 },
7175 )
7176 },
7177 |ctx: &mut &mut NBlockExactJointState<'_>,
7178 theta: &Array1<f64>,
7179 order: OuterEvalOrder| { eval_outer(ctx, theta, order) },
7180 None::<fn(&mut &mut NBlockExactJointState<'_>)>,
7181 Some(
7182 |ctx: &mut &mut NBlockExactJointState<'_>, theta: &Array1<f64>| {
7183 ctx.cache
7184 .ensure_theta(theta)
7185 .map_err(EstimationError::InvalidInput)?;
7186 let design_revision = Some(ctx.cache.design_revision());
7187 let specs = collect_specs(&ctx.cache);
7188 let designs = collect_designs(&ctx.cache);
7189 let t0 = std::time::Instant::now();
7190 let eval_result = (*exact_efs_fn_cell.borrow_mut())(
7191 theta,
7192 &specs,
7193 &designs,
7194 );
7195 let elapsed_s = t0.elapsed().as_secs_f64();
7196 kphase_efs_calls.set(kphase_efs_calls.get() + 1);
7197 kphase_efs_total_s.set(kphase_efs_total_s.get() + elapsed_s);
7198 let (theta_norm, log_kappa_norm) = kphase_log_norms(theta);
7199 log::info!(
7200 "[KAPPA-PHASE] phase=efs call={} design_revision={:?} theta_norm={:.4e} log_kappa_norm={:.4e} elapsed_s={:.4}",
7201 kphase_efs_calls.get(),
7202 design_revision,
7203 theta_norm,
7204 log_kappa_norm,
7205 elapsed_s,
7206 );
7207 let eval = eval_result.map_err(EstimationError::RemlOptimizationFailed)?;
7208 Ok(eval)
7209 },
7210 ),
7211 );
7212 let mut obj = obj.with_seed_inner_state(
7213 move |_ctx: &mut &mut NBlockExactJointState<'_>, beta: &Array1<f64>| {
7214 (seed_inner_beta_fn)(beta)
7215 },
7216 );
7217
7218 match problem.run(&mut obj, "n-block exact-joint spatial") {
7219 Ok(result) => result,
7220 Err(e) => {
7221 let message = e.to_string();
7222 if kappa_phase_failure_is_fixed_kappa_recoverable(&message) {
7242 drop(obj);
7243 log::warn!(
7244 "[KAPPA-PHASE] length-scale optimization could not validate any seed \
7245 ({message}); falling back to a FIXED bootstrap κ (skipping κ \
7246 optimization) and fitting there — a real model at the initial \
7247 length-scale rather than raising (gam#787/#860)."
7248 );
7249 let (designs, resolved_specs) =
7250 build_term_collection_designs_and_freeze_joint(data, block_specs).map_err(
7251 |build_err| {
7252 format!(
7253 "fixed-κ fallback failed to build and freeze joint block \
7254 designs after κ optimization could not validate a seed \
7255 ({message}): {build_err}"
7256 )
7257 },
7258 )?;
7259 let fixed_theta0 = joint_setup.theta0();
7260 let spec_refs: Vec<TermCollectionSpec> = resolved_specs.clone();
7261 let design_refs: Vec<TermCollectionDesign> = designs.clone();
7262 let fit = fit_fn(&fixed_theta0, &spec_refs, &design_refs)?;
7263 return Ok(SpatialLengthScaleOptimizationResult {
7264 resolved_specs,
7265 designs,
7266 fit,
7267 timing: None,
7268 });
7269 }
7270 return Err(message);
7271 }
7272 }
7273 }; let kphase_total_s = kphase_optim_start.elapsed().as_secs_f64();
7283 log::info!(
7284 "[KAPPA-PHASE-SUMMARY] log_kappa_dim={} n_cost={} cost_total_s={:.4} n_eval={} eval_total_s={:.4} n_efs={} efs_total_s={:.4} optim_total_s={:.4}",
7285 kphase_log_kappa_dim,
7286 kphase_cost_calls.get(),
7287 kphase_cost_total_s.get(),
7288 kphase_eval_calls.get(),
7289 kphase_eval_total_s.get(),
7290 kphase_efs_calls.get(),
7291 kphase_efs_total_s.get(),
7292 kphase_total_s,
7293 );
7294 let timing = SpatialLengthScaleOptimizationTiming {
7295 log_kappa_dim: kphase_log_kappa_dim,
7296 cost_calls: kphase_cost_calls.get(),
7297 cost_total_s: kphase_cost_total_s.get(),
7298 eval_calls: kphase_eval_calls.get(),
7299 eval_total_s: kphase_eval_total_s.get(),
7300 efs_calls: kphase_efs_calls.get(),
7301 efs_total_s: kphase_efs_total_s.get(),
7302 slow_path_resets: 0,
7303 design_revision_delta: 0,
7304 nfree_miss_shape: 0,
7305 nfree_miss_value: 0,
7306 nfree_miss_gradient: 0,
7307 nfree_miss_penalty: 0,
7308 nfree_miss_revision: 0,
7309 nfree_miss_second_order: 0,
7310 nfree_miss_other: 0,
7311 optim_total_s: kphase_total_s,
7312 };
7313
7314 let theta_star = result.rho;
7315
7316 if use_staged_kappa && n_total >= KAPPA_POLISH_TRIGGER_N {
7333 let polish = build_uniform_pilot_subsample(
7334 n_total,
7335 KAPPA_POLISH_K,
7336 (n_total as u64).wrapping_add(0xA5A5A5A5),
7337 );
7338 *current_row_set.borrow_mut() = gam_problem::outer_subsample::RowSet::Subsample {
7339 rows: std::sync::Arc::clone(&polish.rows),
7340 n_full: n_total,
7341 };
7342 log::info!(
7343 "[KAPPA-STAGED] rotating to polish subsample: k={} at theta_star",
7344 polish.rows.len(),
7345 );
7346 state.cache.ensure_theta(&theta_star)?;
7350 let (polish_cost, polish_grad, _) = {
7351 let specs = collect_specs(&state.cache);
7352 let designs = collect_designs(&state.cache);
7353 let row_set_borrow = current_row_set.borrow();
7354 exact_fn(
7355 &theta_star,
7356 &specs,
7357 &designs,
7358 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueAndGradient,
7359 &row_set_borrow,
7360 )?
7361 };
7362 if !polish_cost.is_finite() || polish_grad.iter().any(|value| !value.is_finite()) {
7363 return Err(
7364 "polish subsample exact-joint evaluation produced non-finite objective pieces"
7365 .to_string(),
7366 );
7367 }
7368 }
7369 *current_row_set.borrow_mut() = gam_problem::outer_subsample::RowSet::All;
7370 if use_staged_kappa {
7371 log::info!(
7372 "[KAPPA-STAGED] rotating to full data for final coefficient fit (n={})",
7373 n_total,
7374 );
7375 }
7376
7377 state.cache.ensure_theta(&theta_star)?;
7378
7379 let resolved_specs: Vec<TermCollectionSpec> = collect_specs(&state.cache);
7380 let designs: Vec<TermCollectionDesign> = collect_designs(&state.cache);
7381
7382 let fit = fit_fn(&theta_star, &resolved_specs, &designs)?;
7383
7384 for spec in &resolved_specs {
7385 log_spatial_aniso_scales(spec);
7386 }
7387
7388 Ok(SpatialLengthScaleOptimizationResult {
7389 resolved_specs,
7390 designs,
7391 fit,
7392 timing: Some(timing),
7393 })
7394}
7395
7396fn try_exact_joint_latent_coord_optimization(
7397 data: ArrayView2<'_, f64>,
7398 y: ArrayView1<'_, f64>,
7399 weights: ArrayView1<'_, f64>,
7400 offset: ArrayView1<'_, f64>,
7401 resolvedspec: &TermCollectionSpec,
7402 best: &FittedTermCollection,
7403 family: LikelihoodSpec,
7404 options: &FitOptions,
7405 latent: &StandardLatentCoordConfig,
7406) -> Result<FittedTermCollectionWithSpec, EstimationError> {
7407 use gam_solve::rho_optimizer::OuterEvalOrder;
7408 use gam_problem::{DeclaredHessianForm, Derivative, OuterEval};
7409
7410 let rho_dim = best.fit.lambdas.len();
7411 let latent_flat_dim = latent.values.len();
7412 if latent_flat_dim == 0 {
7413 crate::bail_invalid_estim!(
7414 "latent-coordinate optimization requires a non-empty latent block"
7415 );
7416 }
7417 let direct_hypers =
7418 latent_coord_initial_direct_hypers(latent.values.id_mode(), latent.values.latent_dim())?;
7419 let analytic_rho_count = latent
7420 .analytic_penalties
7421 .as_ref()
7422 .map_or(0, |registry| registry.total_rho_count());
7423 let latent_coord_ext_dim = latent_flat_dim + analytic_rho_count + direct_hypers.len();
7424
7425 let mut theta0 = Array1::<f64>::zeros(rho_dim + latent_coord_ext_dim);
7426 theta0
7427 .slice_mut(s![..rho_dim])
7428 .assign(&best.fit.lambdas.mapv(f64::ln));
7429 theta0
7430 .slice_mut(s![rho_dim..rho_dim + latent_flat_dim])
7431 .assign(latent.values.as_flat());
7432 if !direct_hypers.is_empty() {
7433 let direct_start = rho_dim + latent_flat_dim + analytic_rho_count;
7434 theta0
7435 .slice_mut(s![direct_start..direct_start + direct_hypers.len()])
7436 .assign(&direct_hypers);
7437 }
7438
7439 let mut lower = Array1::<f64>::from_elem(theta0.len(), -12.0);
7440 let mut upper = Array1::<f64>::from_elem(theta0.len(), 12.0);
7441 let latent_bound = latent
7442 .values
7443 .as_flat()
7444 .iter()
7445 .fold(1.0_f64, |acc, &v| acc.max(v.abs()))
7446 + 10.0;
7447 for axis in rho_dim..rho_dim + latent_flat_dim {
7448 lower[axis] = -latent_bound;
7449 upper[axis] = latent_bound;
7450 }
7451
7452 struct LatentJointContext<'d> {
7453 rho_dim: usize,
7454 cache: SingleBlockLatentCoordDesignCache,
7455 evaluator: gam_solve::estimate::ExternalJointHyperEvaluator<'d>,
7456 }
7457
7458 impl<'d> LatentJointContext<'d> {
7459 fn eval_full(
7460 &mut self,
7461 theta: &Array1<f64>,
7462 order: OuterEvalOrder,
7463 ) -> Result<
7464 (
7465 f64,
7466 Array1<f64>,
7467 gam_problem::HessianResult,
7468 ),
7469 EstimationError,
7470 > {
7471 if let Some(eval) = self.cache.memoized_eval(theta) {
7472 return Ok(eval);
7473 }
7474 self.cache
7475 .ensure_theta(theta)
7476 .map_err(EstimationError::InvalidInput)?;
7477 let hyper_dirs = self
7478 .cache
7479 .hyper_dirs()
7480 .map_err(EstimationError::InvalidInput)?;
7481 let design_revision = Some(self.cache.design_revision());
7482 let registry_for_key = self.cache.analytic_penalties();
7483 self.evaluator
7484 .set_analytic_penalty_registry(registry_for_key.as_deref());
7485 let mut eval = evaluate_joint_reml_outer_eval_at_theta(
7486 &mut self.evaluator,
7487 self.cache.design(),
7488 theta,
7489 self.rho_dim,
7490 hyper_dirs,
7491 None,
7492 order,
7493 design_revision,
7494 )?;
7495 let latent = self.cache.latent().map_err(EstimationError::InvalidInput)?;
7496 if let Some(registry) = registry_for_key {
7497 let mut registry = registry.as_ref().clone();
7498 registry.apply_weight_schedules(
7499 gam_solve::estimate::reml::outer_eval::current_outer_iter() as usize,
7500 );
7501 add_analytic_penalty_objective_to_eval(
7502 theta,
7503 self.rho_dim,
7504 latent.as_ref(),
7505 ®istry,
7506 &mut eval,
7507 )?;
7508 }
7509 add_latent_id_objective_to_eval(
7510 theta,
7511 self.rho_dim,
7512 self.cache.analytic_penalty_rho_count(),
7513 latent.as_ref(),
7514 &mut eval,
7515 )?;
7516 self.cache.store_eval(eval.clone());
7517 Ok(eval)
7518 }
7519
7520 fn eval_efs(
7521 &mut self,
7522 theta: &Array1<f64>,
7523 ) -> Result<gam_problem::EfsEval, EstimationError> {
7524 self.cache
7525 .ensure_theta(theta)
7526 .map_err(EstimationError::InvalidInput)?;
7527 let hyper_dirs = self
7528 .cache
7529 .hyper_dirs()
7530 .map_err(EstimationError::InvalidInput)?;
7531 let registry_for_key = self.cache.analytic_penalties();
7532 self.evaluator
7533 .set_analytic_penalty_registry(registry_for_key.as_deref());
7534 let mut efs = evaluate_joint_reml_efs_at_theta(
7535 &mut self.evaluator,
7536 self.cache.design(),
7537 theta,
7538 self.rho_dim,
7539 hyper_dirs,
7540 None,
7541 Some(self.cache.design_revision()),
7542 )?;
7543 if let Some(registry) = registry_for_key {
7544 let mut registry = registry.as_ref().clone();
7545 registry.apply_weight_schedules(
7546 gam_solve::estimate::reml::outer_eval::current_outer_iter() as usize,
7547 );
7548 let latent = self.cache.latent().map_err(EstimationError::InvalidInput)?;
7549 let contribution = analytic_penalty_objective_contribution(
7550 theta,
7551 self.rho_dim,
7552 latent.as_ref(),
7553 ®istry,
7554 )?;
7555 efs.cost += contribution.cost;
7556 if let (Some(psi_gradient), Some(psi_indices)) =
7557 (efs.psi_gradient.as_mut(), efs.psi_indices.as_ref())
7558 {
7559 if psi_gradient.len() != psi_indices.len() {
7560 crate::bail_invalid_estim!(
7561 "latent-coordinate analytic penalty EFS psi gradient length mismatch: gradient={}, indices={}",
7562 psi_gradient.len(),
7563 psi_indices.len()
7564 );
7565 }
7566 for (local_idx, &theta_idx) in psi_indices.iter().enumerate() {
7567 psi_gradient[local_idx] += contribution.gradient[theta_idx];
7568 }
7569 }
7570 }
7571 Ok(efs)
7572 }
7573
7574 fn eval_cost(&mut self, theta: &Array1<f64>) -> f64 {
7575 if let Some(cost) = self.cache.memoized_cost(theta) {
7576 return cost;
7577 }
7578 if self.cache.ensure_theta(theta).is_err() {
7579 return f64::INFINITY;
7580 }
7581 let design_revision = Some(self.cache.design_revision());
7582 let registry_for_key = self.cache.analytic_penalties();
7583 self.evaluator
7584 .set_analytic_penalty_registry(registry_for_key.as_deref());
7585 let result = {
7586 let design = self.cache.design();
7587 self.evaluator.evaluate_cost_only(
7588 &design.design,
7589 &design.penalties,
7590 &design.nullspace_dims,
7591 design.linear_constraints.clone(),
7592 theta,
7593 self.rho_dim,
7594 None,
7595 "latent-coordinate-joint cost-only",
7596 design_revision,
7597 )
7598 };
7599 match result {
7600 Ok(cost) => {
7601 let latent = match self.cache.latent() {
7602 Ok(latent) => latent,
7603 Err(_) => return f64::INFINITY,
7604 };
7605 let contribution = match latent_id_objective_contribution(
7606 theta,
7607 self.rho_dim,
7608 self.cache.analytic_penalty_rho_count(),
7609 latent.as_ref(),
7610 ) {
7611 Ok(contribution) => contribution,
7612 Err(_) => return f64::INFINITY,
7613 };
7614 let cost = cost + contribution.cost;
7615 let cost = if let Some(registry) = registry_for_key {
7616 let mut registry = registry.as_ref().clone();
7617 registry.apply_weight_schedules(
7618 gam_solve::estimate::reml::outer_eval::current_outer_iter()
7619 as usize,
7620 );
7621 match analytic_penalty_objective_contribution(
7622 theta,
7623 self.rho_dim,
7624 latent.as_ref(),
7625 ®istry,
7626 ) {
7627 Ok(contribution) => cost + contribution.cost,
7628 Err(_) => return f64::INFINITY,
7629 }
7630 } else {
7631 cost
7632 };
7633 self.cache.store_cost(cost);
7634 cost
7635 }
7636 Err(_) => f64::INFINITY,
7637 }
7638 }
7639 }
7640
7641 let mut ctx = LatentJointContext {
7642 rho_dim,
7643 cache: SingleBlockLatentCoordDesignCache::new(
7644 data.to_owned(),
7645 resolvedspec.clone(),
7646 best.design.clone(),
7647 latent,
7648 rho_dim,
7649 )
7650 .map_err(EstimationError::InvalidInput)?,
7651 evaluator: gam_solve::estimate::ExternalJointHyperEvaluator::new(
7652 y,
7653 weights,
7654 &best.design.design,
7655 offset,
7656 &best.design.penalties,
7657 &external_opts_for_design(&family, &best.design, options),
7658 "latent-coordinate-joint",
7659 )?,
7660 };
7661 let registry_for_key = ctx.cache.analytic_penalties();
7662 ctx.evaluator
7663 .set_analytic_penalty_registry(registry_for_key.as_deref());
7664 ctx.evaluator
7665 .set_persistent_latent_values_fingerprint(latent.values.id_mode());
7666 if let Some(cached_t) = ctx
7667 .evaluator
7668 .load_persistent_latent_values(latent.values.n_obs(), latent.values.latent_dim())
7669 {
7670 let cached_t: Array2<f64> = cached_t;
7671 for (dst, src) in theta0
7672 .slice_mut(s![rho_dim..rho_dim + latent_flat_dim])
7673 .iter_mut()
7674 .zip(cached_t.iter())
7675 {
7676 *dst = *src;
7677 }
7678 }
7679
7680 let problem = exact_joint_multistart_outer_problem(
7681 &theta0,
7682 &lower,
7683 &upper,
7684 rho_dim,
7685 latent_coord_ext_dim,
7686 theta0.len(),
7687 Derivative::Analytic,
7688 DeclaredHessianForm::Unavailable,
7689 false,
7690 false,
7691 seed_risk_profile_for_likelihood_family(&family),
7692 options.tol,
7693 options.max_iter.max(1),
7694 Some(5.0),
7695 Some(0.5),
7696 None,
7697 Some((data.nrows(), best.design.design.ncols().max(1))),
7700 !constant_curvature_term_indices(resolvedspec).is_empty(),
7703 );
7704
7705 let eval_outer = |ctx: &mut &mut LatentJointContext<'_>,
7706 theta: &Array1<f64>,
7707 order: OuterEvalOrder|
7708 -> Result<OuterEval, EstimationError> {
7709 let (cost, gradient, hessian) = ctx.eval_full(theta, order)?;
7710 Ok(OuterEval {
7711 cost,
7712 gradient,
7713 hessian,
7714 inner_beta_hint: None,
7715 })
7716 };
7717
7718 let result = {
7719 let mut obj = problem.build_objective_with_eval_order(
7720 &mut ctx,
7721 |ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>| Ok(ctx.eval_cost(theta)),
7722 |ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>| {
7723 eval_outer(ctx, theta, OuterEvalOrder::ValueAndGradient)
7724 },
7725 |ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>, order: OuterEvalOrder| {
7726 eval_outer(ctx, theta, order)
7727 },
7728 Some(|ctx: &mut &mut LatentJointContext<'_>| {
7729 ctx.cache.reset();
7730 }),
7731 Some(|ctx: &mut &mut LatentJointContext<'_>, theta: &Array1<f64>| ctx.eval_efs(theta)),
7732 );
7733
7734 problem
7735 .run(&mut obj, "latent-coordinate joint REML")
7736 .map_err(|e| {
7737 EstimationError::InvalidInput(format!(
7738 "latent-coordinate joint optimization failed after exhausting strategy fallbacks: {e}"
7739 ))
7740 })?
7741 };
7742 if !result.converged {
7743 crate::bail_invalid_estim!(
7744 "latent-coordinate joint optimization did not converge after {} iterations (final_objective={:.6e}, final_grad_norm={})",
7745 result.iterations,
7746 result.final_value,
7747 result.final_grad_norm_report(),
7748 );
7749 }
7750
7751 let theta_star = result.rho;
7752 let rho_star = theta_star.slice(s![..rho_dim]).mapv(f64::exp);
7753 let mut final_data = data.to_owned();
7754 let flat_t = theta_star
7755 .slice(s![rho_dim..rho_dim + latent_flat_dim])
7756 .to_owned();
7757 let mut fitted_latent_values =
7758 Array2::<f64>::zeros((latent.values.n_obs(), latent.values.latent_dim()));
7759 for n in 0..latent.values.n_obs() {
7760 for axis in 0..latent.values.latent_dim() {
7761 let value = flat_t[n * latent.values.latent_dim() + axis];
7762 fitted_latent_values[[n, axis]] = value;
7763 final_data[[n, latent.feature_cols[axis]]] = value;
7764 }
7765 }
7766 let optimized = fit_term_collection_forspecwith_heuristic_lambdas(
7767 final_data.view(),
7768 y,
7769 weights,
7770 offset,
7771 resolvedspec,
7772 rho_star.as_slice(),
7773 family,
7774 options,
7775 )?;
7776 ctx.evaluator
7777 .store_persistent_latent_values(&fitted_latent_values);
7778 let mut fit = optimized.fit;
7779 fit.reml_score = result.final_value;
7780 fit.penalized_objective = result.final_value;
7781 Ok(FittedTermCollectionWithSpec {
7782 fit,
7783 design: optimized.design,
7784 resolvedspec: resolvedspec.clone(),
7785 adaptive_diagnostics: optimized.adaptive_diagnostics,
7786 kappa_timing: None,
7787 })
7788}
7789
7790pub fn fit_term_collectionwith_latent_coord_optimization(
7791 data: ArrayView2<'_, f64>,
7792 y: Array1<f64>,
7793 weights: Array1<f64>,
7794 offset: Array1<f64>,
7795 spec: &TermCollectionSpec,
7796 latent: &StandardLatentCoordConfig,
7797 family: LikelihoodSpec,
7798 options: &FitOptions,
7799) -> Result<FittedTermCollectionWithSpec, EstimationError> {
7800 let n = data.nrows();
7801 if !(y.len() == n && weights.len() == n && offset.len() == n) {
7802 crate::bail_invalid_estim!(
7803 "fit_term_collectionwith_latent_coord_optimization row mismatch: n={}, y={}, weights={}, offset={}",
7804 n,
7805 y.len(),
7806 weights.len(),
7807 offset.len()
7808 );
7809 }
7810 let best = fit_term_collection_forspec(
7811 data,
7812 y.view(),
7813 weights.view(),
7814 offset.view(),
7815 spec,
7816 family.clone(),
7817 options,
7818 )?;
7819 let resolvedspec = freeze_term_collection_from_design(spec, &best.design)?;
7820 try_exact_joint_latent_coord_optimization(
7821 data,
7822 y.view(),
7823 weights.view(),
7824 offset.view(),
7825 &resolvedspec,
7826 &best,
7827 family,
7828 options,
7829 latent,
7830 )
7831}
7832
7833pub fn fit_term_collectionwith_spatial_length_scale_optimization(
7834 data: ArrayView2<'_, f64>,
7835 y: Array1<f64>,
7836 weights: Array1<f64>,
7837 offset: Array1<f64>,
7838 spec: &TermCollectionSpec,
7839 family: LikelihoodSpec,
7840 options: &FitOptions,
7841 kappa_options: &SpatialLengthScaleOptimizationOptions,
7842) -> Result<FittedTermCollectionWithSpec, EstimationError> {
7843 let mut resolvedspec = spec.clone();
7859 let spatial_terms = spatial_length_scale_term_indices(&resolvedspec);
7860 let n = data.nrows();
7861 if !(y.len() == n && weights.len() == n && offset.len() == n) {
7862 crate::bail_invalid_estim!(
7863 "fit_term_collectionwith_spatial_length_scale_optimization row mismatch: n={}, y={}, weights={}, offset={}",
7864 n,
7865 y.len(),
7866 weights.len(),
7867 offset.len()
7868 );
7869 }
7870 if !kappa_options.enabled || spatial_terms.is_empty() {
7871 let out = fit_term_collection_forspec(
7872 data,
7873 y.view(),
7874 weights.view(),
7875 offset.view(),
7876 &resolvedspec,
7877 family,
7878 options,
7879 )?;
7880 let resolvedspec = freeze_term_collection_from_design(&resolvedspec, &out.design)?;
7881 return Ok(FittedTermCollectionWithSpec {
7882 fit: out.fit,
7883 design: out.design,
7884 resolvedspec,
7885 adaptive_diagnostics: out.adaptive_diagnostics,
7886 kappa_timing: None,
7887 });
7888 }
7889 if kappa_options.max_outer_iter == 0 {
7890 crate::bail_invalid_estim!("spatial kappa optimization requires max_outer_iter >= 1");
7891 }
7892 if !(kappa_options.log_step.is_finite() && kappa_options.log_step > 0.0) {
7893 crate::bail_invalid_estim!("spatial kappa optimization requires log_step > 0");
7894 }
7895 if !(kappa_options.min_length_scale.is_finite()
7896 && kappa_options.max_length_scale.is_finite()
7897 && kappa_options.min_length_scale > 0.0
7898 && kappa_options.max_length_scale >= kappa_options.min_length_scale)
7899 {
7900 crate::bail_invalid_estim!(
7901 "spatial kappa optimization requires valid positive length_scale bounds"
7902 );
7903 }
7904
7905 let pilot_threshold = kappa_options.pilot_subsample_threshold;
7906 if pilot_threshold > 0 && n > pilot_threshold * 2 {
7907 log::info!(
7908 "[spatial-kappa] n={n} exceeds pilot threshold {}; using pilot geometry only for deterministic anisotropy initialization",
7909 pilot_threshold * 2,
7910 );
7911 apply_spatial_anisotropy_pilot_initializer(
7912 data,
7913 &mut resolvedspec,
7914 &spatial_terms,
7915 pilot_threshold,
7916 kappa_options,
7917 );
7918 }
7919
7920 apply_response_aware_anisotropy_seed(data, y.view(), &mut resolvedspec, &spatial_terms);
7929
7930 for term_idx in constant_curvature_term_indices(&resolvedspec) {
7948 if let Some(kappa_seed) =
7949 select_constant_curvature_kappa_sign_seed(data, y.view(), &resolvedspec, term_idx)
7950 && kappa_seed != 0.0
7951 && let Some(SmoothBasisSpec::ConstantCurvature { spec: cc, .. }) =
7952 resolvedspec.smooth_terms.get_mut(term_idx).map(|t| &mut t.basis)
7953 {
7954 log::info!(
7955 "[#1464] pinned CC term {term_idx} baseline κ to κ-fair scan value {kappa_seed} \
7956 (raw profiled REML is sign-blind; scan is authoritative for the sign)"
7957 );
7958 cc.kappa = kappa_seed;
7959 }
7960 }
7961
7962 let baseline_options = superseded_fit_options(options);
7963 let mut best = fit_term_collection_forspec(
7964 data,
7965 y.view(),
7966 weights.view(),
7967 offset.view(),
7968 &resolvedspec,
7969 family.clone(),
7970 &baseline_options,
7971 )?;
7972 resolvedspec = freeze_term_collection_from_design(&resolvedspec, &best.design)?;
7973 let mut spatial_terms = spatial_length_scale_term_indices(&resolvedspec);
7983 sync_aniso_contrasts_from_metadata(&mut resolvedspec, &best.design.smooth);
7987 let mut prescan_improved = false;
7994 if !spatial_terms.is_empty() {
7995 let baseline_score = fit_score(&best.fit);
7996 let range_overrides = prescan_isotropic_spatial_range_seed(
7997 data,
7998 y.view(),
7999 weights.view(),
8000 offset.view(),
8001 &resolvedspec,
8002 baseline_score,
8003 &family,
8004 &baseline_options,
8005 kappa_options,
8006 &spatial_terms,
8007 )?;
8008 if !range_overrides.is_empty() {
8009 prescan_improved = true;
8010 for (term_idx, length_scale) in range_overrides {
8011 set_spatial_length_scale(&mut resolvedspec, term_idx, length_scale)?;
8012 }
8013 best = fit_term_collection_forspec(
8017 data,
8018 y.view(),
8019 weights.view(),
8020 offset.view(),
8021 &resolvedspec,
8022 family.clone(),
8023 &baseline_options,
8024 )?;
8025 resolvedspec = freeze_term_collection_from_design(&resolvedspec, &best.design)?;
8026 spatial_terms = spatial_length_scale_term_indices(&resolvedspec);
8030 }
8031 }
8032 if spatial_terms.is_empty() {
8033 let fitted = fit_term_collection_forspecwith_heuristic_lambdas(
8034 data,
8035 y.view(),
8036 weights.view(),
8037 offset.view(),
8038 &resolvedspec,
8039 best.fit.lambdas.as_slice(),
8040 family,
8041 options,
8042 )?;
8043 return Ok(FittedTermCollectionWithSpec {
8044 fit: fitted.fit,
8045 design: fitted.design,
8046 resolvedspec,
8047 adaptive_diagnostics: fitted.adaptive_diagnostics,
8048 kappa_timing: None,
8049 });
8050 }
8051 let initial_score = fit_score(&best.fit);
8052 if !initial_score.is_finite() {
8053 log::debug!("[spatial-kappa] initial profiled score is non-finite");
8054 }
8055 let joint_result = try_exact_joint_spatial_length_scale_optimization(
8056 data,
8057 y.view(),
8058 weights.view(),
8059 offset.view(),
8060 &resolvedspec,
8061 &best,
8062 family.clone(),
8063 options,
8064 kappa_options,
8065 &spatial_terms,
8066 )
8067 .map(|opt| {
8068 opt.map(|fit| {
8069 let score = fit_score(&fit.fit);
8070 (fit, score)
8071 })
8072 });
8073 let exact_joint = if prescan_improved && !matches!(joint_result, Ok(Some(_))) {
8083 let reason = match &joint_result {
8084 Err(e) => format!("error: {e}"),
8085 _ => "unavailable".to_string(),
8086 };
8087 log::info!(
8088 "[spatial-kappa] #1074 joint polish yielded no usable candidate \
8089 ({reason}); returning the multi-start pre-scan geometry (REML {initial_score:.5})"
8090 );
8091 FittedTermCollectionWithSpec {
8092 fit: best.fit,
8093 design: best.design,
8094 resolvedspec,
8095 adaptive_diagnostics: best.adaptive_diagnostics,
8096 kappa_timing: None,
8097 }
8098 } else {
8099 require_successful_spatial_optimization_result(initial_score, joint_result)?
8100 };
8101 log_spatial_aniso_scales(&exact_joint.resolvedspec);
8102 Ok(exact_joint)
8103}
8104
8105#[derive(Clone, Debug)]
8111pub struct CurvatureInference {
8112 pub term_idx: usize,
8114 pub kappa_hat: f64,
8117 pub ci: gam_geometry::curvature_estimand::KappaProfileCi,
8119 pub flatness: gam_geometry::curvature_estimand::FlatnessTest,
8123}
8124
8125pub fn curvature_inference_forspec(
8143 data: ArrayView2<'_, f64>,
8144 y: ArrayView1<'_, f64>,
8145 weights: ArrayView1<'_, f64>,
8146 offset: ArrayView1<'_, f64>,
8147 resolvedspec: &TermCollectionSpec,
8148 term_idx: usize,
8149 family: LikelihoodSpec,
8150 options: &FitOptions,
8151 level: f64,
8152) -> Result<CurvatureInference, EstimationError> {
8153 let kappa_hat = get_constant_curvature_kappa(resolvedspec, term_idx).ok_or_else(|| {
8154 EstimationError::InvalidInput(format!(
8155 "curvature_inference_forspec: term {term_idx} is not a constant-curvature smooth"
8156 ))
8157 })?;
8158 let (kappa_min, kappa_max) = constant_curvature_kappa_bounds(data, resolvedspec, term_idx);
8159
8160 let cc_fair_inputs: Option<(Array2<f64>, gam_terms::basis::ConstantCurvatureBasisSpec)> =
8185 if kappa_hat < 0.0 {
8186 match resolvedspec.smooth_terms.get(term_idx).map(|t| &t.basis) {
8187 Some(SmoothBasisSpec::ConstantCurvature {
8188 feature_cols, spec, ..
8189 }) => select_columns(data, feature_cols)
8190 .ok()
8191 .map(|x| (x, spec.clone())),
8192 _ => None,
8193 }
8194 } else {
8195 None
8196 };
8197
8198 let v_p_cache: std::cell::RefCell<std::collections::HashMap<u64, f64>> =
8203 std::cell::RefCell::new(std::collections::HashMap::new());
8204 let v_p = |kappa: f64| -> Result<f64, String> {
8205 if !kappa.is_finite() {
8206 return Err(format!("V_p probed a non-finite κ = {kappa}"));
8207 }
8208 let key = kappa.to_bits();
8209 if let Some(&cached) = v_p_cache.borrow().get(&key) {
8210 return Ok(cached);
8211 }
8212 let score = if let Some((x_term, base_spec)) = &cc_fair_inputs {
8213 let mut probe_spec = base_spec.clone();
8214 probe_spec.kappa = kappa;
8215 gam_terms::basis::constant_curvature_kappa_fair_sign_score(x_term.view(), y, &probe_spec)
8216 .map_err(|e| format!("κ-fair criterion at κ={kappa} failed: {e}"))?
8217 } else {
8218 fixed_kappa_profiled_reml_score(
8219 data,
8220 y,
8221 weights,
8222 offset,
8223 resolvedspec,
8224 term_idx,
8225 kappa,
8226 family.clone(),
8227 options,
8228 )
8229 .map_err(|e| format!("V_p fixed-κ fit at κ={kappa} failed: {e}"))?
8230 };
8231 v_p_cache.borrow_mut().insert(key, score);
8232 Ok(score)
8233 };
8234
8235 let h = (1e-3 * (kappa_max - kappa_min)).max(1e-4);
8239 let v_pp = match (v_p(kappa_hat + h), v_p(kappa_hat), v_p(kappa_hat - h)) {
8240 (Ok(vp), Ok(v0), Ok(vm)) => (vp - 2.0 * v0 + vm) / (h * h),
8241 _ => f64::NAN, };
8243
8244 let ci = gam_geometry::curvature_estimand::profile_ci_walk(
8245 &v_p, kappa_hat, v_pp, kappa_min, kappa_max, level, 1e-4,
8246 )
8247 .map_err(EstimationError::InvalidInput)?;
8248 let flatness = gam_geometry::curvature_estimand::flatness_lr_test(&v_p, kappa_hat)
8249 .map_err(EstimationError::InvalidInput)?;
8250
8251 Ok(CurvatureInference {
8252 term_idx,
8253 kappa_hat,
8254 ci,
8255 flatness,
8256 })
8257}
8258
8259#[derive(Clone, Copy, Debug, PartialEq, Eq)]
8262pub enum SmoothLrCorrection {
8263 LawleyLrEstimatedLambda,
8267 LawleyLrFixedLambda,
8272 None,
8276}
8277
8278impl SmoothLrCorrection {
8279 pub fn label(self) -> &'static str {
8281 match self {
8282 SmoothLrCorrection::LawleyLrEstimatedLambda => "lawley_lr_estimated_lambda",
8283 SmoothLrCorrection::LawleyLrFixedLambda => "lawley_lr_fixed_lambda",
8284 SmoothLrCorrection::None => "none",
8285 }
8286 }
8287}
8288
8289#[derive(Clone, Debug)]
8295pub struct SmoothTermLrInference {
8296 pub name: String,
8298 pub term_idx: usize,
8300 pub statistic_lr: f64,
8303 pub ref_df: f64,
8306 pub bartlett_factor: f64,
8309 pub bartlett_factor_conditional: Option<f64>,
8313 pub rho_variation_shift: Option<f64>,
8316 pub statistic_corrected: f64,
8318 pub p_value_uncorrected: f64,
8320 pub p_value_corrected: f64,
8323 pub material: bool,
8331 pub correction: SmoothLrCorrection,
8333}
8334
8335pub const SMOOTH_LR_MATERIAL_THRESHOLD: f64 = 0.10;
8339
8340fn fitted_rho_penalty_components(
8346 penalties: &[BlockwisePenalty],
8347 lambdas: &[f64],
8348 p_total: usize,
8349) -> Result<Vec<gam_terms::inference::lawley::RhoPenaltyComponent>, EstimationError> {
8350 if penalties.len() != lambdas.len() {
8351 return Err(EstimationError::InvalidInput(format!(
8352 "smooth_term_lr_inference: penalty/lambda count mismatch ({} penalties, {} lambdas)",
8353 penalties.len(),
8354 lambdas.len()
8355 )));
8356 }
8357 let mut components = Vec::with_capacity(penalties.len());
8358 for (idx, (penalty, &lambda)) in penalties.iter().zip(lambdas.iter()).enumerate() {
8359 if !(lambda.is_finite() && lambda >= 0.0) {
8360 return Err(EstimationError::InvalidInput(format!(
8361 "smooth_term_lr_inference: lambda[{idx}] is invalid: {lambda}"
8362 )));
8363 }
8364 let r = &penalty.col_range;
8365 if r.end > p_total {
8366 return Err(EstimationError::InvalidInput(format!(
8367 "smooth_term_lr_inference: penalty[{idx}] range {:?} exceeds coefficient dimension {p_total}",
8368 r
8369 )));
8370 }
8371 let mut s_component = Array2::<f64>::zeros((p_total, p_total));
8372 s_component
8373 .slice_mut(s![r.start..r.end, r.start..r.end])
8374 .scaled_add(lambda, &penalty.local);
8375 components.push(gam_terms::inference::lawley::RhoPenaltyComponent { s_component });
8376 }
8377 Ok(components)
8378}
8379
8380pub fn smooth_term_lr_inference_forspec(
8421 data: ArrayView2<'_, f64>,
8422 y: ArrayView1<'_, f64>,
8423 weights: ArrayView1<'_, f64>,
8424 offset: ArrayView1<'_, f64>,
8425 resolvedspec: &TermCollectionSpec,
8426 family: LikelihoodSpec,
8427 options: &FitOptions,
8428) -> Result<Vec<SmoothTermLrInference>, EstimationError> {
8429 use gam_terms::inference::lawley::{
8430 LAWLEY_PAIR_MATRIX_MAX_ROWS, known_scale_expected_jets_with_dispersion,
8431 lawley_lr_bartlett_factor, lawley_lr_mean_shift_with_rho_variation,
8432 };
8433
8434 let n = data.nrows();
8435 let full = fit_term_collection_forspec(
8438 data,
8439 y,
8440 weights,
8441 offset,
8442 resolvedspec,
8443 family.clone(),
8444 options,
8445 )?;
8446 let ll_full = full.fit.log_likelihood;
8447 let p_total = full.design.design.ncols();
8448 let lambdas = full.fit.lambdas.as_slice().ok_or_else(|| {
8449 EstimationError::InvalidInput(
8450 "smooth_term_lr_inference: non-contiguous lambda vector".to_string(),
8451 )
8452 })?;
8453 let s_lambda = weighted_blockwise_penalty_sum(&full.design.penalties, lambdas, p_total);
8454 let rho_penalty_components =
8455 fitted_rho_penalty_components(&full.design.penalties, lambdas, p_total)?;
8456 let rho_covariance = full.fit.artifacts.rho_covariance.as_ref().filter(|cov| {
8457 cov.nrows() == rho_penalty_components.len() && cov.ncols() == rho_penalty_components.len()
8458 });
8459 let full_design_dense = full.design.design.to_dense();
8461 let influence = full.fit.coefficient_influence();
8462 let family_disp = lawley_dispersion_for_family(&family, &full.fit);
8463
8464 let mut penalty_cursor = full.design.random_effect_ranges.len();
8467 let mut out = Vec::<SmoothTermLrInference>::new();
8468 for (term_idx, design_term) in full.design.smooth.terms.iter().enumerate() {
8469 let k = design_term.penalties_local.len();
8470 let block_start = penalty_cursor;
8471 penalty_cursor += k;
8472 if design_term.shape != ShapeConstraint::None {
8475 continue;
8476 }
8477 let coeff_range = design_term.coeff_range.clone();
8478 if coeff_range.start >= coeff_range.end || coeff_range.end > p_total {
8479 continue;
8480 }
8481 let edf = full.fit.per_term_edf(coeff_range.clone(), block_start, k);
8493 let ref_df = wood_reference_df(influence, &coeff_range).unwrap_or(edf.max(1e-12));
8494 if !(ref_df.is_finite() && ref_df > 0.0) {
8495 continue;
8496 }
8497
8498 let mut null_spec = resolvedspec.clone();
8501 let Some(spec_pos) = null_spec
8502 .smooth_terms
8503 .iter()
8504 .position(|t| t.name == design_term.name)
8505 else {
8506 continue;
8507 };
8508 null_spec.smooth_terms.remove(spec_pos);
8509 let null_fit = fit_term_collection_forspec(
8510 data,
8511 y,
8512 weights,
8513 offset,
8514 &null_spec,
8515 family.clone(),
8516 options,
8517 );
8518 let (statistic_lr, eta_null) = match null_fit {
8519 Ok(null) if null.fit.log_likelihood.is_finite() => {
8520 let w = (2.0 * (ll_full - null.fit.log_likelihood)).max(0.0);
8521 let mut eta = null.design.design.dot(&null.fit.beta);
8525 eta += &offset;
8526 (w, Some(eta))
8527 }
8528 _ => (f64::NAN, None),
8529 };
8530
8531 let chi2 = statrs::distribution::ChiSquared::new(ref_df).ok();
8532 let p_uncorrected = match (chi2.as_ref(), statistic_lr.is_finite()) {
8533 (Some(dist), true) => {
8534 use statrs::distribution::ContinuousCDF;
8535 (1.0 - dist.cdf(statistic_lr)).clamp(0.0, 1.0)
8536 }
8537 _ => f64::NAN,
8538 };
8539
8540 let mut bartlett_factor = 1.0;
8544 let mut bartlett_factor_conditional = None;
8545 let mut rho_variation_shift = None;
8546 let mut statistic_corrected = statistic_lr;
8547 let mut p_corrected = p_uncorrected;
8548 let mut correction = SmoothLrCorrection::None;
8549 if let (Some(eta), true, true) = (
8550 eta_null.as_ref(),
8551 statistic_lr.is_finite(),
8552 n <= LAWLEY_PAIR_MATRIX_MAX_ROWS,
8553 ) {
8554 let kappas: Option<Vec<_>> = (0..n)
8555 .map(|i| {
8556 known_scale_expected_jets_with_dispersion(&family, eta[i], family_disp)
8557 .and_then(|jets| jets.kappas().ok())
8558 })
8559 .collect();
8560 if let (Some(kappas), Some(dist)) = (kappas, chi2.as_ref()) {
8561 let fixed_factor = lawley_lr_bartlett_factor(
8562 full_design_dense.view(),
8563 &kappas,
8564 Some(s_lambda.view()),
8565 coeff_range.clone(),
8566 ref_df,
8567 );
8568 if let Ok(c_cond) = fixed_factor
8569 && c_cond.is_finite()
8570 && c_cond > 0.0
8571 {
8572 let mut c_applied = c_cond;
8573 correction = SmoothLrCorrection::LawleyLrFixedLambda;
8574 if let Some(cov) = rho_covariance
8575 && let Ok(total_shift) = lawley_lr_mean_shift_with_rho_variation(
8576 full_design_dense.view(),
8577 &kappas,
8578 s_lambda.view(),
8579 coeff_range.clone(),
8580 &rho_penalty_components,
8581 cov.view(),
8582 )
8583 {
8584 let mean_w = ref_df + total_shift;
8585 if let Some(c_est) =
8586 gam_terms::inference::higher_order::bartlett_factor_from_mean(
8587 mean_w, ref_df,
8588 )
8589 && c_est.is_finite()
8590 && c_est > 0.0
8591 {
8592 let conditional_shift = (c_cond - 1.0) * ref_df;
8593 c_applied = c_est;
8594 bartlett_factor_conditional = Some(c_cond);
8595 rho_variation_shift = Some(total_shift - conditional_shift);
8596 correction = SmoothLrCorrection::LawleyLrEstimatedLambda;
8597 }
8598 }
8599 use statrs::distribution::ContinuousCDF;
8600 bartlett_factor = c_applied;
8601 statistic_corrected = statistic_lr / c_applied;
8602 p_corrected = (1.0 - dist.cdf(statistic_corrected)).clamp(0.0, 1.0);
8603 }
8604 }
8605 }
8606
8607 let material = match correction {
8613 SmoothLrCorrection::LawleyLrEstimatedLambda
8614 | SmoothLrCorrection::LawleyLrFixedLambda => {
8615 let factor_move = (bartlett_factor - 1.0).abs();
8616 let p_denom = p_uncorrected.max(p_corrected).max(f64::MIN_POSITIVE);
8617 let p_move = if p_uncorrected.is_finite() && p_corrected.is_finite() {
8618 (p_corrected - p_uncorrected).abs() / p_denom
8619 } else {
8620 0.0
8621 };
8622 factor_move > SMOOTH_LR_MATERIAL_THRESHOLD || p_move > SMOOTH_LR_MATERIAL_THRESHOLD
8623 }
8624 SmoothLrCorrection::None => false,
8625 };
8626
8627 out.push(SmoothTermLrInference {
8628 name: design_term.name.clone(),
8629 term_idx,
8630 statistic_lr,
8631 ref_df,
8632 bartlett_factor,
8633 bartlett_factor_conditional,
8634 rho_variation_shift,
8635 statistic_corrected,
8636 p_value_uncorrected: p_uncorrected,
8637 p_value_corrected: p_corrected,
8638 material,
8639 correction,
8640 });
8641 }
8642 Ok(out)
8643}
8644
8645fn lawley_dispersion_for_family(family: &LikelihoodSpec, fit: &UnifiedFitResult) -> f64 {
8648 match family.response {
8649 gam_spec::ResponseFamily::Gaussian => {
8650 let sd = fit.standard_deviation;
8651 (sd * sd).max(f64::MIN_POSITIVE)
8652 }
8653 gam_spec::ResponseFamily::Gamma => {
8654 let shape = fit.standard_deviation;
8655 if shape.is_finite() && shape > 0.0 {
8656 1.0 / shape
8657 } else {
8658 1.0
8659 }
8660 }
8661 _ => 1.0,
8662 }
8663}
8664
8665fn wood_reference_df(influence: Option<&Array2<f64>>, coeff_range: &Range<usize>) -> Option<f64> {
8671 let f = influence?;
8672 let (start, end) = (coeff_range.start, coeff_range.end);
8673 if start >= end || end > f.nrows() || end > f.ncols() {
8674 return None;
8675 }
8676 let block = f.slice(s![start..end, start..end]);
8677 let tr = (0..block.nrows()).map(|i| block[[i, i]]).sum::<f64>();
8678 let tr2 = block.dot(&block).diag().sum();
8679 (tr.is_finite() && tr2.is_finite() && tr > 0.0 && tr2 > 0.0).then(|| (tr * tr / tr2).max(1e-12))
8680}