1use self::inner_strategy::GeometryBackendKind;
2use super::*;
3use gam_linalg::sparse_exact::SparseExactFactor;
4use crate::pirls::PIRLS_CACHE_BYTE_BUDGET;
5use crate::pirls::assemble_and_factor_sparse_penalized_system;
6use gam_terms::basis::LocalDesignJacobianProvider;
7use gam_problem::SasLinkState;
8use gam_problem::OuterEval;
9use ndarray::{Array1, Array2, s};
10use std::collections::{HashMap, VecDeque};
11use std::ops::Range;
12use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize};
13use std::sync::{Arc, RwLock};
14
15pub mod assembly;
16pub mod atoms;
17pub(crate) mod continuation;
18pub(crate) mod eval;
19mod firth;
20pub(super) mod hyper;
21mod inner_strategy;
22pub mod jeffreys_subspace;
25pub mod outer_eval;
26pub mod penalty_logdet;
27pub mod per_atom_efs;
28pub mod reml_outer_engine;
29mod rho_key;
30mod sparse_exact_penalty;
31mod trace;
32
33pub(crate) use sparse_exact_penalty::sparse_penalty_block_count_from_canonical;
34
35pub(crate) const EXACT_TAU_TAU_HESSIAN_DENSE_CACHE_BUDGET_BYTES: usize = 512 * 1024 * 1024;
36pub(crate) const FIRTH_MAX_OBSERVATIONS: usize = 20_000;
37pub(crate) const FIRTH_MAX_COEFFICIENTS: usize = 256;
38pub(crate) const FIRTH_MAX_LINEAR_WORK: usize = 2_000_000;
39pub(crate) const FIRTH_MAX_QUADRATIC_WORK: usize = 100_000_000;
40pub(crate) const PERSISTENT_LATENT_VALUES_CACHE_CAPACITY: usize = 8;
41
42#[derive(Debug)]
43pub(crate) struct PersistentLatentValuesCache {
44 pub(crate) entries: HashMap<String, Array2<f64>>,
45 pub(crate) lru: VecDeque<String>,
46 pub(crate) capacity: usize,
47}
48
49impl Default for PersistentLatentValuesCache {
50 fn default() -> Self {
51 Self {
52 entries: HashMap::new(),
53 lru: VecDeque::new(),
54 capacity: PERSISTENT_LATENT_VALUES_CACHE_CAPACITY,
55 }
56 }
57}
58
59impl PersistentLatentValuesCache {
60 pub(crate) fn lookup(
61 &mut self,
62 key: &str,
63 n_obs: usize,
64 latent_dim: usize,
65 ) -> Option<Array2<f64>> {
66 let values = self.entries.get(key)?;
67 if values.dim() != (n_obs, latent_dim) {
68 return None;
69 }
70 let values = values.clone();
71 self.touch(key.to_string());
72 Some(values)
73 }
74
75 pub(crate) fn insert(&mut self, key: String, values: Array2<f64>) {
76 if values.iter().any(|value| !value.is_finite()) {
77 return;
78 }
79 self.entries.insert(key.clone(), values);
80 self.touch(key);
81 while self.entries.len() > self.capacity {
82 let Some(evicted) = self.lru.pop_front() else {
83 break;
84 };
85 self.entries.remove(&evicted);
86 }
87 }
88
89 pub(crate) fn touch(&mut self, key: String) {
90 if let Some(index) = self.lru.iter().position(|queued| queued == &key) {
91 self.lru.remove(index);
92 }
93 self.lru.push_back(key);
94 }
95}
96
97#[derive(Clone)]
102pub(crate) struct IftWarmStartCache {
103 pub beta_original: ndarray::Array1<f64>,
109 pub rho: ndarray::Array1<f64>,
112 pub penalized_hessian_transformed: gam_linalg::matrix::SymmetricMatrix,
116 pub qs: ndarray::Array2<f64>,
120 pub frame_was_original: bool,
124 pub lambda_s_beta_blocks: Option<Vec<ndarray::Array1<f64>>>,
141}
142
143#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
144pub(crate) struct TauTauPlanEstimate {
145 pub(crate) dense_x_bytes: usize,
146 pub(crate) first_order_tau_bytes: usize,
147 pub(crate) second_order_tau_bytes: usize,
148 pub(crate) penalty_first_bytes: usize,
149 pub(crate) penalty_pair_bytes: usize,
150 pub(crate) rho_tau_penalty_bytes: usize,
151 pub(crate) vector_cache_bytes: usize,
152 pub(crate) weighted_scratch_bytes: usize,
153}
154
155impl TauTauPlanEstimate {
156 pub(crate) fn total_bytes(self) -> usize {
157 self.dense_x_bytes
158 .saturating_add(self.first_order_tau_bytes)
159 .saturating_add(self.second_order_tau_bytes)
160 .saturating_add(self.penalty_first_bytes)
161 .saturating_add(self.penalty_pair_bytes)
162 .saturating_add(self.rho_tau_penalty_bytes)
163 .saturating_add(self.vector_cache_bytes)
164 .saturating_add(self.weighted_scratch_bytes)
165 }
166}
167
168#[derive(Clone, Copy, Debug, PartialEq, Eq)]
169pub(crate) struct TauTauHessianPolicy {
170 pub(crate) any_has_implicit: bool,
171 pub(crate) implicit_multidim_duchon: bool,
172 pub(crate) estimated_dense_tau_cache_bytes: usize,
173 pub(crate) gradient_plan: TauTauPlanEstimate,
174 pub(crate) hessian_plan: TauTauPlanEstimate,
175 pub(crate) budget_bytes: usize,
176 pub(crate) firth_pair_terms_unavailable: bool,
177}
178
179impl TauTauHessianPolicy {
180 pub(crate) fn prefer_gradient_only(self) -> bool {
208 self.firth_pair_terms_unavailable
209 }
210}
211
212pub(crate) fn exact_tau_tau_hessian_policy_with_firth(
213 n_obs: usize,
214 p_coeff: usize,
215 hyper_dirs: &[DirectionalHyperParam],
216 firth_pair_terms_unavailable: bool,
217) -> TauTauHessianPolicy {
218 let f64_bytes = std::mem::size_of::<f64>();
219 let dense_matrix_bytes =
220 |rows: usize, cols: usize| -> usize { rows.saturating_mul(cols).saturating_mul(f64_bytes) };
221 let dense_design_bytes = dense_matrix_bytes(n_obs, p_coeff);
222 let dense_penalty_bytes = dense_matrix_bytes(p_coeff, p_coeff);
223 let psi_dim = hyper_dirs.len();
224 let implicit_n_axes = hyper_dirs
225 .iter()
226 .find_map(DirectionalHyperParam::implicit_axis_count_hint)
227 .unwrap_or(0);
228 let gradient_uses_implicit_design = hyper_dirs
229 .iter()
230 .any(DirectionalHyperParam::has_implicit_operator)
231 && gam_terms::basis::should_use_implicit_operators_with_policy(
232 n_obs,
233 p_coeff,
234 implicit_n_axes,
235 &gam_runtime::resource::ResourcePolicy::default_library(),
236 );
237 let dense_first_order_count = hyper_dirs
238 .iter()
239 .filter(|dir| !dir.has_implicit_operator())
240 .count();
241 let first_penalty_component_count = hyper_dirs
242 .iter()
243 .map(DirectionalHyperParam::penalty_first_component_count)
244 .sum::<usize>();
245
246 let mut dense_second_order_count = 0usize;
247 let mut penalty_pair_count = 0usize;
248 for i in 0..psi_dim {
249 for j in i..psi_dim {
250 if hyper_dirs[i]
251 .x_tau_tau_entry_at(j)
252 .or_else(|| hyper_dirs[j].x_tau_tau_entry_at(i))
253 .is_some_and(|entry| !entry.uses_implicit_storage())
254 {
255 dense_second_order_count += if i == j { 1 } else { 2 };
256 }
257 if hyper_dirs[i].has_penaltysecond_pair_at(j)
258 || hyper_dirs[j].has_penaltysecond_pair_at(i)
259 {
260 penalty_pair_count += if i == j { 1 } else { 2 };
261 }
262 }
263 }
264
265 let gradient_dense_first_order_count = if gradient_uses_implicit_design {
266 dense_first_order_count
267 } else {
268 psi_dim
269 };
270 let gradient_needs_dense_x =
271 firth_pair_terms_unavailable || gradient_dense_first_order_count > 0;
272 let gradient_plan = TauTauPlanEstimate {
273 dense_x_bytes: if gradient_needs_dense_x {
274 dense_design_bytes
275 } else {
276 0
277 },
278 first_order_tau_bytes: if gradient_dense_first_order_count > 0 {
279 dense_design_bytes
280 } else {
281 0
282 },
283 second_order_tau_bytes: 0,
284 penalty_first_bytes: psi_dim.saturating_mul(dense_penalty_bytes),
285 penalty_pair_bytes: 0,
286 rho_tau_penalty_bytes: 0,
287 vector_cache_bytes: n_obs.saturating_mul(f64_bytes),
288 weighted_scratch_bytes: dense_penalty_bytes,
289 };
290 let hessian_plan = TauTauPlanEstimate {
291 dense_x_bytes: if psi_dim > 0 { dense_design_bytes } else { 0 },
292 first_order_tau_bytes: dense_first_order_count.saturating_mul(dense_design_bytes),
293 second_order_tau_bytes: dense_second_order_count.saturating_mul(dense_design_bytes),
294 penalty_first_bytes: psi_dim.saturating_mul(dense_penalty_bytes),
295 penalty_pair_bytes: penalty_pair_count.saturating_mul(dense_penalty_bytes),
296 rho_tau_penalty_bytes: first_penalty_component_count
297 .saturating_mul(2)
298 .saturating_mul(dense_penalty_bytes),
299 vector_cache_bytes: psi_dim.saturating_mul(n_obs).saturating_mul(f64_bytes),
300 weighted_scratch_bytes: dense_penalty_bytes,
301 };
302 let any_has_implicit = hyper_dirs
303 .iter()
304 .any(DirectionalHyperParam::has_implicit_operator);
305 let implicit_multidim_duchon = hyper_dirs
306 .iter()
307 .any(DirectionalHyperParam::has_implicit_multidim_duchon);
308 let estimated_dense_tau_cache_bytes = hessian_plan
309 .first_order_tau_bytes
310 .saturating_add(hessian_plan.second_order_tau_bytes);
311 TauTauHessianPolicy {
312 any_has_implicit,
313 implicit_multidim_duchon,
314 estimated_dense_tau_cache_bytes,
315 gradient_plan,
316 hessian_plan,
317 budget_bytes: EXACT_TAU_TAU_HESSIAN_DENSE_CACHE_BUDGET_BYTES,
318 firth_pair_terms_unavailable: firth_pair_terms_unavailable && !hyper_dirs.is_empty(),
319 }
320}
321
322pub(crate) fn firth_problem_scale_allows(n_obs: usize, p_coeff: usize) -> bool {
323 let linear_work = n_obs.saturating_mul(p_coeff);
324 let quadratic_work = linear_work.saturating_mul(p_coeff);
325 n_obs <= FIRTH_MAX_OBSERVATIONS
326 && p_coeff <= FIRTH_MAX_COEFFICIENTS
327 && linear_work <= FIRTH_MAX_LINEAR_WORK
328 && quadratic_work <= FIRTH_MAX_QUADRATIC_WORK
329}
330
331#[cfg(test)]
332mod tests {
333 use super::{
334 DirectionalHyperParam, EvalCacheManager, EvalShared, HyperDesignDerivative,
335 HyperPenaltyDerivative, ImplicitDerivLevel, RemlConfig, RemlState,
336 };
337 use crate::estimate::EstimationError;
338 use gam_linalg::faer_ndarray::FaerCholesky;
339 use gam_linalg::matrix::symmetrize_in_place;
340 use crate::pirls::PirlsCoordinateFrame;
341 use gam_terms::basis::{ImplicitDesignPsiDerivative, RadialScalarKind};
342 use gam_problem::{
343 GlmLikelihoodSpec, InverseLink, LikelihoodSpec, ResponseFamily, StandardLink,
344 };
345 use faer::Side;
346 use gam_problem::{HessianResult, OuterEval};
347 use ndarray::{Array1, Array2, array, s};
348 use std::sync::Arc;
349
350 pub(crate) fn binomial_logit_glm_spec() -> GlmLikelihoodSpec {
353 GlmLikelihoodSpec::canonical(LikelihoodSpec::new(
354 ResponseFamily::Binomial,
355 InverseLink::Standard(StandardLink::Logit),
356 ))
357 }
358
359 pub(crate) fn gaussian_identity_glm_spec() -> GlmLikelihoodSpec {
362 GlmLikelihoodSpec::canonical(LikelihoodSpec::new(
363 ResponseFamily::Gaussian,
364 InverseLink::Standard(StandardLink::Identity),
365 ))
366 }
367
368 impl DirectionalHyperParam {
369 pub(super) fn new(
370 x_tau_original: Array2<f64>,
371 penalty_first_components: Vec<(usize, Array2<f64>)>,
372 x_tau_tau_original: Option<Vec<Option<Array2<f64>>>>,
373 penaltysecond_components: Option<Vec<Option<Vec<(usize, Array2<f64>)>>>>,
374 ) -> Result<Self, EstimationError> {
375 let x_tau_tau_original = x_tau_tau_original.map(|rows| {
376 rows.into_iter()
377 .map(|entry| entry.map(HyperDesignDerivative::from))
378 .collect::<Vec<_>>()
379 });
380 let penalty_first_components = penalty_first_components
381 .into_iter()
382 .map(|(idx, matrix)| (idx, HyperPenaltyDerivative::from(matrix)))
383 .collect();
384 let penaltysecond_components = penaltysecond_components.map(|rows| {
385 rows.into_iter()
386 .map(|row| {
387 row.map(|components| {
388 components
389 .into_iter()
390 .map(|(idx, matrix)| (idx, HyperPenaltyDerivative::from(matrix)))
391 .collect::<Vec<_>>()
392 })
393 })
394 .collect::<Vec<_>>()
395 });
396 Self::new_compact(
397 HyperDesignDerivative::from(x_tau_original),
398 penalty_first_components,
399 x_tau_tau_original,
400 penaltysecond_components,
401 )
402 }
403
404 pub(super) fn single_penalty(
405 penalty_index: usize,
406 x_tau_original: Array2<f64>,
407 s_tau_original: Array2<f64>,
408 x_tau_tau_original: Option<Vec<Option<Array2<f64>>>>,
409 s_tau_tau_original: Option<Vec<Option<Array2<f64>>>>,
410 ) -> Result<Self, EstimationError> {
411 let penaltysecond_components = s_tau_tau_original.map(|rows| {
412 rows.into_iter()
413 .map(|mat| mat.map(|mat| vec![(penalty_index, mat)]))
414 .collect::<Vec<_>>()
415 });
416 Self::new(
417 x_tau_original,
418 vec![(penalty_index, s_tau_original)],
419 x_tau_tau_original,
420 penaltysecond_components,
421 )
422 }
423 }
424
425 #[test]
426 pub(crate) fn firth_problem_scale_gate_blocks_large_quadratic_work() {
427 assert!(super::firth_problem_scale_allows(2_000, 200));
428 assert!(!super::firth_problem_scale_allows(4_800, 241));
429 assert!(!super::firth_problem_scale_allows(4_800, 433));
430 }
431
432 #[test]
433 pub(crate) fn tau_tau_hessian_policy_prefers_gradient_only_for_implicit_tau() {
434 let operator = ImplicitDesignPsiDerivative::new(
435 array![1.0, 2.0, 3.0, 4.0],
436 array![0.5, -1.0, 1.5, 2.0],
437 array![0.1, 0.2, 0.3, 0.4],
438 array![[1.0, 0.2], [0.5, 0.1], [1.5, 0.3], [2.0, 0.4]],
439 None,
440 None,
441 2,
442 2,
443 1,
444 2,
445 );
446 let dir = DirectionalHyperParam::new_compact(
447 HyperDesignDerivative::from_implicit(
448 Arc::new(operator),
449 ImplicitDerivLevel::First(0),
450 1..4,
451 5,
452 ),
453 Vec::new(),
454 None,
455 None,
456 )
457 .expect("implicit directional hyperparam");
458 let policy = super::exact_tau_tau_hessian_policy_with_firth(10, 5, &[dir], false);
459 assert!(policy.any_has_implicit);
460 assert_eq!(
461 policy.gradient_plan.dense_x_bytes,
462 10 * 5 * std::mem::size_of::<f64>()
463 );
464 assert!(!policy.prefer_gradient_only());
465 }
466
467 #[test]
468 pub(crate) fn tau_tau_hessian_policy_does_not_force_gradient_only_for_implicit_multidim_duchon()
469 {
470 let operator = ImplicitDesignPsiDerivative::new_streaming(
478 Arc::new(array![[0.0, 0.0], [1.0, 0.2]]),
479 Arc::new(array![[0.0, 0.0], [1.0, 1.0]]),
480 vec![0.0, 0.0],
481 RadialScalarKind::PureDuchon {
482 block_order: 1,
483 p_order: 0,
484 s_order: 0,
485 dim: 2,
486 },
487 None,
488 None,
489 0,
490 );
491 let dir = DirectionalHyperParam::new_compact(
492 HyperDesignDerivative::from_implicit(
493 Arc::new(operator),
494 ImplicitDerivLevel::First(0),
495 0..2,
496 2,
497 ),
498 Vec::new(),
499 None,
500 None,
501 )
502 .expect("implicit duchon directional hyperparam");
503 let policy = super::exact_tau_tau_hessian_policy_with_firth(10, 5, &[dir], false);
504 assert!(policy.any_has_implicit);
505 assert!(policy.implicit_multidim_duchon);
506 assert!(!policy.prefer_gradient_only());
507 }
508
509 #[test]
510 pub(crate) fn tau_tau_hessian_policy_does_not_force_gradient_only_when_cache_budget_is_exceeded()
511 {
512 let dirs = (0..16)
518 .map(|_| {
519 DirectionalHyperParam::new_compact(
520 HyperDesignDerivative::from(Array2::<f64>::zeros((2, 2))),
521 Vec::new(),
522 None,
523 None,
524 )
525 .expect("dense directional hyperparam")
526 })
527 .collect::<Vec<_>>();
528 let policy = super::exact_tau_tau_hessian_policy_with_firth(320_000, 71, &dirs, false);
529 assert!(!policy.any_has_implicit);
530 assert!(policy.hessian_plan.total_bytes() > policy.budget_bytes);
531 assert!(policy.hessian_plan.total_bytes() > policy.gradient_plan.total_bytes());
532 assert!(!policy.prefer_gradient_only());
533 }
534
535 #[test]
536 pub(crate) fn tau_tau_hessian_policy_prefers_gradient_only_for_firth_pair_gap() {
537 let dir = DirectionalHyperParam::new_compact(
538 HyperDesignDerivative::from(Array2::<f64>::zeros((2, 2))),
539 Vec::new(),
540 None,
541 None,
542 )
543 .expect("dense directional hyperparam");
544 let policy = super::exact_tau_tau_hessian_policy_with_firth(10, 5, &[dir], true);
545 assert!(policy.firth_pair_terms_unavailable);
546 assert!(policy.prefer_gradient_only());
547 }
548
549 trait LogitDesignMotionFixture {
557 fn y(&self) -> &Array1<f64>;
558 fn w(&self) -> &Array1<f64>;
559 fn x(&self) -> &Array2<f64>;
560 fn s0(&self) -> &Array2<f64>;
561 fn cfg(&self) -> &RemlConfig;
562 fn rho(&self) -> &Array1<f64>;
563
564 fn state(&self) -> RemlState<'_> {
565 build_logit_state(self.y(), self.w(), self.x(), self.s0(), self.cfg())
566 }
567
568 fn state_perturbed(
569 &self,
570 x_tau: &Array2<f64>,
571 s_tau: &Array2<f64>,
572 eps: f64,
573 ) -> (RemlState<'_>, RemlState<'_>) {
574 let x_plus = self.x() + &x_tau.mapv(|v| eps * v);
575 let x_minus = self.x() - &x_tau.mapv(|v| eps * v);
576 let s_plus = self.s0() + &s_tau.mapv(|v| eps * v);
577 let s_minus = self.s0() - &s_tau.mapv(|v| eps * v);
578 (
579 build_logit_state(self.y(), self.w(), &x_plus, &s_plus, self.cfg()),
580 build_logit_state(self.y(), self.w(), &x_minus, &s_minus, self.cfg()),
581 )
582 }
583
584 fn fd_directional_gradient(&self, x_tau: &Array2<f64>, s_tau: &Array2<f64>) -> f64 {
586 let h = 2e-5;
587 let (state_plus, state_minus) = self.state_perturbed(x_tau, s_tau, h);
588 let v_plus = state_plus.compute_cost(self.rho()).expect("cost+");
589 let v_minus = state_minus.compute_cost(self.rho()).expect("cost-");
590 (v_plus - v_minus) / (2.0 * h)
591 }
592 }
593
594 pub(crate) fn build_logit_state<'a>(
595 y: &'a Array1<f64>,
596 w: &'a Array1<f64>,
597 x: &Array2<f64>,
598 s: &Array2<f64>,
599 cfg: &'a RemlConfig,
600 ) -> RemlState<'a> {
601 use crate::estimate::PenaltySpec;
602 let p = x.ncols();
603 let offset = Array1::<f64>::zeros(y.len());
604 let spec = PenaltySpec::Dense(s.clone());
605 let canonical = gam_terms::construction::canonicalize_penalty_specs(&[spec], &[1], p, "test")
606 .map(|(canonical, _)| canonical)
607 .expect("canonicalize");
608 RemlState::newwith_offset(
609 y.view(),
610 x.clone(),
611 w.view(),
612 offset.view(),
613 canonical,
614 p,
615 cfg,
616 Some(vec![1]),
617 None,
618 None,
619 )
620 .expect("state")
621 }
622
623 #[test]
624 fn repeated_penalty_ranges_keep_analytic_outer_hessian() {
625 let y = array![0.2, -0.1, 0.3, 0.0];
626 let w = Array1::<f64>::ones(y.len());
627 let x = array![[1.0, -0.7], [1.0, -0.2], [1.0, 0.3], [1.0, 0.9]];
628 let offset = Array1::<f64>::zeros(y.len());
629 let cfg = RemlConfig::external(gaussian_identity_glm_spec(), 1e-10, false);
630 let p = x.ncols();
631 let canonical = vec![
632 gam_terms::construction::CanonicalPenalty::from_dense_root(array![[0.0, 1.0]], p),
633 gam_terms::construction::CanonicalPenalty::from_dense_root(array![[1.0, 0.0]], p),
634 ];
635 let state = RemlState::newwith_offset(
636 y.view(),
637 x,
638 w.view(),
639 offset.view(),
640 canonical,
641 p,
642 &cfg,
643 Some(vec![1, 1]),
644 None,
645 None,
646 )
647 .expect("state");
648
649 assert!(
650 state.analytic_outer_hessian_enabled(),
651 "double-penalty-style repeated coefficient ranges must still route to exact Hessian"
652 );
653 }
654
655 pub(crate) fn poisson_log_glm_spec() -> GlmLikelihoodSpec {
656 GlmLikelihoodSpec::canonical(LikelihoodSpec::new(
657 ResponseFamily::Poisson,
658 InverseLink::Standard(StandardLink::Log),
659 ))
660 }
661
662 #[test]
676 pub(crate) fn fixed_dispersion_laml_surface_is_replication_invariant() {
677 let n = 200usize;
678 let p = 8usize;
679 let c = 3usize;
680 let mut x = Array2::<f64>::zeros((n, p));
681 let mut y = Array1::<f64>::zeros(n);
682 for i in 0..n {
683 let t = (i as f64) / ((n - 1) as f64);
684 let tau = std::f64::consts::TAU;
685 x[[i, 0]] = 1.0;
686 x[[i, 1]] = t;
687 x[[i, 2]] = (tau * t).sin();
688 x[[i, 3]] = (tau * t).cos();
689 x[[i, 4]] = (2.0 * tau * t).sin();
690 x[[i, 5]] = (2.0 * tau * t).cos();
691 x[[i, 6]] = (3.0 * tau * t).sin();
692 x[[i, 7]] = (3.0 * tau * t).cos();
693 let eta = 0.3 + 0.9 * (1.4 * (t - 0.5)).sin();
694 y[i] = (eta.exp() + 0.5 * ((i as f64) * 2.399_963).sin())
696 .round()
697 .max(0.0);
698 }
699 let mut s = Array2::<f64>::zeros((p, p));
700 for j in 1..p {
701 s[[j, j]] = 1.0;
702 }
703
704 let mut x_rep = Array2::<f64>::zeros((n * c, p));
706 let mut y_rep = Array1::<f64>::zeros(n * c);
707 for r in 0..c {
708 for i in 0..n {
709 let row = r * n + i;
710 for j in 0..p {
711 x_rep[[row, j]] = x[[i, j]];
712 }
713 y_rep[row] = y[i];
714 }
715 }
716
717 let w_weighted = Array1::<f64>::from_elem(n, c as f64);
718 let w_rep = Array1::<f64>::ones(n * c);
719
720 let cfg = RemlConfig::external(poisson_log_glm_spec(), 1e-10, false);
721 let st_w = build_logit_state(&y, &w_weighted, &x, &s, &cfg);
722 let st_r = build_logit_state(&y_rep, &w_rep, &x_rep, &s, &cfg);
723
724 for &rho in &[-2.0_f64, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0] {
725 let r = Array1::from_elem(1, rho);
726 let cw = st_w.compute_cost(&r).expect("weighted cost");
727 let cr = st_r.compute_cost(&r).expect("replicated cost");
728 let gw = st_w.compute_gradient(&r).expect("weighted grad");
729 let gr = st_r.compute_gradient(&r).expect("replicated grad");
730 assert!(
733 (cw - cr).abs() <= 1e-9 * (1.0 + cw.abs()),
734 "LAML cost differs between w=c and c× replication at rho={rho}: \
735 cost_w={cw:.12e} cost_r={cr:.12e} diff={:.3e}",
736 cw - cr
737 );
738 assert!(
739 (gw[0] - gr[0]).abs() <= 1e-9 * (1.0 + gw[0].abs()),
740 "LAML gradient differs between w=c and c× replication at rho={rho}: \
741 g_w={:.12e} g_r={:.12e} diff={:.3e}",
742 gw[0],
743 gr[0],
744 gw[0] - gr[0]
745 );
746 }
747 }
748
749 #[test]
758 pub(crate) fn rho_weight_anchor_is_zero_for_fixed_dispersion() {
759 let n = 50usize;
760 let p = 3usize;
761 let mut x = Array2::<f64>::zeros((n, p));
762 let mut y = Array1::<f64>::zeros(n);
763 for i in 0..n {
764 let t = (i as f64) / ((n - 1) as f64);
765 x[[i, 0]] = 1.0;
766 x[[i, 1]] = t;
767 x[[i, 2]] = t * t;
768 y[i] = (1.0 + (3.0 * t).sin()).round().max(0.0);
769 }
770 let mut s = Array2::<f64>::zeros((p, p));
771 s[[2, 2]] = 1.0;
772 let c = 4.0_f64;
774 let w = Array1::<f64>::from_elem(n, c);
775
776 let cfg_pois = RemlConfig::external(poisson_log_glm_spec(), 1e-10, false);
777 let st_pois = build_logit_state(&y, &w, &x, &s, &cfg_pois);
778 assert_eq!(
779 st_pois.rho_weight_anchor(),
780 0.0,
781 "fixed-dispersion (Poisson) anchor must be 0, not the geometric-mean log-weight"
782 );
783
784 let cfg_gauss = RemlConfig::external(gaussian_identity_glm_spec(), 1e-10, false);
785 let st_gauss = build_logit_state(&y, &w, &x, &s, &cfg_gauss);
786 assert!(
787 (st_gauss.rho_weight_anchor() - c.ln()).abs() <= 1e-12,
788 "Gaussian-identity (profiled) anchor must be the geometric-mean log-weight ln(c)={:.6}, got {:.6}",
789 c.ln(),
790 st_gauss.rho_weight_anchor()
791 );
792 }
793
794 pub(crate) fn beta_original_from_bundle(bundle: &EvalShared) -> Array1<f64> {
795 let pr = bundle.pirls_result.as_ref();
796 match pr.coordinate_frame {
797 PirlsCoordinateFrame::OriginalSparseNative => pr.beta_transformed.as_ref().clone(),
798 PirlsCoordinateFrame::TransformedQs => {
799 pr.reparam_result.qs.dot(pr.beta_transformed.as_ref())
800 }
801 }
802 }
803
804 pub(crate) fn compute_joint_hypercostgradienthessian(
805 state: &RemlState<'_>,
806 theta: &Array1<f64>,
807 rho_dim: usize,
808 hyper_dirs: &[DirectionalHyperParam],
809 ) -> Result<(f64, Array1<f64>, Array2<f64>), EstimationError> {
810 let (cost, gradient, hessian) = state.compute_joint_hyper_eval_with_order(
811 theta,
812 rho_dim,
813 hyper_dirs,
814 crate::rho_optimizer::OuterEvalOrder::ValueGradientHessian,
815 )?;
816 Ok((
817 cost,
818 gradient,
819 hessian
820 .materialize_dense()
821 .map_err(EstimationError::RemlOptimizationFailed)?
822 .ok_or_else(|| {
823 EstimationError::RemlOptimizationFailed(
824 "joint hyper Hessian requested but unavailable".to_string(),
825 )
826 })?,
827 ))
828 }
829
830 pub(crate) fn h_original_from_bundle(bundle: &EvalShared) -> Array2<f64> {
831 let pr = bundle.pirls_result.as_ref();
832 match pr.coordinate_frame {
833 PirlsCoordinateFrame::OriginalSparseNative => bundle.h_total.as_ref().clone(),
834 PirlsCoordinateFrame::TransformedQs => {
835 let qs = &pr.reparam_result.qs;
836 let tmp = gam_linalg::faer_ndarray::fast_ab(qs, bundle.h_total.as_ref());
837 gam_linalg::faer_ndarray::fast_abt(&tmp, qs)
838 }
839 }
840 }
841
842 pub(crate) fn single_directional_tau_gradient(
843 state: &RemlState<'_>,
844 rho: &Array1<f64>,
845 hyper: DirectionalHyperParam,
846 ) -> Result<f64, EstimationError> {
847 let mut theta = Array1::<f64>::zeros(rho.len() + 1);
848 theta.slice_mut(s![..rho.len()]).assign(rho);
849 let (_, gradient, _) = state.compute_joint_hyper_eval_with_order(
850 &theta,
851 rho.len(),
852 &[hyper],
853 crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
854 )?;
855 Ok(gradient[rho.len()])
856 }
857
858 pub(crate) fn fd_directional_tau_cost_gradient(
859 y: &Array1<f64>,
860 w: &Array1<f64>,
861 x: &Array2<f64>,
862 s0: &Array2<f64>,
863 cfg: &RemlConfig,
864 rho: &Array1<f64>,
865 x_tau: &Array2<f64>,
866 s_tau: &Array2<f64>,
867 ) -> f64 {
868 let h = 2e-5;
869 let x_plus = x + &x_tau.mapv(|v| h * v);
870 let x_minus = x - &x_tau.mapv(|v| h * v);
871 let s_plus = s0 + &s_tau.mapv(|v| h * v);
872 let s_minus = s0 - &s_tau.mapv(|v| h * v);
873 let state_plus = build_logit_state(y, w, &x_plus, &s_plus, cfg);
874 let state_minus = build_logit_state(y, w, &x_minus, &s_minus, cfg);
875 let v_plus = state_plus.compute_cost(rho).expect("cost+");
876 let v_minus = state_minus.compute_cost(rho).expect("cost-");
877 (v_plus - v_minus) / (2.0 * h)
878 }
879
880 pub(crate) fn directional_tau_hessian_fd_reference(
881 y: &Array1<f64>,
882 w: &Array1<f64>,
883 x: &Array2<f64>,
884 s0: &Array2<f64>,
885 cfg: &RemlConfig,
886 rho: &Array1<f64>,
887 hyper_dirs: &[DirectionalHyperParam],
888 x_tau_mats: &[Array2<f64>],
889 s_tau_mats: &[Array2<f64>],
890 ) -> Array2<f64> {
891 assert_eq!(hyper_dirs.len(), x_tau_mats.len());
892 assert_eq!(hyper_dirs.len(), s_tau_mats.len());
893
894 const TARGET_PHYSICAL_STEP: f64 = 1e-5;
895
896 let n_dirs = hyper_dirs.len();
897 let mut h_ttfd = Array2::<f64>::zeros((n_dirs, n_dirs));
898 for j in 0..n_dirs {
899 let direction_scale = x_tau_mats[j]
900 .iter()
901 .chain(s_tau_mats[j].iter())
902 .fold(0.0_f64, |acc, value| acc.max(value.abs()));
903 let h = if direction_scale > 0.0 {
904 TARGET_PHYSICAL_STEP / direction_scale
905 } else {
906 TARGET_PHYSICAL_STEP
907 };
908
909 let x_plus = x + &x_tau_mats[j].mapv(|v| h * v);
910 let x_minus = x - &x_tau_mats[j].mapv(|v| h * v);
911 let s_plus = s0 + &s_tau_mats[j].mapv(|v| h * v);
912 let s_minus = s0 - &s_tau_mats[j].mapv(|v| h * v);
913
914 let state_plus = build_logit_state(y, w, &x_plus, &s_plus, cfg);
915 let state_minus = build_logit_state(y, w, &x_minus, &s_minus, cfg);
916 for i in 0..n_dirs {
917 let g_plus =
918 single_directional_tau_gradient(&state_plus, rho, hyper_dirs[i].clone())
919 .expect("g+ for FD");
920 let g_minus =
921 single_directional_tau_gradient(&state_minus, rho, hyper_dirs[i].clone())
922 .expect("g- for FD");
923 h_ttfd[[i, j]] = (g_plus - g_minus) / (2.0 * h);
924 }
925 }
926 symmetrize_in_place(&mut h_ttfd);
927 h_ttfd
928 }
929
930 #[test]
931 pub(crate) fn eval_cache_manager_stores_first_order_outer_eval() {
932 let cache = EvalCacheManager::new();
933 let rho = array![0.25, -0.0];
934 let rho_key = EvalCacheManager::sanitized_rhokey(&rho);
935 let eval = OuterEval {
936 cost: 3.5,
937 gradient: array![1.0, -2.0],
938 hessian: HessianResult::Unavailable,
939 inner_beta_hint: None,
940 };
941
942 cache.store_outer_eval(&rho_key, &eval);
943
944 let cached = cache
945 .cached_outer_eval(&rho_key)
946 .expect("first-order outer eval should be cached");
947 assert_eq!(cached.cost, eval.cost);
948 assert_eq!(cached.gradient, eval.gradient);
949 assert!(matches!(cached.hessian, HessianResult::Unavailable));
950
951 cache.invalidate_eval_bundle();
952 assert!(
953 cache.cached_outer_eval(&rho_key).is_none(),
954 "invalidating the bundle should clear the outer-eval cache too"
955 );
956 }
957
958 #[test]
968 pub(crate) fn outer_eval_lru_hit_is_bit_identical_and_evicts_honestly_1575() {
969 use super::OUTER_EVAL_LRU_CAPACITY;
970
971 let make_eval = |seed: f64| OuterEval {
974 cost: (seed * std::f64::consts::PI).sin() / 3.0 - seed,
975 gradient: array![seed, -seed * 2.0, seed.recip()],
976 hessian: HessianResult::Unavailable,
977 inner_beta_hint: Some(array![seed + 0.5, seed - 0.5]),
978 };
979 let bits_eq = |a: &OuterEval, b: &OuterEval| -> bool {
980 a.cost.to_bits() == b.cost.to_bits()
981 && a.gradient.len() == b.gradient.len()
982 && a.gradient
983 .iter()
984 .zip(b.gradient.iter())
985 .all(|(x, y)| x.to_bits() == y.to_bits())
986 };
987
988 let cache = EvalCacheManager::new();
989
990 let rho_a = array![0.25, -1.5];
993 let key_a = EvalCacheManager::sanitized_rhokey(&rho_a);
994 let eval_a = make_eval(0.25);
995 cache.store_outer_eval(&key_a, &eval_a);
996 let hit_a = cache
997 .cached_outer_eval(&key_a)
998 .expect("stored rho_a must hit");
999 assert!(
1000 bits_eq(&hit_a, &eval_a),
1001 "cache hit must be bit-identical (cost+gradient) to the stored miss-path eval"
1002 );
1003 assert_eq!(
1004 hit_a.inner_beta_hint.as_ref().map(|b| b.to_vec()),
1005 eval_a.inner_beta_hint.as_ref().map(|b| b.to_vec()),
1006 "inner_beta_hint must round-trip unchanged"
1007 );
1008
1009 let rho_b = array![0.25, -1.4999999999999998];
1012 let key_b = EvalCacheManager::sanitized_rhokey(&rho_b);
1013 assert_ne!(key_a, key_b, "the two rho-keys must differ");
1014 let eval_b = make_eval(7.0);
1015 cache.store_outer_eval(&key_b, &eval_b);
1016 assert!(
1017 bits_eq(
1018 &cache.cached_outer_eval(&key_b).expect("rho_b must hit"),
1019 &eval_b
1020 ),
1021 "rho_b must return its own eval, not rho_a's"
1022 );
1023 assert!(
1024 bits_eq(
1025 &cache.cached_outer_eval(&key_a).expect("rho_a must still hit"),
1026 &eval_a
1027 ),
1028 "rho_a must be unaffected by the rho_b insert"
1029 );
1030
1031 let cache = EvalCacheManager::new();
1035 let mut keys = Vec::new();
1036 let mut evals = Vec::new();
1037 for i in 0..OUTER_EVAL_LRU_CAPACITY {
1038 let rho = array![i as f64, -(i as f64)];
1039 let key = EvalCacheManager::sanitized_rhokey(&rho);
1040 let eval = make_eval(i as f64 + 0.123);
1041 cache.store_outer_eval(&key, &eval);
1042 keys.push(key);
1043 evals.push(eval);
1044 }
1045 assert_eq!(
1047 cache.outer_eval_lru.read().unwrap().entries.len(),
1048 OUTER_EVAL_LRU_CAPACITY
1049 );
1050 let rho_overflow = array![999.0, -999.0];
1052 let key_overflow = EvalCacheManager::sanitized_rhokey(&rho_overflow);
1053 let eval_overflow = make_eval(42.0);
1054 cache.store_outer_eval(&key_overflow, &eval_overflow);
1055 assert_eq!(
1056 cache.outer_eval_lru.read().unwrap().entries.len(),
1057 OUTER_EVAL_LRU_CAPACITY,
1058 "capacity must stay bounded"
1059 );
1060 assert!(
1061 cache.cached_outer_eval(&keys[0]).is_none(),
1062 "the least-recently-used key must be evicted and now MISS (recompute), not return stale"
1063 );
1064 assert!(
1065 bits_eq(
1066 &cache
1067 .cached_outer_eval(&keys[1])
1068 .expect("a still-resident key must hit"),
1069 &evals[1]
1070 ),
1071 "a still-resident key must return its exact stored bits"
1072 );
1073 assert!(
1074 bits_eq(
1075 &cache
1076 .cached_outer_eval(&key_overflow)
1077 .expect("the freshest key must hit"),
1078 &eval_overflow
1079 ),
1080 "the freshest key must hit with its own eval"
1081 );
1082 }
1083
1084 #[test]
1085 pub(crate) fn reset_outer_seed_state_clears_pirls_cache() {
1086 let y = array![0.0, 1.0, 1.0, 0.0, 0.0, 1.0];
1092 let w = Array1::<f64>::ones(y.len());
1093 let x = array![
1094 [1.0, -1.0, 0.2],
1095 [1.0, -0.5, -0.4],
1096 [1.0, 0.0, 0.7],
1097 [1.0, 0.4, -0.3],
1098 [1.0, 0.9, 0.1],
1099 [1.0, 1.3, -0.6],
1100 ];
1101 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.1, 0.15], [0.0, 0.15, 0.8],];
1102 let rho = array![0.0];
1103 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false);
1104 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1105
1106 state
1109 .compute_outer_eval_with_order(
1110 &rho,
1111 crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
1112 )
1113 .expect("outer eval should succeed");
1114
1115 let populated_len = state.cache_manager.pirls_cache.read().unwrap().map.len();
1116 assert!(
1117 populated_len > 0,
1118 "evaluating the outer objective should populate the PIRLS LRU, got {populated_len}"
1119 );
1120
1121 state.reset_outer_seed_state();
1122
1123 let cleared_len = state.cache_manager.pirls_cache.read().unwrap().map.len();
1124 assert_eq!(
1125 cleared_len, 0,
1126 "reset_outer_seed_state must clear the cross-call PIRLS LRU; got {cleared_len} entries"
1127 );
1128 }
1129
1130 #[test]
1131 pub(crate) fn reset_outer_seed_state_preserves_frozen_negbin_theta_1448() {
1132 use std::sync::atomic::Ordering;
1149
1150 let y = array![0.0, 1.0, 1.0, 0.0, 0.0, 1.0];
1151 let w = Array1::<f64>::ones(y.len());
1152 let x = array![
1153 [1.0, -1.0, 0.2],
1154 [1.0, -0.5, -0.4],
1155 [1.0, 0.0, 0.7],
1156 [1.0, 0.4, -0.3],
1157 [1.0, 0.9, 0.1],
1158 [1.0, 1.3, -0.6],
1159 ];
1160 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.1, 0.15], [0.0, 0.15, 0.8],];
1161 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false);
1162 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1163
1164 let theta_final_bits = 2.5_f64.to_bits();
1166 state
1167 .frozen_negbin_theta
1168 .store(theta_final_bits, Ordering::Relaxed);
1169 assert_eq!(
1170 state.frozen_negbin_theta.load(Ordering::Relaxed),
1171 theta_final_bits,
1172 "precondition: the re-freeze stores θ_final into the frozen slot"
1173 );
1174
1175 state.reset_outer_seed_state();
1177
1178 assert_eq!(
1179 state.frozen_negbin_theta.load(Ordering::Relaxed),
1180 theta_final_bits,
1181 "reset_outer_seed_state (alternation-round reset) must PRESERVE the \
1182 re-frozen NB θ; clearing it would defeat the #1448 θ↔λ alternation \
1183 (the next ρ search would re-derive θ from the seed and never reach \
1184 the joint fixed point)"
1185 );
1186 }
1187
1188 #[test]
1189 pub(crate) fn implicit_hyper_design_derivative_respects_full_model_embedding() {
1190 let operator = ImplicitDesignPsiDerivative::new(
1191 array![1.0, 2.0, 3.0, 4.0],
1192 array![0.5, -1.0, 1.5, 2.0],
1193 array![0.1, 0.2, 0.3, 0.4],
1194 array![[1.0, 0.2], [0.5, 0.1], [1.5, 0.3], [2.0, 0.4]],
1195 None,
1196 None,
1197 2,
1198 2,
1199 1,
1200 2,
1201 );
1202 let local = operator
1203 .materialize_first(0)
1204 .expect("materialized first derivative");
1205 assert_eq!(
1206 local.ncols(),
1207 3,
1208 "operator-local derivative should stay smooth-local"
1209 );
1210
1211 let implicit = HyperDesignDerivative::from_implicit(
1212 Arc::new(operator),
1213 ImplicitDerivLevel::First(0),
1214 1..4,
1215 5,
1216 );
1217 let embedded = HyperDesignDerivative::from_embedded(local.clone(), 1..4, 5);
1218
1219 assert_eq!(implicit.nrows(), embedded.nrows());
1220 assert_eq!(implicit.ncols(), 5);
1221 assert_eq!(implicit.materialize(), embedded.materialize());
1222
1223 let u = array![7.0, 1.5, -2.0, 0.25, -3.0];
1224 let v = array![0.75, -1.25];
1225 assert_eq!(
1226 implicit.forward_mul_original(&u).expect("implicit forward"),
1227 embedded.forward_mul_original(&u).expect("embedded forward")
1228 );
1229 assert_eq!(
1230 implicit
1231 .transpose_mul_original(&v)
1232 .expect("implicit transpose"),
1233 embedded
1234 .transpose_mul_original(&v)
1235 .expect("embedded transpose")
1236 );
1237
1238 let qs = array![
1239 [1.0, 0.0, 0.0],
1240 [0.0, 1.0, 0.0],
1241 [0.0, 0.5, 0.5],
1242 [0.0, 0.0, 1.0],
1243 [0.0, 0.0, 0.0],
1244 ];
1245 assert_eq!(
1246 implicit
1247 .transformed(&qs, None)
1248 .expect("implicit transformed"),
1249 embedded
1250 .transformed(&qs, None)
1251 .expect("embedded transformed")
1252 );
1253 let u_transformed = array![1.0, -0.5, 2.0];
1254 assert_eq!(
1255 implicit
1256 .transformed_forward_mul(&qs, None, &u_transformed)
1257 .expect("implicit transformed forward"),
1258 embedded
1259 .transformed_forward_mul(&qs, None, &u_transformed)
1260 .expect("embedded transformed forward")
1261 );
1262 assert_eq!(
1263 implicit
1264 .transformed_transpose_mul(&qs, None, &v)
1265 .expect("implicit transformed transpose"),
1266 embedded
1267 .transformed_transpose_mul(&qs, None, &v)
1268 .expect("embedded transformed transpose")
1269 );
1270 }
1271
1272 #[test]
1273 pub(crate) fn directional_hyper_identities_match_finite_differences_logit() {
1274 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1275 let w = Array1::<f64>::ones(y.len());
1276 let x = array![
1277 [1.0, -1.2, 0.3],
1278 [1.0, -0.8, -0.4],
1279 [1.0, -0.3, 0.7],
1280 [1.0, 0.1, -0.9],
1281 [1.0, 0.5, 0.2],
1282 [1.0, 0.9, -0.1],
1283 [1.0, 1.3, 0.8],
1284 [1.0, 1.7, -0.6],
1285 ];
1286 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1287
1288 let x_tau = Array2::<f64>::zeros(x.raw_dim());
1293 let s_tau = array![[0.0, 0.0, 0.0], [0.0, 0.25, 0.04], [0.0, 0.04, 0.15],];
1294 let hyper =
1295 DirectionalHyperParam::single_penalty(0, x_tau.clone(), s_tau.clone(), None, None)
1296 .expect("single-penalty hyper direction");
1297 let rho = array![0.0];
1298
1299 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
1303 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1304 let bundle = state.obtain_eval_bundle(&rho).expect("bundle");
1305 let pr = bundle.pirls_result.as_ref();
1306
1307 let beta = beta_original_from_bundle(&bundle);
1308 let h_orig = h_original_from_bundle(&bundle);
1309 let u = &pr.solveweights * &(&pr.solveworking_response - &pr.final_eta);
1310
1311 let x_tau_beta = gam_linalg::faer_ndarray::fast_av(&x_tau, &beta);
1314 let weighted_x_tau_beta = &pr.finalweights * &x_tau_beta;
1315 let rhs = gam_linalg::faer_ndarray::fast_atv(&x_tau, &u)
1316 - gam_linalg::faer_ndarray::fast_atv(&x, &weighted_x_tau_beta)
1317 - s_tau.dot(&beta);
1318 let chol = h_orig.cholesky(Side::Lower).expect("chol(H)");
1319 let b_analytic = chol.solvevec(&rhs);
1320
1321 let eta_dot = &x_tau_beta + &gam_linalg::faer_ndarray::fast_av(&x, &b_analytic);
1325 let w_direction = crate::pirls::directionalworking_curvature_from_c_array(
1326 &pr.solve_c_array,
1327 &pr.finalweights,
1328 &eta_dot,
1329 );
1330 let wx = RemlState::row_scale(&x, &pr.finalweights);
1331 let wx_tau = RemlState::row_scale(&x_tau, &pr.finalweights);
1332 let mut xwtau_x = x.clone();
1333 match w_direction {
1334 crate::pirls::DirectionalWorkingCurvature::Diagonal(diag) => {
1335 xwtau_x = RemlState::row_scale(&xwtau_x, &diag);
1336 }
1337 }
1338 let mut h_tau_analytic = gam_linalg::faer_ndarray::fast_atb(&x_tau, &wx);
1339 h_tau_analytic += &gam_linalg::faer_ndarray::fast_atb(&x, &wx_tau);
1340 h_tau_analytic += &gam_linalg::faer_ndarray::fast_atb(&x, &xwtau_x);
1341 h_tau_analytic += &s_tau;
1342
1343 let ell_beta = gam_linalg::faer_ndarray::fast_atv(&x, &u);
1348 let s_eff = &h_orig - &gam_linalg::faer_ndarray::fast_atb(&x, &wx);
1349 let cancellation = -ell_beta.dot(&b_analytic) + beta.dot(&s_eff.dot(&b_analytic));
1350
1351 let h = 2e-5;
1353 let x_plus = &x + &(x_tau.mapv(|v| h * v));
1354 let x_minus = &x - &(x_tau.mapv(|v| h * v));
1355 let s_plus = &s0 + &(s_tau.mapv(|v| h * v));
1356 let s_minus = &s0 - &(s_tau.mapv(|v| h * v));
1357
1358 let state_plus = build_logit_state(&y, &w, &x_plus, &s_plus, &cfg);
1359 let state_minus = build_logit_state(&y, &w, &x_minus, &s_minus, &cfg);
1360 let bundle_plus = state_plus.obtain_eval_bundle(&rho).expect("bundle+");
1361 let bundle_minus = state_minus.obtain_eval_bundle(&rho).expect("bundle-");
1362 let beta_plus = beta_original_from_bundle(&bundle_plus);
1363 let beta_minus = beta_original_from_bundle(&bundle_minus);
1364 let bfd = (&beta_plus - &beta_minus).mapv(|v| v / (2.0 * h));
1365
1366 let h_plus = h_original_from_bundle(&bundle_plus);
1367 let h_minus = h_original_from_bundle(&bundle_minus);
1368 let h_taufd = (&h_plus - &h_minus).mapv(|v| v / (2.0 * h));
1369
1370 let v_plus = state_plus.compute_cost(&rho).expect("cost+");
1371 let v_minus = state_minus.compute_cost(&rho).expect("cost-");
1372 let v_taufd = (v_plus - v_minus) / (2.0 * h);
1373
1374 let v_tau_analytic = single_directional_tau_gradient(&state, &rho, hyper.clone())
1375 .expect("analytic directional gradient");
1376
1377 let b_num = (&b_analytic - &bfd).mapv(|v| v * v).sum().sqrt();
1378 let b_den = bfd.mapv(|v| v * v).sum().sqrt().max(1e-12);
1379 let b_rel = b_num / b_den;
1380 for i in 0..b_analytic.len() {
1381 assert_eq!(
1382 b_analytic[i].signum(),
1383 bfd[i].signum(),
1384 "B sign mismatch at i={i}: analytic={} fd={}",
1385 b_analytic[i],
1386 bfd[i]
1387 );
1388 }
1389 assert!(
1390 b_rel < 2e-2,
1391 "B implicit solve mismatch vs FD: rel={b_rel:.3e}, num={b_num:.3e}, den={b_den:.3e}"
1392 );
1393
1394 let dh_num = (&h_tau_analytic - &h_taufd).mapv(|v| v * v).sum().sqrt();
1395 let dh_den = h_taufd.mapv(|v| v * v).sum().sqrt().max(1e-12);
1396 let dh_rel = dh_num / dh_den;
1397 for i in 0..h_tau_analytic.nrows() {
1398 for j in 0..h_tau_analytic.ncols() {
1399 assert_eq!(
1400 h_tau_analytic[[i, j]].signum(),
1401 h_taufd[[i, j]].signum(),
1402 "H_tau sign mismatch at ({i},{j}): analytic={} fd={}",
1403 h_tau_analytic[[i, j]],
1404 h_taufd[[i, j]]
1405 );
1406 }
1407 }
1408 assert!(
1409 dh_rel < 3e-2,
1410 "H_tau mismatch vs FD: rel={dh_rel:.3e}, num={dh_num:.3e}, den={dh_den:.3e}"
1411 );
1412
1413 let v_abs = (v_tau_analytic - v_taufd).abs();
1414 let v_rel = v_abs / v_taufd.abs().max(1e-10);
1415 assert_eq!(
1416 v_tau_analytic.signum(),
1417 v_taufd.signum(),
1418 "V_tau sign mismatch: analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
1419 );
1420 assert!(
1421 v_rel < 2e-2,
1422 "V_tau mismatch vs FD: rel={v_rel:.3e}, abs={v_abs:.3e}, analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
1423 );
1424
1425 assert!(
1426 cancellation.abs() < 1e-10,
1427 "stationarity cancellation failed: | -ell_beta^T B + beta^T S B | = {:.3e}",
1428 cancellation.abs()
1429 );
1430 }
1431
1432 #[test]
1433 pub(crate) fn firth_exacthessian_includes_analytic_tk_second_derivatives() {
1434 let y = array![0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0];
1436 let w = Array1::<f64>::ones(y.len());
1437 let x = array![
1438 [1.0, -1.2, 0.4, -2.4],
1439 [1.0, -0.9, -0.1, -1.8],
1440 [1.0, -0.6, 0.3, -1.2],
1441 [1.0, -0.2, -0.4, -0.4],
1442 [1.0, 0.1, 0.5, 0.2],
1443 [1.0, 0.4, -0.6, 0.8],
1444 [1.0, 0.8, 0.2, 1.6],
1445 [1.0, 1.1, -0.3, 2.2],
1446 [1.0, 1.4, 0.7, 2.8],
1447 [1.0, 1.7, -0.2, 3.4],
1448 ];
1449 let s0 = array![
1450 [0.0, 0.0, 0.0, 0.0],
1451 [0.0, 1.5, 0.2, 0.0],
1452 [0.0, 0.2, 1.0, 0.0],
1453 [0.0, 0.0, 0.0, 0.5],
1454 ];
1455 let s1 = array![
1456 [0.0, 0.0, 0.0, 0.0],
1457 [0.0, 0.8, -0.1, 0.0],
1458 [0.0, -0.1, 0.6, 0.0],
1459 [0.0, 0.0, 0.0, 0.3],
1460 ];
1461 let offset = Array1::<f64>::zeros(y.len());
1462 let cfg =
1465 RemlConfig::external(binomial_logit_glm_spec(), 1e-9, true).with_max_iterations(500);
1466 let p = x.ncols();
1467 use crate::estimate::PenaltySpec;
1468 let specs = vec![PenaltySpec::Dense(s0), PenaltySpec::Dense(s1)];
1469 let canonical = gam_terms::construction::canonicalize_penalty_specs(&specs, &[1, 1], p, "test")
1470 .map(|(canonical, _)| canonical)
1471 .expect("canonicalize");
1472 let state = RemlState::newwith_offset(
1473 y.view(),
1474 x.clone(),
1475 w.view(),
1476 offset.view(),
1477 canonical,
1478 p,
1479 &cfg,
1480 Some(vec![1, 1]),
1481 None,
1482 None,
1483 )
1484 .expect("state");
1485 let rho = array![0.1, -0.2];
1486 assert!(
1487 state.analytic_outer_hessian_enabled(),
1488 "Firth logit should no longer disable analytic outer Hessian planning"
1489 );
1490 let outer = state
1491 .compute_outer_eval_with_order(
1492 &rho,
1493 crate::rho_optimizer::OuterEvalOrder::ValueGradientHessian,
1494 )
1495 .expect("outer Hessian eval should succeed");
1496 assert!(
1497 outer.hessian.is_analytic(),
1498 "outer planner should request and return an analytic Hessian"
1499 );
1500 let bundle = state.obtain_eval_bundle(&rho).expect("exact firth bundle");
1501 let h_dense = state
1502 .compute_lamlhessian_exact_from_bundle(&rho, &bundle)
1503 .expect("Firth exact Hessian should include analytic TK second derivatives");
1504 assert_eq!(h_dense.raw_dim(), ndarray::Ix2(2, 2));
1505 assert!(
1506 h_dense.iter().all(|value| value.is_finite()),
1507 "Hessian should be finite: {h_dense:?}"
1508 );
1509 }
1510
1511 #[test]
1512 pub(crate) fn firth_outer_hessian_matches_gradient_finite_difference_with_tk_terms() {
1513 let y = array![0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0];
1514 let w = Array1::<f64>::ones(y.len());
1515 let x = array![
1516 [1.0, -1.0, 0.3],
1517 [1.0, -0.7, -0.2],
1518 [1.0, -0.3, 0.4],
1519 [1.0, 0.0, -0.5],
1520 [1.0, 0.2, 0.6],
1521 [1.0, 0.6, -0.4],
1522 [1.0, 0.9, 0.2],
1523 [1.0, 1.3, -0.1],
1524 ];
1525 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.1], [0.0, 0.1, 0.7],];
1526 let s1 = array![[0.0, 0.0, 0.0], [0.0, 0.4, -0.05], [0.0, -0.05, 0.9],];
1527 let cfg =
1528 RemlConfig::external(binomial_logit_glm_spec(), 1e-9, true).with_max_iterations(500);
1529 let p_dim = x.ncols();
1530 use crate::estimate::PenaltySpec;
1531 let specs = vec![PenaltySpec::Dense(s0), PenaltySpec::Dense(s1)];
1532 let canonical =
1533 gam_terms::construction::canonicalize_penalty_specs(&specs, &[1, 1], p_dim, "test")
1534 .map(|(canonical, _)| canonical)
1535 .expect("canonicalize");
1536 let offset = Array1::<f64>::zeros(y.len());
1537 let state = RemlState::newwith_offset(
1538 y.view(),
1539 x.clone(),
1540 w.view(),
1541 offset.view(),
1542 canonical,
1543 p_dim,
1544 &cfg,
1545 Some(vec![1, 1]),
1546 None,
1547 None,
1548 )
1549 .expect("state");
1550 let rho = array![0.15, -0.25];
1551 let eval = state
1552 .compute_outer_eval_with_order(
1553 &rho,
1554 crate::rho_optimizer::OuterEvalOrder::ValueGradientHessian,
1555 )
1556 .expect("analytic Hessian eval");
1557 let h = match eval.hessian {
1558 HessianResult::Analytic(hessian) => hessian,
1559 HessianResult::Operator(_) | HessianResult::Unavailable => {
1560 panic!("expected dense analytic Hessian")
1561 }
1562 };
1563 let delta = 2.0e-5;
1564 for col in 0..rho.len() {
1565 let mut rp = rho.clone();
1566 let mut rm = rho.clone();
1567 rp[col] += delta;
1568 rm[col] -= delta;
1569 let gp = state
1570 .compute_outer_eval_with_order(
1571 &rp,
1572 crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
1573 )
1574 .expect("plus grad")
1575 .gradient;
1576 let gm = state
1577 .compute_outer_eval_with_order(
1578 &rm,
1579 crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
1580 )
1581 .expect("minus grad")
1582 .gradient;
1583 for row in 0..rho.len() {
1584 let fd = (gp[row] - gm[row]) / (2.0 * delta);
1585 let an = h[[row, col]];
1586 let rel = (fd - an).abs() / fd.abs().max(an.abs()).max(1e-6);
1587 assert!(
1588 rel < 2.0e-3,
1589 "Hessian mismatch ({row},{col}): analytic={an:.9e}, fd={fd:.9e}, rel={rel:.3e}"
1590 );
1591 }
1592 }
1593 }
1594
1595 #[test]
1596 pub(crate) fn firthgradient_lives_in_design_column_space_under_rank_deficiency() {
1597 let x = array![
1599 [1.0, -1.2, 0.4, -2.4],
1600 [1.0, -0.9, -0.1, -1.8],
1601 [1.0, -0.6, 0.3, -1.2],
1602 [1.0, -0.2, -0.4, -0.4],
1603 [1.0, 0.1, 0.5, 0.2],
1604 [1.0, 0.4, -0.6, 0.8],
1605 [1.0, 0.8, 0.2, 1.6],
1606 [1.0, 1.1, -0.3, 2.2],
1607 ];
1608 let beta = array![0.1, -0.2, 0.3, 0.05];
1609 let eta = x.dot(&beta);
1610 let op = super::RemlState::build_firth_dense_operator_for_link(
1611 &gam_problem::InverseLink::Standard(gam_problem::StandardLink::Logit),
1612 &x,
1613 &eta,
1614 ndarray::Array1::ones(x.nrows()).view(),
1615 )
1616 .expect("firth operator");
1617
1618 let gradphi = 0.5 * x.t().dot(&(&op.w1 * &op.h_diag));
1621
1622 let q = &op.q_basis;
1624 let proj = q.dot(&q.t().dot(&gradphi));
1625 let resid = &gradphi - &proj;
1626 let rel =
1627 resid.mapv(|v| v * v).sum().sqrt() / gradphi.mapv(|v| v * v).sum().sqrt().max(1e-12);
1628 assert!(
1629 rel < 1e-10,
1630 "Firth gradient should lie in Col(Xᵀ): rel residual={rel:.3e}"
1631 );
1632 }
1633
1634 #[test]
1635 pub(crate) fn firth_logit_directional_hypergradient_accepts_penalty_only_with_full_tk_gradient()
1636 {
1637 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
1638 let w = Array1::<f64>::ones(y.len());
1639 let x = array![
1640 [1.0, -1.1, 0.2],
1641 [1.0, -0.6, -0.3],
1642 [1.0, -0.1, 0.5],
1643 [1.0, 0.3, -0.7],
1644 [1.0, 0.8, 0.1],
1645 [1.0, 1.2, -0.4],
1646 ];
1647 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1648 let hyper = DirectionalHyperParam::single_penalty(
1649 0,
1650 Array2::<f64>::zeros((x.nrows(), x.ncols())),
1651 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.03], [0.0, 0.03, 0.12],],
1652 None,
1653 None,
1654 )
1655 .expect("single-penalty hyper direction");
1656 let rho = array![0.0];
1657 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-8, true);
1658 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1659 let gradient = single_directional_tau_gradient(&state, &rho, hyper)
1660 .expect("Firth penalty-only directional gradient should use analytic TK propagation");
1661 assert!(gradient.is_finite(), "gradient={gradient}");
1662 let fd = fd_directional_tau_cost_gradient(
1663 &y,
1664 &w,
1665 &x,
1666 &s0,
1667 &cfg,
1668 &rho,
1669 &Array2::<f64>::zeros((x.nrows(), x.ncols())),
1670 &array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.03], [0.0, 0.03, 0.12],],
1671 );
1672 let rel = (gradient - fd).abs() / gradient.abs().max(fd.abs()).max(1.0e-10);
1673 assert!(
1674 rel < 1.0e-3,
1675 "Firth penalty-only directional gradient mismatch: analytic={gradient:.12e}, fd={fd:.12e}, rel={rel:.3e}"
1676 );
1677
1678 let efs_hyper = DirectionalHyperParam::single_penalty(
1679 0,
1680 Array2::<f64>::zeros((x.nrows(), x.ncols())),
1681 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.03], [0.0, 0.03, 0.12],],
1682 None,
1683 None,
1684 )
1685 .expect("single-penalty EFS hyper direction");
1686 let efs = state
1687 .compute_efs_steps_with_psi_ext(&rho, &[efs_hyper])
1688 .expect("Firth penalty-only EFS should use analytic TK propagation");
1689 assert!(efs.cost.is_finite(), "efs cost={}", efs.cost);
1690 }
1691
1692 #[test]
1693 pub(crate) fn firth_logit_directional_hypergradient_accepts_design_moving_with_full_tk_gradient()
1694 {
1695 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
1696 let w = Array1::<f64>::ones(y.len());
1697 let x = array![
1698 [1.0, -1.1, 0.2],
1699 [1.0, -0.6, -0.3],
1700 [1.0, -0.1, 0.5],
1701 [1.0, 0.3, -0.7],
1702 [1.0, 0.8, 0.1],
1703 [1.0, 1.2, -0.4],
1704 ];
1705 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1706 let hyper = DirectionalHyperParam::single_penalty(
1707 0,
1708 Array2::from_elem((x.nrows(), x.ncols()), 1e-3),
1709 Array2::<f64>::zeros((x.ncols(), x.ncols())),
1710 None,
1711 None,
1712 )
1713 .expect("single-penalty hyper direction");
1714 let rho = array![0.0];
1715 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-8, true);
1716 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1717 let gradient = single_directional_tau_gradient(&state, &rho, hyper)
1718 .expect("Firth design-moving directional gradient should use analytic TK propagation");
1719 assert!(gradient.is_finite(), "gradient={gradient}");
1720 let x_tau = Array2::from_elem((x.nrows(), x.ncols()), 1e-3);
1721 let s_tau = Array2::<f64>::zeros((x.ncols(), x.ncols()));
1722 let fd = fd_directional_tau_cost_gradient(&y, &w, &x, &s0, &cfg, &rho, &x_tau, &s_tau);
1723 let rel = (gradient - fd).abs() / gradient.abs().max(fd.abs()).max(1.0e-10);
1724 assert!(
1725 rel < 2.0e-2,
1726 "Firth design-moving directional gradient mismatch: analytic={gradient:.12e}, fd={fd:.12e}, rel={rel:.3e}"
1727 );
1728 }
1729
1730 #[test]
1731 pub(crate) fn firth_logit_hybrid_efs_accepts_full_tk_psi_gradient() {
1732 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
1733 let w = Array1::<f64>::ones(y.len());
1734 let x = array![
1735 [1.0, -1.1, 0.2],
1736 [1.0, -0.6, -0.3],
1737 [1.0, -0.1, 0.5],
1738 [1.0, 0.3, -0.7],
1739 [1.0, 0.8, 0.1],
1740 [1.0, 1.2, -0.4],
1741 ];
1742 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1743 let hyper_dirs = vec![
1744 DirectionalHyperParam::single_penalty(
1745 0,
1746 Array2::from_shape_fn((x.nrows(), x.ncols()), |(i, j)| {
1747 1e-3 * ((i + 1) as f64) * ((j + 2) as f64)
1748 }),
1749 Array2::<f64>::zeros((x.ncols(), x.ncols())),
1750 None,
1751 None,
1752 )
1753 .expect("design-moving hyper direction"),
1754 ];
1755 let rho = array![0.0];
1756 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-8, true);
1757 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1758
1759 let full = state
1760 .evaluate_unified_with_psi_ext(
1761 &rho,
1762 None,
1763 crate::estimate::reml::reml_outer_engine::EvalMode::ValueAndGradient,
1764 &hyper_dirs,
1765 )
1766 .expect("full Firth psi gradient should use analytic TK propagation");
1767 assert!(full.cost.is_finite(), "full cost={}", full.cost);
1768 let full_grad = full.gradient.expect("gradient should be present");
1769 assert!(
1770 full_grad.iter().all(|value| value.is_finite()),
1771 "full gradient={full_grad:?}"
1772 );
1773
1774 let efs = state
1775 .compute_efs_steps_with_psi_ext(&rho, &hyper_dirs)
1776 .expect("hybrid EFS should use analytic TK propagation");
1777 assert!(efs.cost.is_finite(), "efs cost={}", efs.cost);
1778 }
1779
1780 #[test]
1781 pub(crate) fn joint_hyperhessianwires_mixed_blocks() {
1782 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1783 let w = Array1::<f64>::ones(y.len());
1784 let x = array![
1785 [1.0, -1.2, 0.3],
1786 [1.0, -0.8, -0.4],
1787 [1.0, -0.3, 0.7],
1788 [1.0, 0.1, -0.9],
1789 [1.0, 0.5, 0.2],
1790 [1.0, 0.9, -0.1],
1791 [1.0, 1.3, 0.8],
1792 [1.0, 1.7, -0.6],
1793 ];
1794 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1795 let cfg =
1796 RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false).with_max_iterations(500);
1797 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1798 let rho = array![0.0];
1799 let theta = array![0.0, 0.0, 0.0];
1800 let hyper_dirs = vec![
1801 DirectionalHyperParam::single_penalty(
1802 0,
1803 Array2::<f64>::zeros((x.nrows(), x.ncols())),
1804 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15],],
1805 None,
1806 None,
1807 )
1808 .expect("single-penalty hyper direction"),
1809 DirectionalHyperParam::single_penalty(
1810 0,
1811 Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
1812 Array2::<f64>::zeros((x.ncols(), x.ncols())),
1813 None,
1814 None,
1815 )
1816 .expect("single-penalty hyper direction"),
1817 ];
1818
1819 let (_, _, h) =
1820 compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
1821 .expect("joint hyper cost+gradient+hessian");
1822 assert_eq!(h.nrows(), theta.len());
1823 assert_eq!(h.ncols(), theta.len());
1824 assert!(h.iter().all(|v| v.is_finite()));
1825 for i in 0..h.nrows() {
1826 for j in 0..i {
1827 let diff = (h[[i, j]] - h[[j, i]]).abs();
1828 assert!(
1829 diff < 1e-6,
1830 "joint hessian asymmetry at ({i},{j}): {diff:.3e}"
1831 );
1832 }
1833 }
1834 let mixed_0 = h[[0, 1]];
1836 let mixed_1 = h[[0, 2]];
1837 assert!(
1838 mixed_0.is_finite() && mixed_1.is_finite(),
1839 "mixed blocks must be finite"
1840 );
1841 }
1842
1843 #[test]
1844 pub(crate) fn joint_tau_tau_linear_dirs_matchfd_reference_away_fromzero_psi() {
1845 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1846 let w = Array1::<f64>::ones(y.len());
1847 let x = array![
1848 [1.0, -1.2, 0.3],
1849 [1.0, -0.8, -0.4],
1850 [1.0, -0.3, 0.7],
1851 [1.0, 0.1, -0.9],
1852 [1.0, 0.5, 0.2],
1853 [1.0, 0.9, -0.1],
1854 [1.0, 1.3, 0.8],
1855 [1.0, 1.7, -0.6],
1856 ];
1857 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1858 let cfg =
1859 RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false).with_max_iterations(500);
1860 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1861 let rho = array![0.0];
1862 let psi = array![0.7, -0.4];
1863 let theta = array![rho[0], psi[0], psi[1]];
1864 let hyper_dirs = vec![
1865 DirectionalHyperParam::single_penalty(
1866 0,
1867 Array2::<f64>::zeros((x.nrows(), x.ncols())),
1868 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15],],
1869 None,
1870 None,
1871 )
1872 .expect("linear tau direction"),
1873 DirectionalHyperParam::single_penalty(
1874 0,
1875 Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
1876 Array2::<f64>::zeros((x.ncols(), x.ncols())),
1877 None,
1878 None,
1879 )
1880 .expect("linear tau direction"),
1881 ];
1882
1883 let (_, _, h_full) =
1884 compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
1885 .expect("joint hyper cost+gradient+hessian");
1886 let h_tt_analytic = h_full.slice(s![rho.len().., rho.len()..]).to_owned();
1887
1888 let x_tau_mats: Vec<Array2<f64>> = vec![
1893 Array2::<f64>::zeros((x.nrows(), x.ncols())),
1894 Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
1895 ];
1896 let s_tau_mats: Vec<Array2<f64>> = vec![
1897 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15]],
1898 Array2::<f64>::zeros((x.ncols(), x.ncols())),
1899 ];
1900
1901 let h_ttfd = directional_tau_hessian_fd_reference(
1902 &y,
1903 &w,
1904 &x,
1905 &s0,
1906 &cfg,
1907 &rho,
1908 &hyper_dirs,
1909 &x_tau_mats,
1910 &s_tau_mats,
1911 );
1912
1913 let num = (&h_tt_analytic - &h_ttfd)
1914 .iter()
1915 .map(|v| v * v)
1916 .sum::<f64>()
1917 .sqrt();
1918 let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
1919 let rel = num / den;
1920 assert!(
1921 rel < 1e-4,
1922 "linear-dir joint tau-tau block deviates from FD reference away from zero psi: rel={rel:.3e}, analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
1923 );
1924 }
1925
1926 #[test]
1927 pub(crate) fn joint_hypervalidation_rejects_out_of_boundssecond_order_penalty_index() {
1928 let y = array![0.0, 1.0, 0.0, 1.0];
1945 let w = Array1::<f64>::ones(y.len());
1946 let x = array![
1947 [1.0, -0.5, 0.2],
1948 [1.0, -0.1, -0.3],
1949 [1.0, 0.4, 0.6],
1950 [1.0, 0.9, -0.2],
1951 ];
1952 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1953 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, true);
1954 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1955 let theta = array![0.0, 0.0];
1956 let hyper_dirs = vec![
1957 DirectionalHyperParam::new(
1958 Array2::<f64>::zeros((x.nrows(), x.ncols())),
1959 vec![(0, Array2::<f64>::zeros((x.ncols(), x.ncols())))],
1960 None,
1961 Some(vec![Some(vec![(1, Array2::<f64>::eye(x.ncols()))])]),
1962 )
1963 .expect("hyper direction with invalid second-order penalty index"),
1964 ];
1965
1966 let msg = match compute_joint_hypercostgradienthessian(&state, &theta, 1, &hyper_dirs) {
1967 Ok(_) => panic!("invalid second-order penalty index should be rejected"),
1968 Err(err) => err.to_string(),
1969 };
1970 assert!(
1971 msg.contains("out of bounds") || msg.contains("penalty_index"),
1972 "unexpected validation error: {msg}"
1973 );
1974 }
1975
1976 #[test]
1977 pub(crate) fn joint_tau_tau_analytic_matchesfd_reference() {
1978 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1979 let w = Array1::<f64>::ones(y.len());
1980 let x = array![
1981 [1.0, -1.2, 0.3],
1982 [1.0, -0.8, -0.4],
1983 [1.0, -0.3, 0.7],
1984 [1.0, 0.1, -0.9],
1985 [1.0, 0.5, 0.2],
1986 [1.0, 0.9, -0.1],
1987 [1.0, 1.3, 0.8],
1988 [1.0, 1.7, -0.6],
1989 ];
1990 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1991 let cfg =
1992 RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false).with_max_iterations(500);
1993 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1994 let rho = array![0.0];
1995 let psi = array![0.0, 0.0];
1996 let hyper_dirs = vec![
1997 DirectionalHyperParam::single_penalty(
1998 0,
1999 Array2::<f64>::zeros((x.nrows(), x.ncols())),
2000 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15],],
2001 None,
2002 None,
2003 )
2004 .expect("single-penalty hyper direction"),
2005 DirectionalHyperParam::single_penalty(
2006 0,
2007 Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
2008 Array2::<f64>::zeros((x.ncols(), x.ncols())),
2009 None,
2010 None,
2011 )
2012 .expect("single-penalty hyper direction"),
2013 ];
2014
2015 let theta = {
2016 let mut t = Array1::<f64>::zeros(rho.len() + psi.len());
2017 t.slice_mut(s![..rho.len()]).assign(&rho);
2018 t.slice_mut(s![rho.len()..]).assign(&psi);
2019 t
2020 };
2021 let (_, _, h_full) =
2022 compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
2023 .expect("joint hyper cost+gradient+hessian");
2024 let h_tt_analytic = h_full.slice(s![rho.len().., rho.len()..]).to_owned();
2025 assert_eq!(h_tt_analytic.nrows(), hyper_dirs.len());
2026 assert_eq!(h_tt_analytic.ncols(), hyper_dirs.len());
2027
2028 let x_tau_mats: Vec<Array2<f64>> = vec![
2033 Array2::<f64>::zeros((x.nrows(), x.ncols())),
2034 Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
2035 ];
2036 let s_tau_mats: Vec<Array2<f64>> = vec![
2037 array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15]],
2038 Array2::<f64>::zeros((x.ncols(), x.ncols())),
2039 ];
2040
2041 let h_ttfd = directional_tau_hessian_fd_reference(
2042 &y,
2043 &w,
2044 &x,
2045 &s0,
2046 &cfg,
2047 &rho,
2048 &hyper_dirs,
2049 &x_tau_mats,
2050 &s_tau_mats,
2051 );
2052
2053 let num = (&h_tt_analytic - &h_ttfd)
2054 .iter()
2055 .map(|v| v * v)
2056 .sum::<f64>()
2057 .sqrt();
2058 let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2059 let rel = num / den;
2060 assert!(
2061 rel < 1e-4,
2062 "analytic tau-tau block deviates from FD reference: rel={rel:.3e}, analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
2063 );
2064 }
2065
2066 pub(crate) struct GaussianRemlFixture {
2076 pub(crate) y: Array1<f64>,
2077 pub(crate) w: Array1<f64>,
2078 pub(crate) x: Array2<f64>,
2079 pub(crate) s0: Array2<f64>,
2080 pub(crate) cfg: RemlConfig,
2081 pub(crate) rho: Array1<f64>,
2082 pub(crate) x_tau_design: Array2<f64>,
2084 pub(crate) s_tau_penalty: Array2<f64>,
2086 }
2087
2088 impl GaussianRemlFixture {
2089 pub(crate) fn new() -> Self {
2090 let y = array![0.5, 1.2, -0.3, 0.8, 1.1, -0.6, 0.9, 0.1, -0.2, 0.7];
2091 let x = array![
2092 [1.0, -1.2, 0.3],
2093 [1.0, -0.8, -0.4],
2094 [1.0, -0.3, 0.7],
2095 [1.0, 0.1, -0.9],
2096 [1.0, 0.5, 0.2],
2097 [1.0, 0.9, -0.1],
2098 [1.0, 1.3, 0.8],
2099 [1.0, 1.7, -0.6],
2100 [1.0, -0.5, 0.5],
2101 [1.0, 0.3, -0.3],
2102 ];
2103 Self {
2104 w: Array1::<f64>::ones(y.len()),
2105 y,
2106 x: x.clone(),
2107 s0: array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9]],
2108 cfg: RemlConfig::external(gaussian_identity_glm_spec(), 1e-14, false),
2109 rho: array![0.0],
2110 x_tau_design: array![
2111 [0.0, 1e-3, -2e-3],
2112 [0.0, -3e-3, 1e-3],
2113 [0.0, 2e-3, 0.5e-3],
2114 [0.0, -1e-3, 3e-3],
2115 [0.0, 0.5e-3, -1e-3],
2116 [0.0, 1.5e-3, 2e-3],
2117 [0.0, -2e-3, -0.5e-3],
2118 [0.0, 3e-3, 1e-3],
2119 [0.0, -0.5e-3, 2e-3],
2120 [0.0, 1e-3, -1.5e-3],
2121 ],
2122 s_tau_penalty: array![[0.0, 0.0, 0.0], [0.0, 0.25, 0.04], [0.0, 0.04, 0.15]],
2123 }
2124 }
2125 }
2126
2127 impl LogitDesignMotionFixture for GaussianRemlFixture {
2128 fn y(&self) -> &Array1<f64> {
2129 &self.y
2130 }
2131 fn w(&self) -> &Array1<f64> {
2132 &self.w
2133 }
2134 fn x(&self) -> &Array2<f64> {
2135 &self.x
2136 }
2137 fn s0(&self) -> &Array2<f64> {
2138 &self.s0
2139 }
2140 fn cfg(&self) -> &RemlConfig {
2141 &self.cfg
2142 }
2143 fn rho(&self) -> &Array1<f64> {
2144 &self.rho
2145 }
2146 }
2147
2148 #[test]
2149 pub(crate) fn profiled_gaussian_design_moving_gradient_matches_fd() {
2150 let f = GaussianRemlFixture::new();
2151 let state = f.state();
2152 let s_tau = Array2::<f64>::zeros((3, 3));
2153 let hyper = DirectionalHyperParam::single_penalty(
2154 0,
2155 f.x_tau_design.clone(),
2156 s_tau.clone(),
2157 None,
2158 None,
2159 )
2160 .expect("design-moving hyper direction");
2161
2162 let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2163 .expect("analytic directional gradient");
2164 let v_taufd = f.fd_directional_gradient(&f.x_tau_design, &s_tau);
2165
2166 let v_rel = (v_tau_analytic - v_taufd).abs() / v_taufd.abs().max(1e-10);
2167 assert!(
2168 v_rel < 1e-3,
2169 "Gaussian REML design-moving V_tau mismatch: rel={v_rel:.3e}, \
2170 analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
2171 );
2172 }
2173
2174 #[test]
2175 pub(crate) fn profiled_gaussian_penalty_only_gradient_matches_fd() {
2176 let f = GaussianRemlFixture::new();
2177 let state = f.state();
2178 let x_tau = Array2::<f64>::zeros(f.x.raw_dim());
2179 let hyper = DirectionalHyperParam::single_penalty(
2180 0,
2181 x_tau.clone(),
2182 f.s_tau_penalty.clone(),
2183 None,
2184 None,
2185 )
2186 .expect("penalty-only hyper direction");
2187
2188 let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2189 .expect("analytic directional gradient");
2190 let v_taufd = f.fd_directional_gradient(&x_tau, &f.s_tau_penalty);
2191
2192 let v_rel = (v_tau_analytic - v_taufd).abs() / v_taufd.abs().max(1e-10);
2193 assert!(
2194 v_rel < 1e-3,
2195 "Gaussian REML penalty-only V_tau mismatch: rel={v_rel:.3e}, \
2196 analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
2197 );
2198 }
2199
2200 #[test]
2201 pub(crate) fn profiled_gaussian_joint_hessian_matches_fd() {
2202 let f = GaussianRemlFixture::new();
2205 let x_tau_0 = Array2::<f64>::zeros(f.x.raw_dim());
2206 let s_tau_0 = f.s_tau_penalty.clone();
2207 let x_tau_1 = f.x_tau_design.clone();
2208 let s_tau_1 = Array2::<f64>::zeros((3, 3));
2209
2210 let hyper_dirs = vec![
2211 DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2212 .expect("penalty-only direction"),
2213 DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2214 .expect("design-moving direction"),
2215 ];
2216
2217 let state = f.state();
2218 let mut theta = Array1::<f64>::zeros(f.rho.len() + hyper_dirs.len());
2219 theta.slice_mut(s![..f.rho.len()]).assign(&f.rho);
2220 let (_, _, h_full) =
2221 compute_joint_hypercostgradienthessian(&state, &theta, f.rho.len(), &hyper_dirs)
2222 .expect("joint cost+gradient+hessian");
2223 let h_tt_analytic = h_full.slice(s![f.rho.len().., f.rho.len()..]).to_owned();
2224
2225 let x_tau_mats = vec![x_tau_0.clone(), x_tau_1.clone()];
2228 let s_tau_mats = vec![s_tau_0.clone(), s_tau_1.clone()];
2229 let h_ttfd = directional_tau_hessian_fd_reference(
2230 &f.y,
2231 &f.w,
2232 &f.x,
2233 &f.s0,
2234 &f.cfg,
2235 &f.rho,
2236 &hyper_dirs,
2237 &x_tau_mats,
2238 &s_tau_mats,
2239 );
2240
2241 let num = (&h_tt_analytic - &h_ttfd)
2242 .iter()
2243 .map(|v| v * v)
2244 .sum::<f64>()
2245 .sqrt();
2246 let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2247 let rel = num / den;
2248 assert!(
2249 rel < 1e-4,
2250 "Gaussian REML tau-tau Hessian mismatch: rel={rel:.3e}, \
2251 analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
2252 );
2253 }
2254
2255 #[test]
2269 pub(crate) fn logit_design_moving_gradient_matches_fd() {
2270 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0];
2271 let w = Array1::<f64>::ones(y.len());
2272 let x = array![
2273 [1.0, -1.2, 0.3],
2274 [1.0, -0.8, -0.4],
2275 [1.0, -0.3, 0.7],
2276 [1.0, 0.1, -0.9],
2277 [1.0, 0.5, 0.2],
2278 [1.0, 0.9, -0.1],
2279 [1.0, 1.3, 0.8],
2280 [1.0, 1.7, -0.6],
2281 [1.0, -0.5, 0.5],
2282 [1.0, 0.3, -0.3],
2283 ];
2284 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9]];
2285 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
2286 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
2287 let rho = array![0.0];
2288
2289 let x_tau = array![
2291 [0.0, 1e-3, -2e-3],
2292 [0.0, -3e-3, 1e-3],
2293 [0.0, 2e-3, 0.5e-3],
2294 [0.0, -1e-3, 3e-3],
2295 [0.0, 0.5e-3, -1e-3],
2296 [0.0, 1.5e-3, 2e-3],
2297 [0.0, -2e-3, -0.5e-3],
2298 [0.0, 3e-3, 1e-3],
2299 [0.0, -0.5e-3, 2e-3],
2300 [0.0, 1e-3, -1.5e-3],
2301 ];
2302 let s_tau = Array2::<f64>::zeros((3, 3));
2303 let hyper =
2304 DirectionalHyperParam::single_penalty(0, x_tau.clone(), s_tau.clone(), None, None)
2305 .expect("design-moving hyper direction");
2306
2307 let v_tau_analytic = single_directional_tau_gradient(&state, &rho, hyper)
2308 .expect("analytic directional gradient");
2309
2310 let h = 2e-5;
2311 let x_plus = &x + &x_tau.mapv(|v| h * v);
2312 let x_minus = &x - &x_tau.mapv(|v| h * v);
2313 let state_plus = build_logit_state(&y, &w, &x_plus, &s0, &cfg);
2314 let state_minus = build_logit_state(&y, &w, &x_minus, &s0, &cfg);
2315 let v_plus = state_plus.compute_cost(&rho).expect("cost+");
2316 let v_minus = state_minus.compute_cost(&rho).expect("cost-");
2317 let v_taufd = (v_plus - v_minus) / (2.0 * h);
2318
2319 let v_rel = (v_tau_analytic - v_taufd).abs() / v_taufd.abs().max(1e-10);
2320 assert!(
2321 v_rel < 1e-3,
2322 "Logit REML design-moving V_tau mismatch: rel={v_rel:.3e}, \
2323 analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
2324 );
2325 }
2326
2327 #[test]
2328 pub(crate) fn logit_design_moving_hessian_matches_fd() {
2329 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0];
2334 let w = Array1::<f64>::ones(y.len());
2335 let x = array![
2336 [1.0, -1.2, 0.3],
2337 [1.0, -0.8, -0.4],
2338 [1.0, -0.3, 0.7],
2339 [1.0, 0.1, -0.9],
2340 [1.0, 0.5, 0.2],
2341 [1.0, 0.9, -0.1],
2342 [1.0, 1.3, 0.8],
2343 [1.0, 1.7, -0.6],
2344 [1.0, -0.5, 0.5],
2345 [1.0, 0.3, -0.3],
2346 ];
2347 let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9]];
2348 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
2349 let rho = array![0.0];
2350
2351 let x_tau_0 = Array2::<f64>::zeros(x.raw_dim());
2353 let s_tau_0 = array![[0.0, 0.0, 0.0], [0.0, 0.25, 0.04], [0.0, 0.04, 0.15]];
2354 let x_tau_1 = array![
2355 [0.0, 1e-3, -2e-3],
2356 [0.0, -3e-3, 1e-3],
2357 [0.0, 2e-3, 0.5e-3],
2358 [0.0, -1e-3, 3e-3],
2359 [0.0, 0.5e-3, -1e-3],
2360 [0.0, 1.5e-3, 2e-3],
2361 [0.0, -2e-3, -0.5e-3],
2362 [0.0, 3e-3, 1e-3],
2363 [0.0, -0.5e-3, 2e-3],
2364 [0.0, 1e-3, -1.5e-3],
2365 ];
2366 let s_tau_1 = Array2::<f64>::zeros((3, 3));
2367
2368 let hyper_dirs = vec![
2369 DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2370 .expect("penalty-only direction"),
2371 DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2372 .expect("design-moving direction"),
2373 ];
2374
2375 let state = build_logit_state(&y, &w, &x, &s0, &cfg);
2376 let mut theta = Array1::<f64>::zeros(rho.len() + hyper_dirs.len());
2377 theta.slice_mut(s![..rho.len()]).assign(&rho);
2378 let (_, _, h_full) =
2379 compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
2380 .expect("joint cost+gradient+hessian");
2381 let h_tt_analytic = h_full.slice(s![rho.len().., rho.len()..]).to_owned();
2382
2383 let x_tau_mats = vec![x_tau_0.clone(), x_tau_1.clone()];
2384 let s_tau_mats = vec![s_tau_0.clone(), s_tau_1.clone()];
2385 let h_ttfd = directional_tau_hessian_fd_reference(
2386 &y,
2387 &w,
2388 &x,
2389 &s0,
2390 &cfg,
2391 &rho,
2392 &hyper_dirs,
2393 &x_tau_mats,
2394 &s_tau_mats,
2395 );
2396
2397 let num = (&h_tt_analytic - &h_ttfd)
2398 .iter()
2399 .map(|v| v * v)
2400 .sum::<f64>()
2401 .sqrt();
2402 let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2403 let rel = num / den;
2404 assert!(
2405 rel < 1e-4,
2406 "Logit REML design-moving tau-tau Hessian mismatch: rel={rel:.3e}, \
2407 analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
2408 );
2409 }
2410
2411 pub(crate) struct BinomialLogitDesignMotionFixture {
2421 pub(crate) y: Array1<f64>,
2422 pub(crate) w: Array1<f64>,
2423 pub(crate) x: Array2<f64>,
2424 pub(crate) s0: Array2<f64>,
2425 pub(crate) cfg: RemlConfig,
2426 pub(crate) rho: Array1<f64>,
2427 pub(crate) x_tau_design: Array2<f64>,
2429 pub(crate) s_tau_penalty: Array2<f64>,
2431 }
2432
2433 impl BinomialLogitDesignMotionFixture {
2434 pub(crate) fn new() -> Self {
2435 let y = array![
2437 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0,
2438 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0
2439 ];
2440 let x = array![
2442 [1.0, -1.50, 0.42, 0.88, -0.31],
2443 [1.0, -1.12, -0.65, 0.14, 1.23],
2444 [1.0, -0.80, 1.10, -0.53, 0.07],
2445 [1.0, -0.55, -0.22, 1.40, -0.90],
2446 [1.0, -0.30, 0.73, -1.05, 0.44],
2447 [1.0, -0.05, -1.33, 0.60, 0.81],
2448 [1.0, 0.18, 0.55, -0.27, -1.15],
2449 [1.0, 0.42, -0.90, 1.12, 0.33],
2450 [1.0, 0.70, 1.28, -0.78, -0.56],
2451 [1.0, 0.95, -0.18, 0.45, 1.40],
2452 [1.0, 1.20, 0.66, -1.30, -0.02],
2453 [1.0, 1.45, -1.05, 0.22, 0.68],
2454 [1.0, -1.35, 0.90, 0.55, -0.43],
2455 [1.0, -0.98, -0.40, -0.88, 1.05],
2456 [1.0, -0.62, 1.42, 0.30, -0.70],
2457 [1.0, -0.28, -0.77, -1.18, 0.52],
2458 [1.0, 0.05, 0.15, 0.95, -1.35],
2459 [1.0, 0.33, -1.20, -0.40, 0.18],
2460 [1.0, 0.60, 0.82, 1.25, -0.85],
2461 [1.0, 0.88, -0.50, -0.65, 1.10],
2462 [1.0, 1.15, 1.05, 0.10, -0.22],
2463 [1.0, -1.22, -0.95, 0.72, 0.90],
2464 [1.0, -0.75, 0.38, -1.42, 0.15],
2465 [1.0, -0.42, -1.15, 0.50, -1.08],
2466 [1.0, -0.10, 0.60, -0.15, 0.75],
2467 [1.0, 0.25, -0.28, 1.05, -0.48],
2468 [1.0, 0.52, 1.35, -0.92, 0.30],
2469 [1.0, 0.80, -0.70, 0.38, 1.20],
2470 [1.0, 1.08, 0.48, -0.60, -0.95],
2471 [1.0, 1.35, -0.55, 0.85, 0.42]
2472 ];
2473 let s0 = array![
2475 [0.0, 0.0, 0.0, 0.0, 0.0],
2476 [0.0, 1.40, 0.15, 0.05, -0.10],
2477 [0.0, 0.15, 1.10, -0.20, 0.08],
2478 [0.0, 0.05, -0.20, 0.95, 0.12],
2479 [0.0, -0.10, 0.08, 0.12, 1.25]
2480 ];
2481 let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
2482 let x_tau_design = array![
2485 [0.0, 1.2e-3, -0.8e-3, 0.5e-3, -1.5e-3],
2486 [0.0, -2.0e-3, 1.4e-3, -0.3e-3, 0.9e-3],
2487 [0.0, 0.6e-3, -1.1e-3, 1.8e-3, -0.4e-3],
2488 [0.0, -1.3e-3, 0.7e-3, -1.0e-3, 2.1e-3],
2489 [0.0, 0.9e-3, -0.5e-3, 0.2e-3, -0.8e-3],
2490 [0.0, -0.4e-3, 1.8e-3, -1.5e-3, 0.3e-3],
2491 [0.0, 1.5e-3, -1.3e-3, 0.8e-3, -1.1e-3],
2492 [0.0, -0.7e-3, 0.4e-3, -2.0e-3, 1.6e-3],
2493 [0.0, 2.2e-3, -0.9e-3, 1.3e-3, -0.6e-3],
2494 [0.0, -1.0e-3, 1.6e-3, -0.7e-3, 0.5e-3],
2495 [0.0, 0.3e-3, -2.1e-3, 1.1e-3, -1.8e-3],
2496 [0.0, -1.8e-3, 0.2e-3, -0.4e-3, 1.3e-3],
2497 [0.0, 1.1e-3, -1.5e-3, 2.0e-3, -0.2e-3],
2498 [0.0, -0.5e-3, 0.9e-3, -1.2e-3, 0.7e-3],
2499 [0.0, 1.7e-3, -0.3e-3, 0.6e-3, -2.0e-3],
2500 [0.0, -1.4e-3, 1.1e-3, -0.9e-3, 0.4e-3],
2501 [0.0, 0.8e-3, -1.7e-3, 1.5e-3, -0.1e-3],
2502 [0.0, -0.2e-3, 0.6e-3, -1.8e-3, 1.0e-3],
2503 [0.0, 1.4e-3, -0.4e-3, 0.3e-3, -1.3e-3],
2504 [0.0, -0.9e-3, 2.0e-3, -0.5e-3, 0.8e-3],
2505 [0.0, 0.5e-3, -1.0e-3, 1.6e-3, -0.7e-3],
2506 [0.0, -2.1e-3, 0.3e-3, -0.8e-3, 1.5e-3],
2507 [0.0, 0.7e-3, -1.8e-3, 0.9e-3, -0.3e-3],
2508 [0.0, -0.6e-3, 1.3e-3, -2.2e-3, 1.1e-3],
2509 [0.0, 1.9e-3, -0.7e-3, 0.4e-3, -0.9e-3],
2510 [0.0, -1.1e-3, 0.5e-3, -1.4e-3, 2.2e-3],
2511 [0.0, 0.4e-3, -1.6e-3, 1.2e-3, -0.5e-3],
2512 [0.0, -1.6e-3, 0.8e-3, -0.1e-3, 0.6e-3],
2513 [0.0, 1.3e-3, -2.2e-3, 0.7e-3, -1.4e-3],
2514 [0.0, -0.3e-3, 1.0e-3, -1.6e-3, 1.8e-3]
2515 ];
2516 let s_tau_penalty = array![
2518 [0.0, 0.0, 0.0, 0.0, 0.0],
2519 [0.0, 0.30, 0.05, -0.02, 0.04],
2520 [0.0, 0.05, 0.22, 0.03, -0.01],
2521 [0.0, -0.02, 0.03, 0.18, 0.06],
2522 [0.0, 0.04, -0.01, 0.06, 0.26]
2523 ];
2524 Self {
2525 w: Array1::<f64>::ones(y.len()),
2526 y,
2527 x,
2528 s0,
2529 cfg,
2530 rho: array![0.0],
2531 x_tau_design,
2532 s_tau_penalty,
2533 }
2534 }
2535 }
2536
2537 impl LogitDesignMotionFixture for BinomialLogitDesignMotionFixture {
2538 fn y(&self) -> &Array1<f64> {
2539 &self.y
2540 }
2541 fn w(&self) -> &Array1<f64> {
2542 &self.w
2543 }
2544 fn x(&self) -> &Array2<f64> {
2545 &self.x
2546 }
2547 fn s0(&self) -> &Array2<f64> {
2548 &self.s0
2549 }
2550 fn cfg(&self) -> &RemlConfig {
2551 &self.cfg
2552 }
2553 fn rho(&self) -> &Array1<f64> {
2554 &self.rho
2555 }
2556 }
2557
2558 #[test]
2561 pub(crate) fn binomial_logit_n30_design_moving_gradient_matches_fd() {
2562 let f = BinomialLogitDesignMotionFixture::new();
2569 let state = f.state();
2570 let s_tau = Array2::<f64>::zeros((5, 5));
2571 let hyper = DirectionalHyperParam::single_penalty(
2572 0,
2573 f.x_tau_design.clone(),
2574 s_tau.clone(),
2575 None,
2576 None,
2577 )
2578 .expect("design-moving hyper direction");
2579
2580 let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2581 .expect("analytic directional gradient");
2582 let v_tau_fd = f.fd_directional_gradient(&f.x_tau_design, &s_tau);
2583
2584 let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2585 assert!(
2586 v_rel < 1e-3,
2587 "Binomial-logit n=30 design-moving gradient mismatch: rel={v_rel:.3e}, \
2588 analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2589 );
2590 }
2591
2592 #[test]
2593 pub(crate) fn binomial_logit_n30_penalty_only_gradient_matches_fd() {
2594 let f = BinomialLogitDesignMotionFixture::new();
2599 let state = f.state();
2600 let x_tau = Array2::<f64>::zeros(f.x.raw_dim());
2601 let hyper = DirectionalHyperParam::single_penalty(
2602 0,
2603 x_tau.clone(),
2604 f.s_tau_penalty.clone(),
2605 None,
2606 None,
2607 )
2608 .expect("penalty-only hyper direction");
2609
2610 let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2611 .expect("analytic directional gradient");
2612 let v_tau_fd = f.fd_directional_gradient(&x_tau, &f.s_tau_penalty);
2613
2614 let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2615 assert!(
2616 v_rel < 1e-3,
2617 "Binomial-logit n=30 penalty-only gradient mismatch: rel={v_rel:.3e}, \
2618 analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2619 );
2620 }
2621
2622 #[test]
2623 pub(crate) fn binomial_logit_n30_joint_design_penalty_gradient_matches_fd() {
2624 let f = BinomialLogitDesignMotionFixture::new();
2629 let state = f.state();
2630 let hyper = DirectionalHyperParam::single_penalty(
2631 0,
2632 f.x_tau_design.clone(),
2633 f.s_tau_penalty.clone(),
2634 None,
2635 None,
2636 )
2637 .expect("joint design+penalty hyper direction");
2638
2639 let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2640 .expect("analytic directional gradient");
2641 let v_tau_fd = f.fd_directional_gradient(&f.x_tau_design, &f.s_tau_penalty);
2642
2643 let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2644 assert!(
2645 v_rel < 1e-3,
2646 "Binomial-logit n=30 joint design+penalty gradient mismatch: rel={v_rel:.3e}, \
2647 analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2648 );
2649 }
2650
2651 #[test]
2652 pub(crate) fn binomial_logit_n30_design_moving_hessian_matches_fd() {
2653 let f = BinomialLogitDesignMotionFixture::new();
2658 let x_tau_0 = Array2::<f64>::zeros(f.x.raw_dim());
2659 let s_tau_0 = f.s_tau_penalty.clone();
2660 let x_tau_1 = f.x_tau_design.clone();
2661 let s_tau_1 = Array2::<f64>::zeros((5, 5));
2662
2663 let hyper_dirs = vec![
2664 DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2665 .expect("penalty-only direction"),
2666 DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2667 .expect("design-moving direction"),
2668 ];
2669
2670 let state = f.state();
2671 let mut theta = Array1::<f64>::zeros(f.rho.len() + hyper_dirs.len());
2672 theta.slice_mut(s![..f.rho.len()]).assign(&f.rho);
2673 let (_, _, h_full) =
2674 compute_joint_hypercostgradienthessian(&state, &theta, f.rho.len(), &hyper_dirs)
2675 .expect("joint cost+gradient+hessian");
2676 let h_tt_analytic = h_full.slice(s![f.rho.len().., f.rho.len()..]).to_owned();
2677
2678 let x_tau_mats = vec![x_tau_0.clone(), x_tau_1.clone()];
2679 let s_tau_mats = vec![s_tau_0.clone(), s_tau_1.clone()];
2680 let h_tt_fd = directional_tau_hessian_fd_reference(
2681 &f.y,
2682 &f.w,
2683 &f.x,
2684 &f.s0,
2685 &f.cfg,
2686 &f.rho,
2687 &hyper_dirs,
2688 &x_tau_mats,
2689 &s_tau_mats,
2690 );
2691
2692 let num = (&h_tt_analytic - &h_tt_fd)
2693 .iter()
2694 .map(|v| v * v)
2695 .sum::<f64>()
2696 .sqrt();
2697 let den = h_tt_fd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2698 let rel = num / den;
2699 assert!(
2700 rel < 1e-4,
2701 "Binomial-logit n=30 tau-tau Hessian mismatch: rel={rel:.3e}, \
2702 analytic={h_tt_analytic:?}, fd={h_tt_fd:?}"
2703 );
2704 }
2705
2706 #[test]
2707 pub(crate) fn binomial_logit_n30_nonzero_rho_design_moving_gradient_matches_fd() {
2708 let f = BinomialLogitDesignMotionFixture::new();
2712 let rho = array![1.5];
2713 let s_tau = Array2::<f64>::zeros((5, 5));
2714
2715 let state = f.state();
2716 let hyper = DirectionalHyperParam::single_penalty(
2717 0,
2718 f.x_tau_design.clone(),
2719 s_tau.clone(),
2720 None,
2721 None,
2722 )
2723 .expect("design-moving hyper direction");
2724
2725 let v_tau_analytic = single_directional_tau_gradient(&state, &rho, hyper)
2726 .expect("analytic directional gradient");
2727
2728 let h = 2e-5;
2730 let (state_plus, state_minus) = f.state_perturbed(&f.x_tau_design, &s_tau, h);
2731 let v_plus = state_plus.compute_cost(&rho).expect("cost+");
2732 let v_minus = state_minus.compute_cost(&rho).expect("cost-");
2733 let v_tau_fd = (v_plus - v_minus) / (2.0 * h);
2734
2735 let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2736 assert!(
2737 v_rel < 1e-3,
2738 "Binomial-logit n=30 rho=1.5 design-moving gradient mismatch: rel={v_rel:.3e}, \
2739 analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2740 );
2741 }
2742
2743 #[test]
2744 pub(crate) fn binomial_logit_n30_rank_deficient_hessian_matches_cost_fd() {
2745 let f = BinomialLogitDesignMotionFixture::new();
2780 let x_tau_0 = Array2::<f64>::zeros(f.x.raw_dim());
2781 let s_tau_0 = f.s_tau_penalty.clone();
2782 let x_tau_1 = f.x_tau_design.clone();
2783 let s_tau_1 = Array2::<f64>::zeros((5, 5));
2784
2785 let hyper_dirs = vec![
2786 DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2787 .expect("penalty-only direction"),
2788 DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2789 .expect("design-moving direction"),
2790 ];
2791
2792 let state = f.state();
2794 let mut theta = Array1::<f64>::zeros(f.rho.len() + hyper_dirs.len());
2795 theta.slice_mut(s![..f.rho.len()]).assign(&f.rho);
2796 let (_, _, h_full) =
2797 compute_joint_hypercostgradienthessian(&state, &theta, f.rho.len(), &hyper_dirs)
2798 .expect("joint cost+gradient+hessian");
2799 let h_tt_analytic = h_full.slice(s![f.rho.len().., f.rho.len()..]).to_owned();
2800
2801 const TARGET_PHYSICAL_STEP: f64 = 1e-5;
2805 let x_tau_mats = [&x_tau_0, &x_tau_1];
2806 let s_tau_mats = [&s_tau_0, &s_tau_1];
2807 let steps: [f64; 2] = {
2808 let mut steps = [0.0; 2];
2809 for (j, step) in steps.iter_mut().enumerate() {
2810 let scale = x_tau_mats[j]
2811 .iter()
2812 .chain(s_tau_mats[j].iter())
2813 .fold(0.0_f64, |acc, value| acc.max(value.abs()));
2814 *step = if scale > 0.0 {
2815 TARGET_PHYSICAL_STEP / scale
2816 } else {
2817 TARGET_PHYSICAL_STEP
2818 };
2819 }
2820 steps
2821 };
2822
2823 let eval_cost = |a: f64, b: f64| -> f64 {
2825 let x_eval = &f.x
2826 + &x_tau_mats[0].mapv(|v| a * steps[0] * v)
2827 + &x_tau_mats[1].mapv(|v| b * steps[1] * v);
2828 let s_eval = &f.s0
2829 + &s_tau_mats[0].mapv(|v| a * steps[0] * v)
2830 + &s_tau_mats[1].mapv(|v| b * steps[1] * v);
2831 let st = build_logit_state(&f.y, &f.w, &x_eval, &s_eval, &f.cfg);
2832 st.compute_cost(&f.rho).expect("cost eval")
2833 };
2834
2835 let v_00 = eval_cost(0.0, 0.0);
2836 let v_p0 = eval_cost(1.0, 0.0);
2837 let v_m0 = eval_cost(-1.0, 0.0);
2838 let v_0p = eval_cost(0.0, 1.0);
2839 let v_0m = eval_cost(0.0, -1.0);
2840 let v_pp = eval_cost(1.0, 1.0);
2841 let v_pm = eval_cost(1.0, -1.0);
2842 let v_mp = eval_cost(-1.0, 1.0);
2843 let v_mm = eval_cost(-1.0, -1.0);
2844
2845 let h00_fd = (v_p0 - 2.0 * v_00 + v_m0) / (steps[0] * steps[0]);
2846 let h11_fd = (v_0p - 2.0 * v_00 + v_0m) / (steps[1] * steps[1]);
2847 let h01_fd = (v_pp - v_pm - v_mp + v_mm) / (4.0 * steps[0] * steps[1]);
2848
2849 let h_tt_fd = array![[h00_fd, h01_fd], [h01_fd, h11_fd]];
2850
2851 let num = (&h_tt_analytic - &h_tt_fd)
2852 .iter()
2853 .map(|v| v * v)
2854 .sum::<f64>()
2855 .sqrt();
2856 let den = h_tt_fd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2857 let rel = num / den;
2858
2859 assert!(
2860 rel < 3e-3,
2861 "Binomial-logit n=30 rank-deficient Hessian vs cost-FD mismatch: rel={rel:.3e}, \
2862 analytic={h_tt_analytic:?}, fd={h_tt_fd:?}"
2863 );
2864 }
2865}
2866
2867#[derive(Clone, Copy, Debug)]
2868pub(crate) enum RemlGeometry {
2869 DenseSpectral,
2870 SparseExactSpd,
2871}
2872
2873trait PenalizedGeometry {
2874 fn backend_kind(&self) -> GeometryBackendKind;
2875}
2876
2877#[derive(Clone)]
2878pub(crate) enum DerivativeMatrixStorage {
2879 Dense(Array2<f64>),
2880 Zero(ZeroDerivativeMatrix),
2881 Embedded(EmbeddedDerivativeMatrix),
2882 Implicit(ImplicitDerivativeOp),
2883 LatentCoord(LatentCoordDerivativeOp),
2884}
2885
2886trait DerivativeStorageBackend {
2898 fn resident_byte_count(&self) -> usize;
2899 fn design_nrows(&self) -> usize;
2900 fn design_ncols(&self) -> usize;
2901 fn penalty_dim(&self) -> usize;
2902 fn uses_implicit_storage(&self) -> bool;
2903 fn any_nonzero(&self) -> bool;
2904 fn materialize(&self) -> Array2<f64>;
2905 fn implicit_first_axis_info(
2906 &self,
2907 ) -> Option<(
2908 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
2909 usize,
2910 )>;
2911 fn implicit_axis_count_hint(&self) -> Option<usize>;
2912 fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError>;
2913 fn design_transpose_mul_original(
2914 &self,
2915 v: &Array1<f64>,
2916 ) -> Result<Array1<f64>, EstimationError>;
2917 fn design_transformed(
2918 &self,
2919 qs: &Array2<f64>,
2920 free_basis_opt: Option<&Array2<f64>>,
2921 ) -> Result<Array2<f64>, EstimationError>;
2922 fn design_transformed_forward_mul(
2926 &self,
2927 qs: &Array2<f64>,
2928 free_basis_opt: Option<&Array2<f64>>,
2929 u: &Array1<f64>,
2930 ) -> Result<Array1<f64>, EstimationError> {
2931 Ok(self.design_transformed(qs, free_basis_opt)?.dot(u))
2932 }
2933 fn design_transformed_transpose_mul(
2936 &self,
2937 qs: &Array2<f64>,
2938 free_basis_opt: Option<&Array2<f64>>,
2939 v: &Array1<f64>,
2940 ) -> Result<Array1<f64>, EstimationError> {
2941 Ok(self.design_transformed(qs, free_basis_opt)?.t().dot(v))
2942 }
2943 fn penalty_transformed(
2944 &self,
2945 qs: &Array2<f64>,
2946 free_basis_opt: Option<&Array2<f64>>,
2947 ) -> Result<Array2<f64>, EstimationError>;
2948 fn penalty_scaled_add_to(
2949 &self,
2950 target: &mut Array2<f64>,
2951 amp: f64,
2952 ) -> Result<(), EstimationError>;
2953}
2954
2955macro_rules! storage_dispatch {
2960 ($scrutinee:expr, $backend:ident => $body:expr) => {
2961 match $scrutinee {
2962 DerivativeMatrixStorage::Dense($backend) => $body,
2963 DerivativeMatrixStorage::Zero($backend) => $body,
2964 DerivativeMatrixStorage::Embedded($backend) => $body,
2965 DerivativeMatrixStorage::Implicit($backend) => $body,
2966 DerivativeMatrixStorage::LatentCoord($backend) => $body,
2967 }
2968 };
2969}
2970
2971#[derive(Clone)]
2972pub(crate) struct ZeroDerivativeMatrix {
2973 rows: usize,
2974 cols: usize,
2975}
2976
2977impl ZeroDerivativeMatrix {
2978 pub(crate) fn new(rows: usize, cols: usize) -> Self {
2979 Self { rows, cols }
2980 }
2981}
2982
2983#[derive(Clone, Copy, Debug)]
2985pub enum ImplicitDerivLevel {
2986 First(usize),
2988 SecondDiag(usize),
2990 SecondCross(usize, usize),
2992}
2993
2994#[derive(Clone)]
2997pub(crate) struct ImplicitDerivativeOp {
2998 pub(crate) operator: std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
2999 pub(crate) level: ImplicitDerivLevel,
3000 pub(crate) global_range: Range<usize>,
3001 pub(crate) total_dim: usize,
3002 pub(crate) cached_dense: std::sync::Arc<gam_runtime::resource::RayonSafeOnce<Array2<f64>>>,
3012}
3013
3014#[derive(Clone)]
3015pub(crate) struct LatentCoordDerivativeOp {
3016 pub(crate) operator: std::sync::Arc<gam_terms::basis::LatentCoordDesignDerivative>,
3017 pub(crate) flat_axis: usize,
3018 pub(crate) global_range: Range<usize>,
3019 pub(crate) total_dim: usize,
3020 pub(crate) cached_dense: std::sync::Arc<gam_runtime::resource::RayonSafeOnce<Array2<f64>>>,
3021}
3022
3023impl LatentCoordDerivativeOp {
3024 pub(crate) fn materialize_local(&self) -> Array2<f64> {
3025 self.operator.materialize_axis(self.flat_axis).expect(
3026 "radial scalar evaluation failed during latent-coordinate derivative materialization",
3027 )
3028 }
3029
3030 pub(crate) fn materialize_dense(&self) -> &Array2<f64> {
3031 self.cached_dense.get_or_compute(|| {
3032 let local = self.materialize_local();
3033 let mut out = Array2::<f64>::zeros((local.nrows(), self.total_dim));
3034 out.slice_mut(s![.., self.global_range.clone()])
3035 .assign(&local);
3036 out
3037 })
3038 }
3039
3040 pub(crate) fn nrows(&self) -> usize {
3041 self.operator.n_data()
3042 }
3043
3044 pub(crate) fn ncols(&self) -> usize {
3045 self.total_dim
3046 }
3047
3048 pub(crate) fn transpose_mul(&self, v: &Array1<f64>) -> Array1<f64> {
3049 let local = self
3050 .operator
3051 .transpose_mul_axis(self.flat_axis, &v.view())
3052 .expect(
3053 "radial scalar evaluation failed during latent-coordinate derivative transpose_mul",
3054 );
3055 let mut out = Array1::<f64>::zeros(self.total_dim);
3056 out.slice_mut(s![self.global_range.clone()]).assign(&local);
3057 out
3058 }
3059
3060 pub(crate) fn forward_mul(&self, u: &Array1<f64>) -> Array1<f64> {
3061 let u_local = u.slice(s![self.global_range.clone()]).to_owned();
3062 self.operator
3063 .forward_mul_axis(self.flat_axis, &u_local.view())
3064 .expect(
3065 "radial scalar evaluation failed during latent-coordinate derivative forward_mul",
3066 )
3067 }
3068}
3069
3070impl ImplicitDerivativeOp {
3071 pub(crate) fn materialize_local(&self) -> Array2<f64> {
3072 match self.level {
3073 ImplicitDerivLevel::First(axis) => self.operator.materialize_first(axis).expect(
3074 "radial scalar evaluation failed during implicit derivative materialization",
3075 ),
3076 ImplicitDerivLevel::SecondDiag(axis) => {
3077 self.operator.materialize_second_diag(axis).expect(
3078 "radial scalar evaluation failed during implicit derivative materialization",
3079 )
3080 }
3081 ImplicitDerivLevel::SecondCross(d, e) => {
3082 self.operator.materialize_second_cross(d, e).expect(
3083 "radial scalar evaluation failed during implicit derivative materialization",
3084 )
3085 }
3086 }
3087 }
3088
3089 pub(crate) fn materialize_dense(&self) -> &Array2<f64> {
3090 self.cached_dense.get_or_compute(|| {
3091 let local = self.materialize_local();
3092 let mut out = Array2::<f64>::zeros((local.nrows(), self.total_dim));
3093 out.slice_mut(s![.., self.global_range.clone()])
3094 .assign(&local);
3095 out
3096 })
3097 }
3098
3099 pub(crate) fn nrows(&self) -> usize {
3100 self.operator.n_data()
3101 }
3102
3103 pub(crate) fn ncols(&self) -> usize {
3104 self.total_dim
3105 }
3106
3107 pub(crate) fn transpose_mul(&self, v: &Array1<f64>) -> Array1<f64> {
3108 let local = match self.level {
3109 ImplicitDerivLevel::First(axis) => self
3110 .operator
3111 .transpose_mul(axis, &v.view())
3112 .expect("radial scalar evaluation failed during implicit derivative transpose_mul"),
3113 ImplicitDerivLevel::SecondDiag(axis) => self
3114 .operator
3115 .transpose_mul_second_diag(axis, &v.view())
3116 .expect("radial scalar evaluation failed during implicit derivative transpose_mul"),
3117 ImplicitDerivLevel::SecondCross(d, e) => self
3118 .operator
3119 .transpose_mul_second_cross(d, e, &v.view())
3120 .expect("radial scalar evaluation failed during implicit derivative transpose_mul"),
3121 };
3122 let mut out = Array1::<f64>::zeros(self.total_dim);
3123 out.slice_mut(s![self.global_range.clone()]).assign(&local);
3124 out
3125 }
3126
3127 pub(crate) fn forward_mul(&self, u: &Array1<f64>) -> Array1<f64> {
3128 let u_local = u.slice(s![self.global_range.clone()]).to_owned();
3129 match self.level {
3130 ImplicitDerivLevel::First(axis) => self
3131 .operator
3132 .forward_mul(axis, &u_local.view())
3133 .expect("radial scalar evaluation failed during implicit derivative forward_mul"),
3134 ImplicitDerivLevel::SecondDiag(axis) => self
3135 .operator
3136 .forward_mul_second_diag(axis, &u_local.view())
3137 .expect("radial scalar evaluation failed during implicit derivative forward_mul"),
3138 ImplicitDerivLevel::SecondCross(d, e) => self
3139 .operator
3140 .forward_mul_second_cross(d, e, &u_local.view())
3141 .expect("radial scalar evaluation failed during implicit derivative forward_mul"),
3142 }
3143 }
3144}
3145
3146#[derive(Clone)]
3147pub(crate) struct EmbeddedDerivativeMatrix {
3148 pub(crate) local: Array2<f64>,
3149 pub(crate) global_range: Range<usize>,
3150 pub(crate) total_dim: usize,
3151}
3152
3153impl EmbeddedDerivativeMatrix {
3154 pub(crate) fn new(local: Array2<f64>, global_range: Range<usize>, total_dim: usize) -> Self {
3155 Self {
3156 local,
3157 global_range,
3158 total_dim,
3159 }
3160 }
3161}
3162
3163impl DerivativeStorageBackend for Array2<f64> {
3164 fn resident_byte_count(&self) -> usize {
3165 self.len().saturating_mul(std::mem::size_of::<f64>())
3166 }
3167 fn design_nrows(&self) -> usize {
3168 Array2::nrows(self)
3169 }
3170 fn design_ncols(&self) -> usize {
3171 Array2::ncols(self)
3172 }
3173 fn penalty_dim(&self) -> usize {
3174 Array2::nrows(self)
3175 }
3176 fn uses_implicit_storage(&self) -> bool {
3177 false
3178 }
3179 fn any_nonzero(&self) -> bool {
3180 self.iter().any(|v| *v != 0.0)
3181 }
3182 fn materialize(&self) -> Array2<f64> {
3183 self.clone()
3184 }
3185 fn implicit_first_axis_info(
3186 &self,
3187 ) -> Option<(
3188 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3189 usize,
3190 )> {
3191 None
3192 }
3193 fn implicit_axis_count_hint(&self) -> Option<usize> {
3194 None
3195 }
3196
3197 fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3198 if Array2::ncols(self) != u.len() {
3199 crate::bail_invalid_estim!(
3200 "dense hyper design derivative forward_mul_original width mismatch: matrix={}x{}, vector={}",
3201 Array2::nrows(self),
3202 Array2::ncols(self),
3203 u.len()
3204 );
3205 }
3206 Ok(self.dot(u))
3207 }
3208
3209 fn design_transpose_mul_original(
3210 &self,
3211 v: &Array1<f64>,
3212 ) -> Result<Array1<f64>, EstimationError> {
3213 if Array2::nrows(self) != v.len() {
3214 crate::bail_invalid_estim!(
3215 "dense hyper design derivative transpose_mul_original height mismatch: matrix={}x{}, vector={}",
3216 Array2::nrows(self),
3217 Array2::ncols(self),
3218 v.len()
3219 );
3220 }
3221 Ok(self.t().dot(v))
3222 }
3223
3224 fn design_transformed(
3225 &self,
3226 qs: &Array2<f64>,
3227 free_basis_opt: Option<&Array2<f64>>,
3228 ) -> Result<Array2<f64>, EstimationError> {
3229 Ok(gam_linalg::matrix::DenseRightProductView::new(self)
3230 .with_factor(qs)
3231 .with_optional_factor(free_basis_opt)
3232 .materialize())
3233 }
3234
3235 fn penalty_transformed(
3236 &self,
3237 qs: &Array2<f64>,
3238 free_basis_opt: Option<&Array2<f64>>,
3239 ) -> Result<Array2<f64>, EstimationError> {
3240 let mut transformed = qs.t().dot(self).dot(qs);
3241 if let Some(z) = free_basis_opt {
3242 transformed = z.t().dot(&transformed).dot(z);
3243 }
3244 Ok(transformed)
3245 }
3246
3247 fn penalty_scaled_add_to(
3248 &self,
3249 target: &mut Array2<f64>,
3250 amp: f64,
3251 ) -> Result<(), EstimationError> {
3252 if target.raw_dim() != self.raw_dim() {
3253 crate::bail_invalid_estim!(
3254 "dense hyper penalty derivative shape mismatch: target={}x{}, matrix={}x{}",
3255 target.nrows(),
3256 target.ncols(),
3257 Array2::nrows(self),
3258 Array2::ncols(self)
3259 );
3260 }
3261 target.scaled_add(amp, self);
3262 Ok(())
3263 }
3264}
3265
3266impl DerivativeStorageBackend for ZeroDerivativeMatrix {
3267 fn resident_byte_count(&self) -> usize {
3268 0
3269 }
3270 fn design_nrows(&self) -> usize {
3271 self.rows
3272 }
3273 fn design_ncols(&self) -> usize {
3274 self.cols
3275 }
3276 fn penalty_dim(&self) -> usize {
3277 self.cols
3278 }
3279 fn uses_implicit_storage(&self) -> bool {
3280 false
3281 }
3282 fn any_nonzero(&self) -> bool {
3283 false
3284 }
3285 fn materialize(&self) -> Array2<f64> {
3286 Array2::<f64>::zeros((self.rows, self.cols))
3287 }
3288 fn implicit_first_axis_info(
3289 &self,
3290 ) -> Option<(
3291 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3292 usize,
3293 )> {
3294 None
3295 }
3296 fn implicit_axis_count_hint(&self) -> Option<usize> {
3297 None
3298 }
3299
3300 fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3301 if self.cols != u.len() {
3302 crate::bail_invalid_estim!(
3303 "zero hyper design derivative forward_mul_original width mismatch: matrix={}x{}, vector={}",
3304 self.rows,
3305 self.cols,
3306 u.len()
3307 );
3308 }
3309 Ok(Array1::<f64>::zeros(self.rows))
3310 }
3311
3312 fn design_transpose_mul_original(
3313 &self,
3314 v: &Array1<f64>,
3315 ) -> Result<Array1<f64>, EstimationError> {
3316 if self.rows != v.len() {
3317 crate::bail_invalid_estim!(
3318 "zero hyper design derivative transpose_mul_original height mismatch: matrix={}x{}, vector={}",
3319 self.rows,
3320 self.cols,
3321 v.len()
3322 );
3323 }
3324 Ok(Array1::<f64>::zeros(self.cols))
3325 }
3326
3327 fn design_transformed(
3328 &self,
3329 qs: &Array2<f64>,
3330 free_basis_opt: Option<&Array2<f64>>,
3331 ) -> Result<Array2<f64>, EstimationError> {
3332 if self.cols != qs.nrows() {
3333 crate::bail_invalid_estim!(
3334 "zero design derivative width mismatch: total_cols={}, qs rows={}",
3335 self.cols,
3336 qs.nrows()
3337 );
3338 }
3339 let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3340 Ok(Array2::<f64>::zeros((self.rows, cols)))
3341 }
3342
3343 fn design_transformed_forward_mul(
3344 &self,
3345 qs: &Array2<f64>,
3346 free_basis_opt: Option<&Array2<f64>>,
3347 u: &Array1<f64>,
3348 ) -> Result<Array1<f64>, EstimationError> {
3349 if self.cols != qs.nrows() {
3350 crate::bail_invalid_estim!(
3351 "zero design derivative width mismatch: total_cols={}, qs rows={}",
3352 self.cols,
3353 qs.nrows()
3354 );
3355 }
3356 let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3357 if u.len() != cols {
3358 crate::bail_invalid_estim!(
3359 "zero design derivative transformed forward width mismatch: expected {}, vector={}",
3360 cols,
3361 u.len()
3362 );
3363 }
3364 Ok(Array1::<f64>::zeros(self.rows))
3365 }
3366
3367 fn design_transformed_transpose_mul(
3368 &self,
3369 qs: &Array2<f64>,
3370 free_basis_opt: Option<&Array2<f64>>,
3371 v: &Array1<f64>,
3372 ) -> Result<Array1<f64>, EstimationError> {
3373 if self.rows != v.len() {
3374 crate::bail_invalid_estim!(
3375 "zero design derivative transpose height mismatch: matrix rows={}, vector={}",
3376 self.rows,
3377 v.len()
3378 );
3379 }
3380 if self.cols != qs.nrows() {
3381 crate::bail_invalid_estim!(
3382 "zero design derivative width mismatch: total_cols={}, qs rows={}",
3383 self.cols,
3384 qs.nrows()
3385 );
3386 }
3387 let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3388 Ok(Array1::<f64>::zeros(cols))
3389 }
3390
3391 fn penalty_transformed(
3392 &self,
3393 qs: &Array2<f64>,
3394 free_basis_opt: Option<&Array2<f64>>,
3395 ) -> Result<Array2<f64>, EstimationError> {
3396 if self.cols != qs.nrows() {
3397 crate::bail_invalid_estim!(
3398 "zero penalty derivative width mismatch: total_dim={}, qs rows={}",
3399 self.cols,
3400 qs.nrows()
3401 );
3402 }
3403 let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3404 Ok(Array2::<f64>::zeros((cols, cols)))
3405 }
3406
3407 fn penalty_scaled_add_to(
3408 &self,
3409 target: &mut Array2<f64>,
3410 amp: f64,
3411 ) -> Result<(), EstimationError> {
3412 if !amp.is_finite() {
3416 crate::bail_invalid_estim!(
3417 "zero hyper penalty derivative received non-finite amp={amp}"
3418 );
3419 }
3420 if target.nrows() != self.cols || target.ncols() != self.cols {
3421 crate::bail_invalid_estim!(
3422 "zero hyper penalty derivative shape mismatch: target={}x{}, expected {}x{}",
3423 target.nrows(),
3424 target.ncols(),
3425 self.cols,
3426 self.cols
3427 );
3428 }
3429 Ok(())
3430 }
3431}
3432
3433impl DerivativeStorageBackend for EmbeddedDerivativeMatrix {
3434 fn resident_byte_count(&self) -> usize {
3435 self.local.len().saturating_mul(std::mem::size_of::<f64>())
3436 }
3437 fn design_nrows(&self) -> usize {
3438 self.local.nrows()
3439 }
3440 fn design_ncols(&self) -> usize {
3441 self.total_dim
3442 }
3443 fn penalty_dim(&self) -> usize {
3444 self.total_dim
3445 }
3446 fn uses_implicit_storage(&self) -> bool {
3447 false
3448 }
3449 fn any_nonzero(&self) -> bool {
3450 self.local.iter().any(|v| *v != 0.0)
3451 }
3452 fn materialize(&self) -> Array2<f64> {
3453 let mut dense = Array2::<f64>::zeros((self.local.nrows(), self.total_dim));
3454 dense
3455 .slice_mut(s![.., self.global_range.clone()])
3456 .assign(&self.local);
3457 dense
3458 }
3459 fn implicit_first_axis_info(
3460 &self,
3461 ) -> Option<(
3462 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3463 usize,
3464 )> {
3465 None
3466 }
3467 fn implicit_axis_count_hint(&self) -> Option<usize> {
3468 None
3469 }
3470
3471 fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3472 if self.total_dim != u.len() {
3473 crate::bail_invalid_estim!(
3474 "embedded hyper design derivative forward_mul_original width mismatch: total_dim={}, vector={}",
3475 self.total_dim,
3476 u.len()
3477 );
3478 }
3479 let u_local = u.slice(s![self.global_range.clone()]).to_owned();
3480 Ok(self.local.dot(&u_local))
3481 }
3482
3483 fn design_transpose_mul_original(
3484 &self,
3485 v: &Array1<f64>,
3486 ) -> Result<Array1<f64>, EstimationError> {
3487 if self.local.nrows() != v.len() {
3488 crate::bail_invalid_estim!(
3489 "embedded hyper design derivative transpose_mul_original height mismatch: local_rows={}, vector={}",
3490 self.local.nrows(),
3491 v.len()
3492 );
3493 }
3494 let mut out = Array1::<f64>::zeros(self.total_dim);
3495 let pulled = self.local.t().dot(v);
3496 out.slice_mut(s![self.global_range.clone()]).assign(&pulled);
3497 Ok(out)
3498 }
3499
3500 fn design_transformed(
3501 &self,
3502 qs: &Array2<f64>,
3503 free_basis_opt: Option<&Array2<f64>>,
3504 ) -> Result<Array2<f64>, EstimationError> {
3505 if self.total_dim != qs.nrows() {
3506 crate::bail_invalid_estim!(
3507 "embedded design derivative width mismatch: total_cols={}, qs rows={}",
3508 self.total_dim,
3509 qs.nrows()
3510 );
3511 }
3512 let qs_local = qs.slice(s![self.global_range.clone(), ..]);
3513 let mut transformed = self.local.dot(&qs_local);
3514 if let Some(z) = free_basis_opt {
3515 transformed = transformed.dot(z);
3516 }
3517 Ok(transformed)
3518 }
3519
3520 fn penalty_transformed(
3521 &self,
3522 qs: &Array2<f64>,
3523 free_basis_opt: Option<&Array2<f64>>,
3524 ) -> Result<Array2<f64>, EstimationError> {
3525 if self.total_dim != qs.nrows() {
3526 crate::bail_invalid_estim!(
3527 "embedded penalty derivative width mismatch: total_dim={}, qs rows={}",
3528 self.total_dim,
3529 qs.nrows()
3530 );
3531 }
3532 let qs_local = qs.slice(s![self.global_range.clone(), ..]);
3533 let mut transformed = qs_local.t().dot(&self.local).dot(&qs_local);
3534 if let Some(z) = free_basis_opt {
3535 transformed = z.t().dot(&transformed).dot(z);
3536 }
3537 Ok(transformed)
3538 }
3539
3540 fn penalty_scaled_add_to(
3541 &self,
3542 target: &mut Array2<f64>,
3543 amp: f64,
3544 ) -> Result<(), EstimationError> {
3545 if target.nrows() != self.total_dim || target.ncols() != self.total_dim {
3546 crate::bail_invalid_estim!(
3547 "embedded hyper penalty derivative shape mismatch: target={}x{}, expected {}x{}",
3548 target.nrows(),
3549 target.ncols(),
3550 self.total_dim,
3551 self.total_dim
3552 );
3553 }
3554 target
3555 .slice_mut(s![self.global_range.clone(), self.global_range.clone()])
3556 .scaled_add(amp, &self.local);
3557 Ok(())
3558 }
3559}
3560
3561impl DerivativeStorageBackend for ImplicitDerivativeOp {
3562 fn resident_byte_count(&self) -> usize {
3563 0
3564 }
3565 fn design_nrows(&self) -> usize {
3566 self.nrows()
3567 }
3568 fn design_ncols(&self) -> usize {
3569 self.ncols()
3570 }
3571 fn penalty_dim(&self) -> usize {
3572 self.nrows()
3573 }
3574 fn uses_implicit_storage(&self) -> bool {
3575 true
3576 }
3577 fn any_nonzero(&self) -> bool {
3578 true
3579 }
3580 fn materialize(&self) -> Array2<f64> {
3581 self.materialize_dense().clone()
3582 }
3583 fn implicit_first_axis_info(
3584 &self,
3585 ) -> Option<(
3586 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3587 usize,
3588 )> {
3589 match self.level {
3590 ImplicitDerivLevel::First(axis) => Some((self.operator.clone(), axis)),
3591 _ => None,
3592 }
3593 }
3594 fn implicit_axis_count_hint(&self) -> Option<usize> {
3595 Some(self.operator.n_axes())
3596 }
3597
3598 fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3599 if self.ncols() != u.len() {
3600 crate::bail_invalid_estim!(
3601 "implicit hyper design derivative forward_mul_original width mismatch: operator_cols={}, vector={}",
3602 self.ncols(),
3603 u.len()
3604 );
3605 }
3606 Ok(self.forward_mul(u))
3607 }
3608
3609 fn design_transpose_mul_original(
3610 &self,
3611 v: &Array1<f64>,
3612 ) -> Result<Array1<f64>, EstimationError> {
3613 if self.nrows() != v.len() {
3614 crate::bail_invalid_estim!(
3615 "implicit hyper design derivative transpose_mul_original height mismatch: operator_rows={}, vector={}",
3616 self.nrows(),
3617 v.len()
3618 );
3619 }
3620 Ok(self.transpose_mul(v))
3621 }
3622
3623 fn design_transformed(
3624 &self,
3625 qs: &Array2<f64>,
3626 free_basis_opt: Option<&Array2<f64>>,
3627 ) -> Result<Array2<f64>, EstimationError> {
3628 let dense = self.materialize_dense();
3629 Ok(gam_linalg::matrix::DenseRightProductView::new(dense)
3630 .with_factor(qs)
3631 .with_optional_factor(free_basis_opt)
3632 .materialize())
3633 }
3634
3635 fn design_transformed_forward_mul(
3636 &self,
3637 qs: &Array2<f64>,
3638 free_basis_opt: Option<&Array2<f64>>,
3639 u: &Array1<f64>,
3640 ) -> Result<Array1<f64>, EstimationError> {
3641 let mut right = if let Some(z) = free_basis_opt {
3642 z.dot(u)
3643 } else {
3644 u.clone()
3645 };
3646 right = qs.dot(&right);
3647 Ok(self.forward_mul(&right))
3648 }
3649
3650 fn design_transformed_transpose_mul(
3651 &self,
3652 qs: &Array2<f64>,
3653 free_basis_opt: Option<&Array2<f64>>,
3654 v: &Array1<f64>,
3655 ) -> Result<Array1<f64>, EstimationError> {
3656 let mut pulled = qs.t().dot(&self.transpose_mul(v));
3657 if let Some(z) = free_basis_opt {
3658 pulled = z.t().dot(&pulled);
3659 }
3660 Ok(pulled)
3661 }
3662
3663 fn penalty_transformed(
3664 &self,
3665 qs: &Array2<f64>,
3666 free_basis_opt: Option<&Array2<f64>>,
3667 ) -> Result<Array2<f64>, EstimationError> {
3668 let dense = self.materialize_dense();
3669 let mut transformed = qs.t().dot(dense).dot(qs);
3670 if let Some(z) = free_basis_opt {
3671 transformed = z.t().dot(&transformed).dot(z);
3672 }
3673 Ok(transformed)
3674 }
3675
3676 fn penalty_scaled_add_to(
3677 &self,
3678 target: &mut Array2<f64>,
3679 amp: f64,
3680 ) -> Result<(), EstimationError> {
3681 let dense = self.materialize_dense();
3682 if target.raw_dim() != dense.raw_dim() {
3683 crate::bail_invalid_estim!(
3684 "implicit hyper penalty derivative shape mismatch: target={}x{}, matrix={}x{}",
3685 target.nrows(),
3686 target.ncols(),
3687 dense.nrows(),
3688 dense.ncols()
3689 );
3690 }
3691 target.scaled_add(amp, dense);
3692 Ok(())
3693 }
3694}
3695
3696impl DerivativeStorageBackend for LatentCoordDerivativeOp {
3697 fn resident_byte_count(&self) -> usize {
3698 0
3699 }
3700 fn design_nrows(&self) -> usize {
3701 self.nrows()
3702 }
3703 fn design_ncols(&self) -> usize {
3704 self.ncols()
3705 }
3706 fn penalty_dim(&self) -> usize {
3707 self.nrows()
3708 }
3709 fn uses_implicit_storage(&self) -> bool {
3710 true
3711 }
3712 fn any_nonzero(&self) -> bool {
3713 true
3714 }
3715 fn materialize(&self) -> Array2<f64> {
3716 self.materialize_dense().clone()
3717 }
3718 fn implicit_first_axis_info(
3719 &self,
3720 ) -> Option<(
3721 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3722 usize,
3723 )> {
3724 None
3725 }
3726 fn implicit_axis_count_hint(&self) -> Option<usize> {
3727 Some(self.operator.n_axes())
3728 }
3729
3730 fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3731 if self.ncols() != u.len() {
3732 crate::bail_invalid_estim!(
3733 "latent-coordinate hyper design derivative forward_mul_original width mismatch: operator_cols={}, vector={}",
3734 self.ncols(),
3735 u.len()
3736 );
3737 }
3738 Ok(self.forward_mul(u))
3739 }
3740
3741 fn design_transpose_mul_original(
3742 &self,
3743 v: &Array1<f64>,
3744 ) -> Result<Array1<f64>, EstimationError> {
3745 if self.nrows() != v.len() {
3746 crate::bail_invalid_estim!(
3747 "latent-coordinate hyper design derivative transpose_mul_original height mismatch: operator_rows={}, vector={}",
3748 self.nrows(),
3749 v.len()
3750 );
3751 }
3752 Ok(self.transpose_mul(v))
3753 }
3754
3755 fn design_transformed(
3756 &self,
3757 qs: &Array2<f64>,
3758 free_basis_opt: Option<&Array2<f64>>,
3759 ) -> Result<Array2<f64>, EstimationError> {
3760 let dense = self.materialize_dense();
3761 Ok(gam_linalg::matrix::DenseRightProductView::new(dense)
3762 .with_factor(qs)
3763 .with_optional_factor(free_basis_opt)
3764 .materialize())
3765 }
3766
3767 fn design_transformed_forward_mul(
3768 &self,
3769 qs: &Array2<f64>,
3770 free_basis_opt: Option<&Array2<f64>>,
3771 u: &Array1<f64>,
3772 ) -> Result<Array1<f64>, EstimationError> {
3773 let mut right = if let Some(z) = free_basis_opt {
3774 z.dot(u)
3775 } else {
3776 u.clone()
3777 };
3778 right = qs.dot(&right);
3779 Ok(self.forward_mul(&right))
3780 }
3781
3782 fn design_transformed_transpose_mul(
3783 &self,
3784 qs: &Array2<f64>,
3785 free_basis_opt: Option<&Array2<f64>>,
3786 v: &Array1<f64>,
3787 ) -> Result<Array1<f64>, EstimationError> {
3788 let mut pulled = qs.t().dot(&self.transpose_mul(v));
3789 if let Some(z) = free_basis_opt {
3790 pulled = z.t().dot(&pulled);
3791 }
3792 Ok(pulled)
3793 }
3794
3795 fn penalty_transformed(
3796 &self,
3797 qs: &Array2<f64>,
3798 free_basis_opt: Option<&Array2<f64>>,
3799 ) -> Result<Array2<f64>, EstimationError> {
3800 let dense = self.materialize_dense();
3801 let mut transformed = qs.t().dot(dense).dot(qs);
3802 if let Some(z) = free_basis_opt {
3803 transformed = z.t().dot(&transformed).dot(z);
3804 }
3805 Ok(transformed)
3806 }
3807
3808 fn penalty_scaled_add_to(
3809 &self,
3810 target: &mut Array2<f64>,
3811 amp: f64,
3812 ) -> Result<(), EstimationError> {
3813 let dense = self.materialize_dense();
3814 if target.raw_dim() != dense.raw_dim() {
3815 crate::bail_invalid_estim!(
3816 "latent-coordinate hyper penalty derivative shape mismatch: target={}x{}, matrix={}x{}",
3817 target.nrows(),
3818 target.ncols(),
3819 dense.nrows(),
3820 dense.ncols()
3821 );
3822 }
3823 target.scaled_add(amp, dense);
3824 Ok(())
3825 }
3826}
3827
3828#[derive(Clone)]
3829pub struct HyperDesignDerivative {
3830 pub(crate) storage: DerivativeMatrixStorage,
3831}
3832
3833impl HyperDesignDerivative {
3834 pub fn zero(nrows: usize, ncols: usize) -> Self {
3835 Self {
3836 storage: DerivativeMatrixStorage::Zero(ZeroDerivativeMatrix::new(nrows, ncols)),
3837 }
3838 }
3839
3840 pub fn from_embedded(
3841 local: Array2<f64>,
3842 global_range: Range<usize>,
3843 total_cols: usize,
3844 ) -> Self {
3845 Self {
3846 storage: DerivativeMatrixStorage::Embedded(EmbeddedDerivativeMatrix::new(
3847 local,
3848 global_range,
3849 total_cols,
3850 )),
3851 }
3852 }
3853
3854 pub fn from_implicit(
3855 operator: std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3856 level: ImplicitDerivLevel,
3857 global_range: Range<usize>,
3858 total_cols: usize,
3859 ) -> Self {
3860 Self {
3861 storage: DerivativeMatrixStorage::Implicit(ImplicitDerivativeOp {
3862 operator,
3863 level,
3864 global_range,
3865 total_dim: total_cols,
3866 cached_dense: std::sync::Arc::new(gam_runtime::resource::RayonSafeOnce::new()),
3867 }),
3868 }
3869 }
3870
3871 pub fn from_latent_coord(
3872 operator: std::sync::Arc<gam_terms::basis::LatentCoordDesignDerivative>,
3873 flat_axis: usize,
3874 global_range: Range<usize>,
3875 total_cols: usize,
3876 ) -> Self {
3877 Self {
3878 storage: DerivativeMatrixStorage::LatentCoord(LatentCoordDerivativeOp {
3879 operator,
3880 flat_axis,
3881 global_range,
3882 total_dim: total_cols,
3883 cached_dense: std::sync::Arc::new(gam_runtime::resource::RayonSafeOnce::new()),
3884 }),
3885 }
3886 }
3887
3888 pub(crate) fn resident_byte_count(&self) -> usize {
3889 storage_dispatch!(&self.storage, b => b.resident_byte_count())
3890 }
3891
3892 pub(crate) fn nrows(&self) -> usize {
3893 storage_dispatch!(&self.storage, b => b.design_nrows())
3894 }
3895
3896 pub(crate) fn ncols(&self) -> usize {
3897 storage_dispatch!(&self.storage, b => b.design_ncols())
3898 }
3899
3900 pub(crate) fn uses_implicit_storage(&self) -> bool {
3901 storage_dispatch!(&self.storage, b => b.uses_implicit_storage())
3902 }
3903
3904 pub(crate) fn materialize(&self) -> Array2<f64> {
3905 storage_dispatch!(&self.storage, b => b.materialize())
3906 }
3907
3908 pub(crate) fn any_nonzero(&self) -> bool {
3909 storage_dispatch!(&self.storage, b => b.any_nonzero())
3910 }
3911
3912 pub(crate) fn forward_mul_original(
3913 &self,
3914 u: &Array1<f64>,
3915 ) -> Result<Array1<f64>, EstimationError> {
3916 storage_dispatch!(&self.storage, b => b.design_forward_mul_original(u))
3917 }
3918
3919 pub(crate) fn transpose_mul_original(
3920 &self,
3921 v: &Array1<f64>,
3922 ) -> Result<Array1<f64>, EstimationError> {
3923 storage_dispatch!(&self.storage, b => b.design_transpose_mul_original(v))
3924 }
3925
3926 pub(crate) fn transformed(
3927 &self,
3928 qs: &Array2<f64>,
3929 free_basis_opt: Option<&Array2<f64>>,
3930 ) -> Result<Array2<f64>, EstimationError> {
3931 storage_dispatch!(&self.storage, b => b.design_transformed(qs, free_basis_opt))
3932 }
3933
3934 pub(crate) fn transformed_forward_mul(
3935 &self,
3936 qs: &Array2<f64>,
3937 free_basis_opt: Option<&Array2<f64>>,
3938 u: &Array1<f64>,
3939 ) -> Result<Array1<f64>, EstimationError> {
3940 storage_dispatch!(&self.storage, b => b.design_transformed_forward_mul(qs, free_basis_opt, u))
3941 }
3942
3943 pub(crate) fn transformed_transpose_mul(
3944 &self,
3945 qs: &Array2<f64>,
3946 free_basis_opt: Option<&Array2<f64>>,
3947 v: &Array1<f64>,
3948 ) -> Result<Array1<f64>, EstimationError> {
3949 storage_dispatch!(&self.storage, b => b.design_transformed_transpose_mul(qs, free_basis_opt, v))
3950 }
3951
3952 pub(crate) fn implicit_first_axis_info(
3957 &self,
3958 ) -> Option<(
3959 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3960 usize,
3961 )> {
3962 storage_dispatch!(&self.storage, b => b.implicit_first_axis_info())
3963 }
3964
3965 pub(crate) fn implicit_axis_count_hint(&self) -> Option<usize> {
3966 storage_dispatch!(&self.storage, b => b.implicit_axis_count_hint())
3967 }
3968}
3969
3970impl From<Array2<f64>> for HyperDesignDerivative {
3971 fn from(value: Array2<f64>) -> Self {
3972 Self {
3973 storage: DerivativeMatrixStorage::Dense(value),
3974 }
3975 }
3976}
3977
3978#[derive(Clone)]
3979pub struct HyperPenaltyDerivative {
3980 pub(crate) storage: DerivativeMatrixStorage,
3981}
3982
3983impl HyperPenaltyDerivative {
3984 pub fn from_embedded(
3985 local: Array2<f64>,
3986 global_range: Range<usize>,
3987 total_dim: usize,
3988 ) -> Self {
3989 Self {
3990 storage: DerivativeMatrixStorage::Embedded(EmbeddedDerivativeMatrix::new(
3991 local,
3992 global_range,
3993 total_dim,
3994 )),
3995 }
3996 }
3997
3998 pub(crate) fn resident_byte_count(&self) -> usize {
3999 storage_dispatch!(&self.storage, b => b.resident_byte_count())
4000 }
4001
4002 pub(crate) fn nrows(&self) -> usize {
4003 storage_dispatch!(&self.storage, b => b.penalty_dim())
4004 }
4005
4006 pub(crate) fn ncols(&self) -> usize {
4007 self.nrows()
4008 }
4009
4010 pub(crate) fn scaled_materialize(&self, amp: f64) -> Array2<f64> {
4011 let mut out = Array2::<f64>::zeros((self.nrows(), self.ncols()));
4012 self.scaled_add_to(&mut out, amp)
4013 .expect("scaled materialize uses matching target shape");
4014 out
4015 }
4016
4017 pub(crate) fn transformed(
4018 &self,
4019 qs: &Array2<f64>,
4020 free_basis_opt: Option<&Array2<f64>>,
4021 ) -> Result<Array2<f64>, EstimationError> {
4022 storage_dispatch!(&self.storage, b => b.penalty_transformed(qs, free_basis_opt))
4023 }
4024
4025 pub(crate) fn scaled_add_to(
4026 &self,
4027 target: &mut Array2<f64>,
4028 amp: f64,
4029 ) -> Result<(), EstimationError> {
4030 storage_dispatch!(&self.storage, b => b.penalty_scaled_add_to(target, amp))
4031 }
4032}
4033
4034impl From<Array2<f64>> for HyperPenaltyDerivative {
4035 fn from(value: Array2<f64>) -> Self {
4036 Self {
4037 storage: DerivativeMatrixStorage::Dense(value),
4038 }
4039 }
4040}
4041
4042#[derive(Clone)]
4043pub struct PenaltyDerivativeComponent {
4044 pub penalty_index: usize,
4045 pub matrix: HyperPenaltyDerivative,
4046}
4047
4048#[derive(Clone)]
4049pub struct DirectionalHyperParam {
4050 pub(crate) x_tau_original: HyperDesignDerivative,
4051 pub(crate) penalty_first_components: Vec<PenaltyDerivativeComponent>,
4054 pub(crate) x_tau_tau_original: Option<Vec<Option<HyperDesignDerivative>>>,
4058 pub(crate) penaltysecond_components: Option<Vec<Option<Vec<PenaltyDerivativeComponent>>>>,
4061 pub(crate) penaltysecond_component_provider: Option<
4062 std::sync::Arc<
4063 dyn Fn(usize) -> Result<Option<Vec<PenaltyDerivativeComponent>>, EstimationError>
4064 + Send
4065 + Sync
4066 + 'static,
4067 >,
4068 >,
4069 pub(crate) penaltysecond_partner_indices: Option<std::sync::Arc<[usize]>>,
4070 pub(crate) is_penalty_like: bool,
4074}
4075
4076impl DirectionalHyperParam {
4077 pub(crate) fn resident_byte_count(&self) -> usize {
4078 let mut bytes = self.x_tau_original.resident_byte_count();
4079 for component in &self.penalty_first_components {
4080 bytes = bytes.saturating_add(component.matrix.resident_byte_count());
4081 }
4082 if let Some(entries) = self.x_tau_tau_original.as_ref() {
4083 for entry in entries.iter().flatten() {
4084 bytes = bytes.saturating_add(entry.resident_byte_count());
4085 }
4086 }
4087 if let Some(rows) = self.penaltysecond_components.as_ref() {
4088 for components in rows.iter().flatten() {
4089 for component in components {
4090 bytes = bytes.saturating_add(component.matrix.resident_byte_count());
4091 }
4092 }
4093 }
4094 bytes
4095 }
4096
4097 pub(crate) fn canonicalize_penalty_components(
4098 components: Vec<(usize, HyperPenaltyDerivative)>,
4099 ) -> Result<Vec<PenaltyDerivativeComponent>, EstimationError> {
4100 let mut out: Vec<PenaltyDerivativeComponent> = Vec::with_capacity(components.len());
4101 for (penalty_index, matrix) in components {
4102 if out.iter().any(|c| c.penalty_index == penalty_index) {
4103 crate::bail_invalid_estim!(
4104 "duplicate penalty derivative component for penalty {}",
4105 penalty_index
4106 );
4107 }
4108 out.push(PenaltyDerivativeComponent {
4109 penalty_index,
4110 matrix,
4111 });
4112 }
4113 Ok(out)
4114 }
4115
4116 pub fn new_compact(
4117 x_tau_original: HyperDesignDerivative,
4118 penalty_first_components: Vec<(usize, HyperPenaltyDerivative)>,
4119 x_tau_tau_original: Option<Vec<Option<HyperDesignDerivative>>>,
4120 penaltysecond_components: Option<Vec<Option<Vec<(usize, HyperPenaltyDerivative)>>>>,
4121 ) -> Result<Self, EstimationError> {
4122 let is_penalty_like = !x_tau_original.any_nonzero();
4123 let penalty_first_components =
4124 Self::canonicalize_penalty_components(penalty_first_components)?;
4125 let penaltysecond_components = match penaltysecond_components {
4126 Some(rows) => {
4127 let mut out = Vec::with_capacity(rows.len());
4128 for row in rows {
4129 out.push(match row {
4130 Some(components) => {
4131 Some(Self::canonicalize_penalty_components(components)?)
4132 }
4133 None => None,
4134 });
4135 }
4136 Some(out)
4137 }
4138 None => None,
4139 };
4140 Ok(Self {
4141 x_tau_original,
4142 penalty_first_components,
4143 x_tau_tau_original,
4144 penaltysecond_components,
4145 penaltysecond_component_provider: None,
4146 penaltysecond_partner_indices: None,
4147 is_penalty_like,
4148 })
4149 }
4150
4151 pub fn not_penalty_like(mut self) -> Self {
4154 self.is_penalty_like = false;
4155 self
4156 }
4157
4158 pub fn with_penaltysecond_component_provider(
4159 mut self,
4160 provider: std::sync::Arc<
4161 dyn Fn(usize) -> Result<Option<Vec<PenaltyDerivativeComponent>>, EstimationError>
4162 + Send
4163 + Sync
4164 + 'static,
4165 >,
4166 ) -> Self {
4167 self.penaltysecond_component_provider = Some(provider);
4168 self
4169 }
4170
4171 pub fn with_penaltysecond_partner_indices(mut self, partners: Vec<usize>) -> Self {
4172 self.penaltysecond_partner_indices = Some(std::sync::Arc::from(partners));
4173 self
4174 }
4175
4176 pub(crate) fn x_tau_dense(&self) -> Array2<f64> {
4177 self.x_tau_original.materialize()
4178 }
4179
4180 pub(crate) fn transformed_x_tau(
4181 &self,
4182 qs: &Array2<f64>,
4183 free_basis_opt: Option<&Array2<f64>>,
4184 ) -> Result<Array2<f64>, EstimationError> {
4185 self.x_tau_original.transformed(qs, free_basis_opt)
4186 }
4187
4188 pub(crate) fn x_tau_tau_entry_at(&self, j: usize) -> Option<HyperDesignDerivative> {
4189 self.x_tau_tau_original
4190 .as_ref()
4191 .and_then(|rows| rows.get(j))
4192 .and_then(|entry| entry.clone())
4193 }
4194
4195 pub(crate) fn has_implicit_operator(&self) -> bool {
4198 self.x_tau_original.uses_implicit_storage()
4199 }
4200
4201 pub(crate) fn has_implicit_multidim_duchon(&self) -> bool {
4202 self.implicit_first_axis_info()
4203 .is_some_and(|(op, _)| op.n_axes() > 1 && op.is_duchon_family())
4204 }
4205
4206 pub(crate) fn implicit_first_axis_info(
4208 &self,
4209 ) -> Option<(
4210 std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
4211 usize,
4212 )> {
4213 self.x_tau_original.implicit_first_axis_info()
4214 }
4215
4216 pub(crate) fn implicit_axis_count_hint(&self) -> Option<usize> {
4217 self.x_tau_original.implicit_axis_count_hint()
4218 }
4219
4220 pub(crate) fn penalty_first_components(&self) -> &[PenaltyDerivativeComponent] {
4221 &self.penalty_first_components
4222 }
4223
4224 pub(crate) fn penalty_total_at(
4225 &self,
4226 rho: &Array1<f64>,
4227 p: usize,
4228 ) -> Result<Array2<f64>, EstimationError> {
4229 let mut out = Array2::<f64>::zeros((p, p));
4230 for component in &self.penalty_first_components {
4231 if component.matrix.nrows() != p || component.matrix.ncols() != p {
4232 crate::bail_invalid_estim!(
4233 "S_tau shape mismatch for penalty {}: expected {}x{}, got {}x{}",
4234 component.penalty_index,
4235 p,
4236 p,
4237 component.matrix.nrows(),
4238 component.matrix.ncols()
4239 );
4240 }
4241 if component.penalty_index >= rho.len() {
4242 crate::bail_invalid_estim!(
4243 "penalty_index {} out of bounds for rho dimension {}",
4244 component.penalty_index,
4245 rho.len()
4246 );
4247 }
4248 component
4249 .matrix
4250 .scaled_add_to(&mut out, rho[component.penalty_index].exp())?;
4251 }
4252 Ok(out)
4253 }
4254
4255 pub(crate) fn penaltysecond_components_for(
4256 &self,
4257 j: usize,
4258 ) -> Result<Option<Vec<PenaltyDerivativeComponent>>, EstimationError> {
4259 if let Some(components) = self
4260 .penaltysecond_components
4261 .as_ref()
4262 .and_then(|rows| rows.get(j))
4263 .and_then(|row| row.clone())
4264 {
4265 return Ok(Some(components));
4266 }
4267 if let Some(provider) = self.penaltysecond_component_provider.as_ref() {
4268 return provider(j);
4269 }
4270 Ok(None)
4271 }
4272
4273 pub(crate) fn penaltysecond_componentrows(
4274 &self,
4275 ) -> Option<&[Option<Vec<PenaltyDerivativeComponent>>]> {
4276 self.penaltysecond_components.as_deref()
4277 }
4278
4279 pub(crate) fn penalty_first_component_count(&self) -> usize {
4280 self.penalty_first_components.len()
4281 }
4282
4283 pub(crate) fn has_penaltysecond_pair_at(&self, j: usize) -> bool {
4284 self.penaltysecond_components
4285 .as_ref()
4286 .and_then(|rows| rows.get(j))
4287 .is_some_and(Option::is_some)
4288 || self
4289 .penaltysecond_partner_indices
4290 .as_ref()
4291 .is_some_and(|partners| partners.contains(&j))
4292 }
4293}
4294
4295#[derive(Clone, Debug)]
4296pub(crate) struct SparseRemlDecision {
4297 pub(crate) geometry: RemlGeometry,
4298 pub(crate) reason: &'static str,
4299 pub(crate) p: usize,
4300 pub(crate) nnz_x: usize,
4301 pub(crate) nnz_h_upper_est: Option<usize>,
4302 pub(crate) density_h_upper_est: Option<f64>,
4303}
4304
4305#[derive(Clone)]
4306pub(crate) struct SparseExactEvalData {
4307 pub(crate) factor: Arc<SparseExactFactor>,
4308 pub(crate) takahashi: Option<Arc<gam_linalg::sparse_exact::TakahashiInverse>>,
4309 pub(crate) logdet_h: f64,
4310 pub(crate) logdet_s_pos: f64,
4311 pub(crate) penalty_rank: usize,
4312 pub(crate) det1_values: Arc<Array1<f64>>,
4313}
4314
4315#[derive(Clone)]
4316pub struct FirthDenseOperator {
4317 pub(crate) x_dense: Array2<f64>,
4344 pub(crate) x_dense_t: Array2<f64>,
4345 pub(crate) q_basis: Array2<f64>,
4348 pub(crate) x_reduced: Array2<f64>,
4351 pub(crate) observation_weight_sqrt: Option<Array1<f64>>,
4357 pub(crate) k_reduced: Array2<f64>,
4359 pub(crate) x_metric_reduced_inv_diag: Array1<f64>,
4364 pub(crate) half_log_det: f64,
4366 pub(crate) h_diag: Array1<f64>,
4368 pub(crate) w: Array1<f64>,
4370 pub(crate) w1: Array1<f64>,
4371 pub(crate) w2: Array1<f64>,
4372 pub(crate) w3: Array1<f64>,
4373 pub(crate) w4: Array1<f64>,
4374 pub(crate) b_base: Array2<f64>,
4376 pub(crate) p_b_base: Array2<f64>,
4379}
4380
4381#[derive(Clone)]
4382pub(crate) struct FirthDirection {
4383 pub(crate) deta: Array1<f64>,
4384 pub(crate) g_u_reduced: Array2<f64>,
4385 pub(crate) a_u_reduced: Array2<f64>,
4386 pub(crate) dh: Array1<f64>,
4387 pub(crate) b_uvec: Array1<f64>,
4389}
4390
4391#[derive(Clone)]
4392pub(crate) struct FirthTauPartialKernel {
4393 pub(super) deta_partial: Array1<f64>,
4394 pub(crate) dotw1: Array1<f64>,
4395 pub(crate) dotw2: Array1<f64>,
4396 pub(crate) dot_h_partial: Array1<f64>,
4397 pub(crate) x_tau_reduced: Array2<f64>,
4400 pub(super) dot_i_partial: Array2<f64>,
4401 pub(crate) dot_k_reduced: Array2<f64>,
4405}
4406
4407#[derive(Clone)]
4408pub(crate) struct FirthTauExactKernel {
4409 pub(crate) gphi_tau: Array1<f64>,
4410 pub(crate) phi_tau_partial: f64,
4411 pub(crate) tau_kernel: Option<FirthTauPartialKernel>,
4412}
4413
4414#[derive(Clone)]
4426pub(crate) struct FirthTauTauExactKernel {
4427 pub(super) phi_tau_tau_partial: f64,
4428 pub(super) gphi_tau_tau: Array1<f64>,
4429 pub(super) tau_tau_kernel: Option<FirthTauTauPartialKernel>,
4430}
4431
4432#[derive(Clone, Default)]
4445pub(crate) struct FirthTauTauPartialKernel {
4446 pub(super) x_tau_i_reduced: Array2<f64>,
4447 pub(super) x_tau_j_reduced: Array2<f64>,
4448 pub(super) deta_i_partial: Array1<f64>,
4449 pub(super) deta_j_partial: Array1<f64>,
4450 pub(super) dot_h_i_partial: Array1<f64>,
4451 pub(super) dot_h_j_partial: Array1<f64>,
4452 pub(super) dot_k_i_reduced: Array2<f64>,
4453 pub(super) dot_k_j_reduced: Array2<f64>,
4454 pub(super) dot_i_i_partial: Array2<f64>,
4455 pub(super) dot_i_j_partial: Array2<f64>,
4456 pub(super) x_tau_tau_reduced: Option<Array2<f64>>,
4457 pub(super) deta_ij_partial: Option<Array1<f64>>,
4458}
4459
4460#[derive(Clone, Default)]
4468pub(crate) struct FirthTauBetaPartialKernel {
4469 pub(super) x_tau_reduced: Array2<f64>,
4470 pub(super) deta_partial: Array1<f64>,
4471 pub(super) dot_h_partial: Array1<f64>,
4472 pub(super) dot_i_partial: Array2<f64>,
4473 pub(super) dot_k_reduced: Array2<f64>,
4474 pub(super) deta_v: Array1<f64>,
4475 pub(super) deta_tau_v: Array1<f64>,
4476 pub(super) a_v_reduced: Array2<f64>,
4477 pub(super) dh_v: Array1<f64>,
4478 pub(super) b_vvec: Array1<f64>,
4479 pub(super) d_beta_dot_k: Array2<f64>,
4480 pub(super) d_beta_dot_h: Array1<f64>,
4481}
4482
4483#[derive(Clone)]
4494pub(crate) struct EvalShared {
4495 pub(crate) key: Option<Vec<u64>>,
4496 pub(crate) pirls_result: Arc<PirlsResult>,
4497 pub(crate) ridge_passport: RidgePassport,
4498 pub(crate) geometry: RemlGeometry,
4499 pub(crate) h_total: Arc<Array2<f64>>,
4503 pub(crate) sparse_exact: Option<Arc<SparseExactEvalData>>,
4504 pub(crate) firth_dense_operator: Option<Arc<FirthDenseOperator>>,
4505 pub(crate) firth_dense_operator_original: Option<Arc<FirthDenseOperator>>,
4508 pub(crate) penalty_pseudologdet: std::sync::OnceLock<Arc<penalty_logdet::PenaltyPseudologdet>>,
4522 pub(crate) penalty_scores_at_mode: std::sync::OnceLock<Arc<Vec<Array1<f64>>>>,
4535 pub(crate) block_local_correction:
4553 std::sync::OnceLock<(usize, Arc<outer_eval::TkCorrectionTerms>)>,
4554}
4555
4556impl EvalShared {
4557 pub(crate) fn matches(&self, key: &Option<Vec<u64>>) -> bool {
4558 match (&self.key, key) {
4559 (None, None) => true,
4560 (Some(a), Some(b)) => a == b,
4561 _ => false,
4562 }
4563 }
4564
4565 pub(crate) fn penalty_pseudologdet_original(
4580 &self,
4581 canonical_penalties: &[gam_terms::construction::CanonicalPenalty],
4582 lambdas: &[f64],
4583 p: usize,
4584 ) -> Result<Arc<penalty_logdet::PenaltyPseudologdet>, EstimationError> {
4585 if let Some(pld) = self.penalty_pseudologdet.get() {
4586 if pld.dim() != p {
4587 return Err(EstimationError::LayoutError(format!(
4588 "shared penalty pseudo-logdet frame mismatch: cached p={}, requested p={}",
4589 pld.dim(),
4590 p
4591 )));
4592 }
4593 return Ok(Arc::clone(pld));
4594 }
4595 let pld = Arc::new(
4596 penalty_logdet::PenaltyPseudologdet::from_penalties(
4597 canonical_penalties,
4598 lambdas,
4599 self.ridge_passport.penalty_logdet_ridge(),
4600 p,
4601 )
4602 .map_err(EstimationError::InvalidInput)?,
4603 );
4604 match self.penalty_pseudologdet.set(Arc::clone(&pld)) {
4605 Ok(()) => Ok(pld),
4606 Err(_) => Ok(Arc::clone(
4610 self.penalty_pseudologdet
4611 .get()
4612 .expect("OnceLock set raced, so it is initialized"),
4613 )),
4614 }
4615 }
4616}
4617
4618impl PenalizedGeometry for EvalShared {
4619 fn backend_kind(&self) -> GeometryBackendKind {
4620 match self.geometry {
4621 RemlGeometry::DenseSpectral => GeometryBackendKind::DenseSpectral,
4622 RemlGeometry::SparseExactSpd => GeometryBackendKind::SparseExactSpd,
4623 }
4624 }
4625}
4626
4627pub(crate) struct PirlsLruCache {
4637 pub(crate) map: HashMap<Vec<u64>, (Arc<PirlsResult>, u64, usize)>,
4639 pub(crate) byte_budget: usize,
4640 pub(crate) current_bytes: usize,
4641 pub(crate) clock: u64,
4642}
4643
4644impl PirlsLruCache {
4645 pub(crate) fn new(byte_budget: usize) -> Self {
4646 Self {
4647 map: HashMap::new(),
4648 byte_budget: byte_budget.max(1),
4649 current_bytes: 0,
4650 clock: 0,
4651 }
4652 }
4653
4654 pub(crate) fn get(&mut self, key: &Vec<u64>) -> Option<Arc<PirlsResult>> {
4655 if let Some(entry) = self.map.get_mut(key) {
4656 self.clock += 1;
4657 entry.1 = self.clock;
4658 Some(entry.0.clone())
4659 } else {
4660 None
4661 }
4662 }
4663
4664 pub(crate) fn insert(&mut self, key: Vec<u64>, value: Arc<PirlsResult>) {
4665 self.clock += 1;
4666 let bytes = pirls_result_cache_bytes(&value);
4667 if bytes > self.byte_budget {
4671 if let Some((_, _, prev_bytes)) = self.map.remove(&key) {
4672 self.current_bytes = self.current_bytes.saturating_sub(prev_bytes);
4673 }
4674 return;
4675 }
4676 if let Some((_, _, prev_bytes)) = self.map.remove(&key) {
4677 self.current_bytes = self.current_bytes.saturating_sub(prev_bytes);
4678 }
4679 while self.current_bytes + bytes > self.byte_budget {
4680 let evict_key = self
4681 .map
4682 .iter()
4683 .min_by_key(|(_, (_, ts, _))| *ts)
4684 .map(|(k, _)| k.clone());
4685 match evict_key {
4686 Some(k) => {
4687 if let Some((_, _, evict_bytes)) = self.map.remove(&k) {
4688 self.current_bytes = self.current_bytes.saturating_sub(evict_bytes);
4689 }
4690 }
4691 None => break,
4692 }
4693 }
4694 self.current_bytes += bytes;
4695 self.map.insert(key, (value, self.clock, bytes));
4696 }
4697
4698 pub(crate) fn clear(&mut self) {
4699 self.map.clear();
4700 self.current_bytes = 0;
4701 }
4702}
4703
4704#[derive(Clone, Copy, PartialEq, Eq)]
4705pub(crate) struct PenaltySubspaceCacheKey {
4706 pub(crate) penalty_matrix_fingerprint: u64,
4707 pub(crate) ridge_passport_signature: u64,
4708}
4709
4710pub(crate) struct PenaltySubspaceCache {
4711 pub(crate) entry: Option<(PenaltySubspaceCacheKey, Arc<outer_eval::PenaltySubspace>)>,
4712}
4713
4714impl PenaltySubspaceCache {
4715 pub(crate) fn new() -> Self {
4716 Self { entry: None }
4717 }
4718
4719 pub(crate) fn get(
4720 &self,
4721 key: &PenaltySubspaceCacheKey,
4722 ) -> Option<Arc<outer_eval::PenaltySubspace>> {
4723 self.entry
4724 .as_ref()
4725 .filter(|(cached_key, _)| cached_key == key)
4726 .map(|(_, value)| value.clone())
4727 }
4728
4729 pub(crate) fn insert(
4730 &mut self,
4731 key: PenaltySubspaceCacheKey,
4732 value: Arc<outer_eval::PenaltySubspace>,
4733 ) {
4734 self.entry = Some((key, value));
4735 }
4736
4737 pub(crate) fn clear(&mut self) {
4738 self.entry = None;
4739 }
4740}
4741
4742impl PenaltySubspaceCacheKey {
4743 pub(crate) fn from_inputs(
4748 e_transformed: &ndarray::Array2<f64>,
4749 ridge_passport: &gam_problem::RidgePassport,
4750 ) -> Self {
4751 use std::collections::hash_map::DefaultHasher;
4752 use std::hash::{Hash, Hasher};
4753 let mut hasher = DefaultHasher::new();
4754 e_transformed.nrows().hash(&mut hasher);
4755 e_transformed.ncols().hash(&mut hasher);
4756 for value in e_transformed.iter() {
4757 value.to_bits().hash(&mut hasher);
4758 }
4759 let penalty_matrix_fingerprint = hasher.finish();
4760 let mut ridge_hasher = DefaultHasher::new();
4761 ridge_passport.delta.to_bits().hash(&mut ridge_hasher);
4762 (ridge_passport.matrix_form as u8).hash(&mut ridge_hasher);
4763 ridge_passport
4764 .policy
4765 .include_penalty_logdet
4766 .hash(&mut ridge_hasher);
4767 ridge_passport
4768 .policy
4769 .include_laplacehessian
4770 .hash(&mut ridge_hasher);
4771 let ridge_passport_signature = ridge_hasher.finish();
4772 Self {
4773 penalty_matrix_fingerprint,
4774 ridge_passport_signature,
4775 }
4776 }
4777}
4778
4779pub(crate) fn pirls_result_cache_bytes(result: &PirlsResult) -> usize {
4794 use std::mem::size_of;
4795 let n_array_elems = result.final_eta.len()
4796 + result.solveweights.len()
4797 + result.solveworking_response.len()
4798 + result.solvemu.len()
4799 + result.solve_c_array.len()
4800 + result.solve_d_array.len();
4801 let p = result.beta_transformed.0.len();
4802 let pen_h = symmetric_matrix_cache_bytes(&result.penalized_hessian_transformed);
4803 let stab_h = symmetric_matrix_cache_bytes(&result.stabilizedhessian_transformed);
4804 let reparam = (result.reparam_result.s_transformed.len()
4805 + result.reparam_result.qs.len()
4806 + result.reparam_result.e_transformed.len()
4807 + result.reparam_result.det1.len())
4808 * size_of::<f64>();
4809 n_array_elems * size_of::<f64>() + p * size_of::<f64>() + pen_h + stab_h + reparam + 1024
4810}
4811
4812pub(crate) fn symmetric_matrix_cache_bytes(m: &gam_linalg::matrix::SymmetricMatrix) -> usize {
4813 use gam_linalg::matrix::SymmetricMatrix;
4814 use std::mem::size_of;
4815 match m {
4816 SymmetricMatrix::Dense(a) => a.len() * size_of::<f64>(),
4817 SymmetricMatrix::Sparse(s) => {
4818 let (symbolic, values) = s.parts();
4820 values.len() * (size_of::<f64>() + size_of::<usize>())
4821 + std::mem::size_of_val(symbolic.col_ptr())
4822 }
4823 }
4824}
4825
4826pub(crate) const OUTER_EVAL_LRU_CAPACITY: usize = 8;
4834
4835pub(crate) struct OuterEvalLru {
4849 capacity: usize,
4850 entries: std::collections::VecDeque<(Vec<u64>, OuterEval)>,
4852}
4853
4854impl OuterEvalLru {
4855 pub(crate) fn new(capacity: usize) -> Self {
4856 Self {
4857 capacity: capacity.max(1),
4858 entries: std::collections::VecDeque::new(),
4859 }
4860 }
4861
4862 pub(crate) fn get(&mut self, key: &[u64]) -> Option<OuterEval> {
4866 let pos = self
4867 .entries
4868 .iter()
4869 .position(|(k, _)| k.as_slice() == key)?;
4870 let entry = self.entries.remove(pos)?;
4871 let eval = entry.1.clone();
4872 self.entries.push_back(entry);
4873 Some(eval)
4874 }
4875
4876 pub(crate) fn insert(&mut self, key: Vec<u64>, eval: OuterEval) {
4879 if let Some(pos) = self
4880 .entries
4881 .iter()
4882 .position(|(k, _)| k.as_slice() == key.as_slice())
4883 {
4884 self.entries.remove(pos);
4885 }
4886 self.entries.push_back((key, eval));
4887 while self.entries.len() > self.capacity {
4888 self.entries.pop_front();
4889 }
4890 }
4891
4892 pub(crate) fn clear(&mut self) {
4893 self.entries.clear();
4894 }
4895}
4896
4897pub(crate) struct EvalCacheManager {
4902 pub(crate) pirls_cache: RwLock<PirlsLruCache>,
4903 pub(crate) penalty_subspace_cache: RwLock<PenaltySubspaceCache>,
4904 pub(crate) current_eval_bundle: RwLock<Option<EvalShared>>,
4905 pub(crate) current_outer_eval: RwLock<Option<(Vec<u64>, OuterEval)>>,
4909 pub(crate) outer_eval_lru: RwLock<OuterEvalLru>,
4923 pub(crate) pirls_cache_enabled: AtomicBool,
4924}
4925
4926impl EvalCacheManager {
4927 pub(crate) fn new() -> Self {
4928 Self {
4929 pirls_cache: RwLock::new(PirlsLruCache::new(PIRLS_CACHE_BYTE_BUDGET)),
4930 penalty_subspace_cache: RwLock::new(PenaltySubspaceCache::new()),
4931 current_eval_bundle: RwLock::new(None),
4932 current_outer_eval: RwLock::new(None),
4933 outer_eval_lru: RwLock::new(OuterEvalLru::new(OUTER_EVAL_LRU_CAPACITY)),
4934 pirls_cache_enabled: AtomicBool::new(true),
4935 }
4936 }
4937
4938 pub(crate) fn sanitized_rhokey(rho: &Array1<f64>) -> Option<Vec<u64>> {
4942 self::rho_key::sanitized_rhokey(rho)
4943 }
4944
4945 pub(super) fn cached_penalty_subspace<F>(
4952 &self,
4953 e_transformed: &ndarray::Array2<f64>,
4954 ridge_passport: &gam_problem::RidgePassport,
4955 build: F,
4956 ) -> Result<Arc<outer_eval::PenaltySubspace>, EstimationError>
4957 where
4958 F: FnOnce() -> Result<outer_eval::PenaltySubspace, EstimationError>,
4959 {
4960 let key = PenaltySubspaceCacheKey::from_inputs(e_transformed, ridge_passport);
4961 if let Some(hit) = self.penalty_subspace_cache.read().unwrap().get(&key) {
4962 return Ok(hit);
4963 }
4964 let value = Arc::new(build()?);
4965 self.penalty_subspace_cache
4966 .write()
4967 .unwrap()
4968 .insert(key, value.clone());
4969 Ok(value)
4970 }
4971
4972 pub(crate) fn cached_eval_bundle(&self, key: &Option<Vec<u64>>) -> Option<EvalShared> {
4973 let guard = self.current_eval_bundle.read().unwrap();
4974 let bundle: &EvalShared = guard.as_ref()?;
4975 bundle.matches(key).then(|| bundle.clone())
4976 }
4977
4978 pub(crate) fn store_eval_bundle(&self, bundle: EvalShared) {
4979 *self.current_eval_bundle.write().unwrap() = Some(bundle);
4980 }
4981
4982 pub(crate) fn cached_outer_eval(&self, key: &Option<Vec<u64>>) -> Option<OuterEval> {
4983 let key = key.as_ref()?;
4984 self.outer_eval_lru.write().unwrap().get(key)
4991 }
4992
4993 pub(crate) fn store_outer_eval(&self, key: &Option<Vec<u64>>, eval: &OuterEval) {
4994 if let Some(key) = key.clone() {
4995 *self.current_outer_eval.write().unwrap() = Some((key.clone(), eval.clone()));
4999 self.outer_eval_lru.write().unwrap().insert(key, eval.clone());
5000 }
5001 }
5002
5003 pub(crate) fn invalidate_eval_bundle(&self) {
5004 self.current_eval_bundle.write().unwrap().take();
5005 self.current_outer_eval.write().unwrap().take();
5006 self.outer_eval_lru.write().unwrap().clear();
5007 }
5008
5009 pub(crate) fn clear_eval_and_factor_caches(&self) {
5010 self.invalidate_eval_bundle();
5011 self.penalty_subspace_cache.write().unwrap().clear();
5012 }
5013}
5014
5015pub(crate) struct RemlArena {
5018 pub(crate) cost_eval_count: RwLock<u64>,
5019 pub(crate) inner_pirls_solve_count: AtomicU64,
5032 pub(crate) lastgradient_used_stochastic_fallback: AtomicBool,
5033}
5034
5035impl RemlArena {
5036 pub(crate) fn new() -> Self {
5037 Self {
5038 cost_eval_count: RwLock::new(0),
5039 inner_pirls_solve_count: AtomicU64::new(0),
5040 lastgradient_used_stochastic_fallback: AtomicBool::new(false),
5041 }
5042 }
5043}
5044
5045pub(crate) struct AloFrozenNuisance {
5046 pub(crate) n_obs: usize,
5047 pub(crate) influence_scale: Vec<f64>,
5048 pub(crate) phi: f64,
5049}
5050
5051pub(crate) struct RemlState<'a> {
5052 pub(crate) y: ArrayView1<'a, f64>,
5053 pub(crate) x: DesignMatrix,
5054 pub(crate) weights: ArrayView1<'a, f64>,
5055 pub(crate) offset: Array1<f64>,
5056 pub(crate) canonical_penalties: Arc<Vec<gam_terms::construction::CanonicalPenalty>>,
5060 pub(crate) balanced_penalty_root: Array2<f64>,
5061 pub(crate) reparam_invariant: ReparamInvariant,
5062 pub(crate) sparse_penalty_block_count: Option<usize>,
5063 pub(crate) p: usize,
5064 pub(crate) config: Arc<RemlConfig>,
5065 pub(crate) runtime_mixture_link_state: Option<gam_problem::MixtureLinkState>,
5066 pub(crate) runtime_sas_link_state: Option<SasLinkState>,
5067 pub(crate) nullspace_dims: Vec<usize>,
5068 pub(crate) coefficient_lower_bounds: Option<Array1<f64>>,
5069 pub(crate) linear_constraints: Option<crate::pirls::LinearInequalityConstraints>,
5070 pub(crate) penalty_shrinkage_floor: Option<f64>,
5072 pub(crate) rho_prior: gam_problem::RhoPrior,
5074
5075 pub(crate) cache_manager: EvalCacheManager,
5076 pub(crate) arena: RemlArena,
5077 pub(crate) warm_start_beta: RwLock<Option<Coefficients>>,
5078 pub(crate) warm_start_rho: RwLock<Option<Array1<f64>>>,
5088 pub(crate) prev_warm_start_beta: RwLock<Option<Coefficients>>,
5089 pub(crate) prev_warm_start_rho: RwLock<Option<Array1<f64>>>,
5090 pub(crate) warm_start_enabled: AtomicBool,
5091 pub(crate) screening_max_inner_iterations: Arc<AtomicUsize>,
5092 pub(crate) outer_inner_cap: Arc<AtomicUsize>,
5107
5108 pub(crate) last_inner_iters: Arc<AtomicUsize>,
5121 pub(crate) last_inner_converged: Arc<AtomicBool>,
5122
5123 pub(crate) ift_warm_start_cache: RwLock<Option<IftWarmStartCache>>,
5139
5140 pub(crate) last_pirls_lm_lambda: Arc<AtomicU64>,
5152
5153 pub(crate) frozen_negbin_theta: Arc<AtomicU64>,
5165
5166 pub(crate) last_ift_prediction_residual: Arc<AtomicU64>,
5188
5189 pub(crate) last_pirls_accept_rho: Arc<AtomicU64>,
5204
5205 pub(crate) ift_cached_factor: RwLock<Option<Arc<dyn gam_linalg::matrix::FactorizedSystem>>>,
5216
5217 pub(crate) kronecker_penalty_system: Option<gam_terms::smooth::KroneckerPenaltySystem>,
5221 pub(crate) kronecker_factored: Option<gam_terms::basis::KroneckerFactoredBasis>,
5224
5225 pub(crate) gaussian_fixed_cache: RwLock<Option<Arc<crate::pirls::GaussianFixedCache>>>,
5235 pub(crate) gaussian_psi_gram_deriv:
5246 RwLock<Option<Arc<(ndarray::Array2<f64>, ndarray::Array1<f64>)>>>,
5247 pub(crate) glm_psi_gram_deriv:
5265 RwLock<Option<Arc<(ndarray::Array2<f64>, ndarray::Array1<f64>)>>>,
5266 pub(crate) glm_first_step_gram: RwLock<Option<Arc<ndarray::Array2<f64>>>>,
5285 pub(crate) flat_glm_first_step_gram: RwLock<Option<Arc<ndarray::Array2<f64>>>>,
5295 pub(crate) alo_frozen_nuisance: RwLock<Option<AloFrozenNuisance>>,
5302
5303 pub(crate) persistent_warm_start_key: RwLock<Option<String>>,
5306 pub(crate) persistent_latent_values_fingerprint: Option<u64>,
5307 pub(crate) persistent_latent_values_cache: RwLock<PersistentLatentValuesCache>,
5308 pub(crate) analytic_penalty_registry_fingerprint: u64,
5309 pub(crate) persistent_warm_start_loaded: AtomicBool,
5311 pub(crate) persistent_warm_start_store_suppression: AtomicUsize,
5315 pub(crate) alo_stabilization_suppression: AtomicUsize,
5325 pub(crate) persistent_warm_start_disk_enabled: AtomicBool,
5339}