1use crate::error::FdarError;
14use crate::iter_maybe_parallel;
15use crate::matrix::FdMatrix;
16use crate::regression::fdata_to_pc_1d;
17#[cfg(feature = "parallel")]
18use rayon::iter::ParallelIterator;
19
20#[derive(Debug, Clone, PartialEq)]
22pub struct FmmResult {
23 pub mean_function: Vec<f64>,
25 pub beta_functions: FdMatrix,
27 pub random_effects: FdMatrix,
29 pub fitted: FdMatrix,
31 pub residuals: FdMatrix,
33 pub random_variance: Vec<f64>,
35 pub sigma2_eps: f64,
37 pub sigma2_u: Vec<f64>,
39 pub ncomp: usize,
41 pub n_subjects: usize,
43 pub eigenvalues: Vec<f64>,
45}
46
47#[derive(Debug, Clone, PartialEq)]
49pub struct FmmTestResult {
50 pub f_statistics: Vec<f64>,
52 pub p_values: Vec<f64>,
54}
55
56#[must_use = "expensive computation whose result should not be discarded"]
82pub fn fmm(
83 data: &FdMatrix,
84 subject_ids: &[usize],
85 covariates: Option<&FdMatrix>,
86 ncomp: usize,
87) -> Result<FmmResult, FdarError> {
88 let n_total = data.nrows();
89 let m = data.ncols();
90 if n_total == 0 || m == 0 {
91 return Err(FdarError::InvalidDimension {
92 parameter: "data",
93 expected: "non-empty matrix".to_string(),
94 actual: format!("{n_total} x {m}"),
95 });
96 }
97 if subject_ids.len() != n_total {
98 return Err(FdarError::InvalidDimension {
99 parameter: "subject_ids",
100 expected: format!("length {n_total}"),
101 actual: format!("length {}", subject_ids.len()),
102 });
103 }
104 if ncomp == 0 {
105 return Err(FdarError::InvalidParameter {
106 parameter: "ncomp",
107 message: "must be >= 1".to_string(),
108 });
109 }
110
111 let (subject_map, n_subjects) = build_subject_map(subject_ids);
113
114 let argvals: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1).max(1) as f64).collect();
116 let fpca = fdata_to_pc_1d(data, ncomp, &argvals)?;
117 let k = fpca.scores.ncols(); let p = covariates.map_or(0, super::matrix::FdMatrix::ncols);
121 let ComponentResults {
122 gamma,
123 u_hat,
124 sigma2_u,
125 sigma2_eps,
126 } = fit_all_components(
127 &fpca.scores,
128 &subject_map,
129 n_subjects,
130 covariates,
131 p,
132 k,
133 n_total,
134 m,
135 );
136
137 let beta_functions = recover_beta_functions(&gamma, &fpca.rotation, p, m, k);
139 let random_effects = recover_random_effects(&u_hat, &fpca.rotation, n_subjects, m, k);
140
141 let random_variance = compute_random_variance(&random_effects, n_subjects, m);
143
144 let (fitted, residuals) = compute_fitted_residuals(
146 data,
147 &fpca.mean,
148 &beta_functions,
149 &random_effects,
150 covariates,
151 &subject_map,
152 n_total,
153 m,
154 p,
155 );
156
157 let eigenvalues: Vec<f64> = fpca
158 .singular_values
159 .iter()
160 .map(|&sv| sv * sv / n_total as f64)
161 .collect();
162
163 Ok(FmmResult {
164 mean_function: fpca.mean,
165 beta_functions,
166 random_effects,
167 fitted,
168 residuals,
169 random_variance,
170 sigma2_eps,
171 sigma2_u,
172 ncomp: k,
173 n_subjects,
174 eigenvalues,
175 })
176}
177
178fn build_subject_map(subject_ids: &[usize]) -> (Vec<usize>, usize) {
180 let mut unique_ids: Vec<usize> = subject_ids.to_vec();
181 unique_ids.sort_unstable();
182 unique_ids.dedup();
183 let n_subjects = unique_ids.len();
184
185 let map: Vec<usize> = subject_ids
186 .iter()
187 .map(|id| unique_ids.iter().position(|u| u == id).unwrap_or(0))
188 .collect();
189
190 (map, n_subjects)
191}
192
193struct ComponentResults {
195 gamma: Vec<Vec<f64>>, u_hat: Vec<Vec<f64>>, sigma2_u: Vec<f64>, sigma2_eps: f64, }
200
201#[allow(clippy::too_many_arguments)]
206fn fit_all_components(
207 scores: &FdMatrix,
208 subject_map: &[usize],
209 n_subjects: usize,
210 covariates: Option<&FdMatrix>,
211 p: usize,
212 k: usize,
213 n_total: usize,
214 m: usize,
215) -> ComponentResults {
216 let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
219 let score_scale = h.sqrt();
220
221 let per_comp: Vec<ScalarMixedResult> = iter_maybe_parallel!(0..k)
223 .map(|comp| {
224 let comp_scores: Vec<f64> = (0..n_total)
225 .map(|i| scores[(i, comp)] * score_scale)
226 .collect();
227 fit_scalar_mixed_model(&comp_scores, subject_map, n_subjects, covariates, p)
228 })
229 .collect();
230
231 let mut gamma = vec![vec![0.0; k]; p];
233 let mut u_hat = vec![vec![0.0; k]; n_subjects];
234 let mut sigma2_u = vec![0.0; k];
235 let mut sigma2_eps_total = 0.0;
236
237 for (comp, result) in per_comp.iter().enumerate() {
238 for j in 0..p {
239 gamma[j][comp] = result.gamma[j] / score_scale;
240 }
241 for s in 0..n_subjects {
242 u_hat[s][comp] = result.u_hat[s] / score_scale;
243 }
244 sigma2_u[comp] = result.sigma2_u;
245 sigma2_eps_total += result.sigma2_eps;
246 }
247 let sigma2_eps = sigma2_eps_total / k as f64;
248
249 ComponentResults {
250 gamma,
251 u_hat,
252 sigma2_u,
253 sigma2_eps,
254 }
255}
256
257struct ScalarMixedResult {
259 gamma: Vec<f64>, u_hat: Vec<f64>, sigma2_u: f64, sigma2_eps: f64, }
264
265struct SubjectStructure {
267 counts: Vec<usize>,
268 obs: Vec<Vec<usize>>,
269}
270
271impl SubjectStructure {
272 fn new(subject_map: &[usize], n_subjects: usize, n: usize) -> Self {
273 let mut counts = vec![0usize; n_subjects];
274 let mut obs: Vec<Vec<usize>> = vec![Vec::new(); n_subjects];
275 for i in 0..n {
276 let s = subject_map[i];
277 counts[s] += 1;
278 obs[s].push(i);
279 }
280 Self { counts, obs }
281 }
282}
283
284fn shrinkage_weights(ss: &SubjectStructure, sigma2_u: f64, sigma2_e: f64) -> Vec<f64> {
286 ss.counts
287 .iter()
288 .map(|&c| {
289 let ns = c as f64;
290 if ns < 1.0 {
291 0.0
292 } else {
293 sigma2_u / (sigma2_u + sigma2_e / ns)
294 }
295 })
296 .collect()
297}
298
299fn gls_update_gamma(
303 cov: &FdMatrix,
304 p: usize,
305 ss: &SubjectStructure,
306 weights: &[f64],
307 y: &[f64],
308 sigma2_e: f64,
309) -> Option<Vec<f64>> {
310 let n_subjects = ss.counts.len();
311 let mut xtvinvx = vec![0.0; p * p];
312 let mut xtvinvy = vec![0.0; p];
313 let inv_e = 1.0 / sigma2_e;
314
315 for s in 0..n_subjects {
316 let ns = ss.counts[s] as f64;
317 if ns < 1.0 {
318 continue;
319 }
320 let (x_sum, y_sum) = subject_sums(cov, y, &ss.obs[s], p);
321 accumulate_gls_terms(
322 cov,
323 y,
324 &ss.obs[s],
325 &x_sum,
326 y_sum,
327 weights[s],
328 ns,
329 inv_e,
330 p,
331 &mut xtvinvx,
332 &mut xtvinvy,
333 );
334 }
335
336 for j in 0..p {
337 xtvinvx[j * p + j] += 1e-10;
338 }
339 cholesky_solve(&xtvinvx, &xtvinvy, p)
340}
341
342fn subject_sums(cov: &FdMatrix, y: &[f64], obs: &[usize], p: usize) -> (Vec<f64>, f64) {
344 let mut x_sum = vec![0.0; p];
345 let mut y_sum = 0.0;
346 for &i in obs {
347 for r in 0..p {
348 x_sum[r] += cov[(i, r)];
349 }
350 y_sum += y[i];
351 }
352 (x_sum, y_sum)
353}
354
355fn accumulate_gls_terms(
357 cov: &FdMatrix,
358 y: &[f64],
359 obs: &[usize],
360 x_sum: &[f64],
361 y_sum: f64,
362 w_s: f64,
363 ns: f64,
364 inv_e: f64,
365 p: usize,
366 xtvinvx: &mut [f64],
367 xtvinvy: &mut [f64],
368) {
369 for &i in obs {
370 let vinv_y = inv_e * (y[i] - w_s * y_sum / ns);
371 for r in 0..p {
372 xtvinvy[r] += cov[(i, r)] * vinv_y;
373 for c in r..p {
374 let vinv_xc = inv_e * (cov[(i, c)] - w_s * x_sum[c] / ns);
375 let val = cov[(i, r)] * vinv_xc;
376 xtvinvx[r * p + c] += val;
377 if r != c {
378 xtvinvx[c * p + r] += val;
379 }
380 }
381 }
382 }
383}
384
385fn reml_variance_update(
390 residuals: &[f64],
391 ss: &SubjectStructure,
392 weights: &[f64],
393 sigma2_u: f64,
394 p: usize,
395) -> (f64, f64) {
396 let n_subjects = ss.counts.len();
397 let n: usize = ss.counts.iter().sum();
398 let mut sigma2_u_new = 0.0;
399 let mut sigma2_e_new = 0.0;
400
401 for s in 0..n_subjects {
402 let ns = ss.counts[s] as f64;
403 if ns < 1.0 {
404 continue;
405 }
406 let w_s = weights[s];
407 let mean_r_s: f64 = ss.obs[s].iter().map(|&i| residuals[i]).sum::<f64>() / ns;
408 let u_hat_s = w_s * mean_r_s;
409 let cond_var_s = sigma2_u * (1.0 - w_s);
410
411 sigma2_u_new += u_hat_s * u_hat_s + cond_var_s;
412 for &i in &ss.obs[s] {
413 sigma2_e_new += (residuals[i] - u_hat_s).powi(2);
414 }
415 sigma2_e_new += ns * cond_var_s;
416 }
417
418 let denom_e = (n.saturating_sub(p)).max(1) as f64;
420
421 (
422 (sigma2_u_new / n_subjects as f64).max(1e-15),
423 (sigma2_e_new / denom_e).max(1e-15),
424 )
425}
426
427fn fit_scalar_mixed_model(
433 y: &[f64],
434 subject_map: &[usize],
435 n_subjects: usize,
436 covariates: Option<&FdMatrix>,
437 p: usize,
438) -> ScalarMixedResult {
439 let n = y.len();
440 let ss = SubjectStructure::new(subject_map, n_subjects, n);
441
442 let gamma_init = estimate_fixed_effects(y, covariates, p, n);
444 let residuals_init = compute_ols_residuals(y, covariates, &gamma_init, p, n);
445 let (mut sigma2_u, mut sigma2_e) =
446 estimate_variance_components(&residuals_init, subject_map, n_subjects, n);
447
448 if sigma2_e < 1e-15 {
449 sigma2_e = 1e-6;
450 }
451 if sigma2_u < 1e-15 {
452 sigma2_u = sigma2_e * 0.1;
453 }
454
455 let mut gamma = gamma_init;
456
457 for _iter in 0..50 {
458 let sigma2_u_old = sigma2_u;
459 let sigma2_e_old = sigma2_e;
460
461 let weights = shrinkage_weights(&ss, sigma2_u, sigma2_e);
462
463 if let Some(cov) = covariates.filter(|_| p > 0) {
464 if let Some(g) = gls_update_gamma(cov, p, &ss, &weights, y, sigma2_e) {
465 gamma = g;
466 }
467 }
468
469 let r = compute_ols_residuals(y, covariates, &gamma, p, n);
470 (sigma2_u, sigma2_e) = reml_variance_update(&r, &ss, &weights, sigma2_u, p);
471
472 let delta = (sigma2_u - sigma2_u_old).abs() + (sigma2_e - sigma2_e_old).abs();
473 if delta < 1e-10 * (sigma2_u_old + sigma2_e_old) {
474 break;
475 }
476 }
477
478 let final_residuals = compute_ols_residuals(y, covariates, &gamma, p, n);
479 let u_hat = compute_blup(
480 &final_residuals,
481 subject_map,
482 n_subjects,
483 sigma2_u,
484 sigma2_e,
485 );
486
487 ScalarMixedResult {
488 gamma,
489 u_hat,
490 sigma2_u,
491 sigma2_eps: sigma2_e,
492 }
493}
494
495fn estimate_fixed_effects(
497 y: &[f64],
498 covariates: Option<&FdMatrix>,
499 p: usize,
500 n: usize,
501) -> Vec<f64> {
502 if p == 0 || covariates.is_none() {
503 return Vec::new();
504 }
505 let cov = covariates.expect("checked: covariates is Some");
506
507 let mut xtx = vec![0.0; p * p];
509 let mut xty = vec![0.0; p];
510 for i in 0..n {
511 for r in 0..p {
512 xty[r] += cov[(i, r)] * y[i];
513 for s in r..p {
514 let val = cov[(i, r)] * cov[(i, s)];
515 xtx[r * p + s] += val;
516 if r != s {
517 xtx[s * p + r] += val;
518 }
519 }
520 }
521 }
522 for j in 0..p {
524 xtx[j * p + j] += 1e-8;
525 }
526
527 cholesky_solve(&xtx, &xty, p).unwrap_or(vec![0.0; p])
528}
529
530fn cholesky_factor_famm(a: &[f64], p: usize) -> Option<Vec<f64>> {
532 let mut l = vec![0.0; p * p];
533 for j in 0..p {
534 let mut sum = 0.0;
535 for k in 0..j {
536 sum += l[j * p + k] * l[j * p + k];
537 }
538 let diag = a[j * p + j] - sum;
539 if diag <= 0.0 {
540 return None;
541 }
542 l[j * p + j] = diag.sqrt();
543 for i in (j + 1)..p {
544 let mut s = 0.0;
545 for k in 0..j {
546 s += l[i * p + k] * l[j * p + k];
547 }
548 l[i * p + j] = (a[i * p + j] - s) / l[j * p + j];
549 }
550 }
551 Some(l)
552}
553
554fn cholesky_triangular_solve(l: &[f64], b: &[f64], p: usize) -> Vec<f64> {
556 let mut z = vec![0.0; p];
557 for i in 0..p {
558 let mut s = 0.0;
559 for j in 0..i {
560 s += l[i * p + j] * z[j];
561 }
562 z[i] = (b[i] - s) / l[i * p + i];
563 }
564 for i in (0..p).rev() {
565 let mut s = 0.0;
566 for j in (i + 1)..p {
567 s += l[j * p + i] * z[j];
568 }
569 z[i] = (z[i] - s) / l[i * p + i];
570 }
571 z
572}
573
574fn cholesky_solve(a: &[f64], b: &[f64], p: usize) -> Option<Vec<f64>> {
576 let l = cholesky_factor_famm(a, p)?;
577 Some(cholesky_triangular_solve(&l, b, p))
578}
579
580fn compute_ols_residuals(
582 y: &[f64],
583 covariates: Option<&FdMatrix>,
584 gamma: &[f64],
585 p: usize,
586 n: usize,
587) -> Vec<f64> {
588 let mut residuals = y.to_vec();
589 if p > 0 {
590 if let Some(cov) = covariates {
591 for i in 0..n {
592 for j in 0..p {
593 residuals[i] -= cov[(i, j)] * gamma[j];
594 }
595 }
596 }
597 }
598 residuals
599}
600
601fn estimate_variance_components(
605 residuals: &[f64],
606 subject_map: &[usize],
607 n_subjects: usize,
608 n: usize,
609) -> (f64, f64) {
610 let mut subject_sums = vec![0.0; n_subjects];
612 let mut subject_counts = vec![0usize; n_subjects];
613 for i in 0..n {
614 let s = subject_map[i];
615 subject_sums[s] += residuals[i];
616 subject_counts[s] += 1;
617 }
618 let subject_means: Vec<f64> = subject_sums
619 .iter()
620 .zip(&subject_counts)
621 .map(|(&s, &c)| if c > 0 { s / c as f64 } else { 0.0 })
622 .collect();
623
624 let mut ss_within = 0.0;
626 for i in 0..n {
627 let s = subject_map[i];
628 ss_within += (residuals[i] - subject_means[s]).powi(2);
629 }
630 let df_within = n.saturating_sub(n_subjects);
631
632 let grand_mean = residuals.iter().sum::<f64>() / n as f64;
634 let mut ss_between = 0.0;
635 for s in 0..n_subjects {
636 ss_between += subject_counts[s] as f64 * (subject_means[s] - grand_mean).powi(2);
637 }
638
639 let sigma2_eps = if df_within > 0 {
640 ss_within / df_within as f64
641 } else {
642 1e-6
643 };
644
645 let n_bar = n as f64 / n_subjects.max(1) as f64;
647 let df_between = n_subjects.saturating_sub(1).max(1);
648 let ms_between = ss_between / df_between as f64;
649 let sigma2_u = ((ms_between - sigma2_eps) / n_bar).max(0.0);
650
651 (sigma2_u, sigma2_eps)
652}
653
654fn compute_blup(
658 residuals: &[f64],
659 subject_map: &[usize],
660 n_subjects: usize,
661 sigma2_u: f64,
662 sigma2_eps: f64,
663) -> Vec<f64> {
664 let mut subject_sums = vec![0.0; n_subjects];
665 let mut subject_counts = vec![0usize; n_subjects];
666 for (i, &r) in residuals.iter().enumerate() {
667 let s = subject_map[i];
668 subject_sums[s] += r;
669 subject_counts[s] += 1;
670 }
671
672 (0..n_subjects)
673 .map(|s| {
674 let ni = subject_counts[s] as f64;
675 if ni < 1.0 {
676 return 0.0;
677 }
678 let mean_r = subject_sums[s] / ni;
679 let shrinkage = sigma2_u / (sigma2_u + sigma2_eps / ni).max(1e-15);
680 shrinkage * mean_r
681 })
682 .collect()
683}
684
685fn recover_beta_functions(
691 gamma: &[Vec<f64>],
692 rotation: &FdMatrix,
693 p: usize,
694 m: usize,
695 k: usize,
696) -> FdMatrix {
697 let mut beta = FdMatrix::zeros(p, m);
698 for j in 0..p {
699 for t in 0..m {
700 let mut val = 0.0;
701 for comp in 0..k {
702 val += gamma[j][comp] * rotation[(t, comp)];
703 }
704 beta[(j, t)] = val;
705 }
706 }
707 beta
708}
709
710fn recover_random_effects(
712 u_hat: &[Vec<f64>],
713 rotation: &FdMatrix,
714 n_subjects: usize,
715 m: usize,
716 k: usize,
717) -> FdMatrix {
718 let mut re = FdMatrix::zeros(n_subjects, m);
719 for s in 0..n_subjects {
720 for t in 0..m {
721 let mut val = 0.0;
722 for comp in 0..k {
723 val += u_hat[s][comp] * rotation[(t, comp)];
724 }
725 re[(s, t)] = val;
726 }
727 }
728 re
729}
730
731fn compute_random_variance(random_effects: &FdMatrix, n_subjects: usize, m: usize) -> Vec<f64> {
733 (0..m)
734 .map(|t| {
735 let mean: f64 =
736 (0..n_subjects).map(|s| random_effects[(s, t)]).sum::<f64>() / n_subjects as f64;
737 let var: f64 = (0..n_subjects)
738 .map(|s| (random_effects[(s, t)] - mean).powi(2))
739 .sum::<f64>()
740 / n_subjects.max(1) as f64;
741 var
742 })
743 .collect()
744}
745
746fn compute_fitted_residuals(
748 data: &FdMatrix,
749 mean_function: &[f64],
750 beta_functions: &FdMatrix,
751 random_effects: &FdMatrix,
752 covariates: Option<&FdMatrix>,
753 subject_map: &[usize],
754 n_total: usize,
755 m: usize,
756 p: usize,
757) -> (FdMatrix, FdMatrix) {
758 let mut fitted = FdMatrix::zeros(n_total, m);
759 let mut residuals = FdMatrix::zeros(n_total, m);
760
761 for i in 0..n_total {
762 let s = subject_map[i];
763 for t in 0..m {
764 let mut val = mean_function[t] + random_effects[(s, t)];
765 if p > 0 {
766 if let Some(cov) = covariates {
767 for j in 0..p {
768 val += cov[(i, j)] * beta_functions[(j, t)];
769 }
770 }
771 }
772 fitted[(i, t)] = val;
773 residuals[(i, t)] = data[(i, t)] - val;
774 }
775 }
776
777 (fitted, residuals)
778}
779
780#[must_use = "prediction result should not be discarded"]
792pub fn fmm_predict(result: &FmmResult, new_covariates: Option<&FdMatrix>) -> FdMatrix {
793 let m = result.mean_function.len();
794 let n_new = new_covariates.map_or(1, super::matrix::FdMatrix::nrows);
795 let p = result.beta_functions.nrows();
796
797 let mut predicted = FdMatrix::zeros(n_new, m);
798 for i in 0..n_new {
799 for t in 0..m {
800 let mut val = result.mean_function[t];
801 if let Some(cov) = new_covariates {
802 for j in 0..p {
803 val += cov[(i, j)] * result.beta_functions[(j, t)];
804 }
805 }
806 predicted[(i, t)] = val;
807 }
808 }
809 predicted
810}
811
812#[must_use = "expensive computation whose result should not be discarded"]
835pub fn fmm_test_fixed(
836 data: &FdMatrix,
837 subject_ids: &[usize],
838 covariates: &FdMatrix,
839 ncomp: usize,
840 n_perm: usize,
841 seed: u64,
842) -> Result<FmmTestResult, FdarError> {
843 let n_total = data.nrows();
844 let m = data.ncols();
845 let p = covariates.ncols();
846 if n_total == 0 {
847 return Err(FdarError::InvalidDimension {
848 parameter: "data",
849 expected: "non-empty matrix".to_string(),
850 actual: format!("{n_total} rows"),
851 });
852 }
853 if p == 0 {
854 return Err(FdarError::InvalidDimension {
855 parameter: "covariates",
856 expected: "at least 1 column".to_string(),
857 actual: "0 columns".to_string(),
858 });
859 }
860
861 let result = fmm(data, subject_ids, Some(covariates), ncomp)?;
863
864 let observed_stats = compute_integrated_beta_sq(&result.beta_functions, p, m);
866
867 let (f_statistics, p_values) = permutation_test(
869 data,
870 subject_ids,
871 covariates,
872 ncomp,
873 n_perm,
874 seed,
875 &observed_stats,
876 p,
877 m,
878 );
879
880 Ok(FmmTestResult {
881 f_statistics,
882 p_values,
883 })
884}
885
886fn compute_integrated_beta_sq(beta: &FdMatrix, p: usize, m: usize) -> Vec<f64> {
888 let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
889 (0..p)
890 .map(|j| {
891 let ss: f64 = (0..m).map(|t| beta[(j, t)].powi(2)).sum();
892 ss * h
893 })
894 .collect()
895}
896
897fn permutation_test(
899 data: &FdMatrix,
900 subject_ids: &[usize],
901 covariates: &FdMatrix,
902 ncomp: usize,
903 n_perm: usize,
904 seed: u64,
905 observed_stats: &[f64],
906 p: usize,
907 m: usize,
908) -> (Vec<f64>, Vec<f64>) {
909 use rand::prelude::*;
910 let n_total = data.nrows();
911 let mut rng = StdRng::seed_from_u64(seed);
912 let mut n_ge = vec![0usize; p];
913
914 for _ in 0..n_perm {
915 let mut perm_indices: Vec<usize> = (0..n_total).collect();
917 perm_indices.shuffle(&mut rng);
918 let perm_cov = permute_rows(covariates, &perm_indices);
919
920 if let Ok(perm_result) = fmm(data, subject_ids, Some(&perm_cov), ncomp) {
921 let perm_stats = compute_integrated_beta_sq(&perm_result.beta_functions, p, m);
922 for j in 0..p {
923 if perm_stats[j] >= observed_stats[j] {
924 n_ge[j] += 1;
925 }
926 }
927 }
928 }
929
930 let p_values: Vec<f64> = n_ge
931 .iter()
932 .map(|&count| (count + 1) as f64 / (n_perm + 1) as f64)
933 .collect();
934 let f_statistics = observed_stats.to_vec();
935
936 (f_statistics, p_values)
937}
938
939fn permute_rows(mat: &FdMatrix, indices: &[usize]) -> FdMatrix {
941 let n = indices.len();
942 let m = mat.ncols();
943 let mut result = FdMatrix::zeros(n, m);
944 for (new_i, &old_i) in indices.iter().enumerate() {
945 for j in 0..m {
946 result[(new_i, j)] = mat[(old_i, j)];
947 }
948 }
949 result
950}
951
952#[cfg(test)]
957mod tests {
958 use super::*;
959 use crate::test_helpers::uniform_grid;
960 use std::f64::consts::PI;
961
962 fn generate_fmm_data(
965 n_subjects: usize,
966 n_visits: usize,
967 m: usize,
968 ) -> (FdMatrix, Vec<usize>, FdMatrix, Vec<f64>) {
969 let t = uniform_grid(m);
970 let n_total = n_subjects * n_visits;
971 let mut col_major = vec![0.0; n_total * m];
972 let mut subject_ids = vec![0usize; n_total];
973 let mut cov_data = vec![0.0; n_total];
974
975 for s in 0..n_subjects {
976 let z = s as f64 / n_subjects as f64; let subject_effect = 0.5 * (s as f64 - n_subjects as f64 / 2.0); for v in 0..n_visits {
980 let obs = s * n_visits + v;
981 subject_ids[obs] = s;
982 cov_data[obs] = z;
983 let noise_scale = 0.05;
984
985 for (j, &tj) in t.iter().enumerate() {
986 let mu = (2.0 * PI * tj).sin();
988 let fixed = z * tj * 3.0;
989 let random = subject_effect * (2.0 * PI * tj).cos() * 0.3;
990 let noise = noise_scale * ((obs * 7 + j * 3) % 100) as f64 / 100.0;
991 col_major[obs + j * n_total] = mu + fixed + random + noise;
992 }
993 }
994 }
995
996 let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
997 let covariates = FdMatrix::from_column_major(cov_data, n_total, 1).unwrap();
998 (data, subject_ids, covariates, t)
999 }
1000
1001 #[test]
1002 fn test_fmm_basic() {
1003 let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1004 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1005
1006 assert_eq!(result.mean_function.len(), 50);
1007 assert_eq!(result.beta_functions.nrows(), 1); assert_eq!(result.beta_functions.ncols(), 50);
1009 assert_eq!(result.random_effects.nrows(), 10);
1010 assert_eq!(result.fitted.nrows(), 30);
1011 assert_eq!(result.residuals.nrows(), 30);
1012 assert_eq!(result.n_subjects, 10);
1013 }
1014
1015 #[test]
1016 fn test_fmm_fitted_plus_residuals_equals_data() {
1017 let (data, subject_ids, covariates, _t) = generate_fmm_data(8, 3, 40);
1018 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1019
1020 let n = data.nrows();
1021 let m = data.ncols();
1022 for i in 0..n {
1023 for t in 0..m {
1024 let reconstructed = result.fitted[(i, t)] + result.residuals[(i, t)];
1025 assert!(
1026 (reconstructed - data[(i, t)]).abs() < 1e-8,
1027 "Fitted + residual should equal data at ({}, {}): {} vs {}",
1028 i,
1029 t,
1030 reconstructed,
1031 data[(i, t)]
1032 );
1033 }
1034 }
1035 }
1036
1037 #[test]
1038 fn test_fmm_random_variance_positive() {
1039 let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1040 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1041
1042 for &v in &result.random_variance {
1043 assert!(v >= 0.0, "Random variance should be non-negative");
1044 }
1045 }
1046
1047 #[test]
1048 fn test_fmm_no_covariates() {
1049 let (data, subject_ids, _cov, _t) = generate_fmm_data(8, 3, 40);
1050 let result = fmm(&data, &subject_ids, None, 3).unwrap();
1051
1052 assert_eq!(result.beta_functions.nrows(), 0);
1053 assert_eq!(result.n_subjects, 8);
1054 assert_eq!(result.fitted.nrows(), 24);
1055 }
1056
1057 #[test]
1058 fn test_fmm_predict() {
1059 let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1060 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1061
1062 let new_cov = FdMatrix::from_column_major(vec![0.5], 1, 1).unwrap();
1064 let predicted = fmm_predict(&result, Some(&new_cov));
1065
1066 assert_eq!(predicted.nrows(), 1);
1067 assert_eq!(predicted.ncols(), 50);
1068
1069 for t in 0..50 {
1071 assert!(predicted[(0, t)].is_finite());
1072 assert!(
1073 predicted[(0, t)].abs() < 20.0,
1074 "Predicted value too extreme at t={}: {}",
1075 t,
1076 predicted[(0, t)]
1077 );
1078 }
1079 }
1080
1081 #[test]
1082 fn test_fmm_test_fixed_detects_effect() {
1083 let (data, subject_ids, covariates, _t) = generate_fmm_data(15, 3, 40);
1084
1085 let result = fmm_test_fixed(&data, &subject_ids, &covariates, 3, 99, 42).unwrap();
1086
1087 assert_eq!(result.f_statistics.len(), 1);
1088 assert_eq!(result.p_values.len(), 1);
1089 assert!(
1090 result.p_values[0] < 0.1,
1091 "Should detect covariate effect, got p={}",
1092 result.p_values[0]
1093 );
1094 }
1095
1096 #[test]
1097 fn test_fmm_test_fixed_no_effect() {
1098 let n_subjects = 10;
1099 let n_visits = 3;
1100 let m = 40;
1101 let t = uniform_grid(m);
1102 let n_total = n_subjects * n_visits;
1103
1104 let mut col_major = vec![0.0; n_total * m];
1106 let mut subject_ids = vec![0usize; n_total];
1107 let mut cov_data = vec![0.0; n_total];
1108
1109 for s in 0..n_subjects {
1110 for v in 0..n_visits {
1111 let obs = s * n_visits + v;
1112 subject_ids[obs] = s;
1113 cov_data[obs] = s as f64 / n_subjects as f64;
1114 for (j, &tj) in t.iter().enumerate() {
1115 col_major[obs + j * n_total] =
1116 (2.0 * PI * tj).sin() + 0.1 * ((obs * 7 + j * 3) % 100) as f64 / 100.0;
1117 }
1118 }
1119 }
1120
1121 let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
1122 let covariates = FdMatrix::from_column_major(cov_data, n_total, 1).unwrap();
1123
1124 let result = fmm_test_fixed(&data, &subject_ids, &covariates, 3, 99, 42).unwrap();
1125 assert!(
1126 result.p_values[0] > 0.05,
1127 "Should not detect effect, got p={}",
1128 result.p_values[0]
1129 );
1130 }
1131
1132 #[test]
1133 fn test_fmm_invalid_input() {
1134 let data = FdMatrix::zeros(0, 0);
1135 assert!(fmm(&data, &[], None, 1).is_err());
1136
1137 let data = FdMatrix::zeros(10, 50);
1138 let ids = vec![0; 5]; assert!(fmm(&data, &ids, None, 1).is_err());
1140 }
1141
1142 #[test]
1143 fn test_fmm_single_visit_per_subject() {
1144 let n = 10;
1145 let m = 40;
1146 let t = uniform_grid(m);
1147 let mut col_major = vec![0.0; n * m];
1148 let subject_ids: Vec<usize> = (0..n).collect();
1149
1150 for i in 0..n {
1151 for (j, &tj) in t.iter().enumerate() {
1152 col_major[i + j * n] = (2.0 * PI * tj).sin();
1153 }
1154 }
1155 let data = FdMatrix::from_column_major(col_major, n, m).unwrap();
1156
1157 let result = fmm(&data, &subject_ids, None, 2).unwrap();
1159 assert_eq!(result.n_subjects, n);
1160 assert_eq!(result.fitted.nrows(), n);
1161 }
1162
1163 #[test]
1164 fn test_build_subject_map() {
1165 let (map, n) = build_subject_map(&[5, 5, 10, 10, 20]);
1166 assert_eq!(n, 3);
1167 assert_eq!(map, vec![0, 0, 1, 1, 2]);
1168 }
1169
1170 #[test]
1171 fn test_variance_components_positive() {
1172 let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1173 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1174
1175 assert!(result.sigma2_eps >= 0.0);
1176 for &s in &result.sigma2_u {
1177 assert!(s >= 0.0);
1178 }
1179 }
1180
1181 #[test]
1186 fn test_fmm_ncomp_zero_returns_error() {
1187 let (data, subject_ids, _cov, _t) = generate_fmm_data(5, 2, 20);
1188 let err = fmm(&data, &subject_ids, None, 0).unwrap_err();
1189 match err {
1190 FdarError::InvalidParameter { parameter, .. } => {
1191 assert_eq!(parameter, "ncomp");
1192 }
1193 other => panic!("Expected InvalidParameter, got {:?}", other),
1194 }
1195 }
1196
1197 #[test]
1198 fn test_fmm_single_component() {
1199 let (data, subject_ids, covariates, _t) = generate_fmm_data(8, 3, 30);
1201 let result = fmm(&data, &subject_ids, Some(&covariates), 1).unwrap();
1202
1203 assert_eq!(result.ncomp, 1);
1204 assert_eq!(result.sigma2_u.len(), 1);
1205 assert_eq!(result.eigenvalues.len(), 1);
1206 assert_eq!(result.mean_function.len(), 30);
1207 for i in 0..data.nrows() {
1209 for t in 0..data.ncols() {
1210 let diff = (result.fitted[(i, t)] + result.residuals[(i, t)] - data[(i, t)]).abs();
1211 assert!(diff < 1e-8);
1212 }
1213 }
1214 }
1215
1216 #[test]
1217 fn test_fmm_two_subjects() {
1218 let n_subjects = 2;
1220 let n_visits = 5;
1221 let m = 20;
1222 let t = uniform_grid(m);
1223 let n_total = n_subjects * n_visits;
1224 let mut col_major = vec![0.0; n_total * m];
1225 let mut subject_ids = vec![0usize; n_total];
1226
1227 for s in 0..n_subjects {
1228 for v in 0..n_visits {
1229 let obs = s * n_visits + v;
1230 subject_ids[obs] = s;
1231 for (j, &tj) in t.iter().enumerate() {
1232 col_major[obs + j * n_total] =
1233 (2.0 * PI * tj).sin() + (s as f64) * 0.5 + 0.01 * v as f64;
1234 }
1235 }
1236 }
1237 let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
1238 let result = fmm(&data, &subject_ids, None, 2).unwrap();
1239
1240 assert_eq!(result.n_subjects, 2);
1241 assert_eq!(result.random_effects.nrows(), 2);
1242 assert_eq!(result.fitted.nrows(), n_total);
1243 }
1244
1245 #[test]
1246 fn test_fmm_predict_no_covariates() {
1247 let (data, subject_ids, _cov, _t) = generate_fmm_data(6, 3, 30);
1248 let result = fmm(&data, &subject_ids, None, 2).unwrap();
1249
1250 let predicted = fmm_predict(&result, None);
1252 assert_eq!(predicted.nrows(), 1);
1253 assert_eq!(predicted.ncols(), 30);
1254 for t in 0..30 {
1255 let diff = (predicted[(0, t)] - result.mean_function[t]).abs();
1256 assert!(
1257 diff < 1e-12,
1258 "Without covariates, prediction should equal mean"
1259 );
1260 }
1261 }
1262
1263 #[test]
1264 fn test_fmm_predict_multiple_new_subjects() {
1265 let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 40);
1266 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1267
1268 let new_cov = FdMatrix::from_column_major(vec![0.1, 0.5, 0.9], 3, 1).unwrap();
1270 let predicted = fmm_predict(&result, Some(&new_cov));
1271
1272 assert_eq!(predicted.nrows(), 3);
1273 assert_eq!(predicted.ncols(), 40);
1274
1275 for i in 0..3 {
1277 for t in 0..40 {
1278 assert!(predicted[(i, t)].is_finite());
1279 }
1280 }
1281
1282 let diff_01: f64 = (0..40)
1284 .map(|t| (predicted[(0, t)] - predicted[(1, t)]).powi(2))
1285 .sum();
1286 assert!(
1287 diff_01 > 1e-10,
1288 "Different covariates should yield different predictions"
1289 );
1290 }
1291
1292 #[test]
1293 fn test_fmm_eigenvalues_decreasing() {
1294 let (data, subject_ids, _cov, _t) = generate_fmm_data(10, 3, 50);
1295 let result = fmm(&data, &subject_ids, None, 5).unwrap();
1296
1297 for i in 1..result.eigenvalues.len() {
1299 assert!(
1300 result.eigenvalues[i] <= result.eigenvalues[i - 1] + 1e-10,
1301 "Eigenvalues should be non-increasing: {} > {}",
1302 result.eigenvalues[i],
1303 result.eigenvalues[i - 1]
1304 );
1305 }
1306 }
1307
1308 #[test]
1309 fn test_fmm_random_effects_sum_near_zero() {
1310 let (data, subject_ids, covariates, _t) = generate_fmm_data(20, 3, 40);
1312 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1313
1314 let m = result.mean_function.len();
1315 for t in 0..m {
1316 let sum: f64 = (0..result.n_subjects)
1317 .map(|s| result.random_effects[(s, t)])
1318 .sum();
1319 let mean_abs: f64 = (0..result.n_subjects)
1320 .map(|s| result.random_effects[(s, t)].abs())
1321 .sum::<f64>()
1322 / result.n_subjects as f64;
1323 if mean_abs > 1e-10 {
1325 assert!(
1326 (sum / result.n_subjects as f64).abs() < mean_abs * 2.0,
1327 "Random effects should roughly center around zero at t={}: sum={}, mean_abs={}",
1328 t,
1329 sum,
1330 mean_abs
1331 );
1332 }
1333 }
1334 }
1335
1336 #[test]
1337 fn test_fmm_subject_ids_mismatch_error() {
1338 let data = FdMatrix::zeros(10, 20);
1339 let ids = vec![0; 7]; let err = fmm(&data, &ids, None, 1).unwrap_err();
1341 match err {
1342 FdarError::InvalidDimension { parameter, .. } => {
1343 assert_eq!(parameter, "subject_ids");
1344 }
1345 other => panic!("Expected InvalidDimension, got {:?}", other),
1346 }
1347 }
1348
1349 #[test]
1350 fn test_fmm_test_fixed_empty_data_error() {
1351 let data = FdMatrix::zeros(0, 0);
1352 let covariates = FdMatrix::zeros(0, 1);
1353 let err = fmm_test_fixed(&data, &[], &covariates, 1, 10, 42).unwrap_err();
1354 match err {
1355 FdarError::InvalidDimension { parameter, .. } => {
1356 assert_eq!(parameter, "data");
1357 }
1358 other => panic!("Expected InvalidDimension for data, got {:?}", other),
1359 }
1360 }
1361
1362 #[test]
1363 fn test_fmm_test_fixed_zero_covariates_error() {
1364 let data = FdMatrix::zeros(10, 20);
1365 let ids = vec![0; 10];
1366 let covariates = FdMatrix::zeros(10, 0);
1367 let err = fmm_test_fixed(&data, &ids, &covariates, 1, 10, 42).unwrap_err();
1368 match err {
1369 FdarError::InvalidDimension { parameter, .. } => {
1370 assert_eq!(parameter, "covariates");
1371 }
1372 other => panic!("Expected InvalidDimension for covariates, got {:?}", other),
1373 }
1374 }
1375
1376 #[test]
1377 fn test_build_subject_map_single_subject() {
1378 let (map, n) = build_subject_map(&[42, 42, 42]);
1379 assert_eq!(n, 1);
1380 assert_eq!(map, vec![0, 0, 0]);
1381 }
1382
1383 #[test]
1384 fn test_build_subject_map_non_contiguous_ids() {
1385 let (map, n) = build_subject_map(&[100, 200, 100, 300, 200]);
1386 assert_eq!(n, 3);
1387 assert_eq!(map, vec![0, 1, 0, 2, 1]);
1389 }
1390
1391 #[test]
1392 fn test_fmm_many_components_clamped() {
1393 let (data, subject_ids, _cov, _t) = generate_fmm_data(5, 3, 20);
1395 let n_total = data.nrows();
1396 let result = fmm(&data, &subject_ids, None, 100).unwrap();
1398 assert!(
1399 result.ncomp <= n_total.min(20),
1400 "ncomp should be clamped: got {}",
1401 result.ncomp
1402 );
1403 assert!(result.ncomp >= 1);
1404 }
1405
1406 #[test]
1407 fn test_fmm_residuals_small_with_enough_components() {
1408 let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 30);
1410 let result = fmm(&data, &subject_ids, Some(&covariates), 5).unwrap();
1411
1412 let n = data.nrows();
1413 let m = data.ncols();
1414 let mut data_ss = 0.0_f64;
1415 let mut resid_ss = 0.0_f64;
1416 for i in 0..n {
1417 for t in 0..m {
1418 data_ss += data[(i, t)].powi(2);
1419 resid_ss += result.residuals[(i, t)].powi(2);
1420 }
1421 }
1422
1423 let r_squared = 1.0 - resid_ss / data_ss;
1425 assert!(
1426 r_squared > 0.5,
1427 "R-squared should be high with enough components: {}",
1428 r_squared
1429 );
1430 }
1431}