1use self::inner_strategy::GeometryBackendKind;
2use super::*;
3use gam_linalg::sparse_exact::SparseExactFactor;
4use crate::pirls::PIRLS_CACHE_BYTE_BUDGET;
5use crate::pirls::assemble_and_factor_sparse_penalized_system;
6use gam_terms::basis::LocalDesignJacobianProvider;
7use gam_problem::SasLinkState;
8use gam_problem::OuterEval;
9use ndarray::{Array1, Array2, s};
10use std::collections::{HashMap, VecDeque};
11use std::ops::Range;
12use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize};
13use std::sync::{Arc, RwLock};
14
15pub mod assembly;
16pub mod atoms;
17pub(crate) mod continuation;
18pub(crate) mod eval;
19mod firth;
20pub(super) mod hyper;
21mod inner_strategy;
22pub mod jeffreys_subspace;
25pub mod outer_eval;
26pub mod penalty_logdet;
27pub mod per_atom_efs;
28pub mod reml_outer_engine;
29mod rho_key;
30mod sparse_exact_penalty;
31mod trace;
32
33pub(crate) use sparse_exact_penalty::sparse_penalty_block_count_from_canonical;
34
35pub(crate) const EXACT_TAU_TAU_HESSIAN_DENSE_CACHE_BUDGET_BYTES: usize = 512 * 1024 * 1024;
36pub(crate) const FIRTH_MAX_OBSERVATIONS: usize = 20_000;
37pub(crate) const FIRTH_MAX_COEFFICIENTS: usize = 256;
38pub(crate) const FIRTH_MAX_LINEAR_WORK: usize = 2_000_000;
39pub(crate) const FIRTH_MAX_QUADRATIC_WORK: usize = 100_000_000;
40pub(crate) const PERSISTENT_LATENT_VALUES_CACHE_CAPACITY: usize = 8;
41
42#[derive(Debug)]
43pub(crate) struct PersistentLatentValuesCache {
44 pub(crate) entries: HashMap<String, Array2<f64>>,
45 pub(crate) lru: VecDeque<String>,
46 pub(crate) capacity: usize,
47}
48
49impl Default for PersistentLatentValuesCache {
50 fn default() -> Self {
51 Self {
52 entries: HashMap::new(),
53 lru: VecDeque::new(),
54 capacity: PERSISTENT_LATENT_VALUES_CACHE_CAPACITY,
55 }
56 }
57}
58
59impl PersistentLatentValuesCache {
60 pub(crate) fn lookup(
61 &mut self,
62 key: &str,
63 n_obs: usize,
64 latent_dim: usize,
65 ) -> Option<Array2<f64>> {
66 let values = self.entries.get(key)?;
67 if values.dim() != (n_obs, latent_dim) {
68 return None;
69 }
70 let values = values.clone();
71 self.touch(key.to_string());
72 Some(values)
73 }
74
75 pub(crate) fn insert(&mut self, key: String, values: Array2<f64>) {
76 if values.iter().any(|value| !value.is_finite()) {
77 return;
78 }
79 self.entries.insert(key.clone(), values);
80 self.touch(key);
81 while self.entries.len() > self.capacity {
82 let Some(evicted) = self.lru.pop_front() else {
83 break;
84 };
85 self.entries.remove(&evicted);
86 }
87 }
88
89 pub(crate) fn touch(&mut self, key: String) {
90 if let Some(index) = self.lru.iter().position(|queued| queued == &key) {
91 self.lru.remove(index);
92 }
93 self.lru.push_back(key);
94 }
95}
96
97#[derive(Clone)]
102pub(crate) struct IftWarmStartCache {
103 pub beta_original: ndarray::Array1<f64>,
109 pub rho: ndarray::Array1<f64>,
112 pub penalized_hessian_transformed: gam_linalg::matrix::SymmetricMatrix,
116 pub qs: ndarray::Array2<f64>,
120 pub frame_was_original: bool,
124 pub lambda_s_beta_blocks: Option<Vec<ndarray::Array1<f64>>>,
141}
142
143#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
144pub(crate) struct TauTauPlanEstimate {
145 pub(crate) dense_x_bytes: usize,
146 pub(crate) first_order_tau_bytes: usize,
147 pub(crate) second_order_tau_bytes: usize,
148 pub(crate) penalty_first_bytes: usize,
149 pub(crate) penalty_pair_bytes: usize,
150 pub(crate) rho_tau_penalty_bytes: usize,
151 pub(crate) vector_cache_bytes: usize,
152 pub(crate) weighted_scratch_bytes: usize,
153}
154
155impl TauTauPlanEstimate {
156 pub(crate) fn total_bytes(self) -> usize {
157 self.dense_x_bytes
158 .saturating_add(self.first_order_tau_bytes)
159 .saturating_add(self.second_order_tau_bytes)
160 .saturating_add(self.penalty_first_bytes)
161 .saturating_add(self.penalty_pair_bytes)
162 .saturating_add(self.rho_tau_penalty_bytes)
163 .saturating_add(self.vector_cache_bytes)
164 .saturating_add(self.weighted_scratch_bytes)
165 }
166}
167
168#[derive(Clone, Copy, Debug, PartialEq, Eq)]
169pub(crate) struct TauTauHessianPolicy {
170 pub(crate) any_has_implicit: bool,
171 pub(crate) implicit_multidim_duchon: bool,
172 pub(crate) estimated_dense_tau_cache_bytes: usize,
173 pub(crate) gradient_plan: TauTauPlanEstimate,
174 pub(crate) hessian_plan: TauTauPlanEstimate,
175 pub(crate) budget_bytes: usize,
176 pub(crate) firth_pair_terms_unavailable: bool,
177}
178
179impl TauTauHessianPolicy {
180 pub(crate) fn prefer_gradient_only(self) -> bool {
208 self.firth_pair_terms_unavailable
209 }
210}
211
212pub(crate) fn exact_tau_tau_hessian_policy_with_firth(
213 n_obs: usize,
214 p_coeff: usize,
215 hyper_dirs: &[DirectionalHyperParam],
216 firth_pair_terms_unavailable: bool,
217) -> TauTauHessianPolicy {
218 let f64_bytes = std::mem::size_of::<f64>();
219 let dense_matrix_bytes =
220 |rows: usize, cols: usize| -> usize { rows.saturating_mul(cols).saturating_mul(f64_bytes) };
221 let dense_design_bytes = dense_matrix_bytes(n_obs, p_coeff);
222 let dense_penalty_bytes = dense_matrix_bytes(p_coeff, p_coeff);
223 let psi_dim = hyper_dirs.len();
224 let implicit_n_axes = hyper_dirs
225 .iter()
226 .find_map(DirectionalHyperParam::implicit_axis_count_hint)
227 .unwrap_or(0);
228 let gradient_uses_implicit_design = hyper_dirs
229 .iter()
230 .any(DirectionalHyperParam::has_implicit_operator)
231 && gam_terms::basis::should_use_implicit_operators_with_policy(
232 n_obs,
233 p_coeff,
234 implicit_n_axes,
235 &gam_runtime::resource::ResourcePolicy::default_library(),
236 );
237 let dense_first_order_count = hyper_dirs
238 .iter()
239 .filter(|dir| !dir.has_implicit_operator())
240 .count();
241 let first_penalty_component_count = hyper_dirs
242 .iter()
243 .map(DirectionalHyperParam::penalty_first_component_count)
244 .sum::<usize>();
245
246 let mut dense_second_order_count = 0usize;
247 let mut penalty_pair_count = 0usize;
248 for i in 0..psi_dim {
249 for j in i..psi_dim {
250 if hyper_dirs[i]
251 .x_tau_tau_entry_at(j)
252 .or_else(|| hyper_dirs[j].x_tau_tau_entry_at(i))
253 .is_some_and(|entry| !entry.uses_implicit_storage())
254 {
255 dense_second_order_count += if i == j { 1 } else { 2 };
256 }
257 if hyper_dirs[i].has_penaltysecond_pair_at(j)
258 || hyper_dirs[j].has_penaltysecond_pair_at(i)
259 {
260 penalty_pair_count += if i == j { 1 } else { 2 };
261 }
262 }
263 }
264
265 let gradient_dense_first_order_count = if gradient_uses_implicit_design {
266 dense_first_order_count
267 } else {
268 psi_dim
269 };
270 let gradient_needs_dense_x =
271 firth_pair_terms_unavailable || gradient_dense_first_order_count > 0;
272 let gradient_plan = TauTauPlanEstimate {
273 dense_x_bytes: if gradient_needs_dense_x {
274 dense_design_bytes
275 } else {
276 0
277 },
278 first_order_tau_bytes: if gradient_dense_first_order_count > 0 {
279 dense_design_bytes
280 } else {
281 0
282 },
283 second_order_tau_bytes: 0,
284 penalty_first_bytes: psi_dim.saturating_mul(dense_penalty_bytes),
285 penalty_pair_bytes: 0,
286 rho_tau_penalty_bytes: 0,
287 vector_cache_bytes: n_obs.saturating_mul(f64_bytes),
288 weighted_scratch_bytes: dense_penalty_bytes,
289 };
290 let hessian_plan = TauTauPlanEstimate {
291 dense_x_bytes: if psi_dim > 0 { dense_design_bytes } else { 0 },
292 first_order_tau_bytes: dense_first_order_count.saturating_mul(dense_design_bytes),
293 second_order_tau_bytes: dense_second_order_count.saturating_mul(dense_design_bytes),
294 penalty_first_bytes: psi_dim.saturating_mul(dense_penalty_bytes),
295 penalty_pair_bytes: penalty_pair_count.saturating_mul(dense_penalty_bytes),
296 rho_tau_penalty_bytes: first_penalty_component_count
297 .saturating_mul(2)
298 .saturating_mul(dense_penalty_bytes),
299 vector_cache_bytes: psi_dim.saturating_mul(n_obs).saturating_mul(f64_bytes),
300 weighted_scratch_bytes: dense_penalty_bytes,
301 };
302 let any_has_implicit = hyper_dirs
303 .iter()
304 .any(DirectionalHyperParam::has_implicit_operator);
305 let implicit_multidim_duchon = hyper_dirs
306 .iter()
307 .any(DirectionalHyperParam::has_implicit_multidim_duchon);
308 let estimated_dense_tau_cache_bytes = hessian_plan
309 .first_order_tau_bytes
310 .saturating_add(hessian_plan.second_order_tau_bytes);
311 TauTauHessianPolicy {
312 any_has_implicit,
313 implicit_multidim_duchon,
314 estimated_dense_tau_cache_bytes,
315 gradient_plan,
316 hessian_plan,
317 budget_bytes: EXACT_TAU_TAU_HESSIAN_DENSE_CACHE_BUDGET_BYTES,
318 firth_pair_terms_unavailable: firth_pair_terms_unavailable && !hyper_dirs.is_empty(),
319 }
320}
321
322pub(crate) fn firth_problem_scale_allows(n_obs: usize, p_coeff: usize) -> bool {
323 let linear_work = n_obs.saturating_mul(p_coeff);
324 let quadratic_work = linear_work.saturating_mul(p_coeff);
325 n_obs <= FIRTH_MAX_OBSERVATIONS
326 && p_coeff <= FIRTH_MAX_COEFFICIENTS
327 && linear_work <= FIRTH_MAX_LINEAR_WORK
328 && quadratic_work <= FIRTH_MAX_QUADRATIC_WORK
329}
330
331#[cfg(test)]
332mod tests {
333 use super::{
334 DirectionalHyperParam, EvalCacheManager, EvalShared, HyperDesignDerivative,
335 HyperPenaltyDerivative, ImplicitDerivLevel, RemlConfig, RemlState,
336 };
337 use crate::estimate::EstimationError;
338 use gam_linalg::faer_ndarray::FaerCholesky;
339 use gam_linalg::matrix::symmetrize_in_place;
340 use crate::pirls::PirlsCoordinateFrame;
341 use gam_terms::basis::{ImplicitDesignPsiDerivative, RadialScalarKind};
342 use gam_problem::{
343 GlmLikelihoodSpec, InverseLink, LikelihoodSpec, ResponseFamily, StandardLink,
344 };
345 use faer::Side;
346 use gam_problem::{HessianResult, OuterEval};
347 use ndarray::{Array1, Array2, array, s};
348 use std::sync::Arc;
349
350 pub(crate) fn binomial_logit_glm_spec() -> GlmLikelihoodSpec {
353 GlmLikelihoodSpec::canonical(LikelihoodSpec::new(
354 ResponseFamily::Binomial,
355 InverseLink::Standard(StandardLink::Logit),
356 ))
357 }
358
359 pub(crate) fn gaussian_identity_glm_spec() -> GlmLikelihoodSpec {
362 GlmLikelihoodSpec::canonical(LikelihoodSpec::new(
363 ResponseFamily::Gaussian,
364 InverseLink::Standard(StandardLink::Identity),
365 ))
366 }
367
368 impl DirectionalHyperParam {
369 pub(super) fn new(
370 x_tau_original: Array2<f64>,
371 penalty_first_components: Vec<(usize, Array2<f64>)>,
372 x_tau_tau_original: Option<Vec<Option<Array2<f64>>>>,
373 penaltysecond_components: Option<Vec<Option<Vec<(usize, Array2<f64>)>>>>,
374 ) -> Result<Self, EstimationError> {
375 let x_tau_tau_original = x_tau_tau_original.map(|rows| {
376 rows.into_iter()
377 .map(|entry| entry.map(HyperDesignDerivative::from))
378 .collect::<Vec<_>>()
379 });
380 let penalty_first_components = penalty_first_components
381 .into_iter()
382 .map(|(idx, matrix)| (idx, HyperPenaltyDerivative::from(matrix)))
383 .collect();
384 let penaltysecond_components = penaltysecond_components.map(|rows| {
385 rows.into_iter()
386 .map(|row| {
387 row.map(|components| {
388 components
389 .into_iter()
390 .map(|(idx, matrix)| (idx, HyperPenaltyDerivative::from(matrix)))
391 .collect::<Vec<_>>()
392 })
393 })
394 .collect::<Vec<_>>()
395 });
396 Self::new_compact(
397 HyperDesignDerivative::from(x_tau_original),
398 penalty_first_components,
399 x_tau_tau_original,
400 penaltysecond_components,
401 )
402 }
403
404 pub(super) fn single_penalty(
405 penalty_index: usize,
406 x_tau_original: Array2<f64>,
407 s_tau_original: Array2<f64>,
408 x_tau_tau_original: Option<Vec<Option<Array2<f64>>>>,
409 s_tau_tau_original: Option<Vec<Option<Array2<f64>>>>,
410 ) -> Result<Self, EstimationError> {
411 let penaltysecond_components = s_tau_tau_original.map(|rows| {
412 rows.into_iter()
413 .map(|mat| mat.map(|mat| vec![(penalty_index, mat)]))
414 .collect::<Vec<_>>()
415 });
416 Self::new(
417 x_tau_original,
418 vec![(penalty_index, s_tau_original)],
419 x_tau_tau_original,
420 penaltysecond_components,
421 )
422 }
423 }
424
425 #[test]
426 pub(crate) fn firth_problem_scale_gate_blocks_large_quadratic_work() {
427 assert!(super::firth_problem_scale_allows(2_000, 200));
428 assert!(!super::firth_problem_scale_allows(4_800, 241));
429 assert!(!super::firth_problem_scale_allows(4_800, 433));
430 }
431
432 #[test]
433 pub(crate) fn tau_tau_hessian_policy_prefers_gradient_only_for_implicit_tau() {
434 let operator = ImplicitDesignPsiDerivative::new(
435 array![1.0, 2.0, 3.0, 4.0],
436 array![0.5, -1.0, 1.5, 2.0],
437 array![0.1, 0.2, 0.3, 0.4],
438 array![[1.0, 0.2], [0.5, 0.1], [1.5, 0.3], [2.0, 0.4]],
439 None,
440 None,
441 2,
442 2,
443 1,
444 2,
445 );
446 let dir = DirectionalHyperParam::new_compact(
447 HyperDesignDerivative::from_implicit(
448 Arc::new(operator),
449 ImplicitDerivLevel::First(0),
450 1..4,
451 5,
452 ),
453 Vec::new(),
454 None,
455 None,
456 )
457 .expect("implicit directional hyperparam");
458 let policy = super::exact_tau_tau_hessian_policy_with_firth(10, 5, &[dir], false);
459 assert!(policy.any_has_implicit);
460 assert_eq!(
461 policy.gradient_plan.dense_x_bytes,
462 10 * 5 * std::mem::size_of::<f64>()
463 );
464 assert!(!policy.prefer_gradient_only());
465 }
466
467 #[test]
468 pub(crate) fn tau_tau_hessian_policy_does_not_force_gradient_only_for_implicit_multidim_duchon()
469 {
470 let operator = ImplicitDesignPsiDerivative::new_streaming(
478 Arc::new(array![[0.0, 0.0], [1.0, 0.2]]),
479 Arc::new(array![[0.0, 0.0], [1.0, 1.0]]),
480 vec![0.0, 0.0],
481 RadialScalarKind::PureDuchon {
482 block_order: 1,
483 p_order: 0,
484 s_order: 0,
485 dim: 2,
486 },
487 None,
488 None,
489 0,
490 );
491 let dir = DirectionalHyperParam::new_compact(
492 HyperDesignDerivative::from_implicit(
493 Arc::new(operator),
494 ImplicitDerivLevel::First(0),
495 0..2,
496 2,
497 ),
498 Vec::new(),
499 None,
500 None,
501 )
502 .expect("implicit duchon directional hyperparam");
503 let policy = super::exact_tau_tau_hessian_policy_with_firth(10, 5, &[dir], false);
504 assert!(policy.any_has_implicit);
505 assert!(policy.implicit_multidim_duchon);
506 assert!(!policy.prefer_gradient_only());
507 }
508
509 #[test]
510 pub(crate) fn tau_tau_hessian_policy_does_not_force_gradient_only_when_cache_budget_is_exceeded()
511 {
512 let dirs = (0..16)
518 .map(|_| {
519 DirectionalHyperParam::new_compact(
520 HyperDesignDerivative::from(Array2::<f64>::zeros((2, 2))),
521 Vec::new(),
522 None,
523 None,
524 )
525 .expect("dense directional hyperparam")
526 })
527 .collect::<Vec<_>>();
528 let policy = super::exact_tau_tau_hessian_policy_with_firth(320_000, 71, &dirs, false);
529 assert!(!policy.any_has_implicit);
530 assert!(policy.hessian_plan.total_bytes() > policy.budget_bytes);
531 assert!(policy.hessian_plan.total_bytes() > policy.gradient_plan.total_bytes());
532 assert!(!policy.prefer_gradient_only());
533 }
534
535 #[test]
536 pub(crate) fn tau_tau_hessian_policy_prefers_gradient_only_for_firth_pair_gap() {
537 let dir = DirectionalHyperParam::new_compact(
538 HyperDesignDerivative::from(Array2::<f64>::zeros((2, 2))),
539 Vec::new(),
540 None,
541 None,
542 )
543 .expect("dense directional hyperparam");
544 let policy = super::exact_tau_tau_hessian_policy_with_firth(10, 5, &[dir], true);
545 assert!(policy.firth_pair_terms_unavailable);
546 assert!(policy.prefer_gradient_only());
547 }
548
549 trait LogitDesignMotionFixture {
557 fn y(&self) -> &Array1<f64>;
558 fn w(&self) -> &Array1<f64>;
559 fn x(&self) -> &Array2<f64>;
560 fn s0(&self) -> &Array2<f64>;
561 fn cfg(&self) -> &RemlConfig;
562 fn rho(&self) -> &Array1<f64>;
563
564 fn state(&self) -> RemlState<'_> {
565 build_logit_state(self.y(), self.w(), self.x(), self.s0(), self.cfg())
566 }
567
568 fn state_perturbed(
569 &self,
570 x_tau: &Array2<f64>,
571 s_tau: &Array2<f64>,
572 eps: f64,
573 ) -> (RemlState<'_>, RemlState<'_>) {
574 let x_plus = self.x() + &x_tau.mapv(|v| eps * v);
575 let x_minus = self.x() - &x_tau.mapv(|v| eps * v);
576 let s_plus = self.s0() + &s_tau.mapv(|v| eps * v);
577 let s_minus = self.s0() - &s_tau.mapv(|v| eps * v);
578 (
579 build_logit_state(self.y(), self.w(), &x_plus, &s_plus, self.cfg()),
580 build_logit_state(self.y(), self.w(), &x_minus, &s_minus, self.cfg()),
581 )
582 }
583
584 fn fd_directional_gradient(&self, x_tau: &Array2<f64>, s_tau: &Array2<f64>) -> f64 {
586 let h = 2e-5;
587 let (state_plus, state_minus) = self.state_perturbed(x_tau, s_tau, h);
588 let v_plus = state_plus.compute_cost(self.rho()).expect("cost+");
589 let v_minus = state_minus.compute_cost(self.rho()).expect("cost-");
590 (v_plus - v_minus) / (2.0 * h)
591 }
592 }
593
594 pub(crate) fn build_logit_state<'a>(
595 y: &'a Array1<f64>,
596 w: &'a Array1<f64>,
597 x: &Array2<f64>,
598 s: &Array2<f64>,
599 cfg: &'a RemlConfig,
600 ) -> RemlState<'a> {
601 use crate::estimate::PenaltySpec;
602 let p = x.ncols();
603 let offset = Array1::<f64>::zeros(y.len());
604 let spec = PenaltySpec::Dense(s.clone());
605 let canonical = gam_terms::construction::canonicalize_penalty_specs(&[spec], &[1], p, "test")
606 .map(|(canonical, _)| canonical)
607 .expect("canonicalize");
608 RemlState::newwith_offset(
609 y.view(),
610 x.clone(),
611 w.view(),
612 offset.view(),
613 canonical,
614 p,
615 cfg,
616 Some(vec![1]),
617 None,
618 None,
619 )
620 .expect("state")
621 }
622
623 #[test]
624 fn repeated_penalty_ranges_keep_analytic_outer_hessian() {
625 let y = array![0.2, -0.1, 0.3, 0.0];
626 let w = Array1::<f64>::ones(y.len());
627 let x = array![[1.0, -0.7], [1.0, -0.2], [1.0, 0.3], [1.0, 0.9]];
628 let offset = Array1::<f64>::zeros(y.len());
629 let cfg = RemlConfig::external(gaussian_identity_glm_spec(), 1e-10, false);
630 let p = x.ncols();
631 let canonical = vec![
632 gam_terms::construction::CanonicalPenalty::from_dense_root(array![[0.0, 1.0]], p),
633 gam_terms::construction::CanonicalPenalty::from_dense_root(array![[1.0, 0.0]], p),
634 ];
635 let state = RemlState::newwith_offset(
636 y.view(),
637 x,
638 w.view(),
639 offset.view(),
640 canonical,
641 p,
642 &cfg,
643 Some(vec![1, 1]),
644 None,
645 None,
646 )
647 .expect("state");
648
649 assert!(
650 state.analytic_outer_hessian_enabled(),
651 "double-penalty-style repeated coefficient ranges must still route to exact Hessian"
652 );
653 }
654
655 pub(crate) fn poisson_log_glm_spec() -> GlmLikelihoodSpec {
656 GlmLikelihoodSpec::canonical(LikelihoodSpec::new(
657 ResponseFamily::Poisson,
658 InverseLink::Standard(StandardLink::Log),
659 ))
660 }
661
662 #[test]
676 pub(crate) fn fixed_dispersion_laml_surface_is_replication_invariant() {
677 let n = 200usize;
678 let p = 8usize;
679 let c = 3usize;
680 let mut x = Array2::<f64>::zeros((n, p));
681 let mut y = Array1::<f64>::zeros(n);
682 for i in 0..n {
683 let t = (i as f64) / ((n - 1) as f64);
684 let tau = std::f64::consts::TAU;
685 x[[i, 0]] = 1.0;
686 x[[i, 1]] = t;
687 x[[i, 2]] = (tau * t).sin();
688 x[[i, 3]] = (tau * t).cos();
689 x[[i, 4]] = (2.0 * tau * t).sin();
690 x[[i, 5]] = (2.0 * tau * t).cos();
691 x[[i, 6]] = (3.0 * tau * t).sin();
692 x[[i, 7]] = (3.0 * tau * t).cos();
693 let eta = 0.3 + 0.9 * (1.4 * (t - 0.5)).sin();
694 y[i] = (eta.exp() + 0.5 * ((i as f64) * 2.399_963).sin())
696 .round()
697 .max(0.0);
698 }
699 let mut s = Array2::<f64>::zeros((p, p));
700 for j in 1..p {
701 s[[j, j]] = 1.0;
702 }
703
704 let mut x_rep = Array2::<f64>::zeros((n * c, p));
706 let mut y_rep = Array1::<f64>::zeros(n * c);
707 for r in 0..c {
708 for i in 0..n {
709 let row = r * n + i;
710 for j in 0..p {
711 x_rep[[row, j]] = x[[i, j]];
712 }
713 y_rep[row] = y[i];
714 }
715 }
716
717 let w_weighted = Array1::<f64>::from_elem(n, c as f64);
718 let w_rep = Array1::<f64>::ones(n * c);
719
720 let cfg = RemlConfig::external(poisson_log_glm_spec(), 1e-10, false);
721 let st_w = build_logit_state(&y, &w_weighted, &x, &s, &cfg);
722 let st_r = build_logit_state(&y_rep, &w_rep, &x_rep, &s, &cfg);
723
724 for &rho in &[-2.0_f64, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0] {
725 let r = Array1::from_elem(1, rho);
726 let cw = st_w.compute_cost(&r).expect("weighted cost");
727 let cr = st_r.compute_cost(&r).expect("replicated cost");
728 let gw = st_w.compute_gradient(&r).expect("weighted grad");
729 let gr = st_r.compute_gradient(&r).expect("replicated grad");
730 assert!(
733 (cw - cr).abs() <= 1e-9 * (1.0 + cw.abs()),
734 "LAML cost differs between w=c and c× replication at rho={rho}: \
735 cost_w={cw:.12e} cost_r={cr:.12e} diff={:.3e}",
736 cw - cr
737 );
738 assert!(
739 (gw[0] - gr[0]).abs() <= 1e-9 * (1.0 + gw[0].abs()),
740 "LAML gradient differs between w=c and c× replication at rho={rho}: \
741 g_w={:.12e} g_r={:.12e} diff={:.3e}",
742 gw[0],
743 gr[0],
744 gw[0] - gr[0]
745 );
746 }
747 }
748
749 #[test]
758 pub(crate) fn rho_weight_anchor_is_zero_for_fixed_dispersion() {
759 let n = 50usize;
760 let p = 3usize;
761 let mut x = Array2::<f64>::zeros((n, p));
762 let mut y = Array1::<f64>::zeros(n);
763 for i in 0..n {
764 let t = (i as f64) / ((n - 1) as f64);
765 x[[i, 0]] = 1.0;
766 x[[i, 1]] = t;
767 x[[i, 2]] = t * t;
768 y[i] = (1.0 + (3.0 * t).sin()).round().max(0.0);
769 }
770 let mut s = Array2::<f64>::zeros((p, p));
771 s[[2, 2]] = 1.0;
772 let c = 4.0_f64;
774 let w = Array1::<f64>::from_elem(n, c);
775
776 let cfg_pois = RemlConfig::external(poisson_log_glm_spec(), 1e-10, false);
777 let st_pois = build_logit_state(&y, &w, &x, &s, &cfg_pois);
778 assert_eq!(
779 st_pois.rho_weight_anchor(),
780 0.0,
781 "fixed-dispersion (Poisson) anchor must be 0, not the geometric-mean log-weight"
782 );
783
784 let cfg_gauss = RemlConfig::external(gaussian_identity_glm_spec(), 1e-10, false);
785 let st_gauss = build_logit_state(&y, &w, &x, &s, &cfg_gauss);
786 assert!(
787 (st_gauss.rho_weight_anchor() - c.ln()).abs() <= 1e-12,
788 "Gaussian-identity (profiled) anchor must be the geometric-mean log-weight ln(c)={:.6}, got {:.6}",
789 c.ln(),
790 st_gauss.rho_weight_anchor()
791 );
792 }
793
794 pub(crate) fn beta_original_from_bundle(bundle: &EvalShared) -> Array1<f64> {
795 let pr = bundle.pirls_result.as_ref();
796 match pr.coordinate_frame {
797 PirlsCoordinateFrame::OriginalSparseNative => pr.beta_transformed.as_ref().clone(),
798 PirlsCoordinateFrame::TransformedQs => {
799 pr.reparam_result.qs.dot(pr.beta_transformed.as_ref())
800 }
801 }
802 }
803
804 pub(crate) fn compute_joint_hypercostgradienthessian(
805 state: &RemlState<'_>,
806 theta: &Array1<f64>,
807 rho_dim: usize,
808 hyper_dirs: &[DirectionalHyperParam],
809 ) -> Result<(f64, Array1<f64>, Array2<f64>), EstimationError> {
810 let (cost, gradient, hessian) = state.compute_joint_hyper_eval_with_order(
811 theta,
812 rho_dim,
813 hyper_dirs,
814 crate::rho_optimizer::OuterEvalOrder::ValueGradientHessian,
815 )?;
816 Ok((
817 cost,
818 gradient,
819 hessian
820 .materialize_dense()
821 .map_err(EstimationError::RemlOptimizationFailed)?
822 .ok_or_else(|| {
823 EstimationError::RemlOptimizationFailed(
824 "joint hyper Hessian requested but unavailable".to_string(),
825 )
826 })?,
827 ))
828 }
829
830 pub(crate) fn h_original_from_bundle(bundle: &EvalShared) -> Array2<f64> {
831 let pr = bundle.pirls_result.as_ref();
832 match pr.coordinate_frame {
833 PirlsCoordinateFrame::OriginalSparseNative => bundle.h_total.as_ref().clone(),
834 PirlsCoordinateFrame::TransformedQs => {
835 let qs = &pr.reparam_result.qs;
836 let tmp = gam_linalg::faer_ndarray::fast_ab(qs, bundle.h_total.as_ref());
837 gam_linalg::faer_ndarray::fast_abt(&tmp, qs)
838 }
839 }
840 }
841
842 pub(crate) fn single_directional_tau_gradient(
843 state: &RemlState<'_>,
844 rho: &Array1<f64>,
845 hyper: DirectionalHyperParam,
846 ) -> Result<f64, EstimationError> {
847 let mut theta = Array1::<f64>::zeros(rho.len() + 1);
848 theta.slice_mut(s![..rho.len()]).assign(rho);
849 let (_, gradient, _) = state.compute_joint_hyper_eval_with_order(
850 &theta,
851 rho.len(),
852 &[hyper],
853 crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
854 )?;
855 Ok(gradient[rho.len()])
856 }
857
858 pub(crate) fn fd_directional_tau_cost_gradient(
859 y: &Array1<f64>,
860 w: &Array1<f64>,
861 x: &Array2<f64>,
862 s0: &Array2<f64>,
863 cfg: &RemlConfig,
864 rho: &Array1<f64>,
865 x_tau: &Array2<f64>,
866 s_tau: &Array2<f64>,
867 ) -> f64 {
868 let h = 2e-5;
869 let x_plus = x + &x_tau.mapv(|v| h * v);
870 let x_minus = x - &x_tau.mapv(|v| h * v);
871 let s_plus = s0 + &s_tau.mapv(|v| h * v);
872 let s_minus = s0 - &s_tau.mapv(|v| h * v);
873 let state_plus = build_logit_state(y, w, &x_plus, &s_plus, cfg);
874 let state_minus = build_logit_state(y, w, &x_minus, &s_minus, cfg);
875 let v_plus = state_plus.compute_cost(rho).expect("cost+");
876 let v_minus = state_minus.compute_cost(rho).expect("cost-");
877 (v_plus - v_minus) / (2.0 * h)
878 }
879
880 pub(crate) fn directional_tau_hessian_fd_reference(
881 y: &Array1<f64>,
882 w: &Array1<f64>,
883 x: &Array2<f64>,
884 s0: &Array2<f64>,
885 cfg: &RemlConfig,
886 rho: &Array1<f64>,
887 hyper_dirs: &[DirectionalHyperParam],
888 x_tau_mats: &[Array2<f64>],
889 s_tau_mats: &[Array2<f64>],
890 ) -> Array2<f64> {
891 assert_eq!(hyper_dirs.len(), x_tau_mats.len());
892 assert_eq!(hyper_dirs.len(), s_tau_mats.len());
893
894 const TARGET_PHYSICAL_STEP: f64 = 1e-5;
895
896 let n_dirs = hyper_dirs.len();
897 let mut h_ttfd = Array2::<f64>::zeros((n_dirs, n_dirs));
898 for j in 0..n_dirs {
899 let direction_scale = x_tau_mats[j]
900 .iter()
901 .chain(s_tau_mats[j].iter())
902 .fold(0.0_f64, |acc, value| acc.max(value.abs()));
903 let h = if direction_scale > 0.0 {
904 TARGET_PHYSICAL_STEP / direction_scale
905 } else {
906 TARGET_PHYSICAL_STEP
907 };
908
909 let x_plus = x + &x_tau_mats[j].mapv(|v| h * v);
910 let x_minus = x - &x_tau_mats[j].mapv(|v| h * v);
911 let s_plus = s0 + &s_tau_mats[j].mapv(|v| h * v);
912 let s_minus = s0 - &s_tau_mats[j].mapv(|v| h * v);
913
914 let state_plus = build_logit_state(y, w, &x_plus, &s_plus, cfg);
915 let state_minus = build_logit_state(y, w, &x_minus, &s_minus, cfg);
916 for i in 0..n_dirs {
917 let g_plus =
918 single_directional_tau_gradient(&state_plus, rho, hyper_dirs[i].clone())
919 .expect("g+ for FD");
920 let g_minus =
921 single_directional_tau_gradient(&state_minus, rho, hyper_dirs[i].clone())
922 .expect("g- for FD");
923 h_ttfd[[i, j]] = (g_plus - g_minus) / (2.0 * h);
924 }
925 }
926 symmetrize_in_place(&mut h_ttfd);
927 h_ttfd
928 }
929
930 #[test]
931 pub(crate) fn eval_cache_manager_stores_first_order_outer_eval() {
932 let cache = EvalCacheManager::new();
933 let rho = array![0.25, -0.0];
934 let rho_key = EvalCacheManager::sanitized_rhokey(&rho);
935 let eval = OuterEval {
936 cost: 3.5,
937 gradient: array![1.0, -2.0],
938 hessian: HessianResult::Unavailable,
939 inner_beta_hint: None,
940 };
941
942 cache.store_outer_eval(&rho_key, &eval);
943
944 let cached = cache
945 .cached_outer_eval(&rho_key)
946 .expect("first-order outer eval should be cached");
947 assert_eq!(cached.cost, eval.cost);
948 assert_eq!(cached.gradient, eval.gradient);
949 assert!(matches!(cached.hessian, HessianResult::Unavailable));
950
951 cache.invalidate_eval_bundle();
952 assert!(
953 cache.cached_outer_eval(&rho_key).is_none(),
954 "invalidating the bundle should clear the outer-eval cache too"
955 );
956 }
957
958 #[test]
959 pub(crate) fn reset_outer_seed_state_clears_pirls_cache() {
960 let y = array![0.0, 1.0, 1.0, 0.0, 0.0, 1.0];
966 let w = Array1::<f64>::ones(y.len());
967 let x = array![
968 [1.0, -1.0, 0.2],
969 [1.0, -0.5, -0.4],
970 [1.0, 0.0, 0.7],
971 [1.0, 0.4, -0.3],
972 [1.0, 0.9, 0.1],
973 [1.0, 1.3, -0.6],
974 ];
975 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.1, 0.15], [0.0, 0.15, 0.8],];
976 let rho = array![0.0];
977 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false);
978 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
979
980 state
983 .compute_outer_eval_with_order(
984 &rho,
985 crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
986 )
987 .expect("outer eval should succeed");
988
989 let populated_len = state.cache_manager.pirls_cache.read().unwrap().map.len();
990 assert!(
991 populated_len > 0,
992 "evaluating the outer objective should populate the PIRLS LRU, got {populated_len}"
993 );
994
995 state.reset_outer_seed_state();
996
997 let cleared_len = state.cache_manager.pirls_cache.read().unwrap().map.len();
998 assert_eq!(
999 cleared_len, 0,
1000 "reset_outer_seed_state must clear the cross-call PIRLS LRU; got {cleared_len} entries"
1001 );
1002 }
1003
1004 #[test]
1005 pub(crate) fn reset_outer_seed_state_preserves_frozen_negbin_theta_1448() {
1006 use std::sync::atomic::Ordering;
1023
1024 let y = array![0.0, 1.0, 1.0, 0.0, 0.0, 1.0];
1025 let w = Array1::<f64>::ones(y.len());
1026 let x = array![
1027 [1.0, -1.0, 0.2],
1028 [1.0, -0.5, -0.4],
1029 [1.0, 0.0, 0.7],
1030 [1.0, 0.4, -0.3],
1031 [1.0, 0.9, 0.1],
1032 [1.0, 1.3, -0.6],
1033 ];
1034 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.1, 0.15], [0.0, 0.15, 0.8],];
1035 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false);
1036 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1037
1038 let theta_final_bits = 2.5_f64.to_bits();
1040 state
1041 .frozen_negbin_theta
1042 .store(theta_final_bits, Ordering::Relaxed);
1043 assert_eq!(
1044 state.frozen_negbin_theta.load(Ordering::Relaxed),
1045 theta_final_bits,
1046 "precondition: the re-freeze stores θ_final into the frozen slot"
1047 );
1048
1049 state.reset_outer_seed_state();
1051
1052 assert_eq!(
1053 state.frozen_negbin_theta.load(Ordering::Relaxed),
1054 theta_final_bits,
1055 "reset_outer_seed_state (alternation-round reset) must PRESERVE the \
1056 re-frozen NB θ; clearing it would defeat the #1448 θ↔λ alternation \
1057 (the next ρ search would re-derive θ from the seed and never reach \
1058 the joint fixed point)"
1059 );
1060 }
1061
1062 #[test]
1063 pub(crate) fn implicit_hyper_design_derivative_respects_full_model_embedding() {
1064 let operator = ImplicitDesignPsiDerivative::new(
1065 array![1.0, 2.0, 3.0, 4.0],
1066 array![0.5, -1.0, 1.5, 2.0],
1067 array![0.1, 0.2, 0.3, 0.4],
1068 array![[1.0, 0.2], [0.5, 0.1], [1.5, 0.3], [2.0, 0.4]],
1069 None,
1070 None,
1071 2,
1072 2,
1073 1,
1074 2,
1075 );
1076 let local = operator
1077 .materialize_first(0)
1078 .expect("materialized first derivative");
1079 assert_eq!(
1080 local.ncols(),
1081 3,
1082 "operator-local derivative should stay smooth-local"
1083 );
1084
1085 let implicit = HyperDesignDerivative::from_implicit(
1086 Arc::new(operator),
1087 ImplicitDerivLevel::First(0),
1088 1..4,
1089 5,
1090 );
1091 let embedded = HyperDesignDerivative::from_embedded(local.clone(), 1..4, 5);
1092
1093 assert_eq!(implicit.nrows(), embedded.nrows());
1094 assert_eq!(implicit.ncols(), 5);
1095 assert_eq!(implicit.materialize(), embedded.materialize());
1096
1097 let u = array![7.0, 1.5, -2.0, 0.25, -3.0];
1098 let v = array![0.75, -1.25];
1099 assert_eq!(
1100 implicit.forward_mul_original(&u).expect("implicit forward"),
1101 embedded.forward_mul_original(&u).expect("embedded forward")
1102 );
1103 assert_eq!(
1104 implicit
1105 .transpose_mul_original(&v)
1106 .expect("implicit transpose"),
1107 embedded
1108 .transpose_mul_original(&v)
1109 .expect("embedded transpose")
1110 );
1111
1112 let qs = array![
1113 [1.0, 0.0, 0.0],
1114 [0.0, 1.0, 0.0],
1115 [0.0, 0.5, 0.5],
1116 [0.0, 0.0, 1.0],
1117 [0.0, 0.0, 0.0],
1118 ];
1119 assert_eq!(
1120 implicit
1121 .transformed(&qs, None)
1122 .expect("implicit transformed"),
1123 embedded
1124 .transformed(&qs, None)
1125 .expect("embedded transformed")
1126 );
1127 let u_transformed = array![1.0, -0.5, 2.0];
1128 assert_eq!(
1129 implicit
1130 .transformed_forward_mul(&qs, None, &u_transformed)
1131 .expect("implicit transformed forward"),
1132 embedded
1133 .transformed_forward_mul(&qs, None, &u_transformed)
1134 .expect("embedded transformed forward")
1135 );
1136 assert_eq!(
1137 implicit
1138 .transformed_transpose_mul(&qs, None, &v)
1139 .expect("implicit transformed transpose"),
1140 embedded
1141 .transformed_transpose_mul(&qs, None, &v)
1142 .expect("embedded transformed transpose")
1143 );
1144 }
1145
1146 #[test]
1147 pub(crate) fn directional_hyper_identities_match_finite_differences_logit() {
1148 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1149 let w = Array1::<f64>::ones(y.len());
1150 let x = array![
1151 [1.0, -1.2, 0.3],
1152 [1.0, -0.8, -0.4],
1153 [1.0, -0.3, 0.7],
1154 [1.0, 0.1, -0.9],
1155 [1.0, 0.5, 0.2],
1156 [1.0, 0.9, -0.1],
1157 [1.0, 1.3, 0.8],
1158 [1.0, 1.7, -0.6],
1159 ];
1160 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1161
1162 let x_tau = Array2::<f64>::zeros(x.raw_dim());
1167 let s_tau = array![[0.0, 0.0, 0.0], [0.0, 0.25, 0.04], [0.0, 0.04, 0.15],];
1168 let hyper =
1169 DirectionalHyperParam::single_penalty(0, x_tau.clone(), s_tau.clone(), None, None)
1170 .expect("single-penalty hyper direction");
1171 let rho = array![0.0];
1172
1173 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
1177 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1178 let bundle = state.obtain_eval_bundle(&rho).expect("bundle");
1179 let pr = bundle.pirls_result.as_ref();
1180
1181 let beta = beta_original_from_bundle(&bundle);
1182 let h_orig = h_original_from_bundle(&bundle);
1183 let u = &pr.solveweights * &(&pr.solveworking_response - &pr.final_eta);
1184
1185 let x_tau_beta = gam_linalg::faer_ndarray::fast_av(&x_tau, &beta);
1188 let weighted_x_tau_beta = &pr.finalweights * &x_tau_beta;
1189 let rhs = gam_linalg::faer_ndarray::fast_atv(&x_tau, &u)
1190 - gam_linalg::faer_ndarray::fast_atv(&x, &weighted_x_tau_beta)
1191 - s_tau.dot(&beta);
1192 let chol = h_orig.cholesky(Side::Lower).expect("chol(H)");
1193 let b_analytic = chol.solvevec(&rhs);
1194
1195 let eta_dot = &x_tau_beta + &gam_linalg::faer_ndarray::fast_av(&x, &b_analytic);
1199 let w_direction = crate::pirls::directionalworking_curvature_from_c_array(
1200 &pr.solve_c_array,
1201 &pr.finalweights,
1202 &eta_dot,
1203 );
1204 let wx = RemlState::row_scale(&x, &pr.finalweights);
1205 let wx_tau = RemlState::row_scale(&x_tau, &pr.finalweights);
1206 let mut xwtau_x = x.clone();
1207 match w_direction {
1208 crate::pirls::DirectionalWorkingCurvature::Diagonal(diag) => {
1209 xwtau_x = RemlState::row_scale(&xwtau_x, &diag);
1210 }
1211 }
1212 let mut h_tau_analytic = gam_linalg::faer_ndarray::fast_atb(&x_tau, &wx);
1213 h_tau_analytic += &gam_linalg::faer_ndarray::fast_atb(&x, &wx_tau);
1214 h_tau_analytic += &gam_linalg::faer_ndarray::fast_atb(&x, &xwtau_x);
1215 h_tau_analytic += &s_tau;
1216
1217 let ell_beta = gam_linalg::faer_ndarray::fast_atv(&x, &u);
1222 let s_eff = &h_orig - &gam_linalg::faer_ndarray::fast_atb(&x, &wx);
1223 let cancellation = -ell_beta.dot(&b_analytic) + beta.dot(&s_eff.dot(&b_analytic));
1224
1225 let h = 2e-5;
1227 let x_plus = &x + &(x_tau.mapv(|v| h * v));
1228 let x_minus = &x - &(x_tau.mapv(|v| h * v));
1229 let s_plus = &s0 + &(s_tau.mapv(|v| h * v));
1230 let s_minus = &s0 - &(s_tau.mapv(|v| h * v));
1231
1232 let state_plus = build_logit_state(&y, &w, &x_plus, &s_plus, &cfg);
1233 let state_minus = build_logit_state(&y, &w, &x_minus, &s_minus, &cfg);
1234 let bundle_plus = state_plus.obtain_eval_bundle(&rho).expect("bundle+");
1235 let bundle_minus = state_minus.obtain_eval_bundle(&rho).expect("bundle-");
1236 let beta_plus = beta_original_from_bundle(&bundle_plus);
1237 let beta_minus = beta_original_from_bundle(&bundle_minus);
1238 let bfd = (&beta_plus - &beta_minus).mapv(|v| v / (2.0 * h));
1239
1240 let h_plus = h_original_from_bundle(&bundle_plus);
1241 let h_minus = h_original_from_bundle(&bundle_minus);
1242 let h_taufd = (&h_plus - &h_minus).mapv(|v| v / (2.0 * h));
1243
1244 let v_plus = state_plus.compute_cost(&rho).expect("cost+");
1245 let v_minus = state_minus.compute_cost(&rho).expect("cost-");
1246 let v_taufd = (v_plus - v_minus) / (2.0 * h);
1247
1248 let v_tau_analytic = single_directional_tau_gradient(&state, &rho, hyper.clone())
1249 .expect("analytic directional gradient");
1250
1251 let b_num = (&b_analytic - &bfd).mapv(|v| v * v).sum().sqrt();
1252 let b_den = bfd.mapv(|v| v * v).sum().sqrt().max(1e-12);
1253 let b_rel = b_num / b_den;
1254 for i in 0..b_analytic.len() {
1255 assert_eq!(
1256 b_analytic[i].signum(),
1257 bfd[i].signum(),
1258 "B sign mismatch at i={i}: analytic={} fd={}",
1259 b_analytic[i],
1260 bfd[i]
1261 );
1262 }
1263 assert!(
1264 b_rel < 2e-2,
1265 "B implicit solve mismatch vs FD: rel={b_rel:.3e}, num={b_num:.3e}, den={b_den:.3e}"
1266 );
1267
1268 let dh_num = (&h_tau_analytic - &h_taufd).mapv(|v| v * v).sum().sqrt();
1269 let dh_den = h_taufd.mapv(|v| v * v).sum().sqrt().max(1e-12);
1270 let dh_rel = dh_num / dh_den;
1271 for i in 0..h_tau_analytic.nrows() {
1272 for j in 0..h_tau_analytic.ncols() {
1273 assert_eq!(
1274 h_tau_analytic[[i, j]].signum(),
1275 h_taufd[[i, j]].signum(),
1276 "H_tau sign mismatch at ({i},{j}): analytic={} fd={}",
1277 h_tau_analytic[[i, j]],
1278 h_taufd[[i, j]]
1279 );
1280 }
1281 }
1282 assert!(
1283 dh_rel < 3e-2,
1284 "H_tau mismatch vs FD: rel={dh_rel:.3e}, num={dh_num:.3e}, den={dh_den:.3e}"
1285 );
1286
1287 let v_abs = (v_tau_analytic - v_taufd).abs();
1288 let v_rel = v_abs / v_taufd.abs().max(1e-10);
1289 assert_eq!(
1290 v_tau_analytic.signum(),
1291 v_taufd.signum(),
1292 "V_tau sign mismatch: analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
1293 );
1294 assert!(
1295 v_rel < 2e-2,
1296 "V_tau mismatch vs FD: rel={v_rel:.3e}, abs={v_abs:.3e}, analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
1297 );
1298
1299 assert!(
1300 cancellation.abs() < 1e-10,
1301 "stationarity cancellation failed: | -ell_beta^T B + beta^T S B | = {:.3e}",
1302 cancellation.abs()
1303 );
1304 }
1305
1306 #[test]
1307 pub(crate) fn firth_exacthessian_includes_analytic_tk_second_derivatives() {
1308 let y = array![0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0];
1310 let w = Array1::<f64>::ones(y.len());
1311 let x = array![
1312 [1.0, -1.2, 0.4, -2.4],
1313 [1.0, -0.9, -0.1, -1.8],
1314 [1.0, -0.6, 0.3, -1.2],
1315 [1.0, -0.2, -0.4, -0.4],
1316 [1.0, 0.1, 0.5, 0.2],
1317 [1.0, 0.4, -0.6, 0.8],
1318 [1.0, 0.8, 0.2, 1.6],
1319 [1.0, 1.1, -0.3, 2.2],
1320 [1.0, 1.4, 0.7, 2.8],
1321 [1.0, 1.7, -0.2, 3.4],
1322 ];
1323 let s0 = array![
1324 [0.0, 0.0, 0.0, 0.0],
1325 [0.0, 1.5, 0.2, 0.0],
1326 [0.0, 0.2, 1.0, 0.0],
1327 [0.0, 0.0, 0.0, 0.5],
1328 ];
1329 let s1 = array![
1330 [0.0, 0.0, 0.0, 0.0],
1331 [0.0, 0.8, -0.1, 0.0],
1332 [0.0, -0.1, 0.6, 0.0],
1333 [0.0, 0.0, 0.0, 0.3],
1334 ];
1335 let offset = Array1::<f64>::zeros(y.len());
1336 let cfg =
1339 RemlConfig::external(binomial_logit_glm_spec(), 1e-9, true).with_max_iterations(500);
1340 let p = x.ncols();
1341 use crate::estimate::PenaltySpec;
1342 let specs = vec![PenaltySpec::Dense(s0), PenaltySpec::Dense(s1)];
1343 let canonical = gam_terms::construction::canonicalize_penalty_specs(&specs, &[1, 1], p, "test")
1344 .map(|(canonical, _)| canonical)
1345 .expect("canonicalize");
1346 let state = RemlState::newwith_offset(
1347 y.view(),
1348 x.clone(),
1349 w.view(),
1350 offset.view(),
1351 canonical,
1352 p,
1353 &cfg,
1354 Some(vec![1, 1]),
1355 None,
1356 None,
1357 )
1358 .expect("state");
1359 let rho = array![0.1, -0.2];
1360 assert!(
1361 state.analytic_outer_hessian_enabled(),
1362 "Firth logit should no longer disable analytic outer Hessian planning"
1363 );
1364 let outer = state
1365 .compute_outer_eval_with_order(
1366 &rho,
1367 crate::rho_optimizer::OuterEvalOrder::ValueGradientHessian,
1368 )
1369 .expect("outer Hessian eval should succeed");
1370 assert!(
1371 outer.hessian.is_analytic(),
1372 "outer planner should request and return an analytic Hessian"
1373 );
1374 let bundle = state.obtain_eval_bundle(&rho).expect("exact firth bundle");
1375 let h_dense = state
1376 .compute_lamlhessian_exact_from_bundle(&rho, &bundle)
1377 .expect("Firth exact Hessian should include analytic TK second derivatives");
1378 assert_eq!(h_dense.raw_dim(), ndarray::Ix2(2, 2));
1379 assert!(
1380 h_dense.iter().all(|value| value.is_finite()),
1381 "Hessian should be finite: {h_dense:?}"
1382 );
1383 }
1384
1385 #[test]
1386 pub(crate) fn firth_outer_hessian_matches_gradient_finite_difference_with_tk_terms() {
1387 let y = array![0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0];
1388 let w = Array1::<f64>::ones(y.len());
1389 let x = array![
1390 [1.0, -1.0, 0.3],
1391 [1.0, -0.7, -0.2],
1392 [1.0, -0.3, 0.4],
1393 [1.0, 0.0, -0.5],
1394 [1.0, 0.2, 0.6],
1395 [1.0, 0.6, -0.4],
1396 [1.0, 0.9, 0.2],
1397 [1.0, 1.3, -0.1],
1398 ];
1399 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.1], [0.0, 0.1, 0.7],];
1400 let s1 = array![[0.0, 0.0, 0.0], [0.0, 0.4, -0.05], [0.0, -0.05, 0.9],];
1401 let cfg =
1402 RemlConfig::external(binomial_logit_glm_spec(), 1e-9, true).with_max_iterations(500);
1403 let p_dim = x.ncols();
1404 use crate::estimate::PenaltySpec;
1405 let specs = vec![PenaltySpec::Dense(s0), PenaltySpec::Dense(s1)];
1406 let canonical =
1407 gam_terms::construction::canonicalize_penalty_specs(&specs, &[1, 1], p_dim, "test")
1408 .map(|(canonical, _)| canonical)
1409 .expect("canonicalize");
1410 let offset = Array1::<f64>::zeros(y.len());
1411 let state = RemlState::newwith_offset(
1412 y.view(),
1413 x.clone(),
1414 w.view(),
1415 offset.view(),
1416 canonical,
1417 p_dim,
1418 &cfg,
1419 Some(vec![1, 1]),
1420 None,
1421 None,
1422 )
1423 .expect("state");
1424 let rho = array![0.15, -0.25];
1425 let eval = state
1426 .compute_outer_eval_with_order(
1427 &rho,
1428 crate::rho_optimizer::OuterEvalOrder::ValueGradientHessian,
1429 )
1430 .expect("analytic Hessian eval");
1431 let h = match eval.hessian {
1432 HessianResult::Analytic(hessian) => hessian,
1433 HessianResult::Operator(_) | HessianResult::Unavailable => {
1434 panic!("expected dense analytic Hessian")
1435 }
1436 };
1437 let delta = 2.0e-5;
1438 for col in 0..rho.len() {
1439 let mut rp = rho.clone();
1440 let mut rm = rho.clone();
1441 rp[col] += delta;
1442 rm[col] -= delta;
1443 let gp = state
1444 .compute_outer_eval_with_order(
1445 &rp,
1446 crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
1447 )
1448 .expect("plus grad")
1449 .gradient;
1450 let gm = state
1451 .compute_outer_eval_with_order(
1452 &rm,
1453 crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
1454 )
1455 .expect("minus grad")
1456 .gradient;
1457 for row in 0..rho.len() {
1458 let fd = (gp[row] - gm[row]) / (2.0 * delta);
1459 let an = h[[row, col]];
1460 let rel = (fd - an).abs() / fd.abs().max(an.abs()).max(1e-6);
1461 assert!(
1462 rel < 2.0e-3,
1463 "Hessian mismatch ({row},{col}): analytic={an:.9e}, fd={fd:.9e}, rel={rel:.3e}"
1464 );
1465 }
1466 }
1467 }
1468
1469 #[test]
1470 pub(crate) fn firthgradient_lives_in_design_column_space_under_rank_deficiency() {
1471 let x = array![
1473 [1.0, -1.2, 0.4, -2.4],
1474 [1.0, -0.9, -0.1, -1.8],
1475 [1.0, -0.6, 0.3, -1.2],
1476 [1.0, -0.2, -0.4, -0.4],
1477 [1.0, 0.1, 0.5, 0.2],
1478 [1.0, 0.4, -0.6, 0.8],
1479 [1.0, 0.8, 0.2, 1.6],
1480 [1.0, 1.1, -0.3, 2.2],
1481 ];
1482 let beta = array![0.1, -0.2, 0.3, 0.05];
1483 let eta = x.dot(&beta);
1484 let op = super::RemlState::build_firth_dense_operator_for_link(
1485 &gam_problem::InverseLink::Standard(gam_problem::StandardLink::Logit),
1486 &x,
1487 &eta,
1488 ndarray::Array1::ones(x.nrows()).view(),
1489 )
1490 .expect("firth operator");
1491
1492 let gradphi = 0.5 * x.t().dot(&(&op.w1 * &op.h_diag));
1495
1496 let q = &op.q_basis;
1498 let proj = q.dot(&q.t().dot(&gradphi));
1499 let resid = &gradphi - &proj;
1500 let rel =
1501 resid.mapv(|v| v * v).sum().sqrt() / gradphi.mapv(|v| v * v).sum().sqrt().max(1e-12);
1502 assert!(
1503 rel < 1e-10,
1504 "Firth gradient should lie in Col(Xᵀ): rel residual={rel:.3e}"
1505 );
1506 }
1507
1508 #[test]
1509 pub(crate) fn firth_logit_directional_hypergradient_accepts_penalty_only_with_full_tk_gradient()
1510 {
1511 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
1512 let w = Array1::<f64>::ones(y.len());
1513 let x = array![
1514 [1.0, -1.1, 0.2],
1515 [1.0, -0.6, -0.3],
1516 [1.0, -0.1, 0.5],
1517 [1.0, 0.3, -0.7],
1518 [1.0, 0.8, 0.1],
1519 [1.0, 1.2, -0.4],
1520 ];
1521 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1522 let hyper = DirectionalHyperParam::single_penalty(
1523 0,
1524 Array2::<f64>::zeros((x.nrows(), x.ncols())),
1525 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.03], [0.0, 0.03, 0.12],],
1526 None,
1527 None,
1528 )
1529 .expect("single-penalty hyper direction");
1530 let rho = array![0.0];
1531 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-8, true);
1532 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1533 let gradient = single_directional_tau_gradient(&state, &rho, hyper)
1534 .expect("Firth penalty-only directional gradient should use analytic TK propagation");
1535 assert!(gradient.is_finite(), "gradient={gradient}");
1536 let fd = fd_directional_tau_cost_gradient(
1537 &y,
1538 &w,
1539 &x,
1540 &s0,
1541 &cfg,
1542 &rho,
1543 &Array2::<f64>::zeros((x.nrows(), x.ncols())),
1544 &array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.03], [0.0, 0.03, 0.12],],
1545 );
1546 let rel = (gradient - fd).abs() / gradient.abs().max(fd.abs()).max(1.0e-10);
1547 assert!(
1548 rel < 1.0e-3,
1549 "Firth penalty-only directional gradient mismatch: analytic={gradient:.12e}, fd={fd:.12e}, rel={rel:.3e}"
1550 );
1551
1552 let efs_hyper = DirectionalHyperParam::single_penalty(
1553 0,
1554 Array2::<f64>::zeros((x.nrows(), x.ncols())),
1555 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.03], [0.0, 0.03, 0.12],],
1556 None,
1557 None,
1558 )
1559 .expect("single-penalty EFS hyper direction");
1560 let efs = state
1561 .compute_efs_steps_with_psi_ext(&rho, &[efs_hyper])
1562 .expect("Firth penalty-only EFS should use analytic TK propagation");
1563 assert!(efs.cost.is_finite(), "efs cost={}", efs.cost);
1564 }
1565
1566 #[test]
1567 pub(crate) fn firth_logit_directional_hypergradient_accepts_design_moving_with_full_tk_gradient()
1568 {
1569 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
1570 let w = Array1::<f64>::ones(y.len());
1571 let x = array![
1572 [1.0, -1.1, 0.2],
1573 [1.0, -0.6, -0.3],
1574 [1.0, -0.1, 0.5],
1575 [1.0, 0.3, -0.7],
1576 [1.0, 0.8, 0.1],
1577 [1.0, 1.2, -0.4],
1578 ];
1579 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1580 let hyper = DirectionalHyperParam::single_penalty(
1581 0,
1582 Array2::from_elem((x.nrows(), x.ncols()), 1e-3),
1583 Array2::<f64>::zeros((x.ncols(), x.ncols())),
1584 None,
1585 None,
1586 )
1587 .expect("single-penalty hyper direction");
1588 let rho = array![0.0];
1589 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-8, true);
1590 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1591 let gradient = single_directional_tau_gradient(&state, &rho, hyper)
1592 .expect("Firth design-moving directional gradient should use analytic TK propagation");
1593 assert!(gradient.is_finite(), "gradient={gradient}");
1594 let x_tau = Array2::from_elem((x.nrows(), x.ncols()), 1e-3);
1595 let s_tau = Array2::<f64>::zeros((x.ncols(), x.ncols()));
1596 let fd = fd_directional_tau_cost_gradient(&y, &w, &x, &s0, &cfg, &rho, &x_tau, &s_tau);
1597 let rel = (gradient - fd).abs() / gradient.abs().max(fd.abs()).max(1.0e-10);
1598 assert!(
1599 rel < 2.0e-2,
1600 "Firth design-moving directional gradient mismatch: analytic={gradient:.12e}, fd={fd:.12e}, rel={rel:.3e}"
1601 );
1602 }
1603
1604 #[test]
1605 pub(crate) fn firth_logit_hybrid_efs_accepts_full_tk_psi_gradient() {
1606 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
1607 let w = Array1::<f64>::ones(y.len());
1608 let x = array![
1609 [1.0, -1.1, 0.2],
1610 [1.0, -0.6, -0.3],
1611 [1.0, -0.1, 0.5],
1612 [1.0, 0.3, -0.7],
1613 [1.0, 0.8, 0.1],
1614 [1.0, 1.2, -0.4],
1615 ];
1616 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1617 let hyper_dirs = vec![
1618 DirectionalHyperParam::single_penalty(
1619 0,
1620 Array2::from_shape_fn((x.nrows(), x.ncols()), |(i, j)| {
1621 1e-3 * ((i + 1) as f64) * ((j + 2) as f64)
1622 }),
1623 Array2::<f64>::zeros((x.ncols(), x.ncols())),
1624 None,
1625 None,
1626 )
1627 .expect("design-moving hyper direction"),
1628 ];
1629 let rho = array![0.0];
1630 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-8, true);
1631 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1632
1633 let full = state
1634 .evaluate_unified_with_psi_ext(
1635 &rho,
1636 None,
1637 crate::estimate::reml::reml_outer_engine::EvalMode::ValueAndGradient,
1638 &hyper_dirs,
1639 )
1640 .expect("full Firth psi gradient should use analytic TK propagation");
1641 assert!(full.cost.is_finite(), "full cost={}", full.cost);
1642 let full_grad = full.gradient.expect("gradient should be present");
1643 assert!(
1644 full_grad.iter().all(|value| value.is_finite()),
1645 "full gradient={full_grad:?}"
1646 );
1647
1648 let efs = state
1649 .compute_efs_steps_with_psi_ext(&rho, &hyper_dirs)
1650 .expect("hybrid EFS should use analytic TK propagation");
1651 assert!(efs.cost.is_finite(), "efs cost={}", efs.cost);
1652 }
1653
1654 #[test]
1655 pub(crate) fn joint_hyperhessianwires_mixed_blocks() {
1656 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1657 let w = Array1::<f64>::ones(y.len());
1658 let x = array![
1659 [1.0, -1.2, 0.3],
1660 [1.0, -0.8, -0.4],
1661 [1.0, -0.3, 0.7],
1662 [1.0, 0.1, -0.9],
1663 [1.0, 0.5, 0.2],
1664 [1.0, 0.9, -0.1],
1665 [1.0, 1.3, 0.8],
1666 [1.0, 1.7, -0.6],
1667 ];
1668 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1669 let cfg =
1670 RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false).with_max_iterations(500);
1671 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1672 let rho = array![0.0];
1673 let theta = array![0.0, 0.0, 0.0];
1674 let hyper_dirs = vec![
1675 DirectionalHyperParam::single_penalty(
1676 0,
1677 Array2::<f64>::zeros((x.nrows(), x.ncols())),
1678 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15],],
1679 None,
1680 None,
1681 )
1682 .expect("single-penalty hyper direction"),
1683 DirectionalHyperParam::single_penalty(
1684 0,
1685 Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
1686 Array2::<f64>::zeros((x.ncols(), x.ncols())),
1687 None,
1688 None,
1689 )
1690 .expect("single-penalty hyper direction"),
1691 ];
1692
1693 let (_, _, h) =
1694 compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
1695 .expect("joint hyper cost+gradient+hessian");
1696 assert_eq!(h.nrows(), theta.len());
1697 assert_eq!(h.ncols(), theta.len());
1698 assert!(h.iter().all(|v| v.is_finite()));
1699 for i in 0..h.nrows() {
1700 for j in 0..i {
1701 let diff = (h[[i, j]] - h[[j, i]]).abs();
1702 assert!(
1703 diff < 1e-6,
1704 "joint hessian asymmetry at ({i},{j}): {diff:.3e}"
1705 );
1706 }
1707 }
1708 let mixed_0 = h[[0, 1]];
1710 let mixed_1 = h[[0, 2]];
1711 assert!(
1712 mixed_0.is_finite() && mixed_1.is_finite(),
1713 "mixed blocks must be finite"
1714 );
1715 }
1716
1717 #[test]
1718 pub(crate) fn joint_tau_tau_linear_dirs_matchfd_reference_away_fromzero_psi() {
1719 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1720 let w = Array1::<f64>::ones(y.len());
1721 let x = array![
1722 [1.0, -1.2, 0.3],
1723 [1.0, -0.8, -0.4],
1724 [1.0, -0.3, 0.7],
1725 [1.0, 0.1, -0.9],
1726 [1.0, 0.5, 0.2],
1727 [1.0, 0.9, -0.1],
1728 [1.0, 1.3, 0.8],
1729 [1.0, 1.7, -0.6],
1730 ];
1731 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1732 let cfg =
1733 RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false).with_max_iterations(500);
1734 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1735 let rho = array![0.0];
1736 let psi = array![0.7, -0.4];
1737 let theta = array![rho[0], psi[0], psi[1]];
1738 let hyper_dirs = vec![
1739 DirectionalHyperParam::single_penalty(
1740 0,
1741 Array2::<f64>::zeros((x.nrows(), x.ncols())),
1742 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15],],
1743 None,
1744 None,
1745 )
1746 .expect("linear tau direction"),
1747 DirectionalHyperParam::single_penalty(
1748 0,
1749 Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
1750 Array2::<f64>::zeros((x.ncols(), x.ncols())),
1751 None,
1752 None,
1753 )
1754 .expect("linear tau direction"),
1755 ];
1756
1757 let (_, _, h_full) =
1758 compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
1759 .expect("joint hyper cost+gradient+hessian");
1760 let h_tt_analytic = h_full.slice(s![rho.len().., rho.len()..]).to_owned();
1761
1762 let x_tau_mats: Vec<Array2<f64>> = vec![
1767 Array2::<f64>::zeros((x.nrows(), x.ncols())),
1768 Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
1769 ];
1770 let s_tau_mats: Vec<Array2<f64>> = vec![
1771 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15]],
1772 Array2::<f64>::zeros((x.ncols(), x.ncols())),
1773 ];
1774
1775 let h_ttfd = directional_tau_hessian_fd_reference(
1776 &y,
1777 &w,
1778 &x,
1779 &s0,
1780 &cfg,
1781 &rho,
1782 &hyper_dirs,
1783 &x_tau_mats,
1784 &s_tau_mats,
1785 );
1786
1787 let num = (&h_tt_analytic - &h_ttfd)
1788 .iter()
1789 .map(|v| v * v)
1790 .sum::<f64>()
1791 .sqrt();
1792 let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
1793 let rel = num / den;
1794 assert!(
1795 rel < 1e-4,
1796 "linear-dir joint tau-tau block deviates from FD reference away from zero psi: rel={rel:.3e}, analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
1797 );
1798 }
1799
1800 #[test]
1801 pub(crate) fn joint_hypervalidation_rejects_out_of_boundssecond_order_penalty_index() {
1802 let y = array![0.0, 1.0, 0.0, 1.0];
1819 let w = Array1::<f64>::ones(y.len());
1820 let x = array![
1821 [1.0, -0.5, 0.2],
1822 [1.0, -0.1, -0.3],
1823 [1.0, 0.4, 0.6],
1824 [1.0, 0.9, -0.2],
1825 ];
1826 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1827 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, true);
1828 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1829 let theta = array![0.0, 0.0];
1830 let hyper_dirs = vec![
1831 DirectionalHyperParam::new(
1832 Array2::<f64>::zeros((x.nrows(), x.ncols())),
1833 vec![(0, Array2::<f64>::zeros((x.ncols(), x.ncols())))],
1834 None,
1835 Some(vec![Some(vec![(1, Array2::<f64>::eye(x.ncols()))])]),
1836 )
1837 .expect("hyper direction with invalid second-order penalty index"),
1838 ];
1839
1840 let msg = match compute_joint_hypercostgradienthessian(&state, &theta, 1, &hyper_dirs) {
1841 Ok(_) => panic!("invalid second-order penalty index should be rejected"),
1842 Err(err) => err.to_string(),
1843 };
1844 assert!(
1845 msg.contains("out of bounds") || msg.contains("penalty_index"),
1846 "unexpected validation error: {msg}"
1847 );
1848 }
1849
1850 #[test]
1851 pub(crate) fn joint_tau_tau_analytic_matchesfd_reference() {
1852 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1853 let w = Array1::<f64>::ones(y.len());
1854 let x = array![
1855 [1.0, -1.2, 0.3],
1856 [1.0, -0.8, -0.4],
1857 [1.0, -0.3, 0.7],
1858 [1.0, 0.1, -0.9],
1859 [1.0, 0.5, 0.2],
1860 [1.0, 0.9, -0.1],
1861 [1.0, 1.3, 0.8],
1862 [1.0, 1.7, -0.6],
1863 ];
1864 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1865 let cfg =
1866 RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false).with_max_iterations(500);
1867 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1868 let rho = array![0.0];
1869 let psi = array![0.0, 0.0];
1870 let hyper_dirs = vec![
1871 DirectionalHyperParam::single_penalty(
1872 0,
1873 Array2::<f64>::zeros((x.nrows(), x.ncols())),
1874 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15],],
1875 None,
1876 None,
1877 )
1878 .expect("single-penalty hyper direction"),
1879 DirectionalHyperParam::single_penalty(
1880 0,
1881 Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
1882 Array2::<f64>::zeros((x.ncols(), x.ncols())),
1883 None,
1884 None,
1885 )
1886 .expect("single-penalty hyper direction"),
1887 ];
1888
1889 let theta = {
1890 let mut t = Array1::<f64>::zeros(rho.len() + psi.len());
1891 t.slice_mut(s![..rho.len()]).assign(&rho);
1892 t.slice_mut(s![rho.len()..]).assign(&psi);
1893 t
1894 };
1895 let (_, _, h_full) =
1896 compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
1897 .expect("joint hyper cost+gradient+hessian");
1898 let h_tt_analytic = h_full.slice(s![rho.len().., rho.len()..]).to_owned();
1899 assert_eq!(h_tt_analytic.nrows(), hyper_dirs.len());
1900 assert_eq!(h_tt_analytic.ncols(), hyper_dirs.len());
1901
1902 let x_tau_mats: Vec<Array2<f64>> = vec![
1907 Array2::<f64>::zeros((x.nrows(), x.ncols())),
1908 Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
1909 ];
1910 let s_tau_mats: Vec<Array2<f64>> = vec![
1911 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15]],
1912 Array2::<f64>::zeros((x.ncols(), x.ncols())),
1913 ];
1914
1915 let h_ttfd = directional_tau_hessian_fd_reference(
1916 &y,
1917 &w,
1918 &x,
1919 &s0,
1920 &cfg,
1921 &rho,
1922 &hyper_dirs,
1923 &x_tau_mats,
1924 &s_tau_mats,
1925 );
1926
1927 let num = (&h_tt_analytic - &h_ttfd)
1928 .iter()
1929 .map(|v| v * v)
1930 .sum::<f64>()
1931 .sqrt();
1932 let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
1933 let rel = num / den;
1934 assert!(
1935 rel < 1e-4,
1936 "analytic tau-tau block deviates from FD reference: rel={rel:.3e}, analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
1937 );
1938 }
1939
1940 pub(crate) struct GaussianRemlFixture {
1950 pub(crate) y: Array1<f64>,
1951 pub(crate) w: Array1<f64>,
1952 pub(crate) x: Array2<f64>,
1953 pub(crate) s0: Array2<f64>,
1954 pub(crate) cfg: RemlConfig,
1955 pub(crate) rho: Array1<f64>,
1956 pub(crate) x_tau_design: Array2<f64>,
1958 pub(crate) s_tau_penalty: Array2<f64>,
1960 }
1961
1962 impl GaussianRemlFixture {
1963 pub(crate) fn new() -> Self {
1964 let y = array![0.5, 1.2, -0.3, 0.8, 1.1, -0.6, 0.9, 0.1, -0.2, 0.7];
1965 let x = array![
1966 [1.0, -1.2, 0.3],
1967 [1.0, -0.8, -0.4],
1968 [1.0, -0.3, 0.7],
1969 [1.0, 0.1, -0.9],
1970 [1.0, 0.5, 0.2],
1971 [1.0, 0.9, -0.1],
1972 [1.0, 1.3, 0.8],
1973 [1.0, 1.7, -0.6],
1974 [1.0, -0.5, 0.5],
1975 [1.0, 0.3, -0.3],
1976 ];
1977 Self {
1978 w: Array1::<f64>::ones(y.len()),
1979 y,
1980 x: x.clone(),
1981 s0: array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9]],
1982 cfg: RemlConfig::external(gaussian_identity_glm_spec(), 1e-14, false),
1983 rho: array![0.0],
1984 x_tau_design: array![
1985 [0.0, 1e-3, -2e-3],
1986 [0.0, -3e-3, 1e-3],
1987 [0.0, 2e-3, 0.5e-3],
1988 [0.0, -1e-3, 3e-3],
1989 [0.0, 0.5e-3, -1e-3],
1990 [0.0, 1.5e-3, 2e-3],
1991 [0.0, -2e-3, -0.5e-3],
1992 [0.0, 3e-3, 1e-3],
1993 [0.0, -0.5e-3, 2e-3],
1994 [0.0, 1e-3, -1.5e-3],
1995 ],
1996 s_tau_penalty: array![[0.0, 0.0, 0.0], [0.0, 0.25, 0.04], [0.0, 0.04, 0.15]],
1997 }
1998 }
1999 }
2000
2001 impl LogitDesignMotionFixture for GaussianRemlFixture {
2002 fn y(&self) -> &Array1<f64> {
2003 &self.y
2004 }
2005 fn w(&self) -> &Array1<f64> {
2006 &self.w
2007 }
2008 fn x(&self) -> &Array2<f64> {
2009 &self.x
2010 }
2011 fn s0(&self) -> &Array2<f64> {
2012 &self.s0
2013 }
2014 fn cfg(&self) -> &RemlConfig {
2015 &self.cfg
2016 }
2017 fn rho(&self) -> &Array1<f64> {
2018 &self.rho
2019 }
2020 }
2021
2022 #[test]
2023 pub(crate) fn profiled_gaussian_design_moving_gradient_matches_fd() {
2024 let f = GaussianRemlFixture::new();
2025 let state = f.state();
2026 let s_tau = Array2::<f64>::zeros((3, 3));
2027 let hyper = DirectionalHyperParam::single_penalty(
2028 0,
2029 f.x_tau_design.clone(),
2030 s_tau.clone(),
2031 None,
2032 None,
2033 )
2034 .expect("design-moving hyper direction");
2035
2036 let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2037 .expect("analytic directional gradient");
2038 let v_taufd = f.fd_directional_gradient(&f.x_tau_design, &s_tau);
2039
2040 let v_rel = (v_tau_analytic - v_taufd).abs() / v_taufd.abs().max(1e-10);
2041 assert!(
2042 v_rel < 1e-3,
2043 "Gaussian REML design-moving V_tau mismatch: rel={v_rel:.3e}, \
2044 analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
2045 );
2046 }
2047
2048 #[test]
2049 pub(crate) fn profiled_gaussian_penalty_only_gradient_matches_fd() {
2050 let f = GaussianRemlFixture::new();
2051 let state = f.state();
2052 let x_tau = Array2::<f64>::zeros(f.x.raw_dim());
2053 let hyper = DirectionalHyperParam::single_penalty(
2054 0,
2055 x_tau.clone(),
2056 f.s_tau_penalty.clone(),
2057 None,
2058 None,
2059 )
2060 .expect("penalty-only hyper direction");
2061
2062 let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2063 .expect("analytic directional gradient");
2064 let v_taufd = f.fd_directional_gradient(&x_tau, &f.s_tau_penalty);
2065
2066 let v_rel = (v_tau_analytic - v_taufd).abs() / v_taufd.abs().max(1e-10);
2067 assert!(
2068 v_rel < 1e-3,
2069 "Gaussian REML penalty-only V_tau mismatch: rel={v_rel:.3e}, \
2070 analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
2071 );
2072 }
2073
2074 #[test]
2075 pub(crate) fn profiled_gaussian_joint_hessian_matches_fd() {
2076 let f = GaussianRemlFixture::new();
2079 let x_tau_0 = Array2::<f64>::zeros(f.x.raw_dim());
2080 let s_tau_0 = f.s_tau_penalty.clone();
2081 let x_tau_1 = f.x_tau_design.clone();
2082 let s_tau_1 = Array2::<f64>::zeros((3, 3));
2083
2084 let hyper_dirs = vec![
2085 DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2086 .expect("penalty-only direction"),
2087 DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2088 .expect("design-moving direction"),
2089 ];
2090
2091 let state = f.state();
2092 let mut theta = Array1::<f64>::zeros(f.rho.len() + hyper_dirs.len());
2093 theta.slice_mut(s![..f.rho.len()]).assign(&f.rho);
2094 let (_, _, h_full) =
2095 compute_joint_hypercostgradienthessian(&state, &theta, f.rho.len(), &hyper_dirs)
2096 .expect("joint cost+gradient+hessian");
2097 let h_tt_analytic = h_full.slice(s![f.rho.len().., f.rho.len()..]).to_owned();
2098
2099 let x_tau_mats = vec![x_tau_0.clone(), x_tau_1.clone()];
2102 let s_tau_mats = vec![s_tau_0.clone(), s_tau_1.clone()];
2103 let h_ttfd = directional_tau_hessian_fd_reference(
2104 &f.y,
2105 &f.w,
2106 &f.x,
2107 &f.s0,
2108 &f.cfg,
2109 &f.rho,
2110 &hyper_dirs,
2111 &x_tau_mats,
2112 &s_tau_mats,
2113 );
2114
2115 let num = (&h_tt_analytic - &h_ttfd)
2116 .iter()
2117 .map(|v| v * v)
2118 .sum::<f64>()
2119 .sqrt();
2120 let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2121 let rel = num / den;
2122 assert!(
2123 rel < 1e-4,
2124 "Gaussian REML tau-tau Hessian mismatch: rel={rel:.3e}, \
2125 analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
2126 );
2127 }
2128
2129 #[test]
2143 pub(crate) fn logit_design_moving_gradient_matches_fd() {
2144 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0];
2145 let w = Array1::<f64>::ones(y.len());
2146 let x = array![
2147 [1.0, -1.2, 0.3],
2148 [1.0, -0.8, -0.4],
2149 [1.0, -0.3, 0.7],
2150 [1.0, 0.1, -0.9],
2151 [1.0, 0.5, 0.2],
2152 [1.0, 0.9, -0.1],
2153 [1.0, 1.3, 0.8],
2154 [1.0, 1.7, -0.6],
2155 [1.0, -0.5, 0.5],
2156 [1.0, 0.3, -0.3],
2157 ];
2158 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9]];
2159 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
2160 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
2161 let rho = array![0.0];
2162
2163 let x_tau = array![
2165 [0.0, 1e-3, -2e-3],
2166 [0.0, -3e-3, 1e-3],
2167 [0.0, 2e-3, 0.5e-3],
2168 [0.0, -1e-3, 3e-3],
2169 [0.0, 0.5e-3, -1e-3],
2170 [0.0, 1.5e-3, 2e-3],
2171 [0.0, -2e-3, -0.5e-3],
2172 [0.0, 3e-3, 1e-3],
2173 [0.0, -0.5e-3, 2e-3],
2174 [0.0, 1e-3, -1.5e-3],
2175 ];
2176 let s_tau = Array2::<f64>::zeros((3, 3));
2177 let hyper =
2178 DirectionalHyperParam::single_penalty(0, x_tau.clone(), s_tau.clone(), None, None)
2179 .expect("design-moving hyper direction");
2180
2181 let v_tau_analytic = single_directional_tau_gradient(&state, &rho, hyper)
2182 .expect("analytic directional gradient");
2183
2184 let h = 2e-5;
2185 let x_plus = &x + &x_tau.mapv(|v| h * v);
2186 let x_minus = &x - &x_tau.mapv(|v| h * v);
2187 let state_plus = build_logit_state(&y, &w, &x_plus, &s0, &cfg);
2188 let state_minus = build_logit_state(&y, &w, &x_minus, &s0, &cfg);
2189 let v_plus = state_plus.compute_cost(&rho).expect("cost+");
2190 let v_minus = state_minus.compute_cost(&rho).expect("cost-");
2191 let v_taufd = (v_plus - v_minus) / (2.0 * h);
2192
2193 let v_rel = (v_tau_analytic - v_taufd).abs() / v_taufd.abs().max(1e-10);
2194 assert!(
2195 v_rel < 1e-3,
2196 "Logit REML design-moving V_tau mismatch: rel={v_rel:.3e}, \
2197 analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
2198 );
2199 }
2200
2201 #[test]
2202 pub(crate) fn logit_design_moving_hessian_matches_fd() {
2203 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0];
2208 let w = Array1::<f64>::ones(y.len());
2209 let x = array![
2210 [1.0, -1.2, 0.3],
2211 [1.0, -0.8, -0.4],
2212 [1.0, -0.3, 0.7],
2213 [1.0, 0.1, -0.9],
2214 [1.0, 0.5, 0.2],
2215 [1.0, 0.9, -0.1],
2216 [1.0, 1.3, 0.8],
2217 [1.0, 1.7, -0.6],
2218 [1.0, -0.5, 0.5],
2219 [1.0, 0.3, -0.3],
2220 ];
2221 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9]];
2222 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
2223 let rho = array![0.0];
2224
2225 let x_tau_0 = Array2::<f64>::zeros(x.raw_dim());
2227 let s_tau_0 = array![[0.0, 0.0, 0.0], [0.0, 0.25, 0.04], [0.0, 0.04, 0.15]];
2228 let x_tau_1 = array![
2229 [0.0, 1e-3, -2e-3],
2230 [0.0, -3e-3, 1e-3],
2231 [0.0, 2e-3, 0.5e-3],
2232 [0.0, -1e-3, 3e-3],
2233 [0.0, 0.5e-3, -1e-3],
2234 [0.0, 1.5e-3, 2e-3],
2235 [0.0, -2e-3, -0.5e-3],
2236 [0.0, 3e-3, 1e-3],
2237 [0.0, -0.5e-3, 2e-3],
2238 [0.0, 1e-3, -1.5e-3],
2239 ];
2240 let s_tau_1 = Array2::<f64>::zeros((3, 3));
2241
2242 let hyper_dirs = vec![
2243 DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2244 .expect("penalty-only direction"),
2245 DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2246 .expect("design-moving direction"),
2247 ];
2248
2249 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
2250 let mut theta = Array1::<f64>::zeros(rho.len() + hyper_dirs.len());
2251 theta.slice_mut(s![..rho.len()]).assign(&rho);
2252 let (_, _, h_full) =
2253 compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
2254 .expect("joint cost+gradient+hessian");
2255 let h_tt_analytic = h_full.slice(s![rho.len().., rho.len()..]).to_owned();
2256
2257 let x_tau_mats = vec![x_tau_0.clone(), x_tau_1.clone()];
2258 let s_tau_mats = vec![s_tau_0.clone(), s_tau_1.clone()];
2259 let h_ttfd = directional_tau_hessian_fd_reference(
2260 &y,
2261 &w,
2262 &x,
2263 &s0,
2264 &cfg,
2265 &rho,
2266 &hyper_dirs,
2267 &x_tau_mats,
2268 &s_tau_mats,
2269 );
2270
2271 let num = (&h_tt_analytic - &h_ttfd)
2272 .iter()
2273 .map(|v| v * v)
2274 .sum::<f64>()
2275 .sqrt();
2276 let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2277 let rel = num / den;
2278 assert!(
2279 rel < 1e-4,
2280 "Logit REML design-moving tau-tau Hessian mismatch: rel={rel:.3e}, \
2281 analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
2282 );
2283 }
2284
2285 pub(crate) struct BinomialLogitDesignMotionFixture {
2295 pub(crate) y: Array1<f64>,
2296 pub(crate) w: Array1<f64>,
2297 pub(crate) x: Array2<f64>,
2298 pub(crate) s0: Array2<f64>,
2299 pub(crate) cfg: RemlConfig,
2300 pub(crate) rho: Array1<f64>,
2301 pub(crate) x_tau_design: Array2<f64>,
2303 pub(crate) s_tau_penalty: Array2<f64>,
2305 }
2306
2307 impl BinomialLogitDesignMotionFixture {
2308 pub(crate) fn new() -> Self {
2309 let y = array![
2311 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0,
2312 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0
2313 ];
2314 let x = array![
2316 [1.0, -1.50, 0.42, 0.88, -0.31],
2317 [1.0, -1.12, -0.65, 0.14, 1.23],
2318 [1.0, -0.80, 1.10, -0.53, 0.07],
2319 [1.0, -0.55, -0.22, 1.40, -0.90],
2320 [1.0, -0.30, 0.73, -1.05, 0.44],
2321 [1.0, -0.05, -1.33, 0.60, 0.81],
2322 [1.0, 0.18, 0.55, -0.27, -1.15],
2323 [1.0, 0.42, -0.90, 1.12, 0.33],
2324 [1.0, 0.70, 1.28, -0.78, -0.56],
2325 [1.0, 0.95, -0.18, 0.45, 1.40],
2326 [1.0, 1.20, 0.66, -1.30, -0.02],
2327 [1.0, 1.45, -1.05, 0.22, 0.68],
2328 [1.0, -1.35, 0.90, 0.55, -0.43],
2329 [1.0, -0.98, -0.40, -0.88, 1.05],
2330 [1.0, -0.62, 1.42, 0.30, -0.70],
2331 [1.0, -0.28, -0.77, -1.18, 0.52],
2332 [1.0, 0.05, 0.15, 0.95, -1.35],
2333 [1.0, 0.33, -1.20, -0.40, 0.18],
2334 [1.0, 0.60, 0.82, 1.25, -0.85],
2335 [1.0, 0.88, -0.50, -0.65, 1.10],
2336 [1.0, 1.15, 1.05, 0.10, -0.22],
2337 [1.0, -1.22, -0.95, 0.72, 0.90],
2338 [1.0, -0.75, 0.38, -1.42, 0.15],
2339 [1.0, -0.42, -1.15, 0.50, -1.08],
2340 [1.0, -0.10, 0.60, -0.15, 0.75],
2341 [1.0, 0.25, -0.28, 1.05, -0.48],
2342 [1.0, 0.52, 1.35, -0.92, 0.30],
2343 [1.0, 0.80, -0.70, 0.38, 1.20],
2344 [1.0, 1.08, 0.48, -0.60, -0.95],
2345 [1.0, 1.35, -0.55, 0.85, 0.42]
2346 ];
2347 let s0 = array![
2349 [0.0, 0.0, 0.0, 0.0, 0.0],
2350 [0.0, 1.40, 0.15, 0.05, -0.10],
2351 [0.0, 0.15, 1.10, -0.20, 0.08],
2352 [0.0, 0.05, -0.20, 0.95, 0.12],
2353 [0.0, -0.10, 0.08, 0.12, 1.25]
2354 ];
2355 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
2356 let x_tau_design = array![
2359 [0.0, 1.2e-3, -0.8e-3, 0.5e-3, -1.5e-3],
2360 [0.0, -2.0e-3, 1.4e-3, -0.3e-3, 0.9e-3],
2361 [0.0, 0.6e-3, -1.1e-3, 1.8e-3, -0.4e-3],
2362 [0.0, -1.3e-3, 0.7e-3, -1.0e-3, 2.1e-3],
2363 [0.0, 0.9e-3, -0.5e-3, 0.2e-3, -0.8e-3],
2364 [0.0, -0.4e-3, 1.8e-3, -1.5e-3, 0.3e-3],
2365 [0.0, 1.5e-3, -1.3e-3, 0.8e-3, -1.1e-3],
2366 [0.0, -0.7e-3, 0.4e-3, -2.0e-3, 1.6e-3],
2367 [0.0, 2.2e-3, -0.9e-3, 1.3e-3, -0.6e-3],
2368 [0.0, -1.0e-3, 1.6e-3, -0.7e-3, 0.5e-3],
2369 [0.0, 0.3e-3, -2.1e-3, 1.1e-3, -1.8e-3],
2370 [0.0, -1.8e-3, 0.2e-3, -0.4e-3, 1.3e-3],
2371 [0.0, 1.1e-3, -1.5e-3, 2.0e-3, -0.2e-3],
2372 [0.0, -0.5e-3, 0.9e-3, -1.2e-3, 0.7e-3],
2373 [0.0, 1.7e-3, -0.3e-3, 0.6e-3, -2.0e-3],
2374 [0.0, -1.4e-3, 1.1e-3, -0.9e-3, 0.4e-3],
2375 [0.0, 0.8e-3, -1.7e-3, 1.5e-3, -0.1e-3],
2376 [0.0, -0.2e-3, 0.6e-3, -1.8e-3, 1.0e-3],
2377 [0.0, 1.4e-3, -0.4e-3, 0.3e-3, -1.3e-3],
2378 [0.0, -0.9e-3, 2.0e-3, -0.5e-3, 0.8e-3],
2379 [0.0, 0.5e-3, -1.0e-3, 1.6e-3, -0.7e-3],
2380 [0.0, -2.1e-3, 0.3e-3, -0.8e-3, 1.5e-3],
2381 [0.0, 0.7e-3, -1.8e-3, 0.9e-3, -0.3e-3],
2382 [0.0, -0.6e-3, 1.3e-3, -2.2e-3, 1.1e-3],
2383 [0.0, 1.9e-3, -0.7e-3, 0.4e-3, -0.9e-3],
2384 [0.0, -1.1e-3, 0.5e-3, -1.4e-3, 2.2e-3],
2385 [0.0, 0.4e-3, -1.6e-3, 1.2e-3, -0.5e-3],
2386 [0.0, -1.6e-3, 0.8e-3, -0.1e-3, 0.6e-3],
2387 [0.0, 1.3e-3, -2.2e-3, 0.7e-3, -1.4e-3],
2388 [0.0, -0.3e-3, 1.0e-3, -1.6e-3, 1.8e-3]
2389 ];
2390 let s_tau_penalty = array![
2392 [0.0, 0.0, 0.0, 0.0, 0.0],
2393 [0.0, 0.30, 0.05, -0.02, 0.04],
2394 [0.0, 0.05, 0.22, 0.03, -0.01],
2395 [0.0, -0.02, 0.03, 0.18, 0.06],
2396 [0.0, 0.04, -0.01, 0.06, 0.26]
2397 ];
2398 Self {
2399 w: Array1::<f64>::ones(y.len()),
2400 y,
2401 x,
2402 s0,
2403 cfg,
2404 rho: array![0.0],
2405 x_tau_design,
2406 s_tau_penalty,
2407 }
2408 }
2409 }
2410
2411 impl LogitDesignMotionFixture for BinomialLogitDesignMotionFixture {
2412 fn y(&self) -> &Array1<f64> {
2413 &self.y
2414 }
2415 fn w(&self) -> &Array1<f64> {
2416 &self.w
2417 }
2418 fn x(&self) -> &Array2<f64> {
2419 &self.x
2420 }
2421 fn s0(&self) -> &Array2<f64> {
2422 &self.s0
2423 }
2424 fn cfg(&self) -> &RemlConfig {
2425 &self.cfg
2426 }
2427 fn rho(&self) -> &Array1<f64> {
2428 &self.rho
2429 }
2430 }
2431
2432 #[test]
2435 pub(crate) fn binomial_logit_n30_design_moving_gradient_matches_fd() {
2436 let f = BinomialLogitDesignMotionFixture::new();
2443 let state = f.state();
2444 let s_tau = Array2::<f64>::zeros((5, 5));
2445 let hyper = DirectionalHyperParam::single_penalty(
2446 0,
2447 f.x_tau_design.clone(),
2448 s_tau.clone(),
2449 None,
2450 None,
2451 )
2452 .expect("design-moving hyper direction");
2453
2454 let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2455 .expect("analytic directional gradient");
2456 let v_tau_fd = f.fd_directional_gradient(&f.x_tau_design, &s_tau);
2457
2458 let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2459 assert!(
2460 v_rel < 1e-3,
2461 "Binomial-logit n=30 design-moving gradient mismatch: rel={v_rel:.3e}, \
2462 analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2463 );
2464 }
2465
2466 #[test]
2467 pub(crate) fn binomial_logit_n30_penalty_only_gradient_matches_fd() {
2468 let f = BinomialLogitDesignMotionFixture::new();
2473 let state = f.state();
2474 let x_tau = Array2::<f64>::zeros(f.x.raw_dim());
2475 let hyper = DirectionalHyperParam::single_penalty(
2476 0,
2477 x_tau.clone(),
2478 f.s_tau_penalty.clone(),
2479 None,
2480 None,
2481 )
2482 .expect("penalty-only hyper direction");
2483
2484 let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2485 .expect("analytic directional gradient");
2486 let v_tau_fd = f.fd_directional_gradient(&x_tau, &f.s_tau_penalty);
2487
2488 let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2489 assert!(
2490 v_rel < 1e-3,
2491 "Binomial-logit n=30 penalty-only gradient mismatch: rel={v_rel:.3e}, \
2492 analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2493 );
2494 }
2495
2496 #[test]
2497 pub(crate) fn binomial_logit_n30_joint_design_penalty_gradient_matches_fd() {
2498 let f = BinomialLogitDesignMotionFixture::new();
2503 let state = f.state();
2504 let hyper = DirectionalHyperParam::single_penalty(
2505 0,
2506 f.x_tau_design.clone(),
2507 f.s_tau_penalty.clone(),
2508 None,
2509 None,
2510 )
2511 .expect("joint design+penalty hyper direction");
2512
2513 let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2514 .expect("analytic directional gradient");
2515 let v_tau_fd = f.fd_directional_gradient(&f.x_tau_design, &f.s_tau_penalty);
2516
2517 let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2518 assert!(
2519 v_rel < 1e-3,
2520 "Binomial-logit n=30 joint design+penalty gradient mismatch: rel={v_rel:.3e}, \
2521 analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2522 );
2523 }
2524
2525 #[test]
2526 pub(crate) fn binomial_logit_n30_design_moving_hessian_matches_fd() {
2527 let f = BinomialLogitDesignMotionFixture::new();
2532 let x_tau_0 = Array2::<f64>::zeros(f.x.raw_dim());
2533 let s_tau_0 = f.s_tau_penalty.clone();
2534 let x_tau_1 = f.x_tau_design.clone();
2535 let s_tau_1 = Array2::<f64>::zeros((5, 5));
2536
2537 let hyper_dirs = vec![
2538 DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2539 .expect("penalty-only direction"),
2540 DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2541 .expect("design-moving direction"),
2542 ];
2543
2544 let state = f.state();
2545 let mut theta = Array1::<f64>::zeros(f.rho.len() + hyper_dirs.len());
2546 theta.slice_mut(s![..f.rho.len()]).assign(&f.rho);
2547 let (_, _, h_full) =
2548 compute_joint_hypercostgradienthessian(&state, &theta, f.rho.len(), &hyper_dirs)
2549 .expect("joint cost+gradient+hessian");
2550 let h_tt_analytic = h_full.slice(s![f.rho.len().., f.rho.len()..]).to_owned();
2551
2552 let x_tau_mats = vec![x_tau_0.clone(), x_tau_1.clone()];
2553 let s_tau_mats = vec![s_tau_0.clone(), s_tau_1.clone()];
2554 let h_tt_fd = directional_tau_hessian_fd_reference(
2555 &f.y,
2556 &f.w,
2557 &f.x,
2558 &f.s0,
2559 &f.cfg,
2560 &f.rho,
2561 &hyper_dirs,
2562 &x_tau_mats,
2563 &s_tau_mats,
2564 );
2565
2566 let num = (&h_tt_analytic - &h_tt_fd)
2567 .iter()
2568 .map(|v| v * v)
2569 .sum::<f64>()
2570 .sqrt();
2571 let den = h_tt_fd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2572 let rel = num / den;
2573 assert!(
2574 rel < 1e-4,
2575 "Binomial-logit n=30 tau-tau Hessian mismatch: rel={rel:.3e}, \
2576 analytic={h_tt_analytic:?}, fd={h_tt_fd:?}"
2577 );
2578 }
2579
2580 #[test]
2581 pub(crate) fn binomial_logit_n30_nonzero_rho_design_moving_gradient_matches_fd() {
2582 let f = BinomialLogitDesignMotionFixture::new();
2586 let rho = array![1.5];
2587 let s_tau = Array2::<f64>::zeros((5, 5));
2588
2589 let state = f.state();
2590 let hyper = DirectionalHyperParam::single_penalty(
2591 0,
2592 f.x_tau_design.clone(),
2593 s_tau.clone(),
2594 None,
2595 None,
2596 )
2597 .expect("design-moving hyper direction");
2598
2599 let v_tau_analytic = single_directional_tau_gradient(&state, &rho, hyper)
2600 .expect("analytic directional gradient");
2601
2602 let h = 2e-5;
2604 let (state_plus, state_minus) = f.state_perturbed(&f.x_tau_design, &s_tau, h);
2605 let v_plus = state_plus.compute_cost(&rho).expect("cost+");
2606 let v_minus = state_minus.compute_cost(&rho).expect("cost-");
2607 let v_tau_fd = (v_plus - v_minus) / (2.0 * h);
2608
2609 let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2610 assert!(
2611 v_rel < 1e-3,
2612 "Binomial-logit n=30 rho=1.5 design-moving gradient mismatch: rel={v_rel:.3e}, \
2613 analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2614 );
2615 }
2616
2617 #[test]
2618 pub(crate) fn binomial_logit_n30_rank_deficient_hessian_matches_cost_fd() {
2619 let f = BinomialLogitDesignMotionFixture::new();
2654 let x_tau_0 = Array2::<f64>::zeros(f.x.raw_dim());
2655 let s_tau_0 = f.s_tau_penalty.clone();
2656 let x_tau_1 = f.x_tau_design.clone();
2657 let s_tau_1 = Array2::<f64>::zeros((5, 5));
2658
2659 let hyper_dirs = vec![
2660 DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2661 .expect("penalty-only direction"),
2662 DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2663 .expect("design-moving direction"),
2664 ];
2665
2666 let state = f.state();
2668 let mut theta = Array1::<f64>::zeros(f.rho.len() + hyper_dirs.len());
2669 theta.slice_mut(s![..f.rho.len()]).assign(&f.rho);
2670 let (_, _, h_full) =
2671 compute_joint_hypercostgradienthessian(&state, &theta, f.rho.len(), &hyper_dirs)
2672 .expect("joint cost+gradient+hessian");
2673 let h_tt_analytic = h_full.slice(s![f.rho.len().., f.rho.len()..]).to_owned();
2674
2675 const TARGET_PHYSICAL_STEP: f64 = 1e-5;
2679 let x_tau_mats = [&x_tau_0, &x_tau_1];
2680 let s_tau_mats = [&s_tau_0, &s_tau_1];
2681 let steps: [f64; 2] = {
2682 let mut steps = [0.0; 2];
2683 for (j, step) in steps.iter_mut().enumerate() {
2684 let scale = x_tau_mats[j]
2685 .iter()
2686 .chain(s_tau_mats[j].iter())
2687 .fold(0.0_f64, |acc, value| acc.max(value.abs()));
2688 *step = if scale > 0.0 {
2689 TARGET_PHYSICAL_STEP / scale
2690 } else {
2691 TARGET_PHYSICAL_STEP
2692 };
2693 }
2694 steps
2695 };
2696
2697 let eval_cost = |a: f64, b: f64| -> f64 {
2699 let x_eval = &f.x
2700 + &x_tau_mats[0].mapv(|v| a * steps[0] * v)
2701 + &x_tau_mats[1].mapv(|v| b * steps[1] * v);
2702 let s_eval = &f.s0
2703 + &s_tau_mats[0].mapv(|v| a * steps[0] * v)
2704 + &s_tau_mats[1].mapv(|v| b * steps[1] * v);
2705 let st = build_logit_state(&f.y, &f.w, &x_eval, &s_eval, &f.cfg);
2706 st.compute_cost(&f.rho).expect("cost eval")
2707 };
2708
2709 let v_00 = eval_cost(0.0, 0.0);
2710 let v_p0 = eval_cost(1.0, 0.0);
2711 let v_m0 = eval_cost(-1.0, 0.0);
2712 let v_0p = eval_cost(0.0, 1.0);
2713 let v_0m = eval_cost(0.0, -1.0);
2714 let v_pp = eval_cost(1.0, 1.0);
2715 let v_pm = eval_cost(1.0, -1.0);
2716 let v_mp = eval_cost(-1.0, 1.0);
2717 let v_mm = eval_cost(-1.0, -1.0);
2718
2719 let h00_fd = (v_p0 - 2.0 * v_00 + v_m0) / (steps[0] * steps[0]);
2720 let h11_fd = (v_0p - 2.0 * v_00 + v_0m) / (steps[1] * steps[1]);
2721 let h01_fd = (v_pp - v_pm - v_mp + v_mm) / (4.0 * steps[0] * steps[1]);
2722
2723 let h_tt_fd = array![[h00_fd, h01_fd], [h01_fd, h11_fd]];
2724
2725 let num = (&h_tt_analytic - &h_tt_fd)
2726 .iter()
2727 .map(|v| v * v)
2728 .sum::<f64>()
2729 .sqrt();
2730 let den = h_tt_fd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2731 let rel = num / den;
2732
2733 assert!(
2734 rel < 3e-3,
2735 "Binomial-logit n=30 rank-deficient Hessian vs cost-FD mismatch: rel={rel:.3e}, \
2736 analytic={h_tt_analytic:?}, fd={h_tt_fd:?}"
2737 );
2738 }
2739}
2740
2741#[derive(Clone, Copy, Debug)]
2742pub(crate) enum RemlGeometry {
2743 DenseSpectral,
2744 SparseExactSpd,
2745}
2746
2747trait PenalizedGeometry {
2748 fn backend_kind(&self) -> GeometryBackendKind;
2749}
2750
2751#[derive(Clone)]
2752pub(crate) enum DerivativeMatrixStorage {
2753 Dense(Array2<f64>),
2754 Zero(ZeroDerivativeMatrix),
2755 Embedded(EmbeddedDerivativeMatrix),
2756 Implicit(ImplicitDerivativeOp),
2757 LatentCoord(LatentCoordDerivativeOp),
2758}
2759
2760trait DerivativeStorageBackend {
2772 fn resident_byte_count(&self) -> usize;
2773 fn design_nrows(&self) -> usize;
2774 fn design_ncols(&self) -> usize;
2775 fn penalty_dim(&self) -> usize;
2776 fn uses_implicit_storage(&self) -> bool;
2777 fn any_nonzero(&self) -> bool;
2778 fn materialize(&self) -> Array2<f64>;
2779 fn implicit_first_axis_info(
2780 &self,
2781 ) -> Option<(
2782 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
2783 usize,
2784 )>;
2785 fn implicit_axis_count_hint(&self) -> Option<usize>;
2786 fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError>;
2787 fn design_transpose_mul_original(
2788 &self,
2789 v: &Array1<f64>,
2790 ) -> Result<Array1<f64>, EstimationError>;
2791 fn design_transformed(
2792 &self,
2793 qs: &Array2<f64>,
2794 free_basis_opt: Option<&Array2<f64>>,
2795 ) -> Result<Array2<f64>, EstimationError>;
2796 fn design_transformed_forward_mul(
2800 &self,
2801 qs: &Array2<f64>,
2802 free_basis_opt: Option<&Array2<f64>>,
2803 u: &Array1<f64>,
2804 ) -> Result<Array1<f64>, EstimationError> {
2805 Ok(self.design_transformed(qs, free_basis_opt)?.dot(u))
2806 }
2807 fn design_transformed_transpose_mul(
2810 &self,
2811 qs: &Array2<f64>,
2812 free_basis_opt: Option<&Array2<f64>>,
2813 v: &Array1<f64>,
2814 ) -> Result<Array1<f64>, EstimationError> {
2815 Ok(self.design_transformed(qs, free_basis_opt)?.t().dot(v))
2816 }
2817 fn penalty_transformed(
2818 &self,
2819 qs: &Array2<f64>,
2820 free_basis_opt: Option<&Array2<f64>>,
2821 ) -> Result<Array2<f64>, EstimationError>;
2822 fn penalty_scaled_add_to(
2823 &self,
2824 target: &mut Array2<f64>,
2825 amp: f64,
2826 ) -> Result<(), EstimationError>;
2827}
2828
2829macro_rules! storage_dispatch {
2834 ($scrutinee:expr, $backend:ident => $body:expr) => {
2835 match $scrutinee {
2836 DerivativeMatrixStorage::Dense($backend) => $body,
2837 DerivativeMatrixStorage::Zero($backend) => $body,
2838 DerivativeMatrixStorage::Embedded($backend) => $body,
2839 DerivativeMatrixStorage::Implicit($backend) => $body,
2840 DerivativeMatrixStorage::LatentCoord($backend) => $body,
2841 }
2842 };
2843}
2844
2845#[derive(Clone)]
2846pub(crate) struct ZeroDerivativeMatrix {
2847 rows: usize,
2848 cols: usize,
2849}
2850
2851impl ZeroDerivativeMatrix {
2852 pub(crate) fn new(rows: usize, cols: usize) -> Self {
2853 Self { rows, cols }
2854 }
2855}
2856
2857#[derive(Clone, Copy, Debug)]
2859pub enum ImplicitDerivLevel {
2860 First(usize),
2862 SecondDiag(usize),
2864 SecondCross(usize, usize),
2866}
2867
2868#[derive(Clone)]
2871pub(crate) struct ImplicitDerivativeOp {
2872 pub(crate) operator: std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
2873 pub(crate) level: ImplicitDerivLevel,
2874 pub(crate) global_range: Range<usize>,
2875 pub(crate) total_dim: usize,
2876 pub(crate) cached_dense: std::sync::Arc<gam_runtime::resource::RayonSafeOnce<Array2<f64>>>,
2886}
2887
2888#[derive(Clone)]
2889pub(crate) struct LatentCoordDerivativeOp {
2890 pub(crate) operator: std::sync::Arc<gam_terms::basis::LatentCoordDesignDerivative>,
2891 pub(crate) flat_axis: usize,
2892 pub(crate) global_range: Range<usize>,
2893 pub(crate) total_dim: usize,
2894 pub(crate) cached_dense: std::sync::Arc<gam_runtime::resource::RayonSafeOnce<Array2<f64>>>,
2895}
2896
2897impl LatentCoordDerivativeOp {
2898 pub(crate) fn materialize_local(&self) -> Array2<f64> {
2899 self.operator.materialize_axis(self.flat_axis).expect(
2900 "radial scalar evaluation failed during latent-coordinate derivative materialization",
2901 )
2902 }
2903
2904 pub(crate) fn materialize_dense(&self) -> &Array2<f64> {
2905 self.cached_dense.get_or_compute(|| {
2906 let local = self.materialize_local();
2907 let mut out = Array2::<f64>::zeros((local.nrows(), self.total_dim));
2908 out.slice_mut(s![.., self.global_range.clone()])
2909 .assign(&local);
2910 out
2911 })
2912 }
2913
2914 pub(crate) fn nrows(&self) -> usize {
2915 self.operator.n_data()
2916 }
2917
2918 pub(crate) fn ncols(&self) -> usize {
2919 self.total_dim
2920 }
2921
2922 pub(crate) fn transpose_mul(&self, v: &Array1<f64>) -> Array1<f64> {
2923 let local = self
2924 .operator
2925 .transpose_mul_axis(self.flat_axis, &v.view())
2926 .expect(
2927 "radial scalar evaluation failed during latent-coordinate derivative transpose_mul",
2928 );
2929 let mut out = Array1::<f64>::zeros(self.total_dim);
2930 out.slice_mut(s![self.global_range.clone()]).assign(&local);
2931 out
2932 }
2933
2934 pub(crate) fn forward_mul(&self, u: &Array1<f64>) -> Array1<f64> {
2935 let u_local = u.slice(s![self.global_range.clone()]).to_owned();
2936 self.operator
2937 .forward_mul_axis(self.flat_axis, &u_local.view())
2938 .expect(
2939 "radial scalar evaluation failed during latent-coordinate derivative forward_mul",
2940 )
2941 }
2942}
2943
2944impl ImplicitDerivativeOp {
2945 pub(crate) fn materialize_local(&self) -> Array2<f64> {
2946 match self.level {
2947 ImplicitDerivLevel::First(axis) => self.operator.materialize_first(axis).expect(
2948 "radial scalar evaluation failed during implicit derivative materialization",
2949 ),
2950 ImplicitDerivLevel::SecondDiag(axis) => {
2951 self.operator.materialize_second_diag(axis).expect(
2952 "radial scalar evaluation failed during implicit derivative materialization",
2953 )
2954 }
2955 ImplicitDerivLevel::SecondCross(d, e) => {
2956 self.operator.materialize_second_cross(d, e).expect(
2957 "radial scalar evaluation failed during implicit derivative materialization",
2958 )
2959 }
2960 }
2961 }
2962
2963 pub(crate) fn materialize_dense(&self) -> &Array2<f64> {
2964 self.cached_dense.get_or_compute(|| {
2965 let local = self.materialize_local();
2966 let mut out = Array2::<f64>::zeros((local.nrows(), self.total_dim));
2967 out.slice_mut(s![.., self.global_range.clone()])
2968 .assign(&local);
2969 out
2970 })
2971 }
2972
2973 pub(crate) fn nrows(&self) -> usize {
2974 self.operator.n_data()
2975 }
2976
2977 pub(crate) fn ncols(&self) -> usize {
2978 self.total_dim
2979 }
2980
2981 pub(crate) fn transpose_mul(&self, v: &Array1<f64>) -> Array1<f64> {
2982 let local = match self.level {
2983 ImplicitDerivLevel::First(axis) => self
2984 .operator
2985 .transpose_mul(axis, &v.view())
2986 .expect("radial scalar evaluation failed during implicit derivative transpose_mul"),
2987 ImplicitDerivLevel::SecondDiag(axis) => self
2988 .operator
2989 .transpose_mul_second_diag(axis, &v.view())
2990 .expect("radial scalar evaluation failed during implicit derivative transpose_mul"),
2991 ImplicitDerivLevel::SecondCross(d, e) => self
2992 .operator
2993 .transpose_mul_second_cross(d, e, &v.view())
2994 .expect("radial scalar evaluation failed during implicit derivative transpose_mul"),
2995 };
2996 let mut out = Array1::<f64>::zeros(self.total_dim);
2997 out.slice_mut(s![self.global_range.clone()]).assign(&local);
2998 out
2999 }
3000
3001 pub(crate) fn forward_mul(&self, u: &Array1<f64>) -> Array1<f64> {
3002 let u_local = u.slice(s![self.global_range.clone()]).to_owned();
3003 match self.level {
3004 ImplicitDerivLevel::First(axis) => self
3005 .operator
3006 .forward_mul(axis, &u_local.view())
3007 .expect("radial scalar evaluation failed during implicit derivative forward_mul"),
3008 ImplicitDerivLevel::SecondDiag(axis) => self
3009 .operator
3010 .forward_mul_second_diag(axis, &u_local.view())
3011 .expect("radial scalar evaluation failed during implicit derivative forward_mul"),
3012 ImplicitDerivLevel::SecondCross(d, e) => self
3013 .operator
3014 .forward_mul_second_cross(d, e, &u_local.view())
3015 .expect("radial scalar evaluation failed during implicit derivative forward_mul"),
3016 }
3017 }
3018}
3019
3020#[derive(Clone)]
3021pub(crate) struct EmbeddedDerivativeMatrix {
3022 pub(crate) local: Array2<f64>,
3023 pub(crate) global_range: Range<usize>,
3024 pub(crate) total_dim: usize,
3025}
3026
3027impl EmbeddedDerivativeMatrix {
3028 pub(crate) fn new(local: Array2<f64>, global_range: Range<usize>, total_dim: usize) -> Self {
3029 Self {
3030 local,
3031 global_range,
3032 total_dim,
3033 }
3034 }
3035}
3036
3037impl DerivativeStorageBackend for Array2<f64> {
3038 fn resident_byte_count(&self) -> usize {
3039 self.len().saturating_mul(std::mem::size_of::<f64>())
3040 }
3041 fn design_nrows(&self) -> usize {
3042 Array2::nrows(self)
3043 }
3044 fn design_ncols(&self) -> usize {
3045 Array2::ncols(self)
3046 }
3047 fn penalty_dim(&self) -> usize {
3048 Array2::nrows(self)
3049 }
3050 fn uses_implicit_storage(&self) -> bool {
3051 false
3052 }
3053 fn any_nonzero(&self) -> bool {
3054 self.iter().any(|v| *v != 0.0)
3055 }
3056 fn materialize(&self) -> Array2<f64> {
3057 self.clone()
3058 }
3059 fn implicit_first_axis_info(
3060 &self,
3061 ) -> Option<(
3062 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3063 usize,
3064 )> {
3065 None
3066 }
3067 fn implicit_axis_count_hint(&self) -> Option<usize> {
3068 None
3069 }
3070
3071 fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3072 if Array2::ncols(self) != u.len() {
3073 crate::bail_invalid_estim!(
3074 "dense hyper design derivative forward_mul_original width mismatch: matrix={}x{}, vector={}",
3075 Array2::nrows(self),
3076 Array2::ncols(self),
3077 u.len()
3078 );
3079 }
3080 Ok(self.dot(u))
3081 }
3082
3083 fn design_transpose_mul_original(
3084 &self,
3085 v: &Array1<f64>,
3086 ) -> Result<Array1<f64>, EstimationError> {
3087 if Array2::nrows(self) != v.len() {
3088 crate::bail_invalid_estim!(
3089 "dense hyper design derivative transpose_mul_original height mismatch: matrix={}x{}, vector={}",
3090 Array2::nrows(self),
3091 Array2::ncols(self),
3092 v.len()
3093 );
3094 }
3095 Ok(self.t().dot(v))
3096 }
3097
3098 fn design_transformed(
3099 &self,
3100 qs: &Array2<f64>,
3101 free_basis_opt: Option<&Array2<f64>>,
3102 ) -> Result<Array2<f64>, EstimationError> {
3103 Ok(gam_linalg::matrix::DenseRightProductView::new(self)
3104 .with_factor(qs)
3105 .with_optional_factor(free_basis_opt)
3106 .materialize())
3107 }
3108
3109 fn penalty_transformed(
3110 &self,
3111 qs: &Array2<f64>,
3112 free_basis_opt: Option<&Array2<f64>>,
3113 ) -> Result<Array2<f64>, EstimationError> {
3114 let mut transformed = qs.t().dot(self).dot(qs);
3115 if let Some(z) = free_basis_opt {
3116 transformed = z.t().dot(&transformed).dot(z);
3117 }
3118 Ok(transformed)
3119 }
3120
3121 fn penalty_scaled_add_to(
3122 &self,
3123 target: &mut Array2<f64>,
3124 amp: f64,
3125 ) -> Result<(), EstimationError> {
3126 if target.raw_dim() != self.raw_dim() {
3127 crate::bail_invalid_estim!(
3128 "dense hyper penalty derivative shape mismatch: target={}x{}, matrix={}x{}",
3129 target.nrows(),
3130 target.ncols(),
3131 Array2::nrows(self),
3132 Array2::ncols(self)
3133 );
3134 }
3135 target.scaled_add(amp, self);
3136 Ok(())
3137 }
3138}
3139
3140impl DerivativeStorageBackend for ZeroDerivativeMatrix {
3141 fn resident_byte_count(&self) -> usize {
3142 0
3143 }
3144 fn design_nrows(&self) -> usize {
3145 self.rows
3146 }
3147 fn design_ncols(&self) -> usize {
3148 self.cols
3149 }
3150 fn penalty_dim(&self) -> usize {
3151 self.cols
3152 }
3153 fn uses_implicit_storage(&self) -> bool {
3154 false
3155 }
3156 fn any_nonzero(&self) -> bool {
3157 false
3158 }
3159 fn materialize(&self) -> Array2<f64> {
3160 Array2::<f64>::zeros((self.rows, self.cols))
3161 }
3162 fn implicit_first_axis_info(
3163 &self,
3164 ) -> Option<(
3165 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3166 usize,
3167 )> {
3168 None
3169 }
3170 fn implicit_axis_count_hint(&self) -> Option<usize> {
3171 None
3172 }
3173
3174 fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3175 if self.cols != u.len() {
3176 crate::bail_invalid_estim!(
3177 "zero hyper design derivative forward_mul_original width mismatch: matrix={}x{}, vector={}",
3178 self.rows,
3179 self.cols,
3180 u.len()
3181 );
3182 }
3183 Ok(Array1::<f64>::zeros(self.rows))
3184 }
3185
3186 fn design_transpose_mul_original(
3187 &self,
3188 v: &Array1<f64>,
3189 ) -> Result<Array1<f64>, EstimationError> {
3190 if self.rows != v.len() {
3191 crate::bail_invalid_estim!(
3192 "zero hyper design derivative transpose_mul_original height mismatch: matrix={}x{}, vector={}",
3193 self.rows,
3194 self.cols,
3195 v.len()
3196 );
3197 }
3198 Ok(Array1::<f64>::zeros(self.cols))
3199 }
3200
3201 fn design_transformed(
3202 &self,
3203 qs: &Array2<f64>,
3204 free_basis_opt: Option<&Array2<f64>>,
3205 ) -> Result<Array2<f64>, EstimationError> {
3206 if self.cols != qs.nrows() {
3207 crate::bail_invalid_estim!(
3208 "zero design derivative width mismatch: total_cols={}, qs rows={}",
3209 self.cols,
3210 qs.nrows()
3211 );
3212 }
3213 let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3214 Ok(Array2::<f64>::zeros((self.rows, cols)))
3215 }
3216
3217 fn design_transformed_forward_mul(
3218 &self,
3219 qs: &Array2<f64>,
3220 free_basis_opt: Option<&Array2<f64>>,
3221 u: &Array1<f64>,
3222 ) -> Result<Array1<f64>, EstimationError> {
3223 if self.cols != qs.nrows() {
3224 crate::bail_invalid_estim!(
3225 "zero design derivative width mismatch: total_cols={}, qs rows={}",
3226 self.cols,
3227 qs.nrows()
3228 );
3229 }
3230 let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3231 if u.len() != cols {
3232 crate::bail_invalid_estim!(
3233 "zero design derivative transformed forward width mismatch: expected {}, vector={}",
3234 cols,
3235 u.len()
3236 );
3237 }
3238 Ok(Array1::<f64>::zeros(self.rows))
3239 }
3240
3241 fn design_transformed_transpose_mul(
3242 &self,
3243 qs: &Array2<f64>,
3244 free_basis_opt: Option<&Array2<f64>>,
3245 v: &Array1<f64>,
3246 ) -> Result<Array1<f64>, EstimationError> {
3247 if self.rows != v.len() {
3248 crate::bail_invalid_estim!(
3249 "zero design derivative transpose height mismatch: matrix rows={}, vector={}",
3250 self.rows,
3251 v.len()
3252 );
3253 }
3254 if self.cols != qs.nrows() {
3255 crate::bail_invalid_estim!(
3256 "zero design derivative width mismatch: total_cols={}, qs rows={}",
3257 self.cols,
3258 qs.nrows()
3259 );
3260 }
3261 let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3262 Ok(Array1::<f64>::zeros(cols))
3263 }
3264
3265 fn penalty_transformed(
3266 &self,
3267 qs: &Array2<f64>,
3268 free_basis_opt: Option<&Array2<f64>>,
3269 ) -> Result<Array2<f64>, EstimationError> {
3270 if self.cols != qs.nrows() {
3271 crate::bail_invalid_estim!(
3272 "zero penalty derivative width mismatch: total_dim={}, qs rows={}",
3273 self.cols,
3274 qs.nrows()
3275 );
3276 }
3277 let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3278 Ok(Array2::<f64>::zeros((cols, cols)))
3279 }
3280
3281 fn penalty_scaled_add_to(
3282 &self,
3283 target: &mut Array2<f64>,
3284 amp: f64,
3285 ) -> Result<(), EstimationError> {
3286 if !amp.is_finite() {
3290 crate::bail_invalid_estim!(
3291 "zero hyper penalty derivative received non-finite amp={amp}"
3292 );
3293 }
3294 if target.nrows() != self.cols || target.ncols() != self.cols {
3295 crate::bail_invalid_estim!(
3296 "zero hyper penalty derivative shape mismatch: target={}x{}, expected {}x{}",
3297 target.nrows(),
3298 target.ncols(),
3299 self.cols,
3300 self.cols
3301 );
3302 }
3303 Ok(())
3304 }
3305}
3306
3307impl DerivativeStorageBackend for EmbeddedDerivativeMatrix {
3308 fn resident_byte_count(&self) -> usize {
3309 self.local.len().saturating_mul(std::mem::size_of::<f64>())
3310 }
3311 fn design_nrows(&self) -> usize {
3312 self.local.nrows()
3313 }
3314 fn design_ncols(&self) -> usize {
3315 self.total_dim
3316 }
3317 fn penalty_dim(&self) -> usize {
3318 self.total_dim
3319 }
3320 fn uses_implicit_storage(&self) -> bool {
3321 false
3322 }
3323 fn any_nonzero(&self) -> bool {
3324 self.local.iter().any(|v| *v != 0.0)
3325 }
3326 fn materialize(&self) -> Array2<f64> {
3327 let mut dense = Array2::<f64>::zeros((self.local.nrows(), self.total_dim));
3328 dense
3329 .slice_mut(s![.., self.global_range.clone()])
3330 .assign(&self.local);
3331 dense
3332 }
3333 fn implicit_first_axis_info(
3334 &self,
3335 ) -> Option<(
3336 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3337 usize,
3338 )> {
3339 None
3340 }
3341 fn implicit_axis_count_hint(&self) -> Option<usize> {
3342 None
3343 }
3344
3345 fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3346 if self.total_dim != u.len() {
3347 crate::bail_invalid_estim!(
3348 "embedded hyper design derivative forward_mul_original width mismatch: total_dim={}, vector={}",
3349 self.total_dim,
3350 u.len()
3351 );
3352 }
3353 let u_local = u.slice(s![self.global_range.clone()]).to_owned();
3354 Ok(self.local.dot(&u_local))
3355 }
3356
3357 fn design_transpose_mul_original(
3358 &self,
3359 v: &Array1<f64>,
3360 ) -> Result<Array1<f64>, EstimationError> {
3361 if self.local.nrows() != v.len() {
3362 crate::bail_invalid_estim!(
3363 "embedded hyper design derivative transpose_mul_original height mismatch: local_rows={}, vector={}",
3364 self.local.nrows(),
3365 v.len()
3366 );
3367 }
3368 let mut out = Array1::<f64>::zeros(self.total_dim);
3369 let pulled = self.local.t().dot(v);
3370 out.slice_mut(s![self.global_range.clone()]).assign(&pulled);
3371 Ok(out)
3372 }
3373
3374 fn design_transformed(
3375 &self,
3376 qs: &Array2<f64>,
3377 free_basis_opt: Option<&Array2<f64>>,
3378 ) -> Result<Array2<f64>, EstimationError> {
3379 if self.total_dim != qs.nrows() {
3380 crate::bail_invalid_estim!(
3381 "embedded design derivative width mismatch: total_cols={}, qs rows={}",
3382 self.total_dim,
3383 qs.nrows()
3384 );
3385 }
3386 let qs_local = qs.slice(s![self.global_range.clone(), ..]);
3387 let mut transformed = self.local.dot(&qs_local);
3388 if let Some(z) = free_basis_opt {
3389 transformed = transformed.dot(z);
3390 }
3391 Ok(transformed)
3392 }
3393
3394 fn penalty_transformed(
3395 &self,
3396 qs: &Array2<f64>,
3397 free_basis_opt: Option<&Array2<f64>>,
3398 ) -> Result<Array2<f64>, EstimationError> {
3399 if self.total_dim != qs.nrows() {
3400 crate::bail_invalid_estim!(
3401 "embedded penalty derivative width mismatch: total_dim={}, qs rows={}",
3402 self.total_dim,
3403 qs.nrows()
3404 );
3405 }
3406 let qs_local = qs.slice(s![self.global_range.clone(), ..]);
3407 let mut transformed = qs_local.t().dot(&self.local).dot(&qs_local);
3408 if let Some(z) = free_basis_opt {
3409 transformed = z.t().dot(&transformed).dot(z);
3410 }
3411 Ok(transformed)
3412 }
3413
3414 fn penalty_scaled_add_to(
3415 &self,
3416 target: &mut Array2<f64>,
3417 amp: f64,
3418 ) -> Result<(), EstimationError> {
3419 if target.nrows() != self.total_dim || target.ncols() != self.total_dim {
3420 crate::bail_invalid_estim!(
3421 "embedded hyper penalty derivative shape mismatch: target={}x{}, expected {}x{}",
3422 target.nrows(),
3423 target.ncols(),
3424 self.total_dim,
3425 self.total_dim
3426 );
3427 }
3428 target
3429 .slice_mut(s![self.global_range.clone(), self.global_range.clone()])
3430 .scaled_add(amp, &self.local);
3431 Ok(())
3432 }
3433}
3434
3435impl DerivativeStorageBackend for ImplicitDerivativeOp {
3436 fn resident_byte_count(&self) -> usize {
3437 0
3438 }
3439 fn design_nrows(&self) -> usize {
3440 self.nrows()
3441 }
3442 fn design_ncols(&self) -> usize {
3443 self.ncols()
3444 }
3445 fn penalty_dim(&self) -> usize {
3446 self.nrows()
3447 }
3448 fn uses_implicit_storage(&self) -> bool {
3449 true
3450 }
3451 fn any_nonzero(&self) -> bool {
3452 true
3453 }
3454 fn materialize(&self) -> Array2<f64> {
3455 self.materialize_dense().clone()
3456 }
3457 fn implicit_first_axis_info(
3458 &self,
3459 ) -> Option<(
3460 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3461 usize,
3462 )> {
3463 match self.level {
3464 ImplicitDerivLevel::First(axis) => Some((self.operator.clone(), axis)),
3465 _ => None,
3466 }
3467 }
3468 fn implicit_axis_count_hint(&self) -> Option<usize> {
3469 Some(self.operator.n_axes())
3470 }
3471
3472 fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3473 if self.ncols() != u.len() {
3474 crate::bail_invalid_estim!(
3475 "implicit hyper design derivative forward_mul_original width mismatch: operator_cols={}, vector={}",
3476 self.ncols(),
3477 u.len()
3478 );
3479 }
3480 Ok(self.forward_mul(u))
3481 }
3482
3483 fn design_transpose_mul_original(
3484 &self,
3485 v: &Array1<f64>,
3486 ) -> Result<Array1<f64>, EstimationError> {
3487 if self.nrows() != v.len() {
3488 crate::bail_invalid_estim!(
3489 "implicit hyper design derivative transpose_mul_original height mismatch: operator_rows={}, vector={}",
3490 self.nrows(),
3491 v.len()
3492 );
3493 }
3494 Ok(self.transpose_mul(v))
3495 }
3496
3497 fn design_transformed(
3498 &self,
3499 qs: &Array2<f64>,
3500 free_basis_opt: Option<&Array2<f64>>,
3501 ) -> Result<Array2<f64>, EstimationError> {
3502 let dense = self.materialize_dense();
3503 Ok(gam_linalg::matrix::DenseRightProductView::new(dense)
3504 .with_factor(qs)
3505 .with_optional_factor(free_basis_opt)
3506 .materialize())
3507 }
3508
3509 fn design_transformed_forward_mul(
3510 &self,
3511 qs: &Array2<f64>,
3512 free_basis_opt: Option<&Array2<f64>>,
3513 u: &Array1<f64>,
3514 ) -> Result<Array1<f64>, EstimationError> {
3515 let mut right = if let Some(z) = free_basis_opt {
3516 z.dot(u)
3517 } else {
3518 u.clone()
3519 };
3520 right = qs.dot(&right);
3521 Ok(self.forward_mul(&right))
3522 }
3523
3524 fn design_transformed_transpose_mul(
3525 &self,
3526 qs: &Array2<f64>,
3527 free_basis_opt: Option<&Array2<f64>>,
3528 v: &Array1<f64>,
3529 ) -> Result<Array1<f64>, EstimationError> {
3530 let mut pulled = qs.t().dot(&self.transpose_mul(v));
3531 if let Some(z) = free_basis_opt {
3532 pulled = z.t().dot(&pulled);
3533 }
3534 Ok(pulled)
3535 }
3536
3537 fn penalty_transformed(
3538 &self,
3539 qs: &Array2<f64>,
3540 free_basis_opt: Option<&Array2<f64>>,
3541 ) -> Result<Array2<f64>, EstimationError> {
3542 let dense = self.materialize_dense();
3543 let mut transformed = qs.t().dot(dense).dot(qs);
3544 if let Some(z) = free_basis_opt {
3545 transformed = z.t().dot(&transformed).dot(z);
3546 }
3547 Ok(transformed)
3548 }
3549
3550 fn penalty_scaled_add_to(
3551 &self,
3552 target: &mut Array2<f64>,
3553 amp: f64,
3554 ) -> Result<(), EstimationError> {
3555 let dense = self.materialize_dense();
3556 if target.raw_dim() != dense.raw_dim() {
3557 crate::bail_invalid_estim!(
3558 "implicit hyper penalty derivative shape mismatch: target={}x{}, matrix={}x{}",
3559 target.nrows(),
3560 target.ncols(),
3561 dense.nrows(),
3562 dense.ncols()
3563 );
3564 }
3565 target.scaled_add(amp, dense);
3566 Ok(())
3567 }
3568}
3569
3570impl DerivativeStorageBackend for LatentCoordDerivativeOp {
3571 fn resident_byte_count(&self) -> usize {
3572 0
3573 }
3574 fn design_nrows(&self) -> usize {
3575 self.nrows()
3576 }
3577 fn design_ncols(&self) -> usize {
3578 self.ncols()
3579 }
3580 fn penalty_dim(&self) -> usize {
3581 self.nrows()
3582 }
3583 fn uses_implicit_storage(&self) -> bool {
3584 true
3585 }
3586 fn any_nonzero(&self) -> bool {
3587 true
3588 }
3589 fn materialize(&self) -> Array2<f64> {
3590 self.materialize_dense().clone()
3591 }
3592 fn implicit_first_axis_info(
3593 &self,
3594 ) -> Option<(
3595 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3596 usize,
3597 )> {
3598 None
3599 }
3600 fn implicit_axis_count_hint(&self) -> Option<usize> {
3601 Some(self.operator.n_axes())
3602 }
3603
3604 fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3605 if self.ncols() != u.len() {
3606 crate::bail_invalid_estim!(
3607 "latent-coordinate hyper design derivative forward_mul_original width mismatch: operator_cols={}, vector={}",
3608 self.ncols(),
3609 u.len()
3610 );
3611 }
3612 Ok(self.forward_mul(u))
3613 }
3614
3615 fn design_transpose_mul_original(
3616 &self,
3617 v: &Array1<f64>,
3618 ) -> Result<Array1<f64>, EstimationError> {
3619 if self.nrows() != v.len() {
3620 crate::bail_invalid_estim!(
3621 "latent-coordinate hyper design derivative transpose_mul_original height mismatch: operator_rows={}, vector={}",
3622 self.nrows(),
3623 v.len()
3624 );
3625 }
3626 Ok(self.transpose_mul(v))
3627 }
3628
3629 fn design_transformed(
3630 &self,
3631 qs: &Array2<f64>,
3632 free_basis_opt: Option<&Array2<f64>>,
3633 ) -> Result<Array2<f64>, EstimationError> {
3634 let dense = self.materialize_dense();
3635 Ok(gam_linalg::matrix::DenseRightProductView::new(dense)
3636 .with_factor(qs)
3637 .with_optional_factor(free_basis_opt)
3638 .materialize())
3639 }
3640
3641 fn design_transformed_forward_mul(
3642 &self,
3643 qs: &Array2<f64>,
3644 free_basis_opt: Option<&Array2<f64>>,
3645 u: &Array1<f64>,
3646 ) -> Result<Array1<f64>, EstimationError> {
3647 let mut right = if let Some(z) = free_basis_opt {
3648 z.dot(u)
3649 } else {
3650 u.clone()
3651 };
3652 right = qs.dot(&right);
3653 Ok(self.forward_mul(&right))
3654 }
3655
3656 fn design_transformed_transpose_mul(
3657 &self,
3658 qs: &Array2<f64>,
3659 free_basis_opt: Option<&Array2<f64>>,
3660 v: &Array1<f64>,
3661 ) -> Result<Array1<f64>, EstimationError> {
3662 let mut pulled = qs.t().dot(&self.transpose_mul(v));
3663 if let Some(z) = free_basis_opt {
3664 pulled = z.t().dot(&pulled);
3665 }
3666 Ok(pulled)
3667 }
3668
3669 fn penalty_transformed(
3670 &self,
3671 qs: &Array2<f64>,
3672 free_basis_opt: Option<&Array2<f64>>,
3673 ) -> Result<Array2<f64>, EstimationError> {
3674 let dense = self.materialize_dense();
3675 let mut transformed = qs.t().dot(dense).dot(qs);
3676 if let Some(z) = free_basis_opt {
3677 transformed = z.t().dot(&transformed).dot(z);
3678 }
3679 Ok(transformed)
3680 }
3681
3682 fn penalty_scaled_add_to(
3683 &self,
3684 target: &mut Array2<f64>,
3685 amp: f64,
3686 ) -> Result<(), EstimationError> {
3687 let dense = self.materialize_dense();
3688 if target.raw_dim() != dense.raw_dim() {
3689 crate::bail_invalid_estim!(
3690 "latent-coordinate hyper penalty derivative shape mismatch: target={}x{}, matrix={}x{}",
3691 target.nrows(),
3692 target.ncols(),
3693 dense.nrows(),
3694 dense.ncols()
3695 );
3696 }
3697 target.scaled_add(amp, dense);
3698 Ok(())
3699 }
3700}
3701
3702#[derive(Clone)]
3703pub struct HyperDesignDerivative {
3704 pub(crate) storage: DerivativeMatrixStorage,
3705}
3706
3707impl HyperDesignDerivative {
3708 pub fn zero(nrows: usize, ncols: usize) -> Self {
3709 Self {
3710 storage: DerivativeMatrixStorage::Zero(ZeroDerivativeMatrix::new(nrows, ncols)),
3711 }
3712 }
3713
3714 pub fn from_embedded(
3715 local: Array2<f64>,
3716 global_range: Range<usize>,
3717 total_cols: usize,
3718 ) -> Self {
3719 Self {
3720 storage: DerivativeMatrixStorage::Embedded(EmbeddedDerivativeMatrix::new(
3721 local,
3722 global_range,
3723 total_cols,
3724 )),
3725 }
3726 }
3727
3728 pub fn from_implicit(
3729 operator: std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3730 level: ImplicitDerivLevel,
3731 global_range: Range<usize>,
3732 total_cols: usize,
3733 ) -> Self {
3734 Self {
3735 storage: DerivativeMatrixStorage::Implicit(ImplicitDerivativeOp {
3736 operator,
3737 level,
3738 global_range,
3739 total_dim: total_cols,
3740 cached_dense: std::sync::Arc::new(gam_runtime::resource::RayonSafeOnce::new()),
3741 }),
3742 }
3743 }
3744
3745 pub fn from_latent_coord(
3746 operator: std::sync::Arc<gam_terms::basis::LatentCoordDesignDerivative>,
3747 flat_axis: usize,
3748 global_range: Range<usize>,
3749 total_cols: usize,
3750 ) -> Self {
3751 Self {
3752 storage: DerivativeMatrixStorage::LatentCoord(LatentCoordDerivativeOp {
3753 operator,
3754 flat_axis,
3755 global_range,
3756 total_dim: total_cols,
3757 cached_dense: std::sync::Arc::new(gam_runtime::resource::RayonSafeOnce::new()),
3758 }),
3759 }
3760 }
3761
3762 pub(crate) fn resident_byte_count(&self) -> usize {
3763 storage_dispatch!(&self.storage, b => b.resident_byte_count())
3764 }
3765
3766 pub(crate) fn nrows(&self) -> usize {
3767 storage_dispatch!(&self.storage, b => b.design_nrows())
3768 }
3769
3770 pub(crate) fn ncols(&self) -> usize {
3771 storage_dispatch!(&self.storage, b => b.design_ncols())
3772 }
3773
3774 pub(crate) fn uses_implicit_storage(&self) -> bool {
3775 storage_dispatch!(&self.storage, b => b.uses_implicit_storage())
3776 }
3777
3778 pub(crate) fn materialize(&self) -> Array2<f64> {
3779 storage_dispatch!(&self.storage, b => b.materialize())
3780 }
3781
3782 pub(crate) fn any_nonzero(&self) -> bool {
3783 storage_dispatch!(&self.storage, b => b.any_nonzero())
3784 }
3785
3786 pub(crate) fn forward_mul_original(
3787 &self,
3788 u: &Array1<f64>,
3789 ) -> Result<Array1<f64>, EstimationError> {
3790 storage_dispatch!(&self.storage, b => b.design_forward_mul_original(u))
3791 }
3792
3793 pub(crate) fn transpose_mul_original(
3794 &self,
3795 v: &Array1<f64>,
3796 ) -> Result<Array1<f64>, EstimationError> {
3797 storage_dispatch!(&self.storage, b => b.design_transpose_mul_original(v))
3798 }
3799
3800 pub(crate) fn transformed(
3801 &self,
3802 qs: &Array2<f64>,
3803 free_basis_opt: Option<&Array2<f64>>,
3804 ) -> Result<Array2<f64>, EstimationError> {
3805 storage_dispatch!(&self.storage, b => b.design_transformed(qs, free_basis_opt))
3806 }
3807
3808 pub(crate) fn transformed_forward_mul(
3809 &self,
3810 qs: &Array2<f64>,
3811 free_basis_opt: Option<&Array2<f64>>,
3812 u: &Array1<f64>,
3813 ) -> Result<Array1<f64>, EstimationError> {
3814 storage_dispatch!(&self.storage, b => b.design_transformed_forward_mul(qs, free_basis_opt, u))
3815 }
3816
3817 pub(crate) fn transformed_transpose_mul(
3818 &self,
3819 qs: &Array2<f64>,
3820 free_basis_opt: Option<&Array2<f64>>,
3821 v: &Array1<f64>,
3822 ) -> Result<Array1<f64>, EstimationError> {
3823 storage_dispatch!(&self.storage, b => b.design_transformed_transpose_mul(qs, free_basis_opt, v))
3824 }
3825
3826 pub(crate) fn implicit_first_axis_info(
3831 &self,
3832 ) -> Option<(
3833 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3834 usize,
3835 )> {
3836 storage_dispatch!(&self.storage, b => b.implicit_first_axis_info())
3837 }
3838
3839 pub(crate) fn implicit_axis_count_hint(&self) -> Option<usize> {
3840 storage_dispatch!(&self.storage, b => b.implicit_axis_count_hint())
3841 }
3842}
3843
3844impl From<Array2<f64>> for HyperDesignDerivative {
3845 fn from(value: Array2<f64>) -> Self {
3846 Self {
3847 storage: DerivativeMatrixStorage::Dense(value),
3848 }
3849 }
3850}
3851
3852#[derive(Clone)]
3853pub struct HyperPenaltyDerivative {
3854 pub(crate) storage: DerivativeMatrixStorage,
3855}
3856
3857impl HyperPenaltyDerivative {
3858 pub fn from_embedded(
3859 local: Array2<f64>,
3860 global_range: Range<usize>,
3861 total_dim: usize,
3862 ) -> Self {
3863 Self {
3864 storage: DerivativeMatrixStorage::Embedded(EmbeddedDerivativeMatrix::new(
3865 local,
3866 global_range,
3867 total_dim,
3868 )),
3869 }
3870 }
3871
3872 pub(crate) fn resident_byte_count(&self) -> usize {
3873 storage_dispatch!(&self.storage, b => b.resident_byte_count())
3874 }
3875
3876 pub(crate) fn nrows(&self) -> usize {
3877 storage_dispatch!(&self.storage, b => b.penalty_dim())
3878 }
3879
3880 pub(crate) fn ncols(&self) -> usize {
3881 self.nrows()
3882 }
3883
3884 pub(crate) fn scaled_materialize(&self, amp: f64) -> Array2<f64> {
3885 let mut out = Array2::<f64>::zeros((self.nrows(), self.ncols()));
3886 self.scaled_add_to(&mut out, amp)
3887 .expect("scaled materialize uses matching target shape");
3888 out
3889 }
3890
3891 pub(crate) fn transformed(
3892 &self,
3893 qs: &Array2<f64>,
3894 free_basis_opt: Option<&Array2<f64>>,
3895 ) -> Result<Array2<f64>, EstimationError> {
3896 storage_dispatch!(&self.storage, b => b.penalty_transformed(qs, free_basis_opt))
3897 }
3898
3899 pub(crate) fn scaled_add_to(
3900 &self,
3901 target: &mut Array2<f64>,
3902 amp: f64,
3903 ) -> Result<(), EstimationError> {
3904 storage_dispatch!(&self.storage, b => b.penalty_scaled_add_to(target, amp))
3905 }
3906}
3907
3908impl From<Array2<f64>> for HyperPenaltyDerivative {
3909 fn from(value: Array2<f64>) -> Self {
3910 Self {
3911 storage: DerivativeMatrixStorage::Dense(value),
3912 }
3913 }
3914}
3915
3916#[derive(Clone)]
3917pub struct PenaltyDerivativeComponent {
3918 pub penalty_index: usize,
3919 pub matrix: HyperPenaltyDerivative,
3920}
3921
3922#[derive(Clone)]
3923pub struct DirectionalHyperParam {
3924 pub(crate) x_tau_original: HyperDesignDerivative,
3925 pub(crate) penalty_first_components: Vec<PenaltyDerivativeComponent>,
3928 pub(crate) x_tau_tau_original: Option<Vec<Option<HyperDesignDerivative>>>,
3932 pub(crate) penaltysecond_components: Option<Vec<Option<Vec<PenaltyDerivativeComponent>>>>,
3935 pub(crate) penaltysecond_component_provider: Option<
3936 std::sync::Arc<
3937 dyn Fn(usize) -> Result<Option<Vec<PenaltyDerivativeComponent>>, EstimationError>
3938 + Send
3939 + Sync
3940 + 'static,
3941 >,
3942 >,
3943 pub(crate) penaltysecond_partner_indices: Option<std::sync::Arc<[usize]>>,
3944 pub(crate) is_penalty_like: bool,
3948}
3949
3950impl DirectionalHyperParam {
3951 pub(crate) fn resident_byte_count(&self) -> usize {
3952 let mut bytes = self.x_tau_original.resident_byte_count();
3953 for component in &self.penalty_first_components {
3954 bytes = bytes.saturating_add(component.matrix.resident_byte_count());
3955 }
3956 if let Some(entries) = self.x_tau_tau_original.as_ref() {
3957 for entry in entries.iter().flatten() {
3958 bytes = bytes.saturating_add(entry.resident_byte_count());
3959 }
3960 }
3961 if let Some(rows) = self.penaltysecond_components.as_ref() {
3962 for components in rows.iter().flatten() {
3963 for component in components {
3964 bytes = bytes.saturating_add(component.matrix.resident_byte_count());
3965 }
3966 }
3967 }
3968 bytes
3969 }
3970
3971 pub(crate) fn canonicalize_penalty_components(
3972 components: Vec<(usize, HyperPenaltyDerivative)>,
3973 ) -> Result<Vec<PenaltyDerivativeComponent>, EstimationError> {
3974 let mut out: Vec<PenaltyDerivativeComponent> = Vec::with_capacity(components.len());
3975 for (penalty_index, matrix) in components {
3976 if out.iter().any(|c| c.penalty_index == penalty_index) {
3977 crate::bail_invalid_estim!(
3978 "duplicate penalty derivative component for penalty {}",
3979 penalty_index
3980 );
3981 }
3982 out.push(PenaltyDerivativeComponent {
3983 penalty_index,
3984 matrix,
3985 });
3986 }
3987 Ok(out)
3988 }
3989
3990 pub fn new_compact(
3991 x_tau_original: HyperDesignDerivative,
3992 penalty_first_components: Vec<(usize, HyperPenaltyDerivative)>,
3993 x_tau_tau_original: Option<Vec<Option<HyperDesignDerivative>>>,
3994 penaltysecond_components: Option<Vec<Option<Vec<(usize, HyperPenaltyDerivative)>>>>,
3995 ) -> Result<Self, EstimationError> {
3996 let is_penalty_like = !x_tau_original.any_nonzero();
3997 let penalty_first_components =
3998 Self::canonicalize_penalty_components(penalty_first_components)?;
3999 let penaltysecond_components = match penaltysecond_components {
4000 Some(rows) => {
4001 let mut out = Vec::with_capacity(rows.len());
4002 for row in rows {
4003 out.push(match row {
4004 Some(components) => {
4005 Some(Self::canonicalize_penalty_components(components)?)
4006 }
4007 None => None,
4008 });
4009 }
4010 Some(out)
4011 }
4012 None => None,
4013 };
4014 Ok(Self {
4015 x_tau_original,
4016 penalty_first_components,
4017 x_tau_tau_original,
4018 penaltysecond_components,
4019 penaltysecond_component_provider: None,
4020 penaltysecond_partner_indices: None,
4021 is_penalty_like,
4022 })
4023 }
4024
4025 pub fn not_penalty_like(mut self) -> Self {
4028 self.is_penalty_like = false;
4029 self
4030 }
4031
4032 pub fn with_penaltysecond_component_provider(
4033 mut self,
4034 provider: std::sync::Arc<
4035 dyn Fn(usize) -> Result<Option<Vec<PenaltyDerivativeComponent>>, EstimationError>
4036 + Send
4037 + Sync
4038 + 'static,
4039 >,
4040 ) -> Self {
4041 self.penaltysecond_component_provider = Some(provider);
4042 self
4043 }
4044
4045 pub fn with_penaltysecond_partner_indices(mut self, partners: Vec<usize>) -> Self {
4046 self.penaltysecond_partner_indices = Some(std::sync::Arc::from(partners));
4047 self
4048 }
4049
4050 pub(crate) fn x_tau_dense(&self) -> Array2<f64> {
4051 self.x_tau_original.materialize()
4052 }
4053
4054 pub(crate) fn transformed_x_tau(
4055 &self,
4056 qs: &Array2<f64>,
4057 free_basis_opt: Option<&Array2<f64>>,
4058 ) -> Result<Array2<f64>, EstimationError> {
4059 self.x_tau_original.transformed(qs, free_basis_opt)
4060 }
4061
4062 pub(crate) fn x_tau_tau_entry_at(&self, j: usize) -> Option<HyperDesignDerivative> {
4063 self.x_tau_tau_original
4064 .as_ref()
4065 .and_then(|rows| rows.get(j))
4066 .and_then(|entry| entry.clone())
4067 }
4068
4069 pub(crate) fn has_implicit_operator(&self) -> bool {
4072 self.x_tau_original.uses_implicit_storage()
4073 }
4074
4075 pub(crate) fn has_implicit_multidim_duchon(&self) -> bool {
4076 self.implicit_first_axis_info()
4077 .is_some_and(|(op, _)| op.n_axes() > 1 && op.is_duchon_family())
4078 }
4079
4080 pub(crate) fn implicit_first_axis_info(
4082 &self,
4083 ) -> Option<(
4084 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
4085 usize,
4086 )> {
4087 self.x_tau_original.implicit_first_axis_info()
4088 }
4089
4090 pub(crate) fn implicit_axis_count_hint(&self) -> Option<usize> {
4091 self.x_tau_original.implicit_axis_count_hint()
4092 }
4093
4094 pub(crate) fn penalty_first_components(&self) -> &[PenaltyDerivativeComponent] {
4095 &self.penalty_first_components
4096 }
4097
4098 pub(crate) fn penalty_total_at(
4099 &self,
4100 rho: &Array1<f64>,
4101 p: usize,
4102 ) -> Result<Array2<f64>, EstimationError> {
4103 let mut out = Array2::<f64>::zeros((p, p));
4104 for component in &self.penalty_first_components {
4105 if component.matrix.nrows() != p || component.matrix.ncols() != p {
4106 crate::bail_invalid_estim!(
4107 "S_tau shape mismatch for penalty {}: expected {}x{}, got {}x{}",
4108 component.penalty_index,
4109 p,
4110 p,
4111 component.matrix.nrows(),
4112 component.matrix.ncols()
4113 );
4114 }
4115 if component.penalty_index >= rho.len() {
4116 crate::bail_invalid_estim!(
4117 "penalty_index {} out of bounds for rho dimension {}",
4118 component.penalty_index,
4119 rho.len()
4120 );
4121 }
4122 component
4123 .matrix
4124 .scaled_add_to(&mut out, rho[component.penalty_index].exp())?;
4125 }
4126 Ok(out)
4127 }
4128
4129 pub(crate) fn penaltysecond_components_for(
4130 &self,
4131 j: usize,
4132 ) -> Result<Option<Vec<PenaltyDerivativeComponent>>, EstimationError> {
4133 if let Some(components) = self
4134 .penaltysecond_components
4135 .as_ref()
4136 .and_then(|rows| rows.get(j))
4137 .and_then(|row| row.clone())
4138 {
4139 return Ok(Some(components));
4140 }
4141 if let Some(provider) = self.penaltysecond_component_provider.as_ref() {
4142 return provider(j);
4143 }
4144 Ok(None)
4145 }
4146
4147 pub(crate) fn penaltysecond_componentrows(
4148 &self,
4149 ) -> Option<&[Option<Vec<PenaltyDerivativeComponent>>]> {
4150 self.penaltysecond_components.as_deref()
4151 }
4152
4153 pub(crate) fn penalty_first_component_count(&self) -> usize {
4154 self.penalty_first_components.len()
4155 }
4156
4157 pub(crate) fn has_penaltysecond_pair_at(&self, j: usize) -> bool {
4158 self.penaltysecond_components
4159 .as_ref()
4160 .and_then(|rows| rows.get(j))
4161 .is_some_and(Option::is_some)
4162 || self
4163 .penaltysecond_partner_indices
4164 .as_ref()
4165 .is_some_and(|partners| partners.contains(&j))
4166 }
4167}
4168
4169#[derive(Clone, Debug)]
4170pub(crate) struct SparseRemlDecision {
4171 pub(crate) geometry: RemlGeometry,
4172 pub(crate) reason: &'static str,
4173 pub(crate) p: usize,
4174 pub(crate) nnz_x: usize,
4175 pub(crate) nnz_h_upper_est: Option<usize>,
4176 pub(crate) density_h_upper_est: Option<f64>,
4177}
4178
4179#[derive(Clone)]
4180pub(crate) struct SparseExactEvalData {
4181 pub(crate) factor: Arc<SparseExactFactor>,
4182 pub(crate) takahashi: Option<Arc<gam_linalg::sparse_exact::TakahashiInverse>>,
4183 pub(crate) logdet_h: f64,
4184 pub(crate) logdet_s_pos: f64,
4185 pub(crate) penalty_rank: usize,
4186 pub(crate) det1_values: Arc<Array1<f64>>,
4187}
4188
4189#[derive(Clone)]
4190pub struct FirthDenseOperator {
4191 pub(crate) x_dense: Array2<f64>,
4218 pub(crate) x_dense_t: Array2<f64>,
4219 pub(crate) q_basis: Array2<f64>,
4222 pub(crate) x_reduced: Array2<f64>,
4225 pub(crate) observation_weight_sqrt: Option<Array1<f64>>,
4231 pub(crate) k_reduced: Array2<f64>,
4233 pub(crate) x_metric_reduced_inv_diag: Array1<f64>,
4238 pub(crate) half_log_det: f64,
4240 pub(crate) h_diag: Array1<f64>,
4242 pub(crate) w: Array1<f64>,
4244 pub(crate) w1: Array1<f64>,
4245 pub(crate) w2: Array1<f64>,
4246 pub(crate) w3: Array1<f64>,
4247 pub(crate) w4: Array1<f64>,
4248 pub(crate) b_base: Array2<f64>,
4250 pub(crate) p_b_base: Array2<f64>,
4253}
4254
4255#[derive(Clone)]
4256pub(crate) struct FirthDirection {
4257 pub(crate) deta: Array1<f64>,
4258 pub(crate) g_u_reduced: Array2<f64>,
4259 pub(crate) a_u_reduced: Array2<f64>,
4260 pub(crate) dh: Array1<f64>,
4261 pub(crate) b_uvec: Array1<f64>,
4263}
4264
4265#[derive(Clone)]
4266pub(crate) struct FirthTauPartialKernel {
4267 pub(super) deta_partial: Array1<f64>,
4268 pub(crate) dotw1: Array1<f64>,
4269 pub(crate) dotw2: Array1<f64>,
4270 pub(crate) dot_h_partial: Array1<f64>,
4271 pub(crate) x_tau_reduced: Array2<f64>,
4274 pub(super) dot_i_partial: Array2<f64>,
4275 pub(crate) dot_k_reduced: Array2<f64>,
4279}
4280
4281#[derive(Clone)]
4282pub(crate) struct FirthTauExactKernel {
4283 pub(crate) gphi_tau: Array1<f64>,
4284 pub(crate) phi_tau_partial: f64,
4285 pub(crate) tau_kernel: Option<FirthTauPartialKernel>,
4286}
4287
4288#[derive(Clone)]
4300pub(crate) struct FirthTauTauExactKernel {
4301 pub(super) phi_tau_tau_partial: f64,
4302 pub(super) gphi_tau_tau: Array1<f64>,
4303 pub(super) tau_tau_kernel: Option<FirthTauTauPartialKernel>,
4304}
4305
4306#[derive(Clone, Default)]
4319pub(crate) struct FirthTauTauPartialKernel {
4320 pub(super) x_tau_i_reduced: Array2<f64>,
4321 pub(super) x_tau_j_reduced: Array2<f64>,
4322 pub(super) deta_i_partial: Array1<f64>,
4323 pub(super) deta_j_partial: Array1<f64>,
4324 pub(super) dot_h_i_partial: Array1<f64>,
4325 pub(super) dot_h_j_partial: Array1<f64>,
4326 pub(super) dot_k_i_reduced: Array2<f64>,
4327 pub(super) dot_k_j_reduced: Array2<f64>,
4328 pub(super) dot_i_i_partial: Array2<f64>,
4329 pub(super) dot_i_j_partial: Array2<f64>,
4330 pub(super) x_tau_tau_reduced: Option<Array2<f64>>,
4331 pub(super) deta_ij_partial: Option<Array1<f64>>,
4332}
4333
4334#[derive(Clone, Default)]
4342pub(crate) struct FirthTauBetaPartialKernel {
4343 pub(super) x_tau_reduced: Array2<f64>,
4344 pub(super) deta_partial: Array1<f64>,
4345 pub(super) dot_h_partial: Array1<f64>,
4346 pub(super) dot_i_partial: Array2<f64>,
4347 pub(super) dot_k_reduced: Array2<f64>,
4348 pub(super) deta_v: Array1<f64>,
4349 pub(super) deta_tau_v: Array1<f64>,
4350 pub(super) a_v_reduced: Array2<f64>,
4351 pub(super) dh_v: Array1<f64>,
4352 pub(super) b_vvec: Array1<f64>,
4353 pub(super) d_beta_dot_k: Array2<f64>,
4354 pub(super) d_beta_dot_h: Array1<f64>,
4355}
4356
4357#[derive(Clone)]
4368pub(crate) struct EvalShared {
4369 pub(crate) key: Option<Vec<u64>>,
4370 pub(crate) pirls_result: Arc<PirlsResult>,
4371 pub(crate) ridge_passport: RidgePassport,
4372 pub(crate) geometry: RemlGeometry,
4373 pub(crate) h_total: Arc<Array2<f64>>,
4377 pub(crate) sparse_exact: Option<Arc<SparseExactEvalData>>,
4378 pub(crate) firth_dense_operator: Option<Arc<FirthDenseOperator>>,
4379 pub(crate) firth_dense_operator_original: Option<Arc<FirthDenseOperator>>,
4382 pub(crate) penalty_pseudologdet: std::sync::OnceLock<Arc<penalty_logdet::PenaltyPseudologdet>>,
4396 pub(crate) penalty_scores_at_mode: std::sync::OnceLock<Arc<Vec<Array1<f64>>>>,
4409 pub(crate) block_local_correction:
4427 std::sync::OnceLock<(usize, Arc<outer_eval::TkCorrectionTerms>)>,
4428}
4429
4430impl EvalShared {
4431 pub(crate) fn matches(&self, key: &Option<Vec<u64>>) -> bool {
4432 match (&self.key, key) {
4433 (None, None) => true,
4434 (Some(a), Some(b)) => a == b,
4435 _ => false,
4436 }
4437 }
4438
4439 pub(crate) fn penalty_pseudologdet_original(
4454 &self,
4455 canonical_penalties: &[gam_terms::construction::CanonicalPenalty],
4456 lambdas: &[f64],
4457 p: usize,
4458 ) -> Result<Arc<penalty_logdet::PenaltyPseudologdet>, EstimationError> {
4459 if let Some(pld) = self.penalty_pseudologdet.get() {
4460 if pld.dim() != p {
4461 return Err(EstimationError::LayoutError(format!(
4462 "shared penalty pseudo-logdet frame mismatch: cached p={}, requested p={}",
4463 pld.dim(),
4464 p
4465 )));
4466 }
4467 return Ok(Arc::clone(pld));
4468 }
4469 let pld = Arc::new(
4470 penalty_logdet::PenaltyPseudologdet::from_penalties(
4471 canonical_penalties,
4472 lambdas,
4473 self.ridge_passport.penalty_logdet_ridge(),
4474 p,
4475 )
4476 .map_err(EstimationError::InvalidInput)?,
4477 );
4478 match self.penalty_pseudologdet.set(Arc::clone(&pld)) {
4479 Ok(()) => Ok(pld),
4480 Err(_) => Ok(Arc::clone(
4484 self.penalty_pseudologdet
4485 .get()
4486 .expect("OnceLock set raced, so it is initialized"),
4487 )),
4488 }
4489 }
4490}
4491
4492impl PenalizedGeometry for EvalShared {
4493 fn backend_kind(&self) -> GeometryBackendKind {
4494 match self.geometry {
4495 RemlGeometry::DenseSpectral => GeometryBackendKind::DenseSpectral,
4496 RemlGeometry::SparseExactSpd => GeometryBackendKind::SparseExactSpd,
4497 }
4498 }
4499}
4500
4501pub(crate) struct PirlsLruCache {
4511 pub(crate) map: HashMap<Vec<u64>, (Arc<PirlsResult>, u64, usize)>,
4513 pub(crate) byte_budget: usize,
4514 pub(crate) current_bytes: usize,
4515 pub(crate) clock: u64,
4516}
4517
4518impl PirlsLruCache {
4519 pub(crate) fn new(byte_budget: usize) -> Self {
4520 Self {
4521 map: HashMap::new(),
4522 byte_budget: byte_budget.max(1),
4523 current_bytes: 0,
4524 clock: 0,
4525 }
4526 }
4527
4528 pub(crate) fn get(&mut self, key: &Vec<u64>) -> Option<Arc<PirlsResult>> {
4529 if let Some(entry) = self.map.get_mut(key) {
4530 self.clock += 1;
4531 entry.1 = self.clock;
4532 Some(entry.0.clone())
4533 } else {
4534 None
4535 }
4536 }
4537
4538 pub(crate) fn insert(&mut self, key: Vec<u64>, value: Arc<PirlsResult>) {
4539 self.clock += 1;
4540 let bytes = pirls_result_cache_bytes(&value);
4541 if bytes > self.byte_budget {
4545 if let Some((_, _, prev_bytes)) = self.map.remove(&key) {
4546 self.current_bytes = self.current_bytes.saturating_sub(prev_bytes);
4547 }
4548 return;
4549 }
4550 if let Some((_, _, prev_bytes)) = self.map.remove(&key) {
4551 self.current_bytes = self.current_bytes.saturating_sub(prev_bytes);
4552 }
4553 while self.current_bytes + bytes > self.byte_budget {
4554 let evict_key = self
4555 .map
4556 .iter()
4557 .min_by_key(|(_, (_, ts, _))| *ts)
4558 .map(|(k, _)| k.clone());
4559 match evict_key {
4560 Some(k) => {
4561 if let Some((_, _, evict_bytes)) = self.map.remove(&k) {
4562 self.current_bytes = self.current_bytes.saturating_sub(evict_bytes);
4563 }
4564 }
4565 None => break,
4566 }
4567 }
4568 self.current_bytes += bytes;
4569 self.map.insert(key, (value, self.clock, bytes));
4570 }
4571
4572 pub(crate) fn clear(&mut self) {
4573 self.map.clear();
4574 self.current_bytes = 0;
4575 }
4576}
4577
4578#[derive(Clone, Copy, PartialEq, Eq)]
4579pub(crate) struct PenaltySubspaceCacheKey {
4580 pub(crate) penalty_matrix_fingerprint: u64,
4581 pub(crate) ridge_passport_signature: u64,
4582}
4583
4584pub(crate) struct PenaltySubspaceCache {
4585 pub(crate) entry: Option<(PenaltySubspaceCacheKey, Arc<outer_eval::PenaltySubspace>)>,
4586}
4587
4588impl PenaltySubspaceCache {
4589 pub(crate) fn new() -> Self {
4590 Self { entry: None }
4591 }
4592
4593 pub(crate) fn get(
4594 &self,
4595 key: &PenaltySubspaceCacheKey,
4596 ) -> Option<Arc<outer_eval::PenaltySubspace>> {
4597 self.entry
4598 .as_ref()
4599 .filter(|(cached_key, _)| cached_key == key)
4600 .map(|(_, value)| value.clone())
4601 }
4602
4603 pub(crate) fn insert(
4604 &mut self,
4605 key: PenaltySubspaceCacheKey,
4606 value: Arc<outer_eval::PenaltySubspace>,
4607 ) {
4608 self.entry = Some((key, value));
4609 }
4610
4611 pub(crate) fn clear(&mut self) {
4612 self.entry = None;
4613 }
4614}
4615
4616impl PenaltySubspaceCacheKey {
4617 pub(crate) fn from_inputs(
4622 e_transformed: &ndarray::Array2<f64>,
4623 ridge_passport: &gam_problem::RidgePassport,
4624 ) -> Self {
4625 use std::collections::hash_map::DefaultHasher;
4626 use std::hash::{Hash, Hasher};
4627 let mut hasher = DefaultHasher::new();
4628 e_transformed.nrows().hash(&mut hasher);
4629 e_transformed.ncols().hash(&mut hasher);
4630 for value in e_transformed.iter() {
4631 value.to_bits().hash(&mut hasher);
4632 }
4633 let penalty_matrix_fingerprint = hasher.finish();
4634 let mut ridge_hasher = DefaultHasher::new();
4635 ridge_passport.delta.to_bits().hash(&mut ridge_hasher);
4636 (ridge_passport.matrix_form as u8).hash(&mut ridge_hasher);
4637 ridge_passport
4638 .policy
4639 .include_penalty_logdet
4640 .hash(&mut ridge_hasher);
4641 ridge_passport
4642 .policy
4643 .include_laplacehessian
4644 .hash(&mut ridge_hasher);
4645 let ridge_passport_signature = ridge_hasher.finish();
4646 Self {
4647 penalty_matrix_fingerprint,
4648 ridge_passport_signature,
4649 }
4650 }
4651}
4652
4653pub(crate) fn pirls_result_cache_bytes(result: &PirlsResult) -> usize {
4668 use std::mem::size_of;
4669 let n_array_elems = result.final_eta.len()
4670 + result.solveweights.len()
4671 + result.solveworking_response.len()
4672 + result.solvemu.len()
4673 + result.solve_c_array.len()
4674 + result.solve_d_array.len();
4675 let p = result.beta_transformed.0.len();
4676 let pen_h = symmetric_matrix_cache_bytes(&result.penalized_hessian_transformed);
4677 let stab_h = symmetric_matrix_cache_bytes(&result.stabilizedhessian_transformed);
4678 let reparam = (result.reparam_result.s_transformed.len()
4679 + result.reparam_result.qs.len()
4680 + result.reparam_result.e_transformed.len()
4681 + result.reparam_result.det1.len())
4682 * size_of::<f64>();
4683 n_array_elems * size_of::<f64>() + p * size_of::<f64>() + pen_h + stab_h + reparam + 1024
4684}
4685
4686pub(crate) fn symmetric_matrix_cache_bytes(m: &gam_linalg::matrix::SymmetricMatrix) -> usize {
4687 use gam_linalg::matrix::SymmetricMatrix;
4688 use std::mem::size_of;
4689 match m {
4690 SymmetricMatrix::Dense(a) => a.len() * size_of::<f64>(),
4691 SymmetricMatrix::Sparse(s) => {
4692 let (symbolic, values) = s.parts();
4694 values.len() * (size_of::<f64>() + size_of::<usize>())
4695 + std::mem::size_of_val(symbolic.col_ptr())
4696 }
4697 }
4698}
4699
4700pub(crate) struct EvalCacheManager {
4705 pub(crate) pirls_cache: RwLock<PirlsLruCache>,
4706 pub(crate) penalty_subspace_cache: RwLock<PenaltySubspaceCache>,
4707 pub(crate) current_eval_bundle: RwLock<Option<EvalShared>>,
4708 pub(crate) current_outer_eval: RwLock<Option<(Vec<u64>, OuterEval)>>,
4709 pub(crate) pirls_cache_enabled: AtomicBool,
4710}
4711
4712impl EvalCacheManager {
4713 pub(crate) fn new() -> Self {
4714 Self {
4715 pirls_cache: RwLock::new(PirlsLruCache::new(PIRLS_CACHE_BYTE_BUDGET)),
4716 penalty_subspace_cache: RwLock::new(PenaltySubspaceCache::new()),
4717 current_eval_bundle: RwLock::new(None),
4718 current_outer_eval: RwLock::new(None),
4719 pirls_cache_enabled: AtomicBool::new(true),
4720 }
4721 }
4722
4723 pub(crate) fn sanitized_rhokey(rho: &Array1<f64>) -> Option<Vec<u64>> {
4727 self::rho_key::sanitized_rhokey(rho)
4728 }
4729
4730 pub(super) fn cached_penalty_subspace<F>(
4737 &self,
4738 e_transformed: &ndarray::Array2<f64>,
4739 ridge_passport: &gam_problem::RidgePassport,
4740 build: F,
4741 ) -> Result<Arc<outer_eval::PenaltySubspace>, EstimationError>
4742 where
4743 F: FnOnce() -> Result<outer_eval::PenaltySubspace, EstimationError>,
4744 {
4745 let key = PenaltySubspaceCacheKey::from_inputs(e_transformed, ridge_passport);
4746 if let Some(hit) = self.penalty_subspace_cache.read().unwrap().get(&key) {
4747 return Ok(hit);
4748 }
4749 let value = Arc::new(build()?);
4750 self.penalty_subspace_cache
4751 .write()
4752 .unwrap()
4753 .insert(key, value.clone());
4754 Ok(value)
4755 }
4756
4757 pub(crate) fn cached_eval_bundle(&self, key: &Option<Vec<u64>>) -> Option<EvalShared> {
4758 let guard = self.current_eval_bundle.read().unwrap();
4759 let bundle: &EvalShared = guard.as_ref()?;
4760 bundle.matches(key).then(|| bundle.clone())
4761 }
4762
4763 pub(crate) fn store_eval_bundle(&self, bundle: EvalShared) {
4764 *self.current_eval_bundle.write().unwrap() = Some(bundle);
4765 }
4766
4767 pub(crate) fn cached_outer_eval(&self, key: &Option<Vec<u64>>) -> Option<OuterEval> {
4768 let key = key.as_ref()?;
4769 let guard = self.current_outer_eval.read().unwrap();
4770 let (cached_key, eval): &(Vec<u64>, OuterEval) = guard.as_ref()?;
4771 (cached_key == key).then(|| eval.clone())
4772 }
4773
4774 pub(crate) fn store_outer_eval(&self, key: &Option<Vec<u64>>, eval: &OuterEval) {
4775 if let Some(key) = key.clone() {
4776 *self.current_outer_eval.write().unwrap() = Some((key, eval.clone()));
4777 }
4778 }
4779
4780 pub(crate) fn invalidate_eval_bundle(&self) {
4781 self.current_eval_bundle.write().unwrap().take();
4782 self.current_outer_eval.write().unwrap().take();
4783 }
4784
4785 pub(crate) fn clear_eval_and_factor_caches(&self) {
4786 self.invalidate_eval_bundle();
4787 self.penalty_subspace_cache.write().unwrap().clear();
4788 }
4789}
4790
4791pub(crate) struct RemlArena {
4794 pub(crate) cost_eval_count: RwLock<u64>,
4795 pub(crate) lastgradient_used_stochastic_fallback: AtomicBool,
4796}
4797
4798impl RemlArena {
4799 pub(crate) fn new() -> Self {
4800 Self {
4801 cost_eval_count: RwLock::new(0),
4802 lastgradient_used_stochastic_fallback: AtomicBool::new(false),
4803 }
4804 }
4805}
4806
4807pub(crate) struct AloFrozenNuisance {
4808 pub(crate) n_obs: usize,
4809 pub(crate) influence_scale: Vec<f64>,
4810 pub(crate) phi: f64,
4811}
4812
4813pub(crate) struct RemlState<'a> {
4814 pub(crate) y: ArrayView1<'a, f64>,
4815 pub(crate) x: DesignMatrix,
4816 pub(crate) weights: ArrayView1<'a, f64>,
4817 pub(crate) offset: Array1<f64>,
4818 pub(crate) canonical_penalties: Arc<Vec<gam_terms::construction::CanonicalPenalty>>,
4822 pub(crate) balanced_penalty_root: Array2<f64>,
4823 pub(crate) reparam_invariant: ReparamInvariant,
4824 pub(crate) sparse_penalty_block_count: Option<usize>,
4825 pub(crate) p: usize,
4826 pub(crate) config: Arc<RemlConfig>,
4827 pub(crate) runtime_mixture_link_state: Option<gam_problem::MixtureLinkState>,
4828 pub(crate) runtime_sas_link_state: Option<SasLinkState>,
4829 pub(crate) nullspace_dims: Vec<usize>,
4830 pub(crate) coefficient_lower_bounds: Option<Array1<f64>>,
4831 pub(crate) linear_constraints: Option<crate::pirls::LinearInequalityConstraints>,
4832 pub(crate) penalty_shrinkage_floor: Option<f64>,
4834 pub(crate) rho_prior: gam_problem::RhoPrior,
4836
4837 pub(crate) cache_manager: EvalCacheManager,
4838 pub(crate) arena: RemlArena,
4839 pub(crate) warm_start_beta: RwLock<Option<Coefficients>>,
4840 pub(crate) warm_start_rho: RwLock<Option<Array1<f64>>>,
4850 pub(crate) prev_warm_start_beta: RwLock<Option<Coefficients>>,
4851 pub(crate) prev_warm_start_rho: RwLock<Option<Array1<f64>>>,
4852 pub(crate) warm_start_enabled: AtomicBool,
4853 pub(crate) screening_max_inner_iterations: Arc<AtomicUsize>,
4854 pub(crate) outer_inner_cap: Arc<AtomicUsize>,
4869
4870 pub(crate) last_inner_iters: Arc<AtomicUsize>,
4883 pub(crate) last_inner_converged: Arc<AtomicBool>,
4884
4885 pub(crate) ift_warm_start_cache: RwLock<Option<IftWarmStartCache>>,
4901
4902 pub(crate) last_pirls_lm_lambda: Arc<AtomicU64>,
4914
4915 pub(crate) frozen_negbin_theta: Arc<AtomicU64>,
4927
4928 pub(crate) last_ift_prediction_residual: Arc<AtomicU64>,
4950
4951 pub(crate) last_pirls_accept_rho: Arc<AtomicU64>,
4966
4967 pub(crate) ift_cached_factor: RwLock<Option<Arc<dyn gam_linalg::matrix::FactorizedSystem>>>,
4978
4979 pub(crate) kronecker_penalty_system: Option<gam_terms::smooth::KroneckerPenaltySystem>,
4983 pub(crate) kronecker_factored: Option<gam_terms::basis::KroneckerFactoredBasis>,
4986
4987 pub(crate) gaussian_fixed_cache: RwLock<Option<Arc<crate::pirls::GaussianFixedCache>>>,
4997 pub(crate) gaussian_psi_gram_deriv:
5008 RwLock<Option<Arc<(ndarray::Array2<f64>, ndarray::Array1<f64>)>>>,
5009 pub(crate) glm_psi_gram_deriv:
5027 RwLock<Option<Arc<(ndarray::Array2<f64>, ndarray::Array1<f64>)>>>,
5028 pub(crate) glm_first_step_gram: RwLock<Option<Arc<ndarray::Array2<f64>>>>,
5047 pub(crate) flat_glm_first_step_gram: RwLock<Option<Arc<ndarray::Array2<f64>>>>,
5057 pub(crate) alo_frozen_nuisance: RwLock<Option<AloFrozenNuisance>>,
5064
5065 pub(crate) persistent_warm_start_key: RwLock<Option<String>>,
5068 pub(crate) persistent_latent_values_fingerprint: Option<u64>,
5069 pub(crate) persistent_latent_values_cache: RwLock<PersistentLatentValuesCache>,
5070 pub(crate) analytic_penalty_registry_fingerprint: u64,
5071 pub(crate) persistent_warm_start_loaded: AtomicBool,
5073 pub(crate) persistent_warm_start_store_suppression: AtomicUsize,
5077 pub(crate) alo_stabilization_suppression: AtomicUsize,
5087 pub(crate) persistent_warm_start_disk_enabled: AtomicBool,
5101}