1use crate::error::FdarError;
14use crate::iter_maybe_parallel;
15use crate::matrix::FdMatrix;
16use crate::regression::fdata_to_pc_1d;
17#[cfg(feature = "parallel")]
18use rayon::iter::ParallelIterator;
19
20#[derive(Debug, Clone, PartialEq)]
22pub struct FmmResult {
23 pub mean_function: Vec<f64>,
25 pub beta_functions: FdMatrix,
27 pub random_effects: FdMatrix,
29 pub fitted: FdMatrix,
31 pub residuals: FdMatrix,
33 pub random_variance: Vec<f64>,
35 pub sigma2_eps: f64,
37 pub sigma2_u: Vec<f64>,
39 pub ncomp: usize,
41 pub n_subjects: usize,
43 pub eigenvalues: Vec<f64>,
45}
46
47#[derive(Debug, Clone, PartialEq)]
49pub struct FmmTestResult {
50 pub f_statistics: Vec<f64>,
52 pub p_values: Vec<f64>,
54}
55
56#[must_use = "expensive computation whose result should not be discarded"]
82pub fn fmm(
83 data: &FdMatrix,
84 subject_ids: &[usize],
85 covariates: Option<&FdMatrix>,
86 ncomp: usize,
87) -> Result<FmmResult, FdarError> {
88 let n_total = data.nrows();
89 let m = data.ncols();
90 if n_total == 0 || m == 0 {
91 return Err(FdarError::InvalidDimension {
92 parameter: "data",
93 expected: "non-empty matrix".to_string(),
94 actual: format!("{n_total} x {m}"),
95 });
96 }
97 if subject_ids.len() != n_total {
98 return Err(FdarError::InvalidDimension {
99 parameter: "subject_ids",
100 expected: format!("length {n_total}"),
101 actual: format!("length {}", subject_ids.len()),
102 });
103 }
104 if ncomp == 0 {
105 return Err(FdarError::InvalidParameter {
106 parameter: "ncomp",
107 message: "must be >= 1".to_string(),
108 });
109 }
110
111 let (subject_map, n_subjects) = build_subject_map(subject_ids);
113
114 let fpca = fdata_to_pc_1d(data, ncomp)?;
116 let k = fpca.scores.ncols(); let p = covariates.map_or(0, super::matrix::FdMatrix::ncols);
120 let ComponentResults {
121 gamma,
122 u_hat,
123 sigma2_u,
124 sigma2_eps,
125 } = fit_all_components(
126 &fpca.scores,
127 &subject_map,
128 n_subjects,
129 covariates,
130 p,
131 k,
132 n_total,
133 m,
134 );
135
136 let beta_functions = recover_beta_functions(&gamma, &fpca.rotation, p, m, k);
138 let random_effects = recover_random_effects(&u_hat, &fpca.rotation, n_subjects, m, k);
139
140 let random_variance = compute_random_variance(&random_effects, n_subjects, m);
142
143 let (fitted, residuals) = compute_fitted_residuals(
145 data,
146 &fpca.mean,
147 &beta_functions,
148 &random_effects,
149 covariates,
150 &subject_map,
151 n_total,
152 m,
153 p,
154 );
155
156 let eigenvalues: Vec<f64> = fpca
157 .singular_values
158 .iter()
159 .map(|&sv| sv * sv / n_total as f64)
160 .collect();
161
162 Ok(FmmResult {
163 mean_function: fpca.mean,
164 beta_functions,
165 random_effects,
166 fitted,
167 residuals,
168 random_variance,
169 sigma2_eps,
170 sigma2_u,
171 ncomp: k,
172 n_subjects,
173 eigenvalues,
174 })
175}
176
177fn build_subject_map(subject_ids: &[usize]) -> (Vec<usize>, usize) {
179 let mut unique_ids: Vec<usize> = subject_ids.to_vec();
180 unique_ids.sort_unstable();
181 unique_ids.dedup();
182 let n_subjects = unique_ids.len();
183
184 let map: Vec<usize> = subject_ids
185 .iter()
186 .map(|id| unique_ids.iter().position(|u| u == id).unwrap_or(0))
187 .collect();
188
189 (map, n_subjects)
190}
191
192struct ComponentResults {
194 gamma: Vec<Vec<f64>>, u_hat: Vec<Vec<f64>>, sigma2_u: Vec<f64>, sigma2_eps: f64, }
199
200#[allow(clippy::too_many_arguments)]
205fn fit_all_components(
206 scores: &FdMatrix,
207 subject_map: &[usize],
208 n_subjects: usize,
209 covariates: Option<&FdMatrix>,
210 p: usize,
211 k: usize,
212 n_total: usize,
213 m: usize,
214) -> ComponentResults {
215 let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
218 let score_scale = h.sqrt();
219
220 let per_comp: Vec<ScalarMixedResult> = iter_maybe_parallel!(0..k)
222 .map(|comp| {
223 let comp_scores: Vec<f64> = (0..n_total)
224 .map(|i| scores[(i, comp)] * score_scale)
225 .collect();
226 fit_scalar_mixed_model(&comp_scores, subject_map, n_subjects, covariates, p)
227 })
228 .collect();
229
230 let mut gamma = vec![vec![0.0; k]; p];
232 let mut u_hat = vec![vec![0.0; k]; n_subjects];
233 let mut sigma2_u = vec![0.0; k];
234 let mut sigma2_eps_total = 0.0;
235
236 for (comp, result) in per_comp.iter().enumerate() {
237 for j in 0..p {
238 gamma[j][comp] = result.gamma[j] / score_scale;
239 }
240 for s in 0..n_subjects {
241 u_hat[s][comp] = result.u_hat[s] / score_scale;
242 }
243 sigma2_u[comp] = result.sigma2_u;
244 sigma2_eps_total += result.sigma2_eps;
245 }
246 let sigma2_eps = sigma2_eps_total / k as f64;
247
248 ComponentResults {
249 gamma,
250 u_hat,
251 sigma2_u,
252 sigma2_eps,
253 }
254}
255
256struct ScalarMixedResult {
258 gamma: Vec<f64>, u_hat: Vec<f64>, sigma2_u: f64, sigma2_eps: f64, }
263
264struct SubjectStructure {
266 counts: Vec<usize>,
267 obs: Vec<Vec<usize>>,
268}
269
270impl SubjectStructure {
271 fn new(subject_map: &[usize], n_subjects: usize, n: usize) -> Self {
272 let mut counts = vec![0usize; n_subjects];
273 let mut obs: Vec<Vec<usize>> = vec![Vec::new(); n_subjects];
274 for i in 0..n {
275 let s = subject_map[i];
276 counts[s] += 1;
277 obs[s].push(i);
278 }
279 Self { counts, obs }
280 }
281}
282
283fn shrinkage_weights(ss: &SubjectStructure, sigma2_u: f64, sigma2_e: f64) -> Vec<f64> {
285 ss.counts
286 .iter()
287 .map(|&c| {
288 let ns = c as f64;
289 if ns < 1.0 {
290 0.0
291 } else {
292 sigma2_u / (sigma2_u + sigma2_e / ns)
293 }
294 })
295 .collect()
296}
297
298fn gls_update_gamma(
302 cov: &FdMatrix,
303 p: usize,
304 ss: &SubjectStructure,
305 weights: &[f64],
306 y: &[f64],
307 sigma2_e: f64,
308) -> Option<Vec<f64>> {
309 let n_subjects = ss.counts.len();
310 let mut xtvinvx = vec![0.0; p * p];
311 let mut xtvinvy = vec![0.0; p];
312 let inv_e = 1.0 / sigma2_e;
313
314 for s in 0..n_subjects {
315 let ns = ss.counts[s] as f64;
316 if ns < 1.0 {
317 continue;
318 }
319 let (x_sum, y_sum) = subject_sums(cov, y, &ss.obs[s], p);
320 accumulate_gls_terms(
321 cov,
322 y,
323 &ss.obs[s],
324 &x_sum,
325 y_sum,
326 weights[s],
327 ns,
328 inv_e,
329 p,
330 &mut xtvinvx,
331 &mut xtvinvy,
332 );
333 }
334
335 for j in 0..p {
336 xtvinvx[j * p + j] += 1e-10;
337 }
338 cholesky_solve(&xtvinvx, &xtvinvy, p)
339}
340
341fn subject_sums(cov: &FdMatrix, y: &[f64], obs: &[usize], p: usize) -> (Vec<f64>, f64) {
343 let mut x_sum = vec![0.0; p];
344 let mut y_sum = 0.0;
345 for &i in obs {
346 for r in 0..p {
347 x_sum[r] += cov[(i, r)];
348 }
349 y_sum += y[i];
350 }
351 (x_sum, y_sum)
352}
353
354fn accumulate_gls_terms(
356 cov: &FdMatrix,
357 y: &[f64],
358 obs: &[usize],
359 x_sum: &[f64],
360 y_sum: f64,
361 w_s: f64,
362 ns: f64,
363 inv_e: f64,
364 p: usize,
365 xtvinvx: &mut [f64],
366 xtvinvy: &mut [f64],
367) {
368 for &i in obs {
369 let vinv_y = inv_e * (y[i] - w_s * y_sum / ns);
370 for r in 0..p {
371 xtvinvy[r] += cov[(i, r)] * vinv_y;
372 for c in r..p {
373 let vinv_xc = inv_e * (cov[(i, c)] - w_s * x_sum[c] / ns);
374 let val = cov[(i, r)] * vinv_xc;
375 xtvinvx[r * p + c] += val;
376 if r != c {
377 xtvinvx[c * p + r] += val;
378 }
379 }
380 }
381 }
382}
383
384fn reml_variance_update(
389 residuals: &[f64],
390 ss: &SubjectStructure,
391 weights: &[f64],
392 sigma2_u: f64,
393 p: usize,
394) -> (f64, f64) {
395 let n_subjects = ss.counts.len();
396 let n: usize = ss.counts.iter().sum();
397 let mut sigma2_u_new = 0.0;
398 let mut sigma2_e_new = 0.0;
399
400 for s in 0..n_subjects {
401 let ns = ss.counts[s] as f64;
402 if ns < 1.0 {
403 continue;
404 }
405 let w_s = weights[s];
406 let mean_r_s: f64 = ss.obs[s].iter().map(|&i| residuals[i]).sum::<f64>() / ns;
407 let u_hat_s = w_s * mean_r_s;
408 let cond_var_s = sigma2_u * (1.0 - w_s);
409
410 sigma2_u_new += u_hat_s * u_hat_s + cond_var_s;
411 for &i in &ss.obs[s] {
412 sigma2_e_new += (residuals[i] - u_hat_s).powi(2);
413 }
414 sigma2_e_new += ns * cond_var_s;
415 }
416
417 let denom_e = (n.saturating_sub(p)).max(1) as f64;
419
420 (
421 (sigma2_u_new / n_subjects as f64).max(1e-15),
422 (sigma2_e_new / denom_e).max(1e-15),
423 )
424}
425
426fn fit_scalar_mixed_model(
432 y: &[f64],
433 subject_map: &[usize],
434 n_subjects: usize,
435 covariates: Option<&FdMatrix>,
436 p: usize,
437) -> ScalarMixedResult {
438 let n = y.len();
439 let ss = SubjectStructure::new(subject_map, n_subjects, n);
440
441 let gamma_init = estimate_fixed_effects(y, covariates, p, n);
443 let residuals_init = compute_ols_residuals(y, covariates, &gamma_init, p, n);
444 let (mut sigma2_u, mut sigma2_e) =
445 estimate_variance_components(&residuals_init, subject_map, n_subjects, n);
446
447 if sigma2_e < 1e-15 {
448 sigma2_e = 1e-6;
449 }
450 if sigma2_u < 1e-15 {
451 sigma2_u = sigma2_e * 0.1;
452 }
453
454 let mut gamma = gamma_init;
455
456 for _iter in 0..50 {
457 let sigma2_u_old = sigma2_u;
458 let sigma2_e_old = sigma2_e;
459
460 let weights = shrinkage_weights(&ss, sigma2_u, sigma2_e);
461
462 if let Some(cov) = covariates.filter(|_| p > 0) {
463 if let Some(g) = gls_update_gamma(cov, p, &ss, &weights, y, sigma2_e) {
464 gamma = g;
465 }
466 }
467
468 let r = compute_ols_residuals(y, covariates, &gamma, p, n);
469 (sigma2_u, sigma2_e) = reml_variance_update(&r, &ss, &weights, sigma2_u, p);
470
471 let delta = (sigma2_u - sigma2_u_old).abs() + (sigma2_e - sigma2_e_old).abs();
472 if delta < 1e-10 * (sigma2_u_old + sigma2_e_old) {
473 break;
474 }
475 }
476
477 let final_residuals = compute_ols_residuals(y, covariates, &gamma, p, n);
478 let u_hat = compute_blup(
479 &final_residuals,
480 subject_map,
481 n_subjects,
482 sigma2_u,
483 sigma2_e,
484 );
485
486 ScalarMixedResult {
487 gamma,
488 u_hat,
489 sigma2_u,
490 sigma2_eps: sigma2_e,
491 }
492}
493
494fn estimate_fixed_effects(
496 y: &[f64],
497 covariates: Option<&FdMatrix>,
498 p: usize,
499 n: usize,
500) -> Vec<f64> {
501 if p == 0 || covariates.is_none() {
502 return Vec::new();
503 }
504 let cov = covariates.expect("checked: covariates is Some");
505
506 let mut xtx = vec![0.0; p * p];
508 let mut xty = vec![0.0; p];
509 for i in 0..n {
510 for r in 0..p {
511 xty[r] += cov[(i, r)] * y[i];
512 for s in r..p {
513 let val = cov[(i, r)] * cov[(i, s)];
514 xtx[r * p + s] += val;
515 if r != s {
516 xtx[s * p + r] += val;
517 }
518 }
519 }
520 }
521 for j in 0..p {
523 xtx[j * p + j] += 1e-8;
524 }
525
526 cholesky_solve(&xtx, &xty, p).unwrap_or(vec![0.0; p])
527}
528
529fn cholesky_factor_famm(a: &[f64], p: usize) -> Option<Vec<f64>> {
531 let mut l = vec![0.0; p * p];
532 for j in 0..p {
533 let mut sum = 0.0;
534 for k in 0..j {
535 sum += l[j * p + k] * l[j * p + k];
536 }
537 let diag = a[j * p + j] - sum;
538 if diag <= 0.0 {
539 return None;
540 }
541 l[j * p + j] = diag.sqrt();
542 for i in (j + 1)..p {
543 let mut s = 0.0;
544 for k in 0..j {
545 s += l[i * p + k] * l[j * p + k];
546 }
547 l[i * p + j] = (a[i * p + j] - s) / l[j * p + j];
548 }
549 }
550 Some(l)
551}
552
553fn cholesky_triangular_solve(l: &[f64], b: &[f64], p: usize) -> Vec<f64> {
555 let mut z = vec![0.0; p];
556 for i in 0..p {
557 let mut s = 0.0;
558 for j in 0..i {
559 s += l[i * p + j] * z[j];
560 }
561 z[i] = (b[i] - s) / l[i * p + i];
562 }
563 for i in (0..p).rev() {
564 let mut s = 0.0;
565 for j in (i + 1)..p {
566 s += l[j * p + i] * z[j];
567 }
568 z[i] = (z[i] - s) / l[i * p + i];
569 }
570 z
571}
572
573fn cholesky_solve(a: &[f64], b: &[f64], p: usize) -> Option<Vec<f64>> {
575 let l = cholesky_factor_famm(a, p)?;
576 Some(cholesky_triangular_solve(&l, b, p))
577}
578
579fn compute_ols_residuals(
581 y: &[f64],
582 covariates: Option<&FdMatrix>,
583 gamma: &[f64],
584 p: usize,
585 n: usize,
586) -> Vec<f64> {
587 let mut residuals = y.to_vec();
588 if p > 0 {
589 if let Some(cov) = covariates {
590 for i in 0..n {
591 for j in 0..p {
592 residuals[i] -= cov[(i, j)] * gamma[j];
593 }
594 }
595 }
596 }
597 residuals
598}
599
600fn estimate_variance_components(
604 residuals: &[f64],
605 subject_map: &[usize],
606 n_subjects: usize,
607 n: usize,
608) -> (f64, f64) {
609 let mut subject_sums = vec![0.0; n_subjects];
611 let mut subject_counts = vec![0usize; n_subjects];
612 for i in 0..n {
613 let s = subject_map[i];
614 subject_sums[s] += residuals[i];
615 subject_counts[s] += 1;
616 }
617 let subject_means: Vec<f64> = subject_sums
618 .iter()
619 .zip(&subject_counts)
620 .map(|(&s, &c)| if c > 0 { s / c as f64 } else { 0.0 })
621 .collect();
622
623 let mut ss_within = 0.0;
625 for i in 0..n {
626 let s = subject_map[i];
627 ss_within += (residuals[i] - subject_means[s]).powi(2);
628 }
629 let df_within = n.saturating_sub(n_subjects);
630
631 let grand_mean = residuals.iter().sum::<f64>() / n as f64;
633 let mut ss_between = 0.0;
634 for s in 0..n_subjects {
635 ss_between += subject_counts[s] as f64 * (subject_means[s] - grand_mean).powi(2);
636 }
637
638 let sigma2_eps = if df_within > 0 {
639 ss_within / df_within as f64
640 } else {
641 1e-6
642 };
643
644 let n_bar = n as f64 / n_subjects.max(1) as f64;
646 let df_between = n_subjects.saturating_sub(1).max(1);
647 let ms_between = ss_between / df_between as f64;
648 let sigma2_u = ((ms_between - sigma2_eps) / n_bar).max(0.0);
649
650 (sigma2_u, sigma2_eps)
651}
652
653fn compute_blup(
657 residuals: &[f64],
658 subject_map: &[usize],
659 n_subjects: usize,
660 sigma2_u: f64,
661 sigma2_eps: f64,
662) -> Vec<f64> {
663 let mut subject_sums = vec![0.0; n_subjects];
664 let mut subject_counts = vec![0usize; n_subjects];
665 for (i, &r) in residuals.iter().enumerate() {
666 let s = subject_map[i];
667 subject_sums[s] += r;
668 subject_counts[s] += 1;
669 }
670
671 (0..n_subjects)
672 .map(|s| {
673 let ni = subject_counts[s] as f64;
674 if ni < 1.0 {
675 return 0.0;
676 }
677 let mean_r = subject_sums[s] / ni;
678 let shrinkage = sigma2_u / (sigma2_u + sigma2_eps / ni).max(1e-15);
679 shrinkage * mean_r
680 })
681 .collect()
682}
683
684fn recover_beta_functions(
690 gamma: &[Vec<f64>],
691 rotation: &FdMatrix,
692 p: usize,
693 m: usize,
694 k: usize,
695) -> FdMatrix {
696 let mut beta = FdMatrix::zeros(p, m);
697 for j in 0..p {
698 for t in 0..m {
699 let mut val = 0.0;
700 for comp in 0..k {
701 val += gamma[j][comp] * rotation[(t, comp)];
702 }
703 beta[(j, t)] = val;
704 }
705 }
706 beta
707}
708
709fn recover_random_effects(
711 u_hat: &[Vec<f64>],
712 rotation: &FdMatrix,
713 n_subjects: usize,
714 m: usize,
715 k: usize,
716) -> FdMatrix {
717 let mut re = FdMatrix::zeros(n_subjects, m);
718 for s in 0..n_subjects {
719 for t in 0..m {
720 let mut val = 0.0;
721 for comp in 0..k {
722 val += u_hat[s][comp] * rotation[(t, comp)];
723 }
724 re[(s, t)] = val;
725 }
726 }
727 re
728}
729
730fn compute_random_variance(random_effects: &FdMatrix, n_subjects: usize, m: usize) -> Vec<f64> {
732 (0..m)
733 .map(|t| {
734 let mean: f64 =
735 (0..n_subjects).map(|s| random_effects[(s, t)]).sum::<f64>() / n_subjects as f64;
736 let var: f64 = (0..n_subjects)
737 .map(|s| (random_effects[(s, t)] - mean).powi(2))
738 .sum::<f64>()
739 / n_subjects.max(1) as f64;
740 var
741 })
742 .collect()
743}
744
745fn compute_fitted_residuals(
747 data: &FdMatrix,
748 mean_function: &[f64],
749 beta_functions: &FdMatrix,
750 random_effects: &FdMatrix,
751 covariates: Option<&FdMatrix>,
752 subject_map: &[usize],
753 n_total: usize,
754 m: usize,
755 p: usize,
756) -> (FdMatrix, FdMatrix) {
757 let mut fitted = FdMatrix::zeros(n_total, m);
758 let mut residuals = FdMatrix::zeros(n_total, m);
759
760 for i in 0..n_total {
761 let s = subject_map[i];
762 for t in 0..m {
763 let mut val = mean_function[t] + random_effects[(s, t)];
764 if p > 0 {
765 if let Some(cov) = covariates {
766 for j in 0..p {
767 val += cov[(i, j)] * beta_functions[(j, t)];
768 }
769 }
770 }
771 fitted[(i, t)] = val;
772 residuals[(i, t)] = data[(i, t)] - val;
773 }
774 }
775
776 (fitted, residuals)
777}
778
779#[must_use = "prediction result should not be discarded"]
791pub fn fmm_predict(result: &FmmResult, new_covariates: Option<&FdMatrix>) -> FdMatrix {
792 let m = result.mean_function.len();
793 let n_new = new_covariates.map_or(1, super::matrix::FdMatrix::nrows);
794 let p = result.beta_functions.nrows();
795
796 let mut predicted = FdMatrix::zeros(n_new, m);
797 for i in 0..n_new {
798 for t in 0..m {
799 let mut val = result.mean_function[t];
800 if let Some(cov) = new_covariates {
801 for j in 0..p {
802 val += cov[(i, j)] * result.beta_functions[(j, t)];
803 }
804 }
805 predicted[(i, t)] = val;
806 }
807 }
808 predicted
809}
810
811#[must_use = "expensive computation whose result should not be discarded"]
834pub fn fmm_test_fixed(
835 data: &FdMatrix,
836 subject_ids: &[usize],
837 covariates: &FdMatrix,
838 ncomp: usize,
839 n_perm: usize,
840 seed: u64,
841) -> Result<FmmTestResult, FdarError> {
842 let n_total = data.nrows();
843 let m = data.ncols();
844 let p = covariates.ncols();
845 if n_total == 0 {
846 return Err(FdarError::InvalidDimension {
847 parameter: "data",
848 expected: "non-empty matrix".to_string(),
849 actual: format!("{n_total} rows"),
850 });
851 }
852 if p == 0 {
853 return Err(FdarError::InvalidDimension {
854 parameter: "covariates",
855 expected: "at least 1 column".to_string(),
856 actual: "0 columns".to_string(),
857 });
858 }
859
860 let result = fmm(data, subject_ids, Some(covariates), ncomp)?;
862
863 let observed_stats = compute_integrated_beta_sq(&result.beta_functions, p, m);
865
866 let (f_statistics, p_values) = permutation_test(
868 data,
869 subject_ids,
870 covariates,
871 ncomp,
872 n_perm,
873 seed,
874 &observed_stats,
875 p,
876 m,
877 );
878
879 Ok(FmmTestResult {
880 f_statistics,
881 p_values,
882 })
883}
884
885fn compute_integrated_beta_sq(beta: &FdMatrix, p: usize, m: usize) -> Vec<f64> {
887 let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
888 (0..p)
889 .map(|j| {
890 let ss: f64 = (0..m).map(|t| beta[(j, t)].powi(2)).sum();
891 ss * h
892 })
893 .collect()
894}
895
896fn permutation_test(
898 data: &FdMatrix,
899 subject_ids: &[usize],
900 covariates: &FdMatrix,
901 ncomp: usize,
902 n_perm: usize,
903 seed: u64,
904 observed_stats: &[f64],
905 p: usize,
906 m: usize,
907) -> (Vec<f64>, Vec<f64>) {
908 use rand::prelude::*;
909 let n_total = data.nrows();
910 let mut rng = StdRng::seed_from_u64(seed);
911 let mut n_ge = vec![0usize; p];
912
913 for _ in 0..n_perm {
914 let mut perm_indices: Vec<usize> = (0..n_total).collect();
916 perm_indices.shuffle(&mut rng);
917 let perm_cov = permute_rows(covariates, &perm_indices);
918
919 if let Ok(perm_result) = fmm(data, subject_ids, Some(&perm_cov), ncomp) {
920 let perm_stats = compute_integrated_beta_sq(&perm_result.beta_functions, p, m);
921 for j in 0..p {
922 if perm_stats[j] >= observed_stats[j] {
923 n_ge[j] += 1;
924 }
925 }
926 }
927 }
928
929 let p_values: Vec<f64> = n_ge
930 .iter()
931 .map(|&count| (count + 1) as f64 / (n_perm + 1) as f64)
932 .collect();
933 let f_statistics = observed_stats.to_vec();
934
935 (f_statistics, p_values)
936}
937
938fn permute_rows(mat: &FdMatrix, indices: &[usize]) -> FdMatrix {
940 let n = indices.len();
941 let m = mat.ncols();
942 let mut result = FdMatrix::zeros(n, m);
943 for (new_i, &old_i) in indices.iter().enumerate() {
944 for j in 0..m {
945 result[(new_i, j)] = mat[(old_i, j)];
946 }
947 }
948 result
949}
950
951#[cfg(test)]
956mod tests {
957 use super::*;
958 use crate::test_helpers::uniform_grid;
959 use std::f64::consts::PI;
960
961 fn generate_fmm_data(
964 n_subjects: usize,
965 n_visits: usize,
966 m: usize,
967 ) -> (FdMatrix, Vec<usize>, FdMatrix, Vec<f64>) {
968 let t = uniform_grid(m);
969 let n_total = n_subjects * n_visits;
970 let mut col_major = vec![0.0; n_total * m];
971 let mut subject_ids = vec![0usize; n_total];
972 let mut cov_data = vec![0.0; n_total];
973
974 for s in 0..n_subjects {
975 let z = s as f64 / n_subjects as f64; let subject_effect = 0.5 * (s as f64 - n_subjects as f64 / 2.0); for v in 0..n_visits {
979 let obs = s * n_visits + v;
980 subject_ids[obs] = s;
981 cov_data[obs] = z;
982 let noise_scale = 0.05;
983
984 for (j, &tj) in t.iter().enumerate() {
985 let mu = (2.0 * PI * tj).sin();
987 let fixed = z * tj * 3.0;
988 let random = subject_effect * (2.0 * PI * tj).cos() * 0.3;
989 let noise = noise_scale * ((obs * 7 + j * 3) % 100) as f64 / 100.0;
990 col_major[obs + j * n_total] = mu + fixed + random + noise;
991 }
992 }
993 }
994
995 let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
996 let covariates = FdMatrix::from_column_major(cov_data, n_total, 1).unwrap();
997 (data, subject_ids, covariates, t)
998 }
999
1000 #[test]
1001 fn test_fmm_basic() {
1002 let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1003 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1004
1005 assert_eq!(result.mean_function.len(), 50);
1006 assert_eq!(result.beta_functions.nrows(), 1); assert_eq!(result.beta_functions.ncols(), 50);
1008 assert_eq!(result.random_effects.nrows(), 10);
1009 assert_eq!(result.fitted.nrows(), 30);
1010 assert_eq!(result.residuals.nrows(), 30);
1011 assert_eq!(result.n_subjects, 10);
1012 }
1013
1014 #[test]
1015 fn test_fmm_fitted_plus_residuals_equals_data() {
1016 let (data, subject_ids, covariates, _t) = generate_fmm_data(8, 3, 40);
1017 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1018
1019 let n = data.nrows();
1020 let m = data.ncols();
1021 for i in 0..n {
1022 for t in 0..m {
1023 let reconstructed = result.fitted[(i, t)] + result.residuals[(i, t)];
1024 assert!(
1025 (reconstructed - data[(i, t)]).abs() < 1e-8,
1026 "Fitted + residual should equal data at ({}, {}): {} vs {}",
1027 i,
1028 t,
1029 reconstructed,
1030 data[(i, t)]
1031 );
1032 }
1033 }
1034 }
1035
1036 #[test]
1037 fn test_fmm_random_variance_positive() {
1038 let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1039 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1040
1041 for &v in &result.random_variance {
1042 assert!(v >= 0.0, "Random variance should be non-negative");
1043 }
1044 }
1045
1046 #[test]
1047 fn test_fmm_no_covariates() {
1048 let (data, subject_ids, _cov, _t) = generate_fmm_data(8, 3, 40);
1049 let result = fmm(&data, &subject_ids, None, 3).unwrap();
1050
1051 assert_eq!(result.beta_functions.nrows(), 0);
1052 assert_eq!(result.n_subjects, 8);
1053 assert_eq!(result.fitted.nrows(), 24);
1054 }
1055
1056 #[test]
1057 fn test_fmm_predict() {
1058 let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1059 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1060
1061 let new_cov = FdMatrix::from_column_major(vec![0.5], 1, 1).unwrap();
1063 let predicted = fmm_predict(&result, Some(&new_cov));
1064
1065 assert_eq!(predicted.nrows(), 1);
1066 assert_eq!(predicted.ncols(), 50);
1067
1068 for t in 0..50 {
1070 assert!(predicted[(0, t)].is_finite());
1071 assert!(
1072 predicted[(0, t)].abs() < 20.0,
1073 "Predicted value too extreme at t={}: {}",
1074 t,
1075 predicted[(0, t)]
1076 );
1077 }
1078 }
1079
1080 #[test]
1081 fn test_fmm_test_fixed_detects_effect() {
1082 let (data, subject_ids, covariates, _t) = generate_fmm_data(15, 3, 40);
1083
1084 let result = fmm_test_fixed(&data, &subject_ids, &covariates, 3, 99, 42).unwrap();
1085
1086 assert_eq!(result.f_statistics.len(), 1);
1087 assert_eq!(result.p_values.len(), 1);
1088 assert!(
1089 result.p_values[0] < 0.1,
1090 "Should detect covariate effect, got p={}",
1091 result.p_values[0]
1092 );
1093 }
1094
1095 #[test]
1096 fn test_fmm_test_fixed_no_effect() {
1097 let n_subjects = 10;
1098 let n_visits = 3;
1099 let m = 40;
1100 let t = uniform_grid(m);
1101 let n_total = n_subjects * n_visits;
1102
1103 let mut col_major = vec![0.0; n_total * m];
1105 let mut subject_ids = vec![0usize; n_total];
1106 let mut cov_data = vec![0.0; n_total];
1107
1108 for s in 0..n_subjects {
1109 for v in 0..n_visits {
1110 let obs = s * n_visits + v;
1111 subject_ids[obs] = s;
1112 cov_data[obs] = s as f64 / n_subjects as f64;
1113 for (j, &tj) in t.iter().enumerate() {
1114 col_major[obs + j * n_total] =
1115 (2.0 * PI * tj).sin() + 0.1 * ((obs * 7 + j * 3) % 100) as f64 / 100.0;
1116 }
1117 }
1118 }
1119
1120 let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
1121 let covariates = FdMatrix::from_column_major(cov_data, n_total, 1).unwrap();
1122
1123 let result = fmm_test_fixed(&data, &subject_ids, &covariates, 3, 99, 42).unwrap();
1124 assert!(
1125 result.p_values[0] > 0.05,
1126 "Should not detect effect, got p={}",
1127 result.p_values[0]
1128 );
1129 }
1130
1131 #[test]
1132 fn test_fmm_invalid_input() {
1133 let data = FdMatrix::zeros(0, 0);
1134 assert!(fmm(&data, &[], None, 1).is_err());
1135
1136 let data = FdMatrix::zeros(10, 50);
1137 let ids = vec![0; 5]; assert!(fmm(&data, &ids, None, 1).is_err());
1139 }
1140
1141 #[test]
1142 fn test_fmm_single_visit_per_subject() {
1143 let n = 10;
1144 let m = 40;
1145 let t = uniform_grid(m);
1146 let mut col_major = vec![0.0; n * m];
1147 let subject_ids: Vec<usize> = (0..n).collect();
1148
1149 for i in 0..n {
1150 for (j, &tj) in t.iter().enumerate() {
1151 col_major[i + j * n] = (2.0 * PI * tj).sin();
1152 }
1153 }
1154 let data = FdMatrix::from_column_major(col_major, n, m).unwrap();
1155
1156 let result = fmm(&data, &subject_ids, None, 2).unwrap();
1158 assert_eq!(result.n_subjects, n);
1159 assert_eq!(result.fitted.nrows(), n);
1160 }
1161
1162 #[test]
1163 fn test_build_subject_map() {
1164 let (map, n) = build_subject_map(&[5, 5, 10, 10, 20]);
1165 assert_eq!(n, 3);
1166 assert_eq!(map, vec![0, 0, 1, 1, 2]);
1167 }
1168
1169 #[test]
1170 fn test_variance_components_positive() {
1171 let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1172 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1173
1174 assert!(result.sigma2_eps >= 0.0);
1175 for &s in &result.sigma2_u {
1176 assert!(s >= 0.0);
1177 }
1178 }
1179
1180 #[test]
1185 fn test_fmm_ncomp_zero_returns_error() {
1186 let (data, subject_ids, _cov, _t) = generate_fmm_data(5, 2, 20);
1187 let err = fmm(&data, &subject_ids, None, 0).unwrap_err();
1188 match err {
1189 FdarError::InvalidParameter { parameter, .. } => {
1190 assert_eq!(parameter, "ncomp");
1191 }
1192 other => panic!("Expected InvalidParameter, got {:?}", other),
1193 }
1194 }
1195
1196 #[test]
1197 fn test_fmm_single_component() {
1198 let (data, subject_ids, covariates, _t) = generate_fmm_data(8, 3, 30);
1200 let result = fmm(&data, &subject_ids, Some(&covariates), 1).unwrap();
1201
1202 assert_eq!(result.ncomp, 1);
1203 assert_eq!(result.sigma2_u.len(), 1);
1204 assert_eq!(result.eigenvalues.len(), 1);
1205 assert_eq!(result.mean_function.len(), 30);
1206 for i in 0..data.nrows() {
1208 for t in 0..data.ncols() {
1209 let diff = (result.fitted[(i, t)] + result.residuals[(i, t)] - data[(i, t)]).abs();
1210 assert!(diff < 1e-8);
1211 }
1212 }
1213 }
1214
1215 #[test]
1216 fn test_fmm_two_subjects() {
1217 let n_subjects = 2;
1219 let n_visits = 5;
1220 let m = 20;
1221 let t = uniform_grid(m);
1222 let n_total = n_subjects * n_visits;
1223 let mut col_major = vec![0.0; n_total * m];
1224 let mut subject_ids = vec![0usize; n_total];
1225
1226 for s in 0..n_subjects {
1227 for v in 0..n_visits {
1228 let obs = s * n_visits + v;
1229 subject_ids[obs] = s;
1230 for (j, &tj) in t.iter().enumerate() {
1231 col_major[obs + j * n_total] =
1232 (2.0 * PI * tj).sin() + (s as f64) * 0.5 + 0.01 * v as f64;
1233 }
1234 }
1235 }
1236 let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
1237 let result = fmm(&data, &subject_ids, None, 2).unwrap();
1238
1239 assert_eq!(result.n_subjects, 2);
1240 assert_eq!(result.random_effects.nrows(), 2);
1241 assert_eq!(result.fitted.nrows(), n_total);
1242 }
1243
1244 #[test]
1245 fn test_fmm_predict_no_covariates() {
1246 let (data, subject_ids, _cov, _t) = generate_fmm_data(6, 3, 30);
1247 let result = fmm(&data, &subject_ids, None, 2).unwrap();
1248
1249 let predicted = fmm_predict(&result, None);
1251 assert_eq!(predicted.nrows(), 1);
1252 assert_eq!(predicted.ncols(), 30);
1253 for t in 0..30 {
1254 let diff = (predicted[(0, t)] - result.mean_function[t]).abs();
1255 assert!(
1256 diff < 1e-12,
1257 "Without covariates, prediction should equal mean"
1258 );
1259 }
1260 }
1261
1262 #[test]
1263 fn test_fmm_predict_multiple_new_subjects() {
1264 let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 40);
1265 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1266
1267 let new_cov = FdMatrix::from_column_major(vec![0.1, 0.5, 0.9], 3, 1).unwrap();
1269 let predicted = fmm_predict(&result, Some(&new_cov));
1270
1271 assert_eq!(predicted.nrows(), 3);
1272 assert_eq!(predicted.ncols(), 40);
1273
1274 for i in 0..3 {
1276 for t in 0..40 {
1277 assert!(predicted[(i, t)].is_finite());
1278 }
1279 }
1280
1281 let diff_01: f64 = (0..40)
1283 .map(|t| (predicted[(0, t)] - predicted[(1, t)]).powi(2))
1284 .sum();
1285 assert!(
1286 diff_01 > 1e-10,
1287 "Different covariates should yield different predictions"
1288 );
1289 }
1290
1291 #[test]
1292 fn test_fmm_eigenvalues_decreasing() {
1293 let (data, subject_ids, _cov, _t) = generate_fmm_data(10, 3, 50);
1294 let result = fmm(&data, &subject_ids, None, 5).unwrap();
1295
1296 for i in 1..result.eigenvalues.len() {
1298 assert!(
1299 result.eigenvalues[i] <= result.eigenvalues[i - 1] + 1e-10,
1300 "Eigenvalues should be non-increasing: {} > {}",
1301 result.eigenvalues[i],
1302 result.eigenvalues[i - 1]
1303 );
1304 }
1305 }
1306
1307 #[test]
1308 fn test_fmm_random_effects_sum_near_zero() {
1309 let (data, subject_ids, covariates, _t) = generate_fmm_data(20, 3, 40);
1311 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1312
1313 let m = result.mean_function.len();
1314 for t in 0..m {
1315 let sum: f64 = (0..result.n_subjects)
1316 .map(|s| result.random_effects[(s, t)])
1317 .sum();
1318 let mean_abs: f64 = (0..result.n_subjects)
1319 .map(|s| result.random_effects[(s, t)].abs())
1320 .sum::<f64>()
1321 / result.n_subjects as f64;
1322 if mean_abs > 1e-10 {
1324 assert!(
1325 (sum / result.n_subjects as f64).abs() < mean_abs * 2.0,
1326 "Random effects should roughly center around zero at t={}: sum={}, mean_abs={}",
1327 t,
1328 sum,
1329 mean_abs
1330 );
1331 }
1332 }
1333 }
1334
1335 #[test]
1336 fn test_fmm_subject_ids_mismatch_error() {
1337 let data = FdMatrix::zeros(10, 20);
1338 let ids = vec![0; 7]; let err = fmm(&data, &ids, None, 1).unwrap_err();
1340 match err {
1341 FdarError::InvalidDimension { parameter, .. } => {
1342 assert_eq!(parameter, "subject_ids");
1343 }
1344 other => panic!("Expected InvalidDimension, got {:?}", other),
1345 }
1346 }
1347
1348 #[test]
1349 fn test_fmm_test_fixed_empty_data_error() {
1350 let data = FdMatrix::zeros(0, 0);
1351 let covariates = FdMatrix::zeros(0, 1);
1352 let err = fmm_test_fixed(&data, &[], &covariates, 1, 10, 42).unwrap_err();
1353 match err {
1354 FdarError::InvalidDimension { parameter, .. } => {
1355 assert_eq!(parameter, "data");
1356 }
1357 other => panic!("Expected InvalidDimension for data, got {:?}", other),
1358 }
1359 }
1360
1361 #[test]
1362 fn test_fmm_test_fixed_zero_covariates_error() {
1363 let data = FdMatrix::zeros(10, 20);
1364 let ids = vec![0; 10];
1365 let covariates = FdMatrix::zeros(10, 0);
1366 let err = fmm_test_fixed(&data, &ids, &covariates, 1, 10, 42).unwrap_err();
1367 match err {
1368 FdarError::InvalidDimension { parameter, .. } => {
1369 assert_eq!(parameter, "covariates");
1370 }
1371 other => panic!("Expected InvalidDimension for covariates, got {:?}", other),
1372 }
1373 }
1374
1375 #[test]
1376 fn test_build_subject_map_single_subject() {
1377 let (map, n) = build_subject_map(&[42, 42, 42]);
1378 assert_eq!(n, 1);
1379 assert_eq!(map, vec![0, 0, 0]);
1380 }
1381
1382 #[test]
1383 fn test_build_subject_map_non_contiguous_ids() {
1384 let (map, n) = build_subject_map(&[100, 200, 100, 300, 200]);
1385 assert_eq!(n, 3);
1386 assert_eq!(map, vec![0, 1, 0, 2, 1]);
1388 }
1389
1390 #[test]
1391 fn test_fmm_many_components_clamped() {
1392 let (data, subject_ids, _cov, _t) = generate_fmm_data(5, 3, 20);
1394 let n_total = data.nrows();
1395 let result = fmm(&data, &subject_ids, None, 100).unwrap();
1397 assert!(
1398 result.ncomp <= n_total.min(20),
1399 "ncomp should be clamped: got {}",
1400 result.ncomp
1401 );
1402 assert!(result.ncomp >= 1);
1403 }
1404
1405 #[test]
1406 fn test_fmm_residuals_small_with_enough_components() {
1407 let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 30);
1409 let result = fmm(&data, &subject_ids, Some(&covariates), 5).unwrap();
1410
1411 let n = data.nrows();
1412 let m = data.ncols();
1413 let mut data_ss = 0.0_f64;
1414 let mut resid_ss = 0.0_f64;
1415 for i in 0..n {
1416 for t in 0..m {
1417 data_ss += data[(i, t)].powi(2);
1418 resid_ss += result.residuals[(i, t)].powi(2);
1419 }
1420 }
1421
1422 let r_squared = 1.0 - resid_ss / data_ss;
1424 assert!(
1425 r_squared > 0.5,
1426 "R-squared should be high with enough components: {}",
1427 r_squared
1428 );
1429 }
1430}