1use self::inner_strategy::GeometryBackendKind;
2use super::*;
3use gam_linalg::sparse_exact::SparseExactFactor;
4use crate::pirls::PIRLS_CACHE_BYTE_BUDGET;
5use crate::pirls::assemble_and_factor_sparse_penalized_system;
6use gam_terms::basis::LocalDesignJacobianProvider;
7use gam_problem::SasLinkState;
8use gam_problem::OuterEval;
9use ndarray::{Array1, Array2, s};
10use std::collections::{HashMap, VecDeque};
11use std::ops::Range;
12use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize};
13use std::sync::{Arc, RwLock};
14
15pub mod assembly;
16pub mod atoms;
17pub(crate) mod continuation;
18pub(crate) mod eval;
19mod firth;
20pub(super) mod hyper;
21mod inner_strategy;
22pub mod jeffreys_subspace;
25pub mod outer_eval;
26pub mod penalty_logdet;
27pub mod per_atom_efs;
28pub mod reml_outer_engine;
29mod rho_key;
30mod sparse_exact_penalty;
31mod trace;
32
33pub(crate) use sparse_exact_penalty::sparse_penalty_block_count_from_canonical;
34
35pub(crate) const EXACT_TAU_TAU_HESSIAN_DENSE_CACHE_BUDGET_BYTES: usize = 512 * 1024 * 1024;
36pub(crate) const FIRTH_MAX_OBSERVATIONS: usize = 20_000;
37pub(crate) const FIRTH_MAX_COEFFICIENTS: usize = 256;
38pub(crate) const FIRTH_MAX_LINEAR_WORK: usize = 2_000_000;
39pub(crate) const FIRTH_MAX_QUADRATIC_WORK: usize = 100_000_000;
40pub(crate) const PERSISTENT_LATENT_VALUES_CACHE_CAPACITY: usize = 8;
41
42#[derive(Debug)]
43pub(crate) struct PersistentLatentValuesCache {
44 pub(crate) entries: HashMap<String, Array2<f64>>,
45 pub(crate) lru: VecDeque<String>,
46 pub(crate) capacity: usize,
47}
48
49impl Default for PersistentLatentValuesCache {
50 fn default() -> Self {
51 Self {
52 entries: HashMap::new(),
53 lru: VecDeque::new(),
54 capacity: PERSISTENT_LATENT_VALUES_CACHE_CAPACITY,
55 }
56 }
57}
58
59impl PersistentLatentValuesCache {
60 pub(crate) fn lookup(
61 &mut self,
62 key: &str,
63 n_obs: usize,
64 latent_dim: usize,
65 ) -> Option<Array2<f64>> {
66 let values = self.entries.get(key)?;
67 if values.dim() != (n_obs, latent_dim) {
68 return None;
69 }
70 let values = values.clone();
71 self.touch(key.to_string());
72 Some(values)
73 }
74
75 pub(crate) fn insert(&mut self, key: String, values: Array2<f64>) {
76 if values.iter().any(|value| !value.is_finite()) {
77 return;
78 }
79 self.entries.insert(key.clone(), values);
80 self.touch(key);
81 while self.entries.len() > self.capacity {
82 let Some(evicted) = self.lru.pop_front() else {
83 break;
84 };
85 self.entries.remove(&evicted);
86 }
87 }
88
89 pub(crate) fn touch(&mut self, key: String) {
90 if let Some(index) = self.lru.iter().position(|queued| queued == &key) {
91 self.lru.remove(index);
92 }
93 self.lru.push_back(key);
94 }
95}
96
97#[derive(Clone)]
102pub(crate) struct IftWarmStartCache {
103 pub beta_original: ndarray::Array1<f64>,
109 pub rho: ndarray::Array1<f64>,
112 pub penalized_hessian_transformed: gam_linalg::matrix::SymmetricMatrix,
116 pub qs: ndarray::Array2<f64>,
120 pub frame_was_original: bool,
124 pub lambda_s_beta_blocks: Option<Vec<ndarray::Array1<f64>>>,
141}
142
143#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
144pub(crate) struct TauTauPlanEstimate {
145 pub(crate) dense_x_bytes: usize,
146 pub(crate) first_order_tau_bytes: usize,
147 pub(crate) second_order_tau_bytes: usize,
148 pub(crate) penalty_first_bytes: usize,
149 pub(crate) penalty_pair_bytes: usize,
150 pub(crate) rho_tau_penalty_bytes: usize,
151 pub(crate) vector_cache_bytes: usize,
152 pub(crate) weighted_scratch_bytes: usize,
153}
154
155impl TauTauPlanEstimate {
156 pub(crate) fn total_bytes(self) -> usize {
157 self.dense_x_bytes
158 .saturating_add(self.first_order_tau_bytes)
159 .saturating_add(self.second_order_tau_bytes)
160 .saturating_add(self.penalty_first_bytes)
161 .saturating_add(self.penalty_pair_bytes)
162 .saturating_add(self.rho_tau_penalty_bytes)
163 .saturating_add(self.vector_cache_bytes)
164 .saturating_add(self.weighted_scratch_bytes)
165 }
166}
167
168#[derive(Clone, Copy, Debug, PartialEq, Eq)]
169pub(crate) struct TauTauHessianPolicy {
170 pub(crate) any_has_implicit: bool,
171 pub(crate) implicit_multidim_duchon: bool,
172 pub(crate) estimated_dense_tau_cache_bytes: usize,
173 pub(crate) gradient_plan: TauTauPlanEstimate,
174 pub(crate) hessian_plan: TauTauPlanEstimate,
175 pub(crate) budget_bytes: usize,
176 pub(crate) firth_pair_terms_unavailable: bool,
177}
178
179impl TauTauHessianPolicy {
180 pub(crate) fn prefer_gradient_only(self) -> bool {
208 self.firth_pair_terms_unavailable
209 }
210}
211
212pub(crate) fn exact_tau_tau_hessian_policy_with_firth(
213 n_obs: usize,
214 p_coeff: usize,
215 hyper_dirs: &[DirectionalHyperParam],
216 firth_pair_terms_unavailable: bool,
217) -> TauTauHessianPolicy {
218 let f64_bytes = std::mem::size_of::<f64>();
219 let dense_matrix_bytes =
220 |rows: usize, cols: usize| -> usize { rows.saturating_mul(cols).saturating_mul(f64_bytes) };
221 let dense_design_bytes = dense_matrix_bytes(n_obs, p_coeff);
222 let dense_penalty_bytes = dense_matrix_bytes(p_coeff, p_coeff);
223 let psi_dim = hyper_dirs.len();
224 let implicit_n_axes = hyper_dirs
225 .iter()
226 .find_map(DirectionalHyperParam::implicit_axis_count_hint)
227 .unwrap_or(0);
228 let gradient_uses_implicit_design = hyper_dirs
229 .iter()
230 .any(DirectionalHyperParam::has_implicit_operator)
231 && gam_terms::basis::should_use_implicit_operators_with_policy(
232 n_obs,
233 p_coeff,
234 implicit_n_axes,
235 &gam_runtime::resource::ResourcePolicy::default_library(),
236 );
237 let dense_first_order_count = hyper_dirs
238 .iter()
239 .filter(|dir| !dir.has_implicit_operator())
240 .count();
241 let first_penalty_component_count = hyper_dirs
242 .iter()
243 .map(DirectionalHyperParam::penalty_first_component_count)
244 .sum::<usize>();
245
246 let mut dense_second_order_count = 0usize;
247 let mut penalty_pair_count = 0usize;
248 for i in 0..psi_dim {
249 for j in i..psi_dim {
250 if hyper_dirs[i]
251 .x_tau_tau_entry_at(j)
252 .or_else(|| hyper_dirs[j].x_tau_tau_entry_at(i))
253 .is_some_and(|entry| !entry.uses_implicit_storage())
254 {
255 dense_second_order_count += if i == j { 1 } else { 2 };
256 }
257 if hyper_dirs[i].has_penaltysecond_pair_at(j)
258 || hyper_dirs[j].has_penaltysecond_pair_at(i)
259 {
260 penalty_pair_count += if i == j { 1 } else { 2 };
261 }
262 }
263 }
264
265 let gradient_dense_first_order_count = if gradient_uses_implicit_design {
266 dense_first_order_count
267 } else {
268 psi_dim
269 };
270 let gradient_needs_dense_x =
271 firth_pair_terms_unavailable || gradient_dense_first_order_count > 0;
272 let gradient_plan = TauTauPlanEstimate {
273 dense_x_bytes: if gradient_needs_dense_x {
274 dense_design_bytes
275 } else {
276 0
277 },
278 first_order_tau_bytes: if gradient_dense_first_order_count > 0 {
279 dense_design_bytes
280 } else {
281 0
282 },
283 second_order_tau_bytes: 0,
284 penalty_first_bytes: psi_dim.saturating_mul(dense_penalty_bytes),
285 penalty_pair_bytes: 0,
286 rho_tau_penalty_bytes: 0,
287 vector_cache_bytes: n_obs.saturating_mul(f64_bytes),
288 weighted_scratch_bytes: dense_penalty_bytes,
289 };
290 let hessian_plan = TauTauPlanEstimate {
291 dense_x_bytes: if psi_dim > 0 { dense_design_bytes } else { 0 },
292 first_order_tau_bytes: dense_first_order_count.saturating_mul(dense_design_bytes),
293 second_order_tau_bytes: dense_second_order_count.saturating_mul(dense_design_bytes),
294 penalty_first_bytes: psi_dim.saturating_mul(dense_penalty_bytes),
295 penalty_pair_bytes: penalty_pair_count.saturating_mul(dense_penalty_bytes),
296 rho_tau_penalty_bytes: first_penalty_component_count
297 .saturating_mul(2)
298 .saturating_mul(dense_penalty_bytes),
299 vector_cache_bytes: psi_dim.saturating_mul(n_obs).saturating_mul(f64_bytes),
300 weighted_scratch_bytes: dense_penalty_bytes,
301 };
302 let any_has_implicit = hyper_dirs
303 .iter()
304 .any(DirectionalHyperParam::has_implicit_operator);
305 let implicit_multidim_duchon = hyper_dirs
306 .iter()
307 .any(DirectionalHyperParam::has_implicit_multidim_duchon);
308 let estimated_dense_tau_cache_bytes = hessian_plan
309 .first_order_tau_bytes
310 .saturating_add(hessian_plan.second_order_tau_bytes);
311 TauTauHessianPolicy {
312 any_has_implicit,
313 implicit_multidim_duchon,
314 estimated_dense_tau_cache_bytes,
315 gradient_plan,
316 hessian_plan,
317 budget_bytes: EXACT_TAU_TAU_HESSIAN_DENSE_CACHE_BUDGET_BYTES,
318 firth_pair_terms_unavailable: firth_pair_terms_unavailable && !hyper_dirs.is_empty(),
319 }
320}
321
322pub(crate) fn firth_problem_scale_allows(n_obs: usize, p_coeff: usize) -> bool {
323 let linear_work = n_obs.saturating_mul(p_coeff);
324 let quadratic_work = linear_work.saturating_mul(p_coeff);
325 n_obs <= FIRTH_MAX_OBSERVATIONS
326 && p_coeff <= FIRTH_MAX_COEFFICIENTS
327 && linear_work <= FIRTH_MAX_LINEAR_WORK
328 && quadratic_work <= FIRTH_MAX_QUADRATIC_WORK
329}
330
331#[cfg(test)]
332mod tests {
333 use super::{
334 DirectionalHyperParam, EvalCacheManager, EvalShared, HyperDesignDerivative,
335 HyperPenaltyDerivative, ImplicitDerivLevel, RemlConfig, RemlState,
336 };
337 use crate::estimate::EstimationError;
338 use gam_linalg::faer_ndarray::FaerCholesky;
339 use gam_linalg::matrix::symmetrize_in_place;
340 use crate::pirls::PirlsCoordinateFrame;
341 use gam_terms::basis::{ImplicitDesignPsiDerivative, RadialScalarKind};
342 use gam_problem::{
343 GlmLikelihoodSpec, InverseLink, LikelihoodSpec, ResponseFamily, StandardLink,
344 };
345 use faer::Side;
346 use gam_problem::{HessianResult, OuterEval};
347 use ndarray::{Array1, Array2, array, s};
348 use std::sync::Arc;
349
350 pub(crate) fn binomial_logit_glm_spec() -> GlmLikelihoodSpec {
353 GlmLikelihoodSpec::canonical(LikelihoodSpec::new(
354 ResponseFamily::Binomial,
355 InverseLink::Standard(StandardLink::Logit),
356 ))
357 }
358
359 pub(crate) fn gaussian_identity_glm_spec() -> GlmLikelihoodSpec {
362 GlmLikelihoodSpec::canonical(LikelihoodSpec::new(
363 ResponseFamily::Gaussian,
364 InverseLink::Standard(StandardLink::Identity),
365 ))
366 }
367
368 impl DirectionalHyperParam {
369 pub(super) fn new(
370 x_tau_original: Array2<f64>,
371 penalty_first_components: Vec<(usize, Array2<f64>)>,
372 x_tau_tau_original: Option<Vec<Option<Array2<f64>>>>,
373 penaltysecond_components: Option<Vec<Option<Vec<(usize, Array2<f64>)>>>>,
374 ) -> Result<Self, EstimationError> {
375 let x_tau_tau_original = x_tau_tau_original.map(|rows| {
376 rows.into_iter()
377 .map(|entry| entry.map(HyperDesignDerivative::from))
378 .collect::<Vec<_>>()
379 });
380 let penalty_first_components = penalty_first_components
381 .into_iter()
382 .map(|(idx, matrix)| (idx, HyperPenaltyDerivative::from(matrix)))
383 .collect();
384 let penaltysecond_components = penaltysecond_components.map(|rows| {
385 rows.into_iter()
386 .map(|row| {
387 row.map(|components| {
388 components
389 .into_iter()
390 .map(|(idx, matrix)| (idx, HyperPenaltyDerivative::from(matrix)))
391 .collect::<Vec<_>>()
392 })
393 })
394 .collect::<Vec<_>>()
395 });
396 Self::new_compact(
397 HyperDesignDerivative::from(x_tau_original),
398 penalty_first_components,
399 x_tau_tau_original,
400 penaltysecond_components,
401 )
402 }
403
404 pub(super) fn single_penalty(
405 penalty_index: usize,
406 x_tau_original: Array2<f64>,
407 s_tau_original: Array2<f64>,
408 x_tau_tau_original: Option<Vec<Option<Array2<f64>>>>,
409 s_tau_tau_original: Option<Vec<Option<Array2<f64>>>>,
410 ) -> Result<Self, EstimationError> {
411 let penaltysecond_components = s_tau_tau_original.map(|rows| {
412 rows.into_iter()
413 .map(|mat| mat.map(|mat| vec![(penalty_index, mat)]))
414 .collect::<Vec<_>>()
415 });
416 Self::new(
417 x_tau_original,
418 vec![(penalty_index, s_tau_original)],
419 x_tau_tau_original,
420 penaltysecond_components,
421 )
422 }
423 }
424
425 #[test]
426 pub(crate) fn firth_problem_scale_gate_blocks_large_quadratic_work() {
427 assert!(super::firth_problem_scale_allows(2_000, 200));
428 assert!(!super::firth_problem_scale_allows(4_800, 241));
429 assert!(!super::firth_problem_scale_allows(4_800, 433));
430 }
431
432 #[test]
433 pub(crate) fn tau_tau_hessian_policy_prefers_gradient_only_for_implicit_tau() {
434 let operator = ImplicitDesignPsiDerivative::new(
435 array![1.0, 2.0, 3.0, 4.0],
436 array![0.5, -1.0, 1.5, 2.0],
437 array![0.1, 0.2, 0.3, 0.4],
438 array![[1.0, 0.2], [0.5, 0.1], [1.5, 0.3], [2.0, 0.4]],
439 None,
440 None,
441 2,
442 2,
443 1,
444 2,
445 );
446 let dir = DirectionalHyperParam::new_compact(
447 HyperDesignDerivative::from_implicit(
448 Arc::new(operator),
449 ImplicitDerivLevel::First(0),
450 1..4,
451 5,
452 ),
453 Vec::new(),
454 None,
455 None,
456 )
457 .expect("implicit directional hyperparam");
458 let policy = super::exact_tau_tau_hessian_policy_with_firth(10, 5, &[dir], false);
459 assert!(policy.any_has_implicit);
460 assert_eq!(
461 policy.gradient_plan.dense_x_bytes,
462 10 * 5 * std::mem::size_of::<f64>()
463 );
464 assert!(!policy.prefer_gradient_only());
465 }
466
467 #[test]
468 pub(crate) fn tau_tau_hessian_policy_does_not_force_gradient_only_for_implicit_multidim_duchon()
469 {
470 let operator = ImplicitDesignPsiDerivative::new_streaming(
478 Arc::new(array![[0.0, 0.0], [1.0, 0.2]]),
479 Arc::new(array![[0.0, 0.0], [1.0, 1.0]]),
480 vec![0.0, 0.0],
481 RadialScalarKind::PureDuchon {
482 block_order: 1,
483 p_order: 0,
484 s_order: 0,
485 dim: 2,
486 },
487 None,
488 None,
489 0,
490 );
491 let dir = DirectionalHyperParam::new_compact(
492 HyperDesignDerivative::from_implicit(
493 Arc::new(operator),
494 ImplicitDerivLevel::First(0),
495 0..2,
496 2,
497 ),
498 Vec::new(),
499 None,
500 None,
501 )
502 .expect("implicit duchon directional hyperparam");
503 let policy = super::exact_tau_tau_hessian_policy_with_firth(10, 5, &[dir], false);
504 assert!(policy.any_has_implicit);
505 assert!(policy.implicit_multidim_duchon);
506 assert!(!policy.prefer_gradient_only());
507 }
508
509 #[test]
510 pub(crate) fn tau_tau_hessian_policy_does_not_force_gradient_only_when_cache_budget_is_exceeded()
511 {
512 let dirs = (0..16)
518 .map(|_| {
519 DirectionalHyperParam::new_compact(
520 HyperDesignDerivative::from(Array2::<f64>::zeros((2, 2))),
521 Vec::new(),
522 None,
523 None,
524 )
525 .expect("dense directional hyperparam")
526 })
527 .collect::<Vec<_>>();
528 let policy = super::exact_tau_tau_hessian_policy_with_firth(320_000, 71, &dirs, false);
529 assert!(!policy.any_has_implicit);
530 assert!(policy.hessian_plan.total_bytes() > policy.budget_bytes);
531 assert!(policy.hessian_plan.total_bytes() > policy.gradient_plan.total_bytes());
532 assert!(!policy.prefer_gradient_only());
533 }
534
535 #[test]
536 pub(crate) fn tau_tau_hessian_policy_prefers_gradient_only_for_firth_pair_gap() {
537 let dir = DirectionalHyperParam::new_compact(
538 HyperDesignDerivative::from(Array2::<f64>::zeros((2, 2))),
539 Vec::new(),
540 None,
541 None,
542 )
543 .expect("dense directional hyperparam");
544 let policy = super::exact_tau_tau_hessian_policy_with_firth(10, 5, &[dir], true);
545 assert!(policy.firth_pair_terms_unavailable);
546 assert!(policy.prefer_gradient_only());
547 }
548
549 trait LogitDesignMotionFixture {
557 fn y(&self) -> &Array1<f64>;
558 fn w(&self) -> &Array1<f64>;
559 fn x(&self) -> &Array2<f64>;
560 fn s0(&self) -> &Array2<f64>;
561 fn cfg(&self) -> &RemlConfig;
562 fn rho(&self) -> &Array1<f64>;
563
564 fn state(&self) -> RemlState<'_> {
565 build_logit_state(self.y(), self.w(), self.x(), self.s0(), self.cfg())
566 }
567
568 fn state_perturbed(
569 &self,
570 x_tau: &Array2<f64>,
571 s_tau: &Array2<f64>,
572 eps: f64,
573 ) -> (RemlState<'_>, RemlState<'_>) {
574 let x_plus = self.x() + &x_tau.mapv(|v| eps * v);
575 let x_minus = self.x() - &x_tau.mapv(|v| eps * v);
576 let s_plus = self.s0() + &s_tau.mapv(|v| eps * v);
577 let s_minus = self.s0() - &s_tau.mapv(|v| eps * v);
578 (
579 build_logit_state(self.y(), self.w(), &x_plus, &s_plus, self.cfg()),
580 build_logit_state(self.y(), self.w(), &x_minus, &s_minus, self.cfg()),
581 )
582 }
583
584 fn fd_directional_gradient(&self, x_tau: &Array2<f64>, s_tau: &Array2<f64>) -> f64 {
586 let h = 2e-5;
587 let (state_plus, state_minus) = self.state_perturbed(x_tau, s_tau, h);
588 let v_plus = state_plus.compute_cost(self.rho()).expect("cost+");
589 let v_minus = state_minus.compute_cost(self.rho()).expect("cost-");
590 (v_plus - v_minus) / (2.0 * h)
591 }
592 }
593
594 pub(crate) fn build_logit_state<'a>(
595 y: &'a Array1<f64>,
596 w: &'a Array1<f64>,
597 x: &Array2<f64>,
598 s: &Array2<f64>,
599 cfg: &'a RemlConfig,
600 ) -> RemlState<'a> {
601 use crate::estimate::PenaltySpec;
602 let p = x.ncols();
603 let offset = Array1::<f64>::zeros(y.len());
604 let spec = PenaltySpec::Dense(s.clone());
605 let canonical = gam_terms::construction::canonicalize_penalty_specs(&[spec], &[1], p, "test")
606 .map(|(canonical, _)| canonical)
607 .expect("canonicalize");
608 RemlState::newwith_offset(
609 y.view(),
610 x.clone(),
611 w.view(),
612 offset.view(),
613 canonical,
614 p,
615 cfg,
616 Some(vec![1]),
617 None,
618 None,
619 )
620 .expect("state")
621 }
622
623 #[test]
624 fn repeated_penalty_ranges_keep_analytic_outer_hessian() {
625 let y = array![0.2, -0.1, 0.3, 0.0];
626 let w = Array1::<f64>::ones(y.len());
627 let x = array![[1.0, -0.7], [1.0, -0.2], [1.0, 0.3], [1.0, 0.9]];
628 let offset = Array1::<f64>::zeros(y.len());
629 let cfg = RemlConfig::external(gaussian_identity_glm_spec(), 1e-10, false);
630 let p = x.ncols();
631 let canonical = vec![
632 gam_terms::construction::CanonicalPenalty::from_dense_root(array![[0.0, 1.0]], p),
633 gam_terms::construction::CanonicalPenalty::from_dense_root(array![[1.0, 0.0]], p),
634 ];
635 let state = RemlState::newwith_offset(
636 y.view(),
637 x,
638 w.view(),
639 offset.view(),
640 canonical,
641 p,
642 &cfg,
643 Some(vec![1, 1]),
644 None,
645 None,
646 )
647 .expect("state");
648
649 assert!(
650 state.analytic_outer_hessian_enabled(),
651 "double-penalty-style repeated coefficient ranges must still route to exact Hessian"
652 );
653 }
654
655 #[test]
656 fn canonical_logit_firth_declines_exact_tk_hessian_when_row_pair_work_is_large() {
657 let n = 2_000usize;
658 let p = 28usize;
659 let y = Array1::from_iter((0..n).map(|i| if i % 3 == 0 { 1.0 } else { 0.0 }));
660 let w = Array1::<f64>::ones(n);
661 let mut x = Array2::<f64>::zeros((n, p));
662 for i in 0..n {
663 let t = (i as f64 + 0.5) / n as f64;
664 x[[i, 0]] = 1.0;
665 for j in 1..p {
666 x[[i, j]] = ((j as f64) * std::f64::consts::TAU * t).sin()
667 + 0.25 * (((j + 1) as f64) * std::f64::consts::TAU * t).cos();
668 }
669 }
670 let mut s = Array2::<f64>::zeros((p, p));
671 for j in 1..p {
672 s[[j, j]] = 1.0;
673 }
674 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, true);
675 let state = build_logit_state(&y, &w, &x, &s, &cfg);
676
677 assert!(
678 !RemlState::firth_tk_exact_hessian_scale_allows(n, p),
679 "fixture must sit beyond the O(n²·p) exact-Hessian budget"
680 );
681 assert!(
682 !state.analytic_outer_hessian_enabled(),
683 "large canonical-logit Firth fits should keep exact value/gradient but route outer curvature to BFGS"
684 );
685 }
686
687 #[test]
688 fn canonical_logit_firth_keeps_exact_tk_hessian_for_small_separation_guards() {
689 let n = 40usize;
690 let p = 6usize;
691 let y = Array1::from_iter((0..n).map(|i| if i >= n / 2 { 1.0 } else { 0.0 }));
692 let w = Array1::<f64>::ones(n);
693 let mut x = Array2::<f64>::zeros((n, p));
694 for i in 0..n {
695 let t = (i as f64) / (n - 1) as f64;
696 x[[i, 0]] = 1.0;
697 for j in 1..p {
698 x[[i, j]] = t.powi(j as i32);
699 }
700 }
701 let mut s = Array2::<f64>::zeros((p, p));
702 for j in 1..p {
703 s[[j, j]] = 1.0;
704 }
705 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, true);
706 let state = build_logit_state(&y, &w, &x, &s, &cfg);
707
708 assert!(RemlState::firth_tk_exact_hessian_scale_allows(n, p));
709 assert!(
710 state.analytic_outer_hessian_enabled(),
711 "small Firth rescue fits should keep exact TK Hessian curvature"
712 );
713 }
714
715 pub(crate) fn poisson_log_glm_spec() -> GlmLikelihoodSpec {
716 GlmLikelihoodSpec::canonical(LikelihoodSpec::new(
717 ResponseFamily::Poisson,
718 InverseLink::Standard(StandardLink::Log),
719 ))
720 }
721
722 #[test]
736 pub(crate) fn fixed_dispersion_laml_surface_is_replication_invariant() {
737 let n = 200usize;
738 let p = 8usize;
739 let c = 3usize;
740 let mut x = Array2::<f64>::zeros((n, p));
741 let mut y = Array1::<f64>::zeros(n);
742 for i in 0..n {
743 let t = (i as f64) / ((n - 1) as f64);
744 let tau = std::f64::consts::TAU;
745 x[[i, 0]] = 1.0;
746 x[[i, 1]] = t;
747 x[[i, 2]] = (tau * t).sin();
748 x[[i, 3]] = (tau * t).cos();
749 x[[i, 4]] = (2.0 * tau * t).sin();
750 x[[i, 5]] = (2.0 * tau * t).cos();
751 x[[i, 6]] = (3.0 * tau * t).sin();
752 x[[i, 7]] = (3.0 * tau * t).cos();
753 let eta = 0.3 + 0.9 * (1.4 * (t - 0.5)).sin();
754 y[i] = (eta.exp() + 0.5 * ((i as f64) * 2.399_963).sin())
756 .round()
757 .max(0.0);
758 }
759 let mut s = Array2::<f64>::zeros((p, p));
760 for j in 1..p {
761 s[[j, j]] = 1.0;
762 }
763
764 let mut x_rep = Array2::<f64>::zeros((n * c, p));
766 let mut y_rep = Array1::<f64>::zeros(n * c);
767 for r in 0..c {
768 for i in 0..n {
769 let row = r * n + i;
770 for j in 0..p {
771 x_rep[[row, j]] = x[[i, j]];
772 }
773 y_rep[row] = y[i];
774 }
775 }
776
777 let w_weighted = Array1::<f64>::from_elem(n, c as f64);
778 let w_rep = Array1::<f64>::ones(n * c);
779
780 let cfg = RemlConfig::external(poisson_log_glm_spec(), 1e-10, false);
781 let st_w = build_logit_state(&y, &w_weighted, &x, &s, &cfg);
782 let st_r = build_logit_state(&y_rep, &w_rep, &x_rep, &s, &cfg);
783
784 for &rho in &[-2.0_f64, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0] {
785 let r = Array1::from_elem(1, rho);
786 let cw = st_w.compute_cost(&r).expect("weighted cost");
787 let cr = st_r.compute_cost(&r).expect("replicated cost");
788 let gw = st_w.compute_gradient(&r).expect("weighted grad");
789 let gr = st_r.compute_gradient(&r).expect("replicated grad");
790 assert!(
793 (cw - cr).abs() <= 1e-9 * (1.0 + cw.abs()),
794 "LAML cost differs between w=c and c× replication at rho={rho}: \
795 cost_w={cw:.12e} cost_r={cr:.12e} diff={:.3e}",
796 cw - cr
797 );
798 assert!(
799 (gw[0] - gr[0]).abs() <= 1e-9 * (1.0 + gw[0].abs()),
800 "LAML gradient differs between w=c and c× replication at rho={rho}: \
801 g_w={:.12e} g_r={:.12e} diff={:.3e}",
802 gw[0],
803 gr[0],
804 gw[0] - gr[0]
805 );
806 }
807 }
808
809 #[test]
818 pub(crate) fn rho_weight_anchor_is_zero_for_fixed_dispersion() {
819 let n = 50usize;
820 let p = 3usize;
821 let mut x = Array2::<f64>::zeros((n, p));
822 let mut y = Array1::<f64>::zeros(n);
823 for i in 0..n {
824 let t = (i as f64) / ((n - 1) as f64);
825 x[[i, 0]] = 1.0;
826 x[[i, 1]] = t;
827 x[[i, 2]] = t * t;
828 y[i] = (1.0 + (3.0 * t).sin()).round().max(0.0);
829 }
830 let mut s = Array2::<f64>::zeros((p, p));
831 s[[2, 2]] = 1.0;
832 let c = 4.0_f64;
834 let w = Array1::<f64>::from_elem(n, c);
835
836 let cfg_pois = RemlConfig::external(poisson_log_glm_spec(), 1e-10, false);
837 let st_pois = build_logit_state(&y, &w, &x, &s, &cfg_pois);
838 assert_eq!(
839 st_pois.rho_weight_anchor(),
840 0.0,
841 "fixed-dispersion (Poisson) anchor must be 0, not the geometric-mean log-weight"
842 );
843
844 let cfg_gauss = RemlConfig::external(gaussian_identity_glm_spec(), 1e-10, false);
845 let st_gauss = build_logit_state(&y, &w, &x, &s, &cfg_gauss);
846 assert!(
847 (st_gauss.rho_weight_anchor() - c.ln()).abs() <= 1e-12,
848 "Gaussian-identity (profiled) anchor must be the geometric-mean log-weight ln(c)={:.6}, got {:.6}",
849 c.ln(),
850 st_gauss.rho_weight_anchor()
851 );
852 }
853
854 pub(crate) fn beta_original_from_bundle(bundle: &EvalShared) -> Array1<f64> {
855 let pr = bundle.pirls_result.as_ref();
856 match pr.coordinate_frame {
857 PirlsCoordinateFrame::OriginalSparseNative => pr.beta_transformed.as_ref().clone(),
858 PirlsCoordinateFrame::TransformedQs => {
859 pr.reparam_result.qs.dot(pr.beta_transformed.as_ref())
860 }
861 }
862 }
863
864 pub(crate) fn compute_joint_hypercostgradienthessian(
865 state: &RemlState<'_>,
866 theta: &Array1<f64>,
867 rho_dim: usize,
868 hyper_dirs: &[DirectionalHyperParam],
869 ) -> Result<(f64, Array1<f64>, Array2<f64>), EstimationError> {
870 let (cost, gradient, hessian) = state.compute_joint_hyper_eval_with_order(
871 theta,
872 rho_dim,
873 hyper_dirs,
874 crate::rho_optimizer::OuterEvalOrder::ValueGradientHessian,
875 )?;
876 Ok((
877 cost,
878 gradient,
879 hessian
880 .materialize_dense()
881 .map_err(EstimationError::RemlOptimizationFailed)?
882 .ok_or_else(|| {
883 EstimationError::RemlOptimizationFailed(
884 "joint hyper Hessian requested but unavailable".to_string(),
885 )
886 })?,
887 ))
888 }
889
890 pub(crate) fn h_original_from_bundle(bundle: &EvalShared) -> Array2<f64> {
891 let pr = bundle.pirls_result.as_ref();
892 match pr.coordinate_frame {
893 PirlsCoordinateFrame::OriginalSparseNative => bundle.h_total.as_ref().clone(),
894 PirlsCoordinateFrame::TransformedQs => {
895 let qs = &pr.reparam_result.qs;
896 let tmp = gam_linalg::faer_ndarray::fast_ab(qs, bundle.h_total.as_ref());
897 gam_linalg::faer_ndarray::fast_abt(&tmp, qs)
898 }
899 }
900 }
901
902 pub(crate) fn single_directional_tau_gradient(
903 state: &RemlState<'_>,
904 rho: &Array1<f64>,
905 hyper: DirectionalHyperParam,
906 ) -> Result<f64, EstimationError> {
907 let mut theta = Array1::<f64>::zeros(rho.len() + 1);
908 theta.slice_mut(s![..rho.len()]).assign(rho);
909 let (_, gradient, _) = state.compute_joint_hyper_eval_with_order(
910 &theta,
911 rho.len(),
912 &[hyper],
913 crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
914 )?;
915 Ok(gradient[rho.len()])
916 }
917
918 pub(crate) fn fd_directional_tau_cost_gradient(
919 y: &Array1<f64>,
920 w: &Array1<f64>,
921 x: &Array2<f64>,
922 s0: &Array2<f64>,
923 cfg: &RemlConfig,
924 rho: &Array1<f64>,
925 x_tau: &Array2<f64>,
926 s_tau: &Array2<f64>,
927 ) -> f64 {
928 let h = 2e-5;
929 let x_plus = x + &x_tau.mapv(|v| h * v);
930 let x_minus = x - &x_tau.mapv(|v| h * v);
931 let s_plus = s0 + &s_tau.mapv(|v| h * v);
932 let s_minus = s0 - &s_tau.mapv(|v| h * v);
933 let state_plus = build_logit_state(y, w, &x_plus, &s_plus, cfg);
934 let state_minus = build_logit_state(y, w, &x_minus, &s_minus, cfg);
935 let v_plus = state_plus.compute_cost(rho).expect("cost+");
936 let v_minus = state_minus.compute_cost(rho).expect("cost-");
937 (v_plus - v_minus) / (2.0 * h)
938 }
939
940 pub(crate) fn directional_tau_hessian_fd_reference(
941 y: &Array1<f64>,
942 w: &Array1<f64>,
943 x: &Array2<f64>,
944 s0: &Array2<f64>,
945 cfg: &RemlConfig,
946 rho: &Array1<f64>,
947 hyper_dirs: &[DirectionalHyperParam],
948 x_tau_mats: &[Array2<f64>],
949 s_tau_mats: &[Array2<f64>],
950 ) -> Array2<f64> {
951 assert_eq!(hyper_dirs.len(), x_tau_mats.len());
952 assert_eq!(hyper_dirs.len(), s_tau_mats.len());
953
954 const TARGET_PHYSICAL_STEP: f64 = 1e-5;
955
956 let n_dirs = hyper_dirs.len();
957 let mut h_ttfd = Array2::<f64>::zeros((n_dirs, n_dirs));
958 for j in 0..n_dirs {
959 let direction_scale = x_tau_mats[j]
960 .iter()
961 .chain(s_tau_mats[j].iter())
962 .fold(0.0_f64, |acc, value| acc.max(value.abs()));
963 let h = if direction_scale > 0.0 {
964 TARGET_PHYSICAL_STEP / direction_scale
965 } else {
966 TARGET_PHYSICAL_STEP
967 };
968
969 let x_plus = x + &x_tau_mats[j].mapv(|v| h * v);
970 let x_minus = x - &x_tau_mats[j].mapv(|v| h * v);
971 let s_plus = s0 + &s_tau_mats[j].mapv(|v| h * v);
972 let s_minus = s0 - &s_tau_mats[j].mapv(|v| h * v);
973
974 let state_plus = build_logit_state(y, w, &x_plus, &s_plus, cfg);
975 let state_minus = build_logit_state(y, w, &x_minus, &s_minus, cfg);
976 for i in 0..n_dirs {
977 let g_plus =
978 single_directional_tau_gradient(&state_plus, rho, hyper_dirs[i].clone())
979 .expect("g+ for FD");
980 let g_minus =
981 single_directional_tau_gradient(&state_minus, rho, hyper_dirs[i].clone())
982 .expect("g- for FD");
983 h_ttfd[[i, j]] = (g_plus - g_minus) / (2.0 * h);
984 }
985 }
986 symmetrize_in_place(&mut h_ttfd);
987 h_ttfd
988 }
989
990 #[test]
991 pub(crate) fn eval_cache_manager_stores_first_order_outer_eval() {
992 let cache = EvalCacheManager::new();
993 let rho = array![0.25, -0.0];
994 let rho_key = EvalCacheManager::sanitized_rhokey(&rho);
995 let eval = OuterEval {
996 cost: 3.5,
997 gradient: array![1.0, -2.0],
998 hessian: HessianResult::Unavailable,
999 inner_beta_hint: None,
1000 };
1001
1002 cache.store_outer_eval(&rho_key, &eval);
1003
1004 let cached = cache
1005 .cached_outer_eval(&rho_key)
1006 .expect("first-order outer eval should be cached");
1007 assert_eq!(cached.cost, eval.cost);
1008 assert_eq!(cached.gradient, eval.gradient);
1009 assert!(matches!(cached.hessian, HessianResult::Unavailable));
1010
1011 cache.invalidate_eval_bundle();
1012 assert!(
1013 cache.cached_outer_eval(&rho_key).is_none(),
1014 "invalidating the bundle should clear the outer-eval cache too"
1015 );
1016 }
1017
1018 #[test]
1028 pub(crate) fn outer_eval_lru_hit_is_bit_identical_and_evicts_honestly_1575() {
1029 use super::OUTER_EVAL_LRU_CAPACITY;
1030
1031 let make_eval = |seed: f64| OuterEval {
1034 cost: (seed * std::f64::consts::PI).sin() / 3.0 - seed,
1035 gradient: array![seed, -seed * 2.0, seed.recip()],
1036 hessian: HessianResult::Unavailable,
1037 inner_beta_hint: Some(array![seed + 0.5, seed - 0.5]),
1038 };
1039 let bits_eq = |a: &OuterEval, b: &OuterEval| -> bool {
1040 a.cost.to_bits() == b.cost.to_bits()
1041 && a.gradient.len() == b.gradient.len()
1042 && a.gradient
1043 .iter()
1044 .zip(b.gradient.iter())
1045 .all(|(x, y)| x.to_bits() == y.to_bits())
1046 };
1047
1048 let cache = EvalCacheManager::new();
1049
1050 let rho_a = array![0.25, -1.5];
1053 let key_a = EvalCacheManager::sanitized_rhokey(&rho_a);
1054 let eval_a = make_eval(0.25);
1055 cache.store_outer_eval(&key_a, &eval_a);
1056 let hit_a = cache
1057 .cached_outer_eval(&key_a)
1058 .expect("stored rho_a must hit");
1059 assert!(
1060 bits_eq(&hit_a, &eval_a),
1061 "cache hit must be bit-identical (cost+gradient) to the stored miss-path eval"
1062 );
1063 assert_eq!(
1064 hit_a.inner_beta_hint.as_ref().map(|b| b.to_vec()),
1065 eval_a.inner_beta_hint.as_ref().map(|b| b.to_vec()),
1066 "inner_beta_hint must round-trip unchanged"
1067 );
1068
1069 let rho_b = array![0.25, -1.4999999999999998];
1072 let key_b = EvalCacheManager::sanitized_rhokey(&rho_b);
1073 assert_ne!(key_a, key_b, "the two rho-keys must differ");
1074 let eval_b = make_eval(7.0);
1075 cache.store_outer_eval(&key_b, &eval_b);
1076 assert!(
1077 bits_eq(
1078 &cache.cached_outer_eval(&key_b).expect("rho_b must hit"),
1079 &eval_b
1080 ),
1081 "rho_b must return its own eval, not rho_a's"
1082 );
1083 assert!(
1084 bits_eq(
1085 &cache.cached_outer_eval(&key_a).expect("rho_a must still hit"),
1086 &eval_a
1087 ),
1088 "rho_a must be unaffected by the rho_b insert"
1089 );
1090
1091 let cache = EvalCacheManager::new();
1095 let mut keys = Vec::new();
1096 let mut evals = Vec::new();
1097 for i in 0..OUTER_EVAL_LRU_CAPACITY {
1098 let rho = array![i as f64, -(i as f64)];
1099 let key = EvalCacheManager::sanitized_rhokey(&rho);
1100 let eval = make_eval(i as f64 + 0.123);
1101 cache.store_outer_eval(&key, &eval);
1102 keys.push(key);
1103 evals.push(eval);
1104 }
1105 assert_eq!(
1107 cache.outer_eval_lru.read().unwrap().entries.len(),
1108 OUTER_EVAL_LRU_CAPACITY
1109 );
1110 let rho_overflow = array![999.0, -999.0];
1112 let key_overflow = EvalCacheManager::sanitized_rhokey(&rho_overflow);
1113 let eval_overflow = make_eval(42.0);
1114 cache.store_outer_eval(&key_overflow, &eval_overflow);
1115 assert_eq!(
1116 cache.outer_eval_lru.read().unwrap().entries.len(),
1117 OUTER_EVAL_LRU_CAPACITY,
1118 "capacity must stay bounded"
1119 );
1120 assert!(
1121 cache.cached_outer_eval(&keys[0]).is_none(),
1122 "the least-recently-used key must be evicted and now MISS (recompute), not return stale"
1123 );
1124 assert!(
1125 bits_eq(
1126 &cache
1127 .cached_outer_eval(&keys[1])
1128 .expect("a still-resident key must hit"),
1129 &evals[1]
1130 ),
1131 "a still-resident key must return its exact stored bits"
1132 );
1133 assert!(
1134 bits_eq(
1135 &cache
1136 .cached_outer_eval(&key_overflow)
1137 .expect("the freshest key must hit"),
1138 &eval_overflow
1139 ),
1140 "the freshest key must hit with its own eval"
1141 );
1142 }
1143
1144 #[test]
1145 pub(crate) fn reset_outer_seed_state_clears_pirls_cache() {
1146 let y = array![0.0, 1.0, 1.0, 0.0, 0.0, 1.0];
1152 let w = Array1::<f64>::ones(y.len());
1153 let x = array![
1154 [1.0, -1.0, 0.2],
1155 [1.0, -0.5, -0.4],
1156 [1.0, 0.0, 0.7],
1157 [1.0, 0.4, -0.3],
1158 [1.0, 0.9, 0.1],
1159 [1.0, 1.3, -0.6],
1160 ];
1161 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.1, 0.15], [0.0, 0.15, 0.8],];
1162 let rho = array![0.0];
1163 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false);
1164 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1165
1166 state
1169 .compute_outer_eval_with_order(
1170 &rho,
1171 crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
1172 )
1173 .expect("outer eval should succeed");
1174
1175 let populated_len = state.cache_manager.pirls_cache.read().unwrap().map.len();
1176 assert!(
1177 populated_len > 0,
1178 "evaluating the outer objective should populate the PIRLS LRU, got {populated_len}"
1179 );
1180
1181 state.reset_outer_seed_state();
1182
1183 let cleared_len = state.cache_manager.pirls_cache.read().unwrap().map.len();
1184 assert_eq!(
1185 cleared_len, 0,
1186 "reset_outer_seed_state must clear the cross-call PIRLS LRU; got {cleared_len} entries"
1187 );
1188 }
1189
1190 #[test]
1191 pub(crate) fn reset_outer_seed_state_preserves_frozen_negbin_theta_1448() {
1192 use std::sync::atomic::Ordering;
1209
1210 let y = array![0.0, 1.0, 1.0, 0.0, 0.0, 1.0];
1211 let w = Array1::<f64>::ones(y.len());
1212 let x = array![
1213 [1.0, -1.0, 0.2],
1214 [1.0, -0.5, -0.4],
1215 [1.0, 0.0, 0.7],
1216 [1.0, 0.4, -0.3],
1217 [1.0, 0.9, 0.1],
1218 [1.0, 1.3, -0.6],
1219 ];
1220 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.1, 0.15], [0.0, 0.15, 0.8],];
1221 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false);
1222 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1223
1224 let theta_final_bits = 2.5_f64.to_bits();
1226 state
1227 .frozen_negbin_theta
1228 .store(theta_final_bits, Ordering::Relaxed);
1229 assert_eq!(
1230 state.frozen_negbin_theta.load(Ordering::Relaxed),
1231 theta_final_bits,
1232 "precondition: the re-freeze stores θ_final into the frozen slot"
1233 );
1234
1235 state.reset_outer_seed_state();
1237
1238 assert_eq!(
1239 state.frozen_negbin_theta.load(Ordering::Relaxed),
1240 theta_final_bits,
1241 "reset_outer_seed_state (alternation-round reset) must PRESERVE the \
1242 re-frozen NB θ; clearing it would defeat the #1448 θ↔λ alternation \
1243 (the next ρ search would re-derive θ from the seed and never reach \
1244 the joint fixed point)"
1245 );
1246 }
1247
1248 #[test]
1249 pub(crate) fn implicit_hyper_design_derivative_respects_full_model_embedding() {
1250 let operator = ImplicitDesignPsiDerivative::new(
1251 array![1.0, 2.0, 3.0, 4.0],
1252 array![0.5, -1.0, 1.5, 2.0],
1253 array![0.1, 0.2, 0.3, 0.4],
1254 array![[1.0, 0.2], [0.5, 0.1], [1.5, 0.3], [2.0, 0.4]],
1255 None,
1256 None,
1257 2,
1258 2,
1259 1,
1260 2,
1261 );
1262 let local = operator
1263 .materialize_first(0)
1264 .expect("materialized first derivative");
1265 assert_eq!(
1266 local.ncols(),
1267 3,
1268 "operator-local derivative should stay smooth-local"
1269 );
1270
1271 let implicit = HyperDesignDerivative::from_implicit(
1272 Arc::new(operator),
1273 ImplicitDerivLevel::First(0),
1274 1..4,
1275 5,
1276 );
1277 let embedded = HyperDesignDerivative::from_embedded(local.clone(), 1..4, 5);
1278
1279 assert_eq!(implicit.nrows(), embedded.nrows());
1280 assert_eq!(implicit.ncols(), 5);
1281 assert_eq!(implicit.materialize(), embedded.materialize());
1282
1283 let u = array![7.0, 1.5, -2.0, 0.25, -3.0];
1284 let v = array![0.75, -1.25];
1285 assert_eq!(
1286 implicit.forward_mul_original(&u).expect("implicit forward"),
1287 embedded.forward_mul_original(&u).expect("embedded forward")
1288 );
1289 assert_eq!(
1290 implicit
1291 .transpose_mul_original(&v)
1292 .expect("implicit transpose"),
1293 embedded
1294 .transpose_mul_original(&v)
1295 .expect("embedded transpose")
1296 );
1297
1298 let qs = array![
1299 [1.0, 0.0, 0.0],
1300 [0.0, 1.0, 0.0],
1301 [0.0, 0.5, 0.5],
1302 [0.0, 0.0, 1.0],
1303 [0.0, 0.0, 0.0],
1304 ];
1305 assert_eq!(
1306 implicit
1307 .transformed(&qs, None)
1308 .expect("implicit transformed"),
1309 embedded
1310 .transformed(&qs, None)
1311 .expect("embedded transformed")
1312 );
1313 let u_transformed = array![1.0, -0.5, 2.0];
1314 assert_eq!(
1315 implicit
1316 .transformed_forward_mul(&qs, None, &u_transformed)
1317 .expect("implicit transformed forward"),
1318 embedded
1319 .transformed_forward_mul(&qs, None, &u_transformed)
1320 .expect("embedded transformed forward")
1321 );
1322 assert_eq!(
1323 implicit
1324 .transformed_transpose_mul(&qs, None, &v)
1325 .expect("implicit transformed transpose"),
1326 embedded
1327 .transformed_transpose_mul(&qs, None, &v)
1328 .expect("embedded transformed transpose")
1329 );
1330 }
1331
1332 #[test]
1333 pub(crate) fn directional_hyper_identities_match_finite_differences_logit() {
1334 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1335 let w = Array1::<f64>::ones(y.len());
1336 let x = array![
1337 [1.0, -1.2, 0.3],
1338 [1.0, -0.8, -0.4],
1339 [1.0, -0.3, 0.7],
1340 [1.0, 0.1, -0.9],
1341 [1.0, 0.5, 0.2],
1342 [1.0, 0.9, -0.1],
1343 [1.0, 1.3, 0.8],
1344 [1.0, 1.7, -0.6],
1345 ];
1346 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1347
1348 let x_tau = Array2::<f64>::zeros(x.raw_dim());
1353 let s_tau = array![[0.0, 0.0, 0.0], [0.0, 0.25, 0.04], [0.0, 0.04, 0.15],];
1354 let hyper =
1355 DirectionalHyperParam::single_penalty(0, x_tau.clone(), s_tau.clone(), None, None)
1356 .expect("single-penalty hyper direction");
1357 let rho = array![0.0];
1358
1359 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
1363 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1364 let bundle = state.obtain_eval_bundle(&rho).expect("bundle");
1365 let pr = bundle.pirls_result.as_ref();
1366
1367 let beta = beta_original_from_bundle(&bundle);
1368 let h_orig = h_original_from_bundle(&bundle);
1369 let u = &pr.solveweights * &(&pr.solveworking_response - &pr.final_eta);
1370
1371 let x_tau_beta = gam_linalg::faer_ndarray::fast_av(&x_tau, &beta);
1374 let weighted_x_tau_beta = &pr.finalweights * &x_tau_beta;
1375 let rhs = gam_linalg::faer_ndarray::fast_atv(&x_tau, &u)
1376 - gam_linalg::faer_ndarray::fast_atv(&x, &weighted_x_tau_beta)
1377 - s_tau.dot(&beta);
1378 let chol = h_orig.cholesky(Side::Lower).expect("chol(H)");
1379 let b_analytic = chol.solvevec(&rhs);
1380
1381 let eta_dot = &x_tau_beta + &gam_linalg::faer_ndarray::fast_av(&x, &b_analytic);
1385 let w_direction = crate::pirls::directionalworking_curvature_from_c_array(
1386 &pr.solve_c_array,
1387 &pr.finalweights,
1388 &eta_dot,
1389 );
1390 let wx = RemlState::row_scale(&x, &pr.finalweights);
1391 let wx_tau = RemlState::row_scale(&x_tau, &pr.finalweights);
1392 let mut xwtau_x = x.clone();
1393 match w_direction {
1394 crate::pirls::DirectionalWorkingCurvature::Diagonal(diag) => {
1395 xwtau_x = RemlState::row_scale(&xwtau_x, &diag);
1396 }
1397 }
1398 let mut h_tau_analytic = gam_linalg::faer_ndarray::fast_atb(&x_tau, &wx);
1399 h_tau_analytic += &gam_linalg::faer_ndarray::fast_atb(&x, &wx_tau);
1400 h_tau_analytic += &gam_linalg::faer_ndarray::fast_atb(&x, &xwtau_x);
1401 h_tau_analytic += &s_tau;
1402
1403 let ell_beta = gam_linalg::faer_ndarray::fast_atv(&x, &u);
1408 let s_eff = &h_orig - &gam_linalg::faer_ndarray::fast_atb(&x, &wx);
1409 let cancellation = -ell_beta.dot(&b_analytic) + beta.dot(&s_eff.dot(&b_analytic));
1410
1411 let h = 2e-5;
1413 let x_plus = &x + &(x_tau.mapv(|v| h * v));
1414 let x_minus = &x - &(x_tau.mapv(|v| h * v));
1415 let s_plus = &s0 + &(s_tau.mapv(|v| h * v));
1416 let s_minus = &s0 - &(s_tau.mapv(|v| h * v));
1417
1418 let state_plus = build_logit_state(&y, &w, &x_plus, &s_plus, &cfg);
1419 let state_minus = build_logit_state(&y, &w, &x_minus, &s_minus, &cfg);
1420 let bundle_plus = state_plus.obtain_eval_bundle(&rho).expect("bundle+");
1421 let bundle_minus = state_minus.obtain_eval_bundle(&rho).expect("bundle-");
1422 let beta_plus = beta_original_from_bundle(&bundle_plus);
1423 let beta_minus = beta_original_from_bundle(&bundle_minus);
1424 let bfd = (&beta_plus - &beta_minus).mapv(|v| v / (2.0 * h));
1425
1426 let h_plus = h_original_from_bundle(&bundle_plus);
1427 let h_minus = h_original_from_bundle(&bundle_minus);
1428 let h_taufd = (&h_plus - &h_minus).mapv(|v| v / (2.0 * h));
1429
1430 let v_plus = state_plus.compute_cost(&rho).expect("cost+");
1431 let v_minus = state_minus.compute_cost(&rho).expect("cost-");
1432 let v_taufd = (v_plus - v_minus) / (2.0 * h);
1433
1434 let v_tau_analytic = single_directional_tau_gradient(&state, &rho, hyper.clone())
1435 .expect("analytic directional gradient");
1436
1437 let b_num = (&b_analytic - &bfd).mapv(|v| v * v).sum().sqrt();
1438 let b_den = bfd.mapv(|v| v * v).sum().sqrt().max(1e-12);
1439 let b_rel = b_num / b_den;
1440 for i in 0..b_analytic.len() {
1441 assert_eq!(
1442 b_analytic[i].signum(),
1443 bfd[i].signum(),
1444 "B sign mismatch at i={i}: analytic={} fd={}",
1445 b_analytic[i],
1446 bfd[i]
1447 );
1448 }
1449 assert!(
1450 b_rel < 2e-2,
1451 "B implicit solve mismatch vs FD: rel={b_rel:.3e}, num={b_num:.3e}, den={b_den:.3e}"
1452 );
1453
1454 let dh_num = (&h_tau_analytic - &h_taufd).mapv(|v| v * v).sum().sqrt();
1455 let dh_den = h_taufd.mapv(|v| v * v).sum().sqrt().max(1e-12);
1456 let dh_rel = dh_num / dh_den;
1457 for i in 0..h_tau_analytic.nrows() {
1458 for j in 0..h_tau_analytic.ncols() {
1459 assert_eq!(
1460 h_tau_analytic[[i, j]].signum(),
1461 h_taufd[[i, j]].signum(),
1462 "H_tau sign mismatch at ({i},{j}): analytic={} fd={}",
1463 h_tau_analytic[[i, j]],
1464 h_taufd[[i, j]]
1465 );
1466 }
1467 }
1468 assert!(
1469 dh_rel < 3e-2,
1470 "H_tau mismatch vs FD: rel={dh_rel:.3e}, num={dh_num:.3e}, den={dh_den:.3e}"
1471 );
1472
1473 let v_abs = (v_tau_analytic - v_taufd).abs();
1474 let v_rel = v_abs / v_taufd.abs().max(1e-10);
1475 assert_eq!(
1476 v_tau_analytic.signum(),
1477 v_taufd.signum(),
1478 "V_tau sign mismatch: analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
1479 );
1480 assert!(
1481 v_rel < 2e-2,
1482 "V_tau mismatch vs FD: rel={v_rel:.3e}, abs={v_abs:.3e}, analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
1483 );
1484
1485 assert!(
1486 cancellation.abs() < 1e-10,
1487 "stationarity cancellation failed: | -ell_beta^T B + beta^T S B | = {:.3e}",
1488 cancellation.abs()
1489 );
1490 }
1491
1492 #[test]
1493 pub(crate) fn firth_exacthessian_includes_analytic_tk_second_derivatives() {
1494 let y = array![0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0];
1496 let w = Array1::<f64>::ones(y.len());
1497 let x = array![
1498 [1.0, -1.2, 0.4, -2.4],
1499 [1.0, -0.9, -0.1, -1.8],
1500 [1.0, -0.6, 0.3, -1.2],
1501 [1.0, -0.2, -0.4, -0.4],
1502 [1.0, 0.1, 0.5, 0.2],
1503 [1.0, 0.4, -0.6, 0.8],
1504 [1.0, 0.8, 0.2, 1.6],
1505 [1.0, 1.1, -0.3, 2.2],
1506 [1.0, 1.4, 0.7, 2.8],
1507 [1.0, 1.7, -0.2, 3.4],
1508 ];
1509 let s0 = array![
1510 [0.0, 0.0, 0.0, 0.0],
1511 [0.0, 1.5, 0.2, 0.0],
1512 [0.0, 0.2, 1.0, 0.0],
1513 [0.0, 0.0, 0.0, 0.5],
1514 ];
1515 let s1 = array![
1516 [0.0, 0.0, 0.0, 0.0],
1517 [0.0, 0.8, -0.1, 0.0],
1518 [0.0, -0.1, 0.6, 0.0],
1519 [0.0, 0.0, 0.0, 0.3],
1520 ];
1521 let offset = Array1::<f64>::zeros(y.len());
1522 let cfg =
1525 RemlConfig::external(binomial_logit_glm_spec(), 1e-9, true).with_max_iterations(500);
1526 let p = x.ncols();
1527 use crate::estimate::PenaltySpec;
1528 let specs = vec![PenaltySpec::Dense(s0), PenaltySpec::Dense(s1)];
1529 let canonical = gam_terms::construction::canonicalize_penalty_specs(&specs, &[1, 1], p, "test")
1530 .map(|(canonical, _)| canonical)
1531 .expect("canonicalize");
1532 let state = RemlState::newwith_offset(
1533 y.view(),
1534 x.clone(),
1535 w.view(),
1536 offset.view(),
1537 canonical,
1538 p,
1539 &cfg,
1540 Some(vec![1, 1]),
1541 None,
1542 None,
1543 )
1544 .expect("state");
1545 let rho = array![0.1, -0.2];
1546 assert!(
1547 state.analytic_outer_hessian_enabled(),
1548 "Firth logit should no longer disable analytic outer Hessian planning"
1549 );
1550 let outer = state
1551 .compute_outer_eval_with_order(
1552 &rho,
1553 crate::rho_optimizer::OuterEvalOrder::ValueGradientHessian,
1554 )
1555 .expect("outer Hessian eval should succeed");
1556 assert!(
1557 outer.hessian.is_analytic(),
1558 "outer planner should request and return an analytic Hessian"
1559 );
1560 let bundle = state.obtain_eval_bundle(&rho).expect("exact firth bundle");
1561 let h_dense = state
1562 .compute_lamlhessian_exact_from_bundle(&rho, &bundle)
1563 .expect("Firth exact Hessian should include analytic TK second derivatives");
1564 assert_eq!(h_dense.raw_dim(), ndarray::Ix2(2, 2));
1565 assert!(
1566 h_dense.iter().all(|value| value.is_finite()),
1567 "Hessian should be finite: {h_dense:?}"
1568 );
1569 }
1570
1571 #[test]
1572 pub(crate) fn firth_outer_hessian_matches_gradient_finite_difference_with_tk_terms() {
1573 let y = array![0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0];
1574 let w = Array1::<f64>::ones(y.len());
1575 let x = array![
1576 [1.0, -1.0, 0.3],
1577 [1.0, -0.7, -0.2],
1578 [1.0, -0.3, 0.4],
1579 [1.0, 0.0, -0.5],
1580 [1.0, 0.2, 0.6],
1581 [1.0, 0.6, -0.4],
1582 [1.0, 0.9, 0.2],
1583 [1.0, 1.3, -0.1],
1584 ];
1585 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.1], [0.0, 0.1, 0.7],];
1586 let s1 = array![[0.0, 0.0, 0.0], [0.0, 0.4, -0.05], [0.0, -0.05, 0.9],];
1587 let cfg =
1588 RemlConfig::external(binomial_logit_glm_spec(), 1e-9, true).with_max_iterations(500);
1589 let p_dim = x.ncols();
1590 use crate::estimate::PenaltySpec;
1591 let specs = vec![PenaltySpec::Dense(s0), PenaltySpec::Dense(s1)];
1592 let canonical =
1593 gam_terms::construction::canonicalize_penalty_specs(&specs, &[1, 1], p_dim, "test")
1594 .map(|(canonical, _)| canonical)
1595 .expect("canonicalize");
1596 let offset = Array1::<f64>::zeros(y.len());
1597 let state = RemlState::newwith_offset(
1598 y.view(),
1599 x.clone(),
1600 w.view(),
1601 offset.view(),
1602 canonical,
1603 p_dim,
1604 &cfg,
1605 Some(vec![1, 1]),
1606 None,
1607 None,
1608 )
1609 .expect("state");
1610 let rho = array![0.15, -0.25];
1611 let eval = state
1612 .compute_outer_eval_with_order(
1613 &rho,
1614 crate::rho_optimizer::OuterEvalOrder::ValueGradientHessian,
1615 )
1616 .expect("analytic Hessian eval");
1617 let h = match eval.hessian {
1618 HessianResult::Analytic(hessian) => hessian,
1619 HessianResult::Operator(_) | HessianResult::Unavailable => {
1620 panic!("expected dense analytic Hessian")
1621 }
1622 };
1623 let delta = 2.0e-5;
1624 for col in 0..rho.len() {
1625 let mut rp = rho.clone();
1626 let mut rm = rho.clone();
1627 rp[col] += delta;
1628 rm[col] -= delta;
1629 let gp = state
1630 .compute_outer_eval_with_order(
1631 &rp,
1632 crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
1633 )
1634 .expect("plus grad")
1635 .gradient;
1636 let gm = state
1637 .compute_outer_eval_with_order(
1638 &rm,
1639 crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
1640 )
1641 .expect("minus grad")
1642 .gradient;
1643 for row in 0..rho.len() {
1644 let fd = (gp[row] - gm[row]) / (2.0 * delta);
1645 let an = h[[row, col]];
1646 let rel = (fd - an).abs() / fd.abs().max(an.abs()).max(1e-6);
1647 assert!(
1648 rel < 2.0e-3,
1649 "Hessian mismatch ({row},{col}): analytic={an:.9e}, fd={fd:.9e}, rel={rel:.3e}"
1650 );
1651 }
1652 }
1653 }
1654
1655 #[test]
1656 pub(crate) fn firthgradient_lives_in_design_column_space_under_rank_deficiency() {
1657 let x = array![
1659 [1.0, -1.2, 0.4, -2.4],
1660 [1.0, -0.9, -0.1, -1.8],
1661 [1.0, -0.6, 0.3, -1.2],
1662 [1.0, -0.2, -0.4, -0.4],
1663 [1.0, 0.1, 0.5, 0.2],
1664 [1.0, 0.4, -0.6, 0.8],
1665 [1.0, 0.8, 0.2, 1.6],
1666 [1.0, 1.1, -0.3, 2.2],
1667 ];
1668 let beta = array![0.1, -0.2, 0.3, 0.05];
1669 let eta = x.dot(&beta);
1670 let op = super::RemlState::build_firth_dense_operator_for_link(
1671 &gam_problem::InverseLink::Standard(gam_problem::StandardLink::Logit),
1672 &x,
1673 &eta,
1674 ndarray::Array1::ones(x.nrows()).view(),
1675 )
1676 .expect("firth operator");
1677
1678 let gradphi = 0.5 * x.t().dot(&(&op.w1 * &op.h_diag));
1681
1682 let q = &op.q_basis;
1684 let proj = q.dot(&q.t().dot(&gradphi));
1685 let resid = &gradphi - &proj;
1686 let rel =
1687 resid.mapv(|v| v * v).sum().sqrt() / gradphi.mapv(|v| v * v).sum().sqrt().max(1e-12);
1688 assert!(
1689 rel < 1e-10,
1690 "Firth gradient should lie in Col(Xᵀ): rel residual={rel:.3e}"
1691 );
1692 }
1693
1694 #[test]
1695 pub(crate) fn firth_logit_directional_hypergradient_accepts_penalty_only_with_full_tk_gradient()
1696 {
1697 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
1698 let w = Array1::<f64>::ones(y.len());
1699 let x = array![
1700 [1.0, -1.1, 0.2],
1701 [1.0, -0.6, -0.3],
1702 [1.0, -0.1, 0.5],
1703 [1.0, 0.3, -0.7],
1704 [1.0, 0.8, 0.1],
1705 [1.0, 1.2, -0.4],
1706 ];
1707 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1708 let hyper = DirectionalHyperParam::single_penalty(
1709 0,
1710 Array2::<f64>::zeros((x.nrows(), x.ncols())),
1711 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.03], [0.0, 0.03, 0.12],],
1712 None,
1713 None,
1714 )
1715 .expect("single-penalty hyper direction");
1716 let rho = array![0.0];
1717 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-8, true);
1718 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1719 let gradient = single_directional_tau_gradient(&state, &rho, hyper)
1720 .expect("Firth penalty-only directional gradient should use analytic TK propagation");
1721 assert!(gradient.is_finite(), "gradient={gradient}");
1722 let fd = fd_directional_tau_cost_gradient(
1723 &y,
1724 &w,
1725 &x,
1726 &s0,
1727 &cfg,
1728 &rho,
1729 &Array2::<f64>::zeros((x.nrows(), x.ncols())),
1730 &array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.03], [0.0, 0.03, 0.12],],
1731 );
1732 let rel = (gradient - fd).abs() / gradient.abs().max(fd.abs()).max(1.0e-10);
1733 assert!(
1734 rel < 1.0e-3,
1735 "Firth penalty-only directional gradient mismatch: analytic={gradient:.12e}, fd={fd:.12e}, rel={rel:.3e}"
1736 );
1737
1738 let efs_hyper = DirectionalHyperParam::single_penalty(
1739 0,
1740 Array2::<f64>::zeros((x.nrows(), x.ncols())),
1741 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.03], [0.0, 0.03, 0.12],],
1742 None,
1743 None,
1744 )
1745 .expect("single-penalty EFS hyper direction");
1746 let efs = state
1747 .compute_efs_steps_with_psi_ext(&rho, &[efs_hyper])
1748 .expect("Firth penalty-only EFS should use analytic TK propagation");
1749 assert!(efs.cost.is_finite(), "efs cost={}", efs.cost);
1750 }
1751
1752 #[test]
1772 pub(crate) fn firth_logit_rho_gradient_matches_finite_difference_through_inner_solve() {
1773 let x = array![[1.0, -6.0], [1.0, 0.2], [1.0, 5.8]];
1777 let y = array![0.0, 0.0, 1.0];
1778 let w = Array1::<f64>::ones(y.len());
1779 let s0 = array![[1.0, 0.0], [0.0, 1.0]];
1781 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-12, true);
1784 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1785 let delta = 1e-4_f64;
1786 for &rho in &[-0.6_f64, -0.3, 0.0, 0.3, 0.6] {
1787 let r = array![rho];
1788 let analytic = state
1789 .compute_gradient(&r)
1790 .expect("Firth LAML ρ-gradient should evaluate")[0];
1791 let cost_plus = state
1792 .compute_cost(&array![rho + delta])
1793 .expect("Firth LAML cost(ρ+δ) should evaluate");
1794 let cost_minus = state
1795 .compute_cost(&array![rho - delta])
1796 .expect("Firth LAML cost(ρ−δ) should evaluate");
1797 let fd = (cost_plus - cost_minus) / (2.0 * delta);
1798 let rel = (fd - analytic).abs() / fd.abs().max(1e-3);
1799 assert!(
1800 analytic.is_finite() && fd.is_finite(),
1801 "non-finite Firth ρ-gradient at rho={rho:+.3}: fd={fd:+.6e}, analytic={analytic:+.6e}"
1802 );
1803 assert!(
1804 rel < 1e-4,
1805 "Firth ρ-gradient FD desync at rho={rho:+.3}: fd={fd:+.6e}, analytic={analytic:+.6e}, rel={rel:.3e} (>= 1e-4). \
1806 The inner P-IRLS likely converged off the Firth-KKT mode (gam#1821)."
1807 );
1808 }
1809 }
1810
1811 #[test]
1812 pub(crate) fn firth_logit_directional_hypergradient_accepts_design_moving_with_full_tk_gradient()
1813 {
1814 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
1815 let w = Array1::<f64>::ones(y.len());
1816 let x = array![
1817 [1.0, -1.1, 0.2],
1818 [1.0, -0.6, -0.3],
1819 [1.0, -0.1, 0.5],
1820 [1.0, 0.3, -0.7],
1821 [1.0, 0.8, 0.1],
1822 [1.0, 1.2, -0.4],
1823 ];
1824 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1825 let hyper = DirectionalHyperParam::single_penalty(
1826 0,
1827 Array2::from_elem((x.nrows(), x.ncols()), 1e-3),
1828 Array2::<f64>::zeros((x.ncols(), x.ncols())),
1829 None,
1830 None,
1831 )
1832 .expect("single-penalty hyper direction");
1833 let rho = array![0.0];
1834 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-8, true);
1835 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1836 let gradient = single_directional_tau_gradient(&state, &rho, hyper)
1837 .expect("Firth design-moving directional gradient should use analytic TK propagation");
1838 assert!(gradient.is_finite(), "gradient={gradient}");
1839 let x_tau = Array2::from_elem((x.nrows(), x.ncols()), 1e-3);
1840 let s_tau = Array2::<f64>::zeros((x.ncols(), x.ncols()));
1841 let fd = fd_directional_tau_cost_gradient(&y, &w, &x, &s0, &cfg, &rho, &x_tau, &s_tau);
1842 let rel = (gradient - fd).abs() / gradient.abs().max(fd.abs()).max(1.0e-10);
1843 assert!(
1844 rel < 2.0e-2,
1845 "Firth design-moving directional gradient mismatch: analytic={gradient:.12e}, fd={fd:.12e}, rel={rel:.3e}"
1846 );
1847 }
1848
1849 #[test]
1850 pub(crate) fn firth_logit_hybrid_efs_accepts_full_tk_psi_gradient() {
1851 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
1852 let w = Array1::<f64>::ones(y.len());
1853 let x = array![
1854 [1.0, -1.1, 0.2],
1855 [1.0, -0.6, -0.3],
1856 [1.0, -0.1, 0.5],
1857 [1.0, 0.3, -0.7],
1858 [1.0, 0.8, 0.1],
1859 [1.0, 1.2, -0.4],
1860 ];
1861 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1862 let hyper_dirs = vec![
1863 DirectionalHyperParam::single_penalty(
1864 0,
1865 Array2::from_shape_fn((x.nrows(), x.ncols()), |(i, j)| {
1866 1e-3 * ((i + 1) as f64) * ((j + 2) as f64)
1867 }),
1868 Array2::<f64>::zeros((x.ncols(), x.ncols())),
1869 None,
1870 None,
1871 )
1872 .expect("design-moving hyper direction"),
1873 ];
1874 let rho = array![0.0];
1875 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-8, true);
1876 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1877
1878 let full = state
1879 .evaluate_unified_with_psi_ext(
1880 &rho,
1881 None,
1882 crate::estimate::reml::reml_outer_engine::EvalMode::ValueAndGradient,
1883 &hyper_dirs,
1884 )
1885 .expect("full Firth psi gradient should use analytic TK propagation");
1886 assert!(full.cost.is_finite(), "full cost={}", full.cost);
1887 let full_grad = full.gradient.expect("gradient should be present");
1888 assert!(
1889 full_grad.iter().all(|value| value.is_finite()),
1890 "full gradient={full_grad:?}"
1891 );
1892
1893 let efs = state
1894 .compute_efs_steps_with_psi_ext(&rho, &hyper_dirs)
1895 .expect("hybrid EFS should use analytic TK propagation");
1896 assert!(efs.cost.is_finite(), "efs cost={}", efs.cost);
1897 }
1898
1899 #[test]
1900 pub(crate) fn joint_hyperhessianwires_mixed_blocks() {
1901 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1902 let w = Array1::<f64>::ones(y.len());
1903 let x = array![
1904 [1.0, -1.2, 0.3],
1905 [1.0, -0.8, -0.4],
1906 [1.0, -0.3, 0.7],
1907 [1.0, 0.1, -0.9],
1908 [1.0, 0.5, 0.2],
1909 [1.0, 0.9, -0.1],
1910 [1.0, 1.3, 0.8],
1911 [1.0, 1.7, -0.6],
1912 ];
1913 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1914 let cfg =
1915 RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false).with_max_iterations(500);
1916 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1917 let rho = array![0.0];
1918 let theta = array![0.0, 0.0, 0.0];
1919 let hyper_dirs = vec![
1920 DirectionalHyperParam::single_penalty(
1921 0,
1922 Array2::<f64>::zeros((x.nrows(), x.ncols())),
1923 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15],],
1924 None,
1925 None,
1926 )
1927 .expect("single-penalty hyper direction"),
1928 DirectionalHyperParam::single_penalty(
1929 0,
1930 Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
1931 Array2::<f64>::zeros((x.ncols(), x.ncols())),
1932 None,
1933 None,
1934 )
1935 .expect("single-penalty hyper direction"),
1936 ];
1937
1938 let (_, _, h) =
1939 compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
1940 .expect("joint hyper cost+gradient+hessian");
1941 assert_eq!(h.nrows(), theta.len());
1942 assert_eq!(h.ncols(), theta.len());
1943 assert!(h.iter().all(|v| v.is_finite()));
1944 for i in 0..h.nrows() {
1945 for j in 0..i {
1946 let diff = (h[[i, j]] - h[[j, i]]).abs();
1947 assert!(
1948 diff < 1e-6,
1949 "joint hessian asymmetry at ({i},{j}): {diff:.3e}"
1950 );
1951 }
1952 }
1953 let mixed_0 = h[[0, 1]];
1955 let mixed_1 = h[[0, 2]];
1956 assert!(
1957 mixed_0.is_finite() && mixed_1.is_finite(),
1958 "mixed blocks must be finite"
1959 );
1960 }
1961
1962 #[test]
1963 pub(crate) fn joint_tau_tau_linear_dirs_matchfd_reference_away_fromzero_psi() {
1964 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1965 let w = Array1::<f64>::ones(y.len());
1966 let x = array![
1967 [1.0, -1.2, 0.3],
1968 [1.0, -0.8, -0.4],
1969 [1.0, -0.3, 0.7],
1970 [1.0, 0.1, -0.9],
1971 [1.0, 0.5, 0.2],
1972 [1.0, 0.9, -0.1],
1973 [1.0, 1.3, 0.8],
1974 [1.0, 1.7, -0.6],
1975 ];
1976 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1977 let cfg =
1978 RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false).with_max_iterations(500);
1979 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1980 let rho = array![0.0];
1981 let psi = array![0.7, -0.4];
1982 let theta = array![rho[0], psi[0], psi[1]];
1983 let hyper_dirs = vec![
1984 DirectionalHyperParam::single_penalty(
1985 0,
1986 Array2::<f64>::zeros((x.nrows(), x.ncols())),
1987 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15],],
1988 None,
1989 None,
1990 )
1991 .expect("linear tau direction"),
1992 DirectionalHyperParam::single_penalty(
1993 0,
1994 Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
1995 Array2::<f64>::zeros((x.ncols(), x.ncols())),
1996 None,
1997 None,
1998 )
1999 .expect("linear tau direction"),
2000 ];
2001
2002 let (_, _, h_full) =
2003 compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
2004 .expect("joint hyper cost+gradient+hessian");
2005 let h_tt_analytic = h_full.slice(s![rho.len().., rho.len()..]).to_owned();
2006
2007 let x_tau_mats: Vec<Array2<f64>> = vec![
2012 Array2::<f64>::zeros((x.nrows(), x.ncols())),
2013 Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
2014 ];
2015 let s_tau_mats: Vec<Array2<f64>> = vec![
2016 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15]],
2017 Array2::<f64>::zeros((x.ncols(), x.ncols())),
2018 ];
2019
2020 let h_ttfd = directional_tau_hessian_fd_reference(
2021 &y,
2022 &w,
2023 &x,
2024 &s0,
2025 &cfg,
2026 &rho,
2027 &hyper_dirs,
2028 &x_tau_mats,
2029 &s_tau_mats,
2030 );
2031
2032 let num = (&h_tt_analytic - &h_ttfd)
2033 .iter()
2034 .map(|v| v * v)
2035 .sum::<f64>()
2036 .sqrt();
2037 let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2038 let rel = num / den;
2039 assert!(
2040 rel < 1e-4,
2041 "linear-dir joint tau-tau block deviates from FD reference away from zero psi: rel={rel:.3e}, analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
2042 );
2043 }
2044
2045 #[test]
2046 pub(crate) fn joint_hypervalidation_rejects_out_of_boundssecond_order_penalty_index() {
2047 let y = array![0.0, 1.0, 0.0, 1.0];
2064 let w = Array1::<f64>::ones(y.len());
2065 let x = array![
2066 [1.0, -0.5, 0.2],
2067 [1.0, -0.1, -0.3],
2068 [1.0, 0.4, 0.6],
2069 [1.0, 0.9, -0.2],
2070 ];
2071 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
2072 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, true);
2073 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
2074 let theta = array![0.0, 0.0];
2075 let hyper_dirs = vec![
2076 DirectionalHyperParam::new(
2077 Array2::<f64>::zeros((x.nrows(), x.ncols())),
2078 vec![(0, Array2::<f64>::zeros((x.ncols(), x.ncols())))],
2079 None,
2080 Some(vec![Some(vec![(1, Array2::<f64>::eye(x.ncols()))])]),
2081 )
2082 .expect("hyper direction with invalid second-order penalty index"),
2083 ];
2084
2085 let msg = match compute_joint_hypercostgradienthessian(&state, &theta, 1, &hyper_dirs) {
2086 Ok(_) => panic!("invalid second-order penalty index should be rejected"),
2087 Err(err) => err.to_string(),
2088 };
2089 assert!(
2090 msg.contains("out of bounds") || msg.contains("penalty_index"),
2091 "unexpected validation error: {msg}"
2092 );
2093 }
2094
2095 #[test]
2096 pub(crate) fn joint_tau_tau_analytic_matchesfd_reference() {
2097 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
2098 let w = Array1::<f64>::ones(y.len());
2099 let x = array![
2100 [1.0, -1.2, 0.3],
2101 [1.0, -0.8, -0.4],
2102 [1.0, -0.3, 0.7],
2103 [1.0, 0.1, -0.9],
2104 [1.0, 0.5, 0.2],
2105 [1.0, 0.9, -0.1],
2106 [1.0, 1.3, 0.8],
2107 [1.0, 1.7, -0.6],
2108 ];
2109 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
2110 let cfg =
2111 RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false).with_max_iterations(500);
2112 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
2113 let rho = array![0.0];
2114 let psi = array![0.0, 0.0];
2115 let hyper_dirs = vec![
2116 DirectionalHyperParam::single_penalty(
2117 0,
2118 Array2::<f64>::zeros((x.nrows(), x.ncols())),
2119 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15],],
2120 None,
2121 None,
2122 )
2123 .expect("single-penalty hyper direction"),
2124 DirectionalHyperParam::single_penalty(
2125 0,
2126 Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
2127 Array2::<f64>::zeros((x.ncols(), x.ncols())),
2128 None,
2129 None,
2130 )
2131 .expect("single-penalty hyper direction"),
2132 ];
2133
2134 let theta = {
2135 let mut t = Array1::<f64>::zeros(rho.len() + psi.len());
2136 t.slice_mut(s![..rho.len()]).assign(&rho);
2137 t.slice_mut(s![rho.len()..]).assign(&psi);
2138 t
2139 };
2140 let (_, _, h_full) =
2141 compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
2142 .expect("joint hyper cost+gradient+hessian");
2143 let h_tt_analytic = h_full.slice(s![rho.len().., rho.len()..]).to_owned();
2144 assert_eq!(h_tt_analytic.nrows(), hyper_dirs.len());
2145 assert_eq!(h_tt_analytic.ncols(), hyper_dirs.len());
2146
2147 let x_tau_mats: Vec<Array2<f64>> = vec![
2152 Array2::<f64>::zeros((x.nrows(), x.ncols())),
2153 Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
2154 ];
2155 let s_tau_mats: Vec<Array2<f64>> = vec![
2156 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15]],
2157 Array2::<f64>::zeros((x.ncols(), x.ncols())),
2158 ];
2159
2160 let h_ttfd = directional_tau_hessian_fd_reference(
2161 &y,
2162 &w,
2163 &x,
2164 &s0,
2165 &cfg,
2166 &rho,
2167 &hyper_dirs,
2168 &x_tau_mats,
2169 &s_tau_mats,
2170 );
2171
2172 let num = (&h_tt_analytic - &h_ttfd)
2173 .iter()
2174 .map(|v| v * v)
2175 .sum::<f64>()
2176 .sqrt();
2177 let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2178 let rel = num / den;
2179 assert!(
2180 rel < 1e-4,
2181 "analytic tau-tau block deviates from FD reference: rel={rel:.3e}, analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
2182 );
2183 }
2184
2185 pub(crate) struct GaussianRemlFixture {
2195 pub(crate) y: Array1<f64>,
2196 pub(crate) w: Array1<f64>,
2197 pub(crate) x: Array2<f64>,
2198 pub(crate) s0: Array2<f64>,
2199 pub(crate) cfg: RemlConfig,
2200 pub(crate) rho: Array1<f64>,
2201 pub(crate) x_tau_design: Array2<f64>,
2203 pub(crate) s_tau_penalty: Array2<f64>,
2205 }
2206
2207 impl GaussianRemlFixture {
2208 pub(crate) fn new() -> Self {
2209 let y = array![0.5, 1.2, -0.3, 0.8, 1.1, -0.6, 0.9, 0.1, -0.2, 0.7];
2210 let x = array![
2211 [1.0, -1.2, 0.3],
2212 [1.0, -0.8, -0.4],
2213 [1.0, -0.3, 0.7],
2214 [1.0, 0.1, -0.9],
2215 [1.0, 0.5, 0.2],
2216 [1.0, 0.9, -0.1],
2217 [1.0, 1.3, 0.8],
2218 [1.0, 1.7, -0.6],
2219 [1.0, -0.5, 0.5],
2220 [1.0, 0.3, -0.3],
2221 ];
2222 Self {
2223 w: Array1::<f64>::ones(y.len()),
2224 y,
2225 x: x.clone(),
2226 s0: array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9]],
2227 cfg: RemlConfig::external(gaussian_identity_glm_spec(), 1e-14, false),
2228 rho: array![0.0],
2229 x_tau_design: array![
2230 [0.0, 1e-3, -2e-3],
2231 [0.0, -3e-3, 1e-3],
2232 [0.0, 2e-3, 0.5e-3],
2233 [0.0, -1e-3, 3e-3],
2234 [0.0, 0.5e-3, -1e-3],
2235 [0.0, 1.5e-3, 2e-3],
2236 [0.0, -2e-3, -0.5e-3],
2237 [0.0, 3e-3, 1e-3],
2238 [0.0, -0.5e-3, 2e-3],
2239 [0.0, 1e-3, -1.5e-3],
2240 ],
2241 s_tau_penalty: array![[0.0, 0.0, 0.0], [0.0, 0.25, 0.04], [0.0, 0.04, 0.15]],
2242 }
2243 }
2244 }
2245
2246 impl LogitDesignMotionFixture for GaussianRemlFixture {
2247 fn y(&self) -> &Array1<f64> {
2248 &self.y
2249 }
2250 fn w(&self) -> &Array1<f64> {
2251 &self.w
2252 }
2253 fn x(&self) -> &Array2<f64> {
2254 &self.x
2255 }
2256 fn s0(&self) -> &Array2<f64> {
2257 &self.s0
2258 }
2259 fn cfg(&self) -> &RemlConfig {
2260 &self.cfg
2261 }
2262 fn rho(&self) -> &Array1<f64> {
2263 &self.rho
2264 }
2265 }
2266
2267 #[test]
2268 pub(crate) fn profiled_gaussian_design_moving_gradient_matches_fd() {
2269 let f = GaussianRemlFixture::new();
2270 let state = f.state();
2271 let s_tau = Array2::<f64>::zeros((3, 3));
2272 let hyper = DirectionalHyperParam::single_penalty(
2273 0,
2274 f.x_tau_design.clone(),
2275 s_tau.clone(),
2276 None,
2277 None,
2278 )
2279 .expect("design-moving hyper direction");
2280
2281 let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2282 .expect("analytic directional gradient");
2283 let v_taufd = f.fd_directional_gradient(&f.x_tau_design, &s_tau);
2284
2285 let v_rel = (v_tau_analytic - v_taufd).abs() / v_taufd.abs().max(1e-10);
2286 assert!(
2287 v_rel < 1e-3,
2288 "Gaussian REML design-moving V_tau mismatch: rel={v_rel:.3e}, \
2289 analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
2290 );
2291 }
2292
2293 #[test]
2294 pub(crate) fn profiled_gaussian_penalty_only_gradient_matches_fd() {
2295 let f = GaussianRemlFixture::new();
2296 let state = f.state();
2297 let x_tau = Array2::<f64>::zeros(f.x.raw_dim());
2298 let hyper = DirectionalHyperParam::single_penalty(
2299 0,
2300 x_tau.clone(),
2301 f.s_tau_penalty.clone(),
2302 None,
2303 None,
2304 )
2305 .expect("penalty-only hyper direction");
2306
2307 let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2308 .expect("analytic directional gradient");
2309 let v_taufd = f.fd_directional_gradient(&x_tau, &f.s_tau_penalty);
2310
2311 let v_rel = (v_tau_analytic - v_taufd).abs() / v_taufd.abs().max(1e-10);
2312 assert!(
2313 v_rel < 1e-3,
2314 "Gaussian REML penalty-only V_tau mismatch: rel={v_rel:.3e}, \
2315 analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
2316 );
2317 }
2318
2319 #[test]
2320 pub(crate) fn profiled_gaussian_joint_hessian_matches_fd() {
2321 let f = GaussianRemlFixture::new();
2324 let x_tau_0 = Array2::<f64>::zeros(f.x.raw_dim());
2325 let s_tau_0 = f.s_tau_penalty.clone();
2326 let x_tau_1 = f.x_tau_design.clone();
2327 let s_tau_1 = Array2::<f64>::zeros((3, 3));
2328
2329 let hyper_dirs = vec![
2330 DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2331 .expect("penalty-only direction"),
2332 DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2333 .expect("design-moving direction"),
2334 ];
2335
2336 let state = f.state();
2337 let mut theta = Array1::<f64>::zeros(f.rho.len() + hyper_dirs.len());
2338 theta.slice_mut(s![..f.rho.len()]).assign(&f.rho);
2339 let (_, _, h_full) =
2340 compute_joint_hypercostgradienthessian(&state, &theta, f.rho.len(), &hyper_dirs)
2341 .expect("joint cost+gradient+hessian");
2342 let h_tt_analytic = h_full.slice(s![f.rho.len().., f.rho.len()..]).to_owned();
2343
2344 let x_tau_mats = vec![x_tau_0.clone(), x_tau_1.clone()];
2347 let s_tau_mats = vec![s_tau_0.clone(), s_tau_1.clone()];
2348 let h_ttfd = directional_tau_hessian_fd_reference(
2349 &f.y,
2350 &f.w,
2351 &f.x,
2352 &f.s0,
2353 &f.cfg,
2354 &f.rho,
2355 &hyper_dirs,
2356 &x_tau_mats,
2357 &s_tau_mats,
2358 );
2359
2360 let num = (&h_tt_analytic - &h_ttfd)
2361 .iter()
2362 .map(|v| v * v)
2363 .sum::<f64>()
2364 .sqrt();
2365 let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2366 let rel = num / den;
2367 assert!(
2368 rel < 1e-4,
2369 "Gaussian REML tau-tau Hessian mismatch: rel={rel:.3e}, \
2370 analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
2371 );
2372 }
2373
2374 #[test]
2388 pub(crate) fn logit_design_moving_gradient_matches_fd() {
2389 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0];
2390 let w = Array1::<f64>::ones(y.len());
2391 let x = array![
2392 [1.0, -1.2, 0.3],
2393 [1.0, -0.8, -0.4],
2394 [1.0, -0.3, 0.7],
2395 [1.0, 0.1, -0.9],
2396 [1.0, 0.5, 0.2],
2397 [1.0, 0.9, -0.1],
2398 [1.0, 1.3, 0.8],
2399 [1.0, 1.7, -0.6],
2400 [1.0, -0.5, 0.5],
2401 [1.0, 0.3, -0.3],
2402 ];
2403 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9]];
2404 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
2405 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
2406 let rho = array![0.0];
2407
2408 let x_tau = array![
2410 [0.0, 1e-3, -2e-3],
2411 [0.0, -3e-3, 1e-3],
2412 [0.0, 2e-3, 0.5e-3],
2413 [0.0, -1e-3, 3e-3],
2414 [0.0, 0.5e-3, -1e-3],
2415 [0.0, 1.5e-3, 2e-3],
2416 [0.0, -2e-3, -0.5e-3],
2417 [0.0, 3e-3, 1e-3],
2418 [0.0, -0.5e-3, 2e-3],
2419 [0.0, 1e-3, -1.5e-3],
2420 ];
2421 let s_tau = Array2::<f64>::zeros((3, 3));
2422 let hyper =
2423 DirectionalHyperParam::single_penalty(0, x_tau.clone(), s_tau.clone(), None, None)
2424 .expect("design-moving hyper direction");
2425
2426 let v_tau_analytic = single_directional_tau_gradient(&state, &rho, hyper)
2427 .expect("analytic directional gradient");
2428
2429 let h = 2e-5;
2430 let x_plus = &x + &x_tau.mapv(|v| h * v);
2431 let x_minus = &x - &x_tau.mapv(|v| h * v);
2432 let state_plus = build_logit_state(&y, &w, &x_plus, &s0, &cfg);
2433 let state_minus = build_logit_state(&y, &w, &x_minus, &s0, &cfg);
2434 let v_plus = state_plus.compute_cost(&rho).expect("cost+");
2435 let v_minus = state_minus.compute_cost(&rho).expect("cost-");
2436 let v_taufd = (v_plus - v_minus) / (2.0 * h);
2437
2438 let v_rel = (v_tau_analytic - v_taufd).abs() / v_taufd.abs().max(1e-10);
2439 assert!(
2440 v_rel < 1e-3,
2441 "Logit REML design-moving V_tau mismatch: rel={v_rel:.3e}, \
2442 analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
2443 );
2444 }
2445
2446 #[test]
2447 pub(crate) fn logit_design_moving_hessian_matches_fd() {
2448 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0];
2453 let w = Array1::<f64>::ones(y.len());
2454 let x = array![
2455 [1.0, -1.2, 0.3],
2456 [1.0, -0.8, -0.4],
2457 [1.0, -0.3, 0.7],
2458 [1.0, 0.1, -0.9],
2459 [1.0, 0.5, 0.2],
2460 [1.0, 0.9, -0.1],
2461 [1.0, 1.3, 0.8],
2462 [1.0, 1.7, -0.6],
2463 [1.0, -0.5, 0.5],
2464 [1.0, 0.3, -0.3],
2465 ];
2466 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9]];
2467 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
2468 let rho = array![0.0];
2469
2470 let x_tau_0 = Array2::<f64>::zeros(x.raw_dim());
2472 let s_tau_0 = array![[0.0, 0.0, 0.0], [0.0, 0.25, 0.04], [0.0, 0.04, 0.15]];
2473 let x_tau_1 = array![
2474 [0.0, 1e-3, -2e-3],
2475 [0.0, -3e-3, 1e-3],
2476 [0.0, 2e-3, 0.5e-3],
2477 [0.0, -1e-3, 3e-3],
2478 [0.0, 0.5e-3, -1e-3],
2479 [0.0, 1.5e-3, 2e-3],
2480 [0.0, -2e-3, -0.5e-3],
2481 [0.0, 3e-3, 1e-3],
2482 [0.0, -0.5e-3, 2e-3],
2483 [0.0, 1e-3, -1.5e-3],
2484 ];
2485 let s_tau_1 = Array2::<f64>::zeros((3, 3));
2486
2487 let hyper_dirs = vec![
2488 DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2489 .expect("penalty-only direction"),
2490 DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2491 .expect("design-moving direction"),
2492 ];
2493
2494 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
2495 let mut theta = Array1::<f64>::zeros(rho.len() + hyper_dirs.len());
2496 theta.slice_mut(s![..rho.len()]).assign(&rho);
2497 let (_, _, h_full) =
2498 compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
2499 .expect("joint cost+gradient+hessian");
2500 let h_tt_analytic = h_full.slice(s![rho.len().., rho.len()..]).to_owned();
2501
2502 let x_tau_mats = vec![x_tau_0.clone(), x_tau_1.clone()];
2503 let s_tau_mats = vec![s_tau_0.clone(), s_tau_1.clone()];
2504 let h_ttfd = directional_tau_hessian_fd_reference(
2505 &y,
2506 &w,
2507 &x,
2508 &s0,
2509 &cfg,
2510 &rho,
2511 &hyper_dirs,
2512 &x_tau_mats,
2513 &s_tau_mats,
2514 );
2515
2516 let num = (&h_tt_analytic - &h_ttfd)
2517 .iter()
2518 .map(|v| v * v)
2519 .sum::<f64>()
2520 .sqrt();
2521 let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2522 let rel = num / den;
2523 assert!(
2524 rel < 1e-4,
2525 "Logit REML design-moving tau-tau Hessian mismatch: rel={rel:.3e}, \
2526 analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
2527 );
2528 }
2529
2530 pub(crate) struct BinomialLogitDesignMotionFixture {
2540 pub(crate) y: Array1<f64>,
2541 pub(crate) w: Array1<f64>,
2542 pub(crate) x: Array2<f64>,
2543 pub(crate) s0: Array2<f64>,
2544 pub(crate) cfg: RemlConfig,
2545 pub(crate) rho: Array1<f64>,
2546 pub(crate) x_tau_design: Array2<f64>,
2548 pub(crate) s_tau_penalty: Array2<f64>,
2550 }
2551
2552 impl BinomialLogitDesignMotionFixture {
2553 pub(crate) fn new() -> Self {
2554 let y = array![
2556 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0,
2557 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0
2558 ];
2559 let x = array![
2561 [1.0, -1.50, 0.42, 0.88, -0.31],
2562 [1.0, -1.12, -0.65, 0.14, 1.23],
2563 [1.0, -0.80, 1.10, -0.53, 0.07],
2564 [1.0, -0.55, -0.22, 1.40, -0.90],
2565 [1.0, -0.30, 0.73, -1.05, 0.44],
2566 [1.0, -0.05, -1.33, 0.60, 0.81],
2567 [1.0, 0.18, 0.55, -0.27, -1.15],
2568 [1.0, 0.42, -0.90, 1.12, 0.33],
2569 [1.0, 0.70, 1.28, -0.78, -0.56],
2570 [1.0, 0.95, -0.18, 0.45, 1.40],
2571 [1.0, 1.20, 0.66, -1.30, -0.02],
2572 [1.0, 1.45, -1.05, 0.22, 0.68],
2573 [1.0, -1.35, 0.90, 0.55, -0.43],
2574 [1.0, -0.98, -0.40, -0.88, 1.05],
2575 [1.0, -0.62, 1.42, 0.30, -0.70],
2576 [1.0, -0.28, -0.77, -1.18, 0.52],
2577 [1.0, 0.05, 0.15, 0.95, -1.35],
2578 [1.0, 0.33, -1.20, -0.40, 0.18],
2579 [1.0, 0.60, 0.82, 1.25, -0.85],
2580 [1.0, 0.88, -0.50, -0.65, 1.10],
2581 [1.0, 1.15, 1.05, 0.10, -0.22],
2582 [1.0, -1.22, -0.95, 0.72, 0.90],
2583 [1.0, -0.75, 0.38, -1.42, 0.15],
2584 [1.0, -0.42, -1.15, 0.50, -1.08],
2585 [1.0, -0.10, 0.60, -0.15, 0.75],
2586 [1.0, 0.25, -0.28, 1.05, -0.48],
2587 [1.0, 0.52, 1.35, -0.92, 0.30],
2588 [1.0, 0.80, -0.70, 0.38, 1.20],
2589 [1.0, 1.08, 0.48, -0.60, -0.95],
2590 [1.0, 1.35, -0.55, 0.85, 0.42]
2591 ];
2592 let s0 = array![
2594 [0.0, 0.0, 0.0, 0.0, 0.0],
2595 [0.0, 1.40, 0.15, 0.05, -0.10],
2596 [0.0, 0.15, 1.10, -0.20, 0.08],
2597 [0.0, 0.05, -0.20, 0.95, 0.12],
2598 [0.0, -0.10, 0.08, 0.12, 1.25]
2599 ];
2600 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
2601 let x_tau_design = array![
2604 [0.0, 1.2e-3, -0.8e-3, 0.5e-3, -1.5e-3],
2605 [0.0, -2.0e-3, 1.4e-3, -0.3e-3, 0.9e-3],
2606 [0.0, 0.6e-3, -1.1e-3, 1.8e-3, -0.4e-3],
2607 [0.0, -1.3e-3, 0.7e-3, -1.0e-3, 2.1e-3],
2608 [0.0, 0.9e-3, -0.5e-3, 0.2e-3, -0.8e-3],
2609 [0.0, -0.4e-3, 1.8e-3, -1.5e-3, 0.3e-3],
2610 [0.0, 1.5e-3, -1.3e-3, 0.8e-3, -1.1e-3],
2611 [0.0, -0.7e-3, 0.4e-3, -2.0e-3, 1.6e-3],
2612 [0.0, 2.2e-3, -0.9e-3, 1.3e-3, -0.6e-3],
2613 [0.0, -1.0e-3, 1.6e-3, -0.7e-3, 0.5e-3],
2614 [0.0, 0.3e-3, -2.1e-3, 1.1e-3, -1.8e-3],
2615 [0.0, -1.8e-3, 0.2e-3, -0.4e-3, 1.3e-3],
2616 [0.0, 1.1e-3, -1.5e-3, 2.0e-3, -0.2e-3],
2617 [0.0, -0.5e-3, 0.9e-3, -1.2e-3, 0.7e-3],
2618 [0.0, 1.7e-3, -0.3e-3, 0.6e-3, -2.0e-3],
2619 [0.0, -1.4e-3, 1.1e-3, -0.9e-3, 0.4e-3],
2620 [0.0, 0.8e-3, -1.7e-3, 1.5e-3, -0.1e-3],
2621 [0.0, -0.2e-3, 0.6e-3, -1.8e-3, 1.0e-3],
2622 [0.0, 1.4e-3, -0.4e-3, 0.3e-3, -1.3e-3],
2623 [0.0, -0.9e-3, 2.0e-3, -0.5e-3, 0.8e-3],
2624 [0.0, 0.5e-3, -1.0e-3, 1.6e-3, -0.7e-3],
2625 [0.0, -2.1e-3, 0.3e-3, -0.8e-3, 1.5e-3],
2626 [0.0, 0.7e-3, -1.8e-3, 0.9e-3, -0.3e-3],
2627 [0.0, -0.6e-3, 1.3e-3, -2.2e-3, 1.1e-3],
2628 [0.0, 1.9e-3, -0.7e-3, 0.4e-3, -0.9e-3],
2629 [0.0, -1.1e-3, 0.5e-3, -1.4e-3, 2.2e-3],
2630 [0.0, 0.4e-3, -1.6e-3, 1.2e-3, -0.5e-3],
2631 [0.0, -1.6e-3, 0.8e-3, -0.1e-3, 0.6e-3],
2632 [0.0, 1.3e-3, -2.2e-3, 0.7e-3, -1.4e-3],
2633 [0.0, -0.3e-3, 1.0e-3, -1.6e-3, 1.8e-3]
2634 ];
2635 let s_tau_penalty = array![
2637 [0.0, 0.0, 0.0, 0.0, 0.0],
2638 [0.0, 0.30, 0.05, -0.02, 0.04],
2639 [0.0, 0.05, 0.22, 0.03, -0.01],
2640 [0.0, -0.02, 0.03, 0.18, 0.06],
2641 [0.0, 0.04, -0.01, 0.06, 0.26]
2642 ];
2643 Self {
2644 w: Array1::<f64>::ones(y.len()),
2645 y,
2646 x,
2647 s0,
2648 cfg,
2649 rho: array![0.0],
2650 x_tau_design,
2651 s_tau_penalty,
2652 }
2653 }
2654 }
2655
2656 impl LogitDesignMotionFixture for BinomialLogitDesignMotionFixture {
2657 fn y(&self) -> &Array1<f64> {
2658 &self.y
2659 }
2660 fn w(&self) -> &Array1<f64> {
2661 &self.w
2662 }
2663 fn x(&self) -> &Array2<f64> {
2664 &self.x
2665 }
2666 fn s0(&self) -> &Array2<f64> {
2667 &self.s0
2668 }
2669 fn cfg(&self) -> &RemlConfig {
2670 &self.cfg
2671 }
2672 fn rho(&self) -> &Array1<f64> {
2673 &self.rho
2674 }
2675 }
2676
2677 #[test]
2680 pub(crate) fn binomial_logit_n30_design_moving_gradient_matches_fd() {
2681 let f = BinomialLogitDesignMotionFixture::new();
2688 let state = f.state();
2689 let s_tau = Array2::<f64>::zeros((5, 5));
2690 let hyper = DirectionalHyperParam::single_penalty(
2691 0,
2692 f.x_tau_design.clone(),
2693 s_tau.clone(),
2694 None,
2695 None,
2696 )
2697 .expect("design-moving hyper direction");
2698
2699 let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2700 .expect("analytic directional gradient");
2701 let v_tau_fd = f.fd_directional_gradient(&f.x_tau_design, &s_tau);
2702
2703 let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2704 assert!(
2705 v_rel < 1e-3,
2706 "Binomial-logit n=30 design-moving gradient mismatch: rel={v_rel:.3e}, \
2707 analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2708 );
2709 }
2710
2711 #[test]
2712 pub(crate) fn binomial_logit_n30_penalty_only_gradient_matches_fd() {
2713 let f = BinomialLogitDesignMotionFixture::new();
2718 let state = f.state();
2719 let x_tau = Array2::<f64>::zeros(f.x.raw_dim());
2720 let hyper = DirectionalHyperParam::single_penalty(
2721 0,
2722 x_tau.clone(),
2723 f.s_tau_penalty.clone(),
2724 None,
2725 None,
2726 )
2727 .expect("penalty-only hyper direction");
2728
2729 let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2730 .expect("analytic directional gradient");
2731 let v_tau_fd = f.fd_directional_gradient(&x_tau, &f.s_tau_penalty);
2732
2733 let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2734 assert!(
2735 v_rel < 1e-3,
2736 "Binomial-logit n=30 penalty-only gradient mismatch: rel={v_rel:.3e}, \
2737 analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2738 );
2739 }
2740
2741 #[test]
2742 pub(crate) fn binomial_logit_n30_joint_design_penalty_gradient_matches_fd() {
2743 let f = BinomialLogitDesignMotionFixture::new();
2748 let state = f.state();
2749 let hyper = DirectionalHyperParam::single_penalty(
2750 0,
2751 f.x_tau_design.clone(),
2752 f.s_tau_penalty.clone(),
2753 None,
2754 None,
2755 )
2756 .expect("joint design+penalty hyper direction");
2757
2758 let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2759 .expect("analytic directional gradient");
2760 let v_tau_fd = f.fd_directional_gradient(&f.x_tau_design, &f.s_tau_penalty);
2761
2762 let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2763 assert!(
2764 v_rel < 1e-3,
2765 "Binomial-logit n=30 joint design+penalty gradient mismatch: rel={v_rel:.3e}, \
2766 analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2767 );
2768 }
2769
2770 #[test]
2771 pub(crate) fn binomial_logit_n30_design_moving_hessian_matches_fd() {
2772 let f = BinomialLogitDesignMotionFixture::new();
2777 let x_tau_0 = Array2::<f64>::zeros(f.x.raw_dim());
2778 let s_tau_0 = f.s_tau_penalty.clone();
2779 let x_tau_1 = f.x_tau_design.clone();
2780 let s_tau_1 = Array2::<f64>::zeros((5, 5));
2781
2782 let hyper_dirs = vec![
2783 DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2784 .expect("penalty-only direction"),
2785 DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2786 .expect("design-moving direction"),
2787 ];
2788
2789 let state = f.state();
2790 let mut theta = Array1::<f64>::zeros(f.rho.len() + hyper_dirs.len());
2791 theta.slice_mut(s![..f.rho.len()]).assign(&f.rho);
2792 let (_, _, h_full) =
2793 compute_joint_hypercostgradienthessian(&state, &theta, f.rho.len(), &hyper_dirs)
2794 .expect("joint cost+gradient+hessian");
2795 let h_tt_analytic = h_full.slice(s![f.rho.len().., f.rho.len()..]).to_owned();
2796
2797 let x_tau_mats = vec![x_tau_0.clone(), x_tau_1.clone()];
2798 let s_tau_mats = vec![s_tau_0.clone(), s_tau_1.clone()];
2799 let h_tt_fd = directional_tau_hessian_fd_reference(
2800 &f.y,
2801 &f.w,
2802 &f.x,
2803 &f.s0,
2804 &f.cfg,
2805 &f.rho,
2806 &hyper_dirs,
2807 &x_tau_mats,
2808 &s_tau_mats,
2809 );
2810
2811 let num = (&h_tt_analytic - &h_tt_fd)
2812 .iter()
2813 .map(|v| v * v)
2814 .sum::<f64>()
2815 .sqrt();
2816 let den = h_tt_fd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2817 let rel = num / den;
2818 assert!(
2819 rel < 1e-4,
2820 "Binomial-logit n=30 tau-tau Hessian mismatch: rel={rel:.3e}, \
2821 analytic={h_tt_analytic:?}, fd={h_tt_fd:?}"
2822 );
2823 }
2824
2825 #[test]
2826 pub(crate) fn binomial_logit_n30_nonzero_rho_design_moving_gradient_matches_fd() {
2827 let f = BinomialLogitDesignMotionFixture::new();
2831 let rho = array![1.5];
2832 let s_tau = Array2::<f64>::zeros((5, 5));
2833
2834 let state = f.state();
2835 let hyper = DirectionalHyperParam::single_penalty(
2836 0,
2837 f.x_tau_design.clone(),
2838 s_tau.clone(),
2839 None,
2840 None,
2841 )
2842 .expect("design-moving hyper direction");
2843
2844 let v_tau_analytic = single_directional_tau_gradient(&state, &rho, hyper)
2845 .expect("analytic directional gradient");
2846
2847 let h = 2e-5;
2849 let (state_plus, state_minus) = f.state_perturbed(&f.x_tau_design, &s_tau, h);
2850 let v_plus = state_plus.compute_cost(&rho).expect("cost+");
2851 let v_minus = state_minus.compute_cost(&rho).expect("cost-");
2852 let v_tau_fd = (v_plus - v_minus) / (2.0 * h);
2853
2854 let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2855 assert!(
2856 v_rel < 1e-3,
2857 "Binomial-logit n=30 rho=1.5 design-moving gradient mismatch: rel={v_rel:.3e}, \
2858 analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2859 );
2860 }
2861
2862 #[test]
2863 pub(crate) fn binomial_logit_n30_rank_deficient_hessian_matches_cost_fd() {
2864 let f = BinomialLogitDesignMotionFixture::new();
2899 let x_tau_0 = Array2::<f64>::zeros(f.x.raw_dim());
2900 let s_tau_0 = f.s_tau_penalty.clone();
2901 let x_tau_1 = f.x_tau_design.clone();
2902 let s_tau_1 = Array2::<f64>::zeros((5, 5));
2903
2904 let hyper_dirs = vec![
2905 DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2906 .expect("penalty-only direction"),
2907 DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2908 .expect("design-moving direction"),
2909 ];
2910
2911 let state = f.state();
2913 let mut theta = Array1::<f64>::zeros(f.rho.len() + hyper_dirs.len());
2914 theta.slice_mut(s![..f.rho.len()]).assign(&f.rho);
2915 let (_, _, h_full) =
2916 compute_joint_hypercostgradienthessian(&state, &theta, f.rho.len(), &hyper_dirs)
2917 .expect("joint cost+gradient+hessian");
2918 let h_tt_analytic = h_full.slice(s![f.rho.len().., f.rho.len()..]).to_owned();
2919
2920 const TARGET_PHYSICAL_STEP: f64 = 1e-5;
2924 let x_tau_mats = [&x_tau_0, &x_tau_1];
2925 let s_tau_mats = [&s_tau_0, &s_tau_1];
2926 let steps: [f64; 2] = {
2927 let mut steps = [0.0; 2];
2928 for (j, step) in steps.iter_mut().enumerate() {
2929 let scale = x_tau_mats[j]
2930 .iter()
2931 .chain(s_tau_mats[j].iter())
2932 .fold(0.0_f64, |acc, value| acc.max(value.abs()));
2933 *step = if scale > 0.0 {
2934 TARGET_PHYSICAL_STEP / scale
2935 } else {
2936 TARGET_PHYSICAL_STEP
2937 };
2938 }
2939 steps
2940 };
2941
2942 let eval_cost = |a: f64, b: f64| -> f64 {
2944 let x_eval = &f.x
2945 + &x_tau_mats[0].mapv(|v| a * steps[0] * v)
2946 + &x_tau_mats[1].mapv(|v| b * steps[1] * v);
2947 let s_eval = &f.s0
2948 + &s_tau_mats[0].mapv(|v| a * steps[0] * v)
2949 + &s_tau_mats[1].mapv(|v| b * steps[1] * v);
2950 let st = build_logit_state(&f.y, &f.w, &x_eval, &s_eval, &f.cfg);
2951 st.compute_cost(&f.rho).expect("cost eval")
2952 };
2953
2954 let v_00 = eval_cost(0.0, 0.0);
2955 let v_p0 = eval_cost(1.0, 0.0);
2956 let v_m0 = eval_cost(-1.0, 0.0);
2957 let v_0p = eval_cost(0.0, 1.0);
2958 let v_0m = eval_cost(0.0, -1.0);
2959 let v_pp = eval_cost(1.0, 1.0);
2960 let v_pm = eval_cost(1.0, -1.0);
2961 let v_mp = eval_cost(-1.0, 1.0);
2962 let v_mm = eval_cost(-1.0, -1.0);
2963
2964 let h00_fd = (v_p0 - 2.0 * v_00 + v_m0) / (steps[0] * steps[0]);
2965 let h11_fd = (v_0p - 2.0 * v_00 + v_0m) / (steps[1] * steps[1]);
2966 let h01_fd = (v_pp - v_pm - v_mp + v_mm) / (4.0 * steps[0] * steps[1]);
2967
2968 let h_tt_fd = array![[h00_fd, h01_fd], [h01_fd, h11_fd]];
2969
2970 let num = (&h_tt_analytic - &h_tt_fd)
2971 .iter()
2972 .map(|v| v * v)
2973 .sum::<f64>()
2974 .sqrt();
2975 let den = h_tt_fd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2976 let rel = num / den;
2977
2978 assert!(
2979 rel < 3e-3,
2980 "Binomial-logit n=30 rank-deficient Hessian vs cost-FD mismatch: rel={rel:.3e}, \
2981 analytic={h_tt_analytic:?}, fd={h_tt_fd:?}"
2982 );
2983 }
2984}
2985
2986#[derive(Clone, Copy, Debug)]
2987pub(crate) enum RemlGeometry {
2988 DenseSpectral,
2989 SparseExactSpd,
2990}
2991
2992trait PenalizedGeometry {
2993 fn backend_kind(&self) -> GeometryBackendKind;
2994}
2995
2996#[derive(Clone)]
2997pub(crate) enum DerivativeMatrixStorage {
2998 Dense(Array2<f64>),
2999 Zero(ZeroDerivativeMatrix),
3000 Embedded(EmbeddedDerivativeMatrix),
3001 Implicit(ImplicitDerivativeOp),
3002 LatentCoord(LatentCoordDerivativeOp),
3003}
3004
3005trait DerivativeStorageBackend {
3017 fn resident_byte_count(&self) -> usize;
3018 fn design_nrows(&self) -> usize;
3019 fn design_ncols(&self) -> usize;
3020 fn penalty_dim(&self) -> usize;
3021 fn uses_implicit_storage(&self) -> bool;
3022 fn any_nonzero(&self) -> bool;
3023 fn materialize(&self) -> Array2<f64>;
3024 fn implicit_first_axis_info(
3025 &self,
3026 ) -> Option<(
3027 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3028 usize,
3029 )>;
3030 fn implicit_axis_count_hint(&self) -> Option<usize>;
3031 fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError>;
3032 fn design_transpose_mul_original(
3033 &self,
3034 v: &Array1<f64>,
3035 ) -> Result<Array1<f64>, EstimationError>;
3036 fn design_transformed(
3037 &self,
3038 qs: &Array2<f64>,
3039 free_basis_opt: Option<&Array2<f64>>,
3040 ) -> Result<Array2<f64>, EstimationError>;
3041 fn design_transformed_forward_mul(
3045 &self,
3046 qs: &Array2<f64>,
3047 free_basis_opt: Option<&Array2<f64>>,
3048 u: &Array1<f64>,
3049 ) -> Result<Array1<f64>, EstimationError> {
3050 Ok(self.design_transformed(qs, free_basis_opt)?.dot(u))
3051 }
3052 fn design_transformed_transpose_mul(
3055 &self,
3056 qs: &Array2<f64>,
3057 free_basis_opt: Option<&Array2<f64>>,
3058 v: &Array1<f64>,
3059 ) -> Result<Array1<f64>, EstimationError> {
3060 Ok(self.design_transformed(qs, free_basis_opt)?.t().dot(v))
3061 }
3062 fn penalty_transformed(
3063 &self,
3064 qs: &Array2<f64>,
3065 free_basis_opt: Option<&Array2<f64>>,
3066 ) -> Result<Array2<f64>, EstimationError>;
3067 fn penalty_scaled_add_to(
3068 &self,
3069 target: &mut Array2<f64>,
3070 amp: f64,
3071 ) -> Result<(), EstimationError>;
3072}
3073
3074macro_rules! storage_dispatch {
3079 ($scrutinee:expr, $backend:ident => $body:expr) => {
3080 match $scrutinee {
3081 DerivativeMatrixStorage::Dense($backend) => $body,
3082 DerivativeMatrixStorage::Zero($backend) => $body,
3083 DerivativeMatrixStorage::Embedded($backend) => $body,
3084 DerivativeMatrixStorage::Implicit($backend) => $body,
3085 DerivativeMatrixStorage::LatentCoord($backend) => $body,
3086 }
3087 };
3088}
3089
3090#[derive(Clone)]
3091pub(crate) struct ZeroDerivativeMatrix {
3092 rows: usize,
3093 cols: usize,
3094}
3095
3096impl ZeroDerivativeMatrix {
3097 pub(crate) fn new(rows: usize, cols: usize) -> Self {
3098 Self { rows, cols }
3099 }
3100}
3101
3102#[derive(Clone, Copy, Debug)]
3104pub enum ImplicitDerivLevel {
3105 First(usize),
3107 SecondDiag(usize),
3109 SecondCross(usize, usize),
3111}
3112
3113#[derive(Clone)]
3116pub(crate) struct ImplicitDerivativeOp {
3117 pub(crate) operator: std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3118 pub(crate) level: ImplicitDerivLevel,
3119 pub(crate) global_range: Range<usize>,
3120 pub(crate) total_dim: usize,
3121 pub(crate) cached_dense: std::sync::Arc<gam_runtime::resource::RayonSafeOnce<Array2<f64>>>,
3131}
3132
3133#[derive(Clone)]
3134pub(crate) struct LatentCoordDerivativeOp {
3135 pub(crate) operator: std::sync::Arc<gam_terms::basis::LatentCoordDesignDerivative>,
3136 pub(crate) flat_axis: usize,
3137 pub(crate) global_range: Range<usize>,
3138 pub(crate) total_dim: usize,
3139 pub(crate) cached_dense: std::sync::Arc<gam_runtime::resource::RayonSafeOnce<Array2<f64>>>,
3140}
3141
3142impl LatentCoordDerivativeOp {
3143 pub(crate) fn materialize_local(&self) -> Array2<f64> {
3144 self.operator.materialize_axis(self.flat_axis).expect(
3145 "radial scalar evaluation failed during latent-coordinate derivative materialization",
3146 )
3147 }
3148
3149 pub(crate) fn materialize_dense(&self) -> &Array2<f64> {
3150 self.cached_dense.get_or_compute(|| {
3151 let local = self.materialize_local();
3152 let mut out = Array2::<f64>::zeros((local.nrows(), self.total_dim));
3153 out.slice_mut(s![.., self.global_range.clone()])
3154 .assign(&local);
3155 out
3156 })
3157 }
3158
3159 pub(crate) fn nrows(&self) -> usize {
3160 self.operator.n_data()
3161 }
3162
3163 pub(crate) fn ncols(&self) -> usize {
3164 self.total_dim
3165 }
3166
3167 pub(crate) fn transpose_mul(&self, v: &Array1<f64>) -> Array1<f64> {
3168 let local = self
3169 .operator
3170 .transpose_mul_axis(self.flat_axis, &v.view())
3171 .expect(
3172 "radial scalar evaluation failed during latent-coordinate derivative transpose_mul",
3173 );
3174 let mut out = Array1::<f64>::zeros(self.total_dim);
3175 out.slice_mut(s![self.global_range.clone()]).assign(&local);
3176 out
3177 }
3178
3179 pub(crate) fn forward_mul(&self, u: &Array1<f64>) -> Array1<f64> {
3180 let u_local = u.slice(s![self.global_range.clone()]).to_owned();
3181 self.operator
3182 .forward_mul_axis(self.flat_axis, &u_local.view())
3183 .expect(
3184 "radial scalar evaluation failed during latent-coordinate derivative forward_mul",
3185 )
3186 }
3187}
3188
3189impl ImplicitDerivativeOp {
3190 pub(crate) fn materialize_local(&self) -> Array2<f64> {
3191 match self.level {
3192 ImplicitDerivLevel::First(axis) => self.operator.materialize_first(axis).expect(
3193 "radial scalar evaluation failed during implicit derivative materialization",
3194 ),
3195 ImplicitDerivLevel::SecondDiag(axis) => {
3196 self.operator.materialize_second_diag(axis).expect(
3197 "radial scalar evaluation failed during implicit derivative materialization",
3198 )
3199 }
3200 ImplicitDerivLevel::SecondCross(d, e) => {
3201 self.operator.materialize_second_cross(d, e).expect(
3202 "radial scalar evaluation failed during implicit derivative materialization",
3203 )
3204 }
3205 }
3206 }
3207
3208 pub(crate) fn materialize_dense(&self) -> &Array2<f64> {
3209 self.cached_dense.get_or_compute(|| {
3210 let local = self.materialize_local();
3211 let mut out = Array2::<f64>::zeros((local.nrows(), self.total_dim));
3212 out.slice_mut(s![.., self.global_range.clone()])
3213 .assign(&local);
3214 out
3215 })
3216 }
3217
3218 pub(crate) fn nrows(&self) -> usize {
3219 self.operator.n_data()
3220 }
3221
3222 pub(crate) fn ncols(&self) -> usize {
3223 self.total_dim
3224 }
3225
3226 pub(crate) fn transpose_mul(&self, v: &Array1<f64>) -> Array1<f64> {
3227 let local = match self.level {
3228 ImplicitDerivLevel::First(axis) => self
3229 .operator
3230 .transpose_mul(axis, &v.view())
3231 .expect("radial scalar evaluation failed during implicit derivative transpose_mul"),
3232 ImplicitDerivLevel::SecondDiag(axis) => self
3233 .operator
3234 .transpose_mul_second_diag(axis, &v.view())
3235 .expect("radial scalar evaluation failed during implicit derivative transpose_mul"),
3236 ImplicitDerivLevel::SecondCross(d, e) => self
3237 .operator
3238 .transpose_mul_second_cross(d, e, &v.view())
3239 .expect("radial scalar evaluation failed during implicit derivative transpose_mul"),
3240 };
3241 let mut out = Array1::<f64>::zeros(self.total_dim);
3242 out.slice_mut(s![self.global_range.clone()]).assign(&local);
3243 out
3244 }
3245
3246 pub(crate) fn forward_mul(&self, u: &Array1<f64>) -> Array1<f64> {
3247 let u_local = u.slice(s![self.global_range.clone()]).to_owned();
3248 match self.level {
3249 ImplicitDerivLevel::First(axis) => self
3250 .operator
3251 .forward_mul(axis, &u_local.view())
3252 .expect("radial scalar evaluation failed during implicit derivative forward_mul"),
3253 ImplicitDerivLevel::SecondDiag(axis) => self
3254 .operator
3255 .forward_mul_second_diag(axis, &u_local.view())
3256 .expect("radial scalar evaluation failed during implicit derivative forward_mul"),
3257 ImplicitDerivLevel::SecondCross(d, e) => self
3258 .operator
3259 .forward_mul_second_cross(d, e, &u_local.view())
3260 .expect("radial scalar evaluation failed during implicit derivative forward_mul"),
3261 }
3262 }
3263}
3264
3265#[derive(Clone)]
3266pub(crate) struct EmbeddedDerivativeMatrix {
3267 pub(crate) local: Array2<f64>,
3268 pub(crate) global_range: Range<usize>,
3269 pub(crate) total_dim: usize,
3270}
3271
3272impl EmbeddedDerivativeMatrix {
3273 pub(crate) fn new(local: Array2<f64>, global_range: Range<usize>, total_dim: usize) -> Self {
3274 Self {
3275 local,
3276 global_range,
3277 total_dim,
3278 }
3279 }
3280}
3281
3282impl DerivativeStorageBackend for Array2<f64> {
3283 fn resident_byte_count(&self) -> usize {
3284 self.len().saturating_mul(std::mem::size_of::<f64>())
3285 }
3286 fn design_nrows(&self) -> usize {
3287 Array2::nrows(self)
3288 }
3289 fn design_ncols(&self) -> usize {
3290 Array2::ncols(self)
3291 }
3292 fn penalty_dim(&self) -> usize {
3293 Array2::nrows(self)
3294 }
3295 fn uses_implicit_storage(&self) -> bool {
3296 false
3297 }
3298 fn any_nonzero(&self) -> bool {
3299 self.iter().any(|v| *v != 0.0)
3300 }
3301 fn materialize(&self) -> Array2<f64> {
3302 self.clone()
3303 }
3304 fn implicit_first_axis_info(
3305 &self,
3306 ) -> Option<(
3307 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3308 usize,
3309 )> {
3310 None
3311 }
3312 fn implicit_axis_count_hint(&self) -> Option<usize> {
3313 None
3314 }
3315
3316 fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3317 if Array2::ncols(self) != u.len() {
3318 crate::bail_invalid_estim!(
3319 "dense hyper design derivative forward_mul_original width mismatch: matrix={}x{}, vector={}",
3320 Array2::nrows(self),
3321 Array2::ncols(self),
3322 u.len()
3323 );
3324 }
3325 Ok(self.dot(u))
3326 }
3327
3328 fn design_transpose_mul_original(
3329 &self,
3330 v: &Array1<f64>,
3331 ) -> Result<Array1<f64>, EstimationError> {
3332 if Array2::nrows(self) != v.len() {
3333 crate::bail_invalid_estim!(
3334 "dense hyper design derivative transpose_mul_original height mismatch: matrix={}x{}, vector={}",
3335 Array2::nrows(self),
3336 Array2::ncols(self),
3337 v.len()
3338 );
3339 }
3340 Ok(self.t().dot(v))
3341 }
3342
3343 fn design_transformed(
3344 &self,
3345 qs: &Array2<f64>,
3346 free_basis_opt: Option<&Array2<f64>>,
3347 ) -> Result<Array2<f64>, EstimationError> {
3348 Ok(gam_linalg::matrix::DenseRightProductView::new(self)
3349 .with_factor(qs)
3350 .with_optional_factor(free_basis_opt)
3351 .materialize())
3352 }
3353
3354 fn penalty_transformed(
3355 &self,
3356 qs: &Array2<f64>,
3357 free_basis_opt: Option<&Array2<f64>>,
3358 ) -> Result<Array2<f64>, EstimationError> {
3359 let mut transformed = qs.t().dot(self).dot(qs);
3360 if let Some(z) = free_basis_opt {
3361 transformed = z.t().dot(&transformed).dot(z);
3362 }
3363 Ok(transformed)
3364 }
3365
3366 fn penalty_scaled_add_to(
3367 &self,
3368 target: &mut Array2<f64>,
3369 amp: f64,
3370 ) -> Result<(), EstimationError> {
3371 if target.raw_dim() != self.raw_dim() {
3372 crate::bail_invalid_estim!(
3373 "dense hyper penalty derivative shape mismatch: target={}x{}, matrix={}x{}",
3374 target.nrows(),
3375 target.ncols(),
3376 Array2::nrows(self),
3377 Array2::ncols(self)
3378 );
3379 }
3380 target.scaled_add(amp, self);
3381 Ok(())
3382 }
3383}
3384
3385impl DerivativeStorageBackend for ZeroDerivativeMatrix {
3386 fn resident_byte_count(&self) -> usize {
3387 0
3388 }
3389 fn design_nrows(&self) -> usize {
3390 self.rows
3391 }
3392 fn design_ncols(&self) -> usize {
3393 self.cols
3394 }
3395 fn penalty_dim(&self) -> usize {
3396 self.cols
3397 }
3398 fn uses_implicit_storage(&self) -> bool {
3399 false
3400 }
3401 fn any_nonzero(&self) -> bool {
3402 false
3403 }
3404 fn materialize(&self) -> Array2<f64> {
3405 Array2::<f64>::zeros((self.rows, self.cols))
3406 }
3407 fn implicit_first_axis_info(
3408 &self,
3409 ) -> Option<(
3410 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3411 usize,
3412 )> {
3413 None
3414 }
3415 fn implicit_axis_count_hint(&self) -> Option<usize> {
3416 None
3417 }
3418
3419 fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3420 if self.cols != u.len() {
3421 crate::bail_invalid_estim!(
3422 "zero hyper design derivative forward_mul_original width mismatch: matrix={}x{}, vector={}",
3423 self.rows,
3424 self.cols,
3425 u.len()
3426 );
3427 }
3428 Ok(Array1::<f64>::zeros(self.rows))
3429 }
3430
3431 fn design_transpose_mul_original(
3432 &self,
3433 v: &Array1<f64>,
3434 ) -> Result<Array1<f64>, EstimationError> {
3435 if self.rows != v.len() {
3436 crate::bail_invalid_estim!(
3437 "zero hyper design derivative transpose_mul_original height mismatch: matrix={}x{}, vector={}",
3438 self.rows,
3439 self.cols,
3440 v.len()
3441 );
3442 }
3443 Ok(Array1::<f64>::zeros(self.cols))
3444 }
3445
3446 fn design_transformed(
3447 &self,
3448 qs: &Array2<f64>,
3449 free_basis_opt: Option<&Array2<f64>>,
3450 ) -> Result<Array2<f64>, EstimationError> {
3451 if self.cols != qs.nrows() {
3452 crate::bail_invalid_estim!(
3453 "zero design derivative width mismatch: total_cols={}, qs rows={}",
3454 self.cols,
3455 qs.nrows()
3456 );
3457 }
3458 let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3459 Ok(Array2::<f64>::zeros((self.rows, cols)))
3460 }
3461
3462 fn design_transformed_forward_mul(
3463 &self,
3464 qs: &Array2<f64>,
3465 free_basis_opt: Option<&Array2<f64>>,
3466 u: &Array1<f64>,
3467 ) -> Result<Array1<f64>, EstimationError> {
3468 if self.cols != qs.nrows() {
3469 crate::bail_invalid_estim!(
3470 "zero design derivative width mismatch: total_cols={}, qs rows={}",
3471 self.cols,
3472 qs.nrows()
3473 );
3474 }
3475 let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3476 if u.len() != cols {
3477 crate::bail_invalid_estim!(
3478 "zero design derivative transformed forward width mismatch: expected {}, vector={}",
3479 cols,
3480 u.len()
3481 );
3482 }
3483 Ok(Array1::<f64>::zeros(self.rows))
3484 }
3485
3486 fn design_transformed_transpose_mul(
3487 &self,
3488 qs: &Array2<f64>,
3489 free_basis_opt: Option<&Array2<f64>>,
3490 v: &Array1<f64>,
3491 ) -> Result<Array1<f64>, EstimationError> {
3492 if self.rows != v.len() {
3493 crate::bail_invalid_estim!(
3494 "zero design derivative transpose height mismatch: matrix rows={}, vector={}",
3495 self.rows,
3496 v.len()
3497 );
3498 }
3499 if self.cols != qs.nrows() {
3500 crate::bail_invalid_estim!(
3501 "zero design derivative width mismatch: total_cols={}, qs rows={}",
3502 self.cols,
3503 qs.nrows()
3504 );
3505 }
3506 let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3507 Ok(Array1::<f64>::zeros(cols))
3508 }
3509
3510 fn penalty_transformed(
3511 &self,
3512 qs: &Array2<f64>,
3513 free_basis_opt: Option<&Array2<f64>>,
3514 ) -> Result<Array2<f64>, EstimationError> {
3515 if self.cols != qs.nrows() {
3516 crate::bail_invalid_estim!(
3517 "zero penalty derivative width mismatch: total_dim={}, qs rows={}",
3518 self.cols,
3519 qs.nrows()
3520 );
3521 }
3522 let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3523 Ok(Array2::<f64>::zeros((cols, cols)))
3524 }
3525
3526 fn penalty_scaled_add_to(
3527 &self,
3528 target: &mut Array2<f64>,
3529 amp: f64,
3530 ) -> Result<(), EstimationError> {
3531 if !amp.is_finite() {
3535 crate::bail_invalid_estim!(
3536 "zero hyper penalty derivative received non-finite amp={amp}"
3537 );
3538 }
3539 if target.nrows() != self.cols || target.ncols() != self.cols {
3540 crate::bail_invalid_estim!(
3541 "zero hyper penalty derivative shape mismatch: target={}x{}, expected {}x{}",
3542 target.nrows(),
3543 target.ncols(),
3544 self.cols,
3545 self.cols
3546 );
3547 }
3548 Ok(())
3549 }
3550}
3551
3552impl DerivativeStorageBackend for EmbeddedDerivativeMatrix {
3553 fn resident_byte_count(&self) -> usize {
3554 self.local.len().saturating_mul(std::mem::size_of::<f64>())
3555 }
3556 fn design_nrows(&self) -> usize {
3557 self.local.nrows()
3558 }
3559 fn design_ncols(&self) -> usize {
3560 self.total_dim
3561 }
3562 fn penalty_dim(&self) -> usize {
3563 self.total_dim
3564 }
3565 fn uses_implicit_storage(&self) -> bool {
3566 false
3567 }
3568 fn any_nonzero(&self) -> bool {
3569 self.local.iter().any(|v| *v != 0.0)
3570 }
3571 fn materialize(&self) -> Array2<f64> {
3572 let mut dense = Array2::<f64>::zeros((self.local.nrows(), self.total_dim));
3573 dense
3574 .slice_mut(s![.., self.global_range.clone()])
3575 .assign(&self.local);
3576 dense
3577 }
3578 fn implicit_first_axis_info(
3579 &self,
3580 ) -> Option<(
3581 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3582 usize,
3583 )> {
3584 None
3585 }
3586 fn implicit_axis_count_hint(&self) -> Option<usize> {
3587 None
3588 }
3589
3590 fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3591 if self.total_dim != u.len() {
3592 crate::bail_invalid_estim!(
3593 "embedded hyper design derivative forward_mul_original width mismatch: total_dim={}, vector={}",
3594 self.total_dim,
3595 u.len()
3596 );
3597 }
3598 let u_local = u.slice(s![self.global_range.clone()]).to_owned();
3599 Ok(self.local.dot(&u_local))
3600 }
3601
3602 fn design_transpose_mul_original(
3603 &self,
3604 v: &Array1<f64>,
3605 ) -> Result<Array1<f64>, EstimationError> {
3606 if self.local.nrows() != v.len() {
3607 crate::bail_invalid_estim!(
3608 "embedded hyper design derivative transpose_mul_original height mismatch: local_rows={}, vector={}",
3609 self.local.nrows(),
3610 v.len()
3611 );
3612 }
3613 let mut out = Array1::<f64>::zeros(self.total_dim);
3614 let pulled = self.local.t().dot(v);
3615 out.slice_mut(s![self.global_range.clone()]).assign(&pulled);
3616 Ok(out)
3617 }
3618
3619 fn design_transformed(
3620 &self,
3621 qs: &Array2<f64>,
3622 free_basis_opt: Option<&Array2<f64>>,
3623 ) -> Result<Array2<f64>, EstimationError> {
3624 if self.total_dim != qs.nrows() {
3625 crate::bail_invalid_estim!(
3626 "embedded design derivative width mismatch: total_cols={}, qs rows={}",
3627 self.total_dim,
3628 qs.nrows()
3629 );
3630 }
3631 let qs_local = qs.slice(s![self.global_range.clone(), ..]);
3632 let mut transformed = self.local.dot(&qs_local);
3633 if let Some(z) = free_basis_opt {
3634 transformed = transformed.dot(z);
3635 }
3636 Ok(transformed)
3637 }
3638
3639 fn penalty_transformed(
3640 &self,
3641 qs: &Array2<f64>,
3642 free_basis_opt: Option<&Array2<f64>>,
3643 ) -> Result<Array2<f64>, EstimationError> {
3644 if self.total_dim != qs.nrows() {
3645 crate::bail_invalid_estim!(
3646 "embedded penalty derivative width mismatch: total_dim={}, qs rows={}",
3647 self.total_dim,
3648 qs.nrows()
3649 );
3650 }
3651 let qs_local = qs.slice(s![self.global_range.clone(), ..]);
3652 let mut transformed = qs_local.t().dot(&self.local).dot(&qs_local);
3653 if let Some(z) = free_basis_opt {
3654 transformed = z.t().dot(&transformed).dot(z);
3655 }
3656 Ok(transformed)
3657 }
3658
3659 fn penalty_scaled_add_to(
3660 &self,
3661 target: &mut Array2<f64>,
3662 amp: f64,
3663 ) -> Result<(), EstimationError> {
3664 if target.nrows() != self.total_dim || target.ncols() != self.total_dim {
3665 crate::bail_invalid_estim!(
3666 "embedded hyper penalty derivative shape mismatch: target={}x{}, expected {}x{}",
3667 target.nrows(),
3668 target.ncols(),
3669 self.total_dim,
3670 self.total_dim
3671 );
3672 }
3673 target
3674 .slice_mut(s![self.global_range.clone(), self.global_range.clone()])
3675 .scaled_add(amp, &self.local);
3676 Ok(())
3677 }
3678}
3679
3680impl DerivativeStorageBackend for ImplicitDerivativeOp {
3681 fn resident_byte_count(&self) -> usize {
3682 0
3683 }
3684 fn design_nrows(&self) -> usize {
3685 self.nrows()
3686 }
3687 fn design_ncols(&self) -> usize {
3688 self.ncols()
3689 }
3690 fn penalty_dim(&self) -> usize {
3691 self.nrows()
3692 }
3693 fn uses_implicit_storage(&self) -> bool {
3694 true
3695 }
3696 fn any_nonzero(&self) -> bool {
3697 true
3698 }
3699 fn materialize(&self) -> Array2<f64> {
3700 self.materialize_dense().clone()
3701 }
3702 fn implicit_first_axis_info(
3703 &self,
3704 ) -> Option<(
3705 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3706 usize,
3707 )> {
3708 match self.level {
3709 ImplicitDerivLevel::First(axis) => Some((self.operator.clone(), axis)),
3710 _ => None,
3711 }
3712 }
3713 fn implicit_axis_count_hint(&self) -> Option<usize> {
3714 Some(self.operator.n_axes())
3715 }
3716
3717 fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3718 if self.ncols() != u.len() {
3719 crate::bail_invalid_estim!(
3720 "implicit hyper design derivative forward_mul_original width mismatch: operator_cols={}, vector={}",
3721 self.ncols(),
3722 u.len()
3723 );
3724 }
3725 Ok(self.forward_mul(u))
3726 }
3727
3728 fn design_transpose_mul_original(
3729 &self,
3730 v: &Array1<f64>,
3731 ) -> Result<Array1<f64>, EstimationError> {
3732 if self.nrows() != v.len() {
3733 crate::bail_invalid_estim!(
3734 "implicit hyper design derivative transpose_mul_original height mismatch: operator_rows={}, vector={}",
3735 self.nrows(),
3736 v.len()
3737 );
3738 }
3739 Ok(self.transpose_mul(v))
3740 }
3741
3742 fn design_transformed(
3743 &self,
3744 qs: &Array2<f64>,
3745 free_basis_opt: Option<&Array2<f64>>,
3746 ) -> Result<Array2<f64>, EstimationError> {
3747 let dense = self.materialize_dense();
3748 Ok(gam_linalg::matrix::DenseRightProductView::new(dense)
3749 .with_factor(qs)
3750 .with_optional_factor(free_basis_opt)
3751 .materialize())
3752 }
3753
3754 fn design_transformed_forward_mul(
3755 &self,
3756 qs: &Array2<f64>,
3757 free_basis_opt: Option<&Array2<f64>>,
3758 u: &Array1<f64>,
3759 ) -> Result<Array1<f64>, EstimationError> {
3760 let mut right = if let Some(z) = free_basis_opt {
3761 z.dot(u)
3762 } else {
3763 u.clone()
3764 };
3765 right = qs.dot(&right);
3766 Ok(self.forward_mul(&right))
3767 }
3768
3769 fn design_transformed_transpose_mul(
3770 &self,
3771 qs: &Array2<f64>,
3772 free_basis_opt: Option<&Array2<f64>>,
3773 v: &Array1<f64>,
3774 ) -> Result<Array1<f64>, EstimationError> {
3775 let mut pulled = qs.t().dot(&self.transpose_mul(v));
3776 if let Some(z) = free_basis_opt {
3777 pulled = z.t().dot(&pulled);
3778 }
3779 Ok(pulled)
3780 }
3781
3782 fn penalty_transformed(
3783 &self,
3784 qs: &Array2<f64>,
3785 free_basis_opt: Option<&Array2<f64>>,
3786 ) -> Result<Array2<f64>, EstimationError> {
3787 let dense = self.materialize_dense();
3788 let mut transformed = qs.t().dot(dense).dot(qs);
3789 if let Some(z) = free_basis_opt {
3790 transformed = z.t().dot(&transformed).dot(z);
3791 }
3792 Ok(transformed)
3793 }
3794
3795 fn penalty_scaled_add_to(
3796 &self,
3797 target: &mut Array2<f64>,
3798 amp: f64,
3799 ) -> Result<(), EstimationError> {
3800 let dense = self.materialize_dense();
3801 if target.raw_dim() != dense.raw_dim() {
3802 crate::bail_invalid_estim!(
3803 "implicit hyper penalty derivative shape mismatch: target={}x{}, matrix={}x{}",
3804 target.nrows(),
3805 target.ncols(),
3806 dense.nrows(),
3807 dense.ncols()
3808 );
3809 }
3810 target.scaled_add(amp, dense);
3811 Ok(())
3812 }
3813}
3814
3815impl DerivativeStorageBackend for LatentCoordDerivativeOp {
3816 fn resident_byte_count(&self) -> usize {
3817 0
3818 }
3819 fn design_nrows(&self) -> usize {
3820 self.nrows()
3821 }
3822 fn design_ncols(&self) -> usize {
3823 self.ncols()
3824 }
3825 fn penalty_dim(&self) -> usize {
3826 self.nrows()
3827 }
3828 fn uses_implicit_storage(&self) -> bool {
3829 true
3830 }
3831 fn any_nonzero(&self) -> bool {
3832 true
3833 }
3834 fn materialize(&self) -> Array2<f64> {
3835 self.materialize_dense().clone()
3836 }
3837 fn implicit_first_axis_info(
3838 &self,
3839 ) -> Option<(
3840 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3841 usize,
3842 )> {
3843 None
3844 }
3845 fn implicit_axis_count_hint(&self) -> Option<usize> {
3846 Some(self.operator.n_axes())
3847 }
3848
3849 fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3850 if self.ncols() != u.len() {
3851 crate::bail_invalid_estim!(
3852 "latent-coordinate hyper design derivative forward_mul_original width mismatch: operator_cols={}, vector={}",
3853 self.ncols(),
3854 u.len()
3855 );
3856 }
3857 Ok(self.forward_mul(u))
3858 }
3859
3860 fn design_transpose_mul_original(
3861 &self,
3862 v: &Array1<f64>,
3863 ) -> Result<Array1<f64>, EstimationError> {
3864 if self.nrows() != v.len() {
3865 crate::bail_invalid_estim!(
3866 "latent-coordinate hyper design derivative transpose_mul_original height mismatch: operator_rows={}, vector={}",
3867 self.nrows(),
3868 v.len()
3869 );
3870 }
3871 Ok(self.transpose_mul(v))
3872 }
3873
3874 fn design_transformed(
3875 &self,
3876 qs: &Array2<f64>,
3877 free_basis_opt: Option<&Array2<f64>>,
3878 ) -> Result<Array2<f64>, EstimationError> {
3879 let dense = self.materialize_dense();
3880 Ok(gam_linalg::matrix::DenseRightProductView::new(dense)
3881 .with_factor(qs)
3882 .with_optional_factor(free_basis_opt)
3883 .materialize())
3884 }
3885
3886 fn design_transformed_forward_mul(
3887 &self,
3888 qs: &Array2<f64>,
3889 free_basis_opt: Option<&Array2<f64>>,
3890 u: &Array1<f64>,
3891 ) -> Result<Array1<f64>, EstimationError> {
3892 let mut right = if let Some(z) = free_basis_opt {
3893 z.dot(u)
3894 } else {
3895 u.clone()
3896 };
3897 right = qs.dot(&right);
3898 Ok(self.forward_mul(&right))
3899 }
3900
3901 fn design_transformed_transpose_mul(
3902 &self,
3903 qs: &Array2<f64>,
3904 free_basis_opt: Option<&Array2<f64>>,
3905 v: &Array1<f64>,
3906 ) -> Result<Array1<f64>, EstimationError> {
3907 let mut pulled = qs.t().dot(&self.transpose_mul(v));
3908 if let Some(z) = free_basis_opt {
3909 pulled = z.t().dot(&pulled);
3910 }
3911 Ok(pulled)
3912 }
3913
3914 fn penalty_transformed(
3915 &self,
3916 qs: &Array2<f64>,
3917 free_basis_opt: Option<&Array2<f64>>,
3918 ) -> Result<Array2<f64>, EstimationError> {
3919 let dense = self.materialize_dense();
3920 let mut transformed = qs.t().dot(dense).dot(qs);
3921 if let Some(z) = free_basis_opt {
3922 transformed = z.t().dot(&transformed).dot(z);
3923 }
3924 Ok(transformed)
3925 }
3926
3927 fn penalty_scaled_add_to(
3928 &self,
3929 target: &mut Array2<f64>,
3930 amp: f64,
3931 ) -> Result<(), EstimationError> {
3932 let dense = self.materialize_dense();
3933 if target.raw_dim() != dense.raw_dim() {
3934 crate::bail_invalid_estim!(
3935 "latent-coordinate hyper penalty derivative shape mismatch: target={}x{}, matrix={}x{}",
3936 target.nrows(),
3937 target.ncols(),
3938 dense.nrows(),
3939 dense.ncols()
3940 );
3941 }
3942 target.scaled_add(amp, dense);
3943 Ok(())
3944 }
3945}
3946
3947#[derive(Clone)]
3948pub struct HyperDesignDerivative {
3949 pub(crate) storage: DerivativeMatrixStorage,
3950}
3951
3952impl HyperDesignDerivative {
3953 pub fn zero(nrows: usize, ncols: usize) -> Self {
3954 Self {
3955 storage: DerivativeMatrixStorage::Zero(ZeroDerivativeMatrix::new(nrows, ncols)),
3956 }
3957 }
3958
3959 pub fn from_embedded(
3960 local: Array2<f64>,
3961 global_range: Range<usize>,
3962 total_cols: usize,
3963 ) -> Self {
3964 Self {
3965 storage: DerivativeMatrixStorage::Embedded(EmbeddedDerivativeMatrix::new(
3966 local,
3967 global_range,
3968 total_cols,
3969 )),
3970 }
3971 }
3972
3973 pub fn from_implicit(
3974 operator: std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3975 level: ImplicitDerivLevel,
3976 global_range: Range<usize>,
3977 total_cols: usize,
3978 ) -> Self {
3979 Self {
3980 storage: DerivativeMatrixStorage::Implicit(ImplicitDerivativeOp {
3981 operator,
3982 level,
3983 global_range,
3984 total_dim: total_cols,
3985 cached_dense: std::sync::Arc::new(gam_runtime::resource::RayonSafeOnce::new()),
3986 }),
3987 }
3988 }
3989
3990 pub fn from_latent_coord(
3991 operator: std::sync::Arc<gam_terms::basis::LatentCoordDesignDerivative>,
3992 flat_axis: usize,
3993 global_range: Range<usize>,
3994 total_cols: usize,
3995 ) -> Self {
3996 Self {
3997 storage: DerivativeMatrixStorage::LatentCoord(LatentCoordDerivativeOp {
3998 operator,
3999 flat_axis,
4000 global_range,
4001 total_dim: total_cols,
4002 cached_dense: std::sync::Arc::new(gam_runtime::resource::RayonSafeOnce::new()),
4003 }),
4004 }
4005 }
4006
4007 pub(crate) fn resident_byte_count(&self) -> usize {
4008 storage_dispatch!(&self.storage, b => b.resident_byte_count())
4009 }
4010
4011 pub(crate) fn nrows(&self) -> usize {
4012 storage_dispatch!(&self.storage, b => b.design_nrows())
4013 }
4014
4015 pub(crate) fn ncols(&self) -> usize {
4016 storage_dispatch!(&self.storage, b => b.design_ncols())
4017 }
4018
4019 pub(crate) fn uses_implicit_storage(&self) -> bool {
4020 storage_dispatch!(&self.storage, b => b.uses_implicit_storage())
4021 }
4022
4023 pub(crate) fn materialize(&self) -> Array2<f64> {
4024 storage_dispatch!(&self.storage, b => b.materialize())
4025 }
4026
4027 pub(crate) fn any_nonzero(&self) -> bool {
4028 storage_dispatch!(&self.storage, b => b.any_nonzero())
4029 }
4030
4031 pub(crate) fn forward_mul_original(
4032 &self,
4033 u: &Array1<f64>,
4034 ) -> Result<Array1<f64>, EstimationError> {
4035 storage_dispatch!(&self.storage, b => b.design_forward_mul_original(u))
4036 }
4037
4038 pub(crate) fn transpose_mul_original(
4039 &self,
4040 v: &Array1<f64>,
4041 ) -> Result<Array1<f64>, EstimationError> {
4042 storage_dispatch!(&self.storage, b => b.design_transpose_mul_original(v))
4043 }
4044
4045 pub(crate) fn transformed(
4046 &self,
4047 qs: &Array2<f64>,
4048 free_basis_opt: Option<&Array2<f64>>,
4049 ) -> Result<Array2<f64>, EstimationError> {
4050 storage_dispatch!(&self.storage, b => b.design_transformed(qs, free_basis_opt))
4051 }
4052
4053 pub(crate) fn transformed_forward_mul(
4054 &self,
4055 qs: &Array2<f64>,
4056 free_basis_opt: Option<&Array2<f64>>,
4057 u: &Array1<f64>,
4058 ) -> Result<Array1<f64>, EstimationError> {
4059 storage_dispatch!(&self.storage, b => b.design_transformed_forward_mul(qs, free_basis_opt, u))
4060 }
4061
4062 pub(crate) fn transformed_transpose_mul(
4063 &self,
4064 qs: &Array2<f64>,
4065 free_basis_opt: Option<&Array2<f64>>,
4066 v: &Array1<f64>,
4067 ) -> Result<Array1<f64>, EstimationError> {
4068 storage_dispatch!(&self.storage, b => b.design_transformed_transpose_mul(qs, free_basis_opt, v))
4069 }
4070
4071 pub(crate) fn implicit_first_axis_info(
4076 &self,
4077 ) -> Option<(
4078 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
4079 usize,
4080 )> {
4081 storage_dispatch!(&self.storage, b => b.implicit_first_axis_info())
4082 }
4083
4084 pub(crate) fn implicit_axis_count_hint(&self) -> Option<usize> {
4085 storage_dispatch!(&self.storage, b => b.implicit_axis_count_hint())
4086 }
4087}
4088
4089impl From<Array2<f64>> for HyperDesignDerivative {
4090 fn from(value: Array2<f64>) -> Self {
4091 Self {
4092 storage: DerivativeMatrixStorage::Dense(value),
4093 }
4094 }
4095}
4096
4097#[derive(Clone)]
4098pub struct HyperPenaltyDerivative {
4099 pub(crate) storage: DerivativeMatrixStorage,
4100}
4101
4102impl HyperPenaltyDerivative {
4103 pub fn from_embedded(
4104 local: Array2<f64>,
4105 global_range: Range<usize>,
4106 total_dim: usize,
4107 ) -> Self {
4108 Self {
4109 storage: DerivativeMatrixStorage::Embedded(EmbeddedDerivativeMatrix::new(
4110 local,
4111 global_range,
4112 total_dim,
4113 )),
4114 }
4115 }
4116
4117 pub(crate) fn resident_byte_count(&self) -> usize {
4118 storage_dispatch!(&self.storage, b => b.resident_byte_count())
4119 }
4120
4121 pub(crate) fn nrows(&self) -> usize {
4122 storage_dispatch!(&self.storage, b => b.penalty_dim())
4123 }
4124
4125 pub(crate) fn ncols(&self) -> usize {
4126 self.nrows()
4127 }
4128
4129 pub(crate) fn scaled_materialize(&self, amp: f64) -> Array2<f64> {
4130 let mut out = Array2::<f64>::zeros((self.nrows(), self.ncols()));
4131 self.scaled_add_to(&mut out, amp)
4132 .expect("scaled materialize uses matching target shape");
4133 out
4134 }
4135
4136 pub(crate) fn transformed(
4137 &self,
4138 qs: &Array2<f64>,
4139 free_basis_opt: Option<&Array2<f64>>,
4140 ) -> Result<Array2<f64>, EstimationError> {
4141 storage_dispatch!(&self.storage, b => b.penalty_transformed(qs, free_basis_opt))
4142 }
4143
4144 pub(crate) fn scaled_add_to(
4145 &self,
4146 target: &mut Array2<f64>,
4147 amp: f64,
4148 ) -> Result<(), EstimationError> {
4149 storage_dispatch!(&self.storage, b => b.penalty_scaled_add_to(target, amp))
4150 }
4151}
4152
4153impl From<Array2<f64>> for HyperPenaltyDerivative {
4154 fn from(value: Array2<f64>) -> Self {
4155 Self {
4156 storage: DerivativeMatrixStorage::Dense(value),
4157 }
4158 }
4159}
4160
4161#[derive(Clone)]
4162pub struct PenaltyDerivativeComponent {
4163 pub penalty_index: usize,
4164 pub matrix: HyperPenaltyDerivative,
4165}
4166
4167#[derive(Clone)]
4168pub struct DirectionalHyperParam {
4169 pub(crate) x_tau_original: HyperDesignDerivative,
4170 pub(crate) penalty_first_components: Vec<PenaltyDerivativeComponent>,
4173 pub(crate) x_tau_tau_original: Option<Vec<Option<HyperDesignDerivative>>>,
4177 pub(crate) penaltysecond_components: Option<Vec<Option<Vec<PenaltyDerivativeComponent>>>>,
4180 pub(crate) penaltysecond_component_provider: Option<
4181 std::sync::Arc<
4182 dyn Fn(usize) -> Result<Option<Vec<PenaltyDerivativeComponent>>, EstimationError>
4183 + Send
4184 + Sync
4185 + 'static,
4186 >,
4187 >,
4188 pub(crate) penaltysecond_partner_indices: Option<std::sync::Arc<[usize]>>,
4189 pub(crate) is_penalty_like: bool,
4193}
4194
4195impl DirectionalHyperParam {
4196 pub(crate) fn resident_byte_count(&self) -> usize {
4197 let mut bytes = self.x_tau_original.resident_byte_count();
4198 for component in &self.penalty_first_components {
4199 bytes = bytes.saturating_add(component.matrix.resident_byte_count());
4200 }
4201 if let Some(entries) = self.x_tau_tau_original.as_ref() {
4202 for entry in entries.iter().flatten() {
4203 bytes = bytes.saturating_add(entry.resident_byte_count());
4204 }
4205 }
4206 if let Some(rows) = self.penaltysecond_components.as_ref() {
4207 for components in rows.iter().flatten() {
4208 for component in components {
4209 bytes = bytes.saturating_add(component.matrix.resident_byte_count());
4210 }
4211 }
4212 }
4213 bytes
4214 }
4215
4216 pub(crate) fn canonicalize_penalty_components(
4217 components: Vec<(usize, HyperPenaltyDerivative)>,
4218 ) -> Result<Vec<PenaltyDerivativeComponent>, EstimationError> {
4219 let mut out: Vec<PenaltyDerivativeComponent> = Vec::with_capacity(components.len());
4220 for (penalty_index, matrix) in components {
4221 if out.iter().any(|c| c.penalty_index == penalty_index) {
4222 crate::bail_invalid_estim!(
4223 "duplicate penalty derivative component for penalty {}",
4224 penalty_index
4225 );
4226 }
4227 out.push(PenaltyDerivativeComponent {
4228 penalty_index,
4229 matrix,
4230 });
4231 }
4232 Ok(out)
4233 }
4234
4235 pub fn new_compact(
4236 x_tau_original: HyperDesignDerivative,
4237 penalty_first_components: Vec<(usize, HyperPenaltyDerivative)>,
4238 x_tau_tau_original: Option<Vec<Option<HyperDesignDerivative>>>,
4239 penaltysecond_components: Option<Vec<Option<Vec<(usize, HyperPenaltyDerivative)>>>>,
4240 ) -> Result<Self, EstimationError> {
4241 let is_penalty_like = !x_tau_original.any_nonzero();
4242 let penalty_first_components =
4243 Self::canonicalize_penalty_components(penalty_first_components)?;
4244 let penaltysecond_components = match penaltysecond_components {
4245 Some(rows) => {
4246 let mut out = Vec::with_capacity(rows.len());
4247 for row in rows {
4248 out.push(match row {
4249 Some(components) => {
4250 Some(Self::canonicalize_penalty_components(components)?)
4251 }
4252 None => None,
4253 });
4254 }
4255 Some(out)
4256 }
4257 None => None,
4258 };
4259 Ok(Self {
4260 x_tau_original,
4261 penalty_first_components,
4262 x_tau_tau_original,
4263 penaltysecond_components,
4264 penaltysecond_component_provider: None,
4265 penaltysecond_partner_indices: None,
4266 is_penalty_like,
4267 })
4268 }
4269
4270 pub fn not_penalty_like(mut self) -> Self {
4273 self.is_penalty_like = false;
4274 self
4275 }
4276
4277 pub fn with_penaltysecond_component_provider(
4278 mut self,
4279 provider: std::sync::Arc<
4280 dyn Fn(usize) -> Result<Option<Vec<PenaltyDerivativeComponent>>, EstimationError>
4281 + Send
4282 + Sync
4283 + 'static,
4284 >,
4285 ) -> Self {
4286 self.penaltysecond_component_provider = Some(provider);
4287 self
4288 }
4289
4290 pub fn with_penaltysecond_partner_indices(mut self, partners: Vec<usize>) -> Self {
4291 self.penaltysecond_partner_indices = Some(std::sync::Arc::from(partners));
4292 self
4293 }
4294
4295 pub(crate) fn x_tau_dense(&self) -> Array2<f64> {
4296 self.x_tau_original.materialize()
4297 }
4298
4299 pub(crate) fn transformed_x_tau(
4300 &self,
4301 qs: &Array2<f64>,
4302 free_basis_opt: Option<&Array2<f64>>,
4303 ) -> Result<Array2<f64>, EstimationError> {
4304 self.x_tau_original.transformed(qs, free_basis_opt)
4305 }
4306
4307 pub(crate) fn x_tau_tau_entry_at(&self, j: usize) -> Option<HyperDesignDerivative> {
4308 self.x_tau_tau_original
4309 .as_ref()
4310 .and_then(|rows| rows.get(j))
4311 .and_then(|entry| entry.clone())
4312 }
4313
4314 pub(crate) fn has_implicit_operator(&self) -> bool {
4317 self.x_tau_original.uses_implicit_storage()
4318 }
4319
4320 pub(crate) fn has_implicit_multidim_duchon(&self) -> bool {
4321 self.implicit_first_axis_info()
4322 .is_some_and(|(op, _)| op.n_axes() > 1 && op.is_duchon_family())
4323 }
4324
4325 pub(crate) fn implicit_first_axis_info(
4327 &self,
4328 ) -> Option<(
4329 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
4330 usize,
4331 )> {
4332 self.x_tau_original.implicit_first_axis_info()
4333 }
4334
4335 pub(crate) fn implicit_axis_count_hint(&self) -> Option<usize> {
4336 self.x_tau_original.implicit_axis_count_hint()
4337 }
4338
4339 pub(crate) fn penalty_first_components(&self) -> &[PenaltyDerivativeComponent] {
4340 &self.penalty_first_components
4341 }
4342
4343 pub(crate) fn penalty_total_at(
4344 &self,
4345 rho: &Array1<f64>,
4346 p: usize,
4347 ) -> Result<Array2<f64>, EstimationError> {
4348 let mut out = Array2::<f64>::zeros((p, p));
4349 for component in &self.penalty_first_components {
4350 if component.matrix.nrows() != p || component.matrix.ncols() != p {
4351 crate::bail_invalid_estim!(
4352 "S_tau shape mismatch for penalty {}: expected {}x{}, got {}x{}",
4353 component.penalty_index,
4354 p,
4355 p,
4356 component.matrix.nrows(),
4357 component.matrix.ncols()
4358 );
4359 }
4360 if component.penalty_index >= rho.len() {
4361 crate::bail_invalid_estim!(
4362 "penalty_index {} out of bounds for rho dimension {}",
4363 component.penalty_index,
4364 rho.len()
4365 );
4366 }
4367 component
4368 .matrix
4369 .scaled_add_to(&mut out, rho[component.penalty_index].exp())?;
4370 }
4371 Ok(out)
4372 }
4373
4374 pub(crate) fn penaltysecond_components_for(
4375 &self,
4376 j: usize,
4377 ) -> Result<Option<Vec<PenaltyDerivativeComponent>>, EstimationError> {
4378 if let Some(components) = self
4379 .penaltysecond_components
4380 .as_ref()
4381 .and_then(|rows| rows.get(j))
4382 .and_then(|row| row.clone())
4383 {
4384 return Ok(Some(components));
4385 }
4386 if let Some(provider) = self.penaltysecond_component_provider.as_ref() {
4387 return provider(j);
4388 }
4389 Ok(None)
4390 }
4391
4392 pub(crate) fn penaltysecond_componentrows(
4393 &self,
4394 ) -> Option<&[Option<Vec<PenaltyDerivativeComponent>>]> {
4395 self.penaltysecond_components.as_deref()
4396 }
4397
4398 pub(crate) fn penalty_first_component_count(&self) -> usize {
4399 self.penalty_first_components.len()
4400 }
4401
4402 pub(crate) fn has_penaltysecond_pair_at(&self, j: usize) -> bool {
4403 self.penaltysecond_components
4404 .as_ref()
4405 .and_then(|rows| rows.get(j))
4406 .is_some_and(Option::is_some)
4407 || self
4408 .penaltysecond_partner_indices
4409 .as_ref()
4410 .is_some_and(|partners| partners.contains(&j))
4411 }
4412}
4413
4414#[derive(Clone, Debug)]
4415pub(crate) struct SparseRemlDecision {
4416 pub(crate) geometry: RemlGeometry,
4417 pub(crate) reason: &'static str,
4418 pub(crate) p: usize,
4419 pub(crate) nnz_x: usize,
4420 pub(crate) nnz_h_upper_est: Option<usize>,
4421 pub(crate) density_h_upper_est: Option<f64>,
4422}
4423
4424#[derive(Clone)]
4425pub(crate) struct SparseExactEvalData {
4426 pub(crate) factor: Arc<SparseExactFactor>,
4427 pub(crate) takahashi: Option<Arc<gam_linalg::sparse_exact::TakahashiInverse>>,
4428 pub(crate) logdet_h: f64,
4429 pub(crate) logdet_s_pos: f64,
4430 pub(crate) penalty_rank: usize,
4431 pub(crate) det1_values: Arc<Array1<f64>>,
4432}
4433
4434#[derive(Clone)]
4435pub struct FirthDenseOperator {
4436 pub(crate) x_dense: Array2<f64>,
4463 pub(crate) x_dense_t: Array2<f64>,
4464 pub(crate) q_basis: Array2<f64>,
4467 pub(crate) x_reduced: Array2<f64>,
4470 pub(crate) observation_weight_sqrt: Option<Array1<f64>>,
4476 pub(crate) k_reduced: Array2<f64>,
4478 pub(crate) x_metric_reduced_inv_diag: Array1<f64>,
4483 pub(crate) half_log_det: f64,
4485 pub(crate) h_diag: Array1<f64>,
4487 pub(crate) w: Array1<f64>,
4489 pub(crate) w1: Array1<f64>,
4490 pub(crate) w2: Array1<f64>,
4491 pub(crate) w3: Array1<f64>,
4492 pub(crate) w4: Array1<f64>,
4493 pub(crate) b_base: Array2<f64>,
4495 pub(crate) p_b_base: Array2<f64>,
4498}
4499
4500#[derive(Clone)]
4517pub(crate) struct FirthDesignFactor {
4518 pub(crate) x_dense: Array2<f64>,
4520 pub(crate) x_dense_t: Array2<f64>,
4521 pub(crate) q_basis: Array2<f64>,
4523 pub(crate) x_reduced: Array2<f64>,
4525 pub(crate) observation_weight_sqrt: Option<Array1<f64>>,
4527 pub(crate) metric_spectrum: Array1<f64>,
4529 pub(crate) x_metric_reduced_inv_diag: Array1<f64>,
4531 pub(crate) r: usize,
4533 pub(crate) n: usize,
4534}
4535
4536#[derive(Clone)]
4537pub(crate) struct FirthDirection {
4538 pub(crate) deta: Array1<f64>,
4539 pub(crate) g_u_reduced: Array2<f64>,
4540 pub(crate) a_u_reduced: Array2<f64>,
4541 pub(crate) dh: Array1<f64>,
4542 pub(crate) b_uvec: Array1<f64>,
4544}
4545
4546#[derive(Clone)]
4547pub(crate) struct FirthTauPartialKernel {
4548 pub(super) deta_partial: Array1<f64>,
4549 pub(crate) dotw1: Array1<f64>,
4550 pub(crate) dotw2: Array1<f64>,
4551 pub(crate) dot_h_partial: Array1<f64>,
4552 pub(crate) x_tau_reduced: Array2<f64>,
4555 pub(super) dot_i_partial: Array2<f64>,
4556 pub(crate) dot_k_reduced: Array2<f64>,
4560}
4561
4562#[derive(Clone)]
4563pub(crate) struct FirthTauExactKernel {
4564 pub(crate) gphi_tau: Array1<f64>,
4565 pub(crate) phi_tau_partial: f64,
4566 pub(crate) tau_kernel: Option<FirthTauPartialKernel>,
4567}
4568
4569#[derive(Clone)]
4581pub(crate) struct FirthTauTauExactKernel {
4582 pub(super) phi_tau_tau_partial: f64,
4583 pub(super) gphi_tau_tau: Array1<f64>,
4584 pub(super) tau_tau_kernel: Option<FirthTauTauPartialKernel>,
4585}
4586
4587#[derive(Clone, Default)]
4600pub(crate) struct FirthTauTauPartialKernel {
4601 pub(super) x_tau_i_reduced: Array2<f64>,
4602 pub(super) x_tau_j_reduced: Array2<f64>,
4603 pub(super) deta_i_partial: Array1<f64>,
4604 pub(super) deta_j_partial: Array1<f64>,
4605 pub(super) dot_h_i_partial: Array1<f64>,
4606 pub(super) dot_h_j_partial: Array1<f64>,
4607 pub(super) dot_k_i_reduced: Array2<f64>,
4608 pub(super) dot_k_j_reduced: Array2<f64>,
4609 pub(super) dot_i_i_partial: Array2<f64>,
4610 pub(super) dot_i_j_partial: Array2<f64>,
4611 pub(super) x_tau_tau_reduced: Option<Array2<f64>>,
4612 pub(super) deta_ij_partial: Option<Array1<f64>>,
4613}
4614
4615#[derive(Clone, Default)]
4623pub(crate) struct FirthTauBetaPartialKernel {
4624 pub(super) x_tau_reduced: Array2<f64>,
4625 pub(super) deta_partial: Array1<f64>,
4626 pub(super) dot_h_partial: Array1<f64>,
4627 pub(super) dot_i_partial: Array2<f64>,
4628 pub(super) dot_k_reduced: Array2<f64>,
4629 pub(super) deta_v: Array1<f64>,
4630 pub(super) deta_tau_v: Array1<f64>,
4631 pub(super) a_v_reduced: Array2<f64>,
4632 pub(super) dh_v: Array1<f64>,
4633 pub(super) b_vvec: Array1<f64>,
4634 pub(super) d_beta_dot_k: Array2<f64>,
4635 pub(super) d_beta_dot_h: Array1<f64>,
4636}
4637
4638#[derive(Clone)]
4649pub(crate) struct EvalShared {
4650 pub(crate) key: Option<Vec<u64>>,
4651 pub(crate) pirls_result: Arc<PirlsResult>,
4652 pub(crate) ridge_passport: RidgePassport,
4653 pub(crate) geometry: RemlGeometry,
4654 pub(crate) h_total: Arc<Array2<f64>>,
4658 pub(crate) sparse_exact: Option<Arc<SparseExactEvalData>>,
4659 pub(crate) firth_dense_operator: Option<Arc<FirthDenseOperator>>,
4660 pub(crate) firth_dense_operator_original: Option<Arc<FirthDenseOperator>>,
4663 pub(crate) penalty_pseudologdet: std::sync::OnceLock<Arc<penalty_logdet::PenaltyPseudologdet>>,
4677 pub(crate) penalty_scores_at_mode: std::sync::OnceLock<Arc<Vec<Array1<f64>>>>,
4690 pub(crate) block_local_correction:
4708 std::sync::OnceLock<(usize, Arc<outer_eval::TkCorrectionTerms>)>,
4709}
4710
4711impl EvalShared {
4712 pub(crate) fn matches(&self, key: &Option<Vec<u64>>) -> bool {
4713 match (&self.key, key) {
4714 (None, None) => true,
4715 (Some(a), Some(b)) => a == b,
4716 _ => false,
4717 }
4718 }
4719
4720 pub(crate) fn penalty_pseudologdet_original(
4735 &self,
4736 canonical_penalties: &[gam_terms::construction::CanonicalPenalty],
4737 lambdas: &[f64],
4738 p: usize,
4739 ) -> Result<Arc<penalty_logdet::PenaltyPseudologdet>, EstimationError> {
4740 if let Some(pld) = self.penalty_pseudologdet.get() {
4741 if pld.dim() != p {
4742 return Err(EstimationError::LayoutError(format!(
4743 "shared penalty pseudo-logdet frame mismatch: cached p={}, requested p={}",
4744 pld.dim(),
4745 p
4746 )));
4747 }
4748 return Ok(Arc::clone(pld));
4749 }
4750 let pld = Arc::new(
4751 penalty_logdet::PenaltyPseudologdet::from_penalties(
4752 canonical_penalties,
4753 lambdas,
4754 self.ridge_passport.penalty_logdet_ridge(),
4755 p,
4756 )
4757 .map_err(EstimationError::InvalidInput)?,
4758 );
4759 match self.penalty_pseudologdet.set(Arc::clone(&pld)) {
4760 Ok(()) => Ok(pld),
4761 Err(_) => Ok(Arc::clone(
4765 self.penalty_pseudologdet
4766 .get()
4767 .expect("OnceLock set raced, so it is initialized"),
4768 )),
4769 }
4770 }
4771}
4772
4773impl PenalizedGeometry for EvalShared {
4774 fn backend_kind(&self) -> GeometryBackendKind {
4775 match self.geometry {
4776 RemlGeometry::DenseSpectral => GeometryBackendKind::DenseSpectral,
4777 RemlGeometry::SparseExactSpd => GeometryBackendKind::SparseExactSpd,
4778 }
4779 }
4780}
4781
4782pub(crate) struct PirlsLruCache {
4792 pub(crate) map: HashMap<Vec<u64>, (Arc<PirlsResult>, u64, usize)>,
4794 pub(crate) byte_budget: usize,
4795 pub(crate) current_bytes: usize,
4796 pub(crate) clock: u64,
4797}
4798
4799impl PirlsLruCache {
4800 pub(crate) fn new(byte_budget: usize) -> Self {
4801 Self {
4802 map: HashMap::new(),
4803 byte_budget: byte_budget.max(1),
4804 current_bytes: 0,
4805 clock: 0,
4806 }
4807 }
4808
4809 pub(crate) fn get(&mut self, key: &Vec<u64>) -> Option<Arc<PirlsResult>> {
4810 if let Some(entry) = self.map.get_mut(key) {
4811 self.clock += 1;
4812 entry.1 = self.clock;
4813 Some(entry.0.clone())
4814 } else {
4815 None
4816 }
4817 }
4818
4819 pub(crate) fn insert(&mut self, key: Vec<u64>, value: Arc<PirlsResult>) {
4820 self.clock += 1;
4821 let bytes = pirls_result_cache_bytes(&value);
4822 if bytes > self.byte_budget {
4826 if let Some((_, _, prev_bytes)) = self.map.remove(&key) {
4827 self.current_bytes = self.current_bytes.saturating_sub(prev_bytes);
4828 }
4829 return;
4830 }
4831 if let Some((_, _, prev_bytes)) = self.map.remove(&key) {
4832 self.current_bytes = self.current_bytes.saturating_sub(prev_bytes);
4833 }
4834 while self.current_bytes + bytes > self.byte_budget {
4835 let evict_key = self
4836 .map
4837 .iter()
4838 .min_by_key(|(_, (_, ts, _))| *ts)
4839 .map(|(k, _)| k.clone());
4840 match evict_key {
4841 Some(k) => {
4842 if let Some((_, _, evict_bytes)) = self.map.remove(&k) {
4843 self.current_bytes = self.current_bytes.saturating_sub(evict_bytes);
4844 }
4845 }
4846 None => break,
4847 }
4848 }
4849 self.current_bytes += bytes;
4850 self.map.insert(key, (value, self.clock, bytes));
4851 }
4852
4853 pub(crate) fn clear(&mut self) {
4854 self.map.clear();
4855 self.current_bytes = 0;
4856 }
4857}
4858
4859#[derive(Clone, Copy, PartialEq, Eq)]
4860pub(crate) struct PenaltySubspaceCacheKey {
4861 pub(crate) penalty_matrix_fingerprint: u64,
4862 pub(crate) ridge_passport_signature: u64,
4863}
4864
4865pub(crate) struct PenaltySubspaceCache {
4866 pub(crate) entry: Option<(PenaltySubspaceCacheKey, Arc<outer_eval::PenaltySubspace>)>,
4867}
4868
4869impl PenaltySubspaceCache {
4870 pub(crate) fn new() -> Self {
4871 Self { entry: None }
4872 }
4873
4874 pub(crate) fn get(
4875 &self,
4876 key: &PenaltySubspaceCacheKey,
4877 ) -> Option<Arc<outer_eval::PenaltySubspace>> {
4878 self.entry
4879 .as_ref()
4880 .filter(|(cached_key, _)| cached_key == key)
4881 .map(|(_, value)| value.clone())
4882 }
4883
4884 pub(crate) fn insert(
4885 &mut self,
4886 key: PenaltySubspaceCacheKey,
4887 value: Arc<outer_eval::PenaltySubspace>,
4888 ) {
4889 self.entry = Some((key, value));
4890 }
4891
4892 pub(crate) fn clear(&mut self) {
4893 self.entry = None;
4894 }
4895}
4896
4897impl PenaltySubspaceCacheKey {
4898 pub(crate) fn from_inputs(
4903 e_transformed: &ndarray::Array2<f64>,
4904 ridge_passport: &gam_problem::RidgePassport,
4905 ) -> Self {
4906 use std::collections::hash_map::DefaultHasher;
4907 use std::hash::{Hash, Hasher};
4908 let mut hasher = DefaultHasher::new();
4909 e_transformed.nrows().hash(&mut hasher);
4910 e_transformed.ncols().hash(&mut hasher);
4911 for value in e_transformed.iter() {
4912 value.to_bits().hash(&mut hasher);
4913 }
4914 let penalty_matrix_fingerprint = hasher.finish();
4915 let mut ridge_hasher = DefaultHasher::new();
4916 ridge_passport.delta.to_bits().hash(&mut ridge_hasher);
4917 (ridge_passport.matrix_form as u8).hash(&mut ridge_hasher);
4918 ridge_passport
4919 .policy
4920 .include_penalty_logdet
4921 .hash(&mut ridge_hasher);
4922 ridge_passport
4923 .policy
4924 .include_laplacehessian
4925 .hash(&mut ridge_hasher);
4926 let ridge_passport_signature = ridge_hasher.finish();
4927 Self {
4928 penalty_matrix_fingerprint,
4929 ridge_passport_signature,
4930 }
4931 }
4932}
4933
4934pub(crate) fn pirls_result_cache_bytes(result: &PirlsResult) -> usize {
4949 use std::mem::size_of;
4950 let n_array_elems = result.final_eta.len()
4951 + result.solveweights.len()
4952 + result.solveworking_response.len()
4953 + result.solvemu.len()
4954 + result.solve_c_array.len()
4955 + result.solve_d_array.len();
4956 let p = result.beta_transformed.0.len();
4957 let pen_h = symmetric_matrix_cache_bytes(&result.penalized_hessian_transformed);
4958 let stab_h = symmetric_matrix_cache_bytes(&result.stabilizedhessian_transformed);
4959 let reparam = (result.reparam_result.s_transformed.len()
4960 + result.reparam_result.qs.len()
4961 + result.reparam_result.e_transformed.len()
4962 + result.reparam_result.det1.len())
4963 * size_of::<f64>();
4964 n_array_elems * size_of::<f64>() + p * size_of::<f64>() + pen_h + stab_h + reparam + 1024
4965}
4966
4967pub(crate) fn symmetric_matrix_cache_bytes(m: &gam_linalg::matrix::SymmetricMatrix) -> usize {
4968 use gam_linalg::matrix::SymmetricMatrix;
4969 use std::mem::size_of;
4970 match m {
4971 SymmetricMatrix::Dense(a) => a.len() * size_of::<f64>(),
4972 SymmetricMatrix::Sparse(s) => {
4973 let (symbolic, values) = s.parts();
4975 values.len() * (size_of::<f64>() + size_of::<usize>())
4976 + std::mem::size_of_val(symbolic.col_ptr())
4977 }
4978 }
4979}
4980
4981pub(crate) const OUTER_EVAL_LRU_CAPACITY: usize = 8;
4989
4990pub(crate) struct OuterEvalLru {
5004 capacity: usize,
5005 entries: std::collections::VecDeque<(Vec<u64>, OuterEval)>,
5007}
5008
5009impl OuterEvalLru {
5010 pub(crate) fn new(capacity: usize) -> Self {
5011 Self {
5012 capacity: capacity.max(1),
5013 entries: std::collections::VecDeque::new(),
5014 }
5015 }
5016
5017 pub(crate) fn get(&mut self, key: &[u64]) -> Option<OuterEval> {
5021 let pos = self
5022 .entries
5023 .iter()
5024 .position(|(k, _)| k.as_slice() == key)?;
5025 let entry = self.entries.remove(pos)?;
5026 let eval = entry.1.clone();
5027 self.entries.push_back(entry);
5028 Some(eval)
5029 }
5030
5031 pub(crate) fn insert(&mut self, key: Vec<u64>, eval: OuterEval) {
5034 if let Some(pos) = self
5035 .entries
5036 .iter()
5037 .position(|(k, _)| k.as_slice() == key.as_slice())
5038 {
5039 self.entries.remove(pos);
5040 }
5041 self.entries.push_back((key, eval));
5042 while self.entries.len() > self.capacity {
5043 self.entries.pop_front();
5044 }
5045 }
5046
5047 pub(crate) fn clear(&mut self) {
5048 self.entries.clear();
5049 }
5050}
5051
5052pub(crate) struct EvalCacheManager {
5057 pub(crate) pirls_cache: RwLock<PirlsLruCache>,
5058 pub(crate) penalty_subspace_cache: RwLock<PenaltySubspaceCache>,
5059 pub(crate) current_eval_bundle: RwLock<Option<EvalShared>>,
5060 pub(crate) current_outer_eval: RwLock<Option<(Vec<u64>, OuterEval)>>,
5064 pub(crate) outer_eval_lru: RwLock<OuterEvalLru>,
5078 pub(crate) pirls_cache_enabled: AtomicBool,
5079}
5080
5081impl EvalCacheManager {
5082 pub(crate) fn new() -> Self {
5083 Self {
5084 pirls_cache: RwLock::new(PirlsLruCache::new(PIRLS_CACHE_BYTE_BUDGET)),
5085 penalty_subspace_cache: RwLock::new(PenaltySubspaceCache::new()),
5086 current_eval_bundle: RwLock::new(None),
5087 current_outer_eval: RwLock::new(None),
5088 outer_eval_lru: RwLock::new(OuterEvalLru::new(OUTER_EVAL_LRU_CAPACITY)),
5089 pirls_cache_enabled: AtomicBool::new(true),
5090 }
5091 }
5092
5093 pub(crate) fn sanitized_rhokey(rho: &Array1<f64>) -> Option<Vec<u64>> {
5097 self::rho_key::sanitized_rhokey(rho)
5098 }
5099
5100 pub(super) fn cached_penalty_subspace<F>(
5107 &self,
5108 e_transformed: &ndarray::Array2<f64>,
5109 ridge_passport: &gam_problem::RidgePassport,
5110 build: F,
5111 ) -> Result<Arc<outer_eval::PenaltySubspace>, EstimationError>
5112 where
5113 F: FnOnce() -> Result<outer_eval::PenaltySubspace, EstimationError>,
5114 {
5115 let key = PenaltySubspaceCacheKey::from_inputs(e_transformed, ridge_passport);
5116 if let Some(hit) = self.penalty_subspace_cache.read().unwrap().get(&key) {
5117 return Ok(hit);
5118 }
5119 let value = Arc::new(build()?);
5120 self.penalty_subspace_cache
5121 .write()
5122 .unwrap()
5123 .insert(key, value.clone());
5124 Ok(value)
5125 }
5126
5127 pub(crate) fn cached_eval_bundle(&self, key: &Option<Vec<u64>>) -> Option<EvalShared> {
5128 let guard = self.current_eval_bundle.read().unwrap();
5129 let bundle: &EvalShared = guard.as_ref()?;
5130 bundle.matches(key).then(|| bundle.clone())
5131 }
5132
5133 pub(crate) fn store_eval_bundle(&self, bundle: EvalShared) {
5134 *self.current_eval_bundle.write().unwrap() = Some(bundle);
5135 }
5136
5137 pub(crate) fn cached_outer_eval(&self, key: &Option<Vec<u64>>) -> Option<OuterEval> {
5138 let key = key.as_ref()?;
5139 self.outer_eval_lru.write().unwrap().get(key)
5146 }
5147
5148 pub(crate) fn store_outer_eval(&self, key: &Option<Vec<u64>>, eval: &OuterEval) {
5149 if let Some(key) = key.clone() {
5150 *self.current_outer_eval.write().unwrap() = Some((key.clone(), eval.clone()));
5154 self.outer_eval_lru.write().unwrap().insert(key, eval.clone());
5155 }
5156 }
5157
5158 pub(crate) fn invalidate_eval_bundle(&self) {
5159 self.current_eval_bundle.write().unwrap().take();
5160 self.current_outer_eval.write().unwrap().take();
5161 self.outer_eval_lru.write().unwrap().clear();
5162 }
5163
5164 pub(crate) fn clear_eval_and_factor_caches(&self) {
5165 self.invalidate_eval_bundle();
5166 self.penalty_subspace_cache.write().unwrap().clear();
5167 }
5168}
5169
5170pub(crate) struct RemlArena {
5173 pub(crate) cost_eval_count: RwLock<u64>,
5174 pub(crate) inner_pirls_solve_count: AtomicU64,
5187 pub(crate) lastgradient_used_stochastic_fallback: AtomicBool,
5188}
5189
5190impl RemlArena {
5191 pub(crate) fn new() -> Self {
5192 Self {
5193 cost_eval_count: RwLock::new(0),
5194 inner_pirls_solve_count: AtomicU64::new(0),
5195 lastgradient_used_stochastic_fallback: AtomicBool::new(false),
5196 }
5197 }
5198}
5199
5200pub(crate) struct AloFrozenNuisance {
5201 pub(crate) n_obs: usize,
5202 pub(crate) influence_scale: Vec<f64>,
5203 pub(crate) phi: f64,
5204}
5205
5206pub(crate) struct RemlState<'a> {
5207 pub(crate) y: ArrayView1<'a, f64>,
5208 pub(crate) x: DesignMatrix,
5209 pub(crate) weights: ArrayView1<'a, f64>,
5210 pub(crate) offset: Array1<f64>,
5211 pub(crate) canonical_penalties: Arc<Vec<gam_terms::construction::CanonicalPenalty>>,
5215 pub(crate) balanced_penalty_root: Array2<f64>,
5216 pub(crate) reparam_invariant: ReparamInvariant,
5217 pub(crate) sparse_penalty_block_count: Option<usize>,
5218 pub(crate) p: usize,
5219 pub(crate) config: Arc<RemlConfig>,
5220 pub(crate) runtime_mixture_link_state: Option<gam_problem::MixtureLinkState>,
5221 pub(crate) runtime_sas_link_state: Option<SasLinkState>,
5222 pub(crate) nullspace_dims: Vec<usize>,
5223 pub(crate) coefficient_lower_bounds: Option<Array1<f64>>,
5224 pub(crate) linear_constraints: Option<crate::pirls::LinearInequalityConstraints>,
5225 pub(crate) penalty_shrinkage_floor: Option<f64>,
5227 pub(crate) rho_prior: gam_problem::RhoPrior,
5229
5230 pub(crate) cache_manager: EvalCacheManager,
5231 pub(crate) arena: RemlArena,
5232 pub(crate) warm_start_beta: RwLock<Option<Coefficients>>,
5233 pub(crate) warm_start_rho: RwLock<Option<Array1<f64>>>,
5243 pub(crate) prev_warm_start_beta: RwLock<Option<Coefficients>>,
5244 pub(crate) prev_warm_start_rho: RwLock<Option<Array1<f64>>>,
5245 pub(crate) warm_start_enabled: AtomicBool,
5246 pub(crate) screening_max_inner_iterations: Arc<AtomicUsize>,
5247 pub(crate) outer_inner_cap: Arc<AtomicUsize>,
5262
5263 pub(crate) last_inner_iters: Arc<AtomicUsize>,
5276 pub(crate) last_inner_converged: Arc<AtomicBool>,
5277
5278 pub(crate) ift_warm_start_cache: RwLock<Option<IftWarmStartCache>>,
5294
5295 pub(crate) last_pirls_lm_lambda: Arc<AtomicU64>,
5307
5308 pub(crate) frozen_negbin_theta: Arc<AtomicU64>,
5320
5321 pub(crate) frozen_tweedie_phi: Arc<AtomicU64>,
5335
5336 pub(crate) frozen_gamma_shape: Arc<AtomicU64>,
5353
5354 pub(crate) last_ift_prediction_residual: Arc<AtomicU64>,
5376
5377 pub(crate) last_pirls_accept_rho: Arc<AtomicU64>,
5392
5393 pub(crate) ift_cached_factor: RwLock<Option<Arc<dyn gam_linalg::matrix::FactorizedSystem>>>,
5404
5405 pub(crate) kronecker_penalty_system: Option<gam_terms::smooth::KroneckerPenaltySystem>,
5409 pub(crate) kronecker_factored: Option<gam_terms::basis::KroneckerFactoredBasis>,
5412
5413 pub(crate) gaussian_fixed_cache: RwLock<Option<Arc<crate::pirls::GaussianFixedCache>>>,
5423 pub(crate) gaussian_psi_gram_deriv:
5434 RwLock<Option<Arc<(ndarray::Array2<f64>, ndarray::Array1<f64>)>>>,
5435 pub(crate) glm_psi_gram_deriv:
5453 RwLock<Option<Arc<(ndarray::Array2<f64>, ndarray::Array1<f64>)>>>,
5454 pub(crate) glm_first_step_gram: RwLock<Option<Arc<ndarray::Array2<f64>>>>,
5473 pub(crate) flat_glm_first_step_gram: RwLock<Option<Arc<ndarray::Array2<f64>>>>,
5483 pub(crate) alo_frozen_nuisance: RwLock<Option<AloFrozenNuisance>>,
5490
5491 pub(crate) alo_provably_inactive: RwLock<Option<bool>>,
5509
5510 pub(crate) persistent_warm_start_key: RwLock<Option<String>>,
5513 pub(crate) persistent_latent_values_fingerprint: Option<u64>,
5514 pub(crate) persistent_latent_values_cache: RwLock<PersistentLatentValuesCache>,
5515 pub(crate) analytic_penalty_registry_fingerprint: u64,
5516 pub(crate) persistent_warm_start_loaded: AtomicBool,
5518 pub(crate) persistent_warm_start_store_suppression: AtomicUsize,
5522 pub(crate) alo_stabilization_suppression: AtomicUsize,
5532 pub(crate) persistent_warm_start_disk_enabled: AtomicBool,
5546 pub(crate) gaussian_weight_log_sum_half_cache: std::sync::OnceLock<f64>,
5558 pub(crate) gaussian_dp_floor_scale_cache: std::sync::OnceLock<f64>,
5559}