Skip to main content

fdars_core/
famm.rs

1//! Functional Additive Mixed Models (FAMM).
2//!
3//! Implements functional mixed effects models for repeated functional
4//! measurements with subject-level covariates.
5//!
6//! Model: `Y_ij(t) = μ(t) + X_i'β(t) + b_i(t) + ε_ij(t)`
7//!
8//! Key functions:
9//! - [`fmm`] — Fit a functional mixed model via FPC decomposition
10//! - [`fmm_predict`] — Predict curves for new subjects
11//! - [`fmm_test_fixed`] — Hypothesis test on fixed effects
12
13use crate::matrix::FdMatrix;
14use crate::regression::fdata_to_pc_1d;
15
16/// Result of a functional mixed model fit.
17pub struct FmmResult {
18    /// Overall mean function μ̂(t) (length m)
19    pub mean_function: Vec<f64>,
20    /// Fixed effect coefficient functions β̂_j(t) (p × m matrix, one row per covariate)
21    pub beta_functions: FdMatrix,
22    /// Random effect functions b̂_i(t) per subject (n_subjects × m)
23    pub random_effects: FdMatrix,
24    /// Fitted values for all observations (n_total × m)
25    pub fitted: FdMatrix,
26    /// Residuals (n_total × m)
27    pub residuals: FdMatrix,
28    /// Variance of random effects at each time point (length m)
29    pub random_variance: Vec<f64>,
30    /// Residual variance estimate
31    pub sigma2_eps: f64,
32    /// Random effect variance estimate (per-component)
33    pub sigma2_u: Vec<f64>,
34    /// Number of FPC components used
35    pub ncomp: usize,
36    /// Number of subjects
37    pub n_subjects: usize,
38    /// FPC eigenvalues (singular values squared / n)
39    pub eigenvalues: Vec<f64>,
40}
41
42/// Result of fixed effect hypothesis test.
43pub struct FmmTestResult {
44    /// F-statistic per covariate (length p)
45    pub f_statistics: Vec<f64>,
46    /// P-values per covariate (via permutation, length p)
47    pub p_values: Vec<f64>,
48}
49
50// ---------------------------------------------------------------------------
51// Core FMM algorithm
52// ---------------------------------------------------------------------------
53
54/// Fit a functional mixed model via FPC decomposition.
55///
56/// # Arguments
57/// * `data` — All observed curves (n_total × m), stacked across subjects and visits
58/// * `subject_ids` — Subject identifier for each curve (length n_total)
59/// * `covariates` — Subject-level covariates (n_total × p).
60///   Each row corresponds to the same curve in `data`.
61///   If a covariate is subject-level, its value should be repeated across visits.
62/// * `ncomp` — Number of FPC components
63///
64/// # Algorithm
65/// 1. Pool curves, compute FPCA
66/// 2. For each FPC score, fit scalar mixed model: ξ_ijk = x_i'γ_k + u_ik + e_ijk
67/// 3. Recover β̂(t) and b̂_i(t) from component coefficients
68pub fn fmm(
69    data: &FdMatrix,
70    subject_ids: &[usize],
71    covariates: Option<&FdMatrix>,
72    ncomp: usize,
73) -> Option<FmmResult> {
74    let n_total = data.nrows();
75    let m = data.ncols();
76    if n_total == 0 || m == 0 || subject_ids.len() != n_total || ncomp == 0 {
77        return None;
78    }
79
80    // Determine unique subjects
81    let (subject_map, n_subjects) = build_subject_map(subject_ids);
82
83    // Step 1: FPCA on pooled data
84    let fpca = fdata_to_pc_1d(data, ncomp)?;
85    let k = fpca.scores.ncols(); // actual number of components
86
87    // Step 2: For each FPC score, fit scalar mixed model
88    // Normalize scores by sqrt(h) to match R's L²-weighted FPCA convention.
89    // This ensures variance components are on the same scale as R's lmer().
90    let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
91    let score_scale = h.sqrt();
92
93    let p = covariates.map_or(0, |c| c.ncols());
94    let mut gamma = vec![vec![0.0; k]; p]; // gamma[j][k] = fixed effect coeff j for component k
95    let mut u_hat = vec![vec![0.0; k]; n_subjects]; // u_hat[i][k] = random effect for subject i, component k
96    let mut sigma2_u = vec![0.0; k];
97    let mut sigma2_eps_total = 0.0;
98
99    for comp in 0..k {
100        // Scale scores to L²-normalized space
101        let scores: Vec<f64> = (0..n_total)
102            .map(|i| fpca.scores[(i, comp)] * score_scale)
103            .collect();
104        let result = fit_scalar_mixed_model(&scores, &subject_map, n_subjects, covariates, p);
105        // Scale gamma back to original score space for beta reconstruction
106        for j in 0..p {
107            gamma[j][comp] = result.gamma[j] / score_scale;
108        }
109        for s in 0..n_subjects {
110            // Scale u_hat back for random effect reconstruction
111            u_hat[s][comp] = result.u_hat[s] / score_scale;
112        }
113        // Keep variance components in L²-normalized space (matching R)
114        sigma2_u[comp] = result.sigma2_u;
115        sigma2_eps_total += result.sigma2_eps;
116    }
117    let sigma2_eps = sigma2_eps_total / k as f64;
118
119    // Step 3: Recover functional coefficients (using gamma in original scale)
120    let beta_functions = recover_beta_functions(&gamma, &fpca.rotation, p, m, k);
121    let random_effects = recover_random_effects(&u_hat, &fpca.rotation, n_subjects, m, k);
122
123    // Compute random variance function: Var(b_i(t)) across subjects
124    let random_variance = compute_random_variance(&random_effects, n_subjects, m);
125
126    // Compute fitted and residuals
127    let (fitted, residuals) = compute_fitted_residuals(
128        data,
129        &fpca.mean,
130        &beta_functions,
131        &random_effects,
132        covariates,
133        &subject_map,
134        n_total,
135        m,
136        p,
137    );
138
139    let eigenvalues: Vec<f64> = fpca
140        .singular_values
141        .iter()
142        .map(|&sv| sv * sv / n_total as f64)
143        .collect();
144
145    Some(FmmResult {
146        mean_function: fpca.mean,
147        beta_functions,
148        random_effects,
149        fitted,
150        residuals,
151        random_variance,
152        sigma2_eps,
153        sigma2_u,
154        ncomp: k,
155        n_subjects,
156        eigenvalues,
157    })
158}
159
160/// Build mapping from observation index to subject index (0..n_subjects-1).
161fn build_subject_map(subject_ids: &[usize]) -> (Vec<usize>, usize) {
162    let mut unique_ids: Vec<usize> = subject_ids.to_vec();
163    unique_ids.sort_unstable();
164    unique_ids.dedup();
165    let n_subjects = unique_ids.len();
166
167    let map: Vec<usize> = subject_ids
168        .iter()
169        .map(|id| unique_ids.iter().position(|u| u == id).unwrap_or(0))
170        .collect();
171
172    (map, n_subjects)
173}
174
175/// Scalar mixed model result for one FPC component.
176struct ScalarMixedResult {
177    gamma: Vec<f64>, // fixed effects (length p)
178    u_hat: Vec<f64>, // random effects per subject (length n_subjects)
179    sigma2_u: f64,   // random effect variance
180    sigma2_eps: f64, // residual variance
181}
182
183/// Precomputed subject structure for the mixed model.
184struct SubjectStructure {
185    counts: Vec<usize>,
186    obs: Vec<Vec<usize>>,
187}
188
189impl SubjectStructure {
190    fn new(subject_map: &[usize], n_subjects: usize, n: usize) -> Self {
191        let mut counts = vec![0usize; n_subjects];
192        let mut obs: Vec<Vec<usize>> = vec![Vec::new(); n_subjects];
193        for i in 0..n {
194            let s = subject_map[i];
195            counts[s] += 1;
196            obs[s].push(i);
197        }
198        Self { counts, obs }
199    }
200}
201
202/// Compute shrinkage weights: w_s = σ²_u / (σ²_u + σ²_e / n_s).
203fn shrinkage_weights(ss: &SubjectStructure, sigma2_u: f64, sigma2_e: f64) -> Vec<f64> {
204    ss.counts
205        .iter()
206        .map(|&c| {
207            let ns = c as f64;
208            if ns < 1.0 {
209                0.0
210            } else {
211                sigma2_u / (sigma2_u + sigma2_e / ns)
212            }
213        })
214        .collect()
215}
216
217/// GLS fixed effect update using block-diagonal V^{-1}.
218///
219/// Computes γ = (X'V⁻¹X)⁻¹ X'V⁻¹y exploiting the balanced random intercept structure.
220fn gls_update_gamma(
221    cov: &FdMatrix,
222    p: usize,
223    ss: &SubjectStructure,
224    weights: &[f64],
225    y: &[f64],
226    sigma2_e: f64,
227) -> Option<Vec<f64>> {
228    let n_subjects = ss.counts.len();
229    let mut xtvinvx = vec![0.0; p * p];
230    let mut xtvinvy = vec![0.0; p];
231    let inv_e = 1.0 / sigma2_e;
232
233    for s in 0..n_subjects {
234        let ns = ss.counts[s] as f64;
235        if ns < 1.0 {
236            continue;
237        }
238        let (x_sum, y_sum) = subject_sums(cov, y, &ss.obs[s], p);
239        accumulate_gls_terms(
240            cov,
241            y,
242            &ss.obs[s],
243            &x_sum,
244            y_sum,
245            weights[s],
246            ns,
247            inv_e,
248            p,
249            &mut xtvinvx,
250            &mut xtvinvy,
251        );
252    }
253
254    for j in 0..p {
255        xtvinvx[j * p + j] += 1e-10;
256    }
257    cholesky_solve(&xtvinvx, &xtvinvy, p)
258}
259
260/// Compute subject-level covariate sums and response sum.
261fn subject_sums(cov: &FdMatrix, y: &[f64], obs: &[usize], p: usize) -> (Vec<f64>, f64) {
262    let mut x_sum = vec![0.0; p];
263    let mut y_sum = 0.0;
264    for &i in obs {
265        for r in 0..p {
266            x_sum[r] += cov[(i, r)];
267        }
268        y_sum += y[i];
269    }
270    (x_sum, y_sum)
271}
272
273/// Accumulate X'V^{-1}X and X'V^{-1}y for one subject.
274fn accumulate_gls_terms(
275    cov: &FdMatrix,
276    y: &[f64],
277    obs: &[usize],
278    x_sum: &[f64],
279    y_sum: f64,
280    w_s: f64,
281    ns: f64,
282    inv_e: f64,
283    p: usize,
284    xtvinvx: &mut [f64],
285    xtvinvy: &mut [f64],
286) {
287    for &i in obs {
288        let vinv_y = inv_e * (y[i] - w_s * y_sum / ns);
289        for r in 0..p {
290            xtvinvy[r] += cov[(i, r)] * vinv_y;
291            for c in r..p {
292                let vinv_xc = inv_e * (cov[(i, c)] - w_s * x_sum[c] / ns);
293                let val = cov[(i, r)] * vinv_xc;
294                xtvinvx[r * p + c] += val;
295                if r != c {
296                    xtvinvx[c * p + r] += val;
297                }
298            }
299        }
300    }
301}
302
303/// REML EM update for variance components.
304///
305/// Returns (σ²_u_new, σ²_e_new) from the conditional expectations.
306/// Uses n - p divisor for σ²_e (REML correction where p = number of fixed effects).
307fn reml_variance_update(
308    residuals: &[f64],
309    ss: &SubjectStructure,
310    weights: &[f64],
311    sigma2_u: f64,
312    p: usize,
313) -> (f64, f64) {
314    let n_subjects = ss.counts.len();
315    let n: usize = ss.counts.iter().sum();
316    let mut sigma2_u_new = 0.0;
317    let mut sigma2_e_new = 0.0;
318
319    for s in 0..n_subjects {
320        let ns = ss.counts[s] as f64;
321        if ns < 1.0 {
322            continue;
323        }
324        let w_s = weights[s];
325        let mean_r_s: f64 = ss.obs[s].iter().map(|&i| residuals[i]).sum::<f64>() / ns;
326        let u_hat_s = w_s * mean_r_s;
327        let cond_var_s = sigma2_u * (1.0 - w_s);
328
329        sigma2_u_new += u_hat_s * u_hat_s + cond_var_s;
330        for &i in &ss.obs[s] {
331            sigma2_e_new += (residuals[i] - u_hat_s).powi(2);
332        }
333        sigma2_e_new += ns * cond_var_s;
334    }
335
336    // REML divisor: n - p for residual variance (matches R's lmer)
337    let denom_e = (n.saturating_sub(p)).max(1) as f64;
338
339    (
340        (sigma2_u_new / n_subjects as f64).max(1e-15),
341        (sigma2_e_new / denom_e).max(1e-15),
342    )
343}
344
345/// Fit scalar mixed model: y_ij = x_i'γ + u_i + e_ij.
346///
347/// Uses iterative GLS for fixed effects + REML EM for variance components,
348/// matching R's lmer() behavior. Initializes from Henderson's ANOVA, then
349/// iterates until convergence.
350fn fit_scalar_mixed_model(
351    y: &[f64],
352    subject_map: &[usize],
353    n_subjects: usize,
354    covariates: Option<&FdMatrix>,
355    p: usize,
356) -> ScalarMixedResult {
357    let n = y.len();
358    let ss = SubjectStructure::new(subject_map, n_subjects, n);
359
360    // Initialize from OLS + Henderson's ANOVA
361    let gamma_init = estimate_fixed_effects(y, covariates, p, n);
362    let residuals_init = compute_ols_residuals(y, covariates, &gamma_init, p, n);
363    let (mut sigma2_u, mut sigma2_e) =
364        estimate_variance_components(&residuals_init, subject_map, n_subjects, n);
365
366    if sigma2_e < 1e-15 {
367        sigma2_e = 1e-6;
368    }
369    if sigma2_u < 1e-15 {
370        sigma2_u = sigma2_e * 0.1;
371    }
372
373    let mut gamma = gamma_init;
374
375    for _iter in 0..50 {
376        let sigma2_u_old = sigma2_u;
377        let sigma2_e_old = sigma2_e;
378
379        let weights = shrinkage_weights(&ss, sigma2_u, sigma2_e);
380
381        if let Some(cov) = covariates.filter(|_| p > 0) {
382            if let Some(g) = gls_update_gamma(cov, p, &ss, &weights, y, sigma2_e) {
383                gamma = g;
384            }
385        }
386
387        let r = compute_ols_residuals(y, covariates, &gamma, p, n);
388        (sigma2_u, sigma2_e) = reml_variance_update(&r, &ss, &weights, sigma2_u, p);
389
390        let delta = (sigma2_u - sigma2_u_old).abs() + (sigma2_e - sigma2_e_old).abs();
391        if delta < 1e-10 * (sigma2_u_old + sigma2_e_old) {
392            break;
393        }
394    }
395
396    let final_residuals = compute_ols_residuals(y, covariates, &gamma, p, n);
397    let u_hat = compute_blup(
398        &final_residuals,
399        subject_map,
400        n_subjects,
401        sigma2_u,
402        sigma2_e,
403    );
404
405    ScalarMixedResult {
406        gamma,
407        u_hat,
408        sigma2_u,
409        sigma2_eps: sigma2_e,
410    }
411}
412
413/// OLS estimation of fixed effects.
414fn estimate_fixed_effects(
415    y: &[f64],
416    covariates: Option<&FdMatrix>,
417    p: usize,
418    n: usize,
419) -> Vec<f64> {
420    if p == 0 || covariates.is_none() {
421        return Vec::new();
422    }
423    let cov = covariates.unwrap();
424
425    // Solve (X'X)γ = X'y via Cholesky
426    let mut xtx = vec![0.0; p * p];
427    let mut xty = vec![0.0; p];
428    for i in 0..n {
429        for r in 0..p {
430            xty[r] += cov[(i, r)] * y[i];
431            for s in r..p {
432                let val = cov[(i, r)] * cov[(i, s)];
433                xtx[r * p + s] += val;
434                if r != s {
435                    xtx[s * p + r] += val;
436                }
437            }
438        }
439    }
440    // Regularize
441    for j in 0..p {
442        xtx[j * p + j] += 1e-8;
443    }
444
445    cholesky_solve(&xtx, &xty, p).unwrap_or(vec![0.0; p])
446}
447
448/// Cholesky factorization: A = LL'. Returns L (p×p flat row-major) or None if singular.
449fn cholesky_factor_famm(a: &[f64], p: usize) -> Option<Vec<f64>> {
450    let mut l = vec![0.0; p * p];
451    for j in 0..p {
452        let mut sum = 0.0;
453        for k in 0..j {
454            sum += l[j * p + k] * l[j * p + k];
455        }
456        let diag = a[j * p + j] - sum;
457        if diag <= 0.0 {
458            return None;
459        }
460        l[j * p + j] = diag.sqrt();
461        for i in (j + 1)..p {
462            let mut s = 0.0;
463            for k in 0..j {
464                s += l[i * p + k] * l[j * p + k];
465            }
466            l[i * p + j] = (a[i * p + j] - s) / l[j * p + j];
467        }
468    }
469    Some(l)
470}
471
472/// Solve L z = b (forward) then L' x = z (back).
473fn cholesky_triangular_solve(l: &[f64], b: &[f64], p: usize) -> Vec<f64> {
474    let mut z = vec![0.0; p];
475    for i in 0..p {
476        let mut s = 0.0;
477        for j in 0..i {
478            s += l[i * p + j] * z[j];
479        }
480        z[i] = (b[i] - s) / l[i * p + i];
481    }
482    for i in (0..p).rev() {
483        let mut s = 0.0;
484        for j in (i + 1)..p {
485            s += l[j * p + i] * z[j];
486        }
487        z[i] = (z[i] - s) / l[i * p + i];
488    }
489    z
490}
491
492/// Cholesky solve: A x = b where A is p×p symmetric positive definite.
493fn cholesky_solve(a: &[f64], b: &[f64], p: usize) -> Option<Vec<f64>> {
494    let l = cholesky_factor_famm(a, p)?;
495    Some(cholesky_triangular_solve(&l, b, p))
496}
497
498/// Compute OLS residuals: r = y - X*gamma.
499fn compute_ols_residuals(
500    y: &[f64],
501    covariates: Option<&FdMatrix>,
502    gamma: &[f64],
503    p: usize,
504    n: usize,
505) -> Vec<f64> {
506    let mut residuals = y.to_vec();
507    if p > 0 {
508        if let Some(cov) = covariates {
509            for i in 0..n {
510                for j in 0..p {
511                    residuals[i] -= cov[(i, j)] * gamma[j];
512                }
513            }
514        }
515    }
516    residuals
517}
518
519/// Estimate variance components via method of moments.
520///
521/// σ²_u and σ²_ε from one-way random effects ANOVA.
522fn estimate_variance_components(
523    residuals: &[f64],
524    subject_map: &[usize],
525    n_subjects: usize,
526    n: usize,
527) -> (f64, f64) {
528    // Compute subject means and within-subject SS
529    let mut subject_sums = vec![0.0; n_subjects];
530    let mut subject_counts = vec![0usize; n_subjects];
531    for i in 0..n {
532        let s = subject_map[i];
533        subject_sums[s] += residuals[i];
534        subject_counts[s] += 1;
535    }
536    let subject_means: Vec<f64> = subject_sums
537        .iter()
538        .zip(&subject_counts)
539        .map(|(&s, &c)| if c > 0 { s / c as f64 } else { 0.0 })
540        .collect();
541
542    // Within-subject SS
543    let mut ss_within = 0.0;
544    for i in 0..n {
545        let s = subject_map[i];
546        ss_within += (residuals[i] - subject_means[s]).powi(2);
547    }
548    let df_within = n.saturating_sub(n_subjects);
549
550    // Between-subject SS
551    let grand_mean = residuals.iter().sum::<f64>() / n as f64;
552    let mut ss_between = 0.0;
553    for s in 0..n_subjects {
554        ss_between += subject_counts[s] as f64 * (subject_means[s] - grand_mean).powi(2);
555    }
556
557    let sigma2_eps = if df_within > 0 {
558        ss_within / df_within as f64
559    } else {
560        1e-6
561    };
562
563    // Mean number of observations per subject
564    let n_bar = n as f64 / n_subjects.max(1) as f64;
565    let df_between = n_subjects.saturating_sub(1).max(1);
566    let ms_between = ss_between / df_between as f64;
567    let sigma2_u = ((ms_between - sigma2_eps) / n_bar).max(0.0);
568
569    (sigma2_u, sigma2_eps)
570}
571
572/// Compute BLUP (Best Linear Unbiased Prediction) for random effects.
573///
574/// û_i = σ²_u / (σ²_u + σ²_ε/n_i) * (ȳ_i - x̄_i'γ)
575fn compute_blup(
576    residuals: &[f64],
577    subject_map: &[usize],
578    n_subjects: usize,
579    sigma2_u: f64,
580    sigma2_eps: f64,
581) -> Vec<f64> {
582    let mut subject_sums = vec![0.0; n_subjects];
583    let mut subject_counts = vec![0usize; n_subjects];
584    for (i, &r) in residuals.iter().enumerate() {
585        let s = subject_map[i];
586        subject_sums[s] += r;
587        subject_counts[s] += 1;
588    }
589
590    (0..n_subjects)
591        .map(|s| {
592            let ni = subject_counts[s] as f64;
593            if ni < 1.0 {
594                return 0.0;
595            }
596            let mean_r = subject_sums[s] / ni;
597            let shrinkage = sigma2_u / (sigma2_u + sigma2_eps / ni).max(1e-15);
598            shrinkage * mean_r
599        })
600        .collect()
601}
602
603// ---------------------------------------------------------------------------
604// Recovery of functional coefficients
605// ---------------------------------------------------------------------------
606
607/// Recover β̂(t) = Σ_k γ̂_jk φ_k(t) for each covariate j.
608fn recover_beta_functions(
609    gamma: &[Vec<f64>],
610    rotation: &FdMatrix,
611    p: usize,
612    m: usize,
613    k: usize,
614) -> FdMatrix {
615    let mut beta = FdMatrix::zeros(p, m);
616    for j in 0..p {
617        for t in 0..m {
618            let mut val = 0.0;
619            for comp in 0..k {
620                val += gamma[j][comp] * rotation[(t, comp)];
621            }
622            beta[(j, t)] = val;
623        }
624    }
625    beta
626}
627
628/// Recover b̂_i(t) = Σ_k û_ik φ_k(t) for each subject i.
629fn recover_random_effects(
630    u_hat: &[Vec<f64>],
631    rotation: &FdMatrix,
632    n_subjects: usize,
633    m: usize,
634    k: usize,
635) -> FdMatrix {
636    let mut re = FdMatrix::zeros(n_subjects, m);
637    for s in 0..n_subjects {
638        for t in 0..m {
639            let mut val = 0.0;
640            for comp in 0..k {
641                val += u_hat[s][comp] * rotation[(t, comp)];
642            }
643            re[(s, t)] = val;
644        }
645    }
646    re
647}
648
649/// Compute random effect variance function: Var_i(b̂_i(t)).
650fn compute_random_variance(random_effects: &FdMatrix, n_subjects: usize, m: usize) -> Vec<f64> {
651    (0..m)
652        .map(|t| {
653            let mean: f64 =
654                (0..n_subjects).map(|s| random_effects[(s, t)]).sum::<f64>() / n_subjects as f64;
655            let var: f64 = (0..n_subjects)
656                .map(|s| (random_effects[(s, t)] - mean).powi(2))
657                .sum::<f64>()
658                / n_subjects.max(1) as f64;
659            var
660        })
661        .collect()
662}
663
664/// Compute fitted values and residuals.
665fn compute_fitted_residuals(
666    data: &FdMatrix,
667    mean_function: &[f64],
668    beta_functions: &FdMatrix,
669    random_effects: &FdMatrix,
670    covariates: Option<&FdMatrix>,
671    subject_map: &[usize],
672    n_total: usize,
673    m: usize,
674    p: usize,
675) -> (FdMatrix, FdMatrix) {
676    let mut fitted = FdMatrix::zeros(n_total, m);
677    let mut residuals = FdMatrix::zeros(n_total, m);
678
679    for i in 0..n_total {
680        let s = subject_map[i];
681        for t in 0..m {
682            let mut val = mean_function[t] + random_effects[(s, t)];
683            if p > 0 {
684                if let Some(cov) = covariates {
685                    for j in 0..p {
686                        val += cov[(i, j)] * beta_functions[(j, t)];
687                    }
688                }
689            }
690            fitted[(i, t)] = val;
691            residuals[(i, t)] = data[(i, t)] - val;
692        }
693    }
694
695    (fitted, residuals)
696}
697
698// ---------------------------------------------------------------------------
699// Prediction
700// ---------------------------------------------------------------------------
701
702/// Predict curves for new subjects.
703///
704/// # Arguments
705/// * `result` — Fitted FMM result
706/// * `new_covariates` — Covariates for new subjects (n_new × p)
707///
708/// Returns predicted curves (n_new × m) using only fixed effects (no random effects for new subjects).
709pub fn fmm_predict(result: &FmmResult, new_covariates: Option<&FdMatrix>) -> FdMatrix {
710    let m = result.mean_function.len();
711    let n_new = new_covariates.map_or(1, |c| c.nrows());
712    let p = result.beta_functions.nrows();
713
714    let mut predicted = FdMatrix::zeros(n_new, m);
715    for i in 0..n_new {
716        for t in 0..m {
717            let mut val = result.mean_function[t];
718            if let Some(cov) = new_covariates {
719                for j in 0..p {
720                    val += cov[(i, j)] * result.beta_functions[(j, t)];
721                }
722            }
723            predicted[(i, t)] = val;
724        }
725    }
726    predicted
727}
728
729// ---------------------------------------------------------------------------
730// Hypothesis testing
731// ---------------------------------------------------------------------------
732
733/// Permutation test for fixed effects in functional mixed model.
734///
735/// Tests H₀: β_j(t) = 0 for each covariate j.
736/// Uses integrated squared norm as test statistic: T_j = ∫ β̂_j(t)² dt.
737///
738/// # Arguments
739/// * `data` — All observed curves (n_total × m)
740/// * `subject_ids` — Subject identifiers
741/// * `covariates` — Subject-level covariates (n_total × p)
742/// * `ncomp` — Number of FPC components
743/// * `n_perm` — Number of permutations
744/// * `seed` — Random seed
745pub fn fmm_test_fixed(
746    data: &FdMatrix,
747    subject_ids: &[usize],
748    covariates: &FdMatrix,
749    ncomp: usize,
750    n_perm: usize,
751    seed: u64,
752) -> Option<FmmTestResult> {
753    let n_total = data.nrows();
754    let m = data.ncols();
755    let p = covariates.ncols();
756    if n_total == 0 || p == 0 {
757        return None;
758    }
759
760    // Fit observed model
761    let result = fmm(data, subject_ids, Some(covariates), ncomp)?;
762
763    // Observed test statistics: ∫ β̂_j(t)² dt for each covariate
764    let observed_stats = compute_integrated_beta_sq(&result.beta_functions, p, m);
765
766    // Permutation test
767    let (f_statistics, p_values) = permutation_test(
768        data,
769        subject_ids,
770        covariates,
771        ncomp,
772        n_perm,
773        seed,
774        &observed_stats,
775        p,
776        m,
777    );
778
779    Some(FmmTestResult {
780        f_statistics,
781        p_values,
782    })
783}
784
785/// Compute ∫ β̂_j(t)² dt for each covariate.
786fn compute_integrated_beta_sq(beta: &FdMatrix, p: usize, m: usize) -> Vec<f64> {
787    let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
788    (0..p)
789        .map(|j| {
790            let ss: f64 = (0..m).map(|t| beta[(j, t)].powi(2)).sum();
791            ss * h
792        })
793        .collect()
794}
795
796/// Run permutation test for fixed effects.
797fn permutation_test(
798    data: &FdMatrix,
799    subject_ids: &[usize],
800    covariates: &FdMatrix,
801    ncomp: usize,
802    n_perm: usize,
803    seed: u64,
804    observed_stats: &[f64],
805    p: usize,
806    m: usize,
807) -> (Vec<f64>, Vec<f64>) {
808    use rand::prelude::*;
809    let n_total = data.nrows();
810    let mut rng = StdRng::seed_from_u64(seed);
811    let mut n_ge = vec![0usize; p];
812
813    for _ in 0..n_perm {
814        // Permute covariates across subjects
815        let mut perm_indices: Vec<usize> = (0..n_total).collect();
816        perm_indices.shuffle(&mut rng);
817        let perm_cov = permute_rows(covariates, &perm_indices);
818
819        if let Some(perm_result) = fmm(data, subject_ids, Some(&perm_cov), ncomp) {
820            let perm_stats = compute_integrated_beta_sq(&perm_result.beta_functions, p, m);
821            for j in 0..p {
822                if perm_stats[j] >= observed_stats[j] {
823                    n_ge[j] += 1;
824                }
825            }
826        }
827    }
828
829    let p_values: Vec<f64> = n_ge
830        .iter()
831        .map(|&count| (count + 1) as f64 / (n_perm + 1) as f64)
832        .collect();
833    let f_statistics = observed_stats.to_vec();
834
835    (f_statistics, p_values)
836}
837
838/// Permute rows of a matrix according to given indices.
839fn permute_rows(mat: &FdMatrix, indices: &[usize]) -> FdMatrix {
840    let n = indices.len();
841    let m = mat.ncols();
842    let mut result = FdMatrix::zeros(n, m);
843    for (new_i, &old_i) in indices.iter().enumerate() {
844        for j in 0..m {
845            result[(new_i, j)] = mat[(old_i, j)];
846        }
847    }
848    result
849}
850
851// ---------------------------------------------------------------------------
852// Tests
853// ---------------------------------------------------------------------------
854
855#[cfg(test)]
856mod tests {
857    use super::*;
858    use std::f64::consts::PI;
859
860    fn uniform_grid(m: usize) -> Vec<f64> {
861        (0..m).map(|i| i as f64 / (m - 1) as f64).collect()
862    }
863
864    /// Generate repeated measurements: n_subjects × n_visits curves.
865    /// Subject-level covariate z affects the curve amplitude.
866    fn generate_fmm_data(
867        n_subjects: usize,
868        n_visits: usize,
869        m: usize,
870    ) -> (FdMatrix, Vec<usize>, FdMatrix, Vec<f64>) {
871        let t = uniform_grid(m);
872        let n_total = n_subjects * n_visits;
873        let mut col_major = vec![0.0; n_total * m];
874        let mut subject_ids = vec![0usize; n_total];
875        let mut cov_data = vec![0.0; n_total];
876
877        for s in 0..n_subjects {
878            let z = s as f64 / n_subjects as f64; // covariate in [0, 1)
879            let subject_effect = 0.5 * (s as f64 - n_subjects as f64 / 2.0); // random-like effect
880
881            for v in 0..n_visits {
882                let obs = s * n_visits + v;
883                subject_ids[obs] = s;
884                cov_data[obs] = z;
885                let noise_scale = 0.05;
886
887                for (j, &tj) in t.iter().enumerate() {
888                    // Y_sv(t) = sin(2πt) + z * t + subject_effect * cos(2πt) + noise
889                    let mu = (2.0 * PI * tj).sin();
890                    let fixed = z * tj * 3.0;
891                    let random = subject_effect * (2.0 * PI * tj).cos() * 0.3;
892                    let noise = noise_scale * ((obs * 7 + j * 3) % 100) as f64 / 100.0;
893                    col_major[obs + j * n_total] = mu + fixed + random + noise;
894                }
895            }
896        }
897
898        let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
899        let covariates = FdMatrix::from_column_major(cov_data, n_total, 1).unwrap();
900        (data, subject_ids, covariates, t)
901    }
902
903    #[test]
904    fn test_fmm_basic() {
905        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
906        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
907
908        assert_eq!(result.mean_function.len(), 50);
909        assert_eq!(result.beta_functions.nrows(), 1); // 1 covariate
910        assert_eq!(result.beta_functions.ncols(), 50);
911        assert_eq!(result.random_effects.nrows(), 10);
912        assert_eq!(result.fitted.nrows(), 30);
913        assert_eq!(result.residuals.nrows(), 30);
914        assert_eq!(result.n_subjects, 10);
915    }
916
917    #[test]
918    fn test_fmm_fitted_plus_residuals_equals_data() {
919        let (data, subject_ids, covariates, _t) = generate_fmm_data(8, 3, 40);
920        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
921
922        let n = data.nrows();
923        let m = data.ncols();
924        for i in 0..n {
925            for t in 0..m {
926                let reconstructed = result.fitted[(i, t)] + result.residuals[(i, t)];
927                assert!(
928                    (reconstructed - data[(i, t)]).abs() < 1e-8,
929                    "Fitted + residual should equal data at ({}, {}): {} vs {}",
930                    i,
931                    t,
932                    reconstructed,
933                    data[(i, t)]
934                );
935            }
936        }
937    }
938
939    #[test]
940    fn test_fmm_random_variance_positive() {
941        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
942        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
943
944        for &v in &result.random_variance {
945            assert!(v >= 0.0, "Random variance should be non-negative");
946        }
947    }
948
949    #[test]
950    fn test_fmm_no_covariates() {
951        let (data, subject_ids, _cov, _t) = generate_fmm_data(8, 3, 40);
952        let result = fmm(&data, &subject_ids, None, 3).unwrap();
953
954        assert_eq!(result.beta_functions.nrows(), 0);
955        assert_eq!(result.n_subjects, 8);
956        assert_eq!(result.fitted.nrows(), 24);
957    }
958
959    #[test]
960    fn test_fmm_predict() {
961        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
962        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
963
964        // Predict for new subjects with covariate = 0.5
965        let new_cov = FdMatrix::from_column_major(vec![0.5], 1, 1).unwrap();
966        let predicted = fmm_predict(&result, Some(&new_cov));
967
968        assert_eq!(predicted.nrows(), 1);
969        assert_eq!(predicted.ncols(), 50);
970
971        // Predicted curve should be reasonable (not NaN or extreme)
972        for t in 0..50 {
973            assert!(predicted[(0, t)].is_finite());
974            assert!(
975                predicted[(0, t)].abs() < 20.0,
976                "Predicted value too extreme at t={}: {}",
977                t,
978                predicted[(0, t)]
979            );
980        }
981    }
982
983    #[test]
984    fn test_fmm_test_fixed_detects_effect() {
985        let (data, subject_ids, covariates, _t) = generate_fmm_data(15, 3, 40);
986
987        let result = fmm_test_fixed(&data, &subject_ids, &covariates, 3, 99, 42).unwrap();
988
989        assert_eq!(result.f_statistics.len(), 1);
990        assert_eq!(result.p_values.len(), 1);
991        assert!(
992            result.p_values[0] < 0.1,
993            "Should detect covariate effect, got p={}",
994            result.p_values[0]
995        );
996    }
997
998    #[test]
999    fn test_fmm_test_fixed_no_effect() {
1000        let n_subjects = 10;
1001        let n_visits = 3;
1002        let m = 40;
1003        let t = uniform_grid(m);
1004        let n_total = n_subjects * n_visits;
1005
1006        // No covariate effect: Y = sin(2πt) + noise
1007        let mut col_major = vec![0.0; n_total * m];
1008        let mut subject_ids = vec![0usize; n_total];
1009        let mut cov_data = vec![0.0; n_total];
1010
1011        for s in 0..n_subjects {
1012            for v in 0..n_visits {
1013                let obs = s * n_visits + v;
1014                subject_ids[obs] = s;
1015                cov_data[obs] = s as f64 / n_subjects as f64;
1016                for (j, &tj) in t.iter().enumerate() {
1017                    col_major[obs + j * n_total] =
1018                        (2.0 * PI * tj).sin() + 0.1 * ((obs * 7 + j * 3) % 100) as f64 / 100.0;
1019                }
1020            }
1021        }
1022
1023        let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
1024        let covariates = FdMatrix::from_column_major(cov_data, n_total, 1).unwrap();
1025
1026        let result = fmm_test_fixed(&data, &subject_ids, &covariates, 3, 99, 42).unwrap();
1027        assert!(
1028            result.p_values[0] > 0.05,
1029            "Should not detect effect, got p={}",
1030            result.p_values[0]
1031        );
1032    }
1033
1034    #[test]
1035    fn test_fmm_invalid_input() {
1036        let data = FdMatrix::zeros(0, 0);
1037        assert!(fmm(&data, &[], None, 1).is_none());
1038
1039        let data = FdMatrix::zeros(10, 50);
1040        let ids = vec![0; 5]; // wrong length
1041        assert!(fmm(&data, &ids, None, 1).is_none());
1042    }
1043
1044    #[test]
1045    fn test_fmm_single_visit_per_subject() {
1046        let n = 10;
1047        let m = 40;
1048        let t = uniform_grid(m);
1049        let mut col_major = vec![0.0; n * m];
1050        let subject_ids: Vec<usize> = (0..n).collect();
1051
1052        for i in 0..n {
1053            for (j, &tj) in t.iter().enumerate() {
1054                col_major[i + j * n] = (2.0 * PI * tj).sin();
1055            }
1056        }
1057        let data = FdMatrix::from_column_major(col_major, n, m).unwrap();
1058
1059        // Should still work with 1 visit per subject
1060        let result = fmm(&data, &subject_ids, None, 2).unwrap();
1061        assert_eq!(result.n_subjects, n);
1062        assert_eq!(result.fitted.nrows(), n);
1063    }
1064
1065    #[test]
1066    fn test_build_subject_map() {
1067        let (map, n) = build_subject_map(&[5, 5, 10, 10, 20]);
1068        assert_eq!(n, 3);
1069        assert_eq!(map, vec![0, 0, 1, 1, 2]);
1070    }
1071
1072    #[test]
1073    fn test_variance_components_positive() {
1074        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1075        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1076
1077        assert!(result.sigma2_eps >= 0.0);
1078        for &s in &result.sigma2_u {
1079            assert!(s >= 0.0);
1080        }
1081    }
1082}