Skip to main content

fdars_core/
famm.rs

1//! Functional Additive Mixed Models (FAMM).
2//!
3//! Implements functional mixed effects models for repeated functional
4//! measurements with subject-level covariates.
5//!
6//! Model: `Y_ij(t) = μ(t) + X_i'β(t) + b_i(t) + ε_ij(t)`
7//!
8//! Key functions:
9//! - [`fmm`] — Fit a functional mixed model via FPC decomposition
10//! - [`fmm_predict`] — Predict curves for new subjects
11//! - [`fmm_test_fixed`] — Hypothesis test on fixed effects
12
13use crate::matrix::FdMatrix;
14use crate::regression::fdata_to_pc_1d;
15
16/// Result of a functional mixed model fit.
17pub struct FmmResult {
18    /// Overall mean function μ̂(t) (length m)
19    pub mean_function: Vec<f64>,
20    /// Fixed effect coefficient functions β̂_j(t) (p × m matrix, one row per covariate)
21    pub beta_functions: FdMatrix,
22    /// Random effect functions b̂_i(t) per subject (n_subjects × m)
23    pub random_effects: FdMatrix,
24    /// Fitted values for all observations (n_total × m)
25    pub fitted: FdMatrix,
26    /// Residuals (n_total × m)
27    pub residuals: FdMatrix,
28    /// Variance of random effects at each time point (length m)
29    pub random_variance: Vec<f64>,
30    /// Residual variance estimate
31    pub sigma2_eps: f64,
32    /// Random effect variance estimate (per-component)
33    pub sigma2_u: Vec<f64>,
34    /// Number of FPC components used
35    pub ncomp: usize,
36    /// Number of subjects
37    pub n_subjects: usize,
38    /// FPC eigenvalues (singular values squared / n)
39    pub eigenvalues: Vec<f64>,
40}
41
42/// Result of fixed effect hypothesis test.
43pub struct FmmTestResult {
44    /// F-statistic per covariate (length p)
45    pub f_statistics: Vec<f64>,
46    /// P-values per covariate (via permutation, length p)
47    pub p_values: Vec<f64>,
48}
49
50// ---------------------------------------------------------------------------
51// Core FMM algorithm
52// ---------------------------------------------------------------------------
53
54/// Fit a functional mixed model via FPC decomposition.
55///
56/// # Arguments
57/// * `data` — All observed curves (n_total × m), stacked across subjects and visits
58/// * `subject_ids` — Subject identifier for each curve (length n_total)
59/// * `covariates` — Subject-level covariates (n_total × p).
60///   Each row corresponds to the same curve in `data`.
61///   If a covariate is subject-level, its value should be repeated across visits.
62/// * `ncomp` — Number of FPC components
63///
64/// # Algorithm
65/// 1. Pool curves, compute FPCA
66/// 2. For each FPC score, fit scalar mixed model: ξ_ijk = x_i'γ_k + u_ik + e_ijk
67/// 3. Recover β̂(t) and b̂_i(t) from component coefficients
68pub fn fmm(
69    data: &FdMatrix,
70    subject_ids: &[usize],
71    covariates: Option<&FdMatrix>,
72    ncomp: usize,
73) -> Option<FmmResult> {
74    let n_total = data.nrows();
75    let m = data.ncols();
76    if n_total == 0 || m == 0 || subject_ids.len() != n_total || ncomp == 0 {
77        return None;
78    }
79
80    // Determine unique subjects
81    let (subject_map, n_subjects) = build_subject_map(subject_ids);
82
83    // Step 1: FPCA on pooled data
84    let fpca = fdata_to_pc_1d(data, ncomp)?;
85    let k = fpca.scores.ncols(); // actual number of components
86
87    // Step 2: For each FPC score, fit scalar mixed model
88    // Normalize scores by sqrt(h) to match R's L²-weighted FPCA convention.
89    // This ensures variance components are on the same scale as R's lmer().
90    let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
91    let score_scale = h.sqrt();
92
93    let p = covariates.map_or(0, |c| c.ncols());
94    let mut gamma = vec![vec![0.0; k]; p]; // gamma[j][k] = fixed effect coeff j for component k
95    let mut u_hat = vec![vec![0.0; k]; n_subjects]; // u_hat[i][k] = random effect for subject i, component k
96    let mut sigma2_u = vec![0.0; k];
97    let mut sigma2_eps_total = 0.0;
98
99    for comp in 0..k {
100        // Scale scores to L²-normalized space
101        let scores: Vec<f64> = (0..n_total)
102            .map(|i| fpca.scores[(i, comp)] * score_scale)
103            .collect();
104        let result = fit_scalar_mixed_model(&scores, &subject_map, n_subjects, covariates, p);
105        // Scale gamma back to original score space for beta reconstruction
106        for j in 0..p {
107            gamma[j][comp] = result.gamma[j] / score_scale;
108        }
109        for s in 0..n_subjects {
110            // Scale u_hat back for random effect reconstruction
111            u_hat[s][comp] = result.u_hat[s] / score_scale;
112        }
113        // Keep variance components in L²-normalized space (matching R)
114        sigma2_u[comp] = result.sigma2_u;
115        sigma2_eps_total += result.sigma2_eps;
116    }
117    let sigma2_eps = sigma2_eps_total / k as f64;
118
119    // Step 3: Recover functional coefficients (using gamma in original scale)
120    let beta_functions = recover_beta_functions(&gamma, &fpca.rotation, p, m, k);
121    let random_effects = recover_random_effects(&u_hat, &fpca.rotation, n_subjects, m, k);
122
123    // Compute random variance function: Var(b_i(t)) across subjects
124    let random_variance = compute_random_variance(&random_effects, n_subjects, m);
125
126    // Compute fitted and residuals
127    let (fitted, residuals) = compute_fitted_residuals(
128        data,
129        &fpca.mean,
130        &beta_functions,
131        &random_effects,
132        covariates,
133        &subject_map,
134        n_total,
135        m,
136        p,
137    );
138
139    let eigenvalues: Vec<f64> = fpca
140        .singular_values
141        .iter()
142        .map(|&sv| sv * sv / n_total as f64)
143        .collect();
144
145    Some(FmmResult {
146        mean_function: fpca.mean,
147        beta_functions,
148        random_effects,
149        fitted,
150        residuals,
151        random_variance,
152        sigma2_eps,
153        sigma2_u,
154        ncomp: k,
155        n_subjects,
156        eigenvalues,
157    })
158}
159
160/// Build mapping from observation index to subject index (0..n_subjects-1).
161fn build_subject_map(subject_ids: &[usize]) -> (Vec<usize>, usize) {
162    let mut unique_ids: Vec<usize> = subject_ids.to_vec();
163    unique_ids.sort_unstable();
164    unique_ids.dedup();
165    let n_subjects = unique_ids.len();
166
167    let map: Vec<usize> = subject_ids
168        .iter()
169        .map(|id| unique_ids.iter().position(|u| u == id).unwrap_or(0))
170        .collect();
171
172    (map, n_subjects)
173}
174
175/// Scalar mixed model result for one FPC component.
176struct ScalarMixedResult {
177    gamma: Vec<f64>, // fixed effects (length p)
178    u_hat: Vec<f64>, // random effects per subject (length n_subjects)
179    sigma2_u: f64,   // random effect variance
180    sigma2_eps: f64, // residual variance
181}
182
183/// Precomputed subject structure for the mixed model.
184struct SubjectStructure {
185    counts: Vec<usize>,
186    obs: Vec<Vec<usize>>,
187}
188
189impl SubjectStructure {
190    fn new(subject_map: &[usize], n_subjects: usize, n: usize) -> Self {
191        let mut counts = vec![0usize; n_subjects];
192        let mut obs: Vec<Vec<usize>> = vec![Vec::new(); n_subjects];
193        for i in 0..n {
194            let s = subject_map[i];
195            counts[s] += 1;
196            obs[s].push(i);
197        }
198        Self { counts, obs }
199    }
200}
201
202/// Compute shrinkage weights: w_s = σ²_u / (σ²_u + σ²_e / n_s).
203fn shrinkage_weights(ss: &SubjectStructure, sigma2_u: f64, sigma2_e: f64) -> Vec<f64> {
204    ss.counts
205        .iter()
206        .map(|&c| {
207            let ns = c as f64;
208            if ns < 1.0 {
209                0.0
210            } else {
211                sigma2_u / (sigma2_u + sigma2_e / ns)
212            }
213        })
214        .collect()
215}
216
217/// GLS fixed effect update using block-diagonal V^{-1}.
218///
219/// Computes γ = (X'V⁻¹X)⁻¹ X'V⁻¹y exploiting the balanced random intercept structure.
220fn gls_update_gamma(
221    cov: &FdMatrix,
222    p: usize,
223    ss: &SubjectStructure,
224    weights: &[f64],
225    y: &[f64],
226    sigma2_e: f64,
227) -> Option<Vec<f64>> {
228    let n_subjects = ss.counts.len();
229    let mut xtvinvx = vec![0.0; p * p];
230    let mut xtvinvy = vec![0.0; p];
231    let inv_e = 1.0 / sigma2_e;
232
233    for s in 0..n_subjects {
234        let ns = ss.counts[s] as f64;
235        if ns < 1.0 {
236            continue;
237        }
238        let w_s = weights[s];
239
240        // Subject-level sums
241        let mut x_sum = vec![0.0; p];
242        let mut y_sum = 0.0;
243        for &i in &ss.obs[s] {
244            for r in 0..p {
245                x_sum[r] += cov[(i, r)];
246            }
247            y_sum += y[i];
248        }
249
250        // Accumulate X'V^{-1}X and X'V^{-1}y
251        for &i in &ss.obs[s] {
252            let mut vinv_x = vec![0.0; p];
253            for r in 0..p {
254                vinv_x[r] = inv_e * (cov[(i, r)] - w_s * x_sum[r] / ns);
255            }
256            let vinv_y = inv_e * (y[i] - w_s * y_sum / ns);
257
258            for r in 0..p {
259                xtvinvy[r] += cov[(i, r)] * vinv_y;
260                for c in r..p {
261                    let val = cov[(i, r)] * vinv_x[c];
262                    xtvinvx[r * p + c] += val;
263                    if r != c {
264                        xtvinvx[c * p + r] += val;
265                    }
266                }
267            }
268        }
269    }
270
271    for j in 0..p {
272        xtvinvx[j * p + j] += 1e-10;
273    }
274    cholesky_solve(&xtvinvx, &xtvinvy, p)
275}
276
277/// REML EM update for variance components.
278///
279/// Returns (σ²_u_new, σ²_e_new) from the conditional expectations.
280fn reml_variance_update(
281    residuals: &[f64],
282    ss: &SubjectStructure,
283    weights: &[f64],
284    sigma2_u: f64,
285) -> (f64, f64) {
286    let n_subjects = ss.counts.len();
287    let n: usize = ss.counts.iter().sum();
288    let mut sigma2_u_new = 0.0;
289    let mut sigma2_e_new = 0.0;
290
291    for s in 0..n_subjects {
292        let ns = ss.counts[s] as f64;
293        if ns < 1.0 {
294            continue;
295        }
296        let w_s = weights[s];
297        let mean_r_s: f64 = ss.obs[s].iter().map(|&i| residuals[i]).sum::<f64>() / ns;
298        let u_hat_s = w_s * mean_r_s;
299        let cond_var_s = sigma2_u * (1.0 - w_s);
300
301        sigma2_u_new += u_hat_s * u_hat_s + cond_var_s;
302        for &i in &ss.obs[s] {
303            sigma2_e_new += (residuals[i] - u_hat_s).powi(2);
304        }
305        sigma2_e_new += ns * cond_var_s;
306    }
307
308    (
309        (sigma2_u_new / n_subjects as f64).max(1e-15),
310        (sigma2_e_new / n as f64).max(1e-15),
311    )
312}
313
314/// Fit scalar mixed model: y_ij = x_i'γ + u_i + e_ij.
315///
316/// Uses iterative GLS for fixed effects + REML EM for variance components,
317/// matching R's lmer() behavior. Initializes from Henderson's ANOVA, then
318/// iterates until convergence.
319fn fit_scalar_mixed_model(
320    y: &[f64],
321    subject_map: &[usize],
322    n_subjects: usize,
323    covariates: Option<&FdMatrix>,
324    p: usize,
325) -> ScalarMixedResult {
326    let n = y.len();
327    let ss = SubjectStructure::new(subject_map, n_subjects, n);
328
329    // Initialize from OLS + Henderson's ANOVA
330    let gamma_init = estimate_fixed_effects(y, covariates, p, n);
331    let residuals_init = compute_ols_residuals(y, covariates, &gamma_init, p, n);
332    let (mut sigma2_u, mut sigma2_e) =
333        estimate_variance_components(&residuals_init, subject_map, n_subjects, n);
334
335    if sigma2_e < 1e-15 {
336        sigma2_e = 1e-6;
337    }
338    if sigma2_u < 1e-15 {
339        sigma2_u = sigma2_e * 0.1;
340    }
341
342    let mut gamma = gamma_init;
343
344    for _iter in 0..50 {
345        let sigma2_u_old = sigma2_u;
346        let sigma2_e_old = sigma2_e;
347
348        let weights = shrinkage_weights(&ss, sigma2_u, sigma2_e);
349
350        if let Some(cov) = covariates.filter(|_| p > 0) {
351            if let Some(g) = gls_update_gamma(cov, p, &ss, &weights, y, sigma2_e) {
352                gamma = g;
353            }
354        }
355
356        let r = compute_ols_residuals(y, covariates, &gamma, p, n);
357        (sigma2_u, sigma2_e) = reml_variance_update(&r, &ss, &weights, sigma2_u);
358
359        let delta = (sigma2_u - sigma2_u_old).abs() + (sigma2_e - sigma2_e_old).abs();
360        if delta < 1e-10 * (sigma2_u_old + sigma2_e_old) {
361            break;
362        }
363    }
364
365    let final_residuals = compute_ols_residuals(y, covariates, &gamma, p, n);
366    let u_hat = compute_blup(
367        &final_residuals,
368        subject_map,
369        n_subjects,
370        sigma2_u,
371        sigma2_e,
372    );
373
374    ScalarMixedResult {
375        gamma,
376        u_hat,
377        sigma2_u,
378        sigma2_eps: sigma2_e,
379    }
380}
381
382/// OLS estimation of fixed effects.
383fn estimate_fixed_effects(
384    y: &[f64],
385    covariates: Option<&FdMatrix>,
386    p: usize,
387    n: usize,
388) -> Vec<f64> {
389    if p == 0 || covariates.is_none() {
390        return Vec::new();
391    }
392    let cov = covariates.unwrap();
393
394    // Solve (X'X)γ = X'y via Cholesky
395    let mut xtx = vec![0.0; p * p];
396    let mut xty = vec![0.0; p];
397    for i in 0..n {
398        for r in 0..p {
399            xty[r] += cov[(i, r)] * y[i];
400            for s in r..p {
401                let val = cov[(i, r)] * cov[(i, s)];
402                xtx[r * p + s] += val;
403                if r != s {
404                    xtx[s * p + r] += val;
405                }
406            }
407        }
408    }
409    // Regularize
410    for j in 0..p {
411        xtx[j * p + j] += 1e-8;
412    }
413
414    cholesky_solve(&xtx, &xty, p).unwrap_or(vec![0.0; p])
415}
416
417/// Cholesky factorization: A = LL'. Returns L (p×p flat row-major) or None if singular.
418fn cholesky_factor_famm(a: &[f64], p: usize) -> Option<Vec<f64>> {
419    let mut l = vec![0.0; p * p];
420    for j in 0..p {
421        let mut sum = 0.0;
422        for k in 0..j {
423            sum += l[j * p + k] * l[j * p + k];
424        }
425        let diag = a[j * p + j] - sum;
426        if diag <= 0.0 {
427            return None;
428        }
429        l[j * p + j] = diag.sqrt();
430        for i in (j + 1)..p {
431            let mut s = 0.0;
432            for k in 0..j {
433                s += l[i * p + k] * l[j * p + k];
434            }
435            l[i * p + j] = (a[i * p + j] - s) / l[j * p + j];
436        }
437    }
438    Some(l)
439}
440
441/// Solve L z = b (forward) then L' x = z (back).
442fn cholesky_triangular_solve(l: &[f64], b: &[f64], p: usize) -> Vec<f64> {
443    let mut z = vec![0.0; p];
444    for i in 0..p {
445        let mut s = 0.0;
446        for j in 0..i {
447            s += l[i * p + j] * z[j];
448        }
449        z[i] = (b[i] - s) / l[i * p + i];
450    }
451    for i in (0..p).rev() {
452        let mut s = 0.0;
453        for j in (i + 1)..p {
454            s += l[j * p + i] * z[j];
455        }
456        z[i] = (z[i] - s) / l[i * p + i];
457    }
458    z
459}
460
461/// Cholesky solve: A x = b where A is p×p symmetric positive definite.
462fn cholesky_solve(a: &[f64], b: &[f64], p: usize) -> Option<Vec<f64>> {
463    let l = cholesky_factor_famm(a, p)?;
464    Some(cholesky_triangular_solve(&l, b, p))
465}
466
467/// Compute OLS residuals: r = y - X*gamma.
468fn compute_ols_residuals(
469    y: &[f64],
470    covariates: Option<&FdMatrix>,
471    gamma: &[f64],
472    p: usize,
473    n: usize,
474) -> Vec<f64> {
475    let mut residuals = y.to_vec();
476    if p > 0 {
477        if let Some(cov) = covariates {
478            for i in 0..n {
479                for j in 0..p {
480                    residuals[i] -= cov[(i, j)] * gamma[j];
481                }
482            }
483        }
484    }
485    residuals
486}
487
488/// Estimate variance components via method of moments.
489///
490/// σ²_u and σ²_ε from one-way random effects ANOVA.
491fn estimate_variance_components(
492    residuals: &[f64],
493    subject_map: &[usize],
494    n_subjects: usize,
495    n: usize,
496) -> (f64, f64) {
497    // Compute subject means and within-subject SS
498    let mut subject_sums = vec![0.0; n_subjects];
499    let mut subject_counts = vec![0usize; n_subjects];
500    for i in 0..n {
501        let s = subject_map[i];
502        subject_sums[s] += residuals[i];
503        subject_counts[s] += 1;
504    }
505    let subject_means: Vec<f64> = subject_sums
506        .iter()
507        .zip(&subject_counts)
508        .map(|(&s, &c)| if c > 0 { s / c as f64 } else { 0.0 })
509        .collect();
510
511    // Within-subject SS
512    let mut ss_within = 0.0;
513    for i in 0..n {
514        let s = subject_map[i];
515        ss_within += (residuals[i] - subject_means[s]).powi(2);
516    }
517    let df_within = n.saturating_sub(n_subjects);
518
519    // Between-subject SS
520    let grand_mean = residuals.iter().sum::<f64>() / n as f64;
521    let mut ss_between = 0.0;
522    for s in 0..n_subjects {
523        ss_between += subject_counts[s] as f64 * (subject_means[s] - grand_mean).powi(2);
524    }
525
526    let sigma2_eps = if df_within > 0 {
527        ss_within / df_within as f64
528    } else {
529        1e-6
530    };
531
532    // Mean number of observations per subject
533    let n_bar = n as f64 / n_subjects.max(1) as f64;
534    let df_between = n_subjects.saturating_sub(1).max(1);
535    let ms_between = ss_between / df_between as f64;
536    let sigma2_u = ((ms_between - sigma2_eps) / n_bar).max(0.0);
537
538    (sigma2_u, sigma2_eps)
539}
540
541/// Compute BLUP (Best Linear Unbiased Prediction) for random effects.
542///
543/// û_i = σ²_u / (σ²_u + σ²_ε/n_i) * (ȳ_i - x̄_i'γ)
544fn compute_blup(
545    residuals: &[f64],
546    subject_map: &[usize],
547    n_subjects: usize,
548    sigma2_u: f64,
549    sigma2_eps: f64,
550) -> Vec<f64> {
551    let mut subject_sums = vec![0.0; n_subjects];
552    let mut subject_counts = vec![0usize; n_subjects];
553    for (i, &r) in residuals.iter().enumerate() {
554        let s = subject_map[i];
555        subject_sums[s] += r;
556        subject_counts[s] += 1;
557    }
558
559    (0..n_subjects)
560        .map(|s| {
561            let ni = subject_counts[s] as f64;
562            if ni < 1.0 {
563                return 0.0;
564            }
565            let mean_r = subject_sums[s] / ni;
566            let shrinkage = sigma2_u / (sigma2_u + sigma2_eps / ni).max(1e-15);
567            shrinkage * mean_r
568        })
569        .collect()
570}
571
572// ---------------------------------------------------------------------------
573// Recovery of functional coefficients
574// ---------------------------------------------------------------------------
575
576/// Recover β̂(t) = Σ_k γ̂_jk φ_k(t) for each covariate j.
577fn recover_beta_functions(
578    gamma: &[Vec<f64>],
579    rotation: &FdMatrix,
580    p: usize,
581    m: usize,
582    k: usize,
583) -> FdMatrix {
584    let mut beta = FdMatrix::zeros(p, m);
585    for j in 0..p {
586        for t in 0..m {
587            let mut val = 0.0;
588            for comp in 0..k {
589                val += gamma[j][comp] * rotation[(t, comp)];
590            }
591            beta[(j, t)] = val;
592        }
593    }
594    beta
595}
596
597/// Recover b̂_i(t) = Σ_k û_ik φ_k(t) for each subject i.
598fn recover_random_effects(
599    u_hat: &[Vec<f64>],
600    rotation: &FdMatrix,
601    n_subjects: usize,
602    m: usize,
603    k: usize,
604) -> FdMatrix {
605    let mut re = FdMatrix::zeros(n_subjects, m);
606    for s in 0..n_subjects {
607        for t in 0..m {
608            let mut val = 0.0;
609            for comp in 0..k {
610                val += u_hat[s][comp] * rotation[(t, comp)];
611            }
612            re[(s, t)] = val;
613        }
614    }
615    re
616}
617
618/// Compute random effect variance function: Var_i(b̂_i(t)).
619fn compute_random_variance(random_effects: &FdMatrix, n_subjects: usize, m: usize) -> Vec<f64> {
620    (0..m)
621        .map(|t| {
622            let mean: f64 =
623                (0..n_subjects).map(|s| random_effects[(s, t)]).sum::<f64>() / n_subjects as f64;
624            let var: f64 = (0..n_subjects)
625                .map(|s| (random_effects[(s, t)] - mean).powi(2))
626                .sum::<f64>()
627                / n_subjects.max(1) as f64;
628            var
629        })
630        .collect()
631}
632
633/// Compute fitted values and residuals.
634fn compute_fitted_residuals(
635    data: &FdMatrix,
636    mean_function: &[f64],
637    beta_functions: &FdMatrix,
638    random_effects: &FdMatrix,
639    covariates: Option<&FdMatrix>,
640    subject_map: &[usize],
641    n_total: usize,
642    m: usize,
643    p: usize,
644) -> (FdMatrix, FdMatrix) {
645    let mut fitted = FdMatrix::zeros(n_total, m);
646    let mut residuals = FdMatrix::zeros(n_total, m);
647
648    for i in 0..n_total {
649        let s = subject_map[i];
650        for t in 0..m {
651            let mut val = mean_function[t] + random_effects[(s, t)];
652            if p > 0 {
653                if let Some(cov) = covariates {
654                    for j in 0..p {
655                        val += cov[(i, j)] * beta_functions[(j, t)];
656                    }
657                }
658            }
659            fitted[(i, t)] = val;
660            residuals[(i, t)] = data[(i, t)] - val;
661        }
662    }
663
664    (fitted, residuals)
665}
666
667// ---------------------------------------------------------------------------
668// Prediction
669// ---------------------------------------------------------------------------
670
671/// Predict curves for new subjects.
672///
673/// # Arguments
674/// * `result` — Fitted FMM result
675/// * `new_covariates` — Covariates for new subjects (n_new × p)
676///
677/// Returns predicted curves (n_new × m) using only fixed effects (no random effects for new subjects).
678pub fn fmm_predict(result: &FmmResult, new_covariates: Option<&FdMatrix>) -> FdMatrix {
679    let m = result.mean_function.len();
680    let n_new = new_covariates.map_or(1, |c| c.nrows());
681    let p = result.beta_functions.nrows();
682
683    let mut predicted = FdMatrix::zeros(n_new, m);
684    for i in 0..n_new {
685        for t in 0..m {
686            let mut val = result.mean_function[t];
687            if let Some(cov) = new_covariates {
688                for j in 0..p {
689                    val += cov[(i, j)] * result.beta_functions[(j, t)];
690                }
691            }
692            predicted[(i, t)] = val;
693        }
694    }
695    predicted
696}
697
698// ---------------------------------------------------------------------------
699// Hypothesis testing
700// ---------------------------------------------------------------------------
701
702/// Permutation test for fixed effects in functional mixed model.
703///
704/// Tests H₀: β_j(t) = 0 for each covariate j.
705/// Uses integrated squared norm as test statistic: T_j = ∫ β̂_j(t)² dt.
706///
707/// # Arguments
708/// * `data` — All observed curves (n_total × m)
709/// * `subject_ids` — Subject identifiers
710/// * `covariates` — Subject-level covariates (n_total × p)
711/// * `ncomp` — Number of FPC components
712/// * `n_perm` — Number of permutations
713/// * `seed` — Random seed
714pub fn fmm_test_fixed(
715    data: &FdMatrix,
716    subject_ids: &[usize],
717    covariates: &FdMatrix,
718    ncomp: usize,
719    n_perm: usize,
720    seed: u64,
721) -> Option<FmmTestResult> {
722    let n_total = data.nrows();
723    let m = data.ncols();
724    let p = covariates.ncols();
725    if n_total == 0 || p == 0 {
726        return None;
727    }
728
729    // Fit observed model
730    let result = fmm(data, subject_ids, Some(covariates), ncomp)?;
731
732    // Observed test statistics: ∫ β̂_j(t)² dt for each covariate
733    let observed_stats = compute_integrated_beta_sq(&result.beta_functions, p, m);
734
735    // Permutation test
736    let (f_statistics, p_values) = permutation_test(
737        data,
738        subject_ids,
739        covariates,
740        ncomp,
741        n_perm,
742        seed,
743        &observed_stats,
744        p,
745        m,
746    );
747
748    Some(FmmTestResult {
749        f_statistics,
750        p_values,
751    })
752}
753
754/// Compute ∫ β̂_j(t)² dt for each covariate.
755fn compute_integrated_beta_sq(beta: &FdMatrix, p: usize, m: usize) -> Vec<f64> {
756    let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
757    (0..p)
758        .map(|j| {
759            let ss: f64 = (0..m).map(|t| beta[(j, t)].powi(2)).sum();
760            ss * h
761        })
762        .collect()
763}
764
765/// Run permutation test for fixed effects.
766fn permutation_test(
767    data: &FdMatrix,
768    subject_ids: &[usize],
769    covariates: &FdMatrix,
770    ncomp: usize,
771    n_perm: usize,
772    seed: u64,
773    observed_stats: &[f64],
774    p: usize,
775    m: usize,
776) -> (Vec<f64>, Vec<f64>) {
777    use rand::prelude::*;
778    let n_total = data.nrows();
779    let mut rng = StdRng::seed_from_u64(seed);
780    let mut n_ge = vec![0usize; p];
781
782    for _ in 0..n_perm {
783        // Permute covariates across subjects
784        let mut perm_indices: Vec<usize> = (0..n_total).collect();
785        perm_indices.shuffle(&mut rng);
786        let perm_cov = permute_rows(covariates, &perm_indices);
787
788        if let Some(perm_result) = fmm(data, subject_ids, Some(&perm_cov), ncomp) {
789            let perm_stats = compute_integrated_beta_sq(&perm_result.beta_functions, p, m);
790            for j in 0..p {
791                if perm_stats[j] >= observed_stats[j] {
792                    n_ge[j] += 1;
793                }
794            }
795        }
796    }
797
798    let p_values: Vec<f64> = n_ge
799        .iter()
800        .map(|&count| (count + 1) as f64 / (n_perm + 1) as f64)
801        .collect();
802    let f_statistics = observed_stats.to_vec();
803
804    (f_statistics, p_values)
805}
806
807/// Permute rows of a matrix according to given indices.
808fn permute_rows(mat: &FdMatrix, indices: &[usize]) -> FdMatrix {
809    let n = indices.len();
810    let m = mat.ncols();
811    let mut result = FdMatrix::zeros(n, m);
812    for (new_i, &old_i) in indices.iter().enumerate() {
813        for j in 0..m {
814            result[(new_i, j)] = mat[(old_i, j)];
815        }
816    }
817    result
818}
819
820// ---------------------------------------------------------------------------
821// Tests
822// ---------------------------------------------------------------------------
823
824#[cfg(test)]
825mod tests {
826    use super::*;
827    use std::f64::consts::PI;
828
829    fn uniform_grid(m: usize) -> Vec<f64> {
830        (0..m).map(|i| i as f64 / (m - 1) as f64).collect()
831    }
832
833    /// Generate repeated measurements: n_subjects × n_visits curves.
834    /// Subject-level covariate z affects the curve amplitude.
835    fn generate_fmm_data(
836        n_subjects: usize,
837        n_visits: usize,
838        m: usize,
839    ) -> (FdMatrix, Vec<usize>, FdMatrix, Vec<f64>) {
840        let t = uniform_grid(m);
841        let n_total = n_subjects * n_visits;
842        let mut col_major = vec![0.0; n_total * m];
843        let mut subject_ids = vec![0usize; n_total];
844        let mut cov_data = vec![0.0; n_total];
845
846        for s in 0..n_subjects {
847            let z = s as f64 / n_subjects as f64; // covariate in [0, 1)
848            let subject_effect = 0.5 * (s as f64 - n_subjects as f64 / 2.0); // random-like effect
849
850            for v in 0..n_visits {
851                let obs = s * n_visits + v;
852                subject_ids[obs] = s;
853                cov_data[obs] = z;
854                let noise_scale = 0.05;
855
856                for (j, &tj) in t.iter().enumerate() {
857                    // Y_sv(t) = sin(2πt) + z * t + subject_effect * cos(2πt) + noise
858                    let mu = (2.0 * PI * tj).sin();
859                    let fixed = z * tj * 3.0;
860                    let random = subject_effect * (2.0 * PI * tj).cos() * 0.3;
861                    let noise = noise_scale * ((obs * 7 + j * 3) % 100) as f64 / 100.0;
862                    col_major[obs + j * n_total] = mu + fixed + random + noise;
863                }
864            }
865        }
866
867        let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
868        let covariates = FdMatrix::from_column_major(cov_data, n_total, 1).unwrap();
869        (data, subject_ids, covariates, t)
870    }
871
872    #[test]
873    fn test_fmm_basic() {
874        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
875        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
876
877        assert_eq!(result.mean_function.len(), 50);
878        assert_eq!(result.beta_functions.nrows(), 1); // 1 covariate
879        assert_eq!(result.beta_functions.ncols(), 50);
880        assert_eq!(result.random_effects.nrows(), 10);
881        assert_eq!(result.fitted.nrows(), 30);
882        assert_eq!(result.residuals.nrows(), 30);
883        assert_eq!(result.n_subjects, 10);
884    }
885
886    #[test]
887    fn test_fmm_fitted_plus_residuals_equals_data() {
888        let (data, subject_ids, covariates, _t) = generate_fmm_data(8, 3, 40);
889        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
890
891        let n = data.nrows();
892        let m = data.ncols();
893        for i in 0..n {
894            for t in 0..m {
895                let reconstructed = result.fitted[(i, t)] + result.residuals[(i, t)];
896                assert!(
897                    (reconstructed - data[(i, t)]).abs() < 1e-8,
898                    "Fitted + residual should equal data at ({}, {}): {} vs {}",
899                    i,
900                    t,
901                    reconstructed,
902                    data[(i, t)]
903                );
904            }
905        }
906    }
907
908    #[test]
909    fn test_fmm_random_variance_positive() {
910        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
911        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
912
913        for &v in &result.random_variance {
914            assert!(v >= 0.0, "Random variance should be non-negative");
915        }
916    }
917
918    #[test]
919    fn test_fmm_no_covariates() {
920        let (data, subject_ids, _cov, _t) = generate_fmm_data(8, 3, 40);
921        let result = fmm(&data, &subject_ids, None, 3).unwrap();
922
923        assert_eq!(result.beta_functions.nrows(), 0);
924        assert_eq!(result.n_subjects, 8);
925        assert_eq!(result.fitted.nrows(), 24);
926    }
927
928    #[test]
929    fn test_fmm_predict() {
930        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
931        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
932
933        // Predict for new subjects with covariate = 0.5
934        let new_cov = FdMatrix::from_column_major(vec![0.5], 1, 1).unwrap();
935        let predicted = fmm_predict(&result, Some(&new_cov));
936
937        assert_eq!(predicted.nrows(), 1);
938        assert_eq!(predicted.ncols(), 50);
939
940        // Predicted curve should be reasonable (not NaN or extreme)
941        for t in 0..50 {
942            assert!(predicted[(0, t)].is_finite());
943            assert!(
944                predicted[(0, t)].abs() < 20.0,
945                "Predicted value too extreme at t={}: {}",
946                t,
947                predicted[(0, t)]
948            );
949        }
950    }
951
952    #[test]
953    fn test_fmm_test_fixed_detects_effect() {
954        let (data, subject_ids, covariates, _t) = generate_fmm_data(15, 3, 40);
955
956        let result = fmm_test_fixed(&data, &subject_ids, &covariates, 3, 99, 42).unwrap();
957
958        assert_eq!(result.f_statistics.len(), 1);
959        assert_eq!(result.p_values.len(), 1);
960        assert!(
961            result.p_values[0] < 0.1,
962            "Should detect covariate effect, got p={}",
963            result.p_values[0]
964        );
965    }
966
967    #[test]
968    fn test_fmm_test_fixed_no_effect() {
969        let n_subjects = 10;
970        let n_visits = 3;
971        let m = 40;
972        let t = uniform_grid(m);
973        let n_total = n_subjects * n_visits;
974
975        // No covariate effect: Y = sin(2πt) + noise
976        let mut col_major = vec![0.0; n_total * m];
977        let mut subject_ids = vec![0usize; n_total];
978        let mut cov_data = vec![0.0; n_total];
979
980        for s in 0..n_subjects {
981            for v in 0..n_visits {
982                let obs = s * n_visits + v;
983                subject_ids[obs] = s;
984                cov_data[obs] = s as f64 / n_subjects as f64;
985                for (j, &tj) in t.iter().enumerate() {
986                    col_major[obs + j * n_total] =
987                        (2.0 * PI * tj).sin() + 0.1 * ((obs * 7 + j * 3) % 100) as f64 / 100.0;
988                }
989            }
990        }
991
992        let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
993        let covariates = FdMatrix::from_column_major(cov_data, n_total, 1).unwrap();
994
995        let result = fmm_test_fixed(&data, &subject_ids, &covariates, 3, 99, 42).unwrap();
996        assert!(
997            result.p_values[0] > 0.05,
998            "Should not detect effect, got p={}",
999            result.p_values[0]
1000        );
1001    }
1002
1003    #[test]
1004    fn test_fmm_invalid_input() {
1005        let data = FdMatrix::zeros(0, 0);
1006        assert!(fmm(&data, &[], None, 1).is_none());
1007
1008        let data = FdMatrix::zeros(10, 50);
1009        let ids = vec![0; 5]; // wrong length
1010        assert!(fmm(&data, &ids, None, 1).is_none());
1011    }
1012
1013    #[test]
1014    fn test_fmm_single_visit_per_subject() {
1015        let n = 10;
1016        let m = 40;
1017        let t = uniform_grid(m);
1018        let mut col_major = vec![0.0; n * m];
1019        let subject_ids: Vec<usize> = (0..n).collect();
1020
1021        for i in 0..n {
1022            for (j, &tj) in t.iter().enumerate() {
1023                col_major[i + j * n] = (2.0 * PI * tj).sin();
1024            }
1025        }
1026        let data = FdMatrix::from_column_major(col_major, n, m).unwrap();
1027
1028        // Should still work with 1 visit per subject
1029        let result = fmm(&data, &subject_ids, None, 2).unwrap();
1030        assert_eq!(result.n_subjects, n);
1031        assert_eq!(result.fitted.nrows(), n);
1032    }
1033
1034    #[test]
1035    fn test_build_subject_map() {
1036        let (map, n) = build_subject_map(&[5, 5, 10, 10, 20]);
1037        assert_eq!(n, 3);
1038        assert_eq!(map, vec![0, 0, 1, 1, 2]);
1039    }
1040
1041    #[test]
1042    fn test_variance_components_positive() {
1043        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1044        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1045
1046        assert!(result.sigma2_eps >= 0.0);
1047        for &s in &result.sigma2_u {
1048            assert!(s >= 0.0);
1049        }
1050    }
1051}