Skip to main content

fdars_core/
famm.rs

1//! Functional Additive Mixed Models (FAMM).
2//!
3//! Implements functional mixed effects models for repeated functional
4//! measurements with subject-level covariates.
5//!
6//! Model: `Y_ij(t) = μ(t) + X_i'β(t) + b_i(t) + ε_ij(t)`
7//!
8//! Key functions:
9//! - [`fmm`] — Fit a functional mixed model via FPC decomposition
10//! - [`fmm_predict`] — Predict curves for new subjects
11//! - [`fmm_test_fixed`] — Hypothesis test on fixed effects
12
13use crate::matrix::FdMatrix;
14use crate::regression::fdata_to_pc_1d;
15
16/// Result of a functional mixed model fit.
17pub struct FmmResult {
18    /// Overall mean function μ̂(t) (length m)
19    pub mean_function: Vec<f64>,
20    /// Fixed effect coefficient functions β̂_j(t) (p × m matrix, one row per covariate)
21    pub beta_functions: FdMatrix,
22    /// Random effect functions b̂_i(t) per subject (n_subjects × m)
23    pub random_effects: FdMatrix,
24    /// Fitted values for all observations (n_total × m)
25    pub fitted: FdMatrix,
26    /// Residuals (n_total × m)
27    pub residuals: FdMatrix,
28    /// Variance of random effects at each time point (length m)
29    pub random_variance: Vec<f64>,
30    /// Residual variance estimate
31    pub sigma2_eps: f64,
32    /// Random effect variance estimate (per-component)
33    pub sigma2_u: Vec<f64>,
34    /// Number of FPC components used
35    pub ncomp: usize,
36    /// Number of subjects
37    pub n_subjects: usize,
38    /// FPC eigenvalues (singular values squared / n)
39    pub eigenvalues: Vec<f64>,
40}
41
42/// Result of fixed effect hypothesis test.
43pub struct FmmTestResult {
44    /// F-statistic per covariate (length p)
45    pub f_statistics: Vec<f64>,
46    /// P-values per covariate (via permutation, length p)
47    pub p_values: Vec<f64>,
48}
49
50// ---------------------------------------------------------------------------
51// Core FMM algorithm
52// ---------------------------------------------------------------------------
53
54/// Fit a functional mixed model via FPC decomposition.
55///
56/// # Arguments
57/// * `data` — All observed curves (n_total × m), stacked across subjects and visits
58/// * `subject_ids` — Subject identifier for each curve (length n_total)
59/// * `covariates` — Subject-level covariates (n_total × p).
60///   Each row corresponds to the same curve in `data`.
61///   If a covariate is subject-level, its value should be repeated across visits.
62/// * `ncomp` — Number of FPC components
63///
64/// # Algorithm
65/// 1. Pool curves, compute FPCA
66/// 2. For each FPC score, fit scalar mixed model: ξ_ijk = x_i'γ_k + u_ik + e_ijk
67/// 3. Recover β̂(t) and b̂_i(t) from component coefficients
68pub fn fmm(
69    data: &FdMatrix,
70    subject_ids: &[usize],
71    covariates: Option<&FdMatrix>,
72    ncomp: usize,
73) -> Option<FmmResult> {
74    let n_total = data.nrows();
75    let m = data.ncols();
76    if n_total == 0 || m == 0 || subject_ids.len() != n_total || ncomp == 0 {
77        return None;
78    }
79
80    // Determine unique subjects
81    let (subject_map, n_subjects) = build_subject_map(subject_ids);
82
83    // Step 1: FPCA on pooled data
84    let fpca = fdata_to_pc_1d(data, ncomp)?;
85    let k = fpca.scores.ncols(); // actual number of components
86
87    // Step 2: For each FPC score, fit scalar mixed model
88    // Normalize scores by sqrt(h) to match R's L²-weighted FPCA convention.
89    // This ensures variance components are on the same scale as R's lmer().
90    let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
91    let score_scale = h.sqrt();
92
93    let p = covariates.map_or(0, |c| c.ncols());
94    let mut gamma = vec![vec![0.0; k]; p]; // gamma[j][k] = fixed effect coeff j for component k
95    let mut u_hat = vec![vec![0.0; k]; n_subjects]; // u_hat[i][k] = random effect for subject i, component k
96    let mut sigma2_u = vec![0.0; k];
97    let mut sigma2_eps_total = 0.0;
98
99    for comp in 0..k {
100        // Scale scores to L²-normalized space
101        let scores: Vec<f64> = (0..n_total)
102            .map(|i| fpca.scores[(i, comp)] * score_scale)
103            .collect();
104        let result = fit_scalar_mixed_model(&scores, &subject_map, n_subjects, covariates, p);
105        // Scale gamma back to original score space for beta reconstruction
106        for j in 0..p {
107            gamma[j][comp] = result.gamma[j] / score_scale;
108        }
109        for s in 0..n_subjects {
110            // Scale u_hat back for random effect reconstruction
111            u_hat[s][comp] = result.u_hat[s] / score_scale;
112        }
113        // Keep variance components in L²-normalized space (matching R)
114        sigma2_u[comp] = result.sigma2_u;
115        sigma2_eps_total += result.sigma2_eps;
116    }
117    let sigma2_eps = sigma2_eps_total / k as f64;
118
119    // Step 3: Recover functional coefficients (using gamma in original scale)
120    let beta_functions = recover_beta_functions(&gamma, &fpca.rotation, p, m, k);
121    let random_effects = recover_random_effects(&u_hat, &fpca.rotation, n_subjects, m, k);
122
123    // Compute random variance function: Var(b_i(t)) across subjects
124    let random_variance = compute_random_variance(&random_effects, n_subjects, m);
125
126    // Compute fitted and residuals
127    let (fitted, residuals) = compute_fitted_residuals(
128        data,
129        &fpca.mean,
130        &beta_functions,
131        &random_effects,
132        covariates,
133        &subject_map,
134        n_total,
135        m,
136        p,
137    );
138
139    let eigenvalues: Vec<f64> = fpca
140        .singular_values
141        .iter()
142        .map(|&sv| sv * sv / n_total as f64)
143        .collect();
144
145    Some(FmmResult {
146        mean_function: fpca.mean,
147        beta_functions,
148        random_effects,
149        fitted,
150        residuals,
151        random_variance,
152        sigma2_eps,
153        sigma2_u,
154        ncomp: k,
155        n_subjects,
156        eigenvalues,
157    })
158}
159
160/// Build mapping from observation index to subject index (0..n_subjects-1).
161fn build_subject_map(subject_ids: &[usize]) -> (Vec<usize>, usize) {
162    let mut unique_ids: Vec<usize> = subject_ids.to_vec();
163    unique_ids.sort_unstable();
164    unique_ids.dedup();
165    let n_subjects = unique_ids.len();
166
167    let map: Vec<usize> = subject_ids
168        .iter()
169        .map(|id| unique_ids.iter().position(|u| u == id).unwrap_or(0))
170        .collect();
171
172    (map, n_subjects)
173}
174
175/// Scalar mixed model result for one FPC component.
176struct ScalarMixedResult {
177    gamma: Vec<f64>, // fixed effects (length p)
178    u_hat: Vec<f64>, // random effects per subject (length n_subjects)
179    sigma2_u: f64,   // random effect variance
180    sigma2_eps: f64, // residual variance
181}
182
183/// Precomputed subject structure for the mixed model.
184struct SubjectStructure {
185    counts: Vec<usize>,
186    obs: Vec<Vec<usize>>,
187}
188
189impl SubjectStructure {
190    fn new(subject_map: &[usize], n_subjects: usize, n: usize) -> Self {
191        let mut counts = vec![0usize; n_subjects];
192        let mut obs: Vec<Vec<usize>> = vec![Vec::new(); n_subjects];
193        for i in 0..n {
194            let s = subject_map[i];
195            counts[s] += 1;
196            obs[s].push(i);
197        }
198        Self { counts, obs }
199    }
200}
201
202/// Compute shrinkage weights: w_s = σ²_u / (σ²_u + σ²_e / n_s).
203fn shrinkage_weights(ss: &SubjectStructure, sigma2_u: f64, sigma2_e: f64) -> Vec<f64> {
204    ss.counts
205        .iter()
206        .map(|&c| {
207            let ns = c as f64;
208            if ns < 1.0 {
209                0.0
210            } else {
211                sigma2_u / (sigma2_u + sigma2_e / ns)
212            }
213        })
214        .collect()
215}
216
217/// GLS fixed effect update using block-diagonal V^{-1}.
218///
219/// Computes γ = (X'V⁻¹X)⁻¹ X'V⁻¹y exploiting the balanced random intercept structure.
220fn gls_update_gamma(
221    cov: &FdMatrix,
222    p: usize,
223    ss: &SubjectStructure,
224    weights: &[f64],
225    y: &[f64],
226    sigma2_e: f64,
227) -> Option<Vec<f64>> {
228    let n_subjects = ss.counts.len();
229    let mut xtvinvx = vec![0.0; p * p];
230    let mut xtvinvy = vec![0.0; p];
231    let inv_e = 1.0 / sigma2_e;
232
233    for s in 0..n_subjects {
234        let ns = ss.counts[s] as f64;
235        if ns < 1.0 {
236            continue;
237        }
238        let w_s = weights[s];
239
240        // Subject-level sums
241        let mut x_sum = vec![0.0; p];
242        let mut y_sum = 0.0;
243        for &i in &ss.obs[s] {
244            for r in 0..p {
245                x_sum[r] += cov[(i, r)];
246            }
247            y_sum += y[i];
248        }
249
250        // Accumulate X'V^{-1}X and X'V^{-1}y
251        for &i in &ss.obs[s] {
252            let mut vinv_x = vec![0.0; p];
253            for r in 0..p {
254                vinv_x[r] = inv_e * (cov[(i, r)] - w_s * x_sum[r] / ns);
255            }
256            let vinv_y = inv_e * (y[i] - w_s * y_sum / ns);
257
258            for r in 0..p {
259                xtvinvy[r] += cov[(i, r)] * vinv_y;
260                for c in r..p {
261                    let val = cov[(i, r)] * vinv_x[c];
262                    xtvinvx[r * p + c] += val;
263                    if r != c {
264                        xtvinvx[c * p + r] += val;
265                    }
266                }
267            }
268        }
269    }
270
271    for j in 0..p {
272        xtvinvx[j * p + j] += 1e-10;
273    }
274    cholesky_solve(&xtvinvx, &xtvinvy, p)
275}
276
277/// REML EM update for variance components.
278///
279/// Returns (σ²_u_new, σ²_e_new) from the conditional expectations.
280/// Uses n - p divisor for σ²_e (REML correction where p = number of fixed effects).
281fn reml_variance_update(
282    residuals: &[f64],
283    ss: &SubjectStructure,
284    weights: &[f64],
285    sigma2_u: f64,
286    p: usize,
287) -> (f64, f64) {
288    let n_subjects = ss.counts.len();
289    let n: usize = ss.counts.iter().sum();
290    let mut sigma2_u_new = 0.0;
291    let mut sigma2_e_new = 0.0;
292
293    for s in 0..n_subjects {
294        let ns = ss.counts[s] as f64;
295        if ns < 1.0 {
296            continue;
297        }
298        let w_s = weights[s];
299        let mean_r_s: f64 = ss.obs[s].iter().map(|&i| residuals[i]).sum::<f64>() / ns;
300        let u_hat_s = w_s * mean_r_s;
301        let cond_var_s = sigma2_u * (1.0 - w_s);
302
303        sigma2_u_new += u_hat_s * u_hat_s + cond_var_s;
304        for &i in &ss.obs[s] {
305            sigma2_e_new += (residuals[i] - u_hat_s).powi(2);
306        }
307        sigma2_e_new += ns * cond_var_s;
308    }
309
310    // REML divisor: n - p for residual variance (matches R's lmer)
311    let denom_e = (n.saturating_sub(p)).max(1) as f64;
312
313    (
314        (sigma2_u_new / n_subjects as f64).max(1e-15),
315        (sigma2_e_new / denom_e).max(1e-15),
316    )
317}
318
319/// Fit scalar mixed model: y_ij = x_i'γ + u_i + e_ij.
320///
321/// Uses iterative GLS for fixed effects + REML EM for variance components,
322/// matching R's lmer() behavior. Initializes from Henderson's ANOVA, then
323/// iterates until convergence.
324fn fit_scalar_mixed_model(
325    y: &[f64],
326    subject_map: &[usize],
327    n_subjects: usize,
328    covariates: Option<&FdMatrix>,
329    p: usize,
330) -> ScalarMixedResult {
331    let n = y.len();
332    let ss = SubjectStructure::new(subject_map, n_subjects, n);
333
334    // Initialize from OLS + Henderson's ANOVA
335    let gamma_init = estimate_fixed_effects(y, covariates, p, n);
336    let residuals_init = compute_ols_residuals(y, covariates, &gamma_init, p, n);
337    let (mut sigma2_u, mut sigma2_e) =
338        estimate_variance_components(&residuals_init, subject_map, n_subjects, n);
339
340    if sigma2_e < 1e-15 {
341        sigma2_e = 1e-6;
342    }
343    if sigma2_u < 1e-15 {
344        sigma2_u = sigma2_e * 0.1;
345    }
346
347    let mut gamma = gamma_init;
348
349    for _iter in 0..50 {
350        let sigma2_u_old = sigma2_u;
351        let sigma2_e_old = sigma2_e;
352
353        let weights = shrinkage_weights(&ss, sigma2_u, sigma2_e);
354
355        if let Some(cov) = covariates.filter(|_| p > 0) {
356            if let Some(g) = gls_update_gamma(cov, p, &ss, &weights, y, sigma2_e) {
357                gamma = g;
358            }
359        }
360
361        let r = compute_ols_residuals(y, covariates, &gamma, p, n);
362        (sigma2_u, sigma2_e) = reml_variance_update(&r, &ss, &weights, sigma2_u, p);
363
364        let delta = (sigma2_u - sigma2_u_old).abs() + (sigma2_e - sigma2_e_old).abs();
365        if delta < 1e-10 * (sigma2_u_old + sigma2_e_old) {
366            break;
367        }
368    }
369
370    let final_residuals = compute_ols_residuals(y, covariates, &gamma, p, n);
371    let u_hat = compute_blup(
372        &final_residuals,
373        subject_map,
374        n_subjects,
375        sigma2_u,
376        sigma2_e,
377    );
378
379    ScalarMixedResult {
380        gamma,
381        u_hat,
382        sigma2_u,
383        sigma2_eps: sigma2_e,
384    }
385}
386
387/// OLS estimation of fixed effects.
388fn estimate_fixed_effects(
389    y: &[f64],
390    covariates: Option<&FdMatrix>,
391    p: usize,
392    n: usize,
393) -> Vec<f64> {
394    if p == 0 || covariates.is_none() {
395        return Vec::new();
396    }
397    let cov = covariates.unwrap();
398
399    // Solve (X'X)γ = X'y via Cholesky
400    let mut xtx = vec![0.0; p * p];
401    let mut xty = vec![0.0; p];
402    for i in 0..n {
403        for r in 0..p {
404            xty[r] += cov[(i, r)] * y[i];
405            for s in r..p {
406                let val = cov[(i, r)] * cov[(i, s)];
407                xtx[r * p + s] += val;
408                if r != s {
409                    xtx[s * p + r] += val;
410                }
411            }
412        }
413    }
414    // Regularize
415    for j in 0..p {
416        xtx[j * p + j] += 1e-8;
417    }
418
419    cholesky_solve(&xtx, &xty, p).unwrap_or(vec![0.0; p])
420}
421
422/// Cholesky factorization: A = LL'. Returns L (p×p flat row-major) or None if singular.
423fn cholesky_factor_famm(a: &[f64], p: usize) -> Option<Vec<f64>> {
424    let mut l = vec![0.0; p * p];
425    for j in 0..p {
426        let mut sum = 0.0;
427        for k in 0..j {
428            sum += l[j * p + k] * l[j * p + k];
429        }
430        let diag = a[j * p + j] - sum;
431        if diag <= 0.0 {
432            return None;
433        }
434        l[j * p + j] = diag.sqrt();
435        for i in (j + 1)..p {
436            let mut s = 0.0;
437            for k in 0..j {
438                s += l[i * p + k] * l[j * p + k];
439            }
440            l[i * p + j] = (a[i * p + j] - s) / l[j * p + j];
441        }
442    }
443    Some(l)
444}
445
446/// Solve L z = b (forward) then L' x = z (back).
447fn cholesky_triangular_solve(l: &[f64], b: &[f64], p: usize) -> Vec<f64> {
448    let mut z = vec![0.0; p];
449    for i in 0..p {
450        let mut s = 0.0;
451        for j in 0..i {
452            s += l[i * p + j] * z[j];
453        }
454        z[i] = (b[i] - s) / l[i * p + i];
455    }
456    for i in (0..p).rev() {
457        let mut s = 0.0;
458        for j in (i + 1)..p {
459            s += l[j * p + i] * z[j];
460        }
461        z[i] = (z[i] - s) / l[i * p + i];
462    }
463    z
464}
465
466/// Cholesky solve: A x = b where A is p×p symmetric positive definite.
467fn cholesky_solve(a: &[f64], b: &[f64], p: usize) -> Option<Vec<f64>> {
468    let l = cholesky_factor_famm(a, p)?;
469    Some(cholesky_triangular_solve(&l, b, p))
470}
471
472/// Compute OLS residuals: r = y - X*gamma.
473fn compute_ols_residuals(
474    y: &[f64],
475    covariates: Option<&FdMatrix>,
476    gamma: &[f64],
477    p: usize,
478    n: usize,
479) -> Vec<f64> {
480    let mut residuals = y.to_vec();
481    if p > 0 {
482        if let Some(cov) = covariates {
483            for i in 0..n {
484                for j in 0..p {
485                    residuals[i] -= cov[(i, j)] * gamma[j];
486                }
487            }
488        }
489    }
490    residuals
491}
492
493/// Estimate variance components via method of moments.
494///
495/// σ²_u and σ²_ε from one-way random effects ANOVA.
496fn estimate_variance_components(
497    residuals: &[f64],
498    subject_map: &[usize],
499    n_subjects: usize,
500    n: usize,
501) -> (f64, f64) {
502    // Compute subject means and within-subject SS
503    let mut subject_sums = vec![0.0; n_subjects];
504    let mut subject_counts = vec![0usize; n_subjects];
505    for i in 0..n {
506        let s = subject_map[i];
507        subject_sums[s] += residuals[i];
508        subject_counts[s] += 1;
509    }
510    let subject_means: Vec<f64> = subject_sums
511        .iter()
512        .zip(&subject_counts)
513        .map(|(&s, &c)| if c > 0 { s / c as f64 } else { 0.0 })
514        .collect();
515
516    // Within-subject SS
517    let mut ss_within = 0.0;
518    for i in 0..n {
519        let s = subject_map[i];
520        ss_within += (residuals[i] - subject_means[s]).powi(2);
521    }
522    let df_within = n.saturating_sub(n_subjects);
523
524    // Between-subject SS
525    let grand_mean = residuals.iter().sum::<f64>() / n as f64;
526    let mut ss_between = 0.0;
527    for s in 0..n_subjects {
528        ss_between += subject_counts[s] as f64 * (subject_means[s] - grand_mean).powi(2);
529    }
530
531    let sigma2_eps = if df_within > 0 {
532        ss_within / df_within as f64
533    } else {
534        1e-6
535    };
536
537    // Mean number of observations per subject
538    let n_bar = n as f64 / n_subjects.max(1) as f64;
539    let df_between = n_subjects.saturating_sub(1).max(1);
540    let ms_between = ss_between / df_between as f64;
541    let sigma2_u = ((ms_between - sigma2_eps) / n_bar).max(0.0);
542
543    (sigma2_u, sigma2_eps)
544}
545
546/// Compute BLUP (Best Linear Unbiased Prediction) for random effects.
547///
548/// û_i = σ²_u / (σ²_u + σ²_ε/n_i) * (ȳ_i - x̄_i'γ)
549fn compute_blup(
550    residuals: &[f64],
551    subject_map: &[usize],
552    n_subjects: usize,
553    sigma2_u: f64,
554    sigma2_eps: f64,
555) -> Vec<f64> {
556    let mut subject_sums = vec![0.0; n_subjects];
557    let mut subject_counts = vec![0usize; n_subjects];
558    for (i, &r) in residuals.iter().enumerate() {
559        let s = subject_map[i];
560        subject_sums[s] += r;
561        subject_counts[s] += 1;
562    }
563
564    (0..n_subjects)
565        .map(|s| {
566            let ni = subject_counts[s] as f64;
567            if ni < 1.0 {
568                return 0.0;
569            }
570            let mean_r = subject_sums[s] / ni;
571            let shrinkage = sigma2_u / (sigma2_u + sigma2_eps / ni).max(1e-15);
572            shrinkage * mean_r
573        })
574        .collect()
575}
576
577// ---------------------------------------------------------------------------
578// Recovery of functional coefficients
579// ---------------------------------------------------------------------------
580
581/// Recover β̂(t) = Σ_k γ̂_jk φ_k(t) for each covariate j.
582fn recover_beta_functions(
583    gamma: &[Vec<f64>],
584    rotation: &FdMatrix,
585    p: usize,
586    m: usize,
587    k: usize,
588) -> FdMatrix {
589    let mut beta = FdMatrix::zeros(p, m);
590    for j in 0..p {
591        for t in 0..m {
592            let mut val = 0.0;
593            for comp in 0..k {
594                val += gamma[j][comp] * rotation[(t, comp)];
595            }
596            beta[(j, t)] = val;
597        }
598    }
599    beta
600}
601
602/// Recover b̂_i(t) = Σ_k û_ik φ_k(t) for each subject i.
603fn recover_random_effects(
604    u_hat: &[Vec<f64>],
605    rotation: &FdMatrix,
606    n_subjects: usize,
607    m: usize,
608    k: usize,
609) -> FdMatrix {
610    let mut re = FdMatrix::zeros(n_subjects, m);
611    for s in 0..n_subjects {
612        for t in 0..m {
613            let mut val = 0.0;
614            for comp in 0..k {
615                val += u_hat[s][comp] * rotation[(t, comp)];
616            }
617            re[(s, t)] = val;
618        }
619    }
620    re
621}
622
623/// Compute random effect variance function: Var_i(b̂_i(t)).
624fn compute_random_variance(random_effects: &FdMatrix, n_subjects: usize, m: usize) -> Vec<f64> {
625    (0..m)
626        .map(|t| {
627            let mean: f64 =
628                (0..n_subjects).map(|s| random_effects[(s, t)]).sum::<f64>() / n_subjects as f64;
629            let var: f64 = (0..n_subjects)
630                .map(|s| (random_effects[(s, t)] - mean).powi(2))
631                .sum::<f64>()
632                / n_subjects.max(1) as f64;
633            var
634        })
635        .collect()
636}
637
638/// Compute fitted values and residuals.
639fn compute_fitted_residuals(
640    data: &FdMatrix,
641    mean_function: &[f64],
642    beta_functions: &FdMatrix,
643    random_effects: &FdMatrix,
644    covariates: Option<&FdMatrix>,
645    subject_map: &[usize],
646    n_total: usize,
647    m: usize,
648    p: usize,
649) -> (FdMatrix, FdMatrix) {
650    let mut fitted = FdMatrix::zeros(n_total, m);
651    let mut residuals = FdMatrix::zeros(n_total, m);
652
653    for i in 0..n_total {
654        let s = subject_map[i];
655        for t in 0..m {
656            let mut val = mean_function[t] + random_effects[(s, t)];
657            if p > 0 {
658                if let Some(cov) = covariates {
659                    for j in 0..p {
660                        val += cov[(i, j)] * beta_functions[(j, t)];
661                    }
662                }
663            }
664            fitted[(i, t)] = val;
665            residuals[(i, t)] = data[(i, t)] - val;
666        }
667    }
668
669    (fitted, residuals)
670}
671
672// ---------------------------------------------------------------------------
673// Prediction
674// ---------------------------------------------------------------------------
675
676/// Predict curves for new subjects.
677///
678/// # Arguments
679/// * `result` — Fitted FMM result
680/// * `new_covariates` — Covariates for new subjects (n_new × p)
681///
682/// Returns predicted curves (n_new × m) using only fixed effects (no random effects for new subjects).
683pub fn fmm_predict(result: &FmmResult, new_covariates: Option<&FdMatrix>) -> FdMatrix {
684    let m = result.mean_function.len();
685    let n_new = new_covariates.map_or(1, |c| c.nrows());
686    let p = result.beta_functions.nrows();
687
688    let mut predicted = FdMatrix::zeros(n_new, m);
689    for i in 0..n_new {
690        for t in 0..m {
691            let mut val = result.mean_function[t];
692            if let Some(cov) = new_covariates {
693                for j in 0..p {
694                    val += cov[(i, j)] * result.beta_functions[(j, t)];
695                }
696            }
697            predicted[(i, t)] = val;
698        }
699    }
700    predicted
701}
702
703// ---------------------------------------------------------------------------
704// Hypothesis testing
705// ---------------------------------------------------------------------------
706
707/// Permutation test for fixed effects in functional mixed model.
708///
709/// Tests H₀: β_j(t) = 0 for each covariate j.
710/// Uses integrated squared norm as test statistic: T_j = ∫ β̂_j(t)² dt.
711///
712/// # Arguments
713/// * `data` — All observed curves (n_total × m)
714/// * `subject_ids` — Subject identifiers
715/// * `covariates` — Subject-level covariates (n_total × p)
716/// * `ncomp` — Number of FPC components
717/// * `n_perm` — Number of permutations
718/// * `seed` — Random seed
719pub fn fmm_test_fixed(
720    data: &FdMatrix,
721    subject_ids: &[usize],
722    covariates: &FdMatrix,
723    ncomp: usize,
724    n_perm: usize,
725    seed: u64,
726) -> Option<FmmTestResult> {
727    let n_total = data.nrows();
728    let m = data.ncols();
729    let p = covariates.ncols();
730    if n_total == 0 || p == 0 {
731        return None;
732    }
733
734    // Fit observed model
735    let result = fmm(data, subject_ids, Some(covariates), ncomp)?;
736
737    // Observed test statistics: ∫ β̂_j(t)² dt for each covariate
738    let observed_stats = compute_integrated_beta_sq(&result.beta_functions, p, m);
739
740    // Permutation test
741    let (f_statistics, p_values) = permutation_test(
742        data,
743        subject_ids,
744        covariates,
745        ncomp,
746        n_perm,
747        seed,
748        &observed_stats,
749        p,
750        m,
751    );
752
753    Some(FmmTestResult {
754        f_statistics,
755        p_values,
756    })
757}
758
759/// Compute ∫ β̂_j(t)² dt for each covariate.
760fn compute_integrated_beta_sq(beta: &FdMatrix, p: usize, m: usize) -> Vec<f64> {
761    let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
762    (0..p)
763        .map(|j| {
764            let ss: f64 = (0..m).map(|t| beta[(j, t)].powi(2)).sum();
765            ss * h
766        })
767        .collect()
768}
769
770/// Run permutation test for fixed effects.
771fn permutation_test(
772    data: &FdMatrix,
773    subject_ids: &[usize],
774    covariates: &FdMatrix,
775    ncomp: usize,
776    n_perm: usize,
777    seed: u64,
778    observed_stats: &[f64],
779    p: usize,
780    m: usize,
781) -> (Vec<f64>, Vec<f64>) {
782    use rand::prelude::*;
783    let n_total = data.nrows();
784    let mut rng = StdRng::seed_from_u64(seed);
785    let mut n_ge = vec![0usize; p];
786
787    for _ in 0..n_perm {
788        // Permute covariates across subjects
789        let mut perm_indices: Vec<usize> = (0..n_total).collect();
790        perm_indices.shuffle(&mut rng);
791        let perm_cov = permute_rows(covariates, &perm_indices);
792
793        if let Some(perm_result) = fmm(data, subject_ids, Some(&perm_cov), ncomp) {
794            let perm_stats = compute_integrated_beta_sq(&perm_result.beta_functions, p, m);
795            for j in 0..p {
796                if perm_stats[j] >= observed_stats[j] {
797                    n_ge[j] += 1;
798                }
799            }
800        }
801    }
802
803    let p_values: Vec<f64> = n_ge
804        .iter()
805        .map(|&count| (count + 1) as f64 / (n_perm + 1) as f64)
806        .collect();
807    let f_statistics = observed_stats.to_vec();
808
809    (f_statistics, p_values)
810}
811
812/// Permute rows of a matrix according to given indices.
813fn permute_rows(mat: &FdMatrix, indices: &[usize]) -> FdMatrix {
814    let n = indices.len();
815    let m = mat.ncols();
816    let mut result = FdMatrix::zeros(n, m);
817    for (new_i, &old_i) in indices.iter().enumerate() {
818        for j in 0..m {
819            result[(new_i, j)] = mat[(old_i, j)];
820        }
821    }
822    result
823}
824
825// ---------------------------------------------------------------------------
826// Tests
827// ---------------------------------------------------------------------------
828
829#[cfg(test)]
830mod tests {
831    use super::*;
832    use std::f64::consts::PI;
833
834    fn uniform_grid(m: usize) -> Vec<f64> {
835        (0..m).map(|i| i as f64 / (m - 1) as f64).collect()
836    }
837
838    /// Generate repeated measurements: n_subjects × n_visits curves.
839    /// Subject-level covariate z affects the curve amplitude.
840    fn generate_fmm_data(
841        n_subjects: usize,
842        n_visits: usize,
843        m: usize,
844    ) -> (FdMatrix, Vec<usize>, FdMatrix, Vec<f64>) {
845        let t = uniform_grid(m);
846        let n_total = n_subjects * n_visits;
847        let mut col_major = vec![0.0; n_total * m];
848        let mut subject_ids = vec![0usize; n_total];
849        let mut cov_data = vec![0.0; n_total];
850
851        for s in 0..n_subjects {
852            let z = s as f64 / n_subjects as f64; // covariate in [0, 1)
853            let subject_effect = 0.5 * (s as f64 - n_subjects as f64 / 2.0); // random-like effect
854
855            for v in 0..n_visits {
856                let obs = s * n_visits + v;
857                subject_ids[obs] = s;
858                cov_data[obs] = z;
859                let noise_scale = 0.05;
860
861                for (j, &tj) in t.iter().enumerate() {
862                    // Y_sv(t) = sin(2πt) + z * t + subject_effect * cos(2πt) + noise
863                    let mu = (2.0 * PI * tj).sin();
864                    let fixed = z * tj * 3.0;
865                    let random = subject_effect * (2.0 * PI * tj).cos() * 0.3;
866                    let noise = noise_scale * ((obs * 7 + j * 3) % 100) as f64 / 100.0;
867                    col_major[obs + j * n_total] = mu + fixed + random + noise;
868                }
869            }
870        }
871
872        let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
873        let covariates = FdMatrix::from_column_major(cov_data, n_total, 1).unwrap();
874        (data, subject_ids, covariates, t)
875    }
876
877    #[test]
878    fn test_fmm_basic() {
879        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
880        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
881
882        assert_eq!(result.mean_function.len(), 50);
883        assert_eq!(result.beta_functions.nrows(), 1); // 1 covariate
884        assert_eq!(result.beta_functions.ncols(), 50);
885        assert_eq!(result.random_effects.nrows(), 10);
886        assert_eq!(result.fitted.nrows(), 30);
887        assert_eq!(result.residuals.nrows(), 30);
888        assert_eq!(result.n_subjects, 10);
889    }
890
891    #[test]
892    fn test_fmm_fitted_plus_residuals_equals_data() {
893        let (data, subject_ids, covariates, _t) = generate_fmm_data(8, 3, 40);
894        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
895
896        let n = data.nrows();
897        let m = data.ncols();
898        for i in 0..n {
899            for t in 0..m {
900                let reconstructed = result.fitted[(i, t)] + result.residuals[(i, t)];
901                assert!(
902                    (reconstructed - data[(i, t)]).abs() < 1e-8,
903                    "Fitted + residual should equal data at ({}, {}): {} vs {}",
904                    i,
905                    t,
906                    reconstructed,
907                    data[(i, t)]
908                );
909            }
910        }
911    }
912
913    #[test]
914    fn test_fmm_random_variance_positive() {
915        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
916        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
917
918        for &v in &result.random_variance {
919            assert!(v >= 0.0, "Random variance should be non-negative");
920        }
921    }
922
923    #[test]
924    fn test_fmm_no_covariates() {
925        let (data, subject_ids, _cov, _t) = generate_fmm_data(8, 3, 40);
926        let result = fmm(&data, &subject_ids, None, 3).unwrap();
927
928        assert_eq!(result.beta_functions.nrows(), 0);
929        assert_eq!(result.n_subjects, 8);
930        assert_eq!(result.fitted.nrows(), 24);
931    }
932
933    #[test]
934    fn test_fmm_predict() {
935        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
936        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
937
938        // Predict for new subjects with covariate = 0.5
939        let new_cov = FdMatrix::from_column_major(vec![0.5], 1, 1).unwrap();
940        let predicted = fmm_predict(&result, Some(&new_cov));
941
942        assert_eq!(predicted.nrows(), 1);
943        assert_eq!(predicted.ncols(), 50);
944
945        // Predicted curve should be reasonable (not NaN or extreme)
946        for t in 0..50 {
947            assert!(predicted[(0, t)].is_finite());
948            assert!(
949                predicted[(0, t)].abs() < 20.0,
950                "Predicted value too extreme at t={}: {}",
951                t,
952                predicted[(0, t)]
953            );
954        }
955    }
956
957    #[test]
958    fn test_fmm_test_fixed_detects_effect() {
959        let (data, subject_ids, covariates, _t) = generate_fmm_data(15, 3, 40);
960
961        let result = fmm_test_fixed(&data, &subject_ids, &covariates, 3, 99, 42).unwrap();
962
963        assert_eq!(result.f_statistics.len(), 1);
964        assert_eq!(result.p_values.len(), 1);
965        assert!(
966            result.p_values[0] < 0.1,
967            "Should detect covariate effect, got p={}",
968            result.p_values[0]
969        );
970    }
971
972    #[test]
973    fn test_fmm_test_fixed_no_effect() {
974        let n_subjects = 10;
975        let n_visits = 3;
976        let m = 40;
977        let t = uniform_grid(m);
978        let n_total = n_subjects * n_visits;
979
980        // No covariate effect: Y = sin(2πt) + noise
981        let mut col_major = vec![0.0; n_total * m];
982        let mut subject_ids = vec![0usize; n_total];
983        let mut cov_data = vec![0.0; n_total];
984
985        for s in 0..n_subjects {
986            for v in 0..n_visits {
987                let obs = s * n_visits + v;
988                subject_ids[obs] = s;
989                cov_data[obs] = s as f64 / n_subjects as f64;
990                for (j, &tj) in t.iter().enumerate() {
991                    col_major[obs + j * n_total] =
992                        (2.0 * PI * tj).sin() + 0.1 * ((obs * 7 + j * 3) % 100) as f64 / 100.0;
993                }
994            }
995        }
996
997        let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
998        let covariates = FdMatrix::from_column_major(cov_data, n_total, 1).unwrap();
999
1000        let result = fmm_test_fixed(&data, &subject_ids, &covariates, 3, 99, 42).unwrap();
1001        assert!(
1002            result.p_values[0] > 0.05,
1003            "Should not detect effect, got p={}",
1004            result.p_values[0]
1005        );
1006    }
1007
1008    #[test]
1009    fn test_fmm_invalid_input() {
1010        let data = FdMatrix::zeros(0, 0);
1011        assert!(fmm(&data, &[], None, 1).is_none());
1012
1013        let data = FdMatrix::zeros(10, 50);
1014        let ids = vec![0; 5]; // wrong length
1015        assert!(fmm(&data, &ids, None, 1).is_none());
1016    }
1017
1018    #[test]
1019    fn test_fmm_single_visit_per_subject() {
1020        let n = 10;
1021        let m = 40;
1022        let t = uniform_grid(m);
1023        let mut col_major = vec![0.0; n * m];
1024        let subject_ids: Vec<usize> = (0..n).collect();
1025
1026        for i in 0..n {
1027            for (j, &tj) in t.iter().enumerate() {
1028                col_major[i + j * n] = (2.0 * PI * tj).sin();
1029            }
1030        }
1031        let data = FdMatrix::from_column_major(col_major, n, m).unwrap();
1032
1033        // Should still work with 1 visit per subject
1034        let result = fmm(&data, &subject_ids, None, 2).unwrap();
1035        assert_eq!(result.n_subjects, n);
1036        assert_eq!(result.fitted.nrows(), n);
1037    }
1038
1039    #[test]
1040    fn test_build_subject_map() {
1041        let (map, n) = build_subject_map(&[5, 5, 10, 10, 20]);
1042        assert_eq!(n, 3);
1043        assert_eq!(map, vec![0, 0, 1, 1, 2]);
1044    }
1045
1046    #[test]
1047    fn test_variance_components_positive() {
1048        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1049        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1050
1051        assert!(result.sigma2_eps >= 0.0);
1052        for &s in &result.sigma2_u {
1053            assert!(s >= 0.0);
1054        }
1055    }
1056}