1use crate::matrix::FdMatrix;
14use crate::regression::fdata_to_pc_1d;
15
16pub struct FmmResult {
18 pub mean_function: Vec<f64>,
20 pub beta_functions: FdMatrix,
22 pub random_effects: FdMatrix,
24 pub fitted: FdMatrix,
26 pub residuals: FdMatrix,
28 pub random_variance: Vec<f64>,
30 pub sigma2_eps: f64,
32 pub sigma2_u: Vec<f64>,
34 pub ncomp: usize,
36 pub n_subjects: usize,
38 pub eigenvalues: Vec<f64>,
40}
41
42pub struct FmmTestResult {
44 pub f_statistics: Vec<f64>,
46 pub p_values: Vec<f64>,
48}
49
50pub fn fmm(
69 data: &FdMatrix,
70 subject_ids: &[usize],
71 covariates: Option<&FdMatrix>,
72 ncomp: usize,
73) -> Option<FmmResult> {
74 let n_total = data.nrows();
75 let m = data.ncols();
76 if n_total == 0 || m == 0 || subject_ids.len() != n_total || ncomp == 0 {
77 return None;
78 }
79
80 let (subject_map, n_subjects) = build_subject_map(subject_ids);
82
83 let fpca = fdata_to_pc_1d(data, ncomp)?;
85 let k = fpca.scores.ncols(); let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
91 let score_scale = h.sqrt();
92
93 let p = covariates.map_or(0, |c| c.ncols());
94 let mut gamma = vec![vec![0.0; k]; p]; let mut u_hat = vec![vec![0.0; k]; n_subjects]; let mut sigma2_u = vec![0.0; k];
97 let mut sigma2_eps_total = 0.0;
98
99 for comp in 0..k {
100 let scores: Vec<f64> = (0..n_total)
102 .map(|i| fpca.scores[(i, comp)] * score_scale)
103 .collect();
104 let result = fit_scalar_mixed_model(&scores, &subject_map, n_subjects, covariates, p);
105 for j in 0..p {
107 gamma[j][comp] = result.gamma[j] / score_scale;
108 }
109 for s in 0..n_subjects {
110 u_hat[s][comp] = result.u_hat[s] / score_scale;
112 }
113 sigma2_u[comp] = result.sigma2_u;
115 sigma2_eps_total += result.sigma2_eps;
116 }
117 let sigma2_eps = sigma2_eps_total / k as f64;
118
119 let beta_functions = recover_beta_functions(&gamma, &fpca.rotation, p, m, k);
121 let random_effects = recover_random_effects(&u_hat, &fpca.rotation, n_subjects, m, k);
122
123 let random_variance = compute_random_variance(&random_effects, n_subjects, m);
125
126 let (fitted, residuals) = compute_fitted_residuals(
128 data,
129 &fpca.mean,
130 &beta_functions,
131 &random_effects,
132 covariates,
133 &subject_map,
134 n_total,
135 m,
136 p,
137 );
138
139 let eigenvalues: Vec<f64> = fpca
140 .singular_values
141 .iter()
142 .map(|&sv| sv * sv / n_total as f64)
143 .collect();
144
145 Some(FmmResult {
146 mean_function: fpca.mean,
147 beta_functions,
148 random_effects,
149 fitted,
150 residuals,
151 random_variance,
152 sigma2_eps,
153 sigma2_u,
154 ncomp: k,
155 n_subjects,
156 eigenvalues,
157 })
158}
159
160fn build_subject_map(subject_ids: &[usize]) -> (Vec<usize>, usize) {
162 let mut unique_ids: Vec<usize> = subject_ids.to_vec();
163 unique_ids.sort_unstable();
164 unique_ids.dedup();
165 let n_subjects = unique_ids.len();
166
167 let map: Vec<usize> = subject_ids
168 .iter()
169 .map(|id| unique_ids.iter().position(|u| u == id).unwrap_or(0))
170 .collect();
171
172 (map, n_subjects)
173}
174
175struct ScalarMixedResult {
177 gamma: Vec<f64>, u_hat: Vec<f64>, sigma2_u: f64, sigma2_eps: f64, }
182
183struct SubjectStructure {
185 counts: Vec<usize>,
186 obs: Vec<Vec<usize>>,
187}
188
189impl SubjectStructure {
190 fn new(subject_map: &[usize], n_subjects: usize, n: usize) -> Self {
191 let mut counts = vec![0usize; n_subjects];
192 let mut obs: Vec<Vec<usize>> = vec![Vec::new(); n_subjects];
193 for i in 0..n {
194 let s = subject_map[i];
195 counts[s] += 1;
196 obs[s].push(i);
197 }
198 Self { counts, obs }
199 }
200}
201
202fn shrinkage_weights(ss: &SubjectStructure, sigma2_u: f64, sigma2_e: f64) -> Vec<f64> {
204 ss.counts
205 .iter()
206 .map(|&c| {
207 let ns = c as f64;
208 if ns < 1.0 {
209 0.0
210 } else {
211 sigma2_u / (sigma2_u + sigma2_e / ns)
212 }
213 })
214 .collect()
215}
216
217fn gls_update_gamma(
221 cov: &FdMatrix,
222 p: usize,
223 ss: &SubjectStructure,
224 weights: &[f64],
225 y: &[f64],
226 sigma2_e: f64,
227) -> Option<Vec<f64>> {
228 let n_subjects = ss.counts.len();
229 let mut xtvinvx = vec![0.0; p * p];
230 let mut xtvinvy = vec![0.0; p];
231 let inv_e = 1.0 / sigma2_e;
232
233 for s in 0..n_subjects {
234 let ns = ss.counts[s] as f64;
235 if ns < 1.0 {
236 continue;
237 }
238 let (x_sum, y_sum) = subject_sums(cov, y, &ss.obs[s], p);
239 accumulate_gls_terms(
240 cov,
241 y,
242 &ss.obs[s],
243 &x_sum,
244 y_sum,
245 weights[s],
246 ns,
247 inv_e,
248 p,
249 &mut xtvinvx,
250 &mut xtvinvy,
251 );
252 }
253
254 for j in 0..p {
255 xtvinvx[j * p + j] += 1e-10;
256 }
257 cholesky_solve(&xtvinvx, &xtvinvy, p)
258}
259
260fn subject_sums(cov: &FdMatrix, y: &[f64], obs: &[usize], p: usize) -> (Vec<f64>, f64) {
262 let mut x_sum = vec![0.0; p];
263 let mut y_sum = 0.0;
264 for &i in obs {
265 for r in 0..p {
266 x_sum[r] += cov[(i, r)];
267 }
268 y_sum += y[i];
269 }
270 (x_sum, y_sum)
271}
272
273fn accumulate_gls_terms(
275 cov: &FdMatrix,
276 y: &[f64],
277 obs: &[usize],
278 x_sum: &[f64],
279 y_sum: f64,
280 w_s: f64,
281 ns: f64,
282 inv_e: f64,
283 p: usize,
284 xtvinvx: &mut [f64],
285 xtvinvy: &mut [f64],
286) {
287 for &i in obs {
288 let vinv_y = inv_e * (y[i] - w_s * y_sum / ns);
289 for r in 0..p {
290 xtvinvy[r] += cov[(i, r)] * vinv_y;
291 for c in r..p {
292 let vinv_xc = inv_e * (cov[(i, c)] - w_s * x_sum[c] / ns);
293 let val = cov[(i, r)] * vinv_xc;
294 xtvinvx[r * p + c] += val;
295 if r != c {
296 xtvinvx[c * p + r] += val;
297 }
298 }
299 }
300 }
301}
302
303fn reml_variance_update(
308 residuals: &[f64],
309 ss: &SubjectStructure,
310 weights: &[f64],
311 sigma2_u: f64,
312 p: usize,
313) -> (f64, f64) {
314 let n_subjects = ss.counts.len();
315 let n: usize = ss.counts.iter().sum();
316 let mut sigma2_u_new = 0.0;
317 let mut sigma2_e_new = 0.0;
318
319 for s in 0..n_subjects {
320 let ns = ss.counts[s] as f64;
321 if ns < 1.0 {
322 continue;
323 }
324 let w_s = weights[s];
325 let mean_r_s: f64 = ss.obs[s].iter().map(|&i| residuals[i]).sum::<f64>() / ns;
326 let u_hat_s = w_s * mean_r_s;
327 let cond_var_s = sigma2_u * (1.0 - w_s);
328
329 sigma2_u_new += u_hat_s * u_hat_s + cond_var_s;
330 for &i in &ss.obs[s] {
331 sigma2_e_new += (residuals[i] - u_hat_s).powi(2);
332 }
333 sigma2_e_new += ns * cond_var_s;
334 }
335
336 let denom_e = (n.saturating_sub(p)).max(1) as f64;
338
339 (
340 (sigma2_u_new / n_subjects as f64).max(1e-15),
341 (sigma2_e_new / denom_e).max(1e-15),
342 )
343}
344
345fn fit_scalar_mixed_model(
351 y: &[f64],
352 subject_map: &[usize],
353 n_subjects: usize,
354 covariates: Option<&FdMatrix>,
355 p: usize,
356) -> ScalarMixedResult {
357 let n = y.len();
358 let ss = SubjectStructure::new(subject_map, n_subjects, n);
359
360 let gamma_init = estimate_fixed_effects(y, covariates, p, n);
362 let residuals_init = compute_ols_residuals(y, covariates, &gamma_init, p, n);
363 let (mut sigma2_u, mut sigma2_e) =
364 estimate_variance_components(&residuals_init, subject_map, n_subjects, n);
365
366 if sigma2_e < 1e-15 {
367 sigma2_e = 1e-6;
368 }
369 if sigma2_u < 1e-15 {
370 sigma2_u = sigma2_e * 0.1;
371 }
372
373 let mut gamma = gamma_init;
374
375 for _iter in 0..50 {
376 let sigma2_u_old = sigma2_u;
377 let sigma2_e_old = sigma2_e;
378
379 let weights = shrinkage_weights(&ss, sigma2_u, sigma2_e);
380
381 if let Some(cov) = covariates.filter(|_| p > 0) {
382 if let Some(g) = gls_update_gamma(cov, p, &ss, &weights, y, sigma2_e) {
383 gamma = g;
384 }
385 }
386
387 let r = compute_ols_residuals(y, covariates, &gamma, p, n);
388 (sigma2_u, sigma2_e) = reml_variance_update(&r, &ss, &weights, sigma2_u, p);
389
390 let delta = (sigma2_u - sigma2_u_old).abs() + (sigma2_e - sigma2_e_old).abs();
391 if delta < 1e-10 * (sigma2_u_old + sigma2_e_old) {
392 break;
393 }
394 }
395
396 let final_residuals = compute_ols_residuals(y, covariates, &gamma, p, n);
397 let u_hat = compute_blup(
398 &final_residuals,
399 subject_map,
400 n_subjects,
401 sigma2_u,
402 sigma2_e,
403 );
404
405 ScalarMixedResult {
406 gamma,
407 u_hat,
408 sigma2_u,
409 sigma2_eps: sigma2_e,
410 }
411}
412
413fn estimate_fixed_effects(
415 y: &[f64],
416 covariates: Option<&FdMatrix>,
417 p: usize,
418 n: usize,
419) -> Vec<f64> {
420 if p == 0 || covariates.is_none() {
421 return Vec::new();
422 }
423 let cov = covariates.unwrap();
424
425 let mut xtx = vec![0.0; p * p];
427 let mut xty = vec![0.0; p];
428 for i in 0..n {
429 for r in 0..p {
430 xty[r] += cov[(i, r)] * y[i];
431 for s in r..p {
432 let val = cov[(i, r)] * cov[(i, s)];
433 xtx[r * p + s] += val;
434 if r != s {
435 xtx[s * p + r] += val;
436 }
437 }
438 }
439 }
440 for j in 0..p {
442 xtx[j * p + j] += 1e-8;
443 }
444
445 cholesky_solve(&xtx, &xty, p).unwrap_or(vec![0.0; p])
446}
447
448fn cholesky_factor_famm(a: &[f64], p: usize) -> Option<Vec<f64>> {
450 let mut l = vec![0.0; p * p];
451 for j in 0..p {
452 let mut sum = 0.0;
453 for k in 0..j {
454 sum += l[j * p + k] * l[j * p + k];
455 }
456 let diag = a[j * p + j] - sum;
457 if diag <= 0.0 {
458 return None;
459 }
460 l[j * p + j] = diag.sqrt();
461 for i in (j + 1)..p {
462 let mut s = 0.0;
463 for k in 0..j {
464 s += l[i * p + k] * l[j * p + k];
465 }
466 l[i * p + j] = (a[i * p + j] - s) / l[j * p + j];
467 }
468 }
469 Some(l)
470}
471
472fn cholesky_triangular_solve(l: &[f64], b: &[f64], p: usize) -> Vec<f64> {
474 let mut z = vec![0.0; p];
475 for i in 0..p {
476 let mut s = 0.0;
477 for j in 0..i {
478 s += l[i * p + j] * z[j];
479 }
480 z[i] = (b[i] - s) / l[i * p + i];
481 }
482 for i in (0..p).rev() {
483 let mut s = 0.0;
484 for j in (i + 1)..p {
485 s += l[j * p + i] * z[j];
486 }
487 z[i] = (z[i] - s) / l[i * p + i];
488 }
489 z
490}
491
492fn cholesky_solve(a: &[f64], b: &[f64], p: usize) -> Option<Vec<f64>> {
494 let l = cholesky_factor_famm(a, p)?;
495 Some(cholesky_triangular_solve(&l, b, p))
496}
497
498fn compute_ols_residuals(
500 y: &[f64],
501 covariates: Option<&FdMatrix>,
502 gamma: &[f64],
503 p: usize,
504 n: usize,
505) -> Vec<f64> {
506 let mut residuals = y.to_vec();
507 if p > 0 {
508 if let Some(cov) = covariates {
509 for i in 0..n {
510 for j in 0..p {
511 residuals[i] -= cov[(i, j)] * gamma[j];
512 }
513 }
514 }
515 }
516 residuals
517}
518
519fn estimate_variance_components(
523 residuals: &[f64],
524 subject_map: &[usize],
525 n_subjects: usize,
526 n: usize,
527) -> (f64, f64) {
528 let mut subject_sums = vec![0.0; n_subjects];
530 let mut subject_counts = vec![0usize; n_subjects];
531 for i in 0..n {
532 let s = subject_map[i];
533 subject_sums[s] += residuals[i];
534 subject_counts[s] += 1;
535 }
536 let subject_means: Vec<f64> = subject_sums
537 .iter()
538 .zip(&subject_counts)
539 .map(|(&s, &c)| if c > 0 { s / c as f64 } else { 0.0 })
540 .collect();
541
542 let mut ss_within = 0.0;
544 for i in 0..n {
545 let s = subject_map[i];
546 ss_within += (residuals[i] - subject_means[s]).powi(2);
547 }
548 let df_within = n.saturating_sub(n_subjects);
549
550 let grand_mean = residuals.iter().sum::<f64>() / n as f64;
552 let mut ss_between = 0.0;
553 for s in 0..n_subjects {
554 ss_between += subject_counts[s] as f64 * (subject_means[s] - grand_mean).powi(2);
555 }
556
557 let sigma2_eps = if df_within > 0 {
558 ss_within / df_within as f64
559 } else {
560 1e-6
561 };
562
563 let n_bar = n as f64 / n_subjects.max(1) as f64;
565 let df_between = n_subjects.saturating_sub(1).max(1);
566 let ms_between = ss_between / df_between as f64;
567 let sigma2_u = ((ms_between - sigma2_eps) / n_bar).max(0.0);
568
569 (sigma2_u, sigma2_eps)
570}
571
572fn compute_blup(
576 residuals: &[f64],
577 subject_map: &[usize],
578 n_subjects: usize,
579 sigma2_u: f64,
580 sigma2_eps: f64,
581) -> Vec<f64> {
582 let mut subject_sums = vec![0.0; n_subjects];
583 let mut subject_counts = vec![0usize; n_subjects];
584 for (i, &r) in residuals.iter().enumerate() {
585 let s = subject_map[i];
586 subject_sums[s] += r;
587 subject_counts[s] += 1;
588 }
589
590 (0..n_subjects)
591 .map(|s| {
592 let ni = subject_counts[s] as f64;
593 if ni < 1.0 {
594 return 0.0;
595 }
596 let mean_r = subject_sums[s] / ni;
597 let shrinkage = sigma2_u / (sigma2_u + sigma2_eps / ni).max(1e-15);
598 shrinkage * mean_r
599 })
600 .collect()
601}
602
603fn recover_beta_functions(
609 gamma: &[Vec<f64>],
610 rotation: &FdMatrix,
611 p: usize,
612 m: usize,
613 k: usize,
614) -> FdMatrix {
615 let mut beta = FdMatrix::zeros(p, m);
616 for j in 0..p {
617 for t in 0..m {
618 let mut val = 0.0;
619 for comp in 0..k {
620 val += gamma[j][comp] * rotation[(t, comp)];
621 }
622 beta[(j, t)] = val;
623 }
624 }
625 beta
626}
627
628fn recover_random_effects(
630 u_hat: &[Vec<f64>],
631 rotation: &FdMatrix,
632 n_subjects: usize,
633 m: usize,
634 k: usize,
635) -> FdMatrix {
636 let mut re = FdMatrix::zeros(n_subjects, m);
637 for s in 0..n_subjects {
638 for t in 0..m {
639 let mut val = 0.0;
640 for comp in 0..k {
641 val += u_hat[s][comp] * rotation[(t, comp)];
642 }
643 re[(s, t)] = val;
644 }
645 }
646 re
647}
648
649fn compute_random_variance(random_effects: &FdMatrix, n_subjects: usize, m: usize) -> Vec<f64> {
651 (0..m)
652 .map(|t| {
653 let mean: f64 =
654 (0..n_subjects).map(|s| random_effects[(s, t)]).sum::<f64>() / n_subjects as f64;
655 let var: f64 = (0..n_subjects)
656 .map(|s| (random_effects[(s, t)] - mean).powi(2))
657 .sum::<f64>()
658 / n_subjects.max(1) as f64;
659 var
660 })
661 .collect()
662}
663
664fn compute_fitted_residuals(
666 data: &FdMatrix,
667 mean_function: &[f64],
668 beta_functions: &FdMatrix,
669 random_effects: &FdMatrix,
670 covariates: Option<&FdMatrix>,
671 subject_map: &[usize],
672 n_total: usize,
673 m: usize,
674 p: usize,
675) -> (FdMatrix, FdMatrix) {
676 let mut fitted = FdMatrix::zeros(n_total, m);
677 let mut residuals = FdMatrix::zeros(n_total, m);
678
679 for i in 0..n_total {
680 let s = subject_map[i];
681 for t in 0..m {
682 let mut val = mean_function[t] + random_effects[(s, t)];
683 if p > 0 {
684 if let Some(cov) = covariates {
685 for j in 0..p {
686 val += cov[(i, j)] * beta_functions[(j, t)];
687 }
688 }
689 }
690 fitted[(i, t)] = val;
691 residuals[(i, t)] = data[(i, t)] - val;
692 }
693 }
694
695 (fitted, residuals)
696}
697
698pub fn fmm_predict(result: &FmmResult, new_covariates: Option<&FdMatrix>) -> FdMatrix {
710 let m = result.mean_function.len();
711 let n_new = new_covariates.map_or(1, |c| c.nrows());
712 let p = result.beta_functions.nrows();
713
714 let mut predicted = FdMatrix::zeros(n_new, m);
715 for i in 0..n_new {
716 for t in 0..m {
717 let mut val = result.mean_function[t];
718 if let Some(cov) = new_covariates {
719 for j in 0..p {
720 val += cov[(i, j)] * result.beta_functions[(j, t)];
721 }
722 }
723 predicted[(i, t)] = val;
724 }
725 }
726 predicted
727}
728
729pub fn fmm_test_fixed(
746 data: &FdMatrix,
747 subject_ids: &[usize],
748 covariates: &FdMatrix,
749 ncomp: usize,
750 n_perm: usize,
751 seed: u64,
752) -> Option<FmmTestResult> {
753 let n_total = data.nrows();
754 let m = data.ncols();
755 let p = covariates.ncols();
756 if n_total == 0 || p == 0 {
757 return None;
758 }
759
760 let result = fmm(data, subject_ids, Some(covariates), ncomp)?;
762
763 let observed_stats = compute_integrated_beta_sq(&result.beta_functions, p, m);
765
766 let (f_statistics, p_values) = permutation_test(
768 data,
769 subject_ids,
770 covariates,
771 ncomp,
772 n_perm,
773 seed,
774 &observed_stats,
775 p,
776 m,
777 );
778
779 Some(FmmTestResult {
780 f_statistics,
781 p_values,
782 })
783}
784
785fn compute_integrated_beta_sq(beta: &FdMatrix, p: usize, m: usize) -> Vec<f64> {
787 let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
788 (0..p)
789 .map(|j| {
790 let ss: f64 = (0..m).map(|t| beta[(j, t)].powi(2)).sum();
791 ss * h
792 })
793 .collect()
794}
795
796fn permutation_test(
798 data: &FdMatrix,
799 subject_ids: &[usize],
800 covariates: &FdMatrix,
801 ncomp: usize,
802 n_perm: usize,
803 seed: u64,
804 observed_stats: &[f64],
805 p: usize,
806 m: usize,
807) -> (Vec<f64>, Vec<f64>) {
808 use rand::prelude::*;
809 let n_total = data.nrows();
810 let mut rng = StdRng::seed_from_u64(seed);
811 let mut n_ge = vec![0usize; p];
812
813 for _ in 0..n_perm {
814 let mut perm_indices: Vec<usize> = (0..n_total).collect();
816 perm_indices.shuffle(&mut rng);
817 let perm_cov = permute_rows(covariates, &perm_indices);
818
819 if let Some(perm_result) = fmm(data, subject_ids, Some(&perm_cov), ncomp) {
820 let perm_stats = compute_integrated_beta_sq(&perm_result.beta_functions, p, m);
821 for j in 0..p {
822 if perm_stats[j] >= observed_stats[j] {
823 n_ge[j] += 1;
824 }
825 }
826 }
827 }
828
829 let p_values: Vec<f64> = n_ge
830 .iter()
831 .map(|&count| (count + 1) as f64 / (n_perm + 1) as f64)
832 .collect();
833 let f_statistics = observed_stats.to_vec();
834
835 (f_statistics, p_values)
836}
837
838fn permute_rows(mat: &FdMatrix, indices: &[usize]) -> FdMatrix {
840 let n = indices.len();
841 let m = mat.ncols();
842 let mut result = FdMatrix::zeros(n, m);
843 for (new_i, &old_i) in indices.iter().enumerate() {
844 for j in 0..m {
845 result[(new_i, j)] = mat[(old_i, j)];
846 }
847 }
848 result
849}
850
851#[cfg(test)]
856mod tests {
857 use super::*;
858 use std::f64::consts::PI;
859
860 fn uniform_grid(m: usize) -> Vec<f64> {
861 (0..m).map(|i| i as f64 / (m - 1) as f64).collect()
862 }
863
864 fn generate_fmm_data(
867 n_subjects: usize,
868 n_visits: usize,
869 m: usize,
870 ) -> (FdMatrix, Vec<usize>, FdMatrix, Vec<f64>) {
871 let t = uniform_grid(m);
872 let n_total = n_subjects * n_visits;
873 let mut col_major = vec![0.0; n_total * m];
874 let mut subject_ids = vec![0usize; n_total];
875 let mut cov_data = vec![0.0; n_total];
876
877 for s in 0..n_subjects {
878 let z = s as f64 / n_subjects as f64; let subject_effect = 0.5 * (s as f64 - n_subjects as f64 / 2.0); for v in 0..n_visits {
882 let obs = s * n_visits + v;
883 subject_ids[obs] = s;
884 cov_data[obs] = z;
885 let noise_scale = 0.05;
886
887 for (j, &tj) in t.iter().enumerate() {
888 let mu = (2.0 * PI * tj).sin();
890 let fixed = z * tj * 3.0;
891 let random = subject_effect * (2.0 * PI * tj).cos() * 0.3;
892 let noise = noise_scale * ((obs * 7 + j * 3) % 100) as f64 / 100.0;
893 col_major[obs + j * n_total] = mu + fixed + random + noise;
894 }
895 }
896 }
897
898 let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
899 let covariates = FdMatrix::from_column_major(cov_data, n_total, 1).unwrap();
900 (data, subject_ids, covariates, t)
901 }
902
903 #[test]
904 fn test_fmm_basic() {
905 let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
906 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
907
908 assert_eq!(result.mean_function.len(), 50);
909 assert_eq!(result.beta_functions.nrows(), 1); assert_eq!(result.beta_functions.ncols(), 50);
911 assert_eq!(result.random_effects.nrows(), 10);
912 assert_eq!(result.fitted.nrows(), 30);
913 assert_eq!(result.residuals.nrows(), 30);
914 assert_eq!(result.n_subjects, 10);
915 }
916
917 #[test]
918 fn test_fmm_fitted_plus_residuals_equals_data() {
919 let (data, subject_ids, covariates, _t) = generate_fmm_data(8, 3, 40);
920 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
921
922 let n = data.nrows();
923 let m = data.ncols();
924 for i in 0..n {
925 for t in 0..m {
926 let reconstructed = result.fitted[(i, t)] + result.residuals[(i, t)];
927 assert!(
928 (reconstructed - data[(i, t)]).abs() < 1e-8,
929 "Fitted + residual should equal data at ({}, {}): {} vs {}",
930 i,
931 t,
932 reconstructed,
933 data[(i, t)]
934 );
935 }
936 }
937 }
938
939 #[test]
940 fn test_fmm_random_variance_positive() {
941 let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
942 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
943
944 for &v in &result.random_variance {
945 assert!(v >= 0.0, "Random variance should be non-negative");
946 }
947 }
948
949 #[test]
950 fn test_fmm_no_covariates() {
951 let (data, subject_ids, _cov, _t) = generate_fmm_data(8, 3, 40);
952 let result = fmm(&data, &subject_ids, None, 3).unwrap();
953
954 assert_eq!(result.beta_functions.nrows(), 0);
955 assert_eq!(result.n_subjects, 8);
956 assert_eq!(result.fitted.nrows(), 24);
957 }
958
959 #[test]
960 fn test_fmm_predict() {
961 let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
962 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
963
964 let new_cov = FdMatrix::from_column_major(vec![0.5], 1, 1).unwrap();
966 let predicted = fmm_predict(&result, Some(&new_cov));
967
968 assert_eq!(predicted.nrows(), 1);
969 assert_eq!(predicted.ncols(), 50);
970
971 for t in 0..50 {
973 assert!(predicted[(0, t)].is_finite());
974 assert!(
975 predicted[(0, t)].abs() < 20.0,
976 "Predicted value too extreme at t={}: {}",
977 t,
978 predicted[(0, t)]
979 );
980 }
981 }
982
983 #[test]
984 fn test_fmm_test_fixed_detects_effect() {
985 let (data, subject_ids, covariates, _t) = generate_fmm_data(15, 3, 40);
986
987 let result = fmm_test_fixed(&data, &subject_ids, &covariates, 3, 99, 42).unwrap();
988
989 assert_eq!(result.f_statistics.len(), 1);
990 assert_eq!(result.p_values.len(), 1);
991 assert!(
992 result.p_values[0] < 0.1,
993 "Should detect covariate effect, got p={}",
994 result.p_values[0]
995 );
996 }
997
998 #[test]
999 fn test_fmm_test_fixed_no_effect() {
1000 let n_subjects = 10;
1001 let n_visits = 3;
1002 let m = 40;
1003 let t = uniform_grid(m);
1004 let n_total = n_subjects * n_visits;
1005
1006 let mut col_major = vec![0.0; n_total * m];
1008 let mut subject_ids = vec![0usize; n_total];
1009 let mut cov_data = vec![0.0; n_total];
1010
1011 for s in 0..n_subjects {
1012 for v in 0..n_visits {
1013 let obs = s * n_visits + v;
1014 subject_ids[obs] = s;
1015 cov_data[obs] = s as f64 / n_subjects as f64;
1016 for (j, &tj) in t.iter().enumerate() {
1017 col_major[obs + j * n_total] =
1018 (2.0 * PI * tj).sin() + 0.1 * ((obs * 7 + j * 3) % 100) as f64 / 100.0;
1019 }
1020 }
1021 }
1022
1023 let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
1024 let covariates = FdMatrix::from_column_major(cov_data, n_total, 1).unwrap();
1025
1026 let result = fmm_test_fixed(&data, &subject_ids, &covariates, 3, 99, 42).unwrap();
1027 assert!(
1028 result.p_values[0] > 0.05,
1029 "Should not detect effect, got p={}",
1030 result.p_values[0]
1031 );
1032 }
1033
1034 #[test]
1035 fn test_fmm_invalid_input() {
1036 let data = FdMatrix::zeros(0, 0);
1037 assert!(fmm(&data, &[], None, 1).is_none());
1038
1039 let data = FdMatrix::zeros(10, 50);
1040 let ids = vec![0; 5]; assert!(fmm(&data, &ids, None, 1).is_none());
1042 }
1043
1044 #[test]
1045 fn test_fmm_single_visit_per_subject() {
1046 let n = 10;
1047 let m = 40;
1048 let t = uniform_grid(m);
1049 let mut col_major = vec![0.0; n * m];
1050 let subject_ids: Vec<usize> = (0..n).collect();
1051
1052 for i in 0..n {
1053 for (j, &tj) in t.iter().enumerate() {
1054 col_major[i + j * n] = (2.0 * PI * tj).sin();
1055 }
1056 }
1057 let data = FdMatrix::from_column_major(col_major, n, m).unwrap();
1058
1059 let result = fmm(&data, &subject_ids, None, 2).unwrap();
1061 assert_eq!(result.n_subjects, n);
1062 assert_eq!(result.fitted.nrows(), n);
1063 }
1064
1065 #[test]
1066 fn test_build_subject_map() {
1067 let (map, n) = build_subject_map(&[5, 5, 10, 10, 20]);
1068 assert_eq!(n, 3);
1069 assert_eq!(map, vec![0, 0, 1, 1, 2]);
1070 }
1071
1072 #[test]
1073 fn test_variance_components_positive() {
1074 let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1075 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1076
1077 assert!(result.sigma2_eps >= 0.0);
1078 for &s in &result.sigma2_u {
1079 assert!(s >= 0.0);
1080 }
1081 }
1082}