1use self::inner_strategy::GeometryBackendKind;
2use super::*;
3use gam_linalg::sparse_exact::SparseExactFactor;
4use crate::pirls::PIRLS_CACHE_BYTE_BUDGET;
5use crate::pirls::assemble_and_factor_sparse_penalized_system;
6use gam_terms::basis::LocalDesignJacobianProvider;
7use gam_problem::SasLinkState;
8use gam_problem::OuterEval;
9use ndarray::{Array1, Array2, s};
10use std::collections::{HashMap, VecDeque};
11use std::ops::Range;
12use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize};
13use std::sync::{Arc, RwLock};
14
15pub mod assembly;
16pub mod atoms;
17pub(crate) mod continuation;
18pub(crate) mod eval;
19mod firth;
20pub(super) mod hyper;
21mod inner_strategy;
22pub mod jeffreys_subspace;
25pub mod outer_eval;
26pub mod penalty_logdet;
27pub mod per_atom_efs;
28pub mod reml_outer_engine;
29mod rho_key;
30mod sparse_exact_penalty;
31mod trace;
32
33pub(crate) use sparse_exact_penalty::sparse_penalty_block_count_from_canonical;
34
35pub(crate) const EXACT_TAU_TAU_HESSIAN_DENSE_CACHE_BUDGET_BYTES: usize = 512 * 1024 * 1024;
36pub(crate) const FIRTH_MAX_OBSERVATIONS: usize = 20_000;
37pub(crate) const FIRTH_MAX_COEFFICIENTS: usize = 256;
38pub(crate) const FIRTH_MAX_LINEAR_WORK: usize = 2_000_000;
39pub(crate) const FIRTH_MAX_QUADRATIC_WORK: usize = 100_000_000;
40pub(crate) const PERSISTENT_LATENT_VALUES_CACHE_CAPACITY: usize = 8;
41
42#[derive(Debug)]
43pub(crate) struct PersistentLatentValuesCache {
44 pub(crate) entries: HashMap<String, Array2<f64>>,
45 pub(crate) lru: VecDeque<String>,
46 pub(crate) capacity: usize,
47}
48
49impl Default for PersistentLatentValuesCache {
50 fn default() -> Self {
51 Self {
52 entries: HashMap::new(),
53 lru: VecDeque::new(),
54 capacity: PERSISTENT_LATENT_VALUES_CACHE_CAPACITY,
55 }
56 }
57}
58
59impl PersistentLatentValuesCache {
60 pub(crate) fn lookup(
61 &mut self,
62 key: &str,
63 n_obs: usize,
64 latent_dim: usize,
65 ) -> Option<Array2<f64>> {
66 let values = self.entries.get(key)?;
67 if values.dim() != (n_obs, latent_dim) {
68 return None;
69 }
70 let values = values.clone();
71 self.touch(key.to_string());
72 Some(values)
73 }
74
75 pub(crate) fn insert(&mut self, key: String, values: Array2<f64>) {
76 if values.iter().any(|value| !value.is_finite()) {
77 return;
78 }
79 self.entries.insert(key.clone(), values);
80 self.touch(key);
81 while self.entries.len() > self.capacity {
82 let Some(evicted) = self.lru.pop_front() else {
83 break;
84 };
85 self.entries.remove(&evicted);
86 }
87 }
88
89 pub(crate) fn touch(&mut self, key: String) {
90 if let Some(index) = self.lru.iter().position(|queued| queued == &key) {
91 self.lru.remove(index);
92 }
93 self.lru.push_back(key);
94 }
95}
96
97#[derive(Clone)]
102pub(crate) struct IftWarmStartCache {
103 pub beta_original: ndarray::Array1<f64>,
109 pub rho: ndarray::Array1<f64>,
112 pub penalized_hessian_transformed: gam_linalg::matrix::SymmetricMatrix,
116 pub qs: ndarray::Array2<f64>,
120 pub frame_was_original: bool,
124 pub lambda_s_beta_blocks: Option<Vec<ndarray::Array1<f64>>>,
141}
142
143#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
144pub(crate) struct TauTauPlanEstimate {
145 pub(crate) dense_x_bytes: usize,
146 pub(crate) first_order_tau_bytes: usize,
147 pub(crate) second_order_tau_bytes: usize,
148 pub(crate) penalty_first_bytes: usize,
149 pub(crate) penalty_pair_bytes: usize,
150 pub(crate) rho_tau_penalty_bytes: usize,
151 pub(crate) vector_cache_bytes: usize,
152 pub(crate) weighted_scratch_bytes: usize,
153}
154
155impl TauTauPlanEstimate {
156 pub(crate) fn total_bytes(self) -> usize {
157 self.dense_x_bytes
158 .saturating_add(self.first_order_tau_bytes)
159 .saturating_add(self.second_order_tau_bytes)
160 .saturating_add(self.penalty_first_bytes)
161 .saturating_add(self.penalty_pair_bytes)
162 .saturating_add(self.rho_tau_penalty_bytes)
163 .saturating_add(self.vector_cache_bytes)
164 .saturating_add(self.weighted_scratch_bytes)
165 }
166}
167
168#[derive(Clone, Copy, Debug, PartialEq, Eq)]
169pub(crate) struct TauTauHessianPolicy {
170 pub(crate) any_has_implicit: bool,
171 pub(crate) implicit_multidim_duchon: bool,
172 pub(crate) estimated_dense_tau_cache_bytes: usize,
173 pub(crate) gradient_plan: TauTauPlanEstimate,
174 pub(crate) hessian_plan: TauTauPlanEstimate,
175 pub(crate) budget_bytes: usize,
176 pub(crate) firth_pair_terms_unavailable: bool,
177}
178
179impl TauTauHessianPolicy {
180 pub(crate) fn prefer_gradient_only(self) -> bool {
208 self.firth_pair_terms_unavailable
209 }
210}
211
212pub(crate) fn exact_tau_tau_hessian_policy_with_firth(
213 n_obs: usize,
214 p_coeff: usize,
215 hyper_dirs: &[DirectionalHyperParam],
216 firth_pair_terms_unavailable: bool,
217) -> TauTauHessianPolicy {
218 let f64_bytes = std::mem::size_of::<f64>();
219 let dense_matrix_bytes =
220 |rows: usize, cols: usize| -> usize { rows.saturating_mul(cols).saturating_mul(f64_bytes) };
221 let dense_design_bytes = dense_matrix_bytes(n_obs, p_coeff);
222 let dense_penalty_bytes = dense_matrix_bytes(p_coeff, p_coeff);
223 let psi_dim = hyper_dirs.len();
224 let implicit_n_axes = hyper_dirs
225 .iter()
226 .find_map(DirectionalHyperParam::implicit_axis_count_hint)
227 .unwrap_or(0);
228 let gradient_uses_implicit_design = hyper_dirs
229 .iter()
230 .any(DirectionalHyperParam::has_implicit_operator)
231 && gam_terms::basis::should_use_implicit_operators_with_policy(
232 n_obs,
233 p_coeff,
234 implicit_n_axes,
235 &gam_runtime::resource::ResourcePolicy::default_library(),
236 );
237 let dense_first_order_count = hyper_dirs
238 .iter()
239 .filter(|dir| !dir.has_implicit_operator())
240 .count();
241 let first_penalty_component_count = hyper_dirs
242 .iter()
243 .map(DirectionalHyperParam::penalty_first_component_count)
244 .sum::<usize>();
245
246 let mut dense_second_order_count = 0usize;
247 let mut penalty_pair_count = 0usize;
248 for i in 0..psi_dim {
249 for j in i..psi_dim {
250 if hyper_dirs[i]
251 .x_tau_tau_entry_at(j)
252 .or_else(|| hyper_dirs[j].x_tau_tau_entry_at(i))
253 .is_some_and(|entry| !entry.uses_implicit_storage())
254 {
255 dense_second_order_count += if i == j { 1 } else { 2 };
256 }
257 if hyper_dirs[i].has_penaltysecond_pair_at(j)
258 || hyper_dirs[j].has_penaltysecond_pair_at(i)
259 {
260 penalty_pair_count += if i == j { 1 } else { 2 };
261 }
262 }
263 }
264
265 let gradient_dense_first_order_count = if gradient_uses_implicit_design {
266 dense_first_order_count
267 } else {
268 psi_dim
269 };
270 let gradient_needs_dense_x =
271 firth_pair_terms_unavailable || gradient_dense_first_order_count > 0;
272 let gradient_plan = TauTauPlanEstimate {
273 dense_x_bytes: if gradient_needs_dense_x {
274 dense_design_bytes
275 } else {
276 0
277 },
278 first_order_tau_bytes: if gradient_dense_first_order_count > 0 {
279 dense_design_bytes
280 } else {
281 0
282 },
283 second_order_tau_bytes: 0,
284 penalty_first_bytes: psi_dim.saturating_mul(dense_penalty_bytes),
285 penalty_pair_bytes: 0,
286 rho_tau_penalty_bytes: 0,
287 vector_cache_bytes: n_obs.saturating_mul(f64_bytes),
288 weighted_scratch_bytes: dense_penalty_bytes,
289 };
290 let hessian_plan = TauTauPlanEstimate {
291 dense_x_bytes: if psi_dim > 0 { dense_design_bytes } else { 0 },
292 first_order_tau_bytes: dense_first_order_count.saturating_mul(dense_design_bytes),
293 second_order_tau_bytes: dense_second_order_count.saturating_mul(dense_design_bytes),
294 penalty_first_bytes: psi_dim.saturating_mul(dense_penalty_bytes),
295 penalty_pair_bytes: penalty_pair_count.saturating_mul(dense_penalty_bytes),
296 rho_tau_penalty_bytes: first_penalty_component_count
297 .saturating_mul(2)
298 .saturating_mul(dense_penalty_bytes),
299 vector_cache_bytes: psi_dim.saturating_mul(n_obs).saturating_mul(f64_bytes),
300 weighted_scratch_bytes: dense_penalty_bytes,
301 };
302 let any_has_implicit = hyper_dirs
303 .iter()
304 .any(DirectionalHyperParam::has_implicit_operator);
305 let implicit_multidim_duchon = hyper_dirs
306 .iter()
307 .any(DirectionalHyperParam::has_implicit_multidim_duchon);
308 let estimated_dense_tau_cache_bytes = hessian_plan
309 .first_order_tau_bytes
310 .saturating_add(hessian_plan.second_order_tau_bytes);
311 TauTauHessianPolicy {
312 any_has_implicit,
313 implicit_multidim_duchon,
314 estimated_dense_tau_cache_bytes,
315 gradient_plan,
316 hessian_plan,
317 budget_bytes: EXACT_TAU_TAU_HESSIAN_DENSE_CACHE_BUDGET_BYTES,
318 firth_pair_terms_unavailable: firth_pair_terms_unavailable && !hyper_dirs.is_empty(),
319 }
320}
321
322pub(crate) fn firth_problem_scale_allows(n_obs: usize, p_coeff: usize) -> bool {
323 let linear_work = n_obs.saturating_mul(p_coeff);
324 let quadratic_work = linear_work.saturating_mul(p_coeff);
325 n_obs <= FIRTH_MAX_OBSERVATIONS
326 && p_coeff <= FIRTH_MAX_COEFFICIENTS
327 && linear_work <= FIRTH_MAX_LINEAR_WORK
328 && quadratic_work <= FIRTH_MAX_QUADRATIC_WORK
329}
330
331#[cfg(test)]
332mod tests {
333 use super::{
334 DirectionalHyperParam, EvalCacheManager, EvalShared, HyperDesignDerivative,
335 HyperPenaltyDerivative, ImplicitDerivLevel, RemlConfig, RemlState,
336 };
337 use crate::estimate::EstimationError;
338 use gam_linalg::faer_ndarray::FaerCholesky;
339 use gam_linalg::matrix::symmetrize_in_place;
340 use crate::pirls::PirlsCoordinateFrame;
341 use gam_terms::basis::{ImplicitDesignPsiDerivative, RadialScalarKind};
342 use gam_problem::{
343 GlmLikelihoodSpec, InverseLink, LikelihoodSpec, ResponseFamily, StandardLink,
344 };
345 use faer::Side;
346 use gam_problem::{HessianResult, OuterEval};
347 use ndarray::{Array1, Array2, array, s};
348 use std::sync::Arc;
349
350 pub(crate) fn binomial_logit_glm_spec() -> GlmLikelihoodSpec {
353 GlmLikelihoodSpec::canonical(LikelihoodSpec::new(
354 ResponseFamily::Binomial,
355 InverseLink::Standard(StandardLink::Logit),
356 ))
357 }
358
359 pub(crate) fn gaussian_identity_glm_spec() -> GlmLikelihoodSpec {
362 GlmLikelihoodSpec::canonical(LikelihoodSpec::new(
363 ResponseFamily::Gaussian,
364 InverseLink::Standard(StandardLink::Identity),
365 ))
366 }
367
368 impl DirectionalHyperParam {
369 pub(super) fn new(
370 x_tau_original: Array2<f64>,
371 penalty_first_components: Vec<(usize, Array2<f64>)>,
372 x_tau_tau_original: Option<Vec<Option<Array2<f64>>>>,
373 penaltysecond_components: Option<Vec<Option<Vec<(usize, Array2<f64>)>>>>,
374 ) -> Result<Self, EstimationError> {
375 let x_tau_tau_original = x_tau_tau_original.map(|rows| {
376 rows.into_iter()
377 .map(|entry| entry.map(HyperDesignDerivative::from))
378 .collect::<Vec<_>>()
379 });
380 let penalty_first_components = penalty_first_components
381 .into_iter()
382 .map(|(idx, matrix)| (idx, HyperPenaltyDerivative::from(matrix)))
383 .collect();
384 let penaltysecond_components = penaltysecond_components.map(|rows| {
385 rows.into_iter()
386 .map(|row| {
387 row.map(|components| {
388 components
389 .into_iter()
390 .map(|(idx, matrix)| (idx, HyperPenaltyDerivative::from(matrix)))
391 .collect::<Vec<_>>()
392 })
393 })
394 .collect::<Vec<_>>()
395 });
396 Self::new_compact(
397 HyperDesignDerivative::from(x_tau_original),
398 penalty_first_components,
399 x_tau_tau_original,
400 penaltysecond_components,
401 )
402 }
403
404 pub(super) fn single_penalty(
405 penalty_index: usize,
406 x_tau_original: Array2<f64>,
407 s_tau_original: Array2<f64>,
408 x_tau_tau_original: Option<Vec<Option<Array2<f64>>>>,
409 s_tau_tau_original: Option<Vec<Option<Array2<f64>>>>,
410 ) -> Result<Self, EstimationError> {
411 let penaltysecond_components = s_tau_tau_original.map(|rows| {
412 rows.into_iter()
413 .map(|mat| mat.map(|mat| vec![(penalty_index, mat)]))
414 .collect::<Vec<_>>()
415 });
416 Self::new(
417 x_tau_original,
418 vec![(penalty_index, s_tau_original)],
419 x_tau_tau_original,
420 penaltysecond_components,
421 )
422 }
423 }
424
425 #[test]
426 pub(crate) fn firth_problem_scale_gate_blocks_large_quadratic_work() {
427 assert!(super::firth_problem_scale_allows(2_000, 200));
428 assert!(!super::firth_problem_scale_allows(4_800, 241));
429 assert!(!super::firth_problem_scale_allows(4_800, 433));
430 }
431
432 #[test]
433 pub(crate) fn tau_tau_hessian_policy_prefers_gradient_only_for_implicit_tau() {
434 let operator = ImplicitDesignPsiDerivative::new(
435 array![1.0, 2.0, 3.0, 4.0],
436 array![0.5, -1.0, 1.5, 2.0],
437 array![0.1, 0.2, 0.3, 0.4],
438 array![[1.0, 0.2], [0.5, 0.1], [1.5, 0.3], [2.0, 0.4]],
439 None,
440 None,
441 2,
442 2,
443 1,
444 2,
445 );
446 let dir = DirectionalHyperParam::new_compact(
447 HyperDesignDerivative::from_implicit(
448 Arc::new(operator),
449 ImplicitDerivLevel::First(0),
450 1..4,
451 5,
452 ),
453 Vec::new(),
454 None,
455 None,
456 )
457 .expect("implicit directional hyperparam");
458 let policy = super::exact_tau_tau_hessian_policy_with_firth(10, 5, &[dir], false);
459 assert!(policy.any_has_implicit);
460 assert_eq!(
461 policy.gradient_plan.dense_x_bytes,
462 10 * 5 * std::mem::size_of::<f64>()
463 );
464 assert!(!policy.prefer_gradient_only());
465 }
466
467 #[test]
468 pub(crate) fn tau_tau_hessian_policy_does_not_force_gradient_only_for_implicit_multidim_duchon()
469 {
470 let operator = ImplicitDesignPsiDerivative::new_streaming(
478 Arc::new(array![[0.0, 0.0], [1.0, 0.2]]),
479 Arc::new(array![[0.0, 0.0], [1.0, 1.0]]),
480 vec![0.0, 0.0],
481 RadialScalarKind::PureDuchon {
482 block_order: 1,
483 p_order: 0,
484 s_order: 0,
485 dim: 2,
486 },
487 None,
488 None,
489 0,
490 );
491 let dir = DirectionalHyperParam::new_compact(
492 HyperDesignDerivative::from_implicit(
493 Arc::new(operator),
494 ImplicitDerivLevel::First(0),
495 0..2,
496 2,
497 ),
498 Vec::new(),
499 None,
500 None,
501 )
502 .expect("implicit duchon directional hyperparam");
503 let policy = super::exact_tau_tau_hessian_policy_with_firth(10, 5, &[dir], false);
504 assert!(policy.any_has_implicit);
505 assert!(policy.implicit_multidim_duchon);
506 assert!(!policy.prefer_gradient_only());
507 }
508
509 #[test]
510 pub(crate) fn tau_tau_hessian_policy_does_not_force_gradient_only_when_cache_budget_is_exceeded()
511 {
512 let dirs = (0..16)
518 .map(|_| {
519 DirectionalHyperParam::new_compact(
520 HyperDesignDerivative::from(Array2::<f64>::zeros((2, 2))),
521 Vec::new(),
522 None,
523 None,
524 )
525 .expect("dense directional hyperparam")
526 })
527 .collect::<Vec<_>>();
528 let policy = super::exact_tau_tau_hessian_policy_with_firth(320_000, 71, &dirs, false);
529 assert!(!policy.any_has_implicit);
530 assert!(policy.hessian_plan.total_bytes() > policy.budget_bytes);
531 assert!(policy.hessian_plan.total_bytes() > policy.gradient_plan.total_bytes());
532 assert!(!policy.prefer_gradient_only());
533 }
534
535 #[test]
536 pub(crate) fn tau_tau_hessian_policy_prefers_gradient_only_for_firth_pair_gap() {
537 let dir = DirectionalHyperParam::new_compact(
538 HyperDesignDerivative::from(Array2::<f64>::zeros((2, 2))),
539 Vec::new(),
540 None,
541 None,
542 )
543 .expect("dense directional hyperparam");
544 let policy = super::exact_tau_tau_hessian_policy_with_firth(10, 5, &[dir], true);
545 assert!(policy.firth_pair_terms_unavailable);
546 assert!(policy.prefer_gradient_only());
547 }
548
549 trait LogitDesignMotionFixture {
557 fn y(&self) -> &Array1<f64>;
558 fn w(&self) -> &Array1<f64>;
559 fn x(&self) -> &Array2<f64>;
560 fn s0(&self) -> &Array2<f64>;
561 fn cfg(&self) -> &RemlConfig;
562 fn rho(&self) -> &Array1<f64>;
563
564 fn state(&self) -> RemlState<'_> {
565 build_logit_state(self.y(), self.w(), self.x(), self.s0(), self.cfg())
566 }
567
568 fn state_perturbed(
569 &self,
570 x_tau: &Array2<f64>,
571 s_tau: &Array2<f64>,
572 eps: f64,
573 ) -> (RemlState<'_>, RemlState<'_>) {
574 let x_plus = self.x() + &x_tau.mapv(|v| eps * v);
575 let x_minus = self.x() - &x_tau.mapv(|v| eps * v);
576 let s_plus = self.s0() + &s_tau.mapv(|v| eps * v);
577 let s_minus = self.s0() - &s_tau.mapv(|v| eps * v);
578 (
579 build_logit_state(self.y(), self.w(), &x_plus, &s_plus, self.cfg()),
580 build_logit_state(self.y(), self.w(), &x_minus, &s_minus, self.cfg()),
581 )
582 }
583
584 fn fd_directional_gradient(&self, x_tau: &Array2<f64>, s_tau: &Array2<f64>) -> f64 {
586 let h = 2e-5;
587 let (state_plus, state_minus) = self.state_perturbed(x_tau, s_tau, h);
588 let v_plus = state_plus.compute_cost(self.rho()).expect("cost+");
589 let v_minus = state_minus.compute_cost(self.rho()).expect("cost-");
590 (v_plus - v_minus) / (2.0 * h)
591 }
592 }
593
594 pub(crate) fn build_logit_state<'a>(
595 y: &'a Array1<f64>,
596 w: &'a Array1<f64>,
597 x: &Array2<f64>,
598 s: &Array2<f64>,
599 cfg: &'a RemlConfig,
600 ) -> RemlState<'a> {
601 use crate::estimate::PenaltySpec;
602 let p = x.ncols();
603 let offset = Array1::<f64>::zeros(y.len());
604 let spec = PenaltySpec::Dense(s.clone());
605 let canonical = gam_terms::construction::canonicalize_penalty_specs(&[spec], &[1], p, "test")
606 .map(|(canonical, _)| canonical)
607 .expect("canonicalize");
608 RemlState::newwith_offset(
609 y.view(),
610 x.clone(),
611 w.view(),
612 offset.view(),
613 canonical,
614 p,
615 cfg,
616 Some(vec![1]),
617 None,
618 None,
619 )
620 .expect("state")
621 }
622
623 #[test]
624 fn repeated_penalty_ranges_keep_analytic_outer_hessian() {
625 let y = array![0.2, -0.1, 0.3, 0.0];
626 let w = Array1::<f64>::ones(y.len());
627 let x = array![[1.0, -0.7], [1.0, -0.2], [1.0, 0.3], [1.0, 0.9]];
628 let offset = Array1::<f64>::zeros(y.len());
629 let cfg = RemlConfig::external(gaussian_identity_glm_spec(), 1e-10, false);
630 let p = x.ncols();
631 let canonical = vec![
632 gam_terms::construction::CanonicalPenalty::from_dense_root(array![[0.0, 1.0]], p),
633 gam_terms::construction::CanonicalPenalty::from_dense_root(array![[1.0, 0.0]], p),
634 ];
635 let state = RemlState::newwith_offset(
636 y.view(),
637 x,
638 w.view(),
639 offset.view(),
640 canonical,
641 p,
642 &cfg,
643 Some(vec![1, 1]),
644 None,
645 None,
646 )
647 .expect("state");
648
649 assert!(
650 state.analytic_outer_hessian_enabled(),
651 "double-penalty-style repeated coefficient ranges must still route to exact Hessian"
652 );
653 }
654
655 #[test]
656 fn canonical_logit_firth_declines_exact_tk_hessian_when_row_pair_work_is_large() {
657 let n = 2_000usize;
658 let p = 28usize;
659 let y = Array1::from_iter((0..n).map(|i| if i % 3 == 0 { 1.0 } else { 0.0 }));
660 let w = Array1::<f64>::ones(n);
661 let mut x = Array2::<f64>::zeros((n, p));
662 for i in 0..n {
663 let t = (i as f64 + 0.5) / n as f64;
664 x[[i, 0]] = 1.0;
665 for j in 1..p {
666 x[[i, j]] = ((j as f64) * std::f64::consts::TAU * t).sin()
667 + 0.25 * (((j + 1) as f64) * std::f64::consts::TAU * t).cos();
668 }
669 }
670 let mut s = Array2::<f64>::zeros((p, p));
671 for j in 1..p {
672 s[[j, j]] = 1.0;
673 }
674 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, true);
675 let state = build_logit_state(&y, &w, &x, &s, &cfg);
676
677 assert!(
678 !RemlState::firth_tk_exact_hessian_scale_allows(n, p),
679 "fixture must sit beyond the O(n²·p) exact-Hessian budget"
680 );
681 assert!(
682 !state.analytic_outer_hessian_enabled(),
683 "large canonical-logit Firth fits should keep exact value/gradient but route outer curvature to BFGS"
684 );
685 }
686
687 #[test]
688 fn canonical_logit_firth_keeps_exact_tk_hessian_for_small_separation_guards() {
689 let n = 40usize;
690 let p = 6usize;
691 let y = Array1::from_iter((0..n).map(|i| if i >= n / 2 { 1.0 } else { 0.0 }));
692 let w = Array1::<f64>::ones(n);
693 let mut x = Array2::<f64>::zeros((n, p));
694 for i in 0..n {
695 let t = (i as f64) / (n - 1) as f64;
696 x[[i, 0]] = 1.0;
697 for j in 1..p {
698 x[[i, j]] = t.powi(j as i32);
699 }
700 }
701 let mut s = Array2::<f64>::zeros((p, p));
702 for j in 1..p {
703 s[[j, j]] = 1.0;
704 }
705 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, true);
706 let state = build_logit_state(&y, &w, &x, &s, &cfg);
707
708 assert!(RemlState::firth_tk_exact_hessian_scale_allows(n, p));
709 assert!(
710 state.analytic_outer_hessian_enabled(),
711 "small Firth rescue fits should keep exact TK Hessian curvature"
712 );
713 }
714
715 pub(crate) fn poisson_log_glm_spec() -> GlmLikelihoodSpec {
716 GlmLikelihoodSpec::canonical(LikelihoodSpec::new(
717 ResponseFamily::Poisson,
718 InverseLink::Standard(StandardLink::Log),
719 ))
720 }
721
722 #[test]
736 pub(crate) fn fixed_dispersion_laml_surface_is_replication_invariant() {
737 let n = 200usize;
738 let p = 8usize;
739 let c = 3usize;
740 let mut x = Array2::<f64>::zeros((n, p));
741 let mut y = Array1::<f64>::zeros(n);
742 for i in 0..n {
743 let t = (i as f64) / ((n - 1) as f64);
744 let tau = std::f64::consts::TAU;
745 x[[i, 0]] = 1.0;
746 x[[i, 1]] = t;
747 x[[i, 2]] = (tau * t).sin();
748 x[[i, 3]] = (tau * t).cos();
749 x[[i, 4]] = (2.0 * tau * t).sin();
750 x[[i, 5]] = (2.0 * tau * t).cos();
751 x[[i, 6]] = (3.0 * tau * t).sin();
752 x[[i, 7]] = (3.0 * tau * t).cos();
753 let eta = 0.3 + 0.9 * (1.4 * (t - 0.5)).sin();
754 y[i] = (eta.exp() + 0.5 * ((i as f64) * 2.399_963).sin())
756 .round()
757 .max(0.0);
758 }
759 let mut s = Array2::<f64>::zeros((p, p));
760 for j in 1..p {
761 s[[j, j]] = 1.0;
762 }
763
764 let mut x_rep = Array2::<f64>::zeros((n * c, p));
766 let mut y_rep = Array1::<f64>::zeros(n * c);
767 for r in 0..c {
768 for i in 0..n {
769 let row = r * n + i;
770 for j in 0..p {
771 x_rep[[row, j]] = x[[i, j]];
772 }
773 y_rep[row] = y[i];
774 }
775 }
776
777 let w_weighted = Array1::<f64>::from_elem(n, c as f64);
778 let w_rep = Array1::<f64>::ones(n * c);
779
780 let cfg = RemlConfig::external(poisson_log_glm_spec(), 1e-10, false);
781 let st_w = build_logit_state(&y, &w_weighted, &x, &s, &cfg);
782 let st_r = build_logit_state(&y_rep, &w_rep, &x_rep, &s, &cfg);
783
784 for &rho in &[-2.0_f64, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0] {
785 let r = Array1::from_elem(1, rho);
786 let cw = st_w.compute_cost(&r).expect("weighted cost");
787 let cr = st_r.compute_cost(&r).expect("replicated cost");
788 let gw = st_w.compute_gradient(&r).expect("weighted grad");
789 let gr = st_r.compute_gradient(&r).expect("replicated grad");
790 assert!(
793 (cw - cr).abs() <= 1e-9 * (1.0 + cw.abs()),
794 "LAML cost differs between w=c and c× replication at rho={rho}: \
795 cost_w={cw:.12e} cost_r={cr:.12e} diff={:.3e}",
796 cw - cr
797 );
798 assert!(
799 (gw[0] - gr[0]).abs() <= 1e-9 * (1.0 + gw[0].abs()),
800 "LAML gradient differs between w=c and c× replication at rho={rho}: \
801 g_w={:.12e} g_r={:.12e} diff={:.3e}",
802 gw[0],
803 gr[0],
804 gw[0] - gr[0]
805 );
806 }
807 }
808
809 #[test]
818 pub(crate) fn rho_weight_anchor_is_zero_for_fixed_dispersion() {
819 let n = 50usize;
820 let p = 3usize;
821 let mut x = Array2::<f64>::zeros((n, p));
822 let mut y = Array1::<f64>::zeros(n);
823 for i in 0..n {
824 let t = (i as f64) / ((n - 1) as f64);
825 x[[i, 0]] = 1.0;
826 x[[i, 1]] = t;
827 x[[i, 2]] = t * t;
828 y[i] = (1.0 + (3.0 * t).sin()).round().max(0.0);
829 }
830 let mut s = Array2::<f64>::zeros((p, p));
831 s[[2, 2]] = 1.0;
832 let c = 4.0_f64;
834 let w = Array1::<f64>::from_elem(n, c);
835
836 let cfg_pois = RemlConfig::external(poisson_log_glm_spec(), 1e-10, false);
837 let st_pois = build_logit_state(&y, &w, &x, &s, &cfg_pois);
838 assert_eq!(
839 st_pois.rho_weight_anchor(),
840 0.0,
841 "fixed-dispersion (Poisson) anchor must be 0, not the geometric-mean log-weight"
842 );
843
844 let cfg_gauss = RemlConfig::external(gaussian_identity_glm_spec(), 1e-10, false);
845 let st_gauss = build_logit_state(&y, &w, &x, &s, &cfg_gauss);
846 assert!(
847 (st_gauss.rho_weight_anchor() - c.ln()).abs() <= 1e-12,
848 "Gaussian-identity (profiled) anchor must be the geometric-mean log-weight ln(c)={:.6}, got {:.6}",
849 c.ln(),
850 st_gauss.rho_weight_anchor()
851 );
852 }
853
854 pub(crate) fn beta_original_from_bundle(bundle: &EvalShared) -> Array1<f64> {
855 let pr = bundle.pirls_result.as_ref();
856 match pr.coordinate_frame {
857 PirlsCoordinateFrame::OriginalSparseNative => pr.beta_transformed.as_ref().clone(),
858 PirlsCoordinateFrame::TransformedQs => {
859 pr.reparam_result.qs.dot(pr.beta_transformed.as_ref())
860 }
861 }
862 }
863
864 pub(crate) fn compute_joint_hypercostgradienthessian(
865 state: &RemlState<'_>,
866 theta: &Array1<f64>,
867 rho_dim: usize,
868 hyper_dirs: &[DirectionalHyperParam],
869 ) -> Result<(f64, Array1<f64>, Array2<f64>), EstimationError> {
870 let (cost, gradient, hessian) = state.compute_joint_hyper_eval_with_order(
871 theta,
872 rho_dim,
873 hyper_dirs,
874 crate::rho_optimizer::OuterEvalOrder::ValueGradientHessian,
875 )?;
876 Ok((
877 cost,
878 gradient,
879 hessian
880 .materialize_dense()
881 .map_err(EstimationError::RemlOptimizationFailed)?
882 .ok_or_else(|| {
883 EstimationError::RemlOptimizationFailed(
884 "joint hyper Hessian requested but unavailable".to_string(),
885 )
886 })?,
887 ))
888 }
889
890 pub(crate) fn h_original_from_bundle(bundle: &EvalShared) -> Array2<f64> {
891 let pr = bundle.pirls_result.as_ref();
892 match pr.coordinate_frame {
893 PirlsCoordinateFrame::OriginalSparseNative => bundle.h_total.as_ref().clone(),
894 PirlsCoordinateFrame::TransformedQs => {
895 let qs = &pr.reparam_result.qs;
896 let tmp = gam_linalg::faer_ndarray::fast_ab(qs, bundle.h_total.as_ref());
897 gam_linalg::faer_ndarray::fast_abt(&tmp, qs)
898 }
899 }
900 }
901
902 pub(crate) fn single_directional_tau_gradient(
903 state: &RemlState<'_>,
904 rho: &Array1<f64>,
905 hyper: DirectionalHyperParam,
906 ) -> Result<f64, EstimationError> {
907 let mut theta = Array1::<f64>::zeros(rho.len() + 1);
908 theta.slice_mut(s![..rho.len()]).assign(rho);
909 let (_, gradient, _) = state.compute_joint_hyper_eval_with_order(
910 &theta,
911 rho.len(),
912 &[hyper],
913 crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
914 )?;
915 Ok(gradient[rho.len()])
916 }
917
918 pub(crate) fn fd_directional_tau_cost_gradient(
919 y: &Array1<f64>,
920 w: &Array1<f64>,
921 x: &Array2<f64>,
922 s0: &Array2<f64>,
923 cfg: &RemlConfig,
924 rho: &Array1<f64>,
925 x_tau: &Array2<f64>,
926 s_tau: &Array2<f64>,
927 ) -> f64 {
928 let h = 2e-5;
929 let x_plus = x + &x_tau.mapv(|v| h * v);
930 let x_minus = x - &x_tau.mapv(|v| h * v);
931 let s_plus = s0 + &s_tau.mapv(|v| h * v);
932 let s_minus = s0 - &s_tau.mapv(|v| h * v);
933 let state_plus = build_logit_state(y, w, &x_plus, &s_plus, cfg);
934 let state_minus = build_logit_state(y, w, &x_minus, &s_minus, cfg);
935 let v_plus = state_plus.compute_cost(rho).expect("cost+");
936 let v_minus = state_minus.compute_cost(rho).expect("cost-");
937 (v_plus - v_minus) / (2.0 * h)
938 }
939
940 pub(crate) fn directional_tau_hessian_fd_reference(
941 y: &Array1<f64>,
942 w: &Array1<f64>,
943 x: &Array2<f64>,
944 s0: &Array2<f64>,
945 cfg: &RemlConfig,
946 rho: &Array1<f64>,
947 hyper_dirs: &[DirectionalHyperParam],
948 x_tau_mats: &[Array2<f64>],
949 s_tau_mats: &[Array2<f64>],
950 ) -> Array2<f64> {
951 assert_eq!(hyper_dirs.len(), x_tau_mats.len());
952 assert_eq!(hyper_dirs.len(), s_tau_mats.len());
953
954 const TARGET_PHYSICAL_STEP: f64 = 1e-5;
955
956 let n_dirs = hyper_dirs.len();
957 let mut h_ttfd = Array2::<f64>::zeros((n_dirs, n_dirs));
958 for j in 0..n_dirs {
959 let direction_scale = x_tau_mats[j]
960 .iter()
961 .chain(s_tau_mats[j].iter())
962 .fold(0.0_f64, |acc, value| acc.max(value.abs()));
963 let h = if direction_scale > 0.0 {
964 TARGET_PHYSICAL_STEP / direction_scale
965 } else {
966 TARGET_PHYSICAL_STEP
967 };
968
969 let x_plus = x + &x_tau_mats[j].mapv(|v| h * v);
970 let x_minus = x - &x_tau_mats[j].mapv(|v| h * v);
971 let s_plus = s0 + &s_tau_mats[j].mapv(|v| h * v);
972 let s_minus = s0 - &s_tau_mats[j].mapv(|v| h * v);
973
974 let state_plus = build_logit_state(y, w, &x_plus, &s_plus, cfg);
975 let state_minus = build_logit_state(y, w, &x_minus, &s_minus, cfg);
976 for i in 0..n_dirs {
977 let g_plus =
978 single_directional_tau_gradient(&state_plus, rho, hyper_dirs[i].clone())
979 .expect("g+ for FD");
980 let g_minus =
981 single_directional_tau_gradient(&state_minus, rho, hyper_dirs[i].clone())
982 .expect("g- for FD");
983 h_ttfd[[i, j]] = (g_plus - g_minus) / (2.0 * h);
984 }
985 }
986 symmetrize_in_place(&mut h_ttfd);
987 h_ttfd
988 }
989
990 #[test]
991 pub(crate) fn eval_cache_manager_stores_first_order_outer_eval() {
992 let cache = EvalCacheManager::new();
993 let rho = array![0.25, -0.0];
994 let rho_key = EvalCacheManager::sanitized_rhokey(&rho);
995 let eval = OuterEval {
996 cost: 3.5,
997 gradient: array![1.0, -2.0],
998 hessian: HessianResult::Unavailable,
999 inner_beta_hint: None,
1000 };
1001
1002 cache.store_outer_eval(&rho_key, &eval);
1003
1004 let cached = cache
1005 .cached_outer_eval(&rho_key)
1006 .expect("first-order outer eval should be cached");
1007 assert_eq!(cached.cost, eval.cost);
1008 assert_eq!(cached.gradient, eval.gradient);
1009 assert!(matches!(cached.hessian, HessianResult::Unavailable));
1010
1011 cache.invalidate_eval_bundle();
1012 assert!(
1013 cache.cached_outer_eval(&rho_key).is_none(),
1014 "invalidating the bundle should clear the outer-eval cache too"
1015 );
1016 }
1017
1018 #[test]
1028 pub(crate) fn outer_eval_lru_hit_is_bit_identical_and_evicts_honestly_1575() {
1029 use super::OUTER_EVAL_LRU_CAPACITY;
1030
1031 let make_eval = |seed: f64| OuterEval {
1034 cost: (seed * std::f64::consts::PI).sin() / 3.0 - seed,
1035 gradient: array![seed, -seed * 2.0, seed.recip()],
1036 hessian: HessianResult::Unavailable,
1037 inner_beta_hint: Some(array![seed + 0.5, seed - 0.5]),
1038 };
1039 let bits_eq = |a: &OuterEval, b: &OuterEval| -> bool {
1040 a.cost.to_bits() == b.cost.to_bits()
1041 && a.gradient.len() == b.gradient.len()
1042 && a.gradient
1043 .iter()
1044 .zip(b.gradient.iter())
1045 .all(|(x, y)| x.to_bits() == y.to_bits())
1046 };
1047
1048 let cache = EvalCacheManager::new();
1049
1050 let rho_a = array![0.25, -1.5];
1053 let key_a = EvalCacheManager::sanitized_rhokey(&rho_a);
1054 let eval_a = make_eval(0.25);
1055 cache.store_outer_eval(&key_a, &eval_a);
1056 let hit_a = cache
1057 .cached_outer_eval(&key_a)
1058 .expect("stored rho_a must hit");
1059 assert!(
1060 bits_eq(&hit_a, &eval_a),
1061 "cache hit must be bit-identical (cost+gradient) to the stored miss-path eval"
1062 );
1063 assert_eq!(
1064 hit_a.inner_beta_hint.as_ref().map(|b| b.to_vec()),
1065 eval_a.inner_beta_hint.as_ref().map(|b| b.to_vec()),
1066 "inner_beta_hint must round-trip unchanged"
1067 );
1068
1069 let rho_b = array![0.25, -1.4999999999999998];
1072 let key_b = EvalCacheManager::sanitized_rhokey(&rho_b);
1073 assert_ne!(key_a, key_b, "the two rho-keys must differ");
1074 let eval_b = make_eval(7.0);
1075 cache.store_outer_eval(&key_b, &eval_b);
1076 assert!(
1077 bits_eq(
1078 &cache.cached_outer_eval(&key_b).expect("rho_b must hit"),
1079 &eval_b
1080 ),
1081 "rho_b must return its own eval, not rho_a's"
1082 );
1083 assert!(
1084 bits_eq(
1085 &cache.cached_outer_eval(&key_a).expect("rho_a must still hit"),
1086 &eval_a
1087 ),
1088 "rho_a must be unaffected by the rho_b insert"
1089 );
1090
1091 let cache = EvalCacheManager::new();
1095 let mut keys = Vec::new();
1096 let mut evals = Vec::new();
1097 for i in 0..OUTER_EVAL_LRU_CAPACITY {
1098 let rho = array![i as f64, -(i as f64)];
1099 let key = EvalCacheManager::sanitized_rhokey(&rho);
1100 let eval = make_eval(i as f64 + 0.123);
1101 cache.store_outer_eval(&key, &eval);
1102 keys.push(key);
1103 evals.push(eval);
1104 }
1105 assert_eq!(
1107 cache.outer_eval_lru.read().unwrap().entries.len(),
1108 OUTER_EVAL_LRU_CAPACITY
1109 );
1110 let rho_overflow = array![999.0, -999.0];
1112 let key_overflow = EvalCacheManager::sanitized_rhokey(&rho_overflow);
1113 let eval_overflow = make_eval(42.0);
1114 cache.store_outer_eval(&key_overflow, &eval_overflow);
1115 assert_eq!(
1116 cache.outer_eval_lru.read().unwrap().entries.len(),
1117 OUTER_EVAL_LRU_CAPACITY,
1118 "capacity must stay bounded"
1119 );
1120 assert!(
1121 cache.cached_outer_eval(&keys[0]).is_none(),
1122 "the least-recently-used key must be evicted and now MISS (recompute), not return stale"
1123 );
1124 assert!(
1125 bits_eq(
1126 &cache
1127 .cached_outer_eval(&keys[1])
1128 .expect("a still-resident key must hit"),
1129 &evals[1]
1130 ),
1131 "a still-resident key must return its exact stored bits"
1132 );
1133 assert!(
1134 bits_eq(
1135 &cache
1136 .cached_outer_eval(&key_overflow)
1137 .expect("the freshest key must hit"),
1138 &eval_overflow
1139 ),
1140 "the freshest key must hit with its own eval"
1141 );
1142 }
1143
1144 #[test]
1145 pub(crate) fn reset_outer_seed_state_clears_pirls_cache() {
1146 let y = array![0.0, 1.0, 1.0, 0.0, 0.0, 1.0];
1152 let w = Array1::<f64>::ones(y.len());
1153 let x = array![
1154 [1.0, -1.0, 0.2],
1155 [1.0, -0.5, -0.4],
1156 [1.0, 0.0, 0.7],
1157 [1.0, 0.4, -0.3],
1158 [1.0, 0.9, 0.1],
1159 [1.0, 1.3, -0.6],
1160 ];
1161 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.1, 0.15], [0.0, 0.15, 0.8],];
1162 let rho = array![0.0];
1163 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false);
1164 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1165
1166 state
1169 .compute_outer_eval_with_order(
1170 &rho,
1171 crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
1172 )
1173 .expect("outer eval should succeed");
1174
1175 let populated_len = state.cache_manager.pirls_cache.read().unwrap().map.len();
1176 assert!(
1177 populated_len > 0,
1178 "evaluating the outer objective should populate the PIRLS LRU, got {populated_len}"
1179 );
1180
1181 state.reset_outer_seed_state();
1182
1183 let cleared_len = state.cache_manager.pirls_cache.read().unwrap().map.len();
1184 assert_eq!(
1185 cleared_len, 0,
1186 "reset_outer_seed_state must clear the cross-call PIRLS LRU; got {cleared_len} entries"
1187 );
1188 }
1189
1190 #[test]
1191 pub(crate) fn reset_outer_seed_state_preserves_frozen_negbin_theta_1448() {
1192 use std::sync::atomic::Ordering;
1209
1210 let y = array![0.0, 1.0, 1.0, 0.0, 0.0, 1.0];
1211 let w = Array1::<f64>::ones(y.len());
1212 let x = array![
1213 [1.0, -1.0, 0.2],
1214 [1.0, -0.5, -0.4],
1215 [1.0, 0.0, 0.7],
1216 [1.0, 0.4, -0.3],
1217 [1.0, 0.9, 0.1],
1218 [1.0, 1.3, -0.6],
1219 ];
1220 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.1, 0.15], [0.0, 0.15, 0.8],];
1221 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false);
1222 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1223
1224 let theta_final_bits = 2.5_f64.to_bits();
1226 state
1227 .frozen_negbin_theta
1228 .store(theta_final_bits, Ordering::Relaxed);
1229 assert_eq!(
1230 state.frozen_negbin_theta.load(Ordering::Relaxed),
1231 theta_final_bits,
1232 "precondition: the re-freeze stores θ_final into the frozen slot"
1233 );
1234
1235 state.reset_outer_seed_state();
1237
1238 assert_eq!(
1239 state.frozen_negbin_theta.load(Ordering::Relaxed),
1240 theta_final_bits,
1241 "reset_outer_seed_state (alternation-round reset) must PRESERVE the \
1242 re-frozen NB θ; clearing it would defeat the #1448 θ↔λ alternation \
1243 (the next ρ search would re-derive θ from the seed and never reach \
1244 the joint fixed point)"
1245 );
1246 }
1247
1248 #[test]
1249 pub(crate) fn implicit_hyper_design_derivative_respects_full_model_embedding() {
1250 let operator = ImplicitDesignPsiDerivative::new(
1251 array![1.0, 2.0, 3.0, 4.0],
1252 array![0.5, -1.0, 1.5, 2.0],
1253 array![0.1, 0.2, 0.3, 0.4],
1254 array![[1.0, 0.2], [0.5, 0.1], [1.5, 0.3], [2.0, 0.4]],
1255 None,
1256 None,
1257 2,
1258 2,
1259 1,
1260 2,
1261 );
1262 let local = operator
1263 .materialize_first(0)
1264 .expect("materialized first derivative");
1265 assert_eq!(
1266 local.ncols(),
1267 3,
1268 "operator-local derivative should stay smooth-local"
1269 );
1270
1271 let implicit = HyperDesignDerivative::from_implicit(
1272 Arc::new(operator),
1273 ImplicitDerivLevel::First(0),
1274 1..4,
1275 5,
1276 );
1277 let embedded = HyperDesignDerivative::from_embedded(local.clone(), 1..4, 5);
1278
1279 assert_eq!(implicit.nrows(), embedded.nrows());
1280 assert_eq!(implicit.ncols(), 5);
1281 assert_eq!(implicit.materialize(), embedded.materialize());
1282
1283 let u = array![7.0, 1.5, -2.0, 0.25, -3.0];
1284 let v = array![0.75, -1.25];
1285 assert_eq!(
1286 implicit.forward_mul_original(&u).expect("implicit forward"),
1287 embedded.forward_mul_original(&u).expect("embedded forward")
1288 );
1289 assert_eq!(
1290 implicit
1291 .transpose_mul_original(&v)
1292 .expect("implicit transpose"),
1293 embedded
1294 .transpose_mul_original(&v)
1295 .expect("embedded transpose")
1296 );
1297
1298 let qs = array![
1299 [1.0, 0.0, 0.0],
1300 [0.0, 1.0, 0.0],
1301 [0.0, 0.5, 0.5],
1302 [0.0, 0.0, 1.0],
1303 [0.0, 0.0, 0.0],
1304 ];
1305 assert_eq!(
1306 implicit
1307 .transformed(&qs, None)
1308 .expect("implicit transformed"),
1309 embedded
1310 .transformed(&qs, None)
1311 .expect("embedded transformed")
1312 );
1313 let u_transformed = array![1.0, -0.5, 2.0];
1314 assert_eq!(
1315 implicit
1316 .transformed_forward_mul(&qs, None, &u_transformed)
1317 .expect("implicit transformed forward"),
1318 embedded
1319 .transformed_forward_mul(&qs, None, &u_transformed)
1320 .expect("embedded transformed forward")
1321 );
1322 assert_eq!(
1323 implicit
1324 .transformed_transpose_mul(&qs, None, &v)
1325 .expect("implicit transformed transpose"),
1326 embedded
1327 .transformed_transpose_mul(&qs, None, &v)
1328 .expect("embedded transformed transpose")
1329 );
1330 }
1331
1332 #[test]
1333 pub(crate) fn directional_hyper_identities_match_finite_differences_logit() {
1334 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1335 let w = Array1::<f64>::ones(y.len());
1336 let x = array![
1337 [1.0, -1.2, 0.3],
1338 [1.0, -0.8, -0.4],
1339 [1.0, -0.3, 0.7],
1340 [1.0, 0.1, -0.9],
1341 [1.0, 0.5, 0.2],
1342 [1.0, 0.9, -0.1],
1343 [1.0, 1.3, 0.8],
1344 [1.0, 1.7, -0.6],
1345 ];
1346 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1347
1348 let x_tau = Array2::<f64>::zeros(x.raw_dim());
1353 let s_tau = array![[0.0, 0.0, 0.0], [0.0, 0.25, 0.04], [0.0, 0.04, 0.15],];
1354 let hyper =
1355 DirectionalHyperParam::single_penalty(0, x_tau.clone(), s_tau.clone(), None, None)
1356 .expect("single-penalty hyper direction");
1357 let rho = array![0.0];
1358
1359 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
1363 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1364 let bundle = state.obtain_eval_bundle(&rho).expect("bundle");
1365 let pr = bundle.pirls_result.as_ref();
1366
1367 let beta = beta_original_from_bundle(&bundle);
1368 let h_orig = h_original_from_bundle(&bundle);
1369 let u = &pr.solveweights * &(&pr.solveworking_response - &pr.final_eta);
1370
1371 let x_tau_beta = gam_linalg::faer_ndarray::fast_av(&x_tau, &beta);
1374 let weighted_x_tau_beta = &pr.finalweights * &x_tau_beta;
1375 let rhs = gam_linalg::faer_ndarray::fast_atv(&x_tau, &u)
1376 - gam_linalg::faer_ndarray::fast_atv(&x, &weighted_x_tau_beta)
1377 - s_tau.dot(&beta);
1378 let chol = h_orig.cholesky(Side::Lower).expect("chol(H)");
1379 let b_analytic = chol.solvevec(&rhs);
1380
1381 let eta_dot = &x_tau_beta + &gam_linalg::faer_ndarray::fast_av(&x, &b_analytic);
1385 let w_direction = crate::pirls::directionalworking_curvature_from_c_array(
1386 &pr.solve_c_array,
1387 &pr.finalweights,
1388 &eta_dot,
1389 );
1390 let wx = RemlState::row_scale(&x, &pr.finalweights);
1391 let wx_tau = RemlState::row_scale(&x_tau, &pr.finalweights);
1392 let mut xwtau_x = x.clone();
1393 match w_direction {
1394 crate::pirls::DirectionalWorkingCurvature::Diagonal(diag) => {
1395 xwtau_x = RemlState::row_scale(&xwtau_x, &diag);
1396 }
1397 }
1398 let mut h_tau_analytic = gam_linalg::faer_ndarray::fast_atb(&x_tau, &wx);
1399 h_tau_analytic += &gam_linalg::faer_ndarray::fast_atb(&x, &wx_tau);
1400 h_tau_analytic += &gam_linalg::faer_ndarray::fast_atb(&x, &xwtau_x);
1401 h_tau_analytic += &s_tau;
1402
1403 let ell_beta = gam_linalg::faer_ndarray::fast_atv(&x, &u);
1408 let s_eff = &h_orig - &gam_linalg::faer_ndarray::fast_atb(&x, &wx);
1409 let cancellation = -ell_beta.dot(&b_analytic) + beta.dot(&s_eff.dot(&b_analytic));
1410
1411 let h = 2e-5;
1413 let x_plus = &x + &(x_tau.mapv(|v| h * v));
1414 let x_minus = &x - &(x_tau.mapv(|v| h * v));
1415 let s_plus = &s0 + &(s_tau.mapv(|v| h * v));
1416 let s_minus = &s0 - &(s_tau.mapv(|v| h * v));
1417
1418 let state_plus = build_logit_state(&y, &w, &x_plus, &s_plus, &cfg);
1419 let state_minus = build_logit_state(&y, &w, &x_minus, &s_minus, &cfg);
1420 let bundle_plus = state_plus.obtain_eval_bundle(&rho).expect("bundle+");
1421 let bundle_minus = state_minus.obtain_eval_bundle(&rho).expect("bundle-");
1422 let beta_plus = beta_original_from_bundle(&bundle_plus);
1423 let beta_minus = beta_original_from_bundle(&bundle_minus);
1424 let bfd = (&beta_plus - &beta_minus).mapv(|v| v / (2.0 * h));
1425
1426 let h_plus = h_original_from_bundle(&bundle_plus);
1427 let h_minus = h_original_from_bundle(&bundle_minus);
1428 let h_taufd = (&h_plus - &h_minus).mapv(|v| v / (2.0 * h));
1429
1430 let v_plus = state_plus.compute_cost(&rho).expect("cost+");
1431 let v_minus = state_minus.compute_cost(&rho).expect("cost-");
1432 let v_taufd = (v_plus - v_minus) / (2.0 * h);
1433
1434 let v_tau_analytic = single_directional_tau_gradient(&state, &rho, hyper.clone())
1435 .expect("analytic directional gradient");
1436
1437 let b_num = (&b_analytic - &bfd).mapv(|v| v * v).sum().sqrt();
1438 let b_den = bfd.mapv(|v| v * v).sum().sqrt().max(1e-12);
1439 let b_rel = b_num / b_den;
1440 for i in 0..b_analytic.len() {
1441 assert_eq!(
1442 b_analytic[i].signum(),
1443 bfd[i].signum(),
1444 "B sign mismatch at i={i}: analytic={} fd={}",
1445 b_analytic[i],
1446 bfd[i]
1447 );
1448 }
1449 assert!(
1450 b_rel < 2e-2,
1451 "B implicit solve mismatch vs FD: rel={b_rel:.3e}, num={b_num:.3e}, den={b_den:.3e}"
1452 );
1453
1454 let dh_num = (&h_tau_analytic - &h_taufd).mapv(|v| v * v).sum().sqrt();
1455 let dh_den = h_taufd.mapv(|v| v * v).sum().sqrt().max(1e-12);
1456 let dh_rel = dh_num / dh_den;
1457 for i in 0..h_tau_analytic.nrows() {
1458 for j in 0..h_tau_analytic.ncols() {
1459 assert_eq!(
1460 h_tau_analytic[[i, j]].signum(),
1461 h_taufd[[i, j]].signum(),
1462 "H_tau sign mismatch at ({i},{j}): analytic={} fd={}",
1463 h_tau_analytic[[i, j]],
1464 h_taufd[[i, j]]
1465 );
1466 }
1467 }
1468 assert!(
1469 dh_rel < 3e-2,
1470 "H_tau mismatch vs FD: rel={dh_rel:.3e}, num={dh_num:.3e}, den={dh_den:.3e}"
1471 );
1472
1473 let v_abs = (v_tau_analytic - v_taufd).abs();
1474 let v_rel = v_abs / v_taufd.abs().max(1e-10);
1475 assert_eq!(
1476 v_tau_analytic.signum(),
1477 v_taufd.signum(),
1478 "V_tau sign mismatch: analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
1479 );
1480 assert!(
1481 v_rel < 2e-2,
1482 "V_tau mismatch vs FD: rel={v_rel:.3e}, abs={v_abs:.3e}, analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
1483 );
1484
1485 assert!(
1486 cancellation.abs() < 1e-10,
1487 "stationarity cancellation failed: | -ell_beta^T B + beta^T S B | = {:.3e}",
1488 cancellation.abs()
1489 );
1490 }
1491
1492 #[test]
1493 pub(crate) fn firth_exacthessian_includes_analytic_tk_second_derivatives() {
1494 let y = array![0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0];
1496 let w = Array1::<f64>::ones(y.len());
1497 let x = array![
1498 [1.0, -1.2, 0.4, -2.4],
1499 [1.0, -0.9, -0.1, -1.8],
1500 [1.0, -0.6, 0.3, -1.2],
1501 [1.0, -0.2, -0.4, -0.4],
1502 [1.0, 0.1, 0.5, 0.2],
1503 [1.0, 0.4, -0.6, 0.8],
1504 [1.0, 0.8, 0.2, 1.6],
1505 [1.0, 1.1, -0.3, 2.2],
1506 [1.0, 1.4, 0.7, 2.8],
1507 [1.0, 1.7, -0.2, 3.4],
1508 ];
1509 let s0 = array![
1510 [0.0, 0.0, 0.0, 0.0],
1511 [0.0, 1.5, 0.2, 0.0],
1512 [0.0, 0.2, 1.0, 0.0],
1513 [0.0, 0.0, 0.0, 0.5],
1514 ];
1515 let s1 = array![
1516 [0.0, 0.0, 0.0, 0.0],
1517 [0.0, 0.8, -0.1, 0.0],
1518 [0.0, -0.1, 0.6, 0.0],
1519 [0.0, 0.0, 0.0, 0.3],
1520 ];
1521 let offset = Array1::<f64>::zeros(y.len());
1522 let cfg =
1525 RemlConfig::external(binomial_logit_glm_spec(), 1e-9, true).with_max_iterations(500);
1526 let p = x.ncols();
1527 use crate::estimate::PenaltySpec;
1528 let specs = vec![PenaltySpec::Dense(s0), PenaltySpec::Dense(s1)];
1529 let canonical = gam_terms::construction::canonicalize_penalty_specs(&specs, &[1, 1], p, "test")
1530 .map(|(canonical, _)| canonical)
1531 .expect("canonicalize");
1532 let state = RemlState::newwith_offset(
1533 y.view(),
1534 x.clone(),
1535 w.view(),
1536 offset.view(),
1537 canonical,
1538 p,
1539 &cfg,
1540 Some(vec![1, 1]),
1541 None,
1542 None,
1543 )
1544 .expect("state");
1545 let rho = array![0.1, -0.2];
1546 assert!(
1547 state.analytic_outer_hessian_enabled(),
1548 "Firth logit should no longer disable analytic outer Hessian planning"
1549 );
1550 let outer = state
1551 .compute_outer_eval_with_order(
1552 &rho,
1553 crate::rho_optimizer::OuterEvalOrder::ValueGradientHessian,
1554 )
1555 .expect("outer Hessian eval should succeed");
1556 assert!(
1557 outer.hessian.is_analytic(),
1558 "outer planner should request and return an analytic Hessian"
1559 );
1560 let bundle = state.obtain_eval_bundle(&rho).expect("exact firth bundle");
1561 let h_dense = state
1562 .compute_lamlhessian_exact_from_bundle(&rho, &bundle)
1563 .expect("Firth exact Hessian should include analytic TK second derivatives");
1564 assert_eq!(h_dense.raw_dim(), ndarray::Ix2(2, 2));
1565 assert!(
1566 h_dense.iter().all(|value| value.is_finite()),
1567 "Hessian should be finite: {h_dense:?}"
1568 );
1569 }
1570
1571 #[test]
1572 pub(crate) fn firth_outer_hessian_matches_gradient_finite_difference_with_tk_terms() {
1573 let y = array![0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0];
1574 let w = Array1::<f64>::ones(y.len());
1575 let x = array![
1576 [1.0, -1.0, 0.3],
1577 [1.0, -0.7, -0.2],
1578 [1.0, -0.3, 0.4],
1579 [1.0, 0.0, -0.5],
1580 [1.0, 0.2, 0.6],
1581 [1.0, 0.6, -0.4],
1582 [1.0, 0.9, 0.2],
1583 [1.0, 1.3, -0.1],
1584 ];
1585 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.1], [0.0, 0.1, 0.7],];
1586 let s1 = array![[0.0, 0.0, 0.0], [0.0, 0.4, -0.05], [0.0, -0.05, 0.9],];
1587 let cfg =
1588 RemlConfig::external(binomial_logit_glm_spec(), 1e-9, true).with_max_iterations(500);
1589 let p_dim = x.ncols();
1590 use crate::estimate::PenaltySpec;
1591 let specs = vec![PenaltySpec::Dense(s0), PenaltySpec::Dense(s1)];
1592 let canonical =
1593 gam_terms::construction::canonicalize_penalty_specs(&specs, &[1, 1], p_dim, "test")
1594 .map(|(canonical, _)| canonical)
1595 .expect("canonicalize");
1596 let offset = Array1::<f64>::zeros(y.len());
1597 let state = RemlState::newwith_offset(
1598 y.view(),
1599 x.clone(),
1600 w.view(),
1601 offset.view(),
1602 canonical,
1603 p_dim,
1604 &cfg,
1605 Some(vec![1, 1]),
1606 None,
1607 None,
1608 )
1609 .expect("state");
1610 let rho = array![0.15, -0.25];
1611 let eval = state
1612 .compute_outer_eval_with_order(
1613 &rho,
1614 crate::rho_optimizer::OuterEvalOrder::ValueGradientHessian,
1615 )
1616 .expect("analytic Hessian eval");
1617 let h = match eval.hessian {
1618 HessianResult::Analytic(hessian) => hessian,
1619 HessianResult::Operator(_) | HessianResult::Unavailable => {
1620 panic!("expected dense analytic Hessian")
1621 }
1622 };
1623 let delta = 2.0e-5;
1624 for col in 0..rho.len() {
1625 let mut rp = rho.clone();
1626 let mut rm = rho.clone();
1627 rp[col] += delta;
1628 rm[col] -= delta;
1629 let gp = state
1630 .compute_outer_eval_with_order(
1631 &rp,
1632 crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
1633 )
1634 .expect("plus grad")
1635 .gradient;
1636 let gm = state
1637 .compute_outer_eval_with_order(
1638 &rm,
1639 crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
1640 )
1641 .expect("minus grad")
1642 .gradient;
1643 for row in 0..rho.len() {
1644 let fd = (gp[row] - gm[row]) / (2.0 * delta);
1645 let an = h[[row, col]];
1646 let rel = (fd - an).abs() / fd.abs().max(an.abs()).max(1e-6);
1647 assert!(
1648 rel < 2.0e-3,
1649 "Hessian mismatch ({row},{col}): analytic={an:.9e}, fd={fd:.9e}, rel={rel:.3e}"
1650 );
1651 }
1652 }
1653 }
1654
1655 #[test]
1656 pub(crate) fn firthgradient_lives_in_design_column_space_under_rank_deficiency() {
1657 let x = array![
1659 [1.0, -1.2, 0.4, -2.4],
1660 [1.0, -0.9, -0.1, -1.8],
1661 [1.0, -0.6, 0.3, -1.2],
1662 [1.0, -0.2, -0.4, -0.4],
1663 [1.0, 0.1, 0.5, 0.2],
1664 [1.0, 0.4, -0.6, 0.8],
1665 [1.0, 0.8, 0.2, 1.6],
1666 [1.0, 1.1, -0.3, 2.2],
1667 ];
1668 let beta = array![0.1, -0.2, 0.3, 0.05];
1669 let eta = x.dot(&beta);
1670 let op = super::RemlState::build_firth_dense_operator_for_link(
1671 &gam_problem::InverseLink::Standard(gam_problem::StandardLink::Logit),
1672 &x,
1673 &eta,
1674 ndarray::Array1::ones(x.nrows()).view(),
1675 )
1676 .expect("firth operator");
1677
1678 let gradphi = 0.5 * x.t().dot(&(&op.w1 * &op.h_diag));
1681
1682 let q = &op.q_basis;
1684 let proj = q.dot(&q.t().dot(&gradphi));
1685 let resid = &gradphi - &proj;
1686 let rel =
1687 resid.mapv(|v| v * v).sum().sqrt() / gradphi.mapv(|v| v * v).sum().sqrt().max(1e-12);
1688 assert!(
1689 rel < 1e-10,
1690 "Firth gradient should lie in Col(Xᵀ): rel residual={rel:.3e}"
1691 );
1692 }
1693
1694 #[test]
1695 pub(crate) fn firth_logit_directional_hypergradient_accepts_penalty_only_with_full_tk_gradient()
1696 {
1697 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
1698 let w = Array1::<f64>::ones(y.len());
1699 let x = array![
1700 [1.0, -1.1, 0.2],
1701 [1.0, -0.6, -0.3],
1702 [1.0, -0.1, 0.5],
1703 [1.0, 0.3, -0.7],
1704 [1.0, 0.8, 0.1],
1705 [1.0, 1.2, -0.4],
1706 ];
1707 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1708 let hyper = DirectionalHyperParam::single_penalty(
1709 0,
1710 Array2::<f64>::zeros((x.nrows(), x.ncols())),
1711 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.03], [0.0, 0.03, 0.12],],
1712 None,
1713 None,
1714 )
1715 .expect("single-penalty hyper direction");
1716 let rho = array![0.0];
1717 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-8, true);
1718 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1719 let gradient = single_directional_tau_gradient(&state, &rho, hyper)
1720 .expect("Firth penalty-only directional gradient should use analytic TK propagation");
1721 assert!(gradient.is_finite(), "gradient={gradient}");
1722 let fd = fd_directional_tau_cost_gradient(
1723 &y,
1724 &w,
1725 &x,
1726 &s0,
1727 &cfg,
1728 &rho,
1729 &Array2::<f64>::zeros((x.nrows(), x.ncols())),
1730 &array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.03], [0.0, 0.03, 0.12],],
1731 );
1732 let rel = (gradient - fd).abs() / gradient.abs().max(fd.abs()).max(1.0e-10);
1733 assert!(
1734 rel < 1.0e-3,
1735 "Firth penalty-only directional gradient mismatch: analytic={gradient:.12e}, fd={fd:.12e}, rel={rel:.3e}"
1736 );
1737
1738 let efs_hyper = DirectionalHyperParam::single_penalty(
1739 0,
1740 Array2::<f64>::zeros((x.nrows(), x.ncols())),
1741 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.03], [0.0, 0.03, 0.12],],
1742 None,
1743 None,
1744 )
1745 .expect("single-penalty EFS hyper direction");
1746 let efs = state
1747 .compute_efs_steps_with_psi_ext(&rho, &[efs_hyper])
1748 .expect("Firth penalty-only EFS should use analytic TK propagation");
1749 assert!(efs.cost.is_finite(), "efs cost={}", efs.cost);
1750 }
1751
1752 #[test]
1753 pub(crate) fn firth_logit_directional_hypergradient_accepts_design_moving_with_full_tk_gradient()
1754 {
1755 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
1756 let w = Array1::<f64>::ones(y.len());
1757 let x = array![
1758 [1.0, -1.1, 0.2],
1759 [1.0, -0.6, -0.3],
1760 [1.0, -0.1, 0.5],
1761 [1.0, 0.3, -0.7],
1762 [1.0, 0.8, 0.1],
1763 [1.0, 1.2, -0.4],
1764 ];
1765 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1766 let hyper = DirectionalHyperParam::single_penalty(
1767 0,
1768 Array2::from_elem((x.nrows(), x.ncols()), 1e-3),
1769 Array2::<f64>::zeros((x.ncols(), x.ncols())),
1770 None,
1771 None,
1772 )
1773 .expect("single-penalty hyper direction");
1774 let rho = array![0.0];
1775 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-8, true);
1776 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1777 let gradient = single_directional_tau_gradient(&state, &rho, hyper)
1778 .expect("Firth design-moving directional gradient should use analytic TK propagation");
1779 assert!(gradient.is_finite(), "gradient={gradient}");
1780 let x_tau = Array2::from_elem((x.nrows(), x.ncols()), 1e-3);
1781 let s_tau = Array2::<f64>::zeros((x.ncols(), x.ncols()));
1782 let fd = fd_directional_tau_cost_gradient(&y, &w, &x, &s0, &cfg, &rho, &x_tau, &s_tau);
1783 let rel = (gradient - fd).abs() / gradient.abs().max(fd.abs()).max(1.0e-10);
1784 assert!(
1785 rel < 2.0e-2,
1786 "Firth design-moving directional gradient mismatch: analytic={gradient:.12e}, fd={fd:.12e}, rel={rel:.3e}"
1787 );
1788 }
1789
1790 #[test]
1791 pub(crate) fn firth_logit_hybrid_efs_accepts_full_tk_psi_gradient() {
1792 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
1793 let w = Array1::<f64>::ones(y.len());
1794 let x = array![
1795 [1.0, -1.1, 0.2],
1796 [1.0, -0.6, -0.3],
1797 [1.0, -0.1, 0.5],
1798 [1.0, 0.3, -0.7],
1799 [1.0, 0.8, 0.1],
1800 [1.0, 1.2, -0.4],
1801 ];
1802 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1803 let hyper_dirs = vec![
1804 DirectionalHyperParam::single_penalty(
1805 0,
1806 Array2::from_shape_fn((x.nrows(), x.ncols()), |(i, j)| {
1807 1e-3 * ((i + 1) as f64) * ((j + 2) as f64)
1808 }),
1809 Array2::<f64>::zeros((x.ncols(), x.ncols())),
1810 None,
1811 None,
1812 )
1813 .expect("design-moving hyper direction"),
1814 ];
1815 let rho = array![0.0];
1816 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-8, true);
1817 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1818
1819 let full = state
1820 .evaluate_unified_with_psi_ext(
1821 &rho,
1822 None,
1823 crate::estimate::reml::reml_outer_engine::EvalMode::ValueAndGradient,
1824 &hyper_dirs,
1825 )
1826 .expect("full Firth psi gradient should use analytic TK propagation");
1827 assert!(full.cost.is_finite(), "full cost={}", full.cost);
1828 let full_grad = full.gradient.expect("gradient should be present");
1829 assert!(
1830 full_grad.iter().all(|value| value.is_finite()),
1831 "full gradient={full_grad:?}"
1832 );
1833
1834 let efs = state
1835 .compute_efs_steps_with_psi_ext(&rho, &hyper_dirs)
1836 .expect("hybrid EFS should use analytic TK propagation");
1837 assert!(efs.cost.is_finite(), "efs cost={}", efs.cost);
1838 }
1839
1840 #[test]
1841 pub(crate) fn joint_hyperhessianwires_mixed_blocks() {
1842 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1843 let w = Array1::<f64>::ones(y.len());
1844 let x = array![
1845 [1.0, -1.2, 0.3],
1846 [1.0, -0.8, -0.4],
1847 [1.0, -0.3, 0.7],
1848 [1.0, 0.1, -0.9],
1849 [1.0, 0.5, 0.2],
1850 [1.0, 0.9, -0.1],
1851 [1.0, 1.3, 0.8],
1852 [1.0, 1.7, -0.6],
1853 ];
1854 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1855 let cfg =
1856 RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false).with_max_iterations(500);
1857 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1858 let rho = array![0.0];
1859 let theta = array![0.0, 0.0, 0.0];
1860 let hyper_dirs = vec![
1861 DirectionalHyperParam::single_penalty(
1862 0,
1863 Array2::<f64>::zeros((x.nrows(), x.ncols())),
1864 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15],],
1865 None,
1866 None,
1867 )
1868 .expect("single-penalty hyper direction"),
1869 DirectionalHyperParam::single_penalty(
1870 0,
1871 Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
1872 Array2::<f64>::zeros((x.ncols(), x.ncols())),
1873 None,
1874 None,
1875 )
1876 .expect("single-penalty hyper direction"),
1877 ];
1878
1879 let (_, _, h) =
1880 compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
1881 .expect("joint hyper cost+gradient+hessian");
1882 assert_eq!(h.nrows(), theta.len());
1883 assert_eq!(h.ncols(), theta.len());
1884 assert!(h.iter().all(|v| v.is_finite()));
1885 for i in 0..h.nrows() {
1886 for j in 0..i {
1887 let diff = (h[[i, j]] - h[[j, i]]).abs();
1888 assert!(
1889 diff < 1e-6,
1890 "joint hessian asymmetry at ({i},{j}): {diff:.3e}"
1891 );
1892 }
1893 }
1894 let mixed_0 = h[[0, 1]];
1896 let mixed_1 = h[[0, 2]];
1897 assert!(
1898 mixed_0.is_finite() && mixed_1.is_finite(),
1899 "mixed blocks must be finite"
1900 );
1901 }
1902
1903 #[test]
1904 pub(crate) fn joint_tau_tau_linear_dirs_matchfd_reference_away_fromzero_psi() {
1905 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1906 let w = Array1::<f64>::ones(y.len());
1907 let x = array![
1908 [1.0, -1.2, 0.3],
1909 [1.0, -0.8, -0.4],
1910 [1.0, -0.3, 0.7],
1911 [1.0, 0.1, -0.9],
1912 [1.0, 0.5, 0.2],
1913 [1.0, 0.9, -0.1],
1914 [1.0, 1.3, 0.8],
1915 [1.0, 1.7, -0.6],
1916 ];
1917 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1918 let cfg =
1919 RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false).with_max_iterations(500);
1920 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1921 let rho = array![0.0];
1922 let psi = array![0.7, -0.4];
1923 let theta = array![rho[0], psi[0], psi[1]];
1924 let hyper_dirs = vec![
1925 DirectionalHyperParam::single_penalty(
1926 0,
1927 Array2::<f64>::zeros((x.nrows(), x.ncols())),
1928 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15],],
1929 None,
1930 None,
1931 )
1932 .expect("linear tau direction"),
1933 DirectionalHyperParam::single_penalty(
1934 0,
1935 Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
1936 Array2::<f64>::zeros((x.ncols(), x.ncols())),
1937 None,
1938 None,
1939 )
1940 .expect("linear tau direction"),
1941 ];
1942
1943 let (_, _, h_full) =
1944 compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
1945 .expect("joint hyper cost+gradient+hessian");
1946 let h_tt_analytic = h_full.slice(s![rho.len().., rho.len()..]).to_owned();
1947
1948 let x_tau_mats: Vec<Array2<f64>> = vec![
1953 Array2::<f64>::zeros((x.nrows(), x.ncols())),
1954 Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
1955 ];
1956 let s_tau_mats: Vec<Array2<f64>> = vec![
1957 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15]],
1958 Array2::<f64>::zeros((x.ncols(), x.ncols())),
1959 ];
1960
1961 let h_ttfd = directional_tau_hessian_fd_reference(
1962 &y,
1963 &w,
1964 &x,
1965 &s0,
1966 &cfg,
1967 &rho,
1968 &hyper_dirs,
1969 &x_tau_mats,
1970 &s_tau_mats,
1971 );
1972
1973 let num = (&h_tt_analytic - &h_ttfd)
1974 .iter()
1975 .map(|v| v * v)
1976 .sum::<f64>()
1977 .sqrt();
1978 let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
1979 let rel = num / den;
1980 assert!(
1981 rel < 1e-4,
1982 "linear-dir joint tau-tau block deviates from FD reference away from zero psi: rel={rel:.3e}, analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
1983 );
1984 }
1985
1986 #[test]
1987 pub(crate) fn joint_hypervalidation_rejects_out_of_boundssecond_order_penalty_index() {
1988 let y = array![0.0, 1.0, 0.0, 1.0];
2005 let w = Array1::<f64>::ones(y.len());
2006 let x = array![
2007 [1.0, -0.5, 0.2],
2008 [1.0, -0.1, -0.3],
2009 [1.0, 0.4, 0.6],
2010 [1.0, 0.9, -0.2],
2011 ];
2012 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
2013 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, true);
2014 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
2015 let theta = array![0.0, 0.0];
2016 let hyper_dirs = vec![
2017 DirectionalHyperParam::new(
2018 Array2::<f64>::zeros((x.nrows(), x.ncols())),
2019 vec![(0, Array2::<f64>::zeros((x.ncols(), x.ncols())))],
2020 None,
2021 Some(vec![Some(vec![(1, Array2::<f64>::eye(x.ncols()))])]),
2022 )
2023 .expect("hyper direction with invalid second-order penalty index"),
2024 ];
2025
2026 let msg = match compute_joint_hypercostgradienthessian(&state, &theta, 1, &hyper_dirs) {
2027 Ok(_) => panic!("invalid second-order penalty index should be rejected"),
2028 Err(err) => err.to_string(),
2029 };
2030 assert!(
2031 msg.contains("out of bounds") || msg.contains("penalty_index"),
2032 "unexpected validation error: {msg}"
2033 );
2034 }
2035
2036 #[test]
2037 pub(crate) fn joint_tau_tau_analytic_matchesfd_reference() {
2038 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
2039 let w = Array1::<f64>::ones(y.len());
2040 let x = array![
2041 [1.0, -1.2, 0.3],
2042 [1.0, -0.8, -0.4],
2043 [1.0, -0.3, 0.7],
2044 [1.0, 0.1, -0.9],
2045 [1.0, 0.5, 0.2],
2046 [1.0, 0.9, -0.1],
2047 [1.0, 1.3, 0.8],
2048 [1.0, 1.7, -0.6],
2049 ];
2050 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
2051 let cfg =
2052 RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false).with_max_iterations(500);
2053 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
2054 let rho = array![0.0];
2055 let psi = array![0.0, 0.0];
2056 let hyper_dirs = vec![
2057 DirectionalHyperParam::single_penalty(
2058 0,
2059 Array2::<f64>::zeros((x.nrows(), x.ncols())),
2060 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15],],
2061 None,
2062 None,
2063 )
2064 .expect("single-penalty hyper direction"),
2065 DirectionalHyperParam::single_penalty(
2066 0,
2067 Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
2068 Array2::<f64>::zeros((x.ncols(), x.ncols())),
2069 None,
2070 None,
2071 )
2072 .expect("single-penalty hyper direction"),
2073 ];
2074
2075 let theta = {
2076 let mut t = Array1::<f64>::zeros(rho.len() + psi.len());
2077 t.slice_mut(s![..rho.len()]).assign(&rho);
2078 t.slice_mut(s![rho.len()..]).assign(&psi);
2079 t
2080 };
2081 let (_, _, h_full) =
2082 compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
2083 .expect("joint hyper cost+gradient+hessian");
2084 let h_tt_analytic = h_full.slice(s![rho.len().., rho.len()..]).to_owned();
2085 assert_eq!(h_tt_analytic.nrows(), hyper_dirs.len());
2086 assert_eq!(h_tt_analytic.ncols(), hyper_dirs.len());
2087
2088 let x_tau_mats: Vec<Array2<f64>> = vec![
2093 Array2::<f64>::zeros((x.nrows(), x.ncols())),
2094 Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
2095 ];
2096 let s_tau_mats: Vec<Array2<f64>> = vec![
2097 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15]],
2098 Array2::<f64>::zeros((x.ncols(), x.ncols())),
2099 ];
2100
2101 let h_ttfd = directional_tau_hessian_fd_reference(
2102 &y,
2103 &w,
2104 &x,
2105 &s0,
2106 &cfg,
2107 &rho,
2108 &hyper_dirs,
2109 &x_tau_mats,
2110 &s_tau_mats,
2111 );
2112
2113 let num = (&h_tt_analytic - &h_ttfd)
2114 .iter()
2115 .map(|v| v * v)
2116 .sum::<f64>()
2117 .sqrt();
2118 let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2119 let rel = num / den;
2120 assert!(
2121 rel < 1e-4,
2122 "analytic tau-tau block deviates from FD reference: rel={rel:.3e}, analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
2123 );
2124 }
2125
2126 pub(crate) struct GaussianRemlFixture {
2136 pub(crate) y: Array1<f64>,
2137 pub(crate) w: Array1<f64>,
2138 pub(crate) x: Array2<f64>,
2139 pub(crate) s0: Array2<f64>,
2140 pub(crate) cfg: RemlConfig,
2141 pub(crate) rho: Array1<f64>,
2142 pub(crate) x_tau_design: Array2<f64>,
2144 pub(crate) s_tau_penalty: Array2<f64>,
2146 }
2147
2148 impl GaussianRemlFixture {
2149 pub(crate) fn new() -> Self {
2150 let y = array![0.5, 1.2, -0.3, 0.8, 1.1, -0.6, 0.9, 0.1, -0.2, 0.7];
2151 let x = array![
2152 [1.0, -1.2, 0.3],
2153 [1.0, -0.8, -0.4],
2154 [1.0, -0.3, 0.7],
2155 [1.0, 0.1, -0.9],
2156 [1.0, 0.5, 0.2],
2157 [1.0, 0.9, -0.1],
2158 [1.0, 1.3, 0.8],
2159 [1.0, 1.7, -0.6],
2160 [1.0, -0.5, 0.5],
2161 [1.0, 0.3, -0.3],
2162 ];
2163 Self {
2164 w: Array1::<f64>::ones(y.len()),
2165 y,
2166 x: x.clone(),
2167 s0: array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9]],
2168 cfg: RemlConfig::external(gaussian_identity_glm_spec(), 1e-14, false),
2169 rho: array![0.0],
2170 x_tau_design: array![
2171 [0.0, 1e-3, -2e-3],
2172 [0.0, -3e-3, 1e-3],
2173 [0.0, 2e-3, 0.5e-3],
2174 [0.0, -1e-3, 3e-3],
2175 [0.0, 0.5e-3, -1e-3],
2176 [0.0, 1.5e-3, 2e-3],
2177 [0.0, -2e-3, -0.5e-3],
2178 [0.0, 3e-3, 1e-3],
2179 [0.0, -0.5e-3, 2e-3],
2180 [0.0, 1e-3, -1.5e-3],
2181 ],
2182 s_tau_penalty: array![[0.0, 0.0, 0.0], [0.0, 0.25, 0.04], [0.0, 0.04, 0.15]],
2183 }
2184 }
2185 }
2186
2187 impl LogitDesignMotionFixture for GaussianRemlFixture {
2188 fn y(&self) -> &Array1<f64> {
2189 &self.y
2190 }
2191 fn w(&self) -> &Array1<f64> {
2192 &self.w
2193 }
2194 fn x(&self) -> &Array2<f64> {
2195 &self.x
2196 }
2197 fn s0(&self) -> &Array2<f64> {
2198 &self.s0
2199 }
2200 fn cfg(&self) -> &RemlConfig {
2201 &self.cfg
2202 }
2203 fn rho(&self) -> &Array1<f64> {
2204 &self.rho
2205 }
2206 }
2207
2208 #[test]
2209 pub(crate) fn profiled_gaussian_design_moving_gradient_matches_fd() {
2210 let f = GaussianRemlFixture::new();
2211 let state = f.state();
2212 let s_tau = Array2::<f64>::zeros((3, 3));
2213 let hyper = DirectionalHyperParam::single_penalty(
2214 0,
2215 f.x_tau_design.clone(),
2216 s_tau.clone(),
2217 None,
2218 None,
2219 )
2220 .expect("design-moving hyper direction");
2221
2222 let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2223 .expect("analytic directional gradient");
2224 let v_taufd = f.fd_directional_gradient(&f.x_tau_design, &s_tau);
2225
2226 let v_rel = (v_tau_analytic - v_taufd).abs() / v_taufd.abs().max(1e-10);
2227 assert!(
2228 v_rel < 1e-3,
2229 "Gaussian REML design-moving V_tau mismatch: rel={v_rel:.3e}, \
2230 analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
2231 );
2232 }
2233
2234 #[test]
2235 pub(crate) fn profiled_gaussian_penalty_only_gradient_matches_fd() {
2236 let f = GaussianRemlFixture::new();
2237 let state = f.state();
2238 let x_tau = Array2::<f64>::zeros(f.x.raw_dim());
2239 let hyper = DirectionalHyperParam::single_penalty(
2240 0,
2241 x_tau.clone(),
2242 f.s_tau_penalty.clone(),
2243 None,
2244 None,
2245 )
2246 .expect("penalty-only hyper direction");
2247
2248 let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2249 .expect("analytic directional gradient");
2250 let v_taufd = f.fd_directional_gradient(&x_tau, &f.s_tau_penalty);
2251
2252 let v_rel = (v_tau_analytic - v_taufd).abs() / v_taufd.abs().max(1e-10);
2253 assert!(
2254 v_rel < 1e-3,
2255 "Gaussian REML penalty-only V_tau mismatch: rel={v_rel:.3e}, \
2256 analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
2257 );
2258 }
2259
2260 #[test]
2261 pub(crate) fn profiled_gaussian_joint_hessian_matches_fd() {
2262 let f = GaussianRemlFixture::new();
2265 let x_tau_0 = Array2::<f64>::zeros(f.x.raw_dim());
2266 let s_tau_0 = f.s_tau_penalty.clone();
2267 let x_tau_1 = f.x_tau_design.clone();
2268 let s_tau_1 = Array2::<f64>::zeros((3, 3));
2269
2270 let hyper_dirs = vec![
2271 DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2272 .expect("penalty-only direction"),
2273 DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2274 .expect("design-moving direction"),
2275 ];
2276
2277 let state = f.state();
2278 let mut theta = Array1::<f64>::zeros(f.rho.len() + hyper_dirs.len());
2279 theta.slice_mut(s![..f.rho.len()]).assign(&f.rho);
2280 let (_, _, h_full) =
2281 compute_joint_hypercostgradienthessian(&state, &theta, f.rho.len(), &hyper_dirs)
2282 .expect("joint cost+gradient+hessian");
2283 let h_tt_analytic = h_full.slice(s![f.rho.len().., f.rho.len()..]).to_owned();
2284
2285 let x_tau_mats = vec![x_tau_0.clone(), x_tau_1.clone()];
2288 let s_tau_mats = vec![s_tau_0.clone(), s_tau_1.clone()];
2289 let h_ttfd = directional_tau_hessian_fd_reference(
2290 &f.y,
2291 &f.w,
2292 &f.x,
2293 &f.s0,
2294 &f.cfg,
2295 &f.rho,
2296 &hyper_dirs,
2297 &x_tau_mats,
2298 &s_tau_mats,
2299 );
2300
2301 let num = (&h_tt_analytic - &h_ttfd)
2302 .iter()
2303 .map(|v| v * v)
2304 .sum::<f64>()
2305 .sqrt();
2306 let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2307 let rel = num / den;
2308 assert!(
2309 rel < 1e-4,
2310 "Gaussian REML tau-tau Hessian mismatch: rel={rel:.3e}, \
2311 analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
2312 );
2313 }
2314
2315 #[test]
2329 pub(crate) fn logit_design_moving_gradient_matches_fd() {
2330 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0];
2331 let w = Array1::<f64>::ones(y.len());
2332 let x = array![
2333 [1.0, -1.2, 0.3],
2334 [1.0, -0.8, -0.4],
2335 [1.0, -0.3, 0.7],
2336 [1.0, 0.1, -0.9],
2337 [1.0, 0.5, 0.2],
2338 [1.0, 0.9, -0.1],
2339 [1.0, 1.3, 0.8],
2340 [1.0, 1.7, -0.6],
2341 [1.0, -0.5, 0.5],
2342 [1.0, 0.3, -0.3],
2343 ];
2344 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9]];
2345 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
2346 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
2347 let rho = array![0.0];
2348
2349 let x_tau = array![
2351 [0.0, 1e-3, -2e-3],
2352 [0.0, -3e-3, 1e-3],
2353 [0.0, 2e-3, 0.5e-3],
2354 [0.0, -1e-3, 3e-3],
2355 [0.0, 0.5e-3, -1e-3],
2356 [0.0, 1.5e-3, 2e-3],
2357 [0.0, -2e-3, -0.5e-3],
2358 [0.0, 3e-3, 1e-3],
2359 [0.0, -0.5e-3, 2e-3],
2360 [0.0, 1e-3, -1.5e-3],
2361 ];
2362 let s_tau = Array2::<f64>::zeros((3, 3));
2363 let hyper =
2364 DirectionalHyperParam::single_penalty(0, x_tau.clone(), s_tau.clone(), None, None)
2365 .expect("design-moving hyper direction");
2366
2367 let v_tau_analytic = single_directional_tau_gradient(&state, &rho, hyper)
2368 .expect("analytic directional gradient");
2369
2370 let h = 2e-5;
2371 let x_plus = &x + &x_tau.mapv(|v| h * v);
2372 let x_minus = &x - &x_tau.mapv(|v| h * v);
2373 let state_plus = build_logit_state(&y, &w, &x_plus, &s0, &cfg);
2374 let state_minus = build_logit_state(&y, &w, &x_minus, &s0, &cfg);
2375 let v_plus = state_plus.compute_cost(&rho).expect("cost+");
2376 let v_minus = state_minus.compute_cost(&rho).expect("cost-");
2377 let v_taufd = (v_plus - v_minus) / (2.0 * h);
2378
2379 let v_rel = (v_tau_analytic - v_taufd).abs() / v_taufd.abs().max(1e-10);
2380 assert!(
2381 v_rel < 1e-3,
2382 "Logit REML design-moving V_tau mismatch: rel={v_rel:.3e}, \
2383 analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
2384 );
2385 }
2386
2387 #[test]
2388 pub(crate) fn logit_design_moving_hessian_matches_fd() {
2389 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0];
2394 let w = Array1::<f64>::ones(y.len());
2395 let x = array![
2396 [1.0, -1.2, 0.3],
2397 [1.0, -0.8, -0.4],
2398 [1.0, -0.3, 0.7],
2399 [1.0, 0.1, -0.9],
2400 [1.0, 0.5, 0.2],
2401 [1.0, 0.9, -0.1],
2402 [1.0, 1.3, 0.8],
2403 [1.0, 1.7, -0.6],
2404 [1.0, -0.5, 0.5],
2405 [1.0, 0.3, -0.3],
2406 ];
2407 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9]];
2408 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
2409 let rho = array![0.0];
2410
2411 let x_tau_0 = Array2::<f64>::zeros(x.raw_dim());
2413 let s_tau_0 = array![[0.0, 0.0, 0.0], [0.0, 0.25, 0.04], [0.0, 0.04, 0.15]];
2414 let x_tau_1 = array![
2415 [0.0, 1e-3, -2e-3],
2416 [0.0, -3e-3, 1e-3],
2417 [0.0, 2e-3, 0.5e-3],
2418 [0.0, -1e-3, 3e-3],
2419 [0.0, 0.5e-3, -1e-3],
2420 [0.0, 1.5e-3, 2e-3],
2421 [0.0, -2e-3, -0.5e-3],
2422 [0.0, 3e-3, 1e-3],
2423 [0.0, -0.5e-3, 2e-3],
2424 [0.0, 1e-3, -1.5e-3],
2425 ];
2426 let s_tau_1 = Array2::<f64>::zeros((3, 3));
2427
2428 let hyper_dirs = vec![
2429 DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2430 .expect("penalty-only direction"),
2431 DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2432 .expect("design-moving direction"),
2433 ];
2434
2435 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
2436 let mut theta = Array1::<f64>::zeros(rho.len() + hyper_dirs.len());
2437 theta.slice_mut(s![..rho.len()]).assign(&rho);
2438 let (_, _, h_full) =
2439 compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
2440 .expect("joint cost+gradient+hessian");
2441 let h_tt_analytic = h_full.slice(s![rho.len().., rho.len()..]).to_owned();
2442
2443 let x_tau_mats = vec![x_tau_0.clone(), x_tau_1.clone()];
2444 let s_tau_mats = vec![s_tau_0.clone(), s_tau_1.clone()];
2445 let h_ttfd = directional_tau_hessian_fd_reference(
2446 &y,
2447 &w,
2448 &x,
2449 &s0,
2450 &cfg,
2451 &rho,
2452 &hyper_dirs,
2453 &x_tau_mats,
2454 &s_tau_mats,
2455 );
2456
2457 let num = (&h_tt_analytic - &h_ttfd)
2458 .iter()
2459 .map(|v| v * v)
2460 .sum::<f64>()
2461 .sqrt();
2462 let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2463 let rel = num / den;
2464 assert!(
2465 rel < 1e-4,
2466 "Logit REML design-moving tau-tau Hessian mismatch: rel={rel:.3e}, \
2467 analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
2468 );
2469 }
2470
2471 pub(crate) struct BinomialLogitDesignMotionFixture {
2481 pub(crate) y: Array1<f64>,
2482 pub(crate) w: Array1<f64>,
2483 pub(crate) x: Array2<f64>,
2484 pub(crate) s0: Array2<f64>,
2485 pub(crate) cfg: RemlConfig,
2486 pub(crate) rho: Array1<f64>,
2487 pub(crate) x_tau_design: Array2<f64>,
2489 pub(crate) s_tau_penalty: Array2<f64>,
2491 }
2492
2493 impl BinomialLogitDesignMotionFixture {
2494 pub(crate) fn new() -> Self {
2495 let y = array![
2497 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0,
2498 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0
2499 ];
2500 let x = array![
2502 [1.0, -1.50, 0.42, 0.88, -0.31],
2503 [1.0, -1.12, -0.65, 0.14, 1.23],
2504 [1.0, -0.80, 1.10, -0.53, 0.07],
2505 [1.0, -0.55, -0.22, 1.40, -0.90],
2506 [1.0, -0.30, 0.73, -1.05, 0.44],
2507 [1.0, -0.05, -1.33, 0.60, 0.81],
2508 [1.0, 0.18, 0.55, -0.27, -1.15],
2509 [1.0, 0.42, -0.90, 1.12, 0.33],
2510 [1.0, 0.70, 1.28, -0.78, -0.56],
2511 [1.0, 0.95, -0.18, 0.45, 1.40],
2512 [1.0, 1.20, 0.66, -1.30, -0.02],
2513 [1.0, 1.45, -1.05, 0.22, 0.68],
2514 [1.0, -1.35, 0.90, 0.55, -0.43],
2515 [1.0, -0.98, -0.40, -0.88, 1.05],
2516 [1.0, -0.62, 1.42, 0.30, -0.70],
2517 [1.0, -0.28, -0.77, -1.18, 0.52],
2518 [1.0, 0.05, 0.15, 0.95, -1.35],
2519 [1.0, 0.33, -1.20, -0.40, 0.18],
2520 [1.0, 0.60, 0.82, 1.25, -0.85],
2521 [1.0, 0.88, -0.50, -0.65, 1.10],
2522 [1.0, 1.15, 1.05, 0.10, -0.22],
2523 [1.0, -1.22, -0.95, 0.72, 0.90],
2524 [1.0, -0.75, 0.38, -1.42, 0.15],
2525 [1.0, -0.42, -1.15, 0.50, -1.08],
2526 [1.0, -0.10, 0.60, -0.15, 0.75],
2527 [1.0, 0.25, -0.28, 1.05, -0.48],
2528 [1.0, 0.52, 1.35, -0.92, 0.30],
2529 [1.0, 0.80, -0.70, 0.38, 1.20],
2530 [1.0, 1.08, 0.48, -0.60, -0.95],
2531 [1.0, 1.35, -0.55, 0.85, 0.42]
2532 ];
2533 let s0 = array![
2535 [0.0, 0.0, 0.0, 0.0, 0.0],
2536 [0.0, 1.40, 0.15, 0.05, -0.10],
2537 [0.0, 0.15, 1.10, -0.20, 0.08],
2538 [0.0, 0.05, -0.20, 0.95, 0.12],
2539 [0.0, -0.10, 0.08, 0.12, 1.25]
2540 ];
2541 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
2542 let x_tau_design = array![
2545 [0.0, 1.2e-3, -0.8e-3, 0.5e-3, -1.5e-3],
2546 [0.0, -2.0e-3, 1.4e-3, -0.3e-3, 0.9e-3],
2547 [0.0, 0.6e-3, -1.1e-3, 1.8e-3, -0.4e-3],
2548 [0.0, -1.3e-3, 0.7e-3, -1.0e-3, 2.1e-3],
2549 [0.0, 0.9e-3, -0.5e-3, 0.2e-3, -0.8e-3],
2550 [0.0, -0.4e-3, 1.8e-3, -1.5e-3, 0.3e-3],
2551 [0.0, 1.5e-3, -1.3e-3, 0.8e-3, -1.1e-3],
2552 [0.0, -0.7e-3, 0.4e-3, -2.0e-3, 1.6e-3],
2553 [0.0, 2.2e-3, -0.9e-3, 1.3e-3, -0.6e-3],
2554 [0.0, -1.0e-3, 1.6e-3, -0.7e-3, 0.5e-3],
2555 [0.0, 0.3e-3, -2.1e-3, 1.1e-3, -1.8e-3],
2556 [0.0, -1.8e-3, 0.2e-3, -0.4e-3, 1.3e-3],
2557 [0.0, 1.1e-3, -1.5e-3, 2.0e-3, -0.2e-3],
2558 [0.0, -0.5e-3, 0.9e-3, -1.2e-3, 0.7e-3],
2559 [0.0, 1.7e-3, -0.3e-3, 0.6e-3, -2.0e-3],
2560 [0.0, -1.4e-3, 1.1e-3, -0.9e-3, 0.4e-3],
2561 [0.0, 0.8e-3, -1.7e-3, 1.5e-3, -0.1e-3],
2562 [0.0, -0.2e-3, 0.6e-3, -1.8e-3, 1.0e-3],
2563 [0.0, 1.4e-3, -0.4e-3, 0.3e-3, -1.3e-3],
2564 [0.0, -0.9e-3, 2.0e-3, -0.5e-3, 0.8e-3],
2565 [0.0, 0.5e-3, -1.0e-3, 1.6e-3, -0.7e-3],
2566 [0.0, -2.1e-3, 0.3e-3, -0.8e-3, 1.5e-3],
2567 [0.0, 0.7e-3, -1.8e-3, 0.9e-3, -0.3e-3],
2568 [0.0, -0.6e-3, 1.3e-3, -2.2e-3, 1.1e-3],
2569 [0.0, 1.9e-3, -0.7e-3, 0.4e-3, -0.9e-3],
2570 [0.0, -1.1e-3, 0.5e-3, -1.4e-3, 2.2e-3],
2571 [0.0, 0.4e-3, -1.6e-3, 1.2e-3, -0.5e-3],
2572 [0.0, -1.6e-3, 0.8e-3, -0.1e-3, 0.6e-3],
2573 [0.0, 1.3e-3, -2.2e-3, 0.7e-3, -1.4e-3],
2574 [0.0, -0.3e-3, 1.0e-3, -1.6e-3, 1.8e-3]
2575 ];
2576 let s_tau_penalty = array![
2578 [0.0, 0.0, 0.0, 0.0, 0.0],
2579 [0.0, 0.30, 0.05, -0.02, 0.04],
2580 [0.0, 0.05, 0.22, 0.03, -0.01],
2581 [0.0, -0.02, 0.03, 0.18, 0.06],
2582 [0.0, 0.04, -0.01, 0.06, 0.26]
2583 ];
2584 Self {
2585 w: Array1::<f64>::ones(y.len()),
2586 y,
2587 x,
2588 s0,
2589 cfg,
2590 rho: array![0.0],
2591 x_tau_design,
2592 s_tau_penalty,
2593 }
2594 }
2595 }
2596
2597 impl LogitDesignMotionFixture for BinomialLogitDesignMotionFixture {
2598 fn y(&self) -> &Array1<f64> {
2599 &self.y
2600 }
2601 fn w(&self) -> &Array1<f64> {
2602 &self.w
2603 }
2604 fn x(&self) -> &Array2<f64> {
2605 &self.x
2606 }
2607 fn s0(&self) -> &Array2<f64> {
2608 &self.s0
2609 }
2610 fn cfg(&self) -> &RemlConfig {
2611 &self.cfg
2612 }
2613 fn rho(&self) -> &Array1<f64> {
2614 &self.rho
2615 }
2616 }
2617
2618 #[test]
2621 pub(crate) fn binomial_logit_n30_design_moving_gradient_matches_fd() {
2622 let f = BinomialLogitDesignMotionFixture::new();
2629 let state = f.state();
2630 let s_tau = Array2::<f64>::zeros((5, 5));
2631 let hyper = DirectionalHyperParam::single_penalty(
2632 0,
2633 f.x_tau_design.clone(),
2634 s_tau.clone(),
2635 None,
2636 None,
2637 )
2638 .expect("design-moving hyper direction");
2639
2640 let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2641 .expect("analytic directional gradient");
2642 let v_tau_fd = f.fd_directional_gradient(&f.x_tau_design, &s_tau);
2643
2644 let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2645 assert!(
2646 v_rel < 1e-3,
2647 "Binomial-logit n=30 design-moving gradient mismatch: rel={v_rel:.3e}, \
2648 analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2649 );
2650 }
2651
2652 #[test]
2653 pub(crate) fn binomial_logit_n30_penalty_only_gradient_matches_fd() {
2654 let f = BinomialLogitDesignMotionFixture::new();
2659 let state = f.state();
2660 let x_tau = Array2::<f64>::zeros(f.x.raw_dim());
2661 let hyper = DirectionalHyperParam::single_penalty(
2662 0,
2663 x_tau.clone(),
2664 f.s_tau_penalty.clone(),
2665 None,
2666 None,
2667 )
2668 .expect("penalty-only hyper direction");
2669
2670 let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2671 .expect("analytic directional gradient");
2672 let v_tau_fd = f.fd_directional_gradient(&x_tau, &f.s_tau_penalty);
2673
2674 let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2675 assert!(
2676 v_rel < 1e-3,
2677 "Binomial-logit n=30 penalty-only gradient mismatch: rel={v_rel:.3e}, \
2678 analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2679 );
2680 }
2681
2682 #[test]
2683 pub(crate) fn binomial_logit_n30_joint_design_penalty_gradient_matches_fd() {
2684 let f = BinomialLogitDesignMotionFixture::new();
2689 let state = f.state();
2690 let hyper = DirectionalHyperParam::single_penalty(
2691 0,
2692 f.x_tau_design.clone(),
2693 f.s_tau_penalty.clone(),
2694 None,
2695 None,
2696 )
2697 .expect("joint design+penalty hyper direction");
2698
2699 let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2700 .expect("analytic directional gradient");
2701 let v_tau_fd = f.fd_directional_gradient(&f.x_tau_design, &f.s_tau_penalty);
2702
2703 let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2704 assert!(
2705 v_rel < 1e-3,
2706 "Binomial-logit n=30 joint design+penalty gradient mismatch: rel={v_rel:.3e}, \
2707 analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2708 );
2709 }
2710
2711 #[test]
2712 pub(crate) fn binomial_logit_n30_design_moving_hessian_matches_fd() {
2713 let f = BinomialLogitDesignMotionFixture::new();
2718 let x_tau_0 = Array2::<f64>::zeros(f.x.raw_dim());
2719 let s_tau_0 = f.s_tau_penalty.clone();
2720 let x_tau_1 = f.x_tau_design.clone();
2721 let s_tau_1 = Array2::<f64>::zeros((5, 5));
2722
2723 let hyper_dirs = vec![
2724 DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2725 .expect("penalty-only direction"),
2726 DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2727 .expect("design-moving direction"),
2728 ];
2729
2730 let state = f.state();
2731 let mut theta = Array1::<f64>::zeros(f.rho.len() + hyper_dirs.len());
2732 theta.slice_mut(s![..f.rho.len()]).assign(&f.rho);
2733 let (_, _, h_full) =
2734 compute_joint_hypercostgradienthessian(&state, &theta, f.rho.len(), &hyper_dirs)
2735 .expect("joint cost+gradient+hessian");
2736 let h_tt_analytic = h_full.slice(s![f.rho.len().., f.rho.len()..]).to_owned();
2737
2738 let x_tau_mats = vec![x_tau_0.clone(), x_tau_1.clone()];
2739 let s_tau_mats = vec![s_tau_0.clone(), s_tau_1.clone()];
2740 let h_tt_fd = directional_tau_hessian_fd_reference(
2741 &f.y,
2742 &f.w,
2743 &f.x,
2744 &f.s0,
2745 &f.cfg,
2746 &f.rho,
2747 &hyper_dirs,
2748 &x_tau_mats,
2749 &s_tau_mats,
2750 );
2751
2752 let num = (&h_tt_analytic - &h_tt_fd)
2753 .iter()
2754 .map(|v| v * v)
2755 .sum::<f64>()
2756 .sqrt();
2757 let den = h_tt_fd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2758 let rel = num / den;
2759 assert!(
2760 rel < 1e-4,
2761 "Binomial-logit n=30 tau-tau Hessian mismatch: rel={rel:.3e}, \
2762 analytic={h_tt_analytic:?}, fd={h_tt_fd:?}"
2763 );
2764 }
2765
2766 #[test]
2767 pub(crate) fn binomial_logit_n30_nonzero_rho_design_moving_gradient_matches_fd() {
2768 let f = BinomialLogitDesignMotionFixture::new();
2772 let rho = array![1.5];
2773 let s_tau = Array2::<f64>::zeros((5, 5));
2774
2775 let state = f.state();
2776 let hyper = DirectionalHyperParam::single_penalty(
2777 0,
2778 f.x_tau_design.clone(),
2779 s_tau.clone(),
2780 None,
2781 None,
2782 )
2783 .expect("design-moving hyper direction");
2784
2785 let v_tau_analytic = single_directional_tau_gradient(&state, &rho, hyper)
2786 .expect("analytic directional gradient");
2787
2788 let h = 2e-5;
2790 let (state_plus, state_minus) = f.state_perturbed(&f.x_tau_design, &s_tau, h);
2791 let v_plus = state_plus.compute_cost(&rho).expect("cost+");
2792 let v_minus = state_minus.compute_cost(&rho).expect("cost-");
2793 let v_tau_fd = (v_plus - v_minus) / (2.0 * h);
2794
2795 let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2796 assert!(
2797 v_rel < 1e-3,
2798 "Binomial-logit n=30 rho=1.5 design-moving gradient mismatch: rel={v_rel:.3e}, \
2799 analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2800 );
2801 }
2802
2803 #[test]
2804 pub(crate) fn binomial_logit_n30_rank_deficient_hessian_matches_cost_fd() {
2805 let f = BinomialLogitDesignMotionFixture::new();
2840 let x_tau_0 = Array2::<f64>::zeros(f.x.raw_dim());
2841 let s_tau_0 = f.s_tau_penalty.clone();
2842 let x_tau_1 = f.x_tau_design.clone();
2843 let s_tau_1 = Array2::<f64>::zeros((5, 5));
2844
2845 let hyper_dirs = vec![
2846 DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2847 .expect("penalty-only direction"),
2848 DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2849 .expect("design-moving direction"),
2850 ];
2851
2852 let state = f.state();
2854 let mut theta = Array1::<f64>::zeros(f.rho.len() + hyper_dirs.len());
2855 theta.slice_mut(s![..f.rho.len()]).assign(&f.rho);
2856 let (_, _, h_full) =
2857 compute_joint_hypercostgradienthessian(&state, &theta, f.rho.len(), &hyper_dirs)
2858 .expect("joint cost+gradient+hessian");
2859 let h_tt_analytic = h_full.slice(s![f.rho.len().., f.rho.len()..]).to_owned();
2860
2861 const TARGET_PHYSICAL_STEP: f64 = 1e-5;
2865 let x_tau_mats = [&x_tau_0, &x_tau_1];
2866 let s_tau_mats = [&s_tau_0, &s_tau_1];
2867 let steps: [f64; 2] = {
2868 let mut steps = [0.0; 2];
2869 for (j, step) in steps.iter_mut().enumerate() {
2870 let scale = x_tau_mats[j]
2871 .iter()
2872 .chain(s_tau_mats[j].iter())
2873 .fold(0.0_f64, |acc, value| acc.max(value.abs()));
2874 *step = if scale > 0.0 {
2875 TARGET_PHYSICAL_STEP / scale
2876 } else {
2877 TARGET_PHYSICAL_STEP
2878 };
2879 }
2880 steps
2881 };
2882
2883 let eval_cost = |a: f64, b: f64| -> f64 {
2885 let x_eval = &f.x
2886 + &x_tau_mats[0].mapv(|v| a * steps[0] * v)
2887 + &x_tau_mats[1].mapv(|v| b * steps[1] * v);
2888 let s_eval = &f.s0
2889 + &s_tau_mats[0].mapv(|v| a * steps[0] * v)
2890 + &s_tau_mats[1].mapv(|v| b * steps[1] * v);
2891 let st = build_logit_state(&f.y, &f.w, &x_eval, &s_eval, &f.cfg);
2892 st.compute_cost(&f.rho).expect("cost eval")
2893 };
2894
2895 let v_00 = eval_cost(0.0, 0.0);
2896 let v_p0 = eval_cost(1.0, 0.0);
2897 let v_m0 = eval_cost(-1.0, 0.0);
2898 let v_0p = eval_cost(0.0, 1.0);
2899 let v_0m = eval_cost(0.0, -1.0);
2900 let v_pp = eval_cost(1.0, 1.0);
2901 let v_pm = eval_cost(1.0, -1.0);
2902 let v_mp = eval_cost(-1.0, 1.0);
2903 let v_mm = eval_cost(-1.0, -1.0);
2904
2905 let h00_fd = (v_p0 - 2.0 * v_00 + v_m0) / (steps[0] * steps[0]);
2906 let h11_fd = (v_0p - 2.0 * v_00 + v_0m) / (steps[1] * steps[1]);
2907 let h01_fd = (v_pp - v_pm - v_mp + v_mm) / (4.0 * steps[0] * steps[1]);
2908
2909 let h_tt_fd = array![[h00_fd, h01_fd], [h01_fd, h11_fd]];
2910
2911 let num = (&h_tt_analytic - &h_tt_fd)
2912 .iter()
2913 .map(|v| v * v)
2914 .sum::<f64>()
2915 .sqrt();
2916 let den = h_tt_fd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2917 let rel = num / den;
2918
2919 assert!(
2920 rel < 3e-3,
2921 "Binomial-logit n=30 rank-deficient Hessian vs cost-FD mismatch: rel={rel:.3e}, \
2922 analytic={h_tt_analytic:?}, fd={h_tt_fd:?}"
2923 );
2924 }
2925}
2926
2927#[derive(Clone, Copy, Debug)]
2928pub(crate) enum RemlGeometry {
2929 DenseSpectral,
2930 SparseExactSpd,
2931}
2932
2933trait PenalizedGeometry {
2934 fn backend_kind(&self) -> GeometryBackendKind;
2935}
2936
2937#[derive(Clone)]
2938pub(crate) enum DerivativeMatrixStorage {
2939 Dense(Array2<f64>),
2940 Zero(ZeroDerivativeMatrix),
2941 Embedded(EmbeddedDerivativeMatrix),
2942 Implicit(ImplicitDerivativeOp),
2943 LatentCoord(LatentCoordDerivativeOp),
2944}
2945
2946trait DerivativeStorageBackend {
2958 fn resident_byte_count(&self) -> usize;
2959 fn design_nrows(&self) -> usize;
2960 fn design_ncols(&self) -> usize;
2961 fn penalty_dim(&self) -> usize;
2962 fn uses_implicit_storage(&self) -> bool;
2963 fn any_nonzero(&self) -> bool;
2964 fn materialize(&self) -> Array2<f64>;
2965 fn implicit_first_axis_info(
2966 &self,
2967 ) -> Option<(
2968 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
2969 usize,
2970 )>;
2971 fn implicit_axis_count_hint(&self) -> Option<usize>;
2972 fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError>;
2973 fn design_transpose_mul_original(
2974 &self,
2975 v: &Array1<f64>,
2976 ) -> Result<Array1<f64>, EstimationError>;
2977 fn design_transformed(
2978 &self,
2979 qs: &Array2<f64>,
2980 free_basis_opt: Option<&Array2<f64>>,
2981 ) -> Result<Array2<f64>, EstimationError>;
2982 fn design_transformed_forward_mul(
2986 &self,
2987 qs: &Array2<f64>,
2988 free_basis_opt: Option<&Array2<f64>>,
2989 u: &Array1<f64>,
2990 ) -> Result<Array1<f64>, EstimationError> {
2991 Ok(self.design_transformed(qs, free_basis_opt)?.dot(u))
2992 }
2993 fn design_transformed_transpose_mul(
2996 &self,
2997 qs: &Array2<f64>,
2998 free_basis_opt: Option<&Array2<f64>>,
2999 v: &Array1<f64>,
3000 ) -> Result<Array1<f64>, EstimationError> {
3001 Ok(self.design_transformed(qs, free_basis_opt)?.t().dot(v))
3002 }
3003 fn penalty_transformed(
3004 &self,
3005 qs: &Array2<f64>,
3006 free_basis_opt: Option<&Array2<f64>>,
3007 ) -> Result<Array2<f64>, EstimationError>;
3008 fn penalty_scaled_add_to(
3009 &self,
3010 target: &mut Array2<f64>,
3011 amp: f64,
3012 ) -> Result<(), EstimationError>;
3013}
3014
3015macro_rules! storage_dispatch {
3020 ($scrutinee:expr, $backend:ident => $body:expr) => {
3021 match $scrutinee {
3022 DerivativeMatrixStorage::Dense($backend) => $body,
3023 DerivativeMatrixStorage::Zero($backend) => $body,
3024 DerivativeMatrixStorage::Embedded($backend) => $body,
3025 DerivativeMatrixStorage::Implicit($backend) => $body,
3026 DerivativeMatrixStorage::LatentCoord($backend) => $body,
3027 }
3028 };
3029}
3030
3031#[derive(Clone)]
3032pub(crate) struct ZeroDerivativeMatrix {
3033 rows: usize,
3034 cols: usize,
3035}
3036
3037impl ZeroDerivativeMatrix {
3038 pub(crate) fn new(rows: usize, cols: usize) -> Self {
3039 Self { rows, cols }
3040 }
3041}
3042
3043#[derive(Clone, Copy, Debug)]
3045pub enum ImplicitDerivLevel {
3046 First(usize),
3048 SecondDiag(usize),
3050 SecondCross(usize, usize),
3052}
3053
3054#[derive(Clone)]
3057pub(crate) struct ImplicitDerivativeOp {
3058 pub(crate) operator: std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3059 pub(crate) level: ImplicitDerivLevel,
3060 pub(crate) global_range: Range<usize>,
3061 pub(crate) total_dim: usize,
3062 pub(crate) cached_dense: std::sync::Arc<gam_runtime::resource::RayonSafeOnce<Array2<f64>>>,
3072}
3073
3074#[derive(Clone)]
3075pub(crate) struct LatentCoordDerivativeOp {
3076 pub(crate) operator: std::sync::Arc<gam_terms::basis::LatentCoordDesignDerivative>,
3077 pub(crate) flat_axis: usize,
3078 pub(crate) global_range: Range<usize>,
3079 pub(crate) total_dim: usize,
3080 pub(crate) cached_dense: std::sync::Arc<gam_runtime::resource::RayonSafeOnce<Array2<f64>>>,
3081}
3082
3083impl LatentCoordDerivativeOp {
3084 pub(crate) fn materialize_local(&self) -> Array2<f64> {
3085 self.operator.materialize_axis(self.flat_axis).expect(
3086 "radial scalar evaluation failed during latent-coordinate derivative materialization",
3087 )
3088 }
3089
3090 pub(crate) fn materialize_dense(&self) -> &Array2<f64> {
3091 self.cached_dense.get_or_compute(|| {
3092 let local = self.materialize_local();
3093 let mut out = Array2::<f64>::zeros((local.nrows(), self.total_dim));
3094 out.slice_mut(s![.., self.global_range.clone()])
3095 .assign(&local);
3096 out
3097 })
3098 }
3099
3100 pub(crate) fn nrows(&self) -> usize {
3101 self.operator.n_data()
3102 }
3103
3104 pub(crate) fn ncols(&self) -> usize {
3105 self.total_dim
3106 }
3107
3108 pub(crate) fn transpose_mul(&self, v: &Array1<f64>) -> Array1<f64> {
3109 let local = self
3110 .operator
3111 .transpose_mul_axis(self.flat_axis, &v.view())
3112 .expect(
3113 "radial scalar evaluation failed during latent-coordinate derivative transpose_mul",
3114 );
3115 let mut out = Array1::<f64>::zeros(self.total_dim);
3116 out.slice_mut(s![self.global_range.clone()]).assign(&local);
3117 out
3118 }
3119
3120 pub(crate) fn forward_mul(&self, u: &Array1<f64>) -> Array1<f64> {
3121 let u_local = u.slice(s![self.global_range.clone()]).to_owned();
3122 self.operator
3123 .forward_mul_axis(self.flat_axis, &u_local.view())
3124 .expect(
3125 "radial scalar evaluation failed during latent-coordinate derivative forward_mul",
3126 )
3127 }
3128}
3129
3130impl ImplicitDerivativeOp {
3131 pub(crate) fn materialize_local(&self) -> Array2<f64> {
3132 match self.level {
3133 ImplicitDerivLevel::First(axis) => self.operator.materialize_first(axis).expect(
3134 "radial scalar evaluation failed during implicit derivative materialization",
3135 ),
3136 ImplicitDerivLevel::SecondDiag(axis) => {
3137 self.operator.materialize_second_diag(axis).expect(
3138 "radial scalar evaluation failed during implicit derivative materialization",
3139 )
3140 }
3141 ImplicitDerivLevel::SecondCross(d, e) => {
3142 self.operator.materialize_second_cross(d, e).expect(
3143 "radial scalar evaluation failed during implicit derivative materialization",
3144 )
3145 }
3146 }
3147 }
3148
3149 pub(crate) fn materialize_dense(&self) -> &Array2<f64> {
3150 self.cached_dense.get_or_compute(|| {
3151 let local = self.materialize_local();
3152 let mut out = Array2::<f64>::zeros((local.nrows(), self.total_dim));
3153 out.slice_mut(s![.., self.global_range.clone()])
3154 .assign(&local);
3155 out
3156 })
3157 }
3158
3159 pub(crate) fn nrows(&self) -> usize {
3160 self.operator.n_data()
3161 }
3162
3163 pub(crate) fn ncols(&self) -> usize {
3164 self.total_dim
3165 }
3166
3167 pub(crate) fn transpose_mul(&self, v: &Array1<f64>) -> Array1<f64> {
3168 let local = match self.level {
3169 ImplicitDerivLevel::First(axis) => self
3170 .operator
3171 .transpose_mul(axis, &v.view())
3172 .expect("radial scalar evaluation failed during implicit derivative transpose_mul"),
3173 ImplicitDerivLevel::SecondDiag(axis) => self
3174 .operator
3175 .transpose_mul_second_diag(axis, &v.view())
3176 .expect("radial scalar evaluation failed during implicit derivative transpose_mul"),
3177 ImplicitDerivLevel::SecondCross(d, e) => self
3178 .operator
3179 .transpose_mul_second_cross(d, e, &v.view())
3180 .expect("radial scalar evaluation failed during implicit derivative transpose_mul"),
3181 };
3182 let mut out = Array1::<f64>::zeros(self.total_dim);
3183 out.slice_mut(s![self.global_range.clone()]).assign(&local);
3184 out
3185 }
3186
3187 pub(crate) fn forward_mul(&self, u: &Array1<f64>) -> Array1<f64> {
3188 let u_local = u.slice(s![self.global_range.clone()]).to_owned();
3189 match self.level {
3190 ImplicitDerivLevel::First(axis) => self
3191 .operator
3192 .forward_mul(axis, &u_local.view())
3193 .expect("radial scalar evaluation failed during implicit derivative forward_mul"),
3194 ImplicitDerivLevel::SecondDiag(axis) => self
3195 .operator
3196 .forward_mul_second_diag(axis, &u_local.view())
3197 .expect("radial scalar evaluation failed during implicit derivative forward_mul"),
3198 ImplicitDerivLevel::SecondCross(d, e) => self
3199 .operator
3200 .forward_mul_second_cross(d, e, &u_local.view())
3201 .expect("radial scalar evaluation failed during implicit derivative forward_mul"),
3202 }
3203 }
3204}
3205
3206#[derive(Clone)]
3207pub(crate) struct EmbeddedDerivativeMatrix {
3208 pub(crate) local: Array2<f64>,
3209 pub(crate) global_range: Range<usize>,
3210 pub(crate) total_dim: usize,
3211}
3212
3213impl EmbeddedDerivativeMatrix {
3214 pub(crate) fn new(local: Array2<f64>, global_range: Range<usize>, total_dim: usize) -> Self {
3215 Self {
3216 local,
3217 global_range,
3218 total_dim,
3219 }
3220 }
3221}
3222
3223impl DerivativeStorageBackend for Array2<f64> {
3224 fn resident_byte_count(&self) -> usize {
3225 self.len().saturating_mul(std::mem::size_of::<f64>())
3226 }
3227 fn design_nrows(&self) -> usize {
3228 Array2::nrows(self)
3229 }
3230 fn design_ncols(&self) -> usize {
3231 Array2::ncols(self)
3232 }
3233 fn penalty_dim(&self) -> usize {
3234 Array2::nrows(self)
3235 }
3236 fn uses_implicit_storage(&self) -> bool {
3237 false
3238 }
3239 fn any_nonzero(&self) -> bool {
3240 self.iter().any(|v| *v != 0.0)
3241 }
3242 fn materialize(&self) -> Array2<f64> {
3243 self.clone()
3244 }
3245 fn implicit_first_axis_info(
3246 &self,
3247 ) -> Option<(
3248 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3249 usize,
3250 )> {
3251 None
3252 }
3253 fn implicit_axis_count_hint(&self) -> Option<usize> {
3254 None
3255 }
3256
3257 fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3258 if Array2::ncols(self) != u.len() {
3259 crate::bail_invalid_estim!(
3260 "dense hyper design derivative forward_mul_original width mismatch: matrix={}x{}, vector={}",
3261 Array2::nrows(self),
3262 Array2::ncols(self),
3263 u.len()
3264 );
3265 }
3266 Ok(self.dot(u))
3267 }
3268
3269 fn design_transpose_mul_original(
3270 &self,
3271 v: &Array1<f64>,
3272 ) -> Result<Array1<f64>, EstimationError> {
3273 if Array2::nrows(self) != v.len() {
3274 crate::bail_invalid_estim!(
3275 "dense hyper design derivative transpose_mul_original height mismatch: matrix={}x{}, vector={}",
3276 Array2::nrows(self),
3277 Array2::ncols(self),
3278 v.len()
3279 );
3280 }
3281 Ok(self.t().dot(v))
3282 }
3283
3284 fn design_transformed(
3285 &self,
3286 qs: &Array2<f64>,
3287 free_basis_opt: Option<&Array2<f64>>,
3288 ) -> Result<Array2<f64>, EstimationError> {
3289 Ok(gam_linalg::matrix::DenseRightProductView::new(self)
3290 .with_factor(qs)
3291 .with_optional_factor(free_basis_opt)
3292 .materialize())
3293 }
3294
3295 fn penalty_transformed(
3296 &self,
3297 qs: &Array2<f64>,
3298 free_basis_opt: Option<&Array2<f64>>,
3299 ) -> Result<Array2<f64>, EstimationError> {
3300 let mut transformed = qs.t().dot(self).dot(qs);
3301 if let Some(z) = free_basis_opt {
3302 transformed = z.t().dot(&transformed).dot(z);
3303 }
3304 Ok(transformed)
3305 }
3306
3307 fn penalty_scaled_add_to(
3308 &self,
3309 target: &mut Array2<f64>,
3310 amp: f64,
3311 ) -> Result<(), EstimationError> {
3312 if target.raw_dim() != self.raw_dim() {
3313 crate::bail_invalid_estim!(
3314 "dense hyper penalty derivative shape mismatch: target={}x{}, matrix={}x{}",
3315 target.nrows(),
3316 target.ncols(),
3317 Array2::nrows(self),
3318 Array2::ncols(self)
3319 );
3320 }
3321 target.scaled_add(amp, self);
3322 Ok(())
3323 }
3324}
3325
3326impl DerivativeStorageBackend for ZeroDerivativeMatrix {
3327 fn resident_byte_count(&self) -> usize {
3328 0
3329 }
3330 fn design_nrows(&self) -> usize {
3331 self.rows
3332 }
3333 fn design_ncols(&self) -> usize {
3334 self.cols
3335 }
3336 fn penalty_dim(&self) -> usize {
3337 self.cols
3338 }
3339 fn uses_implicit_storage(&self) -> bool {
3340 false
3341 }
3342 fn any_nonzero(&self) -> bool {
3343 false
3344 }
3345 fn materialize(&self) -> Array2<f64> {
3346 Array2::<f64>::zeros((self.rows, self.cols))
3347 }
3348 fn implicit_first_axis_info(
3349 &self,
3350 ) -> Option<(
3351 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3352 usize,
3353 )> {
3354 None
3355 }
3356 fn implicit_axis_count_hint(&self) -> Option<usize> {
3357 None
3358 }
3359
3360 fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3361 if self.cols != u.len() {
3362 crate::bail_invalid_estim!(
3363 "zero hyper design derivative forward_mul_original width mismatch: matrix={}x{}, vector={}",
3364 self.rows,
3365 self.cols,
3366 u.len()
3367 );
3368 }
3369 Ok(Array1::<f64>::zeros(self.rows))
3370 }
3371
3372 fn design_transpose_mul_original(
3373 &self,
3374 v: &Array1<f64>,
3375 ) -> Result<Array1<f64>, EstimationError> {
3376 if self.rows != v.len() {
3377 crate::bail_invalid_estim!(
3378 "zero hyper design derivative transpose_mul_original height mismatch: matrix={}x{}, vector={}",
3379 self.rows,
3380 self.cols,
3381 v.len()
3382 );
3383 }
3384 Ok(Array1::<f64>::zeros(self.cols))
3385 }
3386
3387 fn design_transformed(
3388 &self,
3389 qs: &Array2<f64>,
3390 free_basis_opt: Option<&Array2<f64>>,
3391 ) -> Result<Array2<f64>, EstimationError> {
3392 if self.cols != qs.nrows() {
3393 crate::bail_invalid_estim!(
3394 "zero design derivative width mismatch: total_cols={}, qs rows={}",
3395 self.cols,
3396 qs.nrows()
3397 );
3398 }
3399 let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3400 Ok(Array2::<f64>::zeros((self.rows, cols)))
3401 }
3402
3403 fn design_transformed_forward_mul(
3404 &self,
3405 qs: &Array2<f64>,
3406 free_basis_opt: Option<&Array2<f64>>,
3407 u: &Array1<f64>,
3408 ) -> Result<Array1<f64>, EstimationError> {
3409 if self.cols != qs.nrows() {
3410 crate::bail_invalid_estim!(
3411 "zero design derivative width mismatch: total_cols={}, qs rows={}",
3412 self.cols,
3413 qs.nrows()
3414 );
3415 }
3416 let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3417 if u.len() != cols {
3418 crate::bail_invalid_estim!(
3419 "zero design derivative transformed forward width mismatch: expected {}, vector={}",
3420 cols,
3421 u.len()
3422 );
3423 }
3424 Ok(Array1::<f64>::zeros(self.rows))
3425 }
3426
3427 fn design_transformed_transpose_mul(
3428 &self,
3429 qs: &Array2<f64>,
3430 free_basis_opt: Option<&Array2<f64>>,
3431 v: &Array1<f64>,
3432 ) -> Result<Array1<f64>, EstimationError> {
3433 if self.rows != v.len() {
3434 crate::bail_invalid_estim!(
3435 "zero design derivative transpose height mismatch: matrix rows={}, vector={}",
3436 self.rows,
3437 v.len()
3438 );
3439 }
3440 if self.cols != qs.nrows() {
3441 crate::bail_invalid_estim!(
3442 "zero design derivative width mismatch: total_cols={}, qs rows={}",
3443 self.cols,
3444 qs.nrows()
3445 );
3446 }
3447 let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3448 Ok(Array1::<f64>::zeros(cols))
3449 }
3450
3451 fn penalty_transformed(
3452 &self,
3453 qs: &Array2<f64>,
3454 free_basis_opt: Option<&Array2<f64>>,
3455 ) -> Result<Array2<f64>, EstimationError> {
3456 if self.cols != qs.nrows() {
3457 crate::bail_invalid_estim!(
3458 "zero penalty derivative width mismatch: total_dim={}, qs rows={}",
3459 self.cols,
3460 qs.nrows()
3461 );
3462 }
3463 let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3464 Ok(Array2::<f64>::zeros((cols, cols)))
3465 }
3466
3467 fn penalty_scaled_add_to(
3468 &self,
3469 target: &mut Array2<f64>,
3470 amp: f64,
3471 ) -> Result<(), EstimationError> {
3472 if !amp.is_finite() {
3476 crate::bail_invalid_estim!(
3477 "zero hyper penalty derivative received non-finite amp={amp}"
3478 );
3479 }
3480 if target.nrows() != self.cols || target.ncols() != self.cols {
3481 crate::bail_invalid_estim!(
3482 "zero hyper penalty derivative shape mismatch: target={}x{}, expected {}x{}",
3483 target.nrows(),
3484 target.ncols(),
3485 self.cols,
3486 self.cols
3487 );
3488 }
3489 Ok(())
3490 }
3491}
3492
3493impl DerivativeStorageBackend for EmbeddedDerivativeMatrix {
3494 fn resident_byte_count(&self) -> usize {
3495 self.local.len().saturating_mul(std::mem::size_of::<f64>())
3496 }
3497 fn design_nrows(&self) -> usize {
3498 self.local.nrows()
3499 }
3500 fn design_ncols(&self) -> usize {
3501 self.total_dim
3502 }
3503 fn penalty_dim(&self) -> usize {
3504 self.total_dim
3505 }
3506 fn uses_implicit_storage(&self) -> bool {
3507 false
3508 }
3509 fn any_nonzero(&self) -> bool {
3510 self.local.iter().any(|v| *v != 0.0)
3511 }
3512 fn materialize(&self) -> Array2<f64> {
3513 let mut dense = Array2::<f64>::zeros((self.local.nrows(), self.total_dim));
3514 dense
3515 .slice_mut(s![.., self.global_range.clone()])
3516 .assign(&self.local);
3517 dense
3518 }
3519 fn implicit_first_axis_info(
3520 &self,
3521 ) -> Option<(
3522 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3523 usize,
3524 )> {
3525 None
3526 }
3527 fn implicit_axis_count_hint(&self) -> Option<usize> {
3528 None
3529 }
3530
3531 fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3532 if self.total_dim != u.len() {
3533 crate::bail_invalid_estim!(
3534 "embedded hyper design derivative forward_mul_original width mismatch: total_dim={}, vector={}",
3535 self.total_dim,
3536 u.len()
3537 );
3538 }
3539 let u_local = u.slice(s![self.global_range.clone()]).to_owned();
3540 Ok(self.local.dot(&u_local))
3541 }
3542
3543 fn design_transpose_mul_original(
3544 &self,
3545 v: &Array1<f64>,
3546 ) -> Result<Array1<f64>, EstimationError> {
3547 if self.local.nrows() != v.len() {
3548 crate::bail_invalid_estim!(
3549 "embedded hyper design derivative transpose_mul_original height mismatch: local_rows={}, vector={}",
3550 self.local.nrows(),
3551 v.len()
3552 );
3553 }
3554 let mut out = Array1::<f64>::zeros(self.total_dim);
3555 let pulled = self.local.t().dot(v);
3556 out.slice_mut(s![self.global_range.clone()]).assign(&pulled);
3557 Ok(out)
3558 }
3559
3560 fn design_transformed(
3561 &self,
3562 qs: &Array2<f64>,
3563 free_basis_opt: Option<&Array2<f64>>,
3564 ) -> Result<Array2<f64>, EstimationError> {
3565 if self.total_dim != qs.nrows() {
3566 crate::bail_invalid_estim!(
3567 "embedded design derivative width mismatch: total_cols={}, qs rows={}",
3568 self.total_dim,
3569 qs.nrows()
3570 );
3571 }
3572 let qs_local = qs.slice(s![self.global_range.clone(), ..]);
3573 let mut transformed = self.local.dot(&qs_local);
3574 if let Some(z) = free_basis_opt {
3575 transformed = transformed.dot(z);
3576 }
3577 Ok(transformed)
3578 }
3579
3580 fn penalty_transformed(
3581 &self,
3582 qs: &Array2<f64>,
3583 free_basis_opt: Option<&Array2<f64>>,
3584 ) -> Result<Array2<f64>, EstimationError> {
3585 if self.total_dim != qs.nrows() {
3586 crate::bail_invalid_estim!(
3587 "embedded penalty derivative width mismatch: total_dim={}, qs rows={}",
3588 self.total_dim,
3589 qs.nrows()
3590 );
3591 }
3592 let qs_local = qs.slice(s![self.global_range.clone(), ..]);
3593 let mut transformed = qs_local.t().dot(&self.local).dot(&qs_local);
3594 if let Some(z) = free_basis_opt {
3595 transformed = z.t().dot(&transformed).dot(z);
3596 }
3597 Ok(transformed)
3598 }
3599
3600 fn penalty_scaled_add_to(
3601 &self,
3602 target: &mut Array2<f64>,
3603 amp: f64,
3604 ) -> Result<(), EstimationError> {
3605 if target.nrows() != self.total_dim || target.ncols() != self.total_dim {
3606 crate::bail_invalid_estim!(
3607 "embedded hyper penalty derivative shape mismatch: target={}x{}, expected {}x{}",
3608 target.nrows(),
3609 target.ncols(),
3610 self.total_dim,
3611 self.total_dim
3612 );
3613 }
3614 target
3615 .slice_mut(s![self.global_range.clone(), self.global_range.clone()])
3616 .scaled_add(amp, &self.local);
3617 Ok(())
3618 }
3619}
3620
3621impl DerivativeStorageBackend for ImplicitDerivativeOp {
3622 fn resident_byte_count(&self) -> usize {
3623 0
3624 }
3625 fn design_nrows(&self) -> usize {
3626 self.nrows()
3627 }
3628 fn design_ncols(&self) -> usize {
3629 self.ncols()
3630 }
3631 fn penalty_dim(&self) -> usize {
3632 self.nrows()
3633 }
3634 fn uses_implicit_storage(&self) -> bool {
3635 true
3636 }
3637 fn any_nonzero(&self) -> bool {
3638 true
3639 }
3640 fn materialize(&self) -> Array2<f64> {
3641 self.materialize_dense().clone()
3642 }
3643 fn implicit_first_axis_info(
3644 &self,
3645 ) -> Option<(
3646 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3647 usize,
3648 )> {
3649 match self.level {
3650 ImplicitDerivLevel::First(axis) => Some((self.operator.clone(), axis)),
3651 _ => None,
3652 }
3653 }
3654 fn implicit_axis_count_hint(&self) -> Option<usize> {
3655 Some(self.operator.n_axes())
3656 }
3657
3658 fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3659 if self.ncols() != u.len() {
3660 crate::bail_invalid_estim!(
3661 "implicit hyper design derivative forward_mul_original width mismatch: operator_cols={}, vector={}",
3662 self.ncols(),
3663 u.len()
3664 );
3665 }
3666 Ok(self.forward_mul(u))
3667 }
3668
3669 fn design_transpose_mul_original(
3670 &self,
3671 v: &Array1<f64>,
3672 ) -> Result<Array1<f64>, EstimationError> {
3673 if self.nrows() != v.len() {
3674 crate::bail_invalid_estim!(
3675 "implicit hyper design derivative transpose_mul_original height mismatch: operator_rows={}, vector={}",
3676 self.nrows(),
3677 v.len()
3678 );
3679 }
3680 Ok(self.transpose_mul(v))
3681 }
3682
3683 fn design_transformed(
3684 &self,
3685 qs: &Array2<f64>,
3686 free_basis_opt: Option<&Array2<f64>>,
3687 ) -> Result<Array2<f64>, EstimationError> {
3688 let dense = self.materialize_dense();
3689 Ok(gam_linalg::matrix::DenseRightProductView::new(dense)
3690 .with_factor(qs)
3691 .with_optional_factor(free_basis_opt)
3692 .materialize())
3693 }
3694
3695 fn design_transformed_forward_mul(
3696 &self,
3697 qs: &Array2<f64>,
3698 free_basis_opt: Option<&Array2<f64>>,
3699 u: &Array1<f64>,
3700 ) -> Result<Array1<f64>, EstimationError> {
3701 let mut right = if let Some(z) = free_basis_opt {
3702 z.dot(u)
3703 } else {
3704 u.clone()
3705 };
3706 right = qs.dot(&right);
3707 Ok(self.forward_mul(&right))
3708 }
3709
3710 fn design_transformed_transpose_mul(
3711 &self,
3712 qs: &Array2<f64>,
3713 free_basis_opt: Option<&Array2<f64>>,
3714 v: &Array1<f64>,
3715 ) -> Result<Array1<f64>, EstimationError> {
3716 let mut pulled = qs.t().dot(&self.transpose_mul(v));
3717 if let Some(z) = free_basis_opt {
3718 pulled = z.t().dot(&pulled);
3719 }
3720 Ok(pulled)
3721 }
3722
3723 fn penalty_transformed(
3724 &self,
3725 qs: &Array2<f64>,
3726 free_basis_opt: Option<&Array2<f64>>,
3727 ) -> Result<Array2<f64>, EstimationError> {
3728 let dense = self.materialize_dense();
3729 let mut transformed = qs.t().dot(dense).dot(qs);
3730 if let Some(z) = free_basis_opt {
3731 transformed = z.t().dot(&transformed).dot(z);
3732 }
3733 Ok(transformed)
3734 }
3735
3736 fn penalty_scaled_add_to(
3737 &self,
3738 target: &mut Array2<f64>,
3739 amp: f64,
3740 ) -> Result<(), EstimationError> {
3741 let dense = self.materialize_dense();
3742 if target.raw_dim() != dense.raw_dim() {
3743 crate::bail_invalid_estim!(
3744 "implicit hyper penalty derivative shape mismatch: target={}x{}, matrix={}x{}",
3745 target.nrows(),
3746 target.ncols(),
3747 dense.nrows(),
3748 dense.ncols()
3749 );
3750 }
3751 target.scaled_add(amp, dense);
3752 Ok(())
3753 }
3754}
3755
3756impl DerivativeStorageBackend for LatentCoordDerivativeOp {
3757 fn resident_byte_count(&self) -> usize {
3758 0
3759 }
3760 fn design_nrows(&self) -> usize {
3761 self.nrows()
3762 }
3763 fn design_ncols(&self) -> usize {
3764 self.ncols()
3765 }
3766 fn penalty_dim(&self) -> usize {
3767 self.nrows()
3768 }
3769 fn uses_implicit_storage(&self) -> bool {
3770 true
3771 }
3772 fn any_nonzero(&self) -> bool {
3773 true
3774 }
3775 fn materialize(&self) -> Array2<f64> {
3776 self.materialize_dense().clone()
3777 }
3778 fn implicit_first_axis_info(
3779 &self,
3780 ) -> Option<(
3781 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3782 usize,
3783 )> {
3784 None
3785 }
3786 fn implicit_axis_count_hint(&self) -> Option<usize> {
3787 Some(self.operator.n_axes())
3788 }
3789
3790 fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3791 if self.ncols() != u.len() {
3792 crate::bail_invalid_estim!(
3793 "latent-coordinate hyper design derivative forward_mul_original width mismatch: operator_cols={}, vector={}",
3794 self.ncols(),
3795 u.len()
3796 );
3797 }
3798 Ok(self.forward_mul(u))
3799 }
3800
3801 fn design_transpose_mul_original(
3802 &self,
3803 v: &Array1<f64>,
3804 ) -> Result<Array1<f64>, EstimationError> {
3805 if self.nrows() != v.len() {
3806 crate::bail_invalid_estim!(
3807 "latent-coordinate hyper design derivative transpose_mul_original height mismatch: operator_rows={}, vector={}",
3808 self.nrows(),
3809 v.len()
3810 );
3811 }
3812 Ok(self.transpose_mul(v))
3813 }
3814
3815 fn design_transformed(
3816 &self,
3817 qs: &Array2<f64>,
3818 free_basis_opt: Option<&Array2<f64>>,
3819 ) -> Result<Array2<f64>, EstimationError> {
3820 let dense = self.materialize_dense();
3821 Ok(gam_linalg::matrix::DenseRightProductView::new(dense)
3822 .with_factor(qs)
3823 .with_optional_factor(free_basis_opt)
3824 .materialize())
3825 }
3826
3827 fn design_transformed_forward_mul(
3828 &self,
3829 qs: &Array2<f64>,
3830 free_basis_opt: Option<&Array2<f64>>,
3831 u: &Array1<f64>,
3832 ) -> Result<Array1<f64>, EstimationError> {
3833 let mut right = if let Some(z) = free_basis_opt {
3834 z.dot(u)
3835 } else {
3836 u.clone()
3837 };
3838 right = qs.dot(&right);
3839 Ok(self.forward_mul(&right))
3840 }
3841
3842 fn design_transformed_transpose_mul(
3843 &self,
3844 qs: &Array2<f64>,
3845 free_basis_opt: Option<&Array2<f64>>,
3846 v: &Array1<f64>,
3847 ) -> Result<Array1<f64>, EstimationError> {
3848 let mut pulled = qs.t().dot(&self.transpose_mul(v));
3849 if let Some(z) = free_basis_opt {
3850 pulled = z.t().dot(&pulled);
3851 }
3852 Ok(pulled)
3853 }
3854
3855 fn penalty_transformed(
3856 &self,
3857 qs: &Array2<f64>,
3858 free_basis_opt: Option<&Array2<f64>>,
3859 ) -> Result<Array2<f64>, EstimationError> {
3860 let dense = self.materialize_dense();
3861 let mut transformed = qs.t().dot(dense).dot(qs);
3862 if let Some(z) = free_basis_opt {
3863 transformed = z.t().dot(&transformed).dot(z);
3864 }
3865 Ok(transformed)
3866 }
3867
3868 fn penalty_scaled_add_to(
3869 &self,
3870 target: &mut Array2<f64>,
3871 amp: f64,
3872 ) -> Result<(), EstimationError> {
3873 let dense = self.materialize_dense();
3874 if target.raw_dim() != dense.raw_dim() {
3875 crate::bail_invalid_estim!(
3876 "latent-coordinate hyper penalty derivative shape mismatch: target={}x{}, matrix={}x{}",
3877 target.nrows(),
3878 target.ncols(),
3879 dense.nrows(),
3880 dense.ncols()
3881 );
3882 }
3883 target.scaled_add(amp, dense);
3884 Ok(())
3885 }
3886}
3887
3888#[derive(Clone)]
3889pub struct HyperDesignDerivative {
3890 pub(crate) storage: DerivativeMatrixStorage,
3891}
3892
3893impl HyperDesignDerivative {
3894 pub fn zero(nrows: usize, ncols: usize) -> Self {
3895 Self {
3896 storage: DerivativeMatrixStorage::Zero(ZeroDerivativeMatrix::new(nrows, ncols)),
3897 }
3898 }
3899
3900 pub fn from_embedded(
3901 local: Array2<f64>,
3902 global_range: Range<usize>,
3903 total_cols: usize,
3904 ) -> Self {
3905 Self {
3906 storage: DerivativeMatrixStorage::Embedded(EmbeddedDerivativeMatrix::new(
3907 local,
3908 global_range,
3909 total_cols,
3910 )),
3911 }
3912 }
3913
3914 pub fn from_implicit(
3915 operator: std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3916 level: ImplicitDerivLevel,
3917 global_range: Range<usize>,
3918 total_cols: usize,
3919 ) -> Self {
3920 Self {
3921 storage: DerivativeMatrixStorage::Implicit(ImplicitDerivativeOp {
3922 operator,
3923 level,
3924 global_range,
3925 total_dim: total_cols,
3926 cached_dense: std::sync::Arc::new(gam_runtime::resource::RayonSafeOnce::new()),
3927 }),
3928 }
3929 }
3930
3931 pub fn from_latent_coord(
3932 operator: std::sync::Arc<gam_terms::basis::LatentCoordDesignDerivative>,
3933 flat_axis: usize,
3934 global_range: Range<usize>,
3935 total_cols: usize,
3936 ) -> Self {
3937 Self {
3938 storage: DerivativeMatrixStorage::LatentCoord(LatentCoordDerivativeOp {
3939 operator,
3940 flat_axis,
3941 global_range,
3942 total_dim: total_cols,
3943 cached_dense: std::sync::Arc::new(gam_runtime::resource::RayonSafeOnce::new()),
3944 }),
3945 }
3946 }
3947
3948 pub(crate) fn resident_byte_count(&self) -> usize {
3949 storage_dispatch!(&self.storage, b => b.resident_byte_count())
3950 }
3951
3952 pub(crate) fn nrows(&self) -> usize {
3953 storage_dispatch!(&self.storage, b => b.design_nrows())
3954 }
3955
3956 pub(crate) fn ncols(&self) -> usize {
3957 storage_dispatch!(&self.storage, b => b.design_ncols())
3958 }
3959
3960 pub(crate) fn uses_implicit_storage(&self) -> bool {
3961 storage_dispatch!(&self.storage, b => b.uses_implicit_storage())
3962 }
3963
3964 pub(crate) fn materialize(&self) -> Array2<f64> {
3965 storage_dispatch!(&self.storage, b => b.materialize())
3966 }
3967
3968 pub(crate) fn any_nonzero(&self) -> bool {
3969 storage_dispatch!(&self.storage, b => b.any_nonzero())
3970 }
3971
3972 pub(crate) fn forward_mul_original(
3973 &self,
3974 u: &Array1<f64>,
3975 ) -> Result<Array1<f64>, EstimationError> {
3976 storage_dispatch!(&self.storage, b => b.design_forward_mul_original(u))
3977 }
3978
3979 pub(crate) fn transpose_mul_original(
3980 &self,
3981 v: &Array1<f64>,
3982 ) -> Result<Array1<f64>, EstimationError> {
3983 storage_dispatch!(&self.storage, b => b.design_transpose_mul_original(v))
3984 }
3985
3986 pub(crate) fn transformed(
3987 &self,
3988 qs: &Array2<f64>,
3989 free_basis_opt: Option<&Array2<f64>>,
3990 ) -> Result<Array2<f64>, EstimationError> {
3991 storage_dispatch!(&self.storage, b => b.design_transformed(qs, free_basis_opt))
3992 }
3993
3994 pub(crate) fn transformed_forward_mul(
3995 &self,
3996 qs: &Array2<f64>,
3997 free_basis_opt: Option<&Array2<f64>>,
3998 u: &Array1<f64>,
3999 ) -> Result<Array1<f64>, EstimationError> {
4000 storage_dispatch!(&self.storage, b => b.design_transformed_forward_mul(qs, free_basis_opt, u))
4001 }
4002
4003 pub(crate) fn transformed_transpose_mul(
4004 &self,
4005 qs: &Array2<f64>,
4006 free_basis_opt: Option<&Array2<f64>>,
4007 v: &Array1<f64>,
4008 ) -> Result<Array1<f64>, EstimationError> {
4009 storage_dispatch!(&self.storage, b => b.design_transformed_transpose_mul(qs, free_basis_opt, v))
4010 }
4011
4012 pub(crate) fn implicit_first_axis_info(
4017 &self,
4018 ) -> Option<(
4019 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
4020 usize,
4021 )> {
4022 storage_dispatch!(&self.storage, b => b.implicit_first_axis_info())
4023 }
4024
4025 pub(crate) fn implicit_axis_count_hint(&self) -> Option<usize> {
4026 storage_dispatch!(&self.storage, b => b.implicit_axis_count_hint())
4027 }
4028}
4029
4030impl From<Array2<f64>> for HyperDesignDerivative {
4031 fn from(value: Array2<f64>) -> Self {
4032 Self {
4033 storage: DerivativeMatrixStorage::Dense(value),
4034 }
4035 }
4036}
4037
4038#[derive(Clone)]
4039pub struct HyperPenaltyDerivative {
4040 pub(crate) storage: DerivativeMatrixStorage,
4041}
4042
4043impl HyperPenaltyDerivative {
4044 pub fn from_embedded(
4045 local: Array2<f64>,
4046 global_range: Range<usize>,
4047 total_dim: usize,
4048 ) -> Self {
4049 Self {
4050 storage: DerivativeMatrixStorage::Embedded(EmbeddedDerivativeMatrix::new(
4051 local,
4052 global_range,
4053 total_dim,
4054 )),
4055 }
4056 }
4057
4058 pub(crate) fn resident_byte_count(&self) -> usize {
4059 storage_dispatch!(&self.storage, b => b.resident_byte_count())
4060 }
4061
4062 pub(crate) fn nrows(&self) -> usize {
4063 storage_dispatch!(&self.storage, b => b.penalty_dim())
4064 }
4065
4066 pub(crate) fn ncols(&self) -> usize {
4067 self.nrows()
4068 }
4069
4070 pub(crate) fn scaled_materialize(&self, amp: f64) -> Array2<f64> {
4071 let mut out = Array2::<f64>::zeros((self.nrows(), self.ncols()));
4072 self.scaled_add_to(&mut out, amp)
4073 .expect("scaled materialize uses matching target shape");
4074 out
4075 }
4076
4077 pub(crate) fn transformed(
4078 &self,
4079 qs: &Array2<f64>,
4080 free_basis_opt: Option<&Array2<f64>>,
4081 ) -> Result<Array2<f64>, EstimationError> {
4082 storage_dispatch!(&self.storage, b => b.penalty_transformed(qs, free_basis_opt))
4083 }
4084
4085 pub(crate) fn scaled_add_to(
4086 &self,
4087 target: &mut Array2<f64>,
4088 amp: f64,
4089 ) -> Result<(), EstimationError> {
4090 storage_dispatch!(&self.storage, b => b.penalty_scaled_add_to(target, amp))
4091 }
4092}
4093
4094impl From<Array2<f64>> for HyperPenaltyDerivative {
4095 fn from(value: Array2<f64>) -> Self {
4096 Self {
4097 storage: DerivativeMatrixStorage::Dense(value),
4098 }
4099 }
4100}
4101
4102#[derive(Clone)]
4103pub struct PenaltyDerivativeComponent {
4104 pub penalty_index: usize,
4105 pub matrix: HyperPenaltyDerivative,
4106}
4107
4108#[derive(Clone)]
4109pub struct DirectionalHyperParam {
4110 pub(crate) x_tau_original: HyperDesignDerivative,
4111 pub(crate) penalty_first_components: Vec<PenaltyDerivativeComponent>,
4114 pub(crate) x_tau_tau_original: Option<Vec<Option<HyperDesignDerivative>>>,
4118 pub(crate) penaltysecond_components: Option<Vec<Option<Vec<PenaltyDerivativeComponent>>>>,
4121 pub(crate) penaltysecond_component_provider: Option<
4122 std::sync::Arc<
4123 dyn Fn(usize) -> Result<Option<Vec<PenaltyDerivativeComponent>>, EstimationError>
4124 + Send
4125 + Sync
4126 + 'static,
4127 >,
4128 >,
4129 pub(crate) penaltysecond_partner_indices: Option<std::sync::Arc<[usize]>>,
4130 pub(crate) is_penalty_like: bool,
4134}
4135
4136impl DirectionalHyperParam {
4137 pub(crate) fn resident_byte_count(&self) -> usize {
4138 let mut bytes = self.x_tau_original.resident_byte_count();
4139 for component in &self.penalty_first_components {
4140 bytes = bytes.saturating_add(component.matrix.resident_byte_count());
4141 }
4142 if let Some(entries) = self.x_tau_tau_original.as_ref() {
4143 for entry in entries.iter().flatten() {
4144 bytes = bytes.saturating_add(entry.resident_byte_count());
4145 }
4146 }
4147 if let Some(rows) = self.penaltysecond_components.as_ref() {
4148 for components in rows.iter().flatten() {
4149 for component in components {
4150 bytes = bytes.saturating_add(component.matrix.resident_byte_count());
4151 }
4152 }
4153 }
4154 bytes
4155 }
4156
4157 pub(crate) fn canonicalize_penalty_components(
4158 components: Vec<(usize, HyperPenaltyDerivative)>,
4159 ) -> Result<Vec<PenaltyDerivativeComponent>, EstimationError> {
4160 let mut out: Vec<PenaltyDerivativeComponent> = Vec::with_capacity(components.len());
4161 for (penalty_index, matrix) in components {
4162 if out.iter().any(|c| c.penalty_index == penalty_index) {
4163 crate::bail_invalid_estim!(
4164 "duplicate penalty derivative component for penalty {}",
4165 penalty_index
4166 );
4167 }
4168 out.push(PenaltyDerivativeComponent {
4169 penalty_index,
4170 matrix,
4171 });
4172 }
4173 Ok(out)
4174 }
4175
4176 pub fn new_compact(
4177 x_tau_original: HyperDesignDerivative,
4178 penalty_first_components: Vec<(usize, HyperPenaltyDerivative)>,
4179 x_tau_tau_original: Option<Vec<Option<HyperDesignDerivative>>>,
4180 penaltysecond_components: Option<Vec<Option<Vec<(usize, HyperPenaltyDerivative)>>>>,
4181 ) -> Result<Self, EstimationError> {
4182 let is_penalty_like = !x_tau_original.any_nonzero();
4183 let penalty_first_components =
4184 Self::canonicalize_penalty_components(penalty_first_components)?;
4185 let penaltysecond_components = match penaltysecond_components {
4186 Some(rows) => {
4187 let mut out = Vec::with_capacity(rows.len());
4188 for row in rows {
4189 out.push(match row {
4190 Some(components) => {
4191 Some(Self::canonicalize_penalty_components(components)?)
4192 }
4193 None => None,
4194 });
4195 }
4196 Some(out)
4197 }
4198 None => None,
4199 };
4200 Ok(Self {
4201 x_tau_original,
4202 penalty_first_components,
4203 x_tau_tau_original,
4204 penaltysecond_components,
4205 penaltysecond_component_provider: None,
4206 penaltysecond_partner_indices: None,
4207 is_penalty_like,
4208 })
4209 }
4210
4211 pub fn not_penalty_like(mut self) -> Self {
4214 self.is_penalty_like = false;
4215 self
4216 }
4217
4218 pub fn with_penaltysecond_component_provider(
4219 mut self,
4220 provider: std::sync::Arc<
4221 dyn Fn(usize) -> Result<Option<Vec<PenaltyDerivativeComponent>>, EstimationError>
4222 + Send
4223 + Sync
4224 + 'static,
4225 >,
4226 ) -> Self {
4227 self.penaltysecond_component_provider = Some(provider);
4228 self
4229 }
4230
4231 pub fn with_penaltysecond_partner_indices(mut self, partners: Vec<usize>) -> Self {
4232 self.penaltysecond_partner_indices = Some(std::sync::Arc::from(partners));
4233 self
4234 }
4235
4236 pub(crate) fn x_tau_dense(&self) -> Array2<f64> {
4237 self.x_tau_original.materialize()
4238 }
4239
4240 pub(crate) fn transformed_x_tau(
4241 &self,
4242 qs: &Array2<f64>,
4243 free_basis_opt: Option<&Array2<f64>>,
4244 ) -> Result<Array2<f64>, EstimationError> {
4245 self.x_tau_original.transformed(qs, free_basis_opt)
4246 }
4247
4248 pub(crate) fn x_tau_tau_entry_at(&self, j: usize) -> Option<HyperDesignDerivative> {
4249 self.x_tau_tau_original
4250 .as_ref()
4251 .and_then(|rows| rows.get(j))
4252 .and_then(|entry| entry.clone())
4253 }
4254
4255 pub(crate) fn has_implicit_operator(&self) -> bool {
4258 self.x_tau_original.uses_implicit_storage()
4259 }
4260
4261 pub(crate) fn has_implicit_multidim_duchon(&self) -> bool {
4262 self.implicit_first_axis_info()
4263 .is_some_and(|(op, _)| op.n_axes() > 1 && op.is_duchon_family())
4264 }
4265
4266 pub(crate) fn implicit_first_axis_info(
4268 &self,
4269 ) -> Option<(
4270 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
4271 usize,
4272 )> {
4273 self.x_tau_original.implicit_first_axis_info()
4274 }
4275
4276 pub(crate) fn implicit_axis_count_hint(&self) -> Option<usize> {
4277 self.x_tau_original.implicit_axis_count_hint()
4278 }
4279
4280 pub(crate) fn penalty_first_components(&self) -> &[PenaltyDerivativeComponent] {
4281 &self.penalty_first_components
4282 }
4283
4284 pub(crate) fn penalty_total_at(
4285 &self,
4286 rho: &Array1<f64>,
4287 p: usize,
4288 ) -> Result<Array2<f64>, EstimationError> {
4289 let mut out = Array2::<f64>::zeros((p, p));
4290 for component in &self.penalty_first_components {
4291 if component.matrix.nrows() != p || component.matrix.ncols() != p {
4292 crate::bail_invalid_estim!(
4293 "S_tau shape mismatch for penalty {}: expected {}x{}, got {}x{}",
4294 component.penalty_index,
4295 p,
4296 p,
4297 component.matrix.nrows(),
4298 component.matrix.ncols()
4299 );
4300 }
4301 if component.penalty_index >= rho.len() {
4302 crate::bail_invalid_estim!(
4303 "penalty_index {} out of bounds for rho dimension {}",
4304 component.penalty_index,
4305 rho.len()
4306 );
4307 }
4308 component
4309 .matrix
4310 .scaled_add_to(&mut out, rho[component.penalty_index].exp())?;
4311 }
4312 Ok(out)
4313 }
4314
4315 pub(crate) fn penaltysecond_components_for(
4316 &self,
4317 j: usize,
4318 ) -> Result<Option<Vec<PenaltyDerivativeComponent>>, EstimationError> {
4319 if let Some(components) = self
4320 .penaltysecond_components
4321 .as_ref()
4322 .and_then(|rows| rows.get(j))
4323 .and_then(|row| row.clone())
4324 {
4325 return Ok(Some(components));
4326 }
4327 if let Some(provider) = self.penaltysecond_component_provider.as_ref() {
4328 return provider(j);
4329 }
4330 Ok(None)
4331 }
4332
4333 pub(crate) fn penaltysecond_componentrows(
4334 &self,
4335 ) -> Option<&[Option<Vec<PenaltyDerivativeComponent>>]> {
4336 self.penaltysecond_components.as_deref()
4337 }
4338
4339 pub(crate) fn penalty_first_component_count(&self) -> usize {
4340 self.penalty_first_components.len()
4341 }
4342
4343 pub(crate) fn has_penaltysecond_pair_at(&self, j: usize) -> bool {
4344 self.penaltysecond_components
4345 .as_ref()
4346 .and_then(|rows| rows.get(j))
4347 .is_some_and(Option::is_some)
4348 || self
4349 .penaltysecond_partner_indices
4350 .as_ref()
4351 .is_some_and(|partners| partners.contains(&j))
4352 }
4353}
4354
4355#[derive(Clone, Debug)]
4356pub(crate) struct SparseRemlDecision {
4357 pub(crate) geometry: RemlGeometry,
4358 pub(crate) reason: &'static str,
4359 pub(crate) p: usize,
4360 pub(crate) nnz_x: usize,
4361 pub(crate) nnz_h_upper_est: Option<usize>,
4362 pub(crate) density_h_upper_est: Option<f64>,
4363}
4364
4365#[derive(Clone)]
4366pub(crate) struct SparseExactEvalData {
4367 pub(crate) factor: Arc<SparseExactFactor>,
4368 pub(crate) takahashi: Option<Arc<gam_linalg::sparse_exact::TakahashiInverse>>,
4369 pub(crate) logdet_h: f64,
4370 pub(crate) logdet_s_pos: f64,
4371 pub(crate) penalty_rank: usize,
4372 pub(crate) det1_values: Arc<Array1<f64>>,
4373}
4374
4375#[derive(Clone)]
4376pub struct FirthDenseOperator {
4377 pub(crate) x_dense: Array2<f64>,
4404 pub(crate) x_dense_t: Array2<f64>,
4405 pub(crate) q_basis: Array2<f64>,
4408 pub(crate) x_reduced: Array2<f64>,
4411 pub(crate) observation_weight_sqrt: Option<Array1<f64>>,
4417 pub(crate) k_reduced: Array2<f64>,
4419 pub(crate) x_metric_reduced_inv_diag: Array1<f64>,
4424 pub(crate) half_log_det: f64,
4426 pub(crate) h_diag: Array1<f64>,
4428 pub(crate) w: Array1<f64>,
4430 pub(crate) w1: Array1<f64>,
4431 pub(crate) w2: Array1<f64>,
4432 pub(crate) w3: Array1<f64>,
4433 pub(crate) w4: Array1<f64>,
4434 pub(crate) b_base: Array2<f64>,
4436 pub(crate) p_b_base: Array2<f64>,
4439}
4440
4441#[derive(Clone)]
4458pub(crate) struct FirthDesignFactor {
4459 pub(crate) x_dense: Array2<f64>,
4461 pub(crate) x_dense_t: Array2<f64>,
4462 pub(crate) q_basis: Array2<f64>,
4464 pub(crate) x_reduced: Array2<f64>,
4466 pub(crate) observation_weight_sqrt: Option<Array1<f64>>,
4468 pub(crate) metric_spectrum: Array1<f64>,
4470 pub(crate) x_metric_reduced_inv_diag: Array1<f64>,
4472 pub(crate) r: usize,
4474 pub(crate) n: usize,
4475}
4476
4477#[derive(Clone)]
4478pub(crate) struct FirthDirection {
4479 pub(crate) deta: Array1<f64>,
4480 pub(crate) g_u_reduced: Array2<f64>,
4481 pub(crate) a_u_reduced: Array2<f64>,
4482 pub(crate) dh: Array1<f64>,
4483 pub(crate) b_uvec: Array1<f64>,
4485}
4486
4487#[derive(Clone)]
4488pub(crate) struct FirthTauPartialKernel {
4489 pub(super) deta_partial: Array1<f64>,
4490 pub(crate) dotw1: Array1<f64>,
4491 pub(crate) dotw2: Array1<f64>,
4492 pub(crate) dot_h_partial: Array1<f64>,
4493 pub(crate) x_tau_reduced: Array2<f64>,
4496 pub(super) dot_i_partial: Array2<f64>,
4497 pub(crate) dot_k_reduced: Array2<f64>,
4501}
4502
4503#[derive(Clone)]
4504pub(crate) struct FirthTauExactKernel {
4505 pub(crate) gphi_tau: Array1<f64>,
4506 pub(crate) phi_tau_partial: f64,
4507 pub(crate) tau_kernel: Option<FirthTauPartialKernel>,
4508}
4509
4510#[derive(Clone)]
4522pub(crate) struct FirthTauTauExactKernel {
4523 pub(super) phi_tau_tau_partial: f64,
4524 pub(super) gphi_tau_tau: Array1<f64>,
4525 pub(super) tau_tau_kernel: Option<FirthTauTauPartialKernel>,
4526}
4527
4528#[derive(Clone, Default)]
4541pub(crate) struct FirthTauTauPartialKernel {
4542 pub(super) x_tau_i_reduced: Array2<f64>,
4543 pub(super) x_tau_j_reduced: Array2<f64>,
4544 pub(super) deta_i_partial: Array1<f64>,
4545 pub(super) deta_j_partial: Array1<f64>,
4546 pub(super) dot_h_i_partial: Array1<f64>,
4547 pub(super) dot_h_j_partial: Array1<f64>,
4548 pub(super) dot_k_i_reduced: Array2<f64>,
4549 pub(super) dot_k_j_reduced: Array2<f64>,
4550 pub(super) dot_i_i_partial: Array2<f64>,
4551 pub(super) dot_i_j_partial: Array2<f64>,
4552 pub(super) x_tau_tau_reduced: Option<Array2<f64>>,
4553 pub(super) deta_ij_partial: Option<Array1<f64>>,
4554}
4555
4556#[derive(Clone, Default)]
4564pub(crate) struct FirthTauBetaPartialKernel {
4565 pub(super) x_tau_reduced: Array2<f64>,
4566 pub(super) deta_partial: Array1<f64>,
4567 pub(super) dot_h_partial: Array1<f64>,
4568 pub(super) dot_i_partial: Array2<f64>,
4569 pub(super) dot_k_reduced: Array2<f64>,
4570 pub(super) deta_v: Array1<f64>,
4571 pub(super) deta_tau_v: Array1<f64>,
4572 pub(super) a_v_reduced: Array2<f64>,
4573 pub(super) dh_v: Array1<f64>,
4574 pub(super) b_vvec: Array1<f64>,
4575 pub(super) d_beta_dot_k: Array2<f64>,
4576 pub(super) d_beta_dot_h: Array1<f64>,
4577}
4578
4579#[derive(Clone)]
4590pub(crate) struct EvalShared {
4591 pub(crate) key: Option<Vec<u64>>,
4592 pub(crate) pirls_result: Arc<PirlsResult>,
4593 pub(crate) ridge_passport: RidgePassport,
4594 pub(crate) geometry: RemlGeometry,
4595 pub(crate) h_total: Arc<Array2<f64>>,
4599 pub(crate) sparse_exact: Option<Arc<SparseExactEvalData>>,
4600 pub(crate) firth_dense_operator: Option<Arc<FirthDenseOperator>>,
4601 pub(crate) firth_dense_operator_original: Option<Arc<FirthDenseOperator>>,
4604 pub(crate) penalty_pseudologdet: std::sync::OnceLock<Arc<penalty_logdet::PenaltyPseudologdet>>,
4618 pub(crate) penalty_scores_at_mode: std::sync::OnceLock<Arc<Vec<Array1<f64>>>>,
4631 pub(crate) block_local_correction:
4649 std::sync::OnceLock<(usize, Arc<outer_eval::TkCorrectionTerms>)>,
4650}
4651
4652impl EvalShared {
4653 pub(crate) fn matches(&self, key: &Option<Vec<u64>>) -> bool {
4654 match (&self.key, key) {
4655 (None, None) => true,
4656 (Some(a), Some(b)) => a == b,
4657 _ => false,
4658 }
4659 }
4660
4661 pub(crate) fn penalty_pseudologdet_original(
4676 &self,
4677 canonical_penalties: &[gam_terms::construction::CanonicalPenalty],
4678 lambdas: &[f64],
4679 p: usize,
4680 ) -> Result<Arc<penalty_logdet::PenaltyPseudologdet>, EstimationError> {
4681 if let Some(pld) = self.penalty_pseudologdet.get() {
4682 if pld.dim() != p {
4683 return Err(EstimationError::LayoutError(format!(
4684 "shared penalty pseudo-logdet frame mismatch: cached p={}, requested p={}",
4685 pld.dim(),
4686 p
4687 )));
4688 }
4689 return Ok(Arc::clone(pld));
4690 }
4691 let pld = Arc::new(
4692 penalty_logdet::PenaltyPseudologdet::from_penalties(
4693 canonical_penalties,
4694 lambdas,
4695 self.ridge_passport.penalty_logdet_ridge(),
4696 p,
4697 )
4698 .map_err(EstimationError::InvalidInput)?,
4699 );
4700 match self.penalty_pseudologdet.set(Arc::clone(&pld)) {
4701 Ok(()) => Ok(pld),
4702 Err(_) => Ok(Arc::clone(
4706 self.penalty_pseudologdet
4707 .get()
4708 .expect("OnceLock set raced, so it is initialized"),
4709 )),
4710 }
4711 }
4712}
4713
4714impl PenalizedGeometry for EvalShared {
4715 fn backend_kind(&self) -> GeometryBackendKind {
4716 match self.geometry {
4717 RemlGeometry::DenseSpectral => GeometryBackendKind::DenseSpectral,
4718 RemlGeometry::SparseExactSpd => GeometryBackendKind::SparseExactSpd,
4719 }
4720 }
4721}
4722
4723pub(crate) struct PirlsLruCache {
4733 pub(crate) map: HashMap<Vec<u64>, (Arc<PirlsResult>, u64, usize)>,
4735 pub(crate) byte_budget: usize,
4736 pub(crate) current_bytes: usize,
4737 pub(crate) clock: u64,
4738}
4739
4740impl PirlsLruCache {
4741 pub(crate) fn new(byte_budget: usize) -> Self {
4742 Self {
4743 map: HashMap::new(),
4744 byte_budget: byte_budget.max(1),
4745 current_bytes: 0,
4746 clock: 0,
4747 }
4748 }
4749
4750 pub(crate) fn get(&mut self, key: &Vec<u64>) -> Option<Arc<PirlsResult>> {
4751 if let Some(entry) = self.map.get_mut(key) {
4752 self.clock += 1;
4753 entry.1 = self.clock;
4754 Some(entry.0.clone())
4755 } else {
4756 None
4757 }
4758 }
4759
4760 pub(crate) fn insert(&mut self, key: Vec<u64>, value: Arc<PirlsResult>) {
4761 self.clock += 1;
4762 let bytes = pirls_result_cache_bytes(&value);
4763 if bytes > self.byte_budget {
4767 if let Some((_, _, prev_bytes)) = self.map.remove(&key) {
4768 self.current_bytes = self.current_bytes.saturating_sub(prev_bytes);
4769 }
4770 return;
4771 }
4772 if let Some((_, _, prev_bytes)) = self.map.remove(&key) {
4773 self.current_bytes = self.current_bytes.saturating_sub(prev_bytes);
4774 }
4775 while self.current_bytes + bytes > self.byte_budget {
4776 let evict_key = self
4777 .map
4778 .iter()
4779 .min_by_key(|(_, (_, ts, _))| *ts)
4780 .map(|(k, _)| k.clone());
4781 match evict_key {
4782 Some(k) => {
4783 if let Some((_, _, evict_bytes)) = self.map.remove(&k) {
4784 self.current_bytes = self.current_bytes.saturating_sub(evict_bytes);
4785 }
4786 }
4787 None => break,
4788 }
4789 }
4790 self.current_bytes += bytes;
4791 self.map.insert(key, (value, self.clock, bytes));
4792 }
4793
4794 pub(crate) fn clear(&mut self) {
4795 self.map.clear();
4796 self.current_bytes = 0;
4797 }
4798}
4799
4800#[derive(Clone, Copy, PartialEq, Eq)]
4801pub(crate) struct PenaltySubspaceCacheKey {
4802 pub(crate) penalty_matrix_fingerprint: u64,
4803 pub(crate) ridge_passport_signature: u64,
4804}
4805
4806pub(crate) struct PenaltySubspaceCache {
4807 pub(crate) entry: Option<(PenaltySubspaceCacheKey, Arc<outer_eval::PenaltySubspace>)>,
4808}
4809
4810impl PenaltySubspaceCache {
4811 pub(crate) fn new() -> Self {
4812 Self { entry: None }
4813 }
4814
4815 pub(crate) fn get(
4816 &self,
4817 key: &PenaltySubspaceCacheKey,
4818 ) -> Option<Arc<outer_eval::PenaltySubspace>> {
4819 self.entry
4820 .as_ref()
4821 .filter(|(cached_key, _)| cached_key == key)
4822 .map(|(_, value)| value.clone())
4823 }
4824
4825 pub(crate) fn insert(
4826 &mut self,
4827 key: PenaltySubspaceCacheKey,
4828 value: Arc<outer_eval::PenaltySubspace>,
4829 ) {
4830 self.entry = Some((key, value));
4831 }
4832
4833 pub(crate) fn clear(&mut self) {
4834 self.entry = None;
4835 }
4836}
4837
4838impl PenaltySubspaceCacheKey {
4839 pub(crate) fn from_inputs(
4844 e_transformed: &ndarray::Array2<f64>,
4845 ridge_passport: &gam_problem::RidgePassport,
4846 ) -> Self {
4847 use std::collections::hash_map::DefaultHasher;
4848 use std::hash::{Hash, Hasher};
4849 let mut hasher = DefaultHasher::new();
4850 e_transformed.nrows().hash(&mut hasher);
4851 e_transformed.ncols().hash(&mut hasher);
4852 for value in e_transformed.iter() {
4853 value.to_bits().hash(&mut hasher);
4854 }
4855 let penalty_matrix_fingerprint = hasher.finish();
4856 let mut ridge_hasher = DefaultHasher::new();
4857 ridge_passport.delta.to_bits().hash(&mut ridge_hasher);
4858 (ridge_passport.matrix_form as u8).hash(&mut ridge_hasher);
4859 ridge_passport
4860 .policy
4861 .include_penalty_logdet
4862 .hash(&mut ridge_hasher);
4863 ridge_passport
4864 .policy
4865 .include_laplacehessian
4866 .hash(&mut ridge_hasher);
4867 let ridge_passport_signature = ridge_hasher.finish();
4868 Self {
4869 penalty_matrix_fingerprint,
4870 ridge_passport_signature,
4871 }
4872 }
4873}
4874
4875pub(crate) fn pirls_result_cache_bytes(result: &PirlsResult) -> usize {
4890 use std::mem::size_of;
4891 let n_array_elems = result.final_eta.len()
4892 + result.solveweights.len()
4893 + result.solveworking_response.len()
4894 + result.solvemu.len()
4895 + result.solve_c_array.len()
4896 + result.solve_d_array.len();
4897 let p = result.beta_transformed.0.len();
4898 let pen_h = symmetric_matrix_cache_bytes(&result.penalized_hessian_transformed);
4899 let stab_h = symmetric_matrix_cache_bytes(&result.stabilizedhessian_transformed);
4900 let reparam = (result.reparam_result.s_transformed.len()
4901 + result.reparam_result.qs.len()
4902 + result.reparam_result.e_transformed.len()
4903 + result.reparam_result.det1.len())
4904 * size_of::<f64>();
4905 n_array_elems * size_of::<f64>() + p * size_of::<f64>() + pen_h + stab_h + reparam + 1024
4906}
4907
4908pub(crate) fn symmetric_matrix_cache_bytes(m: &gam_linalg::matrix::SymmetricMatrix) -> usize {
4909 use gam_linalg::matrix::SymmetricMatrix;
4910 use std::mem::size_of;
4911 match m {
4912 SymmetricMatrix::Dense(a) => a.len() * size_of::<f64>(),
4913 SymmetricMatrix::Sparse(s) => {
4914 let (symbolic, values) = s.parts();
4916 values.len() * (size_of::<f64>() + size_of::<usize>())
4917 + std::mem::size_of_val(symbolic.col_ptr())
4918 }
4919 }
4920}
4921
4922pub(crate) const OUTER_EVAL_LRU_CAPACITY: usize = 8;
4930
4931pub(crate) struct OuterEvalLru {
4945 capacity: usize,
4946 entries: std::collections::VecDeque<(Vec<u64>, OuterEval)>,
4948}
4949
4950impl OuterEvalLru {
4951 pub(crate) fn new(capacity: usize) -> Self {
4952 Self {
4953 capacity: capacity.max(1),
4954 entries: std::collections::VecDeque::new(),
4955 }
4956 }
4957
4958 pub(crate) fn get(&mut self, key: &[u64]) -> Option<OuterEval> {
4962 let pos = self
4963 .entries
4964 .iter()
4965 .position(|(k, _)| k.as_slice() == key)?;
4966 let entry = self.entries.remove(pos)?;
4967 let eval = entry.1.clone();
4968 self.entries.push_back(entry);
4969 Some(eval)
4970 }
4971
4972 pub(crate) fn insert(&mut self, key: Vec<u64>, eval: OuterEval) {
4975 if let Some(pos) = self
4976 .entries
4977 .iter()
4978 .position(|(k, _)| k.as_slice() == key.as_slice())
4979 {
4980 self.entries.remove(pos);
4981 }
4982 self.entries.push_back((key, eval));
4983 while self.entries.len() > self.capacity {
4984 self.entries.pop_front();
4985 }
4986 }
4987
4988 pub(crate) fn clear(&mut self) {
4989 self.entries.clear();
4990 }
4991}
4992
4993pub(crate) struct EvalCacheManager {
4998 pub(crate) pirls_cache: RwLock<PirlsLruCache>,
4999 pub(crate) penalty_subspace_cache: RwLock<PenaltySubspaceCache>,
5000 pub(crate) current_eval_bundle: RwLock<Option<EvalShared>>,
5001 pub(crate) current_outer_eval: RwLock<Option<(Vec<u64>, OuterEval)>>,
5005 pub(crate) outer_eval_lru: RwLock<OuterEvalLru>,
5019 pub(crate) pirls_cache_enabled: AtomicBool,
5020}
5021
5022impl EvalCacheManager {
5023 pub(crate) fn new() -> Self {
5024 Self {
5025 pirls_cache: RwLock::new(PirlsLruCache::new(PIRLS_CACHE_BYTE_BUDGET)),
5026 penalty_subspace_cache: RwLock::new(PenaltySubspaceCache::new()),
5027 current_eval_bundle: RwLock::new(None),
5028 current_outer_eval: RwLock::new(None),
5029 outer_eval_lru: RwLock::new(OuterEvalLru::new(OUTER_EVAL_LRU_CAPACITY)),
5030 pirls_cache_enabled: AtomicBool::new(true),
5031 }
5032 }
5033
5034 pub(crate) fn sanitized_rhokey(rho: &Array1<f64>) -> Option<Vec<u64>> {
5038 self::rho_key::sanitized_rhokey(rho)
5039 }
5040
5041 pub(super) fn cached_penalty_subspace<F>(
5048 &self,
5049 e_transformed: &ndarray::Array2<f64>,
5050 ridge_passport: &gam_problem::RidgePassport,
5051 build: F,
5052 ) -> Result<Arc<outer_eval::PenaltySubspace>, EstimationError>
5053 where
5054 F: FnOnce() -> Result<outer_eval::PenaltySubspace, EstimationError>,
5055 {
5056 let key = PenaltySubspaceCacheKey::from_inputs(e_transformed, ridge_passport);
5057 if let Some(hit) = self.penalty_subspace_cache.read().unwrap().get(&key) {
5058 return Ok(hit);
5059 }
5060 let value = Arc::new(build()?);
5061 self.penalty_subspace_cache
5062 .write()
5063 .unwrap()
5064 .insert(key, value.clone());
5065 Ok(value)
5066 }
5067
5068 pub(crate) fn cached_eval_bundle(&self, key: &Option<Vec<u64>>) -> Option<EvalShared> {
5069 let guard = self.current_eval_bundle.read().unwrap();
5070 let bundle: &EvalShared = guard.as_ref()?;
5071 bundle.matches(key).then(|| bundle.clone())
5072 }
5073
5074 pub(crate) fn store_eval_bundle(&self, bundle: EvalShared) {
5075 *self.current_eval_bundle.write().unwrap() = Some(bundle);
5076 }
5077
5078 pub(crate) fn cached_outer_eval(&self, key: &Option<Vec<u64>>) -> Option<OuterEval> {
5079 let key = key.as_ref()?;
5080 self.outer_eval_lru.write().unwrap().get(key)
5087 }
5088
5089 pub(crate) fn store_outer_eval(&self, key: &Option<Vec<u64>>, eval: &OuterEval) {
5090 if let Some(key) = key.clone() {
5091 *self.current_outer_eval.write().unwrap() = Some((key.clone(), eval.clone()));
5095 self.outer_eval_lru.write().unwrap().insert(key, eval.clone());
5096 }
5097 }
5098
5099 pub(crate) fn invalidate_eval_bundle(&self) {
5100 self.current_eval_bundle.write().unwrap().take();
5101 self.current_outer_eval.write().unwrap().take();
5102 self.outer_eval_lru.write().unwrap().clear();
5103 }
5104
5105 pub(crate) fn clear_eval_and_factor_caches(&self) {
5106 self.invalidate_eval_bundle();
5107 self.penalty_subspace_cache.write().unwrap().clear();
5108 }
5109}
5110
5111pub(crate) struct RemlArena {
5114 pub(crate) cost_eval_count: RwLock<u64>,
5115 pub(crate) inner_pirls_solve_count: AtomicU64,
5128 pub(crate) lastgradient_used_stochastic_fallback: AtomicBool,
5129}
5130
5131impl RemlArena {
5132 pub(crate) fn new() -> Self {
5133 Self {
5134 cost_eval_count: RwLock::new(0),
5135 inner_pirls_solve_count: AtomicU64::new(0),
5136 lastgradient_used_stochastic_fallback: AtomicBool::new(false),
5137 }
5138 }
5139}
5140
5141pub(crate) struct AloFrozenNuisance {
5142 pub(crate) n_obs: usize,
5143 pub(crate) influence_scale: Vec<f64>,
5144 pub(crate) phi: f64,
5145}
5146
5147pub(crate) struct RemlState<'a> {
5148 pub(crate) y: ArrayView1<'a, f64>,
5149 pub(crate) x: DesignMatrix,
5150 pub(crate) weights: ArrayView1<'a, f64>,
5151 pub(crate) offset: Array1<f64>,
5152 pub(crate) canonical_penalties: Arc<Vec<gam_terms::construction::CanonicalPenalty>>,
5156 pub(crate) balanced_penalty_root: Array2<f64>,
5157 pub(crate) reparam_invariant: ReparamInvariant,
5158 pub(crate) sparse_penalty_block_count: Option<usize>,
5159 pub(crate) p: usize,
5160 pub(crate) config: Arc<RemlConfig>,
5161 pub(crate) runtime_mixture_link_state: Option<gam_problem::MixtureLinkState>,
5162 pub(crate) runtime_sas_link_state: Option<SasLinkState>,
5163 pub(crate) nullspace_dims: Vec<usize>,
5164 pub(crate) coefficient_lower_bounds: Option<Array1<f64>>,
5165 pub(crate) linear_constraints: Option<crate::pirls::LinearInequalityConstraints>,
5166 pub(crate) penalty_shrinkage_floor: Option<f64>,
5168 pub(crate) rho_prior: gam_problem::RhoPrior,
5170
5171 pub(crate) cache_manager: EvalCacheManager,
5172 pub(crate) arena: RemlArena,
5173 pub(crate) warm_start_beta: RwLock<Option<Coefficients>>,
5174 pub(crate) warm_start_rho: RwLock<Option<Array1<f64>>>,
5184 pub(crate) prev_warm_start_beta: RwLock<Option<Coefficients>>,
5185 pub(crate) prev_warm_start_rho: RwLock<Option<Array1<f64>>>,
5186 pub(crate) warm_start_enabled: AtomicBool,
5187 pub(crate) screening_max_inner_iterations: Arc<AtomicUsize>,
5188 pub(crate) outer_inner_cap: Arc<AtomicUsize>,
5203
5204 pub(crate) last_inner_iters: Arc<AtomicUsize>,
5217 pub(crate) last_inner_converged: Arc<AtomicBool>,
5218
5219 pub(crate) ift_warm_start_cache: RwLock<Option<IftWarmStartCache>>,
5235
5236 pub(crate) last_pirls_lm_lambda: Arc<AtomicU64>,
5248
5249 pub(crate) frozen_negbin_theta: Arc<AtomicU64>,
5261
5262 pub(crate) frozen_tweedie_phi: Arc<AtomicU64>,
5276
5277 pub(crate) frozen_gamma_shape: Arc<AtomicU64>,
5294
5295 pub(crate) last_ift_prediction_residual: Arc<AtomicU64>,
5317
5318 pub(crate) last_pirls_accept_rho: Arc<AtomicU64>,
5333
5334 pub(crate) ift_cached_factor: RwLock<Option<Arc<dyn gam_linalg::matrix::FactorizedSystem>>>,
5345
5346 pub(crate) kronecker_penalty_system: Option<gam_terms::smooth::KroneckerPenaltySystem>,
5350 pub(crate) kronecker_factored: Option<gam_terms::basis::KroneckerFactoredBasis>,
5353
5354 pub(crate) gaussian_fixed_cache: RwLock<Option<Arc<crate::pirls::GaussianFixedCache>>>,
5364 pub(crate) gaussian_psi_gram_deriv:
5375 RwLock<Option<Arc<(ndarray::Array2<f64>, ndarray::Array1<f64>)>>>,
5376 pub(crate) glm_psi_gram_deriv:
5394 RwLock<Option<Arc<(ndarray::Array2<f64>, ndarray::Array1<f64>)>>>,
5395 pub(crate) glm_first_step_gram: RwLock<Option<Arc<ndarray::Array2<f64>>>>,
5414 pub(crate) flat_glm_first_step_gram: RwLock<Option<Arc<ndarray::Array2<f64>>>>,
5424 pub(crate) alo_frozen_nuisance: RwLock<Option<AloFrozenNuisance>>,
5431
5432 pub(crate) alo_provably_inactive: RwLock<Option<bool>>,
5450
5451 pub(crate) persistent_warm_start_key: RwLock<Option<String>>,
5454 pub(crate) persistent_latent_values_fingerprint: Option<u64>,
5455 pub(crate) persistent_latent_values_cache: RwLock<PersistentLatentValuesCache>,
5456 pub(crate) analytic_penalty_registry_fingerprint: u64,
5457 pub(crate) persistent_warm_start_loaded: AtomicBool,
5459 pub(crate) persistent_warm_start_store_suppression: AtomicUsize,
5463 pub(crate) alo_stabilization_suppression: AtomicUsize,
5473 pub(crate) persistent_warm_start_disk_enabled: AtomicBool,
5487 pub(crate) gaussian_weight_log_sum_half_cache: std::sync::OnceLock<f64>,
5499 pub(crate) gaussian_dp_floor_scale_cache: std::sync::OnceLock<f64>,
5500}