1use std::ops::Range;
15use std::sync::Arc;
16
17use ndarray::{Array1, Array2, Array3, Axis, s};
18
19use faer::Side;
20use gam_linalg::faer_ndarray::{
21 FaerEigh, default_rrqr_rank_alpha, fast_ab, fast_ata, fast_atb, fast_xt_diag_y,
22 rrqr_with_permutation,
23};
24
25const RANK_REVEAL_EPS_SLACK: f64 = 64.0;
34
35pub trait RowJacobianOperator: Send + Sync {
42 fn k(&self) -> usize;
44
45 fn ncols(&self) -> usize;
47
48 fn nrows(&self) -> usize;
50
51 fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]);
53
54 fn evaluate_full(&self) -> Array3<f64>;
56
57 fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
73 scale_block_by_sqrt_h(&self.evaluate_full(), h_full)
74 }
75
76 fn channel_flattened_column(&self, col: usize, out: &mut [f64]) {
90 let k = self.k();
91 let n = self.nrows();
92 assert!(
93 col < self.ncols(),
94 "channel_flattened_column col {col} out of range {}",
95 self.ncols()
96 );
97 assert_eq!(
98 out.len(),
99 n * k,
100 "channel_flattened_column out length {} != n*k = {}*{}",
101 out.len(),
102 n,
103 k
104 );
105 let full = self.evaluate_full();
106 for i in 0..n {
107 for ch in 0..k {
108 out[i * k + ch] = full[[i, col, ch]];
109 }
110 }
111 }
112
113 fn channel_flattened_rows(&self, rows: Range<usize>, out: &mut Array2<f64>) {
120 let n = self.nrows();
121 let start = rows.start.min(n);
122 let end = rows.end.min(n);
123 let chunk = end - start;
124 let k = self.k();
125 let p = self.ncols();
126 assert_eq!(out.shape(), &[chunk * k, p]);
127 let full = self.evaluate_full();
128 for local_i in 0..chunk {
129 let row = start + local_i;
130 for ch in 0..k {
131 for col in 0..p {
132 out[[local_i * k + ch, col]] = full[[row, col, ch]];
133 }
134 }
135 }
136 }
137}
138
139pub trait RowHessian: Send + Sync {
141 fn k(&self) -> usize;
142 fn nrows(&self) -> usize;
143 fn fill_row(&self, row: usize, out: &mut [f64]);
145 fn evaluate_full(&self) -> Array3<f64>;
147}
148
149pub struct IdentityRowHessian {
156 n: usize,
157 k: usize,
158}
159
160impl IdentityRowHessian {
161 pub fn new(n: usize, k: usize) -> Self {
164 Self { n, k }
165 }
166}
167
168impl RowHessian for IdentityRowHessian {
169 fn k(&self) -> usize {
170 self.k
171 }
172 fn nrows(&self) -> usize {
173 self.n
174 }
175 fn fill_row(&self, row: usize, out: &mut [f64]) {
176 assert!(
177 row < self.n,
178 "IdentityRowHessian::fill_row row {row} out of range {n}",
179 n = self.n
180 );
181 assert_eq!(out.len(), self.k * self.k);
182 for i in 0..self.k {
183 for j in 0..self.k {
184 out[i * self.k + j] = if i == j { 1.0 } else { 0.0 };
185 }
186 }
187 }
188 fn evaluate_full(&self) -> Array3<f64> {
189 let mut out = Array3::<f64>::zeros((self.n, self.k, self.k));
190 for i in 0..self.n {
191 for c in 0..self.k {
192 out[[i, c, c]] = 1.0;
193 }
194 }
195 out
196 }
197}
198
199pub struct CompiledBlock {
203 pub t_lw: Array2<f64>,
205 pub anchor_correction: Option<Array2<f64>>,
213 pub r_lw: Option<Array2<f64>>,
217}
218
219pub struct CompiledBlocks {
222 pub blocks: Vec<CompiledBlock>,
223 pub joint_rank: usize,
225 pub dropped: Vec<(usize, usize)>,
228}
229
230#[derive(Debug, Clone, Copy, PartialEq, Eq)]
233pub enum PenalizedDirectionAnnotationKind {
234 Independent,
237 PartiallyAbsorbedByHigherPriority,
240 FullyAbsorbedByHigherPriority,
244}
245
246#[derive(Debug, Clone, Copy, PartialEq, Eq)]
248pub struct PenalizedDirectionAnnotation {
249 pub block_idx: usize,
250 pub raw_width: usize,
251 pub kept_width: usize,
252 pub absorbed_width: usize,
253 pub kind: PenalizedDirectionAnnotationKind,
254}
255
256#[derive(Debug)]
258pub enum CompilerError {
259 DimensionMismatch(String),
261 FullyAliased { block_idx: usize, reason: String },
264 LinalgFailure(String),
266}
267
268impl std::fmt::Display for CompilerError {
269 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
270 match self {
271 CompilerError::DimensionMismatch(msg) => write!(f, "dimension mismatch: {msg}"),
272 CompilerError::FullyAliased { block_idx, reason } => {
273 write!(f, "block {block_idx} fully aliased: {reason}")
274 }
275 CompilerError::LinalgFailure(msg) => write!(f, "linalg failure: {msg}"),
276 }
277 }
278}
279
280impl std::error::Error for CompilerError {}
281
282#[derive(Debug, Clone, Copy, PartialEq, Eq)]
287pub enum BlockOrder {
288 Time,
289 Marginal,
290 Logslope,
291 ScoreWarp,
292 LinkDev,
293}
294
295pub fn compile(
305 operators: &[Arc<dyn RowJacobianOperator>],
306 row_hess: &dyn RowHessian,
307 ordering: &[BlockOrder],
308) -> Result<CompiledBlocks, CompilerError> {
309 let n = row_hess.nrows();
315 let k = row_hess.k();
316 let id_struct = IdentityRowHessian::new(n, k);
317 compile_with_dual_metric(operators, row_hess, &id_struct, ordering)
318}
319
320pub fn compile_with_dual_metric(
351 operators: &[Arc<dyn RowJacobianOperator>],
352 row_hess: &dyn RowHessian,
353 row_structural: &dyn RowHessian,
354 ordering: &[BlockOrder],
355) -> Result<CompiledBlocks, CompilerError> {
356 if operators.len() != ordering.len() {
357 return Err(CompilerError::DimensionMismatch(format!(
358 "operators ({}) and ordering ({}) length mismatch",
359 operators.len(),
360 ordering.len()
361 )));
362 }
363 if operators.is_empty() {
364 return Ok(CompiledBlocks {
365 blocks: Vec::new(),
366 joint_rank: 0,
367 dropped: Vec::new(),
368 });
369 }
370
371 let k = row_hess.k();
372 let n = row_hess.nrows();
373 if row_structural.k() != k {
374 return Err(CompilerError::DimensionMismatch(format!(
375 "structural row metric has K={} but curvature row Hessian has K={k}",
376 row_structural.k()
377 )));
378 }
379 if row_structural.nrows() != n {
380 return Err(CompilerError::DimensionMismatch(format!(
381 "structural row metric has nrows={} but curvature row Hessian has nrows={n}",
382 row_structural.nrows()
383 )));
384 }
385 for (idx, op) in operators.iter().enumerate() {
386 if op.k() != k {
387 return Err(CompilerError::DimensionMismatch(format!(
388 "operator {idx} has K={} but row Hessian has K={k}",
389 op.k()
390 )));
391 }
392 if op.nrows() != n {
393 return Err(CompilerError::DimensionMismatch(format!(
394 "operator {idx} has nrows={} but row Hessian has nrows={n}",
395 op.nrows()
396 )));
397 }
398 }
399
400 let h_full = row_hess.evaluate_full();
403 let s_full = row_structural.evaluate_full();
404
405 let scaled_h: Vec<Array2<f64>> = operators
416 .iter()
417 .map(|op| op.scaled_design_by_sqrt_h(&h_full))
418 .collect();
419 let scaled_s: Vec<Array2<f64>> = operators
420 .iter()
421 .map(|op| op.scaled_design_by_sqrt_h(&s_full))
422 .collect();
423
424 let mut compiled: Vec<CompiledBlock> = Vec::with_capacity(operators.len());
425 let mut walk_demotions: Vec<(usize, usize)> = Vec::new();
433 let mut anchor_h: Array2<f64> = Array2::zeros((n * k, 0));
434 let mut anchor_s: Array2<f64> = Array2::zeros((n * k, 0));
435 let mut raw_anchor_h: Array2<f64> = Array2::zeros((n * k, 0));
444
445 for idx in 0..operators.len() {
446 let w_h = &scaled_h[idx];
447 let w_s = &scaled_s[idx];
448 let p_b = w_h.ncols();
449
450 if p_b == 0 {
459 compiled.push(CompiledBlock {
460 t_lw: Array2::<f64>::zeros((0, 0)),
461 anchor_correction: Some(Array2::<f64>::zeros((raw_anchor_h.ncols(), 0))),
462 r_lw: Some(Array2::<f64>::zeros((raw_anchor_h.ncols(), 0))),
463 });
464 continue;
465 }
466
467 let (residual_s, _) = residualise_in_metric(&anchor_s, w_s)?;
476 let g_s = fast_atb(&residual_s, &residual_s);
477 let g_s_bb = fast_atb(w_s, w_s);
484 let g_s_trace: f64 = (0..p_b).map(|i| g_s_bb[[i, i]].max(0.0)).sum();
485 let d = keep_positive_eigenspace(&g_s, n, k, g_s_trace)?;
486 if d.ncols() == 0 {
487 if anchor_h.ncols() == 0 {
488 return Err(CompilerError::FullyAliased {
489 block_idx: idx,
490 reason: format!(
491 "structural residual Gram has no positive eigenspace (block of width {p_b} has zero structural span before any anchor exists)"
492 ),
493 });
494 }
495 compiled.push(CompiledBlock {
496 t_lw: Array2::<f64>::zeros((p_b, 0)),
497 anchor_correction: Some(Array2::<f64>::zeros((raw_anchor_h.ncols(), 0))),
498 r_lw: Some(Array2::<f64>::zeros((raw_anchor_h.ncols(), 0))),
499 });
500 for c in 0..p_b {
504 walk_demotions.push((idx, c));
505 }
506 raw_anchor_h = concat_cols(&raw_anchor_h, w_h);
507 continue;
508 }
509
510 let w_h_d = fast_ab(w_h, &d);
517 let (residual_h, m_h_inner_opt) = residualise_in_metric(&anchor_h, &w_h_d)?;
518 let g_h = fast_atb(&residual_h, &residual_h);
519 let p_d = d.ncols();
520 let g_h_dd = fast_atb(&w_h_d, &w_h_d);
526 let g_h_trace: f64 = (0..p_d).map(|i| g_h_dd[[i, i]].max(0.0)).sum();
527 let t_inner = keep_positive_eigenspace(&g_h, n, k, g_h_trace)?;
528 if t_inner.ncols() == 0 {
529 if anchor_h.ncols() == 0 {
530 return Err(CompilerError::FullyAliased {
531 block_idx: idx,
532 reason: format!(
533 "curvature residual Gram has no positive eigenspace within structurally-kept basis (block of width {p_b}, structural-kept {p_d}) before any anchor exists"
534 ),
535 });
536 }
537 compiled.push(CompiledBlock {
538 t_lw: Array2::<f64>::zeros((p_b, 0)),
539 anchor_correction: Some(Array2::<f64>::zeros((raw_anchor_h.ncols(), 0))),
540 r_lw: Some(Array2::<f64>::zeros((raw_anchor_h.ncols(), 0))),
541 });
542 for c in 0..p_d {
547 walk_demotions.push((idx, c));
548 }
549 raw_anchor_h = concat_cols(&raw_anchor_h, w_h);
550 continue;
551 }
552
553 let v = fast_ab(&d, &t_inner);
555
556 let prior_anchor_h = anchor_h.clone();
564 let prior_raw_anchor_h = raw_anchor_h.clone();
565
566 let residual_h_t = fast_ab(&residual_h, &t_inner);
570 anchor_h = concat_cols(&anchor_h, &residual_h_t);
571 let residual_s_v = fast_ab(&residual_s, &v);
574 anchor_s = concat_cols(&anchor_s, &residual_s_v);
575
576 let m_compiled = match m_h_inner_opt.as_ref() {
603 Some(m) => {
604 let m_kept = fast_ab(m, &t_inner);
605 if m_kept.nrows() != prior_anchor_h.ncols() {
606 return Err(CompilerError::DimensionMismatch(format!(
607 "anchor correction must be indexed by prior-block kept anchor directions: \
608 m_kept has {} rows but prior_anchor_h has {} columns",
609 m_kept.nrows(),
610 prior_anchor_h.ncols()
611 )));
612 }
613 let g_raw = fast_atb(&prior_raw_anchor_h, &prior_raw_anchor_h);
614 let z_rhs = fast_atb(&prior_raw_anchor_h, &prior_anchor_h);
615 let z = solve_psd_system(&g_raw, &z_rhs)?;
616 Some(fast_ab(&z, &m_kept))
617 }
618 None => None,
619 };
620 compiled.push(CompiledBlock {
621 t_lw: v,
622 anchor_correction: m_compiled.clone(),
623 r_lw: m_compiled,
624 });
625
626 raw_anchor_h = concat_cols(&raw_anchor_h, w_h);
630 }
631
632 let audit_dropped = audit_and_drop_trailing_pivots(&anchor_h, &mut compiled)?;
635 let mut dropped = walk_demotions;
639 dropped.extend(audit_dropped);
640 let joint_rank: usize = compiled.iter().map(|b| b.t_lw.ncols()).sum();
641
642 Ok(CompiledBlocks {
643 blocks: compiled,
644 joint_rank,
645 dropped,
646 })
647}
648
649fn scale_block_by_sqrt_h(jb: &Array3<f64>, h_full: &Array3<f64>) -> Array2<f64> {
653 let n = jb.shape()[0];
654 let p = jb.shape()[1];
655 let k = jb.shape()[2];
656 scale_jacobian_by_sqrt_h_with(n, p, k, h_full, |i, a, c| jb[[i, a, c]])
657}
658
659pub fn scale_jacobian_by_sqrt_h_with(
673 n: usize,
674 p: usize,
675 k: usize,
676 h_full: &Array3<f64>,
677 jac: impl Fn(usize, usize, usize) -> f64,
678) -> Array2<f64> {
679 assert_eq!(h_full.shape(), &[n, k, k]);
680 let mut out = Array2::<f64>::zeros((n * k, p));
681 let mut sqrt_h = Array2::<f64>::zeros((k, k));
682 let mut scratch_jrow = Array2::<f64>::zeros((p, k));
683 for i in 0..n {
684 let h_i = h_full.index_axis(Axis(0), i).to_owned();
686 sqrt_h.fill(0.0);
687 symmetric_sqrt_into(&h_i, &mut sqrt_h);
688 for a in 0..p {
692 for c in 0..k {
693 scratch_jrow[[a, c]] = jac(i, a, c);
694 }
695 }
696 for c in 0..k {
697 for a in 0..p {
698 let mut acc = 0.0;
699 for cp in 0..k {
700 acc += sqrt_h[[c, cp]] * scratch_jrow[[a, cp]];
701 }
702 out[[i * k + c, a]] = acc;
703 }
704 }
705 }
706 out
707}
708
709pub(crate) fn symmetric_sqrt_into(m: &Array2<f64>, out: &mut Array2<f64>) {
712 let k = m.nrows();
713 assert_eq!(m.ncols(), k);
714 assert_eq!(out.shape(), &[k, k]);
715 if k == 1 {
716 out[[0, 0]] = m[[0, 0]].max(0.0).sqrt();
717 return;
718 }
719 let (evals, evecs) = match m.eigh(Side::Lower) {
720 Ok(pair) => pair,
721 Err(_) => {
722 out.fill(0.0);
725 for i in 0..k {
726 out[[i, i]] = m[[i, i]].max(0.0).sqrt();
727 }
728 return;
729 }
730 };
731 let mut scaled = evecs.clone();
733 for j in 0..k {
734 let s = evals[j].max(0.0).sqrt();
735 for i in 0..k {
736 scaled[[i, j]] *= s;
737 }
738 }
739 out.assign(&fast_atb(&evecs.t().to_owned(), &scaled.t().to_owned()));
740 out.fill(0.0);
744 for i in 0..k {
745 for j in 0..k {
746 let mut acc = 0.0;
747 for l in 0..k {
748 acc += evecs[[i, l]] * evals[l].max(0.0).sqrt() * evecs[[j, l]];
749 }
750 out[[i, j]] = acc;
751 }
752 }
753}
754
755fn residualise_in_metric(
759 a_scaled: &Array2<f64>,
760 b_scaled: &Array2<f64>,
761) -> Result<(Array2<f64>, Option<Array2<f64>>), CompilerError> {
762 let d = a_scaled.ncols();
763 if d == 0 {
764 return Ok((b_scaled.clone(), None));
765 }
766 let g_aa = fast_atb(a_scaled, a_scaled);
767 let g_ab = fast_atb(a_scaled, b_scaled);
768 let m = solve_psd_system(&g_aa, &g_ab)?;
769 let a_m = fast_ab(a_scaled, &m);
770 let residual = b_scaled - &a_m;
771 Ok((residual, Some(m)))
772}
773
774fn solve_psd_system(g: &Array2<f64>, r: &Array2<f64>) -> Result<Array2<f64>, CompilerError> {
779 let n = g.nrows();
780 if n == 0 {
781 return Ok(Array2::zeros((0, r.ncols())));
782 }
783 let (evals, evecs) = g
784 .eigh(Side::Lower)
785 .map_err(|err| CompilerError::LinalgFailure(format!("Gram eigh failed: {err:?}")))?;
786 let lambda_max = evals.iter().cloned().fold(0.0_f64, f64::max).max(0.0);
787 let tol = lambda_max * RANK_REVEAL_EPS_SLACK * (n.max(1) as f64) * f64::EPSILON;
788 let u_t_r = fast_atb(&evecs, r);
790 let mut scaled = u_t_r.clone();
791 for i in 0..n {
792 let lam = evals[i];
793 let inv = if lam > tol { 1.0 / lam } else { 0.0 };
794 for j in 0..scaled.ncols() {
795 scaled[[i, j]] *= inv;
796 }
797 }
798 let m = fast_ab(&evecs, &scaled);
799 Ok(m)
800}
801
802fn keep_positive_eigenspace(
806 g_tilde: &Array2<f64>,
807 n: usize,
808 k: usize,
809 g_bb_trace: f64,
810) -> Result<Array2<f64>, CompilerError> {
811 let p = g_tilde.nrows();
812 if p == 0 {
813 return Ok(Array2::zeros((0, 0)));
814 }
815 let (evals, evecs) = g_tilde.eigh(Side::Lower).map_err(|err| {
816 CompilerError::LinalgFailure(format!("residual Gram eigh failed: {err:?}"))
817 })?;
818 let lambda_max = evals.iter().cloned().fold(0.0_f64, f64::max).max(0.0);
819 let scale = lambda_max.max(g_bb_trace);
820 let nk = (n.saturating_mul(k)).max(p).max(1) as f64;
821 let tau = scale * RANK_REVEAL_EPS_SLACK * nk * f64::EPSILON;
822 let mut kept: Vec<usize> = (0..p).filter(|&i| evals[i] > tau).collect();
824 kept.sort_by(|&a, &b| {
826 evals[b]
827 .partial_cmp(&evals[a])
828 .unwrap_or(std::cmp::Ordering::Equal)
829 });
830 let mut v = Array2::<f64>::zeros((p, kept.len()));
831 for (out_col, &src_col) in kept.iter().enumerate() {
832 for row in 0..p {
833 v[[row, out_col]] = evecs[[row, src_col]];
834 }
835 }
836 Ok(v)
837}
838
839fn concat_cols(left: &Array2<f64>, right: &Array2<f64>) -> Array2<f64> {
841 let nrows = left.nrows().max(right.nrows());
842 let lc = left.ncols();
843 let rc = right.ncols();
844 let mut out = Array2::<f64>::zeros((nrows, lc + rc));
845 if lc > 0 {
846 out.slice_mut(s![.., ..lc]).assign(left);
847 }
848 if rc > 0 {
849 out.slice_mut(s![.., lc..]).assign(right);
850 }
851 out
852}
853
854fn audit_and_drop_trailing_pivots(
858 w_joint: &Array2<f64>,
859 compiled: &mut [CompiledBlock],
860) -> Result<Vec<(usize, usize)>, CompilerError> {
861 let p_total: usize = compiled.iter().map(|b| b.t_lw.ncols()).sum();
862 if p_total == 0 || w_joint.nrows() == 0 {
863 return Ok(Vec::new());
864 }
865
866 let rrqr = rrqr_with_permutation(w_joint, default_rrqr_rank_alpha())
868 .map_err(|err| CompilerError::LinalgFailure(format!("audit RRQR failed: {err:?}")))?;
869 let rank = rrqr.rank;
870 if rank >= p_total {
871 return Ok(Vec::new());
872 }
873
874 let drop_count = p_total - rank;
881 let latest_idx = compiled.len() - 1;
882 let latest = &mut compiled[latest_idx];
883 let kept_local = latest.t_lw.ncols().saturating_sub(drop_count);
884 let dropped_locals: Vec<(usize, usize)> = (kept_local..latest.t_lw.ncols())
885 .map(|c| (latest_idx, c))
886 .collect();
887 latest.t_lw = latest.t_lw.slice(s![.., ..kept_local]).to_owned();
896 if let Some(m) = latest.anchor_correction.as_ref() {
897 latest.anchor_correction = Some(m.slice(s![.., ..kept_local]).to_owned());
898 }
899 if let Some(r) = latest.r_lw.as_ref() {
900 latest.r_lw = Some(r.slice(s![.., ..kept_local]).to_owned());
901 }
902 Ok(dropped_locals)
903}
904
905pub struct PrimaryChannelBlocks {
914 pub blocks: Vec<Vec<Option<Array2<f64>>>>,
917}
918
919pub fn build_raw_grams_from_channel_blocks(
931 channel_blocks: &PrimaryChannelBlocks,
932 row_hess: &dyn RowHessian,
933 raw_block_ranges: &[std::ops::Range<usize>],
934) -> Result<Array2<f64>, CompilerError> {
935 let num_blocks = channel_blocks.blocks.len();
936 if num_blocks != raw_block_ranges.len() {
937 return Err(CompilerError::DimensionMismatch(format!(
938 "channel_blocks ({num_blocks}) and raw_block_ranges ({}) length mismatch",
939 raw_block_ranges.len()
940 )));
941 }
942 if num_blocks == 0 {
943 return Ok(Array2::<f64>::zeros((0, 0)));
944 }
945 let k = row_hess.k();
946 let n = row_hess.nrows();
947 let p_total: usize = raw_block_ranges.iter().map(|r| r.end - r.start).sum();
948 let expected_total = raw_block_ranges.last().map(|r| r.end).unwrap_or(0);
949 if expected_total != p_total {
950 return Err(CompilerError::DimensionMismatch(format!(
951 "raw_block_ranges must be contiguous from 0; got p_total={p_total} but last end={expected_total}"
952 )));
953 }
954 for (b, slots) in channel_blocks.blocks.iter().enumerate() {
956 if slots.len() != k {
957 return Err(CompilerError::DimensionMismatch(format!(
958 "block {b}: expected {k} channel slots, got {}",
959 slots.len()
960 )));
961 }
962 let p_b = raw_block_ranges[b].end - raw_block_ranges[b].start;
963 for (c, mat) in slots.iter().enumerate() {
964 if let Some(x) = mat.as_ref() {
965 if x.nrows() != n {
966 return Err(CompilerError::DimensionMismatch(format!(
967 "block {b} channel {c}: nrows={} but row Hessian nrows={n}",
968 x.nrows()
969 )));
970 }
971 if x.ncols() != p_b {
972 return Err(CompilerError::DimensionMismatch(format!(
973 "block {b} channel {c}: ncols={} but block width={p_b}",
974 x.ncols()
975 )));
976 }
977 }
978 }
979 }
980
981 let h_full = row_hess.evaluate_full();
983 if h_full.shape() != &[n, k, k] {
984 return Err(CompilerError::DimensionMismatch(format!(
985 "row Hessian evaluate_full shape {:?} != [n={n}, k={k}, k={k}]",
986 h_full.shape()
987 )));
988 }
989 let mut h_pairs: Vec<Array1<f64>> = Vec::with_capacity(k * k);
991 for c in 0..k {
992 for d in 0..k {
993 let mut v = Array1::<f64>::zeros(n);
994 for i in 0..n {
995 v[i] = h_full[[i, c, d]];
996 }
997 h_pairs.push(v);
998 }
999 }
1000
1001 let mut gram = Array2::<f64>::zeros((p_total, p_total));
1002 for a in 0..num_blocks {
1004 let range_a = raw_block_ranges[a].clone();
1005 for b in a..num_blocks {
1006 let range_b = raw_block_ranges[b].clone();
1007 let mut block_acc =
1008 Array2::<f64>::zeros((range_a.end - range_a.start, range_b.end - range_b.start));
1009 for c in 0..k {
1010 let Some(x_a_c) = channel_blocks.blocks[a][c].as_ref() else {
1011 continue;
1012 };
1013 for d in 0..k {
1014 let Some(x_b_d) = channel_blocks.blocks[b][d].as_ref() else {
1015 continue;
1016 };
1017 let h_cd = &h_pairs[c * k + d];
1018 let contrib = fast_xt_diag_y(x_a_c, h_cd, x_b_d);
1020 block_acc += &contrib;
1021 }
1022 }
1023 gram.slice_mut(s![range_a.start..range_a.end, range_b.start..range_b.end])
1025 .assign(&block_acc);
1026 }
1027 }
1028 for i in 0..p_total {
1031 for j in 0..i {
1032 let v = gram[[j, i]];
1033 gram[[i, j]] = v;
1034 }
1035 }
1036 Ok(gram)
1037}
1038
1039pub fn build_raw_grams_structural(
1046 channel_blocks: &PrimaryChannelBlocks,
1047 raw_block_ranges: &[std::ops::Range<usize>],
1048) -> Array2<f64> {
1049 let num_blocks = channel_blocks.blocks.len();
1050 assert_eq!(
1051 num_blocks,
1052 raw_block_ranges.len(),
1053 "channel_blocks ({num_blocks}) and raw_block_ranges ({}) length mismatch",
1054 raw_block_ranges.len()
1055 );
1056 if num_blocks == 0 {
1057 return Array2::<f64>::zeros((0, 0));
1058 }
1059 let p_total = raw_block_ranges.last().map(|r| r.end).unwrap_or(0);
1060 let mut gram = Array2::<f64>::zeros((p_total, p_total));
1061 for a in 0..num_blocks {
1062 let range_a = raw_block_ranges[a].clone();
1063 for b in a..num_blocks {
1064 let range_b = raw_block_ranges[b].clone();
1065 let p_a = range_a.end - range_a.start;
1066 let p_b = range_b.end - range_b.start;
1067 let k_a = channel_blocks.blocks[a].len();
1068 let k_b = channel_blocks.blocks[b].len();
1069 assert_eq!(
1070 k_a, k_b,
1071 "structural Gram: block {a} has {k_a} channels but block {b} has {k_b}",
1072 );
1073 let mut block_acc = Array2::<f64>::zeros((p_a, p_b));
1074 for c in 0..k_a {
1075 let (Some(x_a_c), Some(x_b_c)) = (
1076 channel_blocks.blocks[a][c].as_ref(),
1077 channel_blocks.blocks[b][c].as_ref(),
1078 ) else {
1079 continue;
1080 };
1081 let contrib = if a == b {
1082 fast_ata(x_a_c)
1084 } else {
1085 fast_atb(x_a_c, x_b_c)
1086 };
1087 block_acc += &contrib;
1088 }
1089 gram.slice_mut(s![range_a.start..range_a.end, range_b.start..range_b.end])
1090 .assign(&block_acc);
1091 }
1092 }
1093 for i in 0..p_total {
1094 for j in 0..i {
1095 let v = gram[[j, i]];
1096 gram[[i, j]] = v;
1097 }
1098 }
1099 gram
1100}
1101
1102pub fn build_primary_grams_gpu_or_cpu(
1115 channel_blocks: &PrimaryChannelBlocks,
1116 row_hess: &dyn RowHessian,
1117 raw_block_ranges: &[std::ops::Range<usize>],
1118) -> Result<(Array2<f64>, Array2<f64>), CompilerError> {
1119 let k = row_hess.k();
1120 if k == crate::families::gpu::CHANNELS {
1121 let gpu_blocks: Vec<Vec<Option<Array2<f64>>>> = channel_blocks
1122 .blocks
1123 .iter()
1124 .map(|slots| slots.iter().cloned().collect())
1125 .collect();
1126 if let Some(h_packed) = pack_row_hessian_symmetric(row_hess) {
1127 if let Some(bundle) = crate::families::gpu::try_primary_state_gram_cuda(
1128 &gpu_blocks,
1129 &h_packed,
1130 raw_block_ranges,
1131 ) {
1132 log::info!("[identifiability_compile] gram path = gpu");
1133 return Ok((bundle.gram_h, bundle.gram_struct));
1134 }
1135 }
1136 }
1137 log::info!("[identifiability_compile] gram path = cpu");
1138 let gram_h = build_raw_grams_from_channel_blocks(channel_blocks, row_hess, raw_block_ranges)?;
1139 let gram_struct = build_raw_grams_structural(channel_blocks, raw_block_ranges);
1140 Ok((gram_h, gram_struct))
1141}
1142
1143fn pack_row_hessian_symmetric(row_hess: &dyn RowHessian) -> Option<Array2<f64>> {
1147 use crate::families::gpu::{CHANNELS, PACKED_LEN, packed_index};
1148 if row_hess.k() != CHANNELS {
1149 return None;
1150 }
1151 let n = row_hess.nrows();
1152 let h_full = row_hess.evaluate_full();
1153 if h_full.shape() != [n, CHANNELS, CHANNELS] {
1154 return None;
1155 }
1156 let mut packed = Array2::<f64>::zeros((n, PACKED_LEN));
1157 for i in 0..n {
1158 for c in 0..CHANNELS {
1159 for d in c..CHANNELS {
1160 packed[[i, packed_index(c, d)]] = h_full[[i, c, d]];
1161 }
1162 }
1163 }
1164 Some(packed)
1165}
1166
1167#[derive(Debug)]
1176pub struct CompiledMap {
1177 pub raw_from_compiled: Array2<f64>,
1179 pub compiled_block_ranges: Vec<std::ops::Range<usize>>,
1182 pub raw_block_ranges: Vec<std::ops::Range<usize>>,
1184}
1185
1186impl gam_problem::gauge::CompiledBlockMap for CompiledMap {
1192 fn raw_from_compiled(&self) -> &Array2<f64> {
1193 &self.raw_from_compiled
1194 }
1195 fn raw_block_ranges(&self) -> &[std::ops::Range<usize>] {
1196 &self.raw_block_ranges
1197 }
1198 fn compiled_block_ranges(&self) -> &[std::ops::Range<usize>] {
1199 &self.compiled_block_ranges
1200 }
1201}
1202
1203pub fn compile_from_raw_grams(
1229 gram_h: &Array2<f64>,
1230 gram_struct: &Array2<f64>,
1231 raw_block_ranges: &[std::ops::Range<usize>],
1232 ordering: &[BlockOrder],
1233) -> Result<CompiledMap, CompilerError> {
1234 if raw_block_ranges.len() != ordering.len() {
1235 return Err(CompilerError::DimensionMismatch(format!(
1236 "raw_block_ranges ({}) and ordering ({}) length mismatch",
1237 raw_block_ranges.len(),
1238 ordering.len()
1239 )));
1240 }
1241 let p_raw = raw_block_ranges.last().map(|r| r.end).unwrap_or(0);
1242 if gram_h.shape() != [p_raw, p_raw] {
1243 return Err(CompilerError::DimensionMismatch(format!(
1244 "gram_h shape {:?} != [p_raw={p_raw}, p_raw={p_raw}]",
1245 gram_h.shape()
1246 )));
1247 }
1248 if gram_struct.shape() != [p_raw, p_raw] {
1249 return Err(CompilerError::DimensionMismatch(format!(
1250 "gram_struct shape {:?} != [p_raw={p_raw}, p_raw={p_raw}]",
1251 gram_struct.shape()
1252 )));
1253 }
1254 if raw_block_ranges.is_empty() {
1255 return Ok(CompiledMap {
1256 raw_from_compiled: Array2::<f64>::zeros((0, 0)),
1257 compiled_block_ranges: Vec::new(),
1258 raw_block_ranges: Vec::new(),
1259 });
1260 }
1261 let mut expected_start = 0usize;
1263 for (b, r) in raw_block_ranges.iter().enumerate() {
1264 if r.start != expected_start {
1265 return Err(CompilerError::DimensionMismatch(format!(
1266 "raw_block_ranges must be contiguous from 0; block {b} starts at {} expected {expected_start}",
1267 r.start
1268 )));
1269 }
1270 expected_start = r.end;
1271 }
1272
1273 let mut t_cum: Array2<f64> = Array2::<f64>::zeros((p_raw, 0));
1275 let mut compiled_block_ranges: Vec<std::ops::Range<usize>> =
1276 Vec::with_capacity(raw_block_ranges.len());
1277
1278 for (idx, range_b) in raw_block_ranges.iter().enumerate() {
1279 let p_b = range_b.end - range_b.start;
1280 if p_b == 0 {
1290 let at = t_cum.ncols();
1291 compiled_block_ranges.push(at..at);
1292 continue;
1293 }
1294 let ks_t = fast_ab(gram_struct, &t_cum);
1299 let g_s_aa = fast_atb(&t_cum, &ks_t);
1301 let ks_pb = gram_struct
1303 .slice(s![.., range_b.start..range_b.end])
1304 .to_owned();
1305 let g_s_ab = fast_atb(&t_cum, &ks_pb);
1306 let g_s_bb = gram_struct
1308 .slice(s![range_b.start..range_b.end, range_b.start..range_b.end])
1309 .to_owned();
1310 let r_s = solve_psd_system(&g_s_aa, &g_s_ab)?;
1312 let g_s_res_raw = &g_s_bb - &fast_atb(&g_s_ab, &r_s);
1314 let g_s_res = symmetrise(&g_s_res_raw);
1315 let g_s_bb_trace: f64 = (0..p_b).map(|i| g_s_bb[[i, i]].max(0.0)).sum();
1317 let q_plus = keep_positive_eigenspace(&g_s_res, p_raw, 1, g_s_bb_trace)?;
1319 if q_plus.ncols() == 0 {
1320 if t_cum.ncols() == 0 {
1321 return Err(CompilerError::FullyAliased {
1322 block_idx: idx,
1323 reason: format!(
1324 "structural residual Gram has no positive eigenspace (block of width {p_b} has zero structural span before any anchor exists)"
1325 ),
1326 });
1327 }
1328 let at = t_cum.ncols();
1329 compiled_block_ranges.push(at..at);
1330 continue;
1331 }
1332 let mut diff = Array2::<f64>::zeros((p_raw, p_b));
1337 if t_cum.ncols() > 0 {
1338 let t_rs = fast_ab(&t_cum, &r_s);
1340 for i in 0..p_raw {
1341 for j in 0..p_b {
1342 diff[[i, j]] = -t_rs[[i, j]];
1343 }
1344 }
1345 }
1346 for j in 0..p_b {
1347 diff[[range_b.start + j, j]] += 1.0;
1348 }
1349 let d_mat = fast_ab(&diff, &q_plus);
1350
1351 let kh_t = fast_ab(gram_h, &t_cum);
1354 let g_h_aa = fast_atb(&t_cum, &kh_t);
1355 let kh_d = fast_ab(gram_h, &d_mat);
1356 let g_h_ad = fast_atb(&t_cum, &kh_d);
1357 let r_h = solve_psd_system(&g_h_aa, &g_h_ad)?;
1358 let d_t_kh_d = fast_atb(&d_mat, &kh_d);
1360 let g_h_res_raw = &d_t_kh_d - &fast_atb(&g_h_ad, &r_h);
1361 let g_h_res = symmetrise(&g_h_res_raw);
1362 let k_kept = q_plus.ncols();
1363 let g_h_dd_trace: f64 = (0..k_kept).map(|i| d_t_kh_d[[i, i]].max(0.0)).sum();
1364 let u_mat = keep_positive_eigenspace(&g_h_res, p_raw, 1, g_h_dd_trace)?;
1365 if u_mat.ncols() == 0 {
1366 if t_cum.ncols() == 0 {
1367 return Err(CompilerError::FullyAliased {
1368 block_idx: idx,
1369 reason: format!(
1370 "curvature residual Gram has no positive eigenspace within structurally-kept basis (block of width {p_b}, structural-kept {k_kept}) before any anchor exists"
1371 ),
1372 });
1373 }
1374 let at = t_cum.ncols();
1375 compiled_block_ranges.push(at..at);
1376 continue;
1377 }
1378 let mut e_mat = d_mat.clone();
1380 if t_cum.ncols() > 0 {
1381 let t_rh = fast_ab(&t_cum, &r_h);
1382 e_mat = &e_mat - &t_rh;
1383 }
1384 let t_b = fast_ab(&e_mat, &u_mat);
1385
1386 let start = t_cum.ncols();
1387 let end = start + t_b.ncols();
1388 compiled_block_ranges.push(start..end);
1389 t_cum = concat_cols(&t_cum, &t_b);
1390 }
1391
1392 for v in t_cum.iter() {
1394 if !v.is_finite() {
1395 return Err(CompilerError::LinalgFailure(
1396 "compile_from_raw_grams produced non-finite entry in raw_from_compiled".to_string(),
1397 ));
1398 }
1399 }
1400
1401 Ok(CompiledMap {
1402 raw_from_compiled: t_cum,
1403 compiled_block_ranges,
1404 raw_block_ranges: raw_block_ranges.to_vec(),
1405 })
1406}
1407
1408impl CompiledMap {
1409 pub fn p_raw(&self) -> usize {
1411 self.raw_from_compiled.nrows()
1412 }
1413
1414 pub fn p_compiled(&self) -> usize {
1416 self.raw_from_compiled.ncols()
1417 }
1418
1419 pub fn reduce_design(&self, raw_design: &Array2<f64>) -> Result<Array2<f64>, String> {
1427 if raw_design.ncols() != self.p_raw() {
1428 return Err(format!(
1429 "CompiledMap::reduce_design: raw_design has {} columns, expected p_raw {}",
1430 raw_design.ncols(),
1431 self.p_raw()
1432 ));
1433 }
1434 Ok(fast_ab(raw_design, &self.raw_from_compiled))
1435 }
1436
1437 pub fn lift_coefficients(&self, beta_compiled: &Array1<f64>) -> Result<Array1<f64>, String> {
1444 if beta_compiled.len() != self.p_compiled() {
1445 return Err(format!(
1446 "CompiledMap::lift_coefficients: beta_compiled len {} != p_compiled {}",
1447 beta_compiled.len(),
1448 self.p_compiled()
1449 ));
1450 }
1451 Ok(self.raw_from_compiled.dot(beta_compiled))
1452 }
1453
1454 fn raw_block_rows(&self, block_idx: usize) -> Result<Array2<f64>, String> {
1459 let range = self.raw_block_ranges.get(block_idx).ok_or_else(|| {
1460 format!(
1461 "CompiledMap::raw_block_rows: block {block_idx} out of range {}",
1462 self.raw_block_ranges.len()
1463 )
1464 })?;
1465 Ok(self
1466 .raw_from_compiled
1467 .slice(s![range.start..range.end, ..])
1468 .to_owned())
1469 }
1470}
1471
1472pub fn reduce_penalties_with_map(
1491 map: &CompiledMap,
1492 raw_penalties: &[Option<Array2<f64>>],
1493) -> Result<Vec<Option<Array2<f64>>>, String> {
1494 if raw_penalties.len() != map.raw_block_ranges.len() {
1495 return Err(format!(
1496 "reduce_penalties_with_map: raw_penalties ({}) != blocks ({})",
1497 raw_penalties.len(),
1498 map.raw_block_ranges.len()
1499 ));
1500 }
1501 let p_compiled = map.p_compiled();
1502 let mut reduced: Vec<Option<Array2<f64>>> = Vec::with_capacity(raw_penalties.len());
1503 for (block_idx, raw_penalty) in raw_penalties.iter().enumerate() {
1504 let Some(s_b) = raw_penalty.as_ref() else {
1505 reduced.push(None);
1506 continue;
1507 };
1508 let p_b_raw = map.raw_block_ranges[block_idx].len();
1509 if s_b.shape() != [p_b_raw, p_b_raw] {
1510 return Err(format!(
1511 "reduce_penalties_with_map: block {block_idx} penalty shape {:?} != [{p_b_raw}, {p_b_raw}]",
1512 s_b.shape()
1513 ));
1514 }
1515 let t_b = map.raw_block_rows(block_idx)?;
1517 let s_t_b = fast_ab(s_b, &t_b); let s_compiled_raw = fast_atb(&t_b, &s_t_b); let mut s_compiled = symmetrise(&s_compiled_raw);
1521 if s_compiled.shape() != [p_compiled, p_compiled] {
1522 return Err(format!(
1523 "reduce_penalties_with_map: block {block_idx} reduced penalty shape {:?} != [{p_compiled}, {p_compiled}]",
1524 s_compiled.shape()
1525 ));
1526 }
1527 for v in s_compiled.iter_mut() {
1528 if !v.is_finite() {
1529 return Err(format!(
1530 "reduce_penalties_with_map: block {block_idx} reduced penalty has non-finite entry"
1531 ));
1532 }
1533 }
1534 reduced.push(Some(s_compiled));
1535 }
1536 Ok(reduced)
1537}
1538
1539pub struct BlockOrthogonalization {
1550 pub block_transforms: Vec<Array2<f64>>,
1553 pub dropped: Vec<(usize, usize)>,
1558 pub direction_annotations: Vec<PenalizedDirectionAnnotation>,
1566}
1567
1568pub fn orthogonalize_design_blocks(
1589 block_designs: &[Array2<f64>],
1590 priority: &[u32],
1591 weight: &[f64],
1592) -> Result<BlockOrthogonalization, CompilerError> {
1593 if block_designs.len() != priority.len() {
1594 return Err(CompilerError::DimensionMismatch(format!(
1595 "block_designs ({}) and priority ({}) length mismatch",
1596 block_designs.len(),
1597 priority.len()
1598 )));
1599 }
1600 if block_designs.is_empty() {
1601 return Ok(BlockOrthogonalization {
1602 block_transforms: Vec::new(),
1603 dropped: Vec::new(),
1604 direction_annotations: Vec::new(),
1605 });
1606 }
1607 let n = block_designs[0].nrows();
1608 for (b, x) in block_designs.iter().enumerate() {
1609 if x.nrows() != n {
1610 return Err(CompilerError::DimensionMismatch(format!(
1611 "block {b} design has {} rows but block 0 has {n}",
1612 x.nrows()
1613 )));
1614 }
1615 }
1616 if weight.len() != n {
1617 return Err(CompilerError::DimensionMismatch(format!(
1618 "weight length {} != n {n}",
1619 weight.len()
1620 )));
1621 }
1622 let mut sqrt_w = Array1::<f64>::zeros(n);
1625 for i in 0..n {
1626 let wi = weight[i].max(0.0);
1627 sqrt_w[i] = wi.sqrt();
1628 }
1629
1630 let mut order: Vec<usize> = (0..block_designs.len()).collect();
1634 order.sort_by(|&a, &b| priority[b].cmp(&priority[a]));
1635
1636 let mut anchor: Array2<f64> = Array2::<f64>::zeros((n, 0));
1638
1639 let mut block_transforms: Vec<Option<Array2<f64>>> = vec![None; block_designs.len()];
1641 let mut direction_annotations: Vec<Option<PenalizedDirectionAnnotation>> =
1642 vec![None; block_designs.len()];
1643 let mut dropped: Vec<(usize, usize)> = Vec::new();
1644
1645 for &b in order.iter() {
1646 let x_b = &block_designs[b];
1647 let p_b = x_b.ncols();
1648 let mut w_b = x_b.clone();
1650 for i in 0..n {
1651 let s = sqrt_w[i];
1652 for j in 0..p_b {
1653 w_b[[i, j]] *= s;
1654 }
1655 }
1656 let (residual, _correction) = residualise_in_metric(&anchor, &w_b)?;
1662 let g_res = symmetrise(&fast_atb(&residual, &residual));
1663 let g_bb = fast_atb(&w_b, &w_b);
1672 let g_bb_trace: f64 = (0..p_b).map(|i| g_bb[[i, i]].max(0.0)).sum();
1673 let v_b = keep_positive_eigenspace(&g_res, n, 1, g_bb_trace)?;
1674 let r_b = v_b.ncols();
1675 let absorbed_width = p_b - r_b;
1676 let kind = if absorbed_width == 0 {
1677 PenalizedDirectionAnnotationKind::Independent
1678 } else if r_b == 0 {
1679 PenalizedDirectionAnnotationKind::FullyAbsorbedByHigherPriority
1680 } else {
1681 PenalizedDirectionAnnotationKind::PartiallyAbsorbedByHigherPriority
1682 };
1683 direction_annotations[b] = Some(PenalizedDirectionAnnotation {
1684 block_idx: b,
1685 raw_width: p_b,
1686 kept_width: r_b,
1687 absorbed_width,
1688 kind,
1689 });
1690 if absorbed_width > 0 {
1691 dropped.push((b, absorbed_width));
1692 }
1693 let kept_weighted = fast_ab(&residual, &v_b);
1699 anchor = concat_cols(&anchor, &kept_weighted);
1700 block_transforms[b] = Some(v_b);
1701 }
1702
1703 let block_transforms: Vec<Array2<f64>> = block_transforms
1704 .into_iter()
1705 .enumerate()
1706 .map(|(b, t)| {
1707 t.ok_or_else(|| {
1708 CompilerError::LinalgFailure(format!(
1709 "orthogonalize_design_blocks: block {b} transform was never assigned"
1710 ))
1711 })
1712 })
1713 .collect::<Result<Vec<_>, _>>()?;
1714 let direction_annotations: Vec<PenalizedDirectionAnnotation> = direction_annotations
1715 .into_iter()
1716 .enumerate()
1717 .map(|(b, annotation)| {
1718 annotation.ok_or_else(|| {
1719 CompilerError::LinalgFailure(format!(
1720 "orthogonalize_design_blocks: block {b} direction annotation was never assigned"
1721 ))
1722 })
1723 })
1724 .collect::<Result<Vec<_>, _>>()?;
1725
1726 for (b, v) in block_transforms.iter().enumerate() {
1728 for value in v.iter() {
1729 if !value.is_finite() {
1730 return Err(CompilerError::LinalgFailure(format!(
1731 "orthogonalize_design_blocks: block {b} transform has a non-finite entry"
1732 )));
1733 }
1734 }
1735 }
1736
1737 Ok(BlockOrthogonalization {
1738 block_transforms,
1739 dropped,
1740 direction_annotations,
1741 })
1742}
1743
1744fn symmetrise(m: &Array2<f64>) -> Array2<f64> {
1746 let (r, c) = m.dim();
1747 assert_eq!(r, c, "symmetrise expects square matrix");
1748 let mut out = Array2::<f64>::zeros((r, c));
1749 for i in 0..r {
1750 for j in 0..c {
1751 out[[i, j]] = 0.5 * (m[[i, j]] + m[[j, i]]);
1752 }
1753 }
1754 out
1755}
1756
1757#[cfg(test)]
1758mod tests {
1759 use super::*;
1760 use ndarray::{Array1, Array2};
1761
1762 struct DenseScalarOperator {
1766 design: Array2<f64>,
1767 }
1768
1769 impl DenseScalarOperator {
1770 fn new(design: Array2<f64>) -> Self {
1771 Self { design }
1772 }
1773 }
1774
1775 impl RowJacobianOperator for DenseScalarOperator {
1776 fn k(&self) -> usize {
1777 1
1778 }
1779 fn ncols(&self) -> usize {
1780 self.design.ncols()
1781 }
1782 fn nrows(&self) -> usize {
1783 self.design.nrows()
1784 }
1785 fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
1786 assert_eq!(out.len(), 1);
1787 let mut acc = 0.0;
1788 for (j, &b) in delta_beta.iter().enumerate() {
1789 acc += self.design[[row, j]] * b;
1790 }
1791 out[0] = acc;
1792 }
1793 fn evaluate_full(&self) -> Array3<f64> {
1794 let n = self.design.nrows();
1795 let p = self.design.ncols();
1796 let mut out = Array3::<f64>::zeros((n, p, 1));
1797 for i in 0..n {
1798 for j in 0..p {
1799 out[[i, j, 0]] = self.design[[i, j]];
1800 }
1801 }
1802 out
1803 }
1804 }
1805
1806 struct DiagonalScalarRowHessian {
1812 w: Array1<f64>,
1813 }
1814
1815 impl DiagonalScalarRowHessian {
1816 fn new(w: Array1<f64>) -> Self {
1817 Self { w }
1818 }
1819 }
1820
1821 impl RowHessian for DiagonalScalarRowHessian {
1822 fn k(&self) -> usize {
1823 1
1824 }
1825 fn nrows(&self) -> usize {
1826 self.w.len()
1827 }
1828 fn fill_row(&self, row: usize, out: &mut [f64]) {
1829 assert_eq!(out.len(), 1);
1830 out[0] = self.w[row];
1831 }
1832 fn evaluate_full(&self) -> Array3<f64> {
1833 let n = self.w.len();
1834 let mut out = Array3::<f64>::zeros((n, 1, 1));
1835 for i in 0..n {
1836 out[[i, 0, 0]] = self.w[i];
1837 }
1838 out
1839 }
1840 }
1841
1842 fn op(design: Array2<f64>) -> Arc<dyn RowJacobianOperator> {
1843 Arc::new(DenseScalarOperator::new(design))
1844 }
1845
1846 #[test]
1850 fn compile_two_block_orthogonalises_under_metric() {
1851 let n = 50;
1852 let a = Array2::from_shape_fn((n, 3), |(i, j)| ((i + 1) as f64).sin().powi((j + 1) as i32));
1853 let b = Array2::from_shape_fn((n, 2), |(i, j)| {
1855 0.5 * a[[i, 0]] + ((i as f64) * 0.13 + j as f64).cos()
1856 });
1857 let hess = IdentityRowHessian::new(n, 1);
1858 let ops = vec![op(a.clone()), op(b.clone())];
1859 let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
1860 .expect("compile should succeed");
1861 let v_b = &compiled.blocks[1].t_lw;
1863 let m_b = compiled.blocks[1]
1864 .anchor_correction
1865 .as_ref()
1866 .expect("second block must carry an anchor correction");
1867 let b_v = b.dot(v_b);
1868 let a_m = a.dot(m_b);
1869 let b_compiled = &b_v - &a_m;
1870 let cross = a.t().dot(&b_compiled);
1872 let max_err = cross.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
1873 assert!(
1874 max_err < 1e-10,
1875 "orthogonality residual too large: {max_err:e}"
1876 );
1877 }
1878
1879 #[test]
1881 fn compile_three_block_chain() {
1882 let n = 80;
1883 let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 * 0.1 + j as f64).sin());
1884 let b = Array2::from_shape_fn((n, 2), |(i, j)| {
1885 0.3 * a[[i, 0]] + (j as f64) * (i as f64).cos()
1886 });
1887 let c = Array2::from_shape_fn((n, 2), |(i, j)| {
1888 0.2 * a[[i, 1]] + 0.4 * b[[i, 0]] + ((i + j) as f64).tan().min(5.0).max(-5.0)
1889 });
1890 let hess = IdentityRowHessian::new(n, 1);
1891 let ops = vec![op(a), op(b), op(c)];
1892 let compiled = compile(
1893 &ops,
1894 &hess,
1895 &[
1896 BlockOrder::Marginal,
1897 BlockOrder::Logslope,
1898 BlockOrder::LinkDev,
1899 ],
1900 )
1901 .expect("compile should succeed");
1902 let total: usize = compiled.blocks.iter().map(|b| b.t_lw.ncols()).sum();
1903 assert_eq!(
1904 compiled.joint_rank, total,
1905 "audit must report full rank on synthetic full-rank design"
1906 );
1907 }
1908
1909 #[test]
1913 fn compile_weighted_metric_nontrivial() {
1914 let n = 32;
1915 let a: Array2<f64> = Array2::from_shape_fn((n, 1), |(i, _)| (i as f64 + 1.0).sqrt());
1916 let b: Array2<f64> =
1917 Array2::from_shape_fn((n, 1), |(i, _)| 0.7 * a[[i, 0]] + (i as f64 * 0.05).cos());
1918 let w = Array1::from_shape_fn(n, |i| 0.5 + (i as f64 * 0.2).sin().abs());
1919 let hess = DiagonalScalarRowHessian::new(w.clone());
1920 let ops = vec![op(a.clone()), op(b.clone())];
1921 let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
1922 .expect("compile should succeed");
1923 let m = compiled.blocks[1]
1924 .anchor_correction
1925 .as_ref()
1926 .expect("anchor correction present");
1927 let analytic_num: f64 = (0..n).map(|i| w[i] * a[[i, 0]] * b[[i, 0]]).sum();
1928 let analytic_den: f64 = (0..n).map(|i| w[i] * a[[i, 0]] * a[[i, 0]]).sum();
1929 let analytic = analytic_num / analytic_den;
1930 assert!(m.dim() == (1, 1));
1931 assert!(
1932 (m[[0, 0]] - analytic).abs() < 1e-10,
1933 "weighted projection mismatch: got {got}, analytic {analytic}",
1934 got = m[[0, 0]]
1935 );
1936 }
1937
1938 #[test]
1948 fn compile_emits_anchor_correction_in_raw_column_coordinates() {
1949 let n = 64;
1950 let a: Array2<f64> = Array2::from_shape_fn((n, 3), |(i, j)| {
1954 let c0 = (i as f64 * 0.07 + 1.0).ln();
1955 let c1 = (i as f64 * 0.13).sin();
1956 match j {
1957 0 => c0,
1958 1 => c1,
1959 _ => 2.0 * c0 - 0.5 * c1,
1960 }
1961 });
1962 let c: Array2<f64> = Array2::from_shape_fn((n, 2), |(i, j)| {
1964 0.4 * a[[i, 0]] + (j as f64) * (i as f64 * 0.05).cos() + (i as f64 * 0.011).tanh()
1965 });
1966 let w = Array1::from_shape_fn(n, |i| 0.3 + (i as f64 * 0.17).sin().abs());
1967 let hess = DiagonalScalarRowHessian::new(w.clone());
1968 let ops = vec![op(a.clone()), op(c.clone())];
1969 let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::LinkDev])
1970 .expect("compile should succeed");
1971
1972 let v = &compiled.blocks[1].t_lw;
1973 let m = compiled.blocks[1]
1974 .anchor_correction
1975 .as_ref()
1976 .expect("candidate block must carry an anchor correction");
1977 let k_kept = v.ncols();
1978 assert!(k_kept >= 1, "candidate must keep at least one direction");
1979
1980 assert_eq!(
1983 m.nrows(),
1984 a.ncols(),
1985 "anchor_correction must be indexed by raw anchor columns (d_total), \
1986 got {} rows for {} raw anchor columns",
1987 m.nrows(),
1988 a.ncols(),
1989 );
1990 assert_eq!(m.ncols(), k_kept, "anchor_correction width must match V");
1991
1992 let c_v = c.dot(v);
1996 let a_m = a.dot(m);
1997 let c_tilde = &c_v - &a_m;
1998 let mut max_cross = 0.0_f64;
1999 for ac in 0..a.ncols() {
2000 for cc in 0..c_tilde.ncols() {
2001 let mut acc = 0.0;
2002 for i in 0..n {
2003 acc += w[i] * a[[i, ac]] * c_tilde[[i, cc]];
2004 }
2005 max_cross = max_cross.max(acc.abs());
2006 }
2007 }
2008 assert!(
2009 max_cross < 1e-9,
2010 "raw-coordinate anchor correction must W-orthogonalise the candidate \
2011 against the raw anchor span; max |Aᵀ W C̃| = {max_cross:e}"
2012 );
2013 }
2014
2015 #[test]
2018 fn compile_drops_trailing_pivots_from_latest_block() {
2019 let n = 40;
2020 let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 + 1.0).ln() * (j as f64 + 1.0));
2021 let c = Array2::from_shape_fn((n, 2), |(i, j)| {
2026 if j == 0 {
2027 a[[i, 0]]
2028 } else {
2029 (i as f64 * 0.1).cos()
2030 }
2031 });
2032 let hess = IdentityRowHessian::new(n, 1);
2033 let ops = vec![op(a), op(c)];
2034 let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
2040 .expect("compile should succeed");
2041 let v1_cols = compiled.blocks[1].t_lw.ncols();
2045 assert!(
2046 v1_cols < 2 || !compiled.dropped.is_empty(),
2047 "expected rank loss attributed to block 1, got v1_cols={v1_cols}, dropped={dropped:?}",
2048 dropped = compiled.dropped
2049 );
2050 for (block_idx, _) in &compiled.dropped {
2051 assert_eq!(
2052 *block_idx, 1,
2053 "audit drops must come from the latest block only"
2054 );
2055 }
2056 }
2057
2058 #[test]
2070 fn audit_truncation_keeps_t_lw_and_anchor_correction_in_lockstep() {
2071 let n = 40;
2072 let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 + 1.0).ln() * (j as f64 + 1.0));
2073 let c = Array2::from_shape_fn((n, 2), |(i, j)| {
2074 if j == 0 {
2075 a[[i, 0]]
2076 } else {
2077 (i as f64 * 0.1).cos()
2078 }
2079 });
2080 let hess = IdentityRowHessian::new(n, 1);
2081 let ops = vec![op(a), op(c)];
2082 let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
2083 .expect("compile should succeed");
2084 for (idx, block) in compiled.blocks.iter().enumerate() {
2085 let k_kept = block.t_lw.ncols();
2086 if let Some(m) = block.anchor_correction.as_ref() {
2087 assert_eq!(
2088 m.ncols(),
2089 k_kept,
2090 "block {idx}: anchor_correction.ncols()={ac} must equal t_lw.ncols()={k_kept} \
2091 after audit truncation",
2092 ac = m.ncols(),
2093 );
2094 }
2095 if let Some(r) = block.r_lw.as_ref() {
2096 assert_eq!(
2097 r.ncols(),
2098 k_kept,
2099 "block {idx}: r_lw.ncols()={r_cols} must equal t_lw.ncols()={k_kept} \
2100 after audit truncation",
2101 r_cols = r.ncols(),
2102 );
2103 }
2104 }
2105 }
2106
2107 #[test]
2112 fn compile_flex_anchor_is_first_class() {
2113 let n = 60;
2114 let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 * 0.07 + j as f64).sin());
2120 let b = Array2::from_shape_fn((n, 2), |(i, j)| {
2121 0.4 * a[[i, 0]] + (j as f64) * (i as f64 + 1.0).ln()
2122 });
2123 let hess = IdentityRowHessian::new(n, 1);
2124
2125 let ops_param = vec![op(a.clone()), op(b.clone())];
2126 let compiled_param = compile(
2127 &ops_param,
2128 &hess,
2129 &[BlockOrder::Marginal, BlockOrder::Logslope],
2130 )
2131 .expect("compile should succeed");
2132
2133 let ops_flex = vec![op(a.clone()), op(b.clone())];
2137 let compiled_flex = compile(
2138 &ops_flex,
2139 &hess,
2140 &[BlockOrder::ScoreWarp, BlockOrder::LinkDev],
2141 )
2142 .expect("compile should succeed");
2143
2144 let m_param = compiled_param.blocks[1].anchor_correction.as_ref().unwrap();
2145 let m_flex = compiled_flex.blocks[1].anchor_correction.as_ref().unwrap();
2146 assert_eq!(m_param.dim(), m_flex.dim());
2147 let max_diff = (m_param - m_flex)
2148 .iter()
2149 .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2150 assert!(
2151 max_diff < 1e-12,
2152 "flex vs parametric anchor correction mismatch: {max_diff:e}"
2153 );
2154 }
2155
2156 #[test]
2160 fn bernoulli_row_hessian_matches_irls_weight() {
2161 let w = Array1::from(vec![0.1, 0.5, 0.9, 0.25, 0.75]);
2162 let hess = DiagonalScalarRowHessian::new(w.clone());
2163 let full = hess.evaluate_full();
2164 assert_eq!(full.shape(), &[5, 1, 1]);
2165 for i in 0..5 {
2166 assert_eq!(full[[i, 0, 0]], w[i]);
2167 let mut buf = [0.0_f64; 1];
2168 hess.fill_row(i, &mut buf);
2169 assert_eq!(buf[0], w[i]);
2170 }
2171 }
2172
2173 #[test]
2177 fn compiler_predict_path_roundtrip() {
2178 let n = 24;
2179 let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 * 0.21).cos() + j as f64);
2180 let b = Array2::from_shape_fn((n, 2), |(i, j)| {
2181 0.3 * a[[i, 0]] + (i as f64 + j as f64).sqrt()
2182 });
2183 let hess = IdentityRowHessian::new(n, 1);
2184 let ops = vec![op(a.clone()), op(b.clone())];
2185 let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
2186 .expect("compile should succeed");
2187 let v_b = &compiled.blocks[1].t_lw;
2188 let m_b = compiled.blocks[1].anchor_correction.as_ref().unwrap();
2189 let predict_design = b.dot(v_b) - a.dot(m_b);
2191 assert_eq!(predict_design.nrows(), n);
2196 assert_eq!(predict_design.ncols(), v_b.ncols());
2197 for &val in predict_design.iter() {
2199 assert!(val.is_finite(), "predict design produced non-finite entry");
2200 }
2201 }
2202
2203 #[test]
2209 fn compile_exposes_r_lw_equal_to_m_dot_v() {
2210 let n = 40;
2211 let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 * 0.17 + j as f64).sin());
2212 let b = Array2::from_shape_fn((n, 2), |(i, j)| {
2214 0.6 * a[[i, 0]] + ((i as f64) * 0.11 + j as f64).cos()
2215 });
2216 let hess = IdentityRowHessian::new(n, 1);
2217 let ops = vec![op(a.clone()), op(b.clone())];
2218 let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
2219 .expect("compile should succeed");
2220
2221 assert!(compiled.blocks[0].r_lw.is_none());
2223 assert!(compiled.blocks[0].anchor_correction.is_none());
2224
2225 let v_a = &compiled.blocks[0].t_lw;
2228 let v_b = &compiled.blocks[1].t_lw;
2229 let m_compiled = compiled.blocks[1]
2230 .anchor_correction
2231 .as_ref()
2232 .expect("second block must carry an anchor correction");
2233 let r_lw = compiled.blocks[1]
2234 .r_lw
2235 .as_ref()
2236 .expect("second block must expose r_lw");
2237 let p_a_kept = v_a.ncols();
2238 let p_b_kept = v_b.ncols();
2239 assert_eq!(
2240 m_compiled.dim(),
2241 (p_a_kept, p_b_kept),
2242 "anchor_correction must be at compiled width"
2243 );
2244 assert_eq!(r_lw.dim(), (p_a_kept, p_b_kept));
2245 let diff = r_lw - m_compiled;
2247 let max_diff = diff.iter().fold(0.0_f64, |acc, &x| acc.max(x.abs()));
2248 assert!(
2249 max_diff == 0.0,
2250 "r_lw and anchor_correction must be identical"
2251 );
2252
2253 let b_compiled = b.dot(v_b) - a.dot(m_compiled);
2258 let cross = a.t().dot(&b_compiled);
2259 let max_cross = cross.iter().fold(0.0_f64, |acc, &x| acc.max(x.abs()));
2260 assert!(
2261 max_cross < 1e-10,
2262 "compiled B-design must be H-orthogonal to A: max cross = {max_cross:e}"
2263 );
2264 }
2265
2266 struct DenseRowHessian {
2268 h: Array3<f64>,
2269 }
2270
2271 impl RowHessian for DenseRowHessian {
2272 fn k(&self) -> usize {
2273 self.h.shape()[1]
2274 }
2275 fn nrows(&self) -> usize {
2276 self.h.shape()[0]
2277 }
2278 fn fill_row(&self, row: usize, out: &mut [f64]) {
2279 let k = self.k();
2280 assert_eq!(out.len(), k * k);
2281 for c in 0..k {
2282 for d in 0..k {
2283 out[c * k + d] = self.h[[row, c, d]];
2284 }
2285 }
2286 }
2287 fn evaluate_full(&self) -> Array3<f64> {
2288 self.h.clone()
2289 }
2290 }
2291
2292 fn reference_gram_from_w(j_full: &Array3<f64>, h_full: &Array3<f64>) -> Array2<f64> {
2295 let w = scale_block_by_sqrt_h(j_full, h_full);
2296 fast_ata(&w)
2297 }
2298
2299 #[test]
2302 fn closed_form_gram_matches_reference_two_block_k4() {
2303 let n = 17;
2304 let k = 4;
2305 let p_a = 3;
2306 let p_b = 2;
2307
2308 let make_block = |seed: f64, n: usize, p: usize| -> Vec<Option<Array2<f64>>> {
2310 (0..4)
2311 .map(|c| {
2312 let m = Array2::from_shape_fn((n, p), |(i, j)| {
2313 ((i as f64 + 1.0) * (j as f64 + 1.0) * (c as f64 + 1.0) + seed).sin()
2314 });
2315 Some(m)
2316 })
2317 .collect()
2318 };
2319 let block_a = make_block(0.3, n, p_a);
2320 let block_b = make_block(1.1, n, p_b);
2321
2322 let h = Array3::from_shape_fn((n, k, k), |(i, c, d)| {
2324 let mut acc = 0.0;
2325 for r in 0..k {
2326 let mc = ((i + 1) as f64 * (c + 1) as f64 * (r + 1) as f64 * 0.13).cos();
2327 let md = ((i + 1) as f64 * (d + 1) as f64 * (r + 1) as f64 * 0.13).cos();
2328 acc += mc * md;
2329 }
2330 acc + if c == d { 0.5 } else { 0.0 }
2331 });
2332 let row_hess = DenseRowHessian { h: h.clone() };
2333
2334 let channel_blocks = PrimaryChannelBlocks {
2335 blocks: vec![block_a.clone(), block_b.clone()],
2336 };
2337 let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
2338
2339 let gram = build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges)
2340 .expect("closed-form Gram should succeed");
2341
2342 let p_total = p_a + p_b;
2345 let mut j_full = Array3::<f64>::zeros((n, p_total, k));
2346 for c in 0..k {
2347 if let Some(xa) = block_a[c].as_ref() {
2348 for i in 0..n {
2349 for j in 0..p_a {
2350 j_full[[i, j, c]] = xa[[i, j]];
2351 }
2352 }
2353 }
2354 if let Some(xb) = block_b[c].as_ref() {
2355 for i in 0..n {
2356 for j in 0..p_b {
2357 j_full[[i, p_a + j, c]] = xb[[i, j]];
2358 }
2359 }
2360 }
2361 }
2362 let ref_gram = reference_gram_from_w(&j_full, &h);
2363
2364 let diff = &gram - &ref_gram;
2365 let max_err = diff.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2366 let scale = ref_gram.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2367 assert!(
2368 max_err < 1e-9 * scale.max(1.0),
2369 "closed-form Gram mismatches reference: max_err={max_err:e}, scale={scale:e}"
2370 );
2371
2372 for i in 0..p_total {
2374 for j in 0..p_total {
2375 assert!(
2376 (gram[[i, j]] - gram[[j, i]]).abs() < 1e-12,
2377 "closed-form Gram not symmetric at ({i},{j})"
2378 );
2379 }
2380 }
2381 }
2382
2383 #[test]
2388 fn closed_form_gram_channel_sparsity() {
2389 let n = 13;
2390 let k = 4;
2391 let p_a = 2;
2392 let p_b = 2;
2393
2394 let xa = Array2::from_shape_fn((n, p_a), |(i, j)| ((i + 1) as f64 * 0.21 + j as f64).cos());
2395 let xb = Array2::from_shape_fn((n, p_b), |(i, j)| {
2396 ((i + 1) as f64 * 0.17 + j as f64).sin() + 0.5
2397 });
2398
2399 let block_a: Vec<Option<Array2<f64>>> = vec![Some(xa.clone()), None, None, None];
2400 let block_b: Vec<Option<Array2<f64>>> = vec![None, None, None, Some(xb.clone())];
2401
2402 let h_03_vec = Array1::from_shape_fn(n, |i| 0.7 + 0.3 * ((i as f64) * 0.4).sin());
2405 let h = Array3::from_shape_fn((n, k, k), |(i, c, d)| {
2406 if (c, d) == (0, 3) || (c, d) == (3, 0) {
2409 h_03_vec[i]
2410 } else if c == d {
2411 2.0
2412 } else {
2413 0.0
2414 }
2415 });
2416 let row_hess = DenseRowHessian { h: h.clone() };
2417
2418 let channel_blocks = PrimaryChannelBlocks {
2419 blocks: vec![block_a.clone(), block_b.clone()],
2420 };
2421 let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
2422 let gram = build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges)
2423 .expect("closed-form Gram should succeed");
2424
2425 let cross = gram.slice(s![0..p_a, p_a..(p_a + p_b)]).to_owned();
2427 let expected = fast_xt_diag_y(&xa, &h_03_vec, &xb);
2429 let diff = &cross - &expected;
2430 let max_err = diff.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2431 assert!(
2432 max_err < 1e-12,
2433 "cross-block Gram must equal Xaᵀ·diag(h_03)·Xb: max_err={max_err:e}"
2434 );
2435
2436 let h_zero = Array3::from_shape_fn((n, k, k), |(_, c, d)| if c == d { 2.0 } else { 0.0 });
2438 let row_hess_zero = DenseRowHessian { h: h_zero };
2439 let gram_zero =
2440 build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess_zero, &raw_ranges)
2441 .expect("closed-form Gram should succeed");
2442 let cross_zero = gram_zero.slice(s![0..p_a, p_a..(p_a + p_b)]);
2443 let max_zero = cross_zero.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2444 assert!(
2445 max_zero < 1e-12,
2446 "cross-block Gram must vanish when coupling channel pair is zero: got {max_zero:e}"
2447 );
2448 }
2449
2450 #[test]
2453 fn structural_gram_matches_within_channel_sum() {
2454 let n = 11;
2455 let p_a = 2;
2456 let p_b = 3;
2457 let make_block = |seed: f64, n: usize, p: usize| -> Vec<Option<Array2<f64>>> {
2458 (0..4)
2459 .map(|c| {
2460 if c == 1 {
2461 return None;
2463 }
2464 Some(Array2::from_shape_fn((n, p), |(i, j)| {
2465 ((i as f64 + 1.0) * (j as f64 + 1.0) + seed * (c as f64 + 1.0)).sin()
2466 }))
2467 })
2468 .collect()
2469 };
2470 let block_a = make_block(0.1, n, p_a);
2471 let block_b = make_block(0.7, n, p_b);
2472 let channel_blocks = PrimaryChannelBlocks {
2473 blocks: vec![block_a.clone(), block_b.clone()],
2474 };
2475 let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
2476 let gram = build_raw_grams_structural(&channel_blocks, &raw_ranges);
2477
2478 let mut expected_cross = Array2::<f64>::zeros((p_a, p_b));
2481 for c in 0..4 {
2482 if let (Some(xa), Some(xb)) = (block_a[c].as_ref(), block_b[c].as_ref()) {
2483 expected_cross += &fast_atb(xa, xb);
2484 }
2485 }
2486 let cross = gram.slice(s![0..p_a, p_a..(p_a + p_b)]).to_owned();
2487 let diff = &cross - &expected_cross;
2488 let max_err = diff.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2489 assert!(
2490 max_err < 1e-12,
2491 "structural cross-block must equal Σ_c Xaᵀ·Xb: max_err={max_err:e}"
2492 );
2493
2494 for i in 0..(p_a + p_b) {
2496 for j in 0..(p_a + p_b) {
2497 assert!(
2498 (gram[[i, j]] - gram[[j, i]]).abs() < 1e-12,
2499 "structural Gram not symmetric at ({i},{j})"
2500 );
2501 }
2502 }
2503 }
2504
2505 fn diag_hess(w: Array1<f64>) -> DiagonalScalarRowHessian {
2509 DiagonalScalarRowHessian::new(w)
2510 }
2511
2512 #[test]
2516 fn dual_metric_with_equal_metrics_matches_single_metric() {
2517 let n = 36;
2518 let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 * 0.13 + j as f64).sin());
2519 let b = Array2::from_shape_fn((n, 2), |(i, j)| {
2521 0.4 * a[[i, 0]] + (i as f64 * 0.07 + j as f64).cos()
2522 });
2523 let w = Array1::from_shape_fn(n, |i| 0.5 + (i as f64 * 0.17).sin().abs());
2524 let curvature = diag_hess(w.clone());
2525 let ordering = [BlockOrder::Marginal, BlockOrder::Logslope];
2526
2527 let ops_single = vec![op(a.clone()), op(b.clone())];
2528 let single = compile(&ops_single, &curvature, &ordering)
2529 .expect("single-metric compile should succeed");
2530
2531 let structural_same = diag_hess(w.clone());
2534 let ops_dual = vec![op(a.clone()), op(b.clone())];
2535 let dual = compile_with_dual_metric(&ops_dual, &curvature, &structural_same, &ordering)
2536 .expect("dual-metric compile should succeed");
2537
2538 assert_eq!(single.blocks.len(), dual.blocks.len());
2539 for (idx, (sb, db)) in single.blocks.iter().zip(dual.blocks.iter()).enumerate() {
2540 assert_eq!(sb.t_lw.dim(), db.t_lw.dim(), "block {idx}: V dims differ");
2541 let max_v = (&sb.t_lw - &db.t_lw)
2542 .iter()
2543 .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2544 assert!(max_v < 1e-10, "block {idx}: V mismatch {max_v:e}");
2545 match (sb.anchor_correction.as_ref(), db.anchor_correction.as_ref()) {
2546 (None, None) => {}
2547 (Some(s), Some(d)) => {
2548 assert_eq!(s.dim(), d.dim());
2549 let max_m = (s - d).iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2550 assert!(max_m < 1e-10, "block {idx}: M mismatch {max_m:e}");
2551 }
2552 _ => panic!("block {idx}: one side has anchor correction, the other does not"),
2553 }
2554 }
2555 assert_eq!(single.joint_rank, dual.joint_rank);
2556 }
2557
2558 #[test]
2573 fn dual_metric_resists_pilot_curvature_alias() {
2574 let n = 12;
2575 let a = Array2::from_shape_fn((n, 1), |(i, _)| (i as f64) + 1.0);
2578 let b = Array2::from_shape_fn((n, 1), |(i, _)| {
2579 if i < 6 {
2580 2.0 * a[[i, 0]]
2581 } else {
2582 ((i as f64) * 0.3).cos() + 0.5
2583 }
2584 });
2585
2586 let mut w_vec = vec![0.0_f64; n];
2591 for w in &mut w_vec[..6] {
2592 *w = 1.0;
2593 }
2594 let w = Array1::from(w_vec);
2595 let curvature = diag_hess(w.clone());
2596
2597 let id_struct = IdentityRowHessian::new(n, 1);
2601 let ordering = [BlockOrder::Marginal, BlockOrder::Logslope];
2602
2603 let ops_dual = vec![op(a.clone()), op(b.clone())];
2607 let dual = compile_with_dual_metric(&ops_dual, &curvature, &id_struct, &ordering);
2608
2609 let ops_h_only = vec![op(a.clone()), op(b.clone())];
2613 let h_only = compile_with_dual_metric(&ops_h_only, &curvature, &curvature, &ordering);
2614
2615 match h_only {
2618 Err(CompilerError::FullyAliased { block_idx, .. }) => {
2619 assert_eq!(block_idx, 1, "H-only path must alias block 1");
2620 }
2621 Ok(out) => {
2622 let v1_cols = out.blocks[1].t_lw.ncols();
2627 assert!(
2628 v1_cols == 0 || !out.dropped.is_empty(),
2629 "H-only path should reject B's curvature-aliased column; v1_cols={v1_cols}, dropped={dropped:?}",
2630 dropped = out.dropped,
2631 );
2632 }
2633 Err(other) => panic!("unexpected H-only error: {other:?}"),
2634 }
2635
2636 let dual =
2637 dual.expect("dual-metric must succeed: identity-structural sees B as independent");
2638 assert_eq!(dual.blocks.len(), 2);
2647 assert_eq!(dual.blocks[0].t_lw.ncols(), 1, "A must keep its column");
2648 let v1_post_audit = dual.blocks[1].t_lw.ncols();
2653 let dropped_count = dual.dropped.len();
2654 assert_eq!(
2655 v1_post_audit + dropped_count,
2656 1,
2657 "structural pass kept B's column; audit may demote it but the pre-audit width was 1"
2658 );
2659 }
2660
2661 #[test]
2670 fn dual_metric_identity_structural_preserves_full_rank() {
2671 let n = 24;
2672 let a = Array2::from_shape_fn((n, 2), |(i, j)| ((i + 1) as f64 + j as f64).sqrt());
2673 let b = Array2::from_shape_fn((n, 2), |(i, j)| {
2674 ((i + 1) as f64).ln() + (i as f64 * 0.1 + j as f64).cos()
2675 });
2676 let w = Array1::from_shape_fn(n, |i| 0.4 + (i as f64 * 0.05).sin().powi(2));
2677 let curvature = diag_hess(w.clone());
2678 let id_struct = IdentityRowHessian::new(n, 1);
2679 let ordering = [BlockOrder::Marginal, BlockOrder::Logslope];
2680
2681 let ops = vec![op(a.clone()), op(b.clone())];
2682 let out =
2683 compile_with_dual_metric(&ops, &curvature, &id_struct, &ordering).expect("compile");
2684 assert_eq!(out.blocks[0].t_lw.ncols(), 2);
2686 assert_eq!(out.blocks[1].t_lw.ncols(), 2);
2687 assert_eq!(out.dropped.len(), 0);
2688 assert_eq!(out.joint_rank, 4);
2689 }
2690
2691 #[test]
2697 fn build_primary_grams_gpu_or_cpu_two_block_k4_matches_cpu() {
2698 let n = 11;
2699 let k = 4;
2700 let p_a = 2;
2701 let p_b = 3;
2702
2703 let make_block = |seed: f64, n: usize, p: usize| -> Vec<Option<Array2<f64>>> {
2704 (0..4)
2705 .map(|c| {
2706 let m = Array2::from_shape_fn((n, p), |(i, j)| {
2707 ((i as f64 + 1.0) * (j as f64 + 1.0) * (c as f64 + 1.0) + seed).sin()
2708 });
2709 Some(m)
2710 })
2711 .collect()
2712 };
2713 let block_a = make_block(0.7, n, p_a);
2714 let block_b = make_block(-0.4, n, p_b);
2715
2716 let h = Array3::from_shape_fn((n, k, k), |(i, c, d)| {
2717 let mut acc = 0.0;
2718 for r in 0..k {
2719 let mc = ((i + 1) as f64 * (c + 1) as f64 * (r + 1) as f64 * 0.11).cos();
2720 let md = ((i + 1) as f64 * (d + 1) as f64 * (r + 1) as f64 * 0.11).cos();
2721 acc += mc * md;
2722 }
2723 acc + if c == d { 0.25 } else { 0.0 }
2724 });
2725 let row_hess = DenseRowHessian { h: h.clone() };
2726
2727 let channel_blocks = PrimaryChannelBlocks {
2728 blocks: vec![block_a, block_b],
2729 };
2730 let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
2731
2732 let (gram_h, gram_struct) =
2733 build_primary_grams_gpu_or_cpu(&channel_blocks, &row_hess, &raw_ranges)
2734 .expect("dispatch helper should succeed");
2735
2736 let cpu_h = build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges)
2737 .expect("CPU curvature Gram should succeed");
2738 let cpu_s = build_raw_grams_structural(&channel_blocks, &raw_ranges);
2739
2740 let tol = 1e-9_f64;
2741 for idx in cpu_h.indexed_iter().map(|(i, _)| i) {
2742 let diff = (gram_h[idx] - cpu_h[idx]).abs();
2743 let scale = cpu_h[idx].abs().max(1.0);
2744 assert!(
2745 diff <= tol * scale,
2746 "gram_h mismatch at {idx:?}: helper={} cpu={}",
2747 gram_h[idx],
2748 cpu_h[idx]
2749 );
2750 }
2751 for idx in cpu_s.indexed_iter().map(|(i, _)| i) {
2752 let diff = (gram_struct[idx] - cpu_s[idx]).abs();
2753 let scale = cpu_s[idx].abs().max(1.0);
2754 assert!(
2755 diff <= tol * scale,
2756 "gram_struct mismatch at {idx:?}: helper={} cpu={}",
2757 gram_struct[idx],
2758 cpu_s[idx]
2759 );
2760 }
2761 }
2762
2763 fn scalar_grams_two_block(
2769 a: &Array2<f64>,
2770 b: &Array2<f64>,
2771 w: &Array1<f64>,
2772 ) -> (Array2<f64>, Array2<f64>, Vec<std::ops::Range<usize>>) {
2773 let p_a = a.ncols();
2774 let p_b = b.ncols();
2775 let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
2776 let channel_blocks = PrimaryChannelBlocks {
2777 blocks: vec![vec![Some(a.clone())], vec![Some(b.clone())]],
2778 };
2779 let row_hess = DiagonalScalarRowHessian::new(w.clone());
2780 let gram_h =
2781 build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges).unwrap();
2782 let gram_struct = build_raw_grams_structural(&channel_blocks, &raw_ranges);
2783 (gram_h, gram_struct, raw_ranges)
2784 }
2785
2786 #[test]
2790 fn compile_from_raw_grams_full_structural_alias() {
2791 let n = 10;
2792 let a = Array2::from_shape_fn((n, 2), |(i, j)| ((i + 1) as f64 * (j + 1) as f64).sin());
2793 let l = Array2::from_shape_vec((2, 2), vec![1.0, 0.5, -0.25, 1.0]).unwrap();
2795 let b = a.dot(&l);
2796 let w = Array1::ones(n);
2797 let (gram_h, gram_struct, raw_ranges) = scalar_grams_two_block(&a, &b, &w);
2798 let res = compile_from_raw_grams(
2799 &gram_h,
2800 &gram_struct,
2801 &raw_ranges,
2802 &[BlockOrder::Marginal, BlockOrder::Logslope],
2803 )
2804 .expect("lower-priority full alias should compile to zero width");
2805 assert_eq!(res.compiled_block_ranges[0].len(), 2);
2806 assert_eq!(res.compiled_block_ranges[1].len(), 0);
2807 assert_eq!(res.raw_from_compiled.dim(), (4, 2));
2808 assert!(
2809 res.raw_from_compiled
2810 .slice(s![raw_ranges[1].clone(), ..])
2811 .iter()
2812 .all(|v| v.abs() <= 1.0e-12),
2813 "zero-width block must not retain raw coefficient directions in T"
2814 );
2815 }
2816
2817 #[test]
2824 fn compile_from_raw_grams_zero_width_first_block_is_identifiable() {
2825 let n = 12;
2826 let empty = Array2::<f64>::zeros((n, 0));
2827 let b = Array2::from_shape_fn((n, 2), |(i, j)| {
2828 ((i + 1) as f64 * (j + 1) as f64 * 0.23).cos()
2829 });
2830 let w = Array1::ones(n);
2831 let (gram_h, gram_struct, raw_ranges) = scalar_grams_two_block(&empty, &b, &w);
2832 let map = compile_from_raw_grams(
2833 &gram_h,
2834 &gram_struct,
2835 &raw_ranges,
2836 &[BlockOrder::Marginal, BlockOrder::Logslope],
2837 )
2838 .expect("zero-width first block must be trivially identifiable, not FullyAliased");
2839 assert_eq!(
2840 map.compiled_block_ranges[0].len(),
2841 0,
2842 "empty first block keeps zero columns"
2843 );
2844 assert_eq!(
2845 map.compiled_block_ranges[1].len(),
2846 2,
2847 "the second block keeps its full structural rank"
2848 );
2849 assert_eq!(map.raw_from_compiled.dim(), (2, 2));
2850 }
2851
2852 #[test]
2853 fn orthogonalization_annotates_independent_and_fully_absorbed_blocks() {
2854 let n = 18;
2855 let anchor = Array2::from_shape_fn((n, 2), |(i, j)| {
2856 ((i + 1) as f64 * (0.19 + j as f64 * 0.07)).sin()
2857 });
2858 let duplicate = anchor.clone();
2859 let independent = Array2::from_shape_fn((n, 1), |(i, _)| ((i + 1) as f64 * 0.43).cos());
2860 let weight = vec![1.0; n];
2861 let ortho = orthogonalize_design_blocks(
2862 &[anchor, duplicate, independent],
2863 &[200, 100, 50],
2864 &weight,
2865 )
2866 .expect("structural annotation compile");
2867
2868 assert_eq!(
2869 ortho.direction_annotations[0].kind,
2870 PenalizedDirectionAnnotationKind::Independent
2871 );
2872 assert_eq!(ortho.direction_annotations[0].absorbed_width, 0);
2873 assert_eq!(
2874 ortho.direction_annotations[1].kind,
2875 PenalizedDirectionAnnotationKind::FullyAbsorbedByHigherPriority,
2876 "a duplicated lower-priority block is the same realized-design direction"
2877 );
2878 assert_eq!(ortho.direction_annotations[1].raw_width, 2);
2879 assert_eq!(ortho.direction_annotations[1].kept_width, 0);
2880 assert_eq!(ortho.direction_annotations[1].absorbed_width, 2);
2881 assert_eq!(
2882 ortho.direction_annotations[2].kind,
2883 PenalizedDirectionAnnotationKind::Independent,
2884 "a genuinely new realized-design direction keeps its own penalty block"
2885 );
2886 assert_eq!(ortho.direction_annotations[2].raw_width, 1);
2887 assert_eq!(ortho.direction_annotations[2].kept_width, 1);
2888 assert_eq!(ortho.dropped, vec![(1, 2)]);
2889 }
2890
2891 #[test]
2892 fn compile_from_raw_grams_three_block_full_logslope_alias_keeps_fast_path() {
2893 let n = 24;
2894 let time = Array2::from_shape_fn((n, 2), |(i, j)| {
2895 ((i + 1) as f64 * (j + 2) as f64 * 0.17).sin()
2896 });
2897 let marginal = Array2::from_shape_fn((n, 1), |(i, _)| ((i + 3) as f64 * 0.11).cos());
2898 let logslope = marginal.clone();
2899 let p_time = time.ncols();
2900 let p_marg = marginal.ncols();
2901 let p_log = logslope.ncols();
2902 let raw_ranges = vec![
2903 0..p_time,
2904 p_time..(p_time + p_marg),
2905 (p_time + p_marg)..(p_time + p_marg + p_log),
2906 ];
2907 let channel_blocks = PrimaryChannelBlocks {
2908 blocks: vec![
2909 vec![Some(time.clone())],
2910 vec![Some(marginal.clone())],
2911 vec![Some(logslope.clone())],
2912 ],
2913 };
2914 let row_hess = DiagonalScalarRowHessian::new(Array1::ones(n));
2915 let gram_h =
2916 build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges).unwrap();
2917 let gram_struct = build_raw_grams_structural(&channel_blocks, &raw_ranges);
2918
2919 let map = compile_from_raw_grams(
2920 &gram_h,
2921 &gram_struct,
2922 &raw_ranges,
2923 &[BlockOrder::Time, BlockOrder::Marginal, BlockOrder::Logslope],
2924 )
2925 .expect("fully aliased logslope block should not skip the compiled-map path");
2926
2927 assert_eq!(map.compiled_block_ranges[0].len(), p_time);
2928 assert_eq!(map.compiled_block_ranges[1].len(), p_marg);
2929 assert_eq!(map.compiled_block_ranges[2].len(), 0);
2930 assert_eq!(
2931 map.raw_from_compiled.dim(),
2932 (p_time + p_marg + p_log, p_time + p_marg)
2933 );
2934 let x_raw = {
2935 let mut out = Array2::<f64>::zeros((n, p_time + p_marg + p_log));
2936 out.slice_mut(s![.., raw_ranges[0].clone()]).assign(&time);
2937 out.slice_mut(s![.., raw_ranges[1].clone()])
2938 .assign(&marginal);
2939 out.slice_mut(s![.., raw_ranges[2].clone()])
2940 .assign(&logslope);
2941 out
2942 };
2943 let x_compiled = fast_ab(&x_raw, &map.raw_from_compiled);
2944 let rrqr = rrqr_with_permutation(&x_compiled, default_rrqr_rank_alpha()).unwrap();
2945 assert_eq!(rrqr.rank, x_compiled.ncols());
2946 }
2947
2948 #[test]
2954 fn compile_from_raw_grams_partial_alias_matches_w_reference() {
2955 let n = 25;
2956 let a = Array2::from_shape_fn((n, 2), |(i, j)| {
2957 ((i + 1) as f64 * (j + 1) as f64 * 0.3).sin()
2958 });
2959 let mut b = Array2::<f64>::zeros((n, 2));
2961 for i in 0..n {
2962 b[[i, 0]] = a[[i, 0]];
2963 b[[i, 1]] = ((i + 1) as f64 * 0.7).cos();
2964 }
2965 let w = Array1::from_shape_fn(n, |i| 1.0 + 0.1 * (i as f64));
2966 let (gram_h, gram_struct, raw_ranges) = scalar_grams_two_block(&a, &b, &w);
2967 let compiled = compile_from_raw_grams(
2968 &gram_h,
2969 &gram_struct,
2970 &raw_ranges,
2971 &[BlockOrder::Marginal, BlockOrder::Logslope],
2972 )
2973 .expect("closed-form compile must succeed");
2974 let p_a = a.ncols();
2975 let p_b = b.ncols();
2976 assert_eq!(compiled.raw_from_compiled.shape()[0], p_a + p_b);
2977 assert_eq!(
2978 compiled.raw_from_compiled.shape()[1],
2979 p_a + 1,
2980 "partial alias should leave compiled width = p_a + 1 (one column dropped from B)"
2981 );
2982 assert_eq!(compiled.compiled_block_ranges[0], 0..p_a);
2984 assert_eq!(
2985 compiled.compiled_block_ranges[1].end - compiled.compiled_block_ranges[1].start,
2986 1
2987 );
2988
2989 let mut x_raw = Array2::<f64>::zeros((n, p_a + p_b));
2993 for i in 0..n {
2994 for j in 0..p_a {
2995 x_raw[[i, j]] = a[[i, j]];
2996 }
2997 for j in 0..p_b {
2998 x_raw[[i, p_a + j]] = b[[i, j]];
2999 }
3000 }
3001 let x_compiled = fast_ab(&x_raw, &compiled.raw_from_compiled);
3002 let g_compiled = fast_ata(&x_compiled);
3004 let (evals, _) = g_compiled.eigh(Side::Lower).unwrap();
3005 let lam_max = evals.iter().cloned().fold(0.0_f64, f64::max);
3006 let tol = lam_max * 64.0 * (g_compiled.nrows() as f64) * f64::EPSILON;
3007 let rank_compiled = evals.iter().filter(|&&l| l > tol).count();
3008 assert_eq!(
3009 rank_compiled,
3010 p_a + 1,
3011 "compiled design column rank must equal p_a + 1 after dropping the alias"
3012 );
3013
3014 let ops_dual: Vec<Arc<dyn RowJacobianOperator>> = vec![op(a.clone()), op(b.clone())];
3017 let curvature = DiagonalScalarRowHessian::new(w.clone());
3018 let id_struct = IdentityRowHessian::new(n, 1);
3019 let dual = compile_with_dual_metric(
3020 &ops_dual,
3021 &curvature,
3022 &id_struct,
3023 &[BlockOrder::Marginal, BlockOrder::Logslope],
3024 )
3025 .expect("dual metric compile should succeed");
3026 let dual_total: usize = dual.blocks.iter().map(|b| b.t_lw.ncols()).sum();
3027 assert_eq!(dual_total, p_a + 1, "W-reference total width should match");
3028 }
3029
3030 #[test]
3033 fn compile_from_raw_grams_three_block_ordering_matters() {
3034 let n = 30;
3035 let a = Array2::from_shape_fn((n, 2), |(i, j)| {
3036 ((i + 1) as f64 * (j + 2) as f64 * 0.2).sin()
3037 });
3038 let mut b = Array2::<f64>::zeros((n, 2));
3040 for i in 0..n {
3041 b[[i, 0]] = ((i + 1) as f64 * 0.4).cos();
3042 b[[i, 1]] = a[[i, 0]];
3043 }
3044 let mut c = Array2::<f64>::zeros((n, 2));
3046 for i in 0..n {
3047 c[[i, 0]] = ((i + 1) as f64 * 0.55).sin();
3048 c[[i, 1]] = a[[i, 1]];
3049 }
3050 let w = Array1::ones(n);
3051
3052 let build = |b0: &Array2<f64>, b1: &Array2<f64>, b2: &Array2<f64>| {
3053 let raw_ranges = vec![
3054 0..b0.ncols(),
3055 b0.ncols()..(b0.ncols() + b1.ncols()),
3056 (b0.ncols() + b1.ncols())..(b0.ncols() + b1.ncols() + b2.ncols()),
3057 ];
3058 let channel_blocks = PrimaryChannelBlocks {
3059 blocks: vec![
3060 vec![Some(b0.clone())],
3061 vec![Some(b1.clone())],
3062 vec![Some(b2.clone())],
3063 ],
3064 };
3065 let row_hess = DiagonalScalarRowHessian::new(w.clone());
3066 let gram_h =
3067 build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges)
3068 .unwrap();
3069 let gram_struct = build_raw_grams_structural(&channel_blocks, &raw_ranges);
3070 (gram_h, gram_struct, raw_ranges)
3071 };
3072
3073 let (gh, gs, rr) = build(&a, &b, &c);
3075 let order_abc = compile_from_raw_grams(
3076 &gh,
3077 &gs,
3078 &rr,
3079 &[
3080 BlockOrder::Marginal,
3081 BlockOrder::Logslope,
3082 BlockOrder::LinkDev,
3083 ],
3084 )
3085 .expect("ABC compile");
3086 assert_eq!(order_abc.compiled_block_ranges[0].len(), 2);
3087 assert_eq!(order_abc.compiled_block_ranges[1].len(), 1);
3088 assert_eq!(order_abc.compiled_block_ranges[2].len(), 1);
3089
3090 let (gh2, gs2, rr2) = build(&b, &a, &c);
3093 let order_bac = compile_from_raw_grams(
3094 &gh2,
3095 &gs2,
3096 &rr2,
3097 &[
3098 BlockOrder::Marginal,
3099 BlockOrder::Logslope,
3100 BlockOrder::LinkDev,
3101 ],
3102 )
3103 .expect("BAC compile");
3104 assert_eq!(order_bac.compiled_block_ranges[0].len(), 2);
3105 assert_eq!(order_bac.compiled_block_ranges[1].len(), 1);
3106 let total_abc: usize = order_abc
3108 .compiled_block_ranges
3109 .iter()
3110 .map(|r| r.len())
3111 .sum();
3112 let total_bac: usize = order_bac
3113 .compiled_block_ranges
3114 .iter()
3115 .map(|r| r.len())
3116 .sum();
3117 assert_eq!(total_abc, total_bac);
3118 assert_eq!(total_abc, 4);
3119 }
3120
3121 fn k1_grams(x: &Array2<f64>, w: &Array1<f64>) -> (Array2<f64>, Array2<f64>) {
3126 let gram_struct = fast_atb(x, x);
3127 let xw = fast_xt_diag_y(x, w, x);
3128 (xw, gram_struct)
3129 }
3130
3131 #[test]
3138 fn compiled_map_lift_coefficients_roundtrips_full_rank() {
3139 let n = 21;
3140 let p_a = 2;
3141 let p_b = 2;
3142 let x = Array2::from_shape_fn((n, p_a + p_b), |(i, j)| {
3148 ((i as f64 + 1.0) * (0.21 + 0.17 * j as f64)).sin() + 0.11 * (j as f64)
3149 });
3150 let w = Array1::from_shape_fn(n, |i| 0.5 + 0.5 * ((i as f64) * 0.3).cos().abs());
3151 let (gh, gs) = k1_grams(&x, &w);
3152 let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
3153 let map = compile_from_raw_grams(
3154 &gh,
3155 &gs,
3156 &raw_ranges,
3157 &[BlockOrder::Marginal, BlockOrder::Logslope],
3158 )
3159 .expect("full-rank compile");
3160 assert_eq!(map.p_compiled(), p_a + p_b);
3162 assert_eq!(map.p_raw(), p_a + p_b);
3163 let beta_raw = Array1::from_shape_fn(p_a + p_b, |j| 0.4 * (j as f64) - 0.7);
3166 let tt = fast_atb(&map.raw_from_compiled, &map.raw_from_compiled);
3169 let tb = map.raw_from_compiled.t().dot(&beta_raw);
3170 let theta = solve_psd_system(&tt, &tb.insert_axis(Axis(1)))
3171 .expect("normal-equation solve")
3172 .column(0)
3173 .to_owned();
3174 let lifted = map.lift_coefficients(&theta).expect("lift");
3175 let max_err = (&lifted - &beta_raw)
3176 .iter()
3177 .fold(0.0_f64, |a, &v| a.max(v.abs()));
3178 assert!(
3179 max_err < 1e-8,
3180 "lift round-trip error {max_err:e} (full-rank reduction must be exactly invertible)"
3181 );
3182 }
3183
3184 #[test]
3190 fn compiled_map_reduce_design_matches_lifted_raw_predictor() {
3191 let n = 23;
3192 let p_a = 3;
3193 let p_b = 3;
3194 let mut x = Array2::from_shape_fn((n, p_a + p_b), |(i, j)| {
3195 ((i as f64 + 1.0) * 0.41 + (j as f64 + 1.0) * 0.7).sin() + 0.05 * (i % 3) as f64
3196 });
3197 for i in 0..n {
3199 x[[i, p_a + 1]] = x[[i, 1]];
3200 }
3201 let w = Array1::from_shape_fn(n, |i| 0.6 + 0.4 * ((i as f64) * 0.25).cos().abs());
3202 let (gh, gs) = k1_grams(&x, &w);
3203 let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
3204 let map = compile_from_raw_grams(
3205 &gh,
3206 &gs,
3207 &raw_ranges,
3208 &[BlockOrder::Marginal, BlockOrder::Logslope],
3209 )
3210 .expect("compile");
3211 let x_compiled = map.reduce_design(&x).expect("reduce_design");
3212 assert_eq!(x_compiled.ncols(), map.p_compiled());
3213 let theta = Array1::from_shape_fn(map.p_compiled(), |j| 0.3 * (j as f64) - 0.5);
3214 let pred_compiled = x_compiled.dot(&theta);
3215 let beta_raw = map.lift_coefficients(&theta).expect("lift");
3216 let pred_raw = x.dot(&beta_raw);
3217 let max_err = (&pred_compiled - &pred_raw)
3218 .iter()
3219 .fold(0.0_f64, |a, &v| a.max(v.abs()));
3220 assert!(
3221 max_err < 1e-9,
3222 "compiled-design predictor diverges from lifted raw predictor: {max_err:e}"
3223 );
3224 }
3225
3226 #[test]
3231 fn reduce_penalties_with_map_preserves_energy_on_lift() {
3232 let n = 19;
3233 let p_a = 3;
3234 let p_b = 2;
3235 let mut x = Array2::from_shape_fn((n, p_a + p_b), |(i, j)| {
3239 ((i as f64 + 1.0) * 0.29 + (j as f64 + 1.0) * 0.9).cos()
3240 });
3241 for i in 0..n {
3243 x[[i, p_a]] = x[[i, 0]];
3244 }
3245 let w = Array1::from_shape_fn(n, |i| 0.7 + 0.3 * ((i as f64) * 0.2).sin().abs());
3246 let (gh, gs) = k1_grams(&x, &w);
3247 let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
3248 let map = compile_from_raw_grams(
3249 &gh,
3250 &gs,
3251 &raw_ranges,
3252 &[BlockOrder::Marginal, BlockOrder::Logslope],
3253 )
3254 .expect("compile with alias");
3255 assert!(
3256 map.p_compiled() < p_a + p_b,
3257 "expected at least one absorbed column, got p_compiled={}",
3258 map.p_compiled()
3259 );
3260 let s_a = Array2::<f64>::eye(p_a);
3262 let s_b = Array2::<f64>::eye(p_b);
3263 let reduced = reduce_penalties_with_map(&map, &[Some(s_a.clone()), Some(s_b.clone())])
3264 .expect("reduce penalties");
3265 let theta = Array1::from_shape_fn(map.p_compiled(), |j| {
3268 0.6 * (j as f64) - 0.3 + 0.05 * (j % 2) as f64
3269 });
3270 let beta = map.lift_coefficients(&theta).expect("lift");
3271 for (block_idx, s_raw) in [(0usize, &s_a), (1usize, &s_b)] {
3272 let range = &map.raw_block_ranges[block_idx];
3273 let beta_b = beta.slice(s![range.start..range.end]).to_owned();
3274 let raw_energy = beta_b.dot(&s_raw.dot(&beta_b));
3275 let s_reduced = reduced[block_idx]
3276 .as_ref()
3277 .expect("reduced penalty present");
3278 let reduced_energy = theta.dot(&s_reduced.dot(&theta));
3279 assert!(
3280 (raw_energy - reduced_energy).abs() < 1e-8 * raw_energy.abs().max(1.0),
3281 "block {block_idx} energy mismatch: raw={raw_energy:e} reduced={reduced_energy:e}"
3282 );
3283 }
3284 }
3285}