1pub fn build_term_collection_designs_joint(
9 data: ArrayView2<'_, f64>,
10 specs: &[TermCollectionSpec],
11) -> Result<Vec<TermCollectionDesign>, BasisError> {
12 for spec in specs {
13 validate_term_collection_finite_inputs(data, spec)?;
14 }
15 let smooth_blocks = specs
16 .iter()
17 .map(|spec| spec.smooth_terms.clone())
18 .collect::<Vec<_>>();
19 let planned_blocks = plan_joint_spatial_centers_for_term_blocks(data, &smooth_blocks)?;
20 let mut out = Vec::with_capacity(specs.len());
21 for (spec, planned_terms) in specs.iter().zip(planned_blocks.into_iter()) {
22 let mut planned_spec = spec.clone();
23 planned_spec.smooth_terms = planned_terms;
24 out.push(build_term_collection_design_inner(data, &planned_spec)?);
25 }
26 Ok(out)
27}
28
29pub fn build_term_collection_designs_and_freeze_joint(
30 data: ArrayView2<'_, f64>,
31 specs: &[TermCollectionSpec],
32) -> Result<(Vec<TermCollectionDesign>, Vec<TermCollectionSpec>), EstimationError> {
33 let designs = build_term_collection_designs_joint(data, specs)?;
34 let mut resolved_specs = Vec::with_capacity(specs.len());
35 for (spec, design) in specs.iter().zip(designs.iter()) {
36 resolved_specs.push(freeze_term_collection_from_design(spec, design)?);
37 }
38 Ok((designs, resolved_specs))
39}
40
41pub fn fit_term_collection_forspec(
42 data: ArrayView2<'_, f64>,
43 y: ArrayView1<'_, f64>,
44 weights: ArrayView1<'_, f64>,
45 offset: ArrayView1<'_, f64>,
46 spec: &TermCollectionSpec,
47 family: LikelihoodSpec,
48 options: &FitOptions,
49) -> Result<FittedTermCollection, EstimationError> {
50 fit_term_collection_forspecwith_heuristic_lambdas(
51 data, y, weights, offset, spec, None, family, options,
52 )
53}
54
55pub fn fit_term_collection_with_coefficient_groups(
56 data: ArrayView2<'_, f64>,
57 y: ArrayView1<'_, f64>,
58 weights: ArrayView1<'_, f64>,
59 offset: ArrayView1<'_, f64>,
60 spec: &TermCollectionSpec,
61 groups: &[CoefficientGroupSpec],
62 family: LikelihoodSpec,
63 options: &FitOptions,
64) -> Result<FittedTermCollection, EstimationError> {
65 if groups.is_empty() {
66 return fit_term_collection_forspec(data, y, weights, offset, spec, family, options);
67 }
68 let design = build_term_collection_design(data, spec)?;
69 let base_fit_opts = adaptive_fit_options_base(options, &design);
70 let realized = design
71 .realize_coefficient_groups(groups, &base_fit_opts.rho_prior)
72 .map_err(EstimationError::BasisError)?;
73 let mut grouped_options = base_fit_opts.clone();
74 grouped_options.rho_prior = realized.rho_prior;
75 let fitted = FittedTermCollection {
76 fit: gam_solve::estimate::fit_gam_with_penalty_specs(
77 design.design.clone(),
78 y,
79 weights,
80 offset,
81 realized.penalty_specs,
82 realized.nullspace_dims,
83 family.clone(),
84 &grouped_options,
85 )?,
86 design,
87 adaptive_diagnostics: None,
88 };
89 enforce_term_constraint_feasibility(&fitted.design, &fitted.fit)?;
90 Ok(fitted)
91}
92
93pub fn fit_term_collection_with_penalty_block_gamma_prior_callback<F>(
94 data: ArrayView2<'_, f64>,
95 y: ArrayView1<'_, f64>,
96 weights: ArrayView1<'_, f64>,
97 offset: ArrayView1<'_, f64>,
98 spec: &TermCollectionSpec,
99 callback: F,
100 family: LikelihoodSpec,
101 options: &FitOptions,
102) -> Result<FittedTermCollection, EstimationError>
103where
104 F: FnMut(&PenaltyBlockGammaPriorMetadata<'_>) -> Option<(f64, f64)>,
105{
106 let design = build_term_collection_design(data, spec)?;
107 let mut fit_opts = adaptive_fit_options_base(options, &design);
108 fit_opts.rho_prior = realize_penalty_block_gamma_priors(&design, callback)
109 .map_err(EstimationError::BasisError)?;
110 let fitted = FittedTermCollection {
111 fit: fit_gamwith_heuristic_lambdas(
112 design.design.clone(),
113 y,
114 weights,
115 offset,
116 &design.penalties,
117 None,
118 family.clone(),
119 &fit_opts,
120 )?,
121 design,
122 adaptive_diagnostics: None,
123 };
124 enforce_term_constraint_feasibility(&fitted.design, &fitted.fit)?;
125 Ok(fitted)
126}
127
128pub fn fit_term_collection_with_penalty_block_gamma_priors(
129 data: ArrayView2<'_, f64>,
130 y: ArrayView1<'_, f64>,
131 weights: ArrayView1<'_, f64>,
132 offset: ArrayView1<'_, f64>,
133 spec: &TermCollectionSpec,
134 priors: &[(String, f64, f64)],
135 family: LikelihoodSpec,
136 options: &FitOptions,
137) -> Result<FittedTermCollection, EstimationError> {
138 let design = build_term_collection_design(data, spec)?;
139 let mut fit_opts = adaptive_fit_options_base(options, &design);
140 fit_opts.rho_prior = realize_keyed_penalty_block_gamma_priors(&design, priors)
141 .map_err(EstimationError::BasisError)?;
142 let fitted = FittedTermCollection {
143 fit: fit_gamwith_heuristic_lambdas(
144 design.design.clone(),
145 y,
146 weights,
147 offset,
148 &design.penalties,
149 None,
150 family.clone(),
151 &fit_opts,
152 )?,
153 design,
154 adaptive_diagnostics: None,
155 };
156 enforce_term_constraint_feasibility(&fitted.design, &fitted.fit)?;
157 Ok(fitted)
158}
159
160fn expand_double_penalty_linear_penalty_blocks(
173 design: &TermCollectionDesign,
174 spec: &TermCollectionSpec,
175) -> TermCollectionDesign {
176 let Some(shared_idx) = design.penaltyinfo.iter().position(|info| {
177 info.termname.as_deref() == Some("linear")
178 && matches!(&info.penalty.source, PenaltySource::Other(s) if s == "LinearTermRidge")
179 }) else {
180 return design.clone();
181 };
182
183 let mut new_penalties = Vec::<BlockwisePenalty>::new();
184 let mut new_nullspace = Vec::<usize>::new();
185 let mut new_info = Vec::<PenaltyBlockInfo>::new();
186 for (j, linear) in spec.linear_terms.iter().enumerate() {
187 if !linear.double_penalty {
188 continue;
189 }
190 let Some((_, range)) = design
191 .linear_ranges
192 .iter()
193 .find(|(name, _)| name == &linear.name)
194 else {
195 continue;
196 };
197 new_penalties.push(BlockwisePenalty::ridge(range.clone(), 1.0));
200 new_nullspace.push(0);
201 new_info.push(PenaltyBlockInfo {
202 global_index: 0,
203 termname: Some(linear.name.clone()),
204 penalty: PenaltyInfo {
205 source: PenaltySource::Other("LinearTermRidge".to_string()),
206 original_index: j,
207 active: true,
208 effective_rank: 1,
209 dropped_reason: None,
210 nullspace_dim_hint: 0,
211 normalization_scale: 1.0,
212 kronecker_factors: None,
213 },
214 });
215 new_penalties.push(BlockwisePenalty::ridge(range.clone(), 1.0));
218 new_nullspace.push(0);
219 new_info.push(PenaltyBlockInfo {
220 global_index: 0,
221 termname: None,
222 penalty: PenaltyInfo {
223 source: PenaltySource::DoublePenaltyNullspace,
224 original_index: j,
225 active: true,
226 effective_rank: 1,
227 dropped_reason: None,
228 nullspace_dim_hint: 0,
229 normalization_scale: 1.0,
230 kronecker_factors: None,
231 },
232 });
233 }
234
235 if new_penalties.is_empty() {
236 return design.clone();
237 }
238
239 let mut expanded = design.clone();
240 expanded
241 .penalties
242 .splice(shared_idx..=shared_idx, new_penalties);
243 expanded
244 .nullspace_dims
245 .splice(shared_idx..=shared_idx, new_nullspace);
246 expanded
247 .penaltyinfo
248 .splice(shared_idx..=shared_idx, new_info);
249 for (idx, info) in expanded.penaltyinfo.iter_mut().enumerate() {
251 info.global_index = idx;
252 }
253 expanded
254}
255
256pub fn fit_term_collection_with_coefficient_groups_and_penalty_block_gamma_priors(
257 data: ArrayView2<'_, f64>,
258 y: ArrayView1<'_, f64>,
259 weights: ArrayView1<'_, f64>,
260 offset: ArrayView1<'_, f64>,
261 spec: &TermCollectionSpec,
262 groups: &[CoefficientGroupSpec],
263 priors: &[(String, f64, f64)],
264 family: LikelihoodSpec,
265 options: &FitOptions,
266) -> Result<FittedTermCollection, EstimationError> {
267 if groups.is_empty() {
268 return fit_term_collection_with_penalty_block_gamma_priors(
269 data, y, weights, offset, spec, priors, family, options,
270 );
271 }
272 if priors.is_empty() {
273 return fit_term_collection_with_coefficient_groups(
274 data, y, weights, offset, spec, groups, family, options,
275 );
276 }
277
278 let design = build_term_collection_design(data, spec)?;
279 let design = expand_double_penalty_linear_penalty_blocks(&design, spec);
285 let base_fit_opts = adaptive_fit_options_base(options, &design);
286 let base_rho_prior = realize_keyed_penalty_block_gamma_priors(&design, priors)
287 .map_err(EstimationError::BasisError)?;
288 let realized = design
289 .realize_coefficient_groups(groups, &base_rho_prior)
290 .map_err(EstimationError::BasisError)?;
291 let mut grouped_options = base_fit_opts.clone();
292 grouped_options.rho_prior = realized.rho_prior;
293 let fitted = FittedTermCollection {
294 fit: gam_solve::estimate::fit_gam_with_penalty_specs(
295 design.design.clone(),
296 y,
297 weights,
298 offset,
299 realized.penalty_specs,
300 realized.nullspace_dims,
301 family.clone(),
302 &grouped_options,
303 )?,
304 design,
305 adaptive_diagnostics: None,
306 };
307 enforce_term_constraint_feasibility(&fitted.design, &fitted.fit)?;
308 Ok(fitted)
309}
310
311fn fit_term_collection_forspecwith_heuristic_lambdas(
312 data: ArrayView2<'_, f64>,
313 y: ArrayView1<'_, f64>,
314 weights: ArrayView1<'_, f64>,
315 offset: ArrayView1<'_, f64>,
316 spec: &TermCollectionSpec,
317 heuristic_lambdas: Option<&[f64]>,
318 family: LikelihoodSpec,
319 options: &FitOptions,
320) -> Result<FittedTermCollection, EstimationError> {
321 let adaptive_opts = options.adaptive_regularization.clone().unwrap_or_default();
322 let resolved_spec;
323 let design_spec = if adaptive_opts.enabled {
324 resolved_spec = ensure_matern_adaptive_center_resolution(spec, data.nrows());
325 &resolved_spec
326 } else {
327 spec
328 };
329 let base_design = build_term_collection_design(data, design_spec)?;
330 fit_term_collection_on_realized_design(
331 y,
332 weights,
333 offset,
334 design_spec,
335 &base_design,
336 heuristic_lambdas,
337 family,
338 options,
339 )
340}
341
342fn ensure_matern_adaptive_center_resolution(
343 spec: &TermCollectionSpec,
344 n_rows: usize,
345) -> TermCollectionSpec {
346 let mut out = spec.clone();
347 for term in &mut out.smooth_terms {
348 let gam_terms::smooth::SmoothBasisSpec::Matern {
349 feature_cols,
350 spec: matern,
351 ..
352 } = &mut term.basis
353 else {
354 continue;
355 };
356 if let gam_terms::basis::CenterStrategy::FarthestPoint { num_centers } =
357 &mut matern.center_strategy
358 {
359 let min_centers = (4 * feature_cols.len()).min(n_rows).max(*num_centers);
372 *num_centers = min_centers;
373 }
374 }
375 out
376}
377
378fn has_bounded_linear_terms(spec: &TermCollectionSpec) -> bool {
379 spec.linear_terms.iter().any(|term| {
380 matches!(
381 term.coefficient_geometry,
382 LinearCoefficientGeometry::Bounded { .. }
383 )
384 })
385}
386
387fn fit_term_collection_on_realized_design(
388 y: ArrayView1<'_, f64>,
389 weights: ArrayView1<'_, f64>,
390 offset: ArrayView1<'_, f64>,
391 spec: &TermCollectionSpec,
392 design: &TermCollectionDesign,
393 heuristic_lambdas: Option<&[f64]>,
394 family: LikelihoodSpec,
395 options: &FitOptions,
396) -> Result<FittedTermCollection, EstimationError> {
397 if has_bounded_linear_terms(spec) {
398 return fit_bounded_term_collection_with_design(
399 y,
400 weights,
401 offset,
402 spec,
403 design,
404 heuristic_lambdas,
405 family,
406 options,
407 );
408 }
409 let mut base_fit_opts = adaptive_fit_options_base(options, design);
410 base_fit_opts.rho_prior = relax_smoothing_rho_prior(options, design);
417 let fitted = FittedTermCollection {
418 fit: fit_gamwith_heuristic_lambdas(
419 design.design.clone(),
420 y,
421 weights,
422 offset,
423 &design.penalties,
424 heuristic_lambdas,
425 family.clone(),
426 &base_fit_opts,
427 )?,
428 design: design.clone(),
429 adaptive_diagnostics: None,
430 };
431 enforce_term_constraint_feasibility(&fitted.design, &fitted.fit)?;
432
433 let adaptive_opts = options.adaptive_regularization.clone().unwrap_or_default();
434 if !adaptive_opts.enabled {
435 return Ok(fitted);
436 }
437 let runtime_caches = extract_spatial_operator_runtime_caches(spec, &fitted.design)?;
438 if runtime_caches.is_empty() {
439 return Ok(fitted);
440 }
441 fit_term_collectionwith_exact_spatial_adaptive_regularization(
448 fitted,
449 y,
450 weights,
451 offset,
452 family,
453 options,
454 &runtime_caches,
455 )
456}
457
458#[derive(Clone)]
459struct SpatialOperatorRuntimeCache {
460 termname: String,
461 feature_cols: Vec<usize>,
462 coeff_global_range: Range<usize>,
463 mass_penalty_global_idx: usize,
464 tension_penalty_global_idx: usize,
465 stiffness_penalty_global_idx: usize,
466 d0: Array2<f64>,
467 d1: Array2<f64>,
468 d2: Array2<f64>,
469 collocation_points: Array2<f64>,
470 dimension: usize,
471}
472
473#[derive(Clone)]
474struct SpatialAdaptiveWeights {
475 inv_magweight: Array1<f64>,
476 invgradweight: Array1<f64>,
477 inv_lapweight: Array1<f64>,
478}
479
480#[derive(Clone)]
481struct CharbonnierScalarBlockState {
482 signal: Array1<f64>,
483 radius: Array1<f64>,
484 epsilon: f64,
485}
486
487impl CharbonnierScalarBlockState {
488 fn from_signal(signal: Array1<f64>, epsilon: f64) -> Self {
489 let eps = epsilon.max(1e-12);
490 let radius = signal.mapv(|t| (t * t + eps * eps).sqrt());
491 Self {
492 signal,
493 radius,
494 epsilon: eps,
495 }
496 }
497
498 fn absolute_signal(&self) -> Array1<f64> {
499 self.signal.mapv(f64::abs)
500 }
501
502 fn penalty_value(&self) -> f64 {
503 self.radius.iter().map(|r| r - self.epsilon).sum::<f64>()
504 }
505
506 fn betagradient_coeff(&self) -> Array1<f64> {
507 Array1::from_iter(
508 self.signal
509 .iter()
510 .zip(self.radius.iter())
511 .map(|(t, r)| t / r),
512 )
513 }
514
515 fn betahessian_diag(&self) -> Array1<f64> {
516 let eps2 = self.epsilon * self.epsilon;
517 self.radius.mapv(|r| eps2 / r.powi(3))
518 }
519
520 fn log_epsilon_gradient_terms(&self) -> Array1<f64> {
521 let epsilon = self.epsilon;
522 let eps2 = epsilon * epsilon;
523 self.radius.mapv(|r| eps2 / r - epsilon)
524 }
525
526 fn log_epsilon_betagradient_coeff(&self) -> Array1<f64> {
527 let eps2 = self.epsilon * self.epsilon;
528 Array1::from_iter(
529 self.signal
530 .iter()
531 .zip(self.radius.iter())
532 .map(|(t, r)| -eps2 * t / r.powi(3)),
533 )
534 }
535
536 fn log_epsilon_hessian_terms(&self) -> Array1<f64> {
537 let epsilon = self.epsilon;
538 let eps2 = epsilon * epsilon;
539 let eps4 = eps2 * eps2;
540 self.radius
541 .mapv(|r| 2.0 * eps2 / r - eps4 / r.powi(3) - epsilon)
542 }
543
544 fn surrogateweights_posterior_snr(
545 &self,
546 variance: &Array1<f64>,
547 weight_floor: f64,
548 weight_ceiling: f64,
549 ) -> (Array1<f64>, Array1<f64>) {
550 let eps2 = self.epsilon * self.epsilon;
608 let weight = Array1::from_iter(self.signal.iter().zip(variance.iter()).map(|(&t, &v)| {
609 let credible2 = (t * t - v.max(0.0)).max(0.0);
610 let r = (credible2 + eps2).sqrt();
611 (1.0 / r).clamp(weight_floor, weight_ceiling)
612 }));
613 let invweight = weight.mapv(|u| 1.0 / u);
614 (weight, invweight)
615 }
616
617 fn directionalhessian_diag(&self, direction_signal: &Array1<f64>) -> Array1<f64> {
618 let eps2 = self.epsilon * self.epsilon;
633 Array1::from_iter(
634 self.signal
635 .iter()
636 .zip(direction_signal.iter())
637 .zip(self.radius.iter())
638 .map(|((t, q), r)| -3.0 * eps2 * t * q / r.powi(5)),
639 )
640 }
641
642 fn second_directionalhessian_diag(
649 &self,
650 direction1_signal: &Array1<f64>,
651 direction2_signal: &Array1<f64>,
652 ) -> Array1<f64> {
653 let eps2 = self.epsilon * self.epsilon;
654 Array1::from_iter(
655 self.signal
656 .iter()
657 .zip(direction1_signal.iter())
658 .zip(direction2_signal.iter())
659 .zip(self.radius.iter())
660 .map(|(((t, q1), q2), r)| {
661 let r2 = r * r;
662 let psi4 = -3.0 * eps2 / r.powi(5) + 15.0 * eps2 * t * t / (r.powi(5) * r2);
663 psi4 * q1 * q2
664 }),
665 )
666 }
667
668 fn log_epsilon_betahessian_diag(&self) -> Array1<f64> {
669 let eps2 = self.epsilon * self.epsilon;
670 let eps4 = eps2 * eps2;
671 Array1::from_iter(
672 self.signal
673 .iter()
674 .zip(self.radius.iter())
675 .map(|(_, r)| 2.0 * eps2 / r.powi(3) - 3.0 * eps4 / r.powi(5)),
676 )
677 }
678
679 fn log_epsilon_beta_mixed_second_coeff(&self) -> Array1<f64> {
680 let eps2 = self.epsilon * self.epsilon;
681 Array1::from_iter(
682 self.signal
683 .iter()
684 .zip(self.radius.iter())
685 .map(|(t, r)| eps2 * t * (eps2 - 2.0 * t * t) / r.powi(5)),
686 )
687 }
688
689 fn log_epsilon_betahessian_second_diag(&self) -> Array1<f64> {
690 let eps2 = self.epsilon * self.epsilon;
691 let eps4 = eps2 * eps2;
692 let eps6 = eps4 * eps2;
693 Array1::from_iter(
694 self.radius.iter().map(|r| {
695 4.0 * eps2 / r.powi(3) - 18.0 * eps4 / r.powi(5) + 15.0 * eps6 / r.powi(7)
696 }),
697 )
698 }
699
700 fn log_epsilon_betahessian_directional_diag(
701 &self,
702 direction_signal: &Array1<f64>,
703 ) -> Array1<f64> {
704 let eps2 = self.epsilon * self.epsilon;
705 let eps4 = eps2 * eps2;
706 Array1::from_iter(
707 self.signal
708 .iter()
709 .zip(direction_signal.iter())
710 .zip(self.radius.iter())
711 .map(|((t, q), r)| (-6.0 * eps2 * t / r.powi(5) + 15.0 * eps4 * t / r.powi(7)) * q),
712 )
713 }
714}
715
716#[derive(Clone)]
717struct CharbonnierGroupedBlockState {
718 norm: Array1<f64>,
719 radius: Array1<f64>,
720 signal_blocks: Array2<f64>,
721 epsilon: f64,
722}
723
724impl CharbonnierGroupedBlockState {
725 fn from_signal_blocks(signal_blocks: Array2<f64>, epsilon: f64) -> Self {
726 let eps = epsilon.max(1e-12);
727 let norm = Array1::from_iter(
728 signal_blocks
729 .rows()
730 .into_iter()
731 .map(|row| row.iter().map(|v| v * v).sum::<f64>().sqrt()),
732 );
733 let radius = norm.mapv(|g| (g * g + eps * eps).sqrt());
734 Self {
735 norm,
736 radius,
737 signal_blocks,
738 epsilon: eps,
739 }
740 }
741
742 fn penalty_value(&self) -> f64 {
743 self.radius.iter().map(|r| r - self.epsilon).sum::<f64>()
744 }
745
746 fn norm_signal(&self) -> Array1<f64> {
747 self.norm.clone()
748 }
749
750 fn betagradient_blocks(&self) -> Array2<f64> {
751 let mut out = self.signal_blocks.clone();
752 for (k, mut row) in out.rows_mut().into_iter().enumerate() {
753 let scale = 1.0 / self.radius[k];
754 row.mapv_inplace(|v| v * scale);
755 }
756 out
757 }
758
759 fn betahessian_blocks(&self) -> Vec<Array2<f64>> {
760 let mut out = Vec::with_capacity(self.signal_blocks.nrows());
761 for (k, row) in self.signal_blocks.rows().into_iter().enumerate() {
762 let dim = row.len();
763 let mut block = Array2::<f64>::eye(dim);
764 block.mapv_inplace(|v| v / self.radius[k]);
765 for i in 0..dim {
766 for j in 0..dim {
767 block[[i, j]] -= row[i] * row[j] / self.radius[k].powi(3);
768 }
769 }
770 out.push(block);
771 }
772 out
773 }
774
775 fn log_epsilon_gradient_terms(&self) -> Array1<f64> {
776 let epsilon = self.epsilon;
777 let eps2 = epsilon * epsilon;
778 self.radius.mapv(|r| eps2 / r - epsilon)
779 }
780
781 fn log_epsilon_betagradient_blocks(&self) -> Array2<f64> {
782 let mut out = self.signal_blocks.clone();
783 let eps2 = self.epsilon * self.epsilon;
784 for (k, mut row) in out.rows_mut().into_iter().enumerate() {
785 let scale = -eps2 / self.radius[k].powi(3);
786 row.mapv_inplace(|v| v * scale);
787 }
788 out
789 }
790
791 fn log_epsilon_hessian_terms(&self) -> Array1<f64> {
792 let epsilon = self.epsilon;
793 let eps2 = epsilon * epsilon;
794 let eps4 = eps2 * eps2;
795 self.radius
796 .mapv(|r| 2.0 * eps2 / r - eps4 / r.powi(3) - epsilon)
797 }
798
799 fn surrogateweights_posterior_snr(
800 &self,
801 variance: &Array1<f64>,
802 weight_floor: f64,
803 weight_ceiling: f64,
804 ) -> (Array1<f64>, Array1<f64>) {
805 let eps2 = self.epsilon * self.epsilon;
847 let weight = Array1::from_iter(self.norm.iter().zip(variance.iter()).map(|(&g, &v)| {
848 let credible2 = (g * g - v.max(0.0)).max(0.0);
849 let r = (credible2 + eps2).sqrt();
850 (1.0 / r).clamp(weight_floor, weight_ceiling)
851 }));
852 let invweight = weight.mapv(|u| 1.0 / u);
853 (weight, invweight)
854 }
855
856 fn directionalhessian_blocks(&self, direction_blocks: &Array2<f64>) -> Vec<Array2<f64>> {
857 let mut out = Vec::with_capacity(self.signal_blocks.nrows());
882 for (k, (v, q)) in self
883 .signal_blocks
884 .rows()
885 .into_iter()
886 .zip(direction_blocks.rows().into_iter())
887 .enumerate()
888 {
889 let dim = v.len();
890 let dot = v.iter().zip(q.iter()).map(|(a, b)| a * b).sum::<f64>();
891 let r3 = self.radius[k].powi(3);
892 let r5 = self.radius[k].powi(5);
893 let mut block = Array2::<f64>::eye(dim);
894 block.mapv_inplace(|x| -dot * x / r3);
895 for i in 0..dim {
896 for j in 0..dim {
897 block[[i, j]] -= (q[i] * v[j] + v[i] * q[j]) / r3;
898 block[[i, j]] += 3.0 * dot * v[i] * v[j] / r5;
899 }
900 }
901 out.push(block);
902 }
903 out
904 }
905
906 fn second_directionalhessian_blocks(
923 &self,
924 direction1_blocks: &Array2<f64>,
925 direction2_blocks: &Array2<f64>,
926 ) -> Vec<Array2<f64>> {
927 let mut out = Vec::with_capacity(self.signal_blocks.nrows());
928 for ((k, v), (a, b)) in self.signal_blocks.rows().into_iter().enumerate().zip(
929 direction1_blocks
930 .rows()
931 .into_iter()
932 .zip(direction2_blocks.rows().into_iter()),
933 ) {
934 let dim = v.len();
935 let dot = |x: ndarray::ArrayView1<'_, f64>, y: ndarray::ArrayView1<'_, f64>| {
936 x.iter().zip(y.iter()).map(|(p, q)| p * q).sum::<f64>()
937 };
938 let sa = dot(v, a);
939 let sb = dot(v, b);
940 let ab = dot(a, b);
941 let r = self.radius[k];
942 let r3 = r.powi(3);
943 let r5 = r.powi(5);
944 let r7 = r5 * r * r;
945 let diag = -ab / r3 + 3.0 * sa * sb / r5;
946 let mut block = Array2::<f64>::eye(dim);
947 block.mapv_inplace(|x| diag * x);
948 for i in 0..dim {
949 for j in 0..dim {
950 block[[i, j]] -= (a[i] * b[j] + b[i] * a[j]) / r3;
951 block[[i, j]] += 3.0 * sb * (a[i] * v[j] + v[i] * a[j]) / r5;
952 block[[i, j]] += 3.0 * ab * v[i] * v[j] / r5;
953 block[[i, j]] += 3.0 * sa * (b[i] * v[j] + v[i] * b[j]) / r5;
954 block[[i, j]] -= 15.0 * sa * sb * v[i] * v[j] / r7;
955 }
956 }
957 out.push(block);
958 }
959 out
960 }
961
962 fn log_epsilon_betahessian_blocks(&self) -> Vec<Array2<f64>> {
963 let mut out = Vec::with_capacity(self.signal_blocks.nrows());
964 for (k, row) in self.signal_blocks.rows().into_iter().enumerate() {
965 let dim = row.len();
966 let r3 = self.radius[k].powi(3);
967 let r5 = self.radius[k].powi(5);
968 let mut block = Array2::<f64>::eye(dim);
969 let eps2 = self.epsilon * self.epsilon;
970 block.mapv_inplace(|v| -eps2 * v / r3);
971 for i in 0..dim {
972 for j in 0..dim {
973 block[[i, j]] += 3.0 * eps2 * row[i] * row[j] / r5;
974 }
975 }
976 out.push(block);
977 }
978 out
979 }
980
981 fn log_epsilon_beta_mixed_second_blocks(&self) -> Array2<f64> {
982 let mut out = self.signal_blocks.clone();
983 let eps2 = self.epsilon * self.epsilon;
984 for (k, mut row) in out.rows_mut().into_iter().enumerate() {
985 let norm2 = self.norm[k] * self.norm[k];
986 let scale = eps2 * (eps2 - 2.0 * norm2) / self.radius[k].powi(5);
987 row.mapv_inplace(|v| v * scale);
988 }
989 out
990 }
991
992 fn log_epsilon_betahessian_second_blocks(&self) -> Vec<Array2<f64>> {
993 let mut out = Vec::with_capacity(self.signal_blocks.nrows());
994 let eps2 = self.epsilon * self.epsilon;
995 for (k, row) in self.signal_blocks.rows().into_iter().enumerate() {
996 let dim = row.len();
997 let norm2 = self.norm[k] * self.norm[k];
998 let r5 = self.radius[k].powi(5);
999 let r7 = self.radius[k].powi(7);
1000 let mut block = Array2::<f64>::eye(dim);
1001 block.mapv_inplace(|v| eps2 * (eps2 - 2.0 * norm2) * v / r5);
1002 for i in 0..dim {
1003 for j in 0..dim {
1004 block[[i, j]] += 3.0 * eps2 * (2.0 * norm2 - 3.0 * eps2) * row[i] * row[j] / r7;
1005 }
1006 }
1007 out.push(block);
1008 }
1009 out
1010 }
1011
1012 fn log_epsilon_betahessian_directional_blocks(
1013 &self,
1014 direction_blocks: &Array2<f64>,
1015 ) -> Vec<Array2<f64>> {
1016 let mut out = Vec::with_capacity(self.signal_blocks.nrows());
1017 let eps2 = self.epsilon * self.epsilon;
1018 for (k, (v, q)) in self
1019 .signal_blocks
1020 .rows()
1021 .into_iter()
1022 .zip(direction_blocks.rows().into_iter())
1023 .enumerate()
1024 {
1025 let dim = v.len();
1026 let dot = v.iter().zip(q.iter()).map(|(a, b)| a * b).sum::<f64>();
1027 let r5 = self.radius[k].powi(5);
1028 let r7 = self.radius[k].powi(7);
1029 let mut block = Array2::<f64>::eye(dim);
1030 block.mapv_inplace(|x| 3.0 * eps2 * dot * x / r5);
1031 for i in 0..dim {
1032 for j in 0..dim {
1033 block[[i, j]] += 3.0 * eps2 * (q[i] * v[j] + v[i] * q[j]) / r5;
1034 block[[i, j]] -= 15.0 * eps2 * dot * v[i] * v[j] / r7;
1035 }
1036 }
1037 out.push(block);
1038 }
1039 out
1040 }
1041}
1042
1043fn scalar_operatorgradient(operator: &Array2<f64>, coeff: &Array1<f64>) -> Array1<f64> {
1044 operator.t().dot(coeff)
1045}
1046
1047fn scalar_operatorhessian(operator: &Array2<f64>, diag: &Array1<f64>) -> Array2<f64> {
1048 let mut weighted = operator.clone();
1049 for (k, &w) in diag.iter().enumerate() {
1050 weighted.row_mut(k).mapv_inplace(|v| v * w);
1051 }
1052 let gram = operator.t().dot(&weighted);
1053 (&gram + &gram.t().to_owned()) * 0.5
1054}
1055
1056fn grouped_operatorgradient(
1057 d1: &Array2<f64>,
1058 dimension: usize,
1059 blocks: &Array2<f64>,
1060) -> Result<Array1<f64>, EstimationError> {
1061 if blocks.ncols() != dimension {
1062 crate::bail_invalid_estim!(
1063 "grouped gradient block dimension mismatch: got {}, expected {dimension}",
1064 blocks.ncols()
1065 );
1066 }
1067 if d1.nrows() != blocks.nrows() * dimension {
1068 crate::bail_invalid_estim!(
1069 "grouped gradient row mismatch: D1 has {} rows, blocks imply {}",
1070 d1.nrows(),
1071 blocks.nrows() * dimension
1072 );
1073 }
1074 let mut out = Array1::<f64>::zeros(d1.ncols());
1075 for k in 0..blocks.nrows() {
1076 let gk = d1
1077 .slice(s![k * dimension..(k + 1) * dimension, ..])
1078 .to_owned();
1079 out += &gk.t().dot(&blocks.row(k));
1080 }
1081 Ok(out)
1082}
1083
1084fn grouped_operatorhessian(
1085 d1: &Array2<f64>,
1086 dimension: usize,
1087 blocks: &[Array2<f64>],
1088) -> Result<Array2<f64>, EstimationError> {
1089 if d1.nrows() != blocks.len() * dimension {
1090 crate::bail_invalid_estim!(
1091 "grouped Hessian row mismatch: D1 has {} rows, blocks imply {}",
1092 d1.nrows(),
1093 blocks.len() * dimension
1094 );
1095 }
1096 let p = d1.ncols();
1097 let mut out = Array2::<f64>::zeros((p, p));
1098 for (k, block) in blocks.iter().enumerate() {
1099 if block.nrows() != dimension || block.ncols() != dimension {
1100 crate::bail_invalid_estim!(
1101 "grouped Hessian block {k} has shape {}x{}, expected {}x{}",
1102 block.nrows(),
1103 block.ncols(),
1104 dimension,
1105 dimension
1106 );
1107 }
1108 let gk = d1
1109 .slice(s![k * dimension..(k + 1) * dimension, ..])
1110 .to_owned();
1111 out += &gk.t().dot(&block.dot(&gk));
1112 }
1113 Ok((&out + &out.t().to_owned()) * 0.5)
1114}
1115
1116#[derive(Clone)]
1117struct SpatialPenaltyExactState {
1118 magnitude: CharbonnierScalarBlockState,
1119 gradient: CharbonnierGroupedBlockState,
1120 curvature: CharbonnierGroupedBlockState,
1121}
1122
1123fn collocationgradient_blocks(
1124 gradrows: &Array1<f64>,
1125 dimension: usize,
1126) -> Result<Array2<f64>, EstimationError> {
1127 if dimension == 0 || !gradrows.len().is_multiple_of(dimension) {
1128 crate::bail_invalid_estim!(
1129 "invalid collocation gradient layout: rows={}, dimension={dimension}",
1130 gradrows.len()
1131 );
1132 }
1133 let p = gradrows.len() / dimension;
1134 let mut out = Array2::<f64>::zeros((p, dimension));
1135 for k in 0..p {
1136 for axis in 0..dimension {
1137 out[[k, axis]] = gradrows[k * dimension + axis];
1138 }
1139 }
1140 Ok(out)
1141}
1142
1143fn collocationhessian_blocks(
1144 hessianrows: &Array1<f64>,
1145 dimension: usize,
1146) -> Result<Array2<f64>, EstimationError> {
1147 let block_dim = dimension.checked_mul(dimension).ok_or_else(|| {
1148 EstimationError::InvalidInput("invalid collocation Hessian dimension overflow".to_string())
1149 })?;
1150 if block_dim == 0 || !hessianrows.len().is_multiple_of(block_dim) {
1151 crate::bail_invalid_estim!(
1152 "invalid collocation Hessian layout: rows={}, dimension={dimension}",
1153 hessianrows.len()
1154 );
1155 }
1156 let p = hessianrows.len() / block_dim;
1157 let mut out = Array2::<f64>::zeros((p, block_dim));
1158 for k in 0..p {
1159 for idx in 0..block_dim {
1160 out[[k, idx]] = hessianrows[k * block_dim + idx];
1161 }
1162 }
1163 Ok(out)
1164}
1165
1166impl SpatialPenaltyExactState {
1167 fn from_beta_local(
1168 beta_local: ArrayView1<'_, f64>,
1169 cache: &SpatialOperatorRuntimeCache,
1170 epsilons: [f64; 3],
1171 ) -> Result<Self, EstimationError> {
1172 let gradientrows = cache.d1.dot(&beta_local);
1202 let hessianrows = cache.d2.dot(&beta_local);
1203 Ok(Self {
1204 magnitude: CharbonnierScalarBlockState::from_signal(
1205 cache.d0.dot(&beta_local),
1206 epsilons[0],
1207 ),
1208 gradient: CharbonnierGroupedBlockState::from_signal_blocks(
1209 collocationgradient_blocks(&gradientrows, cache.dimension)?,
1210 epsilons[1],
1211 ),
1212 curvature: CharbonnierGroupedBlockState::from_signal_blocks(
1213 collocationhessian_blocks(&hessianrows, cache.dimension)?,
1214 epsilons[2],
1215 ),
1216 })
1217 }
1218
1219 fn absolute_collocation_magnitudes(&self) -> (Array1<f64>, Array1<f64>, Array1<f64>) {
1220 (
1221 self.magnitude.absolute_signal(),
1222 self.gradient.norm_signal(),
1223 self.curvature.norm_signal(),
1224 )
1225 }
1226}
1227
1228fn robust_epsilon_from_samples(values: &[f64], min_epsilon_cfg: f64) -> f64 {
1229 if values.is_empty() {
1230 return min_epsilon_cfg.max(1e-12);
1231 }
1232 let mut clean = values
1233 .iter()
1234 .copied()
1235 .filter(|v| v.is_finite() && *v >= 0.0)
1236 .collect::<Vec<_>>();
1237 if clean.is_empty() {
1238 return min_epsilon_cfg.max(1e-12);
1239 }
1240 clean.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1241
1242 let n = clean.len();
1243 let median = quantile_from_sorted(&clean, 0.5);
1244 let q75 = quantile_from_sorted(&clean, 0.75);
1245 let q95 = quantile_from_sorted(&clean, 0.95);
1246
1247 let mut abs_dev = clean
1248 .iter()
1249 .map(|v| (v - median).abs())
1250 .filter(|v| v.is_finite())
1251 .collect::<Vec<_>>();
1252 abs_dev.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1253 let mad = 1.4826 * quantile_from_sorted(&abs_dev, 0.5);
1254
1255 let mut scale = median.max(mad).max(q75);
1265
1266 let delta = (f64::EPSILON.sqrt() * q95.max(1.0))
1268 .max(min_epsilon_cfg)
1269 .max(1e-12);
1270 let s_min = min_epsilon_cfg.max(1e-12);
1271
1272 if scale <= delta {
1274 let rms = (clean.iter().map(|v| v * v).sum::<f64>() / n as f64).sqrt();
1275 scale = q95.max(rms);
1276 }
1277 if scale <= delta {
1278 scale = s_min;
1279 }
1280
1281 let kappa = 1.0_f64;
1284 (kappa * scale).max(s_min)
1285}
1286
1287fn extract_spatial_operator_runtime_caches(
1288 spec: &TermCollectionSpec,
1289 design: &TermCollectionDesign,
1290) -> Result<Vec<SpatialOperatorRuntimeCache>, EstimationError> {
1291 let smooth_start = design
1292 .design
1293 .ncols()
1294 .saturating_sub(design.smooth.total_smooth_cols());
1295 let mut out = Vec::<SpatialOperatorRuntimeCache>::new();
1296 for (term_idx, (termspec, term_fit)) in spec
1297 .smooth_terms
1298 .iter()
1299 .zip(design.smooth.terms.iter())
1300 .enumerate()
1301 {
1302 let Some(global_base_idx) = smooth_term_penalty_index(spec, design, term_idx) else {
1303 continue;
1304 };
1305 let mut active_local_idx = 0usize;
1306 let mut mass_local_idx = None;
1307 let mut tension_local_idx = None;
1308 let mut stiffness_local_idx = None;
1309 let mut mass_norm = None;
1310 let mut tension_norm = None;
1311 let mut stiffness_norm = None;
1312 for info in &term_fit.penaltyinfo_local {
1313 if !info.active {
1314 continue;
1315 }
1316 match info.source {
1317 PenaltySource::OperatorMass => {
1318 mass_local_idx = Some(active_local_idx);
1319 mass_norm = Some(info.normalization_scale);
1320 }
1321 PenaltySource::OperatorTension => {
1322 tension_local_idx = Some(active_local_idx);
1323 tension_norm = Some(info.normalization_scale);
1324 }
1325 PenaltySource::OperatorStiffness => {
1326 stiffness_local_idx = Some(active_local_idx);
1327 stiffness_norm = Some(info.normalization_scale);
1328 }
1329 _ => {}
1330 }
1331 active_local_idx += 1;
1332 }
1333 let (
1346 Some(mass_local),
1347 Some(tension_local),
1348 Some(stiffness_local),
1349 Some(mass_scale),
1350 Some(tension_scale),
1351 Some(stiffness_scale),
1352 ) = (
1353 mass_local_idx,
1354 tension_local_idx,
1355 stiffness_local_idx,
1356 mass_norm,
1357 tension_norm,
1358 stiffness_norm,
1359 )
1360 else {
1361 continue;
1362 };
1363 let mass_global_idx = global_base_idx + mass_local;
1364 let tension_global_idx = global_base_idx + tension_local;
1365 let stiffness_global_idx = global_base_idx + stiffness_local;
1366
1367 let (feature_cols, mut d0, mut d1, mut d2, collocation_points, dim, center_mass_rows) =
1368 match (&termspec.basis, &term_fit.metadata) {
1369 (
1370 SmoothBasisSpec::Matern { feature_cols, .. },
1371 BasisMetadata::Matern {
1372 centers,
1373 length_scale,
1374 nu,
1375 include_intercept,
1376 identifiability_transform,
1377 aniso_log_scales,
1378 input_scales,
1379 ..
1380 },
1381 ) => {
1382 let collocation_length_scale = match input_scales.as_deref() {
1388 Some(scales) => {
1389 compensate_length_scale_for_standardization(*length_scale, scales)
1390 }
1391 None => *length_scale,
1392 };
1393 let ops = build_matern_collocation_operator_matrices(
1394 centers.view(),
1395 None,
1396 collocation_length_scale,
1397 *nu,
1398 *include_intercept,
1399 identifiability_transform.as_ref().map(|z| z.view()),
1400 aniso_log_scales.as_deref(),
1401 )?;
1402 (
1403 feature_cols.clone(),
1404 ops.d0,
1405 ops.d1,
1406 ops.d2,
1407 ops.collocation_points,
1408 centers.ncols(),
1409 false,
1410 )
1411 }
1412 (
1413 SmoothBasisSpec::Duchon { feature_cols, .. },
1414 BasisMetadata::Duchon {
1415 centers,
1416 length_scale,
1417 power,
1418 nullspace_order,
1419 identifiability_transform,
1420 input_scales,
1421 aniso_log_scales,
1422 operator_collocation_points: Some(collocation_points),
1423 ..
1424 },
1425 ) => {
1426 let collocation_length_scale = match (length_scale, input_scales.as_deref()) {
1427 (Some(ls), Some(scales)) => {
1428 Some(compensate_length_scale_for_standardization(*ls, scales))
1429 }
1430 (Some(ls), None) => Some(*ls),
1431 (None, _) => None,
1432 };
1433 let ops =
1434 gam_terms::basis::build_duchon_collocation_operator_matriceswithworkspace(
1435 centers.view(),
1436 collocation_points.view(),
1437 None,
1438 collocation_length_scale,
1439 *power,
1440 *nullspace_order,
1441 aniso_log_scales.as_deref(),
1442 identifiability_transform.as_ref().map(|z| z.view()),
1443 2,
1444 &mut BasisWorkspace::default(),
1445 )?;
1446 (
1447 feature_cols.clone(),
1448 ops.d0,
1449 ops.d1,
1450 ops.d2,
1451 ops.collocation_points,
1452 centers.ncols(),
1453 true,
1454 )
1455 }
1456 _ => continue,
1457 };
1458 if center_mass_rows && d0.nrows() > 0 && d0.ncols() > 0 {
1459 let means = d0.sum_axis(Axis(0)).mapv(|v| v / d0.nrows() as f64);
1460 for mut row in d0.rows_mut() {
1461 row -= &means;
1462 }
1463 }
1464
1465 let mass_scale = mass_scale.max(1e-12).sqrt();
1483 let tension_scale = tension_scale.max(1e-12).sqrt();
1484 let stiffness_scale = stiffness_scale.max(1e-12).sqrt();
1485 d0.mapv_inplace(|v| v / mass_scale);
1486 d1.mapv_inplace(|v| v / tension_scale);
1487 d2.mapv_inplace(|v| v / stiffness_scale);
1488
1489 let coeff_global_range =
1490 (smooth_start + term_fit.coeff_range.start)..(smooth_start + term_fit.coeff_range.end);
1491 if d0.ncols() != coeff_global_range.len()
1492 || d1.ncols() != coeff_global_range.len()
1493 || d2.ncols() != coeff_global_range.len()
1494 {
1495 crate::bail_invalid_estim!(
1496 "spatial operator dimension mismatch for term '{}': D0 cols={}, D1 cols={}, D2 cols={}, coeffs={}",
1497 term_fit.name,
1498 d0.ncols(),
1499 d1.ncols(),
1500 d2.ncols(),
1501 coeff_global_range.len()
1502 );
1503 }
1504 out.push(SpatialOperatorRuntimeCache {
1505 termname: term_fit.name.clone(),
1506 feature_cols,
1507 coeff_global_range,
1508 mass_penalty_global_idx: mass_global_idx,
1509 tension_penalty_global_idx: tension_global_idx,
1510 stiffness_penalty_global_idx: stiffness_global_idx,
1511 d0,
1512 d1,
1513 d2,
1514 collocation_points,
1515 dimension: dim,
1516 });
1517 }
1518 Ok(out)
1519}
1520
1521fn scalar_operator_response_variance(
1533 operator: &Array2<f64>,
1534 cov_local: &Array2<f64>,
1535) -> Array1<f64> {
1536 Array1::from_iter(operator.rows().into_iter().map(|row| {
1537 let s = cov_local.dot(&row);
1538 row.dot(&s).max(0.0)
1539 }))
1540}
1541
1542fn grouped_operator_response_variance(
1553 operator: &Array2<f64>,
1554 block_dim: usize,
1555 cov_local: &Array2<f64>,
1556) -> Result<Array1<f64>, EstimationError> {
1557 if block_dim == 0 || !operator.nrows().is_multiple_of(block_dim) {
1558 crate::bail_invalid_estim!(
1559 "grouped variance row layout invalid: rows={}, block_dim={block_dim}",
1560 operator.nrows()
1561 );
1562 }
1563 let p = operator.nrows() / block_dim;
1564 let mut out = Array1::<f64>::zeros(p);
1565 for k in 0..p {
1566 let mut acc = 0.0;
1567 for axis in 0..block_dim {
1568 let row = operator.row(k * block_dim + axis);
1569 let s = cov_local.dot(&row);
1570 acc += row.dot(&s);
1571 }
1572 out[k] = acc.max(0.0);
1573 }
1574 Ok(out)
1575}
1576
1577fn compute_spatial_adaptiveweights_for_beta(
1578 beta: &Array1<f64>,
1579 caches: &[SpatialOperatorRuntimeCache],
1580 epsilon_0: f64,
1581 epsilon_g: f64,
1582 epsilon_c: f64,
1583 weight_floor: f64,
1584 weight_ceiling: f64,
1585 beta_covariance: Option<&Array2<f64>>,
1586) -> Result<Vec<SpatialAdaptiveWeights>, EstimationError> {
1587 caches
1619 .iter()
1620 .map(|cache| {
1621 let beta_local = beta.slice(s![cache.coeff_global_range.clone()]);
1622 let exact = SpatialPenaltyExactState::from_beta_local(
1623 beta_local,
1624 cache,
1625 [epsilon_0, epsilon_g, epsilon_c],
1626 )?;
1627 let cov_local = beta_covariance.map(|cov| {
1628 cov.slice(s![
1629 cache.coeff_global_range.clone(),
1630 cache.coeff_global_range.clone()
1631 ])
1632 .to_owned()
1633 });
1634 let dim = cache.dimension;
1635 let (var_0, var_g, var_c) = match cov_local.as_ref() {
1636 Some(cov) => (
1637 scalar_operator_response_variance(&cache.d0, cov),
1638 grouped_operator_response_variance(&cache.d1, dim, cov)?,
1639 grouped_operator_response_variance(&cache.d2, dim * dim, cov)?,
1640 ),
1641 None => (
1642 Array1::<f64>::zeros(exact.magnitude.signal.len()),
1643 Array1::<f64>::zeros(exact.gradient.norm.len()),
1644 Array1::<f64>::zeros(exact.curvature.norm.len()),
1645 ),
1646 };
1647 let (_, inv_0) = exact.magnitude.surrogateweights_posterior_snr(
1648 &var_0,
1649 weight_floor,
1650 weight_ceiling,
1651 );
1652 let (_, inv_g) =
1653 exact
1654 .gradient
1655 .surrogateweights_posterior_snr(&var_g, weight_floor, weight_ceiling);
1656 let (_, inv_c) = exact.curvature.surrogateweights_posterior_snr(
1657 &var_c,
1658 weight_floor,
1659 weight_ceiling,
1660 );
1661 Ok(SpatialAdaptiveWeights {
1662 inv_magweight: inv_0,
1663 invgradweight: inv_g,
1664 inv_lapweight: inv_c,
1665 })
1666 })
1667 .collect()
1668}
1669
1670fn compute_initial_epsilons(
1671 beta: &Array1<f64>,
1672 caches: &[SpatialOperatorRuntimeCache],
1673 min_epsilon: f64,
1674) -> Result<(f64, f64, f64), EstimationError> {
1675 let mut fvals = Vec::<f64>::new();
1676 let mut gvals = Vec::<f64>::new();
1677 let mut cvals = Vec::<f64>::new();
1678 for cache in caches {
1679 let beta_local = beta.slice(s![cache.coeff_global_range.clone()]);
1680 let exact = SpatialPenaltyExactState::from_beta_local(
1681 beta_local,
1682 cache,
1683 [min_epsilon, min_epsilon, min_epsilon],
1684 )?;
1685 let (f, g, c) = exact.absolute_collocation_magnitudes();
1686 fvals.extend(f.iter().copied());
1687 gvals.extend(g.iter().copied());
1688 cvals.extend(c.iter().copied());
1689 }
1690 let eps_0 = robust_epsilon_from_samples(&fvals, min_epsilon);
1696 let eps_g = robust_epsilon_from_samples(&gvals, min_epsilon);
1697 let eps_c = robust_epsilon_from_samples(&cvals, min_epsilon);
1698 Ok((eps_0, eps_g, eps_c))
1699}
1700
1701fn exact_spatial_adaptive_penalty_index_set(
1702 caches: &[SpatialOperatorRuntimeCache],
1703) -> BTreeSet<usize> {
1704 let mut out = BTreeSet::new();
1705 for cache in caches {
1706 out.insert(cache.mass_penalty_global_idx);
1707 out.insert(cache.tension_penalty_global_idx);
1708 out.insert(cache.stiffness_penalty_global_idx);
1709 }
1710 out
1711}
1712
1713fn build_spatial_adaptive_hyperspecs(cache_count: usize) -> Vec<SpatialAdaptiveHyperSpec> {
1714 let mut out = Vec::with_capacity(cache_count * 3 + 3);
1715 for cache_index in 0..cache_count {
1716 out.push(SpatialAdaptiveHyperSpec {
1717 cache_index,
1718 kind: SpatialAdaptiveHyperKind::LogLambdaMagnitude,
1719 });
1720 out.push(SpatialAdaptiveHyperSpec {
1721 cache_index,
1722 kind: SpatialAdaptiveHyperKind::LogLambdaGradient,
1723 });
1724 out.push(SpatialAdaptiveHyperSpec {
1725 cache_index,
1726 kind: SpatialAdaptiveHyperKind::LogLambdaCurvature,
1727 });
1728 }
1729 out.push(SpatialAdaptiveHyperSpec {
1730 cache_index: 0,
1731 kind: SpatialAdaptiveHyperKind::LogEpsilonMagnitude,
1732 });
1733 out.push(SpatialAdaptiveHyperSpec {
1734 cache_index: 0,
1735 kind: SpatialAdaptiveHyperKind::LogEpsilonGradient,
1736 });
1737 out.push(SpatialAdaptiveHyperSpec {
1738 cache_index: 0,
1739 kind: SpatialAdaptiveHyperKind::LogEpsilonCurvature,
1740 });
1741 out
1742}
1743
1744fn penalty_matrixwith_local_block(
1745 total_dim: usize,
1746 coeff_range: Range<usize>,
1747 local: &Array2<f64>,
1748) -> Array2<f64> {
1749 let mut out = Array2::<f64>::zeros((total_dim, total_dim));
1750 out.slice_mut(s![coeff_range.clone(), coeff_range])
1751 .assign(local);
1752 out
1753}
1754
1755fn fit_term_collectionwith_exact_spatial_adaptive_regularization(
1756 baseline: FittedTermCollection,
1757 y: ArrayView1<'_, f64>,
1758 weights: ArrayView1<'_, f64>,
1759 offset: ArrayView1<'_, f64>,
1760 family: LikelihoodSpec,
1761 options: &FitOptions,
1762 runtime_caches: &[SpatialOperatorRuntimeCache],
1763) -> Result<FittedTermCollection, EstimationError> {
1764 let adaptive_opts = options.adaptive_regularization.clone().unwrap_or_default();
1793 let adaptive_penalty_indices = exact_spatial_adaptive_penalty_index_set(runtime_caches);
1794 let p_total = baseline.design.design.ncols();
1795 struct RetainedPenaltySetup {
1796 global_idx: usize,
1797 global_penalty: Array2<f64>,
1798 nullspace_dim: usize,
1799 log_lambda: f64,
1800 col_range: Range<usize>,
1801 hessian_piece: Array2<f64>,
1802 }
1803 use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
1804 let retained_setups = baseline
1805 .design
1806 .penalties
1807 .par_iter()
1808 .enumerate()
1809 .map(|(idx, bp)| {
1810 if adaptive_penalty_indices.contains(&idx) {
1811 return None;
1812 }
1813 let lambda = baseline.fit.lambdas[idx];
1814 Some(RetainedPenaltySetup {
1815 global_idx: idx,
1816 global_penalty: bp.to_global(p_total),
1817 nullspace_dim: baseline
1818 .design
1819 .nullspace_dims
1820 .get(idx)
1821 .copied()
1822 .unwrap_or(0),
1823 log_lambda: lambda.max(1e-12).ln(),
1824 col_range: bp.col_range.clone(),
1825 hessian_piece: bp.local.mapv(|v| lambda * v),
1826 })
1827 })
1828 .collect::<Vec<_>>();
1829 let retained_count = retained_setups
1830 .iter()
1831 .filter(|setup| setup.is_some())
1832 .count();
1833 let mut retained_penalties = Vec::<Array2<f64>>::with_capacity(retained_count);
1834 let mut retained_nullspace_dims = Vec::<usize>::with_capacity(retained_count);
1835 let mut retained_log_lambdas = Vec::<f64>::with_capacity(retained_count);
1836 let mut retained_global_indices = Vec::<usize>::with_capacity(retained_count);
1837 let mut fixed_quadratichessian = Array2::<f64>::zeros((p_total, p_total));
1838 for setup in retained_setups.into_iter().flatten() {
1839 retained_penalties.push(setup.global_penalty);
1840 retained_nullspace_dims.push(setup.nullspace_dim);
1841 retained_log_lambdas.push(setup.log_lambda);
1842 retained_global_indices.push(setup.global_idx);
1843 fixed_quadratichessian
1844 .slice_mut(s![setup.col_range.clone(), setup.col_range])
1845 .scaled_add(1.0, &setup.hessian_piece);
1846 }
1847
1848 let (eps_0_init, eps_g_init, eps_c_init) = compute_initial_epsilons(
1849 &baseline.fit.beta,
1850 runtime_caches,
1851 adaptive_opts.min_epsilon,
1852 )?;
1853 let mut initial_theta =
1854 Array1::<f64>::zeros(retained_penalties.len() + runtime_caches.len() * 3 + 3);
1855 for (idx, value) in retained_log_lambdas.iter().enumerate() {
1856 initial_theta[idx] = *value;
1857 }
1858 let adaptive_log_lambda_components = runtime_caches
1859 .par_iter()
1860 .map(|cache| {
1861 [
1862 baseline.fit.lambdas[cache.mass_penalty_global_idx]
1863 .max(1e-12)
1864 .ln(),
1865 baseline.fit.lambdas[cache.tension_penalty_global_idx]
1866 .max(1e-12)
1867 .ln(),
1868 baseline.fit.lambdas[cache.stiffness_penalty_global_idx]
1869 .max(1e-12)
1870 .ln(),
1871 ]
1872 })
1873 .collect::<Vec<_>>();
1874 let mut at = retained_penalties.len();
1875 for logs in &adaptive_log_lambda_components {
1876 initial_theta[at] = logs[0];
1877 initial_theta[at + 1] = logs[1];
1878 initial_theta[at + 2] = logs[2];
1879 at += 3;
1880 }
1881 initial_theta[at] = eps_0_init.max(adaptive_opts.min_epsilon).ln();
1882 initial_theta[at + 1] = eps_g_init.max(adaptive_opts.min_epsilon).ln();
1883 initial_theta[at + 2] = eps_c_init.max(adaptive_opts.min_epsilon).ln();
1884
1885 let hyperspecs = build_spatial_adaptive_hyperspecs(runtime_caches.len());
1886 let zero_psi_op: std::sync::Arc<dyn gam_custom_family::CustomFamilyPsiDerivativeOperator> =
1887 std::sync::Arc::new(gam_custom_family::ZeroPsiDerivativeOperator::new(
1888 baseline.design.design.nrows(),
1889 baseline.design.design.ncols(),
1890 ));
1891 let derivative_blocks = vec![
1892 hyperspecs
1893 .par_iter()
1894 .map(|_| CustomFamilyBlockPsiDerivative {
1895 penalty_index: None,
1896 x_psi: Array2::<f64>::zeros((0, 0)),
1897 s_psi: Array2::<f64>::zeros((0, 0)),
1898 s_psi_components: None,
1899 s_psi_penalty_components: None,
1900 x_psi_psi: None,
1901 s_psi_psi: None,
1902 s_psi_psi_components: None,
1903 s_psi_psi_penalty_components: None,
1904 implicit_operator: Some(std::sync::Arc::clone(&zero_psi_op)),
1905 implicit_axis: 0,
1906 implicit_group_id: None,
1907 })
1908 .collect::<Vec<_>>(),
1909 ];
1910
1911 let mixture_link_state = options
1912 .mixture_link
1913 .clone()
1914 .as_ref()
1915 .map(state_fromspec)
1916 .transpose()
1917 .map_err(EstimationError::InvalidInput)?;
1918 let sas_link_state = options
1919 .sas_link
1920 .map(|spec| {
1921 if family.is_binomial_beta_logistic() {
1922 state_from_beta_logisticspec(spec)
1923 } else {
1924 state_from_sasspec(spec)
1925 }
1926 })
1927 .transpose()
1928 .map_err(EstimationError::InvalidInput)?;
1929 let latent_cloglog_state = options.latent_cloglog;
1930 let shared_y = Arc::new(y.to_owned());
1931 let sharedweights = Arc::new(weights.to_owned());
1932 let shared_design = baseline
1933 .design
1934 .design
1935 .try_to_dense_arc("spatial adaptive exact hyperfit design")
1936 .map_err(EstimationError::InvalidInput)?;
1937 let shared_offset = Arc::new(offset.to_owned());
1938 let shared_runtime_caches = Arc::new(runtime_caches.to_vec());
1939 let shared_hyperspecs = Arc::new(hyperspecs.clone());
1940 let zero_quadratic = Arc::new(Array2::<f64>::zeros((
1941 baseline.design.design.ncols(),
1942 baseline.design.design.ncols(),
1943 )));
1944 let base_family = SpatialAdaptiveExactFamily {
1945 family: family.clone(),
1946 latent_cloglog_state,
1947 mixture_link_state: mixture_link_state.clone(),
1948 sas_link_state,
1949 y: shared_y.clone(),
1950 weights: sharedweights.clone(),
1951 design: shared_design.clone(),
1952 offset: shared_offset.clone(),
1953 linear_constraints: baseline.design.linear_constraints.clone(),
1954 runtime_caches: shared_runtime_caches.clone(),
1955 adaptive_params: Vec::new(),
1956 fixed_quadratichessian: zero_quadratic.clone(),
1957 hyperspecs: shared_hyperspecs.clone(),
1958 exact_eval_cache: Arc::new(Mutex::new(None)),
1959 };
1960
1961 let rho_dim = retained_penalties.len();
1962 let operator_slots_end = rho_dim + runtime_caches.len() * 3;
1963 const UNIFIED_LOG_WINDOW: f64 = 6.0;
1973 const RETAINED_LAMBDA_LOG_LOWER_FLOOR: f64 = -30.0;
1974 const RETAINED_LAMBDA_LOG_UPPER_CAP: f64 = 30.0;
1975 const OPERATOR_LAMBDA_LOG_LOWER_FLOOR: f64 = -10.0;
1976 const OPERATOR_LAMBDA_LOG_UPPER_CAP: f64 = 30.0;
1977 let epsilon_floor_log = adaptive_opts.min_epsilon.max(1e-12).ln();
1978 let anchored_bound = |idx: usize, sign: f64| -> f64 {
1979 let raw = initial_theta[idx] + sign * UNIFIED_LOG_WINDOW;
1980 if idx < rho_dim {
1981 raw.clamp(
1982 RETAINED_LAMBDA_LOG_LOWER_FLOOR,
1983 RETAINED_LAMBDA_LOG_UPPER_CAP,
1984 )
1985 } else if idx < operator_slots_end {
1986 raw.clamp(
1987 OPERATOR_LAMBDA_LOG_LOWER_FLOOR,
1988 OPERATOR_LAMBDA_LOG_UPPER_CAP,
1989 )
1990 } else {
1991 raw.max(epsilon_floor_log)
1992 }
1993 };
1994 let eps_lower =
1995 Array1::from_iter((0..initial_theta.len()).map(|idx| anchored_bound(idx, -1.0)));
1996 let eps_upper = Array1::from_iter((0..initial_theta.len()).map(|idx| anchored_bound(idx, 1.0)));
1997 let blockspec = ParameterBlockSpec {
1998 name: "eta".to_string(),
1999 design: baseline.design.design.clone(),
2000 offset: offset.to_owned(),
2001 penalties: retained_penalties
2002 .iter()
2003 .cloned()
2004 .map(PenaltyMatrix::Dense)
2005 .collect(),
2006 nullspace_dims: retained_nullspace_dims.clone(),
2007 initial_log_lambdas: Array1::from_vec(retained_log_lambdas.clone()),
2008 initial_beta: Some(baseline.fit.beta.clone()),
2009 gauge_priority: 100,
2010 jacobian_callback: None,
2011 stacked_design: None,
2012 stacked_offset: None,
2013 };
2014 let screening_cap = Arc::new(AtomicUsize::new(0));
2015 let outer_opts = BlockwiseFitOptions {
2016 inner_max_cycles: options.max_iter,
2017 inner_tol: options.tol,
2018 outer_max_iter: options.max_iter,
2019 outer_tol: options.tol,
2020 compute_covariance: false,
2021 screening_max_inner_iterations: Some(Arc::clone(&screening_cap)),
2022 ..BlockwiseFitOptions::default()
2023 };
2024
2025 use gam_solve::rho_optimizer::OuterProblem;
2026 use gam_problem::{DeclaredHessianForm, Derivative, HessianResult, OuterEval};
2027
2028 struct SpatialAdaptiveOuterState {
2029 warm_cache: Option<CustomFamilyWarmStart>,
2030 last_eval: Option<(
2031 Array1<f64>,
2032 f64,
2033 Array1<f64>,
2034 HessianResult,
2035 CustomFamilyWarmStart,
2036 )>,
2037 }
2038
2039 let n_theta = initial_theta.len();
2040
2041 let theta_bounds = Some((eps_lower.clone(), eps_upper.clone()));
2044 let clamp_theta = {
2045 let lo = eps_lower;
2046 let hi = eps_upper;
2047 move |theta: &Array1<f64>| -> Array1<f64> {
2048 let mut clamped = theta.clone();
2049 for i in 0..clamped.len() {
2050 clamped[i] = clamped[i].clamp(lo[i], hi[i]);
2051 }
2052 clamped
2053 }
2054 };
2055
2056 let decode_theta = |theta: &Array1<f64>| -> (Array1<f64>, Vec<SpatialAdaptiveTermHyperParams>) {
2057 let rho = theta.slice(s![..rho_dim]).to_owned();
2058 let adaptive_lambda_start = rho_dim;
2059 let adaptive_lambda_end = adaptive_lambda_start + runtime_caches.len() * 3;
2060 let eps = [
2061 theta[adaptive_lambda_end].exp(),
2062 theta[adaptive_lambda_end + 1].exp(),
2063 theta[adaptive_lambda_end + 2].exp(),
2064 ];
2065 let adaptive_params = runtime_caches
2066 .iter()
2067 .enumerate()
2068 .map(|(cache_idx, _)| SpatialAdaptiveTermHyperParams {
2069 lambda: [
2070 theta[adaptive_lambda_start + cache_idx * 3].exp(),
2071 theta[adaptive_lambda_start + cache_idx * 3 + 1].exp(),
2072 theta[adaptive_lambda_start + cache_idx * 3 + 2].exp(),
2073 ],
2074 epsilon: eps,
2075 })
2076 .collect::<Vec<_>>();
2077 (rho, adaptive_params)
2078 };
2079 let analytic_outer_hessian_available =
2080 gam_custom_family::joint_exact_analytic_outer_hessian_available()
2081 && base_family
2082 .exact_outer_derivative_order(std::slice::from_ref(&blockspec), &outer_opts)
2083 .has_hessian()
2084 && gam_custom_family::exact_newton_outer_geometry_supports_second_order_solver(
2085 &base_family,
2086 );
2087 let outer_max_iter = gam_custom_family::cost_gated_first_order_max_iter(
2088 options.max_iter,
2089 base_family.coefficient_gradient_cost(std::slice::from_ref(&blockspec)),
2090 analytic_outer_hessian_available,
2091 );
2092 if outer_max_iter < options.max_iter {
2093 log::info!(
2094 "[OUTER] exact spatial adaptive regularization: first-order work gate reduced outer_max_iter {} -> {}",
2095 options.max_iter,
2096 outer_max_iter,
2097 );
2098 }
2099 let problem = OuterProblem::new(n_theta)
2105 .with_gradient(Derivative::Analytic)
2106 .with_hessian(if analytic_outer_hessian_available {
2107 DeclaredHessianForm::Either
2108 } else {
2109 DeclaredHessianForm::Unavailable
2110 })
2111 .with_fallback_policy(gam_solve::rho_optimizer::FallbackPolicy::Disabled)
2112 .with_psi_dim(n_theta.saturating_sub(rho_dim))
2113 .with_tolerance(options.tol)
2114 .with_max_iter(outer_max_iter)
2115 .with_seed_config(gam_problem::SeedConfig::default())
2116 .with_screening_cap(Arc::clone(&screening_cap))
2117 .with_initial_rho(initial_theta.clone());
2118 let problem = if let Some((lo, hi)) = theta_bounds {
2119 problem.with_bounds(lo, hi)
2120 } else {
2121 problem
2122 };
2123
2124 let eval_outer = |st: &mut SpatialAdaptiveOuterState,
2125 theta: &Array1<f64>,
2126 order: gam_solve::rho_optimizer::OuterEvalOrder|
2127 -> Result<OuterEval, EstimationError> {
2128 let theta = clamp_theta(theta);
2129
2130 if let Some((cached_theta, cached_cost, cached_grad, cached_hess, cached_warm)) =
2131 &st.last_eval
2132 && cached_theta.len() == theta.len()
2133 && cached_theta
2134 .iter()
2135 .zip(theta.iter())
2136 .all(|(&a, &b)| (a - b).abs() <= 1e-12)
2137 && (!matches!(
2138 order,
2139 gam_solve::rho_optimizer::OuterEvalOrder::ValueGradientHessian
2140 ) || analytic_outer_hessian_available)
2141 {
2142 st.warm_cache = Some(cached_warm.clone());
2143 return Ok(OuterEval {
2144 cost: *cached_cost,
2145 gradient: cached_grad.clone(),
2146 hessian: if matches!(
2147 order,
2148 gam_solve::rho_optimizer::OuterEvalOrder::ValueGradientHessian
2149 ) && analytic_outer_hessian_available
2150 {
2151 cached_hess.clone()
2152 } else {
2153 HessianResult::Unavailable
2154 },
2155 inner_beta_hint: None,
2156 });
2157 }
2158
2159 let (rho, adaptive_params) = decode_theta(&theta);
2160 let family_eval = base_family.with_adaptive_params(adaptive_params, zero_quadratic.clone());
2161 let need_hessian = matches!(
2162 order,
2163 gam_solve::rho_optimizer::OuterEvalOrder::ValueGradientHessian
2164 ) && analytic_outer_hessian_available;
2165 let result = evaluate_custom_family_joint_hyper(
2166 &family_eval,
2167 std::slice::from_ref(&blockspec),
2168 &outer_opts,
2169 &rho,
2170 &derivative_blocks,
2171 st.warm_cache.as_ref(),
2172 if need_hessian {
2173 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueGradientHessian
2174 } else {
2175 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueAndGradient
2176 },
2177 )
2178 .map_err(|e| {
2179 EstimationError::RemlOptimizationFailed(format!("spatial adaptive eval failed: {e}"))
2180 })?;
2181 if !result.inner_converged {
2182 st.warm_cache = Some(result.warm_start.clone());
2183 return Err(EstimationError::RemlOptimizationFailed(
2184 "exact spatial adaptive inner solve did not converge".to_string(),
2185 ));
2186 }
2187 if !result.objective.is_finite() || result.gradient.iter().any(|v| !v.is_finite()) {
2188 return Err(EstimationError::RemlOptimizationFailed(
2189 "exact spatial adaptive objective returned non-finite values".to_string(),
2190 ));
2191 }
2192 let hessian_result = if need_hessian {
2193 if !result.outer_hessian.is_analytic() {
2194 return Err(EstimationError::RemlOptimizationFailed(
2195 "exact spatial adaptive objective did not return an exact outer Hessian"
2196 .to_string(),
2197 ));
2198 }
2199 match result.outer_hessian.dim() {
2200 Some(dim) if dim == theta.len() => {}
2201 Some(dim) => {
2202 return Err(EstimationError::RemlOptimizationFailed(format!(
2203 "exact spatial adaptive outer Hessian dimension mismatch: got {dim}, expected {}",
2204 theta.len(),
2205 )));
2206 }
2207 None => {
2208 return Err(EstimationError::RemlOptimizationFailed(
2209 "exact spatial adaptive objective did not report an outer Hessian dimension"
2210 .to_string(),
2211 ));
2212 }
2213 }
2214 st.last_eval = Some((
2215 theta.clone(),
2216 result.objective,
2217 result.gradient.clone(),
2218 result.outer_hessian.clone(),
2219 result.warm_start.clone(),
2220 ));
2221 result.outer_hessian
2222 } else {
2223 HessianResult::Unavailable
2224 };
2225 st.warm_cache = Some(result.warm_start);
2226 Ok(OuterEval {
2227 cost: result.objective,
2228 gradient: result.gradient,
2229 hessian: hessian_result,
2230 inner_beta_hint: None,
2231 })
2232 };
2233
2234 let mut obj = problem.build_objective_with_screening_proxy(
2235 SpatialAdaptiveOuterState {
2236 warm_cache: None,
2237 last_eval: None,
2238 },
2239 |st: &mut SpatialAdaptiveOuterState, theta: &Array1<f64>| {
2240 let theta = clamp_theta(theta);
2241 let (rho, adaptive_params) = decode_theta(&theta);
2242 let family_eval =
2243 base_family.with_adaptive_params(adaptive_params, zero_quadratic.clone());
2244 let result = evaluate_custom_family_joint_hyper(
2245 &family_eval,
2246 std::slice::from_ref(&blockspec),
2247 &outer_opts,
2248 &rho,
2249 &derivative_blocks,
2250 st.warm_cache.as_ref(),
2251 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueOnly,
2252 )
2253 .map_err(|e| {
2254 EstimationError::RemlOptimizationFailed(format!(
2255 "spatial adaptive cost eval failed: {e}"
2256 ))
2257 })?;
2258 if !result.inner_converged {
2259 st.warm_cache = Some(result.warm_start);
2260 return Err(EstimationError::RemlOptimizationFailed(
2261 "exact spatial adaptive cost inner solve did not converge".to_string(),
2262 ));
2263 }
2264 st.warm_cache = Some(result.warm_start);
2265 Ok(result.objective)
2266 },
2267 |st: &mut SpatialAdaptiveOuterState, theta: &Array1<f64>| {
2268 eval_outer(
2269 st,
2270 theta,
2271 if analytic_outer_hessian_available {
2272 gam_solve::rho_optimizer::OuterEvalOrder::ValueGradientHessian
2273 } else {
2274 gam_solve::rho_optimizer::OuterEvalOrder::ValueAndGradient
2275 },
2276 )
2277 },
2278 |st: &mut SpatialAdaptiveOuterState,
2279 theta: &Array1<f64>,
2280 order: gam_solve::rho_optimizer::OuterEvalOrder| {
2281 eval_outer(st, theta, order)
2282 },
2283 Some(|st: &mut SpatialAdaptiveOuterState| {
2284 st.warm_cache = None;
2285 st.last_eval = None;
2286 }),
2287 Some(|st: &mut SpatialAdaptiveOuterState, theta: &Array1<f64>| {
2288 let theta = clamp_theta(theta);
2289 let (rho, adaptive_params) = decode_theta(&theta);
2290 let family_eval =
2291 base_family.with_adaptive_params(adaptive_params, zero_quadratic.clone());
2292 let result = evaluate_custom_family_joint_hyper_efs(
2293 &family_eval,
2294 std::slice::from_ref(&blockspec),
2295 &outer_opts,
2296 &rho,
2297 &derivative_blocks,
2298 st.warm_cache.as_ref(),
2299 )
2300 .map_err(|e| {
2301 EstimationError::RemlOptimizationFailed(format!(
2302 "spatial adaptive EFS eval failed: {e}"
2303 ))
2304 })?;
2305 if !result.inner_converged {
2306 st.warm_cache = Some(result.warm_start);
2307 return Err(EstimationError::RemlOptimizationFailed(
2308 "exact spatial adaptive EFS inner solve did not converge".to_string(),
2309 ));
2310 }
2311 st.warm_cache = Some(result.warm_start);
2312 Ok(result.efs_eval)
2313 }),
2314 |st: &mut SpatialAdaptiveOuterState, theta: &Array1<f64>| {
2326 let theta = clamp_theta(theta);
2327 let (rho, adaptive_params) = decode_theta(&theta);
2328 let family_eval =
2329 base_family.with_adaptive_params(adaptive_params, zero_quadratic.clone());
2330 let result = evaluate_custom_family_joint_hyper(
2331 &family_eval,
2332 std::slice::from_ref(&blockspec),
2333 &outer_opts,
2334 &rho,
2335 &derivative_blocks,
2336 st.warm_cache.as_ref(),
2337 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueOnly,
2338 )
2339 .map_err(|e| {
2340 EstimationError::RemlOptimizationFailed(format!(
2341 "spatial adaptive screening eval failed: {e}"
2342 ))
2343 })?;
2344 st.warm_cache = Some(result.warm_start);
2345 Ok(result.objective)
2346 },
2347 );
2348
2349 let outer_result = problem
2350 .run(&mut obj, "exact spatial adaptive regularization")
2351 .map_err(|e| {
2352 EstimationError::InvalidInput(format!(
2353 "exact spatial adaptive outer optimization failed: {e}"
2354 ))
2355 })?;
2356 if !outer_result.converged {
2357 let rel_to_cost_threshold = options.tol * (1.0_f64 + outer_result.final_value.abs());
2374 if let Some(final_grad) = outer_result
2378 .final_grad_norm
2379 .filter(|v| v.is_finite() && *v <= rel_to_cost_threshold)
2380 {
2381 log::info!(
2382 "[spatial-adaptive] outer optimization hit max_iter={} but \
2383 projected gradient norm {:.3e} ≤ τ·(1+|f|) = {:.3e} \
2384 (τ={:.3e}, |f|={:.3e}); accepting iterate under the mgcv-style \
2385 relative-to-cost REML convergence criterion.",
2386 outer_result.iterations,
2387 final_grad,
2388 rel_to_cost_threshold,
2389 options.tol,
2390 outer_result.final_value.abs(),
2391 );
2392 } else {
2393 crate::bail_invalid_estim!(
2394 "exact spatial adaptive outer optimization did not converge after {} iterations (final_objective={:.6e}, final_grad_norm={})",
2395 outer_result.iterations,
2396 outer_result.final_value,
2397 outer_result.final_grad_norm_report(),
2398 );
2399 }
2400 }
2401 let outer_iterations = outer_result.iterations;
2402 let outer_grad_norm: Option<f64> = outer_result.final_grad_norm;
2405 let theta_star = outer_result.rho;
2406 let rho_star = theta_star.slice(s![..rho_dim]).to_owned();
2407 let adaptive_lambda_start = rho_dim;
2408 let adaptive_lambda_end = adaptive_lambda_start + runtime_caches.len() * 3;
2409 let eps_star = [
2410 theta_star[adaptive_lambda_end].exp(),
2411 theta_star[adaptive_lambda_end + 1].exp(),
2412 theta_star[adaptive_lambda_end + 2].exp(),
2413 ];
2414 let adaptive_params = runtime_caches
2415 .iter()
2416 .enumerate()
2417 .map(|(cache_idx, _)| SpatialAdaptiveTermHyperParams {
2418 lambda: [
2419 theta_star[adaptive_lambda_start + cache_idx * 3].exp(),
2420 theta_star[adaptive_lambda_start + cache_idx * 3 + 1].exp(),
2421 theta_star[adaptive_lambda_start + cache_idx * 3 + 2].exp(),
2422 ],
2423 epsilon: eps_star,
2424 })
2425 .collect::<Vec<_>>();
2426 let mut fixed_total = Array2::<f64>::zeros((
2427 baseline.design.design.ncols(),
2428 baseline.design.design.ncols(),
2429 ));
2430 for (idx, penalty) in retained_penalties.iter().enumerate() {
2431 fixed_total.scaled_add(rho_star[idx].exp(), penalty);
2432 }
2433 let final_family =
2434 base_family.with_adaptive_params(adaptive_params.clone(), Arc::new(fixed_total.clone()));
2435 let final_blockspec = ParameterBlockSpec {
2436 name: "eta".to_string(),
2437 design: baseline.design.design.clone(),
2438 offset: offset.to_owned(),
2439 penalties: vec![],
2440 nullspace_dims: vec![],
2441 initial_log_lambdas: Array1::zeros(0),
2442 initial_beta: Some(baseline.fit.beta.clone()),
2443 gauge_priority: 100,
2444 jacobian_callback: None,
2445 stacked_design: None,
2446 stacked_offset: None,
2447 };
2448 let final_fit = fit_custom_family(
2449 &final_family,
2450 &[final_blockspec],
2451 &BlockwiseFitOptions {
2452 inner_max_cycles: options.max_iter,
2453 inner_tol: options.tol,
2454 outer_max_iter: 1,
2455 outer_tol: options.tol,
2456 compute_covariance: true,
2457 ..BlockwiseFitOptions::default()
2458 },
2459 )
2460 .map_err(EstimationError::CustomFamily)?;
2461 let beta = final_fit.block_states[0].beta.clone();
2462 let final_eval = final_family
2463 .exact_evaluation(&beta)
2464 .map_err(EstimationError::InvalidInput)?;
2465 let penalized_hessian = final_eval
2466 .totalobjectivehessian(&final_family.design)
2467 .map_err(EstimationError::InvalidInput)?;
2468 let beta_covariance = final_fit.covariance_conditional.clone();
2469 let beta_standard_errors = beta_covariance
2470 .as_ref()
2471 .map(|cov| Array1::from_iter((0..cov.nrows()).map(|i| cov[[i, i]].max(0.0).sqrt())));
2472
2473 let mut full_lambdas = baseline.fit.lambdas.clone();
2474 for (idx, &global_idx) in retained_global_indices.iter().enumerate() {
2475 full_lambdas[global_idx] = rho_star[idx].exp();
2476 }
2477 for (cache_idx, cache) in runtime_caches.iter().enumerate() {
2478 full_lambdas[cache.mass_penalty_global_idx] = adaptive_params[cache_idx].lambda[0];
2479 full_lambdas[cache.tension_penalty_global_idx] = adaptive_params[cache_idx].lambda[1];
2480 full_lambdas[cache.stiffness_penalty_global_idx] = adaptive_params[cache_idx].lambda[2];
2481 }
2482
2483 let deviance = if family.is_gaussian_identity() {
2484 y.iter()
2485 .zip(final_eval.obs.mu.iter())
2486 .zip(weights.iter())
2487 .map(|((&yy, &mu), &w)| w.max(0.0) * (yy - mu) * (yy - mu))
2488 .sum()
2489 } else {
2490 -2.0 * final_eval.obs.log_likelihood
2491 };
2492 let mut local_penalty_blocks =
2493 Vec::<PenaltySpec>::with_capacity(baseline.design.penalties.len());
2494 for (global_idx, bp) in baseline.design.penalties.iter().enumerate() {
2495 if adaptive_penalty_indices.contains(&global_idx) {
2496 let cache = runtime_caches
2497 .iter()
2498 .find(|cache| {
2499 cache.mass_penalty_global_idx == global_idx
2500 || cache.tension_penalty_global_idx == global_idx
2501 || cache.stiffness_penalty_global_idx == global_idx
2502 })
2503 .ok_or_else(|| {
2504 EstimationError::InvalidInput(format!(
2505 "missing runtime cache for adaptive penalty index {global_idx}"
2506 ))
2507 })?;
2508 let cache_idx = runtime_caches
2509 .iter()
2510 .position(|c| {
2511 c.mass_penalty_global_idx == global_idx
2512 || c.tension_penalty_global_idx == global_idx
2513 || c.stiffness_penalty_global_idx == global_idx
2514 })
2515 .ok_or_else(|| {
2516 EstimationError::InvalidInput(format!(
2517 "missing adaptive cache position for penalty index {global_idx}"
2518 ))
2519 })?;
2520 let state = &final_eval.adaptive_states[cache_idx];
2521 let local = if cache.mass_penalty_global_idx == global_idx {
2522 scalar_operatorhessian(&cache.d0, &state.magnitude.betahessian_diag())
2523 .mapv(|v| adaptive_params[cache_idx].lambda[0] * v)
2524 } else if cache.tension_penalty_global_idx == global_idx {
2525 grouped_operatorhessian(
2526 &cache.d1,
2527 cache.dimension,
2528 &state.gradient.betahessian_blocks(),
2529 )?
2530 .mapv(|v| adaptive_params[cache_idx].lambda[1] * v)
2531 } else {
2532 grouped_operatorhessian(
2533 &cache.d2,
2534 cache.dimension * cache.dimension,
2535 &state.curvature.betahessian_blocks(),
2536 )?
2537 .mapv(|v| adaptive_params[cache_idx].lambda[2] * v)
2538 };
2539 local_penalty_blocks.push(PenaltySpec::Dense(penalty_matrixwith_local_block(
2541 baseline.design.design.ncols(),
2542 cache.coeff_global_range.clone(),
2543 &local,
2544 )));
2545 } else {
2546 local_penalty_blocks.push(PenaltySpec::Dense(
2547 bp.to_global(p_total).mapv(|v| v * full_lambdas[global_idx]),
2548 ));
2549 }
2550 }
2551 let (edf_by_block, penalty_block_trace, edf_total) = if let Some(cov) = beta_covariance.as_ref()
2552 {
2553 exact_bounded_edf(
2554 &local_penalty_blocks,
2555 &Array1::from_elem(local_penalty_blocks.len(), 1.0),
2556 cov,
2557 )?
2558 } else {
2559 (
2560 vec![0.0; local_penalty_blocks.len()],
2561 vec![0.0; local_penalty_blocks.len()],
2562 0.0,
2563 )
2564 };
2565 let stable_penalty_term =
2566 2.0 * final_eval.adaptive_penalty_value + beta.dot(&fixed_total.dot(&beta));
2567 let standard_deviation = if family.is_gaussian_identity() {
2568 let denom = (y.len() as f64 - edf_total).max(1.0);
2569 (deviance / denom).sqrt()
2570 } else {
2571 1.0
2572 };
2573 let maps = compute_spatial_adaptiveweights_for_beta(
2574 &beta,
2575 runtime_caches,
2576 eps_star[0],
2577 eps_star[1],
2578 eps_star[2],
2579 adaptive_opts.weight_floor,
2580 adaptive_opts.weight_ceiling,
2581 beta_covariance.as_ref(),
2585 )?
2586 .into_iter()
2587 .zip(runtime_caches.iter())
2588 .map(|(w, cache)| AdaptiveSpatialMap {
2589 termname: cache.termname.clone(),
2590 feature_cols: cache.feature_cols.clone(),
2591 collocation_points: cache.collocation_points.clone(),
2592 inv_magweight: w.inv_magweight,
2593 invgradweight: w.invgradweight,
2594 inv_lapweight: w.inv_lapweight,
2595 })
2596 .collect::<Vec<_>>();
2597 let fitted_link = if family.is_latent_cloglog() {
2598 FittedLinkState::LatentCLogLog {
2599 state: latent_cloglog_state
2600 .expect("BinomialLatentCLogLog requires an explicit latent-cloglog state"),
2601 }
2602 } else if family.is_binomial_mixture() {
2603 mixture_link_state
2604 .clone()
2605 .map(|state| FittedLinkState::Mixture {
2606 state,
2607 covariance: None,
2608 })
2609 .unwrap_or(FittedLinkState::Standard(None))
2610 } else if family.is_binomial_sas() {
2611 sas_link_state
2612 .map(|state| FittedLinkState::Sas {
2613 state,
2614 covariance: None,
2615 })
2616 .unwrap_or(FittedLinkState::Standard(None))
2617 } else if family.is_binomial_beta_logistic() {
2618 sas_link_state
2619 .map(|state| FittedLinkState::BetaLogistic {
2620 state,
2621 covariance: None,
2622 })
2623 .unwrap_or(FittedLinkState::Standard(None))
2624 } else {
2625 FittedLinkState::Standard(None)
2626 };
2627 let max_abs_eta = final_eval
2628 .obs
2629 .eta
2630 .iter()
2631 .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2632 let fitted = FittedTermCollection {
2633 fit: {
2634 let log_lambdas = full_lambdas.mapv(|v| v.max(1e-300).ln());
2635 let inf = FitInference {
2636 edf_by_block,
2637 penalty_block_trace,
2638 edf_total,
2639 smoothing_correction: None,
2640 penalized_hessian: penalized_hessian.clone().into(),
2643 working_weights: final_eval.obs.fisherweight.clone(),
2644 working_response: {
2645 let mut out = final_eval.obs.eta.clone();
2646 for i in 0..out.len() {
2647 let wi = final_eval.obs.fisherweight[i].max(1e-12);
2648 out[i] += final_eval.obs.score[i] / wi;
2649 }
2650 out
2651 },
2652 reparam_qs: None,
2653 dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2654 beta_covariance: beta_covariance
2655 .clone()
2656 .map(gam_problem::dispersion_cov::PhiScaledCovariance::from),
2657 beta_standard_errors,
2658 beta_covariance_corrected: None,
2659 beta_standard_errors_corrected: None,
2660 beta_covariance_frequentist: None,
2661 coefficient_influence: None,
2662 weighted_gram: None,
2663 bias_correction_beta: None,
2664 };
2665 let geometry = Some(gam_solve::estimate::FitGeometry {
2666 penalized_hessian: penalized_hessian.into(),
2667 working_weights: inf.working_weights.clone(),
2668 working_response: inf.working_response.clone(),
2669 });
2670 let covariance_conditional = beta_covariance;
2671 let pirls_status_val = if final_fit.outer_converged {
2672 gam_solve::pirls::PirlsStatus::Converged
2673 } else {
2674 gam_solve::pirls::PirlsStatus::StalledAtValidMinimum
2675 };
2676 UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
2677 blocks: vec![gam_solve::estimate::FittedBlock {
2678 beta: beta.clone(),
2679 role: gam_problem::BlockRole::Mean,
2680 edf: edf_total,
2681 lambdas: full_lambdas.clone(),
2682 }],
2683 log_lambdas,
2684 lambdas: full_lambdas,
2685 likelihood_scale: family.default_scale_metadata(),
2686 likelihood_family: Some(family),
2687 log_likelihood_normalization:
2688 gam_spec::LogLikelihoodNormalization::UserProvided,
2689 log_likelihood: final_eval.obs.log_likelihood,
2690 deviance,
2691 reml_score: final_fit.penalized_objective,
2692 stable_penalty_term,
2693 penalized_objective: final_fit.penalized_objective,
2694 used_device: false,
2695 outer_iterations,
2696 outer_converged: final_fit.outer_converged,
2697 outer_gradient_norm: outer_grad_norm,
2698 standard_deviation,
2699 covariance_conditional,
2700 covariance_corrected: None,
2701 inference: Some(inf),
2702 fitted_link,
2703 geometry,
2704 block_states: Vec::new(),
2705 pirls_status: pirls_status_val,
2706 max_abs_eta,
2707 constraint_kkt: None,
2708 artifacts: gam_solve::estimate::FitArtifacts {
2709 pirls: None,
2710 ..Default::default()
2711 },
2712 inner_cycles: 0,
2713 })?
2714 },
2715 design: baseline.design,
2716 adaptive_diagnostics: Some(AdaptiveRegularizationDiagnostics {
2717 epsilon_0: eps_star[0],
2718 epsilon_g: eps_star[1],
2719 epsilon_c: eps_star[2],
2720 epsilon_outer_iterations: outer_iterations,
2721 mm_iterations: 0,
2722 converged: final_fit.outer_converged,
2723 maps,
2724 }),
2725 };
2726 enforce_term_constraint_feasibility(&fitted.design, &fitted.fit)?;
2727 Ok(fitted)
2728}
2729
2730fn relax_smoothing_rho_prior(
2762 options: &FitOptions,
2763 design: &TermCollectionDesign,
2764) -> gam_spec::RhoPrior {
2765 use gam_terms::basis::BasisMetadata;
2766 let base = &options.rho_prior;
2767 if matches!(
2770 base,
2771 gam_spec::RhoPrior::Flat | gam_spec::RhoPrior::Independent(_)
2772 ) {
2773 return base.clone();
2774 }
2775 let has_link_aux = options.sas_link.is_some()
2795 || options.optimize_sas
2796 || options.mixture_link.is_some()
2797 || options.optimize_mixture;
2798 let has_moving_kappa = design.smooth.terms.iter().any(|t| {
2799 matches!(
2800 t.metadata,
2801 BasisMetadata::Matern { .. }
2802 | BasisMetadata::Duchon { .. }
2803 | BasisMetadata::Sphere { .. }
2804 | BasisMetadata::SphereHarmonics { .. }
2805 | BasisMetadata::ConstantCurvature { .. }
2806 | BasisMetadata::MeasureJet { .. }
2807 )
2808 });
2809 let length_safe = !has_link_aux && !has_moving_kappa;
2816 if !length_safe {
2817 return base.clone();
2818 }
2819 let coords = &design.penaltyinfo;
2820 if coords.is_empty() {
2821 return base.clone();
2822 }
2823 let n_obs = design.design.nrows();
2834 let p_total = design.design.ncols();
2835 let underdetermined = n_obs < 2 * p_total;
2866 let relaxable_terms: std::collections::HashSet<&str> = design
2878 .smooth
2879 .terms
2880 .iter()
2881 .filter(|t| {
2882 matches!(
2883 t.metadata,
2884 BasisMetadata::BSpline1D { .. }
2885 | BasisMetadata::ThinPlate { .. }
2886 | BasisMetadata::TensorBSpline { .. }
2887 )
2888 && matches!(t.shape, gam_terms::smooth::ShapeConstraint::None)
2902 })
2903 .map(|t| t.name.as_str())
2904 .collect();
2905 let any_relaxed = coords.iter().any(|info| {
2906 info.termname
2907 .as_deref()
2908 .is_some_and(|name| relaxable_terms.contains(name))
2909 });
2910 if !any_relaxed {
2911 return base.clone();
2912 }
2913 let relaxed_prior = if underdetermined {
2918 gam_spec::RhoPrior::Normal {
2919 mean: 0.0,
2920 sd: RELAX_UNDERDETERMINED_RHO_SD,
2921 }
2922 } else {
2923 gam_spec::RhoPrior::Flat
2924 };
2925 let nullspace_select_prior = gam_spec::RhoPrior::PenalizedComplexity {
2952 upper: NULLSPACE_SELECT_PC_UPPER,
2953 tail_prob: NULLSPACE_SELECT_PC_TAIL_PROB,
2954 };
2955 let nullspace_degeneracy_prior = gam_spec::RhoPrior::Normal {
2982 mean: 0.0,
2983 sd: NULLSPACE_WELLDET_DEGENERACY_RHO_SD,
2984 };
2985 let per_coord = coords
2986 .iter()
2987 .map(|info| {
2988 let relax = info
2989 .termname
2990 .as_deref()
2991 .is_some_and(|name| relaxable_terms.contains(name));
2992 if !relax {
2993 return base.clone();
2994 }
2995 let is_nullspace =
2996 matches!(info.penalty.source, PenaltySource::DoublePenaltyNullspace);
2997 if is_nullspace {
3036 if underdetermined {
3037 nullspace_select_prior.clone()
3038 } else {
3039 nullspace_degeneracy_prior.clone()
3040 }
3041 } else {
3042 relaxed_prior.clone()
3043 }
3044 })
3045 .collect::<Vec<_>>();
3046 gam_spec::RhoPrior::Independent(per_coord)
3047}
3048
3049const RELAX_UNDERDETERMINED_RHO_SD: f64 = 15.0;
3062
3063const NULLSPACE_SELECT_PC_UPPER: f64 = 0.05;
3081
3082const NULLSPACE_SELECT_PC_TAIL_PROB: f64 = 0.01;
3092
3093fn adaptive_fit_options_base(options: &FitOptions, design: &TermCollectionDesign) -> FitOptions {
3094 FitOptions {
3095 latent_cloglog: options.latent_cloglog,
3096 mixture_link: options.mixture_link.clone(),
3097 optimize_mixture: options.optimize_mixture,
3098 sas_link: options.sas_link,
3099 optimize_sas: options.optimize_sas,
3100 compute_inference: options.compute_inference,
3101 skip_rho_posterior_inference: options.skip_rho_posterior_inference,
3102 max_iter: options.max_iter,
3103 tol: options.tol,
3104 nullspace_dims: design.nullspace_dims.clone(),
3105 linear_constraints: design.linear_constraints.clone(),
3106 firth_bias_reduction: options.firth_bias_reduction,
3107 adaptive_regularization: None,
3108 penalty_shrinkage_floor: options.penalty_shrinkage_floor,
3109 rho_prior: options.rho_prior.clone(),
3112 kronecker_penalty_system: design.kronecker_penalty_system(),
3113 kronecker_factored: design
3114 .smooth
3115 .terms
3116 .iter()
3117 .find_map(|t| t.kronecker_factored.clone()),
3118 persist_warm_start_disk: options.persist_warm_start_disk,
3119 }
3120}
3121
3122fn superseded_fit_options(options: &FitOptions) -> FitOptions {
3123 let mut fit_options = options.clone();
3124 fit_options.skip_rho_posterior_inference = true;
3125 fit_options
3126}
3127
3128#[derive(Clone)]
3129struct BoundedLinearTermMeta {
3130 col_idx: usize,
3131 min: f64,
3132 max: f64,
3133 prior: BoundedCoefficientPriorSpec,
3134}
3135
3136struct BoundedEffectiveJacobian {
3160 design: Array2<f64>,
3161 bounded_terms: Vec<BoundedLinearTermMeta>,
3162}
3163
3164impl BlockEffectiveJacobian for BoundedEffectiveJacobian {
3165 fn effective_jacobian_rows(
3166 &self,
3167 state: &FamilyLinearizationState<'_>,
3168 rows: std::ops::Range<usize>,
3169 ) -> Result<Array2<f64>, String> {
3170 let p = self.design.ncols();
3171 let n = self.design.nrows();
3172 let rows = rows.start.min(n)..rows.end.min(n);
3173 if !state.beta.is_empty() {
3174 if state.beta.len() != p {
3175 return Err(format!(
3176 "BoundedEffectiveJacobian::effective_jacobian_at: beta length {} != design \
3177 ncols {p}",
3178 state.beta.len(),
3179 ));
3180 }
3181 if state.beta.iter().any(|v| v.is_nan()) {
3182 return Err(
3183 "BoundedEffectiveJacobian::effective_jacobian_at: beta contains NaN"
3184 .to_string(),
3185 );
3186 }
3187 }
3188 let mut jac = self
3189 .design
3190 .slice(ndarray::s![rows.start..rows.end, ..])
3191 .to_owned();
3192 for term in &self.bounded_terms {
3193 let theta = if state.beta.is_empty() {
3194 0.0
3195 } else {
3196 state.beta[term.col_idx]
3197 };
3198 let (_, _, db_dtheta, _, _) = bounded_latent_derivatives(theta, term.min, term.max);
3199 jac.column_mut(term.col_idx).mapv_inplace(|v| v * db_dtheta);
3200 }
3201 Ok(jac)
3202 }
3203}
3204
3205#[derive(Clone)]
3206struct BoundedLinearFamily {
3207 family: LikelihoodSpec,
3208 latent_cloglog_state: Option<LatentCLogLogState>,
3209 mixture_link_state: Option<MixtureLinkState>,
3210 sas_link_state: Option<SasLinkState>,
3211 y: Array1<f64>,
3212 weights: Array1<f64>,
3213 design: Array2<f64>,
3214 designzeroed: Array2<f64>,
3215 offset: Array1<f64>,
3216 bounded_terms: Vec<BoundedLinearTermMeta>,
3217}
3218
3219#[derive(Clone)]
3220struct StandardFamilyObservationState {
3221 eta: Array1<f64>,
3222 mu: Array1<f64>,
3223 score: Array1<f64>,
3224 fisherweight: Array1<f64>,
3225 neghessian_eta: Array1<f64>,
3226 neghessian_eta_derivative: Array1<f64>,
3227 log_likelihood: f64,
3228}
3229
3230fn bounded_logit(z: f64) -> f64 {
3231 let zc = z.clamp(1e-12, 1.0 - 1e-12);
3232 (zc / (1.0 - zc)).ln()
3233}
3234
3235fn stable_sigmoid(theta: f64) -> f64 {
3236 if theta >= 0.0 {
3237 let exp_neg = (-theta).exp();
3238 1.0 / (1.0 + exp_neg)
3239 } else {
3240 let exp_pos = theta.exp();
3241 exp_pos / (1.0 + exp_pos)
3242 }
3243}
3244
3245fn bounded_latent_to_user(theta: f64, min: f64, max: f64) -> (f64, f64, f64) {
3246 let z = stable_sigmoid(theta);
3247 let width = max - min;
3248 let beta = min + width * z;
3249 let db_dtheta = width * z * (1.0 - z);
3250 (beta, z, db_dtheta)
3251}
3252
3253fn bounded_user_to_latent(beta: f64, min: f64, max: f64) -> f64 {
3264 let width = max - min;
3265 if width <= 0.0 || !width.is_finite() {
3266 return 0.0;
3267 }
3268 let z = (beta - min) / width;
3269 bounded_logit(z)
3270}
3271
3272#[derive(Debug, Clone, Copy)]
3276pub struct BoundedSampleColumn {
3277 pub col_idx: usize,
3279 pub min: f64,
3281 pub max: f64,
3283}
3284
3285pub fn sample_bounded_latent_posterior_internal(
3323 beta_user: &Array1<f64>,
3324 user_hessian: &Array2<f64>,
3325 bounded_columns: &[BoundedSampleColumn],
3326 n_draws: usize,
3327 sqrt_cov_scale: f64,
3328 base_seed: u64,
3329) -> Result<Array2<f64>, EstimationError> {
3330 let p = beta_user.len();
3331 if user_hessian.nrows() != p || user_hessian.ncols() != p {
3332 crate::bail_invalid_estim!(
3333 "bounded posterior sampling dimension mismatch: mode has {p} entries, user Hessian is {}x{}",
3334 user_hessian.nrows(),
3335 user_hessian.ncols()
3336 );
3337 }
3338
3339 let mut theta_mode = beta_user.clone();
3341 let mut jac_diag = Array1::<f64>::ones(p);
3342 for bc in bounded_columns {
3343 if bc.col_idx >= p {
3344 crate::bail_invalid_estim!(
3345 "bounded posterior sampling: bounded column index {} out of range for {p} coefficients",
3346 bc.col_idx
3347 );
3348 }
3349 let theta_i = bounded_user_to_latent(beta_user[bc.col_idx], bc.min, bc.max);
3350 let (_, _, db_dtheta) = bounded_latent_to_user(theta_i, bc.min, bc.max);
3351 theta_mode[bc.col_idx] = theta_i;
3352 jac_diag[bc.col_idx] = db_dtheta.max(1e-12);
3357 }
3358
3359 let mut h_latent = user_hessian.clone();
3362 for i in 0..p {
3363 let ji = jac_diag[i];
3364 if ji != 1.0 {
3365 h_latent.row_mut(i).mapv_inplace(|v| v * ji);
3366 h_latent.column_mut(i).mapv_inplace(|v| v * ji);
3367 }
3368 }
3369
3370 use gam_linalg::faer_ndarray::FaerCholesky as _;
3373 use rand::SeedableRng as _;
3374 let chol = h_latent.cholesky(faer::Side::Lower).map_err(|err| {
3375 EstimationError::InvalidInput(format!(
3376 "bounded posterior sampling: Cholesky of the latent penalized Hessian failed: {err:?}"
3377 ))
3378 })?;
3379 let l = chol.lower_triangular();
3380
3381 let mut draws = Array2::<f64>::zeros((n_draws, p));
3382 let mut eps = Array1::<f64>::zeros(p);
3383 let mut delta = Array1::<f64>::zeros(p);
3384 let mut rng = rand::rngs::StdRng::seed_from_u64(base_seed);
3385 for k in 0..n_draws {
3386 for e in eps.iter_mut() {
3387 *e = standard_normal_draw(&mut rng);
3388 }
3389 solve_lower_transpose_into(&l, &eps, &mut delta);
3390 for i in 0..p {
3391 draws[(k, i)] = theta_mode[i] + sqrt_cov_scale * delta[i];
3394 }
3395 for bc in bounded_columns {
3398 let (beta_draw, _, _) = bounded_latent_to_user(draws[(k, bc.col_idx)], bc.min, bc.max);
3399 draws[(k, bc.col_idx)] = beta_draw;
3400 }
3401 }
3402
3403 Ok(draws)
3404}
3405
3406#[inline]
3409fn standard_normal_draw<R: rand::Rng + ?Sized>(rng: &mut R) -> f64 {
3410 use rand::RngExt as _;
3411 let u1 = rng.random::<f64>().max(1e-16);
3412 let u2 = rng.random::<f64>();
3413 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
3414}
3415
3416fn solve_lower_transpose_into(l: &Array2<f64>, b: &Array1<f64>, out: &mut Array1<f64>) {
3420 let p = l.nrows();
3421 for i in (0..p).rev() {
3422 let mut acc = b[i];
3423 for j in (i + 1)..p {
3424 acc -= l[(j, i)] * out[j];
3425 }
3426 let diag = l[(i, i)];
3427 out[i] = if diag.abs() > 0.0 { acc / diag } else { 0.0 };
3428 }
3429}
3430
3431fn bounded_latent_derivatives(theta: f64, min: f64, max: f64) -> (f64, f64, f64, f64, f64) {
3432 let z = stable_sigmoid(theta);
3433 let width = max - min;
3434 let s = z * (1.0 - z);
3435 let beta = min + width * z;
3436 let db_dtheta = width * s;
3437 let d2b_dtheta2 = width * s * (1.0 - 2.0 * z);
3438 let d3b_dtheta3 = width * s * (1.0 - 6.0 * z + 6.0 * z * z);
3439 (beta, z, db_dtheta, d2b_dtheta2, d3b_dtheta3)
3440}
3441
3442fn bounded_prior_terms(theta: f64, prior: &BoundedCoefficientPriorSpec) -> (f64, f64, f64, f64) {
3443 let (a, b) = match prior {
3444 BoundedCoefficientPriorSpec::None => return (0.0, 0.0, 0.0, 0.0),
3446 BoundedCoefficientPriorSpec::Uniform => (1.0, 1.0),
3449 BoundedCoefficientPriorSpec::Beta { a, b } => (*a, *b),
3450 };
3451 let z = stable_sigmoid(theta).clamp(1e-12, 1.0 - 1e-12);
3452 let logp = a * z.ln() + b * (1.0 - z).ln();
3453 let grad = a - (a + b) * z;
3454 let neghess = (a + b) * z * (1.0 - z);
3455 let neghess_derivative = (a + b) * z * (1.0 - z) * (1.0 - 2.0 * z);
3456 (logp, grad, neghess, neghess_derivative)
3457}
3458
3459#[inline]
3468fn glm_eta_observation_state(
3469 w: f64,
3470 lmu: f64,
3471 lmumu: f64,
3472 lmumumu: f64,
3473 var: f64,
3474 d1: f64,
3475 d2: f64,
3476 d3: f64,
3477 mu_deriv_eps: f64,
3478) -> (f64, f64, f64, f64) {
3479 let score = w * lmu * d1;
3480 let fisherweight = (w * d1 * d1 / var).max(mu_deriv_eps);
3481 let neghessian = -w * (lmumu * d1 * d1 + lmu * d2);
3482 let neghessian_deriv = -w * (lmumumu * d1 * d1 * d1 + 3.0 * lmumu * d1 * d2 + lmu * d3);
3483 (score, fisherweight, neghessian, neghessian_deriv)
3484}
3485
3486fn evaluate_standard_familyobservations(
3487 family: LikelihoodSpec,
3488 latent_cloglog_state: Option<&LatentCLogLogState>,
3489 mixture_link_state: Option<&MixtureLinkState>,
3490 sas_link_state: Option<&SasLinkState>,
3491 y: &Array1<f64>,
3492 weights: &Array1<f64>,
3493 eta: &Array1<f64>,
3494) -> Result<StandardFamilyObservationState, EstimationError> {
3495 const PROB_EPS: f64 = 1e-10;
3496 const MU_DERIV_EPS: f64 = 1e-12;
3497 let n = y.len();
3498 if weights.len() != n || eta.len() != n {
3499 crate::bail_invalid_estim!("bounded family observation size mismatch");
3500 }
3501
3502 let mut mu = Array1::<f64>::zeros(n);
3503 let mut score = Array1::<f64>::zeros(n);
3504 let mut fisherweight = Array1::<f64>::zeros(n);
3505 let mut neghessian_eta = Array1::<f64>::zeros(n);
3506 let mut neghessian_eta_derivative = Array1::<f64>::zeros(n);
3507 let mut log_likelihood = 0.0;
3508
3509 for i in 0..n {
3510 let w = weights[i].max(0.0);
3511 let yi = y[i];
3512 let eta_i = eta[i];
3513 match (&family.response, &family.link) {
3514 (ResponseFamily::Gaussian, _) => {
3515 let resid = yi - eta_i;
3516 mu[i] = eta_i;
3517 score[i] = w * resid;
3518 fisherweight[i] = w.max(MU_DERIV_EPS);
3519 neghessian_eta[i] = w;
3520 neghessian_eta_derivative[i] = 0.0;
3521 log_likelihood += -0.5 * w * resid * resid;
3522 }
3523 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Logit)) => {
3524 let jet = logit_inverse_link_jet5(eta_i);
3525 mu[i] = jet.mu;
3526 score[i] = w * (yi - jet.mu);
3527 fisherweight[i] = jet.d1.max(MU_DERIV_EPS);
3528 neghessian_eta[i] = jet.d1;
3529 neghessian_eta_derivative[i] = jet.d2;
3530 let logmu = -gam_linalg::utils::stable_softplus(-eta_i);
3531 let log_one_minusmu = -gam_linalg::utils::stable_softplus(eta_i);
3532 log_likelihood += w * (yi * logmu + (1.0 - yi) * log_one_minusmu);
3533 }
3534 (ResponseFamily::Binomial, _) => {
3535 let inverse_link = if let Some(state) = latent_cloglog_state {
3536 Some(InverseLink::LatentCLogLog(*state))
3537 } else if let Some(state) = mixture_link_state {
3538 Some(InverseLink::Mixture(state.clone()))
3539 } else {
3540 sas_link_state.map(|state| {
3541 if family.is_binomial_beta_logistic() {
3542 InverseLink::BetaLogistic(*state)
3543 } else {
3544 InverseLink::Sas(*state)
3545 }
3546 })
3547 };
3548 let strategy_spec = LikelihoodSpec {
3549 response: family.response.clone(),
3550 link: inverse_link.clone().unwrap_or_else(|| family.link.clone()),
3551 };
3552 let jet = strategy_for_spec(&strategy_spec).inverse_link_jet(eta_i)?;
3553 let mu_i_raw = jet.mu;
3554 let dmu_deta_raw = jet.d1;
3555 let mu_i: f64 = mu_i_raw.clamp(PROB_EPS, 1.0 - PROB_EPS);
3556 let dmu_deta = dmu_deta_raw.max(MU_DERIV_EPS);
3557 let d2mu_deta2 = jet.d2;
3558 let d3mu_deta3 = jet.d3;
3559 let var = (mu_i * (1.0 - mu_i)).max(PROB_EPS);
3560 let lmu = (yi - mu_i) / var;
3561 let lmumu = -(yi / (mu_i * mu_i)) - ((1.0 - yi) / ((1.0 - mu_i) * (1.0 - mu_i)));
3562 let lmumumu = 2.0 * yi / (mu_i * mu_i * mu_i)
3563 - 2.0 * (1.0 - yi) / ((1.0 - mu_i) * (1.0 - mu_i) * (1.0 - mu_i));
3564 mu[i] = mu_i;
3565 score[i] = w * lmu * dmu_deta;
3566 fisherweight[i] = (w * dmu_deta * dmu_deta / var).max(MU_DERIV_EPS);
3567 neghessian_eta[i] = -w * (lmumu * dmu_deta * dmu_deta + lmu * d2mu_deta2);
3568 neghessian_eta_derivative[i] = -w
3569 * (lmumumu * dmu_deta * dmu_deta * dmu_deta
3570 + 3.0 * lmumu * dmu_deta * d2mu_deta2
3571 + lmu * d3mu_deta3);
3572 log_likelihood += w * (yi * mu_i.ln() + (1.0 - yi) * (1.0 - mu_i).ln());
3573 }
3574 (ResponseFamily::Poisson, _) => {
3575 let strategy_spec = LikelihoodSpec {
3578 response: family.response.clone(),
3579 link: family.link.clone(),
3580 };
3581 let jet = strategy_for_spec(&strategy_spec).inverse_link_jet(eta_i)?;
3582 let mu_i = jet.mu.max(PROB_EPS);
3583 let d1 = jet.d1.max(MU_DERIV_EPS);
3584 let var = mu_i;
3585 let lmu = yi / mu_i - 1.0;
3586 let lmumu = -yi / (mu_i * mu_i);
3587 let lmumumu = 2.0 * yi / (mu_i * mu_i * mu_i);
3588 let (s, f, nh, nhd) = glm_eta_observation_state(
3589 w, lmu, lmumu, lmumumu, var, d1, jet.d2, jet.d3, MU_DERIV_EPS,
3590 );
3591 mu[i] = mu_i;
3592 score[i] = s;
3593 fisherweight[i] = f;
3594 neghessian_eta[i] = nh;
3595 neghessian_eta_derivative[i] = nhd;
3596 log_likelihood += w * (yi * mu_i.ln() - mu_i);
3597 }
3598 (ResponseFamily::Tweedie { p }, _) => {
3599 let p = *p;
3604 let strategy_spec = LikelihoodSpec {
3605 response: family.response.clone(),
3606 link: family.link.clone(),
3607 };
3608 let jet = strategy_for_spec(&strategy_spec).inverse_link_jet(eta_i)?;
3609 let mu_i = jet.mu.max(PROB_EPS);
3610 let d1 = jet.d1.max(MU_DERIV_EPS);
3611 let var = mu_i.powf(p);
3612 let resid = yi - mu_i;
3613 let lmu = resid / var;
3614 let lmumu = -mu_i.powf(-p) - p * resid * mu_i.powf(-p - 1.0);
3615 let lmumumu =
3616 2.0 * p * mu_i.powf(-p - 1.0) + p * (p + 1.0) * resid * mu_i.powf(-p - 2.0);
3617 let (s, f, nh, nhd) = glm_eta_observation_state(
3618 w, lmu, lmumu, lmumumu, var, d1, jet.d2, jet.d3, MU_DERIV_EPS,
3619 );
3620 mu[i] = mu_i;
3621 score[i] = s;
3622 fisherweight[i] = f;
3623 neghessian_eta[i] = nh;
3624 neghessian_eta_derivative[i] = nhd;
3625 log_likelihood += w
3627 * (yi * mu_i.powf(1.0 - p) / (1.0 - p) - mu_i.powf(2.0 - p) / (2.0 - p));
3628 }
3629 (ResponseFamily::NegativeBinomial { theta, .. }, _) => {
3630 let theta = (*theta).max(PROB_EPS);
3634 let strategy_spec = LikelihoodSpec {
3635 response: family.response.clone(),
3636 link: family.link.clone(),
3637 };
3638 let jet = strategy_for_spec(&strategy_spec).inverse_link_jet(eta_i)?;
3639 let mu_i = jet.mu.max(PROB_EPS);
3640 let d1 = jet.d1.max(MU_DERIV_EPS);
3641 let mu_plus = mu_i + theta;
3642 let var = mu_i + mu_i * mu_i / theta;
3643 let lmu = yi / mu_i - (yi + theta) / mu_plus;
3644 let lmumu = -yi / (mu_i * mu_i) + (yi + theta) / (mu_plus * mu_plus);
3645 let lmumumu =
3646 2.0 * yi / (mu_i * mu_i * mu_i) - 2.0 * (yi + theta) / (mu_plus * mu_plus * mu_plus);
3647 let (s, f, nh, nhd) = glm_eta_observation_state(
3648 w, lmu, lmumu, lmumumu, var, d1, jet.d2, jet.d3, MU_DERIV_EPS,
3649 );
3650 mu[i] = mu_i;
3651 score[i] = s;
3652 fisherweight[i] = f;
3653 neghessian_eta[i] = nh;
3654 neghessian_eta_derivative[i] = nhd;
3655 log_likelihood += w * (yi * mu_i.ln() - (yi + theta) * mu_plus.ln());
3656 }
3657 (ResponseFamily::Beta { .. }, _) => {
3658 crate::bail_invalid_estim!(
3659 "bounded linear terms are not supported for BetaLogit fits"
3660 );
3661 }
3662 (ResponseFamily::Gamma, _) => {
3663 let strategy_spec = LikelihoodSpec {
3667 response: family.response.clone(),
3668 link: family.link.clone(),
3669 };
3670 let jet = strategy_for_spec(&strategy_spec).inverse_link_jet(eta_i)?;
3671 let mu_i = jet.mu.max(PROB_EPS);
3672 let d1 = jet.d1.max(MU_DERIV_EPS);
3673 let var = mu_i * mu_i;
3674 let lmu = yi / (mu_i * mu_i) - 1.0 / mu_i;
3675 let lmumu = -2.0 * yi / (mu_i * mu_i * mu_i) + 1.0 / (mu_i * mu_i);
3676 let lmumumu =
3677 6.0 * yi / (mu_i * mu_i * mu_i * mu_i) - 2.0 / (mu_i * mu_i * mu_i);
3678 let (s, f, nh, nhd) = glm_eta_observation_state(
3679 w, lmu, lmumu, lmumumu, var, d1, jet.d2, jet.d3, MU_DERIV_EPS,
3680 );
3681 mu[i] = mu_i;
3682 score[i] = s;
3683 fisherweight[i] = f;
3684 neghessian_eta[i] = nh;
3685 neghessian_eta_derivative[i] = nhd;
3686 log_likelihood += w * (-(yi / mu_i) - mu_i.ln());
3687 }
3688 (ResponseFamily::RoystonParmar, _) => {
3689 crate::bail_invalid_estim!(
3690 "bounded linear terms are not supported for survival model fits"
3691 );
3692 }
3693 }
3694 }
3695
3696 Ok(StandardFamilyObservationState {
3697 eta: eta.clone(),
3698 mu,
3699 score,
3700 fisherweight,
3701 neghessian_eta,
3702 neghessian_eta_derivative,
3703 log_likelihood,
3704 })
3705}
3706
3707#[derive(Clone, Copy, Debug, PartialEq, Eq)]
3708enum SpatialAdaptiveHyperKind {
3709 LogLambdaMagnitude,
3710 LogLambdaGradient,
3711 LogLambdaCurvature,
3712 LogEpsilonMagnitude,
3713 LogEpsilonGradient,
3714 LogEpsilonCurvature,
3715}
3716
3717impl SpatialAdaptiveHyperKind {
3718 fn component_index(self) -> usize {
3719 match self {
3720 SpatialAdaptiveHyperKind::LogLambdaMagnitude
3721 | SpatialAdaptiveHyperKind::LogEpsilonMagnitude => 0,
3722 SpatialAdaptiveHyperKind::LogLambdaGradient
3723 | SpatialAdaptiveHyperKind::LogEpsilonGradient => 1,
3724 SpatialAdaptiveHyperKind::LogLambdaCurvature
3725 | SpatialAdaptiveHyperKind::LogEpsilonCurvature => 2,
3726 }
3727 }
3728
3729 fn is_log_lambda(self) -> bool {
3730 matches!(
3731 self,
3732 SpatialAdaptiveHyperKind::LogLambdaMagnitude
3733 | SpatialAdaptiveHyperKind::LogLambdaGradient
3734 | SpatialAdaptiveHyperKind::LogLambdaCurvature
3735 )
3736 }
3737
3738 fn is_log_epsilon(self) -> bool {
3739 matches!(
3740 self,
3741 SpatialAdaptiveHyperKind::LogEpsilonMagnitude
3742 | SpatialAdaptiveHyperKind::LogEpsilonGradient
3743 | SpatialAdaptiveHyperKind::LogEpsilonCurvature
3744 )
3745 }
3746}
3747
3748#[derive(Clone, Copy, Debug)]
3749struct SpatialAdaptiveHyperSpec {
3750 cache_index: usize,
3751 kind: SpatialAdaptiveHyperKind,
3752}
3753
3754#[derive(Clone, Copy, Debug, PartialEq, Eq)]
3755enum SpatialAdaptiveExplicitSecondOrderKind {
3756 StructuralZero,
3757 LocalAlphaAlpha,
3758 LocalAlphaEta,
3759 SharedEtaEta,
3760}
3761
3762#[derive(Clone, Copy, Debug, PartialEq, Eq)]
3767enum AdaptiveComponent {
3768 Magnitude,
3769 Gradient,
3770 Curvature,
3771}
3772
3773impl AdaptiveComponent {
3774 fn from_index(index: usize) -> Result<Self, String> {
3775 match index {
3776 0 => Ok(AdaptiveComponent::Magnitude),
3777 1 => Ok(AdaptiveComponent::Gradient),
3778 2 => Ok(AdaptiveComponent::Curvature),
3779 other => Err(SmoothError::invalid_index(format!(
3780 "invalid adaptive component index {}",
3781 other
3782 ))
3783 .into()),
3784 }
3785 }
3786}
3787
3788#[derive(Clone, Copy, Debug, PartialEq, Eq)]
3793enum HyperDerivativeKind {
3794 Rho,
3796 LogEpsilonFirst,
3798 LogEpsilonSecond,
3800}
3801
3802#[derive(Clone, Copy, Debug, PartialEq, Eq)]
3807enum HyperDriftKind {
3808 Rho,
3809 LogEpsilon,
3810}
3811
3812impl SpatialAdaptiveHyperSpec {
3813 fn component_index(self) -> usize {
3814 self.kind.component_index()
3815 }
3816
3817 fn explicit_second_order_kind(self, other: Self) -> SpatialAdaptiveExplicitSecondOrderKind {
3818 if self.component_index() != other.component_index() {
3819 return SpatialAdaptiveExplicitSecondOrderKind::StructuralZero;
3820 }
3821 match (
3822 self.kind.is_log_lambda(),
3823 other.kind.is_log_lambda(),
3824 self.kind.is_log_epsilon(),
3825 other.kind.is_log_epsilon(),
3826 ) {
3827 (true, true, false, false) if self.cache_index == other.cache_index => {
3828 SpatialAdaptiveExplicitSecondOrderKind::LocalAlphaAlpha
3829 }
3830 (true, false, false, true) | (false, true, true, false) => {
3831 SpatialAdaptiveExplicitSecondOrderKind::LocalAlphaEta
3832 }
3833 (false, false, true, true) => SpatialAdaptiveExplicitSecondOrderKind::SharedEtaEta,
3834 _ => SpatialAdaptiveExplicitSecondOrderKind::StructuralZero,
3835 }
3836 }
3837}
3838
3839#[derive(Clone, Debug)]
3840struct SpatialAdaptiveTermHyperParams {
3841 lambda: [f64; 3],
3842 epsilon: [f64; 3],
3843}
3844
3845#[derive(Clone)]
3846struct SpatialAdaptiveExactEvaluation {
3847 obs: StandardFamilyObservationState,
3848 adaptive_states: Vec<SpatialPenaltyExactState>,
3849 adaptive_penalty_value: f64,
3850 adaptive_penaltygradient: Array1<f64>,
3851 adaptive_penaltyhessian: Array2<f64>,
3852 fixed_quadraticvalue: f64,
3853 fixed_quadraticgradient: Array1<f64>,
3854 fixed_quadratichessian: Array2<f64>,
3855}
3856
3857#[derive(Clone)]
3858struct CachedSpatialAdaptiveExactEvaluation {
3859 beta: Array1<f64>,
3860 eval: Arc<SpatialAdaptiveExactEvaluation>,
3861}
3862
3863impl SpatialAdaptiveExactEvaluation {
3864 fn total_penalty_value(&self) -> f64 {
3865 self.adaptive_penalty_value + self.fixed_quadraticvalue
3866 }
3867
3868 fn total_penaltygradient(&self) -> Array1<f64> {
3869 &self.adaptive_penaltygradient + &self.fixed_quadraticgradient
3870 }
3871
3872 fn total_penaltyhessian(&self) -> Array2<f64> {
3873 &self.adaptive_penaltyhessian + &self.fixed_quadratichessian
3874 }
3875
3876 fn totalobjectivehessian(&self, design: &Array2<f64>) -> Result<Array2<f64>, String> {
3877 let mut out = xt_diag_x_dense(design.view(), self.obs.neghessian_eta.view())?;
3878 out += &self.total_penaltyhessian();
3879 Ok(out)
3880 }
3881}
3882
3883#[derive(Clone)]
3884struct SpatialAdaptiveExactFamily {
3885 family: LikelihoodSpec,
3886 latent_cloglog_state: Option<LatentCLogLogState>,
3887 mixture_link_state: Option<MixtureLinkState>,
3888 sas_link_state: Option<SasLinkState>,
3889 y: Arc<Array1<f64>>,
3890 weights: Arc<Array1<f64>>,
3891 design: Arc<Array2<f64>>,
3892 offset: Arc<Array1<f64>>,
3893 linear_constraints: Option<LinearInequalityConstraints>,
3894 runtime_caches: Arc<Vec<SpatialOperatorRuntimeCache>>,
3895 adaptive_params: Vec<SpatialAdaptiveTermHyperParams>,
3896 fixed_quadratichessian: Arc<Array2<f64>>,
3897 hyperspecs: Arc<Vec<SpatialAdaptiveHyperSpec>>,
3898 exact_eval_cache: Arc<Mutex<Option<CachedSpatialAdaptiveExactEvaluation>>>,
3899}
3900
3901impl SpatialAdaptiveExactFamily {
3902 fn with_adaptive_params(
3903 &self,
3904 adaptive_params: Vec<SpatialAdaptiveTermHyperParams>,
3905 fixed_quadratichessian: Arc<Array2<f64>>,
3906 ) -> Self {
3907 Self {
3908 family: self.family.clone(),
3909 latent_cloglog_state: self.latent_cloglog_state,
3910 mixture_link_state: self.mixture_link_state.clone(),
3911 sas_link_state: self.sas_link_state,
3912 y: self.y.clone(),
3913 weights: self.weights.clone(),
3914 design: self.design.clone(),
3915 offset: self.offset.clone(),
3916 linear_constraints: self.linear_constraints.clone(),
3917 runtime_caches: self.runtime_caches.clone(),
3918 adaptive_params,
3919 fixed_quadratichessian,
3920 hyperspecs: self.hyperspecs.clone(),
3921 exact_eval_cache: Arc::new(Mutex::new(None)),
3922 }
3923 }
3924
3925 fn total_eta(&self, beta: &Array1<f64>) -> Array1<f64> {
3926 gam_linalg::faer_ndarray::fast_av(self.design.as_ref(), beta) + self.offset.as_ref()
3927 }
3928
3929 fn fixed_quadratic_terms(&self, beta: &Array1<f64>) -> (f64, Array1<f64>) {
3930 let grad = self.fixed_quadratichessian.dot(beta);
3931 let value = 0.5 * beta.dot(&grad);
3932 (value, grad)
3933 }
3934
3935 fn adaptive_penalty_value_only(&self, beta: &Array1<f64>) -> Result<f64, String> {
3936 let mut penalty_value = 0.0;
3937 for (cache_idx, cache) in self.runtime_caches.iter().enumerate() {
3938 let params = self.adaptive_params.get(cache_idx).ok_or_else(|| {
3939 format!(
3940 "missing adaptive parameter block for cache {}",
3941 cache.termname
3942 )
3943 })?;
3944 let beta_local = beta.slice(s![cache.coeff_global_range.clone()]);
3945 let state =
3946 SpatialPenaltyExactState::from_beta_local(beta_local, cache, params.epsilon)
3947 .map_err(|e| e.to_string())?;
3948 penalty_value += params.lambda[0] * state.magnitude.penalty_value();
3949 penalty_value += params.lambda[1] * state.gradient.penalty_value();
3950 penalty_value += params.lambda[2] * state.curvature.penalty_value();
3951 }
3952 Ok(penalty_value)
3953 }
3954
3955 fn zero_hyper_parts(&self) -> (Array1<f64>, Array2<f64>) {
3956 let total_dim = self.design.ncols();
3957 (
3958 Array1::<f64>::zeros(total_dim),
3959 Array2::<f64>::zeros((total_dim, total_dim)),
3960 )
3961 }
3962
3963 fn embed_local_hyper_parts(
3964 &self,
3965 coeff_range: &Range<usize>,
3966 local_grad: &Array1<f64>,
3967 local_hess: &Array2<f64>,
3968 ) -> (Array1<f64>, Array2<f64>) {
3969 let (mut beta_mixed, mut betahessian) = self.zero_hyper_parts();
3970 beta_mixed
3971 .slice_mut(s![coeff_range.clone()])
3972 .assign(local_grad);
3973 betahessian
3974 .slice_mut(s![coeff_range.clone(), coeff_range.clone()])
3975 .assign(local_hess);
3976 (beta_mixed, betahessian)
3977 }
3978
3979 fn embed_local_hyper_hessian(
3980 &self,
3981 coeff_range: &Range<usize>,
3982 local_hess: &Array2<f64>,
3983 ) -> Array2<f64> {
3984 let total_dim = self.design.ncols();
3985 let mut out = Array2::<f64>::zeros((total_dim, total_dim));
3986 out.slice_mut(s![coeff_range.clone(), coeff_range.clone()])
3987 .assign(local_hess);
3988 out
3989 }
3990
3991 fn adaptive_block_eval(
4000 &self,
4001 eval: &SpatialAdaptiveExactEvaluation,
4002 cache_idx: usize,
4003 component: AdaptiveComponent,
4004 derivative: HyperDerivativeKind,
4005 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
4006 let cache = self
4007 .runtime_caches
4008 .get(cache_idx)
4009 .ok_or_else(|| format!("adaptive cache index {} out of bounds", cache_idx))?;
4010 let params = self
4011 .adaptive_params
4012 .get(cache_idx)
4013 .ok_or_else(|| format!("adaptive hyperparameter block {} out of bounds", cache_idx))?;
4014 let state = eval
4015 .adaptive_states
4016 .get(cache_idx)
4017 .ok_or_else(|| format!("adaptive exact state index {} out of bounds", cache_idx))?;
4018
4019 let (objective_local, beta_mixed_local, betahessian_local) = match component {
4020 AdaptiveComponent::Magnitude => {
4021 let lambda = params.lambda[0];
4022 let mag = &state.magnitude;
4023 let (objective, gradient_coeff, hessian_diag) = match derivative {
4024 HyperDerivativeKind::Rho => (
4025 mag.penalty_value(),
4026 mag.betagradient_coeff(),
4027 mag.betahessian_diag(),
4028 ),
4029 HyperDerivativeKind::LogEpsilonFirst => (
4030 mag.log_epsilon_gradient_terms().sum(),
4031 mag.log_epsilon_betagradient_coeff(),
4032 mag.log_epsilon_betahessian_diag(),
4033 ),
4034 HyperDerivativeKind::LogEpsilonSecond => (
4035 mag.log_epsilon_hessian_terms().sum(),
4036 mag.log_epsilon_beta_mixed_second_coeff(),
4037 mag.log_epsilon_betahessian_second_diag(),
4038 ),
4039 };
4040 (
4041 lambda * objective,
4042 lambda * scalar_operatorgradient(&cache.d0, &gradient_coeff),
4043 lambda * scalar_operatorhessian(&cache.d0, &hessian_diag),
4044 )
4045 }
4046 AdaptiveComponent::Gradient => {
4047 let lambda = params.lambda[1];
4048 let grad = &state.gradient;
4049 let (objective, gradient_blocks, hessian_blocks) = match derivative {
4050 HyperDerivativeKind::Rho => (
4051 grad.penalty_value(),
4052 grad.betagradient_blocks(),
4053 grad.betahessian_blocks(),
4054 ),
4055 HyperDerivativeKind::LogEpsilonFirst => (
4056 grad.log_epsilon_gradient_terms().sum(),
4057 grad.log_epsilon_betagradient_blocks(),
4058 grad.log_epsilon_betahessian_blocks(),
4059 ),
4060 HyperDerivativeKind::LogEpsilonSecond => (
4061 grad.log_epsilon_hessian_terms().sum(),
4062 grad.log_epsilon_beta_mixed_second_blocks(),
4063 grad.log_epsilon_betahessian_second_blocks(),
4064 ),
4065 };
4066 (
4067 lambda * objective,
4068 lambda
4069 * grouped_operatorgradient(&cache.d1, cache.dimension, &gradient_blocks)
4070 .map_err(|e| e.to_string())?,
4071 lambda
4072 * grouped_operatorhessian(&cache.d1, cache.dimension, &hessian_blocks)
4073 .map_err(|e| e.to_string())?,
4074 )
4075 }
4076 AdaptiveComponent::Curvature => {
4077 let lambda = params.lambda[2];
4078 let group = cache.dimension * cache.dimension;
4079 let curv = &state.curvature;
4080 let (objective, gradient_blocks, hessian_blocks) = match derivative {
4081 HyperDerivativeKind::Rho => (
4082 curv.penalty_value(),
4083 curv.betagradient_blocks(),
4084 curv.betahessian_blocks(),
4085 ),
4086 HyperDerivativeKind::LogEpsilonFirst => (
4087 curv.log_epsilon_gradient_terms().sum(),
4088 curv.log_epsilon_betagradient_blocks(),
4089 curv.log_epsilon_betahessian_blocks(),
4090 ),
4091 HyperDerivativeKind::LogEpsilonSecond => (
4092 curv.log_epsilon_hessian_terms().sum(),
4093 curv.log_epsilon_beta_mixed_second_blocks(),
4094 curv.log_epsilon_betahessian_second_blocks(),
4095 ),
4096 };
4097 (
4098 lambda * objective,
4099 lambda
4100 * grouped_operatorgradient(&cache.d2, group, &gradient_blocks)
4101 .map_err(|e| e.to_string())?,
4102 lambda
4103 * grouped_operatorhessian(&cache.d2, group, &hessian_blocks)
4104 .map_err(|e| e.to_string())?,
4105 )
4106 }
4107 };
4108
4109 let (beta_mixed, betahessian) = self.embed_local_hyper_parts(
4110 &cache.coeff_global_range,
4111 &beta_mixed_local,
4112 &betahessian_local,
4113 );
4114 Ok((objective_local, beta_mixed, betahessian))
4115 }
4116
4117 fn adaptive_shared_log_epsilon_parts(
4118 &self,
4119 eval: &SpatialAdaptiveExactEvaluation,
4120 component: usize,
4121 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
4122 self.adaptive_shared_block_eval(eval, component, HyperDerivativeKind::LogEpsilonFirst)
4128 }
4129
4130 fn adaptive_shared_log_epsilon_second_parts(
4131 &self,
4132 eval: &SpatialAdaptiveExactEvaluation,
4133 component: usize,
4134 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
4135 self.adaptive_shared_block_eval(eval, component, HyperDerivativeKind::LogEpsilonSecond)
4141 }
4142
4143 fn adaptive_shared_block_eval(
4148 &self,
4149 eval: &SpatialAdaptiveExactEvaluation,
4150 component: usize,
4151 derivative: HyperDerivativeKind,
4152 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
4153 let component = AdaptiveComponent::from_index(component)?;
4154 let (mut score, mut hessian) = self.zero_hyper_parts();
4155 let mut objective = 0.0;
4156 for cache_idx in 0..self.runtime_caches.len() {
4157 let (local_objective, local_score, local_hessian) =
4158 self.adaptive_block_eval(eval, cache_idx, component, derivative)?;
4159 objective += local_objective;
4160 score += &local_score;
4161 hessian += &local_hessian;
4162 }
4163 Ok((objective, score, hessian))
4164 }
4165
4166 fn adaptive_shared_log_epsilon_drift(
4167 &self,
4168 eval: &SpatialAdaptiveExactEvaluation,
4169 component: usize,
4170 direction: &Array1<f64>,
4171 ) -> Result<Array2<f64>, String> {
4172 let component = AdaptiveComponent::from_index(component)?;
4176 let total_dim = self.design.ncols();
4177 let mut total = Array2::<f64>::zeros((total_dim, total_dim));
4178 for cache_idx in 0..self.runtime_caches.len() {
4179 total += &self.adaptive_block_drift_eval(
4180 eval,
4181 cache_idx,
4182 component,
4183 HyperDriftKind::LogEpsilon,
4184 direction,
4185 )?;
4186 }
4187 Ok(total)
4188 }
4189
4190 fn adaptive_explicit_second_order_parts(
4191 &self,
4192 eval: &SpatialAdaptiveExactEvaluation,
4193 left: SpatialAdaptiveHyperSpec,
4194 right: SpatialAdaptiveHyperSpec,
4195 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
4196 match left.explicit_second_order_kind(right) {
4205 SpatialAdaptiveExplicitSecondOrderKind::StructuralZero => {
4206 let (score, hessian) = self.zero_hyper_parts();
4207 Ok((0.0, score, hessian))
4208 }
4209 SpatialAdaptiveExplicitSecondOrderKind::LocalAlphaAlpha => self.adaptive_block_eval(
4210 eval,
4211 left.cache_index,
4212 AdaptiveComponent::from_index(left.component_index())?,
4213 HyperDerivativeKind::Rho,
4214 ),
4215 SpatialAdaptiveExplicitSecondOrderKind::LocalAlphaEta => {
4216 let local_alpha = if left.kind.is_log_lambda() {
4217 left
4218 } else {
4219 right
4220 };
4221 self.adaptive_block_eval(
4222 eval,
4223 local_alpha.cache_index,
4224 AdaptiveComponent::from_index(local_alpha.component_index())?,
4225 HyperDerivativeKind::LogEpsilonFirst,
4226 )
4227 }
4228 SpatialAdaptiveExplicitSecondOrderKind::SharedEtaEta => {
4229 self.adaptive_shared_log_epsilon_second_parts(eval, left.component_index())
4230 }
4231 }
4232 }
4233
4234 fn adaptive_block_drift_eval(
4242 &self,
4243 eval: &SpatialAdaptiveExactEvaluation,
4244 cache_idx: usize,
4245 component: AdaptiveComponent,
4246 drift: HyperDriftKind,
4247 direction: &Array1<f64>,
4248 ) -> Result<Array2<f64>, String> {
4249 let cache = self
4250 .runtime_caches
4251 .get(cache_idx)
4252 .ok_or_else(|| format!("adaptive cache index {} out of bounds", cache_idx))?;
4253 let params = self
4254 .adaptive_params
4255 .get(cache_idx)
4256 .ok_or_else(|| format!("adaptive hyperparameter block {} out of bounds", cache_idx))?;
4257 let state = eval
4258 .adaptive_states
4259 .get(cache_idx)
4260 .ok_or_else(|| format!("adaptive exact state index {} out of bounds", cache_idx))?;
4261 let direction_local = direction.slice(s![cache.coeff_global_range.clone()]);
4262
4263 let local_hessian = match component {
4264 AdaptiveComponent::Magnitude => {
4265 let d0_u = cache.d0.dot(&direction_local);
4266 let mag = &state.magnitude;
4267 let diag = match drift {
4268 HyperDriftKind::Rho => mag.directionalhessian_diag(&d0_u),
4269 HyperDriftKind::LogEpsilon => {
4270 mag.log_epsilon_betahessian_directional_diag(&d0_u)
4271 }
4272 };
4273 params.lambda[0] * scalar_operatorhessian(&cache.d0, &diag)
4274 }
4275 AdaptiveComponent::Gradient => {
4276 let d1_u = cache.d1.dot(&direction_local);
4277 let direction_blocks = collocationgradient_blocks(&d1_u, cache.dimension)
4278 .map_err(|e| e.to_string())?;
4279 let grad = &state.gradient;
4280 let blocks = match drift {
4281 HyperDriftKind::Rho => grad.directionalhessian_blocks(&direction_blocks),
4282 HyperDriftKind::LogEpsilon => {
4283 grad.log_epsilon_betahessian_directional_blocks(&direction_blocks)
4284 }
4285 };
4286 params.lambda[1]
4287 * grouped_operatorhessian(&cache.d1, cache.dimension, &blocks)
4288 .map_err(|e| e.to_string())?
4289 }
4290 AdaptiveComponent::Curvature => {
4291 let group = cache.dimension * cache.dimension;
4292 let d2_u = cache.d2.dot(&direction_local);
4293 let direction_blocks =
4294 collocationhessian_blocks(&d2_u, cache.dimension).map_err(|e| e.to_string())?;
4295 let curv = &state.curvature;
4296 let blocks = match drift {
4297 HyperDriftKind::Rho => curv.directionalhessian_blocks(&direction_blocks),
4298 HyperDriftKind::LogEpsilon => {
4299 curv.log_epsilon_betahessian_directional_blocks(&direction_blocks)
4300 }
4301 };
4302 params.lambda[2]
4303 * grouped_operatorhessian(&cache.d2, group, &blocks)
4304 .map_err(|e| e.to_string())?
4305 }
4306 };
4307
4308 Ok(self.embed_local_hyper_hessian(&cache.coeff_global_range, &local_hessian))
4309 }
4310
4311 fn adaptive_hyper_parts(
4312 &self,
4313 eval: &SpatialAdaptiveExactEvaluation,
4314 hyper: SpatialAdaptiveHyperSpec,
4315 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
4316 match hyper.kind {
4317 SpatialAdaptiveHyperKind::LogLambdaMagnitude
4320 | SpatialAdaptiveHyperKind::LogLambdaGradient
4321 | SpatialAdaptiveHyperKind::LogLambdaCurvature => self.adaptive_block_eval(
4322 eval,
4323 hyper.cache_index,
4324 AdaptiveComponent::from_index(hyper.component_index())?,
4325 HyperDerivativeKind::Rho,
4326 ),
4327 SpatialAdaptiveHyperKind::LogEpsilonMagnitude
4329 | SpatialAdaptiveHyperKind::LogEpsilonGradient
4330 | SpatialAdaptiveHyperKind::LogEpsilonCurvature => {
4331 self.adaptive_shared_log_epsilon_parts(eval, hyper.component_index())
4332 }
4333 }
4334 }
4335
4336 fn exact_evaluation_uncached(
4337 &self,
4338 beta: &Array1<f64>,
4339 ) -> Result<SpatialAdaptiveExactEvaluation, String> {
4340 let eta = self.total_eta(beta);
4341 let obs = evaluate_standard_familyobservations(
4342 self.family.clone(),
4343 self.latent_cloglog_state.as_ref(),
4344 self.mixture_link_state.as_ref(),
4345 self.sas_link_state.as_ref(),
4346 &self.y,
4347 &self.weights,
4348 &eta,
4349 )
4350 .map_err(|e| e.to_string())?;
4351 let p = beta.len();
4352 let mut penalty_value = 0.0;
4353 let mut penaltygradient = Array1::<f64>::zeros(p);
4354 let mut penaltyhessian = Array2::<f64>::zeros((p, p));
4355 let mut adaptive_states = Vec::with_capacity(self.runtime_caches.len());
4356
4357 for (cache_idx, cache) in self.runtime_caches.iter().enumerate() {
4358 let params = self.adaptive_params.get(cache_idx).ok_or_else(|| {
4359 format!(
4360 "missing adaptive parameter block for cache {}",
4361 cache.termname
4362 )
4363 })?;
4364 let beta_local = beta.slice(s![cache.coeff_global_range.clone()]);
4365 let state =
4366 SpatialPenaltyExactState::from_beta_local(beta_local, cache, params.epsilon)
4367 .map_err(|e| e.to_string())?;
4368
4369 let g0 = scalar_operatorgradient(&cache.d0, &state.magnitude.betagradient_coeff());
4370 let gg = grouped_operatorgradient(
4371 &cache.d1,
4372 cache.dimension,
4373 &state.gradient.betagradient_blocks(),
4374 )
4375 .map_err(|e| e.to_string())?;
4376 let gc = grouped_operatorgradient(
4377 &cache.d2,
4378 cache.dimension * cache.dimension,
4379 &state.curvature.betagradient_blocks(),
4380 )
4381 .map_err(|e| e.to_string())?;
4382 let h0 = scalar_operatorhessian(&cache.d0, &state.magnitude.betahessian_diag());
4383 let hg = grouped_operatorhessian(
4384 &cache.d1,
4385 cache.dimension,
4386 &state.gradient.betahessian_blocks(),
4387 )
4388 .map_err(|e| e.to_string())?;
4389 let hc = grouped_operatorhessian(
4390 &cache.d2,
4391 cache.dimension * cache.dimension,
4392 &state.curvature.betahessian_blocks(),
4393 )
4394 .map_err(|e| e.to_string())?;
4395
4396 let lambda0 = params.lambda[0];
4397 let lambdag = params.lambda[1];
4398 let lambdac = params.lambda[2];
4399
4400 penalty_value += lambda0 * state.magnitude.penalty_value();
4401 penalty_value += lambdag * state.gradient.penalty_value();
4402 penalty_value += lambdac * state.curvature.penalty_value();
4403
4404 let range = cache.coeff_global_range.clone();
4405 {
4406 let mut grad_local = penaltygradient.slice_mut(s![range.clone()]);
4407 grad_local += &(g0.mapv(|v| lambda0 * v));
4408 grad_local += &(gg.mapv(|v| lambdag * v));
4409 grad_local += &(gc.mapv(|v| lambdac * v));
4410 }
4411 {
4412 let mut h_local = penaltyhessian.slice_mut(s![range.clone(), range]);
4413 h_local += &h0.mapv(|v| lambda0 * v);
4414 h_local += &hg.mapv(|v| lambdag * v);
4415 h_local += &hc.mapv(|v| lambdac * v);
4416 }
4417
4418 adaptive_states.push(state);
4419 }
4420
4421 let (fixed_quadraticvalue, fixed_quadraticgradient) = self.fixed_quadratic_terms(beta);
4422 Ok(SpatialAdaptiveExactEvaluation {
4423 obs,
4424 adaptive_states,
4425 adaptive_penalty_value: penalty_value,
4426 adaptive_penaltygradient: penaltygradient,
4427 adaptive_penaltyhessian: penaltyhessian,
4428 fixed_quadraticvalue,
4429 fixed_quadraticgradient,
4430 fixed_quadratichessian: self.fixed_quadratichessian.as_ref().clone(),
4431 })
4432 }
4433
4434 fn exact_evaluation(
4435 &self,
4436 beta: &Array1<f64>,
4437 ) -> Result<Arc<SpatialAdaptiveExactEvaluation>, String> {
4438 {
4439 let cache = self
4440 .exact_eval_cache
4441 .lock()
4442 .map_err(|_| "spatial adaptive exact-evaluation cache lock poisoned".to_string())?;
4443 if let Some(cached) = cache.as_ref()
4444 && cached.beta.len() == beta.len()
4445 && cached
4446 .beta
4447 .iter()
4448 .zip(beta.iter())
4449 .all(|(&left, &right)| left == right)
4450 {
4451 return Ok(Arc::clone(&cached.eval));
4452 }
4453 }
4454
4455 let eval = Arc::new(self.exact_evaluation_uncached(beta)?);
4456 let mut cache = self
4457 .exact_eval_cache
4458 .lock()
4459 .map_err(|_| "spatial adaptive exact-evaluation cache lock poisoned".to_string())?;
4460 *cache = Some(CachedSpatialAdaptiveExactEvaluation {
4461 beta: beta.clone(),
4462 eval: Arc::clone(&eval),
4463 });
4464 Ok(eval)
4465 }
4466
4467 fn exacthessian_directional_derivative_from_evaluation(
4468 &self,
4469 beta: &Array1<f64>,
4470 eval: &SpatialAdaptiveExactEvaluation,
4471 direction: &Array1<f64>,
4472 ) -> Result<Array2<f64>, String> {
4473 assert_eq!(
4474 beta.len(),
4475 direction.len(),
4476 "beta/direction length mismatch",
4477 );
4478 let d_eta = gam_linalg::faer_ndarray::fast_av(self.design.as_ref(), direction);
4479 let mut total = xt_diag_x_dense(
4480 self.design.view(),
4481 (&eval.obs.neghessian_eta_derivative * &d_eta).view(),
4482 )?;
4483 for (cache_idx, cache) in self.runtime_caches.iter().enumerate() {
4484 let params = self.adaptive_params.get(cache_idx).ok_or_else(|| {
4485 format!(
4486 "missing adaptive parameter block for cache {}",
4487 cache.termname
4488 )
4489 })?;
4490 let state = eval
4491 .adaptive_states
4492 .get(cache_idx)
4493 .ok_or_else(|| format!("missing adaptive state for cache {}", cache.termname))?;
4494 let direction_local = direction.slice(s![cache.coeff_global_range.clone()]);
4495 let d0_u = cache.d0.dot(&direction_local);
4496 let d1_u = cache.d1.dot(&direction_local);
4497 let d2_u = cache.d2.dot(&direction_local);
4498 let h0 =
4499 scalar_operatorhessian(&cache.d0, &state.magnitude.directionalhessian_diag(&d0_u))
4500 .mapv(|v| params.lambda[0] * v);
4501 let hg = grouped_operatorhessian(
4502 &cache.d1,
4503 cache.dimension,
4504 &state.gradient.directionalhessian_blocks(
4505 &collocationgradient_blocks(&d1_u, cache.dimension)
4506 .map_err(|e| e.to_string())?,
4507 ),
4508 )
4509 .map_err(|e| e.to_string())?
4510 .mapv(|v| params.lambda[1] * v);
4511 let hc = grouped_operatorhessian(
4512 &cache.d2,
4513 cache.dimension * cache.dimension,
4514 &state.curvature.directionalhessian_blocks(
4515 &collocationhessian_blocks(&d2_u, cache.dimension)
4516 .map_err(|e| e.to_string())?,
4517 ),
4518 )
4519 .map_err(|e| e.to_string())?
4520 .mapv(|v| params.lambda[2] * v);
4521 let range = cache.coeff_global_range.clone();
4522 let mut local = total.slice_mut(s![range.clone(), range]);
4523 local += &h0;
4524 local += &hg;
4525 local += &hc;
4526 }
4527 Ok(total)
4528 }
4529
4530 fn exacthessian_second_directional_derivative_from_evaluation(
4551 &self,
4552 eval: &SpatialAdaptiveExactEvaluation,
4553 direction_u: &Array1<f64>,
4554 direction_v: &Array1<f64>,
4555 ) -> Result<Option<Array2<f64>>, String> {
4556 let p = self.design.ncols();
4557 if eval.obs.neghessian_eta_derivative.iter().any(|&w| w != 0.0) {
4559 return Ok(None);
4560 }
4561 let mut total = Array2::<f64>::zeros((p, p));
4562 for (cache_idx, cache) in self.runtime_caches.iter().enumerate() {
4563 let params = self.adaptive_params.get(cache_idx).ok_or_else(|| {
4564 format!(
4565 "missing adaptive parameter block for cache {}",
4566 cache.termname
4567 )
4568 })?;
4569 let state = eval
4570 .adaptive_states
4571 .get(cache_idx)
4572 .ok_or_else(|| format!("missing adaptive state for cache {}", cache.termname))?;
4573 let u_local = direction_u.slice(s![cache.coeff_global_range.clone()]);
4574 let v_local = direction_v.slice(s![cache.coeff_global_range.clone()]);
4575
4576 let q0_u = cache.d0.dot(&u_local);
4578 let q0_v = cache.d0.dot(&v_local);
4579 let h0 = scalar_operatorhessian(
4580 &cache.d0,
4581 &state.magnitude.second_directionalhessian_diag(&q0_u, &q0_v),
4582 )
4583 .mapv(|x| params.lambda[0] * x);
4584
4585 let a1 = collocationgradient_blocks(&cache.d1.dot(&u_local), cache.dimension)
4587 .map_err(|e| e.to_string())?;
4588 let b1 = collocationgradient_blocks(&cache.d1.dot(&v_local), cache.dimension)
4589 .map_err(|e| e.to_string())?;
4590 let hg = grouped_operatorhessian(
4591 &cache.d1,
4592 cache.dimension,
4593 &state.gradient.second_directionalhessian_blocks(&a1, &b1),
4594 )
4595 .map_err(|e| e.to_string())?
4596 .mapv(|x| params.lambda[1] * x);
4597
4598 let a2 = collocationhessian_blocks(&cache.d2.dot(&u_local), cache.dimension)
4600 .map_err(|e| e.to_string())?;
4601 let b2 = collocationhessian_blocks(&cache.d2.dot(&v_local), cache.dimension)
4602 .map_err(|e| e.to_string())?;
4603 let hc = grouped_operatorhessian(
4604 &cache.d2,
4605 cache.dimension * cache.dimension,
4606 &state.curvature.second_directionalhessian_blocks(&a2, &b2),
4607 )
4608 .map_err(|e| e.to_string())?
4609 .mapv(|x| params.lambda[2] * x);
4610
4611 let range = cache.coeff_global_range.clone();
4612 let mut local = total.slice_mut(s![range.clone(), range]);
4613 local += &h0;
4614 local += &hg;
4615 local += &hc;
4616 }
4617 Ok(Some(total))
4618 }
4619}
4620
4621impl CustomFamily for SpatialAdaptiveExactFamily {
4622 fn joint_jeffreys_term_required(&self) -> bool {
4626 true
4627 }
4628
4629 fn joint_jeffreys_information_with_specs(
4666 &self,
4667 block_states: &[ParameterBlockState],
4668 specs: &[ParameterBlockSpec],
4669 ) -> Result<Option<Array2<f64>>, String> {
4670 let spec = expect_single_blockspec(specs, "spatial adaptive exact family")?;
4671 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4672 if spec.design.ncols() != beta.len() {
4673 return Err(SmoothError::dimension_mismatch(format!(
4674 "spatial adaptive Jeffreys information: spec design has {} columns, beta has {}",
4675 spec.design.ncols(),
4676 beta.len()
4677 ))
4678 .into());
4679 }
4680 let eval = self.exact_evaluation(beta)?;
4681 Ok(Some(xt_diag_x_dense(
4682 self.design.view(),
4683 eval.obs.neghessian_eta.view(),
4684 )?))
4685 }
4686
4687 fn joint_jeffreys_information_directional_derivative_with_specs(
4688 &self,
4689 block_states: &[ParameterBlockState],
4690 specs: &[ParameterBlockSpec],
4691 d_beta_flat: &Array1<f64>,
4692 ) -> Result<Option<Array2<f64>>, String> {
4693 let spec = expect_single_blockspec(specs, "spatial adaptive exact family")?;
4699 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4700 if spec.design.ncols() != d_beta_flat.len() {
4701 return Err(SmoothError::dimension_mismatch(format!(
4702 "spatial adaptive Jeffreys directional derivative: spec design has {} columns, direction has {}",
4703 spec.design.ncols(),
4704 d_beta_flat.len()
4705 ))
4706 .into());
4707 }
4708 let eval = self.exact_evaluation(beta)?;
4709 let d_eta = gam_linalg::faer_ndarray::fast_av(self.design.as_ref(), d_beta_flat);
4710 Ok(Some(xt_diag_x_dense(
4711 self.design.view(),
4712 (&eval.obs.neghessian_eta_derivative * &d_eta).view(),
4713 )?))
4714 }
4715
4716 fn joint_jeffreys_information_second_directional_derivative_with_specs(
4717 &self,
4718 block_states: &[ParameterBlockState],
4719 specs: &[ParameterBlockSpec],
4720 d_beta_u_flat: &Array1<f64>,
4721 d_betav_flat: &Array1<f64>,
4722 ) -> Result<Option<Array2<f64>>, String> {
4723 let spec = expect_single_blockspec(specs, "spatial adaptive exact family")?;
4730 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4731 if spec.design.ncols() != beta.len()
4732 || d_beta_u_flat.len() != beta.len()
4733 || d_betav_flat.len() != beta.len()
4734 {
4735 return Err(SmoothError::dimension_mismatch(format!(
4736 "spatial adaptive Jeffreys second-direction length mismatch: spec cols={}, dirs=({}, {}), expected {}",
4737 spec.design.ncols(),
4738 d_beta_u_flat.len(),
4739 d_betav_flat.len(),
4740 beta.len()
4741 ))
4742 .into());
4743 }
4744 let eval = self.exact_evaluation(beta)?;
4745 if eval.obs.neghessian_eta_derivative.iter().any(|&w| w != 0.0) {
4746 return Ok(None);
4747 }
4748 Ok(Some(Array2::<f64>::zeros((beta.len(), beta.len()))))
4749 }
4750
4751 fn joint_jeffreys_information_matches_observed_hessian(&self) -> bool {
4752 false
4757 }
4758
4759 fn joint_jeffreys_information_depends_on_psi(&self) -> bool {
4760 false
4769 }
4770
4771 fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
4772 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4773 let eval = self.exact_evaluation(beta)?;
4774 let mut gradient = fast_atv(&self.design, &eval.obs.score);
4775 gradient -= &eval.total_penaltygradient();
4776 let mut hessian = xt_diag_x_dense(self.design.view(), eval.obs.neghessian_eta.view())?;
4777 hessian += &eval.total_penaltyhessian();
4778 Ok(FamilyEvaluation {
4779 log_likelihood: eval.obs.log_likelihood - eval.total_penalty_value(),
4780 blockworking_sets: vec![BlockWorkingSet::ExactNewton {
4781 gradient,
4782 hessian: SymmetricMatrix::Dense(hessian),
4783 }],
4784 })
4785 }
4786
4787 fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
4788 let state = expect_single_block_state(block_states, "spatial adaptive exact family")?;
4789 let beta = &state.beta;
4790 let obs = evaluate_standard_familyobservations(
4791 self.family.clone(),
4792 self.latent_cloglog_state.as_ref(),
4793 self.mixture_link_state.as_ref(),
4794 self.sas_link_state.as_ref(),
4795 &self.y,
4796 &self.weights,
4797 &state.eta,
4798 )
4799 .map_err(|e| e.to_string())?;
4800 let adaptive_penalty = self.adaptive_penalty_value_only(beta)?;
4801 let (fixed_quadratic, _) = self.fixed_quadratic_terms(beta);
4802 Ok(obs.log_likelihood - adaptive_penalty - fixed_quadratic)
4803 }
4804
4805 fn exact_newton_outerobjective(&self) -> ExactNewtonOuterObjective {
4806 ExactNewtonOuterObjective::StrictPseudoLaplace
4807 }
4808
4809 fn exact_newton_joint_hessian(
4810 &self,
4811 block_states: &[ParameterBlockState],
4812 ) -> Result<Option<Array2<f64>>, String> {
4813 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4814 let eval = self.exact_evaluation(beta)?;
4815 Ok(Some(eval.totalobjectivehessian(&self.design)?))
4816 }
4817
4818 fn exact_newton_hessian_directional_derivative(
4819 &self,
4820 block_states: &[ParameterBlockState],
4821 block_idx: usize,
4822 d_beta: &Array1<f64>,
4823 ) -> Result<Option<Array2<f64>>, String> {
4824 expect_block_idx_zero(block_idx, "spatial adaptive exact family", "")?;
4825 self.exact_newton_joint_hessian_directional_derivative(block_states, d_beta)
4826 }
4827
4828 fn exact_newton_joint_hessian_directional_derivative(
4829 &self,
4830 block_states: &[ParameterBlockState],
4831 d_beta_flat: &Array1<f64>,
4832 ) -> Result<Option<Array2<f64>>, String> {
4833 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4834 if d_beta_flat.len() != beta.len() {
4835 return Err(SmoothError::dimension_mismatch(format!(
4836 "spatial adaptive exact family direction length mismatch: got {}, expected {}",
4837 d_beta_flat.len(),
4838 beta.len()
4839 ))
4840 .into());
4841 }
4842 let eval = self.exact_evaluation(beta)?;
4843 Ok(Some(
4844 self.exacthessian_directional_derivative_from_evaluation(beta, &eval, d_beta_flat)?,
4845 ))
4846 }
4847
4848 fn exact_newton_joint_hessiansecond_directional_derivative(
4849 &self,
4850 block_states: &[ParameterBlockState],
4851 d_beta_u_flat: &Array1<f64>,
4852 d_betav_flat: &Array1<f64>,
4853 ) -> Result<Option<Array2<f64>>, String> {
4854 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4855 if d_beta_u_flat.len() != beta.len() || d_betav_flat.len() != beta.len() {
4856 return Err(SmoothError::dimension_mismatch(format!(
4857 "spatial adaptive exact family second-direction length mismatch: got ({}, {}), expected {}",
4858 d_beta_u_flat.len(),
4859 d_betav_flat.len(),
4860 beta.len()
4861 ))
4862 .into());
4863 }
4864 let eval = self.exact_evaluation(beta)?;
4865 self.exacthessian_second_directional_derivative_from_evaluation(
4866 &eval,
4867 d_beta_u_flat,
4868 d_betav_flat,
4869 )
4870 }
4871
4872 fn block_linear_constraints(
4873 &self,
4874 block_states: &[ParameterBlockState],
4875 block_idx: usize,
4876 block_spec: &ParameterBlockSpec,
4877 ) -> Result<Option<LinearInequalityConstraints>, String> {
4878 assert!(!block_states.is_empty(), "block_states must be non-empty");
4879 assert!(
4880 !block_spec.name.is_empty(),
4881 "block spec name must be non-empty",
4882 );
4883 expect_block_idx_zero(block_idx, "spatial adaptive exact family", "")?;
4884 Ok(self.linear_constraints.clone())
4885 }
4886
4887 fn exact_newton_joint_psi_terms(
4888 &self,
4889 block_states: &[ParameterBlockState],
4890 specs: &[ParameterBlockSpec],
4891 derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
4892 psi_index: usize,
4893 ) -> Result<Option<ExactNewtonJointPsiTerms>, String> {
4894 if block_states.len() != 1 || specs.len() != 1 || derivative_blocks.len() != 1 {
4895 return Err(SmoothError::dimension_mismatch(format!(
4896 "spatial adaptive exact family expects one block/state/spec/psi payload, got states={} specs={} deriv_blocks={}",
4897 block_states.len(),
4898 specs.len(),
4899 derivative_blocks.len()
4900 ))
4901 .into());
4902 }
4903 derivative_blocks[0]
4904 .get(psi_index)
4905 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_index))?;
4906 let hyper = self
4907 .hyperspecs
4908 .get(psi_index)
4909 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_index))?;
4910 let beta = &block_states[0].beta;
4911 let eval = self.exact_evaluation(beta)?;
4912 let (direct, beta_mixed, betahessian_explicit) =
4913 self.adaptive_hyper_parts(&eval, *hyper)?;
4914
4915 Ok(Some(ExactNewtonJointPsiTerms {
4936 objective_psi: direct,
4937 score_psi: beta_mixed,
4938 hessian_psi: betahessian_explicit,
4939 hessian_psi_operator: None,
4940 }))
4941 }
4942
4943 fn exact_newton_joint_psisecond_order_terms(
4944 &self,
4945 block_states: &[ParameterBlockState],
4946 specs: &[ParameterBlockSpec],
4947 derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
4948 psi_i: usize,
4949 psi_j: usize,
4950 ) -> Result<Option<gam_problem::ExactNewtonJointPsiSecondOrderTerms>, String> {
4951 if block_states.len() != 1 || specs.len() != 1 || derivative_blocks.len() != 1 {
4952 return Err(SmoothError::dimension_mismatch(format!(
4953 "spatial adaptive exact family expects one block/state/spec/psi payload, got states={} specs={} deriv_blocks={}",
4954 block_states.len(),
4955 specs.len(),
4956 derivative_blocks.len()
4957 ))
4958 .into());
4959 }
4960 derivative_blocks[0]
4961 .get(psi_i)
4962 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_i))?;
4963 derivative_blocks[0]
4964 .get(psi_j)
4965 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_j))?;
4966 let hyper_i = self
4967 .hyperspecs
4968 .get(psi_i)
4969 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_i))?;
4970 let hyper_j = self
4971 .hyperspecs
4972 .get(psi_j)
4973 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_j))?;
4974 let beta = &block_states[0].beta;
4975 let eval = self.exact_evaluation(beta)?;
4976 let (objective_psi_psi, score_psi_psi, hessian_psi_psi) =
4977 self.adaptive_explicit_second_order_parts(&eval, *hyper_i, *hyper_j)?;
4978
4979 Ok(Some(
4980 gam_problem::ExactNewtonJointPsiSecondOrderTerms {
4981 objective_psi_psi,
4982 score_psi_psi,
4983 hessian_psi_psi,
4984 hessian_psi_psi_operator: None,
4985 },
4986 ))
4987 }
4988
4989 fn exact_newton_joint_psihessian_directional_derivative(
4990 &self,
4991 block_states: &[ParameterBlockState],
4992 specs: &[ParameterBlockSpec],
4993 derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
4994 psi_index: usize,
4995 direction: &Array1<f64>,
4996 ) -> Result<Option<Array2<f64>>, String> {
4997 if block_states.len() != 1 || specs.len() != 1 || derivative_blocks.len() != 1 {
4998 return Err(SmoothError::dimension_mismatch(format!(
4999 "spatial adaptive exact family expects one block/state/spec/psi payload, got states={} specs={} deriv_blocks={}",
5000 block_states.len(),
5001 specs.len(),
5002 derivative_blocks.len()
5003 ))
5004 .into());
5005 }
5006 let beta = &block_states[0].beta;
5007 if direction.len() != beta.len() {
5008 return Err(SmoothError::dimension_mismatch(format!(
5009 "spatial adaptive exact family direction length mismatch: got {}, expected {}",
5010 direction.len(),
5011 beta.len()
5012 ))
5013 .into());
5014 }
5015 derivative_blocks[0]
5016 .get(psi_index)
5017 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_index))?;
5018 let hyper = self
5019 .hyperspecs
5020 .get(psi_index)
5021 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_index))?;
5022 let eval = self.exact_evaluation(beta)?;
5023 let drift = match hyper.kind {
5024 SpatialAdaptiveHyperKind::LogLambdaMagnitude
5025 | SpatialAdaptiveHyperKind::LogLambdaGradient
5026 | SpatialAdaptiveHyperKind::LogLambdaCurvature => self.adaptive_block_drift_eval(
5027 &eval,
5028 hyper.cache_index,
5029 AdaptiveComponent::from_index(hyper.kind.component_index())?,
5030 HyperDriftKind::Rho,
5031 direction,
5032 )?,
5033 SpatialAdaptiveHyperKind::LogEpsilonMagnitude
5034 | SpatialAdaptiveHyperKind::LogEpsilonGradient
5035 | SpatialAdaptiveHyperKind::LogEpsilonCurvature => self
5036 .adaptive_shared_log_epsilon_drift(
5037 &eval,
5038 hyper.kind.component_index(),
5039 direction,
5040 )?,
5041 };
5042 Ok(Some(drift))
5043 }
5044}
5045
5046fn expect_single_block_state<'a>(
5047 block_states: &'a [ParameterBlockState],
5048 family_name: &str,
5049) -> Result<&'a ParameterBlockState, String> {
5050 crate::block_layout::block_count::validate_block_count::<SmoothError>(
5051 family_name,
5052 1,
5053 block_states.len(),
5054 )?;
5055 Ok(&block_states[0])
5056}
5057
5058fn expect_single_blockspec<'a>(
5059 specs: &'a [ParameterBlockSpec],
5060 family_name: &str,
5061) -> Result<&'a ParameterBlockSpec, String> {
5062 crate::block_layout::block_count::validate_block_count::<SmoothError>(
5063 family_name,
5064 1,
5065 specs.len(),
5066 )?;
5067 Ok(&specs[0])
5068}
5069
5070fn expect_block_idx_zero(block_idx: usize, family_name: &str, context: &str) -> Result<(), String> {
5071 if block_idx != 0 {
5072 return Err(SmoothError::invalid_index(format!(
5073 "{family_name} expects block_idx 0{context}, got {block_idx}"
5074 ))
5075 .into());
5076 }
5077 Ok::<(), _>(())
5078}
5079
5080impl BoundedLinearFamily {
5081 fn bounded_term_derivative_data(
5082 &self,
5083 latent_beta: &Array1<f64>,
5084 ) -> (
5085 Array1<f64>,
5086 Array1<f64>,
5087 Array1<f64>,
5088 Array1<f64>,
5089 Array1<f64>,
5090 ) {
5091 let p = latent_beta.len();
5092 let mut beta_user = latent_beta.clone();
5093 let mut jac_diag = Array1::<f64>::ones(p);
5094 let mut second_diag = Array1::<f64>::zeros(p);
5095 let mut third_diag = Array1::<f64>::zeros(p);
5096 let mut priorthird = Array1::<f64>::zeros(p);
5097 for term in &self.bounded_terms {
5098 let (beta, _, db_dtheta, d2b_dtheta2, d3b_dtheta3) =
5099 bounded_latent_derivatives(latent_beta[term.col_idx], term.min, term.max);
5100 beta_user[term.col_idx] = beta;
5101 jac_diag[term.col_idx] = db_dtheta;
5102 second_diag[term.col_idx] = d2b_dtheta2;
5103 third_diag[term.col_idx] = d3b_dtheta3;
5104 let (_, _, _, prior_neghess_derivative) =
5105 bounded_prior_terms(latent_beta[term.col_idx], &term.prior);
5106 priorthird[term.col_idx] = prior_neghess_derivative;
5107 }
5108 (beta_user, jac_diag, second_diag, third_diag, priorthird)
5109 }
5110
5111 fn user_beta_and_jacobian(&self, latent_beta: &Array1<f64>) -> (Array1<f64>, Array1<f64>) {
5112 let (beta_user, jac_diag, _, _, _) = self.bounded_term_derivative_data(latent_beta);
5113 (beta_user, jac_diag)
5114 }
5115
5116 fn nonlinear_offset_from_latent(&self, latent_beta: &Array1<f64>) -> Array1<f64> {
5117 let mut offset = self.offset.clone();
5118 for term in &self.bounded_terms {
5119 let (beta, _, _) =
5120 bounded_latent_to_user(latent_beta[term.col_idx], term.min, term.max);
5121 offset.scaled_add(beta, &self.design.column(term.col_idx));
5122 }
5123 offset
5124 }
5125
5126 fn effective_design_for_latent(&self, jac_diag: &Array1<f64>) -> Array2<f64> {
5127 let mut x_eff = self.design.clone();
5128 for term in &self.bounded_terms {
5129 x_eff
5130 .column_mut(term.col_idx)
5131 .mapv_inplace(|v| v * jac_diag[term.col_idx]);
5132 }
5133 x_eff
5134 }
5135
5136 fn exacthessian_andgradient(
5137 &self,
5138 latent_beta: &Array1<f64>,
5139 ) -> Result<
5140 (
5141 StandardFamilyObservationState,
5142 Array2<f64>,
5143 Array1<f64>,
5144 f64,
5145 Array1<f64>,
5146 Array1<f64>,
5147 Array1<f64>,
5148 ),
5149 String,
5150 > {
5151 let (_, jac_diag, second_diag, third_diag, priorthird) =
5152 self.bounded_term_derivative_data(latent_beta);
5153 let x_eff = self.effective_design_for_latent(&jac_diag);
5154 let eta =
5155 self.designzeroed.dot(latent_beta) + self.nonlinear_offset_from_latent(latent_beta);
5156 let obs = evaluate_standard_familyobservations(
5157 self.family.clone(),
5158 self.latent_cloglog_state.as_ref(),
5159 self.mixture_link_state.as_ref(),
5160 self.sas_link_state.as_ref(),
5161 &self.y,
5162 &self.weights,
5163 &eta,
5164 )
5165 .map_err(|e| e.to_string())?;
5166
5167 let mut priorgrad = Array1::<f64>::zeros(latent_beta.len());
5168 let mut prior_neghess = Array2::<f64>::zeros((latent_beta.len(), latent_beta.len()));
5169 let mut prior_loglik = 0.0;
5170 for term in &self.bounded_terms {
5171 let (logp, grad, neghess, _) =
5172 bounded_prior_terms(latent_beta[term.col_idx], &term.prior);
5173 prior_loglik += logp;
5174 priorgrad[term.col_idx] += grad;
5175 prior_neghess[[term.col_idx, term.col_idx]] += neghess;
5176 }
5177
5178 let mut hessian = xt_diag_x_dense(x_eff.view(), obs.neghessian_eta.view())?;
5179 let mut gradient = fast_atv(&x_eff, &obs.score);
5180 for term in &self.bounded_terms {
5181 let score_beta = self.design.column(term.col_idx).dot(&obs.score);
5182 hessian[[term.col_idx, term.col_idx]] -= score_beta * second_diag[term.col_idx];
5183 }
5184 hessian += &prior_neghess;
5185 gradient += &priorgrad;
5186
5187 Ok((
5188 obs,
5189 hessian,
5190 gradient,
5191 prior_loglik,
5192 second_diag,
5193 third_diag,
5194 priorthird,
5195 ))
5196 }
5197
5198 fn evaluation_from_latent(
5199 &self,
5200 latent_beta: &Array1<f64>,
5201 ) -> Result<
5202 (
5203 StandardFamilyObservationState,
5204 Array2<f64>,
5205 Array1<f64>,
5206 f64,
5207 ),
5208 String,
5209 > {
5210 let (obs, hessian, gradient, prior_loglik, _, _, _) =
5211 self.exacthessian_andgradient(latent_beta)?;
5212 Ok((obs, hessian, gradient, prior_loglik))
5213 }
5214}
5215
5216impl CustomFamily for BoundedLinearFamily {
5217 fn joint_jeffreys_term_required(&self) -> bool {
5221 true
5222 }
5223
5224 fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
5225 let latent_beta = &expect_single_block_state(block_states, "bounded linear family")?.beta;
5226 let (obs, hessian, gradient, prior_loglik) = self.evaluation_from_latent(latent_beta)?;
5227 Ok(FamilyEvaluation {
5228 log_likelihood: obs.log_likelihood + prior_loglik,
5229 blockworking_sets: vec![BlockWorkingSet::ExactNewton {
5230 gradient,
5231 hessian: SymmetricMatrix::Dense(hessian),
5232 }],
5233 })
5234 }
5235
5236 fn exact_newton_joint_hessian(
5237 &self,
5238 block_states: &[ParameterBlockState],
5239 ) -> Result<Option<Array2<f64>>, String> {
5240 let latent_beta = &expect_single_block_state(block_states, "bounded linear family")?.beta;
5241 let (_, hessian, _, _) = self.evaluation_from_latent(latent_beta)?;
5242 Ok(Some(hessian))
5243 }
5244
5245 fn exact_newton_hessian_directional_derivative(
5246 &self,
5247 block_states: &[ParameterBlockState],
5248 block_idx: usize,
5249 d_beta: &Array1<f64>,
5250 ) -> Result<Option<Array2<f64>>, String> {
5251 expect_block_idx_zero(block_idx, "bounded linear family", "")?;
5252 self.exact_newton_joint_hessian_directional_derivative(block_states, d_beta)
5253 }
5254
5255 fn exact_newton_joint_hessian_directional_derivative(
5256 &self,
5257 block_states: &[ParameterBlockState],
5258 d_beta_flat: &Array1<f64>,
5259 ) -> Result<Option<Array2<f64>>, String> {
5260 let latent_beta = &expect_single_block_state(block_states, "bounded linear family")?.beta;
5261 if d_beta_flat.len() != latent_beta.len() {
5262 return Err(SmoothError::dimension_mismatch(format!(
5263 "bounded linear family directional derivative length mismatch: got {}, expected {}",
5264 d_beta_flat.len(),
5265 latent_beta.len()
5266 ))
5267 .into());
5268 }
5269
5270 let (obs, _, _, _, second_diag, third_diag, priorthird) =
5271 self.exacthessian_andgradient(latent_beta)?;
5272
5273 let (_, jac_diag, _, _, _) = self.bounded_term_derivative_data(latent_beta);
5274 let x_eff = self.effective_design_for_latent(&jac_diag);
5275 let deta = x_eff.dot(d_beta_flat);
5276 let d_neghess_eta = &obs.neghessian_eta_derivative * &deta;
5277
5278 let mut dx_eff = Array2::<f64>::zeros(x_eff.raw_dim());
5279 for term in &self.bounded_terms {
5280 let scale = second_diag[term.col_idx] * d_beta_flat[term.col_idx];
5281 if scale != 0.0 {
5282 let mut col = dx_eff.column_mut(term.col_idx);
5283 col.assign(&self.design.column(term.col_idx));
5284 col.mapv_inplace(|v| v * scale);
5285 }
5286 }
5287
5288 let mut dhessian = xt_diag_x_dense(x_eff.view(), d_neghess_eta.view())?;
5289 let mut wxdx = Array2::<f64>::zeros((x_eff.ncols(), x_eff.ncols()));
5290 for i in 0..x_eff.nrows() {
5291 let wi = obs.neghessian_eta[i];
5292 if wi == 0.0 {
5293 continue;
5294 }
5295 for a in 0..x_eff.ncols() {
5296 let xa = x_eff[[i, a]];
5297 for b in 0..x_eff.ncols() {
5298 wxdx[[a, b]] += wi * (dx_eff[[i, a]] * x_eff[[i, b]] + xa * dx_eff[[i, b]]);
5299 }
5300 }
5301 }
5302 dhessian += &wxdx;
5303
5304 let d_score = -&obs.neghessian_eta * &deta;
5305 for term in &self.bounded_terms {
5306 let score_beta = self.design.column(term.col_idx).dot(&obs.score);
5307 let d_score_beta = self.design.column(term.col_idx).dot(&d_score);
5308 dhessian[[term.col_idx, term.col_idx]] -= d_score_beta * second_diag[term.col_idx]
5309 + score_beta * third_diag[term.col_idx] * d_beta_flat[term.col_idx];
5310 dhessian[[term.col_idx, term.col_idx]] +=
5311 priorthird[term.col_idx] * d_beta_flat[term.col_idx];
5312 }
5313
5314 Ok(Some(dhessian))
5315 }
5316
5317 fn block_geometry(
5318 &self,
5319 block_states: &[ParameterBlockState],
5320 spec: &ParameterBlockSpec,
5321 ) -> Result<(DesignMatrix, Array1<f64>), String> {
5322 if block_states.is_empty() {
5323 return Ok((
5324 DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
5325 self.designzeroed.clone(),
5326 )),
5327 self.offset.clone(),
5328 ));
5329 }
5330 let offset = self.nonlinear_offset_from_latent(
5331 &expect_single_block_state(block_states, "bounded linear family")?.beta,
5332 );
5333 let x = if spec.design.ncols() == self.designzeroed.ncols() {
5334 self.designzeroed.clone()
5335 } else {
5336 return Err(SmoothError::dimension_mismatch(
5337 "bounded linear family design column mismatch",
5338 )
5339 .into());
5340 };
5341 Ok((
5342 DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(x)),
5343 offset,
5344 ))
5345 }
5346
5347 fn block_geometry_is_dynamic(&self) -> bool {
5348 true
5349 }
5350
5351 fn block_geometry_directional_derivative(
5352 &self,
5353 block_states: &[ParameterBlockState],
5354 block_idx: usize,
5355 spec: &ParameterBlockSpec,
5356 d_beta: &Array1<f64>,
5357 ) -> Result<Option<BlockGeometryDirectionalDerivative>, String> {
5358 expect_block_idx_zero(
5359 block_idx,
5360 "bounded linear family",
5361 " for geometry derivative",
5362 )?;
5363 expect_single_block_state(block_states, "bounded linear family")?;
5364 if d_beta.len() != spec.design.ncols() {
5365 return Err(SmoothError::dimension_mismatch(format!(
5366 "bounded linear family geometry derivative direction mismatch: got {}, expected {}",
5367 d_beta.len(),
5368 spec.design.ncols()
5369 ))
5370 .into());
5371 }
5372 let (_, jac_diag, _, _, _) = self.bounded_term_derivative_data(&block_states[0].beta);
5373 let mut d_offset = Array1::<f64>::zeros(self.offset.len());
5374 let has_drift = self
5375 .bounded_terms
5376 .iter()
5377 .any(|term| jac_diag[term.col_idx] != 0.0 && d_beta[term.col_idx] != 0.0);
5378 if !has_drift {
5379 return Ok(Some(BlockGeometryDirectionalDerivative {
5380 d_design: None,
5381 d_offset,
5382 }));
5383 }
5384 for term in &self.bounded_terms {
5385 let col = term.col_idx;
5386 let drift = jac_diag[col] * d_beta[col];
5387 if drift != 0.0 {
5388 d_offset.scaled_add(drift, &self.design.column(col));
5389 }
5390 }
5391 Ok(Some(BlockGeometryDirectionalDerivative {
5392 d_design: None,
5393 d_offset,
5394 }))
5395 }
5396}
5397
5398#[inline]
5399fn dense_diag_gram_chunkrows(p: usize) -> usize {
5400 const MIN_ROWS: usize = 512;
5401 const MAX_ROWS: usize = 2048;
5402 const TARGET_BYTES: usize = 2 * 1024 * 1024;
5403 let bytes_per_row = p.max(1) * std::mem::size_of::<f64>();
5404 (TARGET_BYTES / bytes_per_row).clamp(MIN_ROWS, MAX_ROWS)
5405}
5406
5407fn xt_diag_x_dense(x: ArrayView2<'_, f64>, w: ArrayView1<'_, f64>) -> Result<Array2<f64>, String> {
5408 if x.nrows() != w.len() {
5409 return Err(SmoothError::dimension_mismatch("xt_diag_x_dense row mismatch").into());
5410 }
5411 let (n, p) = x.dim();
5412 if n == 0 || p == 0 {
5413 return Ok(Array2::<f64>::zeros((p, p)));
5414 }
5415
5416 const STREAMING_BYTES_THRESHOLD: usize = 8 * 1024 * 1024;
5417 let dense_work_bytes = n
5418 .checked_mul(p)
5419 .and_then(|cells| cells.checked_mul(std::mem::size_of::<f64>()))
5420 .unwrap_or(usize::MAX);
5421 if dense_work_bytes <= STREAMING_BYTES_THRESHOLD {
5422 let mut weighted = x.to_owned();
5423 ndarray::Zip::from(weighted.rows_mut())
5424 .and(w)
5425 .par_for_each(|mut row, wi| row *= *wi);
5426 return Ok(fast_atb(&x, &weighted));
5427 }
5428
5429 let chunkrows = dense_diag_gram_chunkrows(p).min(n);
5430 let mut weighted_chunk = Array2::<f64>::zeros((chunkrows, p));
5431 let mut out = Array2::<f64>::zeros((p, p));
5432 for row_start in (0..n).step_by(chunkrows) {
5433 let rows = (n - row_start).min(chunkrows);
5434 let x_chunk = x.slice(s![row_start..row_start + rows, ..]);
5435 {
5436 let mut chunk = weighted_chunk.slice_mut(s![0..rows, ..]);
5437 for local_row in 0..rows {
5438 let scale = w[row_start + local_row];
5439 if scale == 0.0 {
5440 chunk.row_mut(local_row).fill(0.0);
5441 continue;
5442 }
5443 for col in 0..p {
5444 chunk[[local_row, col]] = x_chunk[[local_row, col]] * scale;
5445 }
5446 }
5447 }
5448 out += &fast_atb(&x_chunk, &weighted_chunk.slice(s![0..rows, ..]));
5449 }
5450 Ok(out)
5451}
5452
5453fn trace_of_dense_product(a: &Array2<f64>, b: &Array2<f64>) -> Result<f64, String> {
5454 if a.nrows() != a.ncols() || b.nrows() != b.ncols() || a.nrows() != b.nrows() {
5455 return Err(
5456 SmoothError::dimension_mismatch("trace_of_dense_product dimension mismatch").into(),
5457 );
5458 }
5459 let mut trace = 0.0;
5460 for i in 0..a.nrows() {
5461 for j in 0..a.ncols() {
5462 trace += a[[i, j]] * b[[j, i]];
5463 }
5464 }
5465 Ok(trace)
5466}
5467
5468fn exact_bounded_edf(
5469 penalties: &[PenaltySpec],
5470 lambdas: &Array1<f64>,
5471 latent_cov: &Array2<f64>,
5472) -> Result<(Vec<f64>, Vec<f64>, f64), EstimationError> {
5473 if penalties.len() != lambdas.len() {
5474 crate::bail_invalid_estim!(
5475 "bounded EDF penalty/lambda mismatch: {} penalties vs {} lambdas",
5476 penalties.len(),
5477 lambdas.len()
5478 );
5479 }
5480 if latent_cov.nrows() != latent_cov.ncols() {
5481 crate::bail_invalid_estim!("bounded EDF covariance must be square");
5482 }
5483
5484 let p = latent_cov.nrows();
5485 let mut s_lambda = Array2::<f64>::zeros((p, p));
5486 let mut edf_by_block = Vec::with_capacity(penalties.len());
5487 let mut penalty_block_trace = Vec::with_capacity(penalties.len());
5489 let mut trace_sum = 0.0;
5490
5491 for (k, ps) in penalties.iter().enumerate() {
5492 let lambda_k = lambdas[k];
5493 match ps {
5494 PenaltySpec::Block {
5495 local, col_range, ..
5496 } => {
5497 s_lambda
5498 .slice_mut(ndarray::s![col_range.clone(), col_range.clone()])
5499 .scaled_add(lambda_k, local);
5500 let penalty_rank =
5502 local
5503 .nrows()
5504 .saturating_sub(estimate_penalty_nullity(local).map_err(|e| {
5505 EstimationError::InvalidInput(format!("bounded EDF rank failed: {e}"))
5506 })?);
5507 let cov_block = latent_cov.slice(ndarray::s![col_range.clone(), col_range.clone()]);
5509 let trace_k = lambda_k
5510 * trace_of_dense_product(&cov_block.to_owned(), local)
5511 .map_err(EstimationError::InvalidInput)?;
5512 trace_sum += trace_k;
5513 penalty_block_trace.push(trace_k);
5514 let p_k = penalty_rank as f64;
5515 edf_by_block.push((p_k - trace_k).clamp(0.0, p_k));
5516 }
5517 PenaltySpec::Dense(m) | PenaltySpec::DenseWithMean { matrix: m, .. } => {
5518 s_lambda.scaled_add(lambda_k, m);
5519 let penalty_rank = p.saturating_sub(estimate_penalty_nullity(m).map_err(|e| {
5520 EstimationError::InvalidInput(format!("bounded EDF rank failed: {e}"))
5521 })?);
5522 let trace_k = lambda_k
5523 * trace_of_dense_product(latent_cov, m)
5524 .map_err(EstimationError::InvalidInput)?;
5525 trace_sum += trace_k;
5526 penalty_block_trace.push(trace_k);
5527 let p_k = penalty_rank as f64;
5528 edf_by_block.push((p_k - trace_k).clamp(0.0, p_k));
5529 }
5530 }
5531 }
5532
5533 let nullity_total = estimate_penalty_nullity(&s_lambda)
5534 .map_err(|e| EstimationError::InvalidInput(format!("bounded EDF nullity failed: {e}")))?
5535 as f64;
5536 let edf_total = (p as f64 - trace_sum).clamp(nullity_total, p as f64);
5537 Ok((edf_by_block, penalty_block_trace, edf_total))
5538}
5539
5540fn symmetric_positive_definite_inverse_or_pseudo(
5552 precision: &Array2<f64>,
5553) -> Result<Array2<f64>, EstimationError> {
5554 use gam_linalg::faer_ndarray::FaerEigh;
5555 let p = precision.nrows();
5556 if precision.ncols() != p {
5557 crate::bail_invalid_estim!(
5558 "posterior precision inverse requires a square matrix, got {}x{}",
5559 precision.nrows(),
5560 precision.ncols()
5561 );
5562 }
5563 if p == 0 {
5564 return Ok(Array2::<f64>::zeros((0, 0)));
5565 }
5566 let symmetric = (precision + &precision.t().to_owned()) * 0.5;
5567 let (evals, evecs) = symmetric.eigh(faer::Side::Lower).map_err(|e| {
5568 EstimationError::InvalidInput(format!(
5569 "posterior precision eigendecomposition failed: {e}"
5570 ))
5571 })?;
5572 let max_abs_eval = evals.iter().fold(0.0_f64, |acc, &ev| acc.max(ev.abs()));
5573 let tol =
5574 (10.0 * f64::EPSILON * (p as f64) * (p as f64) * max_abs_eval).max(100.0 * f64::EPSILON);
5575 if let Some(&min_eval) = evals
5576 .iter()
5577 .filter(|&&ev| ev < -tol)
5578 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
5579 {
5580 crate::bail_invalid_estim!(
5581 "bounded posterior precision is non-PD at the converged optimum (min eigenvalue \
5582 {min_eval:.6e} < -tol={tol:.6e}); the reported mode is not a strict posterior \
5583 maximum, so a covariance would be meaningless"
5584 );
5585 }
5586 let mut scaled = evecs.clone();
5588 for (j, &ev) in evals.iter().enumerate() {
5589 let inv = if ev > tol { 1.0 / ev } else { 0.0 };
5590 scaled.column_mut(j).mapv_inplace(|v| v * inv);
5591 }
5592 let cov = scaled.dot(&evecs.t());
5593 Ok((&cov + &cov.t().to_owned()) * 0.5)
5594}
5595
5596fn transform_bounded_latent_precision_to_user_internal(
5597 latent_precision: &Array2<f64>,
5598 jac_diag: &Array1<f64>,
5599) -> Result<Array2<f64>, EstimationError> {
5600 let p = latent_precision.nrows();
5601 if latent_precision.ncols() != p || jac_diag.len() != p {
5602 crate::bail_invalid_estim!(
5603 "bounded precision transform dimension mismatch: precision is {}x{}, jacobian has {} entries",
5604 latent_precision.nrows(),
5605 latent_precision.ncols(),
5606 jac_diag.len()
5607 );
5608 }
5609 let mut out = latent_precision.clone();
5610 for i in 0..p {
5611 let scale = jac_diag[i];
5612 if !scale.is_finite() || scale <= 0.0 {
5613 crate::bail_invalid_estim!(
5614 "bounded precision transform requires a positive finite coefficient jacobian; column {i} has {scale}"
5615 );
5616 }
5617 if scale != 1.0 {
5618 out.row_mut(i).mapv_inplace(|v| v / scale);
5619 out.column_mut(i).mapv_inplace(|v| v / scale);
5620 }
5621 }
5622 Ok(out)
5623}
5624
5625fn fit_bounded_term_collection_with_design(
5626 y: ArrayView1<'_, f64>,
5627 weights: ArrayView1<'_, f64>,
5628 offset: ArrayView1<'_, f64>,
5629 spec: &TermCollectionSpec,
5630 design: &TermCollectionDesign,
5631 heuristic_lambdas: Option<&[f64]>,
5632 family: LikelihoodSpec,
5633 options: &FitOptions,
5634) -> Result<FittedTermCollection, EstimationError> {
5635 let conditioning_cols: Vec<usize> = spec
5636 .linear_terms
5637 .iter()
5638 .enumerate()
5639 .filter_map(|(j, linear)| {
5640 (!linear.double_penalty).then_some(design.intercept_range.end + j)
5641 })
5642 .collect();
5643 let conditioning = LinearFitConditioning::from_columns(design, &conditioning_cols);
5644 let dense_design = design.design.to_dense_cow();
5645 let fit_design = conditioning.apply_to_design(&dense_design);
5646 let fit_penalties = conditioning
5647 .transform_blockwise_penalties_to_internal(&design.penalties, design.design.ncols());
5648 if design.linear_constraints.is_some() {
5649 crate::bail_invalid_estim!(
5650 "bounded() terms are not yet compatible with explicit linear constraints"
5651 );
5652 }
5653 let mut bounded_terms = Vec::<BoundedLinearTermMeta>::new();
5654 for (j, term) in spec.linear_terms.iter().enumerate() {
5655 if term.double_penalty
5656 && matches!(
5657 term.coefficient_geometry,
5658 LinearCoefficientGeometry::Bounded { .. }
5659 )
5660 {
5661 crate::bail_invalid_estim!(
5662 "bounded linear term '{}' cannot also use double_penalty",
5663 term.name
5664 );
5665 }
5666 if let LinearCoefficientGeometry::Bounded { min, max, prior } =
5667 term.coefficient_geometry.clone()
5668 {
5669 let col_idx = design.intercept_range.end + j;
5670 let (min_internal, max_internal) = conditioning.internal_bounds_for(col_idx, min, max);
5671 bounded_terms.push(BoundedLinearTermMeta {
5672 col_idx,
5673 min: min_internal,
5674 max: max_internal,
5675 prior,
5676 });
5677 }
5678 }
5679 if bounded_terms.is_empty() {
5680 crate::bail_invalid_estim!("internal bounded fit path called with no bounded terms");
5681 }
5682
5683 let mut designzeroed = fit_design.clone();
5684 let mut initial_beta = Array1::<f64>::zeros(fit_design.ncols());
5685 for term in &bounded_terms {
5686 designzeroed.column_mut(term.col_idx).fill(0.0);
5687 initial_beta[term.col_idx] = bounded_logit(0.5);
5688 }
5689
5690 let initial_log_lambdas = heuristic_lambdas
5691 .map(|vals| Array1::from_vec(vals.to_vec()))
5692 .unwrap_or_else(|| Array1::zeros(fit_penalties.len()));
5693 if initial_log_lambdas.len() != fit_penalties.len() {
5694 crate::bail_invalid_estim!(
5695 "heuristic lambda length mismatch for bounded model: got {}, expected {}",
5696 initial_log_lambdas.len(),
5697 fit_penalties.len()
5698 );
5699 }
5700
5701 let is_beta_logistic = family.is_binomial_beta_logistic();
5702 let family_adapter = BoundedLinearFamily {
5703 family: family.clone(),
5704 latent_cloglog_state: options.latent_cloglog,
5705 mixture_link_state: options
5706 .mixture_link
5707 .clone()
5708 .as_ref()
5709 .map(state_fromspec)
5710 .transpose()
5711 .map_err(EstimationError::InvalidInput)?,
5712 sas_link_state: options
5713 .sas_link
5714 .map(|spec| {
5715 if is_beta_logistic {
5716 state_from_beta_logisticspec(spec)
5717 } else {
5718 state_from_sasspec(spec)
5719 }
5720 })
5721 .transpose()
5722 .map_err(EstimationError::InvalidInput)?,
5723 y: y.to_owned(),
5724 weights: weights.to_owned(),
5725 design: fit_design.clone(),
5726 designzeroed: designzeroed.clone(),
5727 offset: offset.to_owned(),
5728 bounded_terms: bounded_terms.clone(),
5729 };
5730 let blockspec = ParameterBlockSpec {
5731 name: "eta".to_string(),
5732 design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(designzeroed)),
5733 offset: offset.to_owned(),
5734 penalties: fit_penalties
5735 .iter()
5736 .map(|ps| match ps {
5737 PenaltySpec::Block {
5738 local, col_range, ..
5739 } => PenaltyMatrix::Blockwise {
5740 local: local.clone(),
5741 col_range: col_range.clone(),
5742 total_dim: design.design.ncols(),
5743 },
5744 PenaltySpec::Dense(m) | PenaltySpec::DenseWithMean { matrix: m, .. } => {
5745 PenaltyMatrix::Dense(m.clone())
5746 }
5747 })
5748 .collect(),
5749 nullspace_dims: design.nullspace_dims.clone(),
5750 initial_log_lambdas,
5751 initial_beta: Some(initial_beta),
5752 gauge_priority: 100,
5753 jacobian_callback: Some(Arc::new(BoundedEffectiveJacobian {
5759 design: fit_design.clone(),
5760 bounded_terms: bounded_terms.clone(),
5761 })),
5762 stacked_design: None,
5763 stacked_offset: None,
5764 };
5765 let fit = fit_custom_family(
5766 &family_adapter,
5767 &[blockspec],
5768 &BlockwiseFitOptions {
5769 inner_max_cycles: options.max_iter,
5770 inner_tol: options.tol,
5771 outer_max_iter: options.max_iter,
5772 outer_tol: options.tol,
5773 compute_covariance: false,
5783 ..BlockwiseFitOptions::default()
5784 },
5785 )
5786 .map_err(EstimationError::CustomFamily)?;
5787
5788 let latent_beta = fit.block_states[0].beta.clone();
5789 let (beta_user_internal, jac_diag) = family_adapter.user_beta_and_jacobian(&latent_beta);
5790 let beta_user = conditioning.backtransform_beta(&beta_user_internal);
5791
5792 let (eta_state, h_data, _, _) = family_adapter
5793 .evaluation_from_latent(&latent_beta)
5794 .map_err(EstimationError::InvalidInput)?;
5795 let p_fit = fit_design.ncols();
5796 let mut s_lambda_internal = Array2::<f64>::zeros((p_fit, p_fit));
5797 for (k, penalty) in fit_penalties.iter().enumerate() {
5798 match penalty {
5799 PenaltySpec::Block {
5800 local, col_range, ..
5801 } => {
5802 s_lambda_internal
5803 .slice_mut(ndarray::s![col_range.clone(), col_range.clone()])
5804 .scaled_add(fit.lambdas[k], local);
5805 }
5806 PenaltySpec::Dense(m) | PenaltySpec::DenseWithMean { matrix: m, .. } => {
5807 s_lambda_internal.scaled_add(fit.lambdas[k], m);
5808 }
5809 }
5810 }
5811 let mut latent_precision = h_data.clone();
5812 latent_precision += &s_lambda_internal;
5813 let user_precision_internal =
5814 transform_bounded_latent_precision_to_user_internal(&latent_precision, &jac_diag)?;
5815 let penalized_hessian =
5816 conditioning.transform_penalized_hessian_to_original(&user_precision_internal);
5817
5818 let beta_covariance_unscaled = if options.compute_inference {
5846 Some(symmetric_positive_definite_inverse_or_pseudo(
5847 &penalized_hessian,
5848 )?)
5849 } else {
5850 None
5851 };
5852 let latent_cov = if options.compute_inference {
5858 Some(symmetric_positive_definite_inverse_or_pseudo(
5859 &latent_precision,
5860 )?)
5861 } else {
5862 None
5863 };
5864 let s_lambda_original = weighted_blockwise_penalty_sum(
5865 &design.penalties,
5866 fit.lambdas.as_slice().unwrap(),
5867 design.design.ncols(),
5868 );
5869 let penalty_term = beta_user.dot(&s_lambda_original.dot(&beta_user));
5870 let deviance = if family.is_gaussian_identity() {
5871 y.iter()
5872 .zip(eta_state.mu.iter())
5873 .zip(weights.iter())
5874 .map(|((&yy, &mu), &w)| w.max(0.0) * (yy - mu) * (yy - mu))
5875 .sum()
5876 } else {
5877 -2.0 * eta_state.log_likelihood
5878 };
5879 let (edf_by_block, penalty_block_trace, edf_total) = if let Some(cov) = latent_cov.as_ref() {
5880 exact_bounded_edf(&fit_penalties, &fit.lambdas, cov)?
5881 } else {
5882 (
5883 vec![0.0; fit_penalties.len()],
5884 vec![0.0; fit_penalties.len()],
5885 0.0,
5886 )
5887 };
5888
5889 let glm_likelihood = gam_spec::GlmLikelihoodSpec::canonical(family.clone());
5901 let standard_deviation = if family.is_gaussian_identity() {
5902 let denom = if options.compute_inference {
5903 (y.len() as f64 - edf_total).max(1.0)
5904 } else {
5905 (y.len() as f64).max(1.0)
5906 };
5907 (deviance / denom).sqrt()
5908 } else {
5909 1.0
5910 };
5911 let cov_scale = glm_likelihood
5912 .coefficient_covariance_scale(standard_deviation * standard_deviation)
5913 .max(f64::MIN_POSITIVE);
5914 let dispersion = gam_solve::estimate::dispersion_from_likelihood(&glm_likelihood, standard_deviation);
5915 let beta_covariance = beta_covariance_unscaled.map(|mut cov| {
5921 if cov_scale != 1.0 {
5922 cov.mapv_inplace(|v| v * cov_scale);
5923 }
5924 cov
5925 });
5926 let beta_standard_errors = beta_covariance
5927 .as_ref()
5928 .map(|cov| Array1::from_iter((0..cov.nrows()).map(|i| cov[[i, i]].max(0.0).sqrt())));
5929
5930 let geometry = Some(gam_solve::estimate::FitGeometry {
5931 penalized_hessian: penalized_hessian.clone().into(),
5932 working_weights: eta_state.fisherweight.clone(),
5933 working_response: {
5934 let mut working_response = eta_state.eta.clone();
5935 for i in 0..working_response.len() {
5936 let wi = eta_state.fisherweight[i].max(1e-12);
5937 working_response[i] += eta_state.score[i] / wi;
5938 }
5939 working_response
5940 },
5941 });
5942 let max_abs_eta = eta_state
5943 .eta
5944 .iter()
5945 .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
5946 Ok(FittedTermCollection {
5947 fit: {
5948 let log_lambdas = fit.lambdas.mapv(|v| v.max(1e-300).ln());
5949 let inf = FitInference {
5950 edf_by_block,
5951 penalty_block_trace,
5952 edf_total,
5953 smoothing_correction: None,
5954 penalized_hessian: penalized_hessian.clone().into(),
5957 working_weights: eta_state.fisherweight.clone(),
5958 working_response: {
5959 let mut working_response = eta_state.eta.clone();
5960 for i in 0..working_response.len() {
5961 let wi = eta_state.fisherweight[i].max(1e-12);
5962 working_response[i] += eta_state.score[i] / wi;
5963 }
5964 working_response
5965 },
5966 reparam_qs: None,
5967 dispersion,
5968 beta_covariance: beta_covariance
5969 .clone()
5970 .map(gam_problem::dispersion_cov::PhiScaledCovariance::from),
5971 beta_standard_errors,
5972 beta_covariance_corrected: None,
5973 beta_standard_errors_corrected: None,
5974 beta_covariance_frequentist: None,
5975 coefficient_influence: None,
5976 weighted_gram: None,
5977 bias_correction_beta: None,
5978 };
5979 let covariance_conditional = beta_covariance;
5980 let pirls_status_val = if fit.outer_converged {
5981 gam_solve::pirls::PirlsStatus::Converged
5982 } else {
5983 gam_solve::pirls::PirlsStatus::StalledAtValidMinimum
5984 };
5985 UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
5986 blocks: vec![gam_solve::estimate::FittedBlock {
5987 beta: beta_user.clone(),
5988 role: gam_problem::BlockRole::Mean,
5989 edf: edf_total,
5990 lambdas: fit.lambdas.clone(),
5991 }],
5992 log_lambdas,
5993 lambdas: fit.lambdas,
5994 likelihood_scale: family.default_scale_metadata(),
5995 likelihood_family: Some(family),
5996 log_likelihood_normalization:
5997 gam_spec::LogLikelihoodNormalization::UserProvided,
5998 log_likelihood: eta_state.log_likelihood,
5999 deviance,
6000 reml_score: fit.penalized_objective,
6001 stable_penalty_term: penalty_term,
6002 penalized_objective: fit.penalized_objective,
6003 used_device: false,
6004 outer_iterations: fit.outer_iterations,
6005 outer_converged: fit.outer_converged,
6006 outer_gradient_norm: fit.outer_gradient_norm,
6007 standard_deviation,
6008 covariance_conditional,
6009 covariance_corrected: None,
6010 inference: Some(inf),
6011 fitted_link: gam_solve::estimate::FittedLinkState::Standard(None),
6012 geometry,
6013 block_states: Vec::new(),
6014 pirls_status: pirls_status_val,
6015 max_abs_eta,
6016 constraint_kkt: None,
6017 artifacts: gam_solve::estimate::FitArtifacts {
6018 pirls: None,
6019 ..Default::default()
6020 },
6021 inner_cycles: 0,
6022 })?
6023 },
6024 design: design.clone(),
6025 adaptive_diagnostics: None,
6026 })
6027}
6028
6029fn enforce_term_constraint_feasibility(
6030 design: &TermCollectionDesign,
6031 fit: &UnifiedFitResult,
6032) -> Result<(), EstimationError> {
6033 const CONSTRAINT_FEASIBILITY_RAW_TOL: f64 = 1e-7;
6047 let tol = CONSTRAINT_FEASIBILITY_RAW_TOL;
6048 let smooth_start = design
6049 .design
6050 .ncols()
6051 .saturating_sub(design.smooth.total_smooth_cols());
6052 let mut violations: Vec<String> = Vec::new();
6053 for term in &design.smooth.terms {
6054 let gr = (smooth_start + term.coeff_range.start)..(smooth_start + term.coeff_range.end);
6055 let beta_local = fit.beta.slice(s![gr.clone()]).to_owned();
6056 if let Some(lb) = term.lower_bounds_local.as_ref() {
6057 let mut worst = 0.0_f64;
6058 let mut worst_idx = 0usize;
6059 for i in 0..lb.len().min(beta_local.len()) {
6060 if lb[i].is_finite() {
6061 let viol = (lb[i] - beta_local[i]).max(0.0);
6062 if viol > worst {
6063 worst = viol;
6064 worst_idx = i;
6065 }
6066 }
6067 }
6068 if worst > tol {
6069 violations.push(format!(
6070 "term='{}' kind=lower-bound maxviolation={:.3e} coeff_index={}",
6071 term.name, worst, worst_idx
6072 ));
6073 }
6074 }
6075 if let Some(lin) = term.linear_constraints_local.as_ref() {
6076 let mut worst = 0.0_f64;
6077 let mut worstrow = 0usize;
6078 for i in 0..lin.a.nrows() {
6079 let norm = lin.a.row(i).dot(&lin.a.row(i)).sqrt();
6080 let inv = if norm > 0.0 { 1.0 / norm } else { 0.0 };
6081 let s = (lin.a.row(i).dot(&beta_local) - lin.b[i]) * inv;
6082 let viol = (-s).max(0.0);
6083 if viol > worst {
6084 worst = viol;
6085 worstrow = i;
6086 }
6087 }
6088 if worst > tol {
6089 violations.push(format!(
6090 "term='{}' kind=linear-inequality maxviolation={:.3e} row={}",
6091 term.name, worst, worstrow
6092 ));
6093 }
6094 }
6095 }
6096
6097 if !violations.is_empty() {
6098 let mut msg = format!(
6099 "constraint violation after fit ({} violating term constraints): {}",
6100 violations.len(),
6101 violations.join(" | ")
6102 );
6103 if let Some(kkt) = fit.constraint_kkt.as_ref() {
6104 msg.push_str(&format!(
6105 "; KKT[primal={:.3e}, dual={:.3e}, comp={:.3e}, stat={:.3e}]",
6106 kkt.primal_feasibility, kkt.dual_feasibility, kkt.complementarity, kkt.stationarity
6107 ));
6108 }
6109 return Err(EstimationError::ParameterConstraintViolation(msg));
6110 }
6111 Ok(())
6112}
6113
6114fn stratified_spatial_subsample(
6115 data: ArrayView2<'_, f64>,
6116 spec: &TermCollectionSpec,
6117 target_size: usize,
6118) -> Vec<usize> {
6119 use rand::SeedableRng;
6120 use rand::rngs::StdRng;
6121 use rand::seq::SliceRandom;
6122
6123 let n = data.nrows();
6124 if n <= target_size {
6125 return (0..n).collect();
6126 }
6127
6128 let spatial_cols: Option<Vec<usize>> =
6129 spec.smooth_terms.iter().find_map(|term| match &term.basis {
6130 SmoothBasisSpec::ThinPlate { feature_cols, .. }
6131 | SmoothBasisSpec::Matern { feature_cols, .. }
6132 | SmoothBasisSpec::Duchon { feature_cols, .. } => {
6133 if !feature_cols.is_empty() {
6134 Some(feature_cols.clone())
6135 } else {
6136 None
6137 }
6138 }
6139 _ => None,
6140 });
6141
6142 let cols = match spatial_cols {
6143 Some(c) if !c.is_empty() => c,
6144 _ => {
6145 let mut rng = StdRng::seed_from_u64(spatial_subsample_seed(data, &[], target_size));
6146 let mut indices: Vec<usize> = (0..n).collect();
6147 indices.shuffle(&mut rng);
6148 indices.truncate(target_size);
6149 indices.sort_unstable();
6150 return indices;
6151 }
6152 };
6153 let mut rng = StdRng::seed_from_u64(spatial_subsample_seed(data, &cols, target_size));
6154
6155 let d = cols.len();
6156 let mut mins = vec![f64::INFINITY; d];
6157 let mut maxs = vec![f64::NEG_INFINITY; d];
6158 for i in 0..n {
6159 for (ax, &col) in cols.iter().enumerate() {
6160 let v = data[[i, col]];
6161 if v < mins[ax] {
6162 mins[ax] = v;
6163 }
6164 if v > maxs[ax] {
6165 maxs[ax] = v;
6166 }
6167 }
6168 }
6169
6170 const TARGET_POINTS_PER_CELL: usize = 5;
6174 let total_cells_target = (target_size / TARGET_POINTS_PER_CELL).max(1);
6175 let cells_per_axis = ((total_cells_target as f64).powf(1.0 / d as f64)).ceil() as usize;
6176 let cells_per_axis = cells_per_axis.max(1);
6177
6178 let mut cell_members: std::collections::HashMap<Vec<usize>, Vec<usize>> =
6179 std::collections::HashMap::new();
6180 for i in 0..n {
6181 let mut cell_key = Vec::with_capacity(d);
6182 for (ax, &col) in cols.iter().enumerate() {
6183 let range = maxs[ax] - mins[ax];
6184 let cell = if range <= 0.0 {
6185 0
6186 } else {
6187 let frac = (data[[i, col]] - mins[ax]) / range;
6188 (frac * cells_per_axis as f64).floor() as usize
6189 };
6190 cell_key.push(cell.min(cells_per_axis - 1));
6191 }
6192 cell_members.entry(cell_key).or_default().push(i);
6193 }
6194
6195 let mut selected: Vec<usize> = Vec::with_capacity(target_size);
6196 let mut remaining_budget = target_size;
6197 let mut remaining_population = n;
6198
6199 let mut cells: Vec<(Vec<usize>, Vec<usize>)> = cell_members.into_iter().collect();
6200 cells.sort_by(|a, b| a.0.cmp(&b.0));
6201
6202 for (_, members) in &mut cells {
6203 if remaining_budget == 0 {
6204 break;
6205 }
6206 let alloc = ((members.len() as f64 / remaining_population as f64) * remaining_budget as f64)
6207 .round() as usize;
6208 let alloc = alloc.max(1).min(members.len()).min(remaining_budget);
6209 members.shuffle(&mut rng);
6210 selected.extend_from_slice(&members[..alloc]);
6211 remaining_budget = remaining_budget.saturating_sub(alloc);
6212 remaining_population = remaining_population.saturating_sub(members.len());
6213 }
6214
6215 if selected.len() > target_size {
6216 selected.shuffle(&mut rng);
6217 selected.truncate(target_size);
6218 }
6219
6220 selected.sort_unstable();
6221 selected
6222}
6223
6224fn spatial_subsample_seed(
6225 data: ArrayView2<'_, f64>,
6226 spatial_cols: &[usize],
6227 target_size: usize,
6228) -> u64 {
6229 let mut state = 0x5350_4154_4941_4C53_u64;
6230 spatial_seed_mix(&mut state, data.nrows() as u64);
6231 spatial_seed_mix(&mut state, data.ncols() as u64);
6232 spatial_seed_mix(&mut state, target_size as u64);
6233 spatial_seed_mix(&mut state, spatial_cols.len() as u64);
6234 for &col in spatial_cols {
6235 spatial_seed_mix(&mut state, col as u64);
6236 }
6237
6238 if data.nrows() > 0 {
6239 let mid = data.nrows() / 2;
6240 let last = data.nrows() - 1;
6241 for &row in &[0usize, mid, last] {
6242 for &col in spatial_cols {
6243 let value = data[[row, col]];
6244 spatial_seed_mix(&mut state, value.to_bits());
6245 }
6246 }
6247 }
6248 state
6249}
6250
6251#[inline]
6252fn spatial_seed_mix(state: &mut u64, value: u64) {
6253 let mut s = value.wrapping_add(*state);
6256 let z = gam_linalg::utils::splitmix64(&mut s);
6257 *state ^= z;
6258 *state = (*state).rotate_left(27).wrapping_mul(0x3C79_AC49_2BA7_B653);
6259}
6260
6261fn sampled_rows(data: ArrayView2<'_, f64>, indices: &[usize]) -> Array2<f64> {
6262 let mut sampled = Array2::<f64>::zeros((indices.len(), data.ncols()));
6263 for (new_row, &orig_row) in indices.iter().enumerate() {
6264 sampled.row_mut(new_row).assign(&data.row(orig_row));
6265 }
6266 sampled
6267}
6268
6269fn spatial_term_user_centers(term: &SmoothTermSpec) -> Option<ArrayView2<'_, f64>> {
6270 match spatial_term_center_strategy(term) {
6271 Some(CenterStrategy::UserProvided(centers)) => Some(centers.view()),
6272 _ => None,
6273 }
6274}
6275
6276fn finite_centered_axis_contrasts(values: &[f64], expected_dim: usize) -> Option<Vec<f64>> {
6277 if values.len() != expected_dim || expected_dim <= 1 {
6278 return None;
6279 }
6280 if values.iter().any(|value| !value.is_finite()) {
6281 return None;
6282 }
6283 Some(center_aniso_log_scales(values))
6284}
6285
6286fn blended_pilot_axis_contrasts(
6287 pilot_data: ArrayView2<'_, f64>,
6288 term: &SmoothTermSpec,
6289 centers: ArrayView2<'_, f64>,
6290) -> Option<Vec<f64>> {
6291 let d = centers.ncols();
6292 if d <= 1 {
6293 return None;
6294 }
6295 let center_eta = initial_aniso_contrasts(centers);
6296 let data_eta = standardized_spatial_term_data(pilot_data, term)
6297 .ok()
6298 .and_then(|x| finite_centered_axis_contrasts(&initial_aniso_contrasts(x.view()), d));
6299 let center_eta = finite_centered_axis_contrasts(¢er_eta, d)?;
6300 let blended = match data_eta {
6301 Some(data_eta) => center_eta
6302 .iter()
6303 .zip(data_eta.iter())
6304 .map(|(&from_centers, &from_data)| 0.5 * (from_centers + from_data))
6305 .collect::<Vec<_>>(),
6306 None => center_eta,
6307 };
6308 finite_centered_axis_contrasts(&blended, d)
6309}
6310
6311fn apply_pilot_spatial_psi_reseed(
6312 pilot_data: ArrayView2<'_, f64>,
6313 spec: &TermCollectionSpec,
6314 spatial_terms: &[usize],
6315 kappa_options: &SpatialLengthScaleOptimizationOptions,
6316) -> Result<TermCollectionSpec, EstimationError> {
6317 let dims_per_term = spatial_dims_per_term(spec, spatial_terms);
6318 let use_aniso = has_aniso_terms(spec, spatial_terms);
6319 let log_kappa0 = if use_aniso {
6320 SpatialLogKappaCoords::from_length_scales_aniso(spec, spatial_terms, kappa_options)
6321 } else {
6322 SpatialLogKappaCoords::from_length_scales(spec, spatial_terms, kappa_options)
6323 };
6324 let log_kappa0 = log_kappa0.reseed_from_data(pilot_data, spec, spatial_terms, kappa_options);
6325 let log_kappa_lower = if use_aniso {
6326 SpatialLogKappaCoords::lower_bounds_aniso_from_data(
6327 pilot_data,
6328 spec,
6329 spatial_terms,
6330 &dims_per_term,
6331 kappa_options,
6332 )
6333 } else {
6334 SpatialLogKappaCoords::lower_bounds_from_data(
6335 pilot_data,
6336 spec,
6337 spatial_terms,
6338 kappa_options,
6339 )
6340 };
6341 let log_kappa_upper = if use_aniso {
6342 SpatialLogKappaCoords::upper_bounds_aniso_from_data(
6343 pilot_data,
6344 spec,
6345 spatial_terms,
6346 &dims_per_term,
6347 kappa_options,
6348 )
6349 } else {
6350 SpatialLogKappaCoords::upper_bounds_from_data(
6351 pilot_data,
6352 spec,
6353 spatial_terms,
6354 kappa_options,
6355 )
6356 };
6357 log_kappa0
6358 .clamp_to_bounds(&log_kappa_lower, &log_kappa_upper)
6359 .apply_tospec(spec, spatial_terms)
6360}
6361
6362pub(crate) fn apply_spatial_anisotropy_pilot_initializer(
6363 data: ArrayView2<'_, f64>,
6364 spec: &mut TermCollectionSpec,
6365 spatial_terms: &[usize],
6366 target_size: usize,
6367 kappa_options: &SpatialLengthScaleOptimizationOptions,
6368) -> usize {
6369 if target_size == 0 || data.nrows() <= target_size.saturating_mul(2) || spatial_terms.is_empty()
6370 {
6371 return 0;
6372 }
6373 if !has_aniso_terms(spec, spatial_terms) {
6374 return 0;
6375 }
6376 let indices = stratified_spatial_subsample(data, spec, target_size);
6377 let pilot_data = sampled_rows(data, &indices);
6378 let mut working = spec.clone();
6379 let mut updated_terms = 0usize;
6380 const GEOMETRY_UPDATES: usize = 2;
6381
6382 for pass in 0..GEOMETRY_UPDATES {
6383 let planned_terms = match plan_joint_spatial_centers_for_term_blocks(
6384 pilot_data.view(),
6385 &[working.smooth_terms.clone()],
6386 )
6387 .and_then(|mut blocks| {
6388 blocks.pop().ok_or_else(|| {
6389 BasisError::InvalidInput(
6390 "pilot geometry initializer produced no smooth-term block".to_string(),
6391 )
6392 })
6393 }) {
6394 Ok(terms) => terms,
6395 Err(err) => {
6396 log::warn!(
6397 "[spatial-kappa] pilot geometry initializer skipped after center planning failed: {err}"
6398 );
6399 return updated_terms;
6400 }
6401 };
6402
6403 for &term_idx in spatial_terms {
6404 let Some(current_eta) = get_spatial_aniso_log_scales(&working, term_idx) else {
6405 continue;
6406 };
6407 let Some(d) = get_spatial_feature_dim(&working, term_idx) else {
6408 continue;
6409 };
6410 if d <= 1 || current_eta.len() != d {
6411 continue;
6412 }
6413 let Some(planned_term) = planned_terms.get(term_idx) else {
6414 continue;
6415 };
6416 let Some(centers) = spatial_term_user_centers(planned_term) else {
6417 continue;
6418 };
6419 let Some(eta) = blended_pilot_axis_contrasts(pilot_data.view(), planned_term, centers)
6420 else {
6421 continue;
6422 };
6423 if set_spatial_aniso_log_scales(&mut working, term_idx, eta).is_ok() {
6424 updated_terms += usize::from(pass == 0);
6425 }
6426 }
6427
6428 match apply_pilot_spatial_psi_reseed(
6429 pilot_data.view(),
6430 &working,
6431 spatial_terms,
6432 kappa_options,
6433 ) {
6434 Ok(updated) => {
6435 working = updated;
6436 }
6437 Err(err) => {
6438 log::warn!(
6439 "[spatial-kappa] pilot geometry ψ reseed skipped after deterministic initializer error: {err}"
6440 );
6441 break;
6442 }
6443 }
6444 }
6445
6446 if updated_terms > 0 {
6447 log::info!(
6448 "[spatial-kappa] initialized anisotropy from {}-row pilot geometry for {} spatial term(s); proceeding to full-data optimization",
6449 indices.len(),
6450 updated_terms
6451 );
6452 *spec = working;
6453 }
6454 updated_terms
6455}
6456
6457pub(crate) fn spatial_length_scale_term_indices(spec: &TermCollectionSpec) -> Vec<usize> {
6458 spec.smooth_terms
6459 .iter()
6460 .enumerate()
6461 .filter_map(|(idx, _)| spatial_term_supports_hyper_optimization(spec, idx).then_some(idx))
6462 .collect()
6463}
6464
6465fn fit_score(fit: &UnifiedFitResult) -> f64 {
6477 if fit.reml_score.is_finite() {
6478 return fit.reml_score;
6479 }
6480 let score = 0.5 * fit.deviance + 0.5 * fit.stable_penalty_term;
6481 if score.is_finite() {
6482 score
6483 } else {
6484 f64::INFINITY
6485 }
6486}
6487
6488fn is_recoverable_trial_point_error(err: &EstimationError) -> bool {
6510 matches!(err, EstimationError::BasisError(_))
6511 || err.is_inner_solve_retreat()
6512 || is_recoverable_fit_inference_finiteness_error(err)
6513}
6514
6515fn is_recoverable_fit_inference_finiteness_error(err: &EstimationError) -> bool {
6516 let EstimationError::InvalidInput(message) = err else {
6517 return false;
6518 };
6519
6520 message.contains("must be finite")
6521 && [
6522 "fit_result.beta_covariance_frequentist",
6523 "fit_result.coefficient_influence",
6524 "fit_result.weighted_gram",
6525 ]
6526 .iter()
6527 .any(|field| message.contains(field))
6528}
6529
6530#[cfg(test)]
6531mod spatial_trial_recovery_tests {
6532 use super::*;
6533
6534 #[test]
6535 fn nonfinite_frequentist_covariance_is_recoverable_trial_point() {
6536 let err = EstimationError::InvalidInput(
6537 "fit_result.beta_covariance_frequentist[0] must be finite, got NaN".to_string(),
6538 );
6539
6540 assert!(
6541 is_recoverable_trial_point_error(&err),
6542 "singular trial-point curvature should make spatial κ retreat, not abort"
6543 );
6544 }
6545
6546 #[test]
6547 fn arbitrary_invalid_input_remains_fatal_trial_point_error() {
6548 let err = EstimationError::InvalidInput("outer rho bounds are invalid".to_string());
6549
6550 assert!(
6551 !is_recoverable_trial_point_error(&err),
6552 "the spatial κ recovery gate must not mask unrelated invalid inputs"
6553 );
6554 }
6555}
6556
6557fn require_successful_spatial_optimization_result<T>(
6558 initial_score: f64,
6559 result: Result<Option<(T, f64)>, EstimationError>,
6560) -> Result<T, EstimationError> {
6561 match result {
6562 Ok(Some((value, exact_score))) => {
6563 const SCORE_DRIFT_ABS_TOL: f64 = 1e-6;
6572 const SCORE_DRIFT_REL_TOL: f64 = 1e-8;
6573 let tol = SCORE_DRIFT_ABS_TOL.max(initial_score.abs() * SCORE_DRIFT_REL_TOL);
6574 if exact_score <= initial_score + tol {
6575 Ok(value)
6576 } else {
6577 Err(EstimationError::RemlOptimizationFailed(format!(
6578 "spatial kappa optimization made REML score worse ({initial_score:.6e} -> {exact_score:.6e})"
6579 )))
6580 }
6581 }
6582 Ok(None) => Err(EstimationError::RemlOptimizationFailed(
6583 "spatial kappa optimization is unavailable for one or more eligible spatial terms"
6584 .to_string(),
6585 )),
6586 Err(err) => Err(EstimationError::RemlOptimizationFailed(format!(
6587 "spatial kappa optimization failed: {err}"
6588 ))),
6589 }
6590}
6591
6592fn external_opts_for_design(
6593 family: &LikelihoodSpec,
6594 design: &TermCollectionDesign,
6595 options: &FitOptions,
6596) -> ExternalOptimOptions {
6597 ExternalOptimOptions {
6598 family: family.clone(),
6599 latent_cloglog: options.latent_cloglog,
6600 mixture_link: options.mixture_link.clone(),
6601 optimize_mixture: options.optimize_mixture,
6602 sas_link: options.sas_link,
6603 optimize_sas: options.optimize_sas,
6604 compute_inference: options.compute_inference,
6605 skip_rho_posterior_inference: options.skip_rho_posterior_inference,
6606 max_iter: options.max_iter,
6607 tol: options.tol,
6608 nullspace_dims: design.nullspace_dims.clone(),
6609 linear_constraints: design.linear_constraints.clone(),
6610 firth_bias_reduction: Some(options.firth_bias_reduction),
6611 penalty_shrinkage_floor: options.penalty_shrinkage_floor,
6612 rho_prior: options.rho_prior.clone(),
6613 kronecker_penalty_system: design.kronecker_penalty_system(),
6616 kronecker_factored: design
6617 .smooth
6618 .terms
6619 .iter()
6620 .find_map(|t| t.kronecker_factored.clone()),
6621 persist_warm_start_disk: options.persist_warm_start_disk,
6622 }
6623}
6624
6625fn evaluate_joint_reml_outer_eval_at_theta(
6633 evaluator: &mut gam_solve::estimate::ExternalJointHyperEvaluator<'_>,
6634 design: &TermCollectionDesign,
6635 theta: &Array1<f64>,
6636 rho_dim: usize,
6637 hyper_dirs: Vec<gam_solve::estimate::reml::DirectionalHyperParam>,
6638 warm_start_beta: Option<ArrayView1<'_, f64>>,
6639 order: gam_solve::rho_optimizer::OuterEvalOrder,
6640 design_revision: Option<u64>,
6641) -> Result<
6642 (
6643 f64,
6644 Array1<f64>,
6645 gam_problem::HessianResult,
6646 ),
6647 EstimationError,
6648> {
6649 evaluator.evaluate_with_order(
6650 &design.design,
6651 &design.penalties,
6652 &design.nullspace_dims,
6653 design.linear_constraints.clone(),
6654 theta,
6655 rho_dim,
6656 hyper_dirs,
6657 warm_start_beta,
6658 "evaluate_joint_reml_outer_eval_at_theta",
6659 order,
6660 design_revision,
6661 )
6662}
6663
6664fn evaluate_joint_reml_efs_at_theta(
6665 evaluator: &mut gam_solve::estimate::ExternalJointHyperEvaluator<'_>,
6666 design: &TermCollectionDesign,
6667 theta: &Array1<f64>,
6668 rho_dim: usize,
6669 hyper_dirs: Vec<gam_solve::estimate::reml::DirectionalHyperParam>,
6670 warm_start_beta: Option<ArrayView1<'_, f64>>,
6671 design_revision: Option<u64>,
6672) -> Result<gam_problem::EfsEval, EstimationError> {
6673 evaluator.evaluate_efs(
6674 &design.design,
6675 &design.penalties,
6676 &design.nullspace_dims,
6677 design.linear_constraints.clone(),
6678 theta,
6679 rho_dim,
6680 hyper_dirs,
6681 warm_start_beta,
6682 "evaluate_joint_reml_efs_at_theta",
6683 design_revision,
6684 )
6685}
6686
6687fn exact_joint_spatial_outer_hessian_available(
6688 family: &LikelihoodSpec,
6689 design: &TermCollectionDesign,
6690) -> bool {
6691 let family_supported = match &family.response {
6714 ResponseFamily::Gaussian
6715 | ResponseFamily::Binomial
6716 | ResponseFamily::Poisson
6717 | ResponseFamily::Tweedie { .. }
6718 | ResponseFamily::NegativeBinomial { .. }
6719 | ResponseFamily::Beta { .. }
6720 | ResponseFamily::Gamma
6721 | ResponseFamily::RoystonParmar => true,
6722 };
6723 family_supported && design.design.ncols() > 0
6726}
6727
6728fn smooth_term_penalty_index(
6729 spec: &TermCollectionSpec,
6730 design: &TermCollectionDesign,
6731 term_idx: usize,
6732) -> Option<usize> {
6733 if term_idx >= design.smooth.terms.len() || term_idx >= spec.smooth_terms.len() {
6734 return None;
6735 }
6736 if design.smooth.terms[term_idx].penalties_local.is_empty() {
6737 return None;
6738 }
6739 let linear_penalties = spec
6740 .linear_terms
6741 .iter()
6742 .filter(|t| t.double_penalty)
6743 .count()
6744 * 2;
6745 let random_penalties = design
6746 .random_effect_ranges
6747 .iter()
6748 .filter(|(_, range)| !range.is_empty())
6749 .count();
6750 let smooth_offset = linear_penalties + random_penalties;
6751 let local_offset = design
6752 .smooth
6753 .terms
6754 .iter()
6755 .take(term_idx)
6756 .map(|term| term.penalties_local.len())
6757 .sum::<usize>();
6758 Some(smooth_offset + local_offset)
6759}
6760
6761fn try_build_spatial_term_log_kappa_derivativeinfo(
6762 data: ArrayView2<'_, f64>,
6763 resolvedspec: &TermCollectionSpec,
6764 design: &TermCollectionDesign,
6765 term_idx: usize,
6766) -> Result<Option<SpatialPsiDerivative>, EstimationError> {
6767 let Some((
6768 global_range,
6769 total_p,
6770 x_psi_local,
6771 s_psi_local_check,
6772 x_psi_psi_local,
6773 s_psi_psi_local,
6774 s_psi_components_local,
6775 s_psi_psi_components_local,
6776 implicit_operator,
6777 )) = try_build_spatial_term_log_kappa_derivative(data, resolvedspec, design, term_idx)?
6778 else {
6779 return Ok(None);
6780 };
6781 let Some(penalty_start) = smooth_term_penalty_index(resolvedspec, design, term_idx) else {
6782 return Ok(None);
6783 };
6784 if s_psi_components_local.is_empty() || s_psi_psi_components_local.is_empty() {
6785 return Ok(None);
6786 }
6787 if s_psi_components_local.len() != s_psi_psi_components_local.len() {
6788 return Ok(None);
6789 }
6790 let penalty_indices = (0..s_psi_components_local.len())
6791 .map(|j| penalty_start + j)
6792 .collect::<Vec<_>>();
6793 let penalty_index = penalty_indices[0];
6794 if s_psi_local_check.nrows() == 0 || s_psi_psi_local.nrows() == 0 {
6795 return Ok(None);
6796 }
6797 Ok(Some(SpatialPsiDerivative {
6798 penalty_index,
6799 penalty_indices,
6800 global_range,
6801 total_p,
6802 x_psi_local,
6803 s_psi_components_local,
6804 x_psi_psi_local,
6805 s_psi_psi_components_local,
6806 aniso_group_id: None,
6807 aniso_cross_designs: None,
6808 aniso_cross_penalty_provider: None,
6809 implicit_operator,
6810 implicit_axis: 0,
6811 }))
6812}
6813
6814pub(crate) fn try_build_spatial_log_kappa_derivativeinfo_list(
6815 data: ArrayView2<'_, f64>,
6816 resolvedspec: &TermCollectionSpec,
6817 design: &TermCollectionDesign,
6818 spatial_terms: &[usize],
6819) -> Result<Option<Vec<SpatialPsiDerivative>>, EstimationError> {
6820 let mut out = Vec::new();
6821 let mut aniso_gid = 0usize;
6822 for &term_idx in spatial_terms {
6823 if spatial_term_uses_per_axis_psi(resolvedspec, term_idx) {
6824 if let Some(entries) = try_build_spatial_term_log_kappa_aniso_derivativeinfos(
6825 data,
6826 resolvedspec,
6827 design,
6828 term_idx,
6829 aniso_gid,
6830 )? {
6831 aniso_gid += 1;
6832 out.extend(entries);
6833 continue;
6834 } else {
6835 return Ok(None);
6836 }
6837 }
6838 let Some(info) =
6839 try_build_spatial_term_log_kappa_derivativeinfo(data, resolvedspec, design, term_idx)?
6840 else {
6841 return Ok(None);
6842 };
6843 out.push(info);
6844 }
6845 Ok(Some(out))
6846}
6847
6848fn try_build_spatial_term_log_kappa_aniso_derivativeinfos(
6850 data: ArrayView2<'_, f64>,
6851 resolvedspec: &TermCollectionSpec,
6852 design: &TermCollectionDesign,
6853 term_idx: usize,
6854 aniso_group_id: usize,
6855) -> Result<Option<Vec<SpatialPsiDerivative>>, EstimationError> {
6856 let Some(smooth_term) = design.smooth.terms.get(term_idx) else {
6857 return Ok(None);
6858 };
6859 let Some(termspec) = resolvedspec.smooth_terms.get(term_idx) else {
6860 return Ok(None);
6861 };
6862 let mut aniso_result = match &termspec.basis {
6863 SmoothBasisSpec::Sphere { .. } => return Ok(None),
6864 SmoothBasisSpec::Matern {
6865 feature_cols,
6866 spec,
6867 input_scales,
6868 } => {
6869 let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
6870 if let Some(s) = input_scales {
6871 apply_input_standardization(&mut x, s);
6872 }
6873 let mut spec_operator = spec.clone();
6882 spec_operator.double_penalty = false;
6883 build_matern_basis_log_kappa_aniso_derivatives(x.view(), &spec_operator)
6884 .map_err(EstimationError::from)?
6885 }
6886 SmoothBasisSpec::MeasureJet {
6892 feature_cols,
6893 spec,
6894 input_scales,
6895 } => {
6896 let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
6897 if let Some(s) = input_scales {
6898 apply_input_standardization(&mut x, s);
6899 }
6900 build_measure_jet_basis_psi_derivatives(x.view(), spec)
6901 .map_err(EstimationError::from)?
6902 }
6903 _ => return Ok(None),
6904 };
6905 let d = if let Some(ref op) = aniso_result.implicit_operator {
6908 op.n_axes()
6909 } else if !aniso_result.design_first.is_empty() {
6910 aniso_result.design_first.len()
6911 } else {
6912 0
6913 };
6914 if d == 0 {
6915 return Ok(None);
6916 }
6917 let Some(penalty_start) = smooth_term_penalty_index(resolvedspec, design, term_idx) else {
6918 return Ok(None);
6919 };
6920 let p_total = design.design.ncols();
6921 let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
6922 let global_range = (smooth_start + smooth_term.coeff_range.start)
6923 ..(smooth_start + smooth_term.coeff_range.end);
6924 let num_penalties = aniso_result.penalties_first[0].len();
6925 let penalty_indices: Vec<usize> = (0..num_penalties).map(|j| penalty_start + j).collect();
6926 let penalties_cross_provider = aniso_result.penalties_cross_provider.clone();
6927
6928 let use_implicit_design = aniso_result.design_first.is_empty();
6932 let implicit_op_arc = aniso_result
6933 .implicit_operator
6934 .as_ref()
6935 .map(|op| std::sync::Arc::new(op.clone()));
6936
6937 let mut entries = Vec::with_capacity(d);
6938 for a in 0..d {
6939 let (x_psi_local, x_psi_psi_local) = if use_implicit_design {
6940 (Array2::<f64>::zeros((0, 0)), Array2::<f64>::zeros((0, 0)))
6946 } else {
6947 let x_first = std::mem::take(&mut aniso_result.design_first[a]);
6952 let x_second = std::mem::take(&mut aniso_result.design_second_diag[a]);
6953 if x_first.ncols() != smooth_term.coeff_range.len() {
6954 return Ok(None);
6955 }
6956 (x_first, x_second)
6957 };
6958 let s_psi_components = std::mem::take(&mut aniso_result.penalties_first[a]);
6959 let s_psi_psi_components = std::mem::take(&mut aniso_result.penalties_second_diag[a]);
6960 let cross_designs = if implicit_op_arc.is_some() {
6966 let mut cd = Vec::with_capacity(d - 1);
6967 for b in 0..d {
6968 if b == a {
6969 continue;
6970 }
6971 cd.push((b, Array2::<f64>::zeros((0, 0))));
6972 }
6973 cd
6974 } else if !aniso_result.design_second_cross.is_empty() {
6975 let mut cd = Vec::new();
6976 for (cross_idx, &(pa, pb)) in aniso_result.design_second_cross_pairs.iter().enumerate()
6977 {
6978 if pa == a {
6979 cd.push((pb, aniso_result.design_second_cross[cross_idx].clone()));
6980 } else if pb == a {
6981 cd.push((pa, aniso_result.design_second_cross[cross_idx].clone()));
6982 }
6983 }
6984 cd
6985 } else {
6986 Vec::new()
6987 };
6988 let cross_penalty_provider = if d > 1 {
6989 let penalties_cross_provider = penalties_cross_provider.clone();
6990 Some(std::sync::Arc::new(
6991 move |b_axis: usize| -> Result<Vec<Array2<f64>>, EstimationError> {
6992 if b_axis == a {
6993 return Ok(Vec::new());
6994 }
6995 let (axis_lo, axis_hi) = if a < b_axis { (a, b_axis) } else { (b_axis, a) };
6996 if let Some(provider) = penalties_cross_provider.as_ref() {
6997 provider
6998 .evaluate(axis_lo, axis_hi)
6999 .map_err(EstimationError::from)
7000 } else {
7001 Ok(Vec::new())
7005 }
7006 },
7007 )
7008 as std::sync::Arc<
7009 dyn Fn(usize) -> Result<Vec<Array2<f64>>, EstimationError>
7010 + Send
7011 + Sync
7012 + 'static,
7013 >)
7014 } else {
7015 None
7016 };
7017
7018 entries.push(SpatialPsiDerivative {
7019 penalty_index: penalty_indices[0],
7020 penalty_indices: penalty_indices.clone(),
7021 global_range: global_range.clone(),
7022 total_p: p_total,
7023 x_psi_local,
7024 s_psi_components_local: s_psi_components,
7025 x_psi_psi_local,
7026 s_psi_psi_components_local: s_psi_psi_components,
7027 aniso_group_id: Some(aniso_group_id),
7028 aniso_cross_designs: if cross_designs.is_empty() {
7029 None
7030 } else {
7031 Some(cross_designs)
7032 },
7033 aniso_cross_penalty_provider: cross_penalty_provider,
7034 implicit_operator: implicit_op_arc.clone(),
7035 implicit_axis: a,
7036 });
7037 }
7038 Ok(Some(entries))
7039}
7040
7041#[cfg(test)]
7042mod glm_eta_observation_fd_tests {
7043 use super::*;
7049
7050 fn one_obs(spec: &LikelihoodSpec, y: f64, eta: f64) -> StandardFamilyObservationState {
7051 let yv = Array1::from_vec(vec![y]);
7052 let wv = Array1::from_vec(vec![1.0]);
7053 let ev = Array1::from_vec(vec![eta]);
7054 evaluate_standard_familyobservations(spec.clone(), None, None, None, &yv, &wv, &ev)
7055 .expect("standard family observation state assembles")
7056 }
7057
7058 fn check_fd(label: &str, spec: &LikelihoodSpec, y: f64, eta: f64) {
7059 let h = 1e-5;
7060 let s0 = one_obs(spec, y, eta);
7061 let sp = one_obs(spec, y, eta + h);
7062 let sm = one_obs(spec, y, eta - h);
7063
7064 let score_fd = (sp.log_likelihood - sm.log_likelihood) / (2.0 * h);
7066 let score = s0.score[0];
7067 assert!(
7068 (score - score_fd).abs() <= 1e-4 * (1.0 + score.abs()),
7069 "{label}: score {score} vs FD {score_fd}"
7070 );
7071
7072 let neghess_fd = -(sp.score[0] - sm.score[0]) / (2.0 * h);
7074 let neghess = s0.neghessian_eta[0];
7075 assert!(
7076 (neghess - neghess_fd).abs() <= 1e-3 * (1.0 + neghess.abs()),
7077 "{label}: neghessian_eta {neghess} vs FD {neghess_fd}"
7078 );
7079
7080 let nhd_fd = (sp.neghessian_eta[0] - sm.neghessian_eta[0]) / (2.0 * h);
7082 let nhd = s0.neghessian_eta_derivative[0];
7083 assert!(
7084 (nhd - nhd_fd).abs() <= 1e-2 * (1.0 + nhd.abs()),
7085 "{label}: neghessian_eta_derivative {nhd} vs FD {nhd_fd}"
7086 );
7087 }
7088
7089 #[test]
7090 fn poisson_gamma_nb_tweedie_arms_match_finite_differences_1615_1616() {
7091 let log = InverseLink::Standard(StandardLink::Log);
7092 let poisson = LikelihoodSpec {
7093 response: ResponseFamily::Poisson,
7094 link: log.clone(),
7095 };
7096 check_fd("poisson y=3", &poisson, 3.0, 0.4);
7097 check_fd("poisson y=0", &poisson, 0.0, -0.2);
7098
7099 let gamma = LikelihoodSpec {
7100 response: ResponseFamily::Gamma,
7101 link: log.clone(),
7102 };
7103 check_fd("gamma y=2.5", &gamma, 2.5, 0.3);
7104 check_fd("gamma y=0.7", &gamma, 0.7, -0.1);
7105
7106 let nb = LikelihoodSpec {
7107 response: ResponseFamily::NegativeBinomial {
7108 theta: 1.5,
7109 theta_fixed: true,
7110 },
7111 link: log.clone(),
7112 };
7113 check_fd("negbin y=4", &nb, 4.0, 0.5);
7114 check_fd("negbin y=0", &nb, 0.0, -0.3);
7115
7116 let tweedie = LikelihoodSpec {
7117 response: ResponseFamily::Tweedie { p: 1.5 },
7118 link: log.clone(),
7119 };
7120 check_fd("tweedie y=2", &tweedie, 2.0, 0.25);
7121 check_fd("tweedie y=0.5", &tweedie, 0.5, -0.15);
7122 }
7123}