1use std::sync::Arc;
30
31use ndarray::{Array1, Array2, Array3};
32
33use gam_identifiability::families::compiler::{
34 BlockOrder, RowHessian, RowJacobianOperator, scale_jacobian_by_sqrt_h_with,
35};
36use gam_problem::gauge::assemble_block_triangular_t;
37use faer::Side;
38use gam_linalg::faer_ndarray::FaerEigh;
39use gam_linalg::matrix::{CoefficientTransformOperator, DenseDesignMatrix, DesignMatrix};
40use gam_problem::{FamilyChannelHessian, PenaltyMatrix};
41
42const K_SURVIVAL: usize = 4;
43
44const BETA_NONTRIVIAL_ABS_THRESHOLD: f64 = 1e-12;
50
51pub struct SurvivalRowHessian {
58 h: Array3<f64>,
61}
62
63impl SurvivalRowHessian {
64 pub fn from_pilot_primary_state(
68 q0: &Array1<f64>,
69 q1: &Array1<f64>,
70 qd1: &Array1<f64>,
71 g: &Array1<f64>,
72 z: &Array1<f64>,
73 weights: &Array1<f64>,
74 event: &Array1<f64>,
75 derivative_guard: f64,
76 probit_scale: f64,
77 ) -> Result<Self, String> {
78 let n = q0.len();
79 if [
80 q1.len(),
81 qd1.len(),
82 g.len(),
83 z.len(),
84 weights.len(),
85 event.len(),
86 ]
87 .iter()
88 .any(|&l| l != n)
89 {
90 return Err(format!(
91 "SurvivalRowHessian: length mismatch \
92 q0={n}, q1={}, qd1={}, g={}, z={}, weights={}, event={}",
93 q1.len(),
94 qd1.len(),
95 g.len(),
96 z.len(),
97 weights.len(),
98 event.len()
99 ));
100 }
101 let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
102 for i in 0..n {
103 let (_, _grad, hess) =
104 crate::survival::marginal_slope::row_primary_for_compiler(
105 q0[i],
106 q1[i],
107 qd1[i],
108 g[i],
109 z[i],
110 weights[i],
111 event[i],
112 derivative_guard,
113 probit_scale,
114 )?;
115 let mut h_i = Array2::<f64>::zeros((K_SURVIVAL, K_SURVIVAL));
117 for a in 0..K_SURVIVAL {
118 for b in 0..K_SURVIVAL {
119 h_i[[a, b]] = hess[a][b];
120 }
121 }
122 let clamped = psd_clamp_4x4(&h_i);
123 for a in 0..K_SURVIVAL {
124 for b in 0..K_SURVIVAL {
125 h_full[[i, a, b]] = clamped[[a, b]];
126 }
127 }
128 }
129 Ok(Self { h: h_full })
130 }
131
132 pub fn from_full(h: Array3<f64>) -> Self {
135 assert_eq!(h.shape()[1], K_SURVIVAL);
136 assert_eq!(h.shape()[2], K_SURVIVAL);
137 Self { h }
138 }
139}
140
141impl RowHessian for SurvivalRowHessian {
142 fn k(&self) -> usize {
143 K_SURVIVAL
144 }
145 fn nrows(&self) -> usize {
146 self.h.shape()[0]
147 }
148 fn fill_row(&self, row: usize, out: &mut [f64]) {
149 assert_eq!(out.len(), K_SURVIVAL * K_SURVIVAL);
150 for a in 0..K_SURVIVAL {
151 for b in 0..K_SURVIVAL {
152 out[a * K_SURVIVAL + b] = self.h[[row, a, b]];
153 }
154 }
155 }
156 fn evaluate_full(&self) -> Array3<f64> {
157 self.h.clone()
158 }
159}
160
161impl FamilyChannelHessian for SurvivalRowHessian {
197 fn n_outputs(&self) -> usize {
198 K_SURVIVAL
199 }
200
201 fn n_subjects(&self) -> usize {
202 self.h.shape()[0]
203 }
204
205 fn fill_subject(&self, i: usize, out: &mut [f64]) {
206 assert_eq!(out.len(), K_SURVIVAL * K_SURVIVAL);
207 for a in 0..K_SURVIVAL {
208 for b in 0..K_SURVIVAL {
209 out[a * K_SURVIVAL + b] = self.h[[i, a, b]];
210 }
211 }
212 }
213
214 fn evaluate_full(&self) -> ndarray::Array3<f64> {
215 self.h.clone()
216 }
217
218 fn channel_hessian_at(
219 &self,
220 beta: &[f64],
221 family_scalars: Option<&Arc<dyn std::any::Any + Send + Sync>>,
222 ) -> Result<Arc<dyn FamilyChannelHessian>, String> {
223 use crate::survival::marginal_slope::SurvivalMarginalSlopeFamilyScalars;
224
225 let scalars_opt =
226 family_scalars.and_then(|a| a.downcast_ref::<SurvivalMarginalSlopeFamilyScalars>());
227
228 let beta_nontrivial = beta
230 .iter()
231 .any(|&b| b.abs() > BETA_NONTRIVIAL_ABS_THRESHOLD);
232
233 match scalars_opt {
234 None if beta_nontrivial => {
235 Err(
238 "SurvivalRowHessian::channel_hessian_at: beta is non-trivial but \
239 family_scalars is None; supply SurvivalMarginalSlopeFamilyScalars \
240 via FamilyLinearizationState::family_scalars to evaluate W(β) \
241 correctly (same contract as T26 Jacobian callbacks)."
242 .to_string(),
243 )
244 }
245 None => {
246 Ok(Arc::new(gam_problem::TensorChannelHessian {
248 h: self.h.clone(),
249 }))
250 }
251 Some(sc) => {
252 let n = self.h.shape()[0];
253 if sc.q0_i.len() != n
254 || sc.q1_i.len() != n
255 || sc.qd1_i.len() != n
256 || sc.g_i.len() != n
257 || sc.z_i.len() != n
258 {
259 return Err(format!(
260 "SurvivalRowHessian::channel_hessian_at: scalars length mismatch \
261 (expected n={n}, got q0={} q1={} qd1={} g={} z={})",
262 sc.q0_i.len(),
263 sc.q1_i.len(),
264 sc.qd1_i.len(),
265 sc.g_i.len(),
266 sc.z_i.len(),
267 ));
268 }
269 let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
286 for i in 0..n {
287 let q0 = sc.q0_i[i];
288 let q1 = sc.q1_i[i];
289 let qd1 = sc.qd1_i[i];
290 let g = sc.g_i[i];
291 let z = sc.z_i[i];
292 match crate::survival::marginal_slope::row_primary_for_compiler(
295 q0, q1, qd1, g, z, 1.0, 1.0, crate::survival::marginal_slope::DEFAULT_SURVIVAL_MARGINAL_SLOPE_DERIVATIVE_GUARD,
298 sc.s, ) {
300 Ok((_nll, _grad, hess)) => {
301 let mut h_i = ndarray::Array2::<f64>::zeros((K_SURVIVAL, K_SURVIVAL));
302 for a in 0..K_SURVIVAL {
303 for b in 0..K_SURVIVAL {
304 h_i[[a, b]] = hess[a][b];
305 }
306 }
307 let clamped = psd_clamp_4x4(&h_i);
308 for a in 0..K_SURVIVAL {
309 for b in 0..K_SURVIVAL {
310 h_full[[i, a, b]] = clamped[[a, b]];
311 }
312 }
313 }
314 Err(_) => {
315 for a in 0..K_SURVIVAL {
318 for b in 0..K_SURVIVAL {
319 h_full[[i, a, b]] = self.h[[i, a, b]];
320 }
321 }
322 }
323 }
324 }
325 Ok(Arc::new(SurvivalRowHessian::from_full(h_full)))
326 }
327 }
328 }
329}
330
331fn psd_clamp_4x4(m: &Array2<f64>) -> Array2<f64> {
336 let k = m.nrows();
337 let (evals, evecs) = match m.eigh(Side::Lower) {
338 Ok(pair) => pair,
339 Err(_) => {
340 let mut out = Array2::<f64>::zeros((k, k));
341 for i in 0..k {
342 out[[i, i]] = m[[i, i]].max(0.0);
343 }
344 return out;
345 }
346 };
347 let mut out = Array2::<f64>::zeros((k, k));
348 for i in 0..k {
349 for j in 0..k {
350 let mut acc = 0.0;
351 for l in 0..k {
352 acc += evecs[[i, l]] * evals[l].max(0.0) * evecs[[j, l]];
353 }
354 out[[i, j]] = acc;
355 }
356 }
357 out
358}
359
360pub struct TimeBlockOperator {
363 dq0: Array2<f64>,
364 dq1: Array2<f64>,
365 dqd1: Array2<f64>,
366}
367
368impl TimeBlockOperator {
369 pub fn new(dq0: Array2<f64>, dq1: Array2<f64>, dqd1: Array2<f64>) -> Self {
370 assert_eq!(dq0.dim(), dq1.dim());
371 assert_eq!(dq0.dim(), dqd1.dim());
372 Self { dq0, dq1, dqd1 }
373 }
374}
375
376impl RowJacobianOperator for TimeBlockOperator {
377 fn k(&self) -> usize {
378 K_SURVIVAL
379 }
380 fn ncols(&self) -> usize {
381 self.dq0.ncols()
382 }
383 fn nrows(&self) -> usize {
384 self.dq0.nrows()
385 }
386 fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
387 assert_eq!(out.len(), K_SURVIVAL);
388 assert_eq!(delta_beta.len(), self.dq0.ncols());
389 let mut acc = [0.0_f64; K_SURVIVAL];
390 for (j, &b) in delta_beta.iter().enumerate() {
391 acc[0] += self.dq0[[row, j]] * b;
392 acc[1] += self.dq1[[row, j]] * b;
393 acc[2] += self.dqd1[[row, j]] * b;
394 }
395 out.copy_from_slice(&acc);
396 }
397 fn evaluate_full(&self) -> Array3<f64> {
398 let n = self.dq0.nrows();
399 let p = self.dq0.ncols();
400 let mut out = Array3::<f64>::zeros((n, p, K_SURVIVAL));
401 for i in 0..n {
402 for j in 0..p {
403 out[[i, j, 0]] = self.dq0[[i, j]];
404 out[[i, j, 1]] = self.dq1[[i, j]];
405 out[[i, j, 2]] = self.dqd1[[i, j]];
406 }
407 }
408 out
409 }
410 fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
411 let n = self.dq0.nrows();
417 let p = self.dq0.ncols();
418 scale_jacobian_by_sqrt_h_with(n, p, K_SURVIVAL, h_full, |i, a, c| match c {
419 0 => self.dq0[[i, a]],
420 1 => self.dq1[[i, a]],
421 2 => self.dqd1[[i, a]],
422 _ => 0.0,
423 })
424 }
425}
426
427pub struct QChannelBlockOperator {
432 dq: Array2<f64>,
433 dqd1: Array2<f64>,
434}
435
436impl QChannelBlockOperator {
437 pub fn new(dq: Array2<f64>, dqd1: Array2<f64>) -> Self {
438 assert_eq!(dq.dim(), dqd1.dim());
439 Self { dq, dqd1 }
440 }
441}
442
443impl RowJacobianOperator for QChannelBlockOperator {
444 fn k(&self) -> usize {
445 K_SURVIVAL
446 }
447 fn ncols(&self) -> usize {
448 self.dq.ncols()
449 }
450 fn nrows(&self) -> usize {
451 self.dq.nrows()
452 }
453 fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
454 assert_eq!(out.len(), K_SURVIVAL);
455 assert_eq!(delta_beta.len(), self.dq.ncols());
456 let mut dq_acc = 0.0;
457 let mut dqd_acc = 0.0;
458 for (j, &b) in delta_beta.iter().enumerate() {
459 dq_acc += self.dq[[row, j]] * b;
460 dqd_acc += self.dqd1[[row, j]] * b;
461 }
462 out[0] = dq_acc;
463 out[1] = dq_acc;
464 out[2] = dqd_acc;
465 out[3] = 0.0;
466 }
467 fn evaluate_full(&self) -> Array3<f64> {
468 let n = self.dq.nrows();
469 let p = self.dq.ncols();
470 let mut out = Array3::<f64>::zeros((n, p, K_SURVIVAL));
471 for i in 0..n {
472 for j in 0..p {
473 let v = self.dq[[i, j]];
474 out[[i, j, 0]] = v;
475 out[[i, j, 1]] = v;
476 out[[i, j, 2]] = self.dqd1[[i, j]];
477 }
478 }
479 out
480 }
481 fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
482 let n = self.dq.nrows();
486 let p = self.dq.ncols();
487 scale_jacobian_by_sqrt_h_with(n, p, K_SURVIVAL, h_full, |i, a, c| match c {
488 0 | 1 => self.dq[[i, a]],
489 2 => self.dqd1[[i, a]],
490 _ => 0.0,
491 })
492 }
493}
494
495pub struct LogslopeBlockOperator {
498 dg: Array2<f64>,
499}
500
501impl LogslopeBlockOperator {
502 pub fn new(dg: Array2<f64>) -> Self {
503 Self { dg }
504 }
505}
506
507impl RowJacobianOperator for LogslopeBlockOperator {
508 fn k(&self) -> usize {
509 K_SURVIVAL
510 }
511 fn ncols(&self) -> usize {
512 self.dg.ncols()
513 }
514 fn nrows(&self) -> usize {
515 self.dg.nrows()
516 }
517 fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
518 assert_eq!(out.len(), K_SURVIVAL);
519 assert_eq!(delta_beta.len(), self.dg.ncols());
520 let mut acc = 0.0;
521 for (j, &b) in delta_beta.iter().enumerate() {
522 acc += self.dg[[row, j]] * b;
523 }
524 out[0] = 0.0;
525 out[1] = 0.0;
526 out[2] = 0.0;
527 out[3] = acc;
528 }
529 fn evaluate_full(&self) -> Array3<f64> {
530 let n = self.dg.nrows();
531 let p = self.dg.ncols();
532 let mut out = Array3::<f64>::zeros((n, p, K_SURVIVAL));
533 for i in 0..n {
534 for j in 0..p {
535 out[[i, j, 3]] = self.dg[[i, j]];
536 }
537 }
538 out
539 }
540 fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
541 let n = self.dg.nrows();
546 let p = self.dg.ncols();
547 scale_jacobian_by_sqrt_h_with(n, p, K_SURVIVAL, h_full, |i, a, c| {
548 if c == 3 { self.dg[[i, a]] } else { 0.0 }
549 })
550 }
551}
552
553pub struct SurvivalCompilerInputs {
557 pub operators: Vec<Arc<dyn RowJacobianOperator>>,
558 pub ordering: Vec<BlockOrder>,
559}
560
561pub struct SurvivalParametricCompiled {
577 pub v_time: Array2<f64>,
578 pub v_marginal: Array2<f64>,
579 pub v_logslope: Array2<f64>,
580 pub drops_by_block: (usize, usize, usize),
585}
586
587fn wrap_design_with_transform(
588 raw: DesignMatrix,
589 v: &Array2<f64>,
590 context: &str,
591) -> Result<DesignMatrix, String> {
592 if raw.ncols() != v.nrows() {
593 return Err(format!(
594 "{context}: raw design has {} cols but V has {} rows (V is {}×{})",
595 raw.ncols(),
596 v.nrows(),
597 v.nrows(),
598 v.ncols(),
599 ));
600 }
601 let inner_dense = match raw {
602 DesignMatrix::Dense(d) => d,
603 DesignMatrix::Sparse(_) => {
604 let dense = raw
605 .try_to_dense_by_chunks(&format!("{context} sparse→dense for V apply"))
606 .map_err(|reason| format!("{context}: densify failed: {reason}"))?;
607 DenseDesignMatrix::from(dense)
608 }
609 };
610 let op = CoefficientTransformOperator::new(inner_dense, v.clone())
611 .map_err(|reason| format!("{context}: CoefficientTransformOperator::new: {reason}"))?;
612 Ok(DesignMatrix::Dense(DenseDesignMatrix::from(Arc::new(op))))
613}
614
615pub struct SurvivalParametricCompiledPerTerm {
623 pub v_time_per_term: Vec<Array2<f64>>,
624 pub v_marginal_per_term: Vec<Array2<f64>>,
625 pub v_logslope_per_term: Vec<Array2<f64>>,
626 pub r_lw_per_term: Vec<Option<Array2<f64>>>,
633 pub drops_by_block: (usize, usize, usize),
635}
636
637pub fn compile_survival_parametric_designs_per_term(
656 time_dq0: Array2<f64>,
657 time_dq1: Array2<f64>,
658 time_dqd1: Array2<f64>,
659 time_partition: &[std::ops::Range<usize>],
660 marginal_dq: Array2<f64>,
661 marginal_dqd1: Array2<f64>,
662 marginal_partition: &[std::ops::Range<usize>],
663 logslope_dg: Array2<f64>,
664 logslope_partition: &[std::ops::Range<usize>],
665 row_hess: &dyn RowHessian,
666) -> Result<SurvivalParametricCompiledPerTerm, String> {
667 use gam_identifiability::families::compiler::compile;
668
669 let p_time = time_dq0.ncols();
670 let p_marg = marginal_dq.ncols();
671 let p_log = logslope_dg.ncols();
672 validate_partition(time_partition, p_time, "time")?;
673 validate_partition(marginal_partition, p_marg, "marginal")?;
674 validate_partition(logslope_partition, p_log, "logslope")?;
675
676 let mut operators: Vec<Arc<dyn RowJacobianOperator>> = Vec::new();
680 let mut ordering: Vec<BlockOrder> = Vec::new();
681 for range in time_partition {
682 let dq0 = time_dq0.slice(ndarray::s![.., range.clone()]).to_owned();
683 let dq1 = time_dq1.slice(ndarray::s![.., range.clone()]).to_owned();
684 let dqd1 = time_dqd1.slice(ndarray::s![.., range.clone()]).to_owned();
685 operators.push(Arc::new(TimeBlockOperator::new(dq0, dq1, dqd1)));
686 ordering.push(BlockOrder::Time);
687 }
688 for range in marginal_partition {
689 let dq = marginal_dq.slice(ndarray::s![.., range.clone()]).to_owned();
690 let dqd1 = marginal_dqd1
691 .slice(ndarray::s![.., range.clone()])
692 .to_owned();
693 operators.push(Arc::new(QChannelBlockOperator::new(dq, dqd1)));
694 ordering.push(BlockOrder::Marginal);
695 }
696 for range in logslope_partition {
697 let dg = logslope_dg.slice(ndarray::s![.., range.clone()]).to_owned();
698 operators.push(Arc::new(LogslopeBlockOperator::new(dg)));
699 ordering.push(BlockOrder::Logslope);
700 }
701
702 let compiled = compile(&operators, row_hess, &ordering).map_err(|e| {
703 format!("identifiability::families::compiler::compile (per-term) failed: {e}")
704 })?;
705 let blocks = compiled.blocks;
706 let n_time = time_partition.len();
707 let n_marg = marginal_partition.len();
708 let n_log = logslope_partition.len();
709 if blocks.len() != n_time + n_marg + n_log {
710 return Err(format!(
711 "per-term compile: expected {} compiled blocks (time={}, marg={}, log={}), got {}",
712 n_time + n_marg + n_log,
713 n_time,
714 n_marg,
715 n_log,
716 blocks.len(),
717 ));
718 }
719 let mut iter = blocks.into_iter();
720 let mut v_time_per_term: Vec<Array2<f64>> = Vec::with_capacity(n_time);
721 let mut r_time_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_time);
722 for _ in 0..n_time {
723 let blk = iter.next().unwrap();
724 v_time_per_term.push(blk.t_lw);
725 r_time_per_term.push(blk.r_lw);
726 }
727 let mut v_marginal_per_term: Vec<Array2<f64>> = Vec::with_capacity(n_marg);
728 let mut r_marginal_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_marg);
729 for _ in 0..n_marg {
730 let blk = iter.next().unwrap();
731 v_marginal_per_term.push(blk.t_lw);
732 r_marginal_per_term.push(blk.r_lw);
733 }
734 let mut v_logslope_per_term: Vec<Array2<f64>> = Vec::with_capacity(n_log);
735 let mut r_logslope_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_log);
736 for _ in 0..n_log {
737 let blk = iter.next().unwrap();
738 v_logslope_per_term.push(blk.t_lw);
739 r_logslope_per_term.push(blk.r_lw);
740 }
741 let mut r_lw_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_time + n_marg + n_log);
742 r_lw_per_term.extend(r_time_per_term);
743 r_lw_per_term.extend(r_marginal_per_term);
744 r_lw_per_term.extend(r_logslope_per_term);
745 let drops_time: usize = time_partition
746 .iter()
747 .zip(v_time_per_term.iter())
748 .map(|(r, v)| r.len().saturating_sub(v.ncols()))
749 .sum();
750 let drops_marg: usize = marginal_partition
751 .iter()
752 .zip(v_marginal_per_term.iter())
753 .map(|(r, v)| r.len().saturating_sub(v.ncols()))
754 .sum();
755 let drops_log: usize = logslope_partition
756 .iter()
757 .zip(v_logslope_per_term.iter())
758 .map(|(r, v)| r.len().saturating_sub(v.ncols()))
759 .sum();
760 Ok(SurvivalParametricCompiledPerTerm {
761 v_time_per_term,
762 v_marginal_per_term,
763 v_logslope_per_term,
764 r_lw_per_term,
765 drops_by_block: (drops_time, drops_marg, drops_log),
766 })
767}
768
769fn validate_partition(
770 partition: &[std::ops::Range<usize>],
771 p_block: usize,
772 label: &str,
773) -> Result<(), String> {
774 if partition.is_empty() {
775 if p_block == 0 {
776 return Ok(());
777 }
778 return Err(format!(
779 "{label} partition empty but block has p={p_block} columns"
780 ));
781 }
782 if partition[0].start != 0 {
783 return Err(format!(
784 "{label} partition must start at 0, got start={}",
785 partition[0].start
786 ));
787 }
788 if partition.last().unwrap().end != p_block {
789 return Err(format!(
790 "{label} partition must cover [0, {p_block}); last range ends at {}",
791 partition.last().unwrap().end
792 ));
793 }
794 for w in partition.windows(2) {
795 if w[0].end != w[1].start {
796 return Err(format!(
797 "{label} partition has gap/overlap between [{}..{}) and [{}..{})",
798 w[0].start, w[0].end, w[1].start, w[1].end
799 ));
800 }
801 if w[0].is_empty() {
802 return Err(format!(
803 "{label} partition has empty range [{}..{})",
804 w[0].start, w[0].end
805 ));
806 }
807 }
808 if partition.last().unwrap().is_empty() {
809 return Err(format!("{label} partition's final range is empty",));
810 }
811 Ok(())
812}
813
814pub fn extract_term_partition_from_penalty_ranges(
820 p_block: usize,
821 penalty_ranges: &[std::ops::Range<usize>],
822) -> Vec<std::ops::Range<usize>> {
823 use std::collections::BTreeSet;
824 let mut starts: BTreeSet<usize> = BTreeSet::new();
825 starts.insert(0);
826 starts.insert(p_block);
827 for r in penalty_ranges {
828 starts.insert(r.start.min(p_block));
829 starts.insert(r.end.min(p_block));
830 }
831 let v: Vec<usize> = starts.into_iter().collect();
832 v.windows(2)
833 .filter_map(|w| if w[0] < w[1] { Some(w[0]..w[1]) } else { None })
834 .collect()
835}
836
837pub fn pull_back_blockwise_penalty_through_block_v(
860 pen: &gam_terms::smooth::BlockwisePenalty,
861 v_block: &Array2<f64>,
862) -> Result<PenaltyMatrix, String> {
863 let raw_p = v_block.nrows();
864 let compiled_p = v_block.ncols();
865 let block_p = pen.col_range.len();
866 let embed_start = pen.col_range.start;
867 let embed_end = pen.col_range.end;
868 if embed_end > raw_p {
869 return Err(format!(
870 "pull_back_blockwise_penalty_through_block_v: penalty col_range {embed_start}..{embed_end} \
871 exceeds block raw width {raw_p}"
872 ));
873 }
874 if pen.local.nrows() != block_p || pen.local.ncols() != block_p {
875 return Err(format!(
876 "pull_back_blockwise_penalty_through_block_v: penalty local is {}x{} but col_range \
877 width is {block_p}",
878 pen.local.nrows(),
879 pen.local.ncols(),
880 ));
881 }
882 let mut embedded = Array2::<f64>::zeros((raw_p, raw_p));
883 if block_p > 0 {
884 let mut dst =
885 embedded.slice_mut(ndarray::s![embed_start..embed_end, embed_start..embed_end]);
886 for i in 0..block_p {
887 for j in 0..block_p {
888 dst[[i, j]] = pen.local[[i, j]];
889 }
890 }
891 }
892 let temp = embedded.dot(v_block);
894 let pulled = v_block.t().dot(&temp);
895 let mut sym = Array2::<f64>::zeros((compiled_p, compiled_p));
896 for i in 0..compiled_p {
897 for j in 0..compiled_p {
898 sym[[i, j]] = 0.5 * (pulled[[i, j]] + pulled[[j, i]]);
899 }
900 }
901 Ok(PenaltyMatrix::Dense(sym))
902}
903
904pub fn compiled_map_from_per_term(
926 compiled: &SurvivalParametricCompiledPerTerm,
927) -> gam_identifiability::families::compiler::CompiledMap {
928 let mut v_all: Vec<Array2<f64>> = Vec::new();
931 v_all.extend(compiled.v_time_per_term.iter().cloned());
932 v_all.extend(compiled.v_marginal_per_term.iter().cloned());
933 v_all.extend(compiled.v_logslope_per_term.iter().cloned());
934
935 let t_full = assemble_block_triangular_t(&v_all, &compiled.r_lw_per_term);
936
937 let raw_w = |terms: &[Array2<f64>]| -> usize { terms.iter().map(|v| v.nrows()).sum() };
939 let kept_w = |terms: &[Array2<f64>]| -> usize { terms.iter().map(|v| v.ncols()).sum() };
940 let raw_time = raw_w(&compiled.v_time_per_term);
941 let raw_marg = raw_w(&compiled.v_marginal_per_term);
942 let raw_log = raw_w(&compiled.v_logslope_per_term);
943 let kept_time = kept_w(&compiled.v_time_per_term);
944 let kept_marg = kept_w(&compiled.v_marginal_per_term);
945 let kept_log = kept_w(&compiled.v_logslope_per_term);
946
947 let raw_block_ranges = vec![
948 0..raw_time,
949 raw_time..(raw_time + raw_marg),
950 (raw_time + raw_marg)..(raw_time + raw_marg + raw_log),
951 ];
952 let compiled_block_ranges = vec![
953 0..kept_time,
954 kept_time..(kept_time + kept_marg),
955 (kept_time + kept_marg)..(kept_time + kept_marg + kept_log),
956 ];
957
958 gam_identifiability::families::compiler::CompiledMap {
959 raw_from_compiled: t_full,
960 compiled_block_ranges,
961 raw_block_ranges,
962 }
963}
964
965pub fn apply_compiled_map_to_designs(
993 map: &gam_identifiability::families::compiler::CompiledMap,
994 time_design_entry: DesignMatrix,
995 time_design_exit: DesignMatrix,
996 time_design_derivative_exit: DesignMatrix,
997 marginal_design: DesignMatrix,
998 logslope_design: DesignMatrix,
999 time_penalties: &[gam_terms::smooth::BlockwisePenalty],
1000 marginal_penalties: &[gam_terms::smooth::BlockwisePenalty],
1001 logslope_penalties: &[gam_terms::smooth::BlockwisePenalty],
1002) -> Result<CompiledSurvivalDesignsVMExact, String> {
1003 if map.raw_block_ranges.len() != 3 || map.compiled_block_ranges.len() != 3 {
1004 return Err(format!(
1005 "apply_compiled_map_to_designs: expected exactly 3 blocks (time, marginal, logslope), \
1006 got {} raw / {} compiled",
1007 map.raw_block_ranges.len(),
1008 map.compiled_block_ranges.len(),
1009 ));
1010 }
1011 let time_raw = map.raw_block_ranges[0].clone();
1012 let marg_raw = map.raw_block_ranges[1].clone();
1013 let log_raw = map.raw_block_ranges[2].clone();
1014 let time_compiled = map.compiled_block_ranges[0].clone();
1015 let marg_compiled = map.compiled_block_ranges[1].clone();
1016 let log_compiled = map.compiled_block_ranges[2].clone();
1017
1018 let t = &map.raw_from_compiled;
1019 let raw_total = t.nrows();
1020 let compiled_total = t.ncols();
1021 let expected_raw_total = log_raw.end;
1022 if raw_total != expected_raw_total {
1023 return Err(format!(
1024 "apply_compiled_map_to_designs: T has {raw_total} raw rows but block ranges sum to \
1025 {expected_raw_total}"
1026 ));
1027 }
1028 let expected_compiled_total = log_compiled.end;
1029 if compiled_total != expected_compiled_total {
1030 return Err(format!(
1031 "apply_compiled_map_to_designs: T has {compiled_total} compiled cols but block ranges \
1032 sum to {expected_compiled_total}"
1033 ));
1034 }
1035
1036 let v_time = t
1037 .slice(ndarray::s![time_raw.clone(), time_compiled.clone()])
1038 .to_owned();
1039 let v_marg = t
1040 .slice(ndarray::s![marg_raw.clone(), marg_compiled.clone()])
1041 .to_owned();
1042 let v_log = t
1043 .slice(ndarray::s![log_raw.clone(), log_compiled.clone()])
1044 .to_owned();
1045
1046 let time_entry_out =
1047 wrap_design_with_transform(time_design_entry, &v_time, "compiled-map: time entry")?;
1048 let time_exit_out =
1049 wrap_design_with_transform(time_design_exit, &v_time, "compiled-map: time exit")?;
1050 let time_deriv_out = wrap_design_with_transform(
1051 time_design_derivative_exit,
1052 &v_time,
1053 "compiled-map: time derivative_exit",
1054 )?;
1055 let marg_out = wrap_design_with_transform(marginal_design, &v_marg, "compiled-map: marginal")?;
1056 let log_out = wrap_design_with_transform(logslope_design, &v_log, "compiled-map: logslope")?;
1057
1058 let pull_set = |pens: &[gam_terms::smooth::BlockwisePenalty],
1079 v_block: &Array2<f64>,
1080 channel: &str|
1081 -> Result<Vec<PenaltyMatrix>, String> {
1082 pens.iter()
1083 .map(|p| {
1084 pull_back_blockwise_penalty_through_block_v(p, v_block).map_err(|e| {
1085 format!("apply_compiled_map_to_designs: {channel} penalty pullback: {e}")
1086 })
1087 })
1088 .collect()
1089 };
1090
1091 let time_penalties = pull_set(time_penalties, &v_time, "time")?;
1092 let marginal_penalties = pull_set(marginal_penalties, &v_marg, "marginal")?;
1093 let logslope_penalties = pull_set(logslope_penalties, &v_log, "logslope")?;
1094 validate_block_penalty_shapes("time", time_exit_out.ncols(), &time_penalties)?;
1095 validate_block_penalty_shapes("marginal", marg_out.ncols(), &marginal_penalties)?;
1096 validate_block_penalty_shapes("logslope", log_out.ncols(), &logslope_penalties)?;
1097
1098 Ok(CompiledSurvivalDesignsVMExact {
1099 time_design_entry: time_entry_out,
1100 time_design_exit: time_exit_out,
1101 time_design_derivative_exit: time_deriv_out,
1102 marginal_design: marg_out,
1103 logslope_design: log_out,
1104 time_penalties,
1105 marginal_penalties,
1106 logslope_penalties,
1107 })
1108}
1109
1110fn validate_block_penalty_shapes(
1111 block: &str,
1112 width: usize,
1113 penalties: &[PenaltyMatrix],
1114) -> Result<(), String> {
1115 for (idx, penalty) in penalties.iter().enumerate() {
1116 let shape = penalty.shape();
1117 if shape != (width, width) {
1118 return Err(format!(
1119 "apply_compiled_map_to_designs: {block} penalty {idx} must be {width}x{width}, got {}x{}",
1120 shape.0, shape.1
1121 ));
1122 }
1123 }
1124 Ok(())
1125}
1126
1127pub fn compile_survival_parametric_designs(
1155 time_dq0: Array2<f64>,
1156 time_dq1: Array2<f64>,
1157 time_dqd1: Array2<f64>,
1158 marginal_dq: Array2<f64>,
1159 marginal_dqd1: Array2<f64>,
1160 logslope_dg: Array2<f64>,
1161 row_hess: &dyn RowHessian,
1162) -> Result<SurvivalParametricCompiled, String> {
1163 use gam_identifiability::families::compiler::compile;
1164
1165 let p_time_raw = time_dq0.ncols();
1166 let p_marg_raw = marginal_dq.ncols();
1167 let p_log_raw = logslope_dg.ncols();
1168
1169 let inputs = build_survival_compiler_inputs(
1170 time_dq0,
1171 time_dq1,
1172 time_dqd1,
1173 marginal_dq,
1174 marginal_dqd1,
1175 logslope_dg,
1176 None,
1177 None,
1178 );
1179 if inputs.operators.len() != 3 {
1180 return Err(format!(
1181 "compile_survival_parametric_designs: expected exactly 3 parametric operators \
1182 (time, marginal, logslope); got {}",
1183 inputs.operators.len(),
1184 ));
1185 }
1186 let compiled = compile(&inputs.operators, row_hess, &inputs.ordering)
1187 .map_err(|e| format!("identifiability::families::compiler::compile failed: {e}"))?;
1188 if compiled.blocks.len() != 3 {
1189 return Err(format!(
1190 "compile_survival_parametric_designs: compiler emitted {} blocks; expected 3",
1191 compiled.blocks.len(),
1192 ));
1193 }
1194 let v_time = compiled.blocks[0].t_lw.clone();
1195 let v_marginal = compiled.blocks[1].t_lw.clone();
1196 let v_logslope = compiled.blocks[2].t_lw.clone();
1197 let drops_by_block = (
1198 p_time_raw.saturating_sub(v_time.ncols()),
1199 p_marg_raw.saturating_sub(v_marginal.ncols()),
1200 p_log_raw.saturating_sub(v_logslope.ncols()),
1201 );
1202 Ok(SurvivalParametricCompiled {
1203 v_time,
1204 v_marginal,
1205 v_logslope,
1206 drops_by_block,
1207 })
1208}
1209
1210pub fn build_survival_compiler_inputs(
1222 time_dq0: Array2<f64>,
1223 time_dq1: Array2<f64>,
1224 time_dqd1: Array2<f64>,
1225 marginal_dq: Array2<f64>,
1226 marginal_dqd1: Array2<f64>,
1227 logslope_dg: Array2<f64>,
1228 score_warp_dq_dqd1: Option<(Array2<f64>, Array2<f64>)>,
1229 link_dev_dq_dqd1: Option<(Array2<f64>, Array2<f64>)>,
1230) -> SurvivalCompilerInputs {
1231 let mut operators: Vec<Arc<dyn RowJacobianOperator>> = Vec::with_capacity(5);
1232 let mut ordering: Vec<BlockOrder> = Vec::with_capacity(5);
1233
1234 operators.push(Arc::new(TimeBlockOperator::new(
1235 time_dq0, time_dq1, time_dqd1,
1236 )));
1237 ordering.push(BlockOrder::Time);
1238
1239 operators.push(Arc::new(QChannelBlockOperator::new(
1240 marginal_dq,
1241 marginal_dqd1,
1242 )));
1243 ordering.push(BlockOrder::Marginal);
1244
1245 operators.push(Arc::new(LogslopeBlockOperator::new(logslope_dg)));
1246 ordering.push(BlockOrder::Logslope);
1247
1248 if let Some((dq, dqd1)) = score_warp_dq_dqd1 {
1249 operators.push(Arc::new(QChannelBlockOperator::new(dq, dqd1)));
1250 ordering.push(BlockOrder::ScoreWarp);
1251 }
1252 if let Some((dq, dqd1)) = link_dev_dq_dqd1 {
1253 operators.push(Arc::new(QChannelBlockOperator::new(dq, dqd1)));
1254 ordering.push(BlockOrder::LinkDev);
1255 }
1256
1257 SurvivalCompilerInputs {
1258 operators,
1259 ordering,
1260 }
1261}
1262
1263pub struct CompiledSurvivalDesignsVMExact {
1282 pub time_design_entry: DesignMatrix,
1283 pub time_design_exit: DesignMatrix,
1284 pub time_design_derivative_exit: DesignMatrix,
1285 pub marginal_design: DesignMatrix,
1286 pub logslope_design: DesignMatrix,
1287 pub time_penalties: Vec<PenaltyMatrix>,
1295 pub marginal_penalties: Vec<PenaltyMatrix>,
1296 pub logslope_penalties: Vec<PenaltyMatrix>,
1297}
1298
1299#[cfg(test)]
1300mod tests {
1301 use super::*;
1302 use gam_problem::Gauge;
1303
1304 #[test]
1305 fn psd_clamp_zeros_negative_eigenvalues() {
1306 let mut m = Array2::<f64>::zeros((4, 4));
1310 m[[0, 0]] = 2.0;
1313 m[[1, 1]] = -1.0;
1314 m[[2, 2]] = 0.5;
1315 m[[3, 3]] = -0.25;
1316 let clamped = psd_clamp_4x4(&m);
1317 assert!((clamped[[0, 0]] - 2.0).abs() < 1e-12);
1318 assert!(clamped[[1, 1]].abs() < 1e-12);
1319 assert!((clamped[[2, 2]] - 0.5).abs() < 1e-12);
1320 assert!(clamped[[3, 3]].abs() < 1e-12);
1321 }
1322
1323 #[test]
1324 fn time_block_operator_evaluate_full_shape() {
1325 let n = 6;
1326 let p = 3;
1327 let dq0 = Array2::from_shape_fn((n, p), |(i, j)| (i + j) as f64);
1328 let dq1 = Array2::from_shape_fn((n, p), |(i, j)| (i as f64) * 2.0 + j as f64);
1329 let dqd1 = Array2::from_shape_fn((n, p), |(i, j)| 0.5 * ((i * j) as f64));
1330 let op = TimeBlockOperator::new(dq0.clone(), dq1.clone(), dqd1.clone());
1331 let full = op.evaluate_full();
1332 assert_eq!(full.shape(), &[n, p, K_SURVIVAL]);
1333 for i in 0..n {
1334 for j in 0..p {
1335 assert_eq!(full[[i, j, 0]], dq0[[i, j]]);
1336 assert_eq!(full[[i, j, 1]], dq1[[i, j]]);
1337 assert_eq!(full[[i, j, 2]], dqd1[[i, j]]);
1338 assert_eq!(full[[i, j, 3]], 0.0);
1339 }
1340 }
1341 }
1342
1343 #[test]
1344 fn q_channel_block_apply_row_shares_q0_q1() {
1345 let n = 5;
1346 let p = 2;
1347 let dq = Array2::from_shape_fn((n, p), |(i, j)| (i as f64) * (j as f64 + 1.0));
1348 let dqd1 = Array2::from_shape_fn((n, p), |(i, j)| (j as f64) - (i as f64));
1349 let op = QChannelBlockOperator::new(dq.clone(), dqd1.clone());
1350 let mut out = [0.0_f64; K_SURVIVAL];
1351 let delta = [1.0_f64, -0.5];
1352 op.apply_row(3, &delta, &mut out);
1353 let want_q = dq[[3, 0]] * 1.0 + dq[[3, 1]] * (-0.5);
1354 let want_qd = dqd1[[3, 0]] * 1.0 + dqd1[[3, 1]] * (-0.5);
1355 assert!((out[0] - want_q).abs() < 1e-12);
1356 assert!((out[1] - want_q).abs() < 1e-12);
1357 assert!((out[2] - want_qd).abs() < 1e-12);
1358 assert_eq!(out[3], 0.0);
1359 }
1360
1361 #[test]
1362 fn logslope_block_writes_only_g_channel() {
1363 let n = 4;
1364 let p = 2;
1365 let dg = Array2::from_shape_fn((n, p), |(i, j)| (i as f64) + 0.1 * (j as f64));
1366 let op = LogslopeBlockOperator::new(dg.clone());
1367 let mut out = [0.0_f64; K_SURVIVAL];
1368 let delta = [2.0_f64, -1.0];
1369 op.apply_row(1, &delta, &mut out);
1370 assert_eq!(out[0], 0.0);
1371 assert_eq!(out[1], 0.0);
1372 assert_eq!(out[2], 0.0);
1373 let want = dg[[1, 0]] * 2.0 + dg[[1, 1]] * (-1.0);
1374 assert!((out[3] - want).abs() < 1e-12);
1375 }
1376
1377 #[test]
1378 fn extract_term_partition_simple_cases() {
1379 let full = 0..5usize;
1380 let part = extract_term_partition_from_penalty_ranges(5, &[]);
1382 assert_eq!(part.as_slice(), std::slice::from_ref(&full));
1383 let part = extract_term_partition_from_penalty_ranges(5, std::slice::from_ref(&full));
1385 assert_eq!(part.as_slice(), std::slice::from_ref(&full));
1386 let part = extract_term_partition_from_penalty_ranges(10, &[0..3, 6..10]);
1388 assert_eq!(part, vec![0..3, 3..6, 6..10]);
1389 let part = extract_term_partition_from_penalty_ranges(6, &[0..3, 0..3, 3..6]);
1391 assert_eq!(part, vec![0..3, 3..6]);
1392 let part = extract_term_partition_from_penalty_ranges(0, &[]);
1394 assert!(part.is_empty());
1395 }
1396
1397 #[test]
1398 fn assemble_block_triangular_t_identity_when_v_eye_and_r_none() {
1399 let v_a = Array2::<f64>::eye(2);
1400 let v_b = Array2::<f64>::eye(2);
1401 let t = assemble_block_triangular_t(&[v_a, v_b], &[None, None]);
1402 assert_eq!(t.dim(), (4, 4));
1403 let eye4 = Array2::<f64>::eye(4);
1404 for i in 0..4 {
1405 for j in 0..4 {
1406 assert!((t[[i, j]] - eye4[[i, j]]).abs() < 1e-14);
1407 }
1408 }
1409 }
1410
1411 #[test]
1412 fn assemble_block_triangular_t_with_drops_and_nonzero_r() {
1413 let mut v_a = Array2::<f64>::zeros((3, 2));
1414 v_a[[0, 0]] = 1.0;
1415 v_a[[1, 0]] = 0.5;
1416 v_a[[2, 1]] = 1.0;
1417 let v_b = Array2::<f64>::eye(2);
1418 let r_ab =
1419 Array2::<f64>::from_shape_fn((3, 2), |(i, j)| 1.0 + (i as f64) + 0.25 * (j as f64));
1420 let t =
1421 assemble_block_triangular_t(&[v_a.clone(), v_b.clone()], &[None, Some(r_ab.clone())]);
1422 assert_eq!(t.dim(), (5, 4));
1423 for i in 0..3 {
1424 for j in 0..2 {
1425 assert!((t[[i, j]] - v_a[[i, j]]).abs() < 1e-14);
1426 }
1427 }
1428 for i in 0..2 {
1429 for j in 0..2 {
1430 assert!((t[[3 + i, 2 + j]] - v_b[[i, j]]).abs() < 1e-14);
1431 }
1432 }
1433 for i in 0..3 {
1434 for j in 0..2 {
1435 assert!((t[[i, 2 + j]] + r_ab[[i, j]]).abs() < 1e-14);
1436 }
1437 }
1438 for i in 0..2 {
1439 for j in 0..2 {
1440 assert_eq!(t[[3 + i, j]], 0.0);
1441 }
1442 }
1443 }
1444
1445 #[test]
1446 fn validate_partition_rejects_bad_partitions() {
1447 let bad_start = 1..5usize;
1448 let short_cover = 0..3usize;
1449 let full_cover = 0..5usize;
1450 assert!(validate_partition(std::slice::from_ref(&bad_start), 5, "test").is_err());
1452 assert!(validate_partition(std::slice::from_ref(&short_cover), 5, "test").is_err());
1454 assert!(validate_partition(&[0..2, 3..5], 5, "test").is_err());
1456 assert!(validate_partition(&[0..3, 2..5], 5, "test").is_err());
1458 assert!(validate_partition(&[0..0, 0..5], 5, "test").is_err());
1460 assert!(validate_partition(&[], 0, "test").is_ok());
1462 assert!(validate_partition(&[0..2, 2..5], 5, "test").is_ok());
1464 assert!(validate_partition(std::slice::from_ref(&full_cover), 5, "test").is_ok());
1465 }
1466
1467 #[test]
1478 fn compiled_map_penalty_pullback_is_per_block_width_with_nonzero_residual() {
1479 use gam_identifiability::families::compiler::CompiledMap;
1480 use gam_terms::smooth::BlockwisePenalty;
1481
1482 let n = 10;
1483 let v_time =
1487 Array2::<f64>::from_shape_fn(
1488 (3, 3),
1489 |(i, j)| {
1490 if i == j { 1.0 } else { 0.1 * ((i + j) as f64) }
1491 },
1492 );
1493 let v_marg = Array2::<f64>::from_shape_fn((3, 2), |(i, j)| {
1494 0.5 + 0.3 * (i as f64) - 0.2 * (j as f64)
1495 });
1496 let v_log = Array2::<f64>::from_shape_fn((2, 2), |(i, j)| if i == j { 1.2 } else { 0.4 });
1497 let r_marg = Array2::<f64>::from_shape_fn((3, 2), |(i, j)| 0.7 - 0.1 * ((i + j) as f64));
1499 let r_log =
1504 Array2::<f64>::from_shape_fn((6, 2), |(i, j)| 0.3 + 0.05 * ((i * 2 + j) as f64));
1505
1506 let t = assemble_block_triangular_t(
1507 &[v_time.clone(), v_marg.clone(), v_log.clone()],
1508 &[None, Some(r_marg.clone()), Some(r_log.clone())],
1509 );
1510 assert_eq!(t.dim(), (8, 7), "joint raw 8 × joint compiled 7");
1511
1512 let map = CompiledMap {
1513 raw_from_compiled: t.clone(),
1514 compiled_block_ranges: vec![0..3, 3..5, 5..7],
1515 raw_block_ranges: vec![0..3, 3..6, 6..8],
1516 };
1517
1518 let raw_time_entry = DesignMatrix::Dense(DenseDesignMatrix::from(
1520 Array2::<f64>::from_shape_fn((n, 3), |(i, j)| 1.0 + (i as f64) * 0.1 + (j as f64)),
1521 ));
1522 let raw_time_exit = raw_time_entry.clone();
1523 let raw_time_deriv = raw_time_entry.clone();
1524 let raw_marg = DesignMatrix::Dense(DenseDesignMatrix::from(Array2::<f64>::from_shape_fn(
1525 (n, 3),
1526 |(i, j)| 0.2 * (i as f64) - 0.3 * (j as f64),
1527 )));
1528 let raw_log = DesignMatrix::Dense(DenseDesignMatrix::from(Array2::<f64>::from_shape_fn(
1529 (n, 2),
1530 |(i, j)| 0.5 + (i as f64) * (j as f64 + 1.0),
1531 )));
1532
1533 let s_time =
1535 Array2::<f64>::from_shape_fn(
1536 (3, 3),
1537 |(i, j)| if i == j { (i + 2) as f64 } else { 0.3 },
1538 );
1539 let s_marg =
1540 Array2::<f64>::from_shape_fn(
1541 (3, 3),
1542 |(i, j)| if i == j { 1.5 + i as f64 } else { 0.2 },
1543 );
1544 let s_log = Array2::<f64>::from_shape_fn((2, 2), |(i, j)| if i == j { 2.0 } else { 0.5 });
1545 let time_pens = vec![BlockwisePenalty::new(0..3, s_time.clone())];
1546 let marg_pens = vec![BlockwisePenalty::new(0..3, s_marg.clone())];
1547 let log_pens = vec![BlockwisePenalty::new(0..2, s_log.clone())];
1548
1549 let out = apply_compiled_map_to_designs(
1550 &map,
1551 raw_time_entry,
1552 raw_time_exit,
1553 raw_time_deriv,
1554 raw_marg,
1555 raw_log,
1556 &time_pens,
1557 &marg_pens,
1558 &log_pens,
1559 )
1560 .expect("apply_compiled_map_to_designs must succeed");
1561
1562 assert_eq!(out.time_design_entry.ncols(), 3);
1564 assert_eq!(out.marginal_design.ncols(), 2);
1565 assert_eq!(out.logslope_design.ncols(), 2);
1566
1567 for s in &out.time_penalties {
1570 assert_eq!(
1571 s.as_dense_cow().dim(),
1572 (3, 3),
1573 "time penalty must be per-block 3×3, not joint-width"
1574 );
1575 }
1576 for s in &out.marginal_penalties {
1577 assert_eq!(
1578 s.as_dense_cow().dim(),
1579 (2, 2),
1580 "marginal penalty must match reduced compiled width 2, not joint 7"
1581 );
1582 }
1583 for s in &out.logslope_penalties {
1584 assert_eq!(s.as_dense_cow().dim(), (2, 2));
1585 }
1586
1587 let p_time_dense = out.time_penalties[0].as_dense_cow().into_owned();
1591 let theta_time = Array1::<f64>::from_shape_fn(3, |k| 0.4 + 0.7 * (k as f64));
1592 let gamma_time = v_time.dot(&theta_time);
1593 let lhs = theta_time.dot(&p_time_dense.dot(&theta_time));
1594 let rhs = gamma_time.dot(&s_time.dot(&gamma_time));
1595 assert!(
1596 (lhs - rhs).abs() < 1e-10,
1597 "time-block per-block pullback must be exact: lhs={lhs}, rhs={rhs}"
1598 );
1599
1600 let p_marg_dense = out.marginal_penalties[0].as_dense_cow().into_owned();
1603 let want_marg = v_marg.t().dot(&s_marg.dot(&v_marg));
1604 for i in 0..2 {
1605 for j in 0..2 {
1606 assert!(
1607 (p_marg_dense[[i, j]] - want_marg[[i, j]]).abs() < 1e-12,
1608 "marginal penalty must be V_margᵀ S_marg V_marg at ({i},{j})"
1609 );
1610 }
1611 }
1612 }
1613
1614 #[test]
1621 fn compile_survival_parametric_designs_helper_attributes_drop_to_marginal() {
1622 let n = 24;
1623 let p_time = 3;
1624 let p_marginal = 3;
1625 let p_logslope = 2;
1626 let x: Vec<f64> = (0..n)
1627 .map(|i| -1.0 + 2.0 * (i as f64) / (n as f64 - 1.0))
1628 .collect();
1629 let mut time_dq0 = Array2::<f64>::zeros((n, p_time));
1630 let mut time_dq1 = Array2::<f64>::zeros((n, p_time));
1631 let mut time_dqd1 = Array2::<f64>::zeros((n, p_time));
1632 let mut marg_dq = Array2::<f64>::zeros((n, p_marginal));
1633 let marg_dqd1 = Array2::<f64>::zeros((n, p_marginal));
1634 let mut log_dg = Array2::<f64>::zeros((n, p_logslope));
1635 for i in 0..n {
1636 time_dq0[[i, 0]] = 1.0;
1637 time_dq0[[i, 1]] = x[i];
1638 time_dq0[[i, 2]] = x[i] * x[i];
1639 time_dq1[[i, 0]] = 1.0;
1640 time_dq1[[i, 1]] = x[i];
1641 time_dq1[[i, 2]] = x[i] * x[i];
1642 time_dqd1[[i, 0]] = 0.0;
1643 time_dqd1[[i, 1]] = 1.0;
1644 time_dqd1[[i, 2]] = 2.0 * x[i];
1645 marg_dq[[i, 0]] = 1.0; marg_dq[[i, 1]] = x[i] * x[i] * x[i];
1647 marg_dq[[i, 2]] = x[i].sin();
1648 log_dg[[i, 0]] = (2.0 * x[i]).cos();
1649 log_dg[[i, 1]] = x[i].tanh();
1650 }
1651 let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
1652 for i in 0..n {
1653 for k in 0..K_SURVIVAL {
1654 h_full[[i, k, k]] = 1.0;
1655 }
1656 }
1657 let row_hess = SurvivalRowHessian::from_full(h_full);
1658 let out = compile_survival_parametric_designs(
1659 time_dq0, time_dq1, time_dqd1, marg_dq, marg_dqd1, log_dg, &row_hess,
1660 )
1661 .expect("Phase-4b parametric compile must succeed on single-direction alias");
1662 assert_eq!(out.v_time.ncols(), p_time, "time keeps all columns");
1663 assert_eq!(
1664 out.v_marginal.ncols(),
1665 p_marginal - 1,
1666 "marginal loses exactly the shared-constant direction"
1667 );
1668 assert_eq!(out.v_logslope.ncols(), p_logslope, "logslope is clean");
1669 assert_eq!(
1670 out.drops_by_block,
1671 (0, 1, 0),
1672 "attribution: zero from time/logslope, one from marginal",
1673 );
1674 }
1675
1676 #[test]
1697 fn compile_survival_three_block_with_shared_constant_drops_one_direction() {
1698 use gam_identifiability::families::compiler::compile;
1699
1700 let n = 32;
1701 let p_time = 3;
1702 let p_marginal = 3;
1703 let p_logslope = 2;
1704
1705 let x: Vec<f64> = (0..n)
1716 .map(|i| -1.0 + 2.0 * (i as f64) / (n as f64 - 1.0))
1717 .collect();
1718 let mut time_dq0 = Array2::<f64>::zeros((n, p_time));
1719 let mut time_dq1 = Array2::<f64>::zeros((n, p_time));
1720 let mut time_dqd1 = Array2::<f64>::zeros((n, p_time));
1721 for i in 0..n {
1722 time_dq0[[i, 0]] = 1.0;
1723 time_dq0[[i, 1]] = x[i];
1724 time_dq0[[i, 2]] = x[i] * x[i];
1725 time_dq1[[i, 0]] = 1.0;
1726 time_dq1[[i, 1]] = x[i];
1727 time_dq1[[i, 2]] = x[i] * x[i];
1728 time_dqd1[[i, 0]] = 0.0;
1730 time_dqd1[[i, 1]] = 1.0;
1731 time_dqd1[[i, 2]] = 2.0 * x[i];
1732 }
1733
1734 let mut marg_dq = Array2::<f64>::zeros((n, p_marginal));
1740 let marg_dqd1 = Array2::<f64>::zeros((n, p_marginal));
1741 for i in 0..n {
1742 marg_dq[[i, 0]] = 1.0;
1743 marg_dq[[i, 1]] = x[i] * x[i] * x[i];
1744 marg_dq[[i, 2]] = x[i].sin();
1745 }
1746
1747 let mut log_dg = Array2::<f64>::zeros((n, p_logslope));
1751 for i in 0..n {
1752 log_dg[[i, 0]] = (2.0 * x[i]).cos();
1753 log_dg[[i, 1]] = x[i].tanh();
1754 }
1755
1756 let inputs = build_survival_compiler_inputs(
1757 time_dq0, time_dq1, time_dqd1, marg_dq, marg_dqd1, log_dg, None, None,
1758 );
1759
1760 let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
1766 for i in 0..n {
1767 for k in 0..K_SURVIVAL {
1768 h_full[[i, k, k]] = 1.0;
1769 }
1770 }
1771 let row_hess = SurvivalRowHessian::from_full(h_full);
1772
1773 let compiled = compile(&inputs.operators, &row_hess, &inputs.ordering)
1774 .expect("survival 3-block compile must succeed; aliasing is single-direction");
1775
1776 assert_eq!(compiled.blocks.len(), 3, "expected 3 CompiledBlocks");
1778
1779 let v_time = &compiled.blocks[0].t_lw;
1784 assert_eq!(
1785 v_time.ncols(),
1786 p_time,
1787 "time block (first in ordering) must retain all {p_time} of its columns; V_time={:?}",
1788 v_time.dim(),
1789 );
1790
1791 let v_marg = &compiled.blocks[1].t_lw;
1798 assert_eq!(
1799 v_marg.ncols(),
1800 p_marginal - 1,
1801 "marginal block must lose exactly the shared-constant direction; \
1802 V_marginal cols = {}, expected {}",
1803 v_marg.ncols(),
1804 p_marginal - 1,
1805 );
1806
1807 let v_log = &compiled.blocks[2].t_lw;
1810 assert_eq!(
1811 v_log.ncols(),
1812 p_logslope,
1813 "logslope block (no shared direction) must retain all {p_logslope} columns",
1814 );
1815
1816 let raw_total = p_time + p_marginal + p_logslope;
1819 let kept_total: usize = compiled.blocks.iter().map(|b| b.t_lw.ncols()).sum();
1820 assert_eq!(
1821 kept_total,
1822 raw_total - 1,
1823 "joint kept = raw_total − aliased; got {kept_total}, expected {}",
1824 raw_total - 1,
1825 );
1826 assert_eq!(
1827 compiled.joint_rank, kept_total,
1828 "CompiledBlocks::joint_rank must match the sum of per-block t_lw widths",
1829 );
1830
1831 let v_per_term: Vec<Array2<f64>> = compiled.blocks.iter().map(|b| b.t_lw.clone()).collect();
1841 let r_per_term: Vec<Option<Array2<f64>>> = vec![None; v_per_term.len()];
1842 let gauge = Gauge::from_v_and_r(&v_per_term, &r_per_term);
1843
1844 let mut expected_reduced = vec![0usize];
1845 let mut expected_raw = vec![0usize];
1846 for b in &compiled.blocks {
1847 let prev_reduced = *expected_reduced.last().unwrap();
1848 expected_reduced.push(prev_reduced + b.t_lw.ncols());
1849 let prev_raw = *expected_raw.last().unwrap();
1850 expected_raw.push(prev_raw + b.t_lw.nrows());
1851 }
1852 assert_eq!(
1853 *gauge.block_starts_reduced.last().unwrap(),
1854 compiled.joint_rank,
1855 "SMGS lift reduced dimension must equal the compiled joint_rank",
1856 );
1857 assert_eq!(
1858 gauge.block_starts_reduced, expected_reduced,
1859 "SMGS lift reduced block boundaries must match the compiled kept widths",
1860 );
1861 assert_eq!(
1862 gauge.block_starts_raw, expected_raw,
1863 "SMGS lift raw block boundaries must match the compiled per-block raw widths",
1864 );
1865
1866 for (bi, block) in compiled.blocks.iter().enumerate() {
1871 for j in 0..block.t_lw.ncols() {
1872 let col = block.t_lw.column(j);
1873 assert!(
1874 col.iter().all(|v| v.is_finite()),
1875 "block {bi} kept direction {j} has a non-finite entry",
1876 );
1877 let norm = col.dot(&col).sqrt();
1878 assert!(
1879 norm > 1e-10,
1880 "block {bi} kept direction {j} is degenerate (norm {norm:.3e})",
1881 );
1882 }
1883 }
1884 }
1885
1886 #[test]
1889 fn smgs_lift_via_t_identity_passes_through() {
1890 let v0 = Array2::<f64>::eye(3);
1891 let v1 = Array2::<f64>::eye(2);
1892 let v_per_term = vec![v0, v1];
1893 let r_per_term: Vec<Option<Array2<f64>>> = vec![None, None];
1894 let lift = Gauge::from_v_and_r(&v_per_term, &r_per_term);
1895 assert_eq!(lift.t_full.dim(), (5, 5));
1896 assert_eq!(lift.block_starts_reduced, vec![0, 3, 5]);
1897 assert_eq!(lift.block_starts_raw, vec![0, 3, 5]);
1898 for i in 0..5 {
1899 for j in 0..5 {
1900 let want = if i == j { 1.0 } else { 0.0 };
1901 assert!((lift.t_full[[i, j]] - want).abs() < 1e-14);
1902 }
1903 }
1904 let theta_0 = Array1::from(vec![1.0_f64, -2.0, 3.5]);
1905 let theta_1 = Array1::from(vec![-0.5_f64, 7.0]);
1906 let lifted = lift.lift_block_betas(&[theta_0.clone(), theta_1.clone()]);
1907 assert_eq!(lifted.len(), 2);
1908 for (a, b) in theta_0.iter().zip(lifted[0].iter()) {
1909 assert!((a - b).abs() < 1e-14);
1910 }
1911 for (a, b) in theta_1.iter().zip(lifted[1].iter()) {
1912 assert!((a - b).abs() < 1e-14);
1913 }
1914 }
1915
1916 #[test]
1920 fn smgs_lift_via_t_two_block_with_residualisation() {
1921 let v_a = Array2::<f64>::eye(3);
1922 let mut v_b = Array2::<f64>::zeros((3, 2));
1923 v_b[[0, 0]] = 1.0;
1924 v_b[[2, 1]] = 1.0;
1925 let mut r_b = Array2::<f64>::zeros((3, 2));
1926 r_b[[0, 0]] = 0.4;
1927 r_b[[0, 1]] = -0.1;
1928 r_b[[1, 0]] = 0.7;
1929 r_b[[1, 1]] = 1.3;
1930 r_b[[2, 0]] = -0.2;
1931 r_b[[2, 1]] = 0.5;
1932 let lift = Gauge::from_v_and_r(&[v_a.clone(), v_b.clone()], &[None, Some(r_b.clone())]);
1933 assert_eq!(lift.t_full.dim(), (6, 5));
1934 assert_eq!(lift.block_starts_reduced, vec![0, 3, 5]);
1935 assert_eq!(lift.block_starts_raw, vec![0, 3, 6]);
1936
1937 let theta_a = Array1::from(vec![1.0_f64, 2.0, -1.5]);
1938 let theta_b = Array1::from(vec![0.5_f64, -0.25]);
1939 let lifted = lift.lift_block_betas(&[theta_a.clone(), theta_b.clone()]);
1940 let r_theta_b = r_b.dot(&theta_b);
1941 let expected_a = &theta_a - &r_theta_b;
1942 assert_eq!(lifted[0].len(), 3);
1943 for (got, want) in lifted[0].iter().zip(expected_a.iter()) {
1944 assert!((got - want).abs() < 1e-12, "got {got}, want {want}");
1945 }
1946 assert_eq!(lifted[1].len(), 3);
1947 assert!((lifted[1][0] - theta_b[0]).abs() < 1e-12);
1948 assert!(lifted[1][1].abs() < 1e-12);
1949 assert!((lifted[1][2] - theta_b[1]).abs() < 1e-12);
1950 }
1951
1952 #[test]
1964 fn smgs_lift_covariance_identity_and_rank1_consistency() {
1965 let lift_id = Gauge::from_v_and_r(
1967 &[Array2::<f64>::eye(2), Array2::<f64>::eye(2)],
1968 &[None, None],
1969 );
1970 let mut cov = Array2::<f64>::zeros((4, 4));
1971 for i in 0..4 {
1973 for j in 0..4 {
1974 cov[[i, j]] = 1.0 / (1.0 + (i as f64 - j as f64).abs());
1975 }
1976 }
1977 let lifted_id = lift_id.lift_covariance(&cov);
1978 assert_eq!(lifted_id.dim(), (4, 4));
1979 for i in 0..4 {
1980 for j in 0..4 {
1981 assert!(
1982 (lifted_id[[i, j]] - cov[[i, j]]).abs() < 1e-12,
1983 "identity-T covariance lift must be a no-op at [{i},{j}]",
1984 );
1985 }
1986 }
1987
1988 let v_a = Array2::<f64>::eye(3);
1993 let mut v_b = Array2::<f64>::zeros((3, 2));
1994 v_b[[0, 0]] = 1.0;
1995 v_b[[2, 1]] = 1.0;
1996 let mut r_b = Array2::<f64>::zeros((3, 2));
1997 r_b[[0, 0]] = 0.4;
1998 r_b[[0, 1]] = -0.1;
1999 r_b[[1, 0]] = 0.7;
2000 r_b[[1, 1]] = 1.3;
2001 r_b[[2, 0]] = -0.2;
2002 r_b[[2, 1]] = 0.5;
2003 let lift = Gauge::from_v_and_r(&[v_a, v_b], &[None, Some(r_b)]);
2004
2005 let theta_a = Array1::from(vec![1.0_f64, 2.0, -1.5]);
2006 let theta_b = Array1::from(vec![0.5_f64, -0.25]);
2007 let theta_full = Array1::from(vec![
2009 theta_a[0], theta_a[1], theta_a[2], theta_b[0], theta_b[1],
2010 ]);
2011 let mut cov_rank1 = Array2::<f64>::zeros((5, 5));
2013 for i in 0..5 {
2014 for j in 0..5 {
2015 cov_rank1[[i, j]] = theta_full[i] * theta_full[j];
2016 }
2017 }
2018 let lifted_cov = lift.lift_covariance(&cov_rank1);
2019 let lifted_blocks = lift.lift_block_betas(&[theta_a, theta_b]);
2021 let beta_raw = Array1::from(
2022 lifted_blocks
2023 .iter()
2024 .flat_map(|b| b.iter().copied())
2025 .collect::<Vec<f64>>(),
2026 );
2027 assert_eq!(lifted_cov.dim(), (6, 6));
2028 assert_eq!(beta_raw.len(), 6);
2029 for i in 0..6 {
2030 for j in 0..6 {
2031 let want = beta_raw[i] * beta_raw[j];
2032 assert!(
2033 (lifted_cov[[i, j]] - want).abs() < 1e-10,
2034 "rank-1 covariance pushforward must equal (Tθ)(Tθ)ᵀ at [{i},{j}]: got {}, want {want}",
2035 lifted_cov[[i, j]],
2036 );
2037 }
2038 }
2039 for i in 0..6 {
2041 for j in 0..6 {
2042 assert!((lifted_cov[[i, j]] - lifted_cov[[j, i]]).abs() < 1e-14);
2043 }
2044 }
2045 }
2046
2047 #[test]
2050 fn smgs_lift_via_t_zero_r_matches_per_block_v_lift() {
2051 let mut v_a = Array2::<f64>::zeros((3, 2));
2052 v_a[[0, 0]] = 0.6;
2053 v_a[[1, 0]] = -0.8;
2054 v_a[[1, 1]] = 0.3;
2055 v_a[[2, 1]] = 0.9;
2056 let mut v_b = Array2::<f64>::zeros((4, 3));
2057 v_b[[0, 0]] = 1.0;
2058 v_b[[1, 1]] = -0.4;
2059 v_b[[2, 0]] = 0.2;
2060 v_b[[2, 2]] = 0.7;
2061 v_b[[3, 2]] = -1.1;
2062 let v_per_term = vec![v_a.clone(), v_b.clone()];
2063 let lift = Gauge::from_v_and_r(&v_per_term, &[None, None]);
2064 let theta_a = Array1::from(vec![0.3_f64, -1.4]);
2065 let theta_b = Array1::from(vec![2.1_f64, 0.0, -0.7]);
2066 let via_t = lift.lift_block_betas(&[theta_a.clone(), theta_b.clone()]);
2067 let ref_a = v_a.dot(&theta_a);
2068 let ref_b = v_b.dot(&theta_b);
2069 assert_eq!(via_t[0].len(), ref_a.len());
2070 for (g, w) in via_t[0].iter().zip(ref_a.iter()) {
2071 assert!((g - w).abs() < 1e-12);
2072 }
2073 assert_eq!(via_t[1].len(), ref_b.len());
2074 for (g, w) in via_t[1].iter().zip(ref_b.iter()) {
2075 assert!((g - w).abs() < 1e-12);
2076 }
2077 }
2078
2079 #[test]
2089 fn recompile_after_accept_diff_detection_pilot_curvature_trap() {
2090 let n = 6usize;
2091 let time_dq0 = Array2::<f64>::from_elem((n, 1), 1.0);
2095 let time_dq1 = Array2::<f64>::zeros((n, 1));
2096 let time_dqd1 = Array2::<f64>::zeros((n, 1));
2097 let marg_dq = Array2::<f64>::from_elem((n, 1), 1.0);
2102 let marg_dqd1 = Array2::<f64>::zeros((n, 1));
2103 let log_dg = Array2::<f64>::zeros((n, 0));
2105 let mut time_partition: Vec<std::ops::Range<usize>> = Vec::with_capacity(1);
2106 time_partition.push(0..1);
2107 let mut marg_partition: Vec<std::ops::Range<usize>> = Vec::with_capacity(1);
2108 marg_partition.push(0..1);
2109 let log_partition: Vec<std::ops::Range<usize>> = Vec::new();
2110
2111 let mut h_ident = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
2115 for i in 0..n {
2116 for k in 0..K_SURVIVAL {
2117 h_ident[[i, k, k]] = 1.0;
2118 }
2119 }
2120 let row_hess_ident = SurvivalRowHessian::from_full(h_ident);
2121 let compiled_ident = compile_survival_parametric_designs_per_term(
2122 time_dq0.clone(),
2123 time_dq1.clone(),
2124 time_dqd1.clone(),
2125 &time_partition,
2126 marg_dq.clone(),
2127 marg_dqd1.clone(),
2128 &marg_partition,
2129 log_dg.clone(),
2130 &log_partition,
2131 &row_hess_ident,
2132 )
2133 .expect("identity-H compile must succeed");
2134
2135 let mut h_q0_only = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
2139 for i in 0..n {
2140 h_q0_only[[i, 0, 0]] = 1.0;
2141 }
2142 let row_hess_q0 = SurvivalRowHessian::from_full(h_q0_only);
2143 let compiled_q0 = compile_survival_parametric_designs_per_term(
2144 time_dq0,
2145 time_dq1,
2146 time_dqd1,
2147 &time_partition,
2148 marg_dq,
2149 marg_dqd1,
2150 &marg_partition,
2151 log_dg,
2152 &log_partition,
2153 &row_hess_q0,
2154 )
2155 .expect("q0-only-H compile must succeed");
2156
2157 assert_ne!(
2161 compiled_ident.drops_by_block, compiled_q0.drops_by_block,
2162 "structural-H and data-adaptive-H compiles must produce different \
2163 drops_by_block on the constructed pilot-curvature-trap design; \
2164 identity={:?} q0-only={:?}",
2165 compiled_ident.drops_by_block, compiled_q0.drops_by_block,
2166 );
2167 assert_eq!(
2169 compiled_ident.drops_by_block.1, 0,
2170 "identity-H marg drops expected 0, got {:?}",
2171 compiled_ident.drops_by_block,
2172 );
2173 assert_eq!(
2175 compiled_q0.drops_by_block.1, 1,
2176 "q0-only-H marg drops expected 1, got {:?}",
2177 compiled_q0.drops_by_block,
2178 );
2179 }
2180
2181 #[test]
2182 fn compiled_map_from_per_term_partitions_and_lift_round_trip() {
2183 let v_time = Array2::<f64>::eye(2);
2187 let mut v_marg = Array2::<f64>::zeros((2, 1));
2188 v_marg[[0, 0]] = 1.0;
2189 v_marg[[1, 0]] = 0.5;
2190 let v_log = Array2::<f64>::eye(1);
2191 let r_marg = Array2::<f64>::from_shape_fn((2, 1), |(i, _)| 0.25 + i as f64);
2194 let r_log = Array2::<f64>::from_shape_fn((4, 1), |(i, _)| 0.1 * (i as f64 + 1.0));
2195 let per_term = SurvivalParametricCompiledPerTerm {
2196 v_time_per_term: vec![v_time.clone()],
2197 v_marginal_per_term: vec![v_marg.clone()],
2198 v_logslope_per_term: vec![v_log.clone()],
2199 r_lw_per_term: vec![None, Some(r_marg.clone()), Some(r_log.clone())],
2200 drops_by_block: (0, 1, 0),
2201 };
2202
2203 let map = compiled_map_from_per_term(&per_term);
2204
2205 assert_eq!(map.raw_block_ranges, vec![0..2, 2..4, 4..5]);
2207 assert_eq!(map.compiled_block_ranges, vec![0..2, 2..3, 3..4]);
2209 assert_eq!(map.raw_from_compiled.dim(), (5, 4));
2210
2211 let v_time_slice = map
2214 .raw_from_compiled
2215 .slice(ndarray::s![0..2, 0..2])
2216 .to_owned();
2217 let v_marg_slice = map
2218 .raw_from_compiled
2219 .slice(ndarray::s![2..4, 2..3])
2220 .to_owned();
2221 let v_log_slice = map
2222 .raw_from_compiled
2223 .slice(ndarray::s![4..5, 3..4])
2224 .to_owned();
2225 for i in 0..2 {
2226 for j in 0..2 {
2227 assert!((v_time_slice[[i, j]] - v_time[[i, j]]).abs() < 1e-14);
2228 }
2229 assert!((v_marg_slice[[i, 0]] - v_marg[[i, 0]]).abs() < 1e-14);
2230 }
2231 assert!((v_log_slice[[0, 0]] - v_log[[0, 0]]).abs() < 1e-14);
2232
2233 let ordering = [
2236 gam_identifiability::families::compiler::BlockOrder::Time,
2237 gam_identifiability::families::compiler::BlockOrder::Marginal,
2238 gam_identifiability::families::compiler::BlockOrder::Logslope,
2239 ];
2240 let lift_from_map = Gauge::from_compiled_map(&map, &ordering);
2241 let v_all = vec![v_time, v_marg, v_log];
2242 let lift_direct = Gauge::from_v_and_r(&v_all, &[None, Some(r_marg), Some(r_log)]);
2243 assert_eq!(lift_from_map.t_full.dim(), lift_direct.t_full.dim());
2244 for i in 0..lift_from_map.t_full.nrows() {
2245 for j in 0..lift_from_map.t_full.ncols() {
2246 assert!(
2247 (lift_from_map.t_full[[i, j]] - lift_direct.t_full[[i, j]]).abs() < 1e-14,
2248 "T mismatch at ({i},{j}): map={} direct={}",
2249 lift_from_map.t_full[[i, j]],
2250 lift_direct.t_full[[i, j]],
2251 );
2252 }
2253 }
2254 }
2255}