1pub fn build_term_collection_designs_joint(
9 data: ArrayView2<'_, f64>,
10 specs: &[TermCollectionSpec],
11) -> Result<Vec<TermCollectionDesign>, BasisError> {
12 for spec in specs {
13 validate_term_collection_finite_inputs(data, spec)?;
14 }
15 let smooth_blocks = specs
16 .iter()
17 .map(|spec| spec.smooth_terms.clone())
18 .collect::<Vec<_>>();
19 let planned_blocks = plan_joint_spatial_centers_for_term_blocks(data, &smooth_blocks)?;
20 let mut out = Vec::with_capacity(specs.len());
21 for (spec, planned_terms) in specs.iter().zip(planned_blocks.into_iter()) {
22 let mut planned_spec = spec.clone();
23 planned_spec.smooth_terms = planned_terms;
24 out.push(build_term_collection_design_inner(data, &planned_spec)?);
25 }
26 Ok(out)
27}
28
29pub fn build_term_collection_designs_and_freeze_joint(
30 data: ArrayView2<'_, f64>,
31 specs: &[TermCollectionSpec],
32) -> Result<(Vec<TermCollectionDesign>, Vec<TermCollectionSpec>), EstimationError> {
33 let designs = build_term_collection_designs_joint(data, specs)?;
34 let mut resolved_specs = Vec::with_capacity(specs.len());
35 for (spec, design) in specs.iter().zip(designs.iter()) {
36 resolved_specs.push(freeze_term_collection_from_design(spec, design)?);
37 }
38 Ok((designs, resolved_specs))
39}
40
41pub fn fit_term_collection_forspec(
42 data: ArrayView2<'_, f64>,
43 y: ArrayView1<'_, f64>,
44 weights: ArrayView1<'_, f64>,
45 offset: ArrayView1<'_, f64>,
46 spec: &TermCollectionSpec,
47 family: LikelihoodSpec,
48 options: &FitOptions,
49) -> Result<FittedTermCollection, EstimationError> {
50 fit_term_collection_forspecwith_heuristic_lambdas(
51 data, y, weights, offset, spec, None, family, options,
52 )
53}
54
55pub fn fit_term_collection_with_coefficient_groups(
56 data: ArrayView2<'_, f64>,
57 y: ArrayView1<'_, f64>,
58 weights: ArrayView1<'_, f64>,
59 offset: ArrayView1<'_, f64>,
60 spec: &TermCollectionSpec,
61 groups: &[CoefficientGroupSpec],
62 family: LikelihoodSpec,
63 options: &FitOptions,
64) -> Result<FittedTermCollection, EstimationError> {
65 if groups.is_empty() {
66 return fit_term_collection_forspec(data, y, weights, offset, spec, family, options);
67 }
68 let design = build_term_collection_design(data, spec)?;
69 let base_fit_opts = adaptive_fit_options_base(options, &design);
70 let realized = design
71 .realize_coefficient_groups(groups, &base_fit_opts.rho_prior)
72 .map_err(EstimationError::BasisError)?;
73 let mut grouped_options = base_fit_opts.clone();
74 grouped_options.rho_prior = realized.rho_prior;
75 let fitted = FittedTermCollection {
76 fit: gam_solve::estimate::fit_gam_with_penalty_specs(
77 design.design.clone(),
78 y,
79 weights,
80 offset,
81 realized.penalty_specs,
82 realized.nullspace_dims,
83 family.clone(),
84 &grouped_options,
85 )?,
86 design,
87 adaptive_diagnostics: None,
88 };
89 enforce_term_constraint_feasibility(&fitted.design, &fitted.fit)?;
90 Ok(fitted)
91}
92
93pub fn fit_term_collection_with_penalty_block_gamma_prior_callback<F>(
94 data: ArrayView2<'_, f64>,
95 y: ArrayView1<'_, f64>,
96 weights: ArrayView1<'_, f64>,
97 offset: ArrayView1<'_, f64>,
98 spec: &TermCollectionSpec,
99 callback: F,
100 family: LikelihoodSpec,
101 options: &FitOptions,
102) -> Result<FittedTermCollection, EstimationError>
103where
104 F: FnMut(&PenaltyBlockGammaPriorMetadata<'_>) -> Option<(f64, f64)>,
105{
106 let design = build_term_collection_design(data, spec)?;
107 let mut fit_opts = adaptive_fit_options_base(options, &design);
108 fit_opts.rho_prior = realize_penalty_block_gamma_priors(&design, callback)
109 .map_err(EstimationError::BasisError)?;
110 let fitted = FittedTermCollection {
111 fit: fit_gamwith_heuristic_lambdas(
112 design.design.clone(),
113 y,
114 weights,
115 offset,
116 &design.penalties,
117 None,
118 family.clone(),
119 &fit_opts,
120 )?,
121 design,
122 adaptive_diagnostics: None,
123 };
124 enforce_term_constraint_feasibility(&fitted.design, &fitted.fit)?;
125 Ok(fitted)
126}
127
128pub fn fit_term_collection_with_penalty_block_gamma_priors(
129 data: ArrayView2<'_, f64>,
130 y: ArrayView1<'_, f64>,
131 weights: ArrayView1<'_, f64>,
132 offset: ArrayView1<'_, f64>,
133 spec: &TermCollectionSpec,
134 priors: &[(String, f64, f64)],
135 family: LikelihoodSpec,
136 options: &FitOptions,
137) -> Result<FittedTermCollection, EstimationError> {
138 let design = build_term_collection_design(data, spec)?;
139 let mut fit_opts = adaptive_fit_options_base(options, &design);
140 fit_opts.rho_prior = realize_keyed_penalty_block_gamma_priors(&design, priors)
141 .map_err(EstimationError::BasisError)?;
142 let fitted = FittedTermCollection {
143 fit: fit_gamwith_heuristic_lambdas(
144 design.design.clone(),
145 y,
146 weights,
147 offset,
148 &design.penalties,
149 None,
150 family.clone(),
151 &fit_opts,
152 )?,
153 design,
154 adaptive_diagnostics: None,
155 };
156 enforce_term_constraint_feasibility(&fitted.design, &fitted.fit)?;
157 Ok(fitted)
158}
159
160pub fn fit_term_collection_with_coefficient_groups_and_penalty_block_gamma_priors(
161 data: ArrayView2<'_, f64>,
162 y: ArrayView1<'_, f64>,
163 weights: ArrayView1<'_, f64>,
164 offset: ArrayView1<'_, f64>,
165 spec: &TermCollectionSpec,
166 groups: &[CoefficientGroupSpec],
167 priors: &[(String, f64, f64)],
168 family: LikelihoodSpec,
169 options: &FitOptions,
170) -> Result<FittedTermCollection, EstimationError> {
171 if groups.is_empty() {
172 return fit_term_collection_with_penalty_block_gamma_priors(
173 data, y, weights, offset, spec, priors, family, options,
174 );
175 }
176 if priors.is_empty() {
177 return fit_term_collection_with_coefficient_groups(
178 data, y, weights, offset, spec, groups, family, options,
179 );
180 }
181
182 let design = build_term_collection_design(data, spec)?;
183 let base_fit_opts = adaptive_fit_options_base(options, &design);
184 let base_rho_prior = realize_keyed_penalty_block_gamma_priors(&design, priors)
185 .map_err(EstimationError::BasisError)?;
186 let realized = design
187 .realize_coefficient_groups(groups, &base_rho_prior)
188 .map_err(EstimationError::BasisError)?;
189 let mut grouped_options = base_fit_opts.clone();
190 grouped_options.rho_prior = realized.rho_prior;
191 let fitted = FittedTermCollection {
192 fit: gam_solve::estimate::fit_gam_with_penalty_specs(
193 design.design.clone(),
194 y,
195 weights,
196 offset,
197 realized.penalty_specs,
198 realized.nullspace_dims,
199 family.clone(),
200 &grouped_options,
201 )?,
202 design,
203 adaptive_diagnostics: None,
204 };
205 enforce_term_constraint_feasibility(&fitted.design, &fitted.fit)?;
206 Ok(fitted)
207}
208
209fn fit_term_collection_forspecwith_heuristic_lambdas(
210 data: ArrayView2<'_, f64>,
211 y: ArrayView1<'_, f64>,
212 weights: ArrayView1<'_, f64>,
213 offset: ArrayView1<'_, f64>,
214 spec: &TermCollectionSpec,
215 heuristic_lambdas: Option<&[f64]>,
216 family: LikelihoodSpec,
217 options: &FitOptions,
218) -> Result<FittedTermCollection, EstimationError> {
219 let base_design = build_term_collection_design(data, spec)?;
220 fit_term_collection_on_realized_design(
221 y,
222 weights,
223 offset,
224 spec,
225 &base_design,
226 heuristic_lambdas,
227 family,
228 options,
229 )
230}
231
232fn has_bounded_linear_terms(spec: &TermCollectionSpec) -> bool {
233 spec.linear_terms.iter().any(|term| {
234 matches!(
235 term.coefficient_geometry,
236 LinearCoefficientGeometry::Bounded { .. }
237 )
238 })
239}
240
241fn fit_term_collection_on_realized_design(
242 y: ArrayView1<'_, f64>,
243 weights: ArrayView1<'_, f64>,
244 offset: ArrayView1<'_, f64>,
245 spec: &TermCollectionSpec,
246 design: &TermCollectionDesign,
247 heuristic_lambdas: Option<&[f64]>,
248 family: LikelihoodSpec,
249 options: &FitOptions,
250) -> Result<FittedTermCollection, EstimationError> {
251 if has_bounded_linear_terms(spec) {
252 return fit_bounded_term_collection_with_design(
253 y,
254 weights,
255 offset,
256 spec,
257 design,
258 heuristic_lambdas,
259 family,
260 options,
261 );
262 }
263 let mut base_fit_opts = adaptive_fit_options_base(options, design);
264 base_fit_opts.rho_prior = relax_smoothing_rho_prior(options, design);
271 let fitted = FittedTermCollection {
272 fit: fit_gamwith_heuristic_lambdas(
273 design.design.clone(),
274 y,
275 weights,
276 offset,
277 &design.penalties,
278 heuristic_lambdas,
279 family.clone(),
280 &base_fit_opts,
281 )?,
282 design: design.clone(),
283 adaptive_diagnostics: None,
284 };
285 enforce_term_constraint_feasibility(&fitted.design, &fitted.fit)?;
286
287 let adaptive_opts = options.adaptive_regularization.clone().unwrap_or_default();
288 if !adaptive_opts.enabled {
289 return Ok(fitted);
290 }
291 let runtime_caches = extract_spatial_operator_runtime_caches(spec, &fitted.design)?;
292 if runtime_caches.is_empty() {
293 return Ok(fitted);
294 }
295 fit_term_collectionwith_exact_spatial_adaptive_regularization(
302 fitted,
303 y,
304 weights,
305 offset,
306 family,
307 options,
308 &runtime_caches,
309 )
310}
311
312#[derive(Clone)]
313struct SpatialOperatorRuntimeCache {
314 termname: String,
315 feature_cols: Vec<usize>,
316 coeff_global_range: Range<usize>,
317 mass_penalty_global_idx: usize,
318 tension_penalty_global_idx: usize,
319 stiffness_penalty_global_idx: usize,
320 d0: Array2<f64>,
321 d1: Array2<f64>,
322 d2: Array2<f64>,
323 collocation_points: Array2<f64>,
324 dimension: usize,
325}
326
327#[derive(Clone)]
328struct SpatialAdaptiveWeights {
329 inv_magweight: Array1<f64>,
330 invgradweight: Array1<f64>,
331 inv_lapweight: Array1<f64>,
332}
333
334#[derive(Clone)]
335struct CharbonnierScalarBlockState {
336 signal: Array1<f64>,
337 radius: Array1<f64>,
338 epsilon: f64,
339}
340
341impl CharbonnierScalarBlockState {
342 fn from_signal(signal: Array1<f64>, epsilon: f64) -> Self {
343 let eps = epsilon.max(1e-12);
344 let radius = signal.mapv(|t| (t * t + eps * eps).sqrt());
345 Self {
346 signal,
347 radius,
348 epsilon: eps,
349 }
350 }
351
352 fn absolute_signal(&self) -> Array1<f64> {
353 self.signal.mapv(f64::abs)
354 }
355
356 fn penalty_value(&self) -> f64 {
357 self.radius.iter().map(|r| r - self.epsilon).sum::<f64>()
358 }
359
360 fn betagradient_coeff(&self) -> Array1<f64> {
361 Array1::from_iter(
362 self.signal
363 .iter()
364 .zip(self.radius.iter())
365 .map(|(t, r)| t / r),
366 )
367 }
368
369 fn betahessian_diag(&self) -> Array1<f64> {
370 let eps2 = self.epsilon * self.epsilon;
371 self.radius.mapv(|r| eps2 / r.powi(3))
372 }
373
374 fn log_epsilon_gradient_terms(&self) -> Array1<f64> {
375 let epsilon = self.epsilon;
376 let eps2 = epsilon * epsilon;
377 self.radius.mapv(|r| eps2 / r - epsilon)
378 }
379
380 fn log_epsilon_betagradient_coeff(&self) -> Array1<f64> {
381 let eps2 = self.epsilon * self.epsilon;
382 Array1::from_iter(
383 self.signal
384 .iter()
385 .zip(self.radius.iter())
386 .map(|(t, r)| -eps2 * t / r.powi(3)),
387 )
388 }
389
390 fn log_epsilon_hessian_terms(&self) -> Array1<f64> {
391 let epsilon = self.epsilon;
392 let eps2 = epsilon * epsilon;
393 let eps4 = eps2 * eps2;
394 self.radius
395 .mapv(|r| 2.0 * eps2 / r - eps4 / r.powi(3) - epsilon)
396 }
397
398 fn surrogateweights_posterior_snr(
399 &self,
400 variance: &Array1<f64>,
401 weight_floor: f64,
402 weight_ceiling: f64,
403 ) -> (Array1<f64>, Array1<f64>) {
404 let eps2 = self.epsilon * self.epsilon;
462 let weight = Array1::from_iter(self.signal.iter().zip(variance.iter()).map(|(&t, &v)| {
463 let credible2 = (t * t - v.max(0.0)).max(0.0);
464 let r = (credible2 + eps2).sqrt();
465 (1.0 / r).clamp(weight_floor, weight_ceiling)
466 }));
467 let invweight = weight.mapv(|u| 1.0 / u);
468 (weight, invweight)
469 }
470
471 fn directionalhessian_diag(&self, direction_signal: &Array1<f64>) -> Array1<f64> {
472 let eps2 = self.epsilon * self.epsilon;
487 Array1::from_iter(
488 self.signal
489 .iter()
490 .zip(direction_signal.iter())
491 .zip(self.radius.iter())
492 .map(|((t, q), r)| -3.0 * eps2 * t * q / r.powi(5)),
493 )
494 }
495
496 fn second_directionalhessian_diag(
503 &self,
504 direction1_signal: &Array1<f64>,
505 direction2_signal: &Array1<f64>,
506 ) -> Array1<f64> {
507 let eps2 = self.epsilon * self.epsilon;
508 Array1::from_iter(
509 self.signal
510 .iter()
511 .zip(direction1_signal.iter())
512 .zip(direction2_signal.iter())
513 .zip(self.radius.iter())
514 .map(|(((t, q1), q2), r)| {
515 let r2 = r * r;
516 let psi4 = -3.0 * eps2 / r.powi(5) + 15.0 * eps2 * t * t / (r.powi(5) * r2);
517 psi4 * q1 * q2
518 }),
519 )
520 }
521
522 fn log_epsilon_betahessian_diag(&self) -> Array1<f64> {
523 let eps2 = self.epsilon * self.epsilon;
524 let eps4 = eps2 * eps2;
525 Array1::from_iter(
526 self.signal
527 .iter()
528 .zip(self.radius.iter())
529 .map(|(_, r)| 2.0 * eps2 / r.powi(3) - 3.0 * eps4 / r.powi(5)),
530 )
531 }
532
533 fn log_epsilon_beta_mixed_second_coeff(&self) -> Array1<f64> {
534 let eps2 = self.epsilon * self.epsilon;
535 Array1::from_iter(
536 self.signal
537 .iter()
538 .zip(self.radius.iter())
539 .map(|(t, r)| eps2 * t * (eps2 - 2.0 * t * t) / r.powi(5)),
540 )
541 }
542
543 fn log_epsilon_betahessian_second_diag(&self) -> Array1<f64> {
544 let eps2 = self.epsilon * self.epsilon;
545 let eps4 = eps2 * eps2;
546 let eps6 = eps4 * eps2;
547 Array1::from_iter(
548 self.radius.iter().map(|r| {
549 4.0 * eps2 / r.powi(3) - 18.0 * eps4 / r.powi(5) + 15.0 * eps6 / r.powi(7)
550 }),
551 )
552 }
553
554 fn log_epsilon_betahessian_directional_diag(
555 &self,
556 direction_signal: &Array1<f64>,
557 ) -> Array1<f64> {
558 let eps2 = self.epsilon * self.epsilon;
559 let eps4 = eps2 * eps2;
560 Array1::from_iter(
561 self.signal
562 .iter()
563 .zip(direction_signal.iter())
564 .zip(self.radius.iter())
565 .map(|((t, q), r)| (-6.0 * eps2 * t / r.powi(5) + 15.0 * eps4 * t / r.powi(7)) * q),
566 )
567 }
568}
569
570#[derive(Clone)]
571struct CharbonnierGroupedBlockState {
572 norm: Array1<f64>,
573 radius: Array1<f64>,
574 signal_blocks: Array2<f64>,
575 epsilon: f64,
576}
577
578impl CharbonnierGroupedBlockState {
579 fn from_signal_blocks(signal_blocks: Array2<f64>, epsilon: f64) -> Self {
580 let eps = epsilon.max(1e-12);
581 let norm = Array1::from_iter(
582 signal_blocks
583 .rows()
584 .into_iter()
585 .map(|row| row.iter().map(|v| v * v).sum::<f64>().sqrt()),
586 );
587 let radius = norm.mapv(|g| (g * g + eps * eps).sqrt());
588 Self {
589 norm,
590 radius,
591 signal_blocks,
592 epsilon: eps,
593 }
594 }
595
596 fn penalty_value(&self) -> f64 {
597 self.radius.iter().map(|r| r - self.epsilon).sum::<f64>()
598 }
599
600 fn norm_signal(&self) -> Array1<f64> {
601 self.norm.clone()
602 }
603
604 fn betagradient_blocks(&self) -> Array2<f64> {
605 let mut out = self.signal_blocks.clone();
606 for (k, mut row) in out.rows_mut().into_iter().enumerate() {
607 let scale = 1.0 / self.radius[k];
608 row.mapv_inplace(|v| v * scale);
609 }
610 out
611 }
612
613 fn betahessian_blocks(&self) -> Vec<Array2<f64>> {
614 let mut out = Vec::with_capacity(self.signal_blocks.nrows());
615 for (k, row) in self.signal_blocks.rows().into_iter().enumerate() {
616 let dim = row.len();
617 let mut block = Array2::<f64>::eye(dim);
618 block.mapv_inplace(|v| v / self.radius[k]);
619 for i in 0..dim {
620 for j in 0..dim {
621 block[[i, j]] -= row[i] * row[j] / self.radius[k].powi(3);
622 }
623 }
624 out.push(block);
625 }
626 out
627 }
628
629 fn log_epsilon_gradient_terms(&self) -> Array1<f64> {
630 let epsilon = self.epsilon;
631 let eps2 = epsilon * epsilon;
632 self.radius.mapv(|r| eps2 / r - epsilon)
633 }
634
635 fn log_epsilon_betagradient_blocks(&self) -> Array2<f64> {
636 let mut out = self.signal_blocks.clone();
637 let eps2 = self.epsilon * self.epsilon;
638 for (k, mut row) in out.rows_mut().into_iter().enumerate() {
639 let scale = -eps2 / self.radius[k].powi(3);
640 row.mapv_inplace(|v| v * scale);
641 }
642 out
643 }
644
645 fn log_epsilon_hessian_terms(&self) -> Array1<f64> {
646 let epsilon = self.epsilon;
647 let eps2 = epsilon * epsilon;
648 let eps4 = eps2 * eps2;
649 self.radius
650 .mapv(|r| 2.0 * eps2 / r - eps4 / r.powi(3) - epsilon)
651 }
652
653 fn surrogateweights_posterior_snr(
654 &self,
655 variance: &Array1<f64>,
656 weight_floor: f64,
657 weight_ceiling: f64,
658 ) -> (Array1<f64>, Array1<f64>) {
659 let eps2 = self.epsilon * self.epsilon;
701 let weight = Array1::from_iter(self.norm.iter().zip(variance.iter()).map(|(&g, &v)| {
702 let credible2 = (g * g - v.max(0.0)).max(0.0);
703 let r = (credible2 + eps2).sqrt();
704 (1.0 / r).clamp(weight_floor, weight_ceiling)
705 }));
706 let invweight = weight.mapv(|u| 1.0 / u);
707 (weight, invweight)
708 }
709
710 fn directionalhessian_blocks(&self, direction_blocks: &Array2<f64>) -> Vec<Array2<f64>> {
711 let mut out = Vec::with_capacity(self.signal_blocks.nrows());
736 for (k, (v, q)) in self
737 .signal_blocks
738 .rows()
739 .into_iter()
740 .zip(direction_blocks.rows().into_iter())
741 .enumerate()
742 {
743 let dim = v.len();
744 let dot = v.iter().zip(q.iter()).map(|(a, b)| a * b).sum::<f64>();
745 let r3 = self.radius[k].powi(3);
746 let r5 = self.radius[k].powi(5);
747 let mut block = Array2::<f64>::eye(dim);
748 block.mapv_inplace(|x| -dot * x / r3);
749 for i in 0..dim {
750 for j in 0..dim {
751 block[[i, j]] -= (q[i] * v[j] + v[i] * q[j]) / r3;
752 block[[i, j]] += 3.0 * dot * v[i] * v[j] / r5;
753 }
754 }
755 out.push(block);
756 }
757 out
758 }
759
760 fn second_directionalhessian_blocks(
777 &self,
778 direction1_blocks: &Array2<f64>,
779 direction2_blocks: &Array2<f64>,
780 ) -> Vec<Array2<f64>> {
781 let mut out = Vec::with_capacity(self.signal_blocks.nrows());
782 for ((k, v), (a, b)) in self.signal_blocks.rows().into_iter().enumerate().zip(
783 direction1_blocks
784 .rows()
785 .into_iter()
786 .zip(direction2_blocks.rows().into_iter()),
787 ) {
788 let dim = v.len();
789 let dot = |x: ndarray::ArrayView1<'_, f64>, y: ndarray::ArrayView1<'_, f64>| {
790 x.iter().zip(y.iter()).map(|(p, q)| p * q).sum::<f64>()
791 };
792 let sa = dot(v, a);
793 let sb = dot(v, b);
794 let ab = dot(a, b);
795 let r = self.radius[k];
796 let r3 = r.powi(3);
797 let r5 = r.powi(5);
798 let r7 = r5 * r * r;
799 let diag = -ab / r3 + 3.0 * sa * sb / r5;
800 let mut block = Array2::<f64>::eye(dim);
801 block.mapv_inplace(|x| diag * x);
802 for i in 0..dim {
803 for j in 0..dim {
804 block[[i, j]] -= (a[i] * b[j] + b[i] * a[j]) / r3;
805 block[[i, j]] += 3.0 * sb * (a[i] * v[j] + v[i] * a[j]) / r5;
806 block[[i, j]] += 3.0 * ab * v[i] * v[j] / r5;
807 block[[i, j]] += 3.0 * sa * (b[i] * v[j] + v[i] * b[j]) / r5;
808 block[[i, j]] -= 15.0 * sa * sb * v[i] * v[j] / r7;
809 }
810 }
811 out.push(block);
812 }
813 out
814 }
815
816 fn log_epsilon_betahessian_blocks(&self) -> Vec<Array2<f64>> {
817 let mut out = Vec::with_capacity(self.signal_blocks.nrows());
818 for (k, row) in self.signal_blocks.rows().into_iter().enumerate() {
819 let dim = row.len();
820 let r3 = self.radius[k].powi(3);
821 let r5 = self.radius[k].powi(5);
822 let mut block = Array2::<f64>::eye(dim);
823 let eps2 = self.epsilon * self.epsilon;
824 block.mapv_inplace(|v| -eps2 * v / r3);
825 for i in 0..dim {
826 for j in 0..dim {
827 block[[i, j]] += 3.0 * eps2 * row[i] * row[j] / r5;
828 }
829 }
830 out.push(block);
831 }
832 out
833 }
834
835 fn log_epsilon_beta_mixed_second_blocks(&self) -> Array2<f64> {
836 let mut out = self.signal_blocks.clone();
837 let eps2 = self.epsilon * self.epsilon;
838 for (k, mut row) in out.rows_mut().into_iter().enumerate() {
839 let norm2 = self.norm[k] * self.norm[k];
840 let scale = eps2 * (eps2 - 2.0 * norm2) / self.radius[k].powi(5);
841 row.mapv_inplace(|v| v * scale);
842 }
843 out
844 }
845
846 fn log_epsilon_betahessian_second_blocks(&self) -> Vec<Array2<f64>> {
847 let mut out = Vec::with_capacity(self.signal_blocks.nrows());
848 let eps2 = self.epsilon * self.epsilon;
849 for (k, row) in self.signal_blocks.rows().into_iter().enumerate() {
850 let dim = row.len();
851 let norm2 = self.norm[k] * self.norm[k];
852 let r5 = self.radius[k].powi(5);
853 let r7 = self.radius[k].powi(7);
854 let mut block = Array2::<f64>::eye(dim);
855 block.mapv_inplace(|v| eps2 * (eps2 - 2.0 * norm2) * v / r5);
856 for i in 0..dim {
857 for j in 0..dim {
858 block[[i, j]] += 3.0 * eps2 * (2.0 * norm2 - 3.0 * eps2) * row[i] * row[j] / r7;
859 }
860 }
861 out.push(block);
862 }
863 out
864 }
865
866 fn log_epsilon_betahessian_directional_blocks(
867 &self,
868 direction_blocks: &Array2<f64>,
869 ) -> Vec<Array2<f64>> {
870 let mut out = Vec::with_capacity(self.signal_blocks.nrows());
871 let eps2 = self.epsilon * self.epsilon;
872 for (k, (v, q)) in self
873 .signal_blocks
874 .rows()
875 .into_iter()
876 .zip(direction_blocks.rows().into_iter())
877 .enumerate()
878 {
879 let dim = v.len();
880 let dot = v.iter().zip(q.iter()).map(|(a, b)| a * b).sum::<f64>();
881 let r5 = self.radius[k].powi(5);
882 let r7 = self.radius[k].powi(7);
883 let mut block = Array2::<f64>::eye(dim);
884 block.mapv_inplace(|x| 3.0 * eps2 * dot * x / r5);
885 for i in 0..dim {
886 for j in 0..dim {
887 block[[i, j]] += 3.0 * eps2 * (q[i] * v[j] + v[i] * q[j]) / r5;
888 block[[i, j]] -= 15.0 * eps2 * dot * v[i] * v[j] / r7;
889 }
890 }
891 out.push(block);
892 }
893 out
894 }
895}
896
897fn scalar_operatorgradient(operator: &Array2<f64>, coeff: &Array1<f64>) -> Array1<f64> {
898 operator.t().dot(coeff)
899}
900
901fn scalar_operatorhessian(operator: &Array2<f64>, diag: &Array1<f64>) -> Array2<f64> {
902 let mut weighted = operator.clone();
903 for (k, &w) in diag.iter().enumerate() {
904 weighted.row_mut(k).mapv_inplace(|v| v * w);
905 }
906 let gram = operator.t().dot(&weighted);
907 (&gram + &gram.t().to_owned()) * 0.5
908}
909
910fn grouped_operatorgradient(
911 d1: &Array2<f64>,
912 dimension: usize,
913 blocks: &Array2<f64>,
914) -> Result<Array1<f64>, EstimationError> {
915 if blocks.ncols() != dimension {
916 crate::bail_invalid_estim!(
917 "grouped gradient block dimension mismatch: got {}, expected {dimension}",
918 blocks.ncols()
919 );
920 }
921 if d1.nrows() != blocks.nrows() * dimension {
922 crate::bail_invalid_estim!(
923 "grouped gradient row mismatch: D1 has {} rows, blocks imply {}",
924 d1.nrows(),
925 blocks.nrows() * dimension
926 );
927 }
928 let mut out = Array1::<f64>::zeros(d1.ncols());
929 for k in 0..blocks.nrows() {
930 let gk = d1
931 .slice(s![k * dimension..(k + 1) * dimension, ..])
932 .to_owned();
933 out += &gk.t().dot(&blocks.row(k));
934 }
935 Ok(out)
936}
937
938fn grouped_operatorhessian(
939 d1: &Array2<f64>,
940 dimension: usize,
941 blocks: &[Array2<f64>],
942) -> Result<Array2<f64>, EstimationError> {
943 if d1.nrows() != blocks.len() * dimension {
944 crate::bail_invalid_estim!(
945 "grouped Hessian row mismatch: D1 has {} rows, blocks imply {}",
946 d1.nrows(),
947 blocks.len() * dimension
948 );
949 }
950 let p = d1.ncols();
951 let mut out = Array2::<f64>::zeros((p, p));
952 for (k, block) in blocks.iter().enumerate() {
953 if block.nrows() != dimension || block.ncols() != dimension {
954 crate::bail_invalid_estim!(
955 "grouped Hessian block {k} has shape {}x{}, expected {}x{}",
956 block.nrows(),
957 block.ncols(),
958 dimension,
959 dimension
960 );
961 }
962 let gk = d1
963 .slice(s![k * dimension..(k + 1) * dimension, ..])
964 .to_owned();
965 out += &gk.t().dot(&block.dot(&gk));
966 }
967 Ok((&out + &out.t().to_owned()) * 0.5)
968}
969
970#[derive(Clone)]
971struct SpatialPenaltyExactState {
972 magnitude: CharbonnierScalarBlockState,
973 gradient: CharbonnierGroupedBlockState,
974 curvature: CharbonnierGroupedBlockState,
975}
976
977fn collocationgradient_blocks(
978 gradrows: &Array1<f64>,
979 dimension: usize,
980) -> Result<Array2<f64>, EstimationError> {
981 if dimension == 0 || !gradrows.len().is_multiple_of(dimension) {
982 crate::bail_invalid_estim!(
983 "invalid collocation gradient layout: rows={}, dimension={dimension}",
984 gradrows.len()
985 );
986 }
987 let p = gradrows.len() / dimension;
988 let mut out = Array2::<f64>::zeros((p, dimension));
989 for k in 0..p {
990 for axis in 0..dimension {
991 out[[k, axis]] = gradrows[k * dimension + axis];
992 }
993 }
994 Ok(out)
995}
996
997fn collocationhessian_blocks(
998 hessianrows: &Array1<f64>,
999 dimension: usize,
1000) -> Result<Array2<f64>, EstimationError> {
1001 let block_dim = dimension.checked_mul(dimension).ok_or_else(|| {
1002 EstimationError::InvalidInput("invalid collocation Hessian dimension overflow".to_string())
1003 })?;
1004 if block_dim == 0 || !hessianrows.len().is_multiple_of(block_dim) {
1005 crate::bail_invalid_estim!(
1006 "invalid collocation Hessian layout: rows={}, dimension={dimension}",
1007 hessianrows.len()
1008 );
1009 }
1010 let p = hessianrows.len() / block_dim;
1011 let mut out = Array2::<f64>::zeros((p, block_dim));
1012 for k in 0..p {
1013 for idx in 0..block_dim {
1014 out[[k, idx]] = hessianrows[k * block_dim + idx];
1015 }
1016 }
1017 Ok(out)
1018}
1019
1020impl SpatialPenaltyExactState {
1021 fn from_beta_local(
1022 beta_local: ArrayView1<'_, f64>,
1023 cache: &SpatialOperatorRuntimeCache,
1024 epsilons: [f64; 3],
1025 ) -> Result<Self, EstimationError> {
1026 let gradientrows = cache.d1.dot(&beta_local);
1056 let hessianrows = cache.d2.dot(&beta_local);
1057 Ok(Self {
1058 magnitude: CharbonnierScalarBlockState::from_signal(
1059 cache.d0.dot(&beta_local),
1060 epsilons[0],
1061 ),
1062 gradient: CharbonnierGroupedBlockState::from_signal_blocks(
1063 collocationgradient_blocks(&gradientrows, cache.dimension)?,
1064 epsilons[1],
1065 ),
1066 curvature: CharbonnierGroupedBlockState::from_signal_blocks(
1067 collocationhessian_blocks(&hessianrows, cache.dimension)?,
1068 epsilons[2],
1069 ),
1070 })
1071 }
1072
1073 fn absolute_collocation_magnitudes(&self) -> (Array1<f64>, Array1<f64>, Array1<f64>) {
1074 (
1075 self.magnitude.absolute_signal(),
1076 self.gradient.norm_signal(),
1077 self.curvature.norm_signal(),
1078 )
1079 }
1080}
1081
1082fn robust_epsilon_from_samples(values: &[f64], min_epsilon_cfg: f64) -> f64 {
1083 if values.is_empty() {
1084 return min_epsilon_cfg.max(1e-12);
1085 }
1086 let mut clean = values
1087 .iter()
1088 .copied()
1089 .filter(|v| v.is_finite() && *v >= 0.0)
1090 .collect::<Vec<_>>();
1091 if clean.is_empty() {
1092 return min_epsilon_cfg.max(1e-12);
1093 }
1094 clean.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1095
1096 let n = clean.len();
1097 let median = quantile_from_sorted(&clean, 0.5);
1098 let q75 = quantile_from_sorted(&clean, 0.75);
1099 let q95 = quantile_from_sorted(&clean, 0.95);
1100
1101 let mut abs_dev = clean
1102 .iter()
1103 .map(|v| (v - median).abs())
1104 .filter(|v| v.is_finite())
1105 .collect::<Vec<_>>();
1106 abs_dev.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1107 let mad = 1.4826 * quantile_from_sorted(&abs_dev, 0.5);
1108
1109 let mut scale = median.max(mad).max(q75);
1119
1120 let delta = (f64::EPSILON.sqrt() * q95.max(1.0))
1122 .max(min_epsilon_cfg)
1123 .max(1e-12);
1124 let s_min = min_epsilon_cfg.max(1e-12);
1125
1126 if scale <= delta {
1128 let rms = (clean.iter().map(|v| v * v).sum::<f64>() / n as f64).sqrt();
1129 scale = q95.max(rms);
1130 }
1131 if scale <= delta {
1132 scale = s_min;
1133 }
1134
1135 let kappa = 1.0_f64;
1138 (kappa * scale).max(s_min)
1139}
1140
1141fn extract_spatial_operator_runtime_caches(
1142 spec: &TermCollectionSpec,
1143 design: &TermCollectionDesign,
1144) -> Result<Vec<SpatialOperatorRuntimeCache>, EstimationError> {
1145 let smooth_start = design
1146 .design
1147 .ncols()
1148 .saturating_sub(design.smooth.total_smooth_cols());
1149 let mut out = Vec::<SpatialOperatorRuntimeCache>::new();
1150 for (term_idx, (termspec, term_fit)) in spec
1151 .smooth_terms
1152 .iter()
1153 .zip(design.smooth.terms.iter())
1154 .enumerate()
1155 {
1156 let Some(global_base_idx) = smooth_term_penalty_index(spec, design, term_idx) else {
1157 continue;
1158 };
1159 let mut active_local_idx = 0usize;
1160 let mut mass_local_idx = None;
1161 let mut tension_local_idx = None;
1162 let mut stiffness_local_idx = None;
1163 let mut mass_norm = None;
1164 let mut tension_norm = None;
1165 let mut stiffness_norm = None;
1166 for info in &term_fit.penaltyinfo_local {
1167 if !info.active {
1168 continue;
1169 }
1170 match info.source {
1171 PenaltySource::OperatorMass => {
1172 mass_local_idx = Some(active_local_idx);
1173 mass_norm = Some(info.normalization_scale);
1174 }
1175 PenaltySource::OperatorTension => {
1176 tension_local_idx = Some(active_local_idx);
1177 tension_norm = Some(info.normalization_scale);
1178 }
1179 PenaltySource::OperatorStiffness => {
1180 stiffness_local_idx = Some(active_local_idx);
1181 stiffness_norm = Some(info.normalization_scale);
1182 }
1183 _ => {}
1184 }
1185 active_local_idx += 1;
1186 }
1187 let (
1200 Some(mass_local),
1201 Some(tension_local),
1202 Some(stiffness_local),
1203 Some(mass_scale),
1204 Some(tension_scale),
1205 Some(stiffness_scale),
1206 ) = (
1207 mass_local_idx,
1208 tension_local_idx,
1209 stiffness_local_idx,
1210 mass_norm,
1211 tension_norm,
1212 stiffness_norm,
1213 )
1214 else {
1215 continue;
1216 };
1217 let mass_global_idx = global_base_idx + mass_local;
1218 let tension_global_idx = global_base_idx + tension_local;
1219 let stiffness_global_idx = global_base_idx + stiffness_local;
1220
1221 let (feature_cols, mut d0, mut d1, mut d2, collocation_points, dim, center_mass_rows) =
1222 match (&termspec.basis, &term_fit.metadata) {
1223 (
1224 SmoothBasisSpec::Matern { feature_cols, .. },
1225 BasisMetadata::Matern {
1226 centers,
1227 length_scale,
1228 nu,
1229 include_intercept,
1230 identifiability_transform,
1231 aniso_log_scales,
1232 input_scales,
1233 ..
1234 },
1235 ) => {
1236 let collocation_length_scale = match input_scales.as_deref() {
1242 Some(scales) => {
1243 compensate_length_scale_for_standardization(*length_scale, scales)
1244 }
1245 None => *length_scale,
1246 };
1247 let ops = build_matern_collocation_operator_matrices(
1248 centers.view(),
1249 None,
1250 collocation_length_scale,
1251 *nu,
1252 *include_intercept,
1253 identifiability_transform.as_ref().map(|z| z.view()),
1254 aniso_log_scales.as_deref(),
1255 )?;
1256 (
1257 feature_cols.clone(),
1258 ops.d0,
1259 ops.d1,
1260 ops.d2,
1261 ops.collocation_points,
1262 centers.ncols(),
1263 false,
1264 )
1265 }
1266 (
1267 SmoothBasisSpec::Duchon { feature_cols, .. },
1268 BasisMetadata::Duchon {
1269 centers,
1270 length_scale,
1271 power,
1272 nullspace_order,
1273 identifiability_transform,
1274 input_scales,
1275 aniso_log_scales,
1276 operator_collocation_points: Some(collocation_points),
1277 ..
1278 },
1279 ) => {
1280 let collocation_length_scale = match (length_scale, input_scales.as_deref()) {
1281 (Some(ls), Some(scales)) => {
1282 Some(compensate_length_scale_for_standardization(*ls, scales))
1283 }
1284 (Some(ls), None) => Some(*ls),
1285 (None, _) => None,
1286 };
1287 let ops =
1288 gam_terms::basis::build_duchon_collocation_operator_matriceswithworkspace(
1289 centers.view(),
1290 collocation_points.view(),
1291 None,
1292 collocation_length_scale,
1293 *power,
1294 *nullspace_order,
1295 aniso_log_scales.as_deref(),
1296 identifiability_transform.as_ref().map(|z| z.view()),
1297 2,
1298 &mut BasisWorkspace::default(),
1299 )?;
1300 (
1301 feature_cols.clone(),
1302 ops.d0,
1303 ops.d1,
1304 ops.d2,
1305 ops.collocation_points,
1306 centers.ncols(),
1307 true,
1308 )
1309 }
1310 _ => continue,
1311 };
1312 if center_mass_rows && d0.nrows() > 0 && d0.ncols() > 0 {
1313 let means = d0.sum_axis(Axis(0)).mapv(|v| v / d0.nrows() as f64);
1314 for mut row in d0.rows_mut() {
1315 row -= &means;
1316 }
1317 }
1318
1319 let mass_scale = mass_scale.max(1e-12).sqrt();
1337 let tension_scale = tension_scale.max(1e-12).sqrt();
1338 let stiffness_scale = stiffness_scale.max(1e-12).sqrt();
1339 d0.mapv_inplace(|v| v / mass_scale);
1340 d1.mapv_inplace(|v| v / tension_scale);
1341 d2.mapv_inplace(|v| v / stiffness_scale);
1342
1343 let coeff_global_range =
1344 (smooth_start + term_fit.coeff_range.start)..(smooth_start + term_fit.coeff_range.end);
1345 if d0.ncols() != coeff_global_range.len()
1346 || d1.ncols() != coeff_global_range.len()
1347 || d2.ncols() != coeff_global_range.len()
1348 {
1349 crate::bail_invalid_estim!(
1350 "spatial operator dimension mismatch for term '{}': D0 cols={}, D1 cols={}, D2 cols={}, coeffs={}",
1351 term_fit.name,
1352 d0.ncols(),
1353 d1.ncols(),
1354 d2.ncols(),
1355 coeff_global_range.len()
1356 );
1357 }
1358 out.push(SpatialOperatorRuntimeCache {
1359 termname: term_fit.name.clone(),
1360 feature_cols,
1361 coeff_global_range,
1362 mass_penalty_global_idx: mass_global_idx,
1363 tension_penalty_global_idx: tension_global_idx,
1364 stiffness_penalty_global_idx: stiffness_global_idx,
1365 d0,
1366 d1,
1367 d2,
1368 collocation_points,
1369 dimension: dim,
1370 });
1371 }
1372 Ok(out)
1373}
1374
1375fn scalar_operator_response_variance(
1387 operator: &Array2<f64>,
1388 cov_local: &Array2<f64>,
1389) -> Array1<f64> {
1390 Array1::from_iter(operator.rows().into_iter().map(|row| {
1391 let s = cov_local.dot(&row);
1392 row.dot(&s).max(0.0)
1393 }))
1394}
1395
1396fn grouped_operator_response_variance(
1407 operator: &Array2<f64>,
1408 block_dim: usize,
1409 cov_local: &Array2<f64>,
1410) -> Result<Array1<f64>, EstimationError> {
1411 if block_dim == 0 || !operator.nrows().is_multiple_of(block_dim) {
1412 crate::bail_invalid_estim!(
1413 "grouped variance row layout invalid: rows={}, block_dim={block_dim}",
1414 operator.nrows()
1415 );
1416 }
1417 let p = operator.nrows() / block_dim;
1418 let mut out = Array1::<f64>::zeros(p);
1419 for k in 0..p {
1420 let mut acc = 0.0;
1421 for axis in 0..block_dim {
1422 let row = operator.row(k * block_dim + axis);
1423 let s = cov_local.dot(&row);
1424 acc += row.dot(&s);
1425 }
1426 out[k] = acc.max(0.0);
1427 }
1428 Ok(out)
1429}
1430
1431fn compute_spatial_adaptiveweights_for_beta(
1432 beta: &Array1<f64>,
1433 caches: &[SpatialOperatorRuntimeCache],
1434 epsilon_0: f64,
1435 epsilon_g: f64,
1436 epsilon_c: f64,
1437 weight_floor: f64,
1438 weight_ceiling: f64,
1439 beta_covariance: Option<&Array2<f64>>,
1440) -> Result<Vec<SpatialAdaptiveWeights>, EstimationError> {
1441 caches
1473 .iter()
1474 .map(|cache| {
1475 let beta_local = beta.slice(s![cache.coeff_global_range.clone()]);
1476 let exact = SpatialPenaltyExactState::from_beta_local(
1477 beta_local,
1478 cache,
1479 [epsilon_0, epsilon_g, epsilon_c],
1480 )?;
1481 let cov_local = beta_covariance.map(|cov| {
1482 cov.slice(s![
1483 cache.coeff_global_range.clone(),
1484 cache.coeff_global_range.clone()
1485 ])
1486 .to_owned()
1487 });
1488 let dim = cache.dimension;
1489 let (var_0, var_g, var_c) = match cov_local.as_ref() {
1490 Some(cov) => (
1491 scalar_operator_response_variance(&cache.d0, cov),
1492 grouped_operator_response_variance(&cache.d1, dim, cov)?,
1493 grouped_operator_response_variance(&cache.d2, dim * dim, cov)?,
1494 ),
1495 None => (
1496 Array1::<f64>::zeros(exact.magnitude.signal.len()),
1497 Array1::<f64>::zeros(exact.gradient.norm.len()),
1498 Array1::<f64>::zeros(exact.curvature.norm.len()),
1499 ),
1500 };
1501 let (_, inv_0) = exact.magnitude.surrogateweights_posterior_snr(
1502 &var_0,
1503 weight_floor,
1504 weight_ceiling,
1505 );
1506 let (_, inv_g) =
1507 exact
1508 .gradient
1509 .surrogateweights_posterior_snr(&var_g, weight_floor, weight_ceiling);
1510 let (_, inv_c) = exact.curvature.surrogateweights_posterior_snr(
1511 &var_c,
1512 weight_floor,
1513 weight_ceiling,
1514 );
1515 Ok(SpatialAdaptiveWeights {
1516 inv_magweight: inv_0,
1517 invgradweight: inv_g,
1518 inv_lapweight: inv_c,
1519 })
1520 })
1521 .collect()
1522}
1523
1524fn compute_initial_epsilons(
1525 beta: &Array1<f64>,
1526 caches: &[SpatialOperatorRuntimeCache],
1527 min_epsilon: f64,
1528) -> Result<(f64, f64, f64), EstimationError> {
1529 let mut fvals = Vec::<f64>::new();
1530 let mut gvals = Vec::<f64>::new();
1531 let mut cvals = Vec::<f64>::new();
1532 for cache in caches {
1533 let beta_local = beta.slice(s![cache.coeff_global_range.clone()]);
1534 let exact = SpatialPenaltyExactState::from_beta_local(
1535 beta_local,
1536 cache,
1537 [min_epsilon, min_epsilon, min_epsilon],
1538 )?;
1539 let (f, g, c) = exact.absolute_collocation_magnitudes();
1540 fvals.extend(f.iter().copied());
1541 gvals.extend(g.iter().copied());
1542 cvals.extend(c.iter().copied());
1543 }
1544 let eps_0 = robust_epsilon_from_samples(&fvals, min_epsilon);
1550 let eps_g = robust_epsilon_from_samples(&gvals, min_epsilon);
1551 let eps_c = robust_epsilon_from_samples(&cvals, min_epsilon);
1552 Ok((eps_0, eps_g, eps_c))
1553}
1554
1555fn exact_spatial_adaptive_penalty_index_set(
1556 caches: &[SpatialOperatorRuntimeCache],
1557) -> BTreeSet<usize> {
1558 let mut out = BTreeSet::new();
1559 for cache in caches {
1560 out.insert(cache.mass_penalty_global_idx);
1561 out.insert(cache.tension_penalty_global_idx);
1562 out.insert(cache.stiffness_penalty_global_idx);
1563 }
1564 out
1565}
1566
1567fn build_spatial_adaptive_hyperspecs(cache_count: usize) -> Vec<SpatialAdaptiveHyperSpec> {
1568 let mut out = Vec::with_capacity(cache_count * 3 + 3);
1569 for cache_index in 0..cache_count {
1570 out.push(SpatialAdaptiveHyperSpec {
1571 cache_index,
1572 kind: SpatialAdaptiveHyperKind::LogLambdaMagnitude,
1573 });
1574 out.push(SpatialAdaptiveHyperSpec {
1575 cache_index,
1576 kind: SpatialAdaptiveHyperKind::LogLambdaGradient,
1577 });
1578 out.push(SpatialAdaptiveHyperSpec {
1579 cache_index,
1580 kind: SpatialAdaptiveHyperKind::LogLambdaCurvature,
1581 });
1582 }
1583 out.push(SpatialAdaptiveHyperSpec {
1584 cache_index: 0,
1585 kind: SpatialAdaptiveHyperKind::LogEpsilonMagnitude,
1586 });
1587 out.push(SpatialAdaptiveHyperSpec {
1588 cache_index: 0,
1589 kind: SpatialAdaptiveHyperKind::LogEpsilonGradient,
1590 });
1591 out.push(SpatialAdaptiveHyperSpec {
1592 cache_index: 0,
1593 kind: SpatialAdaptiveHyperKind::LogEpsilonCurvature,
1594 });
1595 out
1596}
1597
1598fn penalty_matrixwith_local_block(
1599 total_dim: usize,
1600 coeff_range: Range<usize>,
1601 local: &Array2<f64>,
1602) -> Array2<f64> {
1603 let mut out = Array2::<f64>::zeros((total_dim, total_dim));
1604 out.slice_mut(s![coeff_range.clone(), coeff_range])
1605 .assign(local);
1606 out
1607}
1608
1609fn fit_term_collectionwith_exact_spatial_adaptive_regularization(
1610 baseline: FittedTermCollection,
1611 y: ArrayView1<'_, f64>,
1612 weights: ArrayView1<'_, f64>,
1613 offset: ArrayView1<'_, f64>,
1614 family: LikelihoodSpec,
1615 options: &FitOptions,
1616 runtime_caches: &[SpatialOperatorRuntimeCache],
1617) -> Result<FittedTermCollection, EstimationError> {
1618 let adaptive_opts = options.adaptive_regularization.clone().unwrap_or_default();
1647 let adaptive_penalty_indices = exact_spatial_adaptive_penalty_index_set(runtime_caches);
1648 let p_total = baseline.design.design.ncols();
1649 struct RetainedPenaltySetup {
1650 global_idx: usize,
1651 global_penalty: Array2<f64>,
1652 nullspace_dim: usize,
1653 log_lambda: f64,
1654 col_range: Range<usize>,
1655 hessian_piece: Array2<f64>,
1656 }
1657 use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
1658 let retained_setups = baseline
1659 .design
1660 .penalties
1661 .par_iter()
1662 .enumerate()
1663 .map(|(idx, bp)| {
1664 if adaptive_penalty_indices.contains(&idx) {
1665 return None;
1666 }
1667 let lambda = baseline.fit.lambdas[idx];
1668 Some(RetainedPenaltySetup {
1669 global_idx: idx,
1670 global_penalty: bp.to_global(p_total),
1671 nullspace_dim: baseline
1672 .design
1673 .nullspace_dims
1674 .get(idx)
1675 .copied()
1676 .unwrap_or(0),
1677 log_lambda: lambda.max(1e-12).ln(),
1678 col_range: bp.col_range.clone(),
1679 hessian_piece: bp.local.mapv(|v| lambda * v),
1680 })
1681 })
1682 .collect::<Vec<_>>();
1683 let retained_count = retained_setups
1684 .iter()
1685 .filter(|setup| setup.is_some())
1686 .count();
1687 let mut retained_penalties = Vec::<Array2<f64>>::with_capacity(retained_count);
1688 let mut retained_nullspace_dims = Vec::<usize>::with_capacity(retained_count);
1689 let mut retained_log_lambdas = Vec::<f64>::with_capacity(retained_count);
1690 let mut retained_global_indices = Vec::<usize>::with_capacity(retained_count);
1691 let mut fixed_quadratichessian = Array2::<f64>::zeros((p_total, p_total));
1692 for setup in retained_setups.into_iter().flatten() {
1693 retained_penalties.push(setup.global_penalty);
1694 retained_nullspace_dims.push(setup.nullspace_dim);
1695 retained_log_lambdas.push(setup.log_lambda);
1696 retained_global_indices.push(setup.global_idx);
1697 fixed_quadratichessian
1698 .slice_mut(s![setup.col_range.clone(), setup.col_range])
1699 .scaled_add(1.0, &setup.hessian_piece);
1700 }
1701
1702 let (eps_0_init, eps_g_init, eps_c_init) = compute_initial_epsilons(
1703 &baseline.fit.beta,
1704 runtime_caches,
1705 adaptive_opts.min_epsilon,
1706 )?;
1707 let mut initial_theta =
1708 Array1::<f64>::zeros(retained_penalties.len() + runtime_caches.len() * 3 + 3);
1709 for (idx, value) in retained_log_lambdas.iter().enumerate() {
1710 initial_theta[idx] = *value;
1711 }
1712 let adaptive_log_lambda_components = runtime_caches
1713 .par_iter()
1714 .map(|cache| {
1715 [
1716 baseline.fit.lambdas[cache.mass_penalty_global_idx]
1717 .max(1e-12)
1718 .ln(),
1719 baseline.fit.lambdas[cache.tension_penalty_global_idx]
1720 .max(1e-12)
1721 .ln(),
1722 baseline.fit.lambdas[cache.stiffness_penalty_global_idx]
1723 .max(1e-12)
1724 .ln(),
1725 ]
1726 })
1727 .collect::<Vec<_>>();
1728 let mut at = retained_penalties.len();
1729 for logs in &adaptive_log_lambda_components {
1730 initial_theta[at] = logs[0];
1731 initial_theta[at + 1] = logs[1];
1732 initial_theta[at + 2] = logs[2];
1733 at += 3;
1734 }
1735 initial_theta[at] = eps_0_init.max(adaptive_opts.min_epsilon).ln();
1736 initial_theta[at + 1] = eps_g_init.max(adaptive_opts.min_epsilon).ln();
1737 initial_theta[at + 2] = eps_c_init.max(adaptive_opts.min_epsilon).ln();
1738
1739 let hyperspecs = build_spatial_adaptive_hyperspecs(runtime_caches.len());
1740 let zero_psi_op: std::sync::Arc<dyn gam_custom_family::CustomFamilyPsiDerivativeOperator> =
1741 std::sync::Arc::new(gam_custom_family::ZeroPsiDerivativeOperator::new(
1742 baseline.design.design.nrows(),
1743 baseline.design.design.ncols(),
1744 ));
1745 let derivative_blocks = vec![
1746 hyperspecs
1747 .par_iter()
1748 .map(|_| CustomFamilyBlockPsiDerivative {
1749 penalty_index: None,
1750 x_psi: Array2::<f64>::zeros((0, 0)),
1751 s_psi: Array2::<f64>::zeros((0, 0)),
1752 s_psi_components: None,
1753 s_psi_penalty_components: None,
1754 x_psi_psi: None,
1755 s_psi_psi: None,
1756 s_psi_psi_components: None,
1757 s_psi_psi_penalty_components: None,
1758 implicit_operator: Some(std::sync::Arc::clone(&zero_psi_op)),
1759 implicit_axis: 0,
1760 implicit_group_id: None,
1761 })
1762 .collect::<Vec<_>>(),
1763 ];
1764
1765 let mixture_link_state = options
1766 .mixture_link
1767 .clone()
1768 .as_ref()
1769 .map(state_fromspec)
1770 .transpose()
1771 .map_err(EstimationError::InvalidInput)?;
1772 let sas_link_state = options
1773 .sas_link
1774 .map(|spec| {
1775 if family.is_binomial_beta_logistic() {
1776 state_from_beta_logisticspec(spec)
1777 } else {
1778 state_from_sasspec(spec)
1779 }
1780 })
1781 .transpose()
1782 .map_err(EstimationError::InvalidInput)?;
1783 let latent_cloglog_state = options.latent_cloglog;
1784 let shared_y = Arc::new(y.to_owned());
1785 let sharedweights = Arc::new(weights.to_owned());
1786 let shared_design = baseline
1787 .design
1788 .design
1789 .try_to_dense_arc("spatial adaptive exact hyperfit design")
1790 .map_err(EstimationError::InvalidInput)?;
1791 let shared_offset = Arc::new(offset.to_owned());
1792 let shared_runtime_caches = Arc::new(runtime_caches.to_vec());
1793 let shared_hyperspecs = Arc::new(hyperspecs.clone());
1794 let zero_quadratic = Arc::new(Array2::<f64>::zeros((
1795 baseline.design.design.ncols(),
1796 baseline.design.design.ncols(),
1797 )));
1798 let base_family = SpatialAdaptiveExactFamily {
1799 family: family.clone(),
1800 latent_cloglog_state,
1801 mixture_link_state: mixture_link_state.clone(),
1802 sas_link_state,
1803 y: shared_y.clone(),
1804 weights: sharedweights.clone(),
1805 design: shared_design.clone(),
1806 offset: shared_offset.clone(),
1807 linear_constraints: baseline.design.linear_constraints.clone(),
1808 runtime_caches: shared_runtime_caches.clone(),
1809 adaptive_params: Vec::new(),
1810 fixed_quadratichessian: zero_quadratic.clone(),
1811 hyperspecs: shared_hyperspecs.clone(),
1812 exact_eval_cache: Arc::new(Mutex::new(None)),
1813 };
1814
1815 let rho_dim = retained_penalties.len();
1816 let operator_slots_end = rho_dim + runtime_caches.len() * 3;
1817 const UNIFIED_LOG_WINDOW: f64 = 6.0;
1827 const RETAINED_LAMBDA_LOG_LOWER_FLOOR: f64 = -30.0;
1828 const RETAINED_LAMBDA_LOG_UPPER_CAP: f64 = 30.0;
1829 const OPERATOR_LAMBDA_LOG_LOWER_FLOOR: f64 = -10.0;
1830 const OPERATOR_LAMBDA_LOG_UPPER_CAP: f64 = 30.0;
1831 let epsilon_floor_log = adaptive_opts.min_epsilon.max(1e-12).ln();
1832 let anchored_bound = |idx: usize, sign: f64| -> f64 {
1833 let raw = initial_theta[idx] + sign * UNIFIED_LOG_WINDOW;
1834 if idx < rho_dim {
1835 raw.clamp(
1836 RETAINED_LAMBDA_LOG_LOWER_FLOOR,
1837 RETAINED_LAMBDA_LOG_UPPER_CAP,
1838 )
1839 } else if idx < operator_slots_end {
1840 raw.clamp(
1841 OPERATOR_LAMBDA_LOG_LOWER_FLOOR,
1842 OPERATOR_LAMBDA_LOG_UPPER_CAP,
1843 )
1844 } else {
1845 raw.max(epsilon_floor_log)
1846 }
1847 };
1848 let eps_lower =
1849 Array1::from_iter((0..initial_theta.len()).map(|idx| anchored_bound(idx, -1.0)));
1850 let eps_upper = Array1::from_iter((0..initial_theta.len()).map(|idx| anchored_bound(idx, 1.0)));
1851 let blockspec = ParameterBlockSpec {
1852 name: "eta".to_string(),
1853 design: baseline.design.design.clone(),
1854 offset: offset.to_owned(),
1855 penalties: retained_penalties
1856 .iter()
1857 .cloned()
1858 .map(PenaltyMatrix::Dense)
1859 .collect(),
1860 nullspace_dims: retained_nullspace_dims.clone(),
1861 initial_log_lambdas: Array1::from_vec(retained_log_lambdas.clone()),
1862 initial_beta: Some(baseline.fit.beta.clone()),
1863 gauge_priority: 100,
1864 jacobian_callback: None,
1865 stacked_design: None,
1866 stacked_offset: None,
1867 };
1868 let screening_cap = Arc::new(AtomicUsize::new(0));
1869 let outer_opts = BlockwiseFitOptions {
1870 inner_max_cycles: options.max_iter,
1871 inner_tol: options.tol,
1872 outer_max_iter: options.max_iter,
1873 outer_tol: options.tol,
1874 compute_covariance: false,
1875 screening_max_inner_iterations: Some(Arc::clone(&screening_cap)),
1876 ..BlockwiseFitOptions::default()
1877 };
1878
1879 use gam_solve::rho_optimizer::OuterProblem;
1880 use gam_problem::{DeclaredHessianForm, Derivative, HessianResult, OuterEval};
1881
1882 struct SpatialAdaptiveOuterState {
1883 warm_cache: Option<CustomFamilyWarmStart>,
1884 last_eval: Option<(
1885 Array1<f64>,
1886 f64,
1887 Array1<f64>,
1888 HessianResult,
1889 CustomFamilyWarmStart,
1890 )>,
1891 }
1892
1893 let n_theta = initial_theta.len();
1894
1895 let theta_bounds = Some((eps_lower.clone(), eps_upper.clone()));
1898 let clamp_theta = {
1899 let lo = eps_lower;
1900 let hi = eps_upper;
1901 move |theta: &Array1<f64>| -> Array1<f64> {
1902 let mut clamped = theta.clone();
1903 for i in 0..clamped.len() {
1904 clamped[i] = clamped[i].clamp(lo[i], hi[i]);
1905 }
1906 clamped
1907 }
1908 };
1909
1910 let decode_theta = |theta: &Array1<f64>| -> (Array1<f64>, Vec<SpatialAdaptiveTermHyperParams>) {
1911 let rho = theta.slice(s![..rho_dim]).to_owned();
1912 let adaptive_lambda_start = rho_dim;
1913 let adaptive_lambda_end = adaptive_lambda_start + runtime_caches.len() * 3;
1914 let eps = [
1915 theta[adaptive_lambda_end].exp(),
1916 theta[adaptive_lambda_end + 1].exp(),
1917 theta[adaptive_lambda_end + 2].exp(),
1918 ];
1919 let adaptive_params = runtime_caches
1920 .iter()
1921 .enumerate()
1922 .map(|(cache_idx, _)| SpatialAdaptiveTermHyperParams {
1923 lambda: [
1924 theta[adaptive_lambda_start + cache_idx * 3].exp(),
1925 theta[adaptive_lambda_start + cache_idx * 3 + 1].exp(),
1926 theta[adaptive_lambda_start + cache_idx * 3 + 2].exp(),
1927 ],
1928 epsilon: eps,
1929 })
1930 .collect::<Vec<_>>();
1931 (rho, adaptive_params)
1932 };
1933 let analytic_outer_hessian_available =
1934 gam_custom_family::joint_exact_analytic_outer_hessian_available()
1935 && base_family
1936 .exact_outer_derivative_order(std::slice::from_ref(&blockspec), &outer_opts)
1937 .has_hessian()
1938 && gam_custom_family::exact_newton_outer_geometry_supports_second_order_solver(
1939 &base_family,
1940 );
1941 let outer_max_iter = gam_custom_family::cost_gated_first_order_max_iter(
1942 options.max_iter,
1943 base_family.coefficient_gradient_cost(std::slice::from_ref(&blockspec)),
1944 analytic_outer_hessian_available,
1945 );
1946 if outer_max_iter < options.max_iter {
1947 log::info!(
1948 "[OUTER] exact spatial adaptive regularization: first-order work gate reduced outer_max_iter {} -> {}",
1949 options.max_iter,
1950 outer_max_iter,
1951 );
1952 }
1953 let problem = OuterProblem::new(n_theta)
1959 .with_gradient(Derivative::Analytic)
1960 .with_hessian(if analytic_outer_hessian_available {
1961 DeclaredHessianForm::Either
1962 } else {
1963 DeclaredHessianForm::Unavailable
1964 })
1965 .with_fallback_policy(gam_solve::rho_optimizer::FallbackPolicy::Disabled)
1966 .with_psi_dim(n_theta.saturating_sub(rho_dim))
1967 .with_tolerance(options.tol)
1968 .with_max_iter(outer_max_iter)
1969 .with_seed_config(gam_problem::SeedConfig::default())
1970 .with_screening_cap(Arc::clone(&screening_cap))
1971 .with_initial_rho(initial_theta.clone());
1972 let problem = if let Some((lo, hi)) = theta_bounds {
1973 problem.with_bounds(lo, hi)
1974 } else {
1975 problem
1976 };
1977
1978 let eval_outer = |st: &mut SpatialAdaptiveOuterState,
1979 theta: &Array1<f64>,
1980 order: gam_solve::rho_optimizer::OuterEvalOrder|
1981 -> Result<OuterEval, EstimationError> {
1982 let theta = clamp_theta(theta);
1983
1984 if let Some((cached_theta, cached_cost, cached_grad, cached_hess, cached_warm)) =
1985 &st.last_eval
1986 && cached_theta.len() == theta.len()
1987 && cached_theta
1988 .iter()
1989 .zip(theta.iter())
1990 .all(|(&a, &b)| (a - b).abs() <= 1e-12)
1991 && (!matches!(
1992 order,
1993 gam_solve::rho_optimizer::OuterEvalOrder::ValueGradientHessian
1994 ) || analytic_outer_hessian_available)
1995 {
1996 st.warm_cache = Some(cached_warm.clone());
1997 return Ok(OuterEval {
1998 cost: *cached_cost,
1999 gradient: cached_grad.clone(),
2000 hessian: if matches!(
2001 order,
2002 gam_solve::rho_optimizer::OuterEvalOrder::ValueGradientHessian
2003 ) && analytic_outer_hessian_available
2004 {
2005 cached_hess.clone()
2006 } else {
2007 HessianResult::Unavailable
2008 },
2009 inner_beta_hint: None,
2010 });
2011 }
2012
2013 let (rho, adaptive_params) = decode_theta(&theta);
2014 let family_eval = base_family.with_adaptive_params(adaptive_params, zero_quadratic.clone());
2015 let need_hessian = matches!(
2016 order,
2017 gam_solve::rho_optimizer::OuterEvalOrder::ValueGradientHessian
2018 ) && analytic_outer_hessian_available;
2019 let result = evaluate_custom_family_joint_hyper(
2020 &family_eval,
2021 std::slice::from_ref(&blockspec),
2022 &outer_opts,
2023 &rho,
2024 &derivative_blocks,
2025 st.warm_cache.as_ref(),
2026 if need_hessian {
2027 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueGradientHessian
2028 } else {
2029 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueAndGradient
2030 },
2031 )
2032 .map_err(|e| {
2033 EstimationError::RemlOptimizationFailed(format!("spatial adaptive eval failed: {e}"))
2034 })?;
2035 if !result.inner_converged {
2036 st.warm_cache = Some(result.warm_start.clone());
2037 return Err(EstimationError::RemlOptimizationFailed(
2038 "exact spatial adaptive inner solve did not converge".to_string(),
2039 ));
2040 }
2041 if !result.objective.is_finite() || result.gradient.iter().any(|v| !v.is_finite()) {
2042 return Err(EstimationError::RemlOptimizationFailed(
2043 "exact spatial adaptive objective returned non-finite values".to_string(),
2044 ));
2045 }
2046 let hessian_result = if need_hessian {
2047 if !result.outer_hessian.is_analytic() {
2048 return Err(EstimationError::RemlOptimizationFailed(
2049 "exact spatial adaptive objective did not return an exact outer Hessian"
2050 .to_string(),
2051 ));
2052 }
2053 match result.outer_hessian.dim() {
2054 Some(dim) if dim == theta.len() => {}
2055 Some(dim) => {
2056 return Err(EstimationError::RemlOptimizationFailed(format!(
2057 "exact spatial adaptive outer Hessian dimension mismatch: got {dim}, expected {}",
2058 theta.len(),
2059 )));
2060 }
2061 None => {
2062 return Err(EstimationError::RemlOptimizationFailed(
2063 "exact spatial adaptive objective did not report an outer Hessian dimension"
2064 .to_string(),
2065 ));
2066 }
2067 }
2068 st.last_eval = Some((
2069 theta.clone(),
2070 result.objective,
2071 result.gradient.clone(),
2072 result.outer_hessian.clone(),
2073 result.warm_start.clone(),
2074 ));
2075 result.outer_hessian
2076 } else {
2077 HessianResult::Unavailable
2078 };
2079 st.warm_cache = Some(result.warm_start);
2080 Ok(OuterEval {
2081 cost: result.objective,
2082 gradient: result.gradient,
2083 hessian: hessian_result,
2084 inner_beta_hint: None,
2085 })
2086 };
2087
2088 let mut obj = problem.build_objective_with_screening_proxy(
2089 SpatialAdaptiveOuterState {
2090 warm_cache: None,
2091 last_eval: None,
2092 },
2093 |st: &mut SpatialAdaptiveOuterState, theta: &Array1<f64>| {
2094 let theta = clamp_theta(theta);
2095 let (rho, adaptive_params) = decode_theta(&theta);
2096 let family_eval =
2097 base_family.with_adaptive_params(adaptive_params, zero_quadratic.clone());
2098 let result = evaluate_custom_family_joint_hyper(
2099 &family_eval,
2100 std::slice::from_ref(&blockspec),
2101 &outer_opts,
2102 &rho,
2103 &derivative_blocks,
2104 st.warm_cache.as_ref(),
2105 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueOnly,
2106 )
2107 .map_err(|e| {
2108 EstimationError::RemlOptimizationFailed(format!(
2109 "spatial adaptive cost eval failed: {e}"
2110 ))
2111 })?;
2112 if !result.inner_converged {
2113 st.warm_cache = Some(result.warm_start);
2114 return Err(EstimationError::RemlOptimizationFailed(
2115 "exact spatial adaptive cost inner solve did not converge".to_string(),
2116 ));
2117 }
2118 st.warm_cache = Some(result.warm_start);
2119 Ok(result.objective)
2120 },
2121 |st: &mut SpatialAdaptiveOuterState, theta: &Array1<f64>| {
2122 eval_outer(
2123 st,
2124 theta,
2125 if analytic_outer_hessian_available {
2126 gam_solve::rho_optimizer::OuterEvalOrder::ValueGradientHessian
2127 } else {
2128 gam_solve::rho_optimizer::OuterEvalOrder::ValueAndGradient
2129 },
2130 )
2131 },
2132 |st: &mut SpatialAdaptiveOuterState,
2133 theta: &Array1<f64>,
2134 order: gam_solve::rho_optimizer::OuterEvalOrder| {
2135 eval_outer(st, theta, order)
2136 },
2137 Some(|st: &mut SpatialAdaptiveOuterState| {
2138 st.warm_cache = None;
2139 st.last_eval = None;
2140 }),
2141 Some(|st: &mut SpatialAdaptiveOuterState, theta: &Array1<f64>| {
2142 let theta = clamp_theta(theta);
2143 let (rho, adaptive_params) = decode_theta(&theta);
2144 let family_eval =
2145 base_family.with_adaptive_params(adaptive_params, zero_quadratic.clone());
2146 let result = evaluate_custom_family_joint_hyper_efs(
2147 &family_eval,
2148 std::slice::from_ref(&blockspec),
2149 &outer_opts,
2150 &rho,
2151 &derivative_blocks,
2152 st.warm_cache.as_ref(),
2153 )
2154 .map_err(|e| {
2155 EstimationError::RemlOptimizationFailed(format!(
2156 "spatial adaptive EFS eval failed: {e}"
2157 ))
2158 })?;
2159 if !result.inner_converged {
2160 st.warm_cache = Some(result.warm_start);
2161 return Err(EstimationError::RemlOptimizationFailed(
2162 "exact spatial adaptive EFS inner solve did not converge".to_string(),
2163 ));
2164 }
2165 st.warm_cache = Some(result.warm_start);
2166 Ok(result.efs_eval)
2167 }),
2168 |st: &mut SpatialAdaptiveOuterState, theta: &Array1<f64>| {
2180 let theta = clamp_theta(theta);
2181 let (rho, adaptive_params) = decode_theta(&theta);
2182 let family_eval =
2183 base_family.with_adaptive_params(adaptive_params, zero_quadratic.clone());
2184 let result = evaluate_custom_family_joint_hyper(
2185 &family_eval,
2186 std::slice::from_ref(&blockspec),
2187 &outer_opts,
2188 &rho,
2189 &derivative_blocks,
2190 st.warm_cache.as_ref(),
2191 gam_solve::estimate::reml::reml_outer_engine::EvalMode::ValueOnly,
2192 )
2193 .map_err(|e| {
2194 EstimationError::RemlOptimizationFailed(format!(
2195 "spatial adaptive screening eval failed: {e}"
2196 ))
2197 })?;
2198 st.warm_cache = Some(result.warm_start);
2199 Ok(result.objective)
2200 },
2201 );
2202
2203 let outer_result = problem
2204 .run(&mut obj, "exact spatial adaptive regularization")
2205 .map_err(|e| {
2206 EstimationError::InvalidInput(format!(
2207 "exact spatial adaptive outer optimization failed: {e}"
2208 ))
2209 })?;
2210 if !outer_result.converged {
2211 let rel_to_cost_threshold = options.tol * (1.0_f64 + outer_result.final_value.abs());
2228 if let Some(final_grad) = outer_result
2232 .final_grad_norm
2233 .filter(|v| v.is_finite() && *v <= rel_to_cost_threshold)
2234 {
2235 log::info!(
2236 "[spatial-adaptive] outer optimization hit max_iter={} but \
2237 projected gradient norm {:.3e} ≤ τ·(1+|f|) = {:.3e} \
2238 (τ={:.3e}, |f|={:.3e}); accepting iterate under the mgcv-style \
2239 relative-to-cost REML convergence criterion.",
2240 outer_result.iterations,
2241 final_grad,
2242 rel_to_cost_threshold,
2243 options.tol,
2244 outer_result.final_value.abs(),
2245 );
2246 } else {
2247 crate::bail_invalid_estim!(
2248 "exact spatial adaptive outer optimization did not converge after {} iterations (final_objective={:.6e}, final_grad_norm={})",
2249 outer_result.iterations,
2250 outer_result.final_value,
2251 outer_result.final_grad_norm_report(),
2252 );
2253 }
2254 }
2255 let outer_iterations = outer_result.iterations;
2256 let outer_grad_norm: Option<f64> = outer_result.final_grad_norm;
2259 let theta_star = outer_result.rho;
2260 let rho_star = theta_star.slice(s![..rho_dim]).to_owned();
2261 let adaptive_lambda_start = rho_dim;
2262 let adaptive_lambda_end = adaptive_lambda_start + runtime_caches.len() * 3;
2263 let eps_star = [
2264 theta_star[adaptive_lambda_end].exp(),
2265 theta_star[adaptive_lambda_end + 1].exp(),
2266 theta_star[adaptive_lambda_end + 2].exp(),
2267 ];
2268 let adaptive_params = runtime_caches
2269 .iter()
2270 .enumerate()
2271 .map(|(cache_idx, _)| SpatialAdaptiveTermHyperParams {
2272 lambda: [
2273 theta_star[adaptive_lambda_start + cache_idx * 3].exp(),
2274 theta_star[adaptive_lambda_start + cache_idx * 3 + 1].exp(),
2275 theta_star[adaptive_lambda_start + cache_idx * 3 + 2].exp(),
2276 ],
2277 epsilon: eps_star,
2278 })
2279 .collect::<Vec<_>>();
2280 let mut fixed_total = Array2::<f64>::zeros((
2281 baseline.design.design.ncols(),
2282 baseline.design.design.ncols(),
2283 ));
2284 for (idx, penalty) in retained_penalties.iter().enumerate() {
2285 fixed_total.scaled_add(rho_star[idx].exp(), penalty);
2286 }
2287 let final_family =
2288 base_family.with_adaptive_params(adaptive_params.clone(), Arc::new(fixed_total.clone()));
2289 let final_blockspec = ParameterBlockSpec {
2290 name: "eta".to_string(),
2291 design: baseline.design.design.clone(),
2292 offset: offset.to_owned(),
2293 penalties: vec![],
2294 nullspace_dims: vec![],
2295 initial_log_lambdas: Array1::zeros(0),
2296 initial_beta: Some(baseline.fit.beta.clone()),
2297 gauge_priority: 100,
2298 jacobian_callback: None,
2299 stacked_design: None,
2300 stacked_offset: None,
2301 };
2302 let final_fit = fit_custom_family(
2303 &final_family,
2304 &[final_blockspec],
2305 &BlockwiseFitOptions {
2306 inner_max_cycles: options.max_iter,
2307 inner_tol: options.tol,
2308 outer_max_iter: 1,
2309 outer_tol: options.tol,
2310 compute_covariance: true,
2311 ..BlockwiseFitOptions::default()
2312 },
2313 )
2314 .map_err(EstimationError::CustomFamily)?;
2315 let beta = final_fit.block_states[0].beta.clone();
2316 let final_eval = final_family
2317 .exact_evaluation(&beta)
2318 .map_err(EstimationError::InvalidInput)?;
2319 let penalized_hessian = final_eval
2320 .totalobjectivehessian(&final_family.design)
2321 .map_err(EstimationError::InvalidInput)?;
2322 let beta_covariance = final_fit.covariance_conditional.clone();
2323 let beta_standard_errors = beta_covariance
2324 .as_ref()
2325 .map(|cov| Array1::from_iter((0..cov.nrows()).map(|i| cov[[i, i]].max(0.0).sqrt())));
2326
2327 let mut full_lambdas = baseline.fit.lambdas.clone();
2328 for (idx, &global_idx) in retained_global_indices.iter().enumerate() {
2329 full_lambdas[global_idx] = rho_star[idx].exp();
2330 }
2331 for (cache_idx, cache) in runtime_caches.iter().enumerate() {
2332 full_lambdas[cache.mass_penalty_global_idx] = adaptive_params[cache_idx].lambda[0];
2333 full_lambdas[cache.tension_penalty_global_idx] = adaptive_params[cache_idx].lambda[1];
2334 full_lambdas[cache.stiffness_penalty_global_idx] = adaptive_params[cache_idx].lambda[2];
2335 }
2336
2337 let deviance = if family.is_gaussian_identity() {
2338 y.iter()
2339 .zip(final_eval.obs.mu.iter())
2340 .zip(weights.iter())
2341 .map(|((&yy, &mu), &w)| w.max(0.0) * (yy - mu) * (yy - mu))
2342 .sum()
2343 } else {
2344 -2.0 * final_eval.obs.log_likelihood
2345 };
2346 let mut local_penalty_blocks =
2347 Vec::<PenaltySpec>::with_capacity(baseline.design.penalties.len());
2348 for (global_idx, bp) in baseline.design.penalties.iter().enumerate() {
2349 if adaptive_penalty_indices.contains(&global_idx) {
2350 let cache = runtime_caches
2351 .iter()
2352 .find(|cache| {
2353 cache.mass_penalty_global_idx == global_idx
2354 || cache.tension_penalty_global_idx == global_idx
2355 || cache.stiffness_penalty_global_idx == global_idx
2356 })
2357 .ok_or_else(|| {
2358 EstimationError::InvalidInput(format!(
2359 "missing runtime cache for adaptive penalty index {global_idx}"
2360 ))
2361 })?;
2362 let cache_idx = runtime_caches
2363 .iter()
2364 .position(|c| {
2365 c.mass_penalty_global_idx == global_idx
2366 || c.tension_penalty_global_idx == global_idx
2367 || c.stiffness_penalty_global_idx == global_idx
2368 })
2369 .ok_or_else(|| {
2370 EstimationError::InvalidInput(format!(
2371 "missing adaptive cache position for penalty index {global_idx}"
2372 ))
2373 })?;
2374 let state = &final_eval.adaptive_states[cache_idx];
2375 let local = if cache.mass_penalty_global_idx == global_idx {
2376 scalar_operatorhessian(&cache.d0, &state.magnitude.betahessian_diag())
2377 .mapv(|v| adaptive_params[cache_idx].lambda[0] * v)
2378 } else if cache.tension_penalty_global_idx == global_idx {
2379 grouped_operatorhessian(
2380 &cache.d1,
2381 cache.dimension,
2382 &state.gradient.betahessian_blocks(),
2383 )?
2384 .mapv(|v| adaptive_params[cache_idx].lambda[1] * v)
2385 } else {
2386 grouped_operatorhessian(
2387 &cache.d2,
2388 cache.dimension * cache.dimension,
2389 &state.curvature.betahessian_blocks(),
2390 )?
2391 .mapv(|v| adaptive_params[cache_idx].lambda[2] * v)
2392 };
2393 local_penalty_blocks.push(PenaltySpec::Dense(penalty_matrixwith_local_block(
2395 baseline.design.design.ncols(),
2396 cache.coeff_global_range.clone(),
2397 &local,
2398 )));
2399 } else {
2400 local_penalty_blocks.push(PenaltySpec::Dense(
2401 bp.to_global(p_total).mapv(|v| v * full_lambdas[global_idx]),
2402 ));
2403 }
2404 }
2405 let (edf_by_block, penalty_block_trace, edf_total) = if let Some(cov) = beta_covariance.as_ref()
2406 {
2407 exact_bounded_edf(
2408 &local_penalty_blocks,
2409 &Array1::from_elem(local_penalty_blocks.len(), 1.0),
2410 cov,
2411 )?
2412 } else {
2413 (
2414 vec![0.0; local_penalty_blocks.len()],
2415 vec![0.0; local_penalty_blocks.len()],
2416 0.0,
2417 )
2418 };
2419 let stable_penalty_term =
2420 2.0 * final_eval.adaptive_penalty_value + beta.dot(&fixed_total.dot(&beta));
2421 let standard_deviation = if family.is_gaussian_identity() {
2422 let denom = (y.len() as f64 - edf_total).max(1.0);
2423 (deviance / denom).sqrt()
2424 } else {
2425 1.0
2426 };
2427 let maps = compute_spatial_adaptiveweights_for_beta(
2428 &beta,
2429 runtime_caches,
2430 eps_star[0],
2431 eps_star[1],
2432 eps_star[2],
2433 adaptive_opts.weight_floor,
2434 adaptive_opts.weight_ceiling,
2435 beta_covariance.as_ref(),
2439 )?
2440 .into_iter()
2441 .zip(runtime_caches.iter())
2442 .map(|(w, cache)| AdaptiveSpatialMap {
2443 termname: cache.termname.clone(),
2444 feature_cols: cache.feature_cols.clone(),
2445 collocation_points: cache.collocation_points.clone(),
2446 inv_magweight: w.inv_magweight,
2447 invgradweight: w.invgradweight,
2448 inv_lapweight: w.inv_lapweight,
2449 })
2450 .collect::<Vec<_>>();
2451 let fitted_link = if family.is_latent_cloglog() {
2452 FittedLinkState::LatentCLogLog {
2453 state: latent_cloglog_state
2454 .expect("BinomialLatentCLogLog requires an explicit latent-cloglog state"),
2455 }
2456 } else if family.is_binomial_mixture() {
2457 mixture_link_state
2458 .clone()
2459 .map(|state| FittedLinkState::Mixture {
2460 state,
2461 covariance: None,
2462 })
2463 .unwrap_or(FittedLinkState::Standard(None))
2464 } else if family.is_binomial_sas() {
2465 sas_link_state
2466 .map(|state| FittedLinkState::Sas {
2467 state,
2468 covariance: None,
2469 })
2470 .unwrap_or(FittedLinkState::Standard(None))
2471 } else if family.is_binomial_beta_logistic() {
2472 sas_link_state
2473 .map(|state| FittedLinkState::BetaLogistic {
2474 state,
2475 covariance: None,
2476 })
2477 .unwrap_or(FittedLinkState::Standard(None))
2478 } else {
2479 FittedLinkState::Standard(None)
2480 };
2481 let max_abs_eta = final_eval
2482 .obs
2483 .eta
2484 .iter()
2485 .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2486 let fitted = FittedTermCollection {
2487 fit: {
2488 let log_lambdas = full_lambdas.mapv(|v| v.max(1e-300).ln());
2489 let inf = FitInference {
2490 edf_by_block,
2491 penalty_block_trace,
2492 edf_total,
2493 smoothing_correction: None,
2494 penalized_hessian: penalized_hessian.clone().into(),
2497 working_weights: final_eval.obs.fisherweight.clone(),
2498 working_response: {
2499 let mut out = final_eval.obs.eta.clone();
2500 for i in 0..out.len() {
2501 let wi = final_eval.obs.fisherweight[i].max(1e-12);
2502 out[i] += final_eval.obs.score[i] / wi;
2503 }
2504 out
2505 },
2506 reparam_qs: None,
2507 dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2508 beta_covariance: beta_covariance
2509 .clone()
2510 .map(gam_problem::dispersion_cov::PhiScaledCovariance::from),
2511 beta_standard_errors,
2512 beta_covariance_corrected: None,
2513 beta_standard_errors_corrected: None,
2514 beta_covariance_frequentist: None,
2515 coefficient_influence: None,
2516 weighted_gram: None,
2517 bias_correction_beta: None,
2518 };
2519 let geometry = Some(gam_solve::estimate::FitGeometry {
2520 penalized_hessian: penalized_hessian.into(),
2521 working_weights: inf.working_weights.clone(),
2522 working_response: inf.working_response.clone(),
2523 });
2524 let covariance_conditional = beta_covariance;
2525 let pirls_status_val = if final_fit.outer_converged {
2526 gam_solve::pirls::PirlsStatus::Converged
2527 } else {
2528 gam_solve::pirls::PirlsStatus::StalledAtValidMinimum
2529 };
2530 UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
2531 blocks: vec![gam_solve::estimate::FittedBlock {
2532 beta: beta.clone(),
2533 role: gam_problem::BlockRole::Mean,
2534 edf: edf_total,
2535 lambdas: full_lambdas.clone(),
2536 }],
2537 log_lambdas,
2538 lambdas: full_lambdas,
2539 likelihood_scale: family.default_scale_metadata(),
2540 likelihood_family: Some(family),
2541 log_likelihood_normalization:
2542 gam_spec::LogLikelihoodNormalization::UserProvided,
2543 log_likelihood: final_eval.obs.log_likelihood,
2544 deviance,
2545 reml_score: final_fit.penalized_objective,
2546 stable_penalty_term,
2547 penalized_objective: final_fit.penalized_objective,
2548 used_device: false,
2549 outer_iterations,
2550 outer_converged: final_fit.outer_converged,
2551 outer_gradient_norm: outer_grad_norm,
2552 standard_deviation,
2553 covariance_conditional,
2554 covariance_corrected: None,
2555 inference: Some(inf),
2556 fitted_link,
2557 geometry,
2558 block_states: Vec::new(),
2559 pirls_status: pirls_status_val,
2560 max_abs_eta,
2561 constraint_kkt: None,
2562 artifacts: gam_solve::estimate::FitArtifacts {
2563 pirls: None,
2564 ..Default::default()
2565 },
2566 inner_cycles: 0,
2567 })?
2568 },
2569 design: baseline.design,
2570 adaptive_diagnostics: Some(AdaptiveRegularizationDiagnostics {
2571 epsilon_0: eps_star[0],
2572 epsilon_g: eps_star[1],
2573 epsilon_c: eps_star[2],
2574 epsilon_outer_iterations: outer_iterations,
2575 mm_iterations: 0,
2576 converged: final_fit.outer_converged,
2577 maps,
2578 }),
2579 };
2580 enforce_term_constraint_feasibility(&fitted.design, &fitted.fit)?;
2581 Ok(fitted)
2582}
2583
2584fn relax_smoothing_rho_prior(
2616 options: &FitOptions,
2617 design: &TermCollectionDesign,
2618) -> gam_spec::RhoPrior {
2619 use gam_terms::basis::BasisMetadata;
2620 let base = &options.rho_prior;
2621 if matches!(
2624 base,
2625 gam_spec::RhoPrior::Flat | gam_spec::RhoPrior::Independent(_)
2626 ) {
2627 return base.clone();
2628 }
2629 let has_link_aux = options.sas_link.is_some()
2649 || options.optimize_sas
2650 || options.mixture_link.is_some()
2651 || options.optimize_mixture;
2652 let has_moving_kappa = design.smooth.terms.iter().any(|t| {
2653 matches!(
2654 t.metadata,
2655 BasisMetadata::Matern { .. }
2656 | BasisMetadata::Duchon { .. }
2657 | BasisMetadata::Sphere { .. }
2658 | BasisMetadata::SphereHarmonics { .. }
2659 | BasisMetadata::ConstantCurvature { .. }
2660 | BasisMetadata::MeasureJet { .. }
2661 )
2662 });
2663 let length_safe = !has_link_aux && !has_moving_kappa;
2670 if !length_safe {
2671 return base.clone();
2672 }
2673 let coords = &design.penaltyinfo;
2674 if coords.is_empty() {
2675 return base.clone();
2676 }
2677 let n_obs = design.design.nrows();
2688 let p_total = design.design.ncols();
2689 let underdetermined = n_obs < 2 * p_total;
2720 let relaxable_terms: std::collections::HashSet<&str> = design
2732 .smooth
2733 .terms
2734 .iter()
2735 .filter(|t| {
2736 matches!(
2737 t.metadata,
2738 BasisMetadata::BSpline1D { .. }
2739 | BasisMetadata::ThinPlate { .. }
2740 | BasisMetadata::TensorBSpline { .. }
2741 )
2742 && matches!(t.shape, gam_terms::smooth::ShapeConstraint::None)
2756 })
2757 .map(|t| t.name.as_str())
2758 .collect();
2759 let any_relaxed = coords.iter().any(|info| {
2760 info.termname
2761 .as_deref()
2762 .is_some_and(|name| relaxable_terms.contains(name))
2763 });
2764 if !any_relaxed {
2765 return base.clone();
2766 }
2767 let relaxed_prior = if underdetermined {
2772 gam_spec::RhoPrior::Normal {
2773 mean: 0.0,
2774 sd: RELAX_UNDERDETERMINED_RHO_SD,
2775 }
2776 } else {
2777 gam_spec::RhoPrior::Flat
2778 };
2779 let nullspace_select_prior = gam_spec::RhoPrior::PenalizedComplexity {
2806 upper: NULLSPACE_SELECT_PC_UPPER,
2807 tail_prob: NULLSPACE_SELECT_PC_TAIL_PROB,
2808 };
2809 let nullspace_degeneracy_prior = gam_spec::RhoPrior::Normal {
2836 mean: 0.0,
2837 sd: NULLSPACE_WELLDET_DEGENERACY_RHO_SD,
2838 };
2839 let per_coord = coords
2840 .iter()
2841 .map(|info| {
2842 let relax = info
2843 .termname
2844 .as_deref()
2845 .is_some_and(|name| relaxable_terms.contains(name));
2846 if !relax {
2847 return base.clone();
2848 }
2849 let is_nullspace =
2850 matches!(info.penalty.source, PenaltySource::DoublePenaltyNullspace);
2851 if is_nullspace {
2890 if underdetermined {
2891 nullspace_select_prior.clone()
2892 } else {
2893 nullspace_degeneracy_prior.clone()
2894 }
2895 } else {
2896 relaxed_prior.clone()
2897 }
2898 })
2899 .collect::<Vec<_>>();
2900 gam_spec::RhoPrior::Independent(per_coord)
2901}
2902
2903const RELAX_UNDERDETERMINED_RHO_SD: f64 = 15.0;
2916
2917const NULLSPACE_SELECT_PC_UPPER: f64 = 0.05;
2935
2936const NULLSPACE_SELECT_PC_TAIL_PROB: f64 = 0.01;
2946
2947fn adaptive_fit_options_base(options: &FitOptions, design: &TermCollectionDesign) -> FitOptions {
2948 FitOptions {
2949 latent_cloglog: options.latent_cloglog,
2950 mixture_link: options.mixture_link.clone(),
2951 optimize_mixture: options.optimize_mixture,
2952 sas_link: options.sas_link,
2953 optimize_sas: options.optimize_sas,
2954 compute_inference: options.compute_inference,
2955 skip_rho_posterior_inference: options.skip_rho_posterior_inference,
2956 max_iter: options.max_iter,
2957 tol: options.tol,
2958 nullspace_dims: design.nullspace_dims.clone(),
2959 linear_constraints: design.linear_constraints.clone(),
2960 firth_bias_reduction: options.firth_bias_reduction,
2961 adaptive_regularization: None,
2962 penalty_shrinkage_floor: options.penalty_shrinkage_floor,
2963 rho_prior: options.rho_prior.clone(),
2966 kronecker_penalty_system: design.kronecker_penalty_system(),
2967 kronecker_factored: design
2968 .smooth
2969 .terms
2970 .iter()
2971 .find_map(|t| t.kronecker_factored.clone()),
2972 persist_warm_start_disk: options.persist_warm_start_disk,
2973 }
2974}
2975
2976fn superseded_fit_options(options: &FitOptions) -> FitOptions {
2977 let mut fit_options = options.clone();
2978 fit_options.skip_rho_posterior_inference = true;
2979 fit_options
2980}
2981
2982#[derive(Clone)]
2983struct BoundedLinearTermMeta {
2984 col_idx: usize,
2985 min: f64,
2986 max: f64,
2987 prior: BoundedCoefficientPriorSpec,
2988}
2989
2990struct BoundedEffectiveJacobian {
3014 design: Array2<f64>,
3015 bounded_terms: Vec<BoundedLinearTermMeta>,
3016}
3017
3018impl BlockEffectiveJacobian for BoundedEffectiveJacobian {
3019 fn effective_jacobian_rows(
3020 &self,
3021 state: &FamilyLinearizationState<'_>,
3022 rows: std::ops::Range<usize>,
3023 ) -> Result<Array2<f64>, String> {
3024 let p = self.design.ncols();
3025 let n = self.design.nrows();
3026 let rows = rows.start.min(n)..rows.end.min(n);
3027 if !state.beta.is_empty() {
3028 if state.beta.len() != p {
3029 return Err(format!(
3030 "BoundedEffectiveJacobian::effective_jacobian_at: beta length {} != design \
3031 ncols {p}",
3032 state.beta.len(),
3033 ));
3034 }
3035 if state.beta.iter().any(|v| v.is_nan()) {
3036 return Err(
3037 "BoundedEffectiveJacobian::effective_jacobian_at: beta contains NaN"
3038 .to_string(),
3039 );
3040 }
3041 }
3042 let mut jac = self
3043 .design
3044 .slice(ndarray::s![rows.start..rows.end, ..])
3045 .to_owned();
3046 for term in &self.bounded_terms {
3047 let theta = if state.beta.is_empty() {
3048 0.0
3049 } else {
3050 state.beta[term.col_idx]
3051 };
3052 let (_, _, db_dtheta, _, _) = bounded_latent_derivatives(theta, term.min, term.max);
3053 jac.column_mut(term.col_idx).mapv_inplace(|v| v * db_dtheta);
3054 }
3055 Ok(jac)
3056 }
3057}
3058
3059#[derive(Clone)]
3060struct BoundedLinearFamily {
3061 family: LikelihoodSpec,
3062 latent_cloglog_state: Option<LatentCLogLogState>,
3063 mixture_link_state: Option<MixtureLinkState>,
3064 sas_link_state: Option<SasLinkState>,
3065 y: Array1<f64>,
3066 weights: Array1<f64>,
3067 design: Array2<f64>,
3068 designzeroed: Array2<f64>,
3069 offset: Array1<f64>,
3070 bounded_terms: Vec<BoundedLinearTermMeta>,
3071}
3072
3073#[derive(Clone)]
3074struct StandardFamilyObservationState {
3075 eta: Array1<f64>,
3076 mu: Array1<f64>,
3077 score: Array1<f64>,
3078 fisherweight: Array1<f64>,
3079 neghessian_eta: Array1<f64>,
3080 neghessian_eta_derivative: Array1<f64>,
3081 log_likelihood: f64,
3082}
3083
3084fn bounded_logit(z: f64) -> f64 {
3085 let zc = z.clamp(1e-12, 1.0 - 1e-12);
3086 (zc / (1.0 - zc)).ln()
3087}
3088
3089fn stable_sigmoid(theta: f64) -> f64 {
3090 if theta >= 0.0 {
3091 let exp_neg = (-theta).exp();
3092 1.0 / (1.0 + exp_neg)
3093 } else {
3094 let exp_pos = theta.exp();
3095 exp_pos / (1.0 + exp_pos)
3096 }
3097}
3098
3099fn bounded_latent_to_user(theta: f64, min: f64, max: f64) -> (f64, f64, f64) {
3100 let z = stable_sigmoid(theta);
3101 let width = max - min;
3102 let beta = min + width * z;
3103 let db_dtheta = width * z * (1.0 - z);
3104 (beta, z, db_dtheta)
3105}
3106
3107fn bounded_user_to_latent(beta: f64, min: f64, max: f64) -> f64 {
3118 let width = max - min;
3119 if width <= 0.0 || !width.is_finite() {
3120 return 0.0;
3121 }
3122 let z = (beta - min) / width;
3123 bounded_logit(z)
3124}
3125
3126#[derive(Debug, Clone, Copy)]
3130pub struct BoundedSampleColumn {
3131 pub col_idx: usize,
3133 pub min: f64,
3135 pub max: f64,
3137}
3138
3139pub fn sample_bounded_latent_posterior_internal(
3177 beta_user: &Array1<f64>,
3178 user_hessian: &Array2<f64>,
3179 bounded_columns: &[BoundedSampleColumn],
3180 n_draws: usize,
3181 sqrt_cov_scale: f64,
3182 base_seed: u64,
3183) -> Result<Array2<f64>, EstimationError> {
3184 let p = beta_user.len();
3185 if user_hessian.nrows() != p || user_hessian.ncols() != p {
3186 crate::bail_invalid_estim!(
3187 "bounded posterior sampling dimension mismatch: mode has {p} entries, user Hessian is {}x{}",
3188 user_hessian.nrows(),
3189 user_hessian.ncols()
3190 );
3191 }
3192
3193 let mut theta_mode = beta_user.clone();
3195 let mut jac_diag = Array1::<f64>::ones(p);
3196 for bc in bounded_columns {
3197 if bc.col_idx >= p {
3198 crate::bail_invalid_estim!(
3199 "bounded posterior sampling: bounded column index {} out of range for {p} coefficients",
3200 bc.col_idx
3201 );
3202 }
3203 let theta_i = bounded_user_to_latent(beta_user[bc.col_idx], bc.min, bc.max);
3204 let (_, _, db_dtheta) = bounded_latent_to_user(theta_i, bc.min, bc.max);
3205 theta_mode[bc.col_idx] = theta_i;
3206 jac_diag[bc.col_idx] = db_dtheta.max(1e-12);
3211 }
3212
3213 let mut h_latent = user_hessian.clone();
3216 for i in 0..p {
3217 let ji = jac_diag[i];
3218 if ji != 1.0 {
3219 h_latent.row_mut(i).mapv_inplace(|v| v * ji);
3220 h_latent.column_mut(i).mapv_inplace(|v| v * ji);
3221 }
3222 }
3223
3224 use gam_linalg::faer_ndarray::FaerCholesky as _;
3227 use rand::SeedableRng as _;
3228 let chol = h_latent.cholesky(faer::Side::Lower).map_err(|err| {
3229 EstimationError::InvalidInput(format!(
3230 "bounded posterior sampling: Cholesky of the latent penalized Hessian failed: {err:?}"
3231 ))
3232 })?;
3233 let l = chol.lower_triangular();
3234
3235 let mut draws = Array2::<f64>::zeros((n_draws, p));
3236 let mut eps = Array1::<f64>::zeros(p);
3237 let mut delta = Array1::<f64>::zeros(p);
3238 let mut rng = rand::rngs::StdRng::seed_from_u64(base_seed);
3239 for k in 0..n_draws {
3240 for e in eps.iter_mut() {
3241 *e = standard_normal_draw(&mut rng);
3242 }
3243 solve_lower_transpose_into(&l, &eps, &mut delta);
3244 for i in 0..p {
3245 draws[(k, i)] = theta_mode[i] + sqrt_cov_scale * delta[i];
3248 }
3249 for bc in bounded_columns {
3252 let (beta_draw, _, _) = bounded_latent_to_user(draws[(k, bc.col_idx)], bc.min, bc.max);
3253 draws[(k, bc.col_idx)] = beta_draw;
3254 }
3255 }
3256
3257 Ok(draws)
3258}
3259
3260#[inline]
3263fn standard_normal_draw<R: rand::Rng + ?Sized>(rng: &mut R) -> f64 {
3264 use rand::RngExt as _;
3265 let u1 = rng.random::<f64>().max(1e-16);
3266 let u2 = rng.random::<f64>();
3267 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
3268}
3269
3270fn solve_lower_transpose_into(l: &Array2<f64>, b: &Array1<f64>, out: &mut Array1<f64>) {
3274 let p = l.nrows();
3275 for i in (0..p).rev() {
3276 let mut acc = b[i];
3277 for j in (i + 1)..p {
3278 acc -= l[(j, i)] * out[j];
3279 }
3280 let diag = l[(i, i)];
3281 out[i] = if diag.abs() > 0.0 { acc / diag } else { 0.0 };
3282 }
3283}
3284
3285fn bounded_latent_derivatives(theta: f64, min: f64, max: f64) -> (f64, f64, f64, f64, f64) {
3286 let z = stable_sigmoid(theta);
3287 let width = max - min;
3288 let s = z * (1.0 - z);
3289 let beta = min + width * z;
3290 let db_dtheta = width * s;
3291 let d2b_dtheta2 = width * s * (1.0 - 2.0 * z);
3292 let d3b_dtheta3 = width * s * (1.0 - 6.0 * z + 6.0 * z * z);
3293 (beta, z, db_dtheta, d2b_dtheta2, d3b_dtheta3)
3294}
3295
3296fn bounded_prior_terms(theta: f64, prior: &BoundedCoefficientPriorSpec) -> (f64, f64, f64, f64) {
3297 let (a, b) = match prior {
3298 BoundedCoefficientPriorSpec::None => return (0.0, 0.0, 0.0, 0.0),
3300 BoundedCoefficientPriorSpec::Uniform => (1.0, 1.0),
3303 BoundedCoefficientPriorSpec::Beta { a, b } => (*a, *b),
3304 };
3305 let z = stable_sigmoid(theta).clamp(1e-12, 1.0 - 1e-12);
3306 let logp = a * z.ln() + b * (1.0 - z).ln();
3307 let grad = a - (a + b) * z;
3308 let neghess = (a + b) * z * (1.0 - z);
3309 let neghess_derivative = (a + b) * z * (1.0 - z) * (1.0 - 2.0 * z);
3310 (logp, grad, neghess, neghess_derivative)
3311}
3312
3313#[inline]
3322fn glm_eta_observation_state(
3323 w: f64,
3324 lmu: f64,
3325 lmumu: f64,
3326 lmumumu: f64,
3327 var: f64,
3328 d1: f64,
3329 d2: f64,
3330 d3: f64,
3331 mu_deriv_eps: f64,
3332) -> (f64, f64, f64, f64) {
3333 let score = w * lmu * d1;
3334 let fisherweight = (w * d1 * d1 / var).max(mu_deriv_eps);
3335 let neghessian = -w * (lmumu * d1 * d1 + lmu * d2);
3336 let neghessian_deriv = -w * (lmumumu * d1 * d1 * d1 + 3.0 * lmumu * d1 * d2 + lmu * d3);
3337 (score, fisherweight, neghessian, neghessian_deriv)
3338}
3339
3340fn evaluate_standard_familyobservations(
3341 family: LikelihoodSpec,
3342 latent_cloglog_state: Option<&LatentCLogLogState>,
3343 mixture_link_state: Option<&MixtureLinkState>,
3344 sas_link_state: Option<&SasLinkState>,
3345 y: &Array1<f64>,
3346 weights: &Array1<f64>,
3347 eta: &Array1<f64>,
3348) -> Result<StandardFamilyObservationState, EstimationError> {
3349 const PROB_EPS: f64 = 1e-10;
3350 const MU_DERIV_EPS: f64 = 1e-12;
3351 let n = y.len();
3352 if weights.len() != n || eta.len() != n {
3353 crate::bail_invalid_estim!("bounded family observation size mismatch");
3354 }
3355
3356 let mut mu = Array1::<f64>::zeros(n);
3357 let mut score = Array1::<f64>::zeros(n);
3358 let mut fisherweight = Array1::<f64>::zeros(n);
3359 let mut neghessian_eta = Array1::<f64>::zeros(n);
3360 let mut neghessian_eta_derivative = Array1::<f64>::zeros(n);
3361 let mut log_likelihood = 0.0;
3362
3363 for i in 0..n {
3364 let w = weights[i].max(0.0);
3365 let yi = y[i];
3366 let eta_i = eta[i];
3367 match (&family.response, &family.link) {
3368 (ResponseFamily::Gaussian, _) => {
3369 let resid = yi - eta_i;
3370 mu[i] = eta_i;
3371 score[i] = w * resid;
3372 fisherweight[i] = w.max(MU_DERIV_EPS);
3373 neghessian_eta[i] = w;
3374 neghessian_eta_derivative[i] = 0.0;
3375 log_likelihood += -0.5 * w * resid * resid;
3376 }
3377 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Logit)) => {
3378 let jet = logit_inverse_link_jet5(eta_i);
3379 mu[i] = jet.mu;
3380 score[i] = w * (yi - jet.mu);
3381 fisherweight[i] = jet.d1.max(MU_DERIV_EPS);
3382 neghessian_eta[i] = jet.d1;
3383 neghessian_eta_derivative[i] = jet.d2;
3384 let logmu = -gam_linalg::utils::stable_softplus(-eta_i);
3385 let log_one_minusmu = -gam_linalg::utils::stable_softplus(eta_i);
3386 log_likelihood += w * (yi * logmu + (1.0 - yi) * log_one_minusmu);
3387 }
3388 (ResponseFamily::Binomial, _) => {
3389 let inverse_link = if let Some(state) = latent_cloglog_state {
3390 Some(InverseLink::LatentCLogLog(*state))
3391 } else if let Some(state) = mixture_link_state {
3392 Some(InverseLink::Mixture(state.clone()))
3393 } else {
3394 sas_link_state.map(|state| {
3395 if family.is_binomial_beta_logistic() {
3396 InverseLink::BetaLogistic(*state)
3397 } else {
3398 InverseLink::Sas(*state)
3399 }
3400 })
3401 };
3402 let strategy_spec = LikelihoodSpec {
3403 response: family.response.clone(),
3404 link: inverse_link.clone().unwrap_or_else(|| family.link.clone()),
3405 };
3406 let jet = strategy_for_spec(&strategy_spec).inverse_link_jet(eta_i)?;
3407 let mu_i_raw = jet.mu;
3408 let dmu_deta_raw = jet.d1;
3409 let mu_i: f64 = mu_i_raw.clamp(PROB_EPS, 1.0 - PROB_EPS);
3410 let dmu_deta = dmu_deta_raw.max(MU_DERIV_EPS);
3411 let d2mu_deta2 = jet.d2;
3412 let d3mu_deta3 = jet.d3;
3413 let var = (mu_i * (1.0 - mu_i)).max(PROB_EPS);
3414 let lmu = (yi - mu_i) / var;
3415 let lmumu = -(yi / (mu_i * mu_i)) - ((1.0 - yi) / ((1.0 - mu_i) * (1.0 - mu_i)));
3416 let lmumumu = 2.0 * yi / (mu_i * mu_i * mu_i)
3417 - 2.0 * (1.0 - yi) / ((1.0 - mu_i) * (1.0 - mu_i) * (1.0 - mu_i));
3418 mu[i] = mu_i;
3419 score[i] = w * lmu * dmu_deta;
3420 fisherweight[i] = (w * dmu_deta * dmu_deta / var).max(MU_DERIV_EPS);
3421 neghessian_eta[i] = -w * (lmumu * dmu_deta * dmu_deta + lmu * d2mu_deta2);
3422 neghessian_eta_derivative[i] = -w
3423 * (lmumumu * dmu_deta * dmu_deta * dmu_deta
3424 + 3.0 * lmumu * dmu_deta * d2mu_deta2
3425 + lmu * d3mu_deta3);
3426 log_likelihood += w * (yi * mu_i.ln() + (1.0 - yi) * (1.0 - mu_i).ln());
3427 }
3428 (ResponseFamily::Poisson, _) => {
3429 let strategy_spec = LikelihoodSpec {
3432 response: family.response.clone(),
3433 link: family.link.clone(),
3434 };
3435 let jet = strategy_for_spec(&strategy_spec).inverse_link_jet(eta_i)?;
3436 let mu_i = jet.mu.max(PROB_EPS);
3437 let d1 = jet.d1.max(MU_DERIV_EPS);
3438 let var = mu_i;
3439 let lmu = yi / mu_i - 1.0;
3440 let lmumu = -yi / (mu_i * mu_i);
3441 let lmumumu = 2.0 * yi / (mu_i * mu_i * mu_i);
3442 let (s, f, nh, nhd) = glm_eta_observation_state(
3443 w, lmu, lmumu, lmumumu, var, d1, jet.d2, jet.d3, MU_DERIV_EPS,
3444 );
3445 mu[i] = mu_i;
3446 score[i] = s;
3447 fisherweight[i] = f;
3448 neghessian_eta[i] = nh;
3449 neghessian_eta_derivative[i] = nhd;
3450 log_likelihood += w * (yi * mu_i.ln() - mu_i);
3451 }
3452 (ResponseFamily::Tweedie { p }, _) => {
3453 let p = *p;
3458 let strategy_spec = LikelihoodSpec {
3459 response: family.response.clone(),
3460 link: family.link.clone(),
3461 };
3462 let jet = strategy_for_spec(&strategy_spec).inverse_link_jet(eta_i)?;
3463 let mu_i = jet.mu.max(PROB_EPS);
3464 let d1 = jet.d1.max(MU_DERIV_EPS);
3465 let var = mu_i.powf(p);
3466 let resid = yi - mu_i;
3467 let lmu = resid / var;
3468 let lmumu = -mu_i.powf(-p) - p * resid * mu_i.powf(-p - 1.0);
3469 let lmumumu =
3470 2.0 * p * mu_i.powf(-p - 1.0) + p * (p + 1.0) * resid * mu_i.powf(-p - 2.0);
3471 let (s, f, nh, nhd) = glm_eta_observation_state(
3472 w, lmu, lmumu, lmumumu, var, d1, jet.d2, jet.d3, MU_DERIV_EPS,
3473 );
3474 mu[i] = mu_i;
3475 score[i] = s;
3476 fisherweight[i] = f;
3477 neghessian_eta[i] = nh;
3478 neghessian_eta_derivative[i] = nhd;
3479 log_likelihood += w
3481 * (yi * mu_i.powf(1.0 - p) / (1.0 - p) - mu_i.powf(2.0 - p) / (2.0 - p));
3482 }
3483 (ResponseFamily::NegativeBinomial { theta, .. }, _) => {
3484 let theta = (*theta).max(PROB_EPS);
3488 let strategy_spec = LikelihoodSpec {
3489 response: family.response.clone(),
3490 link: family.link.clone(),
3491 };
3492 let jet = strategy_for_spec(&strategy_spec).inverse_link_jet(eta_i)?;
3493 let mu_i = jet.mu.max(PROB_EPS);
3494 let d1 = jet.d1.max(MU_DERIV_EPS);
3495 let mu_plus = mu_i + theta;
3496 let var = mu_i + mu_i * mu_i / theta;
3497 let lmu = yi / mu_i - (yi + theta) / mu_plus;
3498 let lmumu = -yi / (mu_i * mu_i) + (yi + theta) / (mu_plus * mu_plus);
3499 let lmumumu =
3500 2.0 * yi / (mu_i * mu_i * mu_i) - 2.0 * (yi + theta) / (mu_plus * mu_plus * mu_plus);
3501 let (s, f, nh, nhd) = glm_eta_observation_state(
3502 w, lmu, lmumu, lmumumu, var, d1, jet.d2, jet.d3, MU_DERIV_EPS,
3503 );
3504 mu[i] = mu_i;
3505 score[i] = s;
3506 fisherweight[i] = f;
3507 neghessian_eta[i] = nh;
3508 neghessian_eta_derivative[i] = nhd;
3509 log_likelihood += w * (yi * mu_i.ln() - (yi + theta) * mu_plus.ln());
3510 }
3511 (ResponseFamily::Beta { .. }, _) => {
3512 crate::bail_invalid_estim!(
3513 "bounded linear terms are not supported for BetaLogit fits"
3514 );
3515 }
3516 (ResponseFamily::Gamma, _) => {
3517 let strategy_spec = LikelihoodSpec {
3521 response: family.response.clone(),
3522 link: family.link.clone(),
3523 };
3524 let jet = strategy_for_spec(&strategy_spec).inverse_link_jet(eta_i)?;
3525 let mu_i = jet.mu.max(PROB_EPS);
3526 let d1 = jet.d1.max(MU_DERIV_EPS);
3527 let var = mu_i * mu_i;
3528 let lmu = yi / (mu_i * mu_i) - 1.0 / mu_i;
3529 let lmumu = -2.0 * yi / (mu_i * mu_i * mu_i) + 1.0 / (mu_i * mu_i);
3530 let lmumumu =
3531 6.0 * yi / (mu_i * mu_i * mu_i * mu_i) - 2.0 / (mu_i * mu_i * mu_i);
3532 let (s, f, nh, nhd) = glm_eta_observation_state(
3533 w, lmu, lmumu, lmumumu, var, d1, jet.d2, jet.d3, MU_DERIV_EPS,
3534 );
3535 mu[i] = mu_i;
3536 score[i] = s;
3537 fisherweight[i] = f;
3538 neghessian_eta[i] = nh;
3539 neghessian_eta_derivative[i] = nhd;
3540 log_likelihood += w * (-(yi / mu_i) - mu_i.ln());
3541 }
3542 (ResponseFamily::RoystonParmar, _) => {
3543 crate::bail_invalid_estim!(
3544 "bounded linear terms are not supported for survival model fits"
3545 );
3546 }
3547 }
3548 }
3549
3550 Ok(StandardFamilyObservationState {
3551 eta: eta.clone(),
3552 mu,
3553 score,
3554 fisherweight,
3555 neghessian_eta,
3556 neghessian_eta_derivative,
3557 log_likelihood,
3558 })
3559}
3560
3561#[derive(Clone, Copy, Debug, PartialEq, Eq)]
3562enum SpatialAdaptiveHyperKind {
3563 LogLambdaMagnitude,
3564 LogLambdaGradient,
3565 LogLambdaCurvature,
3566 LogEpsilonMagnitude,
3567 LogEpsilonGradient,
3568 LogEpsilonCurvature,
3569}
3570
3571impl SpatialAdaptiveHyperKind {
3572 fn component_index(self) -> usize {
3573 match self {
3574 SpatialAdaptiveHyperKind::LogLambdaMagnitude
3575 | SpatialAdaptiveHyperKind::LogEpsilonMagnitude => 0,
3576 SpatialAdaptiveHyperKind::LogLambdaGradient
3577 | SpatialAdaptiveHyperKind::LogEpsilonGradient => 1,
3578 SpatialAdaptiveHyperKind::LogLambdaCurvature
3579 | SpatialAdaptiveHyperKind::LogEpsilonCurvature => 2,
3580 }
3581 }
3582
3583 fn is_log_lambda(self) -> bool {
3584 matches!(
3585 self,
3586 SpatialAdaptiveHyperKind::LogLambdaMagnitude
3587 | SpatialAdaptiveHyperKind::LogLambdaGradient
3588 | SpatialAdaptiveHyperKind::LogLambdaCurvature
3589 )
3590 }
3591
3592 fn is_log_epsilon(self) -> bool {
3593 matches!(
3594 self,
3595 SpatialAdaptiveHyperKind::LogEpsilonMagnitude
3596 | SpatialAdaptiveHyperKind::LogEpsilonGradient
3597 | SpatialAdaptiveHyperKind::LogEpsilonCurvature
3598 )
3599 }
3600}
3601
3602#[derive(Clone, Copy, Debug)]
3603struct SpatialAdaptiveHyperSpec {
3604 cache_index: usize,
3605 kind: SpatialAdaptiveHyperKind,
3606}
3607
3608#[derive(Clone, Copy, Debug, PartialEq, Eq)]
3609enum SpatialAdaptiveExplicitSecondOrderKind {
3610 StructuralZero,
3611 LocalAlphaAlpha,
3612 LocalAlphaEta,
3613 SharedEtaEta,
3614}
3615
3616#[derive(Clone, Copy, Debug, PartialEq, Eq)]
3621enum AdaptiveComponent {
3622 Magnitude,
3623 Gradient,
3624 Curvature,
3625}
3626
3627impl AdaptiveComponent {
3628 fn from_index(index: usize) -> Result<Self, String> {
3629 match index {
3630 0 => Ok(AdaptiveComponent::Magnitude),
3631 1 => Ok(AdaptiveComponent::Gradient),
3632 2 => Ok(AdaptiveComponent::Curvature),
3633 other => Err(SmoothError::invalid_index(format!(
3634 "invalid adaptive component index {}",
3635 other
3636 ))
3637 .into()),
3638 }
3639 }
3640}
3641
3642#[derive(Clone, Copy, Debug, PartialEq, Eq)]
3647enum HyperDerivativeKind {
3648 Rho,
3650 LogEpsilonFirst,
3652 LogEpsilonSecond,
3654}
3655
3656#[derive(Clone, Copy, Debug, PartialEq, Eq)]
3661enum HyperDriftKind {
3662 Rho,
3663 LogEpsilon,
3664}
3665
3666impl SpatialAdaptiveHyperSpec {
3667 fn component_index(self) -> usize {
3668 self.kind.component_index()
3669 }
3670
3671 fn explicit_second_order_kind(self, other: Self) -> SpatialAdaptiveExplicitSecondOrderKind {
3672 if self.component_index() != other.component_index() {
3673 return SpatialAdaptiveExplicitSecondOrderKind::StructuralZero;
3674 }
3675 match (
3676 self.kind.is_log_lambda(),
3677 other.kind.is_log_lambda(),
3678 self.kind.is_log_epsilon(),
3679 other.kind.is_log_epsilon(),
3680 ) {
3681 (true, true, false, false) if self.cache_index == other.cache_index => {
3682 SpatialAdaptiveExplicitSecondOrderKind::LocalAlphaAlpha
3683 }
3684 (true, false, false, true) | (false, true, true, false) => {
3685 SpatialAdaptiveExplicitSecondOrderKind::LocalAlphaEta
3686 }
3687 (false, false, true, true) => SpatialAdaptiveExplicitSecondOrderKind::SharedEtaEta,
3688 _ => SpatialAdaptiveExplicitSecondOrderKind::StructuralZero,
3689 }
3690 }
3691}
3692
3693#[derive(Clone, Debug)]
3694struct SpatialAdaptiveTermHyperParams {
3695 lambda: [f64; 3],
3696 epsilon: [f64; 3],
3697}
3698
3699#[derive(Clone)]
3700struct SpatialAdaptiveExactEvaluation {
3701 obs: StandardFamilyObservationState,
3702 adaptive_states: Vec<SpatialPenaltyExactState>,
3703 adaptive_penalty_value: f64,
3704 adaptive_penaltygradient: Array1<f64>,
3705 adaptive_penaltyhessian: Array2<f64>,
3706 fixed_quadraticvalue: f64,
3707 fixed_quadraticgradient: Array1<f64>,
3708 fixed_quadratichessian: Array2<f64>,
3709}
3710
3711#[derive(Clone)]
3712struct CachedSpatialAdaptiveExactEvaluation {
3713 beta: Array1<f64>,
3714 eval: Arc<SpatialAdaptiveExactEvaluation>,
3715}
3716
3717impl SpatialAdaptiveExactEvaluation {
3718 fn total_penalty_value(&self) -> f64 {
3719 self.adaptive_penalty_value + self.fixed_quadraticvalue
3720 }
3721
3722 fn total_penaltygradient(&self) -> Array1<f64> {
3723 &self.adaptive_penaltygradient + &self.fixed_quadraticgradient
3724 }
3725
3726 fn total_penaltyhessian(&self) -> Array2<f64> {
3727 &self.adaptive_penaltyhessian + &self.fixed_quadratichessian
3728 }
3729
3730 fn totalobjectivehessian(&self, design: &Array2<f64>) -> Result<Array2<f64>, String> {
3731 let mut out = xt_diag_x_dense(design.view(), self.obs.neghessian_eta.view())?;
3732 out += &self.total_penaltyhessian();
3733 Ok(out)
3734 }
3735}
3736
3737#[derive(Clone)]
3738struct SpatialAdaptiveExactFamily {
3739 family: LikelihoodSpec,
3740 latent_cloglog_state: Option<LatentCLogLogState>,
3741 mixture_link_state: Option<MixtureLinkState>,
3742 sas_link_state: Option<SasLinkState>,
3743 y: Arc<Array1<f64>>,
3744 weights: Arc<Array1<f64>>,
3745 design: Arc<Array2<f64>>,
3746 offset: Arc<Array1<f64>>,
3747 linear_constraints: Option<LinearInequalityConstraints>,
3748 runtime_caches: Arc<Vec<SpatialOperatorRuntimeCache>>,
3749 adaptive_params: Vec<SpatialAdaptiveTermHyperParams>,
3750 fixed_quadratichessian: Arc<Array2<f64>>,
3751 hyperspecs: Arc<Vec<SpatialAdaptiveHyperSpec>>,
3752 exact_eval_cache: Arc<Mutex<Option<CachedSpatialAdaptiveExactEvaluation>>>,
3753}
3754
3755impl SpatialAdaptiveExactFamily {
3756 fn with_adaptive_params(
3757 &self,
3758 adaptive_params: Vec<SpatialAdaptiveTermHyperParams>,
3759 fixed_quadratichessian: Arc<Array2<f64>>,
3760 ) -> Self {
3761 Self {
3762 family: self.family.clone(),
3763 latent_cloglog_state: self.latent_cloglog_state,
3764 mixture_link_state: self.mixture_link_state.clone(),
3765 sas_link_state: self.sas_link_state,
3766 y: self.y.clone(),
3767 weights: self.weights.clone(),
3768 design: self.design.clone(),
3769 offset: self.offset.clone(),
3770 linear_constraints: self.linear_constraints.clone(),
3771 runtime_caches: self.runtime_caches.clone(),
3772 adaptive_params,
3773 fixed_quadratichessian,
3774 hyperspecs: self.hyperspecs.clone(),
3775 exact_eval_cache: Arc::new(Mutex::new(None)),
3776 }
3777 }
3778
3779 fn total_eta(&self, beta: &Array1<f64>) -> Array1<f64> {
3780 gam_linalg::faer_ndarray::fast_av(self.design.as_ref(), beta) + self.offset.as_ref()
3781 }
3782
3783 fn fixed_quadratic_terms(&self, beta: &Array1<f64>) -> (f64, Array1<f64>) {
3784 let grad = self.fixed_quadratichessian.dot(beta);
3785 let value = 0.5 * beta.dot(&grad);
3786 (value, grad)
3787 }
3788
3789 fn adaptive_penalty_value_only(&self, beta: &Array1<f64>) -> Result<f64, String> {
3790 let mut penalty_value = 0.0;
3791 for (cache_idx, cache) in self.runtime_caches.iter().enumerate() {
3792 let params = self.adaptive_params.get(cache_idx).ok_or_else(|| {
3793 format!(
3794 "missing adaptive parameter block for cache {}",
3795 cache.termname
3796 )
3797 })?;
3798 let beta_local = beta.slice(s![cache.coeff_global_range.clone()]);
3799 let state =
3800 SpatialPenaltyExactState::from_beta_local(beta_local, cache, params.epsilon)
3801 .map_err(|e| e.to_string())?;
3802 penalty_value += params.lambda[0] * state.magnitude.penalty_value();
3803 penalty_value += params.lambda[1] * state.gradient.penalty_value();
3804 penalty_value += params.lambda[2] * state.curvature.penalty_value();
3805 }
3806 Ok(penalty_value)
3807 }
3808
3809 fn zero_hyper_parts(&self) -> (Array1<f64>, Array2<f64>) {
3810 let total_dim = self.design.ncols();
3811 (
3812 Array1::<f64>::zeros(total_dim),
3813 Array2::<f64>::zeros((total_dim, total_dim)),
3814 )
3815 }
3816
3817 fn embed_local_hyper_parts(
3818 &self,
3819 coeff_range: &Range<usize>,
3820 local_grad: &Array1<f64>,
3821 local_hess: &Array2<f64>,
3822 ) -> (Array1<f64>, Array2<f64>) {
3823 let (mut beta_mixed, mut betahessian) = self.zero_hyper_parts();
3824 beta_mixed
3825 .slice_mut(s![coeff_range.clone()])
3826 .assign(local_grad);
3827 betahessian
3828 .slice_mut(s![coeff_range.clone(), coeff_range.clone()])
3829 .assign(local_hess);
3830 (beta_mixed, betahessian)
3831 }
3832
3833 fn embed_local_hyper_hessian(
3834 &self,
3835 coeff_range: &Range<usize>,
3836 local_hess: &Array2<f64>,
3837 ) -> Array2<f64> {
3838 let total_dim = self.design.ncols();
3839 let mut out = Array2::<f64>::zeros((total_dim, total_dim));
3840 out.slice_mut(s![coeff_range.clone(), coeff_range.clone()])
3841 .assign(local_hess);
3842 out
3843 }
3844
3845 fn adaptive_block_eval(
3854 &self,
3855 eval: &SpatialAdaptiveExactEvaluation,
3856 cache_idx: usize,
3857 component: AdaptiveComponent,
3858 derivative: HyperDerivativeKind,
3859 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
3860 let cache = self
3861 .runtime_caches
3862 .get(cache_idx)
3863 .ok_or_else(|| format!("adaptive cache index {} out of bounds", cache_idx))?;
3864 let params = self
3865 .adaptive_params
3866 .get(cache_idx)
3867 .ok_or_else(|| format!("adaptive hyperparameter block {} out of bounds", cache_idx))?;
3868 let state = eval
3869 .adaptive_states
3870 .get(cache_idx)
3871 .ok_or_else(|| format!("adaptive exact state index {} out of bounds", cache_idx))?;
3872
3873 let (objective_local, beta_mixed_local, betahessian_local) = match component {
3874 AdaptiveComponent::Magnitude => {
3875 let lambda = params.lambda[0];
3876 let mag = &state.magnitude;
3877 let (objective, gradient_coeff, hessian_diag) = match derivative {
3878 HyperDerivativeKind::Rho => (
3879 mag.penalty_value(),
3880 mag.betagradient_coeff(),
3881 mag.betahessian_diag(),
3882 ),
3883 HyperDerivativeKind::LogEpsilonFirst => (
3884 mag.log_epsilon_gradient_terms().sum(),
3885 mag.log_epsilon_betagradient_coeff(),
3886 mag.log_epsilon_betahessian_diag(),
3887 ),
3888 HyperDerivativeKind::LogEpsilonSecond => (
3889 mag.log_epsilon_hessian_terms().sum(),
3890 mag.log_epsilon_beta_mixed_second_coeff(),
3891 mag.log_epsilon_betahessian_second_diag(),
3892 ),
3893 };
3894 (
3895 lambda * objective,
3896 lambda * scalar_operatorgradient(&cache.d0, &gradient_coeff),
3897 lambda * scalar_operatorhessian(&cache.d0, &hessian_diag),
3898 )
3899 }
3900 AdaptiveComponent::Gradient => {
3901 let lambda = params.lambda[1];
3902 let grad = &state.gradient;
3903 let (objective, gradient_blocks, hessian_blocks) = match derivative {
3904 HyperDerivativeKind::Rho => (
3905 grad.penalty_value(),
3906 grad.betagradient_blocks(),
3907 grad.betahessian_blocks(),
3908 ),
3909 HyperDerivativeKind::LogEpsilonFirst => (
3910 grad.log_epsilon_gradient_terms().sum(),
3911 grad.log_epsilon_betagradient_blocks(),
3912 grad.log_epsilon_betahessian_blocks(),
3913 ),
3914 HyperDerivativeKind::LogEpsilonSecond => (
3915 grad.log_epsilon_hessian_terms().sum(),
3916 grad.log_epsilon_beta_mixed_second_blocks(),
3917 grad.log_epsilon_betahessian_second_blocks(),
3918 ),
3919 };
3920 (
3921 lambda * objective,
3922 lambda
3923 * grouped_operatorgradient(&cache.d1, cache.dimension, &gradient_blocks)
3924 .map_err(|e| e.to_string())?,
3925 lambda
3926 * grouped_operatorhessian(&cache.d1, cache.dimension, &hessian_blocks)
3927 .map_err(|e| e.to_string())?,
3928 )
3929 }
3930 AdaptiveComponent::Curvature => {
3931 let lambda = params.lambda[2];
3932 let group = cache.dimension * cache.dimension;
3933 let curv = &state.curvature;
3934 let (objective, gradient_blocks, hessian_blocks) = match derivative {
3935 HyperDerivativeKind::Rho => (
3936 curv.penalty_value(),
3937 curv.betagradient_blocks(),
3938 curv.betahessian_blocks(),
3939 ),
3940 HyperDerivativeKind::LogEpsilonFirst => (
3941 curv.log_epsilon_gradient_terms().sum(),
3942 curv.log_epsilon_betagradient_blocks(),
3943 curv.log_epsilon_betahessian_blocks(),
3944 ),
3945 HyperDerivativeKind::LogEpsilonSecond => (
3946 curv.log_epsilon_hessian_terms().sum(),
3947 curv.log_epsilon_beta_mixed_second_blocks(),
3948 curv.log_epsilon_betahessian_second_blocks(),
3949 ),
3950 };
3951 (
3952 lambda * objective,
3953 lambda
3954 * grouped_operatorgradient(&cache.d2, group, &gradient_blocks)
3955 .map_err(|e| e.to_string())?,
3956 lambda
3957 * grouped_operatorhessian(&cache.d2, group, &hessian_blocks)
3958 .map_err(|e| e.to_string())?,
3959 )
3960 }
3961 };
3962
3963 let (beta_mixed, betahessian) = self.embed_local_hyper_parts(
3964 &cache.coeff_global_range,
3965 &beta_mixed_local,
3966 &betahessian_local,
3967 );
3968 Ok((objective_local, beta_mixed, betahessian))
3969 }
3970
3971 fn adaptive_shared_log_epsilon_parts(
3972 &self,
3973 eval: &SpatialAdaptiveExactEvaluation,
3974 component: usize,
3975 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
3976 self.adaptive_shared_block_eval(eval, component, HyperDerivativeKind::LogEpsilonFirst)
3982 }
3983
3984 fn adaptive_shared_log_epsilon_second_parts(
3985 &self,
3986 eval: &SpatialAdaptiveExactEvaluation,
3987 component: usize,
3988 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
3989 self.adaptive_shared_block_eval(eval, component, HyperDerivativeKind::LogEpsilonSecond)
3995 }
3996
3997 fn adaptive_shared_block_eval(
4002 &self,
4003 eval: &SpatialAdaptiveExactEvaluation,
4004 component: usize,
4005 derivative: HyperDerivativeKind,
4006 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
4007 let component = AdaptiveComponent::from_index(component)?;
4008 let (mut score, mut hessian) = self.zero_hyper_parts();
4009 let mut objective = 0.0;
4010 for cache_idx in 0..self.runtime_caches.len() {
4011 let (local_objective, local_score, local_hessian) =
4012 self.adaptive_block_eval(eval, cache_idx, component, derivative)?;
4013 objective += local_objective;
4014 score += &local_score;
4015 hessian += &local_hessian;
4016 }
4017 Ok((objective, score, hessian))
4018 }
4019
4020 fn adaptive_shared_log_epsilon_drift(
4021 &self,
4022 eval: &SpatialAdaptiveExactEvaluation,
4023 component: usize,
4024 direction: &Array1<f64>,
4025 ) -> Result<Array2<f64>, String> {
4026 let component = AdaptiveComponent::from_index(component)?;
4030 let total_dim = self.design.ncols();
4031 let mut total = Array2::<f64>::zeros((total_dim, total_dim));
4032 for cache_idx in 0..self.runtime_caches.len() {
4033 total += &self.adaptive_block_drift_eval(
4034 eval,
4035 cache_idx,
4036 component,
4037 HyperDriftKind::LogEpsilon,
4038 direction,
4039 )?;
4040 }
4041 Ok(total)
4042 }
4043
4044 fn adaptive_explicit_second_order_parts(
4045 &self,
4046 eval: &SpatialAdaptiveExactEvaluation,
4047 left: SpatialAdaptiveHyperSpec,
4048 right: SpatialAdaptiveHyperSpec,
4049 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
4050 match left.explicit_second_order_kind(right) {
4059 SpatialAdaptiveExplicitSecondOrderKind::StructuralZero => {
4060 let (score, hessian) = self.zero_hyper_parts();
4061 Ok((0.0, score, hessian))
4062 }
4063 SpatialAdaptiveExplicitSecondOrderKind::LocalAlphaAlpha => self.adaptive_block_eval(
4064 eval,
4065 left.cache_index,
4066 AdaptiveComponent::from_index(left.component_index())?,
4067 HyperDerivativeKind::Rho,
4068 ),
4069 SpatialAdaptiveExplicitSecondOrderKind::LocalAlphaEta => {
4070 let local_alpha = if left.kind.is_log_lambda() {
4071 left
4072 } else {
4073 right
4074 };
4075 self.adaptive_block_eval(
4076 eval,
4077 local_alpha.cache_index,
4078 AdaptiveComponent::from_index(local_alpha.component_index())?,
4079 HyperDerivativeKind::LogEpsilonFirst,
4080 )
4081 }
4082 SpatialAdaptiveExplicitSecondOrderKind::SharedEtaEta => {
4083 self.adaptive_shared_log_epsilon_second_parts(eval, left.component_index())
4084 }
4085 }
4086 }
4087
4088 fn adaptive_block_drift_eval(
4096 &self,
4097 eval: &SpatialAdaptiveExactEvaluation,
4098 cache_idx: usize,
4099 component: AdaptiveComponent,
4100 drift: HyperDriftKind,
4101 direction: &Array1<f64>,
4102 ) -> Result<Array2<f64>, String> {
4103 let cache = self
4104 .runtime_caches
4105 .get(cache_idx)
4106 .ok_or_else(|| format!("adaptive cache index {} out of bounds", cache_idx))?;
4107 let params = self
4108 .adaptive_params
4109 .get(cache_idx)
4110 .ok_or_else(|| format!("adaptive hyperparameter block {} out of bounds", cache_idx))?;
4111 let state = eval
4112 .adaptive_states
4113 .get(cache_idx)
4114 .ok_or_else(|| format!("adaptive exact state index {} out of bounds", cache_idx))?;
4115 let direction_local = direction.slice(s![cache.coeff_global_range.clone()]);
4116
4117 let local_hessian = match component {
4118 AdaptiveComponent::Magnitude => {
4119 let d0_u = cache.d0.dot(&direction_local);
4120 let mag = &state.magnitude;
4121 let diag = match drift {
4122 HyperDriftKind::Rho => mag.directionalhessian_diag(&d0_u),
4123 HyperDriftKind::LogEpsilon => {
4124 mag.log_epsilon_betahessian_directional_diag(&d0_u)
4125 }
4126 };
4127 params.lambda[0] * scalar_operatorhessian(&cache.d0, &diag)
4128 }
4129 AdaptiveComponent::Gradient => {
4130 let d1_u = cache.d1.dot(&direction_local);
4131 let direction_blocks = collocationgradient_blocks(&d1_u, cache.dimension)
4132 .map_err(|e| e.to_string())?;
4133 let grad = &state.gradient;
4134 let blocks = match drift {
4135 HyperDriftKind::Rho => grad.directionalhessian_blocks(&direction_blocks),
4136 HyperDriftKind::LogEpsilon => {
4137 grad.log_epsilon_betahessian_directional_blocks(&direction_blocks)
4138 }
4139 };
4140 params.lambda[1]
4141 * grouped_operatorhessian(&cache.d1, cache.dimension, &blocks)
4142 .map_err(|e| e.to_string())?
4143 }
4144 AdaptiveComponent::Curvature => {
4145 let group = cache.dimension * cache.dimension;
4146 let d2_u = cache.d2.dot(&direction_local);
4147 let direction_blocks =
4148 collocationhessian_blocks(&d2_u, cache.dimension).map_err(|e| e.to_string())?;
4149 let curv = &state.curvature;
4150 let blocks = match drift {
4151 HyperDriftKind::Rho => curv.directionalhessian_blocks(&direction_blocks),
4152 HyperDriftKind::LogEpsilon => {
4153 curv.log_epsilon_betahessian_directional_blocks(&direction_blocks)
4154 }
4155 };
4156 params.lambda[2]
4157 * grouped_operatorhessian(&cache.d2, group, &blocks)
4158 .map_err(|e| e.to_string())?
4159 }
4160 };
4161
4162 Ok(self.embed_local_hyper_hessian(&cache.coeff_global_range, &local_hessian))
4163 }
4164
4165 fn adaptive_hyper_parts(
4166 &self,
4167 eval: &SpatialAdaptiveExactEvaluation,
4168 hyper: SpatialAdaptiveHyperSpec,
4169 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
4170 match hyper.kind {
4171 SpatialAdaptiveHyperKind::LogLambdaMagnitude
4174 | SpatialAdaptiveHyperKind::LogLambdaGradient
4175 | SpatialAdaptiveHyperKind::LogLambdaCurvature => self.adaptive_block_eval(
4176 eval,
4177 hyper.cache_index,
4178 AdaptiveComponent::from_index(hyper.component_index())?,
4179 HyperDerivativeKind::Rho,
4180 ),
4181 SpatialAdaptiveHyperKind::LogEpsilonMagnitude
4183 | SpatialAdaptiveHyperKind::LogEpsilonGradient
4184 | SpatialAdaptiveHyperKind::LogEpsilonCurvature => {
4185 self.adaptive_shared_log_epsilon_parts(eval, hyper.component_index())
4186 }
4187 }
4188 }
4189
4190 fn exact_evaluation_uncached(
4191 &self,
4192 beta: &Array1<f64>,
4193 ) -> Result<SpatialAdaptiveExactEvaluation, String> {
4194 let eta = self.total_eta(beta);
4195 let obs = evaluate_standard_familyobservations(
4196 self.family.clone(),
4197 self.latent_cloglog_state.as_ref(),
4198 self.mixture_link_state.as_ref(),
4199 self.sas_link_state.as_ref(),
4200 &self.y,
4201 &self.weights,
4202 &eta,
4203 )
4204 .map_err(|e| e.to_string())?;
4205 let p = beta.len();
4206 let mut penalty_value = 0.0;
4207 let mut penaltygradient = Array1::<f64>::zeros(p);
4208 let mut penaltyhessian = Array2::<f64>::zeros((p, p));
4209 let mut adaptive_states = Vec::with_capacity(self.runtime_caches.len());
4210
4211 for (cache_idx, cache) in self.runtime_caches.iter().enumerate() {
4212 let params = self.adaptive_params.get(cache_idx).ok_or_else(|| {
4213 format!(
4214 "missing adaptive parameter block for cache {}",
4215 cache.termname
4216 )
4217 })?;
4218 let beta_local = beta.slice(s![cache.coeff_global_range.clone()]);
4219 let state =
4220 SpatialPenaltyExactState::from_beta_local(beta_local, cache, params.epsilon)
4221 .map_err(|e| e.to_string())?;
4222
4223 let g0 = scalar_operatorgradient(&cache.d0, &state.magnitude.betagradient_coeff());
4224 let gg = grouped_operatorgradient(
4225 &cache.d1,
4226 cache.dimension,
4227 &state.gradient.betagradient_blocks(),
4228 )
4229 .map_err(|e| e.to_string())?;
4230 let gc = grouped_operatorgradient(
4231 &cache.d2,
4232 cache.dimension * cache.dimension,
4233 &state.curvature.betagradient_blocks(),
4234 )
4235 .map_err(|e| e.to_string())?;
4236 let h0 = scalar_operatorhessian(&cache.d0, &state.magnitude.betahessian_diag());
4237 let hg = grouped_operatorhessian(
4238 &cache.d1,
4239 cache.dimension,
4240 &state.gradient.betahessian_blocks(),
4241 )
4242 .map_err(|e| e.to_string())?;
4243 let hc = grouped_operatorhessian(
4244 &cache.d2,
4245 cache.dimension * cache.dimension,
4246 &state.curvature.betahessian_blocks(),
4247 )
4248 .map_err(|e| e.to_string())?;
4249
4250 let lambda0 = params.lambda[0];
4251 let lambdag = params.lambda[1];
4252 let lambdac = params.lambda[2];
4253
4254 penalty_value += lambda0 * state.magnitude.penalty_value();
4255 penalty_value += lambdag * state.gradient.penalty_value();
4256 penalty_value += lambdac * state.curvature.penalty_value();
4257
4258 let range = cache.coeff_global_range.clone();
4259 {
4260 let mut grad_local = penaltygradient.slice_mut(s![range.clone()]);
4261 grad_local += &(g0.mapv(|v| lambda0 * v));
4262 grad_local += &(gg.mapv(|v| lambdag * v));
4263 grad_local += &(gc.mapv(|v| lambdac * v));
4264 }
4265 {
4266 let mut h_local = penaltyhessian.slice_mut(s![range.clone(), range]);
4267 h_local += &h0.mapv(|v| lambda0 * v);
4268 h_local += &hg.mapv(|v| lambdag * v);
4269 h_local += &hc.mapv(|v| lambdac * v);
4270 }
4271
4272 adaptive_states.push(state);
4273 }
4274
4275 let (fixed_quadraticvalue, fixed_quadraticgradient) = self.fixed_quadratic_terms(beta);
4276 Ok(SpatialAdaptiveExactEvaluation {
4277 obs,
4278 adaptive_states,
4279 adaptive_penalty_value: penalty_value,
4280 adaptive_penaltygradient: penaltygradient,
4281 adaptive_penaltyhessian: penaltyhessian,
4282 fixed_quadraticvalue,
4283 fixed_quadraticgradient,
4284 fixed_quadratichessian: self.fixed_quadratichessian.as_ref().clone(),
4285 })
4286 }
4287
4288 fn exact_evaluation(
4289 &self,
4290 beta: &Array1<f64>,
4291 ) -> Result<Arc<SpatialAdaptiveExactEvaluation>, String> {
4292 {
4293 let cache = self
4294 .exact_eval_cache
4295 .lock()
4296 .map_err(|_| "spatial adaptive exact-evaluation cache lock poisoned".to_string())?;
4297 if let Some(cached) = cache.as_ref()
4298 && cached.beta.len() == beta.len()
4299 && cached
4300 .beta
4301 .iter()
4302 .zip(beta.iter())
4303 .all(|(&left, &right)| left == right)
4304 {
4305 return Ok(Arc::clone(&cached.eval));
4306 }
4307 }
4308
4309 let eval = Arc::new(self.exact_evaluation_uncached(beta)?);
4310 let mut cache = self
4311 .exact_eval_cache
4312 .lock()
4313 .map_err(|_| "spatial adaptive exact-evaluation cache lock poisoned".to_string())?;
4314 *cache = Some(CachedSpatialAdaptiveExactEvaluation {
4315 beta: beta.clone(),
4316 eval: Arc::clone(&eval),
4317 });
4318 Ok(eval)
4319 }
4320
4321 fn exacthessian_directional_derivative_from_evaluation(
4322 &self,
4323 beta: &Array1<f64>,
4324 eval: &SpatialAdaptiveExactEvaluation,
4325 direction: &Array1<f64>,
4326 ) -> Result<Array2<f64>, String> {
4327 assert_eq!(
4328 beta.len(),
4329 direction.len(),
4330 "beta/direction length mismatch",
4331 );
4332 let d_eta = gam_linalg::faer_ndarray::fast_av(self.design.as_ref(), direction);
4333 let mut total = xt_diag_x_dense(
4334 self.design.view(),
4335 (&eval.obs.neghessian_eta_derivative * &d_eta).view(),
4336 )?;
4337 for (cache_idx, cache) in self.runtime_caches.iter().enumerate() {
4338 let params = self.adaptive_params.get(cache_idx).ok_or_else(|| {
4339 format!(
4340 "missing adaptive parameter block for cache {}",
4341 cache.termname
4342 )
4343 })?;
4344 let state = eval
4345 .adaptive_states
4346 .get(cache_idx)
4347 .ok_or_else(|| format!("missing adaptive state for cache {}", cache.termname))?;
4348 let direction_local = direction.slice(s![cache.coeff_global_range.clone()]);
4349 let d0_u = cache.d0.dot(&direction_local);
4350 let d1_u = cache.d1.dot(&direction_local);
4351 let d2_u = cache.d2.dot(&direction_local);
4352 let h0 =
4353 scalar_operatorhessian(&cache.d0, &state.magnitude.directionalhessian_diag(&d0_u))
4354 .mapv(|v| params.lambda[0] * v);
4355 let hg = grouped_operatorhessian(
4356 &cache.d1,
4357 cache.dimension,
4358 &state.gradient.directionalhessian_blocks(
4359 &collocationgradient_blocks(&d1_u, cache.dimension)
4360 .map_err(|e| e.to_string())?,
4361 ),
4362 )
4363 .map_err(|e| e.to_string())?
4364 .mapv(|v| params.lambda[1] * v);
4365 let hc = grouped_operatorhessian(
4366 &cache.d2,
4367 cache.dimension * cache.dimension,
4368 &state.curvature.directionalhessian_blocks(
4369 &collocationhessian_blocks(&d2_u, cache.dimension)
4370 .map_err(|e| e.to_string())?,
4371 ),
4372 )
4373 .map_err(|e| e.to_string())?
4374 .mapv(|v| params.lambda[2] * v);
4375 let range = cache.coeff_global_range.clone();
4376 let mut local = total.slice_mut(s![range.clone(), range]);
4377 local += &h0;
4378 local += &hg;
4379 local += &hc;
4380 }
4381 Ok(total)
4382 }
4383
4384 fn exacthessian_second_directional_derivative_from_evaluation(
4405 &self,
4406 eval: &SpatialAdaptiveExactEvaluation,
4407 direction_u: &Array1<f64>,
4408 direction_v: &Array1<f64>,
4409 ) -> Result<Option<Array2<f64>>, String> {
4410 let p = self.design.ncols();
4411 if eval.obs.neghessian_eta_derivative.iter().any(|&w| w != 0.0) {
4413 return Ok(None);
4414 }
4415 let mut total = Array2::<f64>::zeros((p, p));
4416 for (cache_idx, cache) in self.runtime_caches.iter().enumerate() {
4417 let params = self.adaptive_params.get(cache_idx).ok_or_else(|| {
4418 format!(
4419 "missing adaptive parameter block for cache {}",
4420 cache.termname
4421 )
4422 })?;
4423 let state = eval
4424 .adaptive_states
4425 .get(cache_idx)
4426 .ok_or_else(|| format!("missing adaptive state for cache {}", cache.termname))?;
4427 let u_local = direction_u.slice(s![cache.coeff_global_range.clone()]);
4428 let v_local = direction_v.slice(s![cache.coeff_global_range.clone()]);
4429
4430 let q0_u = cache.d0.dot(&u_local);
4432 let q0_v = cache.d0.dot(&v_local);
4433 let h0 = scalar_operatorhessian(
4434 &cache.d0,
4435 &state.magnitude.second_directionalhessian_diag(&q0_u, &q0_v),
4436 )
4437 .mapv(|x| params.lambda[0] * x);
4438
4439 let a1 = collocationgradient_blocks(&cache.d1.dot(&u_local), cache.dimension)
4441 .map_err(|e| e.to_string())?;
4442 let b1 = collocationgradient_blocks(&cache.d1.dot(&v_local), cache.dimension)
4443 .map_err(|e| e.to_string())?;
4444 let hg = grouped_operatorhessian(
4445 &cache.d1,
4446 cache.dimension,
4447 &state.gradient.second_directionalhessian_blocks(&a1, &b1),
4448 )
4449 .map_err(|e| e.to_string())?
4450 .mapv(|x| params.lambda[1] * x);
4451
4452 let a2 = collocationhessian_blocks(&cache.d2.dot(&u_local), cache.dimension)
4454 .map_err(|e| e.to_string())?;
4455 let b2 = collocationhessian_blocks(&cache.d2.dot(&v_local), cache.dimension)
4456 .map_err(|e| e.to_string())?;
4457 let hc = grouped_operatorhessian(
4458 &cache.d2,
4459 cache.dimension * cache.dimension,
4460 &state.curvature.second_directionalhessian_blocks(&a2, &b2),
4461 )
4462 .map_err(|e| e.to_string())?
4463 .mapv(|x| params.lambda[2] * x);
4464
4465 let range = cache.coeff_global_range.clone();
4466 let mut local = total.slice_mut(s![range.clone(), range]);
4467 local += &h0;
4468 local += &hg;
4469 local += &hc;
4470 }
4471 Ok(Some(total))
4472 }
4473}
4474
4475impl CustomFamily for SpatialAdaptiveExactFamily {
4476 fn joint_jeffreys_term_required(&self) -> bool {
4480 true
4481 }
4482
4483 fn joint_jeffreys_information_with_specs(
4520 &self,
4521 block_states: &[ParameterBlockState],
4522 specs: &[ParameterBlockSpec],
4523 ) -> Result<Option<Array2<f64>>, String> {
4524 let spec = expect_single_blockspec(specs, "spatial adaptive exact family")?;
4525 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4526 if spec.design.ncols() != beta.len() {
4527 return Err(SmoothError::dimension_mismatch(format!(
4528 "spatial adaptive Jeffreys information: spec design has {} columns, beta has {}",
4529 spec.design.ncols(),
4530 beta.len()
4531 ))
4532 .into());
4533 }
4534 let eval = self.exact_evaluation(beta)?;
4535 Ok(Some(xt_diag_x_dense(
4536 self.design.view(),
4537 eval.obs.neghessian_eta.view(),
4538 )?))
4539 }
4540
4541 fn joint_jeffreys_information_directional_derivative_with_specs(
4542 &self,
4543 block_states: &[ParameterBlockState],
4544 specs: &[ParameterBlockSpec],
4545 d_beta_flat: &Array1<f64>,
4546 ) -> Result<Option<Array2<f64>>, String> {
4547 let spec = expect_single_blockspec(specs, "spatial adaptive exact family")?;
4553 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4554 if spec.design.ncols() != d_beta_flat.len() {
4555 return Err(SmoothError::dimension_mismatch(format!(
4556 "spatial adaptive Jeffreys directional derivative: spec design has {} columns, direction has {}",
4557 spec.design.ncols(),
4558 d_beta_flat.len()
4559 ))
4560 .into());
4561 }
4562 let eval = self.exact_evaluation(beta)?;
4563 let d_eta = gam_linalg::faer_ndarray::fast_av(self.design.as_ref(), d_beta_flat);
4564 Ok(Some(xt_diag_x_dense(
4565 self.design.view(),
4566 (&eval.obs.neghessian_eta_derivative * &d_eta).view(),
4567 )?))
4568 }
4569
4570 fn joint_jeffreys_information_second_directional_derivative_with_specs(
4571 &self,
4572 block_states: &[ParameterBlockState],
4573 specs: &[ParameterBlockSpec],
4574 d_beta_u_flat: &Array1<f64>,
4575 d_betav_flat: &Array1<f64>,
4576 ) -> Result<Option<Array2<f64>>, String> {
4577 let spec = expect_single_blockspec(specs, "spatial adaptive exact family")?;
4584 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4585 if spec.design.ncols() != beta.len()
4586 || d_beta_u_flat.len() != beta.len()
4587 || d_betav_flat.len() != beta.len()
4588 {
4589 return Err(SmoothError::dimension_mismatch(format!(
4590 "spatial adaptive Jeffreys second-direction length mismatch: spec cols={}, dirs=({}, {}), expected {}",
4591 spec.design.ncols(),
4592 d_beta_u_flat.len(),
4593 d_betav_flat.len(),
4594 beta.len()
4595 ))
4596 .into());
4597 }
4598 let eval = self.exact_evaluation(beta)?;
4599 if eval.obs.neghessian_eta_derivative.iter().any(|&w| w != 0.0) {
4600 return Ok(None);
4601 }
4602 Ok(Some(Array2::<f64>::zeros((beta.len(), beta.len()))))
4603 }
4604
4605 fn joint_jeffreys_information_matches_observed_hessian(&self) -> bool {
4606 false
4611 }
4612
4613 fn joint_jeffreys_information_depends_on_psi(&self) -> bool {
4614 false
4623 }
4624
4625 fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
4626 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4627 let eval = self.exact_evaluation(beta)?;
4628 let mut gradient = fast_atv(&self.design, &eval.obs.score);
4629 gradient -= &eval.total_penaltygradient();
4630 let mut hessian = xt_diag_x_dense(self.design.view(), eval.obs.neghessian_eta.view())?;
4631 hessian += &eval.total_penaltyhessian();
4632 Ok(FamilyEvaluation {
4633 log_likelihood: eval.obs.log_likelihood - eval.total_penalty_value(),
4634 blockworking_sets: vec![BlockWorkingSet::ExactNewton {
4635 gradient,
4636 hessian: SymmetricMatrix::Dense(hessian),
4637 }],
4638 })
4639 }
4640
4641 fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
4642 let state = expect_single_block_state(block_states, "spatial adaptive exact family")?;
4643 let beta = &state.beta;
4644 let obs = evaluate_standard_familyobservations(
4645 self.family.clone(),
4646 self.latent_cloglog_state.as_ref(),
4647 self.mixture_link_state.as_ref(),
4648 self.sas_link_state.as_ref(),
4649 &self.y,
4650 &self.weights,
4651 &state.eta,
4652 )
4653 .map_err(|e| e.to_string())?;
4654 let adaptive_penalty = self.adaptive_penalty_value_only(beta)?;
4655 let (fixed_quadratic, _) = self.fixed_quadratic_terms(beta);
4656 Ok(obs.log_likelihood - adaptive_penalty - fixed_quadratic)
4657 }
4658
4659 fn exact_newton_outerobjective(&self) -> ExactNewtonOuterObjective {
4660 ExactNewtonOuterObjective::StrictPseudoLaplace
4661 }
4662
4663 fn exact_newton_joint_hessian(
4664 &self,
4665 block_states: &[ParameterBlockState],
4666 ) -> Result<Option<Array2<f64>>, String> {
4667 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4668 let eval = self.exact_evaluation(beta)?;
4669 Ok(Some(eval.totalobjectivehessian(&self.design)?))
4670 }
4671
4672 fn exact_newton_hessian_directional_derivative(
4673 &self,
4674 block_states: &[ParameterBlockState],
4675 block_idx: usize,
4676 d_beta: &Array1<f64>,
4677 ) -> Result<Option<Array2<f64>>, String> {
4678 expect_block_idx_zero(block_idx, "spatial adaptive exact family", "")?;
4679 self.exact_newton_joint_hessian_directional_derivative(block_states, d_beta)
4680 }
4681
4682 fn exact_newton_joint_hessian_directional_derivative(
4683 &self,
4684 block_states: &[ParameterBlockState],
4685 d_beta_flat: &Array1<f64>,
4686 ) -> Result<Option<Array2<f64>>, String> {
4687 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4688 if d_beta_flat.len() != beta.len() {
4689 return Err(SmoothError::dimension_mismatch(format!(
4690 "spatial adaptive exact family direction length mismatch: got {}, expected {}",
4691 d_beta_flat.len(),
4692 beta.len()
4693 ))
4694 .into());
4695 }
4696 let eval = self.exact_evaluation(beta)?;
4697 Ok(Some(
4698 self.exacthessian_directional_derivative_from_evaluation(beta, &eval, d_beta_flat)?,
4699 ))
4700 }
4701
4702 fn exact_newton_joint_hessiansecond_directional_derivative(
4703 &self,
4704 block_states: &[ParameterBlockState],
4705 d_beta_u_flat: &Array1<f64>,
4706 d_betav_flat: &Array1<f64>,
4707 ) -> Result<Option<Array2<f64>>, String> {
4708 let beta = &expect_single_block_state(block_states, "spatial adaptive exact family")?.beta;
4709 if d_beta_u_flat.len() != beta.len() || d_betav_flat.len() != beta.len() {
4710 return Err(SmoothError::dimension_mismatch(format!(
4711 "spatial adaptive exact family second-direction length mismatch: got ({}, {}), expected {}",
4712 d_beta_u_flat.len(),
4713 d_betav_flat.len(),
4714 beta.len()
4715 ))
4716 .into());
4717 }
4718 let eval = self.exact_evaluation(beta)?;
4719 self.exacthessian_second_directional_derivative_from_evaluation(
4720 &eval,
4721 d_beta_u_flat,
4722 d_betav_flat,
4723 )
4724 }
4725
4726 fn block_linear_constraints(
4727 &self,
4728 block_states: &[ParameterBlockState],
4729 block_idx: usize,
4730 block_spec: &ParameterBlockSpec,
4731 ) -> Result<Option<LinearInequalityConstraints>, String> {
4732 assert!(!block_states.is_empty(), "block_states must be non-empty");
4733 assert!(
4734 !block_spec.name.is_empty(),
4735 "block spec name must be non-empty",
4736 );
4737 expect_block_idx_zero(block_idx, "spatial adaptive exact family", "")?;
4738 Ok(self.linear_constraints.clone())
4739 }
4740
4741 fn exact_newton_joint_psi_terms(
4742 &self,
4743 block_states: &[ParameterBlockState],
4744 specs: &[ParameterBlockSpec],
4745 derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
4746 psi_index: usize,
4747 ) -> Result<Option<ExactNewtonJointPsiTerms>, String> {
4748 if block_states.len() != 1 || specs.len() != 1 || derivative_blocks.len() != 1 {
4749 return Err(SmoothError::dimension_mismatch(format!(
4750 "spatial adaptive exact family expects one block/state/spec/psi payload, got states={} specs={} deriv_blocks={}",
4751 block_states.len(),
4752 specs.len(),
4753 derivative_blocks.len()
4754 ))
4755 .into());
4756 }
4757 derivative_blocks[0]
4758 .get(psi_index)
4759 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_index))?;
4760 let hyper = self
4761 .hyperspecs
4762 .get(psi_index)
4763 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_index))?;
4764 let beta = &block_states[0].beta;
4765 let eval = self.exact_evaluation(beta)?;
4766 let (direct, beta_mixed, betahessian_explicit) =
4767 self.adaptive_hyper_parts(&eval, *hyper)?;
4768
4769 Ok(Some(ExactNewtonJointPsiTerms {
4790 objective_psi: direct,
4791 score_psi: beta_mixed,
4792 hessian_psi: betahessian_explicit,
4793 hessian_psi_operator: None,
4794 }))
4795 }
4796
4797 fn exact_newton_joint_psisecond_order_terms(
4798 &self,
4799 block_states: &[ParameterBlockState],
4800 specs: &[ParameterBlockSpec],
4801 derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
4802 psi_i: usize,
4803 psi_j: usize,
4804 ) -> Result<Option<gam_problem::ExactNewtonJointPsiSecondOrderTerms>, String> {
4805 if block_states.len() != 1 || specs.len() != 1 || derivative_blocks.len() != 1 {
4806 return Err(SmoothError::dimension_mismatch(format!(
4807 "spatial adaptive exact family expects one block/state/spec/psi payload, got states={} specs={} deriv_blocks={}",
4808 block_states.len(),
4809 specs.len(),
4810 derivative_blocks.len()
4811 ))
4812 .into());
4813 }
4814 derivative_blocks[0]
4815 .get(psi_i)
4816 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_i))?;
4817 derivative_blocks[0]
4818 .get(psi_j)
4819 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_j))?;
4820 let hyper_i = self
4821 .hyperspecs
4822 .get(psi_i)
4823 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_i))?;
4824 let hyper_j = self
4825 .hyperspecs
4826 .get(psi_j)
4827 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_j))?;
4828 let beta = &block_states[0].beta;
4829 let eval = self.exact_evaluation(beta)?;
4830 let (objective_psi_psi, score_psi_psi, hessian_psi_psi) =
4831 self.adaptive_explicit_second_order_parts(&eval, *hyper_i, *hyper_j)?;
4832
4833 Ok(Some(
4834 gam_problem::ExactNewtonJointPsiSecondOrderTerms {
4835 objective_psi_psi,
4836 score_psi_psi,
4837 hessian_psi_psi,
4838 hessian_psi_psi_operator: None,
4839 },
4840 ))
4841 }
4842
4843 fn exact_newton_joint_psihessian_directional_derivative(
4844 &self,
4845 block_states: &[ParameterBlockState],
4846 specs: &[ParameterBlockSpec],
4847 derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
4848 psi_index: usize,
4849 direction: &Array1<f64>,
4850 ) -> Result<Option<Array2<f64>>, String> {
4851 if block_states.len() != 1 || specs.len() != 1 || derivative_blocks.len() != 1 {
4852 return Err(SmoothError::dimension_mismatch(format!(
4853 "spatial adaptive exact family expects one block/state/spec/psi payload, got states={} specs={} deriv_blocks={}",
4854 block_states.len(),
4855 specs.len(),
4856 derivative_blocks.len()
4857 ))
4858 .into());
4859 }
4860 let beta = &block_states[0].beta;
4861 if direction.len() != beta.len() {
4862 return Err(SmoothError::dimension_mismatch(format!(
4863 "spatial adaptive exact family direction length mismatch: got {}, expected {}",
4864 direction.len(),
4865 beta.len()
4866 ))
4867 .into());
4868 }
4869 derivative_blocks[0]
4870 .get(psi_index)
4871 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_index))?;
4872 let hyper = self
4873 .hyperspecs
4874 .get(psi_index)
4875 .ok_or_else(|| format!("adaptive psi index {} out of bounds", psi_index))?;
4876 let eval = self.exact_evaluation(beta)?;
4877 let drift = match hyper.kind {
4878 SpatialAdaptiveHyperKind::LogLambdaMagnitude
4879 | SpatialAdaptiveHyperKind::LogLambdaGradient
4880 | SpatialAdaptiveHyperKind::LogLambdaCurvature => self.adaptive_block_drift_eval(
4881 &eval,
4882 hyper.cache_index,
4883 AdaptiveComponent::from_index(hyper.kind.component_index())?,
4884 HyperDriftKind::Rho,
4885 direction,
4886 )?,
4887 SpatialAdaptiveHyperKind::LogEpsilonMagnitude
4888 | SpatialAdaptiveHyperKind::LogEpsilonGradient
4889 | SpatialAdaptiveHyperKind::LogEpsilonCurvature => self
4890 .adaptive_shared_log_epsilon_drift(
4891 &eval,
4892 hyper.kind.component_index(),
4893 direction,
4894 )?,
4895 };
4896 Ok(Some(drift))
4897 }
4898}
4899
4900fn expect_single_block_state<'a>(
4901 block_states: &'a [ParameterBlockState],
4902 family_name: &str,
4903) -> Result<&'a ParameterBlockState, String> {
4904 crate::block_layout::block_count::validate_block_count::<SmoothError>(
4905 family_name,
4906 1,
4907 block_states.len(),
4908 )?;
4909 Ok(&block_states[0])
4910}
4911
4912fn expect_single_blockspec<'a>(
4913 specs: &'a [ParameterBlockSpec],
4914 family_name: &str,
4915) -> Result<&'a ParameterBlockSpec, String> {
4916 crate::block_layout::block_count::validate_block_count::<SmoothError>(
4917 family_name,
4918 1,
4919 specs.len(),
4920 )?;
4921 Ok(&specs[0])
4922}
4923
4924fn expect_block_idx_zero(block_idx: usize, family_name: &str, context: &str) -> Result<(), String> {
4925 if block_idx != 0 {
4926 return Err(SmoothError::invalid_index(format!(
4927 "{family_name} expects block_idx 0{context}, got {block_idx}"
4928 ))
4929 .into());
4930 }
4931 Ok::<(), _>(())
4932}
4933
4934impl BoundedLinearFamily {
4935 fn bounded_term_derivative_data(
4936 &self,
4937 latent_beta: &Array1<f64>,
4938 ) -> (
4939 Array1<f64>,
4940 Array1<f64>,
4941 Array1<f64>,
4942 Array1<f64>,
4943 Array1<f64>,
4944 ) {
4945 let p = latent_beta.len();
4946 let mut beta_user = latent_beta.clone();
4947 let mut jac_diag = Array1::<f64>::ones(p);
4948 let mut second_diag = Array1::<f64>::zeros(p);
4949 let mut third_diag = Array1::<f64>::zeros(p);
4950 let mut priorthird = Array1::<f64>::zeros(p);
4951 for term in &self.bounded_terms {
4952 let (beta, _, db_dtheta, d2b_dtheta2, d3b_dtheta3) =
4953 bounded_latent_derivatives(latent_beta[term.col_idx], term.min, term.max);
4954 beta_user[term.col_idx] = beta;
4955 jac_diag[term.col_idx] = db_dtheta;
4956 second_diag[term.col_idx] = d2b_dtheta2;
4957 third_diag[term.col_idx] = d3b_dtheta3;
4958 let (_, _, _, prior_neghess_derivative) =
4959 bounded_prior_terms(latent_beta[term.col_idx], &term.prior);
4960 priorthird[term.col_idx] = prior_neghess_derivative;
4961 }
4962 (beta_user, jac_diag, second_diag, third_diag, priorthird)
4963 }
4964
4965 fn user_beta_and_jacobian(&self, latent_beta: &Array1<f64>) -> (Array1<f64>, Array1<f64>) {
4966 let (beta_user, jac_diag, _, _, _) = self.bounded_term_derivative_data(latent_beta);
4967 (beta_user, jac_diag)
4968 }
4969
4970 fn nonlinear_offset_from_latent(&self, latent_beta: &Array1<f64>) -> Array1<f64> {
4971 let mut offset = self.offset.clone();
4972 for term in &self.bounded_terms {
4973 let (beta, _, _) =
4974 bounded_latent_to_user(latent_beta[term.col_idx], term.min, term.max);
4975 offset.scaled_add(beta, &self.design.column(term.col_idx));
4976 }
4977 offset
4978 }
4979
4980 fn effective_design_for_latent(&self, jac_diag: &Array1<f64>) -> Array2<f64> {
4981 let mut x_eff = self.design.clone();
4982 for term in &self.bounded_terms {
4983 x_eff
4984 .column_mut(term.col_idx)
4985 .mapv_inplace(|v| v * jac_diag[term.col_idx]);
4986 }
4987 x_eff
4988 }
4989
4990 fn exacthessian_andgradient(
4991 &self,
4992 latent_beta: &Array1<f64>,
4993 ) -> Result<
4994 (
4995 StandardFamilyObservationState,
4996 Array2<f64>,
4997 Array1<f64>,
4998 f64,
4999 Array1<f64>,
5000 Array1<f64>,
5001 Array1<f64>,
5002 ),
5003 String,
5004 > {
5005 let (_, jac_diag, second_diag, third_diag, priorthird) =
5006 self.bounded_term_derivative_data(latent_beta);
5007 let x_eff = self.effective_design_for_latent(&jac_diag);
5008 let eta =
5009 self.designzeroed.dot(latent_beta) + self.nonlinear_offset_from_latent(latent_beta);
5010 let obs = evaluate_standard_familyobservations(
5011 self.family.clone(),
5012 self.latent_cloglog_state.as_ref(),
5013 self.mixture_link_state.as_ref(),
5014 self.sas_link_state.as_ref(),
5015 &self.y,
5016 &self.weights,
5017 &eta,
5018 )
5019 .map_err(|e| e.to_string())?;
5020
5021 let mut priorgrad = Array1::<f64>::zeros(latent_beta.len());
5022 let mut prior_neghess = Array2::<f64>::zeros((latent_beta.len(), latent_beta.len()));
5023 let mut prior_loglik = 0.0;
5024 for term in &self.bounded_terms {
5025 let (logp, grad, neghess, _) =
5026 bounded_prior_terms(latent_beta[term.col_idx], &term.prior);
5027 prior_loglik += logp;
5028 priorgrad[term.col_idx] += grad;
5029 prior_neghess[[term.col_idx, term.col_idx]] += neghess;
5030 }
5031
5032 let mut hessian = xt_diag_x_dense(x_eff.view(), obs.neghessian_eta.view())?;
5033 let mut gradient = fast_atv(&x_eff, &obs.score);
5034 for term in &self.bounded_terms {
5035 let score_beta = self.design.column(term.col_idx).dot(&obs.score);
5036 hessian[[term.col_idx, term.col_idx]] -= score_beta * second_diag[term.col_idx];
5037 }
5038 hessian += &prior_neghess;
5039 gradient += &priorgrad;
5040
5041 Ok((
5042 obs,
5043 hessian,
5044 gradient,
5045 prior_loglik,
5046 second_diag,
5047 third_diag,
5048 priorthird,
5049 ))
5050 }
5051
5052 fn evaluation_from_latent(
5053 &self,
5054 latent_beta: &Array1<f64>,
5055 ) -> Result<
5056 (
5057 StandardFamilyObservationState,
5058 Array2<f64>,
5059 Array1<f64>,
5060 f64,
5061 ),
5062 String,
5063 > {
5064 let (obs, hessian, gradient, prior_loglik, _, _, _) =
5065 self.exacthessian_andgradient(latent_beta)?;
5066 Ok((obs, hessian, gradient, prior_loglik))
5067 }
5068}
5069
5070impl CustomFamily for BoundedLinearFamily {
5071 fn joint_jeffreys_term_required(&self) -> bool {
5075 true
5076 }
5077
5078 fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
5079 let latent_beta = &expect_single_block_state(block_states, "bounded linear family")?.beta;
5080 let (obs, hessian, gradient, prior_loglik) = self.evaluation_from_latent(latent_beta)?;
5081 Ok(FamilyEvaluation {
5082 log_likelihood: obs.log_likelihood + prior_loglik,
5083 blockworking_sets: vec![BlockWorkingSet::ExactNewton {
5084 gradient,
5085 hessian: SymmetricMatrix::Dense(hessian),
5086 }],
5087 })
5088 }
5089
5090 fn exact_newton_joint_hessian(
5091 &self,
5092 block_states: &[ParameterBlockState],
5093 ) -> Result<Option<Array2<f64>>, String> {
5094 let latent_beta = &expect_single_block_state(block_states, "bounded linear family")?.beta;
5095 let (_, hessian, _, _) = self.evaluation_from_latent(latent_beta)?;
5096 Ok(Some(hessian))
5097 }
5098
5099 fn exact_newton_hessian_directional_derivative(
5100 &self,
5101 block_states: &[ParameterBlockState],
5102 block_idx: usize,
5103 d_beta: &Array1<f64>,
5104 ) -> Result<Option<Array2<f64>>, String> {
5105 expect_block_idx_zero(block_idx, "bounded linear family", "")?;
5106 self.exact_newton_joint_hessian_directional_derivative(block_states, d_beta)
5107 }
5108
5109 fn exact_newton_joint_hessian_directional_derivative(
5110 &self,
5111 block_states: &[ParameterBlockState],
5112 d_beta_flat: &Array1<f64>,
5113 ) -> Result<Option<Array2<f64>>, String> {
5114 let latent_beta = &expect_single_block_state(block_states, "bounded linear family")?.beta;
5115 if d_beta_flat.len() != latent_beta.len() {
5116 return Err(SmoothError::dimension_mismatch(format!(
5117 "bounded linear family directional derivative length mismatch: got {}, expected {}",
5118 d_beta_flat.len(),
5119 latent_beta.len()
5120 ))
5121 .into());
5122 }
5123
5124 let (obs, _, _, _, second_diag, third_diag, priorthird) =
5125 self.exacthessian_andgradient(latent_beta)?;
5126
5127 let (_, jac_diag, _, _, _) = self.bounded_term_derivative_data(latent_beta);
5128 let x_eff = self.effective_design_for_latent(&jac_diag);
5129 let deta = x_eff.dot(d_beta_flat);
5130 let d_neghess_eta = &obs.neghessian_eta_derivative * &deta;
5131
5132 let mut dx_eff = Array2::<f64>::zeros(x_eff.raw_dim());
5133 for term in &self.bounded_terms {
5134 let scale = second_diag[term.col_idx] * d_beta_flat[term.col_idx];
5135 if scale != 0.0 {
5136 let mut col = dx_eff.column_mut(term.col_idx);
5137 col.assign(&self.design.column(term.col_idx));
5138 col.mapv_inplace(|v| v * scale);
5139 }
5140 }
5141
5142 let mut dhessian = xt_diag_x_dense(x_eff.view(), d_neghess_eta.view())?;
5143 let mut wxdx = Array2::<f64>::zeros((x_eff.ncols(), x_eff.ncols()));
5144 for i in 0..x_eff.nrows() {
5145 let wi = obs.neghessian_eta[i];
5146 if wi == 0.0 {
5147 continue;
5148 }
5149 for a in 0..x_eff.ncols() {
5150 let xa = x_eff[[i, a]];
5151 for b in 0..x_eff.ncols() {
5152 wxdx[[a, b]] += wi * (dx_eff[[i, a]] * x_eff[[i, b]] + xa * dx_eff[[i, b]]);
5153 }
5154 }
5155 }
5156 dhessian += &wxdx;
5157
5158 let d_score = -&obs.neghessian_eta * &deta;
5159 for term in &self.bounded_terms {
5160 let score_beta = self.design.column(term.col_idx).dot(&obs.score);
5161 let d_score_beta = self.design.column(term.col_idx).dot(&d_score);
5162 dhessian[[term.col_idx, term.col_idx]] -= d_score_beta * second_diag[term.col_idx]
5163 + score_beta * third_diag[term.col_idx] * d_beta_flat[term.col_idx];
5164 dhessian[[term.col_idx, term.col_idx]] +=
5165 priorthird[term.col_idx] * d_beta_flat[term.col_idx];
5166 }
5167
5168 Ok(Some(dhessian))
5169 }
5170
5171 fn block_geometry(
5172 &self,
5173 block_states: &[ParameterBlockState],
5174 spec: &ParameterBlockSpec,
5175 ) -> Result<(DesignMatrix, Array1<f64>), String> {
5176 if block_states.is_empty() {
5177 return Ok((
5178 DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
5179 self.designzeroed.clone(),
5180 )),
5181 self.offset.clone(),
5182 ));
5183 }
5184 let offset = self.nonlinear_offset_from_latent(
5185 &expect_single_block_state(block_states, "bounded linear family")?.beta,
5186 );
5187 let x = if spec.design.ncols() == self.designzeroed.ncols() {
5188 self.designzeroed.clone()
5189 } else {
5190 return Err(SmoothError::dimension_mismatch(
5191 "bounded linear family design column mismatch",
5192 )
5193 .into());
5194 };
5195 Ok((
5196 DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(x)),
5197 offset,
5198 ))
5199 }
5200
5201 fn block_geometry_is_dynamic(&self) -> bool {
5202 true
5203 }
5204
5205 fn block_geometry_directional_derivative(
5206 &self,
5207 block_states: &[ParameterBlockState],
5208 block_idx: usize,
5209 spec: &ParameterBlockSpec,
5210 d_beta: &Array1<f64>,
5211 ) -> Result<Option<BlockGeometryDirectionalDerivative>, String> {
5212 expect_block_idx_zero(
5213 block_idx,
5214 "bounded linear family",
5215 " for geometry derivative",
5216 )?;
5217 expect_single_block_state(block_states, "bounded linear family")?;
5218 if d_beta.len() != spec.design.ncols() {
5219 return Err(SmoothError::dimension_mismatch(format!(
5220 "bounded linear family geometry derivative direction mismatch: got {}, expected {}",
5221 d_beta.len(),
5222 spec.design.ncols()
5223 ))
5224 .into());
5225 }
5226 let (_, jac_diag, _, _, _) = self.bounded_term_derivative_data(&block_states[0].beta);
5227 let mut d_offset = Array1::<f64>::zeros(self.offset.len());
5228 let has_drift = self
5229 .bounded_terms
5230 .iter()
5231 .any(|term| jac_diag[term.col_idx] != 0.0 && d_beta[term.col_idx] != 0.0);
5232 if !has_drift {
5233 return Ok(Some(BlockGeometryDirectionalDerivative {
5234 d_design: None,
5235 d_offset,
5236 }));
5237 }
5238 for term in &self.bounded_terms {
5239 let col = term.col_idx;
5240 let drift = jac_diag[col] * d_beta[col];
5241 if drift != 0.0 {
5242 d_offset.scaled_add(drift, &self.design.column(col));
5243 }
5244 }
5245 Ok(Some(BlockGeometryDirectionalDerivative {
5246 d_design: None,
5247 d_offset,
5248 }))
5249 }
5250}
5251
5252#[inline]
5253fn dense_diag_gram_chunkrows(p: usize) -> usize {
5254 const MIN_ROWS: usize = 512;
5255 const MAX_ROWS: usize = 2048;
5256 const TARGET_BYTES: usize = 2 * 1024 * 1024;
5257 let bytes_per_row = p.max(1) * std::mem::size_of::<f64>();
5258 (TARGET_BYTES / bytes_per_row).clamp(MIN_ROWS, MAX_ROWS)
5259}
5260
5261fn xt_diag_x_dense(x: ArrayView2<'_, f64>, w: ArrayView1<'_, f64>) -> Result<Array2<f64>, String> {
5262 if x.nrows() != w.len() {
5263 return Err(SmoothError::dimension_mismatch("xt_diag_x_dense row mismatch").into());
5264 }
5265 let (n, p) = x.dim();
5266 if n == 0 || p == 0 {
5267 return Ok(Array2::<f64>::zeros((p, p)));
5268 }
5269
5270 const STREAMING_BYTES_THRESHOLD: usize = 8 * 1024 * 1024;
5271 let dense_work_bytes = n
5272 .checked_mul(p)
5273 .and_then(|cells| cells.checked_mul(std::mem::size_of::<f64>()))
5274 .unwrap_or(usize::MAX);
5275 if dense_work_bytes <= STREAMING_BYTES_THRESHOLD {
5276 let mut weighted = x.to_owned();
5277 ndarray::Zip::from(weighted.rows_mut())
5278 .and(w)
5279 .par_for_each(|mut row, wi| row *= *wi);
5280 return Ok(fast_atb(&x, &weighted));
5281 }
5282
5283 let chunkrows = dense_diag_gram_chunkrows(p).min(n);
5284 let mut weighted_chunk = Array2::<f64>::zeros((chunkrows, p));
5285 let mut out = Array2::<f64>::zeros((p, p));
5286 for row_start in (0..n).step_by(chunkrows) {
5287 let rows = (n - row_start).min(chunkrows);
5288 let x_chunk = x.slice(s![row_start..row_start + rows, ..]);
5289 {
5290 let mut chunk = weighted_chunk.slice_mut(s![0..rows, ..]);
5291 for local_row in 0..rows {
5292 let scale = w[row_start + local_row];
5293 if scale == 0.0 {
5294 chunk.row_mut(local_row).fill(0.0);
5295 continue;
5296 }
5297 for col in 0..p {
5298 chunk[[local_row, col]] = x_chunk[[local_row, col]] * scale;
5299 }
5300 }
5301 }
5302 out += &fast_atb(&x_chunk, &weighted_chunk.slice(s![0..rows, ..]));
5303 }
5304 Ok(out)
5305}
5306
5307fn trace_of_dense_product(a: &Array2<f64>, b: &Array2<f64>) -> Result<f64, String> {
5308 if a.nrows() != a.ncols() || b.nrows() != b.ncols() || a.nrows() != b.nrows() {
5309 return Err(
5310 SmoothError::dimension_mismatch("trace_of_dense_product dimension mismatch").into(),
5311 );
5312 }
5313 let mut trace = 0.0;
5314 for i in 0..a.nrows() {
5315 for j in 0..a.ncols() {
5316 trace += a[[i, j]] * b[[j, i]];
5317 }
5318 }
5319 Ok(trace)
5320}
5321
5322fn exact_bounded_edf(
5323 penalties: &[PenaltySpec],
5324 lambdas: &Array1<f64>,
5325 latent_cov: &Array2<f64>,
5326) -> Result<(Vec<f64>, Vec<f64>, f64), EstimationError> {
5327 if penalties.len() != lambdas.len() {
5328 crate::bail_invalid_estim!(
5329 "bounded EDF penalty/lambda mismatch: {} penalties vs {} lambdas",
5330 penalties.len(),
5331 lambdas.len()
5332 );
5333 }
5334 if latent_cov.nrows() != latent_cov.ncols() {
5335 crate::bail_invalid_estim!("bounded EDF covariance must be square");
5336 }
5337
5338 let p = latent_cov.nrows();
5339 let mut s_lambda = Array2::<f64>::zeros((p, p));
5340 let mut edf_by_block = Vec::with_capacity(penalties.len());
5341 let mut penalty_block_trace = Vec::with_capacity(penalties.len());
5343 let mut trace_sum = 0.0;
5344
5345 for (k, ps) in penalties.iter().enumerate() {
5346 let lambda_k = lambdas[k];
5347 match ps {
5348 PenaltySpec::Block {
5349 local, col_range, ..
5350 } => {
5351 s_lambda
5352 .slice_mut(ndarray::s![col_range.clone(), col_range.clone()])
5353 .scaled_add(lambda_k, local);
5354 let penalty_rank =
5356 local
5357 .nrows()
5358 .saturating_sub(estimate_penalty_nullity(local).map_err(|e| {
5359 EstimationError::InvalidInput(format!("bounded EDF rank failed: {e}"))
5360 })?);
5361 let cov_block = latent_cov.slice(ndarray::s![col_range.clone(), col_range.clone()]);
5363 let trace_k = lambda_k
5364 * trace_of_dense_product(&cov_block.to_owned(), local)
5365 .map_err(EstimationError::InvalidInput)?;
5366 trace_sum += trace_k;
5367 penalty_block_trace.push(trace_k);
5368 let p_k = penalty_rank as f64;
5369 edf_by_block.push((p_k - trace_k).clamp(0.0, p_k));
5370 }
5371 PenaltySpec::Dense(m) | PenaltySpec::DenseWithMean { matrix: m, .. } => {
5372 s_lambda.scaled_add(lambda_k, m);
5373 let penalty_rank = p.saturating_sub(estimate_penalty_nullity(m).map_err(|e| {
5374 EstimationError::InvalidInput(format!("bounded EDF rank failed: {e}"))
5375 })?);
5376 let trace_k = lambda_k
5377 * trace_of_dense_product(latent_cov, m)
5378 .map_err(EstimationError::InvalidInput)?;
5379 trace_sum += trace_k;
5380 penalty_block_trace.push(trace_k);
5381 let p_k = penalty_rank as f64;
5382 edf_by_block.push((p_k - trace_k).clamp(0.0, p_k));
5383 }
5384 }
5385 }
5386
5387 let nullity_total = estimate_penalty_nullity(&s_lambda)
5388 .map_err(|e| EstimationError::InvalidInput(format!("bounded EDF nullity failed: {e}")))?
5389 as f64;
5390 let edf_total = (p as f64 - trace_sum).clamp(nullity_total, p as f64);
5391 Ok((edf_by_block, penalty_block_trace, edf_total))
5392}
5393
5394fn symmetric_positive_definite_inverse_or_pseudo(
5406 precision: &Array2<f64>,
5407) -> Result<Array2<f64>, EstimationError> {
5408 use gam_linalg::faer_ndarray::FaerEigh;
5409 let p = precision.nrows();
5410 if precision.ncols() != p {
5411 crate::bail_invalid_estim!(
5412 "posterior precision inverse requires a square matrix, got {}x{}",
5413 precision.nrows(),
5414 precision.ncols()
5415 );
5416 }
5417 if p == 0 {
5418 return Ok(Array2::<f64>::zeros((0, 0)));
5419 }
5420 let symmetric = (precision + &precision.t().to_owned()) * 0.5;
5421 let (evals, evecs) = symmetric.eigh(faer::Side::Lower).map_err(|e| {
5422 EstimationError::InvalidInput(format!(
5423 "posterior precision eigendecomposition failed: {e}"
5424 ))
5425 })?;
5426 let max_abs_eval = evals.iter().fold(0.0_f64, |acc, &ev| acc.max(ev.abs()));
5427 let tol =
5428 (10.0 * f64::EPSILON * (p as f64) * (p as f64) * max_abs_eval).max(100.0 * f64::EPSILON);
5429 if let Some(&min_eval) = evals
5430 .iter()
5431 .filter(|&&ev| ev < -tol)
5432 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
5433 {
5434 crate::bail_invalid_estim!(
5435 "bounded posterior precision is non-PD at the converged optimum (min eigenvalue \
5436 {min_eval:.6e} < -tol={tol:.6e}); the reported mode is not a strict posterior \
5437 maximum, so a covariance would be meaningless"
5438 );
5439 }
5440 let mut scaled = evecs.clone();
5442 for (j, &ev) in evals.iter().enumerate() {
5443 let inv = if ev > tol { 1.0 / ev } else { 0.0 };
5444 scaled.column_mut(j).mapv_inplace(|v| v * inv);
5445 }
5446 let cov = scaled.dot(&evecs.t());
5447 Ok((&cov + &cov.t().to_owned()) * 0.5)
5448}
5449
5450fn transform_bounded_latent_precision_to_user_internal(
5451 latent_precision: &Array2<f64>,
5452 jac_diag: &Array1<f64>,
5453) -> Result<Array2<f64>, EstimationError> {
5454 let p = latent_precision.nrows();
5455 if latent_precision.ncols() != p || jac_diag.len() != p {
5456 crate::bail_invalid_estim!(
5457 "bounded precision transform dimension mismatch: precision is {}x{}, jacobian has {} entries",
5458 latent_precision.nrows(),
5459 latent_precision.ncols(),
5460 jac_diag.len()
5461 );
5462 }
5463 let mut out = latent_precision.clone();
5464 for i in 0..p {
5465 let scale = jac_diag[i];
5466 if !scale.is_finite() || scale <= 0.0 {
5467 crate::bail_invalid_estim!(
5468 "bounded precision transform requires a positive finite coefficient jacobian; column {i} has {scale}"
5469 );
5470 }
5471 if scale != 1.0 {
5472 out.row_mut(i).mapv_inplace(|v| v / scale);
5473 out.column_mut(i).mapv_inplace(|v| v / scale);
5474 }
5475 }
5476 Ok(out)
5477}
5478
5479fn fit_bounded_term_collection_with_design(
5480 y: ArrayView1<'_, f64>,
5481 weights: ArrayView1<'_, f64>,
5482 offset: ArrayView1<'_, f64>,
5483 spec: &TermCollectionSpec,
5484 design: &TermCollectionDesign,
5485 heuristic_lambdas: Option<&[f64]>,
5486 family: LikelihoodSpec,
5487 options: &FitOptions,
5488) -> Result<FittedTermCollection, EstimationError> {
5489 let conditioning_cols: Vec<usize> = spec
5490 .linear_terms
5491 .iter()
5492 .enumerate()
5493 .filter_map(|(j, linear)| {
5494 (!linear.double_penalty).then_some(design.intercept_range.end + j)
5495 })
5496 .collect();
5497 let conditioning = LinearFitConditioning::from_columns(design, &conditioning_cols);
5498 let dense_design = design.design.to_dense_cow();
5499 let fit_design = conditioning.apply_to_design(&dense_design);
5500 let fit_penalties = conditioning
5501 .transform_blockwise_penalties_to_internal(&design.penalties, design.design.ncols());
5502 if design.linear_constraints.is_some() {
5503 crate::bail_invalid_estim!(
5504 "bounded() terms are not yet compatible with explicit linear constraints"
5505 );
5506 }
5507 let mut bounded_terms = Vec::<BoundedLinearTermMeta>::new();
5508 for (j, term) in spec.linear_terms.iter().enumerate() {
5509 if term.double_penalty
5510 && matches!(
5511 term.coefficient_geometry,
5512 LinearCoefficientGeometry::Bounded { .. }
5513 )
5514 {
5515 crate::bail_invalid_estim!(
5516 "bounded linear term '{}' cannot also use double_penalty",
5517 term.name
5518 );
5519 }
5520 if let LinearCoefficientGeometry::Bounded { min, max, prior } =
5521 term.coefficient_geometry.clone()
5522 {
5523 let col_idx = design.intercept_range.end + j;
5524 let (min_internal, max_internal) = conditioning.internal_bounds_for(col_idx, min, max);
5525 bounded_terms.push(BoundedLinearTermMeta {
5526 col_idx,
5527 min: min_internal,
5528 max: max_internal,
5529 prior,
5530 });
5531 }
5532 }
5533 if bounded_terms.is_empty() {
5534 crate::bail_invalid_estim!("internal bounded fit path called with no bounded terms");
5535 }
5536
5537 let mut designzeroed = fit_design.clone();
5538 let mut initial_beta = Array1::<f64>::zeros(fit_design.ncols());
5539 for term in &bounded_terms {
5540 designzeroed.column_mut(term.col_idx).fill(0.0);
5541 initial_beta[term.col_idx] = bounded_logit(0.5);
5542 }
5543
5544 let initial_log_lambdas = heuristic_lambdas
5545 .map(|vals| Array1::from_vec(vals.to_vec()))
5546 .unwrap_or_else(|| Array1::zeros(fit_penalties.len()));
5547 if initial_log_lambdas.len() != fit_penalties.len() {
5548 crate::bail_invalid_estim!(
5549 "heuristic lambda length mismatch for bounded model: got {}, expected {}",
5550 initial_log_lambdas.len(),
5551 fit_penalties.len()
5552 );
5553 }
5554
5555 let is_beta_logistic = family.is_binomial_beta_logistic();
5556 let family_adapter = BoundedLinearFamily {
5557 family: family.clone(),
5558 latent_cloglog_state: options.latent_cloglog,
5559 mixture_link_state: options
5560 .mixture_link
5561 .clone()
5562 .as_ref()
5563 .map(state_fromspec)
5564 .transpose()
5565 .map_err(EstimationError::InvalidInput)?,
5566 sas_link_state: options
5567 .sas_link
5568 .map(|spec| {
5569 if is_beta_logistic {
5570 state_from_beta_logisticspec(spec)
5571 } else {
5572 state_from_sasspec(spec)
5573 }
5574 })
5575 .transpose()
5576 .map_err(EstimationError::InvalidInput)?,
5577 y: y.to_owned(),
5578 weights: weights.to_owned(),
5579 design: fit_design.clone(),
5580 designzeroed: designzeroed.clone(),
5581 offset: offset.to_owned(),
5582 bounded_terms: bounded_terms.clone(),
5583 };
5584 let blockspec = ParameterBlockSpec {
5585 name: "eta".to_string(),
5586 design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(designzeroed)),
5587 offset: offset.to_owned(),
5588 penalties: fit_penalties
5589 .iter()
5590 .map(|ps| match ps {
5591 PenaltySpec::Block {
5592 local, col_range, ..
5593 } => PenaltyMatrix::Blockwise {
5594 local: local.clone(),
5595 col_range: col_range.clone(),
5596 total_dim: design.design.ncols(),
5597 },
5598 PenaltySpec::Dense(m) | PenaltySpec::DenseWithMean { matrix: m, .. } => {
5599 PenaltyMatrix::Dense(m.clone())
5600 }
5601 })
5602 .collect(),
5603 nullspace_dims: design.nullspace_dims.clone(),
5604 initial_log_lambdas,
5605 initial_beta: Some(initial_beta),
5606 gauge_priority: 100,
5607 jacobian_callback: Some(Arc::new(BoundedEffectiveJacobian {
5613 design: fit_design.clone(),
5614 bounded_terms: bounded_terms.clone(),
5615 })),
5616 stacked_design: None,
5617 stacked_offset: None,
5618 };
5619 let fit = fit_custom_family(
5620 &family_adapter,
5621 &[blockspec],
5622 &BlockwiseFitOptions {
5623 inner_max_cycles: options.max_iter,
5624 inner_tol: options.tol,
5625 outer_max_iter: options.max_iter,
5626 outer_tol: options.tol,
5627 compute_covariance: false,
5637 ..BlockwiseFitOptions::default()
5638 },
5639 )
5640 .map_err(EstimationError::CustomFamily)?;
5641
5642 let latent_beta = fit.block_states[0].beta.clone();
5643 let (beta_user_internal, jac_diag) = family_adapter.user_beta_and_jacobian(&latent_beta);
5644 let beta_user = conditioning.backtransform_beta(&beta_user_internal);
5645
5646 let (eta_state, h_data, _, _) = family_adapter
5647 .evaluation_from_latent(&latent_beta)
5648 .map_err(EstimationError::InvalidInput)?;
5649 let p_fit = fit_design.ncols();
5650 let mut s_lambda_internal = Array2::<f64>::zeros((p_fit, p_fit));
5651 for (k, penalty) in fit_penalties.iter().enumerate() {
5652 match penalty {
5653 PenaltySpec::Block {
5654 local, col_range, ..
5655 } => {
5656 s_lambda_internal
5657 .slice_mut(ndarray::s![col_range.clone(), col_range.clone()])
5658 .scaled_add(fit.lambdas[k], local);
5659 }
5660 PenaltySpec::Dense(m) | PenaltySpec::DenseWithMean { matrix: m, .. } => {
5661 s_lambda_internal.scaled_add(fit.lambdas[k], m);
5662 }
5663 }
5664 }
5665 let mut latent_precision = h_data.clone();
5666 latent_precision += &s_lambda_internal;
5667 let user_precision_internal =
5668 transform_bounded_latent_precision_to_user_internal(&latent_precision, &jac_diag)?;
5669 let penalized_hessian =
5670 conditioning.transform_penalized_hessian_to_original(&user_precision_internal);
5671
5672 let beta_covariance_unscaled = if options.compute_inference {
5700 Some(symmetric_positive_definite_inverse_or_pseudo(
5701 &penalized_hessian,
5702 )?)
5703 } else {
5704 None
5705 };
5706 let latent_cov = if options.compute_inference {
5712 Some(symmetric_positive_definite_inverse_or_pseudo(
5713 &latent_precision,
5714 )?)
5715 } else {
5716 None
5717 };
5718 let s_lambda_original = weighted_blockwise_penalty_sum(
5719 &design.penalties,
5720 fit.lambdas.as_slice().unwrap(),
5721 design.design.ncols(),
5722 );
5723 let penalty_term = beta_user.dot(&s_lambda_original.dot(&beta_user));
5724 let deviance = if family.is_gaussian_identity() {
5725 y.iter()
5726 .zip(eta_state.mu.iter())
5727 .zip(weights.iter())
5728 .map(|((&yy, &mu), &w)| w.max(0.0) * (yy - mu) * (yy - mu))
5729 .sum()
5730 } else {
5731 -2.0 * eta_state.log_likelihood
5732 };
5733 let (edf_by_block, penalty_block_trace, edf_total) = if let Some(cov) = latent_cov.as_ref() {
5734 exact_bounded_edf(&fit_penalties, &fit.lambdas, cov)?
5735 } else {
5736 (
5737 vec![0.0; fit_penalties.len()],
5738 vec![0.0; fit_penalties.len()],
5739 0.0,
5740 )
5741 };
5742
5743 let glm_likelihood = gam_spec::GlmLikelihoodSpec::canonical(family.clone());
5755 let standard_deviation = if family.is_gaussian_identity() {
5756 let denom = if options.compute_inference {
5757 (y.len() as f64 - edf_total).max(1.0)
5758 } else {
5759 (y.len() as f64).max(1.0)
5760 };
5761 (deviance / denom).sqrt()
5762 } else {
5763 1.0
5764 };
5765 let cov_scale = glm_likelihood
5766 .coefficient_covariance_scale(standard_deviation * standard_deviation)
5767 .max(f64::MIN_POSITIVE);
5768 let dispersion = gam_solve::estimate::dispersion_from_likelihood(&glm_likelihood, standard_deviation);
5769 let beta_covariance = beta_covariance_unscaled.map(|mut cov| {
5775 if cov_scale != 1.0 {
5776 cov.mapv_inplace(|v| v * cov_scale);
5777 }
5778 cov
5779 });
5780 let beta_standard_errors = beta_covariance
5781 .as_ref()
5782 .map(|cov| Array1::from_iter((0..cov.nrows()).map(|i| cov[[i, i]].max(0.0).sqrt())));
5783
5784 let geometry = Some(gam_solve::estimate::FitGeometry {
5785 penalized_hessian: penalized_hessian.clone().into(),
5786 working_weights: eta_state.fisherweight.clone(),
5787 working_response: {
5788 let mut working_response = eta_state.eta.clone();
5789 for i in 0..working_response.len() {
5790 let wi = eta_state.fisherweight[i].max(1e-12);
5791 working_response[i] += eta_state.score[i] / wi;
5792 }
5793 working_response
5794 },
5795 });
5796 let max_abs_eta = eta_state
5797 .eta
5798 .iter()
5799 .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
5800 Ok(FittedTermCollection {
5801 fit: {
5802 let log_lambdas = fit.lambdas.mapv(|v| v.max(1e-300).ln());
5803 let inf = FitInference {
5804 edf_by_block,
5805 penalty_block_trace,
5806 edf_total,
5807 smoothing_correction: None,
5808 penalized_hessian: penalized_hessian.clone().into(),
5811 working_weights: eta_state.fisherweight.clone(),
5812 working_response: {
5813 let mut working_response = eta_state.eta.clone();
5814 for i in 0..working_response.len() {
5815 let wi = eta_state.fisherweight[i].max(1e-12);
5816 working_response[i] += eta_state.score[i] / wi;
5817 }
5818 working_response
5819 },
5820 reparam_qs: None,
5821 dispersion,
5822 beta_covariance: beta_covariance
5823 .clone()
5824 .map(gam_problem::dispersion_cov::PhiScaledCovariance::from),
5825 beta_standard_errors,
5826 beta_covariance_corrected: None,
5827 beta_standard_errors_corrected: None,
5828 beta_covariance_frequentist: None,
5829 coefficient_influence: None,
5830 weighted_gram: None,
5831 bias_correction_beta: None,
5832 };
5833 let covariance_conditional = beta_covariance;
5834 let pirls_status_val = if fit.outer_converged {
5835 gam_solve::pirls::PirlsStatus::Converged
5836 } else {
5837 gam_solve::pirls::PirlsStatus::StalledAtValidMinimum
5838 };
5839 UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
5840 blocks: vec![gam_solve::estimate::FittedBlock {
5841 beta: beta_user.clone(),
5842 role: gam_problem::BlockRole::Mean,
5843 edf: edf_total,
5844 lambdas: fit.lambdas.clone(),
5845 }],
5846 log_lambdas,
5847 lambdas: fit.lambdas,
5848 likelihood_scale: family.default_scale_metadata(),
5849 likelihood_family: Some(family),
5850 log_likelihood_normalization:
5851 gam_spec::LogLikelihoodNormalization::UserProvided,
5852 log_likelihood: eta_state.log_likelihood,
5853 deviance,
5854 reml_score: fit.penalized_objective,
5855 stable_penalty_term: penalty_term,
5856 penalized_objective: fit.penalized_objective,
5857 used_device: false,
5858 outer_iterations: fit.outer_iterations,
5859 outer_converged: fit.outer_converged,
5860 outer_gradient_norm: fit.outer_gradient_norm,
5861 standard_deviation,
5862 covariance_conditional,
5863 covariance_corrected: None,
5864 inference: Some(inf),
5865 fitted_link: gam_solve::estimate::FittedLinkState::Standard(None),
5866 geometry,
5867 block_states: Vec::new(),
5868 pirls_status: pirls_status_val,
5869 max_abs_eta,
5870 constraint_kkt: None,
5871 artifacts: gam_solve::estimate::FitArtifacts {
5872 pirls: None,
5873 ..Default::default()
5874 },
5875 inner_cycles: 0,
5876 })?
5877 },
5878 design: design.clone(),
5879 adaptive_diagnostics: None,
5880 })
5881}
5882
5883fn enforce_term_constraint_feasibility(
5884 design: &TermCollectionDesign,
5885 fit: &UnifiedFitResult,
5886) -> Result<(), EstimationError> {
5887 const CONSTRAINT_FEASIBILITY_RAW_TOL: f64 = 1e-7;
5901 let tol = CONSTRAINT_FEASIBILITY_RAW_TOL;
5902 let smooth_start = design
5903 .design
5904 .ncols()
5905 .saturating_sub(design.smooth.total_smooth_cols());
5906 let mut violations: Vec<String> = Vec::new();
5907 for term in &design.smooth.terms {
5908 let gr = (smooth_start + term.coeff_range.start)..(smooth_start + term.coeff_range.end);
5909 let beta_local = fit.beta.slice(s![gr.clone()]).to_owned();
5910 if let Some(lb) = term.lower_bounds_local.as_ref() {
5911 let mut worst = 0.0_f64;
5912 let mut worst_idx = 0usize;
5913 for i in 0..lb.len().min(beta_local.len()) {
5914 if lb[i].is_finite() {
5915 let viol = (lb[i] - beta_local[i]).max(0.0);
5916 if viol > worst {
5917 worst = viol;
5918 worst_idx = i;
5919 }
5920 }
5921 }
5922 if worst > tol {
5923 violations.push(format!(
5924 "term='{}' kind=lower-bound maxviolation={:.3e} coeff_index={}",
5925 term.name, worst, worst_idx
5926 ));
5927 }
5928 }
5929 if let Some(lin) = term.linear_constraints_local.as_ref() {
5930 let mut worst = 0.0_f64;
5931 let mut worstrow = 0usize;
5932 for i in 0..lin.a.nrows() {
5933 let norm = lin.a.row(i).dot(&lin.a.row(i)).sqrt();
5934 let inv = if norm > 0.0 { 1.0 / norm } else { 0.0 };
5935 let s = (lin.a.row(i).dot(&beta_local) - lin.b[i]) * inv;
5936 let viol = (-s).max(0.0);
5937 if viol > worst {
5938 worst = viol;
5939 worstrow = i;
5940 }
5941 }
5942 if worst > tol {
5943 violations.push(format!(
5944 "term='{}' kind=linear-inequality maxviolation={:.3e} row={}",
5945 term.name, worst, worstrow
5946 ));
5947 }
5948 }
5949 }
5950
5951 if !violations.is_empty() {
5952 let mut msg = format!(
5953 "constraint violation after fit ({} violating term constraints): {}",
5954 violations.len(),
5955 violations.join(" | ")
5956 );
5957 if let Some(kkt) = fit.constraint_kkt.as_ref() {
5958 msg.push_str(&format!(
5959 "; KKT[primal={:.3e}, dual={:.3e}, comp={:.3e}, stat={:.3e}]",
5960 kkt.primal_feasibility, kkt.dual_feasibility, kkt.complementarity, kkt.stationarity
5961 ));
5962 }
5963 return Err(EstimationError::ParameterConstraintViolation(msg));
5964 }
5965 Ok(())
5966}
5967
5968fn stratified_spatial_subsample(
5969 data: ArrayView2<'_, f64>,
5970 spec: &TermCollectionSpec,
5971 target_size: usize,
5972) -> Vec<usize> {
5973 use rand::SeedableRng;
5974 use rand::rngs::StdRng;
5975 use rand::seq::SliceRandom;
5976
5977 let n = data.nrows();
5978 if n <= target_size {
5979 return (0..n).collect();
5980 }
5981
5982 let spatial_cols: Option<Vec<usize>> =
5983 spec.smooth_terms.iter().find_map(|term| match &term.basis {
5984 SmoothBasisSpec::ThinPlate { feature_cols, .. }
5985 | SmoothBasisSpec::Matern { feature_cols, .. }
5986 | SmoothBasisSpec::Duchon { feature_cols, .. } => {
5987 if !feature_cols.is_empty() {
5988 Some(feature_cols.clone())
5989 } else {
5990 None
5991 }
5992 }
5993 _ => None,
5994 });
5995
5996 let cols = match spatial_cols {
5997 Some(c) if !c.is_empty() => c,
5998 _ => {
5999 let mut rng = StdRng::seed_from_u64(spatial_subsample_seed(data, &[], target_size));
6000 let mut indices: Vec<usize> = (0..n).collect();
6001 indices.shuffle(&mut rng);
6002 indices.truncate(target_size);
6003 indices.sort_unstable();
6004 return indices;
6005 }
6006 };
6007 let mut rng = StdRng::seed_from_u64(spatial_subsample_seed(data, &cols, target_size));
6008
6009 let d = cols.len();
6010 let mut mins = vec![f64::INFINITY; d];
6011 let mut maxs = vec![f64::NEG_INFINITY; d];
6012 for i in 0..n {
6013 for (ax, &col) in cols.iter().enumerate() {
6014 let v = data[[i, col]];
6015 if v < mins[ax] {
6016 mins[ax] = v;
6017 }
6018 if v > maxs[ax] {
6019 maxs[ax] = v;
6020 }
6021 }
6022 }
6023
6024 const TARGET_POINTS_PER_CELL: usize = 5;
6028 let total_cells_target = (target_size / TARGET_POINTS_PER_CELL).max(1);
6029 let cells_per_axis = ((total_cells_target as f64).powf(1.0 / d as f64)).ceil() as usize;
6030 let cells_per_axis = cells_per_axis.max(1);
6031
6032 let mut cell_members: std::collections::HashMap<Vec<usize>, Vec<usize>> =
6033 std::collections::HashMap::new();
6034 for i in 0..n {
6035 let mut cell_key = Vec::with_capacity(d);
6036 for (ax, &col) in cols.iter().enumerate() {
6037 let range = maxs[ax] - mins[ax];
6038 let cell = if range <= 0.0 {
6039 0
6040 } else {
6041 let frac = (data[[i, col]] - mins[ax]) / range;
6042 (frac * cells_per_axis as f64).floor() as usize
6043 };
6044 cell_key.push(cell.min(cells_per_axis - 1));
6045 }
6046 cell_members.entry(cell_key).or_default().push(i);
6047 }
6048
6049 let mut selected: Vec<usize> = Vec::with_capacity(target_size);
6050 let mut remaining_budget = target_size;
6051 let mut remaining_population = n;
6052
6053 let mut cells: Vec<(Vec<usize>, Vec<usize>)> = cell_members.into_iter().collect();
6054 cells.sort_by(|a, b| a.0.cmp(&b.0));
6055
6056 for (_, members) in &mut cells {
6057 if remaining_budget == 0 {
6058 break;
6059 }
6060 let alloc = ((members.len() as f64 / remaining_population as f64) * remaining_budget as f64)
6061 .round() as usize;
6062 let alloc = alloc.max(1).min(members.len()).min(remaining_budget);
6063 members.shuffle(&mut rng);
6064 selected.extend_from_slice(&members[..alloc]);
6065 remaining_budget = remaining_budget.saturating_sub(alloc);
6066 remaining_population = remaining_population.saturating_sub(members.len());
6067 }
6068
6069 if selected.len() > target_size {
6070 selected.shuffle(&mut rng);
6071 selected.truncate(target_size);
6072 }
6073
6074 selected.sort_unstable();
6075 selected
6076}
6077
6078fn spatial_subsample_seed(
6079 data: ArrayView2<'_, f64>,
6080 spatial_cols: &[usize],
6081 target_size: usize,
6082) -> u64 {
6083 let mut state = 0x5350_4154_4941_4C53_u64;
6084 spatial_seed_mix(&mut state, data.nrows() as u64);
6085 spatial_seed_mix(&mut state, data.ncols() as u64);
6086 spatial_seed_mix(&mut state, target_size as u64);
6087 spatial_seed_mix(&mut state, spatial_cols.len() as u64);
6088 for &col in spatial_cols {
6089 spatial_seed_mix(&mut state, col as u64);
6090 }
6091
6092 if data.nrows() > 0 {
6093 let mid = data.nrows() / 2;
6094 let last = data.nrows() - 1;
6095 for &row in &[0usize, mid, last] {
6096 for &col in spatial_cols {
6097 let value = data[[row, col]];
6098 spatial_seed_mix(&mut state, value.to_bits());
6099 }
6100 }
6101 }
6102 state
6103}
6104
6105#[inline]
6106fn spatial_seed_mix(state: &mut u64, value: u64) {
6107 let mut s = value.wrapping_add(*state);
6110 let z = gam_linalg::utils::splitmix64(&mut s);
6111 *state ^= z;
6112 *state = (*state).rotate_left(27).wrapping_mul(0x3C79_AC49_2BA7_B653);
6113}
6114
6115fn sampled_rows(data: ArrayView2<'_, f64>, indices: &[usize]) -> Array2<f64> {
6116 let mut sampled = Array2::<f64>::zeros((indices.len(), data.ncols()));
6117 for (new_row, &orig_row) in indices.iter().enumerate() {
6118 sampled.row_mut(new_row).assign(&data.row(orig_row));
6119 }
6120 sampled
6121}
6122
6123fn spatial_term_user_centers(term: &SmoothTermSpec) -> Option<ArrayView2<'_, f64>> {
6124 match spatial_term_center_strategy(term) {
6125 Some(CenterStrategy::UserProvided(centers)) => Some(centers.view()),
6126 _ => None,
6127 }
6128}
6129
6130fn finite_centered_axis_contrasts(values: &[f64], expected_dim: usize) -> Option<Vec<f64>> {
6131 if values.len() != expected_dim || expected_dim <= 1 {
6132 return None;
6133 }
6134 if values.iter().any(|value| !value.is_finite()) {
6135 return None;
6136 }
6137 Some(center_aniso_log_scales(values))
6138}
6139
6140fn blended_pilot_axis_contrasts(
6141 pilot_data: ArrayView2<'_, f64>,
6142 term: &SmoothTermSpec,
6143 centers: ArrayView2<'_, f64>,
6144) -> Option<Vec<f64>> {
6145 let d = centers.ncols();
6146 if d <= 1 {
6147 return None;
6148 }
6149 let center_eta = initial_aniso_contrasts(centers);
6150 let data_eta = standardized_spatial_term_data(pilot_data, term)
6151 .ok()
6152 .and_then(|x| finite_centered_axis_contrasts(&initial_aniso_contrasts(x.view()), d));
6153 let center_eta = finite_centered_axis_contrasts(¢er_eta, d)?;
6154 let blended = match data_eta {
6155 Some(data_eta) => center_eta
6156 .iter()
6157 .zip(data_eta.iter())
6158 .map(|(&from_centers, &from_data)| 0.5 * (from_centers + from_data))
6159 .collect::<Vec<_>>(),
6160 None => center_eta,
6161 };
6162 finite_centered_axis_contrasts(&blended, d)
6163}
6164
6165fn apply_pilot_spatial_psi_reseed(
6166 pilot_data: ArrayView2<'_, f64>,
6167 spec: &TermCollectionSpec,
6168 spatial_terms: &[usize],
6169 kappa_options: &SpatialLengthScaleOptimizationOptions,
6170) -> Result<TermCollectionSpec, EstimationError> {
6171 let dims_per_term = spatial_dims_per_term(spec, spatial_terms);
6172 let use_aniso = has_aniso_terms(spec, spatial_terms);
6173 let log_kappa0 = if use_aniso {
6174 SpatialLogKappaCoords::from_length_scales_aniso(spec, spatial_terms, kappa_options)
6175 } else {
6176 SpatialLogKappaCoords::from_length_scales(spec, spatial_terms, kappa_options)
6177 };
6178 let log_kappa0 = log_kappa0.reseed_from_data(pilot_data, spec, spatial_terms, kappa_options);
6179 let log_kappa_lower = if use_aniso {
6180 SpatialLogKappaCoords::lower_bounds_aniso_from_data(
6181 pilot_data,
6182 spec,
6183 spatial_terms,
6184 &dims_per_term,
6185 kappa_options,
6186 )
6187 } else {
6188 SpatialLogKappaCoords::lower_bounds_from_data(
6189 pilot_data,
6190 spec,
6191 spatial_terms,
6192 kappa_options,
6193 )
6194 };
6195 let log_kappa_upper = if use_aniso {
6196 SpatialLogKappaCoords::upper_bounds_aniso_from_data(
6197 pilot_data,
6198 spec,
6199 spatial_terms,
6200 &dims_per_term,
6201 kappa_options,
6202 )
6203 } else {
6204 SpatialLogKappaCoords::upper_bounds_from_data(
6205 pilot_data,
6206 spec,
6207 spatial_terms,
6208 kappa_options,
6209 )
6210 };
6211 log_kappa0
6212 .clamp_to_bounds(&log_kappa_lower, &log_kappa_upper)
6213 .apply_tospec(spec, spatial_terms)
6214}
6215
6216pub(crate) fn apply_spatial_anisotropy_pilot_initializer(
6217 data: ArrayView2<'_, f64>,
6218 spec: &mut TermCollectionSpec,
6219 spatial_terms: &[usize],
6220 target_size: usize,
6221 kappa_options: &SpatialLengthScaleOptimizationOptions,
6222) -> usize {
6223 if target_size == 0 || data.nrows() <= target_size.saturating_mul(2) || spatial_terms.is_empty()
6224 {
6225 return 0;
6226 }
6227 if !has_aniso_terms(spec, spatial_terms) {
6228 return 0;
6229 }
6230 let indices = stratified_spatial_subsample(data, spec, target_size);
6231 let pilot_data = sampled_rows(data, &indices);
6232 let mut working = spec.clone();
6233 let mut updated_terms = 0usize;
6234 const GEOMETRY_UPDATES: usize = 2;
6235
6236 for pass in 0..GEOMETRY_UPDATES {
6237 let planned_terms = match plan_joint_spatial_centers_for_term_blocks(
6238 pilot_data.view(),
6239 &[working.smooth_terms.clone()],
6240 )
6241 .and_then(|mut blocks| {
6242 blocks.pop().ok_or_else(|| {
6243 BasisError::InvalidInput(
6244 "pilot geometry initializer produced no smooth-term block".to_string(),
6245 )
6246 })
6247 }) {
6248 Ok(terms) => terms,
6249 Err(err) => {
6250 log::warn!(
6251 "[spatial-kappa] pilot geometry initializer skipped after center planning failed: {err}"
6252 );
6253 return updated_terms;
6254 }
6255 };
6256
6257 for &term_idx in spatial_terms {
6258 let Some(current_eta) = get_spatial_aniso_log_scales(&working, term_idx) else {
6259 continue;
6260 };
6261 let Some(d) = get_spatial_feature_dim(&working, term_idx) else {
6262 continue;
6263 };
6264 if d <= 1 || current_eta.len() != d {
6265 continue;
6266 }
6267 let Some(planned_term) = planned_terms.get(term_idx) else {
6268 continue;
6269 };
6270 let Some(centers) = spatial_term_user_centers(planned_term) else {
6271 continue;
6272 };
6273 let Some(eta) = blended_pilot_axis_contrasts(pilot_data.view(), planned_term, centers)
6274 else {
6275 continue;
6276 };
6277 if set_spatial_aniso_log_scales(&mut working, term_idx, eta).is_ok() {
6278 updated_terms += usize::from(pass == 0);
6279 }
6280 }
6281
6282 match apply_pilot_spatial_psi_reseed(
6283 pilot_data.view(),
6284 &working,
6285 spatial_terms,
6286 kappa_options,
6287 ) {
6288 Ok(updated) => {
6289 working = updated;
6290 }
6291 Err(err) => {
6292 log::warn!(
6293 "[spatial-kappa] pilot geometry ψ reseed skipped after deterministic initializer error: {err}"
6294 );
6295 break;
6296 }
6297 }
6298 }
6299
6300 if updated_terms > 0 {
6301 log::info!(
6302 "[spatial-kappa] initialized anisotropy from {}-row pilot geometry for {} spatial term(s); proceeding to full-data optimization",
6303 indices.len(),
6304 updated_terms
6305 );
6306 *spec = working;
6307 }
6308 updated_terms
6309}
6310
6311pub(crate) fn spatial_length_scale_term_indices(spec: &TermCollectionSpec) -> Vec<usize> {
6312 spec.smooth_terms
6313 .iter()
6314 .enumerate()
6315 .filter_map(|(idx, _)| spatial_term_supports_hyper_optimization(spec, idx).then_some(idx))
6316 .collect()
6317}
6318
6319fn fit_score(fit: &UnifiedFitResult) -> f64 {
6331 if fit.reml_score.is_finite() {
6332 return fit.reml_score;
6333 }
6334 let score = 0.5 * fit.deviance + 0.5 * fit.stable_penalty_term;
6335 if score.is_finite() {
6336 score
6337 } else {
6338 f64::INFINITY
6339 }
6340}
6341
6342fn is_recoverable_trial_point_error(err: &EstimationError) -> bool {
6360 matches!(err, EstimationError::BasisError(_))
6361}
6362
6363fn require_successful_spatial_optimization_result<T>(
6364 initial_score: f64,
6365 result: Result<Option<(T, f64)>, EstimationError>,
6366) -> Result<T, EstimationError> {
6367 match result {
6368 Ok(Some((value, exact_score))) => {
6369 const SCORE_DRIFT_ABS_TOL: f64 = 1e-6;
6378 const SCORE_DRIFT_REL_TOL: f64 = 1e-8;
6379 let tol = SCORE_DRIFT_ABS_TOL.max(initial_score.abs() * SCORE_DRIFT_REL_TOL);
6380 if exact_score <= initial_score + tol {
6381 Ok(value)
6382 } else {
6383 Err(EstimationError::RemlOptimizationFailed(format!(
6384 "spatial kappa optimization made REML score worse ({initial_score:.6e} -> {exact_score:.6e})"
6385 )))
6386 }
6387 }
6388 Ok(None) => Err(EstimationError::RemlOptimizationFailed(
6389 "spatial kappa optimization is unavailable for one or more eligible spatial terms"
6390 .to_string(),
6391 )),
6392 Err(err) => Err(EstimationError::RemlOptimizationFailed(format!(
6393 "spatial kappa optimization failed: {err}"
6394 ))),
6395 }
6396}
6397
6398fn external_opts_for_design(
6399 family: &LikelihoodSpec,
6400 design: &TermCollectionDesign,
6401 options: &FitOptions,
6402) -> ExternalOptimOptions {
6403 ExternalOptimOptions {
6404 family: family.clone(),
6405 latent_cloglog: options.latent_cloglog,
6406 mixture_link: options.mixture_link.clone(),
6407 optimize_mixture: options.optimize_mixture,
6408 sas_link: options.sas_link,
6409 optimize_sas: options.optimize_sas,
6410 compute_inference: options.compute_inference,
6411 skip_rho_posterior_inference: options.skip_rho_posterior_inference,
6412 max_iter: options.max_iter,
6413 tol: options.tol,
6414 nullspace_dims: design.nullspace_dims.clone(),
6415 linear_constraints: design.linear_constraints.clone(),
6416 firth_bias_reduction: Some(options.firth_bias_reduction),
6417 penalty_shrinkage_floor: options.penalty_shrinkage_floor,
6418 rho_prior: options.rho_prior.clone(),
6419 kronecker_penalty_system: design.kronecker_penalty_system(),
6422 kronecker_factored: design
6423 .smooth
6424 .terms
6425 .iter()
6426 .find_map(|t| t.kronecker_factored.clone()),
6427 persist_warm_start_disk: options.persist_warm_start_disk,
6428 }
6429}
6430
6431fn evaluate_joint_reml_outer_eval_at_theta(
6439 evaluator: &mut gam_solve::estimate::ExternalJointHyperEvaluator<'_>,
6440 design: &TermCollectionDesign,
6441 theta: &Array1<f64>,
6442 rho_dim: usize,
6443 hyper_dirs: Vec<gam_solve::estimate::reml::DirectionalHyperParam>,
6444 warm_start_beta: Option<ArrayView1<'_, f64>>,
6445 order: gam_solve::rho_optimizer::OuterEvalOrder,
6446 design_revision: Option<u64>,
6447) -> Result<
6448 (
6449 f64,
6450 Array1<f64>,
6451 gam_problem::HessianResult,
6452 ),
6453 EstimationError,
6454> {
6455 evaluator.evaluate_with_order(
6456 &design.design,
6457 &design.penalties,
6458 &design.nullspace_dims,
6459 design.linear_constraints.clone(),
6460 theta,
6461 rho_dim,
6462 hyper_dirs,
6463 warm_start_beta,
6464 "evaluate_joint_reml_outer_eval_at_theta",
6465 order,
6466 design_revision,
6467 )
6468}
6469
6470fn evaluate_joint_reml_efs_at_theta(
6471 evaluator: &mut gam_solve::estimate::ExternalJointHyperEvaluator<'_>,
6472 design: &TermCollectionDesign,
6473 theta: &Array1<f64>,
6474 rho_dim: usize,
6475 hyper_dirs: Vec<gam_solve::estimate::reml::DirectionalHyperParam>,
6476 warm_start_beta: Option<ArrayView1<'_, f64>>,
6477 design_revision: Option<u64>,
6478) -> Result<gam_problem::EfsEval, EstimationError> {
6479 evaluator.evaluate_efs(
6480 &design.design,
6481 &design.penalties,
6482 &design.nullspace_dims,
6483 design.linear_constraints.clone(),
6484 theta,
6485 rho_dim,
6486 hyper_dirs,
6487 warm_start_beta,
6488 "evaluate_joint_reml_efs_at_theta",
6489 design_revision,
6490 )
6491}
6492
6493fn exact_joint_spatial_outer_hessian_available(
6494 family: &LikelihoodSpec,
6495 design: &TermCollectionDesign,
6496) -> bool {
6497 let family_supported = match &family.response {
6520 ResponseFamily::Gaussian
6521 | ResponseFamily::Binomial
6522 | ResponseFamily::Poisson
6523 | ResponseFamily::Tweedie { .. }
6524 | ResponseFamily::NegativeBinomial { .. }
6525 | ResponseFamily::Beta { .. }
6526 | ResponseFamily::Gamma
6527 | ResponseFamily::RoystonParmar => true,
6528 };
6529 family_supported && design.design.ncols() > 0
6532}
6533
6534fn smooth_term_penalty_index(
6535 spec: &TermCollectionSpec,
6536 design: &TermCollectionDesign,
6537 term_idx: usize,
6538) -> Option<usize> {
6539 if term_idx >= design.smooth.terms.len() || term_idx >= spec.smooth_terms.len() {
6540 return None;
6541 }
6542 if design.smooth.terms[term_idx].penalties_local.is_empty() {
6543 return None;
6544 }
6545 let linear_penalties = spec
6546 .linear_terms
6547 .iter()
6548 .filter(|t| t.double_penalty)
6549 .count()
6550 * 2;
6551 let random_penalties = design
6552 .random_effect_ranges
6553 .iter()
6554 .filter(|(_, range)| !range.is_empty())
6555 .count();
6556 let smooth_offset = linear_penalties + random_penalties;
6557 let local_offset = design
6558 .smooth
6559 .terms
6560 .iter()
6561 .take(term_idx)
6562 .map(|term| term.penalties_local.len())
6563 .sum::<usize>();
6564 Some(smooth_offset + local_offset)
6565}
6566
6567fn try_build_spatial_term_log_kappa_derivativeinfo(
6568 data: ArrayView2<'_, f64>,
6569 resolvedspec: &TermCollectionSpec,
6570 design: &TermCollectionDesign,
6571 term_idx: usize,
6572) -> Result<Option<SpatialPsiDerivative>, EstimationError> {
6573 let Some((
6574 global_range,
6575 total_p,
6576 x_psi_local,
6577 s_psi_local_check,
6578 x_psi_psi_local,
6579 s_psi_psi_local,
6580 s_psi_components_local,
6581 s_psi_psi_components_local,
6582 implicit_operator,
6583 )) = try_build_spatial_term_log_kappa_derivative(data, resolvedspec, design, term_idx)?
6584 else {
6585 return Ok(None);
6586 };
6587 let Some(penalty_start) = smooth_term_penalty_index(resolvedspec, design, term_idx) else {
6588 return Ok(None);
6589 };
6590 if s_psi_components_local.is_empty() || s_psi_psi_components_local.is_empty() {
6591 return Ok(None);
6592 }
6593 if s_psi_components_local.len() != s_psi_psi_components_local.len() {
6594 return Ok(None);
6595 }
6596 let penalty_indices = (0..s_psi_components_local.len())
6597 .map(|j| penalty_start + j)
6598 .collect::<Vec<_>>();
6599 let penalty_index = penalty_indices[0];
6600 if s_psi_local_check.nrows() == 0 || s_psi_psi_local.nrows() == 0 {
6601 return Ok(None);
6602 }
6603 Ok(Some(SpatialPsiDerivative {
6604 penalty_index,
6605 penalty_indices,
6606 global_range,
6607 total_p,
6608 x_psi_local,
6609 s_psi_components_local,
6610 x_psi_psi_local,
6611 s_psi_psi_components_local,
6612 aniso_group_id: None,
6613 aniso_cross_designs: None,
6614 aniso_cross_penalty_provider: None,
6615 implicit_operator,
6616 implicit_axis: 0,
6617 }))
6618}
6619
6620pub(crate) fn try_build_spatial_log_kappa_derivativeinfo_list(
6621 data: ArrayView2<'_, f64>,
6622 resolvedspec: &TermCollectionSpec,
6623 design: &TermCollectionDesign,
6624 spatial_terms: &[usize],
6625) -> Result<Option<Vec<SpatialPsiDerivative>>, EstimationError> {
6626 let mut out = Vec::new();
6627 let mut aniso_gid = 0usize;
6628 for &term_idx in spatial_terms {
6629 if spatial_term_uses_per_axis_psi(resolvedspec, term_idx) {
6630 if let Some(entries) = try_build_spatial_term_log_kappa_aniso_derivativeinfos(
6631 data,
6632 resolvedspec,
6633 design,
6634 term_idx,
6635 aniso_gid,
6636 )? {
6637 aniso_gid += 1;
6638 out.extend(entries);
6639 continue;
6640 } else {
6641 return Ok(None);
6642 }
6643 }
6644 let Some(info) =
6645 try_build_spatial_term_log_kappa_derivativeinfo(data, resolvedspec, design, term_idx)?
6646 else {
6647 return Ok(None);
6648 };
6649 out.push(info);
6650 }
6651 Ok(Some(out))
6652}
6653
6654fn try_build_spatial_term_log_kappa_aniso_derivativeinfos(
6656 data: ArrayView2<'_, f64>,
6657 resolvedspec: &TermCollectionSpec,
6658 design: &TermCollectionDesign,
6659 term_idx: usize,
6660 aniso_group_id: usize,
6661) -> Result<Option<Vec<SpatialPsiDerivative>>, EstimationError> {
6662 let Some(smooth_term) = design.smooth.terms.get(term_idx) else {
6663 return Ok(None);
6664 };
6665 let Some(termspec) = resolvedspec.smooth_terms.get(term_idx) else {
6666 return Ok(None);
6667 };
6668 let mut aniso_result = match &termspec.basis {
6669 SmoothBasisSpec::Sphere { .. } => return Ok(None),
6670 SmoothBasisSpec::Matern {
6671 feature_cols,
6672 spec,
6673 input_scales,
6674 } => {
6675 let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
6676 if let Some(s) = input_scales {
6677 apply_input_standardization(&mut x, s);
6678 }
6679 let mut spec_operator = spec.clone();
6688 spec_operator.double_penalty = false;
6689 build_matern_basis_log_kappa_aniso_derivatives(x.view(), &spec_operator)
6690 .map_err(EstimationError::from)?
6691 }
6692 SmoothBasisSpec::MeasureJet {
6698 feature_cols,
6699 spec,
6700 input_scales,
6701 } => {
6702 let mut x = select_columns(data, feature_cols).map_err(EstimationError::from)?;
6703 if let Some(s) = input_scales {
6704 apply_input_standardization(&mut x, s);
6705 }
6706 build_measure_jet_basis_psi_derivatives(x.view(), spec)
6707 .map_err(EstimationError::from)?
6708 }
6709 _ => return Ok(None),
6710 };
6711 let d = if let Some(ref op) = aniso_result.implicit_operator {
6714 op.n_axes()
6715 } else if !aniso_result.design_first.is_empty() {
6716 aniso_result.design_first.len()
6717 } else {
6718 0
6719 };
6720 if d == 0 {
6721 return Ok(None);
6722 }
6723 let Some(penalty_start) = smooth_term_penalty_index(resolvedspec, design, term_idx) else {
6724 return Ok(None);
6725 };
6726 let p_total = design.design.ncols();
6727 let smooth_start = p_total.saturating_sub(design.smooth.total_smooth_cols());
6728 let global_range = (smooth_start + smooth_term.coeff_range.start)
6729 ..(smooth_start + smooth_term.coeff_range.end);
6730 let num_penalties = aniso_result.penalties_first[0].len();
6731 let penalty_indices: Vec<usize> = (0..num_penalties).map(|j| penalty_start + j).collect();
6732 let penalties_cross_provider = aniso_result.penalties_cross_provider.clone();
6733
6734 let use_implicit_design = aniso_result.design_first.is_empty();
6738 let implicit_op_arc = aniso_result
6739 .implicit_operator
6740 .as_ref()
6741 .map(|op| std::sync::Arc::new(op.clone()));
6742
6743 let mut entries = Vec::with_capacity(d);
6744 for a in 0..d {
6745 let (x_psi_local, x_psi_psi_local) = if use_implicit_design {
6746 (Array2::<f64>::zeros((0, 0)), Array2::<f64>::zeros((0, 0)))
6752 } else {
6753 let x_first = std::mem::take(&mut aniso_result.design_first[a]);
6758 let x_second = std::mem::take(&mut aniso_result.design_second_diag[a]);
6759 if x_first.ncols() != smooth_term.coeff_range.len() {
6760 return Ok(None);
6761 }
6762 (x_first, x_second)
6763 };
6764 let s_psi_components = std::mem::take(&mut aniso_result.penalties_first[a]);
6765 let s_psi_psi_components = std::mem::take(&mut aniso_result.penalties_second_diag[a]);
6766 let cross_designs = if implicit_op_arc.is_some() {
6772 let mut cd = Vec::with_capacity(d - 1);
6773 for b in 0..d {
6774 if b == a {
6775 continue;
6776 }
6777 cd.push((b, Array2::<f64>::zeros((0, 0))));
6778 }
6779 cd
6780 } else if !aniso_result.design_second_cross.is_empty() {
6781 let mut cd = Vec::new();
6782 for (cross_idx, &(pa, pb)) in aniso_result.design_second_cross_pairs.iter().enumerate()
6783 {
6784 if pa == a {
6785 cd.push((pb, aniso_result.design_second_cross[cross_idx].clone()));
6786 } else if pb == a {
6787 cd.push((pa, aniso_result.design_second_cross[cross_idx].clone()));
6788 }
6789 }
6790 cd
6791 } else {
6792 Vec::new()
6793 };
6794 let cross_penalty_provider = if d > 1 {
6795 let penalties_cross_provider = penalties_cross_provider.clone();
6796 Some(std::sync::Arc::new(
6797 move |b_axis: usize| -> Result<Vec<Array2<f64>>, EstimationError> {
6798 if b_axis == a {
6799 return Ok(Vec::new());
6800 }
6801 let (axis_lo, axis_hi) = if a < b_axis { (a, b_axis) } else { (b_axis, a) };
6802 if let Some(provider) = penalties_cross_provider.as_ref() {
6803 provider
6804 .evaluate(axis_lo, axis_hi)
6805 .map_err(EstimationError::from)
6806 } else {
6807 Ok(Vec::new())
6811 }
6812 },
6813 )
6814 as std::sync::Arc<
6815 dyn Fn(usize) -> Result<Vec<Array2<f64>>, EstimationError>
6816 + Send
6817 + Sync
6818 + 'static,
6819 >)
6820 } else {
6821 None
6822 };
6823
6824 entries.push(SpatialPsiDerivative {
6825 penalty_index: penalty_indices[0],
6826 penalty_indices: penalty_indices.clone(),
6827 global_range: global_range.clone(),
6828 total_p: p_total,
6829 x_psi_local,
6830 s_psi_components_local: s_psi_components,
6831 x_psi_psi_local,
6832 s_psi_psi_components_local: s_psi_psi_components,
6833 aniso_group_id: Some(aniso_group_id),
6834 aniso_cross_designs: if cross_designs.is_empty() {
6835 None
6836 } else {
6837 Some(cross_designs)
6838 },
6839 aniso_cross_penalty_provider: cross_penalty_provider,
6840 implicit_operator: implicit_op_arc.clone(),
6841 implicit_axis: a,
6842 });
6843 }
6844 Ok(Some(entries))
6845}
6846
6847#[cfg(test)]
6848mod glm_eta_observation_fd_tests {
6849 use super::*;
6855
6856 fn one_obs(spec: &LikelihoodSpec, y: f64, eta: f64) -> StandardFamilyObservationState {
6857 let yv = Array1::from_vec(vec![y]);
6858 let wv = Array1::from_vec(vec![1.0]);
6859 let ev = Array1::from_vec(vec![eta]);
6860 evaluate_standard_familyobservations(spec.clone(), None, None, None, &yv, &wv, &ev)
6861 .expect("standard family observation state assembles")
6862 }
6863
6864 fn check_fd(label: &str, spec: &LikelihoodSpec, y: f64, eta: f64) {
6865 let h = 1e-5;
6866 let s0 = one_obs(spec, y, eta);
6867 let sp = one_obs(spec, y, eta + h);
6868 let sm = one_obs(spec, y, eta - h);
6869
6870 let score_fd = (sp.log_likelihood - sm.log_likelihood) / (2.0 * h);
6872 let score = s0.score[0];
6873 assert!(
6874 (score - score_fd).abs() <= 1e-4 * (1.0 + score.abs()),
6875 "{label}: score {score} vs FD {score_fd}"
6876 );
6877
6878 let neghess_fd = -(sp.score[0] - sm.score[0]) / (2.0 * h);
6880 let neghess = s0.neghessian_eta[0];
6881 assert!(
6882 (neghess - neghess_fd).abs() <= 1e-3 * (1.0 + neghess.abs()),
6883 "{label}: neghessian_eta {neghess} vs FD {neghess_fd}"
6884 );
6885
6886 let nhd_fd = (sp.neghessian_eta[0] - sm.neghessian_eta[0]) / (2.0 * h);
6888 let nhd = s0.neghessian_eta_derivative[0];
6889 assert!(
6890 (nhd - nhd_fd).abs() <= 1e-2 * (1.0 + nhd.abs()),
6891 "{label}: neghessian_eta_derivative {nhd} vs FD {nhd_fd}"
6892 );
6893 }
6894
6895 #[test]
6896 fn poisson_gamma_nb_tweedie_arms_match_finite_differences_1615_1616() {
6897 let log = InverseLink::Standard(StandardLink::Log);
6898 let poisson = LikelihoodSpec {
6899 response: ResponseFamily::Poisson,
6900 link: log.clone(),
6901 };
6902 check_fd("poisson y=3", &poisson, 3.0, 0.4);
6903 check_fd("poisson y=0", &poisson, 0.0, -0.2);
6904
6905 let gamma = LikelihoodSpec {
6906 response: ResponseFamily::Gamma,
6907 link: log.clone(),
6908 };
6909 check_fd("gamma y=2.5", &gamma, 2.5, 0.3);
6910 check_fd("gamma y=0.7", &gamma, 0.7, -0.1);
6911
6912 let nb = LikelihoodSpec {
6913 response: ResponseFamily::NegativeBinomial {
6914 theta: 1.5,
6915 theta_fixed: true,
6916 },
6917 link: log.clone(),
6918 };
6919 check_fd("negbin y=4", &nb, 4.0, 0.5);
6920 check_fd("negbin y=0", &nb, 0.0, -0.3);
6921
6922 let tweedie = LikelihoodSpec {
6923 response: ResponseFamily::Tweedie { p: 1.5 },
6924 link: log.clone(),
6925 };
6926 check_fd("tweedie y=2", &tweedie, 2.0, 0.25);
6927 check_fd("tweedie y=0.5", &tweedie, 0.5, -0.15);
6928 }
6929}