1use std::sync::Arc;
30
31use ndarray::{Array1, Array2, Array3};
32
33use gam_identifiability::families::compiler::{
34 BlockOrder, RowHessian, RowJacobianOperator, scale_jacobian_by_sqrt_h_with,
35};
36use gam_problem::gauge::assemble_block_triangular_t;
37use faer::Side;
38use gam_linalg::faer_ndarray::FaerEigh;
39use gam_linalg::matrix::{CoefficientTransformOperator, DenseDesignMatrix, DesignMatrix};
40use gam_problem::{FamilyChannelHessian, PenaltyMatrix};
41
42const K_SURVIVAL: usize = 4;
43
44const BETA_NONTRIVIAL_ABS_THRESHOLD: f64 = 1e-12;
50
51pub struct SurvivalRowHessian {
58 h: Array3<f64>,
61}
62
63impl SurvivalRowHessian {
64 pub fn from_pilot_primary_state(
68 q0: &Array1<f64>,
69 q1: &Array1<f64>,
70 qd1: &Array1<f64>,
71 g: &Array1<f64>,
72 z: &Array1<f64>,
73 weights: &Array1<f64>,
74 event: &Array1<f64>,
75 derivative_guard: f64,
76 probit_scale: f64,
77 ) -> Result<Self, String> {
78 let n = q0.len();
79 if [
80 q1.len(),
81 qd1.len(),
82 g.len(),
83 z.len(),
84 weights.len(),
85 event.len(),
86 ]
87 .iter()
88 .any(|&l| l != n)
89 {
90 return Err(format!(
91 "SurvivalRowHessian: length mismatch \
92 q0={n}, q1={}, qd1={}, g={}, z={}, weights={}, event={}",
93 q1.len(),
94 qd1.len(),
95 g.len(),
96 z.len(),
97 weights.len(),
98 event.len()
99 ));
100 }
101 let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
102 for i in 0..n {
103 let (_, _grad, hess) =
104 crate::survival::marginal_slope::row_primary_for_compiler(
105 q0[i],
106 q1[i],
107 qd1[i],
108 g[i],
109 z[i],
110 weights[i],
111 event[i],
112 derivative_guard,
113 probit_scale,
114 )?;
115 let mut h_i = Array2::<f64>::zeros((K_SURVIVAL, K_SURVIVAL));
117 for a in 0..K_SURVIVAL {
118 for b in 0..K_SURVIVAL {
119 h_i[[a, b]] = hess[a][b];
120 }
121 }
122 let clamped = psd_clamp_4x4(&h_i);
123 for a in 0..K_SURVIVAL {
124 for b in 0..K_SURVIVAL {
125 h_full[[i, a, b]] = clamped[[a, b]];
126 }
127 }
128 }
129 Ok(Self { h: h_full })
130 }
131
132 pub fn from_full(h: Array3<f64>) -> Self {
135 assert_eq!(h.shape()[1], K_SURVIVAL);
136 assert_eq!(h.shape()[2], K_SURVIVAL);
137 Self { h }
138 }
139}
140
141impl RowHessian for SurvivalRowHessian {
142 fn k(&self) -> usize {
143 K_SURVIVAL
144 }
145 fn nrows(&self) -> usize {
146 self.h.shape()[0]
147 }
148 fn fill_row(&self, row: usize, out: &mut [f64]) {
149 assert_eq!(out.len(), K_SURVIVAL * K_SURVIVAL);
150 for a in 0..K_SURVIVAL {
151 for b in 0..K_SURVIVAL {
152 out[a * K_SURVIVAL + b] = self.h[[row, a, b]];
153 }
154 }
155 }
156 fn evaluate_full(&self) -> Array3<f64> {
157 self.h.clone()
158 }
159}
160
161impl FamilyChannelHessian for SurvivalRowHessian {
197 fn n_outputs(&self) -> usize {
198 K_SURVIVAL
199 }
200
201 fn n_subjects(&self) -> usize {
202 self.h.shape()[0]
203 }
204
205 fn fill_subject(&self, i: usize, out: &mut [f64]) {
206 assert_eq!(out.len(), K_SURVIVAL * K_SURVIVAL);
207 for a in 0..K_SURVIVAL {
208 for b in 0..K_SURVIVAL {
209 out[a * K_SURVIVAL + b] = self.h[[i, a, b]];
210 }
211 }
212 }
213
214 fn evaluate_full(&self) -> ndarray::Array3<f64> {
215 self.h.clone()
216 }
217
218 fn channel_hessian_at(
219 &self,
220 beta: &[f64],
221 family_scalars: Option<&Arc<dyn std::any::Any + Send + Sync>>,
222 ) -> Result<Arc<dyn FamilyChannelHessian>, String> {
223 use crate::survival::marginal_slope::SurvivalMarginalSlopeFamilyScalars;
224
225 let scalars_opt =
226 family_scalars.and_then(|a| a.downcast_ref::<SurvivalMarginalSlopeFamilyScalars>());
227
228 let beta_nontrivial = beta
230 .iter()
231 .any(|&b| b.abs() > BETA_NONTRIVIAL_ABS_THRESHOLD);
232
233 match scalars_opt {
234 None if beta_nontrivial => {
235 Err(
238 "SurvivalRowHessian::channel_hessian_at: beta is non-trivial but \
239 family_scalars is None; supply SurvivalMarginalSlopeFamilyScalars \
240 via FamilyLinearizationState::family_scalars to evaluate W(β) \
241 correctly (same contract as T26 Jacobian callbacks)."
242 .to_string(),
243 )
244 }
245 None => {
246 Ok(Arc::new(gam_problem::TensorChannelHessian {
248 h: self.h.clone(),
249 }))
250 }
251 Some(sc) => {
252 let n = self.h.shape()[0];
253 if sc.q0_i.len() != n
254 || sc.q1_i.len() != n
255 || sc.qd1_i.len() != n
256 || sc.g_i.len() != n
257 || sc.z_i.len() != n
258 {
259 return Err(format!(
260 "SurvivalRowHessian::channel_hessian_at: scalars length mismatch \
261 (expected n={n}, got q0={} q1={} qd1={} g={} z={})",
262 sc.q0_i.len(),
263 sc.q1_i.len(),
264 sc.qd1_i.len(),
265 sc.g_i.len(),
266 sc.z_i.len(),
267 ));
268 }
269 let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
286 for i in 0..n {
287 let q0 = sc.q0_i[i];
288 let q1 = sc.q1_i[i];
289 let qd1 = sc.qd1_i[i];
290 let g = sc.g_i[i];
291 let z = sc.z_i[i];
292 match crate::survival::marginal_slope::row_primary_for_compiler(
295 q0, q1, qd1, g, z, 1.0, 1.0, crate::survival::marginal_slope::DEFAULT_SURVIVAL_MARGINAL_SLOPE_DERIVATIVE_GUARD,
298 sc.s, ) {
300 Ok((_nll, _grad, hess)) => {
301 let mut h_i = ndarray::Array2::<f64>::zeros((K_SURVIVAL, K_SURVIVAL));
302 for a in 0..K_SURVIVAL {
303 for b in 0..K_SURVIVAL {
304 h_i[[a, b]] = hess[a][b];
305 }
306 }
307 let clamped = psd_clamp_4x4(&h_i);
308 for a in 0..K_SURVIVAL {
309 for b in 0..K_SURVIVAL {
310 h_full[[i, a, b]] = clamped[[a, b]];
311 }
312 }
313 }
314 Err(_) => {
315 for a in 0..K_SURVIVAL {
318 for b in 0..K_SURVIVAL {
319 h_full[[i, a, b]] = self.h[[i, a, b]];
320 }
321 }
322 }
323 }
324 }
325 Ok(Arc::new(SurvivalRowHessian::from_full(h_full)))
326 }
327 }
328 }
329}
330
331fn psd_clamp_4x4(m: &Array2<f64>) -> Array2<f64> {
336 let k = m.nrows();
337 let (evals, evecs) = match m.eigh(Side::Lower) {
338 Ok(pair) => pair,
339 Err(_) => {
340 let mut out = Array2::<f64>::zeros((k, k));
341 for i in 0..k {
342 out[[i, i]] = m[[i, i]].max(0.0);
343 }
344 return out;
345 }
346 };
347 let mut out = Array2::<f64>::zeros((k, k));
348 for i in 0..k {
349 for j in 0..k {
350 let mut acc = 0.0;
351 for l in 0..k {
352 acc += evecs[[i, l]] * evals[l].max(0.0) * evecs[[j, l]];
353 }
354 out[[i, j]] = acc;
355 }
356 }
357 out
358}
359
360pub struct TimeBlockOperator {
363 dq0: Array2<f64>,
364 dq1: Array2<f64>,
365 dqd1: Array2<f64>,
366}
367
368impl TimeBlockOperator {
369 pub fn new(dq0: Array2<f64>, dq1: Array2<f64>, dqd1: Array2<f64>) -> Self {
370 assert_eq!(dq0.dim(), dq1.dim());
371 assert_eq!(dq0.dim(), dqd1.dim());
372 Self { dq0, dq1, dqd1 }
373 }
374}
375
376impl RowJacobianOperator for TimeBlockOperator {
377 fn k(&self) -> usize {
378 K_SURVIVAL
379 }
380 fn ncols(&self) -> usize {
381 self.dq0.ncols()
382 }
383 fn nrows(&self) -> usize {
384 self.dq0.nrows()
385 }
386 fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
387 assert_eq!(out.len(), K_SURVIVAL);
388 assert_eq!(delta_beta.len(), self.dq0.ncols());
389 let mut acc = [0.0_f64; K_SURVIVAL];
390 for (j, &b) in delta_beta.iter().enumerate() {
391 acc[0] += self.dq0[[row, j]] * b;
392 acc[1] += self.dq1[[row, j]] * b;
393 acc[2] += self.dqd1[[row, j]] * b;
394 }
395 out.copy_from_slice(&acc);
396 }
397 fn evaluate_full(&self) -> Array3<f64> {
398 let n = self.dq0.nrows();
399 let p = self.dq0.ncols();
400 let mut out = Array3::<f64>::zeros((n, p, K_SURVIVAL));
401 for i in 0..n {
402 for j in 0..p {
403 out[[i, j, 0]] = self.dq0[[i, j]];
404 out[[i, j, 1]] = self.dq1[[i, j]];
405 out[[i, j, 2]] = self.dqd1[[i, j]];
406 }
407 }
408 out
409 }
410 fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
411 let n = self.dq0.nrows();
417 let p = self.dq0.ncols();
418 scale_jacobian_by_sqrt_h_with(n, p, K_SURVIVAL, h_full, |i, a, c| match c {
419 0 => self.dq0[[i, a]],
420 1 => self.dq1[[i, a]],
421 2 => self.dqd1[[i, a]],
422 _ => 0.0,
423 })
424 }
425}
426
427pub struct QChannelBlockOperator {
432 dq: Array2<f64>,
433 dqd1: Array2<f64>,
434}
435
436impl QChannelBlockOperator {
437 pub fn new(dq: Array2<f64>, dqd1: Array2<f64>) -> Self {
438 assert_eq!(dq.dim(), dqd1.dim());
439 Self { dq, dqd1 }
440 }
441}
442
443impl RowJacobianOperator for QChannelBlockOperator {
444 fn k(&self) -> usize {
445 K_SURVIVAL
446 }
447 fn ncols(&self) -> usize {
448 self.dq.ncols()
449 }
450 fn nrows(&self) -> usize {
451 self.dq.nrows()
452 }
453 fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
454 assert_eq!(out.len(), K_SURVIVAL);
455 assert_eq!(delta_beta.len(), self.dq.ncols());
456 let mut dq_acc = 0.0;
457 let mut dqd_acc = 0.0;
458 for (j, &b) in delta_beta.iter().enumerate() {
459 dq_acc += self.dq[[row, j]] * b;
460 dqd_acc += self.dqd1[[row, j]] * b;
461 }
462 out[0] = dq_acc;
463 out[1] = dq_acc;
464 out[2] = dqd_acc;
465 out[3] = 0.0;
466 }
467 fn evaluate_full(&self) -> Array3<f64> {
468 let n = self.dq.nrows();
469 let p = self.dq.ncols();
470 let mut out = Array3::<f64>::zeros((n, p, K_SURVIVAL));
471 for i in 0..n {
472 for j in 0..p {
473 let v = self.dq[[i, j]];
474 out[[i, j, 0]] = v;
475 out[[i, j, 1]] = v;
476 out[[i, j, 2]] = self.dqd1[[i, j]];
477 }
478 }
479 out
480 }
481 fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
482 let n = self.dq.nrows();
486 let p = self.dq.ncols();
487 scale_jacobian_by_sqrt_h_with(n, p, K_SURVIVAL, h_full, |i, a, c| match c {
488 0 | 1 => self.dq[[i, a]],
489 2 => self.dqd1[[i, a]],
490 _ => 0.0,
491 })
492 }
493}
494
495pub struct LogslopeBlockOperator {
498 dg: Array2<f64>,
499}
500
501impl LogslopeBlockOperator {
502 pub fn new(dg: Array2<f64>) -> Self {
503 Self { dg }
504 }
505}
506
507impl RowJacobianOperator for LogslopeBlockOperator {
508 fn k(&self) -> usize {
509 K_SURVIVAL
510 }
511 fn ncols(&self) -> usize {
512 self.dg.ncols()
513 }
514 fn nrows(&self) -> usize {
515 self.dg.nrows()
516 }
517 fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
518 assert_eq!(out.len(), K_SURVIVAL);
519 assert_eq!(delta_beta.len(), self.dg.ncols());
520 let mut acc = 0.0;
521 for (j, &b) in delta_beta.iter().enumerate() {
522 acc += self.dg[[row, j]] * b;
523 }
524 out[0] = 0.0;
525 out[1] = 0.0;
526 out[2] = 0.0;
527 out[3] = acc;
528 }
529 fn evaluate_full(&self) -> Array3<f64> {
530 let n = self.dg.nrows();
531 let p = self.dg.ncols();
532 let mut out = Array3::<f64>::zeros((n, p, K_SURVIVAL));
533 for i in 0..n {
534 for j in 0..p {
535 out[[i, j, 3]] = self.dg[[i, j]];
536 }
537 }
538 out
539 }
540 fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
541 let n = self.dg.nrows();
546 let p = self.dg.ncols();
547 scale_jacobian_by_sqrt_h_with(n, p, K_SURVIVAL, h_full, |i, a, c| {
548 if c == 3 { self.dg[[i, a]] } else { 0.0 }
549 })
550 }
551}
552
553pub struct SurvivalCompilerInputs {
557 pub operators: Vec<Arc<dyn RowJacobianOperator>>,
558 pub ordering: Vec<BlockOrder>,
559}
560
561pub struct SurvivalParametricCompiled {
577 pub v_time: Array2<f64>,
578 pub v_marginal: Array2<f64>,
579 pub v_logslope: Array2<f64>,
580 pub drops_by_block: (usize, usize, usize),
585}
586
587fn wrap_design_with_transform(
588 raw: DesignMatrix,
589 v: &Array2<f64>,
590 context: &str,
591) -> Result<DesignMatrix, String> {
592 if raw.ncols() != v.nrows() {
593 return Err(format!(
594 "{context}: raw design has {} cols but V has {} rows (V is {}×{})",
595 raw.ncols(),
596 v.nrows(),
597 v.nrows(),
598 v.ncols(),
599 ));
600 }
601 let inner_dense = match raw {
602 DesignMatrix::Dense(d) => d,
603 DesignMatrix::Sparse(_) => {
604 let dense = raw
605 .try_to_dense_by_chunks(&format!("{context} sparse→dense for V apply"))
606 .map_err(|reason| format!("{context}: densify failed: {reason}"))?;
607 DenseDesignMatrix::from(dense)
608 }
609 };
610 let op = CoefficientTransformOperator::new(inner_dense, v.clone())
611 .map_err(|reason| format!("{context}: CoefficientTransformOperator::new: {reason}"))?;
612 Ok(DesignMatrix::Dense(DenseDesignMatrix::from(Arc::new(op))))
613}
614
615pub struct SurvivalParametricCompiledPerTerm {
623 pub v_time_per_term: Vec<Array2<f64>>,
624 pub v_marginal_per_term: Vec<Array2<f64>>,
625 pub v_logslope_per_term: Vec<Array2<f64>>,
626 pub r_lw_per_term: Vec<Option<Array2<f64>>>,
633 pub drops_by_block: (usize, usize, usize),
635}
636
637pub fn compile_survival_parametric_designs_per_term(
656 time_dq0: Array2<f64>,
657 time_dq1: Array2<f64>,
658 time_dqd1: Array2<f64>,
659 time_partition: &[std::ops::Range<usize>],
660 marginal_dq: Array2<f64>,
661 marginal_dqd1: Array2<f64>,
662 marginal_partition: &[std::ops::Range<usize>],
663 logslope_dg: Array2<f64>,
664 logslope_partition: &[std::ops::Range<usize>],
665 row_hess: &dyn RowHessian,
666 protect_time: bool,
667) -> Result<SurvivalParametricCompiledPerTerm, String> {
668 use gam_identifiability::families::compiler::compile_protected;
669
670 let p_time = time_dq0.ncols();
671 let p_marg = marginal_dq.ncols();
672 let p_log = logslope_dg.ncols();
673 validate_partition(time_partition, p_time, "time")?;
674 validate_partition(marginal_partition, p_marg, "marginal")?;
675 validate_partition(logslope_partition, p_log, "logslope")?;
676
677 let mut operators: Vec<Arc<dyn RowJacobianOperator>> = Vec::new();
681 let mut ordering: Vec<BlockOrder> = Vec::new();
682 for range in time_partition {
683 let dq0 = time_dq0.slice(ndarray::s![.., range.clone()]).to_owned();
684 let dq1 = time_dq1.slice(ndarray::s![.., range.clone()]).to_owned();
685 let dqd1 = time_dqd1.slice(ndarray::s![.., range.clone()]).to_owned();
686 operators.push(Arc::new(TimeBlockOperator::new(dq0, dq1, dqd1)));
687 ordering.push(BlockOrder::Time);
688 }
689 for range in marginal_partition {
690 let dq = marginal_dq.slice(ndarray::s![.., range.clone()]).to_owned();
691 let dqd1 = marginal_dqd1
692 .slice(ndarray::s![.., range.clone()])
693 .to_owned();
694 operators.push(Arc::new(QChannelBlockOperator::new(dq, dqd1)));
695 ordering.push(BlockOrder::Marginal);
696 }
697 for range in logslope_partition {
698 let dg = logslope_dg.slice(ndarray::s![.., range.clone()]).to_owned();
699 operators.push(Arc::new(LogslopeBlockOperator::new(dg)));
700 ordering.push(BlockOrder::Logslope);
701 }
702
703 let n_time = time_partition.len();
712 let protected: Vec<bool> = if protect_time {
713 (0..operators.len()).map(|i| i < n_time).collect()
714 } else {
715 Vec::new()
716 };
717 let compiled =
718 compile_protected(&operators, row_hess, &ordering, &protected).map_err(|e| {
719 format!("identifiability::families::compiler::compile (per-term) failed: {e}")
720 })?;
721 let blocks = compiled.blocks;
722 let n_marg = marginal_partition.len();
723 let n_log = logslope_partition.len();
724 if blocks.len() != n_time + n_marg + n_log {
725 return Err(format!(
726 "per-term compile: expected {} compiled blocks (time={}, marg={}, log={}), got {}",
727 n_time + n_marg + n_log,
728 n_time,
729 n_marg,
730 n_log,
731 blocks.len(),
732 ));
733 }
734 let mut iter = blocks.into_iter();
735 let mut v_time_per_term: Vec<Array2<f64>> = Vec::with_capacity(n_time);
736 let mut r_time_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_time);
737 for _ in 0..n_time {
738 let blk = iter.next().unwrap();
739 v_time_per_term.push(blk.t_lw);
740 r_time_per_term.push(blk.r_lw);
741 }
742 let mut v_marginal_per_term: Vec<Array2<f64>> = Vec::with_capacity(n_marg);
743 let mut r_marginal_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_marg);
744 for _ in 0..n_marg {
745 let blk = iter.next().unwrap();
746 v_marginal_per_term.push(blk.t_lw);
747 r_marginal_per_term.push(blk.r_lw);
748 }
749 let mut v_logslope_per_term: Vec<Array2<f64>> = Vec::with_capacity(n_log);
750 let mut r_logslope_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_log);
751 for _ in 0..n_log {
752 let blk = iter.next().unwrap();
753 v_logslope_per_term.push(blk.t_lw);
754 r_logslope_per_term.push(blk.r_lw);
755 }
756 let mut r_lw_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_time + n_marg + n_log);
757 r_lw_per_term.extend(r_time_per_term);
758 r_lw_per_term.extend(r_marginal_per_term);
759 r_lw_per_term.extend(r_logslope_per_term);
760 let drops_time: usize = time_partition
761 .iter()
762 .zip(v_time_per_term.iter())
763 .map(|(r, v)| r.len().saturating_sub(v.ncols()))
764 .sum();
765 let drops_marg: usize = marginal_partition
766 .iter()
767 .zip(v_marginal_per_term.iter())
768 .map(|(r, v)| r.len().saturating_sub(v.ncols()))
769 .sum();
770 let drops_log: usize = logslope_partition
771 .iter()
772 .zip(v_logslope_per_term.iter())
773 .map(|(r, v)| r.len().saturating_sub(v.ncols()))
774 .sum();
775 Ok(SurvivalParametricCompiledPerTerm {
776 v_time_per_term,
777 v_marginal_per_term,
778 v_logslope_per_term,
779 r_lw_per_term,
780 drops_by_block: (drops_time, drops_marg, drops_log),
781 })
782}
783
784fn validate_partition(
785 partition: &[std::ops::Range<usize>],
786 p_block: usize,
787 label: &str,
788) -> Result<(), String> {
789 if partition.is_empty() {
790 if p_block == 0 {
791 return Ok(());
792 }
793 return Err(format!(
794 "{label} partition empty but block has p={p_block} columns"
795 ));
796 }
797 if partition[0].start != 0 {
798 return Err(format!(
799 "{label} partition must start at 0, got start={}",
800 partition[0].start
801 ));
802 }
803 if partition.last().unwrap().end != p_block {
804 return Err(format!(
805 "{label} partition must cover [0, {p_block}); last range ends at {}",
806 partition.last().unwrap().end
807 ));
808 }
809 for w in partition.windows(2) {
810 if w[0].end != w[1].start {
811 return Err(format!(
812 "{label} partition has gap/overlap between [{}..{}) and [{}..{})",
813 w[0].start, w[0].end, w[1].start, w[1].end
814 ));
815 }
816 if w[0].is_empty() {
817 return Err(format!(
818 "{label} partition has empty range [{}..{})",
819 w[0].start, w[0].end
820 ));
821 }
822 }
823 if partition.last().unwrap().is_empty() {
824 return Err(format!("{label} partition's final range is empty",));
825 }
826 Ok(())
827}
828
829pub fn extract_term_partition_from_penalty_ranges(
835 p_block: usize,
836 penalty_ranges: &[std::ops::Range<usize>],
837) -> Vec<std::ops::Range<usize>> {
838 use std::collections::BTreeSet;
839 let mut starts: BTreeSet<usize> = BTreeSet::new();
840 starts.insert(0);
841 starts.insert(p_block);
842 for r in penalty_ranges {
843 starts.insert(r.start.min(p_block));
844 starts.insert(r.end.min(p_block));
845 }
846 let v: Vec<usize> = starts.into_iter().collect();
847 v.windows(2)
848 .filter_map(|w| if w[0] < w[1] { Some(w[0]..w[1]) } else { None })
849 .collect()
850}
851
852pub fn pull_back_blockwise_penalty_through_block_v(
875 pen: &gam_terms::smooth::BlockwisePenalty,
876 v_block: &Array2<f64>,
877) -> Result<PenaltyMatrix, String> {
878 let raw_p = v_block.nrows();
879 let compiled_p = v_block.ncols();
880 let block_p = pen.col_range.len();
881 let embed_start = pen.col_range.start;
882 let embed_end = pen.col_range.end;
883 if embed_end > raw_p {
884 return Err(format!(
885 "pull_back_blockwise_penalty_through_block_v: penalty col_range {embed_start}..{embed_end} \
886 exceeds block raw width {raw_p}"
887 ));
888 }
889 if pen.local.nrows() != block_p || pen.local.ncols() != block_p {
890 return Err(format!(
891 "pull_back_blockwise_penalty_through_block_v: penalty local is {}x{} but col_range \
892 width is {block_p}",
893 pen.local.nrows(),
894 pen.local.ncols(),
895 ));
896 }
897 let mut embedded = Array2::<f64>::zeros((raw_p, raw_p));
898 if block_p > 0 {
899 let mut dst =
900 embedded.slice_mut(ndarray::s![embed_start..embed_end, embed_start..embed_end]);
901 for i in 0..block_p {
902 for j in 0..block_p {
903 dst[[i, j]] = pen.local[[i, j]];
904 }
905 }
906 }
907 let temp = embedded.dot(v_block);
909 let pulled = v_block.t().dot(&temp);
910 let mut sym = Array2::<f64>::zeros((compiled_p, compiled_p));
911 for i in 0..compiled_p {
912 for j in 0..compiled_p {
913 sym[[i, j]] = 0.5 * (pulled[[i, j]] + pulled[[j, i]]);
914 }
915 }
916 Ok(PenaltyMatrix::Dense(sym))
917}
918
919pub fn compiled_map_from_per_term(
941 compiled: &SurvivalParametricCompiledPerTerm,
942) -> gam_identifiability::families::compiler::CompiledMap {
943 let mut v_all: Vec<Array2<f64>> = Vec::new();
946 v_all.extend(compiled.v_time_per_term.iter().cloned());
947 v_all.extend(compiled.v_marginal_per_term.iter().cloned());
948 v_all.extend(compiled.v_logslope_per_term.iter().cloned());
949
950 let t_full = assemble_block_triangular_t(&v_all, &compiled.r_lw_per_term);
951
952 let raw_w = |terms: &[Array2<f64>]| -> usize { terms.iter().map(|v| v.nrows()).sum() };
954 let kept_w = |terms: &[Array2<f64>]| -> usize { terms.iter().map(|v| v.ncols()).sum() };
955 let raw_time = raw_w(&compiled.v_time_per_term);
956 let raw_marg = raw_w(&compiled.v_marginal_per_term);
957 let raw_log = raw_w(&compiled.v_logslope_per_term);
958 let kept_time = kept_w(&compiled.v_time_per_term);
959 let kept_marg = kept_w(&compiled.v_marginal_per_term);
960 let kept_log = kept_w(&compiled.v_logslope_per_term);
961
962 let raw_block_ranges = vec![
963 0..raw_time,
964 raw_time..(raw_time + raw_marg),
965 (raw_time + raw_marg)..(raw_time + raw_marg + raw_log),
966 ];
967 let compiled_block_ranges = vec![
968 0..kept_time,
969 kept_time..(kept_time + kept_marg),
970 (kept_time + kept_marg)..(kept_time + kept_marg + kept_log),
971 ];
972
973 gam_identifiability::families::compiler::CompiledMap {
974 raw_from_compiled: t_full,
975 compiled_block_ranges,
976 raw_block_ranges,
977 }
978}
979
980pub fn survival_reduced_logslope_transform_effective(
1036 marginal_dq: ndarray::ArrayView2<'_, f64>,
1037 logslope_dg: ndarray::ArrayView2<'_, f64>,
1038 row_hess: &SurvivalRowHessian,
1039) -> Result<Option<Array2<f64>>, String> {
1040 use crate::bms::block_specs::LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL;
1041 use gam_linalg::faer_ndarray::{
1042 FaerArrayView, factorize_symmetricwith_fallback, fast_atb, fast_xt_diag_x, fast_xt_diag_y,
1043 };
1044
1045 let n = marginal_dq.nrows();
1046 let p_m = marginal_dq.ncols();
1047 let p_log = logslope_dg.ncols();
1048 if p_m == 0 || p_log == 0 {
1049 return Ok(None);
1050 }
1051 if logslope_dg.nrows() != n || row_hess.h.shape()[0] != n {
1052 return Err(format!(
1053 "survival reduced logslope: row mismatch marginal={n}, logslope={}, row_hess={}",
1054 logslope_dg.nrows(),
1055 row_hess.h.shape()[0],
1056 ));
1057 }
1058
1059 let mut w_mm = Array1::<f64>::zeros(n);
1062 let mut w_mg = Array1::<f64>::zeros(n);
1063 let mut w_gg = Array1::<f64>::zeros(n);
1064 for i in 0..n {
1065 w_mm[i] = row_hess.h[[i, 0, 0]] + row_hess.h[[i, 1, 1]];
1066 w_mg[i] = row_hess.h[[i, 0, 3]] + row_hess.h[[i, 1, 3]];
1067 w_gg[i] = row_hess.h[[i, 3, 3]];
1068 if !(w_mm[i].is_finite() && w_mg[i].is_finite() && w_gg[i].is_finite()) {
1069 return Err("survival reduced logslope: non-finite row Hessian weight".to_string());
1070 }
1071 }
1072
1073 let marg = marginal_dq.to_owned();
1074 let log = logslope_dg.to_owned();
1075
1076 let c_gram = fast_xt_diag_x(&log, &w_gg);
1079 let energy_scale = (0..p_log).map(|i| c_gram[[i, i]]).fold(0.0_f64, f64::max);
1080 if !energy_scale.is_finite() || energy_scale <= 0.0 {
1081 return Ok(None);
1082 }
1083
1084 let mut a_gram = fast_xt_diag_x(&marg, &w_mm);
1088 let a_scale = (0..p_m).map(|i| a_gram[[i, i]]).fold(0.0_f64, f64::max);
1089 let a_ridge = (a_scale * LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL).max(f64::EPSILON);
1090 for i in 0..p_m {
1091 a_gram[[i, i]] += a_ridge;
1092 }
1093
1094 let b_cross = fast_xt_diag_y(&marg, &w_mg, &log);
1096 let a_view = FaerArrayView::new(&a_gram);
1097 let a_factor = factorize_symmetricwith_fallback(a_view.as_ref(), Side::Lower).map_err(|e| {
1098 format!("survival reduced logslope: marginal effective Gram factorization failed: {e}")
1099 })?;
1100 let b_view = FaerArrayView::new(&b_cross);
1101 let solved = a_factor.solve(b_view.as_ref()); let a_inv_b = Array2::from_shape_fn((p_m, p_log), |(i, j)| solved[(i, j)]);
1103 let schur = fast_atb(&b_cross, &a_inv_b); let mut stt = &c_gram - &schur;
1105 stt = (&stt + &stt.t()) * 0.5;
1106 if stt.iter().any(|v| !v.is_finite()) {
1107 return Err(
1108 "survival reduced logslope: effective Schur Gram produced non-finite entries"
1109 .to_string(),
1110 );
1111 }
1112
1113 let (evals, evecs) = stt
1114 .eigh(Side::Lower)
1115 .map_err(|e| format!("survival reduced logslope: eigendecomposition failed: {e:?}"))?;
1116 let tol = energy_scale * LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL;
1120 let mut kept: Vec<usize> = (0..evals.len()).filter(|&i| evals[i] > tol).collect();
1121 kept.sort_by(|&a, &b| {
1122 evals[b]
1123 .partial_cmp(&evals[a])
1124 .unwrap_or(std::cmp::Ordering::Equal)
1125 });
1126 let r = kept.len();
1127 if r == p_log || r == 0 {
1130 return Ok(None);
1131 }
1132 let mut transform = Array2::<f64>::zeros((p_log, r));
1133 for (out_col, &src) in kept.iter().enumerate() {
1134 transform.column_mut(out_col).assign(&evecs.column(src));
1135 }
1136 if transform.iter().any(|v| !v.is_finite()) {
1137 return Err(
1138 "survival reduced logslope: reduced transform produced non-finite entries".to_string(),
1139 );
1140 }
1141 Ok(Some(transform))
1142}
1143
1144pub fn survival_block_diagonal_logslope_map(
1161 p_time: usize,
1162 p_marg: usize,
1163 t_log: &Array2<f64>,
1164) -> gam_identifiability::families::compiler::CompiledMap {
1165 let p_log = t_log.nrows();
1166 let r = t_log.ncols();
1167 let raw_total = p_time + p_marg + p_log;
1168 let compiled_total = p_time + p_marg + r;
1169 let mut t_full = Array2::<f64>::zeros((raw_total, compiled_total));
1170 for i in 0..p_time {
1171 t_full[[i, i]] = 1.0;
1172 }
1173 for i in 0..p_marg {
1174 t_full[[p_time + i, p_time + i]] = 1.0;
1175 }
1176 for ri in 0..p_log {
1177 for cj in 0..r {
1178 t_full[[p_time + p_marg + ri, p_time + p_marg + cj]] = t_log[[ri, cj]];
1179 }
1180 }
1181 gam_identifiability::families::compiler::CompiledMap {
1182 raw_from_compiled: t_full,
1183 compiled_block_ranges: vec![
1184 0..p_time,
1185 p_time..(p_time + p_marg),
1186 (p_time + p_marg)..compiled_total,
1187 ],
1188 raw_block_ranges: vec![
1189 0..p_time,
1190 p_time..(p_time + p_marg),
1191 (p_time + p_marg)..raw_total,
1192 ],
1193 }
1194}
1195
1196pub fn apply_compiled_map_to_designs(
1224 map: &gam_identifiability::families::compiler::CompiledMap,
1225 time_design_entry: DesignMatrix,
1226 time_design_exit: DesignMatrix,
1227 time_design_derivative_exit: DesignMatrix,
1228 marginal_design: DesignMatrix,
1229 logslope_design: DesignMatrix,
1230 time_penalties: &[gam_terms::smooth::BlockwisePenalty],
1231 marginal_penalties: &[gam_terms::smooth::BlockwisePenalty],
1232 logslope_penalties: &[gam_terms::smooth::BlockwisePenalty],
1233) -> Result<CompiledSurvivalDesignsVMExact, String> {
1234 if map.raw_block_ranges.len() != 3 || map.compiled_block_ranges.len() != 3 {
1235 return Err(format!(
1236 "apply_compiled_map_to_designs: expected exactly 3 blocks (time, marginal, logslope), \
1237 got {} raw / {} compiled",
1238 map.raw_block_ranges.len(),
1239 map.compiled_block_ranges.len(),
1240 ));
1241 }
1242 let time_raw = map.raw_block_ranges[0].clone();
1243 let marg_raw = map.raw_block_ranges[1].clone();
1244 let log_raw = map.raw_block_ranges[2].clone();
1245 let time_compiled = map.compiled_block_ranges[0].clone();
1246 let marg_compiled = map.compiled_block_ranges[1].clone();
1247 let log_compiled = map.compiled_block_ranges[2].clone();
1248
1249 let t = &map.raw_from_compiled;
1250 let raw_total = t.nrows();
1251 let compiled_total = t.ncols();
1252 let expected_raw_total = log_raw.end;
1253 if raw_total != expected_raw_total {
1254 return Err(format!(
1255 "apply_compiled_map_to_designs: T has {raw_total} raw rows but block ranges sum to \
1256 {expected_raw_total}"
1257 ));
1258 }
1259 let expected_compiled_total = log_compiled.end;
1260 if compiled_total != expected_compiled_total {
1261 return Err(format!(
1262 "apply_compiled_map_to_designs: T has {compiled_total} compiled cols but block ranges \
1263 sum to {expected_compiled_total}"
1264 ));
1265 }
1266
1267 let v_time = t
1268 .slice(ndarray::s![time_raw.clone(), time_compiled.clone()])
1269 .to_owned();
1270 let v_marg = t
1271 .slice(ndarray::s![marg_raw.clone(), marg_compiled.clone()])
1272 .to_owned();
1273 let v_log = t
1274 .slice(ndarray::s![log_raw.clone(), log_compiled.clone()])
1275 .to_owned();
1276
1277 let time_entry_out =
1278 wrap_design_with_transform(time_design_entry, &v_time, "compiled-map: time entry")?;
1279 let time_exit_out =
1280 wrap_design_with_transform(time_design_exit, &v_time, "compiled-map: time exit")?;
1281 let time_deriv_out = wrap_design_with_transform(
1282 time_design_derivative_exit,
1283 &v_time,
1284 "compiled-map: time derivative_exit",
1285 )?;
1286 let marg_out = wrap_design_with_transform(marginal_design, &v_marg, "compiled-map: marginal")?;
1287 let log_out = wrap_design_with_transform(logslope_design, &v_log, "compiled-map: logslope")?;
1288
1289 let pull_set = |pens: &[gam_terms::smooth::BlockwisePenalty],
1310 v_block: &Array2<f64>,
1311 channel: &str|
1312 -> Result<Vec<PenaltyMatrix>, String> {
1313 pens.iter()
1314 .map(|p| {
1315 pull_back_blockwise_penalty_through_block_v(p, v_block).map_err(|e| {
1316 format!("apply_compiled_map_to_designs: {channel} penalty pullback: {e}")
1317 })
1318 })
1319 .collect()
1320 };
1321
1322 let time_penalties = pull_set(time_penalties, &v_time, "time")?;
1323 let marginal_penalties = pull_set(marginal_penalties, &v_marg, "marginal")?;
1324 let logslope_penalties = pull_set(logslope_penalties, &v_log, "logslope")?;
1325 validate_block_penalty_shapes("time", time_exit_out.ncols(), &time_penalties)?;
1326 validate_block_penalty_shapes("marginal", marg_out.ncols(), &marginal_penalties)?;
1327 validate_block_penalty_shapes("logslope", log_out.ncols(), &logslope_penalties)?;
1328
1329 Ok(CompiledSurvivalDesignsVMExact {
1330 time_design_entry: time_entry_out,
1331 time_design_exit: time_exit_out,
1332 time_design_derivative_exit: time_deriv_out,
1333 marginal_design: marg_out,
1334 logslope_design: log_out,
1335 time_penalties,
1336 marginal_penalties,
1337 logslope_penalties,
1338 })
1339}
1340
1341fn validate_block_penalty_shapes(
1342 block: &str,
1343 width: usize,
1344 penalties: &[PenaltyMatrix],
1345) -> Result<(), String> {
1346 for (idx, penalty) in penalties.iter().enumerate() {
1347 let shape = penalty.shape();
1348 if shape != (width, width) {
1349 return Err(format!(
1350 "apply_compiled_map_to_designs: {block} penalty {idx} must be {width}x{width}, got {}x{}",
1351 shape.0, shape.1
1352 ));
1353 }
1354 }
1355 Ok(())
1356}
1357
1358pub fn compile_survival_parametric_designs(
1386 time_dq0: Array2<f64>,
1387 time_dq1: Array2<f64>,
1388 time_dqd1: Array2<f64>,
1389 marginal_dq: Array2<f64>,
1390 marginal_dqd1: Array2<f64>,
1391 logslope_dg: Array2<f64>,
1392 row_hess: &dyn RowHessian,
1393) -> Result<SurvivalParametricCompiled, String> {
1394 use gam_identifiability::families::compiler::compile;
1395
1396 let p_time_raw = time_dq0.ncols();
1397 let p_marg_raw = marginal_dq.ncols();
1398 let p_log_raw = logslope_dg.ncols();
1399
1400 let inputs = build_survival_compiler_inputs(
1401 time_dq0,
1402 time_dq1,
1403 time_dqd1,
1404 marginal_dq,
1405 marginal_dqd1,
1406 logslope_dg,
1407 None,
1408 None,
1409 );
1410 if inputs.operators.len() != 3 {
1411 return Err(format!(
1412 "compile_survival_parametric_designs: expected exactly 3 parametric operators \
1413 (time, marginal, logslope); got {}",
1414 inputs.operators.len(),
1415 ));
1416 }
1417 let compiled = compile(&inputs.operators, row_hess, &inputs.ordering)
1418 .map_err(|e| format!("identifiability::families::compiler::compile failed: {e}"))?;
1419 if compiled.blocks.len() != 3 {
1420 return Err(format!(
1421 "compile_survival_parametric_designs: compiler emitted {} blocks; expected 3",
1422 compiled.blocks.len(),
1423 ));
1424 }
1425 let v_time = compiled.blocks[0].t_lw.clone();
1426 let v_marginal = compiled.blocks[1].t_lw.clone();
1427 let v_logslope = compiled.blocks[2].t_lw.clone();
1428 let drops_by_block = (
1429 p_time_raw.saturating_sub(v_time.ncols()),
1430 p_marg_raw.saturating_sub(v_marginal.ncols()),
1431 p_log_raw.saturating_sub(v_logslope.ncols()),
1432 );
1433 Ok(SurvivalParametricCompiled {
1434 v_time,
1435 v_marginal,
1436 v_logslope,
1437 drops_by_block,
1438 })
1439}
1440
1441pub fn build_survival_compiler_inputs(
1453 time_dq0: Array2<f64>,
1454 time_dq1: Array2<f64>,
1455 time_dqd1: Array2<f64>,
1456 marginal_dq: Array2<f64>,
1457 marginal_dqd1: Array2<f64>,
1458 logslope_dg: Array2<f64>,
1459 score_warp_dq_dqd1: Option<(Array2<f64>, Array2<f64>)>,
1460 link_dev_dq_dqd1: Option<(Array2<f64>, Array2<f64>)>,
1461) -> SurvivalCompilerInputs {
1462 let mut operators: Vec<Arc<dyn RowJacobianOperator>> = Vec::with_capacity(5);
1463 let mut ordering: Vec<BlockOrder> = Vec::with_capacity(5);
1464
1465 operators.push(Arc::new(TimeBlockOperator::new(
1466 time_dq0, time_dq1, time_dqd1,
1467 )));
1468 ordering.push(BlockOrder::Time);
1469
1470 operators.push(Arc::new(QChannelBlockOperator::new(
1471 marginal_dq,
1472 marginal_dqd1,
1473 )));
1474 ordering.push(BlockOrder::Marginal);
1475
1476 operators.push(Arc::new(LogslopeBlockOperator::new(logslope_dg)));
1477 ordering.push(BlockOrder::Logslope);
1478
1479 if let Some((dq, dqd1)) = score_warp_dq_dqd1 {
1480 operators.push(Arc::new(QChannelBlockOperator::new(dq, dqd1)));
1481 ordering.push(BlockOrder::ScoreWarp);
1482 }
1483 if let Some((dq, dqd1)) = link_dev_dq_dqd1 {
1484 operators.push(Arc::new(QChannelBlockOperator::new(dq, dqd1)));
1485 ordering.push(BlockOrder::LinkDev);
1486 }
1487
1488 SurvivalCompilerInputs {
1489 operators,
1490 ordering,
1491 }
1492}
1493
1494pub struct CompiledSurvivalDesignsVMExact {
1513 pub time_design_entry: DesignMatrix,
1514 pub time_design_exit: DesignMatrix,
1515 pub time_design_derivative_exit: DesignMatrix,
1516 pub marginal_design: DesignMatrix,
1517 pub logslope_design: DesignMatrix,
1518 pub time_penalties: Vec<PenaltyMatrix>,
1526 pub marginal_penalties: Vec<PenaltyMatrix>,
1527 pub logslope_penalties: Vec<PenaltyMatrix>,
1528}
1529
1530#[cfg(test)]
1531mod tests {
1532 use super::*;
1533 use gam_problem::Gauge;
1534
1535 #[test]
1536 fn psd_clamp_zeros_negative_eigenvalues() {
1537 let mut m = Array2::<f64>::zeros((4, 4));
1541 m[[0, 0]] = 2.0;
1544 m[[1, 1]] = -1.0;
1545 m[[2, 2]] = 0.5;
1546 m[[3, 3]] = -0.25;
1547 let clamped = psd_clamp_4x4(&m);
1548 assert!((clamped[[0, 0]] - 2.0).abs() < 1e-12);
1549 assert!(clamped[[1, 1]].abs() < 1e-12);
1550 assert!((clamped[[2, 2]] - 0.5).abs() < 1e-12);
1551 assert!(clamped[[3, 3]].abs() < 1e-12);
1552 }
1553
1554 #[test]
1555 fn time_block_operator_evaluate_full_shape() {
1556 let n = 6;
1557 let p = 3;
1558 let dq0 = Array2::from_shape_fn((n, p), |(i, j)| (i + j) as f64);
1559 let dq1 = Array2::from_shape_fn((n, p), |(i, j)| (i as f64) * 2.0 + j as f64);
1560 let dqd1 = Array2::from_shape_fn((n, p), |(i, j)| 0.5 * ((i * j) as f64));
1561 let op = TimeBlockOperator::new(dq0.clone(), dq1.clone(), dqd1.clone());
1562 let full = op.evaluate_full();
1563 assert_eq!(full.shape(), &[n, p, K_SURVIVAL]);
1564 for i in 0..n {
1565 for j in 0..p {
1566 assert_eq!(full[[i, j, 0]], dq0[[i, j]]);
1567 assert_eq!(full[[i, j, 1]], dq1[[i, j]]);
1568 assert_eq!(full[[i, j, 2]], dqd1[[i, j]]);
1569 assert_eq!(full[[i, j, 3]], 0.0);
1570 }
1571 }
1572 }
1573
1574 #[test]
1575 fn q_channel_block_apply_row_shares_q0_q1() {
1576 let n = 5;
1577 let p = 2;
1578 let dq = Array2::from_shape_fn((n, p), |(i, j)| (i as f64) * (j as f64 + 1.0));
1579 let dqd1 = Array2::from_shape_fn((n, p), |(i, j)| (j as f64) - (i as f64));
1580 let op = QChannelBlockOperator::new(dq.clone(), dqd1.clone());
1581 let mut out = [0.0_f64; K_SURVIVAL];
1582 let delta = [1.0_f64, -0.5];
1583 op.apply_row(3, &delta, &mut out);
1584 let want_q = dq[[3, 0]] * 1.0 + dq[[3, 1]] * (-0.5);
1585 let want_qd = dqd1[[3, 0]] * 1.0 + dqd1[[3, 1]] * (-0.5);
1586 assert!((out[0] - want_q).abs() < 1e-12);
1587 assert!((out[1] - want_q).abs() < 1e-12);
1588 assert!((out[2] - want_qd).abs() < 1e-12);
1589 assert_eq!(out[3], 0.0);
1590 }
1591
1592 #[test]
1593 fn logslope_block_writes_only_g_channel() {
1594 let n = 4;
1595 let p = 2;
1596 let dg = Array2::from_shape_fn((n, p), |(i, j)| (i as f64) + 0.1 * (j as f64));
1597 let op = LogslopeBlockOperator::new(dg.clone());
1598 let mut out = [0.0_f64; K_SURVIVAL];
1599 let delta = [2.0_f64, -1.0];
1600 op.apply_row(1, &delta, &mut out);
1601 assert_eq!(out[0], 0.0);
1602 assert_eq!(out[1], 0.0);
1603 assert_eq!(out[2], 0.0);
1604 let want = dg[[1, 0]] * 2.0 + dg[[1, 1]] * (-1.0);
1605 assert!((out[3] - want).abs() < 1e-12);
1606 }
1607
1608 #[test]
1609 fn extract_term_partition_simple_cases() {
1610 let full = 0..5usize;
1611 let part = extract_term_partition_from_penalty_ranges(5, &[]);
1613 assert_eq!(part.as_slice(), std::slice::from_ref(&full));
1614 let part = extract_term_partition_from_penalty_ranges(5, std::slice::from_ref(&full));
1616 assert_eq!(part.as_slice(), std::slice::from_ref(&full));
1617 let part = extract_term_partition_from_penalty_ranges(10, &[0..3, 6..10]);
1619 assert_eq!(part, vec![0..3, 3..6, 6..10]);
1620 let part = extract_term_partition_from_penalty_ranges(6, &[0..3, 0..3, 3..6]);
1622 assert_eq!(part, vec![0..3, 3..6]);
1623 let part = extract_term_partition_from_penalty_ranges(0, &[]);
1625 assert!(part.is_empty());
1626 }
1627
1628 #[test]
1629 fn assemble_block_triangular_t_identity_when_v_eye_and_r_none() {
1630 let v_a = Array2::<f64>::eye(2);
1631 let v_b = Array2::<f64>::eye(2);
1632 let t = assemble_block_triangular_t(&[v_a, v_b], &[None, None]);
1633 assert_eq!(t.dim(), (4, 4));
1634 let eye4 = Array2::<f64>::eye(4);
1635 for i in 0..4 {
1636 for j in 0..4 {
1637 assert!((t[[i, j]] - eye4[[i, j]]).abs() < 1e-14);
1638 }
1639 }
1640 }
1641
1642 #[test]
1643 fn assemble_block_triangular_t_with_drops_and_nonzero_r() {
1644 let mut v_a = Array2::<f64>::zeros((3, 2));
1645 v_a[[0, 0]] = 1.0;
1646 v_a[[1, 0]] = 0.5;
1647 v_a[[2, 1]] = 1.0;
1648 let v_b = Array2::<f64>::eye(2);
1649 let r_ab =
1650 Array2::<f64>::from_shape_fn((3, 2), |(i, j)| 1.0 + (i as f64) + 0.25 * (j as f64));
1651 let t =
1652 assemble_block_triangular_t(&[v_a.clone(), v_b.clone()], &[None, Some(r_ab.clone())]);
1653 assert_eq!(t.dim(), (5, 4));
1654 for i in 0..3 {
1655 for j in 0..2 {
1656 assert!((t[[i, j]] - v_a[[i, j]]).abs() < 1e-14);
1657 }
1658 }
1659 for i in 0..2 {
1660 for j in 0..2 {
1661 assert!((t[[3 + i, 2 + j]] - v_b[[i, j]]).abs() < 1e-14);
1662 }
1663 }
1664 for i in 0..3 {
1665 for j in 0..2 {
1666 assert!((t[[i, 2 + j]] + r_ab[[i, j]]).abs() < 1e-14);
1667 }
1668 }
1669 for i in 0..2 {
1670 for j in 0..2 {
1671 assert_eq!(t[[3 + i, j]], 0.0);
1672 }
1673 }
1674 }
1675
1676 #[test]
1677 fn validate_partition_rejects_bad_partitions() {
1678 let bad_start = 1..5usize;
1679 let short_cover = 0..3usize;
1680 let full_cover = 0..5usize;
1681 assert!(validate_partition(std::slice::from_ref(&bad_start), 5, "test").is_err());
1683 assert!(validate_partition(std::slice::from_ref(&short_cover), 5, "test").is_err());
1685 assert!(validate_partition(&[0..2, 3..5], 5, "test").is_err());
1687 assert!(validate_partition(&[0..3, 2..5], 5, "test").is_err());
1689 assert!(validate_partition(&[0..0, 0..5], 5, "test").is_err());
1691 assert!(validate_partition(&[], 0, "test").is_ok());
1693 assert!(validate_partition(&[0..2, 2..5], 5, "test").is_ok());
1695 assert!(validate_partition(std::slice::from_ref(&full_cover), 5, "test").is_ok());
1696 }
1697
1698 #[test]
1709 fn compiled_map_penalty_pullback_is_per_block_width_with_nonzero_residual() {
1710 use gam_identifiability::families::compiler::CompiledMap;
1711 use gam_terms::smooth::BlockwisePenalty;
1712
1713 let n = 10;
1714 let v_time =
1718 Array2::<f64>::from_shape_fn(
1719 (3, 3),
1720 |(i, j)| {
1721 if i == j { 1.0 } else { 0.1 * ((i + j) as f64) }
1722 },
1723 );
1724 let v_marg = Array2::<f64>::from_shape_fn((3, 2), |(i, j)| {
1725 0.5 + 0.3 * (i as f64) - 0.2 * (j as f64)
1726 });
1727 let v_log = Array2::<f64>::from_shape_fn((2, 2), |(i, j)| if i == j { 1.2 } else { 0.4 });
1728 let r_marg = Array2::<f64>::from_shape_fn((3, 2), |(i, j)| 0.7 - 0.1 * ((i + j) as f64));
1730 let r_log =
1735 Array2::<f64>::from_shape_fn((6, 2), |(i, j)| 0.3 + 0.05 * ((i * 2 + j) as f64));
1736
1737 let t = assemble_block_triangular_t(
1738 &[v_time.clone(), v_marg.clone(), v_log.clone()],
1739 &[None, Some(r_marg.clone()), Some(r_log.clone())],
1740 );
1741 assert_eq!(t.dim(), (8, 7), "joint raw 8 × joint compiled 7");
1742
1743 let map = CompiledMap {
1744 raw_from_compiled: t.clone(),
1745 compiled_block_ranges: vec![0..3, 3..5, 5..7],
1746 raw_block_ranges: vec![0..3, 3..6, 6..8],
1747 };
1748
1749 let raw_time_entry = DesignMatrix::Dense(DenseDesignMatrix::from(
1751 Array2::<f64>::from_shape_fn((n, 3), |(i, j)| 1.0 + (i as f64) * 0.1 + (j as f64)),
1752 ));
1753 let raw_time_exit = raw_time_entry.clone();
1754 let raw_time_deriv = raw_time_entry.clone();
1755 let raw_marg = DesignMatrix::Dense(DenseDesignMatrix::from(Array2::<f64>::from_shape_fn(
1756 (n, 3),
1757 |(i, j)| 0.2 * (i as f64) - 0.3 * (j as f64),
1758 )));
1759 let raw_log = DesignMatrix::Dense(DenseDesignMatrix::from(Array2::<f64>::from_shape_fn(
1760 (n, 2),
1761 |(i, j)| 0.5 + (i as f64) * (j as f64 + 1.0),
1762 )));
1763
1764 let s_time =
1766 Array2::<f64>::from_shape_fn(
1767 (3, 3),
1768 |(i, j)| if i == j { (i + 2) as f64 } else { 0.3 },
1769 );
1770 let s_marg =
1771 Array2::<f64>::from_shape_fn(
1772 (3, 3),
1773 |(i, j)| if i == j { 1.5 + i as f64 } else { 0.2 },
1774 );
1775 let s_log = Array2::<f64>::from_shape_fn((2, 2), |(i, j)| if i == j { 2.0 } else { 0.5 });
1776 let time_pens = vec![BlockwisePenalty::new(0..3, s_time.clone())];
1777 let marg_pens = vec![BlockwisePenalty::new(0..3, s_marg.clone())];
1778 let log_pens = vec![BlockwisePenalty::new(0..2, s_log.clone())];
1779
1780 let out = apply_compiled_map_to_designs(
1781 &map,
1782 raw_time_entry,
1783 raw_time_exit,
1784 raw_time_deriv,
1785 raw_marg,
1786 raw_log,
1787 &time_pens,
1788 &marg_pens,
1789 &log_pens,
1790 )
1791 .expect("apply_compiled_map_to_designs must succeed");
1792
1793 assert_eq!(out.time_design_entry.ncols(), 3);
1795 assert_eq!(out.marginal_design.ncols(), 2);
1796 assert_eq!(out.logslope_design.ncols(), 2);
1797
1798 for s in &out.time_penalties {
1801 assert_eq!(
1802 s.as_dense_cow().dim(),
1803 (3, 3),
1804 "time penalty must be per-block 3×3, not joint-width"
1805 );
1806 }
1807 for s in &out.marginal_penalties {
1808 assert_eq!(
1809 s.as_dense_cow().dim(),
1810 (2, 2),
1811 "marginal penalty must match reduced compiled width 2, not joint 7"
1812 );
1813 }
1814 for s in &out.logslope_penalties {
1815 assert_eq!(s.as_dense_cow().dim(), (2, 2));
1816 }
1817
1818 let p_time_dense = out.time_penalties[0].as_dense_cow().into_owned();
1822 let theta_time = Array1::<f64>::from_shape_fn(3, |k| 0.4 + 0.7 * (k as f64));
1823 let gamma_time = v_time.dot(&theta_time);
1824 let lhs = theta_time.dot(&p_time_dense.dot(&theta_time));
1825 let rhs = gamma_time.dot(&s_time.dot(&gamma_time));
1826 assert!(
1827 (lhs - rhs).abs() < 1e-10,
1828 "time-block per-block pullback must be exact: lhs={lhs}, rhs={rhs}"
1829 );
1830
1831 let p_marg_dense = out.marginal_penalties[0].as_dense_cow().into_owned();
1834 let want_marg = v_marg.t().dot(&s_marg.dot(&v_marg));
1835 for i in 0..2 {
1836 for j in 0..2 {
1837 assert!(
1838 (p_marg_dense[[i, j]] - want_marg[[i, j]]).abs() < 1e-12,
1839 "marginal penalty must be V_margᵀ S_marg V_marg at ({i},{j})"
1840 );
1841 }
1842 }
1843 }
1844
1845 #[test]
1852 fn compile_survival_parametric_designs_helper_attributes_drop_to_marginal() {
1853 let n = 24;
1854 let p_time = 3;
1855 let p_marginal = 3;
1856 let p_logslope = 2;
1857 let x: Vec<f64> = (0..n)
1858 .map(|i| -1.0 + 2.0 * (i as f64) / (n as f64 - 1.0))
1859 .collect();
1860 let mut time_dq0 = Array2::<f64>::zeros((n, p_time));
1861 let mut time_dq1 = Array2::<f64>::zeros((n, p_time));
1862 let mut time_dqd1 = Array2::<f64>::zeros((n, p_time));
1863 let mut marg_dq = Array2::<f64>::zeros((n, p_marginal));
1864 let marg_dqd1 = Array2::<f64>::zeros((n, p_marginal));
1865 let mut log_dg = Array2::<f64>::zeros((n, p_logslope));
1866 for i in 0..n {
1867 time_dq0[[i, 0]] = 1.0;
1868 time_dq0[[i, 1]] = x[i];
1869 time_dq0[[i, 2]] = x[i] * x[i];
1870 time_dq1[[i, 0]] = 1.0;
1871 time_dq1[[i, 1]] = x[i];
1872 time_dq1[[i, 2]] = x[i] * x[i];
1873 time_dqd1[[i, 0]] = 0.0;
1874 time_dqd1[[i, 1]] = 1.0;
1875 time_dqd1[[i, 2]] = 2.0 * x[i];
1876 marg_dq[[i, 0]] = 1.0; marg_dq[[i, 1]] = x[i] * x[i] * x[i];
1878 marg_dq[[i, 2]] = x[i].sin();
1879 log_dg[[i, 0]] = (2.0 * x[i]).cos();
1880 log_dg[[i, 1]] = x[i].tanh();
1881 }
1882 let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
1883 for i in 0..n {
1884 for k in 0..K_SURVIVAL {
1885 h_full[[i, k, k]] = 1.0;
1886 }
1887 }
1888 let row_hess = SurvivalRowHessian::from_full(h_full);
1889 let out = compile_survival_parametric_designs(
1890 time_dq0, time_dq1, time_dqd1, marg_dq, marg_dqd1, log_dg, &row_hess,
1891 )
1892 .expect("Phase-4b parametric compile must succeed on single-direction alias");
1893 assert_eq!(out.v_time.ncols(), p_time, "time keeps all columns");
1894 assert_eq!(
1895 out.v_marginal.ncols(),
1896 p_marginal - 1,
1897 "marginal loses exactly the shared-constant direction"
1898 );
1899 assert_eq!(out.v_logslope.ncols(), p_logslope, "logslope is clean");
1900 assert_eq!(
1901 out.drops_by_block,
1902 (0, 1, 0),
1903 "attribution: zero from time/logslope, one from marginal",
1904 );
1905 }
1906
1907 #[test]
1928 fn compile_survival_three_block_with_shared_constant_drops_one_direction() {
1929 use gam_identifiability::families::compiler::compile;
1930
1931 let n = 32;
1932 let p_time = 3;
1933 let p_marginal = 3;
1934 let p_logslope = 2;
1935
1936 let x: Vec<f64> = (0..n)
1947 .map(|i| -1.0 + 2.0 * (i as f64) / (n as f64 - 1.0))
1948 .collect();
1949 let mut time_dq0 = Array2::<f64>::zeros((n, p_time));
1950 let mut time_dq1 = Array2::<f64>::zeros((n, p_time));
1951 let mut time_dqd1 = Array2::<f64>::zeros((n, p_time));
1952 for i in 0..n {
1953 time_dq0[[i, 0]] = 1.0;
1954 time_dq0[[i, 1]] = x[i];
1955 time_dq0[[i, 2]] = x[i] * x[i];
1956 time_dq1[[i, 0]] = 1.0;
1957 time_dq1[[i, 1]] = x[i];
1958 time_dq1[[i, 2]] = x[i] * x[i];
1959 time_dqd1[[i, 0]] = 0.0;
1961 time_dqd1[[i, 1]] = 1.0;
1962 time_dqd1[[i, 2]] = 2.0 * x[i];
1963 }
1964
1965 let mut marg_dq = Array2::<f64>::zeros((n, p_marginal));
1971 let marg_dqd1 = Array2::<f64>::zeros((n, p_marginal));
1972 for i in 0..n {
1973 marg_dq[[i, 0]] = 1.0;
1974 marg_dq[[i, 1]] = x[i] * x[i] * x[i];
1975 marg_dq[[i, 2]] = x[i].sin();
1976 }
1977
1978 let mut log_dg = Array2::<f64>::zeros((n, p_logslope));
1982 for i in 0..n {
1983 log_dg[[i, 0]] = (2.0 * x[i]).cos();
1984 log_dg[[i, 1]] = x[i].tanh();
1985 }
1986
1987 let inputs = build_survival_compiler_inputs(
1988 time_dq0, time_dq1, time_dqd1, marg_dq, marg_dqd1, log_dg, None, None,
1989 );
1990
1991 let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
1997 for i in 0..n {
1998 for k in 0..K_SURVIVAL {
1999 h_full[[i, k, k]] = 1.0;
2000 }
2001 }
2002 let row_hess = SurvivalRowHessian::from_full(h_full);
2003
2004 let compiled = compile(&inputs.operators, &row_hess, &inputs.ordering)
2005 .expect("survival 3-block compile must succeed; aliasing is single-direction");
2006
2007 assert_eq!(compiled.blocks.len(), 3, "expected 3 CompiledBlocks");
2009
2010 let v_time = &compiled.blocks[0].t_lw;
2015 assert_eq!(
2016 v_time.ncols(),
2017 p_time,
2018 "time block (first in ordering) must retain all {p_time} of its columns; V_time={:?}",
2019 v_time.dim(),
2020 );
2021
2022 let v_marg = &compiled.blocks[1].t_lw;
2029 assert_eq!(
2030 v_marg.ncols(),
2031 p_marginal - 1,
2032 "marginal block must lose exactly the shared-constant direction; \
2033 V_marginal cols = {}, expected {}",
2034 v_marg.ncols(),
2035 p_marginal - 1,
2036 );
2037
2038 let v_log = &compiled.blocks[2].t_lw;
2041 assert_eq!(
2042 v_log.ncols(),
2043 p_logslope,
2044 "logslope block (no shared direction) must retain all {p_logslope} columns",
2045 );
2046
2047 let raw_total = p_time + p_marginal + p_logslope;
2050 let kept_total: usize = compiled.blocks.iter().map(|b| b.t_lw.ncols()).sum();
2051 assert_eq!(
2052 kept_total,
2053 raw_total - 1,
2054 "joint kept = raw_total − aliased; got {kept_total}, expected {}",
2055 raw_total - 1,
2056 );
2057 assert_eq!(
2058 compiled.joint_rank, kept_total,
2059 "CompiledBlocks::joint_rank must match the sum of per-block t_lw widths",
2060 );
2061
2062 let v_per_term: Vec<Array2<f64>> = compiled.blocks.iter().map(|b| b.t_lw.clone()).collect();
2072 let r_per_term: Vec<Option<Array2<f64>>> = vec![None; v_per_term.len()];
2073 let gauge = Gauge::from_v_and_r(&v_per_term, &r_per_term);
2074
2075 let mut expected_reduced = vec![0usize];
2076 let mut expected_raw = vec![0usize];
2077 for b in &compiled.blocks {
2078 let prev_reduced = *expected_reduced.last().unwrap();
2079 expected_reduced.push(prev_reduced + b.t_lw.ncols());
2080 let prev_raw = *expected_raw.last().unwrap();
2081 expected_raw.push(prev_raw + b.t_lw.nrows());
2082 }
2083 assert_eq!(
2084 *gauge.block_starts_reduced.last().unwrap(),
2085 compiled.joint_rank,
2086 "SMGS lift reduced dimension must equal the compiled joint_rank",
2087 );
2088 assert_eq!(
2089 gauge.block_starts_reduced, expected_reduced,
2090 "SMGS lift reduced block boundaries must match the compiled kept widths",
2091 );
2092 assert_eq!(
2093 gauge.block_starts_raw, expected_raw,
2094 "SMGS lift raw block boundaries must match the compiled per-block raw widths",
2095 );
2096
2097 for (bi, block) in compiled.blocks.iter().enumerate() {
2102 for j in 0..block.t_lw.ncols() {
2103 let col = block.t_lw.column(j);
2104 assert!(
2105 col.iter().all(|v| v.is_finite()),
2106 "block {bi} kept direction {j} has a non-finite entry",
2107 );
2108 let norm = col.dot(&col).sqrt();
2109 assert!(
2110 norm > 1e-10,
2111 "block {bi} kept direction {j} is degenerate (norm {norm:.3e})",
2112 );
2113 }
2114 }
2115 }
2116
2117 #[test]
2120 fn smgs_lift_via_t_identity_passes_through() {
2121 let v0 = Array2::<f64>::eye(3);
2122 let v1 = Array2::<f64>::eye(2);
2123 let v_per_term = vec![v0, v1];
2124 let r_per_term: Vec<Option<Array2<f64>>> = vec![None, None];
2125 let lift = Gauge::from_v_and_r(&v_per_term, &r_per_term);
2126 assert_eq!(lift.t_full.dim(), (5, 5));
2127 assert_eq!(lift.block_starts_reduced, vec![0, 3, 5]);
2128 assert_eq!(lift.block_starts_raw, vec![0, 3, 5]);
2129 for i in 0..5 {
2130 for j in 0..5 {
2131 let want = if i == j { 1.0 } else { 0.0 };
2132 assert!((lift.t_full[[i, j]] - want).abs() < 1e-14);
2133 }
2134 }
2135 let theta_0 = Array1::from(vec![1.0_f64, -2.0, 3.5]);
2136 let theta_1 = Array1::from(vec![-0.5_f64, 7.0]);
2137 let lifted = lift.lift_block_betas(&[theta_0.clone(), theta_1.clone()]);
2138 assert_eq!(lifted.len(), 2);
2139 for (a, b) in theta_0.iter().zip(lifted[0].iter()) {
2140 assert!((a - b).abs() < 1e-14);
2141 }
2142 for (a, b) in theta_1.iter().zip(lifted[1].iter()) {
2143 assert!((a - b).abs() < 1e-14);
2144 }
2145 }
2146
2147 #[test]
2151 fn smgs_lift_via_t_two_block_with_residualisation() {
2152 let v_a = Array2::<f64>::eye(3);
2153 let mut v_b = Array2::<f64>::zeros((3, 2));
2154 v_b[[0, 0]] = 1.0;
2155 v_b[[2, 1]] = 1.0;
2156 let mut r_b = Array2::<f64>::zeros((3, 2));
2157 r_b[[0, 0]] = 0.4;
2158 r_b[[0, 1]] = -0.1;
2159 r_b[[1, 0]] = 0.7;
2160 r_b[[1, 1]] = 1.3;
2161 r_b[[2, 0]] = -0.2;
2162 r_b[[2, 1]] = 0.5;
2163 let lift = Gauge::from_v_and_r(&[v_a.clone(), v_b.clone()], &[None, Some(r_b.clone())]);
2164 assert_eq!(lift.t_full.dim(), (6, 5));
2165 assert_eq!(lift.block_starts_reduced, vec![0, 3, 5]);
2166 assert_eq!(lift.block_starts_raw, vec![0, 3, 6]);
2167
2168 let theta_a = Array1::from(vec![1.0_f64, 2.0, -1.5]);
2169 let theta_b = Array1::from(vec![0.5_f64, -0.25]);
2170 let lifted = lift.lift_block_betas(&[theta_a.clone(), theta_b.clone()]);
2171 let r_theta_b = r_b.dot(&theta_b);
2172 let expected_a = &theta_a - &r_theta_b;
2173 assert_eq!(lifted[0].len(), 3);
2174 for (got, want) in lifted[0].iter().zip(expected_a.iter()) {
2175 assert!((got - want).abs() < 1e-12, "got {got}, want {want}");
2176 }
2177 assert_eq!(lifted[1].len(), 3);
2178 assert!((lifted[1][0] - theta_b[0]).abs() < 1e-12);
2179 assert!(lifted[1][1].abs() < 1e-12);
2180 assert!((lifted[1][2] - theta_b[1]).abs() < 1e-12);
2181 }
2182
2183 #[test]
2195 fn smgs_lift_covariance_identity_and_rank1_consistency() {
2196 let lift_id = Gauge::from_v_and_r(
2198 &[Array2::<f64>::eye(2), Array2::<f64>::eye(2)],
2199 &[None, None],
2200 );
2201 let mut cov = Array2::<f64>::zeros((4, 4));
2202 for i in 0..4 {
2204 for j in 0..4 {
2205 cov[[i, j]] = 1.0 / (1.0 + (i as f64 - j as f64).abs());
2206 }
2207 }
2208 let lifted_id = lift_id.lift_covariance(&cov);
2209 assert_eq!(lifted_id.dim(), (4, 4));
2210 for i in 0..4 {
2211 for j in 0..4 {
2212 assert!(
2213 (lifted_id[[i, j]] - cov[[i, j]]).abs() < 1e-12,
2214 "identity-T covariance lift must be a no-op at [{i},{j}]",
2215 );
2216 }
2217 }
2218
2219 let v_a = Array2::<f64>::eye(3);
2224 let mut v_b = Array2::<f64>::zeros((3, 2));
2225 v_b[[0, 0]] = 1.0;
2226 v_b[[2, 1]] = 1.0;
2227 let mut r_b = Array2::<f64>::zeros((3, 2));
2228 r_b[[0, 0]] = 0.4;
2229 r_b[[0, 1]] = -0.1;
2230 r_b[[1, 0]] = 0.7;
2231 r_b[[1, 1]] = 1.3;
2232 r_b[[2, 0]] = -0.2;
2233 r_b[[2, 1]] = 0.5;
2234 let lift = Gauge::from_v_and_r(&[v_a, v_b], &[None, Some(r_b)]);
2235
2236 let theta_a = Array1::from(vec![1.0_f64, 2.0, -1.5]);
2237 let theta_b = Array1::from(vec![0.5_f64, -0.25]);
2238 let theta_full = Array1::from(vec![
2240 theta_a[0], theta_a[1], theta_a[2], theta_b[0], theta_b[1],
2241 ]);
2242 let mut cov_rank1 = Array2::<f64>::zeros((5, 5));
2244 for i in 0..5 {
2245 for j in 0..5 {
2246 cov_rank1[[i, j]] = theta_full[i] * theta_full[j];
2247 }
2248 }
2249 let lifted_cov = lift.lift_covariance(&cov_rank1);
2250 let lifted_blocks = lift.lift_block_betas(&[theta_a, theta_b]);
2252 let beta_raw = Array1::from(
2253 lifted_blocks
2254 .iter()
2255 .flat_map(|b| b.iter().copied())
2256 .collect::<Vec<f64>>(),
2257 );
2258 assert_eq!(lifted_cov.dim(), (6, 6));
2259 assert_eq!(beta_raw.len(), 6);
2260 for i in 0..6 {
2261 for j in 0..6 {
2262 let want = beta_raw[i] * beta_raw[j];
2263 assert!(
2264 (lifted_cov[[i, j]] - want).abs() < 1e-10,
2265 "rank-1 covariance pushforward must equal (Tθ)(Tθ)ᵀ at [{i},{j}]: got {}, want {want}",
2266 lifted_cov[[i, j]],
2267 );
2268 }
2269 }
2270 for i in 0..6 {
2272 for j in 0..6 {
2273 assert!((lifted_cov[[i, j]] - lifted_cov[[j, i]]).abs() < 1e-14);
2274 }
2275 }
2276 }
2277
2278 #[test]
2281 fn smgs_lift_via_t_zero_r_matches_per_block_v_lift() {
2282 let mut v_a = Array2::<f64>::zeros((3, 2));
2283 v_a[[0, 0]] = 0.6;
2284 v_a[[1, 0]] = -0.8;
2285 v_a[[1, 1]] = 0.3;
2286 v_a[[2, 1]] = 0.9;
2287 let mut v_b = Array2::<f64>::zeros((4, 3));
2288 v_b[[0, 0]] = 1.0;
2289 v_b[[1, 1]] = -0.4;
2290 v_b[[2, 0]] = 0.2;
2291 v_b[[2, 2]] = 0.7;
2292 v_b[[3, 2]] = -1.1;
2293 let v_per_term = vec![v_a.clone(), v_b.clone()];
2294 let lift = Gauge::from_v_and_r(&v_per_term, &[None, None]);
2295 let theta_a = Array1::from(vec![0.3_f64, -1.4]);
2296 let theta_b = Array1::from(vec![2.1_f64, 0.0, -0.7]);
2297 let via_t = lift.lift_block_betas(&[theta_a.clone(), theta_b.clone()]);
2298 let ref_a = v_a.dot(&theta_a);
2299 let ref_b = v_b.dot(&theta_b);
2300 assert_eq!(via_t[0].len(), ref_a.len());
2301 for (g, w) in via_t[0].iter().zip(ref_a.iter()) {
2302 assert!((g - w).abs() < 1e-12);
2303 }
2304 assert_eq!(via_t[1].len(), ref_b.len());
2305 for (g, w) in via_t[1].iter().zip(ref_b.iter()) {
2306 assert!((g - w).abs() < 1e-12);
2307 }
2308 }
2309
2310 #[test]
2320 fn recompile_after_accept_diff_detection_pilot_curvature_trap() {
2321 let n = 6usize;
2322 let time_dq0 = Array2::<f64>::from_elem((n, 1), 1.0);
2326 let time_dq1 = Array2::<f64>::zeros((n, 1));
2327 let time_dqd1 = Array2::<f64>::zeros((n, 1));
2328 let marg_dq = Array2::<f64>::from_elem((n, 1), 1.0);
2333 let marg_dqd1 = Array2::<f64>::zeros((n, 1));
2334 let log_dg = Array2::<f64>::zeros((n, 0));
2336 let mut time_partition: Vec<std::ops::Range<usize>> = Vec::with_capacity(1);
2337 time_partition.push(0..1);
2338 let mut marg_partition: Vec<std::ops::Range<usize>> = Vec::with_capacity(1);
2339 marg_partition.push(0..1);
2340 let log_partition: Vec<std::ops::Range<usize>> = Vec::new();
2341
2342 let mut h_ident = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
2346 for i in 0..n {
2347 for k in 0..K_SURVIVAL {
2348 h_ident[[i, k, k]] = 1.0;
2349 }
2350 }
2351 let row_hess_ident = SurvivalRowHessian::from_full(h_ident);
2352 let compiled_ident = compile_survival_parametric_designs_per_term(
2353 time_dq0.clone(),
2354 time_dq1.clone(),
2355 time_dqd1.clone(),
2356 &time_partition,
2357 marg_dq.clone(),
2358 marg_dqd1.clone(),
2359 &marg_partition,
2360 log_dg.clone(),
2361 &log_partition,
2362 &row_hess_ident,
2363 false,
2364 )
2365 .expect("identity-H compile must succeed");
2366
2367 let mut h_q0_only = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
2371 for i in 0..n {
2372 h_q0_only[[i, 0, 0]] = 1.0;
2373 }
2374 let row_hess_q0 = SurvivalRowHessian::from_full(h_q0_only);
2375 let compiled_q0 = compile_survival_parametric_designs_per_term(
2376 time_dq0,
2377 time_dq1,
2378 time_dqd1,
2379 &time_partition,
2380 marg_dq,
2381 marg_dqd1,
2382 &marg_partition,
2383 log_dg,
2384 &log_partition,
2385 &row_hess_q0,
2386 false,
2387 )
2388 .expect("q0-only-H compile must succeed");
2389
2390 assert_ne!(
2394 compiled_ident.drops_by_block, compiled_q0.drops_by_block,
2395 "structural-H and data-adaptive-H compiles must produce different \
2396 drops_by_block on the constructed pilot-curvature-trap design; \
2397 identity={:?} q0-only={:?}",
2398 compiled_ident.drops_by_block, compiled_q0.drops_by_block,
2399 );
2400 assert_eq!(
2402 compiled_ident.drops_by_block.1, 0,
2403 "identity-H marg drops expected 0, got {:?}",
2404 compiled_ident.drops_by_block,
2405 );
2406 assert_eq!(
2408 compiled_q0.drops_by_block.1, 1,
2409 "q0-only-H marg drops expected 1, got {:?}",
2410 compiled_q0.drops_by_block,
2411 );
2412 }
2413
2414 #[test]
2415 fn compiled_map_from_per_term_partitions_and_lift_round_trip() {
2416 let v_time = Array2::<f64>::eye(2);
2420 let mut v_marg = Array2::<f64>::zeros((2, 1));
2421 v_marg[[0, 0]] = 1.0;
2422 v_marg[[1, 0]] = 0.5;
2423 let v_log = Array2::<f64>::eye(1);
2424 let r_marg = Array2::<f64>::from_shape_fn((2, 1), |(i, _)| 0.25 + i as f64);
2427 let r_log = Array2::<f64>::from_shape_fn((4, 1), |(i, _)| 0.1 * (i as f64 + 1.0));
2428 let per_term = SurvivalParametricCompiledPerTerm {
2429 v_time_per_term: vec![v_time.clone()],
2430 v_marginal_per_term: vec![v_marg.clone()],
2431 v_logslope_per_term: vec![v_log.clone()],
2432 r_lw_per_term: vec![None, Some(r_marg.clone()), Some(r_log.clone())],
2433 drops_by_block: (0, 1, 0),
2434 };
2435
2436 let map = compiled_map_from_per_term(&per_term);
2437
2438 assert_eq!(map.raw_block_ranges, vec![0..2, 2..4, 4..5]);
2440 assert_eq!(map.compiled_block_ranges, vec![0..2, 2..3, 3..4]);
2442 assert_eq!(map.raw_from_compiled.dim(), (5, 4));
2443
2444 let v_time_slice = map
2447 .raw_from_compiled
2448 .slice(ndarray::s![0..2, 0..2])
2449 .to_owned();
2450 let v_marg_slice = map
2451 .raw_from_compiled
2452 .slice(ndarray::s![2..4, 2..3])
2453 .to_owned();
2454 let v_log_slice = map
2455 .raw_from_compiled
2456 .slice(ndarray::s![4..5, 3..4])
2457 .to_owned();
2458 for i in 0..2 {
2459 for j in 0..2 {
2460 assert!((v_time_slice[[i, j]] - v_time[[i, j]]).abs() < 1e-14);
2461 }
2462 assert!((v_marg_slice[[i, 0]] - v_marg[[i, 0]]).abs() < 1e-14);
2463 }
2464 assert!((v_log_slice[[0, 0]] - v_log[[0, 0]]).abs() < 1e-14);
2465
2466 let ordering = [
2469 gam_identifiability::families::compiler::BlockOrder::Time,
2470 gam_identifiability::families::compiler::BlockOrder::Marginal,
2471 gam_identifiability::families::compiler::BlockOrder::Logslope,
2472 ];
2473 let lift_from_map = Gauge::from_compiled_map(&map, &ordering);
2474 let v_all = vec![v_time, v_marg, v_log];
2475 let lift_direct = Gauge::from_v_and_r(&v_all, &[None, Some(r_marg), Some(r_log)]);
2476 assert_eq!(lift_from_map.t_full.dim(), lift_direct.t_full.dim());
2477 for i in 0..lift_from_map.t_full.nrows() {
2478 for j in 0..lift_from_map.t_full.ncols() {
2479 assert!(
2480 (lift_from_map.t_full[[i, j]] - lift_direct.t_full[[i, j]]).abs() < 1e-14,
2481 "T mismatch at ({i},{j}): map={} direct={}",
2482 lift_from_map.t_full[[i, j]],
2483 lift_direct.t_full[[i, j]],
2484 );
2485 }
2486 }
2487 }
2488
2489 fn const_row_hess_q0g(n: usize, h00: f64, h03: f64, h33: f64) -> SurvivalRowHessian {
2505 let mut h = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
2506 for i in 0..n {
2507 h[[i, 0, 0]] = h00;
2508 h[[i, 0, 3]] = h03;
2509 h[[i, 3, 0]] = h03;
2510 h[[i, 3, 3]] = h33;
2511 }
2512 SurvivalRowHessian::from_full(h)
2513 }
2514
2515 #[test]
2516 fn survival_reduced_logslope_drops_confounded_keeps_free_979() {
2517 let n = 4;
2523 let row_hess = const_row_hess_q0g(n, 2.0, 2.0, 2.0); let marg = Array2::from_shape_vec((n, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap();
2525 let log =
2528 Array2::from_shape_vec((n, 2), vec![1.0, 10.0, 1.0, -10.0, 1.0, 10.0, 1.0, -10.0])
2529 .unwrap();
2530 let t = survival_reduced_logslope_transform_effective(marg.view(), log.view(), &row_hess)
2531 .expect("contraction must succeed")
2532 .expect("a partial confound must yield a reduced transform");
2533 assert_eq!(t.dim(), (2, 1), "exactly one logslope direction survives");
2534 assert!(
2537 t[[0, 0]].abs() < 1e-6,
2538 "confounded (e1) direction must be dropped, got {}",
2539 t[[0, 0]]
2540 );
2541 assert!(
2542 (t[[1, 0]].abs() - 1.0).abs() < 1e-6,
2543 "free (e2) direction must be kept as a unit vector, got {}",
2544 t[[1, 0]]
2545 );
2546 }
2547
2548 #[test]
2549 fn survival_reduced_logslope_fully_confounded_returns_none_979() {
2550 let n = 4;
2556 let row_hess = const_row_hess_q0g(n, 2.0, 2.0, 2.0);
2557 let marg = Array2::from_shape_vec((n, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap();
2558 let log = marg.clone();
2559 let out =
2560 survival_reduced_logslope_transform_effective(marg.view(), log.view(), &row_hess)
2561 .expect("contraction must succeed");
2562 assert!(
2563 out.is_none(),
2564 "a fully marginal-explained logslope column reduces to nothing → keep raw"
2565 );
2566 }
2567
2568 #[test]
2569 fn survival_reduced_logslope_no_confound_returns_none_979() {
2570 let n = 4;
2574 let row_hess = const_row_hess_q0g(n, 2.0, 0.0, 2.0);
2575 let marg = Array2::from_shape_vec((n, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap();
2576 let log =
2577 Array2::from_shape_vec((n, 2), vec![1.0, 10.0, 1.0, -10.0, 1.0, 10.0, 1.0, -10.0])
2578 .unwrap();
2579 let out =
2580 survival_reduced_logslope_transform_effective(marg.view(), log.view(), &row_hess)
2581 .expect("contraction must succeed");
2582 assert!(out.is_none(), "W-orthogonal channels need no reduction → keep raw");
2583 }
2584
2585 #[test]
2586 fn survival_block_diagonal_logslope_map_is_identity_on_time_and_marginal_979() {
2587 let p_time = 2;
2590 let p_marg = 3;
2591 let t_log = Array2::from_shape_fn((4, 2), |(i, j)| 1.0 + (i * 2 + j) as f64);
2592 let map = survival_block_diagonal_logslope_map(p_time, p_marg, &t_log);
2593
2594 assert_eq!(map.raw_block_ranges, vec![0..2, 2..5, 5..9]);
2595 assert_eq!(map.compiled_block_ranges, vec![0..2, 2..5, 5..7]);
2596 assert_eq!(map.raw_from_compiled.dim(), (9, 7));
2597
2598 let t = &map.raw_from_compiled;
2599 for i in 0..p_time {
2601 for j in 0..p_time {
2602 let want = if i == j { 1.0 } else { 0.0 };
2603 assert!((t[[i, j]] - want).abs() < 1e-14, "V_time[{i},{j}]");
2604 }
2605 }
2606 for i in 0..p_marg {
2608 for j in 0..p_marg {
2609 let want = if i == j { 1.0 } else { 0.0 };
2610 assert!((t[[p_time + i, p_time + j]] - want).abs() < 1e-14, "V_marg[{i},{j}]");
2611 }
2612 }
2613 for i in 0..4 {
2615 for j in 0..2 {
2616 assert!(
2617 (t[[p_time + p_marg + i, p_time + p_marg + j]] - t_log[[i, j]]).abs() < 1e-14,
2618 "V_log[{i},{j}]"
2619 );
2620 }
2621 }
2622 let nnz = t.iter().filter(|&&v| v != 0.0).count();
2625 assert_eq!(nnz, p_time + p_marg + t_log.iter().filter(|&&v| v != 0.0).count());
2626 }
2627}