1use std::sync::Arc;
30
31use ndarray::{Array1, Array2, Array3};
32
33use gam_identifiability::families::compiler::{
34 BlockOrder, RowHessian, RowJacobianOperator, scale_jacobian_by_sqrt_h_with,
35};
36use gam_problem::gauge::assemble_block_triangular_t;
37use faer::Side;
38use gam_linalg::faer_ndarray::FaerEigh;
39use gam_linalg::matrix::{CoefficientTransformOperator, DenseDesignMatrix, DesignMatrix};
40use gam_problem::{FamilyChannelHessian, PenaltyMatrix};
41
42const K_SURVIVAL: usize = 4;
43
44const BETA_NONTRIVIAL_ABS_THRESHOLD: f64 = 1e-12;
50
51pub struct SurvivalRowHessian {
58 h: Array3<f64>,
61}
62
63impl SurvivalRowHessian {
64 pub fn from_pilot_primary_state(
68 q0: &Array1<f64>,
69 q1: &Array1<f64>,
70 qd1: &Array1<f64>,
71 g: &Array1<f64>,
72 z: &Array1<f64>,
73 weights: &Array1<f64>,
74 event: &Array1<f64>,
75 derivative_guard: f64,
76 probit_scale: f64,
77 ) -> Result<Self, String> {
78 let n = q0.len();
79 if [
80 q1.len(),
81 qd1.len(),
82 g.len(),
83 z.len(),
84 weights.len(),
85 event.len(),
86 ]
87 .iter()
88 .any(|&l| l != n)
89 {
90 return Err(format!(
91 "SurvivalRowHessian: length mismatch \
92 q0={n}, q1={}, qd1={}, g={}, z={}, weights={}, event={}",
93 q1.len(),
94 qd1.len(),
95 g.len(),
96 z.len(),
97 weights.len(),
98 event.len()
99 ));
100 }
101 let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
102 for i in 0..n {
103 let (_, _grad, hess) =
104 crate::survival::marginal_slope::row_primary_for_compiler(
105 q0[i],
106 q1[i],
107 qd1[i],
108 g[i],
109 z[i],
110 weights[i],
111 event[i],
112 derivative_guard,
113 probit_scale,
114 )?;
115 let mut h_i = Array2::<f64>::zeros((K_SURVIVAL, K_SURVIVAL));
117 for a in 0..K_SURVIVAL {
118 for b in 0..K_SURVIVAL {
119 h_i[[a, b]] = hess[a][b];
120 }
121 }
122 let clamped = psd_clamp_4x4(&h_i);
123 for a in 0..K_SURVIVAL {
124 for b in 0..K_SURVIVAL {
125 h_full[[i, a, b]] = clamped[[a, b]];
126 }
127 }
128 }
129 Ok(Self { h: h_full })
130 }
131
132 pub fn from_full(h: Array3<f64>) -> Self {
135 assert_eq!(h.shape()[1], K_SURVIVAL);
136 assert_eq!(h.shape()[2], K_SURVIVAL);
137 Self { h }
138 }
139}
140
141impl RowHessian for SurvivalRowHessian {
142 fn k(&self) -> usize {
143 K_SURVIVAL
144 }
145 fn nrows(&self) -> usize {
146 self.h.shape()[0]
147 }
148 fn fill_row(&self, row: usize, out: &mut [f64]) {
149 assert_eq!(out.len(), K_SURVIVAL * K_SURVIVAL);
150 for a in 0..K_SURVIVAL {
151 for b in 0..K_SURVIVAL {
152 out[a * K_SURVIVAL + b] = self.h[[row, a, b]];
153 }
154 }
155 }
156 fn evaluate_full(&self) -> Array3<f64> {
157 self.h.clone()
158 }
159}
160
161impl FamilyChannelHessian for SurvivalRowHessian {
197 fn n_outputs(&self) -> usize {
198 K_SURVIVAL
199 }
200
201 fn n_subjects(&self) -> usize {
202 self.h.shape()[0]
203 }
204
205 fn fill_subject(&self, i: usize, out: &mut [f64]) {
206 assert_eq!(out.len(), K_SURVIVAL * K_SURVIVAL);
207 for a in 0..K_SURVIVAL {
208 for b in 0..K_SURVIVAL {
209 out[a * K_SURVIVAL + b] = self.h[[i, a, b]];
210 }
211 }
212 }
213
214 fn evaluate_full(&self) -> ndarray::Array3<f64> {
215 self.h.clone()
216 }
217
218 fn channel_hessian_at(
219 &self,
220 beta: &[f64],
221 family_scalars: Option<&Arc<dyn std::any::Any + Send + Sync>>,
222 ) -> Result<Arc<dyn FamilyChannelHessian>, String> {
223 use crate::survival::marginal_slope::SurvivalMarginalSlopeFamilyScalars;
224
225 let scalars_opt =
226 family_scalars.and_then(|a| a.downcast_ref::<SurvivalMarginalSlopeFamilyScalars>());
227
228 let beta_nontrivial = beta
230 .iter()
231 .any(|&b| b.abs() > BETA_NONTRIVIAL_ABS_THRESHOLD);
232
233 match scalars_opt {
234 None if beta_nontrivial => {
235 Err(
238 "SurvivalRowHessian::channel_hessian_at: beta is non-trivial but \
239 family_scalars is None; supply SurvivalMarginalSlopeFamilyScalars \
240 via FamilyLinearizationState::family_scalars to evaluate W(β) \
241 correctly (same contract as T26 Jacobian callbacks)."
242 .to_string(),
243 )
244 }
245 None => {
246 Ok(Arc::new(gam_problem::TensorChannelHessian {
248 h: self.h.clone(),
249 }))
250 }
251 Some(sc) => {
252 let n = self.h.shape()[0];
253 if sc.q0_i.len() != n
254 || sc.q1_i.len() != n
255 || sc.qd1_i.len() != n
256 || sc.g_i.len() != n
257 || sc.z_i.len() != n
258 {
259 return Err(format!(
260 "SurvivalRowHessian::channel_hessian_at: scalars length mismatch \
261 (expected n={n}, got q0={} q1={} qd1={} g={} z={})",
262 sc.q0_i.len(),
263 sc.q1_i.len(),
264 sc.qd1_i.len(),
265 sc.g_i.len(),
266 sc.z_i.len(),
267 ));
268 }
269 let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
286 for i in 0..n {
287 let q0 = sc.q0_i[i];
288 let q1 = sc.q1_i[i];
289 let qd1 = sc.qd1_i[i];
290 let g = sc.g_i[i];
291 let z = sc.z_i[i];
292 match crate::survival::marginal_slope::row_primary_for_compiler(
295 q0, q1, qd1, g, z, 1.0, 1.0, crate::survival::marginal_slope::DEFAULT_SURVIVAL_MARGINAL_SLOPE_DERIVATIVE_GUARD,
298 sc.s, ) {
300 Ok((_nll, _grad, hess)) => {
301 let mut h_i = ndarray::Array2::<f64>::zeros((K_SURVIVAL, K_SURVIVAL));
302 for a in 0..K_SURVIVAL {
303 for b in 0..K_SURVIVAL {
304 h_i[[a, b]] = hess[a][b];
305 }
306 }
307 let clamped = psd_clamp_4x4(&h_i);
308 for a in 0..K_SURVIVAL {
309 for b in 0..K_SURVIVAL {
310 h_full[[i, a, b]] = clamped[[a, b]];
311 }
312 }
313 }
314 Err(_) => {
315 for a in 0..K_SURVIVAL {
318 for b in 0..K_SURVIVAL {
319 h_full[[i, a, b]] = self.h[[i, a, b]];
320 }
321 }
322 }
323 }
324 }
325 Ok(Arc::new(SurvivalRowHessian::from_full(h_full)))
326 }
327 }
328 }
329}
330
331fn psd_clamp_4x4(m: &Array2<f64>) -> Array2<f64> {
336 let k = m.nrows();
337 let (evals, evecs) = match m.eigh(Side::Lower) {
338 Ok(pair) => pair,
339 Err(_) => {
340 let mut out = Array2::<f64>::zeros((k, k));
341 for i in 0..k {
342 out[[i, i]] = m[[i, i]].max(0.0);
343 }
344 return out;
345 }
346 };
347 let mut out = Array2::<f64>::zeros((k, k));
348 for i in 0..k {
349 for j in 0..k {
350 let mut acc = 0.0;
351 for l in 0..k {
352 acc += evecs[[i, l]] * evals[l].max(0.0) * evecs[[j, l]];
353 }
354 out[[i, j]] = acc;
355 }
356 }
357 out
358}
359
360pub struct TimeBlockOperator {
363 dq0: Array2<f64>,
364 dq1: Array2<f64>,
365 dqd1: Array2<f64>,
366}
367
368impl TimeBlockOperator {
369 pub fn new(dq0: Array2<f64>, dq1: Array2<f64>, dqd1: Array2<f64>) -> Self {
370 assert_eq!(dq0.dim(), dq1.dim());
371 assert_eq!(dq0.dim(), dqd1.dim());
372 Self { dq0, dq1, dqd1 }
373 }
374}
375
376impl RowJacobianOperator for TimeBlockOperator {
377 fn k(&self) -> usize {
378 K_SURVIVAL
379 }
380 fn ncols(&self) -> usize {
381 self.dq0.ncols()
382 }
383 fn nrows(&self) -> usize {
384 self.dq0.nrows()
385 }
386 fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
387 assert_eq!(out.len(), K_SURVIVAL);
388 assert_eq!(delta_beta.len(), self.dq0.ncols());
389 let mut acc = [0.0_f64; K_SURVIVAL];
390 for (j, &b) in delta_beta.iter().enumerate() {
391 acc[0] += self.dq0[[row, j]] * b;
392 acc[1] += self.dq1[[row, j]] * b;
393 acc[2] += self.dqd1[[row, j]] * b;
394 }
395 out.copy_from_slice(&acc);
396 }
397 fn evaluate_full(&self) -> Array3<f64> {
398 let n = self.dq0.nrows();
399 let p = self.dq0.ncols();
400 let mut out = Array3::<f64>::zeros((n, p, K_SURVIVAL));
401 for i in 0..n {
402 for j in 0..p {
403 out[[i, j, 0]] = self.dq0[[i, j]];
404 out[[i, j, 1]] = self.dq1[[i, j]];
405 out[[i, j, 2]] = self.dqd1[[i, j]];
406 }
407 }
408 out
409 }
410 fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
411 let n = self.dq0.nrows();
417 let p = self.dq0.ncols();
418 scale_jacobian_by_sqrt_h_with(n, p, K_SURVIVAL, h_full, |i, a, c| match c {
419 0 => self.dq0[[i, a]],
420 1 => self.dq1[[i, a]],
421 2 => self.dqd1[[i, a]],
422 _ => 0.0,
423 })
424 }
425}
426
427pub struct QChannelBlockOperator {
432 dq: Array2<f64>,
433 dqd1: Array2<f64>,
434}
435
436impl QChannelBlockOperator {
437 pub fn new(dq: Array2<f64>, dqd1: Array2<f64>) -> Self {
438 assert_eq!(dq.dim(), dqd1.dim());
439 Self { dq, dqd1 }
440 }
441}
442
443impl RowJacobianOperator for QChannelBlockOperator {
444 fn k(&self) -> usize {
445 K_SURVIVAL
446 }
447 fn ncols(&self) -> usize {
448 self.dq.ncols()
449 }
450 fn nrows(&self) -> usize {
451 self.dq.nrows()
452 }
453 fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
454 assert_eq!(out.len(), K_SURVIVAL);
455 assert_eq!(delta_beta.len(), self.dq.ncols());
456 let mut dq_acc = 0.0;
457 let mut dqd_acc = 0.0;
458 for (j, &b) in delta_beta.iter().enumerate() {
459 dq_acc += self.dq[[row, j]] * b;
460 dqd_acc += self.dqd1[[row, j]] * b;
461 }
462 out[0] = dq_acc;
463 out[1] = dq_acc;
464 out[2] = dqd_acc;
465 out[3] = 0.0;
466 }
467 fn evaluate_full(&self) -> Array3<f64> {
468 let n = self.dq.nrows();
469 let p = self.dq.ncols();
470 let mut out = Array3::<f64>::zeros((n, p, K_SURVIVAL));
471 for i in 0..n {
472 for j in 0..p {
473 let v = self.dq[[i, j]];
474 out[[i, j, 0]] = v;
475 out[[i, j, 1]] = v;
476 out[[i, j, 2]] = self.dqd1[[i, j]];
477 }
478 }
479 out
480 }
481 fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
482 let n = self.dq.nrows();
486 let p = self.dq.ncols();
487 scale_jacobian_by_sqrt_h_with(n, p, K_SURVIVAL, h_full, |i, a, c| match c {
488 0 | 1 => self.dq[[i, a]],
489 2 => self.dqd1[[i, a]],
490 _ => 0.0,
491 })
492 }
493}
494
495pub struct LogslopeBlockOperator {
498 dg: Array2<f64>,
499}
500
501impl LogslopeBlockOperator {
502 pub fn new(dg: Array2<f64>) -> Self {
503 Self { dg }
504 }
505}
506
507impl RowJacobianOperator for LogslopeBlockOperator {
508 fn k(&self) -> usize {
509 K_SURVIVAL
510 }
511 fn ncols(&self) -> usize {
512 self.dg.ncols()
513 }
514 fn nrows(&self) -> usize {
515 self.dg.nrows()
516 }
517 fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
518 assert_eq!(out.len(), K_SURVIVAL);
519 assert_eq!(delta_beta.len(), self.dg.ncols());
520 let mut acc = 0.0;
521 for (j, &b) in delta_beta.iter().enumerate() {
522 acc += self.dg[[row, j]] * b;
523 }
524 out[0] = 0.0;
525 out[1] = 0.0;
526 out[2] = 0.0;
527 out[3] = acc;
528 }
529 fn evaluate_full(&self) -> Array3<f64> {
530 let n = self.dg.nrows();
531 let p = self.dg.ncols();
532 let mut out = Array3::<f64>::zeros((n, p, K_SURVIVAL));
533 for i in 0..n {
534 for j in 0..p {
535 out[[i, j, 3]] = self.dg[[i, j]];
536 }
537 }
538 out
539 }
540 fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
541 let n = self.dg.nrows();
546 let p = self.dg.ncols();
547 scale_jacobian_by_sqrt_h_with(n, p, K_SURVIVAL, h_full, |i, a, c| {
548 if c == 3 { self.dg[[i, a]] } else { 0.0 }
549 })
550 }
551}
552
553pub struct SurvivalCompilerInputs {
557 pub operators: Vec<Arc<dyn RowJacobianOperator>>,
558 pub ordering: Vec<BlockOrder>,
559}
560
561pub struct SurvivalParametricCompiled {
577 pub v_time: Array2<f64>,
578 pub v_marginal: Array2<f64>,
579 pub v_logslope: Array2<f64>,
580 pub drops_by_block: (usize, usize, usize),
585}
586
587fn wrap_design_with_transform(
588 raw: DesignMatrix,
589 v: &Array2<f64>,
590 context: &str,
591) -> Result<DesignMatrix, String> {
592 if raw.ncols() != v.nrows() {
593 return Err(format!(
594 "{context}: raw design has {} cols but V has {} rows (V is {}×{})",
595 raw.ncols(),
596 v.nrows(),
597 v.nrows(),
598 v.ncols(),
599 ));
600 }
601 let inner_dense = match raw {
602 DesignMatrix::Dense(d) => d,
603 DesignMatrix::Sparse(_) => {
604 let dense = raw
605 .try_to_dense_by_chunks(&format!("{context} sparse→dense for V apply"))
606 .map_err(|reason| format!("{context}: densify failed: {reason}"))?;
607 DenseDesignMatrix::from(dense)
608 }
609 };
610 let op = CoefficientTransformOperator::new(inner_dense, v.clone())
611 .map_err(|reason| format!("{context}: CoefficientTransformOperator::new: {reason}"))?;
612 Ok(DesignMatrix::Dense(DenseDesignMatrix::from(Arc::new(op))))
613}
614
615pub struct SurvivalParametricCompiledPerTerm {
623 pub v_time_per_term: Vec<Array2<f64>>,
624 pub v_marginal_per_term: Vec<Array2<f64>>,
625 pub v_logslope_per_term: Vec<Array2<f64>>,
626 pub r_lw_per_term: Vec<Option<Array2<f64>>>,
633 pub drops_by_block: (usize, usize, usize),
635}
636
637pub fn compile_survival_parametric_designs_per_term(
656 time_dq0: Array2<f64>,
657 time_dq1: Array2<f64>,
658 time_dqd1: Array2<f64>,
659 time_partition: &[std::ops::Range<usize>],
660 marginal_dq: Array2<f64>,
661 marginal_dqd1: Array2<f64>,
662 marginal_partition: &[std::ops::Range<usize>],
663 logslope_dg: Array2<f64>,
664 logslope_partition: &[std::ops::Range<usize>],
665 row_hess: &dyn RowHessian,
666 protect_time: bool,
667) -> Result<SurvivalParametricCompiledPerTerm, String> {
668 use gam_identifiability::families::compiler::compile_protected;
669
670 let p_time = time_dq0.ncols();
671 let p_marg = marginal_dq.ncols();
672 let p_log = logslope_dg.ncols();
673 validate_partition(time_partition, p_time, "time")?;
674 validate_partition(marginal_partition, p_marg, "marginal")?;
675 validate_partition(logslope_partition, p_log, "logslope")?;
676
677 let mut operators: Vec<Arc<dyn RowJacobianOperator>> = Vec::new();
681 let mut ordering: Vec<BlockOrder> = Vec::new();
682 for range in time_partition {
683 let dq0 = time_dq0.slice(ndarray::s![.., range.clone()]).to_owned();
684 let dq1 = time_dq1.slice(ndarray::s![.., range.clone()]).to_owned();
685 let dqd1 = time_dqd1.slice(ndarray::s![.., range.clone()]).to_owned();
686 operators.push(Arc::new(TimeBlockOperator::new(dq0, dq1, dqd1)));
687 ordering.push(BlockOrder::Time);
688 }
689 for range in marginal_partition {
690 let dq = marginal_dq.slice(ndarray::s![.., range.clone()]).to_owned();
691 let dqd1 = marginal_dqd1
692 .slice(ndarray::s![.., range.clone()])
693 .to_owned();
694 operators.push(Arc::new(QChannelBlockOperator::new(dq, dqd1)));
695 ordering.push(BlockOrder::Marginal);
696 }
697 for range in logslope_partition {
698 let dg = logslope_dg.slice(ndarray::s![.., range.clone()]).to_owned();
699 operators.push(Arc::new(LogslopeBlockOperator::new(dg)));
700 ordering.push(BlockOrder::Logslope);
701 }
702
703 let n_time = time_partition.len();
712 let protected: Vec<bool> = if protect_time {
713 (0..operators.len()).map(|i| i < n_time).collect()
714 } else {
715 Vec::new()
716 };
717 let compiled =
718 compile_protected(&operators, row_hess, &ordering, &protected).map_err(|e| {
719 format!("identifiability::families::compiler::compile (per-term) failed: {e}")
720 })?;
721 let blocks = compiled.blocks;
722 let n_marg = marginal_partition.len();
723 let n_log = logslope_partition.len();
724 if blocks.len() != n_time + n_marg + n_log {
725 return Err(format!(
726 "per-term compile: expected {} compiled blocks (time={}, marg={}, log={}), got {}",
727 n_time + n_marg + n_log,
728 n_time,
729 n_marg,
730 n_log,
731 blocks.len(),
732 ));
733 }
734 let mut iter = blocks.into_iter();
735 let mut v_time_per_term: Vec<Array2<f64>> = Vec::with_capacity(n_time);
736 let mut r_time_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_time);
737 for _ in 0..n_time {
738 let blk = iter.next().unwrap();
739 v_time_per_term.push(blk.t_lw);
740 r_time_per_term.push(blk.r_lw);
741 }
742 let mut v_marginal_per_term: Vec<Array2<f64>> = Vec::with_capacity(n_marg);
743 let mut r_marginal_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_marg);
744 for _ in 0..n_marg {
745 let blk = iter.next().unwrap();
746 v_marginal_per_term.push(blk.t_lw);
747 r_marginal_per_term.push(blk.r_lw);
748 }
749 let mut v_logslope_per_term: Vec<Array2<f64>> = Vec::with_capacity(n_log);
750 let mut r_logslope_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_log);
751 for _ in 0..n_log {
752 let blk = iter.next().unwrap();
753 v_logslope_per_term.push(blk.t_lw);
754 r_logslope_per_term.push(blk.r_lw);
755 }
756 let mut r_lw_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_time + n_marg + n_log);
757 r_lw_per_term.extend(r_time_per_term);
758 r_lw_per_term.extend(r_marginal_per_term);
759 r_lw_per_term.extend(r_logslope_per_term);
760 let drops_time: usize = time_partition
761 .iter()
762 .zip(v_time_per_term.iter())
763 .map(|(r, v)| r.len().saturating_sub(v.ncols()))
764 .sum();
765 let drops_marg: usize = marginal_partition
766 .iter()
767 .zip(v_marginal_per_term.iter())
768 .map(|(r, v)| r.len().saturating_sub(v.ncols()))
769 .sum();
770 let drops_log: usize = logslope_partition
771 .iter()
772 .zip(v_logslope_per_term.iter())
773 .map(|(r, v)| r.len().saturating_sub(v.ncols()))
774 .sum();
775 Ok(SurvivalParametricCompiledPerTerm {
776 v_time_per_term,
777 v_marginal_per_term,
778 v_logslope_per_term,
779 r_lw_per_term,
780 drops_by_block: (drops_time, drops_marg, drops_log),
781 })
782}
783
784fn validate_partition(
785 partition: &[std::ops::Range<usize>],
786 p_block: usize,
787 label: &str,
788) -> Result<(), String> {
789 if partition.is_empty() {
790 if p_block == 0 {
791 return Ok(());
792 }
793 return Err(format!(
794 "{label} partition empty but block has p={p_block} columns"
795 ));
796 }
797 if partition[0].start != 0 {
798 return Err(format!(
799 "{label} partition must start at 0, got start={}",
800 partition[0].start
801 ));
802 }
803 if partition.last().unwrap().end != p_block {
804 return Err(format!(
805 "{label} partition must cover [0, {p_block}); last range ends at {}",
806 partition.last().unwrap().end
807 ));
808 }
809 for w in partition.windows(2) {
810 if w[0].end != w[1].start {
811 return Err(format!(
812 "{label} partition has gap/overlap between [{}..{}) and [{}..{})",
813 w[0].start, w[0].end, w[1].start, w[1].end
814 ));
815 }
816 if w[0].is_empty() {
817 return Err(format!(
818 "{label} partition has empty range [{}..{})",
819 w[0].start, w[0].end
820 ));
821 }
822 }
823 if partition.last().unwrap().is_empty() {
824 return Err(format!("{label} partition's final range is empty",));
825 }
826 Ok(())
827}
828
829pub fn extract_term_partition_from_penalty_ranges(
835 p_block: usize,
836 penalty_ranges: &[std::ops::Range<usize>],
837) -> Vec<std::ops::Range<usize>> {
838 use std::collections::BTreeSet;
839 let mut starts: BTreeSet<usize> = BTreeSet::new();
840 starts.insert(0);
841 starts.insert(p_block);
842 for r in penalty_ranges {
843 starts.insert(r.start.min(p_block));
844 starts.insert(r.end.min(p_block));
845 }
846 let v: Vec<usize> = starts.into_iter().collect();
847 v.windows(2)
848 .filter_map(|w| if w[0] < w[1] { Some(w[0]..w[1]) } else { None })
849 .collect()
850}
851
852pub fn pull_back_blockwise_penalty_through_block_v(
875 pen: &gam_terms::smooth::BlockwisePenalty,
876 v_block: &Array2<f64>,
877) -> Result<PenaltyMatrix, String> {
878 let raw_p = v_block.nrows();
879 let compiled_p = v_block.ncols();
880 let block_p = pen.col_range.len();
881 let embed_start = pen.col_range.start;
882 let embed_end = pen.col_range.end;
883 if embed_end > raw_p {
884 return Err(format!(
885 "pull_back_blockwise_penalty_through_block_v: penalty col_range {embed_start}..{embed_end} \
886 exceeds block raw width {raw_p}"
887 ));
888 }
889 if pen.local.nrows() != block_p || pen.local.ncols() != block_p {
890 return Err(format!(
891 "pull_back_blockwise_penalty_through_block_v: penalty local is {}x{} but col_range \
892 width is {block_p}",
893 pen.local.nrows(),
894 pen.local.ncols(),
895 ));
896 }
897 let mut embedded = Array2::<f64>::zeros((raw_p, raw_p));
898 if block_p > 0 {
899 let mut dst =
900 embedded.slice_mut(ndarray::s![embed_start..embed_end, embed_start..embed_end]);
901 for i in 0..block_p {
902 for j in 0..block_p {
903 dst[[i, j]] = pen.local[[i, j]];
904 }
905 }
906 }
907 let temp = embedded.dot(v_block);
909 let pulled = v_block.t().dot(&temp);
910 let mut sym = Array2::<f64>::zeros((compiled_p, compiled_p));
911 for i in 0..compiled_p {
912 for j in 0..compiled_p {
913 sym[[i, j]] = 0.5 * (pulled[[i, j]] + pulled[[j, i]]);
914 }
915 }
916 Ok(PenaltyMatrix::Dense(sym))
917}
918
919pub fn compiled_map_from_per_term(
941 compiled: &SurvivalParametricCompiledPerTerm,
942) -> gam_identifiability::families::compiler::CompiledMap {
943 let mut v_all: Vec<Array2<f64>> = Vec::new();
946 v_all.extend(compiled.v_time_per_term.iter().cloned());
947 v_all.extend(compiled.v_marginal_per_term.iter().cloned());
948 v_all.extend(compiled.v_logslope_per_term.iter().cloned());
949
950 let t_full = assemble_block_triangular_t(&v_all, &compiled.r_lw_per_term);
951
952 let raw_w = |terms: &[Array2<f64>]| -> usize { terms.iter().map(|v| v.nrows()).sum() };
954 let kept_w = |terms: &[Array2<f64>]| -> usize { terms.iter().map(|v| v.ncols()).sum() };
955 let raw_time = raw_w(&compiled.v_time_per_term);
956 let raw_marg = raw_w(&compiled.v_marginal_per_term);
957 let raw_log = raw_w(&compiled.v_logslope_per_term);
958 let kept_time = kept_w(&compiled.v_time_per_term);
959 let kept_marg = kept_w(&compiled.v_marginal_per_term);
960 let kept_log = kept_w(&compiled.v_logslope_per_term);
961
962 let raw_block_ranges = vec![
963 0..raw_time,
964 raw_time..(raw_time + raw_marg),
965 (raw_time + raw_marg)..(raw_time + raw_marg + raw_log),
966 ];
967 let compiled_block_ranges = vec![
968 0..kept_time,
969 kept_time..(kept_time + kept_marg),
970 (kept_time + kept_marg)..(kept_time + kept_marg + kept_log),
971 ];
972
973 gam_identifiability::families::compiler::CompiledMap {
974 raw_from_compiled: t_full,
975 compiled_block_ranges,
976 raw_block_ranges,
977 }
978}
979
980pub fn survival_reduced_logslope_transform_effective(
1035 marginal_dq: ndarray::ArrayView2<'_, f64>,
1036 logslope_dg: ndarray::ArrayView2<'_, f64>,
1037 row_hess: &SurvivalRowHessian,
1038) -> Result<Option<Array2<f64>>, String> {
1039 use crate::bms::block_specs::LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL;
1040 use gam_linalg::faer_ndarray::{
1041 FaerArrayView, factorize_symmetricwith_fallback, fast_atb, fast_xt_diag_x, fast_xt_diag_y,
1042 };
1043
1044 let n = marginal_dq.nrows();
1045 let p_m = marginal_dq.ncols();
1046 let p_log = logslope_dg.ncols();
1047 if p_m == 0 || p_log == 0 {
1048 return Ok(None);
1049 }
1050 if logslope_dg.nrows() != n || row_hess.h.shape()[0] != n {
1051 return Err(format!(
1052 "survival reduced logslope: row mismatch marginal={n}, logslope={}, row_hess={}",
1053 logslope_dg.nrows(),
1054 row_hess.h.shape()[0],
1055 ));
1056 }
1057
1058 let mut w_mm = Array1::<f64>::zeros(n);
1061 let mut w_mg = Array1::<f64>::zeros(n);
1062 let mut w_gg = Array1::<f64>::zeros(n);
1063 for i in 0..n {
1064 w_mm[i] = row_hess.h[[i, 0, 0]] + row_hess.h[[i, 1, 1]];
1065 w_mg[i] = row_hess.h[[i, 0, 3]] + row_hess.h[[i, 1, 3]];
1066 w_gg[i] = row_hess.h[[i, 3, 3]];
1067 if !(w_mm[i].is_finite() && w_mg[i].is_finite() && w_gg[i].is_finite()) {
1068 return Err("survival reduced logslope: non-finite row Hessian weight".to_string());
1069 }
1070 }
1071
1072 let marg = marginal_dq.to_owned();
1073 let log = logslope_dg.to_owned();
1074
1075 let c_gram = fast_xt_diag_x(&log, &w_gg);
1078 let energy_scale = (0..p_log).map(|i| c_gram[[i, i]]).fold(0.0_f64, f64::max);
1079 if !energy_scale.is_finite() || energy_scale <= 0.0 {
1080 return Ok(None);
1081 }
1082
1083 let mut a_gram = fast_xt_diag_x(&marg, &w_mm);
1087 let a_scale = (0..p_m).map(|i| a_gram[[i, i]]).fold(0.0_f64, f64::max);
1088 let a_ridge = (a_scale * LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL).max(f64::EPSILON);
1089 for i in 0..p_m {
1090 a_gram[[i, i]] += a_ridge;
1091 }
1092
1093 let b_cross = fast_xt_diag_y(&marg, &w_mg, &log);
1095 let a_view = FaerArrayView::new(&a_gram);
1096 let a_factor = factorize_symmetricwith_fallback(a_view.as_ref(), Side::Lower).map_err(|e| {
1097 format!("survival reduced logslope: marginal effective Gram factorization failed: {e}")
1098 })?;
1099 let b_view = FaerArrayView::new(&b_cross);
1100 let solved = a_factor.solve(b_view.as_ref()); let a_inv_b = Array2::from_shape_fn((p_m, p_log), |(i, j)| solved[(i, j)]);
1102 let schur = fast_atb(&b_cross, &a_inv_b); let mut stt = &c_gram - &schur;
1104 stt = (&stt + &stt.t()) * 0.5;
1105 if stt.iter().any(|v| !v.is_finite()) {
1106 return Err(
1107 "survival reduced logslope: effective Schur Gram produced non-finite entries"
1108 .to_string(),
1109 );
1110 }
1111
1112 let (evals, evecs) = stt
1113 .eigh(Side::Lower)
1114 .map_err(|e| format!("survival reduced logslope: eigendecomposition failed: {e:?}"))?;
1115 let tol = energy_scale * LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL;
1119 let mut kept: Vec<usize> = (0..evals.len()).filter(|&i| evals[i] > tol).collect();
1120 kept.sort_by(|&a, &b| {
1121 evals[b]
1122 .partial_cmp(&evals[a])
1123 .unwrap_or(std::cmp::Ordering::Equal)
1124 });
1125 let r = kept.len();
1126 if r == p_log || r == 0 {
1129 return Ok(None);
1130 }
1131 let mut transform = Array2::<f64>::zeros((p_log, r));
1132 for (out_col, &src) in kept.iter().enumerate() {
1133 transform.column_mut(out_col).assign(&evecs.column(src));
1134 }
1135 if transform.iter().any(|v| !v.is_finite()) {
1136 return Err(
1137 "survival reduced logslope: reduced transform produced non-finite entries".to_string(),
1138 );
1139 }
1140 Ok(Some(transform))
1141}
1142
1143pub fn survival_block_diagonal_logslope_map(
1160 p_time: usize,
1161 p_marg: usize,
1162 t_log: &Array2<f64>,
1163) -> gam_identifiability::families::compiler::CompiledMap {
1164 let p_log = t_log.nrows();
1165 let r = t_log.ncols();
1166 let raw_total = p_time + p_marg + p_log;
1167 let compiled_total = p_time + p_marg + r;
1168 let mut t_full = Array2::<f64>::zeros((raw_total, compiled_total));
1169 for i in 0..p_time {
1170 t_full[[i, i]] = 1.0;
1171 }
1172 for i in 0..p_marg {
1173 t_full[[p_time + i, p_time + i]] = 1.0;
1174 }
1175 for ri in 0..p_log {
1176 for cj in 0..r {
1177 t_full[[p_time + p_marg + ri, p_time + p_marg + cj]] = t_log[[ri, cj]];
1178 }
1179 }
1180 gam_identifiability::families::compiler::CompiledMap {
1181 raw_from_compiled: t_full,
1182 compiled_block_ranges: vec![
1183 0..p_time,
1184 p_time..(p_time + p_marg),
1185 (p_time + p_marg)..compiled_total,
1186 ],
1187 raw_block_ranges: vec![
1188 0..p_time,
1189 p_time..(p_time + p_marg),
1190 (p_time + p_marg)..raw_total,
1191 ],
1192 }
1193}
1194
1195pub fn apply_compiled_map_to_designs(
1223 map: &gam_identifiability::families::compiler::CompiledMap,
1224 time_design_entry: DesignMatrix,
1225 time_design_exit: DesignMatrix,
1226 time_design_derivative_exit: DesignMatrix,
1227 marginal_design: DesignMatrix,
1228 logslope_design: DesignMatrix,
1229 time_penalties: &[gam_terms::smooth::BlockwisePenalty],
1230 marginal_penalties: &[gam_terms::smooth::BlockwisePenalty],
1231 logslope_penalties: &[gam_terms::smooth::BlockwisePenalty],
1232) -> Result<CompiledSurvivalDesignsVMExact, String> {
1233 if map.raw_block_ranges.len() != 3 || map.compiled_block_ranges.len() != 3 {
1234 return Err(format!(
1235 "apply_compiled_map_to_designs: expected exactly 3 blocks (time, marginal, logslope), \
1236 got {} raw / {} compiled",
1237 map.raw_block_ranges.len(),
1238 map.compiled_block_ranges.len(),
1239 ));
1240 }
1241 let time_raw = map.raw_block_ranges[0].clone();
1242 let marg_raw = map.raw_block_ranges[1].clone();
1243 let log_raw = map.raw_block_ranges[2].clone();
1244 let time_compiled = map.compiled_block_ranges[0].clone();
1245 let marg_compiled = map.compiled_block_ranges[1].clone();
1246 let log_compiled = map.compiled_block_ranges[2].clone();
1247
1248 let t = &map.raw_from_compiled;
1249 let raw_total = t.nrows();
1250 let compiled_total = t.ncols();
1251 let expected_raw_total = log_raw.end;
1252 if raw_total != expected_raw_total {
1253 return Err(format!(
1254 "apply_compiled_map_to_designs: T has {raw_total} raw rows but block ranges sum to \
1255 {expected_raw_total}"
1256 ));
1257 }
1258 let expected_compiled_total = log_compiled.end;
1259 if compiled_total != expected_compiled_total {
1260 return Err(format!(
1261 "apply_compiled_map_to_designs: T has {compiled_total} compiled cols but block ranges \
1262 sum to {expected_compiled_total}"
1263 ));
1264 }
1265
1266 let v_time = t
1267 .slice(ndarray::s![time_raw.clone(), time_compiled.clone()])
1268 .to_owned();
1269 let v_marg = t
1270 .slice(ndarray::s![marg_raw.clone(), marg_compiled.clone()])
1271 .to_owned();
1272 let v_log = t
1273 .slice(ndarray::s![log_raw.clone(), log_compiled.clone()])
1274 .to_owned();
1275
1276 let time_entry_out =
1277 wrap_design_with_transform(time_design_entry, &v_time, "compiled-map: time entry")?;
1278 let time_exit_out =
1279 wrap_design_with_transform(time_design_exit, &v_time, "compiled-map: time exit")?;
1280 let time_deriv_out = wrap_design_with_transform(
1281 time_design_derivative_exit,
1282 &v_time,
1283 "compiled-map: time derivative_exit",
1284 )?;
1285 let marg_out = wrap_design_with_transform(marginal_design, &v_marg, "compiled-map: marginal")?;
1286 let log_out = wrap_design_with_transform(logslope_design, &v_log, "compiled-map: logslope")?;
1287
1288 let pull_set = |pens: &[gam_terms::smooth::BlockwisePenalty],
1309 v_block: &Array2<f64>,
1310 channel: &str|
1311 -> Result<Vec<PenaltyMatrix>, String> {
1312 pens.iter()
1313 .map(|p| {
1314 pull_back_blockwise_penalty_through_block_v(p, v_block).map_err(|e| {
1315 format!("apply_compiled_map_to_designs: {channel} penalty pullback: {e}")
1316 })
1317 })
1318 .collect()
1319 };
1320
1321 let time_penalties = pull_set(time_penalties, &v_time, "time")?;
1322 let marginal_penalties = pull_set(marginal_penalties, &v_marg, "marginal")?;
1323 let logslope_penalties = pull_set(logslope_penalties, &v_log, "logslope")?;
1324 validate_block_penalty_shapes("time", time_exit_out.ncols(), &time_penalties)?;
1325 validate_block_penalty_shapes("marginal", marg_out.ncols(), &marginal_penalties)?;
1326 validate_block_penalty_shapes("logslope", log_out.ncols(), &logslope_penalties)?;
1327
1328 Ok(CompiledSurvivalDesignsVMExact {
1329 time_design_entry: time_entry_out,
1330 time_design_exit: time_exit_out,
1331 time_design_derivative_exit: time_deriv_out,
1332 marginal_design: marg_out,
1333 logslope_design: log_out,
1334 time_penalties,
1335 marginal_penalties,
1336 logslope_penalties,
1337 })
1338}
1339
1340fn validate_block_penalty_shapes(
1341 block: &str,
1342 width: usize,
1343 penalties: &[PenaltyMatrix],
1344) -> Result<(), String> {
1345 for (idx, penalty) in penalties.iter().enumerate() {
1346 let shape = penalty.shape();
1347 if shape != (width, width) {
1348 return Err(format!(
1349 "apply_compiled_map_to_designs: {block} penalty {idx} must be {width}x{width}, got {}x{}",
1350 shape.0, shape.1
1351 ));
1352 }
1353 }
1354 Ok(())
1355}
1356
1357pub fn compile_survival_parametric_designs(
1385 time_dq0: Array2<f64>,
1386 time_dq1: Array2<f64>,
1387 time_dqd1: Array2<f64>,
1388 marginal_dq: Array2<f64>,
1389 marginal_dqd1: Array2<f64>,
1390 logslope_dg: Array2<f64>,
1391 row_hess: &dyn RowHessian,
1392) -> Result<SurvivalParametricCompiled, String> {
1393 use gam_identifiability::families::compiler::compile;
1394
1395 let p_time_raw = time_dq0.ncols();
1396 let p_marg_raw = marginal_dq.ncols();
1397 let p_log_raw = logslope_dg.ncols();
1398
1399 let inputs = build_survival_compiler_inputs(
1400 time_dq0,
1401 time_dq1,
1402 time_dqd1,
1403 marginal_dq,
1404 marginal_dqd1,
1405 logslope_dg,
1406 None,
1407 None,
1408 );
1409 if inputs.operators.len() != 3 {
1410 return Err(format!(
1411 "compile_survival_parametric_designs: expected exactly 3 parametric operators \
1412 (time, marginal, logslope); got {}",
1413 inputs.operators.len(),
1414 ));
1415 }
1416 let compiled = compile(&inputs.operators, row_hess, &inputs.ordering)
1417 .map_err(|e| format!("identifiability::families::compiler::compile failed: {e}"))?;
1418 if compiled.blocks.len() != 3 {
1419 return Err(format!(
1420 "compile_survival_parametric_designs: compiler emitted {} blocks; expected 3",
1421 compiled.blocks.len(),
1422 ));
1423 }
1424 let v_time = compiled.blocks[0].t_lw.clone();
1425 let v_marginal = compiled.blocks[1].t_lw.clone();
1426 let v_logslope = compiled.blocks[2].t_lw.clone();
1427 let drops_by_block = (
1428 p_time_raw.saturating_sub(v_time.ncols()),
1429 p_marg_raw.saturating_sub(v_marginal.ncols()),
1430 p_log_raw.saturating_sub(v_logslope.ncols()),
1431 );
1432 Ok(SurvivalParametricCompiled {
1433 v_time,
1434 v_marginal,
1435 v_logslope,
1436 drops_by_block,
1437 })
1438}
1439
1440pub fn build_survival_compiler_inputs(
1452 time_dq0: Array2<f64>,
1453 time_dq1: Array2<f64>,
1454 time_dqd1: Array2<f64>,
1455 marginal_dq: Array2<f64>,
1456 marginal_dqd1: Array2<f64>,
1457 logslope_dg: Array2<f64>,
1458 score_warp_dq_dqd1: Option<(Array2<f64>, Array2<f64>)>,
1459 link_dev_dq_dqd1: Option<(Array2<f64>, Array2<f64>)>,
1460) -> SurvivalCompilerInputs {
1461 let mut operators: Vec<Arc<dyn RowJacobianOperator>> = Vec::with_capacity(5);
1462 let mut ordering: Vec<BlockOrder> = Vec::with_capacity(5);
1463
1464 operators.push(Arc::new(TimeBlockOperator::new(
1465 time_dq0, time_dq1, time_dqd1,
1466 )));
1467 ordering.push(BlockOrder::Time);
1468
1469 operators.push(Arc::new(QChannelBlockOperator::new(
1470 marginal_dq,
1471 marginal_dqd1,
1472 )));
1473 ordering.push(BlockOrder::Marginal);
1474
1475 operators.push(Arc::new(LogslopeBlockOperator::new(logslope_dg)));
1476 ordering.push(BlockOrder::Logslope);
1477
1478 if let Some((dq, dqd1)) = score_warp_dq_dqd1 {
1479 operators.push(Arc::new(QChannelBlockOperator::new(dq, dqd1)));
1480 ordering.push(BlockOrder::ScoreWarp);
1481 }
1482 if let Some((dq, dqd1)) = link_dev_dq_dqd1 {
1483 operators.push(Arc::new(QChannelBlockOperator::new(dq, dqd1)));
1484 ordering.push(BlockOrder::LinkDev);
1485 }
1486
1487 SurvivalCompilerInputs {
1488 operators,
1489 ordering,
1490 }
1491}
1492
1493pub struct CompiledSurvivalDesignsVMExact {
1512 pub time_design_entry: DesignMatrix,
1513 pub time_design_exit: DesignMatrix,
1514 pub time_design_derivative_exit: DesignMatrix,
1515 pub marginal_design: DesignMatrix,
1516 pub logslope_design: DesignMatrix,
1517 pub time_penalties: Vec<PenaltyMatrix>,
1525 pub marginal_penalties: Vec<PenaltyMatrix>,
1526 pub logslope_penalties: Vec<PenaltyMatrix>,
1527}
1528
1529#[cfg(test)]
1530mod tests {
1531 use super::*;
1532 use gam_problem::Gauge;
1533
1534 #[test]
1535 fn psd_clamp_zeros_negative_eigenvalues() {
1536 let mut m = Array2::<f64>::zeros((4, 4));
1540 m[[0, 0]] = 2.0;
1543 m[[1, 1]] = -1.0;
1544 m[[2, 2]] = 0.5;
1545 m[[3, 3]] = -0.25;
1546 let clamped = psd_clamp_4x4(&m);
1547 assert!((clamped[[0, 0]] - 2.0).abs() < 1e-12);
1548 assert!(clamped[[1, 1]].abs() < 1e-12);
1549 assert!((clamped[[2, 2]] - 0.5).abs() < 1e-12);
1550 assert!(clamped[[3, 3]].abs() < 1e-12);
1551 }
1552
1553 #[test]
1554 fn time_block_operator_evaluate_full_shape() {
1555 let n = 6;
1556 let p = 3;
1557 let dq0 = Array2::from_shape_fn((n, p), |(i, j)| (i + j) as f64);
1558 let dq1 = Array2::from_shape_fn((n, p), |(i, j)| (i as f64) * 2.0 + j as f64);
1559 let dqd1 = Array2::from_shape_fn((n, p), |(i, j)| 0.5 * ((i * j) as f64));
1560 let op = TimeBlockOperator::new(dq0.clone(), dq1.clone(), dqd1.clone());
1561 let full = op.evaluate_full();
1562 assert_eq!(full.shape(), &[n, p, K_SURVIVAL]);
1563 for i in 0..n {
1564 for j in 0..p {
1565 assert_eq!(full[[i, j, 0]], dq0[[i, j]]);
1566 assert_eq!(full[[i, j, 1]], dq1[[i, j]]);
1567 assert_eq!(full[[i, j, 2]], dqd1[[i, j]]);
1568 assert_eq!(full[[i, j, 3]], 0.0);
1569 }
1570 }
1571 }
1572
1573 #[test]
1574 fn q_channel_block_apply_row_shares_q0_q1() {
1575 let n = 5;
1576 let p = 2;
1577 let dq = Array2::from_shape_fn((n, p), |(i, j)| (i as f64) * (j as f64 + 1.0));
1578 let dqd1 = Array2::from_shape_fn((n, p), |(i, j)| (j as f64) - (i as f64));
1579 let op = QChannelBlockOperator::new(dq.clone(), dqd1.clone());
1580 let mut out = [0.0_f64; K_SURVIVAL];
1581 let delta = [1.0_f64, -0.5];
1582 op.apply_row(3, &delta, &mut out);
1583 let want_q = dq[[3, 0]] * 1.0 + dq[[3, 1]] * (-0.5);
1584 let want_qd = dqd1[[3, 0]] * 1.0 + dqd1[[3, 1]] * (-0.5);
1585 assert!((out[0] - want_q).abs() < 1e-12);
1586 assert!((out[1] - want_q).abs() < 1e-12);
1587 assert!((out[2] - want_qd).abs() < 1e-12);
1588 assert_eq!(out[3], 0.0);
1589 }
1590
1591 #[test]
1592 fn logslope_block_writes_only_g_channel() {
1593 let n = 4;
1594 let p = 2;
1595 let dg = Array2::from_shape_fn((n, p), |(i, j)| (i as f64) + 0.1 * (j as f64));
1596 let op = LogslopeBlockOperator::new(dg.clone());
1597 let mut out = [0.0_f64; K_SURVIVAL];
1598 let delta = [2.0_f64, -1.0];
1599 op.apply_row(1, &delta, &mut out);
1600 assert_eq!(out[0], 0.0);
1601 assert_eq!(out[1], 0.0);
1602 assert_eq!(out[2], 0.0);
1603 let want = dg[[1, 0]] * 2.0 + dg[[1, 1]] * (-1.0);
1604 assert!((out[3] - want).abs() < 1e-12);
1605 }
1606
1607 #[test]
1608 fn extract_term_partition_simple_cases() {
1609 let full = 0..5usize;
1610 let part = extract_term_partition_from_penalty_ranges(5, &[]);
1612 assert_eq!(part.as_slice(), std::slice::from_ref(&full));
1613 let part = extract_term_partition_from_penalty_ranges(5, std::slice::from_ref(&full));
1615 assert_eq!(part.as_slice(), std::slice::from_ref(&full));
1616 let part = extract_term_partition_from_penalty_ranges(10, &[0..3, 6..10]);
1618 assert_eq!(part, vec![0..3, 3..6, 6..10]);
1619 let part = extract_term_partition_from_penalty_ranges(6, &[0..3, 0..3, 3..6]);
1621 assert_eq!(part, vec![0..3, 3..6]);
1622 let part = extract_term_partition_from_penalty_ranges(0, &[]);
1624 assert!(part.is_empty());
1625 }
1626
1627 #[test]
1628 fn assemble_block_triangular_t_identity_when_v_eye_and_r_none() {
1629 let v_a = Array2::<f64>::eye(2);
1630 let v_b = Array2::<f64>::eye(2);
1631 let t = assemble_block_triangular_t(&[v_a, v_b], &[None, None]);
1632 assert_eq!(t.dim(), (4, 4));
1633 let eye4 = Array2::<f64>::eye(4);
1634 for i in 0..4 {
1635 for j in 0..4 {
1636 assert!((t[[i, j]] - eye4[[i, j]]).abs() < 1e-14);
1637 }
1638 }
1639 }
1640
1641 #[test]
1642 fn assemble_block_triangular_t_with_drops_and_nonzero_r() {
1643 let mut v_a = Array2::<f64>::zeros((3, 2));
1644 v_a[[0, 0]] = 1.0;
1645 v_a[[1, 0]] = 0.5;
1646 v_a[[2, 1]] = 1.0;
1647 let v_b = Array2::<f64>::eye(2);
1648 let r_ab =
1649 Array2::<f64>::from_shape_fn((3, 2), |(i, j)| 1.0 + (i as f64) + 0.25 * (j as f64));
1650 let t =
1651 assemble_block_triangular_t(&[v_a.clone(), v_b.clone()], &[None, Some(r_ab.clone())]);
1652 assert_eq!(t.dim(), (5, 4));
1653 for i in 0..3 {
1654 for j in 0..2 {
1655 assert!((t[[i, j]] - v_a[[i, j]]).abs() < 1e-14);
1656 }
1657 }
1658 for i in 0..2 {
1659 for j in 0..2 {
1660 assert!((t[[3 + i, 2 + j]] - v_b[[i, j]]).abs() < 1e-14);
1661 }
1662 }
1663 for i in 0..3 {
1664 for j in 0..2 {
1665 assert!((t[[i, 2 + j]] + r_ab[[i, j]]).abs() < 1e-14);
1666 }
1667 }
1668 for i in 0..2 {
1669 for j in 0..2 {
1670 assert_eq!(t[[3 + i, j]], 0.0);
1671 }
1672 }
1673 }
1674
1675 #[test]
1676 fn validate_partition_rejects_bad_partitions() {
1677 let bad_start = 1..5usize;
1678 let short_cover = 0..3usize;
1679 let full_cover = 0..5usize;
1680 assert!(validate_partition(std::slice::from_ref(&bad_start), 5, "test").is_err());
1682 assert!(validate_partition(std::slice::from_ref(&short_cover), 5, "test").is_err());
1684 assert!(validate_partition(&[0..2, 3..5], 5, "test").is_err());
1686 assert!(validate_partition(&[0..3, 2..5], 5, "test").is_err());
1688 assert!(validate_partition(&[0..0, 0..5], 5, "test").is_err());
1690 assert!(validate_partition(&[], 0, "test").is_ok());
1692 assert!(validate_partition(&[0..2, 2..5], 5, "test").is_ok());
1694 assert!(validate_partition(std::slice::from_ref(&full_cover), 5, "test").is_ok());
1695 }
1696
1697 #[test]
1708 fn compiled_map_penalty_pullback_is_per_block_width_with_nonzero_residual() {
1709 use gam_identifiability::families::compiler::CompiledMap;
1710 use gam_terms::smooth::BlockwisePenalty;
1711
1712 let n = 10;
1713 let v_time =
1717 Array2::<f64>::from_shape_fn(
1718 (3, 3),
1719 |(i, j)| {
1720 if i == j { 1.0 } else { 0.1 * ((i + j) as f64) }
1721 },
1722 );
1723 let v_marg = Array2::<f64>::from_shape_fn((3, 2), |(i, j)| {
1724 0.5 + 0.3 * (i as f64) - 0.2 * (j as f64)
1725 });
1726 let v_log = Array2::<f64>::from_shape_fn((2, 2), |(i, j)| if i == j { 1.2 } else { 0.4 });
1727 let r_marg = Array2::<f64>::from_shape_fn((3, 2), |(i, j)| 0.7 - 0.1 * ((i + j) as f64));
1729 let r_log =
1734 Array2::<f64>::from_shape_fn((6, 2), |(i, j)| 0.3 + 0.05 * ((i * 2 + j) as f64));
1735
1736 let t = assemble_block_triangular_t(
1737 &[v_time.clone(), v_marg.clone(), v_log.clone()],
1738 &[None, Some(r_marg.clone()), Some(r_log.clone())],
1739 );
1740 assert_eq!(t.dim(), (8, 7), "joint raw 8 × joint compiled 7");
1741
1742 let map = CompiledMap {
1743 raw_from_compiled: t.clone(),
1744 compiled_block_ranges: vec![0..3, 3..5, 5..7],
1745 raw_block_ranges: vec![0..3, 3..6, 6..8],
1746 };
1747
1748 let raw_time_entry = DesignMatrix::Dense(DenseDesignMatrix::from(
1750 Array2::<f64>::from_shape_fn((n, 3), |(i, j)| 1.0 + (i as f64) * 0.1 + (j as f64)),
1751 ));
1752 let raw_time_exit = raw_time_entry.clone();
1753 let raw_time_deriv = raw_time_entry.clone();
1754 let raw_marg = DesignMatrix::Dense(DenseDesignMatrix::from(Array2::<f64>::from_shape_fn(
1755 (n, 3),
1756 |(i, j)| 0.2 * (i as f64) - 0.3 * (j as f64),
1757 )));
1758 let raw_log = DesignMatrix::Dense(DenseDesignMatrix::from(Array2::<f64>::from_shape_fn(
1759 (n, 2),
1760 |(i, j)| 0.5 + (i as f64) * (j as f64 + 1.0),
1761 )));
1762
1763 let s_time =
1765 Array2::<f64>::from_shape_fn(
1766 (3, 3),
1767 |(i, j)| if i == j { (i + 2) as f64 } else { 0.3 },
1768 );
1769 let s_marg =
1770 Array2::<f64>::from_shape_fn(
1771 (3, 3),
1772 |(i, j)| if i == j { 1.5 + i as f64 } else { 0.2 },
1773 );
1774 let s_log = Array2::<f64>::from_shape_fn((2, 2), |(i, j)| if i == j { 2.0 } else { 0.5 });
1775 let time_pens = vec![BlockwisePenalty::new(0..3, s_time.clone())];
1776 let marg_pens = vec![BlockwisePenalty::new(0..3, s_marg.clone())];
1777 let log_pens = vec![BlockwisePenalty::new(0..2, s_log.clone())];
1778
1779 let out = apply_compiled_map_to_designs(
1780 &map,
1781 raw_time_entry,
1782 raw_time_exit,
1783 raw_time_deriv,
1784 raw_marg,
1785 raw_log,
1786 &time_pens,
1787 &marg_pens,
1788 &log_pens,
1789 )
1790 .expect("apply_compiled_map_to_designs must succeed");
1791
1792 assert_eq!(out.time_design_entry.ncols(), 3);
1794 assert_eq!(out.marginal_design.ncols(), 2);
1795 assert_eq!(out.logslope_design.ncols(), 2);
1796
1797 for s in &out.time_penalties {
1800 assert_eq!(
1801 s.as_dense_cow().dim(),
1802 (3, 3),
1803 "time penalty must be per-block 3×3, not joint-width"
1804 );
1805 }
1806 for s in &out.marginal_penalties {
1807 assert_eq!(
1808 s.as_dense_cow().dim(),
1809 (2, 2),
1810 "marginal penalty must match reduced compiled width 2, not joint 7"
1811 );
1812 }
1813 for s in &out.logslope_penalties {
1814 assert_eq!(s.as_dense_cow().dim(), (2, 2));
1815 }
1816
1817 let p_time_dense = out.time_penalties[0].as_dense_cow().into_owned();
1821 let theta_time = Array1::<f64>::from_shape_fn(3, |k| 0.4 + 0.7 * (k as f64));
1822 let gamma_time = v_time.dot(&theta_time);
1823 let lhs = theta_time.dot(&p_time_dense.dot(&theta_time));
1824 let rhs = gamma_time.dot(&s_time.dot(&gamma_time));
1825 assert!(
1826 (lhs - rhs).abs() < 1e-10,
1827 "time-block per-block pullback must be exact: lhs={lhs}, rhs={rhs}"
1828 );
1829
1830 let p_marg_dense = out.marginal_penalties[0].as_dense_cow().into_owned();
1833 let want_marg = v_marg.t().dot(&s_marg.dot(&v_marg));
1834 for i in 0..2 {
1835 for j in 0..2 {
1836 assert!(
1837 (p_marg_dense[[i, j]] - want_marg[[i, j]]).abs() < 1e-12,
1838 "marginal penalty must be V_margᵀ S_marg V_marg at ({i},{j})"
1839 );
1840 }
1841 }
1842 }
1843
1844 #[test]
1851 fn compile_survival_parametric_designs_helper_attributes_drop_to_marginal() {
1852 let n = 24;
1853 let p_time = 3;
1854 let p_marginal = 3;
1855 let p_logslope = 2;
1856 let x: Vec<f64> = (0..n)
1857 .map(|i| -1.0 + 2.0 * (i as f64) / (n as f64 - 1.0))
1858 .collect();
1859 let mut time_dq0 = Array2::<f64>::zeros((n, p_time));
1860 let mut time_dq1 = Array2::<f64>::zeros((n, p_time));
1861 let mut time_dqd1 = Array2::<f64>::zeros((n, p_time));
1862 let mut marg_dq = Array2::<f64>::zeros((n, p_marginal));
1863 let marg_dqd1 = Array2::<f64>::zeros((n, p_marginal));
1864 let mut log_dg = Array2::<f64>::zeros((n, p_logslope));
1865 for i in 0..n {
1866 time_dq0[[i, 0]] = 1.0;
1867 time_dq0[[i, 1]] = x[i];
1868 time_dq0[[i, 2]] = x[i] * x[i];
1869 time_dq1[[i, 0]] = 1.0;
1870 time_dq1[[i, 1]] = x[i];
1871 time_dq1[[i, 2]] = x[i] * x[i];
1872 time_dqd1[[i, 0]] = 0.0;
1873 time_dqd1[[i, 1]] = 1.0;
1874 time_dqd1[[i, 2]] = 2.0 * x[i];
1875 marg_dq[[i, 0]] = 1.0; marg_dq[[i, 1]] = x[i] * x[i] * x[i];
1877 marg_dq[[i, 2]] = x[i].sin();
1878 log_dg[[i, 0]] = (2.0 * x[i]).cos();
1879 log_dg[[i, 1]] = x[i].tanh();
1880 }
1881 let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
1882 for i in 0..n {
1883 for k in 0..K_SURVIVAL {
1884 h_full[[i, k, k]] = 1.0;
1885 }
1886 }
1887 let row_hess = SurvivalRowHessian::from_full(h_full);
1888 let out = compile_survival_parametric_designs(
1889 time_dq0, time_dq1, time_dqd1, marg_dq, marg_dqd1, log_dg, &row_hess,
1890 )
1891 .expect("Phase-4b parametric compile must succeed on single-direction alias");
1892 assert_eq!(out.v_time.ncols(), p_time, "time keeps all columns");
1893 assert_eq!(
1894 out.v_marginal.ncols(),
1895 p_marginal - 1,
1896 "marginal loses exactly the shared-constant direction"
1897 );
1898 assert_eq!(out.v_logslope.ncols(), p_logslope, "logslope is clean");
1899 assert_eq!(
1900 out.drops_by_block,
1901 (0, 1, 0),
1902 "attribution: zero from time/logslope, one from marginal",
1903 );
1904 }
1905
1906 #[test]
1927 fn compile_survival_three_block_with_shared_constant_drops_one_direction() {
1928 use gam_identifiability::families::compiler::compile;
1929
1930 let n = 32;
1931 let p_time = 3;
1932 let p_marginal = 3;
1933 let p_logslope = 2;
1934
1935 let x: Vec<f64> = (0..n)
1946 .map(|i| -1.0 + 2.0 * (i as f64) / (n as f64 - 1.0))
1947 .collect();
1948 let mut time_dq0 = Array2::<f64>::zeros((n, p_time));
1949 let mut time_dq1 = Array2::<f64>::zeros((n, p_time));
1950 let mut time_dqd1 = Array2::<f64>::zeros((n, p_time));
1951 for i in 0..n {
1952 time_dq0[[i, 0]] = 1.0;
1953 time_dq0[[i, 1]] = x[i];
1954 time_dq0[[i, 2]] = x[i] * x[i];
1955 time_dq1[[i, 0]] = 1.0;
1956 time_dq1[[i, 1]] = x[i];
1957 time_dq1[[i, 2]] = x[i] * x[i];
1958 time_dqd1[[i, 0]] = 0.0;
1960 time_dqd1[[i, 1]] = 1.0;
1961 time_dqd1[[i, 2]] = 2.0 * x[i];
1962 }
1963
1964 let mut marg_dq = Array2::<f64>::zeros((n, p_marginal));
1970 let marg_dqd1 = Array2::<f64>::zeros((n, p_marginal));
1971 for i in 0..n {
1972 marg_dq[[i, 0]] = 1.0;
1973 marg_dq[[i, 1]] = x[i] * x[i] * x[i];
1974 marg_dq[[i, 2]] = x[i].sin();
1975 }
1976
1977 let mut log_dg = Array2::<f64>::zeros((n, p_logslope));
1981 for i in 0..n {
1982 log_dg[[i, 0]] = (2.0 * x[i]).cos();
1983 log_dg[[i, 1]] = x[i].tanh();
1984 }
1985
1986 let inputs = build_survival_compiler_inputs(
1987 time_dq0, time_dq1, time_dqd1, marg_dq, marg_dqd1, log_dg, None, None,
1988 );
1989
1990 let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
1996 for i in 0..n {
1997 for k in 0..K_SURVIVAL {
1998 h_full[[i, k, k]] = 1.0;
1999 }
2000 }
2001 let row_hess = SurvivalRowHessian::from_full(h_full);
2002
2003 let compiled = compile(&inputs.operators, &row_hess, &inputs.ordering)
2004 .expect("survival 3-block compile must succeed; aliasing is single-direction");
2005
2006 assert_eq!(compiled.blocks.len(), 3, "expected 3 CompiledBlocks");
2008
2009 let v_time = &compiled.blocks[0].t_lw;
2014 assert_eq!(
2015 v_time.ncols(),
2016 p_time,
2017 "time block (first in ordering) must retain all {p_time} of its columns; V_time={:?}",
2018 v_time.dim(),
2019 );
2020
2021 let v_marg = &compiled.blocks[1].t_lw;
2028 assert_eq!(
2029 v_marg.ncols(),
2030 p_marginal - 1,
2031 "marginal block must lose exactly the shared-constant direction; \
2032 V_marginal cols = {}, expected {}",
2033 v_marg.ncols(),
2034 p_marginal - 1,
2035 );
2036
2037 let v_log = &compiled.blocks[2].t_lw;
2040 assert_eq!(
2041 v_log.ncols(),
2042 p_logslope,
2043 "logslope block (no shared direction) must retain all {p_logslope} columns",
2044 );
2045
2046 let raw_total = p_time + p_marginal + p_logslope;
2049 let kept_total: usize = compiled.blocks.iter().map(|b| b.t_lw.ncols()).sum();
2050 assert_eq!(
2051 kept_total,
2052 raw_total - 1,
2053 "joint kept = raw_total − aliased; got {kept_total}, expected {}",
2054 raw_total - 1,
2055 );
2056 assert_eq!(
2057 compiled.joint_rank, kept_total,
2058 "CompiledBlocks::joint_rank must match the sum of per-block t_lw widths",
2059 );
2060
2061 let v_per_term: Vec<Array2<f64>> = compiled.blocks.iter().map(|b| b.t_lw.clone()).collect();
2071 let r_per_term: Vec<Option<Array2<f64>>> = vec![None; v_per_term.len()];
2072 let gauge = Gauge::from_v_and_r(&v_per_term, &r_per_term);
2073
2074 let mut expected_reduced = vec![0usize];
2075 let mut expected_raw = vec![0usize];
2076 for b in &compiled.blocks {
2077 let prev_reduced = *expected_reduced.last().unwrap();
2078 expected_reduced.push(prev_reduced + b.t_lw.ncols());
2079 let prev_raw = *expected_raw.last().unwrap();
2080 expected_raw.push(prev_raw + b.t_lw.nrows());
2081 }
2082 assert_eq!(
2083 *gauge.block_starts_reduced.last().unwrap(),
2084 compiled.joint_rank,
2085 "SMGS lift reduced dimension must equal the compiled joint_rank",
2086 );
2087 assert_eq!(
2088 gauge.block_starts_reduced, expected_reduced,
2089 "SMGS lift reduced block boundaries must match the compiled kept widths",
2090 );
2091 assert_eq!(
2092 gauge.block_starts_raw, expected_raw,
2093 "SMGS lift raw block boundaries must match the compiled per-block raw widths",
2094 );
2095
2096 for (bi, block) in compiled.blocks.iter().enumerate() {
2101 for j in 0..block.t_lw.ncols() {
2102 let col = block.t_lw.column(j);
2103 assert!(
2104 col.iter().all(|v| v.is_finite()),
2105 "block {bi} kept direction {j} has a non-finite entry",
2106 );
2107 let norm = col.dot(&col).sqrt();
2108 assert!(
2109 norm > 1e-10,
2110 "block {bi} kept direction {j} is degenerate (norm {norm:.3e})",
2111 );
2112 }
2113 }
2114 }
2115
2116 #[test]
2119 fn smgs_lift_via_t_identity_passes_through() {
2120 let v0 = Array2::<f64>::eye(3);
2121 let v1 = Array2::<f64>::eye(2);
2122 let v_per_term = vec![v0, v1];
2123 let r_per_term: Vec<Option<Array2<f64>>> = vec![None, None];
2124 let lift = Gauge::from_v_and_r(&v_per_term, &r_per_term);
2125 assert_eq!(lift.t_full.dim(), (5, 5));
2126 assert_eq!(lift.block_starts_reduced, vec![0, 3, 5]);
2127 assert_eq!(lift.block_starts_raw, vec![0, 3, 5]);
2128 for i in 0..5 {
2129 for j in 0..5 {
2130 let want = if i == j { 1.0 } else { 0.0 };
2131 assert!((lift.t_full[[i, j]] - want).abs() < 1e-14);
2132 }
2133 }
2134 let theta_0 = Array1::from(vec![1.0_f64, -2.0, 3.5]);
2135 let theta_1 = Array1::from(vec![-0.5_f64, 7.0]);
2136 let lifted = lift.lift_block_betas(&[theta_0.clone(), theta_1.clone()]);
2137 assert_eq!(lifted.len(), 2);
2138 for (a, b) in theta_0.iter().zip(lifted[0].iter()) {
2139 assert!((a - b).abs() < 1e-14);
2140 }
2141 for (a, b) in theta_1.iter().zip(lifted[1].iter()) {
2142 assert!((a - b).abs() < 1e-14);
2143 }
2144 }
2145
2146 #[test]
2150 fn smgs_lift_via_t_two_block_with_residualisation() {
2151 let v_a = Array2::<f64>::eye(3);
2152 let mut v_b = Array2::<f64>::zeros((3, 2));
2153 v_b[[0, 0]] = 1.0;
2154 v_b[[2, 1]] = 1.0;
2155 let mut r_b = Array2::<f64>::zeros((3, 2));
2156 r_b[[0, 0]] = 0.4;
2157 r_b[[0, 1]] = -0.1;
2158 r_b[[1, 0]] = 0.7;
2159 r_b[[1, 1]] = 1.3;
2160 r_b[[2, 0]] = -0.2;
2161 r_b[[2, 1]] = 0.5;
2162 let lift = Gauge::from_v_and_r(&[v_a.clone(), v_b.clone()], &[None, Some(r_b.clone())]);
2163 assert_eq!(lift.t_full.dim(), (6, 5));
2164 assert_eq!(lift.block_starts_reduced, vec![0, 3, 5]);
2165 assert_eq!(lift.block_starts_raw, vec![0, 3, 6]);
2166
2167 let theta_a = Array1::from(vec![1.0_f64, 2.0, -1.5]);
2168 let theta_b = Array1::from(vec![0.5_f64, -0.25]);
2169 let lifted = lift.lift_block_betas(&[theta_a.clone(), theta_b.clone()]);
2170 let r_theta_b = r_b.dot(&theta_b);
2171 let expected_a = &theta_a - &r_theta_b;
2172 assert_eq!(lifted[0].len(), 3);
2173 for (got, want) in lifted[0].iter().zip(expected_a.iter()) {
2174 assert!((got - want).abs() < 1e-12, "got {got}, want {want}");
2175 }
2176 assert_eq!(lifted[1].len(), 3);
2177 assert!((lifted[1][0] - theta_b[0]).abs() < 1e-12);
2178 assert!(lifted[1][1].abs() < 1e-12);
2179 assert!((lifted[1][2] - theta_b[1]).abs() < 1e-12);
2180 }
2181
2182 #[test]
2194 fn smgs_lift_covariance_identity_and_rank1_consistency() {
2195 let lift_id = Gauge::from_v_and_r(
2197 &[Array2::<f64>::eye(2), Array2::<f64>::eye(2)],
2198 &[None, None],
2199 );
2200 let mut cov = Array2::<f64>::zeros((4, 4));
2201 for i in 0..4 {
2203 for j in 0..4 {
2204 cov[[i, j]] = 1.0 / (1.0 + (i as f64 - j as f64).abs());
2205 }
2206 }
2207 let lifted_id = lift_id.lift_covariance(&cov);
2208 assert_eq!(lifted_id.dim(), (4, 4));
2209 for i in 0..4 {
2210 for j in 0..4 {
2211 assert!(
2212 (lifted_id[[i, j]] - cov[[i, j]]).abs() < 1e-12,
2213 "identity-T covariance lift must be a no-op at [{i},{j}]",
2214 );
2215 }
2216 }
2217
2218 let v_a = Array2::<f64>::eye(3);
2223 let mut v_b = Array2::<f64>::zeros((3, 2));
2224 v_b[[0, 0]] = 1.0;
2225 v_b[[2, 1]] = 1.0;
2226 let mut r_b = Array2::<f64>::zeros((3, 2));
2227 r_b[[0, 0]] = 0.4;
2228 r_b[[0, 1]] = -0.1;
2229 r_b[[1, 0]] = 0.7;
2230 r_b[[1, 1]] = 1.3;
2231 r_b[[2, 0]] = -0.2;
2232 r_b[[2, 1]] = 0.5;
2233 let lift = Gauge::from_v_and_r(&[v_a, v_b], &[None, Some(r_b)]);
2234
2235 let theta_a = Array1::from(vec![1.0_f64, 2.0, -1.5]);
2236 let theta_b = Array1::from(vec![0.5_f64, -0.25]);
2237 let theta_full = Array1::from(vec![
2239 theta_a[0], theta_a[1], theta_a[2], theta_b[0], theta_b[1],
2240 ]);
2241 let mut cov_rank1 = Array2::<f64>::zeros((5, 5));
2243 for i in 0..5 {
2244 for j in 0..5 {
2245 cov_rank1[[i, j]] = theta_full[i] * theta_full[j];
2246 }
2247 }
2248 let lifted_cov = lift.lift_covariance(&cov_rank1);
2249 let lifted_blocks = lift.lift_block_betas(&[theta_a, theta_b]);
2251 let beta_raw = Array1::from(
2252 lifted_blocks
2253 .iter()
2254 .flat_map(|b| b.iter().copied())
2255 .collect::<Vec<f64>>(),
2256 );
2257 assert_eq!(lifted_cov.dim(), (6, 6));
2258 assert_eq!(beta_raw.len(), 6);
2259 for i in 0..6 {
2260 for j in 0..6 {
2261 let want = beta_raw[i] * beta_raw[j];
2262 assert!(
2263 (lifted_cov[[i, j]] - want).abs() < 1e-10,
2264 "rank-1 covariance pushforward must equal (Tθ)(Tθ)ᵀ at [{i},{j}]: got {}, want {want}",
2265 lifted_cov[[i, j]],
2266 );
2267 }
2268 }
2269 for i in 0..6 {
2271 for j in 0..6 {
2272 assert!((lifted_cov[[i, j]] - lifted_cov[[j, i]]).abs() < 1e-14);
2273 }
2274 }
2275 }
2276
2277 #[test]
2280 fn smgs_lift_via_t_zero_r_matches_per_block_v_lift() {
2281 let mut v_a = Array2::<f64>::zeros((3, 2));
2282 v_a[[0, 0]] = 0.6;
2283 v_a[[1, 0]] = -0.8;
2284 v_a[[1, 1]] = 0.3;
2285 v_a[[2, 1]] = 0.9;
2286 let mut v_b = Array2::<f64>::zeros((4, 3));
2287 v_b[[0, 0]] = 1.0;
2288 v_b[[1, 1]] = -0.4;
2289 v_b[[2, 0]] = 0.2;
2290 v_b[[2, 2]] = 0.7;
2291 v_b[[3, 2]] = -1.1;
2292 let v_per_term = vec![v_a.clone(), v_b.clone()];
2293 let lift = Gauge::from_v_and_r(&v_per_term, &[None, None]);
2294 let theta_a = Array1::from(vec![0.3_f64, -1.4]);
2295 let theta_b = Array1::from(vec![2.1_f64, 0.0, -0.7]);
2296 let via_t = lift.lift_block_betas(&[theta_a.clone(), theta_b.clone()]);
2297 let ref_a = v_a.dot(&theta_a);
2298 let ref_b = v_b.dot(&theta_b);
2299 assert_eq!(via_t[0].len(), ref_a.len());
2300 for (g, w) in via_t[0].iter().zip(ref_a.iter()) {
2301 assert!((g - w).abs() < 1e-12);
2302 }
2303 assert_eq!(via_t[1].len(), ref_b.len());
2304 for (g, w) in via_t[1].iter().zip(ref_b.iter()) {
2305 assert!((g - w).abs() < 1e-12);
2306 }
2307 }
2308
2309 #[test]
2319 fn recompile_after_accept_diff_detection_pilot_curvature_trap() {
2320 let n = 6usize;
2321 let time_dq0 = Array2::<f64>::from_elem((n, 1), 1.0);
2325 let time_dq1 = Array2::<f64>::zeros((n, 1));
2326 let time_dqd1 = Array2::<f64>::zeros((n, 1));
2327 let marg_dq = Array2::<f64>::from_elem((n, 1), 1.0);
2332 let marg_dqd1 = Array2::<f64>::zeros((n, 1));
2333 let log_dg = Array2::<f64>::zeros((n, 0));
2335 let mut time_partition: Vec<std::ops::Range<usize>> = Vec::with_capacity(1);
2336 time_partition.push(0..1);
2337 let mut marg_partition: Vec<std::ops::Range<usize>> = Vec::with_capacity(1);
2338 marg_partition.push(0..1);
2339 let log_partition: Vec<std::ops::Range<usize>> = Vec::new();
2340
2341 let mut h_ident = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
2345 for i in 0..n {
2346 for k in 0..K_SURVIVAL {
2347 h_ident[[i, k, k]] = 1.0;
2348 }
2349 }
2350 let row_hess_ident = SurvivalRowHessian::from_full(h_ident);
2351 let compiled_ident = compile_survival_parametric_designs_per_term(
2352 time_dq0.clone(),
2353 time_dq1.clone(),
2354 time_dqd1.clone(),
2355 &time_partition,
2356 marg_dq.clone(),
2357 marg_dqd1.clone(),
2358 &marg_partition,
2359 log_dg.clone(),
2360 &log_partition,
2361 &row_hess_ident,
2362 false,
2363 )
2364 .expect("identity-H compile must succeed");
2365
2366 let mut h_q0_only = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
2370 for i in 0..n {
2371 h_q0_only[[i, 0, 0]] = 1.0;
2372 }
2373 let row_hess_q0 = SurvivalRowHessian::from_full(h_q0_only);
2374 let compiled_q0 = compile_survival_parametric_designs_per_term(
2375 time_dq0,
2376 time_dq1,
2377 time_dqd1,
2378 &time_partition,
2379 marg_dq,
2380 marg_dqd1,
2381 &marg_partition,
2382 log_dg,
2383 &log_partition,
2384 &row_hess_q0,
2385 false,
2386 )
2387 .expect("q0-only-H compile must succeed");
2388
2389 assert_ne!(
2393 compiled_ident.drops_by_block, compiled_q0.drops_by_block,
2394 "structural-H and data-adaptive-H compiles must produce different \
2395 drops_by_block on the constructed pilot-curvature-trap design; \
2396 identity={:?} q0-only={:?}",
2397 compiled_ident.drops_by_block, compiled_q0.drops_by_block,
2398 );
2399 assert_eq!(
2401 compiled_ident.drops_by_block.1, 0,
2402 "identity-H marg drops expected 0, got {:?}",
2403 compiled_ident.drops_by_block,
2404 );
2405 assert_eq!(
2407 compiled_q0.drops_by_block.1, 1,
2408 "q0-only-H marg drops expected 1, got {:?}",
2409 compiled_q0.drops_by_block,
2410 );
2411 }
2412
2413 #[test]
2414 fn compiled_map_from_per_term_partitions_and_lift_round_trip() {
2415 let v_time = Array2::<f64>::eye(2);
2419 let mut v_marg = Array2::<f64>::zeros((2, 1));
2420 v_marg[[0, 0]] = 1.0;
2421 v_marg[[1, 0]] = 0.5;
2422 let v_log = Array2::<f64>::eye(1);
2423 let r_marg = Array2::<f64>::from_shape_fn((2, 1), |(i, _)| 0.25 + i as f64);
2426 let r_log = Array2::<f64>::from_shape_fn((4, 1), |(i, _)| 0.1 * (i as f64 + 1.0));
2427 let per_term = SurvivalParametricCompiledPerTerm {
2428 v_time_per_term: vec![v_time.clone()],
2429 v_marginal_per_term: vec![v_marg.clone()],
2430 v_logslope_per_term: vec![v_log.clone()],
2431 r_lw_per_term: vec![None, Some(r_marg.clone()), Some(r_log.clone())],
2432 drops_by_block: (0, 1, 0),
2433 };
2434
2435 let map = compiled_map_from_per_term(&per_term);
2436
2437 assert_eq!(map.raw_block_ranges, vec![0..2, 2..4, 4..5]);
2439 assert_eq!(map.compiled_block_ranges, vec![0..2, 2..3, 3..4]);
2441 assert_eq!(map.raw_from_compiled.dim(), (5, 4));
2442
2443 let v_time_slice = map
2446 .raw_from_compiled
2447 .slice(ndarray::s![0..2, 0..2])
2448 .to_owned();
2449 let v_marg_slice = map
2450 .raw_from_compiled
2451 .slice(ndarray::s![2..4, 2..3])
2452 .to_owned();
2453 let v_log_slice = map
2454 .raw_from_compiled
2455 .slice(ndarray::s![4..5, 3..4])
2456 .to_owned();
2457 for i in 0..2 {
2458 for j in 0..2 {
2459 assert!((v_time_slice[[i, j]] - v_time[[i, j]]).abs() < 1e-14);
2460 }
2461 assert!((v_marg_slice[[i, 0]] - v_marg[[i, 0]]).abs() < 1e-14);
2462 }
2463 assert!((v_log_slice[[0, 0]] - v_log[[0, 0]]).abs() < 1e-14);
2464
2465 let ordering = [
2468 gam_identifiability::families::compiler::BlockOrder::Time,
2469 gam_identifiability::families::compiler::BlockOrder::Marginal,
2470 gam_identifiability::families::compiler::BlockOrder::Logslope,
2471 ];
2472 let lift_from_map = Gauge::from_compiled_map(&map, &ordering);
2473 let v_all = vec![v_time, v_marg, v_log];
2474 let lift_direct = Gauge::from_v_and_r(&v_all, &[None, Some(r_marg), Some(r_log)]);
2475 assert_eq!(lift_from_map.t_full.dim(), lift_direct.t_full.dim());
2476 for i in 0..lift_from_map.t_full.nrows() {
2477 for j in 0..lift_from_map.t_full.ncols() {
2478 assert!(
2479 (lift_from_map.t_full[[i, j]] - lift_direct.t_full[[i, j]]).abs() < 1e-14,
2480 "T mismatch at ({i},{j}): map={} direct={}",
2481 lift_from_map.t_full[[i, j]],
2482 lift_direct.t_full[[i, j]],
2483 );
2484 }
2485 }
2486 }
2487
2488 fn const_row_hess_q0g(n: usize, h00: f64, h03: f64, h33: f64) -> SurvivalRowHessian {
2504 let mut h = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
2505 for i in 0..n {
2506 h[[i, 0, 0]] = h00;
2507 h[[i, 0, 3]] = h03;
2508 h[[i, 3, 0]] = h03;
2509 h[[i, 3, 3]] = h33;
2510 }
2511 SurvivalRowHessian::from_full(h)
2512 }
2513
2514 #[test]
2515 fn survival_reduced_logslope_drops_confounded_keeps_free_979() {
2516 let n = 4;
2522 let row_hess = const_row_hess_q0g(n, 2.0, 2.0, 2.0); let marg = Array2::from_shape_vec((n, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap();
2524 let log =
2527 Array2::from_shape_vec((n, 2), vec![1.0, 10.0, 1.0, -10.0, 1.0, 10.0, 1.0, -10.0])
2528 .unwrap();
2529 let t = survival_reduced_logslope_transform_effective(marg.view(), log.view(), &row_hess)
2530 .expect("contraction must succeed")
2531 .expect("a partial confound must yield a reduced transform");
2532 assert_eq!(t.dim(), (2, 1), "exactly one logslope direction survives");
2533 assert!(
2536 t[[0, 0]].abs() < 1e-6,
2537 "confounded (e1) direction must be dropped, got {}",
2538 t[[0, 0]]
2539 );
2540 assert!(
2541 (t[[1, 0]].abs() - 1.0).abs() < 1e-6,
2542 "free (e2) direction must be kept as a unit vector, got {}",
2543 t[[1, 0]]
2544 );
2545 }
2546
2547 #[test]
2548 fn survival_reduced_logslope_fully_confounded_returns_none_979() {
2549 let n = 4;
2555 let row_hess = const_row_hess_q0g(n, 2.0, 2.0, 2.0);
2556 let marg = Array2::from_shape_vec((n, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap();
2557 let log = marg.clone();
2558 let out =
2559 survival_reduced_logslope_transform_effective(marg.view(), log.view(), &row_hess)
2560 .expect("contraction must succeed");
2561 assert!(
2562 out.is_none(),
2563 "a fully marginal-explained logslope column reduces to nothing → keep raw"
2564 );
2565 }
2566
2567 #[test]
2568 fn survival_reduced_logslope_no_confound_returns_none_979() {
2569 let n = 4;
2573 let row_hess = const_row_hess_q0g(n, 2.0, 0.0, 2.0);
2574 let marg = Array2::from_shape_vec((n, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap();
2575 let log =
2576 Array2::from_shape_vec((n, 2), vec![1.0, 10.0, 1.0, -10.0, 1.0, 10.0, 1.0, -10.0])
2577 .unwrap();
2578 let out =
2579 survival_reduced_logslope_transform_effective(marg.view(), log.view(), &row_hess)
2580 .expect("contraction must succeed");
2581 assert!(out.is_none(), "W-orthogonal channels need no reduction → keep raw");
2582 }
2583
2584 #[test]
2585 fn survival_block_diagonal_logslope_map_is_identity_on_time_and_marginal_979() {
2586 let p_time = 2;
2589 let p_marg = 3;
2590 let t_log = Array2::from_shape_fn((4, 2), |(i, j)| 1.0 + (i * 2 + j) as f64);
2591 let map = survival_block_diagonal_logslope_map(p_time, p_marg, &t_log);
2592
2593 assert_eq!(map.raw_block_ranges, vec![0..2, 2..5, 5..9]);
2594 assert_eq!(map.compiled_block_ranges, vec![0..2, 2..5, 5..7]);
2595 assert_eq!(map.raw_from_compiled.dim(), (9, 7));
2596
2597 let t = &map.raw_from_compiled;
2598 for i in 0..p_time {
2600 for j in 0..p_time {
2601 let want = if i == j { 1.0 } else { 0.0 };
2602 assert!((t[[i, j]] - want).abs() < 1e-14, "V_time[{i},{j}]");
2603 }
2604 }
2605 for i in 0..p_marg {
2607 for j in 0..p_marg {
2608 let want = if i == j { 1.0 } else { 0.0 };
2609 assert!((t[[p_time + i, p_time + j]] - want).abs() < 1e-14, "V_marg[{i},{j}]");
2610 }
2611 }
2612 for i in 0..4 {
2614 for j in 0..2 {
2615 assert!(
2616 (t[[p_time + p_marg + i, p_time + p_marg + j]] - t_log[[i, j]]).abs() < 1e-14,
2617 "V_log[{i},{j}]"
2618 );
2619 }
2620 }
2621 let nnz = t.iter().filter(|&&v| v != 0.0).count();
2624 assert_eq!(nnz, p_time + p_marg + t_log.iter().filter(|&&v| v != 0.0).count());
2625 }
2626}