1use std::sync::Arc;
30
31use ndarray::{Array1, Array2, Array3};
32
33use gam_identifiability::families::compiler::{
34 BlockOrder, RowHessian, RowJacobianOperator, scale_jacobian_by_sqrt_h_with,
35};
36use gam_problem::gauge::assemble_block_triangular_t;
37use faer::Side;
38use gam_linalg::faer_ndarray::FaerEigh;
39use gam_linalg::matrix::{CoefficientTransformOperator, DenseDesignMatrix, DesignMatrix};
40use gam_problem::{FamilyChannelHessian, PenaltyMatrix};
41
42const K_SURVIVAL: usize = 4;
43
44const BETA_NONTRIVIAL_ABS_THRESHOLD: f64 = 1e-12;
50
51pub struct SurvivalRowHessian {
58 h: Array3<f64>,
61}
62
63impl SurvivalRowHessian {
64 pub fn from_pilot_primary_state(
68 q0: &Array1<f64>,
69 q1: &Array1<f64>,
70 qd1: &Array1<f64>,
71 g: &Array1<f64>,
72 z: &Array1<f64>,
73 weights: &Array1<f64>,
74 event: &Array1<f64>,
75 derivative_guard: f64,
76 probit_scale: f64,
77 ) -> Result<Self, String> {
78 let n = q0.len();
79 if [
80 q1.len(),
81 qd1.len(),
82 g.len(),
83 z.len(),
84 weights.len(),
85 event.len(),
86 ]
87 .iter()
88 .any(|&l| l != n)
89 {
90 return Err(format!(
91 "SurvivalRowHessian: length mismatch \
92 q0={n}, q1={}, qd1={}, g={}, z={}, weights={}, event={}",
93 q1.len(),
94 qd1.len(),
95 g.len(),
96 z.len(),
97 weights.len(),
98 event.len()
99 ));
100 }
101 let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
102 for i in 0..n {
103 let (_, _grad, hess) =
104 crate::survival::marginal_slope::row_primary_for_compiler(
105 q0[i],
106 q1[i],
107 qd1[i],
108 g[i],
109 z[i],
110 weights[i],
111 event[i],
112 derivative_guard,
113 probit_scale,
114 )?;
115 let mut h_i = Array2::<f64>::zeros((K_SURVIVAL, K_SURVIVAL));
117 for a in 0..K_SURVIVAL {
118 for b in 0..K_SURVIVAL {
119 h_i[[a, b]] = hess[a][b];
120 }
121 }
122 let clamped = psd_clamp_4x4(&h_i);
123 for a in 0..K_SURVIVAL {
124 for b in 0..K_SURVIVAL {
125 h_full[[i, a, b]] = clamped[[a, b]];
126 }
127 }
128 }
129 Ok(Self { h: h_full })
130 }
131
132 pub fn from_full(h: Array3<f64>) -> Self {
135 assert_eq!(h.shape()[1], K_SURVIVAL);
136 assert_eq!(h.shape()[2], K_SURVIVAL);
137 Self { h }
138 }
139}
140
141impl RowHessian for SurvivalRowHessian {
142 fn k(&self) -> usize {
143 K_SURVIVAL
144 }
145 fn nrows(&self) -> usize {
146 self.h.shape()[0]
147 }
148 fn fill_row(&self, row: usize, out: &mut [f64]) {
149 assert_eq!(out.len(), K_SURVIVAL * K_SURVIVAL);
150 for a in 0..K_SURVIVAL {
151 for b in 0..K_SURVIVAL {
152 out[a * K_SURVIVAL + b] = self.h[[row, a, b]];
153 }
154 }
155 }
156 fn evaluate_full(&self) -> Array3<f64> {
157 self.h.clone()
158 }
159}
160
161impl FamilyChannelHessian for SurvivalRowHessian {
197 fn n_outputs(&self) -> usize {
198 K_SURVIVAL
199 }
200
201 fn n_subjects(&self) -> usize {
202 self.h.shape()[0]
203 }
204
205 fn fill_subject(&self, i: usize, out: &mut [f64]) {
206 assert_eq!(out.len(), K_SURVIVAL * K_SURVIVAL);
207 for a in 0..K_SURVIVAL {
208 for b in 0..K_SURVIVAL {
209 out[a * K_SURVIVAL + b] = self.h[[i, a, b]];
210 }
211 }
212 }
213
214 fn evaluate_full(&self) -> ndarray::Array3<f64> {
215 self.h.clone()
216 }
217
218 fn channel_hessian_at(
219 &self,
220 beta: &[f64],
221 family_scalars: Option<&Arc<dyn std::any::Any + Send + Sync>>,
222 ) -> Result<Arc<dyn FamilyChannelHessian>, String> {
223 use crate::survival::marginal_slope::SurvivalMarginalSlopeFamilyScalars;
224
225 let scalars_opt =
226 family_scalars.and_then(|a| a.downcast_ref::<SurvivalMarginalSlopeFamilyScalars>());
227
228 let beta_nontrivial = beta
230 .iter()
231 .any(|&b| b.abs() > BETA_NONTRIVIAL_ABS_THRESHOLD);
232
233 match scalars_opt {
234 None if beta_nontrivial => {
235 Err(
238 "SurvivalRowHessian::channel_hessian_at: beta is non-trivial but \
239 family_scalars is None; supply SurvivalMarginalSlopeFamilyScalars \
240 via FamilyLinearizationState::family_scalars to evaluate W(β) \
241 correctly (same contract as T26 Jacobian callbacks)."
242 .to_string(),
243 )
244 }
245 None => {
246 Ok(Arc::new(gam_problem::TensorChannelHessian {
248 h: self.h.clone(),
249 }))
250 }
251 Some(sc) => {
252 let n = self.h.shape()[0];
253 if sc.q0_i.len() != n
254 || sc.q1_i.len() != n
255 || sc.qd1_i.len() != n
256 || sc.g_i.len() != n
257 || sc.z_i.len() != n
258 {
259 return Err(format!(
260 "SurvivalRowHessian::channel_hessian_at: scalars length mismatch \
261 (expected n={n}, got q0={} q1={} qd1={} g={} z={})",
262 sc.q0_i.len(),
263 sc.q1_i.len(),
264 sc.qd1_i.len(),
265 sc.g_i.len(),
266 sc.z_i.len(),
267 ));
268 }
269 let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
286 for i in 0..n {
287 let q0 = sc.q0_i[i];
288 let q1 = sc.q1_i[i];
289 let qd1 = sc.qd1_i[i];
290 let g = sc.g_i[i];
291 let z = sc.z_i[i];
292 match crate::survival::marginal_slope::row_primary_for_compiler(
295 q0, q1, qd1, g, z, 1.0, 1.0, crate::survival::marginal_slope::DEFAULT_SURVIVAL_MARGINAL_SLOPE_DERIVATIVE_GUARD,
298 sc.s, ) {
300 Ok((_nll, _grad, hess)) => {
301 let mut h_i = ndarray::Array2::<f64>::zeros((K_SURVIVAL, K_SURVIVAL));
302 for a in 0..K_SURVIVAL {
303 for b in 0..K_SURVIVAL {
304 h_i[[a, b]] = hess[a][b];
305 }
306 }
307 let clamped = psd_clamp_4x4(&h_i);
308 for a in 0..K_SURVIVAL {
309 for b in 0..K_SURVIVAL {
310 h_full[[i, a, b]] = clamped[[a, b]];
311 }
312 }
313 }
314 Err(_) => {
315 for a in 0..K_SURVIVAL {
318 for b in 0..K_SURVIVAL {
319 h_full[[i, a, b]] = self.h[[i, a, b]];
320 }
321 }
322 }
323 }
324 }
325 Ok(Arc::new(SurvivalRowHessian::from_full(h_full)))
326 }
327 }
328 }
329}
330
331fn psd_clamp_4x4(m: &Array2<f64>) -> Array2<f64> {
336 let k = m.nrows();
337 let (evals, evecs) = match m.eigh(Side::Lower) {
338 Ok(pair) => pair,
339 Err(_) => {
340 let mut out = Array2::<f64>::zeros((k, k));
341 for i in 0..k {
342 out[[i, i]] = m[[i, i]].max(0.0);
343 }
344 return out;
345 }
346 };
347 let mut out = Array2::<f64>::zeros((k, k));
348 for i in 0..k {
349 for j in 0..k {
350 let mut acc = 0.0;
351 for l in 0..k {
352 acc += evecs[[i, l]] * evals[l].max(0.0) * evecs[[j, l]];
353 }
354 out[[i, j]] = acc;
355 }
356 }
357 out
358}
359
360pub struct TimeBlockOperator {
363 dq0: Array2<f64>,
364 dq1: Array2<f64>,
365 dqd1: Array2<f64>,
366}
367
368impl TimeBlockOperator {
369 pub fn new(dq0: Array2<f64>, dq1: Array2<f64>, dqd1: Array2<f64>) -> Self {
370 assert_eq!(dq0.dim(), dq1.dim());
371 assert_eq!(dq0.dim(), dqd1.dim());
372 Self { dq0, dq1, dqd1 }
373 }
374}
375
376impl RowJacobianOperator for TimeBlockOperator {
377 fn k(&self) -> usize {
378 K_SURVIVAL
379 }
380 fn ncols(&self) -> usize {
381 self.dq0.ncols()
382 }
383 fn nrows(&self) -> usize {
384 self.dq0.nrows()
385 }
386 fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
387 assert_eq!(out.len(), K_SURVIVAL);
388 assert_eq!(delta_beta.len(), self.dq0.ncols());
389 let mut acc = [0.0_f64; K_SURVIVAL];
390 for (j, &b) in delta_beta.iter().enumerate() {
391 acc[0] += self.dq0[[row, j]] * b;
392 acc[1] += self.dq1[[row, j]] * b;
393 acc[2] += self.dqd1[[row, j]] * b;
394 }
395 out.copy_from_slice(&acc);
396 }
397 fn evaluate_full(&self) -> Array3<f64> {
398 let n = self.dq0.nrows();
399 let p = self.dq0.ncols();
400 let mut out = Array3::<f64>::zeros((n, p, K_SURVIVAL));
401 for i in 0..n {
402 for j in 0..p {
403 out[[i, j, 0]] = self.dq0[[i, j]];
404 out[[i, j, 1]] = self.dq1[[i, j]];
405 out[[i, j, 2]] = self.dqd1[[i, j]];
406 }
407 }
408 out
409 }
410 fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
411 let n = self.dq0.nrows();
417 let p = self.dq0.ncols();
418 scale_jacobian_by_sqrt_h_with(n, p, K_SURVIVAL, h_full, |i, a, c| match c {
419 0 => self.dq0[[i, a]],
420 1 => self.dq1[[i, a]],
421 2 => self.dqd1[[i, a]],
422 _ => 0.0,
423 })
424 }
425}
426
427pub struct QChannelBlockOperator {
432 dq: Array2<f64>,
433 dqd1: Array2<f64>,
434}
435
436impl QChannelBlockOperator {
437 pub fn new(dq: Array2<f64>, dqd1: Array2<f64>) -> Self {
438 assert_eq!(dq.dim(), dqd1.dim());
439 Self { dq, dqd1 }
440 }
441}
442
443impl RowJacobianOperator for QChannelBlockOperator {
444 fn k(&self) -> usize {
445 K_SURVIVAL
446 }
447 fn ncols(&self) -> usize {
448 self.dq.ncols()
449 }
450 fn nrows(&self) -> usize {
451 self.dq.nrows()
452 }
453 fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
454 assert_eq!(out.len(), K_SURVIVAL);
455 assert_eq!(delta_beta.len(), self.dq.ncols());
456 let mut dq_acc = 0.0;
457 let mut dqd_acc = 0.0;
458 for (j, &b) in delta_beta.iter().enumerate() {
459 dq_acc += self.dq[[row, j]] * b;
460 dqd_acc += self.dqd1[[row, j]] * b;
461 }
462 out[0] = dq_acc;
463 out[1] = dq_acc;
464 out[2] = dqd_acc;
465 out[3] = 0.0;
466 }
467 fn evaluate_full(&self) -> Array3<f64> {
468 let n = self.dq.nrows();
469 let p = self.dq.ncols();
470 let mut out = Array3::<f64>::zeros((n, p, K_SURVIVAL));
471 for i in 0..n {
472 for j in 0..p {
473 let v = self.dq[[i, j]];
474 out[[i, j, 0]] = v;
475 out[[i, j, 1]] = v;
476 out[[i, j, 2]] = self.dqd1[[i, j]];
477 }
478 }
479 out
480 }
481 fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
482 let n = self.dq.nrows();
486 let p = self.dq.ncols();
487 scale_jacobian_by_sqrt_h_with(n, p, K_SURVIVAL, h_full, |i, a, c| match c {
488 0 | 1 => self.dq[[i, a]],
489 2 => self.dqd1[[i, a]],
490 _ => 0.0,
491 })
492 }
493}
494
495pub struct LogslopeBlockOperator {
498 dg: Array2<f64>,
499}
500
501impl LogslopeBlockOperator {
502 pub fn new(dg: Array2<f64>) -> Self {
503 Self { dg }
504 }
505}
506
507impl RowJacobianOperator for LogslopeBlockOperator {
508 fn k(&self) -> usize {
509 K_SURVIVAL
510 }
511 fn ncols(&self) -> usize {
512 self.dg.ncols()
513 }
514 fn nrows(&self) -> usize {
515 self.dg.nrows()
516 }
517 fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
518 assert_eq!(out.len(), K_SURVIVAL);
519 assert_eq!(delta_beta.len(), self.dg.ncols());
520 let mut acc = 0.0;
521 for (j, &b) in delta_beta.iter().enumerate() {
522 acc += self.dg[[row, j]] * b;
523 }
524 out[0] = 0.0;
525 out[1] = 0.0;
526 out[2] = 0.0;
527 out[3] = acc;
528 }
529 fn evaluate_full(&self) -> Array3<f64> {
530 let n = self.dg.nrows();
531 let p = self.dg.ncols();
532 let mut out = Array3::<f64>::zeros((n, p, K_SURVIVAL));
533 for i in 0..n {
534 for j in 0..p {
535 out[[i, j, 3]] = self.dg[[i, j]];
536 }
537 }
538 out
539 }
540 fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
541 let n = self.dg.nrows();
546 let p = self.dg.ncols();
547 scale_jacobian_by_sqrt_h_with(n, p, K_SURVIVAL, h_full, |i, a, c| {
548 if c == 3 { self.dg[[i, a]] } else { 0.0 }
549 })
550 }
551}
552
553pub struct SurvivalCompilerInputs {
557 pub operators: Vec<Arc<dyn RowJacobianOperator>>,
558 pub ordering: Vec<BlockOrder>,
559}
560
561pub struct SurvivalParametricCompiled {
577 pub v_time: Array2<f64>,
578 pub v_marginal: Array2<f64>,
579 pub v_logslope: Array2<f64>,
580 pub drops_by_block: (usize, usize, usize),
585}
586
587fn wrap_design_with_transform(
588 raw: DesignMatrix,
589 v: &Array2<f64>,
590 context: &str,
591) -> Result<DesignMatrix, String> {
592 if raw.ncols() != v.nrows() {
593 return Err(format!(
594 "{context}: raw design has {} cols but V has {} rows (V is {}×{})",
595 raw.ncols(),
596 v.nrows(),
597 v.nrows(),
598 v.ncols(),
599 ));
600 }
601 let inner_dense = match raw {
602 DesignMatrix::Dense(d) => d,
603 DesignMatrix::Sparse(_) => {
604 let dense = raw
605 .try_to_dense_by_chunks(&format!("{context} sparse→dense for V apply"))
606 .map_err(|reason| format!("{context}: densify failed: {reason}"))?;
607 DenseDesignMatrix::from(dense)
608 }
609 };
610 let op = CoefficientTransformOperator::new(inner_dense, v.clone())
611 .map_err(|reason| format!("{context}: CoefficientTransformOperator::new: {reason}"))?;
612 Ok(DesignMatrix::Dense(DenseDesignMatrix::from(Arc::new(op))))
613}
614
615pub struct SurvivalParametricCompiledPerTerm {
623 pub v_time_per_term: Vec<Array2<f64>>,
624 pub v_marginal_per_term: Vec<Array2<f64>>,
625 pub v_logslope_per_term: Vec<Array2<f64>>,
626 pub r_lw_per_term: Vec<Option<Array2<f64>>>,
633 pub drops_by_block: (usize, usize, usize),
635}
636
637pub fn compile_survival_parametric_designs_per_term(
656 time_dq0: Array2<f64>,
657 time_dq1: Array2<f64>,
658 time_dqd1: Array2<f64>,
659 time_partition: &[std::ops::Range<usize>],
660 marginal_dq: Array2<f64>,
661 marginal_dqd1: Array2<f64>,
662 marginal_partition: &[std::ops::Range<usize>],
663 logslope_dg: Array2<f64>,
664 logslope_partition: &[std::ops::Range<usize>],
665 row_hess: &dyn RowHessian,
666) -> Result<SurvivalParametricCompiledPerTerm, String> {
667 use gam_identifiability::families::compiler::compile;
668
669 let p_time = time_dq0.ncols();
670 let p_marg = marginal_dq.ncols();
671 let p_log = logslope_dg.ncols();
672 validate_partition(time_partition, p_time, "time")?;
673 validate_partition(marginal_partition, p_marg, "marginal")?;
674 validate_partition(logslope_partition, p_log, "logslope")?;
675
676 let mut operators: Vec<Arc<dyn RowJacobianOperator>> = Vec::new();
680 let mut ordering: Vec<BlockOrder> = Vec::new();
681 for range in time_partition {
682 let dq0 = time_dq0.slice(ndarray::s![.., range.clone()]).to_owned();
683 let dq1 = time_dq1.slice(ndarray::s![.., range.clone()]).to_owned();
684 let dqd1 = time_dqd1.slice(ndarray::s![.., range.clone()]).to_owned();
685 operators.push(Arc::new(TimeBlockOperator::new(dq0, dq1, dqd1)));
686 ordering.push(BlockOrder::Time);
687 }
688 for range in marginal_partition {
689 let dq = marginal_dq.slice(ndarray::s![.., range.clone()]).to_owned();
690 let dqd1 = marginal_dqd1
691 .slice(ndarray::s![.., range.clone()])
692 .to_owned();
693 operators.push(Arc::new(QChannelBlockOperator::new(dq, dqd1)));
694 ordering.push(BlockOrder::Marginal);
695 }
696 for range in logslope_partition {
697 let dg = logslope_dg.slice(ndarray::s![.., range.clone()]).to_owned();
698 operators.push(Arc::new(LogslopeBlockOperator::new(dg)));
699 ordering.push(BlockOrder::Logslope);
700 }
701
702 let compiled = compile(&operators, row_hess, &ordering).map_err(|e| {
703 format!("identifiability::families::compiler::compile (per-term) failed: {e}")
704 })?;
705 let blocks = compiled.blocks;
706 let n_time = time_partition.len();
707 let n_marg = marginal_partition.len();
708 let n_log = logslope_partition.len();
709 if blocks.len() != n_time + n_marg + n_log {
710 return Err(format!(
711 "per-term compile: expected {} compiled blocks (time={}, marg={}, log={}), got {}",
712 n_time + n_marg + n_log,
713 n_time,
714 n_marg,
715 n_log,
716 blocks.len(),
717 ));
718 }
719 let mut iter = blocks.into_iter();
720 let mut v_time_per_term: Vec<Array2<f64>> = Vec::with_capacity(n_time);
721 let mut r_time_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_time);
722 for _ in 0..n_time {
723 let blk = iter.next().unwrap();
724 v_time_per_term.push(blk.t_lw);
725 r_time_per_term.push(blk.r_lw);
726 }
727 let mut v_marginal_per_term: Vec<Array2<f64>> = Vec::with_capacity(n_marg);
728 let mut r_marginal_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_marg);
729 for _ in 0..n_marg {
730 let blk = iter.next().unwrap();
731 v_marginal_per_term.push(blk.t_lw);
732 r_marginal_per_term.push(blk.r_lw);
733 }
734 let mut v_logslope_per_term: Vec<Array2<f64>> = Vec::with_capacity(n_log);
735 let mut r_logslope_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_log);
736 for _ in 0..n_log {
737 let blk = iter.next().unwrap();
738 v_logslope_per_term.push(blk.t_lw);
739 r_logslope_per_term.push(blk.r_lw);
740 }
741 let mut r_lw_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_time + n_marg + n_log);
742 r_lw_per_term.extend(r_time_per_term);
743 r_lw_per_term.extend(r_marginal_per_term);
744 r_lw_per_term.extend(r_logslope_per_term);
745 let drops_time: usize = time_partition
746 .iter()
747 .zip(v_time_per_term.iter())
748 .map(|(r, v)| r.len().saturating_sub(v.ncols()))
749 .sum();
750 let drops_marg: usize = marginal_partition
751 .iter()
752 .zip(v_marginal_per_term.iter())
753 .map(|(r, v)| r.len().saturating_sub(v.ncols()))
754 .sum();
755 let drops_log: usize = logslope_partition
756 .iter()
757 .zip(v_logslope_per_term.iter())
758 .map(|(r, v)| r.len().saturating_sub(v.ncols()))
759 .sum();
760 Ok(SurvivalParametricCompiledPerTerm {
761 v_time_per_term,
762 v_marginal_per_term,
763 v_logslope_per_term,
764 r_lw_per_term,
765 drops_by_block: (drops_time, drops_marg, drops_log),
766 })
767}
768
769fn validate_partition(
770 partition: &[std::ops::Range<usize>],
771 p_block: usize,
772 label: &str,
773) -> Result<(), String> {
774 if partition.is_empty() {
775 if p_block == 0 {
776 return Ok(());
777 }
778 return Err(format!(
779 "{label} partition empty but block has p={p_block} columns"
780 ));
781 }
782 if partition[0].start != 0 {
783 return Err(format!(
784 "{label} partition must start at 0, got start={}",
785 partition[0].start
786 ));
787 }
788 if partition.last().unwrap().end != p_block {
789 return Err(format!(
790 "{label} partition must cover [0, {p_block}); last range ends at {}",
791 partition.last().unwrap().end
792 ));
793 }
794 for w in partition.windows(2) {
795 if w[0].end != w[1].start {
796 return Err(format!(
797 "{label} partition has gap/overlap between [{}..{}) and [{}..{})",
798 w[0].start, w[0].end, w[1].start, w[1].end
799 ));
800 }
801 if w[0].is_empty() {
802 return Err(format!(
803 "{label} partition has empty range [{}..{})",
804 w[0].start, w[0].end
805 ));
806 }
807 }
808 if partition.last().unwrap().is_empty() {
809 return Err(format!("{label} partition's final range is empty",));
810 }
811 Ok(())
812}
813
814pub fn extract_term_partition_from_penalty_ranges(
820 p_block: usize,
821 penalty_ranges: &[std::ops::Range<usize>],
822) -> Vec<std::ops::Range<usize>> {
823 use std::collections::BTreeSet;
824 let mut starts: BTreeSet<usize> = BTreeSet::new();
825 starts.insert(0);
826 starts.insert(p_block);
827 for r in penalty_ranges {
828 starts.insert(r.start.min(p_block));
829 starts.insert(r.end.min(p_block));
830 }
831 let v: Vec<usize> = starts.into_iter().collect();
832 v.windows(2)
833 .filter_map(|w| if w[0] < w[1] { Some(w[0]..w[1]) } else { None })
834 .collect()
835}
836
837pub fn pull_back_blockwise_penalty_through_block_v(
860 pen: &gam_terms::smooth::BlockwisePenalty,
861 v_block: &Array2<f64>,
862) -> Result<PenaltyMatrix, String> {
863 let raw_p = v_block.nrows();
864 let compiled_p = v_block.ncols();
865 let block_p = pen.col_range.len();
866 let embed_start = pen.col_range.start;
867 let embed_end = pen.col_range.end;
868 if embed_end > raw_p {
869 return Err(format!(
870 "pull_back_blockwise_penalty_through_block_v: penalty col_range {embed_start}..{embed_end} \
871 exceeds block raw width {raw_p}"
872 ));
873 }
874 if pen.local.nrows() != block_p || pen.local.ncols() != block_p {
875 return Err(format!(
876 "pull_back_blockwise_penalty_through_block_v: penalty local is {}x{} but col_range \
877 width is {block_p}",
878 pen.local.nrows(),
879 pen.local.ncols(),
880 ));
881 }
882 let mut embedded = Array2::<f64>::zeros((raw_p, raw_p));
883 if block_p > 0 {
884 let mut dst =
885 embedded.slice_mut(ndarray::s![embed_start..embed_end, embed_start..embed_end]);
886 for i in 0..block_p {
887 for j in 0..block_p {
888 dst[[i, j]] = pen.local[[i, j]];
889 }
890 }
891 }
892 let temp = embedded.dot(v_block);
894 let pulled = v_block.t().dot(&temp);
895 let mut sym = Array2::<f64>::zeros((compiled_p, compiled_p));
896 for i in 0..compiled_p {
897 for j in 0..compiled_p {
898 sym[[i, j]] = 0.5 * (pulled[[i, j]] + pulled[[j, i]]);
899 }
900 }
901 Ok(PenaltyMatrix::Dense(sym))
902}
903
904pub fn compiled_map_from_per_term(
926 compiled: &SurvivalParametricCompiledPerTerm,
927) -> gam_identifiability::families::compiler::CompiledMap {
928 let mut v_all: Vec<Array2<f64>> = Vec::new();
931 v_all.extend(compiled.v_time_per_term.iter().cloned());
932 v_all.extend(compiled.v_marginal_per_term.iter().cloned());
933 v_all.extend(compiled.v_logslope_per_term.iter().cloned());
934
935 let t_full = assemble_block_triangular_t(&v_all, &compiled.r_lw_per_term);
936
937 let raw_w = |terms: &[Array2<f64>]| -> usize { terms.iter().map(|v| v.nrows()).sum() };
939 let kept_w = |terms: &[Array2<f64>]| -> usize { terms.iter().map(|v| v.ncols()).sum() };
940 let raw_time = raw_w(&compiled.v_time_per_term);
941 let raw_marg = raw_w(&compiled.v_marginal_per_term);
942 let raw_log = raw_w(&compiled.v_logslope_per_term);
943 let kept_time = kept_w(&compiled.v_time_per_term);
944 let kept_marg = kept_w(&compiled.v_marginal_per_term);
945 let kept_log = kept_w(&compiled.v_logslope_per_term);
946
947 let raw_block_ranges = vec![
948 0..raw_time,
949 raw_time..(raw_time + raw_marg),
950 (raw_time + raw_marg)..(raw_time + raw_marg + raw_log),
951 ];
952 let compiled_block_ranges = vec![
953 0..kept_time,
954 kept_time..(kept_time + kept_marg),
955 (kept_time + kept_marg)..(kept_time + kept_marg + kept_log),
956 ];
957
958 gam_identifiability::families::compiler::CompiledMap {
959 raw_from_compiled: t_full,
960 compiled_block_ranges,
961 raw_block_ranges,
962 }
963}
964
965pub fn survival_reduced_logslope_transform_effective(
1021 marginal_dq: ndarray::ArrayView2<'_, f64>,
1022 logslope_dg: ndarray::ArrayView2<'_, f64>,
1023 row_hess: &SurvivalRowHessian,
1024) -> Result<Option<Array2<f64>>, String> {
1025 use crate::bms::block_specs::LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL;
1026 use gam_linalg::faer_ndarray::{
1027 FaerArrayView, factorize_symmetricwith_fallback, fast_atb, fast_xt_diag_x, fast_xt_diag_y,
1028 };
1029
1030 let n = marginal_dq.nrows();
1031 let p_m = marginal_dq.ncols();
1032 let p_log = logslope_dg.ncols();
1033 if p_m == 0 || p_log == 0 {
1034 return Ok(None);
1035 }
1036 if logslope_dg.nrows() != n || row_hess.h.shape()[0] != n {
1037 return Err(format!(
1038 "survival reduced logslope: row mismatch marginal={n}, logslope={}, row_hess={}",
1039 logslope_dg.nrows(),
1040 row_hess.h.shape()[0],
1041 ));
1042 }
1043
1044 let mut w_mm = Array1::<f64>::zeros(n);
1047 let mut w_mg = Array1::<f64>::zeros(n);
1048 let mut w_gg = Array1::<f64>::zeros(n);
1049 for i in 0..n {
1050 w_mm[i] = row_hess.h[[i, 0, 0]] + row_hess.h[[i, 1, 1]];
1051 w_mg[i] = row_hess.h[[i, 0, 3]] + row_hess.h[[i, 1, 3]];
1052 w_gg[i] = row_hess.h[[i, 3, 3]];
1053 if !(w_mm[i].is_finite() && w_mg[i].is_finite() && w_gg[i].is_finite()) {
1054 return Err("survival reduced logslope: non-finite row Hessian weight".to_string());
1055 }
1056 }
1057
1058 let marg = marginal_dq.to_owned();
1059 let log = logslope_dg.to_owned();
1060
1061 let c_gram = fast_xt_diag_x(&log, &w_gg);
1064 let energy_scale = (0..p_log).map(|i| c_gram[[i, i]]).fold(0.0_f64, f64::max);
1065 if !energy_scale.is_finite() || energy_scale <= 0.0 {
1066 return Ok(None);
1067 }
1068
1069 let mut a_gram = fast_xt_diag_x(&marg, &w_mm);
1073 let a_scale = (0..p_m).map(|i| a_gram[[i, i]]).fold(0.0_f64, f64::max);
1074 let a_ridge = (a_scale * LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL).max(f64::EPSILON);
1075 for i in 0..p_m {
1076 a_gram[[i, i]] += a_ridge;
1077 }
1078
1079 let b_cross = fast_xt_diag_y(&marg, &w_mg, &log);
1081 let a_view = FaerArrayView::new(&a_gram);
1082 let a_factor = factorize_symmetricwith_fallback(a_view.as_ref(), Side::Lower).map_err(|e| {
1083 format!("survival reduced logslope: marginal effective Gram factorization failed: {e}")
1084 })?;
1085 let b_view = FaerArrayView::new(&b_cross);
1086 let solved = a_factor.solve(b_view.as_ref()); let a_inv_b = Array2::from_shape_fn((p_m, p_log), |(i, j)| solved[(i, j)]);
1088 let schur = fast_atb(&b_cross, &a_inv_b); let mut stt = &c_gram - &schur;
1090 stt = (&stt + &stt.t()) * 0.5;
1091 if stt.iter().any(|v| !v.is_finite()) {
1092 return Err(
1093 "survival reduced logslope: effective Schur Gram produced non-finite entries"
1094 .to_string(),
1095 );
1096 }
1097
1098 let (evals, evecs) = stt
1099 .eigh(Side::Lower)
1100 .map_err(|e| format!("survival reduced logslope: eigendecomposition failed: {e:?}"))?;
1101 let tol = energy_scale * LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL;
1105 let mut kept: Vec<usize> = (0..evals.len()).filter(|&i| evals[i] > tol).collect();
1106 kept.sort_by(|&a, &b| {
1107 evals[b]
1108 .partial_cmp(&evals[a])
1109 .unwrap_or(std::cmp::Ordering::Equal)
1110 });
1111 let r = kept.len();
1112 if r == p_log || r == 0 {
1115 return Ok(None);
1116 }
1117 let mut transform = Array2::<f64>::zeros((p_log, r));
1118 for (out_col, &src) in kept.iter().enumerate() {
1119 transform.column_mut(out_col).assign(&evecs.column(src));
1120 }
1121 if transform.iter().any(|v| !v.is_finite()) {
1122 return Err(
1123 "survival reduced logslope: reduced transform produced non-finite entries".to_string(),
1124 );
1125 }
1126 Ok(Some(transform))
1127}
1128
1129pub fn survival_block_diagonal_logslope_map(
1146 p_time: usize,
1147 p_marg: usize,
1148 t_log: &Array2<f64>,
1149) -> gam_identifiability::families::compiler::CompiledMap {
1150 let p_log = t_log.nrows();
1151 let r = t_log.ncols();
1152 let raw_total = p_time + p_marg + p_log;
1153 let compiled_total = p_time + p_marg + r;
1154 let mut t_full = Array2::<f64>::zeros((raw_total, compiled_total));
1155 for i in 0..p_time {
1156 t_full[[i, i]] = 1.0;
1157 }
1158 for i in 0..p_marg {
1159 t_full[[p_time + i, p_time + i]] = 1.0;
1160 }
1161 for ri in 0..p_log {
1162 for cj in 0..r {
1163 t_full[[p_time + p_marg + ri, p_time + p_marg + cj]] = t_log[[ri, cj]];
1164 }
1165 }
1166 gam_identifiability::families::compiler::CompiledMap {
1167 raw_from_compiled: t_full,
1168 compiled_block_ranges: vec![
1169 0..p_time,
1170 p_time..(p_time + p_marg),
1171 (p_time + p_marg)..compiled_total,
1172 ],
1173 raw_block_ranges: vec![
1174 0..p_time,
1175 p_time..(p_time + p_marg),
1176 (p_time + p_marg)..raw_total,
1177 ],
1178 }
1179}
1180
1181pub fn apply_compiled_map_to_designs(
1209 map: &gam_identifiability::families::compiler::CompiledMap,
1210 time_design_entry: DesignMatrix,
1211 time_design_exit: DesignMatrix,
1212 time_design_derivative_exit: DesignMatrix,
1213 marginal_design: DesignMatrix,
1214 logslope_design: DesignMatrix,
1215 time_penalties: &[gam_terms::smooth::BlockwisePenalty],
1216 marginal_penalties: &[gam_terms::smooth::BlockwisePenalty],
1217 logslope_penalties: &[gam_terms::smooth::BlockwisePenalty],
1218) -> Result<CompiledSurvivalDesignsVMExact, String> {
1219 if map.raw_block_ranges.len() != 3 || map.compiled_block_ranges.len() != 3 {
1220 return Err(format!(
1221 "apply_compiled_map_to_designs: expected exactly 3 blocks (time, marginal, logslope), \
1222 got {} raw / {} compiled",
1223 map.raw_block_ranges.len(),
1224 map.compiled_block_ranges.len(),
1225 ));
1226 }
1227 let time_raw = map.raw_block_ranges[0].clone();
1228 let marg_raw = map.raw_block_ranges[1].clone();
1229 let log_raw = map.raw_block_ranges[2].clone();
1230 let time_compiled = map.compiled_block_ranges[0].clone();
1231 let marg_compiled = map.compiled_block_ranges[1].clone();
1232 let log_compiled = map.compiled_block_ranges[2].clone();
1233
1234 let t = &map.raw_from_compiled;
1235 let raw_total = t.nrows();
1236 let compiled_total = t.ncols();
1237 let expected_raw_total = log_raw.end;
1238 if raw_total != expected_raw_total {
1239 return Err(format!(
1240 "apply_compiled_map_to_designs: T has {raw_total} raw rows but block ranges sum to \
1241 {expected_raw_total}"
1242 ));
1243 }
1244 let expected_compiled_total = log_compiled.end;
1245 if compiled_total != expected_compiled_total {
1246 return Err(format!(
1247 "apply_compiled_map_to_designs: T has {compiled_total} compiled cols but block ranges \
1248 sum to {expected_compiled_total}"
1249 ));
1250 }
1251
1252 let v_time = t
1253 .slice(ndarray::s![time_raw.clone(), time_compiled.clone()])
1254 .to_owned();
1255 let v_marg = t
1256 .slice(ndarray::s![marg_raw.clone(), marg_compiled.clone()])
1257 .to_owned();
1258 let v_log = t
1259 .slice(ndarray::s![log_raw.clone(), log_compiled.clone()])
1260 .to_owned();
1261
1262 let time_entry_out =
1263 wrap_design_with_transform(time_design_entry, &v_time, "compiled-map: time entry")?;
1264 let time_exit_out =
1265 wrap_design_with_transform(time_design_exit, &v_time, "compiled-map: time exit")?;
1266 let time_deriv_out = wrap_design_with_transform(
1267 time_design_derivative_exit,
1268 &v_time,
1269 "compiled-map: time derivative_exit",
1270 )?;
1271 let marg_out = wrap_design_with_transform(marginal_design, &v_marg, "compiled-map: marginal")?;
1272 let log_out = wrap_design_with_transform(logslope_design, &v_log, "compiled-map: logslope")?;
1273
1274 let pull_set = |pens: &[gam_terms::smooth::BlockwisePenalty],
1295 v_block: &Array2<f64>,
1296 channel: &str|
1297 -> Result<Vec<PenaltyMatrix>, String> {
1298 pens.iter()
1299 .map(|p| {
1300 pull_back_blockwise_penalty_through_block_v(p, v_block).map_err(|e| {
1301 format!("apply_compiled_map_to_designs: {channel} penalty pullback: {e}")
1302 })
1303 })
1304 .collect()
1305 };
1306
1307 let time_penalties = pull_set(time_penalties, &v_time, "time")?;
1308 let marginal_penalties = pull_set(marginal_penalties, &v_marg, "marginal")?;
1309 let logslope_penalties = pull_set(logslope_penalties, &v_log, "logslope")?;
1310 validate_block_penalty_shapes("time", time_exit_out.ncols(), &time_penalties)?;
1311 validate_block_penalty_shapes("marginal", marg_out.ncols(), &marginal_penalties)?;
1312 validate_block_penalty_shapes("logslope", log_out.ncols(), &logslope_penalties)?;
1313
1314 Ok(CompiledSurvivalDesignsVMExact {
1315 time_design_entry: time_entry_out,
1316 time_design_exit: time_exit_out,
1317 time_design_derivative_exit: time_deriv_out,
1318 marginal_design: marg_out,
1319 logslope_design: log_out,
1320 time_penalties,
1321 marginal_penalties,
1322 logslope_penalties,
1323 })
1324}
1325
1326fn validate_block_penalty_shapes(
1327 block: &str,
1328 width: usize,
1329 penalties: &[PenaltyMatrix],
1330) -> Result<(), String> {
1331 for (idx, penalty) in penalties.iter().enumerate() {
1332 let shape = penalty.shape();
1333 if shape != (width, width) {
1334 return Err(format!(
1335 "apply_compiled_map_to_designs: {block} penalty {idx} must be {width}x{width}, got {}x{}",
1336 shape.0, shape.1
1337 ));
1338 }
1339 }
1340 Ok(())
1341}
1342
1343pub fn compile_survival_parametric_designs(
1371 time_dq0: Array2<f64>,
1372 time_dq1: Array2<f64>,
1373 time_dqd1: Array2<f64>,
1374 marginal_dq: Array2<f64>,
1375 marginal_dqd1: Array2<f64>,
1376 logslope_dg: Array2<f64>,
1377 row_hess: &dyn RowHessian,
1378) -> Result<SurvivalParametricCompiled, String> {
1379 use gam_identifiability::families::compiler::compile;
1380
1381 let p_time_raw = time_dq0.ncols();
1382 let p_marg_raw = marginal_dq.ncols();
1383 let p_log_raw = logslope_dg.ncols();
1384
1385 let inputs = build_survival_compiler_inputs(
1386 time_dq0,
1387 time_dq1,
1388 time_dqd1,
1389 marginal_dq,
1390 marginal_dqd1,
1391 logslope_dg,
1392 None,
1393 None,
1394 );
1395 if inputs.operators.len() != 3 {
1396 return Err(format!(
1397 "compile_survival_parametric_designs: expected exactly 3 parametric operators \
1398 (time, marginal, logslope); got {}",
1399 inputs.operators.len(),
1400 ));
1401 }
1402 let compiled = compile(&inputs.operators, row_hess, &inputs.ordering)
1403 .map_err(|e| format!("identifiability::families::compiler::compile failed: {e}"))?;
1404 if compiled.blocks.len() != 3 {
1405 return Err(format!(
1406 "compile_survival_parametric_designs: compiler emitted {} blocks; expected 3",
1407 compiled.blocks.len(),
1408 ));
1409 }
1410 let v_time = compiled.blocks[0].t_lw.clone();
1411 let v_marginal = compiled.blocks[1].t_lw.clone();
1412 let v_logslope = compiled.blocks[2].t_lw.clone();
1413 let drops_by_block = (
1414 p_time_raw.saturating_sub(v_time.ncols()),
1415 p_marg_raw.saturating_sub(v_marginal.ncols()),
1416 p_log_raw.saturating_sub(v_logslope.ncols()),
1417 );
1418 Ok(SurvivalParametricCompiled {
1419 v_time,
1420 v_marginal,
1421 v_logslope,
1422 drops_by_block,
1423 })
1424}
1425
1426pub fn build_survival_compiler_inputs(
1438 time_dq0: Array2<f64>,
1439 time_dq1: Array2<f64>,
1440 time_dqd1: Array2<f64>,
1441 marginal_dq: Array2<f64>,
1442 marginal_dqd1: Array2<f64>,
1443 logslope_dg: Array2<f64>,
1444 score_warp_dq_dqd1: Option<(Array2<f64>, Array2<f64>)>,
1445 link_dev_dq_dqd1: Option<(Array2<f64>, Array2<f64>)>,
1446) -> SurvivalCompilerInputs {
1447 let mut operators: Vec<Arc<dyn RowJacobianOperator>> = Vec::with_capacity(5);
1448 let mut ordering: Vec<BlockOrder> = Vec::with_capacity(5);
1449
1450 operators.push(Arc::new(TimeBlockOperator::new(
1451 time_dq0, time_dq1, time_dqd1,
1452 )));
1453 ordering.push(BlockOrder::Time);
1454
1455 operators.push(Arc::new(QChannelBlockOperator::new(
1456 marginal_dq,
1457 marginal_dqd1,
1458 )));
1459 ordering.push(BlockOrder::Marginal);
1460
1461 operators.push(Arc::new(LogslopeBlockOperator::new(logslope_dg)));
1462 ordering.push(BlockOrder::Logslope);
1463
1464 if let Some((dq, dqd1)) = score_warp_dq_dqd1 {
1465 operators.push(Arc::new(QChannelBlockOperator::new(dq, dqd1)));
1466 ordering.push(BlockOrder::ScoreWarp);
1467 }
1468 if let Some((dq, dqd1)) = link_dev_dq_dqd1 {
1469 operators.push(Arc::new(QChannelBlockOperator::new(dq, dqd1)));
1470 ordering.push(BlockOrder::LinkDev);
1471 }
1472
1473 SurvivalCompilerInputs {
1474 operators,
1475 ordering,
1476 }
1477}
1478
1479pub struct CompiledSurvivalDesignsVMExact {
1498 pub time_design_entry: DesignMatrix,
1499 pub time_design_exit: DesignMatrix,
1500 pub time_design_derivative_exit: DesignMatrix,
1501 pub marginal_design: DesignMatrix,
1502 pub logslope_design: DesignMatrix,
1503 pub time_penalties: Vec<PenaltyMatrix>,
1511 pub marginal_penalties: Vec<PenaltyMatrix>,
1512 pub logslope_penalties: Vec<PenaltyMatrix>,
1513}
1514
1515#[cfg(test)]
1516mod tests {
1517 use super::*;
1518 use gam_problem::Gauge;
1519
1520 #[test]
1521 fn psd_clamp_zeros_negative_eigenvalues() {
1522 let mut m = Array2::<f64>::zeros((4, 4));
1526 m[[0, 0]] = 2.0;
1529 m[[1, 1]] = -1.0;
1530 m[[2, 2]] = 0.5;
1531 m[[3, 3]] = -0.25;
1532 let clamped = psd_clamp_4x4(&m);
1533 assert!((clamped[[0, 0]] - 2.0).abs() < 1e-12);
1534 assert!(clamped[[1, 1]].abs() < 1e-12);
1535 assert!((clamped[[2, 2]] - 0.5).abs() < 1e-12);
1536 assert!(clamped[[3, 3]].abs() < 1e-12);
1537 }
1538
1539 #[test]
1540 fn time_block_operator_evaluate_full_shape() {
1541 let n = 6;
1542 let p = 3;
1543 let dq0 = Array2::from_shape_fn((n, p), |(i, j)| (i + j) as f64);
1544 let dq1 = Array2::from_shape_fn((n, p), |(i, j)| (i as f64) * 2.0 + j as f64);
1545 let dqd1 = Array2::from_shape_fn((n, p), |(i, j)| 0.5 * ((i * j) as f64));
1546 let op = TimeBlockOperator::new(dq0.clone(), dq1.clone(), dqd1.clone());
1547 let full = op.evaluate_full();
1548 assert_eq!(full.shape(), &[n, p, K_SURVIVAL]);
1549 for i in 0..n {
1550 for j in 0..p {
1551 assert_eq!(full[[i, j, 0]], dq0[[i, j]]);
1552 assert_eq!(full[[i, j, 1]], dq1[[i, j]]);
1553 assert_eq!(full[[i, j, 2]], dqd1[[i, j]]);
1554 assert_eq!(full[[i, j, 3]], 0.0);
1555 }
1556 }
1557 }
1558
1559 #[test]
1560 fn q_channel_block_apply_row_shares_q0_q1() {
1561 let n = 5;
1562 let p = 2;
1563 let dq = Array2::from_shape_fn((n, p), |(i, j)| (i as f64) * (j as f64 + 1.0));
1564 let dqd1 = Array2::from_shape_fn((n, p), |(i, j)| (j as f64) - (i as f64));
1565 let op = QChannelBlockOperator::new(dq.clone(), dqd1.clone());
1566 let mut out = [0.0_f64; K_SURVIVAL];
1567 let delta = [1.0_f64, -0.5];
1568 op.apply_row(3, &delta, &mut out);
1569 let want_q = dq[[3, 0]] * 1.0 + dq[[3, 1]] * (-0.5);
1570 let want_qd = dqd1[[3, 0]] * 1.0 + dqd1[[3, 1]] * (-0.5);
1571 assert!((out[0] - want_q).abs() < 1e-12);
1572 assert!((out[1] - want_q).abs() < 1e-12);
1573 assert!((out[2] - want_qd).abs() < 1e-12);
1574 assert_eq!(out[3], 0.0);
1575 }
1576
1577 #[test]
1578 fn logslope_block_writes_only_g_channel() {
1579 let n = 4;
1580 let p = 2;
1581 let dg = Array2::from_shape_fn((n, p), |(i, j)| (i as f64) + 0.1 * (j as f64));
1582 let op = LogslopeBlockOperator::new(dg.clone());
1583 let mut out = [0.0_f64; K_SURVIVAL];
1584 let delta = [2.0_f64, -1.0];
1585 op.apply_row(1, &delta, &mut out);
1586 assert_eq!(out[0], 0.0);
1587 assert_eq!(out[1], 0.0);
1588 assert_eq!(out[2], 0.0);
1589 let want = dg[[1, 0]] * 2.0 + dg[[1, 1]] * (-1.0);
1590 assert!((out[3] - want).abs() < 1e-12);
1591 }
1592
1593 #[test]
1594 fn extract_term_partition_simple_cases() {
1595 let full = 0..5usize;
1596 let part = extract_term_partition_from_penalty_ranges(5, &[]);
1598 assert_eq!(part.as_slice(), std::slice::from_ref(&full));
1599 let part = extract_term_partition_from_penalty_ranges(5, std::slice::from_ref(&full));
1601 assert_eq!(part.as_slice(), std::slice::from_ref(&full));
1602 let part = extract_term_partition_from_penalty_ranges(10, &[0..3, 6..10]);
1604 assert_eq!(part, vec![0..3, 3..6, 6..10]);
1605 let part = extract_term_partition_from_penalty_ranges(6, &[0..3, 0..3, 3..6]);
1607 assert_eq!(part, vec![0..3, 3..6]);
1608 let part = extract_term_partition_from_penalty_ranges(0, &[]);
1610 assert!(part.is_empty());
1611 }
1612
1613 #[test]
1614 fn assemble_block_triangular_t_identity_when_v_eye_and_r_none() {
1615 let v_a = Array2::<f64>::eye(2);
1616 let v_b = Array2::<f64>::eye(2);
1617 let t = assemble_block_triangular_t(&[v_a, v_b], &[None, None]);
1618 assert_eq!(t.dim(), (4, 4));
1619 let eye4 = Array2::<f64>::eye(4);
1620 for i in 0..4 {
1621 for j in 0..4 {
1622 assert!((t[[i, j]] - eye4[[i, j]]).abs() < 1e-14);
1623 }
1624 }
1625 }
1626
1627 #[test]
1628 fn assemble_block_triangular_t_with_drops_and_nonzero_r() {
1629 let mut v_a = Array2::<f64>::zeros((3, 2));
1630 v_a[[0, 0]] = 1.0;
1631 v_a[[1, 0]] = 0.5;
1632 v_a[[2, 1]] = 1.0;
1633 let v_b = Array2::<f64>::eye(2);
1634 let r_ab =
1635 Array2::<f64>::from_shape_fn((3, 2), |(i, j)| 1.0 + (i as f64) + 0.25 * (j as f64));
1636 let t =
1637 assemble_block_triangular_t(&[v_a.clone(), v_b.clone()], &[None, Some(r_ab.clone())]);
1638 assert_eq!(t.dim(), (5, 4));
1639 for i in 0..3 {
1640 for j in 0..2 {
1641 assert!((t[[i, j]] - v_a[[i, j]]).abs() < 1e-14);
1642 }
1643 }
1644 for i in 0..2 {
1645 for j in 0..2 {
1646 assert!((t[[3 + i, 2 + j]] - v_b[[i, j]]).abs() < 1e-14);
1647 }
1648 }
1649 for i in 0..3 {
1650 for j in 0..2 {
1651 assert!((t[[i, 2 + j]] + r_ab[[i, j]]).abs() < 1e-14);
1652 }
1653 }
1654 for i in 0..2 {
1655 for j in 0..2 {
1656 assert_eq!(t[[3 + i, j]], 0.0);
1657 }
1658 }
1659 }
1660
1661 #[test]
1662 fn validate_partition_rejects_bad_partitions() {
1663 let bad_start = 1..5usize;
1664 let short_cover = 0..3usize;
1665 let full_cover = 0..5usize;
1666 assert!(validate_partition(std::slice::from_ref(&bad_start), 5, "test").is_err());
1668 assert!(validate_partition(std::slice::from_ref(&short_cover), 5, "test").is_err());
1670 assert!(validate_partition(&[0..2, 3..5], 5, "test").is_err());
1672 assert!(validate_partition(&[0..3, 2..5], 5, "test").is_err());
1674 assert!(validate_partition(&[0..0, 0..5], 5, "test").is_err());
1676 assert!(validate_partition(&[], 0, "test").is_ok());
1678 assert!(validate_partition(&[0..2, 2..5], 5, "test").is_ok());
1680 assert!(validate_partition(std::slice::from_ref(&full_cover), 5, "test").is_ok());
1681 }
1682
1683 #[test]
1694 fn compiled_map_penalty_pullback_is_per_block_width_with_nonzero_residual() {
1695 use gam_identifiability::families::compiler::CompiledMap;
1696 use gam_terms::smooth::BlockwisePenalty;
1697
1698 let n = 10;
1699 let v_time =
1703 Array2::<f64>::from_shape_fn(
1704 (3, 3),
1705 |(i, j)| {
1706 if i == j { 1.0 } else { 0.1 * ((i + j) as f64) }
1707 },
1708 );
1709 let v_marg = Array2::<f64>::from_shape_fn((3, 2), |(i, j)| {
1710 0.5 + 0.3 * (i as f64) - 0.2 * (j as f64)
1711 });
1712 let v_log = Array2::<f64>::from_shape_fn((2, 2), |(i, j)| if i == j { 1.2 } else { 0.4 });
1713 let r_marg = Array2::<f64>::from_shape_fn((3, 2), |(i, j)| 0.7 - 0.1 * ((i + j) as f64));
1715 let r_log =
1720 Array2::<f64>::from_shape_fn((6, 2), |(i, j)| 0.3 + 0.05 * ((i * 2 + j) as f64));
1721
1722 let t = assemble_block_triangular_t(
1723 &[v_time.clone(), v_marg.clone(), v_log.clone()],
1724 &[None, Some(r_marg.clone()), Some(r_log.clone())],
1725 );
1726 assert_eq!(t.dim(), (8, 7), "joint raw 8 × joint compiled 7");
1727
1728 let map = CompiledMap {
1729 raw_from_compiled: t.clone(),
1730 compiled_block_ranges: vec![0..3, 3..5, 5..7],
1731 raw_block_ranges: vec![0..3, 3..6, 6..8],
1732 };
1733
1734 let raw_time_entry = DesignMatrix::Dense(DenseDesignMatrix::from(
1736 Array2::<f64>::from_shape_fn((n, 3), |(i, j)| 1.0 + (i as f64) * 0.1 + (j as f64)),
1737 ));
1738 let raw_time_exit = raw_time_entry.clone();
1739 let raw_time_deriv = raw_time_entry.clone();
1740 let raw_marg = DesignMatrix::Dense(DenseDesignMatrix::from(Array2::<f64>::from_shape_fn(
1741 (n, 3),
1742 |(i, j)| 0.2 * (i as f64) - 0.3 * (j as f64),
1743 )));
1744 let raw_log = DesignMatrix::Dense(DenseDesignMatrix::from(Array2::<f64>::from_shape_fn(
1745 (n, 2),
1746 |(i, j)| 0.5 + (i as f64) * (j as f64 + 1.0),
1747 )));
1748
1749 let s_time =
1751 Array2::<f64>::from_shape_fn(
1752 (3, 3),
1753 |(i, j)| if i == j { (i + 2) as f64 } else { 0.3 },
1754 );
1755 let s_marg =
1756 Array2::<f64>::from_shape_fn(
1757 (3, 3),
1758 |(i, j)| if i == j { 1.5 + i as f64 } else { 0.2 },
1759 );
1760 let s_log = Array2::<f64>::from_shape_fn((2, 2), |(i, j)| if i == j { 2.0 } else { 0.5 });
1761 let time_pens = vec![BlockwisePenalty::new(0..3, s_time.clone())];
1762 let marg_pens = vec![BlockwisePenalty::new(0..3, s_marg.clone())];
1763 let log_pens = vec![BlockwisePenalty::new(0..2, s_log.clone())];
1764
1765 let out = apply_compiled_map_to_designs(
1766 &map,
1767 raw_time_entry,
1768 raw_time_exit,
1769 raw_time_deriv,
1770 raw_marg,
1771 raw_log,
1772 &time_pens,
1773 &marg_pens,
1774 &log_pens,
1775 )
1776 .expect("apply_compiled_map_to_designs must succeed");
1777
1778 assert_eq!(out.time_design_entry.ncols(), 3);
1780 assert_eq!(out.marginal_design.ncols(), 2);
1781 assert_eq!(out.logslope_design.ncols(), 2);
1782
1783 for s in &out.time_penalties {
1786 assert_eq!(
1787 s.as_dense_cow().dim(),
1788 (3, 3),
1789 "time penalty must be per-block 3×3, not joint-width"
1790 );
1791 }
1792 for s in &out.marginal_penalties {
1793 assert_eq!(
1794 s.as_dense_cow().dim(),
1795 (2, 2),
1796 "marginal penalty must match reduced compiled width 2, not joint 7"
1797 );
1798 }
1799 for s in &out.logslope_penalties {
1800 assert_eq!(s.as_dense_cow().dim(), (2, 2));
1801 }
1802
1803 let p_time_dense = out.time_penalties[0].as_dense_cow().into_owned();
1807 let theta_time = Array1::<f64>::from_shape_fn(3, |k| 0.4 + 0.7 * (k as f64));
1808 let gamma_time = v_time.dot(&theta_time);
1809 let lhs = theta_time.dot(&p_time_dense.dot(&theta_time));
1810 let rhs = gamma_time.dot(&s_time.dot(&gamma_time));
1811 assert!(
1812 (lhs - rhs).abs() < 1e-10,
1813 "time-block per-block pullback must be exact: lhs={lhs}, rhs={rhs}"
1814 );
1815
1816 let p_marg_dense = out.marginal_penalties[0].as_dense_cow().into_owned();
1819 let want_marg = v_marg.t().dot(&s_marg.dot(&v_marg));
1820 for i in 0..2 {
1821 for j in 0..2 {
1822 assert!(
1823 (p_marg_dense[[i, j]] - want_marg[[i, j]]).abs() < 1e-12,
1824 "marginal penalty must be V_margᵀ S_marg V_marg at ({i},{j})"
1825 );
1826 }
1827 }
1828 }
1829
1830 #[test]
1837 fn compile_survival_parametric_designs_helper_attributes_drop_to_marginal() {
1838 let n = 24;
1839 let p_time = 3;
1840 let p_marginal = 3;
1841 let p_logslope = 2;
1842 let x: Vec<f64> = (0..n)
1843 .map(|i| -1.0 + 2.0 * (i as f64) / (n as f64 - 1.0))
1844 .collect();
1845 let mut time_dq0 = Array2::<f64>::zeros((n, p_time));
1846 let mut time_dq1 = Array2::<f64>::zeros((n, p_time));
1847 let mut time_dqd1 = Array2::<f64>::zeros((n, p_time));
1848 let mut marg_dq = Array2::<f64>::zeros((n, p_marginal));
1849 let marg_dqd1 = Array2::<f64>::zeros((n, p_marginal));
1850 let mut log_dg = Array2::<f64>::zeros((n, p_logslope));
1851 for i in 0..n {
1852 time_dq0[[i, 0]] = 1.0;
1853 time_dq0[[i, 1]] = x[i];
1854 time_dq0[[i, 2]] = x[i] * x[i];
1855 time_dq1[[i, 0]] = 1.0;
1856 time_dq1[[i, 1]] = x[i];
1857 time_dq1[[i, 2]] = x[i] * x[i];
1858 time_dqd1[[i, 0]] = 0.0;
1859 time_dqd1[[i, 1]] = 1.0;
1860 time_dqd1[[i, 2]] = 2.0 * x[i];
1861 marg_dq[[i, 0]] = 1.0; marg_dq[[i, 1]] = x[i] * x[i] * x[i];
1863 marg_dq[[i, 2]] = x[i].sin();
1864 log_dg[[i, 0]] = (2.0 * x[i]).cos();
1865 log_dg[[i, 1]] = x[i].tanh();
1866 }
1867 let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
1868 for i in 0..n {
1869 for k in 0..K_SURVIVAL {
1870 h_full[[i, k, k]] = 1.0;
1871 }
1872 }
1873 let row_hess = SurvivalRowHessian::from_full(h_full);
1874 let out = compile_survival_parametric_designs(
1875 time_dq0, time_dq1, time_dqd1, marg_dq, marg_dqd1, log_dg, &row_hess,
1876 )
1877 .expect("Phase-4b parametric compile must succeed on single-direction alias");
1878 assert_eq!(out.v_time.ncols(), p_time, "time keeps all columns");
1879 assert_eq!(
1880 out.v_marginal.ncols(),
1881 p_marginal - 1,
1882 "marginal loses exactly the shared-constant direction"
1883 );
1884 assert_eq!(out.v_logslope.ncols(), p_logslope, "logslope is clean");
1885 assert_eq!(
1886 out.drops_by_block,
1887 (0, 1, 0),
1888 "attribution: zero from time/logslope, one from marginal",
1889 );
1890 }
1891
1892 #[test]
1913 fn compile_survival_three_block_with_shared_constant_drops_one_direction() {
1914 use gam_identifiability::families::compiler::compile;
1915
1916 let n = 32;
1917 let p_time = 3;
1918 let p_marginal = 3;
1919 let p_logslope = 2;
1920
1921 let x: Vec<f64> = (0..n)
1932 .map(|i| -1.0 + 2.0 * (i as f64) / (n as f64 - 1.0))
1933 .collect();
1934 let mut time_dq0 = Array2::<f64>::zeros((n, p_time));
1935 let mut time_dq1 = Array2::<f64>::zeros((n, p_time));
1936 let mut time_dqd1 = Array2::<f64>::zeros((n, p_time));
1937 for i in 0..n {
1938 time_dq0[[i, 0]] = 1.0;
1939 time_dq0[[i, 1]] = x[i];
1940 time_dq0[[i, 2]] = x[i] * x[i];
1941 time_dq1[[i, 0]] = 1.0;
1942 time_dq1[[i, 1]] = x[i];
1943 time_dq1[[i, 2]] = x[i] * x[i];
1944 time_dqd1[[i, 0]] = 0.0;
1946 time_dqd1[[i, 1]] = 1.0;
1947 time_dqd1[[i, 2]] = 2.0 * x[i];
1948 }
1949
1950 let mut marg_dq = Array2::<f64>::zeros((n, p_marginal));
1956 let marg_dqd1 = Array2::<f64>::zeros((n, p_marginal));
1957 for i in 0..n {
1958 marg_dq[[i, 0]] = 1.0;
1959 marg_dq[[i, 1]] = x[i] * x[i] * x[i];
1960 marg_dq[[i, 2]] = x[i].sin();
1961 }
1962
1963 let mut log_dg = Array2::<f64>::zeros((n, p_logslope));
1967 for i in 0..n {
1968 log_dg[[i, 0]] = (2.0 * x[i]).cos();
1969 log_dg[[i, 1]] = x[i].tanh();
1970 }
1971
1972 let inputs = build_survival_compiler_inputs(
1973 time_dq0, time_dq1, time_dqd1, marg_dq, marg_dqd1, log_dg, None, None,
1974 );
1975
1976 let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
1982 for i in 0..n {
1983 for k in 0..K_SURVIVAL {
1984 h_full[[i, k, k]] = 1.0;
1985 }
1986 }
1987 let row_hess = SurvivalRowHessian::from_full(h_full);
1988
1989 let compiled = compile(&inputs.operators, &row_hess, &inputs.ordering)
1990 .expect("survival 3-block compile must succeed; aliasing is single-direction");
1991
1992 assert_eq!(compiled.blocks.len(), 3, "expected 3 CompiledBlocks");
1994
1995 let v_time = &compiled.blocks[0].t_lw;
2000 assert_eq!(
2001 v_time.ncols(),
2002 p_time,
2003 "time block (first in ordering) must retain all {p_time} of its columns; V_time={:?}",
2004 v_time.dim(),
2005 );
2006
2007 let v_marg = &compiled.blocks[1].t_lw;
2014 assert_eq!(
2015 v_marg.ncols(),
2016 p_marginal - 1,
2017 "marginal block must lose exactly the shared-constant direction; \
2018 V_marginal cols = {}, expected {}",
2019 v_marg.ncols(),
2020 p_marginal - 1,
2021 );
2022
2023 let v_log = &compiled.blocks[2].t_lw;
2026 assert_eq!(
2027 v_log.ncols(),
2028 p_logslope,
2029 "logslope block (no shared direction) must retain all {p_logslope} columns",
2030 );
2031
2032 let raw_total = p_time + p_marginal + p_logslope;
2035 let kept_total: usize = compiled.blocks.iter().map(|b| b.t_lw.ncols()).sum();
2036 assert_eq!(
2037 kept_total,
2038 raw_total - 1,
2039 "joint kept = raw_total − aliased; got {kept_total}, expected {}",
2040 raw_total - 1,
2041 );
2042 assert_eq!(
2043 compiled.joint_rank, kept_total,
2044 "CompiledBlocks::joint_rank must match the sum of per-block t_lw widths",
2045 );
2046
2047 let v_per_term: Vec<Array2<f64>> = compiled.blocks.iter().map(|b| b.t_lw.clone()).collect();
2057 let r_per_term: Vec<Option<Array2<f64>>> = vec![None; v_per_term.len()];
2058 let gauge = Gauge::from_v_and_r(&v_per_term, &r_per_term);
2059
2060 let mut expected_reduced = vec![0usize];
2061 let mut expected_raw = vec![0usize];
2062 for b in &compiled.blocks {
2063 let prev_reduced = *expected_reduced.last().unwrap();
2064 expected_reduced.push(prev_reduced + b.t_lw.ncols());
2065 let prev_raw = *expected_raw.last().unwrap();
2066 expected_raw.push(prev_raw + b.t_lw.nrows());
2067 }
2068 assert_eq!(
2069 *gauge.block_starts_reduced.last().unwrap(),
2070 compiled.joint_rank,
2071 "SMGS lift reduced dimension must equal the compiled joint_rank",
2072 );
2073 assert_eq!(
2074 gauge.block_starts_reduced, expected_reduced,
2075 "SMGS lift reduced block boundaries must match the compiled kept widths",
2076 );
2077 assert_eq!(
2078 gauge.block_starts_raw, expected_raw,
2079 "SMGS lift raw block boundaries must match the compiled per-block raw widths",
2080 );
2081
2082 for (bi, block) in compiled.blocks.iter().enumerate() {
2087 for j in 0..block.t_lw.ncols() {
2088 let col = block.t_lw.column(j);
2089 assert!(
2090 col.iter().all(|v| v.is_finite()),
2091 "block {bi} kept direction {j} has a non-finite entry",
2092 );
2093 let norm = col.dot(&col).sqrt();
2094 assert!(
2095 norm > 1e-10,
2096 "block {bi} kept direction {j} is degenerate (norm {norm:.3e})",
2097 );
2098 }
2099 }
2100 }
2101
2102 #[test]
2105 fn smgs_lift_via_t_identity_passes_through() {
2106 let v0 = Array2::<f64>::eye(3);
2107 let v1 = Array2::<f64>::eye(2);
2108 let v_per_term = vec![v0, v1];
2109 let r_per_term: Vec<Option<Array2<f64>>> = vec![None, None];
2110 let lift = Gauge::from_v_and_r(&v_per_term, &r_per_term);
2111 assert_eq!(lift.t_full.dim(), (5, 5));
2112 assert_eq!(lift.block_starts_reduced, vec![0, 3, 5]);
2113 assert_eq!(lift.block_starts_raw, vec![0, 3, 5]);
2114 for i in 0..5 {
2115 for j in 0..5 {
2116 let want = if i == j { 1.0 } else { 0.0 };
2117 assert!((lift.t_full[[i, j]] - want).abs() < 1e-14);
2118 }
2119 }
2120 let theta_0 = Array1::from(vec![1.0_f64, -2.0, 3.5]);
2121 let theta_1 = Array1::from(vec![-0.5_f64, 7.0]);
2122 let lifted = lift.lift_block_betas(&[theta_0.clone(), theta_1.clone()]);
2123 assert_eq!(lifted.len(), 2);
2124 for (a, b) in theta_0.iter().zip(lifted[0].iter()) {
2125 assert!((a - b).abs() < 1e-14);
2126 }
2127 for (a, b) in theta_1.iter().zip(lifted[1].iter()) {
2128 assert!((a - b).abs() < 1e-14);
2129 }
2130 }
2131
2132 #[test]
2136 fn smgs_lift_via_t_two_block_with_residualisation() {
2137 let v_a = Array2::<f64>::eye(3);
2138 let mut v_b = Array2::<f64>::zeros((3, 2));
2139 v_b[[0, 0]] = 1.0;
2140 v_b[[2, 1]] = 1.0;
2141 let mut r_b = Array2::<f64>::zeros((3, 2));
2142 r_b[[0, 0]] = 0.4;
2143 r_b[[0, 1]] = -0.1;
2144 r_b[[1, 0]] = 0.7;
2145 r_b[[1, 1]] = 1.3;
2146 r_b[[2, 0]] = -0.2;
2147 r_b[[2, 1]] = 0.5;
2148 let lift = Gauge::from_v_and_r(&[v_a.clone(), v_b.clone()], &[None, Some(r_b.clone())]);
2149 assert_eq!(lift.t_full.dim(), (6, 5));
2150 assert_eq!(lift.block_starts_reduced, vec![0, 3, 5]);
2151 assert_eq!(lift.block_starts_raw, vec![0, 3, 6]);
2152
2153 let theta_a = Array1::from(vec![1.0_f64, 2.0, -1.5]);
2154 let theta_b = Array1::from(vec![0.5_f64, -0.25]);
2155 let lifted = lift.lift_block_betas(&[theta_a.clone(), theta_b.clone()]);
2156 let r_theta_b = r_b.dot(&theta_b);
2157 let expected_a = &theta_a - &r_theta_b;
2158 assert_eq!(lifted[0].len(), 3);
2159 for (got, want) in lifted[0].iter().zip(expected_a.iter()) {
2160 assert!((got - want).abs() < 1e-12, "got {got}, want {want}");
2161 }
2162 assert_eq!(lifted[1].len(), 3);
2163 assert!((lifted[1][0] - theta_b[0]).abs() < 1e-12);
2164 assert!(lifted[1][1].abs() < 1e-12);
2165 assert!((lifted[1][2] - theta_b[1]).abs() < 1e-12);
2166 }
2167
2168 #[test]
2180 fn smgs_lift_covariance_identity_and_rank1_consistency() {
2181 let lift_id = Gauge::from_v_and_r(
2183 &[Array2::<f64>::eye(2), Array2::<f64>::eye(2)],
2184 &[None, None],
2185 );
2186 let mut cov = Array2::<f64>::zeros((4, 4));
2187 for i in 0..4 {
2189 for j in 0..4 {
2190 cov[[i, j]] = 1.0 / (1.0 + (i as f64 - j as f64).abs());
2191 }
2192 }
2193 let lifted_id = lift_id.lift_covariance(&cov);
2194 assert_eq!(lifted_id.dim(), (4, 4));
2195 for i in 0..4 {
2196 for j in 0..4 {
2197 assert!(
2198 (lifted_id[[i, j]] - cov[[i, j]]).abs() < 1e-12,
2199 "identity-T covariance lift must be a no-op at [{i},{j}]",
2200 );
2201 }
2202 }
2203
2204 let v_a = Array2::<f64>::eye(3);
2209 let mut v_b = Array2::<f64>::zeros((3, 2));
2210 v_b[[0, 0]] = 1.0;
2211 v_b[[2, 1]] = 1.0;
2212 let mut r_b = Array2::<f64>::zeros((3, 2));
2213 r_b[[0, 0]] = 0.4;
2214 r_b[[0, 1]] = -0.1;
2215 r_b[[1, 0]] = 0.7;
2216 r_b[[1, 1]] = 1.3;
2217 r_b[[2, 0]] = -0.2;
2218 r_b[[2, 1]] = 0.5;
2219 let lift = Gauge::from_v_and_r(&[v_a, v_b], &[None, Some(r_b)]);
2220
2221 let theta_a = Array1::from(vec![1.0_f64, 2.0, -1.5]);
2222 let theta_b = Array1::from(vec![0.5_f64, -0.25]);
2223 let theta_full = Array1::from(vec![
2225 theta_a[0], theta_a[1], theta_a[2], theta_b[0], theta_b[1],
2226 ]);
2227 let mut cov_rank1 = Array2::<f64>::zeros((5, 5));
2229 for i in 0..5 {
2230 for j in 0..5 {
2231 cov_rank1[[i, j]] = theta_full[i] * theta_full[j];
2232 }
2233 }
2234 let lifted_cov = lift.lift_covariance(&cov_rank1);
2235 let lifted_blocks = lift.lift_block_betas(&[theta_a, theta_b]);
2237 let beta_raw = Array1::from(
2238 lifted_blocks
2239 .iter()
2240 .flat_map(|b| b.iter().copied())
2241 .collect::<Vec<f64>>(),
2242 );
2243 assert_eq!(lifted_cov.dim(), (6, 6));
2244 assert_eq!(beta_raw.len(), 6);
2245 for i in 0..6 {
2246 for j in 0..6 {
2247 let want = beta_raw[i] * beta_raw[j];
2248 assert!(
2249 (lifted_cov[[i, j]] - want).abs() < 1e-10,
2250 "rank-1 covariance pushforward must equal (Tθ)(Tθ)ᵀ at [{i},{j}]: got {}, want {want}",
2251 lifted_cov[[i, j]],
2252 );
2253 }
2254 }
2255 for i in 0..6 {
2257 for j in 0..6 {
2258 assert!((lifted_cov[[i, j]] - lifted_cov[[j, i]]).abs() < 1e-14);
2259 }
2260 }
2261 }
2262
2263 #[test]
2266 fn smgs_lift_via_t_zero_r_matches_per_block_v_lift() {
2267 let mut v_a = Array2::<f64>::zeros((3, 2));
2268 v_a[[0, 0]] = 0.6;
2269 v_a[[1, 0]] = -0.8;
2270 v_a[[1, 1]] = 0.3;
2271 v_a[[2, 1]] = 0.9;
2272 let mut v_b = Array2::<f64>::zeros((4, 3));
2273 v_b[[0, 0]] = 1.0;
2274 v_b[[1, 1]] = -0.4;
2275 v_b[[2, 0]] = 0.2;
2276 v_b[[2, 2]] = 0.7;
2277 v_b[[3, 2]] = -1.1;
2278 let v_per_term = vec![v_a.clone(), v_b.clone()];
2279 let lift = Gauge::from_v_and_r(&v_per_term, &[None, None]);
2280 let theta_a = Array1::from(vec![0.3_f64, -1.4]);
2281 let theta_b = Array1::from(vec![2.1_f64, 0.0, -0.7]);
2282 let via_t = lift.lift_block_betas(&[theta_a.clone(), theta_b.clone()]);
2283 let ref_a = v_a.dot(&theta_a);
2284 let ref_b = v_b.dot(&theta_b);
2285 assert_eq!(via_t[0].len(), ref_a.len());
2286 for (g, w) in via_t[0].iter().zip(ref_a.iter()) {
2287 assert!((g - w).abs() < 1e-12);
2288 }
2289 assert_eq!(via_t[1].len(), ref_b.len());
2290 for (g, w) in via_t[1].iter().zip(ref_b.iter()) {
2291 assert!((g - w).abs() < 1e-12);
2292 }
2293 }
2294
2295 #[test]
2305 fn recompile_after_accept_diff_detection_pilot_curvature_trap() {
2306 let n = 6usize;
2307 let time_dq0 = Array2::<f64>::from_elem((n, 1), 1.0);
2311 let time_dq1 = Array2::<f64>::zeros((n, 1));
2312 let time_dqd1 = Array2::<f64>::zeros((n, 1));
2313 let marg_dq = Array2::<f64>::from_elem((n, 1), 1.0);
2318 let marg_dqd1 = Array2::<f64>::zeros((n, 1));
2319 let log_dg = Array2::<f64>::zeros((n, 0));
2321 let mut time_partition: Vec<std::ops::Range<usize>> = Vec::with_capacity(1);
2322 time_partition.push(0..1);
2323 let mut marg_partition: Vec<std::ops::Range<usize>> = Vec::with_capacity(1);
2324 marg_partition.push(0..1);
2325 let log_partition: Vec<std::ops::Range<usize>> = Vec::new();
2326
2327 let mut h_ident = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
2331 for i in 0..n {
2332 for k in 0..K_SURVIVAL {
2333 h_ident[[i, k, k]] = 1.0;
2334 }
2335 }
2336 let row_hess_ident = SurvivalRowHessian::from_full(h_ident);
2337 let compiled_ident = compile_survival_parametric_designs_per_term(
2338 time_dq0.clone(),
2339 time_dq1.clone(),
2340 time_dqd1.clone(),
2341 &time_partition,
2342 marg_dq.clone(),
2343 marg_dqd1.clone(),
2344 &marg_partition,
2345 log_dg.clone(),
2346 &log_partition,
2347 &row_hess_ident,
2348 )
2349 .expect("identity-H compile must succeed");
2350
2351 let mut h_q0_only = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
2355 for i in 0..n {
2356 h_q0_only[[i, 0, 0]] = 1.0;
2357 }
2358 let row_hess_q0 = SurvivalRowHessian::from_full(h_q0_only);
2359 let compiled_q0 = compile_survival_parametric_designs_per_term(
2360 time_dq0,
2361 time_dq1,
2362 time_dqd1,
2363 &time_partition,
2364 marg_dq,
2365 marg_dqd1,
2366 &marg_partition,
2367 log_dg,
2368 &log_partition,
2369 &row_hess_q0,
2370 )
2371 .expect("q0-only-H compile must succeed");
2372
2373 assert_ne!(
2377 compiled_ident.drops_by_block, compiled_q0.drops_by_block,
2378 "structural-H and data-adaptive-H compiles must produce different \
2379 drops_by_block on the constructed pilot-curvature-trap design; \
2380 identity={:?} q0-only={:?}",
2381 compiled_ident.drops_by_block, compiled_q0.drops_by_block,
2382 );
2383 assert_eq!(
2385 compiled_ident.drops_by_block.1, 0,
2386 "identity-H marg drops expected 0, got {:?}",
2387 compiled_ident.drops_by_block,
2388 );
2389 assert_eq!(
2391 compiled_q0.drops_by_block.1, 1,
2392 "q0-only-H marg drops expected 1, got {:?}",
2393 compiled_q0.drops_by_block,
2394 );
2395 }
2396
2397 #[test]
2398 fn compiled_map_from_per_term_partitions_and_lift_round_trip() {
2399 let v_time = Array2::<f64>::eye(2);
2403 let mut v_marg = Array2::<f64>::zeros((2, 1));
2404 v_marg[[0, 0]] = 1.0;
2405 v_marg[[1, 0]] = 0.5;
2406 let v_log = Array2::<f64>::eye(1);
2407 let r_marg = Array2::<f64>::from_shape_fn((2, 1), |(i, _)| 0.25 + i as f64);
2410 let r_log = Array2::<f64>::from_shape_fn((4, 1), |(i, _)| 0.1 * (i as f64 + 1.0));
2411 let per_term = SurvivalParametricCompiledPerTerm {
2412 v_time_per_term: vec![v_time.clone()],
2413 v_marginal_per_term: vec![v_marg.clone()],
2414 v_logslope_per_term: vec![v_log.clone()],
2415 r_lw_per_term: vec![None, Some(r_marg.clone()), Some(r_log.clone())],
2416 drops_by_block: (0, 1, 0),
2417 };
2418
2419 let map = compiled_map_from_per_term(&per_term);
2420
2421 assert_eq!(map.raw_block_ranges, vec![0..2, 2..4, 4..5]);
2423 assert_eq!(map.compiled_block_ranges, vec![0..2, 2..3, 3..4]);
2425 assert_eq!(map.raw_from_compiled.dim(), (5, 4));
2426
2427 let v_time_slice = map
2430 .raw_from_compiled
2431 .slice(ndarray::s![0..2, 0..2])
2432 .to_owned();
2433 let v_marg_slice = map
2434 .raw_from_compiled
2435 .slice(ndarray::s![2..4, 2..3])
2436 .to_owned();
2437 let v_log_slice = map
2438 .raw_from_compiled
2439 .slice(ndarray::s![4..5, 3..4])
2440 .to_owned();
2441 for i in 0..2 {
2442 for j in 0..2 {
2443 assert!((v_time_slice[[i, j]] - v_time[[i, j]]).abs() < 1e-14);
2444 }
2445 assert!((v_marg_slice[[i, 0]] - v_marg[[i, 0]]).abs() < 1e-14);
2446 }
2447 assert!((v_log_slice[[0, 0]] - v_log[[0, 0]]).abs() < 1e-14);
2448
2449 let ordering = [
2452 gam_identifiability::families::compiler::BlockOrder::Time,
2453 gam_identifiability::families::compiler::BlockOrder::Marginal,
2454 gam_identifiability::families::compiler::BlockOrder::Logslope,
2455 ];
2456 let lift_from_map = Gauge::from_compiled_map(&map, &ordering);
2457 let v_all = vec![v_time, v_marg, v_log];
2458 let lift_direct = Gauge::from_v_and_r(&v_all, &[None, Some(r_marg), Some(r_log)]);
2459 assert_eq!(lift_from_map.t_full.dim(), lift_direct.t_full.dim());
2460 for i in 0..lift_from_map.t_full.nrows() {
2461 for j in 0..lift_from_map.t_full.ncols() {
2462 assert!(
2463 (lift_from_map.t_full[[i, j]] - lift_direct.t_full[[i, j]]).abs() < 1e-14,
2464 "T mismatch at ({i},{j}): map={} direct={}",
2465 lift_from_map.t_full[[i, j]],
2466 lift_direct.t_full[[i, j]],
2467 );
2468 }
2469 }
2470 }
2471
2472 fn const_row_hess_q0g(n: usize, h00: f64, h03: f64, h33: f64) -> SurvivalRowHessian {
2488 let mut h = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
2489 for i in 0..n {
2490 h[[i, 0, 0]] = h00;
2491 h[[i, 0, 3]] = h03;
2492 h[[i, 3, 0]] = h03;
2493 h[[i, 3, 3]] = h33;
2494 }
2495 SurvivalRowHessian::from_full(h)
2496 }
2497
2498 #[test]
2499 fn survival_reduced_logslope_drops_confounded_keeps_free_979() {
2500 let n = 4;
2506 let row_hess = const_row_hess_q0g(n, 2.0, 2.0, 2.0); let marg = Array2::from_shape_vec((n, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap();
2508 let log =
2511 Array2::from_shape_vec((n, 2), vec![1.0, 10.0, 1.0, -10.0, 1.0, 10.0, 1.0, -10.0])
2512 .unwrap();
2513 let t = survival_reduced_logslope_transform_effective(marg.view(), log.view(), &row_hess)
2514 .expect("contraction must succeed")
2515 .expect("a partial confound must yield a reduced transform");
2516 assert_eq!(t.dim(), (2, 1), "exactly one logslope direction survives");
2517 assert!(
2520 t[[0, 0]].abs() < 1e-6,
2521 "confounded (e1) direction must be dropped, got {}",
2522 t[[0, 0]]
2523 );
2524 assert!(
2525 (t[[1, 0]].abs() - 1.0).abs() < 1e-6,
2526 "free (e2) direction must be kept as a unit vector, got {}",
2527 t[[1, 0]]
2528 );
2529 }
2530
2531 #[test]
2532 fn survival_reduced_logslope_fully_confounded_returns_none_979() {
2533 let n = 4;
2539 let row_hess = const_row_hess_q0g(n, 2.0, 2.0, 2.0);
2540 let marg = Array2::from_shape_vec((n, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap();
2541 let log = marg.clone();
2542 let out =
2543 survival_reduced_logslope_transform_effective(marg.view(), log.view(), &row_hess)
2544 .expect("contraction must succeed");
2545 assert!(
2546 out.is_none(),
2547 "a fully marginal-explained logslope column reduces to nothing → keep raw"
2548 );
2549 }
2550
2551 #[test]
2552 fn survival_reduced_logslope_no_confound_returns_none_979() {
2553 let n = 4;
2557 let row_hess = const_row_hess_q0g(n, 2.0, 0.0, 2.0);
2558 let marg = Array2::from_shape_vec((n, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap();
2559 let log =
2560 Array2::from_shape_vec((n, 2), vec![1.0, 10.0, 1.0, -10.0, 1.0, 10.0, 1.0, -10.0])
2561 .unwrap();
2562 let out =
2563 survival_reduced_logslope_transform_effective(marg.view(), log.view(), &row_hess)
2564 .expect("contraction must succeed");
2565 assert!(out.is_none(), "W-orthogonal channels need no reduction → keep raw");
2566 }
2567
2568 #[test]
2569 fn survival_block_diagonal_logslope_map_is_identity_on_time_and_marginal_979() {
2570 let p_time = 2;
2573 let p_marg = 3;
2574 let t_log = Array2::from_shape_fn((4, 2), |(i, j)| 1.0 + (i * 2 + j) as f64);
2575 let map = survival_block_diagonal_logslope_map(p_time, p_marg, &t_log);
2576
2577 assert_eq!(map.raw_block_ranges, vec![0..2, 2..5, 5..9]);
2578 assert_eq!(map.compiled_block_ranges, vec![0..2, 2..5, 5..7]);
2579 assert_eq!(map.raw_from_compiled.dim(), (9, 7));
2580
2581 let t = &map.raw_from_compiled;
2582 for i in 0..p_time {
2584 for j in 0..p_time {
2585 let want = if i == j { 1.0 } else { 0.0 };
2586 assert!((t[[i, j]] - want).abs() < 1e-14, "V_time[{i},{j}]");
2587 }
2588 }
2589 for i in 0..p_marg {
2591 for j in 0..p_marg {
2592 let want = if i == j { 1.0 } else { 0.0 };
2593 assert!((t[[p_time + i, p_time + j]] - want).abs() < 1e-14, "V_marg[{i},{j}]");
2594 }
2595 }
2596 for i in 0..4 {
2598 for j in 0..2 {
2599 assert!(
2600 (t[[p_time + p_marg + i, p_time + p_marg + j]] - t_log[[i, j]]).abs() < 1e-14,
2601 "V_log[{i},{j}]"
2602 );
2603 }
2604 }
2605 let nnz = t.iter().filter(|&&v| v != 0.0).count();
2608 assert_eq!(nnz, p_time + p_marg + t_log.iter().filter(|&&v| v != 0.0).count());
2609 }
2610}