1use ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
29
30use gam_linalg::faer_ndarray::{fast_ab, fast_abt, fast_atb};
31
32pub trait CompiledBlockMap {
41 fn raw_from_compiled(&self) -> &Array2<f64>;
43 fn raw_block_ranges(&self) -> &[std::ops::Range<usize>];
45 fn compiled_block_ranges(&self) -> &[std::ops::Range<usize>];
48}
49
50#[derive(Debug, Clone)]
53pub struct Gauge {
54 pub t_full: Array2<f64>,
56 pub affine_shift: Array1<f64>,
58 pub block_starts_raw: Vec<usize>,
61 pub block_starts_reduced: Vec<usize>,
63}
64
65fn starts_from_widths(widths: &[usize]) -> Vec<usize> {
66 let mut starts = Vec::with_capacity(widths.len() + 1);
67 starts.push(0);
68 for w in widths {
69 starts.push(starts.last().copied().unwrap() + w);
70 }
71 starts
72}
73
74pub fn assemble_block_triangular_t(
84 v_per_term: &[Array2<f64>],
85 r_per_term: &[Option<Array2<f64>>],
86) -> Array2<f64> {
87 assert_eq!(
88 v_per_term.len(),
89 r_per_term.len(),
90 "assemble_block_triangular_t: v_per_term len {} != r_per_term len {}",
91 v_per_term.len(),
92 r_per_term.len(),
93 );
94 let raw_widths: Vec<usize> = v_per_term.iter().map(|v| v.nrows()).collect();
95 let kept_widths: Vec<usize> = v_per_term.iter().map(|v| v.ncols()).collect();
96 let row_offsets = starts_from_widths(&raw_widths);
97 let col_offsets = starts_from_widths(&kept_widths);
98 let total_rows = row_offsets.last().copied().unwrap_or(0);
99 let total_cols = col_offsets.last().copied().unwrap_or(0);
100 let mut t = Array2::<f64>::zeros((total_rows, total_cols));
101 for (b, v) in v_per_term.iter().enumerate() {
103 let r = v.nrows();
104 let c = v.ncols();
105 if r > 0 && c > 0 {
106 t.slice_mut(ndarray::s![
107 row_offsets[b]..row_offsets[b] + r,
108 col_offsets[b]..col_offsets[b] + c
109 ])
110 .assign(v);
111 }
112 }
113 for b in 1..v_per_term.len() {
116 let Some(r_stack) = r_per_term[b].as_ref() else {
117 continue;
118 };
119 let kept_b = kept_widths[b];
120 assert_eq!(
121 r_stack.ncols(),
122 kept_b,
123 "assemble_block_triangular_t: r_per_term[{b}] has {} cols, expected {}",
124 r_stack.ncols(),
125 kept_b,
126 );
127 let expected_rows: usize = raw_widths.iter().take(b).sum();
128 assert_eq!(
129 r_stack.nrows(),
130 expected_rows,
131 "assemble_block_triangular_t: r_per_term[{b}] has {} rows, expected {} \
132 (sum of raw_widths[0..{}])",
133 r_stack.nrows(),
134 expected_rows,
135 b,
136 );
137 let mut local_row = 0usize;
138 for a in 0..b {
139 let r_a = raw_widths[a];
140 if r_a == 0 || kept_b == 0 {
141 local_row += r_a;
142 continue;
143 }
144 let block = r_stack.slice(ndarray::s![local_row..local_row + r_a, ..]);
145 let mut dst = t.slice_mut(ndarray::s![
146 row_offsets[a]..row_offsets[a] + r_a,
147 col_offsets[b]..col_offsets[b] + kept_b
148 ]);
149 for i in 0..r_a {
150 for j in 0..kept_b {
151 dst[[i, j]] = -block[[i, j]];
152 }
153 }
154 local_row += r_a;
155 }
156 }
157 t
158}
159
160impl Gauge {
161 pub fn identity(raw_widths: &[usize]) -> Self {
163 let transforms: Vec<Array2<f64>> =
164 raw_widths.iter().map(|&w| Array2::<f64>::eye(w)).collect();
165 Self::from_block_transforms(&transforms)
166 }
167
168 pub fn from_block_transforms(transforms: &[Array2<f64>]) -> Self {
172 let raw_total: usize = transforms.iter().map(|t| t.nrows()).sum();
173 Self::from_block_transforms_with_shift(transforms, Array1::zeros(raw_total))
174 }
175
176 pub fn from_block_transforms_with_shift(
179 transforms: &[Array2<f64>],
180 affine_shift: Array1<f64>,
181 ) -> Self {
182 let r_none: Vec<Option<Array2<f64>>> = transforms.iter().map(|_| None).collect();
183 let mut gauge = Self::from_v_and_r(transforms, &r_none);
184 assert_eq!(
185 affine_shift.len(),
186 gauge.raw_total(),
187 "Gauge::from_block_transforms_with_shift: affine shift len {} != raw width {}",
188 affine_shift.len(),
189 gauge.raw_total(),
190 );
191 gauge.affine_shift = affine_shift;
192 gauge
193 }
194
195 pub fn from_block_transform_with_shift(
197 transform: Array2<f64>,
198 affine_shift: Array1<f64>,
199 ) -> Self {
200 Self::from_block_transforms_with_shift(&[transform], affine_shift)
201 }
202
203 pub fn from_v_and_r(v_per_term: &[Array2<f64>], r_per_term: &[Option<Array2<f64>>]) -> Self {
207 let raw_widths: Vec<usize> = v_per_term.iter().map(|v| v.nrows()).collect();
208 let reduced_widths: Vec<usize> = v_per_term.iter().map(|v| v.ncols()).collect();
209 Self {
210 t_full: assemble_block_triangular_t(v_per_term, r_per_term),
211 affine_shift: Array1::zeros(raw_widths.iter().sum::<usize>()),
212 block_starts_raw: starts_from_widths(&raw_widths),
213 block_starts_reduced: starts_from_widths(&reduced_widths),
214 }
215 }
216
217 pub fn sum_to_zero(z: Array2<f64>) -> Self {
239 let (k, r) = z.dim();
240 assert!(
241 k > 0 && r < k,
242 "Gauge::sum_to_zero: z must be a tall reparametrisation ({k}×{r}); \
243 a centring section removes at least one direction (r < k)",
244 );
245 Self::from_block_transforms(&[z])
246 }
247
248 pub fn from_t(t_full: Array2<f64>, raw_widths: &[usize], reduced_widths: &[usize]) -> Self {
251 let total_raw: usize = raw_widths.iter().sum();
252 Self::from_t_with_shift(t_full, raw_widths, reduced_widths, Array1::zeros(total_raw))
253 }
254
255 pub fn from_t_with_shift(
258 t_full: Array2<f64>,
259 raw_widths: &[usize],
260 reduced_widths: &[usize],
261 affine_shift: Array1<f64>,
262 ) -> Self {
263 assert_eq!(
264 raw_widths.len(),
265 reduced_widths.len(),
266 "Gauge::from_t: raw_widths len {} != reduced_widths len {}",
267 raw_widths.len(),
268 reduced_widths.len(),
269 );
270 let total_raw: usize = raw_widths.iter().sum();
271 let total_reduced: usize = reduced_widths.iter().sum();
272 assert_eq!(
273 t_full.dim(),
274 (total_raw, total_reduced),
275 "Gauge::from_t: T has shape {:?}, expected ({total_raw}, {total_reduced})",
276 t_full.dim(),
277 );
278 assert_eq!(
279 affine_shift.len(),
280 total_raw,
281 "Gauge::from_t_with_shift: affine shift len {} != raw width {total_raw}",
282 affine_shift.len(),
283 );
284 Self {
285 t_full,
286 affine_shift,
287 block_starts_raw: starts_from_widths(raw_widths),
288 block_starts_reduced: starts_from_widths(reduced_widths),
289 }
290 }
291
292 pub fn from_compiled_map<M: CompiledBlockMap, O>(map: &M, ordering: &[O]) -> Self {
298 assert_eq!(
299 map.raw_block_ranges().len(),
300 map.compiled_block_ranges().len(),
301 "Gauge::from_compiled_map: CompiledMap raw_block_ranges len {} != \
302 compiled_block_ranges len {}",
303 map.raw_block_ranges().len(),
304 map.compiled_block_ranges().len(),
305 );
306 assert_eq!(
307 map.raw_block_ranges().len(),
308 ordering.len(),
309 "Gauge::from_compiled_map: ordering len {} != block count {}",
310 ordering.len(),
311 map.raw_block_ranges().len(),
312 );
313 let mut block_starts_raw = Vec::with_capacity(map.raw_block_ranges().len() + 1);
314 block_starts_raw.push(0);
315 for r in map.raw_block_ranges() {
316 block_starts_raw.push(r.end);
317 }
318 let mut block_starts_reduced = Vec::with_capacity(map.compiled_block_ranges().len() + 1);
319 block_starts_reduced.push(0);
320 for r in map.compiled_block_ranges() {
321 block_starts_reduced.push(r.end);
322 }
323 let total_raw = block_starts_raw.last().copied().unwrap_or(0);
324 Self {
325 t_full: map.raw_from_compiled().clone(),
326 affine_shift: Array1::zeros(total_raw),
327 block_starts_raw,
328 block_starts_reduced,
329 }
330 }
331
332 pub fn n_blocks(&self) -> usize {
334 self.block_starts_raw.len().saturating_sub(1)
335 }
336
337 pub fn raw_total(&self) -> usize {
339 self.block_starts_raw.last().copied().unwrap_or(0)
340 }
341
342 pub fn reduced_total(&self) -> usize {
344 self.block_starts_reduced.last().copied().unwrap_or(0)
345 }
346
347 pub fn raw_widths(&self) -> Vec<usize> {
349 self.block_starts_raw
350 .windows(2)
351 .map(|w| w[1] - w[0])
352 .collect()
353 }
354
355 pub fn reduced_widths(&self) -> Vec<usize> {
357 self.block_starts_reduced
358 .windows(2)
359 .map(|w| w[1] - w[0])
360 .collect()
361 }
362
363 pub fn block_transform(&self, b: usize) -> Array2<f64> {
367 assert!(
368 b < self.n_blocks(),
369 "Gauge::block_transform: block {b} out of range {}",
370 self.n_blocks(),
371 );
372 self.t_full
373 .slice(ndarray::s![
374 self.block_starts_raw[b]..self.block_starts_raw[b + 1],
375 self.block_starts_reduced[b]..self.block_starts_reduced[b + 1]
376 ])
377 .to_owned()
378 }
379
380 pub fn restrict_design<S: Data<Elem = f64>>(
382 &self,
383 raw_design: &ArrayBase<S, Ix2>,
384 ) -> Array2<f64> {
385 let raw_total = self.raw_total();
386 assert_eq!(
387 raw_design.ncols(),
388 raw_total,
389 "Gauge::restrict_design: design has {} columns, expected raw width {raw_total}",
390 raw_design.ncols(),
391 );
392 if self.t_full_is_identity() {
400 return raw_design.to_owned();
401 }
402 fast_ab(raw_design, &self.t_full)
403 }
404
405 fn t_full_is_identity(&self) -> bool {
411 let (r, c) = self.t_full.dim();
412 if r != c {
413 return false;
414 }
415 self.t_full
416 .indexed_iter()
417 .all(|((i, j), &v)| v == if i == j { 1.0 } else { 0.0 })
418 }
419
420 pub fn restrict_design_and_offset<S: Data<Elem = f64>>(
423 &self,
424 raw_design: &ArrayBase<S, Ix2>,
425 raw_offset: &Array1<f64>,
426 ) -> (Array2<f64>, Array1<f64>) {
427 assert_eq!(
428 raw_design.nrows(),
429 raw_offset.len(),
430 "Gauge::restrict_design_and_offset: design rows {} != offset len {}",
431 raw_design.nrows(),
432 raw_offset.len(),
433 );
434 let reduced_design = self.restrict_design(raw_design);
435 let reduced_offset = raw_offset + &raw_design.dot(&self.affine_shift);
436 (reduced_design, reduced_offset)
437 }
438
439 pub fn restrict_penalty<S: Data<Elem = f64>>(
442 &self,
443 raw_penalty: &ArrayBase<S, Ix2>,
444 ) -> Array2<f64> {
445 let raw_total = self.raw_total();
446 assert_eq!(
447 raw_penalty.dim(),
448 (raw_total, raw_total),
449 "Gauge::restrict_penalty: matrix has shape {:?}, expected ({raw_total}, {raw_total})",
450 raw_penalty.dim(),
451 );
452 if self.t_full_is_identity() {
455 return raw_penalty.to_owned();
456 }
457 let t_s = fast_atb(&self.t_full, raw_penalty);
458 fast_ab(&t_s, &self.t_full)
459 }
460
461 pub fn extend_with_identity(&self, extra_raw_widths: &[usize]) -> Self {
466 let extra_total: usize = extra_raw_widths.iter().sum();
467 let raw_total = self.raw_total();
468 let reduced_total = self.reduced_total();
469 let mut t = Array2::<f64>::zeros((raw_total + extra_total, reduced_total + extra_total));
470 t.slice_mut(ndarray::s![0..raw_total, 0..reduced_total])
471 .assign(&self.t_full);
472 for k in 0..extra_total {
473 t[[raw_total + k, reduced_total + k]] = 1.0;
474 }
475 let mut block_starts_raw = self.block_starts_raw.clone();
476 let mut block_starts_reduced = self.block_starts_reduced.clone();
477 for &w in extra_raw_widths {
478 block_starts_raw.push(block_starts_raw.last().copied().unwrap() + w);
479 block_starts_reduced.push(block_starts_reduced.last().copied().unwrap() + w);
480 }
481 let mut affine_shift = Array1::<f64>::zeros(raw_total + extra_total);
482 affine_shift
483 .slice_mut(ndarray::s![0..raw_total])
484 .assign(&self.affine_shift);
485 Self {
486 t_full: t,
487 affine_shift,
488 block_starts_raw,
489 block_starts_reduced,
490 }
491 }
492
493 pub fn lift_block_betas(&self, reduced_block_betas: &[Array1<f64>]) -> Vec<Array1<f64>> {
497 let n_blocks = self.n_blocks();
498 assert_eq!(
499 reduced_block_betas.len(),
500 n_blocks,
501 "Gauge::lift_block_betas: got {} reduced block betas, expected {}",
502 reduced_block_betas.len(),
503 n_blocks,
504 );
505 for (b, beta) in reduced_block_betas.iter().enumerate() {
506 let expected = self.block_starts_reduced[b + 1] - self.block_starts_reduced[b];
507 assert_eq!(
508 beta.len(),
509 expected,
510 "Gauge::lift_block_betas: block {b} has β of len {}, expected reduced width {}",
511 beta.len(),
512 expected,
513 );
514 }
515 let mut theta_full = Array1::<f64>::zeros(self.reduced_total());
516 for (b, beta) in reduced_block_betas.iter().enumerate() {
517 let c0 = self.block_starts_reduced[b];
518 let c1 = self.block_starts_reduced[b + 1];
519 theta_full.slice_mut(ndarray::s![c0..c1]).assign(beta);
520 }
521 let beta_full = self.t_full.dot(&theta_full) + &self.affine_shift;
522 let mut out = Vec::with_capacity(n_blocks);
523 for b in 0..n_blocks {
524 let r0 = self.block_starts_raw[b];
525 let r1 = self.block_starts_raw[b + 1];
526 out.push(beta_full.slice(ndarray::s![r0..r1]).to_owned());
527 }
528 out
529 }
530
531 pub fn lift_covariance(&self, m_reduced: &Array2<f64>) -> Array2<f64> {
540 let total_reduced = self.reduced_total();
541 assert_eq!(
542 m_reduced.dim(),
543 (total_reduced, total_reduced),
544 "Gauge::lift_covariance: matrix has shape {:?}, expected ({total_reduced}, {total_reduced})",
545 m_reduced.dim(),
546 );
547 let t_m = fast_ab(&self.t_full, m_reduced);
548 let mut raw = fast_abt(&t_m, &self.t_full);
549 let n = raw.nrows();
550 for i in 0..n {
551 for j in (i + 1)..n {
552 let avg = 0.5 * (raw[[i, j]] + raw[[j, i]]);
553 raw[[i, j]] = avg;
554 raw[[j, i]] = avg;
555 }
556 }
557 raw
558 }
559}
560
561#[cfg(test)]
562mod tests {
563 use super::*;
564
565 #[test]
566 fn identity_gauge_round_trips_betas_and_covariance() {
567 let gauge = Gauge::identity(&[2, 3]);
568 assert_eq!(gauge.n_blocks(), 2);
569 assert_eq!(gauge.raw_total(), 5);
570 assert_eq!(gauge.reduced_total(), 5);
571 let theta = vec![
572 Array1::from(vec![0.5, -0.25]),
573 Array1::from(vec![1.0, 2.0, -3.0]),
574 ];
575 let raw = gauge.lift_block_betas(&theta);
576 assert_eq!(raw[0].as_slice().unwrap(), &[0.5, -0.25]);
577 assert_eq!(raw[1].as_slice().unwrap(), &[1.0, 2.0, -3.0]);
578
579 let mut cov = Array2::<f64>::eye(5);
580 cov[[0, 3]] = 0.4;
581 cov[[3, 0]] = 0.4;
582 let lifted = gauge.lift_covariance(&cov);
583 for i in 0..5 {
584 for j in 0..5 {
585 assert!(
586 (lifted[[i, j]] - cov[[i, j]]).abs() < 1e-14,
587 "identity gauge must be a covariance no-op at ({i},{j})",
588 );
589 }
590 }
591 }
592
593 #[test]
594 fn identity_section_short_circuits_restrict_bit_exactly() {
595 let gauge = Gauge::identity(&[4]);
598 assert!(gauge.t_full_is_identity());
599
600 let raw_design = Array2::<f64>::from_shape_fn((7, 4), |(i, j)| {
603 ((i as f64) * 0.3 - (j as f64) * 1.7).sin() * 1.000000001
604 });
605 let restricted = gauge.restrict_design(&raw_design);
606 assert_eq!(restricted, raw_design);
608 let via_gemm = fast_ab(&raw_design, &gauge.t_full);
610 assert_eq!(restricted, via_gemm);
611
612 let raw_penalty = Array2::<f64>::from_shape_fn((4, 4), |(i, j)| {
613 (i as f64 + 1.0) * (j as f64 + 2.0) * 0.111
614 });
615 let restricted_pen = gauge.restrict_penalty(&raw_penalty);
616 assert_eq!(restricted_pen, raw_penalty);
617 let pen_via_gemm = fast_ab(&fast_atb(&gauge.t_full, &raw_penalty), &gauge.t_full);
618 assert_eq!(restricted_pen, pen_via_gemm);
619 }
620
621 #[test]
622 fn non_identity_section_is_not_short_circuited() {
623 let mut t = Array2::<f64>::eye(3);
625 t[[0, 1]] = 0.5;
626 let gauge = Gauge::from_t(t.clone(), &[3], &[3]);
627 assert!(!gauge.t_full_is_identity());
628 let raw = Array2::<f64>::from_shape_fn((5, 3), |(i, j)| i as f64 + j as f64 * 0.25);
629 let restricted = gauge.restrict_design(&raw);
630 assert_eq!(restricted, fast_ab(&raw, &t));
631 }
632
633 #[test]
634 fn rectangular_section_is_not_identity() {
635 let z =
638 Array2::<f64>::from_shape_vec((3, 2), vec![1.0, 0.0, 0.0, 1.0, -1.0, -1.0]).unwrap();
639 let gauge = Gauge::sum_to_zero(z);
640 assert!(!gauge.t_full_is_identity());
641 }
642
643 #[test]
644 fn affine_gauge_lifts_betas_and_restricts_offsets() {
645 let t = Array2::from_shape_vec((3, 1), vec![2.0, -1.0, 0.5]).unwrap();
646 let shift = Array1::from(vec![0.25, 1.5, -0.75]);
647 let gauge = Gauge::from_block_transform_with_shift(t.clone(), shift.clone());
648 let theta = Array1::from(vec![4.0]);
649
650 let raw = gauge.lift_block_betas(&[theta.clone()]);
651 let expected_raw = t.dot(&theta) + &shift;
652 assert_eq!(raw[0], expected_raw);
653
654 let x = Array2::from_shape_vec((2, 3), vec![1.0, 0.0, 2.0, -1.0, 3.0, 0.5]).unwrap();
655 let offset = Array1::from(vec![0.1, -0.2]);
656 let (x_reduced, offset_reduced) = gauge.restrict_design_and_offset(&x, &offset);
657 assert_eq!(x_reduced, x.dot(&t));
658 assert_eq!(offset_reduced, &offset + &x.dot(&shift));
659
660 let eta_raw = x.dot(&expected_raw) + &offset;
661 let eta_reduced = x_reduced.dot(&theta) + &offset_reduced;
662 for i in 0..eta_raw.len() {
663 assert!((eta_raw[i] - eta_reduced[i]).abs() < 1e-14);
664 }
665
666 let cov_reduced = Array2::from_elem((1, 1), 3.0);
667 let lifted_cov = gauge.lift_covariance(&cov_reduced);
668 let expected_cov = t.dot(&cov_reduced).dot(&t.t());
669 assert_eq!(lifted_cov, expected_cov);
670 }
671
672 #[test]
684 fn affine_shift_leaves_lifted_covariance_invariant() {
685 let t =
687 Array2::from_shape_vec((4, 2), vec![1.0, 0.0, 0.5, -1.0, 2.0, 0.3, -0.4, 1.5]).unwrap();
688 let raw_widths = [4usize];
689 let reduced_widths = [2usize];
690
691 let cov_reduced = Array2::from_shape_vec((2, 2), vec![2.0, -0.7, -0.7, 1.3]).unwrap();
693
694 let base =
696 Gauge::from_t_with_shift(t.clone(), &raw_widths, &reduced_widths, Array1::zeros(4));
697 let reference = base.lift_covariance(&cov_reduced);
698
699 for &mag in &[0.0, 1e-7, 1.0, 1e3, 1e7] {
701 let shift = Array1::from(vec![mag, -mag, 0.5 * mag, -2.0 * mag]);
702 let gauge = Gauge::from_t_with_shift(t.clone(), &raw_widths, &reduced_widths, shift);
703 let lifted = gauge.lift_covariance(&cov_reduced);
704 for i in 0..4 {
705 for j in 0..4 {
706 assert_eq!(
707 lifted[[i, j]],
708 reference[[i, j]],
709 "affine shift magnitude {mag} must not perturb the lifted covariance \
710 at ({i},{j}) — covariance is offset-invariant",
711 );
712 }
713 }
714 }
715
716 let chol = {
721 let l00 = cov_reduced[[0, 0]].sqrt();
722 let l10 = cov_reduced[[1, 0]] / l00;
723 let l11 = (cov_reduced[[1, 1]] - l10 * l10).sqrt();
724 Array2::from_shape_vec((2, 2), vec![l00, 0.0, l10, l11]).unwrap()
725 };
726 let z_raw = [
727 [1.2, -0.4],
728 [-0.8, 0.9],
729 [0.3, 1.7],
730 [-1.5, -0.6],
731 [0.6, -1.1],
732 [-0.2, 0.3],
733 [1.9, 0.2],
734 [-1.4, -0.9],
735 ];
736 let sample_cov_for_shift = |shift: &Array1<f64>| -> Array2<f64> {
737 let n = z_raw.len();
738 let betas: Vec<Array1<f64>> = z_raw
739 .iter()
740 .map(|z| {
741 let theta = chol.dot(&Array1::from(vec![z[0], z[1]]));
742 t.dot(&theta) + shift
743 })
744 .collect();
745 let mut mean = Array1::<f64>::zeros(4);
746 for b in &betas {
747 mean = &mean + b;
748 }
749 mean /= n as f64;
750 let mut cov = Array2::<f64>::zeros((4, 4));
751 for b in &betas {
752 let c = b - &mean;
753 for i in 0..4 {
754 for j in 0..4 {
755 cov[[i, j]] += c[i] * c[j] / n as f64;
756 }
757 }
758 }
759 cov
760 };
761 let cov_small = sample_cov_for_shift(&Array1::zeros(4));
762 let cov_big = sample_cov_for_shift(&Array1::from(vec![1e6, -1e6, 5e5, -2e6]));
763 for i in 0..4 {
764 for j in 0..4 {
765 assert!(
766 (cov_small[[i, j]] - cov_big[[i, j]]).abs() < 1e-6,
767 "empirical sample covariance must be offset-invariant at ({i},{j}): \
768 small-shift {} vs big-shift {}",
769 cov_small[[i, j]],
770 cov_big[[i, j]],
771 );
772 }
773 }
774 }
775
776 #[test]
777 fn block_diagonal_gauge_matches_per_block_lift() {
778 let mut t0 = Array2::<f64>::zeros((3, 2));
780 t0[[0, 0]] = 1.0;
781 t0[[2, 1]] = 1.0;
782 let t1 = Array2::<f64>::eye(2);
784 let gauge = Gauge::from_block_transforms(&[t0.clone(), t1.clone()]);
785 assert_eq!(gauge.raw_widths(), vec![3, 2]);
786 assert_eq!(gauge.reduced_widths(), vec![2, 2]);
787
788 let theta = vec![Array1::from(vec![1.5, -2.5]), Array1::from(vec![0.5, 4.0])];
789 let raw = gauge.lift_block_betas(&theta);
790 assert_eq!(raw[0].as_slice().unwrap(), &[1.5, 0.0, -2.5]);
791 assert_eq!(raw[1].as_slice().unwrap(), &[0.5, 4.0]);
792
793 assert_eq!(gauge.block_transform(0), t0);
795 assert_eq!(gauge.block_transform(1), t1);
796 }
797
798 #[test]
799 fn triangular_gauge_applies_negative_r_off_diagonal() {
800 let v_a = Array2::<f64>::eye(2);
803 let mut v_b = Array2::<f64>::zeros((2, 1));
804 v_b[[0, 0]] = 1.0;
805 let mut r_ab = Array2::<f64>::zeros((2, 1));
806 r_ab[[0, 0]] = 0.5;
807 r_ab[[1, 0]] = -0.25;
808 let gauge = Gauge::from_v_and_r(&[v_a, v_b], &[None, Some(r_ab)]);
809
810 let theta = vec![Array1::from(vec![1.0, 2.0]), Array1::from(vec![4.0])];
811 let raw = gauge.lift_block_betas(&theta);
812 assert!((raw[0][0] - (-1.0)).abs() < 1e-14);
814 assert!((raw[0][1] - 3.0).abs() < 1e-14);
815 assert!((raw[1][0] - 4.0).abs() < 1e-14);
817 assert!((raw[1][1] - 0.0).abs() < 1e-14);
818 }
819
820 #[test]
824 fn covariance_lift_is_rank1_consistent_with_beta_lift() {
825 let v_a = Array2::<f64>::eye(2);
826 let mut v_b = Array2::<f64>::zeros((2, 1));
827 v_b[[0, 0]] = 1.0;
828 let mut r_ab = Array2::<f64>::zeros((2, 1));
829 r_ab[[0, 0]] = 0.3;
830 r_ab[[1, 0]] = 0.7;
831 let gauge = Gauge::from_v_and_r(&[v_a, v_b], &[None, Some(r_ab)]);
832
833 let theta = vec![Array1::from(vec![0.8, -1.2]), Array1::from(vec![2.0])];
834 let raw = gauge.lift_block_betas(&theta);
835 let beta_full: Vec<f64> = raw.iter().flat_map(|b| b.iter().copied()).collect();
836
837 let theta_full = Array1::from(vec![0.8, -1.2, 2.0]);
838 let cov_rank1 = {
839 let n = theta_full.len();
840 Array2::from_shape_fn((n, n), |(i, j)| theta_full[i] * theta_full[j])
841 };
842 let lifted = gauge.lift_covariance(&cov_rank1);
843 assert_eq!(lifted.dim(), (4, 4));
844 for i in 0..4 {
845 for j in 0..4 {
846 let expected = beta_full[i] * beta_full[j];
847 assert!(
848 (lifted[[i, j]] - expected).abs() < 1e-12,
849 "rank-1 covariance lift must equal (Tθ)(Tθ)ᵀ at ({i},{j}): \
850 got {} expected {expected}",
851 lifted[[i, j]],
852 );
853 }
854 }
855 }
856
857 #[test]
864 fn sum_to_zero_gauge_lifts_via_z_and_preserves_eta() {
865 let s = 1.0 / 2.0_f64.sqrt();
869 let s6 = 1.0 / 6.0_f64.sqrt();
870 let mut z = Array2::<f64>::zeros((3, 2));
871 z[[0, 0]] = s;
872 z[[1, 0]] = -s;
873 z[[2, 0]] = 0.0;
874 z[[0, 1]] = s6;
875 z[[1, 1]] = s6;
876 z[[2, 1]] = -2.0 * s6;
877 for j in 0..2 {
879 assert!(
880 (z.column(j).sum()).abs() < 1e-14,
881 "column {j} must sum to 0"
882 );
883 assert!(
884 (z.column(j).dot(&z.column(j)) - 1.0).abs() < 1e-14,
885 "column {j} must be unit norm"
886 );
887 }
888
889 let gauge = Gauge::sum_to_zero(z.clone());
890 assert_eq!(gauge.n_blocks(), 1);
891 assert_eq!(gauge.raw_widths(), vec![3]);
892 assert_eq!(gauge.reduced_widths(), vec![2]);
893 assert_eq!(gauge.block_transform(0), z);
894
895 let theta = Array1::from(vec![1.3, -0.7]);
897 let raw = gauge.lift_block_betas(&[theta.clone()]);
898 let expected_raw = z.dot(&theta);
899 for i in 0..3 {
900 assert!((raw[0][i] - expected_raw[i]).abs() < 1e-14);
901 }
902 assert!(raw[0].sum().abs() < 1e-14, "lifted β must be centred");
904
905 let b = Array2::from_shape_vec(
907 (4, 3),
908 vec![
909 1.0, 2.0, -1.0, 0.5, -0.5, 3.0, 2.0, 1.0, 1.0, -1.0, 0.0, 4.0,
910 ],
911 )
912 .unwrap();
913 let b_c = fast_ab(&b, &z); assert_eq!(gauge.restrict_design(&b), b_c);
915 let eta_reduced = b_c.dot(&theta);
916 let eta_raw = b.dot(&expected_raw);
917 for i in 0..4 {
918 assert!(
919 (eta_reduced[i] - eta_raw[i]).abs() < 1e-13,
920 "η must be invariant under the centring lift at row {i}",
921 );
922 }
923
924 let cov_rank1 = Array2::from_shape_fn((2, 2), |(i, j)| theta[i] * theta[j]);
926 let lifted = gauge.lift_covariance(&cov_rank1);
927 assert_eq!(lifted.dim(), (3, 3));
928 for i in 0..3 {
929 for j in 0..3 {
930 let expect = expected_raw[i] * expected_raw[j];
931 assert!(
932 (lifted[[i, j]] - expect).abs() < 1e-13,
933 "centring covariance lift must equal (zθ)(zθ)ᵀ at ({i},{j})",
934 );
935 }
936 }
937
938 let raw_penalty = Array2::from_shape_vec(
939 (3, 3),
940 vec![2.0, 0.5, 0.0, 0.5, 3.0, -0.25, 0.0, -0.25, 4.0],
941 )
942 .unwrap();
943 let reduced_penalty = gauge.restrict_penalty(&raw_penalty);
944 let expected_reduced_penalty = fast_ab(&fast_atb(&z, &raw_penalty), &z);
945 assert_eq!(reduced_penalty, expected_reduced_penalty);
946 }
947
948 #[test]
949 #[should_panic(expected = "removes at least one direction")]
950 fn sum_to_zero_rejects_identity_section() {
951 drop(Gauge::sum_to_zero(Array2::<f64>::eye(3)));
953 }
954
955 #[test]
956 fn extend_with_identity_passes_extra_blocks_through() {
957 let mut t0 = Array2::<f64>::zeros((2, 1));
958 t0[[0, 0]] = 1.0;
959 let gauge = Gauge::from_block_transforms(&[t0]).extend_with_identity(&[2]);
960 assert_eq!(gauge.n_blocks(), 2);
961 assert_eq!(gauge.raw_total(), 4);
962 assert_eq!(gauge.reduced_total(), 3);
963
964 let theta = vec![Array1::from(vec![3.0]), Array1::from(vec![1.0, -1.0])];
965 let raw = gauge.lift_block_betas(&theta);
966 assert_eq!(raw[0].as_slice().unwrap(), &[3.0, 0.0]);
967 assert_eq!(raw[1].as_slice().unwrap(), &[1.0, -1.0]);
968
969 let mut cov = Array2::<f64>::eye(3);
972 cov[[1, 2]] = 0.25;
973 cov[[2, 1]] = 0.25;
974 let lifted = gauge.lift_covariance(&cov);
975 assert_eq!(lifted.dim(), (4, 4));
976 assert!((lifted[[0, 0]] - 1.0).abs() < 1e-14);
977 assert!(
978 (lifted[[1, 1]] - 0.0).abs() < 1e-14,
979 "dropped raw row has zero variance"
980 );
981 assert!((lifted[[2, 2]] - 1.0).abs() < 1e-14);
982 assert!((lifted[[3, 3]] - 1.0).abs() < 1e-14);
983 assert!((lifted[[2, 3]] - 0.25).abs() < 1e-14);
984 }
985}