1use ndarray::{Array1, Array2, ArrayView2};
88
89use gam_linalg::faer_ndarray::{FaerArrayView, factorize_symmetricwith_fallback};
90use gam_linalg::matrix::FactorizedSystem;
91use faer::Side;
92
93#[inline]
95fn jet_mul(a: [f64; 3], b: [f64; 3]) -> [f64; 3] {
96 [
97 a[0] * b[0],
98 a[0] * b[1] + a[1] * b[0],
99 a[0] * b[2] + 2.0 * a[1] * b[1] + a[2] * b[0],
100 ]
101}
102
103#[inline]
105fn jet_div(a: [f64; 3], b: [f64; 3]) -> [f64; 3] {
106 let q0 = a[0] / b[0];
107 let q1 = (a[1] - q0 * b[1]) / b[0];
108 let q2 = (a[2] - q0 * b[2] - 2.0 * q1 * b[1]) / b[0];
109 [q0, q1, q2]
110}
111
112#[derive(Debug, Clone, Copy)]
117pub struct RowExpectedJets {
118 pub mu1: f64,
120 pub mu2: f64,
122 pub mu3: f64,
124 pub var: f64,
126 pub dvar_dmu: f64,
128 pub d2var_dmu2: f64,
130 pub dispersion: f64,
133}
134
135#[derive(Debug, Clone, Copy)]
137pub struct RowKappas {
138 pub k2: f64,
140 pub k3: f64,
142 pub k4: f64,
144 pub k2_1: f64,
146 pub k2_11: f64,
148 pub k3_1: f64,
150}
151
152impl RowKappas {
153 pub fn weighted(self, w: f64) -> Self {
158 Self {
159 k2: self.k2 * w,
160 k3: self.k3 * w,
161 k4: self.k4 * w,
162 k2_1: self.k2_1 * w,
163 k2_11: self.k2_11 * w,
164 k3_1: self.k3_1 * w,
165 }
166 }
167}
168
169impl RowExpectedJets {
170 pub fn kappas(&self) -> Result<RowKappas, String> {
173 let phi = self.dispersion;
174 if !(phi.is_finite() && phi > 0.0) {
175 return Err(format!(
176 "RowExpectedJets::kappas: dispersion must be finite and positive; got {phi}"
177 ));
178 }
179 if !(self.var.is_finite() && self.var > 0.0) {
180 return Err(format!(
181 "RowExpectedJets::kappas: variance function must be finite and positive; got {}",
182 self.var
183 ));
184 }
185 let mu1_jet = [self.mu1, self.mu2, self.mu3];
187 let v_jet = [
188 self.var,
189 self.dvar_dmu * self.mu1,
190 self.d2var_dmu2 * self.mu1 * self.mu1 + self.dvar_dmu * self.mu2,
191 ];
192 let c = jet_div(mu1_jet, v_jet);
193 let u0 = jet_mul(mu1_jet, c);
194 let inv_phi = 1.0 / phi;
195 Ok(RowKappas {
196 k2: -u0[0] * inv_phi,
197 k2_1: -u0[1] * inv_phi,
198 k2_11: -u0[2] * inv_phi,
199 k3: -(u0[1] + self.mu1 * c[1]) * inv_phi,
200 k3_1: -(u0[2] + self.mu2 * c[1] + self.mu1 * c[2]) * inv_phi,
201 k4: -(u0[2] + self.mu2 * c[1] + 2.0 * self.mu1 * c[2]) * inv_phi,
202 })
203 }
204
205 pub fn gaussian_identity(dispersion: f64) -> Self {
207 Self {
208 mu1: 1.0,
209 mu2: 0.0,
210 mu3: 0.0,
211 var: 1.0,
212 dvar_dmu: 0.0,
213 d2var_dmu2: 0.0,
214 dispersion,
215 }
216 }
217
218 pub fn poisson_log(eta: f64) -> Self {
220 let mu = eta.exp();
221 Self {
222 mu1: mu,
223 mu2: mu,
224 mu3: mu,
225 var: mu,
226 dvar_dmu: 1.0,
227 d2var_dmu2: 0.0,
228 dispersion: 1.0,
229 }
230 }
231
232 pub fn binomial_logit(eta: f64) -> Self {
234 let mu = 1.0 / (1.0 + (-eta).exp());
235 let mu1 = mu * (1.0 - mu);
236 let mu2 = mu1 * (1.0 - 2.0 * mu);
237 let mu3 = mu2 * (1.0 - 2.0 * mu) - 2.0 * mu1 * mu1;
238 Self {
239 mu1,
240 mu2,
241 mu3,
242 var: mu1,
243 dvar_dmu: 1.0 - 2.0 * mu,
244 d2var_dmu2: -2.0,
245 dispersion: 1.0,
246 }
247 }
248
249 pub fn gamma_log(eta: f64, dispersion: f64) -> Self {
252 let mu = eta.exp();
253 Self {
254 mu1: mu,
255 mu2: mu,
256 mu3: mu,
257 var: mu * mu,
258 dvar_dmu: 2.0 * mu,
259 d2var_dmu2: 2.0,
260 dispersion,
261 }
262 }
263}
264
265pub fn lawley_epsilon(
274 x: ArrayView2<'_, f64>,
275 kappas: &[RowKappas],
276 penalty: Option<ArrayView2<'_, f64>>,
277) -> Result<f64, String> {
278 let n = x.nrows();
279 let k = x.ncols();
280 if n == 0 || k == 0 {
281 return Err(format!(
282 "lawley_epsilon: empty design ({n} rows, {k} columns)"
283 ));
284 }
285 if kappas.len() != n {
286 return Err(format!(
287 "lawley_epsilon: {} cumulant rows for {n} design rows",
288 kappas.len()
289 ));
290 }
291 let mut j_mat = Array2::<f64>::zeros((k, k));
293 for (i, row_kappas) in kappas.iter().enumerate() {
294 let weight = -row_kappas.k2;
295 if !weight.is_finite() {
296 return Err(format!(
297 "lawley_epsilon: non-finite Fisher weight at row {i}"
298 ));
299 }
300 for r in 0..k {
301 let xr = x[[i, r]] * weight;
302 for s in 0..k {
303 j_mat[[r, s]] += xr * x[[i, s]];
304 }
305 }
306 }
307 if let Some(s_pen) = penalty {
308 if s_pen.nrows() != k || s_pen.ncols() != k {
309 return Err(format!(
310 "lawley_epsilon: penalty is {}×{}, expected {k}×{k}",
311 s_pen.nrows(),
312 s_pen.ncols()
313 ));
314 }
315 j_mat += &s_pen;
316 }
317 let j_view = FaerArrayView::new(&j_mat);
318 let factor = factorize_symmetricwith_fallback(j_view.as_ref(), Side::Lower)
319 .map_err(|e| format!("lawley_epsilon: information factorization failed: {e:?}"))?;
320 let j_inv = FactorizedSystem::solvemulti(&factor, &Array2::<f64>::eye(k))?;
321
322 let e_pairs = x.dot(&j_inv).dot(&x.t());
324 let h = e_pairs.diag().to_owned();
325
326 let mut lambda4 = 0.0;
328 for (i, row_kappas) in kappas.iter().enumerate() {
329 let a_i = row_kappas.k4 / 4.0 - row_kappas.k3_1 + row_kappas.k2_11;
330 lambda4 += a_i * h[i] * h[i];
331 }
332
333 let k3: Array1<f64> = kappas.iter().map(|r| r.k3).collect();
336 let k21: Array1<f64> = kappas.iter().map(|r| r.k2_1).collect();
337 let mut lambda6 = 0.0;
338 for i in 0..n {
339 for j in 0..n {
340 let e_ij = e_pairs[[i, j]];
341 let cross = k3[i] * k3[j];
342 let mixed = -k3[i] * k21[j] + k21[i] * k21[j];
343 lambda6 -= e_ij * e_ij * e_ij * (cross / 6.0 + mixed)
344 + h[i] * h[j] * e_ij * (cross / 4.0 + mixed);
345 }
346 }
347
348 let epsilon = lambda4 - lambda6;
349 if !epsilon.is_finite() {
350 return Err(format!(
351 "lawley_epsilon: non-finite ε (λ₄={lambda4}, λ₆={lambda6})"
352 ));
353 }
354 Ok(epsilon)
355}
356
357pub const LAWLEY_PAIR_MATRIX_MAX_ROWS: usize = 2048;
362
363pub fn lawley_lr_mean_shift(
380 x: ArrayView2<'_, f64>,
381 kappas: &[RowKappas],
382 penalty: Option<ArrayView2<'_, f64>>,
383 tested: std::ops::Range<usize>,
384) -> Result<f64, String> {
385 let n = x.nrows();
386 let k = x.ncols();
387 if tested.start >= tested.end || tested.end > k {
388 return Err(format!(
389 "lawley_lr_mean_shift: tested block {}..{} out of range for {k} columns",
390 tested.start, tested.end
391 ));
392 }
393 let eps_full = lawley_epsilon(x, kappas, penalty)?;
395 let nuisance: Vec<usize> = (0..k).filter(|c| !tested.contains(c)).collect();
396 if nuisance.is_empty() {
397 return Ok(eps_full);
399 }
400 let m = nuisance.len();
401 let mut x_null = Array2::<f64>::zeros((n, m));
402 for (col_null, &col_full) in nuisance.iter().enumerate() {
403 for i in 0..n {
404 x_null[[i, col_null]] = x[[i, col_full]];
405 }
406 }
407 let penalty_null = penalty.map(|s_pen| {
408 let mut out = Array2::<f64>::zeros((m, m));
409 for (r_null, &r_full) in nuisance.iter().enumerate() {
410 for (c_null, &c_full) in nuisance.iter().enumerate() {
411 out[[r_null, c_null]] = s_pen[[r_full, c_full]];
412 }
413 }
414 out
415 });
416 let eps_null = lawley_epsilon(
417 x_null.view(),
418 kappas,
419 penalty_null.as_ref().map(|s_pen| s_pen.view()),
420 )?;
421 Ok(eps_full - eps_null)
422}
423
424pub fn lawley_lr_bartlett_factor(
429 x: ArrayView2<'_, f64>,
430 kappas: &[RowKappas],
431 penalty: Option<ArrayView2<'_, f64>>,
432 tested: std::ops::Range<usize>,
433 ref_df: f64,
434) -> Result<f64, String> {
435 if !(ref_df.is_finite() && ref_df > 0.0) {
436 return Err(format!(
437 "lawley_lr_bartlett_factor: reference df must be finite and positive; got {ref_df}"
438 ));
439 }
440 let shift = lawley_lr_mean_shift(x, kappas, penalty, tested)?;
441 let mean_w = ref_df + shift;
442 let factor = crate::inference::higher_order::bartlett_factor_from_mean(mean_w, ref_df)
443 .ok_or_else(|| {
444 format!(
445 "lawley_lr_bartlett_factor: degenerate mean {mean_w} (Δε = {shift}, d = {ref_df})"
446 )
447 })?;
448 if !(factor.is_finite() && factor > 0.0) {
449 return Err(format!(
450 "lawley_lr_bartlett_factor: degenerate factor {factor} (Δε = {shift}, d = {ref_df})"
451 ));
452 }
453 Ok(factor)
454}
455
456#[derive(Debug, Clone)]
463pub struct RhoPenaltyComponent {
464 pub s_component: Array2<f64>,
467}
468
469const RHO_VARIATION_STEP: f64 = 0.05;
479
480pub fn lawley_lr_mean_shift_with_rho_variation(
514 x: ArrayView2<'_, f64>,
515 kappas: &[RowKappas],
516 penalty: ArrayView2<'_, f64>,
517 tested: std::ops::Range<usize>,
518 components: &[RhoPenaltyComponent],
519 rho_cov: ArrayView2<'_, f64>,
520) -> Result<f64, String> {
521 let k = x.ncols();
522 let m = components.len();
523 if m == 0 {
524 return Err(
525 "lawley_lr_mean_shift_with_rho_variation: no smoothing-parameter components"
526 .to_string(),
527 );
528 }
529 if rho_cov.nrows() != m || rho_cov.ncols() != m {
530 return Err(format!(
531 "lawley_lr_mean_shift_with_rho_variation: rho_cov is {}×{}, expected {m}×{m}",
532 rho_cov.nrows(),
533 rho_cov.ncols()
534 ));
535 }
536 for b in 0..m {
537 for c in 0..m {
538 let v_bc = rho_cov[[b, c]];
539 if !v_bc.is_finite() {
540 return Err(format!(
541 "lawley_lr_mean_shift_with_rho_variation: rho_cov[{b},{c}] is not finite"
542 ));
543 }
544 let v_cb = rho_cov[[c, b]];
545 let tol = 1e-10 * (1.0 + v_bc.abs().max(v_cb.abs()));
546 if (v_bc - v_cb).abs() > tol {
547 return Err(format!(
548 "lawley_lr_mean_shift_with_rho_variation: rho_cov must be symmetric; \
549 entries [{b},{c}]={v_bc} and [{c},{b}]={v_cb} differ"
550 ));
551 }
552 }
553 }
554 if penalty.nrows() != k || penalty.ncols() != k {
555 return Err(format!(
556 "lawley_lr_mean_shift_with_rho_variation: penalty is {}×{}, expected {k}×{k}",
557 penalty.nrows(),
558 penalty.ncols()
559 ));
560 }
561 for (b, comp) in components.iter().enumerate() {
562 if comp.s_component.nrows() != k || comp.s_component.ncols() != k {
563 return Err(format!(
564 "lawley_lr_mean_shift_with_rho_variation: component {b} is {}×{}, expected {k}×{k}",
565 comp.s_component.nrows(),
566 comp.s_component.ncols()
567 ));
568 }
569 }
570
571 let conditional = lawley_lr_mean_shift(x, kappas, Some(penalty), tested.clone())?;
574
575 let shift_at = |steps: &[(usize, f64)]| -> Result<f64, String> {
579 let mut s = penalty.to_owned();
580 for &(b, t) in steps {
581 let scale = t.exp() - 1.0;
583 s.scaled_add(scale, &components[b].s_component);
584 }
585 lawley_lr_mean_shift(x, kappas, Some(s.view()), tested.clone())
586 };
587
588 let h = RHO_VARIATION_STEP;
589 let mut quad = 0.0; let base = conditional;
594 for b in 0..m {
595 let fp = shift_at(&[(b, h)])?;
596 let fm = shift_at(&[(b, -h)])?;
597 let hbb = (fp - 2.0 * base + fm) / (h * h);
598 if !hbb.is_finite() {
599 return Err(format!(
600 "lawley_lr_mean_shift_with_rho_variation: non-finite curvature H[{b},{b}]"
601 ));
602 }
603 quad += 0.5 * hbb * rho_cov[[b, b]];
604 for c in (b + 1)..m {
605 let fpp = shift_at(&[(b, h), (c, h)])?;
606 let fpm = shift_at(&[(b, h), (c, -h)])?;
607 let fmp = shift_at(&[(b, -h), (c, h)])?;
608 let fmm = shift_at(&[(b, -h), (c, -h)])?;
609 let hbc = (fpp - fpm - fmp + fmm) / (4.0 * h * h);
610 if !hbc.is_finite() {
611 return Err(format!(
612 "lawley_lr_mean_shift_with_rho_variation: non-finite curvature H[{b},{c}]"
613 ));
614 }
615 quad += hbc * rho_cov[[b, c]];
618 }
619 }
620
621 let total = conditional + quad;
622 if !total.is_finite() {
623 return Err(format!(
624 "lawley_lr_mean_shift_with_rho_variation: non-finite total shift \
625 (conditional={conditional}, rho-variation={quad})"
626 ));
627 }
628 Ok(total)
629}
630
631pub fn lawley_lr_bartlett_factor_with_rho_variation(
653 x: ArrayView2<'_, f64>,
654 kappas: &[RowKappas],
655 penalty: ArrayView2<'_, f64>,
656 tested: std::ops::Range<usize>,
657 components: &[RhoPenaltyComponent],
658 rho_cov: ArrayView2<'_, f64>,
659 ref_df: f64,
660) -> Result<f64, String> {
661 if !(ref_df.is_finite() && ref_df > 0.0) {
662 return Err(format!(
663 "lawley_lr_bartlett_factor_with_rho_variation: reference df must be finite and positive; got {ref_df}"
664 ));
665 }
666 let shift =
667 lawley_lr_mean_shift_with_rho_variation(x, kappas, penalty, tested, components, rho_cov)?;
668 let mean_w = ref_df + shift;
669 let factor = crate::inference::higher_order::bartlett_factor_from_mean(mean_w, ref_df)
670 .ok_or_else(|| {
671 format!(
672 "lawley_lr_bartlett_factor_with_rho_variation: degenerate mean {mean_w} \
673 (Δε(ρ̂) = {shift}, d = {ref_df})"
674 )
675 })?;
676 if !(factor.is_finite() && factor > 0.0) {
677 return Err(format!(
678 "lawley_lr_bartlett_factor_with_rho_variation: degenerate factor {factor} \
679 (Δε(ρ̂) = {shift}, d = {ref_df})"
680 ));
681 }
682 Ok(factor)
683}
684
685pub fn known_scale_expected_jets(
690 family: &gam_spec::LikelihoodSpec,
691 eta: f64,
692) -> Option<RowExpectedJets> {
693 known_scale_expected_jets_with_dispersion(family, eta, 1.0)
694}
695
696pub fn known_scale_expected_jets_with_dispersion(
704 family: &gam_spec::LikelihoodSpec,
705 eta: f64,
706 dispersion: f64,
707) -> Option<RowExpectedJets> {
708 use gam_spec::{InverseLink, ResponseFamily, StandardLink};
709 match (&family.response, &family.link) {
710 (ResponseFamily::Poisson, InverseLink::Standard(StandardLink::Log)) => {
711 Some(RowExpectedJets::poisson_log(eta))
712 }
713 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Logit)) => {
714 Some(RowExpectedJets::binomial_logit(eta))
715 }
716 (ResponseFamily::Gaussian, InverseLink::Standard(StandardLink::Identity)) => {
717 (dispersion.is_finite() && dispersion > 0.0)
718 .then(|| RowExpectedJets::gaussian_identity(dispersion))
719 }
720 (ResponseFamily::Gamma, InverseLink::Standard(StandardLink::Log)) => {
721 (dispersion.is_finite() && dispersion > 0.0)
722 .then(|| RowExpectedJets::gamma_log(eta, dispersion))
723 }
724 _ => None,
725 }
726}
727
728#[cfg(test)]
729mod tests {
730 use super::*;
731 use ndarray::Array2;
732
733 fn lawley_epsilon_index_oracle(
737 x: &Array2<f64>,
738 kappas: &[RowKappas],
739 penalty: Option<&Array2<f64>>,
740 ) -> f64 {
741 let n = x.nrows();
742 let k = x.ncols();
743 let mut kappa2 = Array2::<f64>::zeros((k, k));
745 for i in 0..n {
746 for r in 0..k {
747 for s in 0..k {
748 kappa2[[r, s]] += kappas[i].k2 * x[[i, r]] * x[[i, s]];
749 }
750 }
751 }
752 if let Some(s_pen) = penalty {
753 kappa2 -= s_pen;
754 }
755 let j_view = FaerArrayView::new(&kappa2);
757 let factor = factorize_symmetricwith_fallback(j_view.as_ref(), faer::Side::Lower)
758 .expect("oracle κ_rs factorization");
759 let kappa_up = FactorizedSystem::solvemulti(&factor, &Array2::<f64>::eye(k))
760 .expect("oracle κ_rs inverse");
761
762 let arr3 = |weights: &dyn Fn(usize) -> f64, r: usize, s: usize, t: usize| -> f64 {
764 (0..n)
765 .map(|i| weights(i) * x[[i, r]] * x[[i, s]] * x[[i, t]])
766 .sum()
767 };
768 let arr4 =
769 |weights: &dyn Fn(usize) -> f64, r: usize, s: usize, t: usize, u: usize| -> f64 {
770 (0..n)
771 .map(|i| weights(i) * x[[i, r]] * x[[i, s]] * x[[i, t]] * x[[i, u]])
772 .sum()
773 };
774 let w_k3 = |i: usize| kappas[i].k3;
775 let w_k21 = |i: usize| kappas[i].k2_1;
776 let w_k4 = |i: usize| kappas[i].k4;
777 let w_k31 = |i: usize| kappas[i].k3_1;
778 let w_k211 = |i: usize| kappas[i].k2_11;
779
780 let mut lambda4 = 0.0;
781 for r in 0..k {
782 for s in 0..k {
783 for t in 0..k {
784 for u in 0..k {
785 let braces = arr4(&w_k4, r, s, t, u) / 4.0 - arr4(&w_k31, r, s, t, u)
786 + arr4(&w_k211, r, t, s, u);
787 lambda4 += kappa_up[[r, s]] * kappa_up[[t, u]] * braces;
788 }
789 }
790 }
791 }
792 let mut lambda6 = 0.0;
793 for r in 0..k {
794 for s in 0..k {
795 for t in 0..k {
796 for u in 0..k {
797 for v in 0..k {
798 for w in 0..k {
799 let braces = arr3(&w_k3, r, t, v)
800 * (arr3(&w_k3, s, u, w) / 6.0 - arr3(&w_k21, s, w, u))
801 + arr3(&w_k3, r, t, u)
802 * (arr3(&w_k3, s, v, w) / 4.0 - arr3(&w_k21, s, w, v))
803 + arr3(&w_k21, r, t, v) * arr3(&w_k21, s, w, u)
804 + arr3(&w_k21, r, t, u) * arr3(&w_k21, s, w, v);
805 lambda6 +=
806 kappa_up[[r, s]] * kappa_up[[t, u]] * kappa_up[[v, w]] * braces;
807 }
808 }
809 }
810 }
811 }
812 }
813 lambda4 - lambda6
814 }
815
816 fn intercept_design(n: usize) -> Array2<f64> {
817 Array2::<f64>::ones((n, 1))
818 }
819
820 fn digamma_integer(n: usize) -> f64 {
822 const EULER_GAMMA: f64 = 0.577_215_664_901_532_9;
823 -EULER_GAMMA + (1..n).map(|j| 1.0 / j as f64).sum::<f64>()
824 }
825
826 #[test]
827 fn exponential_intercept_matches_exact_digamma_expansion() {
828 let eta = 0.4;
831 let mut residual_prev = f64::INFINITY;
832 for &n in &[8usize, 16, 32] {
833 let jets = RowExpectedJets::gamma_log(eta, 1.0);
834 let kappas = vec![jets.kappas().expect("exponential kappas"); n];
835 let x = intercept_design(n);
836 let eps = lawley_epsilon(x.view(), &kappas, None).expect("ε");
837 let analytic = 1.0 / (6.0 * n as f64);
838 assert!(
839 (eps - analytic).abs() < 1e-12,
840 "n={n}: ε={eps} vs analytic 1/(6n)={analytic}"
841 );
842 let exact_mean = 2.0 * n as f64 * ((n as f64).ln() - digamma_integer(n));
843 let residual = (exact_mean - 1.0 - eps).abs();
844 assert!(
845 residual < 0.6 / (n * n) as f64,
846 "n={n}: |E[W] − 1 − ε| = {residual} is not O(n⁻²)"
847 );
848 assert!(
849 residual < residual_prev,
850 "n={n}: residual {residual} did not shrink from {residual_prev}"
851 );
852 residual_prev = residual;
853 }
854 }
855
856 #[test]
870 fn penalty_shift_term_is_consumed() {
871 let n = 40usize;
874 let eta = 0.2_f64;
875 let jets = RowExpectedJets::poisson_log(eta);
876 let kappas = vec![jets.kappas().expect("poisson kappas"); n];
877 let mut x = Array2::<f64>::ones((n, 2));
878 for i in 0..n {
879 x[[i, 1]] = (i as f64) / (n as f64) - 0.5;
881 }
882 let eps_unpen = lawley_epsilon(x.view(), &kappas, None).expect("ε unpenalized");
883
884 let mut distinct = std::collections::BTreeSet::new();
887 for &lambda in &[0.5_f64, 2.0, 8.0, 32.0] {
888 let mut s = Array2::<f64>::zeros((2, 2));
889 s[[1, 1]] = lambda;
890 let eps_pen = lawley_epsilon(x.view(), &kappas, Some(s.view())).expect("ε penalized");
891 assert!(
893 (eps_pen - eps_unpen).abs() > 1e-9,
894 "λ={lambda}: penalty did not move ε ({eps_pen} vs {eps_unpen}) — S is being dropped"
895 );
896 assert!(
898 eps_pen.is_finite(),
899 "λ={lambda}: ε must be finite, got {eps_pen}"
900 );
901 distinct.insert((eps_pen * 1e9) as i64);
902 }
903 assert!(
906 distinct.len() >= 3,
907 "ε must vary with λ; got {} distinct values",
908 distinct.len()
909 );
910 }
911
912 #[test]
913 fn poisson_intercept_matches_exact_pmf_mean() {
914 let lambda: f64 = 1.7;
917 for &n in &[20usize, 40] {
918 let jets = RowExpectedJets::poisson_log(lambda.ln());
919 let kappas = vec![jets.kappas().expect("poisson kappas"); n];
920 let x = intercept_design(n);
921 let eps = lawley_epsilon(x.view(), &kappas, None).expect("ε");
922 let analytic = 1.0 / (6.0 * n as f64 * lambda);
923 assert!(
924 (eps - analytic).abs() < 1e-12,
925 "n={n}: ε={eps} vs analytic 1/(6nλ)={analytic}"
926 );
927 let total_rate = n as f64 * lambda;
928 let mut pmf = (-total_rate).exp();
929 let mut exact_mean = 0.0;
930 let s_max = (total_rate + 60.0 * total_rate.sqrt()).ceil() as usize;
931 for s in 0..=s_max {
932 if s > 0 {
933 pmf *= total_rate / s as f64;
934 }
935 let s_f = s as f64;
936 let w = if s == 0 {
937 2.0 * total_rate
938 } else {
939 2.0 * (total_rate - s_f + s_f * (s_f / total_rate).ln())
940 };
941 exact_mean += pmf * w;
942 }
943 let residual = (exact_mean - 1.0 - eps).abs();
944 assert!(
945 residual < 0.7 / (n * n) as f64,
946 "n={n}: |E[W] − 1 − ε| = {residual} is not O(n⁻²)"
947 );
948 }
949 }
950
951 #[test]
959 fn gaussian_known_variance_lr_factor_is_exactly_one() {
960 let n = 20;
961 let k = 3;
962 let mut x = Array2::<f64>::zeros((n, k));
963 for i in 0..n {
964 let z = i as f64 / n as f64;
965 x[[i, 0]] = 1.0;
966 x[[i, 1]] = (5.0 * z).sin();
967 x[[i, 2]] = z - 0.5;
968 }
969 let kappas = vec![
970 RowExpectedJets::gaussian_identity(1.7)
971 .kappas()
972 .expect("gaussian kappas");
973 n
974 ];
975 let s_pen = Array2::<f64>::eye(k) * 0.4;
976 for q in [1usize, 2] {
977 let shift = lawley_lr_mean_shift(x.view(), &kappas, Some(s_pen.view()), k - q..k)
978 .expect("shift");
979 assert!(
980 shift.abs() < 1e-13,
981 "Gaussian known-variance Δε must be 0; got {shift}"
982 );
983 let c = lawley_lr_bartlett_factor(
984 x.view(),
985 &kappas,
986 Some(s_pen.view()),
987 k - q..k,
988 q as f64,
989 )
990 .expect("factor");
991 assert!(
992 (c - 1.0).abs() < 1e-13,
993 "Gaussian known-variance Bartlett factor must be exactly 1; got {c}"
994 );
995 }
996 }
997
998 #[test]
1013 fn exponential_rate_lr_factor_is_one_plus_one_sixth_n() {
1014 let eta = -0.7; for &n in &[8usize, 16, 32] {
1016 let jets = RowExpectedJets::gamma_log(eta, 1.0);
1017 let kappas = vec![jets.kappas().expect("exponential kappas"); n];
1018 let x = intercept_design(n);
1019 let c = lawley_lr_bartlett_factor(x.view(), &kappas, None, 0..1, 1.0).expect("factor");
1020 let analytic = 1.0 + 1.0 / (6.0 * n as f64);
1021 assert!(
1022 (c - analytic).abs() < 1e-12,
1023 "n={n}: factor {c} vs analytic 1 + 1/(6n) = {analytic}"
1024 );
1025 let exact_mean = 2.0 * n as f64 * ((n as f64).ln() - digamma_integer(n));
1026 assert!(
1027 (exact_mean - c).abs() < 0.6 / (n * n) as f64,
1028 "n={n}: |E[W] − c| = {} is not O(n⁻²)",
1029 (exact_mean - c).abs()
1030 );
1031 }
1032 }
1033
1034 #[test]
1054 fn bernoulli_logit_intercept_factor_matches_exact_pmf_mean() {
1055 let mu: f64 = 0.3;
1056 let u = mu * (1.0 - mu);
1057 let eta = (mu / (1.0 - mu)).ln();
1058 let mut residual_prev = f64::INFINITY;
1059 for &n in &[24usize, 48, 96] {
1060 let jets = RowExpectedJets::binomial_logit(eta);
1061 let kappas = vec![jets.kappas().expect("bernoulli kappas"); n];
1062 let x = intercept_design(n);
1063 let shift = lawley_lr_mean_shift(x.view(), &kappas, None, 0..1).expect("Δε");
1064 let analytic = (1.0 - u) / (6.0 * n as f64 * u);
1065 assert!(
1066 (shift - analytic).abs() < 1e-12,
1067 "n={n}: Δε = {shift} vs analytic (1−u)/(6nu) = {analytic}"
1068 );
1069 let c = lawley_lr_bartlett_factor(x.view(), &kappas, None, 0..1, 1.0).expect("factor");
1070 assert!(
1071 (c - (1.0 + analytic)).abs() < 1e-12,
1072 "n={n}: factor {c} vs 1 + ε = {}",
1073 1.0 + analytic
1074 );
1075 let nf = n as f64;
1077 let mut pmf = (1.0 - mu).powi(n as i32); let mut exact_mean = 0.0;
1079 for s in 0..=n {
1080 if s > 0 {
1081 pmf *= mu / (1.0 - mu) * (n - s + 1) as f64 / s as f64;
1082 }
1083 let s_f = s as f64;
1084 let t1 = if s == 0 {
1085 0.0
1086 } else {
1087 s_f * (s_f / (nf * mu)).ln()
1088 };
1089 let t2 = if s == n {
1090 0.0
1091 } else {
1092 (nf - s_f) * ((nf - s_f) / (nf * (1.0 - mu))).ln()
1093 };
1094 exact_mean += pmf * 2.0 * (t1 + t2);
1095 }
1096 let residual = (exact_mean - 1.0 - shift).abs();
1097 assert!(
1098 residual < 2.5 / (n * n) as f64,
1099 "n={n}: |E[W] − 1 − ε| = {residual} is not O(n⁻²)"
1100 );
1101 assert!(
1102 residual < residual_prev,
1103 "n={n}: residual {residual} did not shrink from {residual_prev}"
1104 );
1105 residual_prev = residual;
1106 }
1107 }
1108
1109 #[test]
1112 fn mean_shift_is_full_minus_nuisance_epsilon() {
1113 let n = 19;
1114 let mut x = Array2::<f64>::zeros((n, 2));
1115 let mut kappas = Vec::with_capacity(n);
1116 for i in 0..n {
1117 let z = i as f64 / n as f64;
1118 x[[i, 0]] = 1.0;
1119 x[[i, 1]] = z - 0.5;
1120 let eta = 0.3 - 0.8 * (z - 0.5);
1121 kappas.push(
1122 RowExpectedJets::binomial_logit(eta)
1123 .kappas()
1124 .expect("binomial kappas"),
1125 );
1126 }
1127 let mut s_pen = Array2::<f64>::zeros((2, 2));
1128 s_pen[[1, 1]] = 0.6;
1129 let shift =
1130 lawley_lr_mean_shift(x.view(), &kappas, Some(s_pen.view()), 1..2).expect("shift");
1131 let eps_full = lawley_epsilon(x.view(), &kappas, Some(s_pen.view())).expect("ε_full");
1132 let x_null = x.slice(ndarray::s![.., 0..1]).to_owned();
1133 let s_null = s_pen.slice(ndarray::s![0..1, 0..1]).to_owned();
1134 let eps_null = lawley_epsilon(x_null.view(), &kappas, Some(s_null.view())).expect("ε_null");
1135 assert!(
1136 (shift - (eps_full - eps_null)).abs() < 1e-14,
1137 "Δε = {shift} must equal ε_full − ε_null = {}",
1138 eps_full - eps_null
1139 );
1140 let kappas_w: Vec<RowKappas> = kappas.iter().map(|r| r.weighted(2.0)).collect();
1143 let mut x2 = Array2::<f64>::zeros((2 * n, 2));
1144 let mut kappas2 = Vec::with_capacity(2 * n);
1145 for i in 0..n {
1146 for rep in 0..2 {
1147 let row = 2 * i + rep;
1148 x2[[row, 0]] = x[[i, 0]];
1149 x2[[row, 1]] = x[[i, 1]];
1150 kappas2.push(kappas[i]);
1151 }
1152 }
1153 let shift_w = lawley_lr_mean_shift(x.view(), &kappas_w, Some(s_pen.view()), 1..2)
1154 .expect("weighted shift");
1155 let shift_dup = lawley_lr_mean_shift(x2.view(), &kappas2, Some(s_pen.view()), 1..2)
1156 .expect("duplicated shift");
1157 assert!(
1158 (shift_w - shift_dup).abs() < 1e-12 * (1.0 + shift_dup.abs()),
1159 "weight-2 rows ({shift_w}) must equal duplicated rows ({shift_dup})"
1160 );
1161 }
1162
1163 #[test]
1164 fn row_pair_reduction_matches_index_oracle() {
1165 let n = 17;
1168 let k = 3;
1169 let mut x = Array2::<f64>::zeros((n, k));
1170 let mut kappas = Vec::with_capacity(n);
1171 for i in 0..n {
1172 let z = i as f64 / n as f64;
1173 x[[i, 0]] = 1.0;
1174 x[[i, 1]] = (7.3 * z).sin();
1175 x[[i, 2]] = z * z - 0.4;
1176 let eta = 0.2 + 0.5 * x[[i, 1]] - 0.3 * x[[i, 2]];
1177 kappas.push(
1178 RowExpectedJets::gamma_log(eta, 1.3)
1179 .kappas()
1180 .expect("gamma kappas"),
1181 );
1182 }
1183 let fast = lawley_epsilon(x.view(), &kappas, None).expect("hat form");
1184 let oracle = lawley_epsilon_index_oracle(&x, &kappas, None);
1185 assert!(
1186 (fast - oracle).abs() < 1e-10 * (1.0 + oracle.abs()),
1187 "row-pair ε={fast} vs index-form ε={oracle}"
1188 );
1189
1190 let mut s_pen = Array2::<f64>::eye(k);
1192 s_pen[[0, 0]] = 0.0; s_pen *= 0.8;
1194 let fast_pen = lawley_epsilon(x.view(), &kappas, Some(s_pen.view())).expect("hat form");
1195 let oracle_pen = lawley_epsilon_index_oracle(&x, &kappas, Some(&s_pen));
1196 assert!(
1197 (fast_pen - oracle_pen).abs() < 1e-10 * (1.0 + oracle_pen.abs()),
1198 "penalized row-pair ε={fast_pen} vs index-form ε={oracle_pen}"
1199 );
1200 assert!(
1201 (fast_pen - fast).abs() > 1e-6,
1202 "penalty must move ε (got {fast} → {fast_pen})"
1203 );
1204 }
1205
1206 #[test]
1207 fn canonical_links_collapse_the_mixed_arrays() {
1208 for eta in [-1.3, 0.0, 0.7] {
1212 for jets in [
1213 RowExpectedJets::poisson_log(eta),
1214 RowExpectedJets::binomial_logit(eta),
1215 ] {
1216 let kappas = jets.kappas().expect("canonical kappas");
1217 assert!(
1218 (kappas.k3 - kappas.k2_1).abs() < 1e-13 * (1.0 + kappas.k3.abs()),
1219 "canonical link must satisfy κ₃ = κ₂' (η={eta}): {kappas:?}"
1220 );
1221 assert!(
1222 (kappas.k4 - kappas.k3_1).abs() < 1e-13 * (1.0 + kappas.k4.abs()),
1223 "canonical link must satisfy κ₄ = κ₃' (η={eta}): {kappas:?}"
1224 );
1225 }
1226 }
1227 }
1228
1229 #[test]
1230 fn gaussian_identity_needs_no_correction_even_penalized() {
1231 let n = 12;
1235 let jets = RowExpectedJets::gaussian_identity(2.3);
1236 let kappas = vec![jets.kappas().expect("gaussian kappas"); n];
1237 let mut x = Array2::<f64>::ones((n, 2));
1238 for i in 0..n {
1239 x[[i, 1]] = i as f64 - 5.0;
1240 }
1241 let s_pen = Array2::<f64>::eye(2) * 0.5;
1242 let eps = lawley_epsilon(x.view(), &kappas, Some(s_pen.view())).expect("ε");
1243 assert!(
1244 eps.abs() < 1e-14,
1245 "Gaussian-identity ε must be 0; got {eps}"
1246 );
1247 }
1248
1249 #[test]
1254 fn rho_variation_correction_is_zero_for_gaussian() {
1255 let n = 16usize;
1256 let jets = RowExpectedJets::gaussian_identity(1.3);
1257 let kappas = vec![jets.kappas().expect("gaussian kappas"); n];
1258 let mut x = Array2::<f64>::ones((n, 2));
1259 for i in 0..n {
1260 x[[i, 1]] = i as f64 / n as f64 - 0.5;
1261 }
1262 let mut s_comp = Array2::<f64>::zeros((2, 2));
1264 s_comp[[1, 1]] = 2.0;
1265 let penalty = s_comp.clone();
1266 let components = vec![RhoPenaltyComponent {
1267 s_component: s_comp,
1268 }];
1269 let rho_cov = Array2::from_shape_vec((1, 1), vec![5.0]).unwrap();
1271 let total = lawley_lr_mean_shift_with_rho_variation(
1272 x.view(),
1273 &kappas,
1274 penalty.view(),
1275 1..2,
1276 &components,
1277 rho_cov.view(),
1278 )
1279 .expect("rho-variation shift");
1280 assert!(
1281 total.abs() < 1e-12,
1282 "Gaussian ρ̂-variation total shift must be 0; got {total}"
1283 );
1284 }
1285
1286 #[test]
1299 fn rho_variation_correction_matches_curvature_times_variance() {
1300 let n = 50usize;
1301 let mut x = Array2::<f64>::ones((n, 2));
1302 let mut kappas = Vec::with_capacity(n);
1303 for i in 0..n {
1304 let z = i as f64 / n as f64 - 0.5;
1305 x[[i, 1]] = z;
1306 let eta = 0.3 + 0.6 * z;
1307 kappas.push(
1308 RowExpectedJets::poisson_log(eta)
1309 .kappas()
1310 .expect("poisson kappas"),
1311 );
1312 }
1313 let lambda = 3.0_f64;
1315 let mut s_comp = Array2::<f64>::zeros((2, 2));
1316 s_comp[[1, 1]] = lambda;
1317 let penalty = s_comp.clone();
1318 let components = vec![RhoPenaltyComponent {
1319 s_component: s_comp.clone(),
1320 }];
1321 let tested = 1..2;
1322
1323 let conditional =
1325 lawley_lr_mean_shift(x.view(), &kappas, Some(penalty.view()), tested.clone())
1326 .expect("conditional shift");
1327
1328 let de_at = |t: f64| {
1332 let mut s = Array2::<f64>::zeros((2, 2));
1333 s[[1, 1]] = lambda * t.exp();
1334 lawley_lr_mean_shift(x.view(), &kappas, Some(s.view()), tested.clone())
1335 .expect("perturbed shift")
1336 };
1337 let h = 0.05_f64;
1338 let d2_h = (de_at(h) - 2.0 * conditional + de_at(-h)) / (h * h);
1339 let d2_2h = (de_at(2.0 * h) - 2.0 * conditional + de_at(-2.0 * h)) / (4.0 * h * h);
1340 let curvature = (4.0 * d2_h - d2_2h) / 3.0;
1342 assert!(
1343 curvature.abs() > 1e-9,
1344 "fixture must have non-zero ρ-curvature; got {curvature}"
1345 );
1346
1347 let var_rho = 0.8_f64; let rho_cov = Array2::from_shape_vec((1, 1), vec![var_rho]).unwrap();
1349 let total = lawley_lr_mean_shift_with_rho_variation(
1350 x.view(),
1351 &kappas,
1352 penalty.view(),
1353 tested.clone(),
1354 &components,
1355 rho_cov.view(),
1356 )
1357 .expect("rho-variation shift");
1358
1359 let expected = conditional + 0.5 * curvature * var_rho;
1361 assert!(
1362 (total - expected).abs() < 1e-6 * (1.0 + expected.abs()),
1363 "ρ̂-variation total {total} must equal conditional + ½ H Var = {expected} \
1364 (conditional={conditional}, H={curvature}, Var={var_rho})"
1365 );
1366 assert!(
1368 (total - conditional).abs() > 1e-9,
1369 "ρ̂-variation correction must be non-zero (H={curvature}, Var={var_rho}); \
1370 total={total} conditional={conditional}"
1371 );
1372
1373 let zero_cov = Array2::from_shape_vec((1, 1), vec![0.0]).unwrap();
1375 let total_zero = lawley_lr_mean_shift_with_rho_variation(
1376 x.view(),
1377 &kappas,
1378 penalty.view(),
1379 tested.clone(),
1380 &components,
1381 zero_cov.view(),
1382 )
1383 .expect("zero-variance shift");
1384 assert!(
1385 (total_zero - conditional).abs() < 1e-12,
1386 "zero ρ-variance must recover the conditional shift: {total_zero} vs {conditional}"
1387 );
1388 }
1389
1390 #[test]
1404 fn rho_variation_factor_folds_estimated_lambda_into_c() {
1405 let n = 50usize;
1408 let mut x = Array2::<f64>::ones((n, 2));
1409 let mut kappas = Vec::with_capacity(n);
1410 for i in 0..n {
1411 let z = i as f64 / n as f64 - 0.5;
1412 x[[i, 1]] = z;
1413 let eta = 0.3 + 0.6 * z;
1414 kappas.push(
1415 RowExpectedJets::poisson_log(eta)
1416 .kappas()
1417 .expect("poisson kappas"),
1418 );
1419 }
1420 let lambda = 3.0_f64;
1421 let mut s_comp = Array2::<f64>::zeros((2, 2));
1422 s_comp[[1, 1]] = lambda;
1423 let penalty = s_comp.clone();
1424 let components = vec![RhoPenaltyComponent {
1425 s_component: s_comp,
1426 }];
1427 let tested = 1..2;
1428 let ref_df = 1.0_f64;
1429 let var_rho = 0.8_f64;
1430 let rho_cov = Array2::from_shape_vec((1, 1), vec![var_rho]).unwrap();
1431
1432 let total = lawley_lr_mean_shift_with_rho_variation(
1434 x.view(),
1435 &kappas,
1436 penalty.view(),
1437 tested.clone(),
1438 &components,
1439 rho_cov.view(),
1440 )
1441 .expect("total shift");
1442 let factor = lawley_lr_bartlett_factor_with_rho_variation(
1443 x.view(),
1444 &kappas,
1445 penalty.view(),
1446 tested.clone(),
1447 &components,
1448 rho_cov.view(),
1449 ref_df,
1450 )
1451 .expect("estimated-λ factor");
1452 assert!(
1453 (factor - (1.0 + total / ref_df)).abs() < 1e-12,
1454 "estimated-λ factor {factor} must equal 1 + Δε(ρ̂)/d = {}",
1455 1.0 + total / ref_df
1456 );
1457
1458 let conditional_factor = lawley_lr_bartlett_factor(
1461 x.view(),
1462 &kappas,
1463 Some(penalty.view()),
1464 tested.clone(),
1465 ref_df,
1466 )
1467 .expect("conditional factor");
1468 assert!(
1469 (factor - conditional_factor).abs() > 1e-9,
1470 "estimated-λ factor {factor} must differ from the fixed-λ factor \
1471 {conditional_factor} (ρ̂-variation is load-bearing)"
1472 );
1473
1474 let g_kappas = vec![
1477 RowExpectedJets::gaussian_identity(1.3)
1478 .kappas()
1479 .expect("gaussian kappas");
1480 n
1481 ];
1482 let big_cov = Array2::from_shape_vec((1, 1), vec![5.0]).unwrap();
1483 let g_factor = lawley_lr_bartlett_factor_with_rho_variation(
1484 x.view(),
1485 &g_kappas,
1486 penalty.view(),
1487 tested.clone(),
1488 &components,
1489 big_cov.view(),
1490 ref_df,
1491 )
1492 .expect("gaussian factor");
1493 assert!(
1494 (g_factor - 1.0).abs() < 1e-12,
1495 "Gaussian known-variance estimated-λ factor must be exactly 1; got {g_factor}"
1496 );
1497
1498 assert!(
1500 lawley_lr_bartlett_factor_with_rho_variation(
1501 x.view(),
1502 &kappas,
1503 penalty.view(),
1504 tested.clone(),
1505 &components,
1506 rho_cov.view(),
1507 0.0,
1508 )
1509 .is_err()
1510 );
1511 }
1512
1513 #[test]
1519 fn rho_variation_includes_symmetric_cross_terms() {
1520 let n = 40usize;
1521 let mut x = Array2::<f64>::ones((n, 3));
1522 let mut kappas = Vec::with_capacity(n);
1523 for i in 0..n {
1524 let z = i as f64 / n as f64 - 0.5;
1525 x[[i, 1]] = z;
1526 x[[i, 2]] = z * z - 0.1;
1527 let eta = 0.2 + 0.5 * z - 0.3 * x[[i, 2]];
1528 kappas.push(
1529 RowExpectedJets::binomial_logit(eta)
1530 .kappas()
1531 .expect("binomial kappas"),
1532 );
1533 }
1534 let (l1, l2) = (2.0_f64, 4.0_f64);
1536 let mut s1 = Array2::<f64>::zeros((3, 3));
1537 s1[[1, 1]] = l1;
1538 let mut s2 = Array2::<f64>::zeros((3, 3));
1539 s2[[2, 2]] = l2;
1540 let penalty = &s1 + &s2;
1541 let components = vec![
1542 RhoPenaltyComponent {
1543 s_component: s1.clone(),
1544 },
1545 RhoPenaltyComponent {
1546 s_component: s2.clone(),
1547 },
1548 ];
1549 let tested = 1..3;
1551 let conditional =
1552 lawley_lr_mean_shift(x.view(), &kappas, Some(penalty.view()), tested.clone())
1553 .expect("conditional");
1554
1555 let de = |t0: f64, t1: f64| {
1558 let mut s = Array2::<f64>::zeros((3, 3));
1559 s[[1, 1]] = l1 * t0.exp();
1560 s[[2, 2]] = l2 * t1.exp();
1561 lawley_lr_mean_shift(x.view(), &kappas, Some(s.view()), tested.clone())
1562 .expect("perturbed")
1563 };
1564 let h = 0.05_f64;
1565 let h00 = (de(h, 0.0) - 2.0 * conditional + de(-h, 0.0)) / (h * h);
1566 let h11 = (de(0.0, h) - 2.0 * conditional + de(0.0, -h)) / (h * h);
1567 let h01 = (de(h, h) - de(h, -h) - de(-h, h) + de(-h, -h)) / (4.0 * h * h);
1568
1569 let rho_cov = Array2::from_shape_vec((2, 2), vec![0.7, 0.2, 0.2, 0.5]).unwrap();
1571 let total = lawley_lr_mean_shift_with_rho_variation(
1572 x.view(),
1573 &kappas,
1574 penalty.view(),
1575 tested.clone(),
1576 &components,
1577 rho_cov.view(),
1578 )
1579 .expect("rho-variation shift");
1580
1581 let expected = conditional
1583 + 0.5 * (h00 * rho_cov[[0, 0]] + h11 * rho_cov[[1, 1]])
1584 + h01 * rho_cov[[0, 1]];
1585 assert!(
1586 (total - expected).abs() < 1e-6 * (1.0 + expected.abs()),
1587 "two-parameter ρ̂-variation {total} must equal {expected} \
1588 (H00={h00}, H11={h11}, H01={h01})"
1589 );
1590 let diag_only = conditional + 0.5 * (h00 * rho_cov[[0, 0]] + h11 * rho_cov[[1, 1]]);
1592 assert!(
1593 (total - diag_only).abs() > 1e-9,
1594 "cross term H01·Cov01 must be included (off-diagonal non-zero): \
1595 total={total} diag_only={diag_only}"
1596 );
1597 }
1598
1599 #[test]
1601 fn rho_variation_rejects_shape_mismatch() {
1602 let n = 8usize;
1603 let jets = RowExpectedJets::poisson_log(0.1);
1604 let kappas = vec![jets.kappas().expect("kappas"); n];
1605 let mut x = Array2::<f64>::ones((n, 2));
1606 for i in 0..n {
1607 x[[i, 1]] = i as f64 - 4.0;
1608 }
1609 let mut s = Array2::<f64>::zeros((2, 2));
1610 s[[1, 1]] = 1.0;
1611 let components = vec![RhoPenaltyComponent {
1612 s_component: s.clone(),
1613 }];
1614 let bad_cov = Array2::<f64>::eye(2);
1616 assert!(
1617 lawley_lr_mean_shift_with_rho_variation(
1618 x.view(),
1619 &kappas,
1620 s.view(),
1621 1..2,
1622 &components,
1623 bad_cov.view(),
1624 )
1625 .is_err()
1626 );
1627 let cov1 = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
1629 assert!(
1630 lawley_lr_mean_shift_with_rho_variation(
1631 x.view(),
1632 &kappas,
1633 s.view(),
1634 1..2,
1635 &[],
1636 cov1.view(),
1637 )
1638 .is_err()
1639 );
1640 let wrong = vec![RhoPenaltyComponent {
1642 s_component: Array2::<f64>::eye(3),
1643 }];
1644 assert!(
1645 lawley_lr_mean_shift_with_rho_variation(
1646 x.view(),
1647 &kappas,
1648 s.view(),
1649 1..2,
1650 &wrong,
1651 cov1.view(),
1652 )
1653 .is_err()
1654 );
1655 let nonsymmetric_cov = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
1659 assert!(
1660 lawley_lr_mean_shift_with_rho_variation(
1661 x.view(),
1662 &kappas,
1663 s.view(),
1664 1..2,
1665 &components,
1666 nonsymmetric_cov.view(),
1667 )
1668 .is_ok()
1669 );
1670 let components2 = vec![
1671 RhoPenaltyComponent {
1672 s_component: s.clone(),
1673 },
1674 RhoPenaltyComponent {
1675 s_component: s.clone(),
1676 },
1677 ];
1678 let bad_sym = Array2::from_shape_vec((2, 2), vec![1.0, 0.25, 0.20, 1.0]).unwrap();
1679 assert!(
1680 lawley_lr_mean_shift_with_rho_variation(
1681 x.view(),
1682 &kappas,
1683 s.view(),
1684 1..2,
1685 &components2,
1686 bad_sym.view(),
1687 )
1688 .is_err()
1689 );
1690 }
1691
1692 #[test]
1693 fn epsilon_is_invariant_under_linear_reparametrization() {
1694 let n = 15;
1697 let k = 3;
1698 let mut x = Array2::<f64>::zeros((n, k));
1699 let mut kappas = Vec::with_capacity(n);
1700 for i in 0..n {
1701 let z = i as f64 / n as f64;
1702 x[[i, 0]] = 1.0;
1703 x[[i, 1]] = (3.1 * z).cos();
1704 x[[i, 2]] = z - 0.5;
1705 let eta = -0.1 + 0.6 * x[[i, 1]] + 0.4 * x[[i, 2]];
1706 kappas.push(
1707 RowExpectedJets::binomial_logit(eta)
1708 .kappas()
1709 .expect("binomial kappas"),
1710 );
1711 }
1712 let t_mat = ndarray::arr2(&[[1.0, 0.3, -0.2], [0.0, 1.4, 0.5], [0.0, 0.0, 0.8]]);
1713 let xt = x.dot(&t_mat);
1714 let eps = lawley_epsilon(x.view(), &kappas, None).expect("ε");
1715 let eps_t = lawley_epsilon(xt.view(), &kappas, None).expect("ε reparam");
1716 assert!(
1717 (eps - eps_t).abs() < 1e-9 * (1.0 + eps.abs()),
1718 "ε not reparametrization-invariant: {eps} vs {eps_t}"
1719 );
1720 }
1721}