1use crate::error::FdarError;
14use crate::matrix::FdMatrix;
15use crate::regression::fdata_to_pc_1d;
16
17#[derive(Debug, Clone, PartialEq)]
19pub struct FmmResult {
20 pub mean_function: Vec<f64>,
22 pub beta_functions: FdMatrix,
24 pub random_effects: FdMatrix,
26 pub fitted: FdMatrix,
28 pub residuals: FdMatrix,
30 pub random_variance: Vec<f64>,
32 pub sigma2_eps: f64,
34 pub sigma2_u: Vec<f64>,
36 pub ncomp: usize,
38 pub n_subjects: usize,
40 pub eigenvalues: Vec<f64>,
42}
43
44#[derive(Debug, Clone, PartialEq)]
46pub struct FmmTestResult {
47 pub f_statistics: Vec<f64>,
49 pub p_values: Vec<f64>,
51}
52
53#[must_use = "expensive computation whose result should not be discarded"]
79pub fn fmm(
80 data: &FdMatrix,
81 subject_ids: &[usize],
82 covariates: Option<&FdMatrix>,
83 ncomp: usize,
84) -> Result<FmmResult, FdarError> {
85 let n_total = data.nrows();
86 let m = data.ncols();
87 if n_total == 0 || m == 0 {
88 return Err(FdarError::InvalidDimension {
89 parameter: "data",
90 expected: "non-empty matrix".to_string(),
91 actual: format!("{n_total} x {m}"),
92 });
93 }
94 if subject_ids.len() != n_total {
95 return Err(FdarError::InvalidDimension {
96 parameter: "subject_ids",
97 expected: format!("length {n_total}"),
98 actual: format!("length {}", subject_ids.len()),
99 });
100 }
101 if ncomp == 0 {
102 return Err(FdarError::InvalidParameter {
103 parameter: "ncomp",
104 message: "must be >= 1".to_string(),
105 });
106 }
107
108 let (subject_map, n_subjects) = build_subject_map(subject_ids);
110
111 let fpca = fdata_to_pc_1d(data, ncomp)?;
113 let k = fpca.scores.ncols(); let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
119 let score_scale = h.sqrt();
120
121 let p = covariates.map_or(0, super::matrix::FdMatrix::ncols);
122 let mut gamma = vec![vec![0.0; k]; p]; let mut u_hat = vec![vec![0.0; k]; n_subjects]; let mut sigma2_u = vec![0.0; k];
125 let mut sigma2_eps_total = 0.0;
126
127 for comp in 0..k {
128 let scores: Vec<f64> = (0..n_total)
130 .map(|i| fpca.scores[(i, comp)] * score_scale)
131 .collect();
132 let result = fit_scalar_mixed_model(&scores, &subject_map, n_subjects, covariates, p);
133 for j in 0..p {
135 gamma[j][comp] = result.gamma[j] / score_scale;
136 }
137 for s in 0..n_subjects {
138 u_hat[s][comp] = result.u_hat[s] / score_scale;
140 }
141 sigma2_u[comp] = result.sigma2_u;
143 sigma2_eps_total += result.sigma2_eps;
144 }
145 let sigma2_eps = sigma2_eps_total / k as f64;
146
147 let beta_functions = recover_beta_functions(&gamma, &fpca.rotation, p, m, k);
149 let random_effects = recover_random_effects(&u_hat, &fpca.rotation, n_subjects, m, k);
150
151 let random_variance = compute_random_variance(&random_effects, n_subjects, m);
153
154 let (fitted, residuals) = compute_fitted_residuals(
156 data,
157 &fpca.mean,
158 &beta_functions,
159 &random_effects,
160 covariates,
161 &subject_map,
162 n_total,
163 m,
164 p,
165 );
166
167 let eigenvalues: Vec<f64> = fpca
168 .singular_values
169 .iter()
170 .map(|&sv| sv * sv / n_total as f64)
171 .collect();
172
173 Ok(FmmResult {
174 mean_function: fpca.mean,
175 beta_functions,
176 random_effects,
177 fitted,
178 residuals,
179 random_variance,
180 sigma2_eps,
181 sigma2_u,
182 ncomp: k,
183 n_subjects,
184 eigenvalues,
185 })
186}
187
188fn build_subject_map(subject_ids: &[usize]) -> (Vec<usize>, usize) {
190 let mut unique_ids: Vec<usize> = subject_ids.to_vec();
191 unique_ids.sort_unstable();
192 unique_ids.dedup();
193 let n_subjects = unique_ids.len();
194
195 let map: Vec<usize> = subject_ids
196 .iter()
197 .map(|id| unique_ids.iter().position(|u| u == id).unwrap_or(0))
198 .collect();
199
200 (map, n_subjects)
201}
202
203struct ScalarMixedResult {
205 gamma: Vec<f64>, u_hat: Vec<f64>, sigma2_u: f64, sigma2_eps: f64, }
210
211struct SubjectStructure {
213 counts: Vec<usize>,
214 obs: Vec<Vec<usize>>,
215}
216
217impl SubjectStructure {
218 fn new(subject_map: &[usize], n_subjects: usize, n: usize) -> Self {
219 let mut counts = vec![0usize; n_subjects];
220 let mut obs: Vec<Vec<usize>> = vec![Vec::new(); n_subjects];
221 for i in 0..n {
222 let s = subject_map[i];
223 counts[s] += 1;
224 obs[s].push(i);
225 }
226 Self { counts, obs }
227 }
228}
229
230fn shrinkage_weights(ss: &SubjectStructure, sigma2_u: f64, sigma2_e: f64) -> Vec<f64> {
232 ss.counts
233 .iter()
234 .map(|&c| {
235 let ns = c as f64;
236 if ns < 1.0 {
237 0.0
238 } else {
239 sigma2_u / (sigma2_u + sigma2_e / ns)
240 }
241 })
242 .collect()
243}
244
245fn gls_update_gamma(
249 cov: &FdMatrix,
250 p: usize,
251 ss: &SubjectStructure,
252 weights: &[f64],
253 y: &[f64],
254 sigma2_e: f64,
255) -> Option<Vec<f64>> {
256 let n_subjects = ss.counts.len();
257 let mut xtvinvx = vec![0.0; p * p];
258 let mut xtvinvy = vec![0.0; p];
259 let inv_e = 1.0 / sigma2_e;
260
261 for s in 0..n_subjects {
262 let ns = ss.counts[s] as f64;
263 if ns < 1.0 {
264 continue;
265 }
266 let (x_sum, y_sum) = subject_sums(cov, y, &ss.obs[s], p);
267 accumulate_gls_terms(
268 cov,
269 y,
270 &ss.obs[s],
271 &x_sum,
272 y_sum,
273 weights[s],
274 ns,
275 inv_e,
276 p,
277 &mut xtvinvx,
278 &mut xtvinvy,
279 );
280 }
281
282 for j in 0..p {
283 xtvinvx[j * p + j] += 1e-10;
284 }
285 cholesky_solve(&xtvinvx, &xtvinvy, p)
286}
287
288fn subject_sums(cov: &FdMatrix, y: &[f64], obs: &[usize], p: usize) -> (Vec<f64>, f64) {
290 let mut x_sum = vec![0.0; p];
291 let mut y_sum = 0.0;
292 for &i in obs {
293 for r in 0..p {
294 x_sum[r] += cov[(i, r)];
295 }
296 y_sum += y[i];
297 }
298 (x_sum, y_sum)
299}
300
301fn accumulate_gls_terms(
303 cov: &FdMatrix,
304 y: &[f64],
305 obs: &[usize],
306 x_sum: &[f64],
307 y_sum: f64,
308 w_s: f64,
309 ns: f64,
310 inv_e: f64,
311 p: usize,
312 xtvinvx: &mut [f64],
313 xtvinvy: &mut [f64],
314) {
315 for &i in obs {
316 let vinv_y = inv_e * (y[i] - w_s * y_sum / ns);
317 for r in 0..p {
318 xtvinvy[r] += cov[(i, r)] * vinv_y;
319 for c in r..p {
320 let vinv_xc = inv_e * (cov[(i, c)] - w_s * x_sum[c] / ns);
321 let val = cov[(i, r)] * vinv_xc;
322 xtvinvx[r * p + c] += val;
323 if r != c {
324 xtvinvx[c * p + r] += val;
325 }
326 }
327 }
328 }
329}
330
331fn reml_variance_update(
336 residuals: &[f64],
337 ss: &SubjectStructure,
338 weights: &[f64],
339 sigma2_u: f64,
340 p: usize,
341) -> (f64, f64) {
342 let n_subjects = ss.counts.len();
343 let n: usize = ss.counts.iter().sum();
344 let mut sigma2_u_new = 0.0;
345 let mut sigma2_e_new = 0.0;
346
347 for s in 0..n_subjects {
348 let ns = ss.counts[s] as f64;
349 if ns < 1.0 {
350 continue;
351 }
352 let w_s = weights[s];
353 let mean_r_s: f64 = ss.obs[s].iter().map(|&i| residuals[i]).sum::<f64>() / ns;
354 let u_hat_s = w_s * mean_r_s;
355 let cond_var_s = sigma2_u * (1.0 - w_s);
356
357 sigma2_u_new += u_hat_s * u_hat_s + cond_var_s;
358 for &i in &ss.obs[s] {
359 sigma2_e_new += (residuals[i] - u_hat_s).powi(2);
360 }
361 sigma2_e_new += ns * cond_var_s;
362 }
363
364 let denom_e = (n.saturating_sub(p)).max(1) as f64;
366
367 (
368 (sigma2_u_new / n_subjects as f64).max(1e-15),
369 (sigma2_e_new / denom_e).max(1e-15),
370 )
371}
372
373fn fit_scalar_mixed_model(
379 y: &[f64],
380 subject_map: &[usize],
381 n_subjects: usize,
382 covariates: Option<&FdMatrix>,
383 p: usize,
384) -> ScalarMixedResult {
385 let n = y.len();
386 let ss = SubjectStructure::new(subject_map, n_subjects, n);
387
388 let gamma_init = estimate_fixed_effects(y, covariates, p, n);
390 let residuals_init = compute_ols_residuals(y, covariates, &gamma_init, p, n);
391 let (mut sigma2_u, mut sigma2_e) =
392 estimate_variance_components(&residuals_init, subject_map, n_subjects, n);
393
394 if sigma2_e < 1e-15 {
395 sigma2_e = 1e-6;
396 }
397 if sigma2_u < 1e-15 {
398 sigma2_u = sigma2_e * 0.1;
399 }
400
401 let mut gamma = gamma_init;
402
403 for _iter in 0..50 {
404 let sigma2_u_old = sigma2_u;
405 let sigma2_e_old = sigma2_e;
406
407 let weights = shrinkage_weights(&ss, sigma2_u, sigma2_e);
408
409 if let Some(cov) = covariates.filter(|_| p > 0) {
410 if let Some(g) = gls_update_gamma(cov, p, &ss, &weights, y, sigma2_e) {
411 gamma = g;
412 }
413 }
414
415 let r = compute_ols_residuals(y, covariates, &gamma, p, n);
416 (sigma2_u, sigma2_e) = reml_variance_update(&r, &ss, &weights, sigma2_u, p);
417
418 let delta = (sigma2_u - sigma2_u_old).abs() + (sigma2_e - sigma2_e_old).abs();
419 if delta < 1e-10 * (sigma2_u_old + sigma2_e_old) {
420 break;
421 }
422 }
423
424 let final_residuals = compute_ols_residuals(y, covariates, &gamma, p, n);
425 let u_hat = compute_blup(
426 &final_residuals,
427 subject_map,
428 n_subjects,
429 sigma2_u,
430 sigma2_e,
431 );
432
433 ScalarMixedResult {
434 gamma,
435 u_hat,
436 sigma2_u,
437 sigma2_eps: sigma2_e,
438 }
439}
440
441fn estimate_fixed_effects(
443 y: &[f64],
444 covariates: Option<&FdMatrix>,
445 p: usize,
446 n: usize,
447) -> Vec<f64> {
448 if p == 0 || covariates.is_none() {
449 return Vec::new();
450 }
451 let cov = covariates.expect("checked: covariates is Some");
452
453 let mut xtx = vec![0.0; p * p];
455 let mut xty = vec![0.0; p];
456 for i in 0..n {
457 for r in 0..p {
458 xty[r] += cov[(i, r)] * y[i];
459 for s in r..p {
460 let val = cov[(i, r)] * cov[(i, s)];
461 xtx[r * p + s] += val;
462 if r != s {
463 xtx[s * p + r] += val;
464 }
465 }
466 }
467 }
468 for j in 0..p {
470 xtx[j * p + j] += 1e-8;
471 }
472
473 cholesky_solve(&xtx, &xty, p).unwrap_or(vec![0.0; p])
474}
475
476fn cholesky_factor_famm(a: &[f64], p: usize) -> Option<Vec<f64>> {
478 let mut l = vec![0.0; p * p];
479 for j in 0..p {
480 let mut sum = 0.0;
481 for k in 0..j {
482 sum += l[j * p + k] * l[j * p + k];
483 }
484 let diag = a[j * p + j] - sum;
485 if diag <= 0.0 {
486 return None;
487 }
488 l[j * p + j] = diag.sqrt();
489 for i in (j + 1)..p {
490 let mut s = 0.0;
491 for k in 0..j {
492 s += l[i * p + k] * l[j * p + k];
493 }
494 l[i * p + j] = (a[i * p + j] - s) / l[j * p + j];
495 }
496 }
497 Some(l)
498}
499
500fn cholesky_triangular_solve(l: &[f64], b: &[f64], p: usize) -> Vec<f64> {
502 let mut z = vec![0.0; p];
503 for i in 0..p {
504 let mut s = 0.0;
505 for j in 0..i {
506 s += l[i * p + j] * z[j];
507 }
508 z[i] = (b[i] - s) / l[i * p + i];
509 }
510 for i in (0..p).rev() {
511 let mut s = 0.0;
512 for j in (i + 1)..p {
513 s += l[j * p + i] * z[j];
514 }
515 z[i] = (z[i] - s) / l[i * p + i];
516 }
517 z
518}
519
520fn cholesky_solve(a: &[f64], b: &[f64], p: usize) -> Option<Vec<f64>> {
522 let l = cholesky_factor_famm(a, p)?;
523 Some(cholesky_triangular_solve(&l, b, p))
524}
525
526fn compute_ols_residuals(
528 y: &[f64],
529 covariates: Option<&FdMatrix>,
530 gamma: &[f64],
531 p: usize,
532 n: usize,
533) -> Vec<f64> {
534 let mut residuals = y.to_vec();
535 if p > 0 {
536 if let Some(cov) = covariates {
537 for i in 0..n {
538 for j in 0..p {
539 residuals[i] -= cov[(i, j)] * gamma[j];
540 }
541 }
542 }
543 }
544 residuals
545}
546
547fn estimate_variance_components(
551 residuals: &[f64],
552 subject_map: &[usize],
553 n_subjects: usize,
554 n: usize,
555) -> (f64, f64) {
556 let mut subject_sums = vec![0.0; n_subjects];
558 let mut subject_counts = vec![0usize; n_subjects];
559 for i in 0..n {
560 let s = subject_map[i];
561 subject_sums[s] += residuals[i];
562 subject_counts[s] += 1;
563 }
564 let subject_means: Vec<f64> = subject_sums
565 .iter()
566 .zip(&subject_counts)
567 .map(|(&s, &c)| if c > 0 { s / c as f64 } else { 0.0 })
568 .collect();
569
570 let mut ss_within = 0.0;
572 for i in 0..n {
573 let s = subject_map[i];
574 ss_within += (residuals[i] - subject_means[s]).powi(2);
575 }
576 let df_within = n.saturating_sub(n_subjects);
577
578 let grand_mean = residuals.iter().sum::<f64>() / n as f64;
580 let mut ss_between = 0.0;
581 for s in 0..n_subjects {
582 ss_between += subject_counts[s] as f64 * (subject_means[s] - grand_mean).powi(2);
583 }
584
585 let sigma2_eps = if df_within > 0 {
586 ss_within / df_within as f64
587 } else {
588 1e-6
589 };
590
591 let n_bar = n as f64 / n_subjects.max(1) as f64;
593 let df_between = n_subjects.saturating_sub(1).max(1);
594 let ms_between = ss_between / df_between as f64;
595 let sigma2_u = ((ms_between - sigma2_eps) / n_bar).max(0.0);
596
597 (sigma2_u, sigma2_eps)
598}
599
600fn compute_blup(
604 residuals: &[f64],
605 subject_map: &[usize],
606 n_subjects: usize,
607 sigma2_u: f64,
608 sigma2_eps: f64,
609) -> Vec<f64> {
610 let mut subject_sums = vec![0.0; n_subjects];
611 let mut subject_counts = vec![0usize; n_subjects];
612 for (i, &r) in residuals.iter().enumerate() {
613 let s = subject_map[i];
614 subject_sums[s] += r;
615 subject_counts[s] += 1;
616 }
617
618 (0..n_subjects)
619 .map(|s| {
620 let ni = subject_counts[s] as f64;
621 if ni < 1.0 {
622 return 0.0;
623 }
624 let mean_r = subject_sums[s] / ni;
625 let shrinkage = sigma2_u / (sigma2_u + sigma2_eps / ni).max(1e-15);
626 shrinkage * mean_r
627 })
628 .collect()
629}
630
631fn recover_beta_functions(
637 gamma: &[Vec<f64>],
638 rotation: &FdMatrix,
639 p: usize,
640 m: usize,
641 k: usize,
642) -> FdMatrix {
643 let mut beta = FdMatrix::zeros(p, m);
644 for j in 0..p {
645 for t in 0..m {
646 let mut val = 0.0;
647 for comp in 0..k {
648 val += gamma[j][comp] * rotation[(t, comp)];
649 }
650 beta[(j, t)] = val;
651 }
652 }
653 beta
654}
655
656fn recover_random_effects(
658 u_hat: &[Vec<f64>],
659 rotation: &FdMatrix,
660 n_subjects: usize,
661 m: usize,
662 k: usize,
663) -> FdMatrix {
664 let mut re = FdMatrix::zeros(n_subjects, m);
665 for s in 0..n_subjects {
666 for t in 0..m {
667 let mut val = 0.0;
668 for comp in 0..k {
669 val += u_hat[s][comp] * rotation[(t, comp)];
670 }
671 re[(s, t)] = val;
672 }
673 }
674 re
675}
676
677fn compute_random_variance(random_effects: &FdMatrix, n_subjects: usize, m: usize) -> Vec<f64> {
679 (0..m)
680 .map(|t| {
681 let mean: f64 =
682 (0..n_subjects).map(|s| random_effects[(s, t)]).sum::<f64>() / n_subjects as f64;
683 let var: f64 = (0..n_subjects)
684 .map(|s| (random_effects[(s, t)] - mean).powi(2))
685 .sum::<f64>()
686 / n_subjects.max(1) as f64;
687 var
688 })
689 .collect()
690}
691
692fn compute_fitted_residuals(
694 data: &FdMatrix,
695 mean_function: &[f64],
696 beta_functions: &FdMatrix,
697 random_effects: &FdMatrix,
698 covariates: Option<&FdMatrix>,
699 subject_map: &[usize],
700 n_total: usize,
701 m: usize,
702 p: usize,
703) -> (FdMatrix, FdMatrix) {
704 let mut fitted = FdMatrix::zeros(n_total, m);
705 let mut residuals = FdMatrix::zeros(n_total, m);
706
707 for i in 0..n_total {
708 let s = subject_map[i];
709 for t in 0..m {
710 let mut val = mean_function[t] + random_effects[(s, t)];
711 if p > 0 {
712 if let Some(cov) = covariates {
713 for j in 0..p {
714 val += cov[(i, j)] * beta_functions[(j, t)];
715 }
716 }
717 }
718 fitted[(i, t)] = val;
719 residuals[(i, t)] = data[(i, t)] - val;
720 }
721 }
722
723 (fitted, residuals)
724}
725
726#[must_use = "prediction result should not be discarded"]
738pub fn fmm_predict(result: &FmmResult, new_covariates: Option<&FdMatrix>) -> FdMatrix {
739 let m = result.mean_function.len();
740 let n_new = new_covariates.map_or(1, super::matrix::FdMatrix::nrows);
741 let p = result.beta_functions.nrows();
742
743 let mut predicted = FdMatrix::zeros(n_new, m);
744 for i in 0..n_new {
745 for t in 0..m {
746 let mut val = result.mean_function[t];
747 if let Some(cov) = new_covariates {
748 for j in 0..p {
749 val += cov[(i, j)] * result.beta_functions[(j, t)];
750 }
751 }
752 predicted[(i, t)] = val;
753 }
754 }
755 predicted
756}
757
758#[must_use = "expensive computation whose result should not be discarded"]
781pub fn fmm_test_fixed(
782 data: &FdMatrix,
783 subject_ids: &[usize],
784 covariates: &FdMatrix,
785 ncomp: usize,
786 n_perm: usize,
787 seed: u64,
788) -> Result<FmmTestResult, FdarError> {
789 let n_total = data.nrows();
790 let m = data.ncols();
791 let p = covariates.ncols();
792 if n_total == 0 {
793 return Err(FdarError::InvalidDimension {
794 parameter: "data",
795 expected: "non-empty matrix".to_string(),
796 actual: format!("{n_total} rows"),
797 });
798 }
799 if p == 0 {
800 return Err(FdarError::InvalidDimension {
801 parameter: "covariates",
802 expected: "at least 1 column".to_string(),
803 actual: "0 columns".to_string(),
804 });
805 }
806
807 let result = fmm(data, subject_ids, Some(covariates), ncomp)?;
809
810 let observed_stats = compute_integrated_beta_sq(&result.beta_functions, p, m);
812
813 let (f_statistics, p_values) = permutation_test(
815 data,
816 subject_ids,
817 covariates,
818 ncomp,
819 n_perm,
820 seed,
821 &observed_stats,
822 p,
823 m,
824 );
825
826 Ok(FmmTestResult {
827 f_statistics,
828 p_values,
829 })
830}
831
832fn compute_integrated_beta_sq(beta: &FdMatrix, p: usize, m: usize) -> Vec<f64> {
834 let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
835 (0..p)
836 .map(|j| {
837 let ss: f64 = (0..m).map(|t| beta[(j, t)].powi(2)).sum();
838 ss * h
839 })
840 .collect()
841}
842
843fn permutation_test(
845 data: &FdMatrix,
846 subject_ids: &[usize],
847 covariates: &FdMatrix,
848 ncomp: usize,
849 n_perm: usize,
850 seed: u64,
851 observed_stats: &[f64],
852 p: usize,
853 m: usize,
854) -> (Vec<f64>, Vec<f64>) {
855 use rand::prelude::*;
856 let n_total = data.nrows();
857 let mut rng = StdRng::seed_from_u64(seed);
858 let mut n_ge = vec![0usize; p];
859
860 for _ in 0..n_perm {
861 let mut perm_indices: Vec<usize> = (0..n_total).collect();
863 perm_indices.shuffle(&mut rng);
864 let perm_cov = permute_rows(covariates, &perm_indices);
865
866 if let Ok(perm_result) = fmm(data, subject_ids, Some(&perm_cov), ncomp) {
867 let perm_stats = compute_integrated_beta_sq(&perm_result.beta_functions, p, m);
868 for j in 0..p {
869 if perm_stats[j] >= observed_stats[j] {
870 n_ge[j] += 1;
871 }
872 }
873 }
874 }
875
876 let p_values: Vec<f64> = n_ge
877 .iter()
878 .map(|&count| (count + 1) as f64 / (n_perm + 1) as f64)
879 .collect();
880 let f_statistics = observed_stats.to_vec();
881
882 (f_statistics, p_values)
883}
884
885fn permute_rows(mat: &FdMatrix, indices: &[usize]) -> FdMatrix {
887 let n = indices.len();
888 let m = mat.ncols();
889 let mut result = FdMatrix::zeros(n, m);
890 for (new_i, &old_i) in indices.iter().enumerate() {
891 for j in 0..m {
892 result[(new_i, j)] = mat[(old_i, j)];
893 }
894 }
895 result
896}
897
898#[cfg(test)]
903mod tests {
904 use super::*;
905 use std::f64::consts::PI;
906
907 fn uniform_grid(m: usize) -> Vec<f64> {
908 (0..m).map(|i| i as f64 / (m - 1) as f64).collect()
909 }
910
911 fn generate_fmm_data(
914 n_subjects: usize,
915 n_visits: usize,
916 m: usize,
917 ) -> (FdMatrix, Vec<usize>, FdMatrix, Vec<f64>) {
918 let t = uniform_grid(m);
919 let n_total = n_subjects * n_visits;
920 let mut col_major = vec![0.0; n_total * m];
921 let mut subject_ids = vec![0usize; n_total];
922 let mut cov_data = vec![0.0; n_total];
923
924 for s in 0..n_subjects {
925 let z = s as f64 / n_subjects as f64; let subject_effect = 0.5 * (s as f64 - n_subjects as f64 / 2.0); for v in 0..n_visits {
929 let obs = s * n_visits + v;
930 subject_ids[obs] = s;
931 cov_data[obs] = z;
932 let noise_scale = 0.05;
933
934 for (j, &tj) in t.iter().enumerate() {
935 let mu = (2.0 * PI * tj).sin();
937 let fixed = z * tj * 3.0;
938 let random = subject_effect * (2.0 * PI * tj).cos() * 0.3;
939 let noise = noise_scale * ((obs * 7 + j * 3) % 100) as f64 / 100.0;
940 col_major[obs + j * n_total] = mu + fixed + random + noise;
941 }
942 }
943 }
944
945 let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
946 let covariates = FdMatrix::from_column_major(cov_data, n_total, 1).unwrap();
947 (data, subject_ids, covariates, t)
948 }
949
950 #[test]
951 fn test_fmm_basic() {
952 let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
953 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
954
955 assert_eq!(result.mean_function.len(), 50);
956 assert_eq!(result.beta_functions.nrows(), 1); assert_eq!(result.beta_functions.ncols(), 50);
958 assert_eq!(result.random_effects.nrows(), 10);
959 assert_eq!(result.fitted.nrows(), 30);
960 assert_eq!(result.residuals.nrows(), 30);
961 assert_eq!(result.n_subjects, 10);
962 }
963
964 #[test]
965 fn test_fmm_fitted_plus_residuals_equals_data() {
966 let (data, subject_ids, covariates, _t) = generate_fmm_data(8, 3, 40);
967 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
968
969 let n = data.nrows();
970 let m = data.ncols();
971 for i in 0..n {
972 for t in 0..m {
973 let reconstructed = result.fitted[(i, t)] + result.residuals[(i, t)];
974 assert!(
975 (reconstructed - data[(i, t)]).abs() < 1e-8,
976 "Fitted + residual should equal data at ({}, {}): {} vs {}",
977 i,
978 t,
979 reconstructed,
980 data[(i, t)]
981 );
982 }
983 }
984 }
985
986 #[test]
987 fn test_fmm_random_variance_positive() {
988 let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
989 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
990
991 for &v in &result.random_variance {
992 assert!(v >= 0.0, "Random variance should be non-negative");
993 }
994 }
995
996 #[test]
997 fn test_fmm_no_covariates() {
998 let (data, subject_ids, _cov, _t) = generate_fmm_data(8, 3, 40);
999 let result = fmm(&data, &subject_ids, None, 3).unwrap();
1000
1001 assert_eq!(result.beta_functions.nrows(), 0);
1002 assert_eq!(result.n_subjects, 8);
1003 assert_eq!(result.fitted.nrows(), 24);
1004 }
1005
1006 #[test]
1007 fn test_fmm_predict() {
1008 let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1009 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1010
1011 let new_cov = FdMatrix::from_column_major(vec![0.5], 1, 1).unwrap();
1013 let predicted = fmm_predict(&result, Some(&new_cov));
1014
1015 assert_eq!(predicted.nrows(), 1);
1016 assert_eq!(predicted.ncols(), 50);
1017
1018 for t in 0..50 {
1020 assert!(predicted[(0, t)].is_finite());
1021 assert!(
1022 predicted[(0, t)].abs() < 20.0,
1023 "Predicted value too extreme at t={}: {}",
1024 t,
1025 predicted[(0, t)]
1026 );
1027 }
1028 }
1029
1030 #[test]
1031 fn test_fmm_test_fixed_detects_effect() {
1032 let (data, subject_ids, covariates, _t) = generate_fmm_data(15, 3, 40);
1033
1034 let result = fmm_test_fixed(&data, &subject_ids, &covariates, 3, 99, 42).unwrap();
1035
1036 assert_eq!(result.f_statistics.len(), 1);
1037 assert_eq!(result.p_values.len(), 1);
1038 assert!(
1039 result.p_values[0] < 0.1,
1040 "Should detect covariate effect, got p={}",
1041 result.p_values[0]
1042 );
1043 }
1044
1045 #[test]
1046 fn test_fmm_test_fixed_no_effect() {
1047 let n_subjects = 10;
1048 let n_visits = 3;
1049 let m = 40;
1050 let t = uniform_grid(m);
1051 let n_total = n_subjects * n_visits;
1052
1053 let mut col_major = vec![0.0; n_total * m];
1055 let mut subject_ids = vec![0usize; n_total];
1056 let mut cov_data = vec![0.0; n_total];
1057
1058 for s in 0..n_subjects {
1059 for v in 0..n_visits {
1060 let obs = s * n_visits + v;
1061 subject_ids[obs] = s;
1062 cov_data[obs] = s as f64 / n_subjects as f64;
1063 for (j, &tj) in t.iter().enumerate() {
1064 col_major[obs + j * n_total] =
1065 (2.0 * PI * tj).sin() + 0.1 * ((obs * 7 + j * 3) % 100) as f64 / 100.0;
1066 }
1067 }
1068 }
1069
1070 let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
1071 let covariates = FdMatrix::from_column_major(cov_data, n_total, 1).unwrap();
1072
1073 let result = fmm_test_fixed(&data, &subject_ids, &covariates, 3, 99, 42).unwrap();
1074 assert!(
1075 result.p_values[0] > 0.05,
1076 "Should not detect effect, got p={}",
1077 result.p_values[0]
1078 );
1079 }
1080
1081 #[test]
1082 fn test_fmm_invalid_input() {
1083 let data = FdMatrix::zeros(0, 0);
1084 assert!(fmm(&data, &[], None, 1).is_err());
1085
1086 let data = FdMatrix::zeros(10, 50);
1087 let ids = vec![0; 5]; assert!(fmm(&data, &ids, None, 1).is_err());
1089 }
1090
1091 #[test]
1092 fn test_fmm_single_visit_per_subject() {
1093 let n = 10;
1094 let m = 40;
1095 let t = uniform_grid(m);
1096 let mut col_major = vec![0.0; n * m];
1097 let subject_ids: Vec<usize> = (0..n).collect();
1098
1099 for i in 0..n {
1100 for (j, &tj) in t.iter().enumerate() {
1101 col_major[i + j * n] = (2.0 * PI * tj).sin();
1102 }
1103 }
1104 let data = FdMatrix::from_column_major(col_major, n, m).unwrap();
1105
1106 let result = fmm(&data, &subject_ids, None, 2).unwrap();
1108 assert_eq!(result.n_subjects, n);
1109 assert_eq!(result.fitted.nrows(), n);
1110 }
1111
1112 #[test]
1113 fn test_build_subject_map() {
1114 let (map, n) = build_subject_map(&[5, 5, 10, 10, 20]);
1115 assert_eq!(n, 3);
1116 assert_eq!(map, vec![0, 0, 1, 1, 2]);
1117 }
1118
1119 #[test]
1120 fn test_variance_components_positive() {
1121 let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1122 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1123
1124 assert!(result.sigma2_eps >= 0.0);
1125 for &s in &result.sigma2_u {
1126 assert!(s >= 0.0);
1127 }
1128 }
1129}