Skip to main content

fdars_core/
famm.rs

1//! Functional Additive Mixed Models (FAMM).
2//!
3//! Implements functional mixed effects models for repeated functional
4//! measurements with subject-level covariates.
5//!
6//! Model: `Y_ij(t) = μ(t) + X_i'β(t) + b_i(t) + ε_ij(t)`
7//!
8//! Key functions:
9//! - [`fmm`] — Fit a functional mixed model via FPC decomposition
10//! - [`fmm_predict`] — Predict curves for new subjects
11//! - [`fmm_test_fixed`] — Hypothesis test on fixed effects
12
13use crate::error::FdarError;
14use crate::matrix::FdMatrix;
15use crate::regression::fdata_to_pc_1d;
16
17/// Result of a functional mixed model fit.
18#[derive(Debug, Clone, PartialEq)]
19pub struct FmmResult {
20    /// Overall mean function μ̂(t) (length m)
21    pub mean_function: Vec<f64>,
22    /// Fixed effect coefficient functions β̂_j(t) (p × m matrix, one row per covariate)
23    pub beta_functions: FdMatrix,
24    /// Random effect functions b̂_i(t) per subject (n_subjects × m)
25    pub random_effects: FdMatrix,
26    /// Fitted values for all observations (n_total × m)
27    pub fitted: FdMatrix,
28    /// Residuals (n_total × m)
29    pub residuals: FdMatrix,
30    /// Variance of random effects at each time point (length m)
31    pub random_variance: Vec<f64>,
32    /// Residual variance estimate
33    pub sigma2_eps: f64,
34    /// Random effect variance estimate (per-component)
35    pub sigma2_u: Vec<f64>,
36    /// Number of FPC components used
37    pub ncomp: usize,
38    /// Number of subjects
39    pub n_subjects: usize,
40    /// FPC eigenvalues (singular values squared / n)
41    pub eigenvalues: Vec<f64>,
42}
43
44/// Result of fixed effect hypothesis test.
45#[derive(Debug, Clone, PartialEq)]
46pub struct FmmTestResult {
47    /// F-statistic per covariate (length p)
48    pub f_statistics: Vec<f64>,
49    /// P-values per covariate (via permutation, length p)
50    pub p_values: Vec<f64>,
51}
52
53// ---------------------------------------------------------------------------
54// Core FMM algorithm
55// ---------------------------------------------------------------------------
56
57/// Fit a functional mixed model via FPC decomposition.
58///
59/// # Arguments
60/// * `data` — All observed curves (n_total × m), stacked across subjects and visits
61/// * `subject_ids` — Subject identifier for each curve (length n_total)
62/// * `covariates` — Subject-level covariates (n_total × p).
63///   Each row corresponds to the same curve in `data`.
64///   If a covariate is subject-level, its value should be repeated across visits.
65/// * `ncomp` — Number of FPC components
66///
67/// # Algorithm
68/// 1. Pool curves, compute FPCA
69/// 2. For each FPC score, fit scalar mixed model: ξ_ijk = x_i'γ_k + u_ik + e_ijk
70/// 3. Recover β̂(t) and b̂_i(t) from component coefficients
71///
72/// # Errors
73///
74/// Returns [`FdarError::InvalidDimension`] if `data` is empty (zero rows or
75/// columns), or if `subject_ids.len()` does not match the number of rows.
76/// Returns [`FdarError::InvalidParameter`] if `ncomp` is zero.
77/// Returns [`FdarError::ComputationFailed`] if the underlying FPCA fails.
78#[must_use = "expensive computation whose result should not be discarded"]
79pub fn fmm(
80    data: &FdMatrix,
81    subject_ids: &[usize],
82    covariates: Option<&FdMatrix>,
83    ncomp: usize,
84) -> Result<FmmResult, FdarError> {
85    let n_total = data.nrows();
86    let m = data.ncols();
87    if n_total == 0 || m == 0 {
88        return Err(FdarError::InvalidDimension {
89            parameter: "data",
90            expected: "non-empty matrix".to_string(),
91            actual: format!("{n_total} x {m}"),
92        });
93    }
94    if subject_ids.len() != n_total {
95        return Err(FdarError::InvalidDimension {
96            parameter: "subject_ids",
97            expected: format!("length {n_total}"),
98            actual: format!("length {}", subject_ids.len()),
99        });
100    }
101    if ncomp == 0 {
102        return Err(FdarError::InvalidParameter {
103            parameter: "ncomp",
104            message: "must be >= 1".to_string(),
105        });
106    }
107
108    // Determine unique subjects
109    let (subject_map, n_subjects) = build_subject_map(subject_ids);
110
111    // Step 1: FPCA on pooled data
112    let fpca = fdata_to_pc_1d(data, ncomp)?;
113    let k = fpca.scores.ncols(); // actual number of components
114
115    // Step 2: For each FPC score, fit scalar mixed model
116    // Normalize scores by sqrt(h) to match R's L²-weighted FPCA convention.
117    // This ensures variance components are on the same scale as R's lmer().
118    let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
119    let score_scale = h.sqrt();
120
121    let p = covariates.map_or(0, super::matrix::FdMatrix::ncols);
122    let mut gamma = vec![vec![0.0; k]; p]; // gamma[j][k] = fixed effect coeff j for component k
123    let mut u_hat = vec![vec![0.0; k]; n_subjects]; // u_hat[i][k] = random effect for subject i, component k
124    let mut sigma2_u = vec![0.0; k];
125    let mut sigma2_eps_total = 0.0;
126
127    for comp in 0..k {
128        // Scale scores to L²-normalized space
129        let scores: Vec<f64> = (0..n_total)
130            .map(|i| fpca.scores[(i, comp)] * score_scale)
131            .collect();
132        let result = fit_scalar_mixed_model(&scores, &subject_map, n_subjects, covariates, p);
133        // Scale gamma back to original score space for beta reconstruction
134        for j in 0..p {
135            gamma[j][comp] = result.gamma[j] / score_scale;
136        }
137        for s in 0..n_subjects {
138            // Scale u_hat back for random effect reconstruction
139            u_hat[s][comp] = result.u_hat[s] / score_scale;
140        }
141        // Keep variance components in L²-normalized space (matching R)
142        sigma2_u[comp] = result.sigma2_u;
143        sigma2_eps_total += result.sigma2_eps;
144    }
145    let sigma2_eps = sigma2_eps_total / k as f64;
146
147    // Step 3: Recover functional coefficients (using gamma in original scale)
148    let beta_functions = recover_beta_functions(&gamma, &fpca.rotation, p, m, k);
149    let random_effects = recover_random_effects(&u_hat, &fpca.rotation, n_subjects, m, k);
150
151    // Compute random variance function: Var(b_i(t)) across subjects
152    let random_variance = compute_random_variance(&random_effects, n_subjects, m);
153
154    // Compute fitted and residuals
155    let (fitted, residuals) = compute_fitted_residuals(
156        data,
157        &fpca.mean,
158        &beta_functions,
159        &random_effects,
160        covariates,
161        &subject_map,
162        n_total,
163        m,
164        p,
165    );
166
167    let eigenvalues: Vec<f64> = fpca
168        .singular_values
169        .iter()
170        .map(|&sv| sv * sv / n_total as f64)
171        .collect();
172
173    Ok(FmmResult {
174        mean_function: fpca.mean,
175        beta_functions,
176        random_effects,
177        fitted,
178        residuals,
179        random_variance,
180        sigma2_eps,
181        sigma2_u,
182        ncomp: k,
183        n_subjects,
184        eigenvalues,
185    })
186}
187
188/// Build mapping from observation index to subject index (0..n_subjects-1).
189fn build_subject_map(subject_ids: &[usize]) -> (Vec<usize>, usize) {
190    let mut unique_ids: Vec<usize> = subject_ids.to_vec();
191    unique_ids.sort_unstable();
192    unique_ids.dedup();
193    let n_subjects = unique_ids.len();
194
195    let map: Vec<usize> = subject_ids
196        .iter()
197        .map(|id| unique_ids.iter().position(|u| u == id).unwrap_or(0))
198        .collect();
199
200    (map, n_subjects)
201}
202
203/// Scalar mixed model result for one FPC component.
204struct ScalarMixedResult {
205    gamma: Vec<f64>, // fixed effects (length p)
206    u_hat: Vec<f64>, // random effects per subject (length n_subjects)
207    sigma2_u: f64,   // random effect variance
208    sigma2_eps: f64, // residual variance
209}
210
211/// Precomputed subject structure for the mixed model.
212struct SubjectStructure {
213    counts: Vec<usize>,
214    obs: Vec<Vec<usize>>,
215}
216
217impl SubjectStructure {
218    fn new(subject_map: &[usize], n_subjects: usize, n: usize) -> Self {
219        let mut counts = vec![0usize; n_subjects];
220        let mut obs: Vec<Vec<usize>> = vec![Vec::new(); n_subjects];
221        for i in 0..n {
222            let s = subject_map[i];
223            counts[s] += 1;
224            obs[s].push(i);
225        }
226        Self { counts, obs }
227    }
228}
229
230/// Compute shrinkage weights: w_s = σ²_u / (σ²_u + σ²_e / n_s).
231fn shrinkage_weights(ss: &SubjectStructure, sigma2_u: f64, sigma2_e: f64) -> Vec<f64> {
232    ss.counts
233        .iter()
234        .map(|&c| {
235            let ns = c as f64;
236            if ns < 1.0 {
237                0.0
238            } else {
239                sigma2_u / (sigma2_u + sigma2_e / ns)
240            }
241        })
242        .collect()
243}
244
245/// GLS fixed effect update using block-diagonal V^{-1}.
246///
247/// Computes γ = (X'V⁻¹X)⁻¹ X'V⁻¹y exploiting the balanced random intercept structure.
248fn gls_update_gamma(
249    cov: &FdMatrix,
250    p: usize,
251    ss: &SubjectStructure,
252    weights: &[f64],
253    y: &[f64],
254    sigma2_e: f64,
255) -> Option<Vec<f64>> {
256    let n_subjects = ss.counts.len();
257    let mut xtvinvx = vec![0.0; p * p];
258    let mut xtvinvy = vec![0.0; p];
259    let inv_e = 1.0 / sigma2_e;
260
261    for s in 0..n_subjects {
262        let ns = ss.counts[s] as f64;
263        if ns < 1.0 {
264            continue;
265        }
266        let (x_sum, y_sum) = subject_sums(cov, y, &ss.obs[s], p);
267        accumulate_gls_terms(
268            cov,
269            y,
270            &ss.obs[s],
271            &x_sum,
272            y_sum,
273            weights[s],
274            ns,
275            inv_e,
276            p,
277            &mut xtvinvx,
278            &mut xtvinvy,
279        );
280    }
281
282    for j in 0..p {
283        xtvinvx[j * p + j] += 1e-10;
284    }
285    cholesky_solve(&xtvinvx, &xtvinvy, p)
286}
287
288/// Compute subject-level covariate sums and response sum.
289fn subject_sums(cov: &FdMatrix, y: &[f64], obs: &[usize], p: usize) -> (Vec<f64>, f64) {
290    let mut x_sum = vec![0.0; p];
291    let mut y_sum = 0.0;
292    for &i in obs {
293        for r in 0..p {
294            x_sum[r] += cov[(i, r)];
295        }
296        y_sum += y[i];
297    }
298    (x_sum, y_sum)
299}
300
301/// Accumulate X'V^{-1}X and X'V^{-1}y for one subject.
302fn accumulate_gls_terms(
303    cov: &FdMatrix,
304    y: &[f64],
305    obs: &[usize],
306    x_sum: &[f64],
307    y_sum: f64,
308    w_s: f64,
309    ns: f64,
310    inv_e: f64,
311    p: usize,
312    xtvinvx: &mut [f64],
313    xtvinvy: &mut [f64],
314) {
315    for &i in obs {
316        let vinv_y = inv_e * (y[i] - w_s * y_sum / ns);
317        for r in 0..p {
318            xtvinvy[r] += cov[(i, r)] * vinv_y;
319            for c in r..p {
320                let vinv_xc = inv_e * (cov[(i, c)] - w_s * x_sum[c] / ns);
321                let val = cov[(i, r)] * vinv_xc;
322                xtvinvx[r * p + c] += val;
323                if r != c {
324                    xtvinvx[c * p + r] += val;
325                }
326            }
327        }
328    }
329}
330
331/// REML EM update for variance components.
332///
333/// Returns (σ²_u_new, σ²_e_new) from the conditional expectations.
334/// Uses n - p divisor for σ²_e (REML correction where p = number of fixed effects).
335fn reml_variance_update(
336    residuals: &[f64],
337    ss: &SubjectStructure,
338    weights: &[f64],
339    sigma2_u: f64,
340    p: usize,
341) -> (f64, f64) {
342    let n_subjects = ss.counts.len();
343    let n: usize = ss.counts.iter().sum();
344    let mut sigma2_u_new = 0.0;
345    let mut sigma2_e_new = 0.0;
346
347    for s in 0..n_subjects {
348        let ns = ss.counts[s] as f64;
349        if ns < 1.0 {
350            continue;
351        }
352        let w_s = weights[s];
353        let mean_r_s: f64 = ss.obs[s].iter().map(|&i| residuals[i]).sum::<f64>() / ns;
354        let u_hat_s = w_s * mean_r_s;
355        let cond_var_s = sigma2_u * (1.0 - w_s);
356
357        sigma2_u_new += u_hat_s * u_hat_s + cond_var_s;
358        for &i in &ss.obs[s] {
359            sigma2_e_new += (residuals[i] - u_hat_s).powi(2);
360        }
361        sigma2_e_new += ns * cond_var_s;
362    }
363
364    // REML divisor: n - p for residual variance (matches R's lmer)
365    let denom_e = (n.saturating_sub(p)).max(1) as f64;
366
367    (
368        (sigma2_u_new / n_subjects as f64).max(1e-15),
369        (sigma2_e_new / denom_e).max(1e-15),
370    )
371}
372
373/// Fit scalar mixed model: y_ij = x_i'γ + u_i + e_ij.
374///
375/// Uses iterative GLS for fixed effects + REML EM for variance components,
376/// matching R's lmer() behavior. Initializes from Henderson's ANOVA, then
377/// iterates until convergence.
378fn fit_scalar_mixed_model(
379    y: &[f64],
380    subject_map: &[usize],
381    n_subjects: usize,
382    covariates: Option<&FdMatrix>,
383    p: usize,
384) -> ScalarMixedResult {
385    let n = y.len();
386    let ss = SubjectStructure::new(subject_map, n_subjects, n);
387
388    // Initialize from OLS + Henderson's ANOVA
389    let gamma_init = estimate_fixed_effects(y, covariates, p, n);
390    let residuals_init = compute_ols_residuals(y, covariates, &gamma_init, p, n);
391    let (mut sigma2_u, mut sigma2_e) =
392        estimate_variance_components(&residuals_init, subject_map, n_subjects, n);
393
394    if sigma2_e < 1e-15 {
395        sigma2_e = 1e-6;
396    }
397    if sigma2_u < 1e-15 {
398        sigma2_u = sigma2_e * 0.1;
399    }
400
401    let mut gamma = gamma_init;
402
403    for _iter in 0..50 {
404        let sigma2_u_old = sigma2_u;
405        let sigma2_e_old = sigma2_e;
406
407        let weights = shrinkage_weights(&ss, sigma2_u, sigma2_e);
408
409        if let Some(cov) = covariates.filter(|_| p > 0) {
410            if let Some(g) = gls_update_gamma(cov, p, &ss, &weights, y, sigma2_e) {
411                gamma = g;
412            }
413        }
414
415        let r = compute_ols_residuals(y, covariates, &gamma, p, n);
416        (sigma2_u, sigma2_e) = reml_variance_update(&r, &ss, &weights, sigma2_u, p);
417
418        let delta = (sigma2_u - sigma2_u_old).abs() + (sigma2_e - sigma2_e_old).abs();
419        if delta < 1e-10 * (sigma2_u_old + sigma2_e_old) {
420            break;
421        }
422    }
423
424    let final_residuals = compute_ols_residuals(y, covariates, &gamma, p, n);
425    let u_hat = compute_blup(
426        &final_residuals,
427        subject_map,
428        n_subjects,
429        sigma2_u,
430        sigma2_e,
431    );
432
433    ScalarMixedResult {
434        gamma,
435        u_hat,
436        sigma2_u,
437        sigma2_eps: sigma2_e,
438    }
439}
440
441/// OLS estimation of fixed effects.
442fn estimate_fixed_effects(
443    y: &[f64],
444    covariates: Option<&FdMatrix>,
445    p: usize,
446    n: usize,
447) -> Vec<f64> {
448    if p == 0 || covariates.is_none() {
449        return Vec::new();
450    }
451    let cov = covariates.expect("checked: covariates is Some");
452
453    // Solve (X'X)γ = X'y via Cholesky
454    let mut xtx = vec![0.0; p * p];
455    let mut xty = vec![0.0; p];
456    for i in 0..n {
457        for r in 0..p {
458            xty[r] += cov[(i, r)] * y[i];
459            for s in r..p {
460                let val = cov[(i, r)] * cov[(i, s)];
461                xtx[r * p + s] += val;
462                if r != s {
463                    xtx[s * p + r] += val;
464                }
465            }
466        }
467    }
468    // Regularize
469    for j in 0..p {
470        xtx[j * p + j] += 1e-8;
471    }
472
473    cholesky_solve(&xtx, &xty, p).unwrap_or(vec![0.0; p])
474}
475
476/// Cholesky factorization: A = LL'. Returns L (p×p flat row-major) or None if singular.
477fn cholesky_factor_famm(a: &[f64], p: usize) -> Option<Vec<f64>> {
478    let mut l = vec![0.0; p * p];
479    for j in 0..p {
480        let mut sum = 0.0;
481        for k in 0..j {
482            sum += l[j * p + k] * l[j * p + k];
483        }
484        let diag = a[j * p + j] - sum;
485        if diag <= 0.0 {
486            return None;
487        }
488        l[j * p + j] = diag.sqrt();
489        for i in (j + 1)..p {
490            let mut s = 0.0;
491            for k in 0..j {
492                s += l[i * p + k] * l[j * p + k];
493            }
494            l[i * p + j] = (a[i * p + j] - s) / l[j * p + j];
495        }
496    }
497    Some(l)
498}
499
500/// Solve L z = b (forward) then L' x = z (back).
501fn cholesky_triangular_solve(l: &[f64], b: &[f64], p: usize) -> Vec<f64> {
502    let mut z = vec![0.0; p];
503    for i in 0..p {
504        let mut s = 0.0;
505        for j in 0..i {
506            s += l[i * p + j] * z[j];
507        }
508        z[i] = (b[i] - s) / l[i * p + i];
509    }
510    for i in (0..p).rev() {
511        let mut s = 0.0;
512        for j in (i + 1)..p {
513            s += l[j * p + i] * z[j];
514        }
515        z[i] = (z[i] - s) / l[i * p + i];
516    }
517    z
518}
519
520/// Cholesky solve: A x = b where A is p×p symmetric positive definite.
521fn cholesky_solve(a: &[f64], b: &[f64], p: usize) -> Option<Vec<f64>> {
522    let l = cholesky_factor_famm(a, p)?;
523    Some(cholesky_triangular_solve(&l, b, p))
524}
525
526/// Compute OLS residuals: r = y - X*gamma.
527fn compute_ols_residuals(
528    y: &[f64],
529    covariates: Option<&FdMatrix>,
530    gamma: &[f64],
531    p: usize,
532    n: usize,
533) -> Vec<f64> {
534    let mut residuals = y.to_vec();
535    if p > 0 {
536        if let Some(cov) = covariates {
537            for i in 0..n {
538                for j in 0..p {
539                    residuals[i] -= cov[(i, j)] * gamma[j];
540                }
541            }
542        }
543    }
544    residuals
545}
546
547/// Estimate variance components via method of moments.
548///
549/// σ²_u and σ²_ε from one-way random effects ANOVA.
550fn estimate_variance_components(
551    residuals: &[f64],
552    subject_map: &[usize],
553    n_subjects: usize,
554    n: usize,
555) -> (f64, f64) {
556    // Compute subject means and within-subject SS
557    let mut subject_sums = vec![0.0; n_subjects];
558    let mut subject_counts = vec![0usize; n_subjects];
559    for i in 0..n {
560        let s = subject_map[i];
561        subject_sums[s] += residuals[i];
562        subject_counts[s] += 1;
563    }
564    let subject_means: Vec<f64> = subject_sums
565        .iter()
566        .zip(&subject_counts)
567        .map(|(&s, &c)| if c > 0 { s / c as f64 } else { 0.0 })
568        .collect();
569
570    // Within-subject SS
571    let mut ss_within = 0.0;
572    for i in 0..n {
573        let s = subject_map[i];
574        ss_within += (residuals[i] - subject_means[s]).powi(2);
575    }
576    let df_within = n.saturating_sub(n_subjects);
577
578    // Between-subject SS
579    let grand_mean = residuals.iter().sum::<f64>() / n as f64;
580    let mut ss_between = 0.0;
581    for s in 0..n_subjects {
582        ss_between += subject_counts[s] as f64 * (subject_means[s] - grand_mean).powi(2);
583    }
584
585    let sigma2_eps = if df_within > 0 {
586        ss_within / df_within as f64
587    } else {
588        1e-6
589    };
590
591    // Mean number of observations per subject
592    let n_bar = n as f64 / n_subjects.max(1) as f64;
593    let df_between = n_subjects.saturating_sub(1).max(1);
594    let ms_between = ss_between / df_between as f64;
595    let sigma2_u = ((ms_between - sigma2_eps) / n_bar).max(0.0);
596
597    (sigma2_u, sigma2_eps)
598}
599
600/// Compute BLUP (Best Linear Unbiased Prediction) for random effects.
601///
602/// û_i = σ²_u / (σ²_u + σ²_ε/n_i) * (ȳ_i - x̄_i'γ)
603fn compute_blup(
604    residuals: &[f64],
605    subject_map: &[usize],
606    n_subjects: usize,
607    sigma2_u: f64,
608    sigma2_eps: f64,
609) -> Vec<f64> {
610    let mut subject_sums = vec![0.0; n_subjects];
611    let mut subject_counts = vec![0usize; n_subjects];
612    for (i, &r) in residuals.iter().enumerate() {
613        let s = subject_map[i];
614        subject_sums[s] += r;
615        subject_counts[s] += 1;
616    }
617
618    (0..n_subjects)
619        .map(|s| {
620            let ni = subject_counts[s] as f64;
621            if ni < 1.0 {
622                return 0.0;
623            }
624            let mean_r = subject_sums[s] / ni;
625            let shrinkage = sigma2_u / (sigma2_u + sigma2_eps / ni).max(1e-15);
626            shrinkage * mean_r
627        })
628        .collect()
629}
630
631// ---------------------------------------------------------------------------
632// Recovery of functional coefficients
633// ---------------------------------------------------------------------------
634
635/// Recover β̂(t) = Σ_k γ̂_jk φ_k(t) for each covariate j.
636fn recover_beta_functions(
637    gamma: &[Vec<f64>],
638    rotation: &FdMatrix,
639    p: usize,
640    m: usize,
641    k: usize,
642) -> FdMatrix {
643    let mut beta = FdMatrix::zeros(p, m);
644    for j in 0..p {
645        for t in 0..m {
646            let mut val = 0.0;
647            for comp in 0..k {
648                val += gamma[j][comp] * rotation[(t, comp)];
649            }
650            beta[(j, t)] = val;
651        }
652    }
653    beta
654}
655
656/// Recover b̂_i(t) = Σ_k û_ik φ_k(t) for each subject i.
657fn recover_random_effects(
658    u_hat: &[Vec<f64>],
659    rotation: &FdMatrix,
660    n_subjects: usize,
661    m: usize,
662    k: usize,
663) -> FdMatrix {
664    let mut re = FdMatrix::zeros(n_subjects, m);
665    for s in 0..n_subjects {
666        for t in 0..m {
667            let mut val = 0.0;
668            for comp in 0..k {
669                val += u_hat[s][comp] * rotation[(t, comp)];
670            }
671            re[(s, t)] = val;
672        }
673    }
674    re
675}
676
677/// Compute random effect variance function: Var_i(b̂_i(t)).
678fn compute_random_variance(random_effects: &FdMatrix, n_subjects: usize, m: usize) -> Vec<f64> {
679    (0..m)
680        .map(|t| {
681            let mean: f64 =
682                (0..n_subjects).map(|s| random_effects[(s, t)]).sum::<f64>() / n_subjects as f64;
683            let var: f64 = (0..n_subjects)
684                .map(|s| (random_effects[(s, t)] - mean).powi(2))
685                .sum::<f64>()
686                / n_subjects.max(1) as f64;
687            var
688        })
689        .collect()
690}
691
692/// Compute fitted values and residuals.
693fn compute_fitted_residuals(
694    data: &FdMatrix,
695    mean_function: &[f64],
696    beta_functions: &FdMatrix,
697    random_effects: &FdMatrix,
698    covariates: Option<&FdMatrix>,
699    subject_map: &[usize],
700    n_total: usize,
701    m: usize,
702    p: usize,
703) -> (FdMatrix, FdMatrix) {
704    let mut fitted = FdMatrix::zeros(n_total, m);
705    let mut residuals = FdMatrix::zeros(n_total, m);
706
707    for i in 0..n_total {
708        let s = subject_map[i];
709        for t in 0..m {
710            let mut val = mean_function[t] + random_effects[(s, t)];
711            if p > 0 {
712                if let Some(cov) = covariates {
713                    for j in 0..p {
714                        val += cov[(i, j)] * beta_functions[(j, t)];
715                    }
716                }
717            }
718            fitted[(i, t)] = val;
719            residuals[(i, t)] = data[(i, t)] - val;
720        }
721    }
722
723    (fitted, residuals)
724}
725
726// ---------------------------------------------------------------------------
727// Prediction
728// ---------------------------------------------------------------------------
729
730/// Predict curves for new subjects.
731///
732/// # Arguments
733/// * `result` — Fitted FMM result
734/// * `new_covariates` — Covariates for new subjects (n_new × p)
735///
736/// Returns predicted curves (n_new × m) using only fixed effects (no random effects for new subjects).
737#[must_use = "prediction result should not be discarded"]
738pub fn fmm_predict(result: &FmmResult, new_covariates: Option<&FdMatrix>) -> FdMatrix {
739    let m = result.mean_function.len();
740    let n_new = new_covariates.map_or(1, super::matrix::FdMatrix::nrows);
741    let p = result.beta_functions.nrows();
742
743    let mut predicted = FdMatrix::zeros(n_new, m);
744    for i in 0..n_new {
745        for t in 0..m {
746            let mut val = result.mean_function[t];
747            if let Some(cov) = new_covariates {
748                for j in 0..p {
749                    val += cov[(i, j)] * result.beta_functions[(j, t)];
750                }
751            }
752            predicted[(i, t)] = val;
753        }
754    }
755    predicted
756}
757
758// ---------------------------------------------------------------------------
759// Hypothesis testing
760// ---------------------------------------------------------------------------
761
762/// Permutation test for fixed effects in functional mixed model.
763///
764/// Tests H₀: β_j(t) = 0 for each covariate j.
765/// Uses integrated squared norm as test statistic: T_j = ∫ β̂_j(t)² dt.
766///
767/// # Arguments
768/// * `data` — All observed curves (n_total × m)
769/// * `subject_ids` — Subject identifiers
770/// * `covariates` — Subject-level covariates (n_total × p)
771/// * `ncomp` — Number of FPC components
772/// * `n_perm` — Number of permutations
773/// * `seed` — Random seed
774///
775/// # Errors
776///
777/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows, or
778/// `covariates` has zero columns.
779/// Propagates errors from [`fmm`] (e.g., dimension mismatches or FPCA failure).
780#[must_use = "expensive computation whose result should not be discarded"]
781pub fn fmm_test_fixed(
782    data: &FdMatrix,
783    subject_ids: &[usize],
784    covariates: &FdMatrix,
785    ncomp: usize,
786    n_perm: usize,
787    seed: u64,
788) -> Result<FmmTestResult, FdarError> {
789    let n_total = data.nrows();
790    let m = data.ncols();
791    let p = covariates.ncols();
792    if n_total == 0 {
793        return Err(FdarError::InvalidDimension {
794            parameter: "data",
795            expected: "non-empty matrix".to_string(),
796            actual: format!("{n_total} rows"),
797        });
798    }
799    if p == 0 {
800        return Err(FdarError::InvalidDimension {
801            parameter: "covariates",
802            expected: "at least 1 column".to_string(),
803            actual: "0 columns".to_string(),
804        });
805    }
806
807    // Fit observed model
808    let result = fmm(data, subject_ids, Some(covariates), ncomp)?;
809
810    // Observed test statistics: ∫ β̂_j(t)² dt for each covariate
811    let observed_stats = compute_integrated_beta_sq(&result.beta_functions, p, m);
812
813    // Permutation test
814    let (f_statistics, p_values) = permutation_test(
815        data,
816        subject_ids,
817        covariates,
818        ncomp,
819        n_perm,
820        seed,
821        &observed_stats,
822        p,
823        m,
824    );
825
826    Ok(FmmTestResult {
827        f_statistics,
828        p_values,
829    })
830}
831
832/// Compute ∫ β̂_j(t)² dt for each covariate.
833fn compute_integrated_beta_sq(beta: &FdMatrix, p: usize, m: usize) -> Vec<f64> {
834    let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
835    (0..p)
836        .map(|j| {
837            let ss: f64 = (0..m).map(|t| beta[(j, t)].powi(2)).sum();
838            ss * h
839        })
840        .collect()
841}
842
843/// Run permutation test for fixed effects.
844fn permutation_test(
845    data: &FdMatrix,
846    subject_ids: &[usize],
847    covariates: &FdMatrix,
848    ncomp: usize,
849    n_perm: usize,
850    seed: u64,
851    observed_stats: &[f64],
852    p: usize,
853    m: usize,
854) -> (Vec<f64>, Vec<f64>) {
855    use rand::prelude::*;
856    let n_total = data.nrows();
857    let mut rng = StdRng::seed_from_u64(seed);
858    let mut n_ge = vec![0usize; p];
859
860    for _ in 0..n_perm {
861        // Permute covariates across subjects
862        let mut perm_indices: Vec<usize> = (0..n_total).collect();
863        perm_indices.shuffle(&mut rng);
864        let perm_cov = permute_rows(covariates, &perm_indices);
865
866        if let Ok(perm_result) = fmm(data, subject_ids, Some(&perm_cov), ncomp) {
867            let perm_stats = compute_integrated_beta_sq(&perm_result.beta_functions, p, m);
868            for j in 0..p {
869                if perm_stats[j] >= observed_stats[j] {
870                    n_ge[j] += 1;
871                }
872            }
873        }
874    }
875
876    let p_values: Vec<f64> = n_ge
877        .iter()
878        .map(|&count| (count + 1) as f64 / (n_perm + 1) as f64)
879        .collect();
880    let f_statistics = observed_stats.to_vec();
881
882    (f_statistics, p_values)
883}
884
885/// Permute rows of a matrix according to given indices.
886fn permute_rows(mat: &FdMatrix, indices: &[usize]) -> FdMatrix {
887    let n = indices.len();
888    let m = mat.ncols();
889    let mut result = FdMatrix::zeros(n, m);
890    for (new_i, &old_i) in indices.iter().enumerate() {
891        for j in 0..m {
892            result[(new_i, j)] = mat[(old_i, j)];
893        }
894    }
895    result
896}
897
898// ---------------------------------------------------------------------------
899// Tests
900// ---------------------------------------------------------------------------
901
902#[cfg(test)]
903mod tests {
904    use super::*;
905    use std::f64::consts::PI;
906
907    fn uniform_grid(m: usize) -> Vec<f64> {
908        (0..m).map(|i| i as f64 / (m - 1) as f64).collect()
909    }
910
911    /// Generate repeated measurements: n_subjects × n_visits curves.
912    /// Subject-level covariate z affects the curve amplitude.
913    fn generate_fmm_data(
914        n_subjects: usize,
915        n_visits: usize,
916        m: usize,
917    ) -> (FdMatrix, Vec<usize>, FdMatrix, Vec<f64>) {
918        let t = uniform_grid(m);
919        let n_total = n_subjects * n_visits;
920        let mut col_major = vec![0.0; n_total * m];
921        let mut subject_ids = vec![0usize; n_total];
922        let mut cov_data = vec![0.0; n_total];
923
924        for s in 0..n_subjects {
925            let z = s as f64 / n_subjects as f64; // covariate in [0, 1)
926            let subject_effect = 0.5 * (s as f64 - n_subjects as f64 / 2.0); // random-like effect
927
928            for v in 0..n_visits {
929                let obs = s * n_visits + v;
930                subject_ids[obs] = s;
931                cov_data[obs] = z;
932                let noise_scale = 0.05;
933
934                for (j, &tj) in t.iter().enumerate() {
935                    // Y_sv(t) = sin(2πt) + z * t + subject_effect * cos(2πt) + noise
936                    let mu = (2.0 * PI * tj).sin();
937                    let fixed = z * tj * 3.0;
938                    let random = subject_effect * (2.0 * PI * tj).cos() * 0.3;
939                    let noise = noise_scale * ((obs * 7 + j * 3) % 100) as f64 / 100.0;
940                    col_major[obs + j * n_total] = mu + fixed + random + noise;
941                }
942            }
943        }
944
945        let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
946        let covariates = FdMatrix::from_column_major(cov_data, n_total, 1).unwrap();
947        (data, subject_ids, covariates, t)
948    }
949
950    #[test]
951    fn test_fmm_basic() {
952        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
953        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
954
955        assert_eq!(result.mean_function.len(), 50);
956        assert_eq!(result.beta_functions.nrows(), 1); // 1 covariate
957        assert_eq!(result.beta_functions.ncols(), 50);
958        assert_eq!(result.random_effects.nrows(), 10);
959        assert_eq!(result.fitted.nrows(), 30);
960        assert_eq!(result.residuals.nrows(), 30);
961        assert_eq!(result.n_subjects, 10);
962    }
963
964    #[test]
965    fn test_fmm_fitted_plus_residuals_equals_data() {
966        let (data, subject_ids, covariates, _t) = generate_fmm_data(8, 3, 40);
967        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
968
969        let n = data.nrows();
970        let m = data.ncols();
971        for i in 0..n {
972            for t in 0..m {
973                let reconstructed = result.fitted[(i, t)] + result.residuals[(i, t)];
974                assert!(
975                    (reconstructed - data[(i, t)]).abs() < 1e-8,
976                    "Fitted + residual should equal data at ({}, {}): {} vs {}",
977                    i,
978                    t,
979                    reconstructed,
980                    data[(i, t)]
981                );
982            }
983        }
984    }
985
986    #[test]
987    fn test_fmm_random_variance_positive() {
988        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
989        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
990
991        for &v in &result.random_variance {
992            assert!(v >= 0.0, "Random variance should be non-negative");
993        }
994    }
995
996    #[test]
997    fn test_fmm_no_covariates() {
998        let (data, subject_ids, _cov, _t) = generate_fmm_data(8, 3, 40);
999        let result = fmm(&data, &subject_ids, None, 3).unwrap();
1000
1001        assert_eq!(result.beta_functions.nrows(), 0);
1002        assert_eq!(result.n_subjects, 8);
1003        assert_eq!(result.fitted.nrows(), 24);
1004    }
1005
1006    #[test]
1007    fn test_fmm_predict() {
1008        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1009        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1010
1011        // Predict for new subjects with covariate = 0.5
1012        let new_cov = FdMatrix::from_column_major(vec![0.5], 1, 1).unwrap();
1013        let predicted = fmm_predict(&result, Some(&new_cov));
1014
1015        assert_eq!(predicted.nrows(), 1);
1016        assert_eq!(predicted.ncols(), 50);
1017
1018        // Predicted curve should be reasonable (not NaN or extreme)
1019        for t in 0..50 {
1020            assert!(predicted[(0, t)].is_finite());
1021            assert!(
1022                predicted[(0, t)].abs() < 20.0,
1023                "Predicted value too extreme at t={}: {}",
1024                t,
1025                predicted[(0, t)]
1026            );
1027        }
1028    }
1029
1030    #[test]
1031    fn test_fmm_test_fixed_detects_effect() {
1032        let (data, subject_ids, covariates, _t) = generate_fmm_data(15, 3, 40);
1033
1034        let result = fmm_test_fixed(&data, &subject_ids, &covariates, 3, 99, 42).unwrap();
1035
1036        assert_eq!(result.f_statistics.len(), 1);
1037        assert_eq!(result.p_values.len(), 1);
1038        assert!(
1039            result.p_values[0] < 0.1,
1040            "Should detect covariate effect, got p={}",
1041            result.p_values[0]
1042        );
1043    }
1044
1045    #[test]
1046    fn test_fmm_test_fixed_no_effect() {
1047        let n_subjects = 10;
1048        let n_visits = 3;
1049        let m = 40;
1050        let t = uniform_grid(m);
1051        let n_total = n_subjects * n_visits;
1052
1053        // No covariate effect: Y = sin(2πt) + noise
1054        let mut col_major = vec![0.0; n_total * m];
1055        let mut subject_ids = vec![0usize; n_total];
1056        let mut cov_data = vec![0.0; n_total];
1057
1058        for s in 0..n_subjects {
1059            for v in 0..n_visits {
1060                let obs = s * n_visits + v;
1061                subject_ids[obs] = s;
1062                cov_data[obs] = s as f64 / n_subjects as f64;
1063                for (j, &tj) in t.iter().enumerate() {
1064                    col_major[obs + j * n_total] =
1065                        (2.0 * PI * tj).sin() + 0.1 * ((obs * 7 + j * 3) % 100) as f64 / 100.0;
1066                }
1067            }
1068        }
1069
1070        let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
1071        let covariates = FdMatrix::from_column_major(cov_data, n_total, 1).unwrap();
1072
1073        let result = fmm_test_fixed(&data, &subject_ids, &covariates, 3, 99, 42).unwrap();
1074        assert!(
1075            result.p_values[0] > 0.05,
1076            "Should not detect effect, got p={}",
1077            result.p_values[0]
1078        );
1079    }
1080
1081    #[test]
1082    fn test_fmm_invalid_input() {
1083        let data = FdMatrix::zeros(0, 0);
1084        assert!(fmm(&data, &[], None, 1).is_err());
1085
1086        let data = FdMatrix::zeros(10, 50);
1087        let ids = vec![0; 5]; // wrong length
1088        assert!(fmm(&data, &ids, None, 1).is_err());
1089    }
1090
1091    #[test]
1092    fn test_fmm_single_visit_per_subject() {
1093        let n = 10;
1094        let m = 40;
1095        let t = uniform_grid(m);
1096        let mut col_major = vec![0.0; n * m];
1097        let subject_ids: Vec<usize> = (0..n).collect();
1098
1099        for i in 0..n {
1100            for (j, &tj) in t.iter().enumerate() {
1101                col_major[i + j * n] = (2.0 * PI * tj).sin();
1102            }
1103        }
1104        let data = FdMatrix::from_column_major(col_major, n, m).unwrap();
1105
1106        // Should still work with 1 visit per subject
1107        let result = fmm(&data, &subject_ids, None, 2).unwrap();
1108        assert_eq!(result.n_subjects, n);
1109        assert_eq!(result.fitted.nrows(), n);
1110    }
1111
1112    #[test]
1113    fn test_build_subject_map() {
1114        let (map, n) = build_subject_map(&[5, 5, 10, 10, 20]);
1115        assert_eq!(n, 3);
1116        assert_eq!(map, vec![0, 0, 1, 1, 2]);
1117    }
1118
1119    #[test]
1120    fn test_variance_components_positive() {
1121        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1122        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1123
1124        assert!(result.sigma2_eps >= 0.0);
1125        for &s in &result.sigma2_u {
1126            assert!(s >= 0.0);
1127        }
1128    }
1129}