1use crate::matrix::FdMatrix;
14use crate::regression::fdata_to_pc_1d;
15
16pub struct FmmResult {
18 pub mean_function: Vec<f64>,
20 pub beta_functions: FdMatrix,
22 pub random_effects: FdMatrix,
24 pub fitted: FdMatrix,
26 pub residuals: FdMatrix,
28 pub random_variance: Vec<f64>,
30 pub sigma2_eps: f64,
32 pub sigma2_u: Vec<f64>,
34 pub ncomp: usize,
36 pub n_subjects: usize,
38 pub eigenvalues: Vec<f64>,
40}
41
42pub struct FmmTestResult {
44 pub f_statistics: Vec<f64>,
46 pub p_values: Vec<f64>,
48}
49
50pub fn fmm(
69 data: &FdMatrix,
70 subject_ids: &[usize],
71 covariates: Option<&FdMatrix>,
72 ncomp: usize,
73) -> Option<FmmResult> {
74 let n_total = data.nrows();
75 let m = data.ncols();
76 if n_total == 0 || m == 0 || subject_ids.len() != n_total || ncomp == 0 {
77 return None;
78 }
79
80 let (subject_map, n_subjects) = build_subject_map(subject_ids);
82
83 let fpca = fdata_to_pc_1d(data, ncomp)?;
85 let k = fpca.scores.ncols(); let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
91 let score_scale = h.sqrt();
92
93 let p = covariates.map_or(0, |c| c.ncols());
94 let mut gamma = vec![vec![0.0; k]; p]; let mut u_hat = vec![vec![0.0; k]; n_subjects]; let mut sigma2_u = vec![0.0; k];
97 let mut sigma2_eps_total = 0.0;
98
99 for comp in 0..k {
100 let scores: Vec<f64> = (0..n_total)
102 .map(|i| fpca.scores[(i, comp)] * score_scale)
103 .collect();
104 let result = fit_scalar_mixed_model(&scores, &subject_map, n_subjects, covariates, p);
105 for j in 0..p {
107 gamma[j][comp] = result.gamma[j] / score_scale;
108 }
109 for s in 0..n_subjects {
110 u_hat[s][comp] = result.u_hat[s] / score_scale;
112 }
113 sigma2_u[comp] = result.sigma2_u;
115 sigma2_eps_total += result.sigma2_eps;
116 }
117 let sigma2_eps = sigma2_eps_total / k as f64;
118
119 let beta_functions = recover_beta_functions(&gamma, &fpca.rotation, p, m, k);
121 let random_effects = recover_random_effects(&u_hat, &fpca.rotation, n_subjects, m, k);
122
123 let random_variance = compute_random_variance(&random_effects, n_subjects, m);
125
126 let (fitted, residuals) = compute_fitted_residuals(
128 data,
129 &fpca.mean,
130 &beta_functions,
131 &random_effects,
132 covariates,
133 &subject_map,
134 n_total,
135 m,
136 p,
137 );
138
139 let eigenvalues: Vec<f64> = fpca
140 .singular_values
141 .iter()
142 .map(|&sv| sv * sv / n_total as f64)
143 .collect();
144
145 Some(FmmResult {
146 mean_function: fpca.mean,
147 beta_functions,
148 random_effects,
149 fitted,
150 residuals,
151 random_variance,
152 sigma2_eps,
153 sigma2_u,
154 ncomp: k,
155 n_subjects,
156 eigenvalues,
157 })
158}
159
160fn build_subject_map(subject_ids: &[usize]) -> (Vec<usize>, usize) {
162 let mut unique_ids: Vec<usize> = subject_ids.to_vec();
163 unique_ids.sort_unstable();
164 unique_ids.dedup();
165 let n_subjects = unique_ids.len();
166
167 let map: Vec<usize> = subject_ids
168 .iter()
169 .map(|id| unique_ids.iter().position(|u| u == id).unwrap_or(0))
170 .collect();
171
172 (map, n_subjects)
173}
174
175struct ScalarMixedResult {
177 gamma: Vec<f64>, u_hat: Vec<f64>, sigma2_u: f64, sigma2_eps: f64, }
182
183struct SubjectStructure {
185 counts: Vec<usize>,
186 obs: Vec<Vec<usize>>,
187}
188
189impl SubjectStructure {
190 fn new(subject_map: &[usize], n_subjects: usize, n: usize) -> Self {
191 let mut counts = vec![0usize; n_subjects];
192 let mut obs: Vec<Vec<usize>> = vec![Vec::new(); n_subjects];
193 for i in 0..n {
194 let s = subject_map[i];
195 counts[s] += 1;
196 obs[s].push(i);
197 }
198 Self { counts, obs }
199 }
200}
201
202fn shrinkage_weights(ss: &SubjectStructure, sigma2_u: f64, sigma2_e: f64) -> Vec<f64> {
204 ss.counts
205 .iter()
206 .map(|&c| {
207 let ns = c as f64;
208 if ns < 1.0 {
209 0.0
210 } else {
211 sigma2_u / (sigma2_u + sigma2_e / ns)
212 }
213 })
214 .collect()
215}
216
217fn gls_update_gamma(
221 cov: &FdMatrix,
222 p: usize,
223 ss: &SubjectStructure,
224 weights: &[f64],
225 y: &[f64],
226 sigma2_e: f64,
227) -> Option<Vec<f64>> {
228 let n_subjects = ss.counts.len();
229 let mut xtvinvx = vec![0.0; p * p];
230 let mut xtvinvy = vec![0.0; p];
231 let inv_e = 1.0 / sigma2_e;
232
233 for s in 0..n_subjects {
234 let ns = ss.counts[s] as f64;
235 if ns < 1.0 {
236 continue;
237 }
238 let w_s = weights[s];
239
240 let mut x_sum = vec![0.0; p];
242 let mut y_sum = 0.0;
243 for &i in &ss.obs[s] {
244 for r in 0..p {
245 x_sum[r] += cov[(i, r)];
246 }
247 y_sum += y[i];
248 }
249
250 for &i in &ss.obs[s] {
252 let mut vinv_x = vec![0.0; p];
253 for r in 0..p {
254 vinv_x[r] = inv_e * (cov[(i, r)] - w_s * x_sum[r] / ns);
255 }
256 let vinv_y = inv_e * (y[i] - w_s * y_sum / ns);
257
258 for r in 0..p {
259 xtvinvy[r] += cov[(i, r)] * vinv_y;
260 for c in r..p {
261 let val = cov[(i, r)] * vinv_x[c];
262 xtvinvx[r * p + c] += val;
263 if r != c {
264 xtvinvx[c * p + r] += val;
265 }
266 }
267 }
268 }
269 }
270
271 for j in 0..p {
272 xtvinvx[j * p + j] += 1e-10;
273 }
274 cholesky_solve(&xtvinvx, &xtvinvy, p)
275}
276
277fn reml_variance_update(
281 residuals: &[f64],
282 ss: &SubjectStructure,
283 weights: &[f64],
284 sigma2_u: f64,
285) -> (f64, f64) {
286 let n_subjects = ss.counts.len();
287 let n: usize = ss.counts.iter().sum();
288 let mut sigma2_u_new = 0.0;
289 let mut sigma2_e_new = 0.0;
290
291 for s in 0..n_subjects {
292 let ns = ss.counts[s] as f64;
293 if ns < 1.0 {
294 continue;
295 }
296 let w_s = weights[s];
297 let mean_r_s: f64 = ss.obs[s].iter().map(|&i| residuals[i]).sum::<f64>() / ns;
298 let u_hat_s = w_s * mean_r_s;
299 let cond_var_s = sigma2_u * (1.0 - w_s);
300
301 sigma2_u_new += u_hat_s * u_hat_s + cond_var_s;
302 for &i in &ss.obs[s] {
303 sigma2_e_new += (residuals[i] - u_hat_s).powi(2);
304 }
305 sigma2_e_new += ns * cond_var_s;
306 }
307
308 (
309 (sigma2_u_new / n_subjects as f64).max(1e-15),
310 (sigma2_e_new / n as f64).max(1e-15),
311 )
312}
313
314fn fit_scalar_mixed_model(
320 y: &[f64],
321 subject_map: &[usize],
322 n_subjects: usize,
323 covariates: Option<&FdMatrix>,
324 p: usize,
325) -> ScalarMixedResult {
326 let n = y.len();
327 let ss = SubjectStructure::new(subject_map, n_subjects, n);
328
329 let gamma_init = estimate_fixed_effects(y, covariates, p, n);
331 let residuals_init = compute_ols_residuals(y, covariates, &gamma_init, p, n);
332 let (mut sigma2_u, mut sigma2_e) =
333 estimate_variance_components(&residuals_init, subject_map, n_subjects, n);
334
335 if sigma2_e < 1e-15 {
336 sigma2_e = 1e-6;
337 }
338 if sigma2_u < 1e-15 {
339 sigma2_u = sigma2_e * 0.1;
340 }
341
342 let mut gamma = gamma_init;
343
344 for _iter in 0..50 {
345 let sigma2_u_old = sigma2_u;
346 let sigma2_e_old = sigma2_e;
347
348 let weights = shrinkage_weights(&ss, sigma2_u, sigma2_e);
349
350 if let Some(cov) = covariates.filter(|_| p > 0) {
351 if let Some(g) = gls_update_gamma(cov, p, &ss, &weights, y, sigma2_e) {
352 gamma = g;
353 }
354 }
355
356 let r = compute_ols_residuals(y, covariates, &gamma, p, n);
357 (sigma2_u, sigma2_e) = reml_variance_update(&r, &ss, &weights, sigma2_u);
358
359 let delta = (sigma2_u - sigma2_u_old).abs() + (sigma2_e - sigma2_e_old).abs();
360 if delta < 1e-10 * (sigma2_u_old + sigma2_e_old) {
361 break;
362 }
363 }
364
365 let final_residuals = compute_ols_residuals(y, covariates, &gamma, p, n);
366 let u_hat = compute_blup(
367 &final_residuals,
368 subject_map,
369 n_subjects,
370 sigma2_u,
371 sigma2_e,
372 );
373
374 ScalarMixedResult {
375 gamma,
376 u_hat,
377 sigma2_u,
378 sigma2_eps: sigma2_e,
379 }
380}
381
382fn estimate_fixed_effects(
384 y: &[f64],
385 covariates: Option<&FdMatrix>,
386 p: usize,
387 n: usize,
388) -> Vec<f64> {
389 if p == 0 || covariates.is_none() {
390 return Vec::new();
391 }
392 let cov = covariates.unwrap();
393
394 let mut xtx = vec![0.0; p * p];
396 let mut xty = vec![0.0; p];
397 for i in 0..n {
398 for r in 0..p {
399 xty[r] += cov[(i, r)] * y[i];
400 for s in r..p {
401 let val = cov[(i, r)] * cov[(i, s)];
402 xtx[r * p + s] += val;
403 if r != s {
404 xtx[s * p + r] += val;
405 }
406 }
407 }
408 }
409 for j in 0..p {
411 xtx[j * p + j] += 1e-8;
412 }
413
414 cholesky_solve(&xtx, &xty, p).unwrap_or(vec![0.0; p])
415}
416
417fn cholesky_factor_famm(a: &[f64], p: usize) -> Option<Vec<f64>> {
419 let mut l = vec![0.0; p * p];
420 for j in 0..p {
421 let mut sum = 0.0;
422 for k in 0..j {
423 sum += l[j * p + k] * l[j * p + k];
424 }
425 let diag = a[j * p + j] - sum;
426 if diag <= 0.0 {
427 return None;
428 }
429 l[j * p + j] = diag.sqrt();
430 for i in (j + 1)..p {
431 let mut s = 0.0;
432 for k in 0..j {
433 s += l[i * p + k] * l[j * p + k];
434 }
435 l[i * p + j] = (a[i * p + j] - s) / l[j * p + j];
436 }
437 }
438 Some(l)
439}
440
441fn cholesky_triangular_solve(l: &[f64], b: &[f64], p: usize) -> Vec<f64> {
443 let mut z = vec![0.0; p];
444 for i in 0..p {
445 let mut s = 0.0;
446 for j in 0..i {
447 s += l[i * p + j] * z[j];
448 }
449 z[i] = (b[i] - s) / l[i * p + i];
450 }
451 for i in (0..p).rev() {
452 let mut s = 0.0;
453 for j in (i + 1)..p {
454 s += l[j * p + i] * z[j];
455 }
456 z[i] = (z[i] - s) / l[i * p + i];
457 }
458 z
459}
460
461fn cholesky_solve(a: &[f64], b: &[f64], p: usize) -> Option<Vec<f64>> {
463 let l = cholesky_factor_famm(a, p)?;
464 Some(cholesky_triangular_solve(&l, b, p))
465}
466
467fn compute_ols_residuals(
469 y: &[f64],
470 covariates: Option<&FdMatrix>,
471 gamma: &[f64],
472 p: usize,
473 n: usize,
474) -> Vec<f64> {
475 let mut residuals = y.to_vec();
476 if p > 0 {
477 if let Some(cov) = covariates {
478 for i in 0..n {
479 for j in 0..p {
480 residuals[i] -= cov[(i, j)] * gamma[j];
481 }
482 }
483 }
484 }
485 residuals
486}
487
488fn estimate_variance_components(
492 residuals: &[f64],
493 subject_map: &[usize],
494 n_subjects: usize,
495 n: usize,
496) -> (f64, f64) {
497 let mut subject_sums = vec![0.0; n_subjects];
499 let mut subject_counts = vec![0usize; n_subjects];
500 for i in 0..n {
501 let s = subject_map[i];
502 subject_sums[s] += residuals[i];
503 subject_counts[s] += 1;
504 }
505 let subject_means: Vec<f64> = subject_sums
506 .iter()
507 .zip(&subject_counts)
508 .map(|(&s, &c)| if c > 0 { s / c as f64 } else { 0.0 })
509 .collect();
510
511 let mut ss_within = 0.0;
513 for i in 0..n {
514 let s = subject_map[i];
515 ss_within += (residuals[i] - subject_means[s]).powi(2);
516 }
517 let df_within = n.saturating_sub(n_subjects);
518
519 let grand_mean = residuals.iter().sum::<f64>() / n as f64;
521 let mut ss_between = 0.0;
522 for s in 0..n_subjects {
523 ss_between += subject_counts[s] as f64 * (subject_means[s] - grand_mean).powi(2);
524 }
525
526 let sigma2_eps = if df_within > 0 {
527 ss_within / df_within as f64
528 } else {
529 1e-6
530 };
531
532 let n_bar = n as f64 / n_subjects.max(1) as f64;
534 let df_between = n_subjects.saturating_sub(1).max(1);
535 let ms_between = ss_between / df_between as f64;
536 let sigma2_u = ((ms_between - sigma2_eps) / n_bar).max(0.0);
537
538 (sigma2_u, sigma2_eps)
539}
540
541fn compute_blup(
545 residuals: &[f64],
546 subject_map: &[usize],
547 n_subjects: usize,
548 sigma2_u: f64,
549 sigma2_eps: f64,
550) -> Vec<f64> {
551 let mut subject_sums = vec![0.0; n_subjects];
552 let mut subject_counts = vec![0usize; n_subjects];
553 for (i, &r) in residuals.iter().enumerate() {
554 let s = subject_map[i];
555 subject_sums[s] += r;
556 subject_counts[s] += 1;
557 }
558
559 (0..n_subjects)
560 .map(|s| {
561 let ni = subject_counts[s] as f64;
562 if ni < 1.0 {
563 return 0.0;
564 }
565 let mean_r = subject_sums[s] / ni;
566 let shrinkage = sigma2_u / (sigma2_u + sigma2_eps / ni).max(1e-15);
567 shrinkage * mean_r
568 })
569 .collect()
570}
571
572fn recover_beta_functions(
578 gamma: &[Vec<f64>],
579 rotation: &FdMatrix,
580 p: usize,
581 m: usize,
582 k: usize,
583) -> FdMatrix {
584 let mut beta = FdMatrix::zeros(p, m);
585 for j in 0..p {
586 for t in 0..m {
587 let mut val = 0.0;
588 for comp in 0..k {
589 val += gamma[j][comp] * rotation[(t, comp)];
590 }
591 beta[(j, t)] = val;
592 }
593 }
594 beta
595}
596
597fn recover_random_effects(
599 u_hat: &[Vec<f64>],
600 rotation: &FdMatrix,
601 n_subjects: usize,
602 m: usize,
603 k: usize,
604) -> FdMatrix {
605 let mut re = FdMatrix::zeros(n_subjects, m);
606 for s in 0..n_subjects {
607 for t in 0..m {
608 let mut val = 0.0;
609 for comp in 0..k {
610 val += u_hat[s][comp] * rotation[(t, comp)];
611 }
612 re[(s, t)] = val;
613 }
614 }
615 re
616}
617
618fn compute_random_variance(random_effects: &FdMatrix, n_subjects: usize, m: usize) -> Vec<f64> {
620 (0..m)
621 .map(|t| {
622 let mean: f64 =
623 (0..n_subjects).map(|s| random_effects[(s, t)]).sum::<f64>() / n_subjects as f64;
624 let var: f64 = (0..n_subjects)
625 .map(|s| (random_effects[(s, t)] - mean).powi(2))
626 .sum::<f64>()
627 / n_subjects.max(1) as f64;
628 var
629 })
630 .collect()
631}
632
633fn compute_fitted_residuals(
635 data: &FdMatrix,
636 mean_function: &[f64],
637 beta_functions: &FdMatrix,
638 random_effects: &FdMatrix,
639 covariates: Option<&FdMatrix>,
640 subject_map: &[usize],
641 n_total: usize,
642 m: usize,
643 p: usize,
644) -> (FdMatrix, FdMatrix) {
645 let mut fitted = FdMatrix::zeros(n_total, m);
646 let mut residuals = FdMatrix::zeros(n_total, m);
647
648 for i in 0..n_total {
649 let s = subject_map[i];
650 for t in 0..m {
651 let mut val = mean_function[t] + random_effects[(s, t)];
652 if p > 0 {
653 if let Some(cov) = covariates {
654 for j in 0..p {
655 val += cov[(i, j)] * beta_functions[(j, t)];
656 }
657 }
658 }
659 fitted[(i, t)] = val;
660 residuals[(i, t)] = data[(i, t)] - val;
661 }
662 }
663
664 (fitted, residuals)
665}
666
667pub fn fmm_predict(result: &FmmResult, new_covariates: Option<&FdMatrix>) -> FdMatrix {
679 let m = result.mean_function.len();
680 let n_new = new_covariates.map_or(1, |c| c.nrows());
681 let p = result.beta_functions.nrows();
682
683 let mut predicted = FdMatrix::zeros(n_new, m);
684 for i in 0..n_new {
685 for t in 0..m {
686 let mut val = result.mean_function[t];
687 if let Some(cov) = new_covariates {
688 for j in 0..p {
689 val += cov[(i, j)] * result.beta_functions[(j, t)];
690 }
691 }
692 predicted[(i, t)] = val;
693 }
694 }
695 predicted
696}
697
698pub fn fmm_test_fixed(
715 data: &FdMatrix,
716 subject_ids: &[usize],
717 covariates: &FdMatrix,
718 ncomp: usize,
719 n_perm: usize,
720 seed: u64,
721) -> Option<FmmTestResult> {
722 let n_total = data.nrows();
723 let m = data.ncols();
724 let p = covariates.ncols();
725 if n_total == 0 || p == 0 {
726 return None;
727 }
728
729 let result = fmm(data, subject_ids, Some(covariates), ncomp)?;
731
732 let observed_stats = compute_integrated_beta_sq(&result.beta_functions, p, m);
734
735 let (f_statistics, p_values) = permutation_test(
737 data,
738 subject_ids,
739 covariates,
740 ncomp,
741 n_perm,
742 seed,
743 &observed_stats,
744 p,
745 m,
746 );
747
748 Some(FmmTestResult {
749 f_statistics,
750 p_values,
751 })
752}
753
754fn compute_integrated_beta_sq(beta: &FdMatrix, p: usize, m: usize) -> Vec<f64> {
756 let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
757 (0..p)
758 .map(|j| {
759 let ss: f64 = (0..m).map(|t| beta[(j, t)].powi(2)).sum();
760 ss * h
761 })
762 .collect()
763}
764
765fn permutation_test(
767 data: &FdMatrix,
768 subject_ids: &[usize],
769 covariates: &FdMatrix,
770 ncomp: usize,
771 n_perm: usize,
772 seed: u64,
773 observed_stats: &[f64],
774 p: usize,
775 m: usize,
776) -> (Vec<f64>, Vec<f64>) {
777 use rand::prelude::*;
778 let n_total = data.nrows();
779 let mut rng = StdRng::seed_from_u64(seed);
780 let mut n_ge = vec![0usize; p];
781
782 for _ in 0..n_perm {
783 let mut perm_indices: Vec<usize> = (0..n_total).collect();
785 perm_indices.shuffle(&mut rng);
786 let perm_cov = permute_rows(covariates, &perm_indices);
787
788 if let Some(perm_result) = fmm(data, subject_ids, Some(&perm_cov), ncomp) {
789 let perm_stats = compute_integrated_beta_sq(&perm_result.beta_functions, p, m);
790 for j in 0..p {
791 if perm_stats[j] >= observed_stats[j] {
792 n_ge[j] += 1;
793 }
794 }
795 }
796 }
797
798 let p_values: Vec<f64> = n_ge
799 .iter()
800 .map(|&count| (count + 1) as f64 / (n_perm + 1) as f64)
801 .collect();
802 let f_statistics = observed_stats.to_vec();
803
804 (f_statistics, p_values)
805}
806
807fn permute_rows(mat: &FdMatrix, indices: &[usize]) -> FdMatrix {
809 let n = indices.len();
810 let m = mat.ncols();
811 let mut result = FdMatrix::zeros(n, m);
812 for (new_i, &old_i) in indices.iter().enumerate() {
813 for j in 0..m {
814 result[(new_i, j)] = mat[(old_i, j)];
815 }
816 }
817 result
818}
819
820#[cfg(test)]
825mod tests {
826 use super::*;
827 use std::f64::consts::PI;
828
829 fn uniform_grid(m: usize) -> Vec<f64> {
830 (0..m).map(|i| i as f64 / (m - 1) as f64).collect()
831 }
832
833 fn generate_fmm_data(
836 n_subjects: usize,
837 n_visits: usize,
838 m: usize,
839 ) -> (FdMatrix, Vec<usize>, FdMatrix, Vec<f64>) {
840 let t = uniform_grid(m);
841 let n_total = n_subjects * n_visits;
842 let mut col_major = vec![0.0; n_total * m];
843 let mut subject_ids = vec![0usize; n_total];
844 let mut cov_data = vec![0.0; n_total];
845
846 for s in 0..n_subjects {
847 let z = s as f64 / n_subjects as f64; let subject_effect = 0.5 * (s as f64 - n_subjects as f64 / 2.0); for v in 0..n_visits {
851 let obs = s * n_visits + v;
852 subject_ids[obs] = s;
853 cov_data[obs] = z;
854 let noise_scale = 0.05;
855
856 for (j, &tj) in t.iter().enumerate() {
857 let mu = (2.0 * PI * tj).sin();
859 let fixed = z * tj * 3.0;
860 let random = subject_effect * (2.0 * PI * tj).cos() * 0.3;
861 let noise = noise_scale * ((obs * 7 + j * 3) % 100) as f64 / 100.0;
862 col_major[obs + j * n_total] = mu + fixed + random + noise;
863 }
864 }
865 }
866
867 let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
868 let covariates = FdMatrix::from_column_major(cov_data, n_total, 1).unwrap();
869 (data, subject_ids, covariates, t)
870 }
871
872 #[test]
873 fn test_fmm_basic() {
874 let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
875 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
876
877 assert_eq!(result.mean_function.len(), 50);
878 assert_eq!(result.beta_functions.nrows(), 1); assert_eq!(result.beta_functions.ncols(), 50);
880 assert_eq!(result.random_effects.nrows(), 10);
881 assert_eq!(result.fitted.nrows(), 30);
882 assert_eq!(result.residuals.nrows(), 30);
883 assert_eq!(result.n_subjects, 10);
884 }
885
886 #[test]
887 fn test_fmm_fitted_plus_residuals_equals_data() {
888 let (data, subject_ids, covariates, _t) = generate_fmm_data(8, 3, 40);
889 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
890
891 let n = data.nrows();
892 let m = data.ncols();
893 for i in 0..n {
894 for t in 0..m {
895 let reconstructed = result.fitted[(i, t)] + result.residuals[(i, t)];
896 assert!(
897 (reconstructed - data[(i, t)]).abs() < 1e-8,
898 "Fitted + residual should equal data at ({}, {}): {} vs {}",
899 i,
900 t,
901 reconstructed,
902 data[(i, t)]
903 );
904 }
905 }
906 }
907
908 #[test]
909 fn test_fmm_random_variance_positive() {
910 let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
911 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
912
913 for &v in &result.random_variance {
914 assert!(v >= 0.0, "Random variance should be non-negative");
915 }
916 }
917
918 #[test]
919 fn test_fmm_no_covariates() {
920 let (data, subject_ids, _cov, _t) = generate_fmm_data(8, 3, 40);
921 let result = fmm(&data, &subject_ids, None, 3).unwrap();
922
923 assert_eq!(result.beta_functions.nrows(), 0);
924 assert_eq!(result.n_subjects, 8);
925 assert_eq!(result.fitted.nrows(), 24);
926 }
927
928 #[test]
929 fn test_fmm_predict() {
930 let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
931 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
932
933 let new_cov = FdMatrix::from_column_major(vec![0.5], 1, 1).unwrap();
935 let predicted = fmm_predict(&result, Some(&new_cov));
936
937 assert_eq!(predicted.nrows(), 1);
938 assert_eq!(predicted.ncols(), 50);
939
940 for t in 0..50 {
942 assert!(predicted[(0, t)].is_finite());
943 assert!(
944 predicted[(0, t)].abs() < 20.0,
945 "Predicted value too extreme at t={}: {}",
946 t,
947 predicted[(0, t)]
948 );
949 }
950 }
951
952 #[test]
953 fn test_fmm_test_fixed_detects_effect() {
954 let (data, subject_ids, covariates, _t) = generate_fmm_data(15, 3, 40);
955
956 let result = fmm_test_fixed(&data, &subject_ids, &covariates, 3, 99, 42).unwrap();
957
958 assert_eq!(result.f_statistics.len(), 1);
959 assert_eq!(result.p_values.len(), 1);
960 assert!(
961 result.p_values[0] < 0.1,
962 "Should detect covariate effect, got p={}",
963 result.p_values[0]
964 );
965 }
966
967 #[test]
968 fn test_fmm_test_fixed_no_effect() {
969 let n_subjects = 10;
970 let n_visits = 3;
971 let m = 40;
972 let t = uniform_grid(m);
973 let n_total = n_subjects * n_visits;
974
975 let mut col_major = vec![0.0; n_total * m];
977 let mut subject_ids = vec![0usize; n_total];
978 let mut cov_data = vec![0.0; n_total];
979
980 for s in 0..n_subjects {
981 for v in 0..n_visits {
982 let obs = s * n_visits + v;
983 subject_ids[obs] = s;
984 cov_data[obs] = s as f64 / n_subjects as f64;
985 for (j, &tj) in t.iter().enumerate() {
986 col_major[obs + j * n_total] =
987 (2.0 * PI * tj).sin() + 0.1 * ((obs * 7 + j * 3) % 100) as f64 / 100.0;
988 }
989 }
990 }
991
992 let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
993 let covariates = FdMatrix::from_column_major(cov_data, n_total, 1).unwrap();
994
995 let result = fmm_test_fixed(&data, &subject_ids, &covariates, 3, 99, 42).unwrap();
996 assert!(
997 result.p_values[0] > 0.05,
998 "Should not detect effect, got p={}",
999 result.p_values[0]
1000 );
1001 }
1002
1003 #[test]
1004 fn test_fmm_invalid_input() {
1005 let data = FdMatrix::zeros(0, 0);
1006 assert!(fmm(&data, &[], None, 1).is_none());
1007
1008 let data = FdMatrix::zeros(10, 50);
1009 let ids = vec![0; 5]; assert!(fmm(&data, &ids, None, 1).is_none());
1011 }
1012
1013 #[test]
1014 fn test_fmm_single_visit_per_subject() {
1015 let n = 10;
1016 let m = 40;
1017 let t = uniform_grid(m);
1018 let mut col_major = vec![0.0; n * m];
1019 let subject_ids: Vec<usize> = (0..n).collect();
1020
1021 for i in 0..n {
1022 for (j, &tj) in t.iter().enumerate() {
1023 col_major[i + j * n] = (2.0 * PI * tj).sin();
1024 }
1025 }
1026 let data = FdMatrix::from_column_major(col_major, n, m).unwrap();
1027
1028 let result = fmm(&data, &subject_ids, None, 2).unwrap();
1030 assert_eq!(result.n_subjects, n);
1031 assert_eq!(result.fitted.nrows(), n);
1032 }
1033
1034 #[test]
1035 fn test_build_subject_map() {
1036 let (map, n) = build_subject_map(&[5, 5, 10, 10, 20]);
1037 assert_eq!(n, 3);
1038 assert_eq!(map, vec![0, 0, 1, 1, 2]);
1039 }
1040
1041 #[test]
1042 fn test_variance_components_positive() {
1043 let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1044 let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1045
1046 assert!(result.sigma2_eps >= 0.0);
1047 for &s in &result.sigma2_u {
1048 assert!(s >= 0.0);
1049 }
1050 }
1051}