1use std::ops::Range;
15use std::sync::Arc;
16
17use ndarray::{Array1, Array2, Array3, Axis, s};
18
19use faer::Side;
20use gam_linalg::faer_ndarray::{
21 FaerEigh, default_rrqr_rank_alpha, fast_ab, fast_ata, fast_atb, fast_xt_diag_y,
22 rrqr_with_permutation,
23};
24
25const RANK_REVEAL_EPS_SLACK: f64 = 64.0;
34
35pub trait RowJacobianOperator: Send + Sync {
42 fn k(&self) -> usize;
44
45 fn ncols(&self) -> usize;
47
48 fn nrows(&self) -> usize;
50
51 fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]);
53
54 fn evaluate_full(&self) -> Array3<f64>;
56
57 fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
73 scale_block_by_sqrt_h(&self.evaluate_full(), h_full)
74 }
75
76 fn channel_flattened_column(&self, col: usize, out: &mut [f64]) {
90 let k = self.k();
91 let n = self.nrows();
92 assert!(
93 col < self.ncols(),
94 "channel_flattened_column col {col} out of range {}",
95 self.ncols()
96 );
97 assert_eq!(
98 out.len(),
99 n * k,
100 "channel_flattened_column out length {} != n*k = {}*{}",
101 out.len(),
102 n,
103 k
104 );
105 let full = self.evaluate_full();
106 for i in 0..n {
107 for ch in 0..k {
108 out[i * k + ch] = full[[i, col, ch]];
109 }
110 }
111 }
112
113 fn channel_flattened_rows(&self, rows: Range<usize>, out: &mut Array2<f64>) {
120 let n = self.nrows();
121 let start = rows.start.min(n);
122 let end = rows.end.min(n);
123 let chunk = end - start;
124 let k = self.k();
125 let p = self.ncols();
126 assert_eq!(out.shape(), &[chunk * k, p]);
127 let full = self.evaluate_full();
128 for local_i in 0..chunk {
129 let row = start + local_i;
130 for ch in 0..k {
131 for col in 0..p {
132 out[[local_i * k + ch, col]] = full[[row, col, ch]];
133 }
134 }
135 }
136 }
137}
138
139pub trait RowHessian: Send + Sync {
141 fn k(&self) -> usize;
142 fn nrows(&self) -> usize;
143 fn fill_row(&self, row: usize, out: &mut [f64]);
145 fn evaluate_full(&self) -> Array3<f64>;
147}
148
149pub struct IdentityRowHessian {
156 n: usize,
157 k: usize,
158}
159
160impl IdentityRowHessian {
161 pub fn new(n: usize, k: usize) -> Self {
164 Self { n, k }
165 }
166}
167
168impl RowHessian for IdentityRowHessian {
169 fn k(&self) -> usize {
170 self.k
171 }
172 fn nrows(&self) -> usize {
173 self.n
174 }
175 fn fill_row(&self, row: usize, out: &mut [f64]) {
176 assert!(
177 row < self.n,
178 "IdentityRowHessian::fill_row row {row} out of range {n}",
179 n = self.n
180 );
181 assert_eq!(out.len(), self.k * self.k);
182 for i in 0..self.k {
183 for j in 0..self.k {
184 out[i * self.k + j] = if i == j { 1.0 } else { 0.0 };
185 }
186 }
187 }
188 fn evaluate_full(&self) -> Array3<f64> {
189 let mut out = Array3::<f64>::zeros((self.n, self.k, self.k));
190 for i in 0..self.n {
191 for c in 0..self.k {
192 out[[i, c, c]] = 1.0;
193 }
194 }
195 out
196 }
197}
198
199pub struct CompiledBlock {
203 pub t_lw: Array2<f64>,
205 pub anchor_correction: Option<Array2<f64>>,
213 pub r_lw: Option<Array2<f64>>,
217}
218
219pub struct CompiledBlocks {
222 pub blocks: Vec<CompiledBlock>,
223 pub joint_rank: usize,
225 pub dropped: Vec<(usize, usize)>,
228}
229
230#[derive(Debug, Clone, Copy, PartialEq, Eq)]
233pub enum PenalizedDirectionAnnotationKind {
234 Independent,
237 PartiallyAbsorbedByHigherPriority,
240 FullyAbsorbedByHigherPriority,
244}
245
246#[derive(Debug, Clone, Copy, PartialEq, Eq)]
248pub struct PenalizedDirectionAnnotation {
249 pub block_idx: usize,
250 pub raw_width: usize,
251 pub kept_width: usize,
252 pub absorbed_width: usize,
253 pub kind: PenalizedDirectionAnnotationKind,
254}
255
256#[derive(Debug)]
258pub enum CompilerError {
259 DimensionMismatch(String),
261 FullyAliased { block_idx: usize, reason: String },
264 LinalgFailure(String),
266}
267
268impl std::fmt::Display for CompilerError {
269 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
270 match self {
271 CompilerError::DimensionMismatch(msg) => write!(f, "dimension mismatch: {msg}"),
272 CompilerError::FullyAliased { block_idx, reason } => {
273 write!(f, "block {block_idx} fully aliased: {reason}")
274 }
275 CompilerError::LinalgFailure(msg) => write!(f, "linalg failure: {msg}"),
276 }
277 }
278}
279
280impl std::error::Error for CompilerError {}
281
282#[derive(Debug, Clone, Copy, PartialEq, Eq)]
287pub enum BlockOrder {
288 Time,
289 Marginal,
290 Logslope,
291 ScoreWarp,
292 LinkDev,
293}
294
295pub fn compile(
305 operators: &[Arc<dyn RowJacobianOperator>],
306 row_hess: &dyn RowHessian,
307 ordering: &[BlockOrder],
308) -> Result<CompiledBlocks, CompilerError> {
309 compile_protected(operators, row_hess, ordering, &[])
310}
311
312pub fn compile_protected(
323 operators: &[Arc<dyn RowJacobianOperator>],
324 row_hess: &dyn RowHessian,
325 ordering: &[BlockOrder],
326 protected: &[bool],
327) -> Result<CompiledBlocks, CompilerError> {
328 let n = row_hess.nrows();
334 let k = row_hess.k();
335 let id_struct = IdentityRowHessian::new(n, k);
336 compile_with_dual_metric_protected(operators, row_hess, &id_struct, ordering, protected)
337}
338
339pub fn compile_with_dual_metric(
370 operators: &[Arc<dyn RowJacobianOperator>],
371 row_hess: &dyn RowHessian,
372 row_structural: &dyn RowHessian,
373 ordering: &[BlockOrder],
374) -> Result<CompiledBlocks, CompilerError> {
375 compile_with_dual_metric_protected(operators, row_hess, row_structural, ordering, &[])
376}
377
378pub fn compile_with_dual_metric_protected(
385 operators: &[Arc<dyn RowJacobianOperator>],
386 row_hess: &dyn RowHessian,
387 row_structural: &dyn RowHessian,
388 ordering: &[BlockOrder],
389 protected: &[bool],
390) -> Result<CompiledBlocks, CompilerError> {
391 if operators.len() != ordering.len() {
392 return Err(CompilerError::DimensionMismatch(format!(
393 "operators ({}) and ordering ({}) length mismatch",
394 operators.len(),
395 ordering.len()
396 )));
397 }
398 if operators.is_empty() {
399 return Ok(CompiledBlocks {
400 blocks: Vec::new(),
401 joint_rank: 0,
402 dropped: Vec::new(),
403 });
404 }
405
406 let k = row_hess.k();
407 let n = row_hess.nrows();
408 if row_structural.k() != k {
409 return Err(CompilerError::DimensionMismatch(format!(
410 "structural row metric has K={} but curvature row Hessian has K={k}",
411 row_structural.k()
412 )));
413 }
414 if row_structural.nrows() != n {
415 return Err(CompilerError::DimensionMismatch(format!(
416 "structural row metric has nrows={} but curvature row Hessian has nrows={n}",
417 row_structural.nrows()
418 )));
419 }
420 for (idx, op) in operators.iter().enumerate() {
421 if op.k() != k {
422 return Err(CompilerError::DimensionMismatch(format!(
423 "operator {idx} has K={} but row Hessian has K={k}",
424 op.k()
425 )));
426 }
427 if op.nrows() != n {
428 return Err(CompilerError::DimensionMismatch(format!(
429 "operator {idx} has nrows={} but row Hessian has nrows={n}",
430 op.nrows()
431 )));
432 }
433 }
434
435 let h_full = row_hess.evaluate_full();
438 let s_full = row_structural.evaluate_full();
439
440 let scaled_h: Vec<Array2<f64>> = operators
451 .iter()
452 .map(|op| op.scaled_design_by_sqrt_h(&h_full))
453 .collect();
454 let scaled_s: Vec<Array2<f64>> = operators
455 .iter()
456 .map(|op| op.scaled_design_by_sqrt_h(&s_full))
457 .collect();
458
459 let mut compiled: Vec<CompiledBlock> = Vec::with_capacity(operators.len());
460 let mut walk_demotions: Vec<(usize, usize)> = Vec::new();
468 let mut anchor_h: Array2<f64> = Array2::zeros((n * k, 0));
469 let mut anchor_s: Array2<f64> = Array2::zeros((n * k, 0));
470 let mut raw_anchor_h: Array2<f64> = Array2::zeros((n * k, 0));
479
480 for idx in 0..operators.len() {
481 let w_h = &scaled_h[idx];
482 let w_s = &scaled_s[idx];
483 let p_b = w_h.ncols();
484 let block_protected = protected.get(idx).copied().unwrap_or(false);
485
486 if p_b == 0 {
495 compiled.push(CompiledBlock {
496 t_lw: Array2::<f64>::zeros((0, 0)),
497 anchor_correction: Some(Array2::<f64>::zeros((raw_anchor_h.ncols(), 0))),
498 r_lw: Some(Array2::<f64>::zeros((raw_anchor_h.ncols(), 0))),
499 });
500 continue;
501 }
502
503 let (residual_s, _) = residualise_in_metric(&anchor_s, w_s)?;
512 let g_s = fast_atb(&residual_s, &residual_s);
513 let g_s_bb = fast_atb(w_s, w_s);
520 let g_s_trace: f64 = (0..p_b).map(|i| g_s_bb[[i, i]].max(0.0)).sum();
521 let d = if block_protected {
525 Array2::<f64>::eye(p_b)
526 } else {
527 keep_positive_eigenspace(&g_s, n, k, g_s_trace)?
528 };
529 if d.ncols() == 0 {
530 if anchor_h.ncols() == 0 {
531 return Err(CompilerError::FullyAliased {
532 block_idx: idx,
533 reason: format!(
534 "structural residual Gram has no positive eigenspace (block of width {p_b} has zero structural span before any anchor exists)"
535 ),
536 });
537 }
538 compiled.push(CompiledBlock {
539 t_lw: Array2::<f64>::zeros((p_b, 0)),
540 anchor_correction: Some(Array2::<f64>::zeros((raw_anchor_h.ncols(), 0))),
541 r_lw: Some(Array2::<f64>::zeros((raw_anchor_h.ncols(), 0))),
542 });
543 for c in 0..p_b {
547 walk_demotions.push((idx, c));
548 }
549 raw_anchor_h = concat_cols(&raw_anchor_h, w_h);
550 continue;
551 }
552
553 let w_h_d = fast_ab(w_h, &d);
560 let (residual_h, m_h_inner_opt) = residualise_in_metric(&anchor_h, &w_h_d)?;
561 let g_h = fast_atb(&residual_h, &residual_h);
562 let p_d = d.ncols();
563 let g_h_dd = fast_atb(&w_h_d, &w_h_d);
569 let g_h_trace: f64 = (0..p_d).map(|i| g_h_dd[[i, i]].max(0.0)).sum();
570 let t_inner = if block_protected {
574 Array2::<f64>::eye(p_d)
575 } else {
576 keep_positive_eigenspace(&g_h, n, k, g_h_trace)?
577 };
578 if t_inner.ncols() == 0 {
579 if anchor_h.ncols() == 0 {
580 return Err(CompilerError::FullyAliased {
581 block_idx: idx,
582 reason: format!(
583 "curvature residual Gram has no positive eigenspace within structurally-kept basis (block of width {p_b}, structural-kept {p_d}) before any anchor exists"
584 ),
585 });
586 }
587 compiled.push(CompiledBlock {
588 t_lw: Array2::<f64>::zeros((p_b, 0)),
589 anchor_correction: Some(Array2::<f64>::zeros((raw_anchor_h.ncols(), 0))),
590 r_lw: Some(Array2::<f64>::zeros((raw_anchor_h.ncols(), 0))),
591 });
592 for c in 0..p_d {
597 walk_demotions.push((idx, c));
598 }
599 raw_anchor_h = concat_cols(&raw_anchor_h, w_h);
600 continue;
601 }
602
603 let v = fast_ab(&d, &t_inner);
605
606 let prior_anchor_h = anchor_h.clone();
614 let prior_raw_anchor_h = raw_anchor_h.clone();
615
616 let residual_h_t = fast_ab(&residual_h, &t_inner);
620 anchor_h = concat_cols(&anchor_h, &residual_h_t);
621 let residual_s_v = fast_ab(&residual_s, &v);
624 anchor_s = concat_cols(&anchor_s, &residual_s_v);
625
626 let m_compiled = match m_h_inner_opt.as_ref() {
653 Some(m) => {
654 let m_kept = fast_ab(m, &t_inner);
655 if m_kept.nrows() != prior_anchor_h.ncols() {
656 return Err(CompilerError::DimensionMismatch(format!(
657 "anchor correction must be indexed by prior-block kept anchor directions: \
658 m_kept has {} rows but prior_anchor_h has {} columns",
659 m_kept.nrows(),
660 prior_anchor_h.ncols()
661 )));
662 }
663 let g_raw = fast_atb(&prior_raw_anchor_h, &prior_raw_anchor_h);
664 let z_rhs = fast_atb(&prior_raw_anchor_h, &prior_anchor_h);
665 let z = solve_psd_system(&g_raw, &z_rhs)?;
666 Some(fast_ab(&z, &m_kept))
667 }
668 None => None,
669 };
670 compiled.push(CompiledBlock {
671 t_lw: v,
672 anchor_correction: m_compiled.clone(),
673 r_lw: m_compiled,
674 });
675
676 raw_anchor_h = concat_cols(&raw_anchor_h, w_h);
680 }
681
682 let audit_dropped = audit_and_drop_trailing_pivots(&anchor_h, &mut compiled)?;
685 let mut dropped = walk_demotions;
689 dropped.extend(audit_dropped);
690 let joint_rank: usize = compiled.iter().map(|b| b.t_lw.ncols()).sum();
691
692 Ok(CompiledBlocks {
693 blocks: compiled,
694 joint_rank,
695 dropped,
696 })
697}
698
699fn scale_block_by_sqrt_h(jb: &Array3<f64>, h_full: &Array3<f64>) -> Array2<f64> {
703 let n = jb.shape()[0];
704 let p = jb.shape()[1];
705 let k = jb.shape()[2];
706 scale_jacobian_by_sqrt_h_with(n, p, k, h_full, |i, a, c| jb[[i, a, c]])
707}
708
709pub fn scale_jacobian_by_sqrt_h_with(
723 n: usize,
724 p: usize,
725 k: usize,
726 h_full: &Array3<f64>,
727 jac: impl Fn(usize, usize, usize) -> f64,
728) -> Array2<f64> {
729 assert_eq!(h_full.shape(), &[n, k, k]);
730 let mut out = Array2::<f64>::zeros((n * k, p));
731 let mut sqrt_h = Array2::<f64>::zeros((k, k));
732 let mut scratch_jrow = Array2::<f64>::zeros((p, k));
733 for i in 0..n {
734 let h_i = h_full.index_axis(Axis(0), i).to_owned();
736 sqrt_h.fill(0.0);
737 symmetric_sqrt_into(&h_i, &mut sqrt_h);
738 for a in 0..p {
742 for c in 0..k {
743 scratch_jrow[[a, c]] = jac(i, a, c);
744 }
745 }
746 for c in 0..k {
747 for a in 0..p {
748 let mut acc = 0.0;
749 for cp in 0..k {
750 acc += sqrt_h[[c, cp]] * scratch_jrow[[a, cp]];
751 }
752 out[[i * k + c, a]] = acc;
753 }
754 }
755 }
756 out
757}
758
759pub(crate) fn symmetric_sqrt_into(m: &Array2<f64>, out: &mut Array2<f64>) {
762 let k = m.nrows();
763 assert_eq!(m.ncols(), k);
764 assert_eq!(out.shape(), &[k, k]);
765 if k == 1 {
766 out[[0, 0]] = m[[0, 0]].max(0.0).sqrt();
767 return;
768 }
769 let (evals, evecs) = match m.eigh(Side::Lower) {
770 Ok(pair) => pair,
771 Err(_) => {
772 out.fill(0.0);
775 for i in 0..k {
776 out[[i, i]] = m[[i, i]].max(0.0).sqrt();
777 }
778 return;
779 }
780 };
781 let mut scaled = evecs.clone();
783 for j in 0..k {
784 let s = evals[j].max(0.0).sqrt();
785 for i in 0..k {
786 scaled[[i, j]] *= s;
787 }
788 }
789 out.assign(&fast_atb(&evecs.t().to_owned(), &scaled.t().to_owned()));
790 out.fill(0.0);
794 for i in 0..k {
795 for j in 0..k {
796 let mut acc = 0.0;
797 for l in 0..k {
798 acc += evecs[[i, l]] * evals[l].max(0.0).sqrt() * evecs[[j, l]];
799 }
800 out[[i, j]] = acc;
801 }
802 }
803}
804
805fn residualise_in_metric(
809 a_scaled: &Array2<f64>,
810 b_scaled: &Array2<f64>,
811) -> Result<(Array2<f64>, Option<Array2<f64>>), CompilerError> {
812 let d = a_scaled.ncols();
813 if d == 0 {
814 return Ok((b_scaled.clone(), None));
815 }
816 let g_aa = fast_atb(a_scaled, a_scaled);
817 let g_ab = fast_atb(a_scaled, b_scaled);
818 let m = solve_psd_system(&g_aa, &g_ab)?;
819 let a_m = fast_ab(a_scaled, &m);
820 let residual = b_scaled - &a_m;
821 Ok((residual, Some(m)))
822}
823
824fn solve_psd_system(g: &Array2<f64>, r: &Array2<f64>) -> Result<Array2<f64>, CompilerError> {
829 let n = g.nrows();
830 if n == 0 {
831 return Ok(Array2::zeros((0, r.ncols())));
832 }
833 let (evals, evecs) = g
834 .eigh(Side::Lower)
835 .map_err(|err| CompilerError::LinalgFailure(format!("Gram eigh failed: {err:?}")))?;
836 let lambda_max = evals.iter().cloned().fold(0.0_f64, f64::max).max(0.0);
837 let tol = lambda_max * RANK_REVEAL_EPS_SLACK * (n.max(1) as f64) * f64::EPSILON;
838 let u_t_r = fast_atb(&evecs, r);
840 let mut scaled = u_t_r.clone();
841 for i in 0..n {
842 let lam = evals[i];
843 let inv = if lam > tol { 1.0 / lam } else { 0.0 };
844 for j in 0..scaled.ncols() {
845 scaled[[i, j]] *= inv;
846 }
847 }
848 let m = fast_ab(&evecs, &scaled);
849 Ok(m)
850}
851
852fn keep_positive_eigenspace(
856 g_tilde: &Array2<f64>,
857 n: usize,
858 k: usize,
859 g_bb_trace: f64,
860) -> Result<Array2<f64>, CompilerError> {
861 let p = g_tilde.nrows();
862 if p == 0 {
863 return Ok(Array2::zeros((0, 0)));
864 }
865 let (evals, evecs) = g_tilde.eigh(Side::Lower).map_err(|err| {
866 CompilerError::LinalgFailure(format!("residual Gram eigh failed: {err:?}"))
867 })?;
868 let lambda_max = evals.iter().cloned().fold(0.0_f64, f64::max).max(0.0);
869 let scale = lambda_max.max(g_bb_trace);
870 let nk = (n.saturating_mul(k)).max(p).max(1) as f64;
871 let tau = scale * RANK_REVEAL_EPS_SLACK * nk * f64::EPSILON;
872 let mut kept: Vec<usize> = (0..p).filter(|&i| evals[i] > tau).collect();
874 kept.sort_by(|&a, &b| {
876 evals[b]
877 .partial_cmp(&evals[a])
878 .unwrap_or(std::cmp::Ordering::Equal)
879 });
880 let mut v = Array2::<f64>::zeros((p, kept.len()));
881 for (out_col, &src_col) in kept.iter().enumerate() {
882 for row in 0..p {
883 v[[row, out_col]] = evecs[[row, src_col]];
884 }
885 }
886 Ok(v)
887}
888
889fn concat_cols(left: &Array2<f64>, right: &Array2<f64>) -> Array2<f64> {
891 let nrows = left.nrows().max(right.nrows());
892 let lc = left.ncols();
893 let rc = right.ncols();
894 let mut out = Array2::<f64>::zeros((nrows, lc + rc));
895 if lc > 0 {
896 out.slice_mut(s![.., ..lc]).assign(left);
897 }
898 if rc > 0 {
899 out.slice_mut(s![.., lc..]).assign(right);
900 }
901 out
902}
903
904fn audit_and_drop_trailing_pivots(
908 w_joint: &Array2<f64>,
909 compiled: &mut [CompiledBlock],
910) -> Result<Vec<(usize, usize)>, CompilerError> {
911 let p_total: usize = compiled.iter().map(|b| b.t_lw.ncols()).sum();
912 if p_total == 0 || w_joint.nrows() == 0 {
913 return Ok(Vec::new());
914 }
915
916 let rrqr = rrqr_with_permutation(w_joint, default_rrqr_rank_alpha())
918 .map_err(|err| CompilerError::LinalgFailure(format!("audit RRQR failed: {err:?}")))?;
919 let rank = rrqr.rank;
920 if rank >= p_total {
921 return Ok(Vec::new());
922 }
923
924 let drop_count = p_total - rank;
931 let latest_idx = compiled.len() - 1;
932 let latest = &mut compiled[latest_idx];
933 let kept_local = latest.t_lw.ncols().saturating_sub(drop_count);
934 let dropped_locals: Vec<(usize, usize)> = (kept_local..latest.t_lw.ncols())
935 .map(|c| (latest_idx, c))
936 .collect();
937 latest.t_lw = latest.t_lw.slice(s![.., ..kept_local]).to_owned();
946 if let Some(m) = latest.anchor_correction.as_ref() {
947 latest.anchor_correction = Some(m.slice(s![.., ..kept_local]).to_owned());
948 }
949 if let Some(r) = latest.r_lw.as_ref() {
950 latest.r_lw = Some(r.slice(s![.., ..kept_local]).to_owned());
951 }
952 Ok(dropped_locals)
953}
954
955pub struct PrimaryChannelBlocks {
964 pub blocks: Vec<Vec<Option<Array2<f64>>>>,
967}
968
969pub fn build_raw_grams_from_channel_blocks(
981 channel_blocks: &PrimaryChannelBlocks,
982 row_hess: &dyn RowHessian,
983 raw_block_ranges: &[std::ops::Range<usize>],
984) -> Result<Array2<f64>, CompilerError> {
985 let num_blocks = channel_blocks.blocks.len();
986 if num_blocks != raw_block_ranges.len() {
987 return Err(CompilerError::DimensionMismatch(format!(
988 "channel_blocks ({num_blocks}) and raw_block_ranges ({}) length mismatch",
989 raw_block_ranges.len()
990 )));
991 }
992 if num_blocks == 0 {
993 return Ok(Array2::<f64>::zeros((0, 0)));
994 }
995 let k = row_hess.k();
996 let n = row_hess.nrows();
997 let p_total: usize = raw_block_ranges.iter().map(|r| r.end - r.start).sum();
998 let expected_total = raw_block_ranges.last().map(|r| r.end).unwrap_or(0);
999 if expected_total != p_total {
1000 return Err(CompilerError::DimensionMismatch(format!(
1001 "raw_block_ranges must be contiguous from 0; got p_total={p_total} but last end={expected_total}"
1002 )));
1003 }
1004 for (b, slots) in channel_blocks.blocks.iter().enumerate() {
1006 if slots.len() != k {
1007 return Err(CompilerError::DimensionMismatch(format!(
1008 "block {b}: expected {k} channel slots, got {}",
1009 slots.len()
1010 )));
1011 }
1012 let p_b = raw_block_ranges[b].end - raw_block_ranges[b].start;
1013 for (c, mat) in slots.iter().enumerate() {
1014 if let Some(x) = mat.as_ref() {
1015 if x.nrows() != n {
1016 return Err(CompilerError::DimensionMismatch(format!(
1017 "block {b} channel {c}: nrows={} but row Hessian nrows={n}",
1018 x.nrows()
1019 )));
1020 }
1021 if x.ncols() != p_b {
1022 return Err(CompilerError::DimensionMismatch(format!(
1023 "block {b} channel {c}: ncols={} but block width={p_b}",
1024 x.ncols()
1025 )));
1026 }
1027 }
1028 }
1029 }
1030
1031 let h_full = row_hess.evaluate_full();
1033 if h_full.shape() != &[n, k, k] {
1034 return Err(CompilerError::DimensionMismatch(format!(
1035 "row Hessian evaluate_full shape {:?} != [n={n}, k={k}, k={k}]",
1036 h_full.shape()
1037 )));
1038 }
1039 let mut h_pairs: Vec<Array1<f64>> = Vec::with_capacity(k * k);
1041 for c in 0..k {
1042 for d in 0..k {
1043 let mut v = Array1::<f64>::zeros(n);
1044 for i in 0..n {
1045 v[i] = h_full[[i, c, d]];
1046 }
1047 h_pairs.push(v);
1048 }
1049 }
1050
1051 let mut gram = Array2::<f64>::zeros((p_total, p_total));
1052 for a in 0..num_blocks {
1054 let range_a = raw_block_ranges[a].clone();
1055 for b in a..num_blocks {
1056 let range_b = raw_block_ranges[b].clone();
1057 let mut block_acc =
1058 Array2::<f64>::zeros((range_a.end - range_a.start, range_b.end - range_b.start));
1059 for c in 0..k {
1060 let Some(x_a_c) = channel_blocks.blocks[a][c].as_ref() else {
1061 continue;
1062 };
1063 for d in 0..k {
1064 let Some(x_b_d) = channel_blocks.blocks[b][d].as_ref() else {
1065 continue;
1066 };
1067 let h_cd = &h_pairs[c * k + d];
1068 let contrib = fast_xt_diag_y(x_a_c, h_cd, x_b_d);
1070 block_acc += &contrib;
1071 }
1072 }
1073 gram.slice_mut(s![range_a.start..range_a.end, range_b.start..range_b.end])
1075 .assign(&block_acc);
1076 }
1077 }
1078 for i in 0..p_total {
1081 for j in 0..i {
1082 let v = gram[[j, i]];
1083 gram[[i, j]] = v;
1084 }
1085 }
1086 Ok(gram)
1087}
1088
1089pub fn build_raw_grams_structural(
1096 channel_blocks: &PrimaryChannelBlocks,
1097 raw_block_ranges: &[std::ops::Range<usize>],
1098) -> Array2<f64> {
1099 let num_blocks = channel_blocks.blocks.len();
1100 assert_eq!(
1101 num_blocks,
1102 raw_block_ranges.len(),
1103 "channel_blocks ({num_blocks}) and raw_block_ranges ({}) length mismatch",
1104 raw_block_ranges.len()
1105 );
1106 if num_blocks == 0 {
1107 return Array2::<f64>::zeros((0, 0));
1108 }
1109 let p_total = raw_block_ranges.last().map(|r| r.end).unwrap_or(0);
1110 let mut gram = Array2::<f64>::zeros((p_total, p_total));
1111 for a in 0..num_blocks {
1112 let range_a = raw_block_ranges[a].clone();
1113 for b in a..num_blocks {
1114 let range_b = raw_block_ranges[b].clone();
1115 let p_a = range_a.end - range_a.start;
1116 let p_b = range_b.end - range_b.start;
1117 let k_a = channel_blocks.blocks[a].len();
1118 let k_b = channel_blocks.blocks[b].len();
1119 assert_eq!(
1120 k_a, k_b,
1121 "structural Gram: block {a} has {k_a} channels but block {b} has {k_b}",
1122 );
1123 let mut block_acc = Array2::<f64>::zeros((p_a, p_b));
1124 for c in 0..k_a {
1125 let (Some(x_a_c), Some(x_b_c)) = (
1126 channel_blocks.blocks[a][c].as_ref(),
1127 channel_blocks.blocks[b][c].as_ref(),
1128 ) else {
1129 continue;
1130 };
1131 let contrib = if a == b {
1132 fast_ata(x_a_c)
1134 } else {
1135 fast_atb(x_a_c, x_b_c)
1136 };
1137 block_acc += &contrib;
1138 }
1139 gram.slice_mut(s![range_a.start..range_a.end, range_b.start..range_b.end])
1140 .assign(&block_acc);
1141 }
1142 }
1143 for i in 0..p_total {
1144 for j in 0..i {
1145 let v = gram[[j, i]];
1146 gram[[i, j]] = v;
1147 }
1148 }
1149 gram
1150}
1151
1152pub fn build_primary_grams_gpu_or_cpu(
1165 channel_blocks: &PrimaryChannelBlocks,
1166 row_hess: &dyn RowHessian,
1167 raw_block_ranges: &[std::ops::Range<usize>],
1168) -> Result<(Array2<f64>, Array2<f64>), CompilerError> {
1169 let k = row_hess.k();
1170 if k == crate::families::gpu::CHANNELS {
1171 let gpu_blocks: Vec<Vec<Option<Array2<f64>>>> = channel_blocks
1172 .blocks
1173 .iter()
1174 .map(|slots| slots.iter().cloned().collect())
1175 .collect();
1176 if let Some(h_packed) = pack_row_hessian_symmetric(row_hess) {
1177 if let Some(bundle) = crate::families::gpu::try_primary_state_gram_cuda(
1178 &gpu_blocks,
1179 &h_packed,
1180 raw_block_ranges,
1181 ) {
1182 log::info!("[identifiability_compile] gram path = gpu");
1183 return Ok((bundle.gram_h, bundle.gram_struct));
1184 }
1185 }
1186 }
1187 log::info!("[identifiability_compile] gram path = cpu");
1188 let gram_h = build_raw_grams_from_channel_blocks(channel_blocks, row_hess, raw_block_ranges)?;
1189 let gram_struct = build_raw_grams_structural(channel_blocks, raw_block_ranges);
1190 Ok((gram_h, gram_struct))
1191}
1192
1193fn pack_row_hessian_symmetric(row_hess: &dyn RowHessian) -> Option<Array2<f64>> {
1197 use crate::families::gpu::{CHANNELS, PACKED_LEN, packed_index};
1198 if row_hess.k() != CHANNELS {
1199 return None;
1200 }
1201 let n = row_hess.nrows();
1202 let h_full = row_hess.evaluate_full();
1203 if h_full.shape() != [n, CHANNELS, CHANNELS] {
1204 return None;
1205 }
1206 let mut packed = Array2::<f64>::zeros((n, PACKED_LEN));
1207 for i in 0..n {
1208 for c in 0..CHANNELS {
1209 for d in c..CHANNELS {
1210 packed[[i, packed_index(c, d)]] = h_full[[i, c, d]];
1211 }
1212 }
1213 }
1214 Some(packed)
1215}
1216
1217#[derive(Debug)]
1226pub struct CompiledMap {
1227 pub raw_from_compiled: Array2<f64>,
1229 pub compiled_block_ranges: Vec<std::ops::Range<usize>>,
1232 pub raw_block_ranges: Vec<std::ops::Range<usize>>,
1234}
1235
1236impl gam_problem::gauge::CompiledBlockMap for CompiledMap {
1242 fn raw_from_compiled(&self) -> &Array2<f64> {
1243 &self.raw_from_compiled
1244 }
1245 fn raw_block_ranges(&self) -> &[std::ops::Range<usize>] {
1246 &self.raw_block_ranges
1247 }
1248 fn compiled_block_ranges(&self) -> &[std::ops::Range<usize>] {
1249 &self.compiled_block_ranges
1250 }
1251}
1252
1253pub fn compile_from_raw_grams(
1279 gram_h: &Array2<f64>,
1280 gram_struct: &Array2<f64>,
1281 raw_block_ranges: &[std::ops::Range<usize>],
1282 ordering: &[BlockOrder],
1283) -> Result<CompiledMap, CompilerError> {
1284 compile_from_raw_grams_protected(gram_h, gram_struct, raw_block_ranges, ordering, &[])
1285}
1286
1287pub fn compile_from_raw_grams_protected(
1313 gram_h: &Array2<f64>,
1314 gram_struct: &Array2<f64>,
1315 raw_block_ranges: &[std::ops::Range<usize>],
1316 ordering: &[BlockOrder],
1317 protected: &[bool],
1318) -> Result<CompiledMap, CompilerError> {
1319 if raw_block_ranges.len() != ordering.len() {
1320 return Err(CompilerError::DimensionMismatch(format!(
1321 "raw_block_ranges ({}) and ordering ({}) length mismatch",
1322 raw_block_ranges.len(),
1323 ordering.len()
1324 )));
1325 }
1326 let p_raw = raw_block_ranges.last().map(|r| r.end).unwrap_or(0);
1327 if gram_h.shape() != [p_raw, p_raw] {
1328 return Err(CompilerError::DimensionMismatch(format!(
1329 "gram_h shape {:?} != [p_raw={p_raw}, p_raw={p_raw}]",
1330 gram_h.shape()
1331 )));
1332 }
1333 if gram_struct.shape() != [p_raw, p_raw] {
1334 return Err(CompilerError::DimensionMismatch(format!(
1335 "gram_struct shape {:?} != [p_raw={p_raw}, p_raw={p_raw}]",
1336 gram_struct.shape()
1337 )));
1338 }
1339 if raw_block_ranges.is_empty() {
1340 return Ok(CompiledMap {
1341 raw_from_compiled: Array2::<f64>::zeros((0, 0)),
1342 compiled_block_ranges: Vec::new(),
1343 raw_block_ranges: Vec::new(),
1344 });
1345 }
1346 let mut expected_start = 0usize;
1348 for (b, r) in raw_block_ranges.iter().enumerate() {
1349 if r.start != expected_start {
1350 return Err(CompilerError::DimensionMismatch(format!(
1351 "raw_block_ranges must be contiguous from 0; block {b} starts at {} expected {expected_start}",
1352 r.start
1353 )));
1354 }
1355 expected_start = r.end;
1356 }
1357
1358 let mut t_cum: Array2<f64> = Array2::<f64>::zeros((p_raw, 0));
1360 let mut compiled_block_ranges: Vec<std::ops::Range<usize>> =
1361 Vec::with_capacity(raw_block_ranges.len());
1362
1363 for (idx, range_b) in raw_block_ranges.iter().enumerate() {
1364 let p_b = range_b.end - range_b.start;
1365 let block_protected = protected.get(idx).copied().unwrap_or(false);
1366 if p_b == 0 {
1376 let at = t_cum.ncols();
1377 compiled_block_ranges.push(at..at);
1378 continue;
1379 }
1380 let ks_t = fast_ab(gram_struct, &t_cum);
1385 let g_s_aa = fast_atb(&t_cum, &ks_t);
1387 let ks_pb = gram_struct
1389 .slice(s![.., range_b.start..range_b.end])
1390 .to_owned();
1391 let g_s_ab = fast_atb(&t_cum, &ks_pb);
1392 let g_s_bb = gram_struct
1394 .slice(s![range_b.start..range_b.end, range_b.start..range_b.end])
1395 .to_owned();
1396 let r_s = solve_psd_system(&g_s_aa, &g_s_ab)?;
1398 let g_s_res_raw = &g_s_bb - &fast_atb(&g_s_ab, &r_s);
1400 let g_s_res = symmetrise(&g_s_res_raw);
1401 let g_s_bb_trace: f64 = (0..p_b).map(|i| g_s_bb[[i, i]].max(0.0)).sum();
1403 let q_plus = if block_protected {
1408 Array2::<f64>::eye(p_b)
1409 } else {
1410 keep_positive_eigenspace(&g_s_res, p_raw, 1, g_s_bb_trace)?
1411 };
1412 if q_plus.ncols() == 0 {
1413 if t_cum.ncols() == 0 {
1414 return Err(CompilerError::FullyAliased {
1415 block_idx: idx,
1416 reason: format!(
1417 "structural residual Gram has no positive eigenspace (block of width {p_b} has zero structural span before any anchor exists)"
1418 ),
1419 });
1420 }
1421 let at = t_cum.ncols();
1422 compiled_block_ranges.push(at..at);
1423 continue;
1424 }
1425 let mut diff = Array2::<f64>::zeros((p_raw, p_b));
1430 if t_cum.ncols() > 0 {
1431 let t_rs = fast_ab(&t_cum, &r_s);
1433 for i in 0..p_raw {
1434 for j in 0..p_b {
1435 diff[[i, j]] = -t_rs[[i, j]];
1436 }
1437 }
1438 }
1439 for j in 0..p_b {
1440 diff[[range_b.start + j, j]] += 1.0;
1441 }
1442 let d_mat = fast_ab(&diff, &q_plus);
1443
1444 let kh_t = fast_ab(gram_h, &t_cum);
1447 let g_h_aa = fast_atb(&t_cum, &kh_t);
1448 let kh_d = fast_ab(gram_h, &d_mat);
1449 let g_h_ad = fast_atb(&t_cum, &kh_d);
1450 let r_h = solve_psd_system(&g_h_aa, &g_h_ad)?;
1451 let d_t_kh_d = fast_atb(&d_mat, &kh_d);
1453 let g_h_res_raw = &d_t_kh_d - &fast_atb(&g_h_ad, &r_h);
1454 let g_h_res = symmetrise(&g_h_res_raw);
1455 let k_kept = q_plus.ncols();
1456 let g_h_dd_trace: f64 = (0..k_kept).map(|i| d_t_kh_d[[i, i]].max(0.0)).sum();
1457 let u_mat = if block_protected {
1461 Array2::<f64>::eye(k_kept)
1462 } else {
1463 keep_positive_eigenspace(&g_h_res, p_raw, 1, g_h_dd_trace)?
1464 };
1465 if u_mat.ncols() == 0 {
1466 if t_cum.ncols() == 0 {
1467 return Err(CompilerError::FullyAliased {
1468 block_idx: idx,
1469 reason: format!(
1470 "curvature residual Gram has no positive eigenspace within structurally-kept basis (block of width {p_b}, structural-kept {k_kept}) before any anchor exists"
1471 ),
1472 });
1473 }
1474 let at = t_cum.ncols();
1475 compiled_block_ranges.push(at..at);
1476 continue;
1477 }
1478 let mut e_mat = d_mat.clone();
1480 if t_cum.ncols() > 0 {
1481 let t_rh = fast_ab(&t_cum, &r_h);
1482 e_mat = &e_mat - &t_rh;
1483 }
1484 let t_b = fast_ab(&e_mat, &u_mat);
1485
1486 let start = t_cum.ncols();
1487 let end = start + t_b.ncols();
1488 compiled_block_ranges.push(start..end);
1489 t_cum = concat_cols(&t_cum, &t_b);
1490 }
1491
1492 for v in t_cum.iter() {
1494 if !v.is_finite() {
1495 return Err(CompilerError::LinalgFailure(
1496 "compile_from_raw_grams produced non-finite entry in raw_from_compiled".to_string(),
1497 ));
1498 }
1499 }
1500
1501 Ok(CompiledMap {
1502 raw_from_compiled: t_cum,
1503 compiled_block_ranges,
1504 raw_block_ranges: raw_block_ranges.to_vec(),
1505 })
1506}
1507
1508impl CompiledMap {
1509 pub fn p_raw(&self) -> usize {
1511 self.raw_from_compiled.nrows()
1512 }
1513
1514 pub fn p_compiled(&self) -> usize {
1516 self.raw_from_compiled.ncols()
1517 }
1518
1519 pub fn reduce_design(&self, raw_design: &Array2<f64>) -> Result<Array2<f64>, String> {
1527 if raw_design.ncols() != self.p_raw() {
1528 return Err(format!(
1529 "CompiledMap::reduce_design: raw_design has {} columns, expected p_raw {}",
1530 raw_design.ncols(),
1531 self.p_raw()
1532 ));
1533 }
1534 Ok(fast_ab(raw_design, &self.raw_from_compiled))
1535 }
1536
1537 pub fn lift_coefficients(&self, beta_compiled: &Array1<f64>) -> Result<Array1<f64>, String> {
1544 if beta_compiled.len() != self.p_compiled() {
1545 return Err(format!(
1546 "CompiledMap::lift_coefficients: beta_compiled len {} != p_compiled {}",
1547 beta_compiled.len(),
1548 self.p_compiled()
1549 ));
1550 }
1551 Ok(self.raw_from_compiled.dot(beta_compiled))
1552 }
1553
1554 fn raw_block_rows(&self, block_idx: usize) -> Result<Array2<f64>, String> {
1559 let range = self.raw_block_ranges.get(block_idx).ok_or_else(|| {
1560 format!(
1561 "CompiledMap::raw_block_rows: block {block_idx} out of range {}",
1562 self.raw_block_ranges.len()
1563 )
1564 })?;
1565 Ok(self
1566 .raw_from_compiled
1567 .slice(s![range.start..range.end, ..])
1568 .to_owned())
1569 }
1570}
1571
1572pub fn reduce_penalties_with_map(
1591 map: &CompiledMap,
1592 raw_penalties: &[Option<Array2<f64>>],
1593) -> Result<Vec<Option<Array2<f64>>>, String> {
1594 if raw_penalties.len() != map.raw_block_ranges.len() {
1595 return Err(format!(
1596 "reduce_penalties_with_map: raw_penalties ({}) != blocks ({})",
1597 raw_penalties.len(),
1598 map.raw_block_ranges.len()
1599 ));
1600 }
1601 let p_compiled = map.p_compiled();
1602 let mut reduced: Vec<Option<Array2<f64>>> = Vec::with_capacity(raw_penalties.len());
1603 for (block_idx, raw_penalty) in raw_penalties.iter().enumerate() {
1604 let Some(s_b) = raw_penalty.as_ref() else {
1605 reduced.push(None);
1606 continue;
1607 };
1608 let p_b_raw = map.raw_block_ranges[block_idx].len();
1609 if s_b.shape() != [p_b_raw, p_b_raw] {
1610 return Err(format!(
1611 "reduce_penalties_with_map: block {block_idx} penalty shape {:?} != [{p_b_raw}, {p_b_raw}]",
1612 s_b.shape()
1613 ));
1614 }
1615 let t_b = map.raw_block_rows(block_idx)?;
1617 let s_t_b = fast_ab(s_b, &t_b); let s_compiled_raw = fast_atb(&t_b, &s_t_b); let mut s_compiled = symmetrise(&s_compiled_raw);
1621 if s_compiled.shape() != [p_compiled, p_compiled] {
1622 return Err(format!(
1623 "reduce_penalties_with_map: block {block_idx} reduced penalty shape {:?} != [{p_compiled}, {p_compiled}]",
1624 s_compiled.shape()
1625 ));
1626 }
1627 for v in s_compiled.iter_mut() {
1628 if !v.is_finite() {
1629 return Err(format!(
1630 "reduce_penalties_with_map: block {block_idx} reduced penalty has non-finite entry"
1631 ));
1632 }
1633 }
1634 reduced.push(Some(s_compiled));
1635 }
1636 Ok(reduced)
1637}
1638
1639pub struct BlockOrthogonalization {
1650 pub block_transforms: Vec<Array2<f64>>,
1653 pub dropped: Vec<(usize, usize)>,
1658 pub direction_annotations: Vec<PenalizedDirectionAnnotation>,
1666}
1667
1668pub fn orthogonalize_design_blocks(
1689 block_designs: &[Array2<f64>],
1690 priority: &[u32],
1691 weight: &[f64],
1692) -> Result<BlockOrthogonalization, CompilerError> {
1693 if block_designs.len() != priority.len() {
1694 return Err(CompilerError::DimensionMismatch(format!(
1695 "block_designs ({}) and priority ({}) length mismatch",
1696 block_designs.len(),
1697 priority.len()
1698 )));
1699 }
1700 if block_designs.is_empty() {
1701 return Ok(BlockOrthogonalization {
1702 block_transforms: Vec::new(),
1703 dropped: Vec::new(),
1704 direction_annotations: Vec::new(),
1705 });
1706 }
1707 let n = block_designs[0].nrows();
1708 for (b, x) in block_designs.iter().enumerate() {
1709 if x.nrows() != n {
1710 return Err(CompilerError::DimensionMismatch(format!(
1711 "block {b} design has {} rows but block 0 has {n}",
1712 x.nrows()
1713 )));
1714 }
1715 }
1716 if weight.len() != n {
1717 return Err(CompilerError::DimensionMismatch(format!(
1718 "weight length {} != n {n}",
1719 weight.len()
1720 )));
1721 }
1722 let mut sqrt_w = Array1::<f64>::zeros(n);
1725 for i in 0..n {
1726 let wi = weight[i].max(0.0);
1727 sqrt_w[i] = wi.sqrt();
1728 }
1729
1730 let mut order: Vec<usize> = (0..block_designs.len()).collect();
1734 order.sort_by(|&a, &b| priority[b].cmp(&priority[a]));
1735
1736 let mut anchor: Array2<f64> = Array2::<f64>::zeros((n, 0));
1738
1739 let mut block_transforms: Vec<Option<Array2<f64>>> = vec![None; block_designs.len()];
1741 let mut direction_annotations: Vec<Option<PenalizedDirectionAnnotation>> =
1742 vec![None; block_designs.len()];
1743 let mut dropped: Vec<(usize, usize)> = Vec::new();
1744
1745 for &b in order.iter() {
1746 let x_b = &block_designs[b];
1747 let p_b = x_b.ncols();
1748 let mut w_b = x_b.clone();
1750 for i in 0..n {
1751 let s = sqrt_w[i];
1752 for j in 0..p_b {
1753 w_b[[i, j]] *= s;
1754 }
1755 }
1756 let (residual, _correction) = residualise_in_metric(&anchor, &w_b)?;
1762 let g_res = symmetrise(&fast_atb(&residual, &residual));
1763 let g_bb = fast_atb(&w_b, &w_b);
1772 let g_bb_trace: f64 = (0..p_b).map(|i| g_bb[[i, i]].max(0.0)).sum();
1773 let v_b = keep_positive_eigenspace(&g_res, n, 1, g_bb_trace)?;
1774 let r_b = v_b.ncols();
1775 let absorbed_width = p_b - r_b;
1776 let kind = if absorbed_width == 0 {
1777 PenalizedDirectionAnnotationKind::Independent
1778 } else if r_b == 0 {
1779 PenalizedDirectionAnnotationKind::FullyAbsorbedByHigherPriority
1780 } else {
1781 PenalizedDirectionAnnotationKind::PartiallyAbsorbedByHigherPriority
1782 };
1783 direction_annotations[b] = Some(PenalizedDirectionAnnotation {
1784 block_idx: b,
1785 raw_width: p_b,
1786 kept_width: r_b,
1787 absorbed_width,
1788 kind,
1789 });
1790 if absorbed_width > 0 {
1791 dropped.push((b, absorbed_width));
1792 }
1793 let kept_weighted = fast_ab(&residual, &v_b);
1799 anchor = concat_cols(&anchor, &kept_weighted);
1800 block_transforms[b] = Some(v_b);
1801 }
1802
1803 let block_transforms: Vec<Array2<f64>> = block_transforms
1804 .into_iter()
1805 .enumerate()
1806 .map(|(b, t)| {
1807 t.ok_or_else(|| {
1808 CompilerError::LinalgFailure(format!(
1809 "orthogonalize_design_blocks: block {b} transform was never assigned"
1810 ))
1811 })
1812 })
1813 .collect::<Result<Vec<_>, _>>()?;
1814 let direction_annotations: Vec<PenalizedDirectionAnnotation> = direction_annotations
1815 .into_iter()
1816 .enumerate()
1817 .map(|(b, annotation)| {
1818 annotation.ok_or_else(|| {
1819 CompilerError::LinalgFailure(format!(
1820 "orthogonalize_design_blocks: block {b} direction annotation was never assigned"
1821 ))
1822 })
1823 })
1824 .collect::<Result<Vec<_>, _>>()?;
1825
1826 for (b, v) in block_transforms.iter().enumerate() {
1828 for value in v.iter() {
1829 if !value.is_finite() {
1830 return Err(CompilerError::LinalgFailure(format!(
1831 "orthogonalize_design_blocks: block {b} transform has a non-finite entry"
1832 )));
1833 }
1834 }
1835 }
1836
1837 Ok(BlockOrthogonalization {
1838 block_transforms,
1839 dropped,
1840 direction_annotations,
1841 })
1842}
1843
1844fn symmetrise(m: &Array2<f64>) -> Array2<f64> {
1846 let (r, c) = m.dim();
1847 assert_eq!(r, c, "symmetrise expects square matrix");
1848 let mut out = Array2::<f64>::zeros((r, c));
1849 for i in 0..r {
1850 for j in 0..c {
1851 out[[i, j]] = 0.5 * (m[[i, j]] + m[[j, i]]);
1852 }
1853 }
1854 out
1855}
1856
1857#[cfg(test)]
1858mod tests {
1859 use super::*;
1860 use ndarray::{Array1, Array2};
1861
1862 struct DenseScalarOperator {
1866 design: Array2<f64>,
1867 }
1868
1869 impl DenseScalarOperator {
1870 fn new(design: Array2<f64>) -> Self {
1871 Self { design }
1872 }
1873 }
1874
1875 impl RowJacobianOperator for DenseScalarOperator {
1876 fn k(&self) -> usize {
1877 1
1878 }
1879 fn ncols(&self) -> usize {
1880 self.design.ncols()
1881 }
1882 fn nrows(&self) -> usize {
1883 self.design.nrows()
1884 }
1885 fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
1886 assert_eq!(out.len(), 1);
1887 let mut acc = 0.0;
1888 for (j, &b) in delta_beta.iter().enumerate() {
1889 acc += self.design[[row, j]] * b;
1890 }
1891 out[0] = acc;
1892 }
1893 fn evaluate_full(&self) -> Array3<f64> {
1894 let n = self.design.nrows();
1895 let p = self.design.ncols();
1896 let mut out = Array3::<f64>::zeros((n, p, 1));
1897 for i in 0..n {
1898 for j in 0..p {
1899 out[[i, j, 0]] = self.design[[i, j]];
1900 }
1901 }
1902 out
1903 }
1904 }
1905
1906 struct DiagonalScalarRowHessian {
1912 w: Array1<f64>,
1913 }
1914
1915 impl DiagonalScalarRowHessian {
1916 fn new(w: Array1<f64>) -> Self {
1917 Self { w }
1918 }
1919 }
1920
1921 impl RowHessian for DiagonalScalarRowHessian {
1922 fn k(&self) -> usize {
1923 1
1924 }
1925 fn nrows(&self) -> usize {
1926 self.w.len()
1927 }
1928 fn fill_row(&self, row: usize, out: &mut [f64]) {
1929 assert_eq!(out.len(), 1);
1930 out[0] = self.w[row];
1931 }
1932 fn evaluate_full(&self) -> Array3<f64> {
1933 let n = self.w.len();
1934 let mut out = Array3::<f64>::zeros((n, 1, 1));
1935 for i in 0..n {
1936 out[[i, 0, 0]] = self.w[i];
1937 }
1938 out
1939 }
1940 }
1941
1942 fn op(design: Array2<f64>) -> Arc<dyn RowJacobianOperator> {
1943 Arc::new(DenseScalarOperator::new(design))
1944 }
1945
1946 #[test]
1950 fn compile_two_block_orthogonalises_under_metric() {
1951 let n = 50;
1952 let a = Array2::from_shape_fn((n, 3), |(i, j)| ((i + 1) as f64).sin().powi((j + 1) as i32));
1953 let b = Array2::from_shape_fn((n, 2), |(i, j)| {
1955 0.5 * a[[i, 0]] + ((i as f64) * 0.13 + j as f64).cos()
1956 });
1957 let hess = IdentityRowHessian::new(n, 1);
1958 let ops = vec![op(a.clone()), op(b.clone())];
1959 let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
1960 .expect("compile should succeed");
1961 let v_b = &compiled.blocks[1].t_lw;
1963 let m_b = compiled.blocks[1]
1964 .anchor_correction
1965 .as_ref()
1966 .expect("second block must carry an anchor correction");
1967 let b_v = b.dot(v_b);
1968 let a_m = a.dot(m_b);
1969 let b_compiled = &b_v - &a_m;
1970 let cross = a.t().dot(&b_compiled);
1972 let max_err = cross.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
1973 assert!(
1974 max_err < 1e-10,
1975 "orthogonality residual too large: {max_err:e}"
1976 );
1977 }
1978
1979 #[test]
1981 fn compile_three_block_chain() {
1982 let n = 80;
1983 let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 * 0.1 + j as f64).sin());
1984 let b = Array2::from_shape_fn((n, 2), |(i, j)| {
1985 0.3 * a[[i, 0]] + (j as f64) * (i as f64).cos()
1986 });
1987 let c = Array2::from_shape_fn((n, 2), |(i, j)| {
1988 0.2 * a[[i, 1]] + 0.4 * b[[i, 0]] + ((i + j) as f64).tan().min(5.0).max(-5.0)
1989 });
1990 let hess = IdentityRowHessian::new(n, 1);
1991 let ops = vec![op(a), op(b), op(c)];
1992 let compiled = compile(
1993 &ops,
1994 &hess,
1995 &[
1996 BlockOrder::Marginal,
1997 BlockOrder::Logslope,
1998 BlockOrder::LinkDev,
1999 ],
2000 )
2001 .expect("compile should succeed");
2002 let total: usize = compiled.blocks.iter().map(|b| b.t_lw.ncols()).sum();
2003 assert_eq!(
2004 compiled.joint_rank, total,
2005 "audit must report full rank on synthetic full-rank design"
2006 );
2007 }
2008
2009 #[test]
2015 fn compile_protected_keeps_rank_deficient_first_block_full_width() {
2016 let n = 40;
2017 let a = Array2::from_shape_fn((n, 2), |(i, _)| ((i + 1) as f64 * 0.31).sin());
2020 let b = Array2::from_shape_fn((n, 2), |(i, j)| ((i as f64) * 0.17 + j as f64).cos());
2021 let hess = IdentityRowHessian::new(n, 1);
2022 let ordering = [BlockOrder::Time, BlockOrder::Marginal];
2023
2024 let unprotected = compile(&[op(a.clone()), op(b.clone())], &hess, &ordering)
2025 .expect("unprotected compile");
2026 assert_eq!(
2027 unprotected.blocks[0].t_lw.ncols(),
2028 1,
2029 "unprotected first block drops its duplicate column"
2030 );
2031
2032 let protected = compile_protected(
2033 &[op(a.clone()), op(b.clone())],
2034 &hess,
2035 &ordering,
2036 &[true, false],
2037 )
2038 .expect("protected compile");
2039 let v_a = &protected.blocks[0].t_lw;
2040 assert_eq!(
2041 v_a.ncols(),
2042 2,
2043 "protected first block retains its full raw width"
2044 );
2045 for i in 0..2 {
2048 for j in 0..2 {
2049 let expect = if i == j { 1.0 } else { 0.0 };
2050 assert!(
2051 (v_a[[i, j]] - expect).abs() <= 1e-12,
2052 "protected first block V must be identity, got [{i},{j}]={}",
2053 v_a[[i, j]]
2054 );
2055 }
2056 }
2057 }
2058
2059 #[test]
2063 fn compile_weighted_metric_nontrivial() {
2064 let n = 32;
2065 let a: Array2<f64> = Array2::from_shape_fn((n, 1), |(i, _)| (i as f64 + 1.0).sqrt());
2066 let b: Array2<f64> =
2067 Array2::from_shape_fn((n, 1), |(i, _)| 0.7 * a[[i, 0]] + (i as f64 * 0.05).cos());
2068 let w = Array1::from_shape_fn(n, |i| 0.5 + (i as f64 * 0.2).sin().abs());
2069 let hess = DiagonalScalarRowHessian::new(w.clone());
2070 let ops = vec![op(a.clone()), op(b.clone())];
2071 let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
2072 .expect("compile should succeed");
2073 let m = compiled.blocks[1]
2074 .anchor_correction
2075 .as_ref()
2076 .expect("anchor correction present");
2077 let analytic_num: f64 = (0..n).map(|i| w[i] * a[[i, 0]] * b[[i, 0]]).sum();
2078 let analytic_den: f64 = (0..n).map(|i| w[i] * a[[i, 0]] * a[[i, 0]]).sum();
2079 let analytic = analytic_num / analytic_den;
2080 assert!(m.dim() == (1, 1));
2081 assert!(
2082 (m[[0, 0]] - analytic).abs() < 1e-10,
2083 "weighted projection mismatch: got {got}, analytic {analytic}",
2084 got = m[[0, 0]]
2085 );
2086 }
2087
2088 #[test]
2098 fn compile_emits_anchor_correction_in_raw_column_coordinates() {
2099 let n = 64;
2100 let a: Array2<f64> = Array2::from_shape_fn((n, 3), |(i, j)| {
2104 let c0 = (i as f64 * 0.07 + 1.0).ln();
2105 let c1 = (i as f64 * 0.13).sin();
2106 match j {
2107 0 => c0,
2108 1 => c1,
2109 _ => 2.0 * c0 - 0.5 * c1,
2110 }
2111 });
2112 let c: Array2<f64> = Array2::from_shape_fn((n, 2), |(i, j)| {
2114 0.4 * a[[i, 0]] + (j as f64) * (i as f64 * 0.05).cos() + (i as f64 * 0.011).tanh()
2115 });
2116 let w = Array1::from_shape_fn(n, |i| 0.3 + (i as f64 * 0.17).sin().abs());
2117 let hess = DiagonalScalarRowHessian::new(w.clone());
2118 let ops = vec![op(a.clone()), op(c.clone())];
2119 let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::LinkDev])
2120 .expect("compile should succeed");
2121
2122 let v = &compiled.blocks[1].t_lw;
2123 let m = compiled.blocks[1]
2124 .anchor_correction
2125 .as_ref()
2126 .expect("candidate block must carry an anchor correction");
2127 let k_kept = v.ncols();
2128 assert!(k_kept >= 1, "candidate must keep at least one direction");
2129
2130 assert_eq!(
2133 m.nrows(),
2134 a.ncols(),
2135 "anchor_correction must be indexed by raw anchor columns (d_total), \
2136 got {} rows for {} raw anchor columns",
2137 m.nrows(),
2138 a.ncols(),
2139 );
2140 assert_eq!(m.ncols(), k_kept, "anchor_correction width must match V");
2141
2142 let c_v = c.dot(v);
2146 let a_m = a.dot(m);
2147 let c_tilde = &c_v - &a_m;
2148 let mut max_cross = 0.0_f64;
2149 for ac in 0..a.ncols() {
2150 for cc in 0..c_tilde.ncols() {
2151 let mut acc = 0.0;
2152 for i in 0..n {
2153 acc += w[i] * a[[i, ac]] * c_tilde[[i, cc]];
2154 }
2155 max_cross = max_cross.max(acc.abs());
2156 }
2157 }
2158 assert!(
2159 max_cross < 1e-9,
2160 "raw-coordinate anchor correction must W-orthogonalise the candidate \
2161 against the raw anchor span; max |Aᵀ W C̃| = {max_cross:e}"
2162 );
2163 }
2164
2165 #[test]
2168 fn compile_drops_trailing_pivots_from_latest_block() {
2169 let n = 40;
2170 let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 + 1.0).ln() * (j as f64 + 1.0));
2171 let c = Array2::from_shape_fn((n, 2), |(i, j)| {
2176 if j == 0 {
2177 a[[i, 0]]
2178 } else {
2179 (i as f64 * 0.1).cos()
2180 }
2181 });
2182 let hess = IdentityRowHessian::new(n, 1);
2183 let ops = vec![op(a), op(c)];
2184 let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
2190 .expect("compile should succeed");
2191 let v1_cols = compiled.blocks[1].t_lw.ncols();
2195 assert!(
2196 v1_cols < 2 || !compiled.dropped.is_empty(),
2197 "expected rank loss attributed to block 1, got v1_cols={v1_cols}, dropped={dropped:?}",
2198 dropped = compiled.dropped
2199 );
2200 for (block_idx, _) in &compiled.dropped {
2201 assert_eq!(
2202 *block_idx, 1,
2203 "audit drops must come from the latest block only"
2204 );
2205 }
2206 }
2207
2208 #[test]
2220 fn audit_truncation_keeps_t_lw_and_anchor_correction_in_lockstep() {
2221 let n = 40;
2222 let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 + 1.0).ln() * (j as f64 + 1.0));
2223 let c = Array2::from_shape_fn((n, 2), |(i, j)| {
2224 if j == 0 {
2225 a[[i, 0]]
2226 } else {
2227 (i as f64 * 0.1).cos()
2228 }
2229 });
2230 let hess = IdentityRowHessian::new(n, 1);
2231 let ops = vec![op(a), op(c)];
2232 let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
2233 .expect("compile should succeed");
2234 for (idx, block) in compiled.blocks.iter().enumerate() {
2235 let k_kept = block.t_lw.ncols();
2236 if let Some(m) = block.anchor_correction.as_ref() {
2237 assert_eq!(
2238 m.ncols(),
2239 k_kept,
2240 "block {idx}: anchor_correction.ncols()={ac} must equal t_lw.ncols()={k_kept} \
2241 after audit truncation",
2242 ac = m.ncols(),
2243 );
2244 }
2245 if let Some(r) = block.r_lw.as_ref() {
2246 assert_eq!(
2247 r.ncols(),
2248 k_kept,
2249 "block {idx}: r_lw.ncols()={r_cols} must equal t_lw.ncols()={k_kept} \
2250 after audit truncation",
2251 r_cols = r.ncols(),
2252 );
2253 }
2254 }
2255 }
2256
2257 #[test]
2262 fn compile_flex_anchor_is_first_class() {
2263 let n = 60;
2264 let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 * 0.07 + j as f64).sin());
2270 let b = Array2::from_shape_fn((n, 2), |(i, j)| {
2271 0.4 * a[[i, 0]] + (j as f64) * (i as f64 + 1.0).ln()
2272 });
2273 let hess = IdentityRowHessian::new(n, 1);
2274
2275 let ops_param = vec![op(a.clone()), op(b.clone())];
2276 let compiled_param = compile(
2277 &ops_param,
2278 &hess,
2279 &[BlockOrder::Marginal, BlockOrder::Logslope],
2280 )
2281 .expect("compile should succeed");
2282
2283 let ops_flex = vec![op(a.clone()), op(b.clone())];
2287 let compiled_flex = compile(
2288 &ops_flex,
2289 &hess,
2290 &[BlockOrder::ScoreWarp, BlockOrder::LinkDev],
2291 )
2292 .expect("compile should succeed");
2293
2294 let m_param = compiled_param.blocks[1].anchor_correction.as_ref().unwrap();
2295 let m_flex = compiled_flex.blocks[1].anchor_correction.as_ref().unwrap();
2296 assert_eq!(m_param.dim(), m_flex.dim());
2297 let max_diff = (m_param - m_flex)
2298 .iter()
2299 .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2300 assert!(
2301 max_diff < 1e-12,
2302 "flex vs parametric anchor correction mismatch: {max_diff:e}"
2303 );
2304 }
2305
2306 #[test]
2310 fn bernoulli_row_hessian_matches_irls_weight() {
2311 let w = Array1::from(vec![0.1, 0.5, 0.9, 0.25, 0.75]);
2312 let hess = DiagonalScalarRowHessian::new(w.clone());
2313 let full = hess.evaluate_full();
2314 assert_eq!(full.shape(), &[5, 1, 1]);
2315 for i in 0..5 {
2316 assert_eq!(full[[i, 0, 0]], w[i]);
2317 let mut buf = [0.0_f64; 1];
2318 hess.fill_row(i, &mut buf);
2319 assert_eq!(buf[0], w[i]);
2320 }
2321 }
2322
2323 #[test]
2327 fn compiler_predict_path_roundtrip() {
2328 let n = 24;
2329 let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 * 0.21).cos() + j as f64);
2330 let b = Array2::from_shape_fn((n, 2), |(i, j)| {
2331 0.3 * a[[i, 0]] + (i as f64 + j as f64).sqrt()
2332 });
2333 let hess = IdentityRowHessian::new(n, 1);
2334 let ops = vec![op(a.clone()), op(b.clone())];
2335 let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
2336 .expect("compile should succeed");
2337 let v_b = &compiled.blocks[1].t_lw;
2338 let m_b = compiled.blocks[1].anchor_correction.as_ref().unwrap();
2339 let predict_design = b.dot(v_b) - a.dot(m_b);
2341 assert_eq!(predict_design.nrows(), n);
2346 assert_eq!(predict_design.ncols(), v_b.ncols());
2347 for &val in predict_design.iter() {
2349 assert!(val.is_finite(), "predict design produced non-finite entry");
2350 }
2351 }
2352
2353 #[test]
2359 fn compile_exposes_r_lw_equal_to_m_dot_v() {
2360 let n = 40;
2361 let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 * 0.17 + j as f64).sin());
2362 let b = Array2::from_shape_fn((n, 2), |(i, j)| {
2364 0.6 * a[[i, 0]] + ((i as f64) * 0.11 + j as f64).cos()
2365 });
2366 let hess = IdentityRowHessian::new(n, 1);
2367 let ops = vec![op(a.clone()), op(b.clone())];
2368 let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
2369 .expect("compile should succeed");
2370
2371 assert!(compiled.blocks[0].r_lw.is_none());
2373 assert!(compiled.blocks[0].anchor_correction.is_none());
2374
2375 let v_a = &compiled.blocks[0].t_lw;
2378 let v_b = &compiled.blocks[1].t_lw;
2379 let m_compiled = compiled.blocks[1]
2380 .anchor_correction
2381 .as_ref()
2382 .expect("second block must carry an anchor correction");
2383 let r_lw = compiled.blocks[1]
2384 .r_lw
2385 .as_ref()
2386 .expect("second block must expose r_lw");
2387 let p_a_kept = v_a.ncols();
2388 let p_b_kept = v_b.ncols();
2389 assert_eq!(
2390 m_compiled.dim(),
2391 (p_a_kept, p_b_kept),
2392 "anchor_correction must be at compiled width"
2393 );
2394 assert_eq!(r_lw.dim(), (p_a_kept, p_b_kept));
2395 let diff = r_lw - m_compiled;
2397 let max_diff = diff.iter().fold(0.0_f64, |acc, &x| acc.max(x.abs()));
2398 assert!(
2399 max_diff == 0.0,
2400 "r_lw and anchor_correction must be identical"
2401 );
2402
2403 let b_compiled = b.dot(v_b) - a.dot(m_compiled);
2408 let cross = a.t().dot(&b_compiled);
2409 let max_cross = cross.iter().fold(0.0_f64, |acc, &x| acc.max(x.abs()));
2410 assert!(
2411 max_cross < 1e-10,
2412 "compiled B-design must be H-orthogonal to A: max cross = {max_cross:e}"
2413 );
2414 }
2415
2416 struct DenseRowHessian {
2418 h: Array3<f64>,
2419 }
2420
2421 impl RowHessian for DenseRowHessian {
2422 fn k(&self) -> usize {
2423 self.h.shape()[1]
2424 }
2425 fn nrows(&self) -> usize {
2426 self.h.shape()[0]
2427 }
2428 fn fill_row(&self, row: usize, out: &mut [f64]) {
2429 let k = self.k();
2430 assert_eq!(out.len(), k * k);
2431 for c in 0..k {
2432 for d in 0..k {
2433 out[c * k + d] = self.h[[row, c, d]];
2434 }
2435 }
2436 }
2437 fn evaluate_full(&self) -> Array3<f64> {
2438 self.h.clone()
2439 }
2440 }
2441
2442 fn reference_gram_from_w(j_full: &Array3<f64>, h_full: &Array3<f64>) -> Array2<f64> {
2445 let w = scale_block_by_sqrt_h(j_full, h_full);
2446 fast_ata(&w)
2447 }
2448
2449 #[test]
2452 fn closed_form_gram_matches_reference_two_block_k4() {
2453 let n = 17;
2454 let k = 4;
2455 let p_a = 3;
2456 let p_b = 2;
2457
2458 let make_block = |seed: f64, n: usize, p: usize| -> Vec<Option<Array2<f64>>> {
2460 (0..4)
2461 .map(|c| {
2462 let m = Array2::from_shape_fn((n, p), |(i, j)| {
2463 ((i as f64 + 1.0) * (j as f64 + 1.0) * (c as f64 + 1.0) + seed).sin()
2464 });
2465 Some(m)
2466 })
2467 .collect()
2468 };
2469 let block_a = make_block(0.3, n, p_a);
2470 let block_b = make_block(1.1, n, p_b);
2471
2472 let h = Array3::from_shape_fn((n, k, k), |(i, c, d)| {
2474 let mut acc = 0.0;
2475 for r in 0..k {
2476 let mc = ((i + 1) as f64 * (c + 1) as f64 * (r + 1) as f64 * 0.13).cos();
2477 let md = ((i + 1) as f64 * (d + 1) as f64 * (r + 1) as f64 * 0.13).cos();
2478 acc += mc * md;
2479 }
2480 acc + if c == d { 0.5 } else { 0.0 }
2481 });
2482 let row_hess = DenseRowHessian { h: h.clone() };
2483
2484 let channel_blocks = PrimaryChannelBlocks {
2485 blocks: vec![block_a.clone(), block_b.clone()],
2486 };
2487 let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
2488
2489 let gram = build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges)
2490 .expect("closed-form Gram should succeed");
2491
2492 let p_total = p_a + p_b;
2495 let mut j_full = Array3::<f64>::zeros((n, p_total, k));
2496 for c in 0..k {
2497 if let Some(xa) = block_a[c].as_ref() {
2498 for i in 0..n {
2499 for j in 0..p_a {
2500 j_full[[i, j, c]] = xa[[i, j]];
2501 }
2502 }
2503 }
2504 if let Some(xb) = block_b[c].as_ref() {
2505 for i in 0..n {
2506 for j in 0..p_b {
2507 j_full[[i, p_a + j, c]] = xb[[i, j]];
2508 }
2509 }
2510 }
2511 }
2512 let ref_gram = reference_gram_from_w(&j_full, &h);
2513
2514 let diff = &gram - &ref_gram;
2515 let max_err = diff.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2516 let scale = ref_gram.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2517 assert!(
2518 max_err < 1e-9 * scale.max(1.0),
2519 "closed-form Gram mismatches reference: max_err={max_err:e}, scale={scale:e}"
2520 );
2521
2522 for i in 0..p_total {
2524 for j in 0..p_total {
2525 assert!(
2526 (gram[[i, j]] - gram[[j, i]]).abs() < 1e-12,
2527 "closed-form Gram not symmetric at ({i},{j})"
2528 );
2529 }
2530 }
2531 }
2532
2533 #[test]
2538 fn closed_form_gram_channel_sparsity() {
2539 let n = 13;
2540 let k = 4;
2541 let p_a = 2;
2542 let p_b = 2;
2543
2544 let xa = Array2::from_shape_fn((n, p_a), |(i, j)| ((i + 1) as f64 * 0.21 + j as f64).cos());
2545 let xb = Array2::from_shape_fn((n, p_b), |(i, j)| {
2546 ((i + 1) as f64 * 0.17 + j as f64).sin() + 0.5
2547 });
2548
2549 let block_a: Vec<Option<Array2<f64>>> = vec![Some(xa.clone()), None, None, None];
2550 let block_b: Vec<Option<Array2<f64>>> = vec![None, None, None, Some(xb.clone())];
2551
2552 let h_03_vec = Array1::from_shape_fn(n, |i| 0.7 + 0.3 * ((i as f64) * 0.4).sin());
2555 let h = Array3::from_shape_fn((n, k, k), |(i, c, d)| {
2556 if (c, d) == (0, 3) || (c, d) == (3, 0) {
2559 h_03_vec[i]
2560 } else if c == d {
2561 2.0
2562 } else {
2563 0.0
2564 }
2565 });
2566 let row_hess = DenseRowHessian { h: h.clone() };
2567
2568 let channel_blocks = PrimaryChannelBlocks {
2569 blocks: vec![block_a.clone(), block_b.clone()],
2570 };
2571 let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
2572 let gram = build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges)
2573 .expect("closed-form Gram should succeed");
2574
2575 let cross = gram.slice(s![0..p_a, p_a..(p_a + p_b)]).to_owned();
2577 let expected = fast_xt_diag_y(&xa, &h_03_vec, &xb);
2579 let diff = &cross - &expected;
2580 let max_err = diff.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2581 assert!(
2582 max_err < 1e-12,
2583 "cross-block Gram must equal Xaᵀ·diag(h_03)·Xb: max_err={max_err:e}"
2584 );
2585
2586 let h_zero = Array3::from_shape_fn((n, k, k), |(_, c, d)| if c == d { 2.0 } else { 0.0 });
2588 let row_hess_zero = DenseRowHessian { h: h_zero };
2589 let gram_zero =
2590 build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess_zero, &raw_ranges)
2591 .expect("closed-form Gram should succeed");
2592 let cross_zero = gram_zero.slice(s![0..p_a, p_a..(p_a + p_b)]);
2593 let max_zero = cross_zero.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2594 assert!(
2595 max_zero < 1e-12,
2596 "cross-block Gram must vanish when coupling channel pair is zero: got {max_zero:e}"
2597 );
2598 }
2599
2600 #[test]
2603 fn structural_gram_matches_within_channel_sum() {
2604 let n = 11;
2605 let p_a = 2;
2606 let p_b = 3;
2607 let make_block = |seed: f64, n: usize, p: usize| -> Vec<Option<Array2<f64>>> {
2608 (0..4)
2609 .map(|c| {
2610 if c == 1 {
2611 return None;
2613 }
2614 Some(Array2::from_shape_fn((n, p), |(i, j)| {
2615 ((i as f64 + 1.0) * (j as f64 + 1.0) + seed * (c as f64 + 1.0)).sin()
2616 }))
2617 })
2618 .collect()
2619 };
2620 let block_a = make_block(0.1, n, p_a);
2621 let block_b = make_block(0.7, n, p_b);
2622 let channel_blocks = PrimaryChannelBlocks {
2623 blocks: vec![block_a.clone(), block_b.clone()],
2624 };
2625 let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
2626 let gram = build_raw_grams_structural(&channel_blocks, &raw_ranges);
2627
2628 let mut expected_cross = Array2::<f64>::zeros((p_a, p_b));
2631 for c in 0..4 {
2632 if let (Some(xa), Some(xb)) = (block_a[c].as_ref(), block_b[c].as_ref()) {
2633 expected_cross += &fast_atb(xa, xb);
2634 }
2635 }
2636 let cross = gram.slice(s![0..p_a, p_a..(p_a + p_b)]).to_owned();
2637 let diff = &cross - &expected_cross;
2638 let max_err = diff.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2639 assert!(
2640 max_err < 1e-12,
2641 "structural cross-block must equal Σ_c Xaᵀ·Xb: max_err={max_err:e}"
2642 );
2643
2644 for i in 0..(p_a + p_b) {
2646 for j in 0..(p_a + p_b) {
2647 assert!(
2648 (gram[[i, j]] - gram[[j, i]]).abs() < 1e-12,
2649 "structural Gram not symmetric at ({i},{j})"
2650 );
2651 }
2652 }
2653 }
2654
2655 fn diag_hess(w: Array1<f64>) -> DiagonalScalarRowHessian {
2659 DiagonalScalarRowHessian::new(w)
2660 }
2661
2662 #[test]
2666 fn dual_metric_with_equal_metrics_matches_single_metric() {
2667 let n = 36;
2668 let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 * 0.13 + j as f64).sin());
2669 let b = Array2::from_shape_fn((n, 2), |(i, j)| {
2671 0.4 * a[[i, 0]] + (i as f64 * 0.07 + j as f64).cos()
2672 });
2673 let w = Array1::from_shape_fn(n, |i| 0.5 + (i as f64 * 0.17).sin().abs());
2674 let curvature = diag_hess(w.clone());
2675 let ordering = [BlockOrder::Marginal, BlockOrder::Logslope];
2676
2677 let ops_single = vec![op(a.clone()), op(b.clone())];
2678 let single = compile(&ops_single, &curvature, &ordering)
2679 .expect("single-metric compile should succeed");
2680
2681 let structural_same = diag_hess(w.clone());
2684 let ops_dual = vec![op(a.clone()), op(b.clone())];
2685 let dual = compile_with_dual_metric(&ops_dual, &curvature, &structural_same, &ordering)
2686 .expect("dual-metric compile should succeed");
2687
2688 assert_eq!(single.blocks.len(), dual.blocks.len());
2689 for (idx, (sb, db)) in single.blocks.iter().zip(dual.blocks.iter()).enumerate() {
2690 assert_eq!(sb.t_lw.dim(), db.t_lw.dim(), "block {idx}: V dims differ");
2691 let max_v = (&sb.t_lw - &db.t_lw)
2692 .iter()
2693 .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2694 assert!(max_v < 1e-10, "block {idx}: V mismatch {max_v:e}");
2695 match (sb.anchor_correction.as_ref(), db.anchor_correction.as_ref()) {
2696 (None, None) => {}
2697 (Some(s), Some(d)) => {
2698 assert_eq!(s.dim(), d.dim());
2699 let max_m = (s - d).iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2700 assert!(max_m < 1e-10, "block {idx}: M mismatch {max_m:e}");
2701 }
2702 _ => panic!("block {idx}: one side has anchor correction, the other does not"),
2703 }
2704 }
2705 assert_eq!(single.joint_rank, dual.joint_rank);
2706 }
2707
2708 #[test]
2723 fn dual_metric_resists_pilot_curvature_alias() {
2724 let n = 12;
2725 let a = Array2::from_shape_fn((n, 1), |(i, _)| (i as f64) + 1.0);
2728 let b = Array2::from_shape_fn((n, 1), |(i, _)| {
2729 if i < 6 {
2730 2.0 * a[[i, 0]]
2731 } else {
2732 ((i as f64) * 0.3).cos() + 0.5
2733 }
2734 });
2735
2736 let mut w_vec = vec![0.0_f64; n];
2741 for w in &mut w_vec[..6] {
2742 *w = 1.0;
2743 }
2744 let w = Array1::from(w_vec);
2745 let curvature = diag_hess(w.clone());
2746
2747 let id_struct = IdentityRowHessian::new(n, 1);
2751 let ordering = [BlockOrder::Marginal, BlockOrder::Logslope];
2752
2753 let ops_dual = vec![op(a.clone()), op(b.clone())];
2757 let dual = compile_with_dual_metric(&ops_dual, &curvature, &id_struct, &ordering);
2758
2759 let ops_h_only = vec![op(a.clone()), op(b.clone())];
2763 let h_only = compile_with_dual_metric(&ops_h_only, &curvature, &curvature, &ordering);
2764
2765 match h_only {
2768 Err(CompilerError::FullyAliased { block_idx, .. }) => {
2769 assert_eq!(block_idx, 1, "H-only path must alias block 1");
2770 }
2771 Ok(out) => {
2772 let v1_cols = out.blocks[1].t_lw.ncols();
2777 assert!(
2778 v1_cols == 0 || !out.dropped.is_empty(),
2779 "H-only path should reject B's curvature-aliased column; v1_cols={v1_cols}, dropped={dropped:?}",
2780 dropped = out.dropped,
2781 );
2782 }
2783 Err(other) => panic!("unexpected H-only error: {other:?}"),
2784 }
2785
2786 let dual =
2787 dual.expect("dual-metric must succeed: identity-structural sees B as independent");
2788 assert_eq!(dual.blocks.len(), 2);
2797 assert_eq!(dual.blocks[0].t_lw.ncols(), 1, "A must keep its column");
2798 let v1_post_audit = dual.blocks[1].t_lw.ncols();
2803 let dropped_count = dual.dropped.len();
2804 assert_eq!(
2805 v1_post_audit + dropped_count,
2806 1,
2807 "structural pass kept B's column; audit may demote it but the pre-audit width was 1"
2808 );
2809 }
2810
2811 #[test]
2820 fn dual_metric_identity_structural_preserves_full_rank() {
2821 let n = 24;
2822 let a = Array2::from_shape_fn((n, 2), |(i, j)| ((i + 1) as f64 + j as f64).sqrt());
2823 let b = Array2::from_shape_fn((n, 2), |(i, j)| {
2824 ((i + 1) as f64).ln() + (i as f64 * 0.1 + j as f64).cos()
2825 });
2826 let w = Array1::from_shape_fn(n, |i| 0.4 + (i as f64 * 0.05).sin().powi(2));
2827 let curvature = diag_hess(w.clone());
2828 let id_struct = IdentityRowHessian::new(n, 1);
2829 let ordering = [BlockOrder::Marginal, BlockOrder::Logslope];
2830
2831 let ops = vec![op(a.clone()), op(b.clone())];
2832 let out =
2833 compile_with_dual_metric(&ops, &curvature, &id_struct, &ordering).expect("compile");
2834 assert_eq!(out.blocks[0].t_lw.ncols(), 2);
2836 assert_eq!(out.blocks[1].t_lw.ncols(), 2);
2837 assert_eq!(out.dropped.len(), 0);
2838 assert_eq!(out.joint_rank, 4);
2839 }
2840
2841 #[test]
2847 fn build_primary_grams_gpu_or_cpu_two_block_k4_matches_cpu() {
2848 let n = 11;
2849 let k = 4;
2850 let p_a = 2;
2851 let p_b = 3;
2852
2853 let make_block = |seed: f64, n: usize, p: usize| -> Vec<Option<Array2<f64>>> {
2854 (0..4)
2855 .map(|c| {
2856 let m = Array2::from_shape_fn((n, p), |(i, j)| {
2857 ((i as f64 + 1.0) * (j as f64 + 1.0) * (c as f64 + 1.0) + seed).sin()
2858 });
2859 Some(m)
2860 })
2861 .collect()
2862 };
2863 let block_a = make_block(0.7, n, p_a);
2864 let block_b = make_block(-0.4, n, p_b);
2865
2866 let h = Array3::from_shape_fn((n, k, k), |(i, c, d)| {
2867 let mut acc = 0.0;
2868 for r in 0..k {
2869 let mc = ((i + 1) as f64 * (c + 1) as f64 * (r + 1) as f64 * 0.11).cos();
2870 let md = ((i + 1) as f64 * (d + 1) as f64 * (r + 1) as f64 * 0.11).cos();
2871 acc += mc * md;
2872 }
2873 acc + if c == d { 0.25 } else { 0.0 }
2874 });
2875 let row_hess = DenseRowHessian { h: h.clone() };
2876
2877 let channel_blocks = PrimaryChannelBlocks {
2878 blocks: vec![block_a, block_b],
2879 };
2880 let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
2881
2882 let (gram_h, gram_struct) =
2883 build_primary_grams_gpu_or_cpu(&channel_blocks, &row_hess, &raw_ranges)
2884 .expect("dispatch helper should succeed");
2885
2886 let cpu_h = build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges)
2887 .expect("CPU curvature Gram should succeed");
2888 let cpu_s = build_raw_grams_structural(&channel_blocks, &raw_ranges);
2889
2890 let tol = 1e-9_f64;
2891 for idx in cpu_h.indexed_iter().map(|(i, _)| i) {
2892 let diff = (gram_h[idx] - cpu_h[idx]).abs();
2893 let scale = cpu_h[idx].abs().max(1.0);
2894 assert!(
2895 diff <= tol * scale,
2896 "gram_h mismatch at {idx:?}: helper={} cpu={}",
2897 gram_h[idx],
2898 cpu_h[idx]
2899 );
2900 }
2901 for idx in cpu_s.indexed_iter().map(|(i, _)| i) {
2902 let diff = (gram_struct[idx] - cpu_s[idx]).abs();
2903 let scale = cpu_s[idx].abs().max(1.0);
2904 assert!(
2905 diff <= tol * scale,
2906 "gram_struct mismatch at {idx:?}: helper={} cpu={}",
2907 gram_struct[idx],
2908 cpu_s[idx]
2909 );
2910 }
2911 }
2912
2913 fn scalar_grams_two_block(
2919 a: &Array2<f64>,
2920 b: &Array2<f64>,
2921 w: &Array1<f64>,
2922 ) -> (Array2<f64>, Array2<f64>, Vec<std::ops::Range<usize>>) {
2923 let p_a = a.ncols();
2924 let p_b = b.ncols();
2925 let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
2926 let channel_blocks = PrimaryChannelBlocks {
2927 blocks: vec![vec![Some(a.clone())], vec![Some(b.clone())]],
2928 };
2929 let row_hess = DiagonalScalarRowHessian::new(w.clone());
2930 let gram_h =
2931 build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges).unwrap();
2932 let gram_struct = build_raw_grams_structural(&channel_blocks, &raw_ranges);
2933 (gram_h, gram_struct, raw_ranges)
2934 }
2935
2936 #[test]
2940 fn compile_from_raw_grams_full_structural_alias() {
2941 let n = 10;
2942 let a = Array2::from_shape_fn((n, 2), |(i, j)| ((i + 1) as f64 * (j + 1) as f64).sin());
2943 let l = Array2::from_shape_vec((2, 2), vec![1.0, 0.5, -0.25, 1.0]).unwrap();
2945 let b = a.dot(&l);
2946 let w = Array1::ones(n);
2947 let (gram_h, gram_struct, raw_ranges) = scalar_grams_two_block(&a, &b, &w);
2948 let res = compile_from_raw_grams(
2949 &gram_h,
2950 &gram_struct,
2951 &raw_ranges,
2952 &[BlockOrder::Marginal, BlockOrder::Logslope],
2953 )
2954 .expect("lower-priority full alias should compile to zero width");
2955 assert_eq!(res.compiled_block_ranges[0].len(), 2);
2956 assert_eq!(res.compiled_block_ranges[1].len(), 0);
2957 assert_eq!(res.raw_from_compiled.dim(), (4, 2));
2958 assert!(
2959 res.raw_from_compiled
2960 .slice(s![raw_ranges[1].clone(), ..])
2961 .iter()
2962 .all(|v| v.abs() <= 1.0e-12),
2963 "zero-width block must not retain raw coefficient directions in T"
2964 );
2965 }
2966
2967 #[test]
2974 fn compile_from_raw_grams_zero_width_first_block_is_identifiable() {
2975 let n = 12;
2976 let empty = Array2::<f64>::zeros((n, 0));
2977 let b = Array2::from_shape_fn((n, 2), |(i, j)| {
2978 ((i + 1) as f64 * (j + 1) as f64 * 0.23).cos()
2979 });
2980 let w = Array1::ones(n);
2981 let (gram_h, gram_struct, raw_ranges) = scalar_grams_two_block(&empty, &b, &w);
2982 let map = compile_from_raw_grams(
2983 &gram_h,
2984 &gram_struct,
2985 &raw_ranges,
2986 &[BlockOrder::Marginal, BlockOrder::Logslope],
2987 )
2988 .expect("zero-width first block must be trivially identifiable, not FullyAliased");
2989 assert_eq!(
2990 map.compiled_block_ranges[0].len(),
2991 0,
2992 "empty first block keeps zero columns"
2993 );
2994 assert_eq!(
2995 map.compiled_block_ranges[1].len(),
2996 2,
2997 "the second block keeps its full structural rank"
2998 );
2999 assert_eq!(map.raw_from_compiled.dim(), (2, 2));
3000 }
3001
3002 #[test]
3011 fn compile_from_raw_grams_protected_keeps_full_rank_deficient_first_block() {
3012 let n = 14;
3013 let a = Array2::from_shape_fn((n, 2), |(i, _)| ((i + 1) as f64 * 0.37).sin());
3019 let b = Array2::from_shape_fn((n, 2), |(i, j)| ((i + 1) as f64 * (0.29 + j as f64 * 0.11)).cos());
3021 let w = Array1::ones(n);
3022 let (gram_h, gram_struct, raw_ranges) = scalar_grams_two_block(&a, &b, &w);
3023 let ordering = [BlockOrder::Time, BlockOrder::Marginal];
3024
3025 let unprotected =
3027 compile_from_raw_grams(&gram_h, &gram_struct, &raw_ranges, &ordering)
3028 .expect("unprotected compile");
3029 assert_eq!(
3030 unprotected.compiled_block_ranges[0].len(),
3031 1,
3032 "unprotected first block drops its structural-null direction"
3033 );
3034
3035 let protected =
3038 compile_from_raw_grams_protected(&gram_h, &gram_struct, &raw_ranges, &ordering, &[true, false])
3039 .expect("protected compile");
3040 assert_eq!(
3041 protected.compiled_block_ranges[0].len(),
3042 2,
3043 "protected first block retains its full raw width"
3044 );
3045 let t_block0 = protected
3046 .raw_from_compiled
3047 .slice(s![0..2, protected.compiled_block_ranges[0].clone()])
3048 .to_owned();
3049 for i in 0..2 {
3050 for j in 0..2 {
3051 let expect = if i == j { 1.0 } else { 0.0 };
3052 assert!(
3053 (t_block0[[i, j]] - expect).abs() <= 1e-12,
3054 "protected first block map must be identity, got [{i},{j}]={}",
3055 t_block0[[i, j]]
3056 );
3057 }
3058 }
3059 }
3060
3061 #[test]
3062 fn orthogonalization_annotates_independent_and_fully_absorbed_blocks() {
3063 let n = 18;
3064 let anchor = Array2::from_shape_fn((n, 2), |(i, j)| {
3065 ((i + 1) as f64 * (0.19 + j as f64 * 0.07)).sin()
3066 });
3067 let duplicate = anchor.clone();
3068 let independent = Array2::from_shape_fn((n, 1), |(i, _)| ((i + 1) as f64 * 0.43).cos());
3069 let weight = vec![1.0; n];
3070 let ortho = orthogonalize_design_blocks(
3071 &[anchor, duplicate, independent],
3072 &[200, 100, 50],
3073 &weight,
3074 )
3075 .expect("structural annotation compile");
3076
3077 assert_eq!(
3078 ortho.direction_annotations[0].kind,
3079 PenalizedDirectionAnnotationKind::Independent
3080 );
3081 assert_eq!(ortho.direction_annotations[0].absorbed_width, 0);
3082 assert_eq!(
3083 ortho.direction_annotations[1].kind,
3084 PenalizedDirectionAnnotationKind::FullyAbsorbedByHigherPriority,
3085 "a duplicated lower-priority block is the same realized-design direction"
3086 );
3087 assert_eq!(ortho.direction_annotations[1].raw_width, 2);
3088 assert_eq!(ortho.direction_annotations[1].kept_width, 0);
3089 assert_eq!(ortho.direction_annotations[1].absorbed_width, 2);
3090 assert_eq!(
3091 ortho.direction_annotations[2].kind,
3092 PenalizedDirectionAnnotationKind::Independent,
3093 "a genuinely new realized-design direction keeps its own penalty block"
3094 );
3095 assert_eq!(ortho.direction_annotations[2].raw_width, 1);
3096 assert_eq!(ortho.direction_annotations[2].kept_width, 1);
3097 assert_eq!(ortho.dropped, vec![(1, 2)]);
3098 }
3099
3100 #[test]
3101 fn compile_from_raw_grams_three_block_full_logslope_alias_keeps_fast_path() {
3102 let n = 24;
3103 let time = Array2::from_shape_fn((n, 2), |(i, j)| {
3104 ((i + 1) as f64 * (j + 2) as f64 * 0.17).sin()
3105 });
3106 let marginal = Array2::from_shape_fn((n, 1), |(i, _)| ((i + 3) as f64 * 0.11).cos());
3107 let logslope = marginal.clone();
3108 let p_time = time.ncols();
3109 let p_marg = marginal.ncols();
3110 let p_log = logslope.ncols();
3111 let raw_ranges = vec![
3112 0..p_time,
3113 p_time..(p_time + p_marg),
3114 (p_time + p_marg)..(p_time + p_marg + p_log),
3115 ];
3116 let channel_blocks = PrimaryChannelBlocks {
3117 blocks: vec![
3118 vec![Some(time.clone())],
3119 vec![Some(marginal.clone())],
3120 vec![Some(logslope.clone())],
3121 ],
3122 };
3123 let row_hess = DiagonalScalarRowHessian::new(Array1::ones(n));
3124 let gram_h =
3125 build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges).unwrap();
3126 let gram_struct = build_raw_grams_structural(&channel_blocks, &raw_ranges);
3127
3128 let map = compile_from_raw_grams(
3129 &gram_h,
3130 &gram_struct,
3131 &raw_ranges,
3132 &[BlockOrder::Time, BlockOrder::Marginal, BlockOrder::Logslope],
3133 )
3134 .expect("fully aliased logslope block should not skip the compiled-map path");
3135
3136 assert_eq!(map.compiled_block_ranges[0].len(), p_time);
3137 assert_eq!(map.compiled_block_ranges[1].len(), p_marg);
3138 assert_eq!(map.compiled_block_ranges[2].len(), 0);
3139 assert_eq!(
3140 map.raw_from_compiled.dim(),
3141 (p_time + p_marg + p_log, p_time + p_marg)
3142 );
3143 let x_raw = {
3144 let mut out = Array2::<f64>::zeros((n, p_time + p_marg + p_log));
3145 out.slice_mut(s![.., raw_ranges[0].clone()]).assign(&time);
3146 out.slice_mut(s![.., raw_ranges[1].clone()])
3147 .assign(&marginal);
3148 out.slice_mut(s![.., raw_ranges[2].clone()])
3149 .assign(&logslope);
3150 out
3151 };
3152 let x_compiled = fast_ab(&x_raw, &map.raw_from_compiled);
3153 let rrqr = rrqr_with_permutation(&x_compiled, default_rrqr_rank_alpha()).unwrap();
3154 assert_eq!(rrqr.rank, x_compiled.ncols());
3155 }
3156
3157 #[test]
3163 fn compile_from_raw_grams_partial_alias_matches_w_reference() {
3164 let n = 25;
3165 let a = Array2::from_shape_fn((n, 2), |(i, j)| {
3166 ((i + 1) as f64 * (j + 1) as f64 * 0.3).sin()
3167 });
3168 let mut b = Array2::<f64>::zeros((n, 2));
3170 for i in 0..n {
3171 b[[i, 0]] = a[[i, 0]];
3172 b[[i, 1]] = ((i + 1) as f64 * 0.7).cos();
3173 }
3174 let w = Array1::from_shape_fn(n, |i| 1.0 + 0.1 * (i as f64));
3175 let (gram_h, gram_struct, raw_ranges) = scalar_grams_two_block(&a, &b, &w);
3176 let compiled = compile_from_raw_grams(
3177 &gram_h,
3178 &gram_struct,
3179 &raw_ranges,
3180 &[BlockOrder::Marginal, BlockOrder::Logslope],
3181 )
3182 .expect("closed-form compile must succeed");
3183 let p_a = a.ncols();
3184 let p_b = b.ncols();
3185 assert_eq!(compiled.raw_from_compiled.shape()[0], p_a + p_b);
3186 assert_eq!(
3187 compiled.raw_from_compiled.shape()[1],
3188 p_a + 1,
3189 "partial alias should leave compiled width = p_a + 1 (one column dropped from B)"
3190 );
3191 assert_eq!(compiled.compiled_block_ranges[0], 0..p_a);
3193 assert_eq!(
3194 compiled.compiled_block_ranges[1].end - compiled.compiled_block_ranges[1].start,
3195 1
3196 );
3197
3198 let mut x_raw = Array2::<f64>::zeros((n, p_a + p_b));
3202 for i in 0..n {
3203 for j in 0..p_a {
3204 x_raw[[i, j]] = a[[i, j]];
3205 }
3206 for j in 0..p_b {
3207 x_raw[[i, p_a + j]] = b[[i, j]];
3208 }
3209 }
3210 let x_compiled = fast_ab(&x_raw, &compiled.raw_from_compiled);
3211 let g_compiled = fast_ata(&x_compiled);
3213 let (evals, _) = g_compiled.eigh(Side::Lower).unwrap();
3214 let lam_max = evals.iter().cloned().fold(0.0_f64, f64::max);
3215 let tol = lam_max * 64.0 * (g_compiled.nrows() as f64) * f64::EPSILON;
3216 let rank_compiled = evals.iter().filter(|&&l| l > tol).count();
3217 assert_eq!(
3218 rank_compiled,
3219 p_a + 1,
3220 "compiled design column rank must equal p_a + 1 after dropping the alias"
3221 );
3222
3223 let ops_dual: Vec<Arc<dyn RowJacobianOperator>> = vec![op(a.clone()), op(b.clone())];
3226 let curvature = DiagonalScalarRowHessian::new(w.clone());
3227 let id_struct = IdentityRowHessian::new(n, 1);
3228 let dual = compile_with_dual_metric(
3229 &ops_dual,
3230 &curvature,
3231 &id_struct,
3232 &[BlockOrder::Marginal, BlockOrder::Logslope],
3233 )
3234 .expect("dual metric compile should succeed");
3235 let dual_total: usize = dual.blocks.iter().map(|b| b.t_lw.ncols()).sum();
3236 assert_eq!(dual_total, p_a + 1, "W-reference total width should match");
3237 }
3238
3239 #[test]
3242 fn compile_from_raw_grams_three_block_ordering_matters() {
3243 let n = 30;
3244 let a = Array2::from_shape_fn((n, 2), |(i, j)| {
3245 ((i + 1) as f64 * (j + 2) as f64 * 0.2).sin()
3246 });
3247 let mut b = Array2::<f64>::zeros((n, 2));
3249 for i in 0..n {
3250 b[[i, 0]] = ((i + 1) as f64 * 0.4).cos();
3251 b[[i, 1]] = a[[i, 0]];
3252 }
3253 let mut c = Array2::<f64>::zeros((n, 2));
3255 for i in 0..n {
3256 c[[i, 0]] = ((i + 1) as f64 * 0.55).sin();
3257 c[[i, 1]] = a[[i, 1]];
3258 }
3259 let w = Array1::ones(n);
3260
3261 let build = |b0: &Array2<f64>, b1: &Array2<f64>, b2: &Array2<f64>| {
3262 let raw_ranges = vec![
3263 0..b0.ncols(),
3264 b0.ncols()..(b0.ncols() + b1.ncols()),
3265 (b0.ncols() + b1.ncols())..(b0.ncols() + b1.ncols() + b2.ncols()),
3266 ];
3267 let channel_blocks = PrimaryChannelBlocks {
3268 blocks: vec![
3269 vec![Some(b0.clone())],
3270 vec![Some(b1.clone())],
3271 vec![Some(b2.clone())],
3272 ],
3273 };
3274 let row_hess = DiagonalScalarRowHessian::new(w.clone());
3275 let gram_h =
3276 build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges)
3277 .unwrap();
3278 let gram_struct = build_raw_grams_structural(&channel_blocks, &raw_ranges);
3279 (gram_h, gram_struct, raw_ranges)
3280 };
3281
3282 let (gh, gs, rr) = build(&a, &b, &c);
3284 let order_abc = compile_from_raw_grams(
3285 &gh,
3286 &gs,
3287 &rr,
3288 &[
3289 BlockOrder::Marginal,
3290 BlockOrder::Logslope,
3291 BlockOrder::LinkDev,
3292 ],
3293 )
3294 .expect("ABC compile");
3295 assert_eq!(order_abc.compiled_block_ranges[0].len(), 2);
3296 assert_eq!(order_abc.compiled_block_ranges[1].len(), 1);
3297 assert_eq!(order_abc.compiled_block_ranges[2].len(), 1);
3298
3299 let (gh2, gs2, rr2) = build(&b, &a, &c);
3302 let order_bac = compile_from_raw_grams(
3303 &gh2,
3304 &gs2,
3305 &rr2,
3306 &[
3307 BlockOrder::Marginal,
3308 BlockOrder::Logslope,
3309 BlockOrder::LinkDev,
3310 ],
3311 )
3312 .expect("BAC compile");
3313 assert_eq!(order_bac.compiled_block_ranges[0].len(), 2);
3314 assert_eq!(order_bac.compiled_block_ranges[1].len(), 1);
3315 let total_abc: usize = order_abc
3317 .compiled_block_ranges
3318 .iter()
3319 .map(|r| r.len())
3320 .sum();
3321 let total_bac: usize = order_bac
3322 .compiled_block_ranges
3323 .iter()
3324 .map(|r| r.len())
3325 .sum();
3326 assert_eq!(total_abc, total_bac);
3327 assert_eq!(total_abc, 4);
3328 }
3329
3330 fn k1_grams(x: &Array2<f64>, w: &Array1<f64>) -> (Array2<f64>, Array2<f64>) {
3335 let gram_struct = fast_atb(x, x);
3336 let xw = fast_xt_diag_y(x, w, x);
3337 (xw, gram_struct)
3338 }
3339
3340 #[test]
3347 fn compiled_map_lift_coefficients_roundtrips_full_rank() {
3348 let n = 21;
3349 let p_a = 2;
3350 let p_b = 2;
3351 let x = Array2::from_shape_fn((n, p_a + p_b), |(i, j)| {
3357 ((i as f64 + 1.0) * (0.21 + 0.17 * j as f64)).sin() + 0.11 * (j as f64)
3358 });
3359 let w = Array1::from_shape_fn(n, |i| 0.5 + 0.5 * ((i as f64) * 0.3).cos().abs());
3360 let (gh, gs) = k1_grams(&x, &w);
3361 let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
3362 let map = compile_from_raw_grams(
3363 &gh,
3364 &gs,
3365 &raw_ranges,
3366 &[BlockOrder::Marginal, BlockOrder::Logslope],
3367 )
3368 .expect("full-rank compile");
3369 assert_eq!(map.p_compiled(), p_a + p_b);
3371 assert_eq!(map.p_raw(), p_a + p_b);
3372 let beta_raw = Array1::from_shape_fn(p_a + p_b, |j| 0.4 * (j as f64) - 0.7);
3375 let tt = fast_atb(&map.raw_from_compiled, &map.raw_from_compiled);
3378 let tb = map.raw_from_compiled.t().dot(&beta_raw);
3379 let theta = solve_psd_system(&tt, &tb.insert_axis(Axis(1)))
3380 .expect("normal-equation solve")
3381 .column(0)
3382 .to_owned();
3383 let lifted = map.lift_coefficients(&theta).expect("lift");
3384 let max_err = (&lifted - &beta_raw)
3385 .iter()
3386 .fold(0.0_f64, |a, &v| a.max(v.abs()));
3387 assert!(
3388 max_err < 1e-8,
3389 "lift round-trip error {max_err:e} (full-rank reduction must be exactly invertible)"
3390 );
3391 }
3392
3393 #[test]
3399 fn compiled_map_reduce_design_matches_lifted_raw_predictor() {
3400 let n = 23;
3401 let p_a = 3;
3402 let p_b = 3;
3403 let mut x = Array2::from_shape_fn((n, p_a + p_b), |(i, j)| {
3404 ((i as f64 + 1.0) * 0.41 + (j as f64 + 1.0) * 0.7).sin() + 0.05 * (i % 3) as f64
3405 });
3406 for i in 0..n {
3408 x[[i, p_a + 1]] = x[[i, 1]];
3409 }
3410 let w = Array1::from_shape_fn(n, |i| 0.6 + 0.4 * ((i as f64) * 0.25).cos().abs());
3411 let (gh, gs) = k1_grams(&x, &w);
3412 let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
3413 let map = compile_from_raw_grams(
3414 &gh,
3415 &gs,
3416 &raw_ranges,
3417 &[BlockOrder::Marginal, BlockOrder::Logslope],
3418 )
3419 .expect("compile");
3420 let x_compiled = map.reduce_design(&x).expect("reduce_design");
3421 assert_eq!(x_compiled.ncols(), map.p_compiled());
3422 let theta = Array1::from_shape_fn(map.p_compiled(), |j| 0.3 * (j as f64) - 0.5);
3423 let pred_compiled = x_compiled.dot(&theta);
3424 let beta_raw = map.lift_coefficients(&theta).expect("lift");
3425 let pred_raw = x.dot(&beta_raw);
3426 let max_err = (&pred_compiled - &pred_raw)
3427 .iter()
3428 .fold(0.0_f64, |a, &v| a.max(v.abs()));
3429 assert!(
3430 max_err < 1e-9,
3431 "compiled-design predictor diverges from lifted raw predictor: {max_err:e}"
3432 );
3433 }
3434
3435 #[test]
3440 fn reduce_penalties_with_map_preserves_energy_on_lift() {
3441 let n = 19;
3442 let p_a = 3;
3443 let p_b = 2;
3444 let mut x = Array2::from_shape_fn((n, p_a + p_b), |(i, j)| {
3448 ((i as f64 + 1.0) * 0.29 + (j as f64 + 1.0) * 0.9).cos()
3449 });
3450 for i in 0..n {
3452 x[[i, p_a]] = x[[i, 0]];
3453 }
3454 let w = Array1::from_shape_fn(n, |i| 0.7 + 0.3 * ((i as f64) * 0.2).sin().abs());
3455 let (gh, gs) = k1_grams(&x, &w);
3456 let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
3457 let map = compile_from_raw_grams(
3458 &gh,
3459 &gs,
3460 &raw_ranges,
3461 &[BlockOrder::Marginal, BlockOrder::Logslope],
3462 )
3463 .expect("compile with alias");
3464 assert!(
3465 map.p_compiled() < p_a + p_b,
3466 "expected at least one absorbed column, got p_compiled={}",
3467 map.p_compiled()
3468 );
3469 let s_a = Array2::<f64>::eye(p_a);
3471 let s_b = Array2::<f64>::eye(p_b);
3472 let reduced = reduce_penalties_with_map(&map, &[Some(s_a.clone()), Some(s_b.clone())])
3473 .expect("reduce penalties");
3474 let theta = Array1::from_shape_fn(map.p_compiled(), |j| {
3477 0.6 * (j as f64) - 0.3 + 0.05 * (j % 2) as f64
3478 });
3479 let beta = map.lift_coefficients(&theta).expect("lift");
3480 for (block_idx, s_raw) in [(0usize, &s_a), (1usize, &s_b)] {
3481 let range = &map.raw_block_ranges[block_idx];
3482 let beta_b = beta.slice(s![range.start..range.end]).to_owned();
3483 let raw_energy = beta_b.dot(&s_raw.dot(&beta_b));
3484 let s_reduced = reduced[block_idx]
3485 .as_ref()
3486 .expect("reduced penalty present");
3487 let reduced_energy = theta.dot(&s_reduced.dot(&theta));
3488 assert!(
3489 (raw_energy - reduced_energy).abs() < 1e-8 * raw_energy.abs().max(1.0),
3490 "block {block_idx} energy mismatch: raw={raw_energy:e} reduced={reduced_energy:e}"
3491 );
3492 }
3493 }
3494}