1pub fn build_term_collection_designs_joint(
9 data: ArrayView2<'_, f64>,
10 specs: &[TermCollectionSpec],
11) -> Result<Vec<TermCollectionDesign>, BasisError> {
12 for spec in specs {
13 validate_term_collection_finite_inputs(data, spec)?;
14 }
15 let smooth_blocks = specs
16 .iter()
17 .map(|spec| spec.smooth_terms.clone())
18 .collect::<Vec<_>>();
19 let planned_blocks = plan_joint_spatial_centers_for_term_blocks(data, &smooth_blocks)?;
20 let mut out = Vec::with_capacity(specs.len());
21 for (spec, planned_terms) in specs.iter().zip(planned_blocks.into_iter()) {
22 let mut planned_spec = spec.clone();
23 planned_spec.smooth_terms = planned_terms;
24 out.push(build_term_collection_design_inner(data, &planned_spec)?);
25 }
26 Ok(out)
27}
28
29pub fn build_term_collection_designs_and_freeze_joint(
30 data: ArrayView2<'_, f64>,
31 specs: &[TermCollectionSpec],
32) -> Result<(Vec<TermCollectionDesign>, Vec<TermCollectionSpec>), EstimationError> {
33 let designs = build_term_collection_designs_joint(data, specs)?;
34 let mut resolved_specs = Vec::with_capacity(specs.len());
35 for (spec, design) in specs.iter().zip(designs.iter()) {
36 resolved_specs.push(freeze_term_collection_from_design(spec, design)?);
37 }
38 Ok((designs, resolved_specs))
39}
40
41pub fn fit_term_collection_forspec(
42 data: ArrayView2<'_, f64>,
43 y: ArrayView1<'_, f64>,
44 weights: ArrayView1<'_, f64>,
45 offset: ArrayView1<'_, f64>,
46 spec: &TermCollectionSpec,
47 family: LikelihoodSpec,
48 options: &FitOptions,
49) -> Result<FittedTermCollection, EstimationError> {
50 fit_term_collection_forspecwith_heuristic_lambdas(
51 data, y, weights, offset, spec, None, family, options,
52 )
53}
54
55pub fn fit_term_collection_with_coefficient_groups(
56 data: ArrayView2<'_, f64>,
57 y: ArrayView1<'_, f64>,
58 weights: ArrayView1<'_, f64>,
59 offset: ArrayView1<'_, f64>,
60 spec: &TermCollectionSpec,
61 groups: &[CoefficientGroupSpec],
62 family: LikelihoodSpec,
63 options: &FitOptions,
64) -> Result<FittedTermCollection, EstimationError> {
65 if groups.is_empty() {
66 return fit_term_collection_forspec(data, y, weights, offset, spec, family, options);
67 }
68 let design = build_term_collection_design(data, spec)?;
69 let base_fit_opts = adaptive_fit_options_base(options, &design);
70 let realized = design
71 .realize_coefficient_groups(groups, &base_fit_opts.rho_prior)
72 .map_err(EstimationError::BasisError)?;
73 let mut grouped_options = base_fit_opts.clone();
74 grouped_options.rho_prior = realized.rho_prior;
75 let fitted = FittedTermCollection {
76 fit: gam_solve::estimate::fit_gam_with_penalty_specs(
77 design.design.clone(),
78 y,
79 weights,
80 offset,
81 realized.penalty_specs,
82 realized.nullspace_dims,
83 family.clone(),
84 &grouped_options,
85 )?,
86 design,
87 adaptive_diagnostics: None,
88 };
89 enforce_term_constraint_feasibility(&fitted.design, &fitted.fit)?;
90 Ok(fitted)
91}
92
93pub fn fit_term_collection_with_penalty_block_gamma_prior_callback<F>(
94 data: ArrayView2<'_, f64>,
95 y: ArrayView1<'_, f64>,
96 weights: ArrayView1<'_, f64>,
97 offset: ArrayView1<'_, f64>,
98 spec: &TermCollectionSpec,
99 callback: F,
100 family: LikelihoodSpec,
101 options: &FitOptions,
102) -> Result<FittedTermCollection, EstimationError>
103where
104 F: FnMut(&PenaltyBlockGammaPriorMetadata<'_>) -> Option<(f64, f64)>,
105{
106 let design = build_term_collection_design(data, spec)?;
107 let mut fit_opts = adaptive_fit_options_base(options, &design);
108 fit_opts.rho_prior = realize_penalty_block_gamma_priors(&design, callback)
109 .map_err(EstimationError::BasisError)?;
110 let fitted = FittedTermCollection {
111 fit: fit_gamwith_heuristic_lambdas(
112 design.design.clone(),
113 y,
114 weights,
115 offset,
116 &design.penalties,
117 None,
118 family.clone(),
119 &fit_opts,
120 )?,
121 design,
122 adaptive_diagnostics: None,
123 };
124 enforce_term_constraint_feasibility(&fitted.design, &fitted.fit)?;
125 Ok(fitted)
126}
127
128pub fn fit_term_collection_with_penalty_block_gamma_priors(
129 data: ArrayView2<'_, f64>,
130 y: ArrayView1<'_, f64>,
131 weights: ArrayView1<'_, f64>,
132 offset: ArrayView1<'_, f64>,
133 spec: &TermCollectionSpec,
134 priors: &[(String, f64, f64)],
135 family: LikelihoodSpec,
136 options: &FitOptions,
137) -> Result<FittedTermCollection, EstimationError> {
138 let design = build_term_collection_design(data, spec)?;
139 let mut fit_opts = adaptive_fit_options_base(options, &design);
140 fit_opts.rho_prior = realize_keyed_penalty_block_gamma_priors(&design, priors)
141 .map_err(EstimationError::BasisError)?;
142 let fitted = FittedTermCollection {
143 fit: fit_gamwith_heuristic_lambdas(
144 design.design.clone(),
145 y,
146 weights,
147 offset,
148 &design.penalties,
149 None,
150 family.clone(),
151 &fit_opts,
152 )?,
153 design,
154 adaptive_diagnostics: None,
155 };
156 enforce_term_constraint_feasibility(&fitted.design, &fitted.fit)?;
157 Ok(fitted)
158}
159
160pub fn fit_term_collection_with_coefficient_groups_and_penalty_block_gamma_priors(
161 data: ArrayView2<'_, f64>,
162 y: ArrayView1<'_, f64>,
163 weights: ArrayView1<'_, f64>,
164 offset: ArrayView1<'_, f64>,
165 spec: &TermCollectionSpec,
166 groups: &[CoefficientGroupSpec],
167 priors: &[(String, f64, f64)],
168 family: LikelihoodSpec,
169 options: &FitOptions,
170) -> Result<FittedTermCollection, EstimationError> {
171 if groups.is_empty() {
172 return fit_term_collection_with_penalty_block_gamma_priors(
173 data, y, weights, offset, spec, priors, family, options,
174 );
175 }
176 if priors.is_empty() {
177 return fit_term_collection_with_coefficient_groups(
178 data, y, weights, offset, spec, groups, family, options,
179 );
180 }
181
182 let design = build_term_collection_design(data, spec)?;
183 let base_fit_opts = adaptive_fit_options_base(options, &design);
184 let base_rho_prior = realize_keyed_penalty_block_gamma_priors(&design, priors)
185 .map_err(EstimationError::BasisError)?;
186 let realized = design
187 .realize_coefficient_groups(groups, &base_rho_prior)
188 .map_err(EstimationError::BasisError)?;
189 let mut grouped_options = base_fit_opts.clone();
190 grouped_options.rho_prior = realized.rho_prior;
191 let fitted = FittedTermCollection {
192 fit: gam_solve::estimate::fit_gam_with_penalty_specs(
193 design.design.clone(),
194 y,
195 weights,
196 offset,
197 realized.penalty_specs,
198 realized.nullspace_dims,
199 family.clone(),
200 &grouped_options,
201 )?,
202 design,
203 adaptive_diagnostics: None,
204 };
205 enforce_term_constraint_feasibility(&fitted.design, &fitted.fit)?;
206 Ok(fitted)
207}
208
209fn fit_term_collection_forspecwith_heuristic_lambdas(
210 data: ArrayView2<'_, f64>,
211 y: ArrayView1<'_, f64>,
212 weights: ArrayView1<'_, f64>,
213 offset: ArrayView1<'_, f64>,
214 spec: &TermCollectionSpec,
215 heuristic_lambdas: Option<&[f64]>,
216 family: LikelihoodSpec,
217 options: &FitOptions,
218) -> Result<FittedTermCollection, EstimationError> {
219 let adaptive_opts = options.adaptive_regularization.clone().unwrap_or_default();
220 let resolved_spec;
221 let design_spec = if adaptive_opts.enabled {
222 resolved_spec = ensure_matern_adaptive_center_resolution(spec, data.nrows());
223 &resolved_spec
224 } else {
225 spec
226 };
227 let base_design = build_term_collection_design(data, design_spec)?;
228 fit_term_collection_on_realized_design(
229 y,
230 weights,
231 offset,
232 design_spec,
233 &base_design,
234 heuristic_lambdas,
235 family,
236 options,
237 )
238}
239
240fn ensure_matern_adaptive_center_resolution(
241 spec: &TermCollectionSpec,
242 n_rows: usize,
243) -> TermCollectionSpec {
244 let mut out = spec.clone();
245 for term in &mut out.smooth_terms {
246 let gam_terms::smooth::SmoothBasisSpec::Matern {
247 feature_cols,
248 spec: matern,
249 ..
250 } = &mut term.basis
251 else {
252 continue;
253 };
254 if let gam_terms::basis::CenterStrategy::FarthestPoint { num_centers } =
255 &mut matern.center_strategy
256 {
257 let min_centers = (4 * feature_cols.len()).min(n_rows).max(*num_centers);
270 *num_centers = min_centers;
271 }
272 }
273 out
274}
275
276fn has_bounded_linear_terms(spec: &TermCollectionSpec) -> bool {
277 spec.linear_terms.iter().any(|term| {
278 matches!(
279 term.coefficient_geometry,
280 LinearCoefficientGeometry::Bounded { .. }
281 )
282 })
283}
284
285fn fit_term_collection_on_realized_design(
286 y: ArrayView1<'_, f64>,
287 weights: ArrayView1<'_, f64>,
288 offset: ArrayView1<'_, f64>,
289 spec: &TermCollectionSpec,
290 design: &TermCollectionDesign,
291 heuristic_lambdas: Option<&[f64]>,
292 family: LikelihoodSpec,
293 options: &FitOptions,
294) -> Result<FittedTermCollection, EstimationError> {
295 if has_bounded_linear_terms(spec) {
296 return fit_bounded_term_collection_with_design(
297 y,
298 weights,
299 offset,
300 spec,
301 design,
302 heuristic_lambdas,
303 family,
304 options,
305 );
306 }
307 let mut base_fit_opts = adaptive_fit_options_base(options, design);
308 base_fit_opts.rho_prior = relax_smoothing_rho_prior(options, design);
315 let fitted = FittedTermCollection {
316 fit: fit_gamwith_heuristic_lambdas(
317 design.design.clone(),
318 y,
319 weights,
320 offset,
321 &design.penalties,
322 heuristic_lambdas,
323 family.clone(),
324 &base_fit_opts,
325 )?,
326 design: design.clone(),
327 adaptive_diagnostics: None,
328 };
329 enforce_term_constraint_feasibility(&fitted.design, &fitted.fit)?;
330
331 let adaptive_opts = options.adaptive_regularization.clone().unwrap_or_default();
332 if !adaptive_opts.enabled {
333 return Ok(fitted);
334 }
335 let runtime_caches = extract_spatial_operator_runtime_caches(spec, &fitted.design)?;
336 if runtime_caches.is_empty() {
337 return Ok(fitted);
338 }
339 fit_term_collectionwith_exact_spatial_adaptive_regularization(
346 fitted,
347 y,
348 weights,
349 offset,
350 family,
351 options,
352 &runtime_caches,
353 )
354}
355
356#[derive(Clone)]
357struct SpatialOperatorRuntimeCache {
358 termname: String,
359 feature_cols: Vec<usize>,
360 coeff_global_range: Range<usize>,
361 mass_penalty_global_idx: usize,
362 tension_penalty_global_idx: usize,
363 stiffness_penalty_global_idx: usize,
364 d0: Array2<f64>,
365 d1: Array2<f64>,
366 d2: Array2<f64>,
367 collocation_points: Array2<f64>,
368 dimension: usize,
369}
370
371#[derive(Clone)]
372struct SpatialAdaptiveWeights {
373 inv_magweight: Array1<f64>,
374 invgradweight: Array1<f64>,
375 inv_lapweight: Array1<f64>,
376}
377
378#[derive(Clone)]
379struct CharbonnierScalarBlockState {
380 signal: Array1<f64>,
381 radius: Array1<f64>,
382 epsilon: f64,
383}
384
385impl CharbonnierScalarBlockState {
386 fn from_signal(signal: Array1<f64>, epsilon: f64) -> Self {
387 let eps = epsilon.max(1e-12);
388 let radius = signal.mapv(|t| (t * t + eps * eps).sqrt());
389 Self {
390 signal,
391 radius,
392 epsilon: eps,
393 }
394 }
395
396 fn absolute_signal(&self) -> Array1<f64> {
397 self.signal.mapv(f64::abs)
398 }
399
400 fn penalty_value(&self) -> f64 {
401 self.radius.iter().map(|r| r - self.epsilon).sum::<f64>()
402 }
403
404 fn betagradient_coeff(&self) -> Array1<f64> {
405 Array1::from_iter(
406 self.signal
407 .iter()
408 .zip(self.radius.iter())
409 .map(|(t, r)| t / r),
410 )
411 }
412
413 fn betahessian_diag(&self) -> Array1<f64> {
414 let eps2 = self.epsilon * self.epsilon;
415 self.radius.mapv(|r| eps2 / r.powi(3))
416 }
417
418 fn log_epsilon_gradient_terms(&self) -> Array1<f64> {
419 let epsilon = self.epsilon;
420 let eps2 = epsilon * epsilon;
421 self.radius.mapv(|r| eps2 / r - epsilon)
422 }
423
424 fn log_epsilon_betagradient_coeff(&self) -> Array1<f64> {
425 let eps2 = self.epsilon * self.epsilon;
426 Array1::from_iter(
427 self.signal
428 .iter()
429 .zip(self.radius.iter())
430 .map(|(t, r)| -eps2 * t / r.powi(3)),
431 )
432 }
433
434 fn log_epsilon_hessian_terms(&self) -> Array1<f64> {
435 let epsilon = self.epsilon;
436 let eps2 = epsilon * epsilon;
437 let eps4 = eps2 * eps2;
438 self.radius
439 .mapv(|r| 2.0 * eps2 / r - eps4 / r.powi(3) - epsilon)
440 }
441
442 fn surrogateweights_posterior_snr(
443 &self,
444 variance: &Array1<f64>,
445 weight_floor: f64,
446 weight_ceiling: f64,
447 ) -> (Array1<f64>, Array1<f64>) {
448 let eps2 = self.epsilon * self.epsilon;
506 let weight = Array1::from_iter(self.signal.iter().zip(variance.iter()).map(|(&t, &v)| {
507 let credible2 = (t * t - v.max(0.0)).max(0.0);
508 let r = (credible2 + eps2).sqrt();
509 (1.0 / r).clamp(weight_floor, weight_ceiling)
510 }));
511 let invweight = weight.mapv(|u| 1.0 / u);
512 (weight, invweight)
513 }
514
515 fn directionalhessian_diag(&self, direction_signal: &Array1<f64>) -> Array1<f64> {
516 let eps2 = self.epsilon * self.epsilon;
531 Array1::from_iter(
532 self.signal
533 .iter()
534 .zip(direction_signal.iter())
535 .zip(self.radius.iter())
536 .map(|((t, q), r)| -3.0 * eps2 * t * q / r.powi(5)),
537 )
538 }
539
540 fn second_directionalhessian_diag(
547 &self,
548 direction1_signal: &Array1<f64>,
549 direction2_signal: &Array1<f64>,
550 ) -> Array1<f64> {
551 let eps2 = self.epsilon * self.epsilon;
552 Array1::from_iter(
553 self.signal
554 .iter()
555 .zip(direction1_signal.iter())
556 .zip(direction2_signal.iter())
557 .zip(self.radius.iter())
558 .map(|(((t, q1), q2), r)| {
559 let r2 = r * r;
560 let psi4 = -3.0 * eps2 / r.powi(5) + 15.0 * eps2 * t * t / (r.powi(5) * r2);
561 psi4 * q1 * q2
562 }),
563 )
564 }
565
566 fn log_epsilon_betahessian_diag(&self) -> Array1<f64> {
567 let eps2 = self.epsilon * self.epsilon;
568 let eps4 = eps2 * eps2;
569 Array1::from_iter(
570 self.signal
571 .iter()
572 .zip(self.radius.iter())
573 .map(|(_, r)| 2.0 * eps2 / r.powi(3) - 3.0 * eps4 / r.powi(5)),
574 )
575 }
576
577 fn log_epsilon_beta_mixed_second_coeff(&self) -> Array1<f64> {
578 let eps2 = self.epsilon * self.epsilon;
579 Array1::from_iter(
580 self.signal
581 .iter()
582 .zip(self.radius.iter())
583 .map(|(t, r)| eps2 * t * (eps2 - 2.0 * t * t) / r.powi(5)),
584 )
585 }
586
587 fn log_epsilon_betahessian_second_diag(&self) -> Array1<f64> {
588 let eps2 = self.epsilon * self.epsilon;
589 let eps4 = eps2 * eps2;
590 let eps6 = eps4 * eps2;
591 Array1::from_iter(
592 self.radius.iter().map(|r| {
593 4.0 * eps2 / r.powi(3) - 18.0 * eps4 / r.powi(5) + 15.0 * eps6 / r.powi(7)
594 }),
595 )
596 }
597
598 fn log_epsilon_betahessian_directional_diag(
599 &self,
600 direction_signal: &Array1<f64>,
601 ) -> Array1<f64> {
602 let eps2 = self.epsilon * self.epsilon;
603 let eps4 = eps2 * eps2;
604 Array1::from_iter(
605 self.signal
606 .iter()
607 .zip(direction_signal.iter())
608 .zip(self.radius.iter())
609 .map(|((t, q), r)| (-6.0 * eps2 * t / r.powi(5) + 15.0 * eps4 * t / r.powi(7)) * q),
610 )
611 }
612}
613
614#[derive(Clone)]
615struct CharbonnierGroupedBlockState {
616 norm: Array1<f64>,
617 radius: Array1<f64>,
618 signal_blocks: Array2<f64>,
619 epsilon: f64,
620}
621
622impl CharbonnierGroupedBlockState {
623 fn from_signal_blocks(signal_blocks: Array2<f64>, epsilon: f64) -> Self {
624 let eps = epsilon.max(1e-12);
625 let norm = Array1::from_iter(
626 signal_blocks
627 .rows()
628 .into_iter()
629 .map(|row| row.iter().map(|v| v * v).sum::<f64>().sqrt()),
630 );
631 let radius = norm.mapv(|g| (g * g + eps * eps).sqrt());
632 Self {
633 norm,
634 radius,
635 signal_blocks,
636 epsilon: eps,
637 }
638 }
639
640 fn penalty_value(&self) -> f64 {
641 self.radius.iter().map(|r| r - self.epsilon).sum::<f64>()
642 }
643
644 fn norm_signal(&self) -> Array1<f64> {
645 self.norm.clone()
646 }
647
648 fn betagradient_blocks(&self) -> Array2<f64> {
649 let mut out = self.signal_blocks.clone();
650 for (k, mut row) in out.rows_mut().into_iter().enumerate() {
651 let scale = 1.0 / self.radius[k];
652 row.mapv_inplace(|v| v * scale);
653 }
654 out
655 }
656
657 fn betahessian_blocks(&self) -> Vec<Array2<f64>> {
658 let mut out = Vec::with_capacity(self.signal_blocks.nrows());
659 for (k, row) in self.signal_blocks.rows().into_iter().enumerate() {
660 let dim = row.len();
661 let mut block = Array2::<f64>::eye(dim);
662 block.mapv_inplace(|v| v / self.radius[k]);
663 for i in 0..dim {
664 for j in 0..dim {
665 block[[i, j]] -= row[i] * row[j] / self.radius[k].powi(3);
666 }
667 }
668 out.push(block);
669 }
670 out
671 }
672
673 fn log_epsilon_gradient_terms(&self) -> Array1<f64> {
674 let epsilon = self.epsilon;
675 let eps2 = epsilon * epsilon;
676 self.radius.mapv(|r| eps2 / r - epsilon)
677 }
678
679 fn log_epsilon_betagradient_blocks(&self) -> Array2<f64> {
680 let mut out = self.signal_blocks.clone();
681 let eps2 = self.epsilon * self.epsilon;
682 for (k, mut row) in out.rows_mut().into_iter().enumerate() {
683 let scale = -eps2 / self.radius[k].powi(3);
684 row.mapv_inplace(|v| v * scale);
685 }
686 out
687 }
688
689 fn log_epsilon_hessian_terms(&self) -> Array1<f64> {
690 let epsilon = self.epsilon;
691 let eps2 = epsilon * epsilon;
692 let eps4 = eps2 * eps2;
693 self.radius
694 .mapv(|r| 2.0 * eps2 / r - eps4 / r.powi(3) - epsilon)
695 }
696
697 fn surrogateweights_posterior_snr(
698 &self,
699 variance: &Array1<f64>,
700 weight_floor: f64,
701 weight_ceiling: f64,
702 ) -> (Array1<f64>, Array1<f64>) {
703 let eps2 = self.epsilon * self.epsilon;
745 let weight = Array1::from_iter(self.norm.iter().zip(variance.iter()).map(|(&g, &v)| {
746 let credible2 = (g * g - v.max(0.0)).max(0.0);
747 let r = (credible2 + eps2).sqrt();
748 (1.0 / r).clamp(weight_floor, weight_ceiling)
749 }));
750 let invweight = weight.mapv(|u| 1.0 / u);
751 (weight, invweight)
752 }
753
754 fn directionalhessian_blocks(&self, direction_blocks: &Array2<f64>) -> Vec<Array2<f64>> {
755 let mut out = Vec::with_capacity(self.signal_blocks.nrows());
780 for (k, (v, q)) in self
781 .signal_blocks
782 .rows()
783 .into_iter()
784 .zip(direction_blocks.rows().into_iter())
785 .enumerate()
786 {
787 let dim = v.len();
788 let dot = v.iter().zip(q.iter()).map(|(a, b)| a * b).sum::<f64>();
789 let r3 = self.radius[k].powi(3);
790 let r5 = self.radius[k].powi(5);
791 let mut block = Array2::<f64>::eye(dim);
792 block.mapv_inplace(|x| -dot * x / r3);
793 for i in 0..dim {
794 for j in 0..dim {
795 block[[i, j]] -= (q[i] * v[j] + v[i] * q[j]) / r3;
796 block[[i, j]] += 3.0 * dot * v[i] * v[j] / r5;
797 }
798 }
799 out.push(block);
800 }
801 out
802 }
803
804 fn second_directionalhessian_blocks(
821 &self,
822 direction1_blocks: &Array2<f64>,
823 direction2_blocks: &Array2<f64>,
824 ) -> Vec<Array2<f64>> {
825 let mut out = Vec::with_capacity(self.signal_blocks.nrows());
826 for ((k, v), (a, b)) in self.signal_blocks.rows().into_iter().enumerate().zip(
827 direction1_blocks
828 .rows()
829 .into_iter()
830 .zip(direction2_blocks.rows().into_iter()),
831 ) {
832 let dim = v.len();
833 let dot = |x: ndarray::ArrayView1<'_, f64>, y: ndarray::ArrayView1<'_, f64>| {
834 x.iter().zip(y.iter()).map(|(p, q)| p * q).sum::<f64>()
835 };
836 let sa = dot(v, a);
837 let sb = dot(v, b);
838 let ab = dot(a, b);
839 let r = self.radius[k];
840 let r3 = r.powi(3);
841 let r5 = r.powi(5);
842 let r7 = r5 * r * r;
843 let diag = -ab / r3 + 3.0 * sa * sb / r5;
844 let mut block = Array2::<f64>::eye(dim);
845 block.mapv_inplace(|x| diag * x);
846 for i in 0..dim {
847 for j in 0..dim {
848 block[[i, j]] -= (a[i] * b[j] + b[i] * a[j]) / r3;
849 block[[i, j]] += 3.0 * sb * (a[i] * v[j] + v[i] * a[j]) / r5;
850 block[[i, j]] += 3.0 * ab * v[i] * v[j] / r5;
851 block[[i, j]] += 3.0 * sa * (b[i] * v[j] + v[i] * b[j]) / r5;
852 block[[i, j]] -= 15.0 * sa * sb * v[i] * v[j] / r7;
853 }
854 }
855 out.push(block);
856 }
857 out
858 }
859
860 fn log_epsilon_betahessian_blocks(&self) -> Vec<Array2<f64>> {
861 let mut out = Vec::with_capacity(self.signal_blocks.nrows());
862 for (k, row) in self.signal_blocks.rows().into_iter().enumerate() {
863 let dim = row.len();
864 let r3 = self.radius[k].powi(3);
865 let r5 = self.radius[k].powi(5);
866 let mut block = Array2::<f64>::eye(dim);
867 let eps2 = self.epsilon * self.epsilon;
868 block.mapv_inplace(|v| -eps2 * v / r3);
869 for i in 0..dim {
870 for j in 0..dim {
871 block[[i, j]] += 3.0 * eps2 * row[i] * row[j] / r5;
872 }
873 }
874 out.push(block);
875 }
876 out
877 }
878
879 fn log_epsilon_beta_mixed_second_blocks(&self) -> Array2<f64> {
880 let mut out = self.signal_blocks.clone();
881 let eps2 = self.epsilon * self.epsilon;
882 for (k, mut row) in out.rows_mut().into_iter().enumerate() {
883 let norm2 = self.norm[k] * self.norm[k];
884 let scale = eps2 * (eps2 - 2.0 * norm2) / self.radius[k].powi(5);
885 row.mapv_inplace(|v| v * scale);
886 }
887 out
888 }
889
890 fn log_epsilon_betahessian_second_blocks(&self) -> Vec<Array2<f64>> {
891 let mut out = Vec::with_capacity(self.signal_blocks.nrows());
892 let eps2 = self.epsilon * self.epsilon;
893 for (k, row) in self.signal_blocks.rows().into_iter().enumerate() {
894 let dim = row.len();
895 let norm2 = self.norm[k] * self.norm[k];
896 let r5 = self.radius[k].powi(5);
897 let r7 = self.radius[k].powi(7);
898 let mut block = Array2::<f64>::eye(dim);
899 block.mapv_inplace(|v| eps2 * (eps2 - 2.0 * norm2) * v / r5);
900 for i in 0..dim {
901 for j in 0..dim {
902 block[[i, j]] += 3.0 * eps2 * (2.0 * norm2 - 3.0 * eps2) * row[i] * row[j] / r7;
903 }
904 }
905 out.push(block);
906 }
907 out
908 }
909
910 fn log_epsilon_betahessian_directional_blocks(
911 &self,
912 direction_blocks: &Array2<f64>,
913 ) -> Vec<Array2<f64>> {
914 let mut out = Vec::with_capacity(self.signal_blocks.nrows());
915 let eps2 = self.epsilon * self.epsilon;
916 for (k, (v, q)) in self
917 .signal_blocks
918 .rows()
919 .into_iter()
920 .zip(direction_blocks.rows().into_iter())
921 .enumerate()
922 {
923 let dim = v.len();
924 let dot = v.iter().zip(q.iter()).map(|(a, b)| a * b).sum::<f64>();
925 let r5 = self.radius[k].powi(5);
926 let r7 = self.radius[k].powi(7);
927 let mut block = Array2::<f64>::eye(dim);
928 block.mapv_inplace(|x| 3.0 * eps2 * dot * x / r5);
929 for i in 0..dim {
930 for j in 0..dim {
931 block[[i, j]] += 3.0 * eps2 * (q[i] * v[j] + v[i] * q[j]) / r5;
932 block[[i, j]] -= 15.0 * eps2 * dot * v[i] * v[j] / r7;
933 }
934 }
935 out.push(block);
936 }
937 out
938 }
939}
940
941fn scalar_operatorgradient(operator: &Array2<f64>, coeff: &Array1<f64>) -> Array1<f64> {
942 operator.t().dot(coeff)
943}
944
945fn scalar_operatorhessian(operator: &Array2<f64>, diag: &Array1<f64>) -> Array2<f64> {
946 let mut weighted = operator.clone();
947 for (k, &w) in diag.iter().enumerate() {
948 weighted.row_mut(k).mapv_inplace(|v| v * w);
949 }
950 let gram = operator.t().dot(&weighted);
951 (&gram + &gram.t().to_owned()) * 0.5
952}
953
954fn grouped_operatorgradient(
955 d1: &Array2<f64>,
956 dimension: usize,
957 blocks: &Array2<f64>,
958) -> Result<Array1<f64>, EstimationError> {
959 if blocks.ncols() != dimension {
960 crate::bail_invalid_estim!(
961 "grouped gradient block dimension mismatch: got {}, expected {dimension}",
962 blocks.ncols()
963 );
964 }
965 if d1.nrows() != blocks.nrows() * dimension {
966 crate::bail_invalid_estim!(
967 "grouped gradient row mismatch: D1 has {} rows, blocks imply {}",
968 d1.nrows(),
969 blocks.nrows() * dimension
970 );
971 }
972 let mut out = Array1::<f64>::zeros(d1.ncols());
973 for k in 0..blocks.nrows() {
974 let gk = d1
975 .slice(s![k * dimension..(k + 1) * dimension, ..])
976 .to_owned();
977 out += &gk.t().dot(&blocks.row(k));
978 }
979 Ok(out)
980}
981
982fn grouped_operatorhessian(
983 d1: &Array2<f64>,
984 dimension: usize,
985 blocks: &[Array2<f64>],
986) -> Result<Array2<f64>, EstimationError> {
987 if d1.nrows() != blocks.len() * dimension {
988 crate::bail_invalid_estim!(
989 "grouped Hessian row mismatch: D1 has {} rows, blocks imply {}",
990 d1.nrows(),
991 blocks.len() * dimension
992 );
993 }
994 let p = d1.ncols();
995 let mut out = Array2::<f64>::zeros((p, p));
996 for (k, block) in blocks.iter().enumerate() {
997 if block.nrows() != dimension || block.ncols() != dimension {
998 crate::bail_invalid_estim!(
999 "grouped Hessian block {k} has shape {}x{}, expected {}x{}",
1000 block.nrows(),
1001 block.ncols(),
1002 dimension,
1003 dimension
1004 );
1005 }
1006 let gk = d1
1007 .slice(s![k * dimension..(k + 1) * dimension, ..])
1008 .to_owned();
1009 out += &gk.t().dot(&block.dot(&gk));
1010 }
1011 Ok((&out + &out.t().to_owned()) * 0.5)
1012}
1013
1014#[derive(Clone)]
1015struct SpatialPenaltyExactState {
1016 magnitude: CharbonnierScalarBlockState,
1017 gradient: CharbonnierGroupedBlockState,
1018 curvature: CharbonnierGroupedBlockState,
1019}
1020
1021fn collocationgradient_blocks(
1022 gradrows: &Array1<f64>,
1023 dimension: usize,
1024) -> Result<Array2<f64>, EstimationError> {
1025 if dimension == 0 || !gradrows.len().is_multiple_of(dimension) {
1026 crate::bail_invalid_estim!(
1027 "invalid collocation gradient layout: rows={}, dimension={dimension}",
1028 gradrows.len()
1029 );
1030 }
1031 let p = gradrows.len() / dimension;
1032 let mut out = Array2::<f64>::zeros((p, dimension));
1033 for k in 0..p {
1034 for axis in 0..dimension {
1035 out[[k, axis]] = gradrows[k * dimension + axis];
1036 }
1037 }
1038 Ok(out)
1039}
1040
1041fn collocationhessian_blocks(
1042 hessianrows: &Array1<f64>,
1043 dimension: usize,
1044) -> Result<Array2<f64>, EstimationError> {
1045 let block_dim = dimension.checked_mul(dimension).ok_or_else(|| {
1046 EstimationError::InvalidInput("invalid collocation Hessian dimension overflow".to_string())
1047 })?;
1048 if block_dim == 0 || !hessianrows.len().is_multiple_of(block_dim) {
1049 crate::bail_invalid_estim!(
1050 "invalid collocation Hessian layout: rows={}, dimension={dimension}",
1051 hessianrows.len()
1052 );
1053 }
1054 let p = hessianrows.len() / block_dim;
1055 let mut out = Array2::<f64>::zeros((p, block_dim));
1056 for k in 0..p {
1057 for idx in 0..block_dim {
1058 out[[k, idx]] = hessianrows[k * block_dim + idx];
1059 }
1060 }
1061 Ok(out)
1062}
1063
1064impl SpatialPenaltyExactState {
1065 fn from_beta_local(
1066 beta_local: ArrayView1<'_, f64>,
1067 cache: &SpatialOperatorRuntimeCache,
1068 epsilons: [f64; 3],
1069 ) -> Result<Self, EstimationError> {
1070 let gradientrows = cache.d1.dot(&beta_local);
1100 let hessianrows = cache.d2.dot(&beta_local);
1101 Ok(Self {
1102 magnitude: CharbonnierScalarBlockState::from_signal(
1103 cache.d0.dot(&beta_local),
1104 epsilons[0],
1105 ),
1106 gradient: CharbonnierGroupedBlockState::from_signal_blocks(
1107 collocationgradient_blocks(&gradientrows, cache.dimension)?,
1108 epsilons[1],
1109 ),
1110 curvature: CharbonnierGroupedBlockState::from_signal_blocks(
1111 collocationhessian_blocks(&hessianrows, cache.dimension)?,
1112 epsilons[2],
1113 ),
1114 })
1115 }
1116
1117 fn absolute_collocation_magnitudes(&self) -> (Array1<f64>, Array1<f64>, Array1<f64>) {
1118 (
1119 self.magnitude.absolute_signal(),
1120 self.gradient.norm_signal(),
1121 self.curvature.norm_signal(),
1122 )
1123 }
1124}
1125
1126fn robust_epsilon_from_samples(values: &[f64], min_epsilon_cfg: f64) -> f64 {
1127 if values.is_empty() {
1128 return min_epsilon_cfg.max(1e-12);
1129 }
1130 let mut clean = values
1131 .iter()
1132 .copied()
1133 .filter(|v| v.is_finite() && *v >= 0.0)
1134 .collect::<Vec<_>>();
1135 if clean.is_empty() {
1136 return min_epsilon_cfg.max(1e-12);
1137 }
1138 clean.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1139
1140 let n = clean.len();
1141 let median = quantile_from_sorted(&clean, 0.5);
1142 let q75 = quantile_from_sorted(&clean, 0.75);
1143 let q95 = quantile_from_sorted(&clean, 0.95);
1144
1145 let mut abs_dev = clean
1146 .iter()
1147 .map(|v| (v - median).abs())
1148 .filter(|v| v.is_finite())
1149 .collect::<Vec<_>>();
1150 abs_dev.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1151 let mad = 1.4826 * quantile_from_sorted(&abs_dev, 0.5);
1152
1153 let mut scale = median.max(mad).max(q75);
1163
1164 let delta = (f64::EPSILON.sqrt() * q95.max(1.0))
1166 .max(min_epsilon_cfg)
1167 .max(1e-12);
1168 let s_min = min_epsilon_cfg.max(1e-12);
1169
1170 if scale <= delta {
1172 let rms = (clean.iter().map(|v| v * v).sum::<f64>() / n as f64).sqrt();
1173 scale = q95.max(rms);
1174 }
1175 if scale <= delta {
1176 scale = s_min;
1177 }
1178
1179 let kappa = 1.0_f64;
1182 (kappa * scale).max(s_min)
1183}
1184
1185fn extract_spatial_operator_runtime_caches(
1186 spec: &TermCollectionSpec,
1187 design: &TermCollectionDesign,
1188) -> Result<Vec<SpatialOperatorRuntimeCache>, EstimationError> {
1189 let smooth_start = design
1190 .design
1191 .ncols()
1192 .saturating_sub(design.smooth.total_smooth_cols());
1193 let mut out = Vec::<SpatialOperatorRuntimeCache>::new();
1194 for (term_idx, (termspec, term_fit)) in spec
1195 .smooth_terms
1196 .iter()
1197 .zip(design.smooth.terms.iter())
1198 .enumerate()
1199 {
1200 let Some(global_base_idx) = smooth_term_penalty_index(spec, design, term_idx) else {
1201 continue;
1202 };
1203 let mut active_local_idx = 0usize;
1204 let mut mass_local_idx = None;
1205 let mut tension_local_idx = None;
1206 let mut stiffness_local_idx = None;
1207 let mut mass_norm = None;
1208 let mut tension_norm = None;
1209 let mut stiffness_norm = None;
1210 for info in &term_fit.penaltyinfo_local {
1211 if !info.active {
1212 continue;
1213 }
1214 match info.source {
1215 PenaltySource::OperatorMass => {
1216 mass_local_idx = Some(active_local_idx);
1217 mass_norm = Some(info.normalization_scale);
1218 }
1219 PenaltySource::OperatorTension => {
1220 tension_local_idx = Some(active_local_idx);
1221 tension_norm = Some(info.normalization_scale);
1222 }
1223 PenaltySource::OperatorStiffness => {
1224 stiffness_local_idx = Some(active_local_idx);
1225 stiffness_norm = Some(info.normalization_scale);
1226 }
1227 _ => {}
1228 }
1229 active_local_idx += 1;
1230 }
1231 let (
1244 Some(mass_local),
1245 Some(tension_local),
1246 Some(stiffness_local),
1247 Some(mass_scale),
1248 Some(tension_scale),
1249 Some(stiffness_scale),
1250 ) = (
1251 mass_local_idx,
1252 tension_local_idx,
1253 stiffness_local_idx,
1254 mass_norm,
1255 tension_norm,
1256 stiffness_norm,
1257 )
1258 else {
1259 continue;
1260 };
1261 let mass_global_idx = global_base_idx + mass_local;
1262 let tension_global_idx = global_base_idx + tension_local;
1263 let stiffness_global_idx = global_base_idx + stiffness_local;
1264
1265 let (feature_cols, mut d0, mut d1, mut d2, collocation_points, dim, center_mass_rows) =
1266 match (&termspec.basis, &term_fit.metadata) {
1267 (
1268 SmoothBasisSpec::Matern { feature_cols, .. },
1269 BasisMetadata::Matern {
1270 centers,
1271 length_scale,
1272 nu,
1273 include_intercept,
1274 identifiability_transform,
1275 aniso_log_scales,
1276 input_scales,
1277 ..
1278 },
1279 ) => {
1280 let collocation_length_scale = match input_scales.as_deref() {
1286 Some(scales) => {
1287 compensate_length_scale_for_standardization(*length_scale, scales)
1288 }
1289 None => *length_scale,
1290 };
1291 let ops = build_matern_collocation_operator_matrices(
1292 centers.view(),
1293 None,
1294 collocation_length_scale,
1295 *nu,
1296 *include_intercept,
1297 identifiability_transform.as_ref().map(|z| z.view()),
1298 aniso_log_scales.as_deref(),
1299 )?;
1300 (
1301 feature_cols.clone(),
1302 ops.d0,
1303 ops.d1,
1304 ops.d2,
1305 ops.collocation_points,
1306 centers.ncols(),
1307 false,
1308 )
1309 }
1310 (
1311 SmoothBasisSpec::Duchon { feature_cols, .. },
1312 BasisMetadata::Duchon {
1313 centers,
1314 length_scale,
1315 power,
1316 nullspace_order,
1317 identifiability_transform,
1318 input_scales,
1319 aniso_log_scales,
1320 operator_collocation_points: Some(collocation_points),
1321 ..
1322 },
1323 ) => {
1324 let collocation_length_scale = match (length_scale, input_scales.as_deref()) {
1325 (Some(ls), Some(scales)) => {
1326 Some(compensate_length_scale_for_standardization(*ls, scales))
1327 }
1328 (Some(ls), None) => Some(*ls),
1329 (None, _) => None,
1330 };
1331 let ops =
1332 gam_terms::basis::build_duchon_collocation_operator_matriceswithworkspace(
1333 centers.view(),
1334 collocation_points.view(),
1335 None,
1336 collocation_length_scale,
1337 *power,
1338 *nullspace_order,
1339 aniso_log_scales.as_deref(),
1340 identifiability_transform.as_ref().map(|z| z.view()),
1341 2,
1342 &mut BasisWorkspace::default(),
1343 )?;
1344 (
1345 feature_cols.clone(),
1346 ops.d0,
1347 ops.d1,
1348 ops.d2,
1349 ops.collocation_points,
1350 centers.ncols(),
1351 true,
1352 )
1353 }
1354 _ => continue,
1355 };
1356 if center_mass_rows && d0.nrows() > 0 && d0.ncols() > 0 {
1357 let means = d0.sum_axis(Axis(0)).mapv(|v| v / d0.nrows() as f64);
1358 for mut row in d0.rows_mut() {
1359 row -= &means;
1360 }
1361 }
1362
1363 let mass_scale = mass_scale.max(1e-12).sqrt();
1381 let tension_scale = tension_scale.max(1e-12).sqrt();
1382 let stiffness_scale = stiffness_scale.max(1e-12).sqrt();
1383 d0.mapv_inplace(|v| v / mass_scale);
1384 d1.mapv_inplace(|v| v / tension_scale);
1385 d2.mapv_inplace(|v| v / stiffness_scale);
1386
1387 let coeff_global_range =
1388 (smooth_start + term_fit.coeff_range.start)..(smooth_start + term_fit.coeff_range.end);
1389 if d0.ncols() != coeff_global_range.len()
1390 || d1.ncols() != coeff_global_range.len()
1391 || d2.ncols() != coeff_global_range.len()
1392 {
1393 crate::bail_invalid_estim!(
1394 "spatial operator dimension mismatch for term '{}': D0 cols={}, D1 cols={}, D2 cols={}, coeffs={}",
1395 term_fit.name,
1396 d0.ncols(),
1397 d1.ncols(),
1398 d2.ncols(),
1399 coeff_global_range.len()
1400 );
1401 }
1402 out.push(SpatialOperatorRuntimeCache {
1403 termname: term_fit.name.clone(),
1404 feature_cols,
1405 coeff_global_range,
1406 mass_penalty_global_idx: mass_global_idx,
1407 tension_penalty_global_idx: tension_global_idx,
1408 stiffness_penalty_global_idx: stiffness_global_idx,
1409 d0,
1410 d1,
1411 d2,
1412 collocation_points,
1413 dimension: dim,
1414 });
1415 }
1416 Ok(out)
1417}
1418
1419fn scalar_operator_response_variance(
1431 operator: &Array2<f64>,
1432 cov_local: &Array2<f64>,
1433) -> Array1<f64> {
1434 Array1::from_iter(operator.rows().into_iter().map(|row| {
1435 let s = cov_local.dot(&row);
1436 row.dot(&s).max(0.0)
1437 }))
1438}
1439
1440fn grouped_operator_response_variance(
1451 operator: &Array2<f64>,
1452 block_dim: usize,
1453 cov_local: &Array2<f64>,
1454) -> Result<Array1<f64>, EstimationError> {
1455 if block_dim == 0 || !operator.nrows().is_multiple_of(block_dim) {
1456 crate::bail_invalid_estim!(
1457 "grouped variance row layout invalid: rows={}, block_dim={block_dim}",
1458 operator.nrows()
1459 );
1460 }
1461 let p = operator.nrows() / block_dim;
1462 let mut out = Array1::<f64>::zeros(p);
1463 for k in 0..p {
1464 let mut acc = 0.0;
1465 for axis in 0..block_dim {
1466 let row = operator.row(k * block_dim + axis);
1467 let s = cov_local.dot(&row);
1468 acc += row.dot(&s);
1469 }
1470 out[k] = acc.max(0.0);
1471 }
1472 Ok(out)
1473}
1474
1475fn compute_spatial_adaptiveweights_for_beta(
1476 beta: &Array1<f64>,
1477 caches: &[SpatialOperatorRuntimeCache],
1478 epsilon_0: f64,
1479 epsilon_g: f64,
1480 epsilon_c: f64,
1481 weight_floor: f64,
1482 weight_ceiling: f64,
1483 beta_covariance: Option<&Array2<f64>>,
1484) -> Result<Vec<SpatialAdaptiveWeights>, EstimationError> {
1485 caches
1517 .iter()
1518 .map(|cache| {
1519 let beta_local = beta.slice(s![cache.coeff_global_range.clone()]);
1520 let exact = SpatialPenaltyExactState::from_beta_local(
1521 beta_local,
1522 cache,
1523 [epsilon_0, epsilon_g, epsilon_c],
1524 )?;
1525 let cov_local = beta_covariance.map(|cov| {
1526 cov.slice(s![
1527 cache.coeff_global_range.clone(),
1528 cache.coeff_global_range.clone()
1529 ])
1530 .to_owned()
1531 });
1532 let dim = cache.dimension;
1533 let (var_0, var_g, var_c) = match cov_local.as_ref() {
1534 Some(cov) => (
1535 scalar_operator_response_variance(&cache.d0, cov),
1536 grouped_operator_response_variance(&cache.d1, dim, cov)?,
1537 grouped_operator_response_variance(&cache.d2, dim * dim, cov)?,
1538 ),
1539 None => (
1540 Array1::<f64>::zeros(exact.magnitude.signal.len()),
1541 Array1::<f64>::zeros(exact.gradient.norm.len()),
1542 Array1::<f64>::zeros(exact.curvature.norm.len()),
1543 ),
1544 };
1545 let (_, inv_0) = exact.magnitude.surrogateweights_posterior_snr(
1546 &var_0,
1547 weight_floor,
1548 weight_ceiling,
1549 );
1550 let (_, inv_g) =
1551 exact
1552 .gradient
1553 .surrogateweights_posterior_snr(&var_g, weight_floor, weight_ceiling);
1554 let (_, inv_c) = exact.curvature.surrogateweights_posterior_snr(
1555 &var_c,
1556 weight_floor,
1557 weight_ceiling,
1558 );
1559 Ok(SpatialAdaptiveWeights {
1560 inv_magweight: inv_0,
1561 invgradweight: inv_g,
1562 inv_lapweight: inv_c,
1563 })
1564 })
1565 .collect()
1566}
1567
1568fn compute_initial_epsilons(
1569 beta: &Array1<f64>,
1570 caches: &[SpatialOperatorRuntimeCache],
1571 min_epsilon: f64,
1572) -> Result<(f64, f64, f64), EstimationError> {
1573 let mut fvals = Vec::<f64>::new();
1574 let mut gvals = Vec::<f64>::new();
1575 let mut cvals = Vec::<f64>::new();
1576 for cache in caches {
1577 let beta_local = beta.slice(s![cache.coeff_global_range.clone()]);
1578 let exact = SpatialPenaltyExactState::from_beta_local(
1579 beta_local,
1580 cache,
1581 [min_epsilon, min_epsilon, min_epsilon],
1582 )?;
1583 let (f, g, c) = exact.absolute_collocation_magnitudes();
1584 fvals.extend(f.iter().copied());
1585 gvals.extend(g.iter().copied());
1586 cvals.extend(c.iter().copied());
1587 }
1588 let eps_0 = robust_epsilon_from_samples(&fvals, min_epsilon);
1594 let eps_g = robust_epsilon_from_samples(&gvals, min_epsilon);
1595 let eps_c = robust_epsilon_from_samples(&cvals, min_epsilon);
1596 Ok((eps_0, eps_g, eps_c))
1597}
1598
1599fn exact_spatial_adaptive_penalty_index_set(
1600 caches: &[SpatialOperatorRuntimeCache],
1601) -> BTreeSet<usize> {
1602 let mut out = BTreeSet::new();
1603 for cache in caches {
1604 out.insert(cache.mass_penalty_global_idx);
1605 out.insert(cache.tension_penalty_global_idx);
1606 out.insert(cache.stiffness_penalty_global_idx);
1607 }
1608 out
1609}
1610
1611fn build_spatial_adaptive_hyperspecs(cache_count: usize) -> Vec<SpatialAdaptiveHyperSpec> {
1612 let mut out = Vec::with_capacity(cache_count * 3 + 3);
1613 for cache_index in 0..cache_count {
1614 out.push(SpatialAdaptiveHyperSpec {
1615 cache_index,
1616 kind: SpatialAdaptiveHyperKind::LogLambdaMagnitude,
1617 });
1618 out.push(SpatialAdaptiveHyperSpec {
1619 cache_index,
1620 kind: SpatialAdaptiveHyperKind::LogLambdaGradient,
1621 });
1622 out.push(SpatialAdaptiveHyperSpec {
1623 cache_index,
1624 kind: SpatialAdaptiveHyperKind::LogLambdaCurvature,
1625 });
1626 }
1627 out.push(SpatialAdaptiveHyperSpec {
1628 cache_index: 0,
1629 kind: SpatialAdaptiveHyperKind::LogEpsilonMagnitude,
1630 });
1631 out.push(SpatialAdaptiveHyperSpec {
1632 cache_index: 0,
1633 kind: SpatialAdaptiveHyperKind::LogEpsilonGradient,
1634 });
1635 out.push(SpatialAdaptiveHyperSpec {
1636 cache_index: 0,
1637 kind: SpatialAdaptiveHyperKind::LogEpsilonCurvature,
1638 });
1639 out
1640}
1641
1642fn penalty_matrixwith_local_block(
1643 total_dim: usize,
1644 coeff_range: Range<usize>,
1645 local: &Array2<f64>,
1646) -> Array2<f64> {
1647 let mut out = Array2::<f64>::zeros((total_dim, total_dim));
1648 out.slice_mut(s![coeff_range.clone(), coeff_range])
1649 .assign(local);
1650 out
1651}
1652
1653fn fit_term_collectionwith_exact_spatial_adaptive_regularization(
1654 baseline: FittedTermCollection,
1655 y: ArrayView1<'_, f64>,
1656 weights: ArrayView1<'_, f64>,
1657 offset: ArrayView1<'_, f64>,
1658 family: LikelihoodSpec,
1659 options: &FitOptions,
1660 runtime_caches: &[SpatialOperatorRuntimeCache],
1661) -> Result<FittedTermCollection, EstimationError> {
1662 let adaptive_opts = options.adaptive_regularization.clone().unwrap_or_default();
1691 let adaptive_penalty_indices = exact_spatial_adaptive_penalty_index_set(runtime_caches);
1692 let p_total = baseline.design.design.ncols();
1693 struct RetainedPenaltySetup {
1694 global_idx: usize,
1695 global_penalty: Array2<f64>,
1696 nullspace_dim: usize,
1697 log_lambda: f64,
1698 col_range: Range<usize>,
1699 hessian_piece: Array2<f64>,
1700 }
1701 use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
1702 let retained_setups = baseline
1703 .design
1704 .penalties
1705 .par_iter()
1706 .enumerate()
1707 .map(|(idx, bp)| {
1708 if adaptive_penalty_indices.contains(&idx) {
1709 return None;
1710 }
1711 let lambda = baseline.fit.lambdas[idx];
1712 Some(RetainedPenaltySetup {
1713 global_idx: idx,
1714 global_penalty: bp.to_global(p_total),
1715 nullspace_dim: baseline
1716 .design
1717 .nullspace_dims
1718 .get(idx)
1719 .copied()
1720 .unwrap_or(0),
1721 log_lambda: lambda.max(1e-12).ln(),
1722 col_range: bp.col_range.clone(),
1723 hessian_piece: bp.local.mapv(|v| lambda * v),
1724 })
1725 })
1726 .collect::<Vec<_>>();
1727 let retained_count = retained_setups
1728 .iter()
1729 .filter(|setup| setup.is_some())
1730 .count();
1731 let mut retained_penalties = Vec::<Array2<f64>>::with_capacity(retained_count);
1732 let mut retained_nullspace_dims = Vec::<usize>::with_capacity(retained_count);
1733 let mut retained_log_lambdas = Vec::<f64>::with_capacity(retained_count);
1734 let mut retained_global_indices = Vec::<usize>::with_capacity(retained_count);
1735 let mut fixed_quadratichessian = Array2::<f64>::zeros((p_total, p_total));
1736 for setup in retained_setups.into_iter().flatten() {
1737 retained_penalties.push(setup.global_penalty);
1738 retained_nullspace_dims.push(setup.nullspace_dim);
1739 retained_log_lambdas.push(setup.log_lambda);
1740 retained_global_indices.push(setup.global_idx);
1741 fixed_quadratichessian
1742 .slice_mut(s![setup.col_range.clone(), setup.col_range])
1743 .scaled_add(1.0, &setup.hessian_piece);
1744 }
1745
1746 let (eps_0_init, eps_g_init, eps_c_init) = compute_initial_epsilons(
1747 &baseline.fit.beta,
1748 runtime_caches,
1749 adaptive_opts.min_epsilon,
1750 )?;
1751 let mut initial_theta =
1752 Array1::<f64>::zeros(retained_penalties.len() + runtime_caches.len() * 3 + 3);
1753 for (idx, value) in retained_log_lambdas.iter().enumerate() {
1754 initial_theta[idx] = *value;
1755 }
1756 let adaptive_log_lambda_components = runtime_caches
1757 .par_iter()
1758 .map(|cache| {
1759 [
1760 baseline.fit.lambdas[cache.mass_penalty_global_idx]
1761 .max(1e-12)
1762 .ln(),
1763 baseline.fit.lambdas[cache.tension_penalty_global_idx]
1764 .max(1e-12)
1765 .ln(),
1766 baseline.fit.lambdas[cache.stiffness_penalty_global_idx]
1767 .max(1e-12)
1768 .ln(),
1769 ]
1770 })
1771 .collect::<Vec<_>>();
1772 let mut at = retained_penalties.len();
1773 for logs in &adaptive_log_lambda_components {
1774 initial_theta[at] = logs[0];
1775 initial_theta[at + 1] = logs[1];
1776 initial_theta[at + 2] = logs[2];
1777 at += 3;
1778 }
1779 initial_theta[at] = eps_0_init.max(adaptive_opts.min_epsilon).ln();
1780 initial_theta[at + 1] = eps_g_init.max(adaptive_opts.min_epsilon).ln();
1781 initial_theta[at + 2] = eps_c_init.max(adaptive_opts.min_epsilon).ln();
1782
1783 let hyperspecs = build_spatial_adaptive_hyperspecs(runtime_caches.len());
1784 let zero_psi_op: std::sync::Arc<dyn gam_custom_family::CustomFamilyPsiDerivativeOperator> =
1785 std::sync::Arc::new(gam_custom_family::ZeroPsiDerivativeOperator::new(
1786 baseline.design.design.nrows(),
1787 baseline.design.design.ncols(),
1788 ));
1789 let derivative_blocks = vec![
1790 hyperspecs
1791 .par_iter()
1792 .map(|_| CustomFamilyBlockPsiDerivative {
1793 penalty_index: None,
1794 x_psi: Array2::<f64>::zeros((0, 0)),
1795 s_psi: Array2::<f64>::zeros((0, 0)),
1796 s_psi_components: None,
1797 s_psi_penalty_components: None,
1798 x_psi_psi: None,
1799 s_psi_psi: None,
1800 s_psi_psi_components: None,
1801 s_psi_psi_penalty_components: None,
1802 implicit_operator: Some(std::sync::Arc::clone(&zero_psi_op)),
1803 implicit_axis: 0,
1804 implicit_group_id: None,
1805 })
1806 .collect::<Vec<_>>(),
1807 ];
1808
1809 let mixture_link_state = options
1810 .mixture_link
1811 .clone()
1812 .as_ref()
1813 .map(state_fromspec)
1814 .transpose()
1815 .map_err(EstimationError::InvalidInput)?;
1816 let sas_link_state = options
1817 .sas_link
1818 .map(|spec| {
1819 if family.is_binomial_beta_logistic() {
1820 state_from_beta_logisticspec(spec)
1821 } else {
1822 state_from_sasspec(spec)
1823 }
1824 })
1825 .transpose()
1826 .map_err(EstimationError::InvalidInput)?;
1827 let latent_cloglog_state = options.latent_cloglog;
1828 let shared_y = Arc::new(y.to_owned());
1829 let sharedweights = Arc::new(weights.to_owned());
1830 let shared_design = baseline
1831 .design
1832 .design
1833 .try_to_dense_arc("spatial adaptive exact hyperfit design")
1834 .map_err(EstimationError::InvalidInput)?;
1835 let shared_offset = Arc::new(offset.to_owned());
1836 let shared_runtime_caches = Arc::new(runtime_caches.to_vec());
1837 let shared_hyperspecs = Arc::new(hyperspecs.clone());
1838 let zero_quadratic = Arc::new(Array2::<f64>::zeros((
1839 baseline.design.design.ncols(),
1840 baseline.design.design.ncols(),
1841 )));
1842 let base_family = SpatialAdaptiveExactFamily {
1843 family: family.clone(),
1844 latent_cloglog_state,
1845 mixture_link_state: mixture_link_state.clone(),
1846 sas_link_state,
1847 y: shared_y.clone(),
1848 weights: sharedweights.clone(),
1849 design: shared_design.clone(),
1850 offset: shared_offset.clone(),
1851 linear_constraints: baseline.design.linear_constraints.clone(),
1852 runtime_caches: shared_runtime_caches.clone(),
1853 adaptive_params: Vec::new(),
1854 fixed_quadratichessian: zero_quadratic.clone(),
1855 hyperspecs: shared_hyperspecs.clone(),
1856 exact_eval_cache: Arc::new(Mutex::new(None)),
1857 };
1858
1859 let rho_dim = retained_penalties.len();
1860 let operator_slots_end = rho_dim + runtime_caches.len() * 3;
1861 const UNIFIED_LOG_WINDOW: f64 = 6.0;
1871 const RETAINED_LAMBDA_LOG_LOWER_FLOOR: f64 = -30.0;
1872 const RETAINED_LAMBDA_LOG_UPPER_CAP: f64 = 30.0;
1873 const OPERATOR_LAMBDA_LOG_LOWER_FLOOR: f64 = -10.0;
1874 const OPERATOR_LAMBDA_LOG_UPPER_CAP: f64 = 30.0;
1875 let epsilon_floor_log = adaptive_opts.min_epsilon.max(1e-12).ln();
1876 let anchored_bound = |idx: usize, sign: f64| -> f64 {
1877 let raw = initial_theta[idx] + sign * UNIFIED_LOG_WINDOW;
1878 if idx < rho_dim {
1879 raw.clamp(
1880 RETAINED_LAMBDA_LOG_LOWER_FLOOR,
1881 RETAINED_LAMBDA_LOG_UPPER_CAP,
1882 )
1883 } else if idx < operator_slots_end {
1884 raw.clamp(
1885 OPERATOR_LAMBDA_LOG_LOWER_FLOOR,
1886 OPERATOR_LAMBDA_LOG_UPPER_CAP,
1887 )
1888 } else {
1889 raw.max(epsilon_floor_log)
1890 }
1891 };
1892 let eps_lower =
1893 Array1::from_iter((0..initial_theta.len()).map(|idx| anchored_bound(idx, -1.0)));
1894 let eps_upper = Array1::from_iter((0..initial_theta.len()).map(|idx| anchored_bound(idx, 1.0)));
1895 let blockspec = ParameterBlockSpec {
1896 name: "eta".to_string(),
1897 design: baseline.design.design.clone(),
1898 offset: offset.to_owned(),
1899 penalties: retained_penalties
1900 .iter()
1901 .cloned()
1902 .map(PenaltyMatrix::Dense)
1903 .collect(),
1904 nullspace_dims: retained_nullspace_dims.clone(),
1905 initial_log_lambdas: Array1::from_vec(retained_log_lambdas.clone()),
1906 initial_beta: Some(baseline.fit.beta.clone()),
1907 gauge_priority: 100,
1908 jacobian_callback: None,
1909 stacked_design: None,
1910 stacked_offset: None,
1911 };
1912 let screening_cap = Arc::new(AtomicUsize::new(0));
1913 let outer_opts = BlockwiseFitOptions {
1914 inner_max_cycles: options.max_iter,
1915 inner_tol: options.tol,
1916 outer_max_iter: options.max_iter,
1917 outer_tol: options.tol,
1918 compute_covariance: false,
1919 screening_max_inner_iterations: Some(Arc::clone(&screening_cap)),
1920 ..BlockwiseFitOptions::default()
1921 };
1922
1923 use gam_solve::rho_optimizer::OuterProblem;
1924 use gam_problem::{DeclaredHessianForm, Derivative, HessianResult, OuterEval};
1925
1926 struct SpatialAdaptiveOuterState {
1927 warm_cache: Option<CustomFamilyWarmStart>,
1928 last_eval: Option<(
1929 Array1<f64>,
1930 f64,
1931 Array1<f64>,
1932 HessianResult,
1933 CustomFamilyWarmStart,
1934 )>,
1935 }
1936
1937 let n_theta = initial_theta.len();
1938
1939 let theta_bounds = Some((eps_lower.clone(), eps_upper.clone()));
1942 let clamp_theta = {
1943 let lo = eps_lower;
1944 let hi = eps_upper;
1945 move |theta: &Array1<f64>| -> Array1<f64> {
1946 let mut clamped = theta.clone();
1947 for i in 0..clamped.len() {
1948 clamped[i] = clamped[i].clamp(lo[i], hi[i]);
1949 }
1950 clamped
1951 }
1952 };
1953
1954 let decode_theta = |theta: &Array1<f64>| -> (Array1<f64>, Vec<SpatialAdaptiveTermHyperParams>) {
1955 let rho = theta.slice(s![..rho_dim]).to_owned();
1956 let adaptive_lambda_start = rho_dim;
1957 let adaptive_lambda_end = adaptive_lambda_start + runtime_caches.len() * 3;
1958 let eps = [
1959 theta[adaptive_lambda_end].exp(),
1960 theta[adaptive_lambda_end + 1].exp(),
1961 theta[adaptive_lambda_end + 2].exp(),
1962 ];
1963 let adaptive_params = runtime_caches
1964 .iter()
1965 .enumerate()
1966 .map(|(cache_idx, _)| SpatialAdaptiveTermHyperParams {
1967 lambda: [
1968 theta[adaptive_lambda_start + cache_idx * 3].exp(),
1969 theta[adaptive_lambda_start + cache_idx * 3 + 1].exp(),
1970 theta[adaptive_lambda_start + cache_idx * 3 + 2].exp(),
1971 ],
1972 epsilon: eps,
1973 })
1974 .collect::<Vec<_>>();
1975 (rho, adaptive_params)
1976 };
1977 let analytic_outer_hessian_available =
1978 gam_custom_family::joint_exact_analytic_outer_hessian_available()
1979 && base_family
1980 .exact_outer_derivative_order(std::slice::from_ref(&blockspec), &outer_opts)
1981 .has_hessian()
1982 && gam_custom_family::exact_newton_outer_geometry_supports_second_order_solver(
1983 &base_family,
1984 );
1985 let outer_max_iter = gam_custom_family::cost_gated_first_order_max_iter(
1986 options.max_iter,
1987 base_family.coefficient_gradient_cost(std::slice::from_ref(&blockspec)),
1988 analytic_outer_hessian_available,
1989 );
1990 if outer_max_iter < options.max_iter {
1991 log::info!(
1992 "[OUTER] exact spatial adaptive regularization: first-order work gate reduced outer_max_iter {} -> {}",
1993 options.max_iter,
1994 outer_max_iter,
1995 );
1996 }
1997 let problem = OuterProblem::new(n_theta)
2003 .with_gradient(Derivative::Analytic)
2004 .with_hessian(if analytic_outer_hessian_available {
2005 DeclaredHessianForm::Either
2006 } else {
2007 DeclaredHessianForm::Unavailable
2008 })
2009 .with_fallback_policy(gam_solve::rho_optimizer::FallbackPolicy::Disabled)
2010 .with_psi_dim(n_theta.saturating_sub(rho_dim))
2011 .with_tolerance(options.tol)
2012 .with_max_iter(outer_max_iter)
2013 .with_seed_config(gam_problem::SeedConfig::default())
2014 .with_screening_cap(Arc::clone(&screening_cap))
2015 .with_initial_rho(initial_theta.clone());
2016 let problem = if let Some((lo, hi)) = theta_bounds {
2017 problem.with_bounds(lo, hi)
2018 } else {
2019 problem
2020 };
2021
2022 let eval_outer = |st: &mut SpatialAdaptiveOuterState,
2023 theta: &Array1<f64>,
2024 order: gam_solve::rho_optimizer::OuterEvalOrder|
2025 -> Result<OuterEval, EstimationError> {
2026 let theta = clamp_theta(theta);
2027
2028 if let Some((cached_theta, cached_cost, cached_grad, cached_hess, cached_warm)) =
2029 &st.last_eval
2030 && cached_theta.len() == theta.len()
2031 && cached_theta
2032 .iter()
2033 .zip(theta.iter())
2034 .all(|(&a, &b)| (a - b).abs() <= 1e-12)
2035 && (!matches!(
2036 order,
2037 gam_solve::rho_optimizer::OuterEvalOrder::ValueGradientHessian
2038 ) || analytic_outer_hessian_available)
2039 {
2040 st.warm_cache = Some(cached_warm.clone());
2041 return Ok(OuterEval {
2042 cost: *cached_cost,
2043 gradient: cached_grad.clone(),
2044 hessian: if matches!(
2045 order,
2046 gam_solve::rho_optimizer::OuterEvalOrder::ValueGradientHessian
2047 ) && analytic_outer_hessian_available
2048 {
2049 cached_hess.clone()
2050 } else {
2051 HessianResult::Unavailable
2052 },
2053 inner_beta_hint: None,
2054 });
2055 }
2056
2057 let (rho, adaptive_params) = decode_theta(&theta);
2058 let family_eval = base_family.with_adaptive_params(adaptive_params, zero_quadratic.clone());
2059 let need_hessian = matches!(
2060 order,
2061 gam_solve::rho_optimizer::OuterEvalOrder::ValueGradientHessian
2062 ) && analytic_outer_hessian_available;
2063 let result = evaluate_custom_family_joint_hyper(
2064 &family_eval,
2065 std::slice::from_ref(&blockspec),
2066 &outer_opts,
2067 &rho,
2068 &derivative_blocks,
2069 st.warm_cache.as_ref(),
2070 if need_hessian {
2071 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueGradientHessian
2072 } else {
2073 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueAndGradient
2074 },
2075 )
2076 .map_err(|e| {
2077 EstimationError::RemlOptimizationFailed(format!("spatial adaptive eval failed: {e}"))
2078 })?;
2079 if !result.inner_converged {
2080 st.warm_cache = Some(result.warm_start.clone());
2081 return Err(EstimationError::RemlOptimizationFailed(
2082 "exact spatial adaptive inner solve did not converge".to_string(),
2083 ));
2084 }
2085 if !result.objective.is_finite() || result.gradient.iter().any(|v| !v.is_finite()) {
2086 return Err(EstimationError::RemlOptimizationFailed(
2087 "exact spatial adaptive objective returned non-finite values".to_string(),
2088 ));
2089 }
2090 let hessian_result = if need_hessian {
2091 if !result.outer_hessian.is_analytic() {
2092 return Err(EstimationError::RemlOptimizationFailed(
2093 "exact spatial adaptive objective did not return an exact outer Hessian"
2094 .to_string(),
2095 ));
2096 }
2097 match result.outer_hessian.dim() {
2098 Some(dim) if dim == theta.len() => {}
2099 Some(dim) => {
2100 return Err(EstimationError::RemlOptimizationFailed(format!(
2101 "exact spatial adaptive outer Hessian dimension mismatch: got {dim}, expected {}",
2102 theta.len(),
2103 )));
2104 }
2105 None => {
2106 return Err(EstimationError::RemlOptimizationFailed(
2107 "exact spatial adaptive objective did not report an outer Hessian dimension"
2108 .to_string(),
2109 ));
2110 }
2111 }
2112 st.last_eval = Some((
2113 theta.clone(),
2114 result.objective,
2115 result.gradient.clone(),
2116 result.outer_hessian.clone(),
2117 result.warm_start.clone(),
2118 ));
2119 result.outer_hessian
2120 } else {
2121 HessianResult::Unavailable
2122 };
2123 st.warm_cache = Some(result.warm_start);
2124 Ok(OuterEval {
2125 cost: result.objective,
2126 gradient: result.gradient,
2127 hessian: hessian_result,
2128 inner_beta_hint: None,
2129 })
2130 };
2131
2132 let mut obj = problem.build_objective_with_screening_proxy(
2133 SpatialAdaptiveOuterState {
2134 warm_cache: None,
2135 last_eval: None,
2136 },
2137 |st: &mut SpatialAdaptiveOuterState, theta: &Array1<f64>| {
2138 let theta = clamp_theta(theta);
2139 let (rho, adaptive_params) = decode_theta(&theta);
2140 let family_eval =
2141 base_family.with_adaptive_params(adaptive_params, zero_quadratic.clone());
2142 let result = evaluate_custom_family_joint_hyper(
2143 &family_eval,
2144 std::slice::from_ref(&blockspec),
2145 &outer_opts,
2146 &rho,
2147 &derivative_blocks,
2148 st.warm_cache.as_ref(),
2149 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueOnly,
2150 )
2151 .map_err(|e| {
2152 EstimationError::RemlOptimizationFailed(format!(
2153 "spatial adaptive cost eval failed: {e}"
2154 ))
2155 })?;
2156 if !result.inner_converged {
2157 st.warm_cache = Some(result.warm_start);
2158 return Err(EstimationError::RemlOptimizationFailed(
2159 "exact spatial adaptive cost inner solve did not converge".to_string(),
2160 ));
2161 }
2162 st.warm_cache = Some(result.warm_start);
2163 Ok(result.objective)
2164 },
2165 |st: &mut SpatialAdaptiveOuterState, theta: &Array1<f64>| {
2166 eval_outer(
2167 st,
2168 theta,
2169 if analytic_outer_hessian_available {
2170 gam_solve::rho_optimizer::OuterEvalOrder::ValueGradientHessian
2171 } else {
2172 gam_solve::rho_optimizer::OuterEvalOrder::ValueAndGradient
2173 },
2174 )
2175 },
2176 |st: &mut SpatialAdaptiveOuterState,
2177 theta: &Array1<f64>,
2178 order: gam_solve::rho_optimizer::OuterEvalOrder| {
2179 eval_outer(st, theta, order)
2180 },
2181 Some(|st: &mut SpatialAdaptiveOuterState| {
2182 st.warm_cache = None;
2183 st.last_eval = None;
2184 }),
2185 Some(|st: &mut SpatialAdaptiveOuterState, theta: &Array1<f64>| {
2186 let theta = clamp_theta(theta);
2187 let (rho, adaptive_params) = decode_theta(&theta);
2188 let family_eval =
2189 base_family.with_adaptive_params(adaptive_params, zero_quadratic.clone());
2190 let result = evaluate_custom_family_joint_hyper_efs(
2191 &family_eval,
2192 std::slice::from_ref(&blockspec),
2193 &outer_opts,
2194 &rho,
2195 &derivative_blocks,
2196 st.warm_cache.as_ref(),
2197 )
2198 .map_err(|e| {
2199 EstimationError::RemlOptimizationFailed(format!(
2200 "spatial adaptive EFS eval failed: {e}"
2201 ))
2202 })?;
2203 if !result.inner_converged {
2204 st.warm_cache = Some(result.warm_start);
2205 return Err(EstimationError::RemlOptimizationFailed(
2206 "exact spatial adaptive EFS inner solve did not converge".to_string(),
2207 ));
2208 }
2209 st.warm_cache = Some(result.warm_start);
2210 Ok(result.efs_eval)
2211 }),
2212 |st: &mut SpatialAdaptiveOuterState, theta: &Array1<f64>| {
2224 let theta = clamp_theta(theta);
2225 let (rho, adaptive_params) = decode_theta(&theta);
2226 let family_eval =
2227 base_family.with_adaptive_params(adaptive_params, zero_quadratic.clone());
2228 let result = evaluate_custom_family_joint_hyper(
2229 &family_eval,
2230 std::slice::from_ref(&blockspec),
2231 &outer_opts,
2232 &rho,
2233 &derivative_blocks,
2234 st.warm_cache.as_ref(),
2235 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueOnly,
2236 )
2237 .map_err(|e| {
2238 EstimationError::RemlOptimizationFailed(format!(
2239 "spatial adaptive screening eval failed: {e}"
2240 ))
2241 })?;
2242 st.warm_cache = Some(result.warm_start);
2243 Ok(result.objective)
2244 },
2245 );
2246
2247 let outer_result = problem
2248 .run(&mut obj, "exact spatial adaptive regularization")
2249 .map_err(|e| {
2250 EstimationError::InvalidInput(format!(
2251 "exact spatial adaptive outer optimization failed: {e}"
2252 ))
2253 })?;
2254 if !outer_result.converged {
2255 let rel_to_cost_threshold = options.tol * (1.0_f64 + outer_result.final_value.abs());
2272 if let Some(final_grad) = outer_result
2276 .final_grad_norm
2277 .filter(|v| v.is_finite() && *v <= rel_to_cost_threshold)
2278 {
2279 log::info!(
2280 "[spatial-adaptive] outer optimization hit max_iter={} but \
2281 projected gradient norm {:.3e} ≤ τ·(1+|f|) = {:.3e} \
2282 (τ={:.3e}, |f|={:.3e}); accepting iterate under the mgcv-style \
2283 relative-to-cost REML convergence criterion.",
2284 outer_result.iterations,
2285 final_grad,
2286 rel_to_cost_threshold,
2287 options.tol,
2288 outer_result.final_value.abs(),
2289 );
2290 } else {
2291 crate::bail_invalid_estim!(
2292 "exact spatial adaptive outer optimization did not converge after {} iterations (final_objective={:.6e}, final_grad_norm={})",
2293 outer_result.iterations,
2294 outer_result.final_value,
2295 outer_result.final_grad_norm_report(),
2296 );
2297 }
2298 }
2299 let outer_iterations = outer_result.iterations;
2300 let outer_grad_norm: Option<f64> = outer_result.final_grad_norm;
2303 let theta_star = outer_result.rho;
2304 let rho_star = theta_star.slice(s![..rho_dim]).to_owned();
2305 let adaptive_lambda_start = rho_dim;
2306 let adaptive_lambda_end = adaptive_lambda_start + runtime_caches.len() * 3;
2307 let eps_star = [
2308 theta_star[adaptive_lambda_end].exp(),
2309 theta_star[adaptive_lambda_end + 1].exp(),
2310 theta_star[adaptive_lambda_end + 2].exp(),
2311 ];
2312 let adaptive_params = runtime_caches
2313 .iter()
2314 .enumerate()
2315 .map(|(cache_idx, _)| SpatialAdaptiveTermHyperParams {
2316 lambda: [
2317 theta_star[adaptive_lambda_start + cache_idx * 3].exp(),
2318 theta_star[adaptive_lambda_start + cache_idx * 3 + 1].exp(),
2319 theta_star[adaptive_lambda_start + cache_idx * 3 + 2].exp(),
2320 ],
2321 epsilon: eps_star,
2322 })
2323 .collect::<Vec<_>>();
2324 let mut fixed_total = Array2::<f64>::zeros((
2325 baseline.design.design.ncols(),
2326 baseline.design.design.ncols(),
2327 ));
2328 for (idx, penalty) in retained_penalties.iter().enumerate() {
2329 fixed_total.scaled_add(rho_star[idx].exp(), penalty);
2330 }
2331 let final_family =
2332 base_family.with_adaptive_params(adaptive_params.clone(), Arc::new(fixed_total.clone()));
2333 let final_blockspec = ParameterBlockSpec {
2334 name: "eta".to_string(),
2335 design: baseline.design.design.clone(),
2336 offset: offset.to_owned(),
2337 penalties: vec![],
2338 nullspace_dims: vec![],
2339 initial_log_lambdas: Array1::zeros(0),
2340 initial_beta: Some(baseline.fit.beta.clone()),
2341 gauge_priority: 100,
2342 jacobian_callback: None,
2343 stacked_design: None,
2344 stacked_offset: None,
2345 };
2346 let final_fit = fit_custom_family(
2347 &final_family,
2348 &[final_blockspec],
2349 &BlockwiseFitOptions {
2350 inner_max_cycles: options.max_iter,
2351 inner_tol: options.tol,
2352 outer_max_iter: 1,
2353 outer_tol: options.tol,
2354 compute_covariance: true,
2355 ..BlockwiseFitOptions::default()
2356 },
2357 )
2358 .map_err(EstimationError::CustomFamily)?;
2359 let beta = final_fit.block_states[0].beta.clone();
2360 let final_eval = final_family
2361 .exact_evaluation(&beta)
2362 .map_err(EstimationError::InvalidInput)?;
2363 let penalized_hessian = final_eval
2364 .totalobjectivehessian(&final_family.design)
2365 .map_err(EstimationError::InvalidInput)?;
2366 let beta_covariance = final_fit.covariance_conditional.clone();
2367 let beta_standard_errors = beta_covariance
2368 .as_ref()
2369 .map(|cov| Array1::from_iter((0..cov.nrows()).map(|i| cov[[i, i]].max(0.0).sqrt())));
2370
2371 let mut full_lambdas = baseline.fit.lambdas.clone();
2372 for (idx, &global_idx) in retained_global_indices.iter().enumerate() {
2373 full_lambdas[global_idx] = rho_star[idx].exp();
2374 }
2375 for (cache_idx, cache) in runtime_caches.iter().enumerate() {
2376 full_lambdas[cache.mass_penalty_global_idx] = adaptive_params[cache_idx].lambda[0];
2377 full_lambdas[cache.tension_penalty_global_idx] = adaptive_params[cache_idx].lambda[1];
2378 full_lambdas[cache.stiffness_penalty_global_idx] = adaptive_params[cache_idx].lambda[2];
2379 }
2380
2381 let deviance = if family.is_gaussian_identity() {
2382 y.iter()
2383 .zip(final_eval.obs.mu.iter())
2384 .zip(weights.iter())
2385 .map(|((&yy, &mu), &w)| w.max(0.0) * (yy - mu) * (yy - mu))
2386 .sum()
2387 } else {
2388 -2.0 * final_eval.obs.log_likelihood
2389 };
2390 let mut local_penalty_blocks =
2391 Vec::<PenaltySpec>::with_capacity(baseline.design.penalties.len());
2392 for (global_idx, bp) in baseline.design.penalties.iter().enumerate() {
2393 if adaptive_penalty_indices.contains(&global_idx) {
2394 let cache = runtime_caches
2395 .iter()
2396 .find(|cache| {
2397 cache.mass_penalty_global_idx == global_idx
2398 || cache.tension_penalty_global_idx == global_idx
2399 || cache.stiffness_penalty_global_idx == global_idx
2400 })
2401 .ok_or_else(|| {
2402 EstimationError::InvalidInput(format!(
2403 "missing runtime cache for adaptive penalty index {global_idx}"
2404 ))
2405 })?;
2406 let cache_idx = runtime_caches
2407 .iter()
2408 .position(|c| {
2409 c.mass_penalty_global_idx == global_idx
2410 || c.tension_penalty_global_idx == global_idx
2411 || c.stiffness_penalty_global_idx == global_idx
2412 })
2413 .ok_or_else(|| {
2414 EstimationError::InvalidInput(format!(
2415 "missing adaptive cache position for penalty index {global_idx}"
2416 ))
2417 })?;
2418 let state = &final_eval.adaptive_states[cache_idx];
2419 let local = if cache.mass_penalty_global_idx == global_idx {
2420 scalar_operatorhessian(&cache.d0, &state.magnitude.betahessian_diag())
2421 .mapv(|v| adaptive_params[cache_idx].lambda[0] * v)
2422 } else if cache.tension_penalty_global_idx == global_idx {
2423 grouped_operatorhessian(
2424 &cache.d1,
2425 cache.dimension,
2426 &state.gradient.betahessian_blocks(),
2427 )?
2428 .mapv(|v| adaptive_params[cache_idx].lambda[1] * v)
2429 } else {
2430 grouped_operatorhessian(
2431 &cache.d2,
2432 cache.dimension * cache.dimension,
2433 &state.curvature.betahessian_blocks(),
2434 )?
2435 .mapv(|v| adaptive_params[cache_idx].lambda[2] * v)
2436 };
2437 local_penalty_blocks.push(PenaltySpec::Dense(penalty_matrixwith_local_block(
2439 baseline.design.design.ncols(),
2440 cache.coeff_global_range.clone(),
2441 &local,
2442 )));
2443 } else {
2444 local_penalty_blocks.push(PenaltySpec::Dense(
2445 bp.to_global(p_total).mapv(|v| v * full_lambdas[global_idx]),
2446 ));
2447 }
2448 }
2449 let (edf_by_block, penalty_block_trace, edf_total) = if let Some(cov) = beta_covariance.as_ref()
2450 {
2451 exact_bounded_edf(
2452 &local_penalty_blocks,
2453 &Array1::from_elem(local_penalty_blocks.len(), 1.0),
2454 cov,
2455 )?
2456 } else {
2457 (
2458 vec![0.0; local_penalty_blocks.len()],
2459 vec![0.0; local_penalty_blocks.len()],
2460 0.0,
2461 )
2462 };
2463 let stable_penalty_term =
2464 2.0 * final_eval.adaptive_penalty_value + beta.dot(&fixed_total.dot(&beta));
2465 let standard_deviation = if family.is_gaussian_identity() {
2466 let denom = (y.len() as f64 - edf_total).max(1.0);
2467 (deviance / denom).sqrt()
2468 } else {
2469 1.0
2470 };
2471 let maps = compute_spatial_adaptiveweights_for_beta(
2472 &beta,
2473 runtime_caches,
2474 eps_star[0],
2475 eps_star[1],
2476 eps_star[2],
2477 adaptive_opts.weight_floor,
2478 adaptive_opts.weight_ceiling,
2479 beta_covariance.as_ref(),
2483 )?
2484 .into_iter()
2485 .zip(runtime_caches.iter())
2486 .map(|(w, cache)| AdaptiveSpatialMap {
2487 termname: cache.termname.clone(),
2488 feature_cols: cache.feature_cols.clone(),
2489 collocation_points: cache.collocation_points.clone(),
2490 inv_magweight: w.inv_magweight,
2491 invgradweight: w.invgradweight,
2492 inv_lapweight: w.inv_lapweight,
2493 })
2494 .collect::<Vec<_>>();
2495 let fitted_link = if family.is_latent_cloglog() {
2496 FittedLinkState::LatentCLogLog {
2497 state: latent_cloglog_state
2498 .expect("BinomialLatentCLogLog requires an explicit latent-cloglog state"),
2499 }
2500 } else if family.is_binomial_mixture() {
2501 mixture_link_state
2502 .clone()
2503 .map(|state| FittedLinkState::Mixture {
2504 state,
2505 covariance: None,
2506 })
2507 .unwrap_or(FittedLinkState::Standard(None))
2508 } else if family.is_binomial_sas() {
2509 sas_link_state
2510 .map(|state| FittedLinkState::Sas {
2511 state,
2512 covariance: None,
2513 })
2514 .unwrap_or(FittedLinkState::Standard(None))
2515 } else if family.is_binomial_beta_logistic() {
2516 sas_link_state
2517 .map(|state| FittedLinkState::BetaLogistic {
2518 state,
2519 covariance: None,
2520 })
2521 .unwrap_or(FittedLinkState::Standard(None))
2522 } else {
2523 FittedLinkState::Standard(None)
2524 };
2525 let max_abs_eta = final_eval
2526 .obs
2527 .eta
2528 .iter()
2529 .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2530 let fitted = FittedTermCollection {
2531 fit: {
2532 let log_lambdas = full_lambdas.mapv(|v| v.max(1e-300).ln());
2533 let inf = FitInference {
2534 edf_by_block,
2535 penalty_block_trace,
2536 edf_total,
2537 smoothing_correction: None,
2538 penalized_hessian: penalized_hessian.clone().into(),
2541 working_weights: final_eval.obs.fisherweight.clone(),
2542 working_response: {
2543 let mut out = final_eval.obs.eta.clone();
2544 for i in 0..out.len() {
2545 let wi = final_eval.obs.fisherweight[i].max(1e-12);
2546 out[i] += final_eval.obs.score[i] / wi;
2547 }
2548 out
2549 },
2550 reparam_qs: None,
2551 dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2552 beta_covariance: beta_covariance
2553 .clone()
2554 .map(gam_problem::dispersion_cov::PhiScaledCovariance::from),
2555 beta_standard_errors,
2556 beta_covariance_corrected: None,
2557 beta_standard_errors_corrected: None,
2558 beta_covariance_frequentist: None,
2559 coefficient_influence: None,
2560 weighted_gram: None,
2561 bias_correction_beta: None,
2562 };
2563 let geometry = Some(gam_solve::estimate::FitGeometry {
2564 penalized_hessian: penalized_hessian.into(),
2565 working_weights: inf.working_weights.clone(),
2566 working_response: inf.working_response.clone(),
2567 });
2568 let covariance_conditional = beta_covariance;
2569 let pirls_status_val = if final_fit.outer_converged {
2570 gam_solve::pirls::PirlsStatus::Converged
2571 } else {
2572 gam_solve::pirls::PirlsStatus::StalledAtValidMinimum
2573 };
2574 UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
2575 blocks: vec![gam_solve::estimate::FittedBlock {
2576 beta: beta.clone(),
2577 role: gam_problem::BlockRole::Mean,
2578 edf: edf_total,
2579 lambdas: full_lambdas.clone(),
2580 }],
2581 log_lambdas,
2582 lambdas: full_lambdas,
2583 likelihood_scale: family.default_scale_metadata(),
2584 likelihood_family: Some(family),
2585 log_likelihood_normalization:
2586 gam_spec::LogLikelihoodNormalization::UserProvided,
2587 log_likelihood: final_eval.obs.log_likelihood,
2588 deviance,
2589 reml_score: final_fit.penalized_objective,
2590 stable_penalty_term,
2591 penalized_objective: final_fit.penalized_objective,
2592 used_device: false,
2593 outer_iterations,
2594 outer_converged: final_fit.outer_converged,
2595 outer_gradient_norm: outer_grad_norm,
2596 standard_deviation,
2597 covariance_conditional,
2598 covariance_corrected: None,
2599 inference: Some(inf),
2600 fitted_link,
2601 geometry,
2602 block_states: Vec::new(),
2603 pirls_status: pirls_status_val,
2604 max_abs_eta,
2605 constraint_kkt: None,
2606 artifacts: gam_solve::estimate::FitArtifacts {
2607 pirls: None,
2608 ..Default::default()
2609 },
2610 inner_cycles: 0,
2611 })?
2612 },
2613 design: baseline.design,
2614 adaptive_diagnostics: Some(AdaptiveRegularizationDiagnostics {
2615 epsilon_0: eps_star[0],
2616 epsilon_g: eps_star[1],
2617 epsilon_c: eps_star[2],
2618 epsilon_outer_iterations: outer_iterations,
2619 mm_iterations: 0,
2620 converged: final_fit.outer_converged,
2621 maps,
2622 }),
2623 };
2624 enforce_term_constraint_feasibility(&fitted.design, &fitted.fit)?;
2625 Ok(fitted)
2626}
2627
2628fn relax_smoothing_rho_prior(
2660 options: &FitOptions,
2661 design: &TermCollectionDesign,
2662) -> gam_spec::RhoPrior {
2663 use gam_terms::basis::BasisMetadata;
2664 let base = &options.rho_prior;
2665 if matches!(
2668 base,
2669 gam_spec::RhoPrior::Flat | gam_spec::RhoPrior::Independent(_)
2670 ) {
2671 return base.clone();
2672 }
2673 let has_link_aux = options.sas_link.is_some()
2693 || options.optimize_sas
2694 || options.mixture_link.is_some()
2695 || options.optimize_mixture;
2696 let has_moving_kappa = design.smooth.terms.iter().any(|t| {
2697 matches!(
2698 t.metadata,
2699 BasisMetadata::Matern { .. }
2700 | BasisMetadata::Duchon { .. }
2701 | BasisMetadata::Sphere { .. }
2702 | BasisMetadata::SphereHarmonics { .. }
2703 | BasisMetadata::ConstantCurvature { .. }
2704 | BasisMetadata::MeasureJet { .. }
2705 )
2706 });
2707 let length_safe = !has_link_aux && !has_moving_kappa;
2714 if !length_safe {
2715 return base.clone();
2716 }
2717 let coords = &design.penaltyinfo;
2718 if coords.is_empty() {
2719 return base.clone();
2720 }
2721 let n_obs = design.design.nrows();
2732 let p_total = design.design.ncols();
2733 let underdetermined = n_obs < 2 * p_total;
2764 let relaxable_terms: std::collections::HashSet<&str> = design
2776 .smooth
2777 .terms
2778 .iter()
2779 .filter(|t| {
2780 matches!(
2781 t.metadata,
2782 BasisMetadata::BSpline1D { .. }
2783 | BasisMetadata::ThinPlate { .. }
2784 | BasisMetadata::TensorBSpline { .. }
2785 )
2786 && matches!(t.shape, gam_terms::smooth::ShapeConstraint::None)
2800 })
2801 .map(|t| t.name.as_str())
2802 .collect();
2803 let any_relaxed = coords.iter().any(|info| {
2804 info.termname
2805 .as_deref()
2806 .is_some_and(|name| relaxable_terms.contains(name))
2807 });
2808 if !any_relaxed {
2809 return base.clone();
2810 }
2811 let relaxed_prior = if underdetermined {
2816 gam_spec::RhoPrior::Normal {
2817 mean: 0.0,
2818 sd: RELAX_UNDERDETERMINED_RHO_SD,
2819 }
2820 } else {
2821 gam_spec::RhoPrior::Flat
2822 };
2823 let nullspace_select_prior = gam_spec::RhoPrior::PenalizedComplexity {
2850 upper: NULLSPACE_SELECT_PC_UPPER,
2851 tail_prob: NULLSPACE_SELECT_PC_TAIL_PROB,
2852 };
2853 let nullspace_degeneracy_prior = gam_spec::RhoPrior::Normal {
2880 mean: 0.0,
2881 sd: NULLSPACE_WELLDET_DEGENERACY_RHO_SD,
2882 };
2883 let per_coord = coords
2884 .iter()
2885 .map(|info| {
2886 let relax = info
2887 .termname
2888 .as_deref()
2889 .is_some_and(|name| relaxable_terms.contains(name));
2890 if !relax {
2891 return base.clone();
2892 }
2893 let is_nullspace =
2894 matches!(info.penalty.source, PenaltySource::DoublePenaltyNullspace);
2895 if is_nullspace {
2934 if underdetermined {
2935 nullspace_select_prior.clone()
2936 } else {
2937 nullspace_degeneracy_prior.clone()
2938 }
2939 } else {
2940 relaxed_prior.clone()
2941 }
2942 })
2943 .collect::<Vec<_>>();
2944 gam_spec::RhoPrior::Independent(per_coord)
2945}
2946
2947const RELAX_UNDERDETERMINED_RHO_SD: f64 = 15.0;
2960
2961const NULLSPACE_SELECT_PC_UPPER: f64 = 0.05;
2979
2980const NULLSPACE_SELECT_PC_TAIL_PROB: f64 = 0.01;
2990
2991fn adaptive_fit_options_base(options: &FitOptions, design: &TermCollectionDesign) -> FitOptions {
2992 FitOptions {
2993 latent_cloglog: options.latent_cloglog,
2994 mixture_link: options.mixture_link.clone(),
2995 optimize_mixture: options.optimize_mixture,
2996 sas_link: options.sas_link,
2997 optimize_sas: options.optimize_sas,
2998 compute_inference: options.compute_inference,
2999 skip_rho_posterior_inference: options.skip_rho_posterior_inference,
3000 max_iter: options.max_iter,
3001 tol: options.tol,
3002 nullspace_dims: design.nullspace_dims.clone(),
3003 linear_constraints: design.linear_constraints.clone(),
3004 firth_bias_reduction: options.firth_bias_reduction,
3005 adaptive_regularization: None,
3006 penalty_shrinkage_floor: options.penalty_shrinkage_floor,
3007 rho_prior: options.rho_prior.clone(),
3010 kronecker_penalty_system: design.kronecker_penalty_system(),
3011 kronecker_factored: design
3012 .smooth
3013 .terms
3014 .iter()
3015 .find_map(|t| t.kronecker_factored.clone()),
3016 persist_warm_start_disk: options.persist_warm_start_disk,
3017 }
3018}
3019
3020fn superseded_fit_options(options: &FitOptions) -> FitOptions {
3021 let mut fit_options = options.clone();
3022 fit_options.skip_rho_posterior_inference = true;
3023 fit_options
3024}
3025
3026#[derive(Clone)]
3027struct BoundedLinearTermMeta {
3028 col_idx: usize,
3029 min: f64,
3030 max: f64,
3031 prior: BoundedCoefficientPriorSpec,
3032}
3033
3034struct BoundedEffectiveJacobian {
3058 design: Array2<f64>,
3059 bounded_terms: Vec<BoundedLinearTermMeta>,
3060}
3061
3062impl BlockEffectiveJacobian for BoundedEffectiveJacobian {
3063 fn effective_jacobian_rows(
3064 &self,
3065 state: &FamilyLinearizationState<'_>,
3066 rows: std::ops::Range<usize>,
3067 ) -> Result<Array2<f64>, String> {
3068 let p = self.design.ncols();
3069 let n = self.design.nrows();
3070 let rows = rows.start.min(n)..rows.end.min(n);
3071 if !state.beta.is_empty() {
3072 if state.beta.len() != p {
3073 return Err(format!(
3074 "BoundedEffectiveJacobian::effective_jacobian_at: beta length {} != design \
3075 ncols {p}",
3076 state.beta.len(),
3077 ));
3078 }
3079 if state.beta.iter().any(|v| v.is_nan()) {
3080 return Err(
3081 "BoundedEffectiveJacobian::effective_jacobian_at: beta contains NaN"
3082 .to_string(),
3083 );
3084 }
3085 }
3086 let mut jac = self
3087 .design
3088 .slice(ndarray::s![rows.start..rows.end, ..])
3089 .to_owned();
3090 for term in &self.bounded_terms {
3091 let theta = if state.beta.is_empty() {
3092 0.0
3093 } else {
3094 state.beta[term.col_idx]
3095 };
3096 let (_, _, db_dtheta, _, _) = bounded_latent_derivatives(theta, term.min, term.max);
3097 jac.column_mut(term.col_idx).mapv_inplace(|v| v * db_dtheta);
3098 }
3099 Ok(jac)
3100 }
3101}
3102
3103#[derive(Clone)]
3104struct BoundedLinearFamily {
3105 family: LikelihoodSpec,
3106 latent_cloglog_state: Option<LatentCLogLogState>,
3107 mixture_link_state: Option<MixtureLinkState>,
3108 sas_link_state: Option<SasLinkState>,
3109 y: Array1<f64>,
3110 weights: Array1<f64>,
3111 design: Array2<f64>,
3112 designzeroed: Array2<f64>,
3113 offset: Array1<f64>,
3114 bounded_terms: Vec<BoundedLinearTermMeta>,
3115}
3116
3117#[derive(Clone)]
3118struct StandardFamilyObservationState {
3119 eta: Array1<f64>,
3120 mu: Array1<f64>,
3121 score: Array1<f64>,
3122 fisherweight: Array1<f64>,
3123 neghessian_eta: Array1<f64>,
3124 neghessian_eta_derivative: Array1<f64>,
3125 log_likelihood: f64,
3126}
3127
3128fn bounded_logit(z: f64) -> f64 {
3129 let zc = z.clamp(1e-12, 1.0 - 1e-12);
3130 (zc / (1.0 - zc)).ln()
3131}
3132
3133fn stable_sigmoid(theta: f64) -> f64 {
3134 if theta >= 0.0 {
3135 let exp_neg = (-theta).exp();
3136 1.0 / (1.0 + exp_neg)
3137 } else {
3138 let exp_pos = theta.exp();
3139 exp_pos / (1.0 + exp_pos)
3140 }
3141}
3142
3143fn bounded_latent_to_user(theta: f64, min: f64, max: f64) -> (f64, f64, f64) {
3144 let z = stable_sigmoid(theta);
3145 let width = max - min;
3146 let beta = min + width * z;
3147 let db_dtheta = width * z * (1.0 - z);
3148 (beta, z, db_dtheta)
3149}
3150
3151fn bounded_user_to_latent(beta: f64, min: f64, max: f64) -> f64 {
3162 let width = max - min;
3163 if width <= 0.0 || !width.is_finite() {
3164 return 0.0;
3165 }
3166 let z = (beta - min) / width;
3167 bounded_logit(z)
3168}
3169
3170#[derive(Debug, Clone, Copy)]
3174pub struct BoundedSampleColumn {
3175 pub col_idx: usize,
3177 pub min: f64,
3179 pub max: f64,
3181}
3182
3183pub fn sample_bounded_latent_posterior_internal(
3221 beta_user: &Array1<f64>,
3222 user_hessian: &Array2<f64>,
3223 bounded_columns: &[BoundedSampleColumn],
3224 n_draws: usize,
3225 sqrt_cov_scale: f64,
3226 base_seed: u64,
3227) -> Result<Array2<f64>, EstimationError> {
3228 let p = beta_user.len();
3229 if user_hessian.nrows() != p || user_hessian.ncols() != p {
3230 crate::bail_invalid_estim!(
3231 "bounded posterior sampling dimension mismatch: mode has {p} entries, user Hessian is {}x{}",
3232 user_hessian.nrows(),
3233 user_hessian.ncols()
3234 );
3235 }
3236
3237 let mut theta_mode = beta_user.clone();
3239 let mut jac_diag = Array1::<f64>::ones(p);
3240 for bc in bounded_columns {
3241 if bc.col_idx >= p {
3242 crate::bail_invalid_estim!(
3243 "bounded posterior sampling: bounded column index {} out of range for {p} coefficients",
3244 bc.col_idx
3245 );
3246 }
3247 let theta_i = bounded_user_to_latent(beta_user[bc.col_idx], bc.min, bc.max);
3248 let (_, _, db_dtheta) = bounded_latent_to_user(theta_i, bc.min, bc.max);
3249 theta_mode[bc.col_idx] = theta_i;
3250 jac_diag[bc.col_idx] = db_dtheta.max(1e-12);
3255 }
3256
3257 let mut h_latent = user_hessian.clone();
3260 for i in 0..p {
3261 let ji = jac_diag[i];
3262 if ji != 1.0 {
3263 h_latent.row_mut(i).mapv_inplace(|v| v * ji);
3264 h_latent.column_mut(i).mapv_inplace(|v| v * ji);
3265 }
3266 }
3267
3268 use gam_linalg::faer_ndarray::FaerCholesky as _;
3271 use rand::SeedableRng as _;
3272 let chol = h_latent.cholesky(faer::Side::Lower).map_err(|err| {
3273 EstimationError::InvalidInput(format!(
3274 "bounded posterior sampling: Cholesky of the latent penalized Hessian failed: {err:?}"
3275 ))
3276 })?;
3277 let l = chol.lower_triangular();
3278
3279 let mut draws = Array2::<f64>::zeros((n_draws, p));
3280 let mut eps = Array1::<f64>::zeros(p);
3281 let mut delta = Array1::<f64>::zeros(p);
3282 let mut rng = rand::rngs::StdRng::seed_from_u64(base_seed);
3283 for k in 0..n_draws {
3284 for e in eps.iter_mut() {
3285 *e = standard_normal_draw(&mut rng);
3286 }
3287 solve_lower_transpose_into(&l, &eps, &mut delta);
3288 for i in 0..p {
3289 draws[(k, i)] = theta_mode[i] + sqrt_cov_scale * delta[i];
3292 }
3293 for bc in bounded_columns {
3296 let (beta_draw, _, _) = bounded_latent_to_user(draws[(k, bc.col_idx)], bc.min, bc.max);
3297 draws[(k, bc.col_idx)] = beta_draw;
3298 }
3299 }
3300
3301 Ok(draws)
3302}
3303
3304#[inline]
3307fn standard_normal_draw<R: rand::Rng + ?Sized>(rng: &mut R) -> f64 {
3308 use rand::RngExt as _;
3309 let u1 = rng.random::<f64>().max(1e-16);
3310 let u2 = rng.random::<f64>();
3311 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
3312}
3313
3314fn solve_lower_transpose_into(l: &Array2<f64>, b: &Array1<f64>, out: &mut Array1<f64>) {
3318 let p = l.nrows();
3319 for i in (0..p).rev() {
3320 let mut acc = b[i];
3321 for j in (i + 1)..p {
3322 acc -= l[(j, i)] * out[j];
3323 }
3324 let diag = l[(i, i)];
3325 out[i] = if diag.abs() > 0.0 { acc / diag } else { 0.0 };
3326 }
3327}
3328
3329fn bounded_latent_derivatives(theta: f64, min: f64, max: f64) -> (f64, f64, f64, f64, f64) {
3330 let z = stable_sigmoid(theta);
3331 let width = max - min;
3332 let s = z * (1.0 - z);
3333 let beta = min + width * z;
3334 let db_dtheta = width * s;
3335 let d2b_dtheta2 = width * s * (1.0 - 2.0 * z);
3336 let d3b_dtheta3 = width * s * (1.0 - 6.0 * z + 6.0 * z * z);
3337 (beta, z, db_dtheta, d2b_dtheta2, d3b_dtheta3)
3338}
3339
3340fn bounded_prior_terms(theta: f64, prior: &BoundedCoefficientPriorSpec) -> (f64, f64, f64, f64) {
3341 let (a, b) = match prior {
3342 BoundedCoefficientPriorSpec::None => return (0.0, 0.0, 0.0, 0.0),
3344 BoundedCoefficientPriorSpec::Uniform => (1.0, 1.0),
3347 BoundedCoefficientPriorSpec::Beta { a, b } => (*a, *b),
3348 };
3349 let z = stable_sigmoid(theta).clamp(1e-12, 1.0 - 1e-12);
3350 let logp = a * z.ln() + b * (1.0 - z).ln();
3351 let grad = a - (a + b) * z;
3352 let neghess = (a + b) * z * (1.0 - z);
3353 let neghess_derivative = (a + b) * z * (1.0 - z) * (1.0 - 2.0 * z);
3354 (logp, grad, neghess, neghess_derivative)
3355}
3356
3357#[inline]
3366fn glm_eta_observation_state(
3367 w: f64,
3368 lmu: f64,
3369 lmumu: f64,
3370 lmumumu: f64,
3371 var: f64,
3372 d1: f64,
3373 d2: f64,
3374 d3: f64,
3375 mu_deriv_eps: f64,
3376) -> (f64, f64, f64, f64) {
3377 let score = w * lmu * d1;
3378 let fisherweight = (w * d1 * d1 / var).max(mu_deriv_eps);
3379 let neghessian = -w * (lmumu * d1 * d1 + lmu * d2);
3380 let neghessian_deriv = -w * (lmumumu * d1 * d1 * d1 + 3.0 * lmumu * d1 * d2 + lmu * d3);
3381 (score, fisherweight, neghessian, neghessian_deriv)
3382}
3383
3384fn evaluate_standard_familyobservations(
3385 family: LikelihoodSpec,
3386 latent_cloglog_state: Option<&LatentCLogLogState>,
3387 mixture_link_state: Option<&MixtureLinkState>,
3388 sas_link_state: Option<&SasLinkState>,
3389 y: &Array1<f64>,
3390 weights: &Array1<f64>,
3391 eta: &Array1<f64>,
3392) -> Result<StandardFamilyObservationState, EstimationError> {
3393 const PROB_EPS: f64 = 1e-10;
3394 const MU_DERIV_EPS: f64 = 1e-12;
3395 let n = y.len();
3396 if weights.len() != n || eta.len() != n {
3397 crate::bail_invalid_estim!("bounded family observation size mismatch");
3398 }
3399
3400 let mut mu = Array1::<f64>::zeros(n);
3401 let mut score = Array1::<f64>::zeros(n);
3402 let mut fisherweight = Array1::<f64>::zeros(n);
3403 let mut neghessian_eta = Array1::<f64>::zeros(n);
3404 let mut neghessian_eta_derivative = Array1::<f64>::zeros(n);
3405 let mut log_likelihood = 0.0;
3406
3407 for i in 0..n {
3408 let w = weights[i].max(0.0);
3409 let yi = y[i];
3410 let eta_i = eta[i];
3411 match (&family.response, &family.link) {
3412 (ResponseFamily::Gaussian, _) => {
3413 let resid = yi - eta_i;
3414 mu[i] = eta_i;
3415 score[i] = w * resid;
3416 fisherweight[i] = w.max(MU_DERIV_EPS);
3417 neghessian_eta[i] = w;
3418 neghessian_eta_derivative[i] = 0.0;
3419 log_likelihood += -0.5 * w * resid * resid;
3420 }
3421 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Logit)) => {
3422 let jet = logit_inverse_link_jet5(eta_i);
3423 mu[i] = jet.mu;
3424 score[i] = w * (yi - jet.mu);
3425 fisherweight[i] = jet.d1.max(MU_DERIV_EPS);
3426 neghessian_eta[i] = jet.d1;
3427 neghessian_eta_derivative[i] = jet.d2;
3428 let logmu = -gam_linalg::utils::stable_softplus(-eta_i);
3429 let log_one_minusmu = -gam_linalg::utils::stable_softplus(eta_i);
3430 log_likelihood += w * (yi * logmu + (1.0 - yi) * log_one_minusmu);
3431 }
3432 (ResponseFamily::Binomial, _) => {
3433 let inverse_link = if let Some(state) = latent_cloglog_state {
3434 Some(InverseLink::LatentCLogLog(*state))
3435 } else if let Some(state) = mixture_link_state {
3436 Some(InverseLink::Mixture(state.clone()))
3437 } else {
3438 sas_link_state.map(|state| {
3439 if family.is_binomial_beta_logistic() {
3440 InverseLink::BetaLogistic(*state)
3441 } else {
3442 InverseLink::Sas(*state)
3443 }
3444 })
3445 };
3446 let strategy_spec = LikelihoodSpec {
3447 response: family.response.clone(),
3448 link: inverse_link.clone().unwrap_or_else(|| family.link.clone()),
3449 };
3450 let jet = strategy_for_spec(&strategy_spec).inverse_link_jet(eta_i)?;
3451 let mu_i_raw = jet.mu;
3452 let dmu_deta_raw = jet.d1;
3453 let mu_i: f64 = mu_i_raw.clamp(PROB_EPS, 1.0 - PROB_EPS);
3454 let dmu_deta = dmu_deta_raw.max(MU_DERIV_EPS);
3455 let d2mu_deta2 = jet.d2;
3456 let d3mu_deta3 = jet.d3;
3457 let var = (mu_i * (1.0 - mu_i)).max(PROB_EPS);
3458 let lmu = (yi - mu_i) / var;
3459 let lmumu = -(yi / (mu_i * mu_i)) - ((1.0 - yi) / ((1.0 - mu_i) * (1.0 - mu_i)));
3460 let lmumumu = 2.0 * yi / (mu_i * mu_i * mu_i)
3461 - 2.0 * (1.0 - yi) / ((1.0 - mu_i) * (1.0 - mu_i) * (1.0 - mu_i));
3462 mu[i] = mu_i;
3463 score[i] = w * lmu * dmu_deta;
3464 fisherweight[i] = (w * dmu_deta * dmu_deta / var).max(MU_DERIV_EPS);
3465 neghessian_eta[i] = -w * (lmumu * dmu_deta * dmu_deta + lmu * d2mu_deta2);
3466 neghessian_eta_derivative[i] = -w
3467 * (lmumumu * dmu_deta * dmu_deta * dmu_deta
3468 + 3.0 * lmumu * dmu_deta * d2mu_deta2
3469 + lmu * d3mu_deta3);
3470 log_likelihood += w * (yi * mu_i.ln() + (1.0 - yi) * (1.0 - mu_i).ln());
3471 }
3472 (ResponseFamily::Poisson, _) => {
3473 let strategy_spec = LikelihoodSpec {
3476 response: family.response.clone(),
3477 link: family.link.clone(),
3478 };
3479 let jet = strategy_for_spec(&strategy_spec).inverse_link_jet(eta_i)?;
3480 let mu_i = jet.mu.max(PROB_EPS);
3481 let d1 = jet.d1.max(MU_DERIV_EPS);
3482 let var = mu_i;
3483 let lmu = yi / mu_i - 1.0;
3484 let lmumu = -yi / (mu_i * mu_i);
3485 let lmumumu = 2.0 * yi / (mu_i * mu_i * mu_i);
3486 let (s, f, nh, nhd) = glm_eta_observation_state(
3487 w, lmu, lmumu, lmumumu, var, d1, jet.d2, jet.d3, MU_DERIV_EPS,
3488 );
3489 mu[i] = mu_i;
3490 score[i] = s;
3491 fisherweight[i] = f;
3492 neghessian_eta[i] = nh;
3493 neghessian_eta_derivative[i] = nhd;
3494 log_likelihood += w * (yi * mu_i.ln() - mu_i);
3495 }
3496 (ResponseFamily::Tweedie { p }, _) => {
3497 let p = *p;
3502 let strategy_spec = LikelihoodSpec {
3503 response: family.response.clone(),
3504 link: family.link.clone(),
3505 };
3506 let jet = strategy_for_spec(&strategy_spec).inverse_link_jet(eta_i)?;
3507 let mu_i = jet.mu.max(PROB_EPS);
3508 let d1 = jet.d1.max(MU_DERIV_EPS);
3509 let var = mu_i.powf(p);
3510 let resid = yi - mu_i;
3511 let lmu = resid / var;
3512 let lmumu = -mu_i.powf(-p) - p * resid * mu_i.powf(-p - 1.0);
3513 let lmumumu =
3514 2.0 * p * mu_i.powf(-p - 1.0) + p * (p + 1.0) * resid * mu_i.powf(-p - 2.0);
3515 let (s, f, nh, nhd) = glm_eta_observation_state(
3516 w, lmu, lmumu, lmumumu, var, d1, jet.d2, jet.d3, MU_DERIV_EPS,
3517 );
3518 mu[i] = mu_i;
3519 score[i] = s;
3520 fisherweight[i] = f;
3521 neghessian_eta[i] = nh;
3522 neghessian_eta_derivative[i] = nhd;
3523 log_likelihood += w
3525 * (yi * mu_i.powf(1.0 - p) / (1.0 - p) - mu_i.powf(2.0 - p) / (2.0 - p));
3526 }
3527 (ResponseFamily::NegativeBinomial { theta, .. }, _) => {
3528 let theta = (*theta).max(PROB_EPS);
3532 let strategy_spec = LikelihoodSpec {
3533 response: family.response.clone(),
3534 link: family.link.clone(),
3535 };
3536 let jet = strategy_for_spec(&strategy_spec).inverse_link_jet(eta_i)?;
3537 let mu_i = jet.mu.max(PROB_EPS);
3538 let d1 = jet.d1.max(MU_DERIV_EPS);
3539 let mu_plus = mu_i + theta;
3540 let var = mu_i + mu_i * mu_i / theta;
3541 let lmu = yi / mu_i - (yi + theta) / mu_plus;
3542 let lmumu = -yi / (mu_i * mu_i) + (yi + theta) / (mu_plus * mu_plus);
3543 let lmumumu =
3544 2.0 * yi / (mu_i * mu_i * mu_i) - 2.0 * (yi + theta) / (mu_plus * mu_plus * mu_plus);
3545 let (s, f, nh, nhd) = glm_eta_observation_state(
3546 w, lmu, lmumu, lmumumu, var, d1, jet.d2, jet.d3, MU_DERIV_EPS,
3547 );
3548 mu[i] = mu_i;
3549 score[i] = s;
3550 fisherweight[i] = f;
3551 neghessian_eta[i] = nh;
3552 neghessian_eta_derivative[i] = nhd;
3553 log_likelihood += w * (yi * mu_i.ln() - (yi + theta) * mu_plus.ln());
3554 }
3555 (ResponseFamily::Beta { .. }, _) => {
3556 crate::bail_invalid_estim!(
3557 "bounded linear terms are not supported for BetaLogit fits"
3558 );
3559 }
3560 (ResponseFamily::Gamma, _) => {
3561 let strategy_spec = LikelihoodSpec {
3565 response: family.response.clone(),
3566 link: family.link.clone(),
3567 };
3568 let jet = strategy_for_spec(&strategy_spec).inverse_link_jet(eta_i)?;
3569 let mu_i = jet.mu.max(PROB_EPS);
3570 let d1 = jet.d1.max(MU_DERIV_EPS);
3571 let var = mu_i * mu_i;
3572 let lmu = yi / (mu_i * mu_i) - 1.0 / mu_i;
3573 let lmumu = -2.0 * yi / (mu_i * mu_i * mu_i) + 1.0 / (mu_i * mu_i);
3574 let lmumumu =
3575 6.0 * yi / (mu_i * mu_i * mu_i * mu_i) - 2.0 / (mu_i * mu_i * mu_i);
3576 let (s, f, nh, nhd) = glm_eta_observation_state(
3577 w, lmu, lmumu, lmumumu, var, d1, jet.d2, jet.d3, MU_DERIV_EPS,
3578 );
3579 mu[i] = mu_i;
3580 score[i] = s;
3581 fisherweight[i] = f;
3582 neghessian_eta[i] = nh;
3583 neghessian_eta_derivative[i] = nhd;
3584 log_likelihood += w * (-(yi / mu_i) - mu_i.ln());
3585 }
3586 (ResponseFamily::RoystonParmar, _) => {
3587 crate::bail_invalid_estim!(
3588 "bounded linear terms are not supported for survival model fits"
3589 );
3590 }
3591 }
3592 }
3593
3594 Ok(StandardFamilyObservationState {
3595 eta: eta.clone(),
3596 mu,
3597 score,
3598 fisherweight,
3599 neghessian_eta,
3600 neghessian_eta_derivative,
3601 log_likelihood,
3602 })
3603}
3604
3605#[derive(Clone, Copy, Debug, PartialEq, Eq)]
3606enum SpatialAdaptiveHyperKind {
3607 LogLambdaMagnitude,
3608 LogLambdaGradient,
3609 LogLambdaCurvature,
3610 LogEpsilonMagnitude,
3611 LogEpsilonGradient,
3612 LogEpsilonCurvature,
3613}
3614
3615impl SpatialAdaptiveHyperKind {
3616 fn component_index(self) -> usize {
3617 match self {
3618 SpatialAdaptiveHyperKind::LogLambdaMagnitude
3619 | SpatialAdaptiveHyperKind::LogEpsilonMagnitude => 0,
3620 SpatialAdaptiveHyperKind::LogLambdaGradient
3621 | SpatialAdaptiveHyperKind::LogEpsilonGradient => 1,
3622 SpatialAdaptiveHyperKind::LogLambdaCurvature
3623 | SpatialAdaptiveHyperKind::LogEpsilonCurvature => 2,
3624 }
3625 }
3626
3627 fn is_log_lambda(self) -> bool {
3628 matches!(
3629 self,
3630 SpatialAdaptiveHyperKind::LogLambdaMagnitude
3631 | SpatialAdaptiveHyperKind::LogLambdaGradient
3632 | SpatialAdaptiveHyperKind::LogLambdaCurvature
3633 )
3634 }
3635
3636 fn is_log_epsilon(self) -> bool {
3637 matches!(
3638 self,
3639 SpatialAdaptiveHyperKind::LogEpsilonMagnitude
3640 | SpatialAdaptiveHyperKind::LogEpsilonGradient
3641 | SpatialAdaptiveHyperKind::LogEpsilonCurvature
3642 )
3643 }
3644}
3645
3646#[derive(Clone, Copy, Debug)]
3647struct SpatialAdaptiveHyperSpec {
3648 cache_index: usize,
3649 kind: SpatialAdaptiveHyperKind,
3650}
3651
3652#[derive(Clone, Copy, Debug, PartialEq, Eq)]
3653enum SpatialAdaptiveExplicitSecondOrderKind {
3654 StructuralZero,
3655 LocalAlphaAlpha,
3656 LocalAlphaEta,
3657 SharedEtaEta,
3658}
3659
3660#[derive(Clone, Copy, Debug, PartialEq, Eq)]
3665enum AdaptiveComponent {
3666 Magnitude,
3667 Gradient,
3668 Curvature,
3669}
3670
3671impl AdaptiveComponent {
3672 fn from_index(index: usize) -> Result<Self, String> {
3673 match index {
3674 0 => Ok(AdaptiveComponent::Magnitude),
3675 1 => Ok(AdaptiveComponent::Gradient),
3676 2 => Ok(AdaptiveComponent::Curvature),
3677 other => Err(SmoothError::invalid_index(format!(
3678 "invalid adaptive component index {}",
3679 other
3680 ))
3681 .into()),
3682 }
3683 }
3684}
3685
3686#[derive(Clone, Copy, Debug, PartialEq, Eq)]
3691enum HyperDerivativeKind {
3692 Rho,
3694 LogEpsilonFirst,
3696 LogEpsilonSecond,
3698}
3699
3700#[derive(Clone, Copy, Debug, PartialEq, Eq)]
3705enum HyperDriftKind {
3706 Rho,
3707 LogEpsilon,
3708}
3709
3710impl SpatialAdaptiveHyperSpec {
3711 fn component_index(self) -> usize {
3712 self.kind.component_index()
3713 }
3714
3715 fn explicit_second_order_kind(self, other: Self) -> SpatialAdaptiveExplicitSecondOrderKind {
3716 if self.component_index() != other.component_index() {
3717 return SpatialAdaptiveExplicitSecondOrderKind::StructuralZero;
3718 }
3719 match (
3720 self.kind.is_log_lambda(),
3721 other.kind.is_log_lambda(),
3722 self.kind.is_log_epsilon(),
3723 other.kind.is_log_epsilon(),
3724 ) {
3725 (true, true, false, false) if self.cache_index == other.cache_index => {
3726 SpatialAdaptiveExplicitSecondOrderKind::LocalAlphaAlpha
3727 }
3728 (true, false, false, true) | (false, true, true, false) => {
3729 SpatialAdaptiveExplicitSecondOrderKind::LocalAlphaEta
3730 }
3731 (false, false, true, true) => SpatialAdaptiveExplicitSecondOrderKind::SharedEtaEta,
3732 _ => SpatialAdaptiveExplicitSecondOrderKind::StructuralZero,
3733 }
3734 }
3735}
3736
3737#[derive(Clone, Debug)]
3738struct SpatialAdaptiveTermHyperParams {
3739 lambda: [f64; 3],
3740 epsilon: [f64; 3],
3741}
3742
3743#[derive(Clone)]
3744struct SpatialAdaptiveExactEvaluation {
3745 obs: StandardFamilyObservationState,
3746 adaptive_states: Vec<SpatialPenaltyExactState>,
3747 adaptive_penalty_value: f64,
3748 adaptive_penaltygradient: Array1<f64>,
3749 adaptive_penaltyhessian: Array2<f64>,
3750 fixed_quadraticvalue: f64,
3751 fixed_quadraticgradient: Array1<f64>,
3752 fixed_quadratichessian: Array2<f64>,
3753}
3754
3755#[derive(Clone)]
3756struct CachedSpatialAdaptiveExactEvaluation {
3757 beta: Array1<f64>,
3758 eval: Arc<SpatialAdaptiveExactEvaluation>,
3759}
3760
3761impl SpatialAdaptiveExactEvaluation {
3762 fn total_penalty_value(&self) -> f64 {
3763 self.adaptive_penalty_value + self.fixed_quadraticvalue
3764 }
3765
3766 fn total_penaltygradient(&self) -> Array1<f64> {
3767 &self.adaptive_penaltygradient + &self.fixed_quadraticgradient
3768 }
3769
3770 fn total_penaltyhessian(&self) -> Array2<f64> {
3771 &self.adaptive_penaltyhessian + &self.fixed_quadratichessian
3772 }
3773
3774 fn totalobjectivehessian(&self, design: &Array2<f64>) -> Result<Array2<f64>, String> {
3775 let mut out = xt_diag_x_dense(design.view(), self.obs.neghessian_eta.view())?;
3776 out += &self.total_penaltyhessian();
3777 Ok(out)
3778 }
3779}
3780
3781#[derive(Clone)]
3782struct SpatialAdaptiveExactFamily {
3783 family: LikelihoodSpec,
3784 latent_cloglog_state: Option<LatentCLogLogState>,
3785 mixture_link_state: Option<MixtureLinkState>,
3786 sas_link_state: Option<SasLinkState>,
3787 y: Arc<Array1<f64>>,
3788 weights: Arc<Array1<f64>>,
3789 design: Arc<Array2<f64>>,
3790 offset: Arc<Array1<f64>>,
3791 linear_constraints: Option<LinearInequalityConstraints>,
3792 runtime_caches: Arc<Vec<SpatialOperatorRuntimeCache>>,
3793 adaptive_params: Vec<SpatialAdaptiveTermHyperParams>,
3794 fixed_quadratichessian: Arc<Array2<f64>>,
3795 hyperspecs: Arc<Vec<SpatialAdaptiveHyperSpec>>,
3796 exact_eval_cache: Arc<Mutex<Option<CachedSpatialAdaptiveExactEvaluation>>>,
3797}
3798
3799impl SpatialAdaptiveExactFamily {
3800 fn with_adaptive_params(
3801 &self,
3802 adaptive_params: Vec<SpatialAdaptiveTermHyperParams>,
3803 fixed_quadratichessian: Arc<Array2<f64>>,
3804 ) -> Self {
3805 Self {
3806 family: self.family.clone(),
3807 latent_cloglog_state: self.latent_cloglog_state,
3808 mixture_link_state: self.mixture_link_state.clone(),
3809 sas_link_state: self.sas_link_state,
3810 y: self.y.clone(),
3811 weights: self.weights.clone(),
3812 design: self.design.clone(),
3813 offset: self.offset.clone(),
3814 linear_constraints: self.linear_constraints.clone(),
3815 runtime_caches: self.runtime_caches.clone(),
3816 adaptive_params,
3817 fixed_quadratichessian,
3818 hyperspecs: self.hyperspecs.clone(),
3819 exact_eval_cache: Arc::new(Mutex::new(None)),
3820 }
3821 }
3822
3823 fn total_eta(&self, beta: &Array1<f64>) -> Array1<f64> {
3824 gam_linalg::faer_ndarray::fast_av(self.design.as_ref(), beta) + self.offset.as_ref()
3825 }
3826
3827 fn fixed_quadratic_terms(&self, beta: &Array1<f64>) -> (f64, Array1<f64>) {
3828 let grad = self.fixed_quadratichessian.dot(beta);
3829 let value = 0.5 * beta.dot(&grad);
3830 (value, grad)
3831 }
3832
3833 fn adaptive_penalty_value_only(&self, beta: &Array1<f64>) -> Result<f64, String> {
3834 let mut penalty_value = 0.0;
3835 for (cache_idx, cache) in self.runtime_caches.iter().enumerate() {
3836 let params = self.adaptive_params.get(cache_idx).ok_or_else(|| {
3837 format!(
3838 "missing adaptive parameter block for cache {}",
3839 cache.termname
3840 )
3841 })?;
3842 let beta_local = beta.slice(s![cache.coeff_global_range.clone()]);
3843 let state =
3844 SpatialPenaltyExactState::from_beta_local(beta_local, cache, params.epsilon)
3845 .map_err(|e| e.to_string())?;
3846 penalty_value += params.lambda[0] * state.magnitude.penalty_value();
3847 penalty_value += params.lambda[1] * state.gradient.penalty_value();
3848 penalty_value += params.lambda[2] * state.curvature.penalty_value();
3849 }
3850 Ok(penalty_value)
3851 }
3852
3853 fn zero_hyper_parts(&self) -> (Array1<f64>, Array2<f64>) {
3854 let total_dim = self.design.ncols();
3855 (
3856 Array1::<f64>::zeros(total_dim),
3857 Array2::<f64>::zeros((total_dim, total_dim)),
3858 )
3859 }
3860
3861 fn embed_local_hyper_parts(
3862 &self,
3863 coeff_range: &Range<usize>,
3864 local_grad: &Array1<f64>,
3865 local_hess: &Array2<f64>,
3866 ) -> (Array1<f64>, Array2<f64>) {
3867 let (mut beta_mixed, mut betahessian) = self.zero_hyper_parts();
3868 beta_mixed
3869 .slice_mut(s![coeff_range.clone()])
3870 .assign(local_grad);
3871 betahessian
3872 .slice_mut(s![coeff_range.clone(), coeff_range.clone()])
3873 .assign(local_hess);
3874 (beta_mixed, betahessian)
3875 }
3876
3877 fn embed_local_hyper_hessian(
3878 &self,
3879 coeff_range: &Range<usize>,
3880 local_hess: &Array2<f64>,
3881 ) -> Array2<f64> {
3882 let total_dim = self.design.ncols();
3883 let mut out = Array2::<f64>::zeros((total_dim, total_dim));
3884 out.slice_mut(s![coeff_range.clone(), coeff_range.clone()])
3885 .assign(local_hess);
3886 out
3887 }
3888
3889 fn adaptive_block_eval(
3898 &self,
3899 eval: &SpatialAdaptiveExactEvaluation,
3900 cache_idx: usize,
3901 component: AdaptiveComponent,
3902 derivative: HyperDerivativeKind,
3903 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
3904 let cache = self
3905 .runtime_caches
3906 .get(cache_idx)
3907 .ok_or_else(|| format!("adaptive cache index {} out of bounds", cache_idx))?;
3908 let params = self
3909 .adaptive_params
3910 .get(cache_idx)
3911 .ok_or_else(|| format!("adaptive hyperparameter block {} out of bounds", cache_idx))?;
3912 let state = eval
3913 .adaptive_states
3914 .get(cache_idx)
3915 .ok_or_else(|| format!("adaptive exact state index {} out of bounds", cache_idx))?;
3916
3917 let (objective_local, beta_mixed_local, betahessian_local) = match component {
3918 AdaptiveComponent::Magnitude => {
3919 let lambda = params.lambda[0];
3920 let mag = &state.magnitude;
3921 let (objective, gradient_coeff, hessian_diag) = match derivative {
3922 HyperDerivativeKind::Rho => (
3923 mag.penalty_value(),
3924 mag.betagradient_coeff(),
3925 mag.betahessian_diag(),
3926 ),
3927 HyperDerivativeKind::LogEpsilonFirst => (
3928 mag.log_epsilon_gradient_terms().sum(),
3929 mag.log_epsilon_betagradient_coeff(),
3930 mag.log_epsilon_betahessian_diag(),
3931 ),
3932 HyperDerivativeKind::LogEpsilonSecond => (
3933 mag.log_epsilon_hessian_terms().sum(),
3934 mag.log_epsilon_beta_mixed_second_coeff(),
3935 mag.log_epsilon_betahessian_second_diag(),
3936 ),
3937 };
3938 (
3939 lambda * objective,
3940 lambda * scalar_operatorgradient(&cache.d0, &gradient_coeff),
3941 lambda * scalar_operatorhessian(&cache.d0, &hessian_diag),
3942 )
3943 }
3944 AdaptiveComponent::Gradient => {
3945 let lambda = params.lambda[1];
3946 let grad = &state.gradient;
3947 let (objective, gradient_blocks, hessian_blocks) = match derivative {
3948 HyperDerivativeKind::Rho => (
3949 grad.penalty_value(),
3950 grad.betagradient_blocks(),
3951 grad.betahessian_blocks(),
3952 ),
3953 HyperDerivativeKind::LogEpsilonFirst => (
3954 grad.log_epsilon_gradient_terms().sum(),
3955 grad.log_epsilon_betagradient_blocks(),
3956 grad.log_epsilon_betahessian_blocks(),
3957 ),
3958 HyperDerivativeKind::LogEpsilonSecond => (
3959 grad.log_epsilon_hessian_terms().sum(),
3960 grad.log_epsilon_beta_mixed_second_blocks(),
3961 grad.log_epsilon_betahessian_second_blocks(),
3962 ),
3963 };
3964 (
3965 lambda * objective,
3966 lambda
3967 * grouped_operatorgradient(&cache.d1, cache.dimension, &gradient_blocks)
3968 .map_err(|e| e.to_string())?,
3969 lambda
3970 * grouped_operatorhessian(&cache.d1, cache.dimension, &hessian_blocks)
3971 .map_err(|e| e.to_string())?,
3972 )
3973 }
3974 AdaptiveComponent::Curvature => {
3975 let lambda = params.lambda[2];
3976 let group = cache.dimension * cache.dimension;
3977 let curv = &state.curvature;
3978 let (objective, gradient_blocks, hessian_blocks) = match derivative {
3979 HyperDerivativeKind::Rho => (
3980 curv.penalty_value(),
3981 curv.betagradient_blocks(),
3982 curv.betahessian_blocks(),
3983 ),
3984 HyperDerivativeKind::LogEpsilonFirst => (
3985 curv.log_epsilon_gradient_terms().sum(),
3986 curv.log_epsilon_betagradient_blocks(),
3987 curv.log_epsilon_betahessian_blocks(),
3988 ),
3989 HyperDerivativeKind::LogEpsilonSecond => (
3990 curv.log_epsilon_hessian_terms().sum(),
3991 curv.log_epsilon_beta_mixed_second_blocks(),
3992 curv.log_epsilon_betahessian_second_blocks(),
3993 ),
3994 };
3995 (
3996 lambda * objective,
3997 lambda
3998 * grouped_operatorgradient(&cache.d2, group, &gradient_blocks)
3999 .map_err(|e| e.to_string())?,
4000 lambda
4001 * grouped_operatorhessian(&cache.d2, group, &hessian_blocks)
4002 .map_err(|e| e.to_string())?,
4003 )
4004 }
4005 };
4006
4007 let (beta_mixed, betahessian) = self.embed_local_hyper_parts(
4008 &cache.coeff_global_range,
4009 &beta_mixed_local,
4010 &betahessian_local,
4011 );
4012 Ok((objective_local, beta_mixed, betahessian))
4013 }
4014
4015 fn adaptive_shared_log_epsilon_parts(
4016 &self,
4017 eval: &SpatialAdaptiveExactEvaluation,
4018 component: usize,
4019 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
4020 self.adaptive_shared_block_eval(eval, component, HyperDerivativeKind::LogEpsilonFirst)
4026 }
4027
4028 fn adaptive_shared_log_epsilon_second_parts(
4029 &self,
4030 eval: &SpatialAdaptiveExactEvaluation,
4031 component: usize,
4032 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
4033 self.adaptive_shared_block_eval(eval, component, HyperDerivativeKind::LogEpsilonSecond)
4039 }
4040
4041 fn adaptive_shared_block_eval(
4046 &self,
4047 eval: &SpatialAdaptiveExactEvaluation,
4048 component: usize,
4049 derivative: HyperDerivativeKind,
4050 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
4051 let component = AdaptiveComponent::from_index(component)?;
4052 let (mut score, mut hessian) = self.zero_hyper_parts();
4053 let mut objective = 0.0;
4054 for cache_idx in 0..self.runtime_caches.len() {
4055 let (local_objective, local_score, local_hessian) =
4056 self.adaptive_block_eval(eval, cache_idx, component, derivative)?;
4057 objective += local_objective;
4058 score += &local_score;
4059 hessian += &local_hessian;
4060 }
4061 Ok((objective, score, hessian))
4062 }
4063
4064 fn adaptive_shared_log_epsilon_drift(
4065 &self,
4066 eval: &SpatialAdaptiveExactEvaluation,
4067 component: usize,
4068 direction: &Array1<f64>,
4069 ) -> Result<Array2<f64>, String> {
4070 let component = AdaptiveComponent::from_index(component)?;
4074 let total_dim = self.design.ncols();
4075 let mut total = Array2::<f64>::zeros((total_dim, total_dim));
4076 for cache_idx in 0..self.runtime_caches.len() {
4077 total += &self.adaptive_block_drift_eval(
4078 eval,
4079 cache_idx,
4080 component,
4081 HyperDriftKind::LogEpsilon,
4082 direction,
4083 )?;
4084 }
4085 Ok(total)
4086 }
4087
4088 fn adaptive_explicit_second_order_parts(
4089 &self,
4090 eval: &SpatialAdaptiveExactEvaluation,
4091 left: SpatialAdaptiveHyperSpec,
4092 right: SpatialAdaptiveHyperSpec,
4093 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
4094 match left.explicit_second_order_kind(right) {
4103 SpatialAdaptiveExplicitSecondOrderKind::StructuralZero => {
4104 let (score, hessian) = self.zero_hyper_parts();
4105 Ok((0.0, score, hessian))
4106 }
4107 SpatialAdaptiveExplicitSecondOrderKind::LocalAlphaAlpha => self.adaptive_block_eval(
4108 eval,
4109 left.cache_index,
4110 AdaptiveComponent::from_index(left.component_index())?,
4111 HyperDerivativeKind::Rho,
4112 ),
4113 SpatialAdaptiveExplicitSecondOrderKind::LocalAlphaEta => {
4114 let local_alpha = if left.kind.is_log_lambda() {
4115 left
4116 } else {
4117 right
4118 };
4119 self.adaptive_block_eval(
4120 eval,
4121 local_alpha.cache_index,
4122 AdaptiveComponent::from_index(local_alpha.component_index())?,
4123 HyperDerivativeKind::LogEpsilonFirst,
4124 )
4125 }
4126 SpatialAdaptiveExplicitSecondOrderKind::SharedEtaEta => {
4127 self.adaptive_shared_log_epsilon_second_parts(eval, left.component_index())
4128 }
4129 }
4130 }
4131
4132 fn adaptive_block_drift_eval(
4140 &self,
4141 eval: &SpatialAdaptiveExactEvaluation,
4142 cache_idx: usize,
4143 component: AdaptiveComponent,
4144 drift: HyperDriftKind,
4145 direction: &Array1<f64>,
4146 ) -> Result<Array2<f64>, String> {
4147 let cache = self
4148 .runtime_caches
4149 .get(cache_idx)
4150 .ok_or_else(|| format!("adaptive cache index {} out of bounds", cache_idx))?;
4151 let params = self
4152 .adaptive_params
4153 .get(cache_idx)
4154 .ok_or_else(|| format!("adaptive hyperparameter block {} out of bounds", cache_idx))?;
4155 let state = eval
4156 .adaptive_states
4157 .get(cache_idx)
4158 .ok_or_else(|| format!("adaptive exact state index {} out of bounds", cache_idx))?;
4159 let direction_local = direction.slice(s![cache.coeff_global_range.clone()]);
4160
4161 let local_hessian = match component {
4162 AdaptiveComponent::Magnitude => {
4163 let d0_u = cache.d0.dot(&direction_local);
4164 let mag = &state.magnitude;
4165 let diag = match drift {
4166 HyperDriftKind::Rho => mag.directionalhessian_diag(&d0_u),
4167 HyperDriftKind::LogEpsilon => {
4168 mag.log_epsilon_betahessian_directional_diag(&d0_u)
4169 }
4170 };
4171 params.lambda[0] * scalar_operatorhessian(&cache.d0, &diag)
4172 }
4173 AdaptiveComponent::Gradient => {
4174 let d1_u = cache.d1.dot(&direction_local);
4175 let direction_blocks = collocationgradient_blocks(&d1_u, cache.dimension)
4176 .map_err(|e| e.to_string())?;
4177 let grad = &state.gradient;
4178 let blocks = match drift {
4179 HyperDriftKind::Rho => grad.directionalhessian_blocks(&direction_blocks),
4180 HyperDriftKind::LogEpsilon => {
4181 grad.log_epsilon_betahessian_directional_blocks(&direction_blocks)
4182 }
4183 };
4184 params.lambda[1]
4185 * grouped_operatorhessian(&cache.d1, cache.dimension, &blocks)
4186 .map_err(|e| e.to_string())?
4187 }
4188 AdaptiveComponent::Curvature => {
4189 let group = cache.dimension * cache.dimension;
4190 let d2_u = cache.d2.dot(&direction_local);
4191 let direction_blocks =
4192 collocationhessian_blocks(&d2_u, cache.dimension).map_err(|e| e.to_string())?;
4193 let curv = &state.curvature;
4194 let blocks = match drift {
4195 HyperDriftKind::Rho => curv.directionalhessian_blocks(&direction_blocks),
4196 HyperDriftKind::LogEpsilon => {
4197 curv.log_epsilon_betahessian_directional_blocks(&direction_blocks)
4198 }
4199 };
4200 params.lambda[2]
4201 * grouped_operatorhessian(&cache.d2, group, &blocks)
4202 .map_err(|e| e.to_string())?
4203 }
4204 };
4205
4206 Ok(self.embed_local_hyper_hessian(&cache.coeff_global_range, &local_hessian))
4207 }
4208
4209 fn adaptive_hyper_parts(
4210 &self,
4211 eval: &SpatialAdaptiveExactEvaluation,
4212 hyper: SpatialAdaptiveHyperSpec,
4213 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
4214 match hyper.kind {
4215 SpatialAdaptiveHyperKind::LogLambdaMagnitude
4218 | SpatialAdaptiveHyperKind::LogLambdaGradient
4219 | SpatialAdaptiveHyperKind::LogLambdaCurvature => self.adaptive_block_eval(
4220 eval,
4221 hyper.cache_index,
4222 AdaptiveComponent::from_index(hyper.component_index())?,
4223 HyperDerivativeKind::Rho,
4224 ),
4225 SpatialAdaptiveHyperKind::LogEpsilonMagnitude
4227 | SpatialAdaptiveHyperKind::LogEpsilonGradient
4228 | SpatialAdaptiveHyperKind::LogEpsilonCurvature => {
4229 self.adaptive_shared_log_epsilon_parts(eval, hyper.component_index())
4230 }
4231 }
4232 }
4233
4234 fn exact_evaluation_uncached(
4235 &self,
4236 beta: &Array1<f64>,
4237 ) -> Result<SpatialAdaptiveExactEvaluation, String> {
4238 let eta = self.total_eta(beta);
4239 let obs = evaluate_standard_familyobservations(
4240 self.family.clone(),
4241 self.latent_cloglog_state.as_ref(),
4242 self.mixture_link_state.as_ref(),
4243 self.sas_link_state.as_ref(),
4244 &self.y,
4245 &self.weights,
4246 &eta,
4247 )
4248 .map_err(|e| e.to_string())?;
4249 let p = beta.len();
4250 let mut penalty_value = 0.0;
4251 let mut penaltygradient = Array1::<f64>::zeros(p);
4252 let mut penaltyhessian = Array2::<f64>::zeros((p, p));
4253 let mut adaptive_states = Vec::with_capacity(self.runtime_caches.len());
4254
4255 for (cache_idx, cache) in self.runtime_caches.iter().enumerate() {
4256 let params = self.adaptive_params.get(cache_idx).ok_or_else(|| {
4257 format!(
4258 "missing adaptive parameter block for cache {}",
4259 cache.termname
4260 )
4261 })?;
4262 let beta_local = beta.slice(s![cache.coeff_global_range.clone()]);
4263 let state =
4264 SpatialPenaltyExactState::from_beta_local(beta_local, cache, params.epsilon)
4265 .map_err(|e| e.to_string())?;
4266
4267 let g0 = scalar_operatorgradient(&cache.d0, &state.magnitude.betagradient_coeff());
4268 let gg = grouped_operatorgradient(
4269 &cache.d1,
4270 cache.dimension,
4271 &state.gradient.betagradient_blocks(),
4272 )
4273 .map_err(|e| e.to_string())?;
4274 let gc = grouped_operatorgradient(
4275 &cache.d2,
4276 cache.dimension * cache.dimension,
4277 &state.curvature.betagradient_blocks(),
4278 )
4279 .map_err(|e| e.to_string())?;
4280 let h0 = scalar_operatorhessian(&cache.d0, &state.magnitude.betahessian_diag());
4281 let hg = grouped_operatorhessian(
4282 &cache.d1,
4283 cache.dimension,
4284 &state.gradient.betahessian_blocks(),
4285 )
4286 .map_err(|e| e.to_string())?;
4287 let hc = grouped_operatorhessian(
4288 &cache.d2,
4289 cache.dimension * cache.dimension,
4290 &state.curvature.betahessian_blocks(),
4291 )
4292 .map_err(|e| e.to_string())?;
4293
4294 let lambda0 = params.lambda[0];
4295 let lambdag = params.lambda[1];
4296 let lambdac = params.lambda[2];
4297
4298 penalty_value += lambda0 * state.magnitude.penalty_value();
4299 penalty_value += lambdag * state.gradient.penalty_value();
4300 penalty_value += lambdac * state.curvature.penalty_value();
4301
4302 let range = cache.coeff_global_range.clone();
4303 {
4304 let mut grad_local = penaltygradient.slice_mut(s![range.clone()]);
4305 grad_local += &(g0.mapv(|v| lambda0 * v));
4306 grad_local += &(gg.mapv(|v| lambdag * v));
4307 grad_local += &(gc.mapv(|v| lambdac * v));
4308 }
4309 {
4310 let mut h_local = penaltyhessian.slice_mut(s![range.clone(), range]);
4311 h_local += &h0.mapv(|v| lambda0 * v);
4312 h_local += &hg.mapv(|v| lambdag * v);
4313 h_local += &hc.mapv(|v| lambdac * v);
4314 }
4315
4316 adaptive_states.push(state);
4317 }
4318
4319 let (fixed_quadraticvalue, fixed_quadraticgradient) = self.fixed_quadratic_terms(beta);
4320 Ok(SpatialAdaptiveExactEvaluation {
4321 obs,
4322 adaptive_states,
4323 adaptive_penalty_value: penalty_value,
4324 adaptive_penaltygradient: penaltygradient,
4325 adaptive_penaltyhessian: penaltyhessian,
4326 fixed_quadraticvalue,
4327 fixed_quadraticgradient,
4328 fixed_quadratichessian: self.fixed_quadratichessian.as_ref().clone(),
4329 })
4330 }
4331
4332 fn exact_evaluation(
4333 &self,
4334 beta: &Array1<f64>,
4335 ) -> Result<Arc<SpatialAdaptiveExactEvaluation>, String> {
4336 {
4337 let cache = self
4338 .exact_eval_cache
4339 .lock()
4340 .map_err(|_| "spatial adaptive exact-evaluation cache lock poisoned".to_string())?;
4341 if let Some(cached) = cache.as_ref()
4342 && cached.beta.len() == beta.len()
4343 && cached
4344 .beta
4345 .iter()
4346 .zip(beta.iter())
4347 .all(|(&left, &right)| left == right)
4348 {
4349 return Ok(Arc::clone(&cached.eval));
4350 }
4351 }
4352
4353 let eval = Arc::new(self.exact_evaluation_uncached(beta)?);
4354 let mut cache = self
4355 .exact_eval_cache
4356 .lock()
4357 .map_err(|_| "spatial adaptive exact-evaluation cache lock poisoned".to_string())?;
4358 *cache = Some(CachedSpatialAdaptiveExactEvaluation {
4359 beta: beta.clone(),
4360 eval: Arc::clone(&eval),
4361 });
4362 Ok(eval)
4363 }
4364
4365 fn exacthessian_directional_derivative_from_evaluation(
4366 &self,
4367 beta: &Array1<f64>,
4368 eval: &SpatialAdaptiveExactEvaluation,
4369 direction: &Array1<f64>,
4370 ) -> Result<Array2<f64>, String> {
4371 assert_eq!(
4372 beta.len(),
4373 direction.len(),
4374 "beta/direction length mismatch",
4375 );
4376 let d_eta = gam_linalg::faer_ndarray::fast_av(self.design.as_ref(), direction);
4377 let mut total = xt_diag_x_dense(
4378 self.design.view(),
4379 (&eval.obs.neghessian_eta_derivative * &d_eta).view(),
4380 )?;
4381 for (cache_idx, cache) in self.runtime_caches.iter().enumerate() {
4382 let params = self.adaptive_params.get(cache_idx).ok_or_else(|| {
4383 format!(
4384 "missing adaptive parameter block for cache {}",
4385 cache.termname
4386 )
4387 })?;
4388 let state = eval
4389 .adaptive_states
4390 .get(cache_idx)
4391 .ok_or_else(|| format!("missing adaptive state for cache {}", cache.termname))?;
4392 let direction_local = direction.slice(s![cache.coeff_global_range.clone()]);
4393 let d0_u = cache.d0.dot(&direction_local);
4394 let d1_u = cache.d1.dot(&direction_local);
4395 let d2_u = cache.d2.dot(&direction_local);
4396 let h0 =
4397 scalar_operatorhessian(&cache.d0, &state.magnitude.directionalhessian_diag(&d0_u))
4398 .mapv(|v| params.lambda[0] * v);
4399 let hg = grouped_operatorhessian(
4400 &cache.d1,
4401 cache.dimension,
4402 &state.gradient.directionalhessian_blocks(
4403 &collocationgradient_blocks(&d1_u, cache.dimension)
4404 .map_err(|e| e.to_string())?,
4405 ),
4406 )
4407 .map_err(|e| e.to_string())?
4408 .mapv(|v| params.lambda[1] * v);
4409 let hc = grouped_operatorhessian(
4410 &cache.d2,
4411 cache.dimension * cache.dimension,
4412 &state.curvature.directionalhessian_blocks(
4413 &collocationhessian_blocks(&d2_u, cache.dimension)
4414 .map_err(|e| e.to_string())?,
4415 ),
4416 )
4417 .map_err(|e| e.to_string())?
4418 .mapv(|v| params.lambda[2] * v);
4419 let range = cache.coeff_global_range.clone();
4420 let mut local = total.slice_mut(s![range.clone(), range]);
4421 local += &h0;
4422 local += &hg;
4423 local += &hc;
4424 }
4425 Ok(total)
4426 }
4427
4428 fn exacthessian_second_directional_derivative_from_evaluation(
4449 &self,
4450 eval: &SpatialAdaptiveExactEvaluation,
4451 direction_u: &Array1<f64>,
4452 direction_v: &Array1<f64>,
4453 ) -> Result<Option<Array2<f64>>, String> {
4454 let p = self.design.ncols();
4455 if eval.obs.neghessian_eta_derivative.iter().any(|&w| w != 0.0) {
4457 return Ok(None);
4458 }
4459 let mut total = Array2::<f64>::zeros((p, p));
4460 for (cache_idx, cache) in self.runtime_caches.iter().enumerate() {
4461 let params = self.adaptive_params.get(cache_idx).ok_or_else(|| {
4462 format!(
4463 "missing adaptive parameter block for cache {}",
4464 cache.termname
4465 )
4466 })?;
4467 let state = eval
4468 .adaptive_states
4469 .get(cache_idx)
4470 .ok_or_else(|| format!("missing adaptive state for cache {}", cache.termname))?;
4471 let u_local = direction_u.slice(s![cache.coeff_global_range.clone()]);
4472 let v_local = direction_v.slice(s![cache.coeff_global_range.clone()]);
4473
4474 let q0_u = cache.d0.dot(&u_local);
4476 let q0_v = cache.d0.dot(&v_local);
4477 let h0 = scalar_operatorhessian(
4478 &cache.d0,
4479 &state.magnitude.second_directionalhessian_diag(&q0_u, &q0_v),
4480 )
4481 .mapv(|x| params.lambda[0] * x);
4482
4483 let a1 = collocationgradient_blocks(&cache.d1.dot(&u_local), cache.dimension)
4485 .map_err(|e| e.to_string())?;
4486 let b1 = collocationgradient_blocks(&cache.d1.dot(&v_local), cache.dimension)
4487 .map_err(|e| e.to_string())?;
4488 let hg = grouped_operatorhessian(
4489 &cache.d1,
4490 cache.dimension,
4491 &state.gradient.second_directionalhessian_blocks(&a1, &b1),
4492 )
4493 .map_err(|e| e.to_string())?
4494 .mapv(|x| params.lambda[1] * x);
4495
4496 let a2 = collocationhessian_blocks(&cache.d2.dot(&u_local), cache.dimension)
4498 .map_err(|e| e.to_string())?;
4499 let b2 = collocationhessian_blocks(&cache.d2.dot(&v_local), cache.dimension)
4500 .map_err(|e| e.to_string())?;
4501 let hc = grouped_operatorhessian(
4502 &cache.d2,
4503 cache.dimension * cache.dimension,
4504 &state.curvature.second_directionalhessian_blocks(&a2, &b2),
4505 )
4506 .map_err(|e| e.to_string())?
4507 .mapv(|x| params.lambda[2] * x);
4508
4509 let range = cache.coeff_global_range.clone();
4510 let mut local = total.slice_mut(s![range.clone(), range]);
4511 local += &h0;
4512 local += &hg;
4513 local += &hc;
4514 }
4515 Ok(Some(total))
4516 }
4517}
4518
4519impl CustomFamily for SpatialAdaptiveExactFamily {
4520 fn joint_jeffreys_term_required(&self) -> bool {
4524 true
4525 }
4526
4527 fn joint_jeffreys_information_with_specs(
4564 &self,
4565 block_states: &[ParameterBlockState],
4566 specs: &[ParameterBlockSpec],
4567 ) -> Result<Option<Array2<f64>>, String> {
4568 let spec = expect_single_blockspec(specs, "spatial adaptive exact family")?;
4569 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4570 if spec.design.ncols() != beta.len() {
4571 return Err(SmoothError::dimension_mismatch(format!(
4572 "spatial adaptive Jeffreys information: spec design has {} columns, beta has {}",
4573 spec.design.ncols(),
4574 beta.len()
4575 ))
4576 .into());
4577 }
4578 let eval = self.exact_evaluation(beta)?;
4579 Ok(Some(xt_diag_x_dense(
4580 self.design.view(),
4581 eval.obs.neghessian_eta.view(),
4582 )?))
4583 }
4584
4585 fn joint_jeffreys_information_directional_derivative_with_specs(
4586 &self,
4587 block_states: &[ParameterBlockState],
4588 specs: &[ParameterBlockSpec],
4589 d_beta_flat: &Array1<f64>,
4590 ) -> Result<Option<Array2<f64>>, String> {
4591 let spec = expect_single_blockspec(specs, "spatial adaptive exact family")?;
4597 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4598 if spec.design.ncols() != d_beta_flat.len() {
4599 return Err(SmoothError::dimension_mismatch(format!(
4600 "spatial adaptive Jeffreys directional derivative: spec design has {} columns, direction has {}",
4601 spec.design.ncols(),
4602 d_beta_flat.len()
4603 ))
4604 .into());
4605 }
4606 let eval = self.exact_evaluation(beta)?;
4607 let d_eta = gam_linalg::faer_ndarray::fast_av(self.design.as_ref(), d_beta_flat);
4608 Ok(Some(xt_diag_x_dense(
4609 self.design.view(),
4610 (&eval.obs.neghessian_eta_derivative * &d_eta).view(),
4611 )?))
4612 }
4613
4614 fn joint_jeffreys_information_second_directional_derivative_with_specs(
4615 &self,
4616 block_states: &[ParameterBlockState],
4617 specs: &[ParameterBlockSpec],
4618 d_beta_u_flat: &Array1<f64>,
4619 d_betav_flat: &Array1<f64>,
4620 ) -> Result<Option<Array2<f64>>, String> {
4621 let spec = expect_single_blockspec(specs, "spatial adaptive exact family")?;
4628 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4629 if spec.design.ncols() != beta.len()
4630 || d_beta_u_flat.len() != beta.len()
4631 || d_betav_flat.len() != beta.len()
4632 {
4633 return Err(SmoothError::dimension_mismatch(format!(
4634 "spatial adaptive Jeffreys second-direction length mismatch: spec cols={}, dirs=({}, {}), expected {}",
4635 spec.design.ncols(),
4636 d_beta_u_flat.len(),
4637 d_betav_flat.len(),
4638 beta.len()
4639 ))
4640 .into());
4641 }
4642 let eval = self.exact_evaluation(beta)?;
4643 if eval.obs.neghessian_eta_derivative.iter().any(|&w| w != 0.0) {
4644 return Ok(None);
4645 }
4646 Ok(Some(Array2::<f64>::zeros((beta.len(), beta.len()))))
4647 }
4648
4649 fn joint_jeffreys_information_matches_observed_hessian(&self) -> bool {
4650 false
4655 }
4656
4657 fn joint_jeffreys_information_depends_on_psi(&self) -> bool {
4658 false
4667 }
4668
4669 fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
4670 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4671 let eval = self.exact_evaluation(beta)?;
4672 let mut gradient = fast_atv(&self.design, &eval.obs.score);
4673 gradient -= &eval.total_penaltygradient();
4674 let mut hessian = xt_diag_x_dense(self.design.view(), eval.obs.neghessian_eta.view())?;
4675 hessian += &eval.total_penaltyhessian();
4676 Ok(FamilyEvaluation {
4677 log_likelihood: eval.obs.log_likelihood - eval.total_penalty_value(),
4678 blockworking_sets: vec![BlockWorkingSet::ExactNewton {
4679 gradient,
4680 hessian: SymmetricMatrix::Dense(hessian),
4681 }],
4682 })
4683 }
4684
4685 fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
4686 let state = expect_single_block_state(block_states, "spatial adaptive exact family")?;
4687 let beta = &state.beta;
4688 let obs = evaluate_standard_familyobservations(
4689 self.family.clone(),
4690 self.latent_cloglog_state.as_ref(),
4691 self.mixture_link_state.as_ref(),
4692 self.sas_link_state.as_ref(),
4693 &self.y,
4694 &self.weights,
4695 &state.eta,
4696 )
4697 .map_err(|e| e.to_string())?;
4698 let adaptive_penalty = self.adaptive_penalty_value_only(beta)?;
4699 let (fixed_quadratic, _) = self.fixed_quadratic_terms(beta);
4700 Ok(obs.log_likelihood - adaptive_penalty - fixed_quadratic)
4701 }
4702
4703 fn exact_newton_outerobjective(&self) -> ExactNewtonOuterObjective {
4704 ExactNewtonOuterObjective::StrictPseudoLaplace
4705 }
4706
4707 fn exact_newton_joint_hessian(
4708 &self,
4709 block_states: &[ParameterBlockState],
4710 ) -> Result<Option<Array2<f64>>, String> {
4711 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4712 let eval = self.exact_evaluation(beta)?;
4713 Ok(Some(eval.totalobjectivehessian(&self.design)?))
4714 }
4715
4716 fn exact_newton_hessian_directional_derivative(
4717 &self,
4718 block_states: &[ParameterBlockState],
4719 block_idx: usize,
4720 d_beta: &Array1<f64>,
4721 ) -> Result<Option<Array2<f64>>, String> {
4722 expect_block_idx_zero(block_idx, "spatial adaptive exact family", "")?;
4723 self.exact_newton_joint_hessian_directional_derivative(block_states, d_beta)
4724 }
4725
4726 fn exact_newton_joint_hessian_directional_derivative(
4727 &self,
4728 block_states: &[ParameterBlockState],
4729 d_beta_flat: &Array1<f64>,
4730 ) -> Result<Option<Array2<f64>>, String> {
4731 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4732 if d_beta_flat.len() != beta.len() {
4733 return Err(SmoothError::dimension_mismatch(format!(
4734 "spatial adaptive exact family direction length mismatch: got {}, expected {}",
4735 d_beta_flat.len(),
4736 beta.len()
4737 ))
4738 .into());
4739 }
4740 let eval = self.exact_evaluation(beta)?;
4741 Ok(Some(
4742 self.exacthessian_directional_derivative_from_evaluation(beta, &eval, d_beta_flat)?,
4743 ))
4744 }
4745
4746 fn exact_newton_joint_hessiansecond_directional_derivative(
4747 &self,
4748 block_states: &[ParameterBlockState],
4749 d_beta_u_flat: &Array1<f64>,
4750 d_betav_flat: &Array1<f64>,
4751 ) -> Result<Option<Array2<f64>>, String> {
4752 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4753 if d_beta_u_flat.len() != beta.len() || d_betav_flat.len() != beta.len() {
4754 return Err(SmoothError::dimension_mismatch(format!(
4755 "spatial adaptive exact family second-direction length mismatch: got ({}, {}), expected {}",
4756 d_beta_u_flat.len(),
4757 d_betav_flat.len(),
4758 beta.len()
4759 ))
4760 .into());
4761 }
4762 let eval = self.exact_evaluation(beta)?;
4763 self.exacthessian_second_directional_derivative_from_evaluation(
4764 &eval,
4765 d_beta_u_flat,
4766 d_betav_flat,
4767 )
4768 }
4769
4770 fn block_linear_constraints(
4771 &self,
4772 block_states: &[ParameterBlockState],
4773 block_idx: usize,
4774 block_spec: &ParameterBlockSpec,
4775 ) -> Result<Option<LinearInequalityConstraints>, String> {
4776 assert!(!block_states.is_empty(), "block_states must be non-empty");
4777 assert!(
4778 !block_spec.name.is_empty(),
4779 "block spec name must be non-empty",
4780 );
4781 expect_block_idx_zero(block_idx, "spatial adaptive exact family", "")?;
4782 Ok(self.linear_constraints.clone())
4783 }
4784
4785 fn exact_newton_joint_psi_terms(
4786 &self,
4787 block_states: &[ParameterBlockState],
4788 specs: &[ParameterBlockSpec],
4789 derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
4790 psi_index: usize,
4791 ) -> Result<Option<ExactNewtonJointPsiTerms>, String> {
4792 if block_states.len() != 1 || specs.len() != 1 || derivative_blocks.len() != 1 {
4793 return Err(SmoothError::dimension_mismatch(format!(
4794 "spatial adaptive exact family expects one block/state/spec/psi payload, got states={} specs={} deriv_blocks={}",
4795 block_states.len(),
4796 specs.len(),
4797 derivative_blocks.len()
4798 ))
4799 .into());
4800 }
4801 derivative_blocks[0]
4802 .get(psi_index)
4803 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_index))?;
4804 let hyper = self
4805 .hyperspecs
4806 .get(psi_index)
4807 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_index))?;
4808 let beta = &block_states[0].beta;
4809 let eval = self.exact_evaluation(beta)?;
4810 let (direct, beta_mixed, betahessian_explicit) =
4811 self.adaptive_hyper_parts(&eval, *hyper)?;
4812
4813 Ok(Some(ExactNewtonJointPsiTerms {
4834 objective_psi: direct,
4835 score_psi: beta_mixed,
4836 hessian_psi: betahessian_explicit,
4837 hessian_psi_operator: None,
4838 }))
4839 }
4840
4841 fn exact_newton_joint_psisecond_order_terms(
4842 &self,
4843 block_states: &[ParameterBlockState],
4844 specs: &[ParameterBlockSpec],
4845 derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
4846 psi_i: usize,
4847 psi_j: usize,
4848 ) -> Result<Option<gam_problem::ExactNewtonJointPsiSecondOrderTerms>, String> {
4849 if block_states.len() != 1 || specs.len() != 1 || derivative_blocks.len() != 1 {
4850 return Err(SmoothError::dimension_mismatch(format!(
4851 "spatial adaptive exact family expects one block/state/spec/psi payload, got states={} specs={} deriv_blocks={}",
4852 block_states.len(),
4853 specs.len(),
4854 derivative_blocks.len()
4855 ))
4856 .into());
4857 }
4858 derivative_blocks[0]
4859 .get(psi_i)
4860 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_i))?;
4861 derivative_blocks[0]
4862 .get(psi_j)
4863 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_j))?;
4864 let hyper_i = self
4865 .hyperspecs
4866 .get(psi_i)
4867 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_i))?;
4868 let hyper_j = self
4869 .hyperspecs
4870 .get(psi_j)
4871 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_j))?;
4872 let beta = &block_states[0].beta;
4873 let eval = self.exact_evaluation(beta)?;
4874 let (objective_psi_psi, score_psi_psi, hessian_psi_psi) =
4875 self.adaptive_explicit_second_order_parts(&eval, *hyper_i, *hyper_j)?;
4876
4877 Ok(Some(
4878 gam_problem::ExactNewtonJointPsiSecondOrderTerms {
4879 objective_psi_psi,
4880 score_psi_psi,
4881 hessian_psi_psi,
4882 hessian_psi_psi_operator: None,
4883 },
4884 ))
4885 }
4886
4887 fn exact_newton_joint_psihessian_directional_derivative(
4888 &self,
4889 block_states: &[ParameterBlockState],
4890 specs: &[ParameterBlockSpec],
4891 derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
4892 psi_index: usize,
4893 direction: &Array1<f64>,
4894 ) -> Result<Option<Array2<f64>>, String> {
4895 if block_states.len() != 1 || specs.len() != 1 || derivative_blocks.len() != 1 {
4896 return Err(SmoothError::dimension_mismatch(format!(
4897 "spatial adaptive exact family expects one block/state/spec/psi payload, got states={} specs={} deriv_blocks={}",
4898 block_states.len(),
4899 specs.len(),
4900 derivative_blocks.len()
4901 ))
4902 .into());
4903 }
4904 let beta = &block_states[0].beta;
4905 if direction.len() != beta.len() {
4906 return Err(SmoothError::dimension_mismatch(format!(
4907 "spatial adaptive exact family direction length mismatch: got {}, expected {}",
4908 direction.len(),
4909 beta.len()
4910 ))
4911 .into());
4912 }
4913 derivative_blocks[0]
4914 .get(psi_index)
4915 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_index))?;
4916 let hyper = self
4917 .hyperspecs
4918 .get(psi_index)
4919 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_index))?;
4920 let eval = self.exact_evaluation(beta)?;
4921 let drift = match hyper.kind {
4922 SpatialAdaptiveHyperKind::LogLambdaMagnitude
4923 | SpatialAdaptiveHyperKind::LogLambdaGradient
4924 | SpatialAdaptiveHyperKind::LogLambdaCurvature => self.adaptive_block_drift_eval(
4925 &eval,
4926 hyper.cache_index,
4927 AdaptiveComponent::from_index(hyper.kind.component_index())?,
4928 HyperDriftKind::Rho,
4929 direction,
4930 )?,
4931 SpatialAdaptiveHyperKind::LogEpsilonMagnitude
4932 | SpatialAdaptiveHyperKind::LogEpsilonGradient
4933 | SpatialAdaptiveHyperKind::LogEpsilonCurvature => self
4934 .adaptive_shared_log_epsilon_drift(
4935 &eval,
4936 hyper.kind.component_index(),
4937 direction,
4938 )?,
4939 };
4940 Ok(Some(drift))
4941 }
4942}
4943
4944fn expect_single_block_state<'a>(
4945 block_states: &'a [ParameterBlockState],
4946 family_name: &str,
4947) -> Result<&'a ParameterBlockState, String> {
4948 crate::block_layout::block_count::validate_block_count::<SmoothError>(
4949 family_name,
4950 1,
4951 block_states.len(),
4952 )?;
4953 Ok(&block_states[0])
4954}
4955
4956fn expect_single_blockspec<'a>(
4957 specs: &'a [ParameterBlockSpec],
4958 family_name: &str,
4959) -> Result<&'a ParameterBlockSpec, String> {
4960 crate::block_layout::block_count::validate_block_count::<SmoothError>(
4961 family_name,
4962 1,
4963 specs.len(),
4964 )?;
4965 Ok(&specs[0])
4966}
4967
4968fn expect_block_idx_zero(block_idx: usize, family_name: &str, context: &str) -> Result<(), String> {
4969 if block_idx != 0 {
4970 return Err(SmoothError::invalid_index(format!(
4971 "{family_name} expects block_idx 0{context}, got {block_idx}"
4972 ))
4973 .into());
4974 }
4975 Ok::<(), _>(())
4976}
4977
4978impl BoundedLinearFamily {
4979 fn bounded_term_derivative_data(
4980 &self,
4981 latent_beta: &Array1<f64>,
4982 ) -> (
4983 Array1<f64>,
4984 Array1<f64>,
4985 Array1<f64>,
4986 Array1<f64>,
4987 Array1<f64>,
4988 ) {
4989 let p = latent_beta.len();
4990 let mut beta_user = latent_beta.clone();
4991 let mut jac_diag = Array1::<f64>::ones(p);
4992 let mut second_diag = Array1::<f64>::zeros(p);
4993 let mut third_diag = Array1::<f64>::zeros(p);
4994 let mut priorthird = Array1::<f64>::zeros(p);
4995 for term in &self.bounded_terms {
4996 let (beta, _, db_dtheta, d2b_dtheta2, d3b_dtheta3) =
4997 bounded_latent_derivatives(latent_beta[term.col_idx], term.min, term.max);
4998 beta_user[term.col_idx] = beta;
4999 jac_diag[term.col_idx] = db_dtheta;
5000 second_diag[term.col_idx] = d2b_dtheta2;
5001 third_diag[term.col_idx] = d3b_dtheta3;
5002 let (_, _, _, prior_neghess_derivative) =
5003 bounded_prior_terms(latent_beta[term.col_idx], &term.prior);
5004 priorthird[term.col_idx] = prior_neghess_derivative;
5005 }
5006 (beta_user, jac_diag, second_diag, third_diag, priorthird)
5007 }
5008
5009 fn user_beta_and_jacobian(&self, latent_beta: &Array1<f64>) -> (Array1<f64>, Array1<f64>) {
5010 let (beta_user, jac_diag, _, _, _) = self.bounded_term_derivative_data(latent_beta);
5011 (beta_user, jac_diag)
5012 }
5013
5014 fn nonlinear_offset_from_latent(&self, latent_beta: &Array1<f64>) -> Array1<f64> {
5015 let mut offset = self.offset.clone();
5016 for term in &self.bounded_terms {
5017 let (beta, _, _) =
5018 bounded_latent_to_user(latent_beta[term.col_idx], term.min, term.max);
5019 offset.scaled_add(beta, &self.design.column(term.col_idx));
5020 }
5021 offset
5022 }
5023
5024 fn effective_design_for_latent(&self, jac_diag: &Array1<f64>) -> Array2<f64> {
5025 let mut x_eff = self.design.clone();
5026 for term in &self.bounded_terms {
5027 x_eff
5028 .column_mut(term.col_idx)
5029 .mapv_inplace(|v| v * jac_diag[term.col_idx]);
5030 }
5031 x_eff
5032 }
5033
5034 fn exacthessian_andgradient(
5035 &self,
5036 latent_beta: &Array1<f64>,
5037 ) -> Result<
5038 (
5039 StandardFamilyObservationState,
5040 Array2<f64>,
5041 Array1<f64>,
5042 f64,
5043 Array1<f64>,
5044 Array1<f64>,
5045 Array1<f64>,
5046 ),
5047 String,
5048 > {
5049 let (_, jac_diag, second_diag, third_diag, priorthird) =
5050 self.bounded_term_derivative_data(latent_beta);
5051 let x_eff = self.effective_design_for_latent(&jac_diag);
5052 let eta =
5053 self.designzeroed.dot(latent_beta) + self.nonlinear_offset_from_latent(latent_beta);
5054 let obs = evaluate_standard_familyobservations(
5055 self.family.clone(),
5056 self.latent_cloglog_state.as_ref(),
5057 self.mixture_link_state.as_ref(),
5058 self.sas_link_state.as_ref(),
5059 &self.y,
5060 &self.weights,
5061 &eta,
5062 )
5063 .map_err(|e| e.to_string())?;
5064
5065 let mut priorgrad = Array1::<f64>::zeros(latent_beta.len());
5066 let mut prior_neghess = Array2::<f64>::zeros((latent_beta.len(), latent_beta.len()));
5067 let mut prior_loglik = 0.0;
5068 for term in &self.bounded_terms {
5069 let (logp, grad, neghess, _) =
5070 bounded_prior_terms(latent_beta[term.col_idx], &term.prior);
5071 prior_loglik += logp;
5072 priorgrad[term.col_idx] += grad;
5073 prior_neghess[[term.col_idx, term.col_idx]] += neghess;
5074 }
5075
5076 let mut hessian = xt_diag_x_dense(x_eff.view(), obs.neghessian_eta.view())?;
5077 let mut gradient = fast_atv(&x_eff, &obs.score);
5078 for term in &self.bounded_terms {
5079 let score_beta = self.design.column(term.col_idx).dot(&obs.score);
5080 hessian[[term.col_idx, term.col_idx]] -= score_beta * second_diag[term.col_idx];
5081 }
5082 hessian += &prior_neghess;
5083 gradient += &priorgrad;
5084
5085 Ok((
5086 obs,
5087 hessian,
5088 gradient,
5089 prior_loglik,
5090 second_diag,
5091 third_diag,
5092 priorthird,
5093 ))
5094 }
5095
5096 fn evaluation_from_latent(
5097 &self,
5098 latent_beta: &Array1<f64>,
5099 ) -> Result<
5100 (
5101 StandardFamilyObservationState,
5102 Array2<f64>,
5103 Array1<f64>,
5104 f64,
5105 ),
5106 String,
5107 > {
5108 let (obs, hessian, gradient, prior_loglik, _, _, _) =
5109 self.exacthessian_andgradient(latent_beta)?;
5110 Ok((obs, hessian, gradient, prior_loglik))
5111 }
5112}
5113
5114impl CustomFamily for BoundedLinearFamily {
5115 fn joint_jeffreys_term_required(&self) -> bool {
5119 true
5120 }
5121
5122 fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
5123 let latent_beta = &expect_single_block_state(block_states, "bounded linear family")?.beta;
5124 let (obs, hessian, gradient, prior_loglik) = self.evaluation_from_latent(latent_beta)?;
5125 Ok(FamilyEvaluation {
5126 log_likelihood: obs.log_likelihood + prior_loglik,
5127 blockworking_sets: vec![BlockWorkingSet::ExactNewton {
5128 gradient,
5129 hessian: SymmetricMatrix::Dense(hessian),
5130 }],
5131 })
5132 }
5133
5134 fn exact_newton_joint_hessian(
5135 &self,
5136 block_states: &[ParameterBlockState],
5137 ) -> Result<Option<Array2<f64>>, String> {
5138 let latent_beta = &expect_single_block_state(block_states, "bounded linear family")?.beta;
5139 let (_, hessian, _, _) = self.evaluation_from_latent(latent_beta)?;
5140 Ok(Some(hessian))
5141 }
5142
5143 fn exact_newton_hessian_directional_derivative(
5144 &self,
5145 block_states: &[ParameterBlockState],
5146 block_idx: usize,
5147 d_beta: &Array1<f64>,
5148 ) -> Result<Option<Array2<f64>>, String> {
5149 expect_block_idx_zero(block_idx, "bounded linear family", "")?;
5150 self.exact_newton_joint_hessian_directional_derivative(block_states, d_beta)
5151 }
5152
5153 fn exact_newton_joint_hessian_directional_derivative(
5154 &self,
5155 block_states: &[ParameterBlockState],
5156 d_beta_flat: &Array1<f64>,
5157 ) -> Result<Option<Array2<f64>>, String> {
5158 let latent_beta = &expect_single_block_state(block_states, "bounded linear family")?.beta;
5159 if d_beta_flat.len() != latent_beta.len() {
5160 return Err(SmoothError::dimension_mismatch(format!(
5161 "bounded linear family directional derivative length mismatch: got {}, expected {}",
5162 d_beta_flat.len(),
5163 latent_beta.len()
5164 ))
5165 .into());
5166 }
5167
5168 let (obs, _, _, _, second_diag, third_diag, priorthird) =
5169 self.exacthessian_andgradient(latent_beta)?;
5170
5171 let (_, jac_diag, _, _, _) = self.bounded_term_derivative_data(latent_beta);
5172 let x_eff = self.effective_design_for_latent(&jac_diag);
5173 let deta = x_eff.dot(d_beta_flat);
5174 let d_neghess_eta = &obs.neghessian_eta_derivative * &deta;
5175
5176 let mut dx_eff = Array2::<f64>::zeros(x_eff.raw_dim());
5177 for term in &self.bounded_terms {
5178 let scale = second_diag[term.col_idx] * d_beta_flat[term.col_idx];
5179 if scale != 0.0 {
5180 let mut col = dx_eff.column_mut(term.col_idx);
5181 col.assign(&self.design.column(term.col_idx));
5182 col.mapv_inplace(|v| v * scale);
5183 }
5184 }
5185
5186 let mut dhessian = xt_diag_x_dense(x_eff.view(), d_neghess_eta.view())?;
5187 let mut wxdx = Array2::<f64>::zeros((x_eff.ncols(), x_eff.ncols()));
5188 for i in 0..x_eff.nrows() {
5189 let wi = obs.neghessian_eta[i];
5190 if wi == 0.0 {
5191 continue;
5192 }
5193 for a in 0..x_eff.ncols() {
5194 let xa = x_eff[[i, a]];
5195 for b in 0..x_eff.ncols() {
5196 wxdx[[a, b]] += wi * (dx_eff[[i, a]] * x_eff[[i, b]] + xa * dx_eff[[i, b]]);
5197 }
5198 }
5199 }
5200 dhessian += &wxdx;
5201
5202 let d_score = -&obs.neghessian_eta * &deta;
5203 for term in &self.bounded_terms {
5204 let score_beta = self.design.column(term.col_idx).dot(&obs.score);
5205 let d_score_beta = self.design.column(term.col_idx).dot(&d_score);
5206 dhessian[[term.col_idx, term.col_idx]] -= d_score_beta * second_diag[term.col_idx]
5207 + score_beta * third_diag[term.col_idx] * d_beta_flat[term.col_idx];
5208 dhessian[[term.col_idx, term.col_idx]] +=
5209 priorthird[term.col_idx] * d_beta_flat[term.col_idx];
5210 }
5211
5212 Ok(Some(dhessian))
5213 }
5214
5215 fn block_geometry(
5216 &self,
5217 block_states: &[ParameterBlockState],
5218 spec: &ParameterBlockSpec,
5219 ) -> Result<(DesignMatrix, Array1<f64>), String> {
5220 if block_states.is_empty() {
5221 return Ok((
5222 DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
5223 self.designzeroed.clone(),
5224 )),
5225 self.offset.clone(),
5226 ));
5227 }
5228 let offset = self.nonlinear_offset_from_latent(
5229 &expect_single_block_state(block_states, "bounded linear family")?.beta,
5230 );
5231 let x = if spec.design.ncols() == self.designzeroed.ncols() {
5232 self.designzeroed.clone()
5233 } else {
5234 return Err(SmoothError::dimension_mismatch(
5235 "bounded linear family design column mismatch",
5236 )
5237 .into());
5238 };
5239 Ok((
5240 DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(x)),
5241 offset,
5242 ))
5243 }
5244
5245 fn block_geometry_is_dynamic(&self) -> bool {
5246 true
5247 }
5248
5249 fn block_geometry_directional_derivative(
5250 &self,
5251 block_states: &[ParameterBlockState],
5252 block_idx: usize,
5253 spec: &ParameterBlockSpec,
5254 d_beta: &Array1<f64>,
5255 ) -> Result<Option<BlockGeometryDirectionalDerivative>, String> {
5256 expect_block_idx_zero(
5257 block_idx,
5258 "bounded linear family",
5259 " for geometry derivative",
5260 )?;
5261 expect_single_block_state(block_states, "bounded linear family")?;
5262 if d_beta.len() != spec.design.ncols() {
5263 return Err(SmoothError::dimension_mismatch(format!(
5264 "bounded linear family geometry derivative direction mismatch: got {}, expected {}",
5265 d_beta.len(),
5266 spec.design.ncols()
5267 ))
5268 .into());
5269 }
5270 let (_, jac_diag, _, _, _) = self.bounded_term_derivative_data(&block_states[0].beta);
5271 let mut d_offset = Array1::<f64>::zeros(self.offset.len());
5272 let has_drift = self
5273 .bounded_terms
5274 .iter()
5275 .any(|term| jac_diag[term.col_idx] != 0.0 && d_beta[term.col_idx] != 0.0);
5276 if !has_drift {
5277 return Ok(Some(BlockGeometryDirectionalDerivative {
5278 d_design: None,
5279 d_offset,
5280 }));
5281 }
5282 for term in &self.bounded_terms {
5283 let col = term.col_idx;
5284 let drift = jac_diag[col] * d_beta[col];
5285 if drift != 0.0 {
5286 d_offset.scaled_add(drift, &self.design.column(col));
5287 }
5288 }
5289 Ok(Some(BlockGeometryDirectionalDerivative {
5290 d_design: None,
5291 d_offset,
5292 }))
5293 }
5294}
5295
5296#[inline]
5297fn dense_diag_gram_chunkrows(p: usize) -> usize {
5298 const MIN_ROWS: usize = 512;
5299 const MAX_ROWS: usize = 2048;
5300 const TARGET_BYTES: usize = 2 * 1024 * 1024;
5301 let bytes_per_row = p.max(1) * std::mem::size_of::<f64>();
5302 (TARGET_BYTES / bytes_per_row).clamp(MIN_ROWS, MAX_ROWS)
5303}
5304
5305fn xt_diag_x_dense(x: ArrayView2<'_, f64>, w: ArrayView1<'_, f64>) -> Result<Array2<f64>, String> {
5306 if x.nrows() != w.len() {
5307 return Err(SmoothError::dimension_mismatch("xt_diag_x_dense row mismatch").into());
5308 }
5309 let (n, p) = x.dim();
5310 if n == 0 || p == 0 {
5311 return Ok(Array2::<f64>::zeros((p, p)));
5312 }
5313
5314 const STREAMING_BYTES_THRESHOLD: usize = 8 * 1024 * 1024;
5315 let dense_work_bytes = n
5316 .checked_mul(p)
5317 .and_then(|cells| cells.checked_mul(std::mem::size_of::<f64>()))
5318 .unwrap_or(usize::MAX);
5319 if dense_work_bytes <= STREAMING_BYTES_THRESHOLD {
5320 let mut weighted = x.to_owned();
5321 ndarray::Zip::from(weighted.rows_mut())
5322 .and(w)
5323 .par_for_each(|mut row, wi| row *= *wi);
5324 return Ok(fast_atb(&x, &weighted));
5325 }
5326
5327 let chunkrows = dense_diag_gram_chunkrows(p).min(n);
5328 let mut weighted_chunk = Array2::<f64>::zeros((chunkrows, p));
5329 let mut out = Array2::<f64>::zeros((p, p));
5330 for row_start in (0..n).step_by(chunkrows) {
5331 let rows = (n - row_start).min(chunkrows);
5332 let x_chunk = x.slice(s![row_start..row_start + rows, ..]);
5333 {
5334 let mut chunk = weighted_chunk.slice_mut(s![0..rows, ..]);
5335 for local_row in 0..rows {
5336 let scale = w[row_start + local_row];
5337 if scale == 0.0 {
5338 chunk.row_mut(local_row).fill(0.0);
5339 continue;
5340 }
5341 for col in 0..p {
5342 chunk[[local_row, col]] = x_chunk[[local_row, col]] * scale;
5343 }
5344 }
5345 }
5346 out += &fast_atb(&x_chunk, &weighted_chunk.slice(s![0..rows, ..]));
5347 }
5348 Ok(out)
5349}
5350
5351fn trace_of_dense_product(a: &Array2<f64>, b: &Array2<f64>) -> Result<f64, String> {
5352 if a.nrows() != a.ncols() || b.nrows() != b.ncols() || a.nrows() != b.nrows() {
5353 return Err(
5354 SmoothError::dimension_mismatch("trace_of_dense_product dimension mismatch").into(),
5355 );
5356 }
5357 let mut trace = 0.0;
5358 for i in 0..a.nrows() {
5359 for j in 0..a.ncols() {
5360 trace += a[[i, j]] * b[[j, i]];
5361 }
5362 }
5363 Ok(trace)
5364}
5365
5366fn exact_bounded_edf(
5367 penalties: &[PenaltySpec],
5368 lambdas: &Array1<f64>,
5369 latent_cov: &Array2<f64>,
5370) -> Result<(Vec<f64>, Vec<f64>, f64), EstimationError> {
5371 if penalties.len() != lambdas.len() {
5372 crate::bail_invalid_estim!(
5373 "bounded EDF penalty/lambda mismatch: {} penalties vs {} lambdas",
5374 penalties.len(),
5375 lambdas.len()
5376 );
5377 }
5378 if latent_cov.nrows() != latent_cov.ncols() {
5379 crate::bail_invalid_estim!("bounded EDF covariance must be square");
5380 }
5381
5382 let p = latent_cov.nrows();
5383 let mut s_lambda = Array2::<f64>::zeros((p, p));
5384 let mut edf_by_block = Vec::with_capacity(penalties.len());
5385 let mut penalty_block_trace = Vec::with_capacity(penalties.len());
5387 let mut trace_sum = 0.0;
5388
5389 for (k, ps) in penalties.iter().enumerate() {
5390 let lambda_k = lambdas[k];
5391 match ps {
5392 PenaltySpec::Block {
5393 local, col_range, ..
5394 } => {
5395 s_lambda
5396 .slice_mut(ndarray::s![col_range.clone(), col_range.clone()])
5397 .scaled_add(lambda_k, local);
5398 let penalty_rank =
5400 local
5401 .nrows()
5402 .saturating_sub(estimate_penalty_nullity(local).map_err(|e| {
5403 EstimationError::InvalidInput(format!("bounded EDF rank failed: {e}"))
5404 })?);
5405 let cov_block = latent_cov.slice(ndarray::s![col_range.clone(), col_range.clone()]);
5407 let trace_k = lambda_k
5408 * trace_of_dense_product(&cov_block.to_owned(), local)
5409 .map_err(EstimationError::InvalidInput)?;
5410 trace_sum += trace_k;
5411 penalty_block_trace.push(trace_k);
5412 let p_k = penalty_rank as f64;
5413 edf_by_block.push((p_k - trace_k).clamp(0.0, p_k));
5414 }
5415 PenaltySpec::Dense(m) | PenaltySpec::DenseWithMean { matrix: m, .. } => {
5416 s_lambda.scaled_add(lambda_k, m);
5417 let penalty_rank = p.saturating_sub(estimate_penalty_nullity(m).map_err(|e| {
5418 EstimationError::InvalidInput(format!("bounded EDF rank failed: {e}"))
5419 })?);
5420 let trace_k = lambda_k
5421 * trace_of_dense_product(latent_cov, m)
5422 .map_err(EstimationError::InvalidInput)?;
5423 trace_sum += trace_k;
5424 penalty_block_trace.push(trace_k);
5425 let p_k = penalty_rank as f64;
5426 edf_by_block.push((p_k - trace_k).clamp(0.0, p_k));
5427 }
5428 }
5429 }
5430
5431 let nullity_total = estimate_penalty_nullity(&s_lambda)
5432 .map_err(|e| EstimationError::InvalidInput(format!("bounded EDF nullity failed: {e}")))?
5433 as f64;
5434 let edf_total = (p as f64 - trace_sum).clamp(nullity_total, p as f64);
5435 Ok((edf_by_block, penalty_block_trace, edf_total))
5436}
5437
5438fn symmetric_positive_definite_inverse_or_pseudo(
5450 precision: &Array2<f64>,
5451) -> Result<Array2<f64>, EstimationError> {
5452 use gam_linalg::faer_ndarray::FaerEigh;
5453 let p = precision.nrows();
5454 if precision.ncols() != p {
5455 crate::bail_invalid_estim!(
5456 "posterior precision inverse requires a square matrix, got {}x{}",
5457 precision.nrows(),
5458 precision.ncols()
5459 );
5460 }
5461 if p == 0 {
5462 return Ok(Array2::<f64>::zeros((0, 0)));
5463 }
5464 let symmetric = (precision + &precision.t().to_owned()) * 0.5;
5465 let (evals, evecs) = symmetric.eigh(faer::Side::Lower).map_err(|e| {
5466 EstimationError::InvalidInput(format!(
5467 "posterior precision eigendecomposition failed: {e}"
5468 ))
5469 })?;
5470 let max_abs_eval = evals.iter().fold(0.0_f64, |acc, &ev| acc.max(ev.abs()));
5471 let tol =
5472 (10.0 * f64::EPSILON * (p as f64) * (p as f64) * max_abs_eval).max(100.0 * f64::EPSILON);
5473 if let Some(&min_eval) = evals
5474 .iter()
5475 .filter(|&&ev| ev < -tol)
5476 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
5477 {
5478 crate::bail_invalid_estim!(
5479 "bounded posterior precision is non-PD at the converged optimum (min eigenvalue \
5480 {min_eval:.6e} < -tol={tol:.6e}); the reported mode is not a strict posterior \
5481 maximum, so a covariance would be meaningless"
5482 );
5483 }
5484 let mut scaled = evecs.clone();
5486 for (j, &ev) in evals.iter().enumerate() {
5487 let inv = if ev > tol { 1.0 / ev } else { 0.0 };
5488 scaled.column_mut(j).mapv_inplace(|v| v * inv);
5489 }
5490 let cov = scaled.dot(&evecs.t());
5491 Ok((&cov + &cov.t().to_owned()) * 0.5)
5492}
5493
5494fn transform_bounded_latent_precision_to_user_internal(
5495 latent_precision: &Array2<f64>,
5496 jac_diag: &Array1<f64>,
5497) -> Result<Array2<f64>, EstimationError> {
5498 let p = latent_precision.nrows();
5499 if latent_precision.ncols() != p || jac_diag.len() != p {
5500 crate::bail_invalid_estim!(
5501 "bounded precision transform dimension mismatch: precision is {}x{}, jacobian has {} entries",
5502 latent_precision.nrows(),
5503 latent_precision.ncols(),
5504 jac_diag.len()
5505 );
5506 }
5507 let mut out = latent_precision.clone();
5508 for i in 0..p {
5509 let scale = jac_diag[i];
5510 if !scale.is_finite() || scale <= 0.0 {
5511 crate::bail_invalid_estim!(
5512 "bounded precision transform requires a positive finite coefficient jacobian; column {i} has {scale}"
5513 );
5514 }
5515 if scale != 1.0 {
5516 out.row_mut(i).mapv_inplace(|v| v / scale);
5517 out.column_mut(i).mapv_inplace(|v| v / scale);
5518 }
5519 }
5520 Ok(out)
5521}
5522
5523fn fit_bounded_term_collection_with_design(
5524 y: ArrayView1<'_, f64>,
5525 weights: ArrayView1<'_, f64>,
5526 offset: ArrayView1<'_, f64>,
5527 spec: &TermCollectionSpec,
5528 design: &TermCollectionDesign,
5529 heuristic_lambdas: Option<&[f64]>,
5530 family: LikelihoodSpec,
5531 options: &FitOptions,
5532) -> Result<FittedTermCollection, EstimationError> {
5533 let conditioning_cols: Vec<usize> = spec
5534 .linear_terms
5535 .iter()
5536 .enumerate()
5537 .filter_map(|(j, linear)| {
5538 (!linear.double_penalty).then_some(design.intercept_range.end + j)
5539 })
5540 .collect();
5541 let conditioning = LinearFitConditioning::from_columns(design, &conditioning_cols);
5542 let dense_design = design.design.to_dense_cow();
5543 let fit_design = conditioning.apply_to_design(&dense_design);
5544 let fit_penalties = conditioning
5545 .transform_blockwise_penalties_to_internal(&design.penalties, design.design.ncols());
5546 if design.linear_constraints.is_some() {
5547 crate::bail_invalid_estim!(
5548 "bounded() terms are not yet compatible with explicit linear constraints"
5549 );
5550 }
5551 let mut bounded_terms = Vec::<BoundedLinearTermMeta>::new();
5552 for (j, term) in spec.linear_terms.iter().enumerate() {
5553 if term.double_penalty
5554 && matches!(
5555 term.coefficient_geometry,
5556 LinearCoefficientGeometry::Bounded { .. }
5557 )
5558 {
5559 crate::bail_invalid_estim!(
5560 "bounded linear term '{}' cannot also use double_penalty",
5561 term.name
5562 );
5563 }
5564 if let LinearCoefficientGeometry::Bounded { min, max, prior } =
5565 term.coefficient_geometry.clone()
5566 {
5567 let col_idx = design.intercept_range.end + j;
5568 let (min_internal, max_internal) = conditioning.internal_bounds_for(col_idx, min, max);
5569 bounded_terms.push(BoundedLinearTermMeta {
5570 col_idx,
5571 min: min_internal,
5572 max: max_internal,
5573 prior,
5574 });
5575 }
5576 }
5577 if bounded_terms.is_empty() {
5578 crate::bail_invalid_estim!("internal bounded fit path called with no bounded terms");
5579 }
5580
5581 let mut designzeroed = fit_design.clone();
5582 let mut initial_beta = Array1::<f64>::zeros(fit_design.ncols());
5583 for term in &bounded_terms {
5584 designzeroed.column_mut(term.col_idx).fill(0.0);
5585 initial_beta[term.col_idx] = bounded_logit(0.5);
5586 }
5587
5588 let initial_log_lambdas = heuristic_lambdas
5589 .map(|vals| Array1::from_vec(vals.to_vec()))
5590 .unwrap_or_else(|| Array1::zeros(fit_penalties.len()));
5591 if initial_log_lambdas.len() != fit_penalties.len() {
5592 crate::bail_invalid_estim!(
5593 "heuristic lambda length mismatch for bounded model: got {}, expected {}",
5594 initial_log_lambdas.len(),
5595 fit_penalties.len()
5596 );
5597 }
5598
5599 let is_beta_logistic = family.is_binomial_beta_logistic();
5600 let family_adapter = BoundedLinearFamily {
5601 family: family.clone(),
5602 latent_cloglog_state: options.latent_cloglog,
5603 mixture_link_state: options
5604 .mixture_link
5605 .clone()
5606 .as_ref()
5607 .map(state_fromspec)
5608 .transpose()
5609 .map_err(EstimationError::InvalidInput)?,
5610 sas_link_state: options
5611 .sas_link
5612 .map(|spec| {
5613 if is_beta_logistic {
5614 state_from_beta_logisticspec(spec)
5615 } else {
5616 state_from_sasspec(spec)
5617 }
5618 })
5619 .transpose()
5620 .map_err(EstimationError::InvalidInput)?,
5621 y: y.to_owned(),
5622 weights: weights.to_owned(),
5623 design: fit_design.clone(),
5624 designzeroed: designzeroed.clone(),
5625 offset: offset.to_owned(),
5626 bounded_terms: bounded_terms.clone(),
5627 };
5628 let blockspec = ParameterBlockSpec {
5629 name: "eta".to_string(),
5630 design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(designzeroed)),
5631 offset: offset.to_owned(),
5632 penalties: fit_penalties
5633 .iter()
5634 .map(|ps| match ps {
5635 PenaltySpec::Block {
5636 local, col_range, ..
5637 } => PenaltyMatrix::Blockwise {
5638 local: local.clone(),
5639 col_range: col_range.clone(),
5640 total_dim: design.design.ncols(),
5641 },
5642 PenaltySpec::Dense(m) | PenaltySpec::DenseWithMean { matrix: m, .. } => {
5643 PenaltyMatrix::Dense(m.clone())
5644 }
5645 })
5646 .collect(),
5647 nullspace_dims: design.nullspace_dims.clone(),
5648 initial_log_lambdas,
5649 initial_beta: Some(initial_beta),
5650 gauge_priority: 100,
5651 jacobian_callback: Some(Arc::new(BoundedEffectiveJacobian {
5657 design: fit_design.clone(),
5658 bounded_terms: bounded_terms.clone(),
5659 })),
5660 stacked_design: None,
5661 stacked_offset: None,
5662 };
5663 let fit = fit_custom_family(
5664 &family_adapter,
5665 &[blockspec],
5666 &BlockwiseFitOptions {
5667 inner_max_cycles: options.max_iter,
5668 inner_tol: options.tol,
5669 outer_max_iter: options.max_iter,
5670 outer_tol: options.tol,
5671 compute_covariance: false,
5681 ..BlockwiseFitOptions::default()
5682 },
5683 )
5684 .map_err(EstimationError::CustomFamily)?;
5685
5686 let latent_beta = fit.block_states[0].beta.clone();
5687 let (beta_user_internal, jac_diag) = family_adapter.user_beta_and_jacobian(&latent_beta);
5688 let beta_user = conditioning.backtransform_beta(&beta_user_internal);
5689
5690 let (eta_state, h_data, _, _) = family_adapter
5691 .evaluation_from_latent(&latent_beta)
5692 .map_err(EstimationError::InvalidInput)?;
5693 let p_fit = fit_design.ncols();
5694 let mut s_lambda_internal = Array2::<f64>::zeros((p_fit, p_fit));
5695 for (k, penalty) in fit_penalties.iter().enumerate() {
5696 match penalty {
5697 PenaltySpec::Block {
5698 local, col_range, ..
5699 } => {
5700 s_lambda_internal
5701 .slice_mut(ndarray::s![col_range.clone(), col_range.clone()])
5702 .scaled_add(fit.lambdas[k], local);
5703 }
5704 PenaltySpec::Dense(m) | PenaltySpec::DenseWithMean { matrix: m, .. } => {
5705 s_lambda_internal.scaled_add(fit.lambdas[k], m);
5706 }
5707 }
5708 }
5709 let mut latent_precision = h_data.clone();
5710 latent_precision += &s_lambda_internal;
5711 let user_precision_internal =
5712 transform_bounded_latent_precision_to_user_internal(&latent_precision, &jac_diag)?;
5713 let penalized_hessian =
5714 conditioning.transform_penalized_hessian_to_original(&user_precision_internal);
5715
5716 let beta_covariance_unscaled = if options.compute_inference {
5744 Some(symmetric_positive_definite_inverse_or_pseudo(
5745 &penalized_hessian,
5746 )?)
5747 } else {
5748 None
5749 };
5750 let latent_cov = if options.compute_inference {
5756 Some(symmetric_positive_definite_inverse_or_pseudo(
5757 &latent_precision,
5758 )?)
5759 } else {
5760 None
5761 };
5762 let s_lambda_original = weighted_blockwise_penalty_sum(
5763 &design.penalties,
5764 fit.lambdas.as_slice().unwrap(),
5765 design.design.ncols(),
5766 );
5767 let penalty_term = beta_user.dot(&s_lambda_original.dot(&beta_user));
5768 let deviance = if family.is_gaussian_identity() {
5769 y.iter()
5770 .zip(eta_state.mu.iter())
5771 .zip(weights.iter())
5772 .map(|((&yy, &mu), &w)| w.max(0.0) * (yy - mu) * (yy - mu))
5773 .sum()
5774 } else {
5775 -2.0 * eta_state.log_likelihood
5776 };
5777 let (edf_by_block, penalty_block_trace, edf_total) = if let Some(cov) = latent_cov.as_ref() {
5778 exact_bounded_edf(&fit_penalties, &fit.lambdas, cov)?
5779 } else {
5780 (
5781 vec![0.0; fit_penalties.len()],
5782 vec![0.0; fit_penalties.len()],
5783 0.0,
5784 )
5785 };
5786
5787 let glm_likelihood = gam_spec::GlmLikelihoodSpec::canonical(family.clone());
5799 let standard_deviation = if family.is_gaussian_identity() {
5800 let denom = if options.compute_inference {
5801 (y.len() as f64 - edf_total).max(1.0)
5802 } else {
5803 (y.len() as f64).max(1.0)
5804 };
5805 (deviance / denom).sqrt()
5806 } else {
5807 1.0
5808 };
5809 let cov_scale = glm_likelihood
5810 .coefficient_covariance_scale(standard_deviation * standard_deviation)
5811 .max(f64::MIN_POSITIVE);
5812 let dispersion = gam_solve::estimate::dispersion_from_likelihood(&glm_likelihood, standard_deviation);
5813 let beta_covariance = beta_covariance_unscaled.map(|mut cov| {
5819 if cov_scale != 1.0 {
5820 cov.mapv_inplace(|v| v * cov_scale);
5821 }
5822 cov
5823 });
5824 let beta_standard_errors = beta_covariance
5825 .as_ref()
5826 .map(|cov| Array1::from_iter((0..cov.nrows()).map(|i| cov[[i, i]].max(0.0).sqrt())));
5827
5828 let geometry = Some(gam_solve::estimate::FitGeometry {
5829 penalized_hessian: penalized_hessian.clone().into(),
5830 working_weights: eta_state.fisherweight.clone(),
5831 working_response: {
5832 let mut working_response = eta_state.eta.clone();
5833 for i in 0..working_response.len() {
5834 let wi = eta_state.fisherweight[i].max(1e-12);
5835 working_response[i] += eta_state.score[i] / wi;
5836 }
5837 working_response
5838 },
5839 });
5840 let max_abs_eta = eta_state
5841 .eta
5842 .iter()
5843 .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
5844 Ok(FittedTermCollection {
5845 fit: {
5846 let log_lambdas = fit.lambdas.mapv(|v| v.max(1e-300).ln());
5847 let inf = FitInference {
5848 edf_by_block,
5849 penalty_block_trace,
5850 edf_total,
5851 smoothing_correction: None,
5852 penalized_hessian: penalized_hessian.clone().into(),
5855 working_weights: eta_state.fisherweight.clone(),
5856 working_response: {
5857 let mut working_response = eta_state.eta.clone();
5858 for i in 0..working_response.len() {
5859 let wi = eta_state.fisherweight[i].max(1e-12);
5860 working_response[i] += eta_state.score[i] / wi;
5861 }
5862 working_response
5863 },
5864 reparam_qs: None,
5865 dispersion,
5866 beta_covariance: beta_covariance
5867 .clone()
5868 .map(gam_problem::dispersion_cov::PhiScaledCovariance::from),
5869 beta_standard_errors,
5870 beta_covariance_corrected: None,
5871 beta_standard_errors_corrected: None,
5872 beta_covariance_frequentist: None,
5873 coefficient_influence: None,
5874 weighted_gram: None,
5875 bias_correction_beta: None,
5876 };
5877 let covariance_conditional = beta_covariance;
5878 let pirls_status_val = if fit.outer_converged {
5879 gam_solve::pirls::PirlsStatus::Converged
5880 } else {
5881 gam_solve::pirls::PirlsStatus::StalledAtValidMinimum
5882 };
5883 UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
5884 blocks: vec![gam_solve::estimate::FittedBlock {
5885 beta: beta_user.clone(),
5886 role: gam_problem::BlockRole::Mean,
5887 edf: edf_total,
5888 lambdas: fit.lambdas.clone(),
5889 }],
5890 log_lambdas,
5891 lambdas: fit.lambdas,
5892 likelihood_scale: family.default_scale_metadata(),
5893 likelihood_family: Some(family),
5894 log_likelihood_normalization:
5895 gam_spec::LogLikelihoodNormalization::UserProvided,
5896 log_likelihood: eta_state.log_likelihood,
5897 deviance,
5898 reml_score: fit.penalized_objective,
5899 stable_penalty_term: penalty_term,
5900 penalized_objective: fit.penalized_objective,
5901 used_device: false,
5902 outer_iterations: fit.outer_iterations,
5903 outer_converged: fit.outer_converged,
5904 outer_gradient_norm: fit.outer_gradient_norm,
5905 standard_deviation,
5906 covariance_conditional,
5907 covariance_corrected: None,
5908 inference: Some(inf),
5909 fitted_link: gam_solve::estimate::FittedLinkState::Standard(None),
5910 geometry,
5911 block_states: Vec::new(),
5912 pirls_status: pirls_status_val,
5913 max_abs_eta,
5914 constraint_kkt: None,
5915 artifacts: gam_solve::estimate::FitArtifacts {
5916 pirls: None,
5917 ..Default::default()
5918 },
5919 inner_cycles: 0,
5920 })?
5921 },
5922 design: design.clone(),
5923 adaptive_diagnostics: None,
5924 })
5925}
5926
5927fn enforce_term_constraint_feasibility(
5928 design: &TermCollectionDesign,
5929 fit: &UnifiedFitResult,
5930) -> Result<(), EstimationError> {
5931 const CONSTRAINT_FEASIBILITY_RAW_TOL: f64 = 1e-7;
5945 let tol = CONSTRAINT_FEASIBILITY_RAW_TOL;
5946 let smooth_start = design
5947 .design
5948 .ncols()
5949 .saturating_sub(design.smooth.total_smooth_cols());
5950 let mut violations: Vec<String> = Vec::new();
5951 for term in &design.smooth.terms {
5952 let gr = (smooth_start + term.coeff_range.start)..(smooth_start + term.coeff_range.end);
5953 let beta_local = fit.beta.slice(s![gr.clone()]).to_owned();
5954 if let Some(lb) = term.lower_bounds_local.as_ref() {
5955 let mut worst = 0.0_f64;
5956 let mut worst_idx = 0usize;
5957 for i in 0..lb.len().min(beta_local.len()) {
5958 if lb[i].is_finite() {
5959 let viol = (lb[i] - beta_local[i]).max(0.0);
5960 if viol > worst {
5961 worst = viol;
5962 worst_idx = i;
5963 }
5964 }
5965 }
5966 if worst > tol {
5967 violations.push(format!(
5968 "term='{}' kind=lower-bound maxviolation={:.3e} coeff_index={}",
5969 term.name, worst, worst_idx
5970 ));
5971 }
5972 }
5973 if let Some(lin) = term.linear_constraints_local.as_ref() {
5974 let mut worst = 0.0_f64;
5975 let mut worstrow = 0usize;
5976 for i in 0..lin.a.nrows() {
5977 let norm = lin.a.row(i).dot(&lin.a.row(i)).sqrt();
5978 let inv = if norm > 0.0 { 1.0 / norm } else { 0.0 };
5979 let s = (lin.a.row(i).dot(&beta_local) - lin.b[i]) * inv;
5980 let viol = (-s).max(0.0);
5981 if viol > worst {
5982 worst = viol;
5983 worstrow = i;
5984 }
5985 }
5986 if worst > tol {
5987 violations.push(format!(
5988 "term='{}' kind=linear-inequality maxviolation={:.3e} row={}",
5989 term.name, worst, worstrow
5990 ));
5991 }
5992 }
5993 }
5994
5995 if !violations.is_empty() {
5996 let mut msg = format!(
5997 "constraint violation after fit ({} violating term constraints): {}",
5998 violations.len(),
5999 violations.join(" | ")
6000 );
6001 if let Some(kkt) = fit.constraint_kkt.as_ref() {
6002 msg.push_str(&format!(
6003 "; KKT[primal={:.3e}, dual={:.3e}, comp={:.3e}, stat={:.3e}]",
6004 kkt.primal_feasibility, kkt.dual_feasibility, kkt.complementarity, kkt.stationarity
6005 ));
6006 }
6007 return Err(EstimationError::ParameterConstraintViolation(msg));
6008 }
6009 Ok(())
6010}
6011
6012fn stratified_spatial_subsample(
6013 data: ArrayView2<'_, f64>,
6014 spec: &TermCollectionSpec,
6015 target_size: usize,
6016) -> Vec<usize> {
6017 use rand::SeedableRng;
6018 use rand::rngs::StdRng;
6019 use rand::seq::SliceRandom;
6020
6021 let n = data.nrows();
6022 if n <= target_size {
6023 return (0..n).collect();
6024 }
6025
6026 let spatial_cols: Option<Vec<usize>> =
6027 spec.smooth_terms.iter().find_map(|term| match &term.basis {
6028 SmoothBasisSpec::ThinPlate { feature_cols, .. }
6029 | SmoothBasisSpec::Matern { feature_cols, .. }
6030 | SmoothBasisSpec::Duchon { feature_cols, .. } => {
6031 if !feature_cols.is_empty() {
6032 Some(feature_cols.clone())
6033 } else {
6034 None
6035 }
6036 }
6037 _ => None,
6038 });
6039
6040 let cols = match spatial_cols {
6041 Some(c) if !c.is_empty() => c,
6042 _ => {
6043 let mut rng = StdRng::seed_from_u64(spatial_subsample_seed(data, &[], target_size));
6044 let mut indices: Vec<usize> = (0..n).collect();
6045 indices.shuffle(&mut rng);
6046 indices.truncate(target_size);
6047 indices.sort_unstable();
6048 return indices;
6049 }
6050 };
6051 let mut rng = StdRng::seed_from_u64(spatial_subsample_seed(data, &cols, target_size));
6052
6053 let d = cols.len();
6054 let mut mins = vec![f64::INFINITY; d];
6055 let mut maxs = vec![f64::NEG_INFINITY; d];
6056 for i in 0..n {
6057 for (ax, &col) in cols.iter().enumerate() {
6058 let v = data[[i, col]];
6059 if v < mins[ax] {
6060 mins[ax] = v;
6061 }
6062 if v > maxs[ax] {
6063 maxs[ax] = v;
6064 }
6065 }
6066 }
6067
6068 const TARGET_POINTS_PER_CELL: usize = 5;
6072 let total_cells_target = (target_size / TARGET_POINTS_PER_CELL).max(1);
6073 let cells_per_axis = ((total_cells_target as f64).powf(1.0 / d as f64)).ceil() as usize;
6074 let cells_per_axis = cells_per_axis.max(1);
6075
6076 let mut cell_members: std::collections::HashMap<Vec<usize>, Vec<usize>> =
6077 std::collections::HashMap::new();
6078 for i in 0..n {
6079 let mut cell_key = Vec::with_capacity(d);
6080 for (ax, &col) in cols.iter().enumerate() {
6081 let range = maxs[ax] - mins[ax];
6082 let cell = if range <= 0.0 {
6083 0
6084 } else {
6085 let frac = (data[[i, col]] - mins[ax]) / range;
6086 (frac * cells_per_axis as f64).floor() as usize
6087 };
6088 cell_key.push(cell.min(cells_per_axis - 1));
6089 }
6090 cell_members.entry(cell_key).or_default().push(i);
6091 }
6092
6093 let mut selected: Vec<usize> = Vec::with_capacity(target_size);
6094 let mut remaining_budget = target_size;
6095 let mut remaining_population = n;
6096
6097 let mut cells: Vec<(Vec<usize>, Vec<usize>)> = cell_members.into_iter().collect();
6098 cells.sort_by(|a, b| a.0.cmp(&b.0));
6099
6100 for (_, members) in &mut cells {
6101 if remaining_budget == 0 {
6102 break;
6103 }
6104 let alloc = ((members.len() as f64 / remaining_population as f64) * remaining_budget as f64)
6105 .round() as usize;
6106 let alloc = alloc.max(1).min(members.len()).min(remaining_budget);
6107 members.shuffle(&mut rng);
6108 selected.extend_from_slice(&members[..alloc]);
6109 remaining_budget = remaining_budget.saturating_sub(alloc);
6110 remaining_population = remaining_population.saturating_sub(members.len());
6111 }
6112
6113 if selected.len() > target_size {
6114 selected.shuffle(&mut rng);
6115 selected.truncate(target_size);
6116 }
6117
6118 selected.sort_unstable();
6119 selected
6120}
6121
6122fn spatial_subsample_seed(
6123 data: ArrayView2<'_, f64>,
6124 spatial_cols: &[usize],
6125 target_size: usize,
6126) -> u64 {
6127 let mut state = 0x5350_4154_4941_4C53_u64;
6128 spatial_seed_mix(&mut state, data.nrows() as u64);
6129 spatial_seed_mix(&mut state, data.ncols() as u64);
6130 spatial_seed_mix(&mut state, target_size as u64);
6131 spatial_seed_mix(&mut state, spatial_cols.len() as u64);
6132 for &col in spatial_cols {
6133 spatial_seed_mix(&mut state, col as u64);
6134 }
6135
6136 if data.nrows() > 0 {
6137 let mid = data.nrows() / 2;
6138 let last = data.nrows() - 1;
6139 for &row in &[0usize, mid, last] {
6140 for &col in spatial_cols {
6141 let value = data[[row, col]];
6142 spatial_seed_mix(&mut state, value.to_bits());
6143 }
6144 }
6145 }
6146 state
6147}
6148
6149#[inline]
6150fn spatial_seed_mix(state: &mut u64, value: u64) {
6151 let mut s = value.wrapping_add(*state);
6154 let z = gam_linalg::utils::splitmix64(&mut s);
6155 *state ^= z;
6156 *state = (*state).rotate_left(27).wrapping_mul(0x3C79_AC49_2BA7_B653);
6157}
6158
6159fn sampled_rows(data: ArrayView2<'_, f64>, indices: &[usize]) -> Array2<f64> {
6160 let mut sampled = Array2::<f64>::zeros((indices.len(), data.ncols()));
6161 for (new_row, &orig_row) in indices.iter().enumerate() {
6162 sampled.row_mut(new_row).assign(&data.row(orig_row));
6163 }
6164 sampled
6165}
6166
6167fn spatial_term_user_centers(term: &SmoothTermSpec) -> Option<ArrayView2<'_, f64>> {
6168 match spatial_term_center_strategy(term) {
6169 Some(CenterStrategy::UserProvided(centers)) => Some(centers.view()),
6170 _ => None,
6171 }
6172}
6173
6174fn finite_centered_axis_contrasts(values: &[f64], expected_dim: usize) -> Option<Vec<f64>> {
6175 if values.len() != expected_dim || expected_dim <= 1 {
6176 return None;
6177 }
6178 if values.iter().any(|value| !value.is_finite()) {
6179 return None;
6180 }
6181 Some(center_aniso_log_scales(values))
6182}
6183
6184fn blended_pilot_axis_contrasts(
6185 pilot_data: ArrayView2<'_, f64>,
6186 term: &SmoothTermSpec,
6187 centers: ArrayView2<'_, f64>,
6188) -> Option<Vec<f64>> {
6189 let d = centers.ncols();
6190 if d <= 1 {
6191 return None;
6192 }
6193 let center_eta = initial_aniso_contrasts(centers);
6194 let data_eta = standardized_spatial_term_data(pilot_data, term)
6195 .ok()
6196 .and_then(|x| finite_centered_axis_contrasts(&initial_aniso_contrasts(x.view()), d));
6197 let center_eta = finite_centered_axis_contrasts(¢er_eta, d)?;
6198 let blended = match data_eta {
6199 Some(data_eta) => center_eta
6200 .iter()
6201 .zip(data_eta.iter())
6202 .map(|(&from_centers, &from_data)| 0.5 * (from_centers + from_data))
6203 .collect::<Vec<_>>(),
6204 None => center_eta,
6205 };
6206 finite_centered_axis_contrasts(&blended, d)
6207}
6208
6209fn apply_pilot_spatial_psi_reseed(
6210 pilot_data: ArrayView2<'_, f64>,
6211 spec: &TermCollectionSpec,
6212 spatial_terms: &[usize],
6213 kappa_options: &SpatialLengthScaleOptimizationOptions,
6214) -> Result<TermCollectionSpec, EstimationError> {
6215 let dims_per_term = spatial_dims_per_term(spec, spatial_terms);
6216 let use_aniso = has_aniso_terms(spec, spatial_terms);
6217 let log_kappa0 = if use_aniso {
6218 SpatialLogKappaCoords::from_length_scales_aniso(spec, spatial_terms, kappa_options)
6219 } else {
6220 SpatialLogKappaCoords::from_length_scales(spec, spatial_terms, kappa_options)
6221 };
6222 let log_kappa0 = log_kappa0.reseed_from_data(pilot_data, spec, spatial_terms, kappa_options);
6223 let log_kappa_lower = if use_aniso {
6224 SpatialLogKappaCoords::lower_bounds_aniso_from_data(
6225 pilot_data,
6226 spec,
6227 spatial_terms,
6228 &dims_per_term,
6229 kappa_options,
6230 )
6231 } else {
6232 SpatialLogKappaCoords::lower_bounds_from_data(
6233 pilot_data,
6234 spec,
6235 spatial_terms,
6236 kappa_options,
6237 )
6238 };
6239 let log_kappa_upper = if use_aniso {
6240 SpatialLogKappaCoords::upper_bounds_aniso_from_data(
6241 pilot_data,
6242 spec,
6243 spatial_terms,
6244 &dims_per_term,
6245 kappa_options,
6246 )
6247 } else {
6248 SpatialLogKappaCoords::upper_bounds_from_data(
6249 pilot_data,
6250 spec,
6251 spatial_terms,
6252 kappa_options,
6253 )
6254 };
6255 log_kappa0
6256 .clamp_to_bounds(&log_kappa_lower, &log_kappa_upper)
6257 .apply_tospec(spec, spatial_terms)
6258}
6259
6260pub(crate) fn apply_spatial_anisotropy_pilot_initializer(
6261 data: ArrayView2<'_, f64>,
6262 spec: &mut TermCollectionSpec,
6263 spatial_terms: &[usize],
6264 target_size: usize,
6265 kappa_options: &SpatialLengthScaleOptimizationOptions,
6266) -> usize {
6267 if target_size == 0 || data.nrows() <= target_size.saturating_mul(2) || spatial_terms.is_empty()
6268 {
6269 return 0;
6270 }
6271 if !has_aniso_terms(spec, spatial_terms) {
6272 return 0;
6273 }
6274 let indices = stratified_spatial_subsample(data, spec, target_size);
6275 let pilot_data = sampled_rows(data, &indices);
6276 let mut working = spec.clone();
6277 let mut updated_terms = 0usize;
6278 const GEOMETRY_UPDATES: usize = 2;
6279
6280 for pass in 0..GEOMETRY_UPDATES {
6281 let planned_terms = match plan_joint_spatial_centers_for_term_blocks(
6282 pilot_data.view(),
6283 &[working.smooth_terms.clone()],
6284 )
6285 .and_then(|mut blocks| {
6286 blocks.pop().ok_or_else(|| {
6287 BasisError::InvalidInput(
6288 "pilot geometry initializer produced no smooth-term block".to_string(),
6289 )
6290 })
6291 }) {
6292 Ok(terms) => terms,
6293 Err(err) => {
6294 log::warn!(
6295 "[spatial-kappa] pilot geometry initializer skipped after center planning failed: {err}"
6296 );
6297 return updated_terms;
6298 }
6299 };
6300
6301 for &term_idx in spatial_terms {
6302 let Some(current_eta) = get_spatial_aniso_log_scales(&working, term_idx) else {
6303 continue;
6304 };
6305 let Some(d) = get_spatial_feature_dim(&working, term_idx) else {
6306 continue;
6307 };
6308 if d <= 1 || current_eta.len() != d {
6309 continue;
6310 }
6311 let Some(planned_term) = planned_terms.get(term_idx) else {
6312 continue;
6313 };
6314 let Some(centers) = spatial_term_user_centers(planned_term) else {
6315 continue;
6316 };
6317 let Some(eta) = blended_pilot_axis_contrasts(pilot_data.view(), planned_term, centers)
6318 else {
6319 continue;
6320 };
6321 if set_spatial_aniso_log_scales(&mut working, term_idx, eta).is_ok() {
6322 updated_terms += usize::from(pass == 0);
6323 }
6324 }
6325
6326 match apply_pilot_spatial_psi_reseed(
6327 pilot_data.view(),
6328 &working,
6329 spatial_terms,
6330 kappa_options,
6331 ) {
6332 Ok(updated) => {
6333 working = updated;
6334 }
6335 Err(err) => {
6336 log::warn!(
6337 "[spatial-kappa] pilot geometry ψ reseed skipped after deterministic initializer error: {err}"
6338 );
6339 break;
6340 }
6341 }
6342 }
6343
6344 if updated_terms > 0 {
6345 log::info!(
6346 "[spatial-kappa] initialized anisotropy from {}-row pilot geometry for {} spatial term(s); proceeding to full-data optimization",
6347 indices.len(),
6348 updated_terms
6349 );
6350 *spec = working;
6351 }
6352 updated_terms
6353}
6354
6355pub(crate) fn spatial_length_scale_term_indices(spec: &TermCollectionSpec) -> Vec<usize> {
6356 spec.smooth_terms
6357 .iter()
6358 .enumerate()
6359 .filter_map(|(idx, _)| spatial_term_supports_hyper_optimization(spec, idx).then_some(idx))
6360 .collect()
6361}
6362
6363fn fit_score(fit: &UnifiedFitResult) -> f64 {
6375 if fit.reml_score.is_finite() {
6376 return fit.reml_score;
6377 }
6378 let score = 0.5 * fit.deviance + 0.5 * fit.stable_penalty_term;
6379 if score.is_finite() {
6380 score
6381 } else {
6382 f64::INFINITY
6383 }
6384}
6385
6386fn is_recoverable_trial_point_error(err: &EstimationError) -> bool {
6408 matches!(err, EstimationError::BasisError(_))
6409 || err.is_inner_solve_retreat()
6410 || is_recoverable_fit_inference_finiteness_error(err)
6411}
6412
6413fn is_recoverable_fit_inference_finiteness_error(err: &EstimationError) -> bool {
6414 let EstimationError::InvalidInput(message) = err else {
6415 return false;
6416 };
6417
6418 message.contains("must be finite")
6419 && [
6420 "fit_result.beta_covariance_frequentist",
6421 "fit_result.coefficient_influence",
6422 "fit_result.weighted_gram",
6423 ]
6424 .iter()
6425 .any(|field| message.contains(field))
6426}
6427
6428#[cfg(test)]
6429mod spatial_trial_recovery_tests {
6430 use super::*;
6431
6432 #[test]
6433 fn nonfinite_frequentist_covariance_is_recoverable_trial_point() {
6434 let err = EstimationError::InvalidInput(
6435 "fit_result.beta_covariance_frequentist[0] must be finite, got NaN".to_string(),
6436 );
6437
6438 assert!(
6439 is_recoverable_trial_point_error(&err),
6440 "singular trial-point curvature should make spatial κ retreat, not abort"
6441 );
6442 }
6443
6444 #[test]
6445 fn arbitrary_invalid_input_remains_fatal_trial_point_error() {
6446 let err = EstimationError::InvalidInput("outer rho bounds are invalid".to_string());
6447
6448 assert!(
6449 !is_recoverable_trial_point_error(&err),
6450 "the spatial κ recovery gate must not mask unrelated invalid inputs"
6451 );
6452 }
6453}
6454
6455fn require_successful_spatial_optimization_result<T>(
6456 initial_score: f64,
6457 result: Result<Option<(T, f64)>, EstimationError>,
6458) -> Result<T, EstimationError> {
6459 match result {
6460 Ok(Some((value, exact_score))) => {
6461 const SCORE_DRIFT_ABS_TOL: f64 = 1e-6;
6470 const SCORE_DRIFT_REL_TOL: f64 = 1e-8;
6471 let tol = SCORE_DRIFT_ABS_TOL.max(initial_score.abs() * SCORE_DRIFT_REL_TOL);
6472 if exact_score <= initial_score + tol {
6473 Ok(value)
6474 } else {
6475 Err(EstimationError::RemlOptimizationFailed(format!(
6476 "spatial kappa optimization made REML score worse ({initial_score:.6e} -> {exact_score:.6e})"
6477 )))
6478 }
6479 }
6480 Ok(None) => Err(EstimationError::RemlOptimizationFailed(
6481 "spatial kappa optimization is unavailable for one or more eligible spatial terms"
6482 .to_string(),
6483 )),
6484 Err(err) => Err(EstimationError::RemlOptimizationFailed(format!(
6485 "spatial kappa optimization failed: {err}"
6486 ))),
6487 }
6488}
6489
6490fn external_opts_for_design(
6491 family: &LikelihoodSpec,
6492 design: &TermCollectionDesign,
6493 options: &FitOptions,
6494) -> ExternalOptimOptions {
6495 ExternalOptimOptions {
6496 family: family.clone(),
6497 latent_cloglog: options.latent_cloglog,
6498 mixture_link: options.mixture_link.clone(),
6499 optimize_mixture: options.optimize_mixture,
6500 sas_link: options.sas_link,
6501 optimize_sas: options.optimize_sas,
6502 compute_inference: options.compute_inference,
6503 skip_rho_posterior_inference: options.skip_rho_posterior_inference,
6504 max_iter: options.max_iter,
6505 tol: options.tol,
6506 nullspace_dims: design.nullspace_dims.clone(),
6507 linear_constraints: design.linear_constraints.clone(),
6508 firth_bias_reduction: Some(options.firth_bias_reduction),
6509 penalty_shrinkage_floor: options.penalty_shrinkage_floor,
6510 rho_prior: options.rho_prior.clone(),
6511 kronecker_penalty_system: design.kronecker_penalty_system(),
6514 kronecker_factored: design
6515 .smooth
6516 .terms
6517 .iter()
6518 .find_map(|t| t.kronecker_factored.clone()),
6519 persist_warm_start_disk: options.persist_warm_start_disk,
6520 }
6521}
6522
6523fn evaluate_joint_reml_outer_eval_at_theta(
6531 evaluator: &mut gam_solve::estimate::ExternalJointHyperEvaluator<'_>,
6532 design: &TermCollectionDesign,
6533 theta: &Array1<f64>,
6534 rho_dim: usize,
6535 hyper_dirs: Vec<gam_solve::estimate::reml::DirectionalHyperParam>,
6536 warm_start_beta: Option<ArrayView1<'_, f64>>,
6537 order: gam_solve::rho_optimizer::OuterEvalOrder,
6538 design_revision: Option<u64>,
6539) -> Result<
6540 (
6541 f64,
6542 Array1<f64>,
6543 gam_problem::HessianResult,
6544 ),
6545 EstimationError,
6546> {
6547 evaluator.evaluate_with_order(
6548 &design.design,
6549 &design.penalties,
6550 &design.nullspace_dims,
6551 design.linear_constraints.clone(),
6552 theta,
6553 rho_dim,
6554 hyper_dirs,
6555 warm_start_beta,
6556 "evaluate_joint_reml_outer_eval_at_theta",
6557 order,
6558 design_revision,
6559 )
6560}
6561
6562fn evaluate_joint_reml_efs_at_theta(
6563 evaluator: &mut gam_solve::estimate::ExternalJointHyperEvaluator<'_>,
6564 design: &TermCollectionDesign,
6565 theta: &Array1<f64>,
6566 rho_dim: usize,
6567 hyper_dirs: Vec<gam_solve::estimate::reml::DirectionalHyperParam>,
6568 warm_start_beta: Option<ArrayView1<'_, f64>>,
6569 design_revision: Option<u64>,
6570) -> Result<gam_problem::EfsEval, EstimationError> {
6571 evaluator.evaluate_efs(
6572 &design.design,
6573 &design.penalties,
6574 &design.nullspace_dims,
6575 design.linear_constraints.clone(),
6576 theta,
6577 rho_dim,
6578 hyper_dirs,
6579 warm_start_beta,
6580 "evaluate_joint_reml_efs_at_theta",
6581 design_revision,
6582 )
6583}
6584
6585fn exact_joint_spatial_outer_hessian_available(
6586 family: &LikelihoodSpec,
6587 design: &TermCollectionDesign,
6588) -> bool {
6589 let family_supported = match &family.response {
6612 ResponseFamily::Gaussian
6613 | ResponseFamily::Binomial
6614 | ResponseFamily::Poisson
6615 | ResponseFamily::Tweedie { .. }
6616 | ResponseFamily::NegativeBinomial { .. }
6617 | ResponseFamily::Beta { .. }
6618 | ResponseFamily::Gamma
6619 | ResponseFamily::RoystonParmar => true,
6620 };
6621 family_supported && design.design.ncols() > 0
6624}
6625
6626fn smooth_term_penalty_index(
6627 spec: &TermCollectionSpec,
6628 design: &TermCollectionDesign,
6629 term_idx: usize,
6630) -> Option<usize> {
6631 if term_idx >= design.smooth.terms.len() || term_idx >= spec.smooth_terms.len() {
6632 return None;
6633 }
6634 if design.smooth.terms[term_idx].penalties_local.is_empty() {
6635 return None;
6636 }
6637 let linear_penalties = spec
6638 .linear_terms
6639 .iter()
6640 .filter(|t| t.double_penalty)
6641 .count()
6642 * 2;
6643 let random_penalties = design
6644 .random_effect_ranges
6645 .iter()
6646 .filter(|(_, range)| !range.is_empty())
6647 .count();
6648 let smooth_offset = linear_penalties + random_penalties;
6649 let local_offset = design
6650 .smooth
6651 .terms
6652 .iter()
6653 .take(term_idx)
6654 .map(|term| term.penalties_local.len())
6655 .sum::<usize>();
6656 Some(smooth_offset + local_offset)
6657}
6658
6659fn try_build_spatial_term_log_kappa_derivativeinfo(
6660 data: ArrayView2<'_, f64>,
6661 resolvedspec: &TermCollectionSpec,
6662 design: &TermCollectionDesign,
6663 term_idx: usize,
6664) -> Result<Option<SpatialPsiDerivative>, EstimationError> {
6665 let Some((
6666 global_range,
6667 total_p,
6668 x_psi_local,
6669 s_psi_local_check,
6670 x_psi_psi_local,
6671 s_psi_psi_local,
6672 s_psi_components_local,
6673 s_psi_psi_components_local,
6674 implicit_operator,
6675 )) = try_build_spatial_term_log_kappa_derivative(data, resolvedspec, design, term_idx)?
6676 else {
6677 return Ok(None);
6678 };
6679 let Some(penalty_start) = smooth_term_penalty_index(resolvedspec, design, term_idx) else {
6680 return Ok(None);
6681 };
6682 if s_psi_components_local.is_empty() || s_psi_psi_components_local.is_empty() {
6683 return Ok(None);
6684 }
6685 if s_psi_components_local.len() != s_psi_psi_components_local.len() {
6686 return Ok(None);
6687 }
6688 let penalty_indices = (0..s_psi_components_local.len())
6689 .map(|j| penalty_start + j)
6690 .collect::<Vec<_>>();
6691 let penalty_index = penalty_indices[0];
6692 if s_psi_local_check.nrows() == 0 || s_psi_psi_local.nrows() == 0 {
6693 return Ok(None);
6694 }
6695 Ok(Some(SpatialPsiDerivative {
6696 penalty_index,
6697 penalty_indices,
6698 global_range,
6699 total_p,
6700 x_psi_local,
6701 s_psi_components_local,
6702 x_psi_psi_local,
6703 s_psi_psi_components_local,
6704 aniso_group_id: None,
6705 aniso_cross_designs: None,
6706 aniso_cross_penalty_provider: None,
6707 implicit_operator,
6708 implicit_axis: 0,
6709 }))
6710}
6711
6712pub(crate) fn try_build_spatial_log_kappa_derivativeinfo_list(
6713 data: ArrayView2<'_, f64>,
6714 resolvedspec: &TermCollectionSpec,
6715 design: &TermCollectionDesign,
6716 spatial_terms: &[usize],
6717) -> Result<Option<Vec<SpatialPsiDerivative>>, EstimationError> {
6718 let mut out = Vec::new();
6719 let mut aniso_gid = 0usize;
6720 for &term_idx in spatial_terms {
6721 if spatial_term_uses_per_axis_psi(resolvedspec, term_idx) {
6722 if let Some(entries) = try_build_spatial_term_log_kappa_aniso_derivativeinfos(
6723 data,
6724 resolvedspec,
6725 design,
6726 term_idx,
6727 aniso_gid,
6728 )? {
6729 aniso_gid += 1;
6730 out.extend(entries);
6731 continue;
6732 } else {
6733 return Ok(None);
6734 }
6735 }
6736 let Some(info) =
6737 try_build_spatial_term_log_kappa_derivativeinfo(data, resolvedspec, design, term_idx)?
6738 else {
6739 return Ok(None);
6740 };
6741 out.push(info);
6742 }
6743 Ok(Some(out))
6744}
6745
6746fn try_build_spatial_term_log_kappa_aniso_derivativeinfos(
6748 data: ArrayView2<'_, f64>,
6749 resolvedspec: &TermCollectionSpec,
6750 design: &TermCollectionDesign,
6751 term_idx: usize,
6752 aniso_group_id: usize,
6753) -> Result<Option<Vec<SpatialPsiDerivative>>, EstimationError> {
6754 let Some(smooth_term) = design.smooth.terms.get(term_idx) else {
6755 return Ok(None);
6756 };
6757 let Some(termspec) = resolvedspec.smooth_terms.get(term_idx) else {
6758 return Ok(None);
6759 };
6760 let mut aniso_result = match &termspec.basis {
6761 SmoothBasisSpec::Sphere { .. } => return Ok(None),
6762 SmoothBasisSpec::Matern {
6763 feature_cols,
6764 spec,
6765 input_scales,
6766 } => {
6767 let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
6768 if let Some(s) = input_scales {
6769 apply_input_standardization(&mut x, s);
6770 }
6771 let mut spec_operator = spec.clone();
6780 spec_operator.double_penalty = false;
6781 build_matern_basis_log_kappa_aniso_derivatives(x.view(), &spec_operator)
6782 .map_err(EstimationError::from)?
6783 }
6784 SmoothBasisSpec::MeasureJet {
6790 feature_cols,
6791 spec,
6792 input_scales,
6793 } => {
6794 let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
6795 if let Some(s) = input_scales {
6796 apply_input_standardization(&mut x, s);
6797 }
6798 build_measure_jet_basis_psi_derivatives(x.view(), spec)
6799 .map_err(EstimationError::from)?
6800 }
6801 _ => return Ok(None),
6802 };
6803 let d = if let Some(ref op) = aniso_result.implicit_operator {
6806 op.n_axes()
6807 } else if !aniso_result.design_first.is_empty() {
6808 aniso_result.design_first.len()
6809 } else {
6810 0
6811 };
6812 if d == 0 {
6813 return Ok(None);
6814 }
6815 let Some(penalty_start) = smooth_term_penalty_index(resolvedspec, design, term_idx) else {
6816 return Ok(None);
6817 };
6818 let p_total = design.design.ncols();
6819 let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
6820 let global_range = (smooth_start + smooth_term.coeff_range.start)
6821 ..(smooth_start + smooth_term.coeff_range.end);
6822 let num_penalties = aniso_result.penalties_first[0].len();
6823 let penalty_indices: Vec<usize> = (0..num_penalties).map(|j| penalty_start + j).collect();
6824 let penalties_cross_provider = aniso_result.penalties_cross_provider.clone();
6825
6826 let use_implicit_design = aniso_result.design_first.is_empty();
6830 let implicit_op_arc = aniso_result
6831 .implicit_operator
6832 .as_ref()
6833 .map(|op| std::sync::Arc::new(op.clone()));
6834
6835 let mut entries = Vec::with_capacity(d);
6836 for a in 0..d {
6837 let (x_psi_local, x_psi_psi_local) = if use_implicit_design {
6838 (Array2::<f64>::zeros((0, 0)), Array2::<f64>::zeros((0, 0)))
6844 } else {
6845 let x_first = std::mem::take(&mut aniso_result.design_first[a]);
6850 let x_second = std::mem::take(&mut aniso_result.design_second_diag[a]);
6851 if x_first.ncols() != smooth_term.coeff_range.len() {
6852 return Ok(None);
6853 }
6854 (x_first, x_second)
6855 };
6856 let s_psi_components = std::mem::take(&mut aniso_result.penalties_first[a]);
6857 let s_psi_psi_components = std::mem::take(&mut aniso_result.penalties_second_diag[a]);
6858 let cross_designs = if implicit_op_arc.is_some() {
6864 let mut cd = Vec::with_capacity(d - 1);
6865 for b in 0..d {
6866 if b == a {
6867 continue;
6868 }
6869 cd.push((b, Array2::<f64>::zeros((0, 0))));
6870 }
6871 cd
6872 } else if !aniso_result.design_second_cross.is_empty() {
6873 let mut cd = Vec::new();
6874 for (cross_idx, &(pa, pb)) in aniso_result.design_second_cross_pairs.iter().enumerate()
6875 {
6876 if pa == a {
6877 cd.push((pb, aniso_result.design_second_cross[cross_idx].clone()));
6878 } else if pb == a {
6879 cd.push((pa, aniso_result.design_second_cross[cross_idx].clone()));
6880 }
6881 }
6882 cd
6883 } else {
6884 Vec::new()
6885 };
6886 let cross_penalty_provider = if d > 1 {
6887 let penalties_cross_provider = penalties_cross_provider.clone();
6888 Some(std::sync::Arc::new(
6889 move |b_axis: usize| -> Result<Vec<Array2<f64>>, EstimationError> {
6890 if b_axis == a {
6891 return Ok(Vec::new());
6892 }
6893 let (axis_lo, axis_hi) = if a < b_axis { (a, b_axis) } else { (b_axis, a) };
6894 if let Some(provider) = penalties_cross_provider.as_ref() {
6895 provider
6896 .evaluate(axis_lo, axis_hi)
6897 .map_err(EstimationError::from)
6898 } else {
6899 Ok(Vec::new())
6903 }
6904 },
6905 )
6906 as std::sync::Arc<
6907 dyn Fn(usize) -> Result<Vec<Array2<f64>>, EstimationError>
6908 + Send
6909 + Sync
6910 + 'static,
6911 >)
6912 } else {
6913 None
6914 };
6915
6916 entries.push(SpatialPsiDerivative {
6917 penalty_index: penalty_indices[0],
6918 penalty_indices: penalty_indices.clone(),
6919 global_range: global_range.clone(),
6920 total_p: p_total,
6921 x_psi_local,
6922 s_psi_components_local: s_psi_components,
6923 x_psi_psi_local,
6924 s_psi_psi_components_local: s_psi_psi_components,
6925 aniso_group_id: Some(aniso_group_id),
6926 aniso_cross_designs: if cross_designs.is_empty() {
6927 None
6928 } else {
6929 Some(cross_designs)
6930 },
6931 aniso_cross_penalty_provider: cross_penalty_provider,
6932 implicit_operator: implicit_op_arc.clone(),
6933 implicit_axis: a,
6934 });
6935 }
6936 Ok(Some(entries))
6937}
6938
6939#[cfg(test)]
6940mod glm_eta_observation_fd_tests {
6941 use super::*;
6947
6948 fn one_obs(spec: &LikelihoodSpec, y: f64, eta: f64) -> StandardFamilyObservationState {
6949 let yv = Array1::from_vec(vec![y]);
6950 let wv = Array1::from_vec(vec![1.0]);
6951 let ev = Array1::from_vec(vec![eta]);
6952 evaluate_standard_familyobservations(spec.clone(), None, None, None, &yv, &wv, &ev)
6953 .expect("standard family observation state assembles")
6954 }
6955
6956 fn check_fd(label: &str, spec: &LikelihoodSpec, y: f64, eta: f64) {
6957 let h = 1e-5;
6958 let s0 = one_obs(spec, y, eta);
6959 let sp = one_obs(spec, y, eta + h);
6960 let sm = one_obs(spec, y, eta - h);
6961
6962 let score_fd = (sp.log_likelihood - sm.log_likelihood) / (2.0 * h);
6964 let score = s0.score[0];
6965 assert!(
6966 (score - score_fd).abs() <= 1e-4 * (1.0 + score.abs()),
6967 "{label}: score {score} vs FD {score_fd}"
6968 );
6969
6970 let neghess_fd = -(sp.score[0] - sm.score[0]) / (2.0 * h);
6972 let neghess = s0.neghessian_eta[0];
6973 assert!(
6974 (neghess - neghess_fd).abs() <= 1e-3 * (1.0 + neghess.abs()),
6975 "{label}: neghessian_eta {neghess} vs FD {neghess_fd}"
6976 );
6977
6978 let nhd_fd = (sp.neghessian_eta[0] - sm.neghessian_eta[0]) / (2.0 * h);
6980 let nhd = s0.neghessian_eta_derivative[0];
6981 assert!(
6982 (nhd - nhd_fd).abs() <= 1e-2 * (1.0 + nhd.abs()),
6983 "{label}: neghessian_eta_derivative {nhd} vs FD {nhd_fd}"
6984 );
6985 }
6986
6987 #[test]
6988 fn poisson_gamma_nb_tweedie_arms_match_finite_differences_1615_1616() {
6989 let log = InverseLink::Standard(StandardLink::Log);
6990 let poisson = LikelihoodSpec {
6991 response: ResponseFamily::Poisson,
6992 link: log.clone(),
6993 };
6994 check_fd("poisson y=3", &poisson, 3.0, 0.4);
6995 check_fd("poisson y=0", &poisson, 0.0, -0.2);
6996
6997 let gamma = LikelihoodSpec {
6998 response: ResponseFamily::Gamma,
6999 link: log.clone(),
7000 };
7001 check_fd("gamma y=2.5", &gamma, 2.5, 0.3);
7002 check_fd("gamma y=0.7", &gamma, 0.7, -0.1);
7003
7004 let nb = LikelihoodSpec {
7005 response: ResponseFamily::NegativeBinomial {
7006 theta: 1.5,
7007 theta_fixed: true,
7008 },
7009 link: log.clone(),
7010 };
7011 check_fd("negbin y=4", &nb, 4.0, 0.5);
7012 check_fd("negbin y=0", &nb, 0.0, -0.3);
7013
7014 let tweedie = LikelihoodSpec {
7015 response: ResponseFamily::Tweedie { p: 1.5 },
7016 link: log.clone(),
7017 };
7018 check_fd("tweedie y=2", &tweedie, 2.0, 0.25);
7019 check_fd("tweedie y=0.5", &tweedie, 0.5, -0.15);
7020 }
7021}