1pub fn build_term_collection_designs_joint(
9 data: ArrayView2<'_, f64>,
10 specs: &[TermCollectionSpec],
11) -> Result<Vec<TermCollectionDesign>, BasisError> {
12 for spec in specs {
13 validate_term_collection_finite_inputs(data, spec)?;
14 }
15 let smooth_blocks = specs
16 .iter()
17 .map(|spec| spec.smooth_terms.clone())
18 .collect::<Vec<_>>();
19 let planned_blocks = plan_joint_spatial_centers_for_term_blocks(data, &smooth_blocks)?;
20 let mut out = Vec::with_capacity(specs.len());
21 for (spec, planned_terms) in specs.iter().zip(planned_blocks.into_iter()) {
22 let mut planned_spec = spec.clone();
23 planned_spec.smooth_terms = planned_terms;
24 out.push(build_term_collection_design_inner(data, &planned_spec)?);
25 }
26 Ok(out)
27}
28
29pub fn build_term_collection_designs_and_freeze_joint(
30 data: ArrayView2<'_, f64>,
31 specs: &[TermCollectionSpec],
32) -> Result<(Vec<TermCollectionDesign>, Vec<TermCollectionSpec>), EstimationError> {
33 let designs = build_term_collection_designs_joint(data, specs)?;
34 let mut resolved_specs = Vec::with_capacity(specs.len());
35 for (spec, design) in specs.iter().zip(designs.iter()) {
36 resolved_specs.push(freeze_term_collection_from_design(spec, design)?);
37 }
38 Ok((designs, resolved_specs))
39}
40
41pub fn fit_term_collection_forspec(
42 data: ArrayView2<'_, f64>,
43 y: ArrayView1<'_, f64>,
44 weights: ArrayView1<'_, f64>,
45 offset: ArrayView1<'_, f64>,
46 spec: &TermCollectionSpec,
47 family: LikelihoodSpec,
48 options: &FitOptions,
49) -> Result<FittedTermCollection, EstimationError> {
50 fit_term_collection_forspecwith_heuristic_lambdas(
51 data, y, weights, offset, spec, None, family, options,
52 )
53}
54
55pub fn fit_term_collection_with_coefficient_groups(
56 data: ArrayView2<'_, f64>,
57 y: ArrayView1<'_, f64>,
58 weights: ArrayView1<'_, f64>,
59 offset: ArrayView1<'_, f64>,
60 spec: &TermCollectionSpec,
61 groups: &[CoefficientGroupSpec],
62 family: LikelihoodSpec,
63 options: &FitOptions,
64) -> Result<FittedTermCollection, EstimationError> {
65 if groups.is_empty() {
66 return fit_term_collection_forspec(data, y, weights, offset, spec, family, options);
67 }
68 let design = build_term_collection_design(data, spec)?;
69 let base_fit_opts = adaptive_fit_options_base(options, &design);
70 let realized = design
71 .realize_coefficient_groups(groups, &base_fit_opts.rho_prior)
72 .map_err(EstimationError::BasisError)?;
73 let mut grouped_options = base_fit_opts.clone();
74 grouped_options.rho_prior = realized.rho_prior;
75 let fitted = FittedTermCollection {
76 fit: gam_solve::estimate::fit_gam_with_penalty_specs(
77 design.design.clone(),
78 y,
79 weights,
80 offset,
81 realized.penalty_specs,
82 realized.nullspace_dims,
83 family.clone(),
84 &grouped_options,
85 )?,
86 design,
87 adaptive_diagnostics: None,
88 };
89 enforce_term_constraint_feasibility(&fitted.design, &fitted.fit)?;
90 Ok(fitted)
91}
92
93pub fn fit_term_collection_with_penalty_block_gamma_prior_callback<F>(
94 data: ArrayView2<'_, f64>,
95 y: ArrayView1<'_, f64>,
96 weights: ArrayView1<'_, f64>,
97 offset: ArrayView1<'_, f64>,
98 spec: &TermCollectionSpec,
99 callback: F,
100 family: LikelihoodSpec,
101 options: &FitOptions,
102) -> Result<FittedTermCollection, EstimationError>
103where
104 F: FnMut(&PenaltyBlockGammaPriorMetadata<'_>) -> Option<(f64, f64)>,
105{
106 let design = build_term_collection_design(data, spec)?;
107 let mut fit_opts = adaptive_fit_options_base(options, &design);
108 fit_opts.rho_prior = realize_penalty_block_gamma_priors(&design, callback)
109 .map_err(EstimationError::BasisError)?;
110 let fitted = FittedTermCollection {
111 fit: fit_gamwith_heuristic_lambdas(
112 design.design.clone(),
113 y,
114 weights,
115 offset,
116 &design.penalties,
117 None,
118 family.clone(),
119 &fit_opts,
120 )?,
121 design,
122 adaptive_diagnostics: None,
123 };
124 enforce_term_constraint_feasibility(&fitted.design, &fitted.fit)?;
125 Ok(fitted)
126}
127
128pub fn fit_term_collection_with_penalty_block_gamma_priors(
129 data: ArrayView2<'_, f64>,
130 y: ArrayView1<'_, f64>,
131 weights: ArrayView1<'_, f64>,
132 offset: ArrayView1<'_, f64>,
133 spec: &TermCollectionSpec,
134 priors: &[(String, f64, f64)],
135 family: LikelihoodSpec,
136 options: &FitOptions,
137) -> Result<FittedTermCollection, EstimationError> {
138 let design = build_term_collection_design(data, spec)?;
139 let mut fit_opts = adaptive_fit_options_base(options, &design);
140 fit_opts.rho_prior = realize_keyed_penalty_block_gamma_priors(&design, priors)
141 .map_err(EstimationError::BasisError)?;
142 let fitted = FittedTermCollection {
143 fit: fit_gamwith_heuristic_lambdas(
144 design.design.clone(),
145 y,
146 weights,
147 offset,
148 &design.penalties,
149 None,
150 family.clone(),
151 &fit_opts,
152 )?,
153 design,
154 adaptive_diagnostics: None,
155 };
156 enforce_term_constraint_feasibility(&fitted.design, &fitted.fit)?;
157 Ok(fitted)
158}
159
160fn expand_double_penalty_linear_penalty_blocks(
180 design: &TermCollectionDesign,
181 spec: &TermCollectionSpec,
182) -> TermCollectionDesign {
183 let Some(shared_idx) = design.penaltyinfo.iter().position(|info| {
184 info.termname.as_deref() == Some("linear")
185 && matches!(&info.penalty.source, PenaltySource::Other(s) if s == "LinearTermRidge")
186 }) else {
187 return design.clone();
188 };
189
190 let mut new_penalties = Vec::<BlockwisePenalty>::new();
191 let mut new_nullspace = Vec::<usize>::new();
192 let mut new_info = Vec::<PenaltyBlockInfo>::new();
193 for (j, linear) in spec.linear_terms.iter().enumerate() {
194 if !linear.double_penalty {
195 continue;
196 }
197 let Some((_, range)) = design
198 .linear_ranges
199 .iter()
200 .find(|(name, _)| name == &linear.name)
201 else {
202 continue;
203 };
204 new_penalties.push(BlockwisePenalty::ridge(range.clone(), 1.0));
216 new_nullspace.push(0);
217 new_info.push(PenaltyBlockInfo {
218 global_index: 0,
219 termname: Some(linear.name.clone()),
220 penalty: PenaltyInfo {
221 source: PenaltySource::Other("LinearTermRidge".to_string()),
222 original_index: j,
223 active: true,
224 effective_rank: 1,
225 dropped_reason: None,
226 nullspace_dim_hint: 0,
227 normalization_scale: 1.0,
228 kronecker_factors: None,
229 },
230 });
231 }
232
233 if new_penalties.is_empty() {
234 return design.clone();
235 }
236
237 let mut expanded = design.clone();
238 expanded
239 .penalties
240 .splice(shared_idx..=shared_idx, new_penalties);
241 expanded
242 .nullspace_dims
243 .splice(shared_idx..=shared_idx, new_nullspace);
244 expanded
245 .penaltyinfo
246 .splice(shared_idx..=shared_idx, new_info);
247 for (idx, info) in expanded.penaltyinfo.iter_mut().enumerate() {
249 info.global_index = idx;
250 }
251 expanded
252}
253
254pub fn fit_term_collection_with_coefficient_groups_and_penalty_block_gamma_priors(
255 data: ArrayView2<'_, f64>,
256 y: ArrayView1<'_, f64>,
257 weights: ArrayView1<'_, f64>,
258 offset: ArrayView1<'_, f64>,
259 spec: &TermCollectionSpec,
260 groups: &[CoefficientGroupSpec],
261 priors: &[(String, f64, f64)],
262 family: LikelihoodSpec,
263 options: &FitOptions,
264) -> Result<FittedTermCollection, EstimationError> {
265 if groups.is_empty() {
266 return fit_term_collection_with_penalty_block_gamma_priors(
267 data, y, weights, offset, spec, priors, family, options,
268 );
269 }
270 if priors.is_empty() {
271 return fit_term_collection_with_coefficient_groups(
272 data, y, weights, offset, spec, groups, family, options,
273 );
274 }
275
276 let design = build_term_collection_design(data, spec)?;
277 let design = expand_double_penalty_linear_penalty_blocks(&design, spec);
283 let base_fit_opts = adaptive_fit_options_base(options, &design);
284 let base_rho_prior = realize_keyed_penalty_block_gamma_priors(&design, priors)
285 .map_err(EstimationError::BasisError)?;
286 let realized = design
287 .realize_coefficient_groups(groups, &base_rho_prior)
288 .map_err(EstimationError::BasisError)?;
289 let mut grouped_options = base_fit_opts.clone();
290 grouped_options.rho_prior = realized.rho_prior;
291 let fitted = FittedTermCollection {
292 fit: gam_solve::estimate::fit_gam_with_penalty_specs(
293 design.design.clone(),
294 y,
295 weights,
296 offset,
297 realized.penalty_specs,
298 realized.nullspace_dims,
299 family.clone(),
300 &grouped_options,
301 )?,
302 design,
303 adaptive_diagnostics: None,
304 };
305 enforce_term_constraint_feasibility(&fitted.design, &fitted.fit)?;
306 Ok(fitted)
307}
308
309fn fit_term_collection_forspecwith_heuristic_lambdas(
310 data: ArrayView2<'_, f64>,
311 y: ArrayView1<'_, f64>,
312 weights: ArrayView1<'_, f64>,
313 offset: ArrayView1<'_, f64>,
314 spec: &TermCollectionSpec,
315 heuristic_lambdas: Option<&[f64]>,
316 family: LikelihoodSpec,
317 options: &FitOptions,
318) -> Result<FittedTermCollection, EstimationError> {
319 let adaptive_opts = options.adaptive_regularization.clone().unwrap_or_default();
320 let resolved_spec;
321 let design_spec = if adaptive_opts.enabled {
322 resolved_spec = ensure_matern_adaptive_center_resolution(spec, data.nrows());
323 &resolved_spec
324 } else {
325 spec
326 };
327 let base_design = build_term_collection_design(data, design_spec)?;
328 fit_term_collection_on_realized_design(
329 y,
330 weights,
331 offset,
332 design_spec,
333 &base_design,
334 heuristic_lambdas,
335 family,
336 options,
337 )
338}
339
340fn ensure_matern_adaptive_center_resolution(
341 spec: &TermCollectionSpec,
342 n_rows: usize,
343) -> TermCollectionSpec {
344 let mut out = spec.clone();
345 for term in &mut out.smooth_terms {
346 let gam_terms::smooth::SmoothBasisSpec::Matern {
347 feature_cols,
348 spec: matern,
349 ..
350 } = &mut term.basis
351 else {
352 continue;
353 };
354 if let gam_terms::basis::CenterStrategy::FarthestPoint { num_centers } =
355 &mut matern.center_strategy
356 {
357 let min_centers = (4 * feature_cols.len()).min(n_rows).max(*num_centers);
370 *num_centers = min_centers;
371 }
372 }
373 out
374}
375
376fn has_bounded_linear_terms(spec: &TermCollectionSpec) -> bool {
377 spec.linear_terms.iter().any(|term| {
378 matches!(
379 term.coefficient_geometry,
380 LinearCoefficientGeometry::Bounded { .. }
381 )
382 })
383}
384
385fn fit_term_collection_on_realized_design(
386 y: ArrayView1<'_, f64>,
387 weights: ArrayView1<'_, f64>,
388 offset: ArrayView1<'_, f64>,
389 spec: &TermCollectionSpec,
390 design: &TermCollectionDesign,
391 heuristic_lambdas: Option<&[f64]>,
392 family: LikelihoodSpec,
393 options: &FitOptions,
394) -> Result<FittedTermCollection, EstimationError> {
395 if has_bounded_linear_terms(spec) {
396 return fit_bounded_term_collection_with_design(
397 y,
398 weights,
399 offset,
400 spec,
401 design,
402 heuristic_lambdas,
403 family,
404 options,
405 );
406 }
407 let mut base_fit_opts = adaptive_fit_options_base(options, design);
408 base_fit_opts.rho_prior = relax_smoothing_rho_prior(options, design);
415 let fitted = FittedTermCollection {
416 fit: fit_gamwith_heuristic_lambdas(
417 design.design.clone(),
418 y,
419 weights,
420 offset,
421 &design.penalties,
422 heuristic_lambdas,
423 family.clone(),
424 &base_fit_opts,
425 )?,
426 design: design.clone(),
427 adaptive_diagnostics: None,
428 };
429 enforce_term_constraint_feasibility(&fitted.design, &fitted.fit)?;
430
431 let adaptive_opts = options.adaptive_regularization.clone().unwrap_or_default();
432 if !adaptive_opts.enabled {
433 return Ok(fitted);
434 }
435 let runtime_caches = extract_spatial_operator_runtime_caches(spec, &fitted.design)?;
436 if runtime_caches.is_empty() {
437 return Ok(fitted);
438 }
439 fit_term_collectionwith_exact_spatial_adaptive_regularization(
446 fitted,
447 y,
448 weights,
449 offset,
450 family,
451 options,
452 &runtime_caches,
453 )
454}
455
456#[derive(Clone)]
457struct SpatialOperatorRuntimeCache {
458 termname: String,
459 feature_cols: Vec<usize>,
460 coeff_global_range: Range<usize>,
461 mass_penalty_global_idx: usize,
462 tension_penalty_global_idx: usize,
463 stiffness_penalty_global_idx: usize,
464 d0: Array2<f64>,
465 d1: Array2<f64>,
466 d2: Array2<f64>,
467 collocation_points: Array2<f64>,
468 dimension: usize,
469}
470
471#[derive(Clone)]
472struct SpatialAdaptiveWeights {
473 inv_magweight: Array1<f64>,
474 invgradweight: Array1<f64>,
475 inv_lapweight: Array1<f64>,
476}
477
478#[derive(Clone)]
479struct CharbonnierScalarBlockState {
480 signal: Array1<f64>,
481 radius: Array1<f64>,
482 epsilon: f64,
483}
484
485impl CharbonnierScalarBlockState {
486 fn from_signal(signal: Array1<f64>, epsilon: f64) -> Self {
487 let eps = epsilon.max(1e-12);
488 let radius = signal.mapv(|t| (t * t + eps * eps).sqrt());
489 Self {
490 signal,
491 radius,
492 epsilon: eps,
493 }
494 }
495
496 fn absolute_signal(&self) -> Array1<f64> {
497 self.signal.mapv(f64::abs)
498 }
499
500 fn penalty_value(&self) -> f64 {
501 self.radius.iter().map(|r| r - self.epsilon).sum::<f64>()
502 }
503
504 fn betagradient_coeff(&self) -> Array1<f64> {
505 Array1::from_iter(
506 self.signal
507 .iter()
508 .zip(self.radius.iter())
509 .map(|(t, r)| t / r),
510 )
511 }
512
513 fn betahessian_diag(&self) -> Array1<f64> {
514 let eps2 = self.epsilon * self.epsilon;
515 self.radius.mapv(|r| eps2 / r.powi(3))
516 }
517
518 fn log_epsilon_gradient_terms(&self) -> Array1<f64> {
519 let epsilon = self.epsilon;
520 let eps2 = epsilon * epsilon;
521 self.radius.mapv(|r| eps2 / r - epsilon)
522 }
523
524 fn log_epsilon_betagradient_coeff(&self) -> Array1<f64> {
525 let eps2 = self.epsilon * self.epsilon;
526 Array1::from_iter(
527 self.signal
528 .iter()
529 .zip(self.radius.iter())
530 .map(|(t, r)| -eps2 * t / r.powi(3)),
531 )
532 }
533
534 fn log_epsilon_hessian_terms(&self) -> Array1<f64> {
535 let epsilon = self.epsilon;
536 let eps2 = epsilon * epsilon;
537 let eps4 = eps2 * eps2;
538 self.radius
539 .mapv(|r| 2.0 * eps2 / r - eps4 / r.powi(3) - epsilon)
540 }
541
542 fn surrogateweights_posterior_snr(
543 &self,
544 variance: &Array1<f64>,
545 weight_floor: f64,
546 weight_ceiling: f64,
547 ) -> (Array1<f64>, Array1<f64>) {
548 let eps2 = self.epsilon * self.epsilon;
606 let weight = Array1::from_iter(self.signal.iter().zip(variance.iter()).map(|(&t, &v)| {
607 let credible2 = (t * t - v.max(0.0)).max(0.0);
608 let r = (credible2 + eps2).sqrt();
609 (1.0 / r).clamp(weight_floor, weight_ceiling)
610 }));
611 let invweight = weight.mapv(|u| 1.0 / u);
612 (weight, invweight)
613 }
614
615 fn directionalhessian_diag(&self, direction_signal: &Array1<f64>) -> Array1<f64> {
616 let eps2 = self.epsilon * self.epsilon;
631 Array1::from_iter(
632 self.signal
633 .iter()
634 .zip(direction_signal.iter())
635 .zip(self.radius.iter())
636 .map(|((t, q), r)| -3.0 * eps2 * t * q / r.powi(5)),
637 )
638 }
639
640 fn second_directionalhessian_diag(
647 &self,
648 direction1_signal: &Array1<f64>,
649 direction2_signal: &Array1<f64>,
650 ) -> Array1<f64> {
651 let eps2 = self.epsilon * self.epsilon;
652 Array1::from_iter(
653 self.signal
654 .iter()
655 .zip(direction1_signal.iter())
656 .zip(direction2_signal.iter())
657 .zip(self.radius.iter())
658 .map(|(((t, q1), q2), r)| {
659 let r2 = r * r;
660 let psi4 = -3.0 * eps2 / r.powi(5) + 15.0 * eps2 * t * t / (r.powi(5) * r2);
661 psi4 * q1 * q2
662 }),
663 )
664 }
665
666 fn log_epsilon_betahessian_diag(&self) -> Array1<f64> {
667 let eps2 = self.epsilon * self.epsilon;
668 let eps4 = eps2 * eps2;
669 Array1::from_iter(
670 self.signal
671 .iter()
672 .zip(self.radius.iter())
673 .map(|(_, r)| 2.0 * eps2 / r.powi(3) - 3.0 * eps4 / r.powi(5)),
674 )
675 }
676
677 fn log_epsilon_beta_mixed_second_coeff(&self) -> Array1<f64> {
678 let eps2 = self.epsilon * self.epsilon;
679 Array1::from_iter(
680 self.signal
681 .iter()
682 .zip(self.radius.iter())
683 .map(|(t, r)| eps2 * t * (eps2 - 2.0 * t * t) / r.powi(5)),
684 )
685 }
686
687 fn log_epsilon_betahessian_second_diag(&self) -> Array1<f64> {
688 let eps2 = self.epsilon * self.epsilon;
689 let eps4 = eps2 * eps2;
690 let eps6 = eps4 * eps2;
691 Array1::from_iter(
692 self.radius.iter().map(|r| {
693 4.0 * eps2 / r.powi(3) - 18.0 * eps4 / r.powi(5) + 15.0 * eps6 / r.powi(7)
694 }),
695 )
696 }
697
698 fn log_epsilon_betahessian_directional_diag(
699 &self,
700 direction_signal: &Array1<f64>,
701 ) -> Array1<f64> {
702 let eps2 = self.epsilon * self.epsilon;
703 let eps4 = eps2 * eps2;
704 Array1::from_iter(
705 self.signal
706 .iter()
707 .zip(direction_signal.iter())
708 .zip(self.radius.iter())
709 .map(|((t, q), r)| (-6.0 * eps2 * t / r.powi(5) + 15.0 * eps4 * t / r.powi(7)) * q),
710 )
711 }
712}
713
714#[derive(Clone)]
715struct CharbonnierGroupedBlockState {
716 norm: Array1<f64>,
717 radius: Array1<f64>,
718 signal_blocks: Array2<f64>,
719 epsilon: f64,
720}
721
722impl CharbonnierGroupedBlockState {
723 fn from_signal_blocks(signal_blocks: Array2<f64>, epsilon: f64) -> Self {
724 let eps = epsilon.max(1e-12);
725 let norm = Array1::from_iter(
726 signal_blocks
727 .rows()
728 .into_iter()
729 .map(|row| row.iter().map(|v| v * v).sum::<f64>().sqrt()),
730 );
731 let radius = norm.mapv(|g| (g * g + eps * eps).sqrt());
732 Self {
733 norm,
734 radius,
735 signal_blocks,
736 epsilon: eps,
737 }
738 }
739
740 fn penalty_value(&self) -> f64 {
741 self.radius.iter().map(|r| r - self.epsilon).sum::<f64>()
742 }
743
744 fn norm_signal(&self) -> Array1<f64> {
745 self.norm.clone()
746 }
747
748 fn betagradient_blocks(&self) -> Array2<f64> {
749 let mut out = self.signal_blocks.clone();
750 for (k, mut row) in out.rows_mut().into_iter().enumerate() {
751 let scale = 1.0 / self.radius[k];
752 row.mapv_inplace(|v| v * scale);
753 }
754 out
755 }
756
757 fn betahessian_blocks(&self) -> Vec<Array2<f64>> {
758 let mut out = Vec::with_capacity(self.signal_blocks.nrows());
759 for (k, row) in self.signal_blocks.rows().into_iter().enumerate() {
760 let dim = row.len();
761 let mut block = Array2::<f64>::eye(dim);
762 block.mapv_inplace(|v| v / self.radius[k]);
763 for i in 0..dim {
764 for j in 0..dim {
765 block[[i, j]] -= row[i] * row[j] / self.radius[k].powi(3);
766 }
767 }
768 out.push(block);
769 }
770 out
771 }
772
773 fn log_epsilon_gradient_terms(&self) -> Array1<f64> {
774 let epsilon = self.epsilon;
775 let eps2 = epsilon * epsilon;
776 self.radius.mapv(|r| eps2 / r - epsilon)
777 }
778
779 fn log_epsilon_betagradient_blocks(&self) -> Array2<f64> {
780 let mut out = self.signal_blocks.clone();
781 let eps2 = self.epsilon * self.epsilon;
782 for (k, mut row) in out.rows_mut().into_iter().enumerate() {
783 let scale = -eps2 / self.radius[k].powi(3);
784 row.mapv_inplace(|v| v * scale);
785 }
786 out
787 }
788
789 fn log_epsilon_hessian_terms(&self) -> Array1<f64> {
790 let epsilon = self.epsilon;
791 let eps2 = epsilon * epsilon;
792 let eps4 = eps2 * eps2;
793 self.radius
794 .mapv(|r| 2.0 * eps2 / r - eps4 / r.powi(3) - epsilon)
795 }
796
797 fn surrogateweights_posterior_snr(
798 &self,
799 variance: &Array1<f64>,
800 weight_floor: f64,
801 weight_ceiling: f64,
802 ) -> (Array1<f64>, Array1<f64>) {
803 let eps2 = self.epsilon * self.epsilon;
845 let weight = Array1::from_iter(self.norm.iter().zip(variance.iter()).map(|(&g, &v)| {
846 let credible2 = (g * g - v.max(0.0)).max(0.0);
847 let r = (credible2 + eps2).sqrt();
848 (1.0 / r).clamp(weight_floor, weight_ceiling)
849 }));
850 let invweight = weight.mapv(|u| 1.0 / u);
851 (weight, invweight)
852 }
853
854 fn directionalhessian_blocks(&self, direction_blocks: &Array2<f64>) -> Vec<Array2<f64>> {
855 let mut out = Vec::with_capacity(self.signal_blocks.nrows());
880 for (k, (v, q)) in self
881 .signal_blocks
882 .rows()
883 .into_iter()
884 .zip(direction_blocks.rows().into_iter())
885 .enumerate()
886 {
887 let dim = v.len();
888 let dot = v.iter().zip(q.iter()).map(|(a, b)| a * b).sum::<f64>();
889 let r3 = self.radius[k].powi(3);
890 let r5 = self.radius[k].powi(5);
891 let mut block = Array2::<f64>::eye(dim);
892 block.mapv_inplace(|x| -dot * x / r3);
893 for i in 0..dim {
894 for j in 0..dim {
895 block[[i, j]] -= (q[i] * v[j] + v[i] * q[j]) / r3;
896 block[[i, j]] += 3.0 * dot * v[i] * v[j] / r5;
897 }
898 }
899 out.push(block);
900 }
901 out
902 }
903
904 fn second_directionalhessian_blocks(
921 &self,
922 direction1_blocks: &Array2<f64>,
923 direction2_blocks: &Array2<f64>,
924 ) -> Vec<Array2<f64>> {
925 let mut out = Vec::with_capacity(self.signal_blocks.nrows());
926 for ((k, v), (a, b)) in self.signal_blocks.rows().into_iter().enumerate().zip(
927 direction1_blocks
928 .rows()
929 .into_iter()
930 .zip(direction2_blocks.rows().into_iter()),
931 ) {
932 let dim = v.len();
933 let dot = |x: ndarray::ArrayView1<'_, f64>, y: ndarray::ArrayView1<'_, f64>| {
934 x.iter().zip(y.iter()).map(|(p, q)| p * q).sum::<f64>()
935 };
936 let sa = dot(v, a);
937 let sb = dot(v, b);
938 let ab = dot(a, b);
939 let r = self.radius[k];
940 let r3 = r.powi(3);
941 let r5 = r.powi(5);
942 let r7 = r5 * r * r;
943 let diag = -ab / r3 + 3.0 * sa * sb / r5;
944 let mut block = Array2::<f64>::eye(dim);
945 block.mapv_inplace(|x| diag * x);
946 for i in 0..dim {
947 for j in 0..dim {
948 block[[i, j]] -= (a[i] * b[j] + b[i] * a[j]) / r3;
949 block[[i, j]] += 3.0 * sb * (a[i] * v[j] + v[i] * a[j]) / r5;
950 block[[i, j]] += 3.0 * ab * v[i] * v[j] / r5;
951 block[[i, j]] += 3.0 * sa * (b[i] * v[j] + v[i] * b[j]) / r5;
952 block[[i, j]] -= 15.0 * sa * sb * v[i] * v[j] / r7;
953 }
954 }
955 out.push(block);
956 }
957 out
958 }
959
960 fn log_epsilon_betahessian_blocks(&self) -> Vec<Array2<f64>> {
961 let mut out = Vec::with_capacity(self.signal_blocks.nrows());
962 for (k, row) in self.signal_blocks.rows().into_iter().enumerate() {
963 let dim = row.len();
964 let r3 = self.radius[k].powi(3);
965 let r5 = self.radius[k].powi(5);
966 let mut block = Array2::<f64>::eye(dim);
967 let eps2 = self.epsilon * self.epsilon;
968 block.mapv_inplace(|v| -eps2 * v / r3);
969 for i in 0..dim {
970 for j in 0..dim {
971 block[[i, j]] += 3.0 * eps2 * row[i] * row[j] / r5;
972 }
973 }
974 out.push(block);
975 }
976 out
977 }
978
979 fn log_epsilon_beta_mixed_second_blocks(&self) -> Array2<f64> {
980 let mut out = self.signal_blocks.clone();
981 let eps2 = self.epsilon * self.epsilon;
982 for (k, mut row) in out.rows_mut().into_iter().enumerate() {
983 let norm2 = self.norm[k] * self.norm[k];
984 let scale = eps2 * (eps2 - 2.0 * norm2) / self.radius[k].powi(5);
985 row.mapv_inplace(|v| v * scale);
986 }
987 out
988 }
989
990 fn log_epsilon_betahessian_second_blocks(&self) -> Vec<Array2<f64>> {
991 let mut out = Vec::with_capacity(self.signal_blocks.nrows());
992 let eps2 = self.epsilon * self.epsilon;
993 for (k, row) in self.signal_blocks.rows().into_iter().enumerate() {
994 let dim = row.len();
995 let norm2 = self.norm[k] * self.norm[k];
996 let r5 = self.radius[k].powi(5);
997 let r7 = self.radius[k].powi(7);
998 let mut block = Array2::<f64>::eye(dim);
999 block.mapv_inplace(|v| eps2 * (eps2 - 2.0 * norm2) * v / r5);
1000 for i in 0..dim {
1001 for j in 0..dim {
1002 block[[i, j]] += 3.0 * eps2 * (2.0 * norm2 - 3.0 * eps2) * row[i] * row[j] / r7;
1003 }
1004 }
1005 out.push(block);
1006 }
1007 out
1008 }
1009
1010 fn log_epsilon_betahessian_directional_blocks(
1011 &self,
1012 direction_blocks: &Array2<f64>,
1013 ) -> Vec<Array2<f64>> {
1014 let mut out = Vec::with_capacity(self.signal_blocks.nrows());
1015 let eps2 = self.epsilon * self.epsilon;
1016 for (k, (v, q)) in self
1017 .signal_blocks
1018 .rows()
1019 .into_iter()
1020 .zip(direction_blocks.rows().into_iter())
1021 .enumerate()
1022 {
1023 let dim = v.len();
1024 let dot = v.iter().zip(q.iter()).map(|(a, b)| a * b).sum::<f64>();
1025 let r5 = self.radius[k].powi(5);
1026 let r7 = self.radius[k].powi(7);
1027 let mut block = Array2::<f64>::eye(dim);
1028 block.mapv_inplace(|x| 3.0 * eps2 * dot * x / r5);
1029 for i in 0..dim {
1030 for j in 0..dim {
1031 block[[i, j]] += 3.0 * eps2 * (q[i] * v[j] + v[i] * q[j]) / r5;
1032 block[[i, j]] -= 15.0 * eps2 * dot * v[i] * v[j] / r7;
1033 }
1034 }
1035 out.push(block);
1036 }
1037 out
1038 }
1039}
1040
1041fn scalar_operatorgradient(operator: &Array2<f64>, coeff: &Array1<f64>) -> Array1<f64> {
1042 operator.t().dot(coeff)
1043}
1044
1045fn scalar_operatorhessian(operator: &Array2<f64>, diag: &Array1<f64>) -> Array2<f64> {
1046 let mut weighted = operator.clone();
1047 for (k, &w) in diag.iter().enumerate() {
1048 weighted.row_mut(k).mapv_inplace(|v| v * w);
1049 }
1050 let gram = operator.t().dot(&weighted);
1051 (&gram + &gram.t().to_owned()) * 0.5
1052}
1053
1054fn grouped_operatorgradient(
1055 d1: &Array2<f64>,
1056 dimension: usize,
1057 blocks: &Array2<f64>,
1058) -> Result<Array1<f64>, EstimationError> {
1059 if blocks.ncols() != dimension {
1060 crate::bail_invalid_estim!(
1061 "grouped gradient block dimension mismatch: got {}, expected {dimension}",
1062 blocks.ncols()
1063 );
1064 }
1065 if d1.nrows() != blocks.nrows() * dimension {
1066 crate::bail_invalid_estim!(
1067 "grouped gradient row mismatch: D1 has {} rows, blocks imply {}",
1068 d1.nrows(),
1069 blocks.nrows() * dimension
1070 );
1071 }
1072 let mut out = Array1::<f64>::zeros(d1.ncols());
1073 for k in 0..blocks.nrows() {
1074 let gk = d1
1075 .slice(s![k * dimension..(k + 1) * dimension, ..])
1076 .to_owned();
1077 out += &gk.t().dot(&blocks.row(k));
1078 }
1079 Ok(out)
1080}
1081
1082fn grouped_operatorhessian(
1083 d1: &Array2<f64>,
1084 dimension: usize,
1085 blocks: &[Array2<f64>],
1086) -> Result<Array2<f64>, EstimationError> {
1087 if d1.nrows() != blocks.len() * dimension {
1088 crate::bail_invalid_estim!(
1089 "grouped Hessian row mismatch: D1 has {} rows, blocks imply {}",
1090 d1.nrows(),
1091 blocks.len() * dimension
1092 );
1093 }
1094 let p = d1.ncols();
1095 let mut out = Array2::<f64>::zeros((p, p));
1096 for (k, block) in blocks.iter().enumerate() {
1097 if block.nrows() != dimension || block.ncols() != dimension {
1098 crate::bail_invalid_estim!(
1099 "grouped Hessian block {k} has shape {}x{}, expected {}x{}",
1100 block.nrows(),
1101 block.ncols(),
1102 dimension,
1103 dimension
1104 );
1105 }
1106 let gk = d1
1107 .slice(s![k * dimension..(k + 1) * dimension, ..])
1108 .to_owned();
1109 out += &gk.t().dot(&block.dot(&gk));
1110 }
1111 Ok((&out + &out.t().to_owned()) * 0.5)
1112}
1113
1114#[derive(Clone)]
1115struct SpatialPenaltyExactState {
1116 magnitude: CharbonnierScalarBlockState,
1117 gradient: CharbonnierGroupedBlockState,
1118 curvature: CharbonnierGroupedBlockState,
1119}
1120
1121fn collocationgradient_blocks(
1122 gradrows: &Array1<f64>,
1123 dimension: usize,
1124) -> Result<Array2<f64>, EstimationError> {
1125 if dimension == 0 || !gradrows.len().is_multiple_of(dimension) {
1126 crate::bail_invalid_estim!(
1127 "invalid collocation gradient layout: rows={}, dimension={dimension}",
1128 gradrows.len()
1129 );
1130 }
1131 let p = gradrows.len() / dimension;
1132 let mut out = Array2::<f64>::zeros((p, dimension));
1133 for k in 0..p {
1134 for axis in 0..dimension {
1135 out[[k, axis]] = gradrows[k * dimension + axis];
1136 }
1137 }
1138 Ok(out)
1139}
1140
1141fn collocationhessian_blocks(
1142 hessianrows: &Array1<f64>,
1143 dimension: usize,
1144) -> Result<Array2<f64>, EstimationError> {
1145 let block_dim = dimension.checked_mul(dimension).ok_or_else(|| {
1146 EstimationError::InvalidInput("invalid collocation Hessian dimension overflow".to_string())
1147 })?;
1148 if block_dim == 0 || !hessianrows.len().is_multiple_of(block_dim) {
1149 crate::bail_invalid_estim!(
1150 "invalid collocation Hessian layout: rows={}, dimension={dimension}",
1151 hessianrows.len()
1152 );
1153 }
1154 let p = hessianrows.len() / block_dim;
1155 let mut out = Array2::<f64>::zeros((p, block_dim));
1156 for k in 0..p {
1157 for idx in 0..block_dim {
1158 out[[k, idx]] = hessianrows[k * block_dim + idx];
1159 }
1160 }
1161 Ok(out)
1162}
1163
1164impl SpatialPenaltyExactState {
1165 fn from_beta_local(
1166 beta_local: ArrayView1<'_, f64>,
1167 cache: &SpatialOperatorRuntimeCache,
1168 epsilons: [f64; 3],
1169 ) -> Result<Self, EstimationError> {
1170 let gradientrows = cache.d1.dot(&beta_local);
1200 let hessianrows = cache.d2.dot(&beta_local);
1201 Ok(Self {
1202 magnitude: CharbonnierScalarBlockState::from_signal(
1203 cache.d0.dot(&beta_local),
1204 epsilons[0],
1205 ),
1206 gradient: CharbonnierGroupedBlockState::from_signal_blocks(
1207 collocationgradient_blocks(&gradientrows, cache.dimension)?,
1208 epsilons[1],
1209 ),
1210 curvature: CharbonnierGroupedBlockState::from_signal_blocks(
1211 collocationhessian_blocks(&hessianrows, cache.dimension)?,
1212 epsilons[2],
1213 ),
1214 })
1215 }
1216
1217 fn absolute_collocation_magnitudes(&self) -> (Array1<f64>, Array1<f64>, Array1<f64>) {
1218 (
1219 self.magnitude.absolute_signal(),
1220 self.gradient.norm_signal(),
1221 self.curvature.norm_signal(),
1222 )
1223 }
1224}
1225
1226fn robust_epsilon_from_samples(values: &[f64], min_epsilon_cfg: f64) -> f64 {
1227 if values.is_empty() {
1228 return min_epsilon_cfg.max(1e-12);
1229 }
1230 let mut clean = values
1231 .iter()
1232 .copied()
1233 .filter(|v| v.is_finite() && *v >= 0.0)
1234 .collect::<Vec<_>>();
1235 if clean.is_empty() {
1236 return min_epsilon_cfg.max(1e-12);
1237 }
1238 clean.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1239
1240 let n = clean.len();
1241 let median = quantile_from_sorted(&clean, 0.5);
1242 let q75 = quantile_from_sorted(&clean, 0.75);
1243 let q95 = quantile_from_sorted(&clean, 0.95);
1244
1245 let mut abs_dev = clean
1246 .iter()
1247 .map(|v| (v - median).abs())
1248 .filter(|v| v.is_finite())
1249 .collect::<Vec<_>>();
1250 abs_dev.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1251 let mad = 1.4826 * quantile_from_sorted(&abs_dev, 0.5);
1252
1253 let mut scale = median.max(mad).max(q75);
1263
1264 let delta = (f64::EPSILON.sqrt() * q95.max(1.0))
1266 .max(min_epsilon_cfg)
1267 .max(1e-12);
1268 let s_min = min_epsilon_cfg.max(1e-12);
1269
1270 if scale <= delta {
1272 let rms = (clean.iter().map(|v| v * v).sum::<f64>() / n as f64).sqrt();
1273 scale = q95.max(rms);
1274 }
1275 if scale <= delta {
1276 scale = s_min;
1277 }
1278
1279 let kappa = 1.0_f64;
1282 (kappa * scale).max(s_min)
1283}
1284
1285fn extract_spatial_operator_runtime_caches(
1286 spec: &TermCollectionSpec,
1287 design: &TermCollectionDesign,
1288) -> Result<Vec<SpatialOperatorRuntimeCache>, EstimationError> {
1289 let smooth_start = design
1290 .design
1291 .ncols()
1292 .saturating_sub(design.smooth.total_smooth_cols());
1293 let mut out = Vec::<SpatialOperatorRuntimeCache>::new();
1294 for (term_idx, (termspec, term_fit)) in spec
1295 .smooth_terms
1296 .iter()
1297 .zip(design.smooth.terms.iter())
1298 .enumerate()
1299 {
1300 let Some(global_base_idx) = smooth_term_penalty_index(spec, design, term_idx) else {
1301 continue;
1302 };
1303 let mut active_local_idx = 0usize;
1304 let mut mass_local_idx = None;
1305 let mut tension_local_idx = None;
1306 let mut stiffness_local_idx = None;
1307 let mut mass_norm = None;
1308 let mut tension_norm = None;
1309 let mut stiffness_norm = None;
1310 for info in &term_fit.penaltyinfo_local {
1311 if !info.active {
1312 continue;
1313 }
1314 match info.source {
1315 PenaltySource::OperatorMass => {
1316 mass_local_idx = Some(active_local_idx);
1317 mass_norm = Some(info.normalization_scale);
1318 }
1319 PenaltySource::OperatorTension => {
1320 tension_local_idx = Some(active_local_idx);
1321 tension_norm = Some(info.normalization_scale);
1322 }
1323 PenaltySource::OperatorStiffness => {
1324 stiffness_local_idx = Some(active_local_idx);
1325 stiffness_norm = Some(info.normalization_scale);
1326 }
1327 _ => {}
1328 }
1329 active_local_idx += 1;
1330 }
1331 let (
1344 Some(mass_local),
1345 Some(tension_local),
1346 Some(stiffness_local),
1347 Some(mass_scale),
1348 Some(tension_scale),
1349 Some(stiffness_scale),
1350 ) = (
1351 mass_local_idx,
1352 tension_local_idx,
1353 stiffness_local_idx,
1354 mass_norm,
1355 tension_norm,
1356 stiffness_norm,
1357 )
1358 else {
1359 continue;
1360 };
1361 let mass_global_idx = global_base_idx + mass_local;
1362 let tension_global_idx = global_base_idx + tension_local;
1363 let stiffness_global_idx = global_base_idx + stiffness_local;
1364
1365 let (feature_cols, mut d0, mut d1, mut d2, collocation_points, dim, center_mass_rows) =
1366 match (&termspec.basis, &term_fit.metadata) {
1367 (
1368 SmoothBasisSpec::Matern { feature_cols, .. },
1369 BasisMetadata::Matern {
1370 centers,
1371 length_scale,
1372 nu,
1373 include_intercept,
1374 identifiability_transform,
1375 aniso_log_scales,
1376 input_scales,
1377 ..
1378 },
1379 ) => {
1380 let collocation_length_scale = match input_scales.as_deref() {
1386 Some(scales) => {
1387 compensate_length_scale_for_standardization(*length_scale, scales)
1388 }
1389 None => *length_scale,
1390 };
1391 let ops = build_matern_collocation_operator_matrices(
1392 centers.view(),
1393 None,
1394 collocation_length_scale,
1395 *nu,
1396 *include_intercept,
1397 identifiability_transform.as_ref().map(|z| z.view()),
1398 aniso_log_scales.as_deref(),
1399 )?;
1400 (
1401 feature_cols.clone(),
1402 ops.d0,
1403 ops.d1,
1404 ops.d2,
1405 ops.collocation_points,
1406 centers.ncols(),
1407 false,
1408 )
1409 }
1410 (
1411 SmoothBasisSpec::Duchon { feature_cols, .. },
1412 BasisMetadata::Duchon {
1413 centers,
1414 length_scale,
1415 power,
1416 nullspace_order,
1417 identifiability_transform,
1418 input_scales,
1419 aniso_log_scales,
1420 operator_collocation_points: Some(collocation_points),
1421 ..
1422 },
1423 ) => {
1424 let collocation_length_scale = match (length_scale, input_scales.as_deref()) {
1425 (Some(ls), Some(scales)) => {
1426 Some(compensate_length_scale_for_standardization(*ls, scales))
1427 }
1428 (Some(ls), None) => Some(*ls),
1429 (None, _) => None,
1430 };
1431 let ops =
1432 gam_terms::basis::build_duchon_collocation_operator_matriceswithworkspace(
1433 centers.view(),
1434 collocation_points.view(),
1435 None,
1436 collocation_length_scale,
1437 *power,
1438 *nullspace_order,
1439 aniso_log_scales.as_deref(),
1440 identifiability_transform.as_ref().map(|z| z.view()),
1441 2,
1442 &mut BasisWorkspace::default(),
1443 )?;
1444 (
1445 feature_cols.clone(),
1446 ops.d0,
1447 ops.d1,
1448 ops.d2,
1449 ops.collocation_points,
1450 centers.ncols(),
1451 true,
1452 )
1453 }
1454 _ => continue,
1455 };
1456 if center_mass_rows && d0.nrows() > 0 && d0.ncols() > 0 {
1457 let means = d0.sum_axis(Axis(0)).mapv(|v| v / d0.nrows() as f64);
1458 for mut row in d0.rows_mut() {
1459 row -= &means;
1460 }
1461 }
1462
1463 let mass_scale = mass_scale.max(1e-12).sqrt();
1481 let tension_scale = tension_scale.max(1e-12).sqrt();
1482 let stiffness_scale = stiffness_scale.max(1e-12).sqrt();
1483 d0.mapv_inplace(|v| v / mass_scale);
1484 d1.mapv_inplace(|v| v / tension_scale);
1485 d2.mapv_inplace(|v| v / stiffness_scale);
1486
1487 let coeff_global_range =
1488 (smooth_start + term_fit.coeff_range.start)..(smooth_start + term_fit.coeff_range.end);
1489 if d0.ncols() != coeff_global_range.len()
1490 || d1.ncols() != coeff_global_range.len()
1491 || d2.ncols() != coeff_global_range.len()
1492 {
1493 crate::bail_invalid_estim!(
1494 "spatial operator dimension mismatch for term '{}': D0 cols={}, D1 cols={}, D2 cols={}, coeffs={}",
1495 term_fit.name,
1496 d0.ncols(),
1497 d1.ncols(),
1498 d2.ncols(),
1499 coeff_global_range.len()
1500 );
1501 }
1502 out.push(SpatialOperatorRuntimeCache {
1503 termname: term_fit.name.clone(),
1504 feature_cols,
1505 coeff_global_range,
1506 mass_penalty_global_idx: mass_global_idx,
1507 tension_penalty_global_idx: tension_global_idx,
1508 stiffness_penalty_global_idx: stiffness_global_idx,
1509 d0,
1510 d1,
1511 d2,
1512 collocation_points,
1513 dimension: dim,
1514 });
1515 }
1516 Ok(out)
1517}
1518
1519fn scalar_operator_response_variance(
1531 operator: &Array2<f64>,
1532 cov_local: &Array2<f64>,
1533) -> Array1<f64> {
1534 Array1::from_iter(operator.rows().into_iter().map(|row| {
1535 let s = cov_local.dot(&row);
1536 row.dot(&s).max(0.0)
1537 }))
1538}
1539
1540fn grouped_operator_response_variance(
1551 operator: &Array2<f64>,
1552 block_dim: usize,
1553 cov_local: &Array2<f64>,
1554) -> Result<Array1<f64>, EstimationError> {
1555 if block_dim == 0 || !operator.nrows().is_multiple_of(block_dim) {
1556 crate::bail_invalid_estim!(
1557 "grouped variance row layout invalid: rows={}, block_dim={block_dim}",
1558 operator.nrows()
1559 );
1560 }
1561 let p = operator.nrows() / block_dim;
1562 let mut out = Array1::<f64>::zeros(p);
1563 for k in 0..p {
1564 let mut acc = 0.0;
1565 for axis in 0..block_dim {
1566 let row = operator.row(k * block_dim + axis);
1567 let s = cov_local.dot(&row);
1568 acc += row.dot(&s);
1569 }
1570 out[k] = acc.max(0.0);
1571 }
1572 Ok(out)
1573}
1574
1575fn compute_spatial_adaptiveweights_for_beta(
1576 beta: &Array1<f64>,
1577 caches: &[SpatialOperatorRuntimeCache],
1578 epsilon_0: f64,
1579 epsilon_g: f64,
1580 epsilon_c: f64,
1581 weight_floor: f64,
1582 weight_ceiling: f64,
1583 beta_covariance: Option<&Array2<f64>>,
1584) -> Result<Vec<SpatialAdaptiveWeights>, EstimationError> {
1585 caches
1617 .iter()
1618 .map(|cache| {
1619 let beta_local = beta.slice(s![cache.coeff_global_range.clone()]);
1620 let exact = SpatialPenaltyExactState::from_beta_local(
1621 beta_local,
1622 cache,
1623 [epsilon_0, epsilon_g, epsilon_c],
1624 )?;
1625 let cov_local = beta_covariance.map(|cov| {
1626 cov.slice(s![
1627 cache.coeff_global_range.clone(),
1628 cache.coeff_global_range.clone()
1629 ])
1630 .to_owned()
1631 });
1632 let dim = cache.dimension;
1633 let (var_0, var_g, var_c) = match cov_local.as_ref() {
1634 Some(cov) => (
1635 scalar_operator_response_variance(&cache.d0, cov),
1636 grouped_operator_response_variance(&cache.d1, dim, cov)?,
1637 grouped_operator_response_variance(&cache.d2, dim * dim, cov)?,
1638 ),
1639 None => (
1640 Array1::<f64>::zeros(exact.magnitude.signal.len()),
1641 Array1::<f64>::zeros(exact.gradient.norm.len()),
1642 Array1::<f64>::zeros(exact.curvature.norm.len()),
1643 ),
1644 };
1645 let (_, inv_0) = exact.magnitude.surrogateweights_posterior_snr(
1646 &var_0,
1647 weight_floor,
1648 weight_ceiling,
1649 );
1650 let (_, inv_g) =
1651 exact
1652 .gradient
1653 .surrogateweights_posterior_snr(&var_g, weight_floor, weight_ceiling);
1654 let (_, inv_c) = exact.curvature.surrogateweights_posterior_snr(
1655 &var_c,
1656 weight_floor,
1657 weight_ceiling,
1658 );
1659 Ok(SpatialAdaptiveWeights {
1660 inv_magweight: inv_0,
1661 invgradweight: inv_g,
1662 inv_lapweight: inv_c,
1663 })
1664 })
1665 .collect()
1666}
1667
1668fn compute_initial_epsilons(
1669 beta: &Array1<f64>,
1670 caches: &[SpatialOperatorRuntimeCache],
1671 min_epsilon: f64,
1672) -> Result<(f64, f64, f64), EstimationError> {
1673 let mut fvals = Vec::<f64>::new();
1674 let mut gvals = Vec::<f64>::new();
1675 let mut cvals = Vec::<f64>::new();
1676 for cache in caches {
1677 let beta_local = beta.slice(s![cache.coeff_global_range.clone()]);
1678 let exact = SpatialPenaltyExactState::from_beta_local(
1679 beta_local,
1680 cache,
1681 [min_epsilon, min_epsilon, min_epsilon],
1682 )?;
1683 let (f, g, c) = exact.absolute_collocation_magnitudes();
1684 fvals.extend(f.iter().copied());
1685 gvals.extend(g.iter().copied());
1686 cvals.extend(c.iter().copied());
1687 }
1688 let eps_0 = robust_epsilon_from_samples(&fvals, min_epsilon);
1694 let eps_g = robust_epsilon_from_samples(&gvals, min_epsilon);
1695 let eps_c = robust_epsilon_from_samples(&cvals, min_epsilon);
1696 Ok((eps_0, eps_g, eps_c))
1697}
1698
1699fn exact_spatial_adaptive_penalty_index_set(
1700 caches: &[SpatialOperatorRuntimeCache],
1701) -> BTreeSet<usize> {
1702 let mut out = BTreeSet::new();
1703 for cache in caches {
1704 out.insert(cache.mass_penalty_global_idx);
1705 out.insert(cache.tension_penalty_global_idx);
1706 out.insert(cache.stiffness_penalty_global_idx);
1707 }
1708 out
1709}
1710
1711fn build_spatial_adaptive_hyperspecs(cache_count: usize) -> Vec<SpatialAdaptiveHyperSpec> {
1712 let mut out = Vec::with_capacity(cache_count * 3 + 3);
1713 for cache_index in 0..cache_count {
1714 out.push(SpatialAdaptiveHyperSpec {
1715 cache_index,
1716 kind: SpatialAdaptiveHyperKind::LogLambdaMagnitude,
1717 });
1718 out.push(SpatialAdaptiveHyperSpec {
1719 cache_index,
1720 kind: SpatialAdaptiveHyperKind::LogLambdaGradient,
1721 });
1722 out.push(SpatialAdaptiveHyperSpec {
1723 cache_index,
1724 kind: SpatialAdaptiveHyperKind::LogLambdaCurvature,
1725 });
1726 }
1727 out.push(SpatialAdaptiveHyperSpec {
1728 cache_index: 0,
1729 kind: SpatialAdaptiveHyperKind::LogEpsilonMagnitude,
1730 });
1731 out.push(SpatialAdaptiveHyperSpec {
1732 cache_index: 0,
1733 kind: SpatialAdaptiveHyperKind::LogEpsilonGradient,
1734 });
1735 out.push(SpatialAdaptiveHyperSpec {
1736 cache_index: 0,
1737 kind: SpatialAdaptiveHyperKind::LogEpsilonCurvature,
1738 });
1739 out
1740}
1741
1742fn penalty_matrixwith_local_block(
1743 total_dim: usize,
1744 coeff_range: Range<usize>,
1745 local: &Array2<f64>,
1746) -> Array2<f64> {
1747 let mut out = Array2::<f64>::zeros((total_dim, total_dim));
1748 out.slice_mut(s![coeff_range.clone(), coeff_range])
1749 .assign(local);
1750 out
1751}
1752
1753fn fit_term_collectionwith_exact_spatial_adaptive_regularization(
1754 baseline: FittedTermCollection,
1755 y: ArrayView1<'_, f64>,
1756 weights: ArrayView1<'_, f64>,
1757 offset: ArrayView1<'_, f64>,
1758 family: LikelihoodSpec,
1759 options: &FitOptions,
1760 runtime_caches: &[SpatialOperatorRuntimeCache],
1761) -> Result<FittedTermCollection, EstimationError> {
1762 let adaptive_opts = options.adaptive_regularization.clone().unwrap_or_default();
1791 let adaptive_penalty_indices = exact_spatial_adaptive_penalty_index_set(runtime_caches);
1792 let p_total = baseline.design.design.ncols();
1793 struct RetainedPenaltySetup {
1794 global_idx: usize,
1795 global_penalty: Array2<f64>,
1796 nullspace_dim: usize,
1797 log_lambda: f64,
1798 col_range: Range<usize>,
1799 hessian_piece: Array2<f64>,
1800 }
1801 use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
1802 let retained_setups = baseline
1803 .design
1804 .penalties
1805 .par_iter()
1806 .enumerate()
1807 .map(|(idx, bp)| {
1808 if adaptive_penalty_indices.contains(&idx) {
1809 return None;
1810 }
1811 let lambda = baseline.fit.lambdas[idx];
1812 Some(RetainedPenaltySetup {
1813 global_idx: idx,
1814 global_penalty: bp.to_global(p_total),
1815 nullspace_dim: baseline
1816 .design
1817 .nullspace_dims
1818 .get(idx)
1819 .copied()
1820 .unwrap_or(0),
1821 log_lambda: lambda.max(1e-12).ln(),
1822 col_range: bp.col_range.clone(),
1823 hessian_piece: bp.local.mapv(|v| lambda * v),
1824 })
1825 })
1826 .collect::<Vec<_>>();
1827 let retained_count = retained_setups
1828 .iter()
1829 .filter(|setup| setup.is_some())
1830 .count();
1831 let mut retained_penalties = Vec::<Array2<f64>>::with_capacity(retained_count);
1832 let mut retained_nullspace_dims = Vec::<usize>::with_capacity(retained_count);
1833 let mut retained_log_lambdas = Vec::<f64>::with_capacity(retained_count);
1834 let mut retained_global_indices = Vec::<usize>::with_capacity(retained_count);
1835 let mut fixed_quadratichessian = Array2::<f64>::zeros((p_total, p_total));
1836 for setup in retained_setups.into_iter().flatten() {
1837 retained_penalties.push(setup.global_penalty);
1838 retained_nullspace_dims.push(setup.nullspace_dim);
1839 retained_log_lambdas.push(setup.log_lambda);
1840 retained_global_indices.push(setup.global_idx);
1841 fixed_quadratichessian
1842 .slice_mut(s![setup.col_range.clone(), setup.col_range])
1843 .scaled_add(1.0, &setup.hessian_piece);
1844 }
1845
1846 let (eps_0_init, eps_g_init, eps_c_init) = compute_initial_epsilons(
1847 &baseline.fit.beta,
1848 runtime_caches,
1849 adaptive_opts.min_epsilon,
1850 )?;
1851 let mut initial_theta =
1852 Array1::<f64>::zeros(retained_penalties.len() + runtime_caches.len() * 3 + 3);
1853 for (idx, value) in retained_log_lambdas.iter().enumerate() {
1854 initial_theta[idx] = *value;
1855 }
1856 let adaptive_log_lambda_components = runtime_caches
1857 .par_iter()
1858 .map(|cache| {
1859 [
1860 baseline.fit.lambdas[cache.mass_penalty_global_idx]
1861 .max(1e-12)
1862 .ln(),
1863 baseline.fit.lambdas[cache.tension_penalty_global_idx]
1864 .max(1e-12)
1865 .ln(),
1866 baseline.fit.lambdas[cache.stiffness_penalty_global_idx]
1867 .max(1e-12)
1868 .ln(),
1869 ]
1870 })
1871 .collect::<Vec<_>>();
1872 let mut at = retained_penalties.len();
1873 for logs in &adaptive_log_lambda_components {
1874 initial_theta[at] = logs[0];
1875 initial_theta[at + 1] = logs[1];
1876 initial_theta[at + 2] = logs[2];
1877 at += 3;
1878 }
1879 initial_theta[at] = eps_0_init.max(adaptive_opts.min_epsilon).ln();
1880 initial_theta[at + 1] = eps_g_init.max(adaptive_opts.min_epsilon).ln();
1881 initial_theta[at + 2] = eps_c_init.max(adaptive_opts.min_epsilon).ln();
1882
1883 let hyperspecs = build_spatial_adaptive_hyperspecs(runtime_caches.len());
1884 let zero_psi_op: std::sync::Arc<dyn gam_custom_family::CustomFamilyPsiDerivativeOperator> =
1885 std::sync::Arc::new(gam_custom_family::ZeroPsiDerivativeOperator::new(
1886 baseline.design.design.nrows(),
1887 baseline.design.design.ncols(),
1888 ));
1889 let derivative_blocks = vec![
1890 hyperspecs
1891 .par_iter()
1892 .map(|_| CustomFamilyBlockPsiDerivative {
1893 penalty_index: None,
1894 x_psi: Array2::<f64>::zeros((0, 0)),
1895 s_psi: Array2::<f64>::zeros((0, 0)),
1896 s_psi_components: None,
1897 s_psi_penalty_components: None,
1898 x_psi_psi: None,
1899 s_psi_psi: None,
1900 s_psi_psi_components: None,
1901 s_psi_psi_penalty_components: None,
1902 implicit_operator: Some(std::sync::Arc::clone(&zero_psi_op)),
1903 implicit_axis: 0,
1904 implicit_group_id: None,
1905 })
1906 .collect::<Vec<_>>(),
1907 ];
1908
1909 let mixture_link_state = options
1910 .mixture_link
1911 .clone()
1912 .as_ref()
1913 .map(state_fromspec)
1914 .transpose()
1915 .map_err(EstimationError::InvalidInput)?;
1916 let sas_link_state = options
1917 .sas_link
1918 .map(|spec| {
1919 if family.is_binomial_beta_logistic() {
1920 state_from_beta_logisticspec(spec)
1921 } else {
1922 state_from_sasspec(spec)
1923 }
1924 })
1925 .transpose()
1926 .map_err(EstimationError::InvalidInput)?;
1927 let latent_cloglog_state = options.latent_cloglog;
1928 let shared_y = Arc::new(y.to_owned());
1929 let sharedweights = Arc::new(weights.to_owned());
1930 let shared_design = baseline
1931 .design
1932 .design
1933 .try_to_dense_arc("spatial adaptive exact hyperfit design")
1934 .map_err(EstimationError::InvalidInput)?;
1935 let shared_offset = Arc::new(offset.to_owned());
1936 let shared_runtime_caches = Arc::new(runtime_caches.to_vec());
1937 let shared_hyperspecs = Arc::new(hyperspecs.clone());
1938 let zero_quadratic = Arc::new(Array2::<f64>::zeros((
1939 baseline.design.design.ncols(),
1940 baseline.design.design.ncols(),
1941 )));
1942 let base_family = SpatialAdaptiveExactFamily {
1943 family: family.clone(),
1944 latent_cloglog_state,
1945 mixture_link_state: mixture_link_state.clone(),
1946 sas_link_state,
1947 y: shared_y.clone(),
1948 weights: sharedweights.clone(),
1949 design: shared_design.clone(),
1950 offset: shared_offset.clone(),
1951 linear_constraints: baseline.design.linear_constraints.clone(),
1952 runtime_caches: shared_runtime_caches.clone(),
1953 adaptive_params: Vec::new(),
1954 fixed_quadratichessian: zero_quadratic.clone(),
1955 hyperspecs: shared_hyperspecs.clone(),
1956 exact_eval_cache: Arc::new(Mutex::new(None)),
1957 };
1958
1959 let rho_dim = retained_penalties.len();
1960 let operator_slots_end = rho_dim + runtime_caches.len() * 3;
1961 const UNIFIED_LOG_WINDOW: f64 = 6.0;
1971 const RETAINED_LAMBDA_LOG_LOWER_FLOOR: f64 = -30.0;
1972 const RETAINED_LAMBDA_LOG_UPPER_CAP: f64 = 30.0;
1973 const OPERATOR_LAMBDA_LOG_LOWER_FLOOR: f64 = -10.0;
1974 const OPERATOR_LAMBDA_LOG_UPPER_CAP: f64 = 30.0;
1975 let epsilon_floor_log = adaptive_opts.min_epsilon.max(1e-12).ln();
1976 let anchored_bound = |idx: usize, sign: f64| -> f64 {
1977 let raw = initial_theta[idx] + sign * UNIFIED_LOG_WINDOW;
1978 if idx < rho_dim {
1979 raw.clamp(
1980 RETAINED_LAMBDA_LOG_LOWER_FLOOR,
1981 RETAINED_LAMBDA_LOG_UPPER_CAP,
1982 )
1983 } else if idx < operator_slots_end {
1984 raw.clamp(
1985 OPERATOR_LAMBDA_LOG_LOWER_FLOOR,
1986 OPERATOR_LAMBDA_LOG_UPPER_CAP,
1987 )
1988 } else {
1989 raw.max(epsilon_floor_log)
1990 }
1991 };
1992 let eps_lower =
1993 Array1::from_iter((0..initial_theta.len()).map(|idx| anchored_bound(idx, -1.0)));
1994 let eps_upper = Array1::from_iter((0..initial_theta.len()).map(|idx| anchored_bound(idx, 1.0)));
1995 let blockspec = ParameterBlockSpec {
1996 name: "eta".to_string(),
1997 design: baseline.design.design.clone(),
1998 offset: offset.to_owned(),
1999 penalties: retained_penalties
2000 .iter()
2001 .cloned()
2002 .map(PenaltyMatrix::Dense)
2003 .collect(),
2004 nullspace_dims: retained_nullspace_dims.clone(),
2005 initial_log_lambdas: Array1::from_vec(retained_log_lambdas.clone()),
2006 initial_beta: Some(baseline.fit.beta.clone()),
2007 gauge_priority: 100,
2008 jacobian_callback: None,
2009 stacked_design: None,
2010 stacked_offset: None,
2011 };
2012 let screening_cap = Arc::new(AtomicUsize::new(0));
2013 let outer_opts = BlockwiseFitOptions {
2014 inner_max_cycles: options.max_iter,
2015 inner_tol: options.tol,
2016 outer_max_iter: options.max_iter,
2017 outer_tol: options.tol,
2018 compute_covariance: false,
2019 screening_max_inner_iterations: Some(Arc::clone(&screening_cap)),
2020 ..BlockwiseFitOptions::default()
2021 };
2022
2023 use gam_solve::rho_optimizer::OuterProblem;
2024 use gam_problem::{DeclaredHessianForm, Derivative, HessianResult, OuterEval};
2025
2026 struct SpatialAdaptiveOuterState {
2027 warm_cache: Option<CustomFamilyWarmStart>,
2028 last_eval: Option<(
2029 Array1<f64>,
2030 f64,
2031 Array1<f64>,
2032 HessianResult,
2033 CustomFamilyWarmStart,
2034 )>,
2035 }
2036
2037 let n_theta = initial_theta.len();
2038
2039 let theta_bounds = Some((eps_lower.clone(), eps_upper.clone()));
2042 let clamp_theta = {
2043 let lo = eps_lower;
2044 let hi = eps_upper;
2045 move |theta: &Array1<f64>| -> Array1<f64> {
2046 let mut clamped = theta.clone();
2047 for i in 0..clamped.len() {
2048 clamped[i] = clamped[i].clamp(lo[i], hi[i]);
2049 }
2050 clamped
2051 }
2052 };
2053
2054 let decode_theta = |theta: &Array1<f64>| -> (Array1<f64>, Vec<SpatialAdaptiveTermHyperParams>) {
2055 let rho = theta.slice(s![..rho_dim]).to_owned();
2056 let adaptive_lambda_start = rho_dim;
2057 let adaptive_lambda_end = adaptive_lambda_start + runtime_caches.len() * 3;
2058 let eps = [
2059 theta[adaptive_lambda_end].exp(),
2060 theta[adaptive_lambda_end + 1].exp(),
2061 theta[adaptive_lambda_end + 2].exp(),
2062 ];
2063 let adaptive_params = runtime_caches
2064 .iter()
2065 .enumerate()
2066 .map(|(cache_idx, _)| SpatialAdaptiveTermHyperParams {
2067 lambda: [
2068 theta[adaptive_lambda_start + cache_idx * 3].exp(),
2069 theta[adaptive_lambda_start + cache_idx * 3 + 1].exp(),
2070 theta[adaptive_lambda_start + cache_idx * 3 + 2].exp(),
2071 ],
2072 epsilon: eps,
2073 })
2074 .collect::<Vec<_>>();
2075 (rho, adaptive_params)
2076 };
2077 let analytic_outer_hessian_available =
2078 gam_custom_family::joint_exact_analytic_outer_hessian_available()
2079 && base_family
2080 .exact_outer_derivative_order(std::slice::from_ref(&blockspec), &outer_opts)
2081 .has_hessian()
2082 && gam_custom_family::exact_newton_outer_geometry_supports_second_order_solver(
2083 &base_family,
2084 );
2085 let outer_max_iter = gam_custom_family::cost_gated_first_order_max_iter(
2086 options.max_iter,
2087 base_family.coefficient_gradient_cost(std::slice::from_ref(&blockspec)),
2088 analytic_outer_hessian_available,
2089 );
2090 if outer_max_iter < options.max_iter {
2091 log::info!(
2092 "[OUTER] exact spatial adaptive regularization: first-order work gate reduced outer_max_iter {} -> {}",
2093 options.max_iter,
2094 outer_max_iter,
2095 );
2096 }
2097 let problem = OuterProblem::new(n_theta)
2103 .with_gradient(Derivative::Analytic)
2104 .with_hessian(if analytic_outer_hessian_available {
2105 DeclaredHessianForm::Either
2106 } else {
2107 DeclaredHessianForm::Unavailable
2108 })
2109 .with_fallback_policy(gam_solve::rho_optimizer::FallbackPolicy::Disabled)
2110 .with_psi_dim(n_theta.saturating_sub(rho_dim))
2111 .with_tolerance(options.tol)
2112 .with_max_iter(outer_max_iter)
2113 .with_seed_config(gam_problem::SeedConfig::default())
2114 .with_screening_cap(Arc::clone(&screening_cap))
2115 .with_initial_rho(initial_theta.clone());
2116 let problem = if let Some((lo, hi)) = theta_bounds {
2117 problem.with_bounds(lo, hi)
2118 } else {
2119 problem
2120 };
2121
2122 let eval_outer = |st: &mut SpatialAdaptiveOuterState,
2123 theta: &Array1<f64>,
2124 order: gam_solve::rho_optimizer::OuterEvalOrder|
2125 -> Result<OuterEval, EstimationError> {
2126 let theta = clamp_theta(theta);
2127
2128 if let Some((cached_theta, cached_cost, cached_grad, cached_hess, cached_warm)) =
2129 &st.last_eval
2130 && cached_theta.len() == theta.len()
2131 && cached_theta
2132 .iter()
2133 .zip(theta.iter())
2134 .all(|(&a, &b)| (a - b).abs() <= 1e-12)
2135 && (!matches!(
2136 order,
2137 gam_solve::rho_optimizer::OuterEvalOrder::ValueGradientHessian
2138 ) || analytic_outer_hessian_available)
2139 {
2140 st.warm_cache = Some(cached_warm.clone());
2141 return Ok(OuterEval {
2142 cost: *cached_cost,
2143 gradient: cached_grad.clone(),
2144 hessian: if matches!(
2145 order,
2146 gam_solve::rho_optimizer::OuterEvalOrder::ValueGradientHessian
2147 ) && analytic_outer_hessian_available
2148 {
2149 cached_hess.clone()
2150 } else {
2151 HessianResult::Unavailable
2152 },
2153 inner_beta_hint: None,
2154 });
2155 }
2156
2157 let (rho, adaptive_params) = decode_theta(&theta);
2158 let family_eval = base_family.with_adaptive_params(adaptive_params, zero_quadratic.clone());
2159 let need_hessian = matches!(
2160 order,
2161 gam_solve::rho_optimizer::OuterEvalOrder::ValueGradientHessian
2162 ) && analytic_outer_hessian_available;
2163 let result = evaluate_custom_family_joint_hyper(
2164 &family_eval,
2165 std::slice::from_ref(&blockspec),
2166 &outer_opts,
2167 &rho,
2168 &derivative_blocks,
2169 st.warm_cache.as_ref(),
2170 if need_hessian {
2171 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueGradientHessian
2172 } else {
2173 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueAndGradient
2174 },
2175 )
2176 .map_err(|e| {
2177 EstimationError::RemlOptimizationFailed(format!("spatial adaptive eval failed: {e}"))
2178 })?;
2179 if !result.inner_converged {
2180 st.warm_cache = Some(result.warm_start.clone());
2181 return Err(EstimationError::RemlOptimizationFailed(
2182 "exact spatial adaptive inner solve did not converge".to_string(),
2183 ));
2184 }
2185 if !result.objective.is_finite() || result.gradient.iter().any(|v| !v.is_finite()) {
2186 return Err(EstimationError::RemlOptimizationFailed(
2187 "exact spatial adaptive objective returned non-finite values".to_string(),
2188 ));
2189 }
2190 let hessian_result = if need_hessian {
2191 if !result.outer_hessian.is_analytic() {
2192 return Err(EstimationError::RemlOptimizationFailed(
2193 "exact spatial adaptive objective did not return an exact outer Hessian"
2194 .to_string(),
2195 ));
2196 }
2197 match result.outer_hessian.dim() {
2198 Some(dim) if dim == theta.len() => {}
2199 Some(dim) => {
2200 return Err(EstimationError::RemlOptimizationFailed(format!(
2201 "exact spatial adaptive outer Hessian dimension mismatch: got {dim}, expected {}",
2202 theta.len(),
2203 )));
2204 }
2205 None => {
2206 return Err(EstimationError::RemlOptimizationFailed(
2207 "exact spatial adaptive objective did not report an outer Hessian dimension"
2208 .to_string(),
2209 ));
2210 }
2211 }
2212 st.last_eval = Some((
2213 theta.clone(),
2214 result.objective,
2215 result.gradient.clone(),
2216 result.outer_hessian.clone(),
2217 result.warm_start.clone(),
2218 ));
2219 result.outer_hessian
2220 } else {
2221 HessianResult::Unavailable
2222 };
2223 st.warm_cache = Some(result.warm_start);
2224 Ok(OuterEval {
2225 cost: result.objective,
2226 gradient: result.gradient,
2227 hessian: hessian_result,
2228 inner_beta_hint: None,
2229 })
2230 };
2231
2232 let mut obj = problem.build_objective_with_screening_proxy(
2233 SpatialAdaptiveOuterState {
2234 warm_cache: None,
2235 last_eval: None,
2236 },
2237 |st: &mut SpatialAdaptiveOuterState, theta: &Array1<f64>| {
2238 let theta = clamp_theta(theta);
2239 let (rho, adaptive_params) = decode_theta(&theta);
2240 let family_eval =
2241 base_family.with_adaptive_params(adaptive_params, zero_quadratic.clone());
2242 let result = evaluate_custom_family_joint_hyper(
2243 &family_eval,
2244 std::slice::from_ref(&blockspec),
2245 &outer_opts,
2246 &rho,
2247 &derivative_blocks,
2248 st.warm_cache.as_ref(),
2249 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueOnly,
2250 )
2251 .map_err(|e| {
2252 EstimationError::RemlOptimizationFailed(format!(
2253 "spatial adaptive cost eval failed: {e}"
2254 ))
2255 })?;
2256 if !result.inner_converged {
2257 st.warm_cache = Some(result.warm_start);
2258 return Err(EstimationError::RemlOptimizationFailed(
2259 "exact spatial adaptive cost inner solve did not converge".to_string(),
2260 ));
2261 }
2262 st.warm_cache = Some(result.warm_start);
2263 Ok(result.objective)
2264 },
2265 |st: &mut SpatialAdaptiveOuterState, theta: &Array1<f64>| {
2266 eval_outer(
2267 st,
2268 theta,
2269 if analytic_outer_hessian_available {
2270 gam_solve::rho_optimizer::OuterEvalOrder::ValueGradientHessian
2271 } else {
2272 gam_solve::rho_optimizer::OuterEvalOrder::ValueAndGradient
2273 },
2274 )
2275 },
2276 |st: &mut SpatialAdaptiveOuterState,
2277 theta: &Array1<f64>,
2278 order: gam_solve::rho_optimizer::OuterEvalOrder| {
2279 eval_outer(st, theta, order)
2280 },
2281 Some(|st: &mut SpatialAdaptiveOuterState| {
2282 st.warm_cache = None;
2283 st.last_eval = None;
2284 }),
2285 Some(|st: &mut SpatialAdaptiveOuterState, theta: &Array1<f64>| {
2286 let theta = clamp_theta(theta);
2287 let (rho, adaptive_params) = decode_theta(&theta);
2288 let family_eval =
2289 base_family.with_adaptive_params(adaptive_params, zero_quadratic.clone());
2290 let result = evaluate_custom_family_joint_hyper_efs(
2291 &family_eval,
2292 std::slice::from_ref(&blockspec),
2293 &outer_opts,
2294 &rho,
2295 &derivative_blocks,
2296 st.warm_cache.as_ref(),
2297 )
2298 .map_err(|e| {
2299 EstimationError::RemlOptimizationFailed(format!(
2300 "spatial adaptive EFS eval failed: {e}"
2301 ))
2302 })?;
2303 if !result.inner_converged {
2304 st.warm_cache = Some(result.warm_start);
2305 return Err(EstimationError::RemlOptimizationFailed(
2306 "exact spatial adaptive EFS inner solve did not converge".to_string(),
2307 ));
2308 }
2309 st.warm_cache = Some(result.warm_start);
2310 Ok(result.efs_eval)
2311 }),
2312 |st: &mut SpatialAdaptiveOuterState, theta: &Array1<f64>| {
2324 let theta = clamp_theta(theta);
2325 let (rho, adaptive_params) = decode_theta(&theta);
2326 let family_eval =
2327 base_family.with_adaptive_params(adaptive_params, zero_quadratic.clone());
2328 let result = evaluate_custom_family_joint_hyper(
2329 &family_eval,
2330 std::slice::from_ref(&blockspec),
2331 &outer_opts,
2332 &rho,
2333 &derivative_blocks,
2334 st.warm_cache.as_ref(),
2335 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueOnly,
2336 )
2337 .map_err(|e| {
2338 EstimationError::RemlOptimizationFailed(format!(
2339 "spatial adaptive screening eval failed: {e}"
2340 ))
2341 })?;
2342 st.warm_cache = Some(result.warm_start);
2343 Ok(result.objective)
2344 },
2345 );
2346
2347 let outer_result = problem
2348 .run(&mut obj, "exact spatial adaptive regularization")
2349 .map_err(|e| {
2350 EstimationError::InvalidInput(format!(
2351 "exact spatial adaptive outer optimization failed: {e}"
2352 ))
2353 })?;
2354 if !outer_result.converged {
2355 let rel_to_cost_threshold = options.tol * (1.0_f64 + outer_result.final_value.abs());
2372 if let Some(final_grad) = outer_result
2376 .final_grad_norm
2377 .filter(|v| v.is_finite() && *v <= rel_to_cost_threshold)
2378 {
2379 log::info!(
2380 "[spatial-adaptive] outer optimization hit max_iter={} but \
2381 projected gradient norm {:.3e} ≤ τ·(1+|f|) = {:.3e} \
2382 (τ={:.3e}, |f|={:.3e}); accepting iterate under the mgcv-style \
2383 relative-to-cost REML convergence criterion.",
2384 outer_result.iterations,
2385 final_grad,
2386 rel_to_cost_threshold,
2387 options.tol,
2388 outer_result.final_value.abs(),
2389 );
2390 } else {
2391 crate::bail_invalid_estim!(
2392 "exact spatial adaptive outer optimization did not converge after {} iterations (final_objective={:.6e}, final_grad_norm={})",
2393 outer_result.iterations,
2394 outer_result.final_value,
2395 outer_result.final_grad_norm_report(),
2396 );
2397 }
2398 }
2399 let outer_iterations = outer_result.iterations;
2400 let outer_grad_norm: Option<f64> = outer_result.final_grad_norm;
2403 let theta_star = outer_result.rho;
2404 let rho_star = theta_star.slice(s![..rho_dim]).to_owned();
2405 let adaptive_lambda_start = rho_dim;
2406 let adaptive_lambda_end = adaptive_lambda_start + runtime_caches.len() * 3;
2407 let eps_star = [
2408 theta_star[adaptive_lambda_end].exp(),
2409 theta_star[adaptive_lambda_end + 1].exp(),
2410 theta_star[adaptive_lambda_end + 2].exp(),
2411 ];
2412 let adaptive_params = runtime_caches
2413 .iter()
2414 .enumerate()
2415 .map(|(cache_idx, _)| SpatialAdaptiveTermHyperParams {
2416 lambda: [
2417 theta_star[adaptive_lambda_start + cache_idx * 3].exp(),
2418 theta_star[adaptive_lambda_start + cache_idx * 3 + 1].exp(),
2419 theta_star[adaptive_lambda_start + cache_idx * 3 + 2].exp(),
2420 ],
2421 epsilon: eps_star,
2422 })
2423 .collect::<Vec<_>>();
2424 let mut fixed_total = Array2::<f64>::zeros((
2425 baseline.design.design.ncols(),
2426 baseline.design.design.ncols(),
2427 ));
2428 for (idx, penalty) in retained_penalties.iter().enumerate() {
2429 fixed_total.scaled_add(rho_star[idx].exp(), penalty);
2430 }
2431 let final_family =
2432 base_family.with_adaptive_params(adaptive_params.clone(), Arc::new(fixed_total.clone()));
2433 let final_blockspec = ParameterBlockSpec {
2434 name: "eta".to_string(),
2435 design: baseline.design.design.clone(),
2436 offset: offset.to_owned(),
2437 penalties: vec![],
2438 nullspace_dims: vec![],
2439 initial_log_lambdas: Array1::zeros(0),
2440 initial_beta: Some(baseline.fit.beta.clone()),
2441 gauge_priority: 100,
2442 jacobian_callback: None,
2443 stacked_design: None,
2444 stacked_offset: None,
2445 };
2446 let final_fit = fit_custom_family(
2447 &final_family,
2448 &[final_blockspec],
2449 &BlockwiseFitOptions {
2450 inner_max_cycles: options.max_iter,
2451 inner_tol: options.tol,
2452 outer_max_iter: 1,
2453 outer_tol: options.tol,
2454 compute_covariance: true,
2455 ..BlockwiseFitOptions::default()
2456 },
2457 )
2458 .map_err(EstimationError::CustomFamily)?;
2459 let beta = final_fit.block_states[0].beta.clone();
2460 let final_eval = final_family
2461 .exact_evaluation(&beta)
2462 .map_err(EstimationError::InvalidInput)?;
2463 let penalized_hessian = final_eval
2464 .totalobjectivehessian(&final_family.design)
2465 .map_err(EstimationError::InvalidInput)?;
2466 let beta_covariance = final_fit.covariance_conditional.clone();
2467 let beta_standard_errors = beta_covariance
2468 .as_ref()
2469 .map(|cov| Array1::from_iter((0..cov.nrows()).map(|i| cov[[i, i]].max(0.0).sqrt())));
2470
2471 let mut full_lambdas = baseline.fit.lambdas.clone();
2472 for (idx, &global_idx) in retained_global_indices.iter().enumerate() {
2473 full_lambdas[global_idx] = rho_star[idx].exp();
2474 }
2475 for (cache_idx, cache) in runtime_caches.iter().enumerate() {
2476 full_lambdas[cache.mass_penalty_global_idx] = adaptive_params[cache_idx].lambda[0];
2477 full_lambdas[cache.tension_penalty_global_idx] = adaptive_params[cache_idx].lambda[1];
2478 full_lambdas[cache.stiffness_penalty_global_idx] = adaptive_params[cache_idx].lambda[2];
2479 }
2480
2481 let deviance = if family.is_gaussian_identity() {
2482 y.iter()
2483 .zip(final_eval.obs.mu.iter())
2484 .zip(weights.iter())
2485 .map(|((&yy, &mu), &w)| w.max(0.0) * (yy - mu) * (yy - mu))
2486 .sum()
2487 } else {
2488 -2.0 * final_eval.obs.log_likelihood
2489 };
2490 let mut local_penalty_blocks =
2491 Vec::<PenaltySpec>::with_capacity(baseline.design.penalties.len());
2492 for (global_idx, bp) in baseline.design.penalties.iter().enumerate() {
2493 if adaptive_penalty_indices.contains(&global_idx) {
2494 let cache = runtime_caches
2495 .iter()
2496 .find(|cache| {
2497 cache.mass_penalty_global_idx == global_idx
2498 || cache.tension_penalty_global_idx == global_idx
2499 || cache.stiffness_penalty_global_idx == global_idx
2500 })
2501 .ok_or_else(|| {
2502 EstimationError::InvalidInput(format!(
2503 "missing runtime cache for adaptive penalty index {global_idx}"
2504 ))
2505 })?;
2506 let cache_idx = runtime_caches
2507 .iter()
2508 .position(|c| {
2509 c.mass_penalty_global_idx == global_idx
2510 || c.tension_penalty_global_idx == global_idx
2511 || c.stiffness_penalty_global_idx == global_idx
2512 })
2513 .ok_or_else(|| {
2514 EstimationError::InvalidInput(format!(
2515 "missing adaptive cache position for penalty index {global_idx}"
2516 ))
2517 })?;
2518 let state = &final_eval.adaptive_states[cache_idx];
2519 let local = if cache.mass_penalty_global_idx == global_idx {
2520 scalar_operatorhessian(&cache.d0, &state.magnitude.betahessian_diag())
2521 .mapv(|v| adaptive_params[cache_idx].lambda[0] * v)
2522 } else if cache.tension_penalty_global_idx == global_idx {
2523 grouped_operatorhessian(
2524 &cache.d1,
2525 cache.dimension,
2526 &state.gradient.betahessian_blocks(),
2527 )?
2528 .mapv(|v| adaptive_params[cache_idx].lambda[1] * v)
2529 } else {
2530 grouped_operatorhessian(
2531 &cache.d2,
2532 cache.dimension * cache.dimension,
2533 &state.curvature.betahessian_blocks(),
2534 )?
2535 .mapv(|v| adaptive_params[cache_idx].lambda[2] * v)
2536 };
2537 local_penalty_blocks.push(PenaltySpec::Dense(penalty_matrixwith_local_block(
2539 baseline.design.design.ncols(),
2540 cache.coeff_global_range.clone(),
2541 &local,
2542 )));
2543 } else {
2544 local_penalty_blocks.push(PenaltySpec::Dense(
2545 bp.to_global(p_total).mapv(|v| v * full_lambdas[global_idx]),
2546 ));
2547 }
2548 }
2549 let (edf_by_block, penalty_block_trace, edf_total) = if let Some(cov) = beta_covariance.as_ref()
2550 {
2551 exact_bounded_edf(
2552 &local_penalty_blocks,
2553 &Array1::from_elem(local_penalty_blocks.len(), 1.0),
2554 cov,
2555 )?
2556 } else {
2557 (
2558 vec![0.0; local_penalty_blocks.len()],
2559 vec![0.0; local_penalty_blocks.len()],
2560 0.0,
2561 )
2562 };
2563 let stable_penalty_term =
2564 2.0 * final_eval.adaptive_penalty_value + beta.dot(&fixed_total.dot(&beta));
2565 let standard_deviation = if family.is_gaussian_identity() {
2566 let denom = (y.len() as f64 - edf_total).max(1.0);
2567 (deviance / denom).sqrt()
2568 } else {
2569 1.0
2570 };
2571 let maps = compute_spatial_adaptiveweights_for_beta(
2572 &beta,
2573 runtime_caches,
2574 eps_star[0],
2575 eps_star[1],
2576 eps_star[2],
2577 adaptive_opts.weight_floor,
2578 adaptive_opts.weight_ceiling,
2579 beta_covariance.as_ref(),
2583 )?
2584 .into_iter()
2585 .zip(runtime_caches.iter())
2586 .map(|(w, cache)| AdaptiveSpatialMap {
2587 termname: cache.termname.clone(),
2588 feature_cols: cache.feature_cols.clone(),
2589 collocation_points: cache.collocation_points.clone(),
2590 inv_magweight: w.inv_magweight,
2591 invgradweight: w.invgradweight,
2592 inv_lapweight: w.inv_lapweight,
2593 })
2594 .collect::<Vec<_>>();
2595 let fitted_link = if family.is_latent_cloglog() {
2596 FittedLinkState::LatentCLogLog {
2597 state: latent_cloglog_state
2598 .expect("BinomialLatentCLogLog requires an explicit latent-cloglog state"),
2599 }
2600 } else if family.is_binomial_mixture() {
2601 mixture_link_state
2602 .clone()
2603 .map(|state| FittedLinkState::Mixture {
2604 state,
2605 covariance: None,
2606 })
2607 .unwrap_or(FittedLinkState::Standard(None))
2608 } else if family.is_binomial_sas() {
2609 sas_link_state
2610 .map(|state| FittedLinkState::Sas {
2611 state,
2612 covariance: None,
2613 })
2614 .unwrap_or(FittedLinkState::Standard(None))
2615 } else if family.is_binomial_beta_logistic() {
2616 sas_link_state
2617 .map(|state| FittedLinkState::BetaLogistic {
2618 state,
2619 covariance: None,
2620 })
2621 .unwrap_or(FittedLinkState::Standard(None))
2622 } else {
2623 FittedLinkState::Standard(None)
2624 };
2625 let max_abs_eta = final_eval
2626 .obs
2627 .eta
2628 .iter()
2629 .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2630 let fitted = FittedTermCollection {
2631 fit: {
2632 let log_lambdas = full_lambdas.mapv(|v| v.max(1e-300).ln());
2633 let inf = FitInference {
2634 edf_by_block,
2635 penalty_block_trace,
2636 edf_total,
2637 smoothing_correction: None,
2638 penalized_hessian: penalized_hessian.clone().into(),
2641 working_weights: final_eval.obs.fisherweight.clone(),
2642 working_response: {
2643 let mut out = final_eval.obs.eta.clone();
2644 for i in 0..out.len() {
2645 let wi = final_eval.obs.fisherweight[i].max(1e-12);
2646 out[i] += final_eval.obs.score[i] / wi;
2647 }
2648 out
2649 },
2650 reparam_qs: None,
2651 dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2652 beta_covariance: beta_covariance
2653 .clone()
2654 .map(gam_problem::dispersion_cov::PhiScaledCovariance::from),
2655 beta_standard_errors,
2656 beta_covariance_corrected: None,
2657 beta_standard_errors_corrected: None,
2658 beta_covariance_frequentist: None,
2659 coefficient_influence: None,
2660 weighted_gram: None,
2661 bias_correction_beta: None,
2662 };
2663 let geometry = Some(gam_solve::estimate::FitGeometry {
2664 penalized_hessian: penalized_hessian.into(),
2665 working_weights: inf.working_weights.clone(),
2666 working_response: inf.working_response.clone(),
2667 });
2668 let covariance_conditional = beta_covariance;
2669 let pirls_status_val = if final_fit.outer_converged {
2670 gam_solve::pirls::PirlsStatus::Converged
2671 } else {
2672 gam_solve::pirls::PirlsStatus::StalledAtValidMinimum
2673 };
2674 UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
2675 blocks: vec![gam_solve::estimate::FittedBlock {
2676 beta: beta.clone(),
2677 role: gam_problem::BlockRole::Mean,
2678 edf: edf_total,
2679 lambdas: full_lambdas.clone(),
2680 }],
2681 log_lambdas,
2682 lambdas: full_lambdas,
2683 likelihood_scale: family.default_scale_metadata(),
2684 likelihood_family: Some(family),
2685 log_likelihood_normalization:
2686 gam_spec::LogLikelihoodNormalization::UserProvided,
2687 log_likelihood: final_eval.obs.log_likelihood,
2688 deviance,
2689 reml_score: final_fit.penalized_objective,
2690 stable_penalty_term,
2691 penalized_objective: final_fit.penalized_objective,
2692 used_device: false,
2693 outer_iterations,
2694 outer_converged: final_fit.outer_converged,
2695 outer_gradient_norm: outer_grad_norm,
2696 standard_deviation,
2697 covariance_conditional,
2698 covariance_corrected: None,
2699 inference: Some(inf),
2700 fitted_link,
2701 geometry,
2702 block_states: Vec::new(),
2703 pirls_status: pirls_status_val,
2704 max_abs_eta,
2705 constraint_kkt: None,
2706 artifacts: gam_solve::estimate::FitArtifacts {
2707 pirls: None,
2708 ..Default::default()
2709 },
2710 inner_cycles: 0,
2711 })?
2712 },
2713 design: baseline.design,
2714 adaptive_diagnostics: Some(AdaptiveRegularizationDiagnostics {
2715 epsilon_0: eps_star[0],
2716 epsilon_g: eps_star[1],
2717 epsilon_c: eps_star[2],
2718 epsilon_outer_iterations: outer_iterations,
2719 mm_iterations: 0,
2720 converged: final_fit.outer_converged,
2721 maps,
2722 }),
2723 };
2724 enforce_term_constraint_feasibility(&fitted.design, &fitted.fit)?;
2725 Ok(fitted)
2726}
2727
2728fn relax_smoothing_rho_prior(
2760 options: &FitOptions,
2761 design: &TermCollectionDesign,
2762) -> gam_spec::RhoPrior {
2763 use gam_terms::basis::BasisMetadata;
2764 let base = &options.rho_prior;
2765 if matches!(
2768 base,
2769 gam_spec::RhoPrior::Flat | gam_spec::RhoPrior::Independent(_)
2770 ) {
2771 return base.clone();
2772 }
2773 let has_link_aux = options.sas_link.is_some()
2793 || options.optimize_sas
2794 || options.mixture_link.is_some()
2795 || options.optimize_mixture;
2796 let has_moving_kappa = design.smooth.terms.iter().any(|t| {
2797 matches!(
2798 t.metadata,
2799 BasisMetadata::Matern { .. }
2800 | BasisMetadata::Duchon { .. }
2801 | BasisMetadata::Sphere { .. }
2802 | BasisMetadata::SphereHarmonics { .. }
2803 | BasisMetadata::ConstantCurvature { .. }
2804 | BasisMetadata::MeasureJet { .. }
2805 )
2806 });
2807 let length_safe = !has_link_aux && !has_moving_kappa;
2814 if !length_safe {
2815 return base.clone();
2816 }
2817 let coords = &design.penaltyinfo;
2818 if coords.is_empty() {
2819 return base.clone();
2820 }
2821 let n_obs = design.design.nrows();
2832 let p_total = design.design.ncols();
2833 let underdetermined = n_obs < 2 * p_total;
2864 let relaxable_terms: std::collections::HashSet<&str> = design
2876 .smooth
2877 .terms
2878 .iter()
2879 .filter(|t| {
2880 matches!(
2881 t.metadata,
2882 BasisMetadata::BSpline1D { .. }
2883 | BasisMetadata::ThinPlate { .. }
2884 | BasisMetadata::TensorBSpline { .. }
2885 )
2886 && matches!(t.shape, gam_terms::smooth::ShapeConstraint::None)
2900 })
2901 .map(|t| t.name.as_str())
2902 .collect();
2903 let any_relaxed = coords.iter().any(|info| {
2904 info.termname
2905 .as_deref()
2906 .is_some_and(|name| relaxable_terms.contains(name))
2907 });
2908 if !any_relaxed {
2909 return base.clone();
2910 }
2911 let relaxed_prior = if underdetermined {
2916 gam_spec::RhoPrior::Normal {
2917 mean: 0.0,
2918 sd: RELAX_UNDERDETERMINED_RHO_SD,
2919 }
2920 } else {
2921 gam_spec::RhoPrior::Flat
2922 };
2923 let nullspace_select_prior = gam_spec::RhoPrior::PenalizedComplexity {
2950 upper: NULLSPACE_SELECT_PC_UPPER,
2951 tail_prob: NULLSPACE_SELECT_PC_TAIL_PROB,
2952 };
2953 let nullspace_degeneracy_prior = gam_spec::RhoPrior::Normal {
2980 mean: 0.0,
2981 sd: NULLSPACE_WELLDET_DEGENERACY_RHO_SD,
2982 };
2983 let per_coord = coords
2984 .iter()
2985 .map(|info| {
2986 let relax = info
2987 .termname
2988 .as_deref()
2989 .is_some_and(|name| relaxable_terms.contains(name));
2990 if !relax {
2991 return base.clone();
2992 }
2993 let is_nullspace =
2994 matches!(info.penalty.source, PenaltySource::DoublePenaltyNullspace);
2995 if is_nullspace {
3034 if underdetermined {
3035 nullspace_select_prior.clone()
3036 } else {
3037 nullspace_degeneracy_prior.clone()
3038 }
3039 } else {
3040 relaxed_prior.clone()
3041 }
3042 })
3043 .collect::<Vec<_>>();
3044 gam_spec::RhoPrior::Independent(per_coord)
3045}
3046
3047const RELAX_UNDERDETERMINED_RHO_SD: f64 = 15.0;
3060
3061const NULLSPACE_SELECT_PC_UPPER: f64 = 0.05;
3079
3080const NULLSPACE_SELECT_PC_TAIL_PROB: f64 = 0.01;
3090
3091fn adaptive_fit_options_base(options: &FitOptions, design: &TermCollectionDesign) -> FitOptions {
3092 FitOptions {
3093 latent_cloglog: options.latent_cloglog,
3094 mixture_link: options.mixture_link.clone(),
3095 optimize_mixture: options.optimize_mixture,
3096 sas_link: options.sas_link,
3097 optimize_sas: options.optimize_sas,
3098 compute_inference: options.compute_inference,
3099 skip_rho_posterior_inference: options.skip_rho_posterior_inference,
3100 max_iter: options.max_iter,
3101 tol: options.tol,
3102 nullspace_dims: design.nullspace_dims.clone(),
3103 linear_constraints: design.linear_constraints.clone(),
3104 firth_bias_reduction: options.firth_bias_reduction,
3105 adaptive_regularization: None,
3106 penalty_shrinkage_floor: options.penalty_shrinkage_floor,
3107 rho_prior: options.rho_prior.clone(),
3110 kronecker_penalty_system: design.kronecker_penalty_system(),
3111 kronecker_factored: design
3112 .smooth
3113 .terms
3114 .iter()
3115 .find_map(|t| t.kronecker_factored.clone()),
3116 persist_warm_start_disk: options.persist_warm_start_disk,
3117 }
3118}
3119
3120fn superseded_fit_options(options: &FitOptions) -> FitOptions {
3121 let mut fit_options = options.clone();
3122 fit_options.skip_rho_posterior_inference = true;
3123 fit_options
3124}
3125
3126#[derive(Clone)]
3127struct BoundedLinearTermMeta {
3128 col_idx: usize,
3129 min: f64,
3130 max: f64,
3131 prior: BoundedCoefficientPriorSpec,
3132}
3133
3134struct BoundedEffectiveJacobian {
3158 design: Array2<f64>,
3159 bounded_terms: Vec<BoundedLinearTermMeta>,
3160}
3161
3162impl BlockEffectiveJacobian for BoundedEffectiveJacobian {
3163 fn effective_jacobian_rows(
3164 &self,
3165 state: &FamilyLinearizationState<'_>,
3166 rows: std::ops::Range<usize>,
3167 ) -> Result<Array2<f64>, String> {
3168 let p = self.design.ncols();
3169 let n = self.design.nrows();
3170 let rows = rows.start.min(n)..rows.end.min(n);
3171 if !state.beta.is_empty() {
3172 if state.beta.len() != p {
3173 return Err(format!(
3174 "BoundedEffectiveJacobian::effective_jacobian_at: beta length {} != design \
3175 ncols {p}",
3176 state.beta.len(),
3177 ));
3178 }
3179 if state.beta.iter().any(|v| v.is_nan()) {
3180 return Err(
3181 "BoundedEffectiveJacobian::effective_jacobian_at: beta contains NaN"
3182 .to_string(),
3183 );
3184 }
3185 }
3186 let mut jac = self
3187 .design
3188 .slice(ndarray::s![rows.start..rows.end, ..])
3189 .to_owned();
3190 for term in &self.bounded_terms {
3191 let theta = if state.beta.is_empty() {
3192 0.0
3193 } else {
3194 state.beta[term.col_idx]
3195 };
3196 let (_, _, db_dtheta, _, _) = bounded_latent_derivatives(theta, term.min, term.max);
3197 jac.column_mut(term.col_idx).mapv_inplace(|v| v * db_dtheta);
3198 }
3199 Ok(jac)
3200 }
3201}
3202
3203#[derive(Clone)]
3204struct BoundedLinearFamily {
3205 family: LikelihoodSpec,
3206 latent_cloglog_state: Option<LatentCLogLogState>,
3207 mixture_link_state: Option<MixtureLinkState>,
3208 sas_link_state: Option<SasLinkState>,
3209 y: Array1<f64>,
3210 weights: Array1<f64>,
3211 design: Array2<f64>,
3212 designzeroed: Array2<f64>,
3213 offset: Array1<f64>,
3214 bounded_terms: Vec<BoundedLinearTermMeta>,
3215}
3216
3217#[derive(Clone)]
3218struct StandardFamilyObservationState {
3219 eta: Array1<f64>,
3220 mu: Array1<f64>,
3221 score: Array1<f64>,
3222 fisherweight: Array1<f64>,
3223 neghessian_eta: Array1<f64>,
3224 neghessian_eta_derivative: Array1<f64>,
3225 log_likelihood: f64,
3226}
3227
3228fn bounded_logit(z: f64) -> f64 {
3229 let zc = z.clamp(1e-12, 1.0 - 1e-12);
3230 (zc / (1.0 - zc)).ln()
3231}
3232
3233fn stable_sigmoid(theta: f64) -> f64 {
3234 if theta >= 0.0 {
3235 let exp_neg = (-theta).exp();
3236 1.0 / (1.0 + exp_neg)
3237 } else {
3238 let exp_pos = theta.exp();
3239 exp_pos / (1.0 + exp_pos)
3240 }
3241}
3242
3243fn bounded_latent_to_user(theta: f64, min: f64, max: f64) -> (f64, f64, f64) {
3244 let z = stable_sigmoid(theta);
3245 let width = max - min;
3246 let beta = min + width * z;
3247 let db_dtheta = width * z * (1.0 - z);
3248 (beta, z, db_dtheta)
3249}
3250
3251fn bounded_user_to_latent(beta: f64, min: f64, max: f64) -> f64 {
3262 let width = max - min;
3263 if width <= 0.0 || !width.is_finite() {
3264 return 0.0;
3265 }
3266 let z = (beta - min) / width;
3267 bounded_logit(z)
3268}
3269
3270#[derive(Debug, Clone, Copy)]
3274pub struct BoundedSampleColumn {
3275 pub col_idx: usize,
3277 pub min: f64,
3279 pub max: f64,
3281}
3282
3283pub fn sample_bounded_latent_posterior_internal(
3321 beta_user: &Array1<f64>,
3322 user_hessian: &Array2<f64>,
3323 bounded_columns: &[BoundedSampleColumn],
3324 n_draws: usize,
3325 sqrt_cov_scale: f64,
3326 base_seed: u64,
3327) -> Result<Array2<f64>, EstimationError> {
3328 let p = beta_user.len();
3329 if user_hessian.nrows() != p || user_hessian.ncols() != p {
3330 crate::bail_invalid_estim!(
3331 "bounded posterior sampling dimension mismatch: mode has {p} entries, user Hessian is {}x{}",
3332 user_hessian.nrows(),
3333 user_hessian.ncols()
3334 );
3335 }
3336
3337 let mut theta_mode = beta_user.clone();
3339 let mut jac_diag = Array1::<f64>::ones(p);
3340 for bc in bounded_columns {
3341 if bc.col_idx >= p {
3342 crate::bail_invalid_estim!(
3343 "bounded posterior sampling: bounded column index {} out of range for {p} coefficients",
3344 bc.col_idx
3345 );
3346 }
3347 let theta_i = bounded_user_to_latent(beta_user[bc.col_idx], bc.min, bc.max);
3348 let (_, _, db_dtheta) = bounded_latent_to_user(theta_i, bc.min, bc.max);
3349 theta_mode[bc.col_idx] = theta_i;
3350 jac_diag[bc.col_idx] = db_dtheta.max(1e-12);
3355 }
3356
3357 let mut h_latent = user_hessian.clone();
3360 for i in 0..p {
3361 let ji = jac_diag[i];
3362 if ji != 1.0 {
3363 h_latent.row_mut(i).mapv_inplace(|v| v * ji);
3364 h_latent.column_mut(i).mapv_inplace(|v| v * ji);
3365 }
3366 }
3367
3368 use gam_linalg::faer_ndarray::FaerCholesky as _;
3371 use rand::SeedableRng as _;
3372 let chol = h_latent.cholesky(faer::Side::Lower).map_err(|err| {
3373 EstimationError::InvalidInput(format!(
3374 "bounded posterior sampling: Cholesky of the latent penalized Hessian failed: {err:?}"
3375 ))
3376 })?;
3377 let l = chol.lower_triangular();
3378
3379 let mut draws = Array2::<f64>::zeros((n_draws, p));
3380 let mut eps = Array1::<f64>::zeros(p);
3381 let mut delta = Array1::<f64>::zeros(p);
3382 let mut rng = rand::rngs::StdRng::seed_from_u64(base_seed);
3383 for k in 0..n_draws {
3384 for e in eps.iter_mut() {
3385 *e = standard_normal_draw(&mut rng);
3386 }
3387 solve_lower_transpose_into(&l, &eps, &mut delta);
3388 for i in 0..p {
3389 draws[(k, i)] = theta_mode[i] + sqrt_cov_scale * delta[i];
3392 }
3393 for bc in bounded_columns {
3396 let (beta_draw, _, _) = bounded_latent_to_user(draws[(k, bc.col_idx)], bc.min, bc.max);
3397 draws[(k, bc.col_idx)] = beta_draw;
3398 }
3399 }
3400
3401 Ok(draws)
3402}
3403
3404#[inline]
3407fn standard_normal_draw<R: rand::Rng + ?Sized>(rng: &mut R) -> f64 {
3408 use rand::RngExt as _;
3409 let u1 = rng.random::<f64>().max(1e-16);
3410 let u2 = rng.random::<f64>();
3411 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
3412}
3413
3414fn solve_lower_transpose_into(l: &Array2<f64>, b: &Array1<f64>, out: &mut Array1<f64>) {
3418 let p = l.nrows();
3419 for i in (0..p).rev() {
3420 let mut acc = b[i];
3421 for j in (i + 1)..p {
3422 acc -= l[(j, i)] * out[j];
3423 }
3424 let diag = l[(i, i)];
3425 out[i] = if diag.abs() > 0.0 { acc / diag } else { 0.0 };
3426 }
3427}
3428
3429fn bounded_latent_derivatives(theta: f64, min: f64, max: f64) -> (f64, f64, f64, f64, f64) {
3430 let z = stable_sigmoid(theta);
3431 let width = max - min;
3432 let s = z * (1.0 - z);
3433 let beta = min + width * z;
3434 let db_dtheta = width * s;
3435 let d2b_dtheta2 = width * s * (1.0 - 2.0 * z);
3436 let d3b_dtheta3 = width * s * (1.0 - 6.0 * z + 6.0 * z * z);
3437 (beta, z, db_dtheta, d2b_dtheta2, d3b_dtheta3)
3438}
3439
3440fn bounded_prior_terms(theta: f64, prior: &BoundedCoefficientPriorSpec) -> (f64, f64, f64, f64) {
3441 let (a, b) = match prior {
3442 BoundedCoefficientPriorSpec::None => return (0.0, 0.0, 0.0, 0.0),
3444 BoundedCoefficientPriorSpec::Uniform => (1.0, 1.0),
3447 BoundedCoefficientPriorSpec::Beta { a, b } => (*a, *b),
3448 };
3449 let z = stable_sigmoid(theta).clamp(1e-12, 1.0 - 1e-12);
3450 let logp = a * z.ln() + b * (1.0 - z).ln();
3451 let grad = a - (a + b) * z;
3452 let neghess = (a + b) * z * (1.0 - z);
3453 let neghess_derivative = (a + b) * z * (1.0 - z) * (1.0 - 2.0 * z);
3454 (logp, grad, neghess, neghess_derivative)
3455}
3456
3457#[inline]
3466fn glm_eta_observation_state(
3467 w: f64,
3468 lmu: f64,
3469 lmumu: f64,
3470 lmumumu: f64,
3471 var: f64,
3472 d1: f64,
3473 d2: f64,
3474 d3: f64,
3475 mu_deriv_eps: f64,
3476) -> (f64, f64, f64, f64) {
3477 let score = w * lmu * d1;
3478 let fisherweight = (w * d1 * d1 / var).max(mu_deriv_eps);
3479 let neghessian = -w * (lmumu * d1 * d1 + lmu * d2);
3480 let neghessian_deriv = -w * (lmumumu * d1 * d1 * d1 + 3.0 * lmumu * d1 * d2 + lmu * d3);
3481 (score, fisherweight, neghessian, neghessian_deriv)
3482}
3483
3484fn evaluate_standard_familyobservations(
3485 family: LikelihoodSpec,
3486 latent_cloglog_state: Option<&LatentCLogLogState>,
3487 mixture_link_state: Option<&MixtureLinkState>,
3488 sas_link_state: Option<&SasLinkState>,
3489 y: &Array1<f64>,
3490 weights: &Array1<f64>,
3491 eta: &Array1<f64>,
3492) -> Result<StandardFamilyObservationState, EstimationError> {
3493 const PROB_EPS: f64 = 1e-10;
3494 const MU_DERIV_EPS: f64 = 1e-12;
3495 let n = y.len();
3496 if weights.len() != n || eta.len() != n {
3497 crate::bail_invalid_estim!("bounded family observation size mismatch");
3498 }
3499
3500 let mut mu = Array1::<f64>::zeros(n);
3501 let mut score = Array1::<f64>::zeros(n);
3502 let mut fisherweight = Array1::<f64>::zeros(n);
3503 let mut neghessian_eta = Array1::<f64>::zeros(n);
3504 let mut neghessian_eta_derivative = Array1::<f64>::zeros(n);
3505 let mut log_likelihood = 0.0;
3506
3507 for i in 0..n {
3508 let w = weights[i].max(0.0);
3509 let yi = y[i];
3510 let eta_i = eta[i];
3511 match (&family.response, &family.link) {
3512 (ResponseFamily::Gaussian, _) => {
3513 let resid = yi - eta_i;
3514 mu[i] = eta_i;
3515 score[i] = w * resid;
3516 fisherweight[i] = w.max(MU_DERIV_EPS);
3517 neghessian_eta[i] = w;
3518 neghessian_eta_derivative[i] = 0.0;
3519 log_likelihood += -0.5 * w * resid * resid;
3520 }
3521 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Logit)) => {
3522 let jet = logit_inverse_link_jet5(eta_i);
3523 mu[i] = jet.mu;
3524 score[i] = w * (yi - jet.mu);
3525 fisherweight[i] = jet.d1.max(MU_DERIV_EPS);
3526 neghessian_eta[i] = jet.d1;
3527 neghessian_eta_derivative[i] = jet.d2;
3528 let logmu = -gam_linalg::utils::stable_softplus(-eta_i);
3529 let log_one_minusmu = -gam_linalg::utils::stable_softplus(eta_i);
3530 log_likelihood += w * (yi * logmu + (1.0 - yi) * log_one_minusmu);
3531 }
3532 (ResponseFamily::Binomial, _) => {
3533 let inverse_link = if let Some(state) = latent_cloglog_state {
3534 Some(InverseLink::LatentCLogLog(*state))
3535 } else if let Some(state) = mixture_link_state {
3536 Some(InverseLink::Mixture(state.clone()))
3537 } else {
3538 sas_link_state.map(|state| {
3539 if family.is_binomial_beta_logistic() {
3540 InverseLink::BetaLogistic(*state)
3541 } else {
3542 InverseLink::Sas(*state)
3543 }
3544 })
3545 };
3546 let strategy_spec = LikelihoodSpec {
3547 response: family.response.clone(),
3548 link: inverse_link.clone().unwrap_or_else(|| family.link.clone()),
3549 };
3550 let jet = strategy_for_spec(&strategy_spec).inverse_link_jet(eta_i)?;
3551 let mu_i_raw = jet.mu;
3552 let dmu_deta_raw = jet.d1;
3553 let mu_i: f64 = mu_i_raw.clamp(PROB_EPS, 1.0 - PROB_EPS);
3554 let dmu_deta = dmu_deta_raw.max(MU_DERIV_EPS);
3555 let d2mu_deta2 = jet.d2;
3556 let d3mu_deta3 = jet.d3;
3557 let var = (mu_i * (1.0 - mu_i)).max(PROB_EPS);
3558 let lmu = (yi - mu_i) / var;
3559 let lmumu = -(yi / (mu_i * mu_i)) - ((1.0 - yi) / ((1.0 - mu_i) * (1.0 - mu_i)));
3560 let lmumumu = 2.0 * yi / (mu_i * mu_i * mu_i)
3561 - 2.0 * (1.0 - yi) / ((1.0 - mu_i) * (1.0 - mu_i) * (1.0 - mu_i));
3562 mu[i] = mu_i;
3563 score[i] = w * lmu * dmu_deta;
3564 fisherweight[i] = (w * dmu_deta * dmu_deta / var).max(MU_DERIV_EPS);
3565 neghessian_eta[i] = -w * (lmumu * dmu_deta * dmu_deta + lmu * d2mu_deta2);
3566 neghessian_eta_derivative[i] = -w
3567 * (lmumumu * dmu_deta * dmu_deta * dmu_deta
3568 + 3.0 * lmumu * dmu_deta * d2mu_deta2
3569 + lmu * d3mu_deta3);
3570 log_likelihood += w * (yi * mu_i.ln() + (1.0 - yi) * (1.0 - mu_i).ln());
3571 }
3572 (ResponseFamily::Poisson, _) => {
3573 let strategy_spec = LikelihoodSpec {
3576 response: family.response.clone(),
3577 link: family.link.clone(),
3578 };
3579 let jet = strategy_for_spec(&strategy_spec).inverse_link_jet(eta_i)?;
3580 let mu_i = jet.mu.max(PROB_EPS);
3581 let d1 = jet.d1.max(MU_DERIV_EPS);
3582 let var = mu_i;
3583 let lmu = yi / mu_i - 1.0;
3584 let lmumu = -yi / (mu_i * mu_i);
3585 let lmumumu = 2.0 * yi / (mu_i * mu_i * mu_i);
3586 let (s, f, nh, nhd) = glm_eta_observation_state(
3587 w, lmu, lmumu, lmumumu, var, d1, jet.d2, jet.d3, MU_DERIV_EPS,
3588 );
3589 mu[i] = mu_i;
3590 score[i] = s;
3591 fisherweight[i] = f;
3592 neghessian_eta[i] = nh;
3593 neghessian_eta_derivative[i] = nhd;
3594 log_likelihood += w * (yi * mu_i.ln() - mu_i);
3595 }
3596 (ResponseFamily::Tweedie { p }, _) => {
3597 let p = *p;
3602 let strategy_spec = LikelihoodSpec {
3603 response: family.response.clone(),
3604 link: family.link.clone(),
3605 };
3606 let jet = strategy_for_spec(&strategy_spec).inverse_link_jet(eta_i)?;
3607 let mu_i = jet.mu.max(PROB_EPS);
3608 let d1 = jet.d1.max(MU_DERIV_EPS);
3609 let var = mu_i.powf(p);
3610 let resid = yi - mu_i;
3611 let lmu = resid / var;
3612 let lmumu = -mu_i.powf(-p) - p * resid * mu_i.powf(-p - 1.0);
3613 let lmumumu =
3614 2.0 * p * mu_i.powf(-p - 1.0) + p * (p + 1.0) * resid * mu_i.powf(-p - 2.0);
3615 let (s, f, nh, nhd) = glm_eta_observation_state(
3616 w, lmu, lmumu, lmumumu, var, d1, jet.d2, jet.d3, MU_DERIV_EPS,
3617 );
3618 mu[i] = mu_i;
3619 score[i] = s;
3620 fisherweight[i] = f;
3621 neghessian_eta[i] = nh;
3622 neghessian_eta_derivative[i] = nhd;
3623 log_likelihood += w
3625 * (yi * mu_i.powf(1.0 - p) / (1.0 - p) - mu_i.powf(2.0 - p) / (2.0 - p));
3626 }
3627 (ResponseFamily::NegativeBinomial { theta, .. }, _) => {
3628 let theta = (*theta).max(PROB_EPS);
3632 let strategy_spec = LikelihoodSpec {
3633 response: family.response.clone(),
3634 link: family.link.clone(),
3635 };
3636 let jet = strategy_for_spec(&strategy_spec).inverse_link_jet(eta_i)?;
3637 let mu_i = jet.mu.max(PROB_EPS);
3638 let d1 = jet.d1.max(MU_DERIV_EPS);
3639 let mu_plus = mu_i + theta;
3640 let var = mu_i + mu_i * mu_i / theta;
3641 let lmu = yi / mu_i - (yi + theta) / mu_plus;
3642 let lmumu = -yi / (mu_i * mu_i) + (yi + theta) / (mu_plus * mu_plus);
3643 let lmumumu =
3644 2.0 * yi / (mu_i * mu_i * mu_i) - 2.0 * (yi + theta) / (mu_plus * mu_plus * mu_plus);
3645 let (s, f, nh, nhd) = glm_eta_observation_state(
3646 w, lmu, lmumu, lmumumu, var, d1, jet.d2, jet.d3, MU_DERIV_EPS,
3647 );
3648 mu[i] = mu_i;
3649 score[i] = s;
3650 fisherweight[i] = f;
3651 neghessian_eta[i] = nh;
3652 neghessian_eta_derivative[i] = nhd;
3653 log_likelihood += w * (yi * mu_i.ln() - (yi + theta) * mu_plus.ln());
3654 }
3655 (ResponseFamily::Beta { .. }, _) => {
3656 crate::bail_invalid_estim!(
3657 "bounded linear terms are not supported for BetaLogit fits"
3658 );
3659 }
3660 (ResponseFamily::Gamma, _) => {
3661 let strategy_spec = LikelihoodSpec {
3665 response: family.response.clone(),
3666 link: family.link.clone(),
3667 };
3668 let jet = strategy_for_spec(&strategy_spec).inverse_link_jet(eta_i)?;
3669 let mu_i = jet.mu.max(PROB_EPS);
3670 let d1 = jet.d1.max(MU_DERIV_EPS);
3671 let var = mu_i * mu_i;
3672 let lmu = yi / (mu_i * mu_i) - 1.0 / mu_i;
3673 let lmumu = -2.0 * yi / (mu_i * mu_i * mu_i) + 1.0 / (mu_i * mu_i);
3674 let lmumumu =
3675 6.0 * yi / (mu_i * mu_i * mu_i * mu_i) - 2.0 / (mu_i * mu_i * mu_i);
3676 let (s, f, nh, nhd) = glm_eta_observation_state(
3677 w, lmu, lmumu, lmumumu, var, d1, jet.d2, jet.d3, MU_DERIV_EPS,
3678 );
3679 mu[i] = mu_i;
3680 score[i] = s;
3681 fisherweight[i] = f;
3682 neghessian_eta[i] = nh;
3683 neghessian_eta_derivative[i] = nhd;
3684 log_likelihood += w * (-(yi / mu_i) - mu_i.ln());
3685 }
3686 (ResponseFamily::RoystonParmar, _) => {
3687 crate::bail_invalid_estim!(
3688 "bounded linear terms are not supported for survival model fits"
3689 );
3690 }
3691 }
3692 }
3693
3694 Ok(StandardFamilyObservationState {
3695 eta: eta.clone(),
3696 mu,
3697 score,
3698 fisherweight,
3699 neghessian_eta,
3700 neghessian_eta_derivative,
3701 log_likelihood,
3702 })
3703}
3704
3705#[derive(Clone, Copy, Debug, PartialEq, Eq)]
3706enum SpatialAdaptiveHyperKind {
3707 LogLambdaMagnitude,
3708 LogLambdaGradient,
3709 LogLambdaCurvature,
3710 LogEpsilonMagnitude,
3711 LogEpsilonGradient,
3712 LogEpsilonCurvature,
3713}
3714
3715impl SpatialAdaptiveHyperKind {
3716 fn component_index(self) -> usize {
3717 match self {
3718 SpatialAdaptiveHyperKind::LogLambdaMagnitude
3719 | SpatialAdaptiveHyperKind::LogEpsilonMagnitude => 0,
3720 SpatialAdaptiveHyperKind::LogLambdaGradient
3721 | SpatialAdaptiveHyperKind::LogEpsilonGradient => 1,
3722 SpatialAdaptiveHyperKind::LogLambdaCurvature
3723 | SpatialAdaptiveHyperKind::LogEpsilonCurvature => 2,
3724 }
3725 }
3726
3727 fn is_log_lambda(self) -> bool {
3728 matches!(
3729 self,
3730 SpatialAdaptiveHyperKind::LogLambdaMagnitude
3731 | SpatialAdaptiveHyperKind::LogLambdaGradient
3732 | SpatialAdaptiveHyperKind::LogLambdaCurvature
3733 )
3734 }
3735
3736 fn is_log_epsilon(self) -> bool {
3737 matches!(
3738 self,
3739 SpatialAdaptiveHyperKind::LogEpsilonMagnitude
3740 | SpatialAdaptiveHyperKind::LogEpsilonGradient
3741 | SpatialAdaptiveHyperKind::LogEpsilonCurvature
3742 )
3743 }
3744}
3745
3746#[derive(Clone, Copy, Debug)]
3747struct SpatialAdaptiveHyperSpec {
3748 cache_index: usize,
3749 kind: SpatialAdaptiveHyperKind,
3750}
3751
3752#[derive(Clone, Copy, Debug, PartialEq, Eq)]
3753enum SpatialAdaptiveExplicitSecondOrderKind {
3754 StructuralZero,
3755 LocalAlphaAlpha,
3756 LocalAlphaEta,
3757 SharedEtaEta,
3758}
3759
3760#[derive(Clone, Copy, Debug, PartialEq, Eq)]
3765enum AdaptiveComponent {
3766 Magnitude,
3767 Gradient,
3768 Curvature,
3769}
3770
3771impl AdaptiveComponent {
3772 fn from_index(index: usize) -> Result<Self, String> {
3773 match index {
3774 0 => Ok(AdaptiveComponent::Magnitude),
3775 1 => Ok(AdaptiveComponent::Gradient),
3776 2 => Ok(AdaptiveComponent::Curvature),
3777 other => Err(SmoothError::invalid_index(format!(
3778 "invalid adaptive component index {}",
3779 other
3780 ))
3781 .into()),
3782 }
3783 }
3784}
3785
3786#[derive(Clone, Copy, Debug, PartialEq, Eq)]
3791enum HyperDerivativeKind {
3792 Rho,
3794 LogEpsilonFirst,
3796 LogEpsilonSecond,
3798}
3799
3800#[derive(Clone, Copy, Debug, PartialEq, Eq)]
3805enum HyperDriftKind {
3806 Rho,
3807 LogEpsilon,
3808}
3809
3810impl SpatialAdaptiveHyperSpec {
3811 fn component_index(self) -> usize {
3812 self.kind.component_index()
3813 }
3814
3815 fn explicit_second_order_kind(self, other: Self) -> SpatialAdaptiveExplicitSecondOrderKind {
3816 if self.component_index() != other.component_index() {
3817 return SpatialAdaptiveExplicitSecondOrderKind::StructuralZero;
3818 }
3819 match (
3820 self.kind.is_log_lambda(),
3821 other.kind.is_log_lambda(),
3822 self.kind.is_log_epsilon(),
3823 other.kind.is_log_epsilon(),
3824 ) {
3825 (true, true, false, false) if self.cache_index == other.cache_index => {
3826 SpatialAdaptiveExplicitSecondOrderKind::LocalAlphaAlpha
3827 }
3828 (true, false, false, true) | (false, true, true, false) => {
3829 SpatialAdaptiveExplicitSecondOrderKind::LocalAlphaEta
3830 }
3831 (false, false, true, true) => SpatialAdaptiveExplicitSecondOrderKind::SharedEtaEta,
3832 _ => SpatialAdaptiveExplicitSecondOrderKind::StructuralZero,
3833 }
3834 }
3835}
3836
3837#[derive(Clone, Debug)]
3838struct SpatialAdaptiveTermHyperParams {
3839 lambda: [f64; 3],
3840 epsilon: [f64; 3],
3841}
3842
3843#[derive(Clone)]
3844struct SpatialAdaptiveExactEvaluation {
3845 obs: StandardFamilyObservationState,
3846 adaptive_states: Vec<SpatialPenaltyExactState>,
3847 adaptive_penalty_value: f64,
3848 adaptive_penaltygradient: Array1<f64>,
3849 adaptive_penaltyhessian: Array2<f64>,
3850 fixed_quadraticvalue: f64,
3851 fixed_quadraticgradient: Array1<f64>,
3852 fixed_quadratichessian: Array2<f64>,
3853}
3854
3855#[derive(Clone)]
3856struct CachedSpatialAdaptiveExactEvaluation {
3857 beta: Array1<f64>,
3858 eval: Arc<SpatialAdaptiveExactEvaluation>,
3859}
3860
3861impl SpatialAdaptiveExactEvaluation {
3862 fn total_penalty_value(&self) -> f64 {
3863 self.adaptive_penalty_value + self.fixed_quadraticvalue
3864 }
3865
3866 fn total_penaltygradient(&self) -> Array1<f64> {
3867 &self.adaptive_penaltygradient + &self.fixed_quadraticgradient
3868 }
3869
3870 fn total_penaltyhessian(&self) -> Array2<f64> {
3871 &self.adaptive_penaltyhessian + &self.fixed_quadratichessian
3872 }
3873
3874 fn totalobjectivehessian(&self, design: &Array2<f64>) -> Result<Array2<f64>, String> {
3875 let mut out = xt_diag_x_dense(design.view(), self.obs.neghessian_eta.view())?;
3876 out += &self.total_penaltyhessian();
3877 Ok(out)
3878 }
3879}
3880
3881#[derive(Clone)]
3882struct SpatialAdaptiveExactFamily {
3883 family: LikelihoodSpec,
3884 latent_cloglog_state: Option<LatentCLogLogState>,
3885 mixture_link_state: Option<MixtureLinkState>,
3886 sas_link_state: Option<SasLinkState>,
3887 y: Arc<Array1<f64>>,
3888 weights: Arc<Array1<f64>>,
3889 design: Arc<Array2<f64>>,
3890 offset: Arc<Array1<f64>>,
3891 linear_constraints: Option<LinearInequalityConstraints>,
3892 runtime_caches: Arc<Vec<SpatialOperatorRuntimeCache>>,
3893 adaptive_params: Vec<SpatialAdaptiveTermHyperParams>,
3894 fixed_quadratichessian: Arc<Array2<f64>>,
3895 hyperspecs: Arc<Vec<SpatialAdaptiveHyperSpec>>,
3896 exact_eval_cache: Arc<Mutex<Option<CachedSpatialAdaptiveExactEvaluation>>>,
3897}
3898
3899impl SpatialAdaptiveExactFamily {
3900 fn with_adaptive_params(
3901 &self,
3902 adaptive_params: Vec<SpatialAdaptiveTermHyperParams>,
3903 fixed_quadratichessian: Arc<Array2<f64>>,
3904 ) -> Self {
3905 Self {
3906 family: self.family.clone(),
3907 latent_cloglog_state: self.latent_cloglog_state,
3908 mixture_link_state: self.mixture_link_state.clone(),
3909 sas_link_state: self.sas_link_state,
3910 y: self.y.clone(),
3911 weights: self.weights.clone(),
3912 design: self.design.clone(),
3913 offset: self.offset.clone(),
3914 linear_constraints: self.linear_constraints.clone(),
3915 runtime_caches: self.runtime_caches.clone(),
3916 adaptive_params,
3917 fixed_quadratichessian,
3918 hyperspecs: self.hyperspecs.clone(),
3919 exact_eval_cache: Arc::new(Mutex::new(None)),
3920 }
3921 }
3922
3923 fn total_eta(&self, beta: &Array1<f64>) -> Array1<f64> {
3924 gam_linalg::faer_ndarray::fast_av(self.design.as_ref(), beta) + self.offset.as_ref()
3925 }
3926
3927 fn fixed_quadratic_terms(&self, beta: &Array1<f64>) -> (f64, Array1<f64>) {
3928 let grad = self.fixed_quadratichessian.dot(beta);
3929 let value = 0.5 * beta.dot(&grad);
3930 (value, grad)
3931 }
3932
3933 fn adaptive_penalty_value_only(&self, beta: &Array1<f64>) -> Result<f64, String> {
3934 let mut penalty_value = 0.0;
3935 for (cache_idx, cache) in self.runtime_caches.iter().enumerate() {
3936 let params = self.adaptive_params.get(cache_idx).ok_or_else(|| {
3937 format!(
3938 "missing adaptive parameter block for cache {}",
3939 cache.termname
3940 )
3941 })?;
3942 let beta_local = beta.slice(s![cache.coeff_global_range.clone()]);
3943 let state =
3944 SpatialPenaltyExactState::from_beta_local(beta_local, cache, params.epsilon)
3945 .map_err(|e| e.to_string())?;
3946 penalty_value += params.lambda[0] * state.magnitude.penalty_value();
3947 penalty_value += params.lambda[1] * state.gradient.penalty_value();
3948 penalty_value += params.lambda[2] * state.curvature.penalty_value();
3949 }
3950 Ok(penalty_value)
3951 }
3952
3953 fn zero_hyper_parts(&self) -> (Array1<f64>, Array2<f64>) {
3954 let total_dim = self.design.ncols();
3955 (
3956 Array1::<f64>::zeros(total_dim),
3957 Array2::<f64>::zeros((total_dim, total_dim)),
3958 )
3959 }
3960
3961 fn embed_local_hyper_parts(
3962 &self,
3963 coeff_range: &Range<usize>,
3964 local_grad: &Array1<f64>,
3965 local_hess: &Array2<f64>,
3966 ) -> (Array1<f64>, Array2<f64>) {
3967 let (mut beta_mixed, mut betahessian) = self.zero_hyper_parts();
3968 beta_mixed
3969 .slice_mut(s![coeff_range.clone()])
3970 .assign(local_grad);
3971 betahessian
3972 .slice_mut(s![coeff_range.clone(), coeff_range.clone()])
3973 .assign(local_hess);
3974 (beta_mixed, betahessian)
3975 }
3976
3977 fn embed_local_hyper_hessian(
3978 &self,
3979 coeff_range: &Range<usize>,
3980 local_hess: &Array2<f64>,
3981 ) -> Array2<f64> {
3982 let total_dim = self.design.ncols();
3983 let mut out = Array2::<f64>::zeros((total_dim, total_dim));
3984 out.slice_mut(s![coeff_range.clone(), coeff_range.clone()])
3985 .assign(local_hess);
3986 out
3987 }
3988
3989 fn adaptive_block_eval(
3998 &self,
3999 eval: &SpatialAdaptiveExactEvaluation,
4000 cache_idx: usize,
4001 component: AdaptiveComponent,
4002 derivative: HyperDerivativeKind,
4003 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
4004 let cache = self
4005 .runtime_caches
4006 .get(cache_idx)
4007 .ok_or_else(|| format!("adaptive cache index {} out of bounds", cache_idx))?;
4008 let params = self
4009 .adaptive_params
4010 .get(cache_idx)
4011 .ok_or_else(|| format!("adaptive hyperparameter block {} out of bounds", cache_idx))?;
4012 let state = eval
4013 .adaptive_states
4014 .get(cache_idx)
4015 .ok_or_else(|| format!("adaptive exact state index {} out of bounds", cache_idx))?;
4016
4017 let (objective_local, beta_mixed_local, betahessian_local) = match component {
4018 AdaptiveComponent::Magnitude => {
4019 let lambda = params.lambda[0];
4020 let mag = &state.magnitude;
4021 let (objective, gradient_coeff, hessian_diag) = match derivative {
4022 HyperDerivativeKind::Rho => (
4023 mag.penalty_value(),
4024 mag.betagradient_coeff(),
4025 mag.betahessian_diag(),
4026 ),
4027 HyperDerivativeKind::LogEpsilonFirst => (
4028 mag.log_epsilon_gradient_terms().sum(),
4029 mag.log_epsilon_betagradient_coeff(),
4030 mag.log_epsilon_betahessian_diag(),
4031 ),
4032 HyperDerivativeKind::LogEpsilonSecond => (
4033 mag.log_epsilon_hessian_terms().sum(),
4034 mag.log_epsilon_beta_mixed_second_coeff(),
4035 mag.log_epsilon_betahessian_second_diag(),
4036 ),
4037 };
4038 (
4039 lambda * objective,
4040 lambda * scalar_operatorgradient(&cache.d0, &gradient_coeff),
4041 lambda * scalar_operatorhessian(&cache.d0, &hessian_diag),
4042 )
4043 }
4044 AdaptiveComponent::Gradient => {
4045 let lambda = params.lambda[1];
4046 let grad = &state.gradient;
4047 let (objective, gradient_blocks, hessian_blocks) = match derivative {
4048 HyperDerivativeKind::Rho => (
4049 grad.penalty_value(),
4050 grad.betagradient_blocks(),
4051 grad.betahessian_blocks(),
4052 ),
4053 HyperDerivativeKind::LogEpsilonFirst => (
4054 grad.log_epsilon_gradient_terms().sum(),
4055 grad.log_epsilon_betagradient_blocks(),
4056 grad.log_epsilon_betahessian_blocks(),
4057 ),
4058 HyperDerivativeKind::LogEpsilonSecond => (
4059 grad.log_epsilon_hessian_terms().sum(),
4060 grad.log_epsilon_beta_mixed_second_blocks(),
4061 grad.log_epsilon_betahessian_second_blocks(),
4062 ),
4063 };
4064 (
4065 lambda * objective,
4066 lambda
4067 * grouped_operatorgradient(&cache.d1, cache.dimension, &gradient_blocks)
4068 .map_err(|e| e.to_string())?,
4069 lambda
4070 * grouped_operatorhessian(&cache.d1, cache.dimension, &hessian_blocks)
4071 .map_err(|e| e.to_string())?,
4072 )
4073 }
4074 AdaptiveComponent::Curvature => {
4075 let lambda = params.lambda[2];
4076 let group = cache.dimension * cache.dimension;
4077 let curv = &state.curvature;
4078 let (objective, gradient_blocks, hessian_blocks) = match derivative {
4079 HyperDerivativeKind::Rho => (
4080 curv.penalty_value(),
4081 curv.betagradient_blocks(),
4082 curv.betahessian_blocks(),
4083 ),
4084 HyperDerivativeKind::LogEpsilonFirst => (
4085 curv.log_epsilon_gradient_terms().sum(),
4086 curv.log_epsilon_betagradient_blocks(),
4087 curv.log_epsilon_betahessian_blocks(),
4088 ),
4089 HyperDerivativeKind::LogEpsilonSecond => (
4090 curv.log_epsilon_hessian_terms().sum(),
4091 curv.log_epsilon_beta_mixed_second_blocks(),
4092 curv.log_epsilon_betahessian_second_blocks(),
4093 ),
4094 };
4095 (
4096 lambda * objective,
4097 lambda
4098 * grouped_operatorgradient(&cache.d2, group, &gradient_blocks)
4099 .map_err(|e| e.to_string())?,
4100 lambda
4101 * grouped_operatorhessian(&cache.d2, group, &hessian_blocks)
4102 .map_err(|e| e.to_string())?,
4103 )
4104 }
4105 };
4106
4107 let (beta_mixed, betahessian) = self.embed_local_hyper_parts(
4108 &cache.coeff_global_range,
4109 &beta_mixed_local,
4110 &betahessian_local,
4111 );
4112 Ok((objective_local, beta_mixed, betahessian))
4113 }
4114
4115 fn adaptive_shared_log_epsilon_parts(
4116 &self,
4117 eval: &SpatialAdaptiveExactEvaluation,
4118 component: usize,
4119 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
4120 self.adaptive_shared_block_eval(eval, component, HyperDerivativeKind::LogEpsilonFirst)
4126 }
4127
4128 fn adaptive_shared_log_epsilon_second_parts(
4129 &self,
4130 eval: &SpatialAdaptiveExactEvaluation,
4131 component: usize,
4132 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
4133 self.adaptive_shared_block_eval(eval, component, HyperDerivativeKind::LogEpsilonSecond)
4139 }
4140
4141 fn adaptive_shared_block_eval(
4146 &self,
4147 eval: &SpatialAdaptiveExactEvaluation,
4148 component: usize,
4149 derivative: HyperDerivativeKind,
4150 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
4151 let component = AdaptiveComponent::from_index(component)?;
4152 let (mut score, mut hessian) = self.zero_hyper_parts();
4153 let mut objective = 0.0;
4154 for cache_idx in 0..self.runtime_caches.len() {
4155 let (local_objective, local_score, local_hessian) =
4156 self.adaptive_block_eval(eval, cache_idx, component, derivative)?;
4157 objective += local_objective;
4158 score += &local_score;
4159 hessian += &local_hessian;
4160 }
4161 Ok((objective, score, hessian))
4162 }
4163
4164 fn adaptive_shared_log_epsilon_drift(
4165 &self,
4166 eval: &SpatialAdaptiveExactEvaluation,
4167 component: usize,
4168 direction: &Array1<f64>,
4169 ) -> Result<Array2<f64>, String> {
4170 let component = AdaptiveComponent::from_index(component)?;
4174 let total_dim = self.design.ncols();
4175 let mut total = Array2::<f64>::zeros((total_dim, total_dim));
4176 for cache_idx in 0..self.runtime_caches.len() {
4177 total += &self.adaptive_block_drift_eval(
4178 eval,
4179 cache_idx,
4180 component,
4181 HyperDriftKind::LogEpsilon,
4182 direction,
4183 )?;
4184 }
4185 Ok(total)
4186 }
4187
4188 fn adaptive_explicit_second_order_parts(
4189 &self,
4190 eval: &SpatialAdaptiveExactEvaluation,
4191 left: SpatialAdaptiveHyperSpec,
4192 right: SpatialAdaptiveHyperSpec,
4193 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
4194 match left.explicit_second_order_kind(right) {
4203 SpatialAdaptiveExplicitSecondOrderKind::StructuralZero => {
4204 let (score, hessian) = self.zero_hyper_parts();
4205 Ok((0.0, score, hessian))
4206 }
4207 SpatialAdaptiveExplicitSecondOrderKind::LocalAlphaAlpha => self.adaptive_block_eval(
4208 eval,
4209 left.cache_index,
4210 AdaptiveComponent::from_index(left.component_index())?,
4211 HyperDerivativeKind::Rho,
4212 ),
4213 SpatialAdaptiveExplicitSecondOrderKind::LocalAlphaEta => {
4214 let local_alpha = if left.kind.is_log_lambda() {
4215 left
4216 } else {
4217 right
4218 };
4219 self.adaptive_block_eval(
4220 eval,
4221 local_alpha.cache_index,
4222 AdaptiveComponent::from_index(local_alpha.component_index())?,
4223 HyperDerivativeKind::LogEpsilonFirst,
4224 )
4225 }
4226 SpatialAdaptiveExplicitSecondOrderKind::SharedEtaEta => {
4227 self.adaptive_shared_log_epsilon_second_parts(eval, left.component_index())
4228 }
4229 }
4230 }
4231
4232 fn adaptive_block_drift_eval(
4240 &self,
4241 eval: &SpatialAdaptiveExactEvaluation,
4242 cache_idx: usize,
4243 component: AdaptiveComponent,
4244 drift: HyperDriftKind,
4245 direction: &Array1<f64>,
4246 ) -> Result<Array2<f64>, String> {
4247 let cache = self
4248 .runtime_caches
4249 .get(cache_idx)
4250 .ok_or_else(|| format!("adaptive cache index {} out of bounds", cache_idx))?;
4251 let params = self
4252 .adaptive_params
4253 .get(cache_idx)
4254 .ok_or_else(|| format!("adaptive hyperparameter block {} out of bounds", cache_idx))?;
4255 let state = eval
4256 .adaptive_states
4257 .get(cache_idx)
4258 .ok_or_else(|| format!("adaptive exact state index {} out of bounds", cache_idx))?;
4259 let direction_local = direction.slice(s![cache.coeff_global_range.clone()]);
4260
4261 let local_hessian = match component {
4262 AdaptiveComponent::Magnitude => {
4263 let d0_u = cache.d0.dot(&direction_local);
4264 let mag = &state.magnitude;
4265 let diag = match drift {
4266 HyperDriftKind::Rho => mag.directionalhessian_diag(&d0_u),
4267 HyperDriftKind::LogEpsilon => {
4268 mag.log_epsilon_betahessian_directional_diag(&d0_u)
4269 }
4270 };
4271 params.lambda[0] * scalar_operatorhessian(&cache.d0, &diag)
4272 }
4273 AdaptiveComponent::Gradient => {
4274 let d1_u = cache.d1.dot(&direction_local);
4275 let direction_blocks = collocationgradient_blocks(&d1_u, cache.dimension)
4276 .map_err(|e| e.to_string())?;
4277 let grad = &state.gradient;
4278 let blocks = match drift {
4279 HyperDriftKind::Rho => grad.directionalhessian_blocks(&direction_blocks),
4280 HyperDriftKind::LogEpsilon => {
4281 grad.log_epsilon_betahessian_directional_blocks(&direction_blocks)
4282 }
4283 };
4284 params.lambda[1]
4285 * grouped_operatorhessian(&cache.d1, cache.dimension, &blocks)
4286 .map_err(|e| e.to_string())?
4287 }
4288 AdaptiveComponent::Curvature => {
4289 let group = cache.dimension * cache.dimension;
4290 let d2_u = cache.d2.dot(&direction_local);
4291 let direction_blocks =
4292 collocationhessian_blocks(&d2_u, cache.dimension).map_err(|e| e.to_string())?;
4293 let curv = &state.curvature;
4294 let blocks = match drift {
4295 HyperDriftKind::Rho => curv.directionalhessian_blocks(&direction_blocks),
4296 HyperDriftKind::LogEpsilon => {
4297 curv.log_epsilon_betahessian_directional_blocks(&direction_blocks)
4298 }
4299 };
4300 params.lambda[2]
4301 * grouped_operatorhessian(&cache.d2, group, &blocks)
4302 .map_err(|e| e.to_string())?
4303 }
4304 };
4305
4306 Ok(self.embed_local_hyper_hessian(&cache.coeff_global_range, &local_hessian))
4307 }
4308
4309 fn adaptive_hyper_parts(
4310 &self,
4311 eval: &SpatialAdaptiveExactEvaluation,
4312 hyper: SpatialAdaptiveHyperSpec,
4313 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
4314 match hyper.kind {
4315 SpatialAdaptiveHyperKind::LogLambdaMagnitude
4318 | SpatialAdaptiveHyperKind::LogLambdaGradient
4319 | SpatialAdaptiveHyperKind::LogLambdaCurvature => self.adaptive_block_eval(
4320 eval,
4321 hyper.cache_index,
4322 AdaptiveComponent::from_index(hyper.component_index())?,
4323 HyperDerivativeKind::Rho,
4324 ),
4325 SpatialAdaptiveHyperKind::LogEpsilonMagnitude
4327 | SpatialAdaptiveHyperKind::LogEpsilonGradient
4328 | SpatialAdaptiveHyperKind::LogEpsilonCurvature => {
4329 self.adaptive_shared_log_epsilon_parts(eval, hyper.component_index())
4330 }
4331 }
4332 }
4333
4334 fn exact_evaluation_uncached(
4335 &self,
4336 beta: &Array1<f64>,
4337 ) -> Result<SpatialAdaptiveExactEvaluation, String> {
4338 let eta = self.total_eta(beta);
4339 let obs = evaluate_standard_familyobservations(
4340 self.family.clone(),
4341 self.latent_cloglog_state.as_ref(),
4342 self.mixture_link_state.as_ref(),
4343 self.sas_link_state.as_ref(),
4344 &self.y,
4345 &self.weights,
4346 &eta,
4347 )
4348 .map_err(|e| e.to_string())?;
4349 let p = beta.len();
4350 let mut penalty_value = 0.0;
4351 let mut penaltygradient = Array1::<f64>::zeros(p);
4352 let mut penaltyhessian = Array2::<f64>::zeros((p, p));
4353 let mut adaptive_states = Vec::with_capacity(self.runtime_caches.len());
4354
4355 for (cache_idx, cache) in self.runtime_caches.iter().enumerate() {
4356 let params = self.adaptive_params.get(cache_idx).ok_or_else(|| {
4357 format!(
4358 "missing adaptive parameter block for cache {}",
4359 cache.termname
4360 )
4361 })?;
4362 let beta_local = beta.slice(s![cache.coeff_global_range.clone()]);
4363 let state =
4364 SpatialPenaltyExactState::from_beta_local(beta_local, cache, params.epsilon)
4365 .map_err(|e| e.to_string())?;
4366
4367 let g0 = scalar_operatorgradient(&cache.d0, &state.magnitude.betagradient_coeff());
4368 let gg = grouped_operatorgradient(
4369 &cache.d1,
4370 cache.dimension,
4371 &state.gradient.betagradient_blocks(),
4372 )
4373 .map_err(|e| e.to_string())?;
4374 let gc = grouped_operatorgradient(
4375 &cache.d2,
4376 cache.dimension * cache.dimension,
4377 &state.curvature.betagradient_blocks(),
4378 )
4379 .map_err(|e| e.to_string())?;
4380 let h0 = scalar_operatorhessian(&cache.d0, &state.magnitude.betahessian_diag());
4381 let hg = grouped_operatorhessian(
4382 &cache.d1,
4383 cache.dimension,
4384 &state.gradient.betahessian_blocks(),
4385 )
4386 .map_err(|e| e.to_string())?;
4387 let hc = grouped_operatorhessian(
4388 &cache.d2,
4389 cache.dimension * cache.dimension,
4390 &state.curvature.betahessian_blocks(),
4391 )
4392 .map_err(|e| e.to_string())?;
4393
4394 let lambda0 = params.lambda[0];
4395 let lambdag = params.lambda[1];
4396 let lambdac = params.lambda[2];
4397
4398 penalty_value += lambda0 * state.magnitude.penalty_value();
4399 penalty_value += lambdag * state.gradient.penalty_value();
4400 penalty_value += lambdac * state.curvature.penalty_value();
4401
4402 let range = cache.coeff_global_range.clone();
4403 {
4404 let mut grad_local = penaltygradient.slice_mut(s![range.clone()]);
4405 grad_local += &(g0.mapv(|v| lambda0 * v));
4406 grad_local += &(gg.mapv(|v| lambdag * v));
4407 grad_local += &(gc.mapv(|v| lambdac * v));
4408 }
4409 {
4410 let mut h_local = penaltyhessian.slice_mut(s![range.clone(), range]);
4411 h_local += &h0.mapv(|v| lambda0 * v);
4412 h_local += &hg.mapv(|v| lambdag * v);
4413 h_local += &hc.mapv(|v| lambdac * v);
4414 }
4415
4416 adaptive_states.push(state);
4417 }
4418
4419 let (fixed_quadraticvalue, fixed_quadraticgradient) = self.fixed_quadratic_terms(beta);
4420 Ok(SpatialAdaptiveExactEvaluation {
4421 obs,
4422 adaptive_states,
4423 adaptive_penalty_value: penalty_value,
4424 adaptive_penaltygradient: penaltygradient,
4425 adaptive_penaltyhessian: penaltyhessian,
4426 fixed_quadraticvalue,
4427 fixed_quadraticgradient,
4428 fixed_quadratichessian: self.fixed_quadratichessian.as_ref().clone(),
4429 })
4430 }
4431
4432 fn exact_evaluation(
4433 &self,
4434 beta: &Array1<f64>,
4435 ) -> Result<Arc<SpatialAdaptiveExactEvaluation>, String> {
4436 {
4437 let cache = self
4438 .exact_eval_cache
4439 .lock()
4440 .map_err(|_| "spatial adaptive exact-evaluation cache lock poisoned".to_string())?;
4441 if let Some(cached) = cache.as_ref()
4442 && cached.beta.len() == beta.len()
4443 && cached
4444 .beta
4445 .iter()
4446 .zip(beta.iter())
4447 .all(|(&left, &right)| left == right)
4448 {
4449 return Ok(Arc::clone(&cached.eval));
4450 }
4451 }
4452
4453 let eval = Arc::new(self.exact_evaluation_uncached(beta)?);
4454 let mut cache = self
4455 .exact_eval_cache
4456 .lock()
4457 .map_err(|_| "spatial adaptive exact-evaluation cache lock poisoned".to_string())?;
4458 *cache = Some(CachedSpatialAdaptiveExactEvaluation {
4459 beta: beta.clone(),
4460 eval: Arc::clone(&eval),
4461 });
4462 Ok(eval)
4463 }
4464
4465 fn exacthessian_directional_derivative_from_evaluation(
4466 &self,
4467 beta: &Array1<f64>,
4468 eval: &SpatialAdaptiveExactEvaluation,
4469 direction: &Array1<f64>,
4470 ) -> Result<Array2<f64>, String> {
4471 assert_eq!(
4472 beta.len(),
4473 direction.len(),
4474 "beta/direction length mismatch",
4475 );
4476 let d_eta = gam_linalg::faer_ndarray::fast_av(self.design.as_ref(), direction);
4477 let mut total = xt_diag_x_dense(
4478 self.design.view(),
4479 (&eval.obs.neghessian_eta_derivative * &d_eta).view(),
4480 )?;
4481 for (cache_idx, cache) in self.runtime_caches.iter().enumerate() {
4482 let params = self.adaptive_params.get(cache_idx).ok_or_else(|| {
4483 format!(
4484 "missing adaptive parameter block for cache {}",
4485 cache.termname
4486 )
4487 })?;
4488 let state = eval
4489 .adaptive_states
4490 .get(cache_idx)
4491 .ok_or_else(|| format!("missing adaptive state for cache {}", cache.termname))?;
4492 let direction_local = direction.slice(s![cache.coeff_global_range.clone()]);
4493 let d0_u = cache.d0.dot(&direction_local);
4494 let d1_u = cache.d1.dot(&direction_local);
4495 let d2_u = cache.d2.dot(&direction_local);
4496 let h0 =
4497 scalar_operatorhessian(&cache.d0, &state.magnitude.directionalhessian_diag(&d0_u))
4498 .mapv(|v| params.lambda[0] * v);
4499 let hg = grouped_operatorhessian(
4500 &cache.d1,
4501 cache.dimension,
4502 &state.gradient.directionalhessian_blocks(
4503 &collocationgradient_blocks(&d1_u, cache.dimension)
4504 .map_err(|e| e.to_string())?,
4505 ),
4506 )
4507 .map_err(|e| e.to_string())?
4508 .mapv(|v| params.lambda[1] * v);
4509 let hc = grouped_operatorhessian(
4510 &cache.d2,
4511 cache.dimension * cache.dimension,
4512 &state.curvature.directionalhessian_blocks(
4513 &collocationhessian_blocks(&d2_u, cache.dimension)
4514 .map_err(|e| e.to_string())?,
4515 ),
4516 )
4517 .map_err(|e| e.to_string())?
4518 .mapv(|v| params.lambda[2] * v);
4519 let range = cache.coeff_global_range.clone();
4520 let mut local = total.slice_mut(s![range.clone(), range]);
4521 local += &h0;
4522 local += &hg;
4523 local += &hc;
4524 }
4525 Ok(total)
4526 }
4527
4528 fn exacthessian_second_directional_derivative_from_evaluation(
4549 &self,
4550 eval: &SpatialAdaptiveExactEvaluation,
4551 direction_u: &Array1<f64>,
4552 direction_v: &Array1<f64>,
4553 ) -> Result<Option<Array2<f64>>, String> {
4554 let p = self.design.ncols();
4555 if eval.obs.neghessian_eta_derivative.iter().any(|&w| w != 0.0) {
4557 return Ok(None);
4558 }
4559 let mut total = Array2::<f64>::zeros((p, p));
4560 for (cache_idx, cache) in self.runtime_caches.iter().enumerate() {
4561 let params = self.adaptive_params.get(cache_idx).ok_or_else(|| {
4562 format!(
4563 "missing adaptive parameter block for cache {}",
4564 cache.termname
4565 )
4566 })?;
4567 let state = eval
4568 .adaptive_states
4569 .get(cache_idx)
4570 .ok_or_else(|| format!("missing adaptive state for cache {}", cache.termname))?;
4571 let u_local = direction_u.slice(s![cache.coeff_global_range.clone()]);
4572 let v_local = direction_v.slice(s![cache.coeff_global_range.clone()]);
4573
4574 let q0_u = cache.d0.dot(&u_local);
4576 let q0_v = cache.d0.dot(&v_local);
4577 let h0 = scalar_operatorhessian(
4578 &cache.d0,
4579 &state.magnitude.second_directionalhessian_diag(&q0_u, &q0_v),
4580 )
4581 .mapv(|x| params.lambda[0] * x);
4582
4583 let a1 = collocationgradient_blocks(&cache.d1.dot(&u_local), cache.dimension)
4585 .map_err(|e| e.to_string())?;
4586 let b1 = collocationgradient_blocks(&cache.d1.dot(&v_local), cache.dimension)
4587 .map_err(|e| e.to_string())?;
4588 let hg = grouped_operatorhessian(
4589 &cache.d1,
4590 cache.dimension,
4591 &state.gradient.second_directionalhessian_blocks(&a1, &b1),
4592 )
4593 .map_err(|e| e.to_string())?
4594 .mapv(|x| params.lambda[1] * x);
4595
4596 let a2 = collocationhessian_blocks(&cache.d2.dot(&u_local), cache.dimension)
4598 .map_err(|e| e.to_string())?;
4599 let b2 = collocationhessian_blocks(&cache.d2.dot(&v_local), cache.dimension)
4600 .map_err(|e| e.to_string())?;
4601 let hc = grouped_operatorhessian(
4602 &cache.d2,
4603 cache.dimension * cache.dimension,
4604 &state.curvature.second_directionalhessian_blocks(&a2, &b2),
4605 )
4606 .map_err(|e| e.to_string())?
4607 .mapv(|x| params.lambda[2] * x);
4608
4609 let range = cache.coeff_global_range.clone();
4610 let mut local = total.slice_mut(s![range.clone(), range]);
4611 local += &h0;
4612 local += &hg;
4613 local += &hc;
4614 }
4615 Ok(Some(total))
4616 }
4617}
4618
4619impl CustomFamily for SpatialAdaptiveExactFamily {
4620 fn joint_jeffreys_term_required(&self) -> bool {
4624 true
4625 }
4626
4627 fn joint_jeffreys_information_with_specs(
4664 &self,
4665 block_states: &[ParameterBlockState],
4666 specs: &[ParameterBlockSpec],
4667 ) -> Result<Option<Array2<f64>>, String> {
4668 let spec = expect_single_blockspec(specs, "spatial adaptive exact family")?;
4669 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4670 if spec.design.ncols() != beta.len() {
4671 return Err(SmoothError::dimension_mismatch(format!(
4672 "spatial adaptive Jeffreys information: spec design has {} columns, beta has {}",
4673 spec.design.ncols(),
4674 beta.len()
4675 ))
4676 .into());
4677 }
4678 let eval = self.exact_evaluation(beta)?;
4679 Ok(Some(xt_diag_x_dense(
4680 self.design.view(),
4681 eval.obs.neghessian_eta.view(),
4682 )?))
4683 }
4684
4685 fn joint_jeffreys_information_directional_derivative_with_specs(
4686 &self,
4687 block_states: &[ParameterBlockState],
4688 specs: &[ParameterBlockSpec],
4689 d_beta_flat: &Array1<f64>,
4690 ) -> Result<Option<Array2<f64>>, String> {
4691 let spec = expect_single_blockspec(specs, "spatial adaptive exact family")?;
4697 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4698 if spec.design.ncols() != d_beta_flat.len() {
4699 return Err(SmoothError::dimension_mismatch(format!(
4700 "spatial adaptive Jeffreys directional derivative: spec design has {} columns, direction has {}",
4701 spec.design.ncols(),
4702 d_beta_flat.len()
4703 ))
4704 .into());
4705 }
4706 let eval = self.exact_evaluation(beta)?;
4707 let d_eta = gam_linalg::faer_ndarray::fast_av(self.design.as_ref(), d_beta_flat);
4708 Ok(Some(xt_diag_x_dense(
4709 self.design.view(),
4710 (&eval.obs.neghessian_eta_derivative * &d_eta).view(),
4711 )?))
4712 }
4713
4714 fn joint_jeffreys_information_second_directional_derivative_with_specs(
4715 &self,
4716 block_states: &[ParameterBlockState],
4717 specs: &[ParameterBlockSpec],
4718 d_beta_u_flat: &Array1<f64>,
4719 d_betav_flat: &Array1<f64>,
4720 ) -> Result<Option<Array2<f64>>, String> {
4721 let spec = expect_single_blockspec(specs, "spatial adaptive exact family")?;
4728 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4729 if spec.design.ncols() != beta.len()
4730 || d_beta_u_flat.len() != beta.len()
4731 || d_betav_flat.len() != beta.len()
4732 {
4733 return Err(SmoothError::dimension_mismatch(format!(
4734 "spatial adaptive Jeffreys second-direction length mismatch: spec cols={}, dirs=({}, {}), expected {}",
4735 spec.design.ncols(),
4736 d_beta_u_flat.len(),
4737 d_betav_flat.len(),
4738 beta.len()
4739 ))
4740 .into());
4741 }
4742 let eval = self.exact_evaluation(beta)?;
4743 if eval.obs.neghessian_eta_derivative.iter().any(|&w| w != 0.0) {
4744 return Ok(None);
4745 }
4746 Ok(Some(Array2::<f64>::zeros((beta.len(), beta.len()))))
4747 }
4748
4749 fn joint_jeffreys_information_matches_observed_hessian(&self) -> bool {
4750 false
4755 }
4756
4757 fn joint_jeffreys_information_depends_on_psi(&self) -> bool {
4758 false
4767 }
4768
4769 fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
4770 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4771 let eval = self.exact_evaluation(beta)?;
4772 let mut gradient = fast_atv(&self.design, &eval.obs.score);
4773 gradient -= &eval.total_penaltygradient();
4774 let mut hessian = xt_diag_x_dense(self.design.view(), eval.obs.neghessian_eta.view())?;
4775 hessian += &eval.total_penaltyhessian();
4776 Ok(FamilyEvaluation {
4777 log_likelihood: eval.obs.log_likelihood - eval.total_penalty_value(),
4778 blockworking_sets: vec![BlockWorkingSet::ExactNewton {
4779 gradient,
4780 hessian: SymmetricMatrix::Dense(hessian),
4781 }],
4782 })
4783 }
4784
4785 fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
4786 let state = expect_single_block_state(block_states, "spatial adaptive exact family")?;
4787 let beta = &state.beta;
4788 let obs = evaluate_standard_familyobservations(
4789 self.family.clone(),
4790 self.latent_cloglog_state.as_ref(),
4791 self.mixture_link_state.as_ref(),
4792 self.sas_link_state.as_ref(),
4793 &self.y,
4794 &self.weights,
4795 &state.eta,
4796 )
4797 .map_err(|e| e.to_string())?;
4798 let adaptive_penalty = self.adaptive_penalty_value_only(beta)?;
4799 let (fixed_quadratic, _) = self.fixed_quadratic_terms(beta);
4800 Ok(obs.log_likelihood - adaptive_penalty - fixed_quadratic)
4801 }
4802
4803 fn exact_newton_outerobjective(&self) -> ExactNewtonOuterObjective {
4804 ExactNewtonOuterObjective::StrictPseudoLaplace
4805 }
4806
4807 fn exact_newton_joint_hessian(
4808 &self,
4809 block_states: &[ParameterBlockState],
4810 ) -> Result<Option<Array2<f64>>, String> {
4811 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4812 let eval = self.exact_evaluation(beta)?;
4813 Ok(Some(eval.totalobjectivehessian(&self.design)?))
4814 }
4815
4816 fn exact_newton_hessian_directional_derivative(
4817 &self,
4818 block_states: &[ParameterBlockState],
4819 block_idx: usize,
4820 d_beta: &Array1<f64>,
4821 ) -> Result<Option<Array2<f64>>, String> {
4822 expect_block_idx_zero(block_idx, "spatial adaptive exact family", "")?;
4823 self.exact_newton_joint_hessian_directional_derivative(block_states, d_beta)
4824 }
4825
4826 fn exact_newton_joint_hessian_directional_derivative(
4827 &self,
4828 block_states: &[ParameterBlockState],
4829 d_beta_flat: &Array1<f64>,
4830 ) -> Result<Option<Array2<f64>>, String> {
4831 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4832 if d_beta_flat.len() != beta.len() {
4833 return Err(SmoothError::dimension_mismatch(format!(
4834 "spatial adaptive exact family direction length mismatch: got {}, expected {}",
4835 d_beta_flat.len(),
4836 beta.len()
4837 ))
4838 .into());
4839 }
4840 let eval = self.exact_evaluation(beta)?;
4841 Ok(Some(
4842 self.exacthessian_directional_derivative_from_evaluation(beta, &eval, d_beta_flat)?,
4843 ))
4844 }
4845
4846 fn exact_newton_joint_hessiansecond_directional_derivative(
4847 &self,
4848 block_states: &[ParameterBlockState],
4849 d_beta_u_flat: &Array1<f64>,
4850 d_betav_flat: &Array1<f64>,
4851 ) -> Result<Option<Array2<f64>>, String> {
4852 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4853 if d_beta_u_flat.len() != beta.len() || d_betav_flat.len() != beta.len() {
4854 return Err(SmoothError::dimension_mismatch(format!(
4855 "spatial adaptive exact family second-direction length mismatch: got ({}, {}), expected {}",
4856 d_beta_u_flat.len(),
4857 d_betav_flat.len(),
4858 beta.len()
4859 ))
4860 .into());
4861 }
4862 let eval = self.exact_evaluation(beta)?;
4863 self.exacthessian_second_directional_derivative_from_evaluation(
4864 &eval,
4865 d_beta_u_flat,
4866 d_betav_flat,
4867 )
4868 }
4869
4870 fn block_linear_constraints(
4871 &self,
4872 block_states: &[ParameterBlockState],
4873 block_idx: usize,
4874 block_spec: &ParameterBlockSpec,
4875 ) -> Result<Option<LinearInequalityConstraints>, String> {
4876 assert!(!block_states.is_empty(), "block_states must be non-empty");
4877 assert!(
4878 !block_spec.name.is_empty(),
4879 "block spec name must be non-empty",
4880 );
4881 expect_block_idx_zero(block_idx, "spatial adaptive exact family", "")?;
4882 Ok(self.linear_constraints.clone())
4883 }
4884
4885 fn exact_newton_joint_psi_terms(
4886 &self,
4887 block_states: &[ParameterBlockState],
4888 specs: &[ParameterBlockSpec],
4889 derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
4890 psi_index: usize,
4891 ) -> Result<Option<ExactNewtonJointPsiTerms>, String> {
4892 if block_states.len() != 1 || specs.len() != 1 || derivative_blocks.len() != 1 {
4893 return Err(SmoothError::dimension_mismatch(format!(
4894 "spatial adaptive exact family expects one block/state/spec/psi payload, got states={} specs={} deriv_blocks={}",
4895 block_states.len(),
4896 specs.len(),
4897 derivative_blocks.len()
4898 ))
4899 .into());
4900 }
4901 derivative_blocks[0]
4902 .get(psi_index)
4903 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_index))?;
4904 let hyper = self
4905 .hyperspecs
4906 .get(psi_index)
4907 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_index))?;
4908 let beta = &block_states[0].beta;
4909 let eval = self.exact_evaluation(beta)?;
4910 let (direct, beta_mixed, betahessian_explicit) =
4911 self.adaptive_hyper_parts(&eval, *hyper)?;
4912
4913 Ok(Some(ExactNewtonJointPsiTerms {
4934 objective_psi: direct,
4935 score_psi: beta_mixed,
4936 hessian_psi: betahessian_explicit,
4937 hessian_psi_operator: None,
4938 }))
4939 }
4940
4941 fn exact_newton_joint_psisecond_order_terms(
4942 &self,
4943 block_states: &[ParameterBlockState],
4944 specs: &[ParameterBlockSpec],
4945 derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
4946 psi_i: usize,
4947 psi_j: usize,
4948 ) -> Result<Option<gam_problem::ExactNewtonJointPsiSecondOrderTerms>, String> {
4949 if block_states.len() != 1 || specs.len() != 1 || derivative_blocks.len() != 1 {
4950 return Err(SmoothError::dimension_mismatch(format!(
4951 "spatial adaptive exact family expects one block/state/spec/psi payload, got states={} specs={} deriv_blocks={}",
4952 block_states.len(),
4953 specs.len(),
4954 derivative_blocks.len()
4955 ))
4956 .into());
4957 }
4958 derivative_blocks[0]
4959 .get(psi_i)
4960 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_i))?;
4961 derivative_blocks[0]
4962 .get(psi_j)
4963 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_j))?;
4964 let hyper_i = self
4965 .hyperspecs
4966 .get(psi_i)
4967 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_i))?;
4968 let hyper_j = self
4969 .hyperspecs
4970 .get(psi_j)
4971 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_j))?;
4972 let beta = &block_states[0].beta;
4973 let eval = self.exact_evaluation(beta)?;
4974 let (objective_psi_psi, score_psi_psi, hessian_psi_psi) =
4975 self.adaptive_explicit_second_order_parts(&eval, *hyper_i, *hyper_j)?;
4976
4977 Ok(Some(
4978 gam_problem::ExactNewtonJointPsiSecondOrderTerms {
4979 objective_psi_psi,
4980 score_psi_psi,
4981 hessian_psi_psi,
4982 hessian_psi_psi_operator: None,
4983 },
4984 ))
4985 }
4986
4987 fn exact_newton_joint_psihessian_directional_derivative(
4988 &self,
4989 block_states: &[ParameterBlockState],
4990 specs: &[ParameterBlockSpec],
4991 derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
4992 psi_index: usize,
4993 direction: &Array1<f64>,
4994 ) -> Result<Option<Array2<f64>>, String> {
4995 if block_states.len() != 1 || specs.len() != 1 || derivative_blocks.len() != 1 {
4996 return Err(SmoothError::dimension_mismatch(format!(
4997 "spatial adaptive exact family expects one block/state/spec/psi payload, got states={} specs={} deriv_blocks={}",
4998 block_states.len(),
4999 specs.len(),
5000 derivative_blocks.len()
5001 ))
5002 .into());
5003 }
5004 let beta = &block_states[0].beta;
5005 if direction.len() != beta.len() {
5006 return Err(SmoothError::dimension_mismatch(format!(
5007 "spatial adaptive exact family direction length mismatch: got {}, expected {}",
5008 direction.len(),
5009 beta.len()
5010 ))
5011 .into());
5012 }
5013 derivative_blocks[0]
5014 .get(psi_index)
5015 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_index))?;
5016 let hyper = self
5017 .hyperspecs
5018 .get(psi_index)
5019 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_index))?;
5020 let eval = self.exact_evaluation(beta)?;
5021 let drift = match hyper.kind {
5022 SpatialAdaptiveHyperKind::LogLambdaMagnitude
5023 | SpatialAdaptiveHyperKind::LogLambdaGradient
5024 | SpatialAdaptiveHyperKind::LogLambdaCurvature => self.adaptive_block_drift_eval(
5025 &eval,
5026 hyper.cache_index,
5027 AdaptiveComponent::from_index(hyper.kind.component_index())?,
5028 HyperDriftKind::Rho,
5029 direction,
5030 )?,
5031 SpatialAdaptiveHyperKind::LogEpsilonMagnitude
5032 | SpatialAdaptiveHyperKind::LogEpsilonGradient
5033 | SpatialAdaptiveHyperKind::LogEpsilonCurvature => self
5034 .adaptive_shared_log_epsilon_drift(
5035 &eval,
5036 hyper.kind.component_index(),
5037 direction,
5038 )?,
5039 };
5040 Ok(Some(drift))
5041 }
5042}
5043
5044fn expect_single_block_state<'a>(
5045 block_states: &'a [ParameterBlockState],
5046 family_name: &str,
5047) -> Result<&'a ParameterBlockState, String> {
5048 crate::block_layout::block_count::validate_block_count::<SmoothError>(
5049 family_name,
5050 1,
5051 block_states.len(),
5052 )?;
5053 Ok(&block_states[0])
5054}
5055
5056fn expect_single_blockspec<'a>(
5057 specs: &'a [ParameterBlockSpec],
5058 family_name: &str,
5059) -> Result<&'a ParameterBlockSpec, String> {
5060 crate::block_layout::block_count::validate_block_count::<SmoothError>(
5061 family_name,
5062 1,
5063 specs.len(),
5064 )?;
5065 Ok(&specs[0])
5066}
5067
5068fn expect_block_idx_zero(block_idx: usize, family_name: &str, context: &str) -> Result<(), String> {
5069 if block_idx != 0 {
5070 return Err(SmoothError::invalid_index(format!(
5071 "{family_name} expects block_idx 0{context}, got {block_idx}"
5072 ))
5073 .into());
5074 }
5075 Ok::<(), _>(())
5076}
5077
5078impl BoundedLinearFamily {
5079 fn bounded_term_derivative_data(
5080 &self,
5081 latent_beta: &Array1<f64>,
5082 ) -> (
5083 Array1<f64>,
5084 Array1<f64>,
5085 Array1<f64>,
5086 Array1<f64>,
5087 Array1<f64>,
5088 ) {
5089 let p = latent_beta.len();
5090 let mut beta_user = latent_beta.clone();
5091 let mut jac_diag = Array1::<f64>::ones(p);
5092 let mut second_diag = Array1::<f64>::zeros(p);
5093 let mut third_diag = Array1::<f64>::zeros(p);
5094 let mut priorthird = Array1::<f64>::zeros(p);
5095 for term in &self.bounded_terms {
5096 let (beta, _, db_dtheta, d2b_dtheta2, d3b_dtheta3) =
5097 bounded_latent_derivatives(latent_beta[term.col_idx], term.min, term.max);
5098 beta_user[term.col_idx] = beta;
5099 jac_diag[term.col_idx] = db_dtheta;
5100 second_diag[term.col_idx] = d2b_dtheta2;
5101 third_diag[term.col_idx] = d3b_dtheta3;
5102 let (_, _, _, prior_neghess_derivative) =
5103 bounded_prior_terms(latent_beta[term.col_idx], &term.prior);
5104 priorthird[term.col_idx] = prior_neghess_derivative;
5105 }
5106 (beta_user, jac_diag, second_diag, third_diag, priorthird)
5107 }
5108
5109 fn user_beta_and_jacobian(&self, latent_beta: &Array1<f64>) -> (Array1<f64>, Array1<f64>) {
5110 let (beta_user, jac_diag, _, _, _) = self.bounded_term_derivative_data(latent_beta);
5111 (beta_user, jac_diag)
5112 }
5113
5114 fn nonlinear_offset_from_latent(&self, latent_beta: &Array1<f64>) -> Array1<f64> {
5115 let mut offset = self.offset.clone();
5116 for term in &self.bounded_terms {
5117 let (beta, _, _) =
5118 bounded_latent_to_user(latent_beta[term.col_idx], term.min, term.max);
5119 offset.scaled_add(beta, &self.design.column(term.col_idx));
5120 }
5121 offset
5122 }
5123
5124 fn effective_design_for_latent(&self, jac_diag: &Array1<f64>) -> Array2<f64> {
5125 let mut x_eff = self.design.clone();
5126 for term in &self.bounded_terms {
5127 x_eff
5128 .column_mut(term.col_idx)
5129 .mapv_inplace(|v| v * jac_diag[term.col_idx]);
5130 }
5131 x_eff
5132 }
5133
5134 fn exacthessian_andgradient(
5135 &self,
5136 latent_beta: &Array1<f64>,
5137 ) -> Result<
5138 (
5139 StandardFamilyObservationState,
5140 Array2<f64>,
5141 Array1<f64>,
5142 f64,
5143 Array1<f64>,
5144 Array1<f64>,
5145 Array1<f64>,
5146 ),
5147 String,
5148 > {
5149 let (_, jac_diag, second_diag, third_diag, priorthird) =
5150 self.bounded_term_derivative_data(latent_beta);
5151 let x_eff = self.effective_design_for_latent(&jac_diag);
5152 let eta =
5153 self.designzeroed.dot(latent_beta) + self.nonlinear_offset_from_latent(latent_beta);
5154 let obs = evaluate_standard_familyobservations(
5155 self.family.clone(),
5156 self.latent_cloglog_state.as_ref(),
5157 self.mixture_link_state.as_ref(),
5158 self.sas_link_state.as_ref(),
5159 &self.y,
5160 &self.weights,
5161 &eta,
5162 )
5163 .map_err(|e| e.to_string())?;
5164
5165 let mut priorgrad = Array1::<f64>::zeros(latent_beta.len());
5166 let mut prior_neghess = Array2::<f64>::zeros((latent_beta.len(), latent_beta.len()));
5167 let mut prior_loglik = 0.0;
5168 for term in &self.bounded_terms {
5169 let (logp, grad, neghess, _) =
5170 bounded_prior_terms(latent_beta[term.col_idx], &term.prior);
5171 prior_loglik += logp;
5172 priorgrad[term.col_idx] += grad;
5173 prior_neghess[[term.col_idx, term.col_idx]] += neghess;
5174 }
5175
5176 let mut hessian = xt_diag_x_dense(x_eff.view(), obs.neghessian_eta.view())?;
5177 let mut gradient = fast_atv(&x_eff, &obs.score);
5178 for term in &self.bounded_terms {
5179 let score_beta = self.design.column(term.col_idx).dot(&obs.score);
5180 hessian[[term.col_idx, term.col_idx]] -= score_beta * second_diag[term.col_idx];
5181 }
5182 hessian += &prior_neghess;
5183 gradient += &priorgrad;
5184
5185 Ok((
5186 obs,
5187 hessian,
5188 gradient,
5189 prior_loglik,
5190 second_diag,
5191 third_diag,
5192 priorthird,
5193 ))
5194 }
5195
5196 fn evaluation_from_latent(
5197 &self,
5198 latent_beta: &Array1<f64>,
5199 ) -> Result<
5200 (
5201 StandardFamilyObservationState,
5202 Array2<f64>,
5203 Array1<f64>,
5204 f64,
5205 ),
5206 String,
5207 > {
5208 let (obs, hessian, gradient, prior_loglik, _, _, _) =
5209 self.exacthessian_andgradient(latent_beta)?;
5210 Ok((obs, hessian, gradient, prior_loglik))
5211 }
5212}
5213
5214impl CustomFamily for BoundedLinearFamily {
5215 fn joint_jeffreys_term_required(&self) -> bool {
5219 true
5220 }
5221
5222 fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
5223 let latent_beta = &expect_single_block_state(block_states, "bounded linear family")?.beta;
5224 let (obs, hessian, gradient, prior_loglik) = self.evaluation_from_latent(latent_beta)?;
5225 Ok(FamilyEvaluation {
5226 log_likelihood: obs.log_likelihood + prior_loglik,
5227 blockworking_sets: vec![BlockWorkingSet::ExactNewton {
5228 gradient,
5229 hessian: SymmetricMatrix::Dense(hessian),
5230 }],
5231 })
5232 }
5233
5234 fn exact_newton_joint_hessian(
5235 &self,
5236 block_states: &[ParameterBlockState],
5237 ) -> Result<Option<Array2<f64>>, String> {
5238 let latent_beta = &expect_single_block_state(block_states, "bounded linear family")?.beta;
5239 let (_, hessian, _, _) = self.evaluation_from_latent(latent_beta)?;
5240 Ok(Some(hessian))
5241 }
5242
5243 fn exact_newton_hessian_directional_derivative(
5244 &self,
5245 block_states: &[ParameterBlockState],
5246 block_idx: usize,
5247 d_beta: &Array1<f64>,
5248 ) -> Result<Option<Array2<f64>>, String> {
5249 expect_block_idx_zero(block_idx, "bounded linear family", "")?;
5250 self.exact_newton_joint_hessian_directional_derivative(block_states, d_beta)
5251 }
5252
5253 fn exact_newton_joint_hessian_directional_derivative(
5254 &self,
5255 block_states: &[ParameterBlockState],
5256 d_beta_flat: &Array1<f64>,
5257 ) -> Result<Option<Array2<f64>>, String> {
5258 let latent_beta = &expect_single_block_state(block_states, "bounded linear family")?.beta;
5259 if d_beta_flat.len() != latent_beta.len() {
5260 return Err(SmoothError::dimension_mismatch(format!(
5261 "bounded linear family directional derivative length mismatch: got {}, expected {}",
5262 d_beta_flat.len(),
5263 latent_beta.len()
5264 ))
5265 .into());
5266 }
5267
5268 let (obs, _, _, _, second_diag, third_diag, priorthird) =
5269 self.exacthessian_andgradient(latent_beta)?;
5270
5271 let (_, jac_diag, _, _, _) = self.bounded_term_derivative_data(latent_beta);
5272 let x_eff = self.effective_design_for_latent(&jac_diag);
5273 let deta = x_eff.dot(d_beta_flat);
5274 let d_neghess_eta = &obs.neghessian_eta_derivative * &deta;
5275
5276 let mut dx_eff = Array2::<f64>::zeros(x_eff.raw_dim());
5277 for term in &self.bounded_terms {
5278 let scale = second_diag[term.col_idx] * d_beta_flat[term.col_idx];
5279 if scale != 0.0 {
5280 let mut col = dx_eff.column_mut(term.col_idx);
5281 col.assign(&self.design.column(term.col_idx));
5282 col.mapv_inplace(|v| v * scale);
5283 }
5284 }
5285
5286 let mut dhessian = xt_diag_x_dense(x_eff.view(), d_neghess_eta.view())?;
5287 let mut wxdx = Array2::<f64>::zeros((x_eff.ncols(), x_eff.ncols()));
5288 for i in 0..x_eff.nrows() {
5289 let wi = obs.neghessian_eta[i];
5290 if wi == 0.0 {
5291 continue;
5292 }
5293 for a in 0..x_eff.ncols() {
5294 let xa = x_eff[[i, a]];
5295 for b in 0..x_eff.ncols() {
5296 wxdx[[a, b]] += wi * (dx_eff[[i, a]] * x_eff[[i, b]] + xa * dx_eff[[i, b]]);
5297 }
5298 }
5299 }
5300 dhessian += &wxdx;
5301
5302 let d_score = -&obs.neghessian_eta * &deta;
5303 for term in &self.bounded_terms {
5304 let score_beta = self.design.column(term.col_idx).dot(&obs.score);
5305 let d_score_beta = self.design.column(term.col_idx).dot(&d_score);
5306 dhessian[[term.col_idx, term.col_idx]] -= d_score_beta * second_diag[term.col_idx]
5307 + score_beta * third_diag[term.col_idx] * d_beta_flat[term.col_idx];
5308 dhessian[[term.col_idx, term.col_idx]] +=
5309 priorthird[term.col_idx] * d_beta_flat[term.col_idx];
5310 }
5311
5312 Ok(Some(dhessian))
5313 }
5314
5315 fn block_geometry(
5316 &self,
5317 block_states: &[ParameterBlockState],
5318 spec: &ParameterBlockSpec,
5319 ) -> Result<(DesignMatrix, Array1<f64>), String> {
5320 if block_states.is_empty() {
5321 return Ok((
5322 DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
5323 self.designzeroed.clone(),
5324 )),
5325 self.offset.clone(),
5326 ));
5327 }
5328 let offset = self.nonlinear_offset_from_latent(
5329 &expect_single_block_state(block_states, "bounded linear family")?.beta,
5330 );
5331 let x = if spec.design.ncols() == self.designzeroed.ncols() {
5332 self.designzeroed.clone()
5333 } else {
5334 return Err(SmoothError::dimension_mismatch(
5335 "bounded linear family design column mismatch",
5336 )
5337 .into());
5338 };
5339 Ok((
5340 DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(x)),
5341 offset,
5342 ))
5343 }
5344
5345 fn block_geometry_is_dynamic(&self) -> bool {
5346 true
5347 }
5348
5349 fn block_geometry_directional_derivative(
5350 &self,
5351 block_states: &[ParameterBlockState],
5352 block_idx: usize,
5353 spec: &ParameterBlockSpec,
5354 d_beta: &Array1<f64>,
5355 ) -> Result<Option<BlockGeometryDirectionalDerivative>, String> {
5356 expect_block_idx_zero(
5357 block_idx,
5358 "bounded linear family",
5359 " for geometry derivative",
5360 )?;
5361 expect_single_block_state(block_states, "bounded linear family")?;
5362 if d_beta.len() != spec.design.ncols() {
5363 return Err(SmoothError::dimension_mismatch(format!(
5364 "bounded linear family geometry derivative direction mismatch: got {}, expected {}",
5365 d_beta.len(),
5366 spec.design.ncols()
5367 ))
5368 .into());
5369 }
5370 let (_, jac_diag, _, _, _) = self.bounded_term_derivative_data(&block_states[0].beta);
5371 let mut d_offset = Array1::<f64>::zeros(self.offset.len());
5372 let has_drift = self
5373 .bounded_terms
5374 .iter()
5375 .any(|term| jac_diag[term.col_idx] != 0.0 && d_beta[term.col_idx] != 0.0);
5376 if !has_drift {
5377 return Ok(Some(BlockGeometryDirectionalDerivative {
5378 d_design: None,
5379 d_offset,
5380 }));
5381 }
5382 for term in &self.bounded_terms {
5383 let col = term.col_idx;
5384 let drift = jac_diag[col] * d_beta[col];
5385 if drift != 0.0 {
5386 d_offset.scaled_add(drift, &self.design.column(col));
5387 }
5388 }
5389 Ok(Some(BlockGeometryDirectionalDerivative {
5390 d_design: None,
5391 d_offset,
5392 }))
5393 }
5394}
5395
5396#[inline]
5397fn dense_diag_gram_chunkrows(p: usize) -> usize {
5398 const MIN_ROWS: usize = 512;
5399 const MAX_ROWS: usize = 2048;
5400 const TARGET_BYTES: usize = 2 * 1024 * 1024;
5401 let bytes_per_row = p.max(1) * std::mem::size_of::<f64>();
5402 (TARGET_BYTES / bytes_per_row).clamp(MIN_ROWS, MAX_ROWS)
5403}
5404
5405fn xt_diag_x_dense(x: ArrayView2<'_, f64>, w: ArrayView1<'_, f64>) -> Result<Array2<f64>, String> {
5406 if x.nrows() != w.len() {
5407 return Err(SmoothError::dimension_mismatch("xt_diag_x_dense row mismatch").into());
5408 }
5409 let (n, p) = x.dim();
5410 if n == 0 || p == 0 {
5411 return Ok(Array2::<f64>::zeros((p, p)));
5412 }
5413
5414 const STREAMING_BYTES_THRESHOLD: usize = 8 * 1024 * 1024;
5415 let dense_work_bytes = n
5416 .checked_mul(p)
5417 .and_then(|cells| cells.checked_mul(std::mem::size_of::<f64>()))
5418 .unwrap_or(usize::MAX);
5419 if dense_work_bytes <= STREAMING_BYTES_THRESHOLD {
5420 let mut weighted = x.to_owned();
5421 ndarray::Zip::from(weighted.rows_mut())
5422 .and(w)
5423 .par_for_each(|mut row, wi| row *= *wi);
5424 return Ok(fast_atb(&x, &weighted));
5425 }
5426
5427 let chunkrows = dense_diag_gram_chunkrows(p).min(n);
5428 let mut weighted_chunk = Array2::<f64>::zeros((chunkrows, p));
5429 let mut out = Array2::<f64>::zeros((p, p));
5430 for row_start in (0..n).step_by(chunkrows) {
5431 let rows = (n - row_start).min(chunkrows);
5432 let x_chunk = x.slice(s![row_start..row_start + rows, ..]);
5433 {
5434 let mut chunk = weighted_chunk.slice_mut(s![0..rows, ..]);
5435 for local_row in 0..rows {
5436 let scale = w[row_start + local_row];
5437 if scale == 0.0 {
5438 chunk.row_mut(local_row).fill(0.0);
5439 continue;
5440 }
5441 for col in 0..p {
5442 chunk[[local_row, col]] = x_chunk[[local_row, col]] * scale;
5443 }
5444 }
5445 }
5446 out += &fast_atb(&x_chunk, &weighted_chunk.slice(s![0..rows, ..]));
5447 }
5448 Ok(out)
5449}
5450
5451fn trace_of_dense_product(a: &Array2<f64>, b: &Array2<f64>) -> Result<f64, String> {
5452 if a.nrows() != a.ncols() || b.nrows() != b.ncols() || a.nrows() != b.nrows() {
5453 return Err(
5454 SmoothError::dimension_mismatch("trace_of_dense_product dimension mismatch").into(),
5455 );
5456 }
5457 let mut trace = 0.0;
5458 for i in 0..a.nrows() {
5459 for j in 0..a.ncols() {
5460 trace += a[[i, j]] * b[[j, i]];
5461 }
5462 }
5463 Ok(trace)
5464}
5465
5466fn exact_bounded_edf(
5467 penalties: &[PenaltySpec],
5468 lambdas: &Array1<f64>,
5469 latent_cov: &Array2<f64>,
5470) -> Result<(Vec<f64>, Vec<f64>, f64), EstimationError> {
5471 if penalties.len() != lambdas.len() {
5472 crate::bail_invalid_estim!(
5473 "bounded EDF penalty/lambda mismatch: {} penalties vs {} lambdas",
5474 penalties.len(),
5475 lambdas.len()
5476 );
5477 }
5478 if latent_cov.nrows() != latent_cov.ncols() {
5479 crate::bail_invalid_estim!("bounded EDF covariance must be square");
5480 }
5481
5482 let p = latent_cov.nrows();
5483 let mut s_lambda = Array2::<f64>::zeros((p, p));
5484 let mut edf_by_block = Vec::with_capacity(penalties.len());
5485 let mut penalty_block_trace = Vec::with_capacity(penalties.len());
5487 let mut trace_sum = 0.0;
5488
5489 for (k, ps) in penalties.iter().enumerate() {
5490 let lambda_k = lambdas[k];
5491 match ps {
5492 PenaltySpec::Block {
5493 local, col_range, ..
5494 } => {
5495 s_lambda
5496 .slice_mut(ndarray::s![col_range.clone(), col_range.clone()])
5497 .scaled_add(lambda_k, local);
5498 let penalty_rank =
5500 local
5501 .nrows()
5502 .saturating_sub(estimate_penalty_nullity(local).map_err(|e| {
5503 EstimationError::InvalidInput(format!("bounded EDF rank failed: {e}"))
5504 })?);
5505 let cov_block = latent_cov.slice(ndarray::s![col_range.clone(), col_range.clone()]);
5507 let trace_k = lambda_k
5508 * trace_of_dense_product(&cov_block.to_owned(), local)
5509 .map_err(EstimationError::InvalidInput)?;
5510 trace_sum += trace_k;
5511 penalty_block_trace.push(trace_k);
5512 let p_k = penalty_rank as f64;
5513 edf_by_block.push((p_k - trace_k).clamp(0.0, p_k));
5514 }
5515 PenaltySpec::Dense(m) | PenaltySpec::DenseWithMean { matrix: m, .. } => {
5516 s_lambda.scaled_add(lambda_k, m);
5517 let penalty_rank = p.saturating_sub(estimate_penalty_nullity(m).map_err(|e| {
5518 EstimationError::InvalidInput(format!("bounded EDF rank failed: {e}"))
5519 })?);
5520 let trace_k = lambda_k
5521 * trace_of_dense_product(latent_cov, m)
5522 .map_err(EstimationError::InvalidInput)?;
5523 trace_sum += trace_k;
5524 penalty_block_trace.push(trace_k);
5525 let p_k = penalty_rank as f64;
5526 edf_by_block.push((p_k - trace_k).clamp(0.0, p_k));
5527 }
5528 }
5529 }
5530
5531 let nullity_total = estimate_penalty_nullity(&s_lambda)
5532 .map_err(|e| EstimationError::InvalidInput(format!("bounded EDF nullity failed: {e}")))?
5533 as f64;
5534 let edf_total = (p as f64 - trace_sum).clamp(nullity_total, p as f64);
5535 Ok((edf_by_block, penalty_block_trace, edf_total))
5536}
5537
5538fn symmetric_positive_definite_inverse_or_pseudo(
5550 precision: &Array2<f64>,
5551) -> Result<Array2<f64>, EstimationError> {
5552 use gam_linalg::faer_ndarray::FaerEigh;
5553 let p = precision.nrows();
5554 if precision.ncols() != p {
5555 crate::bail_invalid_estim!(
5556 "posterior precision inverse requires a square matrix, got {}x{}",
5557 precision.nrows(),
5558 precision.ncols()
5559 );
5560 }
5561 if p == 0 {
5562 return Ok(Array2::<f64>::zeros((0, 0)));
5563 }
5564 let symmetric = (precision + &precision.t().to_owned()) * 0.5;
5565 let (evals, evecs) = symmetric.eigh(faer::Side::Lower).map_err(|e| {
5566 EstimationError::InvalidInput(format!(
5567 "posterior precision eigendecomposition failed: {e}"
5568 ))
5569 })?;
5570 let max_abs_eval = evals.iter().fold(0.0_f64, |acc, &ev| acc.max(ev.abs()));
5571 let tol =
5572 (10.0 * f64::EPSILON * (p as f64) * (p as f64) * max_abs_eval).max(100.0 * f64::EPSILON);
5573 if let Some(&min_eval) = evals
5574 .iter()
5575 .filter(|&&ev| ev < -tol)
5576 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
5577 {
5578 crate::bail_invalid_estim!(
5579 "bounded posterior precision is non-PD at the converged optimum (min eigenvalue \
5580 {min_eval:.6e} < -tol={tol:.6e}); the reported mode is not a strict posterior \
5581 maximum, so a covariance would be meaningless"
5582 );
5583 }
5584 let mut scaled = evecs.clone();
5586 for (j, &ev) in evals.iter().enumerate() {
5587 let inv = if ev > tol { 1.0 / ev } else { 0.0 };
5588 scaled.column_mut(j).mapv_inplace(|v| v * inv);
5589 }
5590 let cov = scaled.dot(&evecs.t());
5591 Ok((&cov + &cov.t().to_owned()) * 0.5)
5592}
5593
5594fn transform_bounded_latent_precision_to_user_internal(
5595 latent_precision: &Array2<f64>,
5596 jac_diag: &Array1<f64>,
5597) -> Result<Array2<f64>, EstimationError> {
5598 let p = latent_precision.nrows();
5599 if latent_precision.ncols() != p || jac_diag.len() != p {
5600 crate::bail_invalid_estim!(
5601 "bounded precision transform dimension mismatch: precision is {}x{}, jacobian has {} entries",
5602 latent_precision.nrows(),
5603 latent_precision.ncols(),
5604 jac_diag.len()
5605 );
5606 }
5607 let mut out = latent_precision.clone();
5608 for i in 0..p {
5609 let scale = jac_diag[i];
5610 if !scale.is_finite() || scale <= 0.0 {
5611 crate::bail_invalid_estim!(
5612 "bounded precision transform requires a positive finite coefficient jacobian; column {i} has {scale}"
5613 );
5614 }
5615 if scale != 1.0 {
5616 out.row_mut(i).mapv_inplace(|v| v / scale);
5617 out.column_mut(i).mapv_inplace(|v| v / scale);
5618 }
5619 }
5620 Ok(out)
5621}
5622
5623fn fit_bounded_term_collection_with_design(
5624 y: ArrayView1<'_, f64>,
5625 weights: ArrayView1<'_, f64>,
5626 offset: ArrayView1<'_, f64>,
5627 spec: &TermCollectionSpec,
5628 design: &TermCollectionDesign,
5629 heuristic_lambdas: Option<&[f64]>,
5630 family: LikelihoodSpec,
5631 options: &FitOptions,
5632) -> Result<FittedTermCollection, EstimationError> {
5633 let conditioning_cols: Vec<usize> = spec
5634 .linear_terms
5635 .iter()
5636 .enumerate()
5637 .filter_map(|(j, linear)| {
5638 (!linear.double_penalty).then_some(design.intercept_range.end + j)
5639 })
5640 .collect();
5641 let conditioning = LinearFitConditioning::from_columns(design, &conditioning_cols);
5642 let dense_design = design.design.to_dense_cow();
5643 let fit_design = conditioning.apply_to_design(&dense_design);
5644 let fit_penalties = conditioning
5645 .transform_blockwise_penalties_to_internal(&design.penalties, design.design.ncols());
5646 if design.linear_constraints.is_some() {
5647 crate::bail_invalid_estim!(
5648 "bounded() terms are not yet compatible with explicit linear constraints"
5649 );
5650 }
5651 let mut bounded_terms = Vec::<BoundedLinearTermMeta>::new();
5652 for (j, term) in spec.linear_terms.iter().enumerate() {
5653 if term.double_penalty
5654 && matches!(
5655 term.coefficient_geometry,
5656 LinearCoefficientGeometry::Bounded { .. }
5657 )
5658 {
5659 crate::bail_invalid_estim!(
5660 "bounded linear term '{}' cannot also use double_penalty",
5661 term.name
5662 );
5663 }
5664 if let LinearCoefficientGeometry::Bounded { min, max, prior } =
5665 term.coefficient_geometry.clone()
5666 {
5667 let col_idx = design.intercept_range.end + j;
5668 let (min_internal, max_internal) = conditioning.internal_bounds_for(col_idx, min, max);
5669 bounded_terms.push(BoundedLinearTermMeta {
5670 col_idx,
5671 min: min_internal,
5672 max: max_internal,
5673 prior,
5674 });
5675 }
5676 }
5677 if bounded_terms.is_empty() {
5678 crate::bail_invalid_estim!("internal bounded fit path called with no bounded terms");
5679 }
5680
5681 let mut designzeroed = fit_design.clone();
5682 let mut initial_beta = Array1::<f64>::zeros(fit_design.ncols());
5683 for term in &bounded_terms {
5684 designzeroed.column_mut(term.col_idx).fill(0.0);
5685 initial_beta[term.col_idx] = bounded_logit(0.5);
5686 }
5687
5688 let initial_log_lambdas = heuristic_lambdas
5689 .map(|vals| Array1::from_vec(vals.to_vec()))
5690 .unwrap_or_else(|| Array1::zeros(fit_penalties.len()));
5691 if initial_log_lambdas.len() != fit_penalties.len() {
5692 crate::bail_invalid_estim!(
5693 "heuristic lambda length mismatch for bounded model: got {}, expected {}",
5694 initial_log_lambdas.len(),
5695 fit_penalties.len()
5696 );
5697 }
5698
5699 let is_beta_logistic = family.is_binomial_beta_logistic();
5700 let family_adapter = BoundedLinearFamily {
5701 family: family.clone(),
5702 latent_cloglog_state: options.latent_cloglog,
5703 mixture_link_state: options
5704 .mixture_link
5705 .clone()
5706 .as_ref()
5707 .map(state_fromspec)
5708 .transpose()
5709 .map_err(EstimationError::InvalidInput)?,
5710 sas_link_state: options
5711 .sas_link
5712 .map(|spec| {
5713 if is_beta_logistic {
5714 state_from_beta_logisticspec(spec)
5715 } else {
5716 state_from_sasspec(spec)
5717 }
5718 })
5719 .transpose()
5720 .map_err(EstimationError::InvalidInput)?,
5721 y: y.to_owned(),
5722 weights: weights.to_owned(),
5723 design: fit_design.clone(),
5724 designzeroed: designzeroed.clone(),
5725 offset: offset.to_owned(),
5726 bounded_terms: bounded_terms.clone(),
5727 };
5728 let blockspec = ParameterBlockSpec {
5729 name: "eta".to_string(),
5730 design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(designzeroed)),
5731 offset: offset.to_owned(),
5732 penalties: fit_penalties
5733 .iter()
5734 .map(|ps| match ps {
5735 PenaltySpec::Block {
5736 local, col_range, ..
5737 } => PenaltyMatrix::Blockwise {
5738 local: local.clone(),
5739 col_range: col_range.clone(),
5740 total_dim: design.design.ncols(),
5741 },
5742 PenaltySpec::Dense(m) | PenaltySpec::DenseWithMean { matrix: m, .. } => {
5743 PenaltyMatrix::Dense(m.clone())
5744 }
5745 })
5746 .collect(),
5747 nullspace_dims: design.nullspace_dims.clone(),
5748 initial_log_lambdas,
5749 initial_beta: Some(initial_beta),
5750 gauge_priority: 100,
5751 jacobian_callback: Some(Arc::new(BoundedEffectiveJacobian {
5757 design: fit_design.clone(),
5758 bounded_terms: bounded_terms.clone(),
5759 })),
5760 stacked_design: None,
5761 stacked_offset: None,
5762 };
5763 let fit = fit_custom_family(
5764 &family_adapter,
5765 &[blockspec],
5766 &BlockwiseFitOptions {
5767 inner_max_cycles: options.max_iter,
5768 inner_tol: options.tol,
5769 outer_max_iter: options.max_iter,
5770 outer_tol: options.tol,
5771 compute_covariance: false,
5781 ..BlockwiseFitOptions::default()
5782 },
5783 )
5784 .map_err(EstimationError::CustomFamily)?;
5785
5786 let latent_beta = fit.block_states[0].beta.clone();
5787 let (beta_user_internal, jac_diag) = family_adapter.user_beta_and_jacobian(&latent_beta);
5788 let beta_user = conditioning.backtransform_beta(&beta_user_internal);
5789
5790 let (eta_state, h_data, _, _) = family_adapter
5791 .evaluation_from_latent(&latent_beta)
5792 .map_err(EstimationError::InvalidInput)?;
5793 let p_fit = fit_design.ncols();
5794 let mut s_lambda_internal = Array2::<f64>::zeros((p_fit, p_fit));
5795 for (k, penalty) in fit_penalties.iter().enumerate() {
5796 match penalty {
5797 PenaltySpec::Block {
5798 local, col_range, ..
5799 } => {
5800 s_lambda_internal
5801 .slice_mut(ndarray::s![col_range.clone(), col_range.clone()])
5802 .scaled_add(fit.lambdas[k], local);
5803 }
5804 PenaltySpec::Dense(m) | PenaltySpec::DenseWithMean { matrix: m, .. } => {
5805 s_lambda_internal.scaled_add(fit.lambdas[k], m);
5806 }
5807 }
5808 }
5809 let mut latent_precision = h_data.clone();
5810 latent_precision += &s_lambda_internal;
5811 let user_precision_internal =
5812 transform_bounded_latent_precision_to_user_internal(&latent_precision, &jac_diag)?;
5813 let penalized_hessian =
5814 conditioning.transform_penalized_hessian_to_original(&user_precision_internal);
5815
5816 let beta_covariance_unscaled = if options.compute_inference {
5844 Some(symmetric_positive_definite_inverse_or_pseudo(
5845 &penalized_hessian,
5846 )?)
5847 } else {
5848 None
5849 };
5850 let latent_cov = if options.compute_inference {
5856 Some(symmetric_positive_definite_inverse_or_pseudo(
5857 &latent_precision,
5858 )?)
5859 } else {
5860 None
5861 };
5862 let s_lambda_original = weighted_blockwise_penalty_sum(
5863 &design.penalties,
5864 fit.lambdas.as_slice().unwrap(),
5865 design.design.ncols(),
5866 );
5867 let penalty_term = beta_user.dot(&s_lambda_original.dot(&beta_user));
5868 let deviance = if family.is_gaussian_identity() {
5869 y.iter()
5870 .zip(eta_state.mu.iter())
5871 .zip(weights.iter())
5872 .map(|((&yy, &mu), &w)| w.max(0.0) * (yy - mu) * (yy - mu))
5873 .sum()
5874 } else {
5875 -2.0 * eta_state.log_likelihood
5876 };
5877 let (edf_by_block, penalty_block_trace, edf_total) = if let Some(cov) = latent_cov.as_ref() {
5878 exact_bounded_edf(&fit_penalties, &fit.lambdas, cov)?
5879 } else {
5880 (
5881 vec![0.0; fit_penalties.len()],
5882 vec![0.0; fit_penalties.len()],
5883 0.0,
5884 )
5885 };
5886
5887 let glm_likelihood = gam_spec::GlmLikelihoodSpec::canonical(family.clone());
5899 let standard_deviation = if family.is_gaussian_identity() {
5900 let denom = if options.compute_inference {
5901 (y.len() as f64 - edf_total).max(1.0)
5902 } else {
5903 (y.len() as f64).max(1.0)
5904 };
5905 (deviance / denom).sqrt()
5906 } else {
5907 1.0
5908 };
5909 let cov_scale = glm_likelihood
5910 .coefficient_covariance_scale(standard_deviation * standard_deviation)
5911 .max(f64::MIN_POSITIVE);
5912 let dispersion = gam_solve::estimate::dispersion_from_likelihood(&glm_likelihood, standard_deviation);
5913 let beta_covariance = beta_covariance_unscaled.map(|mut cov| {
5919 if cov_scale != 1.0 {
5920 cov.mapv_inplace(|v| v * cov_scale);
5921 }
5922 cov
5923 });
5924 let beta_standard_errors = beta_covariance
5925 .as_ref()
5926 .map(|cov| Array1::from_iter((0..cov.nrows()).map(|i| cov[[i, i]].max(0.0).sqrt())));
5927
5928 let geometry = Some(gam_solve::estimate::FitGeometry {
5929 penalized_hessian: penalized_hessian.clone().into(),
5930 working_weights: eta_state.fisherweight.clone(),
5931 working_response: {
5932 let mut working_response = eta_state.eta.clone();
5933 for i in 0..working_response.len() {
5934 let wi = eta_state.fisherweight[i].max(1e-12);
5935 working_response[i] += eta_state.score[i] / wi;
5936 }
5937 working_response
5938 },
5939 });
5940 let max_abs_eta = eta_state
5941 .eta
5942 .iter()
5943 .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
5944 Ok(FittedTermCollection {
5945 fit: {
5946 let log_lambdas = fit.lambdas.mapv(|v| v.max(1e-300).ln());
5947 let inf = FitInference {
5948 edf_by_block,
5949 penalty_block_trace,
5950 edf_total,
5951 smoothing_correction: None,
5952 penalized_hessian: penalized_hessian.clone().into(),
5955 working_weights: eta_state.fisherweight.clone(),
5956 working_response: {
5957 let mut working_response = eta_state.eta.clone();
5958 for i in 0..working_response.len() {
5959 let wi = eta_state.fisherweight[i].max(1e-12);
5960 working_response[i] += eta_state.score[i] / wi;
5961 }
5962 working_response
5963 },
5964 reparam_qs: None,
5965 dispersion,
5966 beta_covariance: beta_covariance
5967 .clone()
5968 .map(gam_problem::dispersion_cov::PhiScaledCovariance::from),
5969 beta_standard_errors,
5970 beta_covariance_corrected: None,
5971 beta_standard_errors_corrected: None,
5972 beta_covariance_frequentist: None,
5973 coefficient_influence: None,
5974 weighted_gram: None,
5975 bias_correction_beta: None,
5976 };
5977 let covariance_conditional = beta_covariance;
5978 let pirls_status_val = if fit.outer_converged {
5979 gam_solve::pirls::PirlsStatus::Converged
5980 } else {
5981 gam_solve::pirls::PirlsStatus::StalledAtValidMinimum
5982 };
5983 UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
5984 blocks: vec![gam_solve::estimate::FittedBlock {
5985 beta: beta_user.clone(),
5986 role: gam_problem::BlockRole::Mean,
5987 edf: edf_total,
5988 lambdas: fit.lambdas.clone(),
5989 }],
5990 log_lambdas,
5991 lambdas: fit.lambdas,
5992 likelihood_scale: family.default_scale_metadata(),
5993 likelihood_family: Some(family),
5994 log_likelihood_normalization:
5995 gam_spec::LogLikelihoodNormalization::UserProvided,
5996 log_likelihood: eta_state.log_likelihood,
5997 deviance,
5998 reml_score: fit.penalized_objective,
5999 stable_penalty_term: penalty_term,
6000 penalized_objective: fit.penalized_objective,
6001 used_device: false,
6002 outer_iterations: fit.outer_iterations,
6003 outer_converged: fit.outer_converged,
6004 outer_gradient_norm: fit.outer_gradient_norm,
6005 standard_deviation,
6006 covariance_conditional,
6007 covariance_corrected: None,
6008 inference: Some(inf),
6009 fitted_link: gam_solve::estimate::FittedLinkState::Standard(None),
6010 geometry,
6011 block_states: Vec::new(),
6012 pirls_status: pirls_status_val,
6013 max_abs_eta,
6014 constraint_kkt: None,
6015 artifacts: gam_solve::estimate::FitArtifacts {
6016 pirls: None,
6017 ..Default::default()
6018 },
6019 inner_cycles: 0,
6020 })?
6021 },
6022 design: design.clone(),
6023 adaptive_diagnostics: None,
6024 })
6025}
6026
6027fn enforce_term_constraint_feasibility(
6028 design: &TermCollectionDesign,
6029 fit: &UnifiedFitResult,
6030) -> Result<(), EstimationError> {
6031 const CONSTRAINT_FEASIBILITY_RAW_TOL: f64 = 1e-7;
6045 let tol = CONSTRAINT_FEASIBILITY_RAW_TOL;
6046 let smooth_start = design
6047 .design
6048 .ncols()
6049 .saturating_sub(design.smooth.total_smooth_cols());
6050 let mut violations: Vec<String> = Vec::new();
6051 for term in &design.smooth.terms {
6052 let gr = (smooth_start + term.coeff_range.start)..(smooth_start + term.coeff_range.end);
6053 let beta_local = fit.beta.slice(s![gr.clone()]).to_owned();
6054 if let Some(lb) = term.lower_bounds_local.as_ref() {
6055 let mut worst = 0.0_f64;
6056 let mut worst_idx = 0usize;
6057 for i in 0..lb.len().min(beta_local.len()) {
6058 if lb[i].is_finite() {
6059 let viol = (lb[i] - beta_local[i]).max(0.0);
6060 if viol > worst {
6061 worst = viol;
6062 worst_idx = i;
6063 }
6064 }
6065 }
6066 if worst > tol {
6067 violations.push(format!(
6068 "term='{}' kind=lower-bound maxviolation={:.3e} coeff_index={}",
6069 term.name, worst, worst_idx
6070 ));
6071 }
6072 }
6073 if let Some(lin) = term.linear_constraints_local.as_ref() {
6074 let mut worst = 0.0_f64;
6075 let mut worstrow = 0usize;
6076 for i in 0..lin.a.nrows() {
6077 let norm = lin.a.row(i).dot(&lin.a.row(i)).sqrt();
6078 let inv = if norm > 0.0 { 1.0 / norm } else { 0.0 };
6079 let s = (lin.a.row(i).dot(&beta_local) - lin.b[i]) * inv;
6080 let viol = (-s).max(0.0);
6081 if viol > worst {
6082 worst = viol;
6083 worstrow = i;
6084 }
6085 }
6086 if worst > tol {
6087 violations.push(format!(
6088 "term='{}' kind=linear-inequality maxviolation={:.3e} row={}",
6089 term.name, worst, worstrow
6090 ));
6091 }
6092 }
6093 }
6094
6095 if !violations.is_empty() {
6096 let mut msg = format!(
6097 "constraint violation after fit ({} violating term constraints): {}",
6098 violations.len(),
6099 violations.join(" | ")
6100 );
6101 if let Some(kkt) = fit.constraint_kkt.as_ref() {
6102 msg.push_str(&format!(
6103 "; KKT[primal={:.3e}, dual={:.3e}, comp={:.3e}, stat={:.3e}]",
6104 kkt.primal_feasibility, kkt.dual_feasibility, kkt.complementarity, kkt.stationarity
6105 ));
6106 }
6107 return Err(EstimationError::ParameterConstraintViolation(msg));
6108 }
6109 Ok(())
6110}
6111
6112fn stratified_spatial_subsample(
6113 data: ArrayView2<'_, f64>,
6114 spec: &TermCollectionSpec,
6115 target_size: usize,
6116) -> Vec<usize> {
6117 use rand::SeedableRng;
6118 use rand::rngs::StdRng;
6119 use rand::seq::SliceRandom;
6120
6121 let n = data.nrows();
6122 if n <= target_size {
6123 return (0..n).collect();
6124 }
6125
6126 let spatial_cols: Option<Vec<usize>> =
6127 spec.smooth_terms.iter().find_map(|term| match &term.basis {
6128 SmoothBasisSpec::ThinPlate { feature_cols, .. }
6129 | SmoothBasisSpec::Matern { feature_cols, .. }
6130 | SmoothBasisSpec::Duchon { feature_cols, .. } => {
6131 if !feature_cols.is_empty() {
6132 Some(feature_cols.clone())
6133 } else {
6134 None
6135 }
6136 }
6137 _ => None,
6138 });
6139
6140 let cols = match spatial_cols {
6141 Some(c) if !c.is_empty() => c,
6142 _ => {
6143 let mut rng = StdRng::seed_from_u64(spatial_subsample_seed(data, &[], target_size));
6144 let mut indices: Vec<usize> = (0..n).collect();
6145 indices.shuffle(&mut rng);
6146 indices.truncate(target_size);
6147 indices.sort_unstable();
6148 return indices;
6149 }
6150 };
6151 let mut rng = StdRng::seed_from_u64(spatial_subsample_seed(data, &cols, target_size));
6152
6153 let d = cols.len();
6154 let mut mins = vec![f64::INFINITY; d];
6155 let mut maxs = vec![f64::NEG_INFINITY; d];
6156 for i in 0..n {
6157 for (ax, &col) in cols.iter().enumerate() {
6158 let v = data[[i, col]];
6159 if v < mins[ax] {
6160 mins[ax] = v;
6161 }
6162 if v > maxs[ax] {
6163 maxs[ax] = v;
6164 }
6165 }
6166 }
6167
6168 const TARGET_POINTS_PER_CELL: usize = 5;
6172 let total_cells_target = (target_size / TARGET_POINTS_PER_CELL).max(1);
6173 let cells_per_axis = ((total_cells_target as f64).powf(1.0 / d as f64)).ceil() as usize;
6174 let cells_per_axis = cells_per_axis.max(1);
6175
6176 let mut cell_members: std::collections::HashMap<Vec<usize>, Vec<usize>> =
6177 std::collections::HashMap::new();
6178 for i in 0..n {
6179 let mut cell_key = Vec::with_capacity(d);
6180 for (ax, &col) in cols.iter().enumerate() {
6181 let range = maxs[ax] - mins[ax];
6182 let cell = if range <= 0.0 {
6183 0
6184 } else {
6185 let frac = (data[[i, col]] - mins[ax]) / range;
6186 (frac * cells_per_axis as f64).floor() as usize
6187 };
6188 cell_key.push(cell.min(cells_per_axis - 1));
6189 }
6190 cell_members.entry(cell_key).or_default().push(i);
6191 }
6192
6193 let mut selected: Vec<usize> = Vec::with_capacity(target_size);
6194 let mut remaining_budget = target_size;
6195 let mut remaining_population = n;
6196
6197 let mut cells: Vec<(Vec<usize>, Vec<usize>)> = cell_members.into_iter().collect();
6198 cells.sort_by(|a, b| a.0.cmp(&b.0));
6199
6200 for (_, members) in &mut cells {
6201 if remaining_budget == 0 {
6202 break;
6203 }
6204 let alloc = ((members.len() as f64 / remaining_population as f64) * remaining_budget as f64)
6205 .round() as usize;
6206 let alloc = alloc.max(1).min(members.len()).min(remaining_budget);
6207 members.shuffle(&mut rng);
6208 selected.extend_from_slice(&members[..alloc]);
6209 remaining_budget = remaining_budget.saturating_sub(alloc);
6210 remaining_population = remaining_population.saturating_sub(members.len());
6211 }
6212
6213 if selected.len() > target_size {
6214 selected.shuffle(&mut rng);
6215 selected.truncate(target_size);
6216 }
6217
6218 selected.sort_unstable();
6219 selected
6220}
6221
6222fn spatial_subsample_seed(
6223 data: ArrayView2<'_, f64>,
6224 spatial_cols: &[usize],
6225 target_size: usize,
6226) -> u64 {
6227 let mut state = 0x5350_4154_4941_4C53_u64;
6228 spatial_seed_mix(&mut state, data.nrows() as u64);
6229 spatial_seed_mix(&mut state, data.ncols() as u64);
6230 spatial_seed_mix(&mut state, target_size as u64);
6231 spatial_seed_mix(&mut state, spatial_cols.len() as u64);
6232 for &col in spatial_cols {
6233 spatial_seed_mix(&mut state, col as u64);
6234 }
6235
6236 if data.nrows() > 0 {
6237 let mid = data.nrows() / 2;
6238 let last = data.nrows() - 1;
6239 for &row in &[0usize, mid, last] {
6240 for &col in spatial_cols {
6241 let value = data[[row, col]];
6242 spatial_seed_mix(&mut state, value.to_bits());
6243 }
6244 }
6245 }
6246 state
6247}
6248
6249#[inline]
6250fn spatial_seed_mix(state: &mut u64, value: u64) {
6251 let mut s = value.wrapping_add(*state);
6254 let z = gam_linalg::utils::splitmix64(&mut s);
6255 *state ^= z;
6256 *state = (*state).rotate_left(27).wrapping_mul(0x3C79_AC49_2BA7_B653);
6257}
6258
6259fn sampled_rows(data: ArrayView2<'_, f64>, indices: &[usize]) -> Array2<f64> {
6260 let mut sampled = Array2::<f64>::zeros((indices.len(), data.ncols()));
6261 for (new_row, &orig_row) in indices.iter().enumerate() {
6262 sampled.row_mut(new_row).assign(&data.row(orig_row));
6263 }
6264 sampled
6265}
6266
6267fn spatial_term_user_centers(term: &SmoothTermSpec) -> Option<ArrayView2<'_, f64>> {
6268 match spatial_term_center_strategy(term) {
6269 Some(CenterStrategy::UserProvided(centers)) => Some(centers.view()),
6270 _ => None,
6271 }
6272}
6273
6274fn finite_centered_axis_contrasts(values: &[f64], expected_dim: usize) -> Option<Vec<f64>> {
6275 if values.len() != expected_dim || expected_dim <= 1 {
6276 return None;
6277 }
6278 if values.iter().any(|value| !value.is_finite()) {
6279 return None;
6280 }
6281 Some(center_aniso_log_scales(values))
6282}
6283
6284fn blended_pilot_axis_contrasts(
6285 pilot_data: ArrayView2<'_, f64>,
6286 term: &SmoothTermSpec,
6287 centers: ArrayView2<'_, f64>,
6288) -> Option<Vec<f64>> {
6289 let d = centers.ncols();
6290 if d <= 1 {
6291 return None;
6292 }
6293 let center_eta = initial_aniso_contrasts(centers);
6294 let data_eta = standardized_spatial_term_data(pilot_data, term)
6295 .ok()
6296 .and_then(|x| finite_centered_axis_contrasts(&initial_aniso_contrasts(x.view()), d));
6297 let center_eta = finite_centered_axis_contrasts(¢er_eta, d)?;
6298 let blended = match data_eta {
6299 Some(data_eta) => center_eta
6300 .iter()
6301 .zip(data_eta.iter())
6302 .map(|(&from_centers, &from_data)| 0.5 * (from_centers + from_data))
6303 .collect::<Vec<_>>(),
6304 None => center_eta,
6305 };
6306 finite_centered_axis_contrasts(&blended, d)
6307}
6308
6309fn apply_pilot_spatial_psi_reseed(
6310 pilot_data: ArrayView2<'_, f64>,
6311 spec: &TermCollectionSpec,
6312 spatial_terms: &[usize],
6313 kappa_options: &SpatialLengthScaleOptimizationOptions,
6314) -> Result<TermCollectionSpec, EstimationError> {
6315 let dims_per_term = spatial_dims_per_term(spec, spatial_terms);
6316 let use_aniso = has_aniso_terms(spec, spatial_terms);
6317 let log_kappa0 = if use_aniso {
6318 SpatialLogKappaCoords::from_length_scales_aniso(spec, spatial_terms, kappa_options)
6319 } else {
6320 SpatialLogKappaCoords::from_length_scales(spec, spatial_terms, kappa_options)
6321 };
6322 let log_kappa0 = log_kappa0.reseed_from_data(pilot_data, spec, spatial_terms, kappa_options);
6323 let log_kappa_lower = if use_aniso {
6324 SpatialLogKappaCoords::lower_bounds_aniso_from_data(
6325 pilot_data,
6326 spec,
6327 spatial_terms,
6328 &dims_per_term,
6329 kappa_options,
6330 )
6331 } else {
6332 SpatialLogKappaCoords::lower_bounds_from_data(
6333 pilot_data,
6334 spec,
6335 spatial_terms,
6336 kappa_options,
6337 )
6338 };
6339 let log_kappa_upper = if use_aniso {
6340 SpatialLogKappaCoords::upper_bounds_aniso_from_data(
6341 pilot_data,
6342 spec,
6343 spatial_terms,
6344 &dims_per_term,
6345 kappa_options,
6346 )
6347 } else {
6348 SpatialLogKappaCoords::upper_bounds_from_data(
6349 pilot_data,
6350 spec,
6351 spatial_terms,
6352 kappa_options,
6353 )
6354 };
6355 log_kappa0
6356 .clamp_to_bounds(&log_kappa_lower, &log_kappa_upper)
6357 .apply_tospec(spec, spatial_terms)
6358}
6359
6360pub(crate) fn apply_spatial_anisotropy_pilot_initializer(
6361 data: ArrayView2<'_, f64>,
6362 spec: &mut TermCollectionSpec,
6363 spatial_terms: &[usize],
6364 target_size: usize,
6365 kappa_options: &SpatialLengthScaleOptimizationOptions,
6366) -> usize {
6367 if target_size == 0 || data.nrows() <= target_size.saturating_mul(2) || spatial_terms.is_empty()
6368 {
6369 return 0;
6370 }
6371 if !has_aniso_terms(spec, spatial_terms) {
6372 return 0;
6373 }
6374 let indices = stratified_spatial_subsample(data, spec, target_size);
6375 let pilot_data = sampled_rows(data, &indices);
6376 let mut working = spec.clone();
6377 let mut updated_terms = 0usize;
6378 const GEOMETRY_UPDATES: usize = 2;
6379
6380 for pass in 0..GEOMETRY_UPDATES {
6381 let planned_terms = match plan_joint_spatial_centers_for_term_blocks(
6382 pilot_data.view(),
6383 &[working.smooth_terms.clone()],
6384 )
6385 .and_then(|mut blocks| {
6386 blocks.pop().ok_or_else(|| {
6387 BasisError::InvalidInput(
6388 "pilot geometry initializer produced no smooth-term block".to_string(),
6389 )
6390 })
6391 }) {
6392 Ok(terms) => terms,
6393 Err(err) => {
6394 log::warn!(
6395 "[spatial-kappa] pilot geometry initializer skipped after center planning failed: {err}"
6396 );
6397 return updated_terms;
6398 }
6399 };
6400
6401 for &term_idx in spatial_terms {
6402 let Some(current_eta) = get_spatial_aniso_log_scales(&working, term_idx) else {
6403 continue;
6404 };
6405 let Some(d) = get_spatial_feature_dim(&working, term_idx) else {
6406 continue;
6407 };
6408 if d <= 1 || current_eta.len() != d {
6409 continue;
6410 }
6411 let Some(planned_term) = planned_terms.get(term_idx) else {
6412 continue;
6413 };
6414 let Some(centers) = spatial_term_user_centers(planned_term) else {
6415 continue;
6416 };
6417 let Some(eta) = blended_pilot_axis_contrasts(pilot_data.view(), planned_term, centers)
6418 else {
6419 continue;
6420 };
6421 if set_spatial_aniso_log_scales(&mut working, term_idx, eta).is_ok() {
6422 updated_terms += usize::from(pass == 0);
6423 }
6424 }
6425
6426 match apply_pilot_spatial_psi_reseed(
6427 pilot_data.view(),
6428 &working,
6429 spatial_terms,
6430 kappa_options,
6431 ) {
6432 Ok(updated) => {
6433 working = updated;
6434 }
6435 Err(err) => {
6436 log::warn!(
6437 "[spatial-kappa] pilot geometry ψ reseed skipped after deterministic initializer error: {err}"
6438 );
6439 break;
6440 }
6441 }
6442 }
6443
6444 if updated_terms > 0 {
6445 log::info!(
6446 "[spatial-kappa] initialized anisotropy from {}-row pilot geometry for {} spatial term(s); proceeding to full-data optimization",
6447 indices.len(),
6448 updated_terms
6449 );
6450 *spec = working;
6451 }
6452 updated_terms
6453}
6454
6455pub(crate) fn spatial_length_scale_term_indices(spec: &TermCollectionSpec) -> Vec<usize> {
6456 spec.smooth_terms
6457 .iter()
6458 .enumerate()
6459 .filter_map(|(idx, _)| spatial_term_supports_hyper_optimization(spec, idx).then_some(idx))
6460 .collect()
6461}
6462
6463fn fit_score(fit: &UnifiedFitResult) -> f64 {
6475 if fit.reml_score.is_finite() {
6476 return fit.reml_score;
6477 }
6478 let score = 0.5 * fit.deviance + 0.5 * fit.stable_penalty_term;
6479 if score.is_finite() {
6480 score
6481 } else {
6482 f64::INFINITY
6483 }
6484}
6485
6486fn is_recoverable_trial_point_error(err: &EstimationError) -> bool {
6508 matches!(err, EstimationError::BasisError(_))
6509 || err.is_inner_solve_retreat()
6510 || is_recoverable_fit_inference_finiteness_error(err)
6511}
6512
6513fn is_recoverable_fit_inference_finiteness_error(err: &EstimationError) -> bool {
6514 let EstimationError::InvalidInput(message) = err else {
6515 return false;
6516 };
6517
6518 message.contains("must be finite")
6519 && [
6520 "fit_result.beta_covariance_frequentist",
6521 "fit_result.coefficient_influence",
6522 "fit_result.weighted_gram",
6523 ]
6524 .iter()
6525 .any(|field| message.contains(field))
6526}
6527
6528#[cfg(test)]
6529mod spatial_trial_recovery_tests {
6530 use super::*;
6531
6532 #[test]
6533 fn nonfinite_frequentist_covariance_is_recoverable_trial_point() {
6534 let err = EstimationError::InvalidInput(
6535 "fit_result.beta_covariance_frequentist[0] must be finite, got NaN".to_string(),
6536 );
6537
6538 assert!(
6539 is_recoverable_trial_point_error(&err),
6540 "singular trial-point curvature should make spatial κ retreat, not abort"
6541 );
6542 }
6543
6544 #[test]
6545 fn arbitrary_invalid_input_remains_fatal_trial_point_error() {
6546 let err = EstimationError::InvalidInput("outer rho bounds are invalid".to_string());
6547
6548 assert!(
6549 !is_recoverable_trial_point_error(&err),
6550 "the spatial κ recovery gate must not mask unrelated invalid inputs"
6551 );
6552 }
6553}
6554
6555fn require_successful_spatial_optimization_result<T>(
6556 initial_score: f64,
6557 result: Result<Option<(T, f64)>, EstimationError>,
6558) -> Result<T, EstimationError> {
6559 match result {
6560 Ok(Some((value, exact_score))) => {
6561 const SCORE_DRIFT_ABS_TOL: f64 = 1e-6;
6570 const SCORE_DRIFT_REL_TOL: f64 = 1e-8;
6571 let tol = SCORE_DRIFT_ABS_TOL.max(initial_score.abs() * SCORE_DRIFT_REL_TOL);
6572 if exact_score <= initial_score + tol {
6573 Ok(value)
6574 } else {
6575 Err(EstimationError::RemlOptimizationFailed(format!(
6576 "spatial kappa optimization made REML score worse ({initial_score:.6e} -> {exact_score:.6e})"
6577 )))
6578 }
6579 }
6580 Ok(None) => Err(EstimationError::RemlOptimizationFailed(
6581 "spatial kappa optimization is unavailable for one or more eligible spatial terms"
6582 .to_string(),
6583 )),
6584 Err(err) => Err(EstimationError::RemlOptimizationFailed(format!(
6585 "spatial kappa optimization failed: {err}"
6586 ))),
6587 }
6588}
6589
6590fn external_opts_for_design(
6591 family: &LikelihoodSpec,
6592 design: &TermCollectionDesign,
6593 options: &FitOptions,
6594) -> ExternalOptimOptions {
6595 ExternalOptimOptions {
6596 family: family.clone(),
6597 latent_cloglog: options.latent_cloglog,
6598 mixture_link: options.mixture_link.clone(),
6599 optimize_mixture: options.optimize_mixture,
6600 sas_link: options.sas_link,
6601 optimize_sas: options.optimize_sas,
6602 compute_inference: options.compute_inference,
6603 skip_rho_posterior_inference: options.skip_rho_posterior_inference,
6604 max_iter: options.max_iter,
6605 tol: options.tol,
6606 nullspace_dims: design.nullspace_dims.clone(),
6607 linear_constraints: design.linear_constraints.clone(),
6608 firth_bias_reduction: Some(options.firth_bias_reduction),
6609 penalty_shrinkage_floor: options.penalty_shrinkage_floor,
6610 rho_prior: options.rho_prior.clone(),
6611 kronecker_penalty_system: design.kronecker_penalty_system(),
6614 kronecker_factored: design
6615 .smooth
6616 .terms
6617 .iter()
6618 .find_map(|t| t.kronecker_factored.clone()),
6619 persist_warm_start_disk: options.persist_warm_start_disk,
6620 }
6621}
6622
6623fn evaluate_joint_reml_outer_eval_at_theta(
6631 evaluator: &mut gam_solve::estimate::ExternalJointHyperEvaluator<'_>,
6632 design: &TermCollectionDesign,
6633 theta: &Array1<f64>,
6634 rho_dim: usize,
6635 hyper_dirs: Vec<gam_solve::estimate::reml::DirectionalHyperParam>,
6636 warm_start_beta: Option<ArrayView1<'_, f64>>,
6637 order: gam_solve::rho_optimizer::OuterEvalOrder,
6638 design_revision: Option<u64>,
6639) -> Result<
6640 (
6641 f64,
6642 Array1<f64>,
6643 gam_problem::HessianResult,
6644 ),
6645 EstimationError,
6646> {
6647 evaluator.evaluate_with_order(
6648 &design.design,
6649 &design.penalties,
6650 &design.nullspace_dims,
6651 design.linear_constraints.clone(),
6652 theta,
6653 rho_dim,
6654 hyper_dirs,
6655 warm_start_beta,
6656 "evaluate_joint_reml_outer_eval_at_theta",
6657 order,
6658 design_revision,
6659 )
6660}
6661
6662fn evaluate_joint_reml_efs_at_theta(
6663 evaluator: &mut gam_solve::estimate::ExternalJointHyperEvaluator<'_>,
6664 design: &TermCollectionDesign,
6665 theta: &Array1<f64>,
6666 rho_dim: usize,
6667 hyper_dirs: Vec<gam_solve::estimate::reml::DirectionalHyperParam>,
6668 warm_start_beta: Option<ArrayView1<'_, f64>>,
6669 design_revision: Option<u64>,
6670) -> Result<gam_problem::EfsEval, EstimationError> {
6671 evaluator.evaluate_efs(
6672 &design.design,
6673 &design.penalties,
6674 &design.nullspace_dims,
6675 design.linear_constraints.clone(),
6676 theta,
6677 rho_dim,
6678 hyper_dirs,
6679 warm_start_beta,
6680 "evaluate_joint_reml_efs_at_theta",
6681 design_revision,
6682 )
6683}
6684
6685fn exact_joint_spatial_outer_hessian_available(
6686 family: &LikelihoodSpec,
6687 design: &TermCollectionDesign,
6688) -> bool {
6689 let family_supported = match &family.response {
6712 ResponseFamily::Gaussian
6713 | ResponseFamily::Binomial
6714 | ResponseFamily::Poisson
6715 | ResponseFamily::Tweedie { .. }
6716 | ResponseFamily::NegativeBinomial { .. }
6717 | ResponseFamily::Beta { .. }
6718 | ResponseFamily::Gamma
6719 | ResponseFamily::RoystonParmar => true,
6720 };
6721 family_supported && design.design.ncols() > 0
6724}
6725
6726fn smooth_term_penalty_index(
6727 spec: &TermCollectionSpec,
6728 design: &TermCollectionDesign,
6729 term_idx: usize,
6730) -> Option<usize> {
6731 if term_idx >= design.smooth.terms.len() || term_idx >= spec.smooth_terms.len() {
6732 return None;
6733 }
6734 if design.smooth.terms[term_idx].penalties_local.is_empty() {
6735 return None;
6736 }
6737 let linear_penalties = spec
6738 .linear_terms
6739 .iter()
6740 .filter(|t| t.double_penalty)
6741 .count()
6742 * 2;
6743 let random_penalties = design
6744 .random_effect_ranges
6745 .iter()
6746 .filter(|(_, range)| !range.is_empty())
6747 .count();
6748 let smooth_offset = linear_penalties + random_penalties;
6749 let local_offset = design
6750 .smooth
6751 .terms
6752 .iter()
6753 .take(term_idx)
6754 .map(|term| term.penalties_local.len())
6755 .sum::<usize>();
6756 Some(smooth_offset + local_offset)
6757}
6758
6759fn try_build_spatial_term_log_kappa_derivativeinfo(
6760 data: ArrayView2<'_, f64>,
6761 resolvedspec: &TermCollectionSpec,
6762 design: &TermCollectionDesign,
6763 term_idx: usize,
6764) -> Result<Option<SpatialPsiDerivative>, EstimationError> {
6765 let Some((
6766 global_range,
6767 total_p,
6768 x_psi_local,
6769 s_psi_local_check,
6770 x_psi_psi_local,
6771 s_psi_psi_local,
6772 s_psi_components_local,
6773 s_psi_psi_components_local,
6774 implicit_operator,
6775 )) = try_build_spatial_term_log_kappa_derivative(data, resolvedspec, design, term_idx)?
6776 else {
6777 return Ok(None);
6778 };
6779 let Some(penalty_start) = smooth_term_penalty_index(resolvedspec, design, term_idx) else {
6780 return Ok(None);
6781 };
6782 if s_psi_components_local.is_empty() || s_psi_psi_components_local.is_empty() {
6783 return Ok(None);
6784 }
6785 if s_psi_components_local.len() != s_psi_psi_components_local.len() {
6786 return Ok(None);
6787 }
6788 let penalty_indices = (0..s_psi_components_local.len())
6789 .map(|j| penalty_start + j)
6790 .collect::<Vec<_>>();
6791 let penalty_index = penalty_indices[0];
6792 if s_psi_local_check.nrows() == 0 || s_psi_psi_local.nrows() == 0 {
6793 return Ok(None);
6794 }
6795 Ok(Some(SpatialPsiDerivative {
6796 penalty_index,
6797 penalty_indices,
6798 global_range,
6799 total_p,
6800 x_psi_local,
6801 s_psi_components_local,
6802 x_psi_psi_local,
6803 s_psi_psi_components_local,
6804 aniso_group_id: None,
6805 aniso_cross_designs: None,
6806 aniso_cross_penalty_provider: None,
6807 implicit_operator,
6808 implicit_axis: 0,
6809 }))
6810}
6811
6812pub(crate) fn try_build_spatial_log_kappa_derivativeinfo_list(
6813 data: ArrayView2<'_, f64>,
6814 resolvedspec: &TermCollectionSpec,
6815 design: &TermCollectionDesign,
6816 spatial_terms: &[usize],
6817) -> Result<Option<Vec<SpatialPsiDerivative>>, EstimationError> {
6818 let mut out = Vec::new();
6819 let mut aniso_gid = 0usize;
6820 for &term_idx in spatial_terms {
6821 if spatial_term_uses_per_axis_psi(resolvedspec, term_idx) {
6822 if let Some(entries) = try_build_spatial_term_log_kappa_aniso_derivativeinfos(
6823 data,
6824 resolvedspec,
6825 design,
6826 term_idx,
6827 aniso_gid,
6828 )? {
6829 aniso_gid += 1;
6830 out.extend(entries);
6831 continue;
6832 } else {
6833 return Ok(None);
6834 }
6835 }
6836 let Some(info) =
6837 try_build_spatial_term_log_kappa_derivativeinfo(data, resolvedspec, design, term_idx)?
6838 else {
6839 return Ok(None);
6840 };
6841 out.push(info);
6842 }
6843 Ok(Some(out))
6844}
6845
6846fn try_build_spatial_term_log_kappa_aniso_derivativeinfos(
6848 data: ArrayView2<'_, f64>,
6849 resolvedspec: &TermCollectionSpec,
6850 design: &TermCollectionDesign,
6851 term_idx: usize,
6852 aniso_group_id: usize,
6853) -> Result<Option<Vec<SpatialPsiDerivative>>, EstimationError> {
6854 let Some(smooth_term) = design.smooth.terms.get(term_idx) else {
6855 return Ok(None);
6856 };
6857 let Some(termspec) = resolvedspec.smooth_terms.get(term_idx) else {
6858 return Ok(None);
6859 };
6860 let mut aniso_result = match &termspec.basis {
6861 SmoothBasisSpec::Sphere { .. } => return Ok(None),
6862 SmoothBasisSpec::Matern {
6863 feature_cols,
6864 spec,
6865 input_scales,
6866 } => {
6867 let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
6868 if let Some(s) = input_scales {
6869 apply_input_standardization(&mut x, s);
6870 }
6871 let mut spec_operator = spec.clone();
6880 spec_operator.double_penalty = false;
6881 build_matern_basis_log_kappa_aniso_derivatives(x.view(), &spec_operator)
6882 .map_err(EstimationError::from)?
6883 }
6884 SmoothBasisSpec::MeasureJet {
6890 feature_cols,
6891 spec,
6892 input_scales,
6893 } => {
6894 let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
6895 if let Some(s) = input_scales {
6896 apply_input_standardization(&mut x, s);
6897 }
6898 build_measure_jet_basis_psi_derivatives(x.view(), spec)
6899 .map_err(EstimationError::from)?
6900 }
6901 _ => return Ok(None),
6902 };
6903 let d = if let Some(ref op) = aniso_result.implicit_operator {
6906 op.n_axes()
6907 } else if !aniso_result.design_first.is_empty() {
6908 aniso_result.design_first.len()
6909 } else {
6910 0
6911 };
6912 if d == 0 {
6913 return Ok(None);
6914 }
6915 let Some(penalty_start) = smooth_term_penalty_index(resolvedspec, design, term_idx) else {
6916 return Ok(None);
6917 };
6918 let p_total = design.design.ncols();
6919 let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
6920 let global_range = (smooth_start + smooth_term.coeff_range.start)
6921 ..(smooth_start + smooth_term.coeff_range.end);
6922 let num_penalties = aniso_result.penalties_first[0].len();
6923 let penalty_indices: Vec<usize> = (0..num_penalties).map(|j| penalty_start + j).collect();
6924 let penalties_cross_provider = aniso_result.penalties_cross_provider.clone();
6925
6926 let use_implicit_design = aniso_result.design_first.is_empty();
6930 let implicit_op_arc = aniso_result
6931 .implicit_operator
6932 .as_ref()
6933 .map(|op| std::sync::Arc::new(op.clone()));
6934
6935 let mut entries = Vec::with_capacity(d);
6936 for a in 0..d {
6937 let (x_psi_local, x_psi_psi_local) = if use_implicit_design {
6938 (Array2::<f64>::zeros((0, 0)), Array2::<f64>::zeros((0, 0)))
6944 } else {
6945 let x_first = std::mem::take(&mut aniso_result.design_first[a]);
6950 let x_second = std::mem::take(&mut aniso_result.design_second_diag[a]);
6951 if x_first.ncols() != smooth_term.coeff_range.len() {
6952 return Ok(None);
6953 }
6954 (x_first, x_second)
6955 };
6956 let s_psi_components = std::mem::take(&mut aniso_result.penalties_first[a]);
6957 let s_psi_psi_components = std::mem::take(&mut aniso_result.penalties_second_diag[a]);
6958 let cross_designs = if implicit_op_arc.is_some() {
6964 let mut cd = Vec::with_capacity(d - 1);
6965 for b in 0..d {
6966 if b == a {
6967 continue;
6968 }
6969 cd.push((b, Array2::<f64>::zeros((0, 0))));
6970 }
6971 cd
6972 } else if !aniso_result.design_second_cross.is_empty() {
6973 let mut cd = Vec::new();
6974 for (cross_idx, &(pa, pb)) in aniso_result.design_second_cross_pairs.iter().enumerate()
6975 {
6976 if pa == a {
6977 cd.push((pb, aniso_result.design_second_cross[cross_idx].clone()));
6978 } else if pb == a {
6979 cd.push((pa, aniso_result.design_second_cross[cross_idx].clone()));
6980 }
6981 }
6982 cd
6983 } else {
6984 Vec::new()
6985 };
6986 let cross_penalty_provider = if d > 1 {
6987 let penalties_cross_provider = penalties_cross_provider.clone();
6988 Some(std::sync::Arc::new(
6989 move |b_axis: usize| -> Result<Vec<Array2<f64>>, EstimationError> {
6990 if b_axis == a {
6991 return Ok(Vec::new());
6992 }
6993 let (axis_lo, axis_hi) = if a < b_axis { (a, b_axis) } else { (b_axis, a) };
6994 if let Some(provider) = penalties_cross_provider.as_ref() {
6995 provider
6996 .evaluate(axis_lo, axis_hi)
6997 .map_err(EstimationError::from)
6998 } else {
6999 Ok(Vec::new())
7003 }
7004 },
7005 )
7006 as std::sync::Arc<
7007 dyn Fn(usize) -> Result<Vec<Array2<f64>>, EstimationError>
7008 + Send
7009 + Sync
7010 + 'static,
7011 >)
7012 } else {
7013 None
7014 };
7015
7016 entries.push(SpatialPsiDerivative {
7017 penalty_index: penalty_indices[0],
7018 penalty_indices: penalty_indices.clone(),
7019 global_range: global_range.clone(),
7020 total_p: p_total,
7021 x_psi_local,
7022 s_psi_components_local: s_psi_components,
7023 x_psi_psi_local,
7024 s_psi_psi_components_local: s_psi_psi_components,
7025 aniso_group_id: Some(aniso_group_id),
7026 aniso_cross_designs: if cross_designs.is_empty() {
7027 None
7028 } else {
7029 Some(cross_designs)
7030 },
7031 aniso_cross_penalty_provider: cross_penalty_provider,
7032 implicit_operator: implicit_op_arc.clone(),
7033 implicit_axis: a,
7034 });
7035 }
7036 Ok(Some(entries))
7037}
7038
7039#[cfg(test)]
7040mod glm_eta_observation_fd_tests {
7041 use super::*;
7047
7048 fn one_obs(spec: &LikelihoodSpec, y: f64, eta: f64) -> StandardFamilyObservationState {
7049 let yv = Array1::from_vec(vec![y]);
7050 let wv = Array1::from_vec(vec![1.0]);
7051 let ev = Array1::from_vec(vec![eta]);
7052 evaluate_standard_familyobservations(spec.clone(), None, None, None, &yv, &wv, &ev)
7053 .expect("standard family observation state assembles")
7054 }
7055
7056 fn check_fd(label: &str, spec: &LikelihoodSpec, y: f64, eta: f64) {
7057 let h = 1e-5;
7058 let s0 = one_obs(spec, y, eta);
7059 let sp = one_obs(spec, y, eta + h);
7060 let sm = one_obs(spec, y, eta - h);
7061
7062 let score_fd = (sp.log_likelihood - sm.log_likelihood) / (2.0 * h);
7064 let score = s0.score[0];
7065 assert!(
7066 (score - score_fd).abs() <= 1e-4 * (1.0 + score.abs()),
7067 "{label}: score {score} vs FD {score_fd}"
7068 );
7069
7070 let neghess_fd = -(sp.score[0] - sm.score[0]) / (2.0 * h);
7072 let neghess = s0.neghessian_eta[0];
7073 assert!(
7074 (neghess - neghess_fd).abs() <= 1e-3 * (1.0 + neghess.abs()),
7075 "{label}: neghessian_eta {neghess} vs FD {neghess_fd}"
7076 );
7077
7078 let nhd_fd = (sp.neghessian_eta[0] - sm.neghessian_eta[0]) / (2.0 * h);
7080 let nhd = s0.neghessian_eta_derivative[0];
7081 assert!(
7082 (nhd - nhd_fd).abs() <= 1e-2 * (1.0 + nhd.abs()),
7083 "{label}: neghessian_eta_derivative {nhd} vs FD {nhd_fd}"
7084 );
7085 }
7086
7087 #[test]
7088 fn poisson_gamma_nb_tweedie_arms_match_finite_differences_1615_1616() {
7089 let log = InverseLink::Standard(StandardLink::Log);
7090 let poisson = LikelihoodSpec {
7091 response: ResponseFamily::Poisson,
7092 link: log.clone(),
7093 };
7094 check_fd("poisson y=3", &poisson, 3.0, 0.4);
7095 check_fd("poisson y=0", &poisson, 0.0, -0.2);
7096
7097 let gamma = LikelihoodSpec {
7098 response: ResponseFamily::Gamma,
7099 link: log.clone(),
7100 };
7101 check_fd("gamma y=2.5", &gamma, 2.5, 0.3);
7102 check_fd("gamma y=0.7", &gamma, 0.7, -0.1);
7103
7104 let nb = LikelihoodSpec {
7105 response: ResponseFamily::NegativeBinomial {
7106 theta: 1.5,
7107 theta_fixed: true,
7108 },
7109 link: log.clone(),
7110 };
7111 check_fd("negbin y=4", &nb, 4.0, 0.5);
7112 check_fd("negbin y=0", &nb, 0.0, -0.3);
7113
7114 let tweedie = LikelihoodSpec {
7115 response: ResponseFamily::Tweedie { p: 1.5 },
7116 link: log.clone(),
7117 };
7118 check_fd("tweedie y=2", &tweedie, 2.0, 0.25);
7119 check_fd("tweedie y=0.5", &tweedie, 0.5, -0.15);
7120 }
7121}